diff --git a/.bazelrc b/.bazelrc index fb938169b3c0..0006b953cfe7 100644 --- a/.bazelrc +++ b/.bazelrc @@ -31,6 +31,9 @@ build -c opt build --output_filter=DONT_MATCH_ANYTHING build --copt=-DMLIR_PYTHON_PACKAGE_PREFIX=jaxlib.mlir. +build --copt=-DNB_DOMAIN=jax + +build --legacy_external_runfiles=false # ############################################################################# # Platform Specific configs below. These are automatically picked up by Bazel @@ -97,6 +100,7 @@ build:windows --incompatible_strict_action_env=true # ############################################################################# build:nonccl --define=no_nccl_support=true +build --repo_env USE_PYWRAP_RULES=1 build:posix --copt=-fvisibility=hidden build:posix --copt=-Wno-sign-compare build:posix --cxxopt=-std=c++17 @@ -130,23 +134,27 @@ build:clang --copt=-Wno-gnu-offsetof-extensions build:clang --copt=-Qunused-arguments # Error on struct/class mismatches, since this causes link failures on Windows. build:clang --copt=-Werror=mismatched-tags +# Required when building with clang>=19, see jax-ml/jax#27091 +build:clang --copt=-Wno-error=c23-extensions # Configs for CUDA build:cuda --repo_env TF_NEED_CUDA=1 build:cuda --repo_env TF_NCCL_USE_STUB=1 # "sm" means we emit only cubin, which is forward compatible within a GPU generation. # "compute" means we emit both cubin and PTX, which is larger but also forward compatible to future GPU generations. -build:cuda --repo_env HERMETIC_CUDA_COMPUTE_CAPABILITIES="sm_50,sm_60,sm_70,sm_80,compute_90" +build:cuda --repo_env HERMETIC_CUDA_COMPUTE_CAPABILITIES="sm_50,sm_60,sm_70,sm_80,sm_90,sm_100,compute_120" build:cuda --crosstool_top=@local_config_cuda//crosstool:toolchain build:cuda --@local_config_cuda//:enable_cuda -# Default hermetic CUDA and CUDNN versions. -build:cuda --repo_env=HERMETIC_CUDA_VERSION="12.3.2" -build:cuda --repo_env=HERMETIC_CUDNN_VERSION="9.1.1" +# Default hermetic CUDA, CUDNN and NVSHMEM versions. +build:cuda --repo_env=HERMETIC_CUDA_VERSION="12.8.0" +build:cuda --repo_env=HERMETIC_CUDNN_VERSION="9.8.0" +build:cuda --repo_env=HERMETIC_NVSHMEM_VERSION="3.2.5" build:cuda --@local_config_cuda//cuda:include_cuda_libs=true -# This config is used for building targets with CUDA libraries from stubs. +# This config is used for building targets with CUDA/NVSHMEM libraries from stubs. build:cuda_libraries_from_stubs --@local_config_cuda//cuda:include_cuda_libs=false +build:cuda_libraries_from_stubs --@local_config_nvshmem//:include_nvshmem_libs=false # Force the linker to set RPATH, not RUNPATH. When resolving dynamic libraries, # ld.so prefers in order: RPATH, LD_LIBRARY_PATH, RUNPATH. JAX sets RPATH to @@ -238,6 +246,9 @@ build:ci_linux_aarch64_base --config=clang --verbose_failures=true build:ci_linux_aarch64_base --action_env=TF_SYSROOT="/dt10" build:ci_linux_aarch64_base --color=yes +# This appears to help avoid a timeout in CI for linalg_test. +build:ci_linux_aarch64_base --test_env=OMP_NUM_THREADS=8 + build:ci_linux_aarch64 --config=ci_linux_aarch64_base build:ci_linux_aarch64 --host_crosstool_top="@ml2014_clang_aarch64_config_aarch64//crosstool:toolchain" build:ci_linux_aarch64 --crosstool_top="@ml2014_clang_aarch64_config_aarch64//crosstool:toolchain" @@ -260,8 +271,8 @@ build:ci_darwin_arm64 --color=yes # Windows x86 CI configs build:ci_windows_amd64 --config=avx_windows build:ci_windows_amd64 --compiler=clang-cl --config=clang --verbose_failures=true -build:ci_windows_amd64 --crosstool_top="@xla//tools/toolchains/win/20240424:toolchain" -build:ci_windows_amd64 --extra_toolchains="@xla//tools/toolchains/win/20240424:cc-toolchain-x64_windows-clang-cl" +build:ci_windows_amd64 --crosstool_top="@xla//tools/toolchains/win2022/20241118:toolchain" +build:ci_windows_amd64 --extra_toolchains="@xla//tools/toolchains/win2022/20241118:cc-toolchain-x64_windows-clang-cl" build:ci_windows_amd64 --host_linkopt=/FORCE:MULTIPLE --linkopt=/FORCE:MULTIPLE build:ci_windows_amd64 --color=yes @@ -321,6 +332,9 @@ build:rbe_linux_x86_64 --config=ci_linux_x86_64 build:rbe_linux_x86_64_cuda --config=rbe_linux_x86_64_base build:rbe_linux_x86_64_cuda --config=ci_linux_x86_64_cuda build:rbe_linux_x86_64_cuda --repo_env=REMOTE_GPU_TESTING=1 +# Speed up CUDA repos creation by downloading ".tar" dists from the mirror. +build:rbe_linux_x86_64_cuda --repo_env=USE_CUDA_TAR_ARCHIVE_FILES=1 +build:rbe_linux_x86_64_cuda --repo_env=USE_NVSHMEM_TAR_ARCHIVE_FILES=1 # RBE configs for Windows # Set the remote worker pool @@ -329,9 +343,9 @@ common:rbe_windows_amd64 --remote_instance_name=projects/tensorflow-testing/inst build:rbe_windows_amd64 --config=rbe # Set the host, execution, and target platform -build:rbe_windows_amd64 --host_platform="@xla//tools/toolchains/win:x64_windows-clang-cl" -build:rbe_windows_amd64 --extra_execution_platforms="@xla//tools/toolchains/win:x64_windows-clang-cl" -build:rbe_windows_amd64 --platforms="@xla//tools/toolchains/win:x64_windows-clang-cl" +build:rbe_windows_amd64 --host_platform="@xla//tools/toolchains/win2022:windows_ltsc2022_clang" +build:rbe_windows_amd64 --extra_execution_platforms="@xla//tools/toolchains/win2022:windows_ltsc2022_clang" +build:rbe_windows_amd64 --platforms="@xla//tools/toolchains/win2022:windows_ltsc2022_clang" build:rbe_windows_amd64 --shell_executable=C:\\tools\\msys64\\usr\\bin\\bash.exe build:rbe_windows_amd64 --enable_runfiles @@ -371,6 +385,9 @@ build:rbe_cross_compile_base --remote_instance_name=projects/tensorflow-testing/ build:rbe_cross_compile_linux_aarch64 --config=cross_compile_linux_aarch64 build:rbe_cross_compile_linux_aarch64 --config=rbe_cross_compile_base +# Avoids a timeout in linalg_test on ARM. +build:rbe_cross_compile_linux_aarch64 --test_env=OMP_NUM_THREADS=8 + # Mac x86 build:cross_compile_darwin_x86_64 --config=cross_compile_base build:cross_compile_darwin_x86_64 --config=nonccl @@ -410,7 +427,7 @@ build:rbe_cross_compile_darwin_x86_64 --config=rbe_cross_compile_base ############################################################################# build:debug_symbols --strip=never --per_file_copt="xla/pjrt|xla/python@-g3" -build:debug --config debug_symbols -c fastbuild +build:debug --config=debug_symbols -c fastbuild # Load `.jax_configure.bazelrc` file written by build.py try-import %workspace%/.jax_configure.bazelrc diff --git a/.github/ISSUE_TEMPLATE/bug-report.yml b/.github/ISSUE_TEMPLATE/bug-report.yml index 628310519b66..1f8c2b2ac254 100644 --- a/.github/ISSUE_TEMPLATE/bug-report.yml +++ b/.github/ISSUE_TEMPLATE/bug-report.yml @@ -24,7 +24,7 @@ body: [issue search]: https://github.com/jax-ml/jax/search?q=is%3Aissue&type=issues - [Raw report]: http://github.com/jax-ml/jax/issues/new + [Raw report]: https://github.com/jax-ml/jax/issues/new?template=none - type: textarea attributes: label: Description diff --git a/.github/actionlint.yaml b/.github/actionlint.yaml new file mode 100644 index 000000000000..e7ee1a086558 --- /dev/null +++ b/.github/actionlint.yaml @@ -0,0 +1,20 @@ +# Configuration related to self-hosted runner. +self-hosted-runner: + labels: + - "linux-x86-n2-32" # Linux X86 runner using the 32 vcpu n2-standard-32 machine. + - "linux-x86-n2-64" # Linux X86 runner using the 64 vcpu n2-standard-64 machine. + - "linux-x86-g2-16-l4-1gpu" # Linux X86 GPU runner using g2-standard-16 machine with 1 NVIDIA L4 GPU attached. + - "linux-x86-g2-48-l4-4gpu" # Linux X86 GPU runner using g2-standard-48 machine with 4 NVIDIA L4 GPUs attached. + - "linux-x86-ct5lp-224-8tpu" # Linux X86 TPU runner using ct5lp-hightpu-8t machine with 2x4 topology. + - "linux-arm64-c4a-16" # Linux ARM64 CPU Runner using the 16 vcpu c4a-standard-16 machine. + - "linux-arm64-c4a-64" # Linux ARM64 CPU Runner using the 64 vcpu c4a-standard-64 machine. + - "windows-x86-n2-16" # Windows X86 runner using n2-standard-16 machine. + - "windows-x86-n2-64" # Windows X86 runner using n2-standard-64 machine. + - "linux-x86-a4-224-b200-1gpu" # Linux X86 GPU runner using 1 B200 GPU and 1/8 the resources of a a4-highgpu-8g machine + - "linux-x86-a3-8g-h100-8gpu" # Linux X86 GPU runner using a3-highgpu-8g machine with 8 NVIDIA H100 GPUs attached. + - "linux-x86-ct6e-180-8tpu" # Linux X86 TPU runner using ct6e-hightpu-8t machine with 2x4 topology. + - "linux-x86-ct6e-180-4tpu" # Linux X86 TPU runner using ct6e-hightpu-4t machine with 2x2 topology. + - "linux-x86-ct4p-240-4tpu" # Linux X86 TPU runner using ct4p-hightpu-4t machine with 2x2x1 topology. + - "linux-x86-n2-128" # Linux X86 runner using the 128 vcpu n2-standard-128 machine. + - "linux-x86-n2-16" # Linux X86 runner using the 16 vcpu n2-standard-16 machine. + - "linux-x86_64-cirrascale-64-8gpu-amd-mi250" # AMD runner diff --git a/.github/workflows/asan.yaml b/.github/workflows/asan.yaml index ea69d92e552e..533d4381f474 100644 --- a/.github/workflows/asan.yaml +++ b/.github/workflows/asan.yaml @@ -13,7 +13,7 @@ on: - main paths: - '**/workflows/asan.yaml' - +permissions: {} jobs: asan: # Don't execute in fork due to runner type @@ -41,11 +41,13 @@ jobs: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: path: jax + persist-credentials: false - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: repository: python/cpython path: cpython ref: v3.13.0 + persist-credentials: false - name: Build CPython with ASAN enabled env: ASAN_OPTIONS: detect_leaks=0 diff --git a/.github/workflows/bazel_cpu_py_import_rbe.yml b/.github/workflows/bazel_cpu_py_import_rbe.yml new file mode 100644 index 000000000000..65a7b7b6a01f --- /dev/null +++ b/.github/workflows/bazel_cpu_py_import_rbe.yml @@ -0,0 +1,60 @@ +# CI - Bazel CPU tests with py_import (RBE) +# +# This workflow runs the Bazel CPU tests with py_import dependency. It can only be triggered by +# other workflows via `workflow_call`. It is used by the `CI - Wheel Tests (Continuous)` workflows +# to run the Bazel CPU tests. +# +# It consists of the following job: +# run-tests: +# - Executes the `run_bazel_test_cpu_py_import_rbe.sh` script, which performs the following actions: +# - Runs the Bazel CPU tests with py_import dependency. +name: CI - Bazel CPU tests with py_import (RBE) +permissions: {} +on: + workflow_call: + inputs: + runner: + description: "Which runner should the workflow run on?" + type: string + default: "linux-x86-n2-16" + python: + description: "Which python version to test?" + type: string + default: "3.12" + enable-x64: + description: "Should x64 mode be enabled?" + type: string + default: "0" + halt-for-connection: + description: 'Should this workflow run wait for a remote connection?' + type: string + default: 'no' + +jobs: + run-tests: + defaults: + run: + # Explicitly set the shell to bash + shell: bash + runs-on: ${{ inputs.runner }} + container: ${{ (contains(inputs.runner, 'linux-x86') && 'us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build:latest') || + (contains(inputs.runner, 'linux-arm64') && 'us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build-arm64:latest') }} + env: + JAXCI_HERMETIC_PYTHON_VERSION: ${{ inputs.python }} + JAXCI_ENABLE_X64: ${{ inputs.enable-x64 }} + + name: "${{ (contains(inputs.runner, 'linux-x86') && 'linux x86') || + (contains(inputs.runner, 'linux-arm64') && 'linux arm64') }}, py ${{ inputs.python }}, x64=${{ inputs.enable-x64 }}" + + steps: + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + persist-credentials: false + # Halt for testing + - name: Wait For Connection + uses: google-ml-infra/actions/ci_connection@7f5ca0c263a81ed09ea276524c1b9192f1304e3c + with: + halt-dispatch-input: ${{ inputs.halt-for-connection }} + - name: Run Bazel CPU tests with py_import (RBE) + timeout-minutes: 60 + run: ./ci/run_bazel_test_cpu_py_import_rbe.sh diff --git a/.github/workflows/bazel_cpu_rbe.yml b/.github/workflows/bazel_cpu_rbe.yml index d6816d492d1d..99071974bd00 100644 --- a/.github/workflows/bazel_cpu_rbe.yml +++ b/.github/workflows/bazel_cpu_rbe.yml @@ -18,7 +18,7 @@ on: branches: - main - 'release/**' - +permissions: {} concurrency: group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} # Don't cancel in-progress jobs for main/release branches. @@ -28,31 +28,38 @@ jobs: run_tests: if: github.event.repository.fork == false runs-on: ${{ matrix.runner }} - container: ${{ (contains(matrix.runner, 'linux-x86') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest') || - (contains(matrix.runner, 'linux-arm64') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-arm64:latest') }} + container: ${{ (contains(matrix.runner, 'linux-x86') && 'us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build:latest') || + (contains(matrix.runner, 'linux-arm64') && 'us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build-arm64:latest') }} env: JAXCI_HERMETIC_PYTHON_VERSION: ${{ matrix.python }} JAXCI_ENABLE_X64: ${{ matrix.enable-x_64 }} # Begin Presubmit Naming Check - name modification requires internal check to be updated strategy: matrix: - python: ["3.10", "3.13"] + python: ["3.11", "3.13"] runner: ["linux-x86-n2-16", "linux-arm64-c4a-16"] enable-x_64: [1, 0] exclude: # Exclude x64=1 on the oldest Python and x64=0 on the newest Python. As long as we have # coverage for one of each, we don't need to run both. - - python: "3.10" + - python: "3.11" enable-x_64: 1 - python: "3.13" enable-x_64: 0 - name: "Bazel CPU tests (${{ matrix.runner }}, Python ${{ matrix.python }}, x64=${{ matrix.enable-x_64 }})" + # Only test a single Python version on Arm64 as we don't run the tests. + - python: "3.11" + runner: "linux-arm64-c4a-16" + name: "Bazel CPU ${{ (contains(matrix.runner, 'linux-arm64') && 'build only' || 'tests') }} (${{ matrix.runner }}, Python ${{ matrix.python }}, x64=${{ matrix.enable-x_64 }})" # End Presubmit Naming Check github-cpu-presubmits steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + persist-credentials: false - name: Wait For Connection - uses: google-ml-infra/actions/ci_connection@main + uses: google-ml-infra/actions/ci_connection@7f5ca0c263a81ed09ea276524c1b9192f1304e3c with: halt-dispatch-input: ${{ inputs.halt-for-connection }} - - name: Run Bazel CPU Tests with RBE + # Since we do not have a Linux Arm64 RBE pool, we do not run the tests on Arm64. Instead, we + # cross-compile the tests on the Linux x86 RBE pool. + - name: ${{ (contains(matrix.runner, 'linux-arm64') && 'Build' || 'Run') }} Bazel CPU Tests with RBE run: ./ci/run_bazel_test_cpu_rbe.sh \ No newline at end of file diff --git a/.github/workflows/bazel_cuda_non_rbe.yml b/.github/workflows/bazel_cuda_non_rbe.yml index 0b0e1cb62497..5168dc6d002e 100644 --- a/.github/workflows/bazel_cuda_non_rbe.yml +++ b/.github/workflows/bazel_cuda_non_rbe.yml @@ -17,29 +17,28 @@ on: runner: description: "Which runner should the workflow run on?" type: string - required: true default: "linux-x86-n2-16" python: description: "Which python version to test?" type: string - required: true default: "3.12" enable-x64: description: "Should x64 mode be enabled?" type: string - required: true default: "0" + jaxlib-version: + description: "Which jaxlib version to test? (head/pypi_latest)" + type: string + default: "head" gcs_download_uri: description: "GCS location URI from where the artifacts should be downloaded" - required: true default: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' type: string halt-for-connection: description: 'Should this workflow run wait for a remote connection?' - type: boolean - required: false - default: false - + type: string + default: 'no' +permissions: {} jobs: run-tests: defaults: @@ -47,7 +46,7 @@ jobs: # Explicitly set the shell to bash shell: bash runs-on: ${{ inputs.runner }} - container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cuda12.3-cudnn9.1:latest" + container: "us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build-cuda12.8-cudnn9.8:latest" env: JAXCI_HERMETIC_PYTHON_VERSION: ${{ inputs.python }} @@ -55,17 +54,22 @@ jobs: # Enable writing to the Bazel remote cache bucket. JAXCI_WRITE_TO_BAZEL_REMOTE_CACHE: "1" - name: "Bazel single accelerator and multi-accelerator CUDA tests (${{ inputs.runner }}, Python ${{ inputs.python }}, x64=${{ inputs.enable-x64 }})" + name: "jaxlib=${{ inputs.jaxlib-version }}, + ${{ (contains(inputs.runner, 'h100') && 'h100') || + (contains(inputs.runner, 'b200') && 'b200') || + (contains(inputs.runner, 'l4') && 'l4') }}, py ${{ inputs.python }}, x64=${{ inputs.enable-x64 }}" steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + persist-credentials: false - name: Set env vars for use in artifact download URL run: | os=$(uname -s | awk '{print tolower($0)}') arch=$(uname -m) # Get the major and minor version of Python. - # E.g if JAXCI_HERMETIC_PYTHON_VERSION=3.10, then python_major_minor=310 + # E.g if JAXCI_HERMETIC_PYTHON_VERSION=3.11, then python_major_minor=311 python_major_minor=$(echo "$JAXCI_HERMETIC_PYTHON_VERSION" | tr -d '.') echo "OS=${os}" >> $GITHUB_ENV @@ -77,11 +81,21 @@ jobs: # fails. Instead, we verify the outcome in the next step so that we can print a more # informative error message. continue-on-error: true - run: >- - mkdir -p $(pwd)/dist && - gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ && - gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jax*cuda*plugin*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ && - gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jax*cuda*pjrt*${OS}*${ARCH}*.whl" $(pwd)/dist/ + run: | + mkdir -p $(pwd)/dist + gcloud storage cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl $(pwd)/dist/ + + if [[ ${{ inputs.jaxlib-version }} == "head" ]]; then + gcloud storage cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ && + gcloud storage cp -r "${{ inputs.gcs_download_uri }}/jax*cuda*plugin*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ && + gcloud storage cp -r "${{ inputs.gcs_download_uri }}/jax*cuda*pjrt*${OS}*${ARCH}*.whl" $(pwd)/dist/ + elif [[ ${{ inputs.jaxlib-version }} == "pypi_latest" ]]; then + PYTHON=python${{ inputs.python }} + $PYTHON -m pip download jaxlib jax-cuda12-pjrt jax-cuda12-plugin --dest $(pwd)/dist/ + else + echo "Invalid jaxlib version: ${{ inputs.jaxlib-version }}" + exit 1 + fi - name: Skip the test run if the wheel artifacts were not downloaded successfully if: steps.download-wheel-artifacts.outcome == 'failure' run: | @@ -91,7 +105,7 @@ jobs: exit 1 # Halt for testing - name: Wait For Connection - uses: google-ml-infra/actions/ci_connection@main + uses: google-ml-infra/actions/ci_connection@7f5ca0c263a81ed09ea276524c1b9192f1304e3c with: halt-dispatch-input: ${{ inputs.halt-for-connection }} - name: Run Bazel CUDA tests (Non-RBE) diff --git a/.github/workflows/bazel_cuda_rbe.yml b/.github/workflows/bazel_cuda_rbe.yml index 5a2c94c4db47..3aaf2a485e77 100644 --- a/.github/workflows/bazel_cuda_rbe.yml +++ b/.github/workflows/bazel_cuda_rbe.yml @@ -23,25 +23,25 @@ concurrency: group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} # Don't cancel in-progress jobs for main/release branches. cancel-in-progress: ${{ !contains(github.ref, 'release/') && github.ref != 'main' }} - +permissions: {} jobs: run_tests: if: github.event.repository.fork == false runs-on: ${{ matrix.runner }} - container: 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest' + container: 'us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build:latest' env: JAXCI_HERMETIC_PYTHON_VERSION: ${{ matrix.python }} JAXCI_ENABLE_X64: ${{ matrix.enable-x_64 }} # Begin Presubmit Naming Check - name modification requires internal check to be updated strategy: matrix: - python: ["3.10", "3.13"] + python: ["3.11", "3.13"] runner: ["linux-x86-n2-16"] enable-x_64: [1, 0] exclude: # Exclude x64=1 on the oldest Python and x64=0 on the newest Python. As long as we have # coverage for one of each, we don't need to run both. - - python: "3.10" + - python: "3.11" enable-x_64: 1 - python: "3.13" enable-x_64: 0 @@ -49,8 +49,10 @@ jobs: # End Presubmit Naming Check github-cuda-presubmits steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + persist-credentials: false - name: Wait For Connection - uses: google-ml-infra/actions/ci_connection@main + uses: google-ml-infra/actions/ci_connection@7f5ca0c263a81ed09ea276524c1b9192f1304e3c with: halt-dispatch-input: ${{ inputs.halt-for-connection }} - name: Run Bazel CUDA Tests with RBE diff --git a/.github/workflows/bazel_optional_h100_b200.yml b/.github/workflows/bazel_optional_h100_b200.yml new file mode 100644 index 000000000000..16c7bb95c16b --- /dev/null +++ b/.github/workflows/bazel_optional_h100_b200.yml @@ -0,0 +1,113 @@ +name: CI - Bazel Optional H100 and B200 CUDA tests +on: + # Runs on PR if label "CI Optional GPU Presubmit" is present. + workflow_dispatch: + inputs: + halt-for-connection: + description: 'Should this workflow run wait for a remote connection?' + type: choice + required: true + default: 'no' + options: + - 'yes' + - 'no' + pull_request: + branches: + - main + types: [ labeled, synchronize ] + schedule: + - cron: "0 */2 * * *" # Run once every 2 hours +permissions: + contents: read +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} + # Don't cancel in-progress jobs for main/release branches. + cancel-in-progress: ${{ !contains(github.ref, 'release/') && github.ref != 'main' }} +jobs: + run_tests: + if: ${{ github.event.repository.fork == false && (github.event_name == 'schedule' || github.event_name == 'workflow_dispatch' || contains(github.event.pull_request.labels.*.name, 'CI Optional GPU Presubmit')) }} + runs-on: linux-x86-a4-224-b200-1gpu + container: 'us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build-cuda12.8-cudnn9.8:latest' + name: "Bazel single B200 CUDA tests" +# End Presubmit Naming Check github-cuda-presubmits + steps: + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + persist-credentials: false + - name: Wait For Connection + uses: google-ml-infra/actions/ci_connection@7f5ca0c263a81ed09ea276524c1b9192f1304e3c + with: + halt-dispatch-input: ${{ inputs.halt-for-connection }} + - name: Run Bazel single B200 CUDA Tests + run: | + nvidia-smi + bazel test --config=rbe_linux_x86_64_cuda \ + --config=resultstore \ + --config=rbe_cache \ + --repo_env=HERMETIC_CUDA_VERSION="12.8.0" \ + --repo_env=HERMETIC_CUDNN_VERSION="9.8.0" \ + --repo_env=HERMETIC_PYTHON_VERSION="3.13" \ + --test_env=XLA_PYTHON_CLIENT_ALLOCATOR=platform \ + --run_under "$(pwd)/build/parallel_accelerator_execute.sh" \ + --test_output=errors \ + --test_tag_filters=-multiaccelerator \ + --test_env=JAX_ACCELERATOR_COUNT=1 \ + --test_env=JAX_TESTS_PER_ACCELERATOR=8 \ + --strategy=TestRunner=local \ + --local_test_jobs=8 \ + --test_env=JAX_EXCLUDE_TEST_TARGETS='PmapTest.testSizeOverflow|.*InterpretTest.*' \ + --test_env=TF_CPP_MIN_LOG_LEVEL=0 \ + --test_env=JAX_SKIP_SLOW_TESTS=true \ + --action_env=JAX_ENABLE_X64="1" \ + --action_env=NCCL_DEBUG=WARN \ + --flaky_test_attempts=1 \ + --test_timeout=420 \ + --color=yes \ + //tests:cudnn_fusion_test_gpu \ + //tests:scaled_matmul_stablehlo_test_gpu \ + //tests:fused_attention_stablehlo_test_gpu \ + //tests:nn_test_gpu \ + //tests/pallas:gpu_tests \ + //tests/mosaic:gpu_tests + run_multiaccelerator_tests: + if: ${{ github.event.repository.fork == false && (github.event_name == 'schedule' || github.event_name == 'workflow_dispatch' || contains(github.event.pull_request.labels.*.name, 'CI Optional GPU Presubmit')) }} + runs-on: linux-x86-a3-8g-h100-8gpu + container: 'us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build-cuda12.8-cudnn9.8:latest' + name: "Bazel multiple H100 CUDA tests" + steps: + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + persist-credentials: false + - name: Wait For Connection + uses: google-ml-infra/actions/ci_connection@7f5ca0c263a81ed09ea276524c1b9192f1304e3c + with: + halt-dispatch-input: ${{ inputs.halt-for-connection }} + - name: Run Bazel multiple H100 CUDA Tests + run: | + nvidia-smi + bazel test --config=rbe_linux_x86_64_cuda \ + --config=resultstore \ + --config=rbe_cache \ + --repo_env=HERMETIC_CUDA_VERSION="12.8.0" \ + --repo_env=HERMETIC_CUDNN_VERSION="9.8.0" \ + --repo_env=HERMETIC_PYTHON_VERSION="3.13" \ + --test_env=XLA_PYTHON_CLIENT_ALLOCATOR=platform \ + --test_output=errors \ + --strategy=TestRunner=local \ + --local_test_jobs=8 \ + --test_env=JAX_EXCLUDE_TEST_TARGETS='PmapTest.testSizeOverflow|.*InterpretTest.*' \ + --test_tag_filters=multiaccelerator \ + --test_env=TF_CPP_MIN_LOG_LEVEL=0 \ + --test_env=JAX_SKIP_SLOW_TESTS=true \ + --action_env=JAX_ENABLE_X64="1" \ + --action_env=NCCL_DEBUG=WARN \ + --flaky_test_attempts=1 \ + --color=yes \ + //tests/mosaic:gpu_tests \ + //tests/pallas:gpu_tests \ + //tests:array_interoperability_test_gpu \ + //tests:cudnn_fusion_test_gpu \ + //tests:fused_attention_stablehlo_test_gpu \ + //tests:gpu_tests \ + //tests:python_callback_test_gpu \ + //tests:ragged_collective_test_gpu \ No newline at end of file diff --git a/.github/workflows/build_artifacts.yml b/.github/workflows/build_artifacts.yml index c2e7acb91f7a..7459953c37a1 100644 --- a/.github/workflows/build_artifacts.yml +++ b/.github/workflows/build_artifacts.yml @@ -12,16 +12,14 @@ on: runner: description: "Which runner should the workflow run on?" type: choice - required: true default: "linux-x86-n2-16" options: - "linux-x86-n2-16" - - "linux-arm64-c4a-64" - - "windows-x86-n2-64" + - "linux-arm64-t2a-48" + - "windows-x86-n2-16" artifact: description: "Which JAX artifact to build?" type: choice - required: true default: "jaxlib" options: - "jax" @@ -31,17 +29,14 @@ on: python: description: "Which python version should the artifact be built for?" type: choice - required: false default: "3.12" options: - - "3.10" - "3.11" - "3.12" - "3.13" clone_main_xla: description: "Should latest XLA be used?" type: choice - required: false default: "0" options: - "1" @@ -49,7 +44,6 @@ on: halt-for-connection: description: 'Should this workflow run wait for a remote connection?' type: choice - required: false default: 'no' options: - 'yes' @@ -59,41 +53,32 @@ on: runner: description: "Which runner should the workflow run on?" type: string - required: true default: "linux-x86-n2-16" artifact: description: "Which JAX artifact to build?" type: string - required: true default: "jaxlib" python: description: "Which python version should the artifact be built for?" type: string - required: false default: "3.12" clone_main_xla: description: "Should latest XLA be used?" type: string - required: false default: "0" upload_artifacts_to_gcs: description: "Should the artifacts be uploaded to a GCS bucket?" - required: true default: true type: boolean gcs_upload_uri: description: "GCS location prefix to where the artifacts should be uploaded" - required: true default: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' type: string outputs: gcs_upload_uri: description: "GCS location prefix to where the artifacts were uploaded" value: ${{ jobs.build-artifacts.outputs.gcs_upload_uri }} - -permissions: - contents: read - +permissions: {} jobs: build-artifacts: defaults: @@ -103,15 +88,18 @@ jobs: runs-on: ${{ inputs.runner }} - container: ${{ (contains(inputs.runner, 'linux-x86') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest') || - (contains(inputs.runner, 'linux-arm64') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-arm64:latest') || + container: ${{ (contains(inputs.runner, 'linux-x86') && 'us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build:latest') || + (contains(inputs.runner, 'linux-arm64') && 'us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build-arm64:latest') || (contains(inputs.runner, 'windows-x86') && null) }} env: JAXCI_HERMETIC_PYTHON_VERSION: "${{ inputs.python }}" JAXCI_CLONE_MAIN_XLA: "${{ inputs.clone_main_xla }}" - name: Build ${{ inputs.artifact }} (${{ inputs.runner }}, Python ${{ inputs.python }}, clone main XLA=${{ inputs.clone_main_xla }}) + name: "${{ inputs.artifact }}, + ${{ (contains(inputs.runner, 'linux-x86') && 'linux x86') || + (contains(inputs.runner, 'linux-arm64') && 'linux arm64') || + (contains(inputs.runner, 'windows-x86') && 'windows x86') }}, py ${{ inputs.python }}, clone main XLA=${{ inputs.clone_main_xla }}" # Map the job outputs to step outputs outputs: @@ -119,15 +107,17 @@ jobs: steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - name: Enable RBE if building on Linux x86 - if: contains(inputs.runner, 'linux-x86') + with: + persist-credentials: false + - name: Enable RBE if building on Linux x86 or Windows x86 + if: contains(inputs.runner, 'linux-x86') || contains(inputs.runner, 'windows-x86') run: echo "JAXCI_BUILD_ARTIFACT_WITH_RBE=1" >> $GITHUB_ENV - - name: Enable Bazel remote cache (with writes enabled) if building on Linux Aarch64 or Windows x86 - if: contains(inputs.runner, 'linux-arm64') || contains(inputs.runner, 'windows-x86') + - name: Enable Bazel remote cache (with writes enabled) if building on Linux Aarch64 + if: contains(inputs.runner, 'linux-arm64') run: echo "JAXCI_WRITE_TO_BAZEL_REMOTE_CACHE=1" >> $GITHUB_ENV # Halt for testing - name: Wait For Connection - uses: google-ml-infra/actions/ci_connection@main + uses: google-ml-infra/actions/ci_connection@7f5ca0c263a81ed09ea276524c1b9192f1304e3c with: halt-dispatch-input: ${{ inputs.halt-for-connection }} - name: Build ${{ inputs.artifact }} @@ -136,13 +126,13 @@ jobs: - name: Upload artifacts to a GCS bucket (non-Windows runs) if: >- ${{ inputs.upload_artifacts_to_gcs && !contains(inputs.runner, 'windows-x86') }} - run: gsutil -m cp -r "$(pwd)/dist/*.whl" "${{ inputs.gcs_upload_uri }}"/ + run: gcloud storage cp -r "$(pwd)/dist/*.whl" "${{ inputs.gcs_upload_uri }}"/ # Set shell to cmd to avoid path errors when using gcloud commands on Windows - name: Upload artifacts to a GCS bucket (Windows runs) if: >- ${{ inputs.upload_artifacts_to_gcs && contains(inputs.runner, 'windows-x86') }} shell: cmd - run: gsutil -m cp -r "dist/*.whl" "${{ inputs.gcs_upload_uri }}"/ + run: gcloud storage cp -r "dist/*.whl" "${{ inputs.gcs_upload_uri }}"/ - name: Store the GCS upload URI as an output id: store-gcs-upload-uri if: ${{ inputs.upload_artifacts_to_gcs }} diff --git a/.github/workflows/ci-build.yaml b/.github/workflows/ci-build.yaml index f43407af2ed9..dbd51373a3ac 100644 --- a/.github/workflows/ci-build.yaml +++ b/.github/workflows/ci-build.yaml @@ -1,11 +1,5 @@ name: CI -# We test all supported Python versions as follows: -# - 3.10 : Documentation build -# - 3.10 : Part of Matrix with NumPy dispatch -# - 3.10 : Part of Matrix -# - 3.11 : Part of Matrix - on: # Trigger the workflow on push or pull request, # but only for the main branch @@ -16,10 +10,7 @@ on: branches: - main -permissions: - contents: read # to fetch code - actions: write # to cancel previous workflows - +permissions: {} concurrency: group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} cancel-in-progress: true @@ -30,12 +21,14 @@ jobs: timeout-minutes: 5 steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + persist-credentials: false - name: Set up Python 3.11 - uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0 + uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0 with: python-version: 3.11 - run: python -m pip install pre-commit - - uses: actions/cache@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0 + - uses: actions/cache@5a3ec84eff668545956fd18022155c47e93e2684 # v4.2.3 with: path: ~/.cache/pre-commit key: pre-commit-${{ env.pythonLocation }}-${{ hashFiles('.pre-commit-config.yaml', 'setup.py') }} @@ -53,8 +46,8 @@ jobs: matrix: # Test the oldest and newest supported Python versions here. include: - - name-prefix: "with 3.10" - python-version: "3.10" + - name-prefix: "with 3.11" + python-version: "3.11" enable-x64: 1 prng-upgrade: 1 num_generated_cases: 1 @@ -65,12 +58,14 @@ jobs: num_generated_cases: 1 steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + persist-credentials: false - name: Image Setup run: | apt update apt install -y libssl-dev - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0 + uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0 with: python-version: ${{ matrix.python-version }} - name: Install dependencies @@ -88,7 +83,6 @@ jobs: JAX_SKIP_SLOW_TESTS: true PY_COLORS: 1 run: | - uv pip install --system -e . echo "JAX_NUM_GENERATED_CASES=$JAX_NUM_GENERATED_CASES" echo "JAX_ENABLE_X64=$JAX_ENABLE_X64" echo "JAX_ENABLE_CUSTOM_PRNG=$JAX_ENABLE_CUSTOM_PRNG" @@ -104,11 +98,13 @@ jobs: timeout-minutes: 10 strategy: matrix: - python-version: ['3.10'] + python-version: ['3.12'] steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + persist-credentials: false - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0 + uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0 with: python-version: ${{ matrix.python-version }} - name: Install dependencies @@ -134,15 +130,17 @@ jobs: timeout-minutes: 10 strategy: matrix: - python-version: ['3.10'] + python-version: ['3.11'] steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + persist-credentials: false - name: Image Setup run: | apt update - apt install -y libssl-dev libsqlite3-dev + apt install -y libssl-dev libsqlite3-dev build-essential - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0 + uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0 with: python-version: ${{ matrix.python-version }} - name: Install dependencies @@ -151,7 +149,7 @@ jobs: uv pip install --system -r docs/requirements.txt - name: Render documentation run: | - sphinx-build -j auto --color -W --keep-going -b html -D nb_execution_mode=off docs docs/build/html + sphinx-build -j auto --color -W --keep-going -b html docs docs/build/html jax2tf_test: name: "jax2tf_test (py ${{ matrix.python-version }} on ${{ matrix.os }}, x64=${{ matrix.enable-x64}})" @@ -161,21 +159,23 @@ jobs: matrix: # Test the oldest supported Python version here. include: - - python-version: "3.10" + - python-version: "3.11" os: ubuntu-latest enable-x64: 0 num_generated_cases: 10 steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + persist-credentials: false - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0 + uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0 with: python-version: ${{ matrix.python-version }} - name: Install dependencies run: | pip install uv~=0.5.30 uv pip install --system .[minimum-jaxlib] -r build/test-requirements.txt - uv pip install --system --pre tensorflow==2.19.0rc0 + uv pip install --system --pre tensorflow==2.19.0 - name: Run tests env: @@ -185,7 +185,6 @@ jobs: JAX_SKIP_SLOW_TESTS: true PY_COLORS: 1 run: | - uv pip install --system -e . echo "JAX_NUM_GENERATED_CASES=$JAX_NUM_GENERATED_CASES" echo "JAX_ENABLE_X64=$JAX_ENABLE_X64" echo "JAX_ENABLE_CHECKS=$JAX_ENABLE_CHECKS" @@ -200,8 +199,10 @@ jobs: timeout-minutes: 30 steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + persist-credentials: false - name: Set up Python - uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0 + uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0 with: python-version: 3.12 - name: Install JAX diff --git a/.github/workflows/cloud-tpu-ci-nightly.yml b/.github/workflows/cloud-tpu-ci-nightly.yml index 099f4ad5c520..5a97999c2b23 100644 --- a/.github/workflows/cloud-tpu-ci-nightly.yml +++ b/.github/workflows/cloud-tpu-ci-nightly.yml @@ -11,10 +11,13 @@ # Github Actions environment). name: CI - Cloud TPU (nightly) +# Disable the schedule; Slated for removal, the new test workflow is in +# "wheel_tests_nightly_release.yml" on: - schedule: - - cron: "0 2,14 * * *" # Run at 7am and 7pm PST +# schedule: +# - cron: "0 2,14 * * *" # Run at 7am and 7pm PST workflow_dispatch: # allows triggering the workflow run manually + # This should also be set to read-only in the project settings, but it's nice to # document and enforce the permissions here. permissions: @@ -26,17 +29,25 @@ jobs: matrix: jaxlib-version: ["head", "pypi_latest", "nightly", "nightly+oldest_supported_libtpu"] tpu: [ - # {type: "v3-8", cores: "4"}, # Enable when we have the v3 type available {type: "v4-8", cores: "4", runner: "linux-x86-ct4p-240-4tpu"}, - {type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"} + {type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"}, + {type: "v6e-8", cores: "8", runner: "linux-x86-ct6e-180-8tpu"} ] - python-version: ["3.10"] + python-version: ["3.11"] + # Exclude v6e-8 tests for nightly+oldest_supported_libtpu and pypi_latest for resource constraints. + exclude: + - tpu: + type: "v6e-8" + jaxlib-version: "nightly+oldest_supported_libtpu" + - tpu: + type: "v6e-8" + jaxlib-version: "pypi_latest" name: "TPU test (jaxlib=${{ matrix.jaxlib-version }}, ${{ matrix.tpu.type }})" env: - LIBTPU_OLDEST_VERSION_DATE: 20241205 + LIBTPU_OLDEST_VERSION_DATE: 20250228 PYTHON: python${{ matrix.python-version }} runs-on: ${{ matrix.tpu.runner }} - container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest" + container: "us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build:latest" timeout-minutes: 180 defaults: run: @@ -46,6 +57,8 @@ jobs: # mandates using a specific commit for non-Google actions. We use # https://github.com/sethvargo/ratchet to pin specific versions. - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + persist-credentials: false # Checkout XLA at head, if we're building jaxlib at head. - name: Checkout XLA at head uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 @@ -53,6 +66,7 @@ jobs: with: repository: openxla/xla path: xla + persist-credentials: false # We need to mark the GitHub workspace as safe as otherwise git commands will fail. - name: Mark GitHub workspace as safe run: | @@ -80,14 +94,14 @@ jobs: elif [ "${{ matrix.jaxlib-version }}" == "nightly" ]; then $PYTHON -m uv pip install \ - --pre . -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \ + --pre . -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \ libtpu -f https://storage.googleapis.com/jax-releases/libtpu_releases.html \ requests elif [ "${{ matrix.jaxlib-version }}" == "nightly+oldest_supported_libtpu" ]; then # TODO(phawkins): switch to libtpu, when the oldest release we support is a libtpu release. $PYTHON -m uv pip install \ - --pre . -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \ + --pre . -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \ libtpu-nightly==0.1.dev${{ env.LIBTPU_OLDEST_VERSION_DATE }} \ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html \ requests diff --git a/.github/workflows/cloud-tpu-ci-presubmit.yml b/.github/workflows/cloud-tpu-ci-presubmit.yml index a92e3cc19313..c6988f198675 100644 --- a/.github/workflows/cloud-tpu-ci-presubmit.yml +++ b/.github/workflows/cloud-tpu-ci-presubmit.yml @@ -25,9 +25,7 @@ on: # This should also be set to read-only in the project settings, but it's nice to # document and enforce the permissions here. -permissions: - contents: read - +permissions: {} concurrency: group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} # Don't cancel in-progress jobs for main/release branches. @@ -44,7 +42,7 @@ jobs: with: runner: "linux-x86-n2-16" artifact: ${{ matrix.artifact }} - python: "3.10" + python: "3.11" clone_main_xla: 1 upload_artifacts_to_gcs: true gcs_upload_uri: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' @@ -54,12 +52,13 @@ jobs: needs: [build-jax-artifacts] uses: ./.github/workflows/pytest_tpu.yml # Begin Presubmit Naming Check - name modification requires internal check to be updated - name: "TPU test (jaxlib=head, v5e-8)" + name: "TPU test (jaxlib=head)" with: runner: "linux-x86-ct5lp-224-8tpu" cores: "8" tpu-type: "v5e-8" - python: "3.10" + python: "3.11" libtpu-version-type: "nightly" gcs_download_uri: ${{ needs.build-jax-artifacts.outputs.gcs_upload_uri }} + halt-for-connection: ${{ inputs.halt-for-connection || false }} # End Presubmit Naming Check github-tpu-presubmits \ No newline at end of file diff --git a/.github/workflows/community_release_actions.yml b/.github/workflows/community_release_actions.yml new file mode 100644 index 000000000000..1110cbad9475 --- /dev/null +++ b/.github/workflows/community_release_actions.yml @@ -0,0 +1,34 @@ +name: Release Actions + +on: + release: + types: [published] + +permissions: + contents: read + +jobs: + discord_release: + if: github.repository_owner == 'jax-ml' + runs-on: ubuntu-latest + steps: + - name: Get release URL + id: get-release-url + run: | + URL="https://docs.jax.dev/en/latest/changelog.html" + echo "::set-output name=URL::$URL" + - name: Get content + uses: 2428392/gh-truncate-string-action@b3ff790d21cf42af3ca7579146eedb93c8fb0757 # v1.4.1 + id: get-content + with: + stringToTruncate: | + JAX [${{ github.event.release.tag_name }}](<${{ steps.get-release-url.outputs.URL }}>) was just released! + + ${{ github.event.release.body }} + maxLength: 2000 + truncationSymbol: "..." + - name: Discord Webhook Action + uses: tsickert/discord-webhook@b217a69502f52803de774ded2b1ab7c282e99645 # v7.0.0 + with: + webhook-url: ${{ secrets.DISCORD_WEBHOOK_URL }} + content: ${{ steps.get-content.outputs.string }} diff --git a/.github/workflows/jax-array-api.yml b/.github/workflows/jax-array-api.yml index 2b97c5a05c1c..eaabc54368de 100644 --- a/.github/workflows/jax-array-api.yml +++ b/.github/workflows/jax-array-api.yml @@ -11,38 +11,36 @@ on: concurrency: group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} cancel-in-progress: true - +permissions: {} jobs: build: - - runs-on: ubuntu-latest + runs-on: linux-x86-n2-16 + container: us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build:latest strategy: matrix: python-version: [3.11] - + env: + PYTHON: "python${{ matrix.python-version }}" steps: - name: Checkout jax uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + persist-credentials: false - name: Checkout array-api-tests uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: repository: data-apis/array-api-tests - # TODO(jakevdp) update this to a stable release/tag when available. - ref: '0b89c5268e4e4a352223a487b8f63dbd1023872d' # Latest commit as of 2025-03-04 + ref: '2025.05.23' submodules: 'true' path: 'array-api-tests' - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0 - with: - python-version: ${{ matrix.python-version }} + persist-credentials: false - name: Install dependencies run: | - pip install uv~=0.5.30 - uv pip install --system .[ci] pytest-xdist -r array-api-tests/requirements.txt + $PYTHON -m uv pip install --system .[ci] pytest-xdist -r array-api-tests/requirements.txt - name: Run the test suite env: ARRAY_API_TESTS_MODULE: jax.numpy JAX_ENABLE_X64: 'true' run: | cd ${GITHUB_WORKSPACE}/array-api-tests - pytest -n auto array_api_tests --derandomize --disable-deadline --skips-file ${GITHUB_WORKSPACE}/tests/array_api_skips.txt + $PYTHON -m pytest -n auto array_api_tests --derandomize --disable-deadline --skips-file ${GITHUB_WORKSPACE}/tests/array_api_skips.txt diff --git a/.github/workflows/k8s.yaml b/.github/workflows/k8s.yaml new file mode 100644 index 000000000000..86bc5e6c168b --- /dev/null +++ b/.github/workflows/k8s.yaml @@ -0,0 +1,116 @@ +name: Multi-process run using K8s +on: + push: + branches: + - main + paths: + - '.github/workflows/k8s.yaml' + - 'ci/k8s/**' + - 'jax/distributed.py' + - 'jax/_src/distributed.py' + - 'jax/_src/clusters/**' + pull_request: + branches: + - main + paths: + - '.github/workflows/k8s.yaml' + - 'ci/k8s/**' + - 'jax/distributed.py' + - 'jax/_src/distributed.py' + - 'jax/_src/clusters/**' + +permissions: {} +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} + cancel-in-progress: true +defaults: + run: + shell: bash -ex -o pipefail {0} +jobs: + distributed-initialize: + runs-on: ubuntu-22.04 + strategy: + fail-fast: false + matrix: + controller: [jobset, indexed-job] + steps: + - name: Checkout + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # ratchet:actions/checkout@v4 + with: + path: jax + persist-credentials: false + + - name: Start Minikube cluster + uses: medyagh/setup-minikube@cea33675329b799adccc9526aa5daccc26cd5052 # ratchet:medyagh/setup-minikube@v0.0.19 + + - name: Install K8s Jobset + if: matrix.controller == 'jobset' + run: | + kubectl apply --server-side -f https://github.com/kubernetes-sigs/jobset/releases/download/v0.8.0/manifests.yaml + kubectl wait --for=condition=established crd/jobsets.jobset.x-k8s.io --timeout=60s + kubectl rollout status -n jobset-system deploy/jobset-controller-manager --timeout=120s + + - name: Build image + run: | + cat > Dockerfile <> $GITHUB_ENV + else + # Install the PJRT, JAX CUDA Plugin, and Nvidia CUDA packages from PyPI. + echo "JAXCI_JAX_PYPI_EXTRAS=cuda12">> $GITHUB_ENV + fi + else + gcloud storage cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ && + gcloud storage cp -r "${{ inputs.gcs_download_uri }}/jax*cuda*plugin*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ && + gcloud storage cp -r "${{ inputs.gcs_download_uri }}/jax*cuda*pjrt*${OS}*${ARCH}*.whl" $(pwd)/dist/ + + if [[ "${{ inputs.use-nvidia-pip-wheels }}" == true ]]; then + # Install the Nvidia CUDA packages from PyPI. The wheels downloaded in the previous + # step will be used for the PJRT and JAX CUDA Plugin packages. + echo "JAXCI_JAX_PYPI_EXTRAS=cuda12">> $GITHUB_ENV + fi + fi - name: Skip the test run if the wheel artifacts were not downloaded successfully if: steps.download-wheel-artifacts.outcome == 'failure' run: | @@ -101,12 +131,28 @@ jobs: echo "Skipping the test run." exit 1 - name: Install Python dependencies - run: $JAXCI_PYTHON -m uv pip install -r build/test-requirements.txt + run: | + # For prerelease python 3.14, some pre-built dependency wheels aren't available, + # so we need to download their deps or build them from source. + if [[ $JAXCI_PYTHON == "python3.14" ]]; then + # Build numpy from source + # Need to include fixes for https://github.com/numpy/numpy/issues/28681. + $JAXCI_PYTHON -m uv pip install "git+https://github.com/numpy/numpy@v2.3.0" + + # Install build requirements for scipy + apt update && apt upgrade -y && apt-get install -y gfortran libopenblas-dev liblapack-dev pkg-config --no-install-recommends + $JAXCI_PYTHON -m uv pip install "git+https://github.com/scipy/scipy@main" + + # Install build requirements for pillow + apt install -q -y libjpeg-dev --no-install-recommends + fi + + $JAXCI_PYTHON -m uv pip install -r build/test-requirements.txt # Halt for testing - name: Wait For Connection - uses: google-ml-infra/actions/ci_connection@main + uses: google-ml-infra/actions/ci_connection@7f5ca0c263a81ed09ea276524c1b9192f1304e3c with: halt-dispatch-input: ${{ inputs.halt-for-connection }} - name: Run Pytest CUDA tests - timeout-minutes: 60 + timeout-minutes: 120 run: ./ci/run_pytest_cuda.sh \ No newline at end of file diff --git a/.github/workflows/pytest_tpu.yml b/.github/workflows/pytest_tpu.yml index a105a2feb347..3bb88eef2e3b 100644 --- a/.github/workflows/pytest_tpu.yml +++ b/.github/workflows/pytest_tpu.yml @@ -11,7 +11,6 @@ # - Installs the downloaded jaxlib wheel. # - Runs the TPU tests with Pytest. name: CI - Pytest TPU - on: workflow_call: inputs: @@ -23,61 +22,57 @@ on: runner: description: "Which runner should the workflow run on?" type: string - required: true default: "linux-x86-ct5lp-224-8tpu" cores: description: "How many TPU cores should the test use?" type: string - required: true default: "8" tpu-type: description: "Which TPU type is used for testing?" type: string - required: true default: "v5e-8" python: description: "Which Python version should be used for testing?" type: string - required: true default: "3.12" run-full-tpu-test-suite: description: "Should the full TPU test suite be run?" type: string - required: false default: "0" libtpu-version-type: description: "Which libtpu version should be used for testing?" type: string - required: false # Choices are: # - "nightly": Use the nightly libtpu wheel. # - "pypi_latest": Use the latest libtpu wheel from PyPI. # - "oldest_supported_libtpu": Use the oldest supported libtpu wheel. default: "nightly" + download-jax-only-from-gcs: + description: "Whether to download only the jax wheel from GCS (e.g for testing a jax only release)" + default: '0' + type: string gcs_download_uri: description: "GCS location prefix from where the artifacts should be downloaded" - required: true default: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' type: string halt-for-connection: description: 'Should this workflow run wait for a remote connection?' - type: boolean - required: false - default: false - + type: string + default: 'no' +permissions: {} jobs: run-tests: defaults: run: shell: bash runs-on: ${{ inputs.runner }} - container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest" + container: "us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build:latest" # Begin Presubmit Naming Check - name modification requires internal check to be updated - name: "Pytest TPU (${{ inputs.tpu-type }}, Python ${{ inputs.python }}, libtpu=${{ inputs.libtpu-version-type }})" + name: "${{ inputs.tpu-type }}, py ${{ inputs.python }}, libtpu=${{ inputs.libtpu-version-type }}" # End Presubmit Naming Check github-tpu-presubmits env: - LIBTPU_OLDEST_VERSION_DATE: 20241205 + LIBTPU_OLDEST_VERSION_DATE: 20250228 JAXCI_HERMETIC_PYTHON_VERSION: "${{ inputs.python }}" JAXCI_PYTHON: "python${{ inputs.python }}" JAXCI_RUN_FULL_TPU_TEST_SUITE: "${{ inputs.run-full-tpu-test-suite }}" @@ -85,13 +80,15 @@ jobs: steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + persist-credentials: false - name: Set env vars for use in artifact download URL run: | os=$(uname -s | awk '{print tolower($0)}') arch=$(uname -m) # Get the major and minor version of Python. - # E.g if JAXCI_HERMETIC_PYTHON_VERSION=3.10, then python_major_minor=310 + # E.g if JAXCI_HERMETIC_PYTHON_VERSION=3.11, then python_major_minor=311 # E.g if JAXCI_HERMETIC_PYTHON_VERSION=3.13-nogil, then python_major_minor=313t python_major_minor=$(echo "${JAXCI_HERMETIC_PYTHON_VERSION//-nogil/t}" | tr -d '.') @@ -110,7 +107,11 @@ jobs: run: | mkdir -p $(pwd)/dist gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl $(pwd)/dist/ - gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ + if [[ "${{ inputs.download-jax-only-from-gcs }}" == "1" ]]; then + echo "JAX only release. Only downloading the jax wheel from the release bucket." + else + gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ + fi - name: Skip the test run if the wheel artifacts were not downloaded successfully if: steps.download-wheel-artifacts.outcome == 'failure' run: | @@ -128,9 +129,9 @@ jobs: $JAXCI_PYTHON -m uv pip install --pre libtpu -f https://storage.googleapis.com/jax-releases/libtpu_releases.html elif [[ "${{ inputs.libtpu-version-type }}" == "pypi_latest" ]]; then echo "Using latest libtpu from PyPI" - # Set JAXCI_ADDITIONAL_WHEELS_INSTALL_FROM_PYPI to "tpu_pypi". The `run_pytest_tpu.sh` - # script will install the latest libtpu wheel from PyPI. - echo "JAXCI_ADDITIONAL_WHEELS_INSTALL_FROM_PYPI=tpu_pypi" >> $GITHUB_ENV + # Set JAXCI_JAX_PYPI_EXTRAS to "tpu". The `run_pytest_tpu.sh` script will install the + # latest libtpu wheel from PyPI. + echo "JAXCI_JAX_PYPI_EXTRAS=tpu" >> $GITHUB_ENV elif [[ "${{ inputs.libtpu-version-type }}" == "oldest_supported_libtpu" ]]; then echo "Using oldest supported libtpu" $JAXCI_PYTHON -m uv pip install --pre libtpu-nightly==0.1.dev${{ env.LIBTPU_OLDEST_VERSION_DATE }} \ @@ -143,7 +144,7 @@ jobs: fi # Halt for testing - name: Wait For Connection - uses: google-ml-infra/actions/ci_connection@main + uses: google-ml-infra/actions/ci_connection@7f5ca0c263a81ed09ea276524c1b9192f1304e3c with: halt-dispatch-input: ${{ inputs.halt-for-connection }} - name: Run Pytest TPU tests diff --git a/.github/workflows/release-notification.yml b/.github/workflows/release-notification.yml index a4a342ef6de7..6d68bf922655 100644 --- a/.github/workflows/release-notification.yml +++ b/.github/workflows/release-notification.yml @@ -2,14 +2,21 @@ name: Google Chat Release Notification on: release: types: [published] +permissions: {} jobs: build: + env: + WEBHOOK_URL: ${{ secrets.RELEASES_WEBHOOK }} + RELEASE_NAME: ${{github.event.release.name}} + PUBLISHED_AT: ${{github.event.release.published_at}} + AUTHOR_LOGIN: ${{github.event.release.author.login}} + RELEASE_URL: ${{github.event.release.url}} runs-on: ubuntu-latest steps: - name: Google Chat Notification run: | - curl --location --request POST '${{ secrets.RELEASES_WEBHOOK }}' \ + curl --location --request POST '${WEBHOOK_URL}' \ --header 'Content-Type: application/json' \ --data-raw '{ - "text": "Release ${{github.event.release.name}} at ${{github.event.release.published_at}} by ${{github.event.release.author.login}}. <${{github.event.release.url}}|[github]>" + "text": "Release $RELEASE_NAME at $PUBLISHED_AT by $AUTHOR_LOGIN. <$RELEASE_URL|[github]>" }' diff --git a/.github/workflows/requirements_lock_3_13_ft.patch b/.github/workflows/requirements_lock_3_13_ft.patch deleted file mode 100644 index 0b63cb5b8711..000000000000 --- a/.github/workflows/requirements_lock_3_13_ft.patch +++ /dev/null @@ -1,85 +0,0 @@ -diff --git a/build/requirements_lock_3_13_ft.txt b/build/requirements_lock_3_13_ft.txt -index e7a2968e9..d37e11ee3 100644 ---- a/build/requirements_lock_3_13_ft.txt -+++ b/build/requirements_lock_3_13_ft.txt -@@ -4,6 +4,11 @@ - # - # pip-compile --allow-unsafe --generate-hashes --output-file=build/requirements_lock_3_13_ft.txt build/requirements.in - # -+ -+--pre -+--extra-index-url https://pypi.anaconda.org/scientific-python-nightly-wheels/simple -+numpy -+ - absl-py==2.1.0 \ - --hash=sha256:526a04eadab8b4ee719ce68f204172ead1027549089702d99b9059f129ff1308 \ - --hash=sha256:7820790efbb316739cde8b4e19357243fc3608a152024288513dd968d7d959ff -@@ -328,68 +333,6 @@ mpmath==1.3.0 \ - --hash=sha256:7a28eb2a9774d00c7bc92411c19a89209d5da7c4c9a9e227be8330a23a25b91f \ - --hash=sha256:a0b2b9fe80bbcd81a6647ff13108738cfb482d481d826cc0e02f5b35e5c88d2c - # via -r build/test-requirements.txt --numpy==2.2.1 ; python_version >= "3.13" \ -- --hash=sha256:059e6a747ae84fce488c3ee397cee7e5f905fd1bda5fb18c66bc41807ff119b2 \ -- --hash=sha256:08ef779aed40dbc52729d6ffe7dd51df85796a702afbf68a4f4e41fafdc8bda5 \ -- --hash=sha256:164a829b6aacf79ca47ba4814b130c4020b202522a93d7bff2202bfb33b61c60 \ -- --hash=sha256:26c9c4382b19fcfbbed3238a14abf7ff223890ea1936b8890f058e7ba35e8d71 \ -- --hash=sha256:27f5cdf9f493b35f7e41e8368e7d7b4bbafaf9660cba53fb21d2cd174ec09631 \ -- --hash=sha256:31b89fa67a8042e96715c68e071a1200c4e172f93b0fbe01a14c0ff3ff820fc8 \ -- --hash=sha256:32cb94448be47c500d2c7a95f93e2f21a01f1fd05dd2beea1ccd049bb6001cd2 \ -- --hash=sha256:360137f8fb1b753c5cde3ac388597ad680eccbbbb3865ab65efea062c4a1fd16 \ -- --hash=sha256:3683a8d166f2692664262fd4900f207791d005fb088d7fdb973cc8d663626faa \ -- --hash=sha256:38efc1e56b73cc9b182fe55e56e63b044dd26a72128fd2fbd502f75555d92591 \ -- --hash=sha256:3d03883435a19794e41f147612a77a8f56d4e52822337844fff3d4040a142964 \ -- --hash=sha256:3ecc47cd7f6ea0336042be87d9e7da378e5c7e9b3c8ad0f7c966f714fc10d821 \ -- --hash=sha256:40f9e544c1c56ba8f1cf7686a8c9b5bb249e665d40d626a23899ba6d5d9e1484 \ -- --hash=sha256:4250888bcb96617e00bfa28ac24850a83c9f3a16db471eca2ee1f1714df0f957 \ -- --hash=sha256:4511d9e6071452b944207c8ce46ad2f897307910b402ea5fa975da32e0102800 \ -- --hash=sha256:45681fd7128c8ad1c379f0ca0776a8b0c6583d2f69889ddac01559dfe4390918 \ -- --hash=sha256:48fd472630715e1c1c89bf1feab55c29098cb403cc184b4859f9c86d4fcb6a95 \ -- --hash=sha256:4c86e2a209199ead7ee0af65e1d9992d1dce7e1f63c4b9a616500f93820658d0 \ -- --hash=sha256:4dfda918a13cc4f81e9118dea249e192ab167a0bb1966272d5503e39234d694e \ -- --hash=sha256:5062dc1a4e32a10dc2b8b13cedd58988261416e811c1dc4dbdea4f57eea61b0d \ -- --hash=sha256:51faf345324db860b515d3f364eaa93d0e0551a88d6218a7d61286554d190d73 \ -- --hash=sha256:526fc406ab991a340744aad7e25251dd47a6720a685fa3331e5c59fef5282a59 \ -- --hash=sha256:53c09385ff0b72ba79d8715683c1168c12e0b6e84fb0372e97553d1ea91efe51 \ -- --hash=sha256:55ba24ebe208344aa7a00e4482f65742969a039c2acfcb910bc6fcd776eb4355 \ -- --hash=sha256:5b6c390bfaef8c45a260554888966618328d30e72173697e5cabe6b285fb2348 \ -- --hash=sha256:5c5cc0cbabe9452038ed984d05ac87910f89370b9242371bd9079cb4af61811e \ -- --hash=sha256:5edb4e4caf751c1518e6a26a83501fda79bff41cc59dac48d70e6d65d4ec4440 \ -- --hash=sha256:61048b4a49b1c93fe13426e04e04fdf5a03f456616f6e98c7576144677598675 \ -- --hash=sha256:676f4eebf6b2d430300f1f4f4c2461685f8269f94c89698d832cdf9277f30b84 \ -- --hash=sha256:67d4cda6fa6ffa073b08c8372aa5fa767ceb10c9a0587c707505a6d426f4e046 \ -- --hash=sha256:694f9e921a0c8f252980e85bce61ebbd07ed2b7d4fa72d0e4246f2f8aa6642ab \ -- --hash=sha256:733585f9f4b62e9b3528dd1070ec4f52b8acf64215b60a845fa13ebd73cd0712 \ -- --hash=sha256:7671dc19c7019103ca44e8d94917eba8534c76133523ca8406822efdd19c9308 \ -- --hash=sha256:780077d95eafc2ccc3ced969db22377b3864e5b9a0ea5eb347cc93b3ea900315 \ -- --hash=sha256:7ba9cc93a91d86365a5d270dee221fdc04fb68d7478e6bf6af650de78a8339e3 \ -- --hash=sha256:89b16a18e7bba224ce5114db863e7029803c179979e1af6ad6a6b11f70545008 \ -- --hash=sha256:9036d6365d13b6cbe8f27a0eaf73ddcc070cae584e5ff94bb45e3e9d729feab5 \ -- --hash=sha256:93cf4e045bae74c90ca833cba583c14b62cb4ba2cba0abd2b141ab52548247e2 \ -- --hash=sha256:9ad014faa93dbb52c80d8f4d3dcf855865c876c9660cb9bd7553843dd03a4b1e \ -- --hash=sha256:9b1d07b53b78bf84a96898c1bc139ad7f10fda7423f5fd158fd0f47ec5e01ac7 \ -- --hash=sha256:a7746f235c47abc72b102d3bce9977714c2444bdfaea7888d241b4c4bb6a78bf \ -- --hash=sha256:aa3017c40d513ccac9621a2364f939d39e550c542eb2a894b4c8da92b38896ab \ -- --hash=sha256:b34d87e8a3090ea626003f87f9392b3929a7bbf4104a05b6667348b6bd4bf1cd \ -- --hash=sha256:b541032178a718c165a49638d28272b771053f628382d5e9d1c93df23ff58dbf \ -- --hash=sha256:ba5511d8f31c033a5fcbda22dd5c813630af98c70b2661f2d2c654ae3cdfcfc8 \ -- --hash=sha256:bc8a37ad5b22c08e2dbd27df2b3ef7e5c0864235805b1e718a235bcb200cf1cb \ -- --hash=sha256:bff7d8ec20f5f42607599f9994770fa65d76edca264a87b5e4ea5629bce12268 \ -- --hash=sha256:c1ad395cf254c4fbb5b2132fee391f361a6e8c1adbd28f2cd8e79308a615fe9d \ -- --hash=sha256:f1d09e520217618e76396377c81fba6f290d5f926f50c35f3a5f72b01a0da780 \ -- --hash=sha256:f3eac17d9ec51be534685ba877b6ab5edc3ab7ec95c8f163e5d7b39859524716 \ -- --hash=sha256:f419290bc8968a46c4933158c91a0012b7a99bb2e465d5ef5293879742f8797e \ -- --hash=sha256:f62aa6ee4eb43b024b0e5a01cf65a0bb078ef8c395e8713c6e8a12a697144528 \ -- --hash=sha256:f74e6fdeb9a265624ec3a3918430205dff1df7e95a230779746a6af78bc615af \ -- --hash=sha256:f9b57eaa3b0cd8db52049ed0330747b0364e899e8a606a624813452b8203d5f7 \ -- --hash=sha256:fce4f615f8ca31b2e61aa0eb5865a21e14f5629515c9151850aa936c02a1ee51 -- # via -- # -r build/requirements.in -- # contourpy -- # matplotlib -- # ml-dtypes -- # scipy - nvidia-cublas-cu12==12.8.3.14 ; sys_platform == "linux" \ - --hash=sha256:3f0e05e7293598cf61933258b73e66a160c27d59c4422670bf0b79348c04be44 \ - --hash=sha256:93a4e0e386cc7f6e56c822531396de8170ed17068a1e18f987574895044cd8c3 \ diff --git a/.github/workflows/rocm-ci.yml b/.github/workflows/rocm-ci.yml index 713e9099e381..ab4016f301a8 100644 --- a/.github/workflows/rocm-ci.yml +++ b/.github/workflows/rocm-ci.yml @@ -6,9 +6,7 @@ on: branches: - main -permissions: - contents: read - +permissions: {} concurrency: group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} @@ -18,8 +16,8 @@ jobs: env: BASE_IMAGE: "ubuntu:22.04" TEST_IMAGE: ubuntu-jax-upstream-${{ github.run_id }}_${{ github.run_number }}_${{ github.run_attempt }} - PYTHON_VERSION: "3.10" - ROCM_VERSION: "6.2.4" + PYTHON_VERSION: "3.11" + ROCM_VERSION: "6.3.3" WORKSPACE_DIR: workdir_${{ github.run_id }}_${{ github.run_number }}_${{ github.run_attempt }} steps: - name: Clean up old runs @@ -36,6 +34,7 @@ jobs: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: path: ${{ env.WORKSPACE_DIR }} + persist-credentials: false - name: Build JAX run: | pushd $WORKSPACE_DIR diff --git a/.github/workflows/tsan-suppressions.txt b/.github/workflows/tsan-suppressions_3.13.txt similarity index 50% rename from .github/workflows/tsan-suppressions.txt rename to .github/workflows/tsan-suppressions_3.13.txt index 7b713b2da194..3095eacf8060 100644 --- a/.github/workflows/tsan-suppressions.txt +++ b/.github/workflows/tsan-suppressions_3.13.txt @@ -2,34 +2,29 @@ # are racing on a call to __register_frame_info(), but that function appears to be correctly locked internally. race:llvm::RuntimeDyldELF::registerEHFrames -# https://github.com/python/cpython/issues/128050 -race:partial_vectorcall_fallback - # https://github.com/openxla/xla/issues/20686 race:dnnl_sgemm -# https://github.com/python/cpython/issues/128130 -race_top:run_eval_code_obj - -# Likely only happens when the process is crashing. -race:dump_traceback +# https://github.com/python/cpython/issues/128050 +race:partial_vectorcall_fallback # https://github.com/python/cpython/issues/128137 # Fixed in Python 3.14, but not backported to 3.13. race:immortalize_interned race:_PyUnicode_InternMortal +race:_PyUnicode_InternImmortal # https://github.com/python/cpython/issues/128144 # Fixed in Python 3.14, but not backported to 3.13. race_top:PyMember_GetOne -# https://github.com/python/cpython/issues/129547 -race:type_get_annotations - - -# https://github.com/python/cpython/issues/129748 -race:mi_block_set_nextx +# https://github.com/python/cpython/issues/131680 +# Fixed in Python 3.14, but not backported to 3.13. +race_top:new_reference +race:_Py_IsOwnedByCurrentThread +# https://github.com/python/cpython/issues/128130 +race_top:run_eval_code_obj # Races because the LAPACK and BLAS in our scipy isn't TSAN instrumented. race:heevd_ffi @@ -39,29 +34,6 @@ race:scal_k_ race:gemm_beta race:gemm_oncopy - - -# Races below this point are likely fixed. -# TODO(phawkins): remove these if they don't show up in CI again. - -# https://github.com/python/cpython/issues/128100 -# race:ensure_nonmanaged_dict - -# https://github.com/python/cpython/issues/128657 -# race:py_digest_by_name - -# https://github.com/python/cpython/issues/128714 -# race:func_get_annotations - -# https://github.com/python/cpython/issues/129533 -# race:PyGC_Disable -# race:PyGC_Enable - -# https://github.com/python/cpython/issues/128133 -# race:bytes_hash - -# https://github.com/python/cpython/issues/130571 -# race:_PyObject_GetMethod - -# https://github.com/python/cpython/issues/130547 -# race:split_keys_entry_added +# https://github.com/python/cpython/issues/132214 +# Fixed in Python 3.15, but not backported to 3.13, 3.14. +race:type_update_dict diff --git a/.github/workflows/tsan-suppressions_3.14.txt b/.github/workflows/tsan-suppressions_3.14.txt new file mode 100644 index 000000000000..d987879cab58 --- /dev/null +++ b/.github/workflows/tsan-suppressions_3.14.txt @@ -0,0 +1,21 @@ +# false-positive caused because we haven't tsan-instrumented libgcc_s. Multiple threads +# are racing on a call to __register_frame_info(), but that function appears to be correctly locked internally. +race:llvm::RuntimeDyldELF::registerEHFrames + +# https://github.com/openxla/xla/issues/20686 +race:dnnl_sgemm + +# https://github.com/python/cpython/issues/128050 +race:partial_vectorcall_fallback + +# Races because the LAPACK and BLAS in our scipy isn't TSAN instrumented. +race:heevd_ffi +race:gesdd_ffi +race:dscal_k_ +race:scal_k_ +race:gemm_beta +race:gemm_oncopy + +# https://github.com/python/cpython/issues/132214 +# Fixed in Python 3.15, but not backported to 3.13, 3.14. +race:type_update_dict diff --git a/.github/workflows/tsan.yaml b/.github/workflows/tsan.yaml index 7d93707e4e92..c3ee37dd82f4 100644 --- a/.github/workflows/tsan.yaml +++ b/.github/workflows/tsan.yaml @@ -3,7 +3,6 @@ name: CI - Free-threading and Thread Sanitizer (nightly) concurrency: group: ${{ github.workflow }}-${{ github.ref }} cancel-in-progress: true - on: schedule: - cron: "0 5 * * *" # Daily at 05:00 UTC == 00:00 EST == 21:00 PST @@ -13,7 +12,8 @@ on: - main paths: - '**/workflows/tsan.yaml' - + - '**/workflows/tsan-suppressions*.txt' +permissions: {} jobs: tsan: runs-on: linux-x86-n2-64 @@ -21,6 +21,16 @@ jobs: image: index.docker.io/library/ubuntu@sha256:b359f1067efa76f37863778f7b6d0e8d911e3ee8efa807ad01fbf5dc1ef9006b # ratchet:ubuntu:24.04 strategy: fail-fast: false + matrix: + include: + - name-prefix: "with 3.13" + python-version: "3.13" + github_branch: "3.13" + requirements_lock_name: "requirements_lock_3_13_ft" + - name-prefix: "with 3.14" + python-version: "3.14" + github_branch: "3.14" + requirements_lock_name: "requirements_lock_3_14_ft" defaults: run: shell: bash -l {0} @@ -32,33 +42,51 @@ jobs: DEBIAN_FRONTEND: noninteractive run: | apt update - apt install -y clang-18 libstdc++-14-dev build-essential libssl-dev \ + apt install -q -y clang-18 libstdc++-14-dev build-essential libssl-dev \ zlib1g-dev libbz2-dev libreadline-dev libsqlite3-dev curl git \ libncursesw5-dev xz-utils tk-dev libxml2-dev libxmlsec1-dev \ libffi-dev liblzma-dev file zip - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: path: jax - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - with: - repository: python/cpython - path: cpython - ref: "3.13" + persist-credentials: false - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: repository: numpy/numpy path: numpy submodules: true + persist-credentials: false + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + if: ${{ matrix.python-version == '3.14' }} + with: + repository: scipy/scipy + path: scipy + submodules: true + persist-credentials: false + + - name: Get year & week number + id: get-date + run: echo "date=$(/bin/date "+%Y-%U")" >> $GITHUB_OUTPUT + shell: bash -l {0} - - name: Restore cached TSAN CPython + - name: Restore cached TSAN CPython ${{ matrix.python-version }} id: cache-cpython-tsan-restore - uses: actions/cache/restore@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0 + uses: actions/cache/restore@5a3ec84eff668545956fd18022155c47e93e2684 # v4.2.3 with: path: | ./python-tsan.tgz - key: ${{ runner.os }}-cpython-tsan-${{ hashFiles('cpython/configure.ac') }} + key: ${{ runner.os }}-cpython-tsan-${{ matrix.python-version }}-${{ steps.get-date.outputs.date }} - - name: Build CPython with enabled TSAN + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + if: steps.cache-cpython-tsan-restore.outputs.cache-hit != 'true' + with: + repository: python/cpython + path: cpython + ref: ${{ matrix.github_branch }} + persist-credentials: false + + + - name: Build TSAN CPython ${{ matrix.python-version }} if: steps.cache-cpython-tsan-restore.outputs.cache-hit != 'true' run: | cd cpython @@ -72,31 +100,27 @@ jobs: # Create archive to be used with bazel as hermetic python: cd ${GITHUB_WORKSPACE} && tar -czpf python-tsan.tgz cpython-tsan - - name: Save TSAN CPython + - name: Save TSAN CPython ${{ matrix.python-version }} id: cache-cpython-tsan-save if: steps.cache-cpython-tsan-restore.outputs.cache-hit != 'true' - uses: actions/cache/save@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0 + uses: actions/cache/save@5a3ec84eff668545956fd18022155c47e93e2684 # v4.2.3 with: path: | ./python-tsan.tgz - key: ${{ runner.os }}-cpython-tsan-${{ hashFiles('cpython/configure.ac') }} - - - name: Get year & week number - id: get-date - run: echo "date=$(/bin/date "+%Y-%U")" >> $GITHUB_OUTPUT - shell: bash -l {0} + key: ${{ runner.os }}-cpython-tsan-${{ matrix.python-version }}-${{ steps.get-date.outputs.date }} - name: Restore cached TSAN Numpy id: cache-numpy-tsan-restore - uses: actions/cache/restore@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0 + uses: actions/cache/restore@5a3ec84eff668545956fd18022155c47e93e2684 # v4.2.3 with: path: | ./wheelhouse - key: ${{ runner.os }}-numpy-tsan-${{ hashFiles('numpy/pyproject.toml') }}-${{ steps.get-date.outputs.date }} + key: ${{ runner.os }}-numpy-tsan-${{ matrix.python-version }}-${{ hashFiles('numpy/pyproject.toml') }}-${{ steps.get-date.outputs.date }} - name: Build TSAN Numpy wheel if: steps.cache-numpy-tsan-restore.outputs.cache-hit != 'true' run: | + set -eux cd numpy # If we restored cpython from cache, we need to get python interpreter from python-tsan.tgz @@ -112,8 +136,7 @@ jobs: export PATH=${GITHUB_WORKSPACE}/cpython-tsan/bin/:$PATH python3 -m pip install uv~=0.5.30 - # Make sure to install a compatible Cython version (master branch is best for now) - python3 -m uv pip install -r requirements/build_requirements.txt -U git+https://github.com/cython/cython + python3 -m uv pip install -r requirements/build_requirements.txt CC=clang-18 CXX=clang++-18 python3 -m pip wheel --wheel-dir dist -v . --no-build-isolation -Csetup-args=-Db_sanitize=thread -Csetup-args=-Dbuildtype=debugoptimized @@ -142,11 +165,86 @@ jobs: - name: Save TSAN Numpy wheel id: cache-numpy-tsan-save if: steps.cache-numpy-tsan-restore.outputs.cache-hit != 'true' - uses: actions/cache/save@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0 + uses: actions/cache/save@5a3ec84eff668545956fd18022155c47e93e2684 # v4.2.3 + with: + path: | + ./wheelhouse + key: ${{ runner.os }}-numpy-tsan-${{ matrix.python-version }}-${{ hashFiles('numpy/pyproject.toml') }}-${{ steps.get-date.outputs.date }} + + - name: Restore cached Scipy + if: ${{ matrix.python-version == '3.14' }} + id: cache-scipy-restore + uses: actions/cache/restore@5a3ec84eff668545956fd18022155c47e93e2684 # v4.2.3 with: path: | ./wheelhouse - key: ${{ runner.os }}-numpy-tsan-${{ hashFiles('numpy/pyproject.toml') }}-${{ steps.get-date.outputs.date }} + key: ${{ runner.os }}-scipy-${{ matrix.python-version }}-${{ hashFiles('scipy/pyproject.toml') }}-${{ steps.get-date.outputs.date }} + + - name: Build Scipy wheel + if: ${{ steps.cache-scipy-restore.outputs.cache-hit != 'true' && matrix.python-version == '3.14' }} + env: + DEBIAN_FRONTEND: noninteractive + run: | + # Install scipy dependencies: + apt install -q -y gfortran libopenblas-dev liblapack-dev pkg-config --no-install-recommends + + cd scipy + + # If we restored cpython from cache, we need to get python interpreter from python-tsan.tgz + if [ ! -d ${GITHUB_WORKSPACE}/cpython-tsan/bin/ ]; then + echo "Extract cpython from python-tsan.tgz" + pushd . + ls ${GITHUB_WORKSPACE}/python-tsan.tgz + cd ${GITHUB_WORKSPACE} && tar -xzf python-tsan.tgz + ls ${GITHUB_WORKSPACE}/cpython-tsan/bin/ + popd + fi + + export PATH=${GITHUB_WORKSPACE}/cpython-tsan/bin/:$PATH + + python3 -m pip install uv~=0.5.30 + + python3 -m uv pip install -U --pre numpy --extra-index-url file://${GITHUB_WORKSPACE}/wheelhouse/ + python3 -m uv pip install cython pythran pybind11 meson-python ninja + + python3 -m uv pip list | grep -E "(numpy|pythran|cython|pybind11)" + + export CC=clang-18 + export CXX=clang++-18 + python3 -m pip wheel --wheel-dir dist -vvv . --no-build-isolation --no-deps -Csetup-args=-Dbuildtype=debugoptimized + + # Create simple index and copy the wheel + mkdir -p ${GITHUB_WORKSPACE}/wheelhouse/scipy + + scipy_whl_name=($(cd dist && ls scipy*.whl)) + if [ -z "${scipy_whl_name}" ]; then exit 1; fi + + echo "Built TSAN Scipy wheel: ${scipy_whl_name}" + + cp dist/${scipy_whl_name} ${GITHUB_WORKSPACE}/wheelhouse/scipy + + # Recreate wheelhouse index with Numpy and Scipy + cat << EOF > ${GITHUB_WORKSPACE}/wheelhouse/index.html + + numpy>
+ scipy>
+ + EOF + + cat << EOF > ${GITHUB_WORKSPACE}/wheelhouse/scipy/index.html + + ${scipy_whl_name}
+ + EOF + + - name: Save Scipy wheel + id: cache-scipy-save + if: ${{ steps.cache-scipy-restore.outputs.cache-hit != 'true' && matrix.python-version == '3.14' }} + uses: actions/cache/save@5a3ec84eff668545956fd18022155c47e93e2684 # v4.2.3 + with: + path: | + ./wheelhouse + key: ${{ runner.os }}-scipy-${{ matrix.python-version }}-${{ hashFiles('scipy/pyproject.toml') }}-${{ steps.get-date.outputs.date }} - name: Build Jax and run tests timeout-minutes: 120 @@ -155,15 +253,16 @@ jobs: JAX_ENABLE_X64: true JAX_SKIP_SLOW_TESTS: true PY_COLORS: 1 + DEBIAN_FRONTEND: noninteractive run: | + set -x cd jax export PYTHON_SHA256=($(sha256sum ${GITHUB_WORKSPACE}/python-tsan.tgz)) echo "Python sha256: ${PYTHON_SHA256}" - python3 -VV python3 build/build.py build --configure_only \ - --python_version=3.13-ft \ + --python_version=${{ matrix.python-version }}-ft \ --bazel_options=--repo_env=HERMETIC_PYTHON_URL="file://${GITHUB_WORKSPACE}/python-tsan.tgz" \ --bazel_options=--repo_env=HERMETIC_PYTHON_SHA256=${PYTHON_SHA256} \ --bazel_options=--repo_env=HERMETIC_PYTHON_PREFIX="cpython-tsan/" \ @@ -173,18 +272,24 @@ jobs: --bazel_options=--copt=-g \ --clang_path=/usr/bin/clang-18 - # Patch build/requirements_lock_3_13_ft.txt to use TSAN instrumented NumPy - sed -i "s|+--extra-index-url.*|+--extra-index-url file://${GITHUB_WORKSPACE}/wheelhouse/|" .github/workflows/requirements_lock_3_13_ft.patch - cat .github/workflows/requirements_lock_3_13_ft.patch - git apply .github/workflows/requirements_lock_3_13_ft.patch || exit 1 + mkdir -p dist + # Check whether we have numpy wheel or exit with error + ls ${GITHUB_WORKSPACE}/wheelhouse/numpy/*.whl || exit 1 + cp -v ${GITHUB_WORKSPACE}/wheelhouse/numpy/*.whl dist/ + if [ "${{ matrix.python-version }}" == "3.14" ]; then + # Check whether we have scipy wheel or exit with error + ls ${GITHUB_WORKSPACE}/wheelhouse/scipy/*.whl || exit 1 + cp -v ${GITHUB_WORKSPACE}/wheelhouse/scipy/*.whl dist/ + + # Patch build/requirements_lock_3_14_ft.txt to use TSAN instrumented NumPy and Scipy + sed -i "s|--extra-index-url.*|--extra-index-url file://${GITHUB_WORKSPACE}/wheelhouse/|" build/${{ matrix.requirements_lock_name }}.txt - # Display the content for debugging in logs - cat build/requirements_lock_3_13_ft.txt | head -15 - # Check the patch - cat build/requirements_lock_3_13_ft.txt | head -15 | grep -E "(--pre|.*${GITHUB_WORKSPACE}/wheelhouse/|numpy)" - if [ "$?" == "1" ]; then echo "Could not find the patch in the requirements_lock_3_13_ft.txt"; exit 1; fi - cat build/requirements_lock_3_13_ft.txt | grep -E "(numpy==)" - if [ "$?" == "0" ]; then "Found original numpy dependency in the requirements_lock_3_13_ft.txt"; exit 1; fi + # We should install jpeg dev package to be able to build Pillow from source: + apt install -q -y libjpeg-dev --no-install-recommends + + # Install scipy runtime dependencies (in case we restore scipy wheel from cache): + apt install -q -y libopenblas-dev liblapack-dev --no-install-recommends + fi echo "JAX_NUM_GENERATED_CASES=$JAX_NUM_GENERATED_CASES" echo "JAX_ENABLE_X64=$JAX_ENABLE_X64" @@ -200,17 +305,22 @@ jobs: # Check numpy version ./bazel cquery @pypi_numpy//:* | grep whl + if [ "${{ matrix.python-version }}" == "3.14" ]; then + # Check scipy version + ./bazel cquery @pypi_scipy//:* | grep whl + fi + # Build JAX and run tests ./bazel test \ --test_env=JAX_NUM_GENERATED_CASES=$JAX_NUM_GENERATED_CASES \ --test_env=JAX_ENABLE_X64=$JAX_ENABLE_X64 \ --test_env=JAX_SKIP_SLOW_TESTS=$JAX_SKIP_SLOW_TESTS \ --test_env=PYTHON_GIL=0 \ - --test_env=TSAN_OPTIONS=halt_on_error=1,suppressions=$PWD/.github/workflows/tsan-suppressions.txt \ + --test_env=TSAN_OPTIONS=halt_on_error=1,suppressions=$PWD/.github/workflows/tsan-suppressions_${{ matrix.python-version }}.txt \ --test_env=JAX_TEST_NUM_THREADS=8 \ --test_output=errors \ --local_test_jobs=32 \ - --test_timeout=600 \ + --test_timeout=1800 \ --config=resultstore \ --config=rbe_cache \ //tests:cpu_tests diff --git a/.github/workflows/upstream-nightly.yml b/.github/workflows/upstream-nightly.yml index 5132a12cf16f..23b8ac32d844 100644 --- a/.github/workflows/upstream-nightly.yml +++ b/.github/workflows/upstream-nightly.yml @@ -19,10 +19,11 @@ on: - main paths: - '**workflows/upstream-nightly.yml' - +permissions: {} jobs: upstream-dev: - runs-on: ubuntu-latest + runs-on: linux-x86-n2-64 + container: index.docker.io/library/ubuntu@sha256:b359f1067efa76f37863778f7b6d0e8d911e3ee8efa807ad01fbf5dc1ef9006b # ratchet:ubuntu:24.04 permissions: contents: read issues: write # for failed-build-issue @@ -32,8 +33,10 @@ jobs: python-version: ["3.13"] steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + persist-credentials: false - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0 + uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0 with: python-version: ${{ matrix.python-version }} - name: Install JAX test requirements @@ -66,7 +69,7 @@ jobs: echo "JAX_ENABLE_X64=$JAX_ENABLE_X64" echo "JAX_ENABLE_CHECKS=$JAX_ENABLE_CHECKS" echo "JAX_SKIP_SLOW_TESTS=$JAX_SKIP_SLOW_TESTS" - pytest -n 2 --tb=short --maxfail=20 tests examples + pytest -n auto --tb=short --maxfail=20 tests examples - name: Notify failed build uses: jayqi/failed-build-issue-action@1a893bbf43ef1c2a8705e2b115cd4f0fe3c5649b # v1.2.0 if: failure() && github.event.pull_request == null diff --git a/.github/workflows/wheel_tests_continuous.yml b/.github/workflows/wheel_tests_continuous.yml index ecdf43b133cc..99caad6325a0 100644 --- a/.github/workflows/wheel_tests_continuous.yml +++ b/.github/workflows/wheel_tests_continuous.yml @@ -9,17 +9,21 @@ # that was built in the previous step and runs CPU tests. # 3. build-cuda-artifacts: Calls the `build_artifacts.yml` workflow to build CUDA artifacts and # uploads them to a GCS bucket. -# 4. run-pytest-cuda: Calls the `pytest_cuda.yml` workflow which downloads the jaxlib and CUDA +# 4. run-bazel-test-cpu-py-import: Calls the `bazel_cpu_py_import_rbe.yml` workflow which +# runs Bazel CPU tests with py_import on RBE. +# 5. run-pytest-cuda: Calls the `pytest_cuda.yml` workflow which downloads the jaxlib and CUDA # artifacts that were built in the previous steps and runs the CUDA tests. -# 5. run-bazel-test-cuda: Calls the `bazel_cuda_non_rbe.yml` workflow which downloads the jaxlib +# 6. run-bazel-test-cuda: Calls the `bazel_cuda_non_rbe.yml` workflow which downloads the jaxlib # and CUDA artifacts that were built in the previous steps and runs the # CUDA tests using Bazel. name: CI - Wheel Tests (Continuous) +permissions: + contents: read on: schedule: - - cron: "0 */2 * * *" # Run once every 2 hours + - cron: "0 */3 * * *" # Run once every 3 hours workflow_dispatch: # allows triggering the workflow run manually concurrency: @@ -44,9 +48,9 @@ jobs: fail-fast: false # don't cancel all jobs on failure matrix: # Runner OS and Python values need to match the matrix stategy in the CPU tests job - runner: ["linux-x86-n2-16", "linux-arm64-t2a-48", "windows-x86-n2-64"] + runner: ["linux-x86-n2-16", "linux-arm64-t2a-48", "windows-x86-n2-16"] artifact: ["jaxlib"] - python: ["3.10"] + python: ["3.11"] # Note: For reasons unknown, Github actions groups jobs with the same top-level name in the # dashboard only if we use an expression in the "name" field. Otherwise, it appends the matrix # values to the name and creates a separate entry for each matrix combination. @@ -67,7 +71,7 @@ jobs: # Python values need to match the matrix stategy in the CUDA tests job below runner: ["linux-x86-n2-16"] artifact: ["jax-cuda-plugin", "jax-cuda-pjrt"] - python: ["3.10",] + python: ["3.11",] name: "Build ${{ format('{0}', 'CUDA') }} artifacts" with: runner: ${{ matrix.runner }} @@ -90,7 +94,7 @@ jobs: # Runner OS and Python values need to match the matrix stategy in the # build_jaxlib_artifact job above runner: ["linux-x86-n2-64", "linux-arm64-t2a-48", "windows-x86-n2-64"] - python: ["3.10",] + python: ["3.11",] enable-x64: [1, 0] name: "Pytest CPU (JAX artifacts version = ${{ format('{0}', 'head') }})" with: @@ -111,57 +115,72 @@ jobs: matrix: # Python values need to match the matrix stategy in the artifact build jobs above # See exlusions for what is fully tested - runner: ["linux-x86-g2-48-l4-4gpu", "linux-x86-a3-8g-h100-8gpu","linux-x86-a4-224-b200-1gpu"] - python: ["3.10",] - cuda: ["12.1","12.3","12.8"] + runner: ["linux-x86-g2-48-l4-4gpu", "linux-x86-a3-8g-h100-8gpu", "linux-x86-a4-224-b200-1gpu"] + python: ["3.11",] + cuda: [ + {version: "12.1", use-nvidia-pip-wheels: false}, + {version: "12.8", use-nvidia-pip-wheels: true}, + ] enable-x64: [1, 0] exclude: - # L4 does not run on cuda 12.8 but tests other configs - - runner: "linux-x86-g2-48-l4-4gpu" - cuda: "12.8" - # H100 runs only a single config, CUDA 12.3 Enable x64 1 + # H100 runs only a single config, CUDA 12.8 Enable x64 1 - runner: "linux-x86-a3-8g-h100-8gpu" - cuda: "12.8" - - runner: "linux-x86-a3-8g-h100-8gpu" - cuda: "12.1" + cuda: + version: "12.1" - runner: "linux-x86-a3-8g-h100-8gpu" enable-x64: "0" # B200 runs only a single config, CUDA 12.8 Enable x64 1 - runner: "linux-x86-a4-224-b200-1gpu" - enable-x64: "0" - - runner: "linux-x86-a4-224-b200-1gpu" - cuda: "12.1" + cuda: + version: "12.1" - runner: "linux-x86-a4-224-b200-1gpu" - cuda: "12.3" + enable-x64: "0" - name: "Pytest CUDA (JAX artifacts version = ${{ format('{0}', 'head') }})" + name: "Pytest CUDA (JAX artifacts version = ${{ format('{0}', 'head') }}, CUDA Pip packages = ${{ matrix.cuda.use-nvidia-pip-wheels }})" with: runner: ${{ matrix.runner }} python: ${{ matrix.python }} - cuda: ${{ matrix.cuda }} + cuda-version: ${{ matrix.cuda.version }} + use-nvidia-pip-wheels: ${{ matrix.cuda.use-nvidia-pip-wheels }} enable-x64: ${{ matrix.enable-x64 }} # GCS upload URI is the same for both artifact build jobs gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }} + run-bazel-test-cpu-py-import: + uses: ./.github/workflows/bazel_cpu_py_import_rbe.yml + strategy: + fail-fast: false # don't cancel all jobs on failure + matrix: + runner: ["linux-x86-n2-16", "linux-arm64-t2a-48"] + python: ["3.11",] + enable-x64: [1, 0] + name: "Bazel CPU tests with ${{ format('{0}', 'py_import') }}" + with: + runner: ${{ matrix.runner }} + python: ${{ matrix.python }} + enable-x64: ${{ matrix.enable-x64 }} + run-bazel-test-cuda: # Run test jobs even if the build job fails. Avoids losing test coverage if a single unrelated # build job fails. E.g Windows build job fails but everything else succeeds. In this case, we # still want to run the tests for other platforms. if: ${{ !cancelled() }} - needs: [build-jaxlib-artifact, build-cuda-artifacts] + needs: [build-jax-artifact, build-jaxlib-artifact, build-cuda-artifacts] uses: ./.github/workflows/bazel_cuda_non_rbe.yml strategy: fail-fast: false # don't cancel all jobs on failure matrix: # Python values need to match the matrix stategy in the build artifacts job above runner: ["linux-x86-g2-48-l4-4gpu",] - python: ["3.10",] + python: ["3.11",] + jaxlib-version: ["head", "pypi_latest"] enable-x64: [1, 0] - name: "Bazel CUDA Non-RBE (JAX artifacts version = ${{ format('{0}', 'head') }})" + name: "Bazel CUDA Non-RBE (jax version = ${{ format('{0}', 'head') }})" with: runner: ${{ matrix.runner }} python: ${{ matrix.python }} enable-x64: ${{ matrix.enable-x64 }} + jaxlib-version: ${{ matrix.jaxlib-version }} # GCS upload URI is the same for both artifact build jobs gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }} @@ -175,12 +194,22 @@ jobs: strategy: fail-fast: false # don't cancel all jobs on failure matrix: - python: ["3.10",] + python: ["3.11",] tpu-specs: [ # {type: "v3-8", cores: "4"}, # Enable when we have the v3 type available {type: "v4-8", cores: "4", runner: "linux-x86-ct4p-240-4tpu"}, - {type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"} + {type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"}, + {type: "v6e-8", cores: "8", runner: "linux-x86-ct6e-180-8tpu"} ] + libtpu-version-type: ["nightly", "oldest_supported_libtpu"] + exclude: + # Run a single config for oldest_supported_libtpu + - libtpu-version-type: "oldest_supported_libtpu" + tpu-specs: + type: "v4-8" + - libtpu-version-type: "oldest_supported_libtpu" + tpu-specs: + type: "v6e-8" name: "Pytest TPU (JAX artifacts version = ${{ format('{0}', 'head') }})" with: runner: ${{ matrix.tpu-specs.runner }} @@ -188,5 +217,5 @@ jobs: tpu-type: ${{ matrix.tpu-specs.type }} python: ${{ matrix.python }} run-full-tpu-test-suite: "1" - libtpu-version-type: "nightly" + libtpu-version-type: ${{ matrix.libtpu-version-type }} gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }} \ No newline at end of file diff --git a/.github/workflows/wheel_tests_nightly_release.yml b/.github/workflows/wheel_tests_nightly_release.yml index adb678be9d9d..da6d87495b21 100644 --- a/.github/workflows/wheel_tests_nightly_release.yml +++ b/.github/workflows/wheel_tests_nightly_release.yml @@ -1,12 +1,14 @@ # CI - Wheel Tests (Nightly/Release) # -# This workflow builds JAX artifacts and runs CPU/CUDA tests. +# This workflow is used to test the JAX wheels that was built by internal CI jobs. # -# It orchestrates the following: -# 1. run-pytest-cpu: Calls the `pytest_cpu.yml` workflow which downloads the jaxlib wheel that was +# 1. run-pytest-cpu: Calls the `pytest_cpu.yml` workflow which downloads the JAX wheels that was # built by internal CI jobs and runs CPU tests. -# 2. run-pytest-cuda: Calls the `pytest_cuda.yml` workflow which downloads the jaxlib and CUDA -# artifacts that were built by internal CI jobs and runs the CUDA tests. +# 2. run-pytest-cuda: Calls the `pytest_cuda.yml` workflow which downloads the JAX wheels that was +# built by internal CI jobs and runs CUDA tests. +# 3. run-pytest-tpu: Calls the `pytest_tpu.yml` workflow which downloads the JAX wheels that was +# built by internal CI jobs and runs TPU tests. +# 4. verify-release-wheels-install: Verifies that JAX's release wheels can be installed. name: CI - Wheel Tests (Nightly/Release) on: @@ -15,15 +17,26 @@ on: gcs_download_uri: description: "GCS location URI from where the artifacts should be downloaded" required: true - default: 'gs://jax-nightly-release-transient/nightly/latest' + default: 'gs://jax-nightly-artifacts/latest' + type: string + download-jax-only-from-gcs: + description: "Whether to download only the jax wheel from GCS (e.g for testing a jax only release)" + required: true + default: '0' + type: string + halt-for-connection: + description: 'Should this workflow run wait for a remote connection? (yes/no)' + required: false + default: 'no' type: string concurrency: group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} cancel-in-progress: true - +permissions: {} jobs: run-pytest-cpu: + if: ! ( matrix.python == '3.14' && startsWith(github.ref_name, 'release/') ) uses: ./.github/workflows/pytest_cpu.yml strategy: fail-fast: false # don't cancel all jobs on failure @@ -31,19 +44,24 @@ jobs: # Runner OS and Python values need to match the matrix stategy of our internal CI jobs # that build the wheels. runner: ["linux-x86-n2-64", "linux-arm64-t2a-48", "windows-x86-n2-64"] - python: ["3.10","3.11", "3.12", "3.13", "3.13-nogil"] + python: ["3.11", "3.12", "3.13", "3.13-nogil", "3.14"] enable-x64: [0] exclude: - runner: "windows-x86-n2-64" python: "3.13-nogil" + - runner: "windows-x86-n2-64" + python: "3.14" name: "Pytest CPU (JAX artifacts version = ${{ startsWith(github.ref_name, 'release/') && 'latest release' || 'nightly' }})" with: runner: ${{ matrix.runner }} python: ${{ matrix.python }} enable-x64: ${{ matrix.enable-x64 }} + download-jax-only-from-gcs: ${{inputs.download-jax-only-from-gcs}} gcs_download_uri: ${{inputs.gcs_download_uri}} + halt-for-connection: ${{inputs.halt-for-connection}} run-pytest-cuda: + if: ! ( matrix.python == '3.14' && startsWith(github.ref_name, 'release/') ) uses: ./.github/workflows/pytest_cuda.yml strategy: fail-fast: false # don't cancel all jobs on failure @@ -51,45 +69,66 @@ jobs: # Runner OS and Python values need to match the matrix stategy of our internal CI jobs # that build the wheels. runner: ["linux-x86-g2-48-l4-4gpu"] - python: ["3.10","3.11", "3.12", "3.13", "3.13-nogil"] - cuda: ["12.3", "12.1"] + python: ["3.11", "3.12", "3.13", "3.13-nogil", "3.14"] + cuda: [ + {cuda-version: "12.1", use-nvidia-pip-wheels: false}, + {cuda-version: "12.8", use-nvidia-pip-wheels: true} + ] enable-x64: [0] - name: "Pytest CUDA (JAX artifacts version = ${{ startsWith(github.ref_name, 'release/') && 'latest release' || 'nightly' }})" + name: "Pytest CUDA (JAX artifacts version = ${{ startsWith(github.ref_name, 'release/') && 'latest release' || 'nightly' }}, CUDA Pip packages = ${{ matrix.cuda.use-nvidia-pip-wheels }})" with: runner: ${{ matrix.runner }} python: ${{ matrix.python }} - cuda: ${{ matrix.cuda }} + cuda-version: ${{ matrix.cuda.cuda-version }} + use-nvidia-pip-wheels: ${{ matrix.cuda.use-nvidia-pip-wheels }} enable-x64: ${{ matrix.enable-x64 }} + download-jax-only-from-gcs: ${{inputs.download-jax-only-from-gcs}} gcs_download_uri: ${{inputs.gcs_download_uri}} + halt-for-connection: ${{inputs.halt-for-connection}} run-pytest-tpu: uses: ./.github/workflows/pytest_tpu.yml strategy: fail-fast: false # don't cancel all jobs on failure matrix: - # Skip Python 3.13 as it fails due to missing TensorFlow wheels (used for - # profiler_test.py, build/collect-profile-requirements.txt) for that version (b/402590302) - python: ["3.10", "3.11", "3.12"] + python: ["3.11", "3.12", "3.13", "3.13-nogil"] tpu-specs: [ # {type: "v3-8", cores: "4"}, # Enable when we have the v3 type available {type: "v4-8", cores: "4", runner: "linux-x86-ct4p-240-4tpu"}, - {type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"} + {type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"}, + {type: "v6e-8", cores: "8", runner: "linux-x86-ct6e-180-8tpu"} ] - libtpu-version-type: ["pypi_latest", "nightly", "oldest_supported_libtpu"] + libtpu-version-type: ["pypi_latest", "nightly"] exclude: + # Exclude nightly for releases - libtpu-version-type: ${{ startsWith(github.ref_name, 'release/') && 'nightly' }} + # Exclude pypi_latest for nightly releases - libtpu-version-type: ${{ !startsWith(github.ref_name, 'release/') && 'pypi_latest' }} - # Run a single Python version for v4-8. + # Run a single Python version for v4-8 - tpu-specs: type: "v4-8" - python: "3.10" + python: "3.11" - tpu-specs: type: "v4-8" - python: "3.11" + python: "3.12" + - tpu-specs: + type: "v4-8" + python: "3.13-nogil" + # Run Python versions in between min and max for v6e-8 + - tpu-specs: + type: "v6e-8" + python: "3.13" + - tpu-specs: + type: "v6e-8" + python: "3.13-nogil" # Run min and max Python versions for v5e-8 - tpu-specs: type: "v5e-8" python: "3.11" + - tpu-specs: + type: "v5e-8" + python: "3.12" + name: "Pytest TPU (JAX artifacts version = ${{ startsWith(github.ref_name, 'release/') && 'latest release' || 'nightly' }})" with: runner: ${{ matrix.tpu-specs.runner }} @@ -98,4 +137,90 @@ jobs: python: ${{ matrix.python }} run-full-tpu-test-suite: "1" libtpu-version-type: ${{ matrix.libtpu-version-type }} - gcs_download_uri: ${{inputs.gcs_download_uri}} \ No newline at end of file + download-jax-only-from-gcs: ${{inputs.download-jax-only-from-gcs}} + gcs_download_uri: ${{inputs.gcs_download_uri}} + halt-for-connection: ${{inputs.halt-for-connection}} + + verify-release-wheels-install: + if: ${{ startsWith(github.ref_name, 'release/')}} + defaults: + run: + # Set the shell to bash as GitHub actions runs with /bin/sh by default + shell: bash + runs-on: linux-x86-n2-16 + strategy: + fail-fast: false # don't cancel all jobs on failure + matrix: + python: ["3.11", "3.13", "3.13-nogil"] + container: "us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build:latest" + + # Verifies that JAX's release wheels can be installed + name: "Verify release wheels install (Python ${{ matrix.python }})" + + env: + PYTHON: "python${{ matrix.python }}" + + steps: + - name: Download release wheels from GCS + run: | + mkdir -p $(pwd)/dist + final_gcs_download_uri=${{ inputs.gcs_download_uri }} + + # Get the major and minor version of Python. + # E.g if python=3.11, then python_major_minor=311 + # E.g if python=3.13-nogil, then python_major_minor=313t + python_major_minor=${{ matrix.python }} + python_major_minor=$(echo "${python_major_minor//-nogil/t}" | tr -d '.') + python_major_minor="cp${python_major_minor%t}-cp${python_major_minor}-" + + gcloud storage cp -r "${final_gcs_download_uri}"/jax*py3*none*any.whl $(pwd)/dist/ + + jax_wheel=$(ls dist/jax*py3*none*any.whl 2>/dev/null) + echo "JAX_WHEEL=$jax_wheel" >> $GITHUB_ENV + + if [[ "${{ inputs.download-jax-only-from-gcs }}" != "1" ]]; then + gcloud storage cp -r "${final_gcs_download_uri}/jaxlib*${python_major_minor}*linux*x86_64*.whl" $(pwd)/dist/ + gcloud storage cp -r "${final_gcs_download_uri}/jax*cuda*plugin*${python_major_minor}*linux*x86_64*.whl" $(pwd)/dist/ + gcloud storage cp -r "${final_gcs_download_uri}/jax*cuda*pjrt*linux*x86_64*.whl" $(pwd)/dist/ + + jaxlib_wheel=$(ls dist/jaxlib*${python_major_minor}*linux*x86_64*.whl 2>/dev/null) + jax_cuda_plugin_wheel=$(ls dist/jax*cuda*plugin*${python_major_minor}*linux*x86_64*.whl 2>/dev/null) + jax_cuda_pjrt_wheel=$(ls dist/jax*cuda*pjrt*linux*x86_64*.whl 2>/dev/null) + + echo "JAXLIB_WHEEL=$jaxlib_wheel" >> $GITHUB_ENV + echo "JAX_CUDA_PLUGIN_WHEEL=$jax_cuda_plugin_wheel" >> $GITHUB_ENV + echo "JAX_CUDA_PJRT_WHEEL=$jax_cuda_pjrt_wheel" >> $GITHUB_ENV + fi + - name: Verify JAX CPU packages can be installed + run: | + $PYTHON -m uv venv ~/test_cpu && source ~/test_cpu/bin/activate + if [[ "${{ inputs.download-jax-only-from-gcs }}" == "1" ]]; then + uv pip install $JAX_WHEEL + else + uv pip install $JAX_WHEEL $JAXLIB_WHEEL + fi + - name: Verify JAX TPU packages can be installed + run: | + $PYTHON -m uv venv ~/test_tpu && source ~/test_tpu/bin/activate + + if [[ "${{ inputs.download-jax-only-from-gcs }}" == "1" ]]; then + uv pip install $JAX_WHEEL[tpu] + else + uv pip install $JAX_WHEEL[tpu] $JAXLIB_WHEEL + fi + - name: Verify JAX CUDA packages can be installed (Nvidia Pip Packages) + run: | + $PYTHON -m uv venv ~/test_cuda_pip && source ~/test_cuda_pip/bin/activate + if [[ "${{ inputs.download-jax-only-from-gcs }}" == "1" ]]; then + uv pip install $JAX_WHEEL[cuda] + else + uv pip install $JAX_WHEEL[cuda] $JAXLIB_WHEEL $JAX_CUDA_PJRT_WHEEL $JAX_CUDA_PLUGIN_WHEEL[with-cuda] + fi + - name: Verify JAX CUDA packages can be installed (CUDA local) + run: | + $PYTHON -m uv venv ~/test_cuda_local && source ~/test_cuda_local/bin/activate + if [[ "${{ inputs.download-jax-only-from-gcs }}" == "1" ]]; then + uv pip install $JAX_WHEEL[cuda12-local] + else + uv pip install $JAX_WHEEL $JAXLIB_WHEEL $JAX_CUDA_PJRT_WHEEL $JAX_CUDA_PLUGIN_WHEEL + fi \ No newline at end of file diff --git a/.github/workflows/wheel_win_x64.yml b/.github/workflows/wheel_win_x64.yml deleted file mode 100644 index 444bc83f2889..000000000000 --- a/.github/workflows/wheel_win_x64.yml +++ /dev/null @@ -1,64 +0,0 @@ -name: Wheel build - Windows CPU x86_64 -on: - workflow_dispatch: # allows triggering the workflow run manually - -concurrency: - group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} - cancel-in-progress: true - -env: - DISTUTILS_USE_SDK: 1 - MSSdk: 1 - -jobs: - win-wheels: - strategy: - fail-fast: false # Don't stop all wheel builds if one has a test failure. - matrix: - os: [windows-2019-32core] - arch: [AMD64] - pyver: ['3.10', '3.11', '3.12', '3.13'] - name: ${{ matrix.os }} ${{ matrix.pyver }} jaxlib wheel build - runs-on: ${{ matrix.os }} - - steps: - - name: Install LLVM/Clang - run: choco install llvm --version=18.1.4 --yes --no-progress --allow-downgrade - - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - - uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0 - with: - python-version: ${{ matrix.pyver }} - cache: 'pip' - - - name: Build wheels - env: - BAZEL_VC: "C:\\Program Files (x86)\\Microsoft Visual Studio\\2019\\Enterprise\\VC" - JAXLIB_RELEASE: true - run: | - python -m pip install uv~=0.5.30 - python -m uv pip install -r build/test-requirements.txt \ - --upgrade numpy==2.0.0 scipy==1.13.1 - "C:\\msys64\\;C:\\msys64\\usr\\bin\\;" >> $env:GITHUB_PATH - python.exe build\build.py build --wheels=jaxlib ` - --bazel_options=--color=yes ` - --bazel_options=--config=win_clang ` - --verbose - - - uses: actions/upload-artifact@6f51ac03b9356f520e9adb1b1b7802705f340c2b # v4.5.0 - with: - name: wheels-${{ matrix.os }}-${{ matrix.pyver }} - path: ${{ github.workspace }}\dist\*.whl - retention-days: 5 - - - name: Run tests - env: - JAX_ENABLE_CHECKS: true - JAX_SKIP_SLOW_TESTS: true - PY_COLORS: 1 - run: | - python -m uv pip install --find-links ${{ github.workspace }}\dist jaxlib \ - -e ${{ github.workspace }} - echo "JAX_ENABLE_CHECKS=$JAX_ENABLE_CHECKS" - pytest -n auto --tb=short tests examples diff --git a/.github/workflows/windows_ci.yml b/.github/workflows/windows_ci.yml deleted file mode 100644 index fc2b63396f56..000000000000 --- a/.github/workflows/windows_ci.yml +++ /dev/null @@ -1,73 +0,0 @@ -name: CI - Windows CPU -on: - schedule: - - cron: "0 12 * * *" # Daily at 12:00 UTC - workflow_dispatch: # allows triggering the workflow run manually - pull_request: - types: [ labeled ] # allow force-windows-run label - -concurrency: - group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} - cancel-in-progress: true - -env: - DISTUTILS_USE_SDK: 1 - MSSdk: 1 - -jobs: - win-wheels: - if: ${{ (github.event.action != 'labeled') || (github.event.label.name == 'windows:force-run')}} - strategy: - fail-fast: true - matrix: - os: [windows-2019-32core] - arch: [AMD64] - pyver: ['3.10'] - name: Windows CI build - runs-on: ${{ matrix.os }} - - steps: - - - name: Install LLVM/Clang - run: choco install llvm --version=18.1.4 --yes --no-progress --allow-downgrade - - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - with: - path: jax - - - uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0 - with: - python-version: ${{ matrix.pyver }} - cache: 'pip' - - - name: Build wheels - env: - BAZEL_VC: "C:\\Program Files (x86)\\Microsoft Visual Studio\\2019\\Enterprise\\VC" - JAXLIB_NIGHTLY: true # Tag the wheels as dev versions - run: | - cd jax - python -m pip install uv~=0.5.30 - python -m uv pip install -r build/test-requirements.txt --upgrade numpy==2.0.0 scipy==1.13.1 - "C:\\msys64\\;C:\\msys64\\usr\\bin\\;" >> $env:GITHUB_PATH - python.exe build\build.py build --wheels=jaxlib ` - --bazel_options=--color=yes ` - --bazel_options=--config=win_clang ` - --verbose - - - uses: actions/upload-artifact@6f51ac03b9356f520e9adb1b1b7802705f340c2b # v4.5.0 - with: - name: wheels - path: ${{ github.workspace }}\jax\dist\*.whl - retention-days: 5 - - - name: Run tests - env: - JAX_ENABLE_CHECKS: true - JAX_SKIP_SLOW_TESTS: true - PY_COLORS: 1 - run: | - cd jax - python -m uv pip install --pre --find-links ${{ github.workspace }}\jax\dist jaxlib ` - -e ${{ github.workspace }}\jax - echo "JAX_ENABLE_CHECKS=$JAX_ENABLE_CHECKS" - pytest -n auto --tb=short tests examples diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 27ccc6d831f3..2312c88579d6 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -9,12 +9,17 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: 2c9f875913ee60ca25ce70243dc24d5b6415598c # frozen: v4.6.0 + rev: cef0300fd0fc4d2a87a85fa2093c6b283ea36f4b # frozen: v5.0.0 hooks: - id: check-ast - id: check-merge-conflict - id: check-toml - id: check-yaml + exclude: | + (?x)^( + examples/k8s/svc-acct\.yaml | + ci/k8s/indexed-job\.yaml + )$ - id: end-of-file-fixer # only include python files files: \.py$ @@ -26,17 +31,17 @@ repos: files: \.py$ - repo: https://github.com/astral-sh/ruff-pre-commit - rev: 8983acb92ee4b01924893632cf90af926fa608f0 # frozen: v0.7.0 + rev: 24e02b24b8ab2b7c76225602d13fa60e12d114e6 # frozen: v0.11.9 hooks: - id: ruff - repo: https://github.com/pre-commit/mirrors-mypy - rev: 'bbc3dc1f890007061f18f17e2334f216ea9e5df7' # frozen: v1.14.1 + rev: '7010b10a09f65cd60a23c207349b539aa36dbec1' # frozen: v1.16.0 hooks: - id: mypy files: (jax/|tests/typing_test\.py) - exclude: jax/_src/basearray.py|jax/numpy/__init__.py # Use pyi instead - additional_dependencies: [types-requests==2.31.0, jaxlib, numpy>=2.2.0] + exclude: jax/_src/basearray.py|jax/numpy/__init__.py|jaxlib/_jax/.* # Use pyi instead + additional_dependencies: [types-requests==2.31.0, numpy>=2.2.0] args: [--config=pyproject.toml] - repo: https://github.com/mwouts/jupytext diff --git a/.readthedocs.yml b/.readthedocs.yml index 6f807aa82377..0ac20301cee2 100644 --- a/.readthedocs.yml +++ b/.readthedocs.yml @@ -6,9 +6,23 @@ version: 2 build: - os: "ubuntu-22.04" + os: "ubuntu-24.04" tools: - python: "3.10" + python: "3.12" + jobs: + post_checkout: + # Skip building PRs unless tagged with the "documentation" label. + - | + [ "${READTHEDOCS_VERSION_TYPE}" != "external" ] && echo "Building latest" && exit 0 + (curl -sL https://api.github.com/repos/jax-ml/jax/issues/${READTHEDOCS_VERSION}/labels | grep -q "https://api.github.com/repos/jax-ml/jax/labels/documentation") && echo "Building PR with label" || exit 183 + create_environment: + - asdf plugin add uv + - asdf install uv latest + - asdf global uv latest + - uv venv $READTHEDOCS_VIRTUALENV_PATH + - UV_PROJECT_ENVIRONMENT=$READTHEDOCS_VIRTUALENV_PATH uv pip install -r docs/requirements.txt + install: + - "true" # skip # Build documentation in the docs/ directory with Sphinx sphinx: @@ -18,8 +32,3 @@ sphinx: # Optionally build your docs in additional formats such as PDF and ePub formats: - htmlzip - -# Optionally set the version of Python and requirements required to build your docs -python: - install: - - requirements: docs/requirements.txt diff --git a/BUILD.bazel b/BUILD.bazel index 33cbefd29f0b..d51b9f8c9cef 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -12,25 +12,28 @@ # See the License for the specific language governing permissions and # limitations under the License. -load("@xla//third_party/py:python_wheel.bzl", "collect_data_files", "transitive_py_deps") +load( + "@xla//third_party/py:py_import.bzl", + "py_import", +) load( "//jaxlib:jax.bzl", + "jax_source_package", "jax_wheel", + "pytype_test", + "wheel_sources", ) -collect_data_files( - name = "transitive_py_data", - deps = ["//jax"], -) - -transitive_py_deps( - name = "transitive_py_deps", - deps = [ +wheel_sources( + name = "jax_sources", + data_srcs = ["//jax"], + py_srcs = [ "//jax", "//jax:compilation_cache", "//jax:experimental", "//jax:experimental_colocated_python", "//jax:experimental_sparse", + "//jax:experimental_buffer_callback", "//jax:lax_reference", "//jax:pallas_experimental_gpu_ops", "//jax:pallas_gpu_ops", @@ -39,11 +42,20 @@ transitive_py_deps( "//jax:pallas_triton", "//jax:source_mapper", "//jax:sparse_test_util", + "//jax:test_multiprocess", "//jax:test_util", + "//jax:internal_export_back_compat_test_util", + "//jax:internal_export_back_compat_test_data", + "//jax:internal_test_harnesses", + "//jax:internal_test_util", "//jax/_src/lib", + "//jax/_src/pallas/fuser", "//jax/_src/pallas/mosaic_gpu", "//jax/experimental/array_serialization:serialization", + "//jax/experimental/array_serialization:pytree_serialization", "//jax/experimental/jax2tf", + "//jax/experimental/mosaic/gpu/examples:flash_attention", + "//jax/experimental/mosaic/gpu/examples:matmul", "//jax/extend", "//jax/extend:ifrt_programs", "//jax/extend/mlir", @@ -52,6 +64,14 @@ transitive_py_deps( "//jax/tools:jax_to_ir", "//jax/tools:pgo_nsys_converter", ], + static_srcs = [ + "//jax:py.typed", + "AUTHORS", + "LICENSE", + "README.md", + "pyproject.toml", + "setup.py", + ], ) py_binary( @@ -59,26 +79,78 @@ py_binary( srcs = ["build_wheel.py"], deps = [ "//jaxlib/tools:build_utils", - "@pypi_build//:pkg", - "@pypi_setuptools//:pkg", - "@pypi_wheel//:pkg", + "@pypi//build", + "@pypi//setuptools", + "@pypi//wheel", ], ) jax_wheel( name = "jax_wheel", - build_wheel_only = False, platform_independent = True, - source_files = [ - ":transitive_py_data", - ":transitive_py_deps", - "//jax:py.typed", - "AUTHORS", - "LICENSE", - "README.md", - "pyproject.toml", - "setup.py", - ], + source_files = [":jax_sources"], wheel_binary = ":build_wheel", wheel_name = "jax", ) + +jax_wheel( + name = "jax_wheel_editable", + editable = True, + platform_independent = True, + source_files = [":jax_sources"], + wheel_binary = ":build_wheel", + wheel_name = "jax", +) + +jax_source_package( + name = "jax_source_package", + source_files = [":jax_sources"], + source_package_binary = ":build_wheel", + source_package_name = "jax", +) + +genrule( + name = "wheel_additives", + srcs = [ + "//jax:internal_export_back_compat_test_util", + "//jax:internal_test_harnesses", + "//jax:internal_test_util", + "//jax:internal_export_back_compat_test_data", + "//jax/experimental/mosaic/gpu/examples:flash_attention.py", + "//jax/experimental/mosaic/gpu/examples:matmul.py", + "//jax:test_multiprocess", + ], + outs = ["wheel_additives.zip"], + cmd = "$(location @bazel_tools//tools/zip:zipper) c $@ $(SRCS)", + tools = ["@bazel_tools//tools/zip:zipper"], +) + +py_import( + name = "jax_py_import", + wheel = ":jax_wheel", + wheel_deps = [":wheel_additives"], +) + +# This target is used to add more sources to the jax wheel. +# This is needed for the tests that depend on jax and use modules that are not part of +# the jax wheel, but share the same package paths as the modules in the jax wheel. +py_import( + name = "jax_wheel_with_internal_test_util", + wheel = "@pypi_jax//:whl", + wheel_deps = [":wheel_additives"], +) + +pytype_test( + name = "jax_wheel_size_test", + srcs = ["//jaxlib/tools:wheel_size_test.py"], + args = [ + "--wheel-path=$(location :jax_wheel)", + "--max-size-mib=5", + ], + data = [":jax_wheel"], + main = "wheel_size_test.py", + tags = [ + "manual", + "notap", + ], +) diff --git a/CHANGELOG.md b/CHANGELOG.md index c30877ecae14..32730a2355cd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,6 @@ # Change log -Best viewed [here](https://jax.readthedocs.io/en/latest/changelog.html). +Best viewed [here](https://docs.jax.dev/en/latest/changelog.html). For the changes specific to the experimental Pallas APIs, see {ref}`pallas-changelog`. @@ -16,6 +16,124 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. ## Unreleased +* Breaking changes: + * {func}`jax.jit` now requires `fun` to be passed by position, and additional + arguments to be passed by keyword. Doing otherwise will result in an error + starting in v0.7.x. This raised a DeprecationWarning in v0.6.x. + * The minimum Python version is now 3.11. 3.11 will remain the minimum + supported version until July 2026. + * `Layout`, `.layout`, `.input_layouts` and `.output_layouts` have been + renamed to `Format`, `.format`, `.input_formats` and `.output_formats` in JAX + +## JAX 0.6.2 (June 17, 2025) + +* New features: + * Added {func}`jax.tree.broadcast` which implements a pytree prefix broadcasting helper. + +* Changes + * The minimum NumPy version is 1.26 and the minimum SciPy version is 1.12. + +## JAX 0.6.1 (May 21, 2025) + +* New features: + * Added {func}`jax.lax.axis_size` which returns the size of the mapped axis + given its name. + +* Changes + * Additional checking for the versions of CUDA package dependencies was + re-enabled, having been accidentally disabled in a previous release. + * JAX nightly packages are now published to artifact registry. To install + these packages, see the [JAX installation guide](https://docs.jax.dev/en/latest/installation.html#jax-nightly-installation). + * `jax.sharding.PartitionSpec` no longer inherits from a tuple. + * `jax.ShapeDtypeStruct` is immutable now. Please use `.update` method to + update your `ShapeDtypeStruct` instead of doing in-place updates. + +* Deprecations + * `jax.custom_derivatives.custom_jvp_call_jaxpr_p` is deprecated, and will be + removed in JAX v0.7.0. + +## JAX 0.6.0 (April 16, 2025) + +* Breaking changes + + * {func}`jax.numpy.array` no longer accepts `None`. This behavior was + deprecated since November 2023 and is now removed. + * Removed the `config.jax_data_dependent_tracing_fallback` config option, + which was added temporarily in v0.4.36 to allow users to opt out of the + new "stackless" tracing machinery. + * Removed the `config.jax_eager_pmap` config option. + * Disallow the calling of `lower` and `trace` AOT APIs on the result + of `jax.jit` if there have been subsequent wrappers applied. + Previously this worked, but silently ignored the wrappers. + The workaround is to apply `jax.jit` last among the wrappers, + and similarly for `jax.pmap`. + See {jax-issue}`#27873`. + * The `cuda12_pip` extra for `jax` has been removed; use `pip install jax[cuda12]` + instead. + +* Changes + * The minimum CuDNN version is v9.8. + * JAX is now built using CUDA 12.8. All versions of CUDA 12.1 or newer remain + supported. + * JAX package extras are now updated to use dash instead of underscore to + align with PEP 685. For instance, if you were previously using `pip install jax[cuda12_local]` + to install JAX, run `pip install jax[cuda12-local]` instead. + * {func}`jax.jit` now requires `fun` to be passed by position, and additional + arguments to be passed by keyword. Doing otherwise will result in a + DeprecationWarning in v0.6.X, and an error in starting in v0.7.X. + +* Deprecations + + * {func}`jax.tree_util.build_tree` is deprecated. Use {func}`jax.tree.unflatten` + instead. + * Implemented host callback handlers for CPU and GPU devices using XLA's FFI + and removed existing CPU/GPU handlers using XLA's custom call. + * All APIs in `jax.lib.xla_extension` are now deprecated. + * `jax.interpreters.mlir.hlo` and `jax.interpreters.mlir.func_dialect`, + which were accidental exports, have been removed. If needed, they are + available from `jax.extend.mlir`. + * `jax.interpreters.mlir.custom_call` is deprecated. The APIs provided by + {mod}`jax.ffi` should be used instead. + * The deprecated use of {func}`jax.ffi.ffi_call` with inline arguments is no + longer supported. {func}`~jax.ffi.ffi_call` now unconditionally returns a + callable. + * The following exports in `jax.lib.xla_client` are deprecated: + `get_topology_for_devices`, `heap_profile`, `mlir_api_version`, `Client`, + `CompileOptions`, `DeviceAssignment`, `Frame`, `HloSharding`, `OpSharding`, + `Traceback`. + * The following internal APIs in `jax.util` are deprecated: + `HashableFunction`, `as_hashable_function`, `cache`, `safe_map`, `safe_zip`, + `split_dict`, `split_list`, `split_list_checked`, `split_merge`, `subvals`, + `toposort`, `unzip2`, `wrap_name`, and `wraps`. + * `jax.dlpack.to_dlpack` has been deprecated. You can usually pass a JAX + `Array` directly to the `from_dlpack` function of another framework. If you + need the functionality of `to_dlpack`, use the `__dlpack__` attribute of an + array. + * `jax.lax.infeed`, `jax.lax.infeed_p`, `jax.lax.outfeed`, and + `jax.lax.outfeed_p` are deprecated and will be removed in JAX v0.7.0. + * Several previously-deprecated APIs have been removed, including: + * From `jax.lib.xla_client`: `ArrayImpl`, `FftType`, `PaddingType`, + `PrimitiveType`, `XlaBuilder`, `dtype_to_etype`, + `ops`, `register_custom_call_target`, `shape_from_pyval`, `Shape`, + `XlaComputation`. + * From `jax.lib.xla_extension`: `ArrayImpl`, `XlaRuntimeError`. + * From `jax`: `jax.treedef_is_leaf`, `jax.tree_flatten`, `jax.tree_map`, + `jax.tree_leaves`, `jax.tree_structure`, `jax.tree_transpose`, and + `jax.tree_unflatten`. Replacements can be found in {mod}`jax.tree` or + {mod}`jax.tree_util`. + * From `jax.core`: `AxisSize`, `ClosedJaxpr`, `EvalTrace`, `InDBIdx`, `InputType`, + `Jaxpr`, `JaxprEqn`, `Literal`, `MapPrimitive`, `OpaqueTraceState`, `OutDBIdx`, + `Primitive`, `Token`, `TRACER_LEAK_DEBUGGER_WARNING`, `Var`, `concrete_aval`, + `dedup_referents`, `escaped_tracer_error`, `extend_axis_env_nd`, `full_lower`, `get_referent`, `jaxpr_as_fun`, `join_effects`, `lattice_join`, + `leaked_tracer_error`, `maybe_find_leaked_tracers`, `raise_to_shaped`, + `raise_to_shaped_mappings`, `reset_trace_state`, `str_eqn_compact`, + `substitute_vars_in_output_ty`, `typecompat`, and `used_axis_names_jaxpr`. Most + have no public replacement, though a few are available at {mod}`jax.extend.core`. + * The `vectorized` argument to {func}`~jax.pure_callback` and + {func}`~jax.ffi.ffi_call`. Use the `vmap_method` parameter instead. + +## jax 0.5.3 (Mar 19, 2025) + * New Features * Added a `allow_negative_indices` option to {func}`jax.lax.dynamic_slice`, @@ -34,6 +152,30 @@ Patch release of 0.5.1 ## jax 0.5.1 (Feb 24, 2025) +* Breaking changes + * The jit tracing cache now keys on input NamedShardings. Previously, the + tracing cache did not include sharding information at all + (although subsequent jit caches did like lowering and compilation caches), + so two equivalent shardings of different types would not retrace, + but now they do. For example: + ```python + @jax.jit + def f(x): + return x + + # inp1.sharding is of type SingleDeviceSharding + inp1 = jnp.arange(8) + f(inp1) + + mesh = jax.make_mesh((1,), ('x',)) + # inp2.sharding is of type NamedSharding + inp2 = jax.device_put(jnp.arange(8), NamedSharding(mesh, P('x'))) + f(inp2) # tracing cache miss + ``` + In the above example, calling `f(inp1)` and then `f(inp2)` will lead to a + tracing cache miss because the shardings have changed on the abstract values + while tracing. + * New Features * Added an experimental {func}`jax.experimental.custom_dce.custom_dce` decorator to support customizing the behavior of opaque functions under @@ -81,7 +223,7 @@ Patch release of 0.5.1 ## jax 0.5.0 (Jan 17, 2025) As of this release, JAX now uses -[effort-based versioning](https://jax.readthedocs.io/en/latest/jep/25516-effver.html). +[effort-based versioning](https://docs.jax.dev/en/latest/jep/25516-effver.html). Since this release makes a breaking change to PRNG key semantics that may require users to update their code, we are bumping the "meso" version of JAX to signify this. @@ -101,7 +243,7 @@ to signify this. developers at this point. So it is difficult for us to fix this kind of problem even if we wanted to. - We are open to readding support for Mac x86 if the community is willing + We are open to re-adding support for Mac x86 if the community is willing to help support that platform: in particular, we would need the JAX test suite to pass cleanly on Mac x86 before we could ship releases again. @@ -172,7 +314,7 @@ to signify this. * New Features * {func}`jax.export.export` can be used for device-polymorphic export with shardings constructed with {func}`jax.sharding.AbstractMesh`. - See the [jax.export documentation](https://jax.readthedocs.io/en/latest/export/export.html#device-polymorphic-export). + See the [jax.export documentation](https://docs.jax.dev/en/latest/export/export.html#device-polymorphic-export). * Added {func}`jax.lax.split`. This is a primitive version of {func}`jax.numpy.split`, added because it yields a more compact transpose during automatic differentiation. @@ -214,7 +356,7 @@ This is a patch release of jax 0.4.36. Only "jax" was released at this version. after being deprecated in JAX v0.4.31. Instead use `xb = jax.lib.xla_bridge`, `xc = jax.lib.xla_client`, and `xe = jax.lib.xla_extension`. * The deprecated module `jax.experimental.export` has been removed. It was replaced - by {mod}`jax.export` in JAX v0.4.30. See the [migration guide](https://jax.readthedocs.io/en/latest/export/export.html#migration-guide-from-jax-experimental-export) + by {mod}`jax.export` in JAX v0.4.30. See the [migration guide](https://docs.jax.dev/en/latest/export/export.html#migration-guide-from-jax-experimental-export) for information on migrating to the new API. * The `initial` argument to {func}`jax.nn.softmax` and {func}`jax.nn.log_softmax` has been removed, after being deprecated in v0.4.27. @@ -252,7 +394,7 @@ This is a patch release of jax 0.4.36. Only "jax" was released at this version. call that we guarantee export stability. This is because this custom call relies on Triton IR, which is not guaranteed to be stable. If you need to export code that uses this custom call, you can use the `disabled_checks` - parameter. See more details in the [documentation](https://jax.readthedocs.io/en/latest/export/export.html#compatibility-guarantees-for-custom-calls). + parameter. See more details in the [documentation](https://docs.jax.dev/en/latest/export/export.html#compatibility-guarantees-for-custom-calls). * New Features * {func}`jax.jit` got a new `compiler_options: dict[str, Any]` argument, for @@ -326,7 +468,7 @@ This is a patch release of jax 0.4.36. Only "jax" was released at this version. * `jax_pmap_no_rank_reduction` flag is set to `True` by default. * array[0] on a pmap result now introduces a reshape (use array[0:1] instead). - * The per-shard shape (accessable via jax_array.addressable_shards or + * The per-shard shape (accessible via jax_array.addressable_shards or jax_array.addressable_data(0)) now has a leading (1, ...). Update code that directly accesses shards accordingly. The rank of the per-shard-shape now matches that of the global shape which is the same behavior as jit. @@ -532,7 +674,7 @@ See the 0.4.33 release notes for more details. * Added an API for exporting and serializing JAX functions. This used to exist in `jax.experimental.export` (which is being deprecated), and will now live in `jax.export`. - See the [documentation](https://jax.readthedocs.io/en/latest/export/index.html). + See the [documentation](https://docs.jax.dev/en/latest/export/index.html). * Deprecations * Internal pretty-printing tools `jax.core.pp_*` are deprecated, and will be removed @@ -541,7 +683,7 @@ See the 0.4.33 release notes for more details. release. This previously was the case, but there was an inadvertent regression in the last several JAX releases. * `jax.experimental.export` is deprecated. Use {mod}`jax.export` instead. - See the [migration guide](https://jax.readthedocs.io/en/latest/export/export.html#migration-guide-from-jax-experimental-export). + See the [migration guide](https://docs.jax.dev/en/latest/export/export.html#migration-guide-from-jax-experimental-export). * Passing an array in place of a dtype is now deprecated in most cases; e.g. for arrays `x` and `y`, `x.astype(y)` will raise a warning. To silence it use `x.astype(y.dtype)`. * `jax.xla_computation` is deprecated and will be removed in a future release. @@ -753,7 +895,7 @@ See the 0.4.33 release notes for more details. deprecated. Use `jax.experimental.shard_map` or `jax.vmap` with the `spmd_axis_name` argument for expressing SPMD device-parallel computations. * The `jax.experimental.host_callback` module is deprecated. - Use instead the [new JAX external callbacks](https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html). + Use instead the [new JAX external callbacks](https://docs.jax.dev/en/latest/notebooks/external_callbacks.html). Added `JAX_HOST_CALLBACK_LEGACY` flag to assist in the transition to the new callbacks. See {jax-issue}`#20385` for a discussion. * Passing arguments to {func}`jax.numpy.array_equal` and {func}`jax.numpy.array_equiv` @@ -1225,9 +1367,9 @@ See the 0.4.33 release notes for more details. * Deprecations * Python 3.8 support has been dropped as per - https://jax.readthedocs.io/en/latest/deprecation.html + https://docs.jax.dev/en/latest/deprecation.html * JAX now requires NumPy 1.22 or newer as per - https://jax.readthedocs.io/en/latest/deprecation.html + https://docs.jax.dev/en/latest/deprecation.html * Passing optional arguments to {func}`jax.numpy.ndarray.at` by position is no longer supported, after being deprecated in JAX version 0.4.7. For example, instead of `x.at[i].get(True)`, use `x.at[i].get(indices_are_sorted=True)` @@ -1272,7 +1414,7 @@ See the 0.4.33 release notes for more details. * Deprecations * Python 3.8 support has been dropped as per - https://jax.readthedocs.io/en/latest/deprecation.html + https://docs.jax.dev/en/latest/deprecation.html ## jax 0.4.13 (June 22, 2023) @@ -1382,7 +1524,7 @@ See the 0.4.33 release notes for more details. dict of string stat names with int values, e.g. `"bytes_in_use"`, or None if the platform doesn't support memory statistics. The exact stats returned may vary across platforms. Currently only implemented on Cloud TPU. - * Readded support for the Python buffer protocol (`memoryview`) on CPU + * Re-added support for the Python buffer protocol (`memoryview`) on CPU devices. ## jax 0.4.10 (May 11, 2023) @@ -1451,7 +1593,7 @@ See the 0.4.33 release notes for more details. ## jax 0.4.7 (March 27, 2023) * Changes - * As per https://jax.readthedocs.io/en/latest/jax_array_migration.html#jax-array-migration + * As per https://docs.jax.dev/en/latest/jax_array_migration.html#jax-array-migration `jax.config.jax_array` cannot be disabled anymore. * `jax.config.jax_jit_pjit_api_merge` cannot be disabled anymore. * {func}`jax.experimental.jax2tf.convert` now supports the `native_serialization` @@ -1535,7 +1677,7 @@ Changes: on top of each other. With the `jit`-`pjit` implementation merge, `jit` becomes an initial style primitive which means that we trace to jaxpr as early as possible. For more information see - [this section in autodidax](https://jax.readthedocs.io/en/latest/autodidax.html#on-the-fly-final-style-and-staged-initial-style-processing). + [this section in autodidax](https://docs.jax.dev/en/latest/autodidax.html#on-the-fly-final-style-and-staged-initial-style-processing). Moving to initial style should simplify JAX's internals and make development of features like dynamic shapes, etc easier. You can disable it only via the environment variable i.e. @@ -1620,9 +1762,9 @@ Changes: simplifies and unifies JAX internals, and allows us to unify `jit` and `pjit`. `jax.Array` has been enabled by default in JAX 0.4 and makes some breaking change to the `pjit` API. The [jax.Array migration - guide](https://jax.readthedocs.io/en/latest/jax_array_migration.html) can + guide](https://docs.jax.dev/en/latest/jax_array_migration.html) can help you migrate your codebase to `jax.Array`. You can also look at the - [Distributed arrays and automatic parallelization](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html) + [Distributed arrays and automatic parallelization](https://docs.jax.dev/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html) tutorial to understand the new concepts. * `PartitionSpec` and `Mesh` are now out of experimental. The new API endpoints are `jax.sharding.PartitionSpec` and `jax.sharding.Mesh`. @@ -1651,7 +1793,7 @@ Changes: * The behavior of `XLA_PYTHON_CLIENT_MEM_FRACTION=.XX` has been changed to allocate XX% of the total GPU memory instead of the previous behavior of using currently available GPU memory to calculate preallocation. Please refer to - [GPU memory allocation](https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html) for + [GPU memory allocation](https://docs.jax.dev/en/latest/gpu_memory_allocation.html) for more details. * The deprecated method `.block_host_until_ready()` has been removed. Use `.block_until_ready()` instead. @@ -1765,7 +1907,7 @@ Changes: * Changes * Ahead-of-time lowering and compilation functionality (tracked in {jax-issue}`#7733`) is stable and public. See [the - overview](https://jax.readthedocs.io/en/latest/aot.html) and the API docs + overview](https://docs.jax.dev/en/latest/aot.html) and the API docs for {mod}`jax.stages`. * Introduced {class}`jax.Array`, intended to be used for both `isinstance` checks and type annotations for array types in JAX. Notice that this included some subtle @@ -1786,7 +1928,7 @@ Changes: * Breaking changes * {func}`jax.checkpoint`, also known as {func}`jax.remat`, no longer supports the `concrete` option, following the previous version's deprecation; see - [JEP 11830](https://jax.readthedocs.io/en/latest/jep/11830-new-remat-checkpoint.html). + [JEP 11830](https://docs.jax.dev/en/latest/jep/11830-new-remat-checkpoint.html). * Changes * Added {func}`jax.pure_callback` that enables calling back to pure Python functions from compiled functions (e.g. functions decorated with `jax.jit` or `jax.pmap`). * Deprecations: @@ -1798,7 +1940,7 @@ Changes: * [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.3.15...main). * Breaking changes * Support for NumPy 1.19 has been dropped, per the - [deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html). + [deprecation policy](https://docs.jax.dev/en/latest/deprecation.html). Please upgrade to NumPy 1.20 or newer. * Changes * Added {mod}`jax.debug` that includes utilities for runtime value debugging such at {func}`jax.debug.print` and {func}`jax.debug.breakpoint`. @@ -1816,7 +1958,7 @@ Changes: {mod}`jax.example_libraries.optimizers`. * {func}`jax.checkpoint`, also known as {func}`jax.remat`, has a new implementation switched on by default, meaning the old implementation is - deprecated; see [JEP 11830](https://jax.readthedocs.io/en/latest/jep/11830-new-remat-checkpoint.html). + deprecated; see [JEP 11830](https://docs.jax.dev/en/latest/jep/11830-new-remat-checkpoint.html). ## jax 0.3.15 (July 22, 2022) * [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.3.14...jax-v0.3.15). @@ -1948,7 +2090,7 @@ Changes: * {func}`jax.numpy.linalg.matrix_rank` on TPUs now accepts complex input. * {func}`jax.scipy.cluster.vq.vq` has been added. * `jax.experimental.maps.mesh` has been deleted. - Please use `jax.experimental.maps.Mesh`. Please see https://jax.readthedocs.io/en/latest/_autosummary/jax.experimental.maps.Mesh.html#jax.experimental.maps.Mesh + Please use `jax.experimental.maps.Mesh`. Please see https://docs.jax.dev/en/latest/_autosummary/jax.experimental.maps.Mesh.html#jax.experimental.maps.Mesh for more information. * {func}`jax.scipy.linalg.qr` now returns a length-1 tuple rather than the raw array when `mode='r'`, in order to match the behavior of `scipy.linalg.qr` ({jax-issue}`#10452`) @@ -2064,7 +2206,7 @@ Changes: * Changes: * The functions `jax.ops.index_update`, `jax.ops.index_add`, which were deprecated in 0.2.22, have been removed. Please use - [the `.at` property on JAX arrays](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html) + [the `.at` property on JAX arrays](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html) instead, e.g., `x.at[idx].set(y)`. * Moved `jax.experimental.ann.approx_*_k` into `jax.lax`. These functions are optimized alternatives to `jax.lax.top_k`. @@ -2110,13 +2252,13 @@ Changes: commits](https://github.com/jax-ml/jax/compare/jax-v0.2.28...jax-v0.3.0). * Changes - * jax version has been bumped to 0.3.0. Please see the [design doc](https://jax.readthedocs.io/en/latest/design_notes/jax_versioning.html) + * jax version has been bumped to 0.3.0. Please see the [design doc](https://docs.jax.dev/en/latest/design_notes/jax_versioning.html) for the explanation. ## jaxlib 0.3.0 (Feb 10, 2022) * Changes * Bazel 5.0.0 is now required to build jaxlib. - * jaxlib version has been bumped to 0.3.0. Please see the [design doc](https://jax.readthedocs.io/en/latest/design_notes/jax_versioning.html) + * jaxlib version has been bumped to 0.3.0. Please see the [design doc](https://docs.jax.dev/en/latest/design_notes/jax_versioning.html) for the explanation. ## jax 0.2.28 (Feb 1, 2022) @@ -2138,7 +2280,7 @@ Changes: by default. * Breaking changes * Support for NumPy 1.18 has been dropped, per the - [deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html). + [deprecation policy](https://docs.jax.dev/en/latest/deprecation.html). Please upgrade to a supported NumPy version. * Bug fixes * Fixed a bug where apparently identical pytreedef objects constructed by different routes @@ -2150,7 +2292,7 @@ Changes: * Breaking changes: * Support for NumPy 1.18 has been dropped, per the - [deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html). + [deprecation policy](https://docs.jax.dev/en/latest/deprecation.html). Please upgrade to a supported NumPy version. * The host_callback primitives have been simplified to drop the special autodiff handling for hcb.id_tap and id_print. @@ -2277,7 +2419,7 @@ Changes: * Deprecations * The functions `jax.ops.index_update`, `jax.ops.index_add` etc. are deprecated and will be removed in a future JAX release. Please use - [the `.at` property on JAX arrays](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html) + [the `.at` property on JAX arrays](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html) instead, e.g., `x.at[idx].set(y)`. For now, these functions produce a `DeprecationWarning`. * New features: @@ -2341,7 +2483,7 @@ Changes: commits](https://github.com/jax-ml/jax/compare/jax-v0.2.18...jax-v0.2.19). * Breaking changes: * Support for NumPy 1.17 has been dropped, per the - [deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html). + [deprecation policy](https://docs.jax.dev/en/latest/deprecation.html). Please upgrade to a supported NumPy version. * The `jit` decorator has been added around the implementation of a number of operators on JAX arrays. This speeds up dispatch times for common @@ -2362,10 +2504,10 @@ Changes: ## jaxlib 0.1.70 (Aug 9, 2021) * Breaking changes: * Support for Python 3.6 has been dropped, per the - [deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html). + [deprecation policy](https://docs.jax.dev/en/latest/deprecation.html). Please upgrade to a supported Python version. * Support for NumPy 1.17 has been dropped, per the - [deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html). + [deprecation policy](https://docs.jax.dev/en/latest/deprecation.html). Please upgrade to a supported NumPy version. * The host_callback mechanism now uses one thread per local device for @@ -2379,7 +2521,7 @@ Changes: * Breaking changes: * Support for Python 3.6 has been dropped, per the - [deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html). + [deprecation policy](https://docs.jax.dev/en/latest/deprecation.html). Please upgrade to a supported Python version. * The minimum jaxlib version is now 0.1.69. * The `backend` argument to {py:func}`jax.dlpack.from_dlpack` has been @@ -2428,7 +2570,7 @@ Changes: * Breaking changes: * Support for NumPy 1.16 has been dropped, per the - [deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html). + [deprecation policy](https://docs.jax.dev/en/latest/deprecation.html). * Bug fixes: * Fixed bug that prevented round-tripping from JAX to TF and back: @@ -2968,7 +3110,7 @@ Changes: * Support for reduction over subsets of a pmapped axis using `axis_index_groups` {jax-issue}`#2382`. * Experimental support for printing and calling host-side Python function from - compiled code. See [id_print and id_tap](https://jax.readthedocs.io/en/latest/jax.experimental.host_callback.html) + compiled code. See [id_print and id_tap](https://docs.jax.dev/en/latest/jax.experimental.host_callback.html) ({jax-issue}`#3006`). * Notable changes: * The visibility of names exported from {mod}`jax.numpy` has been @@ -3040,7 +3182,7 @@ Changes: ## jax 0.1.63 (April 12, 2020) * [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.62...jax-v0.1.63). -* Added `jax.custom_jvp` and `jax.custom_vjp` from {jax-issue}`#2026`, see the [tutorial notebook](https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html). Deprecated `jax.custom_transforms` and removed it from the docs (though it still works). +* Added `jax.custom_jvp` and `jax.custom_vjp` from {jax-issue}`#2026`, see the [tutorial notebook](https://docs.jax.dev/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html). Deprecated `jax.custom_transforms` and removed it from the docs (though it still works). * Add `scipy.sparse.linalg.cg` {jax-issue}`#2566`. * Changed how Tracers are printed to show more useful information for debugging {jax-issue}`#2591`. * Made `jax.numpy.isclose` handle `nan` and `inf` correctly {jax-issue}`#2501`. diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 314d4387a044..046d3df3195c 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,4 +1,4 @@ # Contributing to JAX For information on how to contribute to JAX, see -[Contributing to JAX](https://jax.readthedocs.io/en/latest/contributing.html) +[Contributing to JAX](https://docs.jax.dev/en/latest/contributing.html) diff --git a/README.md b/README.md index 0aca7cf58e6e..e6af1b344f24 100644 --- a/README.md +++ b/README.md @@ -7,12 +7,11 @@ [![Continuous integration](https://github.com/jax-ml/jax/actions/workflows/ci-build.yaml/badge.svg)](https://github.com/jax-ml/jax/actions/workflows/ci-build.yaml) [![PyPI version](https://img.shields.io/pypi/v/jax)](https://pypi.org/project/jax/) -[**Quickstart**](#quickstart-colab-in-the-cloud) -| [**Transformations**](#transformations) +[**Transformations**](#transformations) +| [**Scaling**](#scaling) | [**Install guide**](#installation) -| [**Neural net libraries**](#neural-network-libraries) -| [**Change logs**](https://jax.readthedocs.io/en/latest/changelog.html) -| [**Reference docs**](https://jax.readthedocs.io/en/latest/) +| [**Change logs**](https://docs.jax.dev/en/latest/changelog.html) +| [**Reference docs**](https://docs.jax.dev/en/latest/) ## What is JAX? @@ -20,42 +19,29 @@ JAX is a Python library for accelerator-oriented array computation and program transformation, designed for high-performance numerical computing and large-scale machine learning. -With its updated version of [Autograd](https://github.com/hips/autograd), JAX can automatically differentiate native Python and NumPy functions. It can differentiate through loops, branches, recursion, and closures, and it can take derivatives of derivatives of derivatives. It supports reverse-mode differentiation (a.k.a. backpropagation) -via [`grad`](#automatic-differentiation-with-grad) as well as forward-mode differentiation, +via [`jax.grad`](#automatic-differentiation-with-grad) as well as forward-mode differentiation, and the two can be composed arbitrarily to any order. -What’s new is that JAX uses [XLA](https://www.tensorflow.org/xla) -to compile and run your NumPy programs on GPUs and TPUs. Compilation happens -under the hood by default, with library calls getting just-in-time compiled and -executed. But JAX also lets you just-in-time compile your own Python functions -into XLA-optimized kernels using a one-function API, -[`jit`](#compilation-with-jit). Compilation and automatic differentiation can be -composed arbitrarily, so you can express sophisticated algorithms and get -maximal performance without leaving Python. You can even program multiple GPUs -or TPU cores at once using [`pmap`](#spmd-programming-with-pmap), and -differentiate through the whole thing. +JAX uses [XLA](https://www.tensorflow.org/xla) +to compile and scale your NumPy programs on TPUs, GPUs, and other hardware accelerators. +You can compile your own pure functions with [`jax.jit`](#compilation-with-jit). +Compilation and automatic differentiation can be composed arbitrarily. Dig a little deeper, and you'll see that JAX is really an extensible system for -[composable function transformations](#transformations). Both -[`grad`](#automatic-differentiation-with-grad) and [`jit`](#compilation-with-jit) -are instances of such transformations. Others are -[`vmap`](#auto-vectorization-with-vmap) for automatic vectorization and -[`pmap`](#spmd-programming-with-pmap) for single-program multiple-data (SPMD) -parallel programming of multiple accelerators, with more to come. +[composable function transformations](#transformations) at [scale](#scaling). This is a research project, not an official Google product. Expect -[sharp edges](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html). -Please help by trying it out, [reporting -bugs](https://github.com/jax-ml/jax/issues), and letting us know what you -think! +[sharp edges](https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html). +Please help by trying it out, [reporting bugs](https://github.com/jax-ml/jax/issues), +and letting us know what you think! ```python +import jax import jax.numpy as jnp -from jax import grad, jit, vmap def predict(params, inputs): for W, b in params: @@ -67,85 +53,50 @@ def loss(params, inputs, targets): preds = predict(params, inputs) return jnp.sum((preds - targets)**2) -grad_loss = jit(grad(loss)) # compiled gradient evaluation function -perex_grads = jit(vmap(grad_loss, in_axes=(None, 0, 0))) # fast per-example grads +grad_loss = jax.jit(jax.grad(loss)) # compiled gradient evaluation function +perex_grads = jax.jit(jax.vmap(grad_loss, in_axes=(None, 0, 0))) # fast per-example grads ``` ### Contents -* [Quickstart: Colab in the Cloud](#quickstart-colab-in-the-cloud) * [Transformations](#transformations) -* [Current gotchas](#current-gotchas) +* [Scaling](#scaling) +* [Current gotchas](#gotchas-and-sharp-bits) * [Installation](#installation) * [Neural net libraries](#neural-network-libraries) * [Citing JAX](#citing-jax) * [Reference documentation](#reference-documentation) -## Quickstart: Colab in the Cloud -Jump right in using a notebook in your browser, connected to a Google Cloud GPU. -Here are some starter notebooks: -- [The basics: NumPy on accelerators, `grad` for differentiation, `jit` for compilation, and `vmap` for vectorization](https://jax.readthedocs.io/en/latest/quickstart.html) -- [Training a Simple Neural Network, with TensorFlow Dataset Data Loading](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb) - -**JAX now runs on Cloud TPUs.** To try out the preview, see the [Cloud TPU -Colabs](https://github.com/jax-ml/jax/tree/main/cloud_tpu_colabs). - -For a deeper dive into JAX: -- [The Autodiff Cookbook, Part 1: easy and powerful automatic differentiation in JAX](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html) -- [Common gotchas and sharp edges](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html) -- See the [full list of -notebooks](https://github.com/jax-ml/jax/tree/main/docs/notebooks). - ## Transformations At its core, JAX is an extensible system for transforming numerical functions. -Here are four transformations of primary interest: `grad`, `jit`, `vmap`, and -`pmap`. +Here are three: `jax.grad`, `jax.jit`, and `jax.vmap`. ### Automatic differentiation with `grad` -JAX has roughly the same API as [Autograd](https://github.com/hips/autograd). -The most popular function is -[`grad`](https://jax.readthedocs.io/en/latest/jax.html#jax.grad) -for reverse-mode gradients: +Use [`jax.grad`](https://docs.jax.dev/en/latest/jax.html#jax.grad) +to efficiently compute reverse-mode gradients: ```python -from jax import grad +import jax import jax.numpy as jnp -def tanh(x): # Define a function +def tanh(x): y = jnp.exp(-2.0 * x) return (1.0 - y) / (1.0 + y) -grad_tanh = grad(tanh) # Obtain its gradient function -print(grad_tanh(1.0)) # Evaluate it at x = 1.0 +grad_tanh = jax.grad(tanh) +print(grad_tanh(1.0)) # prints 0.4199743 ``` -You can differentiate to any order with `grad`. +You can differentiate to any order with `grad`: ```python -print(grad(grad(grad(tanh)))(1.0)) +print(jax.grad(jax.grad(jax.grad(tanh)))(1.0)) # prints 0.62162673 ``` -For more advanced autodiff, you can use -[`jax.vjp`](https://jax.readthedocs.io/en/latest/jax.html#jax.vjp) for -reverse-mode vector-Jacobian products and -[`jax.jvp`](https://jax.readthedocs.io/en/latest/jax.html#jax.jvp) for -forward-mode Jacobian-vector products. The two can be composed arbitrarily with -one another, and with other JAX transformations. Here's one way to compose those -to make a function that efficiently computes [full Hessian -matrices](https://jax.readthedocs.io/en/latest/_autosummary/jax.hessian.html#jax.hessian): - -```python -from jax import jit, jacfwd, jacrev - -def hessian(fun): - return jit(jacfwd(jacrev(fun))) -``` - -As with [Autograd](https://github.com/hips/autograd), you're free to use -differentiation with Python control structures: +You're free to use differentiation with Python control flow: ```python def abs_val(x): @@ -154,242 +105,134 @@ def abs_val(x): else: return -x -abs_val_grad = grad(abs_val) +abs_val_grad = jax.grad(abs_val) print(abs_val_grad(1.0)) # prints 1.0 print(abs_val_grad(-1.0)) # prints -1.0 (abs_val is re-evaluated) ``` -See the [reference docs on automatic -differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) -and the [JAX Autodiff -Cookbook](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html) +See the [JAX Autodiff +Cookbook](https://docs.jax.dev/en/latest/notebooks/autodiff_cookbook.html) +and the [reference docs on automatic +differentiation](https://docs.jax.dev/en/latest/jax.html#automatic-differentiation) for more. ### Compilation with `jit` -You can use XLA to compile your functions end-to-end with -[`jit`](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit), +Use XLA to compile your functions end-to-end with +[`jit`](https://docs.jax.dev/en/latest/jax.html#just-in-time-compilation-jit), used either as an `@jit` decorator or as a higher-order function. ```python +import jax import jax.numpy as jnp -from jax import jit def slow_f(x): # Element-wise ops see a large benefit from fusion return x * x + x * 2.0 x = jnp.ones((5000, 5000)) -fast_f = jit(slow_f) -%timeit -n10 -r3 fast_f(x) # ~ 4.5 ms / loop on Titan X -%timeit -n10 -r3 slow_f(x) # ~ 14.5 ms / loop (also on GPU via JAX) +fast_f = jax.jit(slow_f) +%timeit -n10 -r3 fast_f(x) +%timeit -n10 -r3 slow_f(x) ``` -You can mix `jit` and `grad` and any other JAX transformation however you like. - -Using `jit` puts constraints on the kind of Python control flow +Using `jax.jit` constrains the kind of Python control flow the function can use; see -the tutorial on [Control Flow and Logical Operators with JIT](https://jax.readthedocs.io/en/latest/control-flow.html) +the tutorial on [Control Flow and Logical Operators with JIT](https://docs.jax.dev/en/latest/control-flow.html) for more. ### Auto-vectorization with `vmap` -[`vmap`](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) is -the vectorizing map. -It has the familiar semantics of mapping a function along array axes, but -instead of keeping the loop on the outside, it pushes the loop down into a -function’s primitive operations for better performance. +[`vmap`](https://docs.jax.dev/en/latest/jax.html#vectorization-vmap) maps +a function along array axes. +But instead of just looping over function applications, it pushes the loop down +onto the function’s primitive operations, e.g. turning matrix-vector multiplies into +matrix-matrix multiplies for better performance. Using `vmap` can save you from having to carry around batch dimensions in your -code. For example, consider this simple *unbatched* neural network prediction -function: - -```python -def predict(params, input_vec): - assert input_vec.ndim == 1 - activations = input_vec - for W, b in params: - outputs = jnp.dot(W, activations) + b # `activations` on the right-hand side! - activations = jnp.tanh(outputs) # inputs to the next layer - return outputs # no activation on last layer -``` - -We often instead write `jnp.dot(activations, W)` to allow for a batch dimension on the -left side of `activations`, but we’ve written this particular prediction function to -apply only to single input vectors. If we wanted to apply this function to a -batch of inputs at once, semantically we could just write - -```python -from functools import partial -predictions = jnp.stack(list(map(partial(predict, params), input_batch))) -``` - -But pushing one example through the network at a time would be slow! It’s better -to vectorize the computation, so that at every layer we’re doing matrix-matrix -multiplication rather than matrix-vector multiplication. - -The `vmap` function does that transformation for us. That is, if we write +code: ```python -from jax import vmap -predictions = vmap(partial(predict, params))(input_batch) -# or, alternatively -predictions = vmap(predict, in_axes=(None, 0))(params, input_batch) -``` +import jax +import jax.numpy as jnp -then the `vmap` function will push the outer loop inside the function, and our -machine will end up executing matrix-matrix multiplications exactly as if we’d -done the batching by hand. +def l1_distance(x, y): + assert x.ndim == y.ndim == 1 # only works on 1D inputs + return jnp.sum(jnp.abs(x - y)) -It’s easy enough to manually batch a simple neural network without `vmap`, but -in other cases manual vectorization can be impractical or impossible. Take the -problem of efficiently computing per-example gradients: that is, for a fixed set -of parameters, we want to compute the gradient of our loss function evaluated -separately at each example in a batch. With `vmap`, it’s easy: +def pairwise_distances(dist1D, xs): + return jax.vmap(jax.vmap(dist1D, (0, None)), (None, 0))(xs, xs) -```python -per_example_gradients = vmap(partial(grad(loss), params))(inputs, targets) +xs = jax.random.normal(jax.random.key(0), (100, 3)) +dists = pairwise_distances(l1_distance, xs) +dists.shape # (100, 100) ``` -Of course, `vmap` can be arbitrarily composed with `jit`, `grad`, and any other -JAX transformation! We use `vmap` with both forward- and reverse-mode automatic -differentiation for fast Jacobian and Hessian matrix calculations in -`jax.jacfwd`, `jax.jacrev`, and `jax.hessian`. - -### SPMD programming with `pmap` - -For parallel programming of multiple accelerators, like multiple GPUs, use -[`pmap`](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap). -With `pmap` you write single-program multiple-data (SPMD) programs, including -fast parallel collective communication operations. Applying `pmap` will mean -that the function you write is compiled by XLA (similarly to `jit`), then -replicated and executed in parallel across devices. - -Here's an example on an 8-GPU machine: +By composing `jax.vmap` with `jax.grad` and `jax.jit`, we can get efficient +Jacobian matrices, or per-example gradients: ```python -from jax import random, pmap -import jax.numpy as jnp - -# Create 8 random 5000 x 6000 matrices, one per GPU -keys = random.split(random.key(0), 8) -mats = pmap(lambda key: random.normal(key, (5000, 6000)))(keys) - -# Run a local matmul on each device in parallel (no data transfer) -result = pmap(lambda x: jnp.dot(x, x.T))(mats) # result.shape is (8, 5000, 5000) - -# Compute the mean on each device in parallel and print the result -print(pmap(jnp.mean)(result)) -# prints [1.1566595 1.1805978 ... 1.2321935 1.2015157] +per_example_grads = jax.jit(jax.vmap(jax.grad(loss), in_axes=(None, 0, 0))) ``` -In addition to expressing pure maps, you can use fast [collective communication -operations](https://jax.readthedocs.io/en/latest/jax.lax.html#parallel-operators) -between devices: +## Scaling + +To scale your computations across thousands of devices, you can use any +composition of these: +* [**Compiler-based automatic parallelization**](https://docs.jax.dev/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html) +where you program as if using a single global machine, and the compiler chooses +how to shard data and partition computation (with some user-provided constraints); +* [**Explicit sharding and automatic partitioning**](https://docs.jax.dev/en/latest/notebooks/explicit-sharding.html) +where you still have a global view but data shardings are +explicit in JAX types, inspectable using `jax.typeof`; +* [**Manual per-device programming**](https://docs.jax.dev/en/latest/notebooks/shard_map.html) +where you have a per-device view of data +and computation, and can communicate with explicit collectives. + +| Mode | View? | Explicit sharding? | Explicit Collectives? | +|---|---|---|---| +| Auto | Global | ❌ | ❌ | +| Explicit | Global | ✅ | ❌ | +| Manual | Per-device | ✅ | ✅ | ```python -from functools import partial -from jax import lax +from jax.sharding import set_mesh, AxisType, PartitionSpec as P +mesh = jax.make_mesh((8,), ('data',), axis_types=(AxisType.Explicit,)) +set_mesh(mesh) -@partial(pmap, axis_name='i') -def normalize(x): - return x / lax.psum(x, 'i') +# parameters are sharded for FSDP: +for W, b in params: + print(f'{jax.typeof(W)}') # f32[512@data,512] + print(f'{jax.typeof(b)}') # f32[512] -print(normalize(jnp.arange(4.))) -# prints [0. 0.16666667 0.33333334 0.5 ] -``` +# shard data for batch parallelism: +inputs, targets = jax.device_put((inputs, targets), P('data')) -You can even [nest `pmap` functions](https://colab.research.google.com/github/jax-ml/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb#scrollTo=MdRscR5MONuN) for more -sophisticated communication patterns. - -It all composes, so you're free to differentiate through parallel computations: - -```python -from jax import grad - -@pmap -def f(x): - y = jnp.sin(x) - @pmap - def g(z): - return jnp.cos(z) * jnp.tan(y.sum()) * jnp.tanh(x).sum() - return grad(lambda w: jnp.sum(g(w)))(x) - -print(f(x)) -# [[ 0. , -0.7170853 ], -# [-3.1085174 , -0.4824318 ], -# [10.366636 , 13.135289 ], -# [ 0.22163185, -0.52112055]] - -print(grad(lambda x: jnp.sum(f(x)))(x)) -# [[ -3.2369726, -1.6356447], -# [ 4.7572474, 11.606951 ], -# [-98.524414 , 42.76499 ], -# [ -1.6007166, -1.2568436]] +# evaluate gradients, automatically parallelized! +gradfun = jax.jit(jax.grad(loss)) +param_grads = gradfun(params, (inputs, targets)) ``` -When reverse-mode differentiating a `pmap` function (e.g. with `grad`), the -backward pass of the computation is parallelized just like the forward pass. +See the [tutorial](https://docs.jax.dev/en/latest/sharded-computation.html) and +[advanced guides](https://docs.jax.dev/en/latest/advanced_guide.html) for more. -See the [SPMD -Cookbook](https://colab.research.google.com/github/jax-ml/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb) -and the [SPMD MNIST classifier from scratch -example](https://github.com/jax-ml/jax/blob/main/examples/spmd_mnist_classifier_fromscratch.py) -for more. +## Gotchas and sharp bits -## Current gotchas - -For a more thorough survey of current gotchas, with examples and explanations, -we highly recommend reading the [Gotchas -Notebook](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html). -Some standouts: - -1. JAX transformations only work on [pure functions](https://en.wikipedia.org/wiki/Pure_function), which don't have side-effects and respect [referential transparency](https://en.wikipedia.org/wiki/Referential_transparency) (i.e. object identity testing with `is` isn't preserved). If you use a JAX transformation on an impure Python function, you might see an error like `Exception: Can't lift Traced...` or `Exception: Different traces at same level`. -1. [In-place mutating updates of - arrays](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#in-place-updates), like `x[i] += y`, aren't supported, but [there are functional alternatives](https://jax.readthedocs.io/en/latest/jax.ops.html). Under a `jit`, those functional alternatives will reuse buffers in-place automatically. -1. [Random numbers are - different](https://jax.readthedocs.io/en/latest/random-numbers.html), but for [good reasons](https://github.com/jax-ml/jax/blob/main/docs/jep/263-prng.md). -1. If you're looking for [convolution - operators](https://jax.readthedocs.io/en/latest/notebooks/convolutions.html), - they're in the `jax.lax` package. -1. JAX enforces single-precision (32-bit, e.g. `float32`) values by default, and - [to enable - double-precision](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision) - (64-bit, e.g. `float64`) one needs to set the `jax_enable_x64` variable at - startup (or set the environment variable `JAX_ENABLE_X64=True`). - On TPU, JAX uses 32-bit values by default for everything _except_ internal - temporary variables in 'matmul-like' operations, such as `jax.numpy.dot` and `lax.conv`. - Those ops have a `precision` parameter which can be used to approximate 32-bit operations - via three bfloat16 passes, with a cost of possibly slower runtime. - Non-matmul operations on TPU lower to implementations that often emphasize speed over - accuracy, so in practice computations on TPU will be less precise than similar - computations on other backends. -1. Some of NumPy's dtype promotion semantics involving a mix of Python scalars - and NumPy types aren't preserved, namely `np.add(1, np.array([2], - np.float32)).dtype` is `float64` rather than `float32`. -1. Some transformations, like `jit`, [constrain how you can use Python control - flow](https://jax.readthedocs.io/en/latest/control-flow.html). - You'll always get loud errors if something goes wrong. You might have to use - [`jit`'s `static_argnums` - parameter](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit), - [structured control flow - primitives](https://jax.readthedocs.io/en/latest/jax.lax.html#control-flow-operators) - like - [`lax.scan`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html#jax.lax.scan), - or just use `jit` on smaller subfunctions. +See the [Gotchas +Notebook](https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html). ## Installation ### Supported platforms -| | Linux x86_64 | Linux aarch64 | Mac x86_64 | Mac aarch64 | Windows x86_64 | Windows WSL2 x86_64 | -|------------|--------------|---------------|--------------|--------------|----------------|---------------------| -| CPU | yes | yes | yes | yes | yes | yes | -| NVIDIA GPU | yes | yes | no | n/a | no | experimental | -| Google TPU | yes | n/a | n/a | n/a | n/a | n/a | -| AMD GPU | yes | no | experimental | n/a | no | no | -| Apple GPU | n/a | no | n/a | experimental | n/a | n/a | -| Intel GPU | experimental | n/a | n/a | n/a | no | no | +| | Linux x86_64 | Linux aarch64 | Mac aarch64 | Windows x86_64 | Windows WSL2 x86_64 | +|------------|--------------|---------------|--------------|----------------|---------------------| +| CPU | yes | yes | yes | yes | yes | +| NVIDIA GPU | yes | yes | n/a | no | experimental | +| Google TPU | yes | n/a | n/a | n/a | n/a | +| AMD GPU | yes | no | n/a | no | no | +| Apple GPU | n/a | no | experimental | n/a | n/a | +| Intel GPU | experimental | n/a | n/a | no | no | ### Instructions @@ -403,28 +246,11 @@ Some standouts: | Mac GPU | Follow [Apple's instructions](https://developer.apple.com/metal/jax/). | | Intel GPU | Follow [Intel's instructions](https://github.com/intel/intel-extension-for-openxla/blob/main/docs/acc_jax.md). | -See [the documentation](https://jax.readthedocs.io/en/latest/installation.html) +See [the documentation](https://docs.jax.dev/en/latest/installation.html) for information on alternative installation strategies. These include compiling from source, installing with Docker, using other versions of CUDA, a community-supported conda build, and answers to some frequently-asked questions. - - -## Neural network libraries - -Multiple Google research groups at Google DeepMind and Alphabet develop and share libraries -for training neural networks in JAX. If you want a fully featured library for neural network -training with examples and how-to guides, try -[Flax](https://github.com/google/flax) and its [documentation site](https://flax.readthedocs.io/en/latest/nnx/index.html). - -Check out the [JAX Ecosystem section](https://jax.readthedocs.io/en/latest/#ecosystem) -on the JAX documentation site for a list of JAX-based network libraries, which includes -[Optax](https://github.com/deepmind/optax) for gradient processing and -optimization, [chex](https://github.com/deepmind/chex) for reliable code and testing, and -[Equinox](https://github.com/patrick-kidger/equinox) for neural networks. -(Watch the NeurIPS 2020 JAX Ecosystem at DeepMind talk -[here](https://www.youtube.com/watch?v=iDxJxIyzSiM) for additional details.) - ## Citing JAX To cite this repository: @@ -452,7 +278,7 @@ paper. ## Reference documentation For details about the JAX API, see the -[reference documentation](https://jax.readthedocs.io/). +[reference documentation](https://docs.jax.dev/). For getting started as a JAX developer, see the -[developer documentation](https://jax.readthedocs.io/en/latest/developer.html). +[developer documentation](https://docs.jax.dev/en/latest/developer.html). diff --git a/WORKSPACE b/WORKSPACE index 129488281ea9..33be3c6e0452 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -1,68 +1,88 @@ # The XLA commit is determined by third_party/xla/workspace.bzl. load("//third_party/xla:workspace.bzl", jax_xla_workspace = "repo") + jax_xla_workspace() # Initialize hermetic Python load("@xla//third_party/py:python_init_rules.bzl", "python_init_rules") + python_init_rules() load("@xla//third_party/py:python_init_repositories.bzl", "python_init_repositories") + python_init_repositories( - requirements = { - "3.10": "//build:requirements_lock_3_10.txt", - "3.11": "//build:requirements_lock_3_11.txt", - "3.12": "//build:requirements_lock_3_12.txt", - "3.13": "//build:requirements_lock_3_13.txt", - "3.13-ft": "//build:requirements_lock_3_13_ft.txt", - }, + default_python_version = "system", + local_wheel_dist_folder = "../dist", local_wheel_inclusion_list = [ + "ml_dtypes*", + "ml-dtypes*", + "numpy*", + "scipy*", + "jax-*", "jaxlib*", "jax_cuda*", "jax-cuda*", ], local_wheel_workspaces = ["//jaxlib:jax.bzl"], - local_wheel_dist_folder = "../dist", - default_python_version = "system", + requirements = { + "3.11": "//build:requirements_lock_3_11.txt", + "3.12": "//build:requirements_lock_3_12.txt", + "3.13": "//build:requirements_lock_3_13.txt", + "3.14": "//build:requirements_lock_3_14.txt", + "3.13-ft": "//build:requirements_lock_3_13_ft.txt", + "3.14-ft": "//build:requirements_lock_3_14_ft.txt", + }, ) load("@xla//third_party/py:python_init_toolchains.bzl", "python_init_toolchains") + python_init_toolchains() load("@xla//third_party/py:python_init_pip.bzl", "python_init_pip") + python_init_pip() load("@pypi//:requirements.bzl", "install_deps") + install_deps() # Optional, to facilitate testing against newest versions of Python load("@xla//third_party/py:python_repo.bzl", "custom_python_interpreter") + custom_python_interpreter( name = "python_dev", - urls = ["https://www.python.org/ftp/python/{version}/Python-{version_variant}.tgz"], strip_prefix = "Python-{version_variant}", + urls = ["https://www.python.org/ftp/python/{version}/Python-{version_variant}.tgz"], version = "3.13.0", version_variant = "3.13.0rc2", ) load("@xla//:workspace4.bzl", "xla_workspace4") + xla_workspace4() load("@xla//:workspace3.bzl", "xla_workspace3") + xla_workspace3() load("@xla//:workspace2.bzl", "xla_workspace2") + xla_workspace2() load("@xla//:workspace1.bzl", "xla_workspace1") + xla_workspace1() load("@xla//:workspace0.bzl", "xla_workspace0") + xla_workspace0() load("//third_party/flatbuffers:workspace.bzl", flatbuffers = "repo") + flatbuffers() load("//jaxlib:jax_python_wheel.bzl", "jax_python_wheel_repository") + jax_python_wheel_repository( name = "jax_wheel", version_key = "_version", @@ -73,6 +93,7 @@ load( "@xla//third_party/py:python_wheel.bzl", "python_wheel_version_suffix_repository", ) + python_wheel_version_suffix_repository( name = "jax_wheel_version_suffix", ) @@ -123,3 +144,30 @@ load( ) nccl_configure(name = "local_config_nccl") + +load( + "@xla//third_party/nvshmem/hermetic:nvshmem_json_init_repository.bzl", + "nvshmem_json_init_repository", +) + +nvshmem_json_init_repository() + +load( + "@nvshmem_redist_json//:distributions.bzl", + "NVSHMEM_REDISTRIBUTIONS", +) +load( + "@xla//third_party/nvshmem/hermetic:nvshmem_redist_init_repository.bzl", + "nvshmem_redist_init_repository", +) + +nvshmem_redist_init_repository( + nvshmem_redistributions = NVSHMEM_REDISTRIBUTIONS, +) + +load( + "@xla//third_party/nvshmem/hermetic:nvshmem_configure.bzl", + "nvshmem_configure", +) + +nvshmem_configure(name = "local_config_nvshmem") diff --git a/benchmarks/api_benchmark.py b/benchmarks/api_benchmark.py index cabebce2227c..a62b78d66ced 100644 --- a/benchmarks/api_benchmark.py +++ b/benchmarks/api_benchmark.py @@ -847,7 +847,7 @@ def safe_map(state): args = tuple(list(range(state.range(0))) for _ in range(state.range(1))) def f(*args): return tuple(args) while state: - jax.util.safe_map(f, *args) + jax._src.util.safe_map(f, *args) @google_benchmark.register @google_benchmark.option.arg_names(['arg_lengths', 'num_args']) @@ -855,7 +855,7 @@ def f(*args): return tuple(args) def safe_zip(state): args = tuple(list(range(state.range(0))) for _ in range(state.range(1))) while state: - jax.util.safe_zip(*args) + jax._src.util.safe_zip(*args) @google_benchmark.register diff --git a/benchmarks/mosaic/matmul_bench.py b/benchmarks/mosaic/matmul_bench.py index 32c147916407..fd3fcd6da315 100644 --- a/benchmarks/mosaic/matmul_bench.py +++ b/benchmarks/mosaic/matmul_bench.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Microbenchmarks for mosaic gpu matrix mutliplication.""" +"""Microbenchmarks for mosaic gpu matrix multiplication.""" import functools import sys diff --git a/benchmarks/sparse_benchmark.py b/benchmarks/sparse_benchmark.py index d6328881d5c6..0ffb2aed5125 100644 --- a/benchmarks/sparse_benchmark.py +++ b/benchmarks/sparse_benchmark.py @@ -21,7 +21,13 @@ import jax from jax.experimental import sparse -def _sparse_bcoo_fromdense(state, jit: bool = False, compile: bool = False): + +def _sparse_fromdense( + state, + bcsr: bool = False, + jit: bool = False, + compile: bool = False, +): shape = (2000, 2000) nse = 10000 size = math.prod(shape) @@ -32,7 +38,7 @@ def _sparse_bcoo_fromdense(state, jit: bool = False, compile: bool = False): ) mat = jnp.zeros(shape).at[indices].set(data) - f = sparse.BCOO.fromdense + f = sparse.BCSR.fromdense if bcsr else sparse.BCOO.fromdense if compile or jit: # Note: nse must be specified for JIT. f = jax.jit(partial(f, nse=nse)) @@ -49,22 +55,12 @@ def _sparse_bcoo_fromdense(state, jit: bool = False, compile: bool = False): f(mat).block_until_ready() -@google_benchmark.register -def sparse_bcoo_fromdense(state): - return _sparse_bcoo_fromdense(state) - - -@google_benchmark.register -def sparse_bcoo_fromdense_jit(state): - return _sparse_bcoo_fromdense(state, jit=True) - - -@google_benchmark.register -def sparse_bcoo_fromdense_compile(state): - return _sparse_bcoo_fromdense(state, compile=True) - - -def _sparse_bcoo_todense(state, jit: bool = False, compile: bool = False): +def _sparse_todense( + state, + bcsr: bool = False, + jit: bool = False, + compile: bool = False, +): shape = (2000, 2000) nse = 10000 size = math.prod(shape) @@ -74,6 +70,8 @@ def _sparse_bcoo_todense(state, jit: bool = False, compile: bool = False): rng.choice(size, size=nse, replace=False), shape=shape ) mat = sparse.BCOO((jnp.array(data), jnp.column_stack(indices)), shape=shape) + if bcsr: + mat = sparse.BCSR.from_bcoo(mat) f = lambda mat: mat.todense() if jit or compile: @@ -91,22 +89,12 @@ def _sparse_bcoo_todense(state, jit: bool = False, compile: bool = False): f(mat).block_until_ready() -@google_benchmark.register -def sparse_bcoo_todense(state): - return _sparse_bcoo_todense(state) - - -@google_benchmark.register -def sparse_bcoo_todense_jit(state): - return _sparse_bcoo_todense(state, jit=True) - - -@google_benchmark.register -def sparse_bcoo_todense_compile(state): - return _sparse_bcoo_todense(state, compile=True) - - -def _sparse_bcoo_matvec(state, jit: bool = False, compile: bool = False): +def _sparse_matvec( + state, + bcsr: bool = False, + jit: bool = False, + compile: bool = False, +): shape = (2000, 2000) nse = 10000 key = jax.random.key(1701) @@ -118,6 +106,9 @@ def _sparse_bcoo_matvec(state, jit: bool = False, compile: bool = False): indices_dtype=jnp.int32, sorted_indices=True, ) + if bcsr: + mat = sparse.BCSR.from_bcoo(mat) + vec = jax.random.uniform(key, shape=(shape[1],), dtype=jnp.float32) f = lambda mat, vec: mat @ vec @@ -136,19 +127,94 @@ def _sparse_bcoo_matvec(state, jit: bool = False, compile: bool = False): f(mat, vec).block_until_ready() +@google_benchmark.register +def sparse_bcoo_fromdense(state): + return _sparse_fromdense(state) + + +@google_benchmark.register +def sparse_bcoo_fromdense_jit(state): + return _sparse_fromdense(state, jit=True) + + +@google_benchmark.register +def sparse_bcoo_fromdense_compile(state): + return _sparse_fromdense(state, compile=True) + + +@google_benchmark.register +def sparse_bcoo_todense(state): + return _sparse_todense(state) + + +@google_benchmark.register +def sparse_bcoo_todense_jit(state): + return _sparse_todense(state, jit=True) + + +@google_benchmark.register +def sparse_bcoo_todense_compile(state): + return _sparse_todense(state, compile=True) + + @google_benchmark.register def sparse_bcoo_matvec(state): - return _sparse_bcoo_matvec(state) + return _sparse_matvec(state) @google_benchmark.register def sparse_bcoo_matvec_jit(state): - return _sparse_bcoo_matvec(state, jit=True) + return _sparse_matvec(state, jit=True) @google_benchmark.register def sparse_bcoo_matvec_compile(state): - return _sparse_bcoo_matvec(state, compile=True) + return _sparse_matvec(state, compile=True) + + +@google_benchmark.register +def sparse_bscr_fromdense(state): + return _sparse_fromdense(state, bcsr=True) + + +@google_benchmark.register +def sparse_bscr_fromdense_jit(state): + return _sparse_fromdense(state, bcsr=True, jit=True) + + +@google_benchmark.register +def sparse_bscr_fromdense_compile(state): + return _sparse_fromdense(state, bcsr=True, compile=True) + + +@google_benchmark.register +def sparse_bscr_todense(state): + return _sparse_todense(state, bcsr=True) + + +@google_benchmark.register +def sparse_bscr_todense_jit(state): + return _sparse_todense(state, bcsr=True, jit=True) + + +@google_benchmark.register +def sparse_bscr_todense_compile(state): + return _sparse_todense(state, bcsr=True, compile=True) + + +@google_benchmark.register +def sparse_bcsr_matvec(state): + return _sparse_matvec(state, bcsr=True) + + +@google_benchmark.register +def sparse_bcsr_matvec_jit(state): + return _sparse_matvec(state, bcsr=True, jit=True) + + +@google_benchmark.register +def sparse_bcsr_matvec_compile(state): + return _sparse_matvec(state, bcsr=True, compile=True) if __name__ == "__main__": diff --git a/benchmarks/tracing_benchmark.py b/benchmarks/tracing_benchmark.py new file mode 100644 index 000000000000..e06ad538d476 --- /dev/null +++ b/benchmarks/tracing_benchmark.py @@ -0,0 +1,76 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Benchmarks for Jax tracing.""" + +import google_benchmark +import jax +from jax import random +from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_kernel as splash +from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_mask as mask_lib +import numpy as np + + +def make_mqa_splash_attention_fn_and_args(): + seed = 0 + key = random.key(seed) + k1, k2, k3 = random.split(key, 3) + + q_seq_len = 1024 + kv_seq_len = 1024 + num_q_heads = 2 + head_dim_qk = 128 + head_dim_v = 128 + dtype = np.dtype("float32") + + q = random.uniform(k1, (num_q_heads, q_seq_len, head_dim_qk), dtype=dtype) + k = random.uniform(k2, (kv_seq_len, head_dim_qk), dtype=dtype) + v = random.uniform(k3, (kv_seq_len, head_dim_v), dtype=dtype) + + mask = mask_lib.NumpyMask( + mask_lib.make_random_mask((q_seq_len, kv_seq_len), sparsity=0.5, seed=0) + ) + mask = mask_lib.MultiHeadMask(tuple(mask for _ in range(num_q_heads))) + block_sizes = splash.BlockSizes.get_default() + + return ( + jax.jit( + splash.make_splash_mqa_single_device(mask, block_sizes=block_sizes) + ) + ), (q, k, v) + + +@google_benchmark.register +@google_benchmark.option.unit(google_benchmark.kMillisecond) +def test_pallas_mqa_splash_attention_trace(state): + attn, (q, k, v) = make_mqa_splash_attention_fn_and_args() + + while state: + _ = attn.trace(q, k, v) + jax.clear_caches() + + +@google_benchmark.register +@google_benchmark.option.unit(google_benchmark.kMillisecond) +def test_pallas_mqa_splash_attention_lower(state): + attn, (q, k, v) = make_mqa_splash_attention_fn_and_args() + traced = attn.trace(q, k, v) + + while state: + _ = traced.lower(lowering_platforms=("tpu",)) + jax.clear_caches() + + +if __name__ == "__main__": + google_benchmark.main() diff --git a/build/BUILD.bazel b/build/BUILD.bazel index f088cd58aa74..a3d347d9209a 100644 --- a/build/BUILD.bazel +++ b/build/BUILD.bazel @@ -13,55 +13,82 @@ # limitations under the License. # ============================================================================== -licenses(["notice"]) - load("@python//:defs.bzl", "compile_pip_requirements") load("@python_version_repo//:py_version.bzl", "REQUIREMENTS") +load("@rules_python//python:py_library.bzl", "py_library") load("//jaxlib:jax.bzl", "all_py_deps") -compile_pip_requirements( - name = "requirements", - extra_args = [ - "--allow-unsafe", - "--build-isolation", - "--rebuild", - ], - requirements_in = "requirements.in", - requirements_txt = REQUIREMENTS, - generate_hashes = True, - data = ["test-requirements.txt", "gpu-test-requirements.txt"] -) +licenses(["notice"]) -compile_pip_requirements( - name = "requirements_nightly", - extra_args = [ - "--allow-unsafe", - "--build-isolation", - "--extra-index-url=https://pypi.anaconda.org/scientific-python-nightly-wheels/simple", - "--pre", - "--upgrade" - ], - requirements_in = "requirements.in", - requirements_txt = REQUIREMENTS, - generate_hashes = False, - data = ["test-requirements.txt", "gpu-test-requirements.txt"] -) +COMMON_REQUIREMENTS = [ + "requirements.in", + "test-requirements.txt", +] -compile_pip_requirements( - name = "requirements_dev", - extra_args = [ - "--allow-unsafe", - "--build-isolation", - "--upgrade", - "--rebuild", - ], - requirements_in = "requirements.in", - requirements_txt = REQUIREMENTS, - generate_hashes = False, - data = ["test-requirements.txt", "gpu-test-requirements.txt"] -) +# It isn't possible to constraint based on free-threaded vs non-free threaded +# in a requirements file. So we do it by having two separate sets of requirement +# files and two sets of build rules. +FREETHREADING_REQUIREMENTS = COMMON_REQUIREMENTS + [ + "freethreading-requirements.txt", +] +NON_FREETHREADING_REQUIREMENTS = COMMON_REQUIREMENTS + [ + "nonfreethreading-requirements.txt", +] + +COMBOS = [ + ("", NON_FREETHREADING_REQUIREMENTS), + ("_ft", FREETHREADING_REQUIREMENTS), +] + +[ + compile_pip_requirements( + name = "requirements" + suffix, + extra_args = [ + "--allow-unsafe", + "--build-isolation", + "--rebuild", + ], + srcs = requirements, + requirements_txt = REQUIREMENTS, + generate_hashes = True, + ) + for suffix, requirements in COMBOS +] + +[ + compile_pip_requirements( + name = "requirements_nightly" + suffix, + extra_args = [ + "--allow-unsafe", + "--build-isolation", + "--extra-index-url=https://pypi.anaconda.org/scientific-python-nightly-wheels/simple", + "--pre", + "--upgrade", + ], + srcs = requirements, + requirements_txt = REQUIREMENTS, + generate_hashes = False, + ) + for suffix, requirements in COMBOS +] + +[ + compile_pip_requirements( + name = "requirements_dev" + suffix, + extra_args = [ + "--allow-unsafe", + "--build-isolation", + "--upgrade", + "--rebuild", + ], + srcs = requirements, + requirements_txt = REQUIREMENTS, + generate_hashes = False, + ) + for suffix, requirements in COMBOS +] py_library( name = "all_py_deps", - deps = all_py_deps(["zstandard"]), -) \ No newline at end of file + deps = all_py_deps(["zstandard", "tensorstore"]), +) diff --git a/build/build.py b/build/build.py index d38b911bb904..40e02a100d98 100755 --- a/build/build.py +++ b/build/build.py @@ -1,4 +1,4 @@ -#!/usr/bin/python +#!/usr/bin/env python3 # # Copyright 2018 The JAX Authors. # @@ -68,13 +68,20 @@ # rule as the default. WHEEL_BUILD_TARGET_DICT_NEW = { "jax": "//:jax_wheel", + "jax_editable": "//:jax_wheel_editable", + "jax_source_package": "//:jax_source_package", "jaxlib": "//jaxlib/tools:jaxlib_wheel", + "jaxlib_editable": "//jaxlib/tools:jaxlib_wheel_editable", "jax-cuda-plugin": "//jaxlib/tools:jax_cuda_plugin_wheel", + "jax-cuda-plugin_editable": "//jaxlib/tools:jax_cuda_plugin_wheel_editable", "jax-cuda-pjrt": "//jaxlib/tools:jax_cuda_pjrt_wheel", + "jax-cuda-pjrt_editable": "//jaxlib/tools:jax_cuda_pjrt_wheel_editable", "jax-rocm-plugin": "//jaxlib/tools:jax_rocm_plugin_wheel", "jax-rocm-pjrt": "//jaxlib/tools:jax_rocm_pjrt_wheel", } +_JAX_CUDA_VERSION = "12" + def add_global_arguments(parser: argparse.ArgumentParser): """Adds all the global arguments that applies to all the CLI subcommands.""" parser.add_argument( @@ -382,6 +389,11 @@ async def main(): arch = platform.machine() os_name = platform.system().lower() + custom_wheel_version_suffix = "" + wheel_build_date = "" + wheel_git_hash = "" + wheel_type = "snapshot" + args = parser.parse_args() logger.info("%s", BANNER) @@ -407,11 +419,12 @@ async def main(): for option in args.bazel_startup_options: bazel_command_base.append(option) - if not args.use_new_wheel_build_rule or args.command == "requirements_update": + if args.command == "requirements_update" or not args.use_new_wheel_build_rule: bazel_command_base.append("run") else: bazel_command_base.append("build") + freethreaded = False if args.python_version: # Do not add --repo_env=HERMETIC_PYTHON_VERSION with default args.python_version # if bazel_options override it @@ -427,6 +440,7 @@ async def main(): ) # Let's interpret X.YY-ft version as free-threading python and set rules_python config flag: if args.python_version.endswith("-ft"): + freethreaded = True bazel_command_base.append( "--@rules_python//python/config_settings:py_freethreaded='yes'" ) @@ -444,14 +458,15 @@ async def main(): for option in args.bazel_options: requirements_command.append(option) + ft_suffix = "_ft" if freethreaded else "" if args.nightly_update: logging.info( "--nightly_update is set. Bazel will run" " //build:requirements_nightly.update" ) - requirements_command.append("//build:requirements_nightly.update") + requirements_command.append(f"//build:requirements{ft_suffix}_nightly.update") else: - requirements_command.append("//build:requirements.update") + requirements_command.append(f"//build:requirements{ft_suffix}.update") result = await executor.run(requirements_command.get_command_as_string(), args.dry_run, args.detailed_timestamped_log) if result.return_code != 0: @@ -489,6 +504,7 @@ async def main(): if args.use_clang: clang_path = args.clang_path or utils.get_clang_path_or_exit() clang_major_version = utils.get_clang_major_version(clang_path) + clangpp_path = utils.get_clangpp_path(clang_path) logging.debug( "Using Clang as the compiler, clang path: %s, clang version: %s", clang_path, @@ -498,6 +514,7 @@ async def main(): # Use double quotes around clang path to avoid path issues on Windows. wheel_build_command_base.append(f"--action_env=CLANG_COMPILER_PATH=\"{clang_path}\"") wheel_build_command_base.append(f"--repo_env=CC=\"{clang_path}\"") + wheel_build_command_base.append(f"--repo_env=CXX=\"{clangpp_path}\"") wheel_build_command_base.append(f"--repo_env=BAZEL_COMPILER=\"{clang_path}\"") if clang_major_version >= 16: @@ -556,7 +573,6 @@ async def main(): if "cuda" in args.wheels: wheel_build_command_base.append("--config=cuda") - wheel_build_command_base.append("--config=cuda_libraries_from_stubs") if args.use_clang: wheel_build_command_base.append( f"--action_env=CLANG_CUDA_COMPILER_PATH=\"{clang_path}\"" @@ -596,7 +612,7 @@ async def main(): wheel_build_command_base.append("--config=rocm") wheel_build_command_base.append(f"--action_env=CLANG_COMPILER_PATH=\"{clang_path}\"") if args.rocm_path: - logging.debug("ROCm tookit path: %s", args.rocm_path) + logging.debug("ROCm toolkit path: %s", args.rocm_path) wheel_build_command_base.append(f"--action_env=ROCM_PATH=\"{args.rocm_path}\"") if args.rocm_amdgpu_targets: logging.debug("ROCm AMD GPU targets: %s", args.rocm_amdgpu_targets) @@ -612,8 +628,19 @@ async def main(): ) for option in args.bazel_options: wheel_build_command_base.append(option) - if "cuda" in args.wheels: - wheel_build_command_base.append("--config=cuda_libraries_from_stubs") + + # Parse the build options for the wheel version suffix. + if "ML_WHEEL_TYPE" in option: + wheel_type = option.split("=")[-1] + if "ML_WHEEL_VERSION_SUFFIX" in option: + custom_wheel_version_suffix = option.split("=")[-1].replace("-", "") + if "ML_WHEEL_BUILD_DATE" in option: + wheel_build_date = option.split("=")[-1].replace("-", "") + if "ML_WHEEL_GIT_HASH" in option: + # Strip leading zeros as they end up being stripped by setuptools, + # which leads to a mismatch between expected and actual wheel names + # https://peps.python.org/pep-0440/ + wheel_git_hash = option.split("=")[-1].lstrip('0')[:9] with open(".jax_configure.bazelrc", "w") as f: jax_configure_options = utils.get_jax_configure_bazel_options(wheel_build_command_base.get_command_as_list(), args.use_new_wheel_build_rule) @@ -649,7 +676,9 @@ async def main(): ) sys.exit(1) - wheel_build_command = copy.deepcopy(wheel_build_command_base) + wheel_build_command = copy.deepcopy(bazel_command_base) + if "cuda" in args.wheels: + wheel_build_command.append("--config=cuda_libraries_from_stubs") print("\n") logger.info( "Building %s for %s %s...", @@ -659,8 +688,13 @@ async def main(): ) # Append the build target to the Bazel command. - build_target = wheel_build_targets[wheel] + if args.use_new_wheel_build_rule and args.editable: + build_target = wheel_build_targets[wheel + "_editable"] + else: + build_target = wheel_build_targets[wheel] wheel_build_command.append(build_target) + if args.use_new_wheel_build_rule and wheel == "jax" and not args.editable: + wheel_build_command.append(wheel_build_targets["jax_source_package"]) if not args.use_new_wheel_build_rule: wheel_build_command.append("--") @@ -692,6 +726,54 @@ async def main(): if result.return_code != 0: raise RuntimeError(f"Command failed with return code {result.return_code}") + if args.use_new_wheel_build_rule: + output_path = args.output_path + jax_bazel_dir = os.path.join("bazel-bin", "dist") + jaxlib_and_plugins_bazel_dir = os.path.join( + "bazel-bin", "jaxlib", "tools", "dist" + ) + for wheel in args.wheels.split(","): + if wheel == "jax": + bazel_dir = jax_bazel_dir + else: + bazel_dir = jaxlib_and_plugins_bazel_dir + if "cuda" in wheel: + wheel_dir = wheel.replace("cuda", f"cuda{_JAX_CUDA_VERSION}").replace( + "-", "_" + ) + else: + wheel_dir = wheel + + if args.editable: + src_dir = os.path.join(bazel_dir, wheel_dir) + dst_dir = os.path.join(output_path, wheel_dir) + utils.copy_dir_recursively(src_dir, dst_dir) + else: + wheel_version_suffix = "dev0+selfbuilt" + if wheel_type == "release": + wheel_version_suffix = custom_wheel_version_suffix + elif wheel_type in ["nightly", "custom"]: + wheel_version_suffix = f".dev{wheel_build_date}" + if wheel_type == "custom": + wheel_version_suffix += ( + f"+{wheel_git_hash}{custom_wheel_version_suffix}" + ) + if wheel in ["jax", "jax-cuda-pjrt"]: + python_tag = "py" + else: + python_tag = "cp" + utils.copy_individual_files( + bazel_dir, + output_path, + f"{wheel_dir}*{wheel_version_suffix}-{python_tag}*.whl", + ) + if wheel == "jax": + utils.copy_individual_files( + bazel_dir, + output_path, + f"{wheel_dir}*{wheel_version_suffix}.tar.gz", + ) + # Exit with success if all wheels in the list were built successfully. sys.exit(0) diff --git a/build/collect-profile-requirements.txt b/build/collect-profile-requirements.txt index da25d4b6ffe1..a334f408e271 100644 --- a/build/collect-profile-requirements.txt +++ b/build/collect-profile-requirements.txt @@ -1,4 +1,5 @@ -tensorflow -tensorboard-plugin-profile -# Needed for the profile plugin to work without error +# TF hasn't released 3.13 wheels yet (b/402590302) +tensorflow; python_version<"3.13" +xprof>=2.19.0 +# Needed for XProf to work without error protobuf diff --git a/build/freethreading-requirements.txt b/build/freethreading-requirements.txt new file mode 100644 index 000000000000..467578870ee9 --- /dev/null +++ b/build/freethreading-requirements.txt @@ -0,0 +1,3 @@ +# Under free-threading, we need an up-to-date numpy at least for the moment. +numpy~=2.2.6; python_version=="3.13" +numpy>=2.2.6; python_version>="3.14" diff --git a/build/gpu-test-requirements.txt b/build/gpu-test-requirements.txt deleted file mode 100644 index ff43f91ba90f..000000000000 --- a/build/gpu-test-requirements.txt +++ /dev/null @@ -1,13 +0,0 @@ -# NVIDIA CUDA dependencies -# Note that the wheels are downloaded only when the targets in bazel command -# contain dependencies on these wheels. -nvidia-cublas-cu12>=12.1.3.1 ; sys_platform == "linux" -nvidia-cuda-cupti-cu12>=12.1.105 ; sys_platform == "linux" -nvidia-cuda-nvcc-cu12>=12.6.85 ; sys_platform == "linux" -nvidia-cuda-runtime-cu12>=12.1.105 ; sys_platform == "linux" -nvidia-cudnn-cu12>=9.1,<10.0 ; sys_platform == "linux" -nvidia-cufft-cu12>=11.0.2.54 ; sys_platform == "linux" -nvidia-cusolver-cu12>=11.4.5.107 ; sys_platform == "linux" -nvidia-cusparse-cu12>=12.1.0.106 ; sys_platform == "linux" -nvidia-nccl-cu12>=2.18.1 ; sys_platform == "linux" -nvidia-nvjitlink-cu12>=12.1.105 ; sys_platform == "linux" diff --git a/build/nonfreethreading-requirements.txt b/build/nonfreethreading-requirements.txt new file mode 100644 index 000000000000..8bd139bf99ac --- /dev/null +++ b/build/nonfreethreading-requirements.txt @@ -0,0 +1,11 @@ +numpy~=2.0.0; python_version<="3.12" +numpy~=2.1.0; python_version=="3.13" +numpy>=2.2.6; python_version>="3.14" + +# These packages have not released free-threaded wheels. +zstandard +tensorstore + +# portpicker is architecture independent, but it depends on psutil which has not +# released a 3.13-ft wheel. +portpicker diff --git a/build/requirements.in b/build/requirements.in index ec7fc71b07e1..a88c194f7b8e 100644 --- a/build/requirements.in +++ b/build/requirements.in @@ -1,22 +1,33 @@ -# -# test deps -# --r test-requirements.txt --r gpu-test-requirements.txt - -# -# build deps -# -numpy~=2.0.0; python_version<="3.12" -numpy~=2.1.0; python_version>="3.13" - # # runtime deps # -scipy>=1.13.1 +scipy>=1.13.1; python_version<="3.12" +scipy>=1.15.2; python_version>="3.13" ml_dtypes>=0.4.0 -opt_einsum -zstandard etils[epath] +opt-einsum + +# Needed to build wheels +build setuptools +wheel + +# JAX's own libraries. We include these in the requirements so you can +# bazel test without building jaxlib and without manually updating the +# the requirements files. +jaxlib==0.6.2 + +# The with-cuda extra also includes NVIDIA's pip packages. +jax-cuda12-plugin[with-cuda]==0.6.2 ; sys_platform == "linux" +jax-cuda12-pjrt==0.6.2 ; sys_platform == "linux" + +# TPU dependencies +libtpu ; sys_platform == "linux" and platform_machine == "x86_64" + +# For Mosaic GPU collectives +nvidia-cuda-nvrtc-cu12>=12.1.55 ; sys_platform == "linux" +nvidia-nvshmem-cu12>=3.2.5 ; sys_platform == "linux" + +# Platform-specific dependencies that are being ignored by pip-compile +colorama>=0.4.4 diff --git a/build/requirements_lock_3_10.txt b/build/requirements_lock_3_10.txt deleted file mode 100644 index 6ed6b59aa584..000000000000 --- a/build/requirements_lock_3_10.txt +++ /dev/null @@ -1,707 +0,0 @@ -# -# This file is autogenerated by pip-compile with Python 3.10 -# by the following command: -# -# bazel run //build:requirements.update -# -absl-py==2.1.0 \ - --hash=sha256:526a04eadab8b4ee719ce68f204172ead1027549089702d99b9059f129ff1308 \ - --hash=sha256:7820790efbb316739cde8b4e19357243fc3608a152024288513dd968d7d959ff - # via -r build/test-requirements.txt -attrs==23.2.0 \ - --hash=sha256:935dc3b529c262f6cf76e50877d35a4bd3c1de194fd41f47a2b7ae8f19971f30 \ - --hash=sha256:99b87a485a5820b23b879f04c2305b44b951b502fd64be915879d77a7e8fc6f1 - # via hypothesis -auditwheel==6.1.0 \ - --hash=sha256:3bdc686e774cf9e355e924b0fe5a562d55caa385d72234ffe7b81b378dba360f \ - --hash=sha256:e52f734861859e3743eb29fcac7da9c4921a1e4bea58f954b52f2926f8e9e364 - # via -r build/test-requirements.txt -build==1.2.1 \ - --hash=sha256:526263f4870c26f26c433545579475377b2b7588b6f1eac76a001e873ae3e19d \ - --hash=sha256:75e10f767a433d9a86e50d83f418e83efc18ede923ee5ff7df93b6cb0306c5d4 - # via -r build/test-requirements.txt -cloudpickle==3.0.0 \ - --hash=sha256:246ee7d0c295602a036e86369c77fecda4ab17b506496730f2f576d9016fd9c7 \ - --hash=sha256:996d9a482c6fb4f33c1a35335cf8afd065d2a56e973270364840712d9131a882 - # via -r build/test-requirements.txt -colorama==0.4.6 \ - --hash=sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44 \ - --hash=sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6 - # via -r build/test-requirements.txt -contourpy==1.2.1 \ - --hash=sha256:00e5388f71c1a0610e6fe56b5c44ab7ba14165cdd6d695429c5cd94021e390b2 \ - --hash=sha256:10a37ae557aabf2509c79715cd20b62e4c7c28b8cd62dd7d99e5ed3ce28c3fd9 \ - --hash=sha256:11959f0ce4a6f7b76ec578576a0b61a28bdc0696194b6347ba3f1c53827178b9 \ - --hash=sha256:187fa1d4c6acc06adb0fae5544c59898ad781409e61a926ac7e84b8f276dcef4 \ - --hash=sha256:1a07fc092a4088ee952ddae19a2b2a85757b923217b7eed584fdf25f53a6e7ce \ - --hash=sha256:1cac0a8f71a041aa587410424ad46dfa6a11f6149ceb219ce7dd48f6b02b87a7 \ - --hash=sha256:1d59e739ab0e3520e62a26c60707cc3ab0365d2f8fecea74bfe4de72dc56388f \ - --hash=sha256:2855c8b0b55958265e8b5888d6a615ba02883b225f2227461aa9127c578a4922 \ - --hash=sha256:2e785e0f2ef0d567099b9ff92cbfb958d71c2d5b9259981cd9bee81bd194c9a4 \ - --hash=sha256:309be79c0a354afff9ff7da4aaed7c3257e77edf6c1b448a779329431ee79d7e \ - --hash=sha256:39f3ecaf76cd98e802f094e0d4fbc6dc9c45a8d0c4d185f0f6c2234e14e5f75b \ - --hash=sha256:457499c79fa84593f22454bbd27670227874cd2ff5d6c84e60575c8b50a69619 \ - --hash=sha256:49e70d111fee47284d9dd867c9bb9a7058a3c617274900780c43e38d90fe1205 \ - --hash=sha256:4c75507d0a55378240f781599c30e7776674dbaf883a46d1c90f37e563453480 \ - --hash=sha256:4c863140fafc615c14a4bf4efd0f4425c02230eb8ef02784c9a156461e62c965 \ - --hash=sha256:4d8908b3bee1c889e547867ca4cdc54e5ab6be6d3e078556814a22457f49423c \ - --hash=sha256:5b9eb0ca724a241683c9685a484da9d35c872fd42756574a7cfbf58af26677fd \ - --hash=sha256:6022cecf8f44e36af10bd9118ca71f371078b4c168b6e0fab43d4a889985dbb5 \ - --hash=sha256:6150ffa5c767bc6332df27157d95442c379b7dce3a38dff89c0f39b63275696f \ - --hash=sha256:62828cada4a2b850dbef89c81f5a33741898b305db244904de418cc957ff05dc \ - --hash=sha256:7b4182299f251060996af5249c286bae9361fa8c6a9cda5efc29fe8bfd6062ec \ - --hash=sha256:94b34f32646ca0414237168d68a9157cb3889f06b096612afdd296003fdd32fd \ - --hash=sha256:9ce6889abac9a42afd07a562c2d6d4b2b7134f83f18571d859b25624a331c90b \ - --hash=sha256:9cffe0f850e89d7c0012a1fb8730f75edd4320a0a731ed0c183904fe6ecfc3a9 \ - --hash=sha256:a12a813949e5066148712a0626895c26b2578874e4cc63160bb007e6df3436fe \ - --hash=sha256:a1eea9aecf761c661d096d39ed9026574de8adb2ae1c5bd7b33558af884fb2ce \ - --hash=sha256:a31f94983fecbac95e58388210427d68cd30fe8a36927980fab9c20062645609 \ - --hash=sha256:ac58bdee53cbeba2ecad824fa8159493f0bf3b8ea4e93feb06c9a465d6c87da8 \ - --hash=sha256:af3f4485884750dddd9c25cb7e3915d83c2db92488b38ccb77dd594eac84c4a0 \ - --hash=sha256:b33d2bc4f69caedcd0a275329eb2198f560b325605810895627be5d4b876bf7f \ - --hash=sha256:b59c0ffceff8d4d3996a45f2bb6f4c207f94684a96bf3d9728dbb77428dd8cb8 \ - --hash=sha256:bb6834cbd983b19f06908b45bfc2dad6ac9479ae04abe923a275b5f48f1a186b \ - --hash=sha256:bd3db01f59fdcbce5b22afad19e390260d6d0222f35a1023d9adc5690a889364 \ - --hash=sha256:bd7c23df857d488f418439686d3b10ae2fbf9bc256cd045b37a8c16575ea1040 \ - --hash=sha256:c2528d60e398c7c4c799d56f907664673a807635b857df18f7ae64d3e6ce2d9f \ - --hash=sha256:d31a63bc6e6d87f77d71e1abbd7387ab817a66733734883d1fc0021ed9bfa083 \ - --hash=sha256:d4492d82b3bc7fbb7e3610747b159869468079fe149ec5c4d771fa1f614a14df \ - --hash=sha256:ddcb8581510311e13421b1f544403c16e901c4e8f09083c881fab2be80ee31ba \ - --hash=sha256:e1d59258c3c67c865435d8fbeb35f8c59b8bef3d6f46c1f29f6123556af28445 \ - --hash=sha256:eb3315a8a236ee19b6df481fc5f997436e8ade24a9f03dfdc6bd490fea20c6da \ - --hash=sha256:ef2b055471c0eb466033760a521efb9d8a32b99ab907fc8358481a1dd29e3bd3 \ - --hash=sha256:ef5adb9a3b1d0c645ff694f9bca7702ec2c70f4d734f9922ea34de02294fdf72 \ - --hash=sha256:f32c38afb74bd98ce26de7cc74a67b40afb7b05aae7b42924ea990d51e4dac02 \ - --hash=sha256:fe0ccca550bb8e5abc22f530ec0466136379c01321fd94f30a22231e8a48d985 - # via matplotlib -cycler==0.12.1 \ - --hash=sha256:85cef7cff222d8644161529808465972e51340599459b8ac3ccbac5a854e0d30 \ - --hash=sha256:88bb128f02ba341da8ef447245a9e138fae777f6a23943da4540077d3601eb1c - # via matplotlib -etils[epath,epy]==1.7.0 \ - --hash=sha256:61af8f7c242171de15e22e5da02d527cb9e677d11f8bcafe18fcc3548eee3e60 \ - --hash=sha256:97b68fd25e185683215286ef3a54e38199b6245f5fe8be6bedc1189be4256350 - # via -r build/requirements.in -exceptiongroup==1.2.1 \ - --hash=sha256:5258b9ed329c5bbdd31a309f53cbfb0b155341807f6ff7606a1e801a891b29ad \ - --hash=sha256:a4785e48b045528f5bfe627b6ad554ff32def154f42372786903b7abcfe1aa16 - # via - # hypothesis - # pytest -execnet==2.1.1 \ - --hash=sha256:26dee51f1b80cebd6d0ca8e74dd8745419761d3bef34163928cbebbdc4749fdc \ - --hash=sha256:5189b52c6121c24feae288166ab41b32549c7e2348652736540b9e6e7d4e72e3 - # via pytest-xdist -filelock==3.14.0 \ - --hash=sha256:43339835842f110ca7ae60f1e1c160714c5a6afd15a2873419ab185334975c0f \ - --hash=sha256:6ea72da3be9b8c82afd3edcf99f2fffbb5076335a5ae4d03248bb5b6c3eae78a - # via -r build/test-requirements.txt -flatbuffers==24.3.25 \ - --hash=sha256:8dbdec58f935f3765e4f7f3cf635ac3a77f83568138d6a2311f524ec96364812 \ - --hash=sha256:de2ec5b203f21441716617f38443e0a8ebf3d25bf0d9c0bb0ce68fa00ad546a4 - # via -r build/test-requirements.txt -fonttools==4.51.0 \ - --hash=sha256:0118ef998a0699a96c7b28457f15546815015a2710a1b23a7bf6c1be60c01636 \ - --hash=sha256:0d145976194a5242fdd22df18a1b451481a88071feadf251221af110ca8f00ce \ - --hash=sha256:0e19bd9e9964a09cd2433a4b100ca7f34e34731e0758e13ba9a1ed6e5468cc0f \ - --hash=sha256:0f08c901d3866a8905363619e3741c33f0a83a680d92a9f0e575985c2634fcc1 \ - --hash=sha256:1250e818b5f8a679ad79660855528120a8f0288f8f30ec88b83db51515411fcc \ - --hash=sha256:15c94eeef6b095831067f72c825eb0e2d48bb4cea0647c1b05c981ecba2bf39f \ - --hash=sha256:1621ee57da887c17312acc4b0e7ac30d3a4fb0fec6174b2e3754a74c26bbed1e \ - --hash=sha256:180194c7fe60c989bb627d7ed5011f2bef1c4d36ecf3ec64daec8302f1ae0716 \ - --hash=sha256:278e50f6b003c6aed19bae2242b364e575bcb16304b53f2b64f6551b9c000e15 \ - --hash=sha256:32b17504696f605e9e960647c5f64b35704782a502cc26a37b800b4d69ff3c77 \ - --hash=sha256:3bee3f3bd9fa1d5ee616ccfd13b27ca605c2b4270e45715bd2883e9504735034 \ - --hash=sha256:4060acc2bfa2d8e98117828a238889f13b6f69d59f4f2d5857eece5277b829ba \ - --hash=sha256:54dcf21a2f2d06ded676e3c3f9f74b2bafded3a8ff12f0983160b13e9f2fb4a7 \ - --hash=sha256:56fc244f2585d6c00b9bcc59e6593e646cf095a96fe68d62cd4da53dd1287b55 \ - --hash=sha256:599bdb75e220241cedc6faebfafedd7670335d2e29620d207dd0378a4e9ccc5a \ - --hash=sha256:5f6bc991d1610f5c3bbe997b0233cbc234b8e82fa99fc0b2932dc1ca5e5afec0 \ - --hash=sha256:60a3409c9112aec02d5fb546f557bca6efa773dcb32ac147c6baf5f742e6258b \ - --hash=sha256:68b3fb7775a923be73e739f92f7e8a72725fd333eab24834041365d2278c3671 \ - --hash=sha256:76f1777d8b3386479ffb4a282e74318e730014d86ce60f016908d9801af9ca2a \ - --hash=sha256:806e7912c32a657fa39d2d6eb1d3012d35f841387c8fc6cf349ed70b7c340039 \ - --hash=sha256:84d7751f4468dd8cdd03ddada18b8b0857a5beec80bce9f435742abc9a851a74 \ - --hash=sha256:865a58b6e60b0938874af0968cd0553bcd88e0b2cb6e588727117bd099eef836 \ - --hash=sha256:8ac27f436e8af7779f0bb4d5425aa3535270494d3bc5459ed27de3f03151e4c2 \ - --hash=sha256:8b4850fa2ef2cfbc1d1f689bc159ef0f45d8d83298c1425838095bf53ef46308 \ - --hash=sha256:8b5ad456813d93b9c4b7ee55302208db2b45324315129d85275c01f5cb7e61a2 \ - --hash=sha256:8e2f1a4499e3b5ee82c19b5ee57f0294673125c65b0a1ff3764ea1f9db2f9ef5 \ - --hash=sha256:9696fe9f3f0c32e9a321d5268208a7cc9205a52f99b89479d1b035ed54c923f1 \ - --hash=sha256:96a48e137c36be55e68845fc4284533bda2980f8d6f835e26bca79d7e2006438 \ - --hash=sha256:a8feca65bab31479d795b0d16c9a9852902e3a3c0630678efb0b2b7941ea9c74 \ - --hash=sha256:aefa011207ed36cd280babfaa8510b8176f1a77261833e895a9d96e57e44802f \ - --hash=sha256:b2b92381f37b39ba2fc98c3a45a9d6383bfc9916a87d66ccb6553f7bdd129097 \ - --hash=sha256:b3c61423f22165541b9403ee39874dcae84cd57a9078b82e1dce8cb06b07fa2e \ - --hash=sha256:b5b48a1121117047d82695d276c2af2ee3a24ffe0f502ed581acc2673ecf1037 \ - --hash=sha256:c18b49adc721a7d0b8dfe7c3130c89b8704baf599fb396396d07d4aa69b824a1 \ - --hash=sha256:c5b8cab0c137ca229433570151b5c1fc6af212680b58b15abd797dcdd9dd5051 \ - --hash=sha256:c7e91abdfae1b5c9e3a543f48ce96013f9a08c6c9668f1e6be0beabf0a569c1b \ - --hash=sha256:cadf4e12a608ef1d13e039864f484c8a968840afa0258b0b843a0556497ea9ed \ - --hash=sha256:dc0673361331566d7a663d7ce0f6fdcbfbdc1f59c6e3ed1165ad7202ca183c68 \ - --hash=sha256:de7c29bdbdd35811f14493ffd2534b88f0ce1b9065316433b22d63ca1cd21f14 \ - --hash=sha256:e9d9298be7a05bb4801f558522adbe2feea1b0b103d5294ebf24a92dd49b78e5 \ - --hash=sha256:ee1af4be1c5afe4c96ca23badd368d8dc75f611887fb0c0dac9f71ee5d6f110e \ - --hash=sha256:f7e89853d8bea103c8e3514b9f9dc86b5b4120afb4583b57eb10dfa5afbe0936 - # via matplotlib -fsspec==2024.5.0 \ - --hash=sha256:1d021b0b0f933e3b3029ed808eb400c08ba101ca2de4b3483fbc9ca23fcee94a \ - --hash=sha256:e0fdbc446d67e182f49a70b82cf7889028a63588fde6b222521f10937b2b670c - # via etils -hypothesis==6.102.4 \ - --hash=sha256:013df31b04a4daede13756f497e60e451963d86f426395a79f99c5d692919bbd \ - --hash=sha256:59b4d144346d5cffb482cc1bafbd21b13ff31608e8c4b3e4630339aee3e87763 - # via -r build/test-requirements.txt -importlib-resources==6.4.0 \ - --hash=sha256:50d10f043df931902d4194ea07ec57960f66a80449ff867bfe782b4c486ba78c \ - --hash=sha256:cdb2b453b8046ca4e3798eb1d84f3cce1446a0e8e7b5ef4efb600f19fc398145 - # via etils -iniconfig==2.0.0 \ - --hash=sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3 \ - --hash=sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374 - # via pytest -kiwisolver==1.4.5 \ - --hash=sha256:00bd361b903dc4bbf4eb165f24d1acbee754fce22ded24c3d56eec268658a5cf \ - --hash=sha256:040c1aebeda72197ef477a906782b5ab0d387642e93bda547336b8957c61022e \ - --hash=sha256:05703cf211d585109fcd72207a31bb170a0f22144d68298dc5e61b3c946518af \ - --hash=sha256:06f54715b7737c2fecdbf140d1afb11a33d59508a47bf11bb38ecf21dc9ab79f \ - --hash=sha256:0dc9db8e79f0036e8173c466d21ef18e1befc02de8bf8aa8dc0813a6dc8a7046 \ - --hash=sha256:0f114aa76dc1b8f636d077979c0ac22e7cd8f3493abbab152f20eb8d3cda71f3 \ - --hash=sha256:11863aa14a51fd6ec28688d76f1735f8f69ab1fabf388851a595d0721af042f5 \ - --hash=sha256:11c7de8f692fc99816e8ac50d1d1aef4f75126eefc33ac79aac02c099fd3db71 \ - --hash=sha256:11d011a7574eb3b82bcc9c1a1d35c1d7075677fdd15de527d91b46bd35e935ee \ - --hash=sha256:146d14bebb7f1dc4d5fbf74f8a6cb15ac42baadee8912eb84ac0b3b2a3dc6ac3 \ - --hash=sha256:15568384086b6df3c65353820a4473575dbad192e35010f622c6ce3eebd57af9 \ - --hash=sha256:19df6e621f6d8b4b9c4d45f40a66839294ff2bb235e64d2178f7522d9170ac5b \ - --hash=sha256:1b04139c4236a0f3aff534479b58f6f849a8b351e1314826c2d230849ed48985 \ - --hash=sha256:210ef2c3a1f03272649aff1ef992df2e724748918c4bc2d5a90352849eb40bea \ - --hash=sha256:2270953c0d8cdab5d422bee7d2007f043473f9d2999631c86a223c9db56cbd16 \ - --hash=sha256:2400873bccc260b6ae184b2b8a4fec0e4082d30648eadb7c3d9a13405d861e89 \ - --hash=sha256:2a40773c71d7ccdd3798f6489aaac9eee213d566850a9533f8d26332d626b82c \ - --hash=sha256:2c5674c4e74d939b9d91dda0fae10597ac7521768fec9e399c70a1f27e2ea2d9 \ - --hash=sha256:3195782b26fc03aa9c6913d5bad5aeb864bdc372924c093b0f1cebad603dd712 \ - --hash=sha256:31a82d498054cac9f6d0b53d02bb85811185bcb477d4b60144f915f3b3126342 \ - --hash=sha256:32d5cf40c4f7c7b3ca500f8985eb3fb3a7dfc023215e876f207956b5ea26632a \ - --hash=sha256:346f5343b9e3f00b8db8ba359350eb124b98c99efd0b408728ac6ebf38173958 \ - --hash=sha256:378a214a1e3bbf5ac4a8708304318b4f890da88c9e6a07699c4ae7174c09a68d \ - --hash=sha256:39b42c68602539407884cf70d6a480a469b93b81b7701378ba5e2328660c847a \ - --hash=sha256:3a2b053a0ab7a3960c98725cfb0bf5b48ba82f64ec95fe06f1d06c99b552e130 \ - --hash=sha256:3aba7311af82e335dd1e36ffff68aaca609ca6290c2cb6d821a39aa075d8e3ff \ - --hash=sha256:3cd32d6c13807e5c66a7cbb79f90b553642f296ae4518a60d8d76243b0ad2898 \ - --hash=sha256:3edd2fa14e68c9be82c5b16689e8d63d89fe927e56debd6e1dbce7a26a17f81b \ - --hash=sha256:4c380469bd3f970ef677bf2bcba2b6b0b4d5c75e7a020fb863ef75084efad66f \ - --hash=sha256:4e66e81a5779b65ac21764c295087de82235597a2293d18d943f8e9e32746265 \ - --hash=sha256:53abb58632235cd154176ced1ae8f0d29a6657aa1aa9decf50b899b755bc2b93 \ - --hash=sha256:5794cf59533bc3f1b1c821f7206a3617999db9fbefc345360aafe2e067514929 \ - --hash=sha256:59415f46a37f7f2efeec758353dd2eae1b07640d8ca0f0c42548ec4125492635 \ - --hash=sha256:59ec7b7c7e1a61061850d53aaf8e93db63dce0c936db1fda2658b70e4a1be709 \ - --hash=sha256:59edc41b24031bc25108e210c0def6f6c2191210492a972d585a06ff246bb79b \ - --hash=sha256:5a580c91d686376f0f7c295357595c5a026e6cbc3d77b7c36e290201e7c11ecb \ - --hash=sha256:5b94529f9b2591b7af5f3e0e730a4e0a41ea174af35a4fd067775f9bdfeee01a \ - --hash=sha256:5c7b3b3a728dc6faf3fc372ef24f21d1e3cee2ac3e9596691d746e5a536de920 \ - --hash=sha256:5c90ae8c8d32e472be041e76f9d2f2dbff4d0b0be8bd4041770eddb18cf49a4e \ - --hash=sha256:5e7139af55d1688f8b960ee9ad5adafc4ac17c1c473fe07133ac092310d76544 \ - --hash=sha256:5ff5cf3571589b6d13bfbfd6bcd7a3f659e42f96b5fd1c4830c4cf21d4f5ef45 \ - --hash=sha256:620ced262a86244e2be10a676b646f29c34537d0d9cc8eb26c08f53d98013390 \ - --hash=sha256:6512cb89e334e4700febbffaaa52761b65b4f5a3cf33f960213d5656cea36a77 \ - --hash=sha256:6c08e1312a9cf1074d17b17728d3dfce2a5125b2d791527f33ffbe805200a355 \ - --hash=sha256:6c3bd3cde54cafb87d74d8db50b909705c62b17c2099b8f2e25b461882e544ff \ - --hash=sha256:6ef7afcd2d281494c0a9101d5c571970708ad911d028137cd558f02b851c08b4 \ - --hash=sha256:7269d9e5f1084a653d575c7ec012ff57f0c042258bf5db0954bf551c158466e7 \ - --hash=sha256:72d40b33e834371fd330fb1472ca19d9b8327acb79a5821d4008391db8e29f20 \ - --hash=sha256:74d1b44c6cfc897df648cc9fdaa09bc3e7679926e6f96df05775d4fb3946571c \ - --hash=sha256:74db36e14a7d1ce0986fa104f7d5637aea5c82ca6326ed0ec5694280942d1162 \ - --hash=sha256:763773d53f07244148ccac5b084da5adb90bfaee39c197554f01b286cf869228 \ - --hash=sha256:76c6a5964640638cdeaa0c359382e5703e9293030fe730018ca06bc2010c4437 \ - --hash=sha256:76d9289ed3f7501012e05abb8358bbb129149dbd173f1f57a1bf1c22d19ab7cc \ - --hash=sha256:7931d8f1f67c4be9ba1dd9c451fb0eeca1a25b89e4d3f89e828fe12a519b782a \ - --hash=sha256:7b8b454bac16428b22560d0a1cf0a09875339cab69df61d7805bf48919415901 \ - --hash=sha256:7e5bab140c309cb3a6ce373a9e71eb7e4873c70c2dda01df6820474f9889d6d4 \ - --hash=sha256:83d78376d0d4fd884e2c114d0621624b73d2aba4e2788182d286309ebdeed770 \ - --hash=sha256:852542f9481f4a62dbb5dd99e8ab7aedfeb8fb6342349a181d4036877410f525 \ - --hash=sha256:85267bd1aa8880a9c88a8cb71e18d3d64d2751a790e6ca6c27b8ccc724bcd5ad \ - --hash=sha256:88a2df29d4724b9237fc0c6eaf2a1adae0cdc0b3e9f4d8e7dc54b16812d2d81a \ - --hash=sha256:88b9f257ca61b838b6f8094a62418421f87ac2a1069f7e896c36a7d86b5d4c29 \ - --hash=sha256:8ab3919a9997ab7ef2fbbed0cc99bb28d3c13e6d4b1ad36e97e482558a91be90 \ - --hash=sha256:92dea1ffe3714fa8eb6a314d2b3c773208d865a0e0d35e713ec54eea08a66250 \ - --hash=sha256:9407b6a5f0d675e8a827ad8742e1d6b49d9c1a1da5d952a67d50ef5f4170b18d \ - --hash=sha256:9408acf3270c4b6baad483865191e3e582b638b1654a007c62e3efe96f09a9a3 \ - --hash=sha256:955e8513d07a283056b1396e9a57ceddbd272d9252c14f154d450d227606eb54 \ - --hash=sha256:9db8ea4c388fdb0f780fe91346fd438657ea602d58348753d9fb265ce1bca67f \ - --hash=sha256:9eaa8b117dc8337728e834b9c6e2611f10c79e38f65157c4c38e9400286f5cb1 \ - --hash=sha256:a51a263952b1429e429ff236d2f5a21c5125437861baeed77f5e1cc2d2c7c6da \ - --hash=sha256:a6aa6315319a052b4ee378aa171959c898a6183f15c1e541821c5c59beaa0238 \ - --hash=sha256:aa12042de0171fad672b6c59df69106d20d5596e4f87b5e8f76df757a7c399aa \ - --hash=sha256:aaf7be1207676ac608a50cd08f102f6742dbfc70e8d60c4db1c6897f62f71523 \ - --hash=sha256:b0157420efcb803e71d1b28e2c287518b8808b7cf1ab8af36718fd0a2c453eb0 \ - --hash=sha256:b3f7e75f3015df442238cca659f8baa5f42ce2a8582727981cbfa15fee0ee205 \ - --hash=sha256:b9098e0049e88c6a24ff64545cdfc50807818ba6c1b739cae221bbbcbc58aad3 \ - --hash=sha256:ba55dce0a9b8ff59495ddd050a0225d58bd0983d09f87cfe2b6aec4f2c1234e4 \ - --hash=sha256:bb86433b1cfe686da83ce32a9d3a8dd308e85c76b60896d58f082136f10bffac \ - --hash=sha256:bbea0db94288e29afcc4c28afbf3a7ccaf2d7e027489c449cf7e8f83c6346eb9 \ - --hash=sha256:bbf1d63eef84b2e8c89011b7f2235b1e0bf7dacc11cac9431fc6468e99ac77fb \ - --hash=sha256:c7940c1dc63eb37a67721b10d703247552416f719c4188c54e04334321351ced \ - --hash=sha256:c9bf3325c47b11b2e51bca0824ea217c7cd84491d8ac4eefd1e409705ef092bd \ - --hash=sha256:cdc8a402aaee9a798b50d8b827d7ecf75edc5fb35ea0f91f213ff927c15f4ff0 \ - --hash=sha256:ceec1a6bc6cab1d6ff5d06592a91a692f90ec7505d6463a88a52cc0eb58545da \ - --hash=sha256:cfe6ab8da05c01ba6fbea630377b5da2cd9bcbc6338510116b01c1bc939a2c18 \ - --hash=sha256:d099e745a512f7e3bbe7249ca835f4d357c586d78d79ae8f1dcd4d8adeb9bda9 \ - --hash=sha256:d0ef46024e6a3d79c01ff13801cb19d0cad7fd859b15037aec74315540acc276 \ - --hash=sha256:d2e5a98f0ec99beb3c10e13b387f8db39106d53993f498b295f0c914328b1333 \ - --hash=sha256:da4cfb373035def307905d05041c1d06d8936452fe89d464743ae7fb8371078b \ - --hash=sha256:da802a19d6e15dffe4b0c24b38b3af68e6c1a68e6e1d8f30148c83864f3881db \ - --hash=sha256:dced8146011d2bc2e883f9bd68618b8247387f4bbec46d7392b3c3b032640126 \ - --hash=sha256:dfdd7c0b105af050eb3d64997809dc21da247cf44e63dc73ff0fd20b96be55a9 \ - --hash=sha256:e368f200bbc2e4f905b8e71eb38b3c04333bddaa6a2464a6355487b02bb7fb09 \ - --hash=sha256:e391b1f0a8a5a10ab3b9bb6afcfd74f2175f24f8975fb87ecae700d1503cdee0 \ - --hash=sha256:e57e563a57fb22a142da34f38acc2fc1a5c864bc29ca1517a88abc963e60d6ec \ - --hash=sha256:e5d706eba36b4c4d5bc6c6377bb6568098765e990cfc21ee16d13963fab7b3e7 \ - --hash=sha256:ec20916e7b4cbfb1f12380e46486ec4bcbaa91a9c448b97023fde0d5bbf9e4ff \ - --hash=sha256:f1d072c2eb0ad60d4c183f3fb44ac6f73fb7a8f16a2694a91f988275cbf352f9 \ - --hash=sha256:f846c260f483d1fd217fe5ed7c173fb109efa6b1fc8381c8b7552c5781756192 \ - --hash=sha256:f91de7223d4c7b793867797bacd1ee53bfe7359bd70d27b7b58a04efbb9436c8 \ - --hash=sha256:faae4860798c31530dd184046a900e652c95513796ef51a12bc086710c2eec4d \ - --hash=sha256:fc579bf0f502e54926519451b920e875f433aceb4624a3646b3252b5caa9e0b6 \ - --hash=sha256:fcc700eadbbccbf6bc1bcb9dbe0786b4b1cb91ca0dcda336eef5c2beed37b797 \ - --hash=sha256:fd32ea360bcbb92d28933fc05ed09bffcb1704ba3fc7942e81db0fd4f81a7892 \ - --hash=sha256:fdb7adb641a0d13bdcd4ef48e062363d8a9ad4a182ac7647ec88f695e719ae9f - # via matplotlib -markdown-it-py==3.0.0 \ - --hash=sha256:355216845c60bd96232cd8d8c40e8f9765cc86f46880e43a8fd22dc1a1a8cab1 \ - --hash=sha256:e3f60a94fa066dc52ec76661e37c851cb232d92f9886b15cb560aaada2df8feb - # via rich -matplotlib==3.8.4 ; python_version <= "3.10" \ - --hash=sha256:1c13f041a7178f9780fb61cc3a2b10423d5e125480e4be51beaf62b172413b67 \ - --hash=sha256:232ce322bfd020a434caaffbd9a95333f7c2491e59cfc014041d95e38ab90d1c \ - --hash=sha256:493e9f6aa5819156b58fce42b296ea31969f2aab71c5b680b4ea7a3cb5c07d94 \ - --hash=sha256:50bac6e4d77e4262c4340d7a985c30912054745ec99756ce213bfbc3cb3808eb \ - --hash=sha256:606e3b90897554c989b1e38a258c626d46c873523de432b1462f295db13de6f9 \ - --hash=sha256:6209e5c9aaccc056e63b547a8152661324404dd92340a6e479b3a7f24b42a5d0 \ - --hash=sha256:6485ac1f2e84676cff22e693eaa4fbed50ef5dc37173ce1f023daef4687df616 \ - --hash=sha256:6addbd5b488aedb7f9bc19f91cd87ea476206f45d7116fcfe3d31416702a82fa \ - --hash=sha256:72f9322712e4562e792b2961971891b9fbbb0e525011e09ea0d1f416c4645661 \ - --hash=sha256:7a6769f58ce51791b4cb8b4d7642489df347697cd3e23d88266aaaee93b41d9a \ - --hash=sha256:8080d5081a86e690d7688ffa542532e87f224c38a6ed71f8fbed34dd1d9fedae \ - --hash=sha256:843cbde2f0946dadd8c5c11c6d91847abd18ec76859dc319362a0964493f0ba6 \ - --hash=sha256:8aac397d5e9ec158960e31c381c5ffc52ddd52bd9a47717e2a694038167dffea \ - --hash=sha256:8f65c9f002d281a6e904976007b2d46a1ee2bcea3a68a8c12dda24709ddc9106 \ - --hash=sha256:90df07db7b599fe7035d2f74ab7e438b656528c68ba6bb59b7dc46af39ee48ef \ - --hash=sha256:9bb0189011785ea794ee827b68777db3ca3f93f3e339ea4d920315a0e5a78d54 \ - --hash=sha256:a0e47eda4eb2614300fc7bb4657fced3e83d6334d03da2173b09e447418d499f \ - --hash=sha256:abc9d838f93583650c35eca41cfcec65b2e7cb50fd486da6f0c49b5e1ed23014 \ - --hash=sha256:ac24233e8f2939ac4fd2919eed1e9c0871eac8057666070e94cbf0b33dd9c338 \ - --hash=sha256:b12ba985837e4899b762b81f5b2845bd1a28f4fdd1a126d9ace64e9c4eb2fb25 \ - --hash=sha256:b7a2a253d3b36d90c8993b4620183b55665a429da8357a4f621e78cd48b2b30b \ - --hash=sha256:c7064120a59ce6f64103c9cefba8ffe6fba87f2c61d67c401186423c9a20fd35 \ - --hash=sha256:c89ee9314ef48c72fe92ce55c4e95f2f39d70208f9f1d9db4e64079420d8d732 \ - --hash=sha256:cc4ccdc64e3039fc303defd119658148f2349239871db72cd74e2eeaa9b80b71 \ - --hash=sha256:ce1edd9f5383b504dbc26eeea404ed0a00656c526638129028b758fd43fc5f10 \ - --hash=sha256:ecd79298550cba13a43c340581a3ec9c707bd895a6a061a78fa2524660482fc0 \ - --hash=sha256:f51c4c869d4b60d769f7b4406eec39596648d9d70246428745a681c327a8ad30 \ - --hash=sha256:fb44f53af0a62dc80bba4443d9b27f2fde6acfdac281d95bc872dc148a6509cc - # via -r build/test-requirements.txt -mdurl==0.1.2 \ - --hash=sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8 \ - --hash=sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba - # via markdown-it-py -ml-dtypes==0.5.1 \ - --hash=sha256:023ce2f502efd4d6c1e0472cc58ce3640d051d40e71e27386bed33901e201327 \ - --hash=sha256:05f23447a1c20ddf4dc7c2c661aa9ed93fcb2658f1017c204d1e758714dc28a8 \ - --hash=sha256:12651420130ee7cc13059fc56dac6ad300c3af3848b802d475148c9defd27c23 \ - --hash=sha256:141b2ea2f20bb10802ddca55d91fe21231ef49715cfc971998e8f2a9838f3dbe \ - --hash=sha256:15ad0f3b0323ce96c24637a88a6f44f6713c64032f27277b069f285c3cf66478 \ - --hash=sha256:1b7fbe5571fdf28fd3aaab3ef4aafc847de9ebf263be959958c1ca58ec8eadf5 \ - --hash=sha256:26ebcc69d7b779c8f129393e99732961b5cc33fcff84090451f448c89b0e01b4 \ - --hash=sha256:6f462f5eca22fb66d7ff9c4744a3db4463af06c49816c4b6ac89b16bfcdc592e \ - --hash=sha256:6f76232163b5b9c34291b54621ee60417601e2e4802a188a0ea7157cd9b323f4 \ - --hash=sha256:7000b6e4d8ef07542c05044ec5d8bbae1df083b3f56822c3da63993a113e716f \ - --hash=sha256:810512e2eccdfc3b41eefa3a27402371a3411453a1efc7e9c000318196140fed \ - --hash=sha256:8f2c028954f16ede77902b223a8da2d9cbb3892375b85809a5c3cfb1587960c4 \ - --hash=sha256:9626d0bca1fb387d5791ca36bacbba298c5ef554747b7ebeafefb4564fc83566 \ - --hash=sha256:ac5b58559bb84a95848ed6984eb8013249f90b6bab62aa5acbad876e256002c9 \ - --hash=sha256:ad4953c5eb9c25a56d11a913c2011d7e580a435ef5145f804d98efa14477d390 \ - --hash=sha256:aefedc579ece2f8fb38f876aa7698204ee4c372d0e54f1c1ffa8ca580b54cc60 \ - --hash=sha256:afb2009ac98da274e893e03162f6269398b2b00d947e7057ee2469a921d58135 \ - --hash=sha256:b8a9d46b4df5ae2135a8e8e72b465448ebbc1559997f4f9304a9ecc3413efb5b \ - --hash=sha256:bd73f51957949069573ff783563486339a9285d72e2f36c18e0c1aa9ca7eb190 \ - --hash=sha256:bf9975bda82a99dc935f2ae4c83846d86df8fd6ba179614acac8e686910851da \ - --hash=sha256:c09526488c3a9e8b7a23a388d4974b670a9a3dd40c5c8a61db5593ce9b725bab \ - --hash=sha256:c9945669d3dadf8acb40ec2e57d38c985d8c285ea73af57fc5b09872c516106d \ - --hash=sha256:d13755f8e8445b3870114e5b6240facaa7cb0c3361e54beba3e07fa912a6e12b \ - --hash=sha256:fd918d4e6a4e0c110e2e05be7a7814d10dc1b95872accbf6512b80a109b71ae1 - # via -r build/requirements.in -mpmath==1.4.0a1 \ - --hash=sha256:78884400f439f500fa76be0121a8f9598313d87664863a192e1185ddbd7ae97f \ - --hash=sha256:f8b7b5a3a1726ab6e8c898eb2157426b82c482ab1ab8ffed9f88bb9e07c6e9c1 - # via -r build/test-requirements.txt -numpy==2.0.0 ; python_version <= "3.12" \ - --hash=sha256:04494f6ec467ccb5369d1808570ae55f6ed9b5809d7f035059000a37b8d7e86f \ - --hash=sha256:0a43f0974d501842866cc83471bdb0116ba0dffdbaac33ec05e6afed5b615238 \ - --hash=sha256:0e50842b2295ba8414c8c1d9d957083d5dfe9e16828b37de883f51fc53c4016f \ - --hash=sha256:0ec84b9ba0654f3b962802edc91424331f423dcf5d5f926676e0150789cb3d95 \ - --hash=sha256:17067d097ed036636fa79f6a869ac26df7db1ba22039d962422506640314933a \ - --hash=sha256:1cde1753efe513705a0c6d28f5884e22bdc30438bf0085c5c486cdaff40cd67a \ - --hash=sha256:1e72728e7501a450288fc8e1f9ebc73d90cfd4671ebbd631f3e7857c39bd16f2 \ - --hash=sha256:2635dbd200c2d6faf2ef9a0d04f0ecc6b13b3cad54f7c67c61155138835515d2 \ - --hash=sha256:2ce46fd0b8a0c947ae047d222f7136fc4d55538741373107574271bc00e20e8f \ - --hash=sha256:34f003cb88b1ba38cb9a9a4a3161c1604973d7f9d5552c38bc2f04f829536609 \ - --hash=sha256:354f373279768fa5a584bac997de6a6c9bc535c482592d7a813bb0c09be6c76f \ - --hash=sha256:38ecb5b0582cd125f67a629072fed6f83562d9dd04d7e03256c9829bdec027ad \ - --hash=sha256:3e8e01233d57639b2e30966c63d36fcea099d17c53bf424d77f088b0f4babd86 \ - --hash=sha256:3f6bed7f840d44c08ebdb73b1825282b801799e325bcbdfa6bc5c370e5aecc65 \ - --hash=sha256:4554eb96f0fd263041baf16cf0881b3f5dafae7a59b1049acb9540c4d57bc8cb \ - --hash=sha256:46e161722e0f619749d1cd892167039015b2c2817296104487cd03ed4a955995 \ - --hash=sha256:49d9f7d256fbc804391a7f72d4a617302b1afac1112fac19b6c6cec63fe7fe8a \ - --hash=sha256:4d2f62e55a4cd9c58c1d9a1c9edaedcd857a73cb6fda875bf79093f9d9086f85 \ - --hash=sha256:5f64641b42b2429f56ee08b4f427a4d2daf916ec59686061de751a55aafa22e4 \ - --hash=sha256:63b92c512d9dbcc37f9d81b123dec99fdb318ba38c8059afc78086fe73820275 \ - --hash=sha256:6d7696c615765091cc5093f76fd1fa069870304beaccfd58b5dcc69e55ef49c1 \ - --hash=sha256:79e843d186c8fb1b102bef3e2bc35ef81160ffef3194646a7fdd6a73c6b97196 \ - --hash=sha256:821eedb7165ead9eebdb569986968b541f9908979c2da8a4967ecac4439bae3d \ - --hash=sha256:84554fc53daa8f6abf8e8a66e076aff6ece62de68523d9f665f32d2fc50fd66e \ - --hash=sha256:8d83bb187fb647643bd56e1ae43f273c7f4dbcdf94550d7938cfc32566756514 \ - --hash=sha256:903703372d46bce88b6920a0cd86c3ad82dae2dbef157b5fc01b70ea1cfc430f \ - --hash=sha256:9416a5c2e92ace094e9f0082c5fd473502c91651fb896bc17690d6fc475128d6 \ - --hash=sha256:9a1712c015831da583b21c5bfe15e8684137097969c6d22e8316ba66b5baabe4 \ - --hash=sha256:9c27f0946a3536403efb0e1c28def1ae6730a72cd0d5878db38824855e3afc44 \ - --hash=sha256:a356364941fb0593bb899a1076b92dfa2029f6f5b8ba88a14fd0984aaf76d0df \ - --hash=sha256:a7039a136017eaa92c1848152827e1424701532ca8e8967fe480fe1569dae581 \ - --hash=sha256:acd3a644e4807e73b4e1867b769fbf1ce8c5d80e7caaef0d90dcdc640dfc9787 \ - --hash=sha256:ad0c86f3455fbd0de6c31a3056eb822fc939f81b1618f10ff3406971893b62a5 \ - --hash=sha256:b4c76e3d4c56f145d41b7b6751255feefae92edbc9a61e1758a98204200f30fc \ - --hash=sha256:b6f6a8f45d0313db07d6d1d37bd0b112f887e1369758a5419c0370ba915b3871 \ - --hash=sha256:c5a59996dc61835133b56a32ebe4ef3740ea5bc19b3983ac60cc32be5a665d54 \ - --hash=sha256:c73aafd1afca80afecb22718f8700b40ac7cab927b8abab3c3e337d70e10e5a2 \ - --hash=sha256:cee6cc0584f71adefe2c908856ccc98702baf95ff80092e4ca46061538a2ba98 \ - --hash=sha256:cef04d068f5fb0518a77857953193b6bb94809a806bd0a14983a8f12ada060c9 \ - --hash=sha256:cf5d1c9e6837f8af9f92b6bd3e86d513cdc11f60fd62185cc49ec7d1aba34864 \ - --hash=sha256:e61155fae27570692ad1d327e81c6cf27d535a5d7ef97648a17d922224b216de \ - --hash=sha256:e7f387600d424f91576af20518334df3d97bc76a300a755f9a8d6e4f5cadd289 \ - --hash=sha256:ed08d2703b5972ec736451b818c2eb9da80d66c3e84aed1deeb0c345fefe461b \ - --hash=sha256:fbd6acc766814ea6443628f4e6751d0da6593dae29c08c0b2606164db026970c \ - --hash=sha256:feff59f27338135776f6d4e2ec7aeeac5d5f7a08a83e80869121ef8164b74af9 - # via - # -r build/requirements.in - # contourpy - # matplotlib - # ml-dtypes - # opt-einsum - # scipy -nvidia-cublas-cu12==12.8.3.14 ; sys_platform == "linux" \ - --hash=sha256:3f0e05e7293598cf61933258b73e66a160c27d59c4422670bf0b79348c04be44 \ - --hash=sha256:93a4e0e386cc7f6e56c822531396de8170ed17068a1e18f987574895044cd8c3 \ - --hash=sha256:9ae5eae500aead01fc4bdfc458209df638b1a3551557ce11a78eea9ece602ae9 - # via - # via -r build/test-requirements.txt - # nvidia-cudnn-cu12 - # nvidia-cusolver-cu12 -nvidia-cuda-cupti-cu12==12.8.57 ; sys_platform == "linux" \ - --hash=sha256:8e0b2eb847de260739bee4a3f66fac31378f4ff49538ff527a38a01a9a39f950 \ - --hash=sha256:bbed719c52a476958a74cfc42f2b95a3fd6b3fd94eb40134acc4601feb4acac3 \ - --hash=sha256:ff154211724fd824e758ce176b66007b558eea19c9a5135fc991827ee147e317 - # via -r build/test-requirements.txt -nvidia-cuda-nvcc-cu12==12.8.61 ; sys_platform == "linux" \ - --hash=sha256:171f605044ba17bc455d19cad289946c3dbea029a90c60dfa7b88e545bc8e329 \ - --hash=sha256:28604ec42aaa09035b0fb7111432e5121bc385580b30c55d2acfb7d644b16548 \ - --hash=sha256:4524739cfc080e9c9e53032912be8f020058e0a7186746d19acef3b6d916ea0b - # via -r build/test-requirements.txt -nvidia-cuda-runtime-cu12==12.8.57 ; sys_platform == "linux" \ - --hash=sha256:534ccebd967b6a44292678fa5da4f00666029cb2ed07a79515ea41ef31fe3ec7 \ - --hash=sha256:75342e28567340b7428ce79a5d6bb6ca5ff9d07b69e7ce00d2c7b4dc23eff0be \ - --hash=sha256:89be637e3ee967323865b85e0f147d75f9a5bd98360befa37481b02dd57af8f5 - # via -r build/test-requirements.txt -nvidia-cudnn-cu12==9.7.1.26 ; sys_platform == "linux" \ - --hash=sha256:6d011159a158f3cfc47bf851aea79e31bcff60d530b70ef70474c84cac484d07 \ - --hash=sha256:7b805b9a4cf9f3da7c5f4ea4a9dff7baf62d1a612d6154a7e0d2ea51ed296241 \ - --hash=sha256:848a61d40ef3b32bd4e1fadb599f0cf04a4b942fbe5fb3be572ad75f9b8c53ef - # via -r build/test-requirements.txt -nvidia-cufft-cu12==11.3.3.41 ; sys_platform == "linux" \ - --hash=sha256:68509dcd7e3306e69d0e2d8a6d21c8b25ed62e6df8aac192ce752f17677398b5 \ - --hash=sha256:da650080ab79fcdf7a4b06aa1b460e99860646b176a43f6208099bdc17836b6a \ - --hash=sha256:f9760612886786601d27a0993bb29ce1f757e6b8b173499d0ecfa850d31b50f8 - # via -r build/test-requirements.txt -nvidia-cusolver-cu12==11.7.2.55 ; sys_platform == "linux" \ - --hash=sha256:0fd9e98246f43c15bee5561147ad235dfdf2d037f5d07c9d41af3f7f72feb7cc \ - --hash=sha256:4d1354102f1e922cee9db51920dba9e2559877cf6ff5ad03a00d853adafb191b \ - --hash=sha256:a5a516c55da5c5aba98420d9bc9bcab18245f21ec87338cc1f930eb18dd411ac - # via -r build/test-requirements.txt -nvidia-cusparse-cu12==12.5.7.53 ; sys_platform == "linux" \ - --hash=sha256:3c1b61eb8c85257ea07e9354606b26397612627fdcd327bfd91ccf6155e7c86d \ - --hash=sha256:82c201d6781bacf6bb7c654f0446728d0fe596dfdd82ef4a04c204ce3e107441 \ - --hash=sha256:d869c6146ca80f4305b62e02d924b4aaced936f8173e3cef536a67eed2a91af1 - # via - # via -r build/test-requirements.txt - # nvidia-cusolver-cu12 -nvidia-nccl-cu12==2.25.1 ; sys_platform == "linux" \ - --hash=sha256:362aed5963fb9ea2ed2f264409baae30143498fd0e5c503aeaa1badd88cdc54a \ - --hash=sha256:4ab428bc915785cc66e8c57cb34c7a64cf739c46702b8db748b6ad6cc7180cf8 - # via -r build/test-requirements.txt -nvidia-nvjitlink-cu12==12.8.61 ; sys_platform == "linux" \ - --hash=sha256:1166a964d25fdc0eae497574d38824305195a5283324a21ccb0ce0c802cbf41c \ - --hash=sha256:45fd79f2ae20bd67e8bc411055939049873bfd8fac70ff13bd4865e0b9bdab17 \ - --hash=sha256:9b80ecab31085dda3ce3b41d043be0ec739216c3fc633b8abe212d5a30026df0 - # via - # via -r build/test-requirements.txt - # nvidia-cufft-cu12 - # nvidia-cusolver-cu12 - # nvidia-cusparse-cu12 -opt-einsum==3.3.0 \ - --hash=sha256:2455e59e3947d3c275477df7f5205b30635e266fe6dc300e3d9f9646bfcea147 \ - --hash=sha256:59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549 - # via - # -r build/requirements.in - # -r build/test-requirements.txt -packaging==24.0 \ - --hash=sha256:2ddfb553fdf02fb784c234c7ba6ccc288296ceabec964ad2eae3777778130bc5 \ - --hash=sha256:eb82c5e3e56209074766e6885bb04b8c38a0c015d0a30036ebe7ece34c9989e9 - # via - # auditwheel - # build - # matplotlib - # pytest -pillow==11.0.0 \ - --hash=sha256:00177a63030d612148e659b55ba99527803288cea7c75fb05766ab7981a8c1b7 \ - --hash=sha256:006bcdd307cc47ba43e924099a038cbf9591062e6c50e570819743f5607404f5 \ - --hash=sha256:084a07ef0821cfe4858fe86652fffac8e187b6ae677e9906e192aafcc1b69903 \ - --hash=sha256:0ae08bd8ffc41aebf578c2af2f9d8749d91f448b3bfd41d7d9ff573d74f2a6b2 \ - --hash=sha256:0e038b0745997c7dcaae350d35859c9715c71e92ffb7e0f4a8e8a16732150f38 \ - --hash=sha256:1187739620f2b365de756ce086fdb3604573337cc28a0d3ac4a01ab6b2d2a6d2 \ - --hash=sha256:16095692a253047fe3ec028e951fa4221a1f3ed3d80c397e83541a3037ff67c9 \ - --hash=sha256:1a61b54f87ab5786b8479f81c4b11f4d61702830354520837f8cc791ebba0f5f \ - --hash=sha256:1c1d72714f429a521d8d2d018badc42414c3077eb187a59579f28e4270b4b0fc \ - --hash=sha256:1e2688958a840c822279fda0086fec1fdab2f95bf2b717b66871c4ad9859d7e8 \ - --hash=sha256:20ec184af98a121fb2da42642dea8a29ec80fc3efbaefb86d8fdd2606619045d \ - --hash=sha256:21a0d3b115009ebb8ac3d2ebec5c2982cc693da935f4ab7bb5c8ebe2f47d36f2 \ - --hash=sha256:224aaa38177597bb179f3ec87eeefcce8e4f85e608025e9cfac60de237ba6316 \ - --hash=sha256:2679d2258b7f1192b378e2893a8a0a0ca472234d4c2c0e6bdd3380e8dfa21b6a \ - --hash=sha256:27a7860107500d813fcd203b4ea19b04babe79448268403172782754870dac25 \ - --hash=sha256:290f2cc809f9da7d6d622550bbf4c1e57518212da51b6a30fe8e0a270a5b78bd \ - --hash=sha256:2e46773dc9f35a1dd28bd6981332fd7f27bec001a918a72a79b4133cf5291dba \ - --hash=sha256:3107c66e43bda25359d5ef446f59c497de2b5ed4c7fdba0894f8d6cf3822dafc \ - --hash=sha256:375b8dd15a1f5d2feafff536d47e22f69625c1aa92f12b339ec0b2ca40263273 \ - --hash=sha256:45c566eb10b8967d71bf1ab8e4a525e5a93519e29ea071459ce517f6b903d7fa \ - --hash=sha256:499c3a1b0d6fc8213519e193796eb1a86a1be4b1877d678b30f83fd979811d1a \ - --hash=sha256:4ad70c4214f67d7466bea6a08061eba35c01b1b89eaa098040a35272a8efb22b \ - --hash=sha256:4b60c9520f7207aaf2e1d94de026682fc227806c6e1f55bba7606d1c94dd623a \ - --hash=sha256:5178952973e588b3f1360868847334e9e3bf49d19e169bbbdfaf8398002419ae \ - --hash=sha256:52a2d8323a465f84faaba5236567d212c3668f2ab53e1c74c15583cf507a0291 \ - --hash=sha256:598b4e238f13276e0008299bd2482003f48158e2b11826862b1eb2ad7c768b97 \ - --hash=sha256:5bd2d3bdb846d757055910f0a59792d33b555800813c3b39ada1829c372ccb06 \ - --hash=sha256:5c39ed17edea3bc69c743a8dd3e9853b7509625c2462532e62baa0732163a904 \ - --hash=sha256:5d203af30149ae339ad1b4f710d9844ed8796e97fda23ffbc4cc472968a47d0b \ - --hash=sha256:5ddbfd761ee00c12ee1be86c9c0683ecf5bb14c9772ddbd782085779a63dd55b \ - --hash=sha256:607bbe123c74e272e381a8d1957083a9463401f7bd01287f50521ecb05a313f8 \ - --hash=sha256:61b887f9ddba63ddf62fd02a3ba7add935d053b6dd7d58998c630e6dbade8527 \ - --hash=sha256:6619654954dc4936fcff82db8eb6401d3159ec6be81e33c6000dfd76ae189947 \ - --hash=sha256:674629ff60030d144b7bca2b8330225a9b11c482ed408813924619c6f302fdbb \ - --hash=sha256:6ec0d5af64f2e3d64a165f490d96368bb5dea8b8f9ad04487f9ab60dc4bb6003 \ - --hash=sha256:6f4dba50cfa56f910241eb7f883c20f1e7b1d8f7d91c750cd0b318bad443f4d5 \ - --hash=sha256:70fbbdacd1d271b77b7721fe3cdd2d537bbbd75d29e6300c672ec6bb38d9672f \ - --hash=sha256:72bacbaf24ac003fea9bff9837d1eedb6088758d41e100c1552930151f677739 \ - --hash=sha256:7326a1787e3c7b0429659e0a944725e1b03eeaa10edd945a86dead1913383944 \ - --hash=sha256:73853108f56df97baf2bb8b522f3578221e56f646ba345a372c78326710d3830 \ - --hash=sha256:73e3a0200cdda995c7e43dd47436c1548f87a30bb27fb871f352a22ab8dcf45f \ - --hash=sha256:75acbbeb05b86bc53cbe7b7e6fe00fbcf82ad7c684b3ad82e3d711da9ba287d3 \ - --hash=sha256:8069c5179902dcdce0be9bfc8235347fdbac249d23bd90514b7a47a72d9fecf4 \ - --hash=sha256:846e193e103b41e984ac921b335df59195356ce3f71dcfd155aa79c603873b84 \ - --hash=sha256:8594f42df584e5b4bb9281799698403f7af489fba84c34d53d1c4bfb71b7c4e7 \ - --hash=sha256:86510e3f5eca0ab87429dd77fafc04693195eec7fd6a137c389c3eeb4cfb77c6 \ - --hash=sha256:8853a3bf12afddfdf15f57c4b02d7ded92c7a75a5d7331d19f4f9572a89c17e6 \ - --hash=sha256:88a58d8ac0cc0e7f3a014509f0455248a76629ca9b604eca7dc5927cc593c5e9 \ - --hash=sha256:8ba470552b48e5835f1d23ecb936bb7f71d206f9dfeee64245f30c3270b994de \ - --hash=sha256:8c676b587da5673d3c75bd67dd2a8cdfeb282ca38a30f37950511766b26858c4 \ - --hash=sha256:8ec4a89295cd6cd4d1058a5e6aec6bf51e0eaaf9714774e1bfac7cfc9051db47 \ - --hash=sha256:94f3e1780abb45062287b4614a5bc0874519c86a777d4a7ad34978e86428b8dd \ - --hash=sha256:9a0f748eaa434a41fccf8e1ee7a3eed68af1b690e75328fd7a60af123c193b50 \ - --hash=sha256:a5629742881bcbc1f42e840af185fd4d83a5edeb96475a575f4da50d6ede337c \ - --hash=sha256:a65149d8ada1055029fcb665452b2814fe7d7082fcb0c5bed6db851cb69b2086 \ - --hash=sha256:b3c5ac4bed7519088103d9450a1107f76308ecf91d6dabc8a33a2fcfb18d0fba \ - --hash=sha256:b4fd7bd29610a83a8c9b564d457cf5bd92b4e11e79a4ee4716a63c959699b306 \ - --hash=sha256:bcd1fb5bb7b07f64c15618c89efcc2cfa3e95f0e3bcdbaf4642509de1942a699 \ - --hash=sha256:c12b5ae868897c7338519c03049a806af85b9b8c237b7d675b8c5e089e4a618e \ - --hash=sha256:c26845094b1af3c91852745ae78e3ea47abf3dbcd1cf962f16b9a5fbe3ee8488 \ - --hash=sha256:c6a660307ca9d4867caa8d9ca2c2658ab685de83792d1876274991adec7b93fa \ - --hash=sha256:c809a70e43c7977c4a42aefd62f0131823ebf7dd73556fa5d5950f5b354087e2 \ - --hash=sha256:c8b2351c85d855293a299038e1f89db92a2f35e8d2f783489c6f0b2b5f3fe8a3 \ - --hash=sha256:cb929ca942d0ec4fac404cbf520ee6cac37bf35be479b970c4ffadf2b6a1cad9 \ - --hash=sha256:d2c0a187a92a1cb5ef2c8ed5412dd8d4334272617f532d4ad4de31e0495bd923 \ - --hash=sha256:d69bfd8ec3219ae71bcde1f942b728903cad25fafe3100ba2258b973bd2bc1b2 \ - --hash=sha256:daffdf51ee5db69a82dd127eabecce20729e21f7a3680cf7cbb23f0829189790 \ - --hash=sha256:e58876c91f97b0952eb766123bfef372792ab3f4e3e1f1a2267834c2ab131734 \ - --hash=sha256:eda2616eb2313cbb3eebbe51f19362eb434b18e3bb599466a1ffa76a033fb916 \ - --hash=sha256:ee217c198f2e41f184f3869f3e485557296d505b5195c513b2bfe0062dc537f1 \ - --hash=sha256:f02541ef64077f22bf4924f225c0fd1248c168f86e4b7abdedd87d6ebaceab0f \ - --hash=sha256:f1b82c27e89fffc6da125d5eb0ca6e68017faf5efc078128cfaa42cf5cb38798 \ - --hash=sha256:fba162b8872d30fea8c52b258a542c5dfd7b235fb5cb352240c8d63b414013eb \ - --hash=sha256:fbbcb7b57dc9c794843e3d1258c0fbf0f48656d46ffe9e09b63bbd6e8cd5d0a2 \ - --hash=sha256:fcb4621042ac4b7865c179bb972ed0da0218a076dc1820ffc48b1d74c1e37fe9 - # via - # -r build/test-requirements.txt - # matplotlib -pluggy==1.5.0 \ - --hash=sha256:2cffa88e94fdc978c4c574f15f9e59b7f4201d439195c3715ca9e2486f1d0cf1 \ - --hash=sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669 - # via pytest -portpicker==1.6.0 \ - --hash=sha256:b2787a41404cf7edbe29b07b9e0ed863b09f2665dcc01c1eb0c2261c1e7d0755 \ - --hash=sha256:bd507fd6f96f65ee02781f2e674e9dc6c99bbfa6e3c39992e3916204c9d431fa - # via -r build/test-requirements.txt -psutil==5.9.8 \ - --hash=sha256:02615ed8c5ea222323408ceba16c60e99c3f91639b07da6373fb7e6539abc56d \ - --hash=sha256:05806de88103b25903dff19bb6692bd2e714ccf9e668d050d144012055cbca73 \ - --hash=sha256:26bd09967ae00920df88e0352a91cff1a78f8d69b3ecabbfe733610c0af486c8 \ - --hash=sha256:27cc40c3493bb10de1be4b3f07cae4c010ce715290a5be22b98493509c6299e2 \ - --hash=sha256:36f435891adb138ed3c9e58c6af3e2e6ca9ac2f365efe1f9cfef2794e6c93b4e \ - --hash=sha256:50187900d73c1381ba1454cf40308c2bf6f34268518b3f36a9b663ca87e65e36 \ - --hash=sha256:611052c4bc70432ec770d5d54f64206aa7203a101ec273a0cd82418c86503bb7 \ - --hash=sha256:6be126e3225486dff286a8fb9a06246a5253f4c7c53b475ea5f5ac934e64194c \ - --hash=sha256:7d79560ad97af658a0f6adfef8b834b53f64746d45b403f225b85c5c2c140eee \ - --hash=sha256:8cb6403ce6d8e047495a701dc7c5bd788add903f8986d523e3e20b98b733e421 \ - --hash=sha256:8db4c1b57507eef143a15a6884ca10f7c73876cdf5d51e713151c1236a0e68cf \ - --hash=sha256:aee678c8720623dc456fa20659af736241f575d79429a0e5e9cf88ae0605cc81 \ - --hash=sha256:bc56c2a1b0d15aa3eaa5a60c9f3f8e3e565303b465dbf57a1b730e7a2b9844e0 \ - --hash=sha256:bd1184ceb3f87651a67b2708d4c3338e9b10c5df903f2e3776b62303b26cb631 \ - --hash=sha256:d06016f7f8625a1825ba3732081d77c94589dca78b7a3fc072194851e88461a4 \ - --hash=sha256:d16bbddf0693323b8c6123dd804100241da461e41d6e332fb0ba6058f630f8c8 - # via portpicker -pyelftools==0.31 \ - --hash=sha256:c774416b10310156879443b81187d182d8d9ee499660380e645918b50bc88f99 \ - --hash=sha256:f52de7b3c7e8c64c8abc04a79a1cf37ac5fb0b8a49809827130b858944840607 - # via auditwheel -pygments==2.18.0 \ - --hash=sha256:786ff802f32e91311bff3889f6e9a86e81505fe99f2735bb6d60ae0c5004f199 \ - --hash=sha256:b8e6aca0523f3ab76fee51799c488e38782ac06eafcf95e7ba832985c8e7b13a - # via rich -pyparsing==3.1.2 \ - --hash=sha256:a1bac0ce561155ecc3ed78ca94d3c9378656ad4c94c1270de543f621420f94ad \ - --hash=sha256:f9db75911801ed778fe61bb643079ff86601aca99fcae6345aa67292038fb742 - # via matplotlib -pyproject-hooks==1.1.0 \ - --hash=sha256:4b37730834edbd6bd37f26ece6b44802fb1c1ee2ece0e54ddff8bfc06db86965 \ - --hash=sha256:7ceeefe9aec63a1064c18d939bdc3adf2d8aa1988a510afec15151578b232aa2 - # via build -pytest==8.2.0 \ - --hash=sha256:1733f0620f6cda4095bbf0d9ff8022486e91892245bb9e7d5542c018f612f233 \ - --hash=sha256:d507d4482197eac0ba2bae2e9babf0672eb333017bcedaa5fb1a3d42c1174b3f - # via pytest-xdist -pytest-xdist==3.6.1 \ - --hash=sha256:9ed4adfb68a016610848639bb7e02c9352d5d9f03d04809919e2dafc3be4cca7 \ - --hash=sha256:ead156a4db231eec769737f57668ef58a2084a34b2e55c4a8fa20d861107300d - # via -r build/test-requirements.txt -python-dateutil==2.9.0.post0 \ - --hash=sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3 \ - --hash=sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427 - # via matplotlib -rich==13.7.1 \ - --hash=sha256:4edbae314f59eb482f54e9e30bf00d33350aaa94f4bfcd4e9e3110e64d0d7222 \ - --hash=sha256:9be308cb1fe2f1f57d67ce99e95af38a1e2bc71ad9813b0e247cf7ffbcc3a432 - # via -r build/test-requirements.txt -scipy==1.13.1 \ - --hash=sha256:017367484ce5498445aade74b1d5ab377acdc65e27095155e448c88497755a5d \ - --hash=sha256:095a87a0312b08dfd6a6155cbbd310a8c51800fc931b8c0b84003014b874ed3c \ - --hash=sha256:20335853b85e9a49ff7572ab453794298bcf0354d8068c5f6775a0eabf350aca \ - --hash=sha256:27e52b09c0d3a1d5b63e1105f24177e544a222b43611aaf5bc44d4a0979e32f9 \ - --hash=sha256:2831f0dc9c5ea9edd6e51e6e769b655f08ec6db6e2e10f86ef39bd32eb11da54 \ - --hash=sha256:2ac65fb503dad64218c228e2dc2d0a0193f7904747db43014645ae139c8fad16 \ - --hash=sha256:392e4ec766654852c25ebad4f64e4e584cf19820b980bc04960bca0b0cd6eaa2 \ - --hash=sha256:436bbb42a94a8aeef855d755ce5a465479c721e9d684de76bf61a62e7c2b81d5 \ - --hash=sha256:45484bee6d65633752c490404513b9ef02475b4284c4cfab0ef946def50b3f59 \ - --hash=sha256:54f430b00f0133e2224c3ba42b805bfd0086fe488835effa33fa291561932326 \ - --hash=sha256:5713f62f781eebd8d597eb3f88b8bf9274e79eeabf63afb4a737abc6c84ad37b \ - --hash=sha256:5d72782f39716b2b3509cd7c33cdc08c96f2f4d2b06d51e52fb45a19ca0c86a1 \ - --hash=sha256:637e98dcf185ba7f8e663e122ebf908c4702420477ae52a04f9908707456ba4d \ - --hash=sha256:8335549ebbca860c52bf3d02f80784e91a004b71b059e3eea9678ba994796a24 \ - --hash=sha256:949ae67db5fa78a86e8fa644b9a6b07252f449dcf74247108c50e1d20d2b4627 \ - --hash=sha256:a014c2b3697bde71724244f63de2476925596c24285c7a637364761f8710891c \ - --hash=sha256:a78b4b3345f1b6f68a763c6e25c0c9a23a9fd0f39f5f3d200efe8feda560a5fa \ - --hash=sha256:cdd7dacfb95fea358916410ec61bbc20440f7860333aee6d882bb8046264e949 \ - --hash=sha256:cfa31f1def5c819b19ecc3a8b52d28ffdcc7ed52bb20c9a7589669dd3c250989 \ - --hash=sha256:d533654b7d221a6a97304ab63c41c96473ff04459e404b83275b60aa8f4b7004 \ - --hash=sha256:d605e9c23906d1994f55ace80e0125c587f96c020037ea6aa98d01b4bd2e222f \ - --hash=sha256:de3ade0e53bc1f21358aa74ff4830235d716211d7d077e340c7349bc3542e884 \ - --hash=sha256:e89369d27f9e7b0884ae559a3a956e77c02114cc60a6058b4e5011572eea9299 \ - --hash=sha256:eccfa1906eacc02de42d70ef4aecea45415f5be17e72b61bafcfd329bdc52e94 \ - --hash=sha256:f26264b282b9da0952a024ae34710c2aff7d27480ee91a2e82b7b7073c24722f - # via -r build/requirements.in -six==1.16.0 \ - --hash=sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926 \ - --hash=sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254 - # via python-dateutil -sortedcontainers==2.4.0 \ - --hash=sha256:25caa5a06cc30b6b83d11423433f65d1f9d76c4c6a0c90e3379eaa43b9bfdb88 \ - --hash=sha256:a163dcaede0f1c021485e957a39245190e74249897e2ae4b2aa38595db237ee0 - # via hypothesis -tomli==2.0.1 \ - --hash=sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc \ - --hash=sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f - # via - # build - # pytest -typing-extensions==4.12.0rc1 \ - --hash=sha256:be199d06d8f09ca2c9425e3aa04a9afba33e892fe079dea959e72df7f8442343 \ - --hash=sha256:f933a7b288a919ca97adbff656e52ff81f7ff25d98a2aabb9355ca4090f772fe - # via etils -wheel==0.43.0 \ - --hash=sha256:465ef92c69fa5c5da2d1cf8ac40559a8c940886afcef87dcf14b9470862f1d85 \ - --hash=sha256:55c570405f142630c6b9f72fe09d9b67cf1477fcf543ae5b8dcb1f5b7377da81 - # via -r build/test-requirements.txt -zipp==3.18.2 \ - --hash=sha256:6278d9ddbcfb1f1089a88fde84481528b07b0e10474e09dcfe53dad4069fa059 \ - --hash=sha256:dce197b859eb796242b0622af1b8beb0a722d52aa2f57133ead08edd5bf5374e - # via etils -zstandard==0.22.0 \ - --hash=sha256:11f0d1aab9516a497137b41e3d3ed4bbf7b2ee2abc79e5c8b010ad286d7464bd \ - --hash=sha256:1958100b8a1cc3f27fa21071a55cb2ed32e9e5df4c3c6e661c193437f171cba2 \ - --hash=sha256:1a90ba9a4c9c884bb876a14be2b1d216609385efb180393df40e5172e7ecf356 \ - --hash=sha256:1d43501f5f31e22baf822720d82b5547f8a08f5386a883b32584a185675c8fbf \ - --hash=sha256:23d2b3c2b8e7e5a6cb7922f7c27d73a9a615f0a5ab5d0e03dd533c477de23004 \ - --hash=sha256:2612e9bb4977381184bb2463150336d0f7e014d6bb5d4a370f9a372d21916f69 \ - --hash=sha256:275df437ab03f8c033b8a2c181e51716c32d831082d93ce48002a5227ec93019 \ - --hash=sha256:2ac9957bc6d2403c4772c890916bf181b2653640da98f32e04b96e4d6fb3252a \ - --hash=sha256:2b11ea433db22e720758cba584c9d661077121fcf60ab43351950ded20283440 \ - --hash=sha256:2fdd53b806786bd6112d97c1f1e7841e5e4daa06810ab4b284026a1a0e484c0b \ - --hash=sha256:33591d59f4956c9812f8063eff2e2c0065bc02050837f152574069f5f9f17775 \ - --hash=sha256:36a47636c3de227cd765e25a21dc5dace00539b82ddd99ee36abae38178eff9e \ - --hash=sha256:39b2853efc9403927f9065cc48c9980649462acbdf81cd4f0cb773af2fd734bc \ - --hash=sha256:3db41c5e49ef73641d5111554e1d1d3af106410a6c1fb52cf68912ba7a343a0d \ - --hash=sha256:445b47bc32de69d990ad0f34da0e20f535914623d1e506e74d6bc5c9dc40bb09 \ - --hash=sha256:466e6ad8caefb589ed281c076deb6f0cd330e8bc13c5035854ffb9c2014b118c \ - --hash=sha256:48f260e4c7294ef275744210a4010f116048e0c95857befb7462e033f09442fe \ - --hash=sha256:4ac59d5d6910b220141c1737b79d4a5aa9e57466e7469a012ed42ce2d3995e88 \ - --hash=sha256:53866a9d8ab363271c9e80c7c2e9441814961d47f88c9bc3b248142c32141d94 \ - --hash=sha256:589402548251056878d2e7c8859286eb91bd841af117dbe4ab000e6450987e08 \ - --hash=sha256:68953dc84b244b053c0d5f137a21ae8287ecf51b20872eccf8eaac0302d3e3b0 \ - --hash=sha256:6c25b8eb733d4e741246151d895dd0308137532737f337411160ff69ca24f93a \ - --hash=sha256:7034d381789f45576ec3f1fa0e15d741828146439228dc3f7c59856c5bcd3292 \ - --hash=sha256:73a1d6bd01961e9fd447162e137ed949c01bdb830dfca487c4a14e9742dccc93 \ - --hash=sha256:8226a33c542bcb54cd6bd0a366067b610b41713b64c9abec1bc4533d69f51e70 \ - --hash=sha256:888196c9c8893a1e8ff5e89b8f894e7f4f0e64a5af4d8f3c410f0319128bb2f8 \ - --hash=sha256:88c5b4b47a8a138338a07fc94e2ba3b1535f69247670abfe422de4e0b344aae2 \ - --hash=sha256:8a1b2effa96a5f019e72874969394edd393e2fbd6414a8208fea363a22803b45 \ - --hash=sha256:93e1856c8313bc688d5df069e106a4bc962eef3d13372020cc6e3ebf5e045202 \ - --hash=sha256:9501f36fac6b875c124243a379267d879262480bf85b1dbda61f5ad4d01b75a3 \ - --hash=sha256:959665072bd60f45c5b6b5d711f15bdefc9849dd5da9fb6c873e35f5d34d8cfb \ - --hash=sha256:a1d67d0d53d2a138f9e29d8acdabe11310c185e36f0a848efa104d4e40b808e4 \ - --hash=sha256:a493d470183ee620a3df1e6e55b3e4de8143c0ba1b16f3ded83208ea8ddfd91d \ - --hash=sha256:a7ccf5825fd71d4542c8ab28d4d482aace885f5ebe4b40faaa290eed8e095a4c \ - --hash=sha256:a88b7df61a292603e7cd662d92565d915796b094ffb3d206579aaebac6b85d5f \ - --hash=sha256:a97079b955b00b732c6f280d5023e0eefe359045e8b83b08cf0333af9ec78f26 \ - --hash=sha256:d22fdef58976457c65e2796e6730a3ea4a254f3ba83777ecfc8592ff8d77d303 \ - --hash=sha256:d75f693bb4e92c335e0645e8845e553cd09dc91616412d1d4650da835b5449df \ - --hash=sha256:d8593f8464fb64d58e8cb0b905b272d40184eac9a18d83cf8c10749c3eafcd7e \ - --hash=sha256:d8fff0f0c1d8bc5d866762ae95bd99d53282337af1be9dc0d88506b340e74b73 \ - --hash=sha256:de20a212ef3d00d609d0b22eb7cc798d5a69035e81839f549b538eff4105d01c \ - --hash=sha256:e9e9d4e2e336c529d4c435baad846a181e39a982f823f7e4495ec0b0ec8538d2 \ - --hash=sha256:f058a77ef0ece4e210bb0450e68408d4223f728b109764676e1a13537d056bb0 \ - --hash=sha256:f1a4b358947a65b94e2501ce3e078bbc929b039ede4679ddb0460829b12f7375 \ - --hash=sha256:f9b2cde1cd1b2a10246dbc143ba49d942d14fb3d2b4bccf4618d475c65464912 \ - --hash=sha256:fe3390c538f12437b859d815040763abc728955a52ca6ff9c5d4ac707c4ad98e - # via -r build/requirements.in - -# The following packages are considered to be unsafe in a requirements file: -setuptools==76.0.0 \ - --hash=sha256:199466a166ff664970d0ee145839f5582cb9bca7a0a3a2e795b6a9cb2308e9c6 \ - --hash=sha256:43b4ee60e10b0d0ee98ad11918e114c70701bc6051662a9a675a0496c1a158f4 - # via - # -r build/requirements.in - # -r build/test-requirements.txt diff --git a/build/requirements_lock_3_11.txt b/build/requirements_lock_3_11.txt index 8446e8361505..3560da350aa5 100644 --- a/build/requirements_lock_3_11.txt +++ b/build/requirements_lock_3_11.txt @@ -19,7 +19,7 @@ auditwheel==6.1.0 \ build==1.2.1 \ --hash=sha256:526263f4870c26f26c433545579475377b2b7588b6f1eac76a001e873ae3e19d \ --hash=sha256:75e10f767a433d9a86e50d83f418e83efc18ede923ee5ff7df93b6cb0306c5d4 - # via -r build/test-requirements.txt + # via -r build/requirements.in cloudpickle==3.0.0 \ --hash=sha256:246ee7d0c295602a036e86369c77fecda4ab17b506496730f2f576d9016fd9c7 \ --hash=sha256:996d9a482c6fb4f33c1a35335cf8afd065d2a56e973270364840712d9131a882 @@ -27,7 +27,7 @@ cloudpickle==3.0.0 \ colorama==0.4.6 \ --hash=sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44 \ --hash=sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6 - # via -r build/test-requirements.txt + # via -r build/requirements.in contourpy==1.2.1 \ --hash=sha256:00e5388f71c1a0610e6fe56b5c44ab7ba14165cdd6d695429c5cd94021e390b2 \ --hash=sha256:10a37ae557aabf2509c79715cd20b62e4c7c28b8cd62dd7d99e5ed3ce28c3fd9 \ @@ -154,6 +154,44 @@ iniconfig==2.0.0 \ --hash=sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3 \ --hash=sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374 # via pytest +jax-cuda12-pjrt==0.6.2 ; sys_platform == "linux" \ + --hash=sha256:22faf020d2e8f7ca1e2915633241f7df7678b73c7078f5f0b2f113248337f7de \ + --hash=sha256:8cd9ead7948ea2c778a508fef5d1159e8b7abf4fccc7037c3fe1dbfcd95012dc + # via + # -r build/requirements.in + # jax-cuda12-plugin +jax-cuda12-plugin[with-cuda]==0.6.2 ; sys_platform == "linux" \ + --hash=sha256:0896cbb308d95291e205cd89d254029dee3a1df43d66e9831331a9afd2d27870 \ + --hash=sha256:1751f88989269b3cdb0dfe4f7b072a6442149818c9bc98c3a395c8acaf910a79 \ + --hash=sha256:2cd8e279a59a38ba0c978a831e13adeb6ee9e4572fba387c7975ba3ad535dd38 \ + --hash=sha256:6c9b002d13b1fcb9403713eedd3876a227ad1ffbdfb3811b1f9f89af4c25a5f7 \ + --hash=sha256:773efa8b55a837406c561f0ef02144dda9019181193760ec5419eec9dd2b9aac \ + --hash=sha256:83345f52f610cdb8e90044566d8e120864150b8090968c8ab6dd8e0bfb9a6a9f \ + --hash=sha256:bc5c3a75d05519b4d326e4669d0f7ad0fe0f0acf875f9313d913748ccca5a9ea \ + --hash=sha256:db4c6103c912d8cd1adf94c34d313bb4760ca7f01c897ca7cd62e65f27994199 \ + --hash=sha256:ed5316ca1818db7ef53230ee0a41398d3a60942e361dfb857a952eb4d92fc8d7 \ + --hash=sha256:febd099f970d350eb8fa5a2c9a2fb4b0ea7b3d6a89df1496663edfa7afe590e5 + # via -r build/requirements.in +jaxlib==0.6.2 \ + --hash=sha256:11eae7e05bc5a79875da36324afb9eddd4baeaef2a0386caf6d4f3720b9aef28 \ + --hash=sha256:153eaa51f778b60851720729d4f461a91edd9ba3932f6f3bc598d4413870038b \ + --hash=sha256:335d7e3515ce78b52a410136f46aa4a7ea14d0e7d640f34e1e137409554ad0ac \ + --hash=sha256:34d8a684a8be949dd87dd4acc97101b4106a0dc9ad151ec891da072319a57b99 \ + --hash=sha256:39cf9555f85ae1ce2e2c1a59fc71f2eca4f9867a7cb934fef881ba56b11371d1 \ + --hash=sha256:3abd536e44b05fb1657507e3ff1fc3691f99613bae3921ecab9e82f27255f784 \ + --hash=sha256:4205d098ce8efb5f7fe2fe5098bae6036094dc8d8829f5e0e0d7a9b155326336 \ + --hash=sha256:70498837caf538bd458ff6858c8bfd404db82015aba8f663670197fa9900ff02 \ + --hash=sha256:87ec2dc9c3ed9ab936eec8535160c5fbd2c849948559f1c5daa75f63fabe5942 \ + --hash=sha256:921dbd4db214eba19a29ba9f2450d880e08b2b2c7b968f28cc89da3e62366af4 \ + --hash=sha256:a208ff61c58128d306bb4e5ad0858bd2b0960f2c1c10ad42c548f74a60c0020e \ + --hash=sha256:b977604cd36c74b174d25ed685017379468138eb747d865f75e466cb273c801d \ + --hash=sha256:bff67b188133ce1f0111c7b163ac321fd646b59ed221ea489063e2e0f85cb967 \ + --hash=sha256:c087a0eb6fb7f6f8f54d56f4730328dfde5040dd3b5ddfa810e7c28ea7102b42 \ + --hash=sha256:c6815509997d6b05e5c9daa7994b9ad473ce3e8c8a17bdbbcacc3c744f76f7a0 \ + --hash=sha256:da4601b2b5dc8c23d6afb293eacfb9aec4e1d1871cb2f29c5a151d103e73b0f8 \ + --hash=sha256:f1dd09b481a93c1d4c750013f467f74194493ba7bd29fcd4d1cec16e3a214f65 \ + --hash=sha256:f94163f14c8fd3ba93ae14b631abacf14cb031bba0b59138869984b4d10375f8 + # via -r build/requirements.in kiwisolver==1.4.5 \ --hash=sha256:00bd361b903dc4bbf4eb165f24d1acbee754fce22ded24c3d56eec268658a5cf \ --hash=sha256:040c1aebeda72197ef477a906782b5ab0d387642e93bda547336b8957c61022e \ @@ -260,11 +298,14 @@ kiwisolver==1.4.5 \ --hash=sha256:fd32ea360bcbb92d28933fc05ed09bffcb1704ba3fc7942e81db0fd4f81a7892 \ --hash=sha256:fdb7adb641a0d13bdcd4ef48e062363d8a9ad4a182ac7647ec88f695e719ae9f # via matplotlib +libtpu==0.0.13 ; sys_platform == "linux" and platform_machine == "x86_64" \ + --hash=sha256:2b4fcd3b902433ef2c22760a3a13b1474491bb4daf88a2670c6c72b295ebe750 + # via -r build/requirements.in markdown-it-py==3.0.0 \ --hash=sha256:355216845c60bd96232cd8d8c40e8f9765cc86f46880e43a8fd22dc1a1a8cab1 \ --hash=sha256:e3f60a94fa066dc52ec76661e37c851cb232d92f9886b15cb560aaada2df8feb # via rich -matplotlib==3.9.0 ; python_version >= "3.11" \ +matplotlib==3.9.0 \ --hash=sha256:063af8587fceeac13b0936c42a2b6c732c2ab1c98d38abc3337e430e1ff75e38 \ --hash=sha256:06a478f0d67636554fa78558cfbcd7b9dba85b51f5c3b5a0c9be49010cf5f321 \ --hash=sha256:0a490715b3b9984fa609116481b22178348c1a220a4499cda79132000a79b4db \ @@ -324,7 +365,10 @@ ml-dtypes==0.5.1 \ --hash=sha256:c9945669d3dadf8acb40ec2e57d38c985d8c285ea73af57fc5b09872c516106d \ --hash=sha256:d13755f8e8445b3870114e5b6240facaa7cb0c3361e54beba3e07fa912a6e12b \ --hash=sha256:fd918d4e6a4e0c110e2e05be7a7814d10dc1b95872accbf6512b80a109b71ae1 - # via -r build/requirements.in + # via + # -r build/requirements.in + # jaxlib + # tensorstore mpmath==1.4.0a1 \ --hash=sha256:78884400f439f500fa76be0121a8f9598313d87664863a192e1185ddbd7ae97f \ --hash=sha256:f8b7b5a3a1726ab6e8c898eb2157426b82c482ab1ab8ffed9f88bb9e07c6e9c1 @@ -376,76 +420,89 @@ numpy==2.0.0 ; python_version <= "3.12" \ --hash=sha256:fbd6acc766814ea6443628f4e6751d0da6593dae29c08c0b2606164db026970c \ --hash=sha256:feff59f27338135776f6d4e2ec7aeeac5d5f7a08a83e80869121ef8164b74af9 # via - # -r build/requirements.in + # -r build/nonfreethreading-requirements.txt # contourpy + # jaxlib # matplotlib # ml-dtypes # opt-einsum # scipy -nvidia-cublas-cu12==12.8.3.14 ; sys_platform == "linux" \ + # tensorstore +nvidia-cublas-cu12==12.8.3.14 \ --hash=sha256:3f0e05e7293598cf61933258b73e66a160c27d59c4422670bf0b79348c04be44 \ --hash=sha256:93a4e0e386cc7f6e56c822531396de8170ed17068a1e18f987574895044cd8c3 \ --hash=sha256:9ae5eae500aead01fc4bdfc458209df638b1a3551557ce11a78eea9ece602ae9 # via - # -r build/test-requirements.txt + # jax-cuda12-plugin # nvidia-cudnn-cu12 # nvidia-cusolver-cu12 -nvidia-cuda-cupti-cu12==12.8.57 ; sys_platform == "linux" \ +nvidia-cuda-cupti-cu12==12.8.57 \ --hash=sha256:8e0b2eb847de260739bee4a3f66fac31378f4ff49538ff527a38a01a9a39f950 \ --hash=sha256:bbed719c52a476958a74cfc42f2b95a3fd6b3fd94eb40134acc4601feb4acac3 \ --hash=sha256:ff154211724fd824e758ce176b66007b558eea19c9a5135fc991827ee147e317 - # via -r build/test-requirements.txt -nvidia-cuda-nvcc-cu12==12.8.61 ; sys_platform == "linux" \ + # via jax-cuda12-plugin +nvidia-cuda-nvcc-cu12==12.8.61 \ --hash=sha256:171f605044ba17bc455d19cad289946c3dbea029a90c60dfa7b88e545bc8e329 \ --hash=sha256:28604ec42aaa09035b0fb7111432e5121bc385580b30c55d2acfb7d644b16548 \ --hash=sha256:4524739cfc080e9c9e53032912be8f020058e0a7186746d19acef3b6d916ea0b - # via -r build/test-requirements.txt -nvidia-cuda-runtime-cu12==12.8.57 ; sys_platform == "linux" \ + # via jax-cuda12-plugin +nvidia-cuda-nvrtc-cu12==12.9.86 ; sys_platform == "linux" \ + --hash=sha256:096d4de6bda726415dfaf3198d4f5c522b8e70139c97feef5cd2ca6d4cd9cead \ + --hash=sha256:210cf05005a447e29214e9ce50851e83fc5f4358df8b453155d5e1918094dcb4 \ + --hash=sha256:72972ebdcf504d69462d3bcd67e7b81edd25d0fb85a2c46d3ea3517666636349 + # via + # -r build/requirements.in + # jax-cuda12-plugin +nvidia-cuda-runtime-cu12==12.8.57 \ --hash=sha256:534ccebd967b6a44292678fa5da4f00666029cb2ed07a79515ea41ef31fe3ec7 \ --hash=sha256:75342e28567340b7428ce79a5d6bb6ca5ff9d07b69e7ce00d2c7b4dc23eff0be \ --hash=sha256:89be637e3ee967323865b85e0f147d75f9a5bd98360befa37481b02dd57af8f5 - # via -r build/test-requirements.txt -nvidia-cudnn-cu12==9.7.1.26 ; sys_platform == "linux" \ - --hash=sha256:6d011159a158f3cfc47bf851aea79e31bcff60d530b70ef70474c84cac484d07 \ - --hash=sha256:7b805b9a4cf9f3da7c5f4ea4a9dff7baf62d1a612d6154a7e0d2ea51ed296241 \ - --hash=sha256:848a61d40ef3b32bd4e1fadb599f0cf04a4b942fbe5fb3be572ad75f9b8c53ef - # via -r build/test-requirements.txt -nvidia-cufft-cu12==11.3.3.41 ; sys_platform == "linux" \ + # via jax-cuda12-plugin +nvidia-cudnn-cu12==9.8.0.87 \ + --hash=sha256:b4b5cfddc32aa4180f9d390ee99e9a9f55a89e7087329b41aba4319327e22466 \ + --hash=sha256:b883faeb2f6f15dba7bbb6756eab6a0d9cecb59db5b0fa07577b9cfa24cd99f4 \ + --hash=sha256:d6b02cd0e3e24aa31d0193a8c39fec239354360d7d81055edddb69f35d53a4c8 + # via jax-cuda12-plugin +nvidia-cufft-cu12==11.3.3.41 \ --hash=sha256:68509dcd7e3306e69d0e2d8a6d21c8b25ed62e6df8aac192ce752f17677398b5 \ --hash=sha256:da650080ab79fcdf7a4b06aa1b460e99860646b176a43f6208099bdc17836b6a \ --hash=sha256:f9760612886786601d27a0993bb29ce1f757e6b8b173499d0ecfa850d31b50f8 - # via -r build/test-requirements.txt -nvidia-cusolver-cu12==11.7.2.55 ; sys_platform == "linux" \ + # via jax-cuda12-plugin +nvidia-cusolver-cu12==11.7.2.55 \ --hash=sha256:0fd9e98246f43c15bee5561147ad235dfdf2d037f5d07c9d41af3f7f72feb7cc \ --hash=sha256:4d1354102f1e922cee9db51920dba9e2559877cf6ff5ad03a00d853adafb191b \ --hash=sha256:a5a516c55da5c5aba98420d9bc9bcab18245f21ec87338cc1f930eb18dd411ac - # via -r build/test-requirements.txt -nvidia-cusparse-cu12==12.5.7.53 ; sys_platform == "linux" \ + # via jax-cuda12-plugin +nvidia-cusparse-cu12==12.5.7.53 \ --hash=sha256:3c1b61eb8c85257ea07e9354606b26397612627fdcd327bfd91ccf6155e7c86d \ --hash=sha256:82c201d6781bacf6bb7c654f0446728d0fe596dfdd82ef4a04c204ce3e107441 \ --hash=sha256:d869c6146ca80f4305b62e02d924b4aaced936f8173e3cef536a67eed2a91af1 # via - # -r build/test-requirements.txt + # jax-cuda12-plugin # nvidia-cusolver-cu12 -nvidia-nccl-cu12==2.25.1 ; sys_platform == "linux" \ +nvidia-nccl-cu12==2.25.1 \ --hash=sha256:362aed5963fb9ea2ed2f264409baae30143498fd0e5c503aeaa1badd88cdc54a \ --hash=sha256:4ab428bc915785cc66e8c57cb34c7a64cf739c46702b8db748b6ad6cc7180cf8 - # via -r build/test-requirements.txt -nvidia-nvjitlink-cu12==12.8.61 ; sys_platform == "linux" \ + # via jax-cuda12-plugin +nvidia-nvjitlink-cu12==12.8.61 \ --hash=sha256:1166a964d25fdc0eae497574d38824305195a5283324a21ccb0ce0c802cbf41c \ --hash=sha256:45fd79f2ae20bd67e8bc411055939049873bfd8fac70ff13bd4865e0b9bdab17 \ --hash=sha256:9b80ecab31085dda3ce3b41d043be0ec739216c3fc633b8abe212d5a30026df0 # via - # -r build/test-requirements.txt + # jax-cuda12-plugin # nvidia-cufft-cu12 # nvidia-cusolver-cu12 # nvidia-cusparse-cu12 +nvidia-nvshmem-cu12==3.2.5 ; sys_platform == "linux" \ + --hash=sha256:2f5798d65f1a08f9878aae17cf4d3dcbfe884d1f12cf170556cd40f2be90ca96 \ + --hash=sha256:e076957d5cc72e51061a04f2d46f55df477be53e8a55d0d621be08f7aefe1d00 + # via + # -r build/requirements.in + # jax-cuda12-plugin opt-einsum==3.3.0 \ --hash=sha256:2455e59e3947d3c275477df7f5205b30635e266fe6dc300e3d9f9646bfcea147 \ --hash=sha256:59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549 - # via - # -r build/requirements.in - # -r build/test-requirements.txt + # via -r build/requirements.in packaging==24.0 \ --hash=sha256:2ddfb553fdf02fb784c234c7ba6ccc288296ceabec964ad2eae3777778130bc5 \ --hash=sha256:eb82c5e3e56209074766e6885bb04b8c38a0c015d0a30036ebe7ece34c9989e9 @@ -537,10 +594,12 @@ pluggy==1.5.0 \ --hash=sha256:2cffa88e94fdc978c4c574f15f9e59b7f4201d439195c3715ca9e2486f1d0cf1 \ --hash=sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669 # via pytest -portpicker==1.6.0 \ +portpicker==1.6.0 ; python_version < "3.13" \ --hash=sha256:b2787a41404cf7edbe29b07b9e0ed863b09f2665dcc01c1eb0c2261c1e7d0755 \ --hash=sha256:bd507fd6f96f65ee02781f2e674e9dc6c99bbfa6e3c39992e3916204c9d431fa - # via -r build/test-requirements.txt + # via + # -r build/nonfreethreading-requirements.txt + # -r build/test-requirements.txt psutil==5.9.8 \ --hash=sha256:02615ed8c5ea222323408ceba16c60e99c3f91639b07da6373fb7e6539abc56d \ --hash=sha256:05806de88103b25903dff19bb6692bd2e714ccf9e668d050d144012055cbca73 \ @@ -591,7 +650,7 @@ rich==13.7.1 \ --hash=sha256:4edbae314f59eb482f54e9e30bf00d33350aaa94f4bfcd4e9e3110e64d0d7222 \ --hash=sha256:9be308cb1fe2f1f57d67ce99e95af38a1e2bc71ad9813b0e247cf7ffbcc3a432 # via -r build/test-requirements.txt -scipy==1.13.1 \ +scipy==1.13.1 ; python_version <= "3.12" \ --hash=sha256:017367484ce5498445aade74b1d5ab377acdc65e27095155e448c88497755a5d \ --hash=sha256:095a87a0312b08dfd6a6155cbbd310a8c51800fc931b8c0b84003014b874ed3c \ --hash=sha256:20335853b85e9a49ff7572ab453794298bcf0354d8068c5f6775a0eabf350aca \ @@ -617,7 +676,9 @@ scipy==1.13.1 \ --hash=sha256:e89369d27f9e7b0884ae559a3a956e77c02114cc60a6058b4e5011572eea9299 \ --hash=sha256:eccfa1906eacc02de42d70ef4aecea45415f5be17e72b61bafcfd329bdc52e94 \ --hash=sha256:f26264b282b9da0952a024ae34710c2aff7d27480ee91a2e82b7b7073c24722f - # via -r build/requirements.in + # via + # -r build/requirements.in + # jaxlib six==1.16.0 \ --hash=sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926 \ --hash=sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254 @@ -626,6 +687,29 @@ sortedcontainers==2.4.0 \ --hash=sha256:25caa5a06cc30b6b83d11423433f65d1f9d76c4c6a0c90e3379eaa43b9bfdb88 \ --hash=sha256:a163dcaede0f1c021485e957a39245190e74249897e2ae4b2aa38595db237ee0 # via hypothesis +tensorstore==0.1.73 \ + --hash=sha256:03cec5141a27d2e65e4ff604641cfb1f7989d66c361534392e810b80cbda617d \ + --hash=sha256:0429bf781ce3ed45be761b46f4bc5979412dadf063f509cb7e9581981a1e097b \ + --hash=sha256:05f7fdcb063f08f40f74c49f92c0f0136c5b715d49e111950bf025b12a72a907 \ + --hash=sha256:0eb83a2526e211a721842c3e98293e4bc9e1fdb9dac37ecf37d6ccbde84b8ee3 \ + --hash=sha256:192feb8a8fd0f37fa298588d037d4889d2f9d07b18b3295488f05ee268f57b70 \ + --hash=sha256:2aed43498b00d37df583da9e06328751cfe695bb166043aa9ef7183174cf7e29 \ + --hash=sha256:421a3f87864a0a8837b4f9f0c8ee86079b46b112de902496d3b90c72f51d02ea \ + --hash=sha256:440569458b91974e0ffa210654a01f2721758476c48240f7c925fc0d107056be \ + --hash=sha256:4433dcfcb943e100b90b0fc8e0b1d174e8c2c1cedb1fcc86e6d20b6a2e961831 \ + --hash=sha256:44d70dd0c000db8c0d2386e788c5e91d3b37ebee8f629f3848d7a012c85d1e11 \ + --hash=sha256:5fc9feab09de9e99c381145adeef5ff9e01f898e509b851ff2edd940c8b2384a \ + --hash=sha256:70d57b63706de4a3a9c1c217b338658fa160b2d41f5b399e6926f9eaf29b2a4d \ + --hash=sha256:7a812e8297a4ed70109057628b767c1a12b535f2db657635f0ed1517b23b990b \ + --hash=sha256:7b4e08bfa61880863bedb90499a23c63d9493cf9310207c230086b0a3700c75d \ + --hash=sha256:83c6ca5cb39ffeeb4a562942e3b9e2f32b026f362b2b7266c44201bd7c3116a5 \ + --hash=sha256:87fb7879af73a5b7ded9c9de3e2014baf6468d9d7c47edfc19490907b346e0a6 \ + --hash=sha256:a11d2e496d7442c68b35cd222a8c8df3fdee9e30fb2984c91546d81faff8bf61 \ + --hash=sha256:be3f5ef6f359486ee52785e8a302819152e51286c50181c6c35f316b7568ce60 \ + --hash=sha256:dd7fa6d7e9579a1a75e6185d7df10e28fcc7db2e14190ed60261a71b9c09e1df \ + --hash=sha256:e99ae99ac48f41c4e36b1e3717c6dbdab96dd27fc91618dd01afb9ad848a9293 \ + --hash=sha256:f24b325385fd30be612ab8494a29d3bfef37b9444357912ba184f30f325f093b + # via -r build/nonfreethreading-requirements.txt typing-extensions==4.12.0rc1 \ --hash=sha256:be199d06d8f09ca2c9425e3aa04a9afba33e892fe079dea959e72df7f8442343 \ --hash=sha256:f933a7b288a919ca97adbff656e52ff81f7ff25d98a2aabb9355ca4090f772fe @@ -633,7 +717,7 @@ typing-extensions==4.12.0rc1 \ wheel==0.43.0 \ --hash=sha256:465ef92c69fa5c5da2d1cf8ac40559a8c940886afcef87dcf14b9470862f1d85 \ --hash=sha256:55c570405f142630c6b9f72fe09d9b67cf1477fcf543ae5b8dcb1f5b7377da81 - # via -r build/test-requirements.txt + # via -r build/requirements.in zipp==3.18.2 \ --hash=sha256:6278d9ddbcfb1f1089a88fde84481528b07b0e10474e09dcfe53dad4069fa059 \ --hash=sha256:dce197b859eb796242b0622af1b8beb0a722d52aa2f57133ead08edd5bf5374e @@ -685,12 +769,10 @@ zstandard==0.22.0 \ --hash=sha256:f1a4b358947a65b94e2501ce3e078bbc929b039ede4679ddb0460829b12f7375 \ --hash=sha256:f9b2cde1cd1b2a10246dbc143ba49d942d14fb3d2b4bccf4618d475c65464912 \ --hash=sha256:fe3390c538f12437b859d815040763abc728955a52ca6ff9c5d4ac707c4ad98e - # via -r build/requirements.in + # via -r build/nonfreethreading-requirements.txt # The following packages are considered to be unsafe in a requirements file: setuptools==76.0.0 \ --hash=sha256:199466a166ff664970d0ee145839f5582cb9bca7a0a3a2e795b6a9cb2308e9c6 \ --hash=sha256:43b4ee60e10b0d0ee98ad11918e114c70701bc6051662a9a675a0496c1a158f4 - # via - # -r build/requirements.in - # -r build/test-requirements.txt + # via -r build/requirements.in diff --git a/build/requirements_lock_3_12.txt b/build/requirements_lock_3_12.txt index 0436ab6dd486..743fbbba325f 100644 --- a/build/requirements_lock_3_12.txt +++ b/build/requirements_lock_3_12.txt @@ -19,7 +19,7 @@ auditwheel==6.1.0 \ build==1.2.1 \ --hash=sha256:526263f4870c26f26c433545579475377b2b7588b6f1eac76a001e873ae3e19d \ --hash=sha256:75e10f767a433d9a86e50d83f418e83efc18ede923ee5ff7df93b6cb0306c5d4 - # via -r build/test-requirements.txt + # via -r build/requirements.in cloudpickle==3.0.0 \ --hash=sha256:246ee7d0c295602a036e86369c77fecda4ab17b506496730f2f576d9016fd9c7 \ --hash=sha256:996d9a482c6fb4f33c1a35335cf8afd065d2a56e973270364840712d9131a882 @@ -27,7 +27,7 @@ cloudpickle==3.0.0 \ colorama==0.4.6 \ --hash=sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44 \ --hash=sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6 - # via -r build/test-requirements.txt + # via -r build/requirements.in contourpy==1.2.1 \ --hash=sha256:00e5388f71c1a0610e6fe56b5c44ab7ba14165cdd6d695429c5cd94021e390b2 \ --hash=sha256:10a37ae557aabf2509c79715cd20b62e4c7c28b8cd62dd7d99e5ed3ce28c3fd9 \ @@ -154,6 +154,44 @@ iniconfig==2.0.0 \ --hash=sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3 \ --hash=sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374 # via pytest +jax-cuda12-pjrt==0.6.2 ; sys_platform == "linux" \ + --hash=sha256:22faf020d2e8f7ca1e2915633241f7df7678b73c7078f5f0b2f113248337f7de \ + --hash=sha256:8cd9ead7948ea2c778a508fef5d1159e8b7abf4fccc7037c3fe1dbfcd95012dc + # via + # -r build/requirements.in + # jax-cuda12-plugin +jax-cuda12-plugin[with-cuda]==0.6.2 ; sys_platform == "linux" \ + --hash=sha256:0896cbb308d95291e205cd89d254029dee3a1df43d66e9831331a9afd2d27870 \ + --hash=sha256:1751f88989269b3cdb0dfe4f7b072a6442149818c9bc98c3a395c8acaf910a79 \ + --hash=sha256:2cd8e279a59a38ba0c978a831e13adeb6ee9e4572fba387c7975ba3ad535dd38 \ + --hash=sha256:6c9b002d13b1fcb9403713eedd3876a227ad1ffbdfb3811b1f9f89af4c25a5f7 \ + --hash=sha256:773efa8b55a837406c561f0ef02144dda9019181193760ec5419eec9dd2b9aac \ + --hash=sha256:83345f52f610cdb8e90044566d8e120864150b8090968c8ab6dd8e0bfb9a6a9f \ + --hash=sha256:bc5c3a75d05519b4d326e4669d0f7ad0fe0f0acf875f9313d913748ccca5a9ea \ + --hash=sha256:db4c6103c912d8cd1adf94c34d313bb4760ca7f01c897ca7cd62e65f27994199 \ + --hash=sha256:ed5316ca1818db7ef53230ee0a41398d3a60942e361dfb857a952eb4d92fc8d7 \ + --hash=sha256:febd099f970d350eb8fa5a2c9a2fb4b0ea7b3d6a89df1496663edfa7afe590e5 + # via -r build/requirements.in +jaxlib==0.6.2 \ + --hash=sha256:11eae7e05bc5a79875da36324afb9eddd4baeaef2a0386caf6d4f3720b9aef28 \ + --hash=sha256:153eaa51f778b60851720729d4f461a91edd9ba3932f6f3bc598d4413870038b \ + --hash=sha256:335d7e3515ce78b52a410136f46aa4a7ea14d0e7d640f34e1e137409554ad0ac \ + --hash=sha256:34d8a684a8be949dd87dd4acc97101b4106a0dc9ad151ec891da072319a57b99 \ + --hash=sha256:39cf9555f85ae1ce2e2c1a59fc71f2eca4f9867a7cb934fef881ba56b11371d1 \ + --hash=sha256:3abd536e44b05fb1657507e3ff1fc3691f99613bae3921ecab9e82f27255f784 \ + --hash=sha256:4205d098ce8efb5f7fe2fe5098bae6036094dc8d8829f5e0e0d7a9b155326336 \ + --hash=sha256:70498837caf538bd458ff6858c8bfd404db82015aba8f663670197fa9900ff02 \ + --hash=sha256:87ec2dc9c3ed9ab936eec8535160c5fbd2c849948559f1c5daa75f63fabe5942 \ + --hash=sha256:921dbd4db214eba19a29ba9f2450d880e08b2b2c7b968f28cc89da3e62366af4 \ + --hash=sha256:a208ff61c58128d306bb4e5ad0858bd2b0960f2c1c10ad42c548f74a60c0020e \ + --hash=sha256:b977604cd36c74b174d25ed685017379468138eb747d865f75e466cb273c801d \ + --hash=sha256:bff67b188133ce1f0111c7b163ac321fd646b59ed221ea489063e2e0f85cb967 \ + --hash=sha256:c087a0eb6fb7f6f8f54d56f4730328dfde5040dd3b5ddfa810e7c28ea7102b42 \ + --hash=sha256:c6815509997d6b05e5c9daa7994b9ad473ce3e8c8a17bdbbcacc3c744f76f7a0 \ + --hash=sha256:da4601b2b5dc8c23d6afb293eacfb9aec4e1d1871cb2f29c5a151d103e73b0f8 \ + --hash=sha256:f1dd09b481a93c1d4c750013f467f74194493ba7bd29fcd4d1cec16e3a214f65 \ + --hash=sha256:f94163f14c8fd3ba93ae14b631abacf14cb031bba0b59138869984b4d10375f8 + # via -r build/requirements.in kiwisolver==1.4.5 \ --hash=sha256:00bd361b903dc4bbf4eb165f24d1acbee754fce22ded24c3d56eec268658a5cf \ --hash=sha256:040c1aebeda72197ef477a906782b5ab0d387642e93bda547336b8957c61022e \ @@ -260,11 +298,14 @@ kiwisolver==1.4.5 \ --hash=sha256:fd32ea360bcbb92d28933fc05ed09bffcb1704ba3fc7942e81db0fd4f81a7892 \ --hash=sha256:fdb7adb641a0d13bdcd4ef48e062363d8a9ad4a182ac7647ec88f695e719ae9f # via matplotlib +libtpu==0.0.13 ; sys_platform == "linux" and platform_machine == "x86_64" \ + --hash=sha256:2b4fcd3b902433ef2c22760a3a13b1474491bb4daf88a2670c6c72b295ebe750 + # via -r build/requirements.in markdown-it-py==3.0.0 \ --hash=sha256:355216845c60bd96232cd8d8c40e8f9765cc86f46880e43a8fd22dc1a1a8cab1 \ --hash=sha256:e3f60a94fa066dc52ec76661e37c851cb232d92f9886b15cb560aaada2df8feb # via rich -matplotlib==3.9.0 ; python_version >= "3.11" \ +matplotlib==3.9.0 \ --hash=sha256:063af8587fceeac13b0936c42a2b6c732c2ab1c98d38abc3337e430e1ff75e38 \ --hash=sha256:06a478f0d67636554fa78558cfbcd7b9dba85b51f5c3b5a0c9be49010cf5f321 \ --hash=sha256:0a490715b3b9984fa609116481b22178348c1a220a4499cda79132000a79b4db \ @@ -324,7 +365,10 @@ ml-dtypes==0.5.1 \ --hash=sha256:c9945669d3dadf8acb40ec2e57d38c985d8c285ea73af57fc5b09872c516106d \ --hash=sha256:d13755f8e8445b3870114e5b6240facaa7cb0c3361e54beba3e07fa912a6e12b \ --hash=sha256:fd918d4e6a4e0c110e2e05be7a7814d10dc1b95872accbf6512b80a109b71ae1 - # via -r build/requirements.in + # via + # -r build/requirements.in + # jaxlib + # tensorstore mpmath==1.4.0a1 \ --hash=sha256:78884400f439f500fa76be0121a8f9598313d87664863a192e1185ddbd7ae97f \ --hash=sha256:f8b7b5a3a1726ab6e8c898eb2157426b82c482ab1ab8ffed9f88bb9e07c6e9c1 @@ -376,76 +420,89 @@ numpy==2.0.0 ; python_version <= "3.12" \ --hash=sha256:fbd6acc766814ea6443628f4e6751d0da6593dae29c08c0b2606164db026970c \ --hash=sha256:feff59f27338135776f6d4e2ec7aeeac5d5f7a08a83e80869121ef8164b74af9 # via - # -r build/requirements.in + # -r build/nonfreethreading-requirements.txt # contourpy + # jaxlib # matplotlib # ml-dtypes # opt-einsum # scipy -nvidia-cublas-cu12==12.8.3.14 ; sys_platform == "linux" \ + # tensorstore +nvidia-cublas-cu12==12.8.3.14 \ --hash=sha256:3f0e05e7293598cf61933258b73e66a160c27d59c4422670bf0b79348c04be44 \ --hash=sha256:93a4e0e386cc7f6e56c822531396de8170ed17068a1e18f987574895044cd8c3 \ --hash=sha256:9ae5eae500aead01fc4bdfc458209df638b1a3551557ce11a78eea9ece602ae9 # via - # -r build/test-requirements.txt + # jax-cuda12-plugin # nvidia-cudnn-cu12 # nvidia-cusolver-cu12 -nvidia-cuda-cupti-cu12==12.8.57 ; sys_platform == "linux" \ +nvidia-cuda-cupti-cu12==12.8.57 \ --hash=sha256:8e0b2eb847de260739bee4a3f66fac31378f4ff49538ff527a38a01a9a39f950 \ --hash=sha256:bbed719c52a476958a74cfc42f2b95a3fd6b3fd94eb40134acc4601feb4acac3 \ --hash=sha256:ff154211724fd824e758ce176b66007b558eea19c9a5135fc991827ee147e317 - # via -r build/test-requirements.txt -nvidia-cuda-nvcc-cu12==12.8.61 ; sys_platform == "linux" \ + # via jax-cuda12-plugin +nvidia-cuda-nvcc-cu12==12.8.61 \ --hash=sha256:171f605044ba17bc455d19cad289946c3dbea029a90c60dfa7b88e545bc8e329 \ --hash=sha256:28604ec42aaa09035b0fb7111432e5121bc385580b30c55d2acfb7d644b16548 \ --hash=sha256:4524739cfc080e9c9e53032912be8f020058e0a7186746d19acef3b6d916ea0b - # via -r build/test-requirements.txt -nvidia-cuda-runtime-cu12==12.8.57 ; sys_platform == "linux" \ + # via jax-cuda12-plugin +nvidia-cuda-nvrtc-cu12==12.9.86 ; sys_platform == "linux" \ + --hash=sha256:096d4de6bda726415dfaf3198d4f5c522b8e70139c97feef5cd2ca6d4cd9cead \ + --hash=sha256:210cf05005a447e29214e9ce50851e83fc5f4358df8b453155d5e1918094dcb4 \ + --hash=sha256:72972ebdcf504d69462d3bcd67e7b81edd25d0fb85a2c46d3ea3517666636349 + # via + # -r build/requirements.in + # jax-cuda12-plugin +nvidia-cuda-runtime-cu12==12.8.57 \ --hash=sha256:534ccebd967b6a44292678fa5da4f00666029cb2ed07a79515ea41ef31fe3ec7 \ --hash=sha256:75342e28567340b7428ce79a5d6bb6ca5ff9d07b69e7ce00d2c7b4dc23eff0be \ --hash=sha256:89be637e3ee967323865b85e0f147d75f9a5bd98360befa37481b02dd57af8f5 - # via -r build/test-requirements.txt -nvidia-cudnn-cu12==9.7.1.26 ; sys_platform == "linux" \ - --hash=sha256:6d011159a158f3cfc47bf851aea79e31bcff60d530b70ef70474c84cac484d07 \ - --hash=sha256:7b805b9a4cf9f3da7c5f4ea4a9dff7baf62d1a612d6154a7e0d2ea51ed296241 \ - --hash=sha256:848a61d40ef3b32bd4e1fadb599f0cf04a4b942fbe5fb3be572ad75f9b8c53ef - # via -r build/test-requirements.txt -nvidia-cufft-cu12==11.3.3.41 ; sys_platform == "linux" \ + # via jax-cuda12-plugin +nvidia-cudnn-cu12==9.8.0.87 \ + --hash=sha256:b4b5cfddc32aa4180f9d390ee99e9a9f55a89e7087329b41aba4319327e22466 \ + --hash=sha256:b883faeb2f6f15dba7bbb6756eab6a0d9cecb59db5b0fa07577b9cfa24cd99f4 \ + --hash=sha256:d6b02cd0e3e24aa31d0193a8c39fec239354360d7d81055edddb69f35d53a4c8 + # via jax-cuda12-plugin +nvidia-cufft-cu12==11.3.3.41 \ --hash=sha256:68509dcd7e3306e69d0e2d8a6d21c8b25ed62e6df8aac192ce752f17677398b5 \ --hash=sha256:da650080ab79fcdf7a4b06aa1b460e99860646b176a43f6208099bdc17836b6a \ --hash=sha256:f9760612886786601d27a0993bb29ce1f757e6b8b173499d0ecfa850d31b50f8 - # via -r build/test-requirements.txt -nvidia-cusolver-cu12==11.7.2.55 ; sys_platform == "linux" \ + # via jax-cuda12-plugin +nvidia-cusolver-cu12==11.7.2.55 \ --hash=sha256:0fd9e98246f43c15bee5561147ad235dfdf2d037f5d07c9d41af3f7f72feb7cc \ --hash=sha256:4d1354102f1e922cee9db51920dba9e2559877cf6ff5ad03a00d853adafb191b \ --hash=sha256:a5a516c55da5c5aba98420d9bc9bcab18245f21ec87338cc1f930eb18dd411ac - # via -r build/test-requirements.txt -nvidia-cusparse-cu12==12.5.7.53 ; sys_platform == "linux" \ + # via jax-cuda12-plugin +nvidia-cusparse-cu12==12.5.7.53 \ --hash=sha256:3c1b61eb8c85257ea07e9354606b26397612627fdcd327bfd91ccf6155e7c86d \ --hash=sha256:82c201d6781bacf6bb7c654f0446728d0fe596dfdd82ef4a04c204ce3e107441 \ --hash=sha256:d869c6146ca80f4305b62e02d924b4aaced936f8173e3cef536a67eed2a91af1 # via - # -r build/test-requirements.txt + # jax-cuda12-plugin # nvidia-cusolver-cu12 -nvidia-nccl-cu12==2.25.1 ; sys_platform == "linux" \ +nvidia-nccl-cu12==2.25.1 \ --hash=sha256:362aed5963fb9ea2ed2f264409baae30143498fd0e5c503aeaa1badd88cdc54a \ --hash=sha256:4ab428bc915785cc66e8c57cb34c7a64cf739c46702b8db748b6ad6cc7180cf8 - # via -r build/test-requirements.txt -nvidia-nvjitlink-cu12==12.8.61 ; sys_platform == "linux" \ + # via jax-cuda12-plugin +nvidia-nvjitlink-cu12==12.8.61 \ --hash=sha256:1166a964d25fdc0eae497574d38824305195a5283324a21ccb0ce0c802cbf41c \ --hash=sha256:45fd79f2ae20bd67e8bc411055939049873bfd8fac70ff13bd4865e0b9bdab17 \ --hash=sha256:9b80ecab31085dda3ce3b41d043be0ec739216c3fc633b8abe212d5a30026df0 # via - # -r build/test-requirements.txt + # jax-cuda12-plugin # nvidia-cufft-cu12 # nvidia-cusolver-cu12 # nvidia-cusparse-cu12 +nvidia-nvshmem-cu12==3.2.5 ; sys_platform == "linux" \ + --hash=sha256:2f5798d65f1a08f9878aae17cf4d3dcbfe884d1f12cf170556cd40f2be90ca96 \ + --hash=sha256:e076957d5cc72e51061a04f2d46f55df477be53e8a55d0d621be08f7aefe1d00 + # via + # -r build/requirements.in + # jax-cuda12-plugin opt-einsum==3.3.0 \ --hash=sha256:2455e59e3947d3c275477df7f5205b30635e266fe6dc300e3d9f9646bfcea147 \ --hash=sha256:59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549 - # via - # -r build/requirements.in - # -r build/test-requirements.txt + # via -r build/requirements.in packaging==24.0 \ --hash=sha256:2ddfb553fdf02fb784c234c7ba6ccc288296ceabec964ad2eae3777778130bc5 \ --hash=sha256:eb82c5e3e56209074766e6885bb04b8c38a0c015d0a30036ebe7ece34c9989e9 @@ -537,10 +594,12 @@ pluggy==1.5.0 \ --hash=sha256:2cffa88e94fdc978c4c574f15f9e59b7f4201d439195c3715ca9e2486f1d0cf1 \ --hash=sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669 # via pytest -portpicker==1.6.0 \ +portpicker==1.6.0 ; python_version < "3.13" \ --hash=sha256:b2787a41404cf7edbe29b07b9e0ed863b09f2665dcc01c1eb0c2261c1e7d0755 \ --hash=sha256:bd507fd6f96f65ee02781f2e674e9dc6c99bbfa6e3c39992e3916204c9d431fa - # via -r build/test-requirements.txt + # via + # -r build/nonfreethreading-requirements.txt + # -r build/test-requirements.txt psutil==5.9.8 \ --hash=sha256:02615ed8c5ea222323408ceba16c60e99c3f91639b07da6373fb7e6539abc56d \ --hash=sha256:05806de88103b25903dff19bb6692bd2e714ccf9e668d050d144012055cbca73 \ @@ -591,7 +650,7 @@ rich==13.7.1 \ --hash=sha256:4edbae314f59eb482f54e9e30bf00d33350aaa94f4bfcd4e9e3110e64d0d7222 \ --hash=sha256:9be308cb1fe2f1f57d67ce99e95af38a1e2bc71ad9813b0e247cf7ffbcc3a432 # via -r build/test-requirements.txt -scipy==1.13.1 \ +scipy==1.13.1 ; python_version <= "3.12" \ --hash=sha256:017367484ce5498445aade74b1d5ab377acdc65e27095155e448c88497755a5d \ --hash=sha256:095a87a0312b08dfd6a6155cbbd310a8c51800fc931b8c0b84003014b874ed3c \ --hash=sha256:20335853b85e9a49ff7572ab453794298bcf0354d8068c5f6775a0eabf350aca \ @@ -617,7 +676,9 @@ scipy==1.13.1 \ --hash=sha256:e89369d27f9e7b0884ae559a3a956e77c02114cc60a6058b4e5011572eea9299 \ --hash=sha256:eccfa1906eacc02de42d70ef4aecea45415f5be17e72b61bafcfd329bdc52e94 \ --hash=sha256:f26264b282b9da0952a024ae34710c2aff7d27480ee91a2e82b7b7073c24722f - # via -r build/requirements.in + # via + # -r build/requirements.in + # jaxlib six==1.16.0 \ --hash=sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926 \ --hash=sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254 @@ -626,6 +687,29 @@ sortedcontainers==2.4.0 \ --hash=sha256:25caa5a06cc30b6b83d11423433f65d1f9d76c4c6a0c90e3379eaa43b9bfdb88 \ --hash=sha256:a163dcaede0f1c021485e957a39245190e74249897e2ae4b2aa38595db237ee0 # via hypothesis +tensorstore==0.1.73 \ + --hash=sha256:03cec5141a27d2e65e4ff604641cfb1f7989d66c361534392e810b80cbda617d \ + --hash=sha256:0429bf781ce3ed45be761b46f4bc5979412dadf063f509cb7e9581981a1e097b \ + --hash=sha256:05f7fdcb063f08f40f74c49f92c0f0136c5b715d49e111950bf025b12a72a907 \ + --hash=sha256:0eb83a2526e211a721842c3e98293e4bc9e1fdb9dac37ecf37d6ccbde84b8ee3 \ + --hash=sha256:192feb8a8fd0f37fa298588d037d4889d2f9d07b18b3295488f05ee268f57b70 \ + --hash=sha256:2aed43498b00d37df583da9e06328751cfe695bb166043aa9ef7183174cf7e29 \ + --hash=sha256:421a3f87864a0a8837b4f9f0c8ee86079b46b112de902496d3b90c72f51d02ea \ + --hash=sha256:440569458b91974e0ffa210654a01f2721758476c48240f7c925fc0d107056be \ + --hash=sha256:4433dcfcb943e100b90b0fc8e0b1d174e8c2c1cedb1fcc86e6d20b6a2e961831 \ + --hash=sha256:44d70dd0c000db8c0d2386e788c5e91d3b37ebee8f629f3848d7a012c85d1e11 \ + --hash=sha256:5fc9feab09de9e99c381145adeef5ff9e01f898e509b851ff2edd940c8b2384a \ + --hash=sha256:70d57b63706de4a3a9c1c217b338658fa160b2d41f5b399e6926f9eaf29b2a4d \ + --hash=sha256:7a812e8297a4ed70109057628b767c1a12b535f2db657635f0ed1517b23b990b \ + --hash=sha256:7b4e08bfa61880863bedb90499a23c63d9493cf9310207c230086b0a3700c75d \ + --hash=sha256:83c6ca5cb39ffeeb4a562942e3b9e2f32b026f362b2b7266c44201bd7c3116a5 \ + --hash=sha256:87fb7879af73a5b7ded9c9de3e2014baf6468d9d7c47edfc19490907b346e0a6 \ + --hash=sha256:a11d2e496d7442c68b35cd222a8c8df3fdee9e30fb2984c91546d81faff8bf61 \ + --hash=sha256:be3f5ef6f359486ee52785e8a302819152e51286c50181c6c35f316b7568ce60 \ + --hash=sha256:dd7fa6d7e9579a1a75e6185d7df10e28fcc7db2e14190ed60261a71b9c09e1df \ + --hash=sha256:e99ae99ac48f41c4e36b1e3717c6dbdab96dd27fc91618dd01afb9ad848a9293 \ + --hash=sha256:f24b325385fd30be612ab8494a29d3bfef37b9444357912ba184f30f325f093b + # via -r build/nonfreethreading-requirements.txt typing-extensions==4.12.0rc1 \ --hash=sha256:be199d06d8f09ca2c9425e3aa04a9afba33e892fe079dea959e72df7f8442343 \ --hash=sha256:f933a7b288a919ca97adbff656e52ff81f7ff25d98a2aabb9355ca4090f772fe @@ -633,7 +717,7 @@ typing-extensions==4.12.0rc1 \ wheel==0.43.0 \ --hash=sha256:465ef92c69fa5c5da2d1cf8ac40559a8c940886afcef87dcf14b9470862f1d85 \ --hash=sha256:55c570405f142630c6b9f72fe09d9b67cf1477fcf543ae5b8dcb1f5b7377da81 - # via -r build/test-requirements.txt + # via -r build/requirements.in zipp==3.18.2 \ --hash=sha256:6278d9ddbcfb1f1089a88fde84481528b07b0e10474e09dcfe53dad4069fa059 \ --hash=sha256:dce197b859eb796242b0622af1b8beb0a722d52aa2f57133ead08edd5bf5374e @@ -685,12 +769,10 @@ zstandard==0.22.0 \ --hash=sha256:f1a4b358947a65b94e2501ce3e078bbc929b039ede4679ddb0460829b12f7375 \ --hash=sha256:f9b2cde1cd1b2a10246dbc143ba49d942d14fb3d2b4bccf4618d475c65464912 \ --hash=sha256:fe3390c538f12437b859d815040763abc728955a52ca6ff9c5d4ac707c4ad98e - # via -r build/requirements.in + # via -r build/nonfreethreading-requirements.txt # The following packages are considered to be unsafe in a requirements file: setuptools==76.0.0 \ --hash=sha256:199466a166ff664970d0ee145839f5582cb9bca7a0a3a2e795b6a9cb2308e9c6 \ --hash=sha256:43b4ee60e10b0d0ee98ad11918e114c70701bc6051662a9a675a0496c1a158f4 - # via - # -r build/requirements.in - # -r build/test-requirements.txt + # via -r build/requirements.in diff --git a/build/requirements_lock_3_13.txt b/build/requirements_lock_3_13.txt index e74d40b798f4..aa45b473d9ae 100644 --- a/build/requirements_lock_3_13.txt +++ b/build/requirements_lock_3_13.txt @@ -19,7 +19,7 @@ auditwheel==6.1.0 \ build==1.2.2.post1 \ --hash=sha256:1d61c0887fa860c01971625baae8bdd338e517b836a2f70dd1f7aa3a6b2fc5b5 \ --hash=sha256:b36993e92ca9375a219c99e606a122ff365a760a2d4bba0caa09bd5278b608b7 - # via -r build/test-requirements.txt + # via -r build/requirements.in cloudpickle==3.0.0 \ --hash=sha256:246ee7d0c295602a036e86369c77fecda4ab17b506496730f2f576d9016fd9c7 \ --hash=sha256:996d9a482c6fb4f33c1a35335cf8afd065d2a56e973270364840712d9131a882 @@ -27,7 +27,7 @@ cloudpickle==3.0.0 \ colorama==0.4.6 \ --hash=sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44 \ --hash=sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6 - # via -r build/test-requirements.txt + # via -r build/requirements.in contourpy==1.3.0 \ --hash=sha256:00ccd0dbaad6d804ab259820fa7cb0b8036bda0686ef844d24125d8287178ce0 \ --hash=sha256:0be4d8425bfa755e0fd76ee1e019636ccc7c29f77a7c86b4328a9eb6a26d0639 \ @@ -181,6 +181,44 @@ iniconfig==2.0.0 \ --hash=sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3 \ --hash=sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374 # via pytest +jax-cuda12-pjrt==0.6.2 ; sys_platform == "linux" \ + --hash=sha256:22faf020d2e8f7ca1e2915633241f7df7678b73c7078f5f0b2f113248337f7de \ + --hash=sha256:8cd9ead7948ea2c778a508fef5d1159e8b7abf4fccc7037c3fe1dbfcd95012dc + # via + # -r build/requirements.in + # jax-cuda12-plugin +jax-cuda12-plugin[with-cuda]==0.6.2 ; sys_platform == "linux" \ + --hash=sha256:0896cbb308d95291e205cd89d254029dee3a1df43d66e9831331a9afd2d27870 \ + --hash=sha256:1751f88989269b3cdb0dfe4f7b072a6442149818c9bc98c3a395c8acaf910a79 \ + --hash=sha256:2cd8e279a59a38ba0c978a831e13adeb6ee9e4572fba387c7975ba3ad535dd38 \ + --hash=sha256:6c9b002d13b1fcb9403713eedd3876a227ad1ffbdfb3811b1f9f89af4c25a5f7 \ + --hash=sha256:773efa8b55a837406c561f0ef02144dda9019181193760ec5419eec9dd2b9aac \ + --hash=sha256:83345f52f610cdb8e90044566d8e120864150b8090968c8ab6dd8e0bfb9a6a9f \ + --hash=sha256:bc5c3a75d05519b4d326e4669d0f7ad0fe0f0acf875f9313d913748ccca5a9ea \ + --hash=sha256:db4c6103c912d8cd1adf94c34d313bb4760ca7f01c897ca7cd62e65f27994199 \ + --hash=sha256:ed5316ca1818db7ef53230ee0a41398d3a60942e361dfb857a952eb4d92fc8d7 \ + --hash=sha256:febd099f970d350eb8fa5a2c9a2fb4b0ea7b3d6a89df1496663edfa7afe590e5 + # via -r build/requirements.in +jaxlib==0.6.2 \ + --hash=sha256:11eae7e05bc5a79875da36324afb9eddd4baeaef2a0386caf6d4f3720b9aef28 \ + --hash=sha256:153eaa51f778b60851720729d4f461a91edd9ba3932f6f3bc598d4413870038b \ + --hash=sha256:335d7e3515ce78b52a410136f46aa4a7ea14d0e7d640f34e1e137409554ad0ac \ + --hash=sha256:34d8a684a8be949dd87dd4acc97101b4106a0dc9ad151ec891da072319a57b99 \ + --hash=sha256:39cf9555f85ae1ce2e2c1a59fc71f2eca4f9867a7cb934fef881ba56b11371d1 \ + --hash=sha256:3abd536e44b05fb1657507e3ff1fc3691f99613bae3921ecab9e82f27255f784 \ + --hash=sha256:4205d098ce8efb5f7fe2fe5098bae6036094dc8d8829f5e0e0d7a9b155326336 \ + --hash=sha256:70498837caf538bd458ff6858c8bfd404db82015aba8f663670197fa9900ff02 \ + --hash=sha256:87ec2dc9c3ed9ab936eec8535160c5fbd2c849948559f1c5daa75f63fabe5942 \ + --hash=sha256:921dbd4db214eba19a29ba9f2450d880e08b2b2c7b968f28cc89da3e62366af4 \ + --hash=sha256:a208ff61c58128d306bb4e5ad0858bd2b0960f2c1c10ad42c548f74a60c0020e \ + --hash=sha256:b977604cd36c74b174d25ed685017379468138eb747d865f75e466cb273c801d \ + --hash=sha256:bff67b188133ce1f0111c7b163ac321fd646b59ed221ea489063e2e0f85cb967 \ + --hash=sha256:c087a0eb6fb7f6f8f54d56f4730328dfde5040dd3b5ddfa810e7c28ea7102b42 \ + --hash=sha256:c6815509997d6b05e5c9daa7994b9ad473ce3e8c8a17bdbbcacc3c744f76f7a0 \ + --hash=sha256:da4601b2b5dc8c23d6afb293eacfb9aec4e1d1871cb2f29c5a151d103e73b0f8 \ + --hash=sha256:f1dd09b481a93c1d4c750013f467f74194493ba7bd29fcd4d1cec16e3a214f65 \ + --hash=sha256:f94163f14c8fd3ba93ae14b631abacf14cb031bba0b59138869984b4d10375f8 + # via -r build/requirements.in kiwisolver==1.4.7 \ --hash=sha256:073a36c8273647592ea332e816e75ef8da5c303236ec0167196793eb1e34657a \ --hash=sha256:08471d4d86cbaec61f86b217dd938a83d85e03785f51121e791a6e6689a3be95 \ @@ -297,11 +335,14 @@ kiwisolver==1.4.7 \ --hash=sha256:f816dd2277f8d63d79f9c8473a79fe54047bc0467754962840782c575522224d \ --hash=sha256:f9a9e8a507420fe35992ee9ecb302dab68550dedc0da9e2880dd88071c5fb052 # via matplotlib +libtpu==0.0.13 ; sys_platform == "linux" and platform_machine == "x86_64" \ + --hash=sha256:2b4fcd3b902433ef2c22760a3a13b1474491bb4daf88a2670c6c72b295ebe750 + # via -r build/requirements.in markdown-it-py==3.0.0 \ --hash=sha256:355216845c60bd96232cd8d8c40e8f9765cc86f46880e43a8fd22dc1a1a8cab1 \ --hash=sha256:e3f60a94fa066dc52ec76661e37c851cb232d92f9886b15cb560aaada2df8feb # via rich -matplotlib==3.9.2 ; python_version >= "3.11" \ +matplotlib==3.9.2 \ --hash=sha256:039082812cacd6c6bec8e17a9c1e6baca230d4116d522e81e1f63a74d01d2e21 \ --hash=sha256:03ba9c1299c920964e8d3857ba27173b4dbb51ca4bab47ffc2c2ba0eb5e2cbc5 \ --hash=sha256:050598c2b29e0b9832cde72bcf97627bf00262adbc4a54e2b856426bb2ef0697 \ @@ -372,12 +413,15 @@ ml-dtypes==0.5.1 \ --hash=sha256:c9945669d3dadf8acb40ec2e57d38c985d8c285ea73af57fc5b09872c516106d \ --hash=sha256:d13755f8e8445b3870114e5b6240facaa7cb0c3361e54beba3e07fa912a6e12b \ --hash=sha256:fd918d4e6a4e0c110e2e05be7a7814d10dc1b95872accbf6512b80a109b71ae1 - # via -r build/requirements.in + # via + # -r build/requirements.in + # jaxlib + # tensorstore mpmath==1.3.0 \ --hash=sha256:7a28eb2a9774d00c7bc92411c19a89209d5da7c4c9a9e227be8330a23a25b91f \ --hash=sha256:a0b2b9fe80bbcd81a6647ff13108738cfb482d481d826cc0e02f5b35e5c88d2c # via -r build/test-requirements.txt -numpy==2.1.2 ; python_version >= "3.13" \ +numpy==2.1.2 ; python_version == "3.13" \ --hash=sha256:05b2d4e667895cc55e3ff2b56077e4c8a5604361fc21a042845ea3ad67465aa8 \ --hash=sha256:12edb90831ff481f7ef5f6bc6431a9d74dc0e5ff401559a71e5e4611d4f2d466 \ --hash=sha256:13311c2db4c5f7609b462bc0f43d3c465424d25c626d95040f073e30f7570e35 \ @@ -432,75 +476,88 @@ numpy==2.1.2 ; python_version >= "3.13" \ --hash=sha256:faa88bc527d0f097abdc2c663cddf37c05a1c2f113716601555249805cf573f1 \ --hash=sha256:fc44e3c68ff00fd991b59092a54350e6e4911152682b4782f68070985aa9e648 # via - # -r build/requirements.in + # -r build/nonfreethreading-requirements.txt # contourpy + # jaxlib # matplotlib # ml-dtypes # scipy -nvidia-cublas-cu12==12.8.3.14 ; sys_platform == "linux" \ + # tensorstore +nvidia-cublas-cu12==12.8.3.14 \ --hash=sha256:3f0e05e7293598cf61933258b73e66a160c27d59c4422670bf0b79348c04be44 \ --hash=sha256:93a4e0e386cc7f6e56c822531396de8170ed17068a1e18f987574895044cd8c3 \ --hash=sha256:9ae5eae500aead01fc4bdfc458209df638b1a3551557ce11a78eea9ece602ae9 # via - # -r build/test-requirements.txt + # jax-cuda12-plugin # nvidia-cudnn-cu12 # nvidia-cusolver-cu12 -nvidia-cuda-cupti-cu12==12.8.57 ; sys_platform == "linux" \ +nvidia-cuda-cupti-cu12==12.8.57 \ --hash=sha256:8e0b2eb847de260739bee4a3f66fac31378f4ff49538ff527a38a01a9a39f950 \ --hash=sha256:bbed719c52a476958a74cfc42f2b95a3fd6b3fd94eb40134acc4601feb4acac3 \ --hash=sha256:ff154211724fd824e758ce176b66007b558eea19c9a5135fc991827ee147e317 - # via -r build/test-requirements.txt -nvidia-cuda-nvcc-cu12==12.8.61 ; sys_platform == "linux" \ + # via jax-cuda12-plugin +nvidia-cuda-nvcc-cu12==12.8.61 \ --hash=sha256:171f605044ba17bc455d19cad289946c3dbea029a90c60dfa7b88e545bc8e329 \ --hash=sha256:28604ec42aaa09035b0fb7111432e5121bc385580b30c55d2acfb7d644b16548 \ --hash=sha256:4524739cfc080e9c9e53032912be8f020058e0a7186746d19acef3b6d916ea0b - # via -r build/test-requirements.txt -nvidia-cuda-runtime-cu12==12.8.57 ; sys_platform == "linux" \ + # via jax-cuda12-plugin +nvidia-cuda-nvrtc-cu12==12.9.86 ; sys_platform == "linux" \ + --hash=sha256:096d4de6bda726415dfaf3198d4f5c522b8e70139c97feef5cd2ca6d4cd9cead \ + --hash=sha256:210cf05005a447e29214e9ce50851e83fc5f4358df8b453155d5e1918094dcb4 \ + --hash=sha256:72972ebdcf504d69462d3bcd67e7b81edd25d0fb85a2c46d3ea3517666636349 + # via + # -r build/requirements.in + # jax-cuda12-plugin +nvidia-cuda-runtime-cu12==12.8.57 \ --hash=sha256:534ccebd967b6a44292678fa5da4f00666029cb2ed07a79515ea41ef31fe3ec7 \ --hash=sha256:75342e28567340b7428ce79a5d6bb6ca5ff9d07b69e7ce00d2c7b4dc23eff0be \ --hash=sha256:89be637e3ee967323865b85e0f147d75f9a5bd98360befa37481b02dd57af8f5 - # via -r build/test-requirements.txt -nvidia-cudnn-cu12==9.7.1.26 ; sys_platform == "linux" \ - --hash=sha256:6d011159a158f3cfc47bf851aea79e31bcff60d530b70ef70474c84cac484d07 \ - --hash=sha256:7b805b9a4cf9f3da7c5f4ea4a9dff7baf62d1a612d6154a7e0d2ea51ed296241 \ - --hash=sha256:848a61d40ef3b32bd4e1fadb599f0cf04a4b942fbe5fb3be572ad75f9b8c53ef - # via -r build/test-requirements.txt -nvidia-cufft-cu12==11.3.3.41 ; sys_platform == "linux" \ + # via jax-cuda12-plugin +nvidia-cudnn-cu12==9.8.0.87 \ + --hash=sha256:b4b5cfddc32aa4180f9d390ee99e9a9f55a89e7087329b41aba4319327e22466 \ + --hash=sha256:b883faeb2f6f15dba7bbb6756eab6a0d9cecb59db5b0fa07577b9cfa24cd99f4 \ + --hash=sha256:d6b02cd0e3e24aa31d0193a8c39fec239354360d7d81055edddb69f35d53a4c8 + # via jax-cuda12-plugin +nvidia-cufft-cu12==11.3.3.41 \ --hash=sha256:68509dcd7e3306e69d0e2d8a6d21c8b25ed62e6df8aac192ce752f17677398b5 \ --hash=sha256:da650080ab79fcdf7a4b06aa1b460e99860646b176a43f6208099bdc17836b6a \ --hash=sha256:f9760612886786601d27a0993bb29ce1f757e6b8b173499d0ecfa850d31b50f8 - # via -r build/test-requirements.txt -nvidia-cusolver-cu12==11.7.2.55 ; sys_platform == "linux" \ + # via jax-cuda12-plugin +nvidia-cusolver-cu12==11.7.2.55 \ --hash=sha256:0fd9e98246f43c15bee5561147ad235dfdf2d037f5d07c9d41af3f7f72feb7cc \ --hash=sha256:4d1354102f1e922cee9db51920dba9e2559877cf6ff5ad03a00d853adafb191b \ --hash=sha256:a5a516c55da5c5aba98420d9bc9bcab18245f21ec87338cc1f930eb18dd411ac - # via -r build/test-requirements.txt -nvidia-cusparse-cu12==12.5.7.53 ; sys_platform == "linux" \ + # via jax-cuda12-plugin +nvidia-cusparse-cu12==12.5.7.53 \ --hash=sha256:3c1b61eb8c85257ea07e9354606b26397612627fdcd327bfd91ccf6155e7c86d \ --hash=sha256:82c201d6781bacf6bb7c654f0446728d0fe596dfdd82ef4a04c204ce3e107441 \ --hash=sha256:d869c6146ca80f4305b62e02d924b4aaced936f8173e3cef536a67eed2a91af1 # via - # -r build/test-requirements.txt + # jax-cuda12-plugin # nvidia-cusolver-cu12 -nvidia-nccl-cu12==2.25.1 ; sys_platform == "linux" \ +nvidia-nccl-cu12==2.25.1 \ --hash=sha256:362aed5963fb9ea2ed2f264409baae30143498fd0e5c503aeaa1badd88cdc54a \ --hash=sha256:4ab428bc915785cc66e8c57cb34c7a64cf739c46702b8db748b6ad6cc7180cf8 - # via -r build/test-requirements.txt -nvidia-nvjitlink-cu12==12.8.61 ; sys_platform == "linux" \ + # via jax-cuda12-plugin +nvidia-nvjitlink-cu12==12.8.61 \ --hash=sha256:1166a964d25fdc0eae497574d38824305195a5283324a21ccb0ce0c802cbf41c \ --hash=sha256:45fd79f2ae20bd67e8bc411055939049873bfd8fac70ff13bd4865e0b9bdab17 \ --hash=sha256:9b80ecab31085dda3ce3b41d043be0ec739216c3fc633b8abe212d5a30026df0 # via - # -r build/test-requirements.txt + # jax-cuda12-plugin # nvidia-cufft-cu12 # nvidia-cusolver-cu12 # nvidia-cusparse-cu12 +nvidia-nvshmem-cu12==3.2.5 ; sys_platform == "linux" \ + --hash=sha256:2f5798d65f1a08f9878aae17cf4d3dcbfe884d1f12cf170556cd40f2be90ca96 \ + --hash=sha256:e076957d5cc72e51061a04f2d46f55df477be53e8a55d0d621be08f7aefe1d00 + # via + # -r build/requirements.in + # jax-cuda12-plugin opt-einsum==3.4.0 \ --hash=sha256:69bb92469f86a1565195ece4ac0323943e83477171b91d24c35afe028a90d7cd \ --hash=sha256:96ca72f1b886d148241348783498194c577fa30a8faac108586b14f1ba4473ac - # via - # -r build/requirements.in - # -r build/test-requirements.txt + # via -r build/requirements.in packaging==24.1 \ --hash=sha256:026ed72c8ed3fcce5bf8950572258698927fd1dbda10a5e981cdf0ac37f4f002 \ --hash=sha256:5b8f2217dbdbd2f7f384c41c628544e6d52f2d0f53c6d0c3ea61aa5d1d7ff124 @@ -600,25 +657,18 @@ pluggy==1.5.0 \ portpicker==1.6.0 \ --hash=sha256:b2787a41404cf7edbe29b07b9e0ed863b09f2665dcc01c1eb0c2261c1e7d0755 \ --hash=sha256:bd507fd6f96f65ee02781f2e674e9dc6c99bbfa6e3c39992e3916204c9d431fa - # via -r build/test-requirements.txt -psutil==6.0.0 \ - --hash=sha256:02b69001f44cc73c1c5279d02b30a817e339ceb258ad75997325e0e6169d8b35 \ - --hash=sha256:1287c2b95f1c0a364d23bc6f2ea2365a8d4d9b726a3be7294296ff7ba97c17f0 \ - --hash=sha256:1e7c870afcb7d91fdea2b37c24aeb08f98b6d67257a5cb0a8bc3ac68d0f1a68c \ - --hash=sha256:21f1fb635deccd510f69f485b87433460a603919b45e2a324ad65b0cc74f8fb1 \ - --hash=sha256:33ea5e1c975250a720b3a6609c490db40dae5d83a4eb315170c4fe0d8b1f34b3 \ - --hash=sha256:34859b8d8f423b86e4385ff3665d3f4d94be3cdf48221fbe476e883514fdb71c \ - --hash=sha256:5fd9a97c8e94059b0ef54a7d4baf13b405011176c3b6ff257c247cae0d560ecd \ - --hash=sha256:6ec7588fb3ddaec7344a825afe298db83fe01bfaaab39155fa84cf1c0d6b13c3 \ - --hash=sha256:6ed2440ada7ef7d0d608f20ad89a04ec47d2d3ab7190896cd62ca5fc4fe08bf0 \ - --hash=sha256:8faae4f310b6d969fa26ca0545338b21f73c6b15db7c4a8d934a5482faa818f2 \ - --hash=sha256:a021da3e881cd935e64a3d0a20983bda0bb4cf80e4f74fa9bfcb1bc5785360c6 \ - --hash=sha256:a495580d6bae27291324fe60cea0b5a7c23fa36a7cd35035a16d93bdcf076b9d \ - --hash=sha256:a9a3dbfb4de4f18174528d87cc352d1f788b7496991cca33c6996f40c9e3c92c \ - --hash=sha256:c588a7e9b1173b6e866756dde596fd4cad94f9399daf99ad8c3258b3cb2b47a0 \ - --hash=sha256:e2e8d0054fc88153ca0544f5c4d554d42e33df2e009c4ff42284ac9ebdef4132 \ - --hash=sha256:fc8c9510cde0146432bbdb433322861ee8c3efbf8589865c8bf8d21cb30c4d14 \ - --hash=sha256:ffe7fc9b6b36beadc8c322f84e1caff51e8703b88eee1da46d1e3a6ae11b4fd0 + # via -r build/nonfreethreading-requirements.txt +psutil==7.0.0 \ + --hash=sha256:101d71dc322e3cffd7cea0650b09b3d08b8e7c4109dd6809fe452dfd00e58b25 \ + --hash=sha256:1e744154a6580bc968a0195fd25e80432d3afec619daf145b9e5ba16cc1d688e \ + --hash=sha256:1fcee592b4c6f146991ca55919ea3d1f8926497a713ed7faaf8225e174581e91 \ + --hash=sha256:39db632f6bb862eeccf56660871433e111b6ea58f2caea825571951d4b6aa3da \ + --hash=sha256:4b1388a4f6875d7e2aff5c4ca1cc16c545ed41dd8bb596cefea80111db353a34 \ + --hash=sha256:4cf3d4eb1aa9b348dec30105c55cd9b7d4629285735a102beb4441e38db90553 \ + --hash=sha256:7be9c3eba38beccb6495ea33afd982a44074b78f28c434a1f51cc07fd315c456 \ + --hash=sha256:84df4eb63e16849689f76b1ffcb36db7b8de703d1bc1fe41773db487621b6c17 \ + --hash=sha256:a5f098451abc2828f7dc6b58d44b532b22f2088f4999a937557b603ce72b1993 \ + --hash=sha256:ba3fcef7523064a6c9da440fc4d6bd07da93ac726b5733c29027d7dc95b39d99 # via portpicker pyelftools==0.31 \ --hash=sha256:c774416b10310156879443b81187d182d8d9ee499660380e645918b50bc88f99 \ @@ -652,41 +702,56 @@ rich==13.9.2 \ --hash=sha256:51a2c62057461aaf7152b4d611168f93a9fc73068f8ded2790f29fe2b5366d0c \ --hash=sha256:8c82a3d3f8dcfe9e734771313e606b39d8247bb6b826e196f4914b333b743cf1 # via -r build/test-requirements.txt -scipy==1.14.1 \ - --hash=sha256:0c2f95de3b04e26f5f3ad5bb05e74ba7f68b837133a4492414b3afd79dfe540e \ - --hash=sha256:1729560c906963fc8389f6aac023739ff3983e727b1a4d87696b7bf108316a79 \ - --hash=sha256:278266012eb69f4a720827bdd2dc54b2271c97d84255b2faaa8f161a158c3b37 \ - --hash=sha256:2843f2d527d9eebec9a43e6b406fb7266f3af25a751aa91d62ff416f54170bc5 \ - --hash=sha256:2da0469a4ef0ecd3693761acbdc20f2fdeafb69e6819cc081308cc978153c675 \ - --hash=sha256:2ff0a7e01e422c15739ecd64432743cf7aae2b03f3084288f399affcefe5222d \ - --hash=sha256:2ff38e22128e6c03ff73b6bb0f85f897d2362f8c052e3b8ad00532198fbdae3f \ - --hash=sha256:30ac8812c1d2aab7131a79ba62933a2a76f582d5dbbc695192453dae67ad6310 \ - --hash=sha256:3a1b111fac6baec1c1d92f27e76511c9e7218f1695d61b59e05e0fe04dc59617 \ - --hash=sha256:4079b90df244709e675cdc8b93bfd8a395d59af40b72e339c2287c91860deb8e \ - --hash=sha256:5149e3fd2d686e42144a093b206aef01932a0059c2a33ddfa67f5f035bdfe13e \ - --hash=sha256:5a275584e726026a5699459aa72f828a610821006228e841b94275c4a7c08417 \ - --hash=sha256:631f07b3734d34aced009aaf6fedfd0eb3498a97e581c3b1e5f14a04164a456d \ - --hash=sha256:716e389b694c4bb564b4fc0c51bc84d381735e0d39d3f26ec1af2556ec6aad94 \ - --hash=sha256:8426251ad1e4ad903a4514712d2fa8fdd5382c978010d1c6f5f37ef286a713ad \ - --hash=sha256:8475230e55549ab3f207bff11ebfc91c805dc3463ef62eda3ccf593254524ce8 \ - --hash=sha256:8bddf15838ba768bb5f5083c1ea012d64c9a444e16192762bd858f1e126196d0 \ - --hash=sha256:8e32dced201274bf96899e6491d9ba3e9a5f6b336708656466ad0522d8528f69 \ - --hash=sha256:8f9ea80f2e65bdaa0b7627fb00cbeb2daf163caa015e59b7516395fe3bd1e066 \ - --hash=sha256:97c5dddd5932bd2a1a31c927ba5e1463a53b87ca96b5c9bdf5dfd6096e27efc3 \ - --hash=sha256:a49f6ed96f83966f576b33a44257d869756df6cf1ef4934f59dd58b25e0327e5 \ - --hash=sha256:af29a935803cc707ab2ed7791c44288a682f9c8107bc00f0eccc4f92c08d6e07 \ - --hash=sha256:b05d43735bb2f07d689f56f7b474788a13ed8adc484a85aa65c0fd931cf9ccd2 \ - --hash=sha256:b28d2ca4add7ac16ae8bb6632a3c86e4b9e4d52d3e34267f6e1b0c1f8d87e389 \ - --hash=sha256:b99722ea48b7ea25e8e015e8341ae74624f72e5f21fc2abd45f3a93266de4c5d \ - --hash=sha256:baff393942b550823bfce952bb62270ee17504d02a1801d7fd0719534dfb9c84 \ - --hash=sha256:c0ee987efa6737242745f347835da2cc5bb9f1b42996a4d97d5c7ff7928cb6f2 \ - --hash=sha256:d0d2821003174de06b69e58cef2316a6622b60ee613121199cb2852a873f8cf3 \ - --hash=sha256:e0cf28db0f24a38b2a0ca33a85a54852586e43cf6fd876365c86e0657cfe7d73 \ - --hash=sha256:e4f5a7c49323533f9103d4dacf4e4f07078f360743dec7f7596949149efeec06 \ - --hash=sha256:eb58ca0abd96911932f688528977858681a59d61a7ce908ffd355957f7025cfc \ - --hash=sha256:edaf02b82cd7639db00dbff629995ef185c8df4c3ffa71a5562a595765a06ce1 \ - --hash=sha256:fef8c87f8abfb884dac04e97824b61299880c43f4ce675dd2cbeadd3c9b466d2 - # via -r build/requirements.in +scipy==1.15.2 ; python_version >= "3.13" \ + --hash=sha256:01edfac9f0798ad6b46d9c4c9ca0e0ad23dbf0b1eb70e96adb9fa7f525eff0bf \ + --hash=sha256:03205d57a28e18dfd39f0377d5002725bf1f19a46f444108c29bdb246b6c8a11 \ + --hash=sha256:08b57a9336b8e79b305a143c3655cc5bdbe6d5ece3378578888d2afbb51c4e37 \ + --hash=sha256:11e7ad32cf184b74380f43d3c0a706f49358b904fa7d5345f16ddf993609184d \ + --hash=sha256:28a0d2c2075946346e4408b211240764759e0fabaeb08d871639b5f3b1aca8a0 \ + --hash=sha256:2b871df1fe1a3ba85d90e22742b93584f8d2b8e6124f8372ab15c71b73e428b8 \ + --hash=sha256:302093e7dfb120e55515936cb55618ee0b895f8bcaf18ff81eca086c17bd80af \ + --hash=sha256:42dabaaa798e987c425ed76062794e93a243be8f0f20fff6e7a89f4d61cb3d40 \ + --hash=sha256:447ce30cee6a9d5d1379087c9e474628dab3db4a67484be1b7dc3196bfb2fac9 \ + --hash=sha256:4c6676490ad76d1c2894d77f976144b41bd1a4052107902238047fb6a473e971 \ + --hash=sha256:54c462098484e7466362a9f1672d20888f724911a74c22ae35b61f9c5919183d \ + --hash=sha256:597a0c7008b21c035831c39927406c6181bcf8f60a73f36219b69d010aa04737 \ + --hash=sha256:5a6fd6eac1ce74a9f77a7fc724080d507c5812d61e72bd5e4c489b042455865e \ + --hash=sha256:5ea7ed46d437fc52350b028b1d44e002646e28f3e8ddc714011aaf87330f2f32 \ + --hash=sha256:601881dfb761311045b03114c5fe718a12634e5608c3b403737ae463c9885d53 \ + --hash=sha256:62ca1ff3eb513e09ed17a5736929429189adf16d2d740f44e53270cc800ecff1 \ + --hash=sha256:69ea6e56d00977f355c0f84eba69877b6df084516c602d93a33812aa04d90a3d \ + --hash=sha256:6a8e34cf4c188b6dd004654f88586d78f95639e48a25dfae9c5e34a6dc34547e \ + --hash=sha256:6d0194c37037707b2afa7a2f2a924cf7bac3dc292d51b6a925e5fcb89bc5c776 \ + --hash=sha256:6f223753c6ea76983af380787611ae1291e3ceb23917393079dcc746ba60cfb5 \ + --hash=sha256:6f5e296ec63c5da6ba6fa0343ea73fd51b8b3e1a300b0a8cae3ed4b1122c7462 \ + --hash=sha256:7cd5b77413e1855351cdde594eca99c1f4a588c2d63711388b6a1f1c01f62274 \ + --hash=sha256:869269b767d5ee7ea6991ed7e22b3ca1f22de73ab9a49c44bad338b725603301 \ + --hash=sha256:87994da02e73549dfecaed9e09a4f9d58a045a053865679aeb8d6d43747d4df3 \ + --hash=sha256:888307125ea0c4466287191e5606a2c910963405ce9671448ff9c81c53f85f58 \ + --hash=sha256:92233b2df6938147be6fa8824b8136f29a18f016ecde986666be5f4d686a91a4 \ + --hash=sha256:9412f5e408b397ff5641080ed1e798623dbe1ec0d78e72c9eca8992976fa65aa \ + --hash=sha256:9b18aa747da280664642997e65aab1dd19d0c3d17068a04b3fe34e2559196cb9 \ + --hash=sha256:9de9d1416b3d9e7df9923ab23cd2fe714244af10b763975bea9e4f2e81cebd27 \ + --hash=sha256:a2ec871edaa863e8213ea5df811cd600734f6400b4af272e1c011e69401218e9 \ + --hash=sha256:a5080a79dfb9b78b768cebf3c9dcbc7b665c5875793569f48bf0e2b1d7f68f6f \ + --hash=sha256:a8bf5cb4a25046ac61d38f8d3c3426ec11ebc350246a4642f2f315fe95bda655 \ + --hash=sha256:b09ae80010f52efddb15551025f9016c910296cf70adbf03ce2a8704f3a5ad20 \ + --hash=sha256:b5e025e903b4f166ea03b109bb241355b9c42c279ea694d8864d033727205e65 \ + --hash=sha256:bad78d580270a4d32470563ea86c6590b465cb98f83d760ff5b0990cb5518a93 \ + --hash=sha256:bae43364d600fdc3ac327db99659dcb79e6e7ecd279a75fe1266669d9a652828 \ + --hash=sha256:c4697a10da8f8765bb7c83e24a470da5797e37041edfd77fd95ba3811a47c4fd \ + --hash=sha256:c90ebe8aaa4397eaefa8455a8182b164a6cc1d59ad53f79943f266d99f68687f \ + --hash=sha256:cd58a314d92838f7e6f755c8a2167ead4f27e1fd5c1251fd54289569ef3495ec \ + --hash=sha256:cf72ff559a53a6a6d77bd8eefd12a17995ffa44ad86c77a5df96f533d4e6c6bb \ + --hash=sha256:def751dd08243934c884a3221156d63e15234a3155cf25978b0a668409d45eb6 \ + --hash=sha256:e7c68b6a43259ba0aab737237876e5c2c549a031ddb7abc28c7b47f22e202ded \ + --hash=sha256:ecf797d2d798cf7c838c6d98321061eb3e72a74710e6c40540f0e8087e3b499e \ + --hash=sha256:f031846580d9acccd0044efd1a90e6f4df3a6e12b4b6bd694a7bc03a89892b28 \ + --hash=sha256:fb530e4794fc8ea76a4a21ccb67dea33e5e0e60f07fc38a49e821e1eae3b71a0 \ + --hash=sha256:fe8a9eb875d430d81755472c5ba75e84acc980e4a8f6204d402849234d3017db + # via + # -r build/requirements.in + # jaxlib six==1.16.0 \ --hash=sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926 \ --hash=sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254 @@ -695,6 +760,29 @@ sortedcontainers==2.4.0 \ --hash=sha256:25caa5a06cc30b6b83d11423433f65d1f9d76c4c6a0c90e3379eaa43b9bfdb88 \ --hash=sha256:a163dcaede0f1c021485e957a39245190e74249897e2ae4b2aa38595db237ee0 # via hypothesis +tensorstore==0.1.73 \ + --hash=sha256:03cec5141a27d2e65e4ff604641cfb1f7989d66c361534392e810b80cbda617d \ + --hash=sha256:0429bf781ce3ed45be761b46f4bc5979412dadf063f509cb7e9581981a1e097b \ + --hash=sha256:05f7fdcb063f08f40f74c49f92c0f0136c5b715d49e111950bf025b12a72a907 \ + --hash=sha256:0eb83a2526e211a721842c3e98293e4bc9e1fdb9dac37ecf37d6ccbde84b8ee3 \ + --hash=sha256:192feb8a8fd0f37fa298588d037d4889d2f9d07b18b3295488f05ee268f57b70 \ + --hash=sha256:2aed43498b00d37df583da9e06328751cfe695bb166043aa9ef7183174cf7e29 \ + --hash=sha256:421a3f87864a0a8837b4f9f0c8ee86079b46b112de902496d3b90c72f51d02ea \ + --hash=sha256:440569458b91974e0ffa210654a01f2721758476c48240f7c925fc0d107056be \ + --hash=sha256:4433dcfcb943e100b90b0fc8e0b1d174e8c2c1cedb1fcc86e6d20b6a2e961831 \ + --hash=sha256:44d70dd0c000db8c0d2386e788c5e91d3b37ebee8f629f3848d7a012c85d1e11 \ + --hash=sha256:5fc9feab09de9e99c381145adeef5ff9e01f898e509b851ff2edd940c8b2384a \ + --hash=sha256:70d57b63706de4a3a9c1c217b338658fa160b2d41f5b399e6926f9eaf29b2a4d \ + --hash=sha256:7a812e8297a4ed70109057628b767c1a12b535f2db657635f0ed1517b23b990b \ + --hash=sha256:7b4e08bfa61880863bedb90499a23c63d9493cf9310207c230086b0a3700c75d \ + --hash=sha256:83c6ca5cb39ffeeb4a562942e3b9e2f32b026f362b2b7266c44201bd7c3116a5 \ + --hash=sha256:87fb7879af73a5b7ded9c9de3e2014baf6468d9d7c47edfc19490907b346e0a6 \ + --hash=sha256:a11d2e496d7442c68b35cd222a8c8df3fdee9e30fb2984c91546d81faff8bf61 \ + --hash=sha256:be3f5ef6f359486ee52785e8a302819152e51286c50181c6c35f316b7568ce60 \ + --hash=sha256:dd7fa6d7e9579a1a75e6185d7df10e28fcc7db2e14190ed60261a71b9c09e1df \ + --hash=sha256:e99ae99ac48f41c4e36b1e3717c6dbdab96dd27fc91618dd01afb9ad848a9293 \ + --hash=sha256:f24b325385fd30be612ab8494a29d3bfef37b9444357912ba184f30f325f093b + # via -r build/nonfreethreading-requirements.txt typing-extensions==4.12.2 \ --hash=sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d \ --hash=sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8 @@ -702,7 +790,7 @@ typing-extensions==4.12.2 \ wheel==0.44.0 \ --hash=sha256:2376a90c98cc337d18623527a97c31797bd02bad0033d41547043a1cbfbe448f \ --hash=sha256:a29c3f2817e95ab89aa4660681ad547c0e9547f20e75b0562fe7723c9a2a9d49 - # via -r build/test-requirements.txt + # via -r build/requirements.in zipp==3.20.2 \ --hash=sha256:a817ac80d6cf4b23bf7f2828b7cabf326f15a001bea8b1f9b49631780ba28350 \ --hash=sha256:bc9eb26f4506fda01b81bcde0ca78103b6e62f991b381fec825435c836edbc29 @@ -805,12 +893,10 @@ zstandard==0.23.0 \ --hash=sha256:fd30d9c67d13d891f2360b2a120186729c111238ac63b43dbd37a5a40670b8ca \ --hash=sha256:fd7699e8fd9969f455ef2926221e0233f81a2542921471382e77a9e2f2b57f4b \ --hash=sha256:fe3b385d996ee0822fd46528d9f0443b880d4d05528fd26a9119a54ec3f91c69 - # via -r build/requirements.in + # via -r build/nonfreethreading-requirements.txt # The following packages are considered to be unsafe in a requirements file: setuptools==76.0.0 \ --hash=sha256:199466a166ff664970d0ee145839f5582cb9bca7a0a3a2e795b6a9cb2308e9c6 \ --hash=sha256:43b4ee60e10b0d0ee98ad11918e114c70701bc6051662a9a675a0496c1a158f4 - # via - # -r build/requirements.in - # -r build/test-requirements.txt + # via -r build/requirements.in diff --git a/build/requirements_lock_3_13_ft.txt b/build/requirements_lock_3_13_ft.txt index e7a2968e981e..348fa9628a76 100644 --- a/build/requirements_lock_3_13_ft.txt +++ b/build/requirements_lock_3_13_ft.txt @@ -2,7 +2,7 @@ # This file is autogenerated by pip-compile with Python 3.13 # by the following command: # -# pip-compile --allow-unsafe --generate-hashes --output-file=build/requirements_lock_3_13_ft.txt build/requirements.in +# bazel run //build:requirements_ft.update # absl-py==2.1.0 \ --hash=sha256:526a04eadab8b4ee719ce68f204172ead1027549089702d99b9059f129ff1308 \ @@ -19,7 +19,7 @@ auditwheel==6.2.0 \ build==1.2.2.post1 \ --hash=sha256:1d61c0887fa860c01971625baae8bdd338e517b836a2f70dd1f7aa3a6b2fc5b5 \ --hash=sha256:b36993e92ca9375a219c99e606a122ff365a760a2d4bba0caa09bd5278b608b7 - # via -r build/test-requirements.txt + # via -r build/requirements.in cloudpickle==3.1.0 \ --hash=sha256:81a929b6e3c7335c863c771d673d105f02efdb89dfaba0c90495d1c64796601b \ --hash=sha256:fe11acda67f61aaaec473e3afe030feb131d78a43461b718185363384f1ba12e @@ -27,7 +27,7 @@ cloudpickle==3.1.0 \ colorama==0.4.6 \ --hash=sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44 \ --hash=sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6 - # via -r build/test-requirements.txt + # via -r build/requirements.in contourpy==1.3.1 \ --hash=sha256:041b640d4ec01922083645a94bb3b2e777e6b626788f4095cf21abbe266413c1 \ --hash=sha256:05e806338bfeaa006acbdeba0ad681a10be63b26e1b17317bfac3c5d98f36cda \ @@ -172,6 +172,44 @@ iniconfig==2.0.0 \ --hash=sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3 \ --hash=sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374 # via pytest +jax-cuda12-pjrt==0.6.2 ; sys_platform == "linux" \ + --hash=sha256:22faf020d2e8f7ca1e2915633241f7df7678b73c7078f5f0b2f113248337f7de \ + --hash=sha256:8cd9ead7948ea2c778a508fef5d1159e8b7abf4fccc7037c3fe1dbfcd95012dc + # via + # -r build/requirements.in + # jax-cuda12-plugin +jax-cuda12-plugin[with-cuda]==0.6.2 ; sys_platform == "linux" \ + --hash=sha256:0896cbb308d95291e205cd89d254029dee3a1df43d66e9831331a9afd2d27870 \ + --hash=sha256:1751f88989269b3cdb0dfe4f7b072a6442149818c9bc98c3a395c8acaf910a79 \ + --hash=sha256:2cd8e279a59a38ba0c978a831e13adeb6ee9e4572fba387c7975ba3ad535dd38 \ + --hash=sha256:6c9b002d13b1fcb9403713eedd3876a227ad1ffbdfb3811b1f9f89af4c25a5f7 \ + --hash=sha256:773efa8b55a837406c561f0ef02144dda9019181193760ec5419eec9dd2b9aac \ + --hash=sha256:83345f52f610cdb8e90044566d8e120864150b8090968c8ab6dd8e0bfb9a6a9f \ + --hash=sha256:bc5c3a75d05519b4d326e4669d0f7ad0fe0f0acf875f9313d913748ccca5a9ea \ + --hash=sha256:db4c6103c912d8cd1adf94c34d313bb4760ca7f01c897ca7cd62e65f27994199 \ + --hash=sha256:ed5316ca1818db7ef53230ee0a41398d3a60942e361dfb857a952eb4d92fc8d7 \ + --hash=sha256:febd099f970d350eb8fa5a2c9a2fb4b0ea7b3d6a89df1496663edfa7afe590e5 + # via -r build/requirements.in +jaxlib==0.6.2 \ + --hash=sha256:11eae7e05bc5a79875da36324afb9eddd4baeaef2a0386caf6d4f3720b9aef28 \ + --hash=sha256:153eaa51f778b60851720729d4f461a91edd9ba3932f6f3bc598d4413870038b \ + --hash=sha256:335d7e3515ce78b52a410136f46aa4a7ea14d0e7d640f34e1e137409554ad0ac \ + --hash=sha256:34d8a684a8be949dd87dd4acc97101b4106a0dc9ad151ec891da072319a57b99 \ + --hash=sha256:39cf9555f85ae1ce2e2c1a59fc71f2eca4f9867a7cb934fef881ba56b11371d1 \ + --hash=sha256:3abd536e44b05fb1657507e3ff1fc3691f99613bae3921ecab9e82f27255f784 \ + --hash=sha256:4205d098ce8efb5f7fe2fe5098bae6036094dc8d8829f5e0e0d7a9b155326336 \ + --hash=sha256:70498837caf538bd458ff6858c8bfd404db82015aba8f663670197fa9900ff02 \ + --hash=sha256:87ec2dc9c3ed9ab936eec8535160c5fbd2c849948559f1c5daa75f63fabe5942 \ + --hash=sha256:921dbd4db214eba19a29ba9f2450d880e08b2b2c7b968f28cc89da3e62366af4 \ + --hash=sha256:a208ff61c58128d306bb4e5ad0858bd2b0960f2c1c10ad42c548f74a60c0020e \ + --hash=sha256:b977604cd36c74b174d25ed685017379468138eb747d865f75e466cb273c801d \ + --hash=sha256:bff67b188133ce1f0111c7b163ac321fd646b59ed221ea489063e2e0f85cb967 \ + --hash=sha256:c087a0eb6fb7f6f8f54d56f4730328dfde5040dd3b5ddfa810e7c28ea7102b42 \ + --hash=sha256:c6815509997d6b05e5c9daa7994b9ad473ce3e8c8a17bdbbcacc3c744f76f7a0 \ + --hash=sha256:da4601b2b5dc8c23d6afb293eacfb9aec4e1d1871cb2f29c5a151d103e73b0f8 \ + --hash=sha256:f1dd09b481a93c1d4c750013f467f74194493ba7bd29fcd4d1cec16e3a214f65 \ + --hash=sha256:f94163f14c8fd3ba93ae14b631abacf14cb031bba0b59138869984b4d10375f8 + # via -r build/requirements.in kiwisolver==1.4.8 \ --hash=sha256:01c3d31902c7db5fb6182832713d3b4122ad9317c2c5877d0539227d96bb2e50 \ --hash=sha256:034d2c891f76bd3edbdb3ea11140d8510dca675443da7304205a2eaa45d8334c \ @@ -254,11 +292,14 @@ kiwisolver==1.4.8 \ --hash=sha256:ed33ca2002a779a2e20eeb06aea7721b6e47f2d4b8a8ece979d8ba9e2a167e34 \ --hash=sha256:fc2ace710ba7c1dfd1a3b42530b62b9ceed115f19a1656adefce7b1782a37794 # via matplotlib +libtpu==0.0.13 ; sys_platform == "linux" and platform_machine == "x86_64" \ + --hash=sha256:2b4fcd3b902433ef2c22760a3a13b1474491bb4daf88a2670c6c72b295ebe750 + # via -r build/requirements.in markdown-it-py==3.0.0 \ --hash=sha256:355216845c60bd96232cd8d8c40e8f9765cc86f46880e43a8fd22dc1a1a8cab1 \ --hash=sha256:e3f60a94fa066dc52ec76661e37c851cb232d92f9886b15cb560aaada2df8feb # via rich -matplotlib==3.10.0 ; python_version >= "3.11" \ +matplotlib==3.10.0 \ --hash=sha256:01d2b19f13aeec2e759414d3bfe19ddfb16b13a1250add08d46d5ff6f9be83c6 \ --hash=sha256:12eaf48463b472c3c0f8dbacdbf906e573013df81a0ab82f0616ea4b11281908 \ --hash=sha256:2c5829a5a1dd5a71f0e31e6e8bb449bc0ee9dbfb05ad28fc0c6b55101b3a4be6 \ @@ -323,137 +364,151 @@ ml-dtypes==0.5.1 \ --hash=sha256:c9945669d3dadf8acb40ec2e57d38c985d8c285ea73af57fc5b09872c516106d \ --hash=sha256:d13755f8e8445b3870114e5b6240facaa7cb0c3361e54beba3e07fa912a6e12b \ --hash=sha256:fd918d4e6a4e0c110e2e05be7a7814d10dc1b95872accbf6512b80a109b71ae1 - # via -r build/requirements.in + # via + # -r build/requirements.in + # jaxlib mpmath==1.3.0 \ --hash=sha256:7a28eb2a9774d00c7bc92411c19a89209d5da7c4c9a9e227be8330a23a25b91f \ --hash=sha256:a0b2b9fe80bbcd81a6647ff13108738cfb482d481d826cc0e02f5b35e5c88d2c # via -r build/test-requirements.txt -numpy==2.2.1 ; python_version >= "3.13" \ - --hash=sha256:059e6a747ae84fce488c3ee397cee7e5f905fd1bda5fb18c66bc41807ff119b2 \ - --hash=sha256:08ef779aed40dbc52729d6ffe7dd51df85796a702afbf68a4f4e41fafdc8bda5 \ - --hash=sha256:164a829b6aacf79ca47ba4814b130c4020b202522a93d7bff2202bfb33b61c60 \ - --hash=sha256:26c9c4382b19fcfbbed3238a14abf7ff223890ea1936b8890f058e7ba35e8d71 \ - --hash=sha256:27f5cdf9f493b35f7e41e8368e7d7b4bbafaf9660cba53fb21d2cd174ec09631 \ - --hash=sha256:31b89fa67a8042e96715c68e071a1200c4e172f93b0fbe01a14c0ff3ff820fc8 \ - --hash=sha256:32cb94448be47c500d2c7a95f93e2f21a01f1fd05dd2beea1ccd049bb6001cd2 \ - --hash=sha256:360137f8fb1b753c5cde3ac388597ad680eccbbbb3865ab65efea062c4a1fd16 \ - --hash=sha256:3683a8d166f2692664262fd4900f207791d005fb088d7fdb973cc8d663626faa \ - --hash=sha256:38efc1e56b73cc9b182fe55e56e63b044dd26a72128fd2fbd502f75555d92591 \ - --hash=sha256:3d03883435a19794e41f147612a77a8f56d4e52822337844fff3d4040a142964 \ - --hash=sha256:3ecc47cd7f6ea0336042be87d9e7da378e5c7e9b3c8ad0f7c966f714fc10d821 \ - --hash=sha256:40f9e544c1c56ba8f1cf7686a8c9b5bb249e665d40d626a23899ba6d5d9e1484 \ - --hash=sha256:4250888bcb96617e00bfa28ac24850a83c9f3a16db471eca2ee1f1714df0f957 \ - --hash=sha256:4511d9e6071452b944207c8ce46ad2f897307910b402ea5fa975da32e0102800 \ - --hash=sha256:45681fd7128c8ad1c379f0ca0776a8b0c6583d2f69889ddac01559dfe4390918 \ - --hash=sha256:48fd472630715e1c1c89bf1feab55c29098cb403cc184b4859f9c86d4fcb6a95 \ - --hash=sha256:4c86e2a209199ead7ee0af65e1d9992d1dce7e1f63c4b9a616500f93820658d0 \ - --hash=sha256:4dfda918a13cc4f81e9118dea249e192ab167a0bb1966272d5503e39234d694e \ - --hash=sha256:5062dc1a4e32a10dc2b8b13cedd58988261416e811c1dc4dbdea4f57eea61b0d \ - --hash=sha256:51faf345324db860b515d3f364eaa93d0e0551a88d6218a7d61286554d190d73 \ - --hash=sha256:526fc406ab991a340744aad7e25251dd47a6720a685fa3331e5c59fef5282a59 \ - --hash=sha256:53c09385ff0b72ba79d8715683c1168c12e0b6e84fb0372e97553d1ea91efe51 \ - --hash=sha256:55ba24ebe208344aa7a00e4482f65742969a039c2acfcb910bc6fcd776eb4355 \ - --hash=sha256:5b6c390bfaef8c45a260554888966618328d30e72173697e5cabe6b285fb2348 \ - --hash=sha256:5c5cc0cbabe9452038ed984d05ac87910f89370b9242371bd9079cb4af61811e \ - --hash=sha256:5edb4e4caf751c1518e6a26a83501fda79bff41cc59dac48d70e6d65d4ec4440 \ - --hash=sha256:61048b4a49b1c93fe13426e04e04fdf5a03f456616f6e98c7576144677598675 \ - --hash=sha256:676f4eebf6b2d430300f1f4f4c2461685f8269f94c89698d832cdf9277f30b84 \ - --hash=sha256:67d4cda6fa6ffa073b08c8372aa5fa767ceb10c9a0587c707505a6d426f4e046 \ - --hash=sha256:694f9e921a0c8f252980e85bce61ebbd07ed2b7d4fa72d0e4246f2f8aa6642ab \ - --hash=sha256:733585f9f4b62e9b3528dd1070ec4f52b8acf64215b60a845fa13ebd73cd0712 \ - --hash=sha256:7671dc19c7019103ca44e8d94917eba8534c76133523ca8406822efdd19c9308 \ - --hash=sha256:780077d95eafc2ccc3ced969db22377b3864e5b9a0ea5eb347cc93b3ea900315 \ - --hash=sha256:7ba9cc93a91d86365a5d270dee221fdc04fb68d7478e6bf6af650de78a8339e3 \ - --hash=sha256:89b16a18e7bba224ce5114db863e7029803c179979e1af6ad6a6b11f70545008 \ - --hash=sha256:9036d6365d13b6cbe8f27a0eaf73ddcc070cae584e5ff94bb45e3e9d729feab5 \ - --hash=sha256:93cf4e045bae74c90ca833cba583c14b62cb4ba2cba0abd2b141ab52548247e2 \ - --hash=sha256:9ad014faa93dbb52c80d8f4d3dcf855865c876c9660cb9bd7553843dd03a4b1e \ - --hash=sha256:9b1d07b53b78bf84a96898c1bc139ad7f10fda7423f5fd158fd0f47ec5e01ac7 \ - --hash=sha256:a7746f235c47abc72b102d3bce9977714c2444bdfaea7888d241b4c4bb6a78bf \ - --hash=sha256:aa3017c40d513ccac9621a2364f939d39e550c542eb2a894b4c8da92b38896ab \ - --hash=sha256:b34d87e8a3090ea626003f87f9392b3929a7bbf4104a05b6667348b6bd4bf1cd \ - --hash=sha256:b541032178a718c165a49638d28272b771053f628382d5e9d1c93df23ff58dbf \ - --hash=sha256:ba5511d8f31c033a5fcbda22dd5c813630af98c70b2661f2d2c654ae3cdfcfc8 \ - --hash=sha256:bc8a37ad5b22c08e2dbd27df2b3ef7e5c0864235805b1e718a235bcb200cf1cb \ - --hash=sha256:bff7d8ec20f5f42607599f9994770fa65d76edca264a87b5e4ea5629bce12268 \ - --hash=sha256:c1ad395cf254c4fbb5b2132fee391f361a6e8c1adbd28f2cd8e79308a615fe9d \ - --hash=sha256:f1d09e520217618e76396377c81fba6f290d5f926f50c35f3a5f72b01a0da780 \ - --hash=sha256:f3eac17d9ec51be534685ba877b6ab5edc3ab7ec95c8f163e5d7b39859524716 \ - --hash=sha256:f419290bc8968a46c4933158c91a0012b7a99bb2e465d5ef5293879742f8797e \ - --hash=sha256:f62aa6ee4eb43b024b0e5a01cf65a0bb078ef8c395e8713c6e8a12a697144528 \ - --hash=sha256:f74e6fdeb9a265624ec3a3918430205dff1df7e95a230779746a6af78bc615af \ - --hash=sha256:f9b57eaa3b0cd8db52049ed0330747b0364e899e8a606a624813452b8203d5f7 \ - --hash=sha256:fce4f615f8ca31b2e61aa0eb5865a21e14f5629515c9151850aa936c02a1ee51 +numpy==2.2.6 ; python_version == "3.13" \ + --hash=sha256:038613e9fb8c72b0a41f025a7e4c3f0b7a1b5d768ece4796b674c8f3fe13efff \ + --hash=sha256:0678000bb9ac1475cd454c6b8c799206af8107e310843532b04d49649c717a47 \ + --hash=sha256:0811bb762109d9708cca4d0b13c4f67146e3c3b7cf8d34018c722adb2d957c84 \ + --hash=sha256:0b605b275d7bd0c640cad4e5d30fa701a8d59302e127e5f79138ad62762c3e3d \ + --hash=sha256:0bca768cd85ae743b2affdc762d617eddf3bcf8724435498a1e80132d04879e6 \ + --hash=sha256:1bc23a79bfabc5d056d106f9befb8d50c31ced2fbc70eedb8155aec74a45798f \ + --hash=sha256:287cc3162b6f01463ccd86be154f284d0893d2b3ed7292439ea97eafa8170e0b \ + --hash=sha256:37c0ca431f82cd5fa716eca9506aefcabc247fb27ba69c5062a6d3ade8cf8f49 \ + --hash=sha256:37e990a01ae6ec7fe7fa1c26c55ecb672dd98b19c3d0e1d1f326fa13cb38d163 \ + --hash=sha256:389d771b1623ec92636b0786bc4ae56abafad4a4c513d36a55dce14bd9ce8571 \ + --hash=sha256:3d70692235e759f260c3d837193090014aebdf026dfd167834bcba43e30c2a42 \ + --hash=sha256:41c5a21f4a04fa86436124d388f6ed60a9343a6f767fced1a8a71c3fbca038ff \ + --hash=sha256:481b49095335f8eed42e39e8041327c05b0f6f4780488f61286ed3c01368d491 \ + --hash=sha256:4eeaae00d789f66c7a25ac5f34b71a7035bb474e679f410e5e1a94deb24cf2d4 \ + --hash=sha256:55a4d33fa519660d69614a9fad433be87e5252f4b03850642f88993f7b2ca566 \ + --hash=sha256:5a6429d4be8ca66d889b7cf70f536a397dc45ba6faeb5f8c5427935d9592e9cf \ + --hash=sha256:5bd4fc3ac8926b3819797a7c0e2631eb889b4118a9898c84f585a54d475b7e40 \ + --hash=sha256:5beb72339d9d4fa36522fc63802f469b13cdbe4fdab4a288f0c441b74272ebfd \ + --hash=sha256:6031dd6dfecc0cf9f668681a37648373bddd6421fff6c66ec1624eed0180ee06 \ + --hash=sha256:71594f7c51a18e728451bb50cc60a3ce4e6538822731b2933209a1f3614e9282 \ + --hash=sha256:74d4531beb257d2c3f4b261bfb0fc09e0f9ebb8842d82a7b4209415896adc680 \ + --hash=sha256:7befc596a7dc9da8a337f79802ee8adb30a552a94f792b9c9d18c840055907db \ + --hash=sha256:894b3a42502226a1cac872f840030665f33326fc3dac8e57c607905773cdcde3 \ + --hash=sha256:8e41fd67c52b86603a91c1a505ebaef50b3314de0213461c7a6e99c9a3beff90 \ + --hash=sha256:8e9ace4a37db23421249ed236fdcdd457d671e25146786dfc96835cd951aa7c1 \ + --hash=sha256:8fc377d995680230e83241d8a96def29f204b5782f371c532579b4f20607a289 \ + --hash=sha256:9551a499bf125c1d4f9e250377c1ee2eddd02e01eac6644c080162c0c51778ab \ + --hash=sha256:b0544343a702fa80c95ad5d3d608ea3599dd54d4632df855e4c8d24eb6ecfa1c \ + --hash=sha256:b093dd74e50a8cba3e873868d9e93a85b78e0daf2e98c6797566ad8044e8363d \ + --hash=sha256:b412caa66f72040e6d268491a59f2c43bf03eb6c96dd8f0307829feb7fa2b6fb \ + --hash=sha256:b4f13750ce79751586ae2eb824ba7e1e8dba64784086c98cdbbcc6a42112ce0d \ + --hash=sha256:b64d8d4d17135e00c8e346e0a738deb17e754230d7e0810ac5012750bbd85a5a \ + --hash=sha256:ba10f8411898fc418a521833e014a77d3ca01c15b0c6cdcce6a0d2897e6dbbdf \ + --hash=sha256:bd48227a919f1bafbdda0583705e547892342c26fb127219d60a5c36882609d1 \ + --hash=sha256:c1f9540be57940698ed329904db803cf7a402f3fc200bfe599334c9bd84a40b2 \ + --hash=sha256:c820a93b0255bc360f53eca31a0e676fd1101f673dda8da93454a12e23fc5f7a \ + --hash=sha256:ce47521a4754c8f4593837384bd3424880629f718d87c5d44f8ed763edd63543 \ + --hash=sha256:d042d24c90c41b54fd506da306759e06e568864df8ec17ccc17e9e884634fd00 \ + --hash=sha256:de749064336d37e340f640b05f24e9e3dd678c57318c7289d222a8a2f543e90c \ + --hash=sha256:e1dda9c7e08dc141e0247a5b8f49cf05984955246a327d4c48bda16821947b2f \ + --hash=sha256:e29554e2bef54a90aa5cc07da6ce955accb83f21ab5de01a62c8478897b264fd \ + --hash=sha256:e3143e4451880bed956e706a3220b4e5cf6172ef05fcc397f6f36a550b1dd868 \ + --hash=sha256:e8213002e427c69c45a52bbd94163084025f533a55a59d6f9c5b820774ef3303 \ + --hash=sha256:efd28d4e9cd7d7a8d39074a4d44c63eda73401580c5c76acda2ce969e0a38e83 \ + --hash=sha256:f0fd6321b839904e15c46e0d257fdd101dd7f530fe03fd6359c1ea63738703f3 \ + --hash=sha256:f1372f041402e37e5e633e586f62aa53de2eac8d98cbfb822806ce4bbefcb74d \ + --hash=sha256:f2618db89be1b4e05f7a1a847a9c1c0abd63e63a1607d892dd54668dd92faf87 \ + --hash=sha256:f447e6acb680fd307f40d3da4852208af94afdfab89cf850986c3ca00562f4fa \ + --hash=sha256:f92729c95468a2f4f15e9bb94c432a9229d0d50de67304399627a943201baa2f \ + --hash=sha256:f9f1adb22318e121c5c69a09142811a201ef17ab257a1e66ca3025065b7f53ae \ + --hash=sha256:fc0c5673685c508a142ca65209b4e79ed6740a4ed6b2267dbba90f34b0b3cfda \ + --hash=sha256:fc7b73d02efb0e18c000e9ad8b83480dfcd5dfd11065997ed4c6747470ae8915 \ + --hash=sha256:fd83c01228a688733f1ded5201c678f0c53ecc1006ffbc404db9f7a899ac6249 \ + --hash=sha256:fe27749d33bb772c80dcd84ae7e8df2adc920ae8297400dabec45f0dedb3f6de \ + --hash=sha256:fee4236c876c4e8369388054d02d0e9bb84821feb1a64dd59e137e6511a551f8 # via - # -r build/requirements.in + # -r build/freethreading-requirements.txt # contourpy + # jaxlib # matplotlib # ml-dtypes # scipy -nvidia-cublas-cu12==12.8.3.14 ; sys_platform == "linux" \ +nvidia-cublas-cu12==12.8.3.14 \ --hash=sha256:3f0e05e7293598cf61933258b73e66a160c27d59c4422670bf0b79348c04be44 \ --hash=sha256:93a4e0e386cc7f6e56c822531396de8170ed17068a1e18f987574895044cd8c3 \ --hash=sha256:9ae5eae500aead01fc4bdfc458209df638b1a3551557ce11a78eea9ece602ae9 # via - # -r build/test-requirements.txt + # jax-cuda12-plugin # nvidia-cudnn-cu12 # nvidia-cusolver-cu12 -nvidia-cuda-cupti-cu12==12.8.57 ; sys_platform == "linux" \ +nvidia-cuda-cupti-cu12==12.8.57 \ --hash=sha256:8e0b2eb847de260739bee4a3f66fac31378f4ff49538ff527a38a01a9a39f950 \ --hash=sha256:bbed719c52a476958a74cfc42f2b95a3fd6b3fd94eb40134acc4601feb4acac3 \ --hash=sha256:ff154211724fd824e758ce176b66007b558eea19c9a5135fc991827ee147e317 - # via -r build/test-requirements.txt -nvidia-cuda-nvcc-cu12==12.8.61 ; sys_platform == "linux" \ + # via jax-cuda12-plugin +nvidia-cuda-nvcc-cu12==12.8.61 \ --hash=sha256:171f605044ba17bc455d19cad289946c3dbea029a90c60dfa7b88e545bc8e329 \ --hash=sha256:28604ec42aaa09035b0fb7111432e5121bc385580b30c55d2acfb7d644b16548 \ --hash=sha256:4524739cfc080e9c9e53032912be8f020058e0a7186746d19acef3b6d916ea0b - # via -r build/test-requirements.txt -nvidia-cuda-runtime-cu12==12.8.57 ; sys_platform == "linux" \ + # via jax-cuda12-plugin +nvidia-cuda-nvrtc-cu12==12.9.86 ; sys_platform == "linux" \ + --hash=sha256:096d4de6bda726415dfaf3198d4f5c522b8e70139c97feef5cd2ca6d4cd9cead \ + --hash=sha256:210cf05005a447e29214e9ce50851e83fc5f4358df8b453155d5e1918094dcb4 \ + --hash=sha256:72972ebdcf504d69462d3bcd67e7b81edd25d0fb85a2c46d3ea3517666636349 + # via + # -r build/requirements.in + # jax-cuda12-plugin +nvidia-cuda-runtime-cu12==12.8.57 \ --hash=sha256:534ccebd967b6a44292678fa5da4f00666029cb2ed07a79515ea41ef31fe3ec7 \ --hash=sha256:75342e28567340b7428ce79a5d6bb6ca5ff9d07b69e7ce00d2c7b4dc23eff0be \ --hash=sha256:89be637e3ee967323865b85e0f147d75f9a5bd98360befa37481b02dd57af8f5 - # via -r build/test-requirements.txt -nvidia-cudnn-cu12==9.7.1.26 ; sys_platform == "linux" \ - --hash=sha256:6d011159a158f3cfc47bf851aea79e31bcff60d530b70ef70474c84cac484d07 \ - --hash=sha256:7b805b9a4cf9f3da7c5f4ea4a9dff7baf62d1a612d6154a7e0d2ea51ed296241 \ - --hash=sha256:848a61d40ef3b32bd4e1fadb599f0cf04a4b942fbe5fb3be572ad75f9b8c53ef - # via -r build/test-requirements.txt -nvidia-cufft-cu12==11.3.3.41 ; sys_platform == "linux" \ + # via jax-cuda12-plugin +nvidia-cudnn-cu12==9.8.0.87 \ + --hash=sha256:b4b5cfddc32aa4180f9d390ee99e9a9f55a89e7087329b41aba4319327e22466 \ + --hash=sha256:b883faeb2f6f15dba7bbb6756eab6a0d9cecb59db5b0fa07577b9cfa24cd99f4 \ + --hash=sha256:d6b02cd0e3e24aa31d0193a8c39fec239354360d7d81055edddb69f35d53a4c8 + # via jax-cuda12-plugin +nvidia-cufft-cu12==11.3.3.41 \ --hash=sha256:68509dcd7e3306e69d0e2d8a6d21c8b25ed62e6df8aac192ce752f17677398b5 \ --hash=sha256:da650080ab79fcdf7a4b06aa1b460e99860646b176a43f6208099bdc17836b6a \ --hash=sha256:f9760612886786601d27a0993bb29ce1f757e6b8b173499d0ecfa850d31b50f8 - # via -r build/test-requirements.txt -nvidia-cusolver-cu12==11.7.2.55 ; sys_platform == "linux" \ + # via jax-cuda12-plugin +nvidia-cusolver-cu12==11.7.2.55 \ --hash=sha256:0fd9e98246f43c15bee5561147ad235dfdf2d037f5d07c9d41af3f7f72feb7cc \ --hash=sha256:4d1354102f1e922cee9db51920dba9e2559877cf6ff5ad03a00d853adafb191b \ --hash=sha256:a5a516c55da5c5aba98420d9bc9bcab18245f21ec87338cc1f930eb18dd411ac - # via -r build/test-requirements.txt -nvidia-cusparse-cu12==12.5.7.53 ; sys_platform == "linux" \ + # via jax-cuda12-plugin +nvidia-cusparse-cu12==12.5.7.53 \ --hash=sha256:3c1b61eb8c85257ea07e9354606b26397612627fdcd327bfd91ccf6155e7c86d \ --hash=sha256:82c201d6781bacf6bb7c654f0446728d0fe596dfdd82ef4a04c204ce3e107441 \ --hash=sha256:d869c6146ca80f4305b62e02d924b4aaced936f8173e3cef536a67eed2a91af1 # via - # -r build/test-requirements.txt + # jax-cuda12-plugin # nvidia-cusolver-cu12 -nvidia-nccl-cu12==2.25.1 ; sys_platform == "linux" \ +nvidia-nccl-cu12==2.25.1 \ --hash=sha256:362aed5963fb9ea2ed2f264409baae30143498fd0e5c503aeaa1badd88cdc54a \ --hash=sha256:4ab428bc915785cc66e8c57cb34c7a64cf739c46702b8db748b6ad6cc7180cf8 - # via -r build/test-requirements.txt -nvidia-nvjitlink-cu12==12.8.61 ; sys_platform == "linux" \ + # via jax-cuda12-plugin +nvidia-nvjitlink-cu12==12.8.61 \ --hash=sha256:1166a964d25fdc0eae497574d38824305195a5283324a21ccb0ce0c802cbf41c \ --hash=sha256:45fd79f2ae20bd67e8bc411055939049873bfd8fac70ff13bd4865e0b9bdab17 \ --hash=sha256:9b80ecab31085dda3ce3b41d043be0ec739216c3fc633b8abe212d5a30026df0 # via - # -r build/test-requirements.txt + # jax-cuda12-plugin # nvidia-cufft-cu12 # nvidia-cusolver-cu12 # nvidia-cusparse-cu12 +nvidia-nvshmem-cu12==3.2.5 ; sys_platform == "linux" \ + --hash=sha256:2f5798d65f1a08f9878aae17cf4d3dcbfe884d1f12cf170556cd40f2be90ca96 \ + --hash=sha256:e076957d5cc72e51061a04f2d46f55df477be53e8a55d0d621be08f7aefe1d00 + # via + # -r build/requirements.in + # jax-cuda12-plugin opt-einsum==3.4.0 \ --hash=sha256:69bb92469f86a1565195ece4ac0323943e83477171b91d24c35afe028a90d7cd \ --hash=sha256:96ca72f1b886d148241348783498194c577fa30a8faac108586b14f1ba4473ac - # via - # -r build/test-requirements.txt - # -r build/requirements.in + # via -r build/requirements.in packaging==24.2 \ --hash=sha256:09abb1bccd265c01f4a3aa3f7a7db064b36514d2cba19a2f694fe6150451a759 \ --hash=sha256:c228a6dc5e932d346bc5739379109d49e8853dd8223571c7c5b55260edc0b97f @@ -541,29 +596,6 @@ pluggy==1.5.0 \ --hash=sha256:2cffa88e94fdc978c4c574f15f9e59b7f4201d439195c3715ca9e2486f1d0cf1 \ --hash=sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669 # via pytest -portpicker==1.6.0 \ - --hash=sha256:b2787a41404cf7edbe29b07b9e0ed863b09f2665dcc01c1eb0c2261c1e7d0755 \ - --hash=sha256:bd507fd6f96f65ee02781f2e674e9dc6c99bbfa6e3c39992e3916204c9d431fa - # via -r build/test-requirements.txt -psutil==6.1.1 \ - --hash=sha256:018aeae2af92d943fdf1da6b58665124897cfc94faa2ca92098838f83e1b1bca \ - --hash=sha256:0bdd4eab935276290ad3cb718e9809412895ca6b5b334f5a9111ee6d9aff9377 \ - --hash=sha256:1924e659d6c19c647e763e78670a05dbb7feaf44a0e9c94bf9e14dfc6ba50468 \ - --hash=sha256:33431e84fee02bc84ea36d9e2c4a6d395d479c9dd9bba2376c1f6ee8f3a4e0b3 \ - --hash=sha256:384636b1a64b47814437d1173be1427a7c83681b17a450bfc309a1953e329603 \ - --hash=sha256:6d4281f5bbca041e2292be3380ec56a9413b790579b8e593b1784499d0005dac \ - --hash=sha256:8be07491f6ebe1a693f17d4f11e69d0dc1811fa082736500f649f79df7735303 \ - --hash=sha256:8df0178ba8a9e5bc84fed9cfa61d54601b371fbec5c8eebad27575f1e105c0d4 \ - --hash=sha256:97f7cb9921fbec4904f522d972f0c0e1f4fabbdd4e0287813b21215074a0f160 \ - --hash=sha256:9ccc4316f24409159897799b83004cb1e24f9819b0dcf9c0b68bdcb6cefee6a8 \ - --hash=sha256:b6e06c20c05fe95a3d7302d74e7097756d4ba1247975ad6905441ae1b5b66003 \ - --hash=sha256:c777eb75bb33c47377c9af68f30e9f11bc78e0f07fbf907be4a5d70b2fe5f030 \ - --hash=sha256:ca9609c77ea3b8481ab005da74ed894035936223422dc591d6772b147421f777 \ - --hash=sha256:cf8496728c18f2d0b45198f06895be52f36611711746b7f30c464b422b50e2f5 \ - --hash=sha256:eaa912e0b11848c4d9279a93d7e2783df352b082f40111e078388701fd479e53 \ - --hash=sha256:f35cfccb065fff93529d2afb4a2e89e363fe63ca1e4a5da22b603a85833c2649 \ - --hash=sha256:fc0ed7fe2231a444fc219b9c42d0376e0a9a1a72f16c5cfa0f68d19f1a0663e8 - # via portpicker pyelftools==0.31 \ --hash=sha256:c774416b10310156879443b81187d182d8d9ee499660380e645918b50bc88f99 \ --hash=sha256:f52de7b3c7e8c64c8abc04a79a1cf37ac5fb0b8a49809827130b858944840607 @@ -596,48 +628,56 @@ rich==13.9.4 \ --hash=sha256:439594978a49a09530cff7ebc4b5c7103ef57baf48d5ea3184f21d9a2befa098 \ --hash=sha256:6049d5e6ec054bf2779ab3358186963bac2ea89175919d699e378b99738c2a90 # via -r build/test-requirements.txt -scipy==1.15.0 \ - --hash=sha256:0e5b34f8894f9904cc578008d1a9467829c1817e9f9cb45e6d6eeb61d2ab7731 \ - --hash=sha256:0fcb16eb04d84670722ce8d93b05257df471704c913cb0ff9dc5a1c31d1e9422 \ - --hash=sha256:129f899ed275c0515d553b8d31696924e2ca87d1972421e46c376b9eb87de3d2 \ - --hash=sha256:161f80a98047c219c257bf5ce1777c574bde36b9d962a46b20d0d7e531f86863 \ - --hash=sha256:1b29e4fc02e155a5fd1165f1e6a73edfdd110470736b0f48bcbe48083f0eee37 \ - --hash=sha256:1e2448acd79c6374583581a1ded32ac71a00c2b9c62dfa87a40e1dd2520be111 \ - --hash=sha256:2240e1fd0782e62e1aacdc7234212ee271d810f67e9cd3b8d521003a82603ef8 \ - --hash=sha256:300742e2cc94e36a2880ebe464a1c8b4352a7b0f3e36ec3d2ac006cdbe0219ac \ - --hash=sha256:327163ad73e54541a675240708244644294cb0a65cca420c9c79baeb9648e479 \ - --hash=sha256:351899dd2a801edd3691622172bc8ea01064b1cada794f8641b89a7dc5418db6 \ - --hash=sha256:35c68f7044b4e7ad73a3e68e513dda946989e523df9b062bd3cf401a1a882192 \ - --hash=sha256:36be480e512d38db67f377add5b759fb117edd987f4791cdf58e59b26962bee4 \ - --hash=sha256:37ce9394cdcd7c5f437583fc6ef91bd290014993900643fdfc7af9b052d1613b \ - --hash=sha256:46e91b5b16909ff79224b56e19cbad65ca500b3afda69225820aa3afbf9ec020 \ - --hash=sha256:4e08c6a36f46abaedf765dd2dfcd3698fa4bd7e311a9abb2d80e33d9b2d72c34 \ - --hash=sha256:52475011be29dfcbecc3dfe3060e471ac5155d72e9233e8d5616b84e2b542054 \ - --hash=sha256:5972e3f96f7dda4fd3bb85906a17338e65eaddfe47f750e240f22b331c08858e \ - --hash=sha256:5abbdc6ede5c5fed7910cf406a948e2c0869231c0db091593a6b2fa78be77e5d \ - --hash=sha256:5beb0a2200372b7416ec73fdae94fe81a6e85e44eb49c35a11ac356d2b8eccc6 \ - --hash=sha256:61513b989ee8d5218fbeb178b2d51534ecaddba050db949ae99eeb3d12f6825d \ - --hash=sha256:6d26f17c64abd6c6c2dfb39920f61518cc9e213d034b45b2380e32ba78fde4c0 \ - --hash=sha256:6f376d7c767731477bac25a85d0118efdc94a572c6b60decb1ee48bf2391a73b \ - --hash=sha256:767e8cf6562931f8312f4faa7ddea412cb783d8df49e62c44d00d89f41f9bbe8 \ - --hash=sha256:82bff2eb01ccf7cea8b6ee5274c2dbeadfdac97919da308ee6d8e5bcbe846443 \ - --hash=sha256:952d2e9eaa787f0a9e95b6e85da3654791b57a156c3e6609e65cc5176ccfe6f2 \ - --hash=sha256:9c8254fe21dd2c6c8f7757035ec0c31daecf3bb3cffd93bc1ca661b731d28136 \ - --hash=sha256:aeac60d3562a7bf2f35549bdfdb6b1751c50590f55ce7322b4b2fc821dc27fca \ - --hash=sha256:b1432102254b6dc7766d081fa92df87832ac25ff0b3d3a940f37276e63eb74ff \ - --hash=sha256:bdca4c7bb8dc41307e5f39e9e5d19c707d8e20a29845e7533b3bb20a9d4ccba0 \ - --hash=sha256:c9624eeae79b18cab1a31944b5ef87aa14b125d6ab69b71db22f0dbd962caf1e \ - --hash=sha256:ccb6248a9987193fe74363a2d73b93bc2c546e0728bd786050b7aef6e17db03c \ - --hash=sha256:cd9d9198a7fd9a77f0eb5105ea9734df26f41faeb2a88a0e62e5245506f7b6df \ - --hash=sha256:d13bbc0658c11f3d19df4138336e4bce2c4fbd78c2755be4bf7b8e235481557f \ - --hash=sha256:d35aef233b098e4de88b1eac29f0df378278e7e250a915766786b773309137c4 \ - --hash=sha256:de112c2dae53107cfeaf65101419662ac0a54e9a088c17958b51c95dac5de56d \ - --hash=sha256:e9baff912ea4f78a543d183ed6f5b3bea9784509b948227daaf6f10727a0e2e5 \ - --hash=sha256:eb1533c59f0ec6c55871206f15a5c72d1fae7ad3c0a8ca33ca88f7c309bbbf8c \ - --hash=sha256:ec915cd26d76f6fc7ae8522f74f5b2accf39546f341c771bb2297f3871934a52 \ - --hash=sha256:fde0f3104dfa1dfbc1f230f65506532d0558d43188789eaf68f97e106249a913 \ - --hash=sha256:fe00169cf875bed0b3c40e4da45b57037dc21d7c7bf0c85ed75f210c281488f1 - # via -r build/requirements.in +scipy==1.15.2 ; python_version >= "3.13" \ + --hash=sha256:01edfac9f0798ad6b46d9c4c9ca0e0ad23dbf0b1eb70e96adb9fa7f525eff0bf \ + --hash=sha256:03205d57a28e18dfd39f0377d5002725bf1f19a46f444108c29bdb246b6c8a11 \ + --hash=sha256:08b57a9336b8e79b305a143c3655cc5bdbe6d5ece3378578888d2afbb51c4e37 \ + --hash=sha256:11e7ad32cf184b74380f43d3c0a706f49358b904fa7d5345f16ddf993609184d \ + --hash=sha256:28a0d2c2075946346e4408b211240764759e0fabaeb08d871639b5f3b1aca8a0 \ + --hash=sha256:2b871df1fe1a3ba85d90e22742b93584f8d2b8e6124f8372ab15c71b73e428b8 \ + --hash=sha256:302093e7dfb120e55515936cb55618ee0b895f8bcaf18ff81eca086c17bd80af \ + --hash=sha256:42dabaaa798e987c425ed76062794e93a243be8f0f20fff6e7a89f4d61cb3d40 \ + --hash=sha256:447ce30cee6a9d5d1379087c9e474628dab3db4a67484be1b7dc3196bfb2fac9 \ + --hash=sha256:4c6676490ad76d1c2894d77f976144b41bd1a4052107902238047fb6a473e971 \ + --hash=sha256:54c462098484e7466362a9f1672d20888f724911a74c22ae35b61f9c5919183d \ + --hash=sha256:597a0c7008b21c035831c39927406c6181bcf8f60a73f36219b69d010aa04737 \ + --hash=sha256:5a6fd6eac1ce74a9f77a7fc724080d507c5812d61e72bd5e4c489b042455865e \ + --hash=sha256:5ea7ed46d437fc52350b028b1d44e002646e28f3e8ddc714011aaf87330f2f32 \ + --hash=sha256:601881dfb761311045b03114c5fe718a12634e5608c3b403737ae463c9885d53 \ + --hash=sha256:62ca1ff3eb513e09ed17a5736929429189adf16d2d740f44e53270cc800ecff1 \ + --hash=sha256:69ea6e56d00977f355c0f84eba69877b6df084516c602d93a33812aa04d90a3d \ + --hash=sha256:6a8e34cf4c188b6dd004654f88586d78f95639e48a25dfae9c5e34a6dc34547e \ + --hash=sha256:6d0194c37037707b2afa7a2f2a924cf7bac3dc292d51b6a925e5fcb89bc5c776 \ + --hash=sha256:6f223753c6ea76983af380787611ae1291e3ceb23917393079dcc746ba60cfb5 \ + --hash=sha256:6f5e296ec63c5da6ba6fa0343ea73fd51b8b3e1a300b0a8cae3ed4b1122c7462 \ + --hash=sha256:7cd5b77413e1855351cdde594eca99c1f4a588c2d63711388b6a1f1c01f62274 \ + --hash=sha256:869269b767d5ee7ea6991ed7e22b3ca1f22de73ab9a49c44bad338b725603301 \ + --hash=sha256:87994da02e73549dfecaed9e09a4f9d58a045a053865679aeb8d6d43747d4df3 \ + --hash=sha256:888307125ea0c4466287191e5606a2c910963405ce9671448ff9c81c53f85f58 \ + --hash=sha256:92233b2df6938147be6fa8824b8136f29a18f016ecde986666be5f4d686a91a4 \ + --hash=sha256:9412f5e408b397ff5641080ed1e798623dbe1ec0d78e72c9eca8992976fa65aa \ + --hash=sha256:9b18aa747da280664642997e65aab1dd19d0c3d17068a04b3fe34e2559196cb9 \ + --hash=sha256:9de9d1416b3d9e7df9923ab23cd2fe714244af10b763975bea9e4f2e81cebd27 \ + --hash=sha256:a2ec871edaa863e8213ea5df811cd600734f6400b4af272e1c011e69401218e9 \ + --hash=sha256:a5080a79dfb9b78b768cebf3c9dcbc7b665c5875793569f48bf0e2b1d7f68f6f \ + --hash=sha256:a8bf5cb4a25046ac61d38f8d3c3426ec11ebc350246a4642f2f315fe95bda655 \ + --hash=sha256:b09ae80010f52efddb15551025f9016c910296cf70adbf03ce2a8704f3a5ad20 \ + --hash=sha256:b5e025e903b4f166ea03b109bb241355b9c42c279ea694d8864d033727205e65 \ + --hash=sha256:bad78d580270a4d32470563ea86c6590b465cb98f83d760ff5b0990cb5518a93 \ + --hash=sha256:bae43364d600fdc3ac327db99659dcb79e6e7ecd279a75fe1266669d9a652828 \ + --hash=sha256:c4697a10da8f8765bb7c83e24a470da5797e37041edfd77fd95ba3811a47c4fd \ + --hash=sha256:c90ebe8aaa4397eaefa8455a8182b164a6cc1d59ad53f79943f266d99f68687f \ + --hash=sha256:cd58a314d92838f7e6f755c8a2167ead4f27e1fd5c1251fd54289569ef3495ec \ + --hash=sha256:cf72ff559a53a6a6d77bd8eefd12a17995ffa44ad86c77a5df96f533d4e6c6bb \ + --hash=sha256:def751dd08243934c884a3221156d63e15234a3155cf25978b0a668409d45eb6 \ + --hash=sha256:e7c68b6a43259ba0aab737237876e5c2c549a031ddb7abc28c7b47f22e202ded \ + --hash=sha256:ecf797d2d798cf7c838c6d98321061eb3e72a74710e6c40540f0e8087e3b499e \ + --hash=sha256:f031846580d9acccd0044efd1a90e6f4df3a6e12b4b6bd694a7bc03a89892b28 \ + --hash=sha256:fb530e4794fc8ea76a4a21ccb67dea33e5e0e60f07fc38a49e821e1eae3b71a0 \ + --hash=sha256:fe8a9eb875d430d81755472c5ba75e84acc980e4a8f6204d402849234d3017db + # via + # -r build/requirements.in + # jaxlib six==1.17.0 \ --hash=sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274 \ --hash=sha256:ff70335d468e7eb6ec65b95b99d3a2836546063f63acc5171de367e834932a81 @@ -653,117 +693,14 @@ typing-extensions==4.12.2 \ wheel==0.45.1 \ --hash=sha256:661e1abd9198507b1409a20c02106d9670b2576e916d58f520316666abca6729 \ --hash=sha256:708e7481cc80179af0e556bbf0cc00b8444c7321e2700b8d8580231d13017248 - # via -r build/test-requirements.txt + # via -r build/requirements.in zipp==3.21.0 \ --hash=sha256:2c9958f6430a2040341a52eb608ed6dd93ef4392e02ffe219417c1b28b5dd1f4 \ --hash=sha256:ac1bbe05fd2991f160ebce24ffbac5f6d11d83dc90891255885223d42b3cd931 # via etils -# python 3.13t can compile 0.23.0 -# due to https://github.com/indygreg/python-zstandard/issues/231 -# zstandard==0.23.0 \ -# --hash=sha256:034b88913ecc1b097f528e42b539453fa82c3557e414b3de9d5632c80439a473 \ -# --hash=sha256:0a7f0804bb3799414af278e9ad51be25edf67f78f916e08afdb983e74161b916 \ -# --hash=sha256:11e3bf3c924853a2d5835b24f03eeba7fc9b07d8ca499e247e06ff5676461a15 \ -# --hash=sha256:12a289832e520c6bd4dcaad68e944b86da3bad0d339ef7989fb7e88f92e96072 \ -# --hash=sha256:1516c8c37d3a053b01c1c15b182f3b5f5eef19ced9b930b684a73bad121addf4 \ -# --hash=sha256:157e89ceb4054029a289fb504c98c6a9fe8010f1680de0201b3eb5dc20aa6d9e \ -# --hash=sha256:1bfe8de1da6d104f15a60d4a8a768288f66aa953bbe00d027398b93fb9680b26 \ -# --hash=sha256:1e172f57cd78c20f13a3415cc8dfe24bf388614324d25539146594c16d78fcc8 \ -# --hash=sha256:1fd7e0f1cfb70eb2f95a19b472ee7ad6d9a0a992ec0ae53286870c104ca939e5 \ -# --hash=sha256:203d236f4c94cd8379d1ea61db2fce20730b4c38d7f1c34506a31b34edc87bdd \ -# --hash=sha256:27d3ef2252d2e62476389ca8f9b0cf2bbafb082a3b6bfe9d90cbcbb5529ecf7c \ -# --hash=sha256:29a2bc7c1b09b0af938b7a8343174b987ae021705acabcbae560166567f5a8db \ -# --hash=sha256:2ef230a8fd217a2015bc91b74f6b3b7d6522ba48be29ad4ea0ca3a3775bf7dd5 \ -# --hash=sha256:2ef3775758346d9ac6214123887d25c7061c92afe1f2b354f9388e9e4d48acfc \ -# --hash=sha256:2f146f50723defec2975fb7e388ae3a024eb7151542d1599527ec2aa9cacb152 \ -# --hash=sha256:2fb4535137de7e244c230e24f9d1ec194f61721c86ebea04e1581d9d06ea1269 \ -# --hash=sha256:32ba3b5ccde2d581b1e6aa952c836a6291e8435d788f656fe5976445865ae045 \ -# --hash=sha256:34895a41273ad33347b2fc70e1bff4240556de3c46c6ea430a7ed91f9042aa4e \ -# --hash=sha256:379b378ae694ba78cef921581ebd420c938936a153ded602c4fea612b7eaa90d \ -# --hash=sha256:38302b78a850ff82656beaddeb0bb989a0322a8bbb1bf1ab10c17506681d772a \ -# --hash=sha256:3aa014d55c3af933c1315eb4bb06dd0459661cc0b15cd61077afa6489bec63bb \ -# --hash=sha256:4051e406288b8cdbb993798b9a45c59a4896b6ecee2f875424ec10276a895740 \ -# --hash=sha256:40b33d93c6eddf02d2c19f5773196068d875c41ca25730e8288e9b672897c105 \ -# --hash=sha256:43da0f0092281bf501f9c5f6f3b4c975a8a0ea82de49ba3f7100e64d422a1274 \ -# --hash=sha256:445e4cb5048b04e90ce96a79b4b63140e3f4ab5f662321975679b5f6360b90e2 \ -# --hash=sha256:48ef6a43b1846f6025dde6ed9fee0c24e1149c1c25f7fb0a0585572b2f3adc58 \ -# --hash=sha256:50a80baba0285386f97ea36239855f6020ce452456605f262b2d33ac35c7770b \ -# --hash=sha256:519fbf169dfac1222a76ba8861ef4ac7f0530c35dd79ba5727014613f91613d4 \ -# --hash=sha256:53dd9d5e3d29f95acd5de6802e909ada8d8d8cfa37a3ac64836f3bc4bc5512db \ -# --hash=sha256:53ea7cdc96c6eb56e76bb06894bcfb5dfa93b7adcf59d61c6b92674e24e2dd5e \ -# --hash=sha256:576856e8594e6649aee06ddbfc738fec6a834f7c85bf7cadd1c53d4a58186ef9 \ -# --hash=sha256:59556bf80a7094d0cfb9f5e50bb2db27fefb75d5138bb16fb052b61b0e0eeeb0 \ -# --hash=sha256:5d41d5e025f1e0bccae4928981e71b2334c60f580bdc8345f824e7c0a4c2a813 \ -# --hash=sha256:61062387ad820c654b6a6b5f0b94484fa19515e0c5116faf29f41a6bc91ded6e \ -# --hash=sha256:61f89436cbfede4bc4e91b4397eaa3e2108ebe96d05e93d6ccc95ab5714be512 \ -# --hash=sha256:62136da96a973bd2557f06ddd4e8e807f9e13cbb0bfb9cc06cfe6d98ea90dfe0 \ -# --hash=sha256:64585e1dba664dc67c7cdabd56c1e5685233fbb1fc1966cfba2a340ec0dfff7b \ -# --hash=sha256:65308f4b4890aa12d9b6ad9f2844b7ee42c7f7a4fd3390425b242ffc57498f48 \ -# --hash=sha256:66b689c107857eceabf2cf3d3fc699c3c0fe8ccd18df2219d978c0283e4c508a \ -# --hash=sha256:6a41c120c3dbc0d81a8e8adc73312d668cd34acd7725f036992b1b72d22c1772 \ -# --hash=sha256:6f77fa49079891a4aab203d0b1744acc85577ed16d767b52fc089d83faf8d8ed \ -# --hash=sha256:72c68dda124a1a138340fb62fa21b9bf4848437d9ca60bd35db36f2d3345f373 \ -# --hash=sha256:752bf8a74412b9892f4e5b58f2f890a039f57037f52c89a740757ebd807f33ea \ -# --hash=sha256:76e79bc28a65f467e0409098fa2c4376931fd3207fbeb6b956c7c476d53746dd \ -# --hash=sha256:774d45b1fac1461f48698a9d4b5fa19a69d47ece02fa469825b442263f04021f \ -# --hash=sha256:77da4c6bfa20dd5ea25cbf12c76f181a8e8cd7ea231c673828d0386b1740b8dc \ -# --hash=sha256:77ea385f7dd5b5676d7fd943292ffa18fbf5c72ba98f7d09fc1fb9e819b34c23 \ -# --hash=sha256:80080816b4f52a9d886e67f1f96912891074903238fe54f2de8b786f86baded2 \ -# --hash=sha256:80a539906390591dd39ebb8d773771dc4db82ace6372c4d41e2d293f8e32b8db \ -# --hash=sha256:82d17e94d735c99621bf8ebf9995f870a6b3e6d14543b99e201ae046dfe7de70 \ -# --hash=sha256:837bb6764be6919963ef41235fd56a6486b132ea64afe5fafb4cb279ac44f259 \ -# --hash=sha256:84433dddea68571a6d6bd4fbf8ff398236031149116a7fff6f777ff95cad3df9 \ -# --hash=sha256:8c24f21fa2af4bb9f2c492a86fe0c34e6d2c63812a839590edaf177b7398f700 \ -# --hash=sha256:8ed7d27cb56b3e058d3cf684d7200703bcae623e1dcc06ed1e18ecda39fee003 \ -# --hash=sha256:9206649ec587e6b02bd124fb7799b86cddec350f6f6c14bc82a2b70183e708ba \ -# --hash=sha256:983b6efd649723474f29ed42e1467f90a35a74793437d0bc64a5bf482bedfa0a \ -# --hash=sha256:98da17ce9cbf3bfe4617e836d561e433f871129e3a7ac16d6ef4c680f13a839c \ -# --hash=sha256:9c236e635582742fee16603042553d276cca506e824fa2e6489db04039521e90 \ -# --hash=sha256:9da6bc32faac9a293ddfdcb9108d4b20416219461e4ec64dfea8383cac186690 \ -# --hash=sha256:a05e6d6218461eb1b4771d973728f0133b2a4613a6779995df557f70794fd60f \ -# --hash=sha256:a0817825b900fcd43ac5d05b8b3079937073d2b1ff9cf89427590718b70dd840 \ -# --hash=sha256:a4ae99c57668ca1e78597d8b06d5af837f377f340f4cce993b551b2d7731778d \ -# --hash=sha256:a8c86881813a78a6f4508ef9daf9d4995b8ac2d147dcb1a450448941398091c9 \ -# --hash=sha256:a8fffdbd9d1408006baaf02f1068d7dd1f016c6bcb7538682622c556e7b68e35 \ -# --hash=sha256:a9b07268d0c3ca5c170a385a0ab9fb7fdd9f5fd866be004c4ea39e44edce47dd \ -# --hash=sha256:ab19a2d91963ed9e42b4e8d77cd847ae8381576585bad79dbd0a8837a9f6620a \ -# --hash=sha256:ac184f87ff521f4840e6ea0b10c0ec90c6b1dcd0bad2f1e4a9a1b4fa177982ea \ -# --hash=sha256:b0e166f698c5a3e914947388c162be2583e0c638a4703fc6a543e23a88dea3c1 \ -# --hash=sha256:b2170c7e0367dde86a2647ed5b6f57394ea7f53545746104c6b09fc1f4223573 \ -# --hash=sha256:b2d8c62d08e7255f68f7a740bae85b3c9b8e5466baa9cbf7f57f1cde0ac6bc09 \ -# --hash=sha256:b4567955a6bc1b20e9c31612e615af6b53733491aeaa19a6b3b37f3b65477094 \ -# --hash=sha256:b69bb4f51daf461b15e7b3db033160937d3ff88303a7bc808c67bbc1eaf98c78 \ -# --hash=sha256:b8c0bd73aeac689beacd4e7667d48c299f61b959475cdbb91e7d3d88d27c56b9 \ -# --hash=sha256:be9b5b8659dff1f913039c2feee1aca499cfbc19e98fa12bc85e037c17ec6ca5 \ -# --hash=sha256:bf0a05b6059c0528477fba9054d09179beb63744355cab9f38059548fedd46a9 \ -# --hash=sha256:c16842b846a8d2a145223f520b7e18b57c8f476924bda92aeee3a88d11cfc391 \ -# --hash=sha256:c363b53e257246a954ebc7c488304b5592b9c53fbe74d03bc1c64dda153fb847 \ -# --hash=sha256:c7c517d74bea1a6afd39aa612fa025e6b8011982a0897768a2f7c8ab4ebb78a2 \ -# --hash=sha256:d20fd853fbb5807c8e84c136c278827b6167ded66c72ec6f9a14b863d809211c \ -# --hash=sha256:d2240ddc86b74966c34554c49d00eaafa8200a18d3a5b6ffbf7da63b11d74ee2 \ -# --hash=sha256:d477ed829077cd945b01fc3115edd132c47e6540ddcd96ca169facff28173057 \ -# --hash=sha256:d50d31bfedd53a928fed6707b15a8dbeef011bb6366297cc435accc888b27c20 \ -# --hash=sha256:dc1d33abb8a0d754ea4763bad944fd965d3d95b5baef6b121c0c9013eaf1907d \ -# --hash=sha256:dc5d1a49d3f8262be192589a4b72f0d03b72dcf46c51ad5852a4fdc67be7b9e4 \ -# --hash=sha256:e2d1a054f8f0a191004675755448d12be47fa9bebbcffa3cdf01db19f2d30a54 \ -# --hash=sha256:e7792606d606c8df5277c32ccb58f29b9b8603bf83b48639b7aedf6df4fe8171 \ -# --hash=sha256:ed1708dbf4d2e3a1c5c69110ba2b4eb6678262028afd6c6fbcc5a8dac9cda68e \ -# --hash=sha256:f2d4380bf5f62daabd7b751ea2339c1a21d1c9463f1feb7fc2bdcea2c29c3160 \ -# --hash=sha256:f3513916e8c645d0610815c257cbfd3242adfd5c4cfa78be514e5a3ebb42a41b \ -# --hash=sha256:f8346bfa098532bc1fb6c7ef06783e969d87a99dd1d2a5a18a892c1d7a643c58 \ -# --hash=sha256:f83fa6cae3fff8e98691248c9320356971b59678a17f20656a9e59cd32cee6d8 \ -# --hash=sha256:fa6ce8b52c5987b3e34d5674b0ab529a4602b632ebab0a93b07bfb4dfc8f8a33 \ -# --hash=sha256:fb2b1ecfef1e67897d336de3a0e3f52478182d6a47eda86cbd42504c5cbd009a \ -# --hash=sha256:fc9ca1c9718cb3b06634c7c8dec57d24e9438b2aa9a0f02b8bb36bf478538880 \ -# --hash=sha256:fd30d9c67d13d891f2360b2a120186729c111238ac63b43dbd37a5a40670b8ca \ -# --hash=sha256:fd7699e8fd9969f455ef2926221e0233f81a2542921471382e77a9e2f2b57f4b \ -# --hash=sha256:fe3b385d996ee0822fd46528d9f0443b880d4d05528fd26a9119a54ec3f91c69 -# # via -r build/requirements.in # The following packages are considered to be unsafe in a requirements file: setuptools==70.3.0 \ --hash=sha256:f171bab1dfbc86b132997f26a119f6056a57950d058587841a0082e8830f9dc5 \ --hash=sha256:fe384da74336c398e0d956d1cae0669bc02eed936cdb1d49b57de1990dc11ffc - # via - # -r build/test-requirements.txt - # -r build/requirements.in + # via -r build/requirements.in diff --git a/build/requirements_lock_3_14.txt b/build/requirements_lock_3_14.txt new file mode 100644 index 000000000000..6a91caa9bbad --- /dev/null +++ b/build/requirements_lock_3_14.txt @@ -0,0 +1,139 @@ +# This file was autogenerated by uv via the following command: +# uv pip compile --output-file=build/requirements_lock_3_14.txt build/requirements.in build/nonfreethreading-requirements.txt build/test-requirements.txt build/gpu-test-requirements.txt +absl-py==2.2.2 + # via -r build/test-requirements.txt +attrs==25.3.0 + # via hypothesis +auditwheel==6.3.0 + # via -r build/test-requirements.txt +build==1.2.2.post1 + # via -r build/test-requirements.txt +cloudpickle==3.1.1 + # via -r build/test-requirements.txt +contourpy==1.3.2 + # via matplotlib +cycler==0.12.1 + # via matplotlib +etils==1.12.2 + # via -r build/requirements.in +execnet==2.1.1 + # via pytest-xdist +filelock==3.18.0 + # via -r build/test-requirements.txt +flatbuffers==25.2.10 + # via -r build/test-requirements.txt +fonttools==4.57.0 + # via matplotlib +fsspec==2025.3.2 + # via etils +hypothesis==6.131.9 + # via -r build/test-requirements.txt +importlib-resources==6.5.2 + # via etils +iniconfig==2.1.0 + # via pytest +kiwisolver==1.4.8 + # via matplotlib +markdown-it-py==3.0.0 + # via rich +matplotlib==3.10.1 + # via -r build/test-requirements.txt +mdurl==0.1.2 + # via markdown-it-py +ml-dtypes==0.5.1 + # via + # -r build/requirements.in + # tensorstore +mpmath==1.4.0a4 + # via -r build/test-requirements.txt +numpy==2.2.6 + # via + # -r build/nonfreethreading-requirements.txt + # contourpy + # matplotlib + # ml-dtypes + # scipy + # tensorstore +nvidia-cublas-cu12==12.8.4.1 + # via + # -r build/gpu-test-requirements.txt + # nvidia-cudnn-cu12 + # nvidia-cusolver-cu12 +nvidia-cuda-cupti-cu12==12.8.90 + # via -r build/gpu-test-requirements.txt +nvidia-cuda-nvcc-cu12==12.8.93 + # via -r build/gpu-test-requirements.txt +nvidia-cuda-runtime-cu12==12.8.90 + # via -r build/gpu-test-requirements.txt +nvidia-cudnn-cu12==9.8.0.87 + # via -r build/gpu-test-requirements.txt +nvidia-cufft-cu12==11.3.3.83 + # via -r build/gpu-test-requirements.txt +nvidia-cusolver-cu12==11.7.3.90 + # via -r build/gpu-test-requirements.txt +nvidia-cusparse-cu12==12.5.8.93 + # via + # -r build/gpu-test-requirements.txt + # nvidia-cusolver-cu12 +nvidia-nccl-cu12==2.26.2.post1 + # via -r build/gpu-test-requirements.txt +nvidia-nvjitlink-cu12==12.8.93 + # via + # -r build/gpu-test-requirements.txt + # nvidia-cufft-cu12 + # nvidia-cusolver-cu12 + # nvidia-cusparse-cu12 +opt-einsum==3.4.0 + # via -r build/test-requirements.txt +packaging==25.0 + # via + # auditwheel + # build + # matplotlib + # pytest +pillow==11.2.1 + # via + # -r build/test-requirements.txt + # matplotlib +pluggy==1.5.0 + # via pytest +portpicker==1.6.0 + # via -r build/nonfreethreading-requirements.txt +psutil==7.0.0 + # via portpicker +pyelftools==0.32 + # via auditwheel +pygments==2.19.1 + # via rich +pyparsing==3.2.3 + # via matplotlib +pyproject-hooks==1.2.0 + # via build +pytest==8.3.5 + # via pytest-xdist +pytest-xdist==3.6.1 + # via -r build/test-requirements.txt +python-dateutil==2.9.0.post0 + # via matplotlib +rich==14.0.0 + # via -r build/test-requirements.txt +scipy==1.15.2 + # via -r build/requirements.in +setuptools==80.0.0 + # via + # -r build/requirements.in + # -r build/test-requirements.txt +six==1.17.0 + # via python-dateutil +sortedcontainers==2.4.0 + # via hypothesis +tensorstore==0.1.74 + # via -r build/nonfreethreading-requirements.txt +typing-extensions==4.13.2 + # via etils +wheel==0.45.1 + # via -r build/test-requirements.txt +zipp==3.21.0 + # via etils +zstandard==0.23.0 + # via -r build/nonfreethreading-requirements.txt diff --git a/build/requirements_lock_3_14_ft.txt b/build/requirements_lock_3_14_ft.txt new file mode 100644 index 000000000000..6eedf149f5fa --- /dev/null +++ b/build/requirements_lock_3_14_ft.txt @@ -0,0 +1,27 @@ +--pre +--extra-index-url https://pypi.anaconda.org/scientific-python-nightly-wheels/simple +numpy + +--pre +--extra-index-url https://pypi.anaconda.org/scientific-python-nightly-wheels/simple +scipy + +absl-py==2.1.0 + +attrs==24.3.0 + +hypothesis==6.123.9 + +sortedcontainers==2.4.0 + +flatbuffers==24.12.23 + +ml-dtypes==0.5.1 + +opt-einsum==3.4.0 + +build==1.2.2.post1 +setuptools==80.0.0 +wheel==0.45.1 +pyproject-hooks==1.2.0 +packaging==25.0 diff --git a/build/rocm/Dockerfile.ms b/build/rocm/Dockerfile.ms index a084045256de..40b4decaafb4 100644 --- a/build/rocm/Dockerfile.ms +++ b/build/rocm/Dockerfile.ms @@ -40,7 +40,7 @@ RUN --mount=type=cache,target=/var/cache/apt \ liblzma-dev # Install pyenv with different python versions -ARG PYTHON_VERSION=3.10.14 +ARG PYTHON_VERSION=3.11.13 RUN git clone https://github.com/pyenv/pyenv.git /pyenv ENV PYENV_ROOT /pyenv ENV PATH $PYENV_ROOT/shims:$PYENV_ROOT/bin:$PATH diff --git a/build/rocm/build_wheels/Dockerfile.manylinux_2_28_x86_64.rocm b/build/rocm/build_wheels/Dockerfile.manylinux_2_28_x86_64.rocm index 08b6bd3ff8d6..3ca491568911 100644 --- a/build/rocm/build_wheels/Dockerfile.manylinux_2_28_x86_64.rocm +++ b/build/rocm/build_wheels/Dockerfile.manylinux_2_28_x86_64.rocm @@ -9,7 +9,7 @@ ARG ROCM_BUILD_NUM # manylinux base image. However, adding this does fix an issue where Bazel isn't able # to find them. RUN --mount=type=cache,target=/var/cache/dnf \ - dnf install -y gcc-c++-8.5.0-22.el8_10.x86_64 + dnf install -y numactl-devel RUN --mount=type=cache,target=/var/cache/dnf \ --mount=type=bind,source=build/rocm/tools/get_rocm.py,target=get_rocm.py \ @@ -25,5 +25,11 @@ RUN mkdir /tmp/llvm-project && wget -qO - https://github.com/llvm/llvm-project/a mkdir /tmp/llvm-project/build && cd /tmp/llvm-project/build && cmake -DLLVM_ENABLE_PROJECTS='clang;lld' -DCMAKE_BUILD_TYPE=Release -DCMAKE_INSTALL_PREFIX=/usr/lib/llvm-18/ ../llvm && \ make -j$(nproc) && make -j$(nproc) install && rm -rf /tmp/llvm-project +# Set some clang config +COPY ./build/rocm/build_wheels/clang.cfg /usr/lib/llvm-18/bin/clang++.cfg +COPY ./build/rocm/build_wheels/clang.cfg /usr/lib/llvm-18/bin/clang.cfg +COPY ./build/rocm/build_wheels/clang.cfg /opt/rocm/llvm/bin/clang++.cfg +COPY ./build/rocm/build_wheels/clang.cfg /opt/rocm/llvm/bin/clang.cfg + # Stop git from erroring out when we don't own the repo RUN git config --global --add safe.directory '*' diff --git a/build/rocm/build_wheels/clang.cfg b/build/rocm/build_wheels/clang.cfg new file mode 100644 index 000000000000..767c04c03ae7 --- /dev/null +++ b/build/rocm/build_wheels/clang.cfg @@ -0,0 +1,3 @@ +# Tell clang where it can find gcc so that it can use gcc's standard libraries +--gcc-toolchain=/opt/rh/gcc-toolset-14/root/usr/ + diff --git a/build/rocm/ci_build b/build/rocm/ci_build index ef43a95044d8..71ce747d7e86 100755 --- a/build/rocm/ci_build +++ b/build/rocm/ci_build @@ -98,7 +98,10 @@ def dist_wheels( bw_cmd.append("/jax") - cmd = ["docker", "run"] + cmd = [ + "docker", + "run", + ] mounts = [ "-v", diff --git a/build/rocm/ci_build.sh b/build/rocm/ci_build.sh index 386f70ee1a96..847d4e9b4b93 100755 --- a/build/rocm/ci_build.sh +++ b/build/rocm/ci_build.sh @@ -44,7 +44,7 @@ CONTAINER_TYPE="rocm" DOCKERFILE_PATH="${SCRIPT_DIR}/Dockerfile.ms" DOCKER_CONTEXT_PATH="${SCRIPT_DIR}" KEEP_IMAGE="--rm" -PYTHON_VERSION="3.10" +PYTHON_VERSION="3.11" ROCM_VERSION="6.1.3" ROCM_BUILD_JOB="" ROCM_BUILD_NUM="" diff --git a/build/rocm/docker/Dockerfile.jax-ubu22 b/build/rocm/docker/Dockerfile.jax-ubu22 index 70b16f9e9677..b6e90f2183d2 100644 --- a/build/rocm/docker/Dockerfile.jax-ubu22 +++ b/build/rocm/docker/Dockerfile.jax-ubu22 @@ -60,7 +60,7 @@ ARG JAX_COMMIT ARG XLA_COMMIT LABEL com.amdgpu.rocm_version="$ROCM_VERSION" \ - com.amdgpu.python_version="3.10" \ + com.amdgpu.python_version="3.11" \ com.amdgpu.jax_version="$JAX_VERSION" \ com.amdgpu.jax_commit="$JAX_COMMIT" \ com.amdgpu.xla_commit="$XLA_COMMIT" diff --git a/build/rocm/setup.rocm.sh b/build/rocm/setup.rocm.sh index 3893d817e3a8..faa79d2ce1fd 100755 --- a/build/rocm/setup.rocm.sh +++ b/build/rocm/setup.rocm.sh @@ -13,7 +13,7 @@ ROCM_BUILD_NAME=ubuntu ROCM_BUILD_NUM=main # Adjust the ROCM repo location -# Intial release don't have the trialing '.0' +# Initial release don't have the trialing '.0' # For example ROCM 5.7.0 is at https://repo.radeon.com/rocm/apt/5.7/ if [ ${ROCM_VERSION##*[^0-9]} -eq '0' ]; then ROCM_VERS=${ROCM_VERSION%.*} diff --git a/build/rocm/tools/build_wheels.py b/build/rocm/tools/build_wheels.py index fd98bbb8ec04..9fdffe6cfa03 100644 --- a/build/rocm/tools/build_wheels.py +++ b/build/rocm/tools/build_wheels.py @@ -226,7 +226,10 @@ def fix_wheel(path, jax_path): py_bin = "/opt/python/cp310-cp310/bin" env["PATH"] = "%s:%s" % (py_bin, env["PATH"]) - cmd = ["pip", "install", "auditwheel>=6"] + # NOTE(mrodden): auditwheel 6.0 added lddtree module, but 6.3.0 changed + # the function to ldd and also changed its behavior + # constrain range to 6.0 to 6.2.x + cmd = ["pip", "install", "auditwheel>=6,<6.3"] subprocess.run(cmd, check=True, env=env) fixwheel_path = os.path.join(jax_path, "build/rocm/tools/fixwheel.py") @@ -248,7 +251,7 @@ def parse_args(): ) p.add_argument( "--python-versions", - default=["3.10.19,3.12"], + default=["3.11.13,3.12"], help="Comma separated CPython versions that wheels will be built and output for", ) p.add_argument( @@ -322,7 +325,7 @@ def main(): shutil.rmtree(os.path.join(args.jax_path, "jax.egg-info")) shutil.rmtree(os.path.join(args.jax_path, "jax", "__pycache__")) - # Make the wheels deleteable by the runner + # Make the wheels deletable by the runner whl_house = os.path.join(args.jax_path, "wheelhouse") logging.info("Changing permissions for %s" % whl_house) mode = 0o664 diff --git a/build/rocm/tools/fixwheel.py b/build/rocm/tools/fixwheel.py index ea77162728d5..7d8c1fcce055 100644 --- a/build/rocm/tools/fixwheel.py +++ b/build/rocm/tools/fixwheel.py @@ -87,7 +87,7 @@ def fix_wheel(path): exclude = list(ext_libs.keys()) # call auditwheel repair with excludes - cmd = ["auditwheel", "repair", "--plat", plat, "--only-plat"] + cmd = ["auditwheel", "-v", "repair", "--plat", plat, "--only-plat"] for ex in exclude: cmd.append("--exclude") diff --git a/build/test-requirements.txt b/build/test-requirements.txt index f0b315771cbb..50311faebde6 100644 --- a/build/test-requirements.txt +++ b/build/test-requirements.txt @@ -1,7 +1,5 @@ absl-py -build cloudpickle -colorama>=0.4.4 filelock flatbuffers hypothesis @@ -10,12 +8,8 @@ pillow>=10.4.0 # TODO(kanglan): Remove once psutil from portpicker supports python 3.13t portpicker; python_version<"3.13" pytest-xdist -wheel rich -setuptools # matplotlib 3.9.0 pins NumPy 1.23, which is incompatible with the requirement # below. -matplotlib~=3.8.4; python_version=="3.10" -matplotlib; python_version>="3.11" -opt-einsum +matplotlib auditwheel \ No newline at end of file diff --git a/build/tools/utils.py b/build/tools/utils.py index 7e375169827b..7ed7f74d07a5 100644 --- a/build/tools/utils.py +++ b/build/tools/utils.py @@ -14,6 +14,7 @@ # ============================================================================== # Helper script for tools/utilities used by the JAX build CLI. import collections +import glob import hashlib import logging import os @@ -201,6 +202,25 @@ def get_clang_major_version(clang_path): return major_version +def get_clangpp_path(clang_path): + clang_path = pathlib.Path(clang_path) + clang_exec_name = clang_path.name + clangpp_exec_name = clang_exec_name + clangpp_path = clang_path.parent / clang_exec_name + # Try and match what the user passed in (either clang-18 or clang) + if "clang++" not in clangpp_exec_name: + clangpp_exec_name = clangpp_exec_name.replace("clang", "clang++") + clangpp_path = clang_path.parent / clangpp_exec_name + if not clangpp_path.exists(): + clangpp_exec_name = "clang++" + clangpp_path = clang_path.parent / clangpp_exec_name + if not clangpp_path.exists(): + raise FileNotFoundError( + f"Failed to get clang++ path from clang path: '{clang_path!s}'. " + f"Tried the path: '{clangpp_path!s}'." + ) + return str(clangpp_path) + def get_gcc_major_version(gcc_path: str): gcc_version_proc = subprocess.run( [gcc_path, "-dumpversion"], @@ -256,3 +276,31 @@ def _parse_string_as_bool(s): return False else: raise ValueError(f"Expected either 'true' or 'false'; got {s}") + + +def copy_dir_recursively(src, dst): + if os.path.exists(dst): + shutil.rmtree(dst) + os.makedirs(dst, exist_ok=True) + for root, dirs, files in os.walk(src): + relative_path = os.path.relpath(root, src) + dst_dir = os.path.join(dst, relative_path) + os.makedirs(dst_dir, exist_ok=True) + for f in files: + src_file = os.path.join(root, f) + dst_file = os.path.join(dst_dir, f) + shutil.copy2(src_file, dst_file) + logging.info("Editable wheel path: %s" % dst) + + +def copy_individual_files(src: str, dst: str, glob_pattern: str): + os.makedirs(dst, exist_ok=True) + logging.debug( + f"Copying files matching pattern {glob_pattern!r} from {src!r} to {dst!r}" + ) + for f in glob.glob(os.path.join(src, glob_pattern)): + dst_file = os.path.join(dst, os.path.basename(f)) + if os.path.exists(dst_file): + os.remove(dst_file) + shutil.copy2(f, dst_file) + logging.info("Distribution path: %s" % dst_file) diff --git a/build_wheel.py b/build_wheel.py index f8e1595d3c3a..793523e8e3b2 100644 --- a/build_wheel.py +++ b/build_wheel.py @@ -47,6 +47,25 @@ parser.add_argument( "--srcs", help="source files for the wheel", action="append" ) +parser.add_argument( + "--build-wheel-only", + default=False, + help=( + "Whether to build the wheel only. Optional." + ), +) +parser.add_argument( + "--build-source-package-only", + default=False, + help=( + "Whether to build the source package only. Optional." + ), +) +parser.add_argument( + "--editable", + action="store_true", + help="Create an 'editable' jax build instead of a wheel.", +) args = parser.parse_args() @@ -76,7 +95,11 @@ def prepare_srcs(deps: list[str], srcs_dir: str) -> None: """ for file in deps: - if not (file.startswith("bazel-out") or file.startswith("external")): + if not ( + file.startswith("bazel-out") + or file.startswith("external") + or file.startswith("jaxlib") + ): copy_file(file, srcs_dir) @@ -89,13 +112,18 @@ def prepare_srcs(deps: list[str], srcs_dir: str) -> None: try: os.makedirs(args.output_path, exist_ok=True) prepare_srcs(args.srcs, pathlib.Path(sources_path)) - build_utils.build_wheel( - sources_path, - args.output_path, - package_name="jax", - git_hash=args.jaxlib_git_hash, - build_wheel_only=False, - ) + package_name = "jax" + if args.editable: + build_utils.build_editable(sources_path, args.output_path, package_name) + else: + build_utils.build_wheel( + sources_path, + args.output_path, + package_name, + git_hash=args.jaxlib_git_hash, + build_wheel_only=args.build_wheel_only, + build_source_package_only=args.build_source_package_only, + ) finally: if tmpdir: tmpdir.cleanup() diff --git a/ci/README.md b/ci/README.md index ea867df52f97..31af3ec0ef87 100644 --- a/ci/README.md +++ b/ci/README.md @@ -1,10 +1,254 @@ -# JAX continuous integration +# JAX Continuous Integration -> [!WARNING] -> This folder is still under construction. It is part of an ongoing -> effort to improve the structure of CI and build related files within the -> JAX repo. This warning will be removed when the contents of this -> directory are stable and appropriate documentation around its usage is in -> place. +This folder contains the configuration files and scripts used to build and test +JAX. It is typically used by continuous integration (CI) jobs to automate builds +and run comprehensive tests across various platforms and configurations. This +page provides an overview of the JAX CI system, its components, and the +different workflows it supports. -******************************************************************************** \ No newline at end of file +******************************************************************************** + +## JAX's CI System + +![Overview of JAX's CI System](jax_ci_system.png) + +JAX's CI system is composed of several interacting components and orchestrates +builds and tests using a hybrid approach, leveraging both an internal CI system +and GitHub Actions as well as an internal build orchestrator for managing +nightly and release flows. It encompasses several distinct workflows, including +comprehensive presubmit checks triggered on pull requests and branch pushes, +bi-hourly continuous builds, extensive nightly builds with broad platform +coverage, and a controlled release process that culminates in PyPI publication. + +These flows build four packages: `jax`, `jaxlib`, `jax-cuda-plugin`, +`jax-cuda-pjrt` and support a range of environments, including: + +* **Linux x86:** CPU, TPU, CUDA +* **Linux aarch64:** CPU, CUDA +* **Windows x86:** CPU +* **Mac Arm64:** CPU + +### Architecture Overview + +1. **Internal CI System:** An internal CI system is used for specific build and + test tasks, such as nightly builds, release candidate (RC) builds, and + Mac-specific testing. + +2. **GitHub Actions:** Used for presubmit checks, continuous integration builds + and tests, and nightly/release artifact testing. + +3. **Build Orchestrator:** An internal tool used to manage complex workflows + such as nightly / release flows, promoting RC builds to release, etc. + +4. **Artifact Storage:** + +* Google Cloud Storage (GCS) Buckets: Used for temporary storage of artifacts + between jobs in GitHub Actions workflows and for storing packages built + during nightly and release flows before testing. +* Artifact Registry: Used to store nightly packages, RC packages and final + releases. +* PyPI: Where final releases are published. + +### CI Workflows and Where They Run + +JAX's CI system consists of the following workflows: + +1. **Presubmits:** Presubmits are run in GitHub actions and are triggered on + pull requests that target the `main` branch and on pushes to the `main` and + `release` branch. JAX's presubmit run time SLO is about 10 minutes so these + are typically run using Bazel with remote build execution + ([RBE](https://bazel.build/remote/rbe)). RBE allows us to execute build and + test actions on a distributed system, separate from the local machine, + instead of solely on the local machine. This enables faster build and test + times by utilizing parallel computing resources and caching across a cluster + of machines. However, we also use Pytest in workflows where we are not able + to use RBE such as the TPU presubmit. In such presubmits, we usually run a + subset of tests to be able to satisfy the presubmit run time SLO. To see the + list of the presubmit workflows, + [click here](https://github.com/search?q=repo%3Ajax-ml%2Fjax+path%3A.github%2Fworkflows%2F+%28path%3A**%2F*.yml+OR+path%3A**%2F*.yaml%29+%22pull_request%22&type=code). + +2. **Continuous:** These jobs are run in GitHub actions and are scheduled to + run once every 2 hours on the `main` branch. It builds JAX packages and runs + a wide range of tests targeting different environments such as CPU, CUDA + (L4, H100, B200, etc), and TPU (v4-8, v5e-8, etc.). For more information, + see + [wheel_tests_continuous.yml](https://github.com/jax-ml/jax/blob/main/.github/workflows/wheel_tests_continuous.yml) + ([An example run](https://github.com/jax-ml/jax/actions/workflows/wheel_tests_continuous.yml).) + +3. **Nightly Builds and Tests:** These jobs use an hybrid approach of both the + internal CI system and GitHub actions. The jobs are triggered once every + night by the internal build orchestrator tool. It first triggers the jobs in + the internal CI system to build the JAX packages for different + configurations (Python versions, CUDA versions, etc) and uploads them to a + staging bucket in GCS as well as to the nightly artifact registry. Next, + testing jobs are triggered that download the artifacts from the staging + bucket and run tests. Mac testing jobs are run in the internal CI system. + For non-Mac testing, a trigger job is run that invokes the + [wheel_tests_nightly_release.yml](https://github.com/jax-ml/jax/blob/main/.github/workflows/wheel_tests_nightly_release.yml) + workflow in GitHub Actions. JAX's nightly artifacts can be found here: + [jax](https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/simple/jax), + [jaxlib](https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/simple/jaxlib), + [jax-cuda-plugin](https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/simple/jax-cuda12-plugin), + [jax-cuda-pjrt](https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/simple/jax-cuda12-pjrt). + +4. **Release Builds and Tests:** Release flow is similar to the nightly flow + except for few differences. First, release process has to be triggered + manually in the internal build orchestrator and should be done only after a + release branch (E.g `release/0.5.3`) has been created. The build jobs build + two sets of artifacts for each package: 1. RC wheels 2. Final version + wheels. These two sets are pretty much the same package except for their + metadata and wheel tags. The RC wheels are then uploaded to the staging + bucket and release artifact registry. After the uploads are done, the test + jobs are triggered. As with the nightly flow, Mac test jobs are run in the + internal CI system while non-Mac test jobs are run in GitHub actions. To see + the GitHub actions run for a particular release, filter the workflow runs by + its branch name. + + +5. **Promote RC to Final and Publish to PyPI:** If the RC wheels pass all + testing, then we are ready to promote it as the final version and publish it + to PyPI. This entire flow is internal and is run in our internal CI system. + Final version of the packages are published to PyPI and JAX's release + artifact registry. JAX's release artifacts (RC and final versions) can be + found here: + [jax](https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-release-artifacts-registry/simple/jax), + [jaxlib](https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-release-artifacts-registry/simple/jaxlib), + [jax-cuda-plugin](https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-release-artifacts-registry/simple/jax-cuda12-plugin), + [jax-cuda-pjrt](https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-release-artifacts-registry/simple/jax-cuda12-pjrt). + +### JAX's Official CI and Build/Test Scripts + +JAX's CI jobs (both internal and those on GitHub actions) run the scripts in +this folder. An overview of the different folders and their purpose is given +below: + +- **ci/**: Contains all build scripts, environment files, and utility scripts. +- **ci/utilities/**: Contains helper scripts used throughout the build/test + process. See + [README.md](https://github.com/jax-ml/jax/blob/main/ci/utilities/README.md) + for a brief overview of these utility scripts and their behavior. +- **ci/envs/**: Holds environment files that set `JAXCI` environment variables + that control build and test configurations. see + [README.md](https://github.com/jax-ml/jax/blob/main/ci/envs/README.md) to + see the complete list of these variables and their behavior. + +Every build script in this folder first source the `JAXCI` envs in +[default.env](https://github.com/jax-ml/jax/blob/main/ci/envs/default.env) and +then run the +[setup_build_environment.sh](https://github.com/jax-ml/jax/blob/main/ci/utilities/setup_build_environment.sh) +script to set up the build environment. + +A brief overview of each build script in this folder is given below: + +> [!NOTE] +> Both internal and GitHub action jobs run under the +> [ml-build](https://github.com/tensorflow/tensorflow/tree/master/ci/official/containers) +> Docker image which contains build tools such as Python, Bazelisk, LLVM/Clang, +> manylinux compliant libraries (in Linux images), etc. + +- **build_artifacts.sh:** These build the various JAX artifacts. We build + three different type of artifacts based on the type of job: Nightly, + RC/Release, or at HEAD. +- **run_bazel_test_cpu_rbe.sh/run_bazel_test_cuda_rbe.sh**: These run Bazel + tests with RBE on every GitHub PR. We test compatibility with both CPU and + CUDA. On platforms where RBE is not natively supported (e.g Linux Arm64), we + cross-compile the test targets for Linux Aarch64 on Linux x86. As the tests + still need to be run on the host machines and because running the tests on a + single machine can take a long time, we skip running them on these + platforms. +- **run_bazel_test_cuda_non_rbe.sh**: These run the following Bazel CUDA + tests: Single accelerator tests with one GPU apiece and Multi-accelerator + tests with all GPUs. These jobs depend on local JAX wheels and therefore + require that the following wheels to be present in the `../dist` folder: + `jax`, `jaxlib`, `jax-cuda-plugin`, and `jax-cuda-pjrt` wheels. In CI + builds, we first build these wheels from source and then run the `bazel + test` command. +- **run_pytest_*.sh**: These run tests with Pytests and use the JAX wheel + packages installed on the system. In CI builds, we build the wheels first + from source and then run the `pytest` commands. We test compatibility with + CPU, CUDA, and TPU. These are primarily run as part of the continuous and + nightly/release test jobs except for TPU which is also run as a presubmit + testing a subset of the tests. + +## Different Test Configurations + +JAX's CI Test jobs run under different test configurations. These configurations +are described briefly in the sections below. + +### XLA Versions + +JAX's CI builds rely on XLA, but use different versions depending on the type of +build. To ensure stability and reproducibility, nightly and release builds use a +pinned XLA version specified in the JAX +[workspace](https://github.com/jax-ml/jax/blob/34a2f0ca4a8f8a26d9a056f8785f412bd156dc23/third_party/xla/workspace.bzl#L24-L25). + +However, to keep JAX compatible with the latest XLA developments, presubmit and +postsubmit builds utilize the most recent XLA version. This is done by +overriding the default XLA dependency with a local copy of the XLA repository. +We do this by passing `--override_repository=xla=/path/to/local/xla` which +instructs Bazel to depend on the XLA in the local system instead of the version +in the workspace. + +The CI system uses the `JAXCI` environment variables to manage this process. +When running jobs that need to use XLA at head, we set `JAXCI_CLONE_MAIN_XLA=1`. +This clones the XLA repository at head and sets `JAXCI_XLA_GIT_DIR` to its path. +[JAX build CLI](https://github.com/jax-ml/jax/blob/main/build/build.py) +automatically adds the necessary Bazel flag (`--override_repository`) to point +to this local XLA version during the build process if `JAXCI_XLA_GIT_DIR` is +set. In jobs where the build CLI is not used such as the RBE presubmits, we +explicitly include `--override_repository=xla="${JAXCI_XLA_GIT_DIR}"` as part +of the test command. + +### Enabling/Disabling 64-bit Data Types + +By default, JAX enforces single-precision numbers to mitigate the Numpy API’s +tendency to aggressively promote operands to `double`. In order to use +double-precision numbers, we need to set the `JAX_ENABLE_X64` environment +variable. In CI, we test both configurations in presubmits and postsubmits by +using the `JAXCI_ENABLE_X64` environment variable. + + + +## [Googlers Only] Connecting to CI Runners for Debugging + +If you are a Googler, you can connect to one of the self-hosted runners we have +on GitHub to debug your workflow. For more information, see +go/ml-github-actions:connect. + +## Running These Scripts Locally on Your Machine + +> [!IMPORTANT] +> If you are a Linux / Windows user, you need to have Docker installed as a +> prerequisite. Additionally, if running on Windows, please run these commands +> in a bash environment as all the scripts are written in Shell. + +Follow the steps below to run a CI script locally on your machine. + +1. [Optional] Set `JAXCI` variables in your shell environment. See + [ci/envs/README.md](https://github.com/jax-ml/jax/blob/main/ci/envs/README.md) + for the list of `JAXCI` variables and their behavior. + +2. [Linux/Windows] + + Start the Docker container by running: + + ```bash + ./ci/utilities/run_docker_container.sh + ``` + + This will start a Docker container named "jax". Note that if you set any + `JAXCI` variables in step 1, they will also be be set in the container. + + Run the script under the Docker container. + + ```bash + # docker exec jax + docker exec jax ./ci/build_artifacts.sh jaxlib + ``` + +3. [Mac] Execute the build script directly. + + ```bash + # ./ + ./ci/build_artifacts.sh jaxlib + ``` diff --git a/ci/build_artifacts.sh b/ci/build_artifacts.sh index 84b8d35a2a50..d7ffe82eb699 100755 --- a/ci/build_artifacts.sh +++ b/ci/build_artifacts.sh @@ -96,6 +96,7 @@ if [[ "${allowed_artifacts[@]}" =~ "${artifact}" ]]; then --bazel_options=--config="$bazelrc_config" $bazel_remote_cache \ --python_version=$JAXCI_HERMETIC_PYTHON_VERSION \ --verbose --detailed_timestamped_log --use_new_wheel_build_rule \ + --output_path="$JAXCI_OUTPUT_DIR" \ $artifact_tag_flags # If building release artifacts, we also build a release candidate ("rc") @@ -105,18 +106,10 @@ if [[ "${allowed_artifacts[@]}" =~ "${artifact}" ]]; then --bazel_options=--config="$bazelrc_config" $bazel_remote_cache \ --python_version=$JAXCI_HERMETIC_PYTHON_VERSION \ --verbose --detailed_timestamped_log --use_new_wheel_build_rule \ + --output_path="$JAXCI_OUTPUT_DIR" \ $artifact_tag_flags --bazel_options=--repo_env=ML_WHEEL_VERSION_SUFFIX="$JAXCI_WHEEL_RC_VERSION" fi - # Move the built artifacts from the Bazel cache directory to the output - # directory. - if [[ "$artifact" == "jax" ]]; then - mv bazel-bin/dist/*.whl "$JAXCI_OUTPUT_DIR" - mv bazel-bin/dist/*.tar.gz "$JAXCI_OUTPUT_DIR" - else - mv bazel-bin/jaxlib/tools/dist/*.whl "$JAXCI_OUTPUT_DIR" - fi - # If building `jaxlib` or `jax-cuda-plugin` or `jax-cuda-pjrt` for Linux, we # run `auditwheel show` to verify manylinux compliance. if [[ "$os" == "linux" ]] && [[ "$artifact" != "jax" ]]; then diff --git a/ci/envs/README.md b/ci/envs/README.md new file mode 100644 index 000000000000..2a81d0f3240d --- /dev/null +++ b/ci/envs/README.md @@ -0,0 +1,41 @@ +# JAXCI Environment Variables + +This docpage describes the various `JAXCI` environment variables that are used +in the CI scripts and their behaviors. These variables are used to control the +behavior of the CI scripts such as the Python version used, path to JAX/XLA +repo, if to clone XLA repo, etc. + +Name | Default Value | Behavior | Usage +------------------------------------------- | ---------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | ----- +`JAXCI_JAX_GIT_DIR` | Present working directory: `$(pwd)` | Path to the JAX's Git directory. | [Usage](https://github.com/search?q=repo%3Ajax-ml%2Fjax%20JAXCI_JAX_GIT_DIR&type=code) +`JAXCI_HERMETIC_PYTHON_VERSION` | System default | Controls the version of hermetic Python to use. This affects the Bazel commands only such as when building artifacts or when running the Bazel test scripts. | [Usage](https://github.com/search?q=repo%3Ajax-ml%2Fjax%20JAXCI_HERMETIC_PYTHON_VERSION&type=code) +`JAXCI_XLA_GIT_DIR` | Unset | When using a local copy of XLA, this points to the root of the XLA git repository. | [Usage](https://github.com/search?q=repo%3Ajax-ml%2Fjax%20JAXCI_XLA_GIT_DIR&type=code) +`JAXCI_CLONE_MAIN_XLA` | 0 | If set to 1, the XLA repository is cloned at HEAD and its path is set in `JAXCI_XLA_GIT_DIR` | [Usage](https://github.com/search?q=repo%3Ajax-ml%2Fjax%20JAXCI_CLONE_MAIN_XLA&type=code) +`JAXCI_XLA_COMMIT` | Unset | Allows overriding the XLA commit that is used when using a local copy of XLA. | [Usage](https://github.com/search?q=repo%3Ajax-ml%2Fjax%20JAXCI_XLA_COMMIT&type=code) +`JAXCI_OUTPUT_DIR` | `$(pwd)/dist` | Controls the location where the artifacts are written to. The directory will be automatically created if it does not exist. | [Usage](https://github.com/search?q=repo%3Ajax-ml%2Fjax%20JAXCI_OUTPUT_DIR&type=code) +`JAXCI_BUILD_ARTIFACT_WITH_RBE` | 0 | When set to 1, Bazel will use RBE to build the artifacts. Requires gcloud authentication and only certain platforms support RBE so this typically only set in CI builds | [Usage](https://github.com/search?q=repo%3Ajax-ml%2Fjax%20JAXCI_BUILD_ARTIFACT_WITH_RBE&type=code) +`JAXCI_WRITE_TO_BAZEL_REMOTE_CACHE` | 0 | When set to 1, Bazel will also try to push new cache entries to the cache bucket. Since writes to the bucket require authentication, this flag is enabled only for CI builds. Note that the builds using RBE use the RBE cache and not Bazel's remote cache, therefore this variable is a no-op if `JAXCI_BUILD_ARTIFACT_WITH_RBE` is set to 1. When `JAXCI_BUILD_ARTIFACT_WITH_RBE` and `JAXCI_WRITE_TO_BAZEL_REMOTE_CACHE` are both not set, Bazel will still read from the public cache bucket to try to speed up the build. | [Usage](https://github.com/search?q=repo%3Ajax-ml%2Fjax%20JAXCI_WRITE_TO_BAZEL_REMOTE_CACHE&type=code) +`JAXCI_ARTIFACT_TYPE` | "default" | Controls the type of artifacts to build. Valid values are "default", "release", "nightly". This affects the wheel tag and metadata, see [ci/build_artifacts.sh](https://github.com/jax-ml/jax/blob/main/ci/build_artifacts.sh) to understand how. | [Usage](https://github.com/search?q=repo%3Ajax-ml%2Fjax%20JAXCI_ARTIFACT_TYPE&type=code) +`JAXCI_WHEEL_RC_VERSION` | Unset | During the release process, we build a Release Candidate (RC) wheel in addition to the release wheel. This environment variable sets the version of the RC wheel to build. Values are set internally. | [Usage](https://github.com/search?q=repo%3Ajax-ml%2Fjax%20JAXCI_WHEEL_RC_VERSION&type=code) +`JAXCI_PYTHON` | `python${JAXCI_HERMETIC_PYTHON_VERSION}` | Points to the system Python binary to use. It used by scripts that make use of the system Python such as the Pytest scripts. | [Usage](https://github.com/search?q=repo%3Ajax-ml%2Fjax%20JAXCI_PYTHON&type=code) +`JAXCI_ENABLE_X64` | 0 | By default, JAX enforces single-precision numbers to mitigate the Numpy API’s tendency to aggressively promote operands to `double`. When set to 1, the tests will use double-precision numbers. | [Usage](https://github.com/search?q=repo%3Ajax-ml%2Fjax%20JAXCI_ENABLE_X64&type=code) +`JAXCI_TPU_CORES` | Unset | Sets the number of TPU cores for the TPU machine type. Values are set in the workflow files. | [Usage](https://github.com/search?q=repo%3Ajax-ml%2Fjax%20JAXCI_TPU_CORES&type=code) +`JAXCI_RUN_FULL_TPU_TEST_SUITE` | 0 | When set to 1, the full TPU test suite is run. Otherwise, a subset of tests is run. | [Usage](https://github.com/search?q=repo%3Ajax-ml%2Fjax%20JAXCI_RUN_FULL_TPU_TEST_SUITE&type=code) +`JAXCI_JAX_PYPI_EXTRAS` | Unset | Used to control the installation of JAX extras from PyPI. See JAX's [setup.py](https://github.com/jax-ml/jax/blob/c9934912885bb7c4b72c5a9271598235a6789a81/setup.py#L71) for the list of valid values. | [Usage](https://github.com/search?q=repo%3Ajax-ml%2Fjax%20JAXCI_JAX_PYPI_EXTRAS&type=code) + +## Docker Specific Environment Variables + +> [!NOTE] +> The following environment variables only affect the build if the +> [run_docker_container.sh](https://github.com/jax-ml/jax/blob/main/ci/utilities/run_docker_container.sh) +> script was invoked to start a Docker container and the build is running inside +> that container. Typically, this would be the internal CI builds and local +> builds. Note that while GitHub actions use the same Docker images, they do not +> invoke "run_docker_container.sh" as they leverage built-in containerization +> features to run jobs within a container. + +Name | Default Value | Behavior | Usage +----------------------- | ------------------------------------------------------------------------------------------------------------ | ---------------------------------------------------------------------------------------------------- | ----- +`JAXCI_DOCKER_WORK_DIR` | "/jax" | The path on the container where the JAX Git repository is mounted to. | [Usage](https://github.com/search?q=repo%3Ajax-ml%2Fjax%20JAXCI_DOCKER_WORK_DIR&type=code) +`JAXCI_DOCKER_ARGS` | Empty String | Space separated string of additional arguments that will be passed when starting the Docker container | [Usage](https://github.com/search?q=repo%3Ajax-ml%2Fjax%20JAXCI_DOCKER_ARGS&type=code) +`JAXCI_DOCKER_IMAGE` | Depends on the system (see [ci/envs/docker.env](https://github.com/jax-ml/jax/blob/main/ci/envs/docker.env)) | Docker image to pull | [Usage](https://github.com/search?q=repo%3Ajax-ml%2Fjax%20JAXCI_DOCKER_IMAGE&type=code) diff --git a/ci/envs/default.env b/ci/envs/default.env index a5a5d56eb8b3..09594af89cbe 100644 --- a/ci/envs/default.env +++ b/ci/envs/default.env @@ -13,9 +13,8 @@ # limitations under the License. # ============================================================================== # This file contains all the default values for the "JAXCI_" environment -# variables used in the CI scripts. These variables are used to control the -# behavior of the CI scripts such as the Python version used, path to JAX/XLA -# repo, if to clone XLA repo, etc. +# variables used in the CI scripts. See ci/envs/README.md for more details on +# the behavior of these variables and their usage in the CI scripts. # The path to the JAX git repository. export JAXCI_JAX_GIT_DIR=$(pwd) @@ -25,12 +24,10 @@ export JAXCI_JAX_GIT_DIR=$(pwd) export JAXCI_HERMETIC_PYTHON_VERSION=${JAXCI_HERMETIC_PYTHON_VERSION:-$(python3 -V | awk '{print $2}' | awk -F. '{print $1"."$2}')} # Set JAXCI_XLA_GIT_DIR to the root of the XLA git repository to use a local -# copy of XLA instead of the pinned version in the WORKSPACE. When -# JAXCI_CLONE_MAIN_XLA=1, this gets set automatically. +# copy of XLA instead of the pinned version in the WORKSPACE. export JAXCI_XLA_GIT_DIR=${JAXCI_XLA_GIT_DIR:-} -# If set to 1, the builds will clone the XLA repository at HEAD and set its -# path in JAXCI_XLA_GIT_DIR. +# If set to 1, the builds will clone the XLA repository at HEAD. export JAXCI_CLONE_MAIN_XLA=${JAXCI_CLONE_MAIN_XLA:-0} # Allows overriding the XLA commit that is used. @@ -39,49 +36,35 @@ export JAXCI_XLA_COMMIT=${JAXCI_XLA_COMMIT:-} # Controls the location where the artifacts are written to. export JAXCI_OUTPUT_DIR="$(pwd)/dist" -# When enabled, artifacts will be built with RBE. Requires gcloud authentication -# and only certain platforms support RBE. Therefore, this flag is enabled only -# for CI builds where RBE is supported. +# Whether to use RBE to build the artifacts. export JAXCI_BUILD_ARTIFACT_WITH_RBE=${JAXCI_BUILD_ARTIFACT_WITH_RBE:-0} -# On platforms where RBE is not supported, we use Bazel remote cache to speed up -# builds. When this flag is enabled, Bazel will also try to push new cache -# entries to the bucket. Since writes to the bucket require authentication, this -# flag is enabled only for CI builds. +# Whether to write new cache entries to the remote cache bucket. export JAXCI_WRITE_TO_BAZEL_REMOTE_CACHE=${JAXCI_WRITE_TO_BAZEL_REMOTE_CACHE:-0} -# Type of artifacts to build. Valid values are "default", "release", "nightly". -# This affects the wheel naming/tag. +# Controls the type of artifacts to build. Valid values are "default", "release", "nightly". export JAXCI_ARTIFACT_TYPE=${JAXCI_ARTIFACT_TYPE:-"default"} -# When building release artifacts, we build a release candidate wheel ("rc" -# tagged wheel) in addition to the release wheel. This environment variable -# sets the version of the release candidate ("RC") artifact to build. +# Controls the version of the Release Candidate wheel to build during the +# release process. export JAXCI_WHEEL_RC_VERSION=${JAXCI_WHEEL_RC_VERSION:-} # ############################################################################# # Test script specific environment variables. # ############################################################################# -# Sets the value of `JAX_ENABLE_X64` in the test scripts. CI builds override -# this value in the Github action workflow files. +# Whether to use double-precision numbers in the tests. export JAXCI_ENABLE_X64=${JAXCI_ENABLE_X64:-0} -# Pytest specific environment variables below. Used in run_pytest_*.sh scripts. -# Sets the number of TPU cores for the TPU machine type. These values are -# defined in the TPU GitHub Actions workflow. +# Sets the number of TPU cores for the TPU machine type. export JAXCI_TPU_CORES=${JAXCI_TPU_CORES:-} -# JAXCI_PYTHON points to the Python interpreter to use for installing JAX wheels -# on the system. By default, it is set to match the version of the hermetic -# Python used by Bazel for building the wheels. +# JAXCI_PYTHON points to the Python binary on the system that should be used +# for installing the JAX wheels on the system and running Pytest scripts. export JAXCI_PYTHON=${JAXCI_PYTHON:-python${JAXCI_HERMETIC_PYTHON_VERSION}} # When set to 1, the full TPU test suite is run. Otherwise, a subset of tests # is run. export JAXCI_RUN_FULL_TPU_TEST_SUITE=${JAXCI_RUN_FULL_TPU_TEST_SUITE:-0} -# We use this environment variable to control which additional wheels to install -# from PyPI. For instance, it can be set to "tpu_pypi" to install the latest -# libtpu wheel from PyPI. See ci/utilities/install_wheels_locally.sh for the -# list of valid values and their behavior. -export JAXCI_ADDITIONAL_WHEELS_INSTALL_FROM_PYPI=${JAXCI_ADDITIONAL_WHEELS_INSTALL_FROM_PYPI:-""} \ No newline at end of file +# Controls which additional extras for JAX to install from PyPI. +export JAXCI_JAX_PYPI_EXTRAS=${JAXCI_JAX_PYPI_EXTRAS:-""} \ No newline at end of file diff --git a/ci/envs/docker.env b/ci/envs/docker.env index 82a76d33350c..cef2cda27bf4 100644 --- a/ci/envs/docker.env +++ b/ci/envs/docker.env @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -# This file contains all the docker specifc envs that are needed by the +# This file contains all the docker specific envs that are needed by the # ci/utilities/run_docker_container.sh script. os=$(uname -s | awk '{print tolower($0)}') @@ -29,17 +29,17 @@ export JAXCI_DOCKER_ARGS="" # Linux x86 image for building JAX artifacts, running Pytests CPU/TPU tests, and # Bazel tests if [[ $os == "linux" ]] && [[ $arch == "x86_64" ]]; then - export JAXCI_DOCKER_IMAGE="us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest" + export JAXCI_DOCKER_IMAGE="us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build:latest" fi # Linux Aarch64 image for building JAX artifacts, running Pytests CPU tests, and # Bazel tests if [[ $os == "linux" ]] && [[ $arch == "aarch64" ]]; then - export JAXCI_DOCKER_IMAGE="us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-arm64:latest" + export JAXCI_DOCKER_IMAGE="us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build-arm64:latest" fi # Windows image for building JAX artifacts, running Pytests CPU tests, and Bazel # tests if [[ $os =~ "msys_nt" ]]; then - export JAXCI_DOCKER_IMAGE="us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/tf-test-windows@sha256:6e2b299f12418d70ea522646b3dd618042a102f2ac2e4f8b1e423638549ea801" + export JAXCI_DOCKER_IMAGE="us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/tf-test-windows:latest" fi \ No newline at end of file diff --git a/ci/jax_ci_system.png b/ci/jax_ci_system.png new file mode 100644 index 000000000000..19efe62ae59e Binary files /dev/null and b/ci/jax_ci_system.png differ diff --git a/ci/k8s/indexed-job.yaml b/ci/k8s/indexed-job.yaml new file mode 100644 index 000000000000..c38a8c9991a2 --- /dev/null +++ b/ci/k8s/indexed-job.yaml @@ -0,0 +1,42 @@ +apiVersion: v1 +kind: Service +metadata: + name: jaxpods +spec: + publishNotReadyAddresses: true + clusterIP: None + selector: + job-name: jaxjob +--- +apiVersion: batch/v1 +kind: Job +metadata: + name: jaxjob +spec: + parallelism: 8 + completions: 8 + completionMode: Indexed + backoffLimit: 0 + template: + spec: + subdomain: jaxpods # must match headless service name + serviceAccountName: jax-job-sa + restartPolicy: Never + containers: + - name: main + image: local/jax:latest + imagePullPolicy: IfNotPresent + resources: + limits: + cpu: 100m + command: + - python + args: + - -c + - | + import jax + jax.distributed.initialize() + print(jax.devices()) + print(jax.local_devices()) + assert jax.process_count() > 1 + assert len(jax.devices()) > len(jax.local_devices()) diff --git a/ci/k8s/jobset.yaml b/ci/k8s/jobset.yaml new file mode 100644 index 000000000000..00150d0a9095 --- /dev/null +++ b/ci/k8s/jobset.yaml @@ -0,0 +1,34 @@ +apiVersion: jobset.x-k8s.io/v1alpha2 +kind: JobSet +metadata: + name: jaxjob +spec: + replicatedJobs: + - name: workers + template: + spec: + parallelism: 8 + completions: 8 + backoffLimit: 0 + template: + spec: + serviceAccountName: jax-job-sa + restartPolicy: Never + containers: + - name: main + image: local/jax:latest + imagePullPolicy: Never + resources: + limits: + cpu: 100m + command: + - python + args: + - -c + - | + import jax + jax.distributed.initialize() + print(jax.devices()) + print(jax.local_devices()) + assert jax.process_count() > 1 + assert len(jax.devices()) > len(jax.local_devices()) diff --git a/ci/run_bazel_test_cpu_py_import_rbe.sh b/ci/run_bazel_test_cpu_py_import_rbe.sh new file mode 100755 index 000000000000..9a17397c47ff --- /dev/null +++ b/ci/run_bazel_test_cpu_py_import_rbe.sh @@ -0,0 +1,69 @@ +#!/bin/bash +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# Runs Bazel CPU tests with py_import on RBE. +# +# -e: abort script if one command fails +# -u: error if undefined variable used +# -x: log all commands +# -o history: record shell history +# -o allexport: export all functions and variables to be available to subscripts +set -exu -o history -o allexport + +# Source default JAXCI environment variables. +source ci/envs/default.env + +# Clone XLA at HEAD if path to local XLA is not provided +if [[ -z "$JAXCI_XLA_GIT_DIR" ]]; then + export JAXCI_CLONE_MAIN_XLA=1 +fi + +# Set up the build environment. +source "ci/utilities/setup_build_environment.sh" + +# Run Bazel CPU tests with RBE. +os=$(uname -s | awk '{print tolower($0)}') +arch=$(uname -m) + +echo "Running CPU tests..." +# When running on Mac or Linux Aarch64, we build the test targets on RBE +# and run the tests locally. These platforms do not have native RBE support so +# we RBE cross-compile them on remote Linux x86 machines. +if [[ $os == "darwin" ]] || ( [[ $os == "linux" ]] && [[ $arch == "aarch64" ]] ); then + bazel test --config=rbe_cross_compile_${os}_${arch} \ + --repo_env=HERMETIC_PYTHON_VERSION="$JAXCI_HERMETIC_PYTHON_VERSION" \ + --override_repository=xla="${JAXCI_XLA_GIT_DIR}" \ + --test_env=JAX_NUM_GENERATED_CASES=25 \ + --test_env=JAX_SKIP_SLOW_TESTS=true \ + --action_env=JAX_ENABLE_X64="$JAXCI_ENABLE_X64" \ + --test_output=errors \ + --color=yes \ + --strategy=TestRunner=local \ + --//jax:build_jaxlib=wheel \ + --//jax:build_jax=wheel \ + //tests:cpu_tests //tests:backend_independent_tests +else + bazel test --config=rbe_${os}_${arch} \ + --repo_env=HERMETIC_PYTHON_VERSION="$JAXCI_HERMETIC_PYTHON_VERSION" \ + --override_repository=xla="${JAXCI_XLA_GIT_DIR}" \ + --test_env=JAX_NUM_GENERATED_CASES=25 \ + --test_env=JAX_SKIP_SLOW_TESTS=true \ + --action_env=JAX_ENABLE_X64="$JAXCI_ENABLE_X64" \ + --test_output=errors \ + --color=yes \ + --//jax:build_jaxlib=wheel \ + --//jax:build_jax=wheel \ + //tests:cpu_tests //tests:backend_independent_tests +fi \ No newline at end of file diff --git a/ci/run_bazel_test_cpu_rbe.sh b/ci/run_bazel_test_cpu_rbe.sh index 248111e0247a..7eeb2adef0b3 100755 --- a/ci/run_bazel_test_cpu_rbe.sh +++ b/ci/run_bazel_test_cpu_rbe.sh @@ -53,7 +53,9 @@ if [[ $os == "darwin" ]] || ( [[ $os == "linux" ]] && [[ $arch == "aarch64" ]] ) --action_env=JAX_ENABLE_X64="$JAXCI_ENABLE_X64" \ --test_output=errors \ --color=yes \ - //tests:cpu_tests //tests:backend_independent_tests + //tests:cpu_tests //tests:backend_independent_tests \ + //jaxlib/tools:jaxlib_wheel_size_test \ + //:jax_wheel_size_test else echo "Running RBE CPU tests..." bazel test --config=rbe_${os}_${arch} \ @@ -64,5 +66,7 @@ else --action_env=JAX_ENABLE_X64="$JAXCI_ENABLE_X64" \ --test_output=errors \ --color=yes \ - //tests:cpu_tests //tests:backend_independent_tests + //tests:cpu_tests //tests:backend_independent_tests \ + //jaxlib/tools:jaxlib_wheel_size_test \ + //:jax_wheel_size_test fi \ No newline at end of file diff --git a/ci/run_bazel_test_cuda_non_rbe.sh b/ci/run_bazel_test_cuda_non_rbe.sh index 176efd3444c9..ce3a7562fea4 100755 --- a/ci/run_bazel_test_cuda_non_rbe.sh +++ b/ci/run_bazel_test_cuda_non_rbe.sh @@ -76,6 +76,7 @@ bazel test --config=ci_linux_x86_64_cuda \ --config=rbe_cache \ --repo_env=HERMETIC_PYTHON_VERSION="$JAXCI_HERMETIC_PYTHON_VERSION" \ --//jax:build_jaxlib=false \ + --//jax:build_jax=false \ --test_env=XLA_PYTHON_CLIENT_ALLOCATOR=platform \ --run_under "$(pwd)/build/parallel_accelerator_execute.sh" \ --test_output=errors \ @@ -102,6 +103,7 @@ bazel test --config=ci_linux_x86_64_cuda \ --config=rbe_cache \ --repo_env=HERMETIC_PYTHON_VERSION="$JAXCI_HERMETIC_PYTHON_VERSION" \ --//jax:build_jaxlib=false \ + --//jax:build_jax=false \ --test_env=XLA_PYTHON_CLIENT_ALLOCATOR=platform \ --test_output=errors \ --jobs=8 \ diff --git a/ci/run_bazel_test_cuda_rbe.sh b/ci/run_bazel_test_cuda_rbe.sh index 17bd8d9db4f8..94c6a89fdb8c 100755 --- a/ci/run_bazel_test_cuda_rbe.sh +++ b/ci/run_bazel_test_cuda_rbe.sh @@ -48,4 +48,10 @@ bazel test --config=rbe_linux_x86_64_cuda \ --test_env=JAX_SKIP_SLOW_TESTS=true \ --action_env=JAX_ENABLE_X64="$JAXCI_ENABLE_X64" \ --color=yes \ - //tests:gpu_tests //tests:backend_independent_tests //tests/pallas:gpu_tests //tests/pallas:backend_independent_tests \ No newline at end of file + --@local_config_cuda//cuda:override_include_cuda_libs=true \ + //tests:gpu_tests //tests:backend_independent_tests \ + //tests/pallas:gpu_tests //tests/pallas:backend_independent_tests \ + //jaxlib/tools:jax_cuda_plugin_wheel_size_test \ + //jaxlib/tools:jax_cuda_pjrt_wheel_size_test \ + //jaxlib/tools:jaxlib_wheel_size_test \ + //:jax_wheel_size_test \ No newline at end of file diff --git a/ci/run_pytest_cpu.sh b/ci/run_pytest_cpu.sh index 43581ef2c96c..9de29691f753 100755 --- a/ci/run_pytest_cpu.sh +++ b/ci/run_pytest_cpu.sh @@ -26,13 +26,13 @@ set -exu -o history -o allexport # Source default JAXCI environment variables. source ci/envs/default.env +# Set up the build environment. +source "ci/utilities/setup_build_environment.sh" + # Install jaxlib wheel inside the $JAXCI_OUTPUT_DIR directory on the system. echo "Installing wheels locally..." source ./ci/utilities/install_wheels_locally.sh -# Set up the build environment. -source "ci/utilities/setup_build_environment.sh" - # Print all the installed packages echo "Installed packages:" "$JAXCI_PYTHON" -m uv pip list diff --git a/ci/run_pytest_tpu.sh b/ci/run_pytest_tpu.sh index 5d8aa9ed648f..ef5a8cbef943 100755 --- a/ci/run_pytest_tpu.sh +++ b/ci/run_pytest_tpu.sh @@ -41,7 +41,9 @@ echo "Installed packages:" "$JAXCI_PYTHON" -c 'import sys; print("python version:", sys.version)' "$JAXCI_PYTHON" -c 'import jax; print("jax version:", jax.__version__)' "$JAXCI_PYTHON" -c 'import jaxlib; print("jaxlib version:", jaxlib.__version__)' -strings /usr/local/lib/"$JAXCI_PYTHON"/dist-packages/libtpu/libtpu.so | grep 'Built on' +# Free-threaded builds use "-nogil" as the suffix for the binary and "t" for its +# dist-packages path +strings /usr/local/lib/"${JAXCI_PYTHON//-nogil/t}"/dist-packages/libtpu/libtpu.so | grep 'Built on' "$JAXCI_PYTHON" -c 'import jax; print("libtpu version:",jax.lib.xla_bridge.get_backend().platform_version)' # Set up all common test environment variables diff --git a/ci/utilities/README.md b/ci/utilities/README.md new file mode 100644 index 000000000000..35af5241767b --- /dev/null +++ b/ci/utilities/README.md @@ -0,0 +1,16 @@ +# JAX CI Utility Scripts + +This docpage gives a brief overview of the different utility scripts and what +they are used for. + +- **setup_build_environment.sh**: Sets up the build environment such as + cloning the latest XLA, adjusting file paths (for Windows), etc. +- **convert_msys_paths_to_win_paths.py**: Converts MSYS Linux-like paths + stored in env variables to Windows paths. +- **install_wheels_locally.sh**: Used by Pytest scripts to install JAX wheels + and any additional extras on the system. +- **run_auditwheel.sh**: Verifies that the Linux artifacts are "manylinux" + compliant. +- **run_docker_container.sh**: Runs a Docker container called "jax". Images + are read from the `JAXCI_DOCKER_IMAGE` environment variable in + [ci/envs/docker.env](https://github.com/jax-ml/jax/blob/main/ci/envs/docker.env). diff --git a/ci/utilities/install_wheels_locally.sh b/ci/utilities/install_wheels_locally.sh index f98f7658ad18..d66e1fea967b 100644 --- a/ci/utilities/install_wheels_locally.sh +++ b/ci/utilities/install_wheels_locally.sh @@ -22,31 +22,34 @@ WHEELS=( $(/usr/bin/find "$JAXCI_OUTPUT_DIR/" -type f \( -name "*jax*py3*" -o - for i in "${!WHEELS[@]}"; do if [[ "${WHEELS[$i]}" == *jax*py3*none*any.whl ]]; then - if [[ "$JAXCI_ADDITIONAL_WHEELS_INSTALL_FROM_PYPI" == "tpu_pypi" ]]; then - # Append [tpu] to the jax wheel name to download the latest libtpu wheel - # from PyPI. - WHEELS[$i]="${WHEELS[$i]}[tpu]" + # Append an extra to the end of the JAX wheel path to install those + # packages as well from PyPI. E.g. jax[tpu] will install the libtpu package + # from PyPI. See ci/envs/README.md for more details. + if [[ -n "$JAXCI_JAX_PYPI_EXTRAS" ]]; then + WHEELS[$i]="${WHEELS[$i]}[$JAXCI_JAX_PYPI_EXTRAS]" fi fi done -if [[ -z "${WHEELS[@]}" ]]; then - echo "ERROR: No wheels found under $JAXCI_OUTPUT_DIR" - exit 1 -fi +if [[ -n "${WHEELS[@]}" ]]; then + echo "Installing the following wheels:" + echo "${WHEELS[@]}" -echo "Installing the following wheels:" -echo "${WHEELS[@]}" - -# Install `uv` if it's not already installed. `uv` is much faster than pip for -# installing Python packages. -if ! command -v uv >/dev/null 2>&1; then - pip install uv~=0.5.30 -fi + # Install `uv` if it's not already installed. `uv` is much faster than pip for + # installing Python packages. + if ! command -v uv >/dev/null 2>&1; then + pip install uv~=0.5.30 + fi -# On Windows, convert MSYS Linux-like paths to Windows paths. -if [[ $(uname -s) =~ "MSYS_NT" ]]; then - "$JAXCI_PYTHON" -m uv pip install $(cygpath -w "${WHEELS[@]}") + # On Windows, convert MSYS Linux-like paths to Windows paths. + if [[ $(uname -s) =~ "MSYS_NT" ]]; then + "$JAXCI_PYTHON" -m uv pip install $(cygpath -w "${WHEELS[@]}") + else + "$JAXCI_PYTHON" -m uv pip install "${WHEELS[@]}" + fi else - "$JAXCI_PYTHON" -m uv pip install "${WHEELS[@]}" + # Note that we don't exit here because the wheels may have been installed + # earlier in a different step in the CI job. + echo "INFO: No wheels found under $JAXCI_OUTPUT_DIR" + echo "INFO: Skipping local wheel installation." fi \ No newline at end of file diff --git a/ci/utilities/run_auditwheel.sh b/ci/utilities/run_auditwheel.sh index 30b6a3b51865..b8f80c3e6778 100755 --- a/ci/utilities/run_auditwheel.sh +++ b/ci/utilities/run_auditwheel.sh @@ -26,6 +26,10 @@ if [[ -z "$WHEELS" ]]; then fi for wheel in $WHEELS; do + # Skip checking manylinux compliance for jax wheel. + if [[ "$wheel" =~ 'jax-' ]]; then + continue + fi printf "\nRunning auditwheel on the following wheel:" ls $wheel OUTPUT_FULL=$(python -m auditwheel show $wheel) diff --git a/ci/utilities/run_docker_container.sh b/ci/utilities/run_docker_container.sh index b12566182331..e0a4592cdf6f 100755 --- a/ci/utilities/run_docker_container.sh +++ b/ci/utilities/run_docker_container.sh @@ -56,6 +56,8 @@ if ! docker container inspect jax >/dev/null 2>&1 ; then # variables to the container. JAXCI_TEMP_ENVFILE_DIR=$(mktemp) env | grep -e "JAXCI_" -e "JAX_" -e "JAXLIB_" > "$JAXCI_TEMP_ENVFILE_DIR" + # TODO(kanglan): Remove this once the rules python debug is done. + echo "RULES_PYTHON_REPO_DEBUG=${RULES_PYTHON_REPO_DEBUG:-0}" >> "$JAXCI_TEMP_ENVFILE_DIR" # On Windows, convert MSYS Linux-like paths to Windows paths. if [[ "$(uname -s)" =~ "MSYS_NT" ]]; then diff --git a/ci/utilities/setup_build_environment.sh b/ci/utilities/setup_build_environment.sh index 114acf2479ff..246665cd2f9f 100644 --- a/ci/utilities/setup_build_environment.sh +++ b/ci/utilities/setup_build_environment.sh @@ -16,7 +16,7 @@ # Set up the build environment for JAX CI jobs. This script depends on the # "JAXCI_" environment variables set or sourced in the build script. -# Pre-emptively mark the JAX git directory as safe. This is necessary for JAX CI +# Preemptively mark the JAX git directory as safe. This is necessary for JAX CI # jobs running on Linux runners in GitHub Actions. Without this, git complains # that the directory has dubious ownership and refuses to run any commands. # Avoid running on Windows runners as git runs into issues with not being able diff --git a/cloud_tpu_colabs/JAX_NeurIPS_2020_demo.ipynb b/cloud_tpu_colabs/JAX_NeurIPS_2020_demo.ipynb index edaa71b93e85..5bc045d0f606 100644 --- a/cloud_tpu_colabs/JAX_NeurIPS_2020_demo.ipynb +++ b/cloud_tpu_colabs/JAX_NeurIPS_2020_demo.ipynb @@ -225,7 +225,7 @@ "* Jacobian pre-accumulation for elementwise operations (like `gelu`)\n", "\n", "\n", - "For much more, see the [JAX Autodiff Cookbook (Part 1)](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html)." + "For much more, see the [JAX Autodiff Cookbook (Part 1)](https://docs.jax.dev/en/latest/notebooks/autodiff_cookbook.html)." ] }, { diff --git a/cloud_tpu_colabs/JAX_demo.ipynb b/cloud_tpu_colabs/JAX_demo.ipynb index d7ba5ed334f4..b69246c57e0b 100644 --- a/cloud_tpu_colabs/JAX_demo.ipynb +++ b/cloud_tpu_colabs/JAX_demo.ipynb @@ -315,7 +315,7 @@ "* Jacobian pre-accumulation for elementwise operations (like `gelu`)\n", "\n", "\n", - "For much more, see the [JAX Autodiff Cookbook (Part 1)](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html)." + "For much more, see the [JAX Autodiff Cookbook (Part 1)](https://docs.jax.dev/en/latest/notebooks/autodiff_cookbook.html)." ] }, { diff --git a/cloud_tpu_colabs/Pmap_Cookbook.ipynb b/cloud_tpu_colabs/Pmap_Cookbook.ipynb index ea126ac4f1e7..8b16cd7694eb 100644 --- a/cloud_tpu_colabs/Pmap_Cookbook.ipynb +++ b/cloud_tpu_colabs/Pmap_Cookbook.ipynb @@ -59,7 +59,7 @@ "id": "2e_06-OAJNyi" }, "source": [ - "A basic starting point is expressing parallel maps with [`pmap`](https://jax.readthedocs.io/en/latest/jax.html#jax.pmap):" + "A basic starting point is expressing parallel maps with [`pmap`](https://docs.jax.dev/en/latest/jax.html#jax.pmap):" ] }, { @@ -407,7 +407,7 @@ "source": [ "When writing nested `pmap` functions in the decorator style, axis names are resolved according to lexical scoping.\n", "\n", - "Check [the JAX reference documentation](https://jax.readthedocs.io/en/latest/jax.lax.html#parallel-operators) for a complete list of the parallel operators. More are being added!\n", + "Check [the JAX reference documentation](https://docs.jax.dev/en/latest/jax.lax.html#parallel-operators) for a complete list of the parallel operators. More are being added!\n", "\n", "Here's how to use `lax.ppermute` to implement a simple halo exchange for a [Rule 30](https://en.wikipedia.org/wiki/Rule_30) simulation:" ] diff --git a/cloud_tpu_colabs/README.md b/cloud_tpu_colabs/README.md index db3dc5f30814..6e5501584da0 100644 --- a/cloud_tpu_colabs/README.md +++ b/cloud_tpu_colabs/README.md @@ -4,7 +4,7 @@ The same JAX code that runs on CPU and GPU can also be run on TPU. Cloud TPUs have the advantage of quickly giving you access to multiple TPU accelerators, including in [Colab](https://research.google.com/colaboratory/). All of the example notebooks here use -[`jax.pmap`](https://jax.readthedocs.io/en/latest/jax.html#jax.pmap) to run JAX +[`jax.pmap`](https://docs.jax.dev/en/latest/jax.html#jax.pmap) to run JAX computation across multiple TPU cores from Colab. You can also run the same code directly on a [Cloud TPU VM](https://cloud.google.com/tpu/docs/jax-quickstart-tpu-vm). diff --git a/conftest.py b/conftest.py index fed4564bbc1c..fa0e6de94346 100644 --- a/conftest.py +++ b/conftest.py @@ -21,6 +21,7 @@ def add_imports(doctest_namespace): import jax import numpy + doctest_namespace["jax"] = jax doctest_namespace["lax"] = jax.lax doctest_namespace["jnp"] = jax.numpy @@ -29,8 +30,8 @@ def add_imports(doctest_namespace): # A pytest hook that runs immediately before test collection (i.e. when pytest # loads all the test cases to run). When running parallel tests via xdist on -# Cloud TPU, we use this hook to set the env vars needed to run multiple test -# processes across different TPU chips. +# GPU or Cloud TPU, we use this hook to set the env vars needed to run multiple +# test processes across different chips. # # It's important that the hook runs before test collection, since jax tests end # up initializing the TPU runtime on import (e.g. to query supported test @@ -43,17 +44,31 @@ def add_imports(doctest_namespace): # https://docs.pytest.org/en/latest/how-to/writing_hook_functions.html#firstresult-stop-at-first-non-none-result # for details. # -# The env var JAX_ENABLE_TPU_XDIST must be set for this hook to have an +# For TPU, the env var JAX_ENABLE_TPU_XDIST must be set for this hook to have an # effect. We do this to minimize any effect on non-TPU tests, and as a pointer # in test code to this "magic" hook. TPU tests should not specify more xdist # workers than the number of TPU chips. +# +# For GPU, the env var JAX_ENABLE_CUDA_XDIST must be set equal to the number of +# CUDA devices. Test processes will be assigned in round robin fashion across +# the devices. def pytest_collection() -> None: - if not os.environ.get("JAX_ENABLE_TPU_XDIST", None): - return - # When running as an xdist worker, will be something like "gw0" - xdist_worker_name = os.environ.get("PYTEST_XDIST_WORKER", "") - if not xdist_worker_name.startswith("gw"): - return - xdist_worker_number = int(xdist_worker_name[len("gw"):]) - os.environ.setdefault("TPU_VISIBLE_CHIPS", str(xdist_worker_number)) - os.environ.setdefault("ALLOW_MULTIPLE_LIBTPU_LOAD", "true") + if os.environ.get("JAX_ENABLE_TPU_XDIST", None): + # When running as an xdist worker, will be something like "gw0" + xdist_worker_name = os.environ.get("PYTEST_XDIST_WORKER", "") + if not xdist_worker_name.startswith("gw"): + return + xdist_worker_number = int(xdist_worker_name[len("gw") :]) + os.environ.setdefault("TPU_VISIBLE_CHIPS", str(xdist_worker_number)) + os.environ.setdefault("ALLOW_MULTIPLE_LIBTPU_LOAD", "true") + + elif num_cuda_devices := os.environ.get("JAX_ENABLE_CUDA_XDIST", None): + num_cuda_devices = int(num_cuda_devices) + # When running as an xdist worker, will be something like "gw0" + xdist_worker_name = os.environ.get("PYTEST_XDIST_WORKER", "") + if not xdist_worker_name.startswith("gw"): + return + xdist_worker_number = int(xdist_worker_name[len("gw") :]) + os.environ.setdefault( + "CUDA_VISIBLE_DEVICES", str(xdist_worker_number % num_cuda_devices) + ) diff --git a/docs/README.md b/docs/README.md index 12e00425592f..54b8a67477b0 100644 --- a/docs/README.md +++ b/docs/README.md @@ -1,2 +1,2 @@ To rebuild the documentation, -see [Update Documentation](https://jax.readthedocs.io/en/latest/developer.html#update-documentation). +see [Update Documentation](https://docs.jax.dev/en/latest/developer.html#update-documentation). diff --git a/docs/_static/multi_process/controller_and_local_devices.png b/docs/_static/multi_process/controller_and_local_devices.png new file mode 100644 index 000000000000..ad74cad65417 Binary files /dev/null and b/docs/_static/multi_process/controller_and_local_devices.png differ diff --git a/docs/_static/multi_process/mcjax_overview.png b/docs/_static/multi_process/mcjax_overview.png new file mode 100644 index 000000000000..dae947ff9df7 Binary files /dev/null and b/docs/_static/multi_process/mcjax_overview.png differ diff --git a/docs/_static/pallas/gpu/grid_tiling_off.svg b/docs/_static/pallas/gpu/grid_tiling_off.svg new file mode 100644 index 000000000000..b11d85759ce4 --- /dev/null +++ b/docs/_static/pallas/gpu/grid_tiling_off.svg @@ -0,0 +1,175 @@ + + + + + A (6x16 tiles) + B (16x16 tiles) + C = A @ B (6x16 tiles) + + + + + + + + diff --git a/docs/_static/pallas/gpu/grid_tiling_on.svg b/docs/_static/pallas/gpu/grid_tiling_on.svg new file mode 100644 index 000000000000..9d24a8187179 --- /dev/null +++ b/docs/_static/pallas/gpu/grid_tiling_on.svg @@ -0,0 +1,183 @@ + + + + + A (6x16 tiles) + B (16x16 tiles) + C = A @ B (6x16 tiles) + + + + + + + + diff --git a/docs/_static/pallas/gpu/memory_spaces.svg b/docs/_static/pallas/gpu/memory_spaces.svg new file mode 100644 index 000000000000..73dc31a12406 --- /dev/null +++ b/docs/_static/pallas/gpu/memory_spaces.svg @@ -0,0 +1,96 @@ + + + + + + Faster / Smaller Capacity + + + Slower / Larger Capacity + + + + + + Registers (RMEM) + Fastest Latency & BW + Smallest Capacity + + Holds arrays (in Pallas). + Spills if full! + + + + + Tensor Memory (TMEM) + Fastest Latency & BW + Smallest Capacity + + Explicitly managed. + Blackwell specific. + + + + + + Shared Memory (SMEM) + Fast (close to compute) + Small Capacity (per SM) + Partitioned into private slices for each CUDA block/cluster. + + + + L2 Cache + Moderate Speed + Moderate Capacity (~100MBs) + Shared betwen SMs, not directly programmable. + + + + Global Memory (GMEM) + Slowest Latency & Bandwidth + Largest Capacity (GBs) + Main GPU memory (HBM/GDDR technology). + + + + + diff --git a/docs/_static/pallas/gpu/nvidia_sm.svg b/docs/_static/pallas/gpu/nvidia_sm.svg new file mode 100644 index 000000000000..76b4edb2afad --- /dev/null +++ b/docs/_static/pallas/gpu/nvidia_sm.svg @@ -0,0 +1,99 @@ + + + + + Streaming Multiprocessor + + + + + + Warp Scheduler + + TensorCore + + ALU + (Float/Int) + + Load/Store + + + Special + Functions + + + + + + + Warp Scheduler + + TensorCore + + ALU + (Float/Int) + + Load/Store + + + Special + Functions + + + + + + + Warp Scheduler + + TensorCore + + ALU + (Float/Int) + + Load/Store + + + Special + Functions + + + + + + + Warp Scheduler + + TensorCore + + ALU + (Float/Int) + + Load/Store + + + Special + Functions + + + + + + Shared Memory / L1 Cache + + + diff --git a/docs/_static/pallas/gpu/pipeline_matmul.svg b/docs/_static/pallas/gpu/pipeline_matmul.svg new file mode 100644 index 000000000000..7037695e33e9 --- /dev/null +++ b/docs/_static/pallas/gpu/pipeline_matmul.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/docs/_static/pallas/gpu/pipeline_matmul_ws.svg b/docs/_static/pallas/gpu/pipeline_matmul_ws.svg new file mode 100644 index 000000000000..3a07ba7e9ece --- /dev/null +++ b/docs/_static/pallas/gpu/pipeline_matmul_ws.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/docs/_static/pallas/gpu/warp_specialization.svg b/docs/_static/pallas/gpu/warp_specialization.svg new file mode 100644 index 000000000000..85fbce49fa0b --- /dev/null +++ b/docs/_static/pallas/gpu/warp_specialization.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/docs/_static/pallas/pipelining_bandwidth_bound.svg b/docs/_static/pallas/pipelining_bandwidth_bound.svg new file mode 100644 index 000000000000..45b78a7ce35e --- /dev/null +++ b/docs/_static/pallas/pipelining_bandwidth_bound.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/docs/_static/pallas/pipelining_compute_bound.svg b/docs/_static/pallas/pipelining_compute_bound.svg new file mode 100644 index 000000000000..cb3b58eaef99 --- /dev/null +++ b/docs/_static/pallas/pipelining_compute_bound.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/docs/_static/pallas/pipelining_example.svg b/docs/_static/pallas/pipelining_example.svg new file mode 100644 index 000000000000..59ca5b433b11 --- /dev/null +++ b/docs/_static/pallas/pipelining_example.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/docs/_static/pallas/pipelining_latency_multistage.svg b/docs/_static/pallas/pipelining_latency_multistage.svg new file mode 100644 index 000000000000..2c40f1692b9a --- /dev/null +++ b/docs/_static/pallas/pipelining_latency_multistage.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/docs/_static/pallas/pipelining_mem_hierarchy.svg b/docs/_static/pallas/pipelining_mem_hierarchy.svg new file mode 100644 index 000000000000..d7a2e6cbabd8 --- /dev/null +++ b/docs/_static/pallas/pipelining_mem_hierarchy.svg @@ -0,0 +1,30 @@ + + + + + + + + + + + + Registers + SRAM/Caches + DRAM/HBM + Network + + Fastest + Fast + Slow + Slowest + + Lowest Capacity + Low Capacity + High Capacity + Highest Capacity + + diff --git a/docs/about.md b/docs/about.md index 58e1703842b9..baeed941c8c3 100644 --- a/docs/about.md +++ b/docs/about.md @@ -19,7 +19,7 @@ technology stack](#components). First, we design the `jax` module to be [composable](https://github.com/jax-ml/jax?tab=readme-ov-file#transformations) and -[extensible](https://jax.readthedocs.io/en/latest/jax.extend.html), so +[extensible](https://docs.jax.dev/en/latest/jax.extend.html), so that a wide variety of domain-specific libraries can thrive outside of it in a decentralized manner. Second, we lean heavily on a modular backend stack (compiler and runtime) to target different @@ -42,10 +42,10 @@ scale. JAX's day-to-day development takes place in the open on GitHub, using pull requests, the issue tracker, discussions, and [JAX Enhancement Proposals -(JEPs)](https://jax.readthedocs.io/en/latest/jep/index.html). Reading +(JEPs)](https://docs.jax.dev/en/latest/jep/index.html). Reading and participating in these is a good way to get involved. We also maintain [developer -notes](https://jax.readthedocs.io/en/latest/contributor_guide.html) +notes](https://docs.jax.dev/en/latest/contributor_guide.html) that cover JAX's internal design. The JAX core team determines whether to accept changes and @@ -56,7 +56,7 @@ intricate decision structure over time (e.g. with designated area owners) if/when it becomes useful to do so. For more see [contributing to -JAX](https://jax.readthedocs.io/en/latest/contributing.html). +JAX](https://docs.jax.dev/en/latest/contributing.html). (components)= ## A modular stack @@ -71,7 +71,7 @@ and (b) an advancing hardware landscape, we lean heavily on While the JAX core library focuses on the fundamentals, we want to encourage domain-specific libraries and tools to be built on top of JAX. Indeed, [many -libraries](https://jax.readthedocs.io/en/latest/#ecosystem) have +libraries](https://docs.jax.dev/en/latest/#ecosystem) have emerged around JAX to offer higher-level features and extensions. How do we encourage such decentralized development? We guide it with @@ -80,11 +80,11 @@ building blocks (e.g. numerical primitives, NumPy operations, arrays, and transformations), encouraging auxiliary libraries to develop utilities as needed for their domain. In addition, JAX exposes a handful of more advanced APIs for -[customization](https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html) +[customization](https://docs.jax.dev/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html) and -[extensibility](https://jax.readthedocs.io/en/latest/jax.extend.html). Libraries +[extensibility](https://docs.jax.dev/en/latest/jax.extend.html). Libraries can [lean on these -APIs](https://jax.readthedocs.io/en/latest/building_on_jax.html) in +APIs](https://docs.jax.dev/en/latest/building_on_jax.html) in order to use JAX as an internal means of implementation, to integrate more with its transformations like autodiff, and more. diff --git a/docs/advanced-autodiff.md b/docs/advanced-autodiff.md index eaa3bc7317c8..bef2fd088a3a 100644 --- a/docs/advanced-autodiff.md +++ b/docs/advanced-autodiff.md @@ -876,7 +876,7 @@ There are two ways to define differentiation rules in JAX: 1. Using {func}`jax.custom_jvp` and {func}`jax.custom_vjp` to define custom differentiation rules for Python functions that are already JAX-transformable; and 2. Defining new `core.Primitive` instances along with all their transformation rules, for example to call into functions from other systems like solvers, simulators, or general numerical computing systems. -This notebook is about #1. To read instead about #2, refer to the [notebook on adding primitives](https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html). +This notebook is about #1. To read instead about #2, refer to the [notebook on adding primitives](https://docs.jax.dev/en/latest/notebooks/How_JAX_primitives_work.html). ### TL;DR: Custom JVPs with {func}`jax.custom_jvp` @@ -1608,7 +1608,7 @@ Array(-0.91113025, dtype=float32) #### Working with `list` / `tuple` / `dict` containers (and other pytrees) -You should expect standard Python containers like lists, tuples, namedtuples, and dicts to just work, along with nested versions of those. In general, any [pytrees](https://jax.readthedocs.io/en/latest/pytrees.html) are permissible, so long as their structures are consistent according to the type constraints. +You should expect standard Python containers like lists, tuples, namedtuples, and dicts to just work, along with nested versions of those. In general, any [pytrees](https://docs.jax.dev/en/latest/pytrees.html) are permissible, so long as their structures are consistent according to the type constraints. Here's a contrived example with {func}`jax.custom_jvp`: diff --git a/docs/advanced_guide.rst b/docs/advanced_guide.rst index db2e83ae2720..1cc48b8959dd 100644 --- a/docs/advanced_guide.rst +++ b/docs/advanced_guide.rst @@ -14,6 +14,7 @@ operations. notebooks/Distributed_arrays_and_automatic_parallelization notebooks/explicit-sharding notebooks/shard_map + notebooks/host-offloading multi_process distributed_data_loading diff --git a/docs/aot.md b/docs/aot.md index 1fcf11ab945d..1870f8c55093 100644 --- a/docs/aot.md +++ b/docs/aot.md @@ -26,7 +26,7 @@ are arrays, JAX does the following in order: carries out this specialization by a process that we call _tracing_. During tracing, JAX stages the specialization of `F` to a jaxpr, which is a function in the [Jaxpr intermediate - language](https://jax.readthedocs.io/en/latest/jaxpr.html). + language](https://docs.jax.dev/en/latest/jaxpr.html). 2. **Lower** this specialized, staged-out computation to the XLA compiler's input language, StableHLO. @@ -49,7 +49,10 @@ some other features along the way. An example: >>> # Print the specialized, staged-out representation (as Jaxpr IR) >>> print(traced.jaxpr) -{ lambda ; a:i32[] b:i32[]. let c:i32[] = mul 2 a; d:i32[] = add c b in (d,) } +{ lambda ; a:i32[] b:i32[]. let + c:i32[] = mul 2:i32[] a + d:i32[] = add c b + in (d,) } >>> lowered = traced.lower() diff --git a/docs/api_compatibility.md b/docs/api_compatibility.md index 749c5907bc6b..985b2145c5c4 100644 --- a/docs/api_compatibility.md +++ b/docs/api_compatibility.md @@ -59,6 +59,11 @@ Any API or import path prefixed with an underscore is explicitly private, and may change without warning between JAX releases. We are working to move all private APIs into `jax._src` to make these expectations more clear. +### jaxlib +Any import path in the `jaxlib` package is considered private, and may change +without warning between releases. Some APIs defined in `jaxlib` have public +aliases in the `jax` package. + ### Legacy internal APIs In addition, there are several legacy modules that currently expose some private APIs without an underscore, including: @@ -91,7 +96,7 @@ guarantees of the main JAX package. If you have code that uses `jax.extend`, we would strongly recommend CI tests against JAX's nightly releases, so as to catch potential changes before they are released. -For details on `jax.extend`, see the [`jax.extend` module docuementation](https://jax.readthedocs.io/en/latest/jax.extend.html), or the design document, {ref}`jax-extend-jep`. +For details on `jax.extend`, see the [`jax.extend` module documentation](https://docs.jax.dev/en/latest/jax.extend.html), or the design document, {ref}`jax-extend-jep`. ## Numerics and randomness diff --git a/docs/autodidax.ipynb b/docs/autodidax.ipynb index 7ec91affa05d..f57ce09e0bf6 100644 --- a/docs/autodidax.ipynb +++ b/docs/autodidax.ipynb @@ -72,7 +72,7 @@ "outputs, we want to override primitive application and let different values\n", "flow through our program. For example, we might want to replace the\n", "application of every primitive with an application of [its JVP\n", - "rule](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html),\n", + "rule](https://docs.jax.dev/en/latest/notebooks/autodiff_cookbook.html),\n", "and let primal-tangent pairs flow through our program. Moreover, we want to be\n", "able to compose multiple transformations, leading to stacks of interpreters." ] @@ -2019,7 +2019,8 @@ "\n", " output = io.StringIO()\n", " c.module.operation.print(file=output)\n", - " compiled = xb.get_backend(None).compile(output.getvalue())\n", + " backend = xb.get_backend(None)\n", + " compiled = backend.compile_and_load(output.getvalue(), backend.devices()[:1])\n", " return partial(execute_compiled, compiled, [v.aval for v in jaxpr.outs])\n", "\n", "def _mlir_dtype(dtype: np.dtype) -> ir.Type:\n", @@ -3620,7 +3621,7 @@ "source": [ "Notice that we're not currently supporting the case where the predicate value\n", "itself is batched. In mainline JAX, we handle this case by transforming the\n", - "conditional to a [select primitive](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.select.html).\n", + "conditional to a [select primitive](https://docs.jax.dev/en/latest/_autosummary/jax.lax.select.html).\n", "That transformation is semantically correct so long as `true_fun` and\n", "`false_fun` do not involve any side-effecting primitives.\n", "\n", diff --git a/docs/autodidax.md b/docs/autodidax.md index 2d4d6cd528af..5bf0e8f78e12 100644 --- a/docs/autodidax.md +++ b/docs/autodidax.md @@ -72,7 +72,7 @@ where we apply primitive operations to numerical inputs to produce numerical outputs, we want to override primitive application and let different values flow through our program. For example, we might want to replace the application of every primitive with an application of [its JVP -rule](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html), +rule](https://docs.jax.dev/en/latest/notebooks/autodiff_cookbook.html), and let primal-tangent pairs flow through our program. Moreover, we want to be able to compose multiple transformations, leading to stacks of interpreters. @@ -1589,7 +1589,8 @@ def xla_callable(hashable_jaxpr: IDHashable, output = io.StringIO() c.module.operation.print(file=output) - compiled = xb.get_backend(None).compile(output.getvalue()) + backend = xb.get_backend(None) + compiled = backend.compile_and_load(output.getvalue(), backend.devices()[:1]) return partial(execute_compiled, compiled, [v.aval for v in jaxpr.outs]) def _mlir_dtype(dtype: np.dtype) -> ir.Type: @@ -2843,7 +2844,7 @@ print(out) Notice that we're not currently supporting the case where the predicate value itself is batched. In mainline JAX, we handle this case by transforming the -conditional to a [select primitive](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.select.html). +conditional to a [select primitive](https://docs.jax.dev/en/latest/_autosummary/jax.lax.select.html). That transformation is semantically correct so long as `true_fun` and `false_fun` do not involve any side-effecting primitives. diff --git a/docs/autodidax.py b/docs/autodidax.py index f8c6372fe30d..695fc9993df5 100644 --- a/docs/autodidax.py +++ b/docs/autodidax.py @@ -62,7 +62,7 @@ # outputs, we want to override primitive application and let different values # flow through our program. For example, we might want to replace the # application of every primitive with an application of [its JVP -# rule](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html), +# rule](https://docs.jax.dev/en/latest/notebooks/autodiff_cookbook.html), # and let primal-tangent pairs flow through our program. Moreover, we want to be # able to compose multiple transformations, leading to stacks of interpreters. @@ -1581,7 +1581,8 @@ def main(*params): output = io.StringIO() c.module.operation.print(file=output) - compiled = xb.get_backend(None).compile(output.getvalue()) + backend = xb.get_backend(None) + compiled = backend.compile_and_load(output.getvalue(), backend.devices()[:1]) return partial(execute_compiled, compiled, [v.aval for v in jaxpr.outs]) def _mlir_dtype(dtype: np.dtype) -> ir.Type: @@ -2837,7 +2838,7 @@ def cond_vmap_rule(axis_size, vals_in, dims_in, *, true_jaxpr, false_jaxpr): # Notice that we're not currently supporting the case where the predicate value # itself is batched. In mainline JAX, we handle this case by transforming the -# conditional to a [select primitive](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.select.html). +# conditional to a [select primitive](https://docs.jax.dev/en/latest/_autosummary/jax.lax.select.html). # That transformation is semantically correct so long as `true_fun` and # `false_fun` do not involve any side-effecting primitives. # diff --git a/docs/autodidax2_part1.ipynb b/docs/autodidax2_part1.ipynb index 0a5a89c8ed98..7a58f54b16c8 100644 --- a/docs/autodidax2_part1.ipynb +++ b/docs/autodidax2_part1.ipynb @@ -674,7 +674,7 @@ "something is constant with respect to differentiation? It's tempting to say\n", "\"it's a constant if and only if it's not a dual number\". But actually dual\n", "numbers created by a *different* JVPInterpreter also need to be considered\n", - "constants with resepect to the JVPInterpreter we're currently handling. That's\n", + "constants with respect to the JVPInterpreter we're currently handling. That's\n", "why we need the `x.interpreter is self` check in `JVPInterpreter.lift`. This\n", "comes up in higher order differentiation when there are multiple JVPInterprers\n", "in scope. The sort of bug where you accidentally interpret a dual number from\n", @@ -1046,7 +1046,7 @@ "That's it for part one of this tutorial. We've done two primitives, three\n", "interpreters and the tracing mechanism that weaves them together. In the next\n", "part we'll add types other than floats, error handling, compilation,\n", - "reverse-mode AD and higher-order primtives. Note that the second part is\n", + "reverse-mode AD and higher-order primitives. Note that the second part is\n", "structured differently. Rather than trying to have a top-to-bottom order that\n", "obeys both code dependencies (e.g. data structures need to be defined before\n", "they're used) and pedagogical dependencies (concepts need to be introduced\n", diff --git a/docs/autodidax2_part1.md b/docs/autodidax2_part1.md index 70dd0e4b696b..a4af594fb253 100644 --- a/docs/autodidax2_part1.md +++ b/docs/autodidax2_part1.md @@ -348,7 +348,7 @@ There are some subtleties worth discussing. First, how do you tell if something is constant with respect to differentiation? It's tempting to say "it's a constant if and only if it's not a dual number". But actually dual numbers created by a *different* JVPInterpreter also need to be considered -constants with resepect to the JVPInterpreter we're currently handling. That's +constants with respect to the JVPInterpreter we're currently handling. That's why we need the `x.interpreter is self` check in `JVPInterpreter.lift`. This comes up in higher order differentiation when there are multiple JVPInterprers in scope. The sort of bug where you accidentally interpret a dual number from @@ -539,7 +539,7 @@ print(jvp(lambda x: eval_jaxpr(build_jaxpr(foo, 1), (x,)), 2.0, 1.0)) That's it for part one of this tutorial. We've done two primitives, three interpreters and the tracing mechanism that weaves them together. In the next part we'll add types other than floats, error handling, compilation, -reverse-mode AD and higher-order primtives. Note that the second part is +reverse-mode AD and higher-order primitives. Note that the second part is structured differently. Rather than trying to have a top-to-bottom order that obeys both code dependencies (e.g. data structures need to be defined before they're used) and pedagogical dependencies (concepts need to be introduced diff --git a/docs/autodidax2_part1.py b/docs/autodidax2_part1.py index bfe59df359d3..44bf843c91b3 100644 --- a/docs/autodidax2_part1.py +++ b/docs/autodidax2_part1.py @@ -307,7 +307,7 @@ def nth_order_derivative(n, f, x): # something is constant with respect to differentiation? It's tempting to say # "it's a constant if and only if it's not a dual number". But actually dual # numbers created by a *different* JVPInterpreter also need to be considered -# constants with resepect to the JVPInterpreter we're currently handling. That's +# constants with respect to the JVPInterpreter we're currently handling. That's # why we need the `x.interpreter is self` check in `JVPInterpreter.lift`. This # comes up in higher order differentiation when there are multiple JVPInterprers # in scope. The sort of bug where you accidentally interpret a dual number from @@ -483,7 +483,7 @@ def eval_atom(x): return env[x] if isinstance(x, Var) else x # That's it for part one of this tutorial. We've done two primitives, three # interpreters and the tracing mechanism that weaves them together. In the next # part we'll add types other than floats, error handling, compilation, -# reverse-mode AD and higher-order primtives. Note that the second part is +# reverse-mode AD and higher-order primitives. Note that the second part is # structured differently. Rather than trying to have a top-to-bottom order that # obeys both code dependencies (e.g. data structures need to be defined before # they're used) and pedagogical dependencies (concepts need to be introduced diff --git a/docs/building_on_jax.md b/docs/building_on_jax.md index 9416b16cde10..6d13f517f50b 100644 --- a/docs/building_on_jax.md +++ b/docs/building_on_jax.md @@ -45,8 +45,8 @@ Here are more specific examples of each pattern. ### Direct usage Jax can be directly imported and utilized to build models “from scratch” as shown across this website, -for example in [JAX Tutorials](https://jax.readthedocs.io/en/latest/tutorials.html) -or [Neural Network with JAX](https://jax.readthedocs.io/en/latest/notebooks/neural_network_with_tfds_data.html). +for example in [JAX Tutorials](https://docs.jax.dev/en/latest/tutorials.html) +or [Neural Network with JAX](https://docs.jax.dev/en/latest/notebooks/neural_network_with_tfds_data.html). This may be the best option if you are unable to find prebuilt code for your particular challenge, or if you're looking to reduce the number of dependencies in your codebase. diff --git a/docs/conf.py b/docs/conf.py index 45964b6d8d7e..3cd3b8ea8776 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -29,6 +29,7 @@ import inspect import operator import os +from pathlib import Path import sys sys.path.insert(0, os.path.abspath('..')) @@ -38,11 +39,11 @@ from typing import ForwardRef def _do_not_evaluate_in_jax( - self, globalns, *args, _evaluate=ForwardRef._evaluate, + self, globalns, *args, _evaluate=ForwardRef._evaluate, **kwargs, ): if globalns.get('__name__', '').startswith('jax'): return self - return _evaluate(self, globalns, *args) + return _evaluate(self, globalns, *args, **kwargs) ForwardRef._evaluate = _do_not_evaluate_in_jax @@ -80,6 +81,7 @@ def _do_not_evaluate_in_jax( "sphinx_remove_toctrees", 'sphinx_copybutton', 'jax_extensions', + 'jax_list_config_options', 'sphinx_design', 'sphinxext.rediraffe', ] @@ -132,6 +134,8 @@ def _do_not_evaluate_in_jax( # These are kept in sync using the jupytext pre-commit hook. 'notebooks/*.md', 'pallas/quickstart.md', + 'pallas/pipelining.md', + 'pallas/gpu/pipelining.md', 'pallas/tpu/pipelining.md', 'pallas/tpu/distributed.md', 'pallas/tpu/sparse.md', @@ -203,6 +207,8 @@ def _do_not_evaluate_in_jax( # -- Options for myst ---------------------------------------------- myst_heading_anchors = 3 # auto-generate 3 levels of heading anchors myst_enable_extensions = ['dollarmath'] +myst_ref_domains = ["py"] +myst_all_links_external = False nb_execution_mode = "force" nb_execution_allow_errors = False nb_merge_streams = True @@ -222,18 +228,19 @@ def _do_not_evaluate_in_jax( 'jep/9407-type-promotion.*', # TODO(jakevdp): enable execution on the following if possible: 'notebooks/Distributed_arrays_and_automatic_parallelization.*', - 'notebooks/explicit-sharding.*', 'notebooks/autodiff_remat.*', # Fails on readthedocs with Kernel Died 'notebooks/convolutions.ipynb', # Requires accelerators 'pallas/quickstart.*', + 'pallas/pipelining.*', + 'pallas/gpu/pipelining.*', 'pallas/tpu/pipelining.*', 'pallas/tpu/distributed.*', 'pallas/tpu/sparse.*', 'pallas/tpu/matmul.*', - 'sharded-computation.*', - 'distributed_data_loading.*' + 'distributed_data_loading.*', + 'notebooks/host-offloading.*', ] # -- Options for HTMLHelp output --------------------------------------------- @@ -352,7 +359,11 @@ def linkcode_resolve(domain, info): source, linenum = inspect.getsourcelines(obj) except: return None - filename = os.path.relpath(filename, start=os.path.dirname(jax.__file__)) + try: + filename = Path(filename).relative_to(Path(jax.__file__).parent) + except ValueError: + # Source file is not a relative to jax; this must be a re-exported function. + return None lines = f"#L{linenum}-L{linenum + len(source)}" if linenum else "" return f"https://github.com/jax-ml/jax/blob/main/jax/{filename}{lines}" diff --git a/docs/config_options.rst b/docs/config_options.rst new file mode 100644 index 000000000000..a8ef4e93a834 --- /dev/null +++ b/docs/config_options.rst @@ -0,0 +1,66 @@ +.. _jax: + +.. This target is required to prevent the Sphinx build error "Unknown target name: jax". +.. The custom directive list_config_options imports JAX to extract real configuration +.. data, which causes Sphinx to look for a target named "jax". This dummy target +.. satisfies that requirement while allowing the actual JAX import to work. + +Configuration Options +===================== + +JAX provides various configuration options to customize its behavior. These options control everything from numerical precision to debugging features. + +How to Use Configuration Options +-------------------------------- + +JAX configuration options can be set in several ways: + +1. **Environment variables** (set before running your program): + + .. code-block:: bash + + export JAX_ENABLE_X64=True + python my_program.py + +2. **Runtime configuration** (in your Python code): + + .. code-block:: python + + import jax + jax.config.update("jax_enable_x64", True) + +3. **Command-line flags** (using Abseil): + + .. code-block:: python + + # In your code: + import jax + jax.config.parse_flags_with_absl() + + .. code-block:: bash + + # When running: + python my_program.py --jax_enable_x64=True + +Common Configuration Options +---------------------------- + +Here are some of the most frequently used configuration options: + +- ``jax_enable_x64`` -- Enable 64-bit floating-point precision +- ``jax_disable_jit`` -- Disable JIT compilation for debugging +- ``jax_debug_nans`` -- Check for and raise errors on NaNs +- ``jax_platforms`` -- Control which backends (CPU/GPU/TPU) JAX will initialize +- ``jax_numpy_rank_promotion`` -- Control automatic rank promotion behavior +- ``jax_default_matmul_precision`` -- Set default precision for matrix multiplication operations + +.. raw:: html + +
+ +All Configuration Options +------------------------- + +Below is a complete list of all available JAX configuration options: + +.. list_config_options:: diff --git a/docs/contributing.md b/docs/contributing.md index 99d78453c436..087432f1f771 100644 --- a/docs/contributing.md +++ b/docs/contributing.md @@ -6,7 +6,7 @@ Everyone can contribute to JAX, and we value everyone's contributions. There are ways to contribute, including: - Answering questions on JAX's [discussions page](https://github.com/jax-ml/jax/discussions) -- Improving or expanding JAX's [documentation](http://jax.readthedocs.io/) +- Improving or expanding JAX's [documentation](http://docs.jax.dev/) - Contributing to JAX's [code-base](http://github.com/jax-ml/jax/) - Contributing in any of the above ways to the broader ecosystem of [libraries built on JAX](https://github.com/jax-ml/jax#neural-network-libraries) @@ -30,13 +30,13 @@ We do all of our development using git, so basic knowledge is assumed. Follow these steps to contribute code: 1. Sign the [Google Contributor License Agreement (CLA)](https://cla.developers.google.com/). - For more information, see the Pull Request Checklist below. + For more information, see the {ref}`pr-checklist` below. 2. Fork the JAX repository by clicking the **Fork** button on the [repository page](http://www.github.com/jax-ml/jax). This creates a copy of the JAX repository in your own account. -3. Install Python >= 3.10 locally in order to run tests. +3. Install Python >= 3.11 locally in order to run tests. 4. `pip` installing your fork from source. This allows you to modify the code and immediately test it out: diff --git a/docs/control-flow.md b/docs/control-flow.md index 7cb959f3e434..8f59bd92add7 100644 --- a/docs/control-flow.md +++ b/docs/control-flow.md @@ -244,19 +244,19 @@ lax.cond(False, lambda x: x+1, lambda x: x-1, operand) `jax.lax` provides two other functions that allow branching on dynamic predicates: -- [`lax.select`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.select.html) is +- [`lax.select`](https://docs.jax.dev/en/latest/_autosummary/jax.lax.select.html) is like a batched version of `lax.cond`, with the choices expressed as pre-computed arrays rather than as functions. -- [`lax.switch`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.switch.html) is +- [`lax.switch`](https://docs.jax.dev/en/latest/_autosummary/jax.lax.switch.html) is like `lax.cond`, but allows switching between any number of callable choices. In addition, `jax.numpy` provides several numpy-style interfaces to these functions: -- [`jnp.where`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.where.html) with +- [`jnp.where`](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.where.html) with three arguments is the numpy-style wrapper of `lax.select`. -- [`jnp.piecewise`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.piecewise.html) +- [`jnp.piecewise`](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.piecewise.html) is a numpy-style wrapper of `lax.switch`, but switches on a list of boolean conditions rather than a single scalar index. -- [`jnp.select`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.select.html) has +- [`jnp.select`](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.select.html) has an API similar to `jnp.piecewise`, but the choices are given as pre-computed arrays rather than as functions. It is implemented in terms of multiple calls to `lax.select`. diff --git a/docs/default_dtypes.md b/docs/default_dtypes.md new file mode 100644 index 000000000000..629f7fb5c314 --- /dev/null +++ b/docs/default_dtypes.md @@ -0,0 +1,82 @@ +(default-dtypes)= +# Default dtypes and the X64 flag +JAX strives to meet the needs of a range of numerical computing practitioners, who +sometimes have conflicting preferences. When it comes to default dtypes, there are +two different camps: + +- Classic scientific computing practitioners (i.e. users of tools like {mod}`numpy` or + {mod}`scipy`) tend to value accuracy of computations foremost: such users would + prefer that computations default to the **widest available representation**: e.g. + floating point values should default to `float64`, integers to `int64`, etc. +- AI researchers (i.e. folks implementing and training neural networks) tend to value + speed over accuracy, to the point where they have developed special data types like + [bfloat16](https://en.wikipedia.org/wiki/Bfloat16_floating-point_format) and others + which deliberately discard the least significant bits in order to speed up computation. + For these users, the mere presence of a float64 value in their computation can lead + to programs that are slow at best, and incompatible with their hardware at worst! + These users would prefer that computations default to `float32` or `int32`. + +The main mechanism JAX offers for this is the `jax_enable_x64` flag, which controls +whether 64-bit values can be created at all. By default this flag is set to `False` +(serving the needs of AI researchers and practitioners), but can be set to `True` +by users who value accuracy over computational speed. + +## Default setting: 32-bits everywhere +By default `jax_enable_x64` is set to False, and so {mod}`jax.numpy` array creation +functions will default to returning 32-bit values. + +For example: +```python +>>> import jax.numpy as jnp + +>>> jnp.arange(5) +Array([0, 1, 2, 3, 4], dtype=int32) + +>>> jnp.zeros(5) +Array([0., 0., 0., 0., 0.], dtype=float32) + +>>> jnp.ones(5, dtype=int) +Array([1, 1, 1, 1, 1], dtype=int32) + +``` + +Beyond defaults, because 64-bit values can be so poisonous to AI workflows, having +this flag set to False prevents you from creating 64-bit arrays at all! For example: +``` +>>> jnp.arange(5, dtype='float64') # doctest: +SKIP +UserWarning: Explicitly requested dtype float64 requested in arange is not available, and will be +truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the +JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more. +Array([0., 1., 2., 3., 4.], dtype=float32) +``` + +## The X64 flag: enabling 64-bit values +To work in the "other mode" where functions default to producing 64-bit values, you can set the +`jax_enable_x64` flag to `True`: +```python +import jax +import jax.numpy as jnp + +jax.config.update('jax_enable_x64', True) + +print(repr(jnp.arange(5))) +print(repr(jnp.zeros(5))) +print(repr(jnp.ones(5, dtype=int))) +``` +``` +Array([0, 1, 2, 3, 4], dtype=int64) +Array([0., 0., 0., 0., 0.], dtype=float64) +Array([1, 1, 1, 1, 1], dtype=int64) +``` + +The X64 configuration can also be set via the `JAX_ENABLE_X64` shell environment variable, +for example: +```bash +$ JAX_ENABLE_X64=1 python main.py +``` +The X64 flag is intended as a **global setting** that should have one value for your whole +program, set at the top of your main file. A common feature request is for the flag to +be contextually configurable (e.g. enabling X64 just for one section of a long program): +this turns out to be difficult to implement within JAX's programming model, where code +execution may happen in a different context than code compilation. There is ongoing work +exploring the feasibility of relaxing this constraint, so stay tuned! diff --git a/docs/developer.md b/docs/developer.md index 0affbba9ed36..e219e8517075 100644 --- a/docs/developer.md +++ b/docs/developer.md @@ -1,7 +1,7 @@ (building-from-source)= # Building from source - + First, obtain the JAX source code: @@ -374,7 +374,7 @@ in terms of files, not installations): --repo_env=HERMETIC_PYTHON_URL="https://remote/url/to/my_python.tgz" --repo_env=HERMETIC_PYTHON_SHA256= - # We assume that top-level folder in the tarbal is called "python", if it is + # We assume that top-level folder in the tarball is called "python", if it is # something different just pass additional HERMETIC_PYTHON_PREFIX parameter --repo_env=HERMETIC_PYTHON_URL="https://remote/url/to/my_python.tgz" --repo_env=HERMETIC_PYTHON_SHA256= @@ -455,7 +455,6 @@ which one is selected by specifying `HERMETIC_PYTHON_VERSION`. For example in `WORKSPACE` file: ``` requirements = { - "3.10": "//build:requirements_lock_3_10.txt", "3.11": "//build:requirements_lock_3_11.txt", "3.12": "//build:requirements_lock_3_12.txt", "3.13": "//build:requirements_lock_3_13.txt", @@ -466,16 +465,16 @@ requirements = { Then you can build and test different combinations of stuff without changing anything in your environment: ``` -# To build with scenario1 dependendencies: +# To build with scenario1 dependencies: bazel test --repo_env=HERMETIC_PYTHON_VERSION=3.13-scenario1 -# To build with scenario2 dependendencies: +# To build with scenario2 dependencies: bazel test --repo_env=HERMETIC_PYTHON_VERSION=3.13-scenario2 -# To build with default dependendencies: +# To build with default dependencies: bazel test --repo_env=HERMETIC_PYTHON_VERSION=3.13 -# To build with scenario1 dependendencies and custom Python 3.13 interpreter: +# To build with scenario1 dependencies and custom Python 3.13 interpreter: bazel test --repo_env=HERMETIC_PYTHON_VERSION=3.13-scenario1 --repo_env=HERMETIC_PYTHON_URL="file:///path/to/cpython.tar.gz" @@ -526,6 +525,11 @@ bazel test //tests:cpu_tests //tests:backend_independent_tests `//tests:gpu_tests` and `//tests:tpu_tests` are also available, if you have the necessary hardware. +You need to configure `cuda` to run `gpu` tests: +``` +python build/build.py build --wheels=jaxlib,jax-cuda-plugin,jax-cuda-pjrt --configure_only +``` + To use a preinstalled `jaxlib` instead of building it you first need to make it available in the hermetic Python. To install a specific version of `jaxlib` within hermetic Python run (using `jaxlib >= 0.4.26` as an example): @@ -785,7 +789,7 @@ desired formats, and which the `jupytext --sync` command recognizes when invoked #### Notebooks within the Sphinx build Some of the notebooks are built automatically as part of the pre-submit checks and -as part of the [Read the docs](https://jax.readthedocs.io/en/latest) build. +as part of the [Read the docs](https://docs.jax.dev/en/latest) build. The build will fail if cells raise errors. If the errors are intentional, you can either catch them, or tag the cell with `raises-exceptions` metadata ([example PR](https://github.com/jax-ml/jax/pull/2402/files)). You have to add this metadata by hand in the `.ipynb` file. It will be preserved when somebody else @@ -796,7 +800,7 @@ See `exclude_patterns` in [conf.py](https://github.com/jax-ml/jax/blob/main/docs ### Documentation building on `readthedocs.io` -JAX's auto-generated documentation is at . +JAX's auto-generated documentation is at . The documentation building is controlled for the entire project by the [readthedocs JAX settings](https://readthedocs.org/dashboard/jax). The current settings @@ -809,7 +813,7 @@ For each automated documentation build you can see the If you want to test the documentation generation on Readthedocs, you can push code to the `test-docs` branch. That branch is also built automatically, and you can -see the generated documentation [here](https://jax.readthedocs.io/en/test-docs/). If the documentation build +see the generated documentation [here](https://docs.jax.dev/en/test-docs/). If the documentation build fails you may want to [wipe the build environment for test-docs](https://docs.readthedocs.io/en/stable/guides/wipe-environment.html). For a local test, I was able to do it in a fresh directory by replaying the commands diff --git a/docs/export/export.md b/docs/export/export.md index 18cdcc6c51d0..95e47385997c 100644 --- a/docs/export/export.md +++ b/docs/export/export.md @@ -161,7 +161,7 @@ e.g., the inference system.) What **matters is when the exporting and consuming components were built**, not the time when the exporting and the compilation happen. For external JAX users, it is -[possible to run JAX and jaxlib at different versions](https://jax.readthedocs.io/en/latest/jep/9419-jax-versioning.html#how-are-jax-and-jaxlib-versioned); +[possible to run JAX and jaxlib at different versions](https://docs.jax.dev/en/latest/jep/9419-jax-versioning.html#how-are-jax-and-jaxlib-versioned); what matters is when the jaxlib release was built. To reduce chances of incompatibility, internal JAX users should: @@ -710,10 +710,7 @@ total 32 -rw-rw-r--@ 1 necula wheel 2333 Jun 19 11:04 jax_ir3_jit_my_fun_export.mlir ``` -Inside Google, you can turn on logging by using the `--vmodule` argument to -specify the logging levels for different modules, -e.g., `--vmodule=_export=3`. - +Set [`JAX_DEBUG_LOG_MODULES=jax._src.export`](https://docs.jax.dev/en/latest/config_options.html#jax_debug_log_modules) to enable extra debugging logging. (export_ensuring_compat)= ### Ensuring forward and backward compatibility diff --git a/docs/export/shape_poly.md b/docs/export/shape_poly.md index 9254030a4e1c..68da231c4a68 100644 --- a/docs/export/shape_poly.md +++ b/docs/export/shape_poly.md @@ -86,7 +86,7 @@ matching the structure of the arguments passed to it. The polymorphic shapes specification can be a pytree prefix in cases where one specification should apply to multiple arguments, as in the above example. -See [how optional parameters are matched to arguments](https://jax.readthedocs.io/en/latest/pytrees.html#applying-optional-parameters-to-pytrees). +See [how optional parameters are matched to arguments](https://docs.jax.dev/en/latest/pytrees.html#applying-optional-parameters-to-pytrees). A few examples of shape specifications: @@ -441,7 +441,7 @@ to {func}`jax.export.symbolic_shape` share a scope and can be mixed up in arithmetic operations. The result would also share the same scope. -You can re-use scopes: +You can reuse scopes: ```python >>> a, = export.symbolic_shape("a,", constraints=("a >= 8",)) @@ -609,7 +609,7 @@ Division had remainder 1 when computing the value of 'd'. Using the following polymorphic shapes specifications: args[0].shape = (b, b, 2*d). Obtained dimension variables: 'b' = 3 from specification 'b' for dimension args[0].shape[0] (= 3), . -Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#shape-assertion-errors for more details. +Please see https://docs.jax.dev/en/latest/export/shape_poly.html#shape-assertion-errors for more details. ``` diff --git a/docs/faq.rst b/docs/faq.rst index 44267f6f5f7d..25d1d9ffab57 100644 --- a/docs/faq.rst +++ b/docs/faq.rst @@ -4,7 +4,7 @@ Frequently asked questions (FAQ) .. comment RST primer for Sphinx: https://thomas-cokelaer.info/tutorials/sphinx/rest_syntax.html .. comment Some links referenced here. Use `JAX - The Sharp Bits`_ (underscore at the end) to reference -.. _JAX - The Sharp Bits: https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html +.. _JAX - The Sharp Bits: https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html We are collecting answers to frequently asked questions here. Contributions welcome! @@ -116,7 +116,7 @@ code in JAX's internal representation, typically because it makes heavy use of Python control flow such as ``for`` loops. For a handful of loop iterations, Python is OK, but if you need *many* loop iterations, you should rewrite your code to make use of JAX's -`structured control flow primitives `_ +`structured control flow primitives `_ (such as :func:`lax.scan`) or avoid wrapping the loop with ``jit`` (you can still use ``jit`` decorated functions *inside* the loop). @@ -422,7 +422,6 @@ for comparing JAX versus NumPy, making using of IPython's convenient `%time and %timeit magics`_:: import numpy as np - import jax.numpy as jnp import jax def f(x): # function we're benchmarking (works in both NumPy & JAX) @@ -431,7 +430,9 @@ for comparing JAX versus NumPy, making using of IPython's convenient x_np = np.ones((1000, 1000), dtype=np.float32) # same as JAX default dtype %timeit f(x_np) # measure NumPy runtime - %time x_jax = jax.device_put(x_np) # measure JAX device transfer time + # measure JAX device transfer time + %time x_jax = jax.device_put(x_np).block_until_ready() + f_jit = jax.jit(f) %time f_jit(x_jax).block_until_ready() # measure JAX compilation time %timeit f_jit(x_jax).block_until_ready() # measure JAX runtime @@ -454,8 +455,8 @@ performing matrix-matrix multiplication) to amortize the increased overhead of JAX/accelerators vs NumPy/CPU. For example, if we switch this example to use 10x10 input instead, JAX/GPU runs 10x slower than NumPy/CPU (100 µs vs 10 µs). -.. _To JIT or not to JIT: https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html#to-jit-or-not-to-jit -.. _Double (64 bit) precision: https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision +.. _To JIT or not to JIT: https://docs.jax.dev/en/latest/notebooks/thinking_in_jax.html#to-jit-or-not-to-jit +.. _Double (64 bit) precision: https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision .. _`%time and %timeit magics`: https://ipython.readthedocs.io/en/stable/interactive/magics.html#magic-time .. _Colab: https://colab.research.google.com/ @@ -841,12 +842,12 @@ reducing :code:`XLA_PYTHON_CLIENT_MEM_FRACTION` from the default of :code:`.75`, or setting :code:`XLA_PYTHON_CLIENT_PREALLOCATE=false`. For more details, please see the page on `JAX GPU memory allocation`_. -.. _JIT mechanics: https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html#jit-mechanics-tracing-and-static-variables -.. _External callbacks in JAX: https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html -.. _Pure callback example: https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html#example-pure-callback-with-custom-jvp -.. _IO callback example: https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html#exploring-jax-experimental-io-callback +.. _JIT mechanics: https://docs.jax.dev/en/latest/notebooks/thinking_in_jax.html#jit-mechanics-tracing-and-static-variables +.. _External callbacks in JAX: https://docs.jax.dev/en/latest/notebooks/external_callbacks.html +.. _Pure callback example: https://docs.jax.dev/en/latest/notebooks/external_callbacks.html#example-pure-callback-with-custom-jvp +.. _IO callback example: https://docs.jax.dev/en/latest/notebooks/external_callbacks.html#exploring-jax-experimental-io-callback .. _Heaviside Step Function: https://en.wikipedia.org/wiki/Heaviside_step_function .. _Sigmoid Function: https://en.wikipedia.org/wiki/Sigmoid_function .. _algebraic_simplifier.cc: https://github.com/openxla/xla/blob/33f815e190982dac4f20d1f35adb98497a382377/xla/hlo/transforms/simplifiers/algebraic_simplifier.cc#L4851 -.. _JAX GPU memory allocation: https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html +.. _JAX GPU memory allocation: https://docs.jax.dev/en/latest/gpu_memory_allocation.html .. _dynamic linker search pattern: https://man7.org/linux/man-pages/man8/ld.so.8.html diff --git a/docs/ffi.ipynb b/docs/ffi.ipynb index b622fba9d5bc..aafe9d56e82b 100644 --- a/docs/ffi.ipynb +++ b/docs/ffi.ipynb @@ -439,7 +439,7 @@ "As far as JAX is concerned, the foreign function is a black box that can't be inspected to determine the appropriate behavior when differentiated.\n", "Therefore, it is the {func}`~jax.ffi.ffi_call` user's responsibility to define a custom derivative rule.\n", "\n", - "More details about custom derivative rules can be found in the [custom derivatives tutorial](https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html), but the most common pattern used for implementing differentiation for foreign functions is to define a {func}`~jax.custom_vjp` which itself calls a foreign function.\n", + "More details about custom derivative rules can be found in the [custom derivatives tutorial](https://docs.jax.dev/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html), but the most common pattern used for implementing differentiation for foreign functions is to define a {func}`~jax.custom_vjp` which itself calls a foreign function.\n", "In this case, we actually define two new FFI calls:\n", "\n", "1. `rms_norm_fwd` returns two outputs: (a) the \"primal\" result, and (b) the \"residuals\" which are used in the backwards pass.\n", @@ -730,13 +730,13 @@ "source": [ "This clearly (to us!) isn't the optimal partitioning of this function, but it's the best that JAX/XLA can do with the information given.\n", "\n", - "To generate better partitioning logic, we can use {func}`~jax.experimental.shard_map.shard_map` or {func}`~jax.experimental.custom_partitioning.custom_partitioning`, and we discuss both options here.\n", + "To generate better partitioning logic, we can use {func}`~jax.shard_map` or {func}`~jax.experimental.custom_partitioning.custom_partitioning`, and we discuss both options here.\n", "That being said, it's not straightforward to generate _optimal_ partitioning for all inputs, because sometimes this would require algorithmic changes.\n", "Specifically, let's add support for \"batch partitioning\", which handles the case where the data are sharded on batch dimensions, but sharding on the last dimension will always require in re-sharding.\n", "\n", "### Using `shard_map`\n", "\n", - "If you are using manual sharding control via {func}`~jax.experimental.shard_map.shard_map`, any FFI calls in your program should already partition appropriately:" + "If you are using manual sharding control via {func}`~jax.shard_map`, any FFI calls in your program should already partition appropriately:" ] }, { @@ -746,9 +746,8 @@ "outputs": [], "source": [ "from functools import partial\n", - "from jax.experimental.shard_map import shard_map\n", "\n", - "@partial(shard_map, mesh=mesh, in_specs=P(\"x\", None), out_specs=P(\"x\", None))\n", + "@partial(jax.shard_map, mesh=mesh, in_specs=P(\"x\", None), out_specs=P(\"x\", None))\n", "def rms_norm_shmap(x):\n", " return rms_norm(x)\n", "\n", @@ -781,11 +780,11 @@ "source": [ "### Using `custom partitioning`\n", "\n", - "If you can't use {func}`~jax.experimental.shard_map.shard_map`, an alternative approach is to use {func}`~jax.experimental.custom_partitioning.custom_partitioning`, which supports automatic parallelization via {func}`jax.jit`.\n", + "If you can't use {func}`~jax.shard_map`, an alternative approach is to use {func}`~jax.experimental.custom_partitioning.custom_partitioning`, which supports automatic parallelization via {func}`jax.jit`.\n", "{func}`~jax.experimental.custom_partitioning.custom_partitioning` works by adding Python callbacks into the XLA compiler's partitioning pass, which allows very flexible logic, but also comes with some rough edges.\n", "We won't go into too much detail on the caveats here, but the main issues that you should be aware of are:\n", "\n", - "1. `custom_partitioning` can cause unexpected cache misses when used with the JAX's [Persistent compilation cache](https://jax.readthedocs.io/en/latest/persistent_compilation_cache.html). This can be mitigated using the `jax_remove_custom_partitioning_ptr_from_cache_key` configuration flag, but that isn't always appropriate either.\n", + "1. `custom_partitioning` can cause unexpected cache misses when used with the JAX's [Persistent compilation cache](https://docs.jax.dev/en/latest/persistent_compilation_cache.html). This can be mitigated using the `jax_remove_custom_partitioning_ptr_from_cache_key` configuration flag, but that isn't always appropriate either.\n", "2. Debugging `custom_partitioning` logic can be tedious because Python errors don't always get propagated, instead causing your Python process to exit. That being said, any exceptions will show up in the process logs, so you should be able to track them down there.\n", "\n", "All that being said, here's how we can wrap our FFI implementation of `rms_norm` using {func}`~jax.experimental.custom_partitioning.custom_partitioning`:" diff --git a/docs/ffi.md b/docs/ffi.md index 4aa03c217855..106b8118f1ab 100644 --- a/docs/ffi.md +++ b/docs/ffi.md @@ -353,7 +353,7 @@ Unlike with batching, {func}`~jax.ffi.ffi_call` doesn't provide any default supp As far as JAX is concerned, the foreign function is a black box that can't be inspected to determine the appropriate behavior when differentiated. Therefore, it is the {func}`~jax.ffi.ffi_call` user's responsibility to define a custom derivative rule. -More details about custom derivative rules can be found in the [custom derivatives tutorial](https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html), but the most common pattern used for implementing differentiation for foreign functions is to define a {func}`~jax.custom_vjp` which itself calls a foreign function. +More details about custom derivative rules can be found in the [custom derivatives tutorial](https://docs.jax.dev/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html), but the most common pattern used for implementing differentiation for foreign functions is to define a {func}`~jax.custom_vjp` which itself calls a foreign function. In this case, we actually define two new FFI calls: 1. `rms_norm_fwd` returns two outputs: (a) the "primal" result, and (b) the "residuals" which are used in the backwards pass. @@ -556,19 +556,18 @@ print(hlo.split("\n\n")[-1]) This clearly (to us!) isn't the optimal partitioning of this function, but it's the best that JAX/XLA can do with the information given. -To generate better partitioning logic, we can use {func}`~jax.experimental.shard_map.shard_map` or {func}`~jax.experimental.custom_partitioning.custom_partitioning`, and we discuss both options here. +To generate better partitioning logic, we can use {func}`~jax.shard_map` or {func}`~jax.experimental.custom_partitioning.custom_partitioning`, and we discuss both options here. That being said, it's not straightforward to generate _optimal_ partitioning for all inputs, because sometimes this would require algorithmic changes. Specifically, let's add support for "batch partitioning", which handles the case where the data are sharded on batch dimensions, but sharding on the last dimension will always require in re-sharding. ### Using `shard_map` -If you are using manual sharding control via {func}`~jax.experimental.shard_map.shard_map`, any FFI calls in your program should already partition appropriately: +If you are using manual sharding control via {func}`~jax.shard_map`, any FFI calls in your program should already partition appropriately: ```{code-cell} ipython3 from functools import partial -from jax.experimental.shard_map import shard_map -@partial(shard_map, mesh=mesh, in_specs=P("x", None), out_specs=P("x", None)) +@partial(jax.shard_map, mesh=mesh, in_specs=P("x", None), out_specs=P("x", None)) def rms_norm_shmap(x): return rms_norm(x) @@ -587,11 +586,11 @@ assert "all-to-all" in hlo_data_shmap ### Using `custom partitioning` -If you can't use {func}`~jax.experimental.shard_map.shard_map`, an alternative approach is to use {func}`~jax.experimental.custom_partitioning.custom_partitioning`, which supports automatic parallelization via {func}`jax.jit`. +If you can't use {func}`~jax.shard_map`, an alternative approach is to use {func}`~jax.experimental.custom_partitioning.custom_partitioning`, which supports automatic parallelization via {func}`jax.jit`. {func}`~jax.experimental.custom_partitioning.custom_partitioning` works by adding Python callbacks into the XLA compiler's partitioning pass, which allows very flexible logic, but also comes with some rough edges. We won't go into too much detail on the caveats here, but the main issues that you should be aware of are: -1. `custom_partitioning` can cause unexpected cache misses when used with the JAX's [Persistent compilation cache](https://jax.readthedocs.io/en/latest/persistent_compilation_cache.html). This can be mitigated using the `jax_remove_custom_partitioning_ptr_from_cache_key` configuration flag, but that isn't always appropriate either. +1. `custom_partitioning` can cause unexpected cache misses when used with the JAX's [Persistent compilation cache](https://docs.jax.dev/en/latest/persistent_compilation_cache.html). This can be mitigated using the `jax_remove_custom_partitioning_ptr_from_cache_key` configuration flag, but that isn't always appropriate either. 2. Debugging `custom_partitioning` logic can be tedious because Python errors don't always get propagated, instead causing your Python process to exit. That being said, any exceptions will show up in the process logs, so you should be able to track them down there. All that being said, here's how we can wrap our FFI implementation of `rms_norm` using {func}`~jax.experimental.custom_partitioning.custom_partitioning`: diff --git a/docs/ffi/CMakeLists.txt b/docs/ffi/CMakeLists.txt index 9d3e9df7d3bf..b7f1af5c1a1b 100644 --- a/docs/ffi/CMakeLists.txt +++ b/docs/ffi/CMakeLists.txt @@ -4,7 +4,7 @@ project(rms_norm LANGUAGES CXX) find_package(Python 3.8 COMPONENTS Interpreter Development.Module REQUIRED) execute_process( COMMAND "${Python_EXECUTABLE}" - "-c" "from jax.extend import ffi; print(ffi.include_dir())" + "-c" "from jax import ffi; print(ffi.include_dir())" OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE XLA_DIR) message(STATUS "XLA include directory: ${XLA_DIR}") diff --git a/docs/gpu_memory_allocation.rst b/docs/gpu_memory_allocation.rst index 6667589e7b72..be40dfc8004c 100644 --- a/docs/gpu_memory_allocation.rst +++ b/docs/gpu_memory_allocation.rst @@ -69,7 +69,7 @@ Common causes of OOM failures disabling the automatic remat pass produces different trade-offs between compute and memory. Note however, that the algorithm is basic and you can often get better trade-off between compute and memory by disabling the automatic remat pass and doing - it manually with `the jax.remat API `_ + it manually with `the jax.remat API `_ Experimental features diff --git a/docs/gpu_performance_tips.md b/docs/gpu_performance_tips.md index bf032dccff88..f62523631872 100644 --- a/docs/gpu_performance_tips.md +++ b/docs/gpu_performance_tips.md @@ -1,6 +1,6 @@ # GPU performance tips - + This document focuses on performance tips for neural network workloads @@ -58,7 +58,173 @@ training on Nvidia GPUs](https://github.com/NVIDIA/JAX-Toolbox/blob/main/rosetta * **--xla_gpu_triton_gemm_any** Use the Triton-based GEMM (matmul) emitter for any GEMM that it supports. The default value is False. -### Communication flags +## Communication tips + +### Auto and manual PGLE + +The Profile Guided Latency Estimator (PGLE) workflow measures the actual running time +of compute and collectives, the the profile information is fed back into XLA compiler +for a better scheduling decision. + +The Profile Guided Latency Estimator can be used manually or automatically. In the auto mode +JAX will collect profile information and recompile a module in a single run. While +in manual mode you need to run a task twice, the first time to collect and save profiles +and the second to compile and run with provided data. + +**Important**: the JAX profiler, which is used by both of the PGLE workflows documented +below, cannot co-exist with the NVIDIA Nsight Systems profiler. This limitation can be +avoided by using the JAX compilation cache, as described below. + +### Auto PGLE +The auto PGLE can be turned on by setting the following environment variables: + +Mandatory: +```bash +export JAX_ENABLE_PGLE=true + +# For JAX version <= 0.5.0 make sure to include: +export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true" +``` + +Optional: +```bash +export JAX_PGLE_PROFILING_RUNS=3 +export JAX_PGLE_AGGREGATION_PERCENTILE=85 + +# Right now the auto PGLE profile collection doesn't work with command buffer. +# If the command buffer is enabled, Auto PGLE will disable it during profile +# collection and enable it back after the recompilation. If you need to have a +# consistent command buffer logic with and with PGLE profile you can disable it +# manually: +export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_enable_command_buffer=''" +``` + +Or in the JAX this can be set as the following: + +``` +import jax +from jax._src import config + +with config.enable_pgle(True), config.pgle_profiling_runs(1): + # Run with the profiler collecting performance information. + train_step() + # Automatically re-compile with PGLE profile results + train_step() + ... +``` + +You can control amount of reruns used to collect profile data by changing `JAX_PGLE_PROFILING_RUNS`. +Increasing this parameter would lead to better profile information, but it will also increase the +amount of non-optimized training steps. + +Decreasing the `JAX_PGLE_AGGREGATION_PERCENTILE` parameter might help in case when performance between steps is too noisy to filter out a non-relevant measures. + +**Attention:** Auto PGLE doesn't work for pre-compiled modules. Since JAX need to recompile the module during execution the auto PGLE will not work neither for AoT nor for the following case: + +``` +import jax +from jax._src import config + +train_step_compiled = train_step().lower().compile() + +with config.enable_pgle(True), config.pgle_profiling_runs(1): + train_step_compiled() + # No effect since module was pre-compiled. + train_step_compiled() +``` + +#### Collecting NVIDIA Nsight Systems profiles when using AutoPGLE +[jax#24910](https://github.com/jax-ml/jax/pull/24910) (JAX v0.5.1 and newer) added a +new JAX configuration option, `JAX_COMPILATION_CACHE_EXPECT_PGLE`, which tells JAX to +attempt to load PGLE-optimized compiled functions from the persistent compilation +cache. + +This allows a two-step process, where the first step writes a PGLE-optimized function +to the cache: +```bash +export JAX_ENABLE_COMPILATION_CACHE=yes # not strictly needed, on by default +export JAX_COMPILATION_CACHE_DIR=/root/jax_cache +JAX_ENABLE_PGLE=yes python my-model.py +``` +And the second step uses Nsight Systems and loads the PGLE-optimized function from the +cache: +```bash +JAX_COMPILATION_CACHE_EXPECT_PGLE=yes nsys profile python my-model.py +``` +See also [this page]( +https://docs.jax.dev/en/latest/persistent_compilation_cache.html#pitfalls) for more +information about the persistent compilation cache and possible pitfalls. + +### Manual PGLE + +If you still want to use a manual Profile Guided Latency Estimator the workflow in XLA/GPU is: + +- 1. Run your workload once, with async collectives and latency hiding scheduler enabled. + +You could do so by setting: + +```bash +export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true" +``` + +- 2. Collect and post process a profile by using JAX profiler, saving the extracted instruction latencies into a binary protobuf file. + +```python +import os +from etils import epath +import jax +from jax.experimental import profiler as exp_profiler + +# Define your profile directory +profile_dir = 'gs://my_bucket/profile' +jax.profiler.start_trace(profile_dir) + +# run your workflow +# for i in range(10): +# train_step() + +# Stop trace +jax.profiler.stop_trace() +profile_dir = epath.Path(profile_dir) +directories = profile_dir.glob('plugins/profile/*/') +directories = [d for d in directories if d.is_dir()] +rundir = directories[-1] +logging.info('rundir: %s', rundir) + +# Post process the profile +fdo_profile = exp_profiler.get_profiled_instructions_proto(os.fspath(rundir)) + +# Save the profile proto to a file. +dump_dir = rundir / 'profile.pb' +dump_dir.parent.mkdir(parents=True, exist_ok=True) +dump_dir.write_bytes(fdo_profile) + +``` + +After this step, you will get a `profile.pb` file under the `rundir` printed in the code. + +- 3. Run the workload again feeding that file into the compilation. + +You need to pass the `profile.pb` file to the `--xla_gpu_pgle_profile_file_or_directory_path` flag. + +```bash + export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_pgle_profile_file_or_directory_path=/path/to/profile/profile.pb" +``` + +To enable logging in the XLA and check if the profile is good, set the logging level to include `INFO`: + +```bash +export TF_CPP_MIN_LOG_LEVEL=0 +``` + +Run the real workflow, if you found these loggings in the running log, it means the profiler is used in the latency hiding scheduler: + +``` +2023-07-21 16:09:43.551600: I external/xla/xla/service/gpu/gpu_hlo_schedule.cc:478] Using PGLE profile from /tmp/profile/plugins/profile/2023_07_20_18_29_30/profile.pb +2023-07-21 16:09:43.551741: I external/xla/xla/service/gpu/gpu_hlo_schedule.cc:573] Found profile, using profile guided latency estimator +``` + +#### Flags * **--xla_gpu_enable_latency_hiding_scheduler** This flag enables latency hiding schedulers to overlap asynchronous communication with computation efficiently. @@ -77,20 +243,6 @@ training on Nvidia GPUs](https://github.com/NVIDIA/JAX-Toolbox/blob/main/rosetta By adjusting this factor, users can fine-tune the trade-off between memory efficiency and performance optimizations. -* **--xla_gpu_enable_pipelined_collectives** When using pipeline parallelism, - this flag enables overlapping the (i+1)-th layer weight `AllGather` with the - i-th layer computation. It also enables overlapping (i+1)-th layer - weight `Reduce`/`ReduceScatter` with i-th layer's computation. The default - value is False. **There are some bugs when this flag is turned on.** -* **--xla_gpu_collective_permute_decomposer_threshold** This flag is useful when - performing [GSPMD pipelining](https://arxiv.org/abs/2105.04663). Setting a - nonzero threshold decomposes `CollectivePermute`s into - `CollectivePermuteReceiveDone` and `CollectivePermuteSendDone` pairs, so that - computation can be performed between each corresponding - `ReceiveDone`/`SendDone` pair and hence achieve more overlap. By default the - threshold is 0 and there is no decomposition. Setting it to threshold > 0 such - as `--xla_gpu_collective_permute_decomposer_threshold=1024` can enable this - feature. * **--xla_gpu_all_gather_combine_threshold_bytes** **--xla_gpu_reduce_scatter_combine_threshold_bytes** **--xla_gpu_all_reduce_combine_threshold_bytes** @@ -102,6 +254,228 @@ training on Nvidia GPUs](https://github.com/NVIDIA/JAX-Toolbox/blob/main/rosetta combine at least a Transformer Layer's weight `AllGather`/`ReduceScatter`. By default, the `combine_threshold_bytes` is set to 256. +### Pipeline Parallelism on GPU + +XLA implements SPMD-based pipeline parallelism optimizations. This is a scaling +technique where the forward and backward pass are split into multiple pipeline +stages. Each device (or device group) processes the result of the previous +pipeline stage (or the pipeline input) and sends its partial result to the next +stage until the end of the pipeline is reached. This optimization works best +when the latency of the computation is larger than communication. At compile +time, the operations will be rearranged to overlap communication with +computation. + +For an optimized schedule, we recommend these XLA flags: +``` +--xla_gpu_enable_latency_hiding_scheduler=true +--xla_gpu_enable_command_buffer='' +--xla_disable_hlo_passes=collective-permute-motion +--xla_gpu_experimental_pipeline_parallelism_opt_level=PIPELINE_PARALLELISM_OPT_LEVEL_ENABLE +``` + +The following JAX example demonstrates a pattern where communication operations +are scheduled to overlap with computations. In this example we will illustrate +how to set up an optimized pipeline parallelism scheduling using 4 GPUs that +form a communication ring (device 0 -> device 1 -> device 2 -> device 3 -> +device 0). We refer to the pattern `0 -> 1 -> 2 -> 3` as the forward edge, and +`3 -> 0` as the back edge. + +``` +# Imports and setup +import functools +import jax +from jax import sharding +from jax.experimental import mesh_utils +import jax.numpy as jnp +import jax.random + +NUM_DEVICES = 4 +NUM_MICROBATCHES = 5 +NUM_CIRC_REPEATS = 2 +CONTRACTING_DIM_SIZE = 4096 +NON_CONTRACTING_DIM_SIZE = 8192 +COMPUTE_INTENSITY = 32 + +# Creates a collective permute for the "forward edge". +# 0->1, 1->2, ... (N-2)->(N-1) +def shift_right(arr): + padding = [[1, 0]] + [[0, 0]] * (arr.ndim - 1) + # Use lax.slice to guarantee the gradient is a pad. + return jax.lax.slice(jnp.pad(arr, padding), [0] * arr.ndim, arr.shape) + + +# Creates a collective permute for the "back edge". +# (N-1)->0 +def cycle_back(arr): + padding = [[0, NUM_DEVICES - 1]] + [[0, 0]] * (arr.ndim - 1) + return jax.lax.slice( + jnp.pad(arr, padding), + [NUM_DEVICES - 1] + [0] * (arr.ndim - 1), + (NUM_DEVICES - 1 + arr.shape[0],) + arr.shape[1:], + ) + + +def select_on_first_device(then_value, else_value): + assert then_value.shape == else_value.shape + is_first_device = jax.lax.broadcasted_iota("int32", then_value.shape, 0) == 0 + return jnp.where(is_first_device, then_value, else_value) + + +def select_on_last_device(then_value, else_value): + assert then_value.shape == else_value.shape + is_last_device = ( + jax.lax.broadcasted_iota("int32", then_value.shape, 0) == NUM_DEVICES - 1 + ) + return jnp.where(is_last_device, then_value, else_value) + + +def select_on_first_cycle(i, then_value, else_value): + assert then_value.shape == else_value.shape + is_first_cycle = i < NUM_MICROBATCHES + return jnp.where(is_first_cycle, then_value, else_value) + + +def while_body(carry, i): + """Body of the pipeline while loop.""" + weights, input_buffer, output_buffer, fwd_edge_data, bwd_edge_data = carry + + # Read input data from input buffer. + input_data = jax.lax.dynamic_slice( + input_buffer, + (0, (i + 0) % NUM_MICROBATCHES, 0, 0), + (NUM_DEVICES, 1, CONTRACTING_DIM_SIZE, NON_CONTRACTING_DIM_SIZE), + ) + + # Collective permute on the "forward edge" shifts data to the next stage. + fwd_edge_data = shift_right(fwd_edge_data) + + # Select compute argument based on device and pipeline cycle. + compute_argument = select_on_first_device( + select_on_first_cycle(i, input_data, bwd_edge_data), + fwd_edge_data, + ).reshape((NUM_DEVICES, CONTRACTING_DIM_SIZE, NON_CONTRACTING_DIM_SIZE)) + + # A few matmuls to simulate compute. + tmp = compute_argument + for _ in range(COMPUTE_INTENSITY): + tmp = jax.lax.dot_general(weights, tmp, (((2,), (1,)), ((0,), (0,)))) + compute_result = tmp.reshape( + (NUM_DEVICES, 1, CONTRACTING_DIM_SIZE, NON_CONTRACTING_DIM_SIZE) + ) + + # Read data from buffer to pass it to the first device of the pipeline on the + # "back edge". + bwd_edge_data = jax.lax.dynamic_slice( + output_buffer, + (0, (1 + i) % NUM_MICROBATCHES, 0, 0), + (NUM_DEVICES, 1, CONTRACTING_DIM_SIZE, NON_CONTRACTING_DIM_SIZE), + ) + + # Collective permute on the "back edge" passes data to the first device. + bwd_edge_data = cycle_back(bwd_edge_data) + + # Update output buffer. We do this after reading from it to avoid the data + # dependency. + output_buffer = jax.lax.dynamic_update_slice( + output_buffer, + compute_result, + (0, (2 + i) % NUM_MICROBATCHES, 0, 0), + ) + + fwd_edge_data = compute_result + carry = ( + weights, + input_buffer, + output_buffer, + fwd_edge_data, + bwd_edge_data, + ) + return carry, i + + +@functools.partial(jax.jit, static_argnames=["mesh"]) +def entry_computation(weights, input_buffer, mesh): + + # Init output buffer. + output_buffer = jnp.zeros_like(input_buffer) + + # Init dummy data for forward and backward edge passed through the while loop. + dummy_data = jnp.zeros( + shape=(NUM_DEVICES, 1, CONTRACTING_DIM_SIZE, NON_CONTRACTING_DIM_SIZE) + ).astype(jnp.float32) + dummy_data = jax.device_put( + dummy_data, + sharding.NamedSharding( + mesh, sharding.PartitionSpec("the_one_and_only_axis") + ), + ) + + # Start pipeline. + carry = weights, input_buffer, output_buffer, dummy_data, dummy_data + num_iterations = NUM_CIRC_REPEATS * NUM_MICROBATCHES + NUM_DEVICES - 1 + carry, _ = jax.lax.scan(while_body, carry, xs=jnp.arange(num_iterations)) + _, _, output_buffer, _, _ = carry + + return output_buffer + + +def main(_): + + # Expect constant number of devices. + assert NUM_DEVICES == jax.local_device_count() + + # Create mesh. + mesh = sharding.Mesh( + mesh_utils.create_device_mesh([NUM_DEVICES]), + axis_names=["the_one_and_only_axis"], + ) + + # Init weights. + weights = 1.0 / CONTRACTING_DIM_SIZE + weights = jax.lax.broadcast_in_dim( + weights, + shape=(NUM_DEVICES, CONTRACTING_DIM_SIZE, CONTRACTING_DIM_SIZE), + broadcast_dimensions=(), + ) + weights = jax.device_put( + weights, + sharding.NamedSharding( + mesh, sharding.PartitionSpec("the_one_and_only_axis") + ), + ) + + # Init random input and replicate it across all devices. + random_key = jax.random.key(0) + input_buffer = jax.random.uniform( + random_key, + shape=( + NUM_MICROBATCHES, + CONTRACTING_DIM_SIZE, + NON_CONTRACTING_DIM_SIZE, + ), + ) + input_buffer = jax.lax.broadcast_in_dim( + input_buffer, + shape=( + NUM_DEVICES, + NUM_MICROBATCHES, + CONTRACTING_DIM_SIZE, + NON_CONTRACTING_DIM_SIZE, + ), + broadcast_dimensions=[1, 2, 3], + ) + input_buffer = jax.device_put( + input_buffer, + sharding.NamedSharding( + mesh, sharding.PartitionSpec("the_one_and_only_axis") + ), + ) + + # Run computation. + output_buffer = entry_computation(weights, input_buffer, mesh) + print(f"output_buffer = \n{output_buffer}") +``` + ## NCCL flags These Nvidia NCCL flag values may be useful for single-host multi-device diff --git a/docs/gradient-checkpointing.md b/docs/gradient-checkpointing.md index 0938a5da944f..e4e842df49f0 100644 --- a/docs/gradient-checkpointing.md +++ b/docs/gradient-checkpointing.md @@ -341,7 +341,7 @@ def predict(params, x): return x ``` -By itself, {func}`jax.ad_checkpoint import.checkpoint_name` is just an identity function. But because some policy functions know to look for them, you can use the names to control whether certain values output by {func}`jax.ad_checkpoint import.checkpoint_name` are considered saveable: +By itself, {func}`jax.ad_checkpoint.checkpoint_name` is just an identity function. But because some policy functions know to look for them, you can use the names to control whether certain values output by {func}`jax.ad_checkpoint.checkpoint_name` are considered saveable: ```{code-cell} print_saved_residuals(loss, params, x, y) diff --git a/docs/index.rst b/docs/index.rst index ba8ebcbdd128..93fc6c284685 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -63,8 +63,9 @@ JAX: High performance array computing :link-type: ref :class-card: user-guides -If you're looking to train neural networks, use Flax_ and start with its tutorials. -For an end-to-end transformer library built on JAX, see MaxText_. +If you're looking to use JAX to train neural networks, start with the +`JAX AI Stack Tutorials`_, and then check out the `JAX AI Stack Examples`_ +to see how JAX models can be implemented using the Flax_ framework. Ecosystem --------- @@ -107,7 +108,7 @@ numerical computing tools; the following is just a small sample of what is out t .. grid-item:: :material-regular:`bar_chart;2em` **Probabilistic modeling** - - `TensorFlow Probabilty`_ + - `TensorFlow Probability`_ - Distrax_ .. grid-item:: :material-outlined:`animation;2em` **Physics & simulation** @@ -121,6 +122,7 @@ numerical computing tools; the following is just a small sample of what is out t - AXLearn_ - Levanter_ - EasyLM_ + - Marin_ Many more JAX-based libraries have been developed; the community-run `Awesome JAX`_ page @@ -165,6 +167,11 @@ maintains an up-to-date list. changelog glossary +.. toctree:: + :hidden: + :maxdepth: 2 + + config_options .. _Awesome JAX: https://github.com/n2cholas/awesome-jax .. _AXLearn: https://github.com/apple/axlearn @@ -179,8 +186,11 @@ maintains an up-to-date list. .. _Grain: https://github.com/google/grain .. _Hugging Face Datasets: https://huggingface.co/docs/datasets/ .. _JAX MD: https://jax-md.readthedocs.io/ +.. _JAX AI Stack Tutorials: https://docs.jaxstack.ai/en/latest/tutorials.html +.. _JAX AI Stack Examples: https://docs.jaxstack.ai/en/latest/examples.html .. _Keras: https://keras.io/ .. _Levanter: https://github.com/stanford-crfm/levanter +.. _Marin: https://github.com/marin-community/marin .. _Lineax: https://github.com/patrick-kidger/lineax .. _MaxText: https://github.com/google/maxtext/ .. _Numpyro: https://num.pyro.ai/en/latest/index.html @@ -189,4 +199,4 @@ maintains an up-to-date list. .. _Orbax: https://orbax.readthedocs.io/ .. _PyMC: https://www.pymc.io/ .. _TensorFlow Datasets: https://www.tensorflow.org/datasets -.. _TensorFlow Probabilty: https://www.tensorflow.org/probability +.. _TensorFlow Probability: https://www.tensorflow.org/probability diff --git a/docs/installation.md b/docs/installation.md index ee675dd1e586..4019f6461473 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -28,14 +28,14 @@ different builds for different operating systems and accelerators. The table below shows all supported platforms and installation options. Check if your setup is supported; and if it says _"yes"_ or _"experimental"_, then click on the corresponding link to learn how to install JAX in greater detail. -| | Linux, x86_64 | Linux, aarch64 | Mac, x86_64 | Mac, aarch64 | Windows, x86_64 | Windows WSL2, x86_64 | -|------------------|---------------------------------------|---------------------------------|---------------------------------------|---------------------------------------|--------------------------|------------------------------------------| -| CPU | {ref}`yes ` | {ref}`yes ` | {ref}`jax≤0.4.38 only ` | {ref}`yes ` | {ref}`yes ` | {ref}`yes ` | -| NVIDIA GPU | {ref}`yes ` | {ref}`yes ` | no | n/a | no | {ref}`experimental ` | -| Google Cloud TPU | {ref}`yes ` | n/a | n/a | n/a | n/a | n/a | -| AMD GPU | {ref}`experimental ` | no | {ref}`experimental ` | n/a | no | no | -| Apple GPU | n/a | no | n/a | {ref}`experimental ` | n/a | n/a | -| Intel GPU | {ref}`experimental `| n/a | n/a | n/a | no | no | +| | Linux, x86_64 | Linux, aarch64 | Mac, aarch64 | Windows, x86_64 | Windows WSL2, x86_64 | +|------------------|---------------------------------------|---------------------------------|---------------------------------------|--------------------------|------------------------------------------| +| CPU | {ref}`yes ` | {ref}`yes ` | {ref}`yes ` | {ref}`yes ` | {ref}`yes ` | +| NVIDIA GPU | {ref}`yes ` | {ref}`yes ` | n/a | no | {ref}`experimental ` | +| Google Cloud TPU | {ref}`yes ` | n/a | n/a | n/a | n/a | +| AMD GPU | {ref}`yes ` | no | n/a | no | no | +| Apple GPU | n/a | no | {ref}`experimental ` | n/a | n/a | +| Intel GPU | {ref}`experimental `| n/a | n/a | no | no | (install-cpu)= @@ -48,7 +48,6 @@ operating systems and architectures: - Linux, x86_64 - Linux, aarch64 -- macOS, Intel - macOS, Apple ARM-based - Windows, x86_64 (*experimental*) @@ -158,7 +157,7 @@ pip install --upgrade pip # Installs the wheel compatible with NVIDIA CUDA 12 and cuDNN 9.0 or newer. # Note: wheels only available on linux. -pip install --upgrade "jax[cuda12_local]" +pip install --upgrade "jax[cuda12-local]" ``` **These `pip` installations do not work with Windows, and may fail silently; refer to the table @@ -226,10 +225,10 @@ refer to (install-amd-gpu)= ## AMD GPU (Linux) -JAX has experimental ROCm support. There are two ways to install JAX: +AMD GPU support is provided by a ROCm JAX plugin supported by AMD. -* Use [AMD's Docker container](https://hub.docker.com/r/rocm/jax-community/tags); or -* Build from source. Refer to the section [Additional notes for building a ROCm jaxlib for AMD GPUs](https://jax.readthedocs.io/en/latest/developer.html#additional-notes-for-building-a-rocm-jaxlib-for-amd-gpus). +There are several ways to use JAX on AMDGPU devices. +Please see [AMD's instructions](https://github.com/jax-ml/jax/blob/main/build/rocm/README.md) for details. (install-intel-gpu)= ## Intel GPU @@ -281,22 +280,34 @@ Unlike the instructions for installing a JAX release, here we name all of JAX's packages explicitly on the command line, so `pip` will upgrade them if a newer version is available. +JAX publishes nightlies, release candidates(RCs), and releases to several non-pypi [PEP 503](https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/) indexes. + +All JAX packages can be reached from the index `https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/` +as well as PyPI mirrored packages. This additional mirroring enables nightly +installation to use --index (-i) as the install method with pip. + +**Note:** The unified index could return an RC or release as the newest version +even with `--pre` immediately after a release before the newest nightly is +rebuilt. If automation or testing must be done against nightlies or you cannot +use our full index, use the extra index `https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/simple/` +which only contains nightly artifacts. + - CPU only: ```bash -pip install -U --pre jax jaxlib -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html +pip install -U --pre jax jaxlib -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ ``` - Google Cloud TPU: ```bash -pip install -U --pre jax jaxlib libtpu requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html +pip install -U --pre jax jaxlib libtpu requests -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html ``` - NVIDIA GPU (CUDA 12): ```bash -pip install -U --pre jax jaxlib "jax-cuda12-plugin[with_cuda]" jax-cuda12-pjrt -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html +pip install -U --pre jax jaxlib "jax-cuda12-plugin[with-cuda]" jax-cuda12-pjrt -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ ``` - NVIDIA GPU (CUDA 12) legacy: @@ -322,10 +333,10 @@ still be installed directly via the URLs here. For example: ```bash # Install jaxlib on CPU via the wheel archive -pip install "jax[cpu]==0.3.25" -f https://storage.googleapis.com/jax-releases/jax_releases.html +pip install "jax[cpu]==0.3.25" -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ # Install the jaxlib 0.3.25 CPU wheel directly -pip install jaxlib==0.3.25 -f https://storage.googleapis.com/jax-releases/jax_releases.html +pip install jaxlib==0.3.25 -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ ``` For specific older GPU wheels, be sure to use the `jax_cuda_releases.html` URL; for example ```bash diff --git a/docs/jax-primitives.md b/docs/jax-primitives.md index abdc8be6d0a8..fab5334b4010 100644 --- a/docs/jax-primitives.md +++ b/docs/jax-primitives.md @@ -21,7 +21,7 @@ kernelspec: A JAX primitive is the basic computational unit of a JAX program. This document explains the interface that a JAX primitive must support to allow JAX to perform all its transformations (this is not a how-to guide). -For example, the multiply-add operation can be implemented in terms of the low-level `jax.lax.*` primitives (which are like XLA operator wrappers) or `jax.core.Primitive("multiply_add")`, as demonstrated further below. +For example, the multiply-add operation can be implemented in terms of the low-level `jax.lax.*` primitives (which are like XLA operator wrappers) or `jax.extend.core.Primitive("multiply_add")`, as demonstrated further below. And JAX is able to take sequences of such primitive operations, and transform them via its composable transformations of Python functions, such as {func}`jax.jit`, {func}`jax.grad` and {func}`jax.vmap`. JAX implements these transforms in a *JAX-traceable* way. This means that when a Python function is executed, the only operations it applies to the data are either: @@ -100,7 +100,7 @@ def trace(name): vtype = str(type(v)) if "jax._src.xla_bridge._JaxComputationBuilder" in vtype: return "" - elif "jaxlib.xla_extension.XlaOp" in vtype: + elif "jaxlib._jax_.XlaOp" in vtype: return "".format(id(v)) elif ("partial_eval.JaxprTracer" in vtype or "batching.BatchTracer" in vtype or @@ -171,7 +171,7 @@ The JAX traceability property is satisfied as long as the function is written in The right way to add support for multiply-add is in terms of existing JAX primitives, as shown above. However, to demonstrate how JAX primitives work, pretend that you want to add a new primitive to JAX for the multiply-add functionality. ```{code-cell} -from jax import core +from jax.extend import core multiply_add_p = core.Primitive("multiply_add") # Create the primitive @@ -300,7 +300,7 @@ def multiply_add_lowering(ctx, xc, yc, zc): return [hlo.AddOp(hlo.MulOp(xc, yc), zc).result] # Now, register the lowering rule with JAX. -# For GPU, refer to the https://jax.readthedocs.io/en/latest/Custom_Operation_for_GPUs.html +# For GPU, refer to the https://docs.jax.dev/en/latest/Custom_Operation_for_GPUs.html from jax.interpreters import mlir mlir.register_lowering(multiply_add_p, multiply_add_lowering, platform='cpu') diff --git a/docs/jax.dlpack.rst b/docs/jax.dlpack.rst index 4a679052775e..eba3ecf62954 100644 --- a/docs/jax.dlpack.rst +++ b/docs/jax.dlpack.rst @@ -9,4 +9,3 @@ :toctree: _autosummary from_dlpack - to_dlpack \ No newline at end of file diff --git a/docs/jax.experimental.pallas.mosaic_gpu.rst b/docs/jax.experimental.pallas.mosaic_gpu.rst index 2d3452609c75..4191dde74df7 100644 --- a/docs/jax.experimental.pallas.mosaic_gpu.rst +++ b/docs/jax.experimental.pallas.mosaic_gpu.rst @@ -10,9 +10,9 @@ Classes :toctree: _autosummary Barrier - GPUBlockSpec - GPUCompilerParams - GPUMemorySpace + BlockSpec + CompilerParams + MemorySpace Layout SwizzleTransform TilingTransform diff --git a/docs/jax.experimental.pallas.triton.rst b/docs/jax.experimental.pallas.triton.rst index 76b0896ccf17..023a33bb0909 100644 --- a/docs/jax.experimental.pallas.triton.rst +++ b/docs/jax.experimental.pallas.triton.rst @@ -9,7 +9,7 @@ Classes .. autosummary:: :toctree: _autosummary - TritonCompilerParams + CompilerParams Functions --------- @@ -19,4 +19,4 @@ Functions approx_tanh debug_barrier - elementwise_inline_asm \ No newline at end of file + elementwise_inline_asm diff --git a/docs/jax.lax.rst b/docs/jax.lax.rst index 9db79f591a4e..43937130e5f4 100644 --- a/docs/jax.lax.rst +++ b/docs/jax.lax.rst @@ -222,6 +222,7 @@ Parallel operators pshuffle pswapaxes axis_index + axis_size Sharding-related operators -------------------------- diff --git a/docs/jax.nn.rst b/docs/jax.nn.rst index adb13f89903d..339f07f4cdcc 100644 --- a/docs/jax.nn.rst +++ b/docs/jax.nn.rst @@ -40,6 +40,7 @@ Activation functions glu squareplus mish + identity Other functions --------------- @@ -53,3 +54,6 @@ Other functions standardize one_hot dot_product_attention + scaled_matmul + get_scaled_dot_general_config + scaled_dot_general diff --git a/docs/jax.rst b/docs/jax.rst index 98cd464cda15..de901caf9414 100644 --- a/docs/jax.rst +++ b/docs/jax.rst @@ -57,6 +57,7 @@ Configuration enable_custom_prng enable_custom_vjp_by_custom_transpose log_compiles + no_tracing numpy_rank_promotion transfer_guard @@ -105,6 +106,31 @@ Automatic differentiation closure_convert checkpoint +Vectorization (:code:`vmap`) +---------------------------- + +.. autosummary:: + :toctree: _autosummary + + vmap + numpy.vectorize + +Parallelization (:code:`pmap`) +------------------------------ + +.. autosummary:: + :toctree: _autosummary + + shard_map + pmap + devices + local_devices + process_index + device_count + local_device_count + process_count + process_indices + Customization ------------- @@ -216,30 +242,6 @@ Array properties and methods Array.T Array.mT -Vectorization (:code:`vmap`) ----------------------------- - -.. autosummary:: - :toctree: _autosummary - - vmap - numpy.vectorize - -Parallelization (:code:`pmap`) ------------------------------- - -.. autosummary:: - :toctree: _autosummary - - pmap - devices - local_devices - process_index - device_count - local_device_count - process_count - process_indices - Callbacks --------- diff --git a/docs/jax.scipy.rst b/docs/jax.scipy.rst index dcbb673997ad..3c436697e1be 100644 --- a/docs/jax.scipy.rst +++ b/docs/jax.scipy.rst @@ -69,6 +69,7 @@ jax.scipy.linalg lu lu_factor lu_solve + pascal polar qr rsf2csf diff --git a/docs/jax.sharding.rst b/docs/jax.sharding.rst index 954f62b8a52d..12760d62ddb3 100644 --- a/docs/jax.sharding.rst +++ b/docs/jax.sharding.rst @@ -16,15 +16,9 @@ Classes .. autoclass:: NamedSharding :members: :show-inheritance: -.. autoclass:: PositionalSharding - :members: - :show-inheritance: .. autoclass:: PmapSharding :members: :show-inheritance: -.. autoclass:: GSPMDSharding - :members: - :show-inheritance: .. autoclass:: PartitionSpec :members: .. autoclass:: Mesh diff --git a/docs/jax.tree.rst b/docs/jax.tree.rst index e65c77c757c1..1a0ddaec86d0 100644 --- a/docs/jax.tree.rst +++ b/docs/jax.tree.rst @@ -12,6 +12,7 @@ List of Functions :toctree: _autosummary all + broadcast flatten flatten_with_path leaves diff --git a/docs/jax.tree_util.rst b/docs/jax.tree_util.rst index 73fd1f376e9f..a17a947af320 100644 --- a/docs/jax.tree_util.rst +++ b/docs/jax.tree_util.rst @@ -13,7 +13,6 @@ List of Functions Partial all_leaves - build_tree register_dataclass register_pytree_node register_pytree_node_class @@ -38,6 +37,7 @@ These APIs are now accessed via :mod:`jax.tree`. :toctree: _autosummary tree_all + tree_broadcast tree_flatten tree_leaves tree_map diff --git a/docs/jax_array_migration.md b/docs/jax_array_migration.md index 95d4a632a295..3cc1629b2068 100644 --- a/docs/jax_array_migration.md +++ b/docs/jax_array_migration.md @@ -1,3 +1,6 @@ +--- +orphan: true +--- (jax-array-migration)= # jax.Array migration @@ -24,7 +27,7 @@ the unified jax.Array After the migration is complete `jax.Array` will be the only type of array in JAX. -This doc explains how to migrate existing codebases to `jax.Array`. For more information on using `jax.Array` and JAX parallelism APIs, see the [Distributed arrays and automatic parallelization](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html) tutorial. +This doc explains how to migrate existing codebases to `jax.Array`. For more information on using `jax.Array` and JAX parallelism APIs, see the [Distributed arrays and automatic parallelization](https://docs.jax.dev/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html) tutorial. ### How to enable jax.Array? diff --git a/docs/jep/10657-sequencing-effects.md b/docs/jep/10657-sequencing-effects.md index 5f7eb0da4c04..ac3024519101 100644 --- a/docs/jep/10657-sequencing-effects.md +++ b/docs/jep/10657-sequencing-effects.md @@ -47,7 +47,7 @@ g() In many cases, JAX will execute `f` and `g` *in parallel*, dispatching the computations onto different threads -- `g` might actually be executed before `f`. Parallel execution is a nice performance optimization, especially if copying -to and from a device is expensive (see the [asynchronous dispatch note](https://jax.readthedocs.io/en/latest/async_dispatch.html) for more details). +to and from a device is expensive (see the [asynchronous dispatch note](https://docs.jax.dev/en/latest/async_dispatch.html) for more details). In practice, however, we often don't need to think about asynchronous dispatch because we're writing pure functions and only care about the inputs and outputs of functions -- we'll naturally block on future diff --git a/docs/jep/12049-type-annotations.md b/docs/jep/12049-type-annotations.md index 7a20958c5cab..bf6123b2bc7f 100644 --- a/docs/jep/12049-type-annotations.md +++ b/docs/jep/12049-type-annotations.md @@ -35,7 +35,7 @@ def slice(operand: Array, start_indices: Sequence[int], For the purposes of static type checking, this use of `Array = Any` for array type annotations puts no constraint on the argument values (`Any` is equivalent to no annotation at all), but it does serve as a form of useful in-code documentation for the developer. -For the sake of generated documentation, the name of the alias gets lost (the [HTML docs](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.slice.html) for `jax.lax.slice` report operand as type `Any`), so the documentation benefit does not go beyond the source code (though we could enable some `sphinx-autodoc` options to improve this: See [autodoc_type_aliases](https://www.sphinx-doc.org/en/master/usage/extensions/autodoc.html#confval-autodoc_type_aliases)). +For the sake of generated documentation, the name of the alias gets lost (the [HTML docs](https://docs.jax.dev/en/latest/_autosummary/jax.lax.slice.html) for `jax.lax.slice` report operand as type `Any`), so the documentation benefit does not go beyond the source code (though we could enable some `sphinx-autodoc` options to improve this: See [autodoc_type_aliases](https://www.sphinx-doc.org/en/master/usage/extensions/autodoc.html#confval-autodoc_type_aliases)). A benefit of this level of type annotation is that it is never wrong to annotate a value with `Any`, so it will provide a concrete benefit to developers and users in the form of documentation, without added complexity of satisfying the stricter needs of any particular static type checker. @@ -122,7 +122,7 @@ All told, the array-type-granularity challenge is less of an issue than the othe ### Challenge 5: imprecise APIs inherited from NumPy A large part of JAX’s user-facing API is inherited from NumPy within the {mod}`jax.numpy` submodule. -NumPy’s API was developed years before static type checking became part of the Python language, and follows Python’s historic recommendations to use a [duck-typing](https://docs.python.org/3/glossary.html#term-duck-typing)/[EAFP](https://docs.python.org/3/glossary.html#term-eafp) coding style, in which strict type-checking at runtime is discouraged. As a concrete example of this, consider the {func}`numpy.tile` function, which is defined like this: +NumPy’s API was developed years before static type checking became part of the Python language, and follows Python’s historic recommendations to use a [duck-typing](https://docs.python.org/3/glossary.html#term-duck-typing)/[EAFP](https://docs.python.org/3/glossary.html#term-EAFP) coding style, in which strict type-checking at runtime is discouraged. As a concrete example of this, consider the {func}`numpy.tile` function, which is defined like this: ```python def tile(A, reps): diff --git a/docs/jep/14273-shard-map.md b/docs/jep/14273-shard-map.md index 63742bc852c6..fa6681551d17 100644 --- a/docs/jep/14273-shard-map.md +++ b/docs/jep/14273-shard-map.md @@ -4,7 +4,7 @@ *January 2023* **This was the design doc proposing `shard_map`. You may instead want -[the up-to-date user docs](https://jax.readthedocs.io/en/latest/notebooks/shard_map.html).** +[the up-to-date user docs](https://docs.jax.dev/en/latest/notebooks/shard_map.html).** ## Motivation @@ -18,7 +18,7 @@ We need great APIs for both, and rather than being mutually exclusive alternatives, they need to compose with each other. With `pjit` (now just `jit`) we have [a next-gen -API](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html) +API](https://docs.jax.dev/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html) for the first school. But we haven't quite leveled-up the second school. `pmap` follows the second school, but over time we found it has [fatal flaws](#why-dont-pmap-or-xmap-already-solve-this). `xmap` solved those flaws, diff --git a/docs/jep/15856-jex.md b/docs/jep/15856-jex.md index a5625abf8930..a821405c399e 100644 --- a/docs/jep/15856-jex.md +++ b/docs/jep/15856-jex.md @@ -14,13 +14,13 @@ import jax.extend as jex Several projects depend on JAX's codebase internals, often to use its core machinery (e.g. to write a -[transformation over its IR](https://jax.readthedocs.io/en/latest/notebooks/Writing_custom_interpreters_in_Jax.html)) +[transformation over its IR](https://docs.jax.dev/en/latest/notebooks/Writing_custom_interpreters_in_Jax.html)) or to extend it (e.g. to [define new primitives](https://github.com/dfm/extending-jax)). Two challenges for these dependencies are (a) that our internals aren't all solidly designed for external use, and (b) that circumventing JAX's public API is -[unsupported](https://jax.readthedocs.io/en/latest/api_compatibility.html). +[unsupported](https://docs.jax.dev/en/latest/api_compatibility.html). In other words, our internals are often used like a library, but are neither structured nor updated like one. @@ -50,12 +50,12 @@ removed altogether. To keep development overhead low, `jax.extend` would not follow the public -[API compatibility](https://jax.readthedocs.io/en/latest/api_compatibility.html) +[API compatibility](https://docs.jax.dev/en/latest/api_compatibility.html) policy. It would promise no deprecation windows nor backwards compatibility between releases. Every release may break existing callers without simple recourse (e.g. without a flag reintroducing prior behavior). We would rely on the -[changelog](https://jax.readthedocs.io/en/latest/changelog.html) +[changelog](https://docs.jax.dev/en/latest/changelog.html) to call out such changes. Callers of `jax.extend` that need to upgrade their code regularly @@ -108,7 +108,7 @@ to process the Jaxpr IR (the output of At initialization, this module will contain many more symbols than what's needed to define primitives and rules, including various names used in setting up -["final-style transformations"](https://jax.readthedocs.io/en/latest/autodidax.html#on-the-fly-final-style-and-staged-initial-style-processing), +["final-style transformations"](https://docs.jax.dev/en/latest/autodidax.html#on-the-fly-final-style-and-staged-initial-style-processing), such as the current `jax._src.core.Trace` and `Tracer` classes. We can revisit whether `jex.core` should also support final-style extensions alongside initial style approaches, and whether it can do so by a more @@ -137,7 +137,7 @@ tracer types from `jex`. This module plus `jex.core` ought to suffice for replicating today's custom primitive tutorials (e.g. -[ours](https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html) +[ours](https://docs.jax.dev/en/latest/notebooks/How_JAX_primitives_work.html) and [dfm's](https://github.com/dfm/extending-jax)). For instance, defining a primitive and its behavior under `jax.jit` @@ -184,6 +184,6 @@ arrays. We have only one item in mind for now. The XLA compiler's array sharding format is more expressive than [those provided by -JAX](https://jax.readthedocs.io/en/latest/jax.sharding.html). We could +JAX](https://docs.jax.dev/en/latest/jax.sharding.html). We could provide this as `jex.sharding.XlaOpShardingProto`, corresponding to today's `jax._src.lib.xla_client.OpSharding` internally. diff --git a/docs/jep/17111-shmap-transpose.md b/docs/jep/17111-shmap-transpose.md index 2fdf5f822835..00d8a3f383fd 100644 --- a/docs/jep/17111-shmap-transpose.md +++ b/docs/jep/17111-shmap-transpose.md @@ -497,7 +497,7 @@ of every function instance along which the outputs are mapped, whereas for mesh axes over which the output is unmapped only one copy of the value is used. See [the `shmap` -JEP](https://jax.readthedocs.io/en/latest/jep/14273-shard-map.html) for examples +JEP](https://docs.jax.dev/en/latest/jep/14273-shard-map.html) for examples of unmapped inputs and outputs. For comparison, in `vmap` unmapped inputs/outputs are indicated by using `in_axes` / `out_axes` of `None` (rather than an `int`). diff --git a/docs/jep/2026-custom-derivatives.md b/docs/jep/2026-custom-derivatives.md index ce149fa6fb35..b09926425667 100644 --- a/docs/jep/2026-custom-derivatives.md +++ b/docs/jep/2026-custom-derivatives.md @@ -2,7 +2,7 @@ This is a design document, explaining some of the thinking behind the design and implementation of `jax.custom_jvp` and `jax.custom_vjp`. For user-oriented -documentation, see [the tutorial notebook](https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html). +documentation, see [the tutorial notebook](https://docs.jax.dev/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html). There are two ways to define differentiation rules in JAX: 1. using `jax.custom_jvp` and `jax.custom_vjp` to define custom differentiation diff --git a/docs/jep/28661-jax-array-protocol.md b/docs/jep/28661-jax-array-protocol.md new file mode 100644 index 000000000000..e05d69d2822d --- /dev/null +++ b/docs/jep/28661-jax-array-protocol.md @@ -0,0 +1,214 @@ +# JEP 28661: Supporting the `__jax_array__` protocol + +[@jakevdp](http://github.com/jakevdp), *May 2025* + +An occasional user request is for the ability to define custom array-like objects that +work with jax APIs. JAX currently has a partial implementation of a mechanism that does +this via a `__jax_array__` method defined on the custom object. This was never intended +to be a load-bearing public API (see the discussion at {jax-issue}`#4725`), but has +become essential to packages like Keras and flax, which explicitly document the ability +to use their custom array objects with jax functions. This JEP proposes a design for +full, documented support of the `__jax_array__` protocol. + +## Levels of array extensibility +Requests for extensibility of JAX arrays come in a few flavors: + +### Level 1 Extensibility: polymorphic inputs +What I’ll call "Level 1" extensibility is the desire that JAX APIs accept polymorphic inputs. +That is, a user desires behavior like this: + +```python +class CustomArray: + data: numpy.ndarray + ... + +x = CustomArray(np.arange(5)) +result = jnp.sin(x) # Converts `x` to JAX array and returns a JAX array +``` + +Under this extensibility model, JAX functions would accept CustomArray objects as inputs, +implicitly converting them to `jax.Array` objects for the sake of computation. +This is similar to the functionality offered by NumPy via the `__array__` method, and in +JAX (in many but not all cases) via the `__jax_array__` method. + +This is the mode of extensibility that has been requested by the maintainers of `flax.nnx` +and others. The current implementation is also used by JAX internally for the case of +symbolic dimensions. + +### Level 2 extensibility: polymorphic outputs +What I’ll call "Level 2" extensibility is the desire that JAX APIs should not only accept +polymorphic inputs, but also wrap outputs to match the class of the input. +That is, a user desires behavior like this: + +```python +class CustomArray: + data: numpy.ndarray + ... + +x = CustomArray(np.arange(5)) +result = jnp.sin(x) # returns a new CustomArray +``` + +Under this extensibility model, JAX functions would not only accept custom objects +as inputs, but have some protocol to determine how to correctly re-wrap outputs with +the same class. In NumPy, this sort of functionality is offered in varying degrees by +the special `__array_ufunc__`, `__array_wrap__`, and `__array_function__` protocols, +which allow user-defined objects to customize how NumPy API functions operate on +arbitrary inputs and map input types to outputs. +JAX does not currently have any equivalent to these interfaces in NumPy. + +This is the mode of extensibility that has been requested by the maintainers of `keras`, +among others. + +### Level 3 extensibility: subclassing `Array` + +What I’ll call "Level 3" extensibility is the desire that the JAX array object itself +could be subclassable. NumPy provides some APIs that allow this +(see [Subclassing ndarray](https://numpy.org/devdocs/user/basics.subclassing.html)) but +this sort of approach would take some extra thought in JAX due to the need for +representing array objects abstractly via tracing. + +This mode of extensibility has occasionally been requested by users who want to add +special metadata to JAX arrays, such as units of measurement. + +## Synopsis + +For the sake of this proposal, we will stick with the simplest, level 1 extensibility +model. The proposed interface is the one currently non-uniformly supported by a number +of JAX APIs, the `__jax_array__` method. Its usage looks something like this: + +```python +import jax +import jax.numpy as jnp +import numpy as np + +class CustomArray: + data: np.ndarray + + def __init__(self, data: np.ndarray): + self.data = data + + def __jax_array__(self) -> jax.Array: + return jnp.asarray(self.data) + +arr = CustomArray(np.arange(5)) +result = jnp.multiply(arr, 2) +print(repr(result)) +# Array([0, 2, 4, 6, 8], dtype=int32) +``` + +We may revisit other extensibility levels in the future. + +## Design challenges + +JAX presents some interesting design challenges related to this kind of extensibility, +which have not been fully explored previously. We’ll discuss them in turn here: + +### Priority of `__jax_array__` vs. PyTree flattening +JAX already has a supported mechanism for registering custom objects, namely pytree +registration (see [Extending pytrees](https://docs.jax.dev/en/latest/pytrees.html#extending-pytrees)). +If we also support __jax_array__, which one should take precedence? + +To put this more concretely, what should be the result of this code? + +```python +@jax.jit +def f(x): + print("is JAX array:", isinstance(x, jax.Array)) + +f(CustomArray(...)) +``` + +If we choose to prioritize `__jax_array__` at the JIT boundary, then the output of this +function would be: +``` +is JAX array: True +``` +That is, at the JIT boundary, the `CustomArray` object would be converted into a +`__jax_array__`, and its shape and dtype would be used to construct a standard JAX +tracer for the function. + +If we choose to prioritize pytree flattening at the JIT boundary, then the output of +this function would be: +``` +type(x)=CustomArray +``` +That is, at the JIT boundary, the `CustomArray` object is flattened, and then unflattened +before being passed to the JIT-compiled function for tracing. If `CustomArray` has been +registered as a pytree, it will generally contain traced arrays as its attributes, and +when x is passed to any JAX API that supports `__jax_array__`, these traced attributes +will be converted to a single traced array according to the logic specified in the method. + +There are deeper consequences here for how other transformations like vmap and grad work +when encountering custom objects: for example, if we prioritize pytree flattening, vmap +would operate over the dimensions of the flattened contents of the custom object, while +if we prioritize `__jax_array__`, vmap would operate over the converted array dimensions. + +This also has consequences when it comes to JIT invariance: consider a function like this: +```python +def f(x): + if isinstance(x, CustomArray): + return x.custom_method() + else: + # do something else + ... + +result1 = f(x) +result2 = jax.jit(f)(x) +``` +If `jit` consumes `x` via pytree flattening, the results should agree for a well-specified +flattening rule. If `jit` consumes `x` via `__jax_array__`, the results will differ because +`x` is no longer a CustomArray within the JIT-compiled version of the function. + +#### Synopsis +As of JAX v0.6.0, transformations prioritize `__jax_array__` when it is available. This status +quo can lead to confusion around lack of JIT invariance, and the current implementation in practice +leads to subtle bugs in the case of automatic differentiation, where the forward and backward pass +do not treat inputs consistently. + +Because the pytree extensibility mechanism already exists for the case of customizing +transformations, it seems most straightforward if transformations act only via this +mechanism: that is, **we propose to remove `__jax_array__` parsing during abstractification.** +This approach will preserve object identity through transformations, and give the user the +most possible flexibility. If the user wants to opt-in to array conversion semantics, that +is always possible by explicitly casting their input via jnp.asarray, which will trigger the +`__jax_array__` protocol. + +### Which APIs should support `__jax_array__`? +JAX has a number of different levels of API, from the level of explicit primitive binding +(e.g. `jax.lax.add_p.bind(x, y)`) to the `jax.lax` APIs (e.g. `jax.lax.add(x, y)`) to the +`jax.numpy` APIs (e.g. `jax.numpy.add(x, y)`). Which of these API categories should handle +implicit conversion via `__jax_array__`? + +In order to limit the scope of the change and the required testing, I propose that `__jax_array__` +only be explicitly supported in `jax.numpy` APIs: after all, it is inspired by the` __array__` +protocol which is supported by the NumPy package. We could always expand this in the future to +`jax.lax` APIs if needed. + +This is in line with the current state of the package, where `__jax_array__` handling is mainly +within the input validation utilities used by `jax.numpy` APIs. + +## Implementation +With these design choices in mind, we plan to implement this as follows: + +- **Adding runtime support to `jax.numpy`**: This is likely the easiest part, as most + `jax.numpy` functions use a common internal utility (`ensure_arraylike`) to validate + inputs and convert them to array. This utility already supports `__jax_array__`, and + so most jax.numpy APIs are already compliant. +- **Adding test coverage**: To ensure compliance across the APIs, we should add a new + test scaffold that calls every `jax.numpy` API with custom inputs and validates correct + behavior. +- **Deprecating `__jax_array__` during abstractification**: Currently JAX's abstractification + pass, used in `jit` and other transformations, does parse the `__jax_array__` protocol, + and this is not the behavior we want long-term. We need to deprecate this behavior, and + ensure that downstream packages that rely on it can move toward pytree registration or + explicit array conversion where necessary. +- **Adding type annotations**: the type interface for jax.numpy functions is in + `jax/numpy/__init__.pyi`, and we’ll need to change each input type from `ArrayLike` to + `ArrayLike | SupportsJAXArray`, where the latter is a protocol with a `__jax_array__` + method. We cannot add this directly to the `ArrayLike` definition, because `ArrayLike` + is used in contexts where `__jax_array__` should not be supported. +- **Documentation**: once the above support is added, we should add a documentation section + on array extensibility that outlines exactly what to expect regarding the `__jax_array__` + protocol, with examples of how it can be used in conjunction with pytree registration + in order to effectively work with user-defined types. diff --git a/docs/jep/4008-custom-vjp-update.md b/docs/jep/4008-custom-vjp-update.md index 1e2270e052a6..c3f2be151ef7 100644 --- a/docs/jep/4008-custom-vjp-update.md +++ b/docs/jep/4008-custom-vjp-update.md @@ -4,7 +4,7 @@ _Oct 14 2020_ This doc assumes familiarity with `jax.custom_vjp`, as described in the [Custom derivative rules for JAX-transformable Python -functions](https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html) +functions](https://docs.jax.dev/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html) notebook. ## What to update diff --git a/docs/jep/4410-omnistaging.md b/docs/jep/4410-omnistaging.md index f95c15f404b6..5b4536864ac2 100644 --- a/docs/jep/4410-omnistaging.md +++ b/docs/jep/4410-omnistaging.md @@ -266,7 +266,7 @@ While tracing the function ex1 at ex1.py:4, this value became a tracer due to JA You can use transformation parameters such as `static_argnums` for `jit` to avoid tracing particular arguments of transformed functions. -See https://jax.readthedocs.io/en/latest/faq.html#abstract-tracer-value-encountered-where-concrete-value-is-expected-error for more information. +See https://docs.jax.dev/en/latest/faq.html#abstract-tracer-value-encountered-where-concrete-value-is-expected-error for more information. Encountered tracer value: Tracedwith ``` diff --git a/docs/jep/9407-type-promotion.ipynb b/docs/jep/9407-type-promotion.ipynb index a1ede3177a3a..5f12877c97a9 100644 --- a/docs/jep/9407-type-promotion.ipynb +++ b/docs/jep/9407-type-promotion.ipynb @@ -12,7 +12,7 @@ "\n", "*Jake VanderPlas, December 2021*\n", "\n", - "One of the challenges faced in the design of any numerical computing library is the choice of how to handle operations between values of different types. This document outlines the thought process behind the promotion semantics used by JAX, summarized in [JAX Type Promotion Semantics](https://jax.readthedocs.io/en/latest/type_promotion.html)." + "One of the challenges faced in the design of any numerical computing library is the choice of how to handle operations between values of different types. This document outlines the thought process behind the promotion semantics used by JAX, summarized in [JAX Type Promotion Semantics](https://docs.jax.dev/en/latest/type_promotion.html)." ] }, { @@ -1335,7 +1335,7 @@ "However, these advantages comes with a few tradeoffs:\n", "\n", "- mixed float/integer promotion is very prone to precision loss: for example, `int64` (with a maximum value of $9.2 \\times 10^{18}$) can be promoted to `float16` (with a maximum value of $6.5 \\times 10^4$), meaning most representable values will become `inf`.\n", - "- as mentioned above, `f*` can no longer be thought of as a \"scalar type\", but as a different flavor of float64. In JAX's parlance, this is referred to as a [*weak type*](https://jax.readthedocs.io/en/latest/type_promotion.html#weakly-typed-values-in-jax), in that it is represented as 64-bit, but only weakly holds to this bit width in promotion with other values.\n", + "- as mentioned above, `f*` can no longer be thought of as a \"scalar type\", but as a different flavor of float64. In JAX's parlance, this is referred to as a [*weak type*](https://docs.jax.dev/en/latest/type_promotion.html#weakly-typed-values-in-jax), in that it is represented as 64-bit, but only weakly holds to this bit width in promotion with other values.\n", "\n", "Note that also, this approach still leaves the `uint64` promotion question unanswered, although it is perhaps reasonable to close the lattice by connecting `u64` to `f*`." ] @@ -1413,7 +1413,7 @@ "id": "o0-E2KWjYEXO" }, "source": [ - "The behavior resulting from this choice is summarized in [JAX Type Promotion Semantics](https://jax.readthedocs.io/en/latest/type_promotion.html). Notably, aside from the inclusion of larger unsigned types (`u16`, `u32`, `u64`) and some details about the behavior of scalar/weak types (`i*`, `f*`, `c*`), this type promotion scheme turns out to be very close to that chosen by PyTorch.\n", + "The behavior resulting from this choice is summarized in [JAX Type Promotion Semantics](https://docs.jax.dev/en/latest/type_promotion.html). Notably, aside from the inclusion of larger unsigned types (`u16`, `u32`, `u64`) and some details about the behavior of scalar/weak types (`i*`, `f*`, `c*`), this type promotion scheme turns out to be very close to that chosen by PyTorch.\n", "\n", "For those interested, the appendix below prints the full promotion tables used by NumPy, Tensorflow, PyTorch, and JAX." ] @@ -2883,7 +2883,7 @@ "source": [ "### JAX Type Promotion: `jax.numpy`\n", "\n", - "`jax.numpy` follows type promotion rules laid out at https://jax.readthedocs.io/en/latest/type_promotion.html. Here we use `i*`, `f*`, `c*` to indicate both Python scalars and weakly-typed arrays." + "`jax.numpy` follows type promotion rules laid out at https://docs.jax.dev/en/latest/type_promotion.html. Here we use `i*`, `f*`, `c*` to indicate both Python scalars and weakly-typed arrays." ] }, { diff --git a/docs/jep/9407-type-promotion.md b/docs/jep/9407-type-promotion.md index ff67a8c21399..c047d76c1b18 100644 --- a/docs/jep/9407-type-promotion.md +++ b/docs/jep/9407-type-promotion.md @@ -20,7 +20,7 @@ kernelspec: *Jake VanderPlas, December 2021* -One of the challenges faced in the design of any numerical computing library is the choice of how to handle operations between values of different types. This document outlines the thought process behind the promotion semantics used by JAX, summarized in [JAX Type Promotion Semantics](https://jax.readthedocs.io/en/latest/type_promotion.html). +One of the challenges faced in the design of any numerical computing library is the choice of how to handle operations between values of different types. This document outlines the thought process behind the promotion semantics used by JAX, summarized in [JAX Type Promotion Semantics](https://docs.jax.dev/en/latest/type_promotion.html). +++ {"id": "Rod6OOyUVbQ8"} @@ -680,7 +680,7 @@ This is important because `f16` and `bf16` are not comparable because they utili However, these advantages comes with a few tradeoffs: - mixed float/integer promotion is very prone to precision loss: for example, `int64` (with a maximum value of $9.2 \times 10^{18}$) can be promoted to `float16` (with a maximum value of $6.5 \times 10^4$), meaning most representable values will become `inf`. -- as mentioned above, `f*` can no longer be thought of as a "scalar type", but as a different flavor of float64. In JAX's parlance, this is referred to as a [*weak type*](https://jax.readthedocs.io/en/latest/type_promotion.html#weakly-typed-values-in-jax), in that it is represented as 64-bit, but only weakly holds to this bit width in promotion with other values. +- as mentioned above, `f*` can no longer be thought of as a "scalar type", but as a different flavor of float64. In JAX's parlance, this is referred to as a [*weak type*](https://docs.jax.dev/en/latest/type_promotion.html#weakly-typed-values-in-jax), in that it is represented as 64-bit, but only weakly holds to this bit width in promotion with other values. Note that also, this approach still leaves the `uint64` promotion question unanswered, although it is perhaps reasonable to close the lattice by connecting `u64` to `f*`. @@ -730,7 +730,7 @@ nx.draw(graph, with_labels=True, node_size=1500, node_color='lightgray', pos=pos +++ {"id": "o0-E2KWjYEXO"} -The behavior resulting from this choice is summarized in [JAX Type Promotion Semantics](https://jax.readthedocs.io/en/latest/type_promotion.html). Notably, aside from the inclusion of larger unsigned types (`u16`, `u32`, `u64`) and some details about the behavior of scalar/weak types (`i*`, `f*`, `c*`), this type promotion scheme turns out to be very close to that chosen by PyTorch. +The behavior resulting from this choice is summarized in [JAX Type Promotion Semantics](https://docs.jax.dev/en/latest/type_promotion.html). Notably, aside from the inclusion of larger unsigned types (`u16`, `u32`, `u64`) and some details about the behavior of scalar/weak types (`i*`, `f*`, `c*`), this type promotion scheme turns out to be very close to that chosen by PyTorch. For those interested, the appendix below prints the full promotion tables used by NumPy, Tensorflow, PyTorch, and JAX. @@ -900,7 +900,7 @@ display.HTML(table.to_html()) ### JAX Type Promotion: `jax.numpy` -`jax.numpy` follows type promotion rules laid out at https://jax.readthedocs.io/en/latest/type_promotion.html. Here we use `i*`, `f*`, `c*` to indicate both Python scalars and weakly-typed arrays. +`jax.numpy` follows type promotion rules laid out at https://docs.jax.dev/en/latest/type_promotion.html. Here we use `i*`, `f*`, `c*` to indicate both Python scalars and weakly-typed arrays. ```{code-cell} :cellView: form diff --git a/docs/jep/9419-jax-versioning.md b/docs/jep/9419-jax-versioning.md index b964aa2af45d..85b95257ebae 100644 --- a/docs/jep/9419-jax-versioning.md +++ b/docs/jep/9419-jax-versioning.md @@ -167,16 +167,16 @@ We maintain an additional version number (`_version`) in [`xla_client.py` in the XLA repository](https://github.com/openxla/xla/blob/main/xla/python/xla_client.py). The idea is that this version number, is defined in `xla/python` together with the C++ parts of JAX, is also accessible to JAX Python as -`jax._src.lib.xla_extension_version`, and must +`jax._src.lib.jaxlib_extension_version`, and must be incremented every time that a change is made to the XLA/Python code that has backwards compatibility implications for `jax`. The JAX Python code can then use this version number to maintain backwards compatibility, e.g.: ``` -from jax._src.lib import xla_extension_version +from jax._src.lib import jaxlib_extension_version # 123 is the new version number for _version in xla_client.py -if xla_extension_version >= 123: +if jaxlib_extension_version >= 123: # Use new code path ... else: diff --git a/docs/jep/index.rst b/docs/jep/index.rst index 1c4ecbb3411f..2ba85a5f4a8d 100644 --- a/docs/jep/index.rst +++ b/docs/jep/index.rst @@ -52,6 +52,7 @@ Then create a pull request that adds a file named 17111: Efficient transposition of `shard_map` (and other maps) <17111-shmap-transpose> 18137: Scope of JAX NumPy & SciPy Wrappers <18137-numpy-scipy-scope> 25516: Effort-based versioning <25516-effver> + 28661: Supporting the `__jax_array__` protocol <28661-jax-array-protocol> Several early JEPs were converted in hindsight from other documentation, diff --git a/docs/jit-compilation.md b/docs/jit-compilation.md index 5e5be308068a..a4e2f8b41f0d 100644 --- a/docs/jit-compilation.md +++ b/docs/jit-compilation.md @@ -55,7 +55,7 @@ The {ref}`jax-internals-jaxpr` section of the documentation provides more inform Importantly, notice that the jaxpr does not capture the side-effect present in the function: there is nothing in it corresponding to `global_list.append(x)`. This is a feature, not a bug: JAX transformations are designed to understand side-effect-free (a.k.a. functionally pure) code. -If *pure function* and *side-effect* are unfamiliar terms, this is explained in a little more detail in [🔪 JAX - The Sharp Bits 🔪: Pure Functions](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#pure-functions). +If *pure function* and *side-effect* are unfamiliar terms, this is explained in a little more detail in [🔪 JAX - The Sharp Bits 🔪: Pure Functions](https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html#pure-functions). Impure functions are dangerous because under JAX transformations they are likely not to behave as intended; they might fail silently, or produce surprising downstream errors like leaked Tracers. Moreover, JAX often can't detect when side effects are present. diff --git a/docs/multi_process.md b/docs/multi_process.md index 32cfae126784..f8c2566ca872 100644 --- a/docs/multi_process.md +++ b/docs/multi_process.md @@ -1,176 +1,667 @@ -# Multi-host and multi-process environments - - - -## Introduction - -This guide explains how to use JAX in environments such as -GPU clusters and [Cloud TPU](https://cloud.google.com/tpu) pods where -accelerators are spread across multiple CPU hosts or JAX processes. We’ll refer -to these as “multi-process” environments. - -This guide specifically focuses on how to use collective communication -operations (e.g. {func}`jax.lax.psum` ) in multi-process settings, although -other communication methods may be useful too depending on your use case (e.g. -RPC, [mpi4jax](https://github.com/mpi4jax/mpi4jax)). If you’re not already -familiar with JAX’s collective operations, we recommend starting with the -{doc}`/sharded-computation` section. An important requirement of -multi-process environments in JAX is direct communication links between -accelerators, e.g. the high-speed interconnects for Cloud TPUs or -[NCCL](https://developer.nvidia.com/nccl) for GPUs. These links allow -collective operations to run across multiple processes’ worth of accelerators -with high performance. - -## Multi-process programming model - -Key concepts: - - * You must run at least one JAX process per host. - * You should initialize the cluster with {func}`jax.distributed.initialize`. - * Each process has a - distinct set of *local* devices it can address. The *global* devices are the set - of all devices across all processes. - * Use standard JAX parallelism APIs like {func}`~jax.jit` (see - {doc}`/sharded-computation` tutorial) and - {func}`~jax.experimental.shard_map.shard_map`. jax.jit only accepts - globally shaped arrays. shard_map allows you to drop to per-device - shape. - * Make sure all processes run the same parallel computations in the same - order. - * Make sure all processes has the same number of local devices. - * Make sure all devices are the same (e.g., all V100, or all H100). - -### Launching JAX processes - -Unlike other distributed systems where a single controller node manages many -worker nodes, JAX uses a “multi-controller” programming model where each JAX -Python process runs independently, sometimes referred to as a {term}`Single -Program, Multiple Data (SPMD)` model. Generally, the same JAX Python -program is run in each process, with only slight differences between each -process’s execution (e.g. different processes will load different input data). -Furthermore, **you must manually run your JAX program on each host!** JAX -doesn’t automatically start multiple processes from a single program invocation. - -(The requirement for multiple processes is why this guide isn’t offered as a -notebook -- we don’t currently have a good way to manage multiple Python -processes from a single notebook.) - -### Initializing the cluster - -To initialize the cluster, you should call {func}`jax.distributed.initialize` at -the start of each process. {func}`jax.distributed.initialize` must be called -early in the program, before any JAX computations are executed. - -The API {func}`jax.distributed.initialize` takes several arguments, namely: - - * `coordinator_address`: the IP address of process 0 in your cluster, together - with a port available on that process. Process 0 will start a JAX service - exposed via that IP address and port, to which the other processes in the - cluster will connect. - * `coordinator_bind_address`: the IP address and port to which the JAX service - on process 0 in your cluster will bind. By default, it will bind to all - available interfaces using the same port as `coordinator_address`. - * `num_processes`: the number of processes in the cluster - * `process_id`: the ID number of this process, in the range `[0 .. - num_processes)`. - * `local_device_ids`: Restricts the visible devices of the current process to - ``local_device_ids``. - -For example on GPU, a typical usage is: +# Introduction to multi-controller JAX (aka multi-process/multi-host JAX) + + + +By reading this tutorial, you'll learn how to scale JAX computations to more +devices than can fit in a single host machine, e.g. when running on a GPU +cluster, Cloud TPU pod, or multiple CPU-only machines. + +The main idea + +- **Run multiple Python processes**, which we sometimes call "controllers." We + can run one (or more) process per host machine. +- **Initialize the cluster with {func}`jax.distributed.initialize`**. +- **A {class}`jax.Array` can span all processes**, and if each process applies + the same JAX function to it, it's like programming against one big device. +- **Use the same [unified sharding mechanism][unified_sharding]** as in + single-controller JAX to control how data is distributed and computation is + parallelized. XLA automatically exploits high-speed networking links like TPU + ICI or NVLink between hosts when available, and otherwise uses available host + networking (e.g. Ethernet, InfiniBand). +- **All processes (usually) run the same Python script**. You write this Python + code almost exactly the same as you would for a single process — just run + multiple instances of it and JAX takes care of the rest. In other words, + except for array creation, you can write your JAX code as if there were one + giant machine with all devices attached to it. + +This tutorial assumes you've read [Distributed arrays and automatic +parallelization][distributed_arrays], which is about single-controller JAX. + +```{figure} _static/multi_process/mcjax_overview.png +:alt: Illustration of a multi-host TPU pod. Each host in the pod is attached via PCI to a board of four TPU chips. The TPUs chips themselves are connected via high-speed inter-chip interconnects. + +Illustration of a multi-host TPU pod. Each host in the pod (green) is attached +via PCI to a board of four TPU chips (blue). The TPUs chips themselves are +connected via high-speed inter-chip interconnects (ICI). JAX Python code runs on +each host, e.g. via ssh. The JAX processes on each host are aware of each other, +allowing you to orchestrate computation across the entire pods' worth of chips. +The principle is the same for GPU, CPU, and other platforms with JAX support! +``` + +## Toy example + +Before we define terms and walk through the details, here's a toy example: +making a process-spanning {class}`jax.Array` of values and applying +{mod}`jax.numpy` functions to it. ```python +# call this file toy.py, to be run in each process simultaneously + import jax +import jax.numpy as jnp +from jax.sharding import NamedSharding, PartitionSpec as P +import numpy as np + +# in this example, get multi-process parameters from sys.argv +import sys +proc_id = int(sys.argv[1]) +num_procs = int(sys.argv[2]) + +# initialize the distributed system +jax.distributed.initialize('localhost:10000', num_procs, proc_id) + +# this example assumes 8 devices total +assert jax.device_count() == 8 + +# make a 2D mesh that refers to devices from all processes +mesh = jax.make_mesh((4, 2), ('i', 'j')) -jax.distributed.initialize(coordinator_address="192.168.0.1:1234", - num_processes=2, - process_id=0) +# create some toy data +global_data = np.arange(32).reshape((4, 8)) + +# make a process- and device-spanning array from our toy data +sharding = NamedSharding(mesh, P('i', 'j')) +global_array = jax.device_put(global_data, sharding) +assert global_array.shape == global_data.shape + +# each process has different shards of the global array +for shard in global_array.addressable_shards: + print(f"device {shard.device} has local data {shard.data}") + +# apply a simple computation, automatically partitioned +global_result = jnp.sum(jnp.sin(global_array)) +print(f'process={proc_id} got result: {global_result}') ``` -On Cloud TPU, Slurm and Open MPI environments, you can simply call {func}`jax.distributed.initialize()` with no -arguments. Default values for the arguments will be chosen automatically. -When running on GPUs with Slurm and Open MPI, it is assumed that one process is started per GPU, i.e. each process will -be assigned only one visible local device. Otherwise it is assumed that one process is started per host, -i.e. each process will be assigned all local devices. -The Open MPI auto-initialization is only used when the JAX processes are launched via `mpirun`/`mpiexec`. +Here, `mesh` contains devices from all processes. We use it to create +`global_array`, logically a single shared array, stored distributed across +devices from all processes. + +Every process must apply the same operations, in the same order, to +`global_array`. XLA automatically partitions those computations, for example +inserting communication collectives to compute the `jnp.sum` over the full +array. We can print the final result because its value is replicated across +processes. + +We can run this code locally on CPU, e.g. using 4 processes and 2 CPU devices +per process: + +```bash +export JAX_NUM_CPU_DEVICES=2 +num_processes=4 + +range=$(seq 0 $(($num_processes - 1))) + +for i in $range; do + python toy.py $i $num_processes > /tmp/toy_$i.out & +done + +wait + +for i in $range; do + echo "=================== process $i output ===================" + cat /tmp/toy_$i.out + echo +done +``` + +Outputs: + +```text +=================== process 0 output =================== +device TFRT_CPU_0 has local data [[0 1 2 3]] +device TFRT_CPU_1 has local data [[4 5 6 7]] +process=0 got result: -0.12398731708526611 + +=================== process 1 output =================== +device TFRT_CPU_131072 has local data [[ 8 9 10 11]] +device TFRT_CPU_131073 has local data [[12 13 14 15]] +process=1 got result: -0.12398731708526611 + +=================== process 2 output =================== +device TFRT_CPU_262144 has local data [[16 17 18 19]] +device TFRT_CPU_262145 has local data [[20 21 22 23]] +process=2 got result: -0.12398731708526611 + +=================== process 3 output =================== +device TFRT_CPU_393216 has local data [[24 25 26 27]] +device TFRT_CPU_393217 has local data [[28 29 30 31]] +process=3 got result: -0.12398731708526611 +``` + +This might not look so different from single-controller JAX code, and in fact, +this is exactly how you'd write the single-controller version of the same +program! (We don't technically need to call {func}`jax.distributed.initialize` +for single-controller, but it doesn't hurt.) Let's run the same code from a +single process: + +```text +JAX_NUM_CPU_DEVICES=8 python toy.py 0 1 +``` + +Outputs: + +```text +device TFRT_CPU_0 has local data [[0 1 2 3]] +device TFRT_CPU_1 has local data [[4 5 6 7]] +device TFRT_CPU_2 has local data [[ 8 9 10 11]] +device TFRT_CPU_3 has local data [[12 13 14 15]] +device TFRT_CPU_4 has local data [[16 17 18 19]] +device TFRT_CPU_5 has local data [[20 21 22 23]] +device TFRT_CPU_6 has local data [[24 25 26 27]] +device TFRT_CPU_7 has local data [[28 29 30 31]] +process=0 got result: -0.12398731708526611 +``` + +The data is sharded across eight devices on one process rather than eight +devices across four processes, but otherwise we're running the same operations +over the same data. + +## Terminology + +It's worth pinning down some terminology. + +We sometimes call each Python process running JAX computations a **controller**, +but the two terms are essentially synonymous. + +Each process has a set of **local devices**, meaning it can transfer data to and +from those devices' memories and run computation on those devices without +involving any other processes. The local devices are usually physically attached +to the process's corresponding host, e.g. via PCI. A device can only be local to +one process; that is, the local device sets are disjoint. A process's local +devices can be queried by evaluating {func}`jax.local_devices()`. We sometimes +use the term **addressable** to mean the same thing as local. + +```{figure} _static/multi_process/controller_and_local_devices.png +:alt: Illustration of how a process/controller and local devices fit into a larger multi-host cluster. The "global devices" are all devices in the cluster. + +Illustration of how a process/controller and local devices fit into a larger +multi-host cluster. The "global devices" are all devices in the cluster. +``` + +The devices across all processes are called the **global devices**. The list of +global devices is queried by {func}`jax.devices()`. That list of all devices is +populated by running {func}`jax.distributed.initialize` on all processes, which +sets up a simple distributed system connecting the processes. + +We often use the terms **global** and **local** to describe process-spanning and +process-local concepts in general. For example, a "local array" could be a numpy +array that's only visible to a single process, vs. a JAX "global array" is +conceptually visible to all processes. + +## Setting up multiple JAX processes + +In practice, setting up multiple JAX processes looks a bit different from the +toy example, which is run from a single host machine. We usually launch each +process on a separate host, or have multiple hosts with multiple processes each. +We can do that directly using `ssh`, or with a cluster manager like Slurm or +Kubernetes. In any case, **you must manually run your JAX program on each +host!** JAX doesn’t automatically start multiple processes from a single program +invocation. + +However they're launched, the Python processes need to run +{func}`jax.distributed.initialize`. When using Slurm, Kubernetes, or any Cloud +TPU deployment, we can run {func}`jax.distributed.initialize` with no arguments +as they're automatically populated. Initializing the system means we can run +{func}`jax.devices()` to report all devices across all processes. + +```{warning} +{func}`jax.distributed.initialize` must be called before running +{func}`jax.devices()`, {func}`jax.local_devices()`, or running any computations +on devices (e.g. with {mod}`jax.numpy`). Otherwise the JAX process won't be +aware of any non-local devices. (Using {func}`jax.config` or other +non-device-accessing functionality is ok.) {func}`jax.distributed.initialize` +will raise an error if you accidentally call it after accessing any devices. +``` + +### GPU Example + +We can run multi-controller JAX on a cluster of [GPU machines][gpu_machines]. +For example, after creating four VMs on Google Cloud with two GPUs per VM, we +can run the following JAX program on every VM. In this example, we provide +arguments to {func}`jax.distributed.initialize` explicitly. The coordinator +address, process id, and number of processes are read from the command line. ```python +# In file gpu_example.py... + import jax +import sys + +# Get the coordinator_address, process_id, and num_processes from the command line. +coord_addr = sys.argv[1] +proc_id = int(sys.argv[2]) +num_procs = int(sys.argv[3]) + +# Initialize the GPU machines. +jax.distributed.initialize(coordinator_address=coord_addr, + num_processes=num_procs, + process_id=proc_id) +print("process id =", jax.process_index()) +print("global devices =", jax.devices()) +print("local devices =", jax.local_devices()) +``` + +For example, if the first VM has address `192.168.0.1`, then you would run +`python3 gpu_example.py 192.168.0.1:8000 0 4` on the first VM, `python3 +gpu_example.py 192.168.0.1:8000 1 4` on the second VM, and so on. After running +the JAX program on all four VMs, the first process prints the following. + +```text +process id = 0 +global devices = [CudaDevice(id=0), CudaDevice(id=1), CudaDevice(id=2), CudaDevice(id=3), CudaDevice(id=4), CudaDevice(id=5), CudaDevice(id=6), CudaDevice(id=7)] +local devices = [CudaDevice(id=0), CudaDevice(id=1)] +``` + +The process successfully sees all eight GPUs as global devices, as well as its +two local devices. Similarly, the second process prints the following. + +```text +process id = 1 +global devices = [CudaDevice(id=0), CudaDevice(id=1), CudaDevice(id=2), CudaDevice(id=3), CudaDevice(id=4), CudaDevice(id=5), CudaDevice(id=6), CudaDevice(id=7)] +local devices = [CudaDevice(id=2), CudaDevice(id=3)] +``` +This VM sees the same global devices, but has a different set of local devices. + +### TPU Example + +As another example, we can run on [Cloud TPU][cloud_tpu]. After creating a +`v5litepod-16` (which has 4 host machines), we might want to test that we can +connect the processes and list all devices: + +```text +$ TPU_NAME=jax-demo +$ EXTERNAL_IPS=$(gcloud compute tpus tpu-vm describe $TPU_NAME --zone 'us-central1-a' \ + | grep externalIp | cut -d: -f2) +$ cat << EOF > demo.py +import jax jax.distributed.initialize() +if jax.process_index() == 0: + print(jax.devices()) +EOF +$ echo $EXTERNAL_IPS | xargs -n 1 -P 0 bash -c ' +scp demo.py $0: +ssh $0 "pip -q install -U jax[tpu]" +ssh $0 "python demo.py" ' ``` -On TPU at present calling {func}`jax.distributed.initialize` is optional, but -recommended since it enables additional checkpointing and health checking features. +Here we're using `xargs` to run multiple `ssh` commands in parallel, each one +running the same Python program on one of the TPU host machines. In the Python +code, we use {func}`jax.process_index()` to print only on one process. Here's +what it prints: -### Local vs. global devices +```text +[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=5, process_index=0, coords=(1,1,0), core_on_chip=0), TpuDevice(id=2, process_index=1, coords=(2,0,0), core_on_chip=0), TpuDevice(id=3, process_index=1, coords=(3,0,0), core_on_chip=0), TpuDevice(id=6, process_index=1, coords=(2,1,0), core_on_chip=0), TpuDevice(id=7, process_index=1, coords=(3,1,0), core_on_chip=0), TpuDevice(id=8, process_index=2, coords=(0,2,0), core_on_chip=0), TpuDevice(id=9, process_index=2, coords=(1,2,0), core_on_chip=0), TpuDevice(id=12, process_index=2, coords=(0,3,0), core_on_chip=0), TpuDevice(id=13, process_index=2, coords=(1,3,0), core_on_chip=0), TpuDevice(id=10, process_index=3, coords=(2,2,0), core_on_chip=0), TpuDevice(id=11, process_index=3, coords=(3,2,0), core_on_chip=0), TpuDevice(id=14, process_index=3, coords=(2,3,0), core_on_chip=0), TpuDevice(id=15, process_index=3, coords=(3,3,0), core_on_chip=0)] +``` -Before we get to running multi-process computations from your program, it’s -important to understand the distinction between *local* and *global* devices. +Woohoo, look at all those TPU cores! + +### Kubernetes Example + +Running multi-controller JAX on a Kubernetes cluster is almost identical in spirit to the GPU and TPU examples above: every pod runs the same Python program, JAX discovers its peers, and the cluster behaves like one giant machine. + +1. **Container image** - start from a JAX-enabled image, e.g. one of the public JAX AI images on Google Artifact Registry ([TPU][google-artifact-tpu] / [GPU][google-artifact-gpu]) or NVIDIA ([NGC][nvidia-ngc] / [JAX-Toolbox][nvidia-jax-toolbox]). + +2. **Workload type** - use either a [JobSet][k8s-jobset] or an [indexed Job][k8s-indexed-job]. Each replica corresponds to one JAX process. + +3. **Service Account** - JAX needs permission to list the pods that belong to the job so that processes discover their peers. A minimal RBAC setup is provided in [examples/k8s/svc-acct.yaml][rbac-svc-acct]. + +Below is a [minimal JobSet][minimal-jobset] that launches two replicas. Replace the placeholders - +image, GPU count, and any private registry secrets - with values that match your environment. + +```yaml +apiVersion: jobset.x-k8s.io/v1alpha2 +kind: JobSet +metadata: + name: jaxjob +spec: + replicatedJobs: + - name: workers + template: + spec: + parallelism: 2 + completions: 2 + backoffLimit: 0 + template: + spec: + serviceAccountName: jax-job-sa # kubectl apply -f svc-acct.yaml + restartPolicy: Never + imagePullSecrets: + # https://k8s.io/docs/tasks/configure-pod-container/pull-image-private-registry/ + - name: null + containers: + - name: main + image: null # e.g. ghcr.io/nvidia/jax:jax + imagePullPolicy: Always + resources: + limits: + cpu: 1 + # https://k8s.io/docs/tasks/manage-gpus/scheduling-gpus/ + nvidia.com/gpu: null + command: + - python + args: + - -c + - | + import jax + jax.distributed.initialize() + print(jax.devices()) + print(jax.local_devices()) + assert jax.process_count() > 1 + assert len(jax.devices()) > len(jax.local_devices()) +``` -**A process’s *local* devices are those that it can directly address and launch -computations on.** For example, on a GPU cluster, each host can only launch -computations on the directly attached GPUs. On a Cloud TPU pod, each host can -only launch computations on the 8 TPU cores attached directly to that host (see -the -[Cloud TPU System Architecture](https://cloud.google.com/tpu/docs/system-architecture) -documentation for more details). You can see a process’s local devices via -{func}`jax.local_devices()`. +Apply the manifest and watch the pods complete: -**The *global* devices are the devices across all processes.** A computation can -span devices across processes and perform collective operations via the direct -communication links between devices, as long as each process launches the -computation on its local devices. You can see all available global devices via -{func}`jax.devices()`. A process’s local devices are always a subset of the -global devices. +```bash +$ kubectl apply -f example.yaml +$ kubectl get pods -l jobset.sigs.k8s.io/jobset-name=jaxjob +NAME READY STATUS RESTARTS AGE +jaxjob-workers-0-0-xpx8l 0/1 Completed 0 8m32s +jaxjob-workers-0-1-ddkq8 0/1 Completed 0 8m32s +``` -### Running multi-process computations +When the job finishes, inspect the logs to confirm that every process saw all accelerators: -So how do you actually run a computation involving cross-process communication? -**Use the same parallel evaluation APIs that you would in a single process!** +```bash +$ kubectl logs -l jobset.sigs.k8s.io/jobset-name=jaxjob +[CudaDevice(id=0), CudaDevice(id=1)] +[CudaDevice(id=0)] +[CudaDevice(id=0), CudaDevice(id=1)] +[CudaDevice(id=1)] +``` + +Every pod should have the same set of global devices and a different set of local devices. At this point, you can replace the inline script with your real JAX program. + +Once the processes are set up, we can start building global {class}`jax.Array`s +and running computations. The remaining Python code examples in this tutorial +are meant to be run on all processes simultaneously, after running +{func}`jax.distributed.initialize`. + +## Meshes, shardings, and computations can span processes and hosts + +Programming multiple processes from JAX usually looks just like programming a +single process, just with more devices! The main exceptions to this are around +data coming in or out of JAX, e.g. when loading from external data sources. +We'll first go over the basics of multi-process computations here, which largely +look the same as their single-process counterparts. The next section goes over +some data loading fundamentals, i.e. how to create JAX Arrays from non-JAX +sources. + +Recall a {class}`jax.sharding.Mesh` pairs an array of {class}`jax.Device`s with +a sequence of names, with one name per array axis. By creating a `Mesh` using +devices from multiple processes, then using that mesh in a +{class}`jax.sharding.Sharding`, we can construct {class}`jax.Array`s sharded +over devices from multiple processes. + +Here's an example that directly constructs a `Mesh` using {func}`jax.devices()` +to get devices from all processes: + +```python +from jax.sharding import Mesh +mesh = Mesh(jax.devices(), ('a',)) + +# in this case, the same as +mesh = jax.make_mesh((jax.device_count(),), ('a',)) # use this in practice +``` + +You should probably use the {func}`jax.make_mesh` helper in practice, not only +because it's simpler but also because it can choose more performant device +orderings automatically, but we're spelling it out here. By default it includes +all devices across processes, just like {func}`jax.devices()`. + +Once we have a mesh, we can shard arrays over it. There are a few ways to +efficiently build process-spanning arrays, detailed in the next section, but for +now we'll stick to `jax.device_put` for simplicity: + +```python +arr = jax.device_put(jnp.ones((32, 32)), NamedSharding(mesh, P('a'))) +if jax.process_index() == 0: + jax.debug.visualize_array_sharding(arr) +``` + +On process 0, this is printed: + +``` +┌───────────────────────┐ +│ TPU 0 │ +├───────────────────────┤ +│ TPU 1 │ +├───────────────────────┤ +│ TPU 4 │ +├───────────────────────┤ +│ TPU 5 │ +├───────────────────────┤ +│ TPU 2 │ +├───────────────────────┤ +│ TPU 3 │ +├───────────────────────┤ +│ TPU 6 │ +├───────────────────────┤ +│ TPU 7 │ +├───────────────────────┤ +│ TPU 8 │ +├───────────────────────┤ +│ TPU 9 │ +├───────────────────────┤ +│ TPU 12 │ +├───────────────────────┤ +│ TPU 13 │ +├───────────────────────┤ +│ TPU 10 │ +├───────────────────────┤ +│ TPU 11 │ +├───────────────────────┤ +│ TPU 14 │ +├───────────────────────┤ +│ TPU 15 │ +└───────────────────────┘ +``` + +Let's try a slightly more interesting computation! + +```python +mesh = jax.make_mesh((jax.device_count() // 2, 2), ('a', 'b')) + +def device_put(x, spec): + return jax.device_put(x, NamedSharding(mesh, spec)) + +# construct global arrays by sharding over the global mesh +x = device_put(jnp.ones((4096, 2048)), P('a', 'b')) +y = device_put(jnp.ones((2048, 4096)), P('b', None)) + +# run a distributed matmul +z = jax.nn.relu(x @ y) + +# inspect the sharding of the result +if jax.process_index() == 0: + jax.debug.visualize_array_sharding(z) + print() + print(z.sharding) +``` -For example, {func}`~jax.experimental.shard_map.shard_map` can be used -to run a parallel computation across multiple processes. (If you’re -not already familiar with how to use `shard_map` to run across -multiple devices within a single process, check out the -{doc}`/sharded-computation` tutorial.) Conceptually, this can be -thought of as running a pmap over a single array sharded across hosts, -where each host “sees” only its local shard of the input and output. +On process 0, this is printed: -Here’s an example of multi-process pmap in action: +``` +┌───────────────────────┐ +│ TPU 0,1 │ +├───────────────────────┤ +│ TPU 4,5 │ +├───────────────────────┤ +│ TPU 8,9 │ +├───────────────────────┤ +│ TPU 12,13 │ +├───────────────────────┤ +│ TPU 2,3 │ +├───────────────────────┤ +│ TPU 6,7 │ +├───────────────────────┤ +│ TPU 10,11 │ +├───────────────────────┤ +│ TPU 14,15 │ +└───────────────────────┘ + +NamedSharding(mesh=Mesh('a': 8, 'b': 2), spec=PartitionSpec('a',), memory_kind=device) +``` + +Here, just from evaluating `x @ y` on all processes, XLA is automatically +generating and running a distributed matrix multiplication. The result is +sharded against the mesh like `P('a', None)`, since in this case the matmul +included a `psum` over the `'b'` axis. + +```{warning} +When applying JAX computations to process-spanning arrays, to avoid deadlocks +and hangs, **it's crucial that all processes with participating devices run the +same computation in the same order**. That's because the computation may +involve collective communication barriers. If a device over which an array is +sharded does not join in the collective because its controller didn't issue the +same computation, the other devices are left waiting. For example, if only the +first three processes evaluated `x @ y`, while the last process evaluated `y @ +x`, the computation would likely hang indefinitely. This assumption, +computations on process-spanning arrays are run on all participating processes +in the same order, is mostly unchecked. + +So the easiest way to avoid deadlocks in multi-process JAX is to run the same +Python code on every process, and beware of any control flow that depends on +{func}`jax.process_index()` and includes communication. +``` + +If a process-spanning array is sharded over devices on different processes, it +is an error to perform operations on the array that require the data to be +available locally to a process, like printing. For example, if we run `print(z)` +in the preceding example, we see + +``` +RuntimeError: Fetching value for `jax.Array` that spans non-addressable (non process local) devices is not possible. You can use `jax.experimental.multihost_utils.process_allgather` to print the global array or use `.addressable_shards` method of jax.Array to inspect the addressable (process local) shards. +``` + +To print the full array value, we must first ensure it's replicated over +processes (but not necessarily over each process's local devices), e.g. using +`jax.device_put`. In the above example, we can write at the end: + +``` +w = device_put(z, P(None, None)) +if jax.process_index() == 0: + print(w) +``` + +Be careful not to write the {func}`jax.device_put` under the `if process_index() +== 0`, because that would lead to a deadlock as only process 0 initiates the +collective communication and waits indefinitely for the other processes. +The {mod}`jax.experimental.multihost_utils` module has some functions that +make it easier to process global {class}`jax.Array`s (e.g., +{func}`jax.experimental.multihost_utils.process_allgather`). + +Alternatively, to print or otherwise perform Python operations on only +process-local data, we can access `z.addressable_shards`. Accessing that +attribute does not require any communication, so any subset of processes can do +it without needing the others. That attribute is not available under a +{func}`jax.jit`. + +## Making process-spanning arrays from external data + +There are three main ways to create process-spanning {class}`jax.Array`s from +external data sources (e.g. numpy arrays from a data loader): + +1. Create or load the full array on all processes, then shard onto devices using + {func}`jax.device_put`; + +2. Create or load on each process an array representing just the data that will + be locally sharded and stored on that process's devices, then shard onto + devices using {func}`jax.make_array_from_process_local_data`; + +3. Create or load on each process's devices separate arrays, each representing + the data to be stored on that device, then assemble them without any data + movement using {func}`jax.make_array_from_single_device_arrays`. + +The latter two are most often used in practice, since it's often too expensive +to materialize the full global data in every process. + +The toy example above uses {func}`jax.device_put`. + +{func}`jax.make_array_from_process_local_data` is often used for distributed data +loading. It's not as general as {func}`jax.make_array_from_single_device_arrays`, +because it doesn't directly specify which slice of the process-local data goes +on each local device. This is convenient when loading data-parallel batches, +because it doesn't matter exactly which microbatch goes on each device. For +example: ```python -# The following is run in parallel on each host on a GPU cluster or TPU pod slice. ->>> import jax ->>> jax.distributed.initialize() # On GPU, see above for the necessary arguments. ->>> jax.device_count() # total number of accelerator devices in the cluster -32 ->>> jax.local_device_count() # number of accelerator devices attached to this host -8 -# The psum is performed over all mapped devices across the pod slice ->>> xs = jax.numpy.ones(jax.local_device_count()) ->>> jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(xs) -ShardedDeviceArray([32., 32., 32., 32., 32., 32., 32., 32.], dtype=float32) -``` - -**It’s very important that all processes run the same cross-process computations -in the same order.** Running the same JAX Python program in each process is -usually sufficient. Some common pitfalls to look out for that may cause -differently-ordered computations despite running the same program: - -* Processes passing differently-shaped inputs to the same parallel function - can cause hangs or incorrect return values. Differently-shaped inputs are - safe so long as they result in identically-shaped per-device data shards - across processes; e.g. passing in different leading batch sizes in order to - run on different numbers of local devices per process is ok, but having each - process pad its batch to a different max example length is not. - -* “Last batch” issues where a parallel function is called in a (training) - loop, and one or more processes exit the loop earlier than the rest. This - will cause the rest to hang waiting for the already-finished processes to - start the computation. - -* Conditions based on non-deterministic ordering of collections can cause code - processes to hang. For example, iterating over - `set` on current Python versions or `dict` [before Python 3.7](https://mail.python.org/pipermail/python-dev/2017-December/151283.html) - may result in a different ordering on different processes, even with the - same insertion order. +# target (micro)batch size across the whole cluster +batch_size = 1024 +# how many examples each process should load per batch +per_process_batch_size = batch_size // jax.process_count() +# how many examples each device will process per batch +per_device_batch_size = batch_size // jax.device_count() + +# make a data-parallel mesh and sharding +mesh = jax.make_mesh((jax.device_count(),), ('batch')) +sharding = NamedSharding(mesh, P('batch')) + +# our "data loader". each process loads a different set of "examples". +process_batch = np.random.rand(per_process_batch_size, 2048, 42) + +# assemble a global array containing the per-process batches from all processes +global_batch = jax.make_array_from_process_local_data(sharding, process_batch) + +# sanity check that everything got sharded correctly +assert global_batch.shape[0] == batch_size +assert process_batch.shape[0] == per_process_batch_size +assert global_batch.addressable_shards[0].data.shape[0] == per_device_batch_size +``` + +{func}`jax.make_array_from_single_device_arrays` is the most general way to +build a process-spanning array. It's often used after performing +{func}`jax.device_put`s to send each device its required data. This is the +lowest-level option, since all data movement is performed manually (via e.g. +{func}`jax.device_put`). Here's an example: + +```python +shape = (jax.process_count(), jax.local_device_count()) +mesh = jax.make_mesh(shape, ('i', 'j')) +sharding = NamedSharding(mesh, P('i', 'j')) + +# manually create per-device data equivalent to np.arange(jax.device_count()) +# i.e. each device will get a single scalar value from 0..N +local_arrays = [ + jax.device_put( + jnp.array([[jax.process_index() * jax.local_device_count() + i]]), + device) + for i, device in enumerate(jax.local_devices()) +] + +# assemble a global array from the local_arrays across all processes +global_array = jax.make_array_from_single_device_arrays( + shape=shape, + sharding=sharding, + arrays=local_arrays) + +# sanity check +assert (np.all( + jax.experimental.multihost_utils.process_allgather(global_array) == + np.arange(jax.device_count()).reshape(global_array.shape))) +``` + +[cloud_tpu]: https://cloud.google.com/tpu?hl=en +[distributed_arrays]: https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html +[gpu_machines]: https://cloud.google.com/compute/docs/gpus +[unified_sharding]: https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html +[google-artifact-tpu]: https://console.cloud.google.com/artifacts/docker/cloud-tpu-images/us/jax-ai-image/tpu +[google-artifact-gpu]: https://console.cloud.google.com/artifacts/docker/deeplearning-images/us-central1/jax-ai-image/gpu +[nvidia-ngc]: https://catalog.ngc.nvidia.com/orgs/nvidia/containers/jax +[nvidia-jax-toolbox]: https://github.com/NVIDIA/JAX-Toolbox +[k8s-jobset]: https://github.com/kubernetes-sigs/jobset +[k8s-indexed-job]: https://kubernetes.io/docs/concepts/workloads/controllers/job/#parallel-jobs +[rbac-svc-acct]: https://github.com/jax-ml/jax/blob/main/examples/k8s/svc-acct.yaml +[minimal-jobset]: https://github.com/jax-ml/jax/blob/main/examples/k8s/example.yaml diff --git a/docs/notebooks/Common_Gotchas_in_JAX.ipynb b/docs/notebooks/Common_Gotchas_in_JAX.ipynb index a1435c4e557e..5879630ac818 100644 --- a/docs/notebooks/Common_Gotchas_in_JAX.ipynb +++ b/docs/notebooks/Common_Gotchas_in_JAX.ipynb @@ -307,7 +307,7 @@ "id": "go3L4x3w4-9p" }, "source": [ - "If we try to update a JAX device array in-place, however, we get an __error__! (☉_☉)" + "If we try to do in-place indexed updating on a `jax.Array`, however, we get an __error__! (☉_☉)" ] }, { @@ -346,7 +346,7 @@ "evalue": "ignored", "output_type": "error", "traceback": [ - "\u001b[0;31mTypeError\u001b[0m\u001b[0;31m:\u001b[0m '' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[] method: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html\n" + "\u001b[0;31mTypeError\u001b[0m\u001b[0;31m:\u001b[0m '' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[] method: https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html\n" ] } ], @@ -357,6 +357,45 @@ "jax_array[1, :] = 1.0" ] }, + { + "cell_type": "markdown", + "id": "8f520bec", + "metadata": {}, + "source": [ + "And if we try to do `__iadd__`-style in-place updating, we get __different behavior than NumPy__! (☉_☉) (☉_☉)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "20fbed45", + "metadata": {}, + "outputs": [], + "source": [ + "jax_array = jnp.array([10, 20])\n", + "jax_array_new = jax_array\n", + "jax_array_new += 10\n", + "print(jax_array_new) # `jax_array_new` is rebound to a new value [20, 30], but...\n", + "print(jax_array) # the original value is unodified as [10, 20] !\n", + "\n", + "numpy_array = np.array([10, 20])\n", + "numpy_array_new = numpy_array\n", + "numpy_array_new += 10\n", + "print(numpy_array_new) # `numpy_array_new is numpy_array`, and it was updated\n", + "print(numpy_array) # in-place, so both are [20, 30] !" + ] + }, + { + "cell_type": "markdown", + "id": "2604e220", + "metadata": {}, + "source": [ + "That's because NumPy defines `__iadd__` to perform in-place mutation. In\n", + "contrast, `jax.Array` doesn't define an `__iadd__`, so Python treats\n", + "`jax_array_new += 10` as syntactic sugar for `jax_array_new = jax_array_new +\n", + "10`, rebinding the variable without mutating any arrays." + ] + }, { "cell_type": "markdown", "metadata": { @@ -365,7 +404,7 @@ "source": [ "Allowing mutation of variables in-place makes program analysis and transformation difficult. JAX requires that programs are pure functions.\n", "\n", - "Instead, JAX offers a _functional_ array update using the [`.at` property on JAX arrays](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html#jax.numpy.ndarray.at)." + "Instead, JAX offers a _functional_ array update using the [`.at` property on JAX arrays](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html#jax.numpy.ndarray.at)." ] }, { @@ -415,6 +454,7 @@ } ], "source": [ + "jax_array = jnp.zeros((3,3), dtype=jnp.float32)\n", "updated_array = jax_array.at[1, :].set(1.0)\n", "print(\"updated array:\\n\", updated_array)" ] @@ -521,7 +561,7 @@ "id": "sTjJ3WuaDyqU" }, "source": [ - "For more details on indexed array updates, see the [documentation for the `.at` property](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html#jax.numpy.ndarray.at)." + "For more details on indexed array updates, see the [documentation for the `.at` property](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html#jax.numpy.ndarray.at)." ] }, { @@ -604,7 +644,7 @@ "id": "NAcXJNAcDi_v" }, "source": [ - "If you would like finer-grained control over the behavior for out-of-bound indices, you can use the optional parameters of [`ndarray.at`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html); for example:" + "If you would like finer-grained control over the behavior for out-of-bound indices, you can use the optional parameters of [`ndarray.at`](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html); for example:" ] }, { @@ -971,7 +1011,7 @@ "evalue": "ignored", "output_type": "error", "traceback": [ - "\u001b[0;31mNonConcreteBooleanIndexError\u001b[0m\u001b[0;31m:\u001b[0m Array boolean indices must be concrete; got ShapedArray(bool[5])\n\nSee https://jax.readthedocs.io/en/latest/errors.html#jax.errors.NonConcreteBooleanIndexError\n" + "\u001b[0;31mNonConcreteBooleanIndexError\u001b[0m\u001b[0;31m:\u001b[0m Array boolean indices must be concrete; got ShapedArray(bool[5])\n\nSee https://docs.jax.dev/en/latest/errors.html#jax.errors.NonConcreteBooleanIndexError\n" ] } ], @@ -1296,7 +1336,7 @@ "While `jax.numpy` makes every attempt to replicate the behavior of numpy's API, there do exist corner cases where the behaviors differ.\n", "Many such cases are discussed in detail in the sections above; here we list several other known places where the APIs diverge.\n", "\n", - "- For binary operations, JAX's type promotion rules differ somewhat from those used by NumPy. See [Type Promotion Semantics](https://jax.readthedocs.io/en/latest/type_promotion.html) for more details.\n", + "- For binary operations, JAX's type promotion rules differ somewhat from those used by NumPy. See [Type Promotion Semantics](https://docs.jax.dev/en/latest/type_promotion.html) for more details.\n", "- When performing unsafe type casts (i.e. casts in which the target dtype cannot represent the input value), JAX's behavior may be backend dependent, and in general may diverge from NumPy's behavior. Numpy allows control over the result in these scenarios via the `casting` argument (see [`np.ndarray.astype`](https://numpy.org/devdocs/reference/generated/numpy.ndarray.astype.html)); JAX does not provide any such configuration, instead directly inheriting the behavior of [XLA:ConvertElementType](https://www.tensorflow.org/xla/operation_semantics#convertelementtype).\n", "\n", " Here is an example of an unsafe cast with differing results between NumPy and JAX:\n", diff --git a/docs/notebooks/Common_Gotchas_in_JAX.md b/docs/notebooks/Common_Gotchas_in_JAX.md index 80ab69be1ed8..0857edc132fa 100644 --- a/docs/notebooks/Common_Gotchas_in_JAX.md +++ b/docs/notebooks/Common_Gotchas_in_JAX.md @@ -177,7 +177,7 @@ print(numpy_array) +++ {"id": "go3L4x3w4-9p"} -If we try to update a JAX device array in-place, however, we get an __error__! (☉_☉) +If we try to do in-place indexed updating on a `jax.Array`, however, we get an __error__! (☉_☉) ```{code-cell} ipython3 :id: iOscaa_GecEK @@ -197,11 +197,32 @@ jax_array = jnp.zeros((3,3), dtype=jnp.float32) jax_array[1, :] = 1.0 ``` +And if we try to do `__iadd__`-style in-place updating, we get __different behavior than NumPy__! (☉_☉) (☉_☉) + +```{code-cell} ipython3 +jax_array = jnp.array([10, 20]) +jax_array_new = jax_array +jax_array_new += 10 +print(jax_array_new) # `jax_array_new` is rebound to a new value [20, 30], but... +print(jax_array) # the original value is unodified as [10, 20] ! + +numpy_array = np.array([10, 20]) +numpy_array_new = numpy_array +numpy_array_new += 10 +print(numpy_array_new) # `numpy_array_new is numpy_array`, and it was updated +print(numpy_array) # in-place, so both are [20, 30] ! +``` + +That's because NumPy defines `__iadd__` to perform in-place mutation. In +contrast, `jax.Array` doesn't define an `__iadd__`, so Python treats +`jax_array_new += 10` as syntactic sugar for `jax_array_new = jax_array_new + +10`, rebinding the variable without mutating any arrays. + +++ {"id": "7mo76sS25Wco"} Allowing mutation of variables in-place makes program analysis and transformation difficult. JAX requires that programs are pure functions. -Instead, JAX offers a _functional_ array update using the [`.at` property on JAX arrays](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html#jax.numpy.ndarray.at). +Instead, JAX offers a _functional_ array update using the [`.at` property on JAX arrays](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html#jax.numpy.ndarray.at). +++ {"id": "hfloZ1QXCS_J"} @@ -219,6 +240,7 @@ For example, the update above can be written as: :id: PBGI-HIeCP_s :outputId: de13f19a-2066-4df1-d503-764c34585529 +jax_array = jnp.zeros((3,3), dtype=jnp.float32) updated_array = jax_array.at[1, :].set(1.0) print("updated array:\n", updated_array) ``` @@ -261,7 +283,7 @@ print(new_jax_array) +++ {"id": "sTjJ3WuaDyqU"} -For more details on indexed array updates, see the [documentation for the `.at` property](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html#jax.numpy.ndarray.at). +For more details on indexed array updates, see the [documentation for the `.at` property](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html#jax.numpy.ndarray.at). +++ {"id": "oZ_jE2WAypdL"} @@ -292,7 +314,7 @@ jnp.arange(10)[11] +++ {"id": "NAcXJNAcDi_v"} -If you would like finer-grained control over the behavior for out-of-bound indices, you can use the optional parameters of [`ndarray.at`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html); for example: +If you would like finer-grained control over the behavior for out-of-bound indices, you can use the optional parameters of [`ndarray.at`](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html); for example: ```{code-cell} ipython3 :id: -0-MaFddO-xy @@ -664,7 +686,7 @@ x.dtype # --> dtype('float64') While `jax.numpy` makes every attempt to replicate the behavior of numpy's API, there do exist corner cases where the behaviors differ. Many such cases are discussed in detail in the sections above; here we list several other known places where the APIs diverge. -- For binary operations, JAX's type promotion rules differ somewhat from those used by NumPy. See [Type Promotion Semantics](https://jax.readthedocs.io/en/latest/type_promotion.html) for more details. +- For binary operations, JAX's type promotion rules differ somewhat from those used by NumPy. See [Type Promotion Semantics](https://docs.jax.dev/en/latest/type_promotion.html) for more details. - When performing unsafe type casts (i.e. casts in which the target dtype cannot represent the input value), JAX's behavior may be backend dependent, and in general may diverge from NumPy's behavior. Numpy allows control over the result in these scenarios via the `casting` argument (see [`np.ndarray.astype`](https://numpy.org/devdocs/reference/generated/numpy.ndarray.astype.html)); JAX does not provide any such configuration, instead directly inheriting the behavior of [XLA:ConvertElementType](https://www.tensorflow.org/xla/operation_semantics#convertelementtype). Here is an example of an unsafe cast with differing results between NumPy and JAX: diff --git a/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb b/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb index e550cbf36da3..e80c7ae94687 100644 --- a/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb +++ b/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb @@ -17,9 +17,9 @@ "1. using `jax.custom_jvp` and `jax.custom_vjp` to define custom differentiation rules for Python functions that are already JAX-transformable; and\n", "2. defining new `core.Primitive` instances along with all their transformation rules, for example to call into functions from other systems like solvers, simulators, or general numerical computing systems.\n", "\n", - "This notebook is about #1. To read instead about #2, see the [notebook on adding primitives](https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html).\n", + "This notebook is about #1. To read instead about #2, see the [notebook on adding primitives](https://docs.jax.dev/en/latest/notebooks/How_JAX_primitives_work.html).\n", "\n", - "For an introduction to JAX's automatic differentiation API, see [The Autodiff Cookbook](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html). This notebook assumes some familiarity with [jax.jvp](https://jax.readthedocs.io/en/latest/jax.html#jax.jvp) and [jax.grad](https://jax.readthedocs.io/en/latest/jax.html#jax.grad), and the mathematical meaning of JVPs and VJPs." + "For an introduction to JAX's automatic differentiation API, see [The Autodiff Cookbook](https://docs.jax.dev/en/latest/notebooks/autodiff_cookbook.html). This notebook assumes some familiarity with [jax.jvp](https://docs.jax.dev/en/latest/jax.html#jax.jvp) and [jax.grad](https://docs.jax.dev/en/latest/jax.html#jax.grad), and the mathematical meaning of JVPs and VJPs." ] }, { @@ -2035,7 +2035,7 @@ "source": [ "### Working with `list` / `tuple` / `dict` containers (and other pytrees)\n", "\n", - "You should expect standard Python containers like lists, tuples, namedtuples, and dicts to just work, along with nested versions of those. In general, any [pytrees](https://jax.readthedocs.io/en/latest/pytrees.html) are permissible, so long as their structures are consistent according to the type constraints. \n", + "You should expect standard Python containers like lists, tuples, namedtuples, and dicts to just work, along with nested versions of those. In general, any [pytrees](https://docs.jax.dev/en/latest/pytrees.html) are permissible, so long as their structures are consistent according to the type constraints. \n", "\n", "Here's a contrived example with `jax.custom_jvp`:" ] diff --git a/docs/notebooks/Custom_derivative_rules_for_Python_code.md b/docs/notebooks/Custom_derivative_rules_for_Python_code.md index 8a63f142693e..82b97e195bd9 100644 --- a/docs/notebooks/Custom_derivative_rules_for_Python_code.md +++ b/docs/notebooks/Custom_derivative_rules_for_Python_code.md @@ -24,9 +24,9 @@ There are two ways to define differentiation rules in JAX: 1. using `jax.custom_jvp` and `jax.custom_vjp` to define custom differentiation rules for Python functions that are already JAX-transformable; and 2. defining new `core.Primitive` instances along with all their transformation rules, for example to call into functions from other systems like solvers, simulators, or general numerical computing systems. -This notebook is about #1. To read instead about #2, see the [notebook on adding primitives](https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html). +This notebook is about #1. To read instead about #2, see the [notebook on adding primitives](https://docs.jax.dev/en/latest/notebooks/How_JAX_primitives_work.html). -For an introduction to JAX's automatic differentiation API, see [The Autodiff Cookbook](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html). This notebook assumes some familiarity with [jax.jvp](https://jax.readthedocs.io/en/latest/jax.html#jax.jvp) and [jax.grad](https://jax.readthedocs.io/en/latest/jax.html#jax.grad), and the mathematical meaning of JVPs and VJPs. +For an introduction to JAX's automatic differentiation API, see [The Autodiff Cookbook](https://docs.jax.dev/en/latest/notebooks/autodiff_cookbook.html). This notebook assumes some familiarity with [jax.jvp](https://docs.jax.dev/en/latest/jax.html#jax.jvp) and [jax.grad](https://docs.jax.dev/en/latest/jax.html#jax.grad), and the mathematical meaning of JVPs and VJPs. +++ {"id": "9Fg3NFNY-2RY"} @@ -1048,7 +1048,7 @@ Array(-0.91113025, dtype=float32) ### Working with `list` / `tuple` / `dict` containers (and other pytrees) -You should expect standard Python containers like lists, tuples, namedtuples, and dicts to just work, along with nested versions of those. In general, any [pytrees](https://jax.readthedocs.io/en/latest/pytrees.html) are permissible, so long as their structures are consistent according to the type constraints. +You should expect standard Python containers like lists, tuples, namedtuples, and dicts to just work, along with nested versions of those. In general, any [pytrees](https://docs.jax.dev/en/latest/pytrees.html) are permissible, so long as their structures are consistent according to the type constraints. Here's a contrived example with `jax.custom_jvp`: diff --git a/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb b/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb index 8abee469d552..90d92c4ea241 100644 --- a/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb +++ b/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb @@ -1276,7 +1276,7 @@ "id": "3qfPjJdhgerc" }, "source": [ - "So computation follows data placement: when we explicitly shard data with `jax.device_put`, and apply functions to that data, the compiler attempts to parallelize the computation and decide the output sharding. This policy for sharded data is a generalization of [JAX's policy of following explicit device placement](https://jax.readthedocs.io/en/latest/faq.html#controlling-data-and-computation-placement-on-devices)." + "So computation follows data placement: when we explicitly shard data with `jax.device_put`, and apply functions to that data, the compiler attempts to parallelize the computation and decide the output sharding. This policy for sharded data is a generalization of [JAX's policy of following explicit device placement](https://docs.jax.dev/en/latest/faq.html#controlling-data-and-computation-placement-on-devices)." ] }, { @@ -1382,7 +1382,7 @@ "id": "6ZYcK8eXrn0p" }, "source": [ - "We say arrays that have been explicitly placed or sharded with `jax.device_put` are _committed_ to their device(s), and so won't be automatically moved. See the [device placement FAQ](https://jax.readthedocs.io/en/latest/faq.html#controlling-data-and-computation-placement-on-devices) for more information.\n", + "We say arrays that have been explicitly placed or sharded with `jax.device_put` are _committed_ to their device(s), and so won't be automatically moved. See the [device placement FAQ](https://docs.jax.dev/en/latest/faq.html#controlling-data-and-computation-placement-on-devices) for more information.\n", "\n", "When arrays are _not_ explicitly placed or sharded with `jax.device_put`, they are placed _uncommitted_ on the default device.\n", "Unlike committed arrays, uncommitted arrays can be moved and resharded automatically: that is, uncommitted arrays can be arguments to a computation even if other arguments are explicitly placed on different devices.\n", @@ -2339,7 +2339,7 @@ "source": [ "### Generating random numbers\n", "\n", - "JAX comes with a functional, deterministic [random number generator](https://jax.readthedocs.io/en/latest/jep/263-prng.html). It underlies the various sampling functions in the [`jax.random` module](https://jax.readthedocs.io/en/latest/jax.random.html), such as `jax.random.uniform`.\n", + "JAX comes with a functional, deterministic [random number generator](https://docs.jax.dev/en/latest/jep/263-prng.html). It underlies the various sampling functions in the [`jax.random` module](https://docs.jax.dev/en/latest/jax.random.html), such as `jax.random.uniform`.\n", "\n", "JAX's random numbers are produced by a counter-based PRNG, so in principle, random number generation should be a pure map over counter values. A pure map is a trivially partitionable operation in principle. It should require no cross-device communication, nor any redundant computation across devices.\n", "\n", diff --git a/docs/notebooks/Distributed_arrays_and_automatic_parallelization.md b/docs/notebooks/Distributed_arrays_and_automatic_parallelization.md index c207f0ae4a00..79990fefb95d 100644 --- a/docs/notebooks/Distributed_arrays_and_automatic_parallelization.md +++ b/docs/notebooks/Distributed_arrays_and_automatic_parallelization.md @@ -427,7 +427,7 @@ jax.debug.visualize_array_sharding(w_copy) +++ {"id": "3qfPjJdhgerc"} -So computation follows data placement: when we explicitly shard data with `jax.device_put`, and apply functions to that data, the compiler attempts to parallelize the computation and decide the output sharding. This policy for sharded data is a generalization of [JAX's policy of following explicit device placement](https://jax.readthedocs.io/en/latest/faq.html#controlling-data-and-computation-placement-on-devices). +So computation follows data placement: when we explicitly shard data with `jax.device_put`, and apply functions to that data, the compiler attempts to parallelize the computation and decide the output sharding. This policy for sharded data is a generalization of [JAX's policy of following explicit device placement](https://docs.jax.dev/en/latest/faq.html#controlling-data-and-computation-placement-on-devices). +++ {"id": "QRB95LaWuT80"} @@ -484,7 +484,7 @@ except ValueError as e: print_exception(e) +++ {"id": "6ZYcK8eXrn0p"} -We say arrays that have been explicitly placed or sharded with `jax.device_put` are _committed_ to their device(s), and so won't be automatically moved. See the [device placement FAQ](https://jax.readthedocs.io/en/latest/faq.html#controlling-data-and-computation-placement-on-devices) for more information. +We say arrays that have been explicitly placed or sharded with `jax.device_put` are _committed_ to their device(s), and so won't be automatically moved. See the [device placement FAQ](https://docs.jax.dev/en/latest/faq.html#controlling-data-and-computation-placement-on-devices) for more information. When arrays are _not_ explicitly placed or sharded with `jax.device_put`, they are placed _uncommitted_ on the default device. Unlike committed arrays, uncommitted arrays can be moved and resharded automatically: that is, uncommitted arrays can be arguments to a computation even if other arguments are explicitly placed on different devices. @@ -854,7 +854,7 @@ outputId: 479c4d81-cb0b-40a5-89ba-394c10dc3297 ### Generating random numbers -JAX comes with a functional, deterministic [random number generator](https://jax.readthedocs.io/en/latest/jep/263-prng.html). It underlies the various sampling functions in the [`jax.random` module](https://jax.readthedocs.io/en/latest/jax.random.html), such as `jax.random.uniform`. +JAX comes with a functional, deterministic [random number generator](https://docs.jax.dev/en/latest/jep/263-prng.html). It underlies the various sampling functions in the [`jax.random` module](https://docs.jax.dev/en/latest/jax.random.html), such as `jax.random.uniform`. JAX's random numbers are produced by a counter-based PRNG, so in principle, random number generation should be a pure map over counter values. A pure map is a trivially partitionable operation in principle. It should require no cross-device communication, nor any redundant computation across devices. diff --git a/docs/notebooks/Neural_Network_and_Data_Loading.ipynb b/docs/notebooks/Neural_Network_and_Data_Loading.ipynb index a7ef2a017048..4c9b6c5e48a7 100644 --- a/docs/notebooks/Neural_Network_and_Data_Loading.ipynb +++ b/docs/notebooks/Neural_Network_and_Data_Loading.ipynb @@ -41,7 +41,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 1, "metadata": { "id": "OksHydJDtbbI" }, @@ -64,7 +64,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 2, "metadata": { "id": "-fmWA06xYE7d" }, @@ -102,7 +102,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 3, "metadata": { "id": "7APc6tD7TiuZ" }, @@ -136,7 +136,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, "metadata": { "id": "4sW2A5mnXHc5", "outputId": "9d3b29e8-fab3-4ecb-9f63-bc8c092f9006" @@ -159,7 +159,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 5, "metadata": { "id": "PpyQxuedXfhp", "outputId": "d5d20211-b6da-44e9-f71e-946f2a9d0fc4" @@ -184,7 +184,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 6, "metadata": { "id": "oJOOncKMXbwK", "outputId": "31285fab-7667-4871-fcba-28e86adc3fc6" @@ -229,7 +229,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 7, "metadata": { "id": "6lTI6I4lWdh5" }, @@ -268,21 +268,37 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 8, "metadata": { "id": "gEvWt8_u2pqG", "outputId": "2c83a679-9ce5-4c67-bccb-9ea835a8eaf6" }, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/m/.opt/miniforge3/envs/jax/lib/python3.12/pty.py:95: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n", + " pid, fd = os.forkpty()\n" + ] + }, { "name": "stdout", "output_type": "stream", "text": [ - "Requirement already satisfied: torch in /opt/anaconda3/lib/python3.7/site-packages (1.4.0)\n", - "Requirement already satisfied: torchvision in /opt/anaconda3/lib/python3.7/site-packages (0.5.0)\n", - "Requirement already satisfied: numpy in /opt/anaconda3/lib/python3.7/site-packages (from torchvision) (1.17.2)\n", - "Requirement already satisfied: six in /opt/anaconda3/lib/python3.7/site-packages (from torchvision) (1.12.0)\n", - "Requirement already satisfied: pillow>=4.1.1 in /opt/anaconda3/lib/python3.7/site-packages (from torchvision) (6.2.0)\n" + "Requirement already satisfied: torch in /home/m/.opt/miniforge3/envs/jax/lib/python3.12/site-packages (2.4.1)\n", + "Requirement already satisfied: torchvision in /home/m/.opt/miniforge3/envs/jax/lib/python3.12/site-packages (0.19.1)\n", + "Requirement already satisfied: filelock in /home/m/.opt/miniforge3/envs/jax/lib/python3.12/site-packages (from torch) (3.16.0)\n", + "Requirement already satisfied: typing-extensions>=4.8.0 in /home/m/.opt/miniforge3/envs/jax/lib/python3.12/site-packages (from torch) (4.12.2)\n", + "Requirement already satisfied: sympy in /home/m/.opt/miniforge3/envs/jax/lib/python3.12/site-packages (from torch) (1.13.2)\n", + "Requirement already satisfied: networkx in /home/m/.opt/miniforge3/envs/jax/lib/python3.12/site-packages (from torch) (3.3)\n", + "Requirement already satisfied: jinja2 in /home/m/.opt/miniforge3/envs/jax/lib/python3.12/site-packages (from torch) (3.1.4)\n", + "Requirement already satisfied: fsspec in /home/m/.opt/miniforge3/envs/jax/lib/python3.12/site-packages (from torch) (2024.9.0)\n", + "Requirement already satisfied: setuptools in /home/m/.opt/miniforge3/envs/jax/lib/python3.12/site-packages (from torch) (73.0.1)\n", + "Requirement already satisfied: numpy in /home/m/.opt/miniforge3/envs/jax/lib/python3.12/site-packages (from torchvision) (1.26.4)\n", + "Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /home/m/.opt/miniforge3/envs/jax/lib/python3.12/site-packages (from torchvision) (10.4.0)\n", + "Requirement already satisfied: MarkupSafe>=2.0 in /home/m/.opt/miniforge3/envs/jax/lib/python3.12/site-packages (from jinja2->torch) (2.1.5)\n", + "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /home/m/.opt/miniforge3/envs/jax/lib/python3.12/site-packages (from sympy->torch) (1.3.0)\n" ] } ], @@ -292,7 +308,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 9, "metadata": { "cellView": "both", "id": "94PjXZ8y3dVF" @@ -301,38 +317,24 @@ "source": [ "import numpy as np\n", "from jax.tree_util import tree_map\n", - "from torch.utils import data\n", + "from torch.utils.data import DataLoader, default_collate\n", "from torchvision.datasets import MNIST\n", "\n", "def numpy_collate(batch):\n", - " return tree_map(np.asarray, data.default_collate(batch))\n", - "\n", - "class NumpyLoader(data.DataLoader):\n", - " def __init__(self, dataset, batch_size=1,\n", - " shuffle=False, sampler=None,\n", - " batch_sampler=None, num_workers=0,\n", - " pin_memory=False, drop_last=False,\n", - " timeout=0, worker_init_fn=None):\n", - " super(self.__class__, self).__init__(dataset,\n", - " batch_size=batch_size,\n", - " shuffle=shuffle,\n", - " sampler=sampler,\n", - " batch_sampler=batch_sampler,\n", - " num_workers=num_workers,\n", - " collate_fn=numpy_collate,\n", - " pin_memory=pin_memory,\n", - " drop_last=drop_last,\n", - " timeout=timeout,\n", - " worker_init_fn=worker_init_fn)\n", + " \"\"\"\n", + " Collate function specifies how to combine a list of data samples into a batch.\n", + " default_collate creates pytorch tensors, then tree_map converts them into numpy arrays.\n", + " \"\"\"\n", + " return tree_map(np.asarray, default_collate(batch))\n", "\n", - "class FlattenAndCast(object):\n", - " def __call__(self, pic):\n", - " return np.ravel(np.array(pic, dtype=jnp.float32))" + "def flatten_and_cast(pic):\n", + " \"\"\"Convert PIL image to flat (1-dimensional) numpy array.\"\"\"\n", + " return np.ravel(np.array(pic, dtype=jnp.float32))" ] }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 10, "metadata": { "id": "l314jsfP4TN4" }, @@ -341,108 +343,110 @@ "name": "stdout", "output_type": "stream", "text": [ - "Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to /tmp/mnist/MNIST/raw/train-images-idx3-ubyte.gz\n" + "Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz\n", + "Failed to download (trying next):\n", + "HTTP Error 404: Not Found\n", + "\n", + "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz\n", + "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to /tmp/mnist/MNIST/raw/train-images-idx3-ubyte.gz\n" ] }, { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "75806ce83ace4f69b81bbc4251c5573f", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" + "name": "stderr", + "output_type": "stream", + "text": [ + "100.0%\n" + ] }, { "name": "stdout", "output_type": "stream", "text": [ "Extracting /tmp/mnist/MNIST/raw/train-images-idx3-ubyte.gz to /tmp/mnist/MNIST/raw\n", - "Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to /tmp/mnist/MNIST/raw/train-labels-idx1-ubyte.gz\n" + "\n", + "Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz\n", + "Failed to download (trying next):\n", + "HTTP Error 404: Not Found\n", + "\n", + "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz\n", + "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to /tmp/mnist/MNIST/raw/train-labels-idx1-ubyte.gz\n" ] }, { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "274ed4ab05f34f70b7a5bb6cf427ffd0", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" + "name": "stderr", + "output_type": "stream", + "text": [ + "100.0%\n" + ] }, { "name": "stdout", "output_type": "stream", "text": [ "Extracting /tmp/mnist/MNIST/raw/train-labels-idx1-ubyte.gz to /tmp/mnist/MNIST/raw\n", - "Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to /tmp/mnist/MNIST/raw/t10k-images-idx3-ubyte.gz\n" + "\n", + "Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz\n", + "Failed to download (trying next):\n", + "HTTP Error 404: Not Found\n", + "\n", + "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz\n", + "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to /tmp/mnist/MNIST/raw/t10k-images-idx3-ubyte.gz\n" ] }, { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "d38fa4eabf3c4d4494eb59e078ac94e8", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" + "name": "stderr", + "output_type": "stream", + "text": [ + "100.0%\n" + ] }, { "name": "stdout", "output_type": "stream", "text": [ "Extracting /tmp/mnist/MNIST/raw/t10k-images-idx3-ubyte.gz to /tmp/mnist/MNIST/raw\n", - "Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to /tmp/mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz\n" + "\n", + "Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz\n", + "Failed to download (trying next):\n", + "HTTP Error 404: Not Found\n", + "\n", + "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz\n", + "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to /tmp/mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz\n" ] }, { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "523ac9565c5f4509a1ee8fdbb1e6d66d", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" + "name": "stderr", + "output_type": "stream", + "text": [ + "100.0%" + ] }, { "name": "stdout", "output_type": "stream", "text": [ "Extracting /tmp/mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz to /tmp/mnist/MNIST/raw\n", - "Processing...\n", - "Done!\n" + "\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" ] } ], "source": [ "# Define our dataset, using torch datasets\n", - "mnist_dataset = MNIST('/tmp/mnist/', download=True, transform=FlattenAndCast())\n", - "training_generator = NumpyLoader(mnist_dataset, batch_size=batch_size, num_workers=0)" + "mnist_dataset = MNIST('/tmp/mnist/', download=True, transform=flatten_and_cast)\n", + "# Create pytorch data loader with custom collate function\n", + "training_generator = DataLoader(mnist_dataset, batch_size=batch_size, collate_fn=numpy_collate)" ] }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 11, "metadata": { "id": "FTNo4beUvb6t", "outputId": "65a9087c-c326-49e5-cbfc-e0839212fa31" @@ -452,27 +456,13 @@ "name": "stderr", "output_type": "stream", "text": [ - "/opt/anaconda3/lib/python3.7/site-packages/torchvision/datasets/mnist.py:55: UserWarning: train_data has been renamed data\n", + "/home/m/.opt/miniforge3/envs/jax/lib/python3.12/site-packages/torchvision/datasets/mnist.py:76: UserWarning: train_data has been renamed data\n", " warnings.warn(\"train_data has been renamed data\")\n", - "/opt/anaconda3/lib/python3.7/site-packages/torchvision/datasets/mnist.py:45: UserWarning: train_labels has been renamed targets\n", - " warnings.warn(\"train_labels has been renamed targets\")\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/opt/anaconda3/lib/python3.7/site-packages/torchvision/datasets/mnist.py:60: UserWarning: test_data has been renamed data\n", + "/home/m/.opt/miniforge3/envs/jax/lib/python3.12/site-packages/torchvision/datasets/mnist.py:66: UserWarning: train_labels has been renamed targets\n", + " warnings.warn(\"train_labels has been renamed targets\")\n", + "/home/m/.opt/miniforge3/envs/jax/lib/python3.12/site-packages/torchvision/datasets/mnist.py:81: UserWarning: test_data has been renamed data\n", " warnings.warn(\"test_data has been renamed data\")\n", - "/opt/anaconda3/lib/python3.7/site-packages/torchvision/datasets/mnist.py:50: UserWarning: test_labels has been renamed targets\n", + "/home/m/.opt/miniforge3/envs/jax/lib/python3.12/site-packages/torchvision/datasets/mnist.py:71: UserWarning: test_labels has been renamed targets\n", " warnings.warn(\"test_labels has been renamed targets\")\n" ] } @@ -499,7 +489,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 12, "metadata": { "id": "X2DnZo3iYj18", "outputId": "0eba3ca2-24a1-4cba-aaf4-3ac61d0c650e" @@ -509,30 +499,30 @@ "name": "stdout", "output_type": "stream", "text": [ - "Epoch 0 in 55.15 sec\n", - "Training set accuracy 0.9157500267028809\n", - "Test set accuracy 0.9195000529289246\n", - "Epoch 1 in 42.26 sec\n", - "Training set accuracy 0.9372166991233826\n", - "Test set accuracy 0.9384000301361084\n", - "Epoch 2 in 44.37 sec\n", - "Training set accuracy 0.9491666555404663\n", - "Test set accuracy 0.9469000697135925\n", - "Epoch 3 in 41.75 sec\n", - "Training set accuracy 0.9568166732788086\n", - "Test set accuracy 0.9534000158309937\n", - "Epoch 4 in 41.16 sec\n", - "Training set accuracy 0.9631333351135254\n", - "Test set accuracy 0.9577000737190247\n", - "Epoch 5 in 38.89 sec\n", + "Epoch 0 in 5.53 sec\n", + "Training set accuracy 0.9156666994094849\n", + "Test set accuracy 0.9199000000953674\n", + "Epoch 1 in 1.13 sec\n", + "Training set accuracy 0.9370499849319458\n", + "Test set accuracy 0.9383999705314636\n", + "Epoch 2 in 1.12 sec\n", + "Training set accuracy 0.9490833282470703\n", + "Test set accuracy 0.9467999935150146\n", + "Epoch 3 in 1.21 sec\n", + "Training set accuracy 0.9568833708763123\n", + "Test set accuracy 0.9532999992370605\n", + "Epoch 4 in 1.17 sec\n", + "Training set accuracy 0.9631666541099548\n", + "Test set accuracy 0.9574999809265137\n", + "Epoch 5 in 1.17 sec\n", "Training set accuracy 0.9675000309944153\n", - "Test set accuracy 0.9616000652313232\n", - "Epoch 6 in 40.68 sec\n", - "Training set accuracy 0.9708333611488342\n", - "Test set accuracy 0.9650000333786011\n", - "Epoch 7 in 41.50 sec\n", - "Training set accuracy 0.973716676235199\n", - "Test set accuracy 0.9672000408172607\n" + "Test set accuracy 0.9615999460220337\n", + "Epoch 6 in 1.11 sec\n", + "Training set accuracy 0.9709500074386597\n", + "Test set accuracy 0.9652999639511108\n", + "Epoch 7 in 1.17 sec\n", + "Training set accuracy 0.9736999869346619\n", + "Test set accuracy 0.967199981212616\n" ] } ], @@ -576,7 +566,7 @@ "formats": "ipynb,md:myst" }, "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -590,9 +580,9 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.4" + "version": "3.12.3" } }, "nbformat": 4, - "nbformat_minor": 1 + "nbformat_minor": 4 } diff --git a/docs/notebooks/Neural_Network_and_Data_Loading.md b/docs/notebooks/Neural_Network_and_Data_Loading.md index cd98022e7421..bcc4019d6da0 100644 --- a/docs/notebooks/Neural_Network_and_Data_Loading.md +++ b/docs/notebooks/Neural_Network_and_Data_Loading.md @@ -7,7 +7,7 @@ jupytext: format_version: 0.13 jupytext_version: 1.16.4 kernelspec: - display_name: Python 3 + display_name: Python 3 (ipykernel) language: python name: python3 --- @@ -192,41 +192,28 @@ JAX is laser-focused on program transformations and accelerator-backed NumPy, so import numpy as np from jax.tree_util import tree_map -from torch.utils import data +from torch.utils.data import DataLoader, default_collate from torchvision.datasets import MNIST def numpy_collate(batch): - return tree_map(np.asarray, data.default_collate(batch)) - -class NumpyLoader(data.DataLoader): - def __init__(self, dataset, batch_size=1, - shuffle=False, sampler=None, - batch_sampler=None, num_workers=0, - pin_memory=False, drop_last=False, - timeout=0, worker_init_fn=None): - super(self.__class__, self).__init__(dataset, - batch_size=batch_size, - shuffle=shuffle, - sampler=sampler, - batch_sampler=batch_sampler, - num_workers=num_workers, - collate_fn=numpy_collate, - pin_memory=pin_memory, - drop_last=drop_last, - timeout=timeout, - worker_init_fn=worker_init_fn) - -class FlattenAndCast(object): - def __call__(self, pic): - return np.ravel(np.array(pic, dtype=jnp.float32)) + """ + Collate function specifies how to combine a list of data samples into a batch. + default_collate creates pytorch tensors, then tree_map converts them into numpy arrays. + """ + return tree_map(np.asarray, default_collate(batch)) + +def flatten_and_cast(pic): + """Convert PIL image to flat (1-dimensional) numpy array.""" + return np.ravel(np.array(pic, dtype=jnp.float32)) ``` ```{code-cell} ipython3 :id: l314jsfP4TN4 # Define our dataset, using torch datasets -mnist_dataset = MNIST('/tmp/mnist/', download=True, transform=FlattenAndCast()) -training_generator = NumpyLoader(mnist_dataset, batch_size=batch_size, num_workers=0) +mnist_dataset = MNIST('/tmp/mnist/', download=True, transform=flatten_and_cast) +# Create pytorch data loader with custom collate function +training_generator = DataLoader(mnist_dataset, batch_size=batch_size, collate_fn=numpy_collate) ``` ```{code-cell} ipython3 diff --git a/docs/notebooks/README.md b/docs/notebooks/README.md index 07be4441ade8..c945c197ad19 100644 --- a/docs/notebooks/README.md +++ b/docs/notebooks/README.md @@ -1,2 +1,2 @@ For instructions on how to change and test notebooks, see -[Update Documentation](https://jax.readthedocs.io/en/latest/developer.html#update-documentation). +[Update Documentation](https://docs.jax.dev/en/latest/developer.html#update-documentation). diff --git a/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb b/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb index 00ba9186eeec..d22457c5d718 100644 --- a/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb +++ b/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb @@ -24,7 +24,7 @@ "\n", "Here we show how to add your own function transformations to the system, by writing a custom Jaxpr interpreter. And we'll get composability with all the other transformations for free.\n", "\n", - "**This example uses internal JAX APIs, which may break at any time. Anything not in [the API Documentation](https://jax.readthedocs.io/en/latest/jax.html) should be assumed internal.**" + "**This example uses internal JAX APIs, which may break at any time. Anything not in [the API Documentation](https://docs.jax.dev/en/latest/jax.html) should be assumed internal.**" ] }, { @@ -215,8 +215,8 @@ "# Importing Jax functions useful for tracing/interpreting.\n", "from functools import wraps\n", "\n", - "from jax import core\n", "from jax import lax\n", + "from jax.extend import core\n", "from jax._src.util import safe_map" ] }, diff --git a/docs/notebooks/Writing_custom_interpreters_in_Jax.md b/docs/notebooks/Writing_custom_interpreters_in_Jax.md index 10c4e7cb6e3b..ad707a9746fc 100644 --- a/docs/notebooks/Writing_custom_interpreters_in_Jax.md +++ b/docs/notebooks/Writing_custom_interpreters_in_Jax.md @@ -27,7 +27,7 @@ etc.) that enable writing concise, accelerated code. Here we show how to add your own function transformations to the system, by writing a custom Jaxpr interpreter. And we'll get composability with all the other transformations for free. -**This example uses internal JAX APIs, which may break at any time. Anything not in [the API Documentation](https://jax.readthedocs.io/en/latest/jax.html) should be assumed internal.** +**This example uses internal JAX APIs, which may break at any time. Anything not in [the API Documentation](https://docs.jax.dev/en/latest/jax.html) should be assumed internal.** ```{code-cell} ipython3 :id: s27RDKvKXFL8 @@ -147,8 +147,8 @@ Let's use `make_jaxpr` to trace a function into a Jaxpr. # Importing Jax functions useful for tracing/interpreting. from functools import wraps -from jax import core from jax import lax +from jax.extend import core from jax._src.util import safe_map ``` diff --git a/docs/notebooks/autodiff_remat.ipynb b/docs/notebooks/autodiff_remat.ipynb index feb906546341..d8a74e4b15fd 100644 --- a/docs/notebooks/autodiff_remat.ipynb +++ b/docs/notebooks/autodiff_remat.ipynb @@ -348,7 +348,7 @@ "source": [ "### Let's think step by step\n", "\n", - "You might want to first (re)read [the Autodiff Cookbook Part 1](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html)." + "You might want to first (re)read [the Autodiff Cookbook Part 1](https://docs.jax.dev/en/latest/notebooks/autodiff_cookbook.html)." ] }, { diff --git a/docs/notebooks/autodiff_remat.md b/docs/notebooks/autodiff_remat.md index 8ba87dcfee18..12564bd91f30 100644 --- a/docs/notebooks/autodiff_remat.md +++ b/docs/notebooks/autodiff_remat.md @@ -156,7 +156,7 @@ print_fwd_bwd(f3, W1, W2, W3, x) ### Let's think step by step -You might want to first (re)read [the Autodiff Cookbook Part 1](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html). +You might want to first (re)read [the Autodiff Cookbook Part 1](https://docs.jax.dev/en/latest/notebooks/autodiff_cookbook.html). +++ {"id": "VMfwm_yinvoZ"} diff --git a/docs/notebooks/explicit-sharding.ipynb b/docs/notebooks/explicit-sharding.ipynb index d656e12d4068..37010b4ab3d3 100644 --- a/docs/notebooks/explicit-sharding.ipynb +++ b/docs/notebooks/explicit-sharding.ipynb @@ -28,7 +28,7 @@ "of work and it's also easy to make mistakes that way because there's no way to\n", "check that the shardings make sense together. More commonly, people add just\n", "enough sharding annotations to constrain the compiler. But this is a slow\n", - "iterative process. It's hard to know ahead of time what XLA's gSPMD pass will\n", + "iterative process. It's hard to know ahead of time what XLA's GSPMD pass will\n", "do (it's a whole-program optimization) so all you can do is add annotations,\n", "inspect XLA's sharding choices to see what happened, and repeat.\n", "\n", @@ -59,7 +59,7 @@ "import numpy as np\n", "import jax.numpy as jnp\n", "from jax.sharding import PartitionSpec as P, AxisType, set_mesh, get_abstract_mesh\n", - "from jax.experimental.shard import reshard, auto_axes\n", + "from jax.experimental.shard import reshard, auto_axes, explicit_axes\n", "\n", "jax.config.update('jax_num_cpu_devices', 8)" ] @@ -397,7 +397,7 @@ " which the split/merged axes are sharded as None then we shard the\n", " resulting split/merged axes as None and the other axes according to their\n", " corresponding input axis shardings. In all other cases we throw an error\n", - " and require the user to provide an `out_shardings` argument." + " and require the user to provide an `out_sharding` argument." ] }, { @@ -414,7 +414,7 @@ "wherever types need to match. For example, the two sides of a `lax.cond` need to\n", "have results with matching shardings. And the carry of `lax.scan` needs to have the\n", "same sharding at the input and the output of the scan body. And when you\n", - "contruct a jaxpr without concrete arguments using `make_jaxpr` you need to\n", + "construct a jaxpr without concrete arguments using `make_jaxpr` you need to\n", "provide shardings too. Certain JAX transformations perform type-level\n", "operations. Automatic differentation constructs a tangent type for each primal\n", "type in the original computation (e.g. `TangentOf(float) == float`,\n", @@ -494,7 +494,7 @@ " print(f\"We're in auto-sharding mode here. This is the current mesh: {get_abstract_mesh()}\")\n", " return x + y\n", "\n", - "result = add_with_out_sharding_kwarg(some_x, some_y, out_shardings=P(\"X\", None))\n", + "result = add_with_out_sharding_kwarg(some_x, some_y, out_sharding=P(\"X\", None))\n", "print(f\"Result type: {jax.typeof(result)}\")" ] }, @@ -527,11 +527,11 @@ "\n", "A summary table:\n", "\n", - "| Mode | Explicit sharding? | Explicit Collectives? |\n", - "|---|---|---|\n", - "| Auto | No | No |\n", - "| Explicit (new) | Yes | No |\n", - "| Manual | Yes | Yes |\n", + "| Mode | View? | Explicit sharding? | Explicit Collectives? |\n", + "|---|---|---|---|\n", + "| Auto | Global | ❌ | ❌ |\n", + "| Explicit | Global | ✅ | ❌ |\n", + "| Manual | Per-device | ✅ | ✅ |\n", "\n", "The current mesh tells us which sharding mode we're in. We can query it with\n", "`get_abstract_mesh`:" @@ -637,7 +637,7 @@ " x = jnp.sin(arr1)\n", " print(f'x.sharding: {jax.typeof(x)}', end='\\n\\n')\n", "\n", - " z = g(x, out_shardings=P(\"X\", \"Y\"))\n", + " z = g(x, out_sharding=P(\"X\", \"Y\"))\n", "\n", " print(f'z.sharding: {jax.typeof(z)}', end=\"\\n\\n\")\n", " return z + 1\n", @@ -652,7 +652,51 @@ "id": "_3sfJjRq8w9f" }, "source": [ - "As you can see, inside `g`, the type of `arr1` is `ShapedArray(float32[4,4@Y])` which indicates it's Explicit over `Y` mesh axis while auto over `X`." + "As you can see, inside `g`, the type of `arr1` is `ShapedArray(float32[4,4@Y])` which indicates it's Explicit over `Y` mesh axis while auto over `X`.\n", + "\n", + "\n", + "You can also use the `explicit_axes` API to drop into `Explicit` mode over some or all mesh axes." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a102e9c7", + "metadata": {}, + "outputs": [], + "source": [ + "auto_mesh = jax.make_mesh((2, 4), (\"X\", \"Y\"),\n", + " axis_types=(AxisType.Auto, AxisType.Auto))\n", + "\n", + "@functools.partial(explicit_axes, axes=('X', 'Y'))\n", + "def explicit_g(y):\n", + " print(f'mesh inside g: {get_abstract_mesh()}')\n", + " print(f'y.sharding inside g: {jax.typeof(y) = }')\n", + " z = y * 2\n", + " print(f'z.sharding inside g: {jax.typeof(z) = }', end='\\n\\n')\n", + " return z\n", + "\n", + "@jax.jit\n", + "def f(arr1):\n", + " print(f'mesh inside f: {get_abstract_mesh()}', end='\\n\\n')\n", + " x = jnp.sin(arr1)\n", + "\n", + " z = explicit_g(x, in_sharding=P(\"X\", \"Y\"))\n", + "\n", + " return z + 1\n", + "\n", + "with jax.sharding.use_mesh(auto_mesh):\n", + " some_x = jax.device_put(np.arange(16).reshape(4, 4), P(\"X\", \"Y\"))\n", + " f(some_x)" + ] + }, + { + "cell_type": "markdown", + "id": "e64d40de", + "metadata": {}, + "source": [ + "As you can see, all axes of mesh inside `f` are of type `Auto` while inside `g`, they are of type `Explicit`.\n", + "Because of that, sharding is visible on the type of arrays inside `g`." ] }, { @@ -734,7 +778,7 @@ " compare_shardings(x)\n", " return x\n", "\n", - "check_in_auto_context(my_array, out_shardings=P(\"X\"))" + "check_in_auto_context(my_array, out_sharding=P(\"X\"))" ] }, { diff --git a/docs/notebooks/explicit-sharding.md b/docs/notebooks/explicit-sharding.md index 7c59a675d8ec..8989d426ffbc 100644 --- a/docs/notebooks/explicit-sharding.md +++ b/docs/notebooks/explicit-sharding.md @@ -31,7 +31,7 @@ constraints? You could put them on every single intermediate but that's a lot of work and it's also easy to make mistakes that way because there's no way to check that the shardings make sense together. More commonly, people add just enough sharding annotations to constrain the compiler. But this is a slow -iterative process. It's hard to know ahead of time what XLA's gSPMD pass will +iterative process. It's hard to know ahead of time what XLA's GSPMD pass will do (it's a whole-program optimization) so all you can do is add annotations, inspect XLA's sharding choices to see what happened, and repeat. @@ -56,7 +56,7 @@ import jax import numpy as np import jax.numpy as jnp from jax.sharding import PartitionSpec as P, AxisType, set_mesh, get_abstract_mesh -from jax.experimental.shard import reshard, auto_axes +from jax.experimental.shard import reshard, auto_axes, explicit_axes jax.config.update('jax_num_cpu_devices', 8) ``` @@ -239,7 +239,7 @@ Here are some example sharding rules: which the split/merged axes are sharded as None then we shard the resulting split/merged axes as None and the other axes according to their corresponding input axis shardings. In all other cases we throw an error - and require the user to provide an `out_shardings` argument. + and require the user to provide an `out_sharding` argument. +++ {"id": "jZMp6w48Xmd7"} @@ -251,7 +251,7 @@ sharding is part of that type. This means that shardings need to match wherever types need to match. For example, the two sides of a `lax.cond` need to have results with matching shardings. And the carry of `lax.scan` needs to have the same sharding at the input and the output of the scan body. And when you -contruct a jaxpr without concrete arguments using `make_jaxpr` you need to +construct a jaxpr without concrete arguments using `make_jaxpr` you need to provide shardings too. Certain JAX transformations perform type-level operations. Automatic differentation constructs a tangent type for each primal type in the original computation (e.g. `TangentOf(float) == float`, @@ -308,7 +308,7 @@ def add_with_out_sharding_kwarg(x, y): print(f"We're in auto-sharding mode here. This is the current mesh: {get_abstract_mesh()}") return x + y -result = add_with_out_sharding_kwarg(some_x, some_y, out_shardings=P("X", None)) +result = add_with_out_sharding_kwarg(some_x, some_y, out_sharding=P("X", None)) print(f"Result type: {jax.typeof(result)}") ``` @@ -337,11 +337,11 @@ JAX now has three styles of parallelism: A summary table: -| Mode | Explicit sharding? | Explicit Collectives? | -|---|---|---| -| Auto | No | No | -| Explicit (new) | Yes | No | -| Manual | Yes | Yes | +| Mode | View? | Explicit sharding? | Explicit Collectives? | +|---|---|---|---| +| Auto | Global | ❌ | ❌ | +| Explicit | Global | ✅ | ❌ | +| Manual | Per-device | ✅ | ✅ | The current mesh tells us which sharding mode we're in. We can query it with `get_abstract_mesh`: @@ -390,7 +390,7 @@ def f(arr1): x = jnp.sin(arr1) print(f'x.sharding: {jax.typeof(x)}', end='\n\n') - z = g(x, out_shardings=P("X", "Y")) + z = g(x, out_sharding=P("X", "Y")) print(f'z.sharding: {jax.typeof(z)}', end="\n\n") return z + 1 @@ -403,6 +403,38 @@ f(some_x) As you can see, inside `g`, the type of `arr1` is `ShapedArray(float32[4,4@Y])` which indicates it's Explicit over `Y` mesh axis while auto over `X`. + +You can also use the `explicit_axes` API to drop into `Explicit` mode over some or all mesh axes. + +```{code-cell} ipython3 +auto_mesh = jax.make_mesh((2, 4), ("X", "Y"), + axis_types=(AxisType.Auto, AxisType.Auto)) + +@functools.partial(explicit_axes, axes=('X', 'Y')) +def explicit_g(y): + print(f'mesh inside g: {get_abstract_mesh()}') + print(f'y.sharding inside g: {jax.typeof(y) = }') + z = y * 2 + print(f'z.sharding inside g: {jax.typeof(z) = }', end='\n\n') + return z + +@jax.jit +def f(arr1): + print(f'mesh inside f: {get_abstract_mesh()}', end='\n\n') + x = jnp.sin(arr1) + + z = explicit_g(x, in_sharding=P("X", "Y")) + + return z + 1 + +with jax.sharding.use_mesh(auto_mesh): + some_x = jax.device_put(np.arange(16).reshape(4, 4), P("X", "Y")) + f(some_x) +``` + +As you can see, all axes of mesh inside `f` are of type `Auto` while inside `g`, they are of type `Explicit`. +Because of that, sharding is visible on the type of arrays inside `g`. + +++ {"id": "sJcWbfAh7UcO"} ## Concrete array shardings can mention `Auto` mesh axis @@ -437,7 +469,7 @@ def check_in_auto_context(x): compare_shardings(x) return x -check_in_auto_context(my_array, out_shardings=P("X")) +check_in_auto_context(my_array, out_sharding=P("X")) ``` +++ {"id": "MRFccsi5X8so"} diff --git a/docs/notebooks/host-offloading.ipynb b/docs/notebooks/host-offloading.ipynb new file mode 100644 index 000000000000..f56cb90ff77e --- /dev/null +++ b/docs/notebooks/host-offloading.ipynb @@ -0,0 +1,522 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "bQbS50fIdHw1" + }, + "source": [ + "(host-offloading)=\n", + "# JAX Memories and Host Offloading\n", + "\n", + "\n", + "\n", + "This tutorial provides a practical introduction to host offloading techniques in JAX, focusing on:\n", + "\n", + "- Activation offloading\n", + "- Parameter offloading\n", + "\n", + "By applying offloading strategies, you can better manage memory resources and reduce memory pressure on your devices. To implement these strategies effectively, you'll need to understand JAX's core mechanisms for data placement and movement.\n", + "\n", + "## Building Blocks for Offloading\n", + "\n", + "JAX provides several key components for controlling where and how data are stored and moved between the host and the device memory. In the following sections, you'll explore:\n", + "\n", + "- How to specify data distribution with sharding\n", + "- How to control memory placement between host and device\n", + "- How to manage data movement in jitted functions\n", + "\n", + "### NamedSharding and Memory Kinds\n", + "\n", + "{class}`~jax.sharding.NamedSharding` defines how data are distributed across devices. It includes:\n", + "\n", + "- Basic data distribution configuration\n", + "- `memory_kind` parameter for specifying memory type (`device` or `pinned_host`)\n", + "- By default, `memory_kind` is set to `device` memory\n", + "- `with_memory_kind` method for creating new sharding with modified memory type" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "f-6sxUlqrlBn", + "outputId": "691a3df2-8341-44a9-a4a0-5521c2d891e3" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "NamedSharding(mesh=Mesh('x': 1, 'y': 1), spec=PartitionSpec('x', 'y'), memory_kind=device)\n", + "NamedSharding(mesh=Mesh('x': 1, 'y': 1), spec=PartitionSpec('x', 'y'), memory_kind=pinned_host)\n" + ] + } + ], + "source": [ + "import jax\n", + "import jax.numpy as jnp\n", + "from jax.sharding import Mesh, NamedSharding, PartitionSpec as P\n", + "import numpy as np\n", + "\n", + "# Create mesh\n", + "# 1x1 mesh represents a single device with two named dimensions (x and y)\n", + "mesh = Mesh(np.array(jax.devices()[0]).reshape(1, 1), ('x', 'y'))\n", + "\n", + "# Device sharding - partitions data along x and y dimensions\n", + "s_dev = NamedSharding(mesh, P('x', 'y'), memory_kind=\"device\")\n", + "\n", + "# Host sharding - same partitioning but in pinned host memory\n", + "s_host = s_dev.with_memory_kind('pinned_host')\n", + "\n", + "print(s_dev) # Shows device memory sharding\n", + "print(s_host) # Shows pinned host memory sharding" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "R_pB9465VoMP" + }, + "source": [ + "### Data Placement with device_put\n", + "\n", + "{func}`jax.device_put` is a function that explicitly transfers arrays to a specified memory location according to a sharding specification." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "OJFnf7FGp6Lj", + "outputId": "c762e1df-2453-4ed9-9d53-0defb6a05ce2" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "pinned_host\n", + "device\n" + ] + } + ], + "source": [ + "# Create a 2x4 array\n", + "arr = jnp.arange(8.0).reshape(2, 4)\n", + "\n", + "# Move arrays to different memory locations based on sharding objects\n", + "arr_host = jax.device_put(arr, s_host) # Places in pinned host memory\n", + "arr_dev = jax.device_put(arr, s_dev) # Places in device memory\n", + "\n", + "# Verify memory locations\n", + "print(arr_host.sharding.memory_kind) # Output: pinned_host\n", + "print(arr_dev.sharding.memory_kind) # Output: device" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "HHXvBpQKTMCR" + }, + "source": [ + "### Output Sharding Controls\n", + "\n", + "Shardings determine how data is split across devices. JAX provides `out_shardings` to control how output arrays are partitioned when leaving a jitted function.\n", + "\n", + "Key Features:\n", + " - Can differ from input sharding\n", + " - Allows different memory kinds for outputs\n", + "\n", + "Examples:\n", + "\n", + "#### Device Output Sharding" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "ZXNj9NUeaIdX", + "outputId": "399321ef-082a-4a77-c33a-9de3421f429b" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Result value of H2D: \n", + " [[0. 1. 2. 3.]\n", + " [4. 5. 6. 7.]]\n" + ] + } + ], + "source": [ + "f = jax.jit(lambda x:x, out_shardings=s_dev)\n", + "out_dev = f(arr_host)\n", + "print(\"Result value of H2D: \\n\", out_dev)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "iYXC5ix384XP" + }, + "source": [ + "Moving data from host to device memory when needed for computation is the essence of host offloading. Use {func}`jax.device_put` to perform this transfer in this example to optimize performance." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "cmM6tJTS84XQ", + "outputId": "40c353a1-fb55-44bc-bac9-dffc09852f49" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Result value of H2D and add 1 in device memory: \n", + " [[1. 2. 3. 4.]\n", + " [5. 6. 7. 8.]]\n" + ] + } + ], + "source": [ + "# Instead of the lambda function, you can define add_func to explicitly\n", + "# move data to device before computation\n", + "def add_func(x): # Move data to device and add one\n", + " x = jax.device_put(x, s_dev)\n", + " return x + 1\n", + "\n", + "f = jax.jit(add_func, out_shardings=s_dev)\n", + "out_dev = f(arr_host)\n", + "print(\"Result value of H2D and add 1 in device memory: \\n\", out_dev)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "EbE-eBrJTBuS" + }, + "source": [ + "#### Host Output Sharding" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "FjZzkxI8ky4r", + "outputId": "2a1b6e7a-1c29-4347-c020-7b47c27a5cc3" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Result value of D2H: \n", + " [[0. 1. 2. 3.]\n", + " [4. 5. 6. 7.]]\n" + ] + } + ], + "source": [ + "f = jax.jit(lambda x: x, out_shardings=s_dev)\n", + "out_host = f(arr_host) # Input arrays in the device memory while output arrays in the host memory\n", + "print(\"Result value of D2H: \\n\", out_host)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "UhLVvRO2p6Lj" + }, + "source": [ + "## Activation Offloading\n", + "\n", + "The detailed coverage of activation offloading can be found in the {ref}`gradient-checkpointing` tutorial. Activation offloading helps manage memory by moving intermediate activations to host memory after the forward pass, and bringing them back to device memory during the backward pass when needed for gradient computation.\n", + "\n", + "To implement activation offloading effectively, you need to understand checkpoint names and policies. Here's how they work in a simple example:\n", + "\n", + "### Checkpoint Names\n", + "\n", + "The {func}`checkpoint_name` function allows you to label activations for memory management during computation. Here's a simple example:" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "id": "sLO9ceS6p6Lj" + }, + "outputs": [], + "source": [ + "from jax.ad_checkpoint import checkpoint_name\n", + "\n", + "def layer(x, w):\n", + " w1, w2 = w\n", + " x = checkpoint_name(x, \"x\")\n", + " y = x @ w1\n", + " return y @ w2, None" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "-_T92oCOp6Lk" + }, + "source": [ + "This example shows:\n", + "\n", + "* A simple neural network layer with two matrix multiplications\n", + "* Labeling of input activation x with identifier `\"x\"`\n", + "* Sequential operations:\n", + " 1. First multiplication: `x @ w1`\n", + " 2. Second multiplication: `y @ w2`\n", + "\n", + "The checkpoint name helps the system decide whether to:\n", + "* Keep the activation in device memory or\n", + "* Offload it to host memory during computation\n", + "\n", + "This pattern is common in neural networks, where multiple transformations are applied sequentially to input data.\n", + "\n", + "\n", + "### Checkpoint Policies\n", + "\n", + "The {func}`jax.remat` transformation manages memory by handling intermediate values through three strategies:\n", + "\n", + "1. Recomputing during backward pass (default behavior)\n", + "2. Storing on device\n", + "3. Offloading to host memory after forward pass and loading back during backward pass\n", + "\n", + "Example of setting an offloading checkpoint policy:" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "id": "W8Usw_wOp6Lk" + }, + "outputs": [], + "source": [ + "from jax import checkpoint_policies as cp\n", + "\n", + "policy = cp.save_and_offload_only_these_names(\n", + " names_which_can_be_saved=[], # No values stored on device\n", + " names_which_can_be_offloaded=[\"x\"], # Offload activations labeled \"x\"\n", + " offload_src=\"device\", # Move from device memory\n", + " offload_dst=\"pinned_host\" # To pinned host memory\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "iuDRCXu7ky4r" + }, + "source": [ + "Since {func}`jax.lax.scan` is commonly used in JAX for handling sequential operations (like RNNs or transformers), you need to know how to apply your offloading strategy in this context.\n", + "\n", + "Key components:\n", + "* {func}`jax.remat` applies our checkpoint policy to the layer function\n", + "* `prevent_cse=False` enables XLA's common subexpression elimination for better performance\n", + "* {func}`jax.lax.scan` iterates the rematerialized layer along an axis" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "xCrxjTx_p6Lk", + "outputId": "13d46584-9b25-4622-b3c3-f50c1dac02c2" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Sample of results: [3.7363498e-07 3.7363498e-07 3.7363498e-07 3.7363498e-07 3.7363498e-07]\n" + ] + } + ], + "source": [ + "def scanned(w, x):\n", + " remat_layer = jax.remat(layer,\n", + " policy=policy, # Use our offloading policy\n", + " prevent_cse=False) # Allow CSE optimizations\n", + " result = jax.lax.scan(remat_layer, x, w)[0]\n", + " return jnp.sum(result)\n", + "\n", + "# Initialize input and weights with small values (0.0001)\n", + "input = jnp.ones((256, 256), dtype=jnp.float32) * 0.001 # Input matrix: 256 x 256\n", + "w1 = jnp.ones((10, 256, 1024), dtype=jnp.float32) * 0.001 # 10 layers of 256 x 1024 matrices\n", + "w2 = jnp.ones((10, 1024, 256), dtype=jnp.float32) * 0.001 # 10 layers of 1024 x 256 matrices\n", + "\n", + "# Compile and compute gradients of the scanned function\n", + "f = jax.jit(jax.grad(scanned)) # Apply JIT compilation to gradient computation\n", + "result_activation = f((w1, w2), input) # Execute the function with weights and input\n", + "print(\"Sample of results: \", result_activation[0][0, 0, :5])" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "0tx7aara42pY" + }, + "source": [ + "### Summary of Activation Offloading\n", + "\n", + "Activation offloading provides a powerful way to manage memory in large computations by:\n", + "\n", + "* Using checkpoint names to mark specific activations\n", + "* Applying policies to control where and how activations are stored\n", + "* Supporting common JAX patterns like scan operations\n", + "* Moving selected activations to host memory when device memory is under budget\n", + "\n", + "This approach is particularly useful when working with large models that would otherwise exceed device memory capacity.\n", + "\n", + "## Parameter Offloading\n", + "\n", + "Model parameters (also known as weights) can be offloaded to the host memory to optimize device memory usage during initialization. This is achieved by using {func}`jax.jit` with a sharding strategy that specifies host memory kind.\n", + "\n", + "While parameter offloading and activation offloading are distinct memory optimization techniques, the following example demonstrates parameter offloading built upon the activation offloading implementation shown earlier.\n", + "\n", + "### Parameter Placement for Computation\n", + "\n", + "Different from the earlier `layer` function, {func}`jax.device_put` is applied to move parameter `w1` and `w2` to the device before the matrix multiplications. This ensures the parameters are available on the device for both forward and backward passes.\n", + "\n", + "Note that the activation offloading implementation remains unchanged, using the same:\n", + "* Checkpoint name `\"x\"`\n", + "* Checkpoint policy\n", + "* `scanned` function combining {func}`jax.remat` and {func}`jax.lax.scan`\n", + "\n", + "### Parameter Initialization with Host Offloading\n", + "\n", + "During the initialization, parameter `w1` and `w2` are placed on host memory before being passed to the {func}`jax.jit` function `f`, while keeping the `input` variable on the device." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "1qGN2hBQdheo", + "outputId": "48c09658-f8b6-4be3-ef0e-02e0e2566e10" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Results match within tolerance: True\n" + ] + } + ], + "source": [ + "# Hybrid version: Both activation and parameter offloading\n", + "def hybrid_layer(x, w):\n", + " # Move model parameters w1 and w2 to host memory via device_put\n", + " w1, w2 = jax.tree.map(lambda x: jax.device_put(x, s_dev), w)\n", + " x = checkpoint_name(x, \"x\") # Offload activation x to host memory\n", + " y = x @ w1\n", + " return y @ w2, None\n", + "\n", + "def hybrid_scanned(w, x):\n", + " remat_layer = jax.remat(hybrid_layer, # Use hybrid_layer instead of layer\n", + " policy=policy, # Use offloading policy\n", + " prevent_cse=False) # Allow CSE optimizations\n", + " result = jax.lax.scan(remat_layer, x, w)[0]\n", + " return jnp.sum(result)\n", + "\n", + "# Move model parameters w1 and w2 to the host via device_put\n", + "# Initialize input and weights with small values (0.0001)\n", + "wh1 = jax.device_put(w1, s_host)\n", + "wh2 = jax.device_put(w2, s_host)\n", + "\n", + "# Compile and compute gradients of the scanned function\n", + "f = jax.jit(jax.grad(hybrid_scanned)) # Apply JIT compilation to gradient computation\n", + "result_both = f((wh1, wh2), input) # Execute with both activation and parameter offloading\n", + "\n", + "# Verify numerical correctness\n", + "are_close = jnp.allclose(\n", + " result_activation[0], # Result from activation offloading only\n", + " result_both[0], # Result from both activation and parameter offloading\n", + " rtol=1e-5,\n", + " atol=1e-5\n", + ")\n", + "print(f\"Results match within tolerance: {are_close}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "SVpozzwHflQk" + }, + "source": [ + "The matching results verify that initializing parameters on host memory maintains computational correctness.\n", + "\n", + "### Limitation of Parameter Offloading\n", + "\n", + "{func}`jax.lax.scan` is crucial for effective parameter management. Using an explicit for loop would cause parameters to continuously occupy device memory, resulting in the same memory usage as without parameter offloading. While {func}`jax.lax.scan` allows specifying the scan axis, parameter offloading currently works only when scanning over axis 0. Scanning over other axes generates a `transpose` operation during compilation before returning parameters to the device, which is expensive and not supported on all platforms.\n", + "\n", + "## Tools for Host Offloading\n", + "\n", + "For device memory analysis, refer to :doc:`device_memory_profiling`. The profiling tools described in {ref}`profiling` can help measure memory savings and performance impact from host offloading." + ] + } + ], + "metadata": { + "accelerator": "TPU", + "colab": { + "gpuType": "V28", + "provenance": [], + "toc_visible": true + }, + "jupytext": { + "formats": "ipynb,md:myst" + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.3" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/docs/notebooks/host-offloading.md b/docs/notebooks/host-offloading.md new file mode 100644 index 000000000000..cffe8b4340fe --- /dev/null +++ b/docs/notebooks/host-offloading.md @@ -0,0 +1,342 @@ +--- +jupytext: + formats: ipynb,md:myst + text_representation: + extension: .md + format_name: myst + format_version: 0.13 + jupytext_version: 1.16.4 +kernelspec: + display_name: Python 3 (ipykernel) + language: python + name: python3 +--- + ++++ {"id": "bQbS50fIdHw1"} + +(host-offloading)= +# JAX Memories and Host Offloading + + + +This tutorial provides a practical introduction to host offloading techniques in JAX, focusing on: + +- Activation offloading +- Parameter offloading + +By applying offloading strategies, you can better manage memory resources and reduce memory pressure on your devices. To implement these strategies effectively, you'll need to understand JAX's core mechanisms for data placement and movement. + +## Building Blocks for Offloading + +JAX provides several key components for controlling where and how data are stored and moved between the host and the device memory. In the following sections, you'll explore: + +- How to specify data distribution with sharding +- How to control memory placement between host and device +- How to manage data movement in jitted functions + +### NamedSharding and Memory Kinds + +{class}`~jax.sharding.NamedSharding` defines how data are distributed across devices. It includes: + +- Basic data distribution configuration +- `memory_kind` parameter for specifying memory type (`device` or `pinned_host`) +- By default, `memory_kind` is set to `device` memory +- `with_memory_kind` method for creating new sharding with modified memory type + +```{code-cell} ipython3 +--- +colab: + base_uri: https://localhost:8080/ +id: f-6sxUlqrlBn +outputId: 691a3df2-8341-44a9-a4a0-5521c2d891e3 +--- +import jax +import jax.numpy as jnp +from jax.sharding import Mesh, NamedSharding, PartitionSpec as P +import numpy as np + +# Create mesh +# 1x1 mesh represents a single device with two named dimensions (x and y) +mesh = Mesh(np.array(jax.devices()[0]).reshape(1, 1), ('x', 'y')) + +# Device sharding - partitions data along x and y dimensions +s_dev = NamedSharding(mesh, P('x', 'y'), memory_kind="device") + +# Host sharding - same partitioning but in pinned host memory +s_host = s_dev.with_memory_kind('pinned_host') + +print(s_dev) # Shows device memory sharding +print(s_host) # Shows pinned host memory sharding +``` + ++++ {"id": "R_pB9465VoMP"} + +### Data Placement with device_put + +{func}`jax.device_put` is a function that explicitly transfers arrays to a specified memory location according to a sharding specification. + +```{code-cell} ipython3 +--- +colab: + base_uri: https://localhost:8080/ +id: OJFnf7FGp6Lj +outputId: c762e1df-2453-4ed9-9d53-0defb6a05ce2 +--- +# Create a 2x4 array +arr = jnp.arange(8.0).reshape(2, 4) + +# Move arrays to different memory locations based on sharding objects +arr_host = jax.device_put(arr, s_host) # Places in pinned host memory +arr_dev = jax.device_put(arr, s_dev) # Places in device memory + +# Verify memory locations +print(arr_host.sharding.memory_kind) # Output: pinned_host +print(arr_dev.sharding.memory_kind) # Output: device +``` + ++++ {"id": "HHXvBpQKTMCR"} + +### Output Sharding Controls + +Shardings determine how data is split across devices. JAX provides `out_shardings` to control how output arrays are partitioned when leaving a jitted function. + +Key Features: + - Can differ from input sharding + - Allows different memory kinds for outputs + +Examples: + +#### Device Output Sharding + +```{code-cell} ipython3 +--- +colab: + base_uri: https://localhost:8080/ +id: ZXNj9NUeaIdX +outputId: 399321ef-082a-4a77-c33a-9de3421f429b +--- +f = jax.jit(lambda x:x, out_shardings=s_dev) +out_dev = f(arr_host) +print("Result value of H2D: \n", out_dev) +``` + ++++ {"id": "iYXC5ix384XP"} + +Moving data from host to device memory when needed for computation is the essence of host offloading. Use {func}`jax.device_put` to perform this transfer in this example to optimize performance. + +```{code-cell} ipython3 +--- +colab: + base_uri: https://localhost:8080/ +id: cmM6tJTS84XQ +outputId: 40c353a1-fb55-44bc-bac9-dffc09852f49 +--- +# Instead of the lambda function, you can define add_func to explicitly +# move data to device before computation +def add_func(x): # Move data to device and add one + x = jax.device_put(x, s_dev) + return x + 1 + +f = jax.jit(add_func, out_shardings=s_dev) +out_dev = f(arr_host) +print("Result value of H2D and add 1 in device memory: \n", out_dev) +``` + ++++ {"id": "EbE-eBrJTBuS"} + +#### Host Output Sharding + +```{code-cell} ipython3 +--- +colab: + base_uri: https://localhost:8080/ +id: FjZzkxI8ky4r +outputId: 2a1b6e7a-1c29-4347-c020-7b47c27a5cc3 +--- +f = jax.jit(lambda x: x, out_shardings=s_dev) +out_host = f(arr_host) # Input arrays in the device memory while output arrays in the host memory +print("Result value of D2H: \n", out_host) +``` + ++++ {"id": "UhLVvRO2p6Lj"} + +## Activation Offloading + +The detailed coverage of activation offloading can be found in the {ref}`gradient-checkpointing` tutorial. Activation offloading helps manage memory by moving intermediate activations to host memory after the forward pass, and bringing them back to device memory during the backward pass when needed for gradient computation. + +To implement activation offloading effectively, you need to understand checkpoint names and policies. Here's how they work in a simple example: + +### Checkpoint Names + +The {func}`checkpoint_name` function allows you to label activations for memory management during computation. Here's a simple example: + +```{code-cell} ipython3 +:id: sLO9ceS6p6Lj + +from jax.ad_checkpoint import checkpoint_name + +def layer(x, w): + w1, w2 = w + x = checkpoint_name(x, "x") + y = x @ w1 + return y @ w2, None +``` + ++++ {"id": "-_T92oCOp6Lk"} + +This example shows: + +* A simple neural network layer with two matrix multiplications +* Labeling of input activation x with identifier `"x"` +* Sequential operations: + 1. First multiplication: `x @ w1` + 2. Second multiplication: `y @ w2` + +The checkpoint name helps the system decide whether to: +* Keep the activation in device memory or +* Offload it to host memory during computation + +This pattern is common in neural networks, where multiple transformations are applied sequentially to input data. + + +### Checkpoint Policies + +The {func}`jax.remat` transformation manages memory by handling intermediate values through three strategies: + +1. Recomputing during backward pass (default behavior) +2. Storing on device +3. Offloading to host memory after forward pass and loading back during backward pass + +Example of setting an offloading checkpoint policy: + +```{code-cell} ipython3 +:id: W8Usw_wOp6Lk + +from jax import checkpoint_policies as cp + +policy = cp.save_and_offload_only_these_names( + names_which_can_be_saved=[], # No values stored on device + names_which_can_be_offloaded=["x"], # Offload activations labeled "x" + offload_src="device", # Move from device memory + offload_dst="pinned_host" # To pinned host memory +) +``` + ++++ {"id": "iuDRCXu7ky4r"} + +Since {func}`jax.lax.scan` is commonly used in JAX for handling sequential operations (like RNNs or transformers), you need to know how to apply your offloading strategy in this context. + +Key components: +* {func}`jax.remat` applies our checkpoint policy to the layer function +* `prevent_cse=False` enables XLA's common subexpression elimination for better performance +* {func}`jax.lax.scan` iterates the rematerialized layer along an axis + +```{code-cell} ipython3 +--- +colab: + base_uri: https://localhost:8080/ +id: xCrxjTx_p6Lk +outputId: 13d46584-9b25-4622-b3c3-f50c1dac02c2 +--- +def scanned(w, x): + remat_layer = jax.remat(layer, + policy=policy, # Use our offloading policy + prevent_cse=False) # Allow CSE optimizations + result = jax.lax.scan(remat_layer, x, w)[0] + return jnp.sum(result) + +# Initialize input and weights with small values (0.0001) +input = jnp.ones((256, 256), dtype=jnp.float32) * 0.001 # Input matrix: 256 x 256 +w1 = jnp.ones((10, 256, 1024), dtype=jnp.float32) * 0.001 # 10 layers of 256 x 1024 matrices +w2 = jnp.ones((10, 1024, 256), dtype=jnp.float32) * 0.001 # 10 layers of 1024 x 256 matrices + +# Compile and compute gradients of the scanned function +f = jax.jit(jax.grad(scanned)) # Apply JIT compilation to gradient computation +result_activation = f((w1, w2), input) # Execute the function with weights and input +print("Sample of results: ", result_activation[0][0, 0, :5]) +``` + ++++ {"id": "0tx7aara42pY"} + +### Summary of Activation Offloading + +Activation offloading provides a powerful way to manage memory in large computations by: + +* Using checkpoint names to mark specific activations +* Applying policies to control where and how activations are stored +* Supporting common JAX patterns like scan operations +* Moving selected activations to host memory when device memory is under budget + +This approach is particularly useful when working with large models that would otherwise exceed device memory capacity. + +## Parameter Offloading + +Model parameters (also known as weights) can be offloaded to the host memory to optimize device memory usage during initialization. This is achieved by using {func}`jax.jit` with a sharding strategy that specifies host memory kind. + +While parameter offloading and activation offloading are distinct memory optimization techniques, the following example demonstrates parameter offloading built upon the activation offloading implementation shown earlier. + +### Parameter Placement for Computation + +Different from the earlier `layer` function, {func}`jax.device_put` is applied to move parameter `w1` and `w2` to the device before the matrix multiplications. This ensures the parameters are available on the device for both forward and backward passes. + +Note that the activation offloading implementation remains unchanged, using the same: +* Checkpoint name `"x"` +* Checkpoint policy +* `scanned` function combining {func}`jax.remat` and {func}`jax.lax.scan` + +### Parameter Initialization with Host Offloading + +During the initialization, parameter `w1` and `w2` are placed on host memory before being passed to the {func}`jax.jit` function `f`, while keeping the `input` variable on the device. + +```{code-cell} ipython3 +--- +colab: + base_uri: https://localhost:8080/ +id: 1qGN2hBQdheo +outputId: 48c09658-f8b6-4be3-ef0e-02e0e2566e10 +--- +# Hybrid version: Both activation and parameter offloading +def hybrid_layer(x, w): + # Move model parameters w1 and w2 to host memory via device_put + w1, w2 = jax.tree.map(lambda x: jax.device_put(x, s_dev), w) + x = checkpoint_name(x, "x") # Offload activation x to host memory + y = x @ w1 + return y @ w2, None + +def hybrid_scanned(w, x): + remat_layer = jax.remat(hybrid_layer, # Use hybrid_layer instead of layer + policy=policy, # Use offloading policy + prevent_cse=False) # Allow CSE optimizations + result = jax.lax.scan(remat_layer, x, w)[0] + return jnp.sum(result) + +# Move model parameters w1 and w2 to the host via device_put +# Initialize input and weights with small values (0.0001) +wh1 = jax.device_put(w1, s_host) +wh2 = jax.device_put(w2, s_host) + +# Compile and compute gradients of the scanned function +f = jax.jit(jax.grad(hybrid_scanned)) # Apply JIT compilation to gradient computation +result_both = f((wh1, wh2), input) # Execute with both activation and parameter offloading + +# Verify numerical correctness +are_close = jnp.allclose( + result_activation[0], # Result from activation offloading only + result_both[0], # Result from both activation and parameter offloading + rtol=1e-5, + atol=1e-5 +) +print(f"Results match within tolerance: {are_close}") +``` + ++++ {"id": "SVpozzwHflQk"} + +The matching results verify that initializing parameters on host memory maintains computational correctness. + +### Limitation of Parameter Offloading + +{func}`jax.lax.scan` is crucial for effective parameter management. Using an explicit for loop would cause parameters to continuously occupy device memory, resulting in the same memory usage as without parameter offloading. While {func}`jax.lax.scan` allows specifying the scan axis, parameter offloading currently works only when scanning over axis 0. Scanning over other axes generates a `transpose` operation during compilation before returning parameters to the device, which is expensive and not supported on all platforms. + +## Tools for Host Offloading + +For device memory analysis, refer to :doc:`device_memory_profiling`. The profiling tools described in {ref}`profiling` can help measure memory savings and performance impact from host offloading. diff --git a/docs/notebooks/neural_network_with_tfds_data.ipynb b/docs/notebooks/neural_network_with_tfds_data.ipynb index c31a99746866..a909d9329e24 100644 --- a/docs/notebooks/neural_network_with_tfds_data.ipynb +++ b/docs/notebooks/neural_network_with_tfds_data.ipynb @@ -46,7 +46,7 @@ "\n", "![JAX](https://raw.githubusercontent.com/jax-ml/jax/main/images/jax_logo_250px.png)\n", "\n", - "Let's combine everything we showed in the [quickstart](https://jax.readthedocs.io/en/latest/quickstart.html) to train a simple neural network. We will first specify and train a simple MLP on MNIST using JAX for the computation. We will use `tensorflow/datasets` data loading API to load images and labels (because it's pretty great, and the world doesn't need yet another data loading library :P).\n", + "Let's combine everything we showed in the [quickstart](https://docs.jax.dev/en/latest/quickstart.html) to train a simple neural network. We will first specify and train a simple MLP on MNIST using JAX for the computation. We will use `tensorflow/datasets` data loading API to load images and labels (because it's pretty great, and the world doesn't need yet another data loading library :P).\n", "\n", "Of course, you can use JAX with any API that is compatible with NumPy to make specifying the model a bit more plug-and-play. Here, just for explanatory purposes, we won't use any neural network libraries or special APIs for building our model." ] diff --git a/docs/notebooks/neural_network_with_tfds_data.md b/docs/notebooks/neural_network_with_tfds_data.md index 53b7d47358c2..9c153d704763 100644 --- a/docs/notebooks/neural_network_with_tfds_data.md +++ b/docs/notebooks/neural_network_with_tfds_data.md @@ -44,7 +44,7 @@ _Forked from_ `neural_network_and_data_loading.ipynb` ![JAX](https://raw.githubusercontent.com/jax-ml/jax/main/images/jax_logo_250px.png) -Let's combine everything we showed in the [quickstart](https://jax.readthedocs.io/en/latest/quickstart.html) to train a simple neural network. We will first specify and train a simple MLP on MNIST using JAX for the computation. We will use `tensorflow/datasets` data loading API to load images and labels (because it's pretty great, and the world doesn't need yet another data loading library :P). +Let's combine everything we showed in the [quickstart](https://docs.jax.dev/en/latest/quickstart.html) to train a simple neural network. We will first specify and train a simple MLP on MNIST using JAX for the computation. We will use `tensorflow/datasets` data loading API to load images and labels (because it's pretty great, and the world doesn't need yet another data loading library :P). Of course, you can use JAX with any API that is compatible with NumPy to make specifying the model a bit more plug-and-play. Here, just for explanatory purposes, we won't use any neural network libraries or special APIs for building our model. diff --git a/docs/notebooks/shard_map.ipynb b/docs/notebooks/shard_map.ipynb index d73b0d4c0f3e..d04b7583a4a0 100644 --- a/docs/notebooks/shard_map.ipynb +++ b/docs/notebooks/shard_map.ipynb @@ -13,9 +13,9 @@ "\n", "`shard_map` is a single-program multiple-data (SPMD) multi-device parallelism API to map a function over shards of data. Mapped function applications, or _instances_, communicate with each other via explicit collective communication operations.\n", "\n", - "`shard_map` is complementary to, and composable with, the automatic compiler-based parallelization built into `jit`. With `jit` you write code as if for a single device, and [the compiler can automatically partition computation over multiple devices](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html), generating per-device code and communication collectives behind the scenes. With `shard_map` you take control, writing your own partitioned code and explicit collectives. Or you can do a bit of both: take manual control across groups of devices while leaving within-group device partitioning up to the compiler. The two approaches can be mixed, matched, and composed as needed.\n", + "`shard_map` is complementary to, and composable with, the automatic compiler-based parallelization built into `jit`. With `jit` you write code as if for a single device, and [the compiler can automatically partition computation over multiple devices](https://docs.jax.dev/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html), generating per-device code and communication collectives behind the scenes. With `shard_map` you take control, writing your own partitioned code and explicit collectives. Or you can do a bit of both: take manual control across groups of devices while leaving within-group device partitioning up to the compiler. The two approaches can be mixed, matched, and composed as needed.\n", "\n", - "If you're familiar with `pmap`, think of `shard_map` as an evolution. It's more expressive, performant, and composable with other JAX APIs. It even works eagerly, for easier debugging! (For more, see [a detailed comparison to `pmap`.](https://jax.readthedocs.io/en/latest/jep/14273-shard-map.html#why-don-t-pmap-or-xmap-already-solve-this))\n", + "If you're familiar with `pmap`, think of `shard_map` as an evolution. It's more expressive, performant, and composable with other JAX APIs. It even works eagerly, for easier debugging! (For more, see [a detailed comparison to `pmap`.](https://docs.jax.dev/en/latest/jep/14273-shard-map.html#why-don-t-pmap-or-xmap-already-solve-this))\n", "\n", "By reading this tutorial, you'll learn how to use `shard_map` to get full control over your multi-device code. You'll see in detail how it composes with `jax.jit`'s automatic parallelization and `jax.grad`'s automatic differentiation. We'll also give some basic examples of neural network parallelization strategies.\n", "\n", @@ -55,8 +55,7 @@ "import jax\n", "import jax.numpy as jnp\n", "\n", - "from jax.sharding import Mesh, PartitionSpec as P\n", - "from jax.experimental.shard_map import shard_map" + "from jax.sharding import Mesh, PartitionSpec as P" ] }, { @@ -71,7 +70,7 @@ "a = jnp.arange( 8 * 16.).reshape(8, 16)\n", "b = jnp.arange(16 * 4.).reshape(16, 4)\n", "\n", - "@partial(shard_map, mesh=mesh, in_specs=(P('x', 'y'), P('y', None)),\n", + "@partial(jax.shard_map, mesh=mesh, in_specs=(P('x', 'y'), P('y', None)),\n", " out_specs=P('x', None))\n", "def matmul_basic(a_block, b_block):\n", " # a_block: f32[2, 8]\n", @@ -249,7 +248,7 @@ "mesh = Mesh(devices, ('i',)) # mesh.shape['i'] = 4\n", "\n", "def check_shmap(f, y):\n", - " ans = shard_map(f, mesh, in_specs=P('i'), out_specs=P('i'))(y)\n", + " ans = jax.shard_map(f, mesh=mesh, in_specs=P('i'), out_specs=P('i'))(y)\n", " expected = jnp.concatenate([f(y_blk) for y_blk in jnp.split(y, mesh.shape['i'])])\n", " print(allclose(ans, expected))\n", "\n", @@ -296,7 +295,7 @@ "source": [ "mesh = jax.make_mesh((4, 2), ('i', 'j'))\n", "\n", - "@partial(shard_map, mesh=mesh, in_specs=P('i', None), out_specs=P('i', 'j'))\n", + "@partial(jax.shard_map, mesh=mesh, in_specs=P('i', None), out_specs=P('i', 'j'))\n", "def f1(x_block):\n", " print(x_block.shape) # prints (3, 12)\n", " return x_block\n", @@ -327,7 +326,7 @@ "metadata": {}, "outputs": [], "source": [ - "@partial(shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P('i', 'j'))\n", + "@partial(jax.shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P('i', 'j'))\n", "def f2(x_block):\n", " print(x_block.shape)\n", " return x_block\n", @@ -383,13 +382,13 @@ "source": [ "x = jnp.array([[3.]])\n", "\n", - "z = shard_map(lambda: x, mesh=mesh, in_specs=(), out_specs=P('i', 'j'))()\n", + "z = jax.shard_map(lambda: x, mesh=mesh, in_specs=(), out_specs=P('i', 'j'))()\n", "print(z) # prints the same as jnp.tile(x, (4, 2))\n", "\n", - "z = shard_map(lambda: x, mesh=mesh, in_specs=(), out_specs=P('i', None))()\n", + "z = jax.shard_map(lambda: x, mesh=mesh, in_specs=(), out_specs=P('i', None))()\n", "print(z) # prints the same as jnp.tile(x, (4, 1)), or just jnp.tile(x, (4,))\n", "\n", - "z = shard_map(lambda: x, mesh=mesh, in_specs=(), out_specs=P(None, None))()\n", + "z = jax.shard_map(lambda: x, mesh=mesh, in_specs=(), out_specs=P(None, None))()\n", "print(z) # prints the same as jnp.tile(x, (1, 1)), or just x" ] }, @@ -410,7 +409,7 @@ "metadata": {}, "outputs": [], "source": [ - "@partial(shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P('i', None))\n", + "@partial(jax.shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P('i', None))\n", "def f3(x_block):\n", " return jax.lax.psum(x_block, 'j')\n", "\n", @@ -439,7 +438,7 @@ "metadata": {}, "outputs": [], "source": [ - "@partial(shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P(None, 'j'))\n", + "@partial(jax.shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P(None, 'j'))\n", "def f4(x_block):\n", " return jax.lax.psum(x_block, 'i')\n", "\n", @@ -448,7 +447,7 @@ "print(y4.shape) # (3,12)\n", "\n", "\n", - "@partial(shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P(None, None))\n", + "@partial(jax.shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P(None, None))\n", "def f5(x_block):\n", " return jax.lax.psum(x_block, ('i', 'j'))\n", "\n", @@ -481,6 +480,346 @@ "`Array`s, or physically how to interpret the buffers across devices as the\n", "physical layout of a single logical `Array`.\n", "\n", + "#### Tracking how values vary over manual mesh axes, and `check_vma=True`\n", + "\n", + "Under a `shard_map`, values can vary across function instances, or they can be\n", + "the same. For example, when we use `in_specs` to split an argument over a mesh\n", + "axis, each function instance along that mesh axis gets a different value:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "38668c79", + "metadata": {}, + "outputs": [], + "source": [ + "mesh = jax.make_mesh((2,), ('i',))\n", + "\n", + "@partial(jax.shard_map, mesh=mesh, in_specs=P('i'), out_specs=P('i'))\n", + "def f(x):\n", + " print(x)\n", + " return 2 * x\n", + "\n", + "x = jnp.arange(6.)\n", + "f(x)" + ] + }, + { + "cell_type": "markdown", + "id": "00b66850", + "metadata": {}, + "source": [ + "If instead `in_specs` does not split the argument over a mesh axis, the value\n", + "is the same for each function instance along that axis:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9d0dfa6d", + "metadata": {}, + "outputs": [], + "source": [ + "@partial(jax.shard_map, mesh=mesh, in_specs=P(), out_specs=P())\n", + "def f(x):\n", + " print(x)\n", + " return 2 * x\n", + "\n", + "x = jnp.arange(6.)\n", + "f(x)" + ] + }, + { + "cell_type": "markdown", + "id": "594b4574", + "metadata": {}, + "source": [ + "A collective's output may have a different variance than its input. For\n", + "example, applying a `psum` produces the same output on each function instance\n", + "along an axis:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "df486b2f", + "metadata": {}, + "outputs": [], + "source": [ + "@partial(jax.shard_map, mesh=mesh, in_specs=P('i'), out_specs=P())\n", + "def f(x):\n", + " y = jax.lax.psum(x, 'i')\n", + " print(y)\n", + " return y\n", + "\n", + "x = jnp.arange(6.)\n", + "f(x)" + ] + }, + { + "cell_type": "markdown", + "id": "bf6a17ad", + "metadata": {}, + "source": [ + "In general, each intermediate value in a `shard_map` can be either unvarying or\n", + "possibly-varying over each manual mesh axis. That information can be tracked in\n", + "the JAX type system, enabled by the `check_vma=True` argument to `shard_map`:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d7f32190", + "metadata": {}, + "outputs": [], + "source": [ + "@partial(jax.shard_map, mesh=mesh, in_specs=P('i'), out_specs=P())\n", + "def f(x):\n", + " print(jax.typeof(x)) # f32[3]{i}\n", + " y = jax.lax.psum(x, 'i')\n", + " print(jax.typeof(y)) # f32[3]\n", + " return y\n", + "\n", + "x = jnp.arange(6.)\n", + "f(x)" + ] + }, + { + "cell_type": "markdown", + "id": "f76cc47f", + "metadata": {}, + "source": [ + "Here, the type `f32[3]{i}` means that the value of `x` is varying over mesh\n", + "axis `'i'`. The type of `y` printing as `f32[3]` indicates it is unvarying over\n", + "all mesh axes; that is, empty sets are not printed. We call this part of the\n", + "type the _varying manual axes_ (VMA), and it can be accessed via\n", + "`jax.typeof(x).vma`.\n", + "\n", + "In general, the VMA type of a value can include any subset of the manual mesh\n", + "axes over which the `shard_map` is acting:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e69a02d3", + "metadata": {}, + "outputs": [], + "source": [ + "mesh = jax.make_mesh((4, 2), ('i', 'j'))\n", + "\n", + "@partial(jax.shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P('i'))\n", + "def f(x):\n", + " print(jax.typeof(x)) # f32[2,2]{i,j}\n", + " y = jax.lax.psum(x, 'j')\n", + " assert jax.typeof(y).vma == {'i'}\n", + " print(jax.typeof(y)) # f32[2,2]{i}\n", + " return y\n", + "\n", + "x = jnp.arange(8 * 4.).reshape(8, 4)\n", + "f(x)" + ] + }, + { + "cell_type": "markdown", + "id": "a36f1654", + "metadata": {}, + "source": [ + "Tracking varying manual axes can be useful:\n", + "1. Your code can include prints, assertions, or conditionals about whether\n", + " values are varying over expected mesh axes;\n", + "2. It enables efficient reverse-mode autodiff that doesn't require defensive\n", + " `psum`s (see [JEP](https://docs.jax.dev/en/latest/jep/17111-shmap-transpose.html));\n", + "3. The correctness of `out_specs` can be checked, ruling out the potential bug\n", + " example below.\n", + "\n", + "For example, this `out_specs` bug is caught with `check_vma=True`, but uncaught\n", + "without it:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c92c1d4d", + "metadata": {}, + "outputs": [], + "source": [ + "mesh = jax.make_mesh((2,), ('i',))\n", + "\n", + "x = jnp.arange(6.)\n", + "try:\n", + " y = jax.shard_map(lambda x: x, mesh=mesh, in_specs=P('i'), out_specs=P())(x)\n", + "except Exception as e:\n", + " print(e)" + ] + }, + { + "cell_type": "markdown", + "id": "68bc33af", + "metadata": {}, + "source": [ + "Here the `out_specs` incorrectly promise that each function instance along mesh\n", + "axis `'i'` produces the same value and thus we can choose just one of them.\n", + "With `check_vma=True` (the default) it raises an exception, while with\n", + "`check_vma=False` there is no exception and instead we get silent undefined\n", + "behavior.\n", + "\n", + "Sometimes we want to treat a value that is unvarying over a mesh axis as\n", + "varying over that mesh axis. That's what `jax.lax.pvary` does:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "21276d78", + "metadata": {}, + "outputs": [], + "source": [ + "@partial(jax.shard_map, mesh=mesh, in_specs=P(), out_specs=None)\n", + "def f(x):\n", + " print(jax.typeof(x)) # f32[6]\n", + " y = jax.lax.pvary(x, 'i')\n", + " print(jax.typeof(y)) # f32[6]{i}\n", + "\n", + "x = jnp.arange(6.)\n", + "f(x)" + ] + }, + { + "cell_type": "markdown", + "id": "8f766c1a", + "metadata": {}, + "source": [ + "Think of `jax.lax.pvary` as applying a type cast: it's a no-op at runtime,\n", + "though under reverse-mode autodiff it transposes to a `jax.lax.psum` (see\n", + "[JEP](https://docs.jax.dev/en/latest/jep/17111-shmap-transpose.html)). That\n", + "makes sense because they do opposite things to the VMA: where `y: f32[3]{i} =\n", + "jax.lax.pvary(x: f32[3], 'i')`, we correspondingly have `x_grad: f32[3] =\n", + "jax.lax.psum(y_grad: f32[3]{i}, 'i')`.\n", + "\n", + "JAX implicitly inserts `jax.lax.pvary` calls in many cases, especially for\n", + "binary operations:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e22d52a4", + "metadata": {}, + "outputs": [], + "source": [ + "@partial(jax.shard_map, mesh=mesh, in_specs=(P('i'), P()), out_specs=P('i'))\n", + "def f(x, y):\n", + " return x * y\n", + "\n", + "x = jnp.arange(6.)\n", + "y = jnp.arange(3.)\n", + "print(jax.make_jaxpr(f)(x, y))" + ] + }, + { + "cell_type": "markdown", + "id": "1bd7f6a5", + "metadata": {}, + "source": [ + "In a jaxpr, the multiplication operation requires the VMA types of its\n", + "arguments to match, but for convenience the `jax.numpy` and `jax.lax` APIs\n", + "automatically apply `jax.lax.pvary` to make argument VMA types agree.\n", + "\n", + "\n", + "\n", + "In some cases, like with `jax.lax.scan`, you might need to apply\n", + "`jax.lax.pvary` yourself to ensure VMA types match as required. For example,\n", + "this code raises an error:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6e33a5fb", + "metadata": {}, + "outputs": [], + "source": [ + "mesh = jax.make_mesh((2,), ('i',))\n", + "\n", + "@partial(jax.shard_map, mesh=mesh, in_specs=(P('i'), P()), out_specs=P('i'))\n", + "def f(x, y):\n", + " def body(carry, _):\n", + " c1, c2 = carry\n", + " return (c2, c1), () # swap the carry\n", + " (x_, y_), _ = jax.lax.scan(body, (x, y), (), length=2)\n", + " return x_, y_\n", + "\n", + "x = jnp.arange(6.)\n", + "y = jnp.arange(3.)\n", + "\n", + "try:\n", + " f(x, y)\n", + "except Exception as e:\n", + " print(e)" + ] + }, + { + "cell_type": "markdown", + "id": "7b6fef36", + "metadata": {}, + "source": [ + "To make the types match, we need to apply `jax.lax.pvary` to some arguments to\n", + "the `scan`:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8c8dbd11", + "metadata": {}, + "outputs": [], + "source": [ + "mesh = jax.make_mesh((2,), ('i',))\n", + "\n", + "@partial(jax.shard_map, mesh=mesh, in_specs=(P('i'), P()), out_specs=P('i'))\n", + "def f(x, y):\n", + " def body(carry, _):\n", + " c1, c2 = carry\n", + " return (c2, c1), () # swap the carry\n", + "\n", + " y = jax.lax.pvary(y, 'i') # apply pvary to fix the error\n", + " (x_, y_), _ = jax.lax.scan(body, (x, y), (), length=2)\n", + " return x_, y_\n", + "\n", + "x = jnp.arange(6.)\n", + "y = jnp.arange(3.)\n", + "\n", + "f(x, y)" + ] + }, + { + "cell_type": "markdown", + "id": "10271c3c", + "metadata": {}, + "source": [ + "Here's a summary of collective primitives and how they affect varying manual axis types:\n", + "\n", + "| Name | Device variance type | Example | Lowers to HLO | Transpose |\n", + "| --- | --- | --- | --- | --- |\n", + "| `psum_invariant` | `Varying -> Invariant` | `y:f32[3]{j} = psum(x:f32[3]{i,j}, axis='i')` | `AllReduceSum` (communication) | `pvary` |\n", + "| `pvary` | `Invariant -> Varying` | `y:f32[3]{i} = pvary(x:f32[3], 'i')` | no-op (no communication) | `psum_invariant` |\n", + "| `all_to_all` | `Varying -> Varying` | `y:f32[16]{i} = all_to_all(x:f32[16]{i}, 'i', 0, 0)` `AllToAll` (communication) | `all_to_all` |\n", + "| `axis_index` | `() -> Varying` | `idx:i32[]{i} = axis_index('i')` | `ReplicaId` and some arithmetic (no communication) | n/a |\n", + "| `psum_scatter` | `Varying -> Varying` | `y:f32[2]{i} = psum_scatter(x:f32[16]{i}, 'i')` | `ReduceScatterSum` (communication) | `all_gather` |\n", + "| `all_gather` | `Varying -> Varying` | `y:f32[16]{i} = all_gather(x:f32[2]{i}, 'i')` | `AllGather` (communication) | `psum_scatter` |\n", + "| `pscatter` | `Invariant -> Varying` | `y:f32[2]{i} = pscatter(x:f32[16], 'i')` | `lambda x: x[axis_index('i'), None]` (no communication) | `all_gather_invariant` |\n", + "| `all_gather_invariant` | `Varying -> Invariant` | `y:f32[16] = all_gather_invariant(x:f32[2]{i}, 'i')` | `AllGather` (communication) | `pscatter` |\n", + "\n", + "A few notes on the table:\n", + "* The function `jax.lax.psum` is a convenience wrapper around `psum_invariant`.\n", + "* It's surprising that `all_gather` is `Varying -> Varying`, but that's because\n", + " it's really the transpose of `psum_scatter` which is `Varying -> Varying`.\n", + "* Neither `pscatter` nor `all_gather_invariant` have user APIs at the time of\n", + " writing, but they're described here for completeness.\n", + "\n", + "\n", "## API Specification\n", "\n", "```python\n", @@ -488,18 +827,21 @@ "Specs = PyTree[PartitionSpec]\n", "\n", "def shard_map(\n", - " f: Callable, mesh: Mesh, in_specs: Specs, out_specs: Specs,\n", - " auto: collections.abc.Set[AxisName] = frozenset([]),\n", - " check_rep: bool = True,\n", + " f: Callable, /, *, out_specs: Specs, mesh: Mesh | None = None,\n", + " in_specs: Specs | None = None,\n", + " axis_names: collections.abc.Set[AxisName] = set(),\n", + " check_vma: bool = True,\n", ") -> Callable:\n", " ...\n", "```\n", "where:\n", "* communication collectives like `psum` in the body of `f` can mention the axis names of `mesh`;\n", - "* `mesh` encodes devices arranged in an array and with associated axis names, just like it does for `sharding.NamedSharding`;\n", - "* `in_specs` and `out_specs` are `PartitionSpec`s which can affinely mention axis names from `mesh` to express slicing/unconcatenation and concatenation of inputs and outputs, respectively, with unmentioned names corresponding to replication and untiling (assert-replicated-so-give-me-one-copy), respectively;\n", - "* `auto` is an optional set of axis names corresponding to the subset of names of `mesh` to treat automatically in the body, as in the caller, rather than manually;\n", - "* `check_rep` is an optional boolean indicating whether to check statically for any replication errors in `out_specs`, and also whether to enable a related automatic differentiation optimization (see [JEP](https://jax.readthedocs.io/en/latest/jep/17111-shmap-transpose.html)).\n", + "* `mesh` encodes devices arranged in an array and with associated axis names, just like it does for `sharding.NamedSharding`; If None, mesh will be inferred from the\n", + "context which can be set via the `jax.sharding.use_mesh` context manager.\n", + "* `in_specs` are `PartitionSpec`s which can zero or one times mention axis names from `mesh` to express slicing/unconcatenation of inputs, respectively, with unmentioned names corresponding to replication and untiling (assert-replicated-so-give-me-one-copy). If None, all mesh axes must be of type `Explicit`, in which case the in_specs are inferred from the argument types;\n", + "* `out_specs` are `PartitionSpec`s which can zero or one times mention axis names from `mesh` to express concatenation of outputs, with unmentioned names corresponding to replication and untiling (assert-replicated-so-give-me-one-copy), respectively;\n", + "* `axis_names` is an optional set of axis names corresponding to the subset of names of `mesh` to treat manual in the body. If empty, `f` is manual over all axes of the mesh.\n", + "* `check_vma` is an optional boolean indicating whether to check statically for any replication errors in `out_specs`, and also whether to enable a related automatic differentiation optimization (see [JEP](https://docs.jax.dev/en/latest/jep/17111-shmap-transpose.html)).\n", "\n", "The shapes of the arguments passed to `f` have the same ranks as the arguments\n", "passed to `shard_map`-of-`f`, and the shape of an argument to `f` is computed\n", @@ -521,7 +863,7 @@ "```python\n", "mesh = Mesh(jax.devices(), ('i',))\n", "x = jnp.arange(16.)\n", - "f_shmapped = shard_map(f, mesh, in_specs=P('i'), out_specs=P('i'))\n", + "f_shmapped = jax.shard_map(f, mesh=mesh, in_specs=P('i'), out_specs=P('i'))\n", "y = f_shmapped(x)\n", "```\n", "\n", @@ -593,8 +935,7 @@ "import jax.numpy as jnp\n", "from jax import lax\n", "\n", - "from jax.sharding import Mesh, NamedSharding, PartitionSpec as P\n", - "from jax.experimental.shard_map import shard_map" + "from jax.sharding import Mesh, NamedSharding, PartitionSpec as P" ] }, { @@ -606,7 +947,7 @@ "source": [ "mesh1d = Mesh(jax.devices()[:4], ('i',))\n", "\n", - "@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P(None))\n", + "@partial(jax.shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P(None))\n", "def f1(x_block):\n", " print('BEFORE:\\n', x_block)\n", " y_block = jax.lax.psum(x_block, 'i')\n", @@ -662,7 +1003,7 @@ "source": [ "mesh2d = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('i', 'j'))\n", "\n", - "@partial(shard_map, mesh=mesh2d, in_specs=P('i', 'j'), out_specs=P(None, 'j'))\n", + "@partial(jax.shard_map, mesh=mesh2d, in_specs=P('i', 'j'), out_specs=P(None, 'j'))\n", "def f2(x_block):\n", " print('BEFORE:\\n', x_block)\n", " y_block = jax.lax.psum(x_block, 'i')\n", @@ -693,7 +1034,7 @@ "metadata": {}, "outputs": [], "source": [ - "@partial(shard_map, mesh=mesh2d, in_specs=P('i', 'j'), out_specs=P(None, None))\n", + "@partial(jax.shard_map, mesh=mesh2d, in_specs=P('i', 'j'), out_specs=P(None, None))\n", "def f3(x_block):\n", " print('BEFORE:\\n', x_block)\n", " y_block = jax.lax.psum(x_block, ('i', 'j'))\n", @@ -730,7 +1071,7 @@ "metadata": {}, "outputs": [], "source": [ - "@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))\n", + "@partial(jax.shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))\n", "def f4(x_block):\n", " print('BEFORE:\\n', x_block)\n", " y_block = jax.lax.all_gather(x_block, 'i', tiled=True)\n", @@ -769,7 +1110,7 @@ "metadata": {}, "outputs": [], "source": [ - "@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))\n", + "@partial(jax.shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))\n", "def f5(x_block):\n", " print('BEFORE:\\n', x_block)\n", " y_block = jax.lax.all_gather(x_block, 'i', tiled=False)\n", @@ -812,7 +1153,7 @@ "metadata": {}, "outputs": [], "source": [ - "@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))\n", + "@partial(jax.shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))\n", "def f6(x_block):\n", " print('BEFORE:\\n', x_block)\n", " y_block = jax.lax.psum_scatter(x_block, 'i', tiled=True)\n", @@ -888,9 +1229,9 @@ "metadata": {}, "outputs": [], "source": [ - "@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))\n", + "@partial(jax.shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))\n", "def f7(x_block):\n", - " sz = jax.lax.psum(1, 'i')\n", + " sz = jax.lax.axis_size('i')\n", " print('BEFORE:\\n', x_block)\n", " y_block = jax.lax.ppermute(x_block, 'i', [(i, (i + 1) % sz) for i in range(sz)])\n", " print('AFTER:\\n', y_block)\n", @@ -947,7 +1288,7 @@ "outputs": [], "source": [ "def psum_scatter(x, axis_name, *, tiled=False):\n", - " size = jax.lax.psum(1, axis_name)\n", + " size = jax.lax.axis_size(axis_name)\n", " idx = jax.lax.axis_index(axis_name) # function instance index along axis_name\n", " if tiled:\n", " x = x.reshape(size, -1, *x.shape[1:]) # split leading axis\n", @@ -966,7 +1307,7 @@ "metadata": {}, "outputs": [], "source": [ - "@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))\n", + "@partial(jax.shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))\n", "def f8(x_block):\n", " print('BEFORE:\\n', x_block)\n", " y_block = psum_scatter(x_block, 'i', tiled=True)\n", @@ -1014,7 +1355,7 @@ "metadata": {}, "outputs": [], "source": [ - "@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))\n", + "@partial(jax.shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))\n", "def f9(x_block):\n", " print('BEFORE:\\n', x_block)\n", " y_block = jax.lax.all_to_all(x_block, 'i', split_axis=0, concat_axis=0,\n", @@ -1086,8 +1427,7 @@ "import jax\n", "import jax.numpy as jnp\n", "\n", - "from jax.sharding import Mesh, NamedSharding, PartitionSpec as P\n", - "from jax.experimental.shard_map import shard_map" + "from jax.sharding import Mesh, NamedSharding, PartitionSpec as P" ] }, { @@ -1163,7 +1503,7 @@ "outputs": [], "source": [ "@jax.jit\n", - "@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),\n", + "@partial(jax.shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),\n", " out_specs=rhs_spec)\n", "def matmul_allgather(lhs_block, rhs_block):\n", " rhs = jax.lax.all_gather(rhs_block, 'i', tiled=True)\n", @@ -1207,10 +1547,10 @@ "outputs": [], "source": [ "@jax.jit\n", - "@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),\n", + "@partial(jax.shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),\n", " out_specs=rhs_spec)\n", "def matmul_allgather_overlapped(lhs_block, rhs_block):\n", - " size = jax.lax.psum(1, 'i')\n", + " size = jax.lax.axis_size('i')\n", " idx = jax.lax.axis_index('i')\n", " shift = partial(jax.lax.ppermute, axis_name='i',\n", " perm=[(i, (i + 1) % size) for i in range(size)])\n", @@ -1256,10 +1596,10 @@ "outputs": [], "source": [ "@jax.jit\n", - "@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),\n", + "@partial(jax.shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),\n", " out_specs=rhs_spec)\n", "def matmul_allgather_overlapped_bidi(lhs_block, rhs_block):\n", - " size = jax.lax.psum(1, 'i')\n", + " size = jax.lax.axis_size('i')\n", " idx = jax.lax.axis_index('i')\n", " shift_up = partial(jax.lax.ppermute, axis_name='i',\n", " perm=[(i, (i + 1) % size) for i in range(size)])\n", @@ -1337,7 +1677,7 @@ "metadata": {}, "outputs": [], "source": [ - "@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),\n", + "@partial(jax.shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),\n", " out_specs=rhs_spec)\n", "def matmul_psumscatter(lhs_block, rhs_block):\n", " out_summand = lhs_block @ rhs_block\n", @@ -1365,10 +1705,10 @@ "metadata": {}, "outputs": [], "source": [ - "@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),\n", + "@partial(jax.shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),\n", " out_specs=rhs_spec)\n", "def matmul_psumscatter_overlapped(lhs_block, rhs_block):\n", - " size = jax.lax.psum(1, 'i')\n", + " size = jax.lax.axis_size('i')\n", " idx = jax.lax.axis_index('i')\n", " shift = partial(jax.lax.ppermute, axis_name='i',\n", " perm=[(i, (i - 1) % size) for i in range(size)])\n", @@ -1408,10 +1748,10 @@ "metadata": {}, "outputs": [], "source": [ - "@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),\n", + "@partial(jax.shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),\n", " out_specs=rhs_spec)\n", "def matmul_psumscatter_overlapped_bidi(lhs_block, rhs_block):\n", - " size = jax.lax.psum(1, 'i')\n", + " size = jax.lax.axis_size('i')\n", " idx = jax.lax.axis_index('i')\n", " shift_up = partial(jax.lax.ppermute, axis_name='i',\n", " perm=[(i, (i + 1) % size) for i in range(size)])\n", @@ -1520,7 +1860,7 @@ "source": [ "Compare these examples with the purely [automatic partitioning examples in the\n", "\"Distributed arrays and automatic partitioning\"\n", - "doc](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html).\n", + "doc](https://docs.jax.dev/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html).\n", "While in those automatic partitioning examples we don't need to edit the model\n", "functions to use different parallelization strategies, with `shard_map` we\n", "often do.\n", @@ -1545,7 +1885,6 @@ "from functools import partial\n", "\n", "from jax.sharding import NamedSharding, Mesh, PartitionSpec as P\n", - "from jax.experimental.shard_map import shard_map\n", "\n", "mesh = jax.make_mesh((8,), ('batch',))\n", "\n", @@ -1555,7 +1894,7 @@ "\n", "# adapt the loss function to sum the losses across devices\n", "def loss_dp(params, batch):\n", - " @partial(shard_map, mesh=mesh, in_specs=P('batch', None), out_specs=P())\n", + " @partial(jax.shard_map, mesh=mesh, in_specs=P('batch', None), out_specs=P())\n", " def loss_spmd(local_batch):\n", " inputs, targets = local_batch\n", " predictions = predict(params, inputs) # use reference 'predict`\n", @@ -1626,7 +1965,7 @@ "parameters from the forward pass for use on the backward pass. Instead, we want\n", "to gather them again on the backward pass. We can express that by using\n", "`jax.remat` with a [custom\n", - "policy](https://jax.readthedocs.io/en/latest/notebooks/autodiff_remat.html#custom-policies-for-what-s-saveable)\n", + "policy](https://docs.jax.dev/en/latest/notebooks/autodiff_remat.html#custom-policies-for-what-s-saveable)\n", "(or a `custom_vjp`), though XLA typically does that rematerialization\n", "automatically.\n", "\n", @@ -1660,7 +1999,7 @@ " return outputs\n", "\n", "def loss_fsdp(params, batch):\n", - " @partial(shard_map, mesh=mesh, in_specs=P('batch'), out_specs=P())\n", + " @partial(jax.shard_map, mesh=mesh, in_specs=P('batch'), out_specs=P())\n", " def loss_spmd(local_params, local_batch):\n", " inputs, targets = local_batch\n", " predictions = predict_fsdp(local_params, inputs)\n", @@ -1729,7 +2068,7 @@ " inputs = jax.nn.relu(outputs)\n", " return outputs\n", "\n", - "@partial(shard_map, mesh=mesh,\n", + "@partial(jax.shard_map, mesh=mesh,\n", " in_specs=(P(None, 'feats'), P('feats', None), P('feats')),\n", " out_specs=P(None, 'feats'))\n", "def gemm_tp(inputs, W, b):\n", @@ -1777,7 +2116,7 @@ " inputs = jax.nn.relu(outputs)\n", " return outputs\n", "\n", - "@partial(shard_map, mesh=mesh,\n", + "@partial(jax.shard_map, mesh=mesh,\n", " in_specs=(P(('feats', 'batch')), P('batch', 'feats')),\n", " out_specs=P())\n", "def loss_fsdp_tp(local_params, local_batch):\n", @@ -1887,7 +2226,7 @@ " outputs = jnp.dot(inputs, W_last) + b_last\n", " return outputs\n", "\n", - "@partial(shard_map, mesh=mesh, in_specs=((P(), P('stages'), P()), P('stages')),\n", + "@partial(jax.shard_map, mesh=mesh, in_specs=((P(), P('stages'), P()), P('stages')),\n", " out_specs=P())\n", "def loss_pp(params, batch):\n", " inputs, targets = batch\n", diff --git a/docs/notebooks/shard_map.md b/docs/notebooks/shard_map.md index c52cf0e6d22b..bf139b48d6f3 100644 --- a/docs/notebooks/shard_map.md +++ b/docs/notebooks/shard_map.md @@ -22,9 +22,9 @@ kernelspec: `shard_map` is a single-program multiple-data (SPMD) multi-device parallelism API to map a function over shards of data. Mapped function applications, or _instances_, communicate with each other via explicit collective communication operations. -`shard_map` is complementary to, and composable with, the automatic compiler-based parallelization built into `jit`. With `jit` you write code as if for a single device, and [the compiler can automatically partition computation over multiple devices](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html), generating per-device code and communication collectives behind the scenes. With `shard_map` you take control, writing your own partitioned code and explicit collectives. Or you can do a bit of both: take manual control across groups of devices while leaving within-group device partitioning up to the compiler. The two approaches can be mixed, matched, and composed as needed. +`shard_map` is complementary to, and composable with, the automatic compiler-based parallelization built into `jit`. With `jit` you write code as if for a single device, and [the compiler can automatically partition computation over multiple devices](https://docs.jax.dev/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html), generating per-device code and communication collectives behind the scenes. With `shard_map` you take control, writing your own partitioned code and explicit collectives. Or you can do a bit of both: take manual control across groups of devices while leaving within-group device partitioning up to the compiler. The two approaches can be mixed, matched, and composed as needed. -If you're familiar with `pmap`, think of `shard_map` as an evolution. It's more expressive, performant, and composable with other JAX APIs. It even works eagerly, for easier debugging! (For more, see [a detailed comparison to `pmap`.](https://jax.readthedocs.io/en/latest/jep/14273-shard-map.html#why-don-t-pmap-or-xmap-already-solve-this)) +If you're familiar with `pmap`, think of `shard_map` as an evolution. It's more expressive, performant, and composable with other JAX APIs. It even works eagerly, for easier debugging! (For more, see [a detailed comparison to `pmap`.](https://docs.jax.dev/en/latest/jep/14273-shard-map.html#why-don-t-pmap-or-xmap-already-solve-this)) By reading this tutorial, you'll learn how to use `shard_map` to get full control over your multi-device code. You'll see in detail how it composes with `jax.jit`'s automatic parallelization and `jax.grad`'s automatic differentiation. We'll also give some basic examples of neural network parallelization strategies. @@ -46,7 +46,6 @@ import jax import jax.numpy as jnp from jax.sharding import Mesh, PartitionSpec as P -from jax.experimental.shard_map import shard_map ``` ```{code-cell} @@ -55,7 +54,7 @@ mesh = jax.make_mesh((4, 2), ('x', 'y')) a = jnp.arange( 8 * 16.).reshape(8, 16) b = jnp.arange(16 * 4.).reshape(16, 4) -@partial(shard_map, mesh=mesh, in_specs=(P('x', 'y'), P('y', None)), +@partial(jax.shard_map, mesh=mesh, in_specs=(P('x', 'y'), P('y', None)), out_specs=P('x', None)) def matmul_basic(a_block, b_block): # a_block: f32[2, 8] @@ -161,7 +160,7 @@ devices = np.array(jax.devices()[:4]) mesh = Mesh(devices, ('i',)) # mesh.shape['i'] = 4 def check_shmap(f, y): - ans = shard_map(f, mesh, in_specs=P('i'), out_specs=P('i'))(y) + ans = jax.shard_map(f, mesh=mesh, in_specs=P('i'), out_specs=P('i'))(y) expected = jnp.concatenate([f(y_blk) for y_blk in jnp.split(y, mesh.shape['i'])]) print(allclose(ans, expected)) @@ -196,7 +195,7 @@ then there's no splitting over that mesh axis. For example: ```{code-cell} mesh = jax.make_mesh((4, 2), ('i', 'j')) -@partial(shard_map, mesh=mesh, in_specs=P('i', None), out_specs=P('i', 'j')) +@partial(jax.shard_map, mesh=mesh, in_specs=P('i', None), out_specs=P('i', 'j')) def f1(x_block): print(x_block.shape) # prints (3, 12) return x_block @@ -215,7 +214,7 @@ less efficient program where all mesh axes are mentioned but the caller performs a `jnp.tile`, for example: ```{code-cell} -@partial(shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P('i', 'j')) +@partial(jax.shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P('i', 'j')) def f2(x_block): print(x_block.shape) return x_block @@ -259,13 +258,13 @@ using the same mesh as above: ```{code-cell} x = jnp.array([[3.]]) -z = shard_map(lambda: x, mesh=mesh, in_specs=(), out_specs=P('i', 'j'))() +z = jax.shard_map(lambda: x, mesh=mesh, in_specs=(), out_specs=P('i', 'j'))() print(z) # prints the same as jnp.tile(x, (4, 2)) -z = shard_map(lambda: x, mesh=mesh, in_specs=(), out_specs=P('i', None))() +z = jax.shard_map(lambda: x, mesh=mesh, in_specs=(), out_specs=P('i', None))() print(z) # prints the same as jnp.tile(x, (4, 1)), or just jnp.tile(x, (4,)) -z = shard_map(lambda: x, mesh=mesh, in_specs=(), out_specs=P(None, None))() +z = jax.shard_map(lambda: x, mesh=mesh, in_specs=(), out_specs=P(None, None))() print(z) # prints the same as jnp.tile(x, (1, 1)), or just x ``` @@ -274,7 +273,7 @@ augment with a corresponding input pspec of P(None, None). As another example, following more closely to the other examples above: ```{code-cell} -@partial(shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P('i', None)) +@partial(jax.shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P('i', None)) def f3(x_block): return jax.lax.psum(x_block, 'j') @@ -291,7 +290,7 @@ two more examples where we vary which mesh axes are mentioned in the output pspec: ```{code-cell} -@partial(shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P(None, 'j')) +@partial(jax.shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P(None, 'j')) def f4(x_block): return jax.lax.psum(x_block, 'i') @@ -300,7 +299,7 @@ y4 = f4(x) print(y4.shape) # (3,12) -@partial(shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P(None, None)) +@partial(jax.shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P(None, None)) def f5(x_block): return jax.lax.psum(x_block, ('i', 'j')) @@ -328,6 +327,226 @@ Instead, `out_specs` just encodes how to assemble the block outputs into `Array`s, or physically how to interpret the buffers across devices as the physical layout of a single logical `Array`. +#### Tracking how values vary over manual mesh axes, and `check_vma=True` + +Under a `shard_map`, values can vary across function instances, or they can be +the same. For example, when we use `in_specs` to split an argument over a mesh +axis, each function instance along that mesh axis gets a different value: + +```{code-cell} +mesh = jax.make_mesh((2,), ('i',)) + +@partial(jax.shard_map, mesh=mesh, in_specs=P('i'), out_specs=P('i')) +def f(x): + print(x) + return 2 * x + +x = jnp.arange(6.) +f(x) +``` + +If instead `in_specs` does not split the argument over a mesh axis, the value +is the same for each function instance along that axis: + +```{code-cell} +@partial(jax.shard_map, mesh=mesh, in_specs=P(), out_specs=P()) +def f(x): + print(x) + return 2 * x + +x = jnp.arange(6.) +f(x) +``` + +A collective's output may have a different variance than its input. For +example, applying a `psum` produces the same output on each function instance +along an axis: + +```{code-cell} +@partial(jax.shard_map, mesh=mesh, in_specs=P('i'), out_specs=P()) +def f(x): + y = jax.lax.psum(x, 'i') + print(y) + return y + +x = jnp.arange(6.) +f(x) +``` + +In general, each intermediate value in a `shard_map` can be either unvarying or +possibly-varying over each manual mesh axis. That information can be tracked in +the JAX type system, enabled by the `check_vma=True` argument to `shard_map`: + +```{code-cell} +@partial(jax.shard_map, mesh=mesh, in_specs=P('i'), out_specs=P()) +def f(x): + print(jax.typeof(x)) # f32[3]{i} + y = jax.lax.psum(x, 'i') + print(jax.typeof(y)) # f32[3] + return y + +x = jnp.arange(6.) +f(x) +``` + +Here, the type `f32[3]{i}` means that the value of `x` is varying over mesh +axis `'i'`. The type of `y` printing as `f32[3]` indicates it is unvarying over +all mesh axes; that is, empty sets are not printed. We call this part of the +type the _varying manual axes_ (VMA), and it can be accessed via +`jax.typeof(x).vma`. + +In general, the VMA type of a value can include any subset of the manual mesh +axes over which the `shard_map` is acting: + +```{code-cell} +mesh = jax.make_mesh((4, 2), ('i', 'j')) + +@partial(jax.shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P('i')) +def f(x): + print(jax.typeof(x)) # f32[2,2]{i,j} + y = jax.lax.psum(x, 'j') + assert jax.typeof(y).vma == {'i'} + print(jax.typeof(y)) # f32[2,2]{i} + return y + +x = jnp.arange(8 * 4.).reshape(8, 4) +f(x) +``` + +Tracking varying manual axes can be useful: +1. Your code can include prints, assertions, or conditionals about whether + values are varying over expected mesh axes; +2. It enables efficient reverse-mode autodiff that doesn't require defensive + `psum`s (see [JEP](https://docs.jax.dev/en/latest/jep/17111-shmap-transpose.html)); +3. The correctness of `out_specs` can be checked, ruling out the potential bug + example below. + +For example, this `out_specs` bug is caught with `check_vma=True`, but uncaught +without it: + +```{code-cell} +mesh = jax.make_mesh((2,), ('i',)) + +x = jnp.arange(6.) +try: + y = jax.shard_map(lambda x: x, mesh=mesh, in_specs=P('i'), out_specs=P())(x) +except Exception as e: + print(e) +``` + +Here the `out_specs` incorrectly promise that each function instance along mesh +axis `'i'` produces the same value and thus we can choose just one of them. +With `check_vma=True` (the default) it raises an exception, while with +`check_vma=False` there is no exception and instead we get silent undefined +behavior. + +Sometimes we want to treat a value that is unvarying over a mesh axis as +varying over that mesh axis. That's what `jax.lax.pvary` does: + +```{code-cell} +@partial(jax.shard_map, mesh=mesh, in_specs=P(), out_specs=None) +def f(x): + print(jax.typeof(x)) # f32[6] + y = jax.lax.pvary(x, 'i') + print(jax.typeof(y)) # f32[6]{i} + +x = jnp.arange(6.) +f(x) +``` + +Think of `jax.lax.pvary` as applying a type cast: it's a no-op at runtime, +though under reverse-mode autodiff it transposes to a `jax.lax.psum` (see +[JEP](https://docs.jax.dev/en/latest/jep/17111-shmap-transpose.html)). That +makes sense because they do opposite things to the VMA: where `y: f32[3]{i} = +jax.lax.pvary(x: f32[3], 'i')`, we correspondingly have `x_grad: f32[3] = +jax.lax.psum(y_grad: f32[3]{i}, 'i')`. + +JAX implicitly inserts `jax.lax.pvary` calls in many cases, especially for +binary operations: + +```{code-cell} +@partial(jax.shard_map, mesh=mesh, in_specs=(P('i'), P()), out_specs=P('i')) +def f(x, y): + return x * y + +x = jnp.arange(6.) +y = jnp.arange(3.) +print(jax.make_jaxpr(f)(x, y)) +``` + +In a jaxpr, the multiplication operation requires the VMA types of its +arguments to match, but for convenience the `jax.numpy` and `jax.lax` APIs +automatically apply `jax.lax.pvary` to make argument VMA types agree. + + + +In some cases, like with `jax.lax.scan`, you might need to apply +`jax.lax.pvary` yourself to ensure VMA types match as required. For example, +this code raises an error: + +```{code-cell} +mesh = jax.make_mesh((2,), ('i',)) + +@partial(jax.shard_map, mesh=mesh, in_specs=(P('i'), P()), out_specs=P('i')) +def f(x, y): + def body(carry, _): + c1, c2 = carry + return (c2, c1), () # swap the carry + (x_, y_), _ = jax.lax.scan(body, (x, y), (), length=2) + return x_, y_ + +x = jnp.arange(6.) +y = jnp.arange(3.) + +try: + f(x, y) +except Exception as e: + print(e) +``` + +To make the types match, we need to apply `jax.lax.pvary` to some arguments to +the `scan`: + +```{code-cell} +mesh = jax.make_mesh((2,), ('i',)) + +@partial(jax.shard_map, mesh=mesh, in_specs=(P('i'), P()), out_specs=P('i')) +def f(x, y): + def body(carry, _): + c1, c2 = carry + return (c2, c1), () # swap the carry + + y = jax.lax.pvary(y, 'i') # apply pvary to fix the error + (x_, y_), _ = jax.lax.scan(body, (x, y), (), length=2) + return x_, y_ + +x = jnp.arange(6.) +y = jnp.arange(3.) + +f(x, y) +``` + +Here's a summary of collective primitives and how they affect varying manual axis types: + +| Name | Device variance type | Example | Lowers to HLO | Transpose | +| --- | --- | --- | --- | --- | +| `psum_invariant` | `Varying -> Invariant` | `y:f32[3]{j} = psum(x:f32[3]{i,j}, axis='i')` | `AllReduceSum` (communication) | `pvary` | +| `pvary` | `Invariant -> Varying` | `y:f32[3]{i} = pvary(x:f32[3], 'i')` | no-op (no communication) | `psum_invariant` | +| `all_to_all` | `Varying -> Varying` | `y:f32[16]{i} = all_to_all(x:f32[16]{i}, 'i', 0, 0)` `AllToAll` (communication) | `all_to_all` | +| `axis_index` | `() -> Varying` | `idx:i32[]{i} = axis_index('i')` | `ReplicaId` and some arithmetic (no communication) | n/a | +| `psum_scatter` | `Varying -> Varying` | `y:f32[2]{i} = psum_scatter(x:f32[16]{i}, 'i')` | `ReduceScatterSum` (communication) | `all_gather` | +| `all_gather` | `Varying -> Varying` | `y:f32[16]{i} = all_gather(x:f32[2]{i}, 'i')` | `AllGather` (communication) | `psum_scatter` | +| `pscatter` | `Invariant -> Varying` | `y:f32[2]{i} = pscatter(x:f32[16], 'i')` | `lambda x: x[axis_index('i'), None]` (no communication) | `all_gather_invariant` | +| `all_gather_invariant` | `Varying -> Invariant` | `y:f32[16] = all_gather_invariant(x:f32[2]{i}, 'i')` | `AllGather` (communication) | `pscatter` | + +A few notes on the table: +* The function `jax.lax.psum` is a convenience wrapper around `psum_invariant`. +* It's surprising that `all_gather` is `Varying -> Varying`, but that's because + it's really the transpose of `psum_scatter` which is `Varying -> Varying`. +* Neither `pscatter` nor `all_gather_invariant` have user APIs at the time of + writing, but they're described here for completeness. + + ## API Specification ```python @@ -335,18 +554,21 @@ from jax.sharding import Mesh Specs = PyTree[PartitionSpec] def shard_map( - f: Callable, mesh: Mesh, in_specs: Specs, out_specs: Specs, - auto: collections.abc.Set[AxisName] = frozenset([]), - check_rep: bool = True, + f: Callable, /, *, out_specs: Specs, mesh: Mesh | None = None, + in_specs: Specs | None = None, + axis_names: collections.abc.Set[AxisName] = set(), + check_vma: bool = True, ) -> Callable: ... ``` where: * communication collectives like `psum` in the body of `f` can mention the axis names of `mesh`; -* `mesh` encodes devices arranged in an array and with associated axis names, just like it does for `sharding.NamedSharding`; -* `in_specs` and `out_specs` are `PartitionSpec`s which can affinely mention axis names from `mesh` to express slicing/unconcatenation and concatenation of inputs and outputs, respectively, with unmentioned names corresponding to replication and untiling (assert-replicated-so-give-me-one-copy), respectively; -* `auto` is an optional set of axis names corresponding to the subset of names of `mesh` to treat automatically in the body, as in the caller, rather than manually; -* `check_rep` is an optional boolean indicating whether to check statically for any replication errors in `out_specs`, and also whether to enable a related automatic differentiation optimization (see [JEP](https://jax.readthedocs.io/en/latest/jep/17111-shmap-transpose.html)). +* `mesh` encodes devices arranged in an array and with associated axis names, just like it does for `sharding.NamedSharding`; If None, mesh will be inferred from the +context which can be set via the `jax.sharding.use_mesh` context manager. +* `in_specs` are `PartitionSpec`s which can zero or one times mention axis names from `mesh` to express slicing/unconcatenation of inputs, respectively, with unmentioned names corresponding to replication and untiling (assert-replicated-so-give-me-one-copy). If None, all mesh axes must be of type `Explicit`, in which case the in_specs are inferred from the argument types; +* `out_specs` are `PartitionSpec`s which can zero or one times mention axis names from `mesh` to express concatenation of outputs, with unmentioned names corresponding to replication and untiling (assert-replicated-so-give-me-one-copy), respectively; +* `axis_names` is an optional set of axis names corresponding to the subset of names of `mesh` to treat manual in the body. If empty, `f` is manual over all axes of the mesh. +* `check_vma` is an optional boolean indicating whether to check statically for any replication errors in `out_specs`, and also whether to enable a related automatic differentiation optimization (see [JEP](https://docs.jax.dev/en/latest/jep/17111-shmap-transpose.html)). The shapes of the arguments passed to `f` have the same ranks as the arguments passed to `shard_map`-of-`f`, and the shape of an argument to `f` is computed @@ -368,7 +590,7 @@ so that this: ```python mesh = Mesh(jax.devices(), ('i',)) x = jnp.arange(16.) -f_shmapped = shard_map(f, mesh, in_specs=P('i'), out_specs=P('i')) +f_shmapped = jax.shard_map(f, mesh=mesh, in_specs=P('i'), out_specs=P('i')) y = f_shmapped(x) ``` @@ -434,13 +656,12 @@ import jax.numpy as jnp from jax import lax from jax.sharding import Mesh, NamedSharding, PartitionSpec as P -from jax.experimental.shard_map import shard_map ``` ```{code-cell} mesh1d = Mesh(jax.devices()[:4], ('i',)) -@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P(None)) +@partial(jax.shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P(None)) def f1(x_block): print('BEFORE:\n', x_block) y_block = jax.lax.psum(x_block, 'i') @@ -478,7 +699,7 @@ each one separately, or over multiple axes at once: ```{code-cell} mesh2d = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('i', 'j')) -@partial(shard_map, mesh=mesh2d, in_specs=P('i', 'j'), out_specs=P(None, 'j')) +@partial(jax.shard_map, mesh=mesh2d, in_specs=P('i', 'j'), out_specs=P(None, 'j')) def f2(x_block): print('BEFORE:\n', x_block) y_block = jax.lax.psum(x_block, 'i') @@ -497,7 +718,7 @@ If we apply the `psum` over both axes, the `y_block` value is equal along both axes: ```{code-cell} -@partial(shard_map, mesh=mesh2d, in_specs=P('i', 'j'), out_specs=P(None, None)) +@partial(jax.shard_map, mesh=mesh2d, in_specs=P('i', 'j'), out_specs=P(None, None)) def f3(x_block): print('BEFORE:\n', x_block) y_block = jax.lax.psum(x_block, ('i', 'j')) @@ -522,7 +743,7 @@ each function application has a full copy of the data along that axis: Illustration of an all_gather computation. ```{code-cell} -@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i')) +@partial(jax.shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i')) def f4(x_block): print('BEFORE:\n', x_block) y_block = jax.lax.all_gather(x_block, 'i', tiled=True) @@ -549,7 +770,7 @@ When `tiled=False` (the default), results are stacked along a new axis instead of concatenated: ```{code-cell} -@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i')) +@partial(jax.shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i')) def f5(x_block): print('BEFORE:\n', x_block) y_block = jax.lax.all_gather(x_block, 'i', tiled=False) @@ -580,7 +801,7 @@ The `jax.lax.psum_scatter` collective is a bit less intuitive. It's like Illustration of a psum_scatter computation. ```{code-cell} -@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i')) +@partial(jax.shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i')) def f6(x_block): print('BEFORE:\n', x_block) y_block = jax.lax.psum_scatter(x_block, 'i', tiled=True) @@ -644,9 +865,9 @@ that mesh axis, `ppermute` sends its argument value from each source function instance to each destination: ```{code-cell} -@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i')) +@partial(jax.shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i')) def f7(x_block): - sz = jax.lax.psum(1, 'i') + sz = jax.lax.axis_size('i') print('BEFORE:\n', x_block) y_block = jax.lax.ppermute(x_block, 'i', [(i, (i + 1) % sz) for i in range(sz)]) print('AFTER:\n', y_block) @@ -691,7 +912,7 @@ this iteration. In code, it might look like this: ```{code-cell} def psum_scatter(x, axis_name, *, tiled=False): - size = jax.lax.psum(1, axis_name) + size = jax.lax.axis_size(axis_name) idx = jax.lax.axis_index(axis_name) # function instance index along axis_name if tiled: x = x.reshape(size, -1, *x.shape[1:]) # split leading axis @@ -704,7 +925,7 @@ def psum_scatter(x, axis_name, *, tiled=False): ``` ```{code-cell} -@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i')) +@partial(jax.shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i')) def f8(x_block): print('BEFORE:\n', x_block) y_block = psum_scatter(x_block, 'i', tiled=True) @@ -740,7 +961,7 @@ transpose operating along one positional axis and one cross-device axis: Illustration of an all_to_all computation. ```{code-cell} -@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i')) +@partial(jax.shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i')) def f9(x_block): print('BEFORE:\n', x_block) y_block = jax.lax.all_to_all(x_block, 'i', split_axis=0, concat_axis=0, @@ -801,7 +1022,6 @@ import jax import jax.numpy as jnp from jax.sharding import Mesh, NamedSharding, PartitionSpec as P -from jax.experimental.shard_map import shard_map ``` ```{code-cell} @@ -835,7 +1055,7 @@ side: ```{code-cell} @jax.jit -@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec), +@partial(jax.shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec), out_specs=rhs_spec) def matmul_allgather(lhs_block, rhs_block): rhs = jax.lax.all_gather(rhs_block, 'i', tiled=True) @@ -861,10 +1081,10 @@ multiplies: ```{code-cell} @jax.jit -@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec), +@partial(jax.shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec), out_specs=rhs_spec) def matmul_allgather_overlapped(lhs_block, rhs_block): - size = jax.lax.psum(1, 'i') + size = jax.lax.axis_size('i') idx = jax.lax.axis_index('i') shift = partial(jax.lax.ppermute, axis_name='i', perm=[(i, (i + 1) % size) for i in range(size)]) @@ -892,10 +1112,10 @@ each half in each direction: ```{code-cell} @jax.jit -@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec), +@partial(jax.shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec), out_specs=rhs_spec) def matmul_allgather_overlapped_bidi(lhs_block, rhs_block): - size = jax.lax.psum(1, 'i') + size = jax.lax.axis_size('i') idx = jax.lax.axis_index('i') shift_up = partial(jax.lax.ppermute, axis_name='i', perm=[(i, (i + 1) % size) for i in range(size)]) @@ -943,7 +1163,7 @@ rhs = device_put(rhs, rhs_spec) Here we can use a `reduce_scatter` to perform the contraction sum over shards: ```{code-cell} -@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec), +@partial(jax.shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec), out_specs=rhs_spec) def matmul_psumscatter(lhs_block, rhs_block): out_summand = lhs_block @ rhs_block @@ -959,10 +1179,10 @@ inline an implementation of `psum_scatter` in terms of `ppermute`, then interleave the communication steps with local matrix multiplies: ```{code-cell} -@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec), +@partial(jax.shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec), out_specs=rhs_spec) def matmul_psumscatter_overlapped(lhs_block, rhs_block): - size = jax.lax.psum(1, 'i') + size = jax.lax.axis_size('i') idx = jax.lax.axis_index('i') shift = partial(jax.lax.ppermute, axis_name='i', perm=[(i, (i - 1) % size) for i in range(size)]) @@ -984,10 +1204,10 @@ As in the previous example, to fully utilize interconnects on TPU, we'd run a bidirectional version: ```{code-cell} -@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec), +@partial(jax.shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec), out_specs=rhs_spec) def matmul_psumscatter_overlapped_bidi(lhs_block, rhs_block): - size = jax.lax.psum(1, 'i') + size = jax.lax.axis_size('i') idx = jax.lax.axis_index('i') shift_up = partial(jax.lax.ppermute, axis_name='i', perm=[(i, (i + 1) % size) for i in range(size)]) @@ -1061,7 +1281,7 @@ params, batch = init(jax.random.key(0), layer_sizes, batch_size) Compare these examples with the purely [automatic partitioning examples in the "Distributed arrays and automatic partitioning" -doc](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html). +doc](https://docs.jax.dev/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html). While in those automatic partitioning examples we don't need to edit the model functions to use different parallelization strategies, with `shard_map` we often do. @@ -1079,7 +1299,6 @@ all-reduce-sums of parameter gradients in the backward pass.) from functools import partial from jax.sharding import NamedSharding, Mesh, PartitionSpec as P -from jax.experimental.shard_map import shard_map mesh = jax.make_mesh((8,), ('batch',)) @@ -1089,7 +1308,7 @@ params = jax.device_put(params, NamedSharding(mesh, P())) # adapt the loss function to sum the losses across devices def loss_dp(params, batch): - @partial(shard_map, mesh=mesh, in_specs=P('batch', None), out_specs=P()) + @partial(jax.shard_map, mesh=mesh, in_specs=P('batch', None), out_specs=P()) def loss_spmd(local_batch): inputs, targets = local_batch predictions = predict(params, inputs) # use reference 'predict` @@ -1137,7 +1356,7 @@ There's one other ingredient we need: we don't want to store the fully gathered parameters from the forward pass for use on the backward pass. Instead, we want to gather them again on the backward pass. We can express that by using `jax.remat` with a [custom -policy](https://jax.readthedocs.io/en/latest/notebooks/autodiff_remat.html#custom-policies-for-what-s-saveable) +policy](https://docs.jax.dev/en/latest/notebooks/autodiff_remat.html#custom-policies-for-what-s-saveable) (or a `custom_vjp`), though XLA typically does that rematerialization automatically. @@ -1164,7 +1383,7 @@ def predict_fsdp(params_frag, inputs): return outputs def loss_fsdp(params, batch): - @partial(shard_map, mesh=mesh, in_specs=P('batch'), out_specs=P()) + @partial(jax.shard_map, mesh=mesh, in_specs=P('batch'), out_specs=P()) def loss_spmd(local_params, local_batch): inputs, targets = local_batch predictions = predict_fsdp(local_params, inputs) @@ -1209,7 +1428,7 @@ def predict_tp(params, inputs): inputs = jax.nn.relu(outputs) return outputs -@partial(shard_map, mesh=mesh, +@partial(jax.shard_map, mesh=mesh, in_specs=(P(None, 'feats'), P('feats', None), P('feats')), out_specs=P(None, 'feats')) def gemm_tp(inputs, W, b): @@ -1245,7 +1464,7 @@ def predict_fsdp_tp(params_frag, inputs): inputs = jax.nn.relu(outputs) return outputs -@partial(shard_map, mesh=mesh, +@partial(jax.shard_map, mesh=mesh, in_specs=(P(('feats', 'batch')), P('batch', 'feats')), out_specs=P()) def loss_fsdp_tp(local_params, local_batch): @@ -1325,7 +1544,7 @@ def predict_pp(params, inputs): outputs = jnp.dot(inputs, W_last) + b_last return outputs -@partial(shard_map, mesh=mesh, in_specs=((P(), P('stages'), P()), P('stages')), +@partial(jax.shard_map, mesh=mesh, in_specs=((P(), P('stages'), P()), P('stages')), out_specs=P()) def loss_pp(params, batch): inputs, targets = batch diff --git a/docs/notebooks/thinking_in_jax.ipynb b/docs/notebooks/thinking_in_jax.ipynb index 5ddcdd32e2b4..28d5f20deab6 100644 --- a/docs/notebooks/thinking_in_jax.ipynb +++ b/docs/notebooks/thinking_in_jax.ipynb @@ -139,7 +139,7 @@ { "data": { "text/plain": [ - "jaxlib.xla_extension.ArrayImpl" + "jaxlib._jax.ArrayImpl" ] }, "execution_count": 4, @@ -248,7 +248,7 @@ "id": "yRYF0YgO3F4H" }, "source": [ - "For updating individual elements, JAX provides an [indexed update syntax](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html#jax-numpy-ndarray-at) that returns an updated copy:" + "For updating individual elements, JAX provides an [indexed update syntax](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html#jax-numpy-ndarray-at) that returns an updated copy:" ] }, { @@ -423,7 +423,7 @@ "id": "0GPqgT7S0q8r" }, "source": [ - "Under the hood, this NumPy operation is translated to a much more general convolution implemented by [`lax.conv_general_dilated`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv_general_dilated.html):" + "Under the hood, this NumPy operation is translated to a much more general convolution implemented by [`lax.conv_general_dilated`](https://docs.jax.dev/en/latest/_autosummary/jax.lax.conv_general_dilated.html):" ] }, { @@ -461,7 +461,7 @@ "id": "7mdo6ycczlbd" }, "source": [ - "This is a batched convolution operation designed to be efficient for the types of convolutions often used in deep neural nets. It requires much more boilerplate, but is far more flexible and scalable than the convolution provided by NumPy (See [Convolutions in JAX](https://jax.readthedocs.io/en/latest/notebooks/convolutions.html) for more detail on JAX convolutions).\n", + "This is a batched convolution operation designed to be efficient for the types of convolutions often used in deep neural nets. It requires much more boilerplate, but is far more flexible and scalable than the convolution provided by NumPy (See [Convolutions in JAX](https://docs.jax.dev/en/latest/notebooks/convolutions.html) for more detail on JAX convolutions).\n", "\n", "At their heart, all `jax.lax` operations are Python wrappers for operations in XLA; here, for example, the convolution implementation is provided by [XLA:ConvWithGeneralPadding](https://www.tensorflow.org/xla/operation_semantics#convwithgeneralpadding_convolution).\n", "Every JAX operation is eventually expressed in terms of these fundamental XLA operations, which is what enables just-in-time (JIT) compilation." @@ -562,7 +562,7 @@ "id": "3GvisB-CA9M8" }, "source": [ - "But due to the compilation (which includes fusing of operations, avoidance of allocating temporary arrays, and a host of other tricks), execution times can be orders of magnitude faster in the JIT-compiled case (note the use of `block_until_ready()` to account for JAX's [asynchronous dispatch](https://jax.readthedocs.io/en/latest/async_dispatch.html)):" + "But due to the compilation (which includes fusing of operations, avoidance of allocating temporary arrays, and a host of other tricks), execution times can be orders of magnitude faster in the JIT-compiled case (note the use of `block_until_ready()` to account for JAX's [asynchronous dispatch](https://docs.jax.dev/en/latest/async_dispatch.html)):" ] }, { @@ -650,7 +650,7 @@ "evalue": "ignored", "output_type": "error", "traceback": [ - "\u001b[0;31mNonConcreteBooleanIndexError\u001b[0m\u001b[0;31m:\u001b[0m Array boolean indices must be concrete; got ShapedArray(bool[10])\n\nSee https://jax.readthedocs.io/en/latest/errors.html#jax.errors.NonConcreteBooleanIndexError\n" + "\u001b[0;31mNonConcreteBooleanIndexError\u001b[0m\u001b[0;31m:\u001b[0m Array boolean indices must be concrete; got ShapedArray(bool[10])\n\nSee https://docs.jax.dev/en/latest/errors.html#jax.errors.NonConcreteBooleanIndexError\n" ] } ], @@ -835,7 +835,7 @@ "evalue": "ignored", "output_type": "error", "traceback": [ - "\u001b[0;31mTracerBoolConversionError\u001b[0m\u001b[0;31m:\u001b[0m Attempted boolean conversion of traced array with shape bool[]..\nThe error occurred while tracing the function f at :1 for jit. This concrete value was not available in Python because it depends on the value of the argument neg.\nSee https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError\n" + "\u001b[0;31mTracerBoolConversionError\u001b[0m\u001b[0;31m:\u001b[0m Attempted boolean conversion of traced array with shape bool[]..\nThe error occurred while tracing the function f at :1 for jit. This concrete value was not available in Python because it depends on the value of the argument neg.\nSee https://docs.jax.dev/en/latest/errors.html#jax.errors.TracerBoolConversionError\n" ] } ], diff --git a/docs/notebooks/thinking_in_jax.md b/docs/notebooks/thinking_in_jax.md index 0693f6ba8579..7b0bb0d9b8ce 100644 --- a/docs/notebooks/thinking_in_jax.md +++ b/docs/notebooks/thinking_in_jax.md @@ -117,7 +117,7 @@ x[0] = 10 +++ {"id": "yRYF0YgO3F4H"} -For updating individual elements, JAX provides an [indexed update syntax](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html#jax-numpy-ndarray-at) that returns an updated copy: +For updating individual elements, JAX provides an [indexed update syntax](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html#jax-numpy-ndarray-at) that returns an updated copy: ```{code-cell} ipython3 :id: 8zqPEAeP3UK5 @@ -189,7 +189,7 @@ jnp.convolve(x, y) +++ {"id": "0GPqgT7S0q8r"} -Under the hood, this NumPy operation is translated to a much more general convolution implemented by [`lax.conv_general_dilated`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv_general_dilated.html): +Under the hood, this NumPy operation is translated to a much more general convolution implemented by [`lax.conv_general_dilated`](https://docs.jax.dev/en/latest/_autosummary/jax.lax.conv_general_dilated.html): ```{code-cell} ipython3 :id: pi4f6ikjzc3l @@ -206,7 +206,7 @@ result[0, 0] +++ {"id": "7mdo6ycczlbd"} -This is a batched convolution operation designed to be efficient for the types of convolutions often used in deep neural nets. It requires much more boilerplate, but is far more flexible and scalable than the convolution provided by NumPy (See [Convolutions in JAX](https://jax.readthedocs.io/en/latest/notebooks/convolutions.html) for more detail on JAX convolutions). +This is a batched convolution operation designed to be efficient for the types of convolutions often used in deep neural nets. It requires much more boilerplate, but is far more flexible and scalable than the convolution provided by NumPy (See [Convolutions in JAX](https://docs.jax.dev/en/latest/notebooks/convolutions.html) for more detail on JAX convolutions). At their heart, all `jax.lax` operations are Python wrappers for operations in XLA; here, for example, the convolution implementation is provided by [XLA:ConvWithGeneralPadding](https://www.tensorflow.org/xla/operation_semantics#convwithgeneralpadding_convolution). Every JAX operation is eventually expressed in terms of these fundamental XLA operations, which is what enables just-in-time (JIT) compilation. @@ -261,7 +261,7 @@ np.allclose(norm(X), norm_compiled(X), atol=1E-6) +++ {"id": "3GvisB-CA9M8"} -But due to the compilation (which includes fusing of operations, avoidance of allocating temporary arrays, and a host of other tricks), execution times can be orders of magnitude faster in the JIT-compiled case (note the use of `block_until_ready()` to account for JAX's [asynchronous dispatch](https://jax.readthedocs.io/en/latest/async_dispatch.html)): +But due to the compilation (which includes fusing of operations, avoidance of allocating temporary arrays, and a host of other tricks), execution times can be orders of magnitude faster in the JIT-compiled case (note the use of `block_until_ready()` to account for JAX's [asynchronous dispatch](https://docs.jax.dev/en/latest/async_dispatch.html)): ```{code-cell} ipython3 :id: 6mUB6VdDAEIY diff --git a/docs/notes.rst b/docs/notes.rst index 08265638000e..502385142b16 100644 --- a/docs/notes.rst +++ b/docs/notes.rst @@ -9,9 +9,6 @@ Dependencies and version compatibility: - :doc:`api_compatibility` outlines JAX's policies with regard to API compatibility across releases. - :doc:`deprecation` outlines JAX's policies with regard to compatibility with Python and NumPy. -Migrations and deprecations: - - :doc:`jax_array_migration` summarizes the changes to the default array type in jax v 0.4.1 - Memory and computation usage: - :doc:`async_dispatch` describes JAX's asynchronous dispatch model. - :doc:`concurrency` describes how JAX interacts with other Python concurrency. @@ -20,6 +17,10 @@ Memory and computation usage: Programmer guardrails: - :doc:`rank_promotion_warning` describes how to configure :mod:`jax.numpy` to avoid implicit rank promotion. +Arrays and data types: + - :doc:`type_promotion` describes JAX's implicit type promotion for functions of two or more values. + - :doc:`default_dtypes` describes how JAX determines the default dtype for array creation functions. + .. toctree:: :hidden: @@ -27,8 +28,9 @@ Programmer guardrails: api_compatibility deprecation - jax_array_migration async_dispatch concurrency gpu_memory_allocation - rank_promotion_warning \ No newline at end of file + rank_promotion_warning + type_promotion + default_dtypes diff --git a/docs/pallas/CHANGELOG.md b/docs/pallas/CHANGELOG.md index 2b1cad7c9a66..e3589b87b720 100644 --- a/docs/pallas/CHANGELOG.md +++ b/docs/pallas/CHANGELOG.md @@ -2,15 +2,54 @@ # Pallas Changelog - + This is the list of changes specific to {class}`jax.experimental.pallas`. -For the overall JAX change log see [here](https://jax.readthedocs.io/en/latest/changelog.html). +For the overall JAX change log see [here](https://docs.jax.dev/en/latest/changelog.html). +## Unreleased + +* New functionality + + * Added a new decorator {func}`jax.experimental.pallas.loop` which allows + to write stateless loops as functions. + +* Deprecations + + * {class}`jax.experimental.pallas.triton.TritonCompilerParams` has been + renamed to {class}`jax.experimental.pallas.triton.CompilerParams`. The + old name is deprecated and will be removed in a future release. + * {class}`jax.experimental.pallas.tpu.TPUCompilerParams` + and {class}`jax.experimental.pallas.tpu.TPUMemorySpace` have been + renamed to {class}`jax.experimental.pallas.tpu.CompilerParams` + and {class}`jax.experimental.pallas.tpu.MemorySpace`. The + old names are deprecated and will be removed in a future release. + +## Released with jax 0.6.1 + +* Removals + + * Removed previously deprecated {mod}`jax.experimental.pallas.gpu`. To use + the Triton backend import {mod}`jax.experimental.pallas.triton`. + +* Changes + + * {func}`jax.experimental.pallas.BlockSpec` now takes in special types in + addition to ints/None in the `block_shape`. `indexing_mode` has been + removed. To achieve "Unblocked", pass a `pl.Element(size)` into + `block_shape` for each entry that needs unblocked indexing. + * {func}`jax.experimental.pallas.pallas_call` now requires `compiler_params` + to be a backend-specific dataclass instead of a param to value mapping. + * {func}`jax.experimental.pallas.debug_check` is now supported both on + TPU and Mosaic GPU. Previously, this functionality was only supported + on TPU and required using the APIs from {mod}`jax.experimental.checkify`. + Note that debug checks are not executed unless + {data}`jax.experimental.pallas.enable_debug_checks` is set. + ## Released with jax 0.5.0 * New functionality diff --git a/docs/pallas/design/async_note.md b/docs/pallas/design/async_note.md index 42e32a074fd7..b255a91d3ec8 100644 --- a/docs/pallas/design/async_note.md +++ b/docs/pallas/design/async_note.md @@ -1,3 +1,4 @@ +(pallas_async)= # Pallas Async Operations ## Background \+ Motivation @@ -463,7 +464,7 @@ def f(x): return fori_loop(0, 8, body, x) ``` -If you run the alias analysis, you’ll find that all of the buffers have been colored the same\! Intuitively, this is problematic because if we are doing a loop of `ppermute`s, we can’t write into the same buffer we are sending into. We generally need an extra (i.e. a “double”) buffer to receive, and then usually we will switch the send/recv buffers on the next iteration. What XLA will do in practice is that it will observe the buffer re-use and defensively insert a copy. +If you run the alias analysis, you’ll find that all of the buffers have been colored the same\! Intuitively, this is problematic because if we are doing a loop of `ppermute`s, we can’t write into the same buffer we are sending into. We generally need an extra (i.e. a “double”) buffer to receive, and then usually we will switch the send/recv buffers on the next iteration. What XLA will do in practice is that it will observe the buffer reuse and defensively insert a copy. ```py def f(x): diff --git a/docs/pallas/design/design.md b/docs/pallas/design/design.md index 17c7a6dbdc0f..53a5eb209510 100644 --- a/docs/pallas/design/design.md +++ b/docs/pallas/design/design.md @@ -71,7 +71,7 @@ A JAX-based kernel language offers several advantages: * JAX as a tracing-based frontend for numerical computing is both mature and well-used. By embedding the kernel programming language in JAX itself, - we can re-use JAX’s tracing infrastructure and provide a + we can reuse JAX’s tracing infrastructure and provide a NumPy-like frontend that’s already familiar to users. * JAX transformations are key to its success, allowing users to express simple programs but transform them to achieve complex @@ -551,7 +551,7 @@ along that dimension. `grad` of `pallas_call` enables automatic differentiation of kernels. `jax.grad` breaks down into applications of three distinct transforms: `jvp`, `partial_eval` and `transpose`. -In principle, we can re-use most of JAX’s infrastructure when +In principle, we can reuse most of JAX’s infrastructure when implementing these rules for `pallas_call` (since it behaves much like existing JAX higher order primitives). diff --git a/docs/pallas/gpu/index.rst b/docs/pallas/gpu/index.rst new file mode 100644 index 000000000000..3fec14832337 --- /dev/null +++ b/docs/pallas/gpu/index.rst @@ -0,0 +1,15 @@ +Pallas:Mosaic GPU +================= +Backend specific documentation for the Mosaic GPU backend. + +.. toctree:: + :caption: Reference documentation + :maxdepth: 2 + + reference + pipelining + +.. toctree:: + :caption: Guides + :maxdepth: 2 + diff --git a/docs/pallas/gpu/pipelining.ipynb b/docs/pallas/gpu/pipelining.ipynb new file mode 100644 index 000000000000..c1bcc27c2dbf --- /dev/null +++ b/docs/pallas/gpu/pipelining.ipynb @@ -0,0 +1,428 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "9552ee76", + "lines_to_next_cell": 0 + }, + "source": [ + "(pallas_mgpu_pipelining)=" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "bJ5yuIr-M0x0" + }, + "source": [ + "\n", + "## Mosaic GPU Pipelining\n", + "\n", + "This guide covers software pipelining using the Mosaic GPU backend for Pallas.\n", + "\n", + "For a general overview of the pipelining API in Pallas, we recommend that users first read {ref}`pallas_software_pipelining`. Pipelining in Pallas is programmed explicitly. For those who are familiar with Triton, this is a significant difference in programming model because in Triton, pipelining is an optimization that is done automatically by the compiler.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "dGAa3iO5DoRT" + }, + "outputs": [], + "source": [ + "import jax\n", + "from jax import lax\n", + "from jax import numpy as jnp\n", + "from jax.experimental.pallas import mosaic_gpu as plgpu\n", + "from jax.experimental import pallas as pl\n", + "import numpy as np" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Pv9j90hVyswo" + }, + "source": [ + "\n", + "### Pipelining with Mosaic GPU\n", + "\n", + "The recommended approach to pipeline using Mosaic GPU is to use the `plgpu.emit_pipeline` function to pipeline over sequential loops (and to use `plgpu.kernel` to partition the problem in parallel over the CUDA grid). `emit_pipeline` follows a similar API as `pl.pallas_call` except it exposes a few additional GPU-specific options.\n", + "\n", + "- `body`, `grid` have similar semantics as in `pl.pallas_call`. The `grid` denotes how many invocations of the `body` function to run. In contrast with a CUDA grid, the pipeline grid is guaranteed to run sequentially.\n", + "- `in_specs` and `out_specs` also work similarly to `pl.pallas_call`, except they also accept `plgpu.BlockSpec` instances that can be used specify GPU-specific transforms, such as swizzling. See [memory reference transforms](https://docs.jax.dev/en/latest/pallas/gpu/reference.html#memory-reference-transforms) for more detail on available transformations.\n", + "- `max_concurrent_steps` controls the maximum number of concurrent memory transfers. Using additional concurrent steps will consume more SMEM to hold temporary buffers, but it can improve the utilization of the memory subsystem. We recommend autotuning this parameter. Low values (e.g. 2) can sometimes achieve higher occupancy (due to lower SMEM usage) which can improve throughput in ALU-heavy kernels, but will introduce more noise due to the hardware taking care of scheduling. Larger values (between 4 and 6) will work best for kernels that can't take advantage of extra occupancy\n", + "- `delay_release` allows the user to specify an additional number of iterations to wait before the buffer is re-used by the pipeline. For example, a buffer copied into SMEM on iteration 0 with `delay_release=1` and `max_concurrent_steps=2` will not be re-used until iteration 3, as opposed to iteration 2 for a standard double-buffered strategy. `delay_release=1` is necessary if you don't await a `plgpu.wgmma` operation on the pipeline operands, as otherwise the pipeline will begin overwriting the buffers while the WGMMA is still reading them. This is useful for certain optimizations such as allowing multiple async matmuls in flight to keep the tensor core pipeline filled, but care must be taken when using such a strategy as **omitting this parameter will silent data races**, and it reduces the efficiency of `emit_pipeline` as we are overlapping fewer memory transfers.\n", + "\n", + "#### Compatibility API using `pl.pallas_call`\n", + "\n", + "As an alternative to `emit_pipeline` and to maintain compatibility with Pallas TPU, Mosaic GPU also implements the existing `pl.pallas_call` API. By default, `pl.pallas_call` on Mosaic GPU will partition your kernel in parallel over the CUDA grid. You can opt-in to pipelining by passing in a `plgpu.GPUCompilerParams` object as the `compiler_params` argument, which specifies the following options that are relevant for pipelining:\n", + "- `dimension_semantics`: A tuple of `Literal['parallel', 'sequential']` that specifies iteration semantics for each grid dimension. `parallel` will partition the corresponding dimension over the CUDA grid, and `sequential` dimensions will be pipelined sequentially. **Note that if no dimensions are marked `sequential`, no pipelining will happen!**\n", + "- `max_concurrent_steps`: identical to the option in `plgpu.emit_pipeline`.\n", + "- `delay_release`: identical to the option in `plgpu.emit_pipeline`.\n", + "\n", + "Pipelining lets you re-use scratch buffers across the sequential iterations of the grid (e.g. for implementing reductions). Additionally, `pallas_call` supports using `plgpu.BlockSpec` objects in place of `pl.BlockSpec` objects when using the Mosaic GPU backend, allowing you to specify GPU-specific memory transformations.\n", + "\n", + "We recommend that users use `plgpu.kernel` rather than `pl.pallas_call` as `plgpu.kernel` supports more features (such as specifying the number of warpgroups and warp specialization).\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Qp3X6wylJtoa" + }, + "source": [ + "### GPU Memory Spaces\n", + "\n", + "Refs exist primarily in one of two memory spaces, which can be explicitly specified by the `memory_space` argument of `BlockSpec`, i.e. `BlockSpec(memory_space=plgpu.GPUMemorySpace.GMEM)`.\n", + "\n", + "- `plgpu.GPUMemorySpace.SMEM` allocates a Ref in Shared Memory (SMEM). SMEM Refs can be dereferenced using array indexing syntax to store values in registers for compute, i.e. `x = y_ref[...]`. This memory space used for a Ref when using `emit_pipeline`.\n", + "\n", + "- `plgpu.GPUMemorySpace.GMEM` allocates a Ref in Global Memory (GMEM/HBM). Any Refs allocated in GMEM are not pipelined, and values cannot be accessed directly using array indexing operations. Instead, GMEM must be accessed via SMEM using `plgpu.copy_gmem_to_smem` for reading, or `plgpu.copy_smem_to_gmem` for writing, or pipelined into SMEM using `plgpu.emit_pipeline`.\n", + "\n", + "The primary purpose of `emit_pipeline` is used to overlap TensorCore computation with data transfers between GMEM and SMEM, since asynchronous copies between GMEM/SMEM have a long latency, but all TensorCore computation must operate on registers (or SMEM Refs in the case of matrix multiplication)." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "0uzcrDCtKABQ" + }, + "source": [ + "### Example: Matmul Kernel on Hopper GPUs" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "vILVdlqEdoEK" + }, + "source": [ + "Let's begin with a matrix multiplication example designed to run on Hopper GPUs. This kernel utilizes the Hopper-specific `wgmma` (warpgroup matrix multiply accumulate) instruction. `wgmma` is issued by a single Mosaic GPU thread and runs asynchronously on the TensorCore.\n", + "\n", + "Our example kernel implements a blockwise matrix multiplication of two matrices of shape `[M, K] @ [K, N] = [M, N]`, where each output block is computed in parallel over the CUDA grid. This grid is specified as the `grid` argument to the outer `plgpu.kernel`, and parallelizes over the non-contracting dimensions M, N of the matrix multiplication." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "KSvqVNdy726B" + }, + "source": [ + "\n", + "
\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "10ebHCQ571Fn" + }, + "source": [ + "\n", + "Within a program instance, we run a sequential pipeline using `plgpu.emit_pipeline` that reduces over the contracting dimension K of the matrix multiplication. On each iteration of the pipeline, we load one tile from each input matrix, multiply them, and then store the result in an accumulator Ref (`plgpu.ACC`). `plgpu.ACC` is a special type of Ref that lives in registers and holds the intermediate results of WGMMA. Once we have accumulated over the entire contracting dimension, we write out the result to the output Ref.\n", + "\n", + "To perform the actual matrix multiplication, we call `plgpu.wgmma` with the accumulator, LHS, and RHS Refs as arguments in order to push the arguments into the TensorCore pipeline. All WGMMA operations are executed in order, so this can be viewed as pushing operations into a queue. Since `wgmma` is an asynchronous instruction, `plgpu.wgmma_wait(N)` is used to wait until there are no more than N `wgmma` operations left in-flight. In this particular implementation we wait for 1 in-flight WGMMA, meaning that the WGMMA we queue on the current iteration will be waited for on the next iteration.\n", + "- `wgmma` wants it's arguments to be in a specific format, defined in the [CUDA documentation](https://docs.nvidia.com/cuda/parallel-thread-execution/#register-fragments-and-shared-memory-matrix-layouts). These are implemented by the `TilingTransform` and `SwizzleTransform` transformations on the input BlockSpecs. Note that in the future transforms will be inferred automatically by Mosaic GPU and these will not need to be manually specified. See the [wgmma reference](https://docs.jax.dev/en/latest/pallas/gpu/reference.html#hopper-wgmma) for full details on using this instruction.\n", + "- We use the `delay_release` parameter in conjunction with `plgpu.wgmma_wait(1)` to always allow one `WGMMA` operation to stay in-flight in order to ensure good TensorCore utilization. Without this, we would be flushing the TensorCore pipeline on every iteration of the kernel." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "6Vf5_VA9iCD1" + }, + "outputs": [], + "source": [ + "def matmul(a, b, tile_m=128, tile_n=128, swizzle=128):\n", + " dtype = jnp.float16\n", + " swizzle_elems = swizzle // jnp.dtype(dtype).itemsize\n", + " tile_k = swizzle_elems\n", + " grid_m = m // tile_m\n", + " grid_k = k // tile_k\n", + " grid_n = n // tile_n\n", + " assert tile_m % swizzle_elems == 0\n", + "\n", + " # Note: Transforms will be inferred automatically\n", + " # by Mosaic GPU in the future.\n", + " transforms = (\n", + " plgpu.TilingTransform((8, swizzle_elems)),\n", + " plgpu.SwizzleTransform(swizzle),\n", + " )\n", + "\n", + " def kernel(a_gmem, b_gmem, o_gmem, o_smem, acc):\n", + " def pipeline_step(_, a_smem, b_smem):\n", + " plgpu.wgmma(acc, a_smem, b_smem)\n", + " plgpu.wgmma_wait(1)\n", + "\n", + " # pl.program_id obtains the index into the grid.\n", + " pid_m = pl.program_id(0)\n", + " pid_n = pl.program_id(1)\n", + "\n", + " pipeline = plgpu.emit_pipeline(\n", + " pipeline_step,\n", + " in_specs=[\n", + " plgpu.BlockSpec(\n", + " (tile_m, tile_k), lambda k: (pid_m, k), transforms=transforms\n", + " ),\n", + " plgpu.BlockSpec(\n", + " (tile_k, tile_n), lambda k: (k, pid_n), transforms=transforms\n", + " ),\n", + " ],\n", + " grid=(grid_k,),\n", + " max_concurrent_steps=2,\n", + " delay_release=1,\n", + " )\n", + "\n", + " pipeline(a_gmem, b_gmem)\n", + " # Store WGMMA accumulator to SMEM and then to GMEM.\n", + " o_smem[...] = acc[...].astype(dtype)\n", + " plgpu.commit_smem()\n", + " m_slice = pl.ds(pid_m * tile_m, tile_m)\n", + " n_slice = pl.ds(pid_n * tile_n, tile_n)\n", + " plgpu.copy_smem_to_gmem(o_smem, o_gmem.at[m_slice, n_slice])\n", + " plgpu.wait_smem_to_gmem(0)\n", + "\n", + " return plgpu.kernel(\n", + " kernel,\n", + " out_shape=jax.ShapeDtypeStruct((m, n), jnp.float16),\n", + " scratch_shapes=[\n", + " plgpu.SMEM((tile_m, tile_n), jnp.float16),\n", + " plgpu.ACC((tile_m, tile_n), jnp.float32)\n", + " ],\n", + " # grid specifies the CUDA grid.\n", + " # Instances of `kernel` will be executed in parallel over this grid.\n", + " grid=(grid_m, grid_n),\n", + " grid_names=(\"m\", \"n\"),\n", + " )(a, b)\n", + "\n", + "m = 132 * 128\n", + "n = 4 * 128\n", + "k = 10 * 64\n", + "key1, key2 = jax.random.split(jax.random.key(42), 2)\n", + "a = jax.random.uniform(key1, shape=(m, k), dtype=jnp.float16)\n", + "b = jax.random.uniform(key2, shape=(k, n), dtype=jnp.float16)\n", + "\n", + "result = matmul(a, b)\n", + "\n", + "np.testing.assert_allclose(result, a @ b)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "lIYV7PN9J8Px" + }, + "source": [ + "### Warp Specialization\n", + "\n", + "Warp specialization is a technique where we program each warp/warpgroup to perform a single task in order to give the GPU hardware the flexibility to schedule them at runtime. Recall that each streaming multiprocessor (SM) in a GPU contains warp schedulers that can swap execution between warps, so for example when one warp is stalling it can begin executing a different warp. In practice, this can be more performant than programming a single instruction stream where the compiler must statically schedule the operations and attempt to overlap them optimally.\n", + "\n", + "In particular, we are interested in warpgroup specialization on Hopper+ GPUs, where it can be useful to have a separate warpgroup issuing TMAs (GMEM/SMEM copies) from the warpgroups performing arithmetic, since indexing calculations and issuing TMAs can take up a significant amount of time and potentially leave the TensorCore idle. The figure below depicts a standard, non-specialized kernel on the left where TMAs (async copies) and matrix multiplication are issued from a single instruction stream, and a warp-specialized version on the right where communication and arithmetic are handled on separate warpgroups. A *consumed barrier* is used to synchronize between the specialized warpgroups that signals to the memory warpgroup when it is safe to begin the next TMA.\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "n-y90IC7v7vL" + }, + "source": [ + "\n", + "
\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ZH0Pui5kFSdD" + }, + "source": [ + "Warp specialization can be enabled in Pallas by using the `plgpu.emit_pipeline_warp_specialized` helper. This pipeline helper handles all of the logic in the memory thread, and the user only needs to specify the work done in the compute threads. It shares the a similar API as the standard `emit_pipeline`, and currently supports the following arguments:\n", + "\n", + "```python\n", + "plgpu.emit_pipeline_warp_specialized(\n", + " body: Callable,\n", + " *\n", + " grid: tuple[int, ...],\n", + " in_specs: Sequence[pallas_core.BlockSpec] = (),\n", + " out_specs: Sequence[pallas_core.BlockSpec] = (),\n", + " max_concurrent_steps: int,\n", + " compute_context: Callable\n", + " num_compute_wgs: int,\n", + " memory_registers: int\n", + " wg_axis: str,\n", + " memory_thread_idx: int | None = None,\n", + ")\n", + "```\n", + "\n", + "There are a few arguments specific to this pipeline emitter, which are:\n", + "- `num_compute_wgs` specifies how many compute threads/warpgroups to use. The pipeline emitter always uses a single memory thread, so in `plgpu.kernel` you should specify `num_threads=num_compute_wgs+1`.\n", + "- `memory_registers` controls how many registers to allocate to the memory thread. The remaining registers are partitioned evenly among the compute threads. The default value is 40 and should be adjusted up or down depending on whether register spills are encountered.\n", + "- `wg_axis` the name of the thread/warpgroup axis (as specified by the `thead_name` argument of `plgpu.kernel`).\n", + "- `memory_thread_idx` specifies which Pallas thread to designate as the memory thread. Defaults to the last thread.\n", + "- `compute_context` is a enables you to specify a prologue/epilogue to the pipeline that only runs in the compute thread. The function allows you to define the initialization and consumption of a loop carry through the pipeline. All compute thread specific arrays should be instantiated here so the memory thread does not materialize them in registers -- otherwise, you may experience slowdowns due to register spills.\n", + "\n", + "The pipeline body of the warp specialized pipeline is run in parallel by all compute threads, and SMEM is shared between compute threads since they are scheduled within the same CUDA block.`lax.axis_index` can be used inside the kernel to obtain the Pallas thread index in order to divide up work amongst compute threads.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ZGbK5gIvFZKy" + }, + "source": [ + "### Example: Matrix Multiplication with Warp Specialization\n", + "\n", + "The following example extends the previous matrix multiplication example to use warp specialization. This particular kernel uses 2 compute threads, which operate on separate columns of the RHS matrix but share the same LHS. Each invocation of the pipeline therefore computes 2 adjacent blocks in the output matrix.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "NYWBqa9-bp2p" + }, + "source": [ + "\n", + "
\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "OkWmfqn7b53M" + }, + "source": [ + "We use the `compute_context` pattern to initialize the WGMMA accumulator, and copy the final accumulator from registers into SMEM. Here, the compute context is defined in the function `compute_thread`. It is critical that the accumulator be created inside of the `compute_thread` function to avoid allocating it in the memory thread which would waste registers. To perform the WGMMA, we wrap the `wgmma` instruction in a `pl.run_state` in order to create an accumulator ref that is initialized to the carry value.\n", + "\n", + "Instead of using `pl.pallas_call` to call the kernel, we instead use the GPU-specific `plgpu.kernel` entry point. `plgpu.kernel` allows us to specify the number of threads to launch per CUDA block via the `num_threads` argument, and allows us to specify a `thread_name` we can use to query the Pallas thread index inside of the kernel.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "EJhWnwJlFGaT" + }, + "outputs": [], + "source": [ + "def matmul_warp_specialized(a, b, tile_m=128, tile_n=128, swizzle=128,\n", + " compute_wgs=2):\n", + " dtype = jnp.float16\n", + " elems_128b = swizzle // jnp.dtype(dtype).itemsize\n", + " tile_k = elems_128b\n", + " grid_m = m // tile_m\n", + " grid_k = k // tile_k\n", + " grid_n = n // tile_n\n", + " assert tile_m % elems_128b == 0\n", + "\n", + " transforms = (\n", + " plgpu.TilingTransform((8, elems_128b)),\n", + " plgpu.SwizzleTransform(128),\n", + " )\n", + "\n", + " def kernel(a_gmem, b_gmem, o_gmem, o_smem):\n", + " wg_idx = lax.axis_index(\"wg\")\n", + " wg_slice = pl.ds(wg_idx * tile_n, tile_n)\n", + " # pl.program_id obtains the index into the pallas_call grid.\n", + " pid_m = pl.program_id(0)\n", + " pid_n = pl.program_id(1)\n", + "\n", + " def compute_thread(pipeline):\n", + " acc = plgpu.layout_cast(\n", + " jnp.full((tile_m, tile_n), 0, dtype=jnp.float32), plgpu.Layout.WGMMA,\n", + " )\n", + " # yield marks the place where the pipelined loop will be inserted.\n", + " # Its argument are the initial carry values, and its result is the carry\n", + " # value after the loop completes.\n", + " final_acc = pipeline(acc)\n", + " o_smem[:, wg_slice] = final_acc[...].astype(dtype)\n", + "\n", + " def kernel_body(_, a_smem, b_smem, carry):\n", + " acc = carry\n", + " b_smem_wg = b_smem.at[:, wg_slice]\n", + " def do_wgmma(acc_ref):\n", + " plgpu.wgmma(acc_ref, a_smem, b_smem_wg)\n", + " acc = pl.run_state(do_wgmma)(\n", + " plgpu.ACC.init(acc))\n", + " return acc\n", + "\n", + " pipeline = plgpu.emit_pipeline_warp_specialized(\n", + " kernel_body,\n", + " in_specs=[\n", + " plgpu.BlockSpec(\n", + " (tile_m, tile_k), lambda k: (pid_m, k), transforms=transforms\n", + " ),\n", + " plgpu.BlockSpec(\n", + " (tile_k, tile_n * 2), lambda k: (k, pid_n),transforms=transforms\n", + " ),\n", + " ],\n", + " grid=(grid_k,),\n", + " compute_context=compute_thread,\n", + " max_concurrent_steps=2,\n", + " num_compute_wgs=compute_wgs,\n", + " memory_registers=40,\n", + " memory_thread_idx=2,\n", + " wg_axis=\"wg\",\n", + " )\n", + " # Call the pipeline\n", + " pipeline(a_gmem, b_gmem)\n", + " # Copy the output from SMEM to GMEM.\n", + " plgpu.commit_smem()\n", + " m_slice = pl.ds(pid_m * tile_m, tile_m)\n", + " n_slice = pl.ds(pid_n * tile_n * 2, tile_n * 2)\n", + " plgpu.copy_smem_to_gmem(o_smem, o_gmem.at[m_slice, n_slice])\n", + " plgpu.wait_smem_to_gmem(0)\n", + "\n", + " return plgpu.kernel(\n", + " kernel,\n", + " out_shape=jax.ShapeDtypeStruct((m, n), jnp.float16),\n", + " scratch_shapes=[\n", + " plgpu.SMEM((tile_m, tile_n * 2), jnp.float16)\n", + " ],\n", + " grid=(grid_m, grid_n // 2),\n", + " grid_names=(\"m\", \"n\"),\n", + " num_threads=3, # 2 compute, 1 memory.\n", + " thread_name=\"wg\"\n", + " )(a, b)\n", + "\n", + "m = 132 * 128\n", + "n = 4 * 128\n", + "k = 10 * 64\n", + "key1, key2 = jax.random.split(jax.random.key(42), 2)\n", + "a = jax.random.uniform(key1, shape=(m, k), dtype=jnp.float16)\n", + "b = jax.random.uniform(key2, shape=(k, n), dtype=jnp.float16)\n", + "\n", + "result = matmul_warp_specialized(a, b)\n", + "\n", + "np.testing.assert_allclose(result, a @ b)" + ] + } + ], + "metadata": { + "colab": { + "last_runtime": { + "build_target": "//experimental/users/justinfu/pallas:colab_gpu", + "kind": "private" + }, + "provenance": [] + }, + "jupytext": { + "formats": "ipynb,md", + "main_language": "python" + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/docs/pallas/gpu/pipelining.md b/docs/pallas/gpu/pipelining.md new file mode 100644 index 000000000000..a2b361f181e1 --- /dev/null +++ b/docs/pallas/gpu/pipelining.md @@ -0,0 +1,332 @@ +--- +jupyter: + jupytext: + formats: ipynb,md + main_language: python + text_representation: + extension: .md + format_name: markdown + format_version: '1.3' + jupytext_version: 1.16.4 + kernelspec: + display_name: Python 3 + name: python3 +--- + + +(pallas_mgpu_pipelining)= + + + +## Mosaic GPU Pipelining + +This guide covers software pipelining using the Mosaic GPU backend for Pallas. + +For a general overview of the pipelining API in Pallas, we recommend that users first read {ref}`pallas_software_pipelining`. Pipelining in Pallas is programmed explicitly. For those who are familiar with Triton, this is a significant difference in programming model because in Triton, pipelining is an optimization that is done automatically by the compiler. + + + +```python id="dGAa3iO5DoRT" +import jax +from jax import lax +from jax import numpy as jnp +from jax.experimental.pallas import mosaic_gpu as plgpu +from jax.experimental import pallas as pl +import numpy as np +``` + + + +### Pipelining with Mosaic GPU + +The recommended approach to pipeline using Mosaic GPU is to use the `plgpu.emit_pipeline` function to pipeline over sequential loops (and to use `plgpu.kernel` to partition the problem in parallel over the CUDA grid). `emit_pipeline` follows a similar API as `pl.pallas_call` except it exposes a few additional GPU-specific options. + +- `body`, `grid` have similar semantics as in `pl.pallas_call`. The `grid` denotes how many invocations of the `body` function to run. In contrast with a CUDA grid, the pipeline grid is guaranteed to run sequentially. +- `in_specs` and `out_specs` also work similarly to `pl.pallas_call`, except they also accept `plgpu.BlockSpec` instances that can be used specify GPU-specific transforms, such as swizzling. See [memory reference transforms](https://docs.jax.dev/en/latest/pallas/gpu/reference.html#memory-reference-transforms) for more detail on available transformations. +- `max_concurrent_steps` controls the maximum number of concurrent memory transfers. Using additional concurrent steps will consume more SMEM to hold temporary buffers, but it can improve the utilization of the memory subsystem. We recommend autotuning this parameter. Low values (e.g. 2) can sometimes achieve higher occupancy (due to lower SMEM usage) which can improve throughput in ALU-heavy kernels, but will introduce more noise due to the hardware taking care of scheduling. Larger values (between 4 and 6) will work best for kernels that can't take advantage of extra occupancy +- `delay_release` allows the user to specify an additional number of iterations to wait before the buffer is re-used by the pipeline. For example, a buffer copied into SMEM on iteration 0 with `delay_release=1` and `max_concurrent_steps=2` will not be re-used until iteration 3, as opposed to iteration 2 for a standard double-buffered strategy. `delay_release=1` is necessary if you don't await a `plgpu.wgmma` operation on the pipeline operands, as otherwise the pipeline will begin overwriting the buffers while the WGMMA is still reading them. This is useful for certain optimizations such as allowing multiple async matmuls in flight to keep the tensor core pipeline filled, but care must be taken when using such a strategy as **omitting this parameter will silent data races**, and it reduces the efficiency of `emit_pipeline` as we are overlapping fewer memory transfers. + +#### Compatibility API using `pl.pallas_call` + +As an alternative to `emit_pipeline` and to maintain compatibility with Pallas TPU, Mosaic GPU also implements the existing `pl.pallas_call` API. By default, `pl.pallas_call` on Mosaic GPU will partition your kernel in parallel over the CUDA grid. You can opt-in to pipelining by passing in a `plgpu.GPUCompilerParams` object as the `compiler_params` argument, which specifies the following options that are relevant for pipelining: +- `dimension_semantics`: A tuple of `Literal['parallel', 'sequential']` that specifies iteration semantics for each grid dimension. `parallel` will partition the corresponding dimension over the CUDA grid, and `sequential` dimensions will be pipelined sequentially. **Note that if no dimensions are marked `sequential`, no pipelining will happen!** +- `max_concurrent_steps`: identical to the option in `plgpu.emit_pipeline`. +- `delay_release`: identical to the option in `plgpu.emit_pipeline`. + +Pipelining lets you re-use scratch buffers across the sequential iterations of the grid (e.g. for implementing reductions). Additionally, `pallas_call` supports using `plgpu.BlockSpec` objects in place of `pl.BlockSpec` objects when using the Mosaic GPU backend, allowing you to specify GPU-specific memory transformations. + +We recommend that users use `plgpu.kernel` rather than `pl.pallas_call` as `plgpu.kernel` supports more features (such as specifying the number of warpgroups and warp specialization). + + + + +### GPU Memory Spaces + +Refs exist primarily in one of two memory spaces, which can be explicitly specified by the `memory_space` argument of `BlockSpec`, i.e. `BlockSpec(memory_space=plgpu.GPUMemorySpace.GMEM)`. + +- `plgpu.GPUMemorySpace.SMEM` allocates a Ref in Shared Memory (SMEM). SMEM Refs can be dereferenced using array indexing syntax to store values in registers for compute, i.e. `x = y_ref[...]`. This memory space used for a Ref when using `emit_pipeline`. + +- `plgpu.GPUMemorySpace.GMEM` allocates a Ref in Global Memory (GMEM/HBM). Any Refs allocated in GMEM are not pipelined, and values cannot be accessed directly using array indexing operations. Instead, GMEM must be accessed via SMEM using `plgpu.copy_gmem_to_smem` for reading, or `plgpu.copy_smem_to_gmem` for writing, or pipelined into SMEM using `plgpu.emit_pipeline`. + +The primary purpose of `emit_pipeline` is used to overlap TensorCore computation with data transfers between GMEM and SMEM, since asynchronous copies between GMEM/SMEM have a long latency, but all TensorCore computation must operate on registers (or SMEM Refs in the case of matrix multiplication). + + + +### Example: Matmul Kernel on Hopper GPUs + + + +Let's begin with a matrix multiplication example designed to run on Hopper GPUs. This kernel utilizes the Hopper-specific `wgmma` (warpgroup matrix multiply accumulate) instruction. `wgmma` is issued by a single Mosaic GPU thread and runs asynchronously on the TensorCore. + +Our example kernel implements a blockwise matrix multiplication of two matrices of shape `[M, K] @ [K, N] = [M, N]`, where each output block is computed in parallel over the CUDA grid. This grid is specified as the `grid` argument to the outer `plgpu.kernel`, and parallelizes over the non-contracting dimensions M, N of the matrix multiplication. + + + + +
+ + + + + +Within a program instance, we run a sequential pipeline using `plgpu.emit_pipeline` that reduces over the contracting dimension K of the matrix multiplication. On each iteration of the pipeline, we load one tile from each input matrix, multiply them, and then store the result in an accumulator Ref (`plgpu.ACC`). `plgpu.ACC` is a special type of Ref that lives in registers and holds the intermediate results of WGMMA. Once we have accumulated over the entire contracting dimension, we write out the result to the output Ref. + +To perform the actual matrix multiplication, we call `plgpu.wgmma` with the accumulator, LHS, and RHS Refs as arguments in order to push the arguments into the TensorCore pipeline. All WGMMA operations are executed in order, so this can be viewed as pushing operations into a queue. Since `wgmma` is an asynchronous instruction, `plgpu.wgmma_wait(N)` is used to wait until there are no more than N `wgmma` operations left in-flight. In this particular implementation we wait for 1 in-flight WGMMA, meaning that the WGMMA we queue on the current iteration will be waited for on the next iteration. +- `wgmma` wants it's arguments to be in a specific format, defined in the [CUDA documentation](https://docs.nvidia.com/cuda/parallel-thread-execution/#register-fragments-and-shared-memory-matrix-layouts). These are implemented by the `TilingTransform` and `SwizzleTransform` transformations on the input BlockSpecs. Note that in the future transforms will be inferred automatically by Mosaic GPU and these will not need to be manually specified. See the [wgmma reference](https://docs.jax.dev/en/latest/pallas/gpu/reference.html#hopper-wgmma) for full details on using this instruction. +- We use the `delay_release` parameter in conjunction with `plgpu.wgmma_wait(1)` to always allow one `WGMMA` operation to stay in-flight in order to ensure good TensorCore utilization. Without this, we would be flushing the TensorCore pipeline on every iteration of the kernel. + + +```python id="6Vf5_VA9iCD1" +def matmul(a, b, tile_m=128, tile_n=128, swizzle=128): + dtype = jnp.float16 + swizzle_elems = swizzle // jnp.dtype(dtype).itemsize + tile_k = swizzle_elems + grid_m = m // tile_m + grid_k = k // tile_k + grid_n = n // tile_n + assert tile_m % swizzle_elems == 0 + + # Note: Transforms will be inferred automatically + # by Mosaic GPU in the future. + transforms = ( + plgpu.TilingTransform((8, swizzle_elems)), + plgpu.SwizzleTransform(swizzle), + ) + + def kernel(a_gmem, b_gmem, o_gmem, o_smem, acc): + def pipeline_step(_, a_smem, b_smem): + plgpu.wgmma(acc, a_smem, b_smem) + plgpu.wgmma_wait(1) + + # pl.program_id obtains the index into the grid. + pid_m = pl.program_id(0) + pid_n = pl.program_id(1) + + pipeline = plgpu.emit_pipeline( + pipeline_step, + in_specs=[ + plgpu.BlockSpec( + (tile_m, tile_k), lambda k: (pid_m, k), transforms=transforms + ), + plgpu.BlockSpec( + (tile_k, tile_n), lambda k: (k, pid_n), transforms=transforms + ), + ], + grid=(grid_k,), + max_concurrent_steps=2, + delay_release=1, + ) + + pipeline(a_gmem, b_gmem) + # Store WGMMA accumulator to SMEM and then to GMEM. + o_smem[...] = acc[...].astype(dtype) + plgpu.commit_smem() + m_slice = pl.ds(pid_m * tile_m, tile_m) + n_slice = pl.ds(pid_n * tile_n, tile_n) + plgpu.copy_smem_to_gmem(o_smem, o_gmem.at[m_slice, n_slice]) + plgpu.wait_smem_to_gmem(0) + + return plgpu.kernel( + kernel, + out_shape=jax.ShapeDtypeStruct((m, n), jnp.float16), + scratch_shapes=[ + plgpu.SMEM((tile_m, tile_n), jnp.float16), + plgpu.ACC((tile_m, tile_n), jnp.float32) + ], + # grid specifies the CUDA grid. + # Instances of `kernel` will be executed in parallel over this grid. + grid=(grid_m, grid_n), + grid_names=("m", "n"), + )(a, b) + +m = 132 * 128 +n = 4 * 128 +k = 10 * 64 +key1, key2 = jax.random.split(jax.random.key(42), 2) +a = jax.random.uniform(key1, shape=(m, k), dtype=jnp.float16) +b = jax.random.uniform(key2, shape=(k, n), dtype=jnp.float16) + +result = matmul(a, b) + +np.testing.assert_allclose(result, a @ b) +``` + + +### Warp Specialization + +Warp specialization is a technique where we program each warp/warpgroup to perform a single task in order to give the GPU hardware the flexibility to schedule them at runtime. Recall that each streaming multiprocessor (SM) in a GPU contains warp schedulers that can swap execution between warps, so for example when one warp is stalling it can begin executing a different warp. In practice, this can be more performant than programming a single instruction stream where the compiler must statically schedule the operations and attempt to overlap them optimally. + +In particular, we are interested in warpgroup specialization on Hopper+ GPUs, where it can be useful to have a separate warpgroup issuing TMAs (GMEM/SMEM copies) from the warpgroups performing arithmetic, since indexing calculations and issuing TMAs can take up a significant amount of time and potentially leave the TensorCore idle. The figure below depicts a standard, non-specialized kernel on the left where TMAs (async copies) and matrix multiplication are issued from a single instruction stream, and a warp-specialized version on the right where communication and arithmetic are handled on separate warpgroups. A *consumed barrier* is used to synchronize between the specialized warpgroups that signals to the memory warpgroup when it is safe to begin the next TMA. + + + + + + +
+ + + + + +Warp specialization can be enabled in Pallas by using the `plgpu.emit_pipeline_warp_specialized` helper. This pipeline helper handles all of the logic in the memory thread, and the user only needs to specify the work done in the compute threads. It shares the a similar API as the standard `emit_pipeline`, and currently supports the following arguments: + +```python +plgpu.emit_pipeline_warp_specialized( + body: Callable, + * + grid: tuple[int, ...], + in_specs: Sequence[pallas_core.BlockSpec] = (), + out_specs: Sequence[pallas_core.BlockSpec] = (), + max_concurrent_steps: int, + compute_context: Callable + num_compute_wgs: int, + memory_registers: int + wg_axis: str, + memory_thread_idx: int | None = None, +) +``` + +There are a few arguments specific to this pipeline emitter, which are: +- `num_compute_wgs` specifies how many compute threads/warpgroups to use. The pipeline emitter always uses a single memory thread, so in `plgpu.kernel` you should specify `num_threads=num_compute_wgs+1`. +- `memory_registers` controls how many registers to allocate to the memory thread. The remaining registers are partitioned evenly among the compute threads. The default value is 40 and should be adjusted up or down depending on whether register spills are encountered. +- `wg_axis` the name of the thread/warpgroup axis (as specified by the `thead_name` argument of `plgpu.kernel`). +- `memory_thread_idx` specifies which Pallas thread to designate as the memory thread. Defaults to the last thread. +- `compute_context` is a enables you to specify a prologue/epilogue to the pipeline that only runs in the compute thread. The function allows you to define the initialization and consumption of a loop carry through the pipeline. All compute thread specific arrays should be instantiated here so the memory thread does not materialize them in registers -- otherwise, you may experience slowdowns due to register spills. + +The pipeline body of the warp specialized pipeline is run in parallel by all compute threads, and SMEM is shared between compute threads since they are scheduled within the same CUDA block.`lax.axis_index` can be used inside the kernel to obtain the Pallas thread index in order to divide up work amongst compute threads. + + + + +### Example: Matrix Multiplication with Warp Specialization + +The following example extends the previous matrix multiplication example to use warp specialization. This particular kernel uses 2 compute threads, which operate on separate columns of the RHS matrix but share the same LHS. Each invocation of the pipeline therefore computes 2 adjacent blocks in the output matrix. + + + + + +
+ + + + +We use the `compute_context` pattern to initialize the WGMMA accumulator, and copy the final accumulator from registers into SMEM. Here, the compute context is defined in the function `compute_thread`. It is critical that the accumulator be created inside of the `compute_thread` function to avoid allocating it in the memory thread which would waste registers. To perform the WGMMA, we wrap the `wgmma` instruction in a `pl.run_state` in order to create an accumulator ref that is initialized to the carry value. + +Instead of using `pl.pallas_call` to call the kernel, we instead use the GPU-specific `plgpu.kernel` entry point. `plgpu.kernel` allows us to specify the number of threads to launch per CUDA block via the `num_threads` argument, and allows us to specify a `thread_name` we can use to query the Pallas thread index inside of the kernel. + + + +```python id="EJhWnwJlFGaT" +def matmul_warp_specialized(a, b, tile_m=128, tile_n=128, swizzle=128, + compute_wgs=2): + dtype = jnp.float16 + elems_128b = swizzle // jnp.dtype(dtype).itemsize + tile_k = elems_128b + grid_m = m // tile_m + grid_k = k // tile_k + grid_n = n // tile_n + assert tile_m % elems_128b == 0 + + transforms = ( + plgpu.TilingTransform((8, elems_128b)), + plgpu.SwizzleTransform(128), + ) + + def kernel(a_gmem, b_gmem, o_gmem, o_smem): + wg_idx = lax.axis_index("wg") + wg_slice = pl.ds(wg_idx * tile_n, tile_n) + # pl.program_id obtains the index into the pallas_call grid. + pid_m = pl.program_id(0) + pid_n = pl.program_id(1) + + def compute_thread(pipeline): + acc = plgpu.layout_cast( + jnp.full((tile_m, tile_n), 0, dtype=jnp.float32), plgpu.Layout.WGMMA, + ) + # yield marks the place where the pipelined loop will be inserted. + # Its argument are the initial carry values, and its result is the carry + # value after the loop completes. + final_acc = pipeline(acc) + o_smem[:, wg_slice] = final_acc[...].astype(dtype) + + def kernel_body(_, a_smem, b_smem, carry): + acc = carry + b_smem_wg = b_smem.at[:, wg_slice] + def do_wgmma(acc_ref): + plgpu.wgmma(acc_ref, a_smem, b_smem_wg) + acc = pl.run_state(do_wgmma)( + plgpu.ACC.init(acc)) + return acc + + pipeline = plgpu.emit_pipeline_warp_specialized( + kernel_body, + in_specs=[ + plgpu.BlockSpec( + (tile_m, tile_k), lambda k: (pid_m, k), transforms=transforms + ), + plgpu.BlockSpec( + (tile_k, tile_n * 2), lambda k: (k, pid_n),transforms=transforms + ), + ], + grid=(grid_k,), + compute_context=compute_thread, + max_concurrent_steps=2, + num_compute_wgs=compute_wgs, + memory_registers=40, + memory_thread_idx=2, + wg_axis="wg", + ) + # Call the pipeline + pipeline(a_gmem, b_gmem) + # Copy the output from SMEM to GMEM. + plgpu.commit_smem() + m_slice = pl.ds(pid_m * tile_m, tile_m) + n_slice = pl.ds(pid_n * tile_n * 2, tile_n * 2) + plgpu.copy_smem_to_gmem(o_smem, o_gmem.at[m_slice, n_slice]) + plgpu.wait_smem_to_gmem(0) + + return plgpu.kernel( + kernel, + out_shape=jax.ShapeDtypeStruct((m, n), jnp.float16), + scratch_shapes=[ + plgpu.SMEM((tile_m, tile_n * 2), jnp.float16) + ], + grid=(grid_m, grid_n // 2), + grid_names=("m", "n"), + num_threads=3, # 2 compute, 1 memory. + thread_name="wg" + )(a, b) + +m = 132 * 128 +n = 4 * 128 +k = 10 * 64 +key1, key2 = jax.random.split(jax.random.key(42), 2) +a = jax.random.uniform(key1, shape=(m, k), dtype=jnp.float16) +b = jax.random.uniform(key2, shape=(k, n), dtype=jnp.float16) + +result = matmul_warp_specialized(a, b) + +np.testing.assert_allclose(result, a @ b) +``` diff --git a/docs/pallas/gpu/reference.md b/docs/pallas/gpu/reference.md new file mode 100644 index 000000000000..d68730619b06 --- /dev/null +++ b/docs/pallas/gpu/reference.md @@ -0,0 +1,783 @@ +# Writing Mosaic GPU kernels with Pallas + +This page is a reference for the most important features of the Pallas:MGPU backend. +It's not a tutorial and as such we do not expect everyone to read it top to bottom. +Still, it is worth going over +just to familiarise yourself with some patterns you can find in other tutorials. + +In the following examples, we're going to assume the following imports are in scope: +```python +import jax.experimental.pallas as pl +import jax.experimental.pallas.mosaic_gpu as plgpu +``` + +## What is a GPU? + +Technically, the NVIDIA GPU architecture looks as follows: the GPU is partitioned into +_streaming multiprocessors_ (SMs). The way this manifests in the CUDA programming model +is that each _CUDA thread block_ (or CTA) is scheduled on exactly one SM, but multiple +blocks can be scheduled onto a single SM at a time. + +Each SM contains a chunk of fast memory called _shared memory_ (SMEM) and 4 subdivisions, +each containing a _warp scheduler_ and compute units (ALU, TensorCore, ...). +This is also reflected in the CUDA programs: each _warp_ (a group of consecutive 32 CUDA +threads in a block) is assigned to one of those subdivisions in a round-robin fashion. +Similarly to blocks, each warp is assigned to exactly one subdivision (it never migrates), +but multiple warps can be assigned to the same SM subdivision. At each clock cycle, the +warp scheduler from each subdivision tries to select one of its resident warps to execute +the next instruction. + +
A diagram of one NVIDIA SM
+ +Going further, recent CUDA versions also outline the concept of a _warpgroup_, which are +4 consecutive warps. Knowing how the hardware looks like, we can see where this is coming +from: 4 consecutive warps occupy the 4 quarters of an SM and let us issue instructions +that utilize the whole SM. + +```{note} +A GPU can be viewed in many different ways and in here we want to focus on a slightly +simplified model that is very TensorCore-centric. This should help you navigate the +complexities of writing kernels involving the TensorCore, but keep in mind that the +real picture is more complicated. +``` + +For our purposes, TensorCore operations have grown so big that it no longer makes much +sense to follow the CUDA model. As such, to us, a GPU is a collection of single-threaded cores +(SMs) with one thread of Pallas:MGPU corresponding to a CUDA warpgroup. In this model, each +operation you perform in the kernel occupies the whole CUDA warpgroup, and its constituent +warps always run in lockstep (modulo the jitter from hardware scheduling) and never take +different paths through control flow (with the small exception of `core_map` that we will +discuss later). One notable addition here is that we still allow you to co-schedule multiple +of those Pallas-level threads on the same SM so that they can cooperate and communicate +through shared memory (we realize that by putting them in the same CUDA block). + +```{note} +From now on, whenever we say "thread", we refer to the Pallas thread, not a CUDA thread/lane. +``` + +```{note} +This is very similar to a programming model popularized by [Triton](https://triton-lang.org/), +but as you will see there are a few differences. Mosaic GPU tends to be more low level, +which usually means you will have to put in more work, but it also puts you more in control. +In our view both approaches have their merits and we encourage you to pick the backend that +suits your needs the best! Pallas supports and will continue to support Triton as an alternative +GPU backend. +``` + +### In-order execution & using multiple hardware units + +Unlike more complicated CPU architectures GPU only support in-order execution. That, however, +does not mean that at any given time only a single instruction is running! Each SM quarter +has multiple independent functional units: TensorCore, Arithmetic logic unit (ALU), +Load/Store (LSU), Special function unit (SFU). If the first instruction targets one of the +units and is followed by another one (that does not use the result of the first one), then the +warp scheduler can issue the second one before the first one completes. This is often referred +to as instruction-level parallelism (ILP) and is a common theme in modern TensorCore kernels: +TensorCore operations are so big and take so many cycles to complete, that it is a waste to not +try to use other units in the meantime. + +To extend this even further, we can take advantage of this hardware-unit-level parallelism by +allowing multiple Pallas threads to run concurrently. If one of the threads primarily +occupies the ALU, while another one primarily issues TensorCore related instructions, we can +take advantage of the efficient context switching built into the warp schedulers to keep both +units busy. This is one of the core idea behind algorithms such as [FlashAttention 3](https://arxiv.org/abs/2407.08608) +or [CUTLASS ping-pong matmul kernels](https://pytorch.org/blog/cutlass-ping-pong-gemm-kernel/). + +For more information on how warp scheduling and instruction issue works, we recommend reading +[Analyzing Modern NVIDIA GPU cores](https://arxiv.org/abs/2503.20481). + +### Memory spaces + +The GPU features a few different memory spaces that can be totally ordered from largest (in +terms of capacity) and slowest (in both total bandwidth and latency of a single access). + +
A diagram of memory spaces of an NVIDIA GPU
+ +The biggest memory space is `plgpu.GMEM`, for _global memory_. In recent data-center grade GPUs +this memory space is often measured in tens or even hudreds of gigabytes, but it is also the +slowest one. + +The next memory space, used for the L2 cache, is also more or less global in the +sense that it is shared by the whole GPU, but its use can only be influenced indirectly through +cache hints. As such, there's no way to manually place values in there and so this memory space +is not exposed in Pallas:MGPU. While only about a 100MB in size, this memory has considerably +higher bandwidth than GMEM, and so it is still often recommended to take advantage of it while +writing high-performance kernels. + +Next in line is _shared memory_, or `plgpu.SMEM`. This memory is located directly inside each SM +and so it is partitioned. Unless block clusters are used (see the section of clusters below), +each block is only allowed to access its own SMEM allocations. + +Finally, the lowest level memory space is the _register memory_. This is where every single value +(i.e. JAX array) in a Pallas kernel will be located. If the compiler runs out of registers to +store those arrays, it will insert _spills_, meaning that it will periodically store and reload +values to memory. Those spills often introduce other significant performance degradations and so +we recommend avoiding them. The warning messages about spills can be clearly seen in the `ptxas` +messages during kernel compilation. To make them visible, run with `MOSAIC_GPU_DUMP_PTXAS=1` +in your environment. + +The Blackwell GPU generation, has one additional memory space called _tensor memory_ or `plgpu.TMEM`. +TMEM is very similar to register memory, only it is explicitly allocated and managed by you. +It is used to store the MMA accumulator, operand metadata (for sparsity or scaling), +and optionally the left MMA operand. See the Blackwell MMA section for more information about TMEM. + +#### Requesting/allocating memory in specific memory spaces + +Kernel inputs or outputs are placed in SMEM by default. If you want to access them as GMEM references +add `memory_space=plgpu.GMEM` to their `BlockSpec`. If you want the kernel to be called with the whole +input or output array in GMEM, it is sufficient to specify `BlockSpec(memory_space=plgpu.GMEM)`. + +`SMEM` and `TMEM` can be allocated explicitly in the `scratch_shapes` argument of `pl.pallas_call`, +or using `pl.run_scoped`. To allocate a reference, simply call the memory space object with the +requested shape and dtype. For example: `plgpu.SMEM((128, 128), jnp.float16)` will allocate a 128x128 +array of float16 elements in shared memory. + +#### Taking advantage of the L2 cache + +While the L2 cache cannot be managed manually, its noticeably higher bandwidth compared to global +memory makes it worth thinking about. The simplest way to take advantage of it, is to reorder +the parallel grid dimensions so that invocations that are scheduled in similar time periods also +access the same input data. + +While the CUDA programming model does not guarantee anything about the order in which the blocks +are assigned to SMs, in recent generations the heuristic seems to simply iterate over the +`(x, y, z)` CUDA grids in column-major order (i.e. `x` is the fastest-changing dimension and +`z` is the slowest). Similarly, Pallas:MGPU does not guarantee how a user-specified grid is mapped to +the CUDA grid (Pallas supports grids of arbitrary rank, not just up to 3D). However, you can assume that +the iteration will happen in _row-major_ order. That is, if a grid has dimensions `(a, b)`, then +`b` will be the fastest-changing dimension and `a` will be the slower one. + +To give a practical example of this, consider a plain matrix multiplication kernel. There, one +usually uses two parallel grid dimensions `(m, n)`, corresponding to tiling the two non-contracting +dimensions. If we use this simple scheme, in Pallas:MGPU all programs with id `(0, ...)` will be +scheduled before any block with id `(1, ...)`. And, collectively, the programs with `m=0` have to +read all of the `B` operand! If the `n` or `k` dimensions are very large, there is no chance that +we'll be able to get cache hits from the `(1, ...)` programs from accesses made by the `(0, ...)` +programs. For simplicity, assuming we can only run 16 blocks at a time, we see this access pattern +from the first scheduled wave: + +
+ + Your browser does not support SVGs or scripting is disabled. + This would be an image showing the access pattern of first 16 blocks without grid tiling. + +
+ +However, if we simply rearrange the grid to be `(m // mt, n, mt)` (and then replace `pl.program_id(0)` +with `pl.program_id(0) * mt + pl.program_id(2)` in the kernel), it is straightforward to see that a +band of programs along both dimensions will be scheduled concurrently (instead of scheduling a single +row). This greatly increases the number of concurrent programs that load similar slices of data, +usually significantly improves the L2 utilization and hence the overall performance of the kernel +(if it was memory bound). Continuing our example with 16 blocks and using `mt=4`, we get the following +access pattern: + +
+ + Your browser does not support SVGs or scripting is disabled. + This would be an image showing the access pattern of first 16 blocks with grid tiling. + +
+ +Note that even though the number of active blocks hasn't changed, the total footprint of the data they +access has halved! We get a much higher chance of getting L2 hits now. + +## Array layouts and memory reference transforms + +In Pallas, the data structures you work with (arrays and references) have a +**logical shape** (e.g., a 128x128 matrix). This +logical shape must be mapped to a **physical representation** (how the data is +actually represented in the GPU's memory). The specific mapping depends on where the +data resides: + +1. **Array Layouts:** Arrays are stored in register memory and we call this mapping + a _layout_. Layouts define how the elements of an array are + distributed across the registers available to the CUDA lanes that form a Pallas thread. +2. **Memory Reference Transforms:** For mutable references pointing + to `SMEM`, this mapping is called a _transform_. + Transforms describe how the logical data structure is arranged within that + block of memory. + +These concepts are crucial for performance, especially when interacting with +specialized hardware units like TensorCores or optimizing memory access +patterns. + +```{note} +We are working on a mode that will deal with assigning layouts and transforms fully +automatically (although with way to provide hints and more control). The APIs listed +below will likely continue to function, but will become optional. +``` + +### Memory reference transforms + +Transforms are applied when a memory reference is first allocated. Pallas +primitives that operate on these references will automatically account for their +associated transforms. + +``` +def body(..., scratch_ref): + # Asynchronous copy will reformat the GMEM data to match the SMEM transforms + plgpu.copy_gmem_to_smem(..., scratch_ref, barrier) + barrier.wait() + plgpu.wgmma(..., scratch_ref) # wgmma only accepts properly transformed refs + ... +``` + +There are two ways in which references are allocated and each has a way to select +the desired transforms: + +**1. Using `plgpu.BlockSpec`** + +```python +transforms = (plgpu.TileTransform((8, 64)), plgpu.SwizzleTransform(128)) +f = pl.pallas_call( + in_specs=plgpu.BlockSpec(in_block_shape, in_index_map, transforms=transforms), + out_specs=plgpu.BlockSpec(out_block_shape, out_index_map, transforms=transforms), + ... +) +``` + +Note that unlike `plgpu.BlockSpec`, `pl.BlockSpec` does *not* allow specifying +transforms. + +**2. Specifying the `transforms` argument on the allocated `SMEM`** + +```python +transforms = (plgpu.TileTransform((8, 64)), plgpu.SwizzleTransform(128)) +f = pl.pallas_call( + scratch_shapes=plgpu.SMEM((128, 128), jnp.float16, transforms=transforms), + ... +) +``` + +The available transforms are: +* `plgpu.TileTransform(tile_shape)`, which organizes the data into contiguous, + non-overlapping tiles of shape `tile_shape`. The data of one tile is always + fully linearized (row-major), before another tile begins (tiles are also + traversed in row-major order). As an example, applying `TileTransform((8, + 64))` to a `(128, 128)` reference means the data corresponding to the logical + slice `[0:8, 0:64]` will be stored first (row-major), followed by + `[0:8, 64:128], [8:16, 0:64], [8:16, 64:128]`, and so on. A different way to achieve + this would be to take the input array `x` and traverse + `x.reshape(128 // 8, 128 // 64, 8, 64).transpose(0, 2, 1, 3)` in row-major order. +* `plgpu.SwizzleTransform(swizzle_in_bytes)`, which transforms the data as described in the + [PTX docs](https://docs.nvidia.com/cuda/parallel-thread-execution/#tensor-swizzling-modes) and + [CUDA docs](https://docs.nvidia.com/cuda/cuda-c-programming-guide/#the-swizzle-modes). + Swizzling is useful, because it allows transferring data in MMA-related layouts + between register and shared memory without bank conflicts. The exact details + of how the memory looks like after swizzling _are not that important_, since + all primitives will account for it automatically. Note that the swizzle amount + is specified in bytes (only 128, 64, 32 and 16 are supported), and is usually + accompanied by a `TileTransform` (which uses elements in its shape!). +* `plgpu.TransposeTransform(permutation)`, which permutes the dimensions of the array before it is linearized. + This is primarily useful in that it lets you change the layout during the GMEM-SMEM copies (only + do keep in mind that changing the minormost/last dimension is not supported by the hardware). + +### Array layouts + +There are a few useful layouts we have defined for you so far: +* `plgpu.Layout.WGMMA`, which is the layout in which the Hopper-generation TensorCore + expects the MMA accumulator or 16-bit input operands to have in registers. +* `plgpu.Layout.WGMMA_ROW`, which is the layout obtained after the above after reducing + it along the rows. Re-broadcasting the rows is free and will produce a value with `WGMMA` + layout. +* `plgpu.Layout.WGMMA_COL`, which is an analogue of the one above, only reduced along + columns instead of rows. +* `plgpu.Layout.WG_STRIDED`, where the value is partitioned equally among the 128 + CUDA lanes making up a Pallas thread. The consecutive elements (after vectorization) + are assigned to the lanes in a round-robin fashion. Very simple and effective when + no interaction with TensorCores is needed. +* `plgpu.Layout.WG_SPLAT`, indicating that the value is constant. Each CUDA lane will + hold a single register that contains the value. You normally never have to interact + with this layout, as it is implicitly used when constant values are created and + is always implicitly convertible to other layouts. + +At the moment, in the default mode of operation, array layout propagation happens +only in a forward direction and there is little implicit support for reconciling +layout conflicts: only splat layouts can be implicitly converted into any other +layout. If you e.g. try to add two arrays that have a different layout, the lowering +will complain and fail. There are very limited facilities that let you convert between +layouts, and we usually recommend storing the value to SMEM and reading it back in +the target layout. + +## MMA (TensorCore) + +In this section, we focus on how Pallas:MGPU kernels can utilize the TensorCore unit. +The programming interface of the TensorCore changes significantly between different +NVIDIA GPU generations, which is why the lowest-level interfaces differ in Pallas:MGPU as well. + +Each MMA operation is associated with three operands: +* the accumulator `D` of shape `(M, N)`, +* the left input `A` of shape `(M, K)`, +* the right input `B` of shape `(K, N)`. +All operands must have the same element type. + +Each use of MMA involves a few steps: +1. Allocating the space for the accumulator (MMA implicitly performs `D += A @ B`) +2. Preparing the `A` and `B` operands +3. Issuing the operation +4. Waiting for the operation to complete +5. Reading out the result + +Steps 2.-4. are usually performed in a loop over the contraction dimension (`K`). + +### Memory space of `A` and `B` operands + +The `A` and `B` operands are generally best passed in through SMEM, where they can +be conveniently loaded using `plgpu.copy_gmem_to_smem`. For those operands to be +compatible with MMA operations, they need to have the appropriate tiling and swizzling +transforms specified upon their allocation. For all currently supported generations, +the TensorCore requires the data to be laid out into row-major 2D tiles of shape +`(8, swizzle_elems)`, where `swizzle_elems` is derived by dividing the swizzle by the +element type bytewidth. The currently supported swizzles are: 128, 64, and 32. Larger +swizzles are preferable as they improve the performance of GMEM-to-SMEM copies. + +```python +def mma_transforms(shape_dtype: jax.ShapeDtypeStruct): + assert len(shape_dtype.shape) == 2 + if shape_dtype.shape[0] % 8: + raise ValueError("Number of rows must be divisible by 8") + for swizzle_bytes in (128, 64, 32): + swizzle_elems = swizzle_bytes // shape_dtype.dtype.itemsize + if shape_dtype.shape[-1] % swizzle_elems == 0: + return (plgpu.TilingTransform((8, swizzle_elems)), + plgpu.SwizzleTransform(swizzle_bytes)) + raise ValueError("Failed to find transforms for the specified window type") +``` + +If the operands need to be transformed, the `A` operand can be passed in through a different +memory space (architecture dependent, see below). The `B` operand _must_ be located in SMEM. + +### Transposed operands + +When performing MMA on 16-bit operands, the TensorCore can automatically transpose the +input data. For example, the `A` reference is allowed to be of shape `(K, M)`, but it +has to be transposed before passing it into the mma function. For example: +```python +assert acc_ref.shape == (M, N) and a_ref.shape == (K, M) and b_ref.shape == (K, N) +a_ref_t = plgpu.transpose_ref(a_ref, (1, 0)) +assert a_ref_t.shape == (M, K) # The shape expected by plgpu.wgmma +plgpu.wgmma(acc, a_ref_t, b_ref) +``` +An analogous operation is allowed on the `B` reference in this case too. + +### Hopper (`wgmma`) + +In this section, we cover the basics of using the Hopper-generation TensorCores, exposed in +PTX as the [`wgmma.mma_async` instruction](https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-matrix-instructions-wgmma-mma). + +#### Allocating the accumulator + +In the Hopper hardware architecture the accumulator is allocated in registers, but in Pallas +it is modeled as a mutable reference, as each MMA operation accumulates in-place. +There are two ways to allocate the accumulator. + +To create a zero-initialized accumulator you can use `pl.run_scoped` with a +`plgpu.ACC((m, n), dtype)` type. +```python +def compute(acc_ref): + ... + return acc_ref[...] +output = pl.run_scoped(compute, plgpu.ACC((m, n), jnp.float32)) +``` +Dereferencing the accumulator reference, as seen in the end of the `compute` function will +implicitly await all outstanding WGMMA operations. + +If you'd like to initialize it with an existing array, you can use `pl.run_state` with +`plgpu.ACC.init(init_array)`: +```python +def compute(acc_ref): + ... + return # pl.run_state only returns the final value of the accumulator +output = pl.run_state(compute)(plgpu.ACC.init(init_array)) +``` +If `pl.run_state` has accumulator operands, it implicitly awaits all outstanding WGMMA +operations before returning the final values. + +#### Preparing the `A` and `B` operands + +As discussed above, we recommend passing in `A` and `B` through shared memory. In this +case the correct tiling and swizzling transforms must be specified. + +`plgpu.wgmma` additionally allows passing in `A` through registers (i.e. not an SMEM +reference but as a regular JAX array). This mode, however, comes with a number of +significant drawbacks and it is very difficult to ensure sufficient synchronization to +make this safe. + +TODO: Explain the conditions under which it is acceptable to do this. + +#### Issuing the operation + +The supported MMA shapes are such that: +* `M` is divisible by 64 +* `N` is divisible by 8 and smaller than 256 +* `K` is a multiple of `swizzle` divided by the bytewidth of element type + +The currently supported data types are: `jnp.float32`, `jnp.bfloat16` and `jnp.float16`. +The accumulator `D` must be a `jnp.float32`, with the exception of `jnp.float16` inputs, +in which case it is allowed to be `jnp.float16` as well. + +#### Waiting for the operation to complete + +Each `plgpu.wgmma` call implicitly synchronizes with all previous `plgpu.wgmma` calls, such +that once control returns from it, we guarantee that no WGMMA other than the last issued +one is still running. As such, any SMEM regions that were read by previously issued WGMMA +instructions can be reused. This is especially relevant for pipelining WGMMA with async memory copies: +```python +buffers = 3 # In reality you might want even more +assert a_smem.shape == (buffers, m, k) +assert b_smem.shape == (buffers, k, n) +assert acc_ref.shape == (m, n) + +def fetch_a_b(ki, slot): + a_slice = ... # Replace with the right M/K slice + b_slice = ... # Replace with the right K/N slice + plgpu.copy_gmem_to_smem(a_gmem.at[a_slice], a_smem.at[slot], a_loaded.at[slot]) + plgpu.copy_gmem_to_smem(b_gmem.at[b_slice], b_smem.at[slot], b_loaded.at[slot]) + +def loop_body(i, _): + slot = jax.lax.rem(i, buffers) + plgpu.barrier_wait(a_loaded.at[slot]) + plgpu.barrier_wait(b_loaded.at[slot]) + plgpu.wgmma(acc_ref, a_smem.at[slot], b_smem.at[slot]) + # We know that only the last issued WGMMA is running, so we can issue a async load in + # into the other buffer + load_i = i + buffers - 1 + load_slot = jax.lax.rem(load_i, buffers) + @pl.when(jnp.logical_and(load_i >= buffers, load_i < num_steps)) + def _do_fetch(): + fetch_a_b(load_i, slot) +for slot in range(buffers): + fetch_a_b(slot, slot) +jax.lax.fori_loop(0, num_steps, loop_body, None) +``` + +### Blackwell (`tcgen05`) + +While Mosaic GPU supports `tcgen05` MMA instructions, exposing this capability to Pallas +is still work in progress. Stay tuned! + +## Using `core_map` + +`pl.pallas_call` is suitable for kernels where a single Pallas thread can +perform the whole computation for an entire CUDA block. The `pl.core_map` +function relaxes this restriction, allowing for using multiple threads within a +single block (e.g. for warp specialization) or across multiple blocks in a block +cluster (e.g. to utilize multicast TMA). + +### Replacing `pl.pallas_call` with `pl.core_map` or `plgpu.kernel` + +Let us begin with a simple Pallas kernel that increments an array: + +```python +@functools.partial( + pl.pallas_call, + grid=(2,), + in_specs=[pl.BlockSpec(block_shape=(128,), index_map=lambda i: (i,))], + out_specs=pl.BlockSpec(block_shape=(128,), index_map=lambda i: (i,)) + out_shape=jax.ShapeDtypeStruct((256,), jnp.float32), # Total output shape +) +def run_kernel(x_ref, y_ref): + # x_ref and y_ref are in SMEM! + y_ref[...] = x_ref[...] + 1 + +x = jnp.arange(256, jnp.float32) +y = run_kernel(x) +np.testing.assert_array_equal(y, x + 1) +``` + +We can write a similar kernel using `pl.core_map`. One big difference is that +unlike `pl.pallas_call`, no GMEM<->SMEM copies will be inserted automatically. +If you want them, you can either insert them yourself or use the +{py:func}`plgpu.emit_pipeline ` +helper. + +```python +@pl.run_state +def run_kernel(x_ref, y_ref): + # Here, we're not in the kernel yet! pl.run_state simply changes the JAX + # immutable arrays into mutable GMEM (not SMEM!) references. + + # Define the mesh: 2 CUDA blocks over 1 axis called "x" + mesh = plgpu.Mesh(grid=(2,), grid_names=("x",)) + + @pl.core_map(mesh) # core_map executes the body + def kernel_body(): + # Once we enter the pl.core_map scope, we are in the body of the kernel. + block_slice = pl.ds(lax.axis_index("x") * 128, 128) + o_ref[block_slice] = x_ref[block_slice] + 1 + +x = jnp.arange(128, jnp.float32) +y_init = jnp.zeros_like(x) +y = run_kernel(x, y_init) +np.testing.assert_array_equal(y, x + 1) +``` + +While `pl.core_map` is a powerful API, it is also quite low-level and is pretty +much always used in under `pl.run_state` (to make JAX arrays into refs) or +`pl.run_scoped` (to allocate for scratch refs). For that reason, we also +provide a convenience API `plgpu.kernel`: + +```python +mesh = plgpu.Mesh(grid=(2,), grid_names=("x",)) + +@functools.partial( + plgpu.kernel, + out_shape=jax.ShapeDtypeStruct((256,), jnp.float32), + mesh=mesh +) +def increment_kernel_core_map(x_ref, y_ref): + # x_ref and y_ref are in GMEM! + block_slice = pl.ds(lax.axis_index("x") * 128, 128) + o_ref[block_slice] = x_ref[block_slice] + 1 + +x = jnp.arange(128, jnp.float32) +y = run_kernel(x) # No need to preallocate outputs as in pl.core_map. +np.testing.assert_array_equal(y, x + 1) +``` + +```{note} +The `plgpu.Mesh` used with `pl.core_map` defines a topology for computation +*within a single GPU*, specifying how work is distributed across CUDA blocks +(the `grid`), Pallas threads within a block (`num_threads`), and potentially +CUDA block clusters (`cluster`). This is analogous to how `jax.sharding.Mesh` +defines a topology for distributed computation *across multiple devices* in JAX. +Both involve SPMD programs executing across the defined topology. Furthermore, +you can run "collectives" over the Pallas threads and cluster (e.g., using +`plgpu.ClusterBarrier` or collective async copies), similar to how JAX +collectives (`psum`, `all_gather`, etc.) operate across devices in a JAX `Mesh`. +Both also use named axes, and `lax.axis_index(axis_name)` can be used to get a +thread's or block's coordinate. +``` + +### Using multiple Pallas threads per CUDA block + +Below, you can find an example of two Pallas threads within a single block +synchronizing through a barrier and even exchanging data through SMEM. + +```python +mesh = plgpu.Mesh(num_threads=2, thread_name="pallas_thread") +@functools.partial( + plgpu.kernel, out_shape=x, mesh=mesh, scratch_shapes=[plgpu.Barrier()] +) +def run_kernel(x_ref, y_ref, barrier_ref): + thread_id = jax.lax.axis_index("pallas_thread") + + @pl.when(thread_id == 0) + def producer_thread(): + smem_val = x_ref[...] + 1 + plgpu.barrier_arrive(barrier_ref) # Signal the consumer thread + + @pl.when(thread_id == 1) + def consumer_thread(): + plgpu.barrier_wait(barrier_ref) # Wait for the producer thread + out_ref[...] = x_ref[...] + 1 + +x = jnp.arange(128, jnp.float32) +y = run_kernel(x) # There's no need to preallocate the input anymore. +np.testing.assert_array_equal(y, x + 2) +``` + +While this example is simple, you can find a more complicated example in the +[synchronization section](#cross-thread-synchronization). + +Multiple threads are frequently used in high-performance kernels such as the +latest flash attention variants or ping-pong matrix multiplication. In both of +those, there are 2 compute threads in the program that use the SM's ALU +and TensorCore in an alternating fashion to ensure no execution conflicts. + +Another common technique is to allocate one Pallas thread and devote it entirely +to scheduling asynchronous copies for data consumed by other threads. While +implementing this scheme from scratch can be complicated, we provide a +convenient helper API: `plgpu.emit_pipeline_warp_specialized`. + +### Using CUDA block clusters + +The kernel below launches a single cluster of 2 CUDA blocks and uses the TMA +multicast feature to collectively perform a copy of GMEM into SMEM of both +blocks. All blocks participating in the collective copy must schedule the exact +same copy for the program to be valid. + +```python +mesh = plgpu.Mesh(cluster=(2,), cluster_names=("cluster",)) + +@functools.partial( + plgpu.kernel, + out_shape=jax.ShapeDtypeStruct((2, 128), jnp.float32), + mesh=mesh, + scratch_shapes=[plgpu.SMEM((128,), jnp.float32), plgpu.Barrier()] +) +def run_kernel(x_ref, y_ref, smem_ref, barrier_ref): + # Specifying collective_axes will enable TMA multicast automatically. + plgpu.copy_gmem_to_smem(x_ref, smem_ref, barrier_ref, collective_axes="cluster") + plgpu.barrier_wait(barrier_ref) + plgpu.copy_smem_to_gmem(smem_ref, o_ref.at[lax.axis_index("cluster")]) + plgpu.wait_smem_to_gmem(0) + +x = jnp.arange(128, jnp.float32) +y = run_kernel(x) +# Each block gets the same data and writes it out. +np.testing.assert_array_equal(y, jnp.stack([x, x], axis=0)) +``` + +### Collective allocations in `pl.run_scoped` + +When using `pl.core_map` with multiple Pallas threads (i.e., `num_threads > 1` +in `plgpu.Mesh`), allocations made via `pl.run_scoped` (for SMEM or Barriers) +must be performed _collectively by all threads_. This is indicated by specifying +a `collective_axis` argument to the `run_scoped`, which has two effects: +1. it promises that all threads will call the same allocation, and +2. all threads will receive the exact same allocation. + +If collective_axes is not specified or does not include the Pallas thread axis, +each thread would get its own private copy of the scratch variable. This is +usually undesired and not supported at the moment. + +## Synchronization structures and primitives + +In this section, we go over the most important functions and data structures +used for synchronization between threads and also some asynchronous operations. + +### `commit_smem` + +Regular reads/writes to references are guaranteed to produce values consistent +with the sequential program order. For example, in the following program, it is +guaranteed that `value` is equal to `value2`. +```python +ref[...] = value +value2 = ref[...] +``` + +This guarantee, however, does not extend to asynchronous primitives such as async +copies or MMA operations. To make the SMEM writes visible to those primitives, you +are required to explicitly synchronize with them using the `plgpu.commit_smem()` function. + +For example: +```python +smem_ref[...] = value +plgpu.commit_smem() +plgpu.copy_smem_to_gmem(smem_ref, ...) +``` +or: +```python +smem_ref[...] = value +plgpu.commit_smem() +plgpu.wgmma(smem_ref, ...) +``` + +Failing to call this function is likely to cause subtle data races, due to those asynchronous +hardware units reading stale data from SMEM. Unfortunately, this function is relatively expensive, +which is why we rely on you, the user, to insert it in the minimal number of places where it's necessary. + +### `Barrier` + +This is essentially a thin wrapper around an array of PTX `mbarrier` types and is +passed in as a reference. All functions involving barriers expect to only get a single +barrier argument, and so if the reference contains multiple, you have to extract one +of them explicitly using `barriers.at[index]`. `Barrier`s are always allocated in SMEM +and as such have relatively low overheads. Each barrier can be configured to complete +after a fixed number of "arrivals" (by default 1). + +To block a thread until a barrier completes, use the following function: +```python +plgpu.barrier_wait(barrier) +``` + +There are three operations that can complete a barrier: + +```{warning} +It is critical to ensure that the synchronization scheme makes it impossible for two +barrier completions to happen without a call to `plgpu.barrier_wait` in between them. +For example, if you use `Barrier`s to synchronize two producer/consumer threads, you +need to perform barrier synchronization going both ways to introduce "backpressure" +that will stop one thread from arriving twice before the other one had a chance to await. +Failing to satisfy this will corrupt the data structure and can cause surprising failures +(including CUDA runtime errors). See below for an example of a valid program with two threads. +``` + +```{warning} +Another critical restriction is that the number of barrier completions must equal the +number of barrier waits throughout the barrier's lifetime. It is not allowed to end a scoped +allocation of a barrier when it has an unawaited completion. Otherwise, when it is +reused by the compiler, leaving it in this state can cause problems downstream. +``` + +```{warning} +Finally, it is crucial to ensure that each thread that ever waits on a `Barrier` +takes part in all `wait` operations on it. It is not allowed to e.g. await every +other completion of a barrier from one thread, and all other completions from another +one. Doing so will lead to deadlocks. To recap: when a `Barrier` is used to wait in +some thread, it must observe every single completion of that barrier (by waiting on it). + +Note that the `Barrier` can receive arrivals from any source, without restrictions. +``` + +#### Asynchronous GMEM-to-SMEM copies + +When an asynchronous GMEM-to-SMEM copy is being executed by the TMA engine, it will +post progress updates to the barrier given to `plgpu.copy_gmem_to_smem`. Once the copy +is complete, the barrier will complete one arrival as well. + +(cross-thread-synchronization)= +#### Explicit arrival (cross-thread synchronization) + +Any thread can explicitly arrival on a barrier using the following function: +```python +plgpu.barrier_arrive(barrier) +``` + +This is especially useful when synchronizing two threads that are in producer/consumer +roles. In this case, we recommend allocating two arrays of `Barrier`s, with size equal +to the size of the "queue" used to pass data between the two threads. For example, +assume one thread continues writing tiles of an array to SMEM while another thread +reads them. We triple-buffer the SMEM region to allow more asynchrony between the two +threads: + +```python +tid = jax.lax.axis_index("thread") +assert queue.shape == (buffering, *item_shape) +assert produced.shape == consumed.shape == (buffering,) + +def thread0_body(i, _): + slot = jax.lax.rem(i, buffering) + @pl.when(i >= buffering) + def _await_consumed(): + plgpu.barrier_wait(consumed.at[slot]) # Wait for consumption of the value before overwriting it + # Option 1: Compute the next value + queue[slot] = produce() + plgpu.barrier_arrive(produced.at[slot]) # Signal the value is ready + # Option 2: Produce the value through async_copy + # plgpu.copy_gmem_to_smem(..., queue.at[slot], barrier=produced.at[slot]) +pl.when(tid == 0)(lambda: jax.lax.fori_loop(0, steps, thread0_body, None)) + +def thread1_body(i, _): + slot = jax.lax.rem(i, buffering) + plgpu.barrier_wait(produced.at[slot]) # Wait for the value to be ready + consume(queue[slot]) # Load and compute + plgpu.barrier_arrive(consumed.at[slot]) # Signal that the value is consumed +pl.when(tid == 1)(lambda: jax.lax.fori_loop(0, steps, thread1_body, None)) +``` + +#### Awaiting `tcgen05` TensorCore instructions + +While Mosaic GPU supports `tcgen05` MMA instructions, exposing this capability to Pallas +is still work in progress. Stay tuned! + +### `ClusterBarrier` + +TODO + +### `Semaphore` + +TODO + +## Asynchronous copies + +TODO + +## Inline Mosaic GPU + +TODO + +## Compiler parameters + +TODO diff --git a/docs/pallas/grid_blockspec.md b/docs/pallas/grid_blockspec.md index ea1df15f2fd4..d360e3e660b5 100644 --- a/docs/pallas/grid_blockspec.md +++ b/docs/pallas/grid_blockspec.md @@ -80,8 +80,14 @@ Not all block shapes are supported. must be equal to the array dimension, or be divisible by `128 * (32 / bitwidth(dtype))`. - * On GPU, the size of the blocks themselves is not restricted, but each - operation must operate on arrays whose size is a power of 2. + * On GPU, when using the Mosaic GPU backend, the size of the blocks is + unrestricted. However, due to hardware limitations, the size of the minormost + array dimension must by such that it is a multiple of 16 bytes. For example, + it must be a multiple of 8 if the input is `jnp.float16`. + + * On GPU, when using the Triton backend, the size of the blocks themselves is + unrestricted, but each operation (including a load or store) must operate + on arrays whose size is a power of 2. ``` If the block shape does not divide evenly the overall shape then the @@ -151,8 +157,7 @@ over the second axis: ```python >>> def show_program_ids(x_shape, block_shape, grid, -... index_map=lambda i, j: (i, j), -... indexing_mode=pl.Blocked()): +... index_map=lambda i, j: (i, j)): ... def program_ids_kernel(o_ref): # Fill the output block with 10*program_id(1) + program_id(0) ... axes = 0 ... for axis in range(len(grid)): @@ -162,7 +167,7 @@ over the second axis: ... out_shape=jax.ShapeDtypeStruct(x_shape, dtype=np.int32), ... grid=grid, ... in_specs=[], -... out_specs=pl.BlockSpec(block_shape, index_map, indexing_mode=indexing_mode), +... out_specs=pl.BlockSpec(block_shape, index_map), ... interpret=True)() ... print(res) @@ -227,7 +232,8 @@ See {ref}`pallas_tpu_noteworthy_properties`. A `None` value appearing as a dimension value in the `block_shape` behaves as the value `1`, except that the corresponding -block axis is squeezed. In the example below, observe that the +block axis is squeezed (you could also pass in `pl.Squeezed()` instead of +`None`). In the example below, observe that the shape of the `o_ref` is (2,) when the block shape was specified as `(None, 2)` (the leading dimension was squeezed). @@ -269,27 +275,33 @@ used: `index_map=lambda *invocation_indices: (0,) * len(block_shape)`. ``` -### The "unblocked" indexing mode +### The "element" indexing mode -The behavior documented above applies to the `indexing_mode=pl.Blocked()`. -When using the `pl.Unblocked` indexing mode the values returned by the +The behavior documented above applies to the default "blocked" indexing mode. +When integers are used in the `block_shape` tuple e.g. `(4, 8)`, it is +equivalent to passing in a `pl.Blocked(block_size)` object instead, e.g. +`(pl.Blocked(4), pl.Blocked(8))`. Blocked indexing mode means the indices +returned by `index_map` are *block indices*. We can pass in objects other than +`pl.Blocked` to change the semantics of `index_map`, most notably, +`pl.Element(block_size)`.. +When using the `pl.Element` indexing mode the values returned by the index map function are used directly as the array indices, without first scaling them by the block size. -When using the unblocked mode you can specify virtual padding -of the array as a tuple of low-high paddings for each dimension: the +When using the `pl.Element` mode you can specify virtual padding +of the array as a tuple of low-high paddings for the dimension: the behavior is as if the overall array is padded on input. No guarantees -are made for the padding values in the unblocked mode, similarly to the padding +are made for the padding values in element mode, similarly to the padding values for the blocked indexing mode when the block shape does not divide the overall array shape. -The unblocked mode is currently supported only on TPUs. +The `Element` mode is currently supported only on TPUs. ```python ->>> # unblocked without padding ->>> show_program_ids(x_shape=(8, 6), block_shape=(2, 3), grid=(4, 2), -... index_map=lambda i, j: (2*i, 3*j), -... indexing_mode=pl.Unblocked()) +>>> # element without padding +>>> show_program_ids(x_shape=(8, 6), block_shape=(pl.Element(2), pl.Element(3)), +... grid=(4, 2), +... index_map=lambda i, j: (2*i, 3*j)) [[ 0 0 0 1 1 1] [ 0 0 0 1 1 1] [10 10 10 11 11 11] @@ -299,10 +311,12 @@ The unblocked mode is currently supported only on TPUs. [30 30 30 31 31 31] [30 30 30 31 31 31]] ->>> # unblocked, first pad the array with 1 row and 2 columns. ->>> show_program_ids(x_shape=(7, 7), block_shape=(2, 3), grid=(4, 3), -... index_map=lambda i, j: (2*i, 3*j), -... indexing_mode=pl.Unblocked(((1, 0), (2, 0)))) +>>> # element, first pad the array with 1 row and 2 columns. +>>> show_program_ids(x_shape=(7, 7), +... block_shape=(pl.Element(2, (1, 0)), +... pl.Element(3, (2, 0))), +... grid=(4, 3), +... index_map=lambda i, j: (2*i, 3*j)) [[ 0 1 1 1 2 2 2] [10 11 11 11 12 12 12] [10 11 11 11 12 12 12] diff --git a/docs/pallas/index.rst b/docs/pallas/index.rst index b2e2fca6c82e..8e1a9816212c 100644 --- a/docs/pallas/index.rst +++ b/docs/pallas/index.rst @@ -22,15 +22,22 @@ See also the :class:`jax.experimental.pallas` module API documentation. :maxdepth: 2 quickstart + pipelining grid_blockspec .. toctree:: - :caption: Platform Features + :caption: TPU backend guide :maxdepth: 2 tpu/index +.. toctree:: + :caption: Mosaic GPU backend guide + :maxdepth: 2 + + gpu/index + .. toctree:: :caption: Design Notes :maxdepth: 2 diff --git a/docs/pallas/pipelining.ipynb b/docs/pallas/pipelining.ipynb new file mode 100644 index 000000000000..6a4158001813 --- /dev/null +++ b/docs/pallas/pipelining.ipynb @@ -0,0 +1,870 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "C93Xlf0DRW9H" + }, + "source": [ + "\n", + "(pallas_software_pipelining)=\n", + "\n", + "# Software Pipelining\n", + "\n", + "Software pipelining is an important technique in performance optimization by overlapping multiple asynchronous operations even if there are data dependencies between them. In the context of kernel writing, the most common form of pipelining involves overlapping communication and memory transfers with compute such that the hardware accelerator never stalls while waiting for data to arrive. Therefore, we will solely focus on the problem of communication-compute pipelining in this tutorial. We will begin by covering the problem conceptually, outlining the Pallas API for writing pipelines, and going over some realistic examples using the API.\n", + "\n", + "This tutorial only covers the conceptual foundations of pipelining. For platform-specific references, please see {ref}`pallas_tpu_pipelining`, or {ref}`pallas_mgpu_pipelining`.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "YkOjspo5BKPD" + }, + "outputs": [], + "source": [ + "import jax\n", + "from jax import numpy as jnp\n", + "from jax.experimental import pallas as pl\n", + "import numpy as np" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "shnVghWUSvpx" + }, + "source": [ + "## Memory Hierarchies\n", + "\n", + "The first step in understanding pipelining conceptually involves understanding the different forms of memory available and the tradeoffs between them. Most hardware architectures (including CPUs, GPUs, and TPUs) utilize a wide variety of memory spaces that tradeoff capacity vs latency/bandwidth. For the purpose of Pallas, we are typically interested in registers, SRAM, DRAM, and potentially network communication:\n", + "- **Registers** are the the memory physically closest to the processor, and typically values must be loaded directly into registers before doing any compute on them.\n", + "- **SRAM** (also known as Shared Memory/L1 and L2 cache on GPUs, or VMEM on TPUs) also lives fairly close to the processor, but has larger capacity than registers.\n", + "SRAM on modern ML accelerators typically range in the 10-100MB range (TPU v5p contains 96MB of VMEM, and H100 GPUs contain ~30MB of L1 cache and 50MB of L2).\n", + "It's reasonable to expect the latency to access SRAM to be on the order of 10x longer than accessing a register.\n", + "- **DRAM** (also known as HBM) has much higher capacity than SRAM, typically in the 10-100GB range for modern ML accelerators. However, the latency is roughly on the order of 10x longer to access compared to SRAM.\n", + "- **Network** communication becomes crucial for larger workloads when the size of DRAM on a single device becomes insufficient or when we'd like to take advantage of parallel computations. We do not cover distributed pipelining in this tutorial, but see the [distributed TPU kernels](https://docs.jax.dev/en/latest/pallas/tpu/distributed.html) guide for writing pipelines across multiple devices.\n", + "\n", + "\n", + "\n", + "\n", + "![memory_hierarchy](../_static/pallas/pipelining_mem_hierarchy.svg)\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "WvW6Lo7d2jfb" + }, + "source": [ + "\n", + "In order to perform computation on values X and Y that live in HBM, we need to:\n", + "\n", + "1. Copy the values x and y into SRAM.\n", + "2. Load the values from SRAM into registers.\n", + "3. Execute the computation and store the result into registers.\n", + "4. Store the values in the output registers into SRAM.\n", + "5. Copy the output values in SRAM back to HBM.\n", + "\n", + "Let’s implement a Pallas function that does just that!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "executionInfo": { + "elapsed": 108, + "status": "ok", + "timestamp": 1744764235906, + "user": { + "displayName": "Justin Fu", + "userId": "17543197034567316452" + }, + "user_tz": 420 + }, + "id": "IrPhDFnT3Nvw", + "outputId": "8bc03872-fd9f-4610-9d53-d4b46be560f4" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "Array([[2., 2., 2., ..., 2., 2., 2.],\n", + " [2., 2., 2., ..., 2., 2., 2.],\n", + " [2., 2., 2., ..., 2., 2., 2.],\n", + " ...,\n", + " [2., 2., 2., ..., 2., 2., 2.],\n", + " [2., 2., 2., ..., 2., 2., 2.],\n", + " [2., 2., 2., ..., 2., 2., 2.]], dtype=float32)" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Note: This is a TPU example.\n", + "\n", + "def add_matrices_kernel(x_sram_ref, y_sram_ref, z_sram_ref):\n", + " # Load x and y from SRAM into registers\n", + " x_regs = x_sram_ref[:, :]\n", + " y_regs = y_sram_ref[:, :]\n", + " # Execute a vectorized add\n", + " z_regs = x_regs + y_regs\n", + " # Store the output values in registers back into SRAM\n", + " z_sram_ref[:, :] = z_regs\n", + "\n", + "\n", + "def add_matrices(x: jax.Array, y: jax.Array) -> jax.Array:\n", + " # pallas_call will first allocate scratch buffers for `x` and `y` in SRAM.\n", + " # It will then copy `x` and `y` from HBM into SRAM.\n", + " z = pl.pallas_call(\n", + " add_matrices_kernel, out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype)\n", + " )(x, y)\n", + " # pallas_call will also copy the output from SRAM back into HBM.\n", + " return z\n", + "\n", + "\n", + "x, y = jnp.ones((512, 512)), jnp.ones((512, 512))\n", + "add_matrices(x, y)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "gGjtwv9u3UNK" + }, + "source": [ + "We've written two functions: `add_matrices_kernel` and `add_matrices`.\n", + "\n", + "`add_matrices_kernel` operates using `Refs` that live in SRAM. Loading from a SRAM Ref produces a value that lives in registers. Values in registers behave like jax.Arrays in that we can use `jnp` and `jax.lax` operations on them to produce new values that live in registers. When we produce the values we'd like to return, we store them in the output SRAM `Ref`.\n", + "\n", + "The `add_matrices` function acts on `jax.Array`s and returns a `jax.Array`. Inside it, we pass `x` and `y` into pallas_call. `pallas_call` is responsible for copying `x` and `y` into SRAM and for allocating the SRAM buffers that the kernel operates on (including allocating `z_vmem_ref`, the output SRAM buffer). After the kernel function is finished running, `pallas_call` will also copy the value in `z_vmem_ref` to HBM, resulting in an output `jax.Array`.\n", + "\n", + "Pallas exposes access to lower level memory spaces like SRAM but writing performant kernels requires more care in utilizing the various memory spaces. For example, we need to consider both:\n", + "\n", + "- **Memory capacity**. SRAM is small! If our arrays are too big, the above kernel would not work because we cannot fit the input into SRAM. For reference, an `f32[2048, 2048]` array is 16MiB, so our above kernel won't scale beyond moderately sized arrays.\n", + "\n", + "- **Memory bandwidth**. Copying to/from HBM and SRAM takes a long time, at least compared to most compute instructions. The `add_matrices` function above will likely spend more time copying between HBM and SRAM than actually performing the addition itself.\n", + "\n", + "With these two constraints in mind, we'll have to rethink our strategy for getting performance out of our accelerators.\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "0Ebs2pCDgsEW" + }, + "source": [ + "## Pipelining Basics\n", + "\n", + "\n", + "How can we take advantage of the strengths of each form of type memory in the hierarchy, and be able to operate on large arrays stored in HBM while still utilizing fast SRAM for compute? Pipelining is a very general programming pattern which will allow us to do exactly this, but it requires transforming your problem into smaller sub-problems that can be overlapped in parallel.\n", + "\n", + "The first step in pipelining is to divide our problem into smaller subproblems that can fit inside of SRAM. For example, an elementwise operation is can be trivially transformed by operating on one slice of the source array at a time, which results in the following 3 steps (also known as stages): \n", + "\n", + "1. **copy_in**: Copy a slice `A[i]` from HBM to SRAM `X`.\n", + "2. **compute**: Load `X` into registers, compute a result, and store in SRAM `Y`\n", + "3. **copy_out**: Copy result `Y` back into HBM `A[i]`.\n", + "\n", + "Note that there is a data-dependence between steps 1-3, and we cannot trivially overlap them since we need step (1) to complete before starting step (2), and so on. However, there is no data dependence across multiple invocations of the subproblem - that is, we can execute step (1) for block `A[i+1]` while executing step (2) for block `A[i]` and step (3) for block `A[i-1]`.\n", + "\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "8vCtShhBjzTd" + }, + "source": [ + "\n", + "![pipelining_example](../_static/pallas/pipelining_example.svg)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Qs3F--kwiOJm" + }, + "source": [ + "The diagram above depicts how an idealized pipelined program can be scheduled across time. The key insight is that in the majority of the kernel, the copy operations are executed in parallel with compute operations, meaning we can ideally \"hide\" the cost of transferring between HBM/SRAM with computation and keep the processor busy with as much uptime as possible.\n", + "\n", + "The initial startup time and final teardown time known as \"bubbles\", where only a subset of the stages are being executed while the pipeline is being \"filled\" or \"drained\". The bulk of the time is spent in the \"steady-state\" phase of the pipeline, where each pipeline stage is being executed in parallel across different iterations of the subproblem. While with more general pipelining approaches the goal is to achieve N-way parallelism (where N is the number of stages), with kernel pipelining we are usually bottlenecked either by memory bandwidth or processing speed. Therefore, our goal with kernel pipelining is typically to achieve full utilization of the FLOPs/s of our processor, meaning that at any point in time there is always a `compute` block active. In the figure above, the compute block is active in 6/8 timeslots, and assuming we are fully utilizing the processor in each compute timeslot, we would have achieved 75% utilization of the processor." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ZcSzl4N6pPbG" + }, + "source": [ + "### Deriving a Double-Buffered Pipeline\n", + "\n", + "Now lets look at how we could implement a pipeline in pseudocode. Consider the following elementwise program, where we load values from HBM (`A[i]`) with a `copy_in` instruction, add 1 to the result, and store the result back to HBM with `copy_out`:\n", + "\n", + "
\n",
+    "for i in range(N):\n",
+    "  copy_in(A[i], X)\n",
+    "  Y = X + 1\n",
+    "  copy_out(Y, A[i])\n",
+    "
\n", + "The issue with this approach is that `copy_in` and `copy_out` are typically blocking operations. So we are forced to wait for the copies to finish while the GPU/TPU is idle, then perform compute while the memory is idle. What we would like to do is to \"pre-fetch\" the input value that is required on the next iteration of the loop asynchronously while performing the computation for the current loop, so that compute and memory communication are happening simultaneously.\n", + "\n", + "In order to reason about the code transformation we will make, lets unroll the loop for N=4, and decompose the copy instructions into separate `copy_start` and `copy_wait` operations to be able to express asynchrony:\n", + "
\n",
+    "  # Itr 1\n",
+    "  copy_in_start(A[0], X)\n",
+    "  copy_in_wait(X)\n",
+    "  Y = X + 1\n",
+    "  copy_out_start(Y, A[0])\n",
+    "  copy_out_wait(Y)\n",
+    "\n",
+    "  # Itr 2\n",
+    "  copy_in_start(A[1], X)\n",
+    "  copy_in_wait(X)\n",
+    "  Y = X + 1\n",
+    "  copy_out_start(Y, A[1])\n",
+    "  copy_out_wait(Y)\n",
+    "\n",
+    "  # Itr 3\n",
+    "  copy_in_start(A[2], X)\n",
+    "  copy_in_wait(X)\n",
+    "  Y = X + 1\n",
+    "  copy_out_start(Y, A[2])\n",
+    "  copy_out_wait(Y)\n",
+    "\n",
+    "  # Itr 4\n",
+    "  copy_in_start(A[3], X)\n",
+    "  copy_in_wait(X)\n",
+    "  Y = X + 1\n",
+    "  copy_out_start(Y, A[3])\n",
+    "  copy_out_wait(Y)\n",
+    "
\n", + "\n", + "Once the loop has been unrolled, the pipelining transformation simply involves issuing `copy_start` instructions as early as possible, and `copy_wait` values as late as possible (right before we need the value). However, in the current state of the loop there is a fake data dependency through X - we cannot simultaneously perform an async copy into X while using it for computation or else we may have a race condition. Therefore, we can use a **multiple-buffering** technique where we keep 2 buffers for each input X and each output Y. With 2 buffers, we can push the `copy_in_start` one iteration ahead (with 3 buffers you can push 2 iterations, and so on) and we rewrite our loop as follows:\n", + "
\n",
+    "  # Prologue\n",
+    "  copy_in_start(A[0], X[0])\n",
+    "  \n",
+    "  # Itr 1\n",
+    "  copy_in_start(A[1], X[1])\n",
+    "  copy_in_wait(X[0])\n",
+    "  Y[0] = X[0] + 1\n",
+    "  copy_out_start(Y[0], A[0])\n",
+    "  copy_out_wait(Y[0])\n",
+    "\n",
+    "  # Itr 2 - Steady state\n",
+    "  copy_in_start(A[2], X[0])\n",
+    "  copy_in_wait(X[1])\n",
+    "  Y[1] = X[1] + 1\n",
+    "  copy_out_start(Y[1], A[1])\n",
+    "  copy_out_wait(Y[1])\n",
+    "\n",
+    "  # Itr 3 - Steady state\n",
+    "  copy_in_start(A[3], X[1])\n",
+    "  copy_in_wait(X[0])\n",
+    "  Y[0] = X[0] + 1\n",
+    "  copy_out_start(Y[0], A[2])\n",
+    "  copy_out_wait(Y[0])\n",
+    "\n",
+    "  # Itr 4 - No copy-in\n",
+    "  copy_in_wait(X[1])\n",
+    "  Y[1] = X[1] + 1\n",
+    "  copy_out_start(Y[1], A[3])\n",
+    "  copy_out_wait(Y[1])\n",
+    "
\n", + "\n", + "Next, we can push the `copy_out_wait` as late as possible, right before we need to write into Y on the subsequent loop iteration.\n", + "\n", + "
\n",
+    "  # Prologue\n",
+    "  copy_in_start(A[0], X[0])\n",
+    "  \n",
+    "  # Itr 1\n",
+    "  copy_in_start(A[1], X[1])\n",
+    "  copy_in_wait(X[0])\n",
+    "  Y[0] = X[0] + 1\n",
+    "  copy_out_start(Y[0], A[0])\n",
+    "\n",
+    "  # Itr 2 - Steady state\n",
+    "  copy_in_start(A[2], X[0])\n",
+    "  copy_in_wait(X[1])\n",
+    "  Y[1] = X[1] + 1\n",
+    "  copy_out_start(Y[1], A[1])\n",
+    "  copy_out_wait(Y[0])\n",
+    "\n",
+    "  # Itr 3 - Steady state\n",
+    "  copy_in_start(A[3], X[1])\n",
+    "  copy_in_wait(X[0])\n",
+    "  Y[0] = X[0] + 1\n",
+    "  copy_out_start(Y[0], A[2])\n",
+    "  copy_out_wait(Y[1])\n",
+    "\n",
+    "  # Itr 4 - No copy-in\n",
+    "  copy_in_wait(X[1])\n",
+    "  Y[1] = X[1] + 1\n",
+    "  copy_out_start(Y[1], A[3])\n",
+    "  copy_out_wait(Y[0])\n",
+    "\n",
+    "  # Epilogue\n",
+    "  copy_out_wait(Y[1])\n",
+    "
\n", + "\n", + "Finally, re-rolling our loop back into a for loop, we obtain the following pipelined loop:\n", + "\n", + "```\n", + "# Prologue\n", + "copy_in_start(A[0], X[0])\n", + "\n", + "# Main loop\n", + "for i in range(N):\n", + " cur_slot = i % 2\n", + " next_slot = (i + 1) % 2\n", + "\n", + " if i < N:\n", + " copy_in_start(A[i+1], X[next_slot])\n", + " \n", + " copy_in_wait(X[cur_slot])\n", + " Y[cur_slot] = X[cur_slot] + 1\n", + " copy_out_start(Y[cur_slot], A[i])\n", + "\n", + " if i > 0:\n", + " copy_out_wait(Y[next_slot])\n", + "\n", + "# Epilogue\n", + "copy_out_wait(Y[1])\n", + "```\n", + "\n", + "If we want to generalize this loop to handle a broader set of computations, notice that we essentially need to specify 3 pieces of information to the pipeline:\n", + "\n", + "- The **grid**, or the bounds of the for loop that specifies the number of subproblems to compute. In our example we had a 1-dimensional grid with size `(N,)`.\n", + "- The **kernel**, or the actual computation happening once the inputs have been loaded into SRAM. In our example we performed an elementwise addition `Y = X + 1`.\n", + "- The **data_slices**, which map a subproblem to corresponding slices into the HBM buffer. In our example the data slice was the identity function `lambda i: i`.\n", + "\n", + "By allowing the user to specify these pieces of information we can write a wide variety of programs following this pattern:\n", + "```python\n", + "def double_buffered_pipeline(\n", + " grid: tuple[int, ...],\n", + " kernel: Callable,\n", + " in_slices: Callable,\n", + " out_slices: Callable):\n", + " # Prologue\n", + " copy_in_start(in_hbm[in_slices(0)], in_sram[0])\n", + "\n", + " # Main loop\n", + " grid_size = prod(grid)\n", + " for i in range(grid_size):\n", + " cur_slot = i % 2\n", + " next_slot = (i + 1) % 2\n", + " if (i + 1) < grid_size:\n", + " copy_in_start(in_hbm[in_slices(i+1)], in_sram[next_slot])\n", + " copy_in_wait(in_sram[cur_slot])\n", + "\n", + " kernel(in_sram[cur_slot], out_ram[cur_slot])\n", + "\n", + " copy_out_start(out_sram[cur_slot], out_hbm[out_slices(i)])\n", + " if i > 0:\n", + " copy_out_wait(out_sram[next_slot])\n", + "\n", + " # Epilogue\n", + " last_slot = (grid_size - 1) % 2\n", + " copy_out_wait(out_sram[last_slot])\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ziBuvv8jDgxo" + }, + "source": [ + "Now that we've seen how to manually implement a pipelined loop, let's look into how to use the Pallas API." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "niMr39cPkJ2m" + }, + "source": [ + "## Pallas Pipelining API\n", + "\n", + "Pallas offers a pipelining API that abstracts away the boilerplate of maintaining multiple buffers and overlapping asynchronous communication with computation. The basics of this API are covered in {ref}`pallas_quickstart`, so we will go over the API briefly here for completeness and discuss some sharp edges that arise from the use of pipelining.\n", + "\n", + "\n", + "### Grid\n", + "\n", + "The program **grid** is a tuple of integers specifying the number of subproblems as an array. The structure of the pipeline can be interpreted as a nested for-loop where the bounds of each loop.\n", + "\n", + "```\n", + "# For grid (N, M, K)\n", + "for n in range (N):\n", + " for m in range(M):\n", + " for k in range(K):\n", + " kernel()\n", + "```\n", + "\n", + "The kernel will be invoked a total of `prod(grid)` times. For more details, see [grid and blockspecs](https://docs.jax.dev/en/latest/pallas/grid_blockspec.html#grid-a-k-a-kernels-in-a-loop).\n", + "\n", + "### BlockSpecs\n", + "\n", + "A BlockSpec specifies the size and slice of data copied to the kernel on each subproblem. The basic constructor to `pl.BlockSpec` involves specifying the `block_shape`, the size of a slice of data, and `index_map`, a function that takes in the program ids of the current subproblem and outputs _blocked_ indices into the source buffer. Blocked indices specify which block to copy on each iteration, assuming the source buffer has been carved into blocks of shape as `block_shape`. The `memory_space` argument specifies what memory space to copy the inputs to - be default this will be SRAM.\n", + "\n", + "```python\n", + "pl.BlockSpec(\n", + " block_shape: tuple[int, ...],\n", + " index_map: Callable,\n", + " memory_space: pl.MemorySpace\n", + ")\n", + "```\n", + "There should be one BlockSpec for each input and each output to the kernel. For more details, see [grid and blockspecs](https://docs.jax.dev/en/latest/pallas/grid_blockspec.html#grid-a-k-a-kernels-in-a-loop).\n", + "\n", + "### Kernel\n", + "\n", + "The kernel function specifies what compute to perform on each subproblem. The kernel function should return no outputs, and instead all outputs should be written into the output buffers that are passed into the kernel. All inputs and output buffers are SRAM buffers by default (unless the user has overridden the behavior by specifying a `memory_space` on the corresponding `BlockSpec`).\n", + "\n", + "```python\n", + "def kernel(*input_buffers, *output_buffers):\n", + " # ... perform compute\n", + " # ... store result into output buffers\n", + "```\n", + "\n", + "The index of the current subproblem can be queried inside the kernel using `pl.program_id(grid_axis: int)`.\n", + "\n", + "\n", + "### Pallas Call\n", + "\n", + "The `pl.pallas_call` function is the main entry point to Pallas and performs pipelined execution when a grid and BlockSpecs are supplied. It has the following signature:\n", + "```python\n", + "def pallas_call(\n", + " kernel,\n", + " grid: tuple[int, ...],\n", + " in_specs: Sequence[PyTree[BlockSpec]],\n", + " out_specs: PyTree[BlockSpec],\n", + " out_shape: PyTree[jax.ShapeDtypeStruct],\n", + ") -> Callable:\n", + "```\n", + "`pallas_call` will return a callable function that when invoked with input values, will return outputs of the same shape as `out_shape`.\n", + "\n", + "`in_specs`, `out_specs`, and `out_shape` are PyTrees of their respective element type. The PyTrees for `in_specs` and the input buffers supplied to the kernel should match, and the PyTrees for `out_specs` and `out_shape` should also match.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "0mHZ63eAq_8j" + }, + "source": [ + "### Example - Elementwise Kernel revisited\n", + "\n", + "Let's revisit the initial `add_matrices_kernel` from the beginning of the tutorial, except using pipelining. We will add two input arrays of shape `f32[4096, 4096]` that live in HBM. As subproblems, we will carve up the inputs into `block_shape=(512, 512)` blocks and only add two blocks together at a time in the kernel. Because addition is elementwise, each `index_map` is identical and selects out the `i, j`th block on the `i, j`th iteration." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "iqr_qjONAHN9" + }, + "outputs": [], + "source": [ + "# Note: This is a TPU example.\n", + "\n", + "total_shape = (4096, 4096)\n", + "block_shape = (512, 512)\n", + "\n", + "def add_matrices_pipelined_kernel(x_ref, y_ref, o_ref):\n", + " o_ref[...] = x_ref[...] + y_ref[...]\n", + "\n", + "def add_matrices_pipelined(x: jax.Array, y: jax.Array):\n", + " return pl.pallas_call(\n", + " add_matrices_pipelined_kernel,\n", + " grid=tuple(total // block for (total, block) in zip(total_shape, block_shape)),\n", + " in_specs=[\n", + " pl.BlockSpec(block_shape, index_map=lambda i, j: (i, j)),\n", + " pl.BlockSpec(block_shape, index_map=lambda i, j: (i, j))\n", + " ],\n", + " out_specs=pl.BlockSpec(block_shape, index_map=lambda i, j: (i, j)),\n", + " out_shape=jax.ShapeDtypeStruct(total_shape, dtype=jnp.float32),\n", + " )(x, y)\n", + "\n", + "x = jax.random.uniform(jax.random.key(0), total_shape, dtype=jnp.float32)\n", + "y = jax.random.uniform(jax.random.key(1), total_shape, dtype=jnp.float32)\n", + "result = add_matrices_pipelined(x, y)\n", + "np.testing.assert_array_equal(\n", + " result, x + y\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "UWHD0_qm6DL7" + }, + "source": [ + "It turns out that with this API, writing a pipelined kernel is not much more lines of code than writing our original naive addition kernel!" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "BZ-4U6Cv6cvU" + }, + "source": [ + "### Parameterizing a Kernel\n", + "\n", + "It's common to parameterize the block shapes in our kernel. Block sizes are perhaps the most important parameter to tune when optimizing the performance of Pallas kernels! They give us control over the pipeline (for example, picking smaller blocks adds more iterations to our pipelined loop where each iteration has less work to do). Let's write a a function that does so:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "RZTAiwrZ6srD" + }, + "outputs": [], + "source": [ + "def add_matrices_pipelined_param(\n", + " x: jax.Array, y: jax.Array, *, bm: int = 256, bn: int = 256\n", + ") -> jax.Array:\n", + " m, n = x.shape\n", + " block_spec = pl.BlockSpec((bm, bn), lambda i, j: (i, j))\n", + " return pl.pallas_call(\n", + " add_matrices_kernel,\n", + " out_shape=x,\n", + " in_specs=[block_spec, block_spec],\n", + " out_specs=block_spec,\n", + " grid=(m // bm, n // bn),\n", + " )(x, y)\n", + "\n", + "np.testing.assert_array_equal(\n", + " add_matrices_pipelined_param(x, y, bm=256, bn=256), x + y\n", + ")\n", + "np.testing.assert_array_equal(\n", + " add_matrices_pipelined_param(x, y, bm=128, bn=128), x + y\n", + ")\n", + "np.testing.assert_array_equal(\n", + " add_matrices_pipelined_param(x, y, bm=512, bn=512), x + y\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "vO8VkbYj_ral" + }, + "source": [ + "## Sharp edges\n", + "\n", + "While pipelining provides a close approximation to the mental model of simply calling a kernel function in a loop, there are a number of sharp edges that arise from the use of intermediate buffers that are not fully hidden from the user and can result in subtle bugs.\n", + "\n", + "### Buffer Revisiting\n", + "\n", + "In general, a good rule-of-thumb to follow is that **the input buffers passed into the kernel function should be interpreted as read-only, and output buffers are write only**.\n", + "\n", + "Writing to inputs and reading from outputs will in most cases result in incorrectness. This is because the SRAM buffers passed to a kernel only contain copies of the data contained in the underlying HBM buffer. If an input SRAM buffer is updated, the updated results will never be written back out to HBM, and if an output buffer is updated, it's updated value is never read into SRAM. This issue is analogous to staleness issues encountered when using caches in general.\n", + "\n", + "There are two cases where a buffer supports both reads and writes - accumulation (discussed next), and marking a pair of input and output buffers as input-output aliased by passing in the `input_output_aliases` argument to `pallas_call`.\n", + "\n", + "\n", + "### Reductions and accumulation\n", + "\n", + "**Reduction/accumulation should only be performed over the last (innermost) dimensions of the grid, and the buffer should be initialized manually first.**\n", + "\n", + "Reductions are one of the few cases where the pipeline supports both reading and writing to an output buffer, but the reason it works is subtle.\n", + "The Pallas pipeline emitter performs an optimization where if the data slices between two consecutive iterations are the same, the pipeline will not issue a `copy_in`/`copy_out` on that buffer. This means the same SRAM buffer used in a previous iteration will be passed into the kernel again on the following iteration, and thus any writes that were issued to the output buffer will become visible on the next iteration. Once the data slice changes, the final accumulated SRAM buffer will be written out to HBM. This is also why reductions must be performed over the last dimensions of the grid -- we want to finish all of the accumulation while the output buffer is in SRAM in the innermost loop, then write it to HBM and never touch that output block again.\n", + "\n", + "As a concrete example, let's consider performing the following computation for reducing an `(8, 1024, 1024)` array along the first axies into a `(1024, 1024)` array.\n", + "\n", + "\n", + "\n", + "\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "executionInfo": { + "elapsed": 244, + "status": "ok", + "timestamp": 1744763773938, + "user": { + "displayName": "Justin Fu", + "userId": "17543197034567316452" + }, + "user_tz": 420 + }, + "id": "4qz1ET-_f9fJ", + "outputId": "e43067ef-933a-45a5-912a-e224151cfa60" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "Array([[8., 8., 8., ..., 8., 8., 8.],\n", + " [8., 8., 8., ..., 8., 8., 8.],\n", + " [8., 8., 8., ..., 8., 8., 8.],\n", + " ...,\n", + " [8., 8., 8., ..., 8., 8., 8.],\n", + " [8., 8., 8., ..., 8., 8., 8.],\n", + " [8., 8., 8., ..., 8., 8., 8.]], dtype=float32)" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "x = jnp.ones((8, 1024, 1024))\n", + "jnp.sum(x, axis=0)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "yX762DRrgCOG" + }, + "source": [ + "To do this using `pallas_call`, we could use a grid of size `(8,)` and in each iteration i load `x[i]` into SRAM. Then we could add `x[i]` to an output SRAM buffer. Let's implement this naively first." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "executionInfo": { + "elapsed": 79, + "status": "ok", + "timestamp": 1744763774254, + "user": { + "displayName": "Justin Fu", + "userId": "17543197034567316452" + }, + "user_tz": 420 + }, + "id": "ZEi1_vQVf-81", + "outputId": "581744b7-ddc1-4dc1-98ec-03c852772eda" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[[65. 65. 65. ... 66. 66. 66.]\n", + " [65. 65. 65. ... 66. 66. 66.]\n", + " [65. 65. 65. ... 66. 66. 66.]\n", + " ...\n", + " [71. 71. 71. ... 72. 72. 72.]\n", + " [71. 71. 71. ... 72. 72. 72.]\n", + " [71. 71. 71. ... 72. 72. 72.]]\n" + ] + } + ], + "source": [ + "# Note: This is a TPU example.\n", + "\n", + "# Warning: this implementation is incorrect!\n", + "def incorrect_sum_kernel(x_ref, o_ref):\n", + " o_ref[...] += x_ref[...]\n", + "\n", + "def incorrect_sum(x: jax.Array,\n", + " block_size: tuple[int, ...] = (256, 256)) -> jax.Array:\n", + " reduction_size, *out_shape = x.shape\n", + " grid = (reduction_size, *(out // blk for out, blk in zip(out_shape, block_size)))\n", + " return pl.pallas_call(\n", + " incorrect_sum_kernel,\n", + " grid=grid,\n", + " # None in `block_shape` means we pick a size of 1 and squeeze it away\n", + " in_specs=[pl.BlockSpec((None, *block_size), lambda i, j, k: (i, j, k))],\n", + " out_specs=pl.BlockSpec(block_size, lambda i, j, k: (j, k)),\n", + " out_shape=jax.ShapeDtypeStruct(out_shape, x.dtype),\n", + " )(x)\n", + "\n", + "result = incorrect_sum(x)\n", + "print(result)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "MglScPDD9618" + }, + "source": [ + "This result is completely wrong!\n", + "\n", + "There are two errors inside this kernel. First, we are accumulating along the first grid dimension instead of the last grid dimension. Second, `o_ref` initially contains garbage values and thus we need to initialize it to zeros before we begin accumulation.\n", + "\n", + "After fixing these two issues, we obtain the following corrected kernel. In this new kernel, we use `@pl.when` to create a conditional that checks when the program ID is `0` along the reduction axis, indicating we are beginning to accumulate into a new output block. We have also moved the reduction dimension to the last axis of the `grid`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "executionInfo": { + "elapsed": 104, + "status": "ok", + "timestamp": 1744763774523, + "user": { + "displayName": "Justin Fu", + "userId": "17543197034567316452" + }, + "user_tz": 420 + }, + "id": "XtgD4nMa9_Bd", + "outputId": "9ef07cdf-9e22-4dc8-c17f-c96172639801" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[[8. 8. 8. ... 8. 8. 8.]\n", + " [8. 8. 8. ... 8. 8. 8.]\n", + " [8. 8. 8. ... 8. 8. 8.]\n", + " ...\n", + " [8. 8. 8. ... 8. 8. 8.]\n", + " [8. 8. 8. ... 8. 8. 8.]\n", + " [8. 8. 8. ... 8. 8. 8.]]\n" + ] + } + ], + "source": [ + "# Note: This is a TPU example.\n", + "\n", + "def correct_sum_kernel(x_ref, o_ref):\n", + " @pl.when(pl.program_id(2) == 0)\n", + " def _():\n", + " o_ref[...] = jnp.zeros_like(o_ref)\n", + " o_ref[...] += x_ref[...]\n", + "\n", + "def correct_sum(x: jax.Array,\n", + " block_size: tuple[int, ...] = (256, 256)) -> jax.Array:\n", + " reduction_size, *out_shape = x.shape\n", + " # We moved the reduction to the last axis of the grid.\n", + " grid = (*(out // blk for out, blk in zip(out_shape, block_size)), reduction_size)\n", + " return pl.pallas_call(\n", + " correct_sum_kernel,\n", + " grid=grid,\n", + " # None in `block_shape` means we pick a size of 1 and squeeze it away\n", + " in_specs=[pl.BlockSpec((None, *block_size), lambda i, j, k: (k, i, j))],\n", + " out_specs=pl.BlockSpec(block_size, lambda i, j, k: (i, j)),\n", + " out_shape=jax.ShapeDtypeStruct(out_shape, x.dtype),\n", + " )(x)\n", + "\n", + "result = correct_sum(x)\n", + "print(result)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "BckuFg6qcnVw" + }, + "source": [ + "\n", + "## Analyzing the performance\n", + "\n", + "What is the performance of a pipelined kernel? This question can vary depending on where the bottleneck is the hardware is. We are typically interested in 3 quantities:\n", + "- **Memory latency** $α$, the minimum latency of a memory transfer.\n", + "- **Memory bandwidth** $β$, the rate in bytes/second that we can transfer from HBM to SRAM.\n", + "- **FLOP/s** $F$, or floating-point-operations per second, the number of calculations per second that the processor can perform.\n", + "\n", + "We refer to a program as **compute-bound** if the processing speed FLOPs/s is the bottleneck, and as **memory-bound** if the bandwidth or latency are the bottleneck. Generally, our goal is to optimize a kernel such that it is compute-bound, meaning we are utilizing all of the available processing power of our hardware.\n", + "\n", + "Suppose we are running a program that requires $X$ bytes of memory transfers per kernel iteration, and runs $Y$ floating-point operations per iteration. The ratio of $X$ to $Y$ varies depending on the type of compute -- for elementwise operations such as addition or multiplication, they will both scale equally. However, for operations such as matrix multiplication, compute scales cubically with the size of the problem while memory scales quadratically.\n", + "\n", + "In a **compute-bound** regime, a pipeline running $N$ iterations would take $(\\alpha + X/\\beta) + N (Y/F)$ seconds, where the first term represents the cost of the initial bubble (multiply by a factor of 2 if there is also a bubble at the end), and the second term represents the total time of the steady-state of the pipeline. Assuming that N is large and there is enough work to produce a long pipeline, the dominating term in the runtime is $F$, the processing speed of the accelerator.\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "NDY4mcae_nMO" + }, + "source": [ + "\n", + "![pipelining_compute](../_static/pallas/pipelining_compute_bound.svg)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "HFWcaAudW4z1" + }, + "source": [ + "In a **memory-bound** regime it is useful to identify if the problem is the latency versus the bandwidth. If the bandwidth is the bottleneck, then the total runtime would take $\\alpha + N(X / \\beta)$ seconds. In contrast with a latency-bound regime, the memory copies happen serially because the bandwidth is already saturated. Being memory-bound is generally not ideal as there will be gaps in time where the processor is idle, and in most hardware configurations the memory bandwidth $\\beta$ is orders of magnitude slower than the processing speed $F$." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "gqcCDsGg_sca" + }, + "source": [ + "\n", + "![pipelining_bandwidth](../_static/pallas/pipelining_bandwidth_bound.svg)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "V4YQCZf1W7X5" + }, + "source": [ + "If the bottleneck is specifically the latency and not the bandwidth, it is possible to fix the problem by inserting additional pipeline stages at the cost of additional SRAM required to store more buffers. With sufficient stages, the problem will either become compute or bandwidth bound again depending on which bottleneck we hit first during the steady-stage stage of the pipeline. The downside, however, of a multi-stage pipeline is that the size of the bubble is proportional to the number of stages so it is important to make sure the pipeline is long enough such that the bubble does not take up a substantial amount of the total runtime.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Sj5PFl0s_yc6" + }, + "source": [ + "\n", + "![pipelining_latency](../_static/pallas/pipelining_latency_multistage.svg)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ar4NVxxFfKEb" + }, + "source": [ + "Pallas on TPU only supports double-buffering, as TPU programs can operate on larger block sizes and double-buffering is typically enough to cover the latency. On GPU, the number of pipeline stages can be specified in both the Triton (via `TritonCompilerParams`) and Mosaic GPU backends (via argument to the pipeline emitter). See the platform-specific pipelining documentation for more details." + ] + } + ], + "metadata": { + "colab": { + "last_runtime": { + "build_target": "//experimental/users/justinfu/pallas:colab", + "kind": "private" + }, + "provenance": [] + }, + "jupytext": { + "formats": "ipynb,md", + "main_language": "python" + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/docs/pallas/pipelining.md b/docs/pallas/pipelining.md new file mode 100644 index 000000000000..2bf21f0d8c27 --- /dev/null +++ b/docs/pallas/pipelining.md @@ -0,0 +1,600 @@ +--- +jupyter: + jupytext: + formats: ipynb,md + main_language: python + text_representation: + extension: .md + format_name: markdown + format_version: '1.3' + jupytext_version: 1.16.4 + kernelspec: + display_name: Python 3 + name: python3 +--- + + + +(pallas_software_pipelining)= + +# Software Pipelining + +Software pipelining is an important technique in performance optimization by overlapping multiple asynchronous operations even if there are data dependencies between them. In the context of kernel writing, the most common form of pipelining involves overlapping communication and memory transfers with compute such that the hardware accelerator never stalls while waiting for data to arrive. Therefore, we will solely focus on the problem of communication-compute pipelining in this tutorial. We will begin by covering the problem conceptually, outlining the Pallas API for writing pipelines, and going over some realistic examples using the API. + +This tutorial only covers the conceptual foundations of pipelining. For platform-specific references, please see {ref}`pallas_tpu_pipelining`, or {ref}`pallas_mgpu_pipelining`. + + + +```python id="YkOjspo5BKPD" +import jax +from jax import numpy as jnp +from jax.experimental import pallas as pl +import numpy as np +``` + + +## Memory Hierarchies + +The first step in understanding pipelining conceptually involves understanding the different forms of memory available and the tradeoffs between them. Most hardware architectures (including CPUs, GPUs, and TPUs) utilize a wide variety of memory spaces that tradeoff capacity vs latency/bandwidth. For the purpose of Pallas, we are typically interested in registers, SRAM, DRAM, and potentially network communication: +- **Registers** are the the memory physically closest to the processor, and typically values must be loaded directly into registers before doing any compute on them. +- **SRAM** (also known as Shared Memory/L1 and L2 cache on GPUs, or VMEM on TPUs) also lives fairly close to the processor, but has larger capacity than registers. +SRAM on modern ML accelerators typically range in the 10-100MB range (TPU v5p contains 96MB of VMEM, and H100 GPUs contain ~30MB of L1 cache and 50MB of L2). +It's reasonable to expect the latency to access SRAM to be on the order of 10x longer than accessing a register. +- **DRAM** (also known as HBM) has much higher capacity than SRAM, typically in the 10-100GB range for modern ML accelerators. However, the latency is roughly on the order of 10x longer to access compared to SRAM. +- **Network** communication becomes crucial for larger workloads when the size of DRAM on a single device becomes insufficient or when we'd like to take advantage of parallel computations. We do not cover distributed pipelining in this tutorial, but see the [distributed TPU kernels](https://docs.jax.dev/en/latest/pallas/tpu/distributed.html) guide for writing pipelines across multiple devices. + + + + +![memory_hierarchy](../_static/pallas/pipelining_mem_hierarchy.svg) + + + + + + +In order to perform computation on values X and Y that live in HBM, we need to: + +1. Copy the values x and y into SRAM. +2. Load the values from SRAM into registers. +3. Execute the computation and store the result into registers. +4. Store the values in the output registers into SRAM. +5. Copy the output values in SRAM back to HBM. + +Let’s implement a Pallas function that does just that! + + +```python executionInfo={"elapsed": 108, "status": "ok", "timestamp": 1744764235906, "user": {"displayName": "Justin Fu", "userId": "17543197034567316452"}, "user_tz": 420} id="IrPhDFnT3Nvw" outputId="8bc03872-fd9f-4610-9d53-d4b46be560f4" +# Note: This is a TPU example. + +def add_matrices_kernel(x_sram_ref, y_sram_ref, z_sram_ref): + # Load x and y from SRAM into registers + x_regs = x_sram_ref[:, :] + y_regs = y_sram_ref[:, :] + # Execute a vectorized add + z_regs = x_regs + y_regs + # Store the output values in registers back into SRAM + z_sram_ref[:, :] = z_regs + + +def add_matrices(x: jax.Array, y: jax.Array) -> jax.Array: + # pallas_call will first allocate scratch buffers for `x` and `y` in SRAM. + # It will then copy `x` and `y` from HBM into SRAM. + z = pl.pallas_call( + add_matrices_kernel, out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype) + )(x, y) + # pallas_call will also copy the output from SRAM back into HBM. + return z + + +x, y = jnp.ones((512, 512)), jnp.ones((512, 512)) +add_matrices(x, y) +``` + + +We've written two functions: `add_matrices_kernel` and `add_matrices`. + +`add_matrices_kernel` operates using `Refs` that live in SRAM. Loading from a SRAM Ref produces a value that lives in registers. Values in registers behave like jax.Arrays in that we can use `jnp` and `jax.lax` operations on them to produce new values that live in registers. When we produce the values we'd like to return, we store them in the output SRAM `Ref`. + +The `add_matrices` function acts on `jax.Array`s and returns a `jax.Array`. Inside it, we pass `x` and `y` into pallas_call. `pallas_call` is responsible for copying `x` and `y` into SRAM and for allocating the SRAM buffers that the kernel operates on (including allocating `z_vmem_ref`, the output SRAM buffer). After the kernel function is finished running, `pallas_call` will also copy the value in `z_vmem_ref` to HBM, resulting in an output `jax.Array`. + +Pallas exposes access to lower level memory spaces like SRAM but writing performant kernels requires more care in utilizing the various memory spaces. For example, we need to consider both: + +- **Memory capacity**. SRAM is small! If our arrays are too big, the above kernel would not work because we cannot fit the input into SRAM. For reference, an `f32[2048, 2048]` array is 16MiB, so our above kernel won't scale beyond moderately sized arrays. + +- **Memory bandwidth**. Copying to/from HBM and SRAM takes a long time, at least compared to most compute instructions. The `add_matrices` function above will likely spend more time copying between HBM and SRAM than actually performing the addition itself. + +With these two constraints in mind, we'll have to rethink our strategy for getting performance out of our accelerators. + + + + + +## Pipelining Basics + + +How can we take advantage of the strengths of each form of type memory in the hierarchy, and be able to operate on large arrays stored in HBM while still utilizing fast SRAM for compute? Pipelining is a very general programming pattern which will allow us to do exactly this, but it requires transforming your problem into smaller sub-problems that can be overlapped in parallel. + +The first step in pipelining is to divide our problem into smaller subproblems that can fit inside of SRAM. For example, an elementwise operation is can be trivially transformed by operating on one slice of the source array at a time, which results in the following 3 steps (also known as stages): + +1. **copy_in**: Copy a slice `A[i]` from HBM to SRAM `X`. +2. **compute**: Load `X` into registers, compute a result, and store in SRAM `Y` +3. **copy_out**: Copy result `Y` back into HBM `A[i]`. + +Note that there is a data-dependence between steps 1-3, and we cannot trivially overlap them since we need step (1) to complete before starting step (2), and so on. However, there is no data dependence across multiple invocations of the subproblem - that is, we can execute step (1) for block `A[i+1]` while executing step (2) for block `A[i]` and step (3) for block `A[i-1]`. + + + + + + + +![pipelining_example](../_static/pallas/pipelining_example.svg) + + + + +The diagram above depicts how an idealized pipelined program can be scheduled across time. The key insight is that in the majority of the kernel, the copy operations are executed in parallel with compute operations, meaning we can ideally "hide" the cost of transferring between HBM/SRAM with computation and keep the processor busy with as much uptime as possible. + +The initial startup time and final teardown time known as "bubbles", where only a subset of the stages are being executed while the pipeline is being "filled" or "drained". The bulk of the time is spent in the "steady-state" phase of the pipeline, where each pipeline stage is being executed in parallel across different iterations of the subproblem. While with more general pipelining approaches the goal is to achieve N-way parallelism (where N is the number of stages), with kernel pipelining we are usually bottlenecked either by memory bandwidth or processing speed. Therefore, our goal with kernel pipelining is typically to achieve full utilization of the FLOPs/s of our processor, meaning that at any point in time there is always a `compute` block active. In the figure above, the compute block is active in 6/8 timeslots, and assuming we are fully utilizing the processor in each compute timeslot, we would have achieved 75% utilization of the processor. + + + +### Deriving a Double-Buffered Pipeline + +Now lets look at how we could implement a pipeline in pseudocode. Consider the following elementwise program, where we load values from HBM (`A[i]`) with a `copy_in` instruction, add 1 to the result, and store the result back to HBM with `copy_out`: + +
+for i in range(N):
+  copy_in(A[i], X)
+  Y = X + 1
+  copy_out(Y, A[i])
+
+The issue with this approach is that `copy_in` and `copy_out` are typically blocking operations. So we are forced to wait for the copies to finish while the GPU/TPU is idle, then perform compute while the memory is idle. What we would like to do is to "pre-fetch" the input value that is required on the next iteration of the loop asynchronously while performing the computation for the current loop, so that compute and memory communication are happening simultaneously. + +In order to reason about the code transformation we will make, lets unroll the loop for N=4, and decompose the copy instructions into separate `copy_start` and `copy_wait` operations to be able to express asynchrony: +
+  # Itr 1
+  copy_in_start(A[0], X)
+  copy_in_wait(X)
+  Y = X + 1
+  copy_out_start(Y, A[0])
+  copy_out_wait(Y)
+
+  # Itr 2
+  copy_in_start(A[1], X)
+  copy_in_wait(X)
+  Y = X + 1
+  copy_out_start(Y, A[1])
+  copy_out_wait(Y)
+
+  # Itr 3
+  copy_in_start(A[2], X)
+  copy_in_wait(X)
+  Y = X + 1
+  copy_out_start(Y, A[2])
+  copy_out_wait(Y)
+
+  # Itr 4
+  copy_in_start(A[3], X)
+  copy_in_wait(X)
+  Y = X + 1
+  copy_out_start(Y, A[3])
+  copy_out_wait(Y)
+
+ +Once the loop has been unrolled, the pipelining transformation simply involves issuing `copy_start` instructions as early as possible, and `copy_wait` values as late as possible (right before we need the value). However, in the current state of the loop there is a fake data dependency through X - we cannot simultaneously perform an async copy into X while using it for computation or else we may have a race condition. Therefore, we can use a **multiple-buffering** technique where we keep 2 buffers for each input X and each output Y. With 2 buffers, we can push the `copy_in_start` one iteration ahead (with 3 buffers you can push 2 iterations, and so on) and we rewrite our loop as follows: +
+  # Prologue
+  copy_in_start(A[0], X[0])
+  
+  # Itr 1
+  copy_in_start(A[1], X[1])
+  copy_in_wait(X[0])
+  Y[0] = X[0] + 1
+  copy_out_start(Y[0], A[0])
+  copy_out_wait(Y[0])
+
+  # Itr 2 - Steady state
+  copy_in_start(A[2], X[0])
+  copy_in_wait(X[1])
+  Y[1] = X[1] + 1
+  copy_out_start(Y[1], A[1])
+  copy_out_wait(Y[1])
+
+  # Itr 3 - Steady state
+  copy_in_start(A[3], X[1])
+  copy_in_wait(X[0])
+  Y[0] = X[0] + 1
+  copy_out_start(Y[0], A[2])
+  copy_out_wait(Y[0])
+
+  # Itr 4 - No copy-in
+  copy_in_wait(X[1])
+  Y[1] = X[1] + 1
+  copy_out_start(Y[1], A[3])
+  copy_out_wait(Y[1])
+
+ +Next, we can push the `copy_out_wait` as late as possible, right before we need to write into Y on the subsequent loop iteration. + +
+  # Prologue
+  copy_in_start(A[0], X[0])
+  
+  # Itr 1
+  copy_in_start(A[1], X[1])
+  copy_in_wait(X[0])
+  Y[0] = X[0] + 1
+  copy_out_start(Y[0], A[0])
+
+  # Itr 2 - Steady state
+  copy_in_start(A[2], X[0])
+  copy_in_wait(X[1])
+  Y[1] = X[1] + 1
+  copy_out_start(Y[1], A[1])
+  copy_out_wait(Y[0])
+
+  # Itr 3 - Steady state
+  copy_in_start(A[3], X[1])
+  copy_in_wait(X[0])
+  Y[0] = X[0] + 1
+  copy_out_start(Y[0], A[2])
+  copy_out_wait(Y[1])
+
+  # Itr 4 - No copy-in
+  copy_in_wait(X[1])
+  Y[1] = X[1] + 1
+  copy_out_start(Y[1], A[3])
+  copy_out_wait(Y[0])
+
+  # Epilogue
+  copy_out_wait(Y[1])
+
+ +Finally, re-rolling our loop back into a for loop, we obtain the following pipelined loop: + +``` +# Prologue +copy_in_start(A[0], X[0]) + +# Main loop +for i in range(N): + cur_slot = i % 2 + next_slot = (i + 1) % 2 + + if i < N: + copy_in_start(A[i+1], X[next_slot]) + + copy_in_wait(X[cur_slot]) + Y[cur_slot] = X[cur_slot] + 1 + copy_out_start(Y[cur_slot], A[i]) + + if i > 0: + copy_out_wait(Y[next_slot]) + +# Epilogue +copy_out_wait(Y[1]) +``` + +If we want to generalize this loop to handle a broader set of computations, notice that we essentially need to specify 3 pieces of information to the pipeline: + +- The **grid**, or the bounds of the for loop that specifies the number of subproblems to compute. In our example we had a 1-dimensional grid with size `(N,)`. +- The **kernel**, or the actual computation happening once the inputs have been loaded into SRAM. In our example we performed an elementwise addition `Y = X + 1`. +- The **data_slices**, which map a subproblem to corresponding slices into the HBM buffer. In our example the data slice was the identity function `lambda i: i`. + +By allowing the user to specify these pieces of information we can write a wide variety of programs following this pattern: +```python +def double_buffered_pipeline( + grid: tuple[int, ...], + kernel: Callable, + in_slices: Callable, + out_slices: Callable): + # Prologue + copy_in_start(in_hbm[in_slices(0)], in_sram[0]) + + # Main loop + grid_size = prod(grid) + for i in range(grid_size): + cur_slot = i % 2 + next_slot = (i + 1) % 2 + if (i + 1) < grid_size: + copy_in_start(in_hbm[in_slices(i+1)], in_sram[next_slot]) + copy_in_wait(in_sram[cur_slot]) + + kernel(in_sram[cur_slot], out_ram[cur_slot]) + + copy_out_start(out_sram[cur_slot], out_hbm[out_slices(i)]) + if i > 0: + copy_out_wait(out_sram[next_slot]) + + # Epilogue + last_slot = (grid_size - 1) % 2 + copy_out_wait(out_sram[last_slot]) +``` + + + +Now that we've seen how to manually implement a pipelined loop, let's look into how to use the Pallas API. + + + +## Pallas Pipelining API + +Pallas offers a pipelining API that abstracts away the boilerplate of maintaining multiple buffers and overlapping asynchronous communication with computation. The basics of this API are covered in {ref}`pallas_quickstart`, so we will go over the API briefly here for completeness and discuss some sharp edges that arise from the use of pipelining. + + +### Grid + +The program **grid** is a tuple of integers specifying the number of subproblems as an array. The structure of the pipeline can be interpreted as a nested for-loop where the bounds of each loop. + +``` +# For grid (N, M, K) +for n in range (N): + for m in range(M): + for k in range(K): + kernel() +``` + +The kernel will be invoked a total of `prod(grid)` times. For more details, see [grid and blockspecs](https://docs.jax.dev/en/latest/pallas/grid_blockspec.html#grid-a-k-a-kernels-in-a-loop). + +### BlockSpecs + +A BlockSpec specifies the size and slice of data copied to the kernel on each subproblem. The basic constructor to `pl.BlockSpec` involves specifying the `block_shape`, the size of a slice of data, and `index_map`, a function that takes in the program ids of the current subproblem and outputs _blocked_ indices into the source buffer. Blocked indices specify which block to copy on each iteration, assuming the source buffer has been carved into blocks of shape as `block_shape`. The `memory_space` argument specifies what memory space to copy the inputs to - be default this will be SRAM. + +```python +pl.BlockSpec( + block_shape: tuple[int, ...], + index_map: Callable, + memory_space: pl.MemorySpace +) +``` +There should be one BlockSpec for each input and each output to the kernel. For more details, see [grid and blockspecs](https://docs.jax.dev/en/latest/pallas/grid_blockspec.html#grid-a-k-a-kernels-in-a-loop). + +### Kernel + +The kernel function specifies what compute to perform on each subproblem. The kernel function should return no outputs, and instead all outputs should be written into the output buffers that are passed into the kernel. All inputs and output buffers are SRAM buffers by default (unless the user has overridden the behavior by specifying a `memory_space` on the corresponding `BlockSpec`). + +```python +def kernel(*input_buffers, *output_buffers): + # ... perform compute + # ... store result into output buffers +``` + +The index of the current subproblem can be queried inside the kernel using `pl.program_id(grid_axis: int)`. + + +### Pallas Call + +The `pl.pallas_call` function is the main entry point to Pallas and performs pipelined execution when a grid and BlockSpecs are supplied. It has the following signature: +```python +def pallas_call( + kernel, + grid: tuple[int, ...], + in_specs: Sequence[PyTree[BlockSpec]], + out_specs: PyTree[BlockSpec], + out_shape: PyTree[jax.ShapeDtypeStruct], +) -> Callable: +``` +`pallas_call` will return a callable function that when invoked with input values, will return outputs of the same shape as `out_shape`. + +`in_specs`, `out_specs`, and `out_shape` are PyTrees of their respective element type. The PyTrees for `in_specs` and the input buffers supplied to the kernel should match, and the PyTrees for `out_specs` and `out_shape` should also match. + + + + +### Example - Elementwise Kernel revisited + +Let's revisit the initial `add_matrices_kernel` from the beginning of the tutorial, except using pipelining. We will add two input arrays of shape `f32[4096, 4096]` that live in HBM. As subproblems, we will carve up the inputs into `block_shape=(512, 512)` blocks and only add two blocks together at a time in the kernel. Because addition is elementwise, each `index_map` is identical and selects out the `i, j`th block on the `i, j`th iteration. + + +```python id="iqr_qjONAHN9" +# Note: This is a TPU example. + +total_shape = (4096, 4096) +block_shape = (512, 512) + +def add_matrices_pipelined_kernel(x_ref, y_ref, o_ref): + o_ref[...] = x_ref[...] + y_ref[...] + +def add_matrices_pipelined(x: jax.Array, y: jax.Array): + return pl.pallas_call( + add_matrices_pipelined_kernel, + grid=tuple(total // block for (total, block) in zip(total_shape, block_shape)), + in_specs=[ + pl.BlockSpec(block_shape, index_map=lambda i, j: (i, j)), + pl.BlockSpec(block_shape, index_map=lambda i, j: (i, j)) + ], + out_specs=pl.BlockSpec(block_shape, index_map=lambda i, j: (i, j)), + out_shape=jax.ShapeDtypeStruct(total_shape, dtype=jnp.float32), + )(x, y) + +x = jax.random.uniform(jax.random.key(0), total_shape, dtype=jnp.float32) +y = jax.random.uniform(jax.random.key(1), total_shape, dtype=jnp.float32) +result = add_matrices_pipelined(x, y) +np.testing.assert_array_equal( + result, x + y +) +``` + + +It turns out that with this API, writing a pipelined kernel is not much more lines of code than writing our original naive addition kernel! + + + +### Parameterizing a Kernel + +It's common to parameterize the block shapes in our kernel. Block sizes are perhaps the most important parameter to tune when optimizing the performance of Pallas kernels! They give us control over the pipeline (for example, picking smaller blocks adds more iterations to our pipelined loop where each iteration has less work to do). Let's write a a function that does so: + + +```python id="RZTAiwrZ6srD" +def add_matrices_pipelined_param( + x: jax.Array, y: jax.Array, *, bm: int = 256, bn: int = 256 +) -> jax.Array: + m, n = x.shape + block_spec = pl.BlockSpec((bm, bn), lambda i, j: (i, j)) + return pl.pallas_call( + add_matrices_kernel, + out_shape=x, + in_specs=[block_spec, block_spec], + out_specs=block_spec, + grid=(m // bm, n // bn), + )(x, y) + +np.testing.assert_array_equal( + add_matrices_pipelined_param(x, y, bm=256, bn=256), x + y +) +np.testing.assert_array_equal( + add_matrices_pipelined_param(x, y, bm=128, bn=128), x + y +) +np.testing.assert_array_equal( + add_matrices_pipelined_param(x, y, bm=512, bn=512), x + y +) +``` + + +## Sharp edges + +While pipelining provides a close approximation to the mental model of simply calling a kernel function in a loop, there are a number of sharp edges that arise from the use of intermediate buffers that are not fully hidden from the user and can result in subtle bugs. + +### Buffer Revisiting + +In general, a good rule-of-thumb to follow is that **the input buffers passed into the kernel function should be interpreted as read-only, and output buffers are write only**. + +Writing to inputs and reading from outputs will in most cases result in incorrectness. This is because the SRAM buffers passed to a kernel only contain copies of the data contained in the underlying HBM buffer. If an input SRAM buffer is updated, the updated results will never be written back out to HBM, and if an output buffer is updated, it's updated value is never read into SRAM. This issue is analogous to staleness issues encountered when using caches in general. + +There are two cases where a buffer supports both reads and writes - accumulation (discussed next), and marking a pair of input and output buffers as input-output aliased by passing in the `input_output_aliases` argument to `pallas_call`. + + +### Reductions and accumulation + +**Reduction/accumulation should only be performed over the last (innermost) dimensions of the grid, and the buffer should be initialized manually first.** + +Reductions are one of the few cases where the pipeline supports both reading and writing to an output buffer, but the reason it works is subtle. +The Pallas pipeline emitter performs an optimization where if the data slices between two consecutive iterations are the same, the pipeline will not issue a `copy_in`/`copy_out` on that buffer. This means the same SRAM buffer used in a previous iteration will be passed into the kernel again on the following iteration, and thus any writes that were issued to the output buffer will become visible on the next iteration. Once the data slice changes, the final accumulated SRAM buffer will be written out to HBM. This is also why reductions must be performed over the last dimensions of the grid -- we want to finish all of the accumulation while the output buffer is in SRAM in the innermost loop, then write it to HBM and never touch that output block again. + +As a concrete example, let's consider performing the following computation for reducing an `(8, 1024, 1024)` array along the first axies into a `(1024, 1024)` array. + + + + + + + + +```python executionInfo={"elapsed": 244, "status": "ok", "timestamp": 1744763773938, "user": {"displayName": "Justin Fu", "userId": "17543197034567316452"}, "user_tz": 420} id="4qz1ET-_f9fJ" outputId="e43067ef-933a-45a5-912a-e224151cfa60" +x = jnp.ones((8, 1024, 1024)) +jnp.sum(x, axis=0) +``` + + +To do this using `pallas_call`, we could use a grid of size `(8,)` and in each iteration i load `x[i]` into SRAM. Then we could add `x[i]` to an output SRAM buffer. Let's implement this naively first. + + +```python executionInfo={"elapsed": 79, "status": "ok", "timestamp": 1744763774254, "user": {"displayName": "Justin Fu", "userId": "17543197034567316452"}, "user_tz": 420} id="ZEi1_vQVf-81" outputId="581744b7-ddc1-4dc1-98ec-03c852772eda" +# Note: This is a TPU example. + +# Warning: this implementation is incorrect! +def incorrect_sum_kernel(x_ref, o_ref): + o_ref[...] += x_ref[...] + +def incorrect_sum(x: jax.Array, + block_size: tuple[int, ...] = (256, 256)) -> jax.Array: + reduction_size, *out_shape = x.shape + grid = (reduction_size, *(out // blk for out, blk in zip(out_shape, block_size))) + return pl.pallas_call( + incorrect_sum_kernel, + grid=grid, + # None in `block_shape` means we pick a size of 1 and squeeze it away + in_specs=[pl.BlockSpec((None, *block_size), lambda i, j, k: (i, j, k))], + out_specs=pl.BlockSpec(block_size, lambda i, j, k: (j, k)), + out_shape=jax.ShapeDtypeStruct(out_shape, x.dtype), + )(x) + +result = incorrect_sum(x) +print(result) +``` + + +This result is completely wrong! + +There are two errors inside this kernel. First, we are accumulating along the first grid dimension instead of the last grid dimension. Second, `o_ref` initially contains garbage values and thus we need to initialize it to zeros before we begin accumulation. + +After fixing these two issues, we obtain the following corrected kernel. In this new kernel, we use `@pl.when` to create a conditional that checks when the program ID is `0` along the reduction axis, indicating we are beginning to accumulate into a new output block. We have also moved the reduction dimension to the last axis of the `grid`. + + +```python executionInfo={"elapsed": 104, "status": "ok", "timestamp": 1744763774523, "user": {"displayName": "Justin Fu", "userId": "17543197034567316452"}, "user_tz": 420} id="XtgD4nMa9_Bd" outputId="9ef07cdf-9e22-4dc8-c17f-c96172639801" +# Note: This is a TPU example. + +def correct_sum_kernel(x_ref, o_ref): + @pl.when(pl.program_id(2) == 0) + def _(): + o_ref[...] = jnp.zeros_like(o_ref) + o_ref[...] += x_ref[...] + +def correct_sum(x: jax.Array, + block_size: tuple[int, ...] = (256, 256)) -> jax.Array: + reduction_size, *out_shape = x.shape + # We moved the reduction to the last axis of the grid. + grid = (*(out // blk for out, blk in zip(out_shape, block_size)), reduction_size) + return pl.pallas_call( + correct_sum_kernel, + grid=grid, + # None in `block_shape` means we pick a size of 1 and squeeze it away + in_specs=[pl.BlockSpec((None, *block_size), lambda i, j, k: (k, i, j))], + out_specs=pl.BlockSpec(block_size, lambda i, j, k: (i, j)), + out_shape=jax.ShapeDtypeStruct(out_shape, x.dtype), + )(x) + +result = correct_sum(x) +print(result) +``` + + + +## Analyzing the performance + +What is the performance of a pipelined kernel? This question can vary depending on where the bottleneck is the hardware is. We are typically interested in 3 quantities: +- **Memory latency** $α$, the minimum latency of a memory transfer. +- **Memory bandwidth** $β$, the rate in bytes/second that we can transfer from HBM to SRAM. +- **FLOP/s** $F$, or floating-point-operations per second, the number of calculations per second that the processor can perform. + +We refer to a program as **compute-bound** if the processing speed FLOPs/s is the bottleneck, and as **memory-bound** if the bandwidth or latency are the bottleneck. Generally, our goal is to optimize a kernel such that it is compute-bound, meaning we are utilizing all of the available processing power of our hardware. + +Suppose we are running a program that requires $X$ bytes of memory transfers per kernel iteration, and runs $Y$ floating-point operations per iteration. The ratio of $X$ to $Y$ varies depending on the type of compute -- for elementwise operations such as addition or multiplication, they will both scale equally. However, for operations such as matrix multiplication, compute scales cubically with the size of the problem while memory scales quadratically. + +In a **compute-bound** regime, a pipeline running $N$ iterations would take $(\alpha + X/\beta) + N (Y/F)$ seconds, where the first term represents the cost of the initial bubble (multiply by a factor of 2 if there is also a bubble at the end), and the second term represents the total time of the steady-state of the pipeline. Assuming that N is large and there is enough work to produce a long pipeline, the dominating term in the runtime is $F$, the processing speed of the accelerator. + + + + + + +![pipelining_compute](../_static/pallas/pipelining_compute_bound.svg) + + + + +In a **memory-bound** regime it is useful to identify if the problem is the latency versus the bandwidth. If the bandwidth is the bottleneck, then the total runtime would take $\alpha + N(X / \beta)$ seconds. In contrast with a latency-bound regime, the memory copies happen serially because the bandwidth is already saturated. Being memory-bound is generally not ideal as there will be gaps in time where the processor is idle, and in most hardware configurations the memory bandwidth $\beta$ is orders of magnitude slower than the processing speed $F$. + + + + +![pipelining_bandwidth](../_static/pallas/pipelining_bandwidth_bound.svg) + + + + +If the bottleneck is specifically the latency and not the bandwidth, it is possible to fix the problem by inserting additional pipeline stages at the cost of additional SRAM required to store more buffers. With sufficient stages, the problem will either become compute or bandwidth bound again depending on which bottleneck we hit first during the steady-stage stage of the pipeline. The downside, however, of a multi-stage pipeline is that the size of the bubble is proportional to the number of stages so it is important to make sure the pipeline is long enough such that the bubble does not take up a substantial amount of the total runtime. + + + + + +![pipelining_latency](../_static/pallas/pipelining_latency_multistage.svg) + + + + +Pallas on TPU only supports double-buffering, as TPU programs can operate on larger block sizes and double-buffering is typically enough to cover the latency. On GPU, the number of pipeline stages can be specified in both the Triton (via `TritonCompilerParams`) and Mosaic GPU backends (via argument to the pipeline emitter). See the platform-specific pipelining documentation for more details. + diff --git a/docs/pallas/quickstart.ipynb b/docs/pallas/quickstart.ipynb index 11dd2108e405..ffdf715e984a 100644 --- a/docs/pallas/quickstart.ipynb +++ b/docs/pallas/quickstart.ipynb @@ -5,6 +5,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ + "(pallas_quickstart)=\n", "# Pallas Quickstart\n", "\n", "\n", @@ -279,7 +280,7 @@ "metadata": {}, "source": [ "TPUs distinguish between vector and scalar memory spaces and in this case the\n", - "output must be placed in scalar memory (`TPUMemorySpace.SMEM`) since `i` is\n", + "output must be placed in scalar memory (`MemorySpace.SMEM`) since `i` is\n", "a scalar. For more details read {ref}`tpu_and_its_memory_spaces`.\n", "To call the above kernel on TPU, run:" ] @@ -296,7 +297,7 @@ "\n", "def iota(size: int):\n", " return pl.pallas_call(iota_kernel,\n", - " out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM),\n", + " out_specs=pl.BlockSpec(memory_space=pltpu.MemorySpace.SMEM),\n", " out_shape=jax.ShapeDtypeStruct((size,), jnp.int32),\n", " grid=(size,))()\n", "iota(8)" diff --git a/docs/pallas/quickstart.md b/docs/pallas/quickstart.md index fff1dcb730f3..5f1832f2a2f0 100644 --- a/docs/pallas/quickstart.md +++ b/docs/pallas/quickstart.md @@ -12,6 +12,7 @@ kernelspec: name: python3 --- +(pallas_quickstart)= # Pallas Quickstart @@ -185,7 +186,7 @@ iota(8) ``` TPUs distinguish between vector and scalar memory spaces and in this case the -output must be placed in scalar memory (`TPUMemorySpace.SMEM`) since `i` is +output must be placed in scalar memory (`MemorySpace.SMEM`) since `i` is a scalar. For more details read {ref}`tpu_and_its_memory_spaces`. To call the above kernel on TPU, run: @@ -195,7 +196,7 @@ from jax.experimental.pallas import tpu as pltpu def iota(size: int): return pl.pallas_call(iota_kernel, - out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM), + out_specs=pl.BlockSpec(memory_space=pltpu.MemorySpace.SMEM), out_shape=jax.ShapeDtypeStruct((size,), jnp.int32), grid=(size,))() iota(8) diff --git a/docs/pallas/tpu/details.rst b/docs/pallas/tpu/details.rst index 0575806e6037..91aefd52d2e8 100644 --- a/docs/pallas/tpu/details.rst +++ b/docs/pallas/tpu/details.rst @@ -99,8 +99,8 @@ for exceptions). This unlocks some interesting capabilities: output, without any risk of race conditions. However, we do require that all invocations that write to a particular slice are consecutive. -The "consecutive" restriction on the output usually means that the some prefix -of the grid dimensions always vary the slice of the output an invocation needs +The "consecutive" restriction on the output usually means that some prefix +of the grid dimensions always varies the slice of the output an invocation needs to access, while the output window remains constant for the remaining suffix. For example, when implementing a Pallas TPU kernel for matrix multiplication, @@ -128,7 +128,7 @@ has no impact on performance, as the compiler is free to rearrange them. However, as Pallas is meant to expose lower-level capabilities, the dimension order can have great impact on the quality of generated code. -TPUs perform bulk of the computation on 2D vector registers, which are typically of +TPUs perform the bulk of the computation on 2D vector registers, which are typically of size 8x128 for 32-bit values (as of TPU v6). When a vector value is loaded from VMEM into registers (e.g. ``x = x_ref[...]``), the last two dimensions of the array will be tiled into the registers. @@ -167,10 +167,11 @@ sequential grid execution guarantees, and will need to parallelize one of the grid axes over cores. This is an opt-in procedure. To allow that, ``pallas_call`` requires an extra parameter named ``dimension_semantics``: -.. +.. code:: python + pallas_call( ..., - compiler_params=pltpu.TPUCompilerParams( + compiler_params=pltpu.CompilerParams( dimension_semantics=["parallel", "parallel", "arbitrary"] ), ) diff --git a/docs/pallas/tpu/distributed.ipynb b/docs/pallas/tpu/distributed.ipynb index b52ec579f508..31db839f8d0b 100644 --- a/docs/pallas/tpu/distributed.ipynb +++ b/docs/pallas/tpu/distributed.ipynb @@ -8,21 +8,21 @@ "source": [ "# Distributed Computing in Pallas for TPUs\n", "\n", - "In this tutorial, we will cover the basics of distributed computing in Pallas on TPUs. We will learn about TPU topologies, communication using the remote DMA primitive, and calling a distributed kernel from JAX using `shard_map`. We will also cover some more advanced kernel writing techniques, such as double-buffering, bi-directional bandwidth optimization, and nested pipelining. As educational examples, we will learn how to implement various collective primitives from JAX, such as `lax.ppermute`, `lax.all_gather`, `lax.psum`, and `lax.psum_scatter`.\n", + "In this tutorial, we will cover the basics of distributed computing in Pallas on TPUs. We will learn about TPU topologies, communication using the remote DMA primitive, and calling a distributed kernel from JAX using `jax.shard_map`. We will also cover some more advanced kernel writing techniques, such as double-buffering, bi-directional bandwidth optimization, and nested pipelining. As educational examples, we will learn how to implement various collective primitives from JAX, such as `lax.ppermute`, `lax.all_gather`, `lax.psum`, and `lax.psum_scatter`.\n", "\n", "Some recommended readings beforehand:\n", " - [Pallas Pipelining on TPU](pallas_tpu_pipelining)\n", - " - [Collectives with `shard_map`](shard_map_collectives_tutorial)" + " - [Collectives with `jax.shard_map`](shard_map_collectives_tutorial)" ] }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 6, "metadata": { "executionInfo": { - "elapsed": 1978, + "elapsed": 52, "status": "ok", - "timestamp": 1722904801801, + "timestamp": 1744390458993, "user": { "displayName": "Justin Fu", "userId": "17543197034567316452" @@ -30,23 +30,23 @@ "user_tz": 420 }, "id": "PyAGnWc9yI8T", - "outputId": "1d8229bd-cab5-495f-93e9-fff2e41db480" + "outputId": "c5912653-c34b-4810-c373-4a2787691317" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Running with 4 TPU v5 lite devices.\n" + "Running with 4 TPU v4 devices.\n" ] } ], "source": [ + "import functools\n", "import jax\n", "from jax import lax\n", "from jax import numpy as jnp\n", "from jax.experimental import pallas as pl\n", - "from jax.experimental import shard_map\n", "from jax.experimental.pallas import tpu as pltpu\n", "\n", "P = jax.sharding.PartitionSpec\n", @@ -71,7 +71,7 @@ "\n", "![tpu_topologies](https://cloud.google.com/static/tpu/docs/images/v4-topologies.png)\n", "\n", - "Flattened as a graph, the torus can be visualized as follows. Each edge (orange or black) is a bidirectional connection between two devices. You will commonly hear about rings in conjunction with discussion about device toplogies — a key feature of a torus is that when taking a slice along an axis of the pod, such as the nodes `[(0,1), (1, 1), (2, 1), (3, 1)]` or `[(0, 1), (1, 1)]`, we have a ring of devices. This is a feature we can use to simplify communication patterns within the pod.\n", + "Flattened as a graph, the torus can be visualized as follows. Each edge (orange or black) is a bidirectional connection between two devices. You will commonly hear about rings in conjunction with discussion about device topologies — a key feature of a torus is that when taking a slice along an axis of the pod, such as the nodes `[(0,1), (1, 1), (2, 1), (3, 1)]` or `[(0, 1), (1, 1)]`, we have a ring of devices. This is a feature we can use to simplify communication patterns within the pod.\n", "\n", "![tpu_torus](https://cloud.google.com/static/tpu/docs/images/untwisted-tori.png)" ] @@ -178,7 +178,7 @@ "\n", "`send_sem` and `recv_sem` are instances of a special type of semaphore reserved exclusively for use with DMAs. They must be allocated with the `tpu.SemaphoreType.DMA` type when specifying input specs to `pallas_call`.\n", "\n", - "Internally, DMA semaphores can be thought of as integer-valued progress trackers. On DMA start, the local device will begin to increment the value of `send_sem` and the receiver's `recv_sem` asynchronously. Waiting on a semaphore will block until the value of the semaphore reaches the total bytes of data sent/received; when the value is reached, waiting threads are released and the sempahore's value is decremented by the same amount. This means that either all data has been sent (for `send_sem`) or all data has been received (for `dst_sem`). The value of the semaphore can be read with `pl.semaphore_read`, but note that the underlying semantics of the value could change between hardware generations (e.g. the value may not represent exactly the number of bytes sent, although this is a useful mental model to have when reasoning about the behavior of the semaphore).\n", + "Internally, DMA semaphores can be thought of as integer-valued progress trackers. On DMA start, the local device will begin to increment the value of `send_sem` and the receiver's `recv_sem` asynchronously. Waiting on a semaphore will block until the value of the semaphore reaches the total bytes of data sent/received; when the value is reached, waiting threads are released and the semaphore's value is decremented by the same amount. This means that either all data has been sent (for `send_sem`) or all data has been received (for `recv_sem`). The value of the semaphore can be read with `pl.semaphore_read`, but note that the underlying semantics of the value could change between hardware generations (e.g. the value may not represent exactly the number of bytes sent, although this is a useful mental model to have when reasoning about the behavior of the semaphore).\n", "\n", "### Routing\n", "\n", @@ -215,12 +215,12 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 7, "metadata": { "executionInfo": { - "elapsed": 1606, + "elapsed": 152, "status": "ok", - "timestamp": 1722904803566, + "timestamp": 1744390459367, "user": { "displayName": "Justin Fu", "userId": "17543197034567316452" @@ -228,7 +228,7 @@ "user_tz": 420 }, "id": "YkyIKN2thZ-V", - "outputId": "9b7ed142-d161-4237-fed8-cbce41adc5f0" + "outputId": "26719bb9-87ff-46dd-af90-a114ce332417" }, "outputs": [ { @@ -271,11 +271,11 @@ "out_shape = jax.ShapeDtypeStruct((8, 128), jnp.float32)\n", "grid_spec = pltpu.PrefetchScalarGridSpec(\n", " num_scalar_prefetch=0,\n", - " # TPUMemorySpace.ANY will (usually) place the tensor in HBM.\n", + " # MemorySpace.ANY will (usually) place the tensor in HBM.\n", " in_specs=[\n", - " pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),\n", + " pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY),\n", " ],\n", - " out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),\n", + " out_specs=pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY),\n", " scratch_shapes=(\n", " # We allocate DMA semaphores in scratch memory.\n", " [pltpu.SemaphoreType.DMA] * 2\n", @@ -288,12 +288,12 @@ ")\n", "# Wrap the kernel within a shard_map to call.\n", "pallas_result = jax.jit(\n", - " shard_map.shard_map(\n", + " jax.shard_map(\n", " right_permute,\n", " mesh=mesh,\n", " in_specs=partition,\n", " out_specs=partition,\n", - " check_rep=False,\n", + " check_vma=False,\n", " )\n", ")(input_arr)\n", "\n", @@ -301,7 +301,7 @@ "perm = tuple((src, (src + 1) % num_devices) for src in range(num_devices))\n", "\n", "xla_result = jax.jit(\n", - " shard_map.shard_map(\n", + " jax.shard_map(\n", " lambda x: lax.ppermute(x, 'x', perm),\n", " mesh=mesh, in_specs=partition, out_specs=partition)\n", ")(input_arr)\n", @@ -338,12 +338,12 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 8, "metadata": { "executionInfo": { - "elapsed": 812, + "elapsed": 209, "status": "ok", - "timestamp": 1722904804531, + "timestamp": 1744390459789, "user": { "displayName": "Justin Fu", "userId": "17543197034567316452" @@ -351,7 +351,7 @@ "user_tz": 420 }, "id": "ojQEZB5mBRqM", - "outputId": "e1648f54-737c-4921-ca3b-b4c639a38d2b" + "outputId": "3a4373f8-1fb5-4a6b-b88e-3461c2609021" }, "outputs": [ { @@ -420,10 +420,10 @@ "grid_spec = pltpu.PrefetchScalarGridSpec(\n", " num_scalar_prefetch=0,\n", " in_specs=[\n", - " # TPUMemorySpace.ANY will (usually) place the tensor in HBM.\n", - " pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),\n", + " # MemorySpace.ANY will (usually) place the tensor in HBM.\n", + " pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY),\n", " ],\n", - " out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),\n", + " out_specs=pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY),\n", " scratch_shapes=(\n", " # DMA semaphores are allocated in scratch memory.\n", " # We allocated one semaphore for a local HBM-VMEM copy,\n", @@ -447,18 +447,18 @@ "\n", "# Wrap the kernel within a shard_map to call.\n", "pallas_result = jax.jit(\n", - " shard_map.shard_map(\n", + " jax.shard_map(\n", " all_gather,\n", " mesh=mesh,\n", " in_specs=partition,\n", " out_specs=partition,\n", - " check_rep=False\n", + " check_vma=False\n", " )\n", ")(input_arr)\n", "\n", "# Compare Pallas result to XLA shard_map result.\n", "xla_result = jax.jit(\n", - " shard_map.shard_map(\n", + " jax.shard_map(\n", " lambda x: lax.all_gather(x, 'x'),\n", " mesh=mesh, in_specs=partition, out_specs=partition\n", " )\n", @@ -477,13 +477,13 @@ "id": "KgU7HI2pS4om" }, "source": [ - "A detail worth mentioning here is the use of multiple receive semaphores. Because we only block on the receiving device, it is still possible for a sender to have sent multiple DMAs in flight before the receiver has finished processing the first one (see the next section and reduce-sum example which discusses race conditions in more detail). In this situation we may hit a situation where the same semaphore is being used for multiple DMAs occurring simultaneously. To avoid this, we allocate `num_devices-1` semaphores so there is no risk of re-use. While this race condition is unlikely to happen on such a small kernel, on larger kernels there is more chance for devices to fall out of sync and potentially cause a silent failure." + "A detail worth mentioning here is the use of multiple receive semaphores. Because we only block on the receiving device, it is still possible for a sender to have sent multiple DMAs in flight before the receiver has finished processing the first one (see the next section and reduce-sum example which discusses race conditions in more detail). In this situation we may hit a situation where the same semaphore is being used for multiple DMAs occurring simultaneously. To avoid this, we allocate `num_devices-1` semaphores so there is no risk of reuse. While this race condition is unlikely to happen on such a small kernel, on larger kernels there is more chance for devices to fall out of sync and potentially cause a silent failure." ] }, { "cell_type": "markdown", "metadata": { - "id": "KgU7HI2pS4om" + "id": "EDCmAaHVtY7x" }, "source": [ "## Advanced Techniques\n", @@ -529,9 +529,9 @@ "\n", "In order to use regular semaphores, they can be allocated in the same way as a DMA semaphore, but by specifying `pltpu.SemaphoreType.REGULAR` rather than `pltpu.SemaphoreType.DMA`.\n", "\n", - "Semaphores must be zero at the end of a Pallas program to complete succesfully. There are two error cases where this may happen:\n", + "Semaphores must be zero at the end of a Pallas program to complete successfully. There are two error cases where this may happen:\n", " - If a semaphore is over-signaled, the program will end with non-zero (>0) semaphores. In this case, the program will crash upon completion. This is useful for debugging as non-zero semaphores typically means there is a bug somewhere inside of the program.\n", - " - If a semaphore is over-waited, the program will hang on the blocking `semaphore_wait` call while it waits for the sempahore to be incremented. In this case the device or program will need to be restarted.\n", + " - If a semaphore is over-waited, the program will hang on the blocking `semaphore_wait` call while it waits for the semaphore to be incremented. In this case the device or program will need to be restarted.\n", "\n", "#### Barrier Semaphores\n", "\n", @@ -569,7 +569,7 @@ "kernel = pl.pallas_call(\n", " example_kernel,\n", " ...,\n", - " compiler_params=pltpu.TPUCompilerParams(collective_id=0),\n", + " compiler_params=pltpu.CompilerParams(collective_id=0),\n", ")\n", "```" ] @@ -644,19 +644,19 @@ "\n", "The main body assumes that a value has already been copied into our local working slot, either from the previous iteration or from the prologue. A complicating factor is that our destination buffers live in HBM, but we need to load values to VMEM before we perform arithmetic. Therefore, we simultaneously copy the working slot value into our VMEM (`receive_scratch`) and pass the value on to our right neighbor's receiving slot. Once the value has been copied into our VMEM, we can accumulate it into our result (contained in `o_ref`).\n", "\n", - "A subtle race condition can occur if one device runs one loop ahead of it's right neighbor. In this case, it could copy into the receiver's `working_slot` at the same time the receiver is reading from it. In order to avoid this, each device will block on a `REGULAR` semaphore before copying into the right neighbor's `dst_ref` until it has signaled that it is done reading from its `working_slot`. This race condition is rarely triggered for a small kernel such as this example, but can it can be explicitly triggered if for example using a `pltpu.delay` instruction to artifically hang a device.\n", + "A subtle race condition can occur if one device runs one loop ahead of it's right neighbor. In this case, it could copy into the receiver's `working_slot` at the same time the receiver is reading from it. In order to avoid this, each device will block on a `REGULAR` semaphore before copying into the right neighbor's `dst_ref` until it has signaled that it is done reading from its `working_slot`. This race condition is rarely triggered for a small kernel such as this example, but can it can be explicitly triggered if for example using a `pltpu.delay` instruction to artificially hang a device.\n", "\n", "Note that this is not an optimal or fully general kernel, as the block sizes must entirely fit in VMEM and we could better interleave communication and accumulation. We will discuss these optimizations in later sections." ] }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 9, "metadata": { "executionInfo": { - "elapsed": 254, + "elapsed": 248, "status": "ok", - "timestamp": 1722904804952, + "timestamp": 1744390460289, "user": { "displayName": "Justin Fu", "userId": "17543197034567316452" @@ -664,7 +664,7 @@ "user_tz": 420 }, "id": "XrY5bMlvBroQ", - "outputId": "77497000-4496-462e-cc3c-73fb640cc14c" + "outputId": "9216e749-48d2-43ff-d64b-bd419acf3e11" }, "outputs": [ { @@ -674,7 +674,7 @@ "Input = [0.9858954 0.11763906 0.9955574 0.775211 ]\n", "Pallas result = [2.8743029 2.8743029 2.8743029 2.8743029]\n", "lax.psum result = [2.8743029 2.8743029 2.8743029 2.8743029]\n", - "Difference |Pallas - lax.psum| = 1.4959369e-08\n" + "Difference |Pallas - lax.psum| = 1.0535587e-08\n" ] } ], @@ -687,6 +687,41 @@ "input_arr = jax.device_put(input_arr, sharding)\n", "\n", "\n", + "def local_barrier(left_neighbor, right_neighbor, double_barrier=True):\n", + " \"\"\"Performs a barrier with neighbors on the global barrier semaphore.\n", + "\n", + " Optionally performs a second barrier, which prevents a potential race\n", + " when reusing the same collective_id across kernel invocations.\n", + " \"\"\"\n", + " barrier_sem = pltpu.get_barrier_semaphore()\n", + " for neighbor in [left_neighbor, right_neighbor]:\n", + " pltpu.semaphore_signal(\n", + " barrier_sem,\n", + " inc=1,\n", + " device_id=(neighbor,),\n", + " device_id_type=pltpu.DeviceIdType.MESH,\n", + " )\n", + " pltpu.semaphore_wait(barrier_sem, 2)\n", + " if double_barrier:\n", + " # The double-barrier prevents a race condition where one neighbor can\n", + " # re-enter the kernel again on a subsequent call and increment the\n", + " # barrier semaphore a second time. This would unblock the current device\n", + " # even if the other neighbor is not ready yet.\n", + " # To implement a double-barrier, we stack-allocate a second REGULAR\n", + " # semaphore using run_scoped.\n", + " @functools.partial(pl.run_scoped,\n", + " second_barrier=pltpu.SemaphoreType.REGULAR)\n", + " def _(second_barrier):\n", + " for neighbor in [left_neighbor, right_neighbor]:\n", + " pltpu.semaphore_signal(\n", + " second_barrier,\n", + " inc=1,\n", + " device_id=(neighbor,),\n", + " device_id_type=pltpu.DeviceIdType.MESH,\n", + " )\n", + " pltpu.semaphore_wait(second_barrier, 2)\n", + "\n", + "\n", "def all_reduce_kernel(\n", " x_ref,\n", " o_ref,\n", @@ -709,20 +744,7 @@ " def _():\n", " # Barrier with both neighbors at the start, since we will be\n", " # communicating with both.\n", - " barrier_sem = pltpu.get_barrier_semaphore()\n", - " pltpu.semaphore_signal(\n", - " barrier_sem,\n", - " inc=1,\n", - " device_id=(left_neighbor,),\n", - " device_id_type=pltpu.DeviceIdType.MESH,\n", - " )\n", - " pltpu.semaphore_signal(\n", - " barrier_sem,\n", - " inc=1,\n", - " device_id=(right_neighbor,),\n", - " device_id_type=pltpu.DeviceIdType.MESH,\n", - " )\n", - " pltpu.semaphore_wait(barrier_sem, 2)\n", + " local_barrier(left_neighbor, right_neighbor)\n", "\n", " # Initialize o_ref, acc_scratch, and hbm_scratch.\n", " o_ref[...] = jnp.zeros_like(o_ref)\n", @@ -787,13 +809,13 @@ " num_scalar_prefetch=0,\n", " in_specs=[\n", " # Our input lives in VMEM\n", - " pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM),\n", + " pl.BlockSpec(memory_space=pltpu.MemorySpace.VMEM),\n", " ],\n", " out_specs=[\n", " # Our output lives in VMEM\n", - " pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM),\n", + " pl.BlockSpec(memory_space=pltpu.MemorySpace.VMEM),\n", " # Our double-buffer lives in HBM\n", - " pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),\n", + " pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY),\n", " ],\n", " grid=(num_devices,),\n", " scratch_shapes=(\n", @@ -807,16 +829,16 @@ " all_reduce_kernel,\n", " out_shape=out_shape,\n", " grid_spec=grid_spec,\n", - " compiler_params=pltpu.TPUCompilerParams(collective_id=0),\n", + " compiler_params=pltpu.CompilerParams(collective_id=0),\n", ")\n", "\n", "pallas_result = jax.jit(\n", - " shard_map.shard_map(\n", + " jax.shard_map(\n", " kernel,\n", " mesh=mesh,\n", " in_specs=partition,\n", " out_specs=partition,\n", - " check_rep=False,\n", + " check_vma=False,\n", " )\n", ")(input_arr)\n", "pallas_result = jax.block_until_ready(pallas_result)[0]\n", @@ -827,7 +849,7 @@ "\n", "\n", "xla_result = jax.jit(\n", - " shard_map.shard_map(\n", + " jax.shard_map(\n", " lax_sum, mesh=mesh, in_specs=P(None, 'x'), out_specs=P(None, 'x')\n", " )\n", ")(input_arr)\n", @@ -892,12 +914,12 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 10, "metadata": { "executionInfo": { - "elapsed": 544, + "elapsed": 362, "status": "ok", - "timestamp": 1722904805699, + "timestamp": 1744390460871, "user": { "displayName": "Justin Fu", "userId": "17543197034567316452" @@ -1017,20 +1039,7 @@ " def _():\n", " # Barrier with both neighbors at the start, since we will be\n", " # communicating with both.\n", - " barrier_sem = pltpu.get_barrier_semaphore()\n", - " pltpu.semaphore_signal(\n", - " barrier_sem,\n", - " inc=1,\n", - " device_id=(left_neighbor,),\n", - " device_id_type=pltpu.DeviceIdType.MESH,\n", - " )\n", - " pltpu.semaphore_signal(\n", - " barrier_sem,\n", - " inc=1,\n", - " device_id=(right_neighbor,),\n", - " device_id_type=pltpu.DeviceIdType.MESH,\n", - " )\n", - " pltpu.semaphore_wait(barrier_sem, 2)\n", + " local_barrier(left_neighbor, right_neighbor)\n", "\n", " # Initialize o_ref, acc_scratch, and hbm_scratch with initial copies.\n", " o_ref[...] = jnp.zeros_like(o_ref[...])\n", @@ -1137,11 +1146,11 @@ "grid_spec = pltpu.PrefetchScalarGridSpec(\n", " num_scalar_prefetch=0,\n", " in_specs=[\n", - " pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM),\n", + " pl.BlockSpec(memory_space=pltpu.MemorySpace.VMEM),\n", " ],\n", " out_specs=[\n", - " pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM),\n", - " pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),\n", + " pl.BlockSpec(memory_space=pltpu.MemorySpace.VMEM),\n", + " pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY),\n", " ],\n", " grid=(num_devices, 2),\n", " scratch_shapes=(\n", @@ -1160,17 +1169,17 @@ " reduce_scatter_kernel,\n", " out_shape=out_shape,\n", " grid_spec=grid_spec,\n", - " compiler_params=pltpu.TPUCompilerParams(collective_id=0),\n", + " compiler_params=pltpu.CompilerParams(collective_id=0),\n", " )(input_arr)[0]\n", "\n", "\n", "pallas_result = jax.jit(\n", - " shard_map.shard_map(\n", + " jax.shard_map(\n", " pallas_reduce_scatter,\n", " mesh=mesh,\n", " in_specs=P(None, 'x'),\n", " out_specs=P('x', None),\n", - " check_rep=False,\n", + " check_vma=False,\n", " )\n", ")(input_arr)\n", "\n", @@ -1179,12 +1188,12 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 11, "metadata": { "executionInfo": { - "elapsed": 596, + "elapsed": 917, "status": "ok", - "timestamp": 1722904806442, + "timestamp": 1744390461967, "user": { "displayName": "Justin Fu", "userId": "17543197034567316452" @@ -1192,7 +1201,7 @@ "user_tz": 420 }, "id": "E-NMh-_teoi4", - "outputId": "24beb42f-1bdd-4c34-e8d2-681dd7f2e9c0" + "outputId": "6c8b82bc-ed64-4cc1-8c5f-65e29cdb333c" }, "outputs": [ { @@ -1220,7 +1229,7 @@ "\n", "\n", "xla_result = jax.jit(\n", - " shard_map.shard_map(\n", + " jax.shard_map(\n", " lax_reduce_sum_scatter,\n", " mesh=mesh,\n", " in_specs=P(None, 'x'),\n", @@ -1298,7 +1307,7 @@ "\n", "In this next example we will modify our previous reduce-scatter example to utilize a nested inner pipeline. Note that the communication and computation costs of `reduce_scatter` both scale linearly with the size of the input, so we do not necessarily expect to see the operation become compute-bound with larger block sizes. This example is purely for demonstration purposes on how to use the pipeline emitter.\n", "\n", - "We will increase the block sizes of the outer kernel such that they would be undesirable to place inside of VMEM, and allocate all inputs and outputs in HBM (`memory_space=TPUMemorySpace.Any`). The only major change from our previous kernel is the body of the kernel where accumulation is done. Rather than manually copying from HBM to VMEM, accumulating, and copying back to HBM, we use `emit_pipeline` to handle the memory transfers for us. Accumulation is done in an inner kernel with a much smaller, VMEM-friendly block size.\n", + "We will increase the block sizes of the outer kernel such that they would be undesirable to place inside of VMEM, and allocate all inputs and outputs in HBM (`memory_space=MemorySpace.ANY`). The only major change from our previous kernel is the body of the kernel where accumulation is done. Rather than manually copying from HBM to VMEM, accumulating, and copying back to HBM, we use `emit_pipeline` to handle the memory transfers for us. Accumulation is done in an inner kernel with a much smaller, VMEM-friendly block size.\n", "\n", "In our previous kernel we had the following kernel body to copy data from HBM to the VMEM accumulator, increment, and then copy the results back to HBM:\n", "\n", @@ -1356,12 +1365,12 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 12, "metadata": { "executionInfo": { - "elapsed": 1341, + "elapsed": 997, "status": "ok", - "timestamp": 1722904807930, + "timestamp": 1744390463178, "user": { "displayName": "Justin Fu", "userId": "17543197034567316452" @@ -1399,7 +1408,7 @@ "inner_block_spec = pl.BlockSpec(\n", " index_map=lambda i, j: (i, j),\n", " block_shape=inner_block_size,\n", - " memory_space=pltpu.TPUMemorySpace.ANY,\n", + " memory_space=pltpu.MemorySpace.ANY,\n", ")\n", "\n", "\n", @@ -1474,20 +1483,7 @@ " def _():\n", " # Barrier with both neighbors at the start, since we will be\n", " # communicating with both.\n", - " barrier_sem = pltpu.get_barrier_semaphore()\n", - " pltpu.semaphore_signal(\n", - " barrier_sem,\n", - " inc=1,\n", - " device_id=(left_neighbor,),\n", - " device_id_type=pltpu.DeviceIdType.MESH,\n", - " )\n", - " pltpu.semaphore_signal(\n", - " barrier_sem,\n", - " inc=1,\n", - " device_id=(right_neighbor,),\n", - " device_id_type=pltpu.DeviceIdType.MESH,\n", - " )\n", - " pltpu.semaphore_wait(barrier_sem, 2)\n", + " local_barrier(left_neighbor, right_neighbor)\n", "\n", " initial_left_copy.start()\n", " initial_left_copy.wait()\n", @@ -1594,11 +1590,11 @@ "grid_spec = pltpu.PrefetchScalarGridSpec(\n", " num_scalar_prefetch=0,\n", " in_specs=[\n", - " pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),\n", + " pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY),\n", " ],\n", " out_specs=[\n", - " pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),\n", - " pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),\n", + " pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY),\n", + " pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY),\n", " ],\n", " grid=(num_devices, 2),\n", " scratch_shapes=(\n", @@ -1616,17 +1612,17 @@ " reduce_scatter_kernel,\n", " out_shape=out_shape,\n", " grid_spec=grid_spec,\n", - " compiler_params=pltpu.TPUCompilerParams(collective_id=0),\n", + " compiler_params=pltpu.CompilerParams(collective_id=0),\n", " )(input_arr)[0]\n", "\n", "\n", "pallas_result = jax.jit(\n", - " shard_map.shard_map(\n", + " jax.shard_map(\n", " pallas_reduce_scatter,\n", " mesh=mesh,\n", " in_specs=P(None, 'x'),\n", " out_specs=P('x', None),\n", - " check_rep=False,\n", + " check_vma=False,\n", " )\n", ")(input_arr)\n", "\n", @@ -1635,12 +1631,12 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 13, "metadata": { "executionInfo": { - "elapsed": 768, + "elapsed": 1132, "status": "ok", - "timestamp": 1722904808851, + "timestamp": 1744390464532, "user": { "displayName": "Justin Fu", "userId": "17543197034567316452" @@ -1648,7 +1644,7 @@ "user_tz": 420 }, "id": "cTEyiMDyx9Y0", - "outputId": "1de26695-3713-430e-9ab4-4ea646691680" + "outputId": "70ce154e-dab2-4ae0-e297-c4774d29da85" }, "outputs": [ { @@ -1670,7 +1666,7 @@ "\n", "\n", "xla_result = jax.jit(\n", - " shard_map.shard_map(\n", + " jax.shard_map(\n", " lax_reduce_sum_scatter,\n", " mesh=mesh,\n", " in_specs=P(None, 'x'),\n", @@ -1705,11 +1701,18 @@ "\n", "### Next Steps\n", "\n", - "Excellent follow-up excercises for the reader could include implementing a distributed matrix multiplication, implementing `lax.all_to_all`, and relaxing synchronization to allow for additional run-ahead." + "Excellent follow-up exercises for the reader could include implementing a distributed matrix multiplication, implementing `lax.all_to_all`, and relaxing synchronization to allow for additional run-ahead." ] } ], "metadata": { + "colab": { + "last_runtime": { + "build_target": "//experimental/users/justinfu/pallas:colab", + "kind": "private" + }, + "provenance": [] + }, "jupytext": { "formats": "ipynb,md:myst", "main_language": "python" @@ -1733,5 +1736,5 @@ } }, "nbformat": 4, - "nbformat_minor": 4 + "nbformat_minor": 0 } diff --git a/docs/pallas/tpu/distributed.md b/docs/pallas/tpu/distributed.md index c1f216c6153e..9d4efd3195f4 100644 --- a/docs/pallas/tpu/distributed.md +++ b/docs/pallas/tpu/distributed.md @@ -17,30 +17,30 @@ kernelspec: # Distributed Computing in Pallas for TPUs -In this tutorial, we will cover the basics of distributed computing in Pallas on TPUs. We will learn about TPU topologies, communication using the remote DMA primitive, and calling a distributed kernel from JAX using `shard_map`. We will also cover some more advanced kernel writing techniques, such as double-buffering, bi-directional bandwidth optimization, and nested pipelining. As educational examples, we will learn how to implement various collective primitives from JAX, such as `lax.ppermute`, `lax.all_gather`, `lax.psum`, and `lax.psum_scatter`. +In this tutorial, we will cover the basics of distributed computing in Pallas on TPUs. We will learn about TPU topologies, communication using the remote DMA primitive, and calling a distributed kernel from JAX using `jax.shard_map`. We will also cover some more advanced kernel writing techniques, such as double-buffering, bi-directional bandwidth optimization, and nested pipelining. As educational examples, we will learn how to implement various collective primitives from JAX, such as `lax.ppermute`, `lax.all_gather`, `lax.psum`, and `lax.psum_scatter`. Some recommended readings beforehand: - [Pallas Pipelining on TPU](pallas_tpu_pipelining) - - [Collectives with `shard_map`](shard_map_collectives_tutorial) + - [Collectives with `jax.shard_map`](shard_map_collectives_tutorial) ```{code-cell} ipython3 --- executionInfo: - elapsed: 1978 + elapsed: 52 status: ok - timestamp: 1722904801801 + timestamp: 1744390458993 user: displayName: Justin Fu userId: '17543197034567316452' user_tz: 420 id: PyAGnWc9yI8T -outputId: 1d8229bd-cab5-495f-93e9-fff2e41db480 +outputId: c5912653-c34b-4810-c373-4a2787691317 --- +import functools import jax from jax import lax from jax import numpy as jnp from jax.experimental import pallas as pl -from jax.experimental import shard_map from jax.experimental.pallas import tpu as pltpu P = jax.sharding.PartitionSpec @@ -61,7 +61,7 @@ TPUs pods are typically arranged in an ND torus topology. The following graphic ![tpu_topologies](https://cloud.google.com/static/tpu/docs/images/v4-topologies.png) -Flattened as a graph, the torus can be visualized as follows. Each edge (orange or black) is a bidirectional connection between two devices. You will commonly hear about rings in conjunction with discussion about device toplogies — a key feature of a torus is that when taking a slice along an axis of the pod, such as the nodes `[(0,1), (1, 1), (2, 1), (3, 1)]` or `[(0, 1), (1, 1)]`, we have a ring of devices. This is a feature we can use to simplify communication patterns within the pod. +Flattened as a graph, the torus can be visualized as follows. Each edge (orange or black) is a bidirectional connection between two devices. You will commonly hear about rings in conjunction with discussion about device topologies — a key feature of a torus is that when taking a slice along an axis of the pod, such as the nodes `[(0,1), (1, 1), (2, 1), (3, 1)]` or `[(0, 1), (1, 1)]`, we have a ring of devices. This is a feature we can use to simplify communication patterns within the pod. ![tpu_torus](https://cloud.google.com/static/tpu/docs/images/untwisted-tori.png) @@ -163,7 +163,7 @@ def example_kernel(input_ref, output_ref, send_sem, recv_sem): `send_sem` and `recv_sem` are instances of a special type of semaphore reserved exclusively for use with DMAs. They must be allocated with the `tpu.SemaphoreType.DMA` type when specifying input specs to `pallas_call`. -Internally, DMA semaphores can be thought of as integer-valued progress trackers. On DMA start, the local device will begin to increment the value of `send_sem` and the receiver's `recv_sem` asynchronously. Waiting on a semaphore will block until the value of the semaphore reaches the total bytes of data sent/received; when the value is reached, waiting threads are released and the sempahore's value is decremented by the same amount. This means that either all data has been sent (for `send_sem`) or all data has been received (for `dst_sem`). The value of the semaphore can be read with `pl.semaphore_read`, but note that the underlying semantics of the value could change between hardware generations (e.g. the value may not represent exactly the number of bytes sent, although this is a useful mental model to have when reasoning about the behavior of the semaphore). +Internally, DMA semaphores can be thought of as integer-valued progress trackers. On DMA start, the local device will begin to increment the value of `send_sem` and the receiver's `recv_sem` asynchronously. Waiting on a semaphore will block until the value of the semaphore reaches the total bytes of data sent/received; when the value is reached, waiting threads are released and the semaphore's value is decremented by the same amount. This means that either all data has been sent (for `send_sem`) or all data has been received (for `recv_sem`). The value of the semaphore can be read with `pl.semaphore_read`, but note that the underlying semantics of the value could change between hardware generations (e.g. the value may not represent exactly the number of bytes sent, although this is a useful mental model to have when reasoning about the behavior of the semaphore). ### Routing @@ -195,15 +195,15 @@ In order to call the kernel in distributed mode, we wrap the `pallas_call` in a ```{code-cell} ipython3 --- executionInfo: - elapsed: 1606 + elapsed: 152 status: ok - timestamp: 1722904803566 + timestamp: 1744390459367 user: displayName: Justin Fu userId: '17543197034567316452' user_tz: 420 id: YkyIKN2thZ-V -outputId: 9b7ed142-d161-4237-fed8-cbce41adc5f0 +outputId: 26719bb9-87ff-46dd-af90-a114ce332417 --- partition = P(None, 'x') mesh = jax.make_mesh((num_devices,), ('x',)) @@ -233,11 +233,11 @@ def right_permute_kernel(input_ref, output_ref, send_sem, recv_sem): out_shape = jax.ShapeDtypeStruct((8, 128), jnp.float32) grid_spec = pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, - # TPUMemorySpace.ANY will (usually) place the tensor in HBM. + # MemorySpace.ANY will (usually) place the tensor in HBM. in_specs=[ - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY), ], - out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + out_specs=pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY), scratch_shapes=( # We allocate DMA semaphores in scratch memory. [pltpu.SemaphoreType.DMA] * 2 @@ -250,12 +250,12 @@ right_permute = pl.pallas_call( ) # Wrap the kernel within a shard_map to call. pallas_result = jax.jit( - shard_map.shard_map( + jax.shard_map( right_permute, mesh=mesh, in_specs=partition, out_specs=partition, - check_rep=False, + check_vma=False, ) )(input_arr) @@ -263,7 +263,7 @@ pallas_result = jax.jit( perm = tuple((src, (src + 1) % num_devices) for src in range(num_devices)) xla_result = jax.jit( - shard_map.shard_map( + jax.shard_map( lambda x: lax.ppermute(x, 'x', perm), mesh=mesh, in_specs=partition, out_specs=partition) )(input_arr) @@ -296,15 +296,15 @@ We can re-purpose Pallas's `grid` argument to implement the loop. Rather than it ```{code-cell} ipython3 --- executionInfo: - elapsed: 812 + elapsed: 209 status: ok - timestamp: 1722904804531 + timestamp: 1744390459789 user: displayName: Justin Fu userId: '17543197034567316452' user_tz: 420 id: ojQEZB5mBRqM -outputId: e1648f54-737c-4921-ca3b-b4c639a38d2b +outputId: 3a4373f8-1fb5-4a6b-b88e-3461c2609021 --- partition = P('x', None) mesh = jax.make_mesh((num_devices,), ('x',)) @@ -356,10 +356,10 @@ out_shape = jax.ShapeDtypeStruct((num_devices, 8, 128), jnp.float32) grid_spec = pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, in_specs=[ - # TPUMemorySpace.ANY will (usually) place the tensor in HBM. - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + # MemorySpace.ANY will (usually) place the tensor in HBM. + pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY), ], - out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + out_specs=pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY), scratch_shapes=( # DMA semaphores are allocated in scratch memory. # We allocated one semaphore for a local HBM-VMEM copy, @@ -383,18 +383,18 @@ all_gather = pl.pallas_call( # Wrap the kernel within a shard_map to call. pallas_result = jax.jit( - shard_map.shard_map( + jax.shard_map( all_gather, mesh=mesh, in_specs=partition, out_specs=partition, - check_rep=False + check_vma=False ) )(input_arr) # Compare Pallas result to XLA shard_map result. xla_result = jax.jit( - shard_map.shard_map( + jax.shard_map( lambda x: lax.all_gather(x, 'x'), mesh=mesh, in_specs=partition, out_specs=partition ) @@ -409,9 +409,9 @@ print('Difference |Pallas - lax.all_gather| = ', +++ {"id": "KgU7HI2pS4om"} -A detail worth mentioning here is the use of multiple receive semaphores. Because we only block on the receiving device, it is still possible for a sender to have sent multiple DMAs in flight before the receiver has finished processing the first one (see the next section and reduce-sum example which discusses race conditions in more detail). In this situation we may hit a situation where the same semaphore is being used for multiple DMAs occurring simultaneously. To avoid this, we allocate `num_devices-1` semaphores so there is no risk of re-use. While this race condition is unlikely to happen on such a small kernel, on larger kernels there is more chance for devices to fall out of sync and potentially cause a silent failure. +A detail worth mentioning here is the use of multiple receive semaphores. Because we only block on the receiving device, it is still possible for a sender to have sent multiple DMAs in flight before the receiver has finished processing the first one (see the next section and reduce-sum example which discusses race conditions in more detail). In this situation we may hit a situation where the same semaphore is being used for multiple DMAs occurring simultaneously. To avoid this, we allocate `num_devices-1` semaphores so there is no risk of reuse. While this race condition is unlikely to happen on such a small kernel, on larger kernels there is more chance for devices to fall out of sync and potentially cause a silent failure. -+++ {"id": "KgU7HI2pS4om"} ++++ {"id": "EDCmAaHVtY7x"} ## Advanced Techniques @@ -451,9 +451,9 @@ def semaphore_read( In order to use regular semaphores, they can be allocated in the same way as a DMA semaphore, but by specifying `pltpu.SemaphoreType.REGULAR` rather than `pltpu.SemaphoreType.DMA`. -Semaphores must be zero at the end of a Pallas program to complete succesfully. There are two error cases where this may happen: +Semaphores must be zero at the end of a Pallas program to complete successfully. There are two error cases where this may happen: - If a semaphore is over-signaled, the program will end with non-zero (>0) semaphores. In this case, the program will crash upon completion. This is useful for debugging as non-zero semaphores typically means there is a bug somewhere inside of the program. - - If a semaphore is over-waited, the program will hang on the blocking `semaphore_wait` call while it waits for the sempahore to be incremented. In this case the device or program will need to be restarted. + - If a semaphore is over-waited, the program will hang on the blocking `semaphore_wait` call while it waits for the semaphore to be incremented. In this case the device or program will need to be restarted. #### Barrier Semaphores @@ -491,7 +491,7 @@ When using barrier semaphores, the `collective_id` compiler parameter must be pa kernel = pl.pallas_call( example_kernel, ..., - compiler_params=pltpu.TPUCompilerParams(collective_id=0), + compiler_params=pltpu.CompilerParams(collective_id=0), ) ``` @@ -556,22 +556,22 @@ The prologue (executed when `outer_step==0`) first initiates a barrier with both The main body assumes that a value has already been copied into our local working slot, either from the previous iteration or from the prologue. A complicating factor is that our destination buffers live in HBM, but we need to load values to VMEM before we perform arithmetic. Therefore, we simultaneously copy the working slot value into our VMEM (`receive_scratch`) and pass the value on to our right neighbor's receiving slot. Once the value has been copied into our VMEM, we can accumulate it into our result (contained in `o_ref`). -A subtle race condition can occur if one device runs one loop ahead of it's right neighbor. In this case, it could copy into the receiver's `working_slot` at the same time the receiver is reading from it. In order to avoid this, each device will block on a `REGULAR` semaphore before copying into the right neighbor's `dst_ref` until it has signaled that it is done reading from its `working_slot`. This race condition is rarely triggered for a small kernel such as this example, but can it can be explicitly triggered if for example using a `pltpu.delay` instruction to artifically hang a device. +A subtle race condition can occur if one device runs one loop ahead of it's right neighbor. In this case, it could copy into the receiver's `working_slot` at the same time the receiver is reading from it. In order to avoid this, each device will block on a `REGULAR` semaphore before copying into the right neighbor's `dst_ref` until it has signaled that it is done reading from its `working_slot`. This race condition is rarely triggered for a small kernel such as this example, but can it can be explicitly triggered if for example using a `pltpu.delay` instruction to artificially hang a device. Note that this is not an optimal or fully general kernel, as the block sizes must entirely fit in VMEM and we could better interleave communication and accumulation. We will discuss these optimizations in later sections. ```{code-cell} ipython3 --- executionInfo: - elapsed: 254 + elapsed: 248 status: ok - timestamp: 1722904804952 + timestamp: 1744390460289 user: displayName: Justin Fu userId: '17543197034567316452' user_tz: 420 id: XrY5bMlvBroQ -outputId: 77497000-4496-462e-cc3c-73fb640cc14c +outputId: 9216e749-48d2-43ff-d64b-bd419acf3e11 --- partition = P(None, 'x') mesh = jax.make_mesh((num_devices,), ('x',)) @@ -581,6 +581,41 @@ input_arr = jax.random.uniform(jax.random.key(0), shape=(8, 128 * num_devices)) input_arr = jax.device_put(input_arr, sharding) +def local_barrier(left_neighbor, right_neighbor, double_barrier=True): + """Performs a barrier with neighbors on the global barrier semaphore. + + Optionally performs a second barrier, which prevents a potential race + when reusing the same collective_id across kernel invocations. + """ + barrier_sem = pltpu.get_barrier_semaphore() + for neighbor in [left_neighbor, right_neighbor]: + pltpu.semaphore_signal( + barrier_sem, + inc=1, + device_id=(neighbor,), + device_id_type=pltpu.DeviceIdType.MESH, + ) + pltpu.semaphore_wait(barrier_sem, 2) + if double_barrier: + # The double-barrier prevents a race condition where one neighbor can + # re-enter the kernel again on a subsequent call and increment the + # barrier semaphore a second time. This would unblock the current device + # even if the other neighbor is not ready yet. + # To implement a double-barrier, we stack-allocate a second REGULAR + # semaphore using run_scoped. + @functools.partial(pl.run_scoped, + second_barrier=pltpu.SemaphoreType.REGULAR) + def _(second_barrier): + for neighbor in [left_neighbor, right_neighbor]: + pltpu.semaphore_signal( + second_barrier, + inc=1, + device_id=(neighbor,), + device_id_type=pltpu.DeviceIdType.MESH, + ) + pltpu.semaphore_wait(second_barrier, 2) + + def all_reduce_kernel( x_ref, o_ref, @@ -603,20 +638,7 @@ def all_reduce_kernel( def _(): # Barrier with both neighbors at the start, since we will be # communicating with both. - barrier_sem = pltpu.get_barrier_semaphore() - pltpu.semaphore_signal( - barrier_sem, - inc=1, - device_id=(left_neighbor,), - device_id_type=pltpu.DeviceIdType.MESH, - ) - pltpu.semaphore_signal( - barrier_sem, - inc=1, - device_id=(right_neighbor,), - device_id_type=pltpu.DeviceIdType.MESH, - ) - pltpu.semaphore_wait(barrier_sem, 2) + local_barrier(left_neighbor, right_neighbor) # Initialize o_ref, acc_scratch, and hbm_scratch. o_ref[...] = jnp.zeros_like(o_ref) @@ -681,13 +703,13 @@ grid_spec = pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, in_specs=[ # Our input lives in VMEM - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), + pl.BlockSpec(memory_space=pltpu.MemorySpace.VMEM), ], out_specs=[ # Our output lives in VMEM - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), + pl.BlockSpec(memory_space=pltpu.MemorySpace.VMEM), # Our double-buffer lives in HBM - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY), ], grid=(num_devices,), scratch_shapes=( @@ -701,16 +723,16 @@ kernel = pl.pallas_call( all_reduce_kernel, out_shape=out_shape, grid_spec=grid_spec, - compiler_params=pltpu.TPUCompilerParams(collective_id=0), + compiler_params=pltpu.CompilerParams(collective_id=0), ) pallas_result = jax.jit( - shard_map.shard_map( + jax.shard_map( kernel, mesh=mesh, in_specs=partition, out_specs=partition, - check_rep=False, + check_vma=False, ) )(input_arr) pallas_result = jax.block_until_ready(pallas_result)[0] @@ -721,7 +743,7 @@ def lax_sum(x): xla_result = jax.jit( - shard_map.shard_map( + jax.shard_map( lax_sum, mesh=mesh, in_specs=P(None, 'x'), out_specs=P(None, 'x') ) )(input_arr) @@ -772,9 +794,9 @@ In terms of construction of the kernel, we introduce an additional `phase` dimen ```{code-cell} ipython3 --- executionInfo: - elapsed: 544 + elapsed: 362 status: ok - timestamp: 1722904805699 + timestamp: 1744390460871 user: displayName: Justin Fu userId: '17543197034567316452' @@ -890,20 +912,7 @@ def reduce_scatter_kernel( def _(): # Barrier with both neighbors at the start, since we will be # communicating with both. - barrier_sem = pltpu.get_barrier_semaphore() - pltpu.semaphore_signal( - barrier_sem, - inc=1, - device_id=(left_neighbor,), - device_id_type=pltpu.DeviceIdType.MESH, - ) - pltpu.semaphore_signal( - barrier_sem, - inc=1, - device_id=(right_neighbor,), - device_id_type=pltpu.DeviceIdType.MESH, - ) - pltpu.semaphore_wait(barrier_sem, 2) + local_barrier(left_neighbor, right_neighbor) # Initialize o_ref, acc_scratch, and hbm_scratch with initial copies. o_ref[...] = jnp.zeros_like(o_ref[...]) @@ -1010,11 +1019,11 @@ out_shape = ( grid_spec = pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, in_specs=[ - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), + pl.BlockSpec(memory_space=pltpu.MemorySpace.VMEM), ], out_specs=[ - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + pl.BlockSpec(memory_space=pltpu.MemorySpace.VMEM), + pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY), ], grid=(num_devices, 2), scratch_shapes=( @@ -1033,17 +1042,17 @@ def pallas_reduce_scatter(input_arr): reduce_scatter_kernel, out_shape=out_shape, grid_spec=grid_spec, - compiler_params=pltpu.TPUCompilerParams(collective_id=0), + compiler_params=pltpu.CompilerParams(collective_id=0), )(input_arr)[0] pallas_result = jax.jit( - shard_map.shard_map( + jax.shard_map( pallas_reduce_scatter, mesh=mesh, in_specs=P(None, 'x'), out_specs=P('x', None), - check_rep=False, + check_vma=False, ) )(input_arr) @@ -1053,15 +1062,15 @@ pallas_result = jax.block_until_ready(pallas_result) ```{code-cell} ipython3 --- executionInfo: - elapsed: 596 + elapsed: 917 status: ok - timestamp: 1722904806442 + timestamp: 1744390461967 user: displayName: Justin Fu userId: '17543197034567316452' user_tz: 420 id: E-NMh-_teoi4 -outputId: 24beb42f-1bdd-4c34-e8d2-681dd7f2e9c0 +outputId: 6c8b82bc-ed64-4cc1-8c5f-65e29cdb333c --- # Compare our result to XLA. def lax_reduce_sum_scatter(x): @@ -1070,7 +1079,7 @@ def lax_reduce_sum_scatter(x): xla_result = jax.jit( - shard_map.shard_map( + jax.shard_map( lax_reduce_sum_scatter, mesh=mesh, in_specs=P(None, 'x'), @@ -1139,7 +1148,7 @@ pl.pallas_call( In this next example we will modify our previous reduce-scatter example to utilize a nested inner pipeline. Note that the communication and computation costs of `reduce_scatter` both scale linearly with the size of the input, so we do not necessarily expect to see the operation become compute-bound with larger block sizes. This example is purely for demonstration purposes on how to use the pipeline emitter. -We will increase the block sizes of the outer kernel such that they would be undesirable to place inside of VMEM, and allocate all inputs and outputs in HBM (`memory_space=TPUMemorySpace.Any`). The only major change from our previous kernel is the body of the kernel where accumulation is done. Rather than manually copying from HBM to VMEM, accumulating, and copying back to HBM, we use `emit_pipeline` to handle the memory transfers for us. Accumulation is done in an inner kernel with a much smaller, VMEM-friendly block size. +We will increase the block sizes of the outer kernel such that they would be undesirable to place inside of VMEM, and allocate all inputs and outputs in HBM (`memory_space=MemorySpace.ANY`). The only major change from our previous kernel is the body of the kernel where accumulation is done. Rather than manually copying from HBM to VMEM, accumulating, and copying back to HBM, we use `emit_pipeline` to handle the memory transfers for us. Accumulation is done in an inner kernel with a much smaller, VMEM-friendly block size. In our previous kernel we had the following kernel body to copy data from HBM to the VMEM accumulator, increment, and then copy the results back to HBM: @@ -1197,9 +1206,9 @@ The full kernel is as follows: ```{code-cell} ipython3 --- executionInfo: - elapsed: 1341 + elapsed: 997 status: ok - timestamp: 1722904807930 + timestamp: 1744390463178 user: displayName: Justin Fu userId: '17543197034567316452' @@ -1233,7 +1242,7 @@ inner_grid = ( inner_block_spec = pl.BlockSpec( index_map=lambda i, j: (i, j), block_shape=inner_block_size, - memory_space=pltpu.TPUMemorySpace.ANY, + memory_space=pltpu.MemorySpace.ANY, ) @@ -1308,20 +1317,7 @@ def reduce_scatter_kernel( def _(): # Barrier with both neighbors at the start, since we will be # communicating with both. - barrier_sem = pltpu.get_barrier_semaphore() - pltpu.semaphore_signal( - barrier_sem, - inc=1, - device_id=(left_neighbor,), - device_id_type=pltpu.DeviceIdType.MESH, - ) - pltpu.semaphore_signal( - barrier_sem, - inc=1, - device_id=(right_neighbor,), - device_id_type=pltpu.DeviceIdType.MESH, - ) - pltpu.semaphore_wait(barrier_sem, 2) + local_barrier(left_neighbor, right_neighbor) initial_left_copy.start() initial_left_copy.wait() @@ -1428,11 +1424,11 @@ out_shape = ( grid_spec = pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, in_specs=[ - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY), ], out_specs=[ - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY), + pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY), ], grid=(num_devices, 2), scratch_shapes=( @@ -1450,17 +1446,17 @@ def pallas_reduce_scatter(input_arr): reduce_scatter_kernel, out_shape=out_shape, grid_spec=grid_spec, - compiler_params=pltpu.TPUCompilerParams(collective_id=0), + compiler_params=pltpu.CompilerParams(collective_id=0), )(input_arr)[0] pallas_result = jax.jit( - shard_map.shard_map( + jax.shard_map( pallas_reduce_scatter, mesh=mesh, in_specs=P(None, 'x'), out_specs=P('x', None), - check_rep=False, + check_vma=False, ) )(input_arr) @@ -1470,15 +1466,15 @@ pallas_result = jax.block_until_ready(pallas_result) ```{code-cell} ipython3 --- executionInfo: - elapsed: 768 + elapsed: 1132 status: ok - timestamp: 1722904808851 + timestamp: 1744390464532 user: displayName: Justin Fu userId: '17543197034567316452' user_tz: 420 id: cTEyiMDyx9Y0 -outputId: 1de26695-3713-430e-9ab4-4ea646691680 +outputId: 70ce154e-dab2-4ae0-e297-c4774d29da85 --- # Now we compare our result to XLA. def lax_reduce_sum_scatter(x): @@ -1487,7 +1483,7 @@ def lax_reduce_sum_scatter(x): xla_result = jax.jit( - shard_map.shard_map( + jax.shard_map( lax_reduce_sum_scatter, mesh=mesh, in_specs=P(None, 'x'), @@ -1518,4 +1514,4 @@ In this tutorial we covered several kernel examples which replicate the function ### Next Steps -Excellent follow-up excercises for the reader could include implementing a distributed matrix multiplication, implementing `lax.all_to_all`, and relaxing synchronization to allow for additional run-ahead. +Excellent follow-up exercises for the reader could include implementing a distributed matrix multiplication, implementing `lax.all_to_all`, and relaxing synchronization to allow for additional run-ahead. diff --git a/docs/pallas/tpu/matmul.ipynb b/docs/pallas/tpu/matmul.ipynb index 9c90add16ab0..dbe9747c4884 100644 --- a/docs/pallas/tpu/matmul.ipynb +++ b/docs/pallas/tpu/matmul.ipynb @@ -210,7 +210,7 @@ " pl.BlockSpec((bk, bn), lambda i, j, k: (k, j))],\n", " out_specs=pl.BlockSpec((bm, bn), lambda i, j, k: (i, j)),\n", " grid=(m // bm, n // bn, k // bk),\n", - " compiler_params=pltpu.TPUCompilerParams(\n", + " compiler_params=pltpu.CompilerParams(\n", " dimension_semantics=(\"parallel\", \"parallel\", \"arbitrary\")),\n", " )(x, y)" ] @@ -466,7 +466,7 @@ " grid=(m // bm, n // bn, k // bk),\n", " ),\n", " out_shape=jax.ShapeDtypeStruct((m, n), x.dtype),\n", - " compiler_params=pltpu.TPUCompilerParams(\n", + " compiler_params=pltpu.CompilerParams(\n", " dimension_semantics=(\"parallel\", \"parallel\", \"arbitrary\")),\n", " )(x, y)" ] @@ -496,7 +496,14 @@ "\n", "Our above analysis about FLOPs vs memory usage applies at a coarse scale i.e. when we are looking at the the size of a the total matrix multiplication. However, remember that in practice, we are pipelining the execution of a blocked matrix multiplication, meaning we have a loop in which we are doing matrix multiplication with smaller blocks.\n", "\n", - "This means that we actually care about the FLOPs vs memory bandwidth usage of each individual instance of the kernel, not the global FLOPs vs memory bandwidth usage. Therefore, the block sizes `bm`, `bk`, `bn` are extremely important for performance. Even if we have the largest matrices in the world, if we pick very small `bm`, `bk`, and `bn`, we will be memory bound because each time we invoke the kernel we will have too few FLOPs to hide the memory transfers happening in the background.\n", + "This means that we actually care about the FLOPs vs memory bandwidth usage of each individual instance of the kernel, not the global FLOPs vs memory bandwidth usage.\n", + "\n", + "In addition, when tiling the matmul operation, the same values could be read multiple times from memory.\n", + "Specifically the memory bandwidth for the first operand of the kernel is `(bm * bk)`, which needs to be multiplied by the grid dimensions, that is `(bm * bk) * m // bm * n // bn * k // bk = m * k * n // bn`.\n", + "Similarly for the second operand, yielding a total bandwidth usage `(m * k * n // bn + k * n * m // bm + m * n) * element_size`.\n", + "\n", + "Therefore, the block sizes `bm`, `bk`, `bn` are extremely important for performance.\n", + " Even if we have the largest matrices in the world, if we pick very small `bm`, `bk`, and `bn`, we will be memory bound because each time we invoke the kernel we will have too few FLOPs to hide the memory transfers happening in the background.\n", "\n", "The intuition should therefore be: to be compute bound, make the blocks as big as possible! There are two main constraints:\n", "\n", @@ -741,7 +748,7 @@ " grid=(m // bm, n // bn, k // bk),\n", " ),\n", " out_shape=jax.ShapeDtypeStruct((m, n), x.dtype),\n", - " compiler_params=pltpu.TPUCompilerParams(\n", + " compiler_params=pltpu.CompilerParams(\n", " dimension_semantics=(\"parallel\", \"parallel\", \"arbitrary\")),\n", " )(x, y)" ] @@ -929,7 +936,7 @@ " grid=(m // bm, n // bn, k // bk),\n", " ),\n", " out_shape=jax.ShapeDtypeStruct((m, n), x.dtype),\n", - " compiler_params=pltpu.TPUCompilerParams(\n", + " compiler_params=pltpu.CompilerParams(\n", " dimension_semantics=(\"parallel\", \"parallel\", \"arbitrary\")),\n", " )(x, y)" ] diff --git a/docs/pallas/tpu/matmul.md b/docs/pallas/tpu/matmul.md index 42084f12d5f5..509d47093af7 100644 --- a/docs/pallas/tpu/matmul.md +++ b/docs/pallas/tpu/matmul.md @@ -167,7 +167,7 @@ def matmul( pl.BlockSpec((bk, bn), lambda i, j, k: (k, j))], out_specs=pl.BlockSpec((bm, bn), lambda i, j, k: (i, j)), grid=(m // bm, n // bn, k // bk), - compiler_params=pltpu.TPUCompilerParams( + compiler_params=pltpu.CompilerParams( dimension_semantics=("parallel", "parallel", "arbitrary")), )(x, y) ``` @@ -321,7 +321,7 @@ def matmul( grid=(m // bm, n // bn, k // bk), ), out_shape=jax.ShapeDtypeStruct((m, n), x.dtype), - compiler_params=pltpu.TPUCompilerParams( + compiler_params=pltpu.CompilerParams( dimension_semantics=("parallel", "parallel", "arbitrary")), )(x, y) ``` @@ -342,7 +342,14 @@ np.testing.assert_array_equal(x @ y, matmul(x, y)) Our above analysis about FLOPs vs memory usage applies at a coarse scale i.e. when we are looking at the the size of a the total matrix multiplication. However, remember that in practice, we are pipelining the execution of a blocked matrix multiplication, meaning we have a loop in which we are doing matrix multiplication with smaller blocks. -This means that we actually care about the FLOPs vs memory bandwidth usage of each individual instance of the kernel, not the global FLOPs vs memory bandwidth usage. Therefore, the block sizes `bm`, `bk`, `bn` are extremely important for performance. Even if we have the largest matrices in the world, if we pick very small `bm`, `bk`, and `bn`, we will be memory bound because each time we invoke the kernel we will have too few FLOPs to hide the memory transfers happening in the background. +This means that we actually care about the FLOPs vs memory bandwidth usage of each individual instance of the kernel, not the global FLOPs vs memory bandwidth usage. + +In addition, when tiling the matmul operation, the same values could be read multiple times from memory. +Specifically the memory bandwidth for the first operand of the kernel is `(bm * bk)`, which needs to be multiplied by the grid dimensions, that is `(bm * bk) * m // bm * n // bn * k // bk = m * k * n // bn`. +Similarly for the second operand, yielding a total bandwidth usage `(m * k * n // bn + k * n * m // bm + m * n) * element_size`. + +Therefore, the block sizes `bm`, `bk`, `bn` are extremely important for performance. + Even if we have the largest matrices in the world, if we pick very small `bm`, `bk`, and `bn`, we will be memory bound because each time we invoke the kernel we will have too few FLOPs to hide the memory transfers happening in the background. The intuition should therefore be: to be compute bound, make the blocks as big as possible! There are two main constraints: @@ -489,7 +496,7 @@ def matmul( grid=(m // bm, n // bn, k // bk), ), out_shape=jax.ShapeDtypeStruct((m, n), x.dtype), - compiler_params=pltpu.TPUCompilerParams( + compiler_params=pltpu.CompilerParams( dimension_semantics=("parallel", "parallel", "arbitrary")), )(x, y) ``` @@ -613,7 +620,7 @@ def matmul( grid=(m // bm, n // bn, k // bk), ), out_shape=jax.ShapeDtypeStruct((m, n), x.dtype), - compiler_params=pltpu.TPUCompilerParams( + compiler_params=pltpu.CompilerParams( dimension_semantics=("parallel", "parallel", "arbitrary")), )(x, y) ``` diff --git a/docs/pallas/tpu/pipelining.ipynb b/docs/pallas/tpu/pipelining.ipynb index 10de587105f2..829cda000e5d 100644 --- a/docs/pallas/tpu/pipelining.ipynb +++ b/docs/pallas/tpu/pipelining.ipynb @@ -2,8 +2,9 @@ "cells": [ { "cell_type": "markdown", - "id": "7704d3bb", - "metadata": {}, + "metadata": { + "id": "7704d3bb" + }, "source": [ "(pallas_tpu_pipelining)=" ] @@ -14,7 +15,7 @@ "id": "teoJ_fUwlu0l" }, "source": [ - "# Pipelining\n", + "# TPU Pipelining\n", "\n", "" ] @@ -25,14 +26,24 @@ "id": "gAJDZh1gBh-h" }, "source": [ - "In this guide we'll cover how memory spaces in TPU work and how to write\n", - "pipelines in Pallas that overlap memory I/O with compute." + "This guide serves as a reference for TPU-specific pipelining concerns.\n", + "We'll review the memory hierarchy and compute units on TPUs, and TPU-specific features of the pipelining API. For a more general-purpose overview of pipelining, see the {ref}`pallas_software_pipelining`." ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": { + "executionInfo": { + "elapsed": 54, + "status": "ok", + "timestamp": 1744908474512, + "user": { + "displayName": "Justin Fu", + "userId": "17543197034567316452" + }, + "user_tz": 420 + }, "id": "ejAVO6ikUUuF" }, "outputs": [], @@ -48,9 +59,8 @@ }, { "cell_type": "markdown", - "id": "0e212a5e", "metadata": { - "id": "TWKESTKAlyjT" + "id": "0e212a5e" }, "source": [ "(tpu_and_its_memory_spaces)=\n", @@ -60,7 +70,9 @@ }, { "cell_type": "markdown", - "metadata": {}, + "metadata": { + "id": "NnWW9GV4kW6P" + }, "source": [ "A TPU and its TensorCore consist of memory spaces (where arrays can reside),\n", "registers (which temporarily store scalar and array values) and compute units\n", @@ -83,568 +95,81 @@ " Values can be loaded into memory from their respective caches (VMEM for\n", " VREGs and SMEM for SREGs).\n", "* **Compute units**: A TensorCore has a scalar unit, vector unit (VPU) and\n", - " matrix unit (MXU) that can do numerical computation.\n", + " matrix unit (MXU) that can do numerical computation. Each of these compute units can operate asynchronously, but this is managed by the TPU compiler and thus from the programmer's perspective a TPU program is single-threaded.\n", " Compute units operate on values that live in SREGs and VREGs and output\n", - " values into those registers as well.\n", - "\n", - "In order to do a vectorized computation on our values `x` and `y` that live\n", - "in HBM, we need to:\n", - "\n", - "1. Copy the values `x` and `y` into VMEM.\n", - "2. Load the values from VMEM into VREGs.\n", - "3. Execute the computation using the VPU or MXU, storing the output in VREGs.\n", - "4. Store the values in the output VREGs into VMEM.\n", - "5. Copy the output values in VMEM back to HBM." + " values into those registers as well." ] }, { "cell_type": "markdown", "metadata": { - "id": "TzctMbNsn3vc" + "id": "8Tl3wt5Wk3Ek" }, "source": [ - "Let's implement a Pallas function that does just that!" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "2IXQxNWrKJyb", - "outputId": "d62eb493-5f92-4496-f113-d3cd24cb0b9f" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "Array([[2., 2., 2., ..., 2., 2., 2.],\n", - " [2., 2., 2., ..., 2., 2., 2.],\n", - " [2., 2., 2., ..., 2., 2., 2.],\n", - " ...,\n", - " [2., 2., 2., ..., 2., 2., 2.],\n", - " [2., 2., 2., ..., 2., 2., 2.],\n", - " [2., 2., 2., ..., 2., 2., 2.]], dtype=float32)" - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "def add_matrices_kernel(x_vmem_ref, y_vmem_ref, z_vmem_ref):\n", - " # Load x and y from VMEM into VREGs\n", - " x_vregs = x_vmem_ref[:, :]\n", - " y_vregs = y_vmem_ref[:, :]\n", - " # Execute a vectorized add\n", - " z_vregs = x_vregs + y_vregs\n", - " # Store the output values in VREGs back into VMEM\n", - " z_vmem_ref[:, :] = z_vregs\n", - "\n", - "\n", - "def add_matrices(x: jax.Array, y: jax.Array) -> jax.Array:\n", - " # pallas_call will first allocate scratch buffers for `x` and `y` in VMEM.\n", - " # It will then copy `x` and `y` from HBM into VMEM.\n", - " z = pl.pallas_call(\n", - " add_matrices_kernel, out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype)\n", - " )(x, y)\n", - " # pallas_call will also copy the output from VMEM back into HBM.\n", - " return z\n", - "\n", - "\n", - "x, y = jnp.ones((512, 512)), jnp.ones((512, 512))\n", - "add_matrices(x, y)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "HMENNLy8okCL" - }, - "source": [ - "We've written two functions: `add_matrices_kernel` and `add_matrices`.\n", + "## TPU-specific Pipelining Features\n", "\n", - "`add_matrices_kernel` operates using `Ref`s that live in VMEM.\n", - "Loading from a VMEM `Ref` produces a value that lives in VREGs.\n", - "Values in VREGs behave like `jax.Array`s in that we can use `jnp` and\n", - "`jax.lax` operations on them to produce new values that live in VREGs.\n", - "When we produce the values we'd like to return, we store them in the output\n", - "VMEM `Ref`.\n", - "\n", - "The `add_matrices` function acts on `jax.Array`s and returns a `jax.Array`.\n", - "Inside it, we pass `x` and `y` into `pallas_call`.\n", - "`pallas_call` is responsible for copying `x` and `y` into VMEM and for\n", - "allocating the VMEM buffers that the kernel operates on (including allocating\n", - "`z_vmem_ref`, the output VMEM buffer).\n", - "After the kernel function is finished running, `pallas_call` will also copy\n", - "the value in `z_vmem_ref` to HBM, resulting in an output `jax.Array`." + "Pallas TPU supports the following platform-specific features." ] }, { "cell_type": "markdown", "metadata": { - "id": "5kWr-1tKpYro" + "id": "1jg5WmExk47l" }, "source": [ - "## Constraints of using VMEM/SMEM\n", - "\n", - "Pallas exposes access to lower level memory spaces like VMEM and SMEM but\n", - "writing kernels utilizing them adds some considerations.\n", + "### TPU Memory Spaces\n", "\n", - "1. Memory capacity. VMEM and SMEM are *small*! VMEM on v4 TPUs is only 16MiB\n", - " and SMEM ranges in the tens to hundreds of KiB.\n", - " If our arrays are too big, we won't even be able to fit them into VMEM at all.\n", - " For reference, a `f32[2048, 2048]` array is 16MiB, so our above kernel won't\n", - " scale beyond moderately sized arrays.\n", + "Pallas exposes all levels of the TPU memory hierarchy to users. The following table maps from Pallas TPU memory spaces to their standard memory types (DRAM/SRAM):\n", "\n", - "2. Memory bandwidth. Copying to/from HBM and VMEM takes a long time, at least\n", - " compared to most compute instructions.\n", - " The `add_matrices` function above will likely spend more time copying\n", - " between HBM and VMEM than actually performing the addition itself.\n", + "| Pallas Enum | TPU Memory Space | Type (DRAM/SRAM) |\n", + "| --- | --- | --- |\n", + "| `pltpu.MemorySpace.ANY` | HBM (usually) or VMEM | DRAM |\n", + "| `pltpu.MemorySpace.VMEM` | VMEM | SRAM |\n", + "| `pltpu.MemorySpace.SMEM` | SMEM | SRAM |\n", + "| `pltpu.MemorySpace.SEMAPHORE` | Semaphore | SRAM |\n", "\n", - "With these two constraints in mind, we'll have to rethink our strategy for\n", - "getting performance out of our TPUs." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "_NTqvlbetB3P" - }, - "source": [ - "## Primer: Pipelining\n", + "- `MemorySpace.VMEM` denotes vector SRAM. It is the default memory space if nothing is specified.\n", + "- `MemorySpace.SMEM` denotes scalar SRAM. Only scalar loads and stores can be performed to/from SMEM.\n", + "- `MemorySpace.ANY` is a hint to the compiler that the memory space is unconstrained. In most cases, XLA will place this buffer in HBM. A buffer assigned to the `ANY` memory space cannot be dereferenced normally using array indexing syntax (e.g. `x[...]`). Instead, we must first copy the values into a VMEM or SMEM buffer using `pltpu.sync_copy` or `pltpu.async_copy`.\n", + "- `MemorySpace.SEMAPHORE` is used to allocate semaphores for constructing barriers or tracking asynchronous operations. It is also possible to return semaphores from the kernel for building asynchronous kernels - this is an experimental feature; see {ref}`pallas_async` for more details.\n", "\n", - "Pipelining our computation offers a way of dealing with both the memory\n", - "capacity and bandwidth constraints in one fell swoop.\n", - "What do we mean by pipelining?\n", + "Pipelining on TPUs is typically done between HBM (DRAM) to VMEM (Vector SRAM). The default behavior for `pallas_call` on TPU is that arguments to `pallas_call` are assumed to live in HBM, and inputs to the user kernel body are stored in VMEM.\n", "\n", - "The goal is: *in parallel* copy to/from HBM and VMEM *while* utilizing our\n", - "compute units.\n", - "Naively this is difficult because in our program above we copy *all* of `x`\n", - "and `y` before we start doing any compute with them, creating a dependence\n", - "between the copy and the compute.\n", + "While not specific to pipelining, it is possible to gain manual control over the memory space of input and output buffers, you can specify the `memory_space` argument on a `BlockSpec`. Note that pipelining is not allowed unless the `memory_space` is marked as `VMEM`. Memory spaces can also be used to specify scratch arguments to a kernel via the `scratch_shapes` argument on `pallas_call`. Scratch buffers are persistent across kernel iterations and are useful for storing intermediate results such as partial accumulations and reductions. A scratch buffer must reside in `VMEM`, `SMEM`, or `SEMAPHORE`.\n", "\n", - "However, if we can chunk up our computation into several subcomputations\n", - "(e.g. when we add two matrices, we can express that as addition of \"blocks\"\n", - "of the original matrices together), we can now overlap the copies of one of\n", - "those subcomputations with the compute of the other. Let's walk through a\n", - "simple example:\n", - "\n", - "Let's say we split our arrays `x` and `y` into `x1, x2` and `y1, y2` (for\n", - "example, split along the leading axis, resulting in two `(256, 512)` arrays\n", - "for each input.\n", - "We can now execute the following pipelined computation.\n", - "\n", - "1. Copy `x1` and `y1` into VMEM.\n", - "1. Start copying `x2` and `y2` into VMEM\n", - "2. Load `x1, y1` from VMEM into VREGs.\n", - "3. Execute the `z1 = x1 + y1` using the compute units.\n", - "4. Store `z1` into VMEM.\n", - "5. Start copying `z1` from VMEM back into HBM.\n", - "6. Wait until `x2, y2` have been copied into VMEM.\n", - "7. Load `x2, y2` from VMEM into VREGs.\n", - "8. Execute the `z2 = x2 + y2` using the compute units.\n", - "9. Store `z2` into VMEM.\n", - "10. Wait until `z1` is copied into HBM.\n", - "10. Start copying `z2` from VMEM back into HBM.\n", - "10. Wait until `z2` is copied into HBM.\n", - "\n", - "Any time we are doing compute here, we are asynchronously copying something.\n", - "This means that some of the time spent copying is not wasted.\n", - "\n", - "The two most important numbers for determining how efficient a pipelined\n", - "computation are a) how many floating point operations (FLOPs) we need to\n", - "execute and b) how many bytes we need to copy to execute that computation.\n", - "The ratio of these two (FLOPs/memory usage) is called the\n", - "*arithmetic intensity* of an operation and determines if our pipeline will\n", - "be compute bound or memory bound." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "gutx7y8uvZKH" - }, - "source": [ - "## Pipelining in Pallas" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "U-dPTjlBverB" - }, - "source": [ - "How do we implement a pipeline like the one above in Pallas?\n", - "It seems like a complex sequence of asynchronous data operations and\n", - "executing kernels that would be a pain to implement manually.\n", - "Fear not! Pallas offers an API for expressing pipelines without too much\n", - "boilerplate, namely through `grid`s and `BlockSpec`s.\n", - "\n", - "See how in the above pipelined example, we are executing the same logic\n", - "multiple times: steps 3-5 and 8-10 both execute the same operations,\n", - "only on different inputs.\n", - "The {func}`jax.experimental.pallas.pallas_call` provides a way to\n", - "execute a kernel multiple times, by using the `grid` argument.\n", - "See {ref}`pallas_grid`.\n", - "\n", - "We also use {class}`jax.experimental.pallas.BlockSpec` to specify\n", - "how to construct the input of each kernel invocation.\n", - "See {ref}`pallas_blockspec`.\n", - "\n", - "In the pipelining example above, we had `(512, 512)`-shaped arrays and\n", - "split them along the leading dimension into two `(256, 512)`-shaped arrays.\n", - "In this pipeline, our `BlockSpec.block_shape` would be `(256, 512)`.\n", - "On the 1st iteration we'd\n", - "like to select `x1` and on the second iteration we'd like to use `x2`.\n", - "This can be expressed with the following `index_map`:\n", - "\n", - "```python\n", - "def x_index_map(i):\n", - " return (i, 0)\n", - "```\n", - "\n", - "We'd then construct the `BlockSpec`:\n", - "```python\n", - "block_spec = pl.BlockSpec((256, 512), x_index_map)\n", - "```\n", - "\n", - "The `BlockSpec`s for `y` and `z` will be the same as the one for `x`." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "noybOKghzjwG" - }, - "source": [ - "### Putting it together\n", - "\n", - "We provide these arguments to `pallas_call` via `grid`, `in_specs` and\n", - "`out_specs` (`in_specs` corresponds to the tuple of positional arguments,\n", - "and `out_specs` corresponds to the output)." + "As an example for using multiple manual memory space assignments in a kernel, the following program copies a slice of an HBM buffer `x_hbm_ref` into a scratch VMEM buffer `scratch_vmem_ref` before using it for arithmetic and storing the result into an output VMEM buffer:" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 8, "metadata": { - "id": "ehKAYAwIojfv", - "outputId": "504bab29-83f3-4e1f-8664-1860ad15b6de" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "Array([[2., 2., 2., ..., 2., 2., 2.],\n", - " [2., 2., 2., ..., 2., 2., 2.],\n", - " [2., 2., 2., ..., 2., 2., 2.],\n", - " ...,\n", - " [2., 2., 2., ..., 2., 2., 2.],\n", - " [2., 2., 2., ..., 2., 2., 2.],\n", - " [2., 2., 2., ..., 2., 2., 2.]], dtype=float32)" - ] + "executionInfo": { + "elapsed": 65, + "status": "ok", + "timestamp": 1744908591430, + "user": { + "displayName": "Justin Fu", + "userId": "17543197034567316452" }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "def add_matrices_pipelined(x: jax.Array, y: jax.Array) -> jax.Array:\n", - " block_spec = pl.BlockSpec((256, 512), lambda i: (i, 0))\n", - " return pl.pallas_call(\n", - " add_matrices_kernel,\n", - " out_shape=x,\n", - " in_specs=[block_spec, block_spec],\n", - " out_specs=block_spec,\n", - " grid=(2,)\n", - " )(x, y)\n", - "\n", - "add_matrices_pipelined(x, y)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "rkytgIZYzz4t" - }, - "source": [ - "We've only added a little bit of code to our original function to add\n", - "automatic pipelining but the `BlockSpec`s and `grid` do a lot of heavy\n", - "lifting!\n", - "\n", - "How does it work? Well, the `BlockSpec`s provide enough information to start\n", - "*prefetching* blocks of our input from HBM into VMEM.\n", - "For example, if we are starting iteration `i` of our `grid`, we can pass\n", - "`i + 1` into the `index_map` functions to obtain the blocks needed for the\n", - "next iteration. We can then start an asynchronous copy for those blocks.\n", - "Similarly for outputs, we can wait for the outputs of the previous iteration\n", - "to be copied before starting the copy for the current iteration's outputs." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "7Xtz9oMs0ZRL" - }, - "source": [ - "### Parameterizing a pipeline" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "esY4GcIB0bqQ" - }, - "source": [ - "It's common to parameterize the block shapes in our kernel. Block sizes are\n", - "perhaps the most important parameter to tune when optimizing the performance\n", - "of Pallas kernels! They give us control over the pipeline (for example,\n", - "picking smaller blocks adds more iterations to our pipelined loop where each\n", - "iteration has less work to do).\n", - "\n", - "Furthermore, we could also carve up the inputs and outputs along the 2nd\n", - "dimension (we are only splitting along the first right now). Let's write a\n", - "more general kernel that handles both of these features." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "VartelFd0YfY" + "user_tz": 420 + }, + "id": "zcqz1CA_o50a" }, "outputs": [], "source": [ - "def add_matrices_pipelined_2d(\n", - " x: jax.Array, y: jax.Array, *, bm: int = 256, bn: int = 256\n", - ") -> jax.Array:\n", - " m, n = x.shape\n", - " block_spec = pl.BlockSpec((bm, bn), lambda i, j: (i, j))\n", - " return pl.pallas_call(\n", - " add_matrices_kernel,\n", - " out_shape=x,\n", - " in_specs=[block_spec, block_spec],\n", - " out_specs=block_spec,\n", - " grid=(m // bm, n // bn),\n", - " )(x, y)\n", - "\n", - "np.testing.assert_array_equal(\n", - " add_matrices_pipelined_2d(x, y, bm=256, bn=256), x + y\n", - ")\n", - "np.testing.assert_array_equal(\n", - " add_matrices_pipelined_2d(x, y, bm=128, bn=128), x + y\n", - ")\n", - "np.testing.assert_array_equal(\n", - " add_matrices_pipelined_2d(x, y, bm=512, bn=512), x + y\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "KrfeYwaW1QA-" - }, - "source": [ - "## Handling reductions" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "P3SqEKDe3Mar" - }, - "source": [ - "How would you implement something like `jnp.sum` using `pallas_call`?\n", - "Specifically, we'd like to pipeline across the reduction dimension.\n", - "\n", - "Take the example of reducing a `(8, 512, 512)`-shaped array to a\n", - "`(512, 512)`-shaped one." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "JoT-ZKEk1R7l", - "outputId": "fd842223-98a5-4e5c-87fc-5dadc94da4fa" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "Array([[8., 8., 8., ..., 8., 8., 8.],\n", - " [8., 8., 8., ..., 8., 8., 8.],\n", - " [8., 8., 8., ..., 8., 8., 8.],\n", - " ...,\n", - " [8., 8., 8., ..., 8., 8., 8.],\n", - " [8., 8., 8., ..., 8., 8., 8.],\n", - " [8., 8., 8., ..., 8., 8., 8.]], dtype=float32)" - ] - }, - "execution_count": 6, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "x = jnp.ones((8, 512, 512))\n", - "jnp.sum(x, axis=0)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "5O3ByvuT3iyC" - }, - "source": [ - "To do this using `pallas_call`, we could use a grid of size `(8,)` and in\n", - "each iteration `i` load `x[i]` into VMEM.\n", - "Then we could add `x[i]` to an output VMEM buffer. Let's implement this\n", - "naively first." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "hqvv_WRQ3bvP", - "outputId": "200648d2-3f4d-4d1a-b95a-d2c1352cd7b8" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "Array([[9., 9., 9., ..., 9., 9., 9.],\n", - " [9., 9., 9., ..., 9., 9., 9.],\n", - " [9., 9., 9., ..., 9., 9., 9.],\n", - " ...,\n", - " [9., 9., 9., ..., 9., 9., 9.],\n", - " [9., 9., 9., ..., 9., 9., 9.],\n", - " [9., 9., 9., ..., 9., 9., 9.]], dtype=float32)" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# Warning: this implementation is incorrect!\n", - "\n", - "def naive_sum_kernel(x_ref, o_ref):\n", - " o_ref[...] += x_ref[...]\n", - "\n", - "def naive_sum(x: jax.Array) -> jax.Array:\n", - " grid, *out_shape = x.shape\n", - " return pl.pallas_call(\n", - " naive_sum_kernel,\n", - " grid=grid,\n", - " # None in `block_shape` means we pick a size of 1 and squeeze it away\n", - " in_specs=[pl.BlockSpec((None, *out_shape), lambda i: (i, 0, 0))],\n", - " out_specs=pl.BlockSpec(out_shape, lambda i: (0, 0)),\n", - " out_shape=jax.ShapeDtypeStruct(out_shape, x.dtype),\n", - " )(x)\n", - "naive_sum(x)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Kv9qJYJY4jbK" - }, - "source": [ - "Notice how we've set up the `BlockSpec`s: we're loading the entirety of\n", - "the `(512, 512)` dimension into VMEM (no pipelining there) but selecting\n", - "the `i`-th dimension of `x` each iteration in the `index_map`.\n", - "We are using a `None` for that dimension in the block shape, which indicates\n", - "that we are selecting a singleton dimension from `x` that we would like\n", - "to squeeze away in the kernel.\n", - "Therefore, `x_ref` is `(512, 512)`-shaped in VMEM as well.\n", - "\n", - "`out_spec` uses `lambda i: (0, 0)` as its `index_map`, indicating that\n", - "`o_ref` is unchanged over the course of the pipeline.\n", - "This means that we can update its value each iteration by reading from and\n", - "writing to it. Or can it?\n", - "Actually there is one catch: *`o_ref` is initially garbage*, meaning we'll\n", - "be accumulating into garbage.\n", - "This will result in the overall function outputting the incorrect value!\n", - "\n", - "Therefore, **whenever we do a reduction in a kernel, we need to make sure\n", - "to initialize the `Ref` that is storing the reduced value**.\n", - "We can accomplish this by conditionally writing a value to `out_ref`\n", - "when we're on iteration 0.\n", - "We can do this with the helper function `pl.when`, a convenience wrapper\n", - "around `jax.lax.cond`, and `pl.program_id`,\n", - "which queries which iteration in a grid axis we are in." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "JXN2RthX5cSw", - "outputId": "195df19b-a889-479b-95b6-1fb7281f1518" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "Array([[8., 8., 8., ..., 8., 8., 8.],\n", - " [8., 8., 8., ..., 8., 8., 8.],\n", - " [8., 8., 8., ..., 8., 8., 8.],\n", - " ...,\n", - " [8., 8., 8., ..., 8., 8., 8.],\n", - " [8., 8., 8., ..., 8., 8., 8.],\n", - " [8., 8., 8., ..., 8., 8., 8.]], dtype=float32)" - ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "def sum_kernel(x_ref, o_ref):\n", - " @pl.when(pl.program_id(axis=0) == 0)\n", - " def _():\n", - " o_ref[...] = jnp.zeros_like(o_ref)\n", + "def hbm_vmem_kernel(x_hbm_ref, out_vmem_ref, scratch_vmem_ref):\n", + " pltpu.sync_copy(x_hbm_ref.at[0:1], scratch_vmem_ref)\n", + " out_vmem_ref[...] = scratch_vmem_ref[...] + 1\n", "\n", - " o_ref[...] += x_ref[...]\n", - "\n", - "def sum(x: jax.Array) -> jax.Array:\n", - " grid, *out_shape = x.shape\n", - " return pl.pallas_call(\n", - " sum_kernel,\n", - " grid=grid,\n", - " # None in `block_shape` means we pick a size of 1 and squeeze it away\n", - " in_specs=[pl.BlockSpec((None, *out_shape), lambda i: (i, 0, 0))],\n", - " out_specs=pl.BlockSpec(out_shape, lambda i: (0, 0)),\n", - " out_shape=jax.ShapeDtypeStruct(out_shape, x.dtype)\n", - " )(x)\n", + "x = jax.random.uniform(jax.random.key(0), (8, 128), jnp.float32)\n", + "out = pl.pallas_call(hbm_vmem_kernel,\n", + " in_specs=[pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY)],\n", + " out_shape=jax.ShapeDtypeStruct((1, 128), jnp.float32),\n", + " scratch_shapes=(pltpu.MemorySpace.VMEM(shape=(1, 128), dtype=jnp.float32),)\n", + ")(x)\n", "\n", - "sum(x)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "2828qXBI5ksZ" - }, - "source": [ - "This `sum` function now outputs the correct values!\n", - "\n", - "One last thing to note about reductions in Pallas are that **they must be\n", - "done in the minormost (rightmost) dimensions of our grid** (our grid is\n", - "1-dimensional in the above example so we are reducing over its minormost\n", - "dimension). This is because the pipeline that Pallas generates using\n", - "the `BlockSpec`s, `grid` and kernel function *does not read outputs back\n", - "from HBM*.\n", - "Once you've written an output value back to HBM you cannot revisit it.\n", - "Therefore, you cannot do a reduction across a grid dimension that has any\n", - "revisiting and therefore all reductions need to happen in the rightmost\n", - "dimensions." + "np.testing.assert_allclose(out, x[0:1] + 1)" ] }, { @@ -655,7 +180,7 @@ "source": [ "(pallas_tpu_megacore)=\n", "\n", - "## TPUs in Megacore configuration" + "### TPUs in Megacore configuration" ] }, { @@ -683,10 +208,20 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 10, "metadata": { + "executionInfo": { + "elapsed": 106, + "status": "ok", + "timestamp": 1744910274556, + "user": { + "displayName": "Justin Fu", + "userId": "17543197034567316452" + }, + "user_tz": 420 + }, "id": "nQNa8RaQ-TR1", - "outputId": "385ed87c-d95c-466c-af77-df3845c979f2" + "outputId": "29c0b574-3528-49a5-8a88-b6987efc69ce" }, "outputs": [ { @@ -701,12 +236,21 @@ " [2., 2., 2., ..., 2., 2., 2.]], dtype=float32)" ] }, - "execution_count": 9, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ + "def add_matrices_kernel(x_vmem_ref, y_vmem_ref, z_vmem_ref):\n", + " # Load x and y from VMEM into VREGs\n", + " x_vregs = x_vmem_ref[:, :]\n", + " y_vregs = y_vmem_ref[:, :]\n", + " # Execute a vectorized add\n", + " z_vregs = x_vregs + y_vregs\n", + " # Store the output values in VREGs back into VMEM\n", + " z_vmem_ref[:, :] = z_vregs\n", + "\n", "def add_matrices_pipelined_megacore(x: jax.Array, y: jax.Array) -> jax.Array:\n", " block_spec = pl.BlockSpec((256, 512), lambda i: (i, 0))\n", " return pl.pallas_call(\n", @@ -715,7 +259,8 @@ " in_specs=[block_spec, block_spec],\n", " out_specs=block_spec,\n", " grid=(2,),\n", - " compiler_params=pltpu.TPUCompilerParams(dimension_semantics=(\"parallel\",))\n", + " compiler_params=pltpu.CompilerParams(\n", + " dimension_semantics=(\"parallel\",))\n", " )(x, y)\n", "\n", "x, y = jnp.ones((512, 512)), jnp.ones((512, 512))\n", @@ -737,28 +282,16 @@ "\n", "> Note that Megacore is only currently available on TPU `v4` and TPU `v5p`. Supplying `dimension_semantics` annotations is a no-op on other platforms, but *not* specifying it will result in only one TensorCore being used (even if there are more than one available)." ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "1ZJ2rV5W8FAe" - }, - "source": [ - "## Conclusion\n", - "\n", - "In this guide we covered how to express TPU pipelines using `pallas_call`,\n", - "`grid` and `BlockSpec`s. We covered how to express nested loops via a\n", - "multi-dimensional grid and how to handle reductions by initialize our\n", - "accumulators at the beginning of the reduction.\n", - "We also learned how to handle Megacore by adding annotations to the kernel.\n", - "\n", - "Exercises left to the reader:\n", - "* Try implementing a `sum` kernel that pipelines the other dimensions as well\n", - "* Add megacore support to the `add` kernel and the `sum` kernel as well." - ] } ], "metadata": { + "colab": { + "last_runtime": { + "build_target": "//experimental/users/justinfu/pallas:colab", + "kind": "private" + }, + "provenance": [] + }, "jupytext": { "formats": "ipynb,md:myst" }, diff --git a/docs/pallas/tpu/pipelining.md b/docs/pallas/tpu/pipelining.md index df570cf0806c..44a252410151 100644 --- a/docs/pallas/tpu/pipelining.md +++ b/docs/pallas/tpu/pipelining.md @@ -11,22 +11,33 @@ kernelspec: name: python3 --- ++++ {"id": "7704d3bb"} + (pallas_tpu_pipelining)= +++ {"id": "teoJ_fUwlu0l"} -# Pipelining +# TPU Pipelining +++ {"id": "gAJDZh1gBh-h"} -In this guide we'll cover how memory spaces in TPU work and how to write -pipelines in Pallas that overlap memory I/O with compute. +This guide serves as a reference for TPU-specific pipelining concerns. +We'll review the memory hierarchy and compute units on TPUs, and TPU-specific features of the pipelining API. For a more general-purpose overview of pipelining, see the {ref}`pallas_software_pipelining`. ```{code-cell} -:id: ejAVO6ikUUuF - +--- +executionInfo: + elapsed: 54 + status: ok + timestamp: 1744908474512 + user: + displayName: Justin Fu + userId: '17543197034567316452' + user_tz: 420 +id: ejAVO6ikUUuF +--- #@title Imports import jax @@ -36,13 +47,13 @@ import jax.numpy as jnp import numpy as np ``` -+++ {"id": "TWKESTKAlyjT"} ++++ {"id": "0e212a5e"} (tpu_and_its_memory_spaces)= ## TPU and its memory spaces -+++ ++++ {"id": "NnWW9GV4kW6P"} A TPU and its TensorCore consist of memory spaces (where arrays can reside), registers (which temporarily store scalar and array values) and compute units @@ -65,384 +76,71 @@ Let's talk about the components of this diagram in more detail: Values can be loaded into memory from their respective caches (VMEM for VREGs and SMEM for SREGs). * **Compute units**: A TensorCore has a scalar unit, vector unit (VPU) and - matrix unit (MXU) that can do numerical computation. + matrix unit (MXU) that can do numerical computation. Each of these compute units can operate asynchronously, but this is managed by the TPU compiler and thus from the programmer's perspective a TPU program is single-threaded. Compute units operate on values that live in SREGs and VREGs and output values into those registers as well. -In order to do a vectorized computation on our values `x` and `y` that live -in HBM, we need to: - -1. Copy the values `x` and `y` into VMEM. -2. Load the values from VMEM into VREGs. -3. Execute the computation using the VPU or MXU, storing the output in VREGs. -4. Store the values in the output VREGs into VMEM. -5. Copy the output values in VMEM back to HBM. - -+++ {"id": "TzctMbNsn3vc"} - -Let's implement a Pallas function that does just that! - -```{code-cell} -:id: 2IXQxNWrKJyb -:outputId: d62eb493-5f92-4496-f113-d3cd24cb0b9f - -def add_matrices_kernel(x_vmem_ref, y_vmem_ref, z_vmem_ref): - # Load x and y from VMEM into VREGs - x_vregs = x_vmem_ref[:, :] - y_vregs = y_vmem_ref[:, :] - # Execute a vectorized add - z_vregs = x_vregs + y_vregs - # Store the output values in VREGs back into VMEM - z_vmem_ref[:, :] = z_vregs - - -def add_matrices(x: jax.Array, y: jax.Array) -> jax.Array: - # pallas_call will first allocate scratch buffers for `x` and `y` in VMEM. - # It will then copy `x` and `y` from HBM into VMEM. - z = pl.pallas_call( - add_matrices_kernel, out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype) - )(x, y) - # pallas_call will also copy the output from VMEM back into HBM. - return z - - -x, y = jnp.ones((512, 512)), jnp.ones((512, 512)) -add_matrices(x, y) -``` - -+++ {"id": "HMENNLy8okCL"} - -We've written two functions: `add_matrices_kernel` and `add_matrices`. - -`add_matrices_kernel` operates using `Ref`s that live in VMEM. -Loading from a VMEM `Ref` produces a value that lives in VREGs. -Values in VREGs behave like `jax.Array`s in that we can use `jnp` and -`jax.lax` operations on them to produce new values that live in VREGs. -When we produce the values we'd like to return, we store them in the output -VMEM `Ref`. - -The `add_matrices` function acts on `jax.Array`s and returns a `jax.Array`. -Inside it, we pass `x` and `y` into `pallas_call`. -`pallas_call` is responsible for copying `x` and `y` into VMEM and for -allocating the VMEM buffers that the kernel operates on (including allocating -`z_vmem_ref`, the output VMEM buffer). -After the kernel function is finished running, `pallas_call` will also copy -the value in `z_vmem_ref` to HBM, resulting in an output `jax.Array`. - -+++ {"id": "5kWr-1tKpYro"} - -## Constraints of using VMEM/SMEM - -Pallas exposes access to lower level memory spaces like VMEM and SMEM but -writing kernels utilizing them adds some considerations. - -1. Memory capacity. VMEM and SMEM are *small*! VMEM on v4 TPUs is only 16MiB - and SMEM ranges in the tens to hundreds of KiB. - If our arrays are too big, we won't even be able to fit them into VMEM at all. - For reference, a `f32[2048, 2048]` array is 16MiB, so our above kernel won't - scale beyond moderately sized arrays. - -2. Memory bandwidth. Copying to/from HBM and VMEM takes a long time, at least - compared to most compute instructions. - The `add_matrices` function above will likely spend more time copying - between HBM and VMEM than actually performing the addition itself. - -With these two constraints in mind, we'll have to rethink our strategy for -getting performance out of our TPUs. - -+++ {"id": "_NTqvlbetB3P"} - -## Primer: Pipelining - -Pipelining our computation offers a way of dealing with both the memory -capacity and bandwidth constraints in one fell swoop. -What do we mean by pipelining? - -The goal is: *in parallel* copy to/from HBM and VMEM *while* utilizing our -compute units. -Naively this is difficult because in our program above we copy *all* of `x` -and `y` before we start doing any compute with them, creating a dependence -between the copy and the compute. - -However, if we can chunk up our computation into several subcomputations -(e.g. when we add two matrices, we can express that as addition of "blocks" -of the original matrices together), we can now overlap the copies of one of -those subcomputations with the compute of the other. Let's walk through a -simple example: - -Let's say we split our arrays `x` and `y` into `x1, x2` and `y1, y2` (for -example, split along the leading axis, resulting in two `(256, 512)` arrays -for each input. -We can now execute the following pipelined computation. - -1. Copy `x1` and `y1` into VMEM. -1. Start copying `x2` and `y2` into VMEM -2. Load `x1, y1` from VMEM into VREGs. -3. Execute the `z1 = x1 + y1` using the compute units. -4. Store `z1` into VMEM. -5. Start copying `z1` from VMEM back into HBM. -6. Wait until `x2, y2` have been copied into VMEM. -7. Load `x2, y2` from VMEM into VREGs. -8. Execute the `z2 = x2 + y2` using the compute units. -9. Store `z2` into VMEM. -10. Wait until `z1` is copied into HBM. -10. Start copying `z2` from VMEM back into HBM. -10. Wait until `z2` is copied into HBM. - -Any time we are doing compute here, we are asynchronously copying something. -This means that some of the time spent copying is not wasted. - -The two most important numbers for determining how efficient a pipelined -computation are a) how many floating point operations (FLOPs) we need to -execute and b) how many bytes we need to copy to execute that computation. -The ratio of these two (FLOPs/memory usage) is called the -*arithmetic intensity* of an operation and determines if our pipeline will -be compute bound or memory bound. - -+++ {"id": "gutx7y8uvZKH"} - -## Pipelining in Pallas - -+++ {"id": "U-dPTjlBverB"} - -How do we implement a pipeline like the one above in Pallas? -It seems like a complex sequence of asynchronous data operations and -executing kernels that would be a pain to implement manually. -Fear not! Pallas offers an API for expressing pipelines without too much -boilerplate, namely through `grid`s and `BlockSpec`s. - -See how in the above pipelined example, we are executing the same logic -multiple times: steps 3-5 and 8-10 both execute the same operations, -only on different inputs. -The {func}`jax.experimental.pallas.pallas_call` provides a way to -execute a kernel multiple times, by using the `grid` argument. -See {ref}`pallas_grid`. - -We also use {class}`jax.experimental.pallas.BlockSpec` to specify -how to construct the input of each kernel invocation. -See {ref}`pallas_blockspec`. - -In the pipelining example above, we had `(512, 512)`-shaped arrays and -split them along the leading dimension into two `(256, 512)`-shaped arrays. -In this pipeline, our `BlockSpec.block_shape` would be `(256, 512)`. -On the 1st iteration we'd -like to select `x1` and on the second iteration we'd like to use `x2`. -This can be expressed with the following `index_map`: - -```python -def x_index_map(i): - return (i, 0) -``` - -We'd then construct the `BlockSpec`: -```python -block_spec = pl.BlockSpec((256, 512), x_index_map) -``` - -The `BlockSpec`s for `y` and `z` will be the same as the one for `x`. - -+++ {"id": "noybOKghzjwG"} - -### Putting it together - -We provide these arguments to `pallas_call` via `grid`, `in_specs` and -`out_specs` (`in_specs` corresponds to the tuple of positional arguments, -and `out_specs` corresponds to the output). - -```{code-cell} -:id: ehKAYAwIojfv -:outputId: 504bab29-83f3-4e1f-8664-1860ad15b6de - -def add_matrices_pipelined(x: jax.Array, y: jax.Array) -> jax.Array: - block_spec = pl.BlockSpec((256, 512), lambda i: (i, 0)) - return pl.pallas_call( - add_matrices_kernel, - out_shape=x, - in_specs=[block_spec, block_spec], - out_specs=block_spec, - grid=(2,) - )(x, y) - -add_matrices_pipelined(x, y) -``` - -+++ {"id": "rkytgIZYzz4t"} - -We've only added a little bit of code to our original function to add -automatic pipelining but the `BlockSpec`s and `grid` do a lot of heavy -lifting! - -How does it work? Well, the `BlockSpec`s provide enough information to start -*prefetching* blocks of our input from HBM into VMEM. -For example, if we are starting iteration `i` of our `grid`, we can pass -`i + 1` into the `index_map` functions to obtain the blocks needed for the -next iteration. We can then start an asynchronous copy for those blocks. -Similarly for outputs, we can wait for the outputs of the previous iteration -to be copied before starting the copy for the current iteration's outputs. ++++ {"id": "8Tl3wt5Wk3Ek"} -+++ {"id": "7Xtz9oMs0ZRL"} +## TPU-specific Pipelining Features -### Parameterizing a pipeline +Pallas TPU supports the following platform-specific features. -+++ {"id": "esY4GcIB0bqQ"} ++++ {"id": "1jg5WmExk47l"} -It's common to parameterize the block shapes in our kernel. Block sizes are -perhaps the most important parameter to tune when optimizing the performance -of Pallas kernels! They give us control over the pipeline (for example, -picking smaller blocks adds more iterations to our pipelined loop where each -iteration has less work to do). +### TPU Memory Spaces -Furthermore, we could also carve up the inputs and outputs along the 2nd -dimension (we are only splitting along the first right now). Let's write a -more general kernel that handles both of these features. +Pallas exposes all levels of the TPU memory hierarchy to users. The following table maps from Pallas TPU memory spaces to their standard memory types (DRAM/SRAM): -```{code-cell} -:id: VartelFd0YfY - -def add_matrices_pipelined_2d( - x: jax.Array, y: jax.Array, *, bm: int = 256, bn: int = 256 -) -> jax.Array: - m, n = x.shape - block_spec = pl.BlockSpec((bm, bn), lambda i, j: (i, j)) - return pl.pallas_call( - add_matrices_kernel, - out_shape=x, - in_specs=[block_spec, block_spec], - out_specs=block_spec, - grid=(m // bm, n // bn), - )(x, y) - -np.testing.assert_array_equal( - add_matrices_pipelined_2d(x, y, bm=256, bn=256), x + y -) -np.testing.assert_array_equal( - add_matrices_pipelined_2d(x, y, bm=128, bn=128), x + y -) -np.testing.assert_array_equal( - add_matrices_pipelined_2d(x, y, bm=512, bn=512), x + y -) -``` - -+++ {"id": "KrfeYwaW1QA-"} - -## Handling reductions +| Pallas Enum | TPU Memory Space | Type (DRAM/SRAM) | +| --- | --- | --- | +| `pltpu.MemorySpace.ANY` | HBM (usually) or VMEM | DRAM | +| `pltpu.MemorySpace.VMEM` | VMEM | SRAM | +| `pltpu.MemorySpace.SMEM` | SMEM | SRAM | +| `pltpu.MemorySpace.SEMAPHORE` | Semaphore | SRAM | -+++ {"id": "P3SqEKDe3Mar"} +- `MemorySpace.VMEM` denotes vector SRAM. It is the default memory space if nothing is specified. +- `MemorySpace.SMEM` denotes scalar SRAM. Only scalar loads and stores can be performed to/from SMEM. +- `MemorySpace.ANY` is a hint to the compiler that the memory space is unconstrained. In most cases, XLA will place this buffer in HBM. A buffer assigned to the `ANY` memory space cannot be dereferenced normally using array indexing syntax (e.g. `x[...]`). Instead, we must first copy the values into a VMEM or SMEM buffer using `pltpu.sync_copy` or `pltpu.async_copy`. +- `MemorySpace.SEMAPHORE` is used to allocate semaphores for constructing barriers or tracking asynchronous operations. It is also possible to return semaphores from the kernel for building asynchronous kernels - this is an experimental feature; see {ref}`pallas_async` for more details. -How would you implement something like `jnp.sum` using `pallas_call`? -Specifically, we'd like to pipeline across the reduction dimension. +Pipelining on TPUs is typically done between HBM (DRAM) to VMEM (Vector SRAM). The default behavior for `pallas_call` on TPU is that arguments to `pallas_call` are assumed to live in HBM, and inputs to the user kernel body are stored in VMEM. -Take the example of reducing a `(8, 512, 512)`-shaped array to a -`(512, 512)`-shaped one. - -```{code-cell} -:id: JoT-ZKEk1R7l -:outputId: fd842223-98a5-4e5c-87fc-5dadc94da4fa - -x = jnp.ones((8, 512, 512)) -jnp.sum(x, axis=0) -``` - -+++ {"id": "5O3ByvuT3iyC"} - -To do this using `pallas_call`, we could use a grid of size `(8,)` and in -each iteration `i` load `x[i]` into VMEM. -Then we could add `x[i]` to an output VMEM buffer. Let's implement this -naively first. - -```{code-cell} -:id: hqvv_WRQ3bvP -:outputId: 200648d2-3f4d-4d1a-b95a-d2c1352cd7b8 - -# Warning: this implementation is incorrect! - -def naive_sum_kernel(x_ref, o_ref): - o_ref[...] += x_ref[...] - -def naive_sum(x: jax.Array) -> jax.Array: - grid, *out_shape = x.shape - return pl.pallas_call( - naive_sum_kernel, - grid=grid, - # None in `block_shape` means we pick a size of 1 and squeeze it away - in_specs=[pl.BlockSpec((None, *out_shape), lambda i: (i, 0, 0))], - out_specs=pl.BlockSpec(out_shape, lambda i: (0, 0)), - out_shape=jax.ShapeDtypeStruct(out_shape, x.dtype), - )(x) -naive_sum(x) -``` +While not specific to pipelining, it is possible to gain manual control over the memory space of input and output buffers, you can specify the `memory_space` argument on a `BlockSpec`. Note that pipelining is not allowed unless the `memory_space` is marked as `VMEM`. Memory spaces can also be used to specify scratch arguments to a kernel via the `scratch_shapes` argument on `pallas_call`. Scratch buffers are persistent across kernel iterations and are useful for storing intermediate results such as partial accumulations and reductions. A scratch buffer must reside in `VMEM`, `SMEM`, or `SEMAPHORE`. -+++ {"id": "Kv9qJYJY4jbK"} - -Notice how we've set up the `BlockSpec`s: we're loading the entirety of -the `(512, 512)` dimension into VMEM (no pipelining there) but selecting -the `i`-th dimension of `x` each iteration in the `index_map`. -We are using a `None` for that dimension in the block shape, which indicates -that we are selecting a singleton dimension from `x` that we would like -to squeeze away in the kernel. -Therefore, `x_ref` is `(512, 512)`-shaped in VMEM as well. - -`out_spec` uses `lambda i: (0, 0)` as its `index_map`, indicating that -`o_ref` is unchanged over the course of the pipeline. -This means that we can update its value each iteration by reading from and -writing to it. Or can it? -Actually there is one catch: *`o_ref` is initially garbage*, meaning we'll -be accumulating into garbage. -This will result in the overall function outputting the incorrect value! - -Therefore, **whenever we do a reduction in a kernel, we need to make sure -to initialize the `Ref` that is storing the reduced value**. -We can accomplish this by conditionally writing a value to `out_ref` -when we're on iteration 0. -We can do this with the helper function `pl.when`, a convenience wrapper -around `jax.lax.cond`, and `pl.program_id`, -which queries which iteration in a grid axis we are in. +As an example for using multiple manual memory space assignments in a kernel, the following program copies a slice of an HBM buffer `x_hbm_ref` into a scratch VMEM buffer `scratch_vmem_ref` before using it for arithmetic and storing the result into an output VMEM buffer: ```{code-cell} -:id: JXN2RthX5cSw -:outputId: 195df19b-a889-479b-95b6-1fb7281f1518 - -def sum_kernel(x_ref, o_ref): - @pl.when(pl.program_id(axis=0) == 0) - def _(): - o_ref[...] = jnp.zeros_like(o_ref) - - o_ref[...] += x_ref[...] - -def sum(x: jax.Array) -> jax.Array: - grid, *out_shape = x.shape - return pl.pallas_call( - sum_kernel, - grid=grid, - # None in `block_shape` means we pick a size of 1 and squeeze it away - in_specs=[pl.BlockSpec((None, *out_shape), lambda i: (i, 0, 0))], - out_specs=pl.BlockSpec(out_shape, lambda i: (0, 0)), - out_shape=jax.ShapeDtypeStruct(out_shape, x.dtype) - )(x) - -sum(x) +--- +executionInfo: + elapsed: 65 + status: ok + timestamp: 1744908591430 + user: + displayName: Justin Fu + userId: '17543197034567316452' + user_tz: 420 +id: zcqz1CA_o50a +--- +def hbm_vmem_kernel(x_hbm_ref, out_vmem_ref, scratch_vmem_ref): + pltpu.sync_copy(x_hbm_ref.at[0:1], scratch_vmem_ref) + out_vmem_ref[...] = scratch_vmem_ref[...] + 1 + +x = jax.random.uniform(jax.random.key(0), (8, 128), jnp.float32) +out = pl.pallas_call(hbm_vmem_kernel, + in_specs=[pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY)], + out_shape=jax.ShapeDtypeStruct((1, 128), jnp.float32), + scratch_shapes=(pltpu.MemorySpace.VMEM(shape=(1, 128), dtype=jnp.float32),) +)(x) + +np.testing.assert_allclose(out, x[0:1] + 1) ``` -+++ {"id": "2828qXBI5ksZ"} - -This `sum` function now outputs the correct values! - -One last thing to note about reductions in Pallas are that **they must be -done in the minormost (rightmost) dimensions of our grid** (our grid is -1-dimensional in the above example so we are reducing over its minormost -dimension). This is because the pipeline that Pallas generates using -the `BlockSpec`s, `grid` and kernel function *does not read outputs back -from HBM*. -Once you've written an output value back to HBM you cannot revisit it. -Therefore, you cannot do a reduction across a grid dimension that has any -revisiting and therefore all reductions need to happen in the rightmost -dimensions. - +++ {"id": "KvPFez9N8cKJ"} (pallas_tpu_megacore)= -## TPUs in Megacore configuration +### TPUs in Megacore configuration +++ {"id": "0f4HAVzQ8n71"} @@ -463,8 +161,26 @@ We can indicate which dimensions are parallelizable by providing an annotation to `pallas_call` called `dimension_semantics`. ```{code-cell} -:id: nQNa8RaQ-TR1 -:outputId: 385ed87c-d95c-466c-af77-df3845c979f2 +--- +executionInfo: + elapsed: 106 + status: ok + timestamp: 1744910274556 + user: + displayName: Justin Fu + userId: '17543197034567316452' + user_tz: 420 +id: nQNa8RaQ-TR1 +outputId: 29c0b574-3528-49a5-8a88-b6987efc69ce +--- +def add_matrices_kernel(x_vmem_ref, y_vmem_ref, z_vmem_ref): + # Load x and y from VMEM into VREGs + x_vregs = x_vmem_ref[:, :] + y_vregs = y_vmem_ref[:, :] + # Execute a vectorized add + z_vregs = x_vregs + y_vregs + # Store the output values in VREGs back into VMEM + z_vmem_ref[:, :] = z_vregs def add_matrices_pipelined_megacore(x: jax.Array, y: jax.Array) -> jax.Array: block_spec = pl.BlockSpec((256, 512), lambda i: (i, 0)) @@ -474,7 +190,8 @@ def add_matrices_pipelined_megacore(x: jax.Array, y: jax.Array) -> jax.Array: in_specs=[block_spec, block_spec], out_specs=block_spec, grid=(2,), - compiler_params=pltpu.TPUCompilerParams(dimension_semantics=("parallel",)) + compiler_params=pltpu.CompilerParams( + dimension_semantics=("parallel",)) )(x, y) x, y = jnp.ones((512, 512)), jnp.ones((512, 512)) @@ -491,17 +208,3 @@ simultaneously on each TensorCore. Pallas will handle splitting up the grid automatically. > Note that Megacore is only currently available on TPU `v4` and TPU `v5p`. Supplying `dimension_semantics` annotations is a no-op on other platforms, but *not* specifying it will result in only one TensorCore being used (even if there are more than one available). - -+++ {"id": "1ZJ2rV5W8FAe"} - -## Conclusion - -In this guide we covered how to express TPU pipelines using `pallas_call`, -`grid` and `BlockSpec`s. We covered how to express nested loops via a -multi-dimensional grid and how to handle reductions by initialize our -accumulators at the beginning of the reduction. -We also learned how to handle Megacore by adding annotations to the kernel. - -Exercises left to the reader: -* Try implementing a `sum` kernel that pipelines the other dimensions as well -* Add megacore support to the `add` kernel and the `sum` kernel as well. diff --git a/docs/pallas/tpu/sparse.ipynb b/docs/pallas/tpu/sparse.ipynb index ac3a0dad2404..6834f2d7d930 100644 --- a/docs/pallas/tpu/sparse.ipynb +++ b/docs/pallas/tpu/sparse.ipynb @@ -62,7 +62,7 @@ "source": [ "## Dynamic Block Indexing with Scalar Prefetch\n", "\n", - "We will be exploiting the \"scalar prefetch\" feature of Pallas to enable us to write sparse kernels. Scalar prefetch allows you to pass in a small amount of data into SMEM (\"scalar memory\") that is loaded before the start of the pipeline (\"prefetch\"). Because this data is loaded before the pipeline, it is available for use in the `index_map` for each BlockSpec, allowing the you to perform data-dependent indexing calculations. The main goal of this tutorial is to go over common programming patterns that utilize this feature.\n", + "We will be exploiting the \"scalar prefetch\" feature of Pallas to enable us to write sparse kernels. Scalar prefetch allows you to pass in a small amount of data into SMEM (\"scalar memory\") that is loaded before the start of the pipeline (\"prefetch\"). Because this data is loaded before the pipeline, it is available for use in the `index_map` for each BlockSpec, allowing you to perform data-dependent indexing calculations. The main goal of this tutorial is to go over common programming patterns that utilize this feature.\n", "\n", "To use scalar prefetch, use `pltpu.PrefetchScalarGridSpec` in place of the standard `pl.GridSpec`:\n", "\n", @@ -253,13 +253,13 @@ "source": [ "## Example: Sparse @ Dense Matrix Multiplication\n", "\n", - "In our first example, we will multiple a sparse LHS matrix with a dense RHS matrix to produce a dense output.\n", + "In our first example, we will multiply a sparse LHS matrix with a dense RHS matrix to produce a dense output.\n", "\n", "We will structure our kernel grid with 2 loops - the outer loop over the columns of the RHS/output, and inner loop over the sparse blocks of the LHS. During each inner loop iteration, we load one block from the LHS and lookup the corresponding block on in the RHS using the block index of the contracting dimension (K). We multiply the two blocks together and accumulate into the correct output block. One outer loop iteration will compute a result for an entire column as depicted by the following diagram:\n", "\n", "![sparse_matmul](../../_static/pallas/sparse/sparse_matmul.svg)\n", "\n", - "It is important that we group the block indices by row (e.g. `[0, 0, 1, 2, 3, 3]`) before we pass them into the kernel for two reasons. First, in our kernel we need to know when to initially zero-out the accumulator in the output ref, and it is easy to do so if the row indices are grouped. Second, the pipelining logic for Pallas does not allow us to re-visit blocks in the output `Ref` on non-consecutive iterations, and therefore we need to do all accumulation into an output block in consecutive kernel iterations. This is because the pipeline emitter will realize that we loading the same output block on consecutive iterations and keep the block in VMEM. When we change output block Pallas will finally store the output into HBM and assume we never touch it again. Failure to access output blocks consecutively will result in incorrect values even though the kernel is otherwise logically correct." + "It is important that we group the block indices by row (e.g. `[0, 0, 1, 2, 3, 3]`) before we pass them into the kernel for two reasons. First, in our kernel we need to know when to initially zero-out the accumulator in the output ref, and it is easy to do so if the row indices are grouped. Second, the pipelining logic for Pallas does not allow us to re-visit blocks in the output `Ref` on non-consecutive iterations, and therefore we need to do all accumulation into an output block in consecutive kernel iterations. This is because the pipeline emitter will realize that we are loading the same output block on consecutive iterations and keep the block in VMEM. When we change output block Pallas will finally store the output into HBM and assume we never touch it again. Failure to access output blocks consecutively will result in incorrect values even though the kernel is otherwise logically correct." ] }, { @@ -437,7 +437,7 @@ "\n", "In our previous example we considered the case when the data itself is sparse. This manifested itself in the kernel structure as a dimension in the kernel grid that was dynamic and looped over the number of nonzero blocks (`num_blocks`).\n", "\n", - "A second useful programming pattern emerges when the underlying is data is dense, but we wish to perform sparse computation over it. Our kernel grid in this case will be dense, but we wish to skip over some blocks in the grid as indicated by a block-sparse mask. This type of programming pattern is commonly arises when using masks in many machine learning applications, such as causal or local masks in self-attention. In these cases, we can entirely skip over computation in blocks where the mask is zeroed-out. Examples of this programming pattern can be found in the Splash Attention and Grouped Matrix Multiplication kernels located in `jax/experimental/pallas/ops/tpu`, or in PyTorch's [FlexAttention](https://pytorch.org/blog/flexattention/).\n", + "A second useful programming pattern emerges when the underlying data is dense, but we wish to perform sparse computation over it. Our kernel grid in this case will be dense, but we wish to skip over some blocks in the grid as indicated by a block-sparse mask. This type of programming pattern commonly arises when using masks in many machine learning applications, such as causal or local masks in self-attention. In these cases, we can entirely skip over computation in blocks where the mask is zeroed-out. Examples of this programming pattern can be found in the Splash Attention and Grouped Matrix Multiplication kernels located in `jax/experimental/pallas/ops/tpu`, or in PyTorch's [FlexAttention](https://pytorch.org/blog/flexattention/).\n", "\n", "The main performance consideration with dealing with a sparse access pattern on dense data is the interaction with pipelining. On any given kernel iteration, the Pallas pipeline emitter will attempt to prefetch the next block of data by calling the `index_map` for each `BlockSpec` on the next iteration of the grid. However, if our computation is sparse we may be skipping the computation for the next block in the grid, so we need some method to tell the pipeline instead begin fetching the *next block that we are not skipping*. In order to do this, we need to construct *prefetch maps* which contains indices to the next non-skipped block of data for each kernel input. The following diagram illustrates how a prefetch map could be constructed for a block-sparse mask that is stored in a COO-like format.\n", "\n", @@ -491,7 +491,7 @@ "source": [ "def sparsify_mask(mask: jax.Array,\n", " block_shape: tuple[int, int]):\n", - " \"\"\"Preprocesses a mask into a sparse reprentation.\n", + " \"\"\"Preprocesses a mask into a sparse representation.\n", "\n", " Args:\n", " mask: A boolean array of shape [M, N]\n", @@ -511,7 +511,6 @@ " block_mask = jnp.zeros((M // bm, N // bn), dtype=mask.dtype)\n", " mask_types_finder = []\n", " mask_data = []\n", - " mask_type_idxs = []\n", "\n", " next_mask_type_idx = 0\n", " prefetch_mask = jnp.zeros_like(block_mask)\n", @@ -536,7 +535,6 @@ " next_j = j\n", " else:\n", " type_index = -1\n", - " mask_type_idxs.append(type_index)\n", " block_mask = block_mask.at[i, j].set(is_nonzero)\n", " prefetch_mask = prefetch_mask.at[i, j].set(next_mask_type_idx)\n", " prefetch_i = prefetch_i.at[i, j].set(next_i)\n", @@ -665,7 +663,7 @@ "\n", "We would generally expect performance to get closer to the theoretical peak as our inputs get larger, since a few of the main reasons why we don't exactly reach theoretical performance are:\n", "- We skip slightly less than half of computation since the blocks along the diagonal are mixed 0s and 1s, and for mixed blocks we need to compute the entire block. With larger inputs, our overhead for mixed blocks becomes smaller relative to the overall computation.\n", - "- The pipeline bubble also becomes accounts for a less percentage of the overall runtime as inputs become larger." + "- The pipeline bubble also accounts for a less percentage of the overall runtime as inputs become larger." ] }, { diff --git a/docs/pallas/tpu/sparse.md b/docs/pallas/tpu/sparse.md index 113f31d8bab2..e9a4bb143a2f 100644 --- a/docs/pallas/tpu/sparse.md +++ b/docs/pallas/tpu/sparse.md @@ -51,7 +51,7 @@ print("Running on", jax.devices()[0].device_kind) ## Dynamic Block Indexing with Scalar Prefetch -We will be exploiting the "scalar prefetch" feature of Pallas to enable us to write sparse kernels. Scalar prefetch allows you to pass in a small amount of data into SMEM ("scalar memory") that is loaded before the start of the pipeline ("prefetch"). Because this data is loaded before the pipeline, it is available for use in the `index_map` for each BlockSpec, allowing the you to perform data-dependent indexing calculations. The main goal of this tutorial is to go over common programming patterns that utilize this feature. +We will be exploiting the "scalar prefetch" feature of Pallas to enable us to write sparse kernels. Scalar prefetch allows you to pass in a small amount of data into SMEM ("scalar memory") that is loaded before the start of the pipeline ("prefetch"). Because this data is loaded before the pipeline, it is available for use in the `index_map` for each BlockSpec, allowing you to perform data-dependent indexing calculations. The main goal of this tutorial is to go over common programming patterns that utilize this feature. To use scalar prefetch, use `pltpu.PrefetchScalarGridSpec` in place of the standard `pl.GridSpec`: @@ -208,13 +208,13 @@ def generate_block_sparse_mat(key, M, N, blk_M, blk_N, p=0.2, dtype=jnp.float32) ## Example: Sparse @ Dense Matrix Multiplication -In our first example, we will multiple a sparse LHS matrix with a dense RHS matrix to produce a dense output. +In our first example, we will multiply a sparse LHS matrix with a dense RHS matrix to produce a dense output. We will structure our kernel grid with 2 loops - the outer loop over the columns of the RHS/output, and inner loop over the sparse blocks of the LHS. During each inner loop iteration, we load one block from the LHS and lookup the corresponding block on in the RHS using the block index of the contracting dimension (K). We multiply the two blocks together and accumulate into the correct output block. One outer loop iteration will compute a result for an entire column as depicted by the following diagram: ![sparse_matmul](../../_static/pallas/sparse/sparse_matmul.svg) -It is important that we group the block indices by row (e.g. `[0, 0, 1, 2, 3, 3]`) before we pass them into the kernel for two reasons. First, in our kernel we need to know when to initially zero-out the accumulator in the output ref, and it is easy to do so if the row indices are grouped. Second, the pipelining logic for Pallas does not allow us to re-visit blocks in the output `Ref` on non-consecutive iterations, and therefore we need to do all accumulation into an output block in consecutive kernel iterations. This is because the pipeline emitter will realize that we loading the same output block on consecutive iterations and keep the block in VMEM. When we change output block Pallas will finally store the output into HBM and assume we never touch it again. Failure to access output blocks consecutively will result in incorrect values even though the kernel is otherwise logically correct. +It is important that we group the block indices by row (e.g. `[0, 0, 1, 2, 3, 3]`) before we pass them into the kernel for two reasons. First, in our kernel we need to know when to initially zero-out the accumulator in the output ref, and it is easy to do so if the row indices are grouped. Second, the pipelining logic for Pallas does not allow us to re-visit blocks in the output `Ref` on non-consecutive iterations, and therefore we need to do all accumulation into an output block in consecutive kernel iterations. This is because the pipeline emitter will realize that we are loading the same output block on consecutive iterations and keep the block in VMEM. When we change output block Pallas will finally store the output into HBM and assume we never touch it again. Failure to access output blocks consecutively will result in incorrect values even though the kernel is otherwise logically correct. ```{code-cell} --- @@ -353,7 +353,7 @@ print("Reference: %.3f ms (avg over %d trials)" % (time * 1000, n_trials)) In our previous example we considered the case when the data itself is sparse. This manifested itself in the kernel structure as a dimension in the kernel grid that was dynamic and looped over the number of nonzero blocks (`num_blocks`). -A second useful programming pattern emerges when the underlying is data is dense, but we wish to perform sparse computation over it. Our kernel grid in this case will be dense, but we wish to skip over some blocks in the grid as indicated by a block-sparse mask. This type of programming pattern is commonly arises when using masks in many machine learning applications, such as causal or local masks in self-attention. In these cases, we can entirely skip over computation in blocks where the mask is zeroed-out. Examples of this programming pattern can be found in the Splash Attention and Grouped Matrix Multiplication kernels located in `jax/experimental/pallas/ops/tpu`, or in PyTorch's [FlexAttention](https://pytorch.org/blog/flexattention/). +A second useful programming pattern emerges when the underlying data is dense, but we wish to perform sparse computation over it. Our kernel grid in this case will be dense, but we wish to skip over some blocks in the grid as indicated by a block-sparse mask. This type of programming pattern commonly arises when using masks in many machine learning applications, such as causal or local masks in self-attention. In these cases, we can entirely skip over computation in blocks where the mask is zeroed-out. Examples of this programming pattern can be found in the Splash Attention and Grouped Matrix Multiplication kernels located in `jax/experimental/pallas/ops/tpu`, or in PyTorch's [FlexAttention](https://pytorch.org/blog/flexattention/). The main performance consideration with dealing with a sparse access pattern on dense data is the interaction with pipelining. On any given kernel iteration, the Pallas pipeline emitter will attempt to prefetch the next block of data by calling the `index_map` for each `BlockSpec` on the next iteration of the grid. However, if our computation is sparse we may be skipping the computation for the next block in the grid, so we need some method to tell the pipeline instead begin fetching the *next block that we are not skipping*. In order to do this, we need to construct *prefetch maps* which contains indices to the next non-skipped block of data for each kernel input. The following diagram illustrates how a prefetch map could be constructed for a block-sparse mask that is stored in a COO-like format. @@ -391,7 +391,7 @@ As we will be working with a sparse mask, we will begin by implementing a functi def sparsify_mask(mask: jax.Array, block_shape: tuple[int, int]): - """Preprocesses a mask into a sparse reprentation. + """Preprocesses a mask into a sparse representation. Args: mask: A boolean array of shape [M, N] @@ -411,7 +411,6 @@ def sparsify_mask(mask: jax.Array, block_mask = jnp.zeros((M // bm, N // bn), dtype=mask.dtype) mask_types_finder = [] mask_data = [] - mask_type_idxs = [] next_mask_type_idx = 0 prefetch_mask = jnp.zeros_like(block_mask) @@ -436,7 +435,6 @@ def sparsify_mask(mask: jax.Array, next_j = j else: type_index = -1 - mask_type_idxs.append(type_index) block_mask = block_mask.at[i, j].set(is_nonzero) prefetch_mask = prefetch_mask.at[i, j].set(next_mask_type_idx) prefetch_i = prefetch_i.at[i, j].set(next_i) @@ -542,7 +540,7 @@ Now let's compare performance versus a naive dense implementation. On TPU v5e, w We would generally expect performance to get closer to the theoretical peak as our inputs get larger, since a few of the main reasons why we don't exactly reach theoretical performance are: - We skip slightly less than half of computation since the blocks along the diagonal are mixed 0s and 1s, and for mixed blocks we need to compute the entire block. With larger inputs, our overhead for mixed blocks becomes smaller relative to the overall computation. -- The pipeline bubble also becomes accounts for a less percentage of the overall runtime as inputs become larger. +- The pipeline bubble also accounts for a less percentage of the overall runtime as inputs become larger. ```{code-cell} --- diff --git a/docs/persistent_compilation_cache.md b/docs/persistent_compilation_cache.md index 0a5a89abe26d..d795a054bc87 100644 --- a/docs/persistent_compilation_cache.md +++ b/docs/persistent_compilation_cache.md @@ -260,14 +260,13 @@ If we were to merely compile this function without shard_map, the cache key for layernorm_matmul_without_shard_map = jax.jit(F, in_shardings=(...), out_sharding=(...))(x1, x2, gamma, beta) ``` -However, if we were to wrap the layernorm primitive in shard_map and define a function G that performs the same computation, the cache key for `layernorm_matmul_with_shard_map` will be the same everytime despite `LayerNorm` being implementing `custom_partitioning`: +However, if we were to wrap the layernorm primitive in shard_map and define a function G that performs the same computation, the cache key for `layernorm_matmul_with_shard_map` will be the same every time despite `LayerNorm` being implementing `custom_partitioning`: ```python import jax -from jax.experimental.shard_map import shard_map def G(x1, x2, gamma, beta, mesh, ispecs, ospecs): - ln_out = shard_map(LayerNorm, mesh, in_specs=ispecs, out_specs=ospecs, check_rep=False)(x1, x2, gamma, beta) + ln_out = jax.shard_map(LayerNorm, mesh=mesh, in_specs=ispecs, out_specs=ospecs, check_vma=False)(x1, x2, gamma, beta) return ln_out @ x2 ispecs = jax.sharding.PartitionSpec(...) diff --git a/docs/profiling.md b/docs/profiling.md index ac992b3a05da..3800a7ce140e 100644 --- a/docs/profiling.md +++ b/docs/profiling.md @@ -8,7 +8,7 @@ We can use the JAX profiler to generate traces of a JAX program that can be visualized using the [Perfetto visualizer](https://ui.perfetto.dev). Currently, this method blocks the program until a link is clicked and the Perfetto UI loads the trace. If you wish to get profiling information without any interaction, -check out the Tensorboard profiler below. +check out the XProf profiler below. ```python with jax.profiler.trace("/tmp/jax-trace", create_perfetto_link=True): @@ -64,48 +64,46 @@ Also, by default, the program will prompt you to open a link to file and open a visualizer. This feature is disabled by passing in `--no_perfetto_link` into the command. Alternatively, you can also point Tensorboard to the `log_dir` to analyze the trace (see the -"Tensorboard Profiling" section below). +"XProf (Tensorboard Profiling)" section below). (tensorboard-profiling)= -## TensorBoard profiling +## XProf (TensorBoard profiling) -[TensorBoard's -profiler](https://www.tensorflow.org/tensorboard/tensorboard_profiling_keras) -can be used to profile JAX programs. Tensorboard is a great way to acquire and +[XProf](https://www.tensorflow.org/tensorboard/tensorboard_profiling_keras) +can be used to profile JAX programs. XProf is a great way to acquire and visualize performance traces and profiles of your program, including activity on GPU and TPU. The end result looks something like this: -![TensorBoard profiler example](_static/tensorboard_profiler.png) +![XProf example](_static/tensorboard_profiler.png) ### Installation -The TensorBoard profiler is only available with the version of TensorBoard -bundled with TensorFlow. - +XProf is available as a plugin to TensorBoard, as well as an independently +run program. ```shell -pip install tensorflow tensorboard-plugin-profile +pip install xprof ``` -If you already have TensorFlow installed, you only need to install the -`tensorboard-plugin-profile` pip package. Be careful to only install one version -of TensorFlow or TensorBoard, otherwise you may encounter the "duplicate -plugins" error described {ref}`below `. See +If you have TensorBoard installed, the `xprof` pip package will also install +the TensorBoard Profiler plugin. Be careful to only install one version of +TensorFlow or TensorBoard, otherwise you may encounter the "duplicate plugins" +error described {ref}`below `. See for more information on installing TensorBoard. -Nightly version of TensorBoard profiler requires nightly tensorflow and -tensorboard +Profiling with the nightly version of TensorBoard requires the nightly +XProf. ```shell -pip install tf-nightly tb-nightly tbp-nightly +pip install tb-nightly xprof-nightly ``` ### Programmatic capture You can instrument your code to capture a profiler trace via the -{func}`jax.profiler.start_trace` and {func}`jax.profiler.stop_trace` -methods. Call {func}`~jax.profiler.start_trace` with the directory to write -trace files to. This should be the same `--logdir` directory used to start -TensorBoard. Then, you can use TensorBoard to view the traces. +{func}`jax.profiler.start_trace` and {func}`jax.profiler.stop_trace` methods. +Call {func}`~jax.profiler.start_trace` with the directory to write trace files +to. This should be the same `--logdir` directory used to start TensorBoard. +Then, you can use TensorBoard to view the traces. For example, to take a profiler trace: @@ -140,29 +138,54 @@ with jax.profiler.trace("/tmp/tensorboard"): y.block_until_ready() ``` +### Viewing the trace + +After capturing a trace, you can view it using either the standalone XProf +tool or the TensorBoard UI. The profiler interface is the same in both cases. + +#### Using Standalone XProf + +You can launch the profiler UI directly using the standalone XProf command by +pointing it to your log directory: + +``` +$ xprof --port 8791 /tmp/tensorboard +Attempting to start XProf server: + Log Directory: /tmp/tensorboard + Port: 8791 +XProf at http://localhost:8791/ (Press CTRL+C to quit) +``` + +Navigate to the provided URL (e.g., http://localhost:8791/) in your browser +to view the profile. + +Available traces appear in the "Runs" dropdown menu on the left. Select the +run you're interested in, and then under the "Tools" dropdown, select +trace_viewer. You should now see a timeline of the execution. You can use the +WASD keys to navigate the trace, and click or drag to select events for more +details. See +[these TensorFlow docs](https://www.tensorflow.org/tensorboard/tensorboard_profiling_keras#use_the_tensorflow_profiler_to_profile_model_training_performance)= +for more details on using the trace viewer. + +#### With TensorBoard + To view the trace, first start TensorBoard if you haven't already: ```shell $ tensorboard --logdir=/tmp/tensorboard [...] Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all -TensorBoard 2.5.0 at http://localhost:6006/ (Press CTRL+C to quit) +TensorBoard 2.20.0 at http://localhost:6006/ (Press CTRL+C to quit) ``` -You should be able to load TensorBoard at in this -example. You can specify a different port with the `--port` flag. See -{ref}`remote_profiling` below if running JAX on a remote server. - -Then, either select "Profile" in the upper-right dropdown menu, or go directly -to . Available traces appear in the "Runs" -dropdown menu on the left. Select the run you're interested in, and then under -"Tools", select `trace_viewer`. You should now see a timeline of the -execution. You can use the WASD keys to navigate the trace, and click or drag to -select events to see more details at the bottom. See [these TensorFlow -docs](https://www.tensorflow.org/tensorboard/tensorboard_profiling_keras#use_the_tensorflow_profiler_to_profile_model_training_performance) -for more details on using the trace viewer. +You should be able to load TensorBoard at http://localhost:6006/ in this +example. Then, select "Profile" from the dropdown menu in the upper-right, +or navigate directly to http://localhost:6006/#profile. -You can also use the `memory_viewer`, `op_profile`, and `graph_viewer` tools. +From there, the experience is the same as the standalone tool: available +traces appear in the "Runs" dropdown menu on the left. Select the run +you're interested in, and then under "Tools", select trace_viewer to see the +timeline. ### Manual capture via TensorBoard @@ -231,6 +254,95 @@ functions. You can add your own events and functions by using {class}`jax.profiler.TraceAnnotation` and {func}`jax.profiler.annotate_function` in your code. +### Configuring profiler options + +The `start_trace` method accepts an optional `profiler_options` parameter, which +allows for fine-grained control over the profiler's behavior. This parameter +should be an instance of `jax.profiler.ProfileOptions`. + + +For example, to disable all python and host traces: + +```python +import jax + +options = jax.profiler.ProfileOptions() +options.python_tracer_level = 0 +options.host_tracer_level = 0 +jax.profiler.start_trace("/tmp/tensorboard", profiler_options=options) + +# Run the operations to be profiled +key = jax.random.key(0) +x = jax.random.normal(key, (5000, 5000)) +y = x @ x +y.block_until_ready() + +jax.profiler.stop_trace() +``` + +#### General options + +1. `host_tracer_level`: Sets the trace level for host-side activities. + + Supported Values: + + `0`: Disables host (CPU) tracing entirely. + + `1`: Enables tracing of only user-instrumented TraceMe events (this is the + default). + + `2`: Includes level 1 traces plus high-level program execution details like + expensive XLA operations. + + `3`: Includes level 2 traces plus more verbose, low-level program execution + details such as cheap XLA operations. + +2. `python_tracer_level`: Controls whether Python tracing is enabled. + + Supported Values: + + `0`: Disables Python function call tracing. + + `1`: Enables Python tracing (this is the default). + +#### Advanced configuration options + +1. `tpu_trace_mode`: Specifies the mode for TPU tracing. + + Supported Values: + + `TRACE_ONLY_HOST`: This means only host-side (CPU) activities are traced, + and no device (TPU/GPU) traces are collected. + + `TRACE_ONLY_XLA`: This means only XLA-level operations on the device are + traced. + + `TRACE_COMPUTE`: This traces compute operations on the device. + + `TRACE_COMPUTE_AND_SYNC`: This traces both compute operations and + synchronization events on the device. + + If "tpu_trace_mode" is not provided the trace_mode defaults to + TRACE_ONLY_XLA. + +2. `tpu_num_sparse_cores_to_trace`: Specifies the number of sparse cores to + trace on the TPU. +3. `tpu_num_sparse_core_tiles_to_trace`: Specifies the number of tiles within + each sparse core to trace on the TPU. +4. `tpu_num_chips_to_profile_per_task`: Specifies the number of TPU chips to + profile per task. + +For example: + +``` +options = ProfileOptions() +options.advanced_configuration = {"tpu_trace_mode" : "TRACE_ONLY_HOST", "tpu_num_sparse_cores_to_trace" : 2} + +``` + +Returns InvalidArgumentError if any unrecognized keys or option values are +found. + ### Troubleshooting #### GPU profiling @@ -308,8 +420,8 @@ replace, so it may be necessary to uninstall everything and reinstall a single version: ```shell -pip uninstall tensorflow tf-nightly tensorboard tb-nightly -pip install tensorflow +pip uninstall tensorflow tf-nightly tensorboard tb-nightly xprof xprof-nightly tensorboard-plugin-profile tbp-nightly +pip install tensorboard xprof ``` ## Nsight diff --git a/docs/quickstart.md b/docs/quickstart.md index 77cbb9d46ab8..40c50dba3dbd 100644 --- a/docs/quickstart.md +++ b/docs/quickstart.md @@ -58,7 +58,7 @@ print(selu(x)) ``` You'll find a few differences between JAX arrays and NumPy arrays once you begin digging-in; -these are explored in [🔪 JAX - The Sharp Bits 🔪](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html). +these are explored in [🔪 JAX - The Sharp Bits 🔪](https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html). ## Just-in-time compilation with {func}`jax.jit` JAX runs transparently on the GPU or TPU (falling back to CPU if you don't have one). However, in the above example, JAX is dispatching kernels to the chip one operation at a time. If we have a sequence of operations, we can use the {func}`jax.jit` function to compile this sequence of operations together using XLA. @@ -121,7 +121,7 @@ print(first_finite_differences(sum_logistic, x_small)) ``` The {func}`~jax.grad` and {func}`~jax.jit` transformations compose and can be mixed arbitrarily. -In the above example we jitted `sum_logistic` and then took its derivative. We can go further: +For instance, while the `sum_logistic` function was differentiated directly in the previous example, it could also be JIT-compiled, and these operations can be combined. We can go further: ```{code-cell} print(grad(jit(grad(jit(grad(sum_logistic)))))(1.0)) diff --git a/docs/random-numbers.md b/docs/random-numbers.md index 00f77e3473bb..5562dc3f43d5 100644 --- a/docs/random-numbers.md +++ b/docs/random-numbers.md @@ -150,9 +150,9 @@ print(random.normal(key)) print(random.normal(key)) ``` -Re-using the same key, even with different {mod}`~jax.random` APIs, can result in correlated outputs, which is generally undesirable. +Reusing the same key, even with different {mod}`~jax.random` APIs, can result in correlated outputs, which is generally undesirable. -**The rule of thumb is: never reuse keys (unless you want identical outputs).** +**The rule of thumb is: never reuse keys (unless you want identical outputs). Reusing the same state will cause __sadness__ and __monotony__, depriving the end user of __lifegiving chaos__.** JAX uses a modern [Threefry counter-based PRNG](https://github.com/jax-ml/jax/blob/main/docs/jep/263-prng.md) that's splittable. That is, its design allows us to fork the PRNG state into new PRNGs for use with parallel stochastic generation. In order to generate different and independent samples, you must {func}`~jax.random.split` the key explicitly before passing it to a random function: diff --git a/docs/rank_promotion_warning.rst b/docs/rank_promotion_warning.rst index 5e4e7ec65cbc..6ec0000e2ffc 100644 --- a/docs/rank_promotion_warning.rst +++ b/docs/rank_promotion_warning.rst @@ -9,14 +9,14 @@ surprising bugs where a silent rank promotion masks an underlying shape error. Here's an example of rank promotion: ->>> import numpy as np ->>> x = np.arange(12).reshape(4, 3) ->>> y = np.array([0, 1, 0]) +>>> from jax import numpy as jnp +>>> x = jnp.arange(12).reshape(4, 3) +>>> y = jnp.array([0, 1, 0]) >>> x + y -array([[ 0, 2, 2], +Array([[ 0, 2, 2], [ 3, 5, 5], [ 6, 8, 8], - [ 9, 11, 11]]) + [ 9, 11, 11]], dtype=int32) To avoid potential surprises, :code:`jax.numpy` is configurable so that expressions requiring rank promotion can lead to a warning, error, or can be diff --git a/docs/requirements.txt b/docs/requirements.txt index 5d49222bbb42..1fd706ab01a5 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,6 +1,7 @@ absl-py ipython>=8.8.0 # 8.7.0 has ipython3 lexer error pydata-sphinx-theme==0.14.4 # v0.15 breaks sidebar toggling +snowballstemmer<3.0.0 # v3.0.0 incompatible with older sphinx; missing stemmer sphinx>=7.3.2,<8.0 # 7.3.0 breaks sphinx-book-theme; 8.0 breaks myst-nb 1.1 sphinx-book-theme==1.1.1 # v1.1.2 requires pydata-sphinx-theme v0.15 sphinx-copybutton>=0.5.0 diff --git a/docs/sharded-computation.ipynb b/docs/sharded-computation.ipynb index d3ddac4edbdb..72cc2d193bfd 100644 --- a/docs/sharded-computation.ipynb +++ b/docs/sharded-computation.ipynb @@ -7,27 +7,50 @@ "(sharded-computation)=\n", "# Introduction to parallel programming\n", "\n", - "\n", + "\n", "\n", "This tutorial serves as an introduction to device parallelism for Single-Program Multi-Data (SPMD) code in JAX. SPMD is a parallelism technique where the same computation, such as the forward pass of a neural network, can be run on different input data (for example, different inputs in a batch) in parallel on different devices, such as several GPUs or Google TPUs.\n", "\n", "The tutorial covers three modes of parallel computation:\n", "\n", - "- _Automatic parallelism via {func}`jax.jit`_: The compiler chooses the optimal computation strategy (a.k.a. \"the compiler takes the wheel\").\n", - "- _Semi-automated parallelism_ using {func}`jax.jit` and {func}`jax.lax.with_sharding_constraint`\n", - "- _Fully manual parallelism with manual control using {func}`jax.experimental.shard_map.shard_map`_: `shard_map` enables per-device code and explicit communication collectives\n", + "- _Automatic sharding via {func}`jax.jit`_: The compiler chooses the optimal computation strategy (a.k.a. \"the compiler takes the wheel\").\n", + "- *Explicit Sharding* (\\*new\\*) is similar to automatic sharding in that\n", + " you're writing a global-view program. The difference is that the sharding\n", + " of each array is part of the array's JAX-level type making it an explicit\n", + " part of the programming model. These shardings are propagated at the JAX\n", + " level and queryable at trace time. It's still the compiler's responsibility\n", + " to turn the whole-array program into per-device programs (turning `jnp.sum`\n", + " into `psum` for example) but the compiler is heavily constrained by the\n", + " user-supplied shardings.\n", + "- _Fully manual sharding with manual control using {func}`jax.shard_map`_: `shard_map` enables per-device code and explicit communication collectives\n", "\n", - "Using these schools of thought for SPMD, you can transform a function written for one device into a function that can run in parallel on multiple devices.\n", + "A summary table:\n", "\n", - "If you are running these examples in a Google Colab notebook, make sure that your hardware accelerator is the latest Google TPU by checking your notebook settings: **Runtime** > **Change runtime type** > **Hardware accelerator** > **TPU v2** (which provides eight devices to work with)." + "| Mode | View? | Explicit sharding? | Explicit Collectives? |\n", + "|---|---|---|---|\n", + "| Auto | Global | ❌ | ❌ |\n", + "| Explicit | Global | ✅ | ❌ |\n", + "| Manual | Per-device | ✅ | ✅ |\n", + "\n", + "Using these schools of thought for SPMD, you can transform a function written for one device into a function that can run in parallel on multiple devices." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7efa1e66", + "metadata": {}, + "outputs": [], + "source": [ + "import jax\n", + "\n", + "jax.config.update('jax_num_cpu_devices', 8)" ] }, { "cell_type": "code", "execution_count": 1, - "metadata": { - "outputId": "18905ae4-7b5e-4bb9-acb4-d8ab914cb456" - }, + "metadata": {}, "outputs": [ { "data": { @@ -48,7 +71,6 @@ } ], "source": [ - "import jax\n", "jax.devices()" ] }, @@ -84,7 +106,9 @@ } ], "source": [ + "import numpy as np\n", "import jax.numpy as jnp\n", + "\n", "arr = jnp.arange(32.0).reshape(4, 8)\n", "arr.devices()" ] @@ -264,51 +288,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "UEObolTqw4pp" - }, - "source": [ - "The device numbers here are not in numerical order, because the mesh reflects the underlying toroidal topology of the device.\n", - "\n", - "The {class}`~jax.sharding.NamedSharding` includes a parameter called `memory_kind`. This parameter determines the type of memory to be used and defaults to `device`. You can set this parameter to `pinned_host` if you prefer to place it on the host.\n", - "\n", - "To create a new sharding that only differs from an existing sharding in terms of its memory kind, you can use the `with_memory_kind` method on the existing sharding." - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "aKNeOHTJnqmS", - "outputId": "847c53ec-8b2e-4be0-f993-7fde7d77c0f2" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "pinned_host\n", - "device\n" - ] - } - ], - "source": [ - "s_host = jax.NamedSharding(mesh, P('x', 'y'), memory_kind='pinned_host')\n", - "s_dev = s_host.with_memory_kind('device')\n", - "arr_host = jax.device_put(arr, s_host)\n", - "arr_dev = jax.device_put(arr, s_dev)\n", - "print(arr_host.sharding.memory_kind)\n", - "print(arr_dev.sharding.memory_kind)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "jDHYnVqHwaST" - }, + "metadata": {}, "source": [ "## 1. Automatic parallelism via `jit`\n", "\n", @@ -400,159 +380,170 @@ "id": "Q4N5mrr9i_ki" }, "source": [ - "The result is partially replicated: that is, the first two elements of the array are replicated on devices `0` and `6`, the second on `1` and `7`, and so on.\n", - "\n", - "### 1.1 Sharding transformation between memory types\n", + "The result is partially replicated: that is, the first two elements of the array are replicated on devices `0` and `4`, the second on `1` and `5`, and so on.\n", "\n", - "The output sharding of a {func}`jax.jit` function can differ from the input sharding if you specify the output sharding using the `out_shardings` parameter. Specifically, the `memory_kind` of the output can be different from that of the input array.\n", + "## 2. Explicit sharding\n", "\n", - "#### Example 1: Pinned host to device memory\n", - "\n", - "In the example below, the {func}`jax.jit` function `f` takes an array sharded in `pinned_host` memory and generates an array in `device` memory." + "The main idea behind explicit shardings, (a.k.a. sharding-in-types), is that\n", + "the JAX-level _type_ of a value includes a description of how the value is sharded.\n", + "We can query the JAX-level type of any JAX value (or Numpy array, or Python\n", + "scalar) using `jax.typeof`:" ] }, { "cell_type": "code", - "execution_count": 11, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "PXu3MhafyRHo", - "outputId": "7bc6821f-a4a9-4cf8-8b21-e279d516d27b" - }, + "execution_count": 9, + "metadata": {}, "outputs": [ + { + "data": { + "text/html": [ + "
  TPU 0    TPU 1    TPU 2    TPU 3    TPU 6    TPU 7    TPU 4    TPU 5  \n",
+       "                                                                        \n",
+       "
\n" + ], + "text/plain": [ + "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121mTPU 0\u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107m \u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107mTPU 1\u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107m \u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82m \u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82mTPU 2\u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214mTPU 3\u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;0;0;0;48;2;231;203;148m \u001b[0m\u001b[38;2;0;0;0;48;2;231;203;148mTPU 6\u001b[0m\u001b[38;2;0;0;0;48;2;231;203;148m \u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207m \u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207mTPU 7\u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207m \u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148m \u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148mTPU 4\u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148m \u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49m \u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49mTPU 5\u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49m \u001b[0m\n", + "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107m \u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;0;0;0;48;2;231;203;148m \u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207m \u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148m \u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49m \u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, { "name": "stdout", "output_type": "stream", "text": [ - "[[ 0. 1. 2. 3. 4. 5. 6. 7.]\n", - " [ 8. 9. 10. 11. 12. 13. 14. 15.]\n", - " [16. 17. 18. 19. 20. 21. 22. 23.]\n", - " [24. 25. 26. 27. 28. 29. 30. 31.]]\n", - "device\n" + "[48. 52. 56. 60. 64. 68. 72. 76.]\n" ] } ], "source": [ - "f = jax.jit(lambda x: x, out_shardings=s_dev)\n", - "out_dev = f(arr_host)\n", - "print(out_dev)\n", - "print(out_dev.sharding.memory_kind)" + "some_array = np.arange(8)\n", + "print(f\"JAX-level type of some_array: {jax.typeof(some_array)}\")" ] }, { "cell_type": "markdown", - "metadata": { - "id": "LuYFqpcBySiX" - }, + "metadata": {}, + "source": [ + "Importantly, we can query the type even while tracing under a `jit` (the JAX-level type\n", + "is almost _defined_ as \"the information about a value we have access to while\n", + "under a jit)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ffe62839", + "metadata": {}, + "outputs": [], "source": [ - "#### Example 2: Device to pinned_host memory\n", + "@jax.jit\n", + "def foo(x):\n", + " print(f\"JAX-level type of x during tracing: {jax.typeof(x)}\")\n", + " return x + x\n", "\n", - "In the example below, the {func}`jax.jit` function `g` takes an array sharded in `device` memory and generates an array in `pinned_host` memory." + "foo(some_array)" + ] + }, + { + "cell_type": "markdown", + "id": "74995421", + "metadata": {}, + "source": [ + "To start seeing shardings in the type we need to set up an explicit-sharding mesh." ] }, { "cell_type": "code", "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "qLsgNlKfybRw", - "outputId": "a16448b9-7e39-408f-b200-505f65ad4464" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[[ 0. 1. 2. 3. 4. 5. 6. 7.]\n", - " [ 8. 9. 10. 11. 12. 13. 14. 15.]\n", - " [16. 17. 18. 19. 20. 21. 22. 23.]\n", - " [24. 25. 26. 27. 28. 29. 30. 31.]]\n", - "pinned_host\n" - ] - } - ], + "id": "e785a694", + "metadata": {}, + "outputs": [], "source": [ - "g = jax.jit(lambda x: x, out_shardings=s_host)\n", - "out_host = g(arr_dev)\n", - "print(out_host)\n", - "print(out_host.sharding.memory_kind)" + "from jax.sharding import AxisType\n", + "\n", + "mesh = jax.make_mesh((2, 4), (\"X\", \"Y\"),\n", + " axis_types=(AxisType.Explicit, AxisType.Explicit))" ] }, { "cell_type": "markdown", - "metadata": { - "id": "7BGD31-owaSU" - }, + "id": "8d81409c", + "metadata": {}, + "source": [ + "Now we can create some sharded arrays:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4969cabd", + "metadata": {}, + "outputs": [], "source": [ - "## 2. Semi-automated sharding with constraints\n", + "replicated_array = np.arange(8).reshape(4, 2)\n", + "sharded_array = jax.device_put(replicated_array, jax.NamedSharding(mesh, P(\"X\", None)))\n", "\n", - "If you'd like to have some control over the sharding used within a particular computation, JAX offers the {func}`~jax.lax.with_sharding_constraint` function. You can use {func}`jax.lax.with_sharding_constraint` (in place of {func}`jax.device_put()`) together with {func}`jax.jit` for more control over how the compiler constraints how the intermediate values and outputs are distributed.\n", + "print(f\"replicated_array type: {jax.typeof(replicated_array)}\")\n", + "print(f\"sharded_array type: {jax.typeof(sharded_array)}\")" + ] + }, + { + "cell_type": "markdown", + "id": "c09acf7d", + "metadata": {}, + "source": [ + "We should read the type `int32[4@X, 2]` as \"a 4-by-2 array of 32-bit ints whose first dimension\n", + "is sharded along mesh axis 'X'. The array is replicated along all other mesh\n", + "axes\"\n", "\n", - "For example, suppose that within `f_contract` above, you'd prefer the output not to be partially-replicated, but rather to be fully sharded across the eight devices:" + "These shardings associated with JAX-level types propagate through operations. For example:" ] }, { "cell_type": "code", - "execution_count": 9, - "metadata": { - "outputId": "8468f5c6-76ca-4367-c9f2-93c723687cfd" - }, - "outputs": [ - { - "data": { - "text/html": [ - "
  TPU 0    TPU 1    TPU 2    TPU 3    TPU 6    TPU 7    TPU 4    TPU 5  \n",
-       "                                                                        \n",
-       "
\n" - ], - "text/plain": [ - "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121mTPU 0\u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107m \u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107mTPU 1\u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107m \u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82m \u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82mTPU 2\u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214mTPU 3\u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;0;0;0;48;2;231;203;148m \u001b[0m\u001b[38;2;0;0;0;48;2;231;203;148mTPU 6\u001b[0m\u001b[38;2;0;0;0;48;2;231;203;148m \u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207m \u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207mTPU 7\u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207m \u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148m \u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148mTPU 4\u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148m \u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49m \u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49mTPU 5\u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49m \u001b[0m\n", - "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107m \u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;0;0;0;48;2;231;203;148m \u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207m \u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148m \u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49m \u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[48. 52. 56. 60. 64. 68. 72. 76.]\n" - ] - } - ], + "execution_count": null, + "id": "ab2f9500", + "metadata": {}, + "outputs": [], "source": [ + "arg0 = jax.device_put(np.arange(4).reshape(4, 1),\n", + " jax.NamedSharding(mesh, P(\"X\", None)))\n", + "arg1 = jax.device_put(np.arange(8).reshape(1, 8),\n", + " jax.NamedSharding(mesh, P(None, \"Y\")))\n", + "\n", "@jax.jit\n", - "def f_contract_2(x):\n", - " out = x.sum(axis=0)\n", - " sharding = jax.sharding.NamedSharding(mesh, P('x'))\n", - " return jax.lax.with_sharding_constraint(out, sharding)\n", + "def add_arrays(x, y):\n", + " ans = x + y\n", + " print(f\"x sharding: {jax.typeof(x)}\")\n", + " print(f\"y sharding: {jax.typeof(y)}\")\n", + " print(f\"ans sharding: {jax.typeof(ans)}\")\n", + " return ans\n", "\n", - "result = f_contract_2(arr_sharded)\n", - "jax.debug.visualize_array_sharding(result)\n", - "print(result)" + "with jax.sharding.use_mesh(mesh):\n", + " add_arrays(arg0, arg1)" ] }, { "cell_type": "markdown", + "id": "dda3d0c5", "metadata": {}, "source": [ - "This gives you a function with the particular output sharding you'd like.\n", + "That's the gist of it. Shardings propagate deterministically at trace time and\n", + "we can query them at trace time.\n", "\n", "## 3. Manual parallelism with `shard_map`\n", "\n", - "In the automatic parallelism methods explored above, you can write a function as if you're operating on the full dataset, and `jit` will split that computation across multiple devices. By contrast, with {func}`jax.experimental.shard_map.shard_map` you write the function that will handle a single shard of data, and `shard_map` will construct the full function.\n", + "In the automatic parallelism methods explored above, you can write a function as if you're operating on the full dataset, and `jit` will split that computation across multiple devices. By contrast, with {func}`jax.shard_map` you write the function that will handle a single shard of data, and `shard_map` will construct the full function.\n", "\n", "`shard_map` works by mapping a function across a particular *mesh* of devices (`shard_map` maps over shards). In the example below:\n", "\n", "- As before, {class}`jax.sharding.Mesh` allows for precise device placement, with the axis names parameter for logical and physical axis names.\n", "- The `in_specs` argument determines the shard sizes. The `out_specs` argument identifies how the blocks are assembled back together.\n", "\n", - "**Note:** {func}`jax.experimental.shard_map.shard_map` code can work inside {func}`jax.jit` if you need it." + "**Note:** {func}`jax.shard_map` code can work inside {func}`jax.jit` if you need it." ] }, { @@ -580,10 +571,9 @@ } ], "source": [ - "from jax.experimental.shard_map import shard_map\n", "mesh = jax.make_mesh((8,), ('x',))\n", "\n", - "f_elementwise_sharded = shard_map(\n", + "f_elementwise_sharded = jax.shard_map(\n", " f_elementwise,\n", " mesh=mesh,\n", " in_specs=P('x'),\n", @@ -624,7 +614,7 @@ " print(f\"device local shape: {x.shape=}\")\n", " return x * 2\n", "\n", - "y = shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P('x'))(x)" + "y = jax.shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P('x'))(x)" ] }, { @@ -658,7 +648,7 @@ "def f(x):\n", " return jnp.sum(x, keepdims=True)\n", "\n", - "shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P('x'))(x)" + "jax.shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P('x'))(x)" ] }, { @@ -693,7 +683,7 @@ " sum_in_shard = x.sum()\n", " return jax.lax.psum(sum_in_shard, 'x')\n", "\n", - "shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P())(x)" + "jax.shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P())(x)" ] }, { @@ -757,7 +747,8 @@ "source": [ "You can automatically run this in a distributed manner using {func}`jax.jit` and passing appropriately sharded data.\n", "\n", - "If you shard the leading axis of both `x` and `weights` in the same way, then the matrix multiplication will automatically happen in parallel:" + "If you shard the leading axis of both `x` and make `weights` fully replicated,\n", + "then the matrix multiplication will automatically happen in parallel:" ] }, { @@ -780,10 +771,8 @@ ], "source": [ "mesh = jax.make_mesh((8,), ('x',))\n", - "sharding = jax.sharding.NamedSharding(mesh, P('x'))\n", - "\n", - "x_sharded = jax.device_put(x, sharding)\n", - "weights_sharded = jax.device_put(weights, sharding)\n", + "x_sharded = jax.device_put(x, jax.NamedSharding(mesh, P('x')))\n", + "weights_sharded = jax.device_put(weights, jax.NamedSharding(mesh, P()))\n", "\n", "layer(x_sharded, weights_sharded, bias)" ] @@ -792,15 +781,13 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Alternatively, you can use {func}`jax.lax.with_sharding_constraint` in the function to automatically distribute unsharded inputs:" + "Alternatively, you can use explicit sharding mode too:" ] }, { "cell_type": "code", "execution_count": 17, - "metadata": { - "outputId": "bb63e8da-ff4f-4e95-f083-10584882daf4" - }, + "metadata": {}, "outputs": [ { "data": { @@ -814,13 +801,22 @@ } ], "source": [ + "explicit_mesh = jax.make_mesh((8,), ('X',), axis_types=(AxisType.Explicit,))\n", + "\n", + "x_sharded = jax.device_put(x, jax.NamedSharding(explicit_mesh, P('X')))\n", + "weights_sharded = jax.device_put(weights, jax.NamedSharding(explicit_mesh, P()))\n", + "\n", "@jax.jit\n", "def layer_auto(x, weights, bias):\n", - " x = jax.lax.with_sharding_constraint(x, sharding)\n", - " weights = jax.lax.with_sharding_constraint(weights, sharding)\n", - " return layer(x, weights, bias)\n", + " print(f\"x sharding: {jax.typeof(x)}\")\n", + " print(f\"weights sharding: {jax.typeof(weights)}\")\n", + " print(f\"bias sharding: {jax.typeof(bias)}\")\n", + " out = layer(x, weights, bias)\n", + " print(f\"out sharding: {jax.typeof(out)}\")\n", + " return out\n", "\n", - "layer_auto(x, weights, bias) # pass in unsharded inputs" + "with jax.sharding.use_mesh(explicit_mesh):\n", + " layer_auto(x_sharded, weights_sharded, bias)" ] }, { @@ -852,7 +848,7 @@ "from functools import partial\n", "\n", "@jax.jit\n", - "@partial(shard_map, mesh=mesh,\n", + "@partial(jax.shard_map, mesh=mesh,\n", " in_specs=(P('x'), P('x', None), P(None)),\n", " out_specs=P(None))\n", "def layer_sharded(x, weights, bias):\n", @@ -871,6 +867,7 @@ "\n", "To learn about each SPMD method in-depth, check out these docs:\n", "- {doc}`../notebooks/Distributed_arrays_and_automatic_parallelization`\n", + "- {doc}`../notebooks/explicit-sharding`\n", "- {doc}`../notebooks/shard_map`" ] } diff --git a/docs/sharded-computation.md b/docs/sharded-computation.md index b05eb8d5f66e..89ffbc07da38 100644 --- a/docs/sharded-computation.md +++ b/docs/sharded-computation.md @@ -14,24 +14,40 @@ kernelspec: (sharded-computation)= # Introduction to parallel programming - + This tutorial serves as an introduction to device parallelism for Single-Program Multi-Data (SPMD) code in JAX. SPMD is a parallelism technique where the same computation, such as the forward pass of a neural network, can be run on different input data (for example, different inputs in a batch) in parallel on different devices, such as several GPUs or Google TPUs. The tutorial covers three modes of parallel computation: -- _Automatic parallelism via {func}`jax.jit`_: The compiler chooses the optimal computation strategy (a.k.a. "the compiler takes the wheel"). -- _Semi-automated parallelism_ using {func}`jax.jit` and {func}`jax.lax.with_sharding_constraint` -- _Fully manual parallelism with manual control using {func}`jax.experimental.shard_map.shard_map`_: `shard_map` enables per-device code and explicit communication collectives +- _Automatic sharding via {func}`jax.jit`_: The compiler chooses the optimal computation strategy (a.k.a. "the compiler takes the wheel"). +- *Explicit Sharding* (\*new\*) is similar to automatic sharding in that + you're writing a global-view program. The difference is that the sharding + of each array is part of the array's JAX-level type making it an explicit + part of the programming model. These shardings are propagated at the JAX + level and queryable at trace time. It's still the compiler's responsibility + to turn the whole-array program into per-device programs (turning `jnp.sum` + into `psum` for example) but the compiler is heavily constrained by the + user-supplied shardings. +- _Fully manual sharding with manual control using {func}`jax.shard_map`_: `shard_map` enables per-device code and explicit communication collectives + +A summary table: + +| Mode | View? | Explicit sharding? | Explicit Collectives? | +|---|---|---|---| +| Auto | Global | ❌ | ❌ | +| Explicit | Global | ✅ | ❌ | +| Manual | Per-device | ✅ | ✅ | Using these schools of thought for SPMD, you can transform a function written for one device into a function that can run in parallel on multiple devices. -If you are running these examples in a Google Colab notebook, make sure that your hardware accelerator is the latest Google TPU by checking your notebook settings: **Runtime** > **Change runtime type** > **Hardware accelerator** > **TPU v2** (which provides eight devices to work with). - ```{code-cell} -:outputId: 18905ae4-7b5e-4bb9-acb4-d8ab914cb456 - import jax + +jax.config.update('jax_num_cpu_devices', 8) +``` + +```{code-cell} jax.devices() ``` @@ -46,7 +62,9 @@ In the simplest cases, arrays are sharded on a single device, as demonstrated be ```{code-cell} :outputId: 39fdbb79-d5c0-4ea6-8b20-88b2c502a27a +import numpy as np import jax.numpy as jnp + arr = jnp.arange(32.0).reshape(4, 8) arr.devices() ``` @@ -90,31 +108,6 @@ print(arr_sharded) jax.debug.visualize_array_sharding(arr_sharded) ``` -+++ {"id": "UEObolTqw4pp"} - -The device numbers here are not in numerical order, because the mesh reflects the underlying toroidal topology of the device. - -The {class}`~jax.sharding.NamedSharding` includes a parameter called `memory_kind`. This parameter determines the type of memory to be used and defaults to `device`. You can set this parameter to `pinned_host` if you prefer to place it on the host. - -To create a new sharding that only differs from an existing sharding in terms of its memory kind, you can use the `with_memory_kind` method on the existing sharding. - -```{code-cell} ---- -colab: - base_uri: https://localhost:8080/ -id: aKNeOHTJnqmS -outputId: 847c53ec-8b2e-4be0-f993-7fde7d77c0f2 ---- -s_host = jax.NamedSharding(mesh, P('x', 'y'), memory_kind='pinned_host') -s_dev = s_host.with_memory_kind('device') -arr_host = jax.device_put(arr, s_host) -arr_dev = jax.device_put(arr, s_dev) -print(arr_host.sharding.memory_kind) -print(arr_dev.sharding.memory_kind) -``` - -+++ {"id": "jDHYnVqHwaST"} - ## 1. Automatic parallelism via `jit` Once you have sharded data, the easiest way to do parallel computation is to simply pass the data to a {func}`jax.jit`-compiled function! In JAX, you need to only specify how you want the input and output of your code to be partitioned, and the compiler will figure out how to: 1) partition everything inside; and 2) compile inter-device communications. @@ -154,90 +147,96 @@ print(result) +++ {"id": "Q4N5mrr9i_ki"} -The result is partially replicated: that is, the first two elements of the array are replicated on devices `0` and `6`, the second on `1` and `7`, and so on. - -### 1.1 Sharding transformation between memory types +The result is partially replicated: that is, the first two elements of the array are replicated on devices `0` and `4`, the second on `1` and `5`, and so on. -The output sharding of a {func}`jax.jit` function can differ from the input sharding if you specify the output sharding using the `out_shardings` parameter. Specifically, the `memory_kind` of the output can be different from that of the input array. +## 2. Explicit sharding -#### Example 1: Pinned host to device memory - -In the example below, the {func}`jax.jit` function `f` takes an array sharded in `pinned_host` memory and generates an array in `device` memory. +The main idea behind explicit shardings, (a.k.a. sharding-in-types), is that +the JAX-level _type_ of a value includes a description of how the value is sharded. +We can query the JAX-level type of any JAX value (or Numpy array, or Python +scalar) using `jax.typeof`: ```{code-cell} ---- -colab: - base_uri: https://localhost:8080/ -id: PXu3MhafyRHo -outputId: 7bc6821f-a4a9-4cf8-8b21-e279d516d27b ---- -f = jax.jit(lambda x: x, out_shardings=s_dev) -out_dev = f(arr_host) -print(out_dev) -print(out_dev.sharding.memory_kind) +some_array = np.arange(8) +print(f"JAX-level type of some_array: {jax.typeof(some_array)}") ``` -+++ {"id": "LuYFqpcBySiX"} +Importantly, we can query the type even while tracing under a `jit` (the JAX-level type +is almost _defined_ as "the information about a value we have access to while +under a jit). + +```{code-cell} +@jax.jit +def foo(x): + print(f"JAX-level type of x during tracing: {jax.typeof(x)}") + return x + x -#### Example 2: Device to pinned_host memory +foo(some_array) +``` -In the example below, the {func}`jax.jit` function `g` takes an array sharded in `device` memory and generates an array in `pinned_host` memory. +To start seeing shardings in the type we need to set up an explicit-sharding mesh. ```{code-cell} ---- -colab: - base_uri: https://localhost:8080/ -id: qLsgNlKfybRw -outputId: a16448b9-7e39-408f-b200-505f65ad4464 ---- -g = jax.jit(lambda x: x, out_shardings=s_host) -out_host = g(arr_dev) -print(out_host) -print(out_host.sharding.memory_kind) +from jax.sharding import AxisType + +mesh = jax.make_mesh((2, 4), ("X", "Y"), + axis_types=(AxisType.Explicit, AxisType.Explicit)) ``` -+++ {"id": "7BGD31-owaSU"} +Now we can create some sharded arrays: -## 2. Semi-automated sharding with constraints +```{code-cell} +replicated_array = np.arange(8).reshape(4, 2) +sharded_array = jax.device_put(replicated_array, jax.NamedSharding(mesh, P("X", None))) + +print(f"replicated_array type: {jax.typeof(replicated_array)}") +print(f"sharded_array type: {jax.typeof(sharded_array)}") +``` -If you'd like to have some control over the sharding used within a particular computation, JAX offers the {func}`~jax.lax.with_sharding_constraint` function. You can use {func}`jax.lax.with_sharding_constraint` (in place of {func}`jax.device_put()`) together with {func}`jax.jit` for more control over how the compiler constraints how the intermediate values and outputs are distributed. +We should read the type `int32[4@X, 2]` as "a 4-by-2 array of 32-bit ints whose first dimension +is sharded along mesh axis 'X'. The array is replicated along all other mesh +axes" -For example, suppose that within `f_contract` above, you'd prefer the output not to be partially-replicated, but rather to be fully sharded across the eight devices: +These shardings associated with JAX-level types propagate through operations. For example: ```{code-cell} -:outputId: 8468f5c6-76ca-4367-c9f2-93c723687cfd +arg0 = jax.device_put(np.arange(4).reshape(4, 1), + jax.NamedSharding(mesh, P("X", None))) +arg1 = jax.device_put(np.arange(8).reshape(1, 8), + jax.NamedSharding(mesh, P(None, "Y"))) @jax.jit -def f_contract_2(x): - out = x.sum(axis=0) - sharding = jax.sharding.NamedSharding(mesh, P('x')) - return jax.lax.with_sharding_constraint(out, sharding) - -result = f_contract_2(arr_sharded) -jax.debug.visualize_array_sharding(result) -print(result) +def add_arrays(x, y): + ans = x + y + print(f"x sharding: {jax.typeof(x)}") + print(f"y sharding: {jax.typeof(y)}") + print(f"ans sharding: {jax.typeof(ans)}") + return ans + +with jax.sharding.use_mesh(mesh): + add_arrays(arg0, arg1) ``` -This gives you a function with the particular output sharding you'd like. +That's the gist of it. Shardings propagate deterministically at trace time and +we can query them at trace time. ## 3. Manual parallelism with `shard_map` -In the automatic parallelism methods explored above, you can write a function as if you're operating on the full dataset, and `jit` will split that computation across multiple devices. By contrast, with {func}`jax.experimental.shard_map.shard_map` you write the function that will handle a single shard of data, and `shard_map` will construct the full function. +In the automatic parallelism methods explored above, you can write a function as if you're operating on the full dataset, and `jit` will split that computation across multiple devices. By contrast, with {func}`jax.shard_map` you write the function that will handle a single shard of data, and `shard_map` will construct the full function. `shard_map` works by mapping a function across a particular *mesh* of devices (`shard_map` maps over shards). In the example below: - As before, {class}`jax.sharding.Mesh` allows for precise device placement, with the axis names parameter for logical and physical axis names. - The `in_specs` argument determines the shard sizes. The `out_specs` argument identifies how the blocks are assembled back together. -**Note:** {func}`jax.experimental.shard_map.shard_map` code can work inside {func}`jax.jit` if you need it. +**Note:** {func}`jax.shard_map` code can work inside {func}`jax.jit` if you need it. ```{code-cell} :outputId: 435c32f3-557a-4676-c11b-17e6bab8c1e2 -from jax.experimental.shard_map import shard_map mesh = jax.make_mesh((8,), ('x',)) -f_elementwise_sharded = shard_map( +f_elementwise_sharded = jax.shard_map( f_elementwise, mesh=mesh, in_specs=P('x'), @@ -259,7 +258,7 @@ def f(x): print(f"device local shape: {x.shape=}") return x * 2 -y = shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P('x'))(x) +y = jax.shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P('x'))(x) ``` Because each of your functions only "sees" the device-local part of the data, it means that aggregation-like functions require some extra thought. @@ -272,7 +271,7 @@ For example, here's what a `shard_map` of a {func}`jax.numpy.sum` looks like: def f(x): return jnp.sum(x, keepdims=True) -shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P('x'))(x) +jax.shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P('x'))(x) ``` Your function `f` operates separately on each shard, and the resulting summation reflects this. @@ -286,7 +285,7 @@ def f(x): sum_in_shard = x.sum() return jax.lax.psum(sum_in_shard, 'x') -shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P())(x) +jax.shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P())(x) ``` Because the output no longer has a sharded dimension, set `out_specs=P()` (recall that the `out_specs` argument identifies how the blocks are assembled back together in `shard_map`). @@ -320,32 +319,38 @@ layer(x, weights, bias) You can automatically run this in a distributed manner using {func}`jax.jit` and passing appropriately sharded data. -If you shard the leading axis of both `x` and `weights` in the same way, then the matrix multiplication will automatically happen in parallel: +If you shard the leading axis of both `x` and make `weights` fully replicated, +then the matrix multiplication will automatically happen in parallel: ```{code-cell} :outputId: 80be899e-8dbc-4bfc-acd2-0f3d554a0aa5 mesh = jax.make_mesh((8,), ('x',)) -sharding = jax.sharding.NamedSharding(mesh, P('x')) - -x_sharded = jax.device_put(x, sharding) -weights_sharded = jax.device_put(weights, sharding) +x_sharded = jax.device_put(x, jax.NamedSharding(mesh, P('x'))) +weights_sharded = jax.device_put(weights, jax.NamedSharding(mesh, P())) layer(x_sharded, weights_sharded, bias) ``` -Alternatively, you can use {func}`jax.lax.with_sharding_constraint` in the function to automatically distribute unsharded inputs: +Alternatively, you can use explicit sharding mode too: ```{code-cell} -:outputId: bb63e8da-ff4f-4e95-f083-10584882daf4 +explicit_mesh = jax.make_mesh((8,), ('X',), axis_types=(AxisType.Explicit,)) + +x_sharded = jax.device_put(x, jax.NamedSharding(explicit_mesh, P('X'))) +weights_sharded = jax.device_put(weights, jax.NamedSharding(explicit_mesh, P())) @jax.jit def layer_auto(x, weights, bias): - x = jax.lax.with_sharding_constraint(x, sharding) - weights = jax.lax.with_sharding_constraint(weights, sharding) - return layer(x, weights, bias) - -layer_auto(x, weights, bias) # pass in unsharded inputs + print(f"x sharding: {jax.typeof(x)}") + print(f"weights sharding: {jax.typeof(weights)}") + print(f"bias sharding: {jax.typeof(bias)}") + out = layer(x, weights, bias) + print(f"out sharding: {jax.typeof(out)}") + return out + +with jax.sharding.use_mesh(explicit_mesh): + layer_auto(x_sharded, weights_sharded, bias) ``` Finally, you can do the same thing with `shard_map`, using {func}`jax.lax.psum` to indicate the cross-shard collective required for the matrix product: @@ -356,7 +361,7 @@ Finally, you can do the same thing with `shard_map`, using {func}`jax.lax.psum` from functools import partial @jax.jit -@partial(shard_map, mesh=mesh, +@partial(jax.shard_map, mesh=mesh, in_specs=(P('x'), P('x', None), P(None)), out_specs=P(None)) def layer_sharded(x, weights, bias): @@ -371,4 +376,5 @@ This tutorial serves as a brief introduction of sharded and parallel computation To learn about each SPMD method in-depth, check out these docs: - {doc}`../notebooks/Distributed_arrays_and_automatic_parallelization` +- {doc}`../notebooks/explicit-sharding` - {doc}`../notebooks/shard_map` diff --git a/docs/sphinxext/jax_list_config_options.py b/docs/sphinxext/jax_list_config_options.py new file mode 100644 index 000000000000..54f7f6eebe85 --- /dev/null +++ b/docs/sphinxext/jax_list_config_options.py @@ -0,0 +1,160 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from operator import itemgetter +from typing import Any, List + +from docutils import nodes +from sphinx.util import logging +from sphinx.util.docutils import SphinxDirective + +logger = logging.getLogger(__name__) + +_deprecations = ( + 'jax_default_dtype_bits', # an experiment that we never documented, but we can't remove it because Keras depends on its existing broken behavior + 'jax_serialization_version' +) + +def create_field_item(label, content): + """Create a field list item with a label and content side by side. + + Args: + label: The label text for the field name + content: The content to add (a node or text) + + Returns: + A field list item with the label and content side by side. + """ + # Create a field list item + field = nodes.field() + + # Create the field name (label) + field_name = nodes.field_name() + field_name += nodes.Text(label) + field += field_name + + # Create the field body (content) + field_body = nodes.field_body() + + if isinstance(content, str): + para = nodes.paragraph() + para += nodes.Text(content) + field_body += para + elif isinstance(content, nodes.Node): + field_body += content + + field += field_body + return field + +class ConfigOptionDirective(SphinxDirective): + required_arguments = 0 + optional_arguments = 0 + has_content = False + + def run(self) -> List[nodes.Node]: + from jax._src.config import config as jax_config + + config_options = sorted(jax_config.meta.items(), key=itemgetter(0)) + result = [] + + for name, (opt_type, meta_args, meta_kwargs) in config_options: + if name in _deprecations: + continue + + holder = jax_config._value_holders[name] + + # Create target for linking + target = nodes.target() + target['ids'].append(name) + result.append(target) + + # Create a section for this option + option_section = nodes.section() + option_section['ids'].append(name) + option_section['classes'].append('config-option-section') + + # Create a title with the option name (important for TOC) + title = nodes.title() + title['classes'] = ['h4'] + title += nodes.Text(name.replace("jax_", "").replace("_", " ").title()) + option_section += title + + # Create a field list for side-by-side display + field_list = nodes.field_list() + field_list['classes'].append('config-field-list') + + # Add type information as a field item + if opt_type == "enum": + type_para = nodes.paragraph() + emphasis_node = nodes.emphasis() + emphasis_node += nodes.Text("Enum values: ") + type_para += emphasis_node + + for i, value in enumerate(enum_values := meta_kwargs.get('enum_values', [])): + type_para += nodes.literal(text=repr(value)) + if i < len(enum_values) - 1: + type_para += nodes.Text(", ") + else: + type_para = nodes.paragraph() + type_para += nodes.literal(text=opt_type.__name__) + + field_list += create_field_item("Type", type_para) + + # Add default value information + default_para = nodes.paragraph() + default_para += nodes.literal(text=repr(holder.value)) + field_list += create_field_item("Default Value", default_para) + + # Add configuration string information + string_para = nodes.paragraph() + string_para += nodes.literal(text=repr(name)) + field_list += create_field_item("Configuration String", string_para) + + string_para = nodes.paragraph() + string_para += nodes.literal(text=name.upper()) + field_list += create_field_item("Environment Variable", string_para) + + # Add the field list to the section + option_section += field_list + + # Add help text in a description box + if (help_text := meta_kwargs.get('help')): + help_para = nodes.paragraph() + # logger.error(name) + # logger.warning(help_text) + + # If we get here, help text seems valid - proceed with normal parsing + # parsed = nodes.Text(help_text) + help_para += self.parse_text_to_nodes(help_text) + + option_section += help_para + + result.append(option_section) + # Add an extra paragraph to ensure proper separation + result.append(nodes.paragraph()) + result.append(nodes.paragraph()) # ensure new line + + return result + + def get_location(self) -> Any: + return (self.env.docname, self.lineno) + +def setup(app): + app.add_directive("list_config_options", ConfigOptionDirective) + + return { + "version": "0.1", + "parallel_read_safe": True, + "parallel_write_safe": True, + } diff --git a/docs/stateful-computations.md b/docs/stateful-computations.md index 30c626bec4e3..1bd719aa2df2 100644 --- a/docs/stateful-computations.md +++ b/docs/stateful-computations.md @@ -20,7 +20,7 @@ kernelspec: JAX transformations like {func}`~jax.jit`, {func}`~jax.vmap`, {func}`~jax.grad`, require the functions they wrap to be pure: that is, functions whose outputs depend *solely* on the inputs, and which have no side effects such as updating of global state. -You can find a discussion of this in [JAX sharp bits: Pure functions](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#pure-functions). +You can find a discussion of this in [JAX sharp bits: Pure functions](https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html#pure-functions). This constraint can pose some challenges in the context of machine learning, where state may exist in many forms. For example: diff --git a/docs/type_promotion.rst b/docs/type_promotion.rst index d3724745fe08..8227aff384aa 100644 --- a/docs/type_promotion.rst +++ b/docs/type_promotion.rst @@ -4,7 +4,7 @@ Type promotion semantics ======================== This document describes JAX's type promotion rules–i.e., the result of :func:`jax.numpy.promote_types` for each pair of types. -For some background on the considerations that went into the design of what is described below, see `Design of Type Promotion Semantics for JAX `_. +For some background on the considerations that went into the design of what is described below, see `Design of Type Promotion Semantics for JAX `_. JAX's type promotion behavior is determined via the following type promotion lattice: diff --git a/docs/user_guides.rst b/docs/user_guides.rst index 6481da7a31dd..47984fc493f4 100644 --- a/docs/user_guides.rst +++ b/docs/user_guides.rst @@ -26,7 +26,6 @@ or deployed codebases. errors aot export/index - type_promotion transfer_guard .. toctree:: diff --git a/docs/xla_flags.md b/docs/xla_flags.md index 1e374abea005..24bb8a96c91c 100644 --- a/docs/xla_flags.md +++ b/docs/xla_flags.md @@ -85,4 +85,4 @@ XLA_FLAGS='--flag1=value1 --flag2=value2' python3 source.py | `xla_gpu_enable_reduce_scatter_combine_by_dim` | Boolean (true/false) | Combine reduce-scatter ops with the same dimension or irrespective of their dimension. | **Additional reading:** -* [GPU performance tips](https://jax.readthedocs.io/en/latest/gpu_performance_tips.html#xla-performance-flags) +* [GPU performance tips](https://docs.jax.dev/en/latest/gpu_performance_tips.html#xla-performance-flags) diff --git a/examples/ffi/CMakeLists.txt b/examples/ffi/CMakeLists.txt index ea7670b81ccc..4a93cc490d33 100644 --- a/examples/ffi/CMakeLists.txt +++ b/examples/ffi/CMakeLists.txt @@ -3,10 +3,10 @@ project(${SKBUILD_PROJECT_NAME} LANGUAGES CXX) option(JAX_FFI_EXAMPLE_ENABLE_CUDA "Enable CUDA support" OFF) -find_package(Python 3.10 REQUIRED COMPONENTS Interpreter Development.Module) +find_package(Python 3.11 REQUIRED COMPONENTS Interpreter Development.Module) execute_process( COMMAND "${Python_EXECUTABLE}" - "-c" "from jax.extend import ffi; print(ffi.include_dir())" + "-c" "from jax import ffi; print(ffi.include_dir())" OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE XLA_DIR) message(STATUS "XLA include directory: ${XLA_DIR}") diff --git a/examples/ffi/README.md b/examples/ffi/README.md index bd45408e50d8..c490f014859b 100644 --- a/examples/ffi/README.md +++ b/examples/ffi/README.md @@ -2,7 +2,7 @@ This directory includes an example project demonstrating the use of JAX's foreign function interface (FFI). The JAX docs provide more information about -this interface in [the FFI tutorial](https://jax.readthedocs.io/en/latest/ffi.html), +this interface in [the FFI tutorial](https://docs.jax.dev/en/latest/ffi.html), but the example in this directory complements that document by demonstrating (and testing!) the full packaging workflow, and some more advanced use cases. Within the example project, there are several example calls: diff --git a/examples/ffi/pyproject.toml b/examples/ffi/pyproject.toml index 130dd91bbc70..84e2c4700500 100644 --- a/examples/ffi/pyproject.toml +++ b/examples/ffi/pyproject.toml @@ -5,7 +5,7 @@ build-backend = "scikit_build_core.build" [project] name = "jax_ffi_example" version = "0.0.1" -requires-python = ">=3.10" +requires-python = ">=3.11" dependencies = ["jax"] [project.optional-dependencies] diff --git a/examples/ffi/src/jax_ffi_example/gpu_examples.cc b/examples/ffi/src/jax_ffi_example/gpu_examples.cc index 921039debe5d..79a4ee91e8c6 100644 --- a/examples/ffi/src/jax_ffi_example/gpu_examples.cc +++ b/examples/ffi/src/jax_ffi_example/gpu_examples.cc @@ -16,8 +16,8 @@ limitations under the License. #include #include -#include "nanobind/nanobind.h" #include "cuda_runtime_api.h" +#include "nanobind/nanobind.h" #include "xla/ffi/api/ffi.h" namespace nb = nanobind; diff --git a/examples/ffi/src/jax_ffi_example/rms_norm.cc b/examples/ffi/src/jax_ffi_example/rms_norm.cc index 819f3b9f868d..bcfc1eb67aa4 100644 --- a/examples/ffi/src/jax_ffi_example/rms_norm.cc +++ b/examples/ffi/src/jax_ffi_example/rms_norm.cc @@ -16,8 +16,6 @@ limitations under the License. #include #include #include -#include -#include #include #include diff --git a/examples/ffi/src/jax_ffi_example/rms_norm.py b/examples/ffi/src/jax_ffi_example/rms_norm.py index 6dbfe5043ddf..996eb9e5d935 100644 --- a/examples/ffi/src/jax_ffi_example/rms_norm.py +++ b/examples/ffi/src/jax_ffi_example/rms_norm.py @@ -14,9 +14,9 @@ """An example demontrating the basic end-to-end use of the JAX FFI. This example is exactly the same as the one in the `FFI tutorial -`, so more details can be found +`, so more details can be found on that page. But, the high level summary is that we implement our custom -extension in ``rms_norm.cc``, then call it usin ``jax.ffi.ffi_call`` in +extension in ``rms_norm.cc``, then call it using ``jax.ffi.ffi_call`` in this module. The behavior under autodiff is implemented using ``jax.custom_vjp``. """ diff --git a/examples/jax_cpp/BUILD b/examples/jax_cpp/BUILD index b3cb995aae21..86f3129c9876 100644 --- a/examples/jax_cpp/BUILD +++ b/examples/jax_cpp/BUILD @@ -21,6 +21,7 @@ cc_binary( srcs = ["main.cc"], tags = ["manual"], deps = [ + "@com_google_absl//absl/log", "@com_google_absl//absl/status:statusor", "@tsl//tsl/platform:logging", "@tsl//tsl/platform:platform_port", @@ -33,6 +34,7 @@ cc_binary( "@xla//xla/pjrt/plugin/xla_cpu:cpu_client_options", "@xla//xla/pjrt/plugin/xla_cpu:xla_cpu_pjrt_client", "@xla//xla/service:hlo_module_config", + "@xla//xla/service:hlo_proto_cc", "@xla//xla/tools:hlo_module_loader", ], ) diff --git a/examples/jax_cpp/main.cc b/examples/jax_cpp/main.cc index 0a1d3a63acfd..8deea5448fec 100644 --- a/examples/jax_cpp/main.cc +++ b/examples/jax_cpp/main.cc @@ -41,7 +41,8 @@ limitations under the License. #include #include -#include "third_party/absl/status/statusor.h" +#include "absl/log/log.h" +#include "absl/status/statusor.h" #include "xla/hlo/builder/xla_computation.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/literal.h" @@ -50,6 +51,7 @@ limitations under the License. #include "xla/pjrt/pjrt_executable.h" #include "xla/pjrt/plugin/xla_cpu/cpu_client_options.h" #include "xla/pjrt/plugin/xla_cpu/xla_cpu_pjrt_client.h" +#include "xla/service/hlo.pb.h" #include "xla/service/hlo_module_config.h" #include "xla/tools/hlo_module_loader.h" #include "tsl/platform/init_main.h" diff --git a/examples/k8s/example.yaml b/examples/k8s/example.yaml new file mode 100644 index 000000000000..9039626e9c82 --- /dev/null +++ b/examples/k8s/example.yaml @@ -0,0 +1,39 @@ +apiVersion: jobset.x-k8s.io/v1alpha2 +kind: JobSet +metadata: + name: jaxjob +spec: + replicatedJobs: + - name: workers + template: + spec: + parallelism: 2 + completions: 2 + backoffLimit: 0 + template: + spec: + serviceAccountName: jax-job-sa # kubectl apply -f svc-acct.yaml + restartPolicy: Never + imagePullSecrets: + # https://k8s.io/docs/tasks/configure-pod-container/pull-image-private-registry/ + - name: null + containers: + - name: main + image: null # e.g. ghcr.io/nvidia/jax:jax + imagePullPolicy: Always + resources: + limits: + cpu: 900m + # https://k8s.io/docs/tasks/manage-gpus/scheduling-gpus/ + nvidia.com/gpu: null + command: + - python + args: + - -c + - | + import jax + jax.distributed.initialize() + print(jax.devices()) + print(jax.local_devices()) + assert jax.process_count() > 1 + assert len(jax.devices()) > len(jax.local_devices()) diff --git a/examples/k8s/svc-acct.yaml b/examples/k8s/svc-acct.yaml new file mode 100644 index 000000000000..c1523964c515 --- /dev/null +++ b/examples/k8s/svc-acct.yaml @@ -0,0 +1,31 @@ +apiVersion: v1 +kind: ServiceAccount +metadata: + name: jax-job-sa + namespace: default +--- +apiVersion: rbac.authorization.k8s.io/v1 +kind: Role +metadata: + name: pod-reader +rules: + - apiGroups: [""] + resources: ["pods", "services"] + verbs: ["get", "list", "watch"] + - apiGroups: ["batch"] + resources: ["jobs"] + verbs: ["get", "list", "watch"] +--- +apiVersion: rbac.authorization.k8s.io/v1 +kind: RoleBinding +metadata: + name: pod-reader-binding + namespace: default +subjects: + - kind: ServiceAccount + name: jax-job-sa + namespace: default +roleRef: + kind: Role + name: pod-reader + apiGroup: rbac.authorization.k8s.io diff --git a/examples/spmd_mnist_classifier_fromscratch.py b/examples/spmd_mnist_classifier_fromscratch.py index 3698314708c7..c5c85b2aff37 100644 --- a/examples/spmd_mnist_classifier_fromscratch.py +++ b/examples/spmd_mnist_classifier_fromscratch.py @@ -12,33 +12,30 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""An MNIST example with single-program multiple-data (SPMD) data parallelism. - -The aim here is to illustrate how to use JAX's `pmap` to express and execute -SPMD programs for data parallelism along a batch dimension, while also -minimizing dependencies by avoiding the use of higher-level layers and -optimizers libraries. -""" - - from functools import partial import time +from jax import NamedSharding import numpy as np import numpy.random as npr - import jax -from jax import jit, grad, pmap +from jax import jit, grad +from jax.experimental.shard import reshard +from jax.sharding import ( + PartitionSpec as P, + AxisType, +) from jax.scipy.special import logsumexp -from jax.tree_util import tree_map -from jax import lax import jax.numpy as jnp -from examples import datasets +import datasets def init_random_params(scale, layer_sizes, rng=npr.RandomState(0)): - return [(scale * rng.randn(m, n), scale * rng.randn(n)) - for m, n, in zip(layer_sizes[:-1], layer_sizes[1:])] + return [ + (scale * rng.randn(m, n), scale * rng.randn(n)) + for m, n in zip(layer_sizes[:-1], layer_sizes[1:]) + ] + def predict(params, inputs): activations = inputs @@ -50,11 +47,21 @@ def predict(params, inputs): logits = jnp.dot(activations, final_w) + final_b return logits - logsumexp(logits, axis=1, keepdims=True) + def loss(params, batch): inputs, targets = batch preds = predict(params, inputs) return -jnp.mean(jnp.sum(preds * targets, axis=1)) + +@partial(jax.jit, donate_argnums=0) +def train_step(params, batch): + grads = grad(loss)(params, batch) + return [ + (w - step_size * dw, b - step_size * db) for (w, b), (dw, db) in zip(params, grads) + ] + + @jit def accuracy(params, batch): inputs, targets = batch @@ -72,57 +79,72 @@ def accuracy(params, batch): train_images, train_labels, test_images, test_labels = datasets.mnist() num_train = train_images.shape[0] + + num_devices = jax.device_count() + print(f"Using {num_devices} devices") + + if batch_size % num_devices != 0: + batch_size = (batch_size // num_devices) * num_devices + print(f"Adjusting batch size to {batch_size} for divisibility") + num_complete_batches, leftover = divmod(num_train, batch_size) num_batches = num_complete_batches + bool(leftover) - # For this manual SPMD example, we get the number of devices (e.g. GPUs or - # TPU cores) that we're using, and use it to reshape data minibatches. - num_devices = jax.device_count() + devices = np.array(jax.devices()) + mesh = jax.make_mesh( + (jax.device_count(),), ("batch",), axis_types=(AxisType.Explicit,) + ) + + replicated_sharding = NamedSharding(mesh, P()) + data_sharding = NamedSharding(mesh, P("batch")) + def data_stream(): rng = npr.RandomState(0) while True: perm = rng.permutation(num_train) for i in range(num_batches): - batch_idx = perm[i * batch_size:(i + 1) * batch_size] - images, labels = train_images[batch_idx], train_labels[batch_idx] - # For this SPMD example, we reshape the data batch dimension into two - # batch dimensions, one of which is mapped over parallel devices. - batch_size_per_device, ragged = divmod(images.shape[0], num_devices) - if ragged: - msg = "batch size must be divisible by device count, got {} and {}." - raise ValueError(msg.format(batch_size, num_devices)) - shape_prefix = (num_devices, batch_size_per_device) - images = images.reshape(shape_prefix + images.shape[1:]) - labels = labels.reshape(shape_prefix + labels.shape[1:]) + batch_idx = perm[i * batch_size : (i + 1) * batch_size] + images_np, labels_np = train_images[batch_idx], train_labels[batch_idx] + + current_batch_size = images_np.shape[0] + if current_batch_size < batch_size: + pad_len = batch_size - current_batch_size + images_np = np.concatenate([images_np, images_np[:pad_len]], axis=0) + labels_np = np.concatenate([labels_np, labels_np[:pad_len]], axis=0) + + images = jax.device_put(images_np, data_sharding) + labels = jax.device_put(labels_np, data_sharding) yield images, labels + batches = data_stream() - @partial(pmap, axis_name='batch') - def spmd_update(params, batch): - grads = grad(loss)(params, batch) - # We compute the total gradients, summing across the device-mapped axis, - # using the `lax.psum` SPMD primitive, which does a fast all-reduce-sum. - grads = [(lax.psum(dw, 'batch'), lax.psum(db, 'batch')) for dw, db in grads] - return [(w - step_size * dw, b - step_size * db) - for (w, b), (dw, db) in zip(params, grads)] - - # We replicate the parameters so that the constituent arrays have a leading - # dimension of size equal to the number of devices we're pmapping over. - init_params = init_random_params(param_scale, layer_sizes) - replicate_array = lambda x: np.broadcast_to(x, (num_devices,) + x.shape) - replicated_params = tree_map(replicate_array, init_params) + params = init_random_params(param_scale, layer_sizes) + replicated_params = jax.device_put(params, replicated_sharding) for epoch in range(num_epochs): start_time = time.time() - for _ in range(num_batches): - replicated_params = spmd_update(replicated_params, next(batches)) + for i in range(num_batches - 1): + print(f"Batch no {i+1} of {num_batches}") + batch = next(batches) + with jax.sharding.use_mesh(mesh): + replicated_params = train_step(replicated_params, batch) epoch_time = time.time() - start_time - # We evaluate using the jitted `accuracy` function (not using pmap) by - # grabbing just one of the replicated parameter values. - params = tree_map(lambda x: x[0], replicated_params) - train_acc = accuracy(params, (train_images, train_labels)) - test_acc = accuracy(params, (test_images, test_labels)) + # Reshard train_images, train_labels, test_images, test_labels + sharded_train_images = reshard(train_images, data_sharding) + sharded_train_labels = reshard(train_labels, data_sharding) + sharded_test_images = reshard(test_images, data_sharding) + sharded_test_labels = reshard(test_labels, data_sharding) + + train_acc = accuracy( + replicated_params, (sharded_train_images, sharded_train_labels) + ) + test_acc = accuracy(replicated_params, (sharded_test_images, sharded_test_labels)) print(f"Epoch {epoch} in {epoch_time:0.2f} sec") print(f"Training set accuracy {train_acc}") print(f"Test set accuracy {test_acc}") + + if epoch < num_epochs - 1: + batches = data_stream() + print(f"Batch no {0} of {num_batches}") + replicated_params = train_step(replicated_params, next(batches)) diff --git a/jax/BUILD b/jax/BUILD index 12eae4afdcf7..44e6faf896c3 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -18,6 +18,7 @@ load("@bazel_skylib//rules:common_settings.bzl", "string_flag") load("@rules_python//python:defs.bzl", "py_library") load( "//jaxlib:jax.bzl", + "buffer_callback_internal_users", "if_building_jaxlib", "jax_export_file_visibility", "jax_extend_internal_users", @@ -63,12 +64,63 @@ string_flag( ) config_setting( - name = "enable_jaxlib_build", + name = "config_build_jaxlib_true", flag_values = { ":build_jaxlib": "true", }, ) +config_setting( + name = "config_build_jaxlib_false", + flag_values = { + ":build_jaxlib": "false", + }, +) + +config_setting( + name = "config_build_jaxlib_wheel", + flag_values = { + ":build_jaxlib": "wheel", + }, +) + +# The flag controls whether jax should be built by Bazel. +# If ":build_jax=true", then jax will be built. +# If ":build_jax=false", then jax is not built. It is assumed that the pre-built jax wheel +# is available in the "dist" folder. +# If ":build_jax=wheel", then jax wheel will be built as a py_import rule attribute. +# The py_import rule unpacks the wheel and provides its content as a py_library. +string_flag( + name = "build_jax", + build_setting_default = "true", + values = [ + "true", + "false", + "wheel", + ], +) + +config_setting( + name = "config_build_jax_true", + flag_values = { + ":build_jax": "true", + }, +) + +config_setting( + name = "config_build_jax_false", + flag_values = { + ":build_jax": "false", + }, +) + +config_setting( + name = "config_build_jax_wheel", + flag_values = { + ":build_jax": "wheel", + }, +) + exports_files([ "LICENSE", "version.py", @@ -93,7 +145,7 @@ package_group( includes = [":internal"], packages = [ # Intentionally avoid jax dependencies on jax.extend. - # See https://jax.readthedocs.io/en/latest/jep/15856-jex.html + # See https://docs.jax.dev/en/latest/jep/15856-jex.html "//tests/...", ] + jax_extend_internal_users, ) @@ -134,6 +186,12 @@ package_group( packages = serialize_executable_internal_users, ) +package_group( + name = "buffer_callback_users", + includes = [":internal"], + packages = buffer_callback_internal_users, +) + # JAX-private test utilities. py_library( # This build target is required in order to use private test utilities in jax._src.test_util, @@ -142,6 +200,7 @@ py_library( # these are available in jax.test_util via the standard :jax target. name = "test_util", srcs = [ + "_src/test_loader.py", "_src/test_util.py", "_src/test_warning_util.py", ], @@ -151,6 +210,7 @@ py_library( deps = [ ":compilation_cache_internal", ":jax", + ":public_test_util", ] + py_deps("absl/testing") + py_deps("numpy"), ) @@ -167,22 +227,41 @@ py_library( ], ), visibility = [":internal"], - deps = [ - ":jax", - ] + py_deps("numpy"), + deps = if_building_jaxlib( + if_building = [ + ":jax", + ], + if_not_building = [], + ) + py_deps("numpy"), ) py_library( name = "internal_test_harnesses", srcs = ["_src/internal_test_util/test_harnesses.py"], visibility = [":internal"] + jax_internal_test_harnesses_visibility, - deps = [ - ":ad_util", - ":config", - ":jax", - ":test_util", - "//jax/_src/lib", - ] + py_deps("numpy"), + deps = if_building_jaxlib( + if_building = [ + ":ad_util", + ":config", + ":jax", + ":test_util", + "//jax/_src/lib", + ], + if_not_building = [], + ) + py_deps("numpy"), +) + +py_library( + name = "test_multiprocess", + srcs = ["_src/test_multiprocess.py"], + visibility = [":internal"], + deps = if_building_jaxlib( + if_building = [ + ":jax", + ":test_util", + ], + if_not_building = [], + ), ) py_library( @@ -191,15 +270,17 @@ py_library( visibility = [ ":internal", ] + jax_internal_export_back_compat_test_util_visibility, - deps = [ - ":jax", - ":test_util", - ] + py_deps("numpy"), + deps = if_building_jaxlib( + if_building = [ + ":jax", + ":test_util", + ], + if_not_building = [], + ) + py_deps("numpy"), ) py_library( name = "internal_export_back_compat_test_data", - testonly = 1, srcs = glob([ "_src/internal_test_util/export_back_compat_test_data/*.py", "_src/internal_test_util/export_back_compat_test_data/pallas/*.py", @@ -214,52 +295,31 @@ py_library_providing_imports_info( name = "jax", srcs = [ "_src/__init__.py", - "_src/ad_checkpoint.py", - "_src/api.py", - "_src/array.py", + "_src/ad_checkpoint.py", # TODO(vanderplas): remove once downstream users depend on :lax "_src/blocked_sampler.py", - "_src/callback.py", "_src/checkify.py", - "_src/custom_batching.py", - "_src/custom_dce.py", - "_src/custom_derivatives.py", - "_src/custom_partitioning.py", - "_src/custom_partitioning_sharding_rule.py", - "_src/custom_transpose.py", "_src/debugging.py", - "_src/dispatch.py", "_src/dlpack.py", - "_src/earray.py", "_src/error_check.py", - "_src/ffi.py", "_src/flatten_util.py", "_src/interpreters/__init__.py", - "_src/interpreters/ad.py", - "_src/interpreters/batching.py", - "_src/interpreters/pxla.py", - "_src/pjit.py", "_src/prng.py", - "_src/public_test_util.py", "_src/random.py", - "_src/shard_alike.py", - "_src/sourcemap.py", - "_src/stages.py", - "_src/tree.py", + "_src/shard_map.py", ] + glob( [ "*.py", "_src/cudnn/**/*.py", "_src/debugger/**/*.py", - "_src/extend/**/*.py", "_src/image/**/*.py", - "_src/export/**/*.py", - "_src/lax/**/*.py", + "_src/lax/**/*.py", # TODO(vanderplas): remove once downstream users depend on :lax "_src/nn/**/*.py", "_src/numpy/**/*.py", "_src/ops/**/*.py", "_src/scipy/**/*.py", - "_src/state/**/*.py", + "_src/state/**/*.py", # TODO(vanderplas): remove once downstream users depend on :lax and :state_types "_src/third_party/**/*.py", + "_src/tpu/**/*.py", "experimental/key_reuse/**/*.py", "experimental/roofline/**/*.py", "image/**/*.py", @@ -299,9 +359,15 @@ py_library_providing_imports_info( visibility = ["//visibility:public"], deps = [ ":abstract_arrays", + ":ad", ":ad_util", + ":api", ":api_util", + ":attrs", ":basearray", + ":batching", + ":buffer_callback", + ":callback", ":cloud_tpu_init", ":compilation_cache_internal", ":compiler", @@ -309,12 +375,23 @@ py_library_providing_imports_info( ":config", ":core", ":custom_api_util", + ":custom_batching", + ":custom_dce", + ":custom_derivatives", + ":custom_partitioning", + ":custom_partitioning_sharding_rule", + ":custom_transpose", ":deprecations", ":dtypes", + ":earray", ":effects", ":environment_info", + ":export", + ":ffi", + ":hashable_array", ":internal_mesh_utils", ":jaxpr_util", + ":lax", ":layout", ":lazy_loader", ":mesh", @@ -328,11 +405,16 @@ py_library_providing_imports_info( ":pickle_util", ":pretty_printer", ":profiler", + ":public_test_util", + ":shard_alike", ":sharding", ":sharding_impls", ":sharding_specs", ":source_info_util", + ":sourcemap", + ":stages", ":traceback_util", + ":tree", ":tree_util", ":typing", ":util", @@ -367,6 +449,53 @@ pytype_strict_library( ], ) +pytype_strict_library( + name = "api", + srcs = [ + "_src/api.py", + "_src/array.py", + "_src/dispatch.py", + "_src/interpreters/pxla.py", + "_src/pjit.py", + ], + visibility = [":internal"] + jax_visibility("api"), + deps = [ + ":abstract_arrays", + ":ad", + ":api_util", + ":attrs", + ":basearray", + ":batching", + ":compiler", + ":config", + ":core", + ":deprecations", + ":dtypes", + ":effects", + ":layout", + ":mesh", + ":mlir", + ":monitoring", + ":op_shardings", + ":partial_eval", + ":partition_spec", + ":profiler", + ":sharding", + ":sharding_impls", + ":sharding_specs", + ":source_info_util", + ":stages", + ":state_types", + ":traceback_util", + ":tree_util", + ":typing", + ":util", + ":xla", + ":xla_bridge", + "//jax/_src/lib", + ] + py_deps("numpy"), +) + pytype_strict_library( name = "api_util", srcs = ["_src/api_util.py"], @@ -382,13 +511,73 @@ pytype_strict_library( ] + py_deps("numpy"), ) +pytype_strict_library( + name = "attrs", + srcs = ["_src/attrs.py"], + deps = [ + ":ad", + ":ad_util", + ":api_util", + ":core", + ":dtypes", + ":partial_eval", + ":source_info_util", + ":tree_util", + ":util", + ], +) + pytype_strict_library( name = "basearray", srcs = ["_src/basearray.py"], pytype_srcs = ["_src/basearray.pyi"], deps = [ + ":named_sharding", ":partition_spec", ":sharding", + ":util", + "//jax/_src/lib", + ] + py_deps("numpy"), +) + +pytype_strict_library( + name = "buffer_callback", + srcs = ["_src/buffer_callback.py"], + deps = [ + ":ad", + ":api", + ":batching", + ":core", + ":effects", + ":ffi", + ":mlir", + ":tree_util", + ":util", + "//jax/_src/lib", + ] + py_deps("numpy"), +) + +pytype_strict_library( + name = "callback", + srcs = ["_src/callback.py"], + deps = [ + ":ad", + ":api", + ":batching", + ":config", + ":core", + ":dtypes", + ":effects", + ":ffi", + ":mlir", + ":pickle_util", + ":sharding", + ":sharding_impls", + ":tree_util", + ":typing", + ":util", + ":xla", + ":xla_bridge", "//jax/_src/lib", ] + py_deps("numpy"), ) @@ -437,6 +626,59 @@ pytype_strict_library( ], ) +py_library_providing_imports_info( + name = "lax", + srcs = glob( + [ + "_src/lax/**/*.py", + "_src/state/**/*.py", + ], + exclude = [ + # These are included in :state_types. + "_src/state/__init__.py", + "_src/state/indexing.py", + "_src/state/types.py", + ], + ) + [ + "_src/ad_checkpoint.py", + ], + visibility = [":internal"] + jax_visibility("lax"), + deps = [ + ":abstract_arrays", + ":ad", + ":ad_util", + ":api", + ":api_util", + ":attrs", + ":batching", + ":callback", + ":config", + ":core", + ":custom_derivatives", + ":custom_partitioning_sharding_rule", + ":dtypes", + ":effects", + ":ffi", + ":mesh", + ":mlir", + ":named_sharding", + ":partial_eval", + ":partition_spec", + ":pretty_printer", + ":sharding", + ":sharding_impls", + ":source_info_util", + ":state_types", + ":traceback_util", + ":tree_util", + ":typing", + ":util", + ":xla", + ":xla_bridge", + "//jax/_src/lib", + ] + py_deps("numpy"), +) + pytype_strict_library( name = "lru_cache", srcs = ["_src/lru_cache.py"], @@ -463,6 +705,7 @@ pytype_strict_library( pytype_strict_library( name = "compiler", srcs = ["_src/compiler.py"], + visibility = [":internal"] + jax_visibility("compiler"), deps = [ ":cache_key", ":compilation_cache_internal", @@ -509,6 +752,116 @@ pytype_strict_library( srcs = ["_src/custom_api_util.py"], ) +pytype_strict_library( + name = "custom_batching", + srcs = ["_src/custom_batching.py"], + deps = [ + ":ad", + ":api", + ":api_util", + ":batching", + ":core", + ":custom_api_util", + ":mlir", + ":partial_eval", + ":source_info_util", + ":traceback_util", + ":tree_util", + ":util", + ":xla", + ], +) + +pytype_strict_library( + name = "custom_dce", + srcs = ["_src/custom_dce.py"], + deps = [ + ":ad", + ":api_util", + ":batching", + ":core", + ":custom_api_util", + ":mlir", + ":partial_eval", + ":source_info_util", + ":traceback_util", + ":tree_util", + ":util", + ], +) + +pytype_strict_library( + name = "custom_derivatives", + srcs = ["_src/custom_derivatives.py"], + deps = [ + ":ad", + ":ad_util", + ":api_util", + ":batching", + ":config", + ":core", + ":custom_api_util", + ":custom_transpose", + ":dtypes", + ":effects", + ":mlir", + ":partial_eval", + ":state_types", + ":traceback_util", + ":tree_util", + ":util", + ":xla", + ], +) + +pytype_strict_library( + name = "custom_partitioning", + srcs = ["_src/custom_partitioning.py"], + deps = [ + ":api", + ":api_util", + ":config", + ":core", + ":custom_api_util", + ":custom_partitioning_sharding_rule", + ":mesh", + ":mlir", + ":partial_eval", + ":sharding", + ":sharding_impls", + ":tree_util", + ":xla_bridge", + "//jax/_src/lib", + ] + py_deps("numpy"), +) + +pytype_strict_library( + name = "custom_partitioning_sharding_rule", + srcs = ["_src/custom_partitioning_sharding_rule.py"], + deps = [ + "//jax/_src/lib", + ], +) + +pytype_strict_library( + name = "custom_transpose", + srcs = ["_src/custom_transpose.py"], + deps = [ + ":ad", + ":ad_util", + ":api_util", + ":core", + ":custom_api_util", + ":mlir", + ":partial_eval", + ":source_info_util", + ":traceback_util", + ":tree_util", + ":util", + ":xla", + ], +) + pytype_strict_library( name = "deprecations", srcs = ["_src/deprecations.py"], @@ -528,6 +881,21 @@ pytype_strict_library( ] + py_deps("ml_dtypes") + py_deps("numpy"), ) +pytype_strict_library( + name = "earray", + srcs = ["_src/earray.py"], + deps = [ + ":api", + ":basearray", + ":core", + ":sharding_impls", + ":tree_util", + ":util", + ":xla", + "//jax/_src/lib", + ] + py_deps("numpy"), +) + pytype_strict_library( name = "effects", srcs = ["_src/effects.py"], @@ -543,11 +911,69 @@ pytype_strict_library( ] + py_deps("numpy"), ) +pytype_strict_library( + name = "export", + srcs = glob([ + "_src/export/**/*.py", + ]), + visibility = [":internal"] + jax_visibility("export"), + deps = [ + ":ad_util", + ":api", + ":config", + ":core", + ":custom_derivatives", + ":dtypes", + ":effects", + ":mesh", + ":mlir", + ":sharding", + ":sharding_impls", + ":source_info_util", + ":stages", + ":tree_util", + ":typing", + ":util", + ":xla_bridge", + "//jax/_src/lib", + ] + py_deps("flatbuffers") + py_deps("numpy") + py_deps("opt_einsum"), +) + +pytype_strict_library( + name = "ffi", + srcs = ["_src/ffi.py"], + deps = [ + ":ad", + ":api", + ":batching", + ":core", + ":effects", + ":hashable_array", + ":layout", + ":mlir", + ":typing", + ":util", + ":xla_bridge", + "//jax/_src/lib", + ] + py_deps("numpy"), +) + +pytype_strict_library( + name = "frozen_dict", + srcs = ["_src/frozen_dict.py"], +) + pytype_strict_library( name = "hardware_utils", srcs = ["_src/hardware_utils.py"], ) +pytype_strict_library( + name = "hashable_array", + srcs = ["_src/hashable_array.py"], + deps = py_deps("numpy"), +) + pytype_library( name = "lax_reference", srcs = ["_src/lax_reference.py"], @@ -585,6 +1011,41 @@ pytype_strict_library( ] + py_deps("numpy"), ) +pytype_strict_library( + name = "ad", + srcs = ["_src/interpreters/ad.py"], + deps = [ + ":ad_util", + ":api_util", + ":config", + ":core", + ":dtypes", + ":mesh", + ":partial_eval", + ":source_info_util", + ":tree_util", + ":util", + ], +) + +pytype_strict_library( + name = "batching", + srcs = ["_src/interpreters/batching.py"], + deps = [ + ":ad_util", + ":config", + ":core", + ":mesh", + ":partial_eval", + ":partition_spec", + ":sharding_impls", + ":source_info_util", + ":tree_util", + ":typing", + ":util", + ] + py_deps("numpy"), +) + pytype_strict_library( name = "mlir", srcs = ["_src/interpreters/mlir.py"], @@ -595,7 +1056,10 @@ pytype_strict_library( ":core", ":dtypes", ":effects", + ":hashable_array", + ":jaxpr_util", ":layout", + ":mesh", ":op_shardings", ":partial_eval", ":partition_spec", @@ -635,6 +1099,11 @@ pytype_strict_library( ], ) +pytype_strict_library( + name = "sourcemap", + srcs = ["_src/sourcemap.py"], +) + pytype_strict_library( name = "source_mapper", srcs = glob(include = ["experimental/source_mapper/**/*.py"]), @@ -646,6 +1115,7 @@ pytype_strict_library( ":core", ":jax", ":source_info_util", + ":sourcemap", ] + py_deps("absl/flags"), ) @@ -656,7 +1126,6 @@ pytype_strict_library( "experimental/pallas/**/*.py", ], exclude = [ - "experimental/pallas/gpu.py", "experimental/pallas/mosaic_gpu.py", "experimental/pallas/ops/gpu/**/*.py", "experimental/pallas/ops/tpu/**/*.py", @@ -671,7 +1140,9 @@ pytype_strict_library( deps = [ ":deprecations", ":jax", + ":lax", ":source_info_util", + ":state_types", "//jax/_src/pallas", ] + py_deps("numpy"), ) @@ -708,7 +1179,7 @@ pytype_strict_library( ":pallas", # build_cleaner: keep "//jax/_src/pallas/fuser:block_spec", "//jax/_src/pallas/fuser:custom_evaluate", - "//jax/_src/pallas/fuser:fusable", + "//jax/_src/pallas/fuser:fusible", "//jax/_src/pallas/fuser:fusion", "//jax/_src/pallas/fuser:jaxpr_fusion", ], @@ -739,6 +1210,7 @@ pytype_strict_library( ":pallas", ":pallas_mosaic_gpu", ":test_util", # This is only to make them runnable as jax_multiplatform_test... + "//jax/_src/lib", ] + py_deps("numpy"), ) @@ -749,6 +1221,7 @@ pytype_strict_library( ":pallas_tpu_users", ], deps = [ + ":dtypes", ":jax", ":pallas", ":pallas_tpu", @@ -769,7 +1242,6 @@ pytype_strict_library( pytype_strict_library( name = "pallas_triton", srcs = [ - "experimental/pallas/gpu.py", "experimental/pallas/triton.py", ], visibility = [ @@ -792,6 +1264,7 @@ pytype_strict_library( deps = [ ":mosaic_gpu", "//jax/_src/pallas/mosaic_gpu:core", + "//jax/_src/pallas/mosaic_gpu:helpers", "//jax/_src/pallas/mosaic_gpu:pallas_call_registration", # build_cleaner: keep "//jax/_src/pallas/mosaic_gpu:pipeline", "//jax/_src/pallas/mosaic_gpu:primitives", @@ -802,6 +1275,12 @@ pytype_strict_library( py_library_providing_imports_info( name = "mosaic_gpu", srcs = glob(["experimental/mosaic/gpu/*.py"]), + data = [ + "@cuda_nvcc//:nvdisasm", + "@cuda_nvcc//:nvvm", + "@cuda_nvcc//:ptxas", + "@nvidia_nvshmem//:libnvshmem_device", + ], visibility = [ ":mosaic_gpu_users", ], @@ -813,6 +1292,7 @@ py_library_providing_imports_info( "//jax/_src/lib", "//jaxlib/mlir:arithmetic_dialect", "//jaxlib/mlir:builtin_dialect", + "//jaxlib/mlir:control_flow_dialect", "//jaxlib/mlir:func_dialect", "//jaxlib/mlir:gpu_dialect", "//jaxlib/mlir:ir", @@ -851,6 +1331,10 @@ pytype_strict_library( pytype_strict_library( name = "partition_spec", srcs = ["_src/partition_spec.py"], + deps = [ + ":util", + "//jax/_src/lib", + ], ) pytype_strict_library( @@ -889,7 +1373,8 @@ pytype_strict_library( deps = [ ":config", ":util", - ] + py_deps("colorama"), + "//jax/_src/lib", + ], ) pytype_strict_library( @@ -902,6 +1387,19 @@ pytype_strict_library( ], ) +pytype_strict_library( + name = "public_test_util", + srcs = [ + "_src/public_test_util.py", + ], + deps = [ + ":api", + ":config", + ":dtypes", + ":tree_util", + ] + py_deps("numpy"), +) + pytype_strict_library( name = "sharding", srcs = ["_src/sharding.py"], @@ -913,6 +1411,44 @@ pytype_strict_library( ], ) +pytype_strict_library( + name = "shard_alike", + srcs = [ + "_src/shard_alike.py", + ], + deps = [ + ":ad", + ":api", + ":batching", + ":config", + ":core", + ":mlir", + ":tree_util", + ":util", + "//jax/_src/lib", + ], +) + +pytype_strict_library( + name = "stages", + srcs = ["_src/stages.py"], + visibility = [":internal"] + jax_visibility("stages"), + deps = [ + ":config", + ":core", + ":layout", + ":mlir", + ":sharding", + ":sharding_impls", + ":source_info_util", + ":traceback_util", + ":tree_util", + ":typing", + ":util", + "//jax/_src/lib", + ], +) + pytype_strict_library( name = "compute_on", srcs = ["_src/compute_on.py"], @@ -945,6 +1481,7 @@ pytype_strict_library( pytype_strict_library( name = "sharding_impls", srcs = ["_src/sharding_impls.py"], + visibility = [":internal"] + jax_visibility("sharding_impls"), deps = [ ":config", ":core", @@ -1014,6 +1551,7 @@ pytype_strict_library( "_src/state/indexing.py", "_src/state/types.py", ], + visibility = [":internal"] + jax_visibility("state_types"), deps = [ ":core", ":dtypes", @@ -1026,6 +1564,15 @@ pytype_strict_library( ] + py_deps("numpy"), ) +pytype_strict_library( + name = "tree", + srcs = ["_src/tree.py"], + deps = [ + ":tree_util", + "//jax/_src/lib", + ], +) + pytype_strict_library( name = "tree_util", srcs = ["_src/tree_util.py"], @@ -1061,6 +1608,7 @@ pytype_strict_library( srcs = ["_src/tpu_custom_call.py"], visibility = [":internal"], deps = [ + ":cloud_tpu_init", ":config", ":core", ":jax", @@ -1097,6 +1645,7 @@ pytype_strict_library( ":abstract_arrays", ":config", ":core", + ":deprecations", ":dtypes", ":sharding_impls", ":source_info_util", @@ -1129,7 +1678,7 @@ pytype_strict_library( ":traceback_util", ":util", "//jax/_src/lib", - ], + ] + py_deps("numpy"), ) # Public JAX libraries below this point. @@ -1141,9 +1690,14 @@ py_library_providing_imports_info( "experimental/*.py", "example_libraries/*.py", ], + [ + "experimental/buffer_callback.py", + "experimental/mosaic/gpu/*.py", + ], ), visibility = ["//visibility:public"], deps = [ + ":buffer_callback", ":jax", ] + py_deps("absl/logging") + py_deps("numpy"), ) @@ -1157,6 +1711,12 @@ pytype_library( deps = [":jax"], ) +pytype_library( + name = "experimental_shard", + srcs = ["experimental/shard.py"], + deps = [":jax"], +) + pytype_library( name = "experimental_sparse", srcs = glob( @@ -1166,7 +1726,10 @@ pytype_library( exclude = ["experimental/sparse/test_util.py"], ), visibility = ["//visibility:public"], - deps = [":jax"], + deps = [ + ":ffi", + ":jax", + ], ) pytype_library( @@ -1198,17 +1761,6 @@ pytype_library( deps = [":jax"], ) -# TODO(apaszke): Remove this target -pytype_library( - name = "pjit", - srcs = ["experimental/pjit.py"], - visibility = ["//visibility:public"], - deps = [ - ":experimental", - ":jax", - ], -) - pytype_library( name = "jet", srcs = ["experimental/jet.py"], @@ -1248,6 +1800,12 @@ pytype_library( ], ) +pytype_strict_library( + name = "extend_src", + srcs = glob(include = ["_src/extend/**/*.py"]), + deps = [":jax"], +) + # TODO(phawkins): remove this target in favor of the finer-grained targets in jax/extend/... pytype_strict_library( name = "extend", @@ -1305,3 +1863,14 @@ pytype_library( "//jax/extend:ifrt_programs", ] + py_deps("numpy") + py_deps("cloudpickle"), ) + +pytype_library( + name = "experimental_buffer_callback", + srcs = [ + "experimental/buffer_callback.py", + ], + visibility = [":buffer_callback_users"], + deps = [ + ":jax", + ], +) diff --git a/jax/__init__.py b/jax/__init__.py index ae3bac4ad3fa..18465c28bc84 100644 --- a/jax/__init__.py +++ b/jax/__init__.py @@ -70,7 +70,6 @@ transfer_guard_host_to_device as transfer_guard_host_to_device, transfer_guard_device_to_device as transfer_guard_device_to_device, transfer_guard_device_to_host as transfer_guard_device_to_host, - spmd_mode as spmd_mode, ) from jax._src.core import ensure_compile_time_eval as ensure_compile_time_eval from jax._src.environment_info import print_environment_info as print_environment_info @@ -100,6 +99,7 @@ from jax._src.api import disable_jit as disable_jit from jax._src.api import eval_shape as eval_shape from jax._src.dtypes import float0 as float0 +from jax._src.api import fwd_and_bwd as fwd_and_bwd from jax._src.api import grad as grad from jax._src.api import hessian as hessian from jax._src.xla_bridge import host_count as host_count @@ -131,6 +131,8 @@ from jax._src.sharding_impls import NamedSharding as NamedSharding from jax._src.sharding_impls import make_mesh as make_mesh +from jax._src.shard_map import shard_map as shard_map + # Force import, allowing jax.interpreters.* to be used after import jax. from jax.interpreters import ad, batching, mlir, partial_eval, pxla, xla del ad, batching, mlir, partial_eval, pxla, xla @@ -141,16 +143,6 @@ make_array_from_process_local_data as make_array_from_process_local_data, ) -from jax._src.tree_util import ( - tree_map as _deprecated_tree_map, - treedef_is_leaf as _deprecated_treedef_is_leaf, - tree_flatten as _deprecated_tree_flatten, - tree_leaves as _deprecated_tree_leaves, - tree_structure as _deprecated_tree_structure, - tree_transpose as _deprecated_tree_transpose, - tree_unflatten as _deprecated_tree_unflatten, -) - # These submodules are separate because they are in an import cycle with # jax and rely on the names imported above. from jax import custom_derivatives as custom_derivatives @@ -184,59 +176,46 @@ del _ccache _deprecations = { - # Added July 2022 + # Finalized 2025-03-25; remove after 2025-06-25 "treedef_is_leaf": ( - "jax.treedef_is_leaf is deprecated: use jax.tree_util.treedef_is_leaf.", - _deprecated_treedef_is_leaf + "jax.treedef_is_leaf was removed in JAX v0.6.0: use jax.tree_util.treedef_is_leaf.", + None ), "tree_flatten": ( - "jax.tree_flatten is deprecated: use jax.tree.flatten (jax v0.4.25 or newer) " + "jax.tree_flatten was removed in JAX v0.6.0: use jax.tree.flatten (jax v0.4.25 or newer) " "or jax.tree_util.tree_flatten (any JAX version).", - _deprecated_tree_flatten + None ), "tree_leaves": ( - "jax.tree_leaves is deprecated: use jax.tree.leaves (jax v0.4.25 or newer) " + "jax.tree_leaves was removed in JAX v0.6.0: use jax.tree.leaves (jax v0.4.25 or newer) " "or jax.tree_util.tree_leaves (any JAX version).", - _deprecated_tree_leaves + None ), "tree_structure": ( - "jax.tree_structure is deprecated: use jax.tree.structure (jax v0.4.25 or newer) " + "jax.tree_structure was removed in JAX v0.6.0: use jax.tree.structure (jax v0.4.25 or newer) " "or jax.tree_util.tree_structure (any JAX version).", - _deprecated_tree_structure + None ), "tree_transpose": ( - "jax.tree_transpose is deprecated: use jax.tree.transpose (jax v0.4.25 or newer) " + "jax.tree_transpose was removed in JAX v0.6.0: use jax.tree.transpose (jax v0.4.25 or newer) " "or jax.tree_util.tree_transpose (any JAX version).", - _deprecated_tree_transpose + None ), "tree_unflatten": ( - "jax.tree_unflatten is deprecated: use jax.tree.unflatten (jax v0.4.25 or newer) " + "jax.tree_unflatten was removed in JAX v0.6.0: use jax.tree.unflatten (jax v0.4.25 or newer) " "or jax.tree_util.tree_unflatten (any JAX version).", - _deprecated_tree_unflatten + None ), - # Added Feb 28, 2024 "tree_map": ( - "jax.tree_map is deprecated: use jax.tree.map (jax v0.4.25 or newer) " + "jax.tree_map was removed in JAX v0.6.0: use jax.tree.map (jax v0.4.25 or newer) " "or jax.tree_util.tree_map (any JAX version).", - _deprecated_tree_map - ), - # Finalized Nov 12 2024; remove after Feb 12 2025 - "clear_backends": ( - "jax.clear_backends was removed in JAX v0.4.36", None ), } import typing as _typing if _typing.TYPE_CHECKING: - from jax._src.tree_util import treedef_is_leaf as treedef_is_leaf - from jax._src.tree_util import tree_flatten as tree_flatten - from jax._src.tree_util import tree_leaves as tree_leaves - from jax._src.tree_util import tree_map as tree_map - from jax._src.tree_util import tree_structure as tree_structure - from jax._src.tree_util import tree_transpose as tree_transpose - from jax._src.tree_util import tree_unflatten as tree_unflatten - + pass else: from jax._src.deprecations import deprecation_getattr as _deprecation_getattr __getattr__ = _deprecation_getattr(__name__, _deprecations) diff --git a/jax/_src/ad_checkpoint.py b/jax/_src/ad_checkpoint.py index c2868cf7c078..c49614521a1c 100644 --- a/jax/_src/ad_checkpoint.py +++ b/jax/_src/ad_checkpoint.py @@ -430,7 +430,7 @@ def _trace_to_jaxpr(fun: Callable, "Consider using the `static_argnums` parameter for `jax.remat` or " "`jax.checkpoint`. See the `jax.checkpoint` docstring and its example " "involving `static_argnums`:\n" - "https://jax.readthedocs.io/en/latest/_autosummary/jax.checkpoint.html" + "https://docs.jax.dev/en/latest/_autosummary/jax.checkpoint.html" "\n") e.args = msg, raise @@ -578,7 +578,7 @@ def remat_partial_eval(trace: pe.JaxprTrace, *tracers: core.Tracer, out_jaxpr_tracers = [pe.JaxprTracer(trace, pe.PartialVal.unknown(x.aval), None) for x in jaxpr_unknown.outvars] new_params = dict(params, jaxpr=jaxpr_unknown, differentiated=True) - recipe = pe.new_eqn_recipe(in_jaxpr_tracers, out_jaxpr_tracers, remat_p, + recipe = pe.new_eqn_recipe(trace, in_jaxpr_tracers, out_jaxpr_tracers, remat_p, new_params, jaxpr_unknown.effects, source_info_util.current()) @@ -621,7 +621,7 @@ def _insert_reduce_precision(jaxpr: core.Jaxpr, num_res: int) -> core.Jaxpr: if v not in used_vars: continue assert isinstance(v, core.Var) - newvar = core.Var(v.suffix, v.aval) + newvar = core.Var(v.aval) finfo = dtypes.finfo(v.aval.dtype) params = dict(exponent_bits=finfo.nexp, mantissa_bits=finfo.nmant) if v in constvars or v in invars: @@ -757,89 +757,34 @@ def _has_effects(effects) -> bool: return bool({e for e in effects if not isinstance(e, core.NamedAxisEffect)}) -def remat_expansion(*args, jaxpr: core.Jaxpr, prevent_cse: bool, - differentiated: bool, is_gpu_platform: bool = False, - **_): +def remat_expansion( + *args, jaxpr: core.Jaxpr, prevent_cse: bool, differentiated: bool, **_ +): assert not jaxpr.constvars if differentiated and prevent_cse: - if config.remat_opt_barrier.value: - translation_rule = _remat_translation_using_opt_barrier - elif is_gpu_platform: - translation_rule = _remat_translation_using_while - else: - translation_rule = _remat_translation_using_cond + translation_rule = _remat_translation_using_opt_barrier else: translation_rule = lambda *args, jaxpr: core.eval_jaxpr(jaxpr, (), *args) return api.named_call(translation_rule, name="checkpoint")(*args, jaxpr=jaxpr) + def _remat_translation_using_opt_barrier(*args, jaxpr: core.Jaxpr): args = lax_internal.optimization_barrier(args) return core.eval_jaxpr(jaxpr, (), *args) -# TODO(mattjj): add core utility for 'create dummy value for this type'? -def _dummy_like(aval: core.AbstractValue) -> Any: - if aval is core.abstract_token: - return lax_internal.create_token() - elif isinstance(aval, (core.ShapedArray, core.DShapedArray)): - return lax_internal.broadcast(lax_internal.empty(aval.dtype), aval.shape) # type: ignore - else: - raise ValueError(aval) - -def _remat_translation_using_while(*args, jaxpr: core.Jaxpr): - # Implements: - # for(counter=0, result=0; counter < rng(1, 2); counter ++) { - # result = eval_jaxpr(*args) - # } - # The loop carry is a tuple: (counter, result, args) - from jax._src.lax import control_flow as lax_control_flow - - avals_out = tuple(v.aval for v in jaxpr.outvars) - carry_init = (np.int32(0), tuple(map(_dummy_like, avals_out)), args) - def cond(carry): - counter, _, _ = carry - unif = lax_internal.rng_uniform(np.int32(1), np.int32(2), shape=()) - return counter < unif - - def body(carry): - counter, _, args = carry - results = core.eval_jaxpr(jaxpr, (), *args) - return (counter + 1, tuple(results), args) - - carry_res = lax_control_flow.while_loop(cond, body, carry_init) - return carry_res[1] - -def _remat_translation_using_cond(*args, jaxpr: core.Jaxpr): - # Implements: - # if(rng(0, 1) < 2) - # return eval_jaxpr(*args) - # else: - # return 0 - from jax._src.lax import control_flow as lax_control_flow - - avals_out = tuple(v.aval for v in jaxpr.outvars) - - def remat_comp(*args): - return tuple(core.eval_jaxpr(jaxpr, (), *args)) - def dummy_comp(*args): - return tuple(map(_dummy_like, avals_out)) - - unif = lax_internal.rng_uniform(np.float32(0), np.float32(1), shape=()) - return lax_control_flow.cond(unif < np.float32(2), remat_comp, dummy_comp, *args) - -def _remat_lowering(ctx, *args, jaxpr: core.Jaxpr, prevent_cse: bool, - differentiated: bool, policy, is_gpu_platform=False): + +def _remat_lowering( + ctx, + *args, + jaxpr: core.Jaxpr, + prevent_cse: bool, + differentiated: bool, + policy, +): jaxpr_args: Sequence[mlir.IrValues] if differentiated and prevent_cse: - # If we're using the loop or cond lowerings, use the slower lower_fun - # based path. - if not config.remat_opt_barrier.value: - return mlir.lower_fun(remat_expansion, multiple_results=True)( - ctx, *args, jaxpr=jaxpr, prevent_cse=prevent_cse, - differentiated=differentiated, policy=policy, - is_gpu_platform=is_gpu_platform) - arg_types = map(mlir.aval_to_ir_type, ctx.avals_in) flat_args = mlir.flatten_ir_values(args) barrier_op = hlo.OptimizationBarrierOp(flat_args) @@ -853,9 +798,8 @@ def _remat_lowering(ctx, *args, jaxpr: core.Jaxpr, prevent_cse: bool, ctx.set_tokens_out(tokens_out) return outs + mlir.register_lowering(remat_p, _remat_lowering) -mlir.register_lowering(remat_p, partial(_remat_lowering, is_gpu_platform=True), - platform="gpu") def checkpoint_name(x, name): @@ -931,10 +875,7 @@ def checkpoint_wrapper( " else:\n" " return g(x)\n" "\n" - "See https://jax.readthedocs.io/en/latest/jep/11830-new-remat-checkpoint.html\n") + "See https://docs.jax.dev/en/latest/jep/11830-new-remat-checkpoint.html\n") raise NotImplementedError(msg) return checkpoint(fun, prevent_cse=prevent_cse, policy=policy, static_argnums=static_argnums) - -# TODO(phawkins): update users to refer to the public name. -_optimization_barrier = lax_internal.optimization_barrier diff --git a/jax/_src/ad_util.py b/jax/_src/ad_util.py index c729a57cfb11..4e9616e48375 100644 --- a/jax/_src/ad_util.py +++ b/jax/_src/ad_util.py @@ -31,6 +31,10 @@ map = safe_map def add_jaxvals(x: ArrayLike, y: ArrayLike) -> Array: + ty = core.typeof(x) + if hasattr(ty, 'vspace_add'): # TODO(mattjj,dougalm): revise away hasattr + return ty.vspace_add(x, y) + x, y = core.standard_insert_pvary(x, y) return add_jaxvals_p.bind(x, y) add_jaxvals_p = Primitive('add_any') @@ -47,6 +51,8 @@ def add_abstract(x, y): return x def zeros_like_aval(aval: core.AbstractValue) -> Array: + if hasattr(aval, 'vspace_zero'): # TODO(mattjj,dougalm): revise away hasattr + return aval.vspace_zero() return aval_zeros_likers[type(aval)](aval) aval_zeros_likers: dict[type, Callable[[Any], Array]] = {} diff --git a/jax/_src/api.py b/jax/_src/api.py index cdcc3e534e74..ff3414e82f53 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -25,6 +25,7 @@ import atexit import collections from collections.abc import Callable, Hashable, Iterable, Sequence +import dataclasses from functools import partial, lru_cache import inspect import math @@ -36,12 +37,14 @@ import numpy as np from contextlib import contextmanager +from jax._src import api_util from jax._src import linear_util as lu from jax._src import stages from jax._src.tree_util import ( tree_map, tree_flatten, tree_unflatten, tree_structure, tree_transpose, tree_leaves, Partial, PyTreeDef, all_leaves, keystr, broadcast_prefix, - prefix_errors, generate_key_paths, tree_flatten_with_path) + prefix_errors, generate_key_paths, tree_flatten_with_path, + equality_errors_pytreedef) from jax._src import config from jax._src import core from jax._src import dispatch @@ -62,7 +65,6 @@ rebase_donate_argnums, _ensure_index, _ensure_index_tuple, apply_flat_fun_nokwargs, check_callable, debug_info, flat_out_axes) -from jax._src.lax import lax as lax_internal from jax._src.lib import jax_jit from jax._src.lib import xla_client as xc from jax._src.lib import pmap_lib @@ -70,7 +72,7 @@ from jax._src.mesh import get_concrete_mesh from jax._src.sharding_impls import ( PmapSharding, TransferToMemoryKind, PartitionSpec as P, NamedSharding) -from jax._src.layout import Layout, AutoLayout +from jax._src.layout import Format, AutoLayout from jax._src.traceback_util import api_boundary from jax._src import tree_util from jax._src.util import unzip2, safe_map, safe_zip, wraps, split_list @@ -80,7 +82,6 @@ from jax._src.interpreters import batching from jax._src.interpreters import partial_eval as pe from jax._src.interpreters import pxla -from jax._src.interpreters import xla traceback_util.register_exclusion(__file__) @@ -111,14 +112,14 @@ def _nan_check_posthook(fun, args, kwargs, output): try: dispatch.check_special(pjit.pjit_p.name, buffers) - except dispatch.InternalFloatingPointError as e: + except api_util.InternalFloatingPointError as e: assert config.debug_nans.value or config.debug_infs.value if hasattr(fun, '_fun'): f = fun._fun if getattr(f, '_apply_primitive', False): raise FloatingPointError(f"invalid value ({e.ty}) encountered in {f.__qualname__}") from None # compiled_fun can only raise in this case - dispatch.maybe_recursive_nan_check(e, f, args, kwargs) + api_util.maybe_recursive_nan_check(e, f, args, kwargs) raise AssertionError("Unreachable") from e else: # TODO(emilyaf): Shouldn't need this fallback. @@ -147,7 +148,7 @@ def _update_debug_special_thread_local(_): def jit( - fun: Callable, + fun: Callable, /, *, in_shardings: Any = sharding_impls.UNSPECIFIED, out_shardings: Any = sharding_impls.UNSPECIFIED, static_argnums: int | Sequence[int] | None = None, @@ -191,7 +192,7 @@ def jit( constant). Static arguments should be hashable, meaning both ``__hash__`` and - ``__eq__`` are implemented, and immutable. Otherwise they can be arbitrary + ``__eq__`` are implemented, and immutable. Otherwise, they can be arbitrary Python objects. Calling the jitted function with different values for these constants will trigger recompilation. Arguments that are not array-like or containers thereof must be marked as static. @@ -231,7 +232,7 @@ def jit( be donated. For more details on buffer donation see the - `FAQ `_. + `FAQ `_. donate_argnames: optional, a string or collection of strings specifying which named arguments are donated to the computation. See the comment on ``donate_argnums`` for details. If not @@ -287,16 +288,19 @@ def jit( Array([ 0, 1, 256, 6561], dtype=int32) """ return pjit.make_jit( - fun, in_shardings, out_shardings, donate_argnums, donate_argnames, - static_argnums, static_argnames, device, backend, abstracted_axes, - keep_unused, inline, compiler_options, use_resource_env=False) + fun, in_shardings=in_shardings, out_shardings=out_shardings, + static_argnums=static_argnums, static_argnames=static_argnames, + donate_argnums=donate_argnums, donate_argnames=donate_argnames, + keep_unused=keep_unused, device=device, backend=backend, inline=inline, + abstracted_axes=abstracted_axes, compiler_options=compiler_options, + use_resource_env=False) @contextmanager def disable_jit(disable: bool = True): """Context manager that disables :py:func:`jit` behavior under its dynamic context. - For debugging it is useful to have a mechanism that disables :py:func:`jit` + For debugging, it is useful to have a mechanism that disables :py:func:`jit` everywhere in a dynamic context. Note that this not only disables explicit uses of :func:`jit` by the user, but will also remove any implicit JIT compilation used by the JAX library: this includes implicit JIT computation of `body` and @@ -322,7 +326,7 @@ def disable_jit(disable: bool = True): ... return y + 3 ... >>> print(f(jax.numpy.array([1, 2, 3]))) # doctest:+ELLIPSIS - Value of y is Tracedwith + Value of y is Tracedwith [5 7 9] Here ``y`` has been abstracted by :py:func:`jit` to a :py:class:`ShapedArray`, @@ -440,6 +444,8 @@ def value_and_grad(fun: Callable, argnums: int | Sequence[int] = 0, shapes and types as the corresponding arguments. If ``has_aux`` is True then a tuple of ((value, auxiliary_data), gradient) is returned. """ + from jax._src.lax import lax as lax_internal # pytype: disable=import-error + if reduce_axes: raise NotImplementedError("reduce_axes argument to grad is deprecated") del reduce_axes @@ -471,8 +477,7 @@ def value_and_grad_f(*args, **kwargs): if not has_aux: ans, vjp_py = _vjp(f_partial, *dyn_args) else: - ans, vjp_py, aux = _vjp( - f_partial, *dyn_args, has_aux=True) + ans, vjp_py, aux = _vjp(f_partial, *dyn_args, has_aux=True) _check_scalar(ans) tree_map(partial(_check_output_dtype_grad, holomorphic), ans) g = vjp_py(lax_internal._one(ans)) @@ -504,17 +509,18 @@ def _check_input_dtype_revderiv(name, holomorphic, allow_int, x): if not dtypes.issubdtype(aval.dtype, np.complexfloating): raise TypeError(f"{name} with holomorphic=True requires inputs with complex dtype, " f"but got {aval.dtype.name}.") - if (dtypes.issubdtype(aval.dtype, dtypes.extended) or - dtypes.issubdtype(aval.dtype, np.integer) or - dtypes.issubdtype(aval.dtype, np.bool_)): - if not allow_int: - raise TypeError(f"{name} requires real- or complex-valued inputs (input dtype " - f"that is a sub-dtype of np.inexact), but got {aval.dtype.name}. " - "If you want to use Boolean- or integer-valued inputs, use vjp " - "or set allow_int to True.") - elif not dtypes.issubdtype(aval.dtype, np.inexact): - raise TypeError(f"{name} requires numerical-valued inputs (input dtype that is a " - f"sub-dtype of np.bool_ or np.number), but got {aval.dtype.name}.") + if isinstance(aval, ShapedArray): + if (dtypes.issubdtype(aval.dtype, dtypes.extended) or + dtypes.issubdtype(aval.dtype, np.integer) or + dtypes.issubdtype(aval.dtype, np.bool_)): + if not allow_int: + raise TypeError(f"{name} requires real- or complex-valued inputs (input dtype " + f"that is a sub-dtype of np.inexact), but got {aval.dtype.name}. " + "If you want to use Boolean- or integer-valued inputs, use vjp " + "or set allow_int to True.") + elif not dtypes.issubdtype(aval.dtype, np.inexact): + raise TypeError(f"{name} requires numerical-valued inputs (input dtype that is a " + f"sub-dtype of np.bool_ or np.number), but got {aval.dtype.name}.") _check_input_dtype_grad = partial(_check_input_dtype_revderiv, "grad") def _check_output_dtype_revderiv(name, holomorphic, x): @@ -539,6 +545,79 @@ def _check_output_dtype_revderiv(name, holomorphic, x): "jax.vjp directly.") _check_output_dtype_grad = partial(_check_output_dtype_revderiv, "grad") +def fwd_and_bwd( + fun: Callable, argnums: int | Sequence[int], has_aux: bool = False, + jitted: bool = True, +) -> tuple[Callable, Callable]: + """Creates functions ``fwd`` and ``bwd`` corresponding to the forward and + backward pass of a given function ``fun``. The forward function ``fwd(*args)`` + functionally behaves much like ``y, fun_vjp = jax.vjp(fun, *args)``, but allows + reuse of the backward function ``bwd`` across multiple iterations, which is + useful to avoid recompilation when the forward and backward do not end up in a + single jitted function: + + >>> import jax + >>> + >>> x = W = cot_out = jax.numpy.ones((4,4)) + >>> + >>> def f(x, W): + ... return x @ W + ... + >>> f_jitted = jax.jit(f) + >>> for i in range(3): + ... y, f_vjp = jax.vjp(f_jitted, x, W) + ... cot_x, cot_W = f_vjp(cot_out) # not jitted + ... cot_x, cot_W = jax.jit(f_vjp)(cot_out) # recompiles on every iteration + ... + >>> fwd, bwd = jax.fwd_and_bwd(f, argnums=(0,1)) + >>> for i in range(3): + ... y, residuals = fwd(x, W) + ... cot_x, cot_W = bwd(residuals, cot_out) # jitted, compiles once + ... + + Args: + fun: Function to produce a forward and backward of. + argnums: Integer or sequence of integers. Specifies which positional argument(s) + to differentiate with respect to. + has_aux: Optional, bool. Indicates whether ``fun`` returns a pair where the + first element is considered the output of the mathematical function to be + differentiated and the second element is auxiliary data. Default False. + jitted: Optional, bool. Indicates whether to return the ``jax.jit`` of + forward and backward. Note that jit-ing only the backward but not the + forward will result in the backward recompiling on every invocation, so we + default to jit-ing both. + + Returns: + The two functions, ``fwd`` and ``bwd``. + + If ``has_aux`` is ``False``, ``fwd(*primals)`` returns a tuple + ``(primals_out, residuals)``, where ``primals_out`` is ``fun(*primals)``. + If ``has_aux`` is ``True``, returns a ``(primals_out, residuals, aux)`` tuple + where ``aux`` is the auxiliary data returned by ``fun``. + + ``bwd`` is a function from ``residuals`` and a cotangent vector with the same + shape as ``primals_out`` to a tuple of cotangent vectors with the same number + and shapes as the ``primals`` designated by ``argnums``, representing the + vector-Jacobian product of ``fun`` evaluated at ``primals``. + """ + check_callable(fun) + argnums = _ensure_index(argnums) + + def fwd(*args, **kwargs): + dbg = debug_info('fwd_and_bwd', fun, args, kwargs) + f = lu.wrap_init(fun, params=kwargs, debug_info=dbg) + f_partial, dyn_args = argnums_partial( + f, argnums, args, require_static_args_hashable=False) + return _vjp(f_partial, *dyn_args, has_aux=has_aux) # type: ignore + def bwd(f_vjp, outgrad): + g = f_vjp(outgrad) + g = g[0] if isinstance(argnums, int) else g + return g + if jitted: + fwd = jit(fwd) + bwd = jit(bwd) + return fwd, bwd + def jacfwd(fun: Callable, argnums: int | Sequence[int] = 0, has_aux: bool = False, holomorphic: bool = False) -> Callable: @@ -778,7 +857,7 @@ def hessian(fun: Callable, argnums: int | Sequence[int] = 0, argnums, has_aux=has_aux, holomorphic=holomorphic) def _std_basis(pytree): - import jax.numpy as jnp + import jax.numpy as jnp # pytype: disable=import-error leaves, _ = tree_flatten(pytree) ndim = sum(map(np.size, leaves)) dtype = dtypes.result_type(*leaves) @@ -794,6 +873,7 @@ def _jacrev_unravel(output_pytree, input_pytree_leaf, arr): output_pytree, 0, input_pytree_leaf, arr) def _possible_downcast(x, example): + from jax._src.lax import lax as lax_internal # pytype: disable=import-error if (dtypes.issubdtype(x.dtype, np.complexfloating) and not dtypes.issubdtype(_dtype(example), np.complexfloating)): x = x.real @@ -855,7 +935,7 @@ def vmap(fun: F, be a container with a matching pytree structure specifying the mapping of its container elements. In other words, ``in_axes`` must be a container tree prefix of the positional argument tuple passed to ``fun``. See this link for more detail: - https://jax.readthedocs.io/en/latest/pytrees.html#applying-optional-parameters-to-pytrees + https://docs.jax.dev/en/latest/pytrees.html#applying-optional-parameters-to-pytrees Either ``axis_size`` must be provided explicitly, or at least one positional argument must have ``in_axes`` not None. The sizes of the @@ -1241,7 +1321,7 @@ def pmap( arguments will not be donated. For more details on buffer donation see the - `FAQ `_. + `FAQ `_. Returns: A parallelized version of ``fun`` with arguments that correspond to those of @@ -1371,10 +1451,8 @@ def pmap( " removed from JAX. Please migrate to pjit and remove global_arg_shapes" " from pmap.") - # TODO(yashkatariya): Move this out after shard_map is out of experimental and - # in _src if config.pmap_shmap_merge.value: - from jax.experimental.shard_map import pmap + from jax._src.shard_map import pmap # pytype: disable=import-error return pmap(fun, axis_name, in_axes=in_axes, out_axes=out_axes, static_broadcasted_argnums=static_broadcasted_argnums, devices=devices, backend=backend, @@ -1488,7 +1566,7 @@ def _prepare_pmap(fun: Callable, in_axes, out_axes, static_broadcasted_tuple, "Instead, each argument passed by keyword is mapped over its " "leading axis. See the description of `in_axes` in the `pmap` " "docstring: " - "https://jax.readthedocs.io/en/latest/_autosummary/jax.pmap.html#jax.pmap") + "https://docs.jax.dev/en/latest/_autosummary/jax.pmap.html#jax.pmap") msg += ("\n\nCheck that the value of the `in_axes` argument to `pmap` " "is a tree prefix of the tuple of arguments passed positionally to " "the pmapped function.") @@ -1568,11 +1646,13 @@ def _cpp_pmap( out_axes) del static_broadcasted_argnums, donate_argnums + prepare_pmap_fn = partial(_prepare_pmap, + fun, in_axes, out_axes, static_broadcasted_tuple, donate_tuple, + devices, backend, axis_size) + @api_boundary def cache_miss(*args, **kwargs): - p = _prepare_pmap(fun, in_axes, out_axes, static_broadcasted_tuple, - donate_tuple, devices, backend, - axis_size, args, kwargs) + p = prepare_pmap_fn(args, kwargs) for arg in p.flat_args: dispatch.check_arg(arg) @@ -1597,7 +1677,7 @@ def cache_miss(*args, **kwargs): out = execute(*p.flat_args) else: out = pxla.xla_pmap_p.bind_with_trace(trace, (p.flat_fun, *p.flat_args), params) - except dispatch.InternalFloatingPointError as e: + except api_util.InternalFloatingPointError as e: raise FloatingPointError(f'Invalid value ({e.ty}) encountered in parallel computation.') out_tree, out_flat = p.out_tree, out @@ -1649,48 +1729,56 @@ def cache_miss(*args, **kwargs): _pmap_cache_clears.add(cpp_mapped_f) pmap_f = wraps(fun)(cpp_mapped_f) + # Store some data for the `lower` and `trace` methods pmap_f._fun = fun + pmap_f._prepare_pmap = prepare_pmap_fn + pmap_f._backend = backend + pmap_f._axis_name = axis_name + pmap_f._donate_tuple = donate_tuple + + # TODO(necula): move these to top-level; we don't need to do this for + # every pmap + cpp_mapped_f_class = type(pmap_f) + cpp_mapped_f_class.lower = _cpp_mapped_lower + cpp_mapped_f_class.trace = _cpp_mapped_trace + # We return directly the function produced by pmap_lib.pmap, because we do not + # want to have Python in the dispatch path. + return pmap_f - @api_boundary - def lower(*args, **kwargs): - return trace(*args, **kwargs).lower() +@api_boundary +def _cpp_mapped_trace(pmap_f, *args, **kwargs): + p = pmap_f._prepare_pmap(args, kwargs) + abstract_args = list(map(shaped_abstractify, p.flat_args)) + closed_jaxpr, xc_backend, replicas, shards, pci = pxla.get_pmap_jaxpr( + p.flat_fun, pmap_f._backend, pmap_f._axis_name, + axis_size=p.local_axis_size, global_axis_size=p.global_axis_size, + devices=p.devices, + name=p.flat_fun.__name__, + in_axes=p.in_axes_flat, + out_axes_thunk=p.out_axes_thunk, + avals=abstract_args) + lower_callable = partial( + pxla.lower_parallel_callable, p.flat_fun, pmap_f._axis_name, + axis_size=p.local_axis_size, global_axis_size=p.global_axis_size, + devices=p.devices, + name=p.flat_fun.__name__, + in_axes=p.in_axes_flat, + donated_invars=p.donated_invars, + is_explicit_global_axis_size=p.is_explicit_global_axis_size, + avals=abstract_args, + closed_jaxpr=closed_jaxpr, + backend=xc_backend, + replicas=replicas, + shards=shards, + pci=pci) + args_info = stages.make_args_info(p.in_tree, abstract_args, pmap_f._donate_tuple) + return stages.Traced(closed_jaxpr, args_info, p.flat_fun.__name__, + p.out_tree(), lower_callable) - @api_boundary - def trace(*args, **kwargs): - p = _prepare_pmap( - fun, in_axes, out_axes, static_broadcasted_tuple, donate_tuple, - devices, backend, axis_size, args, kwargs) - abstract_args = list(map(shaped_abstractify, p.flat_args)) - closed_jaxpr, xc_backend, replicas, shards, pci = pxla.get_pmap_jaxpr( - p.flat_fun, backend, axis_name, - axis_size=p.local_axis_size, global_axis_size=p.global_axis_size, - devices=p.devices, - name=p.flat_fun.__name__, - in_axes=p.in_axes_flat, - out_axes_thunk=p.out_axes_thunk, - avals=abstract_args) - lower_callable = partial( - pxla.lower_parallel_callable, p.flat_fun, axis_name, - axis_size=p.local_axis_size, global_axis_size=p.global_axis_size, - devices=p.devices, - name=p.flat_fun.__name__, - in_axes=p.in_axes_flat, - donated_invars=p.donated_invars, - is_explicit_global_axis_size=p.is_explicit_global_axis_size, - avals=abstract_args, - closed_jaxpr=closed_jaxpr, - backend=xc_backend, - replicas=replicas, - shards=shards, - pci=pci) - args_info = stages.make_args_info(p.in_tree, abstract_args, donate_tuple) - return stages.Traced(closed_jaxpr, args_info, p.flat_fun.__name__, - p.out_tree(), lower_callable) - - pmap_f.lower = lower - pmap_f.trace = trace +@api_boundary +def _cpp_mapped_lower(pmap_f, *args, **kwargs): + return _cpp_mapped_trace(pmap_f, *args, **kwargs).lower() - return pmap_f _pmap_cache_clears = weakref.WeakSet() # type: ignore @@ -1746,13 +1834,17 @@ def jvp( def _jvp(fun: lu.WrappedFun, primals, tangents, has_aux=False): """Variant of jvp() that takes an lu.WrappedFun.""" - ps_flat, tree_def = tree_flatten(primals) - ts_flat, tree_def_2 = tree_flatten(tangents) + primals_, (), primal_box_data = pjit._flatten_boxes(fun.debug_info, primals, {}) + tangents_, (), tangent_box_data = pjit._flatten_boxes(fun.debug_info, tangents, {}) + fun = pjit._handle_boxes(fun, fun.debug_info) + ps_flat, tree_def = tree_flatten(primals_) + ts_flat, tree_def_2 = tree_flatten(tangents_) if tree_def != tree_def_2: raise TypeError("primal and tangent arguments to jax.jvp must have the same tree " f"structure; primals have tree structure {tree_def} whereas tangents have " f"tree structure {tree_def_2}.") for p, t in zip(ps_flat, ts_flat): + if not isinstance(core.typeof(p), ShapedArray): continue if core.primal_dtype_to_tangent_dtype(_dtype(p)) != _dtype(t): raise TypeError("primal and tangent arguments to jax.jvp do not match; " "dtypes must be equal, or in case of int/bool primal dtype " @@ -1768,9 +1860,27 @@ def _jvp(fun: lu.WrappedFun, primals, tangents, has_aux=False): flat_fun, out_tree = flatten_fun_nokwargs(fun, tree_def) out_primals, out_tangents = ad.jvp(flat_fun).call_wrapped(ps_flat, ts_flat) out_tree = out_tree() + if primal_box_data or tangent_box_data: + assert primal_box_data and tangent_box_data + box_treedef, out_tree = out_tree.children() + box_out_flat, out_primals = split_list(out_primals, [box_treedef.num_leaves]) + box_dot_out_flat, out_tangents = split_list(out_tangents, [box_treedef.num_leaves]) + box_out = tree_unflatten(box_treedef, box_out_flat) + box_dot_out = tree_unflatten(box_treedef, box_dot_out_flat) + for (i, kind), b in zip(primal_box_data, box_out): + if kind is pe.BoxAttr: + primals[i].set(tree_unflatten(b.treedef, b.leaves)) + else: + assert False + for (i, kind), b in zip(tangent_box_data, box_dot_out): + if kind is pe.BoxAttr: + tangents[i].set(tree_unflatten(b.treedef, b.leaves)) + else: + assert False return (tree_unflatten(out_tree, out_primals), tree_unflatten(out_tree, out_tangents)) else: + if primal_box_data or tangent_box_data: raise NotImplementedError flat_fun, out_aux_trees = flatten_fun_nokwargs2(fun, tree_def) jvp_fun, aux = ad.jvp(flat_fun, has_aux=True) out_primals, out_tangents = jvp_fun.call_wrapped(ps_flat, ts_flat) @@ -1887,10 +1997,27 @@ def fun(*tangents): for primal_aval, tangent_aval in zip(primal_avals, tangent_avals): expected_tangent_aval = primal_aval.to_tangent_aval() if not core.typecompat(expected_tangent_aval, tangent_aval): - raise ValueError("linearized function called on tangent values inconsistent with " - "the original primal values: " - f"got tangent aval {tangent_aval} for primal aval {primal_aval} " - f"but expected {expected_tangent_aval}") + extra_msg = '' + if (isinstance(primal_aval, core.ShapedArray) and + isinstance(tangent_aval, core.ShapedArray) and + primal_aval.vma != tangent_aval.vma): + pvary_applications = [] + if left := tangent_aval.vma - primal_aval.vma: + pvary_applications.append( + f"applying `jax.lax.pvary(..., {tuple(left)})` to the primal" + " value passed to `jax.linearize`") + if left := primal_aval.vma - tangent_aval.vma: + pvary_applications.append( + f"applying `jax.lax.pvary(..., {tuple(left)})` to the tangent" + " value passed to the callable `f_jvp` returned by" + " `jax.linearize`") + extra_msg = " \nThis might be fixed by:\n" + "\n".join( + f" * {d};" for d in pvary_applications) + raise ValueError( + "linearized function called on tangent values inconsistent with " + "the original primal values:\n" + f"Got tangent aval {tangent_aval} for primal aval {primal_aval} " + f"but expected {expected_tangent_aval}.{extra_msg}") tangents_out = eval_jaxpr(jaxpr, consts, *tangents) tangents_out_ = iter(tangents_out) full_out = [pval.get_known() if pval.is_known() else next(tangents_out_) @@ -2033,6 +2160,84 @@ def _vjp(fun: lu.WrappedFun, *primals, has_aux=False): return out_primal_py, vjp_py, tree_unflatten(aux_tree, aux) +def saved_input_vjp(f: Callable, which: Sequence[bool], *primals, + allow_unused: bool = True, allow_opaque: bool = True): + if len(which) != len(primals): + raise ValueError( + "length of 'which' argument must equal the number of primal input values, " + f"but got {len(which)=} and {len(primals)=}") + + dbg = debug_info("saved_input_vjp", f, primals, {}) + fun = lu.wrap_init(f, debug_info=dbg) + primals_flat, in_tree = tree_flatten(primals) + fun, out_tree = flatten_fun_nokwargs(fun, in_tree) + out_primals_flat, out_pvals, jaxpr, residuals = ad.linearize(fun, *primals_flat) + out_known = [pval.is_known() for pval in out_pvals] + primals_filt, filt_tree = tree_flatten(tuple(p for w, p in zip(which, primals) if w)) + id_map = {id(x): i for i, x in enumerate(primals_filt)} + opaque_residuals = [] + res_spec = [RSpec(id_map[id(r)], True) if id(r) in id_map else + RSpec(opaque_residuals.append(r) or (len(opaque_residuals) - 1), False) # type: ignore + for r in residuals] + f_vjp = Partial(partial(_saved_input_vjpfun, res_spec, filt_tree, in_tree, + out_tree(), out_known, jaxpr), opaque_residuals) + + if not allow_unused and not set(id_map).issubset(res_ids := {id(r) for r in residuals}): + unused = [(i, core.get_aval(x)) for i, (x, w) in enumerate(zip(primals, which)) + if w and id(x) not in res_ids] + assert unused + if len(unused) == 1: + (i, a), = unused + start, was = "an input value", "was" + msg = f" {dbg.arg_names[i]} of type {a.str_short()}" + else: + start, was = "multiple input values", "were" + msg = "\n" + "\n".join(f" * {dbg.arg_names[i]} of type {a.str_short()}" + for i, a in unused) + raise Exception(f"with {allow_unused=}, {start} marked to be saved {was} " + f"not used by the backward pass:{msg}") + + if not allow_opaque and opaque_residuals: + msg = ", ".join(core.get_aval(x).str_short() for x in opaque_residuals) + raise Exception(f"with {allow_opaque=}, the backward pass requires opaque " + f"(non-input) residuals: {msg}") + + out_primals = tree_unflatten(out_tree(), out_primals_flat) + return out_primals, f_vjp + +def _saved_input_vjpfun(res_spec, filtered_tree, in_tree, out_tree, out_known, + jaxpr, opaque_residuals, ct, *saved_primals): + primals_filtered, filtered_tree_ = tree_flatten(saved_primals) + if filtered_tree != filtered_tree_: + raise ValueError( + "inputs passed to f_vjp must be a tuple of (pytrees of) " + "arrays with the same structure as\n" + " tuple(x for x, w in zip(inputs, which) if w)\n" + "given the original call\n" + " _, f_vjp = saved_input_vjp(f, which, *inputs, ...)\n" + "but the structures differ:\n" + + "\n".join(f" * inputs{keystr(path)} was a {thing1} in the original " + f"call, but a {thing2} here, so {explanation}" + for path, thing1, thing2, explanation + in equality_errors_pytreedef(filtered_tree, filtered_tree_))) + + residuals = [primals_filtered[i.idx] if i.primal else opaque_residuals[i.idx] + for i in res_spec] + dummy_args = [ad.UndefinedPrimal(v.aval) for v in jaxpr.invars] + cts_flat, out_tree_ = tree_flatten(ct) + assert out_tree_ == out_tree + cts_flat = [ct for ct, k in zip(cts_flat, out_known) if not k] + arg_cts = ad.backward_pass(jaxpr, True, residuals, dummy_args, cts_flat) + return tree_unflatten(in_tree, map(ad.instantiate_zeros, arg_cts)) + +@dataclasses.dataclass(frozen=True) +class RSpec: + idx: int + primal: bool + +si_vjp = saved_input_vjp + + def linear_transpose(fun: Callable, *primals, reduce_axes=()) -> Callable: """Transpose a function that is promised to be linear. @@ -2147,7 +2352,7 @@ def make_jaxpr( return_shape: bool = False, abstracted_axes: Any | None = None, ) -> Callable[..., core.ClosedJaxpr | tuple[core.ClosedJaxpr, Any]]: - """Creates a function that produces its jaxpr given example args. + """Create a function that returns the jaxpr of ``fun`` given example args. Args: fun: The function whose ``jaxpr`` is to be computed. Its positional @@ -2198,7 +2403,7 @@ def make_jaxpr( c:f32[] = sin a _:f32[] = sin b d:f32[] = cos b - e:f32[] = mul 1.0 d + e:f32[] = mul 1.0:f32[] d f:f32[] = neg e g:f32[] = mul f c in (g,) } @@ -2265,10 +2470,10 @@ def _check_string_compatible_sharding(s): @lru_cache(maxsize=2048) def _check_sharding(aval, s): if (s is not None and - not isinstance(s, (xc.Device, Sharding, Layout, TransferToMemoryKind))): + not isinstance(s, (xc.Device, Sharding, Format, TransferToMemoryKind))): raise ValueError( "`jax.device_put` only accepts `None`, `jax.sharding.Sharding`," - " `jax.Device`, `Layout` or a pytree of these values. Received" + " `jax.Device`, `Format` or a pytree of these values. Received" f" invalid value: {s}") if isinstance(aval, core.ShapedArray) and dtypes.is_string_dtype(aval.dtype): @@ -2294,8 +2499,8 @@ def pspec_to_sharding(val): def device_put( x, - device: None | xc.Device | Sharding | P | Layout | Any | TransferToMemoryKind = None, - *, src: None | xc.Device | Sharding | P | Layout | Any | TransferToMemoryKind = None, + device: None | xc.Device | Sharding | P | Format | Any | TransferToMemoryKind = None, + *, src: None | xc.Device | Sharding | P | Format | Any | TransferToMemoryKind = None, donate: bool | Any = False, may_alias: bool | None | Any = None): """Transfers ``x`` to ``device``. @@ -2374,8 +2579,8 @@ def device_put( for xf, d in zip(x_flat, device_flat): _check_sharding(shaped_abstractify(xf), d) out_flat = dispatch.device_put_p.bind( - *x_flat, devices=device_flat, srcs=src_flat, - copy_semantics=copy_semantics) + *x_flat, devices=tuple(device_flat), srcs=tuple(src_flat), + copy_semantics=tuple(copy_semantics)) return tree_unflatten(treedef, out_flat) @@ -2510,7 +2715,6 @@ def _device_put_replicated(x): sharding = PmapSharding(np.array(devices), sharding_spec) if dtypes.issubdtype(aval.dtype, dtypes.extended): return aval.dtype._rules.device_put_replicated(buf, aval, sharding, devices) - assert len(xla.aval_to_xla_shapes(aval)) == 1 return pxla.batched_device_put(aval, sharding, [buf] * len(devices), devices) with config.explicit_device_put_scope(): @@ -2585,33 +2789,53 @@ class ShapeDtypeStruct: dtype: a dtype-like object sharding: (optional) a :class:`jax.Sharding` object """ - __slots__ = ["shape", "dtype", "sharding", "_dll", "weak_type"] + __slots__ = ["shape", "dtype", "sharding", "_dll", "weak_type", "vma"] - def __init__(self, shape, dtype, *, sharding=None, weak_type=False): + def __init__(self, shape, dtype, *, sharding=None, weak_type=False, + vma=None): self.shape = tuple(shape) if dtype is None: raise ValueError("ShapeDtypeStruct: dtype must be specified.") self.dtype = dtype if dtypes.issubdtype(dtype, dtypes.extended) else np.dtype(dtype) - if sharding is not None and not isinstance(sharding, (Sharding, Layout)): + if sharding is not None and not isinstance(sharding, (Sharding, Format, P)): raise ValueError( - "sharding should be an instance of `jax.sharding.Sharding` or" - f" `jax.experimental.layout.Layout`. Got {sharding} of type" + "sharding should be an instance of `jax.sharding.Sharding`, " + "`jax.sharding.PartitionSpec` or" + f" `jax.experimental.layout.Format`. Got {sharding} of type" f" {type(sharding)}.") - if (isinstance(sharding, Layout) and + if (isinstance(sharding, Format) and isinstance(sharding.device_local_layout, AutoLayout)): raise TypeError( "`DeviceLocalLayout.AUTO` cannot be used in place of a device-local" f" layout in a `ShapeDtypeStruct`. Got {sharding}") - self.sharding = sharding.sharding if isinstance(sharding, Layout) else sharding - self._dll = sharding.device_local_layout if isinstance(sharding, Layout) else None + if isinstance(sharding, Format): + self.sharding = sharding.sharding + elif isinstance(sharding, P): + # TODO(yashkatariya): Should this be abstract mesh? + cur_mesh = get_concrete_mesh() + if cur_mesh is None: + raise TypeError( + "When specifying PartitionSpec to `ShapeDtypeStruct`, the context" + " mesh cannot be empty. Please use `jax.sharding.use_mesh` to set" + " the mesh context.") + self.sharding = NamedSharding(cur_mesh, sharding) + else: + self.sharding = sharding + self._dll = (sharding.device_local_layout if isinstance(sharding, Format) + else None) self.weak_type = weak_type + if vma is not None and not isinstance(vma, (set, frozenset)): + raise TypeError( + "`vma` argument passed to ShapeDtypeStruct should be of type `set`" + f" or `frozenset`. Got type {type(vma)}") + self.vma = None if vma is None else frozenset(vma) size = property(lambda self: math.prod(self.shape)) ndim = property(lambda self: len(self.shape)) @property - def layout(self): - return Layout(self._dll, self.sharding) + def format(self): + return Format(self._dll, self.sharding) def __len__(self): try: @@ -2621,10 +2845,11 @@ def __len__(self): def __repr__(self): sh = f", sharding={self.sharding}" if self.sharding is not None else "" - l = f", layout={self.layout}" if self._dll is not None else "" + l = f", format={self._dll}" if self._dll is not None else "" wt = f", weak_type={self.weak_type}" if self.weak_type else "" + vma = f", vma={self.vma}" if self.vma else "" return (f"{type(self).__name__}(shape={self.shape}, " - f"dtype={self.dtype.name}{sh}{l}{wt})") + f"dtype={self.dtype.name}{sh}{l}{wt}{vma})") __str__ = __repr__ @@ -2632,15 +2857,51 @@ def __eq__(self, other): if not isinstance(other, ShapeDtypeStruct): return False else: - return ((self.shape, self.dtype, self.sharding, self.layout, self.weak_type) == - (other.shape, other.dtype, other.sharding, other.layout, other.weak_type)) + return ((self.shape, self.dtype, self.sharding, self._dll, + self.weak_type, self.vma) == + (other.shape, other.dtype, other.sharding, other._dll, + other.weak_type, other.vma)) def __hash__(self): # TODO(frostig): avoid the conversion from dict by addressing # https://github.com/jax-ml/jax/issues/8182 - return hash((self.shape, self.dtype, self.sharding, self.layout, self.weak_type)) + return hash((self.shape, self.dtype, self.sharding, self._dll, + self.weak_type, self.vma)) + + def __setattr__(self, name, value): + if hasattr(self, name): + if getattr(self, name) == value: + # This can happen if two threads race, for example if two threads + # are trying to hash the same SDS instance. + return + raise RuntimeError( + f"Cannot reassign attributes ({name}) of immutable ShapeDtypeStruct" + " objects") + super().__setattr__(name, value) + + def update(self, **kwargs): + if 'sharding' in kwargs: + s = kwargs['sharding'] + if self._dll is not None and isinstance(s, Sharding): + raise ValueError( + f"You are updating ShapeDtypeStruct with a {type(s)} when the" + f" original ShapeDtypeStruct had a concrete layout {self.format}." + " This might lead to bugs. If you want to do this, create a new" + " ShapeDtypeStruct via the constructor.") + sharding = s + else: + sharding = self.format + return ShapeDtypeStruct( + shape=kwargs.pop('shape', self.shape), + dtype=kwargs.pop('dtype', self.dtype), + sharding=sharding, + weak_type=kwargs.pop('weak_type', self.weak_type), + vma=kwargs.pop('vma', self.vma)) + def _sds_aval_mapping(x): + # TODO(yashkatariya): Propagate vma to ShapedArray? This is only used for + # pallas right now and pallas doesn't use pytype_aval_mappings. aval = ShapedArray( x.shape, dtypes.canonicalize_dtype(x.dtype, allow_extended_dtype=True), weak_type=x.weak_type) @@ -2880,6 +3141,7 @@ def clear_backends(): dispatch.xla_primitive_callable.cache_clear() util.clear_all_caches() pjit._infer_params_cached.cache_clear() + pjit._pjit_lower_cached.cache_clear() pjit._create_pjit_jaxpr.cache_clear() # pytype: disable=attribute-error pjit._cpp_pjit_cache_fun_only.clear() pjit._cpp_pjit_cache_explicit_attributes.clear() diff --git a/jax/_src/api_util.py b/jax/_src/api_util.py index a42141b96fbd..5261764d0bf8 100644 --- a/jax/_src/api_util.py +++ b/jax/_src/api_util.py @@ -28,12 +28,12 @@ from jax._src.tree_util import ( PyTreeDef, tree_flatten, tree_unflatten, tree_map, treedef_children, generate_key_paths, broadcast_prefix, - prefix_errors) -from jax._src.tree_util import _replace_nones + prefix_errors, _replace_nones) from jax._src import linear_util as lu from jax._src.util import (safe_map, WrapKwArgs, Hashable, HashableFunction, - Unhashable, safe_zip) + Unhashable, safe_zip as zip) from jax._src import traceback_util + traceback_util.register_exclusion(__file__) map = safe_map @@ -201,9 +201,11 @@ def _validate_argnames( f"in {argnames_name}. Function does not take these args.") -def argnums_partial(f, dyn_argnums, args, require_static_args_hashable=True): +def argnums_partial(f: lu.WrappedFun, dyn_argnums: int | Sequence[int], + args: Sequence, require_static_args_hashable=True): dyn_argnums = _ensure_index_tuple(dyn_argnums) dyn_argnums = _ensure_inbounds(False, len(args), dyn_argnums) + fixed_args: list if require_static_args_hashable: fixed_args = [] for i, arg in enumerate(args): @@ -257,7 +259,7 @@ def argnums_partial_except(f: lu.WrappedFun, static_argnums: tuple[int, ...], dyn_args = tuple(args[i] for i in dyn_argnums) fixed_args = [] - for i in static_argnums: + for i in sorted(static_argnums): # TODO(shoyer): set allow_invalid=True permanently after static_argnames. if allow_invalid and i >= len(args): continue @@ -273,7 +275,9 @@ def argnums_partial_except(f: lu.WrappedFun, static_argnums: tuple[int, ...], return _argnums_partial(f, dyn_argnums, tuple(fixed_args)), dyn_args @lu.transformation2 -def _argnums_partial(_fun, _dyn_argnums, _fixed_args, *dyn_args, **kwargs): +def _argnums_partial(_fun: Callable, + _dyn_argnums: Sequence[int], + _fixed_args: Sequence, *dyn_args, **kwargs): sentinel = object() args = [sentinel] * (len(_fixed_args) + len(dyn_args)) for i, arg in zip(_dyn_argnums, dyn_args): @@ -334,7 +338,7 @@ def donation_vector(donate_argnums, donate_argnames, in_tree, donate = bool(i in donate_argnums) res.extend((donate,) * arg.num_leaves) if kwargs_tree is not None: - for key, val in safe_zip(kwargs_tree.node_data()[1], kwargs_tree.children()): # type: ignore + for key, val in zip(kwargs_tree.node_data()[1], kwargs_tree.children()): # type: ignore donate = key in donate_argnames res.extend((donate,) * val.num_leaves) return tuple(res) @@ -602,7 +606,7 @@ def debug_info( """Constructd core.DebugInfo for a function given example args and kwargs. `args` and `kwargs` are example positional and keyword arguments, users with - `inspect.Signature` to get the names of argments. The arguments that are + `inspect.Signature` to get the names of arguments. The arguments that are considered static for tracing purposes should be included, and designated using `static_argnums` and `static_argnames`. @@ -673,28 +677,45 @@ def _non_static_arg_names(fn_signature: inspect.Signature | None, top-level arguments. In other cases, including when the `args` and `kwargs` do not match the signature, we use names like `args[0[]`, `args[1]`, etc. """ + # Use the same argument parsing as jit: positional followed by kwargs + # sorted by keys. static = object() static_argnums_ = _ensure_inbounds(True, len(args), static_argnums) static_argnames_ = set(static_argnames) args_ = [static if i in static_argnums_ else x for i, x in enumerate(args)] - kwargs_ = {k:static if k in static_argnames_ else x for k, x in kwargs.items()} + kwargs_ = {k: static if k in static_argnames_ else x for k, x in kwargs.items()} + ordered_args: Sequence[tuple[str, Any]] | None = None if fn_signature is not None: try: ba = fn_signature.bind(*args_, **kwargs_) except (ValueError, TypeError): pass else: - return tuple(f'{name}{lu._clean_keystr_arg_names(path)}' - for name, x in ba.arguments.items() - for path, l in generate_key_paths(x) if l is not static) - args_arg_names = tuple(f'args{lu._clean_keystr_arg_names(path)}' - for path, l in generate_key_paths(args_) - if l is not static) - kwargs_arg_names = tuple(f'kwargs{lu._clean_keystr_arg_names(path)}' - for path, l in generate_key_paths(kwargs_) - if l is not static) - arg_names = args_arg_names + kwargs_arg_names - return arg_names + # Do we have a **kwargs + kwargs_name = next((name for name, p in fn_signature.parameters.items() + if p.kind == inspect.Parameter.VAR_KEYWORD), None) + # Positional argument are those not passed by keyword and not passed + # by **kwargs. + positional = [(name, x) for name, x in ba.arguments.items() + if name not in kwargs and name != kwargs_name] + # Keyword arguments are passed sorted by actual kwarg keyword + sorted_kwargs = sorted(((name, x) for name, x in kwargs_.items()), + key=lambda name_x: name_x[0]) + sorted_kwargs = [(name if name in ba.arguments else f"{kwargs_name}['{name}']", + x) + for name, x in sorted_kwargs] + ordered_args = positional + sorted_kwargs + + if ordered_args is None: + positional = [("args", args_)] + keyword = sorted([(f"kwargs['{name}']", x) for name, x in kwargs_.items() if x is not static], + key=lambda name_x: name_x[0]) + ordered_args = positional + keyword + + return tuple(f'{name}{lu._clean_keystr_arg_names(path)}' + for name, x in ordered_args + for path, l in generate_key_paths(x) if l is not static) + def hoist_obj_attrs(f, flat_args): idxs, objs, flat_args_ = [], [], [] @@ -746,3 +767,41 @@ def _check_no_aliased_closed_over_refs(dbg: core.DebugInfo, consts, args) -> Non f"array reference of type {a.str_short()} was both closed over and " f"passed as the argument " f"{dbg.safe_arg_names(len(args))[i]}" if dbg else "at flat index {i}") + +class InternalFloatingPointError(Exception): + name: str + ty: str + + def __init__(self, name: str, ty: str): + self.name = name + self.ty = ty + +def maybe_recursive_nan_check(e: Exception, fun: Callable, args, kwargs, +) -> None: # always raises an exception + print("Invalid nan value encountered in the output of a jax.jit " + "function. Calling the de-optimized version.") + try: + _ = fun(*args, **kwargs) + except (FloatingPointError, ZeroDivisionError) as e2: + raise e2 from None + else: + _raise_no_nan_in_deoptimized(e) + + +def _raise_no_nan_in_deoptimized(e) -> None: + msg = (f"{str(e)}. Because " + "jax_config.debug_nans.value and/or config.jax_debug_infs is set, the " + "de-optimized function (i.e., the function as if the `jit` " + "decorator were removed) was called in an attempt to get a more " + "precise error message. However, the de-optimized function did not " + "produce invalid values during its execution. This behavior can " + "result from `jit` optimizations causing the invalid value to be " + "produced. It may also arise from having nan/inf literals as " + "inputs or outputs, like `jax.jit(lambda ...: jax.numpy.nan)(...)`. " + "\n\n" + "It may be possible to avoid the invalid value by removing the " + "`jit` decorator, at the cost of losing optimizations. " + "\n\n" + "If you see this error, consider opening a bug report at " + "https://github.com/jax-ml/jax.") + raise FloatingPointError(msg) from None diff --git a/jax/_src/array.py b/jax/_src/array.py index b0793d2c3330..61ad8a7f4405 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -36,14 +36,14 @@ from jax._src.interpreters import mlir from jax._src.interpreters import pxla from jax._src.interpreters import xla -from jax._src.layout import AutoLayout, DeviceLocalLayout, Layout +from jax._src.layout import AutoLayout, DeviceLocalLayout, Format from jax._src.lib import xla_client as xc -from jax._src.lib import xla_extension as xe +from jax._src.lib import _jax from jax._src.sharding import Sharding from jax._src.sharding_impls import ( PmapSharding, SingleDeviceSharding, device_replica_id_map, hashed_index, num_addressable_indices, - local_to_global_shape, use_concrete_mesh) # pyformat: disable + local_to_global_shape, _internal_use_concrete_mesh) # pyformat: disable from jax._src.typing import ArrayLike, DLDeviceType, DTypeLike from jax._src.util import safe_zip, unzip3, use_cpp_class, use_cpp_method, cache import numpy as np @@ -343,8 +343,8 @@ def __format__(self, format_spec): return format(self._value, format_spec) def __getitem__(self, idx): - from jax._src.lax import lax - from jax._src.numpy import indexing + from jax._src.lax import lax # pytype: disable=import-error + from jax._src.numpy import indexing # pytype: disable=import-error self._check_if_deleted() if isinstance(self.sharding, PmapSharding): @@ -444,7 +444,7 @@ def __dlpack__(self, *, stream: int | Any | None = None, max_version: tuple[int, int] | None = None, dl_device: tuple[DLDeviceType, int] | None = None, copy: bool | None = None): - from jax._src.dlpack import to_dlpack # pylint: disable=g-import-not-at-top + from jax._src.dlpack import to_dlpack # pytype: disable=import-error # pylint: disable=g-import-not-at-top device_set = self.sharding.device_set if len(device_set) > 1: @@ -464,7 +464,7 @@ def __dlpack_device__(self) -> tuple[enum.Enum, int]: if len(self._arrays) != 1: raise BufferError("__dlpack__ only supported for unsharded arrays.") - from jax._src.dlpack import DLDeviceType # pylint: disable=g-import-not-at-top + from jax._src.dlpack import DLDeviceType # pytype: disable=import-error # pylint: disable=g-import-not-at-top if self.platform() == "cpu": return DLDeviceType.kDLCPU, 0 @@ -547,17 +547,17 @@ def addressable_shards(self) -> Sequence[Shard]: return out @property - def layout(self): + def format(self): # TODO(yashkatariya): Remove the deleted check from here. if self.is_deleted(): - return Layout(None, self.sharding) + return Format(None, self.sharding) try: - return Layout(DeviceLocalLayout.from_pjrt_layout(self._pjrt_layout), + return Format(DeviceLocalLayout.from_pjrt_layout(self._pjrt_layout), self.sharding) - except xe.XlaRuntimeError as e: + except _jax.XlaRuntimeError as e: msg, *_ = e.args if type(msg) is str and msg.startswith("UNIMPLEMENTED"): - return Layout(None, self.sharding) + return Format(None, self.sharding) else: raise @@ -636,7 +636,8 @@ def _value(self) -> np.ndarray: self._check_if_deleted() if self._npy_value is None: - if self.is_fully_replicated: + if (self.is_fully_replicated and + self.sharding._internal_device_list.addressable_device_list): # type: ignore npy_value, did_copy = self._single_device_array_to_np_array_did_copy() npy_value.flags.writeable = False if did_copy: @@ -710,7 +711,7 @@ def _get_and_check_dtype(arrays: Sequence[basearray.Array | np.ndarray], # TODO(yashkatariya): Remove None from callback input type. def make_array_from_callback( - shape: Shape, sharding: Sharding | Layout, + shape: Shape, sharding: Sharding | Format, data_callback: Callable[[Index | None], ArrayLike], dtype: DTypeLike | None = None) -> ArrayImpl: # pyformat: disable @@ -755,12 +756,12 @@ def make_array_from_callback( (4, 2) """ # pyformat: enable - dll = sharding.device_local_layout if isinstance(sharding, Layout) else None + dll = sharding.device_local_layout if isinstance(sharding, Format) else None if isinstance(dll, AutoLayout): raise TypeError( "`DeviceLocalLayout.AUTO` cannot be used in place of a device-local" f" layout when calling `jax.make_array_from_callback`. Got {sharding}") - sharding = sharding.sharding if isinstance(sharding, Layout) else sharding + sharding = sharding.sharding if isinstance(sharding, Format) else sharding if not isinstance(sharding, Sharding): raise TypeError( f"sharding should be an instance of `jax.sharding`. Got {sharding} of" @@ -811,7 +812,7 @@ def get_data(index: Index | None) -> ArrayImpl | np.ndarray: and sharding.is_fully_replicated and first_value.is_fully_replicated and first_value.sharding._device_assignment == tuple(devices) - and first_value.layout.device_local_layout == dll): + and first_value.format.device_local_layout == dll): return first_value if dtypes.issubdtype(aval.dtype, dtypes.extended): @@ -822,7 +823,7 @@ def get_data(index: Index | None) -> ArrayImpl | np.ndarray: ) if dll is not None: - devices = [Layout(dll, SingleDeviceSharding(d)) for d in devices] + devices = [Format(dll, SingleDeviceSharding(d)) for d in devices] # pxla.batched_device_put doesn't support Layout... Take the slow route arrays = api.device_put(per_device_values, devices) return ArrayImpl(aval, sharding, arrays, committed=True) @@ -1024,7 +1025,7 @@ def make_array_from_single_device_arrays( shape : Shape of the output ``jax.Array``. This conveys information already included with ``sharding`` and ``arrays`` and serves as a double check. sharding: Sharding: A global Sharding instance which describes how the output jax.Array is laid out across devices. - arrays: Sequence of ``jax.Array``\s that are each single device addressable. ``len(arrays)`` + arrays: `list` or `tuple` of ``jax.Array``\s that are each single device addressable. ``len(arrays)`` must equal ``len(sharding.addressable_devices)`` and the shape of each array must be the same. For multiprocess code, each process will call with a different ``arrays`` argument that corresponds to that processes' data. These arrays are commonly created via ``jax.device_put``. @@ -1071,14 +1072,15 @@ def make_array_from_single_device_arrays( if dtypes.issubdtype(aval.dtype, dtypes.extended): return aval.dtype._rules.make_sharded_array(aval, sharding, arrays, committed=True) + arrays = list(arrays) if isinstance(arrays, tuple) else arrays # TODO(phawkins): ideally the cast() could be checked. try: return ArrayImpl(aval, sharding, cast(Sequence[ArrayImpl], arrays), committed=True) except TypeError: - if not isinstance(arrays, Sequence): + if not isinstance(arrays, list): raise TypeError("jax.make_array_from_single_device_arrays `arrays` " - "argument must be a Sequence (list or tuple), but got " + "argument must be a list or tuple, but got " f"{type(arrays)}.") if any(isinstance(arr, core.Tracer) for arr in arrays): raise ValueError( @@ -1092,9 +1094,6 @@ def _get_aval_array(self): return core.update_aval_with_sharding(self.aval, self.sharding) core.pytype_aval_mappings[ArrayImpl] = _get_aval_array -# TODO(jakevdp) replace this with true inheritance at the C++ level. -basearray.Array.register(ArrayImpl) - def _array_mlir_constant_handler(val): try: @@ -1149,7 +1148,7 @@ def shard_device_array(x, devices, indices, sharding): else: # TODO(yashkatariya): Maybe this should be set when we call the handler in # InputsHandler.__call__? - with use_concrete_mesh(None): + with _internal_use_concrete_mesh(None): shards = x._multi_slice(start_indices, limit_indices, removed_dims) aval = core.shaped_abstractify(x) return pxla.batched_device_put(aval, sharding, shards, devices) @@ -1198,7 +1197,7 @@ def _array_shard_arg(xs, shardings, layouts, copy_semantics): x._check_if_deleted() indices, same_indices = _sharding_indices_and_eq(x.sharding, x.shape, sharding) same_layout = (True if layout is None else - x.layout.device_local_layout == layout) + x.format.device_local_layout == layout) if not x.is_fully_addressable: if same_indices and same_layout: @@ -1219,7 +1218,7 @@ def _array_shard_arg(xs, shardings, layouts, copy_semantics): batch_cs.append(cs) # Resharding starts here: elif not same_layout: - results.append(api.device_put(x, Layout(layout, sharding))) + results.append(api.device_put(x, Format(layout, sharding))) elif dispatch.is_single_device_sharding(x.sharding): results.append(shard_device_array(x, devices, indices, sharding)) else: diff --git a/jax/_src/attrs.py b/jax/_src/attrs.py new file mode 100644 index 000000000000..6ace51a091e4 --- /dev/null +++ b/jax/_src/attrs.py @@ -0,0 +1,403 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Any +from collections.abc import Callable + +from jax._src import core +from jax._src import source_info_util +from jax._src import api_util +from jax._src import linear_util as lu +from jax._src.ad_util import (Zero) +from jax._src.api_util import flatten_fun_nokwargs +from jax._src.interpreters import ad +from jax._src.interpreters import partial_eval as pe +from jax._src.tree_util import (tree_flatten, tree_unflatten, tree_structure, + treedef_tuple) +from jax._src.util import unzip2, safe_map, safe_zip, split_list +from jax._src.dtypes import dtype, float0 + +map, unsafe_map = safe_map, map +zip, unsafe_zip = safe_zip, zip + +Array = Any +JaxVal = Any +PyTree = Any +PyTreeDef = Any + +ReadWrite = pe.ReadWrite +Append = pe.Append + +register = api_util.register_class_with_attrs +dne_sentinel = pe.dne_sentinel + +def jax_getattr(obj: Any, attr: str) -> PyTree: + with core.take_current_trace() as t: + return t.process_getattr(obj, attr) + +def jax_setattr(obj: Any, attr: str, val: PyTree) -> None: + with core.take_current_trace() as t: + return t.process_setattr(obj, attr, val) + +def jax_appendattr(obj: Any, attr: str, val: Array) -> None: + import jax.numpy as jnp # pytype: disable=import-error + return jax_extendattr(obj, attr, jnp.expand_dims(val, 0)) + +def jax_extendattr(obj: Any, attr: str, val: Array) -> None: + with core.take_current_trace() as t: + return t.process_extendattr(obj, attr, val) + +def _getattr_impl(_, obj, attr): + return getattr(obj, attr) +core.EvalTrace.process_getattr = _getattr_impl + +def _setattr_impl(_, obj, attr, val): + setattr(obj, attr, val) +core.EvalTrace.process_setattr = _setattr_impl + +def _extendattr_impl(_, obj, attr, val): + import jax.numpy as jnp # pytype: disable=import-error + cur = getattr(obj, attr, dne_sentinel) + if cur is dne_sentinel: + new = val + else: + _check_append_type_agreement(obj, attr, core.typeof(cur), core.typeof(val)) + new = jnp.concatenate([cur, val]) + setattr(obj, attr, new) +core.EvalTrace.process_extendattr = _extendattr_impl + +def _check_append_type_agreement(_, attr, curtype, valtype): + expected = core.mapped_aval(curtype.shape[0], 0, curtype) + got = core.mapped_aval(valtype.shape[0], 0, valtype) + if not core.typematch(expected, got): + raise TypeError( + f"can only append to attr {attr} with values of trailing shape " + f"{expected.str_short()}, but appendattr got value of type " + f"{valtype.str_short()} which has trailing shape {got.str_short()}.") + +def _ensure_tracked(trace: pe.DynamicJaxprTrace, obj: Any, attr: str, + kind: pe.AttrKind): + frame = trace.frame + source_info = source_info_util.current() + + def new_tracer(x): + aval = core.get_aval(x) + tracer = pe.DynamicJaxprTracer(trace, aval, source_info) + var = frame.tracer_to_var[id(tracer)] = frame.newvar(aval) + frame.attrs_vars.append(var) + frame.tracers.append(tracer) + return tracer + + if (obj, attr, Append) in frame.attrs_tracked: + raise TypeError(f"can't read/write to append-only attr {attr}") + + if (obj, attr, kind) not in frame.attrs_tracked: + init_val = getattr(obj, attr, dne_sentinel) + frame.attrs_inits.append(init_val) + init_vals, init_tree = tree_flatten(init_val) + tracers = map(new_tracer, init_vals) + setattr(obj, attr, tree_unflatten(init_tree, tracers)) + frame.attrs_tracked.append((obj, attr, kind)) +pe.DynamicJaxprTrace._ensure_tracked = _ensure_tracked + +def _getattr_staging(trace, obj, attr): + trace._ensure_tracked(obj, attr, ReadWrite) + return getattr(obj, attr) +pe.DynamicJaxprTrace.process_getattr = _getattr_staging + +def _setattr_staging(trace, obj, attr, val): + trace._ensure_tracked(obj, attr, ReadWrite) + setattr(obj, attr, val) +pe.DynamicJaxprTrace.process_setattr = _setattr_staging + +def _extendattr_staging(trace, obj, attr, val): + import jax.numpy as jnp # pytype: disable=import-error + frame = trace.frame + + if (obj, attr, ReadWrite) in frame.attrs_tracked: + raise TypeError("can't append to read/write-only attr {attr}") + + first_write = (obj, attr, Append) not in frame.attrs_tracked + init_val = getattr(obj, attr, dne_sentinel) + if init_val is not dne_sentinel: + _check_append_type_agreement(obj, attr, core.typeof(init_val), core.typeof(val)) + if first_write: + frame.attrs_inits.append(init_val) + frame.attrs_tracked.append((obj, attr, Append)) + tracer = val + else: + assert init_val is not dne_sentinel + with core.set_current_trace(trace): + tracer = jnp.concatenate([init_val, val]) + setattr(obj, attr, tracer) +pe.DynamicJaxprTrace.process_extendattr = _extendattr_staging + + +def jvp(f, primals, tangents, attr_tangents): + attrs, attr_tangents = unzip2(((o, a), t) for o, a, t in attr_tangents) + attr_primals = tuple(jax_getattr(o, a) for o, a in attrs) + primals_flat, in_tree = tree_flatten((attr_primals, *primals)) + tangents_flat, in_tree_ = tree_flatten((attr_tangents, *tangents)) + if in_tree != in_tree_: raise Exception + dbg = api_util.debug_info("attrs_jvp", f, primals, {}) + f_, out_tree = flatten_fun_nokwargs( + _set_attrs(lu.wrap_init(f, debug_info=dbg), attrs), in_tree) + out_primals_flat, out_tangents_flat, tangent_attrs_out = _jvp(f_).call_wrapped( + primals_flat, tangents_flat) + out_primals = tree_unflatten(out_tree(), out_primals_flat) + out_tangents = tree_unflatten(out_tree(), out_tangents_flat) + return out_primals, out_tangents, tangent_attrs_out + +@lu.transformation2 +def _set_attrs(f, attrs, attr_vals, *args): + for (o, a), x in zip(attrs, attr_vals): + jax_setattr(o, a, x) + return f(*args) + +def _jvp(fun: lu.WrappedFun): + return jvpfun2(jvp_subtrace2(fun)) + +@lu.transformation2 +def jvpfun2(f, primals, tangents): + tag = core.TraceTag() + tangents = [Zero.from_primal_value(t) if not isinstance(t, Zero) + and dtype(t) == float0 else t for t in tangents] + ctx = source_info_util.transform_name_stack('jvp') + with ctx: + out_primals, out_tangents, tangent_attrs_out = f(tag, primals, tangents) + return out_primals, out_tangents, tangent_attrs_out + +@lu.transformation2 +def jvp_subtrace2(f, tag, primals, tangents): + with core.take_current_trace() as parent_trace: + trace = ad.JVPTrace(parent_trace, tag) + tag.attrs_tracked = [] # attrs written to + in_tracers = [ad.JVPTracer(trace, x, t) if type(t) is not ad.Zero else x + for x, t in zip(primals, tangents)] + with core.set_current_trace(trace): + ans = f(*in_tracers) + out_primals, out_tangents = unzip2(map(trace.to_primal_tangent_pair, ans)) + tangent_attrs_out = [] + for (obj, name) in tag.attrs_tracked: + primal, tangent = trace.to_primal_tangent_pair(jax_getattr(obj, name)) + jax_setattr(obj, name, primal) + if type(tangent) is not ad.Zero: + tangent_attrs_out.append((obj, name, tangent)) + del tag.attrs_tracked + return out_primals, out_tangents, tangent_attrs_out + +def _setattr_jvp(trace, obj, attr, maybe_tracer): + primal, tangent = trace.to_primal_tangent_pair(maybe_tracer) + if isinstance(tangent, ad.Zero): + return setattr(obj, attr, primal) + if (obj, attr) not in trace.tag.attrs_tracked: + trace.tag.attrs_tracked.append((obj, attr)) + return setattr(obj, attr, ad.JVPTracer(trace, primal, tangent)) +ad.JVPTrace.process_setattr = _setattr_jvp + +def _getattr_jvp(trace, obj, attr): + return getattr(obj, attr) +ad.JVPTrace.process_getattr = _getattr_jvp + +ad.LinearizeTrace.process_setattr = _setattr_jvp +ad.LinearizeTrace.process_getattr = _getattr_jvp + +def linearize(f: Callable, *primals, attrs: list[tuple[Any, str]] = []): + attr_primals = [jax_getattr(o, a) for o, a in attrs] + attr_avals = [core.get_aval(p) for p in attr_primals] + primals_flat, in_tree = tree_flatten(primals) + tree = treedef_tuple((tree_structure(attr_primals), *in_tree.children())) + dbg = api_util.debug_info("attrs linearize", f, primals, {}) + f_, out_tree = flatten_fun_nokwargs( + _set_attrs(lu.wrap_init(f, debug_info=dbg), attrs), tree) + primal_out, out_pvals, jaxpr, consts, attrs_out = _linearize( + f_, *attr_primals, *primals_flat) + f_lin = _lin_wrap(jaxpr, consts, out_pvals, attr_avals, (in_tree, out_tree()), + attrs, attrs_out) + return tree_unflatten(out_tree(), primal_out), f_lin + +def _linearize(traceable: lu.WrappedFun, *primals): + jvpfun, attrs = _split_attrs(_jvp(traceable)) + in_pvals = (tuple(pe.PartialVal.known(p) for p in primals) + + tuple(pe.PartialVal.unknown(core.get_aval(p).to_tangent_aval()) + for p in primals)) + _, in_tree = tree_flatten((primals, primals)) + jvpfun_flat, out_tree = flatten_fun_nokwargs(jvpfun, in_tree) + jaxpr, out_pvals, consts = pe.trace_to_jaxpr_nounits(jvpfun_flat, in_pvals) + out_primals_pvals, out_tangents_pvals, out_tangent_attr_pvals = \ + tree_unflatten(out_tree(), out_pvals) + out_primals_consts = [pval.get_known() for pval in out_primals_pvals] + return (out_primals_consts, [*out_tangents_pvals, *out_tangent_attr_pvals], + jaxpr, consts, attrs()) + +@lu.transformation_with_aux2 +def _split_attrs(f, store, *args, **kwargs): + primals, tangents, tangent_attrs = f(*args, **kwargs) + attrs, tangent_attr_vals = unzip2(((o, a), t) for o, a, t in tangent_attrs) + store.store(attrs) + return primals, tangents, tangent_attr_vals + +def _lin_wrap(jaxpr, consts, out_pvals, attr_avals, io_tree, in_attrs, out_attrs): + in_tree, out_tree = io_tree + def f_lin(*tangents, attr_tangents): + if set(attr_tangents) - set(in_attrs): raise Exception + tangents_, in_tree_ = tree_flatten(tangents) + assert in_tree == in_tree_ + attr_tangents_ = [attr_tangents.get(a, ad.Zero(aval)) + for a, aval in zip(in_attrs, attr_avals)] + out = core.eval_jaxpr(jaxpr, consts, *attr_tangents_, *tangents_) + out_ = iter(out) + out = [p.get_known() if p.is_known() else next(out_) for p in out_pvals] + assert next(out_, None) is None + tangents_out, attr_tangents_out = split_list(out, [len(out)-len(out_attrs)]) + out_ct = tree_unflatten(out_tree, tangents_out) + return out_ct, dict(zip(out_attrs, attr_tangents_out)) + return f_lin + + +def vjp(f, *primals, attrs: list[tuple[Any, str]] = []): + attr_primals = [jax_getattr(o, a) for o, a in attrs] + primals_flat, in_tree = tree_flatten(primals) + tree = treedef_tuple((tree_structure(attr_primals), *in_tree.children())) + dbg = api_util.debug_info("attrs vjp", f, primals, {}) + f_, out_tree = flatten_fun_nokwargs( + _set_attrs(lu.wrap_init(f, debug_info=dbg), attrs), tree) + primal_out, out_pvals, jaxpr, consts, attrs_out = _linearize( + f_, *attr_primals, *primals_flat) + attr_avals = [core.get_aval(jax_getattr(o, a)).to_tangent_aval() + for o, a in attrs_out] + f_vjp = _vjp_wrap(jaxpr, consts, out_pvals, attr_avals, (in_tree, out_tree()), + attrs, attrs_out) + return tree_unflatten(out_tree(), primal_out), f_vjp + +def _vjp_wrap(jaxpr, consts, out_pvals, attr_avals, io_tree, in_attrs, out_attrs): + in_tree, out_tree = io_tree + dummies = [ad.UndefinedPrimal(v.aval) for v in jaxpr.invars] + def f_vjp(out_ct, *, attr_cotangents: dict[tuple[Any, str], JaxVal] = {}): + out_cts, out_tree_ = tree_flatten(out_ct) + assert out_tree == out_tree_ + attr_cts = [attr_cotangents.get(a, ad.Zero(aval)) + for a, aval in zip(out_attrs, attr_avals)] + out = ad.backward_pass(jaxpr, (), consts, dummies, (*out_cts, *attr_cts)) + in_attr_bars, arg_cts = split_list(out, [len(in_attrs)]) + args_ct = tree_unflatten(in_tree, map(ad.instantiate_zeros, arg_cts)) + return args_ct, dict(zip(in_attrs, in_attr_bars)) + return f_vjp + + +class Box: + _val: PyTree + _tag: core.OpaqueTraceState + def __init__(self, val): + self._val = val + self._tag = core.get_opaque_trace_state() + def get(self): + with core.take_current_trace() as t: + return t.process_box_get(self) + def set(self, val): + with core.take_current_trace() as t: + return t.process_box_set(self, val) + +def _box_get_impl(trace, box): + return box._val +core.EvalTrace.process_box_get = _box_get_impl + +def _box_set_impl(trace, box, val): + box._val = val +core.EvalTrace.process_box_set = _box_set_impl + +def _is_local(trace, box): + is_arg = box._tag._trace_ref() is trace + if is_arg: assert box._tag._trace_ref() is trace + return is_arg + +def _box_get_staging(trace, box): + if not _is_local(trace, box): + trace._ensure_tracked(box, '_val', pe.BoxAttr) + return box._val +pe.DynamicJaxprTrace.process_box_get = _box_get_staging + +def _box_set_staging(trace, box, val): + if not _is_local(trace, box): + trace._ensure_tracked(box, '_val', pe.BoxAttr) + box._val = val +pe.DynamicJaxprTrace.process_box_set = _box_set_staging + +def _box_get_jvp(trace, box): + return box._val +ad.JVPTrace.process_box_get = _box_get_jvp + +def _box_set_jvp(trace, box, val): + primal, tangent = trace.to_primal_tangent_pair(val) + if not (isinstance(tangent, ad.Zero) or _is_local(trace, box)): + raise Exception + if isinstance(tangent, ad.Zero): + box._val = primal + else: + box._val = ad.JVPTracer(trace, primal, tangent) +ad.JVPTrace.process_box_set = _box_set_jvp + +def _box_get_linearize(trace, box): + return box._val +ad.LinearizeTrace.process_box_get = _box_get_linearize + +def _box_set_linearize(trace, box, val): + primal, tangent = trace.to_primal_tangent_pair(val) + if not (isinstance(tangent, ad.Zero) or _is_local(trace, box)): + raise Exception + if isinstance(tangent, ad.Zero): + box._val = primal + else: + raise NotImplementedError # TODO + box._val = ad.LinearizeTracer(trace, primal, tangent) +ad.LinearizeTrace.process_box_set = _box_set_linearize + + +class List: + _val: PyTree + _tag: core.OpaqueTraceState + _is_arg: bool + def __init__(self, val=None): + self._val = [] if val is None else val[:] + self._tag = core.get_opaque_trace_state() + self._is_arg = False + def append(self, val): + with core.take_current_trace() as t: + return t.process_list_append(self, val) + def get(self): + with core.take_current_trace() as t: + if _is_local(t, self) and not self._is_arg: + return self._val[:] # defensive copy in case caller erroneously mutates + raise Exception("can't read the value of a List that was not created in " + "this scope") +AppendList = List + +def _list_append_impl(trace, lst, val): + lst._val.append(val) +core.EvalTrace.process_list_append = _list_append_impl + +def _list_append_staging(trace, lst, val): + if not _is_local(trace, lst): + _ensure_list_tracked(trace, lst) + return _list_append_impl(trace, lst, val) +pe.DynamicJaxprTrace.process_list_append = _list_append_staging + +def _ensure_list_tracked(trace, lst): + frame = trace.frame + if (lst, '_val', pe.ListAttr) not in frame.attrs_tracked: + frame.attrs_inits.append(lst._val) + frame.attrs_tracked.append((lst, '_val', pe.ListAttr)) + lst._val = [] diff --git a/jax/_src/basearray.py b/jax/_src/basearray.py index a89d4a2949be..01a988782671 100644 --- a/jax/_src/basearray.py +++ b/jax/_src/basearray.py @@ -16,10 +16,14 @@ from __future__ import annotations -import abc -import numpy as np -from typing import Any, Union from collections.abc import Sequence +import sys +from typing import Any, Union + +from jax._src.lib import xla_client as xc +from jax._src.util import use_cpp_class +import numpy as np + # TODO(jakevdp): fix import cycles and define these. Device = Any @@ -29,7 +33,9 @@ # Array is a type annotation for standard JAX arrays and tracers produced by # core functions in jax.lax and jax.numpy; it is not meant to include # future non-standard array types like KeyArray and BInt. -class Array(abc.ABC): + + +class Array: """Array base class for JAX ``jax.Array`` is the public interface for instance checks and type annotation @@ -47,8 +53,6 @@ def f(x: Array) -> Array: # type annotations are valid for traced and non-trace :func:`jax.numpy.array`, :func:`jax.numpy.zeros`, :func:`jax.numpy.ones`, :func:`jax.numpy.full`, :func:`jax.numpy.arange`, etc. """ - # Note: abstract methods for this class are defined dynamically in - # lax_numpy.py # For the sake of static type analysis, these definitions are mirrored in the # associated basearray.pyi file. @@ -56,42 +60,41 @@ def f(x: Array) -> Array: # type annotations are valid for traced and non-trace __hash__ = None @property - @abc.abstractmethod def dtype(self) -> np.dtype: """The data type (:class:`numpy.dtype`) of the array.""" + raise NotImplementedError @property - @abc.abstractmethod def ndim(self) -> int: """The number of dimensions in the array.""" + raise NotImplementedError @property - @abc.abstractmethod def size(self) -> int: """The total number of elements in the array.""" + raise NotImplementedError @property - @abc.abstractmethod def shape(self) -> tuple[int, ...]: """The shape of the array.""" + raise NotImplementedError # Documentation for sharding-related methods and properties defined on ArrayImpl: - @abc.abstractmethod def addressable_data(self, index: int) -> Array: """Return an array of the addressable data at a particular index.""" + raise NotImplementedError @property - @abc.abstractmethod def addressable_shards(self) -> Sequence[Shard]: """List of addressable shards.""" + raise NotImplementedError @property - @abc.abstractmethod def global_shards(self) -> Sequence[Shard]: """List of global shards.""" + raise NotImplementedError @property - @abc.abstractmethod def is_fully_addressable(self) -> bool: """Is this Array fully addressable? @@ -103,19 +106,19 @@ def is_fully_addressable(self) -> bool: a jax.Array which is fully replicated can span across multiple hosts and is not fully addressable. """ + raise NotImplementedError @property - @abc.abstractmethod def is_fully_replicated(self) -> bool: """Is this Array fully replicated?""" + raise NotImplementedError @property - @abc.abstractmethod def sharding(self) -> Sharding: """The sharding for the array.""" + raise NotImplementedError @property - @abc.abstractmethod def committed(self) -> bool: """Whether the array is committed or not. @@ -137,20 +140,20 @@ def committed(self) -> bool: a + b # Raises an error ``` - See https://jax.readthedocs.io/en/latest/faq.html#controlling-data-and-computation-placement-on-devices + See https://docs.jax.dev/en/latest/faq.html#controlling-data-and-computation-placement-on-devices for more information. """ + raise NotImplementedError @property - @abc.abstractmethod def device(self) -> Device | Sharding: """Array API-compatible device attribute. For single-device arrays, this returns a Device. For sharded arrays, this returns a Sharding. """ + raise NotImplementedError - @abc.abstractmethod def copy_to_host_async(self): """Copies an ``Array`` to the host asynchronously. @@ -165,17 +168,24 @@ def copy_to_host_async(self): array, but does not wait for the copy to complete. This may speed up a future on-host access to the array's contents. """ + raise NotImplementedError +Array = use_cpp_class(xc.Array)(Array) Array.__module__ = "jax" + # StaticScalar is the Union of all scalar types that can be converted to # JAX arrays, and are possible to mark as static arguments. StaticScalar = Union[ np.bool_, np.number, # NumPy scalar types bool, int, float, complex, # Python scalar types ] -StaticScalar.__doc__ = "Type annotation for JAX-compatible static scalars." + +if sys.version_info[:2] < (3, 14): + # Python 3.14 raises + # AttributeError: 'typing.Union' object attribute '__doc__' is read-only + StaticScalar.__doc__ = "Type annotation for JAX-compatible static scalars." # ArrayLike is a Union of all objects that can be implicitly converted to a @@ -187,4 +197,8 @@ def copy_to_host_async(self): np.ndarray, # NumPy array type StaticScalar, # valid scalars ] -ArrayLike.__doc__ = "Type annotation for JAX array-like objects." + +if sys.version_info[:2] < (3, 14): + # Python 3.14 raises + # AttributeError: 'typing.Union' object attribute '__doc__' is read-only + ArrayLike.__doc__ = "Type annotation for JAX array-like objects." diff --git a/jax/_src/basearray.pyi b/jax/_src/basearray.pyi index a368b593332d..54098a081f39 100644 --- a/jax/_src/basearray.pyi +++ b/jax/_src/basearray.pyi @@ -14,11 +14,13 @@ import abc from collections.abc import Callable, Sequence from types import ModuleType -from typing import Any, Protocol, Union, runtime_checkable +from typing import Any, Protocol, runtime_checkable, Union import numpy as np +from jax._src.partition_spec import PartitionSpec as P +from jax._src.named_sharding import NamedSharding from jax._src.sharding import Sharding -from jax._src.partition_spec import PartitionSpec + # TODO(jakevdp) de-duplicate this with the DTypeLike definition in typing.py. # We redefine these here to prevent circular imports. @@ -39,7 +41,8 @@ Traceback = Any PrecisionLike = Any -class Array(abc.ABC): +# TODO(slebedev): Remove the metaclass once ``jax_extension_version >= 325``. +class Array(metaclass=abc.ABCMeta): aval: Any @property @@ -181,12 +184,15 @@ class Array(abc.ABC): promote_integers: bool = True) -> Array: ... def ptp(self, axis: Axis = None, out: None = None, keepdims: bool = False) -> Array: ... - def ravel(self, order: str = 'C') -> Array: ... + def ravel(self, order: str = 'C', *, + out_sharding: NamedSharding | P | None = ...) -> Array: ... @property def real(self) -> Array: ... def repeat(self, repeats: ArrayLike, axis: int | None = None, *, - total_repeat_length: int | None = None) -> Array: ... - def reshape(self, *args: Any, order: str = "C") -> Array: ... + total_repeat_length: int | None = None, + out_sharding: NamedSharding | P | None = None) -> Array: ... + def reshape(self, *args: Any, order: str = "C", + out_sharding: NamedSharding | P | None = ...) -> Array: ... def round(self, decimals: int = 0, out: None = None) -> Array: ... def searchsorted(self, v: ArrayLike, side: str = 'left', sorter: ArrayLike | None = None, *, method: str = 'scan') -> Array: ... @@ -280,25 +286,35 @@ class _IndexUpdateHelper: class _IndexUpdateRef: def get(self, indices_are_sorted: bool = False, unique_indices: bool = False, mode: str | None = None, fill_value: StaticScalar | None = None, - out_spec: Sharding | PartitionSpec | None = None) -> Array: ... + out_sharding: Sharding | P | None = None, wrap_negative_indices: bool = True) -> Array: ... def set(self, values: Any, indices_are_sorted: bool = False, unique_indices: bool = False, - mode: str | None = None, fill_value: StaticScalar | None = None) -> Array: ... + mode: str | None = None, fill_value: StaticScalar | None = None, + wrap_negative_indices: bool = True) -> Array: ... def add(self, values: Any, indices_are_sorted: bool = False, - unique_indices: bool = False, mode: str | None = None) -> Array: ... + unique_indices: bool = False, mode: str | None = None, + wrap_negative_indices: bool = True) -> Array: ... def subtract(self, values: Any, *, indices_are_sorted: bool = False, - unique_indices: bool = False, mode: str | None = None) -> Array: ... + unique_indices: bool = False, mode: str | None = None, + wrap_negative_indices: bool = True) -> Array: ... def mul(self, values: Any, indices_are_sorted: bool = False, - unique_indices: bool = False, mode: str | None = None) -> Array: ... + unique_indices: bool = False, mode: str | None = None, + wrap_negative_indices: bool = True) -> Array: ... def multiply(self, values: Any, indices_are_sorted: bool = False, - unique_indices: bool = False, mode: str | None = None) -> Array: ... + unique_indices: bool = False, mode: str | None = None, + wrap_negative_indices: bool = True) -> Array: ... def divide(self, values: Any, indices_are_sorted: bool = False, - unique_indices: bool = False, mode: str | None = None) -> Array: ... + unique_indices: bool = False, mode: str | None = None, + wrap_negative_indices: bool = True) -> Array: ... def power(self, values: Any, indices_are_sorted: bool = False, - unique_indices: bool = False, mode: str | None = None) -> Array: ... + unique_indices: bool = False, mode: str | None = None, + wrap_negative_indices: bool = True) -> Array: ... def min(self, values: Any, indices_are_sorted: bool = False, - unique_indices: bool = False, mode: str | None = None) -> Array: ... + unique_indices: bool = False, mode: str | None = None, + wrap_negative_indices: bool = True) -> Array: ... def max(self, values: Any, indices_are_sorted: bool = False, - unique_indices: bool = False, mode: str | None = None) -> Array: ... + unique_indices: bool = False, mode: str | None = None, + wrap_negative_indices: bool = True) -> Array: ... def apply(self, func: Callable[[ArrayLike], ArrayLike], indices_are_sorted: bool = False, - unique_indices: bool = False, mode: str | None = None) -> Array: ... + unique_indices: bool = False, mode: str | None = None, + wrap_negative_indices: bool = True) -> Array: ... diff --git a/jax/_src/buffer_callback.py b/jax/_src/buffer_callback.py new file mode 100644 index 000000000000..a1dfb5c2ff18 --- /dev/null +++ b/jax/_src/buffer_callback.py @@ -0,0 +1,267 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Callable, Sequence +import functools +from typing import Any + +import numpy as np + +from jax._src import core +from jax._src import dispatch +from jax._src import effects +from jax._src import ffi +from jax._src import tree_util +from jax._src import util +from jax._src.interpreters import ad +from jax._src.interpreters import batching +from jax._src.interpreters import mlir +from jax._src.lib import ffi as ffi_lib + +export = util.set_module("jax.experimental.buffer_callback") +Buffer = export(ffi_lib.Buffer) +ExecutionStage = export(ffi_lib.ExecutionStage) +ExecutionContext = export(ffi_lib.ExecutionContext) + + +def buffer_callback( + callback: Callable[..., None], + result_shape_dtypes: object, + *, + has_side_effect: bool = False, + vmap_method: str | None = None, + input_output_aliases: dict[int, int] | None = None, + command_buffer_compatible: bool = False, +): + """An experimental callback that operates in place on device buffers. + + Only supported on CPU and GPU backends. + + Note that the plan is for this to eventually be replaced by a consolidated + callback API built using JAX mutable arrays, but for now this provides a + mechanism for prototyping computational kernels using other Python libraries + including Numpy, PyTorch, Cupy, and others. + + Let's start with a simple example: + + >>> def py_add_one_inplace(ctx, out, x): + ... np.asarray(out)[...] = np.asarray(x) + 1 + ... + >>> x = jnp.array(41, dtype=jnp.int32) + >>> out_type = jax.ShapeDtypeStruct(x.shape, x.dtype) + >>> add_one = buffer_callback(py_add_one_inplace, out_type) + >>> add_one(x) # doctest: +SKIP + Array(42, dtype=int32) + + In this example, we're executing a numpy computation via JAX, and this could + have been implemented using :func:`jax.pure_callback`, but in this case, the + output is being populated in-place. This means that JAX doesn't need to copy + the output arrays upon returning from the callback. Note that even though the + callback function operates on mutable buffers, JAX still sees this as an + operation that consumes and produces regular immutable JAX arrays. + + Unlike the other JAX callback APIs, ``buffer_callback`` requires that the + user-defined Python function have the following signature: + + .. code-block:: python + + def callback(ctx: ExecutionContext, out, *args) -> None: + ... + + where ``ctx`` is an instance of + :class:`~jax.experimental.buffer_callback.ExecutionContext`, which mainly + provides access to XLA's computation stream when running on GPU, ``out`` is a + pytree of mutable :class:`~jax.experimental.buffer_callback.Buffer` objects, + and the ``args`` arguments have the same pytree structure as the inputs, but + each leaf is :class:`~jax.experimental.buffer_callback.Buffer`. This callback + should not return any values, and it should overwrite the ``out`` buffers in + place to output values back to JAX. + + It's important to note that this Python function can't really be called + except via ```buffer_callback`` itself, because it's not (yet!) possible to + construct mutable JAX buffers directly in Python. + + The bespoke :class:`~jax.experimental.buffer_callback.Buffer` type is an + array-like object that supports the ``__array__`` protocol on CPU, the + ``__cuda_array_interface__`` protocol on GPU, and the ``__dlpack__`` protocol + on both CPU and GPU. + + Args: + callback: A Python function with the signature and behavior described above. + result_shape_dtypes: A pytree whose leaves have ``shape`` and ``dtype`` + attributes, with a structure that matches the expected output of the + callback function at runtime. :class:`jax.ShapeDtypeStruct` is often used + to define leaf values. + has_side_effect: Whether the callback has side effects. + vmap_method: A string specifying how the callback transforms under + :func:`~jax.vmap` as described in the docs for :func:`~jax.pure_callback`. + input_output_aliases: a dictionary mapping the index of some inputs to + the index of the output that aliases them. These indices are in the + flattened inputs and outputs. + command_buffer_compatible: if ``True``, the callback will be traced into + the command buffer. This means that the Python code should only be + executed once, and then the operations will be replayed for every + subsequent call. + + Returns: + A new callable that accepts :class:`jax.Array` inputs (and pytrees thereof), + and pytree of :class:`jax.Array` objects whose structure matches that + of ``result_shape_dtypes``. + + See Also: + - :func:`jax.pure_callback`: callback designed for pure host functions. + - :func:`jax.experimental.io_callback`: callback designed for impure host + functions. + - :func:`jax.debug.callback`: callback designed for general-purpose + debugging. + - :func:`jax.debug.print`: callback designed for printing. + """ + flat_shape_dtypes, out_tree = tree_util.tree_flatten(result_shape_dtypes) + flat_result_avals = tuple( + core.ShapedArray(x.shape, x.dtype) for x in flat_shape_dtypes + ) + + def wrapped_callback(*args, **kwargs): + flat_args, in_tree = tree_util.tree_flatten((args, kwargs)) + + in_avals = [core.get_aval(x) for x in flat_args] + static_input_output_aliases: tuple[tuple[int, int], ...] = () + if input_output_aliases is not None: + for i_idx, o_idx in sorted(input_output_aliases.items()): + i_idx, o_idx = int(i_idx), int(o_idx) + if i_idx >= len(args): + raise ValueError( + f"input_output_aliases contains the mapping '{i_idx}:{o_idx}' " + f"with input index {i_idx} outside the range [0, " + f"{len(args)}).") + if o_idx >= len(flat_result_avals): + raise ValueError( + f"input_output_aliases contains the mapping '{i_idx}:{o_idx}' " + f"with output index {o_idx} outside the range [0, " + f"{len(flat_result_avals)}).") + in_aval = in_avals[i_idx] + out_aval = flat_result_avals[o_idx] + if not ffi._check_compatible_avals(in_aval, out_aval): + raise ValueError( + f"input_output_aliases contains the mapping '{i_idx}:{o_idx}' " + f"referring to an input with abstract value {in_aval} and an " + f"output with a different abstract value {out_aval}.") + static_input_output_aliases += ((i_idx, o_idx),) + + out_flat = buffer_callback_p.bind( + *flat_args, + callback=callback, + result_avals=flat_result_avals, + in_tree=in_tree, + out_tree=out_tree, + vmap_method=vmap_method, + has_side_effect=has_side_effect, + input_output_aliases=static_input_output_aliases, + command_buffer_compatible=command_buffer_compatible, + ) + return tree_util.tree_unflatten(out_tree, out_flat) + + return wrapped_callback + + +buffer_callback_p = core.Primitive("buffer_callback") +buffer_callback_p.multiple_results = True +dispatch.prim_requires_devices_during_lowering.add(buffer_callback_p) +dispatch.simple_impl(buffer_callback_p) + + +class BufferCallbackEffect(effects.Effect): + def __str__(self): + return "BufferCallback" + +_BufferCallbackEffect = BufferCallbackEffect() +effects.lowerable_effects.add_type(BufferCallbackEffect) +effects.control_flow_allowed_effects.add_type(BufferCallbackEffect) + + +@buffer_callback_p.def_effectful_abstract_eval +def _buffer_callback_abstract_eval( + *args, + result_avals: tuple[core.ShapedArray, ...], + has_side_effect: bool, + **_, +): + del args + effects = {_BufferCallbackEffect} if has_side_effect else core.no_effects + return result_avals, effects + + +def _buffer_callback_jvp_rule(*args, **kwargs): + del args, kwargs + raise ValueError( + "Buffer callbacks do not support JVP. " + "Please use `jax.custom_jvp` to use callbacks while taking gradients.") +ad.primitive_jvps[buffer_callback_p] = _buffer_callback_jvp_rule + + +def _buffer_callback_transpose_rule(*args, **kwargs): + del args, kwargs + raise ValueError( + "Buffer callbacks do not support transpose. " + "Please use `jax.custom_vjp` to use callbacks while taking gradients.") +ad.primitive_transposes[buffer_callback_p] = _buffer_callback_transpose_rule + +batching.primitive_batchers[buffer_callback_p] = functools.partial( + ffi.ffi_batching_rule, buffer_callback_p +) + + +def _buffer_callback_lowering( + ctx: mlir.LoweringRuleContext, + *args: Any, + callback, + in_tree: Any, + out_tree: Any, + has_side_effect: bool, + input_output_aliases: Sequence[tuple[int, int]], + command_buffer_compatible: bool, + **_, +): + + if len(ctx.module_context.platforms) > 1: + raise NotImplementedError("multi-platform lowering for buffer_callback") + platform = ctx.module_context.platforms[0] + target_name = { + "cpu": "xla_buffer_python_cpu_callback", + "cuda": "xla_buffer_python_gpu_callback", + "rocm": "xla_buffer_python_gpu_callback", + }.get(platform) + if target_name is None: + raise ValueError(f"`buffer_callback` not supported on {platform} backend.") + + if command_buffer_compatible and platform in ("cuda", "rocm"): + target_name += "_cmd_buffer" + + def wrapped_callback(exec_ctx, *args: Any): + args_in, args_out = util.split_list(args, [in_tree.num_leaves]) + py_args_in, py_kwargs_in = tree_util.tree_unflatten(in_tree, args_in) + py_args_out = tree_util.tree_unflatten(out_tree, args_out) + if callback(exec_ctx, py_args_out, *py_args_in, **py_kwargs_in) is not None: + raise ValueError("buffer_callback callback must not return any values.") + return () + + ctx.module_context.add_host_callback(wrapped_callback) + index = np.uint64(len(ctx.module_context.host_callbacks) - 1) + rule = ffi.ffi_lowering( + target_name, + has_side_effect=has_side_effect, + operand_output_aliases=dict(input_output_aliases), + ) + return rule(ctx, *args, index=index) +mlir.register_lowering(buffer_callback_p, _buffer_callback_lowering) diff --git a/jax/_src/cache_key.py b/jax/_src/cache_key.py index e4b6e7a2669c..906e686727ef 100644 --- a/jax/_src/cache_key.py +++ b/jax/_src/cache_key.py @@ -56,7 +56,7 @@ def get_flag_prefixes() -> list[str]: def custom_hook() -> str: """Custom hook for any addition to the cache key. - The custom hook will be called everytime get() is called and can be + The custom hook will be called every time get() is called and can be defined to return a string that will be hashed into the cache key. """ return "" @@ -110,6 +110,10 @@ def get( bytes(jaxlib_version_str.encode("utf-8")) ), ), + ( + "backend version", + lambda hash_obj: _hash_platform(hash_obj, backend) + ), ( "XLA flags", lambda hash_obj: _hash_xla_flags(hash_obj, get_flag_prefixes()), @@ -126,7 +130,7 @@ def get( ), ( "accelerator_config", - lambda hash_obj: _hash_accelerator_config(hash_obj, devices, backend), + lambda hash_obj: _hash_accelerator_config(hash_obj, devices), ), ( "compression", @@ -220,7 +224,7 @@ def _hash_devices(hash_obj, devices: np.ndarray) -> None: _hash_string(hash_obj, device.device_kind) -def _hash_accelerator_config(hash_obj, accelerators: np.ndarray, backend): +def _hash_accelerator_config(hash_obj, accelerators: np.ndarray): accelerator_devices = [] for accelerator in accelerators.flat: accelerator_devices.append(accelerator) @@ -233,9 +237,8 @@ def _hash_accelerator_config(hash_obj, accelerators: np.ndarray, backend): # PjRtTopologyDescription as yet. logger.info("get (_hash_accelerator_config): unable to hash " "accelerator config, falling back to hashing " - "devices + platform: %s (type %s)", ex, type(ex)) + "devices %s (type %s)", ex, type(ex)) _hash_devices(hash_obj, accelerators) - _hash_platform(hash_obj, backend) # LINT.IfChange(xla_flags) xla_flags_to_exclude_from_cache_key = [ diff --git a/jax/_src/callback.py b/jax/_src/callback.py index bdceb98d92b7..8c0bc8f3c6ec 100644 --- a/jax/_src/callback.py +++ b/jax/_src/callback.py @@ -20,10 +20,9 @@ import logging from typing import Any -import jax +from jax._src import api from jax._src import config from jax._src import core -from jax._src import deprecations from jax._src import dispatch from jax._src import dtypes from jax._src import effects @@ -37,20 +36,15 @@ from jax._src.interpreters import batching from jax._src.interpreters import mlir from jax._src.interpreters import xla -from jax._src.lax.control_flow.loops import map as lax_map from jax._src.lib import xla_client as xc from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo -from jax._src.sharding_impls import SdyArraySharding, SdyArrayShardingList, SingleDeviceSharding -from jax._src.typing import DeprecatedArg +from jax._src.sharding_impls import SdyArray, SdyArrayList, SdyDim, SingleDeviceSharding +from jax._src.typing import Array, DeprecatedArg import numpy as np logger = logging.getLogger(__name__) -# TODO(dfm): Remove after 6 months. -# Added Oct 1, 2024 -deprecations.register("jax-callback-vectorized") - # `pure_callback_p` is the main primitive for staging out Python pure callbacks. pure_callback_p = core.Primitive("pure_callback") pure_callback_p.multiple_results = True @@ -72,7 +66,7 @@ class _FlatCallback: callback_func: Callable[..., Any] in_tree: tree_util.PyTreeDef # (args, kwargs) pytree for `callback_func`. - def __call__(self, *flat_args: jax.Array) -> Sequence[jax.Array]: + def __call__(self, *flat_args: Array) -> Sequence[Array]: args, kwargs = tree_util.tree_unflatten(self.in_tree, flat_args) return tree_util.tree_leaves(self.callback_func(*args, **kwargs)) @@ -82,20 +76,19 @@ def pure_callback_impl( result_avals, callback: _FlatCallback, sharding: SingleDeviceSharding | None, - vectorized: bool | DeprecatedArg, vmap_method: str | None, ): - del sharding, vectorized, vmap_method, result_avals + del sharding, vmap_method, result_avals try: - cpu_device, *_ = jax.local_devices(backend="cpu") + cpu_device, *_ = xb.local_devices(backend="cpu") except RuntimeError as e: raise RuntimeError( "jax.pure_callback failed to find a local CPU device to place the" " inputs on. Make sure \"cpu\" is listed in --jax_platforms or the" " JAX_PLATFORMS environment variable." ) from e - args = jax.device_put(args, cpu_device) - with jax.default_device(cpu_device): + args = api.device_put(args, cpu_device) + with config.default_device(cpu_device): try: return tree_util.tree_map(np.asarray, callback(*args)) except BaseException: @@ -113,10 +106,9 @@ def pure_callback_abstract_eval( callback: _FlatCallback, result_avals, sharding: SingleDeviceSharding | None, - vectorized: bool | DeprecatedArg, vmap_method: str | None, ): - del avals, callback, sharding, vectorized, vmap_method + del avals, callback, sharding, vmap_method return result_avals @@ -161,13 +153,15 @@ def _callback_op_sharding( " computations" ) if config.use_shardy_partitioner.value: - assert len(avals_out) == 1 - op_sharding = sharding_impls.SdyArrayShardingList([ - sharding_impls.SdyArraySharding( + ndim = 0 + if avals_out and isinstance(avals_out[0], core.ShapedArray): + ndim = avals_out[0].ndim + op_sharding = SdyArrayList([ + SdyArray( mesh_shape=(), - dimension_shardings=[ - sharding_impls.SdyDimSharding(axes=[], is_closed=True) - ] * avals_out[0].ndim, + dim_shardings=[ + SdyDim(axes=[], is_open=False) + ] * ndim, logical_device_ids=())]) else: op_sharding = xc.OpSharding() # type: ignore[assignment] @@ -200,10 +194,14 @@ def _callback_op_sharding( # program has bulk array semantics, so we run the callback with a MAXIMAL # sharding and hence execute it only once on the full logical value). if config.use_shardy_partitioner.value: - op_sharding = sharding_impls.SdyArrayShardingList([ - sharding_impls.SdyArraySharding( + # For shardy, we need to have the same number of shardy annotations as the + # number of result ops. If there are no result ops, we need 1 shardy + # annotation. + num_sdy_shardings = max(1, len(avals_out)) + op_sharding = SdyArrayList(num_sdy_shardings * [ + SdyArray( mesh_shape=(), - dimension_shardings=[], + dim_shardings=[], logical_device_ids=(device_index,))]) else: op_sharding = xc.OpSharding() # type: ignore[assignment] @@ -287,7 +285,7 @@ def pure_callback( When `vmap`-ed the behavior will depend on the value of the ``vmap_method``. * Calling :func:`~jax.vmap` on a callback without an explicit ``vmap_method`` - is deprecated and it will eventually raise ``NotImplementedError``. + raises a ``NotImplementedError``. * ``vmap_method="sequential"`` uses :func:`~jax.lax.map` to loop over the batched arguments, calling ``callback`` once for each batch element. * ``vmap_method="sequential_unrolled"`` is like ``sequential``, but the loop @@ -297,9 +295,8 @@ def pure_callback( * ``vmap_method="broadcast_all"`` behaves like ``expand_dims``, but the inputs are tiled to the expected batched shape. - If necessary, the legacy behavior provided by the deprecated - ``vectorized=True`` argument can be recovered using - ``vmap_method="legacy_vectorized"``. + If necessary, the legacy behavior provided by the removed ``vectorized=True`` + argument can be recovered using ``vmap_method="legacy_vectorized"``. The current default behavior is to use ``vmap_method="sequential"`` when not specified, but this behavior is deprecated, and in the future, the @@ -366,20 +363,13 @@ def pure_callback( (4,) (4,) Array([1., 2., 3., 4.], dtype=float32) - .. _External Callbacks: https://jax.readthedocs.io/en/latest/external-callbacks.html + .. _External Callbacks: https://docs.jax.dev/en/latest/external-callbacks.html """ - if not isinstance(vectorized, DeprecatedArg) and not vectorized is None: - deprecations.warn( - "jax-callback-vectorized", - "The vectorized argument of jax.pure_callback is deprecated and setting " - "it will soon raise an error. To avoid an error in the future, and to " - "suppress this warning, please use the vmap_method argument instead.", - stacklevel=2) - if vmap_method is not None: - raise ValueError( - "the vectorized and vmap_method arguments of jax.pure_callback cannot " - "be used together. Please use the vmap_method argument.") - vmap_method = "legacy_vectorized" if vectorized else "sequential" + # TODO(danfm): Remove this check 3 months after v0.6.0 is released. + if not isinstance(vectorized, DeprecatedArg): + raise ValueError( + "The 'vectorized' argument of jax.pure_callback was removed in JAX " + "v0.6.0. Use 'vmap_method' instead.") allowed_vmap_methods = ["sequential", "sequential_unrolled", "expand_dims", "broadcast_all", "legacy_vectorized", None] if vmap_method not in allowed_vmap_methods: @@ -397,7 +387,6 @@ def pure_callback( callback=_FlatCallback(callback, in_tree), result_avals=tuple(flat_result_avals), sharding=sharding, - vectorized=vectorized, vmap_method=vmap_method, ) return tree_util.tree_unflatten(out_tree, out_flat) @@ -434,15 +423,15 @@ def io_callback_impl( ): del result_avals, sharding, ordered try: - cpu_device, *_ = jax.local_devices(backend="cpu") + cpu_device, *_ = xb.local_devices(backend="cpu") except RuntimeError as e: raise RuntimeError( "jax.io_callback failed to find a local CPU device to place the" " inputs on. Make sure \"cpu\" is listed in --jax_platforms or the" " JAX_PLATFORMS environment variable." ) from e - args = jax.device_put(args, cpu_device) - with jax.default_device(cpu_device): + args = api.device_put(args, cpu_device) + with config.default_device(cpu_device): try: return tree_util.tree_map(np.asarray, callback(*args)) except BaseException: @@ -482,6 +471,7 @@ def io_callback_transpose_rule(*args, **kwargs): def io_callback_batching_rule( args, dims, callback, result_avals, sharding, ordered ): + from jax._src.lax.control_flow.loops import map as lax_map # pytype: disable=import-error if ordered: raise ValueError("Cannot `vmap` ordered IO callback.") is_batched = [d is not batching.not_mapped for d in dims] @@ -575,7 +565,7 @@ def io_callback( - :func:`jax.debug.callback`: callback designed for general-purpose debugging. - :func:`jax.debug.print`: callback designed for printing. - .. _External Callbacks: https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html + .. _External Callbacks: https://docs.jax.dev/en/latest/notebooks/external_callbacks.html """ flat_args, in_tree = tree_util.tree_flatten((args, kwargs)) tree_util.tree_map(_check_shape_dtype, result_shape_dtypes) @@ -592,7 +582,6 @@ def io_callback( return tree_util.tree_unflatten(out_tree, out_flat) - def is_empty_shape(s: core.Shape) -> bool: return any(d == 0 for d in s) @@ -603,7 +592,7 @@ def send_to_host( operand: Any, name: str, *, - sharding: SdyArrayShardingList | xc.OpSharding | None = None, + sharding: SdyArrayList | xc.OpSharding | None = None, ) -> ir.Value: channel_handle = hlo.ChannelHandle.get(channel, mlir.SEND_TO_HOST_TYPE) send_op = hlo.SendOp([operand], token, channel_handle, @@ -619,11 +608,11 @@ def send_to_host( # we need to create an equivalent sharding with no dimensions. If there # are multiple shardings, just grab the first one since all these # shardings should be the same. - assert isinstance(sharding, SdyArrayShardingList) + assert isinstance(sharding, SdyArrayList) assert len(sharding.shardings) >= 1 - sharding = SdyArrayShardingList([ - SdyArraySharding( - mesh_shape=(), dimension_shardings=[], + sharding = SdyArrayList([ + SdyArray( + mesh_shape=(), dim_shardings=[], logical_device_ids=sharding.shardings[0].logical_device_ids)]) mlir.set_sharding(send_op, sharding) return send_op.result @@ -635,7 +624,7 @@ def receive_from_host( out_aval: core.ShapedArray, name: str, *, - sharding: SdyArrayShardingList | xc.OpSharding | None = None, + sharding: SdyArrayList | xc.OpSharding | None = None, ) -> tuple[ir.Value, ir.Value]: channel_handle = hlo.ChannelHandle.get(channel, mlir.RECV_FROM_HOST_TYPE) recv_op = hlo.RecvOp([mlir.aval_to_ir_type(out_aval), @@ -647,7 +636,7 @@ def receive_from_host( _xla_host_transfer_rendezvous=ir.StringAttr.get(str(name)))) if sharding is not None: if config.use_shardy_partitioner.value: - assert isinstance(sharding, SdyArrayShardingList) + assert isinstance(sharding, SdyArrayList) assert len(sharding.shardings) >= 1 # `RecvOp`'s last argument is a `TokenType`. Since Shardy requires the # number of shardings to match the number of results, but JAX only sees @@ -655,10 +644,10 @@ def receive_from_host( # Note that even if a function returns N results, we will end up with N # `RecvOp`s, so we only need to get the first sharding. All shardings are # the same anyways, operating on the same single device ID. - sharding = SdyArrayShardingList([ + sharding = SdyArrayList([ sharding.shardings[0], - SdyArraySharding( - mesh_shape=(), dimension_shardings=[], + SdyArray( + mesh_shape=(), dim_shardings=[], logical_device_ids=sharding.shardings[0].logical_device_ids)]) mlir.set_sharding(recv_op, sharding) # Token should be at the end of the results @@ -666,6 +655,25 @@ def receive_from_host( return token, result + +def _aval_to_xla_shape(aval: core.AbstractValue) -> xc.Shape: + try: + return _xla_shape_handlers[type(aval)](aval) + except KeyError as err: + raise TypeError(f"No xla_shape_handler for type: {type(aval)}") from err + +_xla_shape_handlers: dict[type[core.AbstractValue], + Callable[[Any], xc.Shape]] = {} + +def _make_array_shape(aval: core.ShapedArray) -> xc.Shape: + aval = core.physical_aval(aval) + dtype = np.dtype('bool') if aval.dtype == dtypes.float0 else aval.dtype + return xc.Shape.array_shape(dtype, aval.shape) +_xla_shape_handlers[core.ShapedArray] = _make_array_shape + +_xla_shape_handlers[core.AbstractToken] = lambda _: xc.Shape.token_shape() + + def _emit_tpu_python_callback( backend: xb.XlaBackend, ctx: mlir.LoweringRuleContext, @@ -677,7 +685,7 @@ def _emit_tpu_python_callback( result_avals: Sequence[core.ShapedArray], result_shapes: Sequence[xc.Shape], *, - sharding: SdyArrayShardingList | xc.OpSharding | None = None, + sharding: SdyArrayList | xc.OpSharding | None = None, ) -> tuple[Sequence[ir.Value], Any]: token = token or hlo.create_token() _wrapped_callback = callback @@ -695,8 +703,7 @@ def _wrapped_callback(*args): # pylint: disable=function-redefined send_channel = ctx.module_context.new_channel() dummy_send_aval = core.ShapedArray((1,), np.float32) dummy_send_val = mlir.ir_constant(np.zeros(1, np.float32)) - operand_shapes = [*operand_shapes, - xla.aval_to_xla_shapes(dummy_send_aval)[0]] + operand_shapes = [*operand_shapes, _aval_to_xla_shape(dummy_send_aval)] token = send_to_host(send_channel, token, dummy_send_val, callback.__name__, sharding=sharding) send_channels.append(send_channel) @@ -723,21 +730,6 @@ def _wrapped_callback(*args): # pylint: disable=function-redefined return outputs, token -def _layout_to_mlir_layout(minor_to_major: Sequence[int] | None): - if minor_to_major is None: - # Needed for token layouts - layout: np.ndarray = np.zeros((0,), dtype="int64") - else: - layout = np.array(minor_to_major, dtype="int64") - return ir.DenseIntElementsAttr.get(layout, type=ir.IndexType.get()) - - -def _aval_to_default_layouts(aval): - avals = [core.physical_aval(aval)] - # Row major order is default for `NumPy`. - return [list(range(aval.ndim - 1, -1, -1)) for aval in avals] - - def emit_python_callback( ctx: mlir.LoweringRuleContext, callback, @@ -747,30 +739,42 @@ def emit_python_callback( result_avals: Sequence[core.ShapedArray], *, has_side_effect: bool, - sharding: SdyArrayShardingList | xc.OpSharding | None = None, - operand_layouts: Sequence[Sequence[int] | None] | None = None, - result_layouts: Sequence[Sequence[int] | None] | None = None, + partitioned: bool = False, + sharding: SdyArrayList | xc.OpSharding | None = None, ) -> tuple[Sequence[mlir.IrValues], Any, Any]: - """Emits MLIR that calls back to a provided Python function.""" + """Emits MLIR that calls back to a provided Python function. + + Args: + ctx: The lowering context. + callback: The Python callback function. + token: The token to use for the callback. + operands: The operands to the callback. + operand_avals: The abstract values of the operands. + result_avals: The abstract values of the results. + has_side_effect: Whether the callback has side effects. + partitioned: If True, then `callback` is called on local shards only. If + False, then `callback` is called on all shards. + sharding: The sharding of the callback. + + Returns: + A tuple of MLIR result values, a new token (if any), and the host callback + object. + """ if len(ctx.module_context.platforms) > 1: raise NotImplementedError("multi-platform lowering for python_callback") platform = ctx.module_context.platforms[0] if platform not in {"cpu", "cuda", "rocm", "tpu"}: raise ValueError( f"`EmitPythonCallback` not supported on {platform} backend.") + if partitioned: + if platform not in {"cpu", "cuda", "rocm"}: + raise ValueError( + f"Partitioned callback not supported on {platform} backend.") + if result_avals: + raise ValueError("Partitioned callback not supported with return values.") backend = ctx.module_context.get_backend() - result_shapes = util.flatten( - [xla.aval_to_xla_shapes(result_aval) for result_aval in result_avals]) - operand_shapes = util.flatten( - [xla.aval_to_xla_shapes(op_aval) for op_aval in operand_avals]) - # Handling layouts - if operand_layouts is None: - operand_layouts = util.concatenate( - map(_aval_to_default_layouts, operand_avals)) - operand_mlir_layouts = map(_layout_to_mlir_layout, operand_layouts) - if result_layouts is None: - result_layouts = util.concatenate(map(_aval_to_default_layouts, result_avals)) - result_mlir_layouts = map(_layout_to_mlir_layout, result_layouts) + result_shapes = [_aval_to_xla_shape(aval) for aval in result_avals] + operand_shapes = [_aval_to_xla_shape(aval) for aval in operand_avals] # First we apply checks to ensure output shapes and dtypes match the expected # ones. @@ -822,55 +826,51 @@ def _wrapped_callback(*args): for result_aval in result_avals] return outputs, token, None - result_types = mlir.flatten_ir_types([mlir.aval_to_ir_type(aval) for aval in result_avals]) + device = "gpu" if platform in {"cuda", "rocm"} else "cpu" + partition = "_partitioned" if partitioned else "" + call_target_name = f"xla_ffi{partition}_python_{device}_callback" if token: - callback_without_token = _wrapped_callback def _wrapped_callback(token, *args): # type: ignore # pylint: disable=function-redefined return (token, *callback_without_token(*args)) - - operand_shapes = [ - xla.aval_to_xla_shapes(core.abstract_token)[0], *operand_shapes - ] - result_shapes = [ - xla.aval_to_xla_shapes(core.abstract_token)[0], *result_shapes - ] operands = [token, *operands] - result_types = [mlir.token_type(), *result_types] - operand_mlir_layouts = [_layout_to_mlir_layout(None), *operand_mlir_layouts] - result_mlir_layouts = [_layout_to_mlir_layout(None), *result_mlir_layouts] - callback_descriptor, ifrt_callback = ( - backend.get_emit_python_callback_descriptor(_wrapped_callback, - operand_shapes, - result_shapes)) + if ( + config.use_shardy_partitioner.value + and sharding is not None + and len(ctx.avals_out) > 0 + and isinstance(sharding, SdyArrayList) + ): + # Add a sharding annotation for the token if we have at least one + # output. Otherwise, the single shardy annotation required of all ops + # (even those without any results) can annotate the token. + sharding = SdyArrayList([ + SdyArray( + mesh_shape=(), + dim_shardings=[], + logical_device_ids=()), + *sharding.shardings]) + ctx = dataclasses.replace( + ctx, + avals_in=[core.abstract_token, *ctx.avals_in], + avals_out=[core.abstract_token, *ctx.avals_out], + ) + + # TODO(dsuo): Remove this line once we deprecate the XLA custom call + # handler. + ifrt_callback = _wrapped_callback ctx.module_context.add_host_callback(ifrt_callback) - descriptor_operand = mlir.ir_constant(callback_descriptor) - callback_operands = [descriptor_operand, *operands] - if operand_mlir_layouts is not None: - operand_mlir_layouts = [_layout_to_mlir_layout([]), *operand_mlir_layouts] - result_type = ir.TupleType.get_tuple(result_types) - call_target_name = ("xla_python_gpu_callback" - if platform in {"cuda", "rocm"} else "xla_python_cpu_callback") - result = hlo.CustomCallOp( - [result_type], - callback_operands, - call_target_name=ir.StringAttr.get(call_target_name), - has_side_effect=ir.BoolAttr.get(has_side_effect), - api_version=mlir.i32_attr(2), - called_computations=ir.ArrayAttr.get([]), - backend_config=ir.StringAttr.get(str(callback_descriptor)), - operand_layouts=( - None if operand_mlir_layouts is None - else ir.ArrayAttr.get(operand_mlir_layouts)), - result_layouts=( - None if result_mlir_layouts is None - else ir.ArrayAttr.get(result_mlir_layouts))) + index = np.uint64(len(ctx.module_context.host_callbacks) - 1) + result = ffi.build_ffi_lowering_function( # type: ignore + call_target_name, + has_side_effect=has_side_effect, + )(ctx, *operands, index=np.uint64(index)) + if sharding is not None: mlir.set_sharding(result, sharding) - results = [ - hlo.get_tuple_element(result, mlir.i32_attr(i)) - for i in range(len(result_types)) - ] + + results = result.results # type: ignore + if token: - token, *results = results - return results, token, ifrt_callback + token, *results = results # type: ignore + + return results, token, ifrt_callback # type: ignore diff --git a/jax/_src/checkify.py b/jax/_src/checkify.py index 1ec8ad50b456..0061c9c63f7b 100644 --- a/jax/_src/checkify.py +++ b/jax/_src/checkify.py @@ -25,7 +25,10 @@ from jax import dtypes from jax import lax -from jax.experimental import shard_map +# TODO(yashkatariya): Remove the experimental import after users are migrated +# to `jax.shard_map`. +from jax.experimental import shard_map # noqa: F401 +from jax._src import shard_map as jshmap from jax._src import api from jax._src import api_util from jax._src import ad_checkpoint @@ -50,6 +53,7 @@ from jax._src.tree_util import tree_map from jax._src.tree_util import tree_unflatten from jax._src.typing import Array +from jax._src.partition_spec import PartitionSpec as P from jax._src.util import (as_hashable_function, split_list, safe_map, safe_zip, unzip3, weakref_lru_cache, HashableWrapper, foreach) @@ -261,7 +265,7 @@ def _get_batched_exception(self) -> BatchedError | None: cur_effect = None for error_effect, code in self._code.items(): if self._pred[error_effect][idx]: # type: ignore - if min_code is None or code[idx] < min_code: + if min_code is None or code[idx] < min_code: # type: ignore[index] min_code = code[idx] # type: ignore cur_effect = error_effect @@ -600,7 +604,7 @@ def isnan(x): lax.igamma_p, lax.igammac_p, lax.integer_pow_p, lax.lgamma_p, lax.linear_solve_p, lax.log1p_p, lax.log_p, lax.logistic_p, lax.mul_p, lax.pad_p, lax.pow_p, lax.psum_p, - lax.random_gamma_grad_p, lax.reduce_p, lax.reduce_prod_p, + lax.reduce_p, lax.reduce_prod_p, lax.reduce_sum_p, lax.reduce_window_p, lax.reduce_window_sum_p, lax.regularized_incomplete_beta_p, lax.rem_p, lax.rng_uniform_p, lax.rsqrt_p, lax.sin_p, @@ -756,7 +760,8 @@ def jaxpr_to_checkify_jaxpr( out_tree, error_effects = metadata() return checked_jaxpr, out_tree, error_effects -def cond_error_check(error: Error, enabled_errors, index, *ops, branches): +def cond_error_check(error: Error, enabled_errors, index, *ops, + branches, **params): # Get the error-effects out of all branches so the cond can be called with # a merged error with all these effects. err_vals, err_tree = jtu.tree_flatten(error) @@ -777,7 +782,7 @@ def get_error_effects_from_jaxpr(jxpr): err_and_outs = lax.cond_p.bind( index, *err_vals, *ops, - branches=tuple(new_branches)) + branches=tuple(new_branches), **params) # we need to merge metadata across out_trees (a tuple) err0, out = tree_unflatten(out_trees[0], err_and_outs) @@ -913,14 +918,14 @@ def pjit_error_check(error, enabled_errors, *vals_in, jaxpr, # Update pjit params to account for extra error values. num_error_vals = len(err_vals) num_out_error_vals = out_tree.num_leaves - len(out_shardings) - sharding = sharding_impls.UNSPECIFIED new_in_shardings = (*[sharding] * num_error_vals, *in_shardings) - new_out_shardings = (*[sharding] * num_out_error_vals, *out_shardings) new_in_layouts = (*[None] * num_error_vals, *in_layouts) - new_out_layouts = (*[None] * num_out_error_vals, *out_layouts) new_donated_invars = (*[False] * num_error_vals, *donated_invars) + new_out_shardings = (*[sharding] * num_out_error_vals, *out_shardings) + new_out_layouts = (*[None] * num_out_error_vals, *out_layouts) + err_and_out = pjit.pjit_p.bind( *new_vals_in, jaxpr=checked_jaxpr, @@ -954,7 +959,7 @@ def remat_error_check(error, enabled_errors, *vals_in, jaxpr, **params): def shard_map_error_check( error: Error, enabled_errors, *vals_in, - jaxpr: core.Jaxpr, in_names, out_names, **kwargs + jaxpr: core.Jaxpr, in_specs, out_specs, **kwargs ): if (mesh := kwargs.get('mesh')) is None: raise ValueError('Mesh must be provided for shard_map with checkify.') @@ -962,22 +967,24 @@ def shard_map_error_check( err_vals, err_tree = jtu.tree_flatten(error) num_error_vals = len(err_vals) # Replicated sharding for in errors. - new_in_names = (*([{}] * num_error_vals), *in_names) + new_in_specs = (*([P()] * num_error_vals), *in_specs) new_vals_in = [*err_vals, *vals_in] in_avals = list(map(core.get_aval, new_vals_in)) - auto = kwargs.get('auto') + manual_axes = kwargs.get('manual_axes') + check_vma = kwargs.get('check_vma') for i, v in enumerate(in_avals): if not (sharder := core.shard_aval_handlers.get(type(v))): raise ValueError(f'Unsupported aval type: {type(v)}') - in_avals[i] = sharder(mesh, auto, new_in_names[i], v) + in_avals[i] = sharder(mesh, manual_axes, check_vma, new_in_specs[i], v) - with (shard_map._extend_axis_env(mesh, auto), - mesh_lib.use_abstract_mesh(shard_map._as_manual_mesh(mesh, auto))): + with (jshmap._extend_axis_env(mesh, manual_axes), + mesh_lib.use_abstract_mesh(jshmap._as_manual_mesh(mesh, manual_axes)), # type: ignore[arg-type] + config._check_vma(check_vma)): # jaxpr to checked_jaxpr checked_jaxpr, out_tree, _ = jaxpr_to_checkify_jaxpr( pe.close_jaxpr(jaxpr), enabled_errors, err_tree, *in_avals ) - num_out_error_vals = out_tree.num_leaves - len(out_names) + num_out_error_vals = out_tree.num_leaves - len(out_specs) def expand_errors_leading_dim(*xs): outs = core.eval_jaxpr(checked_jaxpr.jaxpr, checked_jaxpr.consts, *xs) @@ -985,7 +992,7 @@ def expand_errors_leading_dim(*xs): errs = [lax.expand_dims(e, [0]) for e in errs] return *errs, *outs - with core.extend_axis_env_nd(mesh.shape.items()): + with core.extend_axis_env_nd(mesh.shape.items()), config._check_vma(check_vma): jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic( lu.wrap_init(expand_errors_leading_dim, debug_info=checked_jaxpr.jaxpr.debug_info), @@ -995,22 +1002,22 @@ def expand_errors_leading_dim(*xs): # Update shard_map params to account for extra error values. # Use fully sharded partitioning for out errors. - new_out_names = (*([{0: mesh.axis_names}] * num_out_error_vals), *out_names) + new_out_specs = (*([P(mesh.axis_names)] * num_out_error_vals), *out_specs) subfun = lu.hashable_partial( lu.wrap_init(core.eval_jaxpr, debug_info=checked_jaxpr.jaxpr.debug_info), checked_jaxpr.jaxpr, checked_jaxpr.consts ) new_params = dict( jaxpr=checked_jaxpr.jaxpr, - in_names=new_in_names, - out_names=new_out_names, + in_specs=new_in_specs, + out_specs=new_out_specs, **kwargs, ) - _, new_params = shard_map.shard_map_p.get_bind_params(new_params) + _, new_params = jshmap.shard_map_p.get_bind_params(new_params) - err_and_out = shard_map.shard_map_p.bind(subfun, *new_vals_in, **new_params) + err_and_out = jshmap.shard_map_p.bind(subfun, *new_vals_in, **new_params) return tree_unflatten(out_tree, err_and_out) -error_checks[shard_map.shard_map_p] = shard_map_error_check +error_checks[jshmap.shard_map_p] = shard_map_error_check def custom_jvp_call_rule(in_err: Error, enabled_errors: set, *in_vals, num_consts, @@ -1073,17 +1080,17 @@ def jvp(*xs): return [*primal_errs, *out_primals, *tangent_errs, *out_tangents] return lu.wrap_init(jvp, debug_info=jvp_jaxpr_fun.debug_info) -def custom_vjp_call_jaxpr_rule(in_err, enabled_errors, *in_vals, - fun_jaxpr: core.ClosedJaxpr, - fwd_jaxpr_thunk, num_consts, - bwd: lu.WrappedFun, out_trees, - symbolic_zeros: bool): +def custom_vjp_call_rule(in_err, enabled_errors, *in_vals, + call_jaxpr: core.ClosedJaxpr, + fwd_jaxpr_thunk, num_consts, + bwd: lu.WrappedFun, out_trees, + symbolic_zeros: bool): err_vals, err_tree = jtu.tree_flatten(in_err) num_errs = err_tree.num_leaves checkified_fun = lu.wrap_init( - functools.partial(checkify_jaxpr_flat, fun_jaxpr.jaxpr, - fun_jaxpr.consts, enabled_errors, err_tree), - debug_info=fun_jaxpr.jaxpr.debug_info) + functools.partial(checkify_jaxpr_flat, call_jaxpr.jaxpr, + call_jaxpr.consts, enabled_errors, err_tree), + debug_info=call_jaxpr.jaxpr.debug_info) checkified_fun, fun_metadata = _flatten_and_get_error_metadata_thunk( checkified_fun) @@ -1091,13 +1098,13 @@ def checkified_fwd(*args): # TODO(lenamartens, sharadmv): why not checkify here? xs, zeros = args[::2], args[1::2] xs, zeros = xs[num_errs:], zeros[num_errs:] - fwd_jaxpr, fwd_consts = fwd_jaxpr_thunk(*zeros) + fwd_jaxpr, fwd_consts = fwd_jaxpr_thunk.call_wrapped(*zeros) xs_without_consts = xs[num_consts:] return core.eval_jaxpr(fwd_jaxpr, fwd_consts, *xs_without_consts) # TODO(necula): the fwd result_paths are not quite the same as fun_jaxpr checkified_fwd_wrapped = lu.wrap_init(checkified_fwd, - debug_info=fun_jaxpr.jaxpr.debug_info) + debug_info=fwd_jaxpr_thunk.debug_info) bwd_ = lu.wrap_init(lambda *args: (*(None,)*num_errs, *bwd.call_wrapped(*args)), debug_info=bwd.debug_info) checkified_fwd_wrapped, fwd_out_tree = flatten_fun_output(checkified_fwd_wrapped) @@ -1112,7 +1119,7 @@ def checkified_fwd(*args): else: out_err, out_vals = in_err, all_outs return out_err, out_vals -error_checks[custom_derivatives.custom_vjp_call_jaxpr_p] = custom_vjp_call_jaxpr_rule +error_checks[custom_derivatives.custom_vjp_call_p] = custom_vjp_call_rule def check_discharge_rule(error, enabled_errors, *args, err_tree, debug): diff --git a/jax/_src/cloud_tpu_init.py b/jax/_src/cloud_tpu_init.py index 0539e4253063..0d4f37203fbc 100644 --- a/jax/_src/cloud_tpu_init.py +++ b/jax/_src/cloud_tpu_init.py @@ -74,6 +74,8 @@ def cloud_tpu_init() -> None: # Exit early if we're not running on a Cloud TPU VM or libtpu isn't installed. libtpu_path = get_tpu_library_path() num_tpu_chips, tpu_id = hardware_utils.num_available_tpu_chips_and_device_id() + if num_tpu_chips == 0: + os.environ['TPU_SKIP_MDS_QUERY'] = '1' if ( tpu_id is not None and tpu_id >= hardware_utils.TpuVersion.v5e diff --git a/jax/_src/clusters/cloud_tpu_cluster.py b/jax/_src/clusters/cloud_tpu_cluster.py index c8aa765c181c..f45af0e76dd1 100644 --- a/jax/_src/clusters/cloud_tpu_cluster.py +++ b/jax/_src/clusters/cloud_tpu_cluster.py @@ -54,24 +54,26 @@ def get_metadata(key): raise RuntimeError(f"Getting metadata['{key}'] failed for 6 tries") return api_resp.text, api_resp.status_code -def get_tpu_env_value(key): - def get_tpu_env_value_from_metadata(key): - tpu_env_data = get_metadata('tpu-env')[0] - key_value_pairs = tpu_env_data.split('\n') - for key_value_pair in key_value_pairs: - # Typical line is MEGASCALE_NUM_SLICES: '2' - if ':' in key_value_pair: - row_key, value = re.split(':', key_value_pair, 1) - row_key = row_key.strip() - if row_key == key: - return value.strip().strip("'") - return None - +def get_tpu_env_value_from_metadata(key) -> str | None: + metadata_value = None + tpu_env_data = get_metadata('tpu-env')[0] + key_value_pairs = tpu_env_data.split('\n') + for key_value_pair in key_value_pairs: + # Typical line is MEGASCALE_NUM_SLICES: '2' + if ':' in key_value_pair: + row_key, value = re.split(':', key_value_pair, 1) + row_key = row_key.strip() + if row_key == key: + metadata_value = value.strip().strip("'") + return metadata_value + +def get_tpu_env_value(key) -> str | None: + # First try to get the value from the environment. value = os.environ.get(key, None) - return value if value is not None else get_tpu_env_value_from_metadata(key) - -def has_megascale_address(): - return get_tpu_env_value('MEGASCALE_COORDINATOR_ADDRESS') is not None + if value is None: + # If not found, try to get it from the metadata. + value = get_tpu_env_value_from_metadata(key) + return value class BaseTpuCluster(clusters.ClusterEnv): @@ -94,12 +96,11 @@ def is_env_present(cls) -> bool: @classmethod def get_coordinator_address(cls, timeout_secs: int | None) -> str: - if has_megascale_address(): - # For both GCE via QueuedResources and GKE via JobSet, the - # Megascale coordinator address is set as the host with process id = 0, - # so can be used as the jax distributed system coordinator. - coordinator_address = get_tpu_env_value('MEGASCALE_COORDINATOR_ADDRESS') - else: + # For both GCE via QueuedResources and GKE via JobSet, the + # Megascale coordinator address is set as the host with process id = 0, + # so can be used as the jax distributed system coordinator. + coordinator_address = get_tpu_env_value('MEGASCALE_COORDINATOR_ADDRESS') + if not coordinator_address: # For both GCE (QueuedResources and TPUVM create) and GKE via Job API, # the workers lists are sorted by process ID so the first one can # be used as the jax distributed system coordinator. @@ -149,17 +150,18 @@ def get_process_id(cls) -> int: @staticmethod def _get_num_slices() -> int: - if has_megascale_address(): - return int(get_tpu_env_value('MEGASCALE_NUM_SLICES')) - else: + num_slices = get_tpu_env_value('MEGASCALE_NUM_SLICES') + if not num_slices: return 1 + return int(num_slices) # type: ignore + @staticmethod def _get_slice_id() -> int: - if has_megascale_address(): - return int(get_tpu_env_value('MEGASCALE_SLICE_ID')) - else: + slice_id = get_tpu_env_value('MEGASCALE_SLICE_ID') + if not slice_id: return 0 + return int(slice_id) # type: ignore @staticmethod def _get_process_id_in_slice() -> int: diff --git a/jax/_src/clusters/cluster.py b/jax/_src/clusters/cluster.py index 69ef77a6421d..1c0a6fca9df6 100644 --- a/jax/_src/clusters/cluster.py +++ b/jax/_src/clusters/cluster.py @@ -23,7 +23,7 @@ class ClusterEnv: """Interface for defining a cluster environment. - To enable auto bootrapping (aka :func:`jax.distributed.initialize()`), + To enable auto bootstrapping (aka :func:`jax.distributed.initialize()`), cluster environments need to derive from :class:`ClusterEnv` and implement :func:`is_env_present`, :func:`get_coordinator_address`, :func:`get_process_count`, and :func:`get_process_id`. diff --git a/jax/_src/clusters/k8s_cluster.py b/jax/_src/clusters/k8s_cluster.py index 1274724b8ebd..fb312038bf2c 100644 --- a/jax/_src/clusters/k8s_cluster.py +++ b/jax/_src/clusters/k8s_cluster.py @@ -16,13 +16,51 @@ from contextlib import contextmanager from functools import cache +from itertools import chain +import logging +import numpy as np import os import socket +import time import textwrap import warnings from jax._src import clusters +logger = logging.getLogger(__name__) + + +def retry( + func=None, + initial_delay=0, + wait=np.logspace(-1, 1, 5) * np.random.rand(5), + exceptions=Exception, +): + def retry_decorator(func): + def retry_driver(*args, **kwargs): + # Retry the function call with exponential backoff + for i, t in enumerate(chain([initial_delay], wait)): + logger.debug( + f"Trying {func.__name__} in {t:.2f} seconds, attempt {i}/{len(wait)}" + ) + time.sleep(t) + try: + return func(*args, **kwargs) + except exceptions as e: + if i == len(wait): + raise RuntimeError('Retry failed with all attempts exhausted') from e + finally: + logger.debug( + f"Finished {func.__name__} after {i+1} attempts" + ) + return retry_driver + + if func is None: + return retry_decorator + else: + return retry_decorator(func) + + class K8sCluster(clusters.ClusterEnv): # Use an arbitrarily chosen port for the coordinator since we cannot @@ -34,16 +72,18 @@ def is_env_present(cls) -> bool: if 'KUBERNETES_SERVICE_HOST' in os.environ: try: import kubernetes as k8s # pytype: disable=import-error - except ImportError as e: - warnings.warn(textwrap.fill( - "Kubernetes environment detected, but the `kubernetes` package is " - "not installed to enable automatic bootstrapping in this " - "environment. To enable automatic boostrapping, please install " - "jax with the [k8s] extra. For example:" - " pip install jax[k8s]" - " OR" - " pip install jax[k8s,]" - )) + except (ImportError, ModuleNotFoundError): + warnings.warn( + '\n'.join([ + textwrap.fill( + "Kubernetes environment detected, but the `kubernetes` package " + "is not installed to enable automatic bootstrapping in this " + "environment. To enable automatic bootstrapping, please install " + "jax with the [k8s] extra. For example:"), + " pip install jax[k8s]", + " pip install jax[k8s,]", + ]) + ) return False k8s.config.load_incluster_config() @@ -67,7 +107,9 @@ def _handle_api_exception(cls): "this job does not have the permission for pod introspection. Please " "either grant the default SA permission to read pod info, or create a " "dedicated service account with the permission and associated with " - "the job. For more details, see .", + "the job. For an example on setting up the service account, see the " + "example/k8s directory in the JAX repo. For more details, please refer to " + "https://docs.jax.dev/en/latest/multi_process.html#kubernetes-example", width=80 )) raise RuntimeError('\n'.join(err_msg)) from e @@ -81,16 +123,16 @@ def _namespace(cls): @classmethod @cache + # in case of latency for core DNS to update pod IP to etcd/API server + @retry(exceptions=ValueError) def _pod(cls): + ip = socket.gethostbyname(os.getenv('HOSTNAME')) with cls._handle_api_exception(): - ip = socket.gethostbyname(os.getenv('HOSTNAME')) - pods = cls._core_api.list_namespaced_pod( + [pod] = cls._core_api.list_namespaced_pod( namespace=cls._namespace(), field_selector=f'status.podIP={ip}' ).items - assert len(pods) == 1, \ - f"Exactly 1 Kubernetes pod should have IP {ip}, got {len(pods)}." - return pods[0] + return pod @classmethod @cache @@ -101,13 +143,127 @@ def _job(cls): ) @classmethod - def get_coordinator_address(cls, timeout_secs: int | None) -> str: - return '{job_name}-0.{jobset_name}:{port}'.format( - job_name=cls._pod().metadata.labels['job-name'], - jobset_name=cls._job().metadata.labels['jobset.sigs.k8s.io/jobset-name'], - port=cls._coordinator_port + @cache + def _headless_svc(cls): + with cls._handle_api_exception(): + services = cls._core_api.list_namespaced_service(cls._namespace()).items + + pod_labels = cls._pod().metadata.labels or {} + for svc in services: + if svc.spec.cluster_ip == "None": # if headless service + svc_selector = svc.spec.selector or {} + if all(pod_labels.get(k) == v for k, v in svc_selector.items()): + return svc + + # returns None if no headless service targets the current pod + return None + + @classmethod + @cache + def _controller(cls): + # https://github.com/kubernetes/apimachinery/blob/7b4292b/pkg/apis/meta/v1/types.go#L235 + # states that there cannot be more than one managing controller. + for owner in cls._pod().metadata.owner_references: + if owner.controller is True: + return owner + + raise RuntimeError( + 'Cannot automatically initialize distributed workload: ' + f'pod {cls._pod().metadata.name} does not have a controller.' ) + @classmethod + def get_coordinator_address(cls, timeout_secs: int | None) -> str: + controller = cls._controller() + job = cls._job() + pod = cls._pod() + if controller.kind == 'Job': + # if job belongs to a jobset + if 'jobset.sigs.k8s.io/jobset-name' in job.metadata.labels: + coordinator_hostname = '{job_name}-0.{subdomain}'.format( + job_name=job.metadata.name, + subdomain=job.metadata.labels['jobset.sigs.k8s.io/jobset-name'] + ) + # if job is standalone + else: + # check if the job is associated with a headless service, which is + # necessary for pods to communicate with each other + if pod.spec.subdomain is None: + # check if a headless service exists but not specified as subdomain + svc = cls._headless_svc() + err_msg = ( + "Pods within a job need a headless service in order to " + "communicate with each other. " + ) + if svc: + err_msg += ( + f"A headless service '{svc.metadata.name}' is found that " + "targets this job, but it is not specified as the job subdomain. " + "Please add the following to the job specification: " + ) + fix_msg = [ + "```", + "kind: Job", + "spec:", + " ...", + " template:", + " spec:", + f" subdomain: {svc.metadata.name}", + "```", + ] + else: + err_msg += "To fix, add the following to the job specification:" + fix_msg = [ + "```", + "apiVersion: v1", + "kind: Service", + "metadata:", + " name: jaxpods", + "spec:", + " publishNotReadyAddresses: true", + " clusterIP: None", + " selector:", + f" job-name: {job.metadata.name}", + "---", + "kind: Job", + "spec:", + " ...", + " template:", + " spec:", + " subdomain: jaxpods", + "```", + ] + + raise RuntimeError('\n'.join([textwrap.fill(err_msg)] + fix_msg)) + + coordinator_hostname = '{job_name}-0.{subdomain}'.format( + job_name=job.metadata.name, + subdomain=pod.spec.subdomain + ) + + if timeout_secs: + # Ensure host pod is up before trying to communicate + # Retry in case of cached NXDOMAIN DNS failure (30 secs default) + @retry( + initial_delay=0.5, + wait=np.logspace(-1, 1.5, 8) * np.random.rand(8), + exceptions=socket.gaierror + ) + def wait_for_host(hostname): + socket.gethostbyname(hostname) + + wait_for_host(coordinator_hostname) + + return '{hostname}:{port}'.format( + hostname=coordinator_hostname, + port=cls._coordinator_port + ) + + else: + raise RuntimeError( + 'In K8s, cluster automatic bootstrap only supports Job/JobSet.' + ) + @classmethod def get_process_count(cls) -> int: # https://kubernetes.io/docs/concepts/workloads/controllers/job/#controlling-parallelism @@ -120,5 +276,6 @@ def get_process_id(cls) -> int: return int(os.environ['JOB_COMPLETION_INDEX']) except KeyError: raise RuntimeError( - 'K8s job must be run with `completionMode: "Indexed"`.' + 'To enable automatic bootstrap in a K8s cluster, ' + 'jobs must be indexed by setting `completionMode: "Indexed"`.' ) diff --git a/jax/_src/compilation_cache.py b/jax/_src/compilation_cache.py index f1b56adf3359..db2730bb22bc 100644 --- a/jax/_src/compilation_cache.py +++ b/jax/_src/compilation_cache.py @@ -23,7 +23,7 @@ # If zstandard is installed, we use zstd compression, otherwise we use zlib. try: - import zstandard + import zstandard # pytype: disable=import-error except ImportError: zstandard = None @@ -207,7 +207,7 @@ def is_executable_in_cache(backend, cache_key: str) -> bool: def get_executable_and_time( - cache_key: str, compile_options, backend + cache_key: str, compile_options, backend, executable_devices ) -> tuple[xla_client.LoadedExecutable | None, int | None]: """Returns the cached executable and its compilation time if present, or None otherwise. @@ -224,7 +224,7 @@ def get_executable_and_time( serialized_executable, compile_time = extract_executable_and_time( executable_and_time) xla_executable_deserialized = backend.deserialize_executable( - serialized_executable, compile_options) + serialized_executable, executable_devices, compile_options) return xla_executable_deserialized, compile_time @@ -275,7 +275,7 @@ def put_executable_and_time( f"PERSISTENT CACHE WRITE with key {cache_key}, this is unexpected because " "JAX_COMPILATION_CACHE_EXPECT_PGLE is set. The execution that populated the " "cache may lack coverage, " - "https://jax.readthedocs.io/en/latest/persistent_compilation_cache.html may " + "https://docs.jax.dev/en/latest/persistent_compilation_cache.html may " "help debug why this has happened") cache.put(cache_key, executable_and_time) @@ -341,7 +341,7 @@ def combine_executable_and_time( def extract_executable_and_time( - exectuable_and_time: bytes + executable_and_time: bytes ) -> tuple[bytes, int]: """Given the cache entry in the format shown below, extract the serialized executable and the compilation time. @@ -351,5 +351,5 @@ def extract_executable_and_time( Content: compilation time serialized executable (big-endian int) """ - return exectuable_and_time[4:], int.from_bytes( - exectuable_and_time[:4], byteorder='big') + return executable_and_time[4:], int.from_bytes( + executable_and_time[:4], byteorder='big') diff --git a/jax/_src/compiler.py b/jax/_src/compiler.py index dea532d13031..2aef697a353d 100644 --- a/jax/_src/compiler.py +++ b/jax/_src/compiler.py @@ -17,9 +17,12 @@ from __future__ import annotations from collections.abc import Sequence +import copy +from functools import partial import logging import time -from typing import Any, Callable +from typing import Any +from collections.abc import Callable import warnings from jax._src import cache_key as cache_key_type @@ -33,6 +36,7 @@ from jax._src import traceback_util from jax._src.interpreters import mlir from jax._src.lib import xla_client as xc +from jax._src.lib import _jax from jax._src.lib.mlir import ir import numpy as np @@ -113,7 +117,6 @@ def get_compile_options( num_partitions: int, device_assignment=None, use_spmd_partitioning: bool = True, - use_shardy_partitioner: bool = False, use_auto_spmd_partitioning: bool = False, auto_spmd_partitioning_mesh_shape: list[int] | None = None, auto_spmd_partitioning_mesh_ids: list[int] | None = None, @@ -133,10 +136,6 @@ def get_compile_options( `num_partitions`. use_spmd_partitioning: boolean indicating whether to enable SPMD or MPMD partitioning in XLA. - use_shardy_partitioner: boolean indicating whether to use the Shardy - partitioner in XLA. Shardy is a new open sourced propagation framework for - MLIR. Currently Shardy is experimental in JAX. See - www.github.com/openxla/shardy. use_auto_spmd_partitioning: boolean indicating whether to automatically generate XLA shardings for SPMD partitioner. auto_spmd_partitioning_mesh_shape: device mesh shape used to create @@ -156,7 +155,7 @@ def get_compile_options( build_options = compile_options.executable_build_options build_options.use_spmd_partitioning = use_spmd_partitioning build_options.use_auto_spmd_partitioning = use_auto_spmd_partitioning - build_options.use_shardy_partitioner = use_shardy_partitioner + build_options.use_shardy_partitioner = config.use_shardy_partitioner.value if fdo_profile is not None: build_options.fdo_profile = fdo_profile if use_auto_spmd_partitioning: @@ -197,15 +196,6 @@ def get_compile_options( config.memory_fitting_level.value ).value - # This is a temporary workaround to simplify the AutoPGLE usage. - # TODO(b/376647494): Remove once the bug is fixed. - if ((config.enable_pgle.value and config.pgle_profiling_runs.value > 0) - or config.compilation_cache_expect_pgle.value): - logger.debug("Explicitly disabling command buffer scheduling for AutoPGLE.") - if env_options_overrides is None: - env_options_overrides = {} - env_options_overrides['xla_gpu_enable_command_buffer'] = '' - if env_options_overrides is not None: # Some overrides are passed directly on build_options. overrides_on_build_options = [ @@ -248,7 +238,7 @@ def get_compile_options( else: compile_options.profile_version = _NO_PROFILE_DONT_RETRIEVE if backend is None: - logging.info("get_compile_options: no backend supplied; " + logger.info("get_compile_options: no backend supplied; " "disabling XLA-AutoFDO profile") else: fdo_profile_version = get_latest_profile_version(backend) @@ -295,9 +285,49 @@ def get_compile_options( def backend_compile( backend: xc.Client, module: ir.Module, + executable_devices: xc.DeviceList, + options: xc.CompileOptions, +) -> xc.Executable: + sym_name = module.operation.attributes['sym_name'] + module_name = ir.StringAttr(sym_name).value + # Convert ir.Module to a string representation, unless the backend + # explicitly flags the ability to handle a module directly (avoiding the + # overhead of back and forth conversions). + # TODO(slebedev): Change the backend.compile() to accept ir.Module. + built_c: Any + if getattr(backend, "needs_str_ir", True): + built_c = mlir.module_to_bytecode(module) + else: + built_c = module + + if (options.executable_build_options.fdo_profile is not None + and len(options.executable_build_options.fdo_profile)): + logger.debug( + "Compiling module %s with FDO profile of length %d", + module_name, + len(options.executable_build_options.fdo_profile), + ) + + try: + return backend.compile(built_c, executable_devices, options) + except xc.XlaRuntimeError as e: + for error_handler in _XLA_RUNTIME_ERROR_HANDLERS: + handler_result = error_handler(e) + if handler_result is not None: + raise handler_result from e + raise e + + +@profiler.annotate_function +def backend_compile_and_load( + backend: xc.Client, + module: ir.Module, + executable_devices: xc.DeviceList, options: xc.CompileOptions, host_callbacks: Sequence[Any], ) -> xc.LoadedExecutable: + sym_name = module.operation.attributes['sym_name'] + module_name = ir.StringAttr(sym_name).value # Convert ir.Module to a string representation, unless the backend # explicitly flags the ability to handle a module directly (avoiding the # overhead of back and forth conversions). @@ -308,17 +338,47 @@ def backend_compile( else: built_c = module + if (options.executable_build_options.fdo_profile is not None + and len(options.executable_build_options.fdo_profile)): + logger.debug( + "Compiling module %s with FDO profile of length %d", + module_name, + len(options.executable_build_options.fdo_profile), + ) + try: # we use a separate function call to ensure that XLA compilation appears # separately in Python profiling results - if host_callbacks: + # TODO(dsuo): Simplify this logic once we delete _jax.CompileOnlyPyClient. + if isinstance(backend, _jax.CompileOnlyPyClient): + if host_callbacks: + return backend.compile( + built_c, + executable_devices=executable_devices, # type: ignore + compile_options=options, + host_callbacks=host_callbacks, # type: ignore + ) + # Some backends don't have `host_callbacks` option yet + # TODO(sharadmv): remove this fallback when all backends allow `compile` + # to take in `host_callbacks` return backend.compile( - built_c, compile_options=options, host_callbacks=host_callbacks + built_c, executable_devices=executable_devices, compile_options=options) # type: ignore + else: + if host_callbacks: + return backend.compile_and_load( + built_c, + executable_devices=executable_devices, + compile_options=options, + host_callbacks=host_callbacks, + ) + # Some backends don't have `host_callbacks` option yet + # TODO(sharadmv): remove this fallback when all backends allow `compile` + # to take in `host_callbacks` + return backend.compile_and_load( + built_c, + executable_devices=executable_devices, + compile_options=options, ) - # Some backends don't have `host_callbacks` option yet - # TODO(sharadmv): remove this fallback when all backends allow `compile` - # to take in `host_callbacks` - return backend.compile(built_c, compile_options=options) except xc.XlaRuntimeError as e: for error_handler in _XLA_RUNTIME_ERROR_HANDLERS: handler_result = error_handler(e) @@ -354,15 +414,14 @@ def compile_or_get_cached( devices: np.ndarray, compile_options: xc.CompileOptions, host_callbacks: Sequence[Any], + executable_devices: xc.DeviceList, pgle_profiler: profiler.PGLEProfiler | None = None, ) -> xc.LoadedExecutable: sym_name = computation.operation.attributes['sym_name'] module_name = ir.StringAttr(sym_name).value if dumped_to := mlir.dump_module_to_file(computation, "compile"): - logging.info("Dumped the module to %s.", dumped_to) - - use_compilation_cache = compilation_cache.is_cache_used(backend) + logger.info("Dumped the module to %s.", dumped_to) is_multi_process = ( len({device.process_index for device in devices.flatten()}) > 1 @@ -370,67 +429,29 @@ def compile_or_get_cached( min_device_process_id = min( devices.flatten(), key=lambda device: device.id ).process_index - is_auto_pgle_used = ( - config.enable_pgle.value and config.pgle_profiling_runs.value > 0 - ) - if not use_compilation_cache: - if ( - is_multi_process - and is_auto_pgle_used - and distributed.global_state.client is not None - ): - compile_options.executable_build_options.fdo_profile = ( - _share_fdo_profiles( - computation, - devices, - compile_options, - backend, - distributed.global_state.client, - min_device_process_id, - ) - ) + # cache_key: may be None if compilation caching is disabled + cache_key, compile_options = _resolve_compilation_strategy( + computation, + devices, + compile_options, + backend, + pgle_profiler, + is_multi_process, + module_name, + min_device_process_id, + ) - return backend_compile(backend, computation, compile_options, - host_callbacks) + if cache_key is None: + return backend_compile_and_load( + backend, computation, executable_devices, compile_options, + host_callbacks) monitoring.record_event('/jax/compilation_cache/compile_requests_use_cache') - try: - if config.remove_custom_partitioning_ptr_from_cache_key.value: - ignore_callbacks = cache_key_type.IgnoreCallbacks.CUSTOM_PARTITIONING - else: - ignore_callbacks = cache_key_type.IgnoreCallbacks.NO - - cache_key = compilation_cache.get_cache_key( - computation, - devices, - compile_options, - backend, - ignore_callbacks=ignore_callbacks, - ) - except xc._xla.XlaRuntimeError as ex: - logger.error("compile_or_get_cached: unable to generate cache key, " - "skipping the cache: %s", ex) - return backend_compile(backend, computation, compile_options, - host_callbacks) - - if is_auto_pgle_used or config.compilation_cache_expect_pgle.value: - cache_key = _resolve_pgle_module_cache_key( - computation, - devices, - compile_options, - backend, - pgle_profiler, - is_multi_process, - cache_key, - module_name, - min_device_process_id, - ) - cache_retrieval_start = time.monotonic() retrieved_executable, retrieved_compile_time = _cache_read( - module_name, cache_key, compile_options, backend) + module_name, cache_key, compile_options, backend, executable_devices) cache_retrieval_time = time.monotonic() - cache_retrieval_start if retrieved_executable is not None: @@ -450,7 +471,7 @@ def compile_or_get_cached( config.share_binary_between_hosts.value and is_multi_process and distributed.global_state.client is not None - # Host callbacks are currently baked into the HLO module so we cant share + # Host callbacks are currently baked into the HLO module so we can't share # them. and len(host_callbacks) == 0 ): @@ -458,6 +479,7 @@ def compile_or_get_cached( return _compile_and_share_module( backend, computation, + executable_devices, compile_options, host_callbacks, distributed.global_state.client, @@ -470,6 +492,7 @@ def compile_or_get_cached( return _compile_and_write_cache( backend, computation, + executable_devices, compile_options, host_callbacks, module_name, @@ -481,85 +504,130 @@ def compile_or_get_cached( # 1. PGLE optimized module (the one which was recompiled with FDO profile) is # in the persistent cache. In this case the module should be returned from # cache and PGLE should be disabled for this module. Is module is stored in -# the persistent cache under the "pgle_profiled_module_key" which calculated -# with replacing FDO profile with flag which identify that module were PGLE -# profiled. +# the persistent cache under the "pgle_optimized_cache_key", which is +# calculated by replacing the FDO profile with a sentinel value that identifies +# that the module was optimized with PGLE. # 2. PGLE profiled module is not in the persistent cache and the module is -# getting built with an FDO profile. In this case we need to share FDO profile -# with other processes and store the result under the -# "pgle_profiled_module_key" so later in case 1 we will be able to find the +# getting built with an FDO profile. In this case we need to share the FDO +# profile with any other processes and store the result under the +# "pgle_optimized_cache_key" so later in case 1 we will be able to find the # module. # 3. PGLE profiled module is not in the persistent cache and the module is # getting compiled to be PGLEd (FDO profile is empty). In this case we need to -# simply return the non-PGLE profiled module from the persistent cache. +# simply return the non-PGLE profiled module from the persistent cache if it +# exists, and otherwise compile it. # # If the compilation_cache_expect_pgle option is set then in case 1 the PGLE # optimized module will be loaded even if PGLE is not enabled in the current # process. This is useful if we want to combine the use of PGLE with other # profiling tools (e.g. Nsight Systems) that cannot co-exist with PGLE due to # contention for CUPTI resources. -def _resolve_pgle_module_cache_key( +def _resolve_compilation_strategy( computation: ir.Module, devices: np.ndarray, compile_options: xc.CompileOptions, backend: xc.Client, pgle_profiler: profiler.PGLEProfiler | None, is_multi_process: bool, - cache_key: str, module_name: str, min_device_process_id: int, -) -> str: - fdo_profile = compile_options.executable_build_options.fdo_profile - compile_options.executable_build_options.fdo_profile = b"pgle profiled" - - pgle_profiled_module_key = compilation_cache.get_cache_key( - computation, - devices, - compile_options, - backend, - cache_key_type.IgnoreCallbacks.ALL, +) -> tuple[str | None, xc.CompileOptions]: + is_auto_pgle_used = ( + config.enable_pgle.value and config.pgle_profiling_runs.value > 0 ) - compile_options.executable_build_options.fdo_profile = fdo_profile - - result_key = cache_key - if _is_executable_in_cache(backend, pgle_profiled_module_key): - # Load PGLE profiled module from the persistent cache. - result_key = pgle_profiled_module_key - if config.compilation_cache_expect_pgle.value: - logging.info(f"PGLE-optimized {module_name} loaded from compilation cache") - if pgle_profiler is not None: - pgle_profiler.disable() + + get_cache_key = partial(_get_cache_key, backend=backend, + computation=computation, devices=devices) + + if is_auto_pgle_used or config.compilation_cache_expect_pgle.value: + # This can be None if cache key generation fails. + pgle_optimized_cache_key = get_cache_key(compile_options, + override_fdo_profile=b"pgle profiled") + # TODO(b/376647494): remove the workaround when the bug is fixed; the JAX + # profiler cannot collect sufficiently detailed profile data for PGLE if + # command buffers / CUDA graphs are enabled. Therefore disable command + # buffers when compiling for PGLE data collection, but not if AutoPGLE is + # not enabled, and not when re-compiling using PGLE data. This condition + # includes `compilation_cache_expect_pgle` so that slow-to-compile modules + # that are not executed often enough to trigger re-compilation will still + # be cached between an "enable_pgle" run and an "expect_pgle" run. + first_pass_compile_options = copy.deepcopy(compile_options) + first_pass_compile_options.env_option_overrides += [ + ("xla_gpu_enable_command_buffer", ""), + ] else: - # No PGLE-optimised module found in the persistent cache. - if (config.compilation_cache_expect_pgle.value - and _is_executable_in_cache(backend, cache_key)): - # The user asserted this miss was unexpected; emit a warning + pgle_optimized_cache_key = None + first_pass_compile_options = compile_options + + # This can be None if cache key generation fails or caching is disabled + cache_key = get_cache_key(first_pass_compile_options) + + if cache_key is not None and pgle_optimized_cache_key is not None: + # The compilation cache is enabled and AutoPGLE is enabled/expected + if _is_executable_in_cache(backend, pgle_optimized_cache_key): + if config.compilation_cache_expect_pgle.value: + logger.info(f"PGLE-optimized {module_name} loaded from compilation cache") + # No need to record N profiles in this case + if pgle_profiler is not None: + pgle_profiler.disable() + return pgle_optimized_cache_key, compile_options + elif (config.compilation_cache_expect_pgle.value + and _is_executable_in_cache(backend, cache_key)): + # No PGLE-optimized module found in the persistent cache, and the user + # asserted (expect_pgle) that this miss was unexpected warnings.warn(f"PERSISTENT CACHE MISS for PGLE-optimized {module_name} " "despite non-PGLE hit; it may not have been executed " "enough times when the cache was populated") - if fdo_profile is not None and len(fdo_profile) > 0: - # Store module under PGLE profiled module cache key. - result_key = pgle_profiled_module_key - if is_multi_process and distributed.global_state.client is not None: - compile_options.executable_build_options.fdo_profile = ( - _share_fdo_profiles( - computation, - devices, - compile_options, - backend, - distributed.global_state.client, - min_device_process_id, - ) - ) - else: - compile_options.executable_build_options.fdo_profile = fdo_profile - logger.debug( - "Compiling module %s with FDO profile of length %d", - module_name, - len(compile_options.executable_build_options.fdo_profile), + + if (is_auto_pgle_used + and compile_options.executable_build_options.fdo_profile is not None + and len(compile_options.executable_build_options.fdo_profile)): + # Profile data are available to trigger a PGLE-optimized recompilation; + # store under `pgle_optimized_cache_key` if the cache is enabled + if is_multi_process and distributed.global_state.client is not None: + compile_options.executable_build_options.fdo_profile = ( + _share_fdo_profiles( + computation, + devices, + compile_options, + backend, + distributed.global_state.client, + min_device_process_id, ) - return result_key + ) + return pgle_optimized_cache_key, compile_options + else: + # Compile for PGLE collection, store under `cache_key` if the cache is + # enabled. This is also the AutoPGLE-disabled path. + return cache_key, first_pass_compile_options +def _get_cache_key( + options: xc.CompileOptions, + backend: xc.Client, + computation: ir.Module, + devices: np.ndarray, + override_fdo_profile: bytes | None = None) -> str | None: + if not compilation_cache.is_cache_used(backend): + return None + if config.remove_custom_partitioning_ptr_from_cache_key.value: + ignore_callbacks = cache_key_type.IgnoreCallbacks.CUSTOM_PARTITIONING + else: + ignore_callbacks = cache_key_type.IgnoreCallbacks.NO + if override_fdo_profile is not None: + options = copy.deepcopy(options) + options.executable_build_options.fdo_profile = override_fdo_profile + try: + return compilation_cache.get_cache_key( + computation, + devices, + options, + backend, + ignore_callbacks, + ) + except xc._xla.XlaRuntimeError as ex: + logger.error("compile_or_get_cached: unable to generate cache key, " + "skipping the cache: %s", ex) + return None # The process that has the lowest device ID should share FDO profile before # compilation with other processes. @@ -568,7 +636,7 @@ def _share_fdo_profiles( devices: np.ndarray, compile_options: xc.CompileOptions, backend: xc.Client, - global_client: lib.xla_extension.DistributedRuntimeClient, + global_client: lib._jax.DistributedRuntimeClient, min_process_id ) -> bytes | None: sym_name = computation.operation.attributes['sym_name'] @@ -624,14 +692,16 @@ def _share_fdo_profiles( _share_fdo_profiles.modules_profiles = {} + # The process with the first_process_id should compile the module and write it # to the K-V storage. def _compile_and_share_module( backend: xc.Client, computation: ir.Module, + executable_devices: xc.DeviceList, compile_options: xc.CompileOptions, host_callbacks: Sequence[Any], - global_client: lib.xla_extension.DistributedRuntimeClient, + global_client: lib._jax.DistributedRuntimeClient, module_name: str, cache_key: str, first_process_id: int @@ -647,6 +717,7 @@ def _compile_and_share_module( executable = _compile_and_write_cache( backend, computation, + executable_devices, compile_options, host_callbacks, module_name, @@ -667,25 +738,27 @@ def _compile_and_share_module( serialized_executable ) executable = backend.deserialize_executable( - serialized_executable, compile_options - ) + serialized_executable, executable_devices, compile_options) # type: ignore _compile_and_share_module.modules_cache[cache_key] = executable return executable + _compile_and_share_module.modules_cache = {} + def _compile_and_write_cache( backend: xc.Client, computation: ir.Module, + executable_devices: xc.DeviceList, compile_options: xc.CompileOptions, host_callbacks: Sequence[Any], module_name: str, cache_key: str, ) -> xc.LoadedExecutable: start_time = time.monotonic() - executable = backend_compile( - backend, computation, compile_options, host_callbacks + executable = backend_compile_and_load( + backend, computation, executable_devices, compile_options, host_callbacks ) compile_time = time.monotonic() - start_time _cache_write( @@ -693,6 +766,7 @@ def _compile_and_write_cache( ) return executable + def _is_executable_in_cache(backend, cache_key) -> bool: """Checks if executable is presented in cache on a given key """ @@ -709,14 +783,14 @@ def _is_executable_in_cache(backend, cache_key) -> bool: def _cache_read( module_name: str, cache_key: str, compile_options: xc.CompileOptions, - backend: xc.Client + backend: xc.Client, executable_devices: xc.DeviceList, ) -> tuple[xc.LoadedExecutable | None, int | None]: """Looks up the `computation` and it's compilation time in the persistent compilation cache repository. """ try: return compilation_cache.get_executable_and_time( - cache_key, compile_options, backend) + cache_key, compile_options, backend, executable_devices) except Exception as ex: if config.raise_persistent_cache_errors.value: raise diff --git a/jax/_src/config.py b/jax/_src/config.py index cf6a07834a10..9d19ceb8b261 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -235,17 +235,20 @@ def trace_context(): threefry_partitionable.value, threefry_gpu_kernel_lowering.value, use_direct_linearize.value, - varying_axes_in_types.value, softmax_custom_jvp.value, disable_jit.value, debug_key_reuse.value, jax_xla_profile_version.value, + _check_vma.value, # Technically this affects jaxpr->stablehlo lowering, not tracing. hlo_source_file_canonicalization_regex.value, pgle_profiling_runs.value, enable_pgle.value, use_shardy_partitioner.value, - use_high_dynamic_range_gumbel.value) + use_high_dynamic_range_gumbel.value, + error_checking_behavior_nan.value, + error_checking_behavior_divide.value, + error_checking_behavior_oob.value) config = Config() @@ -356,7 +359,7 @@ def __exit__(self, exc_type, exc_value, traceback): " This will be enabled by default in future versions of JAX, at which " "point all uses of the flag will be considered deprecated (following " "the `API compatibility policy " - "`_).") + "`_).") UPGRADE_BOOL_EXTRA_DESC = " (transient)" @@ -908,7 +911,7 @@ def enum_flag(name, default, *args, **kwargs) -> Flag[str]: 'The calling convention version number to use for exporting. This must be ' 'within the range of versions supported by the tf.XlaCallModule ' 'used in your deployment environment. ' - 'See https://jax.readthedocs.io/en/latest/export/shape_poly.html#calling-convention-versions.' + 'See https://docs.jax.dev/en/latest/export/shape_poly.html#calling-convention-versions.' ) ) @@ -917,7 +920,7 @@ def enum_flag(name, default, *args, **kwargs) -> Flag[str]: default=bool_env('JAX_EXPORT_IGNORE_FORWARD_COMPATIBILIY', False), help=( 'Whether to ignore the forward compatibility lowering rules. ' - 'See https://jax.readthedocs.io/en/latest/export/export.html#compatibility-guarantees-for-custom-calls.' + 'See https://docs.jax.dev/en/latest/export/export.html#compatibility-guarantees-for-custom-calls.' ) ) @@ -937,11 +940,17 @@ def enum_flag(name, default, *args, **kwargs) -> Flag[str]: 'otherwise.' )) -jax_pjrt_client_create_options = optional_string_state( +def _validate_jax_pjrt_client_create_options(new_val): + if new_val is not None and not isinstance(new_val, (str, dict)): + raise ValueError('new string config value must be None or of type dict' + f' | str, got {new_val} of type {type(new_val)}.') + +jax_pjrt_client_create_options = string_or_object_state( name='jax_pjrt_client_create_options', default=None, help=('A set of key-value pairs in the format of "k1:v1;k2:v2" strings ' - 'provided to a device platform pjrt client as extra arguments.')) + 'provided to a device platform pjrt client as extra arguments.'), + validator=_validate_jax_pjrt_client_create_options) enable_checks = bool_state( name='jax_enable_checks', @@ -967,6 +976,28 @@ def enum_flag(name, default, *args, **kwargs) -> Flag[str]: 'to disable any debuggers while leak checking is enabled.')) checking_leaks = functools.partial(check_tracer_leaks, True) + +captured_constants_warn_bytes = int_state( + name='jax_captured_constants_warn_bytes', + default=2 * 10 ** 9, + help=('The number of bytes of parameters that may be captured as constants ' + 'before a warning is issued. Defaults to approximately 2GB. ' + 'Set to -1 to disable issuing a warning.' + ) +) + +captured_constants_report_frames = int_state( + name='jax_captured_constants_report_frames', + default=0, + help=('The number of stack frames reported for each captured constant ' + 'indicating the file and operation where the constant was captured. ' + 'Set to -1 to print the complete set of frames, or 0 to disable. ' + 'N.b. the report is only generated if the total amount of captured ' + 'constants exceeds `jax_captured_constants_warn_bytes`, as it is expensive' + 'to generate the report.' + ) +) + debug_nans = bool_state( name='jax_debug_nans', default=False, @@ -995,7 +1026,7 @@ def enum_flag(name, default, *args, **kwargs) -> Flag[str]: name='jax_explain_cache_misses', default=False, help=('Each time there is a miss on one of the main caches (e.g. the ' - 'tracing cache), log an explanation.. Logging is performed with ' + 'tracing cache), log an explanation. Logging is performed with ' '`logging`. When this option is set, the log level is WARNING; ' 'otherwise the level is DEBUG.')) @@ -1013,19 +1044,6 @@ def enum_flag(name, default, *args, **kwargs) -> Flag[str]: help='If True, pmap and shard_map API will be merged.') -spmd_mode = enum_state( - name='jax_spmd_mode', - enum_values=['allow_all', 'allow_jit'], - default='allow_jit', - help=("Decides whether Math on `jax.Array`'s that are not fully addressable " - "(i.e. spans across multiple processes) is allowed. The options are: " - "* allow_jit: Default, `pjit` and `jax.jit` computations are allowed " - " to execute on non-fully addressable `jax.Array`s\n" - "* allow_all: `jnp`, normal math (like `a + b`, etc), `pjit`, " - " `jax.jit` and all other operations are allowed to " - " execute on non-fully addressable `jax.Array`s.")) - - distributed_debug = bool_state( name='jax_distributed_debug', default=False, @@ -1088,20 +1106,13 @@ def enum_flag(name, default, *args, **kwargs) -> Flag[str]: help=('Use direct linearization instead JVP followed by partial eval'), include_in_jit_key=True) -varying_axes_in_types = bool_state( - name='jax_varying_axes_in_types', +# TODO make it so people don't use this, this is internal... +_check_vma = bool_state( + name='check_vma', default=False, - help=('Adds varying manual axes to ShapedArray to track which mesh axes the' - ' array is varying over. This will help to remove the efficient' - ' transpose rewrite machinery in shard_map'), + help='internal implementation detail of shard_map, DO NOT USE', include_in_jit_key=True) -data_dependent_tracing_fallback = bool_state( - name='jax_data_dependent_tracing_fallback', - default=False, - help=('When True, falls back to trace dispatch based on data dependence ' - 'instead of throwing an escaped tracer error.')) - softmax_custom_jvp = bool_state( name='jax_softmax_custom_jvp', default=False, @@ -1317,6 +1328,41 @@ def enum_flag(name, default, *args, **kwargs) -> Flag[str]: ), ) +# TODO(ayx): Move these 3 flags out of config once we have a user-level +# extension mechanism for adding contexts to which the jit cache is sensitive. +error_checking_behavior_nan = enum_state( + name='jax_error_checking_behavior_nan', + enum_values=['ignore', 'raise'], + default='ignore', + help=( + 'Specify the behavior when a NaN is encountered. Options are "ignore"' + ' or "raise".' + ), + include_in_jit_key=True, +) + +error_checking_behavior_divide = enum_state( + name='jax_error_checking_behavior_divide', + enum_values=['ignore', 'raise'], + default='ignore', + help=( + 'Specify the behavior when a divide by zero is encountered. Options are' + ' "ignore" or "raise".' + ), + include_in_jit_key=True, +) + +error_checking_behavior_oob = enum_state( + name='jax_error_checking_behavior_oob', + enum_values=['ignore', 'raise'], + default='ignore', + help=( + 'Specify the behavior when an out of bounds access is encountered.' + ' Options are "ignore" or "raise".' + ), + include_in_jit_key=True, +) + def _update_x64_global(val): jax_jit.global_state().enable_x64 = val @@ -1437,18 +1483,17 @@ def _update_disable_jit_thread_local(val): enum_values=["off", "tracebackhide", "remove_frames", "quiet_remove_frames", "auto"], default="auto", - help="Controls how JAX filters internal frames out of tracebacks.\n\n" - "Valid values are:\n" - " * \"off\": disables traceback filtering.\n" - " * \"auto\": use \"tracebackhide\" if running under a sufficiently" - " new IPython, or \"remove_frames\" otherwise.\n" - " * \"tracebackhide\": adds \"__tracebackhide__\" annotations to" - " hidden stack frames, which some traceback printers support.\n" - " * \"remove_frames\": removes hidden frames from tracebacks, and adds" - " the unfiltered traceback as a __cause__ of the exception.\n" - " * \"quiet_remove_frames\": removes hidden frames from tracebacks, and adds" - " a brief message (to the __cause__ of the exception) describing that this has" - " happened.\n") + help="Controls how JAX filters internal frames out of tracebacks. Valid values are:\n" + "- ``off``: disables traceback filtering.\n" + "- ``auto``: use ``tracebackhide`` if running under a sufficiently " + "new IPython, or ``remove_frames`` otherwise.\n" + "- ``tracebackhide``: adds ``__tracebackhide__`` annotations to " + "hidden stack frames, which some traceback printers support.\n" + "- ``remove_frames``: removes hidden frames from tracebacks, and adds " + "the unfiltered traceback as a ``__cause__`` of the exception.\n" + "- ``quiet_remove_frames``: removes hidden frames from tracebacks, and adds " + "a brief message (to the ``__cause__`` of the exception) describing that this has " + "happened.\n\n") # This flag is for internal use. # TODO(tianjianlu): Removes once we always enable cusparse lowering. @@ -1474,13 +1519,6 @@ def _update_disable_jit_thread_local(val): help=('Attempt constant folding during staging.'), include_in_jit_key=True) -# This flag is temporary during rollout of the remat barrier. -# TODO(parkers): Remove if there are no complaints. -remat_opt_barrier = bool_state( - name='jax_remat_opt_barrier', - default=True, - help=('Enables using optimization-barrier op for lowering remat.')) - enable_remat_opt_pass = bool_state( name='jax_compiler_enable_remat_pass', default=True, @@ -1489,13 +1527,6 @@ def _update_disable_jit_thread_local(val): 'compute when encountering OOM errors. However, you are ' 'likely to get better results manually with jax.checkpoint')) -# TODO(sharadmv,mattjj): set default to True, then remove -eager_pmap = bool_state( - name='jax_eager_pmap', - default=True, - upgrade=True, - help='Enable eager-mode pmap when jax_disable_jit is activated.') - no_tracing = bool_state( name='jax_no_tracing', default=False, @@ -1650,7 +1681,7 @@ def transfer_guard(new_val: str) -> Iterator[None]: """A contextmanager to control the transfer guard level for all transfers. For more information, see - https://jax.readthedocs.io/en/latest/transfer_guard.html + https://docs.jax.dev/en/latest/transfer_guard.html Args: new_val: The new thread-local transfer guard level for all transfers. @@ -1685,16 +1716,17 @@ def _update_garbage_collection_guard(state, key, val): # The default is applied by guard_lib. default=None, help=( - 'Select garbage collection guard level for "jax.Array" objects.\nThis' - ' option can be used to control what happens when a "jax.Array"' - ' object is garbage collected. It is desirable for "jax.Array"' - ' objects to be freed by Python reference couting rather than garbage' + 'Select garbage collection guard level for ``jax.Array`` objects.\n\n' + 'This option can be used to control what happens when a ``jax.Array``' + ' object is garbage collected. It is desirable for ``jax.Array``' + ' objects to be freed by Python reference counting rather than garbage' ' collection in order to avoid device memory being held by the arrays' - ' until garbage collection occurs.\n\nValid values are:\n * "allow":' - ' do not log garbage collection of "jax.Array" objects.\n * "log":' - ' log an error when a "jax.Array" is garbage collected.\n * "fatal":' - ' fatal error if a "jax.Array" is garbage collected.\nDefault is' - ' "allow". Note that not all cycles may be detected.' + ' until garbage collection occurs.\n\n' + 'Valid values are:\n\n' + '* ``allow``: do not log garbage collection of ``jax.Array`` objects.\n' + '* ``log``: log an error when a ``jax.Array`` is garbage collected.\n' + '* ``fatal``: fatal error if a ``jax.Array`` is garbage collected.\n\n' + 'Default is ``allow``. Note that not all cycles may be detected.' ), update_global_hook=lambda val: _update_garbage_collection_guard( guard_lib.global_state(), 'garbage_collect_array', val @@ -1790,18 +1822,20 @@ def _update_garbage_collection_guard(state, key, val): 'O2', 'O3', ], - default='UNKNOWN', + default='O2', help=( 'The degree to which the compiler should attempt to make the program' ' fit in memory' ), - include_in_jit_key=True + include_in_jit_key=True, ) +DEFAULT_CPU_COLLECTIVES_IMPL = "gloo" + cpu_collectives_implementation = optional_enum_state( name='jax_cpu_collectives_implementation', enum_values=["gloo", "mpi", "megascale"], - default=None, + default=DEFAULT_CPU_COLLECTIVES_IMPL, help=( "Cross-process collective implementation used on CPU. Must be one of " '("gloo", "mpi")'), diff --git a/jax/_src/core.py b/jax/_src/core.py index 36ce2f004ed4..f03fc96b76e6 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -34,6 +34,7 @@ import numpy as np +from jax._src import deprecations from jax._src import dtypes from jax._src import config from jax._src import effects @@ -45,7 +46,7 @@ ConcretizationTypeError, TracerArrayConversionError, TracerBoolConversionError, TracerIntegerConversionError, UnexpectedTracerError) from jax._src import linear_util as lu - +from jax._src.tree_util import tree_flatten, tree_unflatten from jax._src import source_info_util from jax._src.util import (safe_zip, safe_map, curry, tuple_insert, tuple_delete, cache, @@ -87,7 +88,7 @@ class Jaxpr: __slots__ = ['__weakref__', '_constvars', '_invars', '_outvars', '_eqns', - '_effects', '_debug_info'] + '_effects', '_debug_info', '_is_high'] _constvars: list[Var] _invars: list[Var] @@ -95,6 +96,7 @@ class Jaxpr: _eqns: list[JaxprEqn] _effects: Effects _debug_info: DebugInfo + _is_high: bool @property def constvars(self) -> list[Var]: @@ -120,6 +122,10 @@ def effects(self) -> Effects: def debug_info(self) -> DebugInfo: return self._debug_info + @property + def is_high(self) -> bool: + return self._is_high + def __init__(self, constvars: Sequence[Var], invars: Sequence[Var], outvars: Sequence[Atom], eqns: Sequence[JaxprEqn], effects: Effects = no_effects, @@ -127,6 +133,7 @@ def __init__(self, constvars: Sequence[Var], invars: Sequence[Var], # compatibility we have to allow calls when the debug_info # is missing. debug_info: DebugInfo = None, # type: ignore[annotation-type-mismatch,assignment] + is_high: bool = False, ): """ Args: @@ -151,6 +158,8 @@ def __init__(self, constvars: Sequence[Var], invars: Sequence[Var], # TODO(necula): re-enable these safety checks # assert (len(debug_info.arg_names) == len(invars)), (debug_info, invars) # assert (len(debug_info.result_paths) == len(outvars)), (debug_info, outvars) + self._is_high = is_high + num_vars = len(constvars) + len(invars) def __str__(self): return str(self.pretty_print()) @@ -177,6 +186,7 @@ def replace(self, **kwargs): eqns=kwargs.pop("eqns", self.eqns), effects=kwargs.pop("effects", self.effects), debug_info=kwargs.pop("debug_info", self.debug_info), + is_high=kwargs.pop("is_high", self.is_high), ) if kwargs: raise ValueError(f"Unknown keyword arguments: {kwargs}") @@ -223,6 +233,16 @@ def __init__(self, jaxpr: Jaxpr, consts: Sequence): def in_avals(self): return [v.aval for v in self.jaxpr.invars] + @property + def in_aval_qdds(self) -> list[AbstractValue | AvalQDD]: + return [v.aval if v.initial_qdd is None else AvalQDD(v.aval, v.initial_qdd) + for v in self.jaxpr.invars] + + @property + def final_aval_qdds(self) -> list[AbstractValue | AvalQDD]: + return [v.aval if v.final_qdd is None else AvalQDD(v.aval, v.final_qdd) + for v in self.jaxpr.invars] + @property def out_avals(self): return [v.aval for v in self.jaxpr.outvars] @@ -320,7 +340,7 @@ def __enter__(self): def __exit__(self, exc_type, exc_value, traceback): config.compute_on_context_manager.set_local(self.prev_compute_type) config.threefry_partitionable.set_local(self.prev_threefry_partitionable) - if self.context.xla_metadata is not None: + if self.context.xla_metadata: config.xla_metadata_context_manager.set_local(self.prev_xla_metadata) config.abstract_mesh_context_manager.set_local(self.prev_abstract_mesh) @@ -412,31 +432,32 @@ def new_jaxpr_eqn(invars, outvars, primitive, params, effects, source_info=None, _var_counter = it.count() -@total_ordering class Var: - __slots__ = ["count", "suffix", "aval"] + __slots__ = ["count", "aval", "initial_qdd", "final_qdd"] count: int - suffix: str aval: AbstractValue + # these are only useful for jaxpr binders but rather than create a separate + # type for those, breaking existing interpreters, we add fields here. + initial_qdd : QuasiDynamicData | None + final_qdd : QuasiDynamicData | None - def __init__(self, suffix: str, aval: AbstractValue): + def __init__(self, aval: AbstractValue, initial_qdd = None, final_qdd = None): + assert isinstance(aval, AbstractValue) self.count = next(_var_counter) - self.suffix = suffix self.aval = aval - - # TODO(phawkins, mattjj): remove ordering of variables. JAX itself does not - # care about variable ordering, but the downstream package kfac_jax does. - def __lt__(self, other): - return self.count < other.count + self.initial_qdd = initial_qdd + self.final_qdd = final_qdd def __repr__(self): - return f'Var(id={id(self)}){self.suffix}:{self.aval.str_short()}' + return f'Var(id={id(self)}):{self.aval.str_short()}' + + def pretty_print(self, context: JaxprPpContext, *, print_dtype: bool = True): + del print_dtype # unused + return f"{context.var_names[self]}" -def gensym(suffix: str = '') -> Callable[[AbstractValue], Var]: - """Produce distinct variables, printed with the optional suffix.""" - return partial(Var, suffix) +gensym = lambda: Var # In a jaxpr, `dropvar` can appear in place of a bound variable to indicate that # the assignment is dropped, i.e. that an expression's output value will never @@ -444,38 +465,51 @@ def gensym(suffix: str = '') -> Callable[[AbstractValue], Var]: # treat it as a special case of one. Its `aval` is similarly inexact. class DropVar(Var): def __init__(self, aval: AbstractValue): - super().__init__('', aval) + super().__init__(aval) def __repr__(self): return '_' + def pretty_print(self, context: JaxprPpContext, *, print_dtype: bool = True): + del context, print_dtype # unused + return '_' class Literal: - __slots__ = ["val", "aval", "hash"] + __slots__ = ["val", "aval"] val: Any aval: AbstractValue - hash: int | None def __init__(self, val, aval): self.val = val self.aval = aval + + @property + def hash(self): try: - self.hash = hash(val) + return hash(self.val) except TypeError: - if type(val) in literalable_types: + if type(self.val) in literalable_types: try: - self.hash = hash((val.item(), val.dtype)) + return hash((self.val.item(), self.val.dtype)) except (TypeError, AttributeError, ValueError): - self.hash = None + return None __hash__ = None # type: ignore - def __repr__(self): - if hasattr(self, 'hash'): - return f'{self.val}' + def pretty_print(self, context: JaxprPpContext, *, print_dtype: bool = True): + del context # unused + dtype = getattr(self.aval, 'dtype', None) + if print_dtype and dtype: + return f'{self.val}:{self.aval.str_short(short_dtypes=True)}' else: - return f'Literal(val={self.val})' + return f'{self.val}' + + def __repr__(self): + return f'{self.val}' literalable_types: set[type] = set() +def is_literalable(x: Any) -> bool: + return type(x) in dtypes.python_scalar_dtypes or (type(x) in literalable_types and not np.shape(x)) + Atom = Union[Var, Literal] class Primitive: @@ -503,11 +537,9 @@ def bind(self, *args, **params): def _true_bind(self, *args, **params): for arg in args: - if (isinstance(arg, Tracer) - and not arg._trace.is_valid() - and not config.data_dependent_tracing_fallback.value): + if isinstance(arg, Tracer) and not arg._trace.is_valid(): raise escaped_tracer_error(arg) - # TODO: figure out how to handle function arguments + # TODO: figure out how to handle function arguments for this assert # assert (not config.enable_checks.value or # all(isinstance(arg, Tracer) or valid_jaxtype(arg) for arg in args)), args @@ -522,6 +554,11 @@ def _true_bind(self, *args, **params): trace_ctx.set_trace(prev_trace) def bind_with_trace(self, trace, args, params): + # TODO(mattjj,dougalm): remove this block? + if self.is_high(**params) and trace.requires_low: + with set_current_trace(trace): + return self.to_lojax(*args, **params) # type: ignore + return trace.process_primitive(self, args, params) def def_impl(self, impl): @@ -551,6 +588,9 @@ def abstract_eval(self, *args, **params): def get_bind_params(self, params): return [], params + def is_high(self, **params) -> bool: + return False + def _effect_free_abstract_eval(abstract_eval): def abstract_eval_(*args, **kwargs): @@ -574,7 +614,7 @@ def read(v: Atom) -> Any: def write(v: Var, val: Any) -> None: if config.enable_checks.value and not config.dynamic_shapes.value: - assert typecheck(v.aval, val), (v.aval, val) + assert typecheck(v.aval, val), (v.aval, get_aval(val)) env[v] = val env: dict[Var, Any] = {} @@ -617,12 +657,13 @@ def check_avals_context_mesh(avals, prim_name): TracerType = TypeVar('TracerType', bound='Tracer') class Trace(Generic[TracerType]): - __slots__ = ("__weakref__", "_invalidated", "_weakref") + __slots__ = ("__weakref__", "_invalidated", "_weakref", "requires_low") def __init__(self): self._invalidated = False # We frequently need a weakref to a trace, so let's precompute one. self._weakref = weakref.ref(self) + self.requires_low = True def process_primitive(self, primitive, tracers, params): raise NotImplementedError("must override") @@ -634,7 +675,7 @@ def is_valid(self): return not self._invalidated def __repr__(self): - return '{}'.format(self.__class__.__name__) + return f'{self.__class__.__name__}' def process_call(self, call_primitive, f, tracers, params): msg = (f"{type(self)} must override process_call to handle call-like " @@ -1021,10 +1062,6 @@ def process_primitive(self, primitive, args, params): else: # TODO(dougalm): delete. this shouldn't be necessary args = map(full_lower, args) - if config.data_dependent_tracing_fallback.value: - for arg in args: - if isinstance(arg, Tracer): - return primitive.bind_with_trace(arg._trace, args, params) check_eval_args(args) return primitive.impl(*args, **params) @@ -1049,6 +1086,8 @@ def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers, **_): # py del primitive, fwd, bwd, _ # Unused. return fun.call_wrapped(*tracers) + def cur_qdd(self, x): + return x.cur_qdd() class TraceTag: # TODO: this works for surprisingly subtle reasons. Function transformations @@ -1439,6 +1478,8 @@ def definitely_equal(x, y): class AbstractValue: __slots__: list[str] = [] + is_high = False + has_qdd = False def to_tangent_aval(self): raise NotImplementedError("must override") @@ -1457,6 +1498,9 @@ def __repr__(self): def update_weak_type(self, weak_type): return self + def update_vma(self, vma): + return self + def strip_weak_type(self) -> AbstractValue: return self.update_weak_type(False) @@ -1466,7 +1510,13 @@ def normalize(self) -> AbstractValue: def update(self, **kwargs): raise NotImplementedError("must override") - def str_short(self, short_dtypes=False): + def lo_ty(self): + raise NotImplementedError("must override") + + def lo_ty_qdd(self, qdd): + raise NotImplementedError("avals with qdd must override") + + def str_short(self, short_dtypes=False, mesh_axis_types=False): return str(self) # For type signatures involving dynamic shapes, we use lists of abstract values @@ -1502,11 +1552,6 @@ def _jaxpr_type_to_callable_annotation(jaxpr: Jaxpr) -> InputType: for v in jaxpr.invars] return tuple(out) -# TODO(dougalm): Deprecate. This is here for backwards compat. -def lattice_join(x, y): - assert typematch(x, y) - return x - # For use in typing annotations to denote either a Tracer or a `valid_jaxtype`. Value = Any @@ -1528,7 +1573,7 @@ def check_valid_jaxtype(x): def update_aval_with_sharding(aval, sharding): if isinstance(sharding, NamedSharding): - aval = aval.update(sharding=NamedSharding( + return aval.update(sharding=NamedSharding( sharding.mesh.abstract_mesh, sharding.spec._normalized_spec_for_aval(aval.ndim))) return aval @@ -1554,6 +1599,12 @@ def shaped_abstractify(x): if isinstance(x, AbstractValue): return x if hasattr(x, '__jax_array__'): + deprecations.warn( + 'jax-abstract-dunder-array', + ('Triggering of __jax_array__() during abstractification is deprecated.' + ' To avoid this error, either explicitly convert your object using' + ' jax.numpy.array(), or register your object as a pytree.'), + stacklevel=6) return shaped_abstractify(x.__jax_array__()) if hasattr(x, 'dtype'): aval = ShapedArray(np.shape(x), x.dtype, @@ -1578,6 +1629,12 @@ def get_aval(x): if (aval_fn := pytype_aval_mappings.get(t)): return aval_fn(x) if hasattr(x, '__jax_array__'): + deprecations.warn( + 'jax-abstract-dunder-array', + ('Triggering of __jax_array__() during abstractification is deprecated.' + ' To avoid this error, either explicitly convert your object using' + ' jax.numpy.array(), or register your object as a pytree.'), + stacklevel=6) return get_aval(x.__jax_array__()) raise TypeError(f"Argument '{x}' of type '{typ}' is not a valid JAX type") @@ -1630,6 +1687,54 @@ def concrete_dim_or_error(val: Any, context=""): else: return concrete_or_error(operator.index, val, context=context) +### Quasi-dynamic data + +# Quasi-dynamic data includes things like liveness bits and the content type of +# a type-changeable box. These change throughout the program but at a given +# point in the program they have a single statically known value. + +class MutableQuasiDynamicData: + def __init__(self, val : QuasiDynamicData | None): + self.init_val = val + self.cur_val = val # immutable payload + + def update(self, val): + self.cur_val = val + +class QuasiDynamicData: + pass + +@dataclass(frozen=True) +class AvalQDD: + aval: AbstractValue + qdd: QuasiDynamicData | None # immutable + + has_qdd = True + def lo_ty(self): + return self.aval.lo_ty_qdd(self.qdd) # type: ignore + + def read_loval(self, val): + return self.aval.read_loval(self.qdd, val) # type: ignore + + def new_from_loval(self, *lovals): + return self.aval.new_from_loval(self.qdd, *lovals) # type: ignore + + def to_tangent_aval(self): + return AvalQDD(self.aval.to_tangent_aval(), self.qdd.to_tangent_qdd()) + +@dataclass(frozen=True) +class AvalMutableQDD: + aval: AbstractValue + mutable_qdd: MutableQuasiDynamicData + +def cur_qdd(x): + prev_trace = trace_ctx.trace + trace_ctx.set_trace(eval_trace) + try: + return prev_trace.cur_qdd(x) + finally: + trace_ctx.set_trace(prev_trace) + ### Extended dtypes # # Extended dtypes are JAX-specific dtypes that allow us to represent logical @@ -1660,7 +1765,8 @@ def physical_aval(aval): if isinstance(aval, ShapedArray): from jax._src.sharding_impls import physical_sharding # type: ignore return ShapedArray((*aval.shape, *elt_aval.shape), elt_aval.dtype, - sharding=physical_sharding(aval, aval.sharding)) + sharding=physical_sharding(aval, aval.sharding), + vma=aval.vma) return DShapedArray((*aval.shape, *elt_aval.shape), elt_aval.dtype) return aval @@ -1704,6 +1810,9 @@ def __repr__(self): return '{}({}{})'.format(self.__class__.__name__, self.str_short(), ", weak_type=True" if self.weak_type else "") + def __str__(self): + return '{}{}'.format("~" if self.weak_type else "", self.str_short()) + _bool = concretization_function_error(bool) _int = concretization_function_error(int, True) _float = concretization_function_error(float, True) @@ -1712,7 +1821,7 @@ def __repr__(self): _oct = concretization_function_error(oct) _index = concretization_function_error(operator.index) - def str_short(self, short_dtypes=False) -> str: + def str_short(self, short_dtypes=False, mesh_axis_types=False) -> str: return dtypes.short_dtype_name(self.dtype) if short_dtypes else self.dtype.name def update_weak_type(self, weak_type): @@ -1786,6 +1895,10 @@ def _invalid_shape_error(shape: Shape, context: str=""): return TypeError(msg) +class ShardingTypeError(Exception): + pass + + # TODO(dougalm): Cast scalar, numpy arrays, etc to jax arrays so that values # passed to primitives are always have avals, etc i.e. they are canonical. def canonicalize_value(val): @@ -1801,14 +1914,20 @@ def canonicalize_value(val): cur_mesh = mesh_lib.get_abstract_mesh() if cur_mesh == aval.sharding.mesh: return val - # Atleast 1 mesh axis should be Manual and all other axes should be - # Manual or Auto to allow casting. # TODO(yashkatariy): Casting to Explicit is not yet allowed. Maybe we need # cast_and_slice_p for it since shape might change? - if (cur_mesh._any_axis_manual and cur_mesh._are_all_axes_auto_or_manual and - aval.sharding.mesh._are_all_axes_auto): - from jax._src.pjit import mesh_cast # pytype: disable=import-error - return mesh_cast(val, NamedSharding(cur_mesh, P(*[None] * aval.ndim))) + # Atleast 1 mesh axis should be Manual and all other axes should be + # Manual or Auto to allow casting. + if cur_mesh._any_axis_manual and cur_mesh._are_all_axes_auto_or_manual: + if aval.sharding.mesh._are_all_axes_auto: + from jax._src.pjit import mesh_cast # pytype: disable=import-error + return mesh_cast(val, NamedSharding(cur_mesh, P(*[None] * aval.ndim))) + elif aval.sharding.mesh._any_axis_explicit: + raise NotImplementedError( + "Closing over inputs to shard_map where the input is sharded on" + " `Explicit` axes is not implemented. As a workaround, please pass" + " those inputs as an argument to shard_map. Got input with shape" + f" {aval.str_short(True, True)}") return val @@ -1817,11 +1936,12 @@ def get_cur_mesh_sharding(spec=None): return NamedSharding(mesh_lib.get_abstract_mesh(), spec) def _make_lengths_same(sharding, ndim): - if ndim > len(sharding.spec): - return sharding.with_spec(sharding.spec._normalized_spec_for_aval(ndim)) - if ndim < len(sharding.spec): - assert all(s is None for s in sharding.spec[ndim:]) - return sharding.with_spec(sharding.spec[:ndim]) + pspec = sharding.spec + if ndim > len(pspec): + return sharding.update(spec=pspec._normalized_spec_for_aval(ndim)) + if ndim < len(pspec): + assert all(s is None for s in pspec[ndim:]), (ndim, pspec) + return sharding.update(spec=P(*pspec[:ndim], unreduced=pspec.unreduced)) assert False, "unreachable" # TODO(yashkatariya): Only works with User/Auto. Generalize it to work with @@ -1833,27 +1953,21 @@ def modify_spec_for_auto_manual(spec, mesh) -> P: new_spec.append(s) else: temp_s = s[0] if isinstance(s, tuple) else s - new_spec.append( - None - if mesh._name_to_type[temp_s] in (AxisType.Auto, AxisType.Manual) - else s) - return P(*new_spec) + new_spec.append(s if mesh._name_to_type[temp_s] == AxisType.Explicit + else None) + new_unreduced = {u for u in spec.unreduced + if mesh._name_to_type[u] == AxisType.Explicit} + return P(*new_spec, unreduced=new_unreduced) def _maybe_modify_sharding(sharding, ndim): if len(sharding.spec) == 0 or all(s is None for s in sharding.spec): - if len(sharding.spec) != ndim: - return _make_lengths_same(sharding, ndim) - return sharding - - if sharding.mesh._are_all_axes_explicit: - if ndim > len(sharding.spec): - return sharding.with_spec(sharding.spec._normalized_spec_for_aval(ndim)) - return sharding - - out = sharding.with_spec(modify_spec_for_auto_manual( - sharding.spec, sharding.mesh)) - if (len(out.spec) != ndim and - (out.mesh.empty or out.mesh._are_all_axes_auto_or_manual)): + out = sharding + elif sharding.mesh._are_all_axes_explicit: + out = sharding + else: + out = sharding.update(spec=modify_spec_for_auto_manual( + sharding.spec, sharding.mesh)) + if len(out.spec) != ndim: out = _make_lengths_same(out, ndim) return out @@ -1894,28 +2008,58 @@ def get_sharding(sharding, shape): raise ValueError("Mesh of an aval must be an AbstractMesh. " f"Got {out_s.mesh} of type {type(out_s.mesh)}") _check_divisibility(out_s, shape) + assert out_s.memory_kind is None return out_s -def str_short_aval(shape, dtype, mesh, spec, short_dtypes=False, - mesh_axis_types=False) -> str: +def str_short_aval(shape, dtype, mesh, spec, vma, + short_dtypes=False, mesh_axis_types=False) -> str: dt_str = dtypes.short_dtype_name(dtype) if short_dtypes else dtype.name dt_str = dt_str.replace('void', 'float0') shapestr = _get_shape_sharding_str(shape, spec) mesh_axes = f'({mesh._axis_types_dict})' if mesh_axis_types else '' - return f'{dt_str}[{shapestr}]{mesh_axes}' + vma_ur = _vma_ur_str(vma, spec.unreduced) + return f'{dt_str}[{shapestr}]{vma_ur}{mesh_axes}' + +@cache(max_size=4096, trace_context_in_key=False) +def get_vma(vma, mesh): + if mesh.empty: + return vma + axis_env = get_axis_env() + for i in vma: + if axis_env.axis_exists(i) and i not in mesh._name_to_type: + continue + if mesh._name_to_type[i] != AxisType.Manual: + raise ValueError( + "Axes mentioned in `vma` field of ShapedArray should" + f" be of type `Manual`. Got axis: {i} of type {mesh._name_to_type[i]}") + assert isinstance(vma, frozenset) + return vma + + +class SingleSideCollectiveEffect(effects.Effect): + __str__ = lambda _: "one-sided communication" + + +single_side_collective_effect = SingleSideCollectiveEffect() +effects.control_flow_allowed_effects.add_type(SingleSideCollectiveEffect) class ShapedArray(UnshapedArray): - __slots__ = ['shape', 'sharding', 'varying_manual_axes'] # inherits slots from parent + __slots__ = ['shape', 'sharding', 'vma'] # inherits slots from parent array_abstraction_level = 2 def __init__(self, shape, dtype, weak_type=False, *, sharding=None, - varying_manual_axes: frozenset[AxisName] = frozenset()): + vma: frozenset[AxisName] = frozenset()): self.shape = canonicalize_shape(shape) self.dtype = _dtype_object(dtype) self.weak_type = weak_type self.sharding = get_sharding(sharding, self.shape) - if config.varying_axes_in_types.value: - self.varying_manual_axes = varying_manual_axes + # short for varying_manual_axes. See docs at + # https://docs.jax.dev/en/latest/notebooks/shard_map.html#tracking-how-values-vary-over-manual-mesh-axes-and-check-vma-true + self.vma = get_vma(vma, self.sharding.mesh) + + def lower_val(self, val): return [val] + def raise_val(self, val): return val + def lo_ty(self): return [self] def update(self, shape=None, dtype=None, weak_type=None, **kwargs): if shape is None: @@ -1926,9 +2070,8 @@ def update(self, shape=None, dtype=None, weak_type=None, **kwargs): weak_type = self.weak_type if 'sharding' not in kwargs: kwargs['sharding'] = self.sharding - if 'varying_manual_axes' not in kwargs: - kwargs['varying_manual_axes'] = getattr(self, 'varying_manual_axes', - frozenset()) + if 'vma' not in kwargs: + kwargs['vma'] = self.vma return ShapedArray(shape, dtype, weak_type, **kwargs) ndim = property(lambda self: len(self.shape)) @@ -1946,26 +2089,30 @@ def __eq__(self, other): and self.dtype == other.dtype and self.shape == other.shape and self.weak_type == other.weak_type and self.sharding == other.sharding - and (getattr(self, 'varying_manual_axes', frozenset()) == - getattr(other, 'varying_manual_axes', frozenset()))) + and self.vma == other.vma) def __hash__(self): # can use hash(self.dtype) and rely on the fact that numpy reuses base dtype # objects, e.g. `np.zeros(3).dtype is np.zeros(4).dtype`, or we can use # the unique character code via hash(self.dtype.char) return hash((self.shape, self.dtype, self.weak_type, self.sharding, - getattr(self, 'varying_manual_axes', frozenset()))) + self.vma)) def to_tangent_aval(self): return ShapedArray( self.shape, primal_dtype_to_tangent_dtype(self.dtype), - self.weak_type, sharding=self.sharding, - varying_manual_axes=getattr(self, 'varying_manual_axes', frozenset())) + self.weak_type, sharding=self.sharding, vma=self.vma) + + def to_cotangent_aval(self): + dtype = primal_dtype_to_tangent_dtype(self.dtype) + sharding = primal_sharding_to_cotangent_sharding(self.sharding) + return ShapedArray( + self.shape, dtype, self.weak_type, sharding=sharding, vma=self.vma) def str_short(self, short_dtypes=False, mesh_axis_types=False): return str_short_aval( self.shape, self.dtype, self.sharding.mesh, self.sharding.spec, - short_dtypes, mesh_axis_types) + self.vma, short_dtypes, mesh_axis_types) def _len(self, ignored_tracer): try: @@ -1973,6 +2120,9 @@ def _len(self, ignored_tracer): except IndexError as err: raise TypeError("len() of unsized object") from err # same as numpy error + def update_vma(self, vma): + return self.update(vma=vma) + def _get_shape_sharding_str(shape, spec): out = [] @@ -1986,6 +2136,18 @@ def _get_shape_sharding_str(shape, spec): out.append(f"{s1}@{s2}") return ','.join(out) +def _create_str(x, prefix): + x_str = f"{','.join(i for i in x)}" + x_str = x_str if len(x) == 1 else f"({x_str})" + return f"{prefix}:{x_str}" + +def _vma_ur_str(vma, unreduced): + if not vma and not unreduced: + return '' + vma_str = _create_str(vma, 'V') if vma else '' + ur_str = _create_str(unreduced, 'U') if unreduced else '' + sep = ', ' if vma and unreduced else '' + return f"{{{vma_str}{sep}{ur_str}}}" def primal_dtype_to_tangent_dtype(primal_dtype): if isinstance(primal_dtype, dtypes.ExtendedDType): @@ -1995,6 +2157,71 @@ def primal_dtype_to_tangent_dtype(primal_dtype): else: return primal_dtype +def primal_sharding_to_cotangent_sharding(sharding): + new_spec = P(*sharding.spec, unreduced=sharding.spec.reduced, + reduced=sharding.spec.unreduced) + return sharding.update(spec=new_spec) + +def pvary(x, axis_name): + if not axis_name: + return x + axes = (axis_name,) if not isinstance(axis_name, tuple) else axis_name + xs, treedef = tree_flatten(x) + ys = pvary_p.bind(*xs, axes=axes, axis_index_groups=None) + return tree_unflatten(treedef, ys) + +pvary_p = Primitive('pvary') +pvary_p.multiple_results = True +pvary_p.def_impl(lambda *args, axes, axis_index_groups: args) + +def _pvary_abstract_eval(*args, axes, axis_index_groups): + if not config._check_vma.value: + return args + assert isinstance(axes, tuple) + arg_vma = [a.vma for a in args] + # If there is intersection between arg_vma and axes, error + if any(set(axes) & a for a in arg_vma): + raise ValueError( + "Collective pvary must be applied to a " + f"non-device-varying type, but got {arg_vma} for collective acting " + f"over axis name {axes}. Please open an issue at " + "https://github.com/jax-ml/jax/issues, and as a temporary " + "workaround pass the check_vma=False argument to `jax.shard_map`") + sharding = NamedSharding(mesh_lib.get_abstract_mesh(), P()) + return [a.update(sharding=sharding, vma=a.vma.union(frozenset(axes))) + for a in args] +pvary_p.def_abstract_eval(_pvary_abstract_eval) + + +def standard_insert_pvary(*args): + if not config._check_vma.value: + return args + if not args: + return args + in_vma = [frozenset() if (aval := get_aval(a)) is abstract_token + else aval.vma for a in args] # pytype: disable=attribute-error + out_vma = frozenset.union(*in_vma) + return [ + pvary(arg, tuple(n for n in out_vma if n not in src)) + if isinstance(get_aval(arg), ShapedArray) and out_vma - src + else arg + for arg, src in zip(args, in_vma) + ] + +def standard_vma_rule(prim_name, *avals, **kwargs) -> frozenset[AxisName]: + if not config._check_vma.value: + return frozenset() + avals = tuple(a for a in avals if a is not abstract_token) + if not avals: + return frozenset() + vma, *vmas = (a.vma for a in avals) + if not all(vma == vma_ for vma_ in vmas): + raise ValueError( + f'Primitive {prim_name} requires varying manual axes ' + f'to match, but got {[vma, *vmas]}. Please open an issue at ' + 'https://github.com/jax-ml/jax/issues and as a temporary ' + 'workaround pass the check_vma=False argument to `jax.shard_map`') + return vma # Dynamic shape stuff below here! We keep the abstract values distinct just so # as not to interfere with any static shape machinery. @@ -2013,6 +2240,7 @@ class DShapedArray(UnshapedArray): array_abstraction_level: int = 3 def __init__(self, shape, dtype, weak_type=False): + assert not any(isinstance(d, Literal) for d in shape) self.shape = shape self.dtype = dtype self.weak_type = weak_type @@ -2022,7 +2250,7 @@ def __init__(self, shape, dtype, weak_type=False): 0 if any(type(d) is int and d == 0 for d in self.shape) else math.prod(self.shape)) - def str_short(self, short_dtypes=False) -> str: + def str_short(self, short_dtypes=False, mesh_axis_types=False) -> str: del short_dtypes # ignored shape = f'{",".join(str(d) for d in self.shape)}' if self.shape else '' dtype = dtypes.short_dtype_name(self.dtype) @@ -2042,6 +2270,10 @@ def update(self, shape=None, dtype=None, weak_type=None): def sharding(self): return NamedSharding(mesh_lib.empty_abstract_mesh, P()) + @property + def vma(self): + return frozenset() + def _len(self, tracer): return self.shape[0] @@ -2051,12 +2283,16 @@ def __eq__(self, other): and self.weak_type == other.weak_type) def __hash__(self): - return hash((self.shape, self.dtype, self.weak_type)) + # We don't hash the contents of the shape because it may contain tracers. + return hash((len(self.shape), self.dtype, self.weak_type)) def to_tangent_aval(self): return DShapedArray(self.shape, primal_dtype_to_tangent_dtype(self.dtype), self.weak_type) + def update_vma(self, vma): + return self + class DArray: _aval: DShapedArray @@ -2143,9 +2379,12 @@ def __init__(self, aval, buf): shape = property(lambda self: self._aval.shape) dtype = property(lambda self: self._aval.dtype) sharding = property(lambda self: self._buf.sharding) + format = property(lambda self: self._buf.format) + committed = _committed = property(lambda self: self._buf._committed) def __getitem__(self, idx): return self._aval._getitem(self, idx) def __setitem__(self, idx, x): return self._aval._setitem(self, idx, x) def __repr__(self) -> str: return 'Mutable' + repr(self[...]) + def __len__(self) -> int: return self._aval._len(self) pytype_aval_mappings[MutableArray] = lambda x: x._aval def mutable_array(init_val): @@ -2183,7 +2422,7 @@ def _freeze_impl(ref): return ref[()] class AbstractToken(AbstractValue): - def str_short(self, short_dtypes=False): return 'Tok' + def str_short(self, short_dtypes=False, mesh_axis_types=False): return 'Tok' def to_tangent_aval(self): return self abstract_token: AbstractToken = AbstractToken() @@ -2203,16 +2442,10 @@ def block_until_ready(self): pytype_aval_mappings[Token] = lambda _: abstract_token -# TODO(dougalm): Deprecate these. They're just here for backwards compat. -def raise_to_shaped(aval): - return aval -raise_to_shaped_mappings: dict[type, Callable] = {} - ### Operations on shapes and dimension sizes. class InconclusiveDimensionOperation(Exception): """Raised when we cannot conclusively compute with symbolic dimensions.""" - pass def is_symbolic_dim(v: Any) -> bool: """Checks if a value is a symbolic dimension used for shape polymorphism. @@ -2224,6 +2457,9 @@ def is_symbolic_dim(v: Any) -> bool: def is_constant_dim(d: DimSize) -> bool: # Whether the dimension is a static integer constant. + # Try using a fast path for non-concrete Tracers. + if isinstance(d, Tracer) and not is_concrete(d): + return False try: operator.index(d) return True @@ -2547,22 +2783,23 @@ def unmapped_aval(size: AxisSize, axis: int | None, def _map_shaped_array( size: int, axis: int | None, aval: ShapedArray) -> ShapedArray: - assert axis is None or aval.shape[axis] == size - # TODO: Extend the named shape - if axis is None: return aval - sharding = aval.sharding.with_spec(tuple_delete(aval.sharding.spec, axis)) + # assert axis is None or aval.shape[axis] == size + if axis is None: + return aval + sharding = aval.sharding.update(spec=tuple_delete(aval.sharding.spec, axis)) return ShapedArray(tuple_delete(aval.shape, axis), aval.dtype, - weak_type=aval.weak_type, sharding=sharding) + weak_type=aval.weak_type, sharding=sharding, vma=aval.vma) def _unmap_shaped_array( size: int, axis: int | None, explicit_mesh_axis, aval: ShapedArray ) -> ShapedArray: if axis is None: return aval elif type(axis) is int: - sharding = aval.sharding.with_spec(tuple_insert( + sharding = aval.sharding.update(spec=tuple_insert( aval.sharding.spec, axis, explicit_mesh_axis)) return ShapedArray(tuple_insert(aval.shape, axis, size), aval.dtype, - weak_type=aval.weak_type, sharding=sharding) + weak_type=aval.weak_type, sharding=sharding, + vma=aval.vma) else: raise TypeError(axis) def _map_dshaped_array( @@ -2616,10 +2853,8 @@ def __lt__(self, other): @dataclass(frozen=True) class NamedAxisEffect(effects.Effect): """A side-effect introducing a new named axis into the current scope.""" - name: AxisName - effects.control_flow_allowed_effects.add_type(NamedAxisEffect) effects.custom_derivatives_allowed_effects.add_type(NamedAxisEffect) effects.lowerable_effects.add_type(NamedAxisEffect) @@ -2674,10 +2909,31 @@ def typematch(t1: AbstractValue, t2: AbstractValue) -> bool: # could try normalizing first and then doing simple equality. # TODO(yashkatariya): Also check `sharding` here. # See https://github.com/jax-ml/jax/issues/26474 - return t1.dtype == t2.dtype and definitely_equal_shape(t1.shape, t2.shape) + return (t1.dtype == t2.dtype and definitely_equal_shape(t1.shape, t2.shape) + and t1.vma == t2.vma) # type: ignore else: return False +def aval_mismatch_extra(a1: AbstractValue, a2: AbstractValue) -> str: + assert not typematch(a1, a2) + if isinstance(a1, ShapedArray) and isinstance(a2, ShapedArray): + mismatches = [] + if a1.dtype != a2.dtype: + mismatches.append('the dtypes do not match') + if a1.shape != a2.shape: + mismatches.append('the shapes do not match') + if a1.vma != a2.vma: + mismatches.append('the varying manual axes do not match') + # TODO(yashkatariya,mattjj): add check for sharding-in-types mismatch + + if len(mismatches) == 0: + return '' + elif len(mismatches) == 1: + return ', so ' + mismatches[0] + else: + return ', so ' + ', '.join(mismatches[:-1]) + ', and ' + mismatches[-1] + return '' + class JaxprTypeError(TypeError): pass custom_typechecks: dict[Primitive, Callable] = {} @@ -2728,15 +2984,19 @@ def ctx_factory(): from jax.experimental.key_reuse._core import check_key_reuse_jaxpr # pytype: disable=import-error check_key_reuse_jaxpr(jaxpr) +# A place to track the quasi-dynamic data associated with a variable during typechecking +@dataclass(frozen=True) +class MutableTypecheckVal: + aval : AbstractValue + mutable_qdd : MutableQuasiDynamicData def _check_jaxpr( ctx_factory: Callable[[], tuple[JaxprPpContext, JaxprPpSettings]], jaxpr: Jaxpr ) -> None: - # Use set of variables to types to check that variables are in scope. - env: set[Var] = set() + env: dict[Var, Atom | MutableTypecheckVal] = {} - def read(x: Atom) -> Atom: + def read(x: Atom) -> Atom | MutableTypecheckVal: # Check the type annotation is itself well-typed. check_type(ctx_factory, env, x.aval) if isinstance(x, Var): @@ -2744,7 +3004,7 @@ def read(x: Atom) -> Atom: if x not in env: ctx, _ = ctx_factory() raise JaxprTypeError(f"Variable '{pp_var(x, ctx)}' not defined") - return x + return env[x] elif isinstance(x, Literal): # Check that the literal matches its type annotation. if not typecheck(x.aval, x.val): @@ -2756,7 +3016,8 @@ def read(x: Atom) -> Atom: else: assert False, "syntactically invalid jaxpr" - def write(v: Var, a: AbstractValue) -> None: + def write(v: Var, a: AvalQDD) -> None: + aval, qdd = a.aval, a.qdd assert isinstance(v, Var), "syntactically invalid jaxpr" # Check the type annotation of the binder is itself well-typed. check_type(ctx_factory, env, v.aval) @@ -2765,19 +3026,23 @@ def write(v: Var, a: AbstractValue) -> None: ctx, _ = ctx_factory() raise JaxprTypeError(f"Variable '{pp_var(v, ctx)}' already bound") # Check that the computed type is consistent with the binder annotation. - if not typematch(v.aval, a): + if not typematch(v.aval, aval): ctx, _ = ctx_factory() raise JaxprTypeError( f"Value for variable '{pp_var(v, ctx)}' inconsistently typed " - f"as {pp_aval(a, ctx)} for let-binder of type {pp_aval(v.aval, ctx)}") + f"as {pp_aval(aval, ctx)} for let-binder of type {pp_aval(v.aval, ctx)}") + # If the variable is not a DropVar, add it to the environment. if not isinstance(v, DropVar): - env.add(v) + if qdd is None: + env[v] = v + else: + env[v] = MutableTypecheckVal(aval, MutableQuasiDynamicData(qdd)) # Check type annotations on lambda binders. for v in it.chain(jaxpr.constvars, jaxpr.invars): check_type(ctx_factory, env, v.aval) - write(v, v.aval) + write(v, AvalQDD(v.aval, v.initial_qdd)) # Check each eqn. sentinel = object() @@ -2787,7 +3052,8 @@ def write(v: Var, a: AbstractValue) -> None: prim = eqn.primitive try: in_atoms = map(read, eqn.invars) - in_avals = [x.aval for x in in_atoms] # use in_atoms for dyn shapes + in_avals = [AvalMutableQDD(x.aval, x.mutable_qdd) if isinstance(x, MutableTypecheckVal) + else x.aval for x in in_atoms] # use in_atoms for dyn shapes # Compute the type of the primitive application. with eqn.ctx.manager: @@ -2837,6 +3103,7 @@ def write(v: Var, a: AbstractValue) -> None: # Check out_type matches the let-binders' annotation (after substitution). out_type = substitute_vars_in_output_ty(out_type, eqn.invars, eqn.outvars) + out_type = [t if isinstance(t, AvalQDD) else AvalQDD(t, None) for t in out_type] foreach(write, eqn.outvars, out_type) except JaxprTypeError as e: @@ -2852,7 +3119,7 @@ def write(v: Var, a: AbstractValue) -> None: def check_type( ctx_factory: Callable[[], tuple[JaxprPpContext, JaxprPpSettings]], - env: set[Var], + env: dict[Var, Atom | MutableTypecheckVal], ty: AbstractValue, ) -> None: if isinstance(ty, DShapedArray): @@ -2922,7 +3189,7 @@ def _check_call(ctx_factory, prim, in_atoms, params): f"{len(call_jaxpr.invars)} inputs") # Check `call_jaxpr` can be applied to in_atoms. - env: dict[Var, Atom] = {} + env: dict[Var, Atom | MutableTypecheckVal] = {} def substitute(aval: AbstractValue): if isinstance(aval, DShapedArray): aval = aval.update(shape=tuple(env.get(d, d) for d in aval.shape)) # type: ignore @@ -2933,7 +3200,7 @@ def substitute(aval: AbstractValue): raise JaxprTypeError(f"Call primitive {prim} passes operand {x} of type " f"{x.aval} to jaxpr expecting type " f"{substitute(v.aval)}") - env[v] = x if type(x) is Var else x.val + env[v] = x.val if type(x) is Literal else x _check_jaxpr(ctx_factory, call_jaxpr) @@ -3097,9 +3364,9 @@ def suggest_same_var_names(self, self.var_names[for_v] = pp_var(like_v, self) -def pp_var(v: Var | Literal, context: JaxprPpContext) -> str: - if isinstance(v, (Literal, DropVar)): return str(v) - return f"{context.var_names[v]}{v.suffix}" +def pp_var(v: Var | Literal, context: JaxprPpContext, *, + print_literal_dtype: bool = True) -> str: + return v.pretty_print(context, print_dtype=print_literal_dtype) def pp_aval(a: AbstractValue, context: JaxprPpContext) -> str: if isinstance(a, DShapedArray): @@ -3139,13 +3406,13 @@ def pp_kv_pair(k:str, v: Any, context: JaxprPpContext, settings: JaxprPpSettings def pp_kv_pairs(kv_pairs, context: JaxprPpContext, settings: JaxprPpSettings) -> pp.Doc: if not kv_pairs: return pp.nil() - return pp.group( + return pp.group(pp.concat([ pp.nest(2, pp.concat([ pp.text("["), pp.brk(""), pp.join(pp.brk(), [pp_kv_pair(k, v, context, settings) for k, v in kv_pairs]) - ])) - + pp.brk("") + pp.text("]") - ) + ])), + pp.brk(""), pp.text("]") + ])) def pp_eqn(eqn: JaxprEqn, context: JaxprPpContext, settings: JaxprPpSettings ) -> pp.Doc: @@ -3166,7 +3433,7 @@ def _pp_eqn(eqn: JaxprEqn, context: JaxprPpContext, settings: JaxprPpSettings, rhs = [pp.text(eqn.primitive.name, annotation=name_stack_annotation), pp_kv_pairs([(p, eqn.params[p]) for p in params], context, settings), pp.text(" ") + pp_vars(eqn.invars, context)] - if lhs.format(): + if eqn.outvars: return pp.concat([lhs, pp.text(" = ", annotation=annotation), *rhs]) else: return pp.concat(rhs) @@ -3261,10 +3528,10 @@ def pp_jaxpr( def pp_jaxprs(jaxprs: Sequence[ClosedJaxpr | Jaxpr], context: JaxprPpContext, settings: JaxprPpSettings) -> pp.Doc: jaxprs = [j.jaxpr if isinstance(j, ClosedJaxpr) else j for j in jaxprs] - return pp.group(pp.nest(2, pp.concat([ + return pp.group(pp.concat([pp.nest(2, pp.concat([ pp.text('('), pp.brk(""), pp.join(pp.brk(), map(lambda x: pp_jaxpr(x, context, settings), jaxprs))] - )) + pp.brk("") + pp.text(')') + )), pp.brk(""), pp.text(')')]) ) @@ -3330,7 +3597,7 @@ def __eq__(self, other): else: return False -def get_opaque_trace_state(convention): +def get_opaque_trace_state(convention=None): del convention return OpaqueTraceState(trace_ctx.trace._weakref) diff --git a/jax/_src/cudnn/fused_attention_stablehlo.py b/jax/_src/cudnn/fused_attention_stablehlo.py index c7e7c83f30f8..f246ff4b7aa7 100644 --- a/jax/_src/cudnn/fused_attention_stablehlo.py +++ b/jax/_src/cudnn/fused_attention_stablehlo.py @@ -24,12 +24,11 @@ from jax._src import dispatch from jax._src.custom_partitioning import custom_partitioning from jax._src.interpreters import batching +from jax._src.interpreters import mlir from jax._src.lib import cuda_versions from jax._src import xla_bridge -from jax.interpreters import mlir -from jax.interpreters import xla -from jax.interpreters.mlir import hlo -from jax.interpreters.mlir import ir +from jax._src.lib.mlir import ir +from jax._src.lib.mlir.dialects import hlo import jax.numpy as jnp from jax.sharding import NamedSharding, PartitionSpec @@ -122,6 +121,9 @@ def default_layouts(*shapes): def get_max_seg_per_batch(q_offsets): return q_offsets.shape[1] - 1 if len(q_offsets.shape) == 2 else 1 +def check_is_paged_attention(page_table_k): + return len(page_table_k.shape) == 4 + def create_dot_product_attention_backend_config_base( batch, num_heads, seq_q, seq_kv, dtype, fmha_scale, mask_type, layout, is_bwd ): @@ -229,6 +231,7 @@ def create_dot_product_attention_backend_config( layout, sliding_window_length, max_seg_per_batch, + is_paged_attention, is_bwd ): backend_config = create_dot_product_attention_backend_config_base( @@ -241,6 +244,7 @@ def create_dot_product_attention_backend_config( backend_config['cudnn_fmha_backend_config']["seed"] = seed backend_config['cudnn_fmha_backend_config']["sliding_window_length"] = sliding_window_length backend_config['cudnn_fmha_backend_config']["max_seg_per_batch"] = max_seg_per_batch + backend_config['cudnn_fmha_backend_config']["is_paged_attention"] = is_paged_attention return json.dumps(backend_config) def create_dot_product_attention_fp8_backend_config( @@ -273,7 +277,7 @@ def get_custom_call_name(has_bias, has_dropout, is_bwd, is_fp8=False): ) def check_layout(query, key, value, bias, q_seqlen, kv_seqlen, - q_offsets, kv_offsets, layout): + q_offsets, kv_offsets, page_table_k, page_table_v, layout): def check_eq(a, b, c, msg): if not (a == b == c): raise ValueError(f"{msg} must be same, got {a}, {b}, {b}") @@ -298,8 +302,25 @@ def check_eq(a, b, c, msg): kB, kS, kN, kH = key.shape vB, vS, vN, vH = value.shape + if page_table_k is not None and page_table_v is not None: + k_blocks, k_block_size = kB, kS + v_blocks, v_block_size = vB, vS + kB, _, k_blocks_per_batch, _ = page_table_k.shape + vB, _, v_blocks_per_batch, _ = page_table_v.shape + kS = k_blocks_per_batch * k_block_size + vS = v_blocks_per_batch * v_block_size + if kB * k_blocks_per_batch != k_blocks: + raise ValueError( + f"Key and page_table_k must have same number of blocks, " + f"got {k_blocks} vs {kB * k_blocks_per_batch}") + if vB * v_blocks_per_batch != v_blocks: + raise ValueError( + f"Value and page_table_v must have same number of blocks, " + f"got {v_blocks} vs {vB * v_blocks_per_batch}") + check_eq(qB, kB, vB, "QKV batch") - check_eq(qH, kH, vH, "QKV dim_per_head") + if qH != kH: + raise ValueError(f"QK must have same head dim, got {qH} vs {kH}") if kN != vN: raise ValueError(f"KV must have same number of heads, got {kN} vs {vN}") if kS != vS: @@ -333,33 +354,35 @@ def check_seqlen_offsets(tensor, name): def check_is_flash_attention( - query, key, layout: int, cudnn_version, has_bias, is_training, is_packed=False, - is_fp8=False): + query, key, value, layout: int, cudnn_version, has_bias, is_training, + is_packed=False, is_paged_attention=False, is_fp8=False): # Extract sequence length (T) and head dim (H) based on layout if layout == AttentionLayout.BNTH.value: - _, _, T, H = query.shape - _, _, S, _ = key.shape + _, _, T, qH = query.shape + _, _, S, vH = value.shape else: - _, T, _, H = query.shape - _, S, _, _ = key.shape + _, T, _, qH = query.shape + _, S, _, vH = value.shape # Flash attention conditions if is_fp8: # FP8 specific conditions - if not ((is_training and H == 128 and T % 128 == 0 and S % 128 == 0) or - (not is_training and H <= 256 and H % 16 == 0)): + if not ((is_training and qH == 128 and T % 128 == 0 and S % 128 == 0) or + (not is_training and qH <= 256 and qH % 16 == 0)): raise NotImplementedError( - f"Unsupported sequence length Q {T}, KV {S} and head dim {H} for FP8." + f"Unsupported sequence length Q {T}, KV {S} and head dim {qH} for FP8." ) else: # bf16/fp16 attention conditions # Check the head dim. is_on_hopper = is_cuda_compute_capability_equal("9.0") H_max = 256 if cudnn_version >= 90500 and is_on_hopper else 128 - if not (H <= H_max and H % 8 == 0): + # check if multi-head latent attention is needed + is_mla = qH != vH + if not (qH <= H_max and qH % 8 == 0): raise NotImplementedError( - f"The head dim must be <= {H_max} and a mutiple of 8, " - f"but got {H}." + f"The head dim must be <= {H_max} and a multiple of 8, " + f"but got {qH}." ) # Check patterns with bias, seqlen should be divisible by 2 @@ -368,8 +391,14 @@ def check_is_flash_attention( f"Unsupported sequence length Q {T}, KV {S}." ) - if is_packed and cudnn_version < 90600: - raise NotImplementedError("Packed layout requires cudnn version >= 9.6.") + if is_packed and (cudnn_version < 90600 or not check_compute_capability("9.0")): + raise NotImplementedError( + "Packed layout requires cudnn version >= 9.6 and at least hopper arch.") + if is_paged_attention and cudnn_version < 90500: + raise NotImplementedError("Page attention requires cudnn version >= 9.5.") + if is_mla and (cudnn_version < 91000 or not check_compute_capability("9.0")): + raise NotImplementedError( + "mla requires cudnn version >= 9.10 and at least hopper arch.") def check_cudnn_version(): # check if cuDNN is installed @@ -395,50 +424,59 @@ def is_cuda_compute_capability_equal(capability): def _dot_product_attention_fwd( query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets, + page_table_k, page_table_v, scale, seed, dropout_rate, variadic_args, mask_type, layout, - sliding_window_length, cudnn_version): + sliding_window_length, cudnn_version, return_residual): # check if flash attention is supported for this attention pattern check_is_flash_attention( - query, key, layout, cudnn_version, bias is not None, False, - get_max_seg_per_batch(q_offsets) > 1) + query, key, value, layout, cudnn_version, bias is not None, False, + get_max_seg_per_batch(q_offsets) > 1, check_is_paged_attention(page_table_k)) outputs = _dot_product_attention_fwd_p_wrapper.bind( query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets, - scale=scale, seed=seed, dropout_rate=dropout_rate, + page_table_k, page_table_v, scale=scale, seed=seed, dropout_rate=dropout_rate, variadic_args=variadic_args, mask_type=mask_type, layout=layout, - sliding_window_length=sliding_window_length, is_training=False) - output = outputs[0] - return output + sliding_window_length=sliding_window_length, is_training=False or return_residual) + if return_residual: + return tuple(outputs) + else: + return outputs[0] def _dot_product_attention_fwd_rule( query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets, - scale, seed, dropout_rate, variadic_args, mask_type, layout, - sliding_window_length, cudnn_version): + page_table_k, page_table_v, scale, seed, dropout_rate, variadic_args, + mask_type, layout, sliding_window_length, cudnn_version, + return_residual): # check if flash attention is supported for this attention pattern check_is_flash_attention( - query, key, layout, cudnn_version, bias is not None, True, + query, key, value, layout, cudnn_version, bias is not None, True, get_max_seg_per_batch(q_offsets) > 1) outputs = _dot_product_attention_fwd_p_wrapper.bind( query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets, - scale=scale, seed=seed, dropout_rate=dropout_rate, + page_table_k, page_table_v, scale=scale, seed=seed, dropout_rate=dropout_rate, variadic_args=variadic_args, mask_type=mask_type, layout=layout, sliding_window_length=sliding_window_length, is_training=True) res = (query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, - kv_offsets, outputs[1], outputs[0]) - return outputs[0], res + kv_offsets, page_table_k, page_table_v, outputs[1], outputs[0]) + if return_residual: + return tuple(outputs), res + else: + return outputs[0], res def _dot_product_attention_bwd_rule( scale, seed, dropout_rate, variadic_args, mask_type, layout, - sliding_window_length, is_training, res, grad_output): + sliding_window_length, is_training, return_residual, res, grad_output): (query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets, - activation, fwd_output) = res + page_table_k, page_table_v, activation, fwd_output) = res + if return_residual: + grad_output = grad_output[0] grads = _dot_product_attention_bwd_p_wrapper.bind( query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets, - activation, fwd_output, grad_output, scale=scale, seed=seed, - dropout_rate=dropout_rate, variadic_args=variadic_args, + page_table_k, page_table_v, activation, fwd_output, grad_output, + scale=scale, seed=seed, dropout_rate=dropout_rate, variadic_args=variadic_args, mask_type=mask_type, layout=layout, sliding_window_length=sliding_window_length ) - grads = (*grads,) + (None,) * (8 - len(grads)) + grads = (*grads,) + (None,) * (10 - len(grads)) return grads def _fix_seqlen_offsets(q_seqlen, kv_seqlen, q_offsets, kv_offsets, query, key): @@ -471,7 +509,7 @@ def _cu_offset(offsets, max_seq): batch = offsets.shape[0] offsets = jnp.where( offsets >= 0, - offsets + (jnp.arange(batch) * max_seq)[..., jnp.newaxis], + offsets + (jnp.arange(batch, dtype=offsets.dtype) * max_seq)[..., jnp.newaxis], offsets, ) return offsets @@ -501,27 +539,28 @@ def _cu_offset(offsets, max_seq): def _dot_product_attention_fwd_impl( query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets, - scale, seed, dropout_rate, variadic_args, mask_type, layout, - sliding_window_length, is_training): + page_table_k, page_table_v, scale, seed, dropout_rate, variadic_args, + mask_type, layout, sliding_window_length, is_training): # args: {Q, K, V, mask*, bias*} q_seqlen, kv_seqlen, q_offsets, kv_offsets = \ _fix_seqlen_offsets(q_seqlen, kv_seqlen, q_offsets, kv_offsets, query, key) outputs = _dot_product_attention_fwd_p.bind( query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets, - scale=scale, seed=seed, dropout_rate=dropout_rate, + page_table_k, page_table_v, scale=scale, seed=seed, dropout_rate=dropout_rate, variadic_args=variadic_args, mask_type=mask_type, layout=layout, sliding_window_length=sliding_window_length, is_training=is_training) return outputs def _dot_product_attention_bwd_impl( query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets, - activation, fwd_output, grad_output, scale, seed, dropout_rate, - variadic_args, mask_type, layout, sliding_window_length): + page_table_k, page_table_v, activation, fwd_output, grad_output, scale, + seed, dropout_rate, variadic_args, mask_type, layout, sliding_window_length): q_seqlen, kv_seqlen, q_offsets, kv_offsets = \ _fix_seqlen_offsets(q_seqlen, kv_seqlen, q_offsets, kv_offsets, query, key) grads = _dot_product_attention_bwd_p.bind( query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets, - activation, fwd_output, grad_output, scale=scale, seed=seed, + page_table_k, page_table_v, activation, fwd_output, grad_output, + scale=scale, seed=seed, dropout_rate=dropout_rate, variadic_args=variadic_args, mask_type=mask_type, layout=layout, sliding_window_length=sliding_window_length) @@ -529,16 +568,17 @@ def _dot_product_attention_bwd_impl( def _dot_product_attention_fwd_abstract( query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets, - *, scale, seed, dropout_rate, variadic_args, mask_type, layout, - sliding_window_length, is_training): + page_table_k, page_table_v, *, scale, seed, dropout_rate, variadic_args, + mask_type, layout, sliding_window_length, is_training): query_dtype = dtypes.canonicalize_dtype(query.dtype) if layout == AttentionLayout.BNTH.value: B, N, T, _ = query.shape - _, _, S, _ = key.shape + _, _, S, H = value.shape + output_shape = (B, N, T, H) else: B, T, N, _ = query.shape - _, S, _, _ = key.shape - output_shape = query.shape + _, S, _, H = value.shape + output_shape = (B, T, N, H) max_seg_per_batch = get_max_seg_per_batch(q_offsets) softmax_stat_shape = (B * max_seg_per_batch, N, T) @@ -555,8 +595,8 @@ def _dot_product_attention_fwd_abstract( def _dot_product_attention_bwd_abstract( query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets, - activation, fwd_output, grad_output, *, scale, seed, dropout_rate, - variadic_args, mask_type, layout, sliding_window_length): + page_table_k, page_table_v, activation, fwd_output, grad_output, *, + scale, seed, dropout_rate, variadic_args, mask_type, layout, sliding_window_length): query_dtype = dtypes.canonicalize_dtype(query.dtype) key_dtype = dtypes.canonicalize_dtype(key.dtype) value_dtype = dtypes.canonicalize_dtype(value.dtype) @@ -594,26 +634,28 @@ def _dot_product_attention_bwd_abstract( def _dot_product_attention_fwd_cuda_lowering( ctx, query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, - kv_offsets, scale, seed, dropout_rate, variadic_args, mask_type, - layout, sliding_window_length, is_training): + kv_offsets, page_table_k, page_table_v, scale, seed, dropout_rate, + variadic_args, mask_type, layout, sliding_window_length, is_training): query_type = ir.RankedTensorType(query.type) query_shape = query_type.shape - key_type = ir.RankedTensorType(key.type) - key_shape = key_type.shape + value_type = ir.RankedTensorType(value.type) + value_shape = value_type.shape if layout == AttentionLayout.BNTH.value: - B, N, T, H = query_shape - _, _, S, _ = key_shape + B, N, T, qk_H = query_shape + _, _, S, v_H = value_shape output_layout = (3, 2, 1, 0) output_transpose_perm = mlir.dense_int_array((0, 1, 2, 3)) else: - B, T, N, H = query_shape - _, S, _, _ = key_shape + B, T, N, qk_H = query_shape + _, S, _, v_H = value_shape output_layout = (3, 1, 2, 0) output_transpose_perm = mlir.dense_int_array((0, 2, 1, 3)) max_seg_per_batch = get_max_seg_per_batch(ir.RankedTensorType(q_offsets.type)) - output_shape = (B, N, T, H) + is_paged_attention = check_is_paged_attention(ir.RankedTensorType(page_table_k.type)) + + output_shape = (B, N, T, v_H) softmax_stat_shape = (B * max_seg_per_batch, N, T) workspace_shape = (0,) workspace_type = ir.IntegerType.get_unsigned(8) @@ -622,19 +664,22 @@ def _dot_product_attention_fwd_cuda_lowering( backend_config = create_dot_product_attention_backend_config( B, N, T, S, query_type.element_type, scale, seed, dropout_rate, mask_type, layout, sliding_window_length, max_seg_per_batch, - is_bwd=False) + is_paged_attention, is_bwd=False) # {Q, K, V, bias*, q_seqlen*, kv_seqlen*, q_offsets*, kv_offsets*}} # {output, activation*, workspace} has_dropout = dropout_rate > 0 operands = [query, key, value] if has_bias: operands.append(bias) - if has_padding(mask_type) or max_seg_per_batch > 1: + if has_padding(mask_type) or max_seg_per_batch > 1 or is_paged_attention: operands.append(q_seqlen) operands.append(kv_seqlen) if max_seg_per_batch > 1: operands.append(q_offsets) operands.append(kv_offsets) + if is_paged_attention: + operands.append(page_table_k) + operands.append(page_table_v) custom_call_name = get_custom_call_name(has_bias, has_dropout, False) @@ -670,38 +715,38 @@ def _dot_product_attention_fwd_cuda_lowering( def _dot_product_attention_bwd_cuda_lowering( ctx, query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets, - activation, fwd_output, grad_output, scale, seed, dropout_rate, - variadic_args, mask_type, layout, sliding_window_length): + page_table_k, page_table_v, activation, fwd_output, grad_output, + scale, seed, dropout_rate, variadic_args, mask_type, layout, sliding_window_length): query_type = ir.RankedTensorType(query.type) query_shape = query_type.shape key_type = ir.RankedTensorType(key.type) - key_shape = key_type.shape value_type = ir.RankedTensorType(value.type) + value_shape = value_type.shape if layout == AttentionLayout.BNTH.value: - B, q_N, T, H = query_shape - _, k_N, S, _ = key_shape + B, q_N, T, qk_H = query_shape + _, v_N, S, v_H = value_shape grad_layout = (3, 2, 1, 0) grad_transpose_perm = mlir.dense_int_array((0, 1, 2, 3)) else: - B, T, q_N, H = query_shape - _, S, k_N, _ = key_shape + B, T, q_N, qk_H = query_shape + _, S, v_N, v_H = value_shape grad_layout = (3, 1, 2, 0) grad_transpose_perm = mlir.dense_int_array((0, 2, 1, 3)) workspace_shape = (0,) workspace_type = ir.IntegerType.get_unsigned(8) - grad_query_shape = (B, q_N, T, H) - grad_key_shape = (B, k_N, S, H) - grad_value_shape = (B, k_N, S, H) + grad_query_shape = (B, q_N, T, qk_H) + grad_key_shape = (B, v_N, S, qk_H) + grad_value_shape = (B, v_N, S, v_H) has_bias, has_dbias = variadic_args max_seg_per_batch = get_max_seg_per_batch(ir.RankedTensorType(q_offsets.type)) backend_config = create_dot_product_attention_backend_config( B, q_N, T, S, query_type.element_type, scale, seed, dropout_rate, mask_type, layout, sliding_window_length, max_seg_per_batch, - is_bwd=True) + False, is_bwd=True) # {Q, K, V, activation, dO, bias*, O, q_seqlen*, kv_seqlen*, # q_offsets*, kv_offsets*} # {dQ, dK, dV, dbias*, workspace} @@ -769,7 +814,7 @@ def _dot_product_attention_fwd_batcher( mask_type, layout, sliding_window_length, is_training): _check_valid_batch_dims(batch_dims) query, key, value, bias, q_seqlen, kv_seqlen, \ - q_offsets, kv_offsets = batched_args + q_offsets, kv_offsets, page_table_k, page_table_v = batched_args query_bdim = batch_dims[0] if is_training: out_bdims = query_bdim, query_bdim @@ -797,7 +842,7 @@ def _dot_product_attention_fwd_batcher( outputs = _dot_product_attention_fwd_p_wrapper.bind( query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets, - scale=scale, seed=seed, dropout_rate=dropout_rate, + page_table_k, page_table_v, scale=scale, seed=seed, dropout_rate=dropout_rate, variadic_args=variadic_args, mask_type=mask_type, layout=layout, sliding_window_length=sliding_window_length, is_training=is_training) @@ -816,7 +861,7 @@ def _dot_product_attention_bwd_batcher( mask_type, layout, sliding_window_length): _check_valid_batch_dims(batch_dims) query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets, \ - activation, fwd_output, grad_output = batched_args + page_table_k, page_table_v, activation, fwd_output, grad_output = batched_args query_bdim = batch_dims[0] out_bdims = query_bdim, query_bdim, query_bdim @@ -853,8 +898,8 @@ def _dot_product_attention_bwd_batcher( grads = _dot_product_attention_bwd_p_wrapper.bind( query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets, - activation, fwd_output, grad_output, scale=scale, seed=seed, - dropout_rate=dropout_rate, variadic_args=variadic_args, + page_table_k, page_table_v, activation, fwd_output, grad_output, + scale=scale, seed=seed, dropout_rate=dropout_rate, variadic_args=variadic_args, mask_type=mask_type, layout=layout, sliding_window_length=sliding_window_length, ) @@ -929,7 +974,7 @@ def _infer_fwd_output_sharding(mesh, arg_shapes, variadic_args,is_training, layo return [out_sharding] _dot_product_attention_fwd_lower = custom_partitioning( - _dot_product_attention_fwd_impl, static_argnums=(8, 9, 10, 11, 12, 13, 14, 15)) + _dot_product_attention_fwd_impl, static_argnums=(10, 11, 12, 13, 14, 15, 16, 17)) def _dot_product_attention_fwd_infer_sharding_from_operands( scale, seed, dropout_rate, variadic_args, mask_type, layout, sliding_window_length, @@ -978,7 +1023,7 @@ def _infer_bwd_output_sharding(mesh, arg_shapes, layout, variadic_args): return out_shardings _dot_product_attention_bwd_lower = custom_partitioning( - _dot_product_attention_bwd_impl, static_argnums=(11, 12, 13, 14, 15, 16, 17) + _dot_product_attention_bwd_impl, static_argnums=(13, 14, 15, 16, 17, 18, 19) ) def _dot_product_attention_bwd_infer_sharding_from_operands( @@ -1018,7 +1063,7 @@ def sharded_impl(*args): _dot_product_attention_fwd_p = core.Primitive("dot_product_attention_fwd") _dot_product_attention_fwd_p.multiple_results = True _dot_product_attention_fwd_p.def_impl( - functools.partial(xla.apply_primitive, _dot_product_attention_fwd_p) + functools.partial(dispatch.apply_primitive, _dot_product_attention_fwd_p) ) _dot_product_attention_fwd_p.def_abstract_eval( _dot_product_attention_fwd_abstract @@ -1043,7 +1088,7 @@ def sharded_impl(*args): _dot_product_attention_bwd_p = core.Primitive("dot_product_attention_bwd") _dot_product_attention_bwd_p.multiple_results = True _dot_product_attention_bwd_p.def_impl( - functools.partial(xla.apply_primitive, _dot_product_attention_bwd_p) + functools.partial(dispatch.apply_primitive, _dot_product_attention_bwd_p) ) _dot_product_attention_bwd_p.def_abstract_eval( _dot_product_attention_bwd_abstract @@ -1071,16 +1116,21 @@ def sharded_impl(*args): _dot_product_attention_bwd_p_wrapper ] = _dot_product_attention_bwd_batcher +def not_implemented_sharding_rule(*args, **kwargs): + return NotImplementedError("Sharding rule not implemented.") + _dot_product_attention_fwd_lower.def_partition( infer_sharding_from_operands=_dot_product_attention_fwd_infer_sharding_from_operands, - partition=_dot_product_attention_fwd_partition) + partition=_dot_product_attention_fwd_partition, + sharding_rule=not_implemented_sharding_rule) mlir.register_lowering(_dot_product_attention_fwd_p_wrapper, mlir.lower_fun(_dot_product_attention_fwd_lower, multiple_results=True)) _dot_product_attention_bwd_lower.def_partition( infer_sharding_from_operands=_dot_product_attention_bwd_infer_sharding_from_operands, - partition=_dot_product_attention_bwd_partition) + partition=_dot_product_attention_bwd_partition, + sharding_rule=not_implemented_sharding_rule) mlir.register_lowering(_dot_product_attention_bwd_p_wrapper, mlir.lower_fun(_dot_product_attention_bwd_lower, multiple_results=True)) @@ -1098,7 +1148,7 @@ def sharded_impl(*args): _dot_product_attention_bwd_p_wrapper ) -@functools.partial(jax.custom_vjp, nondiff_argnums=(8, 9, 10, 11, 12, 13, 14, 15)) +@functools.partial(jax.custom_vjp, nondiff_argnums=(10, 11, 12, 13, 14, 15, 16, 17, 18)) def _dot_product_attention(query: Array, key: Array, value: Array, @@ -1107,6 +1157,8 @@ def _dot_product_attention(query: Array, kv_seqlen: Array, q_offsets: Array, kv_offsets: Array, + page_table_k: Array, + page_table_v: Array, scale: float, seed: int, dropout_rate: float, @@ -1114,13 +1166,14 @@ def _dot_product_attention(query: Array, mask_type: bool, layout: int, sliding_window_length: int | None, - cudnn_version: int): + cudnn_version: int, + return_residual: bool): output = _dot_product_attention_fwd( query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets, - scale=scale, seed=seed, dropout_rate=dropout_rate, + page_table_k, page_table_v, scale=scale, seed=seed, dropout_rate=dropout_rate, variadic_args=variadic_args, mask_type=mask_type, layout=layout, sliding_window_length=sliding_window_length, - cudnn_version=cudnn_version) + cudnn_version=cudnn_version, return_residual=return_residual) return output _dot_product_attention.defvjp( @@ -1161,7 +1214,7 @@ def _dot_product_attention_fp8_fwd( fp8_params_fwd, scale, use_causal_mask, layout, cudnn_version): check_is_flash_attention_fp8( - query, key, layout, cudnn_version, is_training=False) + query, key, value, layout, cudnn_version, is_training=False) descale_q, descale_k, descale_v, descale_s, scale_s, scale_o = fp8_params_fwd outputs = _dot_product_attention_fp8_fwd_p_wrapper.bind( query, key, value, @@ -1175,7 +1228,7 @@ def _dot_product_attention_fp8_fwd_rule( fp8_params, scale, use_causal_mask, layout, cudnn_version): check_is_flash_attention_fp8( - query, key, layout, cudnn_version, is_training=True) + query, key, value, layout, cudnn_version, is_training=True) outputs = _dot_product_attention_fp8_fwd_p_wrapper.bind( query, key, value, *params_from_keys(fp8_params, fp8_params_keys_fwd), @@ -1604,7 +1657,7 @@ def _dot_product_attention_fp8_bwd_partition( _dot_product_attention_fp8_fwd_p = core.Primitive("dot_product_attention_fp8_fwd") _dot_product_attention_fp8_fwd_p.multiple_results = True _dot_product_attention_fp8_fwd_p.def_impl( - functools.partial(xla.apply_primitive, _dot_product_attention_fp8_fwd_p) + functools.partial(dispatch.apply_primitive, _dot_product_attention_fp8_fwd_p) ) _dot_product_attention_fp8_fwd_p.def_abstract_eval( _dot_product_attention_fp8_fwd_abstract @@ -1629,7 +1682,7 @@ def _dot_product_attention_fp8_bwd_partition( _dot_product_attention_fp8_bwd_p = core.Primitive("dot_product_attention_fp8_bwd") _dot_product_attention_fp8_bwd_p.multiple_results = True _dot_product_attention_fp8_bwd_p.def_impl( - functools.partial(xla.apply_primitive, _dot_product_attention_fp8_bwd_p) + functools.partial(dispatch.apply_primitive, _dot_product_attention_fp8_bwd_p) ) _dot_product_attention_fp8_bwd_p.def_abstract_eval( _dot_product_attention_fp8_bwd_abstract @@ -1701,7 +1754,119 @@ def _dot_product_attention_fp8(query: Array, _dot_product_attention_fp8.defvjp(_dot_product_attention_fp8_fwd_rule, _dot_product_attention_fp8_bwd_rule) +def combine_bias_and_mask(bias, mask, dtype): + if bias is not None: + # reshape bias to have 4D shape + bias = bias.reshape((1,) * (4 - len(bias.shape)) + bias.shape) + + if mask is not None: + if mask.dtype == jnp.bool: + large_negative_number = get_large_negative_number(dtype) + mask = jnp.where(mask, jnp.asarray(0, dtype), large_negative_number) + # reshape mask to have 4D shape + mask = mask.reshape((1,) * (4 - len(mask.shape)) + mask.shape) # type: ignore[union-attr] + + # combine bias and mask + if bias is None: + bias = mask + else: + if mask is not None: + # should be broadcast to same shape + bias = bias + mask + return bias + # User interface +def paged_attention( + query: Array, + key: Array, + value: Array, + q_seqlen: Array, + kv_seqlen: Array, + page_table_k: Array, + page_table_v: Array, + bias: Array | None = None, + mask: Array | None = None, + fp8_params: FP8Params | None = None, + *, + scale: float = 1.0, + mask_type: MaskType = MaskType.NO_MASK, + seed: int = 42, + dropout_rate: float = 0., + qkv_layout: str = "BTNH", + sliding_window_length: int | None = None, + use_fp8: bool = False, + return_residual: bool = False +): + """Computes paged attention described in https://arxiv.org/pdf/2309.06180. + + B = batch size + S = length of the key/value (source) + T = length of the query (target) + N = number of attention heads + H = dimensions of each attention head. + + Args: + query: Queries for attention calculation with a shape of BTNH or BNTH. + key: Keys for attention calculation with a shape of + [num_blocks, block_size, N, H] or [num_blocks, N, block_size, H] where + num_blocks = B * Ceil(S / block_size). + value: Values to be used in attention with a shape of + [num_blocks, block_size, N, H] or [num_blocks, N, block_size, H] where + num_blocks = B * Ceil(S / block_size). + q_seqlen: Non padded sequence length of query with a shape of B. + kv_seqlen: Non padded sequence length of key and value with a shape of B. + page_table_k: page table for key of shape [B, 1, num_blocks_per_batch, 1] + where num_blocks_per_batch = Ceil(S / block_size). + page_table_v: page table for value of shape [B, 1, num_blocks_per_batch, 1] + where num_blocks_per_batch = Ceil(S / block_size). + bias: Bias to be added to logits with a shape of BNTS. + mask: Mask used to filter out logits with a shape of BNTS. + scale: Scale for the query. + qkv_layout: Layout string, with supported formats being BTNH, BNTH, BSNH, + BNSH. + sliding_window_length: Window size to make attention only attend to each + token's left local window (pos - sliding_window_length, pos] where `pos` + is the index of each token. E.g., if sliding_window_length == 3 and the + sequence is [0, 1, 2, 3, c, 4, 5], token `c` can attend to [4, 5, c]. + use_fp8: Whether to use FP8 attention mechanism. + return_residual: Whether to return the logsumexp tensor of shape BTN + or BNT to users. See section 3.1.1 in the FlashAttention-2 paper: + https://arxiv.org/pdf/2307.08691 to find the definition of logsumexp. + Returns: + output: the same shape as the query. + residual: the logsumexp tensor if return_residual=True. (non fp8) + """ + cudnn_version = check_cudnn_version() + layout = _normalize_layout(qkv_layout) + if use_fp8: + raise ValueError("Paged attention doesn't support fp8 for now.") + if has_padding(mask_type) and (q_seqlen is None or kv_seqlen is None): + raise ValueError("Require q_seqlen and kv_seqlen to generate padding mask.") + if sliding_window_length is not None and sliding_window_length <= 0: + raise ValueError( + f"Require sliding_window_length > 0, got {sliding_window_length}.") + + bias = combine_bias_and_mask(bias, mask, query.dtype) + # check if input shape and data type is compatiable + check_layout(query, key, value, bias, q_seqlen, kv_seqlen, None, None, + page_table_k, page_table_v, layout) + has_bias = bias is not None + has_dbias = has_bias and \ + should_export_dbias(bias.shape, query.shape, layout) # type: ignore[union-attr] + variadic_args = (has_bias, has_dbias) + + _not_used = jnp.zeros(0, dtype=query.dtype) + if bias is None: + bias = _not_used + + output = _dot_product_attention( + query, key, value, bias, q_seqlen, kv_seqlen, _not_used, _not_used, + page_table_k, page_table_v, scale, seed, dropout_rate, variadic_args, + mask_type, layout.value, sliding_window_length, cudnn_version, + return_residual) + return output + + def dot_product_attention( query: Array, key: Array, @@ -1720,7 +1885,8 @@ def dot_product_attention( dropout_rate: float = 0., qkv_layout: str = "BTNH", sliding_window_length: int | None = None, - use_fp8: bool = False + use_fp8: bool = False, + return_residual: bool = False ): """Computes dot-product attention given query (Q), key (K), and value (V). @@ -1776,8 +1942,12 @@ def dot_product_attention( is the index of each token. E.g., if sliding_window_length == 3 and the sequence is [0, 1, 2, 3, c, 4, 5], token `c` can attend to [4, 5, c]. use_fp8: Whether to use FP8 attention mechanism. + return_residual: Whether to return the logsumexp tensor of shape BTN + or BNT to users. See section 3.1.1 in the FlashAttention-2 paper: + https://arxiv.org/pdf/2307.08691 to find the definition of logsumexp. Returns: - Output of the same shape as the query. + output: the same shape as the query. + residual: the logsumexp tensor if return_residual=True. (non fp8) amax_s: amax of state. (fp8 only) amax_o: amax of output. (fp8 only) """ @@ -1797,7 +1967,8 @@ def dot_product_attention( f"but got: bias={bias}, mask={mask}, q_seqlen={q_seqlen}, kv_seqlen={kv_seqlen}" ) check_fp8_params(fp8_params) - check_layout(query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets, layout) + check_layout(query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets, + None, None, layout) output, amax_s, amax_o = _dot_product_attention_fp8( query, key, value, fp8_params, scale, mask_type == MaskType.CAUSAL, layout.value, cudnn_version @@ -1812,44 +1983,30 @@ def dot_product_attention( if q_offsets is not None and (q_seqlen is None or kv_seqlen is None): raise ValueError("Require q_seqlen and kv_seqlen to use packed layout") - if bias is not None: - # reshape bias to have 4D shape - bias = bias.reshape((1,) * (4 - len(bias.shape)) + bias.shape) - - if mask is not None: - if mask.dtype == jnp.bool: - large_negative_number = get_large_negative_number(query.dtype) - mask = jnp.where(mask, jnp.asarray(0, query.dtype), large_negative_number) - # reshape mask to have 4D shape - mask = mask.reshape((1,) * (4 - len(mask.shape)) + mask.shape) # type: ignore[union-attr] - - # combine bias and mask - if bias is None: - bias = mask - else: - if mask is not None: - # should be broadcast to same shape - bias = bias + mask - + bias = combine_bias_and_mask(bias, mask, query.dtype) # check if input shape and data type is compatiable - check_layout(query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets, layout) + check_layout(query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets, + None, None, layout) has_bias = bias is not None has_dbias = has_bias and \ should_export_dbias(bias.shape, query.shape, layout) # type: ignore[union-attr] variadic_args = (has_bias, has_dbias) + _not_used = jnp.zeros(0, dtype=query.dtype) if bias is None: - bias = jnp.zeros(0, dtype=query.dtype) + bias = _not_used if q_seqlen is None: - q_seqlen = jnp.zeros(0, dtype=query.dtype) + q_seqlen = _not_used if kv_seqlen is None: - kv_seqlen = jnp.zeros(0, dtype=query.dtype) + kv_seqlen = _not_used if q_offsets is None: - q_offsets = jnp.zeros(0, dtype=query.dtype) + q_offsets = _not_used if kv_offsets is None: - kv_offsets = jnp.zeros(0, dtype=query.dtype) + kv_offsets = _not_used + output = _dot_product_attention( query, key, value, bias, q_seqlen, kv_seqlen, q_offsets, kv_offsets, - scale, seed, dropout_rate, variadic_args, mask_type, layout.value, - sliding_window_length, cudnn_version) + _not_used, _not_used, scale, seed, dropout_rate, variadic_args, + mask_type, layout.value, sliding_window_length, cudnn_version, + return_residual) return output diff --git a/jax/_src/cudnn/fusion.py b/jax/_src/cudnn/fusion.py index f320672463cb..355b33e1509c 100644 --- a/jax/_src/cudnn/fusion.py +++ b/jax/_src/cudnn/fusion.py @@ -16,8 +16,8 @@ import jax from jax._src import core as jax_core from jax.interpreters import mlir -from jax.interpreters.mlir import hlo -from jax.interpreters.mlir import ir +from jax._src.lib.mlir import ir +from jax._src.lib.mlir.dialects import hlo diff --git a/jax/_src/cudnn/scaled_matmul_stablehlo.py b/jax/_src/cudnn/scaled_matmul_stablehlo.py index 1a8dee293082..49238f09c46a 100644 --- a/jax/_src/cudnn/scaled_matmul_stablehlo.py +++ b/jax/_src/cudnn/scaled_matmul_stablehlo.py @@ -16,7 +16,6 @@ import json import operator from functools import partial, reduce -from typing import List # Third-party imports import jax @@ -28,7 +27,7 @@ from jax._src.interpreters import batching from jax._src.lax.lax import ranges_like, remaining from jax._src.typing import DTypeLike -from jax.interpreters import mlir, xla +from jax._src.interpreters import mlir from jax.interpreters.mlir import ir from jax.sharding import NamedSharding from jax.sharding import PartitionSpec as P @@ -112,7 +111,7 @@ def _scaled_matmul_abstract(a, b, a_scale, b_scale, *, preferred_element_type): _scaled_matmul_p = core.Primitive("scaled_matmul") _scaled_matmul_p.multiple_results = True -_scaled_matmul_p.def_impl(partial(xla.apply_primitive, _scaled_matmul_p)) +dispatch.simple_impl(_scaled_matmul_p) _scaled_matmul_p.def_abstract_eval(_scaled_matmul_abstract) @@ -492,13 +491,14 @@ def quantize(x, config): elif config.mode == "nvfp4": assert config.scale_type == jnp.float8_e4m3fn assert config.global_scale.dtype == jnp.float32 + SCALE_MAX = jnp.finfo(config.scale_type).max.astype(x.dtype) - scales = scales / config.global_scale - scales_q = jax.lax.optimization_barrier(scales.astype(jnp.float8_e4m3fn)) - scaled_x = x / (scales_q.astype(jnp.float32) * - config.global_scale).astype(x.dtype) + scales_q = jnp.clip(scales / config.global_scale, 0, SCALE_MAX) + scales_q = jax.lax.optimization_barrier(scales_q.astype(config.scale_type)) + scaled_x = x / scales_q.astype(jnp.float32) else: raise ValueError(f"Unrecognized mode: {config.mode}.") + clipped_x = jnp.clip(scaled_x, -MAX, MAX) x_q = clipped_x.astype(config.data_type) @@ -590,7 +590,7 @@ def scaled_dot_general_transpose_lhs( def scaled_dot_general_transpose_rhs( g, x, y, *, dimension_numbers, preferred_element_type: DTypeLike, - configs: List[BlockScaleConfig] + configs: list[BlockScaleConfig] ): (x_contract, y_contract), (x_batch, y_batch) = dimension_numbers swapped_dimension_numbers = ((y_contract, x_contract), (y_batch, x_batch)) @@ -639,6 +639,17 @@ def scaled_dot_bwd(dimension_numbers, preferred_element_type, configs, res, g): } grad_lhs = scaled_dot_general_transpose_lhs(*args, **lhs_kw_args) grad_rhs = scaled_dot_general_transpose_rhs(*args, **rhs_kw_args) + + # We apply a Straight-Through Estimator (STE) with zero-out behavior: if + # inputs are clipped during quantization in fprop, their corresponding gradients + # are zeroed out; otherwise, they pass through unchanged. + if configs[2].mode == "nvfp4": + assert rhs.dtype == lhs.dtype + MAX = jnp.finfo(configs[0].data_type).max.astype(lhs.dtype) + SCALE_MAX = jnp.finfo(configs[0].scale_type).max.astype(lhs.dtype) + grad_lhs = jnp.where(jnp.abs(lhs) <= configs[0].global_scale * MAX * SCALE_MAX, grad_lhs, 0) + grad_rhs = jnp.where(jnp.abs(rhs) <= configs[1].global_scale * MAX * SCALE_MAX, grad_rhs, 0) + return (grad_lhs, grad_rhs) @@ -674,7 +685,7 @@ def _ensure_batch_dim(lhs, rhs, dimension_numbers): def scaled_dot_general_wrapper( lhs, rhs, dimension_numbers, preferred_element_type=jnp.float32, - configs: List[BlockScaleConfig] | None=None, + configs: list[BlockScaleConfig] | None=None, ): if preferred_element_type not in (jnp.float32, jnp.bfloat16, jnp.float16): msg = ('Only support preferred_element_type in (f32, bf16, f16), but got ' diff --git a/jax/_src/custom_batching.py b/jax/_src/custom_batching.py index 338074837ea5..a8876cd9c86c 100644 --- a/jax/_src/custom_batching.py +++ b/jax/_src/custom_batching.py @@ -19,7 +19,6 @@ import functools import operator -from jax import lax from jax._src import api from jax._src import core from jax._src import custom_api_util @@ -103,7 +102,7 @@ class custom_vmap: >>> jax.grad(f)(jnp.zeros(()), jnp.ones(())) Array(1., dtype=float32) - Note that the :py:class:`jax.custom_vjp` must be on the ouside, wrapping the + Note that the :py:class:`jax.custom_vjp` must be on the outside, wrapping the ``custom_vmap``-decorated function. """ @@ -394,6 +393,8 @@ def sequential_vmap(f): See the documentation for :py:class:`~jax.custom_batching.custom_vmap` for more details. """ + from jax._src.lax import control_flow # pytype: disable=import-error + f = custom_vmap(f) @f.def_vmap @@ -405,7 +406,7 @@ def to_map(mapped_args): return f(*args) mapped_args, bcast_args = tree_split(in_batched, list(args)) - out = lax.map(to_map, mapped_args) + out = control_flow.map(to_map, mapped_args) out_batched = tree_map(lambda _: True, out) return out, out_batched diff --git a/jax/_src/custom_dce.py b/jax/_src/custom_dce.py index d336c969a3c4..25fe604085fd 100644 --- a/jax/_src/custom_dce.py +++ b/jax/_src/custom_dce.py @@ -251,9 +251,9 @@ def flatten_dce_rule( # For error checking purposes, we need to reformat the pytree structure # of the output of the DCE rule to match the original output. The catch is # that the DCE rule can return a None to indicated an unused subtree, so we - # need to rebuild those subtrees with a sentinal value at the leaves. This + # need to rebuild those subtrees with a sentinel value at the leaves. This # logic is very similar to what is used in custom_dervatives._flatten_bwd. - sentinal = object() + sentinel = object() dummy = tree_util.tree_unflatten(out_tree, [object()] * out_tree.num_leaves) keypaths, _ = util.unzip2(tree_util.tree_flatten_with_path(dummy)[0]) out_flat = [] @@ -261,7 +261,7 @@ def flatten_dce_rule( def append(x, d): num_leaves = len(tree_util.tree_flatten(d)[0]) if x is None and d is not None: - out_flat.extend([sentinal] * num_leaves) + out_flat.extend([sentinel] * num_leaves) elif x is not None: out_flat.extend([x] * num_leaves) return x @@ -281,7 +281,7 @@ def append(x, d): for kp, used, aval, val in zip(keypaths, used_outs, out_avals, out_flat): if not used: continue - if val is sentinal: + if val is sentinel: raise ValueError( f"Custom DCE rule {rule_name} for function {fun_name} must produce " "values for all of the required outputs (as specified by the " diff --git a/jax/_src/custom_derivatives.py b/jax/_src/custom_derivatives.py index 32856106ad8f..e4fc919e1bb1 100644 --- a/jax/_src/custom_derivatives.py +++ b/jax/_src/custom_derivatives.py @@ -31,7 +31,8 @@ stop_gradient_p, SymbolicZero, Zero, zeros_like_aval) from jax._src.api_util import ( argnums_partial, flatten_fun_nokwargs, resolve_kwargs, - prepend_static_args, debug_info) + prepend_static_args, debug_info, fun_signature, + infer_argnums_and_argnames) from jax._src.errors import UnexpectedTracerError from jax._src.state.types import AbstractRef from jax._src.interpreters import ad @@ -40,11 +41,10 @@ from jax._src.interpreters import partial_eval as pe from jax._src.interpreters import xla from jax._src.interpreters.batching import not_mapped -from jax._src.lax import lax from jax._src.tree_util import ( tree_flatten, tree_unflatten, tree_map, treedef_is_leaf, treedef_tuple, register_pytree_node_class, tree_leaves, tree_flatten_with_path, - tree_leaves_with_path, keystr, treedef_children, PyTreeDef) + tree_leaves_with_path, keystr, treedef_children, tree_structure, PyTreeDef) from jax._src.util import (cache, safe_zip, safe_map, split_list, unzip2, weakref_lru_cache) @@ -87,7 +87,7 @@ def _flatten_fun_nokwargs(f: Callable, ans = f(*py_args) ans_flat, ans_tree = tree_flatten(ans) ans_avals = [core.get_aval(x) for x in ans_flat] - store.store((ans_tree, ans_avals)) + store.store((ans_tree, ans_avals, ())) return ans_flat @@ -130,20 +130,35 @@ def f_jvp(primals, tangents): For a more detailed introduction, see the tutorial_. - .. _tutorial: https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html + .. _tutorial: https://docs.jax.dev/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html """ fun: Callable[..., ReturnValue] nondiff_argnums: Sequence[int] + nondiff_argnames: Sequence[str] jvp: Callable[..., tuple[ReturnValue, ReturnValue]] | None = None symbolic_zeros: bool = False def __init__(self, fun: Callable[..., ReturnValue], nondiff_argnums: Sequence[int] = (), + nondiff_argnames: Sequence[str] = (), ): update_wrapper(self, fun) self.fun = fun - self.nondiff_argnums = nondiff_argnums + + nondiff_argnums_: set[int] = set() + if nondiff_argnames: + sig = fun_signature(self.fun) + assert sig is not None + inferred_nondiff_argnums, _ = infer_argnums_and_argnames( + sig, None, nondiff_argnames + ) + nondiff_argnums_.update(inferred_nondiff_argnums) + + if nondiff_argnums: + nondiff_argnums_.update(nondiff_argnums) + + self.nondiff_argnums = tuple(sorted(nondiff_argnums_)) __getattr__ = custom_api_util.forward_attr @@ -260,10 +275,9 @@ def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue: # pytype: disable ) from e if self.nondiff_argnums: - nondiff_argnums = set(self.nondiff_argnums) - args = tuple(_stop_gradient(x) if i in nondiff_argnums else x + args = tuple(_stop_gradient(x) if i in self.nondiff_argnums else x for i, x in enumerate(args)) - diff_argnums = [i for i in range(len(args)) if i not in nondiff_argnums] + diff_argnums = [i for i in range(len(args)) if i not in self.nondiff_argnums] f_, dyn_args = argnums_partial(lu.wrap_init(self.fun, debug_info=debug), diff_argnums, args, require_static_args_hashable=False) @@ -287,7 +301,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue: # pytype: disable in_tree, out_type1) out_flat = custom_jvp_call_p.bind(flat_fun, flat_jvp, *args_flat, symbolic_zeros=self.symbolic_zeros) - _, (out_tree, _) = lu.merge_linear_aux(out_type1, out_type2) + _, (out_tree, _, _) = lu.merge_linear_aux(out_type1, out_type2) return tree_unflatten(out_tree, out_flat) @partial(lu.transformation_with_aux2, use_eq_store=True) @@ -314,7 +328,7 @@ def _flatten_jvp(f, store, primal_name, jvp_name, in_tree, maybe_out_type, *args try: out_type_ = maybe_out_type() except lu.StoreException: out_type_ = None if out_type_ is not None: - out_tree_, primal_avals_ = out_type_ + out_tree_, primal_avals_, () = out_type_ ty_tree = tree_unflatten(out_tree , [a.str_short() for a in primal_avals]) ty_tree_ = tree_unflatten(out_tree_, [a.str_short() for a in primal_avals_]) if out_tree_ != out_tree: @@ -366,7 +380,7 @@ def _flatten_jvp(f, store, primal_name, jvp_name, in_tree, maybe_out_type, *args if av_et != av_t) raise TypeError(msg.format('\n'.join(disagreements))) - store.store((out_tree, primal_avals)) + store.store((out_tree, primal_avals, ())) return primals_out + tangents_out class CustomJVPCallPrimitive(core.Primitive): @@ -410,8 +424,6 @@ def jvp(*xs): return [*out_primals, *out_tangents] return lu.wrap_init(jvp, debug_info=jvp_jaxpr_fun.debug_info) -effects.custom_derivatives_allowed_effects.add_type(lax.InOutFeedEffect) - custom_jvp_call_p = CustomJVPCallPrimitive('custom_jvp_call') def _custom_jvp_call_typecheck(_, *in_avals, call_jaxpr, jvp_jaxpr_fun, @@ -425,16 +437,14 @@ def _custom_jvp_call_typecheck(_, *in_avals, call_jaxpr, jvp_jaxpr_fun, return call_jaxpr.out_avals, call_jaxpr.effects core.custom_typechecks[custom_jvp_call_p] = _custom_jvp_call_typecheck -def _custom_jvp_call_mlir_translation(ctx, *args, call_jaxpr, jvp_jaxpr_fun, - num_consts, symbolic_zeros): - del jvp_jaxpr_fun, num_consts, symbolic_zeros +def _custom_jvp_vjp_call_lowering(ctx, *args, call_jaxpr, **_): consts = mlir._ir_consts(call_jaxpr.consts) out, tokens = mlir.jaxpr_subcomp(ctx.module_context, call_jaxpr.jaxpr, ctx.name_stack, ctx.tokens_in, consts, *args, dim_var_values=ctx.dim_var_values) ctx.set_tokens_out(tokens) return out -mlir.register_lowering(custom_jvp_call_p, _custom_jvp_call_mlir_translation) +mlir.register_lowering(custom_jvp_call_p, _custom_jvp_vjp_call_lowering) # If a (multi)linear function is defined with a custom jvp, then # custom_jvp_call_ can appear in jaxprs to be transposed. Since it's already @@ -487,6 +497,21 @@ def dce_jvp_jaxpr_thunk(*in_zeros): pe.dce_rules[custom_jvp_call_p] = _custom_jvp_call_dce +def _custom_jvp_call_pp_rule(eqn: core.JaxprEqn, + context: core.JaxprPpContext, + settings: core.JaxprPpSettings) -> core.pp.Doc: + params = dict(eqn.params) + if not params["num_consts"]: + params.pop("num_consts") + params["jvp"] = params.pop("jvp_jaxpr_fun").debug_info.func_name + names = sorted(params) + params["name"] = params["call_jaxpr"].jaxpr.debug_info.func_name + return core._pp_eqn(eqn.replace(params=params), context, settings, + params=["name"] + names) + + +core.pp_eqn_rules[custom_jvp_call_p] = _custom_jvp_call_pp_rule + ### VJPs @custom_api_util.register_custom_decorator_type @@ -521,15 +546,29 @@ def f_bwd(res, g): For a more detailed introduction, see the tutorial_. - .. _tutorial: https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html + .. _tutorial: https://docs.jax.dev/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html """ def __init__(self, fun: Callable[..., ReturnValue], - nondiff_argnums: Sequence[int] = ()): + nondiff_argnums: Sequence[int] = (), + nondiff_argnames: Sequence[str] = ()): update_wrapper(self, fun) self.fun = fun - self.nondiff_argnums = nondiff_argnums + + nondiff_argnums_: set[int] = set() + if nondiff_argnames: + sig = fun_signature(self.fun) + assert sig is not None + inferred_nondiff_argnums, _ = infer_argnums_and_argnames( + sig, None, nondiff_argnames + ) + nondiff_argnums_.update(inferred_nondiff_argnums) + + if nondiff_argnums: + nondiff_argnums_.update(nondiff_argnums) + + self.nondiff_argnums = tuple(sorted(nondiff_argnums_)) self.fwd: Callable[..., tuple[ReturnValue, Any]] | None = None self.bwd: Callable[..., tuple[Any, ...]] | None = None self.symbolic_zeros = False @@ -671,8 +710,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue: # pytype: disable else: if self.nondiff_argnums: for i in self.nondiff_argnums: _check_for_tracers(args[i]) - nondiff_argnums = set(self.nondiff_argnums) - dyn_argnums = [i for i in range(len(args)) if i not in nondiff_argnums] + dyn_argnums = [i for i in range(len(args)) if i not in self.nondiff_argnums] f_, dyn_args = argnums_partial( lu.wrap_init(self.fun, debug_info=debug_fun), dyn_argnums, args, require_static_args_hashable=False) @@ -698,24 +736,27 @@ def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue: # pytype: disable out_flat = custom_vjp_call_p.bind(flat_fun, flat_fwd, flat_bwd, *args_flat, out_trees=out_trees, symbolic_zeros=self.symbolic_zeros) - _, (out_tree, _) = lu.merge_linear_aux(out_type, out_trees) + _, (out_tree, _, _) = lu.merge_linear_aux(out_type, out_trees) return tree_unflatten(out_tree, out_flat) @lu.transformation2 -def _check_primal_refs(f: Callable, nondiff_argnums: Sequence[int], - debug_info: core.DebugInfo, *args): - _check_for_aliased_refs(f, nondiff_argnums, debug_info, args) +def _check_primal_refs( + f: Callable, nondiff_argnums: Sequence[int], debug: core.DebugInfo, *args): + _check_for_aliased_refs(f, nondiff_argnums, debug, args) out = f(*args) - _check_for_returned_refs(f, out, 'primal') + _check_for_returned_refs(f, out, 'primal', [], 0) return out -def _check_for_aliased_refs(f: Callable, - nondiff_argnums: Sequence[int], - debug: core.DebugInfo, - args): +def _check_for_aliased_refs( + f: Callable, nondiff_argnums: Sequence[int], debug: core.DebugInfo, args): + nondiff_argnums_ = set(nondiff_argnums) + argnums = [x for i, arg in enumerate(args) + for x in [i] * tree_structure(arg).num_leaves] leaves = tree_leaves(args) refs: dict[int, int] = {} - for i, x in enumerate(leaves): + for i, (argnum, x) in enumerate(zip(argnums, leaves)): + if argnum in nondiff_argnums: continue + x = x.value if isinstance(x, CustomVJPPrimal) else x if (isinstance((a := core.get_aval(x)), AbstractRef) and (dup_idx := refs.setdefault(id(core.get_referent(x)), i)) != i): arg_names = debug.safe_arg_names(len(leaves)) @@ -725,14 +766,21 @@ def _check_for_aliased_refs(f: Callable, f"array reference of type {a.str_short()} at {arg_names[dup_idx]} and" f" {arg_names[i]}.") -def _check_for_returned_refs(f, out, kind): +def _check_for_returned_refs(f, out, kind, args, after_idx): + args = [x.value if isinstance(x, CustomVJPPrimal) else x for x in args] + ids = {id(x) for x in args if isinstance(core.get_aval(x), AbstractRef)} leaves = tree_leaves_with_path(out) - for path, leaf in leaves: + for i, (path, leaf) in enumerate(leaves): if isinstance((a := core.get_aval(leaf)), AbstractRef): loc = f' at output tree path {keystr(path)}' if path else '' - raise ValueError(f"custom_vjp {kind} function {f} returned a mutable " - f"a array reference of type {a.str_short()}{loc}, " - "but mutable array references cannot be returned.") + if i < after_idx: + raise ValueError(f"custom_vjp {kind} function {f} returned a mutable " + f"array reference of type {a.str_short()}{loc}, " + "but mutable array references cannot be returned there.") + if id(leaf) not in ids: + raise ValueError(f"custom_vjp {kind} function {f} returned a mutable " + f"array reference of type {a.str_short()}{loc} " + "that was not an argument.") @dataclasses.dataclass class CustomVJPPrimal: @@ -787,8 +835,6 @@ def _flatten_fwd(f: Callable, store: lu.EqualStore, if config.mutable_array_checks.value: _check_for_aliased_refs(f, nondiff_argnums, debug_primal, py_args) pair_out = f(*py_args) - if config.mutable_array_checks.value: - _check_for_returned_refs(f, pair_out, "fwd") if not isinstance(pair_out, (list, tuple)) or len(pair_out) != 2: msg = (f"Custom VJP fwd rule {fwd_name} for function {primal_name} " "must produce a pair (list or tuple of length two) where the first " @@ -801,12 +847,14 @@ def _flatten_fwd(f: Callable, store: lu.EqualStore, py_primals_out, res = pair_out primals_out, out_tree = tree_flatten(py_primals_out) res, res_tree = tree_flatten(res) + if config.mutable_array_checks.value: + _check_for_returned_refs(f, pair_out, "fwd", args, out_tree.num_leaves) primal_avals = [core.get_aval(x) for x in primals_out] # If the primal function already ran, check out_tree agreement. try: out_type_ = maybe_out_type() except lu.StoreException: out_type_ = None if out_type_ is not None: - out_tree_, primal_avals_ = out_type_ + out_tree_, primal_avals_, () = out_type_ ty_tree = tree_unflatten(out_tree , [a.str_short() for a in primal_avals]) ty_tree_ = tree_unflatten(out_tree_, [a.str_short() for a in primal_avals_]) if out_tree_ != out_tree: @@ -838,15 +886,21 @@ def _flatten_fwd(f: Callable, store: lu.EqualStore, "shapes/dtypes of:\n" f""" {str(ty_tree_).replace("'", "")}""") raise TypeError(m) - store.store((out_tree, res_tree)) - return (*res, *primals_out) + pruned_res, input_forwards = _filter_forwarded_inputs(res, args) # prune + store.store((out_tree, res_tree, input_forwards)) + return (*pruned_res, *primals_out) + +def _filter_forwarded_inputs(outs, ins): + idxs: dict[int, int] = {id(x): i for i, x in enumerate(ins)} + return [o for o in outs if id(o) not in idxs], [idxs.get(id(o)) for o in outs] @lu.transformation2 def _flatten_bwd(f: Callable, in_tree: PyTreeDef, in_avals: Sequence[core.AbstractValue], - out_trees: Callable[[], Sequence[PyTreeDef]], *args): - out_tree, res_tree = out_trees() + out_trees: Callable[[], tuple[PyTreeDef, PyTreeDef, list[int | None]]], + *args): + out_tree, res_tree, _ = out_trees() assert len(args) == res_tree.num_leaves + out_tree.num_leaves res, cts_out = split_list(args, [res_tree.num_leaves]) py_res = tree_unflatten(res_tree, res) @@ -904,7 +958,8 @@ def append(x, d): "shape/dtypes as the args tuple of the primal function, but at " f"output{keystr(kp)} the bwd rule produced an output of " f"shape/dtype {a_.str_short()} corresponding " - f"to an input of shape/dtype {a.str_short()}.") + f"to an input of shape/dtype {a.str_short()}" + f"{core.aval_mismatch_extra(a, a_)}") raise ValueError(msg) results.append(ct) return results @@ -921,8 +976,8 @@ def _temporary_dtype_exception(a, a_) -> bool: def _temporary_shape_exception(a, a_) -> bool: return config.custom_vjp_disable_shape_check.value -class CustomVJPCallPrimitive(core.CallPrimitive): - initial_style: core.Primitive +class CustomVJPCallPrimitive(core.Primitive): + multiple_results = True def bind(self, *args, **params): return self._true_bind(*args, **params) @@ -931,119 +986,81 @@ def bind_with_trace(self, trace, args, params): fun, fwd, bwd, tracers = args[0], args[1], args[2], args[3:] return trace.process_custom_vjp_call(self, fun, fwd, bwd, tracers, **params) -custom_vjp_call_p = CustomVJPCallPrimitive('custom_vjp_call') + def impl(self, fun, fwd, bwd, *args): + raise NotImplementedError -def _custom_vjp_call_jaxpr_impl(*args, fun_jaxpr, **_): - return core.jaxpr_as_fun(fun_jaxpr)(*args) + def get_bind_params(self, params): + new_params = dict(params) + call_jaxpr: core.ClosedJaxpr = new_params.pop('call_jaxpr') + num_consts: int = new_params.pop('num_consts') + fwd_jaxpr_thunk = new_params.pop('fwd_jaxpr_thunk') + fun = lu.wrap_init(core.jaxpr_as_fun(call_jaxpr), + debug_info=call_jaxpr.jaxpr.debug_info) + fwd = lift_fwd(num_consts, fwd_jaxpr_thunk) + const_avals, _ = split_list(call_jaxpr.in_avals, [num_consts]) + bwd = _handle_consts_in_bwd(new_params.pop('bwd'), const_avals) + return [fun, fwd, bwd], new_params + +def lift_fwd(num_consts: int, fwd_jaxpr_thunk: lu.WrappedFun) -> lu.WrappedFun: + def fwd(*args): + vals, nonzeros = args[::2], args[1::2] + assert len(vals) == len(nonzeros) + _, primals = split_list(vals, [num_consts]) + const_nonzeros, in_nonzeros = split_list(nonzeros, [num_consts]) + if any(const_nonzeros): raise ad.CustomVJPException() + fwd_jaxpr, fwd_consts = fwd_jaxpr_thunk.call_wrapped(*in_nonzeros) + return core.eval_jaxpr(fwd_jaxpr, fwd_consts, *primals) + return lu.wrap_init(fwd, debug_info=fwd_jaxpr_thunk.debug_info) -def _custom_vjp_call_jaxpr_abstract_eval(*_, fun_jaxpr, **__): - disallowed_effects = effects.custom_derivatives_allowed_effects.filter_not_in(fun_jaxpr.effects) +@lu.transformation2 +def _handle_consts_in_bwd(f, const_avals, *args): + return [Zero(a) for a in const_avals] + list(f(*args)) + +custom_vjp_call_p = CustomVJPCallPrimitive('custom_vjp_call') +mlir.register_lowering(custom_vjp_call_p, _custom_jvp_vjp_call_lowering) + +def _custom_vjp_call_typecheck(_, *in_avals, call_jaxpr, **kwargs): + del in_avals, kwargs + disallowed_effects = effects.custom_derivatives_allowed_effects.filter_not_in( + call_jaxpr.effects) if disallowed_effects: raise NotImplementedError( f'Effects not supported in `custom_vjp`: {disallowed_effects}') - return fun_jaxpr.out_avals, fun_jaxpr.effects - -custom_vjp_call_jaxpr_p = core.Primitive('custom_vjp_call_jaxpr') -custom_vjp_call_jaxpr_p.multiple_results = True -custom_vjp_call_jaxpr_p.def_impl(_custom_vjp_call_jaxpr_impl) -custom_vjp_call_jaxpr_p.def_effectful_abstract_eval(_custom_vjp_call_jaxpr_abstract_eval) -CustomVJPCallPrimitive.initial_style = custom_vjp_call_jaxpr_p - -mlir.register_lowering(custom_vjp_call_jaxpr_p, mlir.lower_fun( - _custom_vjp_call_jaxpr_impl, multiple_results=True)) - -def _custom_vjp_call_jaxpr_jvp( - primals, tangents, *, fun_jaxpr: core.ClosedJaxpr, - fwd_jaxpr_thunk: Callable[..., tuple[core.Jaxpr, Sequence[Any]]], - num_consts: int, bwd: lu.WrappedFun, - out_trees: Callable[[], Sequence[PyTreeDef]], - symbolic_zeros: bool): - _, args = split_list(primals, [num_consts]) - consts_dot, args_dot = split_list(tangents, [num_consts]) - if any(type(t) is not Zero for t in consts_dot): - raise ad.CustomVJPException() - zeros = [type(t) is not Zero for t in args_dot] - fwd_jaxpr, fwd_consts = fwd_jaxpr_thunk(*zeros) # consts can be tracers! - _, res_tree = out_trees() - res_and_primals_out = core.eval_jaxpr(fwd_jaxpr, fwd_consts, *args) - res, primals_out = split_list(res_and_primals_out, [res_tree.num_leaves]) - avals_out = [core.get_aval(x).to_tangent_aval() for x in primals_out] - args_dot = map(ad.instantiate_zeros, args_dot) - tangents_out = ad.custom_lin_p.bind( - *res, *args_dot, num_res=res_tree.num_leaves, bwd=bwd, - out_avals=avals_out, symbolic_zeros=symbolic_zeros) - tangents_out = map(lax.tie_p.bind, primals_out, tangents_out) - return primals_out, tangents_out -ad.primitive_jvps[custom_vjp_call_jaxpr_p] = _custom_vjp_call_jaxpr_jvp - -def _custom_vjp_call_jaxpr_vmap( - axis_data, args, in_dims, *, - fun_jaxpr: core.ClosedJaxpr, - fwd_jaxpr_thunk: Callable[..., tuple[core.Jaxpr, Sequence[Any]]], - num_consts: int, bwd: lu.WrappedFun, - out_trees: Callable, symbolic_zeros: bool): - args = [batching.moveaxis(x, d, 0) if d is not not_mapped and d != 0 - else x for x, d in zip(args, in_dims)] - in_batched = [d is not not_mapped for d in in_dims] - _, args_batched = split_list(in_batched, [num_consts]) - batched_fun_jaxpr, out_batched = batching.batch_jaxpr( - fun_jaxpr, axis_data, in_batched, False) - out_dims1 = [0 if b else not_mapped for b in out_batched] - out_dims2 = [] - - @pe._memoize - def batched_fwd_jaxpr_thunk(*zeros): - fwd_jaxpr = core.ClosedJaxpr(*fwd_jaxpr_thunk(*zeros)) # consts can be tracers - batched_fwd_jaxpr, out_batched = batching.batch_jaxpr( - fwd_jaxpr, axis_data, args_batched, False) - out_dims2.append([0 if b else not_mapped for b in out_batched]) - return batched_fwd_jaxpr.jaxpr, batched_fwd_jaxpr.consts - - fwd_args_batched = [0 if b else not_mapped for b in args_batched] - fwd_out_dims = lambda: out_dims2[0] - tag = core.TraceTag() - batched_bwd = batching.batch_custom_vjp_bwd( - bwd, tag, axis_data, fwd_out_dims, fwd_args_batched) - - batched_outs = custom_vjp_call_jaxpr_p.bind( - *args, fun_jaxpr=batched_fun_jaxpr, - fwd_jaxpr_thunk=batched_fwd_jaxpr_thunk, bwd=batched_bwd, - num_consts=num_consts, out_trees=out_trees, symbolic_zeros=symbolic_zeros) - out_dims = out_dims2[0] if out_dims2 else out_dims1 - return batched_outs, out_dims -batching.fancy_primitive_batchers[custom_vjp_call_jaxpr_p] = _custom_vjp_call_jaxpr_vmap + return call_jaxpr.out_avals, call_jaxpr.effects +core.custom_typechecks[custom_vjp_call_p] = _custom_vjp_call_typecheck -def _custom_vjp_call_jaxpr_dce( +def _custom_vjp_call_dce( used_outs: Sequence[bool], eqn: core.JaxprEqn ) -> tuple[list[bool], core.JaxprEqn | None]: if not any(used_outs) and not pe.has_effects(eqn): return [False] * len(eqn.invars), None - fun_jaxpr: core.ClosedJaxpr = eqn.params["fun_jaxpr"] + call_jaxpr: core.ClosedJaxpr = eqn.params["call_jaxpr"] fwd_jaxpr_thunk = eqn.params["fwd_jaxpr_thunk"] bwd: lu.WrappedFun = eqn.params["bwd"] - out_trees: Callable[[], Sequence[PyTreeDef]] = eqn.params["out_trees"] + out_trees: Callable[[], tuple[PyTreeDef, PyTreeDef, list[int | None]]] = eqn.params["out_trees"] symbolic_zeros: bool = eqn.params["symbolic_zeros"] - dce_fun_jaxpr: core.ClosedJaxpr + dce_call_jaxpr: core.ClosedJaxpr used_ins: Sequence[bool] - dce_fun_jaxpr, used_ins = _cached_closed_call_dce_instantiate( - fun_jaxpr, tuple(used_outs)) + dce_call_jaxpr, used_ins = _cached_closed_call_dce_instantiate( + call_jaxpr, tuple(used_outs)) assert all(used_ins) + @partial(lu.wrap_init, debug_info=fwd_jaxpr_thunk.debug_info) @pe._memoize def dce_fwd_jaxpr_thunk(*zeros): - fwd_jaxpr = core.ClosedJaxpr(*fwd_jaxpr_thunk(*zeros)) - _, res_tree = out_trees() - num_res = res_tree.num_leaves + fwd_jaxpr = core.ClosedJaxpr(*fwd_jaxpr_thunk.call_wrapped(*zeros)) + _, res_tree, fwds = out_trees() + num_res_out = res_tree.num_leaves - sum(f is not None for f in fwds) dce_fwd_jaxpr, _ = _cached_closed_call_dce_instantiate( - fwd_jaxpr, (True,) * num_res + tuple(used_outs)) + fwd_jaxpr, (True,) * num_res_out + tuple(used_outs)) return dce_fwd_jaxpr.jaxpr, dce_fwd_jaxpr.consts def dce_bwd(*args): - _, res_tree = out_trees() + _, res_tree, _ = out_trees() res, cts = split_list(args, [res_tree.num_leaves]) cts_ = iter(cts) all_cts = [] - for used, aval in zip(used_outs, fun_jaxpr.out_avals): + for used, aval in zip(used_outs, call_jaxpr.out_avals): if used: all_cts.append(next(cts_)) else: @@ -1060,17 +1077,32 @@ def dce_bwd(*args): outvars = [v for used, v in zip(used_outs, eqn.outvars) if used] new_params = dict( eqn.params, - fun_jaxpr=dce_fun_jaxpr, + call_jaxpr=dce_call_jaxpr, fwd_jaxpr_thunk=dce_fwd_jaxpr_thunk, bwd=dce_bwd_wrapped, ) new_eqn = pe.new_jaxpr_eqn( - eqn.invars, outvars, eqn.primitive, new_params, dce_fun_jaxpr.effects, + eqn.invars, outvars, eqn.primitive, new_params, dce_call_jaxpr.effects, eqn.source_info, eqn.ctx) return list(used_ins), new_eqn -pe.dce_rules[custom_vjp_call_jaxpr_p] = _custom_vjp_call_jaxpr_dce +pe.dce_rules[custom_vjp_call_p] = _custom_vjp_call_dce + -xla.register_initial_style_primitive(custom_vjp_call_jaxpr_p) +def _custom_vjp_call_pp_rule(eqn: core.JaxprEqn, + context: core.JaxprPpContext, + settings: core.JaxprPpSettings) -> core.pp.Doc: + params = dict(eqn.params) + if not params["num_consts"]: + params.pop("num_consts") + params.pop("out_trees") + params["fwd"] = params.pop("fwd_jaxpr_thunk").debug_info.func_name + params["bwd"] = params.pop("bwd").debug_info.func_name + names = sorted(params) + params["name"] = params["call_jaxpr"].jaxpr.debug_info.func_name + return core._pp_eqn(eqn.replace(params=params), context, settings, + params=["name"] + names) + +core.pp_eqn_rules[custom_vjp_call_p] = _custom_vjp_call_pp_rule batching.primitive_batchers[ad.custom_lin_p] = ad.raise_custom_vjp_error_on_jvp mlir.register_lowering(ad.custom_lin_p, ad.raise_custom_vjp_error_on_jvp) @@ -1276,8 +1308,8 @@ def _maybe_perturbed(x: Any) -> bool: @cache() def _closure_convert_for_avals(fun, in_tree, in_avals, debug_info: core.DebugInfo): - wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun, debug_info=debug_info), - in_tree) + wrapped_fun, out_tree = flatten_fun_nokwargs( + lu.wrap_init(fun, debug_info=debug_info), in_tree) jaxpr, out_pvals, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals) out_tree = out_tree() @@ -1424,53 +1456,65 @@ def linear_call(fun: Callable, (residual_args, linear_args), {})), t_in_tree) - t_jaxpr, t_consts = _initial_style_jaxpr(t, (*res_avals, *out_avals)) - t_jaxpr_closed = _close_jaxpr(t_jaxpr) - - if t_out_tree() != lin_tree: - raise TypeError( - 'transpose output pytree structure must match that of linear inputs, ' - f'got output structure {t_out_tree()} ' - f'and input structure {lin_tree}.') + @pe._memoize + def transpose_thunk(): + t_jaxpr, t_consts = _initial_style_jaxpr(t, (*res_avals, *out_avals)) + if t_out_tree() != lin_tree: + raise TypeError( + 'transpose output pytree structure must match that of linear inputs, ' + f'got output structure {t_out_tree()} ' + f'and input structure {lin_tree}.') + return _close_jaxpr(t_jaxpr), t_consts - out = linear_call_p.bind(*f_consts, *t_consts, *operands_res, *operands_lin, + out = linear_call_p.bind(*f_consts, *operands_res, *operands_lin, callee=f_jaxpr_closed, - transpose=t_jaxpr_closed, + transpose_thunk=transpose_thunk, num_callee_consts=len(f_consts), - num_transpose_consts=len(t_consts), num_res=len(operands_res)) return tree_unflatten(out_tree(), out) -def _linear_call_impl(*args, callee, transpose, num_callee_consts, - num_transpose_consts, num_res): - del transpose - consts, _, operands_res, operands_lin = split_list( - args, [num_callee_consts, num_transpose_consts, num_res]) - return core.eval_jaxpr(callee.jaxpr, (), *consts, *operands_res, *operands_lin) - -def _linear_call_transpose_rule(cts, *args, callee, transpose, - num_callee_consts, - num_transpose_consts, num_res): - f_consts, t_consts, operands_res, operands_lin = split_list( - args, [num_callee_consts, num_transpose_consts, num_res]) +def _linear_call_impl(*args, callee, transpose_thunk, num_callee_consts, + num_res): + del transpose_thunk, num_callee_consts, num_res + return core.eval_jaxpr(callee.jaxpr, (), *args) + +def _linear_call_jvp_rule(primals, tangents, callee, transpose_thunk, + num_callee_consts, num_res): + consts_and_res, primals = split_list(primals, [num_callee_consts + num_res]) + const_tangents, tangents = split_list(tangents, [num_callee_consts + num_res]) + assert all(type(t) is Zero for t in const_tangents) + primals_out = linear_call_p.bind( + *consts_and_res, *primals, callee=callee, transpose_thunk=transpose_thunk, + num_callee_consts=num_callee_consts, num_res=num_res) + tangents_out = linear_call_p.bind( + *consts_and_res, *tangents, callee=callee, transpose_thunk=transpose_thunk, + num_callee_consts=num_callee_consts, num_res=num_res) + return primals_out, tangents_out + +def _linear_call_transpose_rule(cts, *args, callee, transpose_thunk, + num_callee_consts, num_res): + transpose, t_consts = transpose_thunk() + f_consts, operands_res, operands_lin = split_list( + args, [num_callee_consts, num_res]) _, _, cts_avals = split_list( - transpose.in_avals, [num_transpose_consts, num_res]) + transpose.in_avals, [len(t_consts), num_res]) assert all(ad.is_undefined_primal(x) for x in operands_lin) assert all(not ad.is_undefined_primal(x) for x in operands_res) + def new_transpose_thunk(): + return callee, f_consts + cts = [zeros_like_aval(a) if type(ct) is Zero else ct for ct, a in zip(cts, cts_avals)] - - cts_out = linear_call_p.bind(*t_consts, *f_consts, *operands_res, *cts, + cts_out = linear_call_p.bind(*t_consts, *operands_res, *cts, callee=transpose, - transpose=callee, + transpose_thunk=new_transpose_thunk, num_callee_consts=len(t_consts), - num_transpose_consts=len(f_consts), num_res=len(operands_res)) - return [None] * (num_callee_consts + num_transpose_consts + num_res) + cts_out + return [None] * (num_callee_consts + num_res) + cts_out def _linear_call_abstract_eval(*args, **kwargs): return kwargs['callee'].out_avals @@ -1479,6 +1523,7 @@ def _linear_call_abstract_eval(*args, **kwargs): linear_call_p.multiple_results = True linear_call_p.def_impl(_linear_call_impl) linear_call_p.def_abstract_eval(_linear_call_abstract_eval) +ad.primitive_jvps[linear_call_p] = _linear_call_jvp_rule ad.primitive_transposes[linear_call_p] = _linear_call_transpose_rule xla.register_initial_style_primitive(linear_call_p) mlir.register_lowering(linear_call_p, mlir.lower_fun( @@ -1558,7 +1603,6 @@ def jvp(primals, tangents): # TODO(mattjj): remove these stubs, which exist to avoid breaking internal users custom_jvp_call_jaxpr_p = core.Primitive("custom_jvp_call_jaxpr") - # The following is a helper for optimizing the behavior of custom_vjp when used # under remat. This is really only useful when the `fwd` function to custom_vjp # executes a black box kernel. Otherwise, DCE will perform this optimization @@ -1612,26 +1656,27 @@ def wrapped_fwd(*args, **kwargs) -> tuple[ReturnValue, Any]: in_avals = [core.get_aval(x) for x in args_flat] fwd_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fwd, in_avals) fwd_jaxpr = pe.close_jaxpr(pe.convert_constvars_jaxpr(fwd_jaxpr)) - prim_tree, res_tree = out_trees() - num_res = res_tree.num_leaves + prim_tree, res_tree, fwds = out_trees() + num_res_out = res_tree.num_leaves - sum(f is not None for f in fwds) - if fwd_jaxpr.effects: + disallowed_effects = effects.custom_derivatives_allowed_effects.filter_not_in(fwd_jaxpr.effects) + if disallowed_effects: raise NotImplementedError( "remat optimization for custom_vjp does not support forward " - f"functions with side effects, but {fwd_name} has the following " - f"effects: {fwd_jaxpr.effects}") + f"functions with these side effects: {disallowed_effects}") @pe._memoize def fun_jaxpr_thunk(): jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals) return jaxpr, consts - out_flat = remat_opt_p.bind(*consts, *args_flat, - num_consts=len(consts), - num_res=num_res, - fwd_jaxpr=fwd_jaxpr, + out_flat = remat_opt_p.bind(*consts, *args_flat, num_consts=len(consts), + num_res=num_res_out, fwd_jaxpr=fwd_jaxpr, fun_jaxpr_thunk=fun_jaxpr_thunk) - res, out_flat = split_list(out_flat, [num_res]) + res, out_flat = split_list(out_flat, [num_res_out]) + res_ = iter(res) + res = [next(res_) if f is None else args_flat[f] for f in fwds] + assert next(res_, None) is None out_tree = treedef_tuple((prim_tree, res_tree)) return tree_unflatten(out_tree, (*out_flat, *res)) diff --git a/jax/_src/custom_partitioning.py b/jax/_src/custom_partitioning.py index 5374071517f1..3c6fe7299dbe 100644 --- a/jax/_src/custom_partitioning.py +++ b/jax/_src/custom_partitioning.py @@ -21,20 +21,23 @@ from functools import partial import inspect -from typing import Any, Callable +from typing import Any +from collections.abc import Callable import weakref import numpy as np -import jax -from jax import tree_util + +from jax._src import api from jax._src import api_util from jax._src import config from jax._src import core from jax._src import custom_api_util from jax._src import dispatch +from jax._src import errors from jax._src import linear_util as lu from jax._src import mesh as mesh_lib from jax._src import sharding_impls +from jax._src import tree_util from jax._src import xla_bridge as xb from jax._src.custom_partitioning_sharding_rule import sdy_sharding_rule_to_mlir, SdyShardingRule, str_to_sdy_sharding_rule from jax._src.interpreters import mlir @@ -42,7 +45,7 @@ from jax._src.lib import xla_client as xc from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo -from jax.errors import UnexpectedTracerError +from jax._src.sharding import Sharding def _resolve_kwargs(fun, args, kwargs): @@ -93,7 +96,7 @@ def _to_jax_shape(s): def _to_jax_sharded_shape(s, sharding): - return jax.ShapeDtypeStruct( + return api.ShapeDtypeStruct( s.dimensions(), s.numpy_dtype(), sharding=sharding ) @@ -140,7 +143,7 @@ def _custom_partitioning_propagate_user_sharding(user_sharding, shape, def _to_hlo_sharding(sharding, num_dimensions): - if not isinstance(sharding, jax.sharding.Sharding): + if not isinstance(sharding, Sharding): raise ValueError("Custom Partitioning rules must return Sharding.") return sharding._to_xla_hlo_sharding(num_dimensions) @@ -178,7 +181,7 @@ def _custom_partitioning_partition(arg_shapes, arg_shardings, result_shape, _to_jax_shape(sharding.tile(s)) for sharding, s in zip(result_shardings, result_shapes) ] - closed_jaxpr = jax.make_jaxpr(lower_fn, axis_env=list(mesh.shape.items()))( + closed_jaxpr = api.make_jaxpr(lower_fn, axis_env=list(mesh.shape.items()))( *info.in_tree.unflatten(tiled_args) ) if ([(o.shape, o.dtype) for o in closed_jaxpr.out_avals] != @@ -251,7 +254,7 @@ def _custom_partitioning_impl(*args, call, in_tree, out_tree, def _check_for_tracers(x): if any(isinstance(leaf, core.Tracer) for leaf in tree_util.tree_leaves(x)): - raise UnexpectedTracerError( + raise errors.UnexpectedTracerError( "Found a JAX Tracer object passed as an argument to a" "custom_partitioning function in a position indicated as static by" "static_argnums. " @@ -482,10 +485,10 @@ def __call__(self, *args, **kwargs): args, require_static_args_hashable=False, ) - static_args = [args[i] for i in self.static_argnums] + static_args = tuple(args[i] for i in self.static_argnums) _check_for_tracers(static_args) else: - static_args = [] + static_args = () f_, dyn_args = lu.wrap_init(self.fun, debug_info=debug), args args_flat, in_tree = tree_util.tree_flatten(dyn_args) flat_fun, out_tree = api_util.flatten_fun_nokwargs(f_, in_tree) @@ -500,6 +503,14 @@ def __call__(self, *args, **kwargs): infer_sharding_from_operands = None sharding_rule = None if config.use_shardy_partitioner.value: + if (self.sharding_rule is None and + (self.propagate_user_sharding is not None or + self.infer_sharding_from_operands is not None)): + raise NotImplementedError( + "Shardy is used, but sharding propagation callbacks instead of " + "sharding_rule are provided. Need to provide sharding_rule to " + "migrate to Shardy." + ) sharding_rule = self.sharding_rule else: propagate_user_sharding = self.propagate_user_sharding @@ -557,11 +568,11 @@ def to_mesh_pspec_sharding(hlo_sharding: xc.HloSharding | None, ndim): return hlo_sharding if mesh.empty or not decode_shardings: assert devices is not None - return sharding_impls._op_sharding_to_pos_sharding(hlo_sharding, devices) + return sharding_impls.GSPMDSharding(devices, hlo_sharding) pspec = sharding_impls.parse_flatten_op_sharding( hlo_sharding, mesh)[0] - pspec = jax.sharding.PartitionSpec(*pspec, *((None,) * (ndim - len(pspec)))) - return jax.sharding.NamedSharding(mesh, pspec) + pspec = sharding_impls.PartitionSpec(*pspec, *((None,) * (ndim - len(pspec)))) + return sharding_impls.NamedSharding(mesh, pspec) sharding_callback_info = _ShardingCallbackInfo(propagate_user_sharding, partition, to_mesh_pspec_sharding, in_tree, out_tree, diff --git a/jax/_src/custom_partitioning_sharding_rule.py b/jax/_src/custom_partitioning_sharding_rule.py index 5e2e5f4e0479..bc27f34b3bfb 100644 --- a/jax/_src/custom_partitioning_sharding_rule.py +++ b/jax/_src/custom_partitioning_sharding_rule.py @@ -15,6 +15,7 @@ """Implements SdyShardingRule.""" from collections import OrderedDict +from typing import Union from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import sdy @@ -28,7 +29,7 @@ _BATCHING_DIM_FACTOR_PREFIX = "?" # A Jax value in general corresponds to an ir.Type or a tuple of ir.Types. -IrTypes = ir.Type | tuple[ir.Type, ...] +IrTypes = Union[ir.Type, tuple[ir.Type, ...]] def _check_factor(factor:str): """Validates a factor. @@ -137,12 +138,12 @@ def __init__(self, operand_mappings: tuple[ArrayMapping, ...], # Check that factors that are used for a whole dimension aren't in # factor_sizes and factors that are never used for a whole dimension are # in factor_sizes. - for factor, inferrable in factors_inferrable.items(): - if factor not in factor_sizes and not inferrable: + for factor, inferable in factors_inferrable.items(): + if factor not in factor_sizes and not inferable: raise ValueError( f"Factor {factor} is only used in compound factors; must specify" " its size") - if factor in factor_sizes and inferrable: + if factor in factor_sizes and inferable: raise ValueError( f"Factor {factor} represents a whole dimension; do not specify its" " size") @@ -472,4 +473,5 @@ def build_dim_mapping_for_compound_factors(i, j, factors): return sdy.OpShardingRuleAttr.get( factor_sizes=[item[1] for item in factors_to_indices_sizes.values()], operand_mappings=tensor_mappings[0:len(operand_types)], - result_mappings=tensor_mappings[len(operand_types):]) + result_mappings=tensor_mappings[len(operand_types):], + is_custom=True) diff --git a/jax/_src/custom_transpose.py b/jax/_src/custom_transpose.py index 5e87fdb203c9..fb125e174122 100644 --- a/jax/_src/custom_transpose.py +++ b/jax/_src/custom_transpose.py @@ -177,15 +177,19 @@ def bind_with_trace(self, trace, call_args, params): # TODO(frostig,mattjj): consider keeping `call` as a named parameter # instead of following this "call primitive" convention. def get_bind_params(self, params): - assert 'call_jaxpr' in params - assert 'transpose_jaxpr_thunk' in params - new_params: dict[str, Any] = dict(params) - new_params['transpose'] = make_transpose_from_thunk( - new_params.pop('transpose_jaxpr_thunk'), - new_params['lin_tree']) - call_jaxpr: core.ClosedJaxpr = new_params.pop('call_jaxpr') - call = lu.wrap_init(core.jaxpr_as_fun(call_jaxpr), - debug_info=call_jaxpr.jaxpr.debug_info) + if 'call_jaxpr' in params: + assert 'transpose_jaxpr_thunk' in params + new_params: dict[str, Any] = dict(params) + new_params['transpose'] = make_transpose_from_thunk( + new_params.pop('transpose_jaxpr_thunk'), + new_params['lin_tree']) + call_jaxpr: core.ClosedJaxpr = new_params.pop('call_jaxpr') + call = lu.wrap_init(core.jaxpr_as_fun(call_jaxpr), + debug_info=call_jaxpr.jaxpr.debug_info) + else: + assert 'transpose' in params + new_params: dict[str, Any] = dict(params) + call = new_params.pop("call") return [call], new_params @@ -213,7 +217,6 @@ def custom_transpose_transpose_rule( # Consider passing this information to the custom transpose rule? res_arg, lin_arg = tree_unflatten(call_in_tree, args) - del lin_arg assert all(not ad.is_undefined_primal(x) for x in tree_leaves(res_arg)) cts = [ad_util.zeros_like_aval(ct.aval) if type(ct) is ad_util.Zero else ct @@ -221,10 +224,17 @@ def custom_transpose_transpose_rule( ct_out = tree_unflatten(out_tree, cts) ct_lin = transpose.call_wrapped(res_arg, ct_out) check_transpose_rule_trees(transpose, lin_tree, tree_structure(ct_lin)) - ct_lin_flat, _ = tree_flatten( - tree_broadcast(lin_tree, ct_lin, is_leaf=lambda x: x is None), - is_leaf=lambda x: x is None) - return [None] * len(tree_leaves(res_arg)) + ct_lin_flat + ct_lin = tree_broadcast(lin_tree, ct_lin, is_leaf=lambda x: x is None) + + # When the transpose returns None, we treat that as a Zero, except when the + # input is also None. In that case, the cotangent corresponding to that input + # should be dropped. + zero = object() + ct_lin = tree_map(lambda l, ct: zero if ct is None and l is not None else ct, + lin_arg, ct_lin, is_leaf=ad.is_undefined_primal) + + ct_lin_flat, _ = tree_flatten(ct_lin) + return [None] * res_tree.num_leaves + [None if ct is zero else ct for ct in ct_lin_flat] def custom_transpose_lowering(*args, call_jaxpr, **params): diff --git a/jax/_src/debugger/cli_debugger.py b/jax/_src/debugger/cli_debugger.py index bf4b38765026..eb1eca3bec48 100644 --- a/jax/_src/debugger/cli_debugger.py +++ b/jax/_src/debugger/cli_debugger.py @@ -105,7 +105,7 @@ def do_pp(self, arg): def do_up(self, _): """u(p) - Move down a stack frame. + Move up a stack frame. """ if self.frame_index == len(self.frames) - 1: print('At topmost frame.', file=self.stdout) diff --git a/jax/_src/debugging.py b/jax/_src/debugging.py index b61b28e12f43..7c09ab998195 100644 --- a/jax/_src/debugging.py +++ b/jax/_src/debugging.py @@ -69,6 +69,8 @@ class OrderedDebugEffect(effects.Effect): effects.remat_allowed_effects.add_type(OrderedDebugEffect) effects.custom_derivatives_allowed_effects.add_type(DebugEffect) effects.custom_derivatives_allowed_effects.add_type(OrderedDebugEffect) +effects.partial_eval_kept_effects.add_type(DebugEffect) +effects.partial_eval_kept_effects.add_type(OrderedDebugEffect) # `debug_callback_p` is the main primitive for staging out Python callbacks. debug_callback_p = core.Primitive('debug_callback') @@ -78,8 +80,8 @@ class OrderedDebugEffect(effects.Effect): @debug_callback_p.def_impl def debug_callback_impl(*args, callback: Callable[..., Any], - effect: DebugEffect): - del effect + effect: DebugEffect, partitioned: bool): + del effect, partitioned try: cpu_device, *_ = jax.local_devices(backend="cpu") except RuntimeError as e: @@ -99,8 +101,8 @@ def debug_callback_impl(*args, callback: Callable[..., Any], @debug_callback_p.def_effectful_abstract_eval def debug_callback_abstract_eval(*flat_avals, callback: Callable[..., Any], - effect: DebugEffect): - del flat_avals, callback + effect: DebugEffect, partitioned: bool): + del flat_avals, callback, partitioned return [], {effect} def debug_callback_batching_rule(args, dims, **params): @@ -126,14 +128,13 @@ def debug_callback_jvp_rule(primals, tangents, **params): return debug_callback_p.bind(*primals, **params), [] ad.primitive_jvps[debug_callback_p] = debug_callback_jvp_rule -def debug_callback_transpose_rule(*flat_args, callback: Callable[..., Any], - effect: DebugEffect): - del flat_args, callback, effect - raise ValueError("Transpose doesn't support debugging callbacks.") +def debug_callback_transpose_rule(_, *flat_args, callback: Callable[..., Any], + effect: DebugEffect, partitioned): + del callback, effect, partitioned + return [None for _ in flat_args] ad.primitive_transposes[debug_callback_p] = debug_callback_transpose_rule def _debug_callback_partial_auto(axis_context, *args, **params): - from jax.experimental.shard_map import shard_map partial_auto = list(set(axis_context.mesh.axis_names) - axis_context.manual_axes) def f(): idx = jax.lax.with_sharding_constraint( @@ -142,9 +143,9 @@ def f(): return jax.lax.cond(idx == 0, lambda: debug_callback_p.bind(*args, **params), lambda: []) - return shard_map(f, axis_context.mesh, in_specs=(), out_specs=[])() + return jax.shard_map(f, in_specs=(), out_specs=[])() -def debug_callback_lowering(ctx, *args, effect, callback, **params): +def debug_callback_lowering(ctx, *args, effect, partitioned, callback, **params): axis_context = ctx.module_context.axis_context if isinstance(axis_context, sharding_impls.SPMDAxisContext): # We're a shard_map, which might be partial-manual or full-manual. @@ -152,21 +153,29 @@ def debug_callback_lowering(ctx, *args, effect, callback, **params): if partial_auto: # If we have partial manual / partial auto sharding, we gather and # conditionally run the callback. - lower = partial(_debug_callback_partial_auto, axis_context, - effect=effect, callback=callback, **params) + lower = partial( + _debug_callback_partial_auto, + axis_context, + effect=effect, + partitioned=partitioned, + callback=callback, + **params, + ) return mlir.lower_fun(lower)(ctx, *args) elif set(axis_context.manual_axes) == set(axis_context.mesh.axis_names): # If we have fully manual sharding during lowering, that means the JAX # program has per-device semantics, so we run the callback on each device. if config.use_shardy_partitioner.value: - assert len(ctx.avals_out) == 1 - sharding = sharding_impls.SdyArrayShardingList([ - sharding_impls.SdyArraySharding( + ndim = 0 + if ctx.avals_out and isinstance(ctx.avals_out[0], core.ShapedArray): + ndim = ctx.avals_out[0].ndim + sharding = sharding_impls.SdyArrayList([ + sharding_impls.SdyArray( mesh_shape=(), - dimension_shardings=[ - sharding_impls.SdyDimSharding(axes=[], is_closed=True) - ] * ctx.avals_out[0].ndim, - logical_device_ids=())]) + dim_shardings=[ + sharding_impls.SdyDim(axes=[], is_open=False) + ] * ndim, + logical_device_ids=(0,))]) else: sharding = xc.OpSharding() sharding.type = xc.OpSharding.Type.MANUAL @@ -177,9 +186,9 @@ def debug_callback_lowering(ctx, *args, effect, callback, **params): # program has bulk array semantics, so we run the callback with a MAXIMAL # sharding and hence execute it only once on the full logical value). if config.use_shardy_partitioner.value: - sharding = sharding_impls.SdyArrayShardingList([ - sharding_impls.SdyArraySharding( - mesh_shape=(), dimension_shardings=[], logical_device_ids=(0,))]) + sharding = sharding_impls.SdyArrayList([ + sharding_impls.SdyArray( + mesh_shape=(), dim_shardings=[], logical_device_ids=(0,))]) else: sharding = xc.OpSharding() sharding.type = xc.OpSharding.Type.MAXIMAL @@ -191,18 +200,23 @@ def debug_callback_lowering(ctx, *args, effect, callback, **params): def _callback(*flat_args): debug_callback_p.impl( - *flat_args, effect=effect, callback=callback, **params) + *flat_args, + effect=effect, + partitioned=partitioned, + callback=callback, + **params, + ) return () if effects.ordered_effects.contains(effect): token = ctx.tokens_in.get(effect) result, token, _ = cb.emit_python_callback( ctx, _callback, token, list(args), ctx.avals_in, ctx.avals_out, - has_side_effect=True) + has_side_effect=True, partitioned=partitioned) ctx.set_tokens_out(mlir.TokenSet({effect: token})) else: result, _, _ = cb.emit_python_callback( ctx, _callback, None, list(args), ctx.avals_in, ctx.avals_out, - has_side_effect=True, sharding=sharding) + has_side_effect=True, partitioned=partitioned, sharding=sharding) return result mlir.register_lowering(debug_callback_p, debug_callback_lowering, platform="cpu") @@ -244,14 +258,22 @@ def _debug_callback_partial_eval_custom(saveable, unks_in, inst_in, eqn): @state_discharge.register_discharge_rule(debug_callback_p) def _debug_callback_state_discharge_rule( - in_avals, out_avals, *args, effect, callback, **params + in_avals, out_avals, *args, effect, partitioned, callback, **params ): del in_avals, out_avals # Unused. - out = debug_callback_p.bind(*args, effect=effect, callback=callback, **params) + out = debug_callback_p.bind( + *args, effect=effect, partitioned=partitioned, callback=callback, **params + ) return args, out -def debug_callback(callback: Callable[..., None], *args: Any, - ordered: bool = False, **kwargs: Any) -> None: + +def debug_callback( + callback: Callable[..., None], + *args: Any, + ordered: bool = False, + partitioned: bool = False, + **kwargs: Any, +) -> None: """Calls a stageable Python callback. For more explanation, see `External Callbacks`_. @@ -274,6 +296,9 @@ def debug_callback(callback: Callable[..., None], *args: Any, ordered: A keyword only argument used to indicate whether or not the staged out computation will enforce ordering of this callback w.r.t. other ordered callbacks. + partitioned: If True, then print local shards only; this option avoids an + all-gather of the operands. If False, print with logical operands; this + option requires an all-gather of operands first. **kwargs: The keyword arguments to the callback. Returns: @@ -284,7 +309,7 @@ def debug_callback(callback: Callable[..., None], *args: Any, - :func:`jax.pure_callback`: callback designed for pure functions. - :func:`jax.debug.print`: callback designed for printing. - .. _External Callbacks: https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html + .. _External Callbacks: https://docs.jax.dev/en/latest/notebooks/external_callbacks.html """ if not callable(callback): raise TypeError("first argument to jax.debug.callback must be callable, " @@ -312,7 +337,10 @@ def _flat_callback(*dyn_args): return () effect = ordered_debug_effect if ordered else debug_effect - debug_callback_p.bind(*dyn_args, callback=_flat_callback, effect=effect) + debug_callback_p.bind( + *dyn_args, callback=_flat_callback, effect=effect, partitioned=partitioned + ) + class _DebugPrintFormatChecker(string.Formatter): @@ -338,7 +366,10 @@ def _format_print_callback(fmt: str, np_printoptions, *args, **kwargs): with np.printoptions(**np_printoptions): sys.stdout.write(fmt.format(*args, **kwargs) + "\n") -def debug_print(fmt: str, *args, ordered: bool = False, **kwargs) -> None: + +def debug_print( + fmt: str, *args, ordered: bool = False, partitioned: bool = False, **kwargs +) -> None: """Prints values and works in staged out JAX functions. This function does *not* work with f-strings because formatting is delayed. @@ -367,6 +398,9 @@ def debug_print(fmt: str, *args, **kwargs): ordered: A keyword only argument used to indicate whether or not the staged out computation will enforce ordering of this ``jax.debug.print`` w.r.t. other ordered ``jax.debug.print`` calls. + partitioned: If True, then print local shards only; this option avoids an + all-gather of the operands. If False, print with logical operands; this + option requires an all-gather of operands first. **kwargs: Additional keyword arguments to be formatted, as if passed to ``fmt.format``. """ @@ -374,7 +408,7 @@ def debug_print(fmt: str, *args, **kwargs): formatter.format(fmt, *args, **kwargs) debug_callback(partial(_format_print_callback, fmt, np.get_printoptions()), - *args, **kwargs, ordered=ordered) + *args, **kwargs, ordered=ordered, partitioned=partitioned) # Sharding visualization @@ -429,6 +463,7 @@ def _inspect_sharding_lowering_rule(ctx: mlir.LoweringRuleContext, value, *, mesh = mesh_lib.Mesh(np.array(devices).reshape(am.axis_sizes), am.axis_names) elif isinstance(axis_context, sharding_impls.SPMDAxisContext): + mesh = axis_context.mesh devices = axis_context.mesh._flat_devices_tuple else: raise NotImplementedError(type(axis_context)) @@ -439,8 +474,9 @@ def _inspect_sharding_lowering_rule(ctx: mlir.LoweringRuleContext, value, *, def _hlo_sharding_callback(hlo_sharding: xc.HloSharding): if mesh.empty: return callback( - sharding_impls._op_sharding_to_pos_sharding(hlo_sharding, devices)) - pspec = parse_flatten_op_sharding(hlo_sharding, mesh)[0] + sharding_impls.GSPMDSharding(devices, hlo_sharding)) + pspec = (P() if hlo_sharding.is_manual() else + parse_flatten_op_sharding(hlo_sharding, mesh)[0]) return callback(NamedSharding(mesh, pspec)) if len(devices) == 1: diff --git a/jax/_src/deprecations.py b/jax/_src/deprecations.py index 37f2f0264782..4e5e22745658 100644 --- a/jax/_src/deprecations.py +++ b/jax/_src/deprecations.py @@ -127,12 +127,12 @@ def warn(deprecation_id: str, message: str, stacklevel: int) -> None: register('jax-dlpack-import-legacy') register('jax-nn-one-hot-float-input') register("jax-numpy-astype-complex-to-real") -register("jax-numpy-array-none") register('jax-numpy-clip-args') register('jax-numpy-linalg-matrix_rank-tol') register('jax-numpy-linalg-pinv-rcond') register('jax-numpy-quantile-interpolation') register('jax-numpy-reduction-non-boolean-where') register('jax-numpy-trimzeros-not-1d-array') -register('pallas-gpu-triton') register('jax-scipy-special-sph-harm') +register('jax-jit-positional-args') +register('jax-abstract-dunder-array') diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index 2330f7628966..f51213e8f5ad 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -24,9 +24,8 @@ import logging import threading import time -from typing import Any, Callable, NamedTuple +from typing import Any -import jax from jax._src import api from jax._src import array from jax._src import basearray @@ -34,24 +33,28 @@ from jax._src import core from jax._src import dtypes from jax._src import lib -from jax._src import source_info_util +from jax._src import pjit from jax._src import traceback_util from jax._src import util + +from jax._src import xla_bridge from jax._src.abstract_arrays import array_types from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir from jax._src.interpreters import pxla from jax._src.interpreters import xla -from jax._src.layout import DeviceLocalLayout, Layout +from jax._src.api_util import InternalFloatingPointError +from jax._src.layout import DeviceLocalLayout, Format from jax._src.lib import xla_client as xc from jax._src.mesh import AbstractMesh, Mesh -from jax._src.monitoring import record_event_duration_secs, record_event_time_span +from jax._src.monitoring import record_scalar, record_event_duration_secs, record_event_time_span from jax._src.partition_spec import PartitionSpec from jax._src.sharding import Sharding -from jax._src.sharding_impls import ( NamedSharding, - SingleDeviceSharding, TransferToMemoryKind, +from jax._src.sharding_impls import ( + NamedSharding, SingleDeviceSharding, TransferToMemoryKind, GSPMDSharding, is_single_device_sharding) +from jax._src.stages import SourceInfo import numpy as np @@ -132,11 +135,12 @@ def get_token_input( # TODO(yueshengys): This might still be buggy in a multi-process SPMD # scenario. Revise the logic later. A distributed shutdown barrier inside # the XLA program may be needed. - return jax.device_put(tok, jax.sharding.PositionalSharding(devices)) + return api.device_put( + tok, NamedSharding(Mesh(devices, 'x'), PartitionSpec('x'))) # We only use replicated sharding for the first time when the token for the # order effect hasn't been created. - s = jax.sharding.GSPMDSharding.get_replicated(devices) + s = GSPMDSharding.get_replicated(devices) sharded_tok = core.Token(pxla.shard_args([s], [None], [None], [tok])[0]) self.current_tokens[eff] = sharded_tok return sharded_tok @@ -178,6 +182,10 @@ def __init__(self, fmt: str, fun_name: str, event: str | None = None): def __enter__(self): self.start_time = time.time() + if self.event is not None: + record_scalar( + self.event, self.start_time, fun_name=self.fun_name + ) def __exit__(self, exc_type, exc_value, traceback): if _on_exit: @@ -190,8 +198,12 @@ def __exit__(self, exc_type, exc_value, traceback): logger.log(log_priority, self.fmt.format( fun_name=self.fun_name, elapsed_time=elapsed_time)) if self.event is not None: - record_event_duration_secs(self.event, elapsed_time) - record_event_time_span(self.event, self.start_time, end_time) + record_event_duration_secs( + self.event, elapsed_time, fun_name=self.fun_name + ) + record_event_time_span( + self.event, self.start_time, end_time, fun_name=self.fun_name + ) log_elapsed_time = LogElapsedTimeContextManager @@ -231,16 +243,10 @@ def jaxpr_has_prim_requiring_devices(jaxpr: core.Jaxpr) -> bool: return False -class SourceInfo(NamedTuple): - source_info: source_info_util.SourceInfo - eqn_name: str - - @util.weakref_lru_cache def get_intermediate_shardings( jaxpr: core.Jaxpr) -> Sequence[tuple[Sharding, SourceInfo]]: - from jax._src import pjit - from jax.experimental import shard_map + from jax._src import shard_map # pytype: disable=import-error out = [] for eqn in jaxpr.eqns: @@ -255,14 +261,12 @@ def get_intermediate_shardings( out.extend((i, source_info) for i in eqn.params['in_shardings']) out.extend((o, source_info) for o in eqn.params['out_shardings']) elif eqn.primitive is shard_map.shard_map_p: - if isinstance(eqn.params['mesh'], AbstractMesh): + mesh = eqn.params['mesh'] + if isinstance(mesh, AbstractMesh): continue source_info = SourceInfo(eqn.source_info, eqn.primitive.name) - def _names_to_pspec(names): - ndmin = max(names) + 1 if names else 0 - return PartitionSpec(*(names.get(i) for i in range(ndmin))) - out.extend((NamedSharding(eqn.params['mesh'], _names_to_pspec(names)), source_info) - for names in [*eqn.params['in_names'], *eqn.params['out_names']]) + out.extend((NamedSharding(mesh, spec), source_info) + for spec in [*eqn.params['in_specs'], *eqn.params['out_specs']]) elif eqn.primitive is device_put_p: source_info = SourceInfo(eqn.source_info, eqn.primitive.name) out.extend((s, source_info) for s in eqn.params['devices'] @@ -339,43 +343,6 @@ class CopySemantics(enum.Enum): COPY = enum.auto() DONATE = enum.auto() -class InternalFloatingPointError(Exception): - name: str - ty: str - - def __init__(self, name: str, ty: str): - self.name = name - self.ty = ty - -def maybe_recursive_nan_check(e: Exception, fun: Callable, args, kwargs, -) -> None: # always raises an exception - print("Invalid nan value encountered in the output of a jax.jit " - "function. Calling the de-optimized version.") - try: - _ = fun(*args, **kwargs) - except (FloatingPointError, ZeroDivisionError) as e2: - raise e2 from None - else: - _raise_no_nan_in_deoptimized(e) - -def _raise_no_nan_in_deoptimized(e) -> None: - msg = (f"{str(e)}. Because " - "jax_config.debug_nans.value and/or config.jax_debug_infs is set, the " - "de-optimized function (i.e., the function as if the `jit` " - "decorator were removed) was called in an attempt to get a more " - "precise error message. However, the de-optimized function did not " - "produce invalid values during its execution. This behavior can " - "result from `jit` optimizations causing the invalid value to be " - "produced. It may also arise from having nan/inf literals as " - "inputs or outputs, like `jax.jit(lambda ...: jax.numpy.nan)(...)`. " - "\n\n" - "It may be possible to avoid the invalid value by removing the " - "`jit` decorator, at the cost of losing optimizations. " - "\n\n" - "If you see this error, consider opening a bug report at " - "https://github.com/jax-ml/jax.") - raise FloatingPointError(msg) from None - def _identity_fn(x): return x @@ -389,16 +356,6 @@ def _different_device_order_reshard(x, target_sharding, copy: CopySemantics): return api.jit(_identity_fn, out_shardings=target_sharding, donate_argnums=donate_argnums)(x) - if inp_sharding.device_set != target_sharding.device_set: - inp_ids = [d.id for d in inp_sharding._device_assignment] - inp_plat = inp_sharding._device_assignment[0].platform.upper() - target_ids = [d.id for d in target_sharding._device_assignment] - target_plat = target_sharding._device_assignment[0].platform.upper() - raise ValueError("Input and target sharding should have the same set of " - f"devices. Got input's device set ids: {inp_ids} on " - f"platform {inp_plat} and target sharding's device set " - f"ids: {target_ids} on platform {target_plat}") - if inp_sharding.is_fully_replicated: permute_order = None else: @@ -422,6 +379,31 @@ def _reorder_shards(x, new_s, copy_semantics: CopySemantics): return xc.reorder_shards(x, new_s, xc_copy_semantics) # type: ignore +@util.cache() +def _is_supported_cross_host_transfer(ndim, src_sharding, dst_sharding): + """Returns True if src->dst is a supported cross-host transfer.""" + backend = xla_bridge.get_backend() + # There is experimental support for cross-host device transfers on TFRT TPU + # backends only. + # TODO: https://github.com/jax-ml/jax/issues/26645 - Allow backends to be + # queried for their cross-host transfer support. + if (xla_bridge.process_count() == 1 or backend.platform not in {"gpu", "tpu"} + or (backend.platform == "gpu" and not backend.platform_version.startswith("cuda")) + or (backend.platform == "tpu" and not backend.platform_version.startswith("TFRT TPU"))): + return False + if (src_sharding._internal_device_list.device_kind != + dst_sharding._internal_device_list.device_kind): + return False + if (src_sharding._to_xla_hlo_sharding(ndim) != + dst_sharding._to_xla_hlo_sharding(ndim)): + return False + # This check excludes the case where the source and destination shardings + # have the same process index sets but there are shards that require + # cross-host transfers. This case is supportable but expensive to check for. + return (src_sharding._internal_device_list.process_indices != + dst_sharding._internal_device_list.process_indices) + + @dataclasses.dataclass(frozen=True) class _DeferredShardArg: """Deferred call to `pxla.shard_args`. @@ -443,7 +425,7 @@ def result_handler(self, shard_arg_result): def _device_put_sharding_impl(x, aval, device, copy): - from jax.experimental import multihost_utils + from jax.experimental import multihost_utils # pytype: disable=import-error if isinstance(device, Sharding): s = device @@ -452,7 +434,8 @@ def _device_put_sharding_impl(x, aval, device, copy): return x if (not s.is_fully_addressable and - isinstance(x, array.ArrayImpl) and not x.is_fully_addressable): + isinstance(x, array.ArrayImpl) and not x.is_fully_addressable and + s.device_set == x.sharding.device_set): assert isinstance(s, Sharding) return _different_device_order_reshard(x, s, copy) @@ -463,12 +446,43 @@ def _device_put_sharding_impl(x, aval, device, copy): assert isinstance(s, Sharding) return _different_device_order_reshard(x, s, copy) + # There is experimental support for cross-host device transfers on TFRT TPU. + if (isinstance(x, array.ArrayImpl) and x._committed + and _is_supported_cross_host_transfer(x.ndim, x.sharding, s)): + return xc.batched_copy_array_to_devices_with_sharding( + [x], [s._internal_device_list], [s], # pytype: disable=attribute-error + pxla.to_xc_copy_semantics([copy]))[0] + if not s.is_fully_addressable: + # If both the source and target shardings are not fully addressable and + # one of the above conditions has not been met, then assume that the user + # is attempting a different device order reshard. + if (isinstance(x, array.ArrayImpl) and not x.is_fully_addressable + and s.device_set != x.sharding.device_set): + inp_ids = [d.id for d in x.sharding._device_assignment] + inp_plat = x.sharding._device_assignment[0].platform.upper() + target_ids = [d.id for d in s._device_assignment] + target_plat = s._device_assignment[0].platform.upper() + raise ValueError( + "For a cross-host reshard in multi-controller JAX, input and target" + " sharding should have the same set of devices. Got input's device" + f" set ids: {inp_ids} on platform {inp_plat} and target sharding's" + f" device set ids: {target_ids} on platform {target_plat}.\n\n" + "There is experimental support for cross-host transfers with " + "different device sets, when input/output shardings have the same " + "indices and layouts, in the TFRT TPU runtime only.") + if ((isinstance(x, array.ArrayImpl) and not x._committed) or - type(x) in array_types): - # TODO(emilyaf): Remove this condition when jit works when a sharding - # has no local devices. - if not config.enable_empty_arrays.value: + type(x) in array_types or type(x) in dtypes.python_scalar_dtypes): + # If all hosts participate in the sharding, assert that the input is the + # same on all hosts. If some hosts have no addressable devices in the + # sharding, bypass the check, since we can't easily distinguish between + # these two cases: (1) the sharding contains the same subset of global + # devices on all hosts (and hosts with no addressable devices in the + # sharding do not transfer data) or (2) the sharding contains a + # different subset of devices on each host. For (1), the input should be + # the same on all hosts, but for (2) it need not be. + if xla_bridge.process_count() == len(s._internal_device_list.process_indices): # pytype: disable=attribute-error multihost_utils.assert_equal( x, fail_message=( f"{type(x)} passed to device_put is not the same on each" @@ -495,6 +509,9 @@ def _device_put_sharding_impl(x, aval, device, copy): return _DeferredShardArg(x, x.sharding, aval, x.committed, copy) elif is_single_device_sharding(x.sharding): device = x.sharding._device_assignment[0] if device is None else device + if copy == CopySemantics.COPY: + return xc.batched_device_put(aval, SingleDeviceSharding(device), [x], + [device], True, True) return pxla.batched_device_put(aval, SingleDeviceSharding(device), [x], [device]) @@ -504,8 +521,8 @@ def _device_put_sharding_impl(x, aval, device, copy): def _device_put_impl( - x, *, device: Device | Sharding | Layout | None, - src: Device | Sharding | Layout | None, copy: CopySemantics): + x, *, device: Device | Sharding | Format | None, + src: Device | Sharding | Format | None, copy: CopySemantics): if (isinstance(device, TransferToMemoryKind) or isinstance(src, TransferToMemoryKind)): raise ValueError( @@ -519,10 +536,10 @@ def _device_put_impl( raise TypeError( f"Argument '{x}' of type {type(x)} is not a valid JAX type") from err - if isinstance(device, Layout): + if isinstance(device, Format): l = device dll = l.device_local_layout - x_dll = x.layout.device_local_layout if hasattr(x, 'layout') else None + x_dll = x.format.device_local_layout if hasattr(x, 'format') else None if dll is None and l.sharding is None: return _device_put_sharding_impl(x, aval, l.sharding, copy) if (not isinstance(l.sharding, Sharding) or @@ -544,8 +561,8 @@ def _device_put_impl( def _batched_device_put_impl( *xs, - devices: Sequence[Device | Sharding | Layout | None], - srcs: Sequence[Device | Sharding | Layout | None], + devices: Sequence[Device | Sharding | Format | None], + srcs: Sequence[Device | Sharding | Format | None], copy_semantics: Sequence[CopySemantics]): ys = [] dsa_indices, dsa_xs, dsa_shardings, dsa_copy_semantics = [], [], [], [] @@ -561,7 +578,7 @@ def _batched_device_put_impl( if dsa_xs: # Batch shard_arg calls. Helps improve efficiency for backends that support # efficient batch transfer. - # device_put handles `Layout` via a different path, so just pass `None` as + # device_put handles `Format` via a different path, so just pass `None` as # the layout here. shard_arg_results = pxla.shard_args(dsa_shardings, [None] * len(dsa_xs), dsa_copy_semantics, dsa_xs) @@ -600,7 +617,7 @@ def _device_put_transpose(cts, *_, devices, srcs, copy_semantics): assert cp == CopySemantics.COPY new_copy_semantics.append(CopySemantics.COPY) ys = device_put_p.bind(*args, devices=srcs, srcs=devices, - copy_semantics=new_copy_semantics) + copy_semantics=tuple(new_copy_semantics)) for i, y in zip(indices, ys): results[i] = y return results diff --git a/jax/_src/distributed.py b/jax/_src/distributed.py index af50e2e9e31a..ae1baf8052c0 100644 --- a/jax/_src/distributed.py +++ b/jax/_src/distributed.py @@ -22,7 +22,7 @@ from jax._src import clusters from jax._src import config from jax._src import xla_bridge -from jax._src.lib import xla_extension +from jax._src.lib import _jax logger = logging.getLogger(__name__) @@ -37,10 +37,11 @@ class State: process_id: int = 0 num_processes: int = 1 - service: Any | None = None - client: Any | None = None + service: _jax.DistributedRuntimeService | Any | None = None + client: _jax.DistributedRuntimeClient | Any | None = None preemption_sync_manager: Any | None = None coordinator_address: str | None = None + slice_index: int | None = None def initialize(self, coordinator_address: str | None = None, @@ -53,7 +54,8 @@ def initialize(self, service_heartbeat_interval_seconds: int = 10, service_max_missing_heartbeats: int = 10, client_heartbeat_interval_seconds: int = 10, - client_max_missing_heartbeats: int = 10): + client_max_missing_heartbeats: int = 10, + slice_index: int | None = None): coordinator_address = (coordinator_address or os.environ.get('JAX_COORDINATOR_ADDRESS')) if isinstance(local_device_ids, int): @@ -130,7 +132,7 @@ def initialize(self, logger.info( 'Starting JAX distributed service on %s', coordinator_bind_address ) - self.service = xla_extension.get_distributed_runtime_service( + self.service = _jax.get_distributed_runtime_service( coordinator_bind_address, num_processes, heartbeat_interval=service_heartbeat_interval_seconds, max_missing_heartbeats=service_max_missing_heartbeats) @@ -140,7 +142,7 @@ def initialize(self, if self.client is not None: raise RuntimeError('distributed.initialize should only be called once.') - self.client = xla_extension.get_distributed_runtime_client( + self.client = _jax.get_distributed_runtime_client( coordinator_address, process_id, init_timeout=initialization_timeout, heartbeat_interval=client_heartbeat_interval_seconds, max_missing_heartbeats=client_max_missing_heartbeats, use_compression=True) @@ -149,22 +151,31 @@ def initialize(self, self.initialize_preemption_sync_manager() + if slice_index is None and 'JAX_SLICE_INDEX' in os.environ: + slice_index = int(os.environ.get('JAX_SLICE_INDEX')) # type: ignore + self.slice_index = slice_index + def shutdown(self): + if self.preemption_sync_manager: + # It's important to shut down the preemption sync manager before the + # client because the preemption sync manager depends on the client. + # TODO: Delete hasattr check once 0.6.1 is the minimum jaxlib version + if hasattr(self.preemption_sync_manager, "shutdown"): + self.preemption_sync_manager.shutdown() + self.preemption_sync_manager = None if self.client: self.client.shutdown() self.client = None if self.service: self.service.shutdown() self.service = None - if self.preemption_sync_manager: - self.preemption_sync_manager = None def initialize_preemption_sync_manager(self): if self.preemption_sync_manager is not None: raise RuntimeError( 'Preemption sync manager should only be initialized once.') self.preemption_sync_manager = ( - xla_extension.create_preemption_sync_manager()) + _jax.create_preemption_sync_manager()) self.preemption_sync_manager.initialize(self.client) global_state = State() @@ -175,7 +186,8 @@ def initialize(coordinator_address: str | None = None, local_device_ids: int | Sequence[int] | None = None, cluster_detection_method: str | None = None, initialization_timeout: int = 300, - coordinator_bind_address: str | None = None): + coordinator_bind_address: str | None = None, + slice_index: int | None = None): """Initializes the JAX distributed system. Calling :func:`~jax.distributed.initialize` prepares JAX for execution on @@ -236,6 +248,8 @@ def initialize(coordinator_address: str | None = None, all available addresses on the same port as ``coordinator_address``. On systems that have multiple network interfaces per node it may be insufficient to only have the coordinator service listen on one address/interface. + slice_index: The slice index assigned to this process' local devices. If any process sets ``slice_index``, + then all processes must do so. If ``None`` the slice indices will be chosen automatically. Raises: RuntimeError: If :func:`~jax.distributed.initialize` is called more than once @@ -261,7 +275,8 @@ def initialize(coordinator_address: str | None = None, "This includes any computation, but also calls to jax.devices, jax.device_put, and others.") global_state.initialize(coordinator_address, num_processes, process_id, local_device_ids, cluster_detection_method, - initialization_timeout, coordinator_bind_address) + initialization_timeout, coordinator_bind_address, + slice_index=slice_index) def is_initialized() -> bool: diff --git a/jax/_src/dlpack.py b/jax/_src/dlpack.py index a0b1db608ad0..1f19ac0f45c0 100644 --- a/jax/_src/dlpack.py +++ b/jax/_src/dlpack.py @@ -130,7 +130,7 @@ def to_dlpack(x: Array, stream: int | Any | None = None, ) from None # As new versions are adopted over time, we can maintain some legacy paths - # for compatability mediated through the max_version parameter. + # for compatibility mediated through the max_version parameter. # TODO(micky774): Deprecate default usage of DLPackManagedTensor when XLA # supports DLManagedTensorVersioned (DLPack version 1.0) and repurpose the # current _to_dlpack as a legacy path for (0,5) <= max_version < (1,0). @@ -240,7 +240,7 @@ def from_dlpack(external_array, device transfer or copy was requested. Args: - external_array: An array object that has ``__dlpack__` and + external_array: An array object that has ``__dlpack__`` and ``__dlpack_device__`` methods. device: The (optional) :py:class:`Device`, representing the device on which the returned array should be placed. If given, then the result is diff --git a/jax/_src/dtypes.py b/jax/_src/dtypes.py index 01500c008405..3be98da4a84b 100644 --- a/jax/_src/dtypes.py +++ b/jax/_src/dtypes.py @@ -90,19 +90,18 @@ def type(self) -> type: ... # fp8 support -# TODO: remove Optional when minimum ml_dtypes version >= 0.5.0 -float8_e3m4: type[np.generic] | None = None -float8_e4m3: type[np.generic] | None = None -float8_e8m0fnu: type[np.generic] | None = None +float8_e3m4: type[np.generic] = ml_dtypes.float8_e3m4 +float8_e4m3: type[np.generic] = ml_dtypes.float8_e4m3 +float8_e8m0fnu: type[np.generic] = ml_dtypes.float8_e8m0fnu float8_e4m3b11fnuz: type[np.generic] = ml_dtypes.float8_e4m3b11fnuz float8_e4m3fn: type[np.generic] = ml_dtypes.float8_e4m3fn float8_e4m3fnuz: type[np.generic] = ml_dtypes.float8_e4m3fnuz float8_e5m2: type[np.generic] = ml_dtypes.float8_e5m2 float8_e5m2fnuz: type[np.generic] = ml_dtypes.float8_e5m2fnuz -_float8_e3m4_dtype: np.dtype | None = None -_float8_e4m3_dtype: np.dtype | None = None -_float8_e8m0fnu_dtype: np.dtype | None = None +_float8_e3m4_dtype: np.dtype = np.dtype(float8_e3m4) +_float8_e4m3_dtype: np.dtype = np.dtype(float8_e4m3) +_float8_e8m0fnu_dtype: np.dtype = np.dtype(float8_e8m0fnu) _float8_e4m3b11fnuz_dtype: np.dtype = np.dtype(float8_e4m3b11fnuz) _float8_e4m3fn_dtype: np.dtype = np.dtype(float8_e4m3fn) _float8_e4m3fnuz_dtype: np.dtype = np.dtype(float8_e4m3fnuz) @@ -111,9 +110,9 @@ def type(self) -> type: ... # fp4 support # TODO: remove Optional when minimum ml_dtypes version >= 0.5.0 -float4_e2m1fn: type[np.generic] | None = None +float4_e2m1fn: type[np.generic] = ml_dtypes.float4_e2m1fn -_float4_e2m1fn_dtype: np.dtype | None = None +_float4_e2m1fn_dtype: np.dtype = np.dtype(float4_e2m1fn) def supports_inf(dtype: DTypeLike) -> bool: """Return true if the dtype supports infinity, else return False.""" @@ -127,6 +126,10 @@ def supports_inf(dtype: DTypeLike) -> bool: _bfloat16_dtype: np.dtype = np.dtype(bfloat16) _custom_float_scalar_types = [ + float4_e2m1fn, + float8_e3m4, + float8_e4m3, + float8_e8m0fnu, float8_e4m3b11fnuz, float8_e4m3fn, float8_e4m3fnuz, @@ -135,6 +138,10 @@ def supports_inf(dtype: DTypeLike) -> bool: bfloat16, ] _custom_float_dtypes = [ + _float4_e2m1fn_dtype, + _float8_e3m4_dtype, + _float8_e4m3_dtype, + _float8_e8m0fnu_dtype, _float8_e4m3b11fnuz_dtype, _float8_e4m3fn_dtype, _float8_e4m3fnuz_dtype, @@ -143,6 +150,9 @@ def supports_inf(dtype: DTypeLike) -> bool: _bfloat16_dtype, ] _float8_dtypes = [ + _float8_e3m4_dtype, + _float8_e4m3_dtype, + _float8_e8m0fnu_dtype, _float8_e4m3b11fnuz_dtype, _float8_e4m3fn_dtype, _float8_e4m3fnuz_dtype, @@ -150,58 +160,28 @@ def supports_inf(dtype: DTypeLike) -> bool: _float8_e5m2fnuz_dtype, ] -_float4_dtypes: list[np.dtype] = [] - -# TODO: remove the if statements below when minimum ml_dtypes version >= 0.5.0 -if hasattr(ml_dtypes, "float8_e4m3"): - float8_e4m3 = ml_dtypes.float8_e4m3 - _float8_e4m3_dtype = np.dtype(float8_e4m3) - _custom_float_scalar_types.insert(0, float8_e4m3) # type: ignore[arg-type] - _custom_float_dtypes.insert(0, _float8_e4m3_dtype) - _float8_dtypes.insert(0, _float8_e4m3_dtype) -if hasattr(ml_dtypes, "float8_e3m4"): - float8_e3m4 = ml_dtypes.float8_e3m4 - _float8_e3m4_dtype = np.dtype(float8_e3m4) - _custom_float_scalar_types.insert(0, float8_e3m4) # type: ignore[arg-type] - _custom_float_dtypes.insert(0, _float8_e3m4_dtype) - _float8_dtypes.insert(0, _float8_e3m4_dtype) -if hasattr(ml_dtypes, "float8_e8m0fnu"): - float8_e8m0fnu = ml_dtypes.float8_e8m0fnu - _float8_e8m0fnu_dtype = np.dtype(float8_e8m0fnu) - _custom_float_scalar_types.insert(0, float8_e8m0fnu) # type: ignore[arg-type] - _custom_float_dtypes.insert(0, _float8_e8m0fnu_dtype) - _float8_dtypes.insert(0, _float8_e8m0fnu_dtype) -if hasattr(ml_dtypes, "float4_e2m1fn"): - float4_e2m1fn = ml_dtypes.float4_e2m1fn - _float4_e2m1fn_dtype = np.dtype(float4_e2m1fn) - _custom_float_scalar_types.insert(0, float4_e2m1fn) # type: ignore[arg-type] - _custom_float_dtypes.insert(0, _float4_e2m1fn_dtype) - _float4_dtypes.insert(0, _float4_e2m1fn_dtype) - -# 2-bit integer support -int2: type[np.generic] | None = None -uint2: type[np.generic] | None = None - -_int2_dtype: np.dtype | None = None -_uint2_dtype: np.dtype | None = None - -_intn_dtypes = [] - -# Remove the condition once the minimum ml_dtypes version required by JAX -# contains https://github.com/jax-ml/ml_dtypes/pull/154. -if hasattr(ml_dtypes, 'int2'): - int2 = ml_dtypes.int2 - uint2 = ml_dtypes.uint2 - _int2_dtype = np.dtype(int2) - _uint2_dtype = np.dtype(uint2) - _intn_dtypes.extend([_int2_dtype, _uint2_dtype]) +_float4_dtypes: list[np.dtype] = [ + _float4_e2m1fn_dtype, +] + +int2: type[np.generic] = ml_dtypes.int2 +uint2: type[np.generic] = ml_dtypes.uint2 + +_int2_dtype: np.dtype = np.dtype(int2) +_uint2_dtype: np.dtype = np.dtype(uint2) # 4-bit integer support int4: type[np.generic] = ml_dtypes.int4 uint4: type[np.generic] = ml_dtypes.uint4 _int4_dtype = np.dtype(int4) _uint4_dtype = np.dtype(uint4) -_intn_dtypes.extend([_int4_dtype, _uint4_dtype]) + +_intn_dtypes = [ + _int2_dtype, + _uint2_dtype, + _int4_dtype, + _uint4_dtype, +] # Default types. bool_ = np.bool_ @@ -299,6 +279,12 @@ def to_inexact_dtype(dtype: DTypeLike) -> DType: return _dtype_to_inexact.get(dtype_, dtype_) +def to_floating_dtype(dtype: DTypeLike) -> DType: + """Promotes a dtype to a non-complex floating dtype.""" + dtype_ = np.dtype(dtype) + return finfo(_dtype_to_inexact.get(dtype_, dtype_)).dtype + + def to_complex_dtype(dtype: DTypeLike) -> DType: ftype = to_inexact_dtype(dtype) if ftype in [np.dtype('float64'), np.dtype('complex128')]: @@ -387,7 +373,8 @@ def _scalar_type_to_dtype(typ: type, value: Any = None) -> DType: """ dtype = canonicalize_dtype(python_scalar_dtypes[typ]) if typ is int and value is not None: - if value < np.iinfo(dtype).min or value > np.iinfo(dtype).max: + iinfo = np.iinfo(dtype) + if value < iinfo.min or value > iinfo.max: raise OverflowError(f"Python int {value} too large to convert to {dtype}") return dtype @@ -472,9 +459,9 @@ def _issubdtype_cached(a: type | np.dtype | ExtendedDType, # to the normal scalar type hierarchy. if a_sctype in _custom_float_scalar_types: return b_sctype in {a_sctype, np.floating, np.inexact, np.number, np.generic} - if (int2 is not None and a_sctype == int2) or a_sctype == int4: + if a_sctype in [int2, int4]: return b_sctype in {a_sctype, np.signedinteger, np.integer, np.number, np.generic} - if (uint2 is not None and a_sctype == uint2) or a_sctype == uint4: + if a_sctype in [uint2, uint4]: return b_sctype in {a_sctype, np.unsignedinteger, np.integer, np.number, np.generic} # Otherwise, fall back to numpy.issubdtype @@ -491,6 +478,7 @@ def _issubdtype_cached(a: type | np.dtype | ExtendedDType, _unsigned_types: list[JAXType] _int_types: list[JAXType] _unsigned_types = [ + np.dtype(uint2), np.dtype(uint4), np.dtype('uint8'), np.dtype('uint16'), @@ -498,6 +486,7 @@ def _issubdtype_cached(a: type | np.dtype | ExtendedDType, np.dtype('uint64'), ] _signed_types = [ + np.dtype(int2), np.dtype(int4), np.dtype('int8'), np.dtype('int16'), @@ -505,11 +494,6 @@ def _issubdtype_cached(a: type | np.dtype | ExtendedDType, np.dtype('int64'), ] -if _int2_dtype is not None: - _signed_types.insert(0, _int2_dtype) -if _uint2_dtype is not None: - _unsigned_types.insert(0, _uint2_dtype) - _int_types = _unsigned_types + _signed_types _float_types: list[JAXType] = [ @@ -622,11 +606,7 @@ def _type_promotion_lattice(jax_numpy_dtype_promotion: str) -> dict[JAXType, lis This DAG maps each type to its immediately higher type on the lattice. """ b1, = _bool_types - if _int2_dtype is not None: - assert _uint2_dtype is not None - _uint2, uint4, u1, u2, u4, u8, _int2, int4, i1, i2, i4, i8 = _int_types - else: - uint4, u1, u2, u4, u8, int4, i1, i2, i4, i8 = _int_types + uint2, uint4, u1, u2, u4, u8, int2, int4, i1, i2, i4, i8 = _int_types *f1_types, bf, f2, f4, f8 = _float_types c4, c8 = _complex_types i_, f_, c_ = _weak_types @@ -634,19 +614,13 @@ def _type_promotion_lattice(jax_numpy_dtype_promotion: str) -> dict[JAXType, lis out: dict[JAXType, list[JAXType]] out = { b1: [i_], - i_: [u1, uint4, i1, int4], - uint4: [], u1: [i2, u2], u2: [i4, u4], u4: [i8, u8], u8: [f_], - int4: [], i1: [i2], i2: [i4], i4: [i8], i8: [f_], + i_: [u1, uint2, uint4, i1, int2, int4], + uint2: [], uint4: [], u1: [i2, u2], u2: [i4, u4], u4: [i8, u8], u8: [f_], + int2: [], int4: [], i1: [i2], i2: [i4], i4: [i8], i8: [f_], f_: [*f1_types, bf, f2, c_], **{t: [] for t in f1_types}, bf: [f4], f2: [f4], f4: [f8, c4], f8: [c8], c_: [c4], c4: [c8], c8: [], } - if _int2_dtype is not None: - out[i_].append(_int2_dtype) - out[_int2_dtype] = [] - if _uint2_dtype is not None: - out[i_].append(_uint2_dtype) - out[_uint2_dtype] = [] return out elif jax_numpy_dtype_promotion == 'strict': return { @@ -1010,6 +984,7 @@ class PrimalTangentDType(ExtendedDType): return PrimalTangentDType() +@functools.cache def short_dtype_name(dtype) -> str: if isinstance(dtype, ExtendedDType): return str(dtype) diff --git a/jax/_src/effects.py b/jax/_src/effects.py index 36528c5feae5..fb79c542e78b 100644 --- a/jax/_src/effects.py +++ b/jax/_src/effects.py @@ -47,7 +47,7 @@ for each thread the `RuntimeToken` returned by the last dispatched computation. For more details, see the design note: -https://jax.readthedocs.io/en/latest/jep/10657-sequencing-effects.html. +https://docs.jax.dev/en/latest/jep/10657-sequencing-effects.html. """ from __future__ import annotations @@ -118,3 +118,5 @@ def filter_not_in(self, effects: Iterable[Effect]) -> list[Effect]: control_flow_allowed_effects: EffectTypeSet = EffectTypeSet() custom_derivatives_allowed_effects: EffectTypeSet = EffectTypeSet() remat_allowed_effects: EffectTypeSet = EffectTypeSet() + +partial_eval_kept_effects: EffectTypeSet = EffectTypeSet() diff --git a/jax/_src/environment_info.py b/jax/_src/environment_info.py index 4abfdeaa0f14..e11d1b4f4b31 100644 --- a/jax/_src/environment_info.py +++ b/jax/_src/environment_info.py @@ -14,6 +14,7 @@ from __future__ import annotations +import os import platform import subprocess import sys @@ -48,8 +49,10 @@ def print_environment_info(return_string: bool = False) -> str | None: python: {python_version} device info: {xb.devices()[0].device_kind}-{xb.device_count()}, {xb.local_device_count()} local devices" process_count: {xb.process_count()} - platform: {platform.uname()} -""") + platform: {platform.uname()}""") + for key, value in os.environ.items(): + if key.startswith("JAX_"): + info += f"\n{key}={value}" nvidia_smi = try_nvidia_smi() if nvidia_smi: info += '\n\n$ nvidia-smi\n' + nvidia_smi diff --git a/jax/_src/error_check.py b/jax/_src/error_check.py index 60dc2f76a5b2..339407fa295f 100644 --- a/jax/_src/error_check.py +++ b/jax/_src/error_check.py @@ -14,27 +14,30 @@ from __future__ import annotations +import dataclasses from functools import partial +import json import threading +import traceback as tb_lib +from types import TracebackType +import warnings import jax from jax._src import core from jax._src import source_info_util from jax._src import traceback_util import jax._src.mesh as mesh_lib -from jax.experimental.shard_map import shard_map +from jax._src import shard_map +import jax.export import jax.numpy as jnp from jax.sharding import NamedSharding, PartitionSpec as P -Traceback = source_info_util.Traceback - - traceback_util.register_exclusion(__file__) class JaxValueError(ValueError): - """Exception raised for failed runtime error checks in JAX.""" + """Exception raised for runtime errors detected within JAX computations.""" #: The default error code for no error. @@ -44,8 +47,9 @@ class JaxValueError(ValueError): _NO_ERROR = jnp.iinfo(jnp.uint32).max -_error_list_lock = threading.Lock() -_error_list: list[tuple[str, Traceback]] = [] # (error_message, traceback) pair +_error_list_lock = threading.RLock() +# (error_message, traceback) pairs. Traceback is `str` when imported from AOT. +_error_list: list[tuple[str, TracebackType | str]] = [] class _ErrorStorage(threading.local): @@ -58,38 +62,42 @@ def __init__(self): def _initialize_error_code_ref() -> None: - """Initialize error_code_ref in the current thread. + """Initialize the error code ref in the current thread. - The size of the error code array is determined by the mesh in the context. In - single-device environment, the array is a scalar. In multi-device - environment, the array has the same shape as the mesh. + The shape and size of the error code array depend on the mesh in the context. + In single-device environments, the array is a scalar. In multi-device + environments, its shape and size match those of the mesh. """ - with core.eval_context(): - # Get mesh from the context. - mesh = mesh_lib.get_concrete_mesh() - - if mesh is None: # single-device case. - error_code = jnp.uint32(_NO_ERROR) - - else: # multi-device case. - sharding = NamedSharding(mesh, P(*mesh.axis_names)) - error_code = jnp.full( - mesh.axis_sizes, - jnp.uint32(_NO_ERROR), - device=sharding, - ) + # Get mesh from the context. + mesh = mesh_lib.get_concrete_mesh() + + if mesh is None: # single-device case. + error_code = jnp.uint32(_NO_ERROR) + + else: # multi-device case. + sharding = NamedSharding(mesh, P(*mesh.axis_names)) + error_code = jnp.full( + mesh.axis_sizes, + jnp.uint32(_NO_ERROR), + device=sharding, + ) - _error_storage.ref = core.mutable_array(error_code) + _error_storage.ref = core.mutable_array(error_code) class error_checking_context: - """Redefine the error checking state based on the mesh in the context. + """Redefine the internal error state based on the mesh in the context. - This context manager should be used when starting a multi-device - computation, and whenever the mesh is changed. + When using JAX in multi-device environments in explicit mode, error tracking + needs to be properly aligned with the device mesh. This context manager + ensures that the internal error state is correctly initialized based on the + current mesh configuration. - When exiting the context, the error checking state will be reset to the - original state. + This context manager should be used when starting a multi-device computation, + or when switching between different device meshes. + + On entering the context, it initializes a new error state based on the mesh in + the context. On exiting the context, it restores the previous error state. """ __slots__ = ("old_ref",) @@ -99,7 +107,8 @@ def __init__(self): def __enter__(self): self.old_ref = _error_storage.ref - _initialize_error_code_ref() + with core.eval_context(): + _initialize_error_code_ref() return self def __exit__(self, exc_type, exc_value, traceback): @@ -107,19 +116,46 @@ def __exit__(self, exc_type, exc_value, traceback): def set_error_if(pred: jax.Array, /, msg: str) -> None: - """Set error if any element of pred is true. - - If the error is already set, the new error will be ignored. It will not - override the existing error. - - In auto mode, this function does not work under jit. + """Set the internal error state if any element of `pred` is `True`. + + This function is used inside JAX computations to detect runtime errors without + immediately halting execution. When this function is traced (e.g., inside + :func:`jax.jit`), the corresponding error message and its traceback are + recorded. At execution time, if `pred` contains any `True` values, the error + state is set, but execution continues without interruption. The recorded error + can later be raised using :func:`raise_if_error`. + + If the error state has already been set, subsequent errors are ignored and + will not override the existing error. + + For multi-device environments, in explicit mode, users must call + :func:`error_checking_context` to initialize a new error tracking state that + matches the device mesh. In auto mode, implicit cross-device communication may + occur inside this function, which could impact performance. A warning is + issued in such cases. + + When exporting a function with `jax.export`, error checking must be explicitly + wrapped using :func:`wrap_for_export` before export and + :func:`unwrap_from_import` after import. + + Args: + pred: A JAX boolean array. If any element of `pred` is `True`, the internal + error state will be set. + msg: The corresponding error message to be raised later. """ if _error_storage.ref is None: - _initialize_error_code_ref() + with core.eval_context(): + _initialize_error_code_ref() assert _error_storage.ref is not None + # Get the traceback. traceback = source_info_util.current().traceback assert traceback is not None + traceback = traceback.as_python_traceback() + assert isinstance(traceback, TracebackType) + traceback = traceback_util.filter_traceback(traceback) + assert isinstance(traceback, TracebackType) + with _error_list_lock: new_error_code = jnp.uint32(len(_error_list)) _error_list.append((msg, traceback)) @@ -127,41 +163,55 @@ def set_error_if(pred: jax.Array, /, msg: str) -> None: out_sharding = core.typeof(_error_storage.ref).sharding in_sharding: NamedSharding = core.typeof(pred).sharding - if out_sharding.mesh.shape_tuple == (): # single-device case. + # Reduce `pred`. + if all(dim is None for dim in out_sharding.spec): # single-device case. pred = pred.any() else: # multi-device case. has_auto_axes = mesh_lib.AxisType.Auto in in_sharding.mesh.axis_types - if has_auto_axes: - raise NotImplementedError( - "Error checking in auto mode is not supported yet. Please use" - " explicit mode." - ) - if out_sharding.mesh != in_sharding.mesh: - raise ValueError( - "The error code state and the predicate must be on the same mesh, " - f"but got {out_sharding.mesh} and {in_sharding.mesh} respectively. " - "Please use `with error_checking_context()` to redefine the error " - "code state based on the mesh." + if has_auto_axes: # auto mode. + warnings.warn( + "When at least one mesh axis of `pred` is in auto mode, calling" + " `set_error_if` will cause implicit communication between devices." + " To avoid this, consider converting the mesh axis in auto mode to" + " explicit mode.", + RuntimeWarning, ) - pred = shard_map( - partial(jnp.any, keepdims=True), - mesh=out_sharding.mesh, - in_specs=in_sharding.spec, - out_specs=out_sharding.spec, - )(pred) # perform per-device reduction + pred = pred.any() # reduce to a single scalar + else: # explicit mode. + if out_sharding.mesh != in_sharding.mesh: + raise ValueError( + "The error code state and the predicate must be on the same mesh, " + f"but got {out_sharding.mesh} and {in_sharding.mesh} respectively. " + "Please use `with error_checking_context()` to redefine the error " + "code state based on the mesh." + ) + pred = shard_map.shard_map( + partial(jnp.any, keepdims=True), + mesh=out_sharding.mesh, + in_specs=in_sharding.spec, + out_specs=out_sharding.spec, + )(pred) # perform per-device reduction error_code = _error_storage.ref[...] - should_update = jnp.logical_and(pred, error_code == jnp.uint32(_NO_ERROR)) + should_update = jnp.logical_and(error_code == jnp.uint32(_NO_ERROR), pred) error_code = jnp.where(should_update, new_error_code, error_code) # TODO(ayx): support vmap and shard_map. _error_storage.ref[...] = error_code def raise_if_error() -> None: - """Raise error if an error is set. + """Raise an exception if the internal error state is set. + + This function should be called after a computation completes to check for any + errors that were marked during execution via `set_error_if()`. If an error + exists, it raises a `JaxValueError` with the corresponding error message. + + This function should not be called inside a traced function (e.g., inside + :func:`jax.jit`). Doing so will raise a `ValueError`. - This function should be called after the computation is finished. It should - not be called within a traced context, such as within a jitted function." + Raises: + JaxValueError: If the internal error state is set. + ValueError: If called within a traced JAX function. """ if _error_storage.ref is None: # if not initialized, do nothing return @@ -180,8 +230,136 @@ def raise_if_error() -> None: device=_error_storage.ref.sharding, ) # clear the error code - msg, traceback = _error_list[error_code] - exc = JaxValueError(msg) - traceback = traceback.as_python_traceback() - filtered_traceback = traceback_util.filter_traceback(traceback) - raise exc.with_traceback(filtered_traceback) + with _error_list_lock: + msg, traceback = _error_list[error_code] + if isinstance(traceback, str): # from imported AOT functions + exc = JaxValueError( + f"{msg}\nThe original traceback is shown below:\n{traceback}" + ) + raise exc + else: + exc = JaxValueError(msg) + raise exc.with_traceback(traceback) + + +@dataclasses.dataclass(frozen=True) +class _ErrorClass: + """A class to store error information for AOT compilation. + + This class is used internally by the wrapper functions `wrap_for_export` and + `unwrap_from_import` to encapsulate error-related data within an exported + function. + + Attributes: + error_code (jax.Array): A JAX array representing the final error state of + the function to be exported. This value is local to the wrapper function. + error_list (list[tuple[str, str]]): A list of `(error_message, traceback)` + pairs containing error messages and corresponding stack traces. This error + list is local to the wrapper function, and does not contain pairs of error + information from other functions. + """ + + error_code: jax.Array + error_list: list[tuple[str, str]] + + +jax.tree_util.register_dataclass( + _ErrorClass, data_fields=("error_code",), meta_fields=("error_list",) +) +jax.export.register_pytree_node_serialization( + _ErrorClass, + serialized_name=f"{_ErrorClass.__module__}.{_ErrorClass.__name__}", + serialize_auxdata=lambda x: json.dumps(x, ensure_ascii=False).encode( + "utf-8" + ), + deserialize_auxdata=lambda x: json.loads(x.decode("utf-8")), +) + + +def _traceback_to_str(traceback: TracebackType) -> str: + """Convert a traceback to a string for export.""" + return "".join(tb_lib.format_list(tb_lib.extract_tb(traceback))).rstrip("\n") + + +def wrap_for_export(f): + """Wrap a function with error checking to make it compatible with AOT mode. + + Error checking relies on global state, which cannot be serialized across + processes. This wrapper ensures that the error state remains within the + function scope, making it possible to export the function and later import in + other processes. + + When the function is later imported, it must be wrapped with + :func:`unwrap_from_import` to integrate the error checking mechanism of the + imported function into the global error checking mechanism of the current + process. + + This function should only be applied once to a function; wrapping the same + function multiple times is unnecessary. + """ + + def inner(*args, **kwargs): + global _error_list + + # 1. Save the old state and initialize a new state. + with core.eval_context(): + old_ref = _error_storage.ref + _initialize_error_code_ref() + with _error_list_lock: + old_error_list, _error_list = _error_list, [] + + # 2. Trace the function. + out = f(*args, **kwargs) + error_code = _error_storage.ref[...].min() + + # 3. Restore the old state. + _error_list, new_error_list = old_error_list, _error_list + with core.eval_context(): + _error_storage.ref = old_ref + + new_error_list = [ + (msg, _traceback_to_str(traceback)) for msg, traceback in new_error_list + ] + return out, _ErrorClass(error_code, new_error_list) + + return inner + + +def unwrap_from_import(f): + """Unwrap a function after AOT import to restore error checking. + + When an AOT-exported function is imported in a new process, its error state is + separate from the global error state of the current process. This wrapper + ensures that errors detected during execution are correctly integrated into + the global error checking mechanism of the current process. + + This function should only be applied to functions that were previously wrapped + with :func:`wrap_for_export` before export. + """ + if _error_storage.ref is None: + with core.eval_context(): + _initialize_error_code_ref() + assert _error_storage.ref is not None + + def inner(*args, **kwargs): + out, error_class = f(*args, **kwargs) + new_error_code, error_list = error_class.error_code, error_class.error_list + + # Update the global error list. + with _error_list_lock: + offset = len(_error_list) + _error_list.extend(error_list) + + # Update the global error code array. + error_code = _error_storage.ref[...] + should_update = jnp.logical_and( + error_code == jnp.uint32(_NO_ERROR), + new_error_code != jnp.uint32(_NO_ERROR), + ) + error_code = jnp.where(should_update, new_error_code + offset, error_code) + # TODO(ayx): support vmap and shard_map. + _error_storage.ref[...] = error_code + + return out + + return inner diff --git a/jax/_src/errors.py b/jax/_src/errors.py index 6540fd1f5d41..a548714869ab 100644 --- a/jax/_src/errors.py +++ b/jax/_src/errors.py @@ -21,7 +21,7 @@ class _JAXErrorMixin: """Mixin for JAX-specific errors""" - _error_page = 'https://jax.readthedocs.io/en/latest/errors.html' + _error_page = 'https://docs.jax.dev/en/latest/errors.html' _module_name = "jax.errors" def __init__(self, message: str): @@ -306,7 +306,7 @@ class TracerArrayConversionError(JAXTypeError): and concrete vs. abstract values, you may want to read :ref:`faq-different-kinds-of-jax-values`. - .. _External Callbacks: https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html + .. _External Callbacks: https://docs.jax.dev/en/latest/notebooks/external_callbacks.html """ def __init__(self, tracer: core.Tracer): super().__init__( @@ -503,7 +503,7 @@ class TracerBoolConversionError(ConcretizationTypeError): In this case, the error occurs because Python's built-in ``min`` function is not compatible with JAX transforms. This can be fixed by replacing it with - ``jnp.minumum``: + ``jnp.minimum``: >>> @jit ... def func(x): @@ -530,7 +530,7 @@ class UnexpectedTracerError(JAXTypeError): function ``f`` that stores, in some scope outside of ``f``, a reference to an intermediate value, that value is considered to have been leaked. Leaking values is a side effect. (Read more about avoiding side effects in - `Pure Functions `_) + `Pure Functions `_) JAX detects leaks when you then use the leaked value in another operation later on, at which point it raises an ``UnexpectedTracerError``. @@ -678,6 +678,5 @@ class KeyReuseError(JAXTypeError): This sort of key reuse is problematic because the JAX PRNG is stateless, and keys must be manually split; For more information on this see `the Pseudorandom Numbers - tutorial `_. + tutorial `_. """ - pass diff --git a/jax/_src/export/_export.py b/jax/_src/export/_export.py index afae3d9bcdc2..16ffedc8dd09 100644 --- a/jax/_src/export/_export.py +++ b/jax/_src/export/_export.py @@ -11,9 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""JAX APIs for exporting JAX functions for interoperation. - -""" +"""JAX APIs for exporting JAX functions for interoperation.""" from __future__ import annotations @@ -27,15 +25,14 @@ import re from typing import Any, Protocol, TypeVar, Union, cast -from absl import logging +import logging import numpy as np -import jax -from jax import sharding - from jax._src import ad_util +from jax._src import api from jax._src import config from jax._src import core +from jax._src import custom_derivatives from jax._src import dispatch from jax._src import dtypes from jax._src import effects @@ -43,20 +40,25 @@ from jax._src.interpreters import mlir from jax._src.interpreters import pxla from jax._src.lib import xla_client -from jax._src.lib import xla_extension, xla_extension_version +from jax._src.lib import _jax from jax._src.lib.mlir import ir, passmanager from jax._src.lib.mlir.dialects import hlo from jax._src.lib.mlir.dialects import func as func_dialect +from jax._src import mesh from jax._src import pjit +from jax._src import sharding from jax._src import sharding_impls from jax._src import source_info_util from jax._src import stages from jax._src import tree_util +from jax._src import typing from jax._src import util from jax._src import xla_bridge as xb from jax._src.export import shape_poly +logger = logging.getLogger(__name__) + map = util.safe_map zip = util.safe_zip @@ -67,7 +69,7 @@ HloSharding = xla_client.HloSharding # The minimum and maximum supported calling convention version. -# See https://jax.readthedocs.io/en/latest/export/export.html#export-calling-convention-version +# See https://docs.jax.dev/en/latest/export/export.html#export-calling-convention-version minimum_supported_calling_convention_version = 9 maximum_supported_calling_convention_version = 9 @@ -153,16 +155,16 @@ class Exported: platforms: a tuple containing the platforms for which the function should be exported. The set of platforms in JAX is open-ended; users can add platforms. JAX built-in platforms are: 'tpu', 'cpu', 'cuda', 'rocm'. - See https://jax.readthedocs.io/en/latest/export/export.html#cross-platform-and-multi-platform-export. + See https://docs.jax.dev/en/latest/export/export.html#cross-platform-and-multi-platform-export. ordered_effects: the ordered effects present in the serialized module. - This is present from serialization version 9. See https://jax.readthedocs.io/en/latest/export/export.html#module-calling-convention + This is present from serialization version 9. See https://docs.jax.dev/en/latest/export/export.html#module-calling-convention for the calling convention in presence of ordered effects. unordered_effects: the unordered effects present in the serialized module. This is present from serialization version 9. mlir_module_serialized: the serialized lowered VHLO module. calling_convention_version: a version number for the calling convention of the exported module. - See more versioning details at https://jax.readthedocs.io/en/latest/export/export.html#calling-convention-versions. + See more versioning details at https://docs.jax.dev/en/latest/export/export.html#calling-convention-versions. module_kept_var_idx: the sorted indices of the arguments among `in_avals` that must be passed to the module. The other arguments have been dropped because they are not used. @@ -181,7 +183,7 @@ class Exported: for each primal output. It returns a tuple with the cotangents corresponding to the flattened primal inputs. - See a [description of the calling convention for the `mlir_module`](https://jax.readthedocs.io/en/latest/export/export.html#module-calling-convention). + See a [description of the calling convention for the `mlir_module`](https://docs.jax.dev/en/latest/export/export.html#module-calling-convention). """ fun_name: str in_tree: tree_util.PyTreeDef @@ -215,7 +217,7 @@ def __str__(self): def in_shardings_jax( self, - mesh: sharding.Mesh) -> Sequence[sharding.Sharding | None]: + mesh: mesh.Mesh) -> Sequence[sharding.Sharding | None]: """Creates Shardings corresponding to self.in_shardings_hlo. The Exported object stores `in_shardings_hlo` as HloShardings, which are @@ -225,7 +227,7 @@ def in_shardings_jax( Example usage: - >>> from jax import export + >>> from jax import export, sharding >>> # Prepare the exported object: >>> exp_mesh = sharding.Mesh(jax.devices(), ("a",)) >>> exp = export.export(jax.jit(lambda x: jax.numpy.add(x, x), @@ -255,7 +257,7 @@ def in_shardings_jax( def out_shardings_jax( self, - mesh: sharding.Mesh) -> Sequence[sharding.Sharding | None]: + mesh: mesh.Mesh) -> Sequence[sharding.Sharding | None]: """Creates Shardings corresponding to `self.out_shardings_hlo`. See documentation for in_shardings_jax. @@ -306,7 +308,7 @@ def call(self, *args, **kwargs): The invocation supports reverse-mode AD, and all the features supported by exporting: shape polymorphism, multi-platform, device polymorphism. - See the examples in the [JAX export documentation](https://jax.readthedocs.io/en/latest/export/export.html). + See the examples in the [JAX export documentation](https://docs.jax.dev/en/latest/export/export.html). """ return call_exported(self)(*args, **kwargs) @@ -512,13 +514,13 @@ def default_export_platform() -> str: One of: `tpu`, `cpu`, `cuda`, `rocm`. """ # Canonicalize to turn 'gpu' into 'cuda' or 'rocm' - return xb.canonicalize_platform(jax.default_backend()) + return xb.canonicalize_platform(xb.default_backend()) default_lowering_platform = default_export_platform def shape_and_dtype_jax_array(a) -> tuple[Sequence[int | None], DType]: """Returns the shape and dtype of a jax.Array or a j""" - if isinstance(a, jax.ShapeDtypeStruct): + if isinstance(a, api.ShapeDtypeStruct): return a.shape, a.dtype aval = core.get_aval(a) return aval.shape, aval.dtype @@ -529,6 +531,7 @@ def export( *, platforms: Sequence[str] | None = None, disabled_checks: Sequence[DisabledSafetyCheck] = (), + _override_lowering_rules: Sequence[tuple[Any, Any]] | None = None ) -> Callable[..., Exported]: """Exports a JAX function for persistent serialization. @@ -540,7 +543,14 @@ def export( the exported code takes an argument specifying the platform. If None, then use the default JAX backend. The calling convention for multiple platforms is explained at - https://jax.readthedocs.io/en/latest/export/export.html#module-calling-convention. + https://docs.jax.dev/en/latest/export/export.html#module-calling-convention. + _override_lowering_rules: an optional sequence of custom lowering rules + for some JAX primitives. Each element of the sequence is a pair + of a JAX primitive and a lowering function. Defining lowering rules + is an advanced feature using JAX internal APIs, which are subject + to change. Furthermore, the responsibility for the stability of the + MLIR emitted through these custom lowering rules, rests with the user + of these rules. disabled_checks: the safety checks to disable. See documentation for of `jax.export.DisabledSafetyCheck`. @@ -568,7 +578,8 @@ def export( Array([0.09983342, 0.19866933, 0.29552022, 0.38941833], dtype=float32) """ return _export_internal(fun_jit, platforms=platforms, - disabled_checks=disabled_checks) + disabled_checks=disabled_checks, + override_lowering_rules=_override_lowering_rules) # TODO(necula): remove this once we improve the integration with jax2tf. @@ -577,13 +588,14 @@ def _export_internal( *, platforms: Sequence[str] | None = None, disabled_checks: Sequence[DisabledSafetyCheck] = (), - _device_assignment_for_internal_jax2tf_use_only = None, + _device_assignment_for_internal_jax2tf_use_only=None, + override_lowering_rules=None, ) -> Callable[..., Exported]: """Exports native serialization for a JAX function. Note: this function exists only for internal usage by jax2tf. Use `jax.export` instead. - See https://jax.readthedocs.io/en/latest/export/export.html + See https://docs.jax.dev/en/latest/export/export.html See docstring of `export` for more details. """ @@ -604,6 +616,7 @@ def do_export(*args_specs, **kwargs_specs) -> Exported: lowered = traced.lower( lowering_platforms=actual_lowering_platforms, _private_parameters=mlir.LoweringParameters( + override_lowering_rules=override_lowering_rules, for_export=True, export_ignore_forward_compatibility=config.export_ignore_forward_compatibility.value)) return _export_lowered( @@ -658,6 +671,12 @@ def _export_lowered( # For pmap module_kept_var_idx = tuple(range(len(args_avals_flat))) shape_poly_state = lowering.compile_args["shape_poly_state"] + + # Make a copy of mlir module as we should not mutate it + # because it may be cached + context = mlir.make_ir_context() + with context, ir.Location.unknown(context): + mlir_module = ir.Module.parse(mlir.module_to_bytecode(mlir_module)) if (not all(core.is_constant_shape(a.shape) for a in args_avals_flat) or lowering.compile_args.get("ordered_effects", [])): mlir_module = _wrap_main_func( @@ -674,12 +693,10 @@ def _export_lowered( # Shardy was used during lowering if we can find the Shardy mesh in the # module. Note that the mesh should have been lifted by the # `sdy-lift-inlined-meshes` pass in mlir.py. - shardy_enabled = False - if xla_extension_version >= 319: - shardy_enabled = xla_extension.sdy.lowered_with_shardy( - mlir.module_to_bytecode(mlir_module)) + shardy_enabled = _jax.sdy.lowered_with_shardy( + mlir.module_to_bytecode(mlir_module)) - mlir_module_serialized = _module_to_bytecode(mlir_module, shardy_enabled) + mlir_module_serialized = _module_to_bytecode(mlir_module) # Figure out the result types and shapes if "global_out_avals" in lowering.compile_args: @@ -691,16 +708,15 @@ def _export_lowered( out_avals_flat = lowered.compile_args["out_avals"] # type: ignore # Log and then check the module. - if logging.vlog_is_on(3): - logmsg = (f"fun_name={fun_name} version={version} " - f"lowering_platforms={lowering._platforms} " # type: ignore[unused-ignore,attribute-error] - f"disabled_checks={disabled_checks}") - logging.info("Exported JAX function: %s\n", logmsg) - logging.info(mlir.dump_module_message(mlir_module, "export")) - logging.info( - "Size of mlir_module_serialized: %d byte", - len(mlir_module_serialized), - ) + logmsg = (f"fun_name={fun_name} version={version} " + f"lowering_platforms={lowering._platforms} " # type: ignore[unused-ignore,attribute-error] + f"disabled_checks={disabled_checks}") + logger.debug("Exported JAX function: %s\n", logmsg) + logger.debug(mlir.dump_module_message(mlir_module, "export")) + logger.debug( + "Size of mlir_module_serialized: %d byte", + len(mlir_module_serialized), + ) _check_module(mlir_module, disabled_checks=disabled_checks, @@ -730,18 +746,29 @@ def export_sharding(s: LoweringSharding, if _device_assignment_for_internal_jax2tf_use_only is not None: _device_assignment_for_internal_jax2tf_use_only[0] = device_assignment - mesh = None + cur_mesh = cur_arg = cur_k_path = None + # lowered.args_info is a tree of the args, but we need the out avals too to + # get the key paths for. + out_avals_tree = tree_util.tree_unflatten(lowered.out_tree, out_avals_flat) if config.use_shardy_partitioner.value: - for sharding in itertools.chain.from_iterable( - [all_in_shardings, lowering.compile_args["out_shardings"]]): + for sharding, (k_path, arg) in zip( + itertools.chain.from_iterable([ + all_in_shardings, lowering.compile_args["out_shardings"]]), + itertools.chain.from_iterable([ + tree_util.tree_flatten_with_path(lowered.args_info)[0], + tree_util.tree_flatten_with_path(out_avals_tree)[0]])): if isinstance(sharding, sharding_impls.NamedSharding): - if mesh is not None and mesh.shape_tuple != sharding.mesh.shape_tuple: + if cur_mesh is None: + cur_mesh, cur_arg, cur_k_path = sharding.mesh, arg, k_path + elif cur_mesh.shape_tuple != sharding.mesh.shape_tuple: raise ValueError( - f'Mesh for all inputs should be equal. Got one mesh: {mesh} and' - f' another mesh: {sharding.mesh}') - mesh = sharding.mesh - if mesh and isinstance(mesh, mesh_lib.Mesh): - mesh = mesh.abstract_mesh + "Mesh for all inputs/outputs should be equal. Got one mesh " + f"{cur_mesh} on an array {cur_arg._aval} at " # type: ignore[union-attr] + f"{shape_poly.args_kwargs_path_to_str(cur_k_path)} and another mesh: " # type: ignore[arg-type] + f"{sharding.mesh}' on a tensor {arg._aval} at " + f"{shape_poly.args_kwargs_path_to_str(k_path)}") + if cur_mesh and isinstance(cur_mesh, mesh_lib.Mesh): + cur_mesh = cur_mesh.abstract_mesh def _get_exported_vjp(exp_primal: Exported) -> Exported: # Turn the primal jaxpr into a function, in preparation for exporting @@ -759,7 +786,7 @@ def _get_exported_vjp(exp_primal: Exported) -> Exported: device_assignment=device_assignment, apply_jit=True, flat_primal_fun=True, - mesh=mesh) # type: ignore[arg-type] + mesh=cur_mesh) # type: ignore[arg-type] return export(fun_vjp_jax, # type: ignore[arg-type] platforms=exp_primal.platforms, disabled_checks=exp_primal.disabled_safety_checks)(*vjp_in_avals) @@ -783,12 +810,8 @@ def _get_exported_vjp(exp_primal: Exported) -> Exported: calling_convention_version=version, _get_vjp=_get_exported_vjp) -def _module_to_bytecode(module: ir.Module, shardy_enabled: bool) -> bytes: - if xla_extension_version >= 319 and shardy_enabled: - mlir_str = xla_extension.sdy.sdy_round_trip_export_pipeline( - mlir.module_to_bytecode(module)) - else: - mlir_str = mlir.module_to_bytecode(module) +def _module_to_bytecode(module: ir.Module) -> bytes: + mlir_str = mlir.module_to_bytecode(module) # `target_version` is used to manage situations when a StableHLO producer # and a StableHLO consumer were built using different versions of StableHLO. # @@ -828,10 +851,10 @@ def _wrap_main_func( ) -> ir.Module: """Wraps the lowered module with a new "main" handling dimension arguments. - See calling convention documentation https://jax.readthedocs.io/en/latest/export/export.html#module-calling-convention. + See calling convention documentation https://docs.jax.dev/en/latest/export/export.html#module-calling-convention. Args: - module: the HLO module as obtained from lowering. + module: a copy of HLO module as obtained from lowering. args_avals_flat: the avals for all the arguments of the lowered function, which correspond to the array arguments of the `module`. args_kwargs_tree: the PyTreeDef corresponding to `(args, kwargs)`, for error @@ -845,10 +868,9 @@ def _wrap_main_func( Returns the wrapped module, without dimension and token arguments. """ dim_vars = shape_poly.all_dim_vars(args_avals_flat) - context = mlir.make_ir_context() + context = module.context + wrapped_module = module with context, ir.Location.unknown(context): - # Make a copy, do not mutate because it may be cached - wrapped_module = ir.Module.parse(mlir.module_to_bytecode(module)) symbol_table = ir.SymbolTable(wrapped_module.operation) orig_main = symbol_table["main"] orig_main.attributes["sym_visibility"] = ir.StringAttr.get("private") @@ -1066,6 +1088,8 @@ def _check_lowering(lowering) -> None: "hipsolver_gesvd_ffi", "hipsolver_gesvdj_ffi", # tridiagonal on GPU "cusolver_sytrd_ffi", + # tridiagonal_solve on GPU + "cusparse_gtsv2_ffi", ] # These are the JAX custom call target names that are guaranteed to be stable. # Their backwards compatibility is tested by back_compat_test.py. @@ -1073,33 +1097,21 @@ def _check_lowering(lowering) -> None: *_CPU_FFI_KERNELS, *_GPU_FFI_KERNELS, "Sharding", "SPMDFullToShardShape", "SPMDShardToFullShape", + "annotate_device_placement", "cu_threefry2x32_ffi", # Triton IR does not guarantee stability. # "__gpu$xla.gpu.triton", - # cholesky on CPU - "lapack_spotrf", "lapack_dpotrf", "lapack_cpotrf", "lapack_zpotrf", # eigh on TPU "Eigh", - # eig on CPU - "lapack_sgeev", "lapack_dgeev", "lapack_cgeev", "lapack_zgeev", - # svd on CPU - "lapack_sgesdd", "lapack_dgesdd", "lapack_cgesdd", "lapack_zgesdd", # qr and svd on TPU "Qr", "ProductOfElementaryHouseholderReflectors", - # triangular_solve on CPU - "blas_strsm", "blas_dtrsm", "blas_ctrsm", "blas_ztrsm", - # schur on CPU - "lapack_sgees", "lapack_dgees", "lapack_cgees", "lapack_zgees", - # tridiagonal on CPU - "lapack_ssytrd", "lapack_dsytrd", "lapack_chetrd", "lapack_zhetrd", - # hessenberg on CPU - "lapack_sgehrd", "lapack_dgehrd", "lapack_cgehrd", "lapack_zgehrd", # lu on TPU "LuDecomposition", # ApproxTopK on TPU "ApproxTopK", "stablehlo.dynamic_approx_top_k", "tf.call_tf_function", # From jax2tf.call_tf(func, call_tf_graph=True) "tpu_custom_call", # Pallas/TPU kernels + "mosaic_gpu", # Pallas Mosaic GPU kernels # TODO(burmako): maintain backwards compatibility for these, until they # are upstreamed to StableHLO. # See https://github.com/openxla/stablehlo/issues/8. @@ -1177,7 +1189,7 @@ def walk_operations(op): disallowed_custom_call_ops_str = "\n".join(disallowed_custom_call_ops) msg = ("Cannot serialize code with custom calls whose targets have no " "compatibility guarantees. " - "See https://jax.readthedocs.io/en/latest/export/export.html#compatibility-guarantees-for-custom-calls. " + "See https://docs.jax.dev/en/latest/export/export.html#compatibility-guarantees-for-custom-calls. " "Examples are:\n" f"{disallowed_custom_call_ops_str}.\n") raise ValueError(msg) @@ -1198,22 +1210,13 @@ def expand_in_shardings(in_shardings: Sequence[LoweringSharding], return tuple(all_in_shardings) -def _hlo_sharding_to_xla_compatible_sharding( - hlo_sharding: HloSharding | None, - mesh: sharding.Mesh) -> sharding.Sharding | None: - if hlo_sharding is None: - return None - return sharding_impls._gspmd_to_named_sharding_via_mesh( - _hlo_sharding_to_gspmd_sharding(hlo_sharding, tuple(mesh.devices.flat)), # type: ignore[arg-type] - mesh) - - def _hlo_sharding_to_gspmd_sharding( hlo_sharding: HloSharding | None, - device_assignment: Sequence[jax.Device]) -> sharding.GSPMDSharding | None: + device_assignment: Sequence[_jax.Device] + ) -> sharding_impls.GSPMDSharding | None: if hlo_sharding is None: return None - return sharding.GSPMDSharding(device_assignment, hlo_sharding) + return sharding_impls.GSPMDSharding(device_assignment, hlo_sharding) def _hlo_sharding_to_named_sharding( @@ -1254,7 +1257,7 @@ def flattened_primal_fun_jax(*args_flat): args_flat_jax, out_cts_flat_jax = util.split_list(args_and_out_cts_flat_jax, [len(in_avals)]) - _, pullback_jax = jax.vjp(primal_fun if flat_primal_fun else flattened_primal_fun_jax, + _, pullback_jax = api.vjp(primal_fun if flat_primal_fun else flattened_primal_fun_jax, *args_flat_jax) return pullback_jax(out_cts_flat_jax) @@ -1286,12 +1289,12 @@ def flattened_primal_fun_jax(*args_flat): ### Calling the exported function -def call(exported: Exported) -> Callable[..., jax.Array]: +def call(exported: Exported) -> Callable[..., typing.Array]: if not isinstance(exported, Exported): raise ValueError( "The exported argument must be an export.Exported. " f"Found {exported}.") - @jax.custom_vjp + @custom_derivatives.custom_vjp def f_flat(*args_flat): return call_exported_p.bind(*args_flat, exported=exported) @@ -1400,7 +1403,7 @@ def pp_arg_dim(dim_idx: int | None) -> str: # it would be ambiguous whether we should continue tracing with a result # of type `f32[c]` or `f32[d]`. shape_constraints.check_statically(synthetic_eval) - exported_dim_values = [synthetic_eval.evaluate(solution[var]) + exported_dim_values = [synthetic_eval.evaluate(solution[var]) # type: ignore[arg-type] for var in exported_dim_vars] out_avals = tuple( core.ShapedArray(core.evaluate_shape(out_aval.shape, exported_dim_vars, @@ -1423,25 +1426,23 @@ def _call_exported_lowering(ctx: mlir.LoweringRuleContext, *args, ctx.module_context.shape_poly_state.uses_dim_vars = True submodule = ir.Module.parse(exported.mlir_module()) - shardy_enabled = False - if xla_extension_version >= 319: - shardy_enabled = xla_extension.sdy.lowered_with_shardy( - mlir.module_to_bytecode(submodule)) + submodule_bc = mlir.module_to_bytecode(submodule) + shardy_enabled = _jax.sdy.lowered_with_shardy(submodule_bc) if shardy_enabled: - submodule = ir.Module.parse(xla_extension.sdy.sdy_round_trip_import_shardings( - mlir.module_to_bytecode(submodule))) + submodule = ir.Module.parse( + _jax.sdy.sdy_round_trip_import_shardings(submodule_bc) + ) with submodule.context: pipeline = passmanager.PassManager.parse( 'builtin.module(sdy-lift-inlined-meshes)') pipeline.run(submodule.operation) - # TODO(bartchr): delete this once I have JAX export support multiple meshes. mesh = None if shardy_enabled: - sdy_mesh_axes = xla_extension.sdy.get_mesh(mlir.module_to_bytecode(submodule)) - mesh = mesh_lib.AbstractMesh( - *list(zip(*sdy_mesh_axes))[::-1]) if sdy_mesh_axes else None + sdy_mesh_axes = _jax.sdy.get_mesh(mlir.module_to_bytecode(submodule)) + mesh = (mesh_lib.AbstractMesh(*list(zip(*sdy_mesh_axes))[::-1]) + if sdy_mesh_axes else mesh_lib.empty_abstract_mesh) axis_context = ctx.module_context.axis_context if isinstance(axis_context, sharding_impls.ShardingContext): diff --git a/jax/_src/export/serialization.fbs b/jax/_src/export/serialization.fbs index 7d3e342f1879..01cfa9944dfd 100644 --- a/jax/_src/export/serialization.fbs +++ b/jax/_src/export/serialization.fbs @@ -45,7 +45,7 @@ enum AbstractValueKind: byte { } enum DType: byte { - // Last used id: 22 + // Last used id: 29 bool = 0, i8 = 1, i16 = 2, @@ -76,6 +76,10 @@ enum DType: byte { f8_e5m2fnuz = 21, f8_e8m0fnu = 25, f4_e2m1fn = 26, + + key_fry = 27, + key_rbg = 28, + key_unsafe_rbg = 29, } table AbstractValue { diff --git a/jax/_src/export/serialization.py b/jax/_src/export/serialization.py index ac97c11d1177..dd33de846c42 100644 --- a/jax/_src/export/serialization.py +++ b/jax/_src/export/serialization.py @@ -19,7 +19,7 @@ import types from collections.abc import Callable, Sequence from functools import partial -from typing import TypeVar +from typing import Any, TypeVar try: import flatbuffers @@ -48,6 +48,8 @@ # Version 2, Dec 16th, 2023, adds the f0 dtype. # Version 3, October 16th, 2024, adds serialization for namedtuple and custom types # This version is backwards compatible with Version 2. +# Version 4, April 7th, 2025, adds serialization for PRNGs key types. +# This version is backwards compatible with Version 2 and 3. _SERIALIZATION_VERSION = 2 def serialize(exp: _export.Exported, vjp_order: int = 0) -> bytearray: @@ -357,21 +359,22 @@ def _deserialize_pytreedef_to_pytree(p: ser_flatbuf.PyTreeDef): dtypes._float8_e4m3fnuz_dtype: ser_flatbuf.DType.f8_e4m3fnuz, dtypes._float8_e5m2_dtype: ser_flatbuf.DType.f8_e5m2, dtypes._float8_e5m2fnuz_dtype: ser_flatbuf.DType.f8_e5m2fnuz, + dtypes._float8_e3m4_dtype: ser_flatbuf.DType.f8_e3m4, + dtypes._float8_e4m3_dtype: ser_flatbuf.DType.f8_e4m3, + dtypes._float8_e8m0fnu_dtype: ser_flatbuf.DType.f8_e8m0fnu, + dtypes._float4_e2m1fn_dtype: ser_flatbuf.DType.f4_e2m1fn, } -if dtypes._float8_e3m4_dtype is not None: - _dtype_to_dtype_kind[dtypes._float8_e3m4_dtype] = ser_flatbuf.DType.f8_e3m4 -if dtypes._float8_e4m3_dtype is not None: - _dtype_to_dtype_kind[dtypes._float8_e4m3_dtype] = ser_flatbuf.DType.f8_e4m3 -if dtypes._float8_e8m0fnu_dtype is not None: - _dtype_to_dtype_kind[dtypes._float8_e8m0fnu_dtype] = ser_flatbuf.DType.f8_e8m0fnu -if dtypes._float4_e2m1fn_dtype is not None: - _dtype_to_dtype_kind[dtypes._float4_e2m1fn_dtype] = ser_flatbuf.DType.f4_e2m1fn _dtype_kind_to_dtype = { kind: dtype for dtype, kind in _dtype_to_dtype_kind.items() } +def register_dtype_kind(dtype: Any, kind: int): + _dtype_to_dtype_kind[dtype] = kind + _dtype_kind_to_dtype[kind] = dtype + + def _serialize_aval( builder: flatbuffers.Builder, aval: core.ShapedArray ) -> int: diff --git a/jax/_src/export/serialization_generated.py b/jax/_src/export/serialization_generated.py index b1fc13333777..5a3ba8f72322 100644 --- a/jax/_src/export/serialization_generated.py +++ b/jax/_src/export/serialization_generated.py @@ -21,7 +21,7 @@ from flatbuffers.compat import import_numpy np = import_numpy() -class PyTreeDefKind(object): +class PyTreeDefKind: leaf = 0 none = 1 tuple = 2 @@ -30,12 +30,12 @@ class PyTreeDefKind(object): custom = 5 -class AbstractValueKind(object): +class AbstractValueKind: shapedArray = 0 abstractToken = 1 -class DType(object): +class DType: bool = 0 i8 = 1 i16 = 2 @@ -53,30 +53,33 @@ class DType(object): bf16 = 14 i4 = 15 ui4 = 16 - f8_e3m4 = 24 - f8_e4m3 = 23 f8_e4m3b11fnuz = 17 f8_e4m3fn = 18 f8_e4m3fnuz = 19 f8_e5m2 = 20 f8_e5m2fnuz = 21 f0 = 22 + f8_e4m3 = 23 + f8_e3m4 = 24 f8_e8m0fnu = 25 f4_e2m1fn = 26 + key_fry = 27 + key_rbg = 28 + key_unsafe_rbg = 29 -class ShardingKind(object): +class ShardingKind: unspecified = 0 hlo_sharding = 1 -class DisabledSafetyCheckKind(object): +class DisabledSafetyCheckKind: platform = 0 custom_call = 1 shape_assertions = 2 -class PyTreeDef(object): +class PyTreeDef: __slots__ = ['_tab'] @classmethod @@ -211,7 +214,7 @@ def PyTreeDefEnd(builder): -class AbstractValue(object): +class AbstractValue: __slots__ = ['_tab'] @classmethod @@ -283,7 +286,7 @@ def AbstractValueEnd(builder): -class Sharding(object): +class Sharding: __slots__ = ['_tab'] @classmethod @@ -352,7 +355,7 @@ def ShardingEnd(builder): -class Effect(object): +class Effect: __slots__ = ['_tab'] @classmethod @@ -388,7 +391,7 @@ def EffectEnd(builder): -class DisabledSafetyCheck(object): +class DisabledSafetyCheck: __slots__ = ['_tab'] @classmethod @@ -434,7 +437,7 @@ def DisabledSafetyCheckEnd(builder): -class Exported(object): +class Exported: __slots__ = ['_tab'] @classmethod diff --git a/jax/_src/export/shape_poly.py b/jax/_src/export/shape_poly.py index 6a6ce93712ff..c89ae2cc04ca 100644 --- a/jax/_src/export/shape_poly.py +++ b/jax/_src/export/shape_poly.py @@ -13,7 +13,7 @@ # limitations under the License. """Shape polymorphism support. -See documentation at https://jax.readthedocs.io/en/latest/export/shape_poly.html. +See documentation at https://docs.jax.dev/en/latest/export/shape_poly.html. """ from __future__ import annotations @@ -34,23 +34,21 @@ import numpy as np import opt_einsum -import jax - +from jax._src import api from jax._src import config from jax._src import core from jax._src import dtypes from jax._src import effects -from jax._src.lax import lax from jax._src.interpreters import mlir -from jax._src.numpy import einsum as jnp_einsum from jax._src import source_info_util from jax._src import tree_util +from jax._src import typing from jax._src import util DimSize = Union["_DimExpr", int] TfVal = Any -DimVarEnv = dict[str, jax.Array] +DimVarEnv = dict[str, typing.Array] DType = Any # Tuples of terms and their coefficients, sorted with the largest term first. @@ -70,7 +68,7 @@ class InconclusiveDimensionOperation(core.InconclusiveDimensionOperation): are non-constant, and the result of the operation cannot be represented as a boolean value for all values of the symbolic dimensions involved. -Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#comparison-of-symbolic-dimensions-is-partially-supported +Please see https://docs.jax.dev/en/latest/export/shape_poly.html#comparison-of-symbolic-dimensions-is-partially-supported for more details. """ @@ -214,6 +212,8 @@ def __ge__(self, other: _DimFactor): return self._syntactic_cmp(other) >= 0 def evaluate(self, env: DimVarEnv, scope: SymbolicScope): + from jax._src.lax import lax # pytype: disable=import-error + if self.var is not None: try: return env[self.var] @@ -227,7 +227,7 @@ def evaluate(self, env: DimVarEnv, scope: SymbolicScope): return normalized_var._evaluate(env) # type: ignore err_msg = ( f"Encountered dimension variable '{self.var}' that is not appearing in the shapes of the function arguments.\n" - "Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#dimension-variables-must-be-solvable-from-the-input-shapes for more details.") + "Please see https://docs.jax.dev/en/latest/export/shape_poly.html#dimension-variables-must-be-solvable-from-the-input-shapes for more details.") raise UnexpectedDimVar(err_msg) else: operand_values = [opnd._evaluate(env) for opnd in self.operands] @@ -654,7 +654,7 @@ def _eq(self, other: _DimExpr) -> bool: # Here we really ought to raise InconclusiveDimensionOperation, but __eq__ # cannot raise exceptions, because it is used indirectly when hashing. # So, we say that the expressions are disequal, which is really unsound. - # See https://jax.readthedocs.io/en/latest/export/shape_poly.html#comparison-of-symbolic-dimensions-is-partially-supported + # See https://docs.jax.dev/en/latest/export/shape_poly.html#comparison-of-symbolic-dimensions-is-partially-supported return False return diff == 0 @@ -841,7 +841,7 @@ def __eq__(self, other: Any) -> bool: # Here we really ought to raise InconclusiveDimensionOperation, but __eq__ # cannot raise exceptions, because it is used indirectly when hashing. # So, we say that the expressions are disequal, which is really unsound. - # See https://jax.readthedocs.io/en/latest/export/shape_poly.html#comparison-of-symbolic-dimensions-is-partially-supported + # See https://docs.jax.dev/en/latest/export/shape_poly.html#comparison-of-symbolic-dimensions-is-partially-supported return False return diff == 0 @@ -978,7 +978,7 @@ def cmp_sequence(s1, s2, elem_cmp) -> int: class SymbolicScope: - """Indentifies a scope for symbolic expressions. + """Identifies a scope for symbolic expressions. All symbolic expressions that interact (e.g., appear in the argument shapes for one JAX function invocation, or are involved in arithmetic operations) @@ -986,7 +986,7 @@ class SymbolicScope: Holds the constraints on symbolic expressions. - See [the README](https://jax.readthedocs.io/en/latest/export/shape_poly.html#user-specified-symbolic-constraints) + See [the README](https://docs.jax.dev/en/latest/export/shape_poly.html#user-specified-symbolic-constraints) for more details. Args: @@ -1112,7 +1112,7 @@ def _check_same_scope(self, other: _DimExpr, f"Invalid mixing of symbolic scopes {when}.\n" f"Expected {self_descr}scope {self}\n" f"and found for '{other}' ({other_descr}) scope {other.scope}\n" - f"See https://jax.readthedocs.io/en/latest/export/shape_poly.html#user-specified-symbolic-constraints.") + f"See https://docs.jax.dev/en/latest/export/shape_poly.html#user-specified-symbolic-constraints.") def _clear_caches(self): self._bounds_cache.clear() @@ -1255,7 +1255,7 @@ def fake_dim(d): # here some errors due to non-equal dimensions, but we catch them # later. return 8 - fake_ops.append(jax.ShapeDtypeStruct(tuple(map(fake_dim, shape)), + fake_ops.append(api.ShapeDtypeStruct(tuple(map(fake_dim, shape)), operand.dtype)) contract_fake_ops, contractions = opt_einsum.contract_path(*fake_ops, @@ -1267,8 +1267,6 @@ def fake_dim(d): contract_operands.append(operands[idx[0]]) return contract_operands, contractions -jnp_einsum._poly_einsum_handlers[_DimExpr] = _einsum_contract_path - # To implement shape-constraint checking we use a shape assertion primitive. # shape_assertion_p.bind(assert_what: bool, *error_message_inputs, # error_message="...{0}...{1}") @@ -1303,8 +1301,8 @@ class ShapeAssertionEffect(effects.Effect): effects.remat_allowed_effects.add_type(ShapeAssertionEffect) effects.custom_derivatives_allowed_effects.add_type(ShapeAssertionEffect) -def shape_assertion(assert_what: jax.Array, - *error_message_inputs: jax.Array, +def shape_assertion(assert_what: typing.Array, + *error_message_inputs: typing.Array, error_message: str) -> None: """Adds a shape assertion in the code. @@ -1384,7 +1382,7 @@ def symbolic_shape(shape_spec: str | None, ) -> Sequence[DimSize]: """Constructs a symbolic shape from a string representation. - See https://jax.readthedocs.io/en/latest/export/shape_poly.html for examples. + See https://docs.jax.dev/en/latest/export/shape_poly.html for examples. Args: shape_spec: a symbolic shape specification. None stands for "...". @@ -1396,13 +1394,13 @@ def symbolic_shape(shape_spec: str | None, mod(e1, e2), max(e1, e2), or min(e1, e2). constraints: a sequence of constraints on symbolic dimension expressions, of the form `e1 >= e2` or `e1 <= e2`, or `e1 == e2`. - See [the documentation](https://jax.readthedocs.io/en/latest/export/shape_poly.html#user-specified-symbolic-constraints) + See [the documentation](https://docs.jax.dev/en/latest/export/shape_poly.html#user-specified-symbolic-constraints) for usage. scope: optionally, you can specify that the parsed symbolic expressions be created in the given scope. If this is missing, then a new `SymbolicScope` is created with the given `constraints`. You cannot specify both a `scope` and `constraints`. - See [the documentation](https://jax.readthedocs.io/en/latest/export/shape_poly.html#user-specified-symbolic-constraints) + See [the documentation](https://docs.jax.dev/en/latest/export/shape_poly.html#user-specified-symbolic-constraints) for usage. like: when `shape_spec` contains placeholders ("_", "..."), use this shape to fill in the placeholders. @@ -1437,7 +1435,7 @@ def symbolic_args_specs( """Constructs a pytree of jax.ShapeDtypeSpec arguments specs for `export`. See the documentation of :func:`jax.export.symbolic_shape` and - the [shape polymorphism documentation](https://jax.readthedocs.io/en/latest/export/shape_poly.html) for details. + the [shape polymorphism documentation](https://docs.jax.dev/en/latest/export/shape_poly.html) for details. Args: args: a pytree of arguments. These can be jax.Array, or jax.ShapeDTypeSpec. @@ -1450,7 +1448,7 @@ def symbolic_args_specs( applies to all arguments), or a pytree matching a prefix of the `args`. See [how optional parameters are matched to - arguments](https://jax.readthedocs.io/en/latest/pytrees.html#applying-optional-parameters-to-pytrees). + arguments](https://docs.jax.dev/en/latest/pytrees.html#applying-optional-parameters-to-pytrees). constraints: as for :func:`jax.export.symbolic_shape`. scope: as for :func:`jax.export.symbolic_shape`. @@ -1485,14 +1483,14 @@ def symbolic_args_specs( elif constraints: raise ValueError("Cannot use both `scope` and `constraints`") args_specs_flat = ( - jax.ShapeDtypeStruct(symbolic_shape(spec, like=s, scope=scope), t) + api.ShapeDtypeStruct(symbolic_shape(spec, like=s, scope=scope), t) for s, t, spec in zip(shapes, dtypes, polymorphic_shapes_flat)) return args_tree.unflatten(args_specs_flat) def shape_and_dtype_jax_array(a) -> tuple[Sequence[int | None], DType]: """Returns the shape and dtype of a jax.Array or a j""" - if isinstance(a, jax.ShapeDtypeStruct): + if isinstance(a, api.ShapeDtypeStruct): return a.shape, a.dtype aval = core.get_aval(a) return aval.shape, aval.dtype @@ -1785,7 +1783,7 @@ def check_statically(self, eval: ShapeEvaluator) -> None: if not ok: raise self.make_error(eval) - def compute(self, eval: ShapeEvaluator) -> jax.Array | None: + def compute(self, eval: ShapeEvaluator) -> typing.Array | None: """Computes if the constraint is satisfied. If the constraint can be resolved statically returns None @@ -1793,6 +1791,8 @@ def compute(self, eval: ShapeEvaluator) -> jax.Array | None: resolved statically, returns a value representing if the constraint is satisfied. """ + from jax._src.lax import lax # pytype: disable=import-error + left, right = eval.evaluate(self.left), eval.evaluate(self.right) # Try to evaluate the constraint statically. if core.is_constant_shape((left, right)): @@ -1997,8 +1997,8 @@ def solve_dim_vars( def compute_dim_vars_from_arg_shapes( args_avals: Sequence[core.ShapedArray], - *actual_args: jax.Array, - args_kwargs_tree: tree_util.PyTreeDef) -> Sequence[jax.Array]: + *actual_args: typing.Array, + args_kwargs_tree: tree_util.PyTreeDef) -> Sequence[typing.Array]: """Computes values of dimension variables to unify args_avals with actual arguments. Like `solve_dim_vars` except that here we express the solution as @@ -2021,7 +2021,7 @@ def compute_dim_vars_from_arg_shapes( } synthetic_eval = ShapeEvaluator(synthetic_env) shape_constraints.shape_assertions(synthetic_eval) - return tuple(synthetic_eval.evaluate(solution[var]) for var in dim_vars) + return tuple(synthetic_eval.evaluate(solution[var]) for var in dim_vars) # type: ignore[arg-type] def _solve_dim_equations( eqns: list[_DimEquation], @@ -2038,7 +2038,7 @@ def _solve_dim_equations( " Using the following polymorphic shapes specifications: " + ",".join(f"{arg_name}.shape = {arg_spec}" for arg_name, arg_spec in polymorphic_shape_specs)) + "." - solution_err_msg_trailer_errors = ". Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#shape-assertion-errors for more details." + solution_err_msg_trailer_errors = ". Please see https://docs.jax.dev/en/latest/export/shape_poly.html#shape-assertion-errors for more details." shape_constraints = ShapeConstraints() # accumulate shape constraints scope: SymbolicScope | None = None @@ -2171,6 +2171,6 @@ def add_explicit_symbolic_constraints(shape_env: DimVarEnv): " Unprocessed specifications: " + ", ".join(f"'{eqn.aval_dim_expr}' for dimension size {eqn.dim_name}" for eqn in eqns) + - ". Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#dimension-variables-must-be-solvable-from-the-input-shapes for more details." + ". Please see https://docs.jax.dev/en/latest/export/shape_poly.html#dimension-variables-must-be-solvable-from-the-input-shapes for more details." ) raise ValueError(err_msg) diff --git a/jax/_src/ffi.py b/jax/_src/ffi.py index 05697f00e945..e774f98aaa78 100644 --- a/jax/_src/ffi.py +++ b/jax/_src/ffi.py @@ -16,19 +16,19 @@ from collections.abc import Callable, Mapping, Sequence import ctypes +import dataclasses import functools import os from typing import Any, overload import numpy as np -import jax from jax._src import core -from jax._src import deprecations from jax._src import dispatch from jax._src import effects from jax._src import util from jax._src import xla_bridge +from jax._src.hashable_array import HashableArray from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir @@ -39,11 +39,6 @@ from jax._src.typing import (Array, ArrayLike, DeprecatedArg, DuckTypedArray, Shape) -# TODO(dfm): Remove after 6 months or less because there aren't any offical -# compatibility guarantees for jax.extend (see JEP 15856) -# Added Oct 13, 2024 -deprecations.register("jax-ffi-call-args") - map, unsafe_map = util.safe_map, map FfiLayoutOptions = Sequence[int] | DeviceLocalLayout | None @@ -61,7 +56,7 @@ def register_ffi_target( name: the name of the target. fn: a ``PyCapsule`` object containing the function pointer, or a ``dict`` where the keys are FFI stage names (e.g. `"execute"`) and the values are - ``PyCapsule`` objects continaing a pointer to the handler for that stage. + ``PyCapsule`` objects containing a pointer to the handler for that stage. platform: the target platform. api_version: the XLA custom call API version to use. Supported versions are: 1 (default) for the typed FFI or 0 for the earlier "custom call" API. @@ -141,7 +136,7 @@ def include_dir() -> str: def _aval_shape(aval: core.AbstractValue) -> Shape: - return () if aval is core.abstract_token else aval.shape # pytype: disable=attribute-error + return () if aval is core.abstract_token else core.physical_aval(aval).shape # pytype: disable=attribute-error def _convert_layout_for_lowering( @@ -325,7 +320,7 @@ def _convert_layouts_for_ffi_call( def ffi_call( target_name: str, result_shape_dtypes: ResultMetadata, - *deprecated_args: ArrayLike, + *, has_side_effect: bool = ..., vmap_method: str | None = ..., input_layouts: Sequence[FfiLayoutOptions] | None = ..., @@ -333,9 +328,8 @@ def ffi_call( input_output_aliases: dict[int, int] | None = ..., custom_call_api_version: int = ..., legacy_backend_config: str | None = ..., - vectorized: bool | DeprecatedArg = ..., - **deprecated_kwargs: Any, -) -> Callable[..., Array] | Array: + vectorized: bool | None | DeprecatedArg = DeprecatedArg(), +) -> Callable[..., Array]: ... @@ -343,7 +337,7 @@ def ffi_call( def ffi_call( target_name: str, result_shape_dtypes: Sequence[ResultMetadata], - *deprecated_args: ArrayLike, + *, has_side_effect: bool = ..., vmap_method: str | None = ..., input_layouts: Sequence[FfiLayoutOptions] | None = ..., @@ -351,16 +345,15 @@ def ffi_call( input_output_aliases: dict[int, int] | None = ..., custom_call_api_version: int = ..., legacy_backend_config: str | None = ..., - vectorized: bool | DeprecatedArg = ..., - **deprecated_kwargs: Any, -) -> Callable[..., Sequence[Array]] | Sequence[Array]: + vectorized: bool | None | DeprecatedArg = DeprecatedArg(), +) -> Callable[..., Sequence[Array]]: ... def ffi_call( target_name: str, result_shape_dtypes: ResultMetadata | Sequence[ResultMetadata], - *deprecated_args: ArrayLike, + *, has_side_effect: bool = False, vmap_method: str | None = None, input_layouts: Sequence[FfiLayoutOptions] | None = None, @@ -368,16 +361,15 @@ def ffi_call( input_output_aliases: dict[int, int] | None = None, custom_call_api_version: int = 4, legacy_backend_config: str | None = None, - vectorized: bool | DeprecatedArg = DeprecatedArg(), - **deprecated_kwargs: Any, -) -> Callable[..., Array | Sequence[Array]] | Array | Sequence[Array]: + vectorized: bool | None | DeprecatedArg = DeprecatedArg(), +) -> Callable[..., Array | Sequence[Array]]: """Call a foreign function interface (FFI) target. See the :ref:`ffi-tutorial` tutorial for more information. Like :func:`~jax.pure_callback`, the behavior of ``ffi_call`` under :func:`~jax.vmap` depends on the value of ``vmap_method``. See the - :func:`~jax.pure_callback` documenation for more details about the allowed + :func:`~jax.pure_callback` documentation for more details about the allowed values and examples of their behavior. The current default behavior is to use ``vmap_method="sequential"`` when @@ -430,18 +422,11 @@ def ffi_call( to execute the FFI handler. Any keyword arguments are passed as named attributes to the FFI handler using XLA's FFI interface. """ - if not isinstance(vectorized, DeprecatedArg) and not vectorized is None: - deprecations.warn( - "jax-callback-vectorized", - "The vectorized argument of ffi_call is deprecated and setting " - "it will soon raise an error. To avoid an error in the future, and to " - "suppress this warning, please use the vmap_method argument instead.", - stacklevel=2) - if vmap_method is not None: - raise ValueError( - "the vectorized and vmap_method arguments of ffi_call cannot " - "be used together. Please use the vmap_method argument.") - vmap_method = "legacy_vectorized" if vectorized else "sequential" + # TODO(danfm): Remove this check 3 months after v0.6.0 is released. + if not isinstance(vectorized, DeprecatedArg): + raise ValueError( + "The 'vectorized' argument of jax.ffi.ffi_call was removed in JAX " + "v0.6.0. Use 'vmap_method' instead.") allowed_vmap_methods = ["sequential", "sequential_unrolled", "expand_dims", "broadcast_all", "legacy_vectorized", None] if vmap_method not in allowed_vmap_methods: @@ -515,11 +500,10 @@ def wrapped(*args: ArrayLike, **kwargs: Any): "and an output with a different layout " f"{static_output_layouts[o_idx]}.") static_input_output_aliases += ((i_idx, o_idx),) - + args = core.standard_insert_pvary(*args) results = ffi_call_p.bind( *args, result_avals=result_avals, - vectorized=vectorized, vmap_method=vmap_method, target_name=target_name, has_side_effect=has_side_effect, @@ -537,19 +521,7 @@ def wrapped(*args: ArrayLike, **kwargs: Any): else: return results[0] - if deprecated_args or deprecated_kwargs: - deprecations.warn( - "jax-ffi-call-args", - "Calling ffi_call directly with input arguments is deprecated. " - "Instead, ffi_call should be used to construct a callable, which can " - "then be called with the appropriate inputs. For example,\n" - " ffi_call('target_name', output_type, x, argument=5)\n" - "should be replaced with\n" - " ffi_call('target_name', output_type)(x, argument=5)", - stacklevel=2) - return wrapped(*deprecated_args, **deprecated_kwargs) - else: - return wrapped + return wrapped # ffi_call must support some small non-hashable input arguments, like np.arrays @@ -587,24 +559,6 @@ def _unwrap_kwargs_hashable(kwargs: Sequence[tuple[str, Any]]) -> dict[str, Any] return unwrapped_kwargs -class HashableArray: - __slots__ = ["val"] - - def __init__(self, val): - assert isinstance(val, np.ndarray) - self.val = np.copy(val) - self.val.setflags(write=False) - - def __repr__(self): - return f"HashableArray({self.val})" - - def __hash__(self): - return hash((self.val.shape, self.val.dtype, self.val.tobytes())) - - def __eq__(self, other): - return isinstance(other, HashableArray) and np.array_equal(self.val, other.val) - - class HashableDict: __slots__ = ["val"] @@ -622,6 +576,7 @@ def __eq__(self, other): return isinstance(other, HashableDict) and self.val == other.val +@dataclasses.dataclass(frozen=True) class FfiEffect(effects.Effect): def __str__(self): return "FFI" @@ -638,9 +593,10 @@ def ffi_call_abstract_eval( has_side_effect: bool, **_, ): - del avals_in # unused + out_vma = core.standard_vma_rule('ffi_call', *avals_in) effects = {_FfiEffect} if has_side_effect else core.no_effects - return result_avals, effects + return tuple(r if r is core.abstract_token else r.update(vma=out_vma) + for r in result_avals), effects def ffi_call_jvp(*args, target_name, **_): @@ -684,20 +640,12 @@ def ffi_batching_rule( args, dims, *, - vectorized: bool | None | DeprecatedArg, vmap_method: str | None, result_avals: Sequence[core.ShapedArray], **kwargs: Any, ): - if isinstance(vectorized, DeprecatedArg) and vmap_method is None: - deprecations.warn( - "jax-callback-vectorized", - f"The default behavior of {prim.name} under vmap will soon " - "change. Currently, the default behavior is to generate a sequential " - "vmap (i.e. a loop), but in the future the default will be to raise " - "an error. To keep the current default, set vmap_method='sequential'.", - stacklevel=6) - vmap_method = "sequential" + from jax._src.lax import control_flow # pytype: disable=import-error + from jax._src.lax import lax # pytype: disable=import-error axis_size, = {a.shape[d] for a, d in zip(args, dims) if d is not batching.not_mapped} @@ -726,7 +674,6 @@ def ffi_batching_rule( for layout, d in zip(kwargs["input_layouts"], dims)) outvals = prim.bind( *new_args, - vectorized=vectorized, vmap_method=vmap_method, result_avals=batched_result_avals, **kwargs, @@ -734,7 +681,7 @@ def ffi_batching_rule( elif vmap_method == "expand_dims" or vmap_method == "broadcast_all": size = axis_size if vmap_method == "broadcast_all" else 1 bcast_args = [ - jax.lax.broadcast(x, (size,)) if d is batching.not_mapped else x + lax.broadcast(x, (size,)) if d is batching.not_mapped else x for x, d in zip(new_args, dims)] if kwargs.get("input_layouts") is not None: kwargs["input_layouts"] = tuple( @@ -742,7 +689,6 @@ def ffi_batching_rule( for layout in kwargs["input_layouts"]) outvals = prim.bind( *bcast_args, - vectorized=vectorized, vmap_method=vmap_method, result_avals=batched_result_avals, **kwargs, @@ -755,13 +701,12 @@ def _batch_fun(batched_args): return prim.bind( *merged_args, result_avals=result_avals, - vectorized=vectorized, vmap_method=vmap_method, **kwargs, ) unroll = vmap_method == "sequential_unrolled" g = lambda _, x: ((), _batch_fun(x)) - _, outvals = jax.lax.scan(g, (), batched_args, unroll=unroll) + _, outvals = control_flow.scan(g, (), batched_args, unroll=unroll) else: raise NotImplementedError( f"vmap is only supported for the {prim.name} primitive when vmap_method " diff --git a/jax/_src/flatten_util.py b/jax/_src/flatten_util.py index ff35b8db8e25..bf3bd33f286a 100644 --- a/jax/_src/flatten_util.py +++ b/jax/_src/flatten_util.py @@ -41,7 +41,7 @@ def ravel_pytree(pytree): component of the output. For details on dtype promotion, see - https://jax.readthedocs.io/en/latest/type_promotion.html. + https://docs.jax.dev/en/latest/type_promotion.html. """ leaves, treedef = tree_flatten(pytree) diff --git a/jax/_src/frozen_dict.py b/jax/_src/frozen_dict.py new file mode 100644 index 000000000000..c110717d80b5 --- /dev/null +++ b/jax/_src/frozen_dict.py @@ -0,0 +1,49 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, TypeVar +from collections.abc import Iterator, Mapping + +K = TypeVar("K") +V = TypeVar("V") + + +class FrozenDict(Mapping[K, V]): + + def __init__(self, d: Mapping[K, V]): + self._d = dict(d.items()) + + def __repr__(self) -> str: + return f"FrozenDict({self._d!r})" + + def __str__(self) -> str: + return f"FrozenDict({self._d})" + + def __getitem__(self, key: K) -> V: + return self._d[key] + + def __hash__(self) -> int: + # This assumes that the values are hashable. + return hash(frozenset(self._d.items())) + + def __eq__(self, other: Any) -> bool: + if not isinstance(other, FrozenDict): + return False + return self._d == other._d + + def __iter__(self) -> Iterator[K]: + return iter(self._d) + + def __len__(self) -> int: + return len(self._d) diff --git a/jax/_src/hashable_array.py b/jax/_src/hashable_array.py new file mode 100644 index 000000000000..4757a9c5eb24 --- /dev/null +++ b/jax/_src/hashable_array.py @@ -0,0 +1,37 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the + +import numpy as np + + +class HashableArray: + __slots__ = ["val"] + val: np.ndarray + + def __init__(self, val): + self.val = np.array(val, copy=True) + self.val.setflags(write=False) + + def __repr__(self): + return f"HashableArray({self.val!r})" + + def __str__(self): + return f"HashableArray({self.val})" + + def __hash__(self): + return hash((self.val.shape, self.val.dtype, self.val.tobytes())) + + def __eq__(self, other): + return isinstance(other, HashableArray) and np.array_equal( + self.val, other.val + ) diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/__init__.py b/jax/_src/internal_test_util/export_back_compat_test_data/__init__.py new file mode 100644 index 000000000000..3da0dd1fa3ca --- /dev/null +++ b/jax/_src/internal_test_util/export_back_compat_test_data/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2025 The JAX Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/annotate_data_placement.py b/jax/_src/internal_test_util/export_back_compat_test_data/annotate_data_placement.py new file mode 100644 index 000000000000..d3aab292c9fe --- /dev/null +++ b/jax/_src/internal_test_util/export_back_compat_test_data/annotate_data_placement.py @@ -0,0 +1,134 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ruff: noqa + +import datetime +from numpy import array, float32 + +data_2025_04_07_tpu = {} + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2025_04_07_tpu['gspmd'] = dict( + testdata_version=1, + platform='tpu', + custom_call_targets=['annotate_device_placement'], + serialized_date=datetime.date(2025, 4, 7), + inputs=(array([0.], dtype=float32), array([0.], dtype=float32)), + expected_outputs=(array([0.], dtype=float32),), + mlir_module_text=r""" +#loc1 = loc("x") +#loc2 = loc("y") +module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor<1xf32> {mhlo.memory_kind = "device", mhlo.sharding = "{maximal device=0}"} loc("x"), %arg1: tensor<1xf32> {mhlo.memory_kind = "pinned_host", mhlo.sharding = "{maximal device=0}"} loc("y")) -> (tensor<1xf32> {jax.result_info = "result", mhlo.memory_kind = "pinned_host", mhlo.sharding = "{maximal device=0}"}) { + %0 = stablehlo.add %arg0, %arg1 : tensor<1xf32> loc(#loc4) + %1 = stablehlo.custom_call @annotate_device_placement(%0) {has_side_effect = true, mhlo.frontend_attributes = {_xla_buffer_placement = "pinned_host"}} : (tensor<1xf32>) -> tensor<1xf32> loc(#loc) + return %1 : tensor<1xf32> loc(#loc) + } loc(#loc) +} loc(#loc) +#loc = loc(unknown) +#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":878:13) +#loc4 = loc("jit(func)/jit(main)/add"(#loc3)) +""", + mlir_module_serialized=b"ML\xefR\rStableHLO_v1.9.3\x00\x01\x1b\x05\x01\x05\x0b\x01\x03\x0b\x03\t\x0f\x13\x17\x1b\x03oQ\x0b\x01%\x07\x0f#\x0b\x0f\x0b\x0b\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x17\x0b\x13\x0b\x03-\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x1b\x0b\x0f#\x0b\x0b\x0b\x0b\x13\x0b\x0b\x0b\x0b\x0b\x01\x05\x0b\x0f\x03\x07\x13\x1b\x07\x02\n\x02\x1f\x11\x03\x05\x03\x07\x07\t\x0b\x03\r\x03\x05\x0f\x11\x01\x00\x05\x11\x05\x13\x05\x15\x1d\x13\x01\x05\x17\x1d\x17\x01\x05\x19\x1d\x1b\x1d\x05\x1b\x17\x1f\xba\r\x1b\x05\x1d\x03\x03#E\x05\x1f\x03\x01\x1d!\x1d#\x1d%\x1d'\x03\x0515\r\x05'3)+\x1d)\r\x05'-)+#\x07\x03\x03;\r\x07=?'-)+\x1d+\x1d-\x1d/\x1d1\r\x03G-\x1d3\x0b\x03\x1d5\x1d7\x05\x03\x01\t\x01\x02\x02)\x03\x05\t\x11\x05\x05\x05\x03\x05\t\x04e\x05\x01Q\x01\x05\x01\x07\x04S\x03\x01\x05\x03P\x01\x03\x07\x04?\x03\t\x0f\x05\x0b\x11\x0b\x15\x00\x05\x06\x19\x03\x05\x05\x01\x03\x07G\x01!\x05\x03\x05\x03\x05\t\x04\x01\x03\x07\x06\x03\x01\x05\x01\x00\x9a\x0695\x03-\x0f\x0b\x0f!\x0f\x19'\x1d#3i1\x05\x05\x13%)9\x15\x1f\x0f\x11\x0f\x0b\x11builtin\x00vhlo\x00module\x00func_v1\x00add_v1\x00custom_call_v1\x00return_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00x\x00y\x00jit(func)/jit(main)/add\x00third_party/py/jax/tests/export_back_compat_test.py\x00mhlo.frontend_attributes\x00mhlo.memory_kind\x00mhlo.sharding\x00{maximal device=0}\x00pinned_host\x00device\x00jax.result_info\x00result\x00main\x00public\x00_xla_buffer_placement\x00\x00annotate_device_placement\x00\x08'\x07\x05\x1f\x01\x0b/79AC\x11IKM%O%%%", + xla_call_module_version=9, + nr_devices=1, +) # End paste + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2025_04_07_tpu['shardy'] = dict( + testdata_version=1, + platform='tpu', + custom_call_targets=['annotate_device_placement', 'xla.sdy.FuncResultSharding'], + serialized_date=datetime.date(2025, 5, 28), + inputs=(array([0.], dtype=float32), array([0.], dtype=float32)), + expected_outputs=(array([0.], dtype=float32),), + mlir_module_text=r""" +#loc1 = loc("x") +#loc2 = loc("y") +module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.frontend_attributes = {xla.sdy.meshes = "{mesh = #sdy.mesh<[\22a\22=1]>}"}, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor<1xf32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding<@mesh, [{\22a\22}]>"}, mhlo.memory_kind = "device", mhlo.sharding = "{devices=[1]<=[1]}"} loc("x"), %arg1: tensor<1xf32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding<@mesh, [{\22a\22}]>"}, mhlo.memory_kind = "pinned_host", mhlo.sharding = "{devices=[1]<=[1]}"} loc("y")) -> (tensor<1xf32> {jax.result_info = "result", mhlo.memory_kind = "pinned_host", mhlo.sharding = "{devices=[1]<=[1]}"}) { + %0 = stablehlo.add %arg0, %arg1 : tensor<1xf32> loc(#loc4) + %1 = stablehlo.custom_call @annotate_device_placement(%0) {has_side_effect = true, mhlo.frontend_attributes = {_xla_buffer_placement = "pinned_host"}} : (tensor<1xf32>) -> tensor<1xf32> loc(#loc) + %2 = stablehlo.custom_call @xla.sdy.FuncResultSharding(%1) {has_side_effect = true, mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\22a\22}]>]>"}} : (tensor<1xf32>) -> tensor<1xf32> loc(#loc) + return %2 : tensor<1xf32> loc(#loc) + } loc(#loc) +} loc(#loc) +#loc = loc(unknown) +#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":801:13) +#loc4 = loc("jit(func)/jit(main)/add"(#loc3)) +""", + mlir_module_serialized=b'ML\xefR\rStableHLO_v1.10.3\x00\x01\x1b\x05\x01\x05\x0b\x01\x03\x0b\x03\t\x0f\x13\x17\x1b\x03\x85g\x0b\x01-\x07\x0b\x0f+\x0b\x0f\x13\x0b\x0b\x0b\x0b\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x17\x0b\x13\x13\x03;\x0b\x0b\x0b\x0b\x0b\x0b\x13\x0b\x0b\x0b\x0b\x13#\x0b\x0b#\x0b\x0f#\x0b\x0b\x0b\x0b\x13\x0b\x0b\x13\x0b\x0b\x01\x05\x0b\x0f\x03\x07\x13\x1b\x07\x02\x9a\x02\x1f\x05\x0f\x11\x03\x05\x03\t\t\x0b\x03\r\x13\x05\x15\x05\x05\x11\x11\x01\x00\x03\x03\x0f\x11\x05\x13\x05\x15\x05\x17\x05\x19\x05\x1b\x1d\x1b\x01\x05\x1d\x1d\x1f\x01\x05\x1f\x1d#%\x05!\x17\'\x86\x0c\x1b\x05#\x03\x03\x03[\x03\x03\x03a\x03\x01\x1d%\x1d\'\x1d)\x1d+\x1d\x0f\r\x03;G\x1d-\x0b\x03\x1d/\x05\x03\x03\x05EK\r\x0779/I13\x1d1\x1d3\r\x0779/513#\x07\x03\x03Q\r\x07SU/513\x1d5\x1d7\x1d9\x1d;\r\x03]5\x1d=\x1d?\r\x03;c\x1dA\x1dC\x01\t\x01\x02\x02)\x03\x05\t\x11\x05\x05\x05\x03\x05\t\x04w\x05\x01Q\x01\x07\x01\x07\x04e\x03\x01\x05\x05P\x01\x03\x07\x04Q\x03\x0b\x13\x05\x0b\x19\x0b\x1d\x00\x07\x06!\x03\x05\x05\x01\x03\x03G\x01)\x05\x03\x05\x03\x05\x03G\x01+\x07\x03\x05\x03\x07\t\x04\x01\x03\t\x06\x03\x01\x05\x01\x006\tE7Y5-\x0f\x0b\x0f!\x0f=\x03#\x19\'\x1d#i1\x05\x05\x13%)9\x1f93\x15\x0f\x11\x1f\x0f\x0b\x11builtin\x00vhlo\x00module\x00custom_call_v1\x00func_v1\x00add_v1\x00return_v1\x00mhlo.frontend_attributes\x00jax.uses_shape_polymorphism\x00xla.sdy.meshes\x00{mesh = #sdy.mesh<["a"=1]>}\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00x\x00y\x00jit(func)/jit(main)/add\x00third_party/py/jax/tests/export_back_compat_test.py\x00mhlo.memory_kind\x00mhlo.sharding\x00{devices=[1]<=[1]}\x00pinned_host\x00xla.sdy.sharding\x00\x00#sdy.sharding<@mesh, [{"a"}]>\x00device\x00jax.result_info\x00result\x00main\x00public\x00_xla_buffer_placement\x00annotate_device_placement\x00#sdy.sharding_per_value<[<@mesh, [{"a"}]>]>\x00xla.sdy.FuncResultSharding\x00\x089\t\x05/\x01\x0bCMOWY\x11=?_-A---\x11=?e-A---', + xla_call_module_version=9, + nr_devices=1, +) # End paste + + +data_2025_04_07_cuda = {} + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2025_04_07_cuda['gspmd'] = dict( + testdata_version=1, + platform='cuda', + custom_call_targets=['annotate_device_placement'], + serialized_date=datetime.date(2025, 4, 7), + inputs=(array([0.], dtype=float32), array([0.], dtype=float32)), + expected_outputs=(array([0.], dtype=float32),), + mlir_module_text=r""" +#loc1 = loc("x") +#loc2 = loc("y") +module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor<1xf32> {mhlo.memory_kind = "device", mhlo.sharding = "{maximal device=0}"} loc("x"), %arg1: tensor<1xf32> {mhlo.memory_kind = "pinned_host", mhlo.sharding = "{maximal device=0}"} loc("y")) -> (tensor<1xf32> {jax.result_info = "result", mhlo.memory_kind = "pinned_host", mhlo.sharding = "{maximal device=0}"}) { + %0 = stablehlo.add %arg0, %arg1 : tensor<1xf32> loc(#loc4) + %1 = stablehlo.custom_call @annotate_device_placement(%0) {has_side_effect = true, mhlo.frontend_attributes = {_xla_buffer_placement = "pinned_host"}} : (tensor<1xf32>) -> tensor<1xf32> loc(#loc) + return %1 : tensor<1xf32> loc(#loc) + } loc(#loc) +} loc(#loc) +#loc = loc(unknown) +#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":878:13) +#loc4 = loc("jit(func)/jit(main)/add"(#loc3)) +""", + mlir_module_serialized=b"ML\xefR\rStableHLO_v1.9.3\x00\x01\x1b\x05\x01\x05\x0b\x01\x03\x0b\x03\t\x0f\x13\x17\x1b\x03oQ\x0b\x01%\x07\x0f#\x0b\x0f\x0b\x0b\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x17\x0b\x13\x0b\x03-\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x1b\x0b\x0f#\x0b\x0b\x0b\x0b\x13\x0b\x0b\x0b\x0b\x0b\x01\x05\x0b\x0f\x03\x07\x13\x1b\x07\x02\n\x02\x1f\x11\x03\x05\x03\x07\x07\t\x0b\x03\r\x03\x05\x0f\x11\x01\x00\x05\x11\x05\x13\x05\x15\x1d\x13\x01\x05\x17\x1d\x17\x01\x05\x19\x1d\x1b\x1d\x05\x1b\x17\x1f\xba\r\x1b\x05\x1d\x03\x03#E\x05\x1f\x03\x01\x1d!\x1d#\x1d%\x1d'\x03\x0515\r\x05'3)+\x1d)\r\x05'-)+#\x07\x03\x03;\r\x07=?'-)+\x1d+\x1d-\x1d/\x1d1\r\x03G-\x1d3\x0b\x03\x1d5\x1d7\x05\x03\x01\t\x01\x02\x02)\x03\x05\t\x11\x05\x05\x05\x03\x05\t\x04e\x05\x01Q\x01\x05\x01\x07\x04S\x03\x01\x05\x03P\x01\x03\x07\x04?\x03\t\x0f\x05\x0b\x11\x0b\x15\x00\x05\x06\x19\x03\x05\x05\x01\x03\x07G\x01!\x05\x03\x05\x03\x05\t\x04\x01\x03\x07\x06\x03\x01\x05\x01\x00\x9a\x0695\x03-\x0f\x0b\x0f!\x0f\x19'\x1d#3i1\x05\x05\x13%)9\x15\x1f\x0f\x11\x0f\x0b\x11builtin\x00vhlo\x00module\x00func_v1\x00add_v1\x00custom_call_v1\x00return_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00x\x00y\x00jit(func)/jit(main)/add\x00third_party/py/jax/tests/export_back_compat_test.py\x00mhlo.frontend_attributes\x00mhlo.memory_kind\x00mhlo.sharding\x00{maximal device=0}\x00pinned_host\x00device\x00jax.result_info\x00result\x00main\x00public\x00_xla_buffer_placement\x00\x00annotate_device_placement\x00\x08'\x07\x05\x1f\x01\x0b/79AC\x11IKM%O%%%", + xla_call_module_version=9, + nr_devices=1, +) # End paste + + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2025_04_07_cuda['shardy'] = dict( + testdata_version=1, + platform='cuda', + custom_call_targets=['annotate_device_placement', 'xla.sdy.FuncResultSharding'], + serialized_date=datetime.date(2025, 5, 28), + inputs=(array([0.], dtype=float32), array([0.], dtype=float32)), + expected_outputs=(array([0.], dtype=float32),), + mlir_module_text=r""" +#loc1 = loc("x") +#loc2 = loc("y") +module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.frontend_attributes = {xla.sdy.meshes = "{mesh = #sdy.mesh<[\22a\22=1]>}"}, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor<1xf32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding<@mesh, [{\22a\22}]>"}, mhlo.memory_kind = "device", mhlo.sharding = "{devices=[1]<=[1]}"} loc("x"), %arg1: tensor<1xf32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding<@mesh, [{\22a\22}]>"}, mhlo.memory_kind = "pinned_host", mhlo.sharding = "{devices=[1]<=[1]}"} loc("y")) -> (tensor<1xf32> {jax.result_info = "result", mhlo.memory_kind = "pinned_host", mhlo.sharding = "{devices=[1]<=[1]}"}) { + %0 = stablehlo.add %arg0, %arg1 : tensor<1xf32> loc(#loc4) + %1 = stablehlo.custom_call @annotate_device_placement(%0) {has_side_effect = true, mhlo.frontend_attributes = {_xla_buffer_placement = "pinned_host"}} : (tensor<1xf32>) -> tensor<1xf32> loc(#loc) + %2 = stablehlo.custom_call @xla.sdy.FuncResultSharding(%1) {has_side_effect = true, mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\22a\22}]>]>"}} : (tensor<1xf32>) -> tensor<1xf32> loc(#loc) + return %2 : tensor<1xf32> loc(#loc) + } loc(#loc) +} loc(#loc) +#loc = loc(unknown) +#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":806:13) +#loc4 = loc("jit(func)/jit(main)/add"(#loc3)) +""", + mlir_module_serialized=b'ML\xefR\rStableHLO_v1.10.3\x00\x01\x1b\x05\x01\x05\x0b\x01\x03\x0b\x03\t\x0f\x13\x17\x1b\x03\x85g\x0b\x01-\x07\x0b\x0f+\x0b\x0f\x13\x0b\x0b\x0b\x0b\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x17\x0b\x13\x13\x03;\x0b\x0b\x0b\x0b\x0b\x0b\x13\x0b\x0b\x0b\x0b\x13#\x0b\x0b#\x0b\x0f#\x0b\x0b\x0b\x0b\x13\x0b\x0b\x13\x0b\x0b\x01\x05\x0b\x0f\x03\x07\x13\x1b\x07\x02\x9a\x02\x1f\x05\x0f\x11\x03\x05\x03\t\t\x0b\x03\r\x13\x05\x15\x05\x05\x11\x11\x01\x00\x03\x03\x0f\x11\x05\x13\x05\x15\x05\x17\x05\x19\x05\x1b\x1d\x1b\x01\x05\x1d\x1d\x1f\x01\x05\x1f\x1d#%\x05!\x17\'\x9a\x0c\x1b\x05#\x03\x03\x03[\x03\x03\x03a\x03\x01\x1d%\x1d\'\x1d)\x1d+\x1d\x0f\r\x03;G\x1d-\x0b\x03\x1d/\x05\x03\x03\x05EK\r\x0779/I13\x1d1\x1d3\r\x0779/513#\x07\x03\x03Q\r\x07SU/513\x1d5\x1d7\x1d9\x1d;\r\x03]5\x1d=\x1d?\r\x03;c\x1dA\x1dC\x01\t\x01\x02\x02)\x03\x05\t\x11\x05\x05\x05\x03\x05\t\x04w\x05\x01Q\x01\x07\x01\x07\x04e\x03\x01\x05\x05P\x01\x03\x07\x04Q\x03\x0b\x13\x05\x0b\x19\x0b\x1d\x00\x07\x06!\x03\x05\x05\x01\x03\x03G\x01)\x05\x03\x05\x03\x05\x03G\x01+\x07\x03\x05\x03\x07\t\x04\x01\x03\t\x06\x03\x01\x05\x01\x006\tE7Y5-\x0f\x0b\x0f!\x0f=\x03#\x19\'\x1d#i1\x05\x05\x13%)9\x1f93\x15\x0f\x11\x1f\x0f\x0b\x11builtin\x00vhlo\x00module\x00custom_call_v1\x00func_v1\x00add_v1\x00return_v1\x00mhlo.frontend_attributes\x00jax.uses_shape_polymorphism\x00xla.sdy.meshes\x00{mesh = #sdy.mesh<["a"=1]>}\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00x\x00y\x00jit(func)/jit(main)/add\x00third_party/py/jax/tests/export_back_compat_test.py\x00mhlo.memory_kind\x00mhlo.sharding\x00{devices=[1]<=[1]}\x00pinned_host\x00xla.sdy.sharding\x00\x00#sdy.sharding<@mesh, [{"a"}]>\x00device\x00jax.result_info\x00result\x00main\x00public\x00_xla_buffer_placement\x00annotate_device_placement\x00#sdy.sharding_per_value<[<@mesh, [{"a"}]>]>\x00xla.sdy.FuncResultSharding\x00\x089\t\x05/\x01\x0bCMOWY\x11=?_-A---\x11=?e-A---', + xla_call_module_version=9, + nr_devices=1, +) # End paste diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/cpu_cholesky_lapack_potrf.py b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_cholesky_lapack_potrf.py index eb4143615da6..ee06d902d235 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_data/cpu_cholesky_lapack_potrf.py +++ b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_cholesky_lapack_potrf.py @@ -17,345 +17,8 @@ import datetime from numpy import array, float32, complex64 -data_2023_06_19 = {} - - -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_06_19["f32"] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_spotrf'], - serialized_date=datetime.date(2023, 6, 19), - inputs=(array([[ 24.343887, 13.603932, 20.50489 , 12.063956], - [ 13.603932, 58.879757, -31.84056 , 16.328012], - [ 20.50489 , -31.84056 , 66.890755, -9.92216 ], - [ 12.063956, 16.328012, -9.92216 , 23.640734]], dtype=float32),), - expected_outputs=(array([[ 4.9339523, 0. , 0. , 0. ], - [ 2.7572079, 7.1608353, 0. , 0. ], - [ 4.155875 , -6.0466647, 3.6134892, 0. ], - [ 2.4450896, 1.3387254, -3.3177967, 2.2050648]], dtype=float32),), - mlir_module_text=r""" -#loc = loc(unknown) -module @jit_cholesky attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main(%arg0: tensor<4x4xf32> {jax.arg_info = "x", mhlo.sharding = "{replicated}"} loc(unknown)) -> (tensor<4x4xf32> {jax.result_info = ""}) { - %0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<4x4xf32>) -> tensor<4x4xf32> loc(#loc2) - %1 = stablehlo.add %arg0, %0 : tensor<4x4xf32> loc(#loc3) - %2 = stablehlo.constant dense<2.000000e+00> : tensor loc(#loc) - %3 = stablehlo.broadcast_in_dim %2, dims = [] : (tensor) -> tensor<4x4xf32> loc(#loc4) - %4 = stablehlo.divide %1, %3 : tensor<4x4xf32> loc(#loc4) - %5 = stablehlo.constant dense<1> : tensor loc(#loc5) - %6 = stablehlo.constant dense<1> : tensor loc(#loc5) - %7 = stablehlo.constant dense<4> : tensor loc(#loc5) - %8:2 = stablehlo.custom_call @lapack_spotrf(%5, %6, %7, %4) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>]} : (tensor, tensor, tensor, tensor<4x4xf32>) -> (tensor<4x4xf32>, tensor) loc(#loc5) - %9 = stablehlo.constant dense<0> : tensor loc(#loc5) - %10 = stablehlo.broadcast_in_dim %9, dims = [] : (tensor) -> tensor loc(#loc5) - %11 = stablehlo.compare EQ, %8#1, %10, SIGNED : (tensor, tensor) -> tensor loc(#loc5) - %12 = stablehlo.broadcast_in_dim %11, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc5) - %13 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc5) - %14 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor<4x4xf32> loc(#loc5) - %15 = stablehlo.broadcast_in_dim %12, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc5) - %16 = stablehlo.select %15, %8#0, %14 : tensor<4x4xi1>, tensor<4x4xf32> loc(#loc5) - %17 = call @tril(%16) : (tensor<4x4xf32>) -> tensor<4x4xf32> loc(#loc6) - return %17 : tensor<4x4xf32> loc(#loc) - } loc(#loc) - func.func private @tril(%arg0: tensor<4x4xf32> loc(unknown)) -> tensor<4x4xf32> { - %0 = stablehlo.iota dim = 0 : tensor<4x4xi32> loc(#loc7) - %1 = stablehlo.constant dense<0> : tensor loc(#loc6) - %2 = stablehlo.broadcast_in_dim %1, dims = [] : (tensor) -> tensor<4x4xi32> loc(#loc8) - %3 = stablehlo.add %0, %2 : tensor<4x4xi32> loc(#loc8) - %4 = stablehlo.iota dim = 1 : tensor<4x4xi32> loc(#loc9) - %5 = stablehlo.compare GE, %3, %4, SIGNED : (tensor<4x4xi32>, tensor<4x4xi32>) -> tensor<4x4xi1> loc(#loc10) - %6 = stablehlo.constant dense<0.000000e+00> : tensor loc(#loc6) - %7 = stablehlo.broadcast_in_dim %6, dims = [] : (tensor) -> tensor<4x4xf32> loc(#loc11) - %8 = stablehlo.select %5, %arg0, %7 : tensor<4x4xi1>, tensor<4x4xf32> loc(#loc12) - return %8 : tensor<4x4xf32> loc(#loc6) - } loc(#loc6) -} loc(#loc) -#loc1 = loc("/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py":292:0) -#loc2 = loc("jit(cholesky)/jit(main)/transpose[permutation=(1, 0)]"(#loc1)) -#loc3 = loc("jit(cholesky)/jit(main)/add"(#loc1)) -#loc4 = loc("jit(cholesky)/jit(main)/div"(#loc1)) -#loc5 = loc("jit(cholesky)/jit(main)/cholesky"(#loc1)) -#loc6 = loc("jit(cholesky)/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False,) name=tril keep_unused=False inline=False]"(#loc1)) -#loc7 = loc("jit(cholesky)/jit(main)/jit(tril)/iota[dtype=int32 shape=(4, 4) dimension=0]"(#loc1)) -#loc8 = loc("jit(cholesky)/jit(main)/jit(tril)/add"(#loc1)) -#loc9 = loc("jit(cholesky)/jit(main)/jit(tril)/iota[dtype=int32 shape=(4, 4) dimension=1]"(#loc1)) -#loc10 = loc("jit(cholesky)/jit(main)/jit(tril)/ge"(#loc1)) -#loc11 = loc("jit(cholesky)/jit(main)/jit(tril)/broadcast_in_dim[shape=(4, 4) broadcast_dimensions=()]"(#loc1)) -#loc12 = loc("jit(cholesky)/jit(main)/jit(tril)/select_n"(#loc1)) -""", - mlir_module_serialized=b'ML\xefR\x01StableHLO_v0.9.0\x00\x01)\x05\x01\x03\x01\x03\x05\x03\x19\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x03"\x02\xd9%\x01\x87\x0f\x17\x07\x0b\x13\x0f\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x0b\x0f\x13#\x0b\x0b\x0b33\x0b\x0b\x13\x0f\x0b\x0b\x13\x0f\x0b\x1b\x0f\x0b\x13\x0f\x0b\x0f\x0b\x13\x0b\x0f\x0b\x0f\x0b\x13\x0b\x0b\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x1b\x13\x13\x13\x0b\x03S\x0f\x0b\x0b\x0f\x0b\x0bO\x0f\x1b\x0b\x0b\x0b\x0b\x0f\x13\x0b\x0b\x0b\x0b\x0b\x0f\x1f\x0f\x0f\x0b\x1fO\x1f\x1f\x1f\x0b\x0b\x0b\x0b\x1b\x0f\x17\x13\x0b\x1fO\x01\x03\x0f\x03#\x17\x0f\x0f\x17\x07\x07\x07\x07\x17\x13\x07\x17\x13\x13\x13\x0f\x17\x02J\x07\x1dg\x03\x177\x92\x04\x01\x1f\x05\x1f\x03\x03\x1d\xb3\x1d5\x03\x05!\x11\x01\x05\x05#\x05%\x05\'\x05)\x05+\x03\x03\x07\xb1\x05-\x1d?\x03\x05/\x051\x1de\x03\x03\x03\x07\xbf\x03\x07+\x0f-\x0f\r/\x053\x055\x057\x03\x0b\x11\x95\x13\x89\x15\xa1\r\xa7\x17\xa9\x03\x0b\x11\x8d\x13\x89\x15\x8d\r\x8f\x17\xad\x059\x05;\x03\x03\x19\xaf\x1d=\x03\x05=\x05?\x03\x03\x19\xb5\x1dE\x03\x05A\x03\x05!\x91#\xb7\x1dK\x03\x05C\x03\x03\x07\xb9\x1dQ\x03\x05E\x1dU\x03\x05G\x03\x03Y\xbb\x05I\x1d]\x03\x05K\x1da\x03\x05M\x03\x03\x07\xbd\x05O\x05Q\x03\x03\x07\xc1\x03\x11m\xc3o\x8bq\xc5s\xc7u\xc9w\xcby\xcd{\xd1\x05S\x05U\x05W\x05Y\x05[\x05]\x05_\x05a\x03\x05!\x91#\xd3\x03\x03\x07\xd5\x03\x03\x1d\xd7\x03\x03\x85\x8f\x05c\x1f\x1d\x01#\x19\x1de\x03\x03\xab\x1dg\t\x07\x1f\x1f!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\x03\x97\r\x05\x99\x9b\x9d\x9f\x1di\x1dk\x1dm\x1do\x03\x03\xa3\r\x03\xa5\x8b\x1dq\x1ds\x1du\r\x01\x1dw\x13\x0b\x01\x1f\x05\t\x00\x00\x00\x00\x1f\x1b\x01\x13\x0b\x05\x07\x05\x1f\x07\t\x00\x00\x00\x00\x1f\x15!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x07\t\x00\x00\x00@\x1f\x05\t\x01\x00\x00\x00\x1f\x05\t\x04\x00\x00\x00\x0b\x05\x1dy\x03\x01\x05\x01\x03\t\x87\x87\x87\x93\x03\x03\xcf\x15\x03\x01\r\x01\x03\x05\x93\x87\x07\x01\x1f\x07\t\x00\x00\xc0\x7f\x1f\x15!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\x02\x02)\x05\x11\x11\x0f)\x01\x11)\x01\x0f)\x05\x11\x11\x11\x1d\x01\t\x1b)\x05\x11\x11\r)\x03\t\x0b\x13\x11\x03\x03\x03\x03)\x03\x01\x0b)\x03\x01\x17)\x03\t\x17)\x01\r)\x05\x05\x05\r\x04\xd6\x03\x05\x01\x11\x05)\x07\x03\x01\t\x07\x11\x051\x05\x03)O\x03\x03\x05\x13\x07[W\x03\x03\x03\x01\x0b\x06_\x03\x03\x05\x01\x03\x03\x03\x05c\x03\x07\x05\x07%\t\x03\x03\x03\x07\x15\x06%\x03\x03\x05\x05\t\x03\x03\x01\'\x03\x05\x03\x03\x01\'\x03\x05\x03\x03\x01i\x03\x05\x17\x07\x01k\x05\x03\x05\t\r\x0f\x11\x0b\x03\x03\x01\x1b\x03\x05\x05\x07\x01\t\x03\x05\x03\x17\r\x07\x01}\x03!\x05\x15\x19\x05\x07\x01\t\x03#\x03\x1b\x03\x03\x01\x7f\x03\x07\x05\x07\x01\t\x03\x03\x03\x1f\x05\x07\x01\x81\x03\x13\x03\x1d\x0f\x06\x01\x03\x03\x07#\x13!\x19\x07\x0b\x83\x03\x03\x03%\x11\x04\x05\x03\'\x07\x11\x0b3\x05\x03\x15+\x03\x03\x05\t\x03;9\x03\t\x03\x03\x0b\x1b\x03\x05\x05\x07\x1f\t\x03\t\x03\x05\x0b\x06\x1f\x03\t\x05\x03\x07\t\x03CA\x03\t\r\x07IG\x03\x13\x05\t\x0b\x03\x03\x0bM\x03\x07\x05\x07O\t\x03\x03\x03\x0f\x0f\x06S\x03\x03\x07\r\x01\x11\x11\x04\x0b\x03\x13\x06\x03\x01\x05\x01\x00\n\x16{\x1d\x11\x0f\x0b!\x1b\x1d\x05\x1b\x0b\x03\x0f\x1f/!!)#\x1f\x19C99m\x19W\xb3K\x9bM\x9b\x97\xd2\x02\x1b%)+\x1b+\x1f\x1f\x15\x1d\x15\x13\r\x11\x1f\x15\x1b\x15\x15\x17\x0f\x11\x11)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00func_v1\x00iota_v1\x00add_v1\x00compare_v1\x00select_v1\x00return_v1\x00transpose_v1\x00divide_v1\x00custom_call_v1\x00call_v1\x00value\x00sym_name\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00broadcast_dimensions\x00compare_type\x00comparison_direction\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_cholesky\x00jit(cholesky)/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False,) name=tril keep_unused=False inline=False]\x00/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py\x00jit(cholesky)/jit(main)/jit(tril)/iota[dtype=int32 shape=(4, 4) dimension=0]\x00jit(cholesky)/jit(main)/jit(tril)/add\x00jit(cholesky)/jit(main)/jit(tril)/iota[dtype=int32 shape=(4, 4) dimension=1]\x00jit(cholesky)/jit(main)/jit(tril)/ge\x00jit(cholesky)/jit(main)/jit(tril)/broadcast_in_dim[shape=(4, 4) broadcast_dimensions=()]\x00jit(cholesky)/jit(main)/jit(tril)/select_n\x00permutation\x00jit(cholesky)/jit(main)/transpose[permutation=(1, 0)]\x00jit(cholesky)/jit(main)/add\x00jit(cholesky)/jit(main)/div\x00jit(cholesky)/jit(main)/cholesky\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00callee\x00\x00tril\x00jax.arg_info\x00x\x00mhlo.sharding\x00{replicated}\x00jax.result_info\x00main\x00public\x00private\x00lapack_spotrf\x00', - xla_call_module_version=6, -) # End paste - - -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_06_19["f64"] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_dpotrf'], - serialized_date=datetime.date(2023, 6, 19), - inputs=(array([[ 23.022171138130666 , -16.79765603341739 , 0.9133449305189146, - -25.36636199966769 ], - [-16.79765603341739 , 31.655770252600092 , -1.5189878284433445, - 20.0344758332268 ], - [ 0.9133449305189146, -1.5189878284433445, 10.940134497877208 , - 8.169020034607513 ], - [-25.36636199966769 , 20.0344758332268 , 8.169020034607513 , - 37.054603917509596 ]]),), - expected_outputs=(array([[ 4.7981424674691215 , 0. , 0. , - 0. ], - [-3.500866459740129 , 4.404509539513645 , 0. , - 0. ], - [ 0.19035385812557523, -0.1935707899825621 , 3.2964268922333835 , - 0. ], - [-5.286704630312426 , 0.3465604732420997 , 2.8037778311164425 , - 1.060228174247855 ]]),), - mlir_module_text=r""" -#loc = loc(unknown) -module @jit_cholesky attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main(%arg0: tensor<4x4xf64> {jax.arg_info = "x", mhlo.sharding = "{replicated}"} loc(unknown)) -> (tensor<4x4xf64> {jax.result_info = ""}) { - %0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<4x4xf64>) -> tensor<4x4xf64> loc(#loc2) - %1 = stablehlo.add %arg0, %0 : tensor<4x4xf64> loc(#loc3) - %2 = stablehlo.constant dense<2.000000e+00> : tensor loc(#loc) - %3 = stablehlo.broadcast_in_dim %2, dims = [] : (tensor) -> tensor<4x4xf64> loc(#loc4) - %4 = stablehlo.divide %1, %3 : tensor<4x4xf64> loc(#loc4) - %5 = stablehlo.constant dense<1> : tensor loc(#loc5) - %6 = stablehlo.constant dense<1> : tensor loc(#loc5) - %7 = stablehlo.constant dense<4> : tensor loc(#loc5) - %8:2 = stablehlo.custom_call @lapack_dpotrf(%5, %6, %7, %4) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>]} : (tensor, tensor, tensor, tensor<4x4xf64>) -> (tensor<4x4xf64>, tensor) loc(#loc5) - %9 = stablehlo.constant dense<0> : tensor loc(#loc5) - %10 = stablehlo.broadcast_in_dim %9, dims = [] : (tensor) -> tensor loc(#loc5) - %11 = stablehlo.compare EQ, %8#1, %10, SIGNED : (tensor, tensor) -> tensor loc(#loc5) - %12 = stablehlo.broadcast_in_dim %11, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc5) - %13 = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc5) - %14 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor<4x4xf64> loc(#loc5) - %15 = stablehlo.broadcast_in_dim %12, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc5) - %16 = stablehlo.select %15, %8#0, %14 : tensor<4x4xi1>, tensor<4x4xf64> loc(#loc5) - %17 = call @tril(%16) : (tensor<4x4xf64>) -> tensor<4x4xf64> loc(#loc6) - return %17 : tensor<4x4xf64> loc(#loc) - } loc(#loc) - func.func private @tril(%arg0: tensor<4x4xf64> loc(unknown)) -> tensor<4x4xf64> { - %0 = stablehlo.iota dim = 0 : tensor<4x4xi32> loc(#loc7) - %1 = stablehlo.constant dense<0> : tensor loc(#loc6) - %2 = stablehlo.broadcast_in_dim %1, dims = [] : (tensor) -> tensor<4x4xi32> loc(#loc8) - %3 = stablehlo.add %0, %2 : tensor<4x4xi32> loc(#loc8) - %4 = stablehlo.iota dim = 1 : tensor<4x4xi32> loc(#loc9) - %5 = stablehlo.compare GE, %3, %4, SIGNED : (tensor<4x4xi32>, tensor<4x4xi32>) -> tensor<4x4xi1> loc(#loc10) - %6 = stablehlo.constant dense<0.000000e+00> : tensor loc(#loc6) - %7 = stablehlo.broadcast_in_dim %6, dims = [] : (tensor) -> tensor<4x4xf64> loc(#loc11) - %8 = stablehlo.select %5, %arg0, %7 : tensor<4x4xi1>, tensor<4x4xf64> loc(#loc12) - return %8 : tensor<4x4xf64> loc(#loc6) - } loc(#loc6) -} loc(#loc) -#loc1 = loc("/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py":292:0) -#loc2 = loc("jit(cholesky)/jit(main)/transpose[permutation=(1, 0)]"(#loc1)) -#loc3 = loc("jit(cholesky)/jit(main)/add"(#loc1)) -#loc4 = loc("jit(cholesky)/jit(main)/div"(#loc1)) -#loc5 = loc("jit(cholesky)/jit(main)/cholesky"(#loc1)) -#loc6 = loc("jit(cholesky)/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False,) name=tril keep_unused=False inline=False]"(#loc1)) -#loc7 = loc("jit(cholesky)/jit(main)/jit(tril)/iota[dtype=int32 shape=(4, 4) dimension=0]"(#loc1)) -#loc8 = loc("jit(cholesky)/jit(main)/jit(tril)/add"(#loc1)) -#loc9 = loc("jit(cholesky)/jit(main)/jit(tril)/iota[dtype=int32 shape=(4, 4) dimension=1]"(#loc1)) -#loc10 = loc("jit(cholesky)/jit(main)/jit(tril)/ge"(#loc1)) -#loc11 = loc("jit(cholesky)/jit(main)/jit(tril)/broadcast_in_dim[shape=(4, 4) broadcast_dimensions=()]"(#loc1)) -#loc12 = loc("jit(cholesky)/jit(main)/jit(tril)/select_n"(#loc1)) -""", - mlir_module_serialized=b'ML\xefR\x01StableHLO_v0.9.0\x00\x01)\x05\x01\x03\x01\x03\x05\x03\x19\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x03"\x02\xd9%\x01\x87\x0f\x17\x07\x0b\x13\x0f\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x0b\x0f\x13#\x0b\x0b\x0b33\x0b\x0b\x13\x0f\x0b\x0b\x13\x0f\x0b\x1b\x0f\x0b\x13\x0f\x0b\x0f\x0b\x13\x0b\x0f\x0b\x0f\x0b\x13\x0b\x0b\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x1b\x13\x13\x13\x0b\x03S\x0f\x0b\x0b\x0f\x0b\x0bO\x0f\x1b\x0b\x0b\x0b\x0b\x0f\x13\x0b\x0b\x0b\x0b\x0b\x0f\x1f\x0f\x0f\x0b/O/\x1f\x1f\x0b\x0b\x0b\x0b\x1b\x0f\x17\x13\x0b/O\x01\x03\x0f\x03#\x17\x0f\x0f\x17\x07\x07\x07\x07\x17\x13\x07\x17\x13\x13\x13\x0f\x17\x02z\x07\x1dg\x03\x177\x92\x04\x01\x1f\x05\x1f\x03\x03\x1d\xb3\x1d5\x03\x05!\x11\x01\x05\x05#\x05%\x05\'\x05)\x05+\x03\x03\x07\xb1\x05-\x1d?\x03\x05/\x051\x1de\x03\x03\x03\x07\xbf\x03\x07+\x0f-\x0f\r/\x053\x055\x057\x03\x0b\x11\x95\x13\x89\x15\xa1\r\xa7\x17\xa9\x03\x0b\x11\x8d\x13\x89\x15\x8d\r\x8f\x17\xad\x059\x05;\x03\x03\x19\xaf\x1d=\x03\x05=\x05?\x03\x03\x19\xb5\x1dE\x03\x05A\x03\x05!\x91#\xb7\x1dK\x03\x05C\x03\x03\x07\xb9\x1dQ\x03\x05E\x1dU\x03\x05G\x03\x03Y\xbb\x05I\x1d]\x03\x05K\x1da\x03\x05M\x03\x03\x07\xbd\x05O\x05Q\x03\x03\x07\xc1\x03\x11m\xc3o\x8bq\xc5s\xc7u\xc9w\xcby\xcd{\xd1\x05S\x05U\x05W\x05Y\x05[\x05]\x05_\x05a\x03\x05!\x91#\xd3\x03\x03\x07\xd5\x03\x03\x1d\xd7\x03\x03\x85\x8f\x05c\x1f\x1d\x01#\x19\x1de\x03\x03\xab\x1dg\t\x07\x1f\x1f!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\x03\x97\r\x05\x99\x9b\x9d\x9f\x1di\x1dk\x1dm\x1do\x03\x03\xa3\r\x03\xa5\x8b\x1dq\x1ds\x1du\r\x01\x1dw\x13\x0b\x01\x1f\x05\t\x00\x00\x00\x00\x1f\x1b\x01\x13\x0b\x05\x07\x05\x1f\x07\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x15!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x07\x11\x00\x00\x00\x00\x00\x00\x00@\x1f\x05\t\x01\x00\x00\x00\x1f\x05\t\x04\x00\x00\x00\x0b\x05\x1dy\x03\x01\x05\x01\x03\t\x87\x87\x87\x93\x03\x03\xcf\x15\x03\x01\r\x01\x03\x05\x93\x87\x07\x01\x1f\x07\x11\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f\x15!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\x02\x02)\x05\x11\x11\x0f)\x01\x11)\x01\x0f)\x05\x11\x11\x11\x1d\x01\x0b\x1b)\x05\x11\x11\r)\x03\t\x0b\x13\x11\x03\x03\x03\x03)\x03\x01\x0b)\x03\x01\x17)\x03\t\x17)\x01\r)\x05\x05\x05\r\x04\xd6\x03\x05\x01\x11\x05)\x07\x03\x01\t\x07\x11\x051\x05\x03)O\x03\x03\x05\x13\x07[W\x03\x03\x03\x01\x0b\x06_\x03\x03\x05\x01\x03\x03\x03\x05c\x03\x07\x05\x07%\t\x03\x03\x03\x07\x15\x06%\x03\x03\x05\x05\t\x03\x03\x01\'\x03\x05\x03\x03\x01\'\x03\x05\x03\x03\x01i\x03\x05\x17\x07\x01k\x05\x03\x05\t\r\x0f\x11\x0b\x03\x03\x01\x1b\x03\x05\x05\x07\x01\t\x03\x05\x03\x17\r\x07\x01}\x03!\x05\x15\x19\x05\x07\x01\t\x03#\x03\x1b\x03\x03\x01\x7f\x03\x07\x05\x07\x01\t\x03\x03\x03\x1f\x05\x07\x01\x81\x03\x13\x03\x1d\x0f\x06\x01\x03\x03\x07#\x13!\x19\x07\x0b\x83\x03\x03\x03%\x11\x04\x05\x03\'\x07\x11\x0b3\x05\x03\x15+\x03\x03\x05\t\x03;9\x03\t\x03\x03\x0b\x1b\x03\x05\x05\x07\x1f\t\x03\t\x03\x05\x0b\x06\x1f\x03\t\x05\x03\x07\t\x03CA\x03\t\r\x07IG\x03\x13\x05\t\x0b\x03\x03\x0bM\x03\x07\x05\x07O\t\x03\x03\x03\x0f\x0f\x06S\x03\x03\x07\r\x01\x11\x11\x04\x0b\x03\x13\x06\x03\x01\x05\x01\x00\n\x16{\x1d\x11\x0f\x0b!\x1b\x1d\x05\x1b\x0b\x03\x0f\x1f/!!)#\x1f\x19C99m\x19W\xb3K\x9bM\x9b\x97\xd2\x02\x1b%)+\x1b+\x1f\x1f\x15\x1d\x15\x13\r\x11\x1f\x15\x1b\x15\x15\x17\x0f\x11\x11)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00func_v1\x00iota_v1\x00add_v1\x00compare_v1\x00select_v1\x00return_v1\x00transpose_v1\x00divide_v1\x00custom_call_v1\x00call_v1\x00value\x00sym_name\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00broadcast_dimensions\x00compare_type\x00comparison_direction\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_cholesky\x00jit(cholesky)/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False,) name=tril keep_unused=False inline=False]\x00/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py\x00jit(cholesky)/jit(main)/jit(tril)/iota[dtype=int32 shape=(4, 4) dimension=0]\x00jit(cholesky)/jit(main)/jit(tril)/add\x00jit(cholesky)/jit(main)/jit(tril)/iota[dtype=int32 shape=(4, 4) dimension=1]\x00jit(cholesky)/jit(main)/jit(tril)/ge\x00jit(cholesky)/jit(main)/jit(tril)/broadcast_in_dim[shape=(4, 4) broadcast_dimensions=()]\x00jit(cholesky)/jit(main)/jit(tril)/select_n\x00permutation\x00jit(cholesky)/jit(main)/transpose[permutation=(1, 0)]\x00jit(cholesky)/jit(main)/add\x00jit(cholesky)/jit(main)/div\x00jit(cholesky)/jit(main)/cholesky\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00callee\x00\x00tril\x00jax.arg_info\x00x\x00mhlo.sharding\x00{replicated}\x00jax.result_info\x00main\x00public\x00private\x00lapack_dpotrf\x00', - xla_call_module_version=6, -) # End paste - - - -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_06_19["c64"] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_cpotrf'], - serialized_date=datetime.date(2023, 6, 19), - inputs=(array([[ 38.089394 +6.36582342e-09j, 3.3509154+3.13455486e+01j, - -0.5972489-3.80308151e+01j, -19.04205 +1.22770605e+01j], - [ 3.3509154-3.13455486e+01j, 73.875755 +4.06565448e-09j, - -12.427276 -1.23379612e+01j, 41.542507 -9.63993359e+00j], - [ -0.5972489+3.80308151e+01j, -12.427276 +1.23379612e+01j, - 73.04141 -4.18667753e-07j, 8.193126 -2.60565052e+01j], - [-19.04205 -1.22770605e+01j, 41.542507 +9.63993359e+00j, - 8.193126 +2.60565052e+01j, 52.977036 -1.09952367e-07j]], - dtype=complex64),), - expected_outputs=(array([[ 6.1716604 +0.j , 0. +0.j , - 0. +0.j , 0. +0.j ], - [ 0.542952 -5.078949j , 6.912687 +0.j , - 0. +0.j , 0. +0.j ], - [-0.09677281+6.162169j , 2.7373738 +1.3719271j, - 5.0679703 +0.j , 0. +0.j ], - [-3.0854013 -1.9892638j, 4.7903748 +3.8177056j, - 0.3555784 +0.5865844j, 1.2276335 +0.j ]], dtype=complex64),), - mlir_module_text=r""" -#loc = loc(unknown) -module @jit_cholesky attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main(%arg0: tensor<4x4xcomplex> {jax.arg_info = "x", mhlo.sharding = "{replicated}"} loc(unknown)) -> (tensor<4x4xcomplex> {jax.result_info = ""}) { - %0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<4x4xcomplex>) -> tensor<4x4xcomplex> loc(#loc2) - %1 = stablehlo.real %0 : (tensor<4x4xcomplex>) -> tensor<4x4xf32> loc(#loc3) - %2 = stablehlo.imag %0 : (tensor<4x4xcomplex>) -> tensor<4x4xf32> loc(#loc4) - %3 = stablehlo.negate %2 : tensor<4x4xf32> loc(#loc5) - %4 = stablehlo.complex %1, %3 : tensor<4x4xcomplex> loc(#loc6) - %5 = stablehlo.add %arg0, %4 : tensor<4x4xcomplex> loc(#loc7) - %6 = stablehlo.constant dense<(2.000000e+00,0.000000e+00)> : tensor> loc(#loc) - %7 = stablehlo.broadcast_in_dim %6, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc8) - %8 = stablehlo.divide %5, %7 : tensor<4x4xcomplex> loc(#loc8) - %9 = stablehlo.constant dense<1> : tensor loc(#loc9) - %10 = stablehlo.constant dense<1> : tensor loc(#loc9) - %11 = stablehlo.constant dense<4> : tensor loc(#loc9) - %12:2 = stablehlo.custom_call @lapack_cpotrf(%9, %10, %11, %8) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>]} : (tensor, tensor, tensor, tensor<4x4xcomplex>) -> (tensor<4x4xcomplex>, tensor) loc(#loc9) - %13 = stablehlo.constant dense<0> : tensor loc(#loc9) - %14 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor loc(#loc9) - %15 = stablehlo.compare EQ, %12#1, %14, SIGNED : (tensor, tensor) -> tensor loc(#loc9) - %16 = stablehlo.broadcast_in_dim %15, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc9) - %17 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> loc(#loc9) - %18 = stablehlo.broadcast_in_dim %17, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc9) - %19 = stablehlo.broadcast_in_dim %16, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc9) - %20 = stablehlo.select %19, %12#0, %18 : tensor<4x4xi1>, tensor<4x4xcomplex> loc(#loc9) - %21 = call @tril(%20) : (tensor<4x4xcomplex>) -> tensor<4x4xcomplex> loc(#loc10) - return %21 : tensor<4x4xcomplex> loc(#loc) - } loc(#loc) - func.func private @tril(%arg0: tensor<4x4xcomplex> loc(unknown)) -> tensor<4x4xcomplex> { - %0 = stablehlo.iota dim = 0 : tensor<4x4xi32> loc(#loc11) - %1 = stablehlo.constant dense<0> : tensor loc(#loc10) - %2 = stablehlo.broadcast_in_dim %1, dims = [] : (tensor) -> tensor<4x4xi32> loc(#loc12) - %3 = stablehlo.add %0, %2 : tensor<4x4xi32> loc(#loc12) - %4 = stablehlo.iota dim = 1 : tensor<4x4xi32> loc(#loc13) - %5 = stablehlo.compare GE, %3, %4, SIGNED : (tensor<4x4xi32>, tensor<4x4xi32>) -> tensor<4x4xi1> loc(#loc14) - %6 = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor> loc(#loc10) - %7 = stablehlo.broadcast_in_dim %6, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc15) - %8 = stablehlo.select %5, %arg0, %7 : tensor<4x4xi1>, tensor<4x4xcomplex> loc(#loc16) - return %8 : tensor<4x4xcomplex> loc(#loc10) - } loc(#loc10) -} loc(#loc) -#loc1 = loc("/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py":292:0) -#loc2 = loc("jit(cholesky)/jit(main)/transpose[permutation=(1, 0)]"(#loc1)) -#loc3 = loc("jit(cholesky)/jit(main)/real"(#loc1)) -#loc4 = loc("jit(cholesky)/jit(main)/imag"(#loc1)) -#loc5 = loc("jit(cholesky)/jit(main)/neg"(#loc1)) -#loc6 = loc("jit(cholesky)/jit(main)/complex"(#loc1)) -#loc7 = loc("jit(cholesky)/jit(main)/add"(#loc1)) -#loc8 = loc("jit(cholesky)/jit(main)/div"(#loc1)) -#loc9 = loc("jit(cholesky)/jit(main)/cholesky"(#loc1)) -#loc10 = loc("jit(cholesky)/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False,) name=tril keep_unused=False inline=False]"(#loc1)) -#loc11 = loc("jit(cholesky)/jit(main)/jit(tril)/iota[dtype=int32 shape=(4, 4) dimension=0]"(#loc1)) -#loc12 = loc("jit(cholesky)/jit(main)/jit(tril)/add"(#loc1)) -#loc13 = loc("jit(cholesky)/jit(main)/jit(tril)/iota[dtype=int32 shape=(4, 4) dimension=1]"(#loc1)) -#loc14 = loc("jit(cholesky)/jit(main)/jit(tril)/ge"(#loc1)) -#loc15 = loc("jit(cholesky)/jit(main)/jit(tril)/broadcast_in_dim[shape=(4, 4) broadcast_dimensions=()]"(#loc1)) -#loc16 = loc("jit(cholesky)/jit(main)/jit(tril)/select_n"(#loc1)) -""", - mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x011\x05\x01\x03\x01\x03\x05\x03!\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x1f!#%\x03J\x02\xe9)\x01\x97\x17\x0f\x07\x0b\x13\x0f\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x0b\x0f\x13#\x0b\x0b\x0b33\x0b\x0b\x13\x0f\x0b\x0b\x13\x0f\x0b\x1b\x0f\x0b\x13\x0f\x0b\x0f\x0b\x13\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x13\x0b\x0b\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x1b\x13\x13\x13\x0b\x03S\x0f\x0b\x0b\x0f\x0b\x0bO\x0f\x1b\x0b\x0b\x0b\x0b\x0f\x13\x0b\x0b\x0b\x0b\x0b\x0f\x1f\x0f\x0f\x0b/O/\x1f\x1f\x0b\x0b\x0b\x0b\x1b\x0f\x17\x13\x0b/O\x01\x03\x0f\x03'\x17\x0f\x0f\x17\x07\x07\x17\x0b\x07\x07\x17\x13\x07\x17\x13\x13\x13\x0f\x17\x02\xe6\x07\x177\x92\x04\x01\x1dw\x01\x1f\x05'\x03\x03\x1d\xc3\x1d5\x01\x05)\x11\x01\x05\x05+\x05-\x05/\x051\x053\x03\x03\x07\xc1\x055\x1d?\x01\x057\x059\x1du\x01\x03\x03\x07\xcf\x03\x07+\x0f-\x0f\r/\x05;\x05=\x05?\x03\x0b\x11\xa5\x13\x99\x15\xb1\r\xb7\x17\xb9\x03\x0b\x11\x9d\x13\x99\x15\x9d\r\x9f\x17\xbd\x05A\x05C\x03\x03\x19\xbf\x1d=\x01\x05E\x05G\x03\x03\x19\xc5\x1dE\x01\x05I\x03\x05!\xa1#\xc7\x1dK\x01\x05K\x03\x03\x07\xc9\x1dQ\x01\x05M\x1dU\x01\x05O\x03\x03Y\xcb\x05Q\x1d]\x01\x05S\x1da\x01\x05U\x1de\x01\x05W\x1di\x01\x05Y\x1dm\x01\x05[\x1dq\x01\x05]\x03\x03\x07\xcd\x05_\x05a\x03\x03\x07\xd1\x03\x11}\xd3\x7f\x9b\x81\xd5\x83\xd7\x85\xd9\x87\xdb\x89\xdd\x8b\xe1\x05c\x05e\x05g\x05i\x05k\x05m\x05o\x05q\x03\x05!\xa1#\xe3\x03\x03\x07\xe5\x03\x03\x1d\xe7\x03\x03\x95\x9f\x05s\x1f!\x01#\x1d\x1du\x03\x03\xbb\x1dw\t\x07\x1f#!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\x03\xa7\r\x05\xa9\xab\xad\xaf\x1dy\x1d{\x1d}\x1d\x7f\x03\x03\xb3\r\x03\xb5\x9b\x1d\x81\x1d\x83\x1d\x85\r\x01\x1d\x87\x13\x0b\x01\x1f\x05\t\x00\x00\x00\x00\x1f\x1f\x01\x13\x0b\x05\x07\x05\x1f\x07\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x19!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x07\x11\x00\x00\x00@\x00\x00\x00\x00\x1f\x05\t\x01\x00\x00\x00\x1f\x05\t\x04\x00\x00\x00\x0b\x05\x1d\x89\x03\x01\x05\x01\x03\t\x97\x97\x97\xa3\x03\x03\xdf\x15\x03\x01\r\x01\x03\x05\xa3\x97\x07\x01\x1f\x07\x11\x00\x00\xc0\x7f\x00\x00\xc0\x7f\x1f\x19!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\x02\x02)\x05\x11\x11\x11)\x01\x15)\x01\x11)\x05\x11\x11\x15\x1d\x01)\x05\x11\x11\x13\x03\x13\t\x1b)\x05\x11\x11\r)\x03\t\x0b\x13\x11\x03\x03\x03\x03)\x03\x01\x0b)\x03\x01\x1b)\x03\t\x1b)\x01\r)\x05\x05\x05\r\x04J\x04\x05\x01\x11\x05)\x07\x03\x01\t\x07\x11\x051\x05\x031_\x03\x03\x05\x13\x07[W\x03\x03\x03\x01\x15\x06_\x03\x0f\x03\x03\x17\x06c\x03\x0f\x03\x03\x19\x06g\x03\x0f\x03\x07\x1b\x06k\x03\x03\x05\x05\t\x0b\x06o\x03\x03\x05\x01\x0b\x03\x03\x05s\x03\x07\x05\x07%\t\x03\x03\x03\x0f\x1d\x06%\x03\x03\x05\r\x11\x03\x03\x03'\x03\x05\x03\x03\x03'\x03\x05\x03\x03\x03y\x03\x05\x1f\x07\x03{\x05\x03\x05\t\x15\x17\x19\x13\x03\x03\x03\x1b\x03\x05\x05\x07\x03\t\x03\x05\x03\x1f\r\x07\x03\x8d\x03%\x05\x1d!\x05\x07\x03\t\x03'\x03#\x03\x03\x03\x8f\x03\x07\x05\x07\x03\t\x03\x03\x03'\x05\x07\x03\x91\x03\x17\x03%\x0f\x06\x03\x03\x03\x07+\x1b)!\x07\x0b\x93\x03\x03\x03-\x11\x04\x05\x03/\x07\x11\x0b3\x05\x03\x15+\x03\x03\x05\t\x03;9\x03\t\x03\x03\x0b\x1b\x03\x05\x05\x07\x1f\t\x03\t\x03\x05\x0b\x06\x1f\x03\t\x05\x03\x07\t\x03CA\x03\t\r\x07IG\x03\x17\x05\t\x0b\x03\x03\x0bM\x03\x07\x05\x07O\t\x03\x03\x03\x0f\x0f\x06S\x03\x03\x07\r\x01\x11\x11\x04\x0b\x03\x13\x06\x03\x01\x05\x01\x00\x96\x18\x8b\x1d\x11\x0f\x0b!\x1b\x1d\x05\x1b\x0b\x03\x0f\x1f/!!)#\x1f\x19C99A9;;m\x19W\xb3K\x9bM\x9b\x97\xd2\x02\x1b%)+\x1b+\x1f\x1f\x15\x1d\x15\x13\r\x11\x1f\x15\x17\x15\x11\x11\x1b\x15\x15\x17\x0f\x11\x11)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00func_v1\x00iota_v1\x00add_v1\x00compare_v1\x00select_v1\x00return_v1\x00transpose_v1\x00real_v1\x00imag_v1\x00negate_v1\x00complex_v1\x00divide_v1\x00custom_call_v1\x00call_v1\x00value\x00sym_name\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00broadcast_dimensions\x00compare_type\x00comparison_direction\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_cholesky\x00jit(cholesky)/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False,) name=tril keep_unused=False inline=False]\x00/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py\x00jit(cholesky)/jit(main)/jit(tril)/iota[dtype=int32 shape=(4, 4) dimension=0]\x00jit(cholesky)/jit(main)/jit(tril)/add\x00jit(cholesky)/jit(main)/jit(tril)/iota[dtype=int32 shape=(4, 4) dimension=1]\x00jit(cholesky)/jit(main)/jit(tril)/ge\x00jit(cholesky)/jit(main)/jit(tril)/broadcast_in_dim[shape=(4, 4) broadcast_dimensions=()]\x00jit(cholesky)/jit(main)/jit(tril)/select_n\x00permutation\x00jit(cholesky)/jit(main)/transpose[permutation=(1, 0)]\x00jit(cholesky)/jit(main)/real\x00jit(cholesky)/jit(main)/imag\x00jit(cholesky)/jit(main)/neg\x00jit(cholesky)/jit(main)/complex\x00jit(cholesky)/jit(main)/add\x00jit(cholesky)/jit(main)/div\x00jit(cholesky)/jit(main)/cholesky\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00callee\x00\x00tril\x00jax.arg_info\x00x\x00mhlo.sharding\x00{replicated}\x00jax.result_info\x00main\x00public\x00private\x00lapack_cpotrf\x00", - xla_call_module_version=6, -) # End paste - - -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_06_19["c128"] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_zpotrf'], - serialized_date=datetime.date(2023, 6, 19), - inputs=(array([[ 77.35445791180521 -6.4555004827448569e-16j, - 16.89356598261691 -5.4959586590823566e+00j, - -21.124380423202325+6.4431220601700787e+01j, - 55.385054340628855+2.5198457006849742e+00j], - [ 16.89356598261691 +5.4959586590823566e+00j, - 67.125263428637 -3.2921739472953976e-16j, - 25.14078382035968 +1.2783276691803774e+01j, - 51.116221409460884-2.2635508887939348e+00j], - [-21.124380423202325-6.4431220601700787e+01j, - 25.14078382035968 -1.2783276691803774e+01j, - 107.43449297637208 -2.8959717546347756e-15j, - 12.493792156221616-5.7556567757218694e+01j], - [ 55.385054340628855-2.5198457006849715e+00j, - 51.116221409460884+2.2635508887939326e+00j, - 12.493792156221616+5.7556567757218708e+01j, - 78.9856503203742 +2.0971925518284437e-16j]]),), - expected_outputs=(array([[ 8.795138311124232 +0.j , - 0. +0.j , - 0. +0.j , - 0. +0.j ], - [ 1.9207845726825759+0.624885984127274j , - 7.940111306576433 +0.j , - 0. +0.j , - 0. +0.j ], - [-2.401824698593298 -7.325776846534311j , - 4.3238621722485755-0.026813746599595675j, - 5.413152651345813 +0.j , - 0. +0.j ], - [ 6.297235174866659 -0.28650438589440164j , - 4.936910868956218 +0.849977768846063j , - 0.7751580530200595+1.279980716041562j , - 3.451611642915363 +0.j ]]),), - mlir_module_text=r""" -#loc = loc(unknown) -module @jit_cholesky attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main(%arg0: tensor<4x4xcomplex> {jax.arg_info = "x", mhlo.sharding = "{replicated}"} loc(unknown)) -> (tensor<4x4xcomplex> {jax.result_info = ""}) { - %0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<4x4xcomplex>) -> tensor<4x4xcomplex> loc(#loc2) - %1 = stablehlo.real %0 : (tensor<4x4xcomplex>) -> tensor<4x4xf64> loc(#loc3) - %2 = stablehlo.imag %0 : (tensor<4x4xcomplex>) -> tensor<4x4xf64> loc(#loc4) - %3 = stablehlo.negate %2 : tensor<4x4xf64> loc(#loc5) - %4 = stablehlo.complex %1, %3 : tensor<4x4xcomplex> loc(#loc6) - %5 = stablehlo.add %arg0, %4 : tensor<4x4xcomplex> loc(#loc7) - %6 = stablehlo.constant dense<(2.000000e+00,0.000000e+00)> : tensor> loc(#loc) - %7 = stablehlo.broadcast_in_dim %6, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc8) - %8 = stablehlo.divide %5, %7 : tensor<4x4xcomplex> loc(#loc8) - %9 = stablehlo.constant dense<1> : tensor loc(#loc9) - %10 = stablehlo.constant dense<1> : tensor loc(#loc9) - %11 = stablehlo.constant dense<4> : tensor loc(#loc9) - %12:2 = stablehlo.custom_call @lapack_zpotrf(%9, %10, %11, %8) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>]} : (tensor, tensor, tensor, tensor<4x4xcomplex>) -> (tensor<4x4xcomplex>, tensor) loc(#loc9) - %13 = stablehlo.constant dense<0> : tensor loc(#loc9) - %14 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor loc(#loc9) - %15 = stablehlo.compare EQ, %12#1, %14, SIGNED : (tensor, tensor) -> tensor loc(#loc9) - %16 = stablehlo.broadcast_in_dim %15, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc9) - %17 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> loc(#loc9) - %18 = stablehlo.broadcast_in_dim %17, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc9) - %19 = stablehlo.broadcast_in_dim %16, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc9) - %20 = stablehlo.select %19, %12#0, %18 : tensor<4x4xi1>, tensor<4x4xcomplex> loc(#loc9) - %21 = call @tril(%20) : (tensor<4x4xcomplex>) -> tensor<4x4xcomplex> loc(#loc10) - return %21 : tensor<4x4xcomplex> loc(#loc) - } loc(#loc) - func.func private @tril(%arg0: tensor<4x4xcomplex> loc(unknown)) -> tensor<4x4xcomplex> { - %0 = stablehlo.iota dim = 0 : tensor<4x4xi32> loc(#loc11) - %1 = stablehlo.constant dense<0> : tensor loc(#loc10) - %2 = stablehlo.broadcast_in_dim %1, dims = [] : (tensor) -> tensor<4x4xi32> loc(#loc12) - %3 = stablehlo.add %0, %2 : tensor<4x4xi32> loc(#loc12) - %4 = stablehlo.iota dim = 1 : tensor<4x4xi32> loc(#loc13) - %5 = stablehlo.compare GE, %3, %4, SIGNED : (tensor<4x4xi32>, tensor<4x4xi32>) -> tensor<4x4xi1> loc(#loc14) - %6 = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor> loc(#loc10) - %7 = stablehlo.broadcast_in_dim %6, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc15) - %8 = stablehlo.select %5, %arg0, %7 : tensor<4x4xi1>, tensor<4x4xcomplex> loc(#loc16) - return %8 : tensor<4x4xcomplex> loc(#loc10) - } loc(#loc10) -} loc(#loc) -#loc1 = loc("/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py":292:0) -#loc2 = loc("jit(cholesky)/jit(main)/transpose[permutation=(1, 0)]"(#loc1)) -#loc3 = loc("jit(cholesky)/jit(main)/real"(#loc1)) -#loc4 = loc("jit(cholesky)/jit(main)/imag"(#loc1)) -#loc5 = loc("jit(cholesky)/jit(main)/neg"(#loc1)) -#loc6 = loc("jit(cholesky)/jit(main)/complex"(#loc1)) -#loc7 = loc("jit(cholesky)/jit(main)/add"(#loc1)) -#loc8 = loc("jit(cholesky)/jit(main)/div"(#loc1)) -#loc9 = loc("jit(cholesky)/jit(main)/cholesky"(#loc1)) -#loc10 = loc("jit(cholesky)/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False,) name=tril keep_unused=False inline=False]"(#loc1)) -#loc11 = loc("jit(cholesky)/jit(main)/jit(tril)/iota[dtype=int32 shape=(4, 4) dimension=0]"(#loc1)) -#loc12 = loc("jit(cholesky)/jit(main)/jit(tril)/add"(#loc1)) -#loc13 = loc("jit(cholesky)/jit(main)/jit(tril)/iota[dtype=int32 shape=(4, 4) dimension=1]"(#loc1)) -#loc14 = loc("jit(cholesky)/jit(main)/jit(tril)/ge"(#loc1)) -#loc15 = loc("jit(cholesky)/jit(main)/jit(tril)/broadcast_in_dim[shape=(4, 4) broadcast_dimensions=()]"(#loc1)) -#loc16 = loc("jit(cholesky)/jit(main)/jit(tril)/select_n"(#loc1)) -""", - mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x011\x05\x01\x03\x01\x03\x05\x03!\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x1f!#%\x03J\x02\xe9)\x01\x97\x17\x0f\x07\x0b\x13\x0f\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x0b\x0f\x13#\x0b\x0b\x0b33\x0b\x0b\x13\x0f\x0b\x0b\x13\x0f\x0b\x1b\x0f\x0b\x13\x0f\x0b\x0f\x0b\x13\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x13\x0b\x0b\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x1b\x13\x13\x13\x0b\x03S\x0f\x0b\x0b\x0f\x0b\x0bO\x0f\x1b\x0b\x0b\x0b\x0b\x0f\x13\x0b\x0b\x0b\x0b\x0b\x0f\x1f\x0f\x0f\x0bOOO\x1f\x1f\x0b\x0b\x0b\x0b\x1b\x0f\x17\x13\x0bOO\x01\x03\x0f\x03'\x17\x0f\x0f\x17\x07\x07\x17\x0b\x07\x07\x17\x13\x07\x17\x13\x13\x13\x0f\x17\x02F\x08\x177\x92\x04\x01\x1dw\x01\x1f\x05'\x03\x03\x1d\xc3\x1d5\x01\x05)\x11\x01\x05\x05+\x05-\x05/\x051\x053\x03\x03\x07\xc1\x055\x1d?\x01\x057\x059\x1du\x01\x03\x03\x07\xcf\x03\x07+\x0f-\x0f\r/\x05;\x05=\x05?\x03\x0b\x11\xa5\x13\x99\x15\xb1\r\xb7\x17\xb9\x03\x0b\x11\x9d\x13\x99\x15\x9d\r\x9f\x17\xbd\x05A\x05C\x03\x03\x19\xbf\x1d=\x01\x05E\x05G\x03\x03\x19\xc5\x1dE\x01\x05I\x03\x05!\xa1#\xc7\x1dK\x01\x05K\x03\x03\x07\xc9\x1dQ\x01\x05M\x1dU\x01\x05O\x03\x03Y\xcb\x05Q\x1d]\x01\x05S\x1da\x01\x05U\x1de\x01\x05W\x1di\x01\x05Y\x1dm\x01\x05[\x1dq\x01\x05]\x03\x03\x07\xcd\x05_\x05a\x03\x03\x07\xd1\x03\x11}\xd3\x7f\x9b\x81\xd5\x83\xd7\x85\xd9\x87\xdb\x89\xdd\x8b\xe1\x05c\x05e\x05g\x05i\x05k\x05m\x05o\x05q\x03\x05!\xa1#\xe3\x03\x03\x07\xe5\x03\x03\x1d\xe7\x03\x03\x95\x9f\x05s\x1f!\x01#\x1d\x1du\x03\x03\xbb\x1dw\t\x07\x1f#!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\x03\xa7\r\x05\xa9\xab\xad\xaf\x1dy\x1d{\x1d}\x1d\x7f\x03\x03\xb3\r\x03\xb5\x9b\x1d\x81\x1d\x83\x1d\x85\r\x01\x1d\x87\x13\x0b\x01\x1f\x05\t\x00\x00\x00\x00\x1f\x1f\x01\x13\x0b\x05\x07\x05\x1f\x07!\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x19!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x07!\x00\x00\x00\x00\x00\x00\x00@\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x05\t\x01\x00\x00\x00\x1f\x05\t\x04\x00\x00\x00\x0b\x05\x1d\x89\x03\x01\x05\x01\x03\t\x97\x97\x97\xa3\x03\x03\xdf\x15\x03\x01\r\x01\x03\x05\xa3\x97\x07\x01\x1f\x07!\x00\x00\x00\x00\x00\x00\xf8\x7f\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f\x19!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\x02\x02)\x05\x11\x11\x11)\x01\x15)\x01\x11)\x05\x11\x11\x15\x1d\x01)\x05\x11\x11\x13\x03\x13\x0b\x1b)\x05\x11\x11\r)\x03\t\x0b\x13\x11\x03\x03\x03\x03)\x03\x01\x0b)\x03\x01\x1b)\x03\t\x1b)\x01\r)\x05\x05\x05\r\x04J\x04\x05\x01\x11\x05)\x07\x03\x01\t\x07\x11\x051\x05\x031_\x03\x03\x05\x13\x07[W\x03\x03\x03\x01\x15\x06_\x03\x0f\x03\x03\x17\x06c\x03\x0f\x03\x03\x19\x06g\x03\x0f\x03\x07\x1b\x06k\x03\x03\x05\x05\t\x0b\x06o\x03\x03\x05\x01\x0b\x03\x03\x05s\x03\x07\x05\x07%\t\x03\x03\x03\x0f\x1d\x06%\x03\x03\x05\r\x11\x03\x03\x03'\x03\x05\x03\x03\x03'\x03\x05\x03\x03\x03y\x03\x05\x1f\x07\x03{\x05\x03\x05\t\x15\x17\x19\x13\x03\x03\x03\x1b\x03\x05\x05\x07\x03\t\x03\x05\x03\x1f\r\x07\x03\x8d\x03%\x05\x1d!\x05\x07\x03\t\x03'\x03#\x03\x03\x03\x8f\x03\x07\x05\x07\x03\t\x03\x03\x03'\x05\x07\x03\x91\x03\x17\x03%\x0f\x06\x03\x03\x03\x07+\x1b)!\x07\x0b\x93\x03\x03\x03-\x11\x04\x05\x03/\x07\x11\x0b3\x05\x03\x15+\x03\x03\x05\t\x03;9\x03\t\x03\x03\x0b\x1b\x03\x05\x05\x07\x1f\t\x03\t\x03\x05\x0b\x06\x1f\x03\t\x05\x03\x07\t\x03CA\x03\t\r\x07IG\x03\x17\x05\t\x0b\x03\x03\x0bM\x03\x07\x05\x07O\t\x03\x03\x03\x0f\x0f\x06S\x03\x03\x07\r\x01\x11\x11\x04\x0b\x03\x13\x06\x03\x01\x05\x01\x00\x96\x18\x8b\x1d\x11\x0f\x0b!\x1b\x1d\x05\x1b\x0b\x03\x0f\x1f/!!)#\x1f\x19C99A9;;m\x19W\xb3K\x9bM\x9b\x97\xd2\x02\x1b%)+\x1b+\x1f\x1f\x15\x1d\x15\x13\r\x11\x1f\x15\x17\x15\x11\x11\x1b\x15\x15\x17\x0f\x11\x11)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00func_v1\x00iota_v1\x00add_v1\x00compare_v1\x00select_v1\x00return_v1\x00transpose_v1\x00real_v1\x00imag_v1\x00negate_v1\x00complex_v1\x00divide_v1\x00custom_call_v1\x00call_v1\x00value\x00sym_name\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00broadcast_dimensions\x00compare_type\x00comparison_direction\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_cholesky\x00jit(cholesky)/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False,) name=tril keep_unused=False inline=False]\x00/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py\x00jit(cholesky)/jit(main)/jit(tril)/iota[dtype=int32 shape=(4, 4) dimension=0]\x00jit(cholesky)/jit(main)/jit(tril)/add\x00jit(cholesky)/jit(main)/jit(tril)/iota[dtype=int32 shape=(4, 4) dimension=1]\x00jit(cholesky)/jit(main)/jit(tril)/ge\x00jit(cholesky)/jit(main)/jit(tril)/broadcast_in_dim[shape=(4, 4) broadcast_dimensions=()]\x00jit(cholesky)/jit(main)/jit(tril)/select_n\x00permutation\x00jit(cholesky)/jit(main)/transpose[permutation=(1, 0)]\x00jit(cholesky)/jit(main)/real\x00jit(cholesky)/jit(main)/imag\x00jit(cholesky)/jit(main)/neg\x00jit(cholesky)/jit(main)/complex\x00jit(cholesky)/jit(main)/add\x00jit(cholesky)/jit(main)/div\x00jit(cholesky)/jit(main)/cholesky\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00callee\x00\x00tril\x00jax.arg_info\x00x\x00mhlo.sharding\x00{replicated}\x00jax.result_info\x00main\x00public\x00private\x00lapack_zpotrf\x00", - xla_call_module_version=6, -) # End paste - data_2024_05_31 = {} - # Pasted from the test output (see export_back_compat_test_util.py module docstring) data_2024_05_31["c128"] = dict( testdata_version=1, diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/cpu_eig_lapack_geev.py b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_eig_lapack_geev.py index bc28857fa325..e6792dc2d1b4 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_data/cpu_eig_lapack_geev.py +++ b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_eig_lapack_geev.py @@ -15,279 +15,10 @@ # ruff: noqa import datetime -from numpy import array, float32, complex64 - -data_2023_06_19 = {} - -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_06_19["f32"] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_sgeev'], - serialized_date=datetime.date(2023, 6, 19), - inputs=(), - expected_outputs=(array([ 3.2464241e+01+0.j, -2.4642489e+00+0.j, 1.4189274e-07+0.j, - -4.0686123e-07+0.j], dtype=complex64), array([[-0.40377745 +0.j, -0.82883257 +0.j, -0.06733338 +0.j, - -0.5208027 +0.j], - [-0.46480742 +0.j, -0.4371466 +0.j, 0.49492982 +0.j, - 0.82081676 +0.j], - [-0.52583724 +0.j, -0.045459956+0.j, -0.78785884 +0.j, - -0.07922471 +0.j], - [-0.5868671 +0.j, 0.3462263 +0.j, 0.36026272 +0.j, - -0.2207891 +0.j]], dtype=complex64), array([[-0.11417642+0.j, -0.73277813+0.j, 0.16960056+0.j, - -0.5435681 +0.j], - [-0.33000448+0.j, -0.28974825+0.j, 0.16204938+0.j, - 0.67456985+0.j], - [-0.54583275+0.j, 0.15328142+0.j, -0.8329006 +0.j, - 0.28156415+0.j], - [-0.761661 +0.j, 0.5963111 +0.j, 0.5012507 +0.j, - -0.41256607+0.j]], dtype=complex64)), - mlir_module_text=r""" -module @jit_func attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main() -> (tensor<4xcomplex> {jax.result_info = "[0]"}, tensor<4x4xcomplex> {jax.result_info = "[1]"}, tensor<4x4xcomplex> {jax.result_info = "[2]"}) { - %0 = stablehlo.iota dim = 0 : tensor<16xf32> loc(#loc3) - %1 = stablehlo.reshape %0 : (tensor<16xf32>) -> tensor<4x4xf32> loc(#loc4) - %2 = stablehlo.constant dense<1> : tensor loc(#loc5) - %3 = stablehlo.constant dense<4> : tensor loc(#loc5) - %4 = stablehlo.constant dense<86> : tensor loc(#loc5) - %5 = stablehlo.constant dense<86> : tensor loc(#loc5) - %6:8 = stablehlo.custom_call @lapack_sgeev(%2, %3, %4, %5, %1) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>, dense<[0, 1]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>]} : (tensor, tensor, tensor, tensor, tensor<4x4xf32>) -> (tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xcomplex>, tensor<4x4xcomplex>, tensor) loc(#loc5) - %7 = stablehlo.complex %6#3, %6#4 : tensor<4xcomplex> loc(#loc5) - %8 = stablehlo.constant dense<0> : tensor loc(#loc5) - %9 = stablehlo.broadcast_in_dim %8, dims = [] : (tensor) -> tensor loc(#loc5) - %10 = stablehlo.compare EQ, %6#7, %9, SIGNED : (tensor, tensor) -> tensor loc(#loc5) - %11 = stablehlo.broadcast_in_dim %10, dims = [] : (tensor) -> tensor<1xi1> loc(#loc5) - %12 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> loc(#loc5) - %13 = stablehlo.broadcast_in_dim %12, dims = [] : (tensor>) -> tensor<4xcomplex> loc(#loc5) - %14 = stablehlo.broadcast_in_dim %11, dims = [0] : (tensor<1xi1>) -> tensor<4xi1> loc(#loc5) - %15 = stablehlo.select %14, %7, %13 : tensor<4xi1>, tensor<4xcomplex> loc(#loc5) - %16 = stablehlo.broadcast_in_dim %10, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc5) - %17 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> loc(#loc5) - %18 = stablehlo.broadcast_in_dim %17, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc5) - %19 = stablehlo.broadcast_in_dim %16, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc5) - %20 = stablehlo.select %19, %6#5, %18 : tensor<4x4xi1>, tensor<4x4xcomplex> loc(#loc5) - %21 = stablehlo.broadcast_in_dim %10, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc5) - %22 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> loc(#loc5) - %23 = stablehlo.broadcast_in_dim %22, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc5) - %24 = stablehlo.broadcast_in_dim %21, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc5) - %25 = stablehlo.select %24, %6#6, %23 : tensor<4x4xi1>, tensor<4x4xcomplex> loc(#loc5) - return %15, %20, %25 : tensor<4xcomplex>, tensor<4x4xcomplex>, tensor<4x4xcomplex> loc(#loc) - } loc(#loc) -} loc(#loc) -#loc = loc(unknown) -#loc1 = loc("/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py":496:0) -#loc2 = loc("/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py":497:0) -#loc3 = loc("jit(func)/jit(main)/iota[dtype=float32 shape=(16,) dimension=0]"(#loc1)) -#loc4 = loc("jit(func)/jit(main)/reshape[new_sizes=(4, 4) dimensions=None]"(#loc1)) -#loc5 = loc("jit(func)/jit(main)/eig[compute_left_eigenvectors=True compute_right_eigenvectors=True]"(#loc2)) -""", - mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01%\x05\x01\x03\x01\x03\x05\x03\x15\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x03\xe7\x9b9\x01[\x0f\x13\x0b\x07\x0b\x13\x0f\x0b\x17\x0b\x13\x13#\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x0f\x0b\x13\x0b\x17\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x0b\x13\x03AO\x0f\x0b\x0b/\x0b\x17\x13\x0b\x13\x0b\x13\x0b\x0b\x0b\x0f\x1f\x1f\x13\x0b\x0b\x0b\x0b\x1f+\x1f\x0f\x0b\x0b//O\x01\x03\x0f\x037\x17\x0f\x07\x13\x07\x07\x17\x0f\x0b\x0f\x07\x13\x17\x17\x1b\x13\x07\x07\x13\x13\x13\x13\x0f\x13\x13\x13\x13\x02v\x06\x1d9;\x03\x03\t\x8f\x05\x1b\x1f\x05\x1d\x03\x03\x05\x95\x11\x01\x05\x05\x1f\x17\x13\xc2\x07\x01\x05!\x03\x03\x05\x7f\x03\x03\t\x99\x03\x07\x1b\r\x1d\r\x0f\x1f\x05#\x05%\x05'\x03\x0b#_%e'g\x0fu)w\x05)\x05+\x05-\x05/\x03\x03-y\x051\x1d1\x11\x053\x1d5\x11\x055\x03\x03\x05{\x057\x17\x13\xc6\x07\x01\x03\x03\x05}\x03\x11A\x81C\x83E\x85G_I\x87K\x89M_O\x8b\x059\x05;\x05=\x05?\x05A\x05C\x05E\x05G\x03\x03\x05\x8d\x03\x05U\x91W\x93\x05I\x05K\x03\x03\t\x97\x1f)!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f'\x01\x03\x01\x1dM\x1f+\x11\x00\x00\x00\x00\x00\x00\x00\x00#\x1f\x03\x07imq\r\x03ak\x1dO\r\x03ao\x1dQ\r\x03as\x1dS\x1dU\x1dW\x13\r\x01\x1f\x05\t\x01\x00\x00\x00\x1f\x05\t\x04\x00\x00\x00\x1f\x15\x03V\x0b\x05\x1dY\x1d[\x05\x01\x03\x0b]]]][\x03\x11[[[cc[[]\x1f\x05\t\x00\x00\x00\x00\x1f-\x01\t\x07\x07\x01\x1f\x11\x11\x00\x00\xc0\x7f\x00\x00\xc0\x7f\x1f5\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f7!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\x02\x02)\x05\x11\x11\x13)\x01#\x01)\x03\x11\x13\t\x1d)\x05\x11\x11\x0b)\x01\x13\x03\x0b)\x01%\x13)\x03\x11\x0b)\x05\x05\x05\x07)\x05\x11\x11\x07\x11\x01\x07\t\x03\x03)\x03A\x0b\x1b!)\x03\x01\x17)\x03\t\x17)\x03\x05\x17)\x03\x01\r)\x01\x07)\x03\x05\x07)\x03\x11\x07)\x03\x05\r)\x03\t\r\x04\x92\x03\x05\x01\x11\x07\x19\x07\x03\x01\x05\t\x11\x07!\x05\x03Cm\x0b\x03/+\x03!\r\x063\x03\x0f\x03\x01\x05\x03\x017\x03\x05\x05\x03\x01=\x03\x05\x05\x03\x01\x15\x03\x15\x05\x03\x01\x15\x03\x15\x0f\x07\x01?\x11\x0f\x0f\x0f\x19\x19\x03\x03\x05\x0b\x05\x07\t\x0b\x03\x11\x06\x01\x03\t\x05\x13\x15\x05\x03\x01Q\x03\x05\x03\x07\x01\x03\x03\x05\x03\x1f\x13\x07\x01S\x03/\x05\x1b!\x03\x07\x01\x03\x031\x03#\x05\x03\x01\x0b\x03\x11\x03\x07\x01\x03\x03\t\x03'\x03\x07\x01Y\x033\x03%\x07\x06\x01\x03\t\x07+\x1d)\x03\x07\x01\x03\x03\x1b\x03#\x05\x03\x01\x0b\x03\x11\x03\x07\x01\x03\x03\x03\x031\x03\x07\x01\x17\x03\x1d\x03/\x07\x06\x01\x03\x03\x075\x173\x03\x07\x01\x03\x03\x1b\x03#\x05\x03\x01\x0b\x03\x11\x03\x07\x01\x03\x03\x03\x03;\x03\x07\x01\x17\x03\x1d\x039\x07\x06\x01\x03\x03\x07?\x19=\x15\x04\x07\x07-7A\x06\x03\x01\x05\x01\x00&\r]\x1b\x03\x0f\x0b\t\t\t!+\x1b\x1f/!!)#\x1f\x19\xb1}\x81\x1f\x1f\x15\x1d\x15\x13%)\x97\x13+\r\x15\x17\x17\x1f\x17\x11\x11\x15\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00select_v1\x00func_v1\x00iota_v1\x00reshape_v1\x00custom_call_v1\x00complex_v1\x00compare_v1\x00return_v1\x00value\x00broadcast_dimensions\x00sym_name\x00/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00jit(func)/jit(main)/iota[dtype=float32 shape=(16,) dimension=0]\x00jit(func)/jit(main)/reshape[new_sizes=(4, 4) dimensions=None]\x00jit(func)/jit(main)/eig[compute_left_eigenvectors=True compute_right_eigenvectors=True]\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00compare_type\x00comparison_direction\x00jax.result_info\x00[0]\x00[1]\x00[2]\x00main\x00public\x00\x00lapack_sgeev\x00", - xla_call_module_version=6, -) # End paste - - -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_06_19["f64"] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_dgeev'], - serialized_date=datetime.date(2023, 6, 19), - inputs=(), - expected_outputs=(array([ 3.2464249196572972e+01+0.j, -2.4642491965729802e+00+0.j, - -1.5210037805054253e-15+0.j, 1.2568096307462507e-16+0.j]), array([[-0.4037774907686232 +0.j, 0.8288327563197505 +0.j, - 0.5454962288885842 +0.j, -0.2420483778598153 +0.j], - [-0.46480737115848986 +0.j, 0.43714638836388725 +0.j, - -0.7640998541831632 +0.j, -0.04349021275982002 +0.j], - [-0.5258372515483576 +0.j, 0.045460020408024715+0.j, - -0.10828897829942748 +0.j, 0.8131255590990858 +0.j], - [-0.5868671319382249 +0.j, -0.3462263475478384 +0.j, - 0.32689260359400607 +0.j, -0.5275869684794504 +0.j]]), array([[-0.11417645138733863+0.j, 0.7327780959803554 +0.j, - 0.49133754464261303+0.j, -0.04933420991901029+0.j], - [-0.33000459866554765+0.j, 0.28974835239692637+0.j, - -0.8355289351028521 +0.j, -0.3408099365295394 +0.j], - [-0.545832745943757 +0.j, -0.1532813911865017 +0.j, - 0.1970452362778633 +0.j, 0.8296225028161098 +0.j], - [-0.7616608932219663 +0.j, -0.5963111347699308 +0.j, - 0.14714615418237506+0.j, -0.43947835636755994+0.j]])), - mlir_module_text=r""" -module @jit_func attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main() -> (tensor<4xcomplex> {jax.result_info = "[0]"}, tensor<4x4xcomplex> {jax.result_info = "[1]"}, tensor<4x4xcomplex> {jax.result_info = "[2]"}) { - %0 = stablehlo.iota dim = 0 : tensor<16xf64> loc(#loc3) - %1 = stablehlo.reshape %0 : (tensor<16xf64>) -> tensor<4x4xf64> loc(#loc4) - %2 = stablehlo.constant dense<1> : tensor loc(#loc5) - %3 = stablehlo.constant dense<4> : tensor loc(#loc5) - %4 = stablehlo.constant dense<86> : tensor loc(#loc5) - %5 = stablehlo.constant dense<86> : tensor loc(#loc5) - %6:8 = stablehlo.custom_call @lapack_dgeev(%2, %3, %4, %5, %1) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>, dense<[0, 1]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>]} : (tensor, tensor, tensor, tensor, tensor<4x4xf64>) -> (tensor<4x4xf64>, tensor<4x4xf64>, tensor<4x4xf64>, tensor<4xf64>, tensor<4xf64>, tensor<4x4xcomplex>, tensor<4x4xcomplex>, tensor) loc(#loc5) - %7 = stablehlo.complex %6#3, %6#4 : tensor<4xcomplex> loc(#loc5) - %8 = stablehlo.constant dense<0> : tensor loc(#loc5) - %9 = stablehlo.broadcast_in_dim %8, dims = [] : (tensor) -> tensor loc(#loc5) - %10 = stablehlo.compare EQ, %6#7, %9, SIGNED : (tensor, tensor) -> tensor loc(#loc5) - %11 = stablehlo.broadcast_in_dim %10, dims = [] : (tensor) -> tensor<1xi1> loc(#loc5) - %12 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> loc(#loc5) - %13 = stablehlo.broadcast_in_dim %12, dims = [] : (tensor>) -> tensor<4xcomplex> loc(#loc5) - %14 = stablehlo.broadcast_in_dim %11, dims = [0] : (tensor<1xi1>) -> tensor<4xi1> loc(#loc5) - %15 = stablehlo.select %14, %7, %13 : tensor<4xi1>, tensor<4xcomplex> loc(#loc5) - %16 = stablehlo.broadcast_in_dim %10, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc5) - %17 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> loc(#loc5) - %18 = stablehlo.broadcast_in_dim %17, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc5) - %19 = stablehlo.broadcast_in_dim %16, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc5) - %20 = stablehlo.select %19, %6#5, %18 : tensor<4x4xi1>, tensor<4x4xcomplex> loc(#loc5) - %21 = stablehlo.broadcast_in_dim %10, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc5) - %22 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> loc(#loc5) - %23 = stablehlo.broadcast_in_dim %22, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc5) - %24 = stablehlo.broadcast_in_dim %21, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc5) - %25 = stablehlo.select %24, %6#6, %23 : tensor<4x4xi1>, tensor<4x4xcomplex> loc(#loc5) - return %15, %20, %25 : tensor<4xcomplex>, tensor<4x4xcomplex>, tensor<4x4xcomplex> loc(#loc) - } loc(#loc) -} loc(#loc) -#loc = loc(unknown) -#loc1 = loc("/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py":496:0) -#loc2 = loc("/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py":497:0) -#loc3 = loc("jit(func)/jit(main)/iota[dtype=float64 shape=(16,) dimension=0]"(#loc1)) -#loc4 = loc("jit(func)/jit(main)/reshape[new_sizes=(4, 4) dimensions=None]"(#loc1)) -#loc5 = loc("jit(func)/jit(main)/eig[compute_left_eigenvectors=True compute_right_eigenvectors=True]"(#loc2)) -""", - mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01%\x05\x01\x03\x01\x03\x05\x03\x15\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x03\xe7\x9b9\x01[\x0f\x13\x0b\x07\x0b\x13\x0f\x0b\x17\x0b\x13\x13#\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x0f\x0b\x13\x0b\x17\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x0b\x13\x03AO\x0f\x0b\x0b/\x0b\x17\x13\x0b\x13\x0b\x13\x0b\x0b\x0b\x0f\x1f\x1f\x13\x0b\x0b\x0b\x0b\x1f+\x1f\x0f\x0b\x0bO/O\x01\x03\x0f\x037\x17\x0f\x07\x13\x07\x07\x17\x0f\x0b\x0f\x07\x13\x17\x17\x1b\x13\x07\x07\x13\x13\x13\x13\x0f\x13\x13\x13\x13\x02\x96\x06\x1d9;\x03\x03\t\x8f\x05\x1b\x1f\x05\x1d\x03\x03\x05\x95\x11\x01\x05\x05\x1f\x17\x13\xc2\x07\x01\x05!\x03\x03\x05\x7f\x03\x03\t\x99\x03\x07\x1b\r\x1d\r\x0f\x1f\x05#\x05%\x05'\x03\x0b#_%e'g\x0fu)w\x05)\x05+\x05-\x05/\x03\x03-y\x051\x1d1\x11\x053\x1d5\x11\x055\x03\x03\x05{\x057\x17\x13\xc6\x07\x01\x03\x03\x05}\x03\x11A\x81C\x83E\x85G_I\x87K\x89M_O\x8b\x059\x05;\x05=\x05?\x05A\x05C\x05E\x05G\x03\x03\x05\x8d\x03\x05U\x91W\x93\x05I\x05K\x03\x03\t\x97\x1f)!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f'\x01\x03\x01\x1dM\x1f+\x11\x00\x00\x00\x00\x00\x00\x00\x00#\x1f\x03\x07imq\r\x03ak\x1dO\r\x03ao\x1dQ\r\x03as\x1dS\x1dU\x1dW\x13\r\x01\x1f\x05\t\x01\x00\x00\x00\x1f\x05\t\x04\x00\x00\x00\x1f\x15\x03V\x0b\x05\x1dY\x1d[\x05\x01\x03\x0b]]]][\x03\x11[[[cc[[]\x1f\x05\t\x00\x00\x00\x00\x1f-\x01\t\x07\x07\x01\x1f\x11!\x00\x00\x00\x00\x00\x00\xf8\x7f\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f5\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f7!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\x02\x02)\x05\x11\x11\x13)\x01#\x01)\x03\x11\x13\x0b\x1d)\x05\x11\x11\x0b)\x01\x13\x03\x0b)\x01%\x13)\x03\x11\x0b)\x05\x05\x05\x07)\x05\x11\x11\x07\x11\x01\x07\t\x03\x03)\x03A\x0b\x1b!)\x03\x01\x17)\x03\t\x17)\x03\x05\x17)\x03\x01\r)\x01\x07)\x03\x05\x07)\x03\x11\x07)\x03\x05\r)\x03\t\r\x04\x92\x03\x05\x01\x11\x07\x19\x07\x03\x01\x05\t\x11\x07!\x05\x03Cm\x0b\x03/+\x03!\r\x063\x03\x0f\x03\x01\x05\x03\x017\x03\x05\x05\x03\x01=\x03\x05\x05\x03\x01\x15\x03\x15\x05\x03\x01\x15\x03\x15\x0f\x07\x01?\x11\x0f\x0f\x0f\x19\x19\x03\x03\x05\x0b\x05\x07\t\x0b\x03\x11\x06\x01\x03\t\x05\x13\x15\x05\x03\x01Q\x03\x05\x03\x07\x01\x03\x03\x05\x03\x1f\x13\x07\x01S\x03/\x05\x1b!\x03\x07\x01\x03\x031\x03#\x05\x03\x01\x0b\x03\x11\x03\x07\x01\x03\x03\t\x03'\x03\x07\x01Y\x033\x03%\x07\x06\x01\x03\t\x07+\x1d)\x03\x07\x01\x03\x03\x1b\x03#\x05\x03\x01\x0b\x03\x11\x03\x07\x01\x03\x03\x03\x031\x03\x07\x01\x17\x03\x1d\x03/\x07\x06\x01\x03\x03\x075\x173\x03\x07\x01\x03\x03\x1b\x03#\x05\x03\x01\x0b\x03\x11\x03\x07\x01\x03\x03\x03\x03;\x03\x07\x01\x17\x03\x1d\x039\x07\x06\x01\x03\x03\x07?\x19=\x15\x04\x07\x07-7A\x06\x03\x01\x05\x01\x00&\r]\x1b\x03\x0f\x0b\t\t\t!+\x1b\x1f/!!)#\x1f\x19\xb1}\x81\x1f\x1f\x15\x1d\x15\x13%)\x97\x13+\r\x15\x17\x17\x1f\x17\x11\x11\x15\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00select_v1\x00func_v1\x00iota_v1\x00reshape_v1\x00custom_call_v1\x00complex_v1\x00compare_v1\x00return_v1\x00value\x00broadcast_dimensions\x00sym_name\x00/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00jit(func)/jit(main)/iota[dtype=float64 shape=(16,) dimension=0]\x00jit(func)/jit(main)/reshape[new_sizes=(4, 4) dimensions=None]\x00jit(func)/jit(main)/eig[compute_left_eigenvectors=True compute_right_eigenvectors=True]\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00compare_type\x00comparison_direction\x00jax.result_info\x00[0]\x00[1]\x00[2]\x00main\x00public\x00\x00lapack_dgeev\x00", - xla_call_module_version=6, -) # End paste - - -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_06_19["c64"] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_cgeev'], - serialized_date=datetime.date(2023, 6, 19), - inputs=(), - expected_outputs=(array([ 3.2464237e+01+0.j, -2.4642489e+00+0.j, -5.7737714e-07+0.j, - 1.4719126e-07+0.j], dtype=complex64), array([[ 0.4037776 +0.j, 0.8288327 +0.j, -0.53126234 -0.j, - 0.052026853-0.j], - [ 0.46480742 +0.j, 0.43714646 -0.j, 0.80768156 +0.j, - -0.47577178 -0.j], - [ 0.52583724 +0.j, 0.045459922-0.j, -0.021575088-0.j, - 0.79546237 +0.j], - [ 0.5868671 +0.j, -0.3462263 -0.j, -0.25484383 -0.j, - -0.3717177 -0.j]], dtype=complex64), array([[ 0.114176475+0.j, 0.7327782 +0.j, -0.5452461 -0.j, - -0.13326685 -0.j], - [ 0.3300045 +0.j, 0.28974816 -0.j, 0.68821603 +0.j, - -0.2182906 -0.j], - [ 0.5458328 +0.j, -0.1532814 -0.j, 0.25930583 -0.j, - 0.8363818 +0.j], - [ 0.76166093 +0.j, -0.5963111 -0.j, -0.40227592 -0.j, - -0.4848244 -0.j]], dtype=complex64)), - mlir_module_text=r""" -module @jit_func attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main() -> (tensor<4xcomplex> {jax.result_info = "[0]"}, tensor<4x4xcomplex> {jax.result_info = "[1]"}, tensor<4x4xcomplex> {jax.result_info = "[2]"}) { - %0 = stablehlo.iota dim = 0 : tensor<16xcomplex> loc(#loc3) - %1 = stablehlo.reshape %0 : (tensor<16xcomplex>) -> tensor<4x4xcomplex> loc(#loc4) - %2 = stablehlo.constant dense<1> : tensor loc(#loc5) - %3 = stablehlo.constant dense<4> : tensor loc(#loc5) - %4 = stablehlo.constant dense<86> : tensor loc(#loc5) - %5 = stablehlo.constant dense<86> : tensor loc(#loc5) - %6:6 = stablehlo.custom_call @lapack_cgeev(%2, %3, %4, %5, %1) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>, dense<[0, 1]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>]} : (tensor, tensor, tensor, tensor, tensor<4x4xcomplex>) -> (tensor<4x4xcomplex>, tensor<8xf32>, tensor<4xcomplex>, tensor<4x4xcomplex>, tensor<4x4xcomplex>, tensor) loc(#loc5) - %7 = stablehlo.constant dense<0> : tensor loc(#loc5) - %8 = stablehlo.broadcast_in_dim %7, dims = [] : (tensor) -> tensor loc(#loc5) - %9 = stablehlo.compare EQ, %6#5, %8, SIGNED : (tensor, tensor) -> tensor loc(#loc5) - %10 = stablehlo.broadcast_in_dim %9, dims = [] : (tensor) -> tensor<1xi1> loc(#loc5) - %11 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> loc(#loc5) - %12 = stablehlo.broadcast_in_dim %11, dims = [] : (tensor>) -> tensor<4xcomplex> loc(#loc5) - %13 = stablehlo.broadcast_in_dim %10, dims = [0] : (tensor<1xi1>) -> tensor<4xi1> loc(#loc5) - %14 = stablehlo.select %13, %6#2, %12 : tensor<4xi1>, tensor<4xcomplex> loc(#loc5) - %15 = stablehlo.broadcast_in_dim %9, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc5) - %16 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> loc(#loc5) - %17 = stablehlo.broadcast_in_dim %16, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc5) - %18 = stablehlo.broadcast_in_dim %15, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc5) - %19 = stablehlo.select %18, %6#3, %17 : tensor<4x4xi1>, tensor<4x4xcomplex> loc(#loc5) - %20 = stablehlo.broadcast_in_dim %9, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc5) - %21 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> loc(#loc5) - %22 = stablehlo.broadcast_in_dim %21, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc5) - %23 = stablehlo.broadcast_in_dim %20, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc5) - %24 = stablehlo.select %23, %6#4, %22 : tensor<4x4xi1>, tensor<4x4xcomplex> loc(#loc5) - return %14, %19, %24 : tensor<4xcomplex>, tensor<4x4xcomplex>, tensor<4x4xcomplex> loc(#loc) - } loc(#loc) -} loc(#loc) -#loc = loc(unknown) -#loc1 = loc("/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py":496:0) -#loc2 = loc("/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py":497:0) -#loc3 = loc("jit(func)/jit(main)/iota[dtype=complex64 shape=(16,) dimension=0]"(#loc1)) -#loc4 = loc("jit(func)/jit(main)/reshape[new_sizes=(4, 4) dimensions=None]"(#loc1)) -#loc5 = loc("jit(func)/jit(main)/eig[compute_left_eigenvectors=True compute_right_eigenvectors=True]"(#loc2)) -""", - mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01#\x05\x01\x03\x01\x03\x05\x03\x13\x07\t\x0b\r\x0f\x11\x13\x15\x17\x03\xe5\x9b7\x01[\x0f\x13\x0b\x07\x0b\x13\x0f\x0b\x17\x0b\x13\x13#\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x0f\x0b\x13\x0b\x17\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x0b\x13\x03A\x0fO\x0b\x0b/\x0b\x17\x13\x0b\x13\x0b\x13\x0b\x0b\x0b\x0f\x1f\x1f\x13\x0b\x0b\x0b\x0b\x1f#\x1f\x0f\x0b\x0b//O\x01\x03\x0f\x035\x17\x0f\x07\x13\x0b\x07\x0f\x0f\x07\x07\x17\x17\x1b\x13\x07\x07\x13\x13\x13\x13\x13\x0f\x13\x13\x13\x13\x02Z\x06\x1d9;\x03\x03\t\x8f\x05\x19\x1f\x05\x1b\x03\x03\x05\x95\x11\x01\x05\x05\x1d\x17\x13\xc2\x07\x01\x05\x1f\x03\x03\x05\x7f\x03\x03\t\x99\x03\x07\x1b\r\x1d\r\x0f\x1f\x05!\x05#\x05%\x03\x0b#_%e'g\x0fu)w\x05'\x05)\x05+\x05-\x03\x03-y\x05/\x1d1\x11\x051\x1d5\x11\x053\x03\x03\x05{\x055\x17\x13\xc6\x07\x01\x03\x03\x05}\x03\x11A\x81C\x83E\x85G_I\x87K\x89M_O\x8b\x057\x059\x05;\x05=\x05?\x05A\x05C\x05E\x03\x03\x05\x8d\x03\x05U\x91W\x93\x05G\x05I\x03\x03\t\x97\x1f%\x01\x1f'!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1dK\x1f)\x11\x00\x00\x00\x00\x00\x00\x00\x00#\x1b\x03\x07imq\r\x03ak\x1dM\r\x03ao\x1dO\r\x03as\x1dQ\x1dS\x1dU\x13\r\x01\x1f\x05\t\x01\x00\x00\x00\x1f\x05\t\x04\x00\x00\x00\x1f\x11\x03V\x0b\x05\x1dW\x1dY\x05\x01\x03\x0b[[[[]\x03\r]cc]][\x1f\x05\t\x00\x00\x00\x00\x1f+\x01\t\x07\x07\x01\x1f\x0f\x11\x00\x00\xc0\x7f\x00\x00\xc0\x7f\x1f3\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f5!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\x02\x02)\x05\x11\x11\x0b)\x01\x1f\x01)\x03\x11\x0b\x03\x15\x1d)\x01\x0b)\x01!\x13\t)\x05\x05\x05\x07)\x05\x11\x11\x07\x11\x01\x07\t\x03\x03)\x03A\x0b\x1b!)\x03!\x15)\x03\x01\x13)\x03\t\x13)\x03\x05\x13)\x03\x01\r)\x01\x07)\x03\x05\x07)\x03\x11\x07)\x03\x05\r)\x03\t\r\x04j\x03\x05\x01\x11\x07\x19\x07\x03\x01\x05\t\x11\x07!\x05\x03=i\x0b\x03/+\x03\x1d\r\x063\x03\x03\x03\x01\x05\x03\x017\x03\x05\x05\x03\x01=\x03\x05\x05\x03\x01\x15\x03\x11\x05\x03\x01\x15\x03\x11\x0f\x07\x01?\r\x03#\t\x03\x03\x05\x0b\x05\x07\t\x0b\x03\x05\x03\x01Q\x03\x05\x03\x07\x01\x03\x03\x05\x03\x19\x11\x07\x01S\x03-\x05\x17\x1b\x03\x07\x01\x03\x03/\x03\x1d\x05\x03\x01\x0b\x03\x0f\x03\x07\x01\x03\x03\t\x03!\x03\x07\x01Y\x031\x03\x1f\x07\x06\x01\x03\t\x07%\x11#\x03\x07\x01\x03\x03\x17\x03\x1d\x05\x03\x01\x0b\x03\x0f\x03\x07\x01\x03\x03\x03\x03+\x03\x07\x01\x17\x03\x19\x03)\x07\x06\x01\x03\x03\x07/\x13-\x03\x07\x01\x03\x03\x17\x03\x1d\x05\x03\x01\x0b\x03\x0f\x03\x07\x01\x03\x03\x03\x035\x03\x07\x01\x17\x03\x19\x033\x07\x06\x01\x03\x03\x079\x157\x13\x04\x07\x07'1;\x06\x03\x01\x05\x01\x00\xfe\x0c[\x1b\x03\x0f\x0b\t\t\t!+\x1b\x1f/!!)#\x1f\x19\xb1}\x85\x1f\x1f\x15\x1d\x15\x13%)\x97\x13+\r\x15\x17\x1f\x17\x11\x11\x15\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00select_v1\x00func_v1\x00iota_v1\x00reshape_v1\x00custom_call_v1\x00compare_v1\x00return_v1\x00value\x00broadcast_dimensions\x00sym_name\x00/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00jit(func)/jit(main)/iota[dtype=complex64 shape=(16,) dimension=0]\x00jit(func)/jit(main)/reshape[new_sizes=(4, 4) dimensions=None]\x00jit(func)/jit(main)/eig[compute_left_eigenvectors=True compute_right_eigenvectors=True]\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00compare_type\x00comparison_direction\x00jax.result_info\x00[0]\x00[1]\x00[2]\x00main\x00public\x00\x00lapack_cgeev\x00", - xla_call_module_version=6, -) # End paste - - -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_06_19["c128"] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_zgeev'], - serialized_date=datetime.date(2023, 6, 19), - inputs=(), - expected_outputs=(array([ 3.2464249196572965e+01+0.j, -2.4642491965729807e+00+0.j, - -1.6035677295293283e-15+0.j, 1.2218554396786611e-16+0.j]), array([[ 0.40377749076862335 +0.j, 0.8288327563197505 +0.j, - -0.5457111210844892 +0.j, -0.2322136424094458 -0.j], - [ 0.46480737115848997 +0.j, 0.4371463883638875 -0.j, - 0.7625701354883243 +0.j, -0.06012408092789514 -0.j], - [ 0.5258372515483578 +0.j, 0.045460020408024694-0.j, - 0.1119930922768192 +0.j, 0.8168890890841272 +0.j], - [ 0.5868671319382247 +0.j, -0.34622634754783854 -0.j, - -0.32885210668065423 +0.j, -0.5245513657467864 -0.j]]), array([[ 0.11417645138733871+0.j, 0.7327780959803554 +0.j, - -0.49606131100796214+0.j, -0.04689746607984153-0.j], - [ 0.3300045986655476 +0.j, 0.2897483523969264 -0.j, - 0.8344969112540657 +0.j, -0.34421909950105706-0.j], - [ 0.5458327459437571 +0.j, -0.15328139118650172-0.j, - -0.18080988948424467+0.j, 0.8291305972416383 +0.j], - [ 0.7616608932219663 +0.j, -0.5963111347699308 -0.j, - -0.1576257107618584 +0.j, -0.4380140316607401 -0.j]])), - mlir_module_text=r""" -module @jit_func attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main() -> (tensor<4xcomplex> {jax.result_info = "[0]"}, tensor<4x4xcomplex> {jax.result_info = "[1]"}, tensor<4x4xcomplex> {jax.result_info = "[2]"}) { - %0 = stablehlo.iota dim = 0 : tensor<16xcomplex> loc(#loc3) - %1 = stablehlo.reshape %0 : (tensor<16xcomplex>) -> tensor<4x4xcomplex> loc(#loc4) - %2 = stablehlo.constant dense<1> : tensor loc(#loc5) - %3 = stablehlo.constant dense<4> : tensor loc(#loc5) - %4 = stablehlo.constant dense<86> : tensor loc(#loc5) - %5 = stablehlo.constant dense<86> : tensor loc(#loc5) - %6:6 = stablehlo.custom_call @lapack_zgeev(%2, %3, %4, %5, %1) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>, dense<[0, 1]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>]} : (tensor, tensor, tensor, tensor, tensor<4x4xcomplex>) -> (tensor<4x4xcomplex>, tensor<8xf64>, tensor<4xcomplex>, tensor<4x4xcomplex>, tensor<4x4xcomplex>, tensor) loc(#loc5) - %7 = stablehlo.constant dense<0> : tensor loc(#loc5) - %8 = stablehlo.broadcast_in_dim %7, dims = [] : (tensor) -> tensor loc(#loc5) - %9 = stablehlo.compare EQ, %6#5, %8, SIGNED : (tensor, tensor) -> tensor loc(#loc5) - %10 = stablehlo.broadcast_in_dim %9, dims = [] : (tensor) -> tensor<1xi1> loc(#loc5) - %11 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> loc(#loc5) - %12 = stablehlo.broadcast_in_dim %11, dims = [] : (tensor>) -> tensor<4xcomplex> loc(#loc5) - %13 = stablehlo.broadcast_in_dim %10, dims = [0] : (tensor<1xi1>) -> tensor<4xi1> loc(#loc5) - %14 = stablehlo.select %13, %6#2, %12 : tensor<4xi1>, tensor<4xcomplex> loc(#loc5) - %15 = stablehlo.broadcast_in_dim %9, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc5) - %16 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> loc(#loc5) - %17 = stablehlo.broadcast_in_dim %16, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc5) - %18 = stablehlo.broadcast_in_dim %15, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc5) - %19 = stablehlo.select %18, %6#3, %17 : tensor<4x4xi1>, tensor<4x4xcomplex> loc(#loc5) - %20 = stablehlo.broadcast_in_dim %9, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc5) - %21 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> loc(#loc5) - %22 = stablehlo.broadcast_in_dim %21, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc5) - %23 = stablehlo.broadcast_in_dim %20, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc5) - %24 = stablehlo.select %23, %6#4, %22 : tensor<4x4xi1>, tensor<4x4xcomplex> loc(#loc5) - return %14, %19, %24 : tensor<4xcomplex>, tensor<4x4xcomplex>, tensor<4x4xcomplex> loc(#loc) - } loc(#loc) -} loc(#loc) -#loc = loc(unknown) -#loc1 = loc("/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py":496:0) -#loc2 = loc("/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py":497:0) -#loc3 = loc("jit(func)/jit(main)/iota[dtype=complex128 shape=(16,) dimension=0]"(#loc1)) -#loc4 = loc("jit(func)/jit(main)/reshape[new_sizes=(4, 4) dimensions=None]"(#loc1)) -#loc5 = loc("jit(func)/jit(main)/eig[compute_left_eigenvectors=True compute_right_eigenvectors=True]"(#loc2)) -""", - mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01#\x05\x01\x03\x01\x03\x05\x03\x13\x07\t\x0b\r\x0f\x11\x13\x15\x17\x03\xe5\x9b7\x01[\x0f\x13\x0b\x07\x0b\x13\x0f\x0b\x17\x0b\x13\x13#\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x0f\x0b\x13\x0b\x17\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x0b\x13\x03A\x0fO\x0b\x0b/\x0b\x17\x13\x0b\x13\x0b\x13\x0b\x0b\x0b\x0f\x1f\x1f\x13\x0b\x0b\x0b\x0b\x1f#\x1f\x0f\x0b\x0bO/O\x01\x03\x0f\x035\x17\x0f\x07\x13\x0b\x07\x0f\x0f\x07\x07\x17\x17\x1b\x13\x07\x07\x13\x13\x13\x13\x13\x0f\x13\x13\x13\x13\x02z\x06\x1d9;\x03\x03\t\x8f\x05\x19\x1f\x05\x1b\x03\x03\x05\x95\x11\x01\x05\x05\x1d\x17\x13\xc2\x07\x01\x05\x1f\x03\x03\x05\x7f\x03\x03\t\x99\x03\x07\x1b\r\x1d\r\x0f\x1f\x05!\x05#\x05%\x03\x0b#_%e'g\x0fu)w\x05'\x05)\x05+\x05-\x03\x03-y\x05/\x1d1\x11\x051\x1d5\x11\x053\x03\x03\x05{\x055\x17\x13\xc6\x07\x01\x03\x03\x05}\x03\x11A\x81C\x83E\x85G_I\x87K\x89M_O\x8b\x057\x059\x05;\x05=\x05?\x05A\x05C\x05E\x03\x03\x05\x8d\x03\x05U\x91W\x93\x05G\x05I\x03\x03\t\x97\x1f%\x01\x1f'!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1dK\x1f)\x11\x00\x00\x00\x00\x00\x00\x00\x00#\x1b\x03\x07imq\r\x03ak\x1dM\r\x03ao\x1dO\r\x03as\x1dQ\x1dS\x1dU\x13\r\x01\x1f\x05\t\x01\x00\x00\x00\x1f\x05\t\x04\x00\x00\x00\x1f\x11\x03V\x0b\x05\x1dW\x1dY\x05\x01\x03\x0b[[[[]\x03\r]cc]][\x1f\x05\t\x00\x00\x00\x00\x1f+\x01\t\x07\x07\x01\x1f\x0f!\x00\x00\x00\x00\x00\x00\xf8\x7f\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f3\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f5!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\x02\x02)\x05\x11\x11\x0b)\x01\x1f\x01)\x03\x11\x0b\x03\x15\x1d)\x01\x0b)\x01!\x13\x0b)\x05\x05\x05\x07)\x05\x11\x11\x07\x11\x01\x07\t\x03\x03)\x03A\x0b\x1b!)\x03!\x15)\x03\x01\x13)\x03\t\x13)\x03\x05\x13)\x03\x01\r)\x01\x07)\x03\x05\x07)\x03\x11\x07)\x03\x05\r)\x03\t\r\x04j\x03\x05\x01\x11\x07\x19\x07\x03\x01\x05\t\x11\x07!\x05\x03=i\x0b\x03/+\x03\x1d\r\x063\x03\x03\x03\x01\x05\x03\x017\x03\x05\x05\x03\x01=\x03\x05\x05\x03\x01\x15\x03\x11\x05\x03\x01\x15\x03\x11\x0f\x07\x01?\r\x03#\t\x03\x03\x05\x0b\x05\x07\t\x0b\x03\x05\x03\x01Q\x03\x05\x03\x07\x01\x03\x03\x05\x03\x19\x11\x07\x01S\x03-\x05\x17\x1b\x03\x07\x01\x03\x03/\x03\x1d\x05\x03\x01\x0b\x03\x0f\x03\x07\x01\x03\x03\t\x03!\x03\x07\x01Y\x031\x03\x1f\x07\x06\x01\x03\t\x07%\x11#\x03\x07\x01\x03\x03\x17\x03\x1d\x05\x03\x01\x0b\x03\x0f\x03\x07\x01\x03\x03\x03\x03+\x03\x07\x01\x17\x03\x19\x03)\x07\x06\x01\x03\x03\x07/\x13-\x03\x07\x01\x03\x03\x17\x03\x1d\x05\x03\x01\x0b\x03\x0f\x03\x07\x01\x03\x03\x03\x035\x03\x07\x01\x17\x03\x19\x033\x07\x06\x01\x03\x03\x079\x157\x13\x04\x07\x07'1;\x06\x03\x01\x05\x01\x00\x02\r[\x1b\x03\x0f\x0b\t\t\t!+\x1b\x1f/!!)#\x1f\x19\xb1}\x87\x1f\x1f\x15\x1d\x15\x13%)\x97\x13+\r\x15\x17\x1f\x17\x11\x11\x15\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00select_v1\x00func_v1\x00iota_v1\x00reshape_v1\x00custom_call_v1\x00compare_v1\x00return_v1\x00value\x00broadcast_dimensions\x00sym_name\x00/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00jit(func)/jit(main)/iota[dtype=complex128 shape=(16,) dimension=0]\x00jit(func)/jit(main)/reshape[new_sizes=(4, 4) dimensions=None]\x00jit(func)/jit(main)/eig[compute_left_eigenvectors=True compute_right_eigenvectors=True]\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00compare_type\x00comparison_direction\x00jax.result_info\x00[0]\x00[1]\x00[2]\x00main\x00public\x00\x00lapack_zgeev\x00", - xla_call_module_version=6, -) # End paste - +from numpy import array, complex64 data_2024_08_19 = {} - # Pasted from the test output (see export_back_compat_test_util.py module docstring) data_2024_08_19["c128"] = dict( testdata_version=1, diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/cpu_eigh_lapack_syev.py b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_eigh_lapack_syev.py index f0696db1aeda..cd5f5c55caf9 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_data/cpu_eigh_lapack_syev.py +++ b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_eigh_lapack_syev.py @@ -17,376 +17,8 @@ import datetime from numpy import array, float32, complex64 -data_2023_03_17 = dict( - # Pasted from the test output (see back_compat_test.py module docstring) - f32=dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_ssyevd'], - serialized_date=datetime.date(2023, 3, 17), - inputs=(), - expected_outputs=(array([[-0.6185769 , -0.20142993 , -0.09725195 , 0.62983674 , - -0.07926044 , 0.3605001 , -0.019093221 , -0.18446997 ], - [-0.47070873 , 0.29325768 , -0.19454119 , -0.6394365 , - 0.0622955 , 0.33249345 , 0.28112718 , -0.22856665 ], - [-0.32284075 , -0.12361939 , 0.20547704 , -0.18307868 , - 0.47294614 , -0.3170349 , -0.6373532 , -0.27266347 ], - [-0.17497246 , -0.079641335 , 0.15042791 , -0.15416273 , - -0.815209 , -0.38054234 , -0.083263926 , -0.31676024 ], - [-0.027104253 , -0.26490977 , 0.32271704 , 0.08653544 , - 0.30305928 , -0.33998996 , 0.6926741 , -0.360857 ], - [ 0.12076397 , 0.43288827 , -0.64385164 , 0.2652551 , - 0.09482376 , -0.37435007 , 0.00091664493, -0.40495378 ], - [ 0.26863196 , 0.51607686 , 0.53846526 , 0.16969058 , - -0.021670295 , 0.35755336 , -0.113144726 , -0.4490505 ], - [ 0.4165004 , -0.57262254 , -0.2814425 , -0.17463988 , - -0.01698498 , 0.3613705 , -0.12186296 , -0.49314725 ]], - dtype=float32), array([-2.4598808e+01, -3.3105560e-05, -3.1002426e-05, -1.0103593e-05, - -1.0022322e-05, 4.0141886e-06, 9.5510331e-06, 2.7659882e+02], - dtype=float32)), - mlir_module_text=r""" -module @jit__lambda_ { - func.func public @main() -> (tensor<8x8xf32> {jax.result_info = "[0]"}, tensor<8xf32> {jax.result_info = "[1]"}) { - %0 = stablehlo.iota dim = 0 : tensor<64xf32> - %1 = stablehlo.reshape %0 : (tensor<64xf32>) -> tensor<8x8xf32> - %2 = stablehlo.transpose %1, dims = [1, 0] : (tensor<8x8xf32>) -> tensor<8x8xf32> - %3 = stablehlo.add %1, %2 : tensor<8x8xf32> - %4 = stablehlo.constant dense<2.000000e+00> : tensor - %5 = stablehlo.broadcast_in_dim %4, dims = [] : (tensor) -> tensor<8x8xf32> - %6 = stablehlo.divide %3, %5 : tensor<8x8xf32> - %7 = call @tril(%6) : (tensor<8x8xf32>) -> tensor<8x8xf32> - %8 = stablehlo.constant dense<1> : tensor - %9 = stablehlo.constant dense<1> : tensor - %10 = stablehlo.constant dense<8> : tensor - %11 = stablehlo.custom_call @lapack_ssyevd(%8, %9, %10, %7) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor<8x8xf32>) -> tuple, tensor<8xf32>, tensor, tensor<177xf32>, tensor<43xi32>> - %12 = stablehlo.get_tuple_element %11[0] : (tuple, tensor<8xf32>, tensor, tensor<177xf32>, tensor<43xi32>>) -> tensor<8x8xf32> - %13 = stablehlo.get_tuple_element %11[1] : (tuple, tensor<8xf32>, tensor, tensor<177xf32>, tensor<43xi32>>) -> tensor<8xf32> - %14 = stablehlo.get_tuple_element %11[2] : (tuple, tensor<8xf32>, tensor, tensor<177xf32>, tensor<43xi32>>) -> tensor - %15 = stablehlo.get_tuple_element %11[3] : (tuple, tensor<8xf32>, tensor, tensor<177xf32>, tensor<43xi32>>) -> tensor<177xf32> - %16 = stablehlo.get_tuple_element %11[4] : (tuple, tensor<8xf32>, tensor, tensor<177xf32>, tensor<43xi32>>) -> tensor<43xi32> - %17 = stablehlo.constant dense<0> : tensor - %18 = stablehlo.broadcast_in_dim %17, dims = [] : (tensor) -> tensor - %19 = stablehlo.compare EQ, %14, %18, SIGNED : (tensor, tensor) -> tensor - %20 = stablehlo.broadcast_in_dim %19, dims = [] : (tensor) -> tensor<1x1xi1> - %21 = stablehlo.constant dense<0x7FC00000> : tensor - %22 = stablehlo.broadcast_in_dim %21, dims = [] : (tensor) -> tensor<8x8xf32> - %23 = stablehlo.broadcast_in_dim %20, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<8x8xi1> - %24 = stablehlo.select %23, %12, %22 : tensor<8x8xi1>, tensor<8x8xf32> - %25 = stablehlo.broadcast_in_dim %19, dims = [] : (tensor) -> tensor<1xi1> - %26 = stablehlo.constant dense<0x7FC00000> : tensor - %27 = stablehlo.broadcast_in_dim %26, dims = [] : (tensor) -> tensor<8xf32> - %28 = stablehlo.broadcast_in_dim %25, dims = [0] : (tensor<1xi1>) -> tensor<8xi1> - %29 = stablehlo.select %28, %13, %27 : tensor<8xi1>, tensor<8xf32> - return %24, %29 : tensor<8x8xf32>, tensor<8xf32> - } - func.func private @tril(%arg0: tensor<8x8xf32>) -> tensor<8x8xf32> { - %0 = stablehlo.iota dim = 0 : tensor<8x8xi32> - %1 = stablehlo.constant dense<0> : tensor - %2 = stablehlo.broadcast_in_dim %1, dims = [] : (tensor) -> tensor<8x8xi32> - %3 = stablehlo.add %0, %2 : tensor<8x8xi32> - %4 = stablehlo.iota dim = 1 : tensor<8x8xi32> - %5 = stablehlo.compare GE, %3, %4, SIGNED : (tensor<8x8xi32>, tensor<8x8xi32>) -> tensor<8x8xi1> - %6 = stablehlo.constant dense<0.000000e+00> : tensor - %7 = stablehlo.broadcast_in_dim %6, dims = [] : (tensor) -> tensor<8x8xf32> - %8 = stablehlo.select %5, %arg0, %7 : tensor<8x8xi1>, tensor<8x8xf32> - return %8 : tensor<8x8xf32> - } -} -""", - mlir_module_serialized=b"ML\xefR\x03MLIRxxx-trunk\x00\x01-\x05\x01\x05\x01\x03\x05\x03\x1d\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x1f!\x03z\x02\xf77\x01\x9b\x0f\x17\x13\x0b\x07\x0f\x0b\x0b\x0b\x0b\x17\x0b\x0b\x0b\x0b\x13\x0b\x13\x0f\x0b\x0b\x17\x0f\x13\x13\x13\x0b33\x0b\x0f\x0b\x0b\x13\x0f\x0b\x1b\x0f\x0b\x13\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x13\x0b\x0f\x0b\x0f\x0b\x13\x0b\x13\x0b\x0b\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x13\x13\x13\x13\x1b\x13\x13\x03]\x0f/\x0b\x0b\x0f\x0b\x0bO\x0b\x13\x13\x0b\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x1f\x0f\x0f\x0b\x1fO\x1f\x1f\x1f\x0b\x0b\x0b\x0b\x1b\x0f\x17\x1f\x0f\x0f\x0f\x0f\x0f\x0b\x1fO/\x037\x17\x0f\x07\x0f\x07\x13\x07\x07\x17\x07\x17\x13\x17\x13\x17\x17\x13\x17\x1f\x13\x13\x13\x0f\x17\x13\x13\x13\x02\n\t\x1du\x03\x17\x11\xf6\x04\x01\x03\x03\x13\xc5\x05#\x1f\x1d;\x03\x05%\x05'\x05)\x05+\x17\x11\xf2\x04\x01\x05-\x05/\x051\x053\x03\x03!\xc1\x055\x03\x03\x07\xc3\x1dA\x03\x057\x059\x17\x11\xea\x04\x01\x1do\x15\x03\x03\x07\xd1\x03\x03\x07\xf1\x03\x03\x0f5\x05;\x03\x0b\x17\x9f\x19\xab\x1b\xad\x0f\xb7\x1d\xb9\x03\x0b\x17\xa3\x19\xbd\x1b\xa3\x0f\xa5\x1d\xbf\x05=\x1d?\x03\x05?\x05A\x03\x03!\xc7\x1dG\x03\x05C\x03\x05'\xa7)\xc9\x1dM\x03\x05E\x03\x03\x07\xcb\x1dS\x03\x05G\x1dW\x03\x05I\x1d[+\x05K\x1d_+\x05M\x03\x03c\xcd\x05O\x1dg\x15\x05Q\x1dk\x15\x05S\x03\x03\x07\xcf\x05U\x03\x03s\xa5\x05W\x05Y\x03\x03\x07\xd3\x03\x11{\xd5}\xd7\x7f\xd9\x81\x9f\x83\xdb\x85\xdd\x87\xdf\x89\xe3\x05[\x05]\x05_\x05a\x05c\x05e\x05g\x05i\x03\x03\r\xe5\x03\x03\r\xe7\x03\x03\r\xe9\x03\x03\r\xeb\x03\x03\r\xed\x03\x05'\xa7)\xef\x03\x03\x13\xf3\x03\x03\x13\xf5\x1f'\x01\x1f+\x11\x00\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1dk\x03\x03\xbb\x1dm\t\x07\x1f)!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00#\x1d\x03\x05\xaf\xb3\r\x03\xa1\xb1\x1do\r\x03\xa1\xb5\x1dq\x1ds\x1du\r\x01#\x1f\x1dw\x13\r\x01\x1f\x03\t\x00\x00\x00\x00\x1f!\x01\x13\r\x05\x07\x05\x1f\x07\t\x00\x00\x00\x00\x1f\x17!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x07\t\x00\x00\x00@\x1f\x03\t\x01\x00\x00\x00\x1f\x03\t\x08\x00\x00\x00\x0b\x05\x1dy\x1d{\x05\x01\x03\t\x9b\x9b\x9b\xa9\x03\x03\xe1\x15\x03\x01\r\x01\x03\x0b\xa9\x9d\x9b\x9d\x9d\x13\x05\x01\x13\x05\x05\x13\x05\t\x13\x05\r\x13\x05\x11\x07\x01\x1f\x07\t\x00\x00\xc0\x7f\x1f\x17!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f5\x11\x00\x00\x00\x00\x00\x00\x00\x00)\x05!!\t)\x01\x05\x1b)\x01\t\t)\x03!\t\x1d\x01)\x05!!\x05\x13)\x05!!\x0f)\x03\t\r)\x03\x8a\x05\t)\x03\xad\x05\x11\x01\x05\x01\x0b\x11\x03\x01\x03\x01)\x03\x01\r)\x03\x02\x02\t/\x0b\x01\x0b\x03\x19\x1b)\x03\x01\x13)\x03\t\x13)\x03\x05\x13)\x01\x0f)\x05\x05\x05\x0f)\x03\x05\x0f)\x03!\x0f)\x03\x05\r\x04:\x05\x05\x01\x11\t3\x07\x03\x01\t\r\x11\t7\x05\x03=}\t\x03Y\x1f\x03#\x15\x06]\x03\x01\x03\x01\x17\x07ea\x03\x01\x03\x03\x0f\x06i\x03\x01\x05\x03\x05\x05\x03\tm\x03\x07\x03\x07-\x05\x03\x01\x03\t\x19\x06-\x03\x01\x05\x07\x0b\x1b\x07\x0bq\x03\x01\x03\r\x05\x03\x01/\x03\x03\x05\x03\x01/\x03\x03\x05\x03\x01w\x03\x03\x1d\x07\x01y\x03%\t\x11\x13\x15\x0f\x07\x07\x01\x8b\x03\x01\x03\x17\x07\x07\x01\x8d\x03\x0b\x03\x17\x07\x07\x01\x8f\x03\x03\x03\x17\x07\x07\x01\x91\x03\x19\x03\x17\x07\x07\x01\x93\x03\x1b\x03\x17\x05\x03\x01#\x03\x03\x03\x07\x01\x05\x03\x03\x03#\x11\x07\x01\x95\x03-\x05\x1d%\x03\x07\x01\x05\x03/\x03'\x05\x03\x011\x03\x07\x03\x07\x01\x05\x03\x01\x03+\x03\x07\x01\x97\x03\x15\x03)\x0b\x06\x01\x03\x01\x07/\x19-\x03\x07\x01\x05\x031\x03'\x05\x03\x011\x03\x07\x03\x07\x01\x05\x03\x0b\x035\x03\x07\x01\x99\x033\x033\x0b\x06\x01\x03\x0b\x079\x1b7\x13\x04\t\x051;\r\x11\x0b9\x05\x03\x15+\x03\x01\t\t\x03=\x1f\x03\x11\x05\x03\x0b#\x03\x03\x03\x07%\x05\x03\x11\x03\x05\x0f\x06%\x03\x11\x05\x03\x07\t\x03EC\x03\x11\x11\x07KI\x03\x15\x05\t\x0b\x05\x03\x0bO\x03\x07\x03\x07Q\x05\x03\x01\x03\x0f\x0b\x06U\x03\x01\x07\r\x01\x11\x13\x04\x0b\x03\x13\x06\x03\x01\x05\x01\x00\xb2\x19}\x1d\x03\x11\x0f\x0b\t\t\x0b!\x1f/!!)#\x1f\x19\x7f\x0f99m\x19\x85\x89W\xb3K\x9bM\x9b\x96\x04\x1b+\x1b\x1f\x1f\x15\x1d\x15+\x83\x13\r\r\x1f\x11\x15\x1b\x17\x15\x17\x0f\x11\x15\x11+\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00get_tuple_element_v1\x00iota_v1\x00select_v1\x00func_v1\x00add_v1\x00compare_v1\x00return_v1\x00reshape_v1\x00transpose_v1\x00divide_v1\x00call_v1\x00custom_call_v1\x00value\x00index\x00sym_name\x00third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py\x00broadcast_dimensions\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00compare_type\x00comparison_direction\x00jit__lambda_\x00jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False,) name=tril in_positional_semantics=(<_PositionalSemantics.GLOBAL: 1>,) out_positional_semantics=_PositionalSemantics.GLOBAL keep_unused=False inline=False]\x00jit()/jit(main)/jit(tril)/iota[dtype=int32 shape=(8, 8) dimension=0]\x00jit()/jit(main)/jit(tril)/add\x00jit()/jit(main)/jit(tril)/iota[dtype=int32 shape=(8, 8) dimension=1]\x00jit()/jit(main)/jit(tril)/ge\x00jit()/jit(main)/jit(tril)/broadcast_in_dim[shape=(8, 8) broadcast_dimensions=()]\x00jit()/jit(main)/jit(tril)/select_n\x00jit()/jit(main)/iota[dtype=float32 shape=(64,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(8, 8) dimensions=None]\x00permutation\x00jit()/jit(main)/transpose[permutation=(1, 0)]\x00jit()/jit(main)/add\x00jit()/jit(main)/div\x00callee\x00jit()/jit(main)/eigh[lower=True sort_eigenvalues=True]\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jax.result_info\x00tril\x00[0]\x00[1]\x00main\x00public\x00private\x00\x00lapack_ssyevd\x00", - xla_call_module_version=4, - ), # End paste - - # Pasted from the test output (see back_compat_test.py module docstring) - f64=dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_dsyevd'], - serialized_date=datetime.date(2023, 3, 17), - inputs=(), - expected_outputs=(array([[-6.1857700048412056e-01, 2.4081403770912022e-01, - 3.5662489253627483e-01, -6.3034019033669797e-01, - 1.0043483479985752e-16, -2.8842036081919542e-02, - 7.7164692943283169e-25, -1.8446994643771725e-01], - [-4.7070881487314614e-01, 4.7473787464450845e-01, - -4.8036836210243367e-01, 4.3802686872516400e-01, - 1.7961797619639258e-01, 8.3080980076741355e-03, - 2.1415294457221756e-01, -2.2856669794666584e-01], - [-3.2284062926217072e-01, -5.4336490915553370e-01, - 2.2181041859724990e-01, 2.9947877954402297e-01, - -3.6491813600134632e-01, 3.2867679819727436e-01, - 3.8223299448843473e-01, -2.7266344945561438e-01], - [-1.7497244365119530e-01, -8.9251550609769414e-02, - -6.3518515114898394e-02, 1.9162997359209971e-01, - -2.2087281326110139e-01, 5.9957027043505064e-02, - -8.7632498908241274e-01, -3.1676020096456303e-01], - [-2.7104258040220038e-02, -3.3772873786627672e-01, - 2.5901386593721748e-01, 1.7032650752287815e-01, - 6.7521217612940332e-01, -4.5036136532965476e-01, - -1.2279030059078447e-02, -3.6085695247351163e-01], - [ 1.2076392757075530e-01, -3.3834734096469254e-01, - -6.5506827461665540e-01, -5.0472498521116749e-01, - 6.9987430903492118e-02, 1.0595648906599275e-01, - 8.3443844143082022e-02, -4.0495370398246017e-01], - [ 2.6863211318173097e-01, 2.2958613191407318e-01, - 6.3952843755683941e-02, 1.8776775771084137e-02, - -5.3523731432241317e-01, -5.9199531677602002e-01, - 1.7916671834524248e-01, -4.4905045549140887e-01], - [ 4.1650029879270661e-01, 3.6355449432857079e-01, - 2.9755313100756142e-01, 1.6826270392615944e-02, - 1.9621068035557282e-01, 5.6830030587314817e-01, - 2.9607517592514246e-02, -4.9314720700035747e-01]]), array([-2.4598804776133626e+01, -4.6567755957874661e-14, - -1.9932120610662194e-14, -5.7323356091157378e-15, - -4.5459724251334835e-16, 4.0479851042511616e-14, - 9.2325194924982089e-14, 2.7659880477613365e+02])), - mlir_module_text=r""" -module @jit__lambda_ { - func.func public @main() -> (tensor<8x8xf64> {jax.result_info = "[0]"}, tensor<8xf64> {jax.result_info = "[1]"}) { - %0 = stablehlo.iota dim = 0 : tensor<64xf64> - %1 = stablehlo.reshape %0 : (tensor<64xf64>) -> tensor<8x8xf64> - %2 = stablehlo.transpose %1, dims = [1, 0] : (tensor<8x8xf64>) -> tensor<8x8xf64> - %3 = stablehlo.add %1, %2 : tensor<8x8xf64> - %4 = stablehlo.constant dense<2.000000e+00> : tensor - %5 = stablehlo.broadcast_in_dim %4, dims = [] : (tensor) -> tensor<8x8xf64> - %6 = stablehlo.divide %3, %5 : tensor<8x8xf64> - %7 = call @tril(%6) : (tensor<8x8xf64>) -> tensor<8x8xf64> - %8 = stablehlo.constant dense<1> : tensor - %9 = stablehlo.constant dense<1> : tensor - %10 = stablehlo.constant dense<8> : tensor - %11 = stablehlo.custom_call @lapack_dsyevd(%8, %9, %10, %7) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor<8x8xf64>) -> tuple, tensor<8xf64>, tensor, tensor<177xf64>, tensor<43xi32>> - %12 = stablehlo.get_tuple_element %11[0] : (tuple, tensor<8xf64>, tensor, tensor<177xf64>, tensor<43xi32>>) -> tensor<8x8xf64> - %13 = stablehlo.get_tuple_element %11[1] : (tuple, tensor<8xf64>, tensor, tensor<177xf64>, tensor<43xi32>>) -> tensor<8xf64> - %14 = stablehlo.get_tuple_element %11[2] : (tuple, tensor<8xf64>, tensor, tensor<177xf64>, tensor<43xi32>>) -> tensor - %15 = stablehlo.get_tuple_element %11[3] : (tuple, tensor<8xf64>, tensor, tensor<177xf64>, tensor<43xi32>>) -> tensor<177xf64> - %16 = stablehlo.get_tuple_element %11[4] : (tuple, tensor<8xf64>, tensor, tensor<177xf64>, tensor<43xi32>>) -> tensor<43xi32> - %17 = stablehlo.constant dense<0> : tensor - %18 = stablehlo.broadcast_in_dim %17, dims = [] : (tensor) -> tensor - %19 = stablehlo.compare EQ, %14, %18, SIGNED : (tensor, tensor) -> tensor - %20 = stablehlo.broadcast_in_dim %19, dims = [] : (tensor) -> tensor<1x1xi1> - %21 = stablehlo.constant dense<0x7FF8000000000000> : tensor - %22 = stablehlo.broadcast_in_dim %21, dims = [] : (tensor) -> tensor<8x8xf64> - %23 = stablehlo.broadcast_in_dim %20, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<8x8xi1> - %24 = stablehlo.select %23, %12, %22 : tensor<8x8xi1>, tensor<8x8xf64> - %25 = stablehlo.broadcast_in_dim %19, dims = [] : (tensor) -> tensor<1xi1> - %26 = stablehlo.constant dense<0x7FF8000000000000> : tensor - %27 = stablehlo.broadcast_in_dim %26, dims = [] : (tensor) -> tensor<8xf64> - %28 = stablehlo.broadcast_in_dim %25, dims = [0] : (tensor<1xi1>) -> tensor<8xi1> - %29 = stablehlo.select %28, %13, %27 : tensor<8xi1>, tensor<8xf64> - return %24, %29 : tensor<8x8xf64>, tensor<8xf64> - } - func.func private @tril(%arg0: tensor<8x8xf64>) -> tensor<8x8xf64> { - %0 = stablehlo.iota dim = 0 : tensor<8x8xi32> - %1 = stablehlo.constant dense<0> : tensor - %2 = stablehlo.broadcast_in_dim %1, dims = [] : (tensor) -> tensor<8x8xi32> - %3 = stablehlo.add %0, %2 : tensor<8x8xi32> - %4 = stablehlo.iota dim = 1 : tensor<8x8xi32> - %5 = stablehlo.compare GE, %3, %4, SIGNED : (tensor<8x8xi32>, tensor<8x8xi32>) -> tensor<8x8xi1> - %6 = stablehlo.constant dense<0.000000e+00> : tensor - %7 = stablehlo.broadcast_in_dim %6, dims = [] : (tensor) -> tensor<8x8xf64> - %8 = stablehlo.select %5, %arg0, %7 : tensor<8x8xi1>, tensor<8x8xf64> - return %8 : tensor<8x8xf64> - } -} -""", - mlir_module_serialized=b"ML\xefR\x03MLIRxxx-trunk\x00\x01-\x05\x01\x05\x01\x03\x05\x03\x1d\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x1f!\x03z\x02\xf77\x01\x9b\x0f\x17\x13\x0b\x07\x0f\x0b\x0b\x0b\x0b\x17\x0b\x0b\x0b\x0b\x13\x0b\x13\x0f\x0b\x0b\x17\x0f\x13\x13\x13\x0b33\x0b\x0f\x0b\x0b\x13\x0f\x0b\x1b\x0f\x0b\x13\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x13\x0b\x0f\x0b\x0f\x0b\x13\x0b\x13\x0b\x0b\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x13\x13\x13\x13\x1b\x13\x13\x03]\x0f/\x0b\x0b\x0f\x0b\x0bO\x0b\x13\x13\x0b\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x1f\x0f\x0f\x0b/O/\x1f\x1f\x0b\x0b\x0b\x0b\x1b\x0f\x17\x1f\x0f\x0f\x0f\x0f\x0f\x0b/O/\x037\x17\x0f\x07\x0f\x07\x13\x07\x07\x17\x07\x17\x13\x17\x13\x17\x17\x13\x17\x1f\x13\x13\x13\x0f\x17\x13\x13\x13\x02:\t\x1du\x03\x17\x11\xf6\x04\x01\x03\x03\x13\xc5\x05#\x1f\x1d;\x03\x05%\x05'\x05)\x05+\x17\x11\xf2\x04\x01\x05-\x05/\x051\x053\x03\x03!\xc1\x055\x03\x03\x07\xc3\x1dA\x03\x057\x059\x17\x11\xea\x04\x01\x1do\x15\x03\x03\x07\xd1\x03\x03\x07\xf1\x03\x03\x0f5\x05;\x03\x0b\x17\x9f\x19\xab\x1b\xad\x0f\xb7\x1d\xb9\x03\x0b\x17\xa3\x19\xbd\x1b\xa3\x0f\xa5\x1d\xbf\x05=\x1d?\x03\x05?\x05A\x03\x03!\xc7\x1dG\x03\x05C\x03\x05'\xa7)\xc9\x1dM\x03\x05E\x03\x03\x07\xcb\x1dS\x03\x05G\x1dW\x03\x05I\x1d[+\x05K\x1d_+\x05M\x03\x03c\xcd\x05O\x1dg\x15\x05Q\x1dk\x15\x05S\x03\x03\x07\xcf\x05U\x03\x03s\xa5\x05W\x05Y\x03\x03\x07\xd3\x03\x11{\xd5}\xd7\x7f\xd9\x81\x9f\x83\xdb\x85\xdd\x87\xdf\x89\xe3\x05[\x05]\x05_\x05a\x05c\x05e\x05g\x05i\x03\x03\r\xe5\x03\x03\r\xe7\x03\x03\r\xe9\x03\x03\r\xeb\x03\x03\r\xed\x03\x05'\xa7)\xef\x03\x03\x13\xf3\x03\x03\x13\xf5\x1f'\x01\x1f+\x11\x00\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1dk\x03\x03\xbb\x1dm\t\x07\x1f)!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00#\x1d\x03\x05\xaf\xb3\r\x03\xa1\xb1\x1do\r\x03\xa1\xb5\x1dq\x1ds\x1du\r\x01#\x1f\x1dw\x13\r\x01\x1f\x03\t\x00\x00\x00\x00\x1f!\x01\x13\r\x05\x07\x05\x1f\x07\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x17!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x07\x11\x00\x00\x00\x00\x00\x00\x00@\x1f\x03\t\x01\x00\x00\x00\x1f\x03\t\x08\x00\x00\x00\x0b\x05\x1dy\x1d{\x05\x01\x03\t\x9b\x9b\x9b\xa9\x03\x03\xe1\x15\x03\x01\r\x01\x03\x0b\xa9\x9d\x9b\x9d\x9d\x13\x05\x01\x13\x05\x05\x13\x05\t\x13\x05\r\x13\x05\x11\x07\x01\x1f\x07\x11\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f\x17!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f5\x11\x00\x00\x00\x00\x00\x00\x00\x00)\x05!!\t)\x01\x05\x1b)\x01\t\x0b)\x03!\t\x1d\x01)\x05!!\x05\x13)\x05!!\x0f)\x03\t\r)\x03\x8a\x05\t)\x03\xad\x05\x11\x01\x05\x01\x0b\x11\x03\x01\x03\x01)\x03\x01\r)\x03\x02\x02\t/\x0b\x01\x0b\x03\x19\x1b)\x03\x01\x13)\x03\t\x13)\x03\x05\x13)\x01\x0f)\x05\x05\x05\x0f)\x03\x05\x0f)\x03!\x0f)\x03\x05\r\x04:\x05\x05\x01\x11\t3\x07\x03\x01\t\r\x11\t7\x05\x03=}\t\x03Y\x1f\x03#\x15\x06]\x03\x01\x03\x01\x17\x07ea\x03\x01\x03\x03\x0f\x06i\x03\x01\x05\x03\x05\x05\x03\tm\x03\x07\x03\x07-\x05\x03\x01\x03\t\x19\x06-\x03\x01\x05\x07\x0b\x1b\x07\x0bq\x03\x01\x03\r\x05\x03\x01/\x03\x03\x05\x03\x01/\x03\x03\x05\x03\x01w\x03\x03\x1d\x07\x01y\x03%\t\x11\x13\x15\x0f\x07\x07\x01\x8b\x03\x01\x03\x17\x07\x07\x01\x8d\x03\x0b\x03\x17\x07\x07\x01\x8f\x03\x03\x03\x17\x07\x07\x01\x91\x03\x19\x03\x17\x07\x07\x01\x93\x03\x1b\x03\x17\x05\x03\x01#\x03\x03\x03\x07\x01\x05\x03\x03\x03#\x11\x07\x01\x95\x03-\x05\x1d%\x03\x07\x01\x05\x03/\x03'\x05\x03\x011\x03\x07\x03\x07\x01\x05\x03\x01\x03+\x03\x07\x01\x97\x03\x15\x03)\x0b\x06\x01\x03\x01\x07/\x19-\x03\x07\x01\x05\x031\x03'\x05\x03\x011\x03\x07\x03\x07\x01\x05\x03\x0b\x035\x03\x07\x01\x99\x033\x033\x0b\x06\x01\x03\x0b\x079\x1b7\x13\x04\t\x051;\r\x11\x0b9\x05\x03\x15+\x03\x01\t\t\x03=\x1f\x03\x11\x05\x03\x0b#\x03\x03\x03\x07%\x05\x03\x11\x03\x05\x0f\x06%\x03\x11\x05\x03\x07\t\x03EC\x03\x11\x11\x07KI\x03\x15\x05\t\x0b\x05\x03\x0bO\x03\x07\x03\x07Q\x05\x03\x01\x03\x0f\x0b\x06U\x03\x01\x07\r\x01\x11\x13\x04\x0b\x03\x13\x06\x03\x01\x05\x01\x00\xb2\x19}\x1d\x03\x11\x0f\x0b\t\t\x0b!\x1f/!!)#\x1f\x19\x7f\x0f99m\x19\x85\x89W\xb3K\x9bM\x9b\x96\x04\x1b+\x1b\x1f\x1f\x15\x1d\x15+\x83\x13\r\r\x1f\x11\x15\x1b\x17\x15\x17\x0f\x11\x15\x11+\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00get_tuple_element_v1\x00iota_v1\x00select_v1\x00func_v1\x00add_v1\x00compare_v1\x00return_v1\x00reshape_v1\x00transpose_v1\x00divide_v1\x00call_v1\x00custom_call_v1\x00value\x00index\x00sym_name\x00third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py\x00broadcast_dimensions\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00compare_type\x00comparison_direction\x00jit__lambda_\x00jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False,) name=tril in_positional_semantics=(<_PositionalSemantics.GLOBAL: 1>,) out_positional_semantics=_PositionalSemantics.GLOBAL keep_unused=False inline=False]\x00jit()/jit(main)/jit(tril)/iota[dtype=int32 shape=(8, 8) dimension=0]\x00jit()/jit(main)/jit(tril)/add\x00jit()/jit(main)/jit(tril)/iota[dtype=int32 shape=(8, 8) dimension=1]\x00jit()/jit(main)/jit(tril)/ge\x00jit()/jit(main)/jit(tril)/broadcast_in_dim[shape=(8, 8) broadcast_dimensions=()]\x00jit()/jit(main)/jit(tril)/select_n\x00jit()/jit(main)/iota[dtype=float64 shape=(64,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(8, 8) dimensions=None]\x00permutation\x00jit()/jit(main)/transpose[permutation=(1, 0)]\x00jit()/jit(main)/add\x00jit()/jit(main)/div\x00callee\x00jit()/jit(main)/eigh[lower=True sort_eigenvalues=True]\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jax.result_info\x00tril\x00[0]\x00[1]\x00main\x00public\x00private\x00\x00lapack_dsyevd\x00", - xla_call_module_version=4, - ), # End paste - - # Pasted from the test output (see back_compat_test.py module docstring) - c64=dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_cheevd'], - serialized_date=datetime.date(2023, 3, 17), - inputs=(), - expected_outputs=(array([[-0.6185769 +0.j, -0.20142993 +0.j, -0.09725195 +0.j, - 0.62983674 +0.j, -0.07926044 +0.j, 0.3605001 -0.j, - -0.019093221 +0.j, -0.18446997 +0.j], - [-0.47070873 +0.j, 0.29325768 +0.j, -0.19454116 +0.j, - -0.6394365 +0.j, 0.06229549 +0.j, 0.33249345 +0.j, - 0.28112718 +0.j, -0.22856665 +0.j], - [-0.32284075 +0.j, -0.12361939 +0.j, 0.20547704 +0.j, - -0.18307868 +0.j, 0.47294614 +0.j, -0.3170349 +0.j, - -0.6373532 +0.j, -0.27266347 +0.j], - [-0.17497246 +0.j, -0.079641335 +0.j, 0.15042792 +0.j, - -0.15416273 +0.j, -0.815209 +0.j, -0.38054234 +0.j, - -0.083263926 +0.j, -0.31676024 +0.j], - [-0.027104257 +0.j, -0.26490977 +0.j, 0.32271704 +0.j, - 0.08653544 +0.j, 0.30305928 +0.j, -0.33998996 +0.j, - 0.6926741 +0.j, -0.360857 +0.j], - [ 0.120763965 +0.j, 0.43288827 +0.j, -0.64385164 +0.j, - 0.2652551 +0.j, 0.094823755 +0.j, -0.37435007 +0.j, - 0.00091664493+0.j, -0.40495378 +0.j], - [ 0.26863196 +0.j, 0.51607686 +0.j, 0.53846526 +0.j, - 0.16969058 +0.j, -0.0216703 +0.j, 0.35755336 +0.j, - -0.113144726 +0.j, -0.4490505 +0.j], - [ 0.4165004 +0.j, -0.57262254 +0.j, -0.28144246 +0.j, - -0.17463988 +0.j, -0.016984984 +0.j, 0.3613705 +0.j, - -0.12186296 +0.j, -0.49314725 +0.j]], dtype=complex64), array([-2.4598808e+01, -3.3105560e-05, -3.1002426e-05, -1.0103593e-05, - -1.0022322e-05, 4.0141886e-06, 9.5510331e-06, 2.7659882e+02], - dtype=float32)), - mlir_module_text=r""" -module @jit__lambda_ { - func.func public @main() -> (tensor<8x8xcomplex> {jax.result_info = "[0]"}, tensor<8xf32> {jax.result_info = "[1]"}) { - %0 = stablehlo.iota dim = 0 : tensor<64xcomplex> - %1 = stablehlo.reshape %0 : (tensor<64xcomplex>) -> tensor<8x8xcomplex> - %2 = stablehlo.transpose %1, dims = [1, 0] : (tensor<8x8xcomplex>) -> tensor<8x8xcomplex> - %3 = stablehlo.real %2 : (tensor<8x8xcomplex>) -> tensor<8x8xf32> - %4 = stablehlo.imag %2 : (tensor<8x8xcomplex>) -> tensor<8x8xf32> - %5 = stablehlo.negate %4 : tensor<8x8xf32> - %6 = stablehlo.complex %3, %5 : tensor<8x8xcomplex> - %7 = stablehlo.add %1, %6 : tensor<8x8xcomplex> - %8 = stablehlo.constant dense<(2.000000e+00,0.000000e+00)> : tensor> - %9 = stablehlo.broadcast_in_dim %8, dims = [] : (tensor>) -> tensor<8x8xcomplex> - %10 = stablehlo.divide %7, %9 : tensor<8x8xcomplex> - %11 = call @tril(%10) : (tensor<8x8xcomplex>) -> tensor<8x8xcomplex> - %12 = stablehlo.constant dense<1> : tensor - %13 = stablehlo.constant dense<1> : tensor - %14 = stablehlo.constant dense<8> : tensor - %15 = stablehlo.custom_call @lapack_cheevd(%12, %13, %14, %11) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor<8x8xcomplex>) -> tuple>, tensor<8xf32>, tensor, tensor<81xcomplex>, tensor<169xf32>, tensor<43xi32>> - %16 = stablehlo.get_tuple_element %15[0] : (tuple>, tensor<8xf32>, tensor, tensor<81xcomplex>, tensor<169xf32>, tensor<43xi32>>) -> tensor<8x8xcomplex> - %17 = stablehlo.get_tuple_element %15[1] : (tuple>, tensor<8xf32>, tensor, tensor<81xcomplex>, tensor<169xf32>, tensor<43xi32>>) -> tensor<8xf32> - %18 = stablehlo.get_tuple_element %15[2] : (tuple>, tensor<8xf32>, tensor, tensor<81xcomplex>, tensor<169xf32>, tensor<43xi32>>) -> tensor - %19 = stablehlo.get_tuple_element %15[3] : (tuple>, tensor<8xf32>, tensor, tensor<81xcomplex>, tensor<169xf32>, tensor<43xi32>>) -> tensor<81xcomplex> - %20 = stablehlo.get_tuple_element %15[4] : (tuple>, tensor<8xf32>, tensor, tensor<81xcomplex>, tensor<169xf32>, tensor<43xi32>>) -> tensor<169xf32> - %21 = stablehlo.get_tuple_element %15[5] : (tuple>, tensor<8xf32>, tensor, tensor<81xcomplex>, tensor<169xf32>, tensor<43xi32>>) -> tensor<43xi32> - %22 = stablehlo.constant dense<0> : tensor - %23 = stablehlo.broadcast_in_dim %22, dims = [] : (tensor) -> tensor - %24 = stablehlo.compare EQ, %18, %23, SIGNED : (tensor, tensor) -> tensor - %25 = stablehlo.broadcast_in_dim %24, dims = [] : (tensor) -> tensor<1x1xi1> - %26 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> - %27 = stablehlo.broadcast_in_dim %26, dims = [] : (tensor>) -> tensor<8x8xcomplex> - %28 = stablehlo.broadcast_in_dim %25, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<8x8xi1> - %29 = stablehlo.select %28, %16, %27 : tensor<8x8xi1>, tensor<8x8xcomplex> - %30 = stablehlo.broadcast_in_dim %24, dims = [] : (tensor) -> tensor<1xi1> - %31 = stablehlo.constant dense<0x7FC00000> : tensor - %32 = stablehlo.broadcast_in_dim %31, dims = [] : (tensor) -> tensor<8xf32> - %33 = stablehlo.broadcast_in_dim %30, dims = [0] : (tensor<1xi1>) -> tensor<8xi1> - %34 = stablehlo.select %33, %17, %32 : tensor<8xi1>, tensor<8xf32> - return %29, %34 : tensor<8x8xcomplex>, tensor<8xf32> - } - func.func private @tril(%arg0: tensor<8x8xcomplex>) -> tensor<8x8xcomplex> { - %0 = stablehlo.iota dim = 0 : tensor<8x8xi32> - %1 = stablehlo.constant dense<0> : tensor - %2 = stablehlo.broadcast_in_dim %1, dims = [] : (tensor) -> tensor<8x8xi32> - %3 = stablehlo.add %0, %2 : tensor<8x8xi32> - %4 = stablehlo.iota dim = 1 : tensor<8x8xi32> - %5 = stablehlo.compare GE, %3, %4, SIGNED : (tensor<8x8xi32>, tensor<8x8xi32>) -> tensor<8x8xi1> - %6 = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor> - %7 = stablehlo.broadcast_in_dim %6, dims = [] : (tensor>) -> tensor<8x8xcomplex> - %8 = stablehlo.select %5, %arg0, %7 : tensor<8x8xi1>, tensor<8x8xcomplex> - return %8 : tensor<8x8xcomplex> - } -} -""", - mlir_module_serialized=b"ML\xefR\x03MLIRxxx-trunk\x00\x015\x05\x01\x05\x01\x03\x05\x03%\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x1f!#%')\x03\xc6\x02\x1e\x02?\x01\xa9\x0f\x17\x13\x0b\x17\x0b\x07\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x0b\x13\x0f\x0b\x0b\x17\x0f\x13\x13\x0b33\x0b\x0f\x0b\x0b\x13\x0f\x0b\x1b\x0f\x0b\x13\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x13\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x13\x0b\x13\x0b\x0b\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x13\x13\x13\x13\x13\x1b\x17\x03a\x0f/\x0b\x0b\x0f\x0b\x0bO\x0b\x13\x13\x0b\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x1f\x0f\x0f\x0b/O/\x1f\x1f\x0b\x0b\x0b\x0b\x1b\x0f\x17#\x0f\x0f\x0f\x0f\x0f\x0f\x0b/O\x1f/\x01\x07\x17\x17\x17\x03?\x17\x0f\x07\x0f\x07\x13\x07\x07\x0b\x17\x17\x07\x17\x13\x17\x17\x13\x0f\x17\x17\x13\x17#\x13\x13\x13\x0f\x17\x13\x13\x13\x02&\n\x1d\x83\x03\x17\x13\xf6\x04\x01\x03\x03\x15\xd3\x05+\x17\x13\xf2\x04\x01\x05-\x1f\x1d9\x03\x05/\x051\x053\x055\x057\x059\x05;\x03\x03!\xcf\x05=\x03\x03\x07\xd1\x1d?\x03\x05?\x05A\x17\x13\xea\x04\x01\x1d}\t\x03\x03\x07\xdf\x03\x03\x113\x05C\x03\x0b\x17\xad\x19\xb9\x1b\xbb\x11\xc5\x1d\xc7\x03\x0b\x17\xb1\x19\xcb\x1b\xb1\x11\xb3\x1d\xcd\x05E\x1d=\x03\x05G\x05I\x03\x03!\xd5\x1dE\x03\x05K\x03\x05'\xb5)\xd7\x1dK\x03\x05M\x03\x03\x07\xd9\x1dQ\x03\x05O\x1dU\x03\x05Q\x1dY+\x05S\x1d]+\x05U\x03\x03a\xdb\x05W\x1de\t\x05Y\x1di\t\x05[\x1dm\t\x05]\x1dq\t\x05_\x1du\t\x05a\x1dy\t\x05c\x03\x03\x07\xdd\x05e\x03\x03\x81\xb3\x05g\x05i\x03\x03\x07\xe1\x03\x11\x89\xe3\x8b\xe5\x8d\xe7\x8f\xad\x91\xe9\x93\xeb\x95\xed\x97\xf1\x05k\x05m\x05o\x05q\x05s\x05u\x05w\x05y\x03\x03\x0b\xf3\x03\x03\x0b\xf5\x03\x03\x0b\xf7\x03\x03\x0b\xf9\x03\x03\x0b\xfb\x03\x03\x0b\xfd\x03\x05'\xb5)\xff\x03\x03\x07\x02\x02\x1f/\x01\x1f3\x11\x00\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1d{\x03\x03\xc9\x1d}\t\x07\x1f1!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00#%\x03\x05\xbd\xc1\r\x03\xaf\xbf\x1d\x7f\r\x03\xaf\xc3\x1d\x81\x1d\x83\x1d\x85\r\x01#'\x1d\x87\x13\r\x01\x1f\x03\t\x00\x00\x00\x00\x1f)\x01\x13\r\x05\x07\x05\x1f\x07\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x1b!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x07\x11\x00\x00\x00@\x00\x00\x00\x00\x1f\x03\t\x01\x00\x00\x00\x1f\x03\t\x08\x00\x00\x00\x0b\x05\x1d\x89\x1d\x8b\x05\x01\x03\t\xa9\xa9\xa9\xb7\x03\x03\xef\x15\x03\x01\r\x01\x03\r\xb7\xab\xa9\xab\xab\xab\x13\x05\x01\x13\x05\x05\x13\x05\t\x13\x05\r\x13\x05\x11\x13\x05\x15\x07\x01\x1f\x07\x11\x00\x00\xc0\x7f\x00\x00\xc0\x7f\x1f\x1b!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f#\t\x00\x00\xc0\x7f\x1f=\x11\x00\x00\x00\x00\x00\x00\x00\x00\x03\x03\x15\x06\x02\x03\x03\x07\n\x02\x03\x03\x15\x0e\x02)\x05!!\x11)\x01\x05\x1b)\x01\x11\t)\x03!\t\x1d\x01\x03\t)\x05!!\x05)\x05!!\t\x13)\x05!!\x0f)\x03\t\r)\x03\x8a\x02\x11)\x03J\x05\t)\x03\xad\x05)\x01\t\x11\x01\x05\x01\x0b\x11\x03\x01\x03\x01)\x03\x01\r)\x03\x02\x02\x11/\r\x01\x0b\x03\x1d\x1f!)\x03\x01\x17)\x03\t\x17)\x03\x05\x17)\x01\x0f)\x05\x05\x05\x0f)\x03\x05\x0f)\x03!\x0f)\x03\x05\r\x04\xda\x05\x05\x01\x11\r1\x07\x03\x01\t\r\x11\r5\x05\x03G\x91\t\x03W\x1f\x03+\x15\x06[\x03\x01\x03\x01\x17\x07c_\x03\x01\x03\x03\x19\x06g\x03\x15\x03\x05\x1b\x06k\x03\x15\x03\x05\x1d\x06o\x03\x15\x03\t\x1f\x06s\x03\x01\x05\x07\x0b\x0f\x06w\x03\x01\x05\x03\r\x05\x03\r{\x03\x07\x03\x07-\x05\x03\x01\x03\x11!\x06-\x03\x01\x05\x0f\x13#\x07\x0f\x7f\x03\x01\x03\x15\x05\x03\x01/\x03\x03\x05\x03\x01/\x03\x03\x05\x03\x01\x85\x03\x03%\x07\x01\x87\x03-\t\x19\x1b\x1d\x17\x07\x07\x01\x99\x03\x01\x03\x1f\x07\x07\x01\x9b\x03\x0b\x03\x1f\x07\x07\x01\x9d\x03\x03\x03\x1f\x07\x07\x01\x9f\x03\x1d\x03\x1f\x07\x07\x01\xa1\x03\x1f\x03\x1f\x07\x07\x01\xa3\x03!\x03\x1f\x05\x03\x01#\x03\x03\x03\x07\x01\x05\x03\x03\x03-\x11\x07\x01\xa5\x035\x05%/\x03\x07\x01\x05\x037\x031\x05\x03\x01\xa7\x03\x07\x03\x07\x01\x05\x03\x01\x035\x03\x07\x01\x12\x02\x03\x19\x033\x0b\x06\x01\x03\x01\x079!7\x03\x07\x01\x05\x039\x031\x05\x03\x01\x16\x02\x03#\x03\x07\x01\x05\x03\x0b\x03?\x03\x07\x01\x1a\x02\x03;\x03=\x0b\x06\x01\x03\x0b\x07C#A\x13\x04\r\x05;E\r\x11\x0f7\x05\x03\x15+\x03\x01\r\t\x03;\x1f\x03\x13\x05\x03\x0f#\x03\x03\x03\x07%\x05\x03\x13\x03\x05\x0f\x06%\x03\x13\x05\x03\x07\t\x03CA\x03\x13\x11\x07IG\x03\x19\x05\t\x0b\x05\x03\x0fM\x03\x07\x03\x07O\x05\x03\x01\x03\x0f\x0b\x06S\x03\x01\x07\r\x01\x11\x13\x04\x0f\x03\x13\x06\x03\x01\x05\x01\x00F\x1c\x8d\x1d\x03\x11\x0f\x0b\t\t\x0b!\x1f/!!)#\x1f\x19\x7f\x0f99A9;;m\x19\x85\x8dW\xb3K\x9bM\x9b\x96\x04\x1b+\x1b\x1f\x1f\x15\x1d\x15+\x83\x13\r\r\x1f\x11\x15\x17\x15\x11\x11\x1b\x17\x15\x17\x0f\x11\x15\x11+\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00get_tuple_element_v1\x00iota_v1\x00select_v1\x00func_v1\x00add_v1\x00compare_v1\x00return_v1\x00reshape_v1\x00transpose_v1\x00real_v1\x00imag_v1\x00negate_v1\x00complex_v1\x00divide_v1\x00call_v1\x00custom_call_v1\x00value\x00index\x00sym_name\x00third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py\x00broadcast_dimensions\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00compare_type\x00comparison_direction\x00jit__lambda_\x00jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False,) name=tril in_positional_semantics=(<_PositionalSemantics.GLOBAL: 1>,) out_positional_semantics=_PositionalSemantics.GLOBAL keep_unused=False inline=False]\x00jit()/jit(main)/jit(tril)/iota[dtype=int32 shape=(8, 8) dimension=0]\x00jit()/jit(main)/jit(tril)/add\x00jit()/jit(main)/jit(tril)/iota[dtype=int32 shape=(8, 8) dimension=1]\x00jit()/jit(main)/jit(tril)/ge\x00jit()/jit(main)/jit(tril)/broadcast_in_dim[shape=(8, 8) broadcast_dimensions=()]\x00jit()/jit(main)/jit(tril)/select_n\x00jit()/jit(main)/iota[dtype=complex64 shape=(64,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(8, 8) dimensions=None]\x00permutation\x00jit()/jit(main)/transpose[permutation=(1, 0)]\x00jit()/jit(main)/real\x00jit()/jit(main)/imag\x00jit()/jit(main)/neg\x00jit()/jit(main)/complex\x00jit()/jit(main)/add\x00jit()/jit(main)/div\x00callee\x00jit()/jit(main)/eigh[lower=True sort_eigenvalues=True]\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jax.result_info\x00tril\x00[0]\x00[1]\x00main\x00public\x00private\x00\x00lapack_cheevd\x00", - xla_call_module_version=4, - ), # End paste - - # Pasted from the test output (see back_compat_test.py module docstring) - c128=dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_zheevd'], - serialized_date=datetime.date(2023, 3, 17), - inputs=(), - expected_outputs=(array([[-6.1857700048412056e-01+0.j, 2.4081403770912022e-01+0.j, - 3.5662489253627483e-01+0.j, -6.3034019033669797e-01+0.j, - 1.0043483479985752e-16+0.j, -2.8842036081919542e-02+0.j, - 7.7164692943283169e-25+0.j, -1.8446994643771725e-01+0.j], - [-4.7070881487314609e-01+0.j, 4.7473787464450828e-01+0.j, - -4.8036836210243361e-01+0.j, 4.3802686872516400e-01+0.j, - 1.7961797619639255e-01+0.j, 8.3080980076741355e-03+0.j, - 2.1415294457221759e-01+0.j, -2.2856669794666584e-01+0.j], - [-3.2284062926217072e-01+0.j, -5.4336490915553370e-01+0.j, - 2.2181041859724987e-01+0.j, 2.9947877954402286e-01+0.j, - -3.6491813600134637e-01+0.j, 3.2867679819727436e-01+0.j, - 3.8223299448843473e-01+0.j, -2.7266344945561438e-01+0.j], - [-1.7497244365119527e-01+0.j, -8.9251550609769331e-02+0.j, - -6.3518515114898352e-02+0.j, 1.9162997359209963e-01+0.j, - -2.2087281326110142e-01+0.j, 5.9957027043505008e-02+0.j, - -8.7632498908241274e-01+0.j, -3.1676020096456303e-01+0.j], - [-2.7104258040220017e-02+0.j, -3.3772873786627688e-01+0.j, - 2.5901386593721754e-01+0.j, 1.7032650752287815e-01+0.j, - 6.7521217612940321e-01+0.j, -4.5036136532965476e-01+0.j, - -1.2279030059078447e-02+0.j, -3.6085695247351163e-01+0.j], - [ 1.2076392757075533e-01+0.j, -3.3834734096469249e-01+0.j, - -6.5506827461665529e-01+0.j, -5.0472498521116760e-01+0.j, - 6.9987430903492132e-02+0.j, 1.0595648906599270e-01+0.j, - 8.3443844143082035e-02+0.j, -4.0495370398246017e-01+0.j], - [ 2.6863211318173102e-01+0.j, 2.2958613191407312e-01+0.j, - 6.3952843755683969e-02+0.j, 1.8776775771084192e-02+0.j, - -5.3523731432241317e-01+0.j, -5.9199531677602002e-01+0.j, - 1.7916671834524250e-01+0.j, -4.4905045549140887e-01+0.j], - [ 4.1650029879270667e-01+0.j, 3.6355449432857068e-01+0.j, - 2.9755313100756148e-01+0.j, 1.6826270392616000e-02+0.j, - 1.9621068035557282e-01+0.j, 5.6830030587314817e-01+0.j, - 2.9607517592514260e-02+0.j, -4.9314720700035747e-01+0.j]]), array([-2.4598804776133626e+01, -4.6567755957874661e-14, - -1.9932120610662194e-14, -5.7323356091157378e-15, - -4.5459724251334835e-16, 4.0479851042511616e-14, - 9.2325194924982089e-14, 2.7659880477613365e+02])), - mlir_module_text=r""" -module @jit__lambda_ { - func.func public @main() -> (tensor<8x8xcomplex> {jax.result_info = "[0]"}, tensor<8xf64> {jax.result_info = "[1]"}) { - %0 = stablehlo.iota dim = 0 : tensor<64xcomplex> - %1 = stablehlo.reshape %0 : (tensor<64xcomplex>) -> tensor<8x8xcomplex> - %2 = stablehlo.transpose %1, dims = [1, 0] : (tensor<8x8xcomplex>) -> tensor<8x8xcomplex> - %3 = stablehlo.real %2 : (tensor<8x8xcomplex>) -> tensor<8x8xf64> - %4 = stablehlo.imag %2 : (tensor<8x8xcomplex>) -> tensor<8x8xf64> - %5 = stablehlo.negate %4 : tensor<8x8xf64> - %6 = stablehlo.complex %3, %5 : tensor<8x8xcomplex> - %7 = stablehlo.add %1, %6 : tensor<8x8xcomplex> - %8 = stablehlo.constant dense<(2.000000e+00,0.000000e+00)> : tensor> - %9 = stablehlo.broadcast_in_dim %8, dims = [] : (tensor>) -> tensor<8x8xcomplex> - %10 = stablehlo.divide %7, %9 : tensor<8x8xcomplex> - %11 = call @tril(%10) : (tensor<8x8xcomplex>) -> tensor<8x8xcomplex> - %12 = stablehlo.constant dense<1> : tensor - %13 = stablehlo.constant dense<1> : tensor - %14 = stablehlo.constant dense<8> : tensor - %15 = stablehlo.custom_call @lapack_zheevd(%12, %13, %14, %11) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor<8x8xcomplex>) -> tuple>, tensor<8xf64>, tensor, tensor<81xcomplex>, tensor<169xf64>, tensor<43xi32>> - %16 = stablehlo.get_tuple_element %15[0] : (tuple>, tensor<8xf64>, tensor, tensor<81xcomplex>, tensor<169xf64>, tensor<43xi32>>) -> tensor<8x8xcomplex> - %17 = stablehlo.get_tuple_element %15[1] : (tuple>, tensor<8xf64>, tensor, tensor<81xcomplex>, tensor<169xf64>, tensor<43xi32>>) -> tensor<8xf64> - %18 = stablehlo.get_tuple_element %15[2] : (tuple>, tensor<8xf64>, tensor, tensor<81xcomplex>, tensor<169xf64>, tensor<43xi32>>) -> tensor - %19 = stablehlo.get_tuple_element %15[3] : (tuple>, tensor<8xf64>, tensor, tensor<81xcomplex>, tensor<169xf64>, tensor<43xi32>>) -> tensor<81xcomplex> - %20 = stablehlo.get_tuple_element %15[4] : (tuple>, tensor<8xf64>, tensor, tensor<81xcomplex>, tensor<169xf64>, tensor<43xi32>>) -> tensor<169xf64> - %21 = stablehlo.get_tuple_element %15[5] : (tuple>, tensor<8xf64>, tensor, tensor<81xcomplex>, tensor<169xf64>, tensor<43xi32>>) -> tensor<43xi32> - %22 = stablehlo.constant dense<0> : tensor - %23 = stablehlo.broadcast_in_dim %22, dims = [] : (tensor) -> tensor - %24 = stablehlo.compare EQ, %18, %23, SIGNED : (tensor, tensor) -> tensor - %25 = stablehlo.broadcast_in_dim %24, dims = [] : (tensor) -> tensor<1x1xi1> - %26 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> - %27 = stablehlo.broadcast_in_dim %26, dims = [] : (tensor>) -> tensor<8x8xcomplex> - %28 = stablehlo.broadcast_in_dim %25, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<8x8xi1> - %29 = stablehlo.select %28, %16, %27 : tensor<8x8xi1>, tensor<8x8xcomplex> - %30 = stablehlo.broadcast_in_dim %24, dims = [] : (tensor) -> tensor<1xi1> - %31 = stablehlo.constant dense<0x7FF8000000000000> : tensor - %32 = stablehlo.broadcast_in_dim %31, dims = [] : (tensor) -> tensor<8xf64> - %33 = stablehlo.broadcast_in_dim %30, dims = [0] : (tensor<1xi1>) -> tensor<8xi1> - %34 = stablehlo.select %33, %17, %32 : tensor<8xi1>, tensor<8xf64> - return %29, %34 : tensor<8x8xcomplex>, tensor<8xf64> - } - func.func private @tril(%arg0: tensor<8x8xcomplex>) -> tensor<8x8xcomplex> { - %0 = stablehlo.iota dim = 0 : tensor<8x8xi32> - %1 = stablehlo.constant dense<0> : tensor - %2 = stablehlo.broadcast_in_dim %1, dims = [] : (tensor) -> tensor<8x8xi32> - %3 = stablehlo.add %0, %2 : tensor<8x8xi32> - %4 = stablehlo.iota dim = 1 : tensor<8x8xi32> - %5 = stablehlo.compare GE, %3, %4, SIGNED : (tensor<8x8xi32>, tensor<8x8xi32>) -> tensor<8x8xi1> - %6 = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor> - %7 = stablehlo.broadcast_in_dim %6, dims = [] : (tensor>) -> tensor<8x8xcomplex> - %8 = stablehlo.select %5, %arg0, %7 : tensor<8x8xi1>, tensor<8x8xcomplex> - return %8 : tensor<8x8xcomplex> - } -} -""", - mlir_module_serialized=b"ML\xefR\x03MLIRxxx-trunk\x00\x015\x05\x01\x05\x01\x03\x05\x03%\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x1f!#%')\x03\xc6\x02\x1e\x02?\x01\xa9\x0f\x17\x13\x0b\x17\x0b\x07\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x0b\x13\x0f\x0b\x0b\x17\x0f\x13\x13\x0b33\x0b\x0f\x0b\x0b\x13\x0f\x0b\x1b\x0f\x0b\x13\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x13\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x13\x0b\x13\x0b\x0b\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x13\x13\x13\x13\x13\x1b\x17\x03a\x0f/\x0b\x0b\x0f\x0b\x0bO\x0b\x13\x13\x0b\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x1f\x0f\x0f\x0bOOO\x1f\x1f\x0b\x0b\x0b\x0b\x1b\x0f\x17#\x0f\x0f\x0f\x0f\x0f\x0f\x0bOO//\x01\x07\x17\x17\x17\x03?\x17\x0f\x07\x0f\x07\x13\x07\x07\x0b\x17\x17\x07\x17\x13\x17\x17\x13\x0f\x17\x17\x13\x17#\x13\x13\x13\x0f\x17\x13\x13\x13\x02\x96\n\x1d\x83\x03\x17\x13\xf6\x04\x01\x03\x03\x15\xd3\x05+\x17\x13\xf2\x04\x01\x05-\x1f\x1d9\x03\x05/\x051\x053\x055\x057\x059\x05;\x03\x03!\xcf\x05=\x03\x03\x07\xd1\x1d?\x03\x05?\x05A\x17\x13\xea\x04\x01\x1d}\t\x03\x03\x07\xdf\x03\x03\x113\x05C\x03\x0b\x17\xad\x19\xb9\x1b\xbb\x11\xc5\x1d\xc7\x03\x0b\x17\xb1\x19\xcb\x1b\xb1\x11\xb3\x1d\xcd\x05E\x1d=\x03\x05G\x05I\x03\x03!\xd5\x1dE\x03\x05K\x03\x05'\xb5)\xd7\x1dK\x03\x05M\x03\x03\x07\xd9\x1dQ\x03\x05O\x1dU\x03\x05Q\x1dY+\x05S\x1d]+\x05U\x03\x03a\xdb\x05W\x1de\t\x05Y\x1di\t\x05[\x1dm\t\x05]\x1dq\t\x05_\x1du\t\x05a\x1dy\t\x05c\x03\x03\x07\xdd\x05e\x03\x03\x81\xb3\x05g\x05i\x03\x03\x07\xe1\x03\x11\x89\xe3\x8b\xe5\x8d\xe7\x8f\xad\x91\xe9\x93\xeb\x95\xed\x97\xf1\x05k\x05m\x05o\x05q\x05s\x05u\x05w\x05y\x03\x03\x0b\xf3\x03\x03\x0b\xf5\x03\x03\x0b\xf7\x03\x03\x0b\xf9\x03\x03\x0b\xfb\x03\x03\x0b\xfd\x03\x05'\xb5)\xff\x03\x03\x07\x02\x02\x1f/\x01\x1f3\x11\x00\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1d{\x03\x03\xc9\x1d}\t\x07\x1f1!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00#%\x03\x05\xbd\xc1\r\x03\xaf\xbf\x1d\x7f\r\x03\xaf\xc3\x1d\x81\x1d\x83\x1d\x85\r\x01#'\x1d\x87\x13\r\x01\x1f\x03\t\x00\x00\x00\x00\x1f)\x01\x13\r\x05\x07\x05\x1f\x07!\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x1b!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x07!\x00\x00\x00\x00\x00\x00\x00@\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x03\t\x01\x00\x00\x00\x1f\x03\t\x08\x00\x00\x00\x0b\x05\x1d\x89\x1d\x8b\x05\x01\x03\t\xa9\xa9\xa9\xb7\x03\x03\xef\x15\x03\x01\r\x01\x03\r\xb7\xab\xa9\xab\xab\xab\x13\x05\x01\x13\x05\x05\x13\x05\t\x13\x05\r\x13\x05\x11\x13\x05\x15\x07\x01\x1f\x07!\x00\x00\x00\x00\x00\x00\xf8\x7f\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f\x1b!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f#\x11\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f=\x11\x00\x00\x00\x00\x00\x00\x00\x00\x03\x03\x15\x06\x02\x03\x03\x07\n\x02\x03\x03\x15\x0e\x02)\x05!!\x11)\x01\x05\x1b)\x01\x11\x0b)\x03!\t\x1d\x01\x03\t)\x05!!\x05)\x05!!\t\x13)\x05!!\x0f)\x03\t\r)\x03\x8a\x02\x11)\x03J\x05\t)\x03\xad\x05)\x01\t\x11\x01\x05\x01\x0b\x11\x03\x01\x03\x01)\x03\x01\r)\x03\x02\x02\x11/\r\x01\x0b\x03\x1d\x1f!)\x03\x01\x17)\x03\t\x17)\x03\x05\x17)\x01\x0f)\x05\x05\x05\x0f)\x03\x05\x0f)\x03!\x0f)\x03\x05\r\x04\xda\x05\x05\x01\x11\r1\x07\x03\x01\t\r\x11\r5\x05\x03G\x91\t\x03W\x1f\x03+\x15\x06[\x03\x01\x03\x01\x17\x07c_\x03\x01\x03\x03\x19\x06g\x03\x15\x03\x05\x1b\x06k\x03\x15\x03\x05\x1d\x06o\x03\x15\x03\t\x1f\x06s\x03\x01\x05\x07\x0b\x0f\x06w\x03\x01\x05\x03\r\x05\x03\r{\x03\x07\x03\x07-\x05\x03\x01\x03\x11!\x06-\x03\x01\x05\x0f\x13#\x07\x0f\x7f\x03\x01\x03\x15\x05\x03\x01/\x03\x03\x05\x03\x01/\x03\x03\x05\x03\x01\x85\x03\x03%\x07\x01\x87\x03-\t\x19\x1b\x1d\x17\x07\x07\x01\x99\x03\x01\x03\x1f\x07\x07\x01\x9b\x03\x0b\x03\x1f\x07\x07\x01\x9d\x03\x03\x03\x1f\x07\x07\x01\x9f\x03\x1d\x03\x1f\x07\x07\x01\xa1\x03\x1f\x03\x1f\x07\x07\x01\xa3\x03!\x03\x1f\x05\x03\x01#\x03\x03\x03\x07\x01\x05\x03\x03\x03-\x11\x07\x01\xa5\x035\x05%/\x03\x07\x01\x05\x037\x031\x05\x03\x01\xa7\x03\x07\x03\x07\x01\x05\x03\x01\x035\x03\x07\x01\x12\x02\x03\x19\x033\x0b\x06\x01\x03\x01\x079!7\x03\x07\x01\x05\x039\x031\x05\x03\x01\x16\x02\x03#\x03\x07\x01\x05\x03\x0b\x03?\x03\x07\x01\x1a\x02\x03;\x03=\x0b\x06\x01\x03\x0b\x07C#A\x13\x04\r\x05;E\r\x11\x0f7\x05\x03\x15+\x03\x01\r\t\x03;\x1f\x03\x13\x05\x03\x0f#\x03\x03\x03\x07%\x05\x03\x13\x03\x05\x0f\x06%\x03\x13\x05\x03\x07\t\x03CA\x03\x13\x11\x07IG\x03\x19\x05\t\x0b\x05\x03\x0fM\x03\x07\x03\x07O\x05\x03\x01\x03\x0f\x0b\x06S\x03\x01\x07\r\x01\x11\x13\x04\x0f\x03\x13\x06\x03\x01\x05\x01\x00J\x1c\x8d\x1d\x03\x11\x0f\x0b\t\t\x0b!\x1f/!!)#\x1f\x19\x7f\x0f99A9;;m\x19\x85\x8fW\xb3K\x9bM\x9b\x96\x04\x1b+\x1b\x1f\x1f\x15\x1d\x15+\x83\x13\r\r\x1f\x11\x15\x17\x15\x11\x11\x1b\x17\x15\x17\x0f\x11\x15\x11+\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00get_tuple_element_v1\x00iota_v1\x00select_v1\x00func_v1\x00add_v1\x00compare_v1\x00return_v1\x00reshape_v1\x00transpose_v1\x00real_v1\x00imag_v1\x00negate_v1\x00complex_v1\x00divide_v1\x00call_v1\x00custom_call_v1\x00value\x00index\x00sym_name\x00third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py\x00broadcast_dimensions\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00compare_type\x00comparison_direction\x00jit__lambda_\x00jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False,) name=tril in_positional_semantics=(<_PositionalSemantics.GLOBAL: 1>,) out_positional_semantics=_PositionalSemantics.GLOBAL keep_unused=False inline=False]\x00jit()/jit(main)/jit(tril)/iota[dtype=int32 shape=(8, 8) dimension=0]\x00jit()/jit(main)/jit(tril)/add\x00jit()/jit(main)/jit(tril)/iota[dtype=int32 shape=(8, 8) dimension=1]\x00jit()/jit(main)/jit(tril)/ge\x00jit()/jit(main)/jit(tril)/broadcast_in_dim[shape=(8, 8) broadcast_dimensions=()]\x00jit()/jit(main)/jit(tril)/select_n\x00jit()/jit(main)/iota[dtype=complex128 shape=(64,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(8, 8) dimensions=None]\x00permutation\x00jit()/jit(main)/transpose[permutation=(1, 0)]\x00jit()/jit(main)/real\x00jit()/jit(main)/imag\x00jit()/jit(main)/neg\x00jit()/jit(main)/complex\x00jit()/jit(main)/add\x00jit()/jit(main)/div\x00callee\x00jit()/jit(main)/eigh[lower=True sort_eigenvalues=True]\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jax.result_info\x00tril\x00[0]\x00[1]\x00main\x00public\x00private\x00\x00lapack_zheevd\x00", - xla_call_module_version=4, - ), # End paste -) - data_2024_08_19 = {} - # Pasted from the test output (see export_back_compat_test_util.py module docstring) data_2024_08_19["c128"] = dict( testdata_version=1, diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/cpu_hessenberg_lapack_gehrd.py b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_hessenberg_lapack_gehrd.py index 204af8f55396..8d87c2524e64 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_data/cpu_hessenberg_lapack_gehrd.py +++ b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_hessenberg_lapack_gehrd.py @@ -17,275 +17,8 @@ import datetime from numpy import array, float32, complex64 -data_2024_08_30 = {} - - -# Pasted from the test output (see export_back_compat_test_util.py module docstring) -data_2024_08_30["c128"] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_zgehrd'], - serialized_date=datetime.date(2024, 8, 30), - inputs=(), - expected_outputs=(array([[[ 0.7137638961069523 +2.4533812415320035e+00j, - -0.3272236912989258 -3.2003874808591863e+00j, - -3.065817294924296 +1.6978219378771007e+00j, - -3.3971558164664 +2.6931967836060400e-01j], - [ 6.346214936866542 +0.0000000000000000e+00j, - 2.083218259144673 -1.2191838498692813e+00j, - 1.9552582313969427 -3.3216313521481879e+00j, - 2.7451664155727293 +2.5460553490974451e+00j], - [-0.16133388943502391 +3.6906265775683444e-01j, - -4.698636849217318 +0.0000000000000000e+00j, - 2.5396292124414077 -3.3038474840573420e+00j, - 2.5410992366186456 +4.1958389320867528e-01j], - [ 0.47396123039280513 +3.9524384493417053e-03j, - 0.058880409351504966-7.8934332132630333e-02j, - 0.9469634796174572 +0.0000000000000000e+00j, - -3.130422531669044 -8.8070401977461810e-01j]], - - [[-6.7065483048969465 -4.1981401054281309e-01j, - -0.21813268822330256 -3.8602920478381799e+00j, - -0.8248337528620167 -2.9073223456990824e+00j, - -3.597231249446879 +2.7626541679004930e+00j], - [-6.812126638479044 +0.0000000000000000e+00j, - -0.20651586628458585 -1.0948249928988512e+00j, - -1.6675586608354327 +4.2553627621795744e+00j, - -2.410110723267707 +3.6065122124698634e-01j], - [ 0.038235817369200516-3.7823713529009173e-01j, - -8.508141062606947 +0.0000000000000000e+00j, - 4.260708077719245 -6.8052584397204630e-02j, - 5.345997177836541 -1.1955161503390279e+00j], - [-0.18541509608158574 -1.2016051097247168e-01j, - -0.02698777746917469 -4.4847463691672246e-01j, - 6.149305574585603 +0.0000000000000000e+00j, - -2.483131585236393 +2.8524912589603817e+00j]]]), array([[1.2286220194325557+0.5121060656500841j , - 1.9529937219183482-0.23299856112387676j, - 1.5940499664125072-0.8044281430962614j ], - [1.6682114302246909-0.11372755955977935j, - 1.4075913155446236-0.6008708461880701j , - 1.5086928152468893-0.8609480935086589j ]])), - mlir_module_text=r""" -module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main() -> (tensor<2x4x4xcomplex> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<2x3xcomplex> {jax.result_info = "[1]", mhlo.layout_mode = "default"}) { - %cst = stablehlo.constant dense<[[[(0.71376389610695234,2.4533812415320035), (-1.0686093138739379,-1.885041510645256), (3.2629529488994033,-0.87160041258342402), (2.4332168907311504,3.4960248990882183)], [(-1.450884474619478,-3.249935163088522), (0.53920035905924757,-5.0056840575116066), (0.13157186736298554,2.5015499854549939), (-1.2451270607408882,0.24345856951924827)], [(2.457366083193417,-2.3532935513245605), (-0.37595429769485644,1.5729223427874068), (3.5877693970448052,-0.30904304334212157), (-1.685615117470264,2.6148811836470265)], [(-3.6826776618664727,-1.5711608241015744), (-0.12407609317204518,-4.7137561145212281), (1.3298255603911306,-1.6739172003954141), (-2.6345448161870149,-0.089008252847513236)]], [[(-6.7065483048969465,-0.41981401054281309), (-2.1586544949255457,0.34815132010709054), (-5.1462488701272413,3.440817752555807), (1.0301804086076078,-0.6994760434270566)], [(4.551940883969797,-0.77472653800638502), (4.4485186470774796,-0.0024458890677252756), (0.66610302132250898,2.5976571401862039), (-5.0693248202533674,-5.7405538897950699)], [(0.14148406399087146,-4.3279346473525058), (-2.353557113110897,2.0880432773400326), (-3.2524452107293618,-0.42398740171508631), (3.7200566224095519,-0.56951559566037058)], [(-2.2001612082232613,-1.2218661647417151), (0.72437359623190833,8.6381970213061301), (0.72314820631775734,0.058458198280771749), (0.37498718985014962,2.1160469724471378)]]]> : tensor<2x4x4xcomplex> loc(#loc) - %c = stablehlo.constant dense<4> : tensor loc(#loc2) - %c_0 = stablehlo.constant dense<1> : tensor loc(#loc2) - %c_1 = stablehlo.constant dense<4> : tensor loc(#loc2) - %c_2 = stablehlo.constant dense<4> : tensor loc(#loc2) - %c_3 = stablehlo.constant dense<2> : tensor loc(#loc2) - %c_4 = stablehlo.constant dense<4288> : tensor loc(#loc2) - %0:4 = stablehlo.custom_call @lapack_zgehrd(%c, %c_0, %c_1, %c_2, %c_3, %c_4, %cst) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor<2x4x4xcomplex>) -> (tensor<2x4x4xcomplex>, tensor<2x3xcomplex>, tensor<2xi32>, tensor<4288xcomplex>) loc(#loc2) - %c_5 = stablehlo.constant dense<0> : tensor loc(#loc2) - %1 = stablehlo.broadcast_in_dim %c_5, dims = [] : (tensor) -> tensor<2xi32> loc(#loc2) - %2 = stablehlo.compare EQ, %0#2, %1, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc2) - %3 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc2) - %cst_6 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> loc(#loc2) - %4 = stablehlo.broadcast_in_dim %cst_6, dims = [] : (tensor>) -> tensor<2x4x4xcomplex> loc(#loc2) - %5 = stablehlo.broadcast_in_dim %3, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc2) - %6 = stablehlo.select %5, %0#0, %4 : tensor<2x4x4xi1>, tensor<2x4x4xcomplex> loc(#loc2) - %7 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc2) - %cst_7 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> loc(#loc2) - %8 = stablehlo.broadcast_in_dim %cst_7, dims = [] : (tensor>) -> tensor<2x3xcomplex> loc(#loc2) - %9 = stablehlo.broadcast_in_dim %7, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x3xi1> loc(#loc2) - %10 = stablehlo.select %9, %0#1, %8 : tensor<2x3xi1>, tensor<2x3xcomplex> loc(#loc2) - return %6, %10 : tensor<2x4x4xcomplex>, tensor<2x3xcomplex> loc(#loc) - } loc(#loc) -} loc(#loc) -#loc = loc(unknown) -#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":697:13) -#loc2 = loc("jit(func)/jit(main)/hessenberg"(#loc1)) -""", - mlir_module_serialized=b'ML\xefR\x01StableHLO_v0.9.0\x00\x01\x1f\x05\x01\x03\x01\x03\x05\x03\x0f\x07\t\x0b\r\x0f\x11\x13\x03\xef\xa19\x01W\x0f\x0b\x07\x0b\x13\x13\x0f\x0b\x13\x13+\x0b\x0f\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x13\x0b\x17\x0b\x13\x13\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x0b\x13\x13\x03K\x0f\x0b\x0b\x0b\x0bo/\x0b\x13\x1b\x0b\x1b\x0b\x0b\x0b&\x10\x1f\x1f\x1f\x1f\x0b\x0b\x0b\x0b\'\x0f\x17\x1bO\x1f\x0f\x0b\x0b/OoO\x01\x05\x0b\x0f\x035\x0f\x1b\x07\x0b\x17\x07\x07\x0f\x07\x13\x17\x07\x17\x13\x13\x13\x13\x13\x13\x1b\x13\x1b\x13\x17\x17\x13\x02\xce\x0f\x1d-/\x05\x15\x1f\x05\x17\x03\x03\x03w\x03\x03\x07\x93\x11\x03\x05\x05\x19\x03\x03\x07\x99\x03\x03\x03\x9b\x03\t\x17\x19\x1b\r\x1d\r\x0f\x1f\x05\x1b\x11\x01\x00\x05\x1d\x05\x1f\x05!\x03\x0b#Y%e\'g\x0fq)s\x05#\x05%\x05\'\x05)\x03\x03\x03u\x05+\x171\xe6\n\x1b\x05-\x03\x03\x03y\x03\x03\x03{\x03\x03\x03}\x03\x11;\x7f=\x81?\x83AYC\x85E\x87G\x89I\x8d\x05/\x051\x053\x055\x057\x059\x05;\x05=\x03\x03\x03\x91\x03\x05O\x95Q\x97\x05?\x05A\x03\x03\x07\x9d\x03\x03\x07\x9f\x1f\x1f\x01\x03\x01\x1dC\x1dE\x1dG\x1f!1\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f%\x11\x00\x00\x00\x00\x00\x00\x00\x00#\x19\x03\x05im\r\x05[k]_\x1dI\r\x05[o]_\x1dK\x1dM\x1dO\x1f\x07\x02\x08p\t\xdba\'\xd7\xe6?\xa8\xff\'X\x86\xa0\x03@\x0c\xa2t\x14\x06\x19\xf1\xbfT.}I!)\xfe\xbf\x0fG_\x13\x87\x1a\n@\xae:g\x8c&\xe4\xeb\xbf\xeb\x1e\xcej:w\x03@N\xaf\xfc\xe6\xdb\xf7\x0b@\x9f<\x8c\xa3\xd26\xf7\xbf^\xaf\xbc\x01\xde\xff\t\xc0b\xd4\x84\x1c!A\xe1?\xd6{\xa4\n\xd2\x05\x14\xc0\xf0\xe6\xb2\xd1X\xd7\xc0?2\xb5\x86\xa3,\x03\x04@\x91\xf2SZ\n\xec\xf3\xbf\x04\x10\x02\x81\xa6)\xcf?8\xec\x8c\x8c\xaf\xa8\x03@\r\x9d\xc6\x91\x8b\xd3\x02\xc0\xb0\xf6X\x9d\xa2\x0f\xd8\xbf\xbd\xb6V\x9e\xb0*\xf9?7-\x0fq\xc0\xb3\x0c@{|\ry\\\xc7\xd3\xbf\x04\xd9\xb2\x8eG\xf8\xfa\xbf\x9b\x84u\xd3F\xeb\x04@\xf4h\xbb\xb4\x1fv\r\xc0\xdc\\D\x88y#\xf9\xbf\x9a\xaecjs\xc3\xbf\xbf<\xc1\x04\xe2\xe2\xda\x12\xc0\x89<\xb4*\xf7F\xf5?\x1b\x90\xfef]\xc8\xfa\xbf\xdc\xf4\x8a;\x8c\x13\x05\xc0\xf8\xdd\r\xaf>\xc9\xb6\xbfvN\x1af\x81\xd3\x1a\xc0Z\xc6k\x95;\xde\xda\xbf\x87\x8c\xd8\xa5\xecD\x01\xc0\xdd\xd3zy\x1cH\xd6?\x04\x18\x89C\xc2\x95\x14\xc0\x8c\xc95u\xcb\x86\x0b@\x881\xbfs\x9e{\xf0?\x92Y[\x95\x1bb\xe6\xbf\x06\xe7\xb7\xfd/5\x12@L\x95\x02O\x8f\xca\xe8\xbf2`\xe3xH\xcb\x11@>\xda\xc6\xb1f\td\xbfZ\x1a\x8bH\xb7P\xe5?\xa8\x90zw\x00\xc8\x04@<(\xef\x15\xfdF\x14\xc0\xb4aF\xc2S\xf6\x16\xc0\xc1{\xdfY&\x1c\xc2?\xcfj\xa6\x19\xceO\x11\xc0\xc4\xa2p\xc0\x15\xd4\x02\xc0\xfcv\xa6\x08P\xb4\x00@^\xea\xa0\xfe\x01\x05\n\xc0^\x11\x12\x0e\x9c"\xdb\xbfR#\xe4\x0b\xad\xc2\r@F\x8b=\xc5x9\xe2\xbfZ\xf9\x99\x1e\xee\x99\x01\xc0My\x1a\x89\xc3\x8c\xf3\xbf\xd1\xdc<\x89\x11.\xe7?2\xd4\x8d\xc2\xc1F!@mw\t\xb5\x07$\xe7?G\x16\x99\xa3;\xee\xad?M\xd24E\xca\xff\xd7?\xa2\xae\xfb\x08\xaa\xed\x00@\x1f\x05\t\x04\x00\x00\x00\x1f\x05\t\x01\x00\x00\x00\x1f\x05\t\x02\x00\x00\x00\x1f\x05\t\xc0\x10\x00\x00\x0b\x05\x1dQ\x1dS\x05\x01\x03\x0fWWWWWWa\x03\x03\x8b\x15\x03\x01\x19\x01\x03\ta\x8fcc\x1f#!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x05\t\x00\x00\x00\x00\x1f\'\x01\t\x07\x07\x01\x1f-\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x13!\x00\x00\x00\x00\x00\x00\xf8\x7f\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f11\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x1f7!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x01\x15)\x07\t\x11\x11\x0b\x01\x03\x1b)\x05\t\r\x0b\x13\x1d)\x01\x0b\x1b)\x03\t\x15\x11\x01\x05\x07\r\x0b)\x03\x02\x86\x0b)\x03\x01\x0f)\x03\r\x0f)\x03\t\x0f)\x03\x05\x0f)\x03\x01\x11)\x03\t\t)\x07\t\x05\x05\t)\x03\x05\x11)\x07\t\x11\x11\t)\x03\r\x11)\x05\t\x05\t)\x05\t\r\t)\x03\t\x11\x04\xde\x02\x05\x01\x11\x05\x15\x07\x03\x01\x05\t\x11\x05!\x07\x031Y\x03\x03\x05+\x03\x07\x03\x03\x01\t\x03\x05\x03\x03\x013\x03\x05\x03\x03\x01\t\x03\x05\x03\x03\x01\t\x03\x05\x03\x03\x015\x03\x05\x03\x03\x017\x03\x05\x0b\x07\x019\t\x07\r\x17\x1d\x0f\x03\x05\x07\t\x0b\r\x01\x03\x03\x01K\x03\x05\x05\x07\x01\x0b\x03\x17\x03\x17\r\x07\x01M\x03)\x05\x13\x19\x05\x07\x01\x11\x03+\x03\x1b\x03\x03\x01\x13\x03\x13\x05\x07\x01\x0b\x03\x07\x03\x1f\x05\x07\x01S\x03/\x03\x1d\x07\x06\x01\x03\x07\x07#\x0f!\x05\x07\x01\x11\x033\x03\x1b\x03\x03\x01\x13\x03\x13\x05\x07\x01\x0b\x03\r\x03)\x05\x07\x01U\x035\x03\'\x07\x06\x01\x03\r\x07-\x11+\x0f\x04\x05\x05%/\x06\x03\x01\x05\x01\x00\xf2\tU\x1d\x03\x0f\x0b\t\t\x11#!+\x1b\x1f/!!)#\x1f\x19i?\x1f\x15\x1d\x15\x13%)9\x13+\r\x15\x17\x1f\x11\x15)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00select_v1\x00func_v1\x00custom_call_v1\x00compare_v1\x00return_v1\x00value\x00broadcast_dimensions\x00sym_name\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00jit(func)/jit(main)/hessenberg\x00third_party/py/jax/tests/export_back_compat_test.py\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00compare_type\x00comparison_direction\x00jax.result_info\x00mhlo.layout_mode\x00default\x00[0]\x00[1]\x00main\x00public\x00\x00lapack_zgehrd\x00', - xla_call_module_version=9, - nr_devices=1, -) # End paste - - -# Pasted from the test output (see export_back_compat_test_util.py module docstring) -data_2024_08_30["c64"] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_cgehrd'], - serialized_date=datetime.date(2024, 8, 30), - inputs=(), - expected_outputs=(array([[[ 5.2023945 -0.878671j , -2.8841915 -0.47488597j , - 1.3024182 +0.6651789j , 4.9291854 -1.9147056j ], - [ 6.3457894 +0.j , 1.6869383 -4.6557646j , - 0.88955224-1.7617276j , 2.9149916 +4.342665j ], - [-0.2465725 -0.5776757j , -5.3007755 +0.j , - -0.9786545 -0.0633831j , -1.3690261 -1.5921416j ], - [ 0.35462287+0.35993803j , -0.38403815-0.46558398j , - 2.8020499 +0.j , 0.5636822 -6.218306j ]], - - [[ 1.0687767 -3.88293j , -4.0144 -2.5885587j , - 5.3900986 -0.8850739j , 2.079677 +3.5515747j ], - [ 7.5675693 +0.j , 0.5971966 -3.6699948j , - 2.246994 -1.0858283j , -0.8870981 -0.022960603j], - [-0.2183232 +0.10552277j , 5.860886 +0.j , - -5.091036 +6.2841997j , 5.008773 +1.8765848j ], - [ 0.1378771 +0.427895j , 0.63263524-0.3470098j , - 6.4528017 +0.j , -4.233642 -0.84165764j ]]], - dtype=complex64), array([[1.0933675-0.3605358j , 1.1987956+0.5659744j , - 1.9999101-0.013409062j], - [1.4504763-0.44363326j , 1.3110259-0.07426627j , - 1.227255 +0.97383535j ]], dtype=complex64)), - mlir_module_text=r""" -module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main() -> (tensor<2x4x4xcomplex> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<2x3xcomplex> {jax.result_info = "[1]", mhlo.layout_mode = "default"}) { - %cst = stablehlo.constant dense<[[[(5.20239449,-0.87867099), (-0.211780012,-0.923053801), (-5.25181627,1.90887547), (-1.61342144,-1.98000157)], [(-5.924900e-01,2.28788424), (-1.74142945,-3.25563216), (3.08765078,-3.25260139), (-3.35189271,-0.571629047)], [(3.032444,3.44394636), (1.22205484,0.808871626), (2.58686161,-7.47011566), (1.9139297,-2.57945323)], [(-3.28396916,-1.68601465), (2.62759161,-0.953538239), (-2.78763294,-0.0429570749), (0.426534384,-0.211706176)]], [[(1.06877673,-3.882930e+00), (-0.0192247611,5.96663713), (1.15329504,-5.0599103), (-1.76508892,-1.98541296)], [(-3.40901089,3.35722542), (-6.13531398,2.55851483), (-4.8095789,0.164206699), (-0.247624069,-3.13545418)], [(2.04217815,-1.89123917), (-1.18974173,-1.69466627), (-2.28673625,-0.487834573), (3.01541853,-1.85637176)], [(-2.9499588,-4.23393869), (8.44624137,5.57274485), (-1.09048736,2.4864223), (-0.305431545,-0.298133373)]]]> : tensor<2x4x4xcomplex> loc(#loc) - %c = stablehlo.constant dense<4> : tensor loc(#loc2) - %c_0 = stablehlo.constant dense<1> : tensor loc(#loc2) - %c_1 = stablehlo.constant dense<4> : tensor loc(#loc2) - %c_2 = stablehlo.constant dense<4> : tensor loc(#loc2) - %c_3 = stablehlo.constant dense<2> : tensor loc(#loc2) - %c_4 = stablehlo.constant dense<4288> : tensor loc(#loc2) - %0:4 = stablehlo.custom_call @lapack_cgehrd(%c, %c_0, %c_1, %c_2, %c_3, %c_4, %cst) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor<2x4x4xcomplex>) -> (tensor<2x4x4xcomplex>, tensor<2x3xcomplex>, tensor<2xi32>, tensor<4288xcomplex>) loc(#loc2) - %c_5 = stablehlo.constant dense<0> : tensor loc(#loc2) - %1 = stablehlo.broadcast_in_dim %c_5, dims = [] : (tensor) -> tensor<2xi32> loc(#loc2) - %2 = stablehlo.compare EQ, %0#2, %1, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc2) - %3 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc2) - %cst_6 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> loc(#loc2) - %4 = stablehlo.broadcast_in_dim %cst_6, dims = [] : (tensor>) -> tensor<2x4x4xcomplex> loc(#loc2) - %5 = stablehlo.broadcast_in_dim %3, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc2) - %6 = stablehlo.select %5, %0#0, %4 : tensor<2x4x4xi1>, tensor<2x4x4xcomplex> loc(#loc2) - %7 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc2) - %cst_7 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> loc(#loc2) - %8 = stablehlo.broadcast_in_dim %cst_7, dims = [] : (tensor>) -> tensor<2x3xcomplex> loc(#loc2) - %9 = stablehlo.broadcast_in_dim %7, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x3xi1> loc(#loc2) - %10 = stablehlo.select %9, %0#1, %8 : tensor<2x3xi1>, tensor<2x3xcomplex> loc(#loc2) - return %6, %10 : tensor<2x4x4xcomplex>, tensor<2x3xcomplex> loc(#loc) - } loc(#loc) -} loc(#loc) -#loc = loc(unknown) -#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":697:13) -#loc2 = loc("jit(func)/jit(main)/hessenberg"(#loc1)) -""", - mlir_module_serialized=b'ML\xefR\x01StableHLO_v0.9.0\x00\x01\x1f\x05\x01\x03\x01\x03\x05\x03\x0f\x07\t\x0b\r\x0f\x11\x13\x03\xef\xa19\x01W\x0f\x0b\x07\x0b\x13\x13\x0f\x0b\x13\x13+\x0b\x0f\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x13\x0b\x17\x0b\x13\x13\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x0b\x13\x13\x03K\x0f\x0b\x0b\x0b\x0bo/\x0b\x13\x1b\x0b\x1b\x0b\x0b\x0b&\x08\x1f\x1f\x1f\x1f\x0b\x0b\x0b\x0b\'\x0f\x17\x1bO\x1f\x0f\x0b\x0b//oO\x01\x05\x0b\x0f\x035\x0f\x1b\x07\x0b\x17\x07\x07\x0f\x07\x13\x17\x07\x17\x13\x13\x13\x13\x13\x13\x1b\x13\x1b\x13\x17\x17\x13\x02\xae\x0b\x1d-/\x05\x15\x1f\x05\x17\x03\x03\x03w\x03\x03\x07\x93\x11\x03\x05\x05\x19\x03\x03\x07\x99\x03\x03\x03\x9b\x03\t\x17\x19\x1b\r\x1d\r\x0f\x1f\x05\x1b\x11\x01\x00\x05\x1d\x05\x1f\x05!\x03\x0b#Y%e\'g\x0fq)s\x05#\x05%\x05\'\x05)\x03\x03\x03u\x05+\x171\xe6\n\x1b\x05-\x03\x03\x03y\x03\x03\x03{\x03\x03\x03}\x03\x11;\x7f=\x81?\x83AYC\x85E\x87G\x89I\x8d\x05/\x051\x053\x055\x057\x059\x05;\x05=\x03\x03\x03\x91\x03\x05O\x95Q\x97\x05?\x05A\x03\x03\x07\x9d\x03\x03\x07\x9f\x1f\x1f\x01\x03\x01\x1dC\x1dE\x1dG\x1f!1\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f%\x11\x00\x00\x00\x00\x00\x00\x00\x00#\x19\x03\x05im\r\x05[k]_\x1dI\r\x05[o]_\x1dK\x1dM\x1dO\x1f\x07\x02\x04\x04z\xa6@\x95\xf0`\xbf\xdc\xdcX\xbeAMl\xbf\xe1\x0e\xa8\xc0\x08V\xf4?\x98\x84\xce\xbf\xb1p\xfd\xbfm\xad\x17\xbf\xb2l\x12@)\xe7\xde\xbfG\\P\xc0\x12\x9cE@\x9f*P\xc0i\x85V\xc0HV\x12\xbf\x90\x13B@\x9ei\\@Kl\x9c?6\x12O?$\x8f%@0\x0b\xef\xc0\xa6\xfb\xf4?\xc3\x15%\xc0\x8d,R\xc0T\xcf\xd7\xbfv*(@\x15\x1bt\xbf\x94h2\xc0\xc2\xf3/\xbd\xb7b\xda>\x81\xc9X\xbe\xad\xcd\x88?\xed\x81x\xc0?}\x9d\xbc\xb1\xee\xbe@,\x9f\x93?\xc9\xea\xa1\xc0o\xee\xe1\xbf\x03"\xfe\xbf<-Z\xc0\xc8\xdcV@~T\xc4\xc0\xb5\xbe#@\x12\xe8\x99\xc0\xcd%(>*\x91}\xbeH\xabH\xc0\x0c\xb3\x02@ \x14\xf2\xbfuI\x98\xbf\xd3\xea\xd8\xbf\xe3Y\x12\xc0t\xc5\xf9\xbe\x9e\xfc@@\x97\x9d\xed\xbf \xcc<\xc0m|\x87\xc0\xce#\x07A\xedS\xb2@\x17\x95\x8b\xbf\x8b!\x1f@\x86a\x9c\xbe\xf0\xa4\x98\xbe\x1f\x05\t\x04\x00\x00\x00\x1f\x05\t\x01\x00\x00\x00\x1f\x05\t\x02\x00\x00\x00\x1f\x05\t\xc0\x10\x00\x00\x0b\x05\x1dQ\x1dS\x05\x01\x03\x0fWWWWWWa\x03\x03\x8b\x15\x03\x01\x19\x01\x03\ta\x8fcc\x1f#!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x05\t\x00\x00\x00\x00\x1f\'\x01\t\x07\x07\x01\x1f-\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x13\x11\x00\x00\xc0\x7f\x00\x00\xc0\x7f\x1f11\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x1f7!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x01\x15)\x07\t\x11\x11\x0b\x01\x03\x1b)\x05\t\r\x0b\x13\x1d)\x01\x0b\x1b)\x03\t\x15\x11\x01\x05\x07\r\t)\x03\x02\x86\x0b)\x03\x01\x0f)\x03\r\x0f)\x03\t\x0f)\x03\x05\x0f)\x03\x01\x11)\x03\t\t)\x07\t\x05\x05\t)\x03\x05\x11)\x07\t\x11\x11\t)\x03\r\x11)\x05\t\x05\t)\x05\t\r\t)\x03\t\x11\x04\xde\x02\x05\x01\x11\x05\x15\x07\x03\x01\x05\t\x11\x05!\x07\x031Y\x03\x03\x05+\x03\x07\x03\x03\x01\t\x03\x05\x03\x03\x013\x03\x05\x03\x03\x01\t\x03\x05\x03\x03\x01\t\x03\x05\x03\x03\x015\x03\x05\x03\x03\x017\x03\x05\x0b\x07\x019\t\x07\r\x17\x1d\x0f\x03\x05\x07\t\x0b\r\x01\x03\x03\x01K\x03\x05\x05\x07\x01\x0b\x03\x17\x03\x17\r\x07\x01M\x03)\x05\x13\x19\x05\x07\x01\x11\x03+\x03\x1b\x03\x03\x01\x13\x03\x13\x05\x07\x01\x0b\x03\x07\x03\x1f\x05\x07\x01S\x03/\x03\x1d\x07\x06\x01\x03\x07\x07#\x0f!\x05\x07\x01\x11\x033\x03\x1b\x03\x03\x01\x13\x03\x13\x05\x07\x01\x0b\x03\r\x03)\x05\x07\x01U\x035\x03\'\x07\x06\x01\x03\r\x07-\x11+\x0f\x04\x05\x05%/\x06\x03\x01\x05\x01\x00\xf2\tU\x1d\x03\x0f\x0b\t\t\x11#!+\x1b\x1f/!!)#\x1f\x19i?\x1f\x15\x1d\x15\x13%)9\x13+\r\x15\x17\x1f\x11\x15)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00select_v1\x00func_v1\x00custom_call_v1\x00compare_v1\x00return_v1\x00value\x00broadcast_dimensions\x00sym_name\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00jit(func)/jit(main)/hessenberg\x00third_party/py/jax/tests/export_back_compat_test.py\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00compare_type\x00comparison_direction\x00jax.result_info\x00mhlo.layout_mode\x00default\x00[0]\x00[1]\x00main\x00public\x00\x00lapack_cgehrd\x00', - xla_call_module_version=9, - nr_devices=1, -) # End paste - - -# Pasted from the test output (see export_back_compat_test_util.py module docstring) -data_2024_08_30["f32"] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_sgehrd'], - serialized_date=datetime.date(2024, 8, 30), - inputs=(), - expected_outputs=(array([[[-3.5237675 , -6.1161256 , -0.549011 , -4.7706876 ], - [ 5.8401766 , 3.424213 , 0.3059119 , 2.3492367 ], - [ 0.63135445 , 2.7238827 , -0.106214404, -0.82470125 ], - [-0.27146497 , 0.09917235 , 0.2545611 , -0.5113605 ]], - - [[ 4.297168 , -1.8758869 , 0.33528137 , 5.867136 ], - [-7.129698 , -3.3118155 , -1.3492918 , -2.8959117 ], - [-0.7266852 , -3.506432 , 4.77164 , -4.0780373 ], - [ 0.14084078 , 0.3389384 , 2.3910007 , -0.79807365 ]]], - dtype=float32), array([[1.3584172, 1.9805213, 0. ], - [1.2920669, 1.7939165, 0. ]], dtype=float32)), - mlir_module_text=r""" -module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main() -> (tensor<2x4x4xf32> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<2x3xf32> {jax.result_info = "[1]", mhlo.layout_mode = "default"}) { - %cst = stablehlo.constant dense<[[[-3.52376747, -0.758410036, 4.85795927, -6.0243597], [-2.09321976, -1.27957773, -0.956288218, -1.11928439], [-5.00878525, 0.51314038, 3.53047514, -2.91282868], [2.15363932, 0.635739565, -0.21264787, 0.555740714]], [[4.29716778, -3.86209464, -2.39021468, 4.17441607], [2.08234859, -1.03958249, 4.09025383, 5.22586823], [-6.69425774, 3.43749118, -0.691099107, 1.59547663], [1.29743183, -2.00156212, 3.08750296, 2.39243269]]]> : tensor<2x4x4xf32> loc(#loc) - %c = stablehlo.constant dense<4> : tensor loc(#loc2) - %c_0 = stablehlo.constant dense<1> : tensor loc(#loc2) - %c_1 = stablehlo.constant dense<4> : tensor loc(#loc2) - %c_2 = stablehlo.constant dense<4> : tensor loc(#loc2) - %c_3 = stablehlo.constant dense<2> : tensor loc(#loc2) - %c_4 = stablehlo.constant dense<4288> : tensor loc(#loc2) - %0:4 = stablehlo.custom_call @lapack_sgehrd(%c, %c_0, %c_1, %c_2, %c_3, %c_4, %cst) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor<2x4x4xf32>) -> (tensor<2x4x4xf32>, tensor<2x3xf32>, tensor<2xi32>, tensor<4288xf32>) loc(#loc2) - %c_5 = stablehlo.constant dense<0> : tensor loc(#loc2) - %1 = stablehlo.broadcast_in_dim %c_5, dims = [] : (tensor) -> tensor<2xi32> loc(#loc2) - %2 = stablehlo.compare EQ, %0#2, %1, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc2) - %3 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc2) - %cst_6 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc2) - %4 = stablehlo.broadcast_in_dim %cst_6, dims = [] : (tensor) -> tensor<2x4x4xf32> loc(#loc2) - %5 = stablehlo.broadcast_in_dim %3, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc2) - %6 = stablehlo.select %5, %0#0, %4 : tensor<2x4x4xi1>, tensor<2x4x4xf32> loc(#loc2) - %7 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc2) - %cst_7 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc2) - %8 = stablehlo.broadcast_in_dim %cst_7, dims = [] : (tensor) -> tensor<2x3xf32> loc(#loc2) - %9 = stablehlo.broadcast_in_dim %7, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x3xi1> loc(#loc2) - %10 = stablehlo.select %9, %0#1, %8 : tensor<2x3xi1>, tensor<2x3xf32> loc(#loc2) - return %6, %10 : tensor<2x4x4xf32>, tensor<2x3xf32> loc(#loc) - } loc(#loc) -} loc(#loc) -#loc = loc(unknown) -#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":697:13) -#loc2 = loc("jit(func)/jit(main)/hessenberg"(#loc1)) -""", - mlir_module_serialized=b'ML\xefR\x01StableHLO_v0.9.0\x00\x01\x1f\x05\x01\x03\x01\x03\x05\x03\x0f\x07\t\x0b\r\x0f\x11\x13\x03\xed\xa17\x01W\x0f\x0b\x07\x0b\x13\x13\x0f\x0b\x13\x13+\x0b\x0f\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x13\x0b\x17\x0b\x13\x13\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x0b\x13\x13\x03K\x0f\x0b\x0b\x0b\x0bo/\x0b\x13\x1b\x0b\x1b\x0b\x0b\x0b&\x04\x1f\x1f\x1f\x1f\x0b\x0b\x0b\x0b\'\x0f\x17\x1bO\x1f\x0f\x0b\x0b/\x1foO\x01\x05\x0b\x0f\x033\x0f\x1b\x07\x07\x17\x07\x07\x0f\x07\x13\x17\x17\x13\x13\x13\x13\x13\x13\x1b\x13\x1b\x13\x17\x17\x13\x02\x96\t\x1d-/\x05\x15\x1f\x05\x17\x03\x03\x03w\x03\x03\x07\x93\x11\x03\x05\x05\x19\x03\x03\x07\x99\x03\x03\x03\x9b\x03\t\x17\x19\x1b\r\x1d\r\x0f\x1f\x05\x1b\x11\x01\x00\x05\x1d\x05\x1f\x05!\x03\x0b#Y%e\'g\x0fq)s\x05#\x05%\x05\'\x05)\x03\x03\x03u\x05+\x171\xe6\n\x1b\x05-\x03\x03\x03y\x03\x03\x03{\x03\x03\x03}\x03\x11;\x7f=\x81?\x83AYC\x85E\x87G\x89I\x8d\x05/\x051\x053\x055\x057\x059\x05;\x05=\x03\x03\x03\x91\x03\x05O\x95Q\x97\x05?\x05A\x03\x03\x07\x9d\x03\x03\x07\x9f\x1f\x1d\x01\x03\x01\x1dC\x1dE\x1dG\x1f\x1f1\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f#\x11\x00\x00\x00\x00\x00\x00\x00\x00#\x19\x03\x05im\r\x05[k]_\x1dI\r\x05[o]_\x1dK\x1dM\x1dO\x1f\x07\x02\x02h\x85a\xc0)\'B\xbfgt\x9b@\x8e\xc7\xc0\xc0P\xf7\x05\xc04\xc9\xa3\xbfN\xcft\xbf\xb6D\x8f\xbf\xf8G\xa0\xc0+]\x03?N\xf3a@\xc9k:\xc0:\xd5\t@\xd4\xbf"?]\xc0Y\xbe\x06E\x0e?f\x82\x89@\x8f,w\xc0G\xf9\x18\xc0\xd1\x94\x85@3E\x05@\n\x11\x85\xbf\\\xe3\x82@P:\xa7@\\7\xd6\xc0\xdb\xff[@\xdf\xeb0\xbf\x948\xcc??\x12\xa6?\x98\x19\x00\xc0\xa6\x99E@\x9e\x1d\x19@\x1f\x05\t\x04\x00\x00\x00\x1f\x05\t\x01\x00\x00\x00\x1f\x05\t\x02\x00\x00\x00\x1f\x05\t\xc0\x10\x00\x00\x0b\x05\x1dQ\x1dS\x05\x01\x03\x0fWWWWWWa\x03\x03\x8b\x15\x03\x01\x19\x01\x03\ta\x8fcc\x1f!!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x05\t\x00\x00\x00\x00\x1f%\x01\t\x07\x07\x01\x1f+\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x13\t\x00\x00\xc0\x7f\x1f/1\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x1f5!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x01\x15)\x07\t\x11\x11\x0b\x01\t)\x05\t\r\x0b\x13\x1d)\x01\x0b\x1b)\x03\t\x15\x11\x01\x05\x07\r)\x03\x02\x86\x0b)\x03\x01\x0f)\x03\r\x0f)\x03\t\x0f)\x03\x05\x0f)\x03\x01\x11)\x03\t\t)\x07\t\x05\x05\t)\x03\x05\x11)\x07\t\x11\x11\t)\x03\r\x11)\x05\t\x05\t)\x05\t\r\t)\x03\t\x11\x04\xde\x02\x05\x01\x11\x05\x15\x07\x03\x01\x05\t\x11\x05!\x07\x031Y\x03\x03\x05+\x03\x07\x03\x03\x01\t\x03\x05\x03\x03\x013\x03\x05\x03\x03\x01\t\x03\x05\x03\x03\x01\t\x03\x05\x03\x03\x015\x03\x05\x03\x03\x017\x03\x05\x0b\x07\x019\t\x07\r\x17\x1b\x0f\x03\x05\x07\t\x0b\r\x01\x03\x03\x01K\x03\x05\x05\x07\x01\x0b\x03\x17\x03\x17\r\x07\x01M\x03\'\x05\x13\x19\x05\x07\x01\x11\x03)\x03\x1b\x03\x03\x01\x13\x03\x13\x05\x07\x01\x0b\x03\x07\x03\x1f\x05\x07\x01S\x03-\x03\x1d\x07\x06\x01\x03\x07\x07#\x0f!\x05\x07\x01\x11\x031\x03\x1b\x03\x03\x01\x13\x03\x13\x05\x07\x01\x0b\x03\r\x03)\x05\x07\x01U\x033\x03\'\x07\x06\x01\x03\r\x07-\x11+\x0f\x04\x05\x05%/\x06\x03\x01\x05\x01\x00\xf2\tU\x1d\x03\x0f\x0b\t\t\x11#!+\x1b\x1f/!!)#\x1f\x19i?\x1f\x15\x1d\x15\x13%)9\x13+\r\x15\x17\x1f\x11\x15)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00select_v1\x00func_v1\x00custom_call_v1\x00compare_v1\x00return_v1\x00value\x00broadcast_dimensions\x00sym_name\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00jit(func)/jit(main)/hessenberg\x00third_party/py/jax/tests/export_back_compat_test.py\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00compare_type\x00comparison_direction\x00jax.result_info\x00mhlo.layout_mode\x00default\x00[0]\x00[1]\x00main\x00public\x00\x00lapack_sgehrd\x00', - xla_call_module_version=9, - nr_devices=1, -) # End paste - - -# Pasted from the test output (see export_back_compat_test_util.py module docstring) -data_2024_08_30["f64"] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_dgehrd'], - serialized_date=datetime.date(2024, 8, 30), - inputs=(), - expected_outputs=(array([[[ 0.9307390587491866 , -0.35692982324474015 , - -0.1271353200176119 , -0.43952156917870067 ], - [ 2.2633695323673964 , 0.9965090965971986 , - -1.3244131008423046 , 1.7324542351344163 ], - [ 0.24558316247256504 , 2.922776762811796 , - 3.630059093036474 , 1.4330664619737252 ], - [-0.2856727718012896 , -0.4601276537179077 , - -2.8602148466873802 , 1.9928744545245372 ]], - - [[-0.5351339571818844 , 5.753313169426148 , - 0.1385440281649789 , 2.8445493054193807 ], - [ 4.676815781213274 , 2.920688567170204 , - -2.610159425457712 , 4.0359806870679655 ], - [-0.16963242599901043 , -2.342935131066633 , - 4.179999589709703 , -0.6810604472011716 ], - [ 0.030645999613174775, -0.2271804227402005 , - -2.2755242550977153 , 0.7136684502626782 ]]]), array([[1.751436143556826 , 1.6505497938190505, 0. ], - [1.9422862513069978, 1.9018440331997255, 0. ]])), - mlir_module_text=r""" -module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main() -> (tensor<2x4x4xf64> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<2x3xf64> {jax.result_info = "[1]", mhlo.layout_mode = "default"}) { - %cst = stablehlo.constant dense<[[[0.93073905874918661, 0.18483901505653183, -0.11804347408930886, -0.53725392025434981], [-1.700777672846173, 1.3531570270421245, -2.4375034855727518, 2.2945174202226699], [-0.97352780716312858, -0.8319788592736328, 2.4986640885328582, -2.8118637941861766], [1.1324489199416958, -1.9301638714393787, 1.5523821278819048, 2.7676215285832253]], [[-0.53513395718188439, -5.2137633671981938, 2.9644475919777618, 2.2891023676266191], [-4.4068992105328642, 1.2751848926168665, -2.8947257279736456, -2.6817410994805888], [1.5408926111334784, -0.85423691880254915, 6.4217874587762065, -0.43997818045540715], [-0.27837952612324207, 1.1509460853774549, -0.21686805683301608, 0.11738425574951133]]]> : tensor<2x4x4xf64> loc(#loc) - %c = stablehlo.constant dense<4> : tensor loc(#loc2) - %c_0 = stablehlo.constant dense<1> : tensor loc(#loc2) - %c_1 = stablehlo.constant dense<4> : tensor loc(#loc2) - %c_2 = stablehlo.constant dense<4> : tensor loc(#loc2) - %c_3 = stablehlo.constant dense<2> : tensor loc(#loc2) - %c_4 = stablehlo.constant dense<4288> : tensor loc(#loc2) - %0:4 = stablehlo.custom_call @lapack_dgehrd(%c, %c_0, %c_1, %c_2, %c_3, %c_4, %cst) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor<2x4x4xf64>) -> (tensor<2x4x4xf64>, tensor<2x3xf64>, tensor<2xi32>, tensor<4288xf64>) loc(#loc2) - %c_5 = stablehlo.constant dense<0> : tensor loc(#loc2) - %1 = stablehlo.broadcast_in_dim %c_5, dims = [] : (tensor) -> tensor<2xi32> loc(#loc2) - %2 = stablehlo.compare EQ, %0#2, %1, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc2) - %3 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc2) - %cst_6 = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc2) - %4 = stablehlo.broadcast_in_dim %cst_6, dims = [] : (tensor) -> tensor<2x4x4xf64> loc(#loc2) - %5 = stablehlo.broadcast_in_dim %3, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc2) - %6 = stablehlo.select %5, %0#0, %4 : tensor<2x4x4xi1>, tensor<2x4x4xf64> loc(#loc2) - %7 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc2) - %cst_7 = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc2) - %8 = stablehlo.broadcast_in_dim %cst_7, dims = [] : (tensor) -> tensor<2x3xf64> loc(#loc2) - %9 = stablehlo.broadcast_in_dim %7, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x3xi1> loc(#loc2) - %10 = stablehlo.select %9, %0#1, %8 : tensor<2x3xi1>, tensor<2x3xf64> loc(#loc2) - return %6, %10 : tensor<2x4x4xf64>, tensor<2x3xf64> loc(#loc) - } loc(#loc) -} loc(#loc) -#loc = loc(unknown) -#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":697:13) -#loc2 = loc("jit(func)/jit(main)/hessenberg"(#loc1)) -""", - mlir_module_serialized=b'ML\xefR\x01StableHLO_v0.9.0\x00\x01\x1f\x05\x01\x03\x01\x03\x05\x03\x0f\x07\t\x0b\r\x0f\x11\x13\x03\xed\xa17\x01W\x0f\x0b\x07\x0b\x13\x13\x0f\x0b\x13\x13+\x0b\x0f\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x13\x0b\x17\x0b\x13\x13\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x0b\x13\x13\x03K\x0f\x0b\x0b\x0b\x0bo/\x0b\x13\x1b\x0b\x1b\x0b\x0b\x0b&\x08\x1f\x1f\x1f\x1f\x0b\x0b\x0b\x0b\'\x0f\x17\x1bO\x1f\x0f\x0b\x0b//oO\x01\x05\x0b\x0f\x033\x0f\x1b\x07\x07\x17\x07\x07\x0f\x07\x13\x17\x17\x13\x13\x13\x13\x13\x13\x1b\x13\x1b\x13\x17\x17\x13\x02\xa6\x0b\x1d-/\x05\x15\x1f\x05\x17\x03\x03\x03w\x03\x03\x07\x93\x11\x03\x05\x05\x19\x03\x03\x07\x99\x03\x03\x03\x9b\x03\t\x17\x19\x1b\r\x1d\r\x0f\x1f\x05\x1b\x11\x01\x00\x05\x1d\x05\x1f\x05!\x03\x0b#Y%e\'g\x0fq)s\x05#\x05%\x05\'\x05)\x03\x03\x03u\x05+\x171\xe6\n\x1b\x05-\x03\x03\x03y\x03\x03\x03{\x03\x03\x03}\x03\x11;\x7f=\x81?\x83AYC\x85E\x87G\x89I\x8d\x05/\x051\x053\x055\x057\x059\x05;\x05=\x03\x03\x03\x91\x03\x05O\x95Q\x97\x05?\x05A\x03\x03\x07\x9d\x03\x03\x07\x9f\x1f\x1d\x01\x03\x01\x1dC\x1dE\x1dG\x1f\x1f1\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f#\x11\x00\x00\x00\x00\x00\x00\x00\x00#\x19\x03\x05im\r\x05[k]_\x1dI\r\x05[o]_\x1dK\x1dM\x1dO\x1f\x07\x02\x04\xa6\x00NG\x9d\xc8\xed?\xf2\xa8X\n\xce\xa8\xc7?#E\xb8\xdc\x188\xbe\xbf\xb8|$"/1\xe1\xbf\xc4B*\xa6b6\xfb\xbf\xe8\xf9\x97\xfb\x87\xa6\xf5?)^\xd3\xd3\x01\x80\x03\xc0T\xab\xff\xf2+[\x02@4d\xb0\xc9#\'\xef\xbf~e\xf1 \x92\x9f\xea\xbf\x96\x81\xff\x98C\xfd\x03@W\xb0\xe6q\xb2~\x06\xc0F\xa48\xc2\x82\x1e\xf2?\xcc\x0b\xfc\x82\xf3\xe1\xfe\xbf\xdc\\b\xa4\x8e\xd6\xf8?\x8c\xc3\x87\xc1\x16$\x06@\x83h\xa2?\xd1\x1f\xe1\xbf\xdc\xcb\xbc\xc8\xe4\xda\x14\xc0\xe6\x00\x92L0\xb7\x07@Q8\xf1\xe6\x14P\x02@\t\x07\xc8/\xaa\xa0\x11\xc0\x8eH"F(g\xf4?\xf5Jd\xf6e(\x07\xc0\x9e\xddt\xad4t\x05\xc0\x1cv\xb7\x02\x7f\xa7\xf8?B^\xa9\xa9\xe8U\xeb\xbf\x1e:5\r\xe9\xaf\x19@\xa2\x9c\x00>\x9a(\xdc\xbf\xc1\xd1$\\\xf8\xd0\xd1\xbf}|BqFj\xf2?6\x8b\xd2\x1dU\xc2\xcb\xbfdk\x82\x03\xe5\x0c\xbe?\x1f\x05\t\x04\x00\x00\x00\x1f\x05\t\x01\x00\x00\x00\x1f\x05\t\x02\x00\x00\x00\x1f\x05\t\xc0\x10\x00\x00\x0b\x05\x1dQ\x1dS\x05\x01\x03\x0fWWWWWWa\x03\x03\x8b\x15\x03\x01\x19\x01\x03\ta\x8fcc\x1f!!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x05\t\x00\x00\x00\x00\x1f%\x01\t\x07\x07\x01\x1f+\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x13\x11\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f/1\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x1f5!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x01\x15)\x07\t\x11\x11\x0b\x01\x0b)\x05\t\r\x0b\x13\x1d)\x01\x0b\x1b)\x03\t\x15\x11\x01\x05\x07\r)\x03\x02\x86\x0b)\x03\x01\x0f)\x03\r\x0f)\x03\t\x0f)\x03\x05\x0f)\x03\x01\x11)\x03\t\t)\x07\t\x05\x05\t)\x03\x05\x11)\x07\t\x11\x11\t)\x03\r\x11)\x05\t\x05\t)\x05\t\r\t)\x03\t\x11\x04\xde\x02\x05\x01\x11\x05\x15\x07\x03\x01\x05\t\x11\x05!\x07\x031Y\x03\x03\x05+\x03\x07\x03\x03\x01\t\x03\x05\x03\x03\x013\x03\x05\x03\x03\x01\t\x03\x05\x03\x03\x01\t\x03\x05\x03\x03\x015\x03\x05\x03\x03\x017\x03\x05\x0b\x07\x019\t\x07\r\x17\x1b\x0f\x03\x05\x07\t\x0b\r\x01\x03\x03\x01K\x03\x05\x05\x07\x01\x0b\x03\x17\x03\x17\r\x07\x01M\x03\'\x05\x13\x19\x05\x07\x01\x11\x03)\x03\x1b\x03\x03\x01\x13\x03\x13\x05\x07\x01\x0b\x03\x07\x03\x1f\x05\x07\x01S\x03-\x03\x1d\x07\x06\x01\x03\x07\x07#\x0f!\x05\x07\x01\x11\x031\x03\x1b\x03\x03\x01\x13\x03\x13\x05\x07\x01\x0b\x03\r\x03)\x05\x07\x01U\x033\x03\'\x07\x06\x01\x03\r\x07-\x11+\x0f\x04\x05\x05%/\x06\x03\x01\x05\x01\x00\xf2\tU\x1d\x03\x0f\x0b\t\t\x11#!+\x1b\x1f/!!)#\x1f\x19i?\x1f\x15\x1d\x15\x13%)9\x13+\r\x15\x17\x1f\x11\x15)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00select_v1\x00func_v1\x00custom_call_v1\x00compare_v1\x00return_v1\x00value\x00broadcast_dimensions\x00sym_name\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00jit(func)/jit(main)/hessenberg\x00third_party/py/jax/tests/export_back_compat_test.py\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00compare_type\x00comparison_direction\x00jax.result_info\x00mhlo.layout_mode\x00default\x00[0]\x00[1]\x00main\x00public\x00\x00lapack_dgehrd\x00', - xla_call_module_version=9, - nr_devices=1, -) # End paste - data_2024_08_31 = {} - # Pasted from the test output (see export_back_compat_test_util.py module docstring) data_2024_08_31["c128"] = dict( testdata_version=1, diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/cpu_lu_lapack_getrf.py b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_lu_lapack_getrf.py index 72d97df53a4f..2290db62e436 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_data/cpu_lu_lapack_getrf.py +++ b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_lu_lapack_getrf.py @@ -17,527 +17,8 @@ import datetime from numpy import array, int32, float32, complex64 -data_2023_06_14 = {} - -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_06_14['f32'] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_sgetrf'], - serialized_date=datetime.date(2023, 6, 14), - inputs=(), - expected_outputs=(array([[6. , 7. , 8. ], - [0. , 1. , 2. ], - [0.5, 0.5, 0. ]], dtype=float32), array([2, 2, 2], dtype=int32), array([2, 0, 1], dtype=int32)), - mlir_module_text=r""" -#loc = loc(unknown) -module @jit__lambda_ attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main() -> (tensor<3x3xf32> {jax.result_info = "[0]"}, tensor<3xi32> {jax.result_info = "[1]"}, tensor<3xi32> {jax.result_info = "[2]"}) { - %0 = stablehlo.iota dim = 0 : tensor<9xf32> loc(#loc3) - %1 = stablehlo.reshape %0 : (tensor<9xf32>) -> tensor<3x3xf32> loc(#loc4) - %2 = stablehlo.constant dense<3> : tensor loc(#loc5) - %3 = stablehlo.constant dense<3> : tensor loc(#loc5) - %4 = stablehlo.convert %2 : (tensor) -> tensor loc(#loc5) - %5 = stablehlo.reshape %4 : (tensor) -> tensor<1xi32> loc(#loc5) - %6 = stablehlo.convert %3 : (tensor) -> tensor loc(#loc5) - %7 = stablehlo.reshape %6 : (tensor) -> tensor<1xi32> loc(#loc5) - %8 = stablehlo.concatenate %5, %7, dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> loc(#loc5) - %9 = stablehlo.constant dense<3> : tensor loc(#loc5) - %10 = stablehlo.convert %9 : (tensor) -> tensor loc(#loc5) - %11 = stablehlo.reshape %10 : (tensor) -> tensor<1xi32> loc(#loc5) - %12 = stablehlo.constant dense<> : tensor<0xi32> loc(#loc5) - %13 = stablehlo.constant dense<1> : tensor loc(#loc5) - %14 = stablehlo.constant dense<3> : tensor loc(#loc5) - %15 = stablehlo.constant dense<3> : tensor loc(#loc5) - %16:3 = stablehlo.custom_call @lapack_sgetrf(%13, %14, %15, %1) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>]} : (tensor, tensor, tensor, tensor<3x3xf32>) -> (tensor<3x3xf32>, tensor<3xi32>, tensor) loc(#loc5) - %17 = stablehlo.constant dense<1> : tensor loc(#loc5) - %18 = stablehlo.broadcast_in_dim %17, dims = [] : (tensor) -> tensor<3xi32> loc(#loc5) - %19 = stablehlo.subtract %16#1, %18 : tensor<3xi32> loc(#loc5) - %20 = stablehlo.constant dense<0> : tensor loc(#loc5) - %21 = stablehlo.broadcast_in_dim %20, dims = [] : (tensor) -> tensor loc(#loc5) - %22 = stablehlo.compare GE, %16#2, %21, SIGNED : (tensor, tensor) -> tensor loc(#loc5) - %23 = stablehlo.broadcast_in_dim %22, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc5) - %24 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc5) - %25 = stablehlo.broadcast_in_dim %24, dims = [] : (tensor) -> tensor<3x3xf32> loc(#loc5) - %26 = stablehlo.broadcast_in_dim %23, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<3x3xi1> loc(#loc5) - %27 = stablehlo.select %26, %16#0, %25 : tensor<3x3xi1>, tensor<3x3xf32> loc(#loc5) - %28 = stablehlo.iota dim = 0 : tensor<3xi32> loc(#loc6) - %29 = stablehlo.constant dense<0> : tensor loc(#loc7) - %30 = stablehlo.constant dense<0> : tensor loc(#loc8) - %31:4 = stablehlo.while(%iterArg = %30, %iterArg_0 = %29, %iterArg_1 = %28, %iterArg_2 = %19) : tensor, tensor, tensor<3xi32>, tensor<3xi32> - cond { - %32 = stablehlo.constant dense<3> : tensor loc(#loc9) - %33 = stablehlo.compare LT, %iterArg, %32, SIGNED : (tensor, tensor) -> tensor loc(#loc10) - stablehlo.return %33 : tensor loc(#loc9) - } do { - %32 = stablehlo.constant dense<1> : tensor loc(#loc9) - %33 = stablehlo.add %iterArg_0, %32 : tensor loc(#loc11) - %34 = stablehlo.constant dense<0> : tensor loc(#loc9) - %35 = stablehlo.compare LT, %iterArg_0, %34, SIGNED : (tensor, tensor) -> tensor loc(#loc12) - %36 = stablehlo.constant dense<3> : tensor loc(#loc9) - %37 = stablehlo.add %iterArg_0, %36 : tensor loc(#loc11) - %38 = stablehlo.select %35, %37, %iterArg_0 : tensor, tensor loc(#loc13) - %39 = stablehlo.convert %38 : (tensor) -> tensor loc(#loc14) - %40 = stablehlo.broadcast_in_dim %39, dims = [] : (tensor) -> tensor<1xi32> loc(#loc15) - %41 = "stablehlo.gather"(%iterArg_2, %40) {dimension_numbers = #stablehlo.gather, indices_are_sorted = true, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1xi32>) -> tensor loc(#loc16) - %42 = stablehlo.constant dense<0> : tensor loc(#loc9) - %43 = stablehlo.compare LT, %iterArg_0, %42, SIGNED : (tensor, tensor) -> tensor loc(#loc12) - %44 = stablehlo.constant dense<3> : tensor loc(#loc9) - %45 = stablehlo.add %iterArg_0, %44 : tensor loc(#loc11) - %46 = stablehlo.select %43, %45, %iterArg_0 : tensor, tensor loc(#loc13) - %47 = stablehlo.convert %46 : (tensor) -> tensor loc(#loc14) - %48 = stablehlo.broadcast_in_dim %47, dims = [] : (tensor) -> tensor<1xi32> loc(#loc15) - %49 = "stablehlo.gather"(%iterArg_1, %48) {dimension_numbers = #stablehlo.gather, indices_are_sorted = true, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1xi32>) -> tensor loc(#loc16) - %50 = stablehlo.constant dense<0> : tensor loc(#loc9) - %51 = stablehlo.compare LT, %41, %50, SIGNED : (tensor, tensor) -> tensor loc(#loc12) - %52 = stablehlo.constant dense<3> : tensor loc(#loc9) - %53 = stablehlo.add %41, %52 : tensor loc(#loc11) - %54 = stablehlo.select %51, %53, %41 : tensor, tensor loc(#loc13) - %55 = stablehlo.dynamic_slice %iterArg_1, %54, sizes = [1] : (tensor<3xi32>, tensor) -> tensor<1xi32> loc(#loc17) - %56 = stablehlo.reshape %55 : (tensor<1xi32>) -> tensor loc(#loc18) - %57 = stablehlo.constant dense<0> : tensor loc(#loc9) - %58 = stablehlo.compare LT, %iterArg_0, %57, SIGNED : (tensor, tensor) -> tensor loc(#loc12) - %59 = stablehlo.constant dense<3> : tensor loc(#loc9) - %60 = stablehlo.add %iterArg_0, %59 : tensor loc(#loc11) - %61 = stablehlo.select %58, %60, %iterArg_0 : tensor, tensor loc(#loc13) - %62 = stablehlo.convert %61 : (tensor) -> tensor loc(#loc14) - %63 = stablehlo.broadcast_in_dim %62, dims = [] : (tensor) -> tensor<1xi32> loc(#loc15) - %64 = "stablehlo.scatter"(%iterArg_1, %63, %56) ({ - ^bb0(%arg0: tensor loc(unknown), %arg1: tensor loc(unknown)): - stablehlo.return %arg1 : tensor loc(#loc19) - }) {indices_are_sorted = true, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = true} : (tensor<3xi32>, tensor<1xi32>, tensor) -> tensor<3xi32> loc(#loc19) - %65 = stablehlo.constant dense<0> : tensor loc(#loc9) - %66 = stablehlo.compare LT, %41, %65, SIGNED : (tensor, tensor) -> tensor loc(#loc12) - %67 = stablehlo.constant dense<3> : tensor loc(#loc9) - %68 = stablehlo.add %41, %67 : tensor loc(#loc11) - %69 = stablehlo.select %66, %68, %41 : tensor, tensor loc(#loc13) - %70 = stablehlo.broadcast_in_dim %69, dims = [] : (tensor) -> tensor<1xi32> loc(#loc15) - %71 = "stablehlo.scatter"(%64, %70, %49) ({ - ^bb0(%arg0: tensor loc(unknown), %arg1: tensor loc(unknown)): - stablehlo.return %arg1 : tensor loc(#loc19) - }) {indices_are_sorted = true, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = true} : (tensor<3xi32>, tensor<1xi32>, tensor) -> tensor<3xi32> loc(#loc19) - %72 = stablehlo.constant dense<1> : tensor loc(#loc9) - %73 = stablehlo.add %iterArg, %72 : tensor loc(#loc11) - stablehlo.return %73, %33, %71, %iterArg_2 : tensor, tensor, tensor<3xi32>, tensor<3xi32> loc(#loc9) - } loc(#loc9) - return %27, %19, %31#2 : tensor<3x3xf32>, tensor<3xi32>, tensor<3xi32> loc(#loc) - } loc(#loc) -} loc(#loc) -#loc1 = loc("third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py":550:0) -#loc2 = loc("third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py":551:0) -#loc3 = loc("jit()/jit(main)/iota[dtype=float32 shape=(9,) dimension=0]"(#loc1)) -#loc4 = loc("jit()/jit(main)/reshape[new_sizes=(3, 3) dimensions=None]"(#loc1)) -#loc5 = loc("jit()/jit(main)/lu"(#loc2)) -#loc6 = loc("jit()/jit(main)/iota[dtype=int32 shape=(3,) dimension=0]"(#loc2)) -#loc7 = loc("jit()/jit(main)/lu_pivots_to_permutation[permutation_size=3]"(#loc2)) -#loc8 = loc("jit()/jit(main)/scan[reverse=False length=3 num_consts=0 num_carry=3 linear=(False, False, False) unroll=1]"(#loc2)) -#loc9 = loc("jit()/jit(main)/while[cond_nconsts=0 body_nconsts=0]"(#loc2)) -#loc10 = loc("jit()/jit(main)/while/cond/lt"(#loc2)) -#loc11 = loc("jit()/jit(main)/while/body/add"(#loc2)) -#loc12 = loc("jit()/jit(main)/while/body/lt"(#loc2)) -#loc13 = loc("jit()/jit(main)/while/body/select_n"(#loc2)) -#loc14 = loc("jit()/jit(main)/while/body/convert_element_type[new_dtype=int32 weak_type=False]"(#loc2)) -#loc15 = loc("jit()/jit(main)/while/body/broadcast_in_dim[shape=(1,) broadcast_dimensions=()]"(#loc2)) -#loc16 = loc("jit()/jit(main)/while/body/gather[dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)) slice_sizes=(1,) unique_indices=True indices_are_sorted=True mode=GatherScatterMode.PROMISE_IN_BOUNDS fill_value=None]"(#loc2)) -#loc17 = loc("jit()/jit(main)/while/body/dynamic_slice[slice_sizes=(1,)]"(#loc2)) -#loc18 = loc("jit()/jit(main)/while/body/squeeze[dimensions=(0,)]"(#loc2)) -#loc19 = loc("jit()/jit(main)/while/body/scatter[update_consts=() dimension_numbers=ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,)) indices_are_sorted=True unique_indices=True mode=GatherScatterMode.FILL_OR_DROP]"(#loc2)) -""", - mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x013\x05\x01\x03\x01\x03\x05\x03#\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x1f!#%'\x03\xa6\x02\x0e\x023\x01\xb1\x0f\x0f\x07\x17\x0b\x13\x13\x0f\x1b\x13\x0f\x0f\x13\x0f\x13\x13\x13\x0f\x0f\x0b\x13\x17\x0b\x0b\x0b\x0b;\x0b\x0b\x0b\x0f;#\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x0b\x0f\x0b\x0f\x0b\x0b\x13\x0b\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x13\x13\x0f\x0b\x0f\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x0f\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x03Q\x0f\x0f/\x0b\x0f\x0b\x0bO\x0b/\x0b\x17\x13\x0b\x13\x0b\x13\x0b\x0b\x0b/\x0f/\x1f\x0b\x0b\x0b\x0b\x1b\x0f\x17\x17/\x1f\x1f\x0b\x1fO/\x0b\x01\x07\x0b\x13\x0b\x01\x03\x0f\x031\x0f\x0f\x13\x13\x0f\x17\x07\x07\x07\x07\x07\x13\x0f\x13\x1b\x13\x13\x13\x13\x13\x13\x17\x17\x13\x02J\t\x1d]\x07\x1d\x8b\x07\x1f\x17-\x9e\x08\x01\x05)\x03\x03/\xb9\x03\x03\t\xd9\x1d\x8d\x07\x03\x051\xc13\xff\x03\x03\t\xfd\x1d\x8f\x07\x1d\x91\x07\x03\x03\t\xdf\x1d\x95\x07\x1d\x02\x02\x07\x03\x03\t\xdd\x03\x03\t\xf5\x1d\x93\x07\x11\x01\x05\x05+\x03\x03S\xb1\x17-\x9a\x08\x01\x05-\x05/\x051\x053\x03\r\x97\xb57\xb19\xbb\x99\xb9;\xc3\x9b\xb5\x055\x057\x059\x1d\x9d\x07\x03\r7\xb19\xbb\xa9\xb5\xab\xb5\xad\xbb\xaf\xb9\x03\x07C%E%'G\x05;\x05=\x05?\x03\x0bK\xbdM\xc5O\xc7'\xd5Q\xd7\x05A\x05C\x05E\x05G\x05I\x1dW+\x05K\x1d[+\x05M\x05O\x03\x03a\xb1\x05Q\x03\x03\t\xdb\x03\x11g\xe1i\xe3k\xe5m\xbdo\xe7q\xe9s\xebu\xef\x05S\x05U\x05W\x05Y\x05[\x05]\x05_\x05a\x03\x03\t\xf3\x03\x051\xc13\xf7\x03\x03\t\xf9\x03\x03/\xfb\x1d\x81\x07\x05c\x1d\x85\x07\x05e\x1d\x89\x07\x05g\x05i\x05k\x05m\x05o\x05q\x05s\x05u\x05w\x05y\x05{\x03\x03;\xc3\x1d\xa3\x07\x05}\x1d\xa7\x07\x05\x7f\x05\x81\x05\x83\x05\x85\x05\x87\x13\x11\x01\x1f%\x01\x1f\x1d\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1d\x89\x1f+\x01\x05\x03\x03\x01\x1f'!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\t\x07\x1f\x1d\x11\x01\x00\x00\x00\x00\x00\x00\x00#\x1f\x03\x07\xc9\xcd\xd1\r\x03\xb7\xcb\x1d\x8b\r\x03\xb7\xcf\x1d\x8d\r\x03\xb7\xd3\x1d\x8f\x1d\x91\x1d\x93\x1f\x03\x11\x03\x00\x00\x00\x00\x00\x00\x00\x1f\x19\x01\x1f\x03\x11\x01\x00\x00\x00\x00\x00\x00\x00\x1f\x05\t\x03\x00\x00\x00\x0b\x05\x1d\x95\x1d\x97\x05\x01\x03\t\xb3\xb3\xb3\xbf\x03\x03\xed\x15\x03\x01\r\x01\x03\x07\xbf\xf1\xb3\x1f)\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x05\t\x01\x00\x00\x00\x1f\x05\t\x00\x00\x00\x00\x07\x05\x1f\x1b\t\x00\x00\xc0\x7f\x1f1!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f\x03\x11\x00\x00\x00\x00\x00\x00\x00\x00\x07\x0b\x05\x99\x1d\n\x02\x07\x05\x9b\x01\x02\x02)\x01\x11)\x01\x0f)\x03\r\x0f)\x03\x05\x0f)\x01\x17)\x05\r\r\x13\x1b\x1d\t\x13\x01)\x03\x01\x0f)\x01\x13)\x03\x05\x11\x11\x01\x07\r\x07\x07)\x03%\x13)\x03\t\x0f)\x03\x01\x15)\x03\t\x15)\x03\x05\x15)\x03\x01\x11)\x05\x05\x05\x17)\x05\r\r\x17)\x03\t\x11\x04f\n\x05\x01\x11\x05A\x07\x03\x01\x05\x19\x11\x05I\x05\x03K\x85\x13\x03U)\x03!\x0f\x06Y\x03\r\x03\x01\x03\x03\x01\r\x03\x03\x03\x03\x01\r\x03\x03\x0b\x06\x01\x03\x05\x03\x05\x0f\x06\x01\x03\t\x03\t\x0b\x06\x01\x03\x05\x03\x07\x0f\x06\x01\x03\t\x03\r\x1b\x07\x01_\x03#\x05\x0b\x0f\x03\x03\x01\r\x03\x03\x0b\x06\x01\x03\x05\x03\x13\x0f\x06\x01\x03\t\x03\x15\x03\x03\x01c\x03\x19\x03\x03\x01\x1f\x03\x03\x03\x03\x01\x19\x03\x05\x03\x03\x01\x19\x03\x05\x1d\x07\x01e\x07\r\x07\x05\t\x1b\x1d\x1f\x03\x03\x03\x01w\x03\x05\x05\x07\x01\x0b\x03\x07\x03'\x1f\x06\x01\x03\x07\x05#)\x03\x03\x01!\x03\x05\x05\x07\x01\x0b\x03\x05\x03-\x07\x07\x01y\x03\x0b\x05%/\x05\x07\x01\x0b\x03-\x031\x03\x03\x01{\x03\x1b\x05\x07\x01\x0b\x03\r\x035\x05\x07\x01}\x03/\x033\r\x06\x01\x03\r\x079!7\x13\x03\x7f)\x03\x07\x03\x03\x83\x13\x03\x03\x03\x03\x87\x13\x03\x03!\x16\x03\t\x03\x03\x07\x07\tA?=+\t\x03\r\x0f\t\x03\x05\x03\x05\x07\x05\x07\x05\x03\x03\x03\r\x03\x03\x07\x07\x06\x02\x11\x03\x0b\x05KS\x11\x04\x03\x03U\x03]\xaf\t\x03\x05\x03\x05\x07\x05\x07\x05\x03\x03\x03\x1f\x03\x03\t\x06\x0f\x03\x03\x05MS\x03\x03\x03\x13\x03\x03\x07\x07\x15\x11\x03\x0b\x05MW\x03\x03\x03\r\x03\x03\t\x06\x0f\x03\x03\x05M[\r\x06\x17\x03\x03\x07Y]M\x0b\x06#\x03\x05\x03_\x05\x07\x1b\x0b\x03\t\x03a\x15\x07=5\x03\x05\x05Qc\x03\x03\x03\x13\x03\x03\x07\x07\x15\x11\x03\x0b\x05Mg\x03\x03\x03\r\x03\x03\t\x06\x0f\x03\x03\x05Mk\r\x06\x17\x03\x03\x07imM\x0b\x06#\x03\x05\x03o\x05\x07\x1b\x0b\x03\t\x03q\x15\x07=5\x03\x05\x05Os\x03\x03\x03!\x03\x05\x07\x07\x15\x11\x03\x0b\x05ew\x03\x03\x03\x19\x03\x05\t\x06\x0f\x03\x05\x05e{\r\x06\x17\x03\x05\x07y}e#\x07\xa1\x9f\x03\t\x05O\x7f\x0f\x06\xa5\x03\x05\x03\x81\x03\x03\x03\x13\x03\x03\x07\x07\x15\x11\x03\x0b\x05M\x85\x03\x03\x03\r\x03\x03\t\x06\x0f\x03\x03\x05M\x89\r\x06\x17\x03\x03\x07\x87\x8bM\x0b\x06#\x03\x05\x03\x8d\x05\x07\x1b\x0b\x03\t\x03\x8f\x17\x17\x1d?\x03\x07\x07O\x91\x83\x05\x03\x05\x07\x05\x05\x05\x05\x05\x11\x04\x1d\x03\xa9\x03\x03\x03!\x03\x05\x07\x07\x15\x11\x03\x0b\x05e\x95\x03\x03\x03\x19\x03\x05\t\x06\x0f\x03\x05\x05e\x99\r\x06\x17\x03\x05\x07\x97\x9be\x05\x07\x1b\x0b\x03\t\x03\x9d\x17\x17\x1d?\x03\x07\x07\x93\x9fu\x05\x03\x05\x07\x05\x05\x05\x05\x05\x11\x04\x1d\x03\xa9\x03\x03\x03\x1f\x03\x03\t\x06\x0f\x03\x03\x05K\xa3\x11\x04\x03\t\xa5U\xa1Q\x11\x04\x05\x07;+G\x06\x03\x01\x05\x01\x00v%\x9dM2\x04\x1d\x03\x0f\x0b\t\t\t!'\x1f;+y\x87.\x04!\x19+\xb1\xb3YMO{\xe9\x8b\x83\x1f/!!)#\x1f\x19\x157\x85\x87\x1f\x1f\x15\x1d\x15\x1b%)\x19'#+\x1b+\x83\x13\r#\x13\x19\x1f\x1f\x11\x17\x15\x11\x15\x17\x15\x17\x0f\x17)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00compare_v1\x00add_v1\x00convert_v1\x00select_v1\x00reshape_v1\x00return_v1\x00iota_v1\x00gather_v1\x00scatter_v1\x00func_v1\x00concatenate_v1\x00custom_call_v1\x00subtract_v1\x00while_v1\x00dynamic_slice_v1\x00value\x00sym_name\x00third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py\x00broadcast_dimensions\x00compare_type\x00comparison_direction\x00index_vector_dim\x00indices_are_sorted\x00slice_sizes\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00jit()/jit(main)/iota[dtype=float32 shape=(9,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(3, 3) dimensions=None]\x00jit()/jit(main)/lu\x00dimension\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jit()/jit(main)/iota[dtype=int32 shape=(3,) dimension=0]\x00jit()/jit(main)/lu_pivots_to_permutation[permutation_size=3]\x00jit()/jit(main)/scan[reverse=False length=3 num_consts=0 num_carry=3 linear=(False, False, False) unroll=1]\x00jit()/jit(main)/while[cond_nconsts=0 body_nconsts=0]\x00jit()/jit(main)/while/body/add\x00jit()/jit(main)/while/body/lt\x00jit()/jit(main)/while/body/select_n\x00jit()/jit(main)/while/body/convert_element_type[new_dtype=int32 weak_type=False]\x00jit()/jit(main)/while/body/broadcast_in_dim[shape=(1,) broadcast_dimensions=()]\x00collapsed_slice_dims\x00offset_dims\x00start_index_map\x00jit()/jit(main)/while/body/gather[dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)) slice_sizes=(1,) unique_indices=True indices_are_sorted=True mode=GatherScatterMode.PROMISE_IN_BOUNDS fill_value=None]\x00jit()/jit(main)/while/body/dynamic_slice[slice_sizes=(1,)]\x00jit()/jit(main)/while/body/squeeze[dimensions=(0,)]\x00inserted_window_dims\x00scatter_dims_to_operand_dims\x00unique_indices\x00update_window_dims\x00jax.result_info\x00[0]\x00[1]\x00[2]\x00main\x00public\x00\x00lapack_sgetrf\x00jit()/jit(main)/while/body/scatter[update_consts=() dimension_numbers=ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,)) indices_are_sorted=True unique_indices=True mode=GatherScatterMode.FILL_OR_DROP]\x00jit()/jit(main)/while/cond/lt\x00", - xla_call_module_version=6, -) # End paste - - -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_06_14['f64'] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_dgetrf'], - serialized_date=datetime.date(2023, 6, 14), - inputs=(), - expected_outputs=(array([[6. , 7. , 8. ], - [0. , 1. , 2. ], - [0.5, 0.5, 0. ]]), array([2, 2, 2], dtype=int32), array([2, 0, 1], dtype=int32)), - mlir_module_text=r""" -#loc = loc(unknown) -module @jit__lambda_ attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main() -> (tensor<3x3xf64> {jax.result_info = "[0]"}, tensor<3xi32> {jax.result_info = "[1]"}, tensor<3xi32> {jax.result_info = "[2]"}) { - %0 = stablehlo.iota dim = 0 : tensor<9xf64> loc(#loc3) - %1 = stablehlo.reshape %0 : (tensor<9xf64>) -> tensor<3x3xf64> loc(#loc4) - %2 = stablehlo.constant dense<3> : tensor loc(#loc5) - %3 = stablehlo.constant dense<3> : tensor loc(#loc5) - %4 = stablehlo.convert %2 : (tensor) -> tensor loc(#loc5) - %5 = stablehlo.reshape %4 : (tensor) -> tensor<1xi32> loc(#loc5) - %6 = stablehlo.convert %3 : (tensor) -> tensor loc(#loc5) - %7 = stablehlo.reshape %6 : (tensor) -> tensor<1xi32> loc(#loc5) - %8 = stablehlo.concatenate %5, %7, dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> loc(#loc5) - %9 = stablehlo.constant dense<3> : tensor loc(#loc5) - %10 = stablehlo.convert %9 : (tensor) -> tensor loc(#loc5) - %11 = stablehlo.reshape %10 : (tensor) -> tensor<1xi32> loc(#loc5) - %12 = stablehlo.constant dense<> : tensor<0xi32> loc(#loc5) - %13 = stablehlo.constant dense<1> : tensor loc(#loc5) - %14 = stablehlo.constant dense<3> : tensor loc(#loc5) - %15 = stablehlo.constant dense<3> : tensor loc(#loc5) - %16:3 = stablehlo.custom_call @lapack_dgetrf(%13, %14, %15, %1) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>]} : (tensor, tensor, tensor, tensor<3x3xf64>) -> (tensor<3x3xf64>, tensor<3xi32>, tensor) loc(#loc5) - %17 = stablehlo.constant dense<1> : tensor loc(#loc5) - %18 = stablehlo.broadcast_in_dim %17, dims = [] : (tensor) -> tensor<3xi32> loc(#loc5) - %19 = stablehlo.subtract %16#1, %18 : tensor<3xi32> loc(#loc5) - %20 = stablehlo.constant dense<0> : tensor loc(#loc5) - %21 = stablehlo.broadcast_in_dim %20, dims = [] : (tensor) -> tensor loc(#loc5) - %22 = stablehlo.compare GE, %16#2, %21, SIGNED : (tensor, tensor) -> tensor loc(#loc5) - %23 = stablehlo.broadcast_in_dim %22, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc5) - %24 = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc5) - %25 = stablehlo.broadcast_in_dim %24, dims = [] : (tensor) -> tensor<3x3xf64> loc(#loc5) - %26 = stablehlo.broadcast_in_dim %23, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<3x3xi1> loc(#loc5) - %27 = stablehlo.select %26, %16#0, %25 : tensor<3x3xi1>, tensor<3x3xf64> loc(#loc5) - %28 = stablehlo.iota dim = 0 : tensor<3xi32> loc(#loc6) - %29 = stablehlo.constant dense<0> : tensor loc(#loc7) - %30 = stablehlo.constant dense<0> : tensor loc(#loc8) - %31:4 = stablehlo.while(%iterArg = %30, %iterArg_0 = %29, %iterArg_1 = %28, %iterArg_2 = %19) : tensor, tensor, tensor<3xi32>, tensor<3xi32> - cond { - %32 = stablehlo.constant dense<3> : tensor loc(#loc9) - %33 = stablehlo.compare LT, %iterArg, %32, SIGNED : (tensor, tensor) -> tensor loc(#loc10) - stablehlo.return %33 : tensor loc(#loc9) - } do { - %32 = stablehlo.constant dense<1> : tensor loc(#loc9) - %33 = stablehlo.add %iterArg_0, %32 : tensor loc(#loc11) - %34 = stablehlo.constant dense<0> : tensor loc(#loc9) - %35 = stablehlo.compare LT, %iterArg_0, %34, SIGNED : (tensor, tensor) -> tensor loc(#loc12) - %36 = stablehlo.constant dense<3> : tensor loc(#loc9) - %37 = stablehlo.add %iterArg_0, %36 : tensor loc(#loc11) - %38 = stablehlo.select %35, %37, %iterArg_0 : tensor, tensor loc(#loc13) - %39 = stablehlo.convert %38 : (tensor) -> tensor loc(#loc14) - %40 = stablehlo.broadcast_in_dim %39, dims = [] : (tensor) -> tensor<1xi32> loc(#loc15) - %41 = "stablehlo.gather"(%iterArg_2, %40) {dimension_numbers = #stablehlo.gather, indices_are_sorted = true, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1xi32>) -> tensor loc(#loc16) - %42 = stablehlo.constant dense<0> : tensor loc(#loc9) - %43 = stablehlo.compare LT, %iterArg_0, %42, SIGNED : (tensor, tensor) -> tensor loc(#loc12) - %44 = stablehlo.constant dense<3> : tensor loc(#loc9) - %45 = stablehlo.add %iterArg_0, %44 : tensor loc(#loc11) - %46 = stablehlo.select %43, %45, %iterArg_0 : tensor, tensor loc(#loc13) - %47 = stablehlo.convert %46 : (tensor) -> tensor loc(#loc14) - %48 = stablehlo.broadcast_in_dim %47, dims = [] : (tensor) -> tensor<1xi32> loc(#loc15) - %49 = "stablehlo.gather"(%iterArg_1, %48) {dimension_numbers = #stablehlo.gather, indices_are_sorted = true, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1xi32>) -> tensor loc(#loc16) - %50 = stablehlo.constant dense<0> : tensor loc(#loc9) - %51 = stablehlo.compare LT, %41, %50, SIGNED : (tensor, tensor) -> tensor loc(#loc12) - %52 = stablehlo.constant dense<3> : tensor loc(#loc9) - %53 = stablehlo.add %41, %52 : tensor loc(#loc11) - %54 = stablehlo.select %51, %53, %41 : tensor, tensor loc(#loc13) - %55 = stablehlo.dynamic_slice %iterArg_1, %54, sizes = [1] : (tensor<3xi32>, tensor) -> tensor<1xi32> loc(#loc17) - %56 = stablehlo.reshape %55 : (tensor<1xi32>) -> tensor loc(#loc18) - %57 = stablehlo.constant dense<0> : tensor loc(#loc9) - %58 = stablehlo.compare LT, %iterArg_0, %57, SIGNED : (tensor, tensor) -> tensor loc(#loc12) - %59 = stablehlo.constant dense<3> : tensor loc(#loc9) - %60 = stablehlo.add %iterArg_0, %59 : tensor loc(#loc11) - %61 = stablehlo.select %58, %60, %iterArg_0 : tensor, tensor loc(#loc13) - %62 = stablehlo.convert %61 : (tensor) -> tensor loc(#loc14) - %63 = stablehlo.broadcast_in_dim %62, dims = [] : (tensor) -> tensor<1xi32> loc(#loc15) - %64 = "stablehlo.scatter"(%iterArg_1, %63, %56) ({ - ^bb0(%arg0: tensor loc(unknown), %arg1: tensor loc(unknown)): - stablehlo.return %arg1 : tensor loc(#loc19) - }) {indices_are_sorted = true, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = true} : (tensor<3xi32>, tensor<1xi32>, tensor) -> tensor<3xi32> loc(#loc19) - %65 = stablehlo.constant dense<0> : tensor loc(#loc9) - %66 = stablehlo.compare LT, %41, %65, SIGNED : (tensor, tensor) -> tensor loc(#loc12) - %67 = stablehlo.constant dense<3> : tensor loc(#loc9) - %68 = stablehlo.add %41, %67 : tensor loc(#loc11) - %69 = stablehlo.select %66, %68, %41 : tensor, tensor loc(#loc13) - %70 = stablehlo.broadcast_in_dim %69, dims = [] : (tensor) -> tensor<1xi32> loc(#loc15) - %71 = "stablehlo.scatter"(%64, %70, %49) ({ - ^bb0(%arg0: tensor loc(unknown), %arg1: tensor loc(unknown)): - stablehlo.return %arg1 : tensor loc(#loc19) - }) {indices_are_sorted = true, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = true} : (tensor<3xi32>, tensor<1xi32>, tensor) -> tensor<3xi32> loc(#loc19) - %72 = stablehlo.constant dense<1> : tensor loc(#loc9) - %73 = stablehlo.add %iterArg, %72 : tensor loc(#loc11) - stablehlo.return %73, %33, %71, %iterArg_2 : tensor, tensor, tensor<3xi32>, tensor<3xi32> loc(#loc9) - } loc(#loc9) - return %27, %19, %31#2 : tensor<3x3xf64>, tensor<3xi32>, tensor<3xi32> loc(#loc) - } loc(#loc) -} loc(#loc) -#loc1 = loc("third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py":553:0) -#loc2 = loc("third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py":554:0) -#loc3 = loc("jit()/jit(main)/iota[dtype=float64 shape=(9,) dimension=0]"(#loc1)) -#loc4 = loc("jit()/jit(main)/reshape[new_sizes=(3, 3) dimensions=None]"(#loc1)) -#loc5 = loc("jit()/jit(main)/lu"(#loc2)) -#loc6 = loc("jit()/jit(main)/iota[dtype=int32 shape=(3,) dimension=0]"(#loc2)) -#loc7 = loc("jit()/jit(main)/lu_pivots_to_permutation[permutation_size=3]"(#loc2)) -#loc8 = loc("jit()/jit(main)/scan[reverse=False length=3 num_consts=0 num_carry=3 linear=(False, False, False) unroll=1]"(#loc2)) -#loc9 = loc("jit()/jit(main)/while[cond_nconsts=0 body_nconsts=0]"(#loc2)) -#loc10 = loc("jit()/jit(main)/while/cond/lt"(#loc2)) -#loc11 = loc("jit()/jit(main)/while/body/add"(#loc2)) -#loc12 = loc("jit()/jit(main)/while/body/lt"(#loc2)) -#loc13 = loc("jit()/jit(main)/while/body/select_n"(#loc2)) -#loc14 = loc("jit()/jit(main)/while/body/convert_element_type[new_dtype=int32 weak_type=False]"(#loc2)) -#loc15 = loc("jit()/jit(main)/while/body/broadcast_in_dim[shape=(1,) broadcast_dimensions=()]"(#loc2)) -#loc16 = loc("jit()/jit(main)/while/body/gather[dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)) slice_sizes=(1,) unique_indices=True indices_are_sorted=True mode=GatherScatterMode.PROMISE_IN_BOUNDS fill_value=None]"(#loc2)) -#loc17 = loc("jit()/jit(main)/while/body/dynamic_slice[slice_sizes=(1,)]"(#loc2)) -#loc18 = loc("jit()/jit(main)/while/body/squeeze[dimensions=(0,)]"(#loc2)) -#loc19 = loc("jit()/jit(main)/while/body/scatter[update_consts=() dimension_numbers=ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,)) indices_are_sorted=True unique_indices=True mode=GatherScatterMode.FILL_OR_DROP]"(#loc2)) -""", - mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x013\x05\x01\x03\x01\x03\x05\x03#\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x1f!#%'\x03\xa6\x02\x0e\x023\x01\xb1\x0f\x0f\x07\x17\x0b\x13\x13\x0f\x1b\x13\x0f\x0f\x13\x0f\x13\x13\x13\x0f\x0f\x0b\x13\x17\x0b\x0b\x0b\x0b;\x0b\x0b\x0b\x0f;#\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x0b\x0f\x0b\x0f\x0b\x0b\x13\x0b\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x13\x13\x0f\x0b\x0f\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x0f\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x03Q\x0f\x0f/\x0b\x0f\x0b\x0bO\x0b/\x0b\x17\x13\x0b\x13\x0b\x13\x0b\x0b\x0b/\x0f/\x1f\x0b\x0b\x0b\x0b\x1b\x0f\x17\x17/\x1f\x1f\x0b/O/\x0b\x01\x07\x0b\x13\x0b\x01\x03\x0f\x031\x0f\x0f\x13\x13\x0f\x17\x07\x07\x07\x07\x07\x13\x0f\x13\x1b\x13\x13\x13\x13\x13\x13\x17\x17\x13\x02Z\t\x1d]\x07\x1d\x8b\x07\x1f\x17-\xaa\x08\x01\x05)\x03\x03/\xb9\x03\x03\t\xd9\x1d\x8d\x07\x03\x051\xc13\xff\x03\x03\t\xfd\x1d\x8f\x07\x1d\x91\x07\x03\x03\t\xdf\x1d\x95\x07\x1d\x02\x02\x07\x03\x03\t\xdd\x03\x03\t\xf5\x1d\x93\x07\x11\x01\x05\x05+\x03\x03S\xb1\x17-\xa6\x08\x01\x05-\x05/\x051\x053\x03\r\x97\xb57\xb19\xbb\x99\xb9;\xc3\x9b\xb5\x055\x057\x059\x1d\x9d\x07\x03\r7\xb19\xbb\xa9\xb5\xab\xb5\xad\xbb\xaf\xb9\x03\x07C%E%'G\x05;\x05=\x05?\x03\x0bK\xbdM\xc5O\xc7'\xd5Q\xd7\x05A\x05C\x05E\x05G\x05I\x1dW+\x05K\x1d[+\x05M\x05O\x03\x03a\xb1\x05Q\x03\x03\t\xdb\x03\x11g\xe1i\xe3k\xe5m\xbdo\xe7q\xe9s\xebu\xef\x05S\x05U\x05W\x05Y\x05[\x05]\x05_\x05a\x03\x03\t\xf3\x03\x051\xc13\xf7\x03\x03\t\xf9\x03\x03/\xfb\x1d\x81\x07\x05c\x1d\x85\x07\x05e\x1d\x89\x07\x05g\x05i\x05k\x05m\x05o\x05q\x05s\x05u\x05w\x05y\x05{\x03\x03;\xc3\x1d\xa3\x07\x05}\x1d\xa7\x07\x05\x7f\x05\x81\x05\x83\x05\x85\x05\x87\x13\x11\x01\x1f%\x01\x1f\x1d\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1d\x89\x1f+\x01\x05\x03\x03\x01\x1f'!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\t\x07\x1f\x1d\x11\x01\x00\x00\x00\x00\x00\x00\x00#\x1f\x03\x07\xc9\xcd\xd1\r\x03\xb7\xcb\x1d\x8b\r\x03\xb7\xcf\x1d\x8d\r\x03\xb7\xd3\x1d\x8f\x1d\x91\x1d\x93\x1f\x03\x11\x03\x00\x00\x00\x00\x00\x00\x00\x1f\x19\x01\x1f\x03\x11\x01\x00\x00\x00\x00\x00\x00\x00\x1f\x05\t\x03\x00\x00\x00\x0b\x05\x1d\x95\x1d\x97\x05\x01\x03\t\xb3\xb3\xb3\xbf\x03\x03\xed\x15\x03\x01\r\x01\x03\x07\xbf\xf1\xb3\x1f)\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x05\t\x01\x00\x00\x00\x1f\x05\t\x00\x00\x00\x00\x07\x05\x1f\x1b\x11\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f1!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f\x03\x11\x00\x00\x00\x00\x00\x00\x00\x00\x07\x0b\x05\x99\x1d\n\x02\x07\x05\x9b\x01\x02\x02)\x01\x11)\x01\x0f)\x03\r\x0f)\x03\x05\x0f)\x01\x17)\x05\r\r\x13\x1b\x1d\x0b\x13\x01)\x03\x01\x0f)\x01\x13)\x03\x05\x11\x11\x01\x07\r\x07\x07)\x03%\x13)\x03\t\x0f)\x03\x01\x15)\x03\t\x15)\x03\x05\x15)\x03\x01\x11)\x05\x05\x05\x17)\x05\r\r\x17)\x03\t\x11\x04f\n\x05\x01\x11\x05A\x07\x03\x01\x05\x19\x11\x05I\x05\x03K\x85\x13\x03U)\x03!\x0f\x06Y\x03\r\x03\x01\x03\x03\x01\r\x03\x03\x03\x03\x01\r\x03\x03\x0b\x06\x01\x03\x05\x03\x05\x0f\x06\x01\x03\t\x03\t\x0b\x06\x01\x03\x05\x03\x07\x0f\x06\x01\x03\t\x03\r\x1b\x07\x01_\x03#\x05\x0b\x0f\x03\x03\x01\r\x03\x03\x0b\x06\x01\x03\x05\x03\x13\x0f\x06\x01\x03\t\x03\x15\x03\x03\x01c\x03\x19\x03\x03\x01\x1f\x03\x03\x03\x03\x01\x19\x03\x05\x03\x03\x01\x19\x03\x05\x1d\x07\x01e\x07\r\x07\x05\t\x1b\x1d\x1f\x03\x03\x03\x01w\x03\x05\x05\x07\x01\x0b\x03\x07\x03'\x1f\x06\x01\x03\x07\x05#)\x03\x03\x01!\x03\x05\x05\x07\x01\x0b\x03\x05\x03-\x07\x07\x01y\x03\x0b\x05%/\x05\x07\x01\x0b\x03-\x031\x03\x03\x01{\x03\x1b\x05\x07\x01\x0b\x03\r\x035\x05\x07\x01}\x03/\x033\r\x06\x01\x03\r\x079!7\x13\x03\x7f)\x03\x07\x03\x03\x83\x13\x03\x03\x03\x03\x87\x13\x03\x03!\x16\x03\t\x03\x03\x07\x07\tA?=+\t\x03\r\x0f\t\x03\x05\x03\x05\x07\x05\x07\x05\x03\x03\x03\r\x03\x03\x07\x07\x06\x02\x11\x03\x0b\x05KS\x11\x04\x03\x03U\x03]\xaf\t\x03\x05\x03\x05\x07\x05\x07\x05\x03\x03\x03\x1f\x03\x03\t\x06\x0f\x03\x03\x05MS\x03\x03\x03\x13\x03\x03\x07\x07\x15\x11\x03\x0b\x05MW\x03\x03\x03\r\x03\x03\t\x06\x0f\x03\x03\x05M[\r\x06\x17\x03\x03\x07Y]M\x0b\x06#\x03\x05\x03_\x05\x07\x1b\x0b\x03\t\x03a\x15\x07=5\x03\x05\x05Qc\x03\x03\x03\x13\x03\x03\x07\x07\x15\x11\x03\x0b\x05Mg\x03\x03\x03\r\x03\x03\t\x06\x0f\x03\x03\x05Mk\r\x06\x17\x03\x03\x07imM\x0b\x06#\x03\x05\x03o\x05\x07\x1b\x0b\x03\t\x03q\x15\x07=5\x03\x05\x05Os\x03\x03\x03!\x03\x05\x07\x07\x15\x11\x03\x0b\x05ew\x03\x03\x03\x19\x03\x05\t\x06\x0f\x03\x05\x05e{\r\x06\x17\x03\x05\x07y}e#\x07\xa1\x9f\x03\t\x05O\x7f\x0f\x06\xa5\x03\x05\x03\x81\x03\x03\x03\x13\x03\x03\x07\x07\x15\x11\x03\x0b\x05M\x85\x03\x03\x03\r\x03\x03\t\x06\x0f\x03\x03\x05M\x89\r\x06\x17\x03\x03\x07\x87\x8bM\x0b\x06#\x03\x05\x03\x8d\x05\x07\x1b\x0b\x03\t\x03\x8f\x17\x17\x1d?\x03\x07\x07O\x91\x83\x05\x03\x05\x07\x05\x05\x05\x05\x05\x11\x04\x1d\x03\xa9\x03\x03\x03!\x03\x05\x07\x07\x15\x11\x03\x0b\x05e\x95\x03\x03\x03\x19\x03\x05\t\x06\x0f\x03\x05\x05e\x99\r\x06\x17\x03\x05\x07\x97\x9be\x05\x07\x1b\x0b\x03\t\x03\x9d\x17\x17\x1d?\x03\x07\x07\x93\x9fu\x05\x03\x05\x07\x05\x05\x05\x05\x05\x11\x04\x1d\x03\xa9\x03\x03\x03\x1f\x03\x03\t\x06\x0f\x03\x03\x05K\xa3\x11\x04\x03\t\xa5U\xa1Q\x11\x04\x05\x07;+G\x06\x03\x01\x05\x01\x00v%\x9dM2\x04\x1d\x03\x0f\x0b\t\t\t!'\x1f;+y\x87.\x04!\x19+\xb1\xb3YMO{\xe9\x8b\x83\x1f/!!)#\x1f\x19\x157\x85\x87\x1f\x1f\x15\x1d\x15\x1b%)\x19'#+\x1b+\x83\x13\r#\x13\x19\x1f\x1f\x11\x17\x15\x11\x15\x17\x15\x17\x0f\x17)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00compare_v1\x00add_v1\x00convert_v1\x00select_v1\x00reshape_v1\x00return_v1\x00iota_v1\x00gather_v1\x00scatter_v1\x00func_v1\x00concatenate_v1\x00custom_call_v1\x00subtract_v1\x00while_v1\x00dynamic_slice_v1\x00value\x00sym_name\x00third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py\x00broadcast_dimensions\x00compare_type\x00comparison_direction\x00index_vector_dim\x00indices_are_sorted\x00slice_sizes\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00jit()/jit(main)/iota[dtype=float64 shape=(9,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(3, 3) dimensions=None]\x00jit()/jit(main)/lu\x00dimension\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jit()/jit(main)/iota[dtype=int32 shape=(3,) dimension=0]\x00jit()/jit(main)/lu_pivots_to_permutation[permutation_size=3]\x00jit()/jit(main)/scan[reverse=False length=3 num_consts=0 num_carry=3 linear=(False, False, False) unroll=1]\x00jit()/jit(main)/while[cond_nconsts=0 body_nconsts=0]\x00jit()/jit(main)/while/body/add\x00jit()/jit(main)/while/body/lt\x00jit()/jit(main)/while/body/select_n\x00jit()/jit(main)/while/body/convert_element_type[new_dtype=int32 weak_type=False]\x00jit()/jit(main)/while/body/broadcast_in_dim[shape=(1,) broadcast_dimensions=()]\x00collapsed_slice_dims\x00offset_dims\x00start_index_map\x00jit()/jit(main)/while/body/gather[dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)) slice_sizes=(1,) unique_indices=True indices_are_sorted=True mode=GatherScatterMode.PROMISE_IN_BOUNDS fill_value=None]\x00jit()/jit(main)/while/body/dynamic_slice[slice_sizes=(1,)]\x00jit()/jit(main)/while/body/squeeze[dimensions=(0,)]\x00inserted_window_dims\x00scatter_dims_to_operand_dims\x00unique_indices\x00update_window_dims\x00jax.result_info\x00[0]\x00[1]\x00[2]\x00main\x00public\x00\x00lapack_dgetrf\x00jit()/jit(main)/while/body/scatter[update_consts=() dimension_numbers=ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,)) indices_are_sorted=True unique_indices=True mode=GatherScatterMode.FILL_OR_DROP]\x00jit()/jit(main)/while/cond/lt\x00", - xla_call_module_version=6, -) # End paste - - - -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_06_14['c64'] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_cgetrf'], - serialized_date=datetime.date(2023, 6, 14), - inputs=(), - expected_outputs=(array([[6. +0.j, 7. +0.j, 8. +0.j], - [0. +0.j, 1. +0.j, 2. +0.j], - [0.5+0.j, 0.5+0.j, 0. +0.j]], dtype=complex64), array([2, 2, 2], dtype=int32), array([2, 0, 1], dtype=int32)), - mlir_module_text=r""" -#loc = loc(unknown) -module @jit__lambda_ attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main() -> (tensor<3x3xcomplex> {jax.result_info = "[0]"}, tensor<3xi32> {jax.result_info = "[1]"}, tensor<3xi32> {jax.result_info = "[2]"}) { - %0 = stablehlo.iota dim = 0 : tensor<9xcomplex> loc(#loc3) - %1 = stablehlo.reshape %0 : (tensor<9xcomplex>) -> tensor<3x3xcomplex> loc(#loc4) - %2 = stablehlo.constant dense<3> : tensor loc(#loc5) - %3 = stablehlo.constant dense<3> : tensor loc(#loc5) - %4 = stablehlo.convert %2 : (tensor) -> tensor loc(#loc5) - %5 = stablehlo.reshape %4 : (tensor) -> tensor<1xi32> loc(#loc5) - %6 = stablehlo.convert %3 : (tensor) -> tensor loc(#loc5) - %7 = stablehlo.reshape %6 : (tensor) -> tensor<1xi32> loc(#loc5) - %8 = stablehlo.concatenate %5, %7, dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> loc(#loc5) - %9 = stablehlo.constant dense<3> : tensor loc(#loc5) - %10 = stablehlo.convert %9 : (tensor) -> tensor loc(#loc5) - %11 = stablehlo.reshape %10 : (tensor) -> tensor<1xi32> loc(#loc5) - %12 = stablehlo.constant dense<> : tensor<0xi32> loc(#loc5) - %13 = stablehlo.constant dense<1> : tensor loc(#loc5) - %14 = stablehlo.constant dense<3> : tensor loc(#loc5) - %15 = stablehlo.constant dense<3> : tensor loc(#loc5) - %16:3 = stablehlo.custom_call @lapack_cgetrf(%13, %14, %15, %1) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>]} : (tensor, tensor, tensor, tensor<3x3xcomplex>) -> (tensor<3x3xcomplex>, tensor<3xi32>, tensor) loc(#loc5) - %17 = stablehlo.constant dense<1> : tensor loc(#loc5) - %18 = stablehlo.broadcast_in_dim %17, dims = [] : (tensor) -> tensor<3xi32> loc(#loc5) - %19 = stablehlo.subtract %16#1, %18 : tensor<3xi32> loc(#loc5) - %20 = stablehlo.constant dense<0> : tensor loc(#loc5) - %21 = stablehlo.broadcast_in_dim %20, dims = [] : (tensor) -> tensor loc(#loc5) - %22 = stablehlo.compare GE, %16#2, %21, SIGNED : (tensor, tensor) -> tensor loc(#loc5) - %23 = stablehlo.broadcast_in_dim %22, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc5) - %24 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> loc(#loc5) - %25 = stablehlo.broadcast_in_dim %24, dims = [] : (tensor>) -> tensor<3x3xcomplex> loc(#loc5) - %26 = stablehlo.broadcast_in_dim %23, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<3x3xi1> loc(#loc5) - %27 = stablehlo.select %26, %16#0, %25 : tensor<3x3xi1>, tensor<3x3xcomplex> loc(#loc5) - %28 = stablehlo.iota dim = 0 : tensor<3xi32> loc(#loc6) - %29 = stablehlo.constant dense<0> : tensor loc(#loc7) - %30 = stablehlo.constant dense<0> : tensor loc(#loc8) - %31:4 = stablehlo.while(%iterArg = %30, %iterArg_0 = %29, %iterArg_1 = %28, %iterArg_2 = %19) : tensor, tensor, tensor<3xi32>, tensor<3xi32> - cond { - %32 = stablehlo.constant dense<3> : tensor loc(#loc9) - %33 = stablehlo.compare LT, %iterArg, %32, SIGNED : (tensor, tensor) -> tensor loc(#loc10) - stablehlo.return %33 : tensor loc(#loc9) - } do { - %32 = stablehlo.constant dense<1> : tensor loc(#loc9) - %33 = stablehlo.add %iterArg_0, %32 : tensor loc(#loc11) - %34 = stablehlo.constant dense<0> : tensor loc(#loc9) - %35 = stablehlo.compare LT, %iterArg_0, %34, SIGNED : (tensor, tensor) -> tensor loc(#loc12) - %36 = stablehlo.constant dense<3> : tensor loc(#loc9) - %37 = stablehlo.add %iterArg_0, %36 : tensor loc(#loc11) - %38 = stablehlo.select %35, %37, %iterArg_0 : tensor, tensor loc(#loc13) - %39 = stablehlo.convert %38 : (tensor) -> tensor loc(#loc14) - %40 = stablehlo.broadcast_in_dim %39, dims = [] : (tensor) -> tensor<1xi32> loc(#loc15) - %41 = "stablehlo.gather"(%iterArg_2, %40) {dimension_numbers = #stablehlo.gather, indices_are_sorted = true, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1xi32>) -> tensor loc(#loc16) - %42 = stablehlo.constant dense<0> : tensor loc(#loc9) - %43 = stablehlo.compare LT, %iterArg_0, %42, SIGNED : (tensor, tensor) -> tensor loc(#loc12) - %44 = stablehlo.constant dense<3> : tensor loc(#loc9) - %45 = stablehlo.add %iterArg_0, %44 : tensor loc(#loc11) - %46 = stablehlo.select %43, %45, %iterArg_0 : tensor, tensor loc(#loc13) - %47 = stablehlo.convert %46 : (tensor) -> tensor loc(#loc14) - %48 = stablehlo.broadcast_in_dim %47, dims = [] : (tensor) -> tensor<1xi32> loc(#loc15) - %49 = "stablehlo.gather"(%iterArg_1, %48) {dimension_numbers = #stablehlo.gather, indices_are_sorted = true, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1xi32>) -> tensor loc(#loc16) - %50 = stablehlo.constant dense<0> : tensor loc(#loc9) - %51 = stablehlo.compare LT, %41, %50, SIGNED : (tensor, tensor) -> tensor loc(#loc12) - %52 = stablehlo.constant dense<3> : tensor loc(#loc9) - %53 = stablehlo.add %41, %52 : tensor loc(#loc11) - %54 = stablehlo.select %51, %53, %41 : tensor, tensor loc(#loc13) - %55 = stablehlo.dynamic_slice %iterArg_1, %54, sizes = [1] : (tensor<3xi32>, tensor) -> tensor<1xi32> loc(#loc17) - %56 = stablehlo.reshape %55 : (tensor<1xi32>) -> tensor loc(#loc18) - %57 = stablehlo.constant dense<0> : tensor loc(#loc9) - %58 = stablehlo.compare LT, %iterArg_0, %57, SIGNED : (tensor, tensor) -> tensor loc(#loc12) - %59 = stablehlo.constant dense<3> : tensor loc(#loc9) - %60 = stablehlo.add %iterArg_0, %59 : tensor loc(#loc11) - %61 = stablehlo.select %58, %60, %iterArg_0 : tensor, tensor loc(#loc13) - %62 = stablehlo.convert %61 : (tensor) -> tensor loc(#loc14) - %63 = stablehlo.broadcast_in_dim %62, dims = [] : (tensor) -> tensor<1xi32> loc(#loc15) - %64 = "stablehlo.scatter"(%iterArg_1, %63, %56) ({ - ^bb0(%arg0: tensor loc(unknown), %arg1: tensor loc(unknown)): - stablehlo.return %arg1 : tensor loc(#loc19) - }) {indices_are_sorted = true, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = true} : (tensor<3xi32>, tensor<1xi32>, tensor) -> tensor<3xi32> loc(#loc19) - %65 = stablehlo.constant dense<0> : tensor loc(#loc9) - %66 = stablehlo.compare LT, %41, %65, SIGNED : (tensor, tensor) -> tensor loc(#loc12) - %67 = stablehlo.constant dense<3> : tensor loc(#loc9) - %68 = stablehlo.add %41, %67 : tensor loc(#loc11) - %69 = stablehlo.select %66, %68, %41 : tensor, tensor loc(#loc13) - %70 = stablehlo.broadcast_in_dim %69, dims = [] : (tensor) -> tensor<1xi32> loc(#loc15) - %71 = "stablehlo.scatter"(%64, %70, %49) ({ - ^bb0(%arg0: tensor loc(unknown), %arg1: tensor loc(unknown)): - stablehlo.return %arg1 : tensor loc(#loc19) - }) {indices_are_sorted = true, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = true} : (tensor<3xi32>, tensor<1xi32>, tensor) -> tensor<3xi32> loc(#loc19) - %72 = stablehlo.constant dense<1> : tensor loc(#loc9) - %73 = stablehlo.add %iterArg, %72 : tensor loc(#loc11) - stablehlo.return %73, %33, %71, %iterArg_2 : tensor, tensor, tensor<3xi32>, tensor<3xi32> loc(#loc9) - } loc(#loc9) - return %27, %19, %31#2 : tensor<3x3xcomplex>, tensor<3xi32>, tensor<3xi32> loc(#loc) - } loc(#loc) -} loc(#loc) -#loc1 = loc("third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py":553:0) -#loc2 = loc("third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py":554:0) -#loc3 = loc("jit()/jit(main)/iota[dtype=complex64 shape=(9,) dimension=0]"(#loc1)) -#loc4 = loc("jit()/jit(main)/reshape[new_sizes=(3, 3) dimensions=None]"(#loc1)) -#loc5 = loc("jit()/jit(main)/lu"(#loc2)) -#loc6 = loc("jit()/jit(main)/iota[dtype=int32 shape=(3,) dimension=0]"(#loc2)) -#loc7 = loc("jit()/jit(main)/lu_pivots_to_permutation[permutation_size=3]"(#loc2)) -#loc8 = loc("jit()/jit(main)/scan[reverse=False length=3 num_consts=0 num_carry=3 linear=(False, False, False) unroll=1]"(#loc2)) -#loc9 = loc("jit()/jit(main)/while[cond_nconsts=0 body_nconsts=0]"(#loc2)) -#loc10 = loc("jit()/jit(main)/while/cond/lt"(#loc2)) -#loc11 = loc("jit()/jit(main)/while/body/add"(#loc2)) -#loc12 = loc("jit()/jit(main)/while/body/lt"(#loc2)) -#loc13 = loc("jit()/jit(main)/while/body/select_n"(#loc2)) -#loc14 = loc("jit()/jit(main)/while/body/convert_element_type[new_dtype=int32 weak_type=False]"(#loc2)) -#loc15 = loc("jit()/jit(main)/while/body/broadcast_in_dim[shape=(1,) broadcast_dimensions=()]"(#loc2)) -#loc16 = loc("jit()/jit(main)/while/body/gather[dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)) slice_sizes=(1,) unique_indices=True indices_are_sorted=True mode=GatherScatterMode.PROMISE_IN_BOUNDS fill_value=None]"(#loc2)) -#loc17 = loc("jit()/jit(main)/while/body/dynamic_slice[slice_sizes=(1,)]"(#loc2)) -#loc18 = loc("jit()/jit(main)/while/body/squeeze[dimensions=(0,)]"(#loc2)) -#loc19 = loc("jit()/jit(main)/while/body/scatter[update_consts=() dimension_numbers=ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,)) indices_are_sorted=True unique_indices=True mode=GatherScatterMode.FILL_OR_DROP]"(#loc2)) -""", - mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x013\x05\x01\x03\x01\x03\x05\x03#\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x1f!#%'\x03\xaa\x02\x0e\x025\x01\xb1\x0f\x0f\x07\x17\x0b\x13\x13\x0f\x1b\x13\x0f\x0f\x13\x0f\x13\x13\x13\x0f\x0f\x0b\x13\x17\x0b\x0b\x0b\x0b;\x0b\x0b\x0b\x0f;#\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x0b\x0f\x0b\x0f\x0b\x0b\x13\x0b\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x13\x13\x0f\x0b\x0f\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x0f\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x03Q\x0f\x0f/\x0b\x0f\x0b\x0bO\x0b/\x0b\x17\x13\x0b\x13\x0b\x13\x0b\x0b\x0b/\x0f/\x1f\x0b\x0b\x0b\x0b\x1b\x0f\x17\x17/\x1f\x1f\x0b/O/\x0b\x01\x07\x0b\x13\x0b\x01\x03\x0f\x033\x0f\x0f\x13\x13\x0f\x17\x07\x07\x0b\x07\x07\x13\x0f\x13\x1b\x07\x13\x13\x13\x13\x13\x13\x17\x17\x13\x02b\t\x1d]\x07\x1d\x8b\x07\x1f\x17-\xaa\x08\x01\x05)\x03\x03/\xb9\x03\x03\t\xd9\x1d\x8d\x07\x03\x051\xc13\xff\x03\x03\t\xfd\x1d\x8f\x07\x1d\x91\x07\x03\x03\t\xdf\x1d\x95\x07\x1d\x02\x02\x07\x03\x03\t\xdd\x03\x03\t\xf5\x1d\x93\x07\x11\x01\x05\x05+\x03\x03S\xb1\x17-\xa6\x08\x01\x05-\x05/\x051\x053\x03\r\x97\xb57\xb19\xbb\x99\xb9;\xc3\x9b\xb5\x055\x057\x059\x1d\x9d\x07\x03\r7\xb19\xbb\xa9\xb5\xab\xb5\xad\xbb\xaf\xb9\x03\x07C%E%'G\x05;\x05=\x05?\x03\x0bK\xbdM\xc5O\xc7'\xd5Q\xd7\x05A\x05C\x05E\x05G\x05I\x1dW+\x05K\x1d[+\x05M\x05O\x03\x03a\xb1\x05Q\x03\x03\t\xdb\x03\x11g\xe1i\xe3k\xe5m\xbdo\xe7q\xe9s\xebu\xef\x05S\x05U\x05W\x05Y\x05[\x05]\x05_\x05a\x03\x03\t\xf3\x03\x051\xc13\xf7\x03\x03\t\xf9\x03\x03/\xfb\x1d\x81\x07\x05c\x1d\x85\x07\x05e\x1d\x89\x07\x05g\x05i\x05k\x05m\x05o\x05q\x05s\x05u\x05w\x05y\x05{\x03\x03;\xc3\x1d\xa3\x07\x05}\x1d\xa7\x07\x05\x7f\x05\x81\x05\x83\x05\x85\x05\x87\x13\x11\x01\x1f'\x01\x1f\x1d\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1d\x89\x1f-\x01\x05\x03\x03\x01\x1f)!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\t\x07\x1f\x1d\x11\x01\x00\x00\x00\x00\x00\x00\x00#\x1f\x03\x07\xc9\xcd\xd1\r\x03\xb7\xcb\x1d\x8b\r\x03\xb7\xcf\x1d\x8d\r\x03\xb7\xd3\x1d\x8f\x1d\x91\x1d\x93\x1f\x03\x11\x03\x00\x00\x00\x00\x00\x00\x00\x1f\x19\x01\x1f\x03\x11\x01\x00\x00\x00\x00\x00\x00\x00\x1f\x05\t\x03\x00\x00\x00\x0b\x05\x1d\x95\x1d\x97\x05\x01\x03\t\xb3\xb3\xb3\xbf\x03\x03\xed\x15\x03\x01\r\x01\x03\x07\xbf\xf1\xb3\x1f+\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x05\t\x01\x00\x00\x00\x1f\x05\t\x00\x00\x00\x00\x07\x05\x1f\x1b\x11\x00\x00\xc0\x7f\x00\x00\xc0\x7f\x1f3!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f\x03\x11\x00\x00\x00\x00\x00\x00\x00\x00\x07\x0b\x05\x99\x1d\n\x02\x07\x05\x9b\x01\x02\x02)\x01\x11)\x01\x0f)\x03\r\x0f)\x03\x05\x0f)\x01\x17)\x05\r\r\x13\x1b\x1d\x03!\x13\x01)\x03\x01\x0f)\x01\x13)\x03\x05\x11\x11\x01\x07\r\x07\x07\t)\x03%\x13)\x03\t\x0f)\x03\x01\x15)\x03\t\x15)\x03\x05\x15)\x03\x01\x11)\x05\x05\x05\x17)\x05\r\r\x17)\x03\t\x11\x04f\n\x05\x01\x11\x05A\x07\x03\x01\x05\x19\x11\x05I\x05\x03K\x85\x13\x03U)\x03#\x0f\x06Y\x03\r\x03\x01\x03\x03\x01\r\x03\x03\x03\x03\x01\r\x03\x03\x0b\x06\x01\x03\x05\x03\x05\x0f\x06\x01\x03\t\x03\t\x0b\x06\x01\x03\x05\x03\x07\x0f\x06\x01\x03\t\x03\r\x1b\x07\x01_\x03%\x05\x0b\x0f\x03\x03\x01\r\x03\x03\x0b\x06\x01\x03\x05\x03\x13\x0f\x06\x01\x03\t\x03\x15\x03\x03\x01c\x03\x19\x03\x03\x01\x1f\x03\x03\x03\x03\x01\x19\x03\x05\x03\x03\x01\x19\x03\x05\x1d\x07\x01e\x07\r\x07\x05\t\x1b\x1d\x1f\x03\x03\x03\x01w\x03\x05\x05\x07\x01\x0b\x03\x07\x03'\x1f\x06\x01\x03\x07\x05#)\x03\x03\x01!\x03\x05\x05\x07\x01\x0b\x03\x05\x03-\x07\x07\x01y\x03\x0b\x05%/\x05\x07\x01\x0b\x03/\x031\x03\x03\x01{\x03\x1b\x05\x07\x01\x0b\x03\r\x035\x05\x07\x01}\x031\x033\r\x06\x01\x03\r\x079!7\x13\x03\x7f)\x03\x07\x03\x03\x83\x13\x03\x03\x03\x03\x87\x13\x03\x03!\x16\x03\t\x03\x03\x07\x07\tA?=+\t\x03\r\x0f\t\x03\x05\x03\x05\x07\x05\x07\x05\x03\x03\x03\r\x03\x03\x07\x07\x06\x02\x11\x03\x0b\x05KS\x11\x04\x03\x03U\x03]\xaf\t\x03\x05\x03\x05\x07\x05\x07\x05\x03\x03\x03\x1f\x03\x03\t\x06\x0f\x03\x03\x05MS\x03\x03\x03\x13\x03\x03\x07\x07\x15\x11\x03\x0b\x05MW\x03\x03\x03\r\x03\x03\t\x06\x0f\x03\x03\x05M[\r\x06\x17\x03\x03\x07Y]M\x0b\x06#\x03\x05\x03_\x05\x07\x1b\x0b\x03\t\x03a\x15\x07=5\x03\x05\x05Qc\x03\x03\x03\x13\x03\x03\x07\x07\x15\x11\x03\x0b\x05Mg\x03\x03\x03\r\x03\x03\t\x06\x0f\x03\x03\x05Mk\r\x06\x17\x03\x03\x07imM\x0b\x06#\x03\x05\x03o\x05\x07\x1b\x0b\x03\t\x03q\x15\x07=5\x03\x05\x05Os\x03\x03\x03!\x03\x05\x07\x07\x15\x11\x03\x0b\x05ew\x03\x03\x03\x19\x03\x05\t\x06\x0f\x03\x05\x05e{\r\x06\x17\x03\x05\x07y}e#\x07\xa1\x9f\x03\t\x05O\x7f\x0f\x06\xa5\x03\x05\x03\x81\x03\x03\x03\x13\x03\x03\x07\x07\x15\x11\x03\x0b\x05M\x85\x03\x03\x03\r\x03\x03\t\x06\x0f\x03\x03\x05M\x89\r\x06\x17\x03\x03\x07\x87\x8bM\x0b\x06#\x03\x05\x03\x8d\x05\x07\x1b\x0b\x03\t\x03\x8f\x17\x17\x1d?\x03\x07\x07O\x91\x83\x05\x03\x05\x07\x05\x05\x05\x05\x05\x11\x04\x1d\x03\xa9\x03\x03\x03!\x03\x05\x07\x07\x15\x11\x03\x0b\x05e\x95\x03\x03\x03\x19\x03\x05\t\x06\x0f\x03\x05\x05e\x99\r\x06\x17\x03\x05\x07\x97\x9be\x05\x07\x1b\x0b\x03\t\x03\x9d\x17\x17\x1d?\x03\x07\x07\x93\x9fu\x05\x03\x05\x07\x05\x05\x05\x05\x05\x11\x04\x1d\x03\xa9\x03\x03\x03\x1f\x03\x03\t\x06\x0f\x03\x03\x05K\xa3\x11\x04\x03\t\xa5U\xa1Q\x11\x04\x05\x07;+G\x06\x03\x01\x05\x01\x00~%\x9dM2\x04\x1d\x03\x0f\x0b\t\t\t!'\x1f;+y\x87.\x04!\x19+\xb1\xb3YMO{\xe9\x8b\x83\x1f/!!)#\x1f\x19\x157\x85\x8b\x1f\x1f\x15\x1d\x15\x1b%)\x19'#+\x1b+\x83\x13\r#\x13\x19\x1f\x1f\x11\x17\x15\x11\x15\x17\x15\x17\x0f\x17)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00compare_v1\x00add_v1\x00convert_v1\x00select_v1\x00reshape_v1\x00return_v1\x00iota_v1\x00gather_v1\x00scatter_v1\x00func_v1\x00concatenate_v1\x00custom_call_v1\x00subtract_v1\x00while_v1\x00dynamic_slice_v1\x00value\x00sym_name\x00third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py\x00broadcast_dimensions\x00compare_type\x00comparison_direction\x00index_vector_dim\x00indices_are_sorted\x00slice_sizes\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00jit()/jit(main)/iota[dtype=complex64 shape=(9,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(3, 3) dimensions=None]\x00jit()/jit(main)/lu\x00dimension\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jit()/jit(main)/iota[dtype=int32 shape=(3,) dimension=0]\x00jit()/jit(main)/lu_pivots_to_permutation[permutation_size=3]\x00jit()/jit(main)/scan[reverse=False length=3 num_consts=0 num_carry=3 linear=(False, False, False) unroll=1]\x00jit()/jit(main)/while[cond_nconsts=0 body_nconsts=0]\x00jit()/jit(main)/while/body/add\x00jit()/jit(main)/while/body/lt\x00jit()/jit(main)/while/body/select_n\x00jit()/jit(main)/while/body/convert_element_type[new_dtype=int32 weak_type=False]\x00jit()/jit(main)/while/body/broadcast_in_dim[shape=(1,) broadcast_dimensions=()]\x00collapsed_slice_dims\x00offset_dims\x00start_index_map\x00jit()/jit(main)/while/body/gather[dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)) slice_sizes=(1,) unique_indices=True indices_are_sorted=True mode=GatherScatterMode.PROMISE_IN_BOUNDS fill_value=None]\x00jit()/jit(main)/while/body/dynamic_slice[slice_sizes=(1,)]\x00jit()/jit(main)/while/body/squeeze[dimensions=(0,)]\x00inserted_window_dims\x00scatter_dims_to_operand_dims\x00unique_indices\x00update_window_dims\x00jax.result_info\x00[0]\x00[1]\x00[2]\x00main\x00public\x00\x00lapack_cgetrf\x00jit()/jit(main)/while/body/scatter[update_consts=() dimension_numbers=ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,)) indices_are_sorted=True unique_indices=True mode=GatherScatterMode.FILL_OR_DROP]\x00jit()/jit(main)/while/cond/lt\x00", - xla_call_module_version=6, -) # End paste - - -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_06_14['c128'] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_zgetrf'], - serialized_date=datetime.date(2023, 6, 14), - inputs=(), - expected_outputs=(array([[6. +0.j, 7. +0.j, 8. +0.j], - [0. +0.j, 1. +0.j, 2. +0.j], - [0.5+0.j, 0.5+0.j, 0. +0.j]]), array([2, 2, 2], dtype=int32), array([2, 0, 1], dtype=int32)), - mlir_module_text=r""" -#loc = loc(unknown) -module @jit__lambda_ attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main() -> (tensor<3x3xcomplex> {jax.result_info = "[0]"}, tensor<3xi32> {jax.result_info = "[1]"}, tensor<3xi32> {jax.result_info = "[2]"}) { - %0 = stablehlo.iota dim = 0 : tensor<9xcomplex> loc(#loc3) - %1 = stablehlo.reshape %0 : (tensor<9xcomplex>) -> tensor<3x3xcomplex> loc(#loc4) - %2 = stablehlo.constant dense<3> : tensor loc(#loc5) - %3 = stablehlo.constant dense<3> : tensor loc(#loc5) - %4 = stablehlo.convert %2 : (tensor) -> tensor loc(#loc5) - %5 = stablehlo.reshape %4 : (tensor) -> tensor<1xi32> loc(#loc5) - %6 = stablehlo.convert %3 : (tensor) -> tensor loc(#loc5) - %7 = stablehlo.reshape %6 : (tensor) -> tensor<1xi32> loc(#loc5) - %8 = stablehlo.concatenate %5, %7, dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> loc(#loc5) - %9 = stablehlo.constant dense<3> : tensor loc(#loc5) - %10 = stablehlo.convert %9 : (tensor) -> tensor loc(#loc5) - %11 = stablehlo.reshape %10 : (tensor) -> tensor<1xi32> loc(#loc5) - %12 = stablehlo.constant dense<> : tensor<0xi32> loc(#loc5) - %13 = stablehlo.constant dense<1> : tensor loc(#loc5) - %14 = stablehlo.constant dense<3> : tensor loc(#loc5) - %15 = stablehlo.constant dense<3> : tensor loc(#loc5) - %16:3 = stablehlo.custom_call @lapack_zgetrf(%13, %14, %15, %1) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>]} : (tensor, tensor, tensor, tensor<3x3xcomplex>) -> (tensor<3x3xcomplex>, tensor<3xi32>, tensor) loc(#loc5) - %17 = stablehlo.constant dense<1> : tensor loc(#loc5) - %18 = stablehlo.broadcast_in_dim %17, dims = [] : (tensor) -> tensor<3xi32> loc(#loc5) - %19 = stablehlo.subtract %16#1, %18 : tensor<3xi32> loc(#loc5) - %20 = stablehlo.constant dense<0> : tensor loc(#loc5) - %21 = stablehlo.broadcast_in_dim %20, dims = [] : (tensor) -> tensor loc(#loc5) - %22 = stablehlo.compare GE, %16#2, %21, SIGNED : (tensor, tensor) -> tensor loc(#loc5) - %23 = stablehlo.broadcast_in_dim %22, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc5) - %24 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> loc(#loc5) - %25 = stablehlo.broadcast_in_dim %24, dims = [] : (tensor>) -> tensor<3x3xcomplex> loc(#loc5) - %26 = stablehlo.broadcast_in_dim %23, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<3x3xi1> loc(#loc5) - %27 = stablehlo.select %26, %16#0, %25 : tensor<3x3xi1>, tensor<3x3xcomplex> loc(#loc5) - %28 = stablehlo.iota dim = 0 : tensor<3xi32> loc(#loc6) - %29 = stablehlo.constant dense<0> : tensor loc(#loc7) - %30 = stablehlo.constant dense<0> : tensor loc(#loc8) - %31:4 = stablehlo.while(%iterArg = %30, %iterArg_0 = %29, %iterArg_1 = %28, %iterArg_2 = %19) : tensor, tensor, tensor<3xi32>, tensor<3xi32> - cond { - %32 = stablehlo.constant dense<3> : tensor loc(#loc9) - %33 = stablehlo.compare LT, %iterArg, %32, SIGNED : (tensor, tensor) -> tensor loc(#loc10) - stablehlo.return %33 : tensor loc(#loc9) - } do { - %32 = stablehlo.constant dense<1> : tensor loc(#loc9) - %33 = stablehlo.add %iterArg_0, %32 : tensor loc(#loc11) - %34 = stablehlo.constant dense<0> : tensor loc(#loc9) - %35 = stablehlo.compare LT, %iterArg_0, %34, SIGNED : (tensor, tensor) -> tensor loc(#loc12) - %36 = stablehlo.constant dense<3> : tensor loc(#loc9) - %37 = stablehlo.add %iterArg_0, %36 : tensor loc(#loc11) - %38 = stablehlo.select %35, %37, %iterArg_0 : tensor, tensor loc(#loc13) - %39 = stablehlo.convert %38 : (tensor) -> tensor loc(#loc14) - %40 = stablehlo.broadcast_in_dim %39, dims = [] : (tensor) -> tensor<1xi32> loc(#loc15) - %41 = "stablehlo.gather"(%iterArg_2, %40) {dimension_numbers = #stablehlo.gather, indices_are_sorted = true, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1xi32>) -> tensor loc(#loc16) - %42 = stablehlo.constant dense<0> : tensor loc(#loc9) - %43 = stablehlo.compare LT, %iterArg_0, %42, SIGNED : (tensor, tensor) -> tensor loc(#loc12) - %44 = stablehlo.constant dense<3> : tensor loc(#loc9) - %45 = stablehlo.add %iterArg_0, %44 : tensor loc(#loc11) - %46 = stablehlo.select %43, %45, %iterArg_0 : tensor, tensor loc(#loc13) - %47 = stablehlo.convert %46 : (tensor) -> tensor loc(#loc14) - %48 = stablehlo.broadcast_in_dim %47, dims = [] : (tensor) -> tensor<1xi32> loc(#loc15) - %49 = "stablehlo.gather"(%iterArg_1, %48) {dimension_numbers = #stablehlo.gather, indices_are_sorted = true, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1xi32>) -> tensor loc(#loc16) - %50 = stablehlo.constant dense<0> : tensor loc(#loc9) - %51 = stablehlo.compare LT, %41, %50, SIGNED : (tensor, tensor) -> tensor loc(#loc12) - %52 = stablehlo.constant dense<3> : tensor loc(#loc9) - %53 = stablehlo.add %41, %52 : tensor loc(#loc11) - %54 = stablehlo.select %51, %53, %41 : tensor, tensor loc(#loc13) - %55 = stablehlo.dynamic_slice %iterArg_1, %54, sizes = [1] : (tensor<3xi32>, tensor) -> tensor<1xi32> loc(#loc17) - %56 = stablehlo.reshape %55 : (tensor<1xi32>) -> tensor loc(#loc18) - %57 = stablehlo.constant dense<0> : tensor loc(#loc9) - %58 = stablehlo.compare LT, %iterArg_0, %57, SIGNED : (tensor, tensor) -> tensor loc(#loc12) - %59 = stablehlo.constant dense<3> : tensor loc(#loc9) - %60 = stablehlo.add %iterArg_0, %59 : tensor loc(#loc11) - %61 = stablehlo.select %58, %60, %iterArg_0 : tensor, tensor loc(#loc13) - %62 = stablehlo.convert %61 : (tensor) -> tensor loc(#loc14) - %63 = stablehlo.broadcast_in_dim %62, dims = [] : (tensor) -> tensor<1xi32> loc(#loc15) - %64 = "stablehlo.scatter"(%iterArg_1, %63, %56) ({ - ^bb0(%arg0: tensor loc(unknown), %arg1: tensor loc(unknown)): - stablehlo.return %arg1 : tensor loc(#loc19) - }) {indices_are_sorted = true, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = true} : (tensor<3xi32>, tensor<1xi32>, tensor) -> tensor<3xi32> loc(#loc19) - %65 = stablehlo.constant dense<0> : tensor loc(#loc9) - %66 = stablehlo.compare LT, %41, %65, SIGNED : (tensor, tensor) -> tensor loc(#loc12) - %67 = stablehlo.constant dense<3> : tensor loc(#loc9) - %68 = stablehlo.add %41, %67 : tensor loc(#loc11) - %69 = stablehlo.select %66, %68, %41 : tensor, tensor loc(#loc13) - %70 = stablehlo.broadcast_in_dim %69, dims = [] : (tensor) -> tensor<1xi32> loc(#loc15) - %71 = "stablehlo.scatter"(%64, %70, %49) ({ - ^bb0(%arg0: tensor loc(unknown), %arg1: tensor loc(unknown)): - stablehlo.return %arg1 : tensor loc(#loc19) - }) {indices_are_sorted = true, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = true} : (tensor<3xi32>, tensor<1xi32>, tensor) -> tensor<3xi32> loc(#loc19) - %72 = stablehlo.constant dense<1> : tensor loc(#loc9) - %73 = stablehlo.add %iterArg, %72 : tensor loc(#loc11) - stablehlo.return %73, %33, %71, %iterArg_2 : tensor, tensor, tensor<3xi32>, tensor<3xi32> loc(#loc9) - } loc(#loc9) - return %27, %19, %31#2 : tensor<3x3xcomplex>, tensor<3xi32>, tensor<3xi32> loc(#loc) - } loc(#loc) -} loc(#loc) -#loc1 = loc("third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py":553:0) -#loc2 = loc("third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py":554:0) -#loc3 = loc("jit()/jit(main)/iota[dtype=complex128 shape=(9,) dimension=0]"(#loc1)) -#loc4 = loc("jit()/jit(main)/reshape[new_sizes=(3, 3) dimensions=None]"(#loc1)) -#loc5 = loc("jit()/jit(main)/lu"(#loc2)) -#loc6 = loc("jit()/jit(main)/iota[dtype=int32 shape=(3,) dimension=0]"(#loc2)) -#loc7 = loc("jit()/jit(main)/lu_pivots_to_permutation[permutation_size=3]"(#loc2)) -#loc8 = loc("jit()/jit(main)/scan[reverse=False length=3 num_consts=0 num_carry=3 linear=(False, False, False) unroll=1]"(#loc2)) -#loc9 = loc("jit()/jit(main)/while[cond_nconsts=0 body_nconsts=0]"(#loc2)) -#loc10 = loc("jit()/jit(main)/while/cond/lt"(#loc2)) -#loc11 = loc("jit()/jit(main)/while/body/add"(#loc2)) -#loc12 = loc("jit()/jit(main)/while/body/lt"(#loc2)) -#loc13 = loc("jit()/jit(main)/while/body/select_n"(#loc2)) -#loc14 = loc("jit()/jit(main)/while/body/convert_element_type[new_dtype=int32 weak_type=False]"(#loc2)) -#loc15 = loc("jit()/jit(main)/while/body/broadcast_in_dim[shape=(1,) broadcast_dimensions=()]"(#loc2)) -#loc16 = loc("jit()/jit(main)/while/body/gather[dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)) slice_sizes=(1,) unique_indices=True indices_are_sorted=True mode=GatherScatterMode.PROMISE_IN_BOUNDS fill_value=None]"(#loc2)) -#loc17 = loc("jit()/jit(main)/while/body/dynamic_slice[slice_sizes=(1,)]"(#loc2)) -#loc18 = loc("jit()/jit(main)/while/body/squeeze[dimensions=(0,)]"(#loc2)) -#loc19 = loc("jit()/jit(main)/while/body/scatter[update_consts=() dimension_numbers=ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,)) indices_are_sorted=True unique_indices=True mode=GatherScatterMode.FILL_OR_DROP]"(#loc2)) -""", - mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x013\x05\x01\x03\x01\x03\x05\x03#\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x1f!#%'\x03\xaa\x02\x0e\x025\x01\xb1\x0f\x0f\x07\x17\x0b\x13\x13\x0f\x1b\x13\x0f\x0f\x13\x0f\x13\x13\x13\x0f\x0f\x0b\x13\x17\x0b\x0b\x0b\x0b;\x0b\x0b\x0b\x0f;#\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x0b\x0f\x0b\x0f\x0b\x0b\x13\x0b\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x13\x13\x0f\x0b\x0f\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x0f\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x03Q\x0f\x0f/\x0b\x0f\x0b\x0bO\x0b/\x0b\x17\x13\x0b\x13\x0b\x13\x0b\x0b\x0b/\x0f/\x1f\x0b\x0b\x0b\x0b\x1b\x0f\x17\x17/\x1f\x1f\x0bOO/\x0b\x01\x07\x0b\x13\x0b\x01\x03\x0f\x033\x0f\x0f\x13\x13\x0f\x17\x07\x07\x0b\x07\x07\x13\x0f\x13\x1b\x07\x13\x13\x13\x13\x13\x13\x17\x17\x13\x02\x82\t\x1d]\x07\x1d\x8b\x07\x1f\x17-\xaa\x08\x01\x05)\x03\x03/\xb9\x03\x03\t\xd9\x1d\x8d\x07\x03\x051\xc13\xff\x03\x03\t\xfd\x1d\x8f\x07\x1d\x91\x07\x03\x03\t\xdf\x1d\x95\x07\x1d\x02\x02\x07\x03\x03\t\xdd\x03\x03\t\xf5\x1d\x93\x07\x11\x01\x05\x05+\x03\x03S\xb1\x17-\xa6\x08\x01\x05-\x05/\x051\x053\x03\r\x97\xb57\xb19\xbb\x99\xb9;\xc3\x9b\xb5\x055\x057\x059\x1d\x9d\x07\x03\r7\xb19\xbb\xa9\xb5\xab\xb5\xad\xbb\xaf\xb9\x03\x07C%E%'G\x05;\x05=\x05?\x03\x0bK\xbdM\xc5O\xc7'\xd5Q\xd7\x05A\x05C\x05E\x05G\x05I\x1dW+\x05K\x1d[+\x05M\x05O\x03\x03a\xb1\x05Q\x03\x03\t\xdb\x03\x11g\xe1i\xe3k\xe5m\xbdo\xe7q\xe9s\xebu\xef\x05S\x05U\x05W\x05Y\x05[\x05]\x05_\x05a\x03\x03\t\xf3\x03\x051\xc13\xf7\x03\x03\t\xf9\x03\x03/\xfb\x1d\x81\x07\x05c\x1d\x85\x07\x05e\x1d\x89\x07\x05g\x05i\x05k\x05m\x05o\x05q\x05s\x05u\x05w\x05y\x05{\x03\x03;\xc3\x1d\xa3\x07\x05}\x1d\xa7\x07\x05\x7f\x05\x81\x05\x83\x05\x85\x05\x87\x13\x11\x01\x1f'\x01\x1f\x1d\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1d\x89\x1f-\x01\x05\x03\x03\x01\x1f)!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\t\x07\x1f\x1d\x11\x01\x00\x00\x00\x00\x00\x00\x00#\x1f\x03\x07\xc9\xcd\xd1\r\x03\xb7\xcb\x1d\x8b\r\x03\xb7\xcf\x1d\x8d\r\x03\xb7\xd3\x1d\x8f\x1d\x91\x1d\x93\x1f\x03\x11\x03\x00\x00\x00\x00\x00\x00\x00\x1f\x19\x01\x1f\x03\x11\x01\x00\x00\x00\x00\x00\x00\x00\x1f\x05\t\x03\x00\x00\x00\x0b\x05\x1d\x95\x1d\x97\x05\x01\x03\t\xb3\xb3\xb3\xbf\x03\x03\xed\x15\x03\x01\r\x01\x03\x07\xbf\xf1\xb3\x1f+\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x05\t\x01\x00\x00\x00\x1f\x05\t\x00\x00\x00\x00\x07\x05\x1f\x1b!\x00\x00\x00\x00\x00\x00\xf8\x7f\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f3!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f\x03\x11\x00\x00\x00\x00\x00\x00\x00\x00\x07\x0b\x05\x99\x1d\n\x02\x07\x05\x9b\x01\x02\x02)\x01\x11)\x01\x0f)\x03\r\x0f)\x03\x05\x0f)\x01\x17)\x05\r\r\x13\x1b\x1d\x03!\x13\x01)\x03\x01\x0f)\x01\x13)\x03\x05\x11\x11\x01\x07\r\x07\x07\x0b)\x03%\x13)\x03\t\x0f)\x03\x01\x15)\x03\t\x15)\x03\x05\x15)\x03\x01\x11)\x05\x05\x05\x17)\x05\r\r\x17)\x03\t\x11\x04f\n\x05\x01\x11\x05A\x07\x03\x01\x05\x19\x11\x05I\x05\x03K\x85\x13\x03U)\x03#\x0f\x06Y\x03\r\x03\x01\x03\x03\x01\r\x03\x03\x03\x03\x01\r\x03\x03\x0b\x06\x01\x03\x05\x03\x05\x0f\x06\x01\x03\t\x03\t\x0b\x06\x01\x03\x05\x03\x07\x0f\x06\x01\x03\t\x03\r\x1b\x07\x01_\x03%\x05\x0b\x0f\x03\x03\x01\r\x03\x03\x0b\x06\x01\x03\x05\x03\x13\x0f\x06\x01\x03\t\x03\x15\x03\x03\x01c\x03\x19\x03\x03\x01\x1f\x03\x03\x03\x03\x01\x19\x03\x05\x03\x03\x01\x19\x03\x05\x1d\x07\x01e\x07\r\x07\x05\t\x1b\x1d\x1f\x03\x03\x03\x01w\x03\x05\x05\x07\x01\x0b\x03\x07\x03'\x1f\x06\x01\x03\x07\x05#)\x03\x03\x01!\x03\x05\x05\x07\x01\x0b\x03\x05\x03-\x07\x07\x01y\x03\x0b\x05%/\x05\x07\x01\x0b\x03/\x031\x03\x03\x01{\x03\x1b\x05\x07\x01\x0b\x03\r\x035\x05\x07\x01}\x031\x033\r\x06\x01\x03\r\x079!7\x13\x03\x7f)\x03\x07\x03\x03\x83\x13\x03\x03\x03\x03\x87\x13\x03\x03!\x16\x03\t\x03\x03\x07\x07\tA?=+\t\x03\r\x0f\t\x03\x05\x03\x05\x07\x05\x07\x05\x03\x03\x03\r\x03\x03\x07\x07\x06\x02\x11\x03\x0b\x05KS\x11\x04\x03\x03U\x03]\xaf\t\x03\x05\x03\x05\x07\x05\x07\x05\x03\x03\x03\x1f\x03\x03\t\x06\x0f\x03\x03\x05MS\x03\x03\x03\x13\x03\x03\x07\x07\x15\x11\x03\x0b\x05MW\x03\x03\x03\r\x03\x03\t\x06\x0f\x03\x03\x05M[\r\x06\x17\x03\x03\x07Y]M\x0b\x06#\x03\x05\x03_\x05\x07\x1b\x0b\x03\t\x03a\x15\x07=5\x03\x05\x05Qc\x03\x03\x03\x13\x03\x03\x07\x07\x15\x11\x03\x0b\x05Mg\x03\x03\x03\r\x03\x03\t\x06\x0f\x03\x03\x05Mk\r\x06\x17\x03\x03\x07imM\x0b\x06#\x03\x05\x03o\x05\x07\x1b\x0b\x03\t\x03q\x15\x07=5\x03\x05\x05Os\x03\x03\x03!\x03\x05\x07\x07\x15\x11\x03\x0b\x05ew\x03\x03\x03\x19\x03\x05\t\x06\x0f\x03\x05\x05e{\r\x06\x17\x03\x05\x07y}e#\x07\xa1\x9f\x03\t\x05O\x7f\x0f\x06\xa5\x03\x05\x03\x81\x03\x03\x03\x13\x03\x03\x07\x07\x15\x11\x03\x0b\x05M\x85\x03\x03\x03\r\x03\x03\t\x06\x0f\x03\x03\x05M\x89\r\x06\x17\x03\x03\x07\x87\x8bM\x0b\x06#\x03\x05\x03\x8d\x05\x07\x1b\x0b\x03\t\x03\x8f\x17\x17\x1d?\x03\x07\x07O\x91\x83\x05\x03\x05\x07\x05\x05\x05\x05\x05\x11\x04\x1d\x03\xa9\x03\x03\x03!\x03\x05\x07\x07\x15\x11\x03\x0b\x05e\x95\x03\x03\x03\x19\x03\x05\t\x06\x0f\x03\x05\x05e\x99\r\x06\x17\x03\x05\x07\x97\x9be\x05\x07\x1b\x0b\x03\t\x03\x9d\x17\x17\x1d?\x03\x07\x07\x93\x9fu\x05\x03\x05\x07\x05\x05\x05\x05\x05\x11\x04\x1d\x03\xa9\x03\x03\x03\x1f\x03\x03\t\x06\x0f\x03\x03\x05K\xa3\x11\x04\x03\t\xa5U\xa1Q\x11\x04\x05\x07;+G\x06\x03\x01\x05\x01\x00\x82%\x9dM2\x04\x1d\x03\x0f\x0b\t\t\t!'\x1f;+y\x87.\x04!\x19+\xb1\xb3YMO{\xe9\x8b\x83\x1f/!!)#\x1f\x19\x157\x85\x8d\x1f\x1f\x15\x1d\x15\x1b%)\x19'#+\x1b+\x83\x13\r#\x13\x19\x1f\x1f\x11\x17\x15\x11\x15\x17\x15\x17\x0f\x17)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00compare_v1\x00add_v1\x00convert_v1\x00select_v1\x00reshape_v1\x00return_v1\x00iota_v1\x00gather_v1\x00scatter_v1\x00func_v1\x00concatenate_v1\x00custom_call_v1\x00subtract_v1\x00while_v1\x00dynamic_slice_v1\x00value\x00sym_name\x00third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py\x00broadcast_dimensions\x00compare_type\x00comparison_direction\x00index_vector_dim\x00indices_are_sorted\x00slice_sizes\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00jit()/jit(main)/iota[dtype=complex128 shape=(9,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(3, 3) dimensions=None]\x00jit()/jit(main)/lu\x00dimension\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jit()/jit(main)/iota[dtype=int32 shape=(3,) dimension=0]\x00jit()/jit(main)/lu_pivots_to_permutation[permutation_size=3]\x00jit()/jit(main)/scan[reverse=False length=3 num_consts=0 num_carry=3 linear=(False, False, False) unroll=1]\x00jit()/jit(main)/while[cond_nconsts=0 body_nconsts=0]\x00jit()/jit(main)/while/body/add\x00jit()/jit(main)/while/body/lt\x00jit()/jit(main)/while/body/select_n\x00jit()/jit(main)/while/body/convert_element_type[new_dtype=int32 weak_type=False]\x00jit()/jit(main)/while/body/broadcast_in_dim[shape=(1,) broadcast_dimensions=()]\x00collapsed_slice_dims\x00offset_dims\x00start_index_map\x00jit()/jit(main)/while/body/gather[dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)) slice_sizes=(1,) unique_indices=True indices_are_sorted=True mode=GatherScatterMode.PROMISE_IN_BOUNDS fill_value=None]\x00jit()/jit(main)/while/body/dynamic_slice[slice_sizes=(1,)]\x00jit()/jit(main)/while/body/squeeze[dimensions=(0,)]\x00inserted_window_dims\x00scatter_dims_to_operand_dims\x00unique_indices\x00update_window_dims\x00jax.result_info\x00[0]\x00[1]\x00[2]\x00main\x00public\x00\x00lapack_zgetrf\x00jit()/jit(main)/while/body/scatter[update_consts=() dimension_numbers=ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,)) indices_are_sorted=True unique_indices=True mode=GatherScatterMode.FILL_OR_DROP]\x00jit()/jit(main)/while/cond/lt\x00", - xla_call_module_version=6, -) # End paste - data_2024_05_31 = {} - # Pasted from the test output (see export_back_compat_test_util.py module docstring) data_2024_05_31["c128"] = dict( testdata_version=1, diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/cpu_qr_lapack_geqrf.py b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_qr_lapack_geqrf.py index 94314a7ae518..bf41f3c3445c 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_data/cpu_qr_lapack_geqrf.py +++ b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_qr_lapack_geqrf.py @@ -17,259 +17,13 @@ import datetime from numpy import array, float32, complex64 -data_2023_03_17 = {} +data_2025_04_02 = {} -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_03_17["f32"] = dict( +data_2025_04_02['c128'] = dict( testdata_version=1, platform='cpu', - custom_call_targets=['lapack_sgeqrf', 'lapack_sorgqr'], - serialized_date=datetime.date(2023, 3, 17), - inputs=(), - expected_outputs=(array([[ 0. , 0.91287076, 0.4082487 ], - [-0.44721356, 0.36514866, -0.8164965 ], - [-0.8944271 , -0.18257445, 0.40824816]], dtype=float32), array([[-6.7082043e+00, -8.0498438e+00, -9.3914852e+00], - [ 0.0000000e+00, 1.0954441e+00, 2.1908894e+00], - [ 0.0000000e+00, 0.0000000e+00, 7.1525574e-07]], dtype=float32)), - mlir_module_text=r""" -module @jit__lambda_ { - func.func public @main() -> (tensor<3x3xf32> {jax.result_info = "[0]"}, tensor<3x3xf32> {jax.result_info = "[1]"}) { - %0 = stablehlo.iota dim = 0 : tensor<9xf32> - %1 = stablehlo.reshape %0 : (tensor<9xf32>) -> tensor<3x3xf32> - %2 = stablehlo.constant dense<1> : tensor - %3 = stablehlo.constant dense<3> : tensor - %4 = stablehlo.constant dense<3> : tensor - %5 = stablehlo.constant dense<96> : tensor - %6 = stablehlo.custom_call @lapack_sgeqrf(%2, %3, %4, %5, %1) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor<3x3xf32>) -> tuple, tensor<3xf32>, tensor, tensor<96xf32>> - %7 = stablehlo.get_tuple_element %6[0] : (tuple, tensor<3xf32>, tensor, tensor<96xf32>>) -> tensor<3x3xf32> - %8 = stablehlo.get_tuple_element %6[1] : (tuple, tensor<3xf32>, tensor, tensor<96xf32>>) -> tensor<3xf32> - %9 = stablehlo.get_tuple_element %6[2] : (tuple, tensor<3xf32>, tensor, tensor<96xf32>>) -> tensor - %10 = stablehlo.get_tuple_element %6[3] : (tuple, tensor<3xf32>, tensor, tensor<96xf32>>) -> tensor<96xf32> - %11 = stablehlo.constant dense<0> : tensor - %12 = stablehlo.broadcast_in_dim %11, dims = [] : (tensor) -> tensor - %13 = stablehlo.compare EQ, %9, %12, SIGNED : (tensor, tensor) -> tensor - %14 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor<1x1xi1> - %15 = stablehlo.constant dense<0x7FC00000> : tensor - %16 = stablehlo.broadcast_in_dim %15, dims = [] : (tensor) -> tensor<3x3xf32> - %17 = stablehlo.broadcast_in_dim %14, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<3x3xi1> - %18 = stablehlo.select %17, %7, %16 : tensor<3x3xi1>, tensor<3x3xf32> - %19 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor<1xi1> - %20 = stablehlo.constant dense<0x7FC00000> : tensor - %21 = stablehlo.broadcast_in_dim %20, dims = [] : (tensor) -> tensor<3xf32> - %22 = stablehlo.broadcast_in_dim %19, dims = [0] : (tensor<1xi1>) -> tensor<3xi1> - %23 = stablehlo.select %22, %8, %21 : tensor<3xi1>, tensor<3xf32> - %24 = stablehlo.constant dense<0.000000e+00> : tensor - %25 = stablehlo.pad %18, %24, low = [0, 0], high = [0, 0], interior = [0, 0] : (tensor<3x3xf32>, tensor) -> tensor<3x3xf32> - %26 = stablehlo.constant dense<1> : tensor - %27 = stablehlo.constant dense<3> : tensor - %28 = stablehlo.constant dense<3> : tensor - %29 = stablehlo.constant dense<3> : tensor - %30 = stablehlo.constant dense<96> : tensor - %31 = stablehlo.custom_call @lapack_sorgqr(%26, %27, %28, %29, %30, %25, %23) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor, tensor<3x3xf32>, tensor<3xf32>) -> tuple, tensor, tensor<96xf32>> - %32 = stablehlo.get_tuple_element %31[0] : (tuple, tensor, tensor<96xf32>>) -> tensor<3x3xf32> - %33 = stablehlo.get_tuple_element %31[1] : (tuple, tensor, tensor<96xf32>>) -> tensor - %34 = stablehlo.get_tuple_element %31[2] : (tuple, tensor, tensor<96xf32>>) -> tensor<96xf32> - %35 = stablehlo.constant dense<0> : tensor - %36 = stablehlo.broadcast_in_dim %35, dims = [] : (tensor) -> tensor - %37 = stablehlo.compare EQ, %33, %36, SIGNED : (tensor, tensor) -> tensor - %38 = stablehlo.broadcast_in_dim %37, dims = [] : (tensor) -> tensor<1x1xi1> - %39 = stablehlo.constant dense<0x7FC00000> : tensor - %40 = stablehlo.broadcast_in_dim %39, dims = [] : (tensor) -> tensor<3x3xf32> - %41 = stablehlo.broadcast_in_dim %38, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<3x3xi1> - %42 = stablehlo.select %41, %32, %40 : tensor<3x3xi1>, tensor<3x3xf32> - %43 = call @triu(%18) : (tensor<3x3xf32>) -> tensor<3x3xf32> - return %42, %43 : tensor<3x3xf32>, tensor<3x3xf32> - } - func.func private @triu(%arg0: tensor<3x3xf32>) -> tensor<3x3xf32> { - %0 = stablehlo.iota dim = 0 : tensor<3x3xi32> - %1 = stablehlo.constant dense<-1> : tensor - %2 = stablehlo.broadcast_in_dim %1, dims = [] : (tensor) -> tensor<3x3xi32> - %3 = stablehlo.add %0, %2 : tensor<3x3xi32> - %4 = stablehlo.iota dim = 1 : tensor<3x3xi32> - %5 = stablehlo.compare GE, %3, %4, SIGNED : (tensor<3x3xi32>, tensor<3x3xi32>) -> tensor<3x3xi1> - %6 = stablehlo.constant dense<0.000000e+00> : tensor - %7 = stablehlo.broadcast_in_dim %6, dims = [] : (tensor) -> tensor<3x3xf32> - %8 = stablehlo.select %5, %7, %arg0 : tensor<3x3xi1>, tensor<3x3xf32> - return %8 : tensor<3x3xf32> - } -} -""", - mlir_module_serialized=b"ML\xefR\x03MLIRxxx-trunk\x00\x01+\x05\x01\x05\x01\x03\x05\x03\x1b\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x1f\x03\xa2\x02\n\x027\x01\x9b\x0f\x0f\x17\x13\x0b\x0f\x13\x07\x0b\x0b\x0b\x13\x0b\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x0b\x13\x17\x13\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x13\x13\x13\x1b\x13\x13\x0b33\x0b\x0f\x0b\x13\x0b\x13\x0f\x0b\x1b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0bK\x13\x13\x0f\x0b#\x0b\x0b\x0b\x0f\x0b\x0bK\x03g\x0fO/\x0b/\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x13\x13\x0b\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x1f\x0f\x0f\x0b\x1f\x1f\x1f\x1f\x0b\x1f\x0f\x17\x1b\x0f\x0f\x0f\x0f\x1f\x0b\x1fO/\x0b'\x0f\x17\x17\x01\x05\x17\x0b\x037\x0f\x17\x0f\x07\x07\x07\x07\x17\x13\x17\x17\x07\x0f\x17\x13\x17\x17\x13\x13\x1b\x13\x13\x13\x13\x13\x13\x17\x02\xae\t\x1d\x7f\x05\x1d\x97\x05\x17!\xee\x05\x01\x03\x03\x15\xcd\x05!\x1dY\x05\x03\x03\t\xd7\x1f\x05#\x05%\x05'\x03\x03\t\xf1\x05)\x05+\x05-\x05/\x051\x03\x03%\xc9\x053\x1da\x05\x055\x057\x03\x03\t\xd3\x17!\xea\x05\x01\x03\x03\t\xd5\x03\x03\t\xd9\x059\x05;\x05=\x05?\x05A\x05C\x05E\x05G\x03\x03\x11\xe5\x03\x03\x11\xe7\x03\x03\x11\xe9\x03\x03\t\xed\x03\x05)\xab+\xef\x03\x03\x15\xf3\x03\x03\x13S\x05I\x03\x0b\x19\xa1\x1b\xb3\x1d\xb5\x13\xbf\x1f\xc1\x03\x0b\x19\xa7\x1b\xc5\x1d\xa7\x13\xa9\x1f\xc7\x05K\x1d]\x05\x05M\x03\x03\t\xcb\x05O\x03\x03%\xcf\x1dg\x05\x05Q\x03\x05)\xab+\xd1\x1dm\x05\x05S\x1dq\x05\x05U\x1du\x05\x05W\x1dy/\x05Y\x1d}/\x05[\x05]\x03\x115\xad7\xaf9\xdb;\xa1=\xb1?\xddA\xdfC\xe3\x03\x03\x11\xeb\x03\x03\x15\xf5\x1d\x89\x05\x05_\x03\x07\x8d\xa3\x8f\xa3\x91\xa3\x05a\x05c\x05e\x1d\x95\x05\x05g\x05i\x03\x115\xad7\xaf9\xf7;\xa1=\xb1?\xf9A\xfbC\xff\x1f)\x01\x1f+!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f-\x11\x00\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1f\x1d\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1dk\x03\x03\xc3\x1dm\t\x07\x0b\x05\x1do\x05\x01#\x1f\x03\x05\xb7\xbb\r\x03\xa5\xb9\x1dq\r\x03\xa5\xbd\x1ds\x1du\x1dw\r\x01#!\x1dy\x13\x0b\x01\x1f\x01\t\xff\xff\xff\xff\x1f#\x01\x13\x0b\x05\x07\x05\x1f\x05\t\x00\x00\x00\x00\x1f\x01\t\x01\x00\x00\x00\x1f\x01\t\x03\x00\x00\x00\x1f\x01\t`\x00\x00\x00\x1d{\x03\x0b\x9b\x9b\x9b\x9b\x9d\x03\x03\xe1\x15\x03\x01\x11\x01\x03\t\x9d\x9f\x9b\x9f\x13\x07\x01\x13\x07\x05\x13\x07\t\x13\x07\r\x1f\x01\t\x00\x00\x00\x00\x07\x01\x1f\x05\t\x00\x00\xc0\x7f\x1f\x1d!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f3\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1d}\x03\x0f\x9b\x9b\x9b\x9b\x9b\x9d\x9f\x03\x03\xfd\x15\x03\x01\x15\x01\x03\x07\x9d\x9b\x9f\x03\x03\x06\x02\xa9\x05\x7f)\x01\x07)\x05\r\r\t)\x01\t\x1b\t\x1d\x01)\x05\r\r\x07)\x03\r\t)\x03\x02\x03\t)\x05\r\r\r\x13)\x01\r)\x05\x05\x05\r)\x03\t\x0b\x11\x01\x05\x03\x03\x11\x03\x03\x03\x03)\x03\x01\x0b)\x03%\t/\t\x03\x11\x01\x13)\x03\x01\x17)\x03\t\x17)\x03\x05\x17)\x03\x05\r)\x03\r\r)\x03\x05\x0b/\x07\x03\x01\x13\x04\xe6\x06\x05\x01\x11\x0fQ\x07\x03\x01\t\x0f\x11\x0fU\x05\x03Y\xb5\x0b\x03w#\x03%\x17\x06{\x03\x03\x03\x01\x03\x03\x011\x03\x01\x03\x03\x01\r\x03\x01\x03\x03\x01\r\x03\x01\x03\x03\x013\x03\x01\x13\x07\x01\x81\x03'\x0b\x05\x07\t\x0b\x03\x07\x07\x01E\x03\x03\x03\r\x07\x07\x01G\x03\x11\x03\r\x07\x07\x01I\x03\x01\x03\r\x07\x07\x01\x83\x03\x13\x03\r\x03\x03\x01K\x03\x01\x05\x07\x01\x07\x03\x01\x03\x17\r\x07\x01M\x03\x19\x05\x13\x19\x05\x07\x01\x07\x03\x1b\x03\x1b\x03\x03\x01\x17\x03\x05\x05\x07\x01\x07\x03\x03\x03\x1f\x05\x07\x01O\x03\x15\x03\x1d\t\x06\x01\x03\x03\x07#\x0f!\x05\x07\x01\x07\x03/\x03\x1b\x03\x03\x01\x17\x03\x05\x05\x07\x01\x07\x03\x11\x03)\x05\x07\x01\x85\x031\x03'\t\x06\x01\x03\x11\x07-\x11+\x03\x03\x87-\x03\x05\x19\x07\x93\x8b\x03\x03\x05%1\x03\x03\x031\x03\x01\x03\x03\x03\r\x03\x01\x03\x03\x03\r\x03\x01\x03\x03\x03\r\x03\x01\x03\x03\x033\x03\x01\x13\x07\x03\x99\x035\x0f579;=3/\x07\x07\x03E\x03\x03\x03?\x07\x07\x03G\x03\x01\x03?\x07\x07\x03I\x03\x13\x03?\x03\x03\x03K\x03\x01\x05\x07\x03\x07\x03\x01\x03G\r\x07\x03M\x03\x19\x05CI\x05\x07\x03\x07\x03\x1b\x03K\x03\x03\x03\x17\x03\x05\x05\x07\x03\x07\x03\x03\x03O\x05\x07\x03O\x03\x15\x03M\t\x06\x03\x03\x03\x07SAQ\x1b\x07\x0b\x02\x02\x03\x03\x03%\x11\x04\x0f\x05UW\x0f\x11\x0bW\x05\x03\x15+\x03\x03\x0f\x0b\x03[#\x03\x0f\x03\x03\x0b_\x03\x01\x05\x07'\x07\x03\x0f\x03\x05\x15\x06'\x03\x0f\x05\x03\x07\x0b\x03ec\x03\x0f\r\x07ki\x03\x15\x05\t\x0b\x03\x03\x0b-\x03\x05\x05\x07o\x07\x03\x03\x03\x0f\t\x06s\x03\x03\x07\r\x11\x01\x11\x04\x0b\x03\x13\x06\x03\x01\x05\x01\x00\xc6\x18\x81\x0f\x1d\x1d\x11\x0f\x0b\t\t\x03\x0b!Y\x87##%_=\x85\x87W\xb3K\x9bM\x9b\xd2\x02\x1b\x1f/!!)#\x1f\x19+\x1b\x1f\x83\x1f\x15\x1d\x15+\x13\r\r\x11\x0f\x17\x0f\x1f\x15\x11\x17\x11\x15+)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00get_tuple_element_v1\x00select_v1\x00iota_v1\x00compare_v1\x00func_v1\x00return_v1\x00custom_call_v1\x00add_v1\x00reshape_v1\x00pad_v1\x00call_v1\x00value\x00index\x00sym_name\x00broadcast_dimensions\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py\x00iota_dimension\x00compare_type\x00comparison_direction\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jit__lambda_\x00jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False,) name=triu keep_unused=False inline=False]\x00jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=0]\x00jit()/jit(main)/jit(triu)/add\x00jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=1]\x00jit()/jit(main)/jit(triu)/ge\x00jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(3, 3) broadcast_dimensions=()]\x00jit()/jit(main)/jit(triu)/select_n\x00jit()/jit(main)/iota[dtype=float32 shape=(9,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(3, 3) dimensions=None]\x00jit()/jit(main)/geqrf\x00jit()/jit(main)/qr[full_matrices=True]\x00edge_padding_high\x00edge_padding_low\x00interior_padding\x00jit()/jit(main)/pad[padding_config=((0, 0, 0), (0, 0, 0))]\x00jit()/jit(main)/householder_product\x00jax.result_info\x00triu\x00\x00[0]\x00[1]\x00main\x00public\x00private\x00lapack_sgeqrf\x00lapack_sorgqr\x00callee\x00", - xla_call_module_version=4, -) # End paste - - -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_03_17["f64"] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_dgeqrf', 'lapack_dorgqr'], - serialized_date=datetime.date(2023, 3, 17), - inputs=(), - expected_outputs=(array([[ 0. , 0.9128709291752773 , 0.40824829046386235], - [-0.447213595499958 , 0.3651483716701102 , -0.8164965809277263 ], - [-0.894427190999916 , -0.1825741858350548 , 0.40824829046386324]]), array([[-6.7082039324993694e+00, -8.0498447189992444e+00, - -9.3914855054991175e+00], - [ 0.0000000000000000e+00, 1.0954451150103341e+00, - 2.1908902300206665e+00], - [ 0.0000000000000000e+00, 0.0000000000000000e+00, - -8.8817841970012523e-16]])), - mlir_module_text=r""" -module @jit__lambda_ { - func.func public @main() -> (tensor<3x3xf64> {jax.result_info = "[0]"}, tensor<3x3xf64> {jax.result_info = "[1]"}) { - %0 = stablehlo.iota dim = 0 : tensor<9xf64> - %1 = stablehlo.reshape %0 : (tensor<9xf64>) -> tensor<3x3xf64> - %2 = stablehlo.constant dense<1> : tensor - %3 = stablehlo.constant dense<3> : tensor - %4 = stablehlo.constant dense<3> : tensor - %5 = stablehlo.constant dense<96> : tensor - %6 = stablehlo.custom_call @lapack_dgeqrf(%2, %3, %4, %5, %1) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor<3x3xf64>) -> tuple, tensor<3xf64>, tensor, tensor<96xf64>> - %7 = stablehlo.get_tuple_element %6[0] : (tuple, tensor<3xf64>, tensor, tensor<96xf64>>) -> tensor<3x3xf64> - %8 = stablehlo.get_tuple_element %6[1] : (tuple, tensor<3xf64>, tensor, tensor<96xf64>>) -> tensor<3xf64> - %9 = stablehlo.get_tuple_element %6[2] : (tuple, tensor<3xf64>, tensor, tensor<96xf64>>) -> tensor - %10 = stablehlo.get_tuple_element %6[3] : (tuple, tensor<3xf64>, tensor, tensor<96xf64>>) -> tensor<96xf64> - %11 = stablehlo.constant dense<0> : tensor - %12 = stablehlo.broadcast_in_dim %11, dims = [] : (tensor) -> tensor - %13 = stablehlo.compare EQ, %9, %12, SIGNED : (tensor, tensor) -> tensor - %14 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor<1x1xi1> - %15 = stablehlo.constant dense<0x7FF8000000000000> : tensor - %16 = stablehlo.broadcast_in_dim %15, dims = [] : (tensor) -> tensor<3x3xf64> - %17 = stablehlo.broadcast_in_dim %14, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<3x3xi1> - %18 = stablehlo.select %17, %7, %16 : tensor<3x3xi1>, tensor<3x3xf64> - %19 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor<1xi1> - %20 = stablehlo.constant dense<0x7FF8000000000000> : tensor - %21 = stablehlo.broadcast_in_dim %20, dims = [] : (tensor) -> tensor<3xf64> - %22 = stablehlo.broadcast_in_dim %19, dims = [0] : (tensor<1xi1>) -> tensor<3xi1> - %23 = stablehlo.select %22, %8, %21 : tensor<3xi1>, tensor<3xf64> - %24 = stablehlo.constant dense<0.000000e+00> : tensor - %25 = stablehlo.pad %18, %24, low = [0, 0], high = [0, 0], interior = [0, 0] : (tensor<3x3xf64>, tensor) -> tensor<3x3xf64> - %26 = stablehlo.constant dense<1> : tensor - %27 = stablehlo.constant dense<3> : tensor - %28 = stablehlo.constant dense<3> : tensor - %29 = stablehlo.constant dense<3> : tensor - %30 = stablehlo.constant dense<96> : tensor - %31 = stablehlo.custom_call @lapack_dorgqr(%26, %27, %28, %29, %30, %25, %23) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor, tensor<3x3xf64>, tensor<3xf64>) -> tuple, tensor, tensor<96xf64>> - %32 = stablehlo.get_tuple_element %31[0] : (tuple, tensor, tensor<96xf64>>) -> tensor<3x3xf64> - %33 = stablehlo.get_tuple_element %31[1] : (tuple, tensor, tensor<96xf64>>) -> tensor - %34 = stablehlo.get_tuple_element %31[2] : (tuple, tensor, tensor<96xf64>>) -> tensor<96xf64> - %35 = stablehlo.constant dense<0> : tensor - %36 = stablehlo.broadcast_in_dim %35, dims = [] : (tensor) -> tensor - %37 = stablehlo.compare EQ, %33, %36, SIGNED : (tensor, tensor) -> tensor - %38 = stablehlo.broadcast_in_dim %37, dims = [] : (tensor) -> tensor<1x1xi1> - %39 = stablehlo.constant dense<0x7FF8000000000000> : tensor - %40 = stablehlo.broadcast_in_dim %39, dims = [] : (tensor) -> tensor<3x3xf64> - %41 = stablehlo.broadcast_in_dim %38, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<3x3xi1> - %42 = stablehlo.select %41, %32, %40 : tensor<3x3xi1>, tensor<3x3xf64> - %43 = call @triu(%18) : (tensor<3x3xf64>) -> tensor<3x3xf64> - return %42, %43 : tensor<3x3xf64>, tensor<3x3xf64> - } - func.func private @triu(%arg0: tensor<3x3xf64>) -> tensor<3x3xf64> { - %0 = stablehlo.iota dim = 0 : tensor<3x3xi32> - %1 = stablehlo.constant dense<-1> : tensor - %2 = stablehlo.broadcast_in_dim %1, dims = [] : (tensor) -> tensor<3x3xi32> - %3 = stablehlo.add %0, %2 : tensor<3x3xi32> - %4 = stablehlo.iota dim = 1 : tensor<3x3xi32> - %5 = stablehlo.compare GE, %3, %4, SIGNED : (tensor<3x3xi32>, tensor<3x3xi32>) -> tensor<3x3xi1> - %6 = stablehlo.constant dense<0.000000e+00> : tensor - %7 = stablehlo.broadcast_in_dim %6, dims = [] : (tensor) -> tensor<3x3xf64> - %8 = stablehlo.select %5, %7, %arg0 : tensor<3x3xi1>, tensor<3x3xf64> - return %8 : tensor<3x3xf64> - } -} -""", - mlir_module_serialized=b"ML\xefR\x03MLIRxxx-trunk\x00\x01+\x05\x01\x05\x01\x03\x05\x03\x1b\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x1f\x03\xa2\x02\n\x027\x01\x9b\x0f\x0f\x17\x13\x0b\x0f\x13\x07\x0b\x0b\x0b\x13\x0b\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x0b\x13\x17\x13\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x13\x13\x13\x1b\x13\x13\x0b33\x0b\x0f\x0b\x13\x0b\x13\x0f\x0b\x1b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0bK\x13\x13\x0f\x0b#\x0b\x0b\x0b\x0f\x0b\x0bK\x03g\x0fO/\x0b/\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x13\x13\x0b\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x1f\x0f\x0f\x0b/\x1f\x1f\x1f\x0b\x1f\x0f\x17\x1b\x0f\x0f\x0f\x0f\x1f\x0b/O/\x0b'\x0f\x17\x17\x01\x05\x17\x0b\x037\x0f\x17\x0f\x07\x07\x07\x07\x17\x13\x17\x17\x07\x0f\x17\x13\x17\x17\x13\x13\x1b\x13\x13\x13\x13\x13\x13\x17\x02\xce\t\x1d\x7f\x05\x1d\x97\x05\x17!\xee\x05\x01\x03\x03\x15\xcd\x05!\x1dY\x05\x03\x03\t\xd7\x1f\x05#\x05%\x05'\x03\x03\t\xf1\x05)\x05+\x05-\x05/\x051\x03\x03%\xc9\x053\x1da\x05\x055\x057\x03\x03\t\xd3\x17!\xea\x05\x01\x03\x03\t\xd5\x03\x03\t\xd9\x059\x05;\x05=\x05?\x05A\x05C\x05E\x05G\x03\x03\x11\xe5\x03\x03\x11\xe7\x03\x03\x11\xe9\x03\x03\t\xed\x03\x05)\xab+\xef\x03\x03\x15\xf3\x03\x03\x13S\x05I\x03\x0b\x19\xa1\x1b\xb3\x1d\xb5\x13\xbf\x1f\xc1\x03\x0b\x19\xa7\x1b\xc5\x1d\xa7\x13\xa9\x1f\xc7\x05K\x1d]\x05\x05M\x03\x03\t\xcb\x05O\x03\x03%\xcf\x1dg\x05\x05Q\x03\x05)\xab+\xd1\x1dm\x05\x05S\x1dq\x05\x05U\x1du\x05\x05W\x1dy/\x05Y\x1d}/\x05[\x05]\x03\x115\xad7\xaf9\xdb;\xa1=\xb1?\xddA\xdfC\xe3\x03\x03\x11\xeb\x03\x03\x15\xf5\x1d\x89\x05\x05_\x03\x07\x8d\xa3\x8f\xa3\x91\xa3\x05a\x05c\x05e\x1d\x95\x05\x05g\x05i\x03\x115\xad7\xaf9\xf7;\xa1=\xb1?\xf9A\xfbC\xff\x1f)\x01\x1f+!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f-\x11\x00\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1f\x1d\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1dk\x03\x03\xc3\x1dm\t\x07\x0b\x05\x1do\x05\x01#\x1f\x03\x05\xb7\xbb\r\x03\xa5\xb9\x1dq\r\x03\xa5\xbd\x1ds\x1du\x1dw\r\x01#!\x1dy\x13\x0b\x01\x1f\x01\t\xff\xff\xff\xff\x1f#\x01\x13\x0b\x05\x07\x05\x1f\x05\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x01\t\x01\x00\x00\x00\x1f\x01\t\x03\x00\x00\x00\x1f\x01\t`\x00\x00\x00\x1d{\x03\x0b\x9b\x9b\x9b\x9b\x9d\x03\x03\xe1\x15\x03\x01\x11\x01\x03\t\x9d\x9f\x9b\x9f\x13\x07\x01\x13\x07\x05\x13\x07\t\x13\x07\r\x1f\x01\t\x00\x00\x00\x00\x07\x01\x1f\x05\x11\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f\x1d!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f3\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1d}\x03\x0f\x9b\x9b\x9b\x9b\x9b\x9d\x9f\x03\x03\xfd\x15\x03\x01\x15\x01\x03\x07\x9d\x9b\x9f\x03\x03\x06\x02\xa9\x05\x7f)\x01\x07)\x05\r\r\t)\x01\t\x1b\x0b\x1d\x01)\x05\r\r\x07)\x03\r\t)\x03\x02\x03\t)\x05\r\r\r\x13)\x01\r)\x05\x05\x05\r)\x03\t\x0b\x11\x01\x05\x03\x03\x11\x03\x03\x03\x03)\x03\x01\x0b)\x03%\t/\t\x03\x11\x01\x13)\x03\x01\x17)\x03\t\x17)\x03\x05\x17)\x03\x05\r)\x03\r\r)\x03\x05\x0b/\x07\x03\x01\x13\x04\xe6\x06\x05\x01\x11\x0fQ\x07\x03\x01\t\x0f\x11\x0fU\x05\x03Y\xb5\x0b\x03w#\x03%\x17\x06{\x03\x03\x03\x01\x03\x03\x011\x03\x01\x03\x03\x01\r\x03\x01\x03\x03\x01\r\x03\x01\x03\x03\x013\x03\x01\x13\x07\x01\x81\x03'\x0b\x05\x07\t\x0b\x03\x07\x07\x01E\x03\x03\x03\r\x07\x07\x01G\x03\x11\x03\r\x07\x07\x01I\x03\x01\x03\r\x07\x07\x01\x83\x03\x13\x03\r\x03\x03\x01K\x03\x01\x05\x07\x01\x07\x03\x01\x03\x17\r\x07\x01M\x03\x19\x05\x13\x19\x05\x07\x01\x07\x03\x1b\x03\x1b\x03\x03\x01\x17\x03\x05\x05\x07\x01\x07\x03\x03\x03\x1f\x05\x07\x01O\x03\x15\x03\x1d\t\x06\x01\x03\x03\x07#\x0f!\x05\x07\x01\x07\x03/\x03\x1b\x03\x03\x01\x17\x03\x05\x05\x07\x01\x07\x03\x11\x03)\x05\x07\x01\x85\x031\x03'\t\x06\x01\x03\x11\x07-\x11+\x03\x03\x87-\x03\x05\x19\x07\x93\x8b\x03\x03\x05%1\x03\x03\x031\x03\x01\x03\x03\x03\r\x03\x01\x03\x03\x03\r\x03\x01\x03\x03\x03\r\x03\x01\x03\x03\x033\x03\x01\x13\x07\x03\x99\x035\x0f579;=3/\x07\x07\x03E\x03\x03\x03?\x07\x07\x03G\x03\x01\x03?\x07\x07\x03I\x03\x13\x03?\x03\x03\x03K\x03\x01\x05\x07\x03\x07\x03\x01\x03G\r\x07\x03M\x03\x19\x05CI\x05\x07\x03\x07\x03\x1b\x03K\x03\x03\x03\x17\x03\x05\x05\x07\x03\x07\x03\x03\x03O\x05\x07\x03O\x03\x15\x03M\t\x06\x03\x03\x03\x07SAQ\x1b\x07\x0b\x02\x02\x03\x03\x03%\x11\x04\x0f\x05UW\x0f\x11\x0bW\x05\x03\x15+\x03\x03\x0f\x0b\x03[#\x03\x0f\x03\x03\x0b_\x03\x01\x05\x07'\x07\x03\x0f\x03\x05\x15\x06'\x03\x0f\x05\x03\x07\x0b\x03ec\x03\x0f\r\x07ki\x03\x15\x05\t\x0b\x03\x03\x0b-\x03\x05\x05\x07o\x07\x03\x03\x03\x0f\t\x06s\x03\x03\x07\r\x11\x01\x11\x04\x0b\x03\x13\x06\x03\x01\x05\x01\x00\xc6\x18\x81\x0f\x1d\x1d\x11\x0f\x0b\t\t\x03\x0b!Y\x87##%_=\x85\x87W\xb3K\x9bM\x9b\xd2\x02\x1b\x1f/!!)#\x1f\x19+\x1b\x1f\x83\x1f\x15\x1d\x15+\x13\r\r\x11\x0f\x17\x0f\x1f\x15\x11\x17\x11\x15+)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00get_tuple_element_v1\x00select_v1\x00iota_v1\x00compare_v1\x00func_v1\x00return_v1\x00custom_call_v1\x00add_v1\x00reshape_v1\x00pad_v1\x00call_v1\x00value\x00index\x00sym_name\x00broadcast_dimensions\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py\x00iota_dimension\x00compare_type\x00comparison_direction\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jit__lambda_\x00jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False,) name=triu keep_unused=False inline=False]\x00jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=0]\x00jit()/jit(main)/jit(triu)/add\x00jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=1]\x00jit()/jit(main)/jit(triu)/ge\x00jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(3, 3) broadcast_dimensions=()]\x00jit()/jit(main)/jit(triu)/select_n\x00jit()/jit(main)/iota[dtype=float64 shape=(9,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(3, 3) dimensions=None]\x00jit()/jit(main)/geqrf\x00jit()/jit(main)/qr[full_matrices=True]\x00edge_padding_high\x00edge_padding_low\x00interior_padding\x00jit()/jit(main)/pad[padding_config=((0, 0, 0), (0, 0, 0))]\x00jit()/jit(main)/householder_product\x00jax.result_info\x00triu\x00\x00[0]\x00[1]\x00main\x00public\x00private\x00lapack_dgeqrf\x00lapack_dorgqr\x00callee\x00", - xla_call_module_version=4, -) # End paste - - -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_03_17["c64"] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_cgeqrf', 'lapack_cungqr'], - serialized_date=datetime.date(2023, 3, 17), - inputs=(), - expected_outputs=(array([[ 0. +0.j, 0.91287076+0.j, 0.4082487 +0.j], - [-0.44721356-0.j, 0.36514866+0.j, -0.8164965 +0.j], - [-0.8944271 -0.j, -0.18257445+0.j, 0.40824816+0.j]], - dtype=complex64), array([[-6.7082043e+00+0.j, -8.0498438e+00+0.j, -9.3914852e+00+0.j], - [ 0.0000000e+00+0.j, 1.0954441e+00+0.j, 2.1908894e+00+0.j], - [ 0.0000000e+00+0.j, 0.0000000e+00+0.j, 7.1525574e-07+0.j]], - dtype=complex64)), - mlir_module_text=r""" -module @jit__lambda_ { - func.func public @main() -> (tensor<3x3xcomplex> {jax.result_info = "[0]"}, tensor<3x3xcomplex> {jax.result_info = "[1]"}) { - %0 = stablehlo.iota dim = 0 : tensor<9xcomplex> - %1 = stablehlo.reshape %0 : (tensor<9xcomplex>) -> tensor<3x3xcomplex> - %2 = stablehlo.constant dense<1> : tensor - %3 = stablehlo.constant dense<3> : tensor - %4 = stablehlo.constant dense<3> : tensor - %5 = stablehlo.constant dense<96> : tensor - %6 = stablehlo.custom_call @lapack_cgeqrf(%2, %3, %4, %5, %1) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor<3x3xcomplex>) -> tuple>, tensor<3xcomplex>, tensor, tensor<96xcomplex>> - %7 = stablehlo.get_tuple_element %6[0] : (tuple>, tensor<3xcomplex>, tensor, tensor<96xcomplex>>) -> tensor<3x3xcomplex> - %8 = stablehlo.get_tuple_element %6[1] : (tuple>, tensor<3xcomplex>, tensor, tensor<96xcomplex>>) -> tensor<3xcomplex> - %9 = stablehlo.get_tuple_element %6[2] : (tuple>, tensor<3xcomplex>, tensor, tensor<96xcomplex>>) -> tensor - %10 = stablehlo.get_tuple_element %6[3] : (tuple>, tensor<3xcomplex>, tensor, tensor<96xcomplex>>) -> tensor<96xcomplex> - %11 = stablehlo.constant dense<0> : tensor - %12 = stablehlo.broadcast_in_dim %11, dims = [] : (tensor) -> tensor - %13 = stablehlo.compare EQ, %9, %12, SIGNED : (tensor, tensor) -> tensor - %14 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor<1x1xi1> - %15 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> - %16 = stablehlo.broadcast_in_dim %15, dims = [] : (tensor>) -> tensor<3x3xcomplex> - %17 = stablehlo.broadcast_in_dim %14, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<3x3xi1> - %18 = stablehlo.select %17, %7, %16 : tensor<3x3xi1>, tensor<3x3xcomplex> - %19 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor<1xi1> - %20 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> - %21 = stablehlo.broadcast_in_dim %20, dims = [] : (tensor>) -> tensor<3xcomplex> - %22 = stablehlo.broadcast_in_dim %19, dims = [0] : (tensor<1xi1>) -> tensor<3xi1> - %23 = stablehlo.select %22, %8, %21 : tensor<3xi1>, tensor<3xcomplex> - %24 = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor> - %25 = stablehlo.pad %18, %24, low = [0, 0], high = [0, 0], interior = [0, 0] : (tensor<3x3xcomplex>, tensor>) -> tensor<3x3xcomplex> - %26 = stablehlo.constant dense<1> : tensor - %27 = stablehlo.constant dense<3> : tensor - %28 = stablehlo.constant dense<3> : tensor - %29 = stablehlo.constant dense<3> : tensor - %30 = stablehlo.constant dense<96> : tensor - %31 = stablehlo.custom_call @lapack_cungqr(%26, %27, %28, %29, %30, %25, %23) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor, tensor<3x3xcomplex>, tensor<3xcomplex>) -> tuple>, tensor, tensor<96xcomplex>> - %32 = stablehlo.get_tuple_element %31[0] : (tuple>, tensor, tensor<96xcomplex>>) -> tensor<3x3xcomplex> - %33 = stablehlo.get_tuple_element %31[1] : (tuple>, tensor, tensor<96xcomplex>>) -> tensor - %34 = stablehlo.get_tuple_element %31[2] : (tuple>, tensor, tensor<96xcomplex>>) -> tensor<96xcomplex> - %35 = stablehlo.constant dense<0> : tensor - %36 = stablehlo.broadcast_in_dim %35, dims = [] : (tensor) -> tensor - %37 = stablehlo.compare EQ, %33, %36, SIGNED : (tensor, tensor) -> tensor - %38 = stablehlo.broadcast_in_dim %37, dims = [] : (tensor) -> tensor<1x1xi1> - %39 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> - %40 = stablehlo.broadcast_in_dim %39, dims = [] : (tensor>) -> tensor<3x3xcomplex> - %41 = stablehlo.broadcast_in_dim %38, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<3x3xi1> - %42 = stablehlo.select %41, %32, %40 : tensor<3x3xi1>, tensor<3x3xcomplex> - %43 = call @triu(%18) : (tensor<3x3xcomplex>) -> tensor<3x3xcomplex> - return %42, %43 : tensor<3x3xcomplex>, tensor<3x3xcomplex> - } - func.func private @triu(%arg0: tensor<3x3xcomplex>) -> tensor<3x3xcomplex> { - %0 = stablehlo.iota dim = 0 : tensor<3x3xi32> - %1 = stablehlo.constant dense<-1> : tensor - %2 = stablehlo.broadcast_in_dim %1, dims = [] : (tensor) -> tensor<3x3xi32> - %3 = stablehlo.add %0, %2 : tensor<3x3xi32> - %4 = stablehlo.iota dim = 1 : tensor<3x3xi32> - %5 = stablehlo.compare GE, %3, %4, SIGNED : (tensor<3x3xi32>, tensor<3x3xi32>) -> tensor<3x3xi1> - %6 = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor> - %7 = stablehlo.broadcast_in_dim %6, dims = [] : (tensor>) -> tensor<3x3xcomplex> - %8 = stablehlo.select %5, %7, %arg0 : tensor<3x3xi1>, tensor<3x3xcomplex> - return %8 : tensor<3x3xcomplex> - } -} -""", - mlir_module_serialized=b"ML\xefR\x03MLIRxxx-trunk\x00\x01+\x05\x01\x05\x01\x03\x05\x03\x1b\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x1f\x03\xa6\x02\n\x029\x01\x9b\x0f\x0f\x17\x13\x0b\x0f\x13\x07\x0b\x0b\x0b\x13\x0b\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x0b\x13\x17\x13\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x13\x13\x13\x1b\x13\x13\x0b33\x0b\x0f\x0b\x13\x0b\x13\x0f\x0b\x1b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0bK\x13\x13\x0f\x0b#\x0b\x0b\x0b\x0f\x0b\x0bK\x03g\x0fO/\x0b/\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x13\x13\x0b\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x1f\x0f\x0f\x0b/\x1f\x1f\x1f\x0b\x1f\x0f\x17\x1b\x0f\x0f\x0f\x0f\x1f\x0b/O/\x0b'\x0f\x17\x17\x01\x05\x17\x0b\x039\x0f\x17\x0f\x07\x0b\x07\x07\x17\x13\x17\x17\x07\x0f\x17\x13\x17\x07\x17\x13\x13\x1b\x13\x13\x13\x13\x13\x13\x17\x02\xd6\t\x1d\x7f\x05\x1d\x97\x05\x17!\xee\x05\x01\x03\x03\x15\xcd\x05!\x1dY\x05\x03\x03\t\xd7\x1f\x05#\x05%\x05'\x03\x03\t\xf1\x05)\x05+\x05-\x05/\x051\x03\x03%\xc9\x053\x1da\x05\x055\x057\x03\x03\t\xd3\x17!\xea\x05\x01\x03\x03\t\xd5\x03\x03\t\xd9\x059\x05;\x05=\x05?\x05A\x05C\x05E\x05G\x03\x03\x11\xe5\x03\x03\x11\xe7\x03\x03\x11\xe9\x03\x03\t\xed\x03\x05)\xab+\xef\x03\x03\x15\xf3\x03\x03\x13S\x05I\x03\x0b\x19\xa1\x1b\xb3\x1d\xb5\x13\xbf\x1f\xc1\x03\x0b\x19\xa7\x1b\xc5\x1d\xa7\x13\xa9\x1f\xc7\x05K\x1d]\x05\x05M\x03\x03\t\xcb\x05O\x03\x03%\xcf\x1dg\x05\x05Q\x03\x05)\xab+\xd1\x1dm\x05\x05S\x1dq\x05\x05U\x1du\x05\x05W\x1dy/\x05Y\x1d}/\x05[\x05]\x03\x115\xad7\xaf9\xdb;\xa1=\xb1?\xddA\xdfC\xe3\x03\x03\x11\xeb\x03\x03\x15\xf5\x1d\x89\x05\x05_\x03\x07\x8d\xa3\x8f\xa3\x91\xa3\x05a\x05c\x05e\x1d\x95\x05\x05g\x05i\x03\x115\xad7\xaf9\xf7;\xa1=\xb1?\xf9A\xfbC\xff\x1f+\x01\x1f-!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f/\x11\x00\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1f\x1d\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1dk\x03\x03\xc3\x1dm\t\x07\x0b\x05\x1do\x05\x01#\x1f\x03\x05\xb7\xbb\r\x03\xa5\xb9\x1dq\r\x03\xa5\xbd\x1ds\x1du\x1dw\r\x01##\x1dy\x13\x0b\x01\x1f\x01\t\xff\xff\xff\xff\x1f%\x01\x13\x0b\x05\x07\x05\x1f\x05\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x01\t\x01\x00\x00\x00\x1f\x01\t\x03\x00\x00\x00\x1f\x01\t`\x00\x00\x00\x1d{\x03\x0b\x9b\x9b\x9b\x9b\x9d\x03\x03\xe1\x15\x03\x01\x11\x01\x03\t\x9d\x9f\x9b\x9f\x13\x07\x01\x13\x07\x05\x13\x07\t\x13\x07\r\x1f\x01\t\x00\x00\x00\x00\x07\x01\x1f\x05\x11\x00\x00\xc0\x7f\x00\x00\xc0\x7f\x1f\x1d!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f5\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1d}\x03\x0f\x9b\x9b\x9b\x9b\x9b\x9d\x9f\x03\x03\xfd\x15\x03\x01\x15\x01\x03\x07\x9d\x9b\x9f\x03\x03\x06\x02\xa9\x05\x7f)\x01\x07)\x05\r\r\t)\x01\t\x1b\x03!\x1d\x01)\x05\r\r\x07)\x03\r\t)\x03\x02\x03\t)\x05\r\r\r\x13)\x01\r)\x05\x05\x05\r)\x03\t\x0b\x11\x01\x05\x03\x03\t\x11\x03\x03\x03\x03)\x03\x01\x0b)\x03%\t/\t\x03\x11\x01\x13)\x03\x01\x17)\x03\t\x17)\x03\x05\x17)\x03\x05\r)\x03\r\r)\x03\x05\x0b/\x07\x03\x01\x13\x04\xe6\x06\x05\x01\x11\x0fQ\x07\x03\x01\t\x0f\x11\x0fU\x05\x03Y\xb5\x0b\x03w#\x03'\x17\x06{\x03\x03\x03\x01\x03\x03\x011\x03\x01\x03\x03\x01\r\x03\x01\x03\x03\x01\r\x03\x01\x03\x03\x013\x03\x01\x13\x07\x01\x81\x03)\x0b\x05\x07\t\x0b\x03\x07\x07\x01E\x03\x03\x03\r\x07\x07\x01G\x03\x11\x03\r\x07\x07\x01I\x03\x01\x03\r\x07\x07\x01\x83\x03\x13\x03\r\x03\x03\x01K\x03\x01\x05\x07\x01\x07\x03\x01\x03\x17\r\x07\x01M\x03\x19\x05\x13\x19\x05\x07\x01\x07\x03\x1b\x03\x1b\x03\x03\x01\x17\x03\x05\x05\x07\x01\x07\x03\x03\x03\x1f\x05\x07\x01O\x03\x15\x03\x1d\t\x06\x01\x03\x03\x07#\x0f!\x05\x07\x01\x07\x031\x03\x1b\x03\x03\x01\x17\x03\x05\x05\x07\x01\x07\x03\x11\x03)\x05\x07\x01\x85\x033\x03'\t\x06\x01\x03\x11\x07-\x11+\x03\x03\x87-\x03\x05\x19\x07\x93\x8b\x03\x03\x05%1\x03\x03\x031\x03\x01\x03\x03\x03\r\x03\x01\x03\x03\x03\r\x03\x01\x03\x03\x03\r\x03\x01\x03\x03\x033\x03\x01\x13\x07\x03\x99\x037\x0f579;=3/\x07\x07\x03E\x03\x03\x03?\x07\x07\x03G\x03\x01\x03?\x07\x07\x03I\x03\x13\x03?\x03\x03\x03K\x03\x01\x05\x07\x03\x07\x03\x01\x03G\r\x07\x03M\x03\x19\x05CI\x05\x07\x03\x07\x03\x1b\x03K\x03\x03\x03\x17\x03\x05\x05\x07\x03\x07\x03\x03\x03O\x05\x07\x03O\x03\x15\x03M\t\x06\x03\x03\x03\x07SAQ\x1b\x07\x0b\x02\x02\x03\x03\x03%\x11\x04\x0f\x05UW\x0f\x11\x0bW\x05\x03\x15+\x03\x03\x0f\x0b\x03[#\x03\x0f\x03\x03\x0b_\x03\x01\x05\x07'\x07\x03\x0f\x03\x05\x15\x06'\x03\x0f\x05\x03\x07\x0b\x03ec\x03\x0f\r\x07ki\x03\x15\x05\t\x0b\x03\x03\x0b-\x03\x05\x05\x07o\x07\x03\x03\x03\x0f\t\x06s\x03\x03\x07\r\x11\x01\x11\x04\x0b\x03\x13\x06\x03\x01\x05\x01\x00\xce\x18\x81\x0f\x1d\x1d\x11\x0f\x0b\t\t\x03\x0b!Y\x87##%_=\x85\x8bW\xb3K\x9bM\x9b\xd2\x02\x1b\x1f/!!)#\x1f\x19+\x1b\x1f\x83\x1f\x15\x1d\x15+\x13\r\r\x11\x0f\x17\x0f\x1f\x15\x11\x17\x11\x15+)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00get_tuple_element_v1\x00select_v1\x00iota_v1\x00compare_v1\x00func_v1\x00return_v1\x00custom_call_v1\x00add_v1\x00reshape_v1\x00pad_v1\x00call_v1\x00value\x00index\x00sym_name\x00broadcast_dimensions\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py\x00iota_dimension\x00compare_type\x00comparison_direction\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jit__lambda_\x00jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False,) name=triu keep_unused=False inline=False]\x00jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=0]\x00jit()/jit(main)/jit(triu)/add\x00jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=1]\x00jit()/jit(main)/jit(triu)/ge\x00jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(3, 3) broadcast_dimensions=()]\x00jit()/jit(main)/jit(triu)/select_n\x00jit()/jit(main)/iota[dtype=complex64 shape=(9,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(3, 3) dimensions=None]\x00jit()/jit(main)/geqrf\x00jit()/jit(main)/qr[full_matrices=True]\x00edge_padding_high\x00edge_padding_low\x00interior_padding\x00jit()/jit(main)/pad[padding_config=((0, 0, 0), (0, 0, 0))]\x00jit()/jit(main)/householder_product\x00jax.result_info\x00triu\x00\x00[0]\x00[1]\x00main\x00public\x00private\x00lapack_cgeqrf\x00lapack_cungqr\x00callee\x00", - xla_call_module_version=4, -) # End paste - - -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_03_17["c128"] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_zgeqrf', 'lapack_zungqr'], - serialized_date=datetime.date(2023, 3, 17), + custom_call_targets=['lapack_zgeqrf_ffi', 'lapack_zungqr_ffi'], + serialized_date=datetime.date(2025, 4, 2), inputs=(), expected_outputs=(array([[ 0. +0.j, 0.9128709291752773 +0.j, 0.40824829046386235+0.j], @@ -283,531 +37,199 @@ [ 0.0000000000000000e+00+0.j, 0.0000000000000000e+00+0.j, -8.8817841970012523e-16+0.j]])), mlir_module_text=r""" -module @jit__lambda_ { - func.func public @main() -> (tensor<3x3xcomplex> {jax.result_info = "[0]"}, tensor<3x3xcomplex> {jax.result_info = "[1]"}) { - %0 = stablehlo.iota dim = 0 : tensor<9xcomplex> - %1 = stablehlo.reshape %0 : (tensor<9xcomplex>) -> tensor<3x3xcomplex> - %2 = stablehlo.constant dense<1> : tensor - %3 = stablehlo.constant dense<3> : tensor - %4 = stablehlo.constant dense<3> : tensor - %5 = stablehlo.constant dense<96> : tensor - %6 = stablehlo.custom_call @lapack_zgeqrf(%2, %3, %4, %5, %1) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor<3x3xcomplex>) -> tuple>, tensor<3xcomplex>, tensor, tensor<96xcomplex>> - %7 = stablehlo.get_tuple_element %6[0] : (tuple>, tensor<3xcomplex>, tensor, tensor<96xcomplex>>) -> tensor<3x3xcomplex> - %8 = stablehlo.get_tuple_element %6[1] : (tuple>, tensor<3xcomplex>, tensor, tensor<96xcomplex>>) -> tensor<3xcomplex> - %9 = stablehlo.get_tuple_element %6[2] : (tuple>, tensor<3xcomplex>, tensor, tensor<96xcomplex>>) -> tensor - %10 = stablehlo.get_tuple_element %6[3] : (tuple>, tensor<3xcomplex>, tensor, tensor<96xcomplex>>) -> tensor<96xcomplex> - %11 = stablehlo.constant dense<0> : tensor - %12 = stablehlo.broadcast_in_dim %11, dims = [] : (tensor) -> tensor - %13 = stablehlo.compare EQ, %9, %12, SIGNED : (tensor, tensor) -> tensor - %14 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor<1x1xi1> - %15 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> - %16 = stablehlo.broadcast_in_dim %15, dims = [] : (tensor>) -> tensor<3x3xcomplex> - %17 = stablehlo.broadcast_in_dim %14, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<3x3xi1> - %18 = stablehlo.select %17, %7, %16 : tensor<3x3xi1>, tensor<3x3xcomplex> - %19 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor<1xi1> - %20 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> - %21 = stablehlo.broadcast_in_dim %20, dims = [] : (tensor>) -> tensor<3xcomplex> - %22 = stablehlo.broadcast_in_dim %19, dims = [0] : (tensor<1xi1>) -> tensor<3xi1> - %23 = stablehlo.select %22, %8, %21 : tensor<3xi1>, tensor<3xcomplex> - %24 = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor> - %25 = stablehlo.pad %18, %24, low = [0, 0], high = [0, 0], interior = [0, 0] : (tensor<3x3xcomplex>, tensor>) -> tensor<3x3xcomplex> - %26 = stablehlo.constant dense<1> : tensor - %27 = stablehlo.constant dense<3> : tensor - %28 = stablehlo.constant dense<3> : tensor - %29 = stablehlo.constant dense<3> : tensor - %30 = stablehlo.constant dense<96> : tensor - %31 = stablehlo.custom_call @lapack_zungqr(%26, %27, %28, %29, %30, %25, %23) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor, tensor<3x3xcomplex>, tensor<3xcomplex>) -> tuple>, tensor, tensor<96xcomplex>> - %32 = stablehlo.get_tuple_element %31[0] : (tuple>, tensor, tensor<96xcomplex>>) -> tensor<3x3xcomplex> - %33 = stablehlo.get_tuple_element %31[1] : (tuple>, tensor, tensor<96xcomplex>>) -> tensor - %34 = stablehlo.get_tuple_element %31[2] : (tuple>, tensor, tensor<96xcomplex>>) -> tensor<96xcomplex> - %35 = stablehlo.constant dense<0> : tensor - %36 = stablehlo.broadcast_in_dim %35, dims = [] : (tensor) -> tensor - %37 = stablehlo.compare EQ, %33, %36, SIGNED : (tensor, tensor) -> tensor - %38 = stablehlo.broadcast_in_dim %37, dims = [] : (tensor) -> tensor<1x1xi1> - %39 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> - %40 = stablehlo.broadcast_in_dim %39, dims = [] : (tensor>) -> tensor<3x3xcomplex> - %41 = stablehlo.broadcast_in_dim %38, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<3x3xi1> - %42 = stablehlo.select %41, %32, %40 : tensor<3x3xi1>, tensor<3x3xcomplex> - %43 = call @triu(%18) : (tensor<3x3xcomplex>) -> tensor<3x3xcomplex> - return %42, %43 : tensor<3x3xcomplex>, tensor<3x3xcomplex> - } - func.func private @triu(%arg0: tensor<3x3xcomplex>) -> tensor<3x3xcomplex> { - %0 = stablehlo.iota dim = 0 : tensor<3x3xi32> - %1 = stablehlo.constant dense<-1> : tensor - %2 = stablehlo.broadcast_in_dim %1, dims = [] : (tensor) -> tensor<3x3xi32> - %3 = stablehlo.add %0, %2 : tensor<3x3xi32> - %4 = stablehlo.iota dim = 1 : tensor<3x3xi32> - %5 = stablehlo.compare GE, %3, %4, SIGNED : (tensor<3x3xi32>, tensor<3x3xi32>) -> tensor<3x3xi1> - %6 = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor> - %7 = stablehlo.broadcast_in_dim %6, dims = [] : (tensor>) -> tensor<3x3xcomplex> - %8 = stablehlo.select %5, %7, %arg0 : tensor<3x3xi1>, tensor<3x3xcomplex> - return %8 : tensor<3x3xcomplex> - } -} -""", - mlir_module_serialized=b"ML\xefR\x03MLIRxxx-trunk\x00\x01+\x05\x01\x05\x01\x03\x05\x03\x1b\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x1f\x03\xa6\x02\n\x029\x01\x9b\x0f\x0f\x17\x13\x0b\x0f\x13\x07\x0b\x0b\x0b\x13\x0b\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x0b\x13\x17\x13\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x13\x13\x13\x1b\x13\x13\x0b33\x0b\x0f\x0b\x13\x0b\x13\x0f\x0b\x1b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0bK\x13\x13\x0f\x0b#\x0b\x0b\x0b\x0f\x0b\x0bK\x03g\x0fO/\x0b/\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x13\x13\x0b\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x1f\x0f\x0f\x0bO\x1f\x1f\x1f\x0b\x1f\x0f\x17\x1b\x0f\x0f\x0f\x0f\x1f\x0bOO/\x0b'\x0f\x17\x17\x01\x05\x17\x0b\x039\x0f\x17\x0f\x07\x0b\x07\x07\x17\x13\x17\x17\x07\x0f\x17\x13\x17\x07\x17\x13\x13\x1b\x13\x13\x13\x13\x13\x13\x17\x02\x16\n\x1d\x7f\x05\x1d\x97\x05\x17!\xee\x05\x01\x03\x03\x15\xcd\x05!\x1dY\x05\x03\x03\t\xd7\x1f\x05#\x05%\x05'\x03\x03\t\xf1\x05)\x05+\x05-\x05/\x051\x03\x03%\xc9\x053\x1da\x05\x055\x057\x03\x03\t\xd3\x17!\xea\x05\x01\x03\x03\t\xd5\x03\x03\t\xd9\x059\x05;\x05=\x05?\x05A\x05C\x05E\x05G\x03\x03\x11\xe5\x03\x03\x11\xe7\x03\x03\x11\xe9\x03\x03\t\xed\x03\x05)\xab+\xef\x03\x03\x15\xf3\x03\x03\x13S\x05I\x03\x0b\x19\xa1\x1b\xb3\x1d\xb5\x13\xbf\x1f\xc1\x03\x0b\x19\xa7\x1b\xc5\x1d\xa7\x13\xa9\x1f\xc7\x05K\x1d]\x05\x05M\x03\x03\t\xcb\x05O\x03\x03%\xcf\x1dg\x05\x05Q\x03\x05)\xab+\xd1\x1dm\x05\x05S\x1dq\x05\x05U\x1du\x05\x05W\x1dy/\x05Y\x1d}/\x05[\x05]\x03\x115\xad7\xaf9\xdb;\xa1=\xb1?\xddA\xdfC\xe3\x03\x03\x11\xeb\x03\x03\x15\xf5\x1d\x89\x05\x05_\x03\x07\x8d\xa3\x8f\xa3\x91\xa3\x05a\x05c\x05e\x1d\x95\x05\x05g\x05i\x03\x115\xad7\xaf9\xf7;\xa1=\xb1?\xf9A\xfbC\xff\x1f+\x01\x1f-!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f/\x11\x00\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1f\x1d\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1dk\x03\x03\xc3\x1dm\t\x07\x0b\x05\x1do\x05\x01#\x1f\x03\x05\xb7\xbb\r\x03\xa5\xb9\x1dq\r\x03\xa5\xbd\x1ds\x1du\x1dw\r\x01##\x1dy\x13\x0b\x01\x1f\x01\t\xff\xff\xff\xff\x1f%\x01\x13\x0b\x05\x07\x05\x1f\x05!\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x01\t\x01\x00\x00\x00\x1f\x01\t\x03\x00\x00\x00\x1f\x01\t`\x00\x00\x00\x1d{\x03\x0b\x9b\x9b\x9b\x9b\x9d\x03\x03\xe1\x15\x03\x01\x11\x01\x03\t\x9d\x9f\x9b\x9f\x13\x07\x01\x13\x07\x05\x13\x07\t\x13\x07\r\x1f\x01\t\x00\x00\x00\x00\x07\x01\x1f\x05!\x00\x00\x00\x00\x00\x00\xf8\x7f\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f\x1d!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f5\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1d}\x03\x0f\x9b\x9b\x9b\x9b\x9b\x9d\x9f\x03\x03\xfd\x15\x03\x01\x15\x01\x03\x07\x9d\x9b\x9f\x03\x03\x06\x02\xa9\x05\x7f)\x01\x07)\x05\r\r\t)\x01\t\x1b\x03!\x1d\x01)\x05\r\r\x07)\x03\r\t)\x03\x02\x03\t)\x05\r\r\r\x13)\x01\r)\x05\x05\x05\r)\x03\t\x0b\x11\x01\x05\x03\x03\x0b\x11\x03\x03\x03\x03)\x03\x01\x0b)\x03%\t/\t\x03\x11\x01\x13)\x03\x01\x17)\x03\t\x17)\x03\x05\x17)\x03\x05\r)\x03\r\r)\x03\x05\x0b/\x07\x03\x01\x13\x04\xe6\x06\x05\x01\x11\x0fQ\x07\x03\x01\t\x0f\x11\x0fU\x05\x03Y\xb5\x0b\x03w#\x03'\x17\x06{\x03\x03\x03\x01\x03\x03\x011\x03\x01\x03\x03\x01\r\x03\x01\x03\x03\x01\r\x03\x01\x03\x03\x013\x03\x01\x13\x07\x01\x81\x03)\x0b\x05\x07\t\x0b\x03\x07\x07\x01E\x03\x03\x03\r\x07\x07\x01G\x03\x11\x03\r\x07\x07\x01I\x03\x01\x03\r\x07\x07\x01\x83\x03\x13\x03\r\x03\x03\x01K\x03\x01\x05\x07\x01\x07\x03\x01\x03\x17\r\x07\x01M\x03\x19\x05\x13\x19\x05\x07\x01\x07\x03\x1b\x03\x1b\x03\x03\x01\x17\x03\x05\x05\x07\x01\x07\x03\x03\x03\x1f\x05\x07\x01O\x03\x15\x03\x1d\t\x06\x01\x03\x03\x07#\x0f!\x05\x07\x01\x07\x031\x03\x1b\x03\x03\x01\x17\x03\x05\x05\x07\x01\x07\x03\x11\x03)\x05\x07\x01\x85\x033\x03'\t\x06\x01\x03\x11\x07-\x11+\x03\x03\x87-\x03\x05\x19\x07\x93\x8b\x03\x03\x05%1\x03\x03\x031\x03\x01\x03\x03\x03\r\x03\x01\x03\x03\x03\r\x03\x01\x03\x03\x03\r\x03\x01\x03\x03\x033\x03\x01\x13\x07\x03\x99\x037\x0f579;=3/\x07\x07\x03E\x03\x03\x03?\x07\x07\x03G\x03\x01\x03?\x07\x07\x03I\x03\x13\x03?\x03\x03\x03K\x03\x01\x05\x07\x03\x07\x03\x01\x03G\r\x07\x03M\x03\x19\x05CI\x05\x07\x03\x07\x03\x1b\x03K\x03\x03\x03\x17\x03\x05\x05\x07\x03\x07\x03\x03\x03O\x05\x07\x03O\x03\x15\x03M\t\x06\x03\x03\x03\x07SAQ\x1b\x07\x0b\x02\x02\x03\x03\x03%\x11\x04\x0f\x05UW\x0f\x11\x0bW\x05\x03\x15+\x03\x03\x0f\x0b\x03[#\x03\x0f\x03\x03\x0b_\x03\x01\x05\x07'\x07\x03\x0f\x03\x05\x15\x06'\x03\x0f\x05\x03\x07\x0b\x03ec\x03\x0f\r\x07ki\x03\x15\x05\t\x0b\x03\x03\x0b-\x03\x05\x05\x07o\x07\x03\x03\x03\x0f\t\x06s\x03\x03\x07\r\x11\x01\x11\x04\x0b\x03\x13\x06\x03\x01\x05\x01\x00\xd2\x18\x81\x0f\x1d\x1d\x11\x0f\x0b\t\t\x03\x0b!Y\x87##%_=\x85\x8dW\xb3K\x9bM\x9b\xd2\x02\x1b\x1f/!!)#\x1f\x19+\x1b\x1f\x83\x1f\x15\x1d\x15+\x13\r\r\x11\x0f\x17\x0f\x1f\x15\x11\x17\x11\x15+)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00get_tuple_element_v1\x00select_v1\x00iota_v1\x00compare_v1\x00func_v1\x00return_v1\x00custom_call_v1\x00add_v1\x00reshape_v1\x00pad_v1\x00call_v1\x00value\x00index\x00sym_name\x00broadcast_dimensions\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py\x00iota_dimension\x00compare_type\x00comparison_direction\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jit__lambda_\x00jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False,) name=triu keep_unused=False inline=False]\x00jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=0]\x00jit()/jit(main)/jit(triu)/add\x00jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=1]\x00jit()/jit(main)/jit(triu)/ge\x00jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(3, 3) broadcast_dimensions=()]\x00jit()/jit(main)/jit(triu)/select_n\x00jit()/jit(main)/iota[dtype=complex128 shape=(9,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(3, 3) dimensions=None]\x00jit()/jit(main)/geqrf\x00jit()/jit(main)/qr[full_matrices=True]\x00edge_padding_high\x00edge_padding_low\x00interior_padding\x00jit()/jit(main)/pad[padding_config=((0, 0, 0), (0, 0, 0))]\x00jit()/jit(main)/householder_product\x00jax.result_info\x00triu\x00\x00[0]\x00[1]\x00main\x00public\x00private\x00lapack_zgeqrf\x00lapack_zungqr\x00callee\x00", - xla_call_module_version=4, -) # End paste - - -data_2024_08_22 = {} - - -# Pasted from the test output (see export_back_compat_test_util.py module docstring) -data_2024_08_22['c128'] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_zgeqrf_ffi', 'lapack_zungqr_ffi'], - serialized_date=datetime.date(2024, 8, 22), - inputs=(), - expected_outputs=( - array([ - [0.0 + 0.0j, 0.9128709291752773 + 0.0j, 0.40824829046386235 + 0.0j], - [ - -0.447213595499958 - 0.0j, - 0.3651483716701102 + 0.0j, - -0.8164965809277263 + 0.0j, - ], - [ - -0.894427190999916 - 0.0j, - -0.1825741858350548 + 0.0j, - 0.40824829046386324 + 0.0j, - ], - ]), - array([ - [ - -6.7082039324993694e00 + 0.0j, - -8.0498447189992444e00 + 0.0j, - -9.3914855054991175e00 + 0.0j, - ], - [ - 0.0000000000000000e00 + 0.0j, - 1.0954451150103341e00 + 0.0j, - 2.1908902300206665e00 + 0.0j, - ], - [ - 0.0000000000000000e00 + 0.0j, - 0.0000000000000000e00 + 0.0j, - -8.8817841970012523e-16 + 0.0j, - ], - ]), - ), - mlir_module_text=r""" -#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":364:11) -#loc10 = loc("jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=triu keep_unused=False inline=False]"(#loc3)) module @jit__lambda_ attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main() -> (tensor<3x3xcomplex> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<3x3xcomplex> {jax.result_info = "[1]", mhlo.layout_mode = "default"}) { + func.func public @main() -> (tensor<3x3xcomplex> {jax.result_info = "result[0]"}, tensor<3x3xcomplex> {jax.result_info = "result[1]"}) { + %c = stablehlo.constant dense<-1> : tensor loc(#loc) + %cst = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor> loc(#loc) %0 = stablehlo.iota dim = 0 : tensor<9xcomplex> loc(#loc4) %1 = stablehlo.reshape %0 : (tensor<9xcomplex>) -> tensor<3x3xcomplex> loc(#loc5) - %c = stablehlo.constant dense<3> : tensor loc(#loc6) - %c_0 = stablehlo.constant dense<3> : tensor loc(#loc6) - %2:2 = stablehlo.custom_call @lapack_zgeqrf_ffi(%1) {mhlo.backend_config = {}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>]} : (tensor<3x3xcomplex>) -> (tensor<3x3xcomplex>, tensor<3xcomplex>) loc(#loc6) - %cst = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor> loc(#loc7) - %3 = stablehlo.pad %2#0, %cst, low = [0, 0], high = [0, 0], interior = [0, 0] : (tensor<3x3xcomplex>, tensor>) -> tensor<3x3xcomplex> loc(#loc8) - %c_1 = stablehlo.constant dense<3> : tensor loc(#loc9) - %c_2 = stablehlo.constant dense<3> : tensor loc(#loc9) - %c_3 = stablehlo.constant dense<3> : tensor loc(#loc9) - %c_4 = stablehlo.constant dense<1> : tensor loc(#loc9) - %c_5 = stablehlo.constant dense<3> : tensor loc(#loc9) - %c_6 = stablehlo.constant dense<3> : tensor loc(#loc9) - %c_7 = stablehlo.constant dense<3> : tensor loc(#loc9) - %c_8 = stablehlo.constant dense<96> : tensor loc(#loc9) - %4:3 = stablehlo.custom_call @lapack_zungqr(%c_4, %c_5, %c_6, %c_7, %c_8, %3, %2#1) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor, tensor<3x3xcomplex>, tensor<3xcomplex>) -> (tensor<3x3xcomplex>, tensor, tensor<96xcomplex>) loc(#loc9) - %c_9 = stablehlo.constant dense<0> : tensor loc(#loc9) - %5 = stablehlo.broadcast_in_dim %c_9, dims = [] : (tensor) -> tensor loc(#loc9) - %6 = stablehlo.compare EQ, %4#1, %5, SIGNED : (tensor, tensor) -> tensor loc(#loc9) - %7 = stablehlo.broadcast_in_dim %6, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc9) - %cst_10 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> loc(#loc9) - %8 = stablehlo.broadcast_in_dim %cst_10, dims = [] : (tensor>) -> tensor<3x3xcomplex> loc(#loc9) - %9 = stablehlo.broadcast_in_dim %7, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<3x3xi1> loc(#loc9) - %10 = stablehlo.select %9, %4#0, %8 : tensor<3x3xi1>, tensor<3x3xcomplex> loc(#loc9) - %11 = call @triu(%2#0) : (tensor<3x3xcomplex>) -> tensor<3x3xcomplex> loc(#loc10) - return %10, %11 : tensor<3x3xcomplex>, tensor<3x3xcomplex> loc(#loc) + %2:2 = stablehlo.custom_call @lapack_zgeqrf_ffi(%1) {mhlo.backend_config = {}, mhlo.frontend_attributes = {num_batch_dims = "0"}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>]} : (tensor<3x3xcomplex>) -> (tensor<3x3xcomplex>, tensor<3xcomplex>) loc(#loc6) + %3 = stablehlo.pad %2#0, %cst, low = [0, 0], high = [0, 0], interior = [0, 0] : (tensor<3x3xcomplex>, tensor>) -> tensor<3x3xcomplex> loc(#loc7) + %4 = stablehlo.custom_call @lapack_zungqr_ffi(%3, %2#1) {mhlo.backend_config = {}, mhlo.frontend_attributes = {num_batch_dims = "0"}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>]} : (tensor<3x3xcomplex>, tensor<3xcomplex>) -> tensor<3x3xcomplex> loc(#loc8) + %5 = stablehlo.iota dim = 0 : tensor<3x3xi32> loc(#loc9) + %6 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<3x3xi32> loc(#loc10) + %7 = stablehlo.add %5, %6 : tensor<3x3xi32> loc(#loc10) + %8 = stablehlo.iota dim = 1 : tensor<3x3xi32> loc(#loc9) + %9 = stablehlo.compare GE, %7, %8, SIGNED : (tensor<3x3xi32>, tensor<3x3xi32>) -> tensor<3x3xi1> loc(#loc11) + %10 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor>) -> tensor<3x3xcomplex> loc(#loc12) + %11 = stablehlo.select %9, %10, %2#0 : tensor<3x3xi1>, tensor<3x3xcomplex> loc(#loc13) + return %4, %11 : tensor<3x3xcomplex>, tensor<3x3xcomplex> loc(#loc) } loc(#loc) - func.func private @triu(%arg0: tensor<3x3xcomplex> {mhlo.layout_mode = "default"} loc("jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=triu keep_unused=False inline=False]"(#loc3))) -> (tensor<3x3xcomplex> {mhlo.layout_mode = "default"}) { - %0 = stablehlo.iota dim = 0 : tensor<3x3xi32> loc(#loc11) - %c = stablehlo.constant dense<-1> : tensor loc(#loc10) - %1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<3x3xi32> loc(#loc12) - %2 = stablehlo.add %0, %1 : tensor<3x3xi32> loc(#loc12) - %3 = stablehlo.iota dim = 1 : tensor<3x3xi32> loc(#loc13) - %4 = stablehlo.compare GE, %2, %3, SIGNED : (tensor<3x3xi32>, tensor<3x3xi32>) -> tensor<3x3xi1> loc(#loc14) - %cst = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor> loc(#loc10) - %5 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor>) -> tensor<3x3xcomplex> loc(#loc15) - %6 = stablehlo.select %4, %5, %arg0 : tensor<3x3xi1>, tensor<3x3xcomplex> loc(#loc16) - return %6 : tensor<3x3xcomplex> loc(#loc10) - } loc(#loc10) } loc(#loc) #loc = loc(unknown) -#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":363:26) -#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":363:14) -#loc4 = loc("jit()/jit(main)/iota[dtype=complex128 shape=(9,) dimension=0]"(#loc1)) -#loc5 = loc("jit()/jit(main)/reshape[new_sizes=(3, 3) dimensions=None]"(#loc2)) +#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":411:26) +#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":411:14) +#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":412:11) +#loc4 = loc("jit()/jit(main)/iota"(#loc1)) +#loc5 = loc("jit()/jit(main)/reshape"(#loc2)) #loc6 = loc("jit()/jit(main)/geqrf"(#loc3)) -#loc7 = loc("jit()/jit(main)/qr[full_matrices=True]"(#loc3)) -#loc8 = loc("jit()/jit(main)/pad[padding_config=((0, 0, 0), (0, 0, 0))]"(#loc3)) -#loc9 = loc("jit()/jit(main)/householder_product"(#loc3)) -#loc11 = loc("jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=0]"(#loc3)) -#loc12 = loc("jit()/jit(main)/jit(triu)/add"(#loc3)) -#loc13 = loc("jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=1]"(#loc3)) -#loc14 = loc("jit()/jit(main)/jit(triu)/ge"(#loc3)) -#loc15 = loc("jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(3, 3) broadcast_dimensions=()]"(#loc3)) -#loc16 = loc("jit()/jit(main)/jit(triu)/select_n"(#loc3)) +#loc7 = loc("jit()/jit(main)/pad"(#loc3)) +#loc8 = loc("jit()/jit(main)/householder_product"(#loc3)) +#loc9 = loc("jit()/jit(main)/iota"(#loc3)) +#loc10 = loc("jit()/jit(main)/add"(#loc3)) +#loc11 = loc("jit()/jit(main)/ge"(#loc3)) +#loc12 = loc("jit()/jit(main)/broadcast_in_dim"(#loc3)) +#loc13 = loc("jit()/jit(main)/select_n"(#loc3)) """, - mlir_module_serialized=( - b"ML\xefR\x01StableHLO_v0.9.0\x00\x01)\x05\x01\x03\x01\x03\x05\x03\x19\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x03\xae\x02\x12\x025\x01\x9d\x0f\x17\x0b\x0f\x13\x13\x0b\x07\x0b\x0f\x13\x0f\x0b\x0b\x0b\x0b\x13\x0b\x0b\x0f\x0b\x0b\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b+\x0b\x0f\x0b\x0b\x0b33\x0b\x0f\x0b\x13\x0b\x13\x0f\x0b\x1b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x17\x0f\x0b\x17\x0bS\x0b\x0f\x0b#\x0b\x0b\x0b\x0f\x0b\x0b\x13\x13K\x13\x1b\x13\x03g\x0fO\x0b\x0b\x0b//\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x1b\x0b\x0b\x0b\x13\x0b\x0b\x0f\x1f\x0f\x0f\x0bO/\x0b\x0b\x0b\x0f\x0f\x17\x13\x1f\x1f\x1f\x0b\x0b'\x0f\x17\x17\x1f\x0bOO\x01\x07\x17\x17\x0b\x01\x05\x0b\x0f\x031\x17\x0f\x0f\x0b\x07\x0f\x17\x07\x07\x07\x17\x13\x17\x07\x17\x13\x13\x13\x13\x13\x17\x13\x0f\x17\x02\xf2\t\x1d\x8f\x03\x17\x11\xb2\x05\x17\x05\x1f\x1dO\x03\x03\x03%\xd1\x03\x03\x05\xd9\x05!\x1f\x05#\x1dy\x03\x03\x03\x05\xeb\x11\x03\x05\x05%\x05'\x05)\x05+\x03\x03#\xcd\x05-\x05/\x1dW\x03\x051\x053\x03\x03\x05\xd7\x055\x057\x059\x05;\x05=\x05?\x05A\x05C\x03\tACE\x17G\x17\rI\x05E\x11\x01\x00\x05G\x05I\x05K\x03\x0b\x19\xa1\x1b\xb7\x1d\xb9\r\xc3\x1f\xc5\x03\x0b\x19\xad\x1b\xc9\x1d\xad\r\xaf\x1f\xcb\x05M\x1dS\x03\x05O\x03\x03\x05\xcf\x05Q\x03\x03#\xd3\x1d]\x03\x05S\x03\x05)\xb1+\xd5\x1dc\x03\x05U\x1dg\x03\x05W\x1dk\x03\x05Y\x1doq\x05[\x17\x11\xae\x055\x1duw\x05]\x17\x11\xae\x05\x1d\x05_\x03\x13/\xdb1\xb33\xdd5\xa17\xb5}\xdf9\xe1;\xe3=\xe7\x05a\x1d\x81\x03\x05c\x03\x07\x85\xa9\x87\xa9\x89\xa9\x05e\x05g\x05i\x1d\x8d\x03\x05k\x05m\x03\x03\x05\xe9\x03\x03\x05\xed\x03\x11/\xef1\xb33\xf15\xa17\xb59\xf3;\xf5=\xf9\x03\x03\x05\xfb\x03\x05)\xb1+\xfd\x03\x03\x05\xff\x1f/\x01\x1f)!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1do\x1dq\x1f+\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x1b\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1ds\x03\x03\xc7\x1du\t\x07\x1dw\x05\x01#\x1d\x03\x05\xbb\xbf\r\x05\xab\xbd\xa3\xa5\x1dy\r\x05\xab\xc1\xa3\xa5\x1d{\x1d}\x1d\x7f\r\x03\xa3\xa5#!\x1d\x81\x13\r\x01\x1f\x07\t\xff\xff\xff\xff\x1f#\x01\x13\r\x05\x07\x05\x1f\x0f!\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\t\x11\x03\x00\x00\x00\x00\x00\x00\x00\x0b\x03\x1d\x83\r\x01\x03\x03\x9f\x03\x03\xe5\x15\x03\x01\x01\x01\x03\x05\x9f\xa7\x1f\x07\t\x01\x00\x00\x00\x1f\x07\t\x03\x00\x00\x00\x1f\x07\t`\x00\x00\x00\x0b\x05\x1d\x85\x03\x0f\x9d\x9d\x9d\x9d\x9d\x9f\xa7\x03\x03\xf7\x15\x03\x01\x15\x01\x03\x07\x9f\x9d\xa7\x1f\x07\t\x00\x00\x00\x00\x07\x01\x1f\x0f!\x00\x00\x00\x00\x00\x00\xf8\x7f\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f\x1b!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\x03%\x02\x02\x03\x03\x0e\x02\xaf\x05\x87\x01\t\x01\x02\x02)\x05\r\r\x0b)\x01\x17)\x01\r\x03\x1f\x1d)\x01\x0b)\x05\r\r\x17\x01\x13\x1b)\x05\r\r\x13)\x03\t\r\x11\x01\x05\x05\x05\x0b\x11\x03\x05\x03\x05)\x03\x01\r)\x03%\x0b)\x03\r\x0b)\x03\t\x15)\x03\x05\x15)\x03\x02\x03\x0b)\x03\x01\x15)\x01\x13)\x05\x05\x05\x13\x04\x8a\x04\x05\x01\x11\x0f?\x07\x03\x01\t\t\x11\x0fK\x07\x039i\x07\x03m!\x03%\x15\x06s\x03\x05\x03\x01\x03\x03\x13\x0b\x03\t\x03\x03\x13\x0b\x03\t\x11\x07\x13{\x05\x05'\x03\x03\x03\x03\x7f-\x03\x0f\x17\x07\x8b\x83\x03\x05\x05\t\r\x03\x03\x01\x0b\x03\t\x03\x03\x01\x0b\x03\t\x03\x03\x01\x0b\x03\t\x03\x03\x01\x91\x03\x07\x03\x03\x01\x15\x03\x07\x03\x03\x01\x15\x03\x07\x03\x03\x01\x15\x03\x07\x03\x03\x01\x93\x03\x07\x11\x07\x01\x95\x07\x05\x07-\x0f\x17\x19\x1b\x1d\x1f\x0f\x0b\x03\x03\x01\x97\x03\x07\x05\x07\x01\t\x03\x07\x03'\x0b\x07\x01\x99\x031\x05#)\x05\x07\x01\t\x033\x03+\x03\x03\x01\x9b\x03\x0f\x05\x07\x01\t\x03\x05\x03/\x05\x07\x01\x06\x02\x03\x19\x03-\r\x06\x01\x03\x05\x073!1\x19\x07\x07\n\x02\x03\x05\x03\t\x0f\x04\x0f\x0557\t\x11\x07M\x07\x03\x15+\x03\x05\x07\x07\x03Q!\x03\x11\x03\x03\x07U\x03\x07\x05\x07'\t\x03\x11\x03\x05\x13\x06'\x03\x11\x05\x03\x07\x07\x03[Y\x03\x11\x0b\x07a_\x03\x19\x05\t\x0b\x03\x03\x07-\x03\x0f\x05\x07e\t\x03\x05\x03\x0f\r\x06i\x03\x05\x07\r\x11\x01\x0f\x04\x07\x03\x13\x06\x03\x01\x05\x01\x00\xaa\x1a\x89\x0f\x1d%\x11\x0f\x0b\t\t\x03\x0b!\x11#Y\x87##%_)=\x85\x8dW\xb3K\x9bM\x9bn\x03\x1b%)9\x1f/!!)#\x1f\x19+\x1b+\x1f\x1f\x15\x1d\x15i\x13\r\x11\x0f\x17\x0f\x1f\x15\x15\x17\x11\x11)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00iota_v1\x00func_v1\x00compare_v1\x00select_v1\x00return_v1\x00custom_call_v1\x00add_v1\x00reshape_v1\x00pad_v1\x00call_v1\x00value\x00sym_name\x00third_party/py/jax/tests/export_back_compat_test.py\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00broadcast_dimensions\x00compare_type\x00comparison_direction\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,)" - b' out_shardings=(UnspecifiedValue,) in_layouts=(None,)' - b' out_layouts=(None,) resource_env=None donated_invars=(False,)' - b' name=triu keep_unused=False' - b' inline=False]\x00jit()/jit(main)/jit(triu)/iota[dtype=int32' - b' shape=(3, 3)' - b' dimension=0]\x00jit()/jit(main)/jit(triu)/add\x00jit()/jit(main)/jit(triu)/iota[dtype=int32' - b' shape=(3, 3)' - b' dimension=1]\x00jit()/jit(main)/jit(triu)/ge\x00jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(3,' - b' 3) broadcast_dimensions=()]\x00jit()/jit(main)/jit(triu)/select_n\x00jit()/jit(main)/iota[dtype=complex128' - b' shape=(9,)' - b' dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(3, 3)' - b' dimensions=None]\x00jit()/jit(main)/geqrf\x00mhlo.backend_config\x00jit()/jit(main)/qr[full_matrices=True]\x00edge_padding_high\x00edge_padding_low\x00interior_padding\x00jit()/jit(main)/pad[padding_config=((0,' - b' 0, 0), (0, 0,' - b' 0))]\x00jit()/jit(main)/householder_product\x00mhlo.layout_mode\x00default\x00jax.result_info\x00triu\x00\x00[0]\x00[1]\x00main\x00public\x00private\x00lapack_zgeqrf_ffi\x00lapack_zungqr\x00callee\x00' - ), + mlir_module_serialized=b'ML\xefR\rStableHLO_v1.9.3\x00\x01)\x05\x01\x05\x19\x01\x03\x0b\x03\x17\x0f\x13\x17\x1b\x1f#\'+/37\x03\xc7\x8b)\x01E\x17\x07\x0b\x0f\x0b\x1b\x0f\x0f#\x0b\x0f\x0b\x0b\x0b\x0f\x17\x0f\x0b\x17\x0b\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x03G\x0b/\x0b\x0f\x0b\x0b\x0b\x0fO\x13\x0f\x0b\x13\x13\x0b\x13\x0b\x0b\x0b\x1fO\x0b\x13\x0b\x0b\x0b\x0f\x17/\x0b\x0f\x13\x0f\x0b\x0b\x01\x05\x0b\x0f\x03%\x17\x0b\x07\x17\x0f\x07\x0f\x07\x17\x07\x13\x13\x13\x13\x13\x13\x17\x07\x02\xd2\x04\x17\x05r\x06\x17\x1f\x05\x1d\x11\x03\x05\x05\x1f\x03\x05\'o)q\x1d\t\x01\x1d7\x01\x03\x07\x13\x15\x17\x07\x19\x07\x05!\x11\x01\x00\x05#\x05%\x05\'\x1d\t\x1f\x17\x05n\x065\x1d#%\x05)\x17\x05n\x06\x1d\x05+\x05-\x1d-\x01\x05/\x1d1\x01\x051\x1d5\x01\x053\x055\x1d;\x01\x057\x1d?\x01\x059\x1dC\x01\x05;\x03\x01\x1f!\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1d=\x13\t\x01\x0b\x03\x1d?\x05\x01\x03\x03U\x1f\x1d!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\x05U}\x1f#\x01#\x15\x03\x05_c\r\x03Ia\x1dA\r\x03Ie\x1dC\x1dE\x1dG\x1f\r\t\xff\xff\xff\xff\x1f\x11!\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\r\x01\r\x03su\x1dI\x1dK\x1dM\x03\x03{\x15\x03\x01\x01\x01\x1f\x1f\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1dO\x03\x03\x83\x15\x01\x01\x01\x13\t\x05\t\x07\x07\x05\x01\t\x01\x02\x02)\x05\r\r\x07\x03\x17\x1d)\x05\r\r\x0f)\x01\x0f\x1b)\x01\x07\x13\x11\x01\x05\x05\x05\x0b)\x03%\x07)\x03\r\x07)\x03\t\x13)\x03\x05\x13)\x03\t\t)\x03\x01\t)\x05\r\r\'\x01\x04"\x02\x05\x01Q\x03\x11\x01\x07\x04\xff\x03\x01\x05\x0bP\x03\x03\x07\x04\xeb\x03\x1f=\x05B\x03\x05\x03\r\x05B\x03\x07\x03\x11\x03B\x1d\t\x03\x19\r\x06!\x03\x05\x03\x05\x07G+\x0b\x0b\x05\x05\x1b\x03\x07\x0fF/\r\x03\x05\x05\t\x03\x07G3\x0b\x0f\x03\x05\x05\r\x0b\x03B\r\t\x03\x0b\tF\x0f\x11\x03\x0b\x03\x01\x11\x06\x0f\x03\x0b\x05\x11\x13\x03B\r\x13\x03\x0b\x13F9\x15\x03%\x05\x15\x17\tF=\x11\x03\x05\x03\x03\x15\x06A\x03\x05\x07\x19\x1b\t\x17\x04\x03\x05\x0f\x1d\x06\x03\x01\x05\x01\x00\xba\x0bQ%%\x05\x1f\x0f\x0b\x15\x15\x03!CS79Y9=3)A\x1b%)9;i\x15\x15\x17\x0f\x0f\x17\x11)\x1f\x19\x11\x0f\x0b\x11builtin\x00vhlo\x00module\x00iota_v1\x00constant_v1\x00custom_call_v1\x00broadcast_in_dim_v1\x00func_v1\x00reshape_v1\x00pad_v1\x00add_v1\x00compare_v1\x00select_v1\x00return_v1\x00third_party/py/jax/tests/export_back_compat_test.py\x00jit()/jit(main)/iota\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00jit()/jit(main)/reshape\x00mhlo.backend_config\x00mhlo.frontend_attributes\x00jit()/jit(main)/geqrf\x00jit()/jit(main)/pad\x00jit()/jit(main)/householder_product\x00jit()/jit(main)/add\x00jit()/jit(main)/ge\x00jit()/jit(main)/broadcast_in_dim\x00jit()/jit(main)/select_n\x00jax.result_info\x00\x00result[0]\x00result[1]\x00main\x00public\x00num_batch_dims\x000\x00lapack_zgeqrf_ffi\x00lapack_zungqr_ffi\x00\x08[\x17\x057\x01\x0bE[]gi\x03k\x03m\x03K\x11MOwEQSyW\x07GGG\x11MO\x7fEQW\x81S\x03Y\x03\x85\x05\x87\x89', xla_call_module_version=9, nr_devices=1, ) # End paste - -# Pasted from the test output (see export_back_compat_test_util.py module docstring) -data_2024_08_22['c64'] = dict( +data_2025_04_02['c64'] = dict( testdata_version=1, platform='cpu', custom_call_targets=['lapack_cgeqrf_ffi', 'lapack_cungqr_ffi'], - serialized_date=datetime.date(2024, 8, 22), + serialized_date=datetime.date(2025, 4, 2), inputs=(), - expected_outputs=( - array( - [ - [0.0 + 0.0j, 0.91287076 + 0.0j, 0.4082487 + 0.0j], - [-0.44721356 - 0.0j, 0.36514866 + 0.0j, -0.8164965 + 0.0j], - [-0.8944271 - 0.0j, -0.18257445 + 0.0j, 0.40824816 + 0.0j], - ], - dtype=complex64, - ), - array( - [ - [ - -6.7082043e00 + 0.0j, - -8.0498438e00 + 0.0j, - -9.3914852e00 + 0.0j, - ], - [0.0000000e00 + 0.0j, 1.0954441e00 + 0.0j, 2.1908894e00 + 0.0j], - [ - 0.0000000e00 + 0.0j, - 0.0000000e00 + 0.0j, - 7.1525574e-07 + 0.0j, - ], - ], - dtype=complex64, - ), - ), + expected_outputs=(array([[ 0. +0.j, 0.91287076+0.j, 0.4082487 +0.j], + [-0.44721356-0.j, 0.36514866+0.j, -0.8164965 +0.j], + [-0.8944271 -0.j, -0.18257445+0.j, 0.40824816+0.j]], + dtype=complex64), array([[-6.7082043e+00+0.j, -8.0498438e+00+0.j, -9.3914852e+00+0.j], + [ 0.0000000e+00+0.j, 1.0954441e+00+0.j, 2.1908894e+00+0.j], + [ 0.0000000e+00+0.j, 0.0000000e+00+0.j, 7.1525574e-07+0.j]], + dtype=complex64)), mlir_module_text=r""" -#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":364:11) -#loc10 = loc("jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=triu keep_unused=False inline=False]"(#loc3)) module @jit__lambda_ attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main() -> (tensor<3x3xcomplex> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<3x3xcomplex> {jax.result_info = "[1]", mhlo.layout_mode = "default"}) { + func.func public @main() -> (tensor<3x3xcomplex> {jax.result_info = "result[0]"}, tensor<3x3xcomplex> {jax.result_info = "result[1]"}) { + %c = stablehlo.constant dense<-1> : tensor loc(#loc) + %cst = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor> loc(#loc) %0 = stablehlo.iota dim = 0 : tensor<9xcomplex> loc(#loc4) %1 = stablehlo.reshape %0 : (tensor<9xcomplex>) -> tensor<3x3xcomplex> loc(#loc5) - %c = stablehlo.constant dense<3> : tensor loc(#loc6) - %c_0 = stablehlo.constant dense<3> : tensor loc(#loc6) - %2:2 = stablehlo.custom_call @lapack_cgeqrf_ffi(%1) {mhlo.backend_config = {}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>]} : (tensor<3x3xcomplex>) -> (tensor<3x3xcomplex>, tensor<3xcomplex>) loc(#loc6) - %cst = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor> loc(#loc7) - %3 = stablehlo.pad %2#0, %cst, low = [0, 0], high = [0, 0], interior = [0, 0] : (tensor<3x3xcomplex>, tensor>) -> tensor<3x3xcomplex> loc(#loc8) - %c_1 = stablehlo.constant dense<3> : tensor loc(#loc9) - %c_2 = stablehlo.constant dense<3> : tensor loc(#loc9) - %c_3 = stablehlo.constant dense<3> : tensor loc(#loc9) - %c_4 = stablehlo.constant dense<1> : tensor loc(#loc9) - %c_5 = stablehlo.constant dense<3> : tensor loc(#loc9) - %c_6 = stablehlo.constant dense<3> : tensor loc(#loc9) - %c_7 = stablehlo.constant dense<3> : tensor loc(#loc9) - %c_8 = stablehlo.constant dense<96> : tensor loc(#loc9) - %4:3 = stablehlo.custom_call @lapack_cungqr(%c_4, %c_5, %c_6, %c_7, %c_8, %3, %2#1) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor, tensor<3x3xcomplex>, tensor<3xcomplex>) -> (tensor<3x3xcomplex>, tensor, tensor<96xcomplex>) loc(#loc9) - %c_9 = stablehlo.constant dense<0> : tensor loc(#loc9) - %5 = stablehlo.broadcast_in_dim %c_9, dims = [] : (tensor) -> tensor loc(#loc9) - %6 = stablehlo.compare EQ, %4#1, %5, SIGNED : (tensor, tensor) -> tensor loc(#loc9) - %7 = stablehlo.broadcast_in_dim %6, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc9) - %cst_10 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> loc(#loc9) - %8 = stablehlo.broadcast_in_dim %cst_10, dims = [] : (tensor>) -> tensor<3x3xcomplex> loc(#loc9) - %9 = stablehlo.broadcast_in_dim %7, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<3x3xi1> loc(#loc9) - %10 = stablehlo.select %9, %4#0, %8 : tensor<3x3xi1>, tensor<3x3xcomplex> loc(#loc9) - %11 = call @triu(%2#0) : (tensor<3x3xcomplex>) -> tensor<3x3xcomplex> loc(#loc10) - return %10, %11 : tensor<3x3xcomplex>, tensor<3x3xcomplex> loc(#loc) + %2:2 = stablehlo.custom_call @lapack_cgeqrf_ffi(%1) {mhlo.backend_config = {}, mhlo.frontend_attributes = {num_batch_dims = "0"}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>]} : (tensor<3x3xcomplex>) -> (tensor<3x3xcomplex>, tensor<3xcomplex>) loc(#loc6) + %3 = stablehlo.pad %2#0, %cst, low = [0, 0], high = [0, 0], interior = [0, 0] : (tensor<3x3xcomplex>, tensor>) -> tensor<3x3xcomplex> loc(#loc7) + %4 = stablehlo.custom_call @lapack_cungqr_ffi(%3, %2#1) {mhlo.backend_config = {}, mhlo.frontend_attributes = {num_batch_dims = "0"}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>]} : (tensor<3x3xcomplex>, tensor<3xcomplex>) -> tensor<3x3xcomplex> loc(#loc8) + %5 = stablehlo.iota dim = 0 : tensor<3x3xi32> loc(#loc9) + %6 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<3x3xi32> loc(#loc10) + %7 = stablehlo.add %5, %6 : tensor<3x3xi32> loc(#loc10) + %8 = stablehlo.iota dim = 1 : tensor<3x3xi32> loc(#loc9) + %9 = stablehlo.compare GE, %7, %8, SIGNED : (tensor<3x3xi32>, tensor<3x3xi32>) -> tensor<3x3xi1> loc(#loc11) + %10 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor>) -> tensor<3x3xcomplex> loc(#loc12) + %11 = stablehlo.select %9, %10, %2#0 : tensor<3x3xi1>, tensor<3x3xcomplex> loc(#loc13) + return %4, %11 : tensor<3x3xcomplex>, tensor<3x3xcomplex> loc(#loc) } loc(#loc) - func.func private @triu(%arg0: tensor<3x3xcomplex> {mhlo.layout_mode = "default"} loc("jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=triu keep_unused=False inline=False]"(#loc3))) -> (tensor<3x3xcomplex> {mhlo.layout_mode = "default"}) { - %0 = stablehlo.iota dim = 0 : tensor<3x3xi32> loc(#loc11) - %c = stablehlo.constant dense<-1> : tensor loc(#loc10) - %1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<3x3xi32> loc(#loc12) - %2 = stablehlo.add %0, %1 : tensor<3x3xi32> loc(#loc12) - %3 = stablehlo.iota dim = 1 : tensor<3x3xi32> loc(#loc13) - %4 = stablehlo.compare GE, %2, %3, SIGNED : (tensor<3x3xi32>, tensor<3x3xi32>) -> tensor<3x3xi1> loc(#loc14) - %cst = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor> loc(#loc10) - %5 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor>) -> tensor<3x3xcomplex> loc(#loc15) - %6 = stablehlo.select %4, %5, %arg0 : tensor<3x3xi1>, tensor<3x3xcomplex> loc(#loc16) - return %6 : tensor<3x3xcomplex> loc(#loc10) - } loc(#loc10) } loc(#loc) #loc = loc(unknown) -#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":363:26) -#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":363:14) -#loc4 = loc("jit()/jit(main)/iota[dtype=complex64 shape=(9,) dimension=0]"(#loc1)) -#loc5 = loc("jit()/jit(main)/reshape[new_sizes=(3, 3) dimensions=None]"(#loc2)) +#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":411:26) +#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":411:14) +#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":412:11) +#loc4 = loc("jit()/jit(main)/iota"(#loc1)) +#loc5 = loc("jit()/jit(main)/reshape"(#loc2)) #loc6 = loc("jit()/jit(main)/geqrf"(#loc3)) -#loc7 = loc("jit()/jit(main)/qr[full_matrices=True]"(#loc3)) -#loc8 = loc("jit()/jit(main)/pad[padding_config=((0, 0, 0), (0, 0, 0))]"(#loc3)) -#loc9 = loc("jit()/jit(main)/householder_product"(#loc3)) -#loc11 = loc("jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=0]"(#loc3)) -#loc12 = loc("jit()/jit(main)/jit(triu)/add"(#loc3)) -#loc13 = loc("jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=1]"(#loc3)) -#loc14 = loc("jit()/jit(main)/jit(triu)/ge"(#loc3)) -#loc15 = loc("jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(3, 3) broadcast_dimensions=()]"(#loc3)) -#loc16 = loc("jit()/jit(main)/jit(triu)/select_n"(#loc3)) +#loc7 = loc("jit()/jit(main)/pad"(#loc3)) +#loc8 = loc("jit()/jit(main)/householder_product"(#loc3)) +#loc9 = loc("jit()/jit(main)/iota"(#loc3)) +#loc10 = loc("jit()/jit(main)/add"(#loc3)) +#loc11 = loc("jit()/jit(main)/ge"(#loc3)) +#loc12 = loc("jit()/jit(main)/broadcast_in_dim"(#loc3)) +#loc13 = loc("jit()/jit(main)/select_n"(#loc3)) """, - mlir_module_serialized=( - b"ML\xefR\x01StableHLO_v0.9.0\x00\x01)\x05\x01\x03\x01\x03\x05\x03\x19\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x03\xae\x02\x12\x025\x01\x9d\x0f\x17\x0b\x0f\x13\x13\x0b\x07\x0b\x0f\x13\x0f\x0b\x0b\x0b\x0b\x13\x0b\x0b\x0f\x0b\x0b\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b+\x0b\x0f\x0b\x0b\x0b33\x0b\x0f\x0b\x13\x0b\x13\x0f\x0b\x1b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x17\x0f\x0b\x17\x0bS\x0b\x0f\x0b#\x0b\x0b\x0b\x0f\x0b\x0b\x13\x13K\x13\x1b\x13\x03g\x0fO\x0b\x0b\x0b//\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x1b\x0b\x0b\x0b\x13\x0b\x0b\x0f\x1f\x0f\x0f\x0b//\x0b\x0b\x0b\x0f\x0f\x17\x13\x1f\x1f\x1f\x0b\x0b'\x0f\x17\x17\x1f\x0b/O\x01\x07\x17\x17\x0b\x01\x05\x0b\x0f\x031\x17\x0f\x0f\x0b\x07\x0f\x17\x07\x07\x07\x17\x13\x17\x07\x17\x13\x13\x13\x13\x13\x17\x13\x0f\x17\x02\xb2\t\x1d\x8f\x03\x17\x11\xb2\x05\x17\x05\x1f\x1dO\x03\x03\x03%\xd1\x03\x03\x05\xd9\x05!\x1f\x05#\x1dy\x03\x03\x03\x05\xeb\x11\x03\x05\x05%\x05'\x05)\x05+\x03\x03#\xcd\x05-\x05/\x1dW\x03\x051\x053\x03\x03\x05\xd7\x055\x057\x059\x05;\x05=\x05?\x05A\x05C\x03\tACE\x17G\x17\rI\x05E\x11\x01\x00\x05G\x05I\x05K\x03\x0b\x19\xa1\x1b\xb7\x1d\xb9\r\xc3\x1f\xc5\x03\x0b\x19\xad\x1b\xc9\x1d\xad\r\xaf\x1f\xcb\x05M\x1dS\x03\x05O\x03\x03\x05\xcf\x05Q\x03\x03#\xd3\x1d]\x03\x05S\x03\x05)\xb1+\xd5\x1dc\x03\x05U\x1dg\x03\x05W\x1dk\x03\x05Y\x1doq\x05[\x17\x11\xae\x055\x1duw\x05]\x17\x11\xae\x05\x1d\x05_\x03\x13/\xdb1\xb33\xdd5\xa17\xb5}\xdf9\xe1;\xe3=\xe7\x05a\x1d\x81\x03\x05c\x03\x07\x85\xa9\x87\xa9\x89\xa9\x05e\x05g\x05i\x1d\x8d\x03\x05k\x05m\x03\x03\x05\xe9\x03\x03\x05\xed\x03\x11/\xef1\xb33\xf15\xa17\xb59\xf3;\xf5=\xf9\x03\x03\x05\xfb\x03\x05)\xb1+\xfd\x03\x03\x05\xff\x1f/\x01\x1f)!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1do\x1dq\x1f+\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x1b\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1ds\x03\x03\xc7\x1du\t\x07\x1dw\x05\x01#\x1d\x03\x05\xbb\xbf\r\x05\xab\xbd\xa3\xa5\x1dy\r\x05\xab\xc1\xa3\xa5\x1d{\x1d}\x1d\x7f\r\x03\xa3\xa5#!\x1d\x81\x13\r\x01\x1f\x07\t\xff\xff\xff\xff\x1f#\x01\x13\r\x05\x07\x05\x1f\x0f\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\t\x11\x03\x00\x00\x00\x00\x00\x00\x00\x0b\x03\x1d\x83\r\x01\x03\x03\x9f\x03\x03\xe5\x15\x03\x01\x01\x01\x03\x05\x9f\xa7\x1f\x07\t\x01\x00\x00\x00\x1f\x07\t\x03\x00\x00\x00\x1f\x07\t`\x00\x00\x00\x0b\x05\x1d\x85\x03\x0f\x9d\x9d\x9d\x9d\x9d\x9f\xa7\x03\x03\xf7\x15\x03\x01\x15\x01\x03\x07\x9f\x9d\xa7\x1f\x07\t\x00\x00\x00\x00\x07\x01\x1f\x0f\x11\x00\x00\xc0\x7f\x00\x00\xc0\x7f\x1f\x1b!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\x03%\x02\x02\x03\x03\x0e\x02\xaf\x05\x87\x01\t\x01\x02\x02)\x05\r\r\x0b)\x01\x17)\x01\r\x03\x1f\x1d)\x01\x0b)\x05\r\r\x17\x01\x13\x1b)\x05\r\r\x13)\x03\t\r\x11\x01\x05\x05\x05\t\x11\x03\x05\x03\x05)\x03\x01\r)\x03%\x0b)\x03\r\x0b)\x03\t\x15)\x03\x05\x15)\x03\x02\x03\x0b)\x03\x01\x15)\x01\x13)\x05\x05\x05\x13\x04\x8a\x04\x05\x01\x11\x0f?\x07\x03\x01\t\t\x11\x0fK\x07\x039i\x07\x03m!\x03%\x15\x06s\x03\x05\x03\x01\x03\x03\x13\x0b\x03\t\x03\x03\x13\x0b\x03\t\x11\x07\x13{\x05\x05'\x03\x03\x03\x03\x7f-\x03\x0f\x17\x07\x8b\x83\x03\x05\x05\t\r\x03\x03\x01\x0b\x03\t\x03\x03\x01\x0b\x03\t\x03\x03\x01\x0b\x03\t\x03\x03\x01\x91\x03\x07\x03\x03\x01\x15\x03\x07\x03\x03\x01\x15\x03\x07\x03\x03\x01\x15\x03\x07\x03\x03\x01\x93\x03\x07\x11\x07\x01\x95\x07\x05\x07-\x0f\x17\x19\x1b\x1d\x1f\x0f\x0b\x03\x03\x01\x97\x03\x07\x05\x07\x01\t\x03\x07\x03'\x0b\x07\x01\x99\x031\x05#)\x05\x07\x01\t\x033\x03+\x03\x03\x01\x9b\x03\x0f\x05\x07\x01\t\x03\x05\x03/\x05\x07\x01\x06\x02\x03\x19\x03-\r\x06\x01\x03\x05\x073!1\x19\x07\x07\n\x02\x03\x05\x03\t\x0f\x04\x0f\x0557\t\x11\x07M\x07\x03\x15+\x03\x05\x07\x07\x03Q!\x03\x11\x03\x03\x07U\x03\x07\x05\x07'\t\x03\x11\x03\x05\x13\x06'\x03\x11\x05\x03\x07\x07\x03[Y\x03\x11\x0b\x07a_\x03\x19\x05\t\x0b\x03\x03\x07-\x03\x0f\x05\x07e\t\x03\x05\x03\x0f\r\x06i\x03\x05\x07\r\x11\x01\x0f\x04\x07\x03\x13\x06\x03\x01\x05\x01\x00\xa6\x1a\x89\x0f\x1d%\x11\x0f\x0b\t\t\x03\x0b!\x11#Y\x87##%_)=\x85\x8bW\xb3K\x9bM\x9bn\x03\x1b%)9\x1f/!!)#\x1f\x19+\x1b+\x1f\x1f\x15\x1d\x15i\x13\r\x11\x0f\x17\x0f\x1f\x15\x15\x17\x11\x11)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00iota_v1\x00func_v1\x00compare_v1\x00select_v1\x00return_v1\x00custom_call_v1\x00add_v1\x00reshape_v1\x00pad_v1\x00call_v1\x00value\x00sym_name\x00third_party/py/jax/tests/export_back_compat_test.py\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00broadcast_dimensions\x00compare_type\x00comparison_direction\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,)" - b' out_shardings=(UnspecifiedValue,) in_layouts=(None,)' - b' out_layouts=(None,) resource_env=None donated_invars=(False,)' - b' name=triu keep_unused=False' - b' inline=False]\x00jit()/jit(main)/jit(triu)/iota[dtype=int32' - b' shape=(3, 3)' - b' dimension=0]\x00jit()/jit(main)/jit(triu)/add\x00jit()/jit(main)/jit(triu)/iota[dtype=int32' - b' shape=(3, 3)' - b' dimension=1]\x00jit()/jit(main)/jit(triu)/ge\x00jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(3,' - b' 3) broadcast_dimensions=()]\x00jit()/jit(main)/jit(triu)/select_n\x00jit()/jit(main)/iota[dtype=complex64' - b' shape=(9,)' - b' dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(3, 3)' - b' dimensions=None]\x00jit()/jit(main)/geqrf\x00mhlo.backend_config\x00jit()/jit(main)/qr[full_matrices=True]\x00edge_padding_high\x00edge_padding_low\x00interior_padding\x00jit()/jit(main)/pad[padding_config=((0,' - b' 0, 0), (0, 0,' - b' 0))]\x00jit()/jit(main)/householder_product\x00mhlo.layout_mode\x00default\x00jax.result_info\x00triu\x00\x00[0]\x00[1]\x00main\x00public\x00private\x00lapack_cgeqrf_ffi\x00lapack_cungqr\x00callee\x00' - ), + mlir_module_serialized=b'ML\xefR\rStableHLO_v1.9.3\x00\x01)\x05\x01\x05\x19\x01\x03\x0b\x03\x17\x0f\x13\x17\x1b\x1f#\'+/37\x03\xc7\x8b)\x01E\x17\x07\x0b\x0f\x0b\x1b\x0f\x0f#\x0b\x0f\x0b\x0b\x0b\x0f\x17\x0f\x0b\x17\x0b\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x03G\x0b/\x0b\x0f\x0b\x0b\x0b\x0fO\x13\x0f\x0b\x13\x13\x0b\x13\x0b\x0b\x0b\x1f/\x0b\x13\x0b\x0b\x0b\x0f\x17/\x0b\x0f\x13\x0f\x0b\x0b\x01\x05\x0b\x0f\x03%\x17\x0b\x07\x17\x0f\x07\x0f\x07\x17\x07\x13\x13\x13\x13\x13\x13\x17\x07\x02\xb2\x04\x17\x05r\x06\x17\x1f\x05\x1d\x11\x03\x05\x05\x1f\x03\x05\'o)q\x1d\t\x01\x1d7\x01\x03\x07\x13\x15\x17\x07\x19\x07\x05!\x11\x01\x00\x05#\x05%\x05\'\x1d\t\x1f\x17\x05n\x065\x1d#%\x05)\x17\x05n\x06\x1d\x05+\x05-\x1d-\x01\x05/\x1d1\x01\x051\x1d5\x01\x053\x055\x1d;\x01\x057\x1d?\x01\x059\x1dC\x01\x05;\x03\x01\x1f!\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1d=\x13\t\x01\x0b\x03\x1d?\x05\x01\x03\x03U\x1f\x1d!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\x05U}\x1f#\x01#\x15\x03\x05_c\r\x03Ia\x1dA\r\x03Ie\x1dC\x1dE\x1dG\x1f\r\t\xff\xff\xff\xff\x1f\x11\x11\x00\x00\x00\x00\x00\x00\x00\x00\r\x01\r\x03su\x1dI\x1dK\x1dM\x03\x03{\x15\x03\x01\x01\x01\x1f\x1f\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1dO\x03\x03\x83\x15\x01\x01\x01\x13\t\x05\t\x07\x07\x05\x01\t\x01\x02\x02)\x05\r\r\x07\x03\x17\x1d)\x05\r\r\x0f)\x01\x0f\x1b)\x01\x07\x13\x11\x01\x05\x05\x05\t)\x03%\x07)\x03\r\x07)\x03\t\x13)\x03\x05\x13)\x03\t\t)\x03\x01\t)\x05\r\r\'\x01\x04"\x02\x05\x01Q\x03\x11\x01\x07\x04\xff\x03\x01\x05\x0bP\x03\x03\x07\x04\xeb\x03\x1f=\x05B\x03\x05\x03\r\x05B\x03\x07\x03\x11\x03B\x1d\t\x03\x19\r\x06!\x03\x05\x03\x05\x07G+\x0b\x0b\x05\x05\x1b\x03\x07\x0fF/\r\x03\x05\x05\t\x03\x07G3\x0b\x0f\x03\x05\x05\r\x0b\x03B\r\t\x03\x0b\tF\x0f\x11\x03\x0b\x03\x01\x11\x06\x0f\x03\x0b\x05\x11\x13\x03B\r\x13\x03\x0b\x13F9\x15\x03%\x05\x15\x17\tF=\x11\x03\x05\x03\x03\x15\x06A\x03\x05\x07\x19\x1b\t\x17\x04\x03\x05\x0f\x1d\x06\x03\x01\x05\x01\x00\xba\x0bQ%%\x05\x1f\x0f\x0b\x15\x15\x03!CS79Y9=3)A\x1b%)9;i\x15\x15\x17\x0f\x0f\x17\x11)\x1f\x19\x11\x0f\x0b\x11builtin\x00vhlo\x00module\x00iota_v1\x00constant_v1\x00custom_call_v1\x00broadcast_in_dim_v1\x00func_v1\x00reshape_v1\x00pad_v1\x00add_v1\x00compare_v1\x00select_v1\x00return_v1\x00third_party/py/jax/tests/export_back_compat_test.py\x00jit()/jit(main)/iota\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00jit()/jit(main)/reshape\x00mhlo.backend_config\x00mhlo.frontend_attributes\x00jit()/jit(main)/geqrf\x00jit()/jit(main)/pad\x00jit()/jit(main)/householder_product\x00jit()/jit(main)/add\x00jit()/jit(main)/ge\x00jit()/jit(main)/broadcast_in_dim\x00jit()/jit(main)/select_n\x00jax.result_info\x00\x00result[0]\x00result[1]\x00main\x00public\x00num_batch_dims\x000\x00lapack_cgeqrf_ffi\x00lapack_cungqr_ffi\x00\x08[\x17\x057\x01\x0bE[]gi\x03k\x03m\x03K\x11MOwEQSyW\x07GGG\x11MO\x7fEQW\x81S\x03Y\x03\x85\x05\x87\x89', xla_call_module_version=9, nr_devices=1, ) # End paste - -# Pasted from the test output (see export_back_compat_test_util.py module docstring) -data_2024_08_22['f32'] = dict( +data_2025_04_02['f32'] = dict( testdata_version=1, platform='cpu', custom_call_targets=['lapack_sgeqrf_ffi', 'lapack_sorgqr_ffi'], - serialized_date=datetime.date(2024, 8, 22), + serialized_date=datetime.date(2025, 4, 2), inputs=(), - expected_outputs=( - array( - [ - [0.0, 0.91287076, 0.4082487], - [-0.44721356, 0.36514866, -0.8164965], - [-0.8944271, -0.18257445, 0.40824816], - ], - dtype=float32, - ), - array( - [ - [-6.7082043e00, -8.0498438e00, -9.3914852e00], - [0.0000000e00, 1.0954441e00, 2.1908894e00], - [0.0000000e00, 0.0000000e00, 7.1525574e-07], - ], - dtype=float32, - ), - ), + expected_outputs=(array([[ 0. , 0.91287076, 0.4082487 ], + [-0.44721356, 0.36514866, -0.8164965 ], + [-0.8944271 , -0.18257445, 0.40824816]], dtype=float32), array([[-6.7082043e+00, -8.0498438e+00, -9.3914852e+00], + [ 0.0000000e+00, 1.0954441e+00, 2.1908894e+00], + [ 0.0000000e+00, 0.0000000e+00, 7.1525574e-07]], dtype=float32)), mlir_module_text=r""" -#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":364:11) -#loc10 = loc("jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=triu keep_unused=False inline=False]"(#loc3)) module @jit__lambda_ attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main() -> (tensor<3x3xf32> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<3x3xf32> {jax.result_info = "[1]", mhlo.layout_mode = "default"}) { + func.func public @main() -> (tensor<3x3xf32> {jax.result_info = "result[0]"}, tensor<3x3xf32> {jax.result_info = "result[1]"}) { + %c = stablehlo.constant dense<-1> : tensor loc(#loc) + %cst = stablehlo.constant dense<0.000000e+00> : tensor loc(#loc) %0 = stablehlo.iota dim = 0 : tensor<9xf32> loc(#loc4) %1 = stablehlo.reshape %0 : (tensor<9xf32>) -> tensor<3x3xf32> loc(#loc5) - %c = stablehlo.constant dense<3> : tensor loc(#loc6) - %c_0 = stablehlo.constant dense<3> : tensor loc(#loc6) - %2:2 = stablehlo.custom_call @lapack_sgeqrf_ffi(%1) {mhlo.backend_config = {}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>]} : (tensor<3x3xf32>) -> (tensor<3x3xf32>, tensor<3xf32>) loc(#loc6) - %cst = stablehlo.constant dense<0.000000e+00> : tensor loc(#loc7) - %3 = stablehlo.pad %2#0, %cst, low = [0, 0], high = [0, 0], interior = [0, 0] : (tensor<3x3xf32>, tensor) -> tensor<3x3xf32> loc(#loc8) - %c_1 = stablehlo.constant dense<3> : tensor loc(#loc9) - %c_2 = stablehlo.constant dense<3> : tensor loc(#loc9) - %c_3 = stablehlo.constant dense<3> : tensor loc(#loc9) - %c_4 = stablehlo.constant dense<1> : tensor loc(#loc9) - %c_5 = stablehlo.constant dense<3> : tensor loc(#loc9) - %c_6 = stablehlo.constant dense<3> : tensor loc(#loc9) - %c_7 = stablehlo.constant dense<3> : tensor loc(#loc9) - %c_8 = stablehlo.constant dense<96> : tensor loc(#loc9) - %4:3 = stablehlo.custom_call @lapack_sorgqr(%c_4, %c_5, %c_6, %c_7, %c_8, %3, %2#1) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor, tensor<3x3xf32>, tensor<3xf32>) -> (tensor<3x3xf32>, tensor, tensor<96xf32>) loc(#loc9) - %c_9 = stablehlo.constant dense<0> : tensor loc(#loc9) - %5 = stablehlo.broadcast_in_dim %c_9, dims = [] : (tensor) -> tensor loc(#loc9) - %6 = stablehlo.compare EQ, %4#1, %5, SIGNED : (tensor, tensor) -> tensor loc(#loc9) - %7 = stablehlo.broadcast_in_dim %6, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc9) - %cst_10 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc9) - %8 = stablehlo.broadcast_in_dim %cst_10, dims = [] : (tensor) -> tensor<3x3xf32> loc(#loc9) - %9 = stablehlo.broadcast_in_dim %7, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<3x3xi1> loc(#loc9) - %10 = stablehlo.select %9, %4#0, %8 : tensor<3x3xi1>, tensor<3x3xf32> loc(#loc9) - %11 = call @triu(%2#0) : (tensor<3x3xf32>) -> tensor<3x3xf32> loc(#loc10) - return %10, %11 : tensor<3x3xf32>, tensor<3x3xf32> loc(#loc) + %2:2 = stablehlo.custom_call @lapack_sgeqrf_ffi(%1) {mhlo.backend_config = {}, mhlo.frontend_attributes = {num_batch_dims = "0"}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>]} : (tensor<3x3xf32>) -> (tensor<3x3xf32>, tensor<3xf32>) loc(#loc6) + %3 = stablehlo.pad %2#0, %cst, low = [0, 0], high = [0, 0], interior = [0, 0] : (tensor<3x3xf32>, tensor) -> tensor<3x3xf32> loc(#loc7) + %4 = stablehlo.custom_call @lapack_sorgqr_ffi(%3, %2#1) {mhlo.backend_config = {}, mhlo.frontend_attributes = {num_batch_dims = "0"}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>]} : (tensor<3x3xf32>, tensor<3xf32>) -> tensor<3x3xf32> loc(#loc8) + %5 = stablehlo.iota dim = 0 : tensor<3x3xi32> loc(#loc9) + %6 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<3x3xi32> loc(#loc10) + %7 = stablehlo.add %5, %6 : tensor<3x3xi32> loc(#loc10) + %8 = stablehlo.iota dim = 1 : tensor<3x3xi32> loc(#loc9) + %9 = stablehlo.compare GE, %7, %8, SIGNED : (tensor<3x3xi32>, tensor<3x3xi32>) -> tensor<3x3xi1> loc(#loc11) + %10 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<3x3xf32> loc(#loc12) + %11 = stablehlo.select %9, %10, %2#0 : tensor<3x3xi1>, tensor<3x3xf32> loc(#loc13) + return %4, %11 : tensor<3x3xf32>, tensor<3x3xf32> loc(#loc) } loc(#loc) - func.func private @triu(%arg0: tensor<3x3xf32> {mhlo.layout_mode = "default"} loc("jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=triu keep_unused=False inline=False]"(#loc3))) -> (tensor<3x3xf32> {mhlo.layout_mode = "default"}) { - %0 = stablehlo.iota dim = 0 : tensor<3x3xi32> loc(#loc11) - %c = stablehlo.constant dense<-1> : tensor loc(#loc10) - %1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<3x3xi32> loc(#loc12) - %2 = stablehlo.add %0, %1 : tensor<3x3xi32> loc(#loc12) - %3 = stablehlo.iota dim = 1 : tensor<3x3xi32> loc(#loc13) - %4 = stablehlo.compare GE, %2, %3, SIGNED : (tensor<3x3xi32>, tensor<3x3xi32>) -> tensor<3x3xi1> loc(#loc14) - %cst = stablehlo.constant dense<0.000000e+00> : tensor loc(#loc10) - %5 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<3x3xf32> loc(#loc15) - %6 = stablehlo.select %4, %5, %arg0 : tensor<3x3xi1>, tensor<3x3xf32> loc(#loc16) - return %6 : tensor<3x3xf32> loc(#loc10) - } loc(#loc10) } loc(#loc) #loc = loc(unknown) -#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":363:26) -#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":363:14) -#loc4 = loc("jit()/jit(main)/iota[dtype=float32 shape=(9,) dimension=0]"(#loc1)) -#loc5 = loc("jit()/jit(main)/reshape[new_sizes=(3, 3) dimensions=None]"(#loc2)) +#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":411:26) +#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":411:14) +#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":412:11) +#loc4 = loc("jit()/jit(main)/iota"(#loc1)) +#loc5 = loc("jit()/jit(main)/reshape"(#loc2)) #loc6 = loc("jit()/jit(main)/geqrf"(#loc3)) -#loc7 = loc("jit()/jit(main)/qr[full_matrices=True]"(#loc3)) -#loc8 = loc("jit()/jit(main)/pad[padding_config=((0, 0, 0), (0, 0, 0))]"(#loc3)) -#loc9 = loc("jit()/jit(main)/householder_product"(#loc3)) -#loc11 = loc("jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=0]"(#loc3)) -#loc12 = loc("jit()/jit(main)/jit(triu)/add"(#loc3)) -#loc13 = loc("jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=1]"(#loc3)) -#loc14 = loc("jit()/jit(main)/jit(triu)/ge"(#loc3)) -#loc15 = loc("jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(3, 3) broadcast_dimensions=()]"(#loc3)) -#loc16 = loc("jit()/jit(main)/jit(triu)/select_n"(#loc3)) +#loc7 = loc("jit()/jit(main)/pad"(#loc3)) +#loc8 = loc("jit()/jit(main)/householder_product"(#loc3)) +#loc9 = loc("jit()/jit(main)/iota"(#loc3)) +#loc10 = loc("jit()/jit(main)/add"(#loc3)) +#loc11 = loc("jit()/jit(main)/ge"(#loc3)) +#loc12 = loc("jit()/jit(main)/broadcast_in_dim"(#loc3)) +#loc13 = loc("jit()/jit(main)/select_n"(#loc3)) """, - mlir_module_serialized=( - b"ML\xefR\x01StableHLO_v0.9.0\x00\x01)\x05\x01\x03\x01\x03\x05\x03\x19\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x03\xaa\x02\x12\x023\x01\x9d\x0f\x17\x0b\x0f\x13\x13\x0b\x07\x0b\x0f\x13\x0f\x0b\x0b\x0b\x0b\x13\x0b\x0b\x0f\x0b\x0b\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b+\x0b\x0f\x0b\x0b\x0b33\x0b\x0f\x0b\x13\x0b\x13\x0f\x0b\x1b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x17\x0f\x0b\x17\x0bS\x0b\x0f\x0b#\x0b\x0b\x0b\x0f\x0b\x0b\x13\x13K\x13\x1b\x13\x03g\x0fO\x0b\x0b\x0b//\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x1b\x0b\x0b\x0b\x13\x0b\x0b\x0f\x1f\x0f\x0f\x0b\x1f/\x0b\x0b\x0b\x0f\x0f\x17\x13\x1f\x1f\x1f\x0b\x0b'\x0f\x17\x17\x1f\x0b\x1fO\x01\x07\x17\x17\x0b\x01\x05\x0b\x0f\x03/\x17\x0f\x0f\x07\x07\x0f\x17\x07\x07\x07\x17\x13\x17\x17\x13\x13\x13\x13\x13\x17\x13\x0f\x17\x02\x8a\t\x1d\x8f\x03\x17\x11\xb2\x05\x17\x05\x1f\x1dO\x03\x03\x03%\xd1\x03\x03\x05\xd9\x05!\x1f\x05#\x1dy\x03\x03\x03\x05\xeb\x11\x03\x05\x05%\x05'\x05)\x05+\x03\x03#\xcd\x05-\x05/\x1dW\x03\x051\x053\x03\x03\x05\xd7\x055\x057\x059\x05;\x05=\x05?\x05A\x05C\x03\tACE\x17G\x17\rI\x05E\x11\x01\x00\x05G\x05I\x05K\x03\x0b\x19\xa1\x1b\xb7\x1d\xb9\r\xc3\x1f\xc5\x03\x0b\x19\xad\x1b\xc9\x1d\xad\r\xaf\x1f\xcb\x05M\x1dS\x03\x05O\x03\x03\x05\xcf\x05Q\x03\x03#\xd3\x1d]\x03\x05S\x03\x05)\xb1+\xd5\x1dc\x03\x05U\x1dg\x03\x05W\x1dk\x03\x05Y\x1doq\x05[\x17\x11\xae\x055\x1duw\x05]\x17\x11\xae\x05\x1d\x05_\x03\x13/\xdb1\xb33\xdd5\xa17\xb5}\xdf9\xe1;\xe3=\xe7\x05a\x1d\x81\x03\x05c\x03\x07\x85\xa9\x87\xa9\x89\xa9\x05e\x05g\x05i\x1d\x8d\x03\x05k\x05m\x03\x03\x05\xe9\x03\x03\x05\xed\x03\x11/\xef1\xb33\xf15\xa17\xb59\xf3;\xf5=\xf9\x03\x03\x05\xfb\x03\x05)\xb1+\xfd\x03\x03\x05\xff\x1f-\x01\x1f'!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1do\x1dq\x1f)\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x1b\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1ds\x03\x03\xc7\x1du\t\x07\x1dw\x05\x01#\x1d\x03\x05\xbb\xbf\r\x05\xab\xbd\xa3\xa5\x1dy\r\x05\xab\xc1\xa3\xa5\x1d{\x1d}\x1d\x7f\r\x03\xa3\xa5#\x1f\x1d\x81\x13\r\x01\x1f\x07\t\xff\xff\xff\xff\x1f!\x01\x13\r\x05\x07\x05\x1f\x0f\t\x00\x00\x00\x00\x1f\t\x11\x03\x00\x00\x00\x00\x00\x00\x00\x0b\x03\x1d\x83\r\x01\x03\x03\x9f\x03\x03\xe5\x15\x03\x01\x01\x01\x03\x05\x9f\xa7\x1f\x07\t\x01\x00\x00\x00\x1f\x07\t\x03\x00\x00\x00\x1f\x07\t`\x00\x00\x00\x0b\x05\x1d\x85\x03\x0f\x9d\x9d\x9d\x9d\x9d\x9f\xa7\x03\x03\xf7\x15\x03\x01\x15\x01\x03\x07\x9f\x9d\xa7\x1f\x07\t\x00\x00\x00\x00\x07\x01\x1f\x0f\t\x00\x00\xc0\x7f\x1f\x1b!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\x03%\x02\x02\x03\x03\x0e\x02\xaf\x05\x87\x01\t\x01\x02\x02)\x05\r\r\x0b)\x01\x17)\x01\r\t\x1d)\x01\x0b)\x05\r\r\x17\x01\x13\x1b)\x05\r\r\x13)\x03\t\r\x11\x01\x05\x05\x05\x11\x03\x05\x03\x05)\x03\x01\r)\x03%\x0b)\x03\r\x0b)\x03\t\x15)\x03\x05\x15)\x03\x02\x03\x0b)\x03\x01\x15)\x01\x13)\x05\x05\x05\x13\x04\x8a\x04\x05\x01\x11\x0f?\x07\x03\x01\t\t\x11\x0fK\x07\x039i\x07\x03m!\x03#\x15\x06s\x03\x05\x03\x01\x03\x03\x13\x0b\x03\t\x03\x03\x13\x0b\x03\t\x11\x07\x13{\x05\x05%\x03\x03\x03\x03\x7f-\x03\x0f\x17\x07\x8b\x83\x03\x05\x05\t\r\x03\x03\x01\x0b\x03\t\x03\x03\x01\x0b\x03\t\x03\x03\x01\x0b\x03\t\x03\x03\x01\x91\x03\x07\x03\x03\x01\x15\x03\x07\x03\x03\x01\x15\x03\x07\x03\x03\x01\x15\x03\x07\x03\x03\x01\x93\x03\x07\x11\x07\x01\x95\x07\x05\x07+\x0f\x17\x19\x1b\x1d\x1f\x0f\x0b\x03\x03\x01\x97\x03\x07\x05\x07\x01\t\x03\x07\x03'\x0b\x07\x01\x99\x03/\x05#)\x05\x07\x01\t\x031\x03+\x03\x03\x01\x9b\x03\x0f\x05\x07\x01\t\x03\x05\x03/\x05\x07\x01\x06\x02\x03\x19\x03-\r\x06\x01\x03\x05\x073!1\x19\x07\x07\n\x02\x03\x05\x03\t\x0f\x04\x0f\x0557\t\x11\x07M\x07\x03\x15+\x03\x05\x07\x07\x03Q!\x03\x11\x03\x03\x07U\x03\x07\x05\x07'\t\x03\x11\x03\x05\x13\x06'\x03\x11\x05\x03\x07\x07\x03[Y\x03\x11\x0b\x07a_\x03\x19\x05\t\x0b\x03\x03\x07-\x03\x0f\x05\x07e\t\x03\x05\x03\x0f\r\x06i\x03\x05\x07\r\x11\x01\x0f\x04\x07\x03\x13\x06\x03\x01\x05\x01\x00\x9e\x1a\x89\x0f\x1d%\x11\x0f\x0b\t\t\x03\x0b!\x11#Y\x87##%_)=\x85\x87W\xb3K\x9bM\x9bn\x03\x1b%)9\x1f/!!)#\x1f\x19+\x1b+\x1f\x1f\x15\x1d\x15i\x13\r\x11\x0f\x17\x0f\x1f\x15\x15\x17\x11\x11)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00iota_v1\x00func_v1\x00compare_v1\x00select_v1\x00return_v1\x00custom_call_v1\x00add_v1\x00reshape_v1\x00pad_v1\x00call_v1\x00value\x00sym_name\x00third_party/py/jax/tests/export_back_compat_test.py\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00broadcast_dimensions\x00compare_type\x00comparison_direction\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,)" - b' out_shardings=(UnspecifiedValue,) in_layouts=(None,)' - b' out_layouts=(None,) resource_env=None donated_invars=(False,)' - b' name=triu keep_unused=False' - b' inline=False]\x00jit()/jit(main)/jit(triu)/iota[dtype=int32' - b' shape=(3, 3)' - b' dimension=0]\x00jit()/jit(main)/jit(triu)/add\x00jit()/jit(main)/jit(triu)/iota[dtype=int32' - b' shape=(3, 3)' - b' dimension=1]\x00jit()/jit(main)/jit(triu)/ge\x00jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(3,' - b' 3) broadcast_dimensions=()]\x00jit()/jit(main)/jit(triu)/select_n\x00jit()/jit(main)/iota[dtype=float32' - b' shape=(9,)' - b' dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(3, 3)' - b' dimensions=None]\x00jit()/jit(main)/geqrf\x00mhlo.backend_config\x00jit()/jit(main)/qr[full_matrices=True]\x00edge_padding_high\x00edge_padding_low\x00interior_padding\x00jit()/jit(main)/pad[padding_config=((0,' - b' 0, 0), (0, 0,' - b' 0))]\x00jit()/jit(main)/householder_product\x00mhlo.layout_mode\x00default\x00jax.result_info\x00triu\x00\x00[0]\x00[1]\x00main\x00public\x00private\x00lapack_sgeqrf_ffi\x00lapack_sorgqr\x00callee\x00' - ), + mlir_module_serialized=b'ML\xefR\rStableHLO_v1.9.3\x00\x01)\x05\x01\x05\x19\x01\x03\x0b\x03\x17\x0f\x13\x17\x1b\x1f#\'+/37\x03\xc5\x8b\'\x01E\x17\x07\x0b\x0f\x0b\x1b\x0f\x0f#\x0b\x0f\x0b\x0b\x0b\x0f\x17\x0f\x0b\x17\x0b\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x03G\x0b/\x0b\x0f\x0b\x0b\x0b\x0fO\x13\x0f\x0b\x13\x13\x0b\x13\x0b\x0b\x0b\x1f\x1f\x0b\x13\x0b\x0b\x0b\x0f\x17/\x0b\x0f\x13\x0f\x0b\x0b\x01\x05\x0b\x0f\x03#\x17\x07\x07\x17\x0f\x07\x0f\x07\x17\x13\x13\x13\x13\x13\x13\x17\x07\x02\x9a\x04\x17\x05r\x06\x17\x1f\x05\x1d\x11\x03\x05\x05\x1f\x03\x05\'o)q\x1d\t\x01\x1d7\x01\x03\x07\x13\x15\x17\x07\x19\x07\x05!\x11\x01\x00\x05#\x05%\x05\'\x1d\t\x1f\x17\x05n\x065\x1d#%\x05)\x17\x05n\x06\x1d\x05+\x05-\x1d-\x01\x05/\x1d1\x01\x051\x1d5\x01\x053\x055\x1d;\x01\x057\x1d?\x01\x059\x1dC\x01\x05;\x03\x01\x1f\x1f\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1d=\x13\t\x01\x0b\x03\x1d?\x05\x01\x03\x03U\x1f\x1b!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\x05U}\x1f!\x01#\x15\x03\x05_c\r\x03Ia\x1dA\r\x03Ie\x1dC\x1dE\x1dG\x1f\r\t\xff\xff\xff\xff\x1f\x11\t\x00\x00\x00\x00\r\x01\r\x03su\x1dI\x1dK\x1dM\x03\x03{\x15\x03\x01\x01\x01\x1f\x1d\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1dO\x03\x03\x83\x15\x01\x01\x01\x13\t\x05\t\x07\x07\x05\x01\t\x01\x02\x02)\x05\r\r\x07\t\x1d)\x05\r\r\x0f)\x01\x0f\x1b)\x01\x07\x13\x11\x01\x05\x05\x05)\x03%\x07)\x03\r\x07)\x03\t\x13)\x03\x05\x13)\x03\t\t)\x03\x01\t)\x05\r\r%\x01\x04"\x02\x05\x01Q\x03\x11\x01\x07\x04\xff\x03\x01\x05\x0bP\x03\x03\x07\x04\xeb\x03\x1f=\x05B\x03\x05\x03\r\x05B\x03\x07\x03\x11\x03B\x1d\t\x03\x17\r\x06!\x03\x05\x03\x05\x07G+\x0b\x0b\x05\x05\x19\x03\x07\x0fF/\r\x03\x05\x05\t\x03\x07G3\x0b\x0f\x03\x05\x05\r\x0b\x03B\r\t\x03\x0b\tF\x0f\x11\x03\x0b\x03\x01\x11\x06\x0f\x03\x0b\x05\x11\x13\x03B\r\x13\x03\x0b\x13F9\x15\x03#\x05\x15\x17\tF=\x11\x03\x05\x03\x03\x15\x06A\x03\x05\x07\x19\x1b\t\x17\x04\x03\x05\x0f\x1d\x06\x03\x01\x05\x01\x00\xba\x0bQ%%\x05\x1f\x0f\x0b\x15\x15\x03!CS79Y9=3)A\x1b%)9;i\x15\x15\x17\x0f\x0f\x17\x11)\x1f\x19\x11\x0f\x0b\x11builtin\x00vhlo\x00module\x00iota_v1\x00constant_v1\x00custom_call_v1\x00broadcast_in_dim_v1\x00func_v1\x00reshape_v1\x00pad_v1\x00add_v1\x00compare_v1\x00select_v1\x00return_v1\x00third_party/py/jax/tests/export_back_compat_test.py\x00jit()/jit(main)/iota\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00jit()/jit(main)/reshape\x00mhlo.backend_config\x00mhlo.frontend_attributes\x00jit()/jit(main)/geqrf\x00jit()/jit(main)/pad\x00jit()/jit(main)/householder_product\x00jit()/jit(main)/add\x00jit()/jit(main)/ge\x00jit()/jit(main)/broadcast_in_dim\x00jit()/jit(main)/select_n\x00jax.result_info\x00\x00result[0]\x00result[1]\x00main\x00public\x00num_batch_dims\x000\x00lapack_sgeqrf_ffi\x00lapack_sorgqr_ffi\x00\x08[\x17\x057\x01\x0bE[]gi\x03k\x03m\x03K\x11MOwEQSyW\x07GGG\x11MO\x7fEQW\x81S\x03Y\x03\x85\x05\x87\x89', xla_call_module_version=9, nr_devices=1, ) # End paste - -# Pasted from the test output (see export_back_compat_test_util.py module docstring) -data_2024_08_22['f64'] = dict( +data_2025_04_02['f64'] = dict( testdata_version=1, platform='cpu', custom_call_targets=['lapack_dgeqrf_ffi', 'lapack_dorgqr_ffi'], - serialized_date=datetime.date(2024, 8, 22), + serialized_date=datetime.date(2025, 4, 2), inputs=(), - expected_outputs=( - array([ - [0.0, 0.9128709291752773, 0.40824829046386235], - [-0.447213595499958, 0.3651483716701102, -0.8164965809277263], - [-0.894427190999916, -0.1825741858350548, 0.40824829046386324], - ]), - array([ - [ - -6.7082039324993694e00, - -8.0498447189992444e00, - -9.3914855054991175e00, - ], - [ - 0.0000000000000000e00, - 1.0954451150103341e00, - 2.1908902300206665e00, - ], - [ - 0.0000000000000000e00, - 0.0000000000000000e00, - -8.8817841970012523e-16, - ], - ]), - ), + expected_outputs=(array([[ 0. , 0.9128709291752773 , 0.40824829046386235], + [-0.447213595499958 , 0.3651483716701102 , -0.8164965809277263 ], + [-0.894427190999916 , -0.1825741858350548 , 0.40824829046386324]]), array([[-6.7082039324993694e+00, -8.0498447189992444e+00, + -9.3914855054991175e+00], + [ 0.0000000000000000e+00, 1.0954451150103341e+00, + 2.1908902300206665e+00], + [ 0.0000000000000000e+00, 0.0000000000000000e+00, + -8.8817841970012523e-16]])), mlir_module_text=r""" -#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":364:11) -#loc10 = loc("jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=triu keep_unused=False inline=False]"(#loc3)) module @jit__lambda_ attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main() -> (tensor<3x3xf64> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<3x3xf64> {jax.result_info = "[1]", mhlo.layout_mode = "default"}) { + func.func public @main() -> (tensor<3x3xf64> {jax.result_info = "result[0]"}, tensor<3x3xf64> {jax.result_info = "result[1]"}) { + %c = stablehlo.constant dense<-1> : tensor loc(#loc) + %cst = stablehlo.constant dense<0.000000e+00> : tensor loc(#loc) %0 = stablehlo.iota dim = 0 : tensor<9xf64> loc(#loc4) %1 = stablehlo.reshape %0 : (tensor<9xf64>) -> tensor<3x3xf64> loc(#loc5) - %c = stablehlo.constant dense<3> : tensor loc(#loc6) - %c_0 = stablehlo.constant dense<3> : tensor loc(#loc6) - %2:2 = stablehlo.custom_call @lapack_dgeqrf_ffi(%1) {mhlo.backend_config = {}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>]} : (tensor<3x3xf64>) -> (tensor<3x3xf64>, tensor<3xf64>) loc(#loc6) - %cst = stablehlo.constant dense<0.000000e+00> : tensor loc(#loc7) - %3 = stablehlo.pad %2#0, %cst, low = [0, 0], high = [0, 0], interior = [0, 0] : (tensor<3x3xf64>, tensor) -> tensor<3x3xf64> loc(#loc8) - %c_1 = stablehlo.constant dense<3> : tensor loc(#loc9) - %c_2 = stablehlo.constant dense<3> : tensor loc(#loc9) - %c_3 = stablehlo.constant dense<3> : tensor loc(#loc9) - %c_4 = stablehlo.constant dense<1> : tensor loc(#loc9) - %c_5 = stablehlo.constant dense<3> : tensor loc(#loc9) - %c_6 = stablehlo.constant dense<3> : tensor loc(#loc9) - %c_7 = stablehlo.constant dense<3> : tensor loc(#loc9) - %c_8 = stablehlo.constant dense<96> : tensor loc(#loc9) - %4:3 = stablehlo.custom_call @lapack_dorgqr(%c_4, %c_5, %c_6, %c_7, %c_8, %3, %2#1) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor, tensor<3x3xf64>, tensor<3xf64>) -> (tensor<3x3xf64>, tensor, tensor<96xf64>) loc(#loc9) - %c_9 = stablehlo.constant dense<0> : tensor loc(#loc9) - %5 = stablehlo.broadcast_in_dim %c_9, dims = [] : (tensor) -> tensor loc(#loc9) - %6 = stablehlo.compare EQ, %4#1, %5, SIGNED : (tensor, tensor) -> tensor loc(#loc9) - %7 = stablehlo.broadcast_in_dim %6, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc9) - %cst_10 = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc9) - %8 = stablehlo.broadcast_in_dim %cst_10, dims = [] : (tensor) -> tensor<3x3xf64> loc(#loc9) - %9 = stablehlo.broadcast_in_dim %7, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<3x3xi1> loc(#loc9) - %10 = stablehlo.select %9, %4#0, %8 : tensor<3x3xi1>, tensor<3x3xf64> loc(#loc9) - %11 = call @triu(%2#0) : (tensor<3x3xf64>) -> tensor<3x3xf64> loc(#loc10) - return %10, %11 : tensor<3x3xf64>, tensor<3x3xf64> loc(#loc) + %2:2 = stablehlo.custom_call @lapack_dgeqrf_ffi(%1) {mhlo.backend_config = {}, mhlo.frontend_attributes = {num_batch_dims = "0"}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>]} : (tensor<3x3xf64>) -> (tensor<3x3xf64>, tensor<3xf64>) loc(#loc6) + %3 = stablehlo.pad %2#0, %cst, low = [0, 0], high = [0, 0], interior = [0, 0] : (tensor<3x3xf64>, tensor) -> tensor<3x3xf64> loc(#loc7) + %4 = stablehlo.custom_call @lapack_dorgqr_ffi(%3, %2#1) {mhlo.backend_config = {}, mhlo.frontend_attributes = {num_batch_dims = "0"}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>]} : (tensor<3x3xf64>, tensor<3xf64>) -> tensor<3x3xf64> loc(#loc8) + %5 = stablehlo.iota dim = 0 : tensor<3x3xi32> loc(#loc9) + %6 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<3x3xi32> loc(#loc10) + %7 = stablehlo.add %5, %6 : tensor<3x3xi32> loc(#loc10) + %8 = stablehlo.iota dim = 1 : tensor<3x3xi32> loc(#loc9) + %9 = stablehlo.compare GE, %7, %8, SIGNED : (tensor<3x3xi32>, tensor<3x3xi32>) -> tensor<3x3xi1> loc(#loc11) + %10 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<3x3xf64> loc(#loc12) + %11 = stablehlo.select %9, %10, %2#0 : tensor<3x3xi1>, tensor<3x3xf64> loc(#loc13) + return %4, %11 : tensor<3x3xf64>, tensor<3x3xf64> loc(#loc) } loc(#loc) - func.func private @triu(%arg0: tensor<3x3xf64> {mhlo.layout_mode = "default"} loc("jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=triu keep_unused=False inline=False]"(#loc3))) -> (tensor<3x3xf64> {mhlo.layout_mode = "default"}) { - %0 = stablehlo.iota dim = 0 : tensor<3x3xi32> loc(#loc11) - %c = stablehlo.constant dense<-1> : tensor loc(#loc10) - %1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<3x3xi32> loc(#loc12) - %2 = stablehlo.add %0, %1 : tensor<3x3xi32> loc(#loc12) - %3 = stablehlo.iota dim = 1 : tensor<3x3xi32> loc(#loc13) - %4 = stablehlo.compare GE, %2, %3, SIGNED : (tensor<3x3xi32>, tensor<3x3xi32>) -> tensor<3x3xi1> loc(#loc14) - %cst = stablehlo.constant dense<0.000000e+00> : tensor loc(#loc10) - %5 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<3x3xf64> loc(#loc15) - %6 = stablehlo.select %4, %5, %arg0 : tensor<3x3xi1>, tensor<3x3xf64> loc(#loc16) - return %6 : tensor<3x3xf64> loc(#loc10) - } loc(#loc10) } loc(#loc) #loc = loc(unknown) -#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":363:26) -#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":363:14) -#loc4 = loc("jit()/jit(main)/iota[dtype=float64 shape=(9,) dimension=0]"(#loc1)) -#loc5 = loc("jit()/jit(main)/reshape[new_sizes=(3, 3) dimensions=None]"(#loc2)) +#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":411:26) +#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":411:14) +#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":412:11) +#loc4 = loc("jit()/jit(main)/iota"(#loc1)) +#loc5 = loc("jit()/jit(main)/reshape"(#loc2)) #loc6 = loc("jit()/jit(main)/geqrf"(#loc3)) -#loc7 = loc("jit()/jit(main)/qr[full_matrices=True]"(#loc3)) -#loc8 = loc("jit()/jit(main)/pad[padding_config=((0, 0, 0), (0, 0, 0))]"(#loc3)) -#loc9 = loc("jit()/jit(main)/householder_product"(#loc3)) -#loc11 = loc("jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=0]"(#loc3)) -#loc12 = loc("jit()/jit(main)/jit(triu)/add"(#loc3)) -#loc13 = loc("jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=1]"(#loc3)) -#loc14 = loc("jit()/jit(main)/jit(triu)/ge"(#loc3)) -#loc15 = loc("jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(3, 3) broadcast_dimensions=()]"(#loc3)) -#loc16 = loc("jit()/jit(main)/jit(triu)/select_n"(#loc3)) +#loc7 = loc("jit()/jit(main)/pad"(#loc3)) +#loc8 = loc("jit()/jit(main)/householder_product"(#loc3)) +#loc9 = loc("jit()/jit(main)/iota"(#loc3)) +#loc10 = loc("jit()/jit(main)/add"(#loc3)) +#loc11 = loc("jit()/jit(main)/ge"(#loc3)) +#loc12 = loc("jit()/jit(main)/broadcast_in_dim"(#loc3)) +#loc13 = loc("jit()/jit(main)/select_n"(#loc3)) """, - mlir_module_serialized=( - b"ML\xefR\x01StableHLO_v0.9.0\x00\x01)\x05\x01\x03\x01\x03\x05\x03\x19\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x03\xaa\x02\x12\x023\x01\x9d\x0f\x17\x0b\x0f\x13\x13\x0b\x07\x0b\x0f\x13\x0f\x0b\x0b\x0b\x0b\x13\x0b\x0b\x0f\x0b\x0b\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b+\x0b\x0f\x0b\x0b\x0b33\x0b\x0f\x0b\x13\x0b\x13\x0f\x0b\x1b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x17\x0f\x0b\x17\x0bS\x0b\x0f\x0b#\x0b\x0b\x0b\x0f\x0b\x0b\x13\x13K\x13\x1b\x13\x03g\x0fO\x0b\x0b\x0b//\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x1b\x0b\x0b\x0b\x13\x0b\x0b\x0f\x1f\x0f\x0f\x0b//\x0b\x0b\x0b\x0f\x0f\x17\x13\x1f\x1f\x1f\x0b\x0b'\x0f\x17\x17\x1f\x0b/O\x01\x07\x17\x17\x0b\x01\x05\x0b\x0f\x03/\x17\x0f\x0f\x07\x07\x0f\x17\x07\x07\x07\x17\x13\x17\x17\x13\x13\x13\x13\x13\x17\x13\x0f\x17\x02\xaa\t\x1d\x8f\x03\x17\x11\xb2\x05\x17\x05\x1f\x1dO\x03\x03\x03%\xd1\x03\x03\x05\xd9\x05!\x1f\x05#\x1dy\x03\x03\x03\x05\xeb\x11\x03\x05\x05%\x05'\x05)\x05+\x03\x03#\xcd\x05-\x05/\x1dW\x03\x051\x053\x03\x03\x05\xd7\x055\x057\x059\x05;\x05=\x05?\x05A\x05C\x03\tACE\x17G\x17\rI\x05E\x11\x01\x00\x05G\x05I\x05K\x03\x0b\x19\xa1\x1b\xb7\x1d\xb9\r\xc3\x1f\xc5\x03\x0b\x19\xad\x1b\xc9\x1d\xad\r\xaf\x1f\xcb\x05M\x1dS\x03\x05O\x03\x03\x05\xcf\x05Q\x03\x03#\xd3\x1d]\x03\x05S\x03\x05)\xb1+\xd5\x1dc\x03\x05U\x1dg\x03\x05W\x1dk\x03\x05Y\x1doq\x05[\x17\x11\xae\x055\x1duw\x05]\x17\x11\xae\x05\x1d\x05_\x03\x13/\xdb1\xb33\xdd5\xa17\xb5}\xdf9\xe1;\xe3=\xe7\x05a\x1d\x81\x03\x05c\x03\x07\x85\xa9\x87\xa9\x89\xa9\x05e\x05g\x05i\x1d\x8d\x03\x05k\x05m\x03\x03\x05\xe9\x03\x03\x05\xed\x03\x11/\xef1\xb33\xf15\xa17\xb59\xf3;\xf5=\xf9\x03\x03\x05\xfb\x03\x05)\xb1+\xfd\x03\x03\x05\xff\x1f-\x01\x1f'!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1do\x1dq\x1f)\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x1b\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1ds\x03\x03\xc7\x1du\t\x07\x1dw\x05\x01#\x1d\x03\x05\xbb\xbf\r\x05\xab\xbd\xa3\xa5\x1dy\r\x05\xab\xc1\xa3\xa5\x1d{\x1d}\x1d\x7f\r\x03\xa3\xa5#\x1f\x1d\x81\x13\r\x01\x1f\x07\t\xff\xff\xff\xff\x1f!\x01\x13\r\x05\x07\x05\x1f\x0f\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\t\x11\x03\x00\x00\x00\x00\x00\x00\x00\x0b\x03\x1d\x83\r\x01\x03\x03\x9f\x03\x03\xe5\x15\x03\x01\x01\x01\x03\x05\x9f\xa7\x1f\x07\t\x01\x00\x00\x00\x1f\x07\t\x03\x00\x00\x00\x1f\x07\t`\x00\x00\x00\x0b\x05\x1d\x85\x03\x0f\x9d\x9d\x9d\x9d\x9d\x9f\xa7\x03\x03\xf7\x15\x03\x01\x15\x01\x03\x07\x9f\x9d\xa7\x1f\x07\t\x00\x00\x00\x00\x07\x01\x1f\x0f\x11\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f\x1b!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\x03%\x02\x02\x03\x03\x0e\x02\xaf\x05\x87\x01\t\x01\x02\x02)\x05\r\r\x0b)\x01\x17)\x01\r\x0b\x1d)\x01\x0b)\x05\r\r\x17\x01\x13\x1b)\x05\r\r\x13)\x03\t\r\x11\x01\x05\x05\x05\x11\x03\x05\x03\x05)\x03\x01\r)\x03%\x0b)\x03\r\x0b)\x03\t\x15)\x03\x05\x15)\x03\x02\x03\x0b)\x03\x01\x15)\x01\x13)\x05\x05\x05\x13\x04\x8a\x04\x05\x01\x11\x0f?\x07\x03\x01\t\t\x11\x0fK\x07\x039i\x07\x03m!\x03#\x15\x06s\x03\x05\x03\x01\x03\x03\x13\x0b\x03\t\x03\x03\x13\x0b\x03\t\x11\x07\x13{\x05\x05%\x03\x03\x03\x03\x7f-\x03\x0f\x17\x07\x8b\x83\x03\x05\x05\t\r\x03\x03\x01\x0b\x03\t\x03\x03\x01\x0b\x03\t\x03\x03\x01\x0b\x03\t\x03\x03\x01\x91\x03\x07\x03\x03\x01\x15\x03\x07\x03\x03\x01\x15\x03\x07\x03\x03\x01\x15\x03\x07\x03\x03\x01\x93\x03\x07\x11\x07\x01\x95\x07\x05\x07+\x0f\x17\x19\x1b\x1d\x1f\x0f\x0b\x03\x03\x01\x97\x03\x07\x05\x07\x01\t\x03\x07\x03'\x0b\x07\x01\x99\x03/\x05#)\x05\x07\x01\t\x031\x03+\x03\x03\x01\x9b\x03\x0f\x05\x07\x01\t\x03\x05\x03/\x05\x07\x01\x06\x02\x03\x19\x03-\r\x06\x01\x03\x05\x073!1\x19\x07\x07\n\x02\x03\x05\x03\t\x0f\x04\x0f\x0557\t\x11\x07M\x07\x03\x15+\x03\x05\x07\x07\x03Q!\x03\x11\x03\x03\x07U\x03\x07\x05\x07'\t\x03\x11\x03\x05\x13\x06'\x03\x11\x05\x03\x07\x07\x03[Y\x03\x11\x0b\x07a_\x03\x19\x05\t\x0b\x03\x03\x07-\x03\x0f\x05\x07e\t\x03\x05\x03\x0f\r\x06i\x03\x05\x07\r\x11\x01\x0f\x04\x07\x03\x13\x06\x03\x01\x05\x01\x00\x9e\x1a\x89\x0f\x1d%\x11\x0f\x0b\t\t\x03\x0b!\x11#Y\x87##%_)=\x85\x87W\xb3K\x9bM\x9bn\x03\x1b%)9\x1f/!!)#\x1f\x19+\x1b+\x1f\x1f\x15\x1d\x15i\x13\r\x11\x0f\x17\x0f\x1f\x15\x15\x17\x11\x11)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00iota_v1\x00func_v1\x00compare_v1\x00select_v1\x00return_v1\x00custom_call_v1\x00add_v1\x00reshape_v1\x00pad_v1\x00call_v1\x00value\x00sym_name\x00third_party/py/jax/tests/export_back_compat_test.py\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00broadcast_dimensions\x00compare_type\x00comparison_direction\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,)" - b' out_shardings=(UnspecifiedValue,) in_layouts=(None,)' - b' out_layouts=(None,) resource_env=None donated_invars=(False,)' - b' name=triu keep_unused=False' - b' inline=False]\x00jit()/jit(main)/jit(triu)/iota[dtype=int32' - b' shape=(3, 3)' - b' dimension=0]\x00jit()/jit(main)/jit(triu)/add\x00jit()/jit(main)/jit(triu)/iota[dtype=int32' - b' shape=(3, 3)' - b' dimension=1]\x00jit()/jit(main)/jit(triu)/ge\x00jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(3,' - b' 3) broadcast_dimensions=()]\x00jit()/jit(main)/jit(triu)/select_n\x00jit()/jit(main)/iota[dtype=float64' - b' shape=(9,)' - b' dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(3, 3)' - b' dimensions=None]\x00jit()/jit(main)/geqrf\x00mhlo.backend_config\x00jit()/jit(main)/qr[full_matrices=True]\x00edge_padding_high\x00edge_padding_low\x00interior_padding\x00jit()/jit(main)/pad[padding_config=((0,' - b' 0, 0), (0, 0,' - b' 0))]\x00jit()/jit(main)/householder_product\x00mhlo.layout_mode\x00default\x00jax.result_info\x00triu\x00\x00[0]\x00[1]\x00main\x00public\x00private\x00lapack_dgeqrf_ffi\x00lapack_dorgqr\x00callee\x00' - ), + mlir_module_serialized=b'ML\xefR\rStableHLO_v1.9.3\x00\x01)\x05\x01\x05\x19\x01\x03\x0b\x03\x17\x0f\x13\x17\x1b\x1f#\'+/37\x03\xc5\x8b\'\x01E\x17\x07\x0b\x0f\x0b\x1b\x0f\x0f#\x0b\x0f\x0b\x0b\x0b\x0f\x17\x0f\x0b\x17\x0b\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x03G\x0b/\x0b\x0f\x0b\x0b\x0b\x0fO\x13\x0f\x0b\x13\x13\x0b\x13\x0b\x0b\x0b\x1f/\x0b\x13\x0b\x0b\x0b\x0f\x17/\x0b\x0f\x13\x0f\x0b\x0b\x01\x05\x0b\x0f\x03#\x17\x07\x07\x17\x0f\x07\x0f\x07\x17\x13\x13\x13\x13\x13\x13\x17\x07\x02\xaa\x04\x17\x05r\x06\x17\x1f\x05\x1d\x11\x03\x05\x05\x1f\x03\x05\'o)q\x1d\t\x01\x1d7\x01\x03\x07\x13\x15\x17\x07\x19\x07\x05!\x11\x01\x00\x05#\x05%\x05\'\x1d\t\x1f\x17\x05n\x065\x1d#%\x05)\x17\x05n\x06\x1d\x05+\x05-\x1d-\x01\x05/\x1d1\x01\x051\x1d5\x01\x053\x055\x1d;\x01\x057\x1d?\x01\x059\x1dC\x01\x05;\x03\x01\x1f\x1f\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1d=\x13\t\x01\x0b\x03\x1d?\x05\x01\x03\x03U\x1f\x1b!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\x05U}\x1f!\x01#\x15\x03\x05_c\r\x03Ia\x1dA\r\x03Ie\x1dC\x1dE\x1dG\x1f\r\t\xff\xff\xff\xff\x1f\x11\x11\x00\x00\x00\x00\x00\x00\x00\x00\r\x01\r\x03su\x1dI\x1dK\x1dM\x03\x03{\x15\x03\x01\x01\x01\x1f\x1d\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1dO\x03\x03\x83\x15\x01\x01\x01\x13\t\x05\t\x07\x07\x05\x01\t\x01\x02\x02)\x05\r\r\x07\x0b\x1d)\x05\r\r\x0f)\x01\x0f\x1b)\x01\x07\x13\x11\x01\x05\x05\x05)\x03%\x07)\x03\r\x07)\x03\t\x13)\x03\x05\x13)\x03\t\t)\x03\x01\t)\x05\r\r%\x01\x04"\x02\x05\x01Q\x03\x11\x01\x07\x04\xff\x03\x01\x05\x0bP\x03\x03\x07\x04\xeb\x03\x1f=\x05B\x03\x05\x03\r\x05B\x03\x07\x03\x11\x03B\x1d\t\x03\x17\r\x06!\x03\x05\x03\x05\x07G+\x0b\x0b\x05\x05\x19\x03\x07\x0fF/\r\x03\x05\x05\t\x03\x07G3\x0b\x0f\x03\x05\x05\r\x0b\x03B\r\t\x03\x0b\tF\x0f\x11\x03\x0b\x03\x01\x11\x06\x0f\x03\x0b\x05\x11\x13\x03B\r\x13\x03\x0b\x13F9\x15\x03#\x05\x15\x17\tF=\x11\x03\x05\x03\x03\x15\x06A\x03\x05\x07\x19\x1b\t\x17\x04\x03\x05\x0f\x1d\x06\x03\x01\x05\x01\x00\xba\x0bQ%%\x05\x1f\x0f\x0b\x15\x15\x03!CS79Y9=3)A\x1b%)9;i\x15\x15\x17\x0f\x0f\x17\x11)\x1f\x19\x11\x0f\x0b\x11builtin\x00vhlo\x00module\x00iota_v1\x00constant_v1\x00custom_call_v1\x00broadcast_in_dim_v1\x00func_v1\x00reshape_v1\x00pad_v1\x00add_v1\x00compare_v1\x00select_v1\x00return_v1\x00third_party/py/jax/tests/export_back_compat_test.py\x00jit()/jit(main)/iota\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00jit()/jit(main)/reshape\x00mhlo.backend_config\x00mhlo.frontend_attributes\x00jit()/jit(main)/geqrf\x00jit()/jit(main)/pad\x00jit()/jit(main)/householder_product\x00jit()/jit(main)/add\x00jit()/jit(main)/ge\x00jit()/jit(main)/broadcast_in_dim\x00jit()/jit(main)/select_n\x00jax.result_info\x00\x00result[0]\x00result[1]\x00main\x00public\x00num_batch_dims\x000\x00lapack_dgeqrf_ffi\x00lapack_dorgqr_ffi\x00\x08[\x17\x057\x01\x0bE[]gi\x03k\x03m\x03K\x11MOwEQSyW\x07GGG\x11MO\x7fEQW\x81S\x03Y\x03\x85\x05\x87\x89', xla_call_module_version=9, nr_devices=1, ) # End paste diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/cpu_schur_lapack_gees.py b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_schur_lapack_gees.py index 309aa73f20ba..db514111ec3e 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_data/cpu_schur_lapack_gees.py +++ b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_schur_lapack_gees.py @@ -15,232 +15,7 @@ # ruff: noqa import datetime -from numpy import array, int32, float32, complex64 - -data_2023_07_16 = {} - -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_07_16["f32"] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_sgees'], - serialized_date=datetime.date(2023, 7, 16), - inputs=(array([[ 0., 1., 2., 3.], - [ 4., 5., 6., 7.], - [ 8., 9., 10., 11.], - [12., 13., 14., 15.]], dtype=float32),), - expected_outputs=(array([[ 3.2464233e+01, -1.3416403e+01, -1.5532076e-05, -4.3390692e-06], - [ 0.0000000e+00, -2.4642491e+00, -1.4625000e-06, -6.4478525e-07], - [ 0.0000000e+00, 0.0000000e+00, -8.1893580e-07, -2.5704816e-07], - [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 1.5155359e-07]], - dtype=float32), array([[-0.11417631 , 0.828833 , -0.546308 , -0.039330132], - [-0.33000442 , 0.4371459 , 0.69909686 , 0.45963493 ], - [-0.54583275 , 0.045459975, 0.24073309 , -0.80127877 ], - [-0.7616609 , -0.34622616 , -0.39352104 , 0.3809742 ]], - dtype=float32)), - mlir_module_text=r""" -#loc = loc(unknown) -module @jit_func attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main(%arg0: tensor<4x4xf32> {jax.arg_info = "input", mhlo.sharding = "{replicated}"} loc(unknown)) -> (tensor<4x4xf32> {jax.result_info = "[0]"}, tensor<4x4xf32> {jax.result_info = "[1]"}) { - %0 = stablehlo.constant dense<1> : tensor loc(#loc2) - %1 = stablehlo.constant dense<4> : tensor loc(#loc2) - %2 = stablehlo.constant dense<86> : tensor loc(#loc2) - %3 = stablehlo.constant dense<78> : tensor loc(#loc2) - %4:6 = stablehlo.custom_call @lapack_sgees(%0, %1, %2, %3, %arg0) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>, dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>]} : (tensor, tensor, tensor, tensor, tensor<4x4xf32>) -> (tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor, tensor) loc(#loc2) - %5 = stablehlo.constant dense<0> : tensor loc(#loc2) - %6 = stablehlo.broadcast_in_dim %5, dims = [] : (tensor) -> tensor loc(#loc2) - %7 = stablehlo.compare EQ, %4#5, %6, SIGNED : (tensor, tensor) -> tensor loc(#loc2) - %8 = stablehlo.broadcast_in_dim %7, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc2) - %9 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc2) - %10 = stablehlo.broadcast_in_dim %9, dims = [] : (tensor) -> tensor<4x4xf32> loc(#loc2) - %11 = stablehlo.broadcast_in_dim %8, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc2) - %12 = stablehlo.select %11, %4#0, %10 : tensor<4x4xi1>, tensor<4x4xf32> loc(#loc2) - %13 = stablehlo.broadcast_in_dim %7, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc2) - %14 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc2) - %15 = stablehlo.broadcast_in_dim %14, dims = [] : (tensor) -> tensor<4x4xf32> loc(#loc2) - %16 = stablehlo.broadcast_in_dim %13, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc2) - %17 = stablehlo.select %16, %4#3, %15 : tensor<4x4xi1>, tensor<4x4xf32> loc(#loc2) - return %12, %17 : tensor<4x4xf32>, tensor<4x4xf32> loc(#loc) - } loc(#loc) -} loc(#loc) -#loc1 = loc("/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py":483:0) -#loc2 = loc("jit(func)/jit(main)/schur[compute_schur_vectors=True sort_eig_vals=False select_callable=None]"(#loc1)) -""", - mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01\x1f\x05\x01\x03\x01\x03\x05\x03\x0f\x07\t\x0b\r\x0f\x11\x13\x03\xd5\x97+\x01M\x0f\x0b\x13\x07\x0f\x0b\x0b\x13\x13#\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x13\x0b\x17\x0b\x13\x13\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x0b\x03K\x0fO\x0b/\x0f\x1b\x0b\x0b\x0b\x0b\x0b\x13\x13\x0b\x13\x0b\x0b\x0b\x1f\x1f\x13\x13\x0b\x0b\x0b\x0b\x0b\x1f\x0f\x17#\x1f\x0f\x0b\x0b\x1fO\x01\x03\x0f\x03)\x17\x0f\x0f\x07\x07\x07\x0f\x13\x07\x17\x17\x1b\x07\x07\x13\x13\x13\x13\x0f\x13\x02\xbe\x05\x1d')\x05\x15\x03\x03\r\x8d\x1f\x11\x01\x05\x05\x17\x05\x19\x03\x03\x03\x93\x03\x03\r\x95\x03\x07\x15\t\x17\t\x0b\x19\x05\x1b\x05\x1d\x05\x1f\x03\x0b\x1dU\x1fa!c\x0bm#o\x05!\x05#\x05%\x05'\x03\x03\x03q\x05)\x17+\x8e\x07\x01\x05+\x03\x03\x03s\x03\x03\x03u\x03\x03\x03w\x03\x115y7{9};\x7f=\x81?\x83A\x85C\x89\x05-\x05/\x051\x053\x055\x057\x059\x05;\x03\x03\x03\x8b\x03\x05I\x8fK\x91\x05=\x05?\x1f\x1f\x01\x1f!!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1dA\x1f#\x11\x00\x00\x00\x00\x00\x00\x00\x00\x03\x03W\r\x05Y[]_\x1dC\x1dE\x1dG\x1dI#\x19\x03\x05ei\r\x03Qg\x1dK\r\x03Qk\x1dM\x1dO\x1dQ\x1f\x05\t\x01\x00\x00\x00\x1f\x05\t\x04\x00\x00\x00\x1f\x07\x03V\x1f\x07\x03N\x0b\x05\x1dS\x1dU\x03\x01\x05\x01\x03\x0bMMMMO\x03\x03\x87\x15\x03\x01\x11\x01\x03\rOSSOMM\x1f\x05\t\x00\x00\x00\x00\x1f%\x01\t\x07\x07\x01\x1f\x0f\t\x00\x00\xc0\x7f\x1f)!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\x02\x02)\x05\x11\x11\t)\x01\x1b)\x01\x1d\t\x13\x01)\x01\t)\x03\x11\t\x1d)\x05\x05\x05\r)\x05\x11\x11\r\x11\x03\x03\x05\x03\x03\x1b!)\x03\x01\x0b)\x03\t\x0b)\x03\x05\x0b)\x03\x01\x13)\x01\r)\x03\t\x13\x04\xa2\x02\x05\x01\x11\x07\x13\x07\x03\x01\x05\t\x11\x07\x1b\x05\x031O\x03\x03\x07\x03\x03\x01%\x03\x05\x03\x03\x01-\x03\x05\x03\x03\x01/\x03\x07\x03\x03\x011\x03\x07\x0b\x07\x013\r\x03\x11\x11\x03\x05\x05\x0b\x03\x05\x07\t\x01\x03\x03\x01E\x03\x05\x05\x07\x01\x05\x03\x05\x03\x17\r\x07\x01G\x03'\x05\x15\x19\x05\x07\x01\x05\x03\x15\x03\x1b\x03\x03\x01\x0f\x03\x0f\x05\x07\x01\x05\x03\x03\x03\x1f\x05\x07\x01\x11\x03\x17\x03\x1d\x07\x06\x01\x03\x03\x07#\x0b!\x05\x07\x01\x05\x03\x15\x03\x1b\x03\x03\x01\x0f\x03\x0f\x05\x07\x01\x05\x03\x03\x03)\x05\x07\x01\x11\x03\x17\x03'\x07\x06\x01\x03\x03\x07-\x11+\x0f\x04\x07\x05%/\x06\x03\x01\x05\x01\x002\x0bW\x1b\x03\x0f\x0b\t\t\x1b\x1d\r\x1b!+\x1b\x1f/!!)#\x1f\x19\x97\xbf\x1f\x15\x1d\x15\x13%)+\x13\r\x15\x17\x1f\x11\x15)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00select_v1\x00func_v1\x00custom_call_v1\x00compare_v1\x00return_v1\x00value\x00sym_name\x00broadcast_dimensions\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00jit(func)/jit(main)/schur[compute_schur_vectors=True sort_eig_vals=False select_callable=None]\x00/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00compare_type\x00comparison_direction\x00jax.result_info\x00jax.arg_info\x00input\x00mhlo.sharding\x00{replicated}\x00[0]\x00[1]\x00main\x00public\x00\x00lapack_sgees\x00", - xla_call_module_version=6, -) # End paste - - -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_07_16["f64"] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_dgees'], - serialized_date=datetime.date(2023, 7, 16), - inputs=(array([[ 0., 1., 2., 3.], - [ 4., 5., 6., 7.], - [ 8., 9., 10., 11.], - [12., 13., 14., 15.]]),), - expected_outputs=(array([[ 3.2464249196572958e+01, -1.3416407864998734e+01, - 1.4217165257496823e-15, 1.7257338996070338e-16], - [ 0.0000000000000000e+00, -2.4642491965729794e+00, - 4.0099214829607365e-16, 2.9384059908060751e-16], - [ 0.0000000000000000e+00, 0.0000000000000000e+00, - -1.5668631265126207e-15, 6.3403580326623540e-16], - [ 0.0000000000000000e+00, 0.0000000000000000e+00, - 0.0000000000000000e+00, 1.2369554016158485e-16]]), array([[-0.11417645138733855 , 0.8288327563197505 , - 0.4940336612834742 , -0.23649681080057947 ], - [-0.3300045986655475 , 0.4371463883638869 , - -0.8349858635153001 , -0.052901868866879136], - [-0.545832745943757 , 0.045460020408024784, - 0.18787074318017621 , 0.8152941701354965 ], - [-0.7616608932219662 , -0.3462263475478383 , - 0.1530814590516493 , -0.525895490468038 ]])), - mlir_module_text=r""" -#loc = loc(unknown) -module @jit_func attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main(%arg0: tensor<4x4xf64> {jax.arg_info = "input", mhlo.sharding = "{replicated}"} loc(unknown)) -> (tensor<4x4xf64> {jax.result_info = "[0]"}, tensor<4x4xf64> {jax.result_info = "[1]"}) { - %0 = stablehlo.constant dense<1> : tensor loc(#loc2) - %1 = stablehlo.constant dense<4> : tensor loc(#loc2) - %2 = stablehlo.constant dense<86> : tensor loc(#loc2) - %3 = stablehlo.constant dense<78> : tensor loc(#loc2) - %4:6 = stablehlo.custom_call @lapack_dgees(%0, %1, %2, %3, %arg0) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>, dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>]} : (tensor, tensor, tensor, tensor, tensor<4x4xf64>) -> (tensor<4x4xf64>, tensor<4xf64>, tensor<4xf64>, tensor<4x4xf64>, tensor, tensor) loc(#loc2) - %5 = stablehlo.constant dense<0> : tensor loc(#loc2) - %6 = stablehlo.broadcast_in_dim %5, dims = [] : (tensor) -> tensor loc(#loc2) - %7 = stablehlo.compare EQ, %4#5, %6, SIGNED : (tensor, tensor) -> tensor loc(#loc2) - %8 = stablehlo.broadcast_in_dim %7, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc2) - %9 = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc2) - %10 = stablehlo.broadcast_in_dim %9, dims = [] : (tensor) -> tensor<4x4xf64> loc(#loc2) - %11 = stablehlo.broadcast_in_dim %8, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc2) - %12 = stablehlo.select %11, %4#0, %10 : tensor<4x4xi1>, tensor<4x4xf64> loc(#loc2) - %13 = stablehlo.broadcast_in_dim %7, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc2) - %14 = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc2) - %15 = stablehlo.broadcast_in_dim %14, dims = [] : (tensor) -> tensor<4x4xf64> loc(#loc2) - %16 = stablehlo.broadcast_in_dim %13, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc2) - %17 = stablehlo.select %16, %4#3, %15 : tensor<4x4xi1>, tensor<4x4xf64> loc(#loc2) - return %12, %17 : tensor<4x4xf64>, tensor<4x4xf64> loc(#loc) - } loc(#loc) -} loc(#loc) -#loc1 = loc("/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py":483:0) -#loc2 = loc("jit(func)/jit(main)/schur[compute_schur_vectors=True sort_eig_vals=False select_callable=None]"(#loc1)) -""", - mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01\x1f\x05\x01\x03\x01\x03\x05\x03\x0f\x07\t\x0b\r\x0f\x11\x13\x03\xd5\x97+\x01M\x0f\x0b\x13\x07\x0f\x0b\x0b\x13\x13#\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x13\x0b\x17\x0b\x13\x13\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x0b\x03K\x0fO\x0b/\x0f\x1b\x0b\x0b\x0b\x0b\x0b\x13\x13\x0b\x13\x0b\x0b\x0b\x1f\x1f\x13\x13\x0b\x0b\x0b\x0b\x0b\x1f\x0f\x17#\x1f\x0f\x0b\x0b/O\x01\x03\x0f\x03)\x17\x0f\x0f\x07\x07\x07\x0f\x13\x07\x17\x17\x1b\x07\x07\x13\x13\x13\x13\x0f\x13\x02\xce\x05\x1d')\x05\x15\x03\x03\r\x8d\x1f\x11\x01\x05\x05\x17\x05\x19\x03\x03\x03\x93\x03\x03\r\x95\x03\x07\x15\t\x17\t\x0b\x19\x05\x1b\x05\x1d\x05\x1f\x03\x0b\x1dU\x1fa!c\x0bm#o\x05!\x05#\x05%\x05'\x03\x03\x03q\x05)\x17+\x8e\x07\x01\x05+\x03\x03\x03s\x03\x03\x03u\x03\x03\x03w\x03\x115y7{9};\x7f=\x81?\x83A\x85C\x89\x05-\x05/\x051\x053\x055\x057\x059\x05;\x03\x03\x03\x8b\x03\x05I\x8fK\x91\x05=\x05?\x1f\x1f\x01\x1f!!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1dA\x1f#\x11\x00\x00\x00\x00\x00\x00\x00\x00\x03\x03W\r\x05Y[]_\x1dC\x1dE\x1dG\x1dI#\x19\x03\x05ei\r\x03Qg\x1dK\r\x03Qk\x1dM\x1dO\x1dQ\x1f\x05\t\x01\x00\x00\x00\x1f\x05\t\x04\x00\x00\x00\x1f\x07\x03V\x1f\x07\x03N\x0b\x05\x1dS\x1dU\x03\x01\x05\x01\x03\x0bMMMMO\x03\x03\x87\x15\x03\x01\x11\x01\x03\rOSSOMM\x1f\x05\t\x00\x00\x00\x00\x1f%\x01\t\x07\x07\x01\x1f\x0f\x11\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f)!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\x02\x02)\x05\x11\x11\t)\x01\x1b)\x01\x1d\x0b\x13\x01)\x01\t)\x03\x11\t\x1d)\x05\x05\x05\r)\x05\x11\x11\r\x11\x03\x03\x05\x03\x03\x1b!)\x03\x01\x0b)\x03\t\x0b)\x03\x05\x0b)\x03\x01\x13)\x01\r)\x03\t\x13\x04\xa2\x02\x05\x01\x11\x07\x13\x07\x03\x01\x05\t\x11\x07\x1b\x05\x031O\x03\x03\x07\x03\x03\x01%\x03\x05\x03\x03\x01-\x03\x05\x03\x03\x01/\x03\x07\x03\x03\x011\x03\x07\x0b\x07\x013\r\x03\x11\x11\x03\x05\x05\x0b\x03\x05\x07\t\x01\x03\x03\x01E\x03\x05\x05\x07\x01\x05\x03\x05\x03\x17\r\x07\x01G\x03'\x05\x15\x19\x05\x07\x01\x05\x03\x15\x03\x1b\x03\x03\x01\x0f\x03\x0f\x05\x07\x01\x05\x03\x03\x03\x1f\x05\x07\x01\x11\x03\x17\x03\x1d\x07\x06\x01\x03\x03\x07#\x0b!\x05\x07\x01\x05\x03\x15\x03\x1b\x03\x03\x01\x0f\x03\x0f\x05\x07\x01\x05\x03\x03\x03)\x05\x07\x01\x11\x03\x17\x03'\x07\x06\x01\x03\x03\x07-\x11+\x0f\x04\x07\x05%/\x06\x03\x01\x05\x01\x002\x0bW\x1b\x03\x0f\x0b\t\t\x1b\x1d\r\x1b!+\x1b\x1f/!!)#\x1f\x19\x97\xbf\x1f\x15\x1d\x15\x13%)+\x13\r\x15\x17\x1f\x11\x15)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00select_v1\x00func_v1\x00custom_call_v1\x00compare_v1\x00return_v1\x00value\x00sym_name\x00broadcast_dimensions\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00jit(func)/jit(main)/schur[compute_schur_vectors=True sort_eig_vals=False select_callable=None]\x00/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00compare_type\x00comparison_direction\x00jax.result_info\x00jax.arg_info\x00input\x00mhlo.sharding\x00{replicated}\x00[0]\x00[1]\x00main\x00public\x00\x00lapack_dgees\x00", - xla_call_module_version=6, -) # End paste - -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_07_16["c64"] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_cgees'], - serialized_date=datetime.date(2023, 7, 16), - inputs=(array([[ 0.+0.j, 1.+0.j, 2.+0.j, 3.+0.j], - [ 4.+0.j, 5.+0.j, 6.+0.j, 7.+0.j], - [ 8.+0.j, 9.+0.j, 10.+0.j, 11.+0.j], - [12.+0.j, 13.+0.j, 14.+0.j, 15.+0.j]], dtype=complex64),), - expected_outputs=(array([[ 3.2464264e+01+0.j, -1.3416414e+01+0.j, -3.3649465e-06+0.j, - 3.5482326e-06+0.j], - [ 0.0000000e+00+0.j, -2.4642489e+00+0.j, -7.4810049e-07+0.j, - 6.1193055e-07+0.j], - [ 0.0000000e+00+0.j, 0.0000000e+00+0.j, -5.7737759e-07+0.j, - 2.5704813e-07+0.j], - [ 0.0000000e+00+0.j, 0.0000000e+00+0.j, 0.0000000e+00+0.j, - 1.4719124e-07+0.j]], dtype=complex64), array([[ 0.11417647 +0.j, -0.8288329 +0.j, 0.5452458 +0.j, - -0.05202686 +0.j], - [ 0.3300045 +0.j, -0.43714625 +0.j, -0.68821627 +0.j, - 0.47577178 +0.j], - [ 0.54583293 +0.j, -0.045460097-0.j, -0.25930598 +0.j, - -0.79546237 +0.j], - [ 0.76166105 +0.j, 0.3462263 +0.j, 0.40227604 +0.j, - 0.37171766 +0.j]], dtype=complex64)), - mlir_module_text=r""" -#loc = loc(unknown) -module @jit_func attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main(%arg0: tensor<4x4xcomplex> {jax.arg_info = "input", mhlo.sharding = "{replicated}"} loc(unknown)) -> (tensor<4x4xcomplex> {jax.result_info = "[0]"}, tensor<4x4xcomplex> {jax.result_info = "[1]"}) { - %0 = stablehlo.constant dense<1> : tensor loc(#loc2) - %1 = stablehlo.constant dense<4> : tensor loc(#loc2) - %2 = stablehlo.constant dense<86> : tensor loc(#loc2) - %3 = stablehlo.constant dense<78> : tensor loc(#loc2) - %4:6 = stablehlo.custom_call @lapack_cgees(%0, %1, %2, %3, %arg0) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>, dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>]} : (tensor, tensor, tensor, tensor, tensor<4x4xcomplex>) -> (tensor<4x4xcomplex>, tensor<4xf32>, tensor<4xcomplex>, tensor<4x4xcomplex>, tensor, tensor) loc(#loc2) - %5 = stablehlo.constant dense<0> : tensor loc(#loc2) - %6 = stablehlo.broadcast_in_dim %5, dims = [] : (tensor) -> tensor loc(#loc2) - %7 = stablehlo.compare EQ, %4#5, %6, SIGNED : (tensor, tensor) -> tensor loc(#loc2) - %8 = stablehlo.broadcast_in_dim %7, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc2) - %9 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> loc(#loc2) - %10 = stablehlo.broadcast_in_dim %9, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc2) - %11 = stablehlo.broadcast_in_dim %8, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc2) - %12 = stablehlo.select %11, %4#0, %10 : tensor<4x4xi1>, tensor<4x4xcomplex> loc(#loc2) - %13 = stablehlo.broadcast_in_dim %7, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc2) - %14 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> loc(#loc2) - %15 = stablehlo.broadcast_in_dim %14, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc2) - %16 = stablehlo.broadcast_in_dim %13, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc2) - %17 = stablehlo.select %16, %4#3, %15 : tensor<4x4xi1>, tensor<4x4xcomplex> loc(#loc2) - return %12, %17 : tensor<4x4xcomplex>, tensor<4x4xcomplex> loc(#loc) - } loc(#loc) -} loc(#loc) -#loc1 = loc("/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py":483:0) -#loc2 = loc("jit(func)/jit(main)/schur[compute_schur_vectors=True sort_eig_vals=False select_callable=None]"(#loc1)) -""", - mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01\x1f\x05\x01\x03\x01\x03\x05\x03\x0f\x07\t\x0b\r\x0f\x11\x13\x03\xd9\x97/\x01M\x0f\x0b\x13\x07\x0f\x0b\x0b\x13\x13#\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x13\x0b\x17\x0b\x13\x13\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x0b\x03K\x0fO\x0b/\x0f\x1b\x0b\x0b\x0b\x0b\x0b\x13\x13\x0b\x13\x0b\x0b\x0b\x1f\x1f\x13\x13\x0b\x0b\x0b\x0b\x0b\x1f\x0f\x17#\x1f\x0f\x0b\x0b/O\x01\x03\x0f\x03-\x17\x0f\x0f\x0b\x07\x07\x0f\x07\x07\x17\x17\x1b\x07\x07\x13\x13\x13\x13\x13\x13\x0f\x13\x02\xe6\x05\x1d')\x05\x15\x03\x03\r\x8d\x1f\x11\x01\x05\x05\x17\x05\x19\x03\x03\x03\x93\x03\x03\r\x95\x03\x07\x15\t\x17\t\x0b\x19\x05\x1b\x05\x1d\x05\x1f\x03\x0b\x1dU\x1fa!c\x0bm#o\x05!\x05#\x05%\x05'\x03\x03\x03q\x05)\x17+\x8e\x07\x01\x05+\x03\x03\x03s\x03\x03\x03u\x03\x03\x03w\x03\x115y7{9};\x7f=\x81?\x83A\x85C\x89\x05-\x05/\x051\x053\x055\x057\x059\x05;\x03\x03\x03\x8b\x03\x05I\x8fK\x91\x05=\x05?\x1f#\x01\x1f%!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1dA\x1f'\x11\x00\x00\x00\x00\x00\x00\x00\x00\x03\x03W\r\x05Y[]_\x1dC\x1dE\x1dG\x1dI#\x19\x03\x05ei\r\x03Qg\x1dK\r\x03Qk\x1dM\x1dO\x1dQ\x1f\x05\t\x01\x00\x00\x00\x1f\x05\t\x04\x00\x00\x00\x1f\x07\x03V\x1f\x07\x03N\x0b\x05\x1dS\x1dU\x03\x01\x05\x01\x03\x0bMMMMO\x03\x03\x87\x15\x03\x01\x11\x01\x03\rOSSOMM\x1f\x05\t\x00\x00\x00\x00\x1f)\x01\t\x07\x07\x01\x1f\x0f\x11\x00\x00\xc0\x7f\x00\x00\xc0\x7f\x1f-!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\x02\x02)\x05\x11\x11\t)\x01\x1b)\x01\x1d\x03\x11\x13\x01)\x01\t\t\x1d)\x05\x05\x05\r)\x05\x11\x11\r\x11\x03\x03\x05\x03\x03\x1b!)\x03\x11\x11)\x03\x11\t)\x03\x01\x0b)\x03\t\x0b)\x03\x05\x0b)\x03\x01\x13)\x01\r)\x03\t\x13\x04\xa2\x02\x05\x01\x11\x07\x13\x07\x03\x01\x05\t\x11\x07\x1b\x05\x031O\x03\x03\x07\x03\x03\x01%\x03\x05\x03\x03\x01-\x03\x05\x03\x03\x01/\x03\x07\x03\x03\x011\x03\x07\x0b\x07\x013\r\x03\x1f!\x03\x05\x05\x0b\x03\x05\x07\t\x01\x03\x03\x01E\x03\x05\x05\x07\x01\x05\x03\x05\x03\x17\r\x07\x01G\x03+\x05\x15\x19\x05\x07\x01\x05\x03\x15\x03\x1b\x03\x03\x01\x0f\x03\x0f\x05\x07\x01\x05\x03\x03\x03\x1f\x05\x07\x01\x11\x03\x17\x03\x1d\x07\x06\x01\x03\x03\x07#\x0b!\x05\x07\x01\x05\x03\x15\x03\x1b\x03\x03\x01\x0f\x03\x0f\x05\x07\x01\x05\x03\x03\x03)\x05\x07\x01\x11\x03\x17\x03'\x07\x06\x01\x03\x03\x07-\x11+\x0f\x04\x07\x05%/\x06\x03\x01\x05\x01\x002\x0bW\x1b\x03\x0f\x0b\t\t\x1b\x1d\r\x1b!+\x1b\x1f/!!)#\x1f\x19\x97\xbf\x1f\x15\x1d\x15\x13%)+\x13\r\x15\x17\x1f\x11\x15)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00select_v1\x00func_v1\x00custom_call_v1\x00compare_v1\x00return_v1\x00value\x00sym_name\x00broadcast_dimensions\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00jit(func)/jit(main)/schur[compute_schur_vectors=True sort_eig_vals=False select_callable=None]\x00/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00compare_type\x00comparison_direction\x00jax.result_info\x00jax.arg_info\x00input\x00mhlo.sharding\x00{replicated}\x00[0]\x00[1]\x00main\x00public\x00\x00lapack_cgees\x00", - xla_call_module_version=6, -) # End paste - -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_07_16["c128"] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_zgees'], - serialized_date=datetime.date(2023, 7, 16), - inputs=(array([[ 0.+0.j, 1.+0.j, 2.+0.j, 3.+0.j], - [ 4.+0.j, 5.+0.j, 6.+0.j, 7.+0.j], - [ 8.+0.j, 9.+0.j, 10.+0.j, 11.+0.j], - [12.+0.j, 13.+0.j, 14.+0.j, 15.+0.j]]),), - expected_outputs=(array([[ 3.2464249196572965e+01+0.j, -1.3416407864998730e+01+0.j, - 4.3084836728703156e-15+0.j, 2.8665351303736084e-15+0.j], - [ 0.0000000000000000e+00+0.j, -2.4642491965729802e+00+0.j, - -2.3716026934523430e-16+0.j, 3.7279396143672773e-16+0.j], - [ 0.0000000000000000e+00+0.j, 0.0000000000000000e+00+0.j, - -1.6035677295293287e-15+0.j, -6.3403580326623540e-16+0.j], - [ 0.0000000000000000e+00+0.j, 0.0000000000000000e+00+0.j, - 0.0000000000000000e+00+0.j, 1.2218554396786608e-16+0.j]]), array([[ 0.11417645138733863+0.j, -0.8288327563197504 +0.j, - 0.4960613110079619 +0.j, 0.2322136424094458 +0.j], - [ 0.33000459866554754+0.j, -0.43714638836388703+0.j, - -0.8344969112540657 +0.j, 0.06012408092789509+0.j], - [ 0.5458327459437572 +0.j, -0.04546002040802478-0.j, - 0.18080988948424495+0.j, -0.8168890890841272 +0.j], - [ 0.7616608932219662 +0.j, 0.34622634754783854+0.j, - 0.15762571076185886+0.j, 0.5245513657467864 +0.j]])), - mlir_module_text=r""" -#loc = loc(unknown) -module @jit_func attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main(%arg0: tensor<4x4xcomplex> {jax.arg_info = "input", mhlo.sharding = "{replicated}"} loc(unknown)) -> (tensor<4x4xcomplex> {jax.result_info = "[0]"}, tensor<4x4xcomplex> {jax.result_info = "[1]"}) { - %0 = stablehlo.constant dense<1> : tensor loc(#loc2) - %1 = stablehlo.constant dense<4> : tensor loc(#loc2) - %2 = stablehlo.constant dense<86> : tensor loc(#loc2) - %3 = stablehlo.constant dense<78> : tensor loc(#loc2) - %4:6 = stablehlo.custom_call @lapack_zgees(%0, %1, %2, %3, %arg0) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>, dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>]} : (tensor, tensor, tensor, tensor, tensor<4x4xcomplex>) -> (tensor<4x4xcomplex>, tensor<4xf64>, tensor<4xcomplex>, tensor<4x4xcomplex>, tensor, tensor) loc(#loc2) - %5 = stablehlo.constant dense<0> : tensor loc(#loc2) - %6 = stablehlo.broadcast_in_dim %5, dims = [] : (tensor) -> tensor loc(#loc2) - %7 = stablehlo.compare EQ, %4#5, %6, SIGNED : (tensor, tensor) -> tensor loc(#loc2) - %8 = stablehlo.broadcast_in_dim %7, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc2) - %9 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> loc(#loc2) - %10 = stablehlo.broadcast_in_dim %9, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc2) - %11 = stablehlo.broadcast_in_dim %8, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc2) - %12 = stablehlo.select %11, %4#0, %10 : tensor<4x4xi1>, tensor<4x4xcomplex> loc(#loc2) - %13 = stablehlo.broadcast_in_dim %7, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc2) - %14 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> loc(#loc2) - %15 = stablehlo.broadcast_in_dim %14, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc2) - %16 = stablehlo.broadcast_in_dim %13, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc2) - %17 = stablehlo.select %16, %4#3, %15 : tensor<4x4xi1>, tensor<4x4xcomplex> loc(#loc2) - return %12, %17 : tensor<4x4xcomplex>, tensor<4x4xcomplex> loc(#loc) - } loc(#loc) -} loc(#loc) -#loc1 = loc("/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py":483:0) -#loc2 = loc("jit(func)/jit(main)/schur[compute_schur_vectors=True sort_eig_vals=False select_callable=None]"(#loc1)) -""", - mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01\x1f\x05\x01\x03\x01\x03\x05\x03\x0f\x07\t\x0b\r\x0f\x11\x13\x03\xd9\x97/\x01M\x0f\x0b\x13\x07\x0f\x0b\x0b\x13\x13#\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x13\x0b\x17\x0b\x13\x13\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x0b\x03K\x0fO\x0b/\x0f\x1b\x0b\x0b\x0b\x0b\x0b\x13\x13\x0b\x13\x0b\x0b\x0b\x1f\x1f\x13\x13\x0b\x0b\x0b\x0b\x0b\x1f\x0f\x17#\x1f\x0f\x0b\x0bOO\x01\x03\x0f\x03-\x17\x0f\x0f\x0b\x07\x07\x0f\x07\x07\x17\x17\x1b\x07\x07\x13\x13\x13\x13\x13\x13\x0f\x13\x02\x06\x06\x1d')\x05\x15\x03\x03\r\x8d\x1f\x11\x01\x05\x05\x17\x05\x19\x03\x03\x03\x93\x03\x03\r\x95\x03\x07\x15\t\x17\t\x0b\x19\x05\x1b\x05\x1d\x05\x1f\x03\x0b\x1dU\x1fa!c\x0bm#o\x05!\x05#\x05%\x05'\x03\x03\x03q\x05)\x17+\x8e\x07\x01\x05+\x03\x03\x03s\x03\x03\x03u\x03\x03\x03w\x03\x115y7{9};\x7f=\x81?\x83A\x85C\x89\x05-\x05/\x051\x053\x055\x057\x059\x05;\x03\x03\x03\x8b\x03\x05I\x8fK\x91\x05=\x05?\x1f#\x01\x1f%!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1dA\x1f'\x11\x00\x00\x00\x00\x00\x00\x00\x00\x03\x03W\r\x05Y[]_\x1dC\x1dE\x1dG\x1dI#\x19\x03\x05ei\r\x03Qg\x1dK\r\x03Qk\x1dM\x1dO\x1dQ\x1f\x05\t\x01\x00\x00\x00\x1f\x05\t\x04\x00\x00\x00\x1f\x07\x03V\x1f\x07\x03N\x0b\x05\x1dS\x1dU\x03\x01\x05\x01\x03\x0bMMMMO\x03\x03\x87\x15\x03\x01\x11\x01\x03\rOSSOMM\x1f\x05\t\x00\x00\x00\x00\x1f)\x01\t\x07\x07\x01\x1f\x0f!\x00\x00\x00\x00\x00\x00\xf8\x7f\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f-!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\x02\x02)\x05\x11\x11\t)\x01\x1b)\x01\x1d\x03\x11\x13\x01)\x01\t\x0b\x1d)\x05\x05\x05\r)\x05\x11\x11\r\x11\x03\x03\x05\x03\x03\x1b!)\x03\x11\x11)\x03\x11\t)\x03\x01\x0b)\x03\t\x0b)\x03\x05\x0b)\x03\x01\x13)\x01\r)\x03\t\x13\x04\xa2\x02\x05\x01\x11\x07\x13\x07\x03\x01\x05\t\x11\x07\x1b\x05\x031O\x03\x03\x07\x03\x03\x01%\x03\x05\x03\x03\x01-\x03\x05\x03\x03\x01/\x03\x07\x03\x03\x011\x03\x07\x0b\x07\x013\r\x03\x1f!\x03\x05\x05\x0b\x03\x05\x07\t\x01\x03\x03\x01E\x03\x05\x05\x07\x01\x05\x03\x05\x03\x17\r\x07\x01G\x03+\x05\x15\x19\x05\x07\x01\x05\x03\x15\x03\x1b\x03\x03\x01\x0f\x03\x0f\x05\x07\x01\x05\x03\x03\x03\x1f\x05\x07\x01\x11\x03\x17\x03\x1d\x07\x06\x01\x03\x03\x07#\x0b!\x05\x07\x01\x05\x03\x15\x03\x1b\x03\x03\x01\x0f\x03\x0f\x05\x07\x01\x05\x03\x03\x03)\x05\x07\x01\x11\x03\x17\x03'\x07\x06\x01\x03\x03\x07-\x11+\x0f\x04\x07\x05%/\x06\x03\x01\x05\x01\x002\x0bW\x1b\x03\x0f\x0b\t\t\x1b\x1d\r\x1b!+\x1b\x1f/!!)#\x1f\x19\x97\xbf\x1f\x15\x1d\x15\x13%)+\x13\r\x15\x17\x1f\x11\x15)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00select_v1\x00func_v1\x00custom_call_v1\x00compare_v1\x00return_v1\x00value\x00sym_name\x00broadcast_dimensions\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00jit(func)/jit(main)/schur[compute_schur_vectors=True sort_eig_vals=False select_callable=None]\x00/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00compare_type\x00comparison_direction\x00jax.result_info\x00jax.arg_info\x00input\x00mhlo.sharding\x00{replicated}\x00[0]\x00[1]\x00main\x00public\x00\x00lapack_zgees\x00", - xla_call_module_version=6, -) # End paste +from numpy import array, float32, complex64 data_2024_11_29 = {} diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/cpu_svd_lapack_gesdd.py b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_svd_lapack_gesdd.py index 2d71308caeda..995847a03a60 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_data/cpu_svd_lapack_gesdd.py +++ b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_svd_lapack_gesdd.py @@ -17,435 +17,8 @@ import datetime from numpy import array, float32, complex64 -data_2023_06_19 = {} - - - -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_06_19["f32"] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_sgesdd'], - serialized_date=datetime.date(2023, 6, 22), - inputs=(array([[[ 1.5410905 , -2.775912 , -2.374003 , 4.028736 ], - [-0.56933475, 1.6115232 , 0.9041465 , -0.8321383 ], - [-5.382895 , 4.734856 , 2.1972926 , 1.5553856 ], - [ 0.5109847 , -1.1969309 , 3.3766198 , -1.3678027 ]], - - [[ 2.2637439 , 3.406768 , 4.809871 , 2.8010902 ], - [-1.9981416 , -0.6599986 , 0.5138156 , 4.5982494 ], - [-2.335944 , -9.151717 , -1.0481138 , 2.272443 ], - [-8.257684 , 1.8223318 , 0.38403794, 5.0769973 ]]], - dtype=float32),), - expected_outputs=(array([[[-0.48540133 , 0.6682397 , -0.48819906 , -0.28196266 ], - [ 0.2180054 , -0.13631375 , 0.14819765 , -0.95495003 ], - [ 0.8457052 , 0.44643915 , -0.27943406 , 0.08597418 ], - [ 0.040523227, -0.57928085 , -0.8133977 , -0.03429017 ]], - - [[-0.21146733 , 0.46376425 , 0.786309 , 0.34917438 ], - [ 0.3461469 , 0.21883713 , 0.3399653 , -0.84659094 ], - [ 0.6526192 , -0.5834038 , 0.3972404 , 0.2755518 ], - [ 0.6399631 , 0.6298203 , -0.32915345 , 0.2922879 ]]], - dtype=float32), array([[ 8.551608 , 5.3574076, 2.8073738, 0.5226082], - [11.457576 , 10.041606 , 5.6716514, 1.4754109]], dtype=float32), array([[[-0.6319046 , 0.6612254 , 0.39110154 , -0.102553196], - [-0.2971051 , 0.13673358 , -0.50112 , 0.80119365 ], - [ 0.08969147 , 0.4433047 , -0.73647296 , -0.5030348 ], - [-0.7101976 , -0.5895471 , -0.23135659 , -0.30745354 ]], - - [[-0.6964344 , -0.5023085 , -0.11150039 , 0.50023323 ], - [-0.32121164 , 0.7889568 , 0.3183193 , 0.41598475 ], - [ 0.5096958 , -0.31399378 , 0.60193455 , 0.5284816 ], - [-0.3898877 , -0.16322286 , 0.7238198 , -0.5453721 ]]], - dtype=float32)), - mlir_module_text=r""" -#loc = loc(unknown) -module @jit_func attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main(%arg0: tensor<2x4x4xf32> {jax.arg_info = "input", mhlo.sharding = "{replicated}"} loc(unknown)) -> (tensor<2x4x4xf32> {jax.result_info = "[0]"}, tensor<2x4xf32> {jax.result_info = "[1]"}, tensor<2x4x4xf32> {jax.result_info = "[2]"}) { - %0 = stablehlo.constant dense<1> : tensor loc(#loc2) - %1 = stablehlo.constant dense<1> : tensor loc(#loc2) - %2 = stablehlo.constant dense<2> : tensor loc(#loc2) - %3 = stablehlo.constant dense<4> : tensor loc(#loc2) - %4 = stablehlo.constant dense<4> : tensor loc(#loc2) - %5 = stablehlo.constant dense<268> : tensor loc(#loc2) - %6:7 = stablehlo.custom_call @lapack_sgesdd(%0, %1, %2, %3, %4, %5, %arg0) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>, dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 2, 0]> : tensor<3xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor<2x4x4xf32>) -> (tensor<2x4x4xf32>, tensor<2x4xf32>, tensor<2x4x4xf32>, tensor<2x4x4xf32>, tensor<2xi32>, tensor<32xi32>, tensor<268xf32>) loc(#loc2) - %7 = stablehlo.constant dense<0> : tensor loc(#loc2) - %8 = stablehlo.broadcast_in_dim %7, dims = [] : (tensor) -> tensor<2xi32> loc(#loc2) - %9 = stablehlo.compare EQ, %6#4, %8, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc2) - %10 = stablehlo.broadcast_in_dim %9, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc2) - %11 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc2) - %12 = stablehlo.broadcast_in_dim %11, dims = [] : (tensor) -> tensor<2x4xf32> loc(#loc2) - %13 = stablehlo.broadcast_in_dim %10, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x4xi1> loc(#loc2) - %14 = stablehlo.select %13, %6#1, %12 : tensor<2x4xi1>, tensor<2x4xf32> loc(#loc2) - %15 = stablehlo.broadcast_in_dim %9, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc2) - %16 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc2) - %17 = stablehlo.broadcast_in_dim %16, dims = [] : (tensor) -> tensor<2x4x4xf32> loc(#loc2) - %18 = stablehlo.broadcast_in_dim %15, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc2) - %19 = stablehlo.select %18, %6#2, %17 : tensor<2x4x4xi1>, tensor<2x4x4xf32> loc(#loc2) - %20 = stablehlo.broadcast_in_dim %9, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc2) - %21 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc2) - %22 = stablehlo.broadcast_in_dim %21, dims = [] : (tensor) -> tensor<2x4x4xf32> loc(#loc2) - %23 = stablehlo.broadcast_in_dim %20, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc2) - %24 = stablehlo.select %23, %6#3, %22 : tensor<2x4x4xi1>, tensor<2x4x4xf32> loc(#loc2) - return %19, %14, %24 : tensor<2x4x4xf32>, tensor<2x4xf32>, tensor<2x4x4xf32> loc(#loc) - } loc(#loc) -} loc(#loc) -#loc1 = loc("/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py":355:0) -#loc2 = loc("jit(func)/jit(main)/svd[full_matrices=True compute_uv=True]"(#loc1)) -""", - mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01\x1f\x05\x01\x03\x01\x03\x05\x03\x0f\x07\t\x0b\r\x0f\x11\x13\x03\xef\xa57\x01Q\x0f\x0b\x07\x13\x0b\x13\x13\x0f\x0b\x13\x13\x13#\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x0b\x17\x0b\x13\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x0b\x13\x03U\x0fo\x0b/\x0f\x1b\x0b\x0b\x0b\x0b\x0b\x17\x13\x0b\x13\x0b\x13\x0b\x0b\x0b\x1f\x1f\x1f\x1f\x0b\x0b\x0b\x0b\x0b'\x0f\x17'O\x1f\x0f\x0b\x0b/\x1fOo\x01\x03\x0f\x035\x0f\x1b\x07\x07\x17\x07\x07\x0f\x07\x13\x1b\x1b\x1f\x13\x17\x13\x13\x13\x13\x13\x13\x17\x13\x17\x13\x13\x02\xb6\x07\x1d+-\x05\x15\x1f\x03\x03\t\x97\x05\x17\x03\x03\t\x9d\x03\x03\x03\x9f\x11\x01\x05\x05\x19\x03\x03\x03y\x03\x03\x03}\x03\x03\t\xa3\x03\x07\x1b\x0f\x1d\x0f\x11\x1f\x05\x1b\x05\x1d\x05\x1f\x03\x0b#Y%e'g\x11u)w\x05!\x05#\x05%\x05'\x05)\x17/\x8e\x05\x01\x05+\x03\x03\x03{\x03\x03\x03\x7f\x03\x117\x819\x83;\x85=\x87?\x89A\x8bC\x8dE\x91\x05-\x05/\x051\x053\x055\x057\x059\x05;\x03\x03\x03\x95\x03\x05K\x99M\x9b\x05=\x05?\x03\x03\t\xa1\x1f!\x01\x1f#1\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1dA\x1f'\x11\x00\x00\x00\x00\x00\x00\x00\x00\x03\x03[\r\x05]_ac\x1dC\x1dE\x1dG\x1dI#\x1b\x03\x07imq\r\x03Uk\x1dK\r\x03Uo\x1dM\r\x03Us\x1dO\x1dQ\x1dS\x1f\x03\t\x01\x00\x00\x00\x1f\x03\t\x02\x00\x00\x00\x1f\x03\t\x04\x00\x00\x00\x1f\x03\t\x0c\x01\x00\x00\x0b\x05\x1dU\x1dW\x03\x01\x05\x01\x03\x0fQQQQQQS\x03\x03\x8f\x15\x03\x01\x19\x01\x03\x0fS\x93SSWWW\x1f%!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x03\t\x00\x00\x00\x00\x1f)\x01\t\x07\x07\x01\x1f/\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x11\t\x00\x00\xc0\x7f\x1f3!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f51\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x01\x02\x02)\x01\x13)\x07\t\x11\x11\t\x01\t)\x05\t\x11\t\x13\x1d)\x01\t\x1b)\x03\t\x13)\x07\t\x05\x05\x07)\x07\t\x11\x11\x07\x11\x03\x05\x07\x05\x0b\x05)\x03\x81\x13)\x03b\x08\t)\x03\x01\r)\x03\r\r)\x03\t\r)\x03\x05\r)\x03\x01\x0f)\x03\t\x07)\x05\t\x05\x07)\x03\x05\x0f)\x05\t\x11\x07)\x03\t\x0f)\x03\r\x0f\x04~\x03\x05\x01\x11\x05\x19\x07\x03\x01\x05\t\x11\x05!\x05\x03Ak\x03\x05\x05\x03\x03\x01\x13\x03\x03\x03\x03\x01\x13\x03\x03\x03\x03\x011\x03\x03\x03\x03\x01\x15\x03\x03\x03\x03\x01\x15\x03\x03\x03\x03\x013\x03\x03\x0b\x07\x015\x0f\x05\x0b\x05\x05\x15\x1d\x1f\x0f\x03\x05\x07\t\x0b\r\x01\x03\x03\x01G\x03\x03\x05\x07\x01\x07\x03\x15\x03\x1d\r\x07\x01I\x03+\x05\x17\x1f\x05\x07\x01\x0b\x03-\x03!\x03\x03\x01\r\x03\x11\x05\x07\x01\x07\x03\x0b\x03%\x05\x07\x01O\x031\x03#\x07\x06\x01\x03\x0b\x07)\x11'\x05\x07\x01\x0b\x03\x17\x03!\x03\x03\x01\r\x03\x11\x05\x07\x01\x07\x03\x05\x03/\x05\x07\x01\x17\x03\x19\x03-\x07\x06\x01\x03\x05\x073\x131\x05\x07\x01\x0b\x03\x17\x03!\x03\x03\x01\r\x03\x11\x05\x07\x01\x07\x03\x05\x039\x05\x07\x01\x17\x03\x19\x037\x07\x06\x01\x03\x05\x07=\x15;\x0f\x04\x05\x075+?\x06\x03\x01\x05\x01\x00\xbe\nY\x1d\x03\x0f\x0b\t\t\t\x1b\x1d\r\x1b!+\x1b\x1f/!!)#\x1f\x19\x97y\x1f\x15\x1d\x15\x13%)\x13+\r\x15\x17\x1f\x11\x15)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00select_v1\x00func_v1\x00custom_call_v1\x00compare_v1\x00return_v1\x00value\x00broadcast_dimensions\x00sym_name\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00jit(func)/jit(main)/svd[full_matrices=True compute_uv=True]\x00/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00compare_type\x00comparison_direction\x00jax.result_info\x00jax.arg_info\x00input\x00mhlo.sharding\x00{replicated}\x00[0]\x00[1]\x00[2]\x00main\x00public\x00\x00lapack_sgesdd\x00", - xla_call_module_version=6, -) # End paste - - -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_06_19["f64"] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_dgesdd'], - serialized_date=datetime.date(2023, 6, 22), - inputs=(array([[[ 0.3445689867809981 , 3.5114993759427104 , - 4.702602090972179 , -0.2702264758497052 ], - [ 2.209901632583705 , -2.6286702510632773 , - 4.591276599385847 , 3.4465035398844828 ], - [-1.5083742421154478 , 3.3225165204269635 , - 1.2596205557926703 , 3.524804355848018 ], - [ 1.5118969169108838 , 1.838885943509677 , - 2.818520751293422 , 3.06002540493494 ]], - - [[-2.4045510943950843 , -1.5657555633438576 , - -0.6061472334580296 , -0.23926156407779164], - [ 4.087879920053448 , -3.2507640936811715 , - -2.2556577657517476 , 6.090369998330348 ], - [ 1.1165401344486945 , 2.2134726894037247 , - 5.225178515435584 , 1.9794693474107725 ], - [-4.127878192684534 , -0.37313660200336163, - 0.7893465897510026 , -2.0315217791342848 ]]]),), - expected_outputs=(array([[[-0.5109626909166218 , -0.41744996156105785, - -0.731253241567692 , 0.1729779025790829 ], - [-0.5623501368035175 , 0.7608931604238581 , - 0.03470920608540986, 0.32186828528169453], - [-0.39585755254587435, -0.4954770291405409 , - 0.6561880513437818 , 0.4089212062978684 ], - [-0.5157288533916834 , -0.03577207859388855, - 0.18297871183094833, -0.8362194085221047 ]], - - [[-0.12124821978030875, -0.30260506534356213, - -0.5817463045715607 , -0.7451847292758064 ], - [ 0.8877417367326685 , -0.15794001239879188, - -0.3761180739267688 , 0.2133184375808915 ], - [ 0.03055221675864994, 0.9244545314395409 , - -0.3686107533067095 , -0.09260936183071355], - [-0.44303503260363514, -0.16990864078317836, - -0.619864940232637 , 0.624994775612963 ]]]), array([[8.951386926411189 , 5.762891699811626 , 3.839104008889441 , - 1.2696468971033248 ], - [9.21500688857692 , 6.477297670883227 , 3.24626945855818 , - 0.05112101994354587]]), array([[[-0.17890276924244797 , -0.2881812520705063 , - -0.7749616998111006 , -0.5332726590950898 ], - [ 0.38712159387038353 , -0.8985113987184378 , - 0.1397618670046424 , 0.15258033445914954 ], - [-0.23140697924040152 , -0.03708202130554661 , - -0.5045854966104308 , 0.8309447696839614 ], - [-0.8744034999217865 , -0.32901938548360005 , - 0.35396957633060866 , -0.043246992182741084]], - - [[ 0.6276106632546885 , -0.26728735347872895 , - -0.22995258718774078 , 0.6941067163520401 ], - [ 0.2802931697592562 , 0.4781137804659157 , - 0.808362569504731 , 0.19847646746808023 ], - [ 0.6187014005224262 , 0.47714095343944474 , - -0.3740686697560633 , -0.49961757159793246 ], - [-0.3804591585793503 , 0.6872417290515944 , - -0.3921025301835001 , 0.47875384105714014 ]]])), - mlir_module_text=r""" -#loc = loc(unknown) -module @jit_func attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main(%arg0: tensor<2x4x4xf64> {jax.arg_info = "input", mhlo.sharding = "{replicated}"} loc(unknown)) -> (tensor<2x4x4xf64> {jax.result_info = "[0]"}, tensor<2x4xf64> {jax.result_info = "[1]"}, tensor<2x4x4xf64> {jax.result_info = "[2]"}) { - %0 = stablehlo.constant dense<1> : tensor loc(#loc2) - %1 = stablehlo.constant dense<1> : tensor loc(#loc2) - %2 = stablehlo.constant dense<2> : tensor loc(#loc2) - %3 = stablehlo.constant dense<4> : tensor loc(#loc2) - %4 = stablehlo.constant dense<4> : tensor loc(#loc2) - %5 = stablehlo.constant dense<268> : tensor loc(#loc2) - %6:7 = stablehlo.custom_call @lapack_dgesdd(%0, %1, %2, %3, %4, %5, %arg0) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>, dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 2, 0]> : tensor<3xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor<2x4x4xf64>) -> (tensor<2x4x4xf64>, tensor<2x4xf64>, tensor<2x4x4xf64>, tensor<2x4x4xf64>, tensor<2xi32>, tensor<32xi32>, tensor<268xf64>) loc(#loc2) - %7 = stablehlo.constant dense<0> : tensor loc(#loc2) - %8 = stablehlo.broadcast_in_dim %7, dims = [] : (tensor) -> tensor<2xi32> loc(#loc2) - %9 = stablehlo.compare EQ, %6#4, %8, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc2) - %10 = stablehlo.broadcast_in_dim %9, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc2) - %11 = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc2) - %12 = stablehlo.broadcast_in_dim %11, dims = [] : (tensor) -> tensor<2x4xf64> loc(#loc2) - %13 = stablehlo.broadcast_in_dim %10, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x4xi1> loc(#loc2) - %14 = stablehlo.select %13, %6#1, %12 : tensor<2x4xi1>, tensor<2x4xf64> loc(#loc2) - %15 = stablehlo.broadcast_in_dim %9, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc2) - %16 = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc2) - %17 = stablehlo.broadcast_in_dim %16, dims = [] : (tensor) -> tensor<2x4x4xf64> loc(#loc2) - %18 = stablehlo.broadcast_in_dim %15, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc2) - %19 = stablehlo.select %18, %6#2, %17 : tensor<2x4x4xi1>, tensor<2x4x4xf64> loc(#loc2) - %20 = stablehlo.broadcast_in_dim %9, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc2) - %21 = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc2) - %22 = stablehlo.broadcast_in_dim %21, dims = [] : (tensor) -> tensor<2x4x4xf64> loc(#loc2) - %23 = stablehlo.broadcast_in_dim %20, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc2) - %24 = stablehlo.select %23, %6#3, %22 : tensor<2x4x4xi1>, tensor<2x4x4xf64> loc(#loc2) - return %19, %14, %24 : tensor<2x4x4xf64>, tensor<2x4xf64>, tensor<2x4x4xf64> loc(#loc) - } loc(#loc) -} loc(#loc) -#loc1 = loc("/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py":355:0) -#loc2 = loc("jit(func)/jit(main)/svd[full_matrices=True compute_uv=True]"(#loc1)) -""", - mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01\x1f\x05\x01\x03\x01\x03\x05\x03\x0f\x07\t\x0b\r\x0f\x11\x13\x03\xef\xa57\x01Q\x0f\x0b\x07\x13\x0b\x13\x13\x0f\x0b\x13\x13\x13#\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x0b\x17\x0b\x13\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x0b\x13\x03U\x0fo\x0b/\x0f\x1b\x0b\x0b\x0b\x0b\x0b\x17\x13\x0b\x13\x0b\x13\x0b\x0b\x0b\x1f\x1f\x1f\x1f\x0b\x0b\x0b\x0b\x0b'\x0f\x17'O\x1f\x0f\x0b\x0b//Oo\x01\x03\x0f\x035\x0f\x1b\x07\x07\x17\x07\x07\x0f\x07\x13\x1b\x1b\x1f\x13\x17\x13\x13\x13\x13\x13\x13\x17\x13\x17\x13\x13\x02\xc6\x07\x1d+-\x05\x15\x1f\x03\x03\t\x97\x05\x17\x03\x03\t\x9d\x03\x03\x03\x9f\x11\x01\x05\x05\x19\x03\x03\x03y\x03\x03\x03}\x03\x03\t\xa3\x03\x07\x1b\x0f\x1d\x0f\x11\x1f\x05\x1b\x05\x1d\x05\x1f\x03\x0b#Y%e'g\x11u)w\x05!\x05#\x05%\x05'\x05)\x17/\x8e\x05\x01\x05+\x03\x03\x03{\x03\x03\x03\x7f\x03\x117\x819\x83;\x85=\x87?\x89A\x8bC\x8dE\x91\x05-\x05/\x051\x053\x055\x057\x059\x05;\x03\x03\x03\x95\x03\x05K\x99M\x9b\x05=\x05?\x03\x03\t\xa1\x1f!\x01\x1f#1\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1dA\x1f'\x11\x00\x00\x00\x00\x00\x00\x00\x00\x03\x03[\r\x05]_ac\x1dC\x1dE\x1dG\x1dI#\x1b\x03\x07imq\r\x03Uk\x1dK\r\x03Uo\x1dM\r\x03Us\x1dO\x1dQ\x1dS\x1f\x03\t\x01\x00\x00\x00\x1f\x03\t\x02\x00\x00\x00\x1f\x03\t\x04\x00\x00\x00\x1f\x03\t\x0c\x01\x00\x00\x0b\x05\x1dU\x1dW\x03\x01\x05\x01\x03\x0fQQQQQQS\x03\x03\x8f\x15\x03\x01\x19\x01\x03\x0fS\x93SSWWW\x1f%!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x03\t\x00\x00\x00\x00\x1f)\x01\t\x07\x07\x01\x1f/\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x11\x11\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f3!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f51\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x01\x02\x02)\x01\x13)\x07\t\x11\x11\t\x01\x0b)\x05\t\x11\t\x13\x1d)\x01\t\x1b)\x03\t\x13)\x07\t\x05\x05\x07)\x07\t\x11\x11\x07\x11\x03\x05\x07\x05\x0b\x05)\x03\x81\x13)\x03b\x08\t)\x03\x01\r)\x03\r\r)\x03\t\r)\x03\x05\r)\x03\x01\x0f)\x03\t\x07)\x05\t\x05\x07)\x03\x05\x0f)\x05\t\x11\x07)\x03\t\x0f)\x03\r\x0f\x04~\x03\x05\x01\x11\x05\x19\x07\x03\x01\x05\t\x11\x05!\x05\x03Ak\x03\x05\x05\x03\x03\x01\x13\x03\x03\x03\x03\x01\x13\x03\x03\x03\x03\x011\x03\x03\x03\x03\x01\x15\x03\x03\x03\x03\x01\x15\x03\x03\x03\x03\x013\x03\x03\x0b\x07\x015\x0f\x05\x0b\x05\x05\x15\x1d\x1f\x0f\x03\x05\x07\t\x0b\r\x01\x03\x03\x01G\x03\x03\x05\x07\x01\x07\x03\x15\x03\x1d\r\x07\x01I\x03+\x05\x17\x1f\x05\x07\x01\x0b\x03-\x03!\x03\x03\x01\r\x03\x11\x05\x07\x01\x07\x03\x0b\x03%\x05\x07\x01O\x031\x03#\x07\x06\x01\x03\x0b\x07)\x11'\x05\x07\x01\x0b\x03\x17\x03!\x03\x03\x01\r\x03\x11\x05\x07\x01\x07\x03\x05\x03/\x05\x07\x01\x17\x03\x19\x03-\x07\x06\x01\x03\x05\x073\x131\x05\x07\x01\x0b\x03\x17\x03!\x03\x03\x01\r\x03\x11\x05\x07\x01\x07\x03\x05\x039\x05\x07\x01\x17\x03\x19\x037\x07\x06\x01\x03\x05\x07=\x15;\x0f\x04\x05\x075+?\x06\x03\x01\x05\x01\x00\xbe\nY\x1d\x03\x0f\x0b\t\t\t\x1b\x1d\r\x1b!+\x1b\x1f/!!)#\x1f\x19\x97y\x1f\x15\x1d\x15\x13%)\x13+\r\x15\x17\x1f\x11\x15)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00select_v1\x00func_v1\x00custom_call_v1\x00compare_v1\x00return_v1\x00value\x00broadcast_dimensions\x00sym_name\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00jit(func)/jit(main)/svd[full_matrices=True compute_uv=True]\x00/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00compare_type\x00comparison_direction\x00jax.result_info\x00jax.arg_info\x00input\x00mhlo.sharding\x00{replicated}\x00[0]\x00[1]\x00[2]\x00main\x00public\x00\x00lapack_dgesdd\x00", - xla_call_module_version=6, -) # End paste - - -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_06_19["c64"] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_cgesdd'], - serialized_date=datetime.date(2023, 6, 22), - inputs=(array([[[ 1.6052934 +0.45878917j, 4.587192 -4.5177283j , - 0.4177733 -1.9419309j , -2.2248359 -4.5042715j ], - [-7.083374 -8.127356j , 2.7596245 -4.991001j , - -0.52622825+5.033981j , -0.35441273-1.8215327j ], - [-0.7996552 -2.4052901j , -0.8506142 -3.164714j , - -0.3090829 +2.2020447j , 1.2367196 +2.8830793j ], - [ 1.4633094 -0.5451007j , -3.7833478 +6.6770763j , - -3.1279542 -2.2322626j , -2.1099617 -2.9661314j ]], - - [[ 1.2560439 -5.4743752j , -2.0085676 +2.0063214j , - -0.8132642 -3.4407883j , -0.17360081+0.6419895j ], - [ 2.3756726 +6.3315964j , -0.31447247-1.9387872j , - 4.6732006 -4.286903j , 1.7702469 -1.4957623j ], - [ 1.6918924 -0.52161306j, 0.49963537+4.7751374j , - -1.9243752 -4.5870543j , 2.8829405 +1.7382988j ], - [ 1.4884951 -0.44194785j, -1.3645276 -2.8733373j , - -0.39430943+2.4366508j , -0.76268387+5.2014065j ]]], - dtype=complex64),), - expected_outputs=(array([[[ 0.016725361+0.19210356j , 0.5452691 +0.5572638j , - 0.41363996 +0.18964858j , -0.26152334 -0.28195143j ], - [ 0.53678626 +0.64057267j , -0.21783225 -0.21288812j , - 0.28426644 +0.30535883j , 0.15201284 +0.10768581j ], - [ 0.21286921 +0.154735j , 0.066471666-0.25652882j , - -0.4074613 -0.10356682j , -0.11794163 -0.81844836j ], - [-0.39079374 -0.20583564j , -0.18335931 -0.4421772j , - 0.63489586 +0.19758748j , 0.038680226-0.36351213j ]], - - [[-0.3178596 +0.39032036j , -0.1273337 -0.30841744j , - 0.26394194 +0.26815224j , -0.21332254 -0.66947937j ], - [-0.39241245 -0.60790956j , -0.14006221 +0.41040683j , - -0.0830612 -0.10184447j , -0.45091942 -0.2603987j ], - [-0.36103728 +0.2876153j , -0.4965461 +0.10084368j , - -0.13752826 -0.6203828j , 0.35439825 -0.028546419j], - [ 0.062335093-0.078214265j, 0.35014474 -0.5668197j , - -0.42214075 -0.5090833j , -0.2889288 -0.15894148j ]]], - dtype=complex64), array([[15.135655 , 9.373035 , 7.444931 , 0.41523397], - [12.316969 , 8.661011 , 5.005059 , 2.115905 ]], - dtype=float32), array([[[-0.6537865 +0.j , -0.20306697 -0.6166746j , - 0.29948467 +0.24257992j , -0.007604365+0.04945353j ], - [ 0.52712685 +0.j , -0.11291563 -0.7116954j , - -0.089219 -0.36348897j , -0.23654723 -0.08269388j ], - [-0.31538543 +0.j , -0.014410622+0.15958191j , - -0.17958623 -0.13690898j , -0.6930434 -0.58613425j ], - [-0.44185135 +0.j , 0.17604677 -0.050492246j, - -0.4213856 -0.69485146j , 0.22373371 +0.2465445j ]], - - [[-0.64551586 +0.j , 0.32932255 -0.11672116j , - -0.093527466+0.6710145j , -0.038554154+0.02716677j ], - [ 0.4241116 +0.j , 0.031135002-0.539813j , - -0.26271763 +0.22760014j , -0.63609654 -0.04817467j ], - [-0.4577485 +0.j , -0.15202768 +0.2734652j , - 0.18931003 -0.3297506j , -0.7331101 -0.10269702j ], - [ 0.44034657 +0.j , 0.29474002 +0.63307834j , - 0.31271848 +0.4216674j , -0.20595454 -0.020532424j]]], - dtype=complex64)), - mlir_module_text=r""" -#loc = loc(unknown) -module @jit_func attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main(%arg0: tensor<2x4x4xcomplex> {jax.arg_info = "input", mhlo.sharding = "{replicated}"} loc(unknown)) -> (tensor<2x4x4xcomplex> {jax.result_info = "[0]"}, tensor<2x4xf32> {jax.result_info = "[1]"}, tensor<2x4x4xcomplex> {jax.result_info = "[2]"}) { - %0 = stablehlo.constant dense<1> : tensor loc(#loc2) - %1 = stablehlo.constant dense<1> : tensor loc(#loc2) - %2 = stablehlo.constant dense<2> : tensor loc(#loc2) - %3 = stablehlo.constant dense<4> : tensor loc(#loc2) - %4 = stablehlo.constant dense<4> : tensor loc(#loc2) - %5 = stablehlo.constant dense<264> : tensor loc(#loc2) - %6:8 = stablehlo.custom_call @lapack_cgesdd(%0, %1, %2, %3, %4, %5, %arg0) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>, dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 2, 0]> : tensor<3xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor<2x4x4xcomplex>) -> (tensor<2x4x4xcomplex>, tensor<2x4xf32>, tensor<2x4x4xcomplex>, tensor<2x4x4xcomplex>, tensor<2xi32>, tensor<32xi32>, tensor<100xf32>, tensor<264xcomplex>) loc(#loc2) - %7 = stablehlo.constant dense<0> : tensor loc(#loc2) - %8 = stablehlo.broadcast_in_dim %7, dims = [] : (tensor) -> tensor<2xi32> loc(#loc2) - %9 = stablehlo.compare EQ, %6#4, %8, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc2) - %10 = stablehlo.broadcast_in_dim %9, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc2) - %11 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc2) - %12 = stablehlo.broadcast_in_dim %11, dims = [] : (tensor) -> tensor<2x4xf32> loc(#loc2) - %13 = stablehlo.broadcast_in_dim %10, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x4xi1> loc(#loc2) - %14 = stablehlo.select %13, %6#1, %12 : tensor<2x4xi1>, tensor<2x4xf32> loc(#loc2) - %15 = stablehlo.broadcast_in_dim %9, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc2) - %16 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> loc(#loc2) - %17 = stablehlo.broadcast_in_dim %16, dims = [] : (tensor>) -> tensor<2x4x4xcomplex> loc(#loc2) - %18 = stablehlo.broadcast_in_dim %15, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc2) - %19 = stablehlo.select %18, %6#2, %17 : tensor<2x4x4xi1>, tensor<2x4x4xcomplex> loc(#loc2) - %20 = stablehlo.broadcast_in_dim %9, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc2) - %21 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> loc(#loc2) - %22 = stablehlo.broadcast_in_dim %21, dims = [] : (tensor>) -> tensor<2x4x4xcomplex> loc(#loc2) - %23 = stablehlo.broadcast_in_dim %20, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc2) - %24 = stablehlo.select %23, %6#3, %22 : tensor<2x4x4xi1>, tensor<2x4x4xcomplex> loc(#loc2) - return %19, %14, %24 : tensor<2x4x4xcomplex>, tensor<2x4xf32>, tensor<2x4x4xcomplex> loc(#loc) - } loc(#loc) -} loc(#loc) -#loc1 = loc("/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py":355:0) -#loc2 = loc("jit(func)/jit(main)/svd[full_matrices=True compute_uv=True]"(#loc1)) -""", - mlir_module_serialized=b'ML\xefR\x01StableHLO_v0.9.0\x00\x01\x1f\x05\x01\x03\x01\x03\x05\x03\x0f\x07\t\x0b\r\x0f\x11\x13\x03\xf9\xa9=\x01S\x0f\x0b\x07\x13\x0b\x13\x0f\x0b\x13\x13\x13\x13#\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x0b\x17\x0b\x13\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x0b\x13\x13\x03W\x0fo/\x0b\x0f\x1b\x0b\x0b\x0b\x0b\x0b\x17\x13\x0b\x13\x0b\x13\x0b\x0b\x0b\x1f\x1f\x1f\x1f\x0b\x0b\x0b\x0b\x0b\'\x0f\x17+O\x1f\x0f\x0b\x0b/\x1fO/o\x01\x03\x0f\x03;\x0f\x1b\x07\x07\x17\x07\x07\x0b\x07\x0f\x13\x0f\x1b\x1b\x1f\x13\x17\x17\x13\x13\x13\x13\x13\x13\x17\x13\x17\x13\x13\x02\x1e\x08\x1d+-\x05\x15\x1f\x03\x03\t\x99\x05\x17\x03\x03\t\x9f\x11\x01\x05\x05\x19\x03\x03\x03{\x03\x03\x03\x7f\x03\x03\x03\xa5\x03\x03\t\xa7\x03\x07\x1b\r\x1d\r\x0f\x1f\x05\x1b\x05\x1d\x05\x1f\x03\x0b#[%g\'i\x0fw)y\x05!\x05#\x05%\x05\'\x05)\x17/\x8e\x05\x01\x05+\x03\x03\x03}\x03\x03\x03\x81\x03\x117\x839\x85;\x87=\x89?\x8bA\x8dC\x8fE\x93\x05-\x05/\x051\x053\x055\x057\x059\x05;\x03\x03\x03\x97\x03\x05K\x9bM\x9d\x05=\x05?\x03\x03\x03\xa1\x03\x03\t\xa3\x1f\'\x01\x1f)1\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f-\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1dA\x03\x03]\r\x05_ace\x1dC\x1dE\x1dG\x1dI#\x1f\x03\x07kos\r\x03Ym\x1dK\r\x03Yq\x1dM\r\x03Yu\x1dO\x1dQ\x1dS\x1f\x03\t\x01\x00\x00\x00\x1f\x03\t\x02\x00\x00\x00\x1f\x03\t\x04\x00\x00\x00\x1f\x03\t\x08\x01\x00\x00\x0b\x05\x1dU\x1dW\x03\x01\x05\x01\x03\x0fSSSSSSU\x03\x03\x91\x15\x03\x01\x19\x01\x03\x11U\x95UUWWWW\x1f+!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x03\t\x00\x00\x00\x00\x1f/\x01\t\x07\x07\x01\x1f5\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x19\t\x00\x00\xc0\x7f\x1f9!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f\x15\x11\x00\x00\xc0\x7f\x00\x00\xc0\x7f\x1f;1\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x01\x02\x02)\x01\x13)\x07\t\x11\x11\x11\x01\t)\x05\t\x11\t\x13\x1d\x03\t\x1b)\x01\x11)\x03\t\x13)\x01\t)\x07\t\x05\x05\x07)\x07\t\x11\x11\x07\x11\x03\x05\x07\x05\x0b\x05)\x03\x81\x13)\x03"\x03\t)\x03B\x08\x11)\x03\x01\r)\x03\r\r)\x03\t\r)\x03\x05\r)\x03\x01\x0f)\x03\t\x07)\x05\t\x05\x07)\x03\x05\x0f)\x05\t\x11\x07)\x03\t\x0f)\x03\r\x0f\x04\x82\x03\x05\x01\x11\x05\x19\x07\x03\x01\x05\t\x11\x05!\x05\x03Ck\x03\x05\x05\x03\x03\x01\x11\x03\x03\x03\x03\x01\x11\x03\x03\x03\x03\x011\x03\x03\x03\x03\x01\x13\x03\x03\x03\x03\x01\x13\x03\x03\x03\x03\x013\x03\x03\x0b\x07\x015\x11\x05\x0b\x05\x05\x17!#%\x0f\x03\x05\x07\t\x0b\r\x01\x03\x03\x01G\x03\x03\x05\x07\x01\x07\x03\x17\x03\x1f\r\x07\x01I\x031\x05\x17!\x05\x07\x01\x0b\x033\x03#\x03\x03\x01O\x03\x19\x05\x07\x01\x07\x03\x0b\x03\'\x05\x07\x01Q\x037\x03%\x07\x06\x01\x03\x0b\x07+\x11)\x05\x07\x01\x0b\x03\x1b\x03#\x03\x03\x01\x15\x03\x15\x05\x07\x01\x07\x03\x05\x031\x05\x07\x01\x17\x03\x1d\x03/\x07\x06\x01\x03\x05\x075\x133\x05\x07\x01\x0b\x03\x1b\x03#\x03\x03\x01\x15\x03\x15\x05\x07\x01\x07\x03\x05\x03;\x05\x07\x01\x17\x03\x1d\x039\x07\x06\x01\x03\x05\x07?\x15=\x0f\x04\x05\x077-A\x06\x03\x01\x05\x01\x00\xbe\nY\x1d\x03\x0f\x0b\t\t\t\x1b\x1d\r\x1b!+\x1b\x1f/!!)#\x1f\x19\x97y\x1f\x15\x1d\x15\x13%)\x13+\r\x15\x17\x1f\x11\x15)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00select_v1\x00func_v1\x00custom_call_v1\x00compare_v1\x00return_v1\x00value\x00broadcast_dimensions\x00sym_name\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00jit(func)/jit(main)/svd[full_matrices=True compute_uv=True]\x00/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00compare_type\x00comparison_direction\x00jax.result_info\x00jax.arg_info\x00input\x00mhlo.sharding\x00{replicated}\x00[0]\x00[1]\x00[2]\x00main\x00public\x00\x00lapack_cgesdd\x00', - xla_call_module_version=6, -) # End paste - - -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_06_19["c128"] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_zgesdd'], - serialized_date=datetime.date(2023, 6, 22), - inputs=(array([[[-0.9247611722912019-1.3615157109291343j , - -1.0663457975211892+4.73170030936092j , - -1.4918732811689488-2.880861991859318j , - -1.111356346434667 -2.869701609083459j ], - [-4.71291623424314 -1.5444012898828912j , - -5.232967549101415 -0.41287816948482003j, - 0.8905737109262459+9.50245186328329j , - 4.397722119094926 -6.842005210371916j ], - [ 1.9369405063276903+2.3496014107398917j , - -1.5609345742256133+4.2102103739897805j , - 0.6596030248996742+5.195353435247212j , - 0.6315014498240328-1.2778849649354402j ], - [ 5.115159214503849 -0.8856276268773485j , - 1.3719934567460779-2.236070491368575j , - 0.4974504006612811-3.0462081956756637j , - -0.2620346712025989+4.424682727912594j ]], - - [[-1.8242711798401063-0.8543252170262536j , - -2.724527211360488 +2.256038331706666j , - -1.2777487543905157+0.976556823566376j , - 3.7438974536713223-0.4994301527847589j ], - [-0.6359051102028691+2.730662301129662j , - -1.2877728943263032+3.9124921723649053j , - -3.4618573226579894+1.7835551986994034j , - -1.4710491660152465+2.144967500163963j ], - [-3.6013691182532828+2.8182351980619034j , - 2.0045935428878803+1.1146211993017152j , - -2.332213857689336 -0.874915651404938j , - -1.5393862406530452+0.6852883119580928j ], - [-2.674897392856801 +2.0724239502976984j , - -3.349108041292141 -1.0215359152295307j , - 0.2603515088197114-1.9093411474619364j , - 5.41252457188561 +8.634368042893094j ]]]),), - expected_outputs=(array([[[-0.04173678258633362+0.10796693731538423j , - 0.6813428383170976 +0.34327979589293334j , - -0.41770229002865755+0.20028957850808823j , - -0.43443513665085287+0.034743251442636465j], - [-0.8408468609573512 -0.1326064604464803j , - -0.21674151028481228+0.015170556885426551j, - 0.17147327711152338+0.1531041615298256j , - -0.3568765623609291 +0.21904384306708768j ], - [-0.2673618144044136 +0.1379833616281103j , - -0.17534278352558025-0.378992615769627j , - -0.8179957069096054 -0.037506032257391624j, - 0.25392637883428526-0.009771014463849802j], - [ 0.40569239968065934-0.08297706578106905j , - -0.4321527034953765 +0.09791545663574397j , - -0.23439193826962654-0.08427130532228161j , - -0.42348296145608866+0.6251448114949291j ]], - - [[ 0.0272684373986653 +0.36312055550335454j , - 0.270297713559288 +0.1304616587162563j , - 0.04286867013923673-0.4765859417602139j , - 0.7242702256119968 +0.15420620503522459j ], - [-0.08593436615104483+0.1189990183325552j , - 0.37050286109355285-0.6240865462984536j , - 0.46902056878806025-0.34747949920770266j , - -0.31667671459632074-0.10340064369932994j ], - [-0.07914843440873574-0.033487314943774035j, - 0.4110353453489128 -0.455090805566563j , - -0.431131803930273 +0.40910871949632j , - 0.13782730102420274+0.49428280062680086j ], - [-0.7478497242333215 +0.5283836938016964j , - -0.08345894989956631+0.011807690067190268j, - -0.27178304569905287+0.056526279406748176j, - -0.09911954913441999-0.2598859654000683j ]]]), array([[16.80132997488892 , 7.744755614558116 , 5.831221808032041 , - 1.1195288361137765], - [12.39537594694893 , 8.218551160453814 , 4.683634850274079 , - 1.8820915363839188]]), array([[[ 0.35796251040556704 +0.j , - 0.40179383774178046 -0.1269359716702074j , - -0.0751486661300563 -0.6109813931761136j , - -0.23049271148274278 +0.51209309438597j ], - [-0.4682861415308549 +0.j , - -0.013958972669495105+0.4210606476774211j , - -0.6006888466394119 -0.3766516564723718j , - -0.24264518623237025 -0.20408557153193485j ], - [-0.6392945524816095 +0.j , - 0.2432388607602898 -0.6679928485374246j , - 0.18168178910997038 -0.08126854868489754j , - -0.2030612067046724 -0.07124733621915219j ], - [-0.49383540371426055 +0.j , - -0.010402968929686592+0.3734624991410737j , - 0.27994282704104956 +0.01949406216762731j , - 0.32588905219319236 +0.6569569657140543j ]], - - [[ 0.2666920370516844 +0.j , - 0.24929033811571413 +0.27271089049933883j , - -0.012922512768026735+0.16383354123801513j , - 0.07388201893235022 -0.8717175469187741j ], - [-0.6156140469162428 +0.j , - -0.33787077397020143 +0.37797154650923376j , - -0.3916043058726119 -0.2839601305776179j , - -0.2714888604157674 -0.23729034093304682j ], - [ 0.5618758038857617 +0.j , - -0.5788776267734554 -0.13833058883452312j , - -0.48995086206819644 +0.19259594116096765j , - -0.22967101640965012 -0.012926826751577613j], - [-0.48393210641613593 +0.j , - -0.1049229605428438 -0.4911419972025977j , - -0.07782239226461217 +0.6751317817750165j , - 0.11941657609231515 -0.19354808489959852j ]]])), - mlir_module_text=r""" -#loc = loc(unknown) -module @jit_func attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main(%arg0: tensor<2x4x4xcomplex> {jax.arg_info = "input", mhlo.sharding = "{replicated}"} loc(unknown)) -> (tensor<2x4x4xcomplex> {jax.result_info = "[0]"}, tensor<2x4xf64> {jax.result_info = "[1]"}, tensor<2x4x4xcomplex> {jax.result_info = "[2]"}) { - %0 = stablehlo.constant dense<1> : tensor loc(#loc2) - %1 = stablehlo.constant dense<1> : tensor loc(#loc2) - %2 = stablehlo.constant dense<2> : tensor loc(#loc2) - %3 = stablehlo.constant dense<4> : tensor loc(#loc2) - %4 = stablehlo.constant dense<4> : tensor loc(#loc2) - %5 = stablehlo.constant dense<264> : tensor loc(#loc2) - %6:8 = stablehlo.custom_call @lapack_zgesdd(%0, %1, %2, %3, %4, %5, %arg0) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>, dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 2, 0]> : tensor<3xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor<2x4x4xcomplex>) -> (tensor<2x4x4xcomplex>, tensor<2x4xf64>, tensor<2x4x4xcomplex>, tensor<2x4x4xcomplex>, tensor<2xi32>, tensor<32xi32>, tensor<100xf64>, tensor<264xcomplex>) loc(#loc2) - %7 = stablehlo.constant dense<0> : tensor loc(#loc2) - %8 = stablehlo.broadcast_in_dim %7, dims = [] : (tensor) -> tensor<2xi32> loc(#loc2) - %9 = stablehlo.compare EQ, %6#4, %8, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc2) - %10 = stablehlo.broadcast_in_dim %9, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc2) - %11 = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc2) - %12 = stablehlo.broadcast_in_dim %11, dims = [] : (tensor) -> tensor<2x4xf64> loc(#loc2) - %13 = stablehlo.broadcast_in_dim %10, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x4xi1> loc(#loc2) - %14 = stablehlo.select %13, %6#1, %12 : tensor<2x4xi1>, tensor<2x4xf64> loc(#loc2) - %15 = stablehlo.broadcast_in_dim %9, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc2) - %16 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> loc(#loc2) - %17 = stablehlo.broadcast_in_dim %16, dims = [] : (tensor>) -> tensor<2x4x4xcomplex> loc(#loc2) - %18 = stablehlo.broadcast_in_dim %15, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc2) - %19 = stablehlo.select %18, %6#2, %17 : tensor<2x4x4xi1>, tensor<2x4x4xcomplex> loc(#loc2) - %20 = stablehlo.broadcast_in_dim %9, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc2) - %21 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> loc(#loc2) - %22 = stablehlo.broadcast_in_dim %21, dims = [] : (tensor>) -> tensor<2x4x4xcomplex> loc(#loc2) - %23 = stablehlo.broadcast_in_dim %20, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc2) - %24 = stablehlo.select %23, %6#3, %22 : tensor<2x4x4xi1>, tensor<2x4x4xcomplex> loc(#loc2) - return %19, %14, %24 : tensor<2x4x4xcomplex>, tensor<2x4xf64>, tensor<2x4x4xcomplex> loc(#loc) - } loc(#loc) -} loc(#loc) -#loc1 = loc("/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py":355:0) -#loc2 = loc("jit(func)/jit(main)/svd[full_matrices=True compute_uv=True]"(#loc1)) -""", - mlir_module_serialized=b'ML\xefR\x01StableHLO_v0.9.0\x00\x01\x1f\x05\x01\x03\x01\x03\x05\x03\x0f\x07\t\x0b\r\x0f\x11\x13\x03\xf9\xa9=\x01S\x0f\x0b\x07\x13\x0b\x13\x0f\x0b\x13\x13\x13\x13#\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x0b\x17\x0b\x13\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x0b\x13\x13\x03W\x0fo/\x0b\x0f\x1b\x0b\x0b\x0b\x0b\x0b\x17\x13\x0b\x13\x0b\x13\x0b\x0b\x0b\x1f\x1f\x1f\x1f\x0b\x0b\x0b\x0b\x0b\'\x0f\x17+O\x1f\x0f\x0b\x0b//OOo\x01\x03\x0f\x03;\x0f\x1b\x07\x07\x17\x07\x07\x0b\x07\x0f\x13\x0f\x1b\x1b\x1f\x13\x17\x17\x13\x13\x13\x13\x13\x13\x17\x13\x17\x13\x13\x02N\x08\x1d+-\x05\x15\x1f\x03\x03\t\x99\x05\x17\x03\x03\t\x9f\x11\x01\x05\x05\x19\x03\x03\x03{\x03\x03\x03\x7f\x03\x03\x03\xa5\x03\x03\t\xa7\x03\x07\x1b\r\x1d\r\x0f\x1f\x05\x1b\x05\x1d\x05\x1f\x03\x0b#[%g\'i\x0fw)y\x05!\x05#\x05%\x05\'\x05)\x17/\x8e\x05\x01\x05+\x03\x03\x03}\x03\x03\x03\x81\x03\x117\x839\x85;\x87=\x89?\x8bA\x8dC\x8fE\x93\x05-\x05/\x051\x053\x055\x057\x059\x05;\x03\x03\x03\x97\x03\x05K\x9bM\x9d\x05=\x05?\x03\x03\x03\xa1\x03\x03\t\xa3\x1f\'\x01\x1f)1\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f-\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1dA\x03\x03]\r\x05_ace\x1dC\x1dE\x1dG\x1dI#\x1f\x03\x07kos\r\x03Ym\x1dK\r\x03Yq\x1dM\r\x03Yu\x1dO\x1dQ\x1dS\x1f\x03\t\x01\x00\x00\x00\x1f\x03\t\x02\x00\x00\x00\x1f\x03\t\x04\x00\x00\x00\x1f\x03\t\x08\x01\x00\x00\x0b\x05\x1dU\x1dW\x03\x01\x05\x01\x03\x0fSSSSSSU\x03\x03\x91\x15\x03\x01\x19\x01\x03\x11U\x95UUWWWW\x1f+!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x03\t\x00\x00\x00\x00\x1f/\x01\t\x07\x07\x01\x1f5\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x19\x11\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f9!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f\x15!\x00\x00\x00\x00\x00\x00\xf8\x7f\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f;1\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x01\x02\x02)\x01\x13)\x07\t\x11\x11\x11\x01\x0b)\x05\t\x11\t\x13\x1d\x03\t\x1b)\x01\x11)\x03\t\x13)\x01\t)\x07\t\x05\x05\x07)\x07\t\x11\x11\x07\x11\x03\x05\x07\x05\x0b\x05)\x03\x81\x13)\x03"\x03\t)\x03B\x08\x11)\x03\x01\r)\x03\r\r)\x03\t\r)\x03\x05\r)\x03\x01\x0f)\x03\t\x07)\x05\t\x05\x07)\x03\x05\x0f)\x05\t\x11\x07)\x03\t\x0f)\x03\r\x0f\x04\x82\x03\x05\x01\x11\x05\x19\x07\x03\x01\x05\t\x11\x05!\x05\x03Ck\x03\x05\x05\x03\x03\x01\x11\x03\x03\x03\x03\x01\x11\x03\x03\x03\x03\x011\x03\x03\x03\x03\x01\x13\x03\x03\x03\x03\x01\x13\x03\x03\x03\x03\x013\x03\x03\x0b\x07\x015\x11\x05\x0b\x05\x05\x17!#%\x0f\x03\x05\x07\t\x0b\r\x01\x03\x03\x01G\x03\x03\x05\x07\x01\x07\x03\x17\x03\x1f\r\x07\x01I\x031\x05\x17!\x05\x07\x01\x0b\x033\x03#\x03\x03\x01O\x03\x19\x05\x07\x01\x07\x03\x0b\x03\'\x05\x07\x01Q\x037\x03%\x07\x06\x01\x03\x0b\x07+\x11)\x05\x07\x01\x0b\x03\x1b\x03#\x03\x03\x01\x15\x03\x15\x05\x07\x01\x07\x03\x05\x031\x05\x07\x01\x17\x03\x1d\x03/\x07\x06\x01\x03\x05\x075\x133\x05\x07\x01\x0b\x03\x1b\x03#\x03\x03\x01\x15\x03\x15\x05\x07\x01\x07\x03\x05\x03;\x05\x07\x01\x17\x03\x1d\x039\x07\x06\x01\x03\x05\x07?\x15=\x0f\x04\x05\x077-A\x06\x03\x01\x05\x01\x00\xbe\nY\x1d\x03\x0f\x0b\t\t\t\x1b\x1d\r\x1b!+\x1b\x1f/!!)#\x1f\x19\x97y\x1f\x15\x1d\x15\x13%)\x13+\r\x15\x17\x1f\x11\x15)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00select_v1\x00func_v1\x00custom_call_v1\x00compare_v1\x00return_v1\x00value\x00broadcast_dimensions\x00sym_name\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00jit(func)/jit(main)/svd[full_matrices=True compute_uv=True]\x00/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00compare_type\x00comparison_direction\x00jax.result_info\x00jax.arg_info\x00input\x00mhlo.sharding\x00{replicated}\x00[0]\x00[1]\x00[2]\x00main\x00public\x00\x00lapack_zgesdd\x00', - xla_call_module_version=6, -) # End paste - data_2024_08_13 = {} - # Pasted from the test output (see export_back_compat_test_util.py module docstring) data_2024_08_13["c128"] = dict( testdata_version=1, diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/cpu_triangular_solve_blas_trsm.py b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_triangular_solve_blas_trsm.py index c401ca041bfb..f3640c6114c7 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_data/cpu_triangular_solve_blas_trsm.py +++ b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_triangular_solve_blas_trsm.py @@ -17,196 +17,8 @@ import datetime from numpy import array, float32, complex64 -data_2023_07_16 = {} -# Pasted from the test output (see back_compat_test_util.py module docstring) -data_2023_07_16["f32"] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['blas_strsm'], - serialized_date=datetime.date(2023, 7, 16), - inputs=(array([[ 5., 0., 0., 0.], - [ 4., 10., 0., 0.], - [ 8., 9., 15., 0.], - [12., 13., 14., 20.]], dtype=float32), array([[ 0., 1., 2., 3., 4.], - [ 5., 6., 7., 8., 9.], - [10., 11., 12., 13., 14.], - [15., 16., 17., 18., 19.]], dtype=float32)), - expected_outputs=(array([[ 0. , 0.2 , 0.4 , 0.6 , - 0.8 ], - [ 0.5 , 0.52 , 0.54 , 0.56 , - 0.58000004 ], - [ 0.36666667 , 0.31466666 , 0.26266667 , 0.21066667 , - 0.15866666 ], - [ 0.16833334 , 0.12173338 , 0.0751333 , 0.02853328 , - -0.018066704]], dtype=float32),), - mlir_module_text=r""" -#loc = loc(unknown) -module @jit_func attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main(%arg0: tensor<4x4xf32> {jax.arg_info = "a", mhlo.sharding = "{replicated}"} loc(unknown), %arg1: tensor<4x5xf32> {jax.arg_info = "b", mhlo.sharding = "{replicated}"} loc(unknown)) -> (tensor<4x5xf32> {jax.result_info = ""}) { - %0 = stablehlo.constant dense<1.000000e+00> : tensor loc(#loc2) - %1 = stablehlo.constant dense<1> : tensor loc(#loc2) - %2 = stablehlo.constant dense<1> : tensor loc(#loc2) - %3 = stablehlo.constant dense<0> : tensor loc(#loc2) - %4 = stablehlo.constant dense<0> : tensor loc(#loc2) - %5 = stablehlo.constant dense<4> : tensor loc(#loc2) - %6 = stablehlo.constant dense<5> : tensor loc(#loc2) - %7 = stablehlo.constant dense<1> : tensor loc(#loc2) - %8 = stablehlo.custom_call @blas_strsm(%1, %2, %3, %4, %5, %6, %7, %0, %arg0, %arg1) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>]} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor<4x4xf32>, tensor<4x5xf32>) -> tensor<4x5xf32> loc(#loc2) - return %8 : tensor<4x5xf32> loc(#loc) - } loc(#loc) -} loc(#loc) -#loc1 = loc("/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py":508:0) -#loc2 = loc("jit(func)/jit(main)/triangular_solve[left_side=True lower=True transpose_a=False conjugate_a=False unit_diagonal=False]"(#loc1)) -""", - mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01\x19\x05\x01\x03\x01\x03\x05\x03\t\x07\t\x0b\r\x03\xa5{\x17\x01?\x0f\x07\x0b\x13\x0f\x0b\x13#\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x13\x0b\x17\x0b\x13\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x03=\x0fO\x0b\x0b\x0b\x0b\x13\x1b\x0b\x1b\x0b\x0b\x0f\x13\x0b\x0b\x0b\x1f\x1f\x1f\x1f\x1f\x0b\x0b\x0b\x0b3\x0f\x13\x0f\x01\x03\x0f\x03\x15\x0f\x17\x07\x17\x0f\x07\x1b\x07\x13\x13\x02J\x04\x1d#%\x1f\x05\x0f\x03\x03\x05c\x11\x01\x05\x05\x11\x03\x03\x05e\x03\x07\x11\t\x13\t\x0b\x15\x05\x13\x05\x15\x05\x17\x03\x0b\x19K\x1bU\x1dW\x0b]\x1f_\x05\x19\x05\x1b\x05\x1d\x05\x1f\x03\x03\x05a\x05!\x17'\xf2\x07\x01\x05#\x03\x03\x05g\x03\x03\x05i\x03\x11/k1I3m5o7q9s;u=y\x05%\x05'\x05)\x05+\x05-\x05/\x051\x053\x1f\x13\x01\x1f\x15!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1d5\x1d7\x1d9\x1d;\x03\x05MQ\r\x05COEG\x1d=\r\x05CSEG\x1d?#\x0f\x03\x03Y\r\x03[I\x1dA\x1dC\x1dE\x1f\x0b\t\x00\x00\x80?\x1f\x03\t\x01\x00\x00\x00\x1f\x03\t\x00\x00\x00\x00\x1f\x03\t\x04\x00\x00\x00\x1f\x03\t\x05\x00\x00\x00\x0b\x05\x1dG\x03\x01\x05\x01\x03\x15????????AA\x03\x03w\x15\x01%\x01\x03\x03A\x01\x02\x02)\x01\x11)\x05\x11\x15\x07\t)\x05\x11\x11\x07)\x01\x07\x13\x11\x05\t\x05\x03\x05\x1b)\x03\x01\r)\x03\t\r\x04\xb9\x05\x01\x11\x03\x0f\x07\x03\x01\x05\x05\x11\x03\x17\x05\x03\x17+\x05\t\x03\x05\x03\x03\x03\x01!\x03\x0b\x03\x03\x01\x07\x03\x03\x03\x03\x01\x07\x03\x03\x03\x03\x01\r\x03\x03\x03\x03\x01\r\x03\x03\x03\x03\x01)\x03\x03\x03\x03\x01+\x03\x03\x03\x03\x01\x07\x03\x03\x07\x07\x01-\x03\x05\x15\x07\t\x0b\r\x0f\x11\x13\x05\x01\x03\t\x04\x03\x03\x15\x06\x03\x01\x05\x01\x00\xca\tI\x17\x0f\x0b!\x05\x05\x03\x1b\x1d\x1b\x1f/!!)#\x1f\x19\x97\xf1\x1f\x15\x1d\x15\x13%)\x13\r\x15\x1f\x11\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00func_v1\x00custom_call_v1\x00return_v1\x00value\x00sym_name\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00jit(func)/jit(main)/triangular_solve[left_side=True lower=True transpose_a=False conjugate_a=False unit_diagonal=False]\x00/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jax.arg_info\x00mhlo.sharding\x00{replicated}\x00\x00a\x00b\x00jax.result_info\x00main\x00public\x00blas_strsm\x00", - xla_call_module_version=6, -) # End paste - -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_07_16["f64"] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['blas_dtrsm'], - serialized_date=datetime.date(2023, 7, 16), - inputs=(array([[ 5., 0., 0., 0.], - [ 4., 10., 0., 0.], - [ 8., 9., 15., 0.], - [12., 13., 14., 20.]]), array([[ 0., 1., 2., 3., 4.], - [ 5., 6., 7., 8., 9.], - [10., 11., 12., 13., 14.], - [15., 16., 17., 18., 19.]])), - expected_outputs=(array([[ 0. , 0.2 , - 0.4 , 0.6000000000000001 , - 0.8 ], - [ 0.5 , 0.52 , - 0.54 , 0.5599999999999999 , - 0.58 ], - [ 0.36666666666666664 , 0.3146666666666667 , - 0.2626666666666667 , 0.21066666666666667 , - 0.15866666666666665 ], - [ 0.16833333333333336 , 0.1217333333333333 , - 0.07513333333333323 , 0.0285333333333333 , - -0.018066666666666675]]),), - mlir_module_text=r""" -#loc = loc(unknown) -module @jit_func attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main(%arg0: tensor<4x4xf64> {jax.arg_info = "a", mhlo.sharding = "{replicated}"} loc(unknown), %arg1: tensor<4x5xf64> {jax.arg_info = "b", mhlo.sharding = "{replicated}"} loc(unknown)) -> (tensor<4x5xf64> {jax.result_info = ""}) { - %0 = stablehlo.constant dense<1.000000e+00> : tensor loc(#loc2) - %1 = stablehlo.constant dense<1> : tensor loc(#loc2) - %2 = stablehlo.constant dense<1> : tensor loc(#loc2) - %3 = stablehlo.constant dense<0> : tensor loc(#loc2) - %4 = stablehlo.constant dense<0> : tensor loc(#loc2) - %5 = stablehlo.constant dense<4> : tensor loc(#loc2) - %6 = stablehlo.constant dense<5> : tensor loc(#loc2) - %7 = stablehlo.constant dense<1> : tensor loc(#loc2) - %8 = stablehlo.custom_call @blas_dtrsm(%1, %2, %3, %4, %5, %6, %7, %0, %arg0, %arg1) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>]} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor<4x4xf64>, tensor<4x5xf64>) -> tensor<4x5xf64> loc(#loc2) - return %8 : tensor<4x5xf64> loc(#loc) - } loc(#loc) -} loc(#loc) -#loc1 = loc("/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py":511:0) -#loc2 = loc("jit(func)/jit(main)/triangular_solve[left_side=True lower=True transpose_a=False conjugate_a=False unit_diagonal=False]"(#loc1)) -""", - mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01\x19\x05\x01\x03\x01\x03\x05\x03\t\x07\t\x0b\r\x03\xa5{\x17\x01?\x0f\x07\x0b\x13\x0f\x0b\x13#\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x13\x0b\x17\x0b\x13\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x03=\x0fO\x0b\x0b\x0b\x0b\x13\x1b\x0b\x1b\x0b\x0b\x0f\x13\x0b\x0b\x0b/\x1f\x1f\x1f\x1f\x0b\x0b\x0b\x0b3\x0f\x13\x0f\x01\x03\x0f\x03\x15\x0f\x17\x07\x17\x0f\x07\x1b\x07\x13\x13\x02Z\x04\x1d#%\x1f\x05\x0f\x03\x03\x05c\x11\x01\x05\x05\x11\x03\x03\x05e\x03\x07\x11\t\x13\t\x0b\x15\x05\x13\x05\x15\x05\x17\x03\x0b\x19K\x1bU\x1dW\x0b]\x1f_\x05\x19\x05\x1b\x05\x1d\x05\x1f\x03\x03\x05a\x05!\x17'\xfe\x07\x01\x05#\x03\x03\x05g\x03\x03\x05i\x03\x11/k1I3m5o7q9s;u=y\x05%\x05'\x05)\x05+\x05-\x05/\x051\x053\x1f\x13\x01\x1f\x15!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1d5\x1d7\x1d9\x1d;\x03\x05MQ\r\x05COEG\x1d=\r\x05CSEG\x1d?#\x0f\x03\x03Y\r\x03[I\x1dA\x1dC\x1dE\x1f\x0b\x11\x00\x00\x00\x00\x00\x00\xf0?\x1f\x03\t\x01\x00\x00\x00\x1f\x03\t\x00\x00\x00\x00\x1f\x03\t\x04\x00\x00\x00\x1f\x03\t\x05\x00\x00\x00\x0b\x05\x1dG\x03\x01\x05\x01\x03\x15????????AA\x03\x03w\x15\x01%\x01\x03\x03A\x01\x02\x02)\x01\x11)\x05\x11\x15\x07\x0b)\x05\x11\x11\x07)\x01\x07\x13\x11\x05\t\x05\x03\x05\x1b)\x03\x01\r)\x03\t\r\x04\xb9\x05\x01\x11\x03\x0f\x07\x03\x01\x05\x05\x11\x03\x17\x05\x03\x17+\x05\t\x03\x05\x03\x03\x03\x01!\x03\x0b\x03\x03\x01\x07\x03\x03\x03\x03\x01\x07\x03\x03\x03\x03\x01\r\x03\x03\x03\x03\x01\r\x03\x03\x03\x03\x01)\x03\x03\x03\x03\x01+\x03\x03\x03\x03\x01\x07\x03\x03\x07\x07\x01-\x03\x05\x15\x07\t\x0b\r\x0f\x11\x13\x05\x01\x03\t\x04\x03\x03\x15\x06\x03\x01\x05\x01\x00\xca\tI\x17\x0f\x0b!\x05\x05\x03\x1b\x1d\x1b\x1f/!!)#\x1f\x19\x97\xf1\x1f\x15\x1d\x15\x13%)\x13\r\x15\x1f\x11\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00func_v1\x00custom_call_v1\x00return_v1\x00value\x00sym_name\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00jit(func)/jit(main)/triangular_solve[left_side=True lower=True transpose_a=False conjugate_a=False unit_diagonal=False]\x00/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jax.arg_info\x00mhlo.sharding\x00{replicated}\x00\x00a\x00b\x00jax.result_info\x00main\x00public\x00blas_dtrsm\x00", - xla_call_module_version=6, -) # End paste - - -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_07_16["c64"] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['blas_ctrsm'], - serialized_date=datetime.date(2023, 7, 16), - inputs=(array([[ 5.+0.j, 0.+0.j, 0.+0.j, 0.+0.j], - [ 4.+0.j, 10.+0.j, 0.+0.j, 0.+0.j], - [ 8.+0.j, 9.+0.j, 15.+0.j, 0.+0.j], - [12.+0.j, 13.+0.j, 14.+0.j, 20.+0.j]], dtype=complex64), array([[ 0.+0.j, 1.+0.j, 2.+0.j, 3.+0.j, 4.+0.j], - [ 5.+0.j, 6.+0.j, 7.+0.j, 8.+0.j, 9.+0.j], - [10.+0.j, 11.+0.j, 12.+0.j, 13.+0.j, 14.+0.j], - [15.+0.j, 16.+0.j, 17.+0.j, 18.+0.j, 19.+0.j]], dtype=complex64)), - expected_outputs=(array([[ 0. +0.j, 0.2 +0.j, 0.4 +0.j, - 0.6 +0.j, 0.8 +0.j], - [ 0.5 +0.j, 0.52 +0.j, 0.54 +0.j, - 0.56 +0.j, 0.58000004 +0.j], - [ 0.36666667 +0.j, 0.31466666 +0.j, 0.26266667 +0.j, - 0.21066667 +0.j, 0.15866666 +0.j], - [ 0.16833334 +0.j, 0.12173338 +0.j, 0.0751333 +0.j, - 0.02853328 +0.j, -0.018066704+0.j]], dtype=complex64),), - mlir_module_text=r""" -#loc = loc(unknown) -module @jit_func attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main(%arg0: tensor<4x4xcomplex> {jax.arg_info = "a", mhlo.sharding = "{replicated}"} loc(unknown), %arg1: tensor<4x5xcomplex> {jax.arg_info = "b", mhlo.sharding = "{replicated}"} loc(unknown)) -> (tensor<4x5xcomplex> {jax.result_info = ""}) { - %0 = stablehlo.constant dense<(1.000000e+00,0.000000e+00)> : tensor> loc(#loc2) - %1 = stablehlo.constant dense<1> : tensor loc(#loc2) - %2 = stablehlo.constant dense<1> : tensor loc(#loc2) - %3 = stablehlo.constant dense<0> : tensor loc(#loc2) - %4 = stablehlo.constant dense<0> : tensor loc(#loc2) - %5 = stablehlo.constant dense<4> : tensor loc(#loc2) - %6 = stablehlo.constant dense<5> : tensor loc(#loc2) - %7 = stablehlo.constant dense<1> : tensor loc(#loc2) - %8 = stablehlo.custom_call @blas_ctrsm(%1, %2, %3, %4, %5, %6, %7, %0, %arg0, %arg1) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>]} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor>, tensor<4x4xcomplex>, tensor<4x5xcomplex>) -> tensor<4x5xcomplex> loc(#loc2) - return %8 : tensor<4x5xcomplex> loc(#loc) - } loc(#loc) -} loc(#loc) -#loc1 = loc("/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py":510:0) -#loc2 = loc("jit(func)/jit(main)/triangular_solve[left_side=True lower=True transpose_a=False conjugate_a=False unit_diagonal=False]"(#loc1)) -""", - mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01\x19\x05\x01\x03\x01\x03\x05\x03\t\x07\t\x0b\r\x03\xa7{\x19\x01?\x0f\x07\x0b\x13\x0f\x0b\x13#\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x13\x0b\x17\x0b\x13\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x03=\x0fO\x0b\x0b\x0b\x0b\x13\x1b\x0b\x1b\x0b\x0b\x0f\x13\x0b\x0b\x0b/\x1f\x1f\x1f\x1f\x0b\x0b\x0b\x0b3\x0f\x13\x0f\x01\x03\x0f\x03\x17\x0f\x17\x0b\x17\x0f\x07\x1b\x07\x07\x13\x13\x02b\x04\x1d#%\x1f\x05\x0f\x03\x03\x05c\x11\x01\x05\x05\x11\x03\x03\x05e\x03\x07\x11\t\x13\t\x0b\x15\x05\x13\x05\x15\x05\x17\x03\x0b\x19K\x1bU\x1dW\x0b]\x1f_\x05\x19\x05\x1b\x05\x1d\x05\x1f\x03\x03\x05a\x05!\x17'\xfa\x07\x01\x05#\x03\x03\x05g\x03\x03\x05i\x03\x11/k1I3m5o7q9s;u=y\x05%\x05'\x05)\x05+\x05-\x05/\x051\x053\x1f\x15\x01\x1f\x17!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1d5\x1d7\x1d9\x1d;\x03\x05MQ\r\x05COEG\x1d=\r\x05CSEG\x1d?#\x0f\x03\x03Y\r\x03[I\x1dA\x1dC\x1dE\x1f\x0b\x11\x00\x00\x80?\x00\x00\x00\x00\x1f\x03\t\x01\x00\x00\x00\x1f\x03\t\x00\x00\x00\x00\x1f\x03\t\x04\x00\x00\x00\x1f\x03\t\x05\x00\x00\x00\x0b\x05\x1dG\x03\x01\x05\x01\x03\x15????????AA\x03\x03w\x15\x01%\x01\x03\x03A\x01\x02\x02)\x01\x13)\x05\x11\x15\x07\x03\x11)\x05\x11\x11\x07)\x01\x07\x13\x11\x05\t\x05\x03\x05\t\x1b)\x03\x01\r)\x03\t\r\x04\xb9\x05\x01\x11\x03\x0f\x07\x03\x01\x05\x05\x11\x03\x17\x05\x03\x17+\x05\t\x03\x05\x03\x03\x03\x01!\x03\x0b\x03\x03\x01\x07\x03\x03\x03\x03\x01\x07\x03\x03\x03\x03\x01\r\x03\x03\x03\x03\x01\r\x03\x03\x03\x03\x01)\x03\x03\x03\x03\x01+\x03\x03\x03\x03\x01\x07\x03\x03\x07\x07\x01-\x03\x05\x15\x07\t\x0b\r\x0f\x11\x13\x05\x01\x03\t\x04\x03\x03\x15\x06\x03\x01\x05\x01\x00\xca\tI\x17\x0f\x0b!\x05\x05\x03\x1b\x1d\x1b\x1f/!!)#\x1f\x19\x97\xf1\x1f\x15\x1d\x15\x13%)\x13\r\x15\x1f\x11\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00func_v1\x00custom_call_v1\x00return_v1\x00value\x00sym_name\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00jit(func)/jit(main)/triangular_solve[left_side=True lower=True transpose_a=False conjugate_a=False unit_diagonal=False]\x00/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jax.arg_info\x00mhlo.sharding\x00{replicated}\x00\x00a\x00b\x00jax.result_info\x00main\x00public\x00blas_ctrsm\x00", - xla_call_module_version=6, -) # End paste - - -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_07_16["c128"] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['blas_ztrsm'], - serialized_date=datetime.date(2023, 7, 16), - inputs=(array([[ 5.+0.j, 0.+0.j, 0.+0.j, 0.+0.j], - [ 4.+0.j, 10.+0.j, 0.+0.j, 0.+0.j], - [ 8.+0.j, 9.+0.j, 15.+0.j, 0.+0.j], - [12.+0.j, 13.+0.j, 14.+0.j, 20.+0.j]]), array([[ 0.+0.j, 1.+0.j, 2.+0.j, 3.+0.j, 4.+0.j], - [ 5.+0.j, 6.+0.j, 7.+0.j, 8.+0.j, 9.+0.j], - [10.+0.j, 11.+0.j, 12.+0.j, 13.+0.j, 14.+0.j], - [15.+0.j, 16.+0.j, 17.+0.j, 18.+0.j, 19.+0.j]])), - expected_outputs=(array([[ 0. +0.j, 0.2 +0.j, - 0.4 +0.j, 0.6000000000000001 +0.j, - 0.8 +0.j], - [ 0.5 +0.j, 0.52 +0.j, - 0.54 +0.j, 0.5599999999999999 +0.j, - 0.58 +0.j], - [ 0.36666666666666664 +0.j, 0.3146666666666667 +0.j, - 0.2626666666666667 +0.j, 0.21066666666666667 +0.j, - 0.15866666666666665 +0.j], - [ 0.16833333333333336 +0.j, 0.1217333333333333 +0.j, - 0.07513333333333323 +0.j, 0.0285333333333333 +0.j, - -0.018066666666666675+0.j]]),), - mlir_module_text=r""" -#loc = loc(unknown) -module @jit_func attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main(%arg0: tensor<4x4xcomplex> {jax.arg_info = "a", mhlo.sharding = "{replicated}"} loc(unknown), %arg1: tensor<4x5xcomplex> {jax.arg_info = "b", mhlo.sharding = "{replicated}"} loc(unknown)) -> (tensor<4x5xcomplex> {jax.result_info = ""}) { - %0 = stablehlo.constant dense<(1.000000e+00,0.000000e+00)> : tensor> loc(#loc2) - %1 = stablehlo.constant dense<1> : tensor loc(#loc2) - %2 = stablehlo.constant dense<1> : tensor loc(#loc2) - %3 = stablehlo.constant dense<0> : tensor loc(#loc2) - %4 = stablehlo.constant dense<0> : tensor loc(#loc2) - %5 = stablehlo.constant dense<4> : tensor loc(#loc2) - %6 = stablehlo.constant dense<5> : tensor loc(#loc2) - %7 = stablehlo.constant dense<1> : tensor loc(#loc2) - %8 = stablehlo.custom_call @blas_ztrsm(%1, %2, %3, %4, %5, %6, %7, %0, %arg0, %arg1) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>]} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor>, tensor<4x4xcomplex>, tensor<4x5xcomplex>) -> tensor<4x5xcomplex> loc(#loc2) - return %8 : tensor<4x5xcomplex> loc(#loc) - } loc(#loc) -} loc(#loc) -#loc1 = loc("/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py":510:0) -#loc2 = loc("jit(func)/jit(main)/triangular_solve[left_side=True lower=True transpose_a=False conjugate_a=False unit_diagonal=False]"(#loc1)) -""", - mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01\x19\x05\x01\x03\x01\x03\x05\x03\t\x07\t\x0b\r\x03\xa7{\x19\x01?\x0f\x07\x0b\x13\x0f\x0b\x13#\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x13\x0b\x17\x0b\x13\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x03=\x0fO\x0b\x0b\x0b\x0b\x13\x1b\x0b\x1b\x0b\x0b\x0f\x13\x0b\x0b\x0bO\x1f\x1f\x1f\x1f\x0b\x0b\x0b\x0b3\x0f\x13\x0f\x01\x03\x0f\x03\x17\x0f\x17\x0b\x17\x0f\x07\x1b\x07\x07\x13\x13\x02\x82\x04\x1d#%\x1f\x05\x0f\x03\x03\x05c\x11\x01\x05\x05\x11\x03\x03\x05e\x03\x07\x11\t\x13\t\x0b\x15\x05\x13\x05\x15\x05\x17\x03\x0b\x19K\x1bU\x1dW\x0b]\x1f_\x05\x19\x05\x1b\x05\x1d\x05\x1f\x03\x03\x05a\x05!\x17'\xfa\x07\x01\x05#\x03\x03\x05g\x03\x03\x05i\x03\x11/k1I3m5o7q9s;u=y\x05%\x05'\x05)\x05+\x05-\x05/\x051\x053\x1f\x15\x01\x1f\x17!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1d5\x1d7\x1d9\x1d;\x03\x05MQ\r\x05COEG\x1d=\r\x05CSEG\x1d?#\x0f\x03\x03Y\r\x03[I\x1dA\x1dC\x1dE\x1f\x0b!\x00\x00\x00\x00\x00\x00\xf0?\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x03\t\x01\x00\x00\x00\x1f\x03\t\x00\x00\x00\x00\x1f\x03\t\x04\x00\x00\x00\x1f\x03\t\x05\x00\x00\x00\x0b\x05\x1dG\x03\x01\x05\x01\x03\x15????????AA\x03\x03w\x15\x01%\x01\x03\x03A\x01\x02\x02)\x01\x13)\x05\x11\x15\x07\x03\x11)\x05\x11\x11\x07)\x01\x07\x13\x11\x05\t\x05\x03\x05\x0b\x1b)\x03\x01\r)\x03\t\r\x04\xb9\x05\x01\x11\x03\x0f\x07\x03\x01\x05\x05\x11\x03\x17\x05\x03\x17+\x05\t\x03\x05\x03\x03\x03\x01!\x03\x0b\x03\x03\x01\x07\x03\x03\x03\x03\x01\x07\x03\x03\x03\x03\x01\r\x03\x03\x03\x03\x01\r\x03\x03\x03\x03\x01)\x03\x03\x03\x03\x01+\x03\x03\x03\x03\x01\x07\x03\x03\x07\x07\x01-\x03\x05\x15\x07\t\x0b\r\x0f\x11\x13\x05\x01\x03\t\x04\x03\x03\x15\x06\x03\x01\x05\x01\x00\xca\tI\x17\x0f\x0b!\x05\x05\x03\x1b\x1d\x1b\x1f/!!)#\x1f\x19\x97\xf1\x1f\x15\x1d\x15\x13%)\x13\r\x15\x1f\x11\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00func_v1\x00custom_call_v1\x00return_v1\x00value\x00sym_name\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00jit(func)/jit(main)/triangular_solve[left_side=True lower=True transpose_a=False conjugate_a=False unit_diagonal=False]\x00/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jax.arg_info\x00mhlo.sharding\x00{replicated}\x00\x00a\x00b\x00jax.result_info\x00main\x00public\x00blas_ztrsm\x00", - xla_call_module_version=6, -) # End paste - data_2024_12_02 = {} - # Pasted from the test output (see export_back_compat_test_util.py module docstring) data_2024_12_02['c128'] = dict( testdata_version=1, diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/cpu_tridiagonal_lapack_sytrd_hetrd.py b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_tridiagonal_lapack_sytrd_hetrd.py index 9e245052e03a..c986e4ffd115 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_data/cpu_tridiagonal_lapack_sytrd_hetrd.py +++ b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_tridiagonal_lapack_sytrd_hetrd.py @@ -17,432 +17,8 @@ import datetime from numpy import array, float32, complex64 -data_2024_09_03 = {} - - -# Pasted from the test output (see export_back_compat_test_util.py module docstring) -data_2024_09_03["c128"] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_zhetrd'], - serialized_date=datetime.date(2024, 9, 3), - inputs=(), - expected_outputs=(array([[[-1.6782909868280393 +0.j , - -0.44670237330570184+4.847000766107959j , - 2.05945450900321 -2.2848432268240106j , - -1.852046418980849 +1.672382006137275j ], - [ 8.516713699516982 +0.j , - -2.7881860505313174 +0.j , - 0.9238284715039695 -2.3790501284019947j , - 0.5005102262291599 -1.30066052934836j ], - [-0.12132810525381293-0.2963030371159077j , - -3.6374350042782893 +0.j , - 0.5605752523031344 +0.j , - -2.9865099107523174 +0.5492956557924651j ], - [-0.40379248092949666-0.7813328344426929j , - -0.07101654492399719-0.27208840961051617j, - -7.4654253782049285 +0.j , - -8.172380353916964 +0.j ]], - - [[-3.996403598623405 +0.j , - 0.59408630943699 +2.531609474375295j , - -1.789098034543644 -2.538389274566601j , - -1.291106590337488 +3.1576544511573843j ], - [10.8950662522622 +0.j , - -2.8151642043836693 +0.j , - 6.18998567202382 +1.1866537964613415j , - 3.1900218245393352 +2.7291222716752372j ], - [-0.3142889671188478 -0.37781876498252764j, - 3.049208563595754 +0.j , - -2.4383044880335487 +0.j , - 4.075435464493341 -0.6653616942280807j ], - [ 0.32757687545025194+0.565870910342534j , - 0.8177026465997795 -0.15906305615104555j, - 3.3415143060767125 +0.j , - 4.094619408678314 +0.j ]]]), array([[-1.6782909868280393, -2.7881860505313174, 0.5605752523031344, - -8.172380353916964 ], - [-3.996403598623405 , -2.8151642043836693, -2.4383044880335487, - 4.094619408678314 ]]), array([[ 8.516713699516982 , -3.6374350042782893, -7.4654253782049285], - [10.8950662522622 , 3.049208563595754 , 3.3415143060767125]]), array([[1.0626274644222748+0.06050271598884928j, - 1.834630852474663 +0.18575551495730305j, - 1.981584368497257 +0.19102912741736966j], - [1.0365789616521406-0.40942548304121656j, - 1.0872592163018966-0.3187050677167622j , - 1.0458498304770472-0.9989483435319496j ]])), - mlir_module_text=r""" -#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":695:13) -#loc5 = loc("jit(func)/jit(main)/pjit"(#loc1)) -module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main() -> (tensor<2x4x4xcomplex> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<2x4xf64> {jax.result_info = "[1]", mhlo.layout_mode = "default"}, tensor<2x3xf64> {jax.result_info = "[2]", mhlo.layout_mode = "default"}, tensor<2x3xcomplex> {jax.result_info = "[3]", mhlo.layout_mode = "default"}) { - %cst = stablehlo.constant dense<[[[(-1.6782909868280393,-0.44303325034407437), (-0.44670237330570184,4.8470007661079588), (2.0594545090032099,-2.2848432268240106), (-1.852046418980849,1.6723820061372749)], [(-0.53338018421119981,-0.5152843101202178), (-8.6208093221459947,-1.4723511111926109), (0.92382847150396952,-2.3790501284019947), (0.50051022622915986,-1.30066052934836)], [(0.94535043721506584,2.744088772946665), (-5.9178492824175759,-4.3744650461123786), (1.8341291553102983,-4.8378584827626838), (-2.9865099107523174,0.54929565579246509)], [(3.2517513113853891,7.2792034361133062), (-0.09841002311276037,0.88008791818205689), (-0.035759860211603468,2.4677764344580244), (-3.6133109853094476,-2.2833696560058976)]], [[(-3.996403598623405,2.42308766118121), (0.59408630943699003,2.531609474375295), (-1.789098034543644,-2.538389274566601), (-1.2911065903374881,3.1576544511573843)], [(-0.39853021063902833,4.4607177630985086), (1.0742061295773189,-2.6002112528615386), (6.1899856720238198,1.1866537964613415), (3.1900218245393352,2.7291222716752372)], [(5.2347956435718022,2.8649782894514577), (2.3527586611916762,2.4688953673448575), (-2.317572140163894,4.3609023810820053), (4.0754354644933413,-0.66536169422808067)], [(-6.2237114632988675,-4.9294897244018943), (4.2994486027667103,-1.3300494261380422), (-0.51942958410141249,0.60038999428238982), (0.084516726847668963,-7.2944134049318752)]]]> : tensor<2x4x4xcomplex> loc(#loc) - %c = stablehlo.constant dense<4> : tensor loc(#loc2) - %c_0 = stablehlo.constant dense<1> : tensor loc(#loc2) - %c_1 = stablehlo.constant dense<4> : tensor loc(#loc2) - %c_2 = stablehlo.constant dense<2> : tensor loc(#loc2) - %c_3 = stablehlo.constant dense<128> : tensor loc(#loc2) - %0:6 = stablehlo.custom_call @lapack_zhetrd(%c, %c_0, %c_1, %c_2, %c_3, %cst) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>, dense<[1, 0]> : tensor<2xindex>, dense<[1, 0]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor, tensor<2x4x4xcomplex>) -> (tensor<2x4x4xcomplex>, tensor<2x4xf64>, tensor<2x3xf64>, tensor<2x3xcomplex>, tensor<2xi32>, tensor<128xcomplex>) loc(#loc2) - %c_4 = stablehlo.constant dense<0> : tensor loc(#loc) - %1 = stablehlo.broadcast_in_dim %c_4, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) - %2 = stablehlo.compare EQ, %0#4, %1, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) - %3 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc4) - %cst_5 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> loc(#loc) - %4 = call @_where(%3, %0#0, %cst_5) : (tensor<2x1x1xi1>, tensor<2x4x4xcomplex>, tensor>) -> tensor<2x4x4xcomplex> loc(#loc5) - %c_6 = stablehlo.constant dense<0> : tensor loc(#loc) - %5 = stablehlo.broadcast_in_dim %c_6, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) - %6 = stablehlo.compare EQ, %0#4, %5, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) - %7 = stablehlo.broadcast_in_dim %6, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc4) - %cst_7 = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc) - %8 = call @_where_0(%7, %0#1, %cst_7) : (tensor<2x1xi1>, tensor<2x4xf64>, tensor) -> tensor<2x4xf64> loc(#loc5) - %c_8 = stablehlo.constant dense<0> : tensor loc(#loc) - %9 = stablehlo.broadcast_in_dim %c_8, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) - %10 = stablehlo.compare EQ, %0#4, %9, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) - %11 = stablehlo.broadcast_in_dim %10, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc4) - %cst_9 = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc) - %12 = call @_where_1(%11, %0#2, %cst_9) : (tensor<2x1xi1>, tensor<2x3xf64>, tensor) -> tensor<2x3xf64> loc(#loc5) - %c_10 = stablehlo.constant dense<0> : tensor loc(#loc) - %13 = stablehlo.broadcast_in_dim %c_10, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) - %14 = stablehlo.compare EQ, %0#4, %13, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) - %15 = stablehlo.broadcast_in_dim %14, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc4) - %cst_11 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> loc(#loc) - %16 = call @_where_2(%15, %0#3, %cst_11) : (tensor<2x1xi1>, tensor<2x3xcomplex>, tensor>) -> tensor<2x3xcomplex> loc(#loc5) - return %4, %8, %12, %16 : tensor<2x4x4xcomplex>, tensor<2x4xf64>, tensor<2x3xf64>, tensor<2x3xcomplex> loc(#loc) - } loc(#loc) - func.func private @_where(%arg0: tensor<2x1x1xi1> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg1: tensor<2x4x4xcomplex> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg2: tensor> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1))) -> (tensor<2x4x4xcomplex> {mhlo.layout_mode = "default"}) { - %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc6) - %1 = stablehlo.broadcast_in_dim %arg2, dims = [] : (tensor>) -> tensor<2x4x4xcomplex> loc(#loc6) - %2 = stablehlo.select %0, %arg1, %1 : tensor<2x4x4xi1>, tensor<2x4x4xcomplex> loc(#loc7) - return %2 : tensor<2x4x4xcomplex> loc(#loc5) - } loc(#loc5) - func.func private @_where_0(%arg0: tensor<2x1xi1> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg1: tensor<2x4xf64> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg2: tensor {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1))) -> (tensor<2x4xf64> {mhlo.layout_mode = "default"}) { - %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x4xi1> loc(#loc6) - %1 = stablehlo.broadcast_in_dim %arg2, dims = [] : (tensor) -> tensor<2x4xf64> loc(#loc6) - %2 = stablehlo.select %0, %arg1, %1 : tensor<2x4xi1>, tensor<2x4xf64> loc(#loc7) - return %2 : tensor<2x4xf64> loc(#loc5) - } loc(#loc5) - func.func private @_where_1(%arg0: tensor<2x1xi1> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg1: tensor<2x3xf64> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg2: tensor {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1))) -> (tensor<2x3xf64> {mhlo.layout_mode = "default"}) { - %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x3xi1> loc(#loc6) - %1 = stablehlo.broadcast_in_dim %arg2, dims = [] : (tensor) -> tensor<2x3xf64> loc(#loc6) - %2 = stablehlo.select %0, %arg1, %1 : tensor<2x3xi1>, tensor<2x3xf64> loc(#loc7) - return %2 : tensor<2x3xf64> loc(#loc5) - } loc(#loc5) - func.func private @_where_2(%arg0: tensor<2x1xi1> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg1: tensor<2x3xcomplex> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg2: tensor> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1))) -> (tensor<2x3xcomplex> {mhlo.layout_mode = "default"}) { - %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x3xi1> loc(#loc6) - %1 = stablehlo.broadcast_in_dim %arg2, dims = [] : (tensor>) -> tensor<2x3xcomplex> loc(#loc6) - %2 = stablehlo.select %0, %arg1, %1 : tensor<2x3xi1>, tensor<2x3xcomplex> loc(#loc7) - return %2 : tensor<2x3xcomplex> loc(#loc5) - } loc(#loc5) -} loc(#loc) -#loc = loc(unknown) -#loc2 = loc("jit(func)/jit(main)/tridiagonal"(#loc1)) -#loc3 = loc("jit(func)/jit(main)/eq"(#loc1)) -#loc4 = loc("jit(func)/jit(main)/broadcast_in_dim"(#loc1)) -#loc6 = loc("jit(func)/jit(main)/jit(_where)/broadcast_in_dim"(#loc1)) -#loc7 = loc("jit(func)/jit(main)/jit(_where)/select_n"(#loc1)) -""", - mlir_module_serialized=b'ML\xefR\rStableHLO_v1.3.0\x00\x01#\x05\x01\x05\x13\x01\x03\x0b\x03\x11\x0f\x13\x17\x1b\x1f#\'+\x03\xf7\x99I\x01-\x0f\x07\x0f\x0f\x17\x0f\x0f\x0f\x0f#\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x03m\x0f\x0b\x0b\x0f\x0b\x17\x13\x0f\x0b\x1f\x0b\x0b/OO\x0b\x0b\x0b\x0b\x0b\x1fo/O/\x0b\x1b\x1b\x0b\x1b\x0b\x1b\x0b\x1b\x0b\x0b\x0b\x0b\x0b\x0b\x0bo&\x10\x1f\x1f\x1f\x0b\x0b\x0b\x0b#\x0f\x17#\x01\x05\x0b\x0f\x03E\x0f\x1b\x17\x17\x17\x17\x0f\x0f\x07\x13\x0b\x07\x07\x07\x13\x1b\x17\x07\x1f\x1f\x1f\x1f\x1f\x13\x13\x17\x1b\x13\x17\x13\x13\x13\x13\x13\x02\x12\x10\x1d\x1f\t\x1f\x1d#\t\x1d)\t\x17!\xde\n\x1b\x1d\'\t\x1d%\t\x1d+\t\x11\x03\x05\x03\x07\x15\x17\x19\x11\x1b\x11\x05\x17\x11\x01\x00\x05\x19\x05\x1b\x05\x1d\x05\x1f\x05!\x05#\x05%\x05\'\x05)\x05+\x1f5\x01\x1d-\x1d/\x1f?\x01\x1d1\x03\x07999\r\x03/1\x03\x039\x1d3\x1f\x05\t\x00\x00\x00\x00\t\x07\x07\x01\x1fG\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f3!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1fC!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1d5\x1d7\x1d9\x1d;\x1f\x05\t\x04\x00\x00\x00\x1fA1\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1fE\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x11!\x00\x00\x00\x00\x00\x00\xf8\x7f\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f\x13\x11\x00\x00\x00\x00\x00\x00\xf8\x7f#)\x03\tcgko\r\x055e/1\x1d=\r\x055i/1\x1d?\r\x055m/1\x1dA\r\x055q/1\x1dC\x1dE\x1dG#+#-#/#1\x1f;1\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x1f\x07\x02\x08d\x91Y\xa6G\xda\xfa\xbf$-Q"\xa8Z\xdc\xbfL0\x19\x8d\xc5\x96\xdc\xbf\x86{8+Tc\x13@\xf0%\x1eI\xc3y\x00@\xe4\x91\xbd\xe2[G\x02\xc0\x85%\x03m\xfb\xa1\xfd\xbf\x9atl\xa2\x13\xc2\xfa?\x9c\xb0\xf0Qs\x11\xe1\xbf\xd8v\x83\x855}\xe0\xbf\x84V/\xb8\xda=!\xc0\n\xd3\xec\t\xc0\x8e\xf7\xbf\x98$\x07\xba\x00\x90\xed?\xd5?\x08oK\x08\x03\xc0>\xf8\x9e\x05.\x04\xe0?\xf2\xfcKj\x81\xcf\xf4\xbf\xe4"c\x8fO@\xee?y\x03\x89\xd0\xe4\xf3\x05@\xee\x8f\xaa\xae\xe0\xab\x17\xc0\xf20\xda\xc3s\x7f\x11\xc0V*+\xd0\x97X\xfd?P\x91\xf8\x92\xf7Y\x13\xc0\x7f\xe3\xdeN_\xe4\x07\xc0\x14\xd5\xae{\xd4\x93\xe1?\xbc\x00\t1\x96\x03\n@`&l\x81\xe7\x1d\x1d@X/\xde6f1\xb9\xbf\x06KF#\xae)\xec?\xcd\x9a<\xcc\x1dO\xa2\xbf\x91\xb1>\x92\x01\xbe\x03@\xf2s\x01\x97\x0f\xe8\x0c\xc0\xf5\xcaiOWD\x02\xc0F\xa2-s\xa2\xf8\x0f\xc0X\xea\xa0\xc8{b\x03@\x0b\x10\xc1J\xc1\x02\xe3?2|\xd5w\xbc@\x04@\xca>\xbbB%\xa0\xfc\xbf\xe8>6\t\x9fN\x04\xc0\xafdRb_\xa8\xf4\xbf\x80Q>V\xe0B\t@UhJ\xdb\x84\x81\xd9\xbf\t\xc7\xb4e\xc6\xd7\x11@<(;\xc4\xf2/\xf1?\x1a\xda\xad\x8e;\xcd\x04\xc0\x1c4\xa0\x9a\x8b\xc2\x18@z\x9c\xf7\xb0\x88\xfc\xf2?\xaea\x8f)*\x85\t@\x00\x0b\xbd\x0e>\xd5\x05@b\x89\xe9Dn\xf0\x14@a\x8d\xc7\xbcy\xeb\x06@\x8a\x97\t"s\xd2\x02@\xc2\xef\xdf6L\xc0\x03@J\xff Cc\x8a\x02\xc0\xd7.\xcfd\x90q\x11@s\xd4S\xf4>M\x10@t\x10\x97\x9b\xa4J\xe5\xbf\x8eo*\x9e\x14\xe5\x18\xc0\xc5\x18\x81\'\xcc\xb7\x13\xc0\x19\xdd\x8e\xa7\xa22\x11@-95\xe8\xe1G\xf5\xbfZK\x89\xca*\x9f\xe0\xbfR;\xc9\x13e6\xe3?\x7f\x94\xc6a\xe3\xa2\xb5?\xe2\xbe&\xb5z-\x1d\xc0\x1f\x05\t\x01\x00\x00\x00\x1f\x05\t\x02\x00\x00\x00\x1f\x05\t\x80\x00\x00\x00\x0b\x05\x1dI\x1dK\x05\x01\x03\r33333W\x03\x03\x95\x15\x03\x01\x15\x01\x03\rWIIIYY\x01\t\x01\x02\x02)\x01\')\x07\t\x11\x11\x19)\x05\t\x05\x15)\x05\t\x11\x1b)\x05\t\r\x1b)\x05\t\r\x19)\x01\x19)\x01\x1b\x01)\x03\t\'\x03\x1b\x0b\x1d\x13)\x03\t\x15)\x07\t\x05\x05\x15)\x05\t\r\x15\x1b\x11\x01\t\x07\x0b\r\x0f\x11\x07#\x07\x11\x03\x07\x11\x07\t\x0b\x13\x03\x0b\x11\x07\t\r\x13\x03\r\x11\x07\t\x0f\x11\x03\x0f)\x03\t\x1d)\x03\x01\x1d)\x05\t\x11\x15)\x07\t\x11\x11\x15)\x03\r\x1d)\x03\x02\x04\x19)\x03\x01\x1f)\x03\r\x1f)\x03\t\x1f)\x03\x05\x1f)\x03\x05\x1d\x04J\x07\x05\x01Q\x03\x13\x01\x07\x04"\x07\x03\x01\x15\x07P\x03\x03\x07\x04\xf6\x03\x03I\x81\x05B\x03\x05\x03\x07\x05B\x0b\x07\x03\x05\x05B\x0b\t\x03\x05\x05B\x0b\x07\x03\x05\x05B\x0b\x0b\x03\x05\x05B\x0b\r\x03\x05\x11F\x0b\x0f\r\x07\x0b\r\x0f\x17=\r\x03\x05\x07\t\x0b\x01\x05B\x03\x11\x03\x05\x03F\x07\x13\x03\x17\x03\x19\rF\x07\x15\x03!\x05\x15\x1b\x03F\x0f\x17\x03#\x03\x1d\x05B\x03\x19\x03\x11\x0fF\x01\x1b\x03\x07\x07\x1f\r!\x05B\x03\x11\x03\x05\x03F\x07\x13\x03\x17\x03%\rF\x07\x15\x03!\x05\x15\'\x03F\x0f\x17\x03\t\x03)\x05B\x03\x1d\x03\x13\x0fF\x01\x1f\x03\x0b\x07+\x0f-\x05B\x03\x11\x03\x05\x03F\x07\x13\x03\x17\x031\rF\x07\x15\x03!\x05\x153\x03F\x0f\x17\x03\t\x035\x05B\x03\x1d\x03\x13\x0fF\x01!\x03\r\x077\x119\x05B\x03\x11\x03\x05\x03F\x07\x13\x03\x17\x03=\rF\x07\x15\x03!\x05\x15?\x03F\x0f\x17\x03\t\x03A\x05B\x03\x19\x03\x11\x0fF\x01#\x03\x0f\x07C\x13E\t\x04\x03\t#/;G\x07P\x01%\x07\x04S\x03\r\x13\x07G\x01\x0f\x01#\x01\x00\x03F\x05\'\x039\x03\x01\x03F\x05\x13\x03\x07\x03\x05\x0b\x06\r\x03\x07\x07\x07\x03\t\t\x04\x01\x03\x0b\x07P\x01)\x07\x04S\x03\r\x13\x07\x13\x01\x17\x01\'\x01\x00\x03F\x05+\x037\x03\x01\x03F\x05\x13\x03\x0b\x03\x05\x0b\x06\r\x03\x0b\x07\x07\x03\t\t\x04\x01\x03\x0b\x07P\x01-\x07\x04S\x03\r\x13\x07\x13\x01\x1b\x01\'\x01\x00\x03F\x05+\x03%\x03\x01\x03F\x05\x13\x03\r\x03\x05\x0b\x06\r\x03\r\x07\x07\x03\t\t\x04\x01\x03\x0b\x07P\x01/\x07\x04S\x03\r\x13\x07\x13\x01\x1f\x01#\x01\x00\x03F\x05+\x03%\x03\x01\x03F\x05\x13\x03\x0f\x03\x05\x0b\x06\r\x03\x0f\x07\x07\x03\t\t\x04\x01\x03\x0b\x06\x03\x01\x05\x01\x00\x96\tM\x1d\x03\x0f\x0b\t\t\t\t\x13\x13\x13\x0f\x11!\x11#K/ASci3\x13%)9\x1f\x11\x17\x15\x15\x11\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00func_v1\x00return_v1\x00select_v1\x00compare_v1\x00call_v1\x00custom_call_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00jit(func)/jit(main)/pjit\x00third_party/py/jax/tests/export_back_compat_test.py\x00jit(func)/jit(main)/jit(_where)/broadcast_in_dim\x00jit(func)/jit(main)/jit(_where)/select_n\x00jit(func)/jit(main)/tridiagonal\x00jit(func)/jit(main)/eq\x00jit(func)/jit(main)/broadcast_in_dim\x00mhlo.layout_mode\x00default\x00jax.result_info\x00private\x00_where\x00_where_0\x00_where_1\x00_where_2\x00[0]\x00[1]\x00[2]\x00[3]\x00main\x00public\x00\x00lapack_zhetrd\x00\x08\x9d1\x05;\x01\x0bK_asu\x03\x81\x03U\x03\x83\x03\x85\x03\x87\x11\x89\x8b\x8dK\x8f\x91\x93\x97\x03?\x03-\x05AC\x03E\x03[\x03M\x03]\x03O\x03Q\x03S\x0b7w;M=\x03\x7f\x0b7y;O=\x03G\x0b7{;Q=\x0b7};S=', - xla_call_module_version=9, - nr_devices=1, -) # End paste - - -# Pasted from the test output (see export_back_compat_test_util.py module docstring) -data_2024_09_03["c64"] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_chetrd'], - serialized_date=datetime.date(2024, 9, 3), - inputs=(), - expected_outputs=(array([[[ 3.3228416 +0.j , -1.9756439 +4.593356j , - 7.367708 +0.88518727j , -8.659938 +1.6132793j ], - [-6.9206004 +0.j , -3.6362798 +0.j , - 3.3011198 -4.644362j , -4.8589935 -0.61439794j ], - [ 0.64957 +0.060723424j, 6.620491 +0.j , - 0.2882607 +0.j , -1.0288142 +1.8544064j ], - [-0.05458622 +0.10473086j , -0.15611424 +0.06925995j , - -4.431866 +0.j , 2.364208 +0.j ]], - - [[-4.1803885 +0.j , 0.5670845 +0.6913016j , - 2.675204 -0.23881845j , -0.41825035 -1.4060576j ], - [ 8.33625 +0.j , 2.6144838 +0.j , - -2.4941807 -1.9316154j , 0.6687787 -2.209776j ], - [ 0.019031923+0.17462212j , 2.7034955 +0.j , - -0.70924187 +0.j , 2.7962255 +1.5316825j ], - [-0.057821754+0.023692288j, -0.62805307 -0.0882424j , - 6.6364865 +0.j , -1.698973 +0.j ]]], - dtype=complex64), array([[ 3.3228416 , -3.6362798 , 0.2882607 , 2.364208 ], - [-4.1803885 , 2.6144838 , -0.70924187, -1.698973 ]], - dtype=float32), array([[-6.9206004, 6.620491 , -4.431866 ], - [ 8.33625 , 2.7034955, 6.6364865]], dtype=float32), array([[1.360567 +0.1977107j , 1.7586378-0.56989706j, - 1.5772758-0.8165493j ], - [1.9152443-0.1834492j , 1.1593437+0.55631363j, - 1.6889225-0.724835j ]], dtype=complex64)), - mlir_module_text=r""" -#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":695:13) -#loc5 = loc("jit(func)/jit(main)/pjit"(#loc1)) -module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main() -> (tensor<2x4x4xcomplex> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<2x4xf32> {jax.result_info = "[1]", mhlo.layout_mode = "default"}, tensor<2x3xf32> {jax.result_info = "[2]", mhlo.layout_mode = "default"}, tensor<2x3xcomplex> {jax.result_info = "[3]", mhlo.layout_mode = "default"}) { - %cst = stablehlo.constant dense<[[[(3.32284164,1.14621949), (-1.97564387,4.59335613), (7.36770821,0.885187268), (-8.65993785,1.61327934)], [(2.495340e+00,1.36827672), (-3.96969199,-0.636681795), (3.3011198,-4.64436197), (-4.85899353,-0.614397943)], [(6.03322554,1.46055949), (-3.89591122,-4.1833396), (-1.46423841,-0.106284566), (-1.0288142,1.85440636)], [(-0.657281339,0.911450386), (3.18693113,-2.02812219), (-2.64483237,0.351429433), (4.45011663,-1.79112875)]], [[(-4.18038845,-3.65238023), (0.567084491,0.691301584), (2.67520404,-0.238818452), (-0.418250352,-1.4060576)], [(-7.62970591,1.5292784), (0.269325763,2.48722434), (-2.49418068,-1.93161535), (0.668778717,-2.20977592)], [(-0.570908666,-2.75890398), (-0.235837936,3.45861554), (-0.946199476,0.23120968), (2.79622555,1.53168249)], [(0.886947453,-0.466695577), (-3.194850e+00,-0.0176551137), (-4.37602425,-3.7703948), (0.883143305,-4.70016575)]]]> : tensor<2x4x4xcomplex> loc(#loc) - %c = stablehlo.constant dense<4> : tensor loc(#loc2) - %c_0 = stablehlo.constant dense<1> : tensor loc(#loc2) - %c_1 = stablehlo.constant dense<4> : tensor loc(#loc2) - %c_2 = stablehlo.constant dense<2> : tensor loc(#loc2) - %c_3 = stablehlo.constant dense<128> : tensor loc(#loc2) - %0:6 = stablehlo.custom_call @lapack_chetrd(%c, %c_0, %c_1, %c_2, %c_3, %cst) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>, dense<[1, 0]> : tensor<2xindex>, dense<[1, 0]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor, tensor<2x4x4xcomplex>) -> (tensor<2x4x4xcomplex>, tensor<2x4xf32>, tensor<2x3xf32>, tensor<2x3xcomplex>, tensor<2xi32>, tensor<128xcomplex>) loc(#loc2) - %c_4 = stablehlo.constant dense<0> : tensor loc(#loc) - %1 = stablehlo.broadcast_in_dim %c_4, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) - %2 = stablehlo.compare EQ, %0#4, %1, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) - %3 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc4) - %cst_5 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> loc(#loc) - %4 = call @_where(%3, %0#0, %cst_5) : (tensor<2x1x1xi1>, tensor<2x4x4xcomplex>, tensor>) -> tensor<2x4x4xcomplex> loc(#loc5) - %c_6 = stablehlo.constant dense<0> : tensor loc(#loc) - %5 = stablehlo.broadcast_in_dim %c_6, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) - %6 = stablehlo.compare EQ, %0#4, %5, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) - %7 = stablehlo.broadcast_in_dim %6, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc4) - %cst_7 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc) - %8 = call @_where_0(%7, %0#1, %cst_7) : (tensor<2x1xi1>, tensor<2x4xf32>, tensor) -> tensor<2x4xf32> loc(#loc5) - %c_8 = stablehlo.constant dense<0> : tensor loc(#loc) - %9 = stablehlo.broadcast_in_dim %c_8, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) - %10 = stablehlo.compare EQ, %0#4, %9, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) - %11 = stablehlo.broadcast_in_dim %10, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc4) - %cst_9 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc) - %12 = call @_where_1(%11, %0#2, %cst_9) : (tensor<2x1xi1>, tensor<2x3xf32>, tensor) -> tensor<2x3xf32> loc(#loc5) - %c_10 = stablehlo.constant dense<0> : tensor loc(#loc) - %13 = stablehlo.broadcast_in_dim %c_10, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) - %14 = stablehlo.compare EQ, %0#4, %13, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) - %15 = stablehlo.broadcast_in_dim %14, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc4) - %cst_11 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> loc(#loc) - %16 = call @_where_2(%15, %0#3, %cst_11) : (tensor<2x1xi1>, tensor<2x3xcomplex>, tensor>) -> tensor<2x3xcomplex> loc(#loc5) - return %4, %8, %12, %16 : tensor<2x4x4xcomplex>, tensor<2x4xf32>, tensor<2x3xf32>, tensor<2x3xcomplex> loc(#loc) - } loc(#loc) - func.func private @_where(%arg0: tensor<2x1x1xi1> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg1: tensor<2x4x4xcomplex> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg2: tensor> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1))) -> (tensor<2x4x4xcomplex> {mhlo.layout_mode = "default"}) { - %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc6) - %1 = stablehlo.broadcast_in_dim %arg2, dims = [] : (tensor>) -> tensor<2x4x4xcomplex> loc(#loc6) - %2 = stablehlo.select %0, %arg1, %1 : tensor<2x4x4xi1>, tensor<2x4x4xcomplex> loc(#loc7) - return %2 : tensor<2x4x4xcomplex> loc(#loc5) - } loc(#loc5) - func.func private @_where_0(%arg0: tensor<2x1xi1> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg1: tensor<2x4xf32> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg2: tensor {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1))) -> (tensor<2x4xf32> {mhlo.layout_mode = "default"}) { - %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x4xi1> loc(#loc6) - %1 = stablehlo.broadcast_in_dim %arg2, dims = [] : (tensor) -> tensor<2x4xf32> loc(#loc6) - %2 = stablehlo.select %0, %arg1, %1 : tensor<2x4xi1>, tensor<2x4xf32> loc(#loc7) - return %2 : tensor<2x4xf32> loc(#loc5) - } loc(#loc5) - func.func private @_where_1(%arg0: tensor<2x1xi1> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg1: tensor<2x3xf32> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg2: tensor {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1))) -> (tensor<2x3xf32> {mhlo.layout_mode = "default"}) { - %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x3xi1> loc(#loc6) - %1 = stablehlo.broadcast_in_dim %arg2, dims = [] : (tensor) -> tensor<2x3xf32> loc(#loc6) - %2 = stablehlo.select %0, %arg1, %1 : tensor<2x3xi1>, tensor<2x3xf32> loc(#loc7) - return %2 : tensor<2x3xf32> loc(#loc5) - } loc(#loc5) - func.func private @_where_2(%arg0: tensor<2x1xi1> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg1: tensor<2x3xcomplex> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg2: tensor> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1))) -> (tensor<2x3xcomplex> {mhlo.layout_mode = "default"}) { - %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x3xi1> loc(#loc6) - %1 = stablehlo.broadcast_in_dim %arg2, dims = [] : (tensor>) -> tensor<2x3xcomplex> loc(#loc6) - %2 = stablehlo.select %0, %arg1, %1 : tensor<2x3xi1>, tensor<2x3xcomplex> loc(#loc7) - return %2 : tensor<2x3xcomplex> loc(#loc5) - } loc(#loc5) -} loc(#loc) -#loc = loc(unknown) -#loc2 = loc("jit(func)/jit(main)/tridiagonal"(#loc1)) -#loc3 = loc("jit(func)/jit(main)/eq"(#loc1)) -#loc4 = loc("jit(func)/jit(main)/broadcast_in_dim"(#loc1)) -#loc6 = loc("jit(func)/jit(main)/jit(_where)/broadcast_in_dim"(#loc1)) -#loc7 = loc("jit(func)/jit(main)/jit(_where)/select_n"(#loc1)) -""", - mlir_module_serialized=b'ML\xefR\rStableHLO_v1.3.0\x00\x01#\x05\x01\x05\x13\x01\x03\x0b\x03\x11\x0f\x13\x17\x1b\x1f#\'+\x03\xf7\x99I\x01-\x0f\x07\x0f\x0f\x17\x0f\x0f\x0f\x0f#\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x03m\x0f\x0b\x0b\x0f\x0b\x17\x13\x0f\x0b\x1f\x0b\x0b/OO\x0b\x0b\x0b\x0b\x0b\x1fo//\x1f\x0b\x1b\x1b\x0b\x1b\x0b\x1b\x0b\x1b\x0b\x0b\x0b\x0b\x0b\x0b\x0bo&\x08\x1f\x1f\x1f\x0b\x0b\x0b\x0b#\x0f\x17#\x01\x05\x0b\x0f\x03E\x0f\x1b\x17\x17\x17\x17\x0f\x0f\x07\x13\x0b\x07\x07\x07\x13\x1b\x17\x07\x1f\x1f\x1f\x1f\x1f\x13\x13\x17\x1b\x13\x17\x13\x13\x13\x13\x13\x02\xe2\x0b\x1d\x1f\t\x1f\x1d#\t\x1d)\t\x17!\xde\n\x1b\x1d\'\t\x1d%\t\x1d+\t\x11\x03\x05\x03\x07\x15\x17\x19\x11\x1b\x11\x05\x17\x11\x01\x00\x05\x19\x05\x1b\x05\x1d\x05\x1f\x05!\x05#\x05%\x05\'\x05)\x05+\x1f5\x01\x1d-\x1d/\x1f?\x01\x1d1\x03\x07999\r\x03/1\x03\x039\x1d3\x1f\x05\t\x00\x00\x00\x00\t\x07\x07\x01\x1fG\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f3!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1fC!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1d5\x1d7\x1d9\x1d;\x1f\x05\t\x04\x00\x00\x00\x1fA1\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1fE\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x11\x11\x00\x00\xc0\x7f\x00\x00\xc0\x7f\x1f\x13\t\x00\x00\xc0\x7f#)\x03\tcgko\r\x055e/1\x1d=\r\x055i/1\x1d?\r\x055m/1\x1dA\r\x055q/1\x1dC\x1dE\x1dG#+#-#/#1\x1f;1\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x1f\x07\x02\x04p\xa9T@R\xb7\x92?\xe6\xe1\xfc\xbf\xc6\xfc\x92@D\xc4\xeb@\xa2\x9bb?\x1b\x8f\n\xc1\xf0\x7f\xce?\xa7\xb3\x1f@\xb1#\xaf?o\x0f~\xc0\x94\xfd"\xbf\x8cES@\x9d\x9e\x94\xc0\xe0|\x9b\xc0/I\x1d\xbf/\x10\xc1@\x9d\xf3\xba?\x9cVy\xc0\xeb\xdd\x85\xc0*l\xbb\xbf\xb9\xab\xd9\xbd/\xb0\x83\xbf0]\xed?\x97C(\xbf\xd0Ti?\xae\xf6K@\xc1\xcc\x01\xc0\xefD)\xc0\x8f\xee\xb3>[g\x8e@\xb5C\xe5\xbf\xbe\xc5\x85\xc0\x99\xc0i\xc0s,\x11?$\xf90?\x8b6+@\xd3\x8ct\xbe\xe9$\xd6\xbe\xb2\xf9\xb3\xbf\x8d&\xf4\xc0e\xbf\xc3?\x11\xe5\x89>\xaf.\x1f@\xa8\xa0\x1f\xc0,?\xf7\xbf\x155+?\xf8l\r\xc0\x12\'\x12\xbf\xe2\x910\xc0\x80\x7fq\xbe\xf5Y]@!:r\xbf;\xc2l>\\\xf52@,\x0e\xc4?\xfd\x0ec?\xb9\xf2\xee\xbelxL\xc0u\xa1\x90\xbcd\x08\x8c\xc0&Nq\xc0\xae\x15b?\xc2g\x96\xc0\x1f\x05\t\x01\x00\x00\x00\x1f\x05\t\x02\x00\x00\x00\x1f\x05\t\x80\x00\x00\x00\x0b\x05\x1dI\x1dK\x05\x01\x03\r33333W\x03\x03\x95\x15\x03\x01\x15\x01\x03\rWIIIYY\x01\t\x01\x02\x02)\x01\')\x07\t\x11\x11\x19)\x05\t\x05\x15)\x05\t\x11\x1b)\x05\t\r\x1b)\x05\t\r\x19)\x01\x19)\x01\x1b\x01)\x03\t\'\x03\x1b\t\x1d\x13)\x03\t\x15)\x07\t\x05\x05\x15)\x05\t\r\x15\x1b\x11\x01\t\x07\x0b\r\x0f\x11\x07#\x07\x11\x03\x07\x11\x07\t\x0b\x13\x03\x0b\x11\x07\t\r\x13\x03\r\x11\x07\t\x0f\x11\x03\x0f)\x03\t\x1d)\x03\x01\x1d)\x05\t\x11\x15)\x07\t\x11\x11\x15)\x03\r\x1d)\x03\x02\x04\x19)\x03\x01\x1f)\x03\r\x1f)\x03\t\x1f)\x03\x05\x1f)\x03\x05\x1d\x04J\x07\x05\x01Q\x03\x13\x01\x07\x04"\x07\x03\x01\x15\x07P\x03\x03\x07\x04\xf6\x03\x03I\x81\x05B\x03\x05\x03\x07\x05B\x0b\x07\x03\x05\x05B\x0b\t\x03\x05\x05B\x0b\x07\x03\x05\x05B\x0b\x0b\x03\x05\x05B\x0b\r\x03\x05\x11F\x0b\x0f\r\x07\x0b\r\x0f\x17=\r\x03\x05\x07\t\x0b\x01\x05B\x03\x11\x03\x05\x03F\x07\x13\x03\x17\x03\x19\rF\x07\x15\x03!\x05\x15\x1b\x03F\x0f\x17\x03#\x03\x1d\x05B\x03\x19\x03\x11\x0fF\x01\x1b\x03\x07\x07\x1f\r!\x05B\x03\x11\x03\x05\x03F\x07\x13\x03\x17\x03%\rF\x07\x15\x03!\x05\x15\'\x03F\x0f\x17\x03\t\x03)\x05B\x03\x1d\x03\x13\x0fF\x01\x1f\x03\x0b\x07+\x0f-\x05B\x03\x11\x03\x05\x03F\x07\x13\x03\x17\x031\rF\x07\x15\x03!\x05\x153\x03F\x0f\x17\x03\t\x035\x05B\x03\x1d\x03\x13\x0fF\x01!\x03\r\x077\x119\x05B\x03\x11\x03\x05\x03F\x07\x13\x03\x17\x03=\rF\x07\x15\x03!\x05\x15?\x03F\x0f\x17\x03\t\x03A\x05B\x03\x19\x03\x11\x0fF\x01#\x03\x0f\x07C\x13E\t\x04\x03\t#/;G\x07P\x01%\x07\x04S\x03\r\x13\x07G\x01\x0f\x01#\x01\x00\x03F\x05\'\x039\x03\x01\x03F\x05\x13\x03\x07\x03\x05\x0b\x06\r\x03\x07\x07\x07\x03\t\t\x04\x01\x03\x0b\x07P\x01)\x07\x04S\x03\r\x13\x07\x13\x01\x17\x01\'\x01\x00\x03F\x05+\x037\x03\x01\x03F\x05\x13\x03\x0b\x03\x05\x0b\x06\r\x03\x0b\x07\x07\x03\t\t\x04\x01\x03\x0b\x07P\x01-\x07\x04S\x03\r\x13\x07\x13\x01\x1b\x01\'\x01\x00\x03F\x05+\x03%\x03\x01\x03F\x05\x13\x03\r\x03\x05\x0b\x06\r\x03\r\x07\x07\x03\t\t\x04\x01\x03\x0b\x07P\x01/\x07\x04S\x03\r\x13\x07\x13\x01\x1f\x01#\x01\x00\x03F\x05+\x03%\x03\x01\x03F\x05\x13\x03\x0f\x03\x05\x0b\x06\r\x03\x0f\x07\x07\x03\t\t\x04\x01\x03\x0b\x06\x03\x01\x05\x01\x00\x96\tM\x1d\x03\x0f\x0b\t\t\t\t\x13\x13\x13\x0f\x11!\x11#K/ASci3\x13%)9\x1f\x11\x17\x15\x15\x11\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00func_v1\x00return_v1\x00select_v1\x00compare_v1\x00call_v1\x00custom_call_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00jit(func)/jit(main)/pjit\x00third_party/py/jax/tests/export_back_compat_test.py\x00jit(func)/jit(main)/jit(_where)/broadcast_in_dim\x00jit(func)/jit(main)/jit(_where)/select_n\x00jit(func)/jit(main)/tridiagonal\x00jit(func)/jit(main)/eq\x00jit(func)/jit(main)/broadcast_in_dim\x00mhlo.layout_mode\x00default\x00jax.result_info\x00private\x00_where\x00_where_0\x00_where_1\x00_where_2\x00[0]\x00[1]\x00[2]\x00[3]\x00main\x00public\x00\x00lapack_chetrd\x00\x08\x9d1\x05;\x01\x0bK_asu\x03\x81\x03U\x03\x83\x03\x85\x03\x87\x11\x89\x8b\x8dK\x8f\x91\x93\x97\x03?\x03-\x05AC\x03E\x03[\x03M\x03]\x03O\x03Q\x03S\x0b7w;M=\x03\x7f\x0b7y;O=\x03G\x0b7{;Q=\x0b7};S=', - xla_call_module_version=9, - nr_devices=1, -) # End paste - - -# Pasted from the test output (see export_back_compat_test_util.py module docstring) -data_2024_09_03["f32"] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_ssytrd'], - serialized_date=datetime.date(2024, 9, 3), - inputs=(), - expected_outputs=(array([[[-0.8395241 , 0.156272 , -1.6810869 , 0.23832119], - [-2.985257 , -5.571 , -0.22652794, -0.83806676], - [ 0.27237308, -1.6295947 , 2.0042834 , -1.148861 ], - [-0.17183593, 0.57464546, 0.5536146 , -4.206357 ]], - - [[ 1.7666914 , 2.569005 , -0.86576384, -0.1617768 ], - [-5.143918 , 5.0426254 , -3.7237067 , 4.383015 ], - [ 0.33311516, -1.5299042 , -8.854181 , -2.896776 ], - [ 0.3419102 , 0.2669245 , -2.8250606 , 5.752488 ]]], - dtype=float32), array([[-0.8395241, -5.571 , 2.0042834, -4.206357 ], - [ 1.7666914, 5.0426254, -8.854181 , 5.752488 ]], dtype=float32), array([[-2.985257 , -1.6295947, 0.5536146], - [-5.143918 , -1.5299042, -2.8250606]], dtype=float32), array([[1.8120625, 1.5035137, 0. ], - [1.6288393, 1.8669801, 0. ]], dtype=float32)), - mlir_module_text=r""" -#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":695:13) -#loc5 = loc("jit(func)/jit(main)/pjit"(#loc1)) -module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main() -> (tensor<2x4x4xf32> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<2x4xf32> {jax.result_info = "[1]", mhlo.layout_mode = "default"}, tensor<2x3xf32> {jax.result_info = "[2]", mhlo.layout_mode = "default"}, tensor<2x3xf32> {jax.result_info = "[3]", mhlo.layout_mode = "default"}) { - %cst = stablehlo.constant dense<[[[-0.83952409, 1.562720e-01, -1.6810869, 0.238321185], [2.42421508, -5.17118931, -0.226527944, -0.838066756], [1.47339451, -1.32866347, -3.3505435, -1.14886105], [-0.929541587, -0.955984473, 2.71886253, 0.748659431]], [[1.76669145, 2.56900501, -0.865763843, -0.161776796], [3.23469758, -0.362713158, -3.72370672, 4.38301516], [2.79104376, 7.36582708, -3.04437494, -2.89677596], [2.86473417, 0.981746375, -2.13533139, 5.34802151]]]> : tensor<2x4x4xf32> loc(#loc) - %c = stablehlo.constant dense<4> : tensor loc(#loc2) - %c_0 = stablehlo.constant dense<1> : tensor loc(#loc2) - %c_1 = stablehlo.constant dense<4> : tensor loc(#loc2) - %c_2 = stablehlo.constant dense<2> : tensor loc(#loc2) - %c_3 = stablehlo.constant dense<128> : tensor loc(#loc2) - %0:6 = stablehlo.custom_call @lapack_ssytrd(%c, %c_0, %c_1, %c_2, %c_3, %cst) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>, dense<[1, 0]> : tensor<2xindex>, dense<[1, 0]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor, tensor<2x4x4xf32>) -> (tensor<2x4x4xf32>, tensor<2x4xf32>, tensor<2x3xf32>, tensor<2x3xf32>, tensor<2xi32>, tensor<128xf32>) loc(#loc2) - %c_4 = stablehlo.constant dense<0> : tensor loc(#loc) - %1 = stablehlo.broadcast_in_dim %c_4, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) - %2 = stablehlo.compare EQ, %0#4, %1, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) - %3 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc4) - %cst_5 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc) - %4 = call @_where(%3, %0#0, %cst_5) : (tensor<2x1x1xi1>, tensor<2x4x4xf32>, tensor) -> tensor<2x4x4xf32> loc(#loc5) - %c_6 = stablehlo.constant dense<0> : tensor loc(#loc) - %5 = stablehlo.broadcast_in_dim %c_6, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) - %6 = stablehlo.compare EQ, %0#4, %5, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) - %7 = stablehlo.broadcast_in_dim %6, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc4) - %cst_7 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc) - %8 = call @_where_0(%7, %0#1, %cst_7) : (tensor<2x1xi1>, tensor<2x4xf32>, tensor) -> tensor<2x4xf32> loc(#loc5) - %c_8 = stablehlo.constant dense<0> : tensor loc(#loc) - %9 = stablehlo.broadcast_in_dim %c_8, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) - %10 = stablehlo.compare EQ, %0#4, %9, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) - %11 = stablehlo.broadcast_in_dim %10, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc4) - %cst_9 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc) - %12 = call @_where_1(%11, %0#2, %cst_9) : (tensor<2x1xi1>, tensor<2x3xf32>, tensor) -> tensor<2x3xf32> loc(#loc5) - %c_10 = stablehlo.constant dense<0> : tensor loc(#loc) - %13 = stablehlo.broadcast_in_dim %c_10, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) - %14 = stablehlo.compare EQ, %0#4, %13, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) - %15 = stablehlo.broadcast_in_dim %14, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc4) - %cst_11 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc) - %16 = call @_where_1(%15, %0#3, %cst_11) : (tensor<2x1xi1>, tensor<2x3xf32>, tensor) -> tensor<2x3xf32> loc(#loc5) - return %4, %8, %12, %16 : tensor<2x4x4xf32>, tensor<2x4xf32>, tensor<2x3xf32>, tensor<2x3xf32> loc(#loc) - } loc(#loc) - func.func private @_where(%arg0: tensor<2x1x1xi1> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg1: tensor<2x4x4xf32> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg2: tensor {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1))) -> (tensor<2x4x4xf32> {mhlo.layout_mode = "default"}) { - %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc6) - %1 = stablehlo.broadcast_in_dim %arg2, dims = [] : (tensor) -> tensor<2x4x4xf32> loc(#loc6) - %2 = stablehlo.select %0, %arg1, %1 : tensor<2x4x4xi1>, tensor<2x4x4xf32> loc(#loc7) - return %2 : tensor<2x4x4xf32> loc(#loc5) - } loc(#loc5) - func.func private @_where_0(%arg0: tensor<2x1xi1> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg1: tensor<2x4xf32> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg2: tensor {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1))) -> (tensor<2x4xf32> {mhlo.layout_mode = "default"}) { - %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x4xi1> loc(#loc6) - %1 = stablehlo.broadcast_in_dim %arg2, dims = [] : (tensor) -> tensor<2x4xf32> loc(#loc6) - %2 = stablehlo.select %0, %arg1, %1 : tensor<2x4xi1>, tensor<2x4xf32> loc(#loc7) - return %2 : tensor<2x4xf32> loc(#loc5) - } loc(#loc5) - func.func private @_where_1(%arg0: tensor<2x1xi1> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg1: tensor<2x3xf32> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg2: tensor {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1))) -> (tensor<2x3xf32> {mhlo.layout_mode = "default"}) { - %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x3xi1> loc(#loc6) - %1 = stablehlo.broadcast_in_dim %arg2, dims = [] : (tensor) -> tensor<2x3xf32> loc(#loc6) - %2 = stablehlo.select %0, %arg1, %1 : tensor<2x3xi1>, tensor<2x3xf32> loc(#loc7) - return %2 : tensor<2x3xf32> loc(#loc5) - } loc(#loc5) -} loc(#loc) -#loc = loc(unknown) -#loc2 = loc("jit(func)/jit(main)/tridiagonal"(#loc1)) -#loc3 = loc("jit(func)/jit(main)/eq"(#loc1)) -#loc4 = loc("jit(func)/jit(main)/broadcast_in_dim"(#loc1)) -#loc6 = loc("jit(func)/jit(main)/jit(_where)/broadcast_in_dim"(#loc1)) -#loc7 = loc("jit(func)/jit(main)/jit(_where)/select_n"(#loc1)) -""", - mlir_module_serialized=b'ML\xefR\rStableHLO_v1.3.0\x00\x01#\x05\x01\x05\x13\x01\x03\x0b\x03\x11\x0f\x13\x17\x1b\x1f#\'+\x03\xe9\x93A\x01-\x0f\x07\x0f\x17\x0f\x0f\x0f\x0f\x0f#\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x03g\x0f\x0b\x0b\x0f\x0b\x13\x1f\x0b\x0b/\x1f\x17\x0f\x0b\x0bO\x0b\x0b\x0bO\x1fo/\x0b\x1b\x1b\x0b\x1b\x0b\x1b\x0b\x1b\x0b\x0b\x0b\x0b\x0b\x0bo&\x04\x1f\x1f\x1f\x0b\x0b\x0b\x0b#\x0f\x17#\x01\x05\x0b\x0f\x03=\x0f\x17\x0f\x1b\x17\x17\x07\x07\x13\x07\x07\x13\x1b\x07\x1f\x1f\x1f\x1f\x17\x13\x13\x17\x1b\x13\x17\x13\x13\x13\x13\x13\x02b\t\x1d\x1f\x07\x1f\x1d)\x07\x17!\xde\n\x1b\x1d#\x07\x1d\'\x07\x1d+\x07\x1d%\x07\x11\x03\x05\x03\x07\x15\x17\x19\x11\x1b\x11\x05\x17\x11\x01\x00\x05\x19\x05\x1b\x05\x1d\x05\x1f\x05!\x05#\x05%\x05\'\x05)\x05+\x1f-\x01\x1d-\x1d/\x1f7\x01\x1d1\r\x03/1\x1f\x05\t\x00\x00\x00\x00\t\x07\x07\x01\x1f?\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\t\t\x00\x00\xc0\x7f\x03\x07777\x03\x037\x1d3\x1d5\x1f;!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1d7\x1d9\x1f+!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f\x05\t\x04\x00\x00\x00\x1f91\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f=\x11\x00\x00\x00\x00\x00\x00\x00\x00#!\x03\t_cgk\r\x055a/1\x1d;\r\x055e/1\x1d=\r\x055i/1\x1d?\r\x055m/1\x1dA\x1dC\x1dE###%#\'\x1f31\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x1f\x0b\x02\x02\r\xebV\xbf\xc4\x05 >\xdb-\xd7\xbfx\nt>W&\x1b@bz\xa5\xc0\xf1\xf6g\xbe\x8b\x8bV\xbf1\x98\xbc?\xa5\x11\xaa\xbfNoV\xc0\xe1\r\x93\xbfp\xf6m\xbff\xbbt\xbf\xd8\x01.@%\xa8??\xf2"\xe2?\x94j$@\xb3\xa2]\xbf\xd1\xa8%\xbeI\x05O@\x8a\xb5\xb9\xbe6Qn\xc0\xa9A\x8c@v\xa02@\xdb\xb4\xeb@\n\xd7B\xc0\xc7d9\xc0\xceW7@\xbbS{?E\xa9\x08\xc0\xfe"\xab@\x1f\x05\t\x01\x00\x00\x00\x1f\x05\t\x02\x00\x00\x00\x1f\x05\t\x80\x00\x00\x00\x0b\x05\x1dG\x1dI\x05\x01\x03\r33333W\x03\x03\x8f\x15\x03\x01\x15\x01\x03\rWKKKYY\x01\t\x01\x02\x02)\x01\x1f)\x05\t\r\x13)\x01\x13)\x07\t\x11\x11\x13)\x05\t\x11\x13)\x05\t\x05\x11\x01\t)\x03\t\x1f\x1d\x13)\x03\t\x11)\x07\t\x05\x05\x11\x1b\x11\x01\t\x0b\r\x07\x07\x11\x07\x1d\x0b\t\x03\x0b\x11\x07\x0f\r\t\x03\r\x11\x07\x0f\x07\t\x03\x07)\x05\t\r\x11)\x03\t\x17)\x03\x01\x17)\x05\t\x11\x11)\x07\t\x11\x11\x11)\x03\r\x17)\x03\x02\x04\x13)\x03\x01\x19)\x03\r\x19)\x03\t\x19)\x03\x05\x19)\x03\x05\x17\x04\x8a\x06\x05\x01Q\x03\x13\x01\x07\x04b\x06\x03\x01\x11\x07P\x03\x03\x07\x04\xf6\x03\x03I\x81\x05B\x03\x05\x03\x0b\x05B\x0b\x07\x03\x05\x05B\x0b\t\x03\x05\x05B\x0b\x07\x03\x05\x05B\x0b\x0b\x03\x05\x05B\x0b\r\x03\x05\x11F\x0b\x0f\r\x0b\r\x07\x07\x155\r\x03\x05\x07\t\x0b\x01\x05B\x03\x11\x03\x05\x03F\x05\x13\x03\x15\x03\x19\x0bF\x05\x15\x03\x1b\x05\x15\x1b\x03F\r\x17\x03\x1d\x03\x1d\x05B\x03\x19\x03\t\rF\x01\x1b\x03\x0b\x07\x1f\r!\x05B\x03\x11\x03\x05\x03F\x05\x13\x03\x15\x03%\x0bF\x05\x15\x03\x1b\x05\x15\'\x03F\r\x17\x03\x0f\x03)\x05B\x03\x19\x03\t\rF\x01\x1d\x03\r\x07+\x0f-\x05B\x03\x11\x03\x05\x03F\x05\x13\x03\x15\x031\x0bF\x05\x15\x03\x1b\x05\x153\x03F\r\x17\x03\x0f\x035\x05B\x03\x19\x03\t\rF\x01\x1f\x03\x07\x077\x119\x05B\x03\x11\x03\x05\x03F\x05\x13\x03\x15\x03=\x0bF\x05\x15\x03\x1b\x05\x15?\x03F\r\x17\x03\x0f\x03A\x05B\x03\x19\x03\t\rF\x01\x1f\x03\x07\x07C\x13E\t\x04\x03\t#/;G\x07P\x01!\x07\x04S\x03\r\x13\x07;\x01\x17\x01\x13\x01\x00\x03F\t#\x031\x03\x01\x03F\t\x13\x03\x0b\x03\x05\x0f\x06\x0f\x03\x0b\x07\x07\x03\t\t\x04\x01\x03\x0b\x07P\x01%\x07\x04S\x03\r\x13\x07\x1f\x01\x1b\x01\x13\x01\x00\x03F\t\'\x03/\x03\x01\x03F\t\x13\x03\r\x03\x05\x0f\x06\x0f\x03\r\x07\x07\x03\t\t\x04\x01\x03\x0b\x07P\x01)\x07\x04S\x03\r\x13\x07\x1f\x01\x0f\x01\x13\x01\x00\x03F\t\'\x03)\x03\x01\x03F\t\x13\x03\x07\x03\x05\x0f\x06\x0f\x03\x07\x07\x07\x03\t\t\x04\x01\x03\x0b\x06\x03\x01\x05\x01\x00n\tK\x1d\x03\x0f\x0b\t\t\t\t\x13\x0f\x13\x11!\x11#K/ASci3\x13%)9\x1f\x15\x11\x17\x15\x11\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00func_v1\x00return_v1\x00compare_v1\x00call_v1\x00select_v1\x00custom_call_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00jit(func)/jit(main)/pjit\x00third_party/py/jax/tests/export_back_compat_test.py\x00jit(func)/jit(main)/jit(_where)/broadcast_in_dim\x00jit(func)/jit(main)/jit(_where)/select_n\x00jit(func)/jit(main)/tridiagonal\x00jit(func)/jit(main)/eq\x00jit(func)/jit(main)/broadcast_in_dim\x00mhlo.layout_mode\x00default\x00jax.result_info\x00private\x00_where_1\x00_where\x00_where_0\x00[0]\x00[1]\x00[2]\x00[3]\x00main\x00public\x00\x00lapack_ssytrd\x00\x08\x89+\x05;\x01\x0bM[]oq\x03{\x03U\x03}\x03\x7f\x03\x81\x11\x83\x85\x87M\x89\x8b\x8d\x91\x039\x03-\x05;=\x03?\x03A\x03O\x03Q\x03I\x0bCsEOG\x03y\x0bCuEQG\x03S\x0bCwEIG', - xla_call_module_version=9, - nr_devices=1, -) # End paste - - -# Pasted from the test output (see export_back_compat_test_util.py module docstring) -data_2024_09_03["f64"] = dict( - testdata_version=1, - platform='cpu', - custom_call_targets=['lapack_dsytrd'], - serialized_date=datetime.date(2024, 9, 3), - inputs=(), - expected_outputs=(array([[[ 0.8251247184208595 , -2.6963562039892532 , - 0.8082445002373937 , -1.551980329390836 ], - [-2.629505060186711 , 4.427374205796291 , - -2.2111093161901074 , 7.552489598405787 ], - [ 0.2269453213819231 , 0.3650586474106988 , - -3.5933639667756205 , 4.828829679372501 ], - [-0.6415372293575187 , -0.2519326897319508 , - -1.7607827845801751 , -3.381311711243865 ]], - - [[-4.000421911405985 , 3.6303350337601055 , - 2.8066821235532355 , 1.099224389184342 ], - [-4.141622408467332 , -5.276404169116551 , - -0.8496056221591237 , -2.275319346221659 ], - [ 0.5828958067901202 , 0.9351254869793256 , - 2.7765603683442177 , -4.339686212557215 ], - [-0.6391146585297987 , 0.3129920702652711 , - -0.25441692469349864, -1.4155240723557498 ]]]), array([[ 0.8251247184208595, 4.427374205796291 , -3.5933639667756205, - -3.381311711243865 ], - [-4.000421911405985 , -5.276404169116551 , 2.7765603683442177, - -1.4155240723557498]]), array([[-2.629505060186711 , 0.3650586474106988 , -1.7607827845801751 ], - [-4.141622408467332 , 0.9351254869793256 , -0.25441692469349864]]), array([[1.3669846724688552, 1.8806358893589366, 0. ], - [1.1440109149169537, 1.8215532880266878, 0. ]])), - mlir_module_text=r""" -#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":695:13) -#loc5 = loc("jit(func)/jit(main)/pjit"(#loc1)) -module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main() -> (tensor<2x4x4xf64> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<2x4xf64> {jax.result_info = "[1]", mhlo.layout_mode = "default"}, tensor<2x3xf64> {jax.result_info = "[2]", mhlo.layout_mode = "default"}, tensor<2x3xf64> {jax.result_info = "[3]", mhlo.layout_mode = "default"}) { - %cst = stablehlo.constant dense<[[[0.82512471842085955, -2.6963562039892532, 0.80824450023739369, -1.5519803293908361], [0.96498805326781766, -4.1313349231964409, -2.2111093161901074, 7.5524895984057867], [0.81575339483804743, 1.0647235400727899, -1.0064296232364345, 4.8288296793725012], [-2.3060011529502993, -2.9182106402942192, -1.7781896154088577, 2.5904630742096817]], [[-4.0004219114059847, 3.6303350337601055, 2.8066821235532355, 1.0992243891843421], [0.59643883228393779, -1.5243235004961249, -0.84960562215912372, -2.275319346221659], [2.7617960295487092, -0.57538970930521982, 0.12559406141906576, -4.3396862125572149], [-3.0281643919760217, 0.38177997229319849, 3.860398204232184, -2.5166384340510231]]]> : tensor<2x4x4xf64> loc(#loc) - %c = stablehlo.constant dense<4> : tensor loc(#loc2) - %c_0 = stablehlo.constant dense<1> : tensor loc(#loc2) - %c_1 = stablehlo.constant dense<4> : tensor loc(#loc2) - %c_2 = stablehlo.constant dense<2> : tensor loc(#loc2) - %c_3 = stablehlo.constant dense<128> : tensor loc(#loc2) - %0:6 = stablehlo.custom_call @lapack_dsytrd(%c, %c_0, %c_1, %c_2, %c_3, %cst) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>, dense<[1, 0]> : tensor<2xindex>, dense<[1, 0]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor, tensor<2x4x4xf64>) -> (tensor<2x4x4xf64>, tensor<2x4xf64>, tensor<2x3xf64>, tensor<2x3xf64>, tensor<2xi32>, tensor<128xf64>) loc(#loc2) - %c_4 = stablehlo.constant dense<0> : tensor loc(#loc) - %1 = stablehlo.broadcast_in_dim %c_4, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) - %2 = stablehlo.compare EQ, %0#4, %1, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) - %3 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc4) - %cst_5 = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc) - %4 = call @_where(%3, %0#0, %cst_5) : (tensor<2x1x1xi1>, tensor<2x4x4xf64>, tensor) -> tensor<2x4x4xf64> loc(#loc5) - %c_6 = stablehlo.constant dense<0> : tensor loc(#loc) - %5 = stablehlo.broadcast_in_dim %c_6, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) - %6 = stablehlo.compare EQ, %0#4, %5, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) - %7 = stablehlo.broadcast_in_dim %6, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc4) - %cst_7 = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc) - %8 = call @_where_0(%7, %0#1, %cst_7) : (tensor<2x1xi1>, tensor<2x4xf64>, tensor) -> tensor<2x4xf64> loc(#loc5) - %c_8 = stablehlo.constant dense<0> : tensor loc(#loc) - %9 = stablehlo.broadcast_in_dim %c_8, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) - %10 = stablehlo.compare EQ, %0#4, %9, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) - %11 = stablehlo.broadcast_in_dim %10, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc4) - %cst_9 = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc) - %12 = call @_where_1(%11, %0#2, %cst_9) : (tensor<2x1xi1>, tensor<2x3xf64>, tensor) -> tensor<2x3xf64> loc(#loc5) - %c_10 = stablehlo.constant dense<0> : tensor loc(#loc) - %13 = stablehlo.broadcast_in_dim %c_10, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) - %14 = stablehlo.compare EQ, %0#4, %13, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) - %15 = stablehlo.broadcast_in_dim %14, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc4) - %cst_11 = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc) - %16 = call @_where_1(%15, %0#3, %cst_11) : (tensor<2x1xi1>, tensor<2x3xf64>, tensor) -> tensor<2x3xf64> loc(#loc5) - return %4, %8, %12, %16 : tensor<2x4x4xf64>, tensor<2x4xf64>, tensor<2x3xf64>, tensor<2x3xf64> loc(#loc) - } loc(#loc) - func.func private @_where(%arg0: tensor<2x1x1xi1> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg1: tensor<2x4x4xf64> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg2: tensor {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1))) -> (tensor<2x4x4xf64> {mhlo.layout_mode = "default"}) { - %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc6) - %1 = stablehlo.broadcast_in_dim %arg2, dims = [] : (tensor) -> tensor<2x4x4xf64> loc(#loc6) - %2 = stablehlo.select %0, %arg1, %1 : tensor<2x4x4xi1>, tensor<2x4x4xf64> loc(#loc7) - return %2 : tensor<2x4x4xf64> loc(#loc5) - } loc(#loc5) - func.func private @_where_0(%arg0: tensor<2x1xi1> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg1: tensor<2x4xf64> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg2: tensor {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1))) -> (tensor<2x4xf64> {mhlo.layout_mode = "default"}) { - %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x4xi1> loc(#loc6) - %1 = stablehlo.broadcast_in_dim %arg2, dims = [] : (tensor) -> tensor<2x4xf64> loc(#loc6) - %2 = stablehlo.select %0, %arg1, %1 : tensor<2x4xi1>, tensor<2x4xf64> loc(#loc7) - return %2 : tensor<2x4xf64> loc(#loc5) - } loc(#loc5) - func.func private @_where_1(%arg0: tensor<2x1xi1> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg1: tensor<2x3xf64> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1)), %arg2: tensor {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc1))) -> (tensor<2x3xf64> {mhlo.layout_mode = "default"}) { - %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x3xi1> loc(#loc6) - %1 = stablehlo.broadcast_in_dim %arg2, dims = [] : (tensor) -> tensor<2x3xf64> loc(#loc6) - %2 = stablehlo.select %0, %arg1, %1 : tensor<2x3xi1>, tensor<2x3xf64> loc(#loc7) - return %2 : tensor<2x3xf64> loc(#loc5) - } loc(#loc5) -} loc(#loc) -#loc = loc(unknown) -#loc2 = loc("jit(func)/jit(main)/tridiagonal"(#loc1)) -#loc3 = loc("jit(func)/jit(main)/eq"(#loc1)) -#loc4 = loc("jit(func)/jit(main)/broadcast_in_dim"(#loc1)) -#loc6 = loc("jit(func)/jit(main)/jit(_where)/broadcast_in_dim"(#loc1)) -#loc7 = loc("jit(func)/jit(main)/jit(_where)/select_n"(#loc1)) -""", - mlir_module_serialized=b'ML\xefR\rStableHLO_v1.3.0\x00\x01#\x05\x01\x05\x13\x01\x03\x0b\x03\x11\x0f\x13\x17\x1b\x1f#\'+\x03\xe9\x93A\x01-\x0f\x07\x0f\x17\x0f\x0f\x0f\x0f\x0f#\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x03g\x0f\x0b\x0b\x0f\x0b\x13\x1f\x0b\x0b//\x17\x0f\x0b\x0bO\x0b\x0b\x0bO\x1fo/\x0b\x1b\x1b\x0b\x1b\x0b\x1b\x0b\x1b\x0b\x0b\x0b\x0b\x0b\x0bo&\x08\x1f\x1f\x1f\x0b\x0b\x0b\x0b#\x0f\x17#\x01\x05\x0b\x0f\x03=\x0f\x17\x0f\x1b\x17\x17\x07\x07\x13\x07\x07\x13\x1b\x07\x1f\x1f\x1f\x1f\x17\x13\x13\x17\x1b\x13\x17\x13\x13\x13\x13\x13\x02r\x0b\x1d\x1f\x07\x1f\x1d)\x07\x17!\xde\n\x1b\x1d#\x07\x1d\'\x07\x1d+\x07\x1d%\x07\x11\x03\x05\x03\x07\x15\x17\x19\x11\x1b\x11\x05\x17\x11\x01\x00\x05\x19\x05\x1b\x05\x1d\x05\x1f\x05!\x05#\x05%\x05\'\x05)\x05+\x1f-\x01\x1d-\x1d/\x1f7\x01\x1d1\r\x03/1\x1f\x05\t\x00\x00\x00\x00\t\x07\x07\x01\x1f?\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\t\x11\x00\x00\x00\x00\x00\x00\xf8\x7f\x03\x07777\x03\x037\x1d3\x1d5\x1f;!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1d7\x1d9\x1f+!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f\x05\t\x04\x00\x00\x00\x1f91\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f=\x11\x00\x00\x00\x00\x00\x00\x00\x00#!\x03\t_cgk\r\x055a/1\x1d;\r\x055e/1\x1d=\r\x055i/1\x1d?\r\x055m/1\x1dA\x1dC\x1dE###%#\'\x1f31\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x1f\x0b\x02\x04A\xa4\x17\xf4kg\xea?\x1f\x01\x943#\x92\x05\xc0\x86 \xf6\x91#\xdd\xe9?\x9dMlS\xe9\xd4\xf8\xbf\x88\x1c:\xa0.\xe1\xee?8\xce\x7f\xa9|\x86\x10\xc0\xe8V\xc7\x14Z\xb0\x01\xc0\xd2!R\xd5\xbf5\x1e@\xbf\xc5\r\xdd\xa6\x1a\xea?\xbcM\xfe\x8c\x1b\t\xf1?\xdbj\xd8\xf2U\x1a\xf0\xbf\xado;\xba\xb8P\x13@\xbb\xad\x83\xbb\xb0r\x02\xc0\x1f9\xf7\xd1~X\x07\xc0)ID\xf4vs\xfc\xbfD\xcfI\xb4D\xb9\x04@\x16\xc3\xfe\x99n\x00\x10\xc0\x82.\x1c\x18\xed\n\r@\x8cn\xd7\xc1\x15t\x06@|2(Pl\x96\xf1?\x88*\xd7\xe3\x06\x16\xe3?F{\xf2\t\xa1c\xf8\xbf8z5!\xf8/\xeb\xbf4\xd3\x1f\xa1\xda3\x02\xc0)\x13I\x84(\x18\x06@\xbcw\xfd\xad\x97i\xe2\xbf\x1e\xf0.Yw\x13\xc0?dW\xd7\xb3\xd6[\x11\xc0\x04\x97\xb3@\xae9\x08\xc0\xbc\x17\xd1C\x15o\xd8?\x02\xb7%t\x18\xe2\x0e@\xac\xd8\xd0T\x13"\x04\xc0\x1f\x05\t\x01\x00\x00\x00\x1f\x05\t\x02\x00\x00\x00\x1f\x05\t\x80\x00\x00\x00\x0b\x05\x1dG\x1dI\x05\x01\x03\r33333W\x03\x03\x8f\x15\x03\x01\x15\x01\x03\rWKKKYY\x01\t\x01\x02\x02)\x01\x1f)\x05\t\r\x13)\x01\x13)\x07\t\x11\x11\x13)\x05\t\x11\x13)\x05\t\x05\x11\x01\x0b)\x03\t\x1f\x1d\x13)\x03\t\x11)\x07\t\x05\x05\x11\x1b\x11\x01\t\x0b\r\x07\x07\x11\x07\x1d\x0b\t\x03\x0b\x11\x07\x0f\r\t\x03\r\x11\x07\x0f\x07\t\x03\x07)\x05\t\r\x11)\x03\t\x17)\x03\x01\x17)\x05\t\x11\x11)\x07\t\x11\x11\x11)\x03\r\x17)\x03\x02\x04\x13)\x03\x01\x19)\x03\r\x19)\x03\t\x19)\x03\x05\x19)\x03\x05\x17\x04\x8a\x06\x05\x01Q\x03\x13\x01\x07\x04b\x06\x03\x01\x11\x07P\x03\x03\x07\x04\xf6\x03\x03I\x81\x05B\x03\x05\x03\x0b\x05B\x0b\x07\x03\x05\x05B\x0b\t\x03\x05\x05B\x0b\x07\x03\x05\x05B\x0b\x0b\x03\x05\x05B\x0b\r\x03\x05\x11F\x0b\x0f\r\x0b\r\x07\x07\x155\r\x03\x05\x07\t\x0b\x01\x05B\x03\x11\x03\x05\x03F\x05\x13\x03\x15\x03\x19\x0bF\x05\x15\x03\x1b\x05\x15\x1b\x03F\r\x17\x03\x1d\x03\x1d\x05B\x03\x19\x03\t\rF\x01\x1b\x03\x0b\x07\x1f\r!\x05B\x03\x11\x03\x05\x03F\x05\x13\x03\x15\x03%\x0bF\x05\x15\x03\x1b\x05\x15\'\x03F\r\x17\x03\x0f\x03)\x05B\x03\x19\x03\t\rF\x01\x1d\x03\r\x07+\x0f-\x05B\x03\x11\x03\x05\x03F\x05\x13\x03\x15\x031\x0bF\x05\x15\x03\x1b\x05\x153\x03F\r\x17\x03\x0f\x035\x05B\x03\x19\x03\t\rF\x01\x1f\x03\x07\x077\x119\x05B\x03\x11\x03\x05\x03F\x05\x13\x03\x15\x03=\x0bF\x05\x15\x03\x1b\x05\x15?\x03F\r\x17\x03\x0f\x03A\x05B\x03\x19\x03\t\rF\x01\x1f\x03\x07\x07C\x13E\t\x04\x03\t#/;G\x07P\x01!\x07\x04S\x03\r\x13\x07;\x01\x17\x01\x13\x01\x00\x03F\t#\x031\x03\x01\x03F\t\x13\x03\x0b\x03\x05\x0f\x06\x0f\x03\x0b\x07\x07\x03\t\t\x04\x01\x03\x0b\x07P\x01%\x07\x04S\x03\r\x13\x07\x1f\x01\x1b\x01\x13\x01\x00\x03F\t\'\x03/\x03\x01\x03F\t\x13\x03\r\x03\x05\x0f\x06\x0f\x03\r\x07\x07\x03\t\t\x04\x01\x03\x0b\x07P\x01)\x07\x04S\x03\r\x13\x07\x1f\x01\x0f\x01\x13\x01\x00\x03F\t\'\x03)\x03\x01\x03F\t\x13\x03\x07\x03\x05\x0f\x06\x0f\x03\x07\x07\x07\x03\t\t\x04\x01\x03\x0b\x06\x03\x01\x05\x01\x00n\tK\x1d\x03\x0f\x0b\t\t\t\t\x13\x0f\x13\x11!\x11#K/ASci3\x13%)9\x1f\x15\x11\x17\x15\x11\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00func_v1\x00return_v1\x00compare_v1\x00call_v1\x00select_v1\x00custom_call_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00jit(func)/jit(main)/pjit\x00third_party/py/jax/tests/export_back_compat_test.py\x00jit(func)/jit(main)/jit(_where)/broadcast_in_dim\x00jit(func)/jit(main)/jit(_where)/select_n\x00jit(func)/jit(main)/tridiagonal\x00jit(func)/jit(main)/eq\x00jit(func)/jit(main)/broadcast_in_dim\x00mhlo.layout_mode\x00default\x00jax.result_info\x00private\x00_where_1\x00_where\x00_where_0\x00[0]\x00[1]\x00[2]\x00[3]\x00main\x00public\x00\x00lapack_dsytrd\x00\x08\x89+\x05;\x01\x0bM[]oq\x03{\x03U\x03}\x03\x7f\x03\x81\x11\x83\x85\x87M\x89\x8b\x8d\x91\x039\x03-\x05;=\x03?\x03A\x03O\x03Q\x03I\x0bCsEOG\x03y\x0bCuEQG\x03S\x0bCwEIG', - xla_call_module_version=9, - nr_devices=1, -) # End paste - data_2024_12_01 = {} - # Pasted from the test output (see export_back_compat_test_util.py module docstring) data_2024_12_01["c128"] = dict( testdata_version=1, diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/cuda_eigh_cusolver_syev.py b/jax/_src/internal_test_util/export_back_compat_test_data/cuda_eigh_cusolver_syev.py index 56479e82f9d9..4b7c37723cc1 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_data/cuda_eigh_cusolver_syev.py +++ b/jax/_src/internal_test_util/export_back_compat_test_data/cuda_eigh_cusolver_syev.py @@ -17,1399 +17,6 @@ import datetime from numpy import array, float32, complex64 -data_2023_03_17=dict( - # Pasted from the test output (see back_compat_test.py module docstring) - f32_syevj=dict( - testdata_version=1, - platform='cuda', - custom_call_targets=['cusolver_syevj'], - serialized_date=datetime.date(2023, 3, 17), - inputs=(), - expected_outputs=(array([[ 6.18577063e-01, -8.00570633e-05, -1.96905047e-01, - -8.95753130e-02, 7.24549413e-01, -1.07546024e-01, - -4.77200520e-04, 1.84469908e-01], - [ 4.70708847e-01, 3.31519186e-05, 2.80930042e-01, - -5.84393919e-01, -4.93098050e-01, -2.50211239e-01, - -1.14346610e-03, 2.28566617e-01], - [ 3.22840720e-01, -5.11042356e-01, -3.03526163e-01, - 2.48800799e-01, -3.14544559e-01, 5.54342926e-01, - 1.10838346e-06, 2.72663534e-01], - [ 1.74972475e-01, 4.18093473e-01, -2.66933769e-01, - 5.78716159e-01, -2.97307134e-01, -4.46864694e-01, - 1.09066934e-06, 3.16760242e-01], - [ 2.71042082e-02, 4.29418474e-01, 4.71952170e-01, - 1.10573582e-01, 9.57800150e-02, 4.65731144e-01, - -4.72866714e-01, 3.60856950e-01], - [-1.20763958e-01, -3.84347916e-01, 5.79687178e-01, - 2.87678182e-01, 1.63329691e-01, -2.02215970e-01, - 4.32829827e-01, 4.04953718e-01], - [-2.68632114e-01, 3.63640338e-01, -2.97110289e-01, - -3.32554609e-01, 3.46945561e-02, 2.77071655e-01, - 5.63131213e-01, 4.49050426e-01], - [-4.16500419e-01, -3.15715015e-01, -2.68094122e-01, - -2.19244853e-01, 8.65960941e-02, -2.90307850e-01, - -5.21475971e-01, 4.93147314e-01]], dtype=float32), array([-2.4598812e+01, -2.4345848e-06, -1.2664314e-06, -8.6959182e-07, - -8.2917722e-07, 1.6633214e-06, 2.0499781e-06, 2.7659885e+02], - dtype=float32)), - mlir_module_text=""" -module @jit__lambda_ { - func.func public @main() -> (tensor<8x8xf32> {jax.result_info = "[0]"}, tensor<8xf32> {jax.result_info = "[1]"}) { - %0 = stablehlo.iota dim = 0 : tensor<64xf32> - %1 = stablehlo.reshape %0 : (tensor<64xf32>) -> tensor<8x8xf32> - %2 = stablehlo.transpose %1, dims = [1, 0] : (tensor<8x8xf32>) -> tensor<8x8xf32> - %3 = stablehlo.add %1, %2 : tensor<8x8xf32> - %4 = stablehlo.constant dense<2.000000e+00> : tensor - %5 = stablehlo.broadcast_in_dim %4, dims = [] : (tensor) -> tensor<8x8xf32> - %6 = stablehlo.divide %3, %5 : tensor<8x8xf32> - %7 = call @tril(%6) : (tensor<8x8xf32>) -> tensor<8x8xf32> - %8 = stablehlo.custom_call @cusolver_syevj(%7) {api_version = 2 : i32, backend_config = "\00\00\00\00\00\00\00\00\01\00\00\00\08\00\00\00M\08\00\00", operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>]} : (tensor<8x8xf32>) -> tuple, tensor<8xf32>, tensor, tensor<2125xf32>> - %9 = stablehlo.get_tuple_element %8[0] : (tuple, tensor<8xf32>, tensor, tensor<2125xf32>>) -> tensor<8x8xf32> - %10 = stablehlo.get_tuple_element %8[1] : (tuple, tensor<8xf32>, tensor, tensor<2125xf32>>) -> tensor<8xf32> - %11 = stablehlo.get_tuple_element %8[2] : (tuple, tensor<8xf32>, tensor, tensor<2125xf32>>) -> tensor - %12 = stablehlo.get_tuple_element %8[3] : (tuple, tensor<8xf32>, tensor, tensor<2125xf32>>) -> tensor<2125xf32> - %13 = stablehlo.constant dense<0> : tensor - %14 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor - %15 = stablehlo.compare EQ, %11, %14, SIGNED : (tensor, tensor) -> tensor - %16 = stablehlo.broadcast_in_dim %15, dims = [] : (tensor) -> tensor<1x1xi1> - %17 = stablehlo.constant dense<0x7FC00000> : tensor - %18 = stablehlo.broadcast_in_dim %17, dims = [] : (tensor) -> tensor<8x8xf32> - %19 = stablehlo.broadcast_in_dim %16, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<8x8xi1> - %20 = stablehlo.select %19, %9, %18 : tensor<8x8xi1>, tensor<8x8xf32> - %21 = stablehlo.broadcast_in_dim %15, dims = [] : (tensor) -> tensor<1xi1> - %22 = stablehlo.constant dense<0x7FC00000> : tensor - %23 = stablehlo.broadcast_in_dim %22, dims = [] : (tensor) -> tensor<8xf32> - %24 = stablehlo.broadcast_in_dim %21, dims = [0] : (tensor<1xi1>) -> tensor<8xi1> - %25 = stablehlo.select %24, %10, %23 : tensor<8xi1>, tensor<8xf32> - return %20, %25 : tensor<8x8xf32>, tensor<8xf32> - } - func.func private @tril(%arg0: tensor<8x8xf32>) -> tensor<8x8xf32> { - %0 = stablehlo.iota dim = 0 : tensor<8x8xi32> - %1 = stablehlo.constant dense<0> : tensor - %2 = stablehlo.broadcast_in_dim %1, dims = [] : (tensor) -> tensor<8x8xi32> - %3 = stablehlo.add %0, %2 : tensor<8x8xi32> - %4 = stablehlo.iota dim = 1 : tensor<8x8xi32> - %5 = stablehlo.compare GE, %3, %4, SIGNED : (tensor<8x8xi32>, tensor<8x8xi32>) -> tensor<8x8xi1> - %6 = stablehlo.constant dense<0.000000e+00> : tensor - %7 = stablehlo.broadcast_in_dim %6, dims = [] : (tensor) -> tensor<8x8xf32> - %8 = stablehlo.select %5, %arg0, %7 : tensor<8x8xi1>, tensor<8x8xf32> - return %8 : tensor<8x8xf32> - } -} -""", - mlir_module_serialized=b"ML\xefR\x03MLIRxxx-trunk\x00\x01-\x05\x01\x05\x01\x03\x05\x03\x1d\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x1f!\x03^\x02\xeb5\x01\x95\x0f\x17\x13\x07\x0f\x0b\x0b\x0b\x0b\x0b\x17\x0b\x0b\x0b\x0b\x13\x0b\x13\x0f\x0b\x0b\x17\x0f\x13\x13\x0b33\x0b\x0f\x0b\x0b\x13\x0f\x0b\x1b\x0f\x0b\x13\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x13\x0b\x0f\x0b\x0f\x0b\x13\x0b\x13\x0bK\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x13\x13\x13\x1b\x13\x13\x03W\x0b\x0b\x0f\x0b\x0bO/\x0b\x13\x13\x0b\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x1f\x0f\x0f\x0b\x1fO\x1f\x0b\x0b\x0b\x0b\x0f\x0f\x17\x1b\x0f\x0f\x0f\x0f\x0f\x0b\x1fO/\x035\x17\x0f\x07\x0f\x07\x13\x07\x07\x17\x07\x17\x13\x17\x17\x17\x13\x17\x1b\x13\x13\x13\x0f\x17\x13\x13\x13\x02r\x08\x1d\x85\x03\x17\x116\x04\x01\x03\x03\x13\xbd\x1f\x1d9\x03\x05#\x05%\x05'\x05)\x05+\x17\x112\x04\x01\x05-\x05/\x051\x053\x03\x03!\xb9\x055\x03\x03\x0b\xbb\x1d?\x03\x057\x059\x17\x11*\x04\x01\x1dm\x15\x03\x03\x0b\xe5\x03\x03\x0f3\x05;\x03\x0b\x17\x95\x19\xa3\x1b\xa5\x0f\xaf\x1d\xb1\x03\x0b\x17\x99\x19\xb5\x1b\x99\x0f\x9b\x1d\xb7\x05=\x1d=\x03\x05?\x05A\x03\x03!\xbf\x1dE\x03\x05C\x03\x05'\x9d)\xc1\x1dK\x03\x05E\x03\x03\x0b\xc3\x1dQ\x03\x05G\x1dU\x03\x05I\x1dY+\x05K\x1d]+\x05M\x03\x03a\xc5\x05O\x1de\x15\x05Q\x1di\x15\x05S\x03\x03\x0b\xc7\x05U\x03\x03q\x9b\x05W\x03\x11u\xc9w\xcby\xcd{\x95}\xcf\x7f\xd1\x81\xd3\x83\xd7\x05Y\x05[\x05]\x05_\x05a\x05c\x05e\x05g\x05i\x03\x03\r\xdb\x03\x03\r\xdd\x03\x03\r\xdf\x03\x03\r\xe1\x03\x05'\x9d)\xe3\x03\x03\x13\xe7\x03\x03\x13\xe9\x03\x01\x1dk\x03\x03\xb3\x1dm\t\x07\x1f%!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f'\x11\x00\x00\x00\x00\x00\x00\x00\x00#\x1b\x03\x05\xa7\xab\r\x03\x97\xa9\x1do\r\x03\x97\xad\x1dq\x1ds\x1du\r\x01#\x1d\x1dw\x13\r\x01\x1f\x07\t\x00\x00\x00\x00\x1f\x1f\x01\x13\r\x05\x07\x05\x1f\x03\t\x00\x00\x00\x00\x1f\x17!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x03\t\x00\x00\x00@\x0b\x05\x1dy\x1d{\x05\x01\x03\x03\x9f\x03\x03\xd5\x15\x03\x01\x01\x01\x03\t\x9f\xa1\xd9\xa1\x1f)\x01\x13\x05\x01\x13\x05\x05\x13\x05\t\x13\x05\r\x07\x01\x1f\x03\t\x00\x00\xc0\x7f\x1f\x17!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f3\x11\x00\x00\x00\x00\x00\x00\x00\x00)\x05!!\t)\x01\t\x1b)\x01\x05\t)\x03!\t\x1d\x01)\x05!!\x05\x13)\x05!!\x0f)\x03\t\r)\x03jB\t\x11\x01\x05\x01\x0b\x11\x03\x01\x03\x01)\x03\x01\r)\x03\x02\x02\t/\t\x01\x0b\x07\x19)\x03\t\x13)\x03\x05\x13)\x03\x01\x13)\x01\x0f)\x05\x05\x05\x0f)\x03\x05\x0f)\x03!\x0f)\x03\x05\r\x04\xc6\x04\x05\x01\x11\x071\x07\x03\x01\t\r\x11\x075\x05\x035m\t\x03W\x1f\x03!\x15\x06[\x03\x01\x03\x01\x17\x07c_\x03\x01\x03\x03\x0f\x06g\x03\x01\x05\x03\x05\x05\x03\x07k\x03\x03\x03\x07-\x05\x03\x01\x03\t\x19\x06-\x03\x01\x05\x07\x0b\x1b\x07\to\x03\x01\x03\r\x1d\x07\x01s\x03#\x03\x0f\x07\x07\x01\x87\x03\x01\x03\x11\x07\x07\x01\x89\x03\x0b\x03\x11\x07\x07\x01\x8b\x03\x07\x03\x11\x07\x07\x01\x8d\x03\x19\x03\x11\x05\x03\x01#\x03\x07\x03\x07\x01\x05\x03\x07\x03\x1b\x11\x07\x01\x8f\x03+\x05\x17\x1d\x03\x07\x01\x05\x03-\x03\x1f\x05\x03\x01/\x03\x03\x03\x07\x01\x05\x03\x01\x03#\x03\x07\x01\x91\x03\x15\x03!\x0b\x06\x01\x03\x01\x07'\x13%\x03\x07\x01\x05\x03/\x03\x1f\x05\x03\x01/\x03\x03\x03\x07\x01\x05\x03\x0b\x03-\x03\x07\x01\x93\x031\x03+\x0b\x06\x01\x03\x0b\x071\x15/\x13\x04\x07\x05)3\r\x11\t7\x05\x03\x15+\x03\x01\x07\t\x03;\x1f\x03\x11\x05\x03\t#\x03\x07\x03\x07%\x05\x03\x11\x03\x05\x0f\x06%\x03\x11\x05\x03\x07\t\x03CA\x03\x11\x11\x07IG\x03\x15\x05\t\x0b\x05\x03\tM\x03\x03\x03\x07O\x05\x03\x01\x03\x0f\x0b\x06S\x03\x01\x07\r\x01\x11\x13\x04\t\x03\x13\x06\x03\x01\x05\x01\x00\x06\x1a}\x1f+\x11\x0f\x0b\t\t\x0b!\x7f\x1f/!!)#\x1f\x19\x0f99m\x19\x85\x89W\xb3K\x9bM\x9b\x96\x04\x1b+\x1b\x1f\x1f\x15\x1d\x15+\x83\x13\r\r\x1f\x11\x15\x1b\x17\x15\x17\x0f\x11\x15\x11+\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00get_tuple_element_v1\x00iota_v1\x00select_v1\x00func_v1\x00add_v1\x00compare_v1\x00return_v1\x00reshape_v1\x00transpose_v1\x00divide_v1\x00call_v1\x00custom_call_v1\x00value\x00index\x00sym_name\x00third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py\x00broadcast_dimensions\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00compare_type\x00comparison_direction\x00jit__lambda_\x00jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False,) name=tril in_positional_semantics=(<_PositionalSemantics.GLOBAL: 1>,) out_positional_semantics=_PositionalSemantics.GLOBAL keep_unused=False inline=False]\x00jit()/jit(main)/jit(tril)/iota[dtype=int32 shape=(8, 8) dimension=0]\x00jit()/jit(main)/jit(tril)/add\x00jit()/jit(main)/jit(tril)/iota[dtype=int32 shape=(8, 8) dimension=1]\x00jit()/jit(main)/jit(tril)/ge\x00jit()/jit(main)/jit(tril)/broadcast_in_dim[shape=(8, 8) broadcast_dimensions=()]\x00jit()/jit(main)/jit(tril)/select_n\x00jit()/jit(main)/iota[dtype=float32 shape=(64,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(8, 8) dimensions=None]\x00permutation\x00jit()/jit(main)/transpose[permutation=(1, 0)]\x00jit()/jit(main)/add\x00jit()/jit(main)/div\x00callee\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jit()/jit(main)/eigh[lower=True sort_eigenvalues=True]\x00jax.result_info\x00tril\x00[0]\x00[1]\x00main\x00public\x00private\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x08\x00\x00\x00M\x08\x00\x00\x00cusolver_syevj\x00", - xla_call_module_version=4, - ), # End paste - - # Pasted from the test output (see back_compat_test.py module docstring) - f32_syevd=dict( - testdata_version=1, - platform='cuda', - custom_call_targets=['cusolver_syevd'], - serialized_date=datetime.date(2023, 3, 17), - inputs=(), - expected_outputs=(array([[ 3.14863890e-01, 0.00000000e+00, 0.00000000e+00, - 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, - 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, - 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, - 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, - 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, - 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, - -4.91220355e-01, 0.00000000e+00, 0.00000000e+00, - 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, - 0.00000000e+00, 0.00000000e+00, 8.05416584e-01, - 0.00000000e+00, -1.77893345e-03, -2.64500137e-02, - 1.46598322e-04, -5.19353598e-02, -8.64148438e-02], - [ 2.99391806e-01, 2.77544819e-02, 6.73292065e-03, - -6.83086272e-03, -3.54272849e-03, -1.21014733e-02, - -1.32716037e-02, -1.15843862e-03, -8.83520208e-03, - -6.63395738e-03, 1.60171092e-03, -1.01765711e-03, - 1.19860061e-02, -1.33239310e-02, 1.76237477e-03, - 1.27085261e-02, 3.38556734e-03, -8.78101215e-03, - 1.58616400e-03, -7.37631368e-03, 3.81911686e-03, - -5.18379211e-02, -7.22059654e-03, 1.85085051e-02, - 2.94725411e-03, 4.74284729e-03, -1.33781182e-02, - -3.61499190e-03, -5.49228955e-03, -1.05845921e-01, - 1.01772454e-02, 4.47412670e-01, 1.95654288e-01, - 3.94686669e-01, 7.00925171e-01, -9.06614065e-02], - [ 2.83920437e-01, 1.69272088e-02, 6.64264262e-02, - -1.18565477e-01, 3.54601629e-02, -1.52457461e-01, - 6.84847543e-03, 1.90414500e-03, -2.76310533e-01, - 3.76881436e-02, 1.22269124e-01, -1.01556584e-01, - -1.90264836e-01, -1.16590485e-01, 6.09031200e-01, - -9.43092555e-02, -3.74726858e-03, -2.33182713e-01, - 1.95203945e-01, -1.20613754e-01, 3.94887812e-02, - -5.88066364e-03, 1.19152360e-01, -1.46030456e-01, - -4.74781469e-02, 2.67041594e-01, -1.22617789e-01, - 5.77996820e-02, 2.58437768e-02, -1.34434626e-01, - -3.28330845e-02, -9.32494774e-02, 1.14714004e-01, - 1.21207587e-01, -2.04871535e-01, -9.49072391e-02], - [ 2.68448830e-01, 2.17946004e-02, -1.94895901e-02, - 3.40374447e-02, 6.18659109e-02, 1.72068894e-01, - -8.02555401e-03, 9.68076065e-02, 4.98391055e-02, - 5.55528253e-02, -3.23998183e-02, -2.63249427e-01, - -4.35045222e-03, 5.20016700e-02, -5.92328422e-02, - 4.31317724e-02, -2.00986061e-02, -2.69871447e-02, - 1.54309347e-01, 1.74670279e-01, -4.97168908e-03, - -4.15510803e-01, -4.33471389e-02, -3.71299796e-02, - 5.26434295e-02, -1.18867345e-01, -2.42547281e-02, - -3.90263759e-02, -2.58720964e-01, -3.92957211e-01, - -1.28192365e-01, 2.77028710e-01, -4.02157485e-01, - -1.77024350e-01, -1.76668167e-01, -9.91534367e-02], - [ 2.52977222e-01, 3.48518007e-02, 7.02044442e-02, - 1.42712081e-02, 4.50692251e-02, 7.16193160e-03, - 1.19931757e-01, 2.32399218e-02, -6.05047755e-02, - 1.06077030e-01, 1.03731848e-01, -1.13200452e-02, - 5.94755262e-03, -2.32813850e-01, 8.72232541e-02, - 8.17264095e-02, 3.30835059e-02, 4.88227099e-01, - 6.14454560e-02, 1.43805355e-01, -7.40422234e-02, - 2.25823849e-01, -3.86487693e-01, 1.30468249e-01, - 3.16427708e-01, -1.19733319e-01, -4.18486483e-02, - -2.74667948e-01, -2.16731444e-01, 2.60375626e-02, - 5.77645637e-02, -7.56322592e-02, 2.28632554e-01, - 2.37157010e-02, -1.40153974e-01, -1.03399649e-01], - [ 2.37505659e-01, 7.01064467e-02, -3.83728333e-02, - 5.06979637e-02, 1.83892641e-02, 4.02548499e-02, - -3.88330072e-02, 3.13181393e-02, -5.75652197e-02, - 7.04995319e-02, -6.92743529e-03, -9.82947052e-02, - -4.91717793e-02, 4.06844541e-02, -1.53035461e-03, - 4.68783826e-02, 5.36918640e-03, -1.67432979e-01, - 1.03467651e-01, 3.48554403e-02, 3.20128165e-02, - 4.70223904e-01, 9.19904634e-02, 6.90946281e-02, - -6.94891065e-02, 3.92344594e-02, -6.30731881e-02, - 2.22810470e-02, -3.87494615e-03, 1.96694940e-01, - -1.92701817e-02, 2.01028123e-01, 1.89283062e-02, - -6.97807550e-01, 2.03354478e-01, -1.07645869e-01], - [ 2.22034067e-01, -1.60748392e-01, 2.42968962e-01, - -3.35482806e-01, -3.41870189e-02, 1.28819138e-01, - 1.24212839e-01, -3.87125909e-02, -5.60933471e-01, - 7.95257688e-02, -3.60307507e-02, 3.67332071e-01, - -5.87672107e-02, 7.33083040e-02, -3.94398779e-01, - -7.60597512e-02, 1.71925854e-02, 1.17799109e-02, - -2.65986789e-02, 1.98394638e-02, -1.35528380e-02, - -3.39059532e-02, 9.92002785e-02, -7.92167559e-02, - 9.19176906e-04, -4.89958897e-02, 5.72972372e-02, - 1.21006947e-02, 4.03640568e-02, -1.18844979e-01, - -2.80744191e-02, -1.74218431e-01, -4.31395955e-02, - -6.09265082e-02, 3.76862884e-02, -1.11892074e-01], - [ 2.06562474e-01, 1.73960440e-02, -2.63249487e-01, - 1.38902217e-01, -4.79032584e-02, -2.24852517e-01, - 4.69521992e-02, -3.35566737e-02, 1.37603536e-01, - -5.11448458e-02, 8.18398222e-02, 1.07205749e-01, - -1.46739393e-01, -1.30916521e-01, -2.28276670e-01, - -7.91462511e-02, 6.24803789e-02, 4.59876209e-02, - 8.15130547e-02, 1.46908918e-02, -2.61019613e-03, - 1.13239333e-01, 2.98404664e-01, -1.80148214e-01, - 1.44556239e-01, -3.98542970e-01, -4.15323582e-03, - 4.42554235e-01, 4.46505845e-02, -3.50878686e-02, - -1.36736231e-02, 1.28197059e-01, 1.92225441e-01, - 9.25138816e-02, -2.71676213e-01, -1.16138257e-01], - [ 1.91090912e-01, -3.68523598e-02, -6.60930753e-01, - 3.02158773e-01, 1.77503861e-02, 1.00428194e-01, - -1.10393446e-02, 9.11340117e-03, -7.01573640e-02, - -3.42316413e-03, -7.93189174e-05, 2.59178817e-01, - 1.22925844e-02, 6.14976510e-02, -1.56667307e-01, - -5.03374226e-02, -4.95696850e-02, -1.59401018e-02, - -4.26767953e-02, -5.12050986e-02, -6.04047906e-03, - 5.44762500e-02, -1.07276395e-01, -1.12534806e-01, - -1.20743208e-01, 3.80993217e-01, -2.20808387e-02, - -2.89817184e-01, 3.23761255e-02, -6.17432930e-02, - -3.90686616e-02, -5.96804358e-02, -4.96021062e-02, - 8.57739672e-02, -8.64073634e-02, -1.20384485e-01], - [ 1.75619304e-01, 1.71932317e-02, 4.29833472e-01, - 8.81271958e-02, -3.94745134e-02, -5.61874844e-02, - 7.05854744e-02, 7.86138419e-03, 4.67237175e-01, - -1.88353360e-02, 6.92435876e-02, -3.38627174e-02, - -8.19625556e-02, -4.84902970e-02, -2.62022078e-01, - -1.48765266e-01, 7.19114691e-02, -1.21600203e-01, - 1.18209779e-01, 2.58331411e-02, 4.69931588e-02, - 9.96347591e-02, 2.32059956e-01, -1.78489253e-01, - 1.77511200e-03, 1.59484446e-01, 3.28991674e-02, - -4.70239580e-01, 1.65105104e-01, -2.61324756e-02, - -1.49319443e-04, -8.15570727e-02, 7.44131976e-05, - 8.14437792e-02, -7.25714415e-02, -1.24630690e-01], - [ 1.60147712e-01, -1.10780589e-01, 2.73144871e-01, - 1.10703602e-01, 2.37337053e-02, 4.52041216e-02, - 1.52682560e-02, -3.83009948e-02, 2.30164632e-01, - 2.54375394e-02, -3.03758867e-02, 8.13979190e-03, - 2.33282149e-02, 3.12441736e-02, -1.84844747e-01, - 2.14728359e-02, -5.53616770e-02, -2.22909674e-02, - -9.31906551e-02, -1.01961263e-01, -3.32283713e-02, - 8.18983093e-02, -3.90430242e-01, 1.43959653e-02, - -1.31596243e-02, 4.55893874e-01, -4.22518775e-02, - 5.82709551e-01, -1.36653170e-01, -3.07889320e-02, - -4.67781313e-02, -6.33331314e-02, -5.06754033e-03, - 3.76623571e-02, -6.18892610e-02, -1.28876895e-01], - [ 1.44676119e-01, -2.91557442e-02, 2.55934417e-01, - 5.66692650e-01, 3.84408869e-02, 1.04354315e-01, - -1.37322113e-01, 7.15484237e-03, -1.95520781e-02, - -2.59401686e-02, -9.82144028e-02, 2.44248882e-01, - 1.52861271e-02, 1.99174404e-01, 2.76121795e-01, - 8.94557908e-02, -1.24152258e-01, 6.37411512e-03, - -1.13803938e-01, -3.23315486e-02, -3.17632034e-02, - 2.70075332e-02, 2.75091957e-02, -4.90174480e-02, - -2.08239228e-01, -3.95830333e-01, -5.95310889e-02, - -4.46558185e-03, -7.16161057e-02, -4.99811508e-02, - -1.02262713e-01, -2.79212356e-01, -5.11405505e-02, - 2.62467805e-02, 1.03744328e-01, -1.33123115e-01], - [ 1.29204527e-01, -1.63312718e-01, -1.99243486e-01, - -2.34051406e-01, 3.55675933e-03, 1.56449080e-02, - 9.30304453e-02, -7.26388171e-02, 1.25461653e-01, - 1.20737530e-01, 4.42517921e-02, -4.18601990e-01, - -1.94645032e-01, 1.02710314e-01, -7.12260604e-02, - -6.79927021e-02, -3.08946688e-02, -8.88019353e-02, - -4.35314551e-02, -2.15784147e-01, -1.86102502e-02, - 5.49090989e-02, -3.75167191e-01, -8.20007622e-02, - -2.06737250e-01, -3.52603942e-01, -3.86392660e-02, - -2.84039471e-02, 2.83454835e-01, -2.61564963e-02, - 1.20758023e-02, -2.92337686e-01, -5.17344326e-02, - 3.77417319e-02, 1.23368390e-01, -1.37369320e-01], - [ 1.13732927e-01, -6.13378249e-02, -1.77854180e-01, - -4.99198377e-01, 2.01901477e-02, 1.41450047e-01, - -3.23677920e-02, 9.39797983e-03, 5.04098058e-01, - 1.23931216e-02, -8.47154856e-02, 3.81212860e-01, - 1.21610202e-01, 4.87964153e-02, 2.52459884e-01, - 1.51112108e-02, -4.74717468e-02, -1.84605867e-02, - -7.36073852e-02, 3.58235948e-02, -7.69592915e-03, - -7.00120777e-02, 1.28127992e-01, 4.49521616e-02, - -7.93955289e-04, -3.76549661e-02, 1.04670962e-02, - 7.88062997e-03, -2.23614484e-01, -9.32817012e-02, - -4.67354655e-02, -1.74636483e-01, 1.47633761e-01, - -1.42957285e-01, 7.11189136e-02, -1.41615525e-01], - [ 9.82613564e-02, -1.55768439e-01, -1.11842593e-02, - 6.37831986e-02, 5.79317398e-02, 3.34746271e-01, - 3.84975046e-01, -2.11655404e-02, -4.85437140e-02, - -4.50517267e-01, -3.28294598e-02, -2.49714255e-01, - 3.28522325e-01, -1.25372112e-01, 2.82705110e-02, - 1.42169207e-01, -8.04641694e-02, 6.62415996e-02, - -9.59652960e-02, -5.61193414e-02, -4.80792150e-02, - -4.04721648e-02, 2.45707080e-01, 2.35501617e-01, - -4.14447524e-02, 4.34486791e-02, -4.62412462e-02, - 4.26126681e-02, 2.55748153e-01, -7.83308148e-02, - 2.59090564e-03, -3.38329338e-02, 1.78729519e-01, - -3.09782606e-02, -8.34960043e-02, -1.45861730e-01], - [ 8.27897936e-02, -5.39819747e-02, 5.41151650e-02, - -2.87518036e-02, 1.98750496e-02, -1.58728033e-01, - -4.75713938e-01, 1.16178179e-02, -2.98879808e-03, - 2.26475924e-01, 2.46154964e-02, 1.24507852e-01, - 4.07826692e-01, -2.43859500e-01, 1.46053182e-02, - 8.78053382e-02, -7.19747171e-02, -4.02797535e-02, - -8.92022029e-02, -4.73439731e-02, 2.02829354e-02, - -9.01956186e-02, -1.16379023e-01, 1.02566876e-01, - 1.27621949e-01, -3.85584086e-02, -1.85301397e-02, - -1.46384817e-02, 5.42852879e-01, -1.11336805e-01, - -4.69652563e-02, 1.10105053e-01, -3.25540863e-02, - -9.18325037e-02, -1.09285243e-01, -1.50107935e-01], - [ 6.73181787e-02, -1.69579491e-01, -5.90509735e-02, - -8.87142718e-02, -4.61161807e-02, -1.32888526e-01, - -4.28256035e-01, -4.96512838e-02, -1.00748278e-01, - -1.56540096e-01, -1.33985683e-01, -3.31550747e-01, - 3.25447232e-01, 2.73245610e-02, -6.19893037e-02, - -1.48184791e-01, 1.88705355e-01, 1.62340149e-01, - 1.02853999e-01, 3.19841057e-01, -6.06105961e-02, - 1.69779122e-01, 1.54020518e-01, -8.75391066e-02, - -2.06520095e-01, 6.03866279e-02, 1.08508043e-01, - 4.56446186e-02, -2.30992153e-01, 6.16142601e-02, - 5.93037927e-04, -2.22505212e-01, -4.13618460e-02, - 1.47342280e-01, 4.37493399e-02, -1.54354155e-01], - [ 5.18466011e-02, 1.40082181e-01, -2.43853368e-02, - -9.01594944e-03, -2.02037729e-02, -2.15594158e-01, - -1.49669036e-01, -2.02583615e-02, 4.76960652e-03, - -4.28980350e-01, -2.16286242e-01, 2.93388069e-02, - -2.61512101e-01, 4.32281435e-01, 5.15976362e-02, - -2.38068718e-02, 1.35174215e-01, 1.65118262e-01, - 1.18229888e-01, -4.75422740e-02, -1.69874616e-02, - -9.87956077e-02, -6.16191179e-02, 1.92472130e-01, - 4.03664082e-01, 9.86855701e-02, 2.18016505e-02, - 9.58452746e-03, 2.42479756e-01, -9.45590809e-02, - 6.06411323e-02, -1.15035795e-01, -5.60823381e-02, - -1.10115618e-01, 7.84227401e-02, -1.58600360e-01], - [ 3.63750085e-02, 2.90070504e-01, 2.58655623e-02, - -4.51171659e-02, -9.76288766e-02, -7.32196262e-03, - 2.62665208e-02, -1.30719528e-01, -3.34864855e-02, - 1.83281839e-01, -2.03847468e-01, -7.86208585e-02, - 2.39961028e-01, 9.32282284e-02, -1.40201841e-02, - -1.65743440e-01, 2.50046160e-02, 1.87149823e-01, - -1.68221984e-02, -6.99453712e-01, 2.46135090e-02, - 9.76792276e-02, 1.59403309e-01, 1.05807781e-01, - -1.64897703e-02, -3.37719321e-02, 9.97098759e-02, - -5.71760125e-02, -2.09543109e-01, 1.61970984e-02, - 4.49959114e-02, 1.13044158e-01, -1.33089647e-01, - 6.79383874e-02, -1.17107280e-01, -1.62846550e-01], - [ 2.09034402e-02, 9.87452939e-02, 3.10002435e-02, - -3.82550769e-02, 6.49476936e-03, -1.86508909e-01, - -1.58566430e-01, 1.52609888e-02, 2.44785240e-03, - -1.72963649e-01, 2.82357018e-02, 6.35804012e-02, - -4.01134878e-01, -3.48292142e-01, -9.30772051e-02, - 2.69406252e-02, -1.48355186e-01, 6.67649359e-02, - -1.52495161e-01, -4.16254858e-03, -7.79623985e-02, - -8.69922712e-02, 1.67651065e-02, 4.43452805e-01, - -4.69122916e-01, 1.32700158e-02, 1.84264123e-01, - -4.69396599e-02, -8.76988843e-02, -8.42647329e-02, - 1.80242240e-01, 4.39915545e-02, -3.01284958e-02, - -4.19178084e-02, -6.55100867e-02, -1.67092770e-01], - [ 5.43184578e-03, -8.44964292e-03, 5.85759105e-03, - -7.32589066e-02, -6.53161779e-02, 1.58945680e-01, - -1.98484868e-01, -2.29594544e-01, -3.62942442e-02, - -4.60159145e-02, 4.65791941e-01, -1.32931456e-01, - -1.30874768e-01, 1.82594404e-01, 4.72868867e-02, - 7.68151507e-02, -1.17584936e-01, -7.83182383e-02, - -5.70569098e-01, 5.07849343e-02, -6.92476258e-02, - 1.45652056e-01, 1.57256410e-01, -2.92076059e-02, - 2.85284370e-01, 2.52744146e-02, 2.82830708e-02, - -5.04164398e-02, -1.00659683e-01, 5.86346574e-02, - 1.91001222e-02, 8.99196714e-02, -1.54763028e-01, - 1.01448707e-01, -7.42661506e-02, -1.71338975e-01], - [-1.00397598e-02, 6.89980984e-02, 5.02617331e-03, - -5.32203764e-02, 1.92967560e-02, -5.64105034e-01, - 3.46719325e-01, -7.40835667e-02, -5.14018210e-03, - 9.32325572e-02, 1.93343818e-01, 3.23573984e-02, - 2.21131876e-01, 3.06417048e-01, -8.70961323e-03, - 4.47171003e-01, 8.35162401e-02, 8.83740187e-02, - -8.72178078e-02, 1.18704282e-01, 1.05058528e-01, - -4.56921048e-02, 1.59751941e-02, -3.00876088e-02, - -2.47394085e-01, 4.93424907e-02, -6.64604902e-02, - -3.64027135e-02, -1.82686392e-02, -4.59523462e-02, - -1.26862470e-02, 2.52796169e-02, -4.81151454e-02, - -2.86283679e-02, -2.56162435e-02, -1.75585181e-01], - [-2.55113579e-02, 1.63476765e-02, -6.48622513e-02, - 8.53358284e-02, -1.47179626e-02, -2.74279952e-01, - 3.23813617e-01, 1.18787922e-01, -3.12188938e-02, - 1.27388835e-01, -1.47029653e-01, -6.44396339e-03, - 1.59717619e-01, -8.00469816e-02, 4.15628105e-02, - -3.71895492e-01, -2.58336008e-01, -3.58502686e-01, - -9.30814072e-02, 2.37474293e-01, -1.02323368e-01, - 7.77886510e-02, -2.62345857e-04, 3.05618107e-01, - 2.69323707e-01, -4.94645983e-02, 7.17321262e-02, - 1.81141701e-02, -7.26979673e-02, 3.66130173e-02, - 3.41478437e-02, -1.42837018e-01, -2.29302347e-01, - 9.40499976e-02, 9.85415503e-02, -1.79831386e-01], - [-4.09829244e-02, 2.96095997e-01, 5.72670512e-02, - -1.39296770e-01, -1.60581374e-03, 2.67294142e-02, - 5.13432994e-02, 3.44210893e-01, -4.88008671e-02, - -1.20673403e-01, -4.54095185e-01, -3.60888802e-02, - -3.48375738e-02, -3.80728357e-02, 6.19033575e-02, - 2.85812598e-02, -5.49174994e-02, 8.16437509e-03, - -3.89526159e-01, 1.42197743e-01, -6.57034442e-02, - 9.32944417e-02, -1.29381031e-01, -4.54968363e-01, - -7.63084590e-02, -1.27602285e-02, -3.93663906e-02, - -2.22954508e-02, 9.34363678e-02, 4.61584628e-02, - 1.17300354e-01, 1.84356645e-01, 4.64061499e-02, - 2.61230320e-02, -1.38632745e-01, -1.84077591e-01], - [-5.64545169e-02, -3.65092814e-01, -4.26685773e-02, - 1.75265297e-02, -1.79290678e-03, 7.54252076e-02, - -2.16403184e-03, 1.22491851e-01, 4.61655157e-03, - 9.93698239e-02, -2.86250204e-01, 1.17600495e-02, - -1.76643163e-01, -1.61555171e-01, 4.21675071e-02, - 4.96386349e-01, 2.84064054e-01, -1.88499331e-01, - 5.03461063e-02, -9.29289460e-02, 2.72047639e-01, - 1.54824242e-01, 7.62812719e-02, 9.09931362e-02, - 1.82046860e-01, -1.51961623e-02, 1.57171339e-01, - -2.52939817e-02, -6.88583925e-02, 8.74516144e-02, - 1.06507227e-01, 3.63174151e-03, -2.16592148e-01, - 1.95526704e-01, -2.63463091e-02, -1.88323811e-01], - [-7.19260871e-02, 1.53307199e-01, 2.98810583e-02, - -1.76042188e-02, 4.68952209e-02, 2.30930567e-01, - -1.91631261e-02, -3.50371659e-01, -1.39247498e-03, - -3.16982158e-02, 3.19441818e-02, 1.38011038e-01, - 1.15297228e-01, 1.21593997e-01, 1.12343794e-02, - -6.25559241e-02, 2.27593221e-02, -1.95765942e-01, - 2.61839062e-01, 1.88924655e-01, 1.47905156e-01, - 3.61047573e-02, -1.53986499e-01, 4.26004231e-02, - -1.01659156e-01, -9.87078920e-02, -1.97795078e-01, - 2.87956242e-02, 2.66166143e-02, 2.03926936e-02, - 6.36121154e-01, 1.17329828e-01, -1.68884546e-02, - 1.05052806e-01, -1.36004210e-01, -1.92570001e-01], - [-8.73977244e-02, 2.91939259e-01, -6.38535023e-02, - 1.23778999e-01, 2.33115517e-02, 8.99281502e-02, - -2.38235518e-02, 2.54457176e-01, -2.92873345e-02, - 1.45903289e-01, 2.51857221e-01, -1.22888424e-01, - 4.71667722e-02, -1.51163086e-01, -6.75680041e-02, - 1.34960130e-01, -5.27166612e-02, 5.85827529e-02, - 6.49949759e-02, -6.27990216e-02, 7.91215152e-02, - -2.11644500e-01, 1.25666901e-01, -2.19153777e-01, - 1.45102561e-01, 9.46507752e-02, 2.63710856e-01, - 1.36273995e-01, -2.85680946e-02, -9.64817554e-02, - 3.51572961e-01, -3.73799771e-01, 7.54300505e-02, - -1.52278930e-01, 2.77134597e-01, -1.96816236e-01], - [-1.02869295e-01, 4.54483837e-01, -3.16920318e-02, - -9.15080402e-03, 4.94015254e-02, 2.09832817e-01, - 9.22076330e-02, -3.92193407e-01, -1.33265834e-03, - 1.03313603e-01, -7.82989189e-02, 8.86598602e-03, - -9.18587223e-02, -1.70766622e-01, 5.54255210e-02, - 2.28601284e-02, 1.81634039e-01, 4.14796174e-02, - 3.81892845e-02, 2.48120666e-01, 1.65915981e-01, - 2.87097245e-02, -2.50649545e-02, 4.36540544e-02, - -5.01171201e-02, 3.54694985e-02, 1.90053612e-01, - 9.52630565e-02, 1.70738876e-01, 3.70882489e-02, - -4.90600616e-01, -9.28841755e-02, -8.13470930e-02, - 8.31348598e-02, 5.93565181e-02, -2.01062426e-01], - [-1.18340865e-01, -6.85950592e-02, 4.95309308e-02, - -1.77844893e-02, -9.69045609e-02, 2.31995173e-02, - -1.06131600e-03, 2.21603140e-01, -6.05566725e-02, - -2.82245725e-01, 2.64784724e-01, 8.62200931e-02, - 1.37575060e-01, 1.50092602e-01, 4.38311473e-02, - -1.27834529e-01, -1.75913945e-02, -2.03415841e-01, - 1.48476526e-01, -7.80855790e-02, 2.29345813e-01, - 3.37421596e-02, -3.02611887e-01, -3.64654101e-02, - -4.98286486e-02, -1.24875009e-01, 5.32554924e-01, - -5.55246398e-02, -8.19649324e-02, 4.32646945e-02, - -1.92818239e-01, 1.91410363e-01, 1.91146538e-01, - -1.30635314e-02, -1.27977282e-01, -2.05308631e-01], - [-1.33812457e-01, 5.83807267e-02, 6.38746191e-03, - -6.32736981e-02, 2.60766506e-01, 1.92557305e-01, - -4.26477045e-02, 5.47973156e-01, 1.53431622e-02, - 2.03396276e-01, 2.18420655e-01, 1.71779748e-02, - -7.09848702e-02, 2.39939511e-01, -2.50959713e-02, - -1.48106590e-01, 1.51656091e-01, 1.71890616e-01, - 7.37760216e-02, 5.53064533e-02, 1.98505912e-02, - 9.67100039e-02, 1.37430176e-01, 2.82746285e-01, - -1.24559112e-01, 1.80215873e-02, -2.68079907e-01, - 9.55012143e-02, 1.30839288e-01, 8.27972442e-02, - -9.96278524e-02, 4.17835526e-02, -4.81917933e-02, - 1.98767141e-01, -6.95911944e-02, -2.09554836e-01], - [-1.49284035e-01, -7.56456144e-03, -8.76261014e-03, - 2.92932428e-02, -8.39372516e-01, 5.67366369e-02, - -2.41059046e-02, 8.43372419e-02, -2.29054149e-02, - 3.72556150e-02, 3.59098194e-03, -3.51436548e-02, - -4.86128107e-02, -4.90781479e-02, -2.96334457e-02, - 2.16081198e-02, -6.04292788e-02, 1.73466746e-02, - 5.54120354e-02, 4.32790630e-02, 1.27067477e-01, - -9.41377804e-02, -1.37587115e-02, 7.06801787e-02, - -1.22610051e-02, 2.18931045e-02, -3.70597780e-01, - -1.30672632e-02, -4.53533195e-02, -1.70034133e-02, - -1.13316208e-01, -3.45941707e-02, 1.05737671e-01, - -2.95185428e-02, 2.46357918e-02, -2.13801056e-01], - [-1.64755657e-01, -1.91551998e-01, 1.24477036e-02, - 1.76897332e-01, -1.70191415e-02, 2.34046783e-02, - 6.76611960e-02, -1.21719569e-01, -1.60261299e-02, - 2.84169883e-01, -7.72131458e-02, -4.39732298e-02, - -6.60723150e-02, 8.68341923e-02, 7.35200867e-02, - -1.56345084e-01, 4.99212921e-01, -9.53519195e-02, - -1.69593558e-01, 3.12364921e-02, -4.14223462e-01, - -2.19161183e-01, -7.49167113e-04, 4.25142385e-02, - -2.26298310e-02, 3.90600637e-02, 1.34113848e-01, - -4.32782359e-02, -2.25105719e-03, -8.36708769e-02, - 7.53742829e-02, 1.09890841e-01, 3.47145647e-01, - -1.67040601e-01, -4.17540558e-02, -2.18047246e-01], - [-1.80227250e-01, -3.65751952e-01, 1.95310116e-02, - 3.56181487e-02, -2.47674435e-02, -2.56252866e-02, - 1.70394495e-01, -1.01341322e-01, 6.43750429e-02, - -1.18520278e-02, 7.76712969e-02, 1.21111691e-01, - -7.56260678e-02, -1.32285401e-01, 2.50612080e-01, - -2.70852149e-01, -9.66061503e-02, 4.63890702e-01, - 5.18286489e-02, 1.14975851e-02, 7.05922395e-02, - 7.95801077e-03, 3.40116471e-02, -2.50298321e-01, - -4.72176410e-02, 7.11330771e-02, 7.71585703e-02, - 7.12307394e-02, 1.51480496e-01, 4.94032800e-02, - 9.26278085e-02, 1.93590626e-01, -3.63108933e-01, - -1.36400744e-01, 1.46016315e-01, -2.22293481e-01], - [-1.95698813e-01, 8.16941485e-02, 6.35532150e-03, - -5.50320372e-02, 1.45350844e-01, -7.66825154e-02, - -1.48402769e-02, 8.44644289e-03, -3.05129532e-02, - -3.45072865e-01, 1.88118920e-01, 1.39703169e-01, - 9.01852995e-02, -3.05740625e-01, -7.54492134e-02, - 6.51175901e-02, 2.45817453e-01, -1.89270392e-01, - 1.16880536e-01, -2.26171866e-01, -3.72853994e-01, - 5.43844700e-03, -1.24716990e-01, -1.48458153e-01, - 5.83554097e-02, -8.44632387e-02, -3.41172040e-01, - -5.05601391e-02, -1.60052970e-01, 5.74440435e-02, - -1.45993277e-01, -4.03214097e-02, -2.16732427e-01, - -2.84256153e-02, 1.41579702e-01, -2.26539686e-01], - [-2.11170420e-01, -6.31088763e-02, 8.17671046e-03, - -5.57366088e-02, 6.94130734e-02, 3.52174342e-02, - -6.57851174e-02, -9.82191563e-02, -1.27271414e-02, - 1.43996403e-01, -1.19659491e-01, -5.62400967e-02, - -1.02117673e-01, 1.46197915e-01, -6.46053180e-02, - 2.75428176e-01, -5.38663089e-01, 1.51460487e-02, - 3.81278455e-01, 1.08411210e-02, -4.44346756e-01, - 4.02242467e-02, 9.23668295e-02, -7.21167400e-02, - 3.91138941e-02, 4.99221608e-02, 9.94546860e-02, - -3.87978405e-02, 1.93843860e-02, 8.32882449e-02, - -1.15623131e-01, 8.08125958e-02, 1.40358344e-01, - 1.01261795e-01, -5.90205789e-02, -2.30785877e-01], - [-2.26641983e-01, -1.44536331e-01, 8.91233422e-03, - 5.05167954e-02, 3.87359351e-01, -1.25706807e-01, - -9.50697213e-02, -1.42298609e-01, -7.01352954e-02, - -3.15868692e-03, -1.33074358e-01, -1.18453935e-01, - -7.71054849e-02, -4.75535467e-02, -1.50268868e-01, - -1.44392461e-01, -1.82032049e-01, -1.19762598e-02, - -1.21959276e-01, -6.38470054e-02, 4.80738163e-01, - -1.59658909e-01, 2.71296166e-02, -4.31644246e-02, - 1.02411315e-01, 2.07743910e-03, -2.89108336e-01, - -1.03720047e-01, -2.01758668e-01, -2.16420572e-02, - -1.27163813e-01, -7.36601278e-03, 3.14732850e-01, - -1.12868495e-01, 3.11465543e-02, -2.35032097e-01]], dtype=float32), array([-1.89882166e+03, -1.79985218e-04, -1.70435800e-04, -1.27975552e-04, - -1.24901737e-04, -1.24676313e-04, -1.16428266e-04, -1.06598200e-04, - -1.00050034e-04, -9.61478145e-05, -8.36294785e-05, -6.41566730e-05, - -4.51904889e-05, -2.39018827e-05, -1.49146554e-05, -9.43070791e-06, - -8.04440424e-06, 1.51055592e-05, 2.01099483e-05, 2.64523860e-05, - 3.25085311e-05, 5.15936626e-05, 5.31896258e-05, 7.24942220e-05, - 9.04739063e-05, 1.04830775e-04, 1.08393360e-04, 1.37811687e-04, - 1.49946762e-04, 1.86386926e-04, 1.89535742e-04, 2.40968098e-03, - 2.56012683e-03, 2.69382820e-03, 3.27441283e-03, 2.52088105e+04], - dtype=float32)), - mlir_module_text=r""" -module @jit__lambda_ { - func.func public @main() -> (tensor<36x36xf32> {jax.result_info = "[0]"}, tensor<36xf32> {jax.result_info = "[1]"}) { - %0 = stablehlo.iota dim = 0 : tensor<1296xf32> - %1 = stablehlo.reshape %0 : (tensor<1296xf32>) -> tensor<36x36xf32> - %2 = stablehlo.transpose %1, dims = [1, 0] : (tensor<36x36xf32>) -> tensor<36x36xf32> - %3 = stablehlo.add %1, %2 : tensor<36x36xf32> - %4 = stablehlo.constant dense<2.000000e+00> : tensor - %5 = stablehlo.broadcast_in_dim %4, dims = [] : (tensor) -> tensor<36x36xf32> - %6 = stablehlo.divide %3, %5 : tensor<36x36xf32> - %7 = call @tril(%6) : (tensor<36x36xf32>) -> tensor<36x36xf32> - %8 = stablehlo.custom_call @cusolver_syevd(%7) {api_version = 2 : i32, backend_config = "\00\00\00\00\00\00\00\00\01\00\00\00$\00\00\00Y\98\00\00", operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>]} : (tensor<36x36xf32>) -> tuple, tensor<36xf32>, tensor, tensor<39001xf32>> - %9 = stablehlo.get_tuple_element %8[0] : (tuple, tensor<36xf32>, tensor, tensor<39001xf32>>) -> tensor<36x36xf32> - %10 = stablehlo.get_tuple_element %8[1] : (tuple, tensor<36xf32>, tensor, tensor<39001xf32>>) -> tensor<36xf32> - %11 = stablehlo.get_tuple_element %8[2] : (tuple, tensor<36xf32>, tensor, tensor<39001xf32>>) -> tensor - %12 = stablehlo.get_tuple_element %8[3] : (tuple, tensor<36xf32>, tensor, tensor<39001xf32>>) -> tensor<39001xf32> - %13 = stablehlo.constant dense<0> : tensor - %14 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor - %15 = stablehlo.compare EQ, %11, %14, SIGNED : (tensor, tensor) -> tensor - %16 = stablehlo.broadcast_in_dim %15, dims = [] : (tensor) -> tensor<1x1xi1> - %17 = stablehlo.constant dense<0x7FC00000> : tensor - %18 = stablehlo.broadcast_in_dim %17, dims = [] : (tensor) -> tensor<36x36xf32> - %19 = stablehlo.broadcast_in_dim %16, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<36x36xi1> - %20 = stablehlo.select %19, %9, %18 : tensor<36x36xi1>, tensor<36x36xf32> - %21 = stablehlo.broadcast_in_dim %15, dims = [] : (tensor) -> tensor<1xi1> - %22 = stablehlo.constant dense<0x7FC00000> : tensor - %23 = stablehlo.broadcast_in_dim %22, dims = [] : (tensor) -> tensor<36xf32> - %24 = stablehlo.broadcast_in_dim %21, dims = [0] : (tensor<1xi1>) -> tensor<36xi1> - %25 = stablehlo.select %24, %10, %23 : tensor<36xi1>, tensor<36xf32> - return %20, %25 : tensor<36x36xf32>, tensor<36xf32> - } - func.func private @tril(%arg0: tensor<36x36xf32>) -> tensor<36x36xf32> { - %0 = stablehlo.iota dim = 0 : tensor<36x36xi32> - %1 = stablehlo.constant dense<0> : tensor - %2 = stablehlo.broadcast_in_dim %1, dims = [] : (tensor) -> tensor<36x36xi32> - %3 = stablehlo.add %0, %2 : tensor<36x36xi32> - %4 = stablehlo.iota dim = 1 : tensor<36x36xi32> - %5 = stablehlo.compare GE, %3, %4, SIGNED : (tensor<36x36xi32>, tensor<36x36xi32>) -> tensor<36x36xi1> - %6 = stablehlo.constant dense<0.000000e+00> : tensor - %7 = stablehlo.broadcast_in_dim %6, dims = [] : (tensor) -> tensor<36x36xf32> - %8 = stablehlo.select %5, %arg0, %7 : tensor<36x36xi1>, tensor<36x36xf32> - return %8 : tensor<36x36xf32> - } -} -""", - mlir_module_serialized=b"ML\xefR\x03MLIRxxx-trunk\x00\x01-\x05\x01\x05\x01\x03\x05\x03\x1d\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x1f!\x03^\x02\xeb5\x01\x95\x0f\x17\x13\x07\x0f\x0b\x0b\x0b\x0b\x0b\x17\x0b\x0b\x0b\x0b\x13\x0b\x13\x0f\x0b\x0b\x17\x0f\x13\x13\x0b33\x0b\x0f\x0b\x0b\x13\x0f\x0b\x1b\x0f\x0b\x13\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x13\x0b\x0f\x0b\x0f\x0b\x13\x0b\x13\x0bK\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x13\x13\x13\x1b\x13\x13\x03W\x0b\x0b\x0f\x0b\x0bO/\x0b\x13\x13\x0b\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x1f\x0f\x0f\x0b\x1fO\x1f\x0b\x0b\x0b\x0b\x0f\x0f\x17\x1b\x0f\x0f\x0f\x0f\x0f\x0b\x1fO/\x035\x17\x0f\x07\x0f\x07\x13\x07\x07\x17\x07\x17\x13\x1b\x17\x17\x13\x17\x1b\x13\x13\x13\x0f\x17\x13\x13\x13\x02v\x08\x1d\x85\x03\x17\x11R\x04\x01\x03\x03\x13\xbd\x1f\x1d9\x03\x05#\x05%\x05'\x05)\x05+\x17\x11N\x04\x01\x05-\x05/\x051\x053\x03\x03!\xb9\x055\x03\x03\x0b\xbb\x1d?\x03\x057\x059\x17\x11F\x04\x01\x1dm\x15\x03\x03\x0b\xe5\x03\x03\x0f3\x05;\x03\x0b\x17\x95\x19\xa3\x1b\xa5\x0f\xaf\x1d\xb1\x03\x0b\x17\x99\x19\xb5\x1b\x99\x0f\x9b\x1d\xb7\x05=\x1d=\x03\x05?\x05A\x03\x03!\xbf\x1dE\x03\x05C\x03\x05'\x9d)\xc1\x1dK\x03\x05E\x03\x03\x0b\xc3\x1dQ\x03\x05G\x1dU\x03\x05I\x1dY+\x05K\x1d]+\x05M\x03\x03a\xc5\x05O\x1de\x15\x05Q\x1di\x15\x05S\x03\x03\x0b\xc7\x05U\x03\x03q\x9b\x05W\x03\x11u\xc9w\xcby\xcd{\x95}\xcf\x7f\xd1\x81\xd3\x83\xd7\x05Y\x05[\x05]\x05_\x05a\x05c\x05e\x05g\x05i\x03\x03\r\xdb\x03\x03\r\xdd\x03\x03\r\xdf\x03\x03\r\xe1\x03\x05'\x9d)\xe3\x03\x03\x13\xe7\x03\x03\x13\xe9\x03\x01\x1dk\x03\x03\xb3\x1dm\t\x07\x1f%!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f'\x11\x00\x00\x00\x00\x00\x00\x00\x00#\x1b\x03\x05\xa7\xab\r\x03\x97\xa9\x1do\r\x03\x97\xad\x1dq\x1ds\x1du\r\x01#\x1d\x1dw\x13\r\x01\x1f\x07\t\x00\x00\x00\x00\x1f\x1f\x01\x13\r\x05\x07\x05\x1f\x03\t\x00\x00\x00\x00\x1f\x17!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x03\t\x00\x00\x00@\x0b\x05\x1dy\x1d{\x05\x01\x03\x03\x9f\x03\x03\xd5\x15\x03\x01\x01\x01\x03\t\x9f\xa1\xd9\xa1\x1f)\x01\x13\x05\x01\x13\x05\x05\x13\x05\t\x13\x05\r\x07\x01\x1f\x03\t\x00\x00\xc0\x7f\x1f\x17!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f3\x11\x00\x00\x00\x00\x00\x00\x00\x00)\x05\x91\x91\t)\x01\t\x1b)\x01\x05\t)\x03\x91\t\x1d\x01)\x05\x91\x91\x05\x13)\x05\x91\x91\x0f)\x03\t\r)\x03\x94\x85\t\t\x11\x01\x05\x01\x0b\x11\x03\x01\x03\x01)\x03\x01\r)\x03\x82(\t/\t\x01\x0b\x07\x19)\x03\t\x13)\x03\x05\x13)\x03\x01\x13)\x01\x0f)\x05\x05\x05\x0f)\x03\x05\x0f)\x03\x91\x0f)\x03\x05\r\x04\xc6\x04\x05\x01\x11\x071\x07\x03\x01\t\r\x11\x075\x05\x035m\t\x03W\x1f\x03!\x15\x06[\x03\x01\x03\x01\x17\x07c_\x03\x01\x03\x03\x0f\x06g\x03\x01\x05\x03\x05\x05\x03\x07k\x03\x03\x03\x07-\x05\x03\x01\x03\t\x19\x06-\x03\x01\x05\x07\x0b\x1b\x07\to\x03\x01\x03\r\x1d\x07\x01s\x03#\x03\x0f\x07\x07\x01\x87\x03\x01\x03\x11\x07\x07\x01\x89\x03\x0b\x03\x11\x07\x07\x01\x8b\x03\x07\x03\x11\x07\x07\x01\x8d\x03\x19\x03\x11\x05\x03\x01#\x03\x07\x03\x07\x01\x05\x03\x07\x03\x1b\x11\x07\x01\x8f\x03+\x05\x17\x1d\x03\x07\x01\x05\x03-\x03\x1f\x05\x03\x01/\x03\x03\x03\x07\x01\x05\x03\x01\x03#\x03\x07\x01\x91\x03\x15\x03!\x0b\x06\x01\x03\x01\x07'\x13%\x03\x07\x01\x05\x03/\x03\x1f\x05\x03\x01/\x03\x03\x03\x07\x01\x05\x03\x0b\x03-\x03\x07\x01\x93\x031\x03+\x0b\x06\x01\x03\x0b\x071\x15/\x13\x04\x07\x05)3\r\x11\t7\x05\x03\x15+\x03\x01\x07\t\x03;\x1f\x03\x11\x05\x03\t#\x03\x07\x03\x07%\x05\x03\x11\x03\x05\x0f\x06%\x03\x11\x05\x03\x07\t\x03CA\x03\x11\x11\x07IG\x03\x15\x05\t\x0b\x05\x03\tM\x03\x03\x03\x07O\x05\x03\x01\x03\x0f\x0b\x06S\x03\x01\x07\r\x01\x11\x13\x04\t\x03\x13\x06\x03\x01\x05\x01\x00.\x1a}\x1f+\x11\x0f\x0b\t\t\x0b!\x7f\x1f/!!)#\x1f\x19\x0f99m\x19\x89\x8dW\xb7K\x9fM\x9f\x96\x04\x1b+\x1b\x1f\x1f\x15\x1d\x15+\x83\x13\r\r\x1f\x11\x15\x1b\x17\x15\x17\x0f\x11\x15\x11+\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00get_tuple_element_v1\x00iota_v1\x00select_v1\x00func_v1\x00add_v1\x00compare_v1\x00return_v1\x00reshape_v1\x00transpose_v1\x00divide_v1\x00call_v1\x00custom_call_v1\x00value\x00index\x00sym_name\x00third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py\x00broadcast_dimensions\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00compare_type\x00comparison_direction\x00jit__lambda_\x00jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False,) name=tril in_positional_semantics=(<_PositionalSemantics.GLOBAL: 1>,) out_positional_semantics=_PositionalSemantics.GLOBAL keep_unused=False inline=False]\x00jit()/jit(main)/jit(tril)/iota[dtype=int32 shape=(36, 36) dimension=0]\x00jit()/jit(main)/jit(tril)/add\x00jit()/jit(main)/jit(tril)/iota[dtype=int32 shape=(36, 36) dimension=1]\x00jit()/jit(main)/jit(tril)/ge\x00jit()/jit(main)/jit(tril)/broadcast_in_dim[shape=(36, 36) broadcast_dimensions=()]\x00jit()/jit(main)/jit(tril)/select_n\x00jit()/jit(main)/iota[dtype=float32 shape=(1296,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(36, 36) dimensions=None]\x00permutation\x00jit()/jit(main)/transpose[permutation=(1, 0)]\x00jit()/jit(main)/add\x00jit()/jit(main)/div\x00callee\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jit()/jit(main)/eigh[lower=True sort_eigenvalues=True]\x00jax.result_info\x00tril\x00[0]\x00[1]\x00main\x00public\x00private\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00$\x00\x00\x00Y\x98\x00\x00\x00cusolver_syevd\x00", - xla_call_module_version=4, - ), # End paste - - # Pasted from the test output (see back_compat_test.py module docstring) - f64_syevj=dict( - testdata_version=1, - platform='cuda', - custom_call_targets=['cusolver_syevj'], - serialized_date=datetime.date(2023, 3, 17), - inputs=(), - expected_outputs=(array([[ 6.1857700048412179e-01, -7.9870412160195655e-05, - -7.1795133407817180e-02, 7.2651725579187088e-01, - -5.8816812454044016e-04, -1.0752133550364418e-01, - -1.9695247974936425e-01, 1.8446994643771727e-01], - [ 4.7070881487314487e-01, 3.3071017759156432e-05, - -5.9630159401629157e-01, -4.7856902268752244e-01, - -1.4151478943184035e-03, -2.5017522435505674e-01, - 2.8106392345809550e-01, 2.2856669794666581e-01], - [ 3.2284062926217122e-01, -5.1104181032785456e-01, - 2.4098685972870454e-01, -3.2057977627137213e-01, - 6.0128498619340851e-04, 5.5435726441071020e-01, - -3.0349043125069775e-01, 2.7266344945561433e-01], - [ 1.7497244365119549e-01, 4.1809211960021736e-01, - 5.7112844532216078e-01, -3.1146378582869927e-01, - -4.8989605706119613e-04, -4.4689091764000977e-01, - -2.6709076241922963e-01, 3.1676020096456298e-01], - [ 2.7104258040218803e-02, 4.2941995817157164e-01, - 1.1304358388496584e-01, 9.3073375918824142e-02, - -4.7236149166811120e-01, 4.6617552271070906e-01, - 4.7197416944525139e-01, 3.6085695247351168e-01], - [-1.2076392757075657e-01, -3.8434927079561992e-01, - 2.9171425263113138e-01, 1.5624558970245273e-01, - 4.3260383504376299e-01, -2.0278835428567779e-01, - 5.7959048064074936e-01, 4.0495370398246017e-01], - [-2.6863211318173014e-01, 3.6363990709349564e-01, - -3.3163183889685732e-01, 4.2836063092320187e-02, - 5.6343802845177837e-01, 2.7652818360156795e-01, - -2.9700444618985122e-01, 4.4905045549140854e-01], - [-4.1650029879270561e-01, -3.1571410434740910e-01, - -2.1714457524599659e-01, 9.1940300282126255e-02, - -5.2178844473770358e-01, -2.8968513893859849e-01, - -2.6809045393495168e-01, 4.9314720700035708e-01]]), array([-2.4598804776133605e+01, -2.8026300235964570e-15, - -1.8958980326674837e-15, 1.5553235693581772e-15, - 1.6670762548207520e-15, 2.2405283578797194e-15, - 5.4086800892994285e-15, 2.7659880477613365e+02])), - mlir_module_text=""" -module @jit__lambda_ { - func.func public @main() -> (tensor<8x8xf64> {jax.result_info = "[0]"}, tensor<8xf64> {jax.result_info = "[1]"}) { - %0 = stablehlo.iota dim = 0 : tensor<64xf64> - %1 = stablehlo.reshape %0 : (tensor<64xf64>) -> tensor<8x8xf64> - %2 = stablehlo.transpose %1, dims = [1, 0] : (tensor<8x8xf64>) -> tensor<8x8xf64> - %3 = stablehlo.add %1, %2 : tensor<8x8xf64> - %4 = stablehlo.constant dense<2.000000e+00> : tensor - %5 = stablehlo.broadcast_in_dim %4, dims = [] : (tensor) -> tensor<8x8xf64> - %6 = stablehlo.divide %3, %5 : tensor<8x8xf64> - %7 = call @tril(%6) : (tensor<8x8xf64>) -> tensor<8x8xf64> - %8 = stablehlo.custom_call @cusolver_syevj(%7) {api_version = 2 : i32, backend_config = "\01\00\00\00\00\00\00\00\01\00\00\00\08\00\00\00M\08\00\00", operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>]} : (tensor<8x8xf64>) -> tuple, tensor<8xf64>, tensor, tensor<2125xf64>> - %9 = stablehlo.get_tuple_element %8[0] : (tuple, tensor<8xf64>, tensor, tensor<2125xf64>>) -> tensor<8x8xf64> - %10 = stablehlo.get_tuple_element %8[1] : (tuple, tensor<8xf64>, tensor, tensor<2125xf64>>) -> tensor<8xf64> - %11 = stablehlo.get_tuple_element %8[2] : (tuple, tensor<8xf64>, tensor, tensor<2125xf64>>) -> tensor - %12 = stablehlo.get_tuple_element %8[3] : (tuple, tensor<8xf64>, tensor, tensor<2125xf64>>) -> tensor<2125xf64> - %13 = stablehlo.constant dense<0> : tensor - %14 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor - %15 = stablehlo.compare EQ, %11, %14, SIGNED : (tensor, tensor) -> tensor - %16 = stablehlo.broadcast_in_dim %15, dims = [] : (tensor) -> tensor<1x1xi1> - %17 = stablehlo.constant dense<0x7FF8000000000000> : tensor - %18 = stablehlo.broadcast_in_dim %17, dims = [] : (tensor) -> tensor<8x8xf64> - %19 = stablehlo.broadcast_in_dim %16, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<8x8xi1> - %20 = stablehlo.select %19, %9, %18 : tensor<8x8xi1>, tensor<8x8xf64> - %21 = stablehlo.broadcast_in_dim %15, dims = [] : (tensor) -> tensor<1xi1> - %22 = stablehlo.constant dense<0x7FF8000000000000> : tensor - %23 = stablehlo.broadcast_in_dim %22, dims = [] : (tensor) -> tensor<8xf64> - %24 = stablehlo.broadcast_in_dim %21, dims = [0] : (tensor<1xi1>) -> tensor<8xi1> - %25 = stablehlo.select %24, %10, %23 : tensor<8xi1>, tensor<8xf64> - return %20, %25 : tensor<8x8xf64>, tensor<8xf64> - } - func.func private @tril(%arg0: tensor<8x8xf64>) -> tensor<8x8xf64> { - %0 = stablehlo.iota dim = 0 : tensor<8x8xi32> - %1 = stablehlo.constant dense<0> : tensor - %2 = stablehlo.broadcast_in_dim %1, dims = [] : (tensor) -> tensor<8x8xi32> - %3 = stablehlo.add %0, %2 : tensor<8x8xi32> - %4 = stablehlo.iota dim = 1 : tensor<8x8xi32> - %5 = stablehlo.compare GE, %3, %4, SIGNED : (tensor<8x8xi32>, tensor<8x8xi32>) -> tensor<8x8xi1> - %6 = stablehlo.constant dense<0.000000e+00> : tensor - %7 = stablehlo.broadcast_in_dim %6, dims = [] : (tensor) -> tensor<8x8xf64> - %8 = stablehlo.select %5, %arg0, %7 : tensor<8x8xi1>, tensor<8x8xf64> - return %8 : tensor<8x8xf64> - } -} -""", - mlir_module_serialized=b"ML\xefR\x03MLIRxxx-trunk\x00\x01-\x05\x01\x05\x01\x03\x05\x03\x1d\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x1f!\x03^\x02\xeb5\x01\x95\x0f\x17\x13\x07\x0f\x0b\x0b\x0b\x0b\x0b\x17\x0b\x0b\x0b\x0b\x13\x0b\x13\x0f\x0b\x0b\x17\x0f\x13\x13\x0b33\x0b\x0f\x0b\x0b\x13\x0f\x0b\x1b\x0f\x0b\x13\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x13\x0b\x0f\x0b\x0f\x0b\x13\x0b\x13\x0bK\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x13\x13\x13\x1b\x13\x13\x03W\x0b\x0b\x0f\x0b\x0bO/\x0b\x13\x13\x0b\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x1f\x0f\x0f\x0b/O/\x0b\x0b\x0b\x0b\x0f\x0f\x17\x1b\x0f\x0f\x0f\x0f\x0f\x0b/O/\x035\x17\x0f\x07\x0f\x07\x13\x07\x07\x17\x07\x17\x13\x17\x17\x17\x13\x17\x1b\x13\x13\x13\x0f\x17\x13\x13\x13\x02\xa2\x08\x1d\x85\x03\x17\x116\x04\x01\x03\x03\x13\xbd\x1f\x1d9\x03\x05#\x05%\x05'\x05)\x05+\x17\x112\x04\x01\x05-\x05/\x051\x053\x03\x03!\xb9\x055\x03\x03\x0b\xbb\x1d?\x03\x057\x059\x17\x11*\x04\x01\x1dm\x15\x03\x03\x0b\xe5\x03\x03\x0f3\x05;\x03\x0b\x17\x95\x19\xa3\x1b\xa5\x0f\xaf\x1d\xb1\x03\x0b\x17\x99\x19\xb5\x1b\x99\x0f\x9b\x1d\xb7\x05=\x1d=\x03\x05?\x05A\x03\x03!\xbf\x1dE\x03\x05C\x03\x05'\x9d)\xc1\x1dK\x03\x05E\x03\x03\x0b\xc3\x1dQ\x03\x05G\x1dU\x03\x05I\x1dY+\x05K\x1d]+\x05M\x03\x03a\xc5\x05O\x1de\x15\x05Q\x1di\x15\x05S\x03\x03\x0b\xc7\x05U\x03\x03q\x9b\x05W\x03\x11u\xc9w\xcby\xcd{\x95}\xcf\x7f\xd1\x81\xd3\x83\xd7\x05Y\x05[\x05]\x05_\x05a\x05c\x05e\x05g\x05i\x03\x03\r\xdb\x03\x03\r\xdd\x03\x03\r\xdf\x03\x03\r\xe1\x03\x05'\x9d)\xe3\x03\x03\x13\xe7\x03\x03\x13\xe9\x03\x01\x1dk\x03\x03\xb3\x1dm\t\x07\x1f%!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f'\x11\x00\x00\x00\x00\x00\x00\x00\x00#\x1b\x03\x05\xa7\xab\r\x03\x97\xa9\x1do\r\x03\x97\xad\x1dq\x1ds\x1du\r\x01#\x1d\x1dw\x13\r\x01\x1f\x07\t\x00\x00\x00\x00\x1f\x1f\x01\x13\r\x05\x07\x05\x1f\x03\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x17!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x03\x11\x00\x00\x00\x00\x00\x00\x00@\x0b\x05\x1dy\x1d{\x05\x01\x03\x03\x9f\x03\x03\xd5\x15\x03\x01\x01\x01\x03\t\x9f\xa1\xd9\xa1\x1f)\x01\x13\x05\x01\x13\x05\x05\x13\x05\t\x13\x05\r\x07\x01\x1f\x03\x11\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f\x17!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f3\x11\x00\x00\x00\x00\x00\x00\x00\x00)\x05!!\t)\x01\t\x1b)\x01\x05\x0b)\x03!\t\x1d\x01)\x05!!\x05\x13)\x05!!\x0f)\x03\t\r)\x03jB\t\x11\x01\x05\x01\x0b\x11\x03\x01\x03\x01)\x03\x01\r)\x03\x02\x02\t/\t\x01\x0b\x07\x19)\x03\t\x13)\x03\x05\x13)\x03\x01\x13)\x01\x0f)\x05\x05\x05\x0f)\x03\x05\x0f)\x03!\x0f)\x03\x05\r\x04\xc6\x04\x05\x01\x11\x071\x07\x03\x01\t\r\x11\x075\x05\x035m\t\x03W\x1f\x03!\x15\x06[\x03\x01\x03\x01\x17\x07c_\x03\x01\x03\x03\x0f\x06g\x03\x01\x05\x03\x05\x05\x03\x07k\x03\x03\x03\x07-\x05\x03\x01\x03\t\x19\x06-\x03\x01\x05\x07\x0b\x1b\x07\to\x03\x01\x03\r\x1d\x07\x01s\x03#\x03\x0f\x07\x07\x01\x87\x03\x01\x03\x11\x07\x07\x01\x89\x03\x0b\x03\x11\x07\x07\x01\x8b\x03\x07\x03\x11\x07\x07\x01\x8d\x03\x19\x03\x11\x05\x03\x01#\x03\x07\x03\x07\x01\x05\x03\x07\x03\x1b\x11\x07\x01\x8f\x03+\x05\x17\x1d\x03\x07\x01\x05\x03-\x03\x1f\x05\x03\x01/\x03\x03\x03\x07\x01\x05\x03\x01\x03#\x03\x07\x01\x91\x03\x15\x03!\x0b\x06\x01\x03\x01\x07'\x13%\x03\x07\x01\x05\x03/\x03\x1f\x05\x03\x01/\x03\x03\x03\x07\x01\x05\x03\x0b\x03-\x03\x07\x01\x93\x031\x03+\x0b\x06\x01\x03\x0b\x071\x15/\x13\x04\x07\x05)3\r\x11\t7\x05\x03\x15+\x03\x01\x07\t\x03;\x1f\x03\x11\x05\x03\t#\x03\x07\x03\x07%\x05\x03\x11\x03\x05\x0f\x06%\x03\x11\x05\x03\x07\t\x03CA\x03\x11\x11\x07IG\x03\x15\x05\t\x0b\x05\x03\tM\x03\x03\x03\x07O\x05\x03\x01\x03\x0f\x0b\x06S\x03\x01\x07\r\x01\x11\x13\x04\t\x03\x13\x06\x03\x01\x05\x01\x00\x06\x1a}\x1f+\x11\x0f\x0b\t\t\x0b!\x7f\x1f/!!)#\x1f\x19\x0f99m\x19\x85\x89W\xb3K\x9bM\x9b\x96\x04\x1b+\x1b\x1f\x1f\x15\x1d\x15+\x83\x13\r\r\x1f\x11\x15\x1b\x17\x15\x17\x0f\x11\x15\x11+\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00get_tuple_element_v1\x00iota_v1\x00select_v1\x00func_v1\x00add_v1\x00compare_v1\x00return_v1\x00reshape_v1\x00transpose_v1\x00divide_v1\x00call_v1\x00custom_call_v1\x00value\x00index\x00sym_name\x00third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py\x00broadcast_dimensions\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00compare_type\x00comparison_direction\x00jit__lambda_\x00jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False,) name=tril in_positional_semantics=(<_PositionalSemantics.GLOBAL: 1>,) out_positional_semantics=_PositionalSemantics.GLOBAL keep_unused=False inline=False]\x00jit()/jit(main)/jit(tril)/iota[dtype=int32 shape=(8, 8) dimension=0]\x00jit()/jit(main)/jit(tril)/add\x00jit()/jit(main)/jit(tril)/iota[dtype=int32 shape=(8, 8) dimension=1]\x00jit()/jit(main)/jit(tril)/ge\x00jit()/jit(main)/jit(tril)/broadcast_in_dim[shape=(8, 8) broadcast_dimensions=()]\x00jit()/jit(main)/jit(tril)/select_n\x00jit()/jit(main)/iota[dtype=float64 shape=(64,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(8, 8) dimensions=None]\x00permutation\x00jit()/jit(main)/transpose[permutation=(1, 0)]\x00jit()/jit(main)/add\x00jit()/jit(main)/div\x00callee\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jit()/jit(main)/eigh[lower=True sort_eigenvalues=True]\x00jax.result_info\x00tril\x00[0]\x00[1]\x00main\x00public\x00private\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x08\x00\x00\x00M\x08\x00\x00\x00cusolver_syevj\x00", - xla_call_module_version=4, - ), # End paste - - # Pasted from the test output (see back_compat_test.py module docstring) - f64_syevd=dict( - testdata_version=1, - platform='cuda', - custom_call_targets=['cusolver_syevd'], - serialized_date=datetime.date(2023, 3, 17), - inputs=(), - expected_outputs=(array([[-3.1486359056225782e-01, 3.7431364158123925e-02, - 6.1831284766658730e-02, -1.2946991231313536e-02, - 1.9330566993707950e-02, 3.1760201896488226e-03, - 0.0000000000000000e+00, 0.0000000000000000e+00, - 0.0000000000000000e+00, 0.0000000000000000e+00, - 0.0000000000000000e+00, 0.0000000000000000e+00, - 0.0000000000000000e+00, 0.0000000000000000e+00, - 0.0000000000000000e+00, 0.0000000000000000e+00, - 0.0000000000000000e+00, 0.0000000000000000e+00, - 0.0000000000000000e+00, 0.0000000000000000e+00, - 0.0000000000000000e+00, 0.0000000000000000e+00, - 0.0000000000000000e+00, 0.0000000000000000e+00, - 0.0000000000000000e+00, 0.0000000000000000e+00, - 0.0000000000000000e+00, 0.0000000000000000e+00, - 0.0000000000000000e+00, 0.0000000000000000e+00, - 0.0000000000000000e+00, 0.0000000000000000e+00, - 0.0000000000000000e+00, 0.0000000000000000e+00, - 9.4213470166864710e-01, -8.6414847942068732e-02], - [-2.9939200325938797e-01, 8.3501568928299474e-01, - 4.0680107296867257e-01, -4.6573192775473518e-02, - 6.5422207600829785e-02, 2.2099527094683900e-02, - -1.0242349878775975e-02, 4.0829390183091318e-03, - -1.5827725558444371e-02, -8.6793932713605769e-03, - 1.3047005177451432e-03, -5.3573283556152184e-03, - -1.1723085990292578e-02, -3.4282481604778923e-03, - 1.5300655388654032e-03, 1.3010433879291027e-02, - -7.6245808434662662e-03, 5.9569775610370131e-04, - -5.9294293157650772e-03, -1.9734040942842074e-03, - -1.8628968192927392e-02, -1.3034235399858809e-02, - -5.0097004610369401e-03, 2.4749245795903537e-02, - -5.0644358547264675e-03, 3.0532167800601515e-03, - 2.0824661626164857e-02, -1.5147462161617094e-03, - 1.6322395782111299e-02, -1.1236053191734820e-02, - -1.1821960842042806e-02, 3.8822577430670320e-03, - 7.0724820528586508e-04, 1.9906723944256747e-02, - -1.7030338737863057e-01, -9.0661051391036640e-02], - [-2.8392041595652112e-01, -1.0171687781151459e-01, - -1.1816431661072314e-01, 2.9212172394267638e-01, - 3.3294458108354380e-01, 4.2087881292542445e-01, - -2.2194306321456944e-01, 1.2056157631930936e-01, - -1.0764065526585581e-01, 4.4945129933377570e-02, - -1.1518299700192679e-01, -3.1085391640205563e-02, - 3.1385765542768805e-02, -2.2533661915179113e-02, - 9.3053311217867085e-02, -1.6099650538834706e-01, - -3.8639305088265900e-02, 9.2990366329018387e-03, - 4.6666113341746911e-02, -2.1871647987757620e-01, - 1.7703518610745730e-01, 1.5467613762024190e-01, - -7.2294521250116733e-02, 2.3499877830015681e-01, - -5.6829378083033165e-03, -1.0178485446351725e-01, - 1.7877785721217213e-01, 2.1684187554288339e-01, - 7.7233872499541889e-02, 2.2835265304748494e-02, - 3.1080805156356406e-01, 3.1722234078538948e-02, - -7.8092425763001377e-02, 9.4554636051152510e-02, - -9.6031463624110386e-02, -9.4907254840003452e-02], - [-2.6844882865365438e-01, -2.0201860535424061e-02, - -2.0343029420688158e-01, 1.2815855886454322e-01, - 4.8774092445450092e-02, 1.3232562034943543e-01, - -1.8521836621459195e-01, 9.8747816539597660e-02, - 2.7324903486606195e-01, -7.8737437097193080e-02, - 4.9421661772677816e-02, 7.1493931251323112e-02, - 3.5542595611320515e-01, 1.3920746216059152e-01, - -2.8249741974519734e-02, 6.7932896387190703e-02, - -2.3008512044551552e-01, 5.5015746716542496e-02, - -6.0329018554125865e-03, 8.4249901371007491e-02, - -1.0850059549176212e-01, -2.7052679792044718e-02, - 1.7199248671821082e-01, -2.0779039909219962e-01, - 1.1023999772580403e-01, 4.0228126834019268e-01, - -7.1331569093078903e-02, -2.2546040356632324e-01, - -5.6848723613690040e-02, 2.0039103669806510e-01, - -2.2375524112669190e-01, -6.6955463229343037e-02, - -1.4356710092268696e-01, 2.2907198003730800e-01, - -8.4342913246148038e-02, -9.9153458288970819e-02], - [-2.5297724135078736e-01, -9.7633470097019753e-02, - -2.0613664461051402e-02, -4.6575018452204114e-01, - -4.5475545929408095e-01, -1.6835202228307944e-01, - -2.7411043542686481e-01, 1.4382896244553764e-01, - 1.5533482960243880e-01, -7.7897907011887785e-02, - -5.9104799414908579e-02, -5.1049057176047449e-02, - 5.0937034273965797e-03, -2.9920502980456239e-02, - 7.9164430071644656e-02, 6.5334090456028976e-02, - -2.4594170101813598e-01, 4.0287932953704184e-02, - 1.3071075582032446e-01, -5.6912271071735306e-02, - -1.2680756132856946e-02, 3.5044366466197449e-02, - -5.1780762628180410e-03, 1.2325979893038844e-01, - -1.3286387357961091e-01, -1.9718715617446650e-01, - -7.0204376770955132e-02, -9.3710658292701816e-03, - 7.6870928390159760e-03, 1.2623341382152653e-01, - 3.4895566103640097e-01, 7.7553659039143241e-02, - -3.4023999296528072e-02, 8.3074702907895745e-02, - -8.5300072672481381e-02, -1.0339966173793817e-01], - [-2.3750565404792034e-01, -8.2181485614283623e-02, - -2.4796576412755008e-02, 2.6469606244089910e-01, - 2.5136155191565374e-01, -8.5932117879471037e-01, - -6.7801327364868255e-02, 2.3630380146045637e-02, - -6.0339530364635997e-02, 2.4318784991642788e-02, - -2.0157980609574723e-02, 1.3969684905577337e-02, - 5.2064373452097072e-02, -1.3504287072787914e-03, - 1.1948855400414819e-02, -7.7684684576308824e-02, - -1.8126869586737940e-02, -3.2895203661275497e-02, - -4.7194795185232655e-03, -6.2526420481870917e-02, - 7.8353014950393762e-02, 4.3021669650274826e-02, - 4.1123834759705602e-02, 2.1527669096626890e-02, - 3.2298969317449348e-02, 2.3438124417394162e-02, - 3.1518151219115144e-02, 8.9704214482948422e-02, - 7.6821260017619769e-03, -8.5409778343425186e-03, - 1.5521001031338759e-02, -1.3290428648657086e-02, - 1.8906628930454021e-02, -1.2782589525387992e-02, - -8.2979044248598546e-02, -1.0764586518690553e-01], - [-2.2203406674505338e-01, -9.0264475102341105e-02, - 9.0740700176499111e-03, 6.9171384437416147e-02, - -1.3111811612891669e-01, -1.8966507957248607e-02, - 4.0414304307463594e-01, -3.2564666059313241e-02, - 5.6086124244845181e-01, -4.0083205571491060e-02, - -2.4505702319715772e-02, 2.8981348567837486e-02, - -1.8028953963325864e-01, 1.2810669493073431e-02, - -3.0205734928244080e-02, 1.3016546116209483e-03, - 4.1180187675978214e-01, 1.8487430939971340e-03, - 2.1878399115523185e-02, -1.2942737544986772e-02, - 3.1612876215063763e-02, 1.9040590265843902e-02, - -2.9853451951736565e-01, -2.1069261774264141e-02, - 1.2756924052704141e-02, 1.0396556130345047e-02, - 2.0982593071380967e-01, 7.2513245350085284e-02, - 2.6961322653924678e-02, 4.4259057451694346e-02, - 1.3245555422671054e-02, -1.1355432725780245e-02, - -1.6423769454471046e-01, 2.1283797622603673e-01, - -7.7771821344734746e-02, -1.1189206863587289e-01], - [-2.0656247944218639e-01, -7.5555152047925872e-02, - 2.1436004480934572e-02, 1.8519822533150174e-01, - -4.7687267679858099e-02, 1.0893715640778658e-01, - 5.4446388557811642e-01, 6.7864355635107079e-02, - 1.8925675037139755e-01, 3.6392773516755073e-02, - -2.4764455183159433e-02, -3.8468294614801751e-02, - -2.8696444635530814e-02, -1.8823021866307067e-02, - 4.8264052464878845e-02, -3.6882747079153497e-02, - -3.0155420938729255e-01, 1.0404831951207047e-02, - 4.4505477004053171e-03, -4.6873846610364103e-02, - 2.4798470273412251e-02, 2.5891733287640804e-02, - 3.5011544817152707e-01, 8.5903050378751358e-02, - -1.6860450574909990e-02, -3.9052038500091160e-02, - -2.9924661599529656e-01, -1.5823886416275893e-03, - 2.8254484941419005e-03, -4.8861168063938747e-03, - 9.7917302635802658e-02, 2.7710576047465570e-02, - 2.3536560145276611e-01, -3.9600571986552502e-01, - -7.4934893198527877e-02, -1.1613827208484023e-01], - [-1.9109089213931946e-01, -8.4666472598656825e-02, - 5.7740802097843921e-02, 1.9626130737187028e-01, - -2.4601756649487860e-01, 8.1511271167717628e-02, - -4.6530930078469529e-01, 6.8795587726048116e-02, - 5.2415554010200038e-02, -1.7332120317563506e-03, - 3.1251731285109323e-02, 1.5521676381926154e-02, - -1.2359815126908288e-01, 2.7460289856811461e-02, - 1.9114633014954776e-02, 2.8966001347205911e-03, - 4.3487864890462036e-01, -2.2957986155413699e-02, - -1.5357935266312277e-02, 1.0016152245695723e-02, - -4.5019081491420573e-02, -2.4405778384030734e-02, - -7.4832588748429490e-02, -4.4078616914614753e-02, - 3.0809052034342380e-02, 1.1926634983737788e-01, - -8.1517751909305367e-02, -7.7527914203627396e-02, - -3.7123430398910418e-02, 1.3750979135916276e-02, - -9.7457414231716055e-02, -1.7178991628521816e-02, - 2.1304973749867503e-01, -5.4941011823140218e-01, - -6.7860578570392335e-02, -1.2038447553380759e-01], - [-1.7561930483645249e-01, -8.8342789136092309e-02, - -1.1242590243640400e-02, -1.8652768797207359e-01, - -9.8464009205703876e-02, 1.7256713195193910e-02, - 2.9649268724224581e-01, 5.8780632678962143e-02, - -3.4585362321307522e-01, 7.6907763800451081e-03, - 2.5103268120083535e-02, 2.5393826053803564e-02, - 4.3240349879996420e-01, 3.3310696488693933e-02, - 2.1609140330890370e-02, 1.3951456173138647e-03, - -1.2840968480253712e-01, -3.3248191939129826e-02, - -8.9379099725266672e-04, -1.8994911138723630e-03, - -2.3834826680311980e-02, 4.7502947323282011e-03, - -4.4024121870114297e-01, -6.7327999197165686e-02, - 2.9359383382924452e-02, 9.1479482958182867e-02, - 3.8593300484440007e-01, -4.7958512765110956e-02, - -5.1251961259242168e-02, 1.8636628882937378e-02, - -6.5572564769060912e-02, -2.2887842635462220e-02, - -1.6042006104302377e-02, -3.3250776465128573e-01, - -6.6477273291217359e-02, -1.2463067898277495e-01], - [-1.6014771753358550e-01, -8.3434053708109190e-02, - 1.3638599925185501e-02, -2.4158649874087133e-02, - -1.1124755841847851e-01, 4.2695267715302458e-02, - 1.4866152720116035e-01, 4.9700778378845270e-04, - -3.5326388070491549e-01, -1.5745483283003094e-02, - -8.9738221678782072e-03, 1.0993364411347295e-02, - 1.9527915544397639e-01, 1.3259513825918660e-02, - -3.9339417079053149e-03, -3.7389315402467350e-02, - 3.0825337281314197e-01, 2.9465425388143118e-02, - -1.0086552608467406e-04, -2.1130010935818223e-02, - 2.4746795171351338e-02, 1.2876294127766924e-02, - -1.3542161100061775e-01, 2.3491306500478031e-02, - 2.8381089185132442e-02, 5.0060402655999779e-02, - -4.7990645387633185e-01, 1.7841388064942280e-02, - 3.6163722246352295e-02, 2.2692968040711251e-02, - -1.4881297657765719e-03, -1.1068249840362020e-02, - 4.3250260717661632e-01, 4.5393847466427317e-01, - -6.1116215809998306e-02, -1.2887688243174231e-01], - [-1.4467613023071851e-01, -8.5360329958689612e-02, - 3.6773895176301370e-02, 2.8417567832807769e-04, - -1.4251569175101705e-01, 1.8419541161364662e-02, - 1.4739729008583152e-01, -6.2901931512317516e-02, - -4.3820330673251112e-01, -1.1585923923104585e-01, - -4.6526417840431711e-02, 1.2161556905396271e-02, - -8.3388018002128958e-02, 2.3616237126461999e-02, - -9.1086898933490409e-02, 9.6073985629915787e-02, - 3.0200810799555788e-01, 9.9080289536070815e-02, - 4.9921034650103280e-02, 7.6871969202905246e-02, - -8.3377720121475072e-03, -1.7031625806123534e-02, - 4.5636496936456672e-01, -4.0005637071420394e-02, - -1.9891703100641429e-02, 1.2472945837760744e-02, - 5.9697784009368959e-03, -9.5789228620796370e-03, - 6.8806967828826657e-02, 1.5038487697273856e-01, - 6.8452882565985446e-02, 1.3123694381544091e-02, - -5.6226049096551989e-01, -4.1018946243773058e-02, - -5.6717572380307106e-02, -1.3312308588070965e-01], - [-1.2920454292785150e-01, -6.6253352907543861e-02, - -1.0164436321011842e-01, -1.4433060335444364e-01, - 1.6028176487458967e-01, 3.4584483531135940e-02, - 1.9900533500768001e-02, -5.2164178106233798e-02, - -1.2875710620386896e-01, -1.3038955529948765e-01, - -3.1311992664378889e-02, 2.5299917094429910e-02, - -4.1764341929454979e-01, 5.7547077142788963e-02, - -1.1598534347679475e-01, 1.8086109486937549e-01, - -6.3115663671148348e-02, 8.6408791666891471e-02, - 4.0289642159952954e-02, 1.2892059198986330e-01, - -7.5052803928986972e-02, -3.4807004039357006e-02, - 2.0072216849958635e-01, -1.1909118683716058e-01, - -2.6393566026650855e-02, 6.6849035713186178e-02, - 4.7200759534307635e-01, -7.6853961442131774e-02, - 2.6993333821331650e-02, 1.7484304402685918e-01, - 5.3240433359001025e-03, 2.9788042206222785e-03, - 5.1760936987899087e-01, 1.1384037033693235e-01, - -5.1865856323749862e-02, -1.3736928932967699e-01], - [-1.1373295562498452e-01, -5.7235135967154585e-02, - -4.7652965020097103e-02, -1.7627396739100985e-02, - 7.7938405922626644e-02, 2.2087656281477019e-02, - 6.1009605667557178e-03, -5.4981966965685393e-02, - -1.8486086378646865e-01, 3.8911039431433647e-02, - 3.5079519080830110e-02, 1.9272432328556483e-02, - -5.9096451891695889e-01, -7.7247905448605157e-03, - 3.7441325666613741e-02, -4.9165769090891341e-02, - -3.3776276260195798e-01, 1.6606308621317768e-02, - 3.8859102913090936e-02, -1.9047412918711374e-02, - -3.8482634352387676e-02, -4.8755071639337150e-02, - -4.3270527443011519e-01, -9.1999354995766322e-02, - 1.0430914529054176e-01, 1.4978760949122619e-01, - -3.4135100214765429e-01, -2.5289826614278744e-02, - 3.4608873349492607e-02, 8.8085003662463843e-02, - -1.5196825642675141e-01, -9.3051296574294673e-03, - -2.4468277187262805e-01, -2.4348157193486621e-02, - -4.7513567722300747e-02, -1.4161549277864433e-01], - [-9.8261368322117570e-02, -1.6390394385331745e-02, - -5.4742294041798749e-02, -5.8987021949670405e-02, - -1.6882319276059432e-01, 4.3601612172208745e-02, - -2.9911314975774938e-02, 2.3284677199386728e-03, - -3.1808540586289284e-02, 6.9627318822466044e-01, - 1.6271702602637766e-01, 1.5743246880124597e-02, - -4.3195703838658110e-02, -2.2494758789598773e-01, - 7.1399213422553218e-02, -1.3240943946997921e-01, - -8.4980139589052577e-03, -3.2038201094679952e-01, - 6.2407097431780204e-02, -7.6882180114861851e-02, - 2.9470860002467913e-02, -4.2571478756212582e-02, - 2.0163350380724604e-01, -3.2389702717405428e-01, - 6.9711204990479309e-02, -8.1573794801329258e-02, - 1.3304500243627673e-01, 4.0406118875997113e-02, - 8.2477981782237836e-02, -1.1543529624088469e-01, - -1.1014206710642817e-01, 4.2320022953069426e-06, - 3.8041226304310447e-03, 1.3395530894194055e-01, - -3.9467794046677329e-02, -1.4586169622761166e-01], - [-8.2789781019250525e-02, -1.9278711714630567e-01, - 2.2165755909431184e-01, -2.1201546316262262e-01, - 1.4307796989725635e-01, 6.0334342472999250e-02, - -5.5139304406736672e-02, -1.9408969113742302e-02, - 5.4970843704949646e-02, -4.5047658482968128e-01, - -3.3338315762977556e-02, -6.5308425743183532e-02, - 1.4218465309675436e-02, 4.9087218418760230e-02, - 1.8670840217742501e-01, -1.5287462038432642e-01, - -1.3217180940167689e-02, -6.6463048958420534e-02, - 3.8845065361654303e-04, -2.2429929685530131e-01, - -2.6776933696982124e-02, 8.5772405898653856e-02, - 1.1857225379472448e-01, -3.3789334871471582e-01, - 8.3834684881833613e-02, -1.7391265231974168e-01, - -5.9431721332300208e-03, 2.7485104738181495e-02, - 1.6105963634532708e-01, -4.7246605597344127e-01, - -2.3898285645951292e-01, -2.0628986543330220e-02, - -2.1798010578591574e-02, 1.6076906598537423e-02, - -5.4377032852269684e-02, -1.5010789967657906e-01], - [-6.7318193716383562e-02, -1.3247564302860890e-01, - 1.7006921492087917e-01, 1.2398760160260749e-01, - -1.4177630269484331e-01, 1.5422349385403381e-02, - -5.9592326716797428e-02, -3.5882053764316857e-02, - -1.7232432793461348e-02, 2.3701488719579314e-01, - -4.6593215018650616e-02, -6.3082282004145299e-02, - -2.0902723950643357e-02, 5.2050993065408405e-02, - -8.0468326155430828e-02, -5.0880717820819980e-02, - -1.1820152914284968e-01, 5.6506976812092713e-01, - -2.1968735055254530e-02, 1.6529598718631755e-01, - 1.0797738052990204e-01, -3.0113303079001008e-02, - 5.5521405735639642e-03, 2.7802427161516047e-01, - -1.3829193596041753e-01, -1.1466435184415830e-01, - 1.1740546330296046e-01, -1.7311150238082029e-01, - -1.6365530586101310e-01, -3.6819727396673907e-01, - -3.1239015782869367e-01, 6.3966770007709506e-02, - -2.6591619532336051e-02, 1.2885889151522636e-01, - -3.7992961598361283e-02, -1.5435410312554640e-01], - [-5.1846606413516585e-02, -6.0477319044140554e-02, - -7.5750638182608219e-03, -1.0624372654415394e-01, - 8.1266486795481985e-02, 4.0180836057036554e-02, - -3.7783670829837974e-02, 4.6289675320758547e-02, - 3.3808855820936547e-02, -1.9195948450068509e-01, - -5.8196442046703094e-02, 1.7282080569685822e-03, - 1.4755965059760449e-02, -6.0959969133142022e-01, - -2.8239274796445768e-01, 1.2486767782495350e-01, - -1.6812624118941352e-02, -3.1637047991210354e-01, - -3.4329518102613220e-02, 2.9658523886210797e-01, - 2.1095830387260842e-01, -7.1581690223787436e-02, - 1.4902746008909057e-02, 2.5118050689616306e-01, - 1.5960904763919231e-01, 1.6146826320314336e-01, - -3.0778528162015331e-02, -6.0781897242040703e-03, - -1.5766062756371724e-01, -2.2924930849571712e-01, - -2.3919944196342770e-02, 4.0432828090792343e-02, - -3.3603315710298294e-02, 6.6005717038430623e-03, - -3.2237412023528290e-02, -1.5860030657451374e-01], - [-3.6375019110649567e-02, -4.6095054123273631e-02, - 4.1487329226456366e-03, -4.9882330119267008e-02, - 2.6789583798631911e-01, 2.8310263556813459e-02, - -5.0744234427433435e-02, -2.1955670997388516e-01, - 8.8814242427478526e-02, 7.2616405945027329e-02, - 3.7105581486243189e-01, 1.3801726499993164e-01, - 1.2228306569610396e-01, -1.8641957679946289e-01, - -1.7746951776518829e-01, 1.1838468893129621e-01, - 4.1434840944853890e-02, 3.4352445701196649e-01, - -1.3539286248067484e-01, 1.2179016223131671e-01, - -1.4481862254120659e-01, -6.0813770391397334e-02, - -9.5024877677197070e-02, -2.6026144416788322e-01, - 6.7007386100264313e-02, -2.7403316717453452e-01, - -1.2940472617950355e-01, -7.0811325772559455e-02, - 1.0283464270665656e-02, -5.0042226650144100e-02, - 3.9567119578457077e-01, -2.3131183910318670e-01, - -2.4438157021422158e-02, -9.5495078814865603e-02, - -3.1811761848109070e-02, -1.6284651002348108e-01], - [-2.0903431807782615e-02, 7.2327502897265056e-02, - -2.1426834420397733e-01, -2.4971807305411563e-02, - -6.8251303361485452e-02, -3.5176957926268708e-03, - -1.7281098595222758e-02, -2.7919893499292525e-01, - -7.5490419998562163e-03, 8.8933532299955390e-02, - -8.3918077552881970e-02, 4.2946166228858822e-02, - -3.5084337029511685e-02, 5.2484778345047800e-01, - -1.3476341073870199e-01, 8.9651093734304757e-02, - -2.6221874920893444e-02, -3.2081171793188057e-01, - -7.0201683149374666e-02, 9.7920337768921742e-02, - -7.6208072805887969e-02, 2.9964575931518713e-02, - 2.1839138515231137e-03, 2.1907625163481245e-01, - 7.8802565386018458e-02, 1.0637722019900711e-01, - -1.5047419808766808e-02, -1.2522929609505140e-01, - 1.0489044814827699e-01, -4.4452472469644072e-01, - 2.5261973738582033e-01, -1.9360753077714768e-01, - -3.0637038971187570e-02, -3.9473390838082588e-04, - -1.0054456334322568e-02, -1.6709271347244839e-01], - [-5.4318445049156196e-03, -1.1991560506989501e-01, - 1.6016393502783463e-01, -9.0534713898102900e-02, - 1.7803986653673967e-01, 4.2517830558630100e-02, - -6.5595472901773699e-02, -6.9456352075150884e-02, - 7.9849581869208763e-02, 1.4596149872374808e-01, - -3.7448911148165226e-01, 3.0784697110174092e-02, - 1.0212691273921030e-01, 1.2477201433959939e-01, - -2.1170895978207616e-01, 1.9057503902571590e-01, - -1.9885301263116554e-02, -2.1847437899940467e-01, - -1.3659628076825936e-01, 6.2262165446311392e-02, - -1.9622860693073528e-02, 4.1620399347292121e-02, - -3.1648999142503326e-02, 8.2027519954154221e-02, - -7.9260224219164649e-02, -4.4257777757196498e-01, - -1.0450524222584731e-01, 7.1670676847096298e-02, - 4.6620848245388563e-02, 3.5490360494088574e-01, - -3.4694381436297000e-01, -2.2966638374036538e-01, - -2.1349097951285249e-02, -5.0149218417714851e-02, - -2.8318514185483656e-02, -1.7133891692141581e-01], - [ 1.0039742797951326e-02, 1.4486958501002600e-01, - -3.0487486722127227e-01, 1.2108072885929126e-01, - -1.1723298949673400e-01, 9.6017523703054095e-03, - 4.9883113678426960e-03, 3.2018649396693973e-02, - -4.0095882258820964e-02, -2.4528012104090294e-01, - 6.0349817604330003e-01, -6.0025406492642708e-02, - -1.6146280657180472e-02, 1.5798023347451132e-01, - -1.5035528625979958e-02, -2.2434556029665070e-02, - -2.4354754626807390e-02, -1.5308774844201870e-01, - -1.1065734099847921e-02, 5.1339996940509787e-02, - 1.6396255893983677e-01, 2.4722965810338692e-02, - 9.6017297101513074e-03, 1.6662850312888863e-01, - 9.1395453034799151e-02, -4.2004786665153609e-01, - 3.0226599593042958e-02, 3.3204444593892296e-02, - -9.0545811500522586e-02, 1.1327046229049616e-01, - -2.5108979165208944e-01, 1.2687846708619716e-01, - -2.1404901679780933e-03, 2.9977168343317158e-02, - 5.5400172108409033e-03, -1.7558512037038310e-01], - [ 2.5511330100818325e-02, -5.8698168696025753e-02, - 8.0629703301508024e-02, -7.0612253616157819e-02, - 3.2715731475630602e-02, 2.1732269341780134e-02, - -5.6700795470449199e-02, -6.8235752853351661e-01, - 6.4905178300795938e-02, -3.5862976828251472e-02, - 8.8618413873728166e-02, 3.1550620324006268e-01, - 9.2319437517647415e-02, -1.0599662867975553e-01, - 2.6587503059973538e-01, -1.0545080566473539e-01, - -2.2738440485640277e-02, -6.6368929276419075e-02, - -5.1003071286368440e-02, -1.1626185301232636e-01, - 5.4119363471023328e-02, -2.4882466696968256e-02, - 4.6420092314024886e-02, 1.7831888983094824e-01, - -2.7253935859206135e-01, 1.7198911112035339e-01, - 1.3432430343834192e-02, 7.1000954309573148e-03, - -3.8416339301886476e-03, 1.6384316059667964e-01, - -6.0953258543061287e-02, 2.6960776094017469e-01, - 2.0718992188831518e-02, -2.7614704623654989e-02, - -1.2643038301898243e-02, -1.7983132381935049e-01], - [ 4.0982917403685273e-02, -4.7160343894475723e-02, - 7.8787266856851345e-03, -1.6730572778497552e-01, - 2.7113248408711793e-01, 9.8438763801876154e-03, - 2.2608843153598773e-02, 4.0738411310515976e-01, - 3.2355058682223534e-02, 1.1698920368317291e-01, - 1.4072643414054364e-01, 6.7061453574130916e-02, - -1.8930127519950827e-02, 1.9146087806398635e-01, - -2.4250669817151019e-02, 1.1868698006794093e-01, - 1.0317141879348907e-01, -8.5252634874863287e-02, - -2.8010523433118828e-01, 1.3060583612270180e-01, - -9.9969111180962050e-02, -3.4760563118607063e-02, - -1.7994529116745678e-02, -6.0554676763009442e-02, - -4.6559703882739706e-01, 1.1940676107160293e-01, - -1.0161278374127546e-01, 1.3173327834920193e-01, - -2.2709272071986680e-02, -1.1755702148341549e-01, - 3.7441059930431703e-02, 4.4164660080364565e-01, - -6.6992110689447992e-02, -2.5301348191003502e-02, - -9.7262032302421250e-03, -1.8407752726831786e-01], - [ 5.6454504706552257e-02, 7.8158541336176779e-02, - -1.4338657014458589e-01, 1.0703741291078765e-01, - -1.3942580377761906e-03, 2.2695174951015635e-03, - -3.8562621975632518e-02, -3.0965063003047144e-01, - 3.7355997032764349e-02, 1.4990453152525209e-02, - -1.1227058245216649e-01, -7.0287795373175999e-01, - 1.1718292741895955e-01, -5.1035967037226390e-02, - -9.4000621055494157e-02, 1.7518267045374700e-01, - -1.4730348981690847e-02, 5.1783743616797537e-02, - 2.1169018058168132e-01, 5.8597372997689870e-02, - -1.6243455966644404e-01, 5.9497378897041750e-02, - -7.3121464646455983e-02, -1.8084067697810838e-01, - -6.6501694611624321e-02, 4.1097079298917809e-02, - -4.3356588698331838e-02, 2.4444891440205574e-01, - 6.5642952335239826e-03, -9.6906979426258765e-03, - 1.8913630981055121e-03, 2.7008769602574367e-01, - 8.8545125037905337e-03, -3.9988001886776758e-02, - 9.3906452280477001e-03, -1.8832373071728517e-01], - [ 7.1926092009419212e-02, 8.0994217906793439e-02, - -2.0767188447365928e-01, -1.5196436606475891e-01, - 1.3077919554196207e-01, -2.1254474743086713e-02, - 4.5019671597743463e-02, 9.6558458919928689e-02, - 1.2420216348711157e-02, -6.2064238471275191e-03, - 9.8956490118614168e-02, -3.2363738790615754e-01, - -3.2870638207842147e-02, -1.5482218310094722e-01, - 2.9647782998980127e-01, -6.1576762109174010e-02, - 1.2666434428081200e-01, 2.1955834692424955e-02, - -1.8997255642944891e-03, -1.0295835477975461e-01, - 1.8208445909004639e-02, -1.1030261882048981e-01, - 4.3794875217006007e-02, 1.8518198489376456e-01, - -4.0747443172392700e-01, 1.3827664021164707e-01, - -4.1431123873109715e-03, -1.4061023435938111e-01, - 1.3942741953117222e-02, 1.9365617058920072e-02, - -8.4489815015323350e-02, -5.7838799828344145e-01, - -2.8902818751484066e-02, -2.4186610549109096e-02, - 1.2086263962861131e-02, -1.9256993416625251e-01], - [ 8.7397679312286217e-02, 1.5064561887342939e-01, - -2.1080556782941462e-01, 1.5916760566958116e-01, - -1.9624826757584166e-01, 1.5198104896205650e-02, - -1.4330248064956560e-02, 3.3068118190946301e-02, - -3.5714352226646290e-02, -1.4260141979380403e-01, - -2.4115477092387741e-01, 3.4101861982281523e-01, - -1.9029646752241479e-03, -2.7699284020832545e-02, - 1.0920088465260440e-01, -1.5239532632222408e-01, - -8.5144012779746134e-02, 5.5970342531910411e-02, - 6.9106614215268647e-02, 2.4036876137100174e-01, - -1.2301443222654272e-01, -1.1953863304856910e-01, - 3.5171852820881193e-03, -2.1104179481631563e-01, - -1.6652675336533382e-01, -6.9825511877400867e-02, - 7.3611503187800218e-03, 5.1349708686040763e-01, - -3.0172148431909446e-01, -1.0589893886410634e-01, - 3.6783462028334960e-03, -2.0553003985674112e-01, - 1.8790472746182015e-02, 1.9823557204917654e-02, - 2.5168461511062466e-02, -1.9681613761521988e-01], - [ 1.0286926661515323e-01, -5.1095768728277327e-02, - 1.3471859461003702e-01, 3.0500821091821676e-02, - -1.6790235354550213e-02, -7.0308669455806175e-03, - -3.0939649438101019e-03, 2.5665199177927620e-02, - 2.1279168221811904e-03, -2.5037640808915945e-02, - -1.2405085129935786e-01, -2.6231150724568519e-01, - -8.5787446133464614e-03, 3.9627338244596369e-02, - 2.3267441336286346e-01, -4.0743293242468487e-01, - 2.4149661576382757e-04, -6.3680910375172040e-02, - -4.3805185403053759e-01, 2.0300111728111647e-01, - -2.1099142295899803e-01, -3.4325637130492054e-01, - 2.4798870388207689e-02, 5.8652422232119368e-02, - 3.1273508409742873e-01, -6.5663309732651248e-02, - 8.4976320234436575e-02, -1.2972698624062320e-01, - -1.0136590956706468e-01, 1.7606369902531008e-01, - 1.7776135567204221e-01, 1.0742707779456324e-01, - -7.9052346006256245e-03, 7.3493627583932908e-02, - 9.9131943085618447e-03, -2.0106234106418724e-01], - [ 1.1834085391802023e-01, 8.1061946874736585e-02, - -1.6265342280467382e-01, -2.5856375159094996e-01, - 1.4258244531423583e-01, -2.5799424990869069e-02, - 1.9638649342146815e-02, 7.3355921016709083e-02, - 5.9394009978013036e-02, 1.5655633426552953e-01, - -9.8792934500238835e-02, 9.9575902680803088e-02, - 1.8527367488061958e-02, 6.3288806058580380e-02, - 3.7739330071097632e-01, 3.9157302813010320e-01, - 1.3485974151563190e-01, 2.4396726581112591e-01, - 3.6171829433890815e-02, -1.5329124928290030e-01, - 1.0994295285071572e-01, -6.2470988682208468e-02, - 7.2649015124010521e-02, 1.4656583051512045e-01, - 5.0160574613932607e-01, 4.7267639935224738e-02, - -4.2965682291764895e-02, 1.8881658695211850e-01, - -1.0776584277343945e-01, -2.6754374009298049e-02, - -7.7009726198669998e-02, 8.6417047403639091e-02, - -5.3833621971674586e-03, -8.0918819205681225e-02, - 2.1780800232539175e-02, -2.0530854451315456e-01], - [ 1.3381244122088717e-01, 1.5997082437941978e-01, - -2.1906335649966574e-01, 2.3332171765159351e-01, - -6.4994730069703827e-02, -2.7137179321886296e-02, - 4.4299490835366419e-02, 5.4082161016101568e-02, - -6.1822856454263338e-02, -6.6517101749567792e-02, - -2.9376460130324589e-01, 1.1103413626514062e-01, - -3.3806550575053815e-02, -1.8397686746205080e-01, - 3.9400318507963744e-02, 1.8758272608343995e-01, - 2.1898570040268548e-02, -5.7258401311702969e-02, - -1.2054652895121606e-01, -3.3785342949153890e-01, - -3.9112933378476634e-02, 1.2987622324621689e-01, - -7.4850924489854642e-02, -1.8237325410753219e-01, - -1.1058781873480500e-01, -2.0595217802395629e-01, - 7.5757742040963461e-03, -5.3655875317100610e-01, - -2.0896258914648322e-01, -5.5945308120122161e-02, - 8.2455318541596961e-02, 1.7624602710846482e-01, - -2.2489297400574856e-02, 5.2915934277181324e-02, - 3.8152138968863464e-02, -2.0955474796212192e-01], - [ 1.4928402852375416e-01, -4.7103999084602964e-02, - 1.5843017378423407e-01, -1.0471529213101267e-01, - 4.1822224430947852e-02, 4.9674575956627585e-03, - -1.3311898606966285e-03, 4.8322275176183468e-02, - 2.6782623911085109e-02, 1.3647784270166637e-02, - 1.0980857986376788e-01, -5.0748588072257886e-04, - -1.0361251293227987e-02, 1.1049141088458188e-01, - -4.7174567274205670e-01, -2.0220954115377396e-01, - 1.3182956708179594e-02, -1.1843903142333311e-02, - 2.0088578029524848e-01, -5.3080319758187777e-01, - -1.6308626968204651e-01, -1.6901681485606096e-01, - 7.0269705034495436e-02, 9.8708103667137601e-02, - 5.8906260202682963e-02, 1.3406466835766842e-01, - 1.3927440769859889e-02, 9.2483635015958410e-02, - -4.1489874017913597e-01, 5.3520424215223954e-02, - 3.3087563030626183e-02, -4.3491644319790287e-02, - -1.4259433018598195e-05, 8.4993306168228474e-03, - 1.9440725644047020e-02, -2.1380095141108932e-01], - [ 1.6475561582662113e-01, 1.5611161975472390e-01, - -2.2981567498448408e-01, -2.5170242091030143e-01, - 1.2572164509633985e-01, -3.5101394036068920e-02, - 1.4388788465769620e-02, 6.4367254285863956e-02, - 7.9127393952476463e-02, 4.7770236979792664e-02, - -1.4967962998375717e-01, 9.6657597995555136e-02, - 2.8600846275685401e-02, 1.0247100377903102e-03, - -1.7416826445456965e-01, -5.2903452729155642e-01, - 1.2378709088794008e-01, 1.6002483124980629e-01, - 2.3117191956384286e-01, 2.0936710152049257e-01, - 8.2739337123958492e-02, 3.7851995698789648e-01, - 1.9060641335918893e-02, -6.4314540668445608e-03, - 8.4778867125413743e-02, 1.6232730574310308e-02, - -5.8776952506303194e-02, -1.6317833767006093e-01, - 2.0541131812472332e-01, 9.4709191370766388e-02, - -3.0776520624173034e-02, 1.1938827858311649e-01, - 1.2517716200189802e-02, -1.3352132837280375e-01, - 3.8021934168930759e-02, -2.1804715486005660e-01], - [ 1.8022720312948814e-01, 8.6827318631575279e-02, - -7.4501227114099414e-02, 1.2876209736226957e-01, - -2.2037890384696301e-01, -2.5814842105572621e-02, - -3.1406758090893994e-02, 9.6294241223690305e-02, - 1.8240072112824506e-02, -7.8775576899911090e-02, - -3.0389264268442007e-02, 8.6684499299869738e-02, - 5.4365532030452843e-02, -1.1850090448995039e-01, - -2.4574663651253167e-01, 5.0606647353540021e-02, - -1.1179254494673002e-01, 1.5746625930135386e-01, - -2.3653025671773734e-01, -2.4326576699636770e-01, - 5.1089622549619594e-02, -2.8901934374460203e-01, - -2.6451534372339578e-02, 5.4045829899974578e-02, - 7.5844174532653701e-03, 9.5261278786040723e-02, - 6.5117432591824925e-02, 1.5374072905554484e-01, - 6.6944827374030014e-01, 2.5045538719576737e-03, - -5.5672913354967879e-02, 1.2051210553600417e-02, - 3.3658431259863966e-02, -3.1395677687489406e-03, - 4.7661017511192831e-02, -2.2229335830902394e-01], - [ 1.9569879043235514e-01, -9.1769753653107577e-02, - 2.7141769527027171e-01, 2.2785564717029946e-01, - 6.4057719170856758e-02, -3.7788206214948872e-03, - 9.7259287514508460e-03, 1.6918261328737952e-01, - 6.8155784376799586e-04, -1.4846652373116371e-02, - 9.7665427524227605e-02, 1.4020779957899679e-01, - 5.4803013440760974e-02, -3.7770889485239614e-02, - 2.0161818269196646e-01, 1.3431772896192445e-01, - -2.2780324141178667e-02, -1.3299949529057514e-01, - 5.6952253822862586e-01, 1.7551693338628394e-01, - -3.8851158821630960e-01, -8.2597118671349307e-02, - -5.5521724833590726e-02, 1.8126259477529724e-01, - 1.7814975368438311e-02, -6.5528218308503153e-02, - 3.7971760553771383e-02, -1.5071623691597721e-01, - 2.1592446351812103e-01, -5.6402536331480002e-04, - 4.5088070248228272e-02, 2.6712876881033590e-02, - -5.4087768899409383e-03, 6.8686308808012492e-02, - 3.2287080492312645e-02, -2.2653956175799139e-01], - [ 2.1117037773522207e-01, -5.0164247531242157e-02, - 2.6588099000556803e-01, 9.2461134185888125e-02, - -1.8638912752062822e-01, -1.3326201088302150e-02, - -1.5139012219398481e-02, 5.6526342555140038e-02, - -2.1347405801495557e-02, 4.2134620640229903e-03, - 1.6189227618448768e-01, -4.0274584225345120e-02, - -5.6430110607539385e-02, -5.8413256975427548e-02, - 5.2327365554425583e-02, 1.0547316593589447e-01, - -1.0141590903757328e-01, 2.2750086641208328e-03, - -2.9965053997941909e-01, 1.5580924251156411e-03, - -9.8801397992561726e-02, 7.0133690173366392e-01, - 2.9288631311505543e-02, 3.2187639373342534e-02, - 8.5847997795661615e-02, 2.0571325754758280e-01, - 7.4079833507648560e-02, 1.5568547966076893e-01, - -4.9689302197244593e-02, 7.8435365554783448e-02, - 4.8351735020509205e-02, -1.7685071128733182e-01, - 6.5889048949493989e-03, 8.0297089881752479e-02, - 3.9088810533135447e-02, -2.3078576520695873e-01], - [ 2.2664196503808903e-01, -3.6435359235223168e-02, - 2.7461198824493543e-01, 8.5347376974543573e-02, - -2.1059797477235808e-02, 1.1448326379020789e-02, - -2.6592754399652377e-02, 2.5891172442431810e-02, - 2.8366243844641929e-02, -2.0536075588459556e-02, - 6.6444382000443650e-05, -6.6068428617317751e-02, - 2.3676624954254568e-02, 2.2112015932022797e-01, - 3.6011261258148117e-02, 6.3110902119789564e-02, - -6.5129709470743133e-02, -4.8955274099800709e-02, - 1.5625642089103450e-01, 1.1336968441478927e-01, - 7.1887047535547766e-01, -1.4060033754799098e-01, - -4.3732646616641863e-02, -2.9113406474813336e-01, - -5.4252028224128682e-02, 8.5563234976626823e-02, - -9.8842092892354998e-03, -8.6014269752744857e-02, - -5.3867992496449059e-02, 1.0226004671603665e-01, - 2.0616418999784455e-01, -6.6321426514466278e-02, - 1.7485733797709232e-02, 1.0373147806260606e-02, - 3.9178042791043720e-02, -2.3503196865592610e-01]]), array([-1.8988227080038084e+03, -8.1652460579197793e-12, - -6.8293671717855184e-12, -5.0961343548435651e-12, - -4.6422244875241180e-12, -4.0432649621797409e-12, - -4.6750947941168519e-13, -4.2866623066103143e-13, - -3.9638626555876315e-13, -3.4647469398250028e-13, - -3.2765729675497798e-13, -3.0727463002427591e-13, - -2.9879803908775378e-13, -2.4080245315867009e-13, - -2.1775959053373055e-13, -1.8534745675222213e-13, - -1.5959779217062472e-13, -1.0879546752449559e-13, - -9.0067575069985811e-14, -5.3973885458936187e-14, - -4.6064162488080463e-14, 6.1429074771130427e-15, - 1.3659631287864453e-14, 3.4753391317142145e-14, - 8.7547004653142170e-14, 1.2585089324337818e-13, - 1.5745245909745148e-13, 2.0606204849135956e-13, - 2.1792577470203850e-13, 2.6674476798831050e-13, - 3.0421425292401405e-13, 3.1193691330212636e-13, - 3.1270969371399125e-13, 4.3446674157388007e-13, - 1.6764394233642590e-12, 2.5208822708003838e+04])), - mlir_module_text=r""" -module @jit__lambda_ { - func.func public @main() -> (tensor<36x36xf64> {jax.result_info = "[0]"}, tensor<36xf64> {jax.result_info = "[1]"}) { - %0 = stablehlo.iota dim = 0 : tensor<1296xf64> - %1 = stablehlo.reshape %0 : (tensor<1296xf64>) -> tensor<36x36xf64> - %2 = stablehlo.transpose %1, dims = [1, 0] : (tensor<36x36xf64>) -> tensor<36x36xf64> - %3 = stablehlo.add %1, %2 : tensor<36x36xf64> - %4 = stablehlo.constant dense<2.000000e+00> : tensor - %5 = stablehlo.broadcast_in_dim %4, dims = [] : (tensor) -> tensor<36x36xf64> - %6 = stablehlo.divide %3, %5 : tensor<36x36xf64> - %7 = call @tril(%6) : (tensor<36x36xf64>) -> tensor<36x36xf64> - %8 = stablehlo.custom_call @cusolver_syevd(%7) {api_version = 2 : i32, backend_config = "\01\00\00\00\00\00\00\00\01\00\00\00$\00\00\00Y\98\00\00", operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>]} : (tensor<36x36xf64>) -> tuple, tensor<36xf64>, tensor, tensor<39001xf64>> - %9 = stablehlo.get_tuple_element %8[0] : (tuple, tensor<36xf64>, tensor, tensor<39001xf64>>) -> tensor<36x36xf64> - %10 = stablehlo.get_tuple_element %8[1] : (tuple, tensor<36xf64>, tensor, tensor<39001xf64>>) -> tensor<36xf64> - %11 = stablehlo.get_tuple_element %8[2] : (tuple, tensor<36xf64>, tensor, tensor<39001xf64>>) -> tensor - %12 = stablehlo.get_tuple_element %8[3] : (tuple, tensor<36xf64>, tensor, tensor<39001xf64>>) -> tensor<39001xf64> - %13 = stablehlo.constant dense<0> : tensor - %14 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor - %15 = stablehlo.compare EQ, %11, %14, SIGNED : (tensor, tensor) -> tensor - %16 = stablehlo.broadcast_in_dim %15, dims = [] : (tensor) -> tensor<1x1xi1> - %17 = stablehlo.constant dense<0x7FF8000000000000> : tensor - %18 = stablehlo.broadcast_in_dim %17, dims = [] : (tensor) -> tensor<36x36xf64> - %19 = stablehlo.broadcast_in_dim %16, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<36x36xi1> - %20 = stablehlo.select %19, %9, %18 : tensor<36x36xi1>, tensor<36x36xf64> - %21 = stablehlo.broadcast_in_dim %15, dims = [] : (tensor) -> tensor<1xi1> - %22 = stablehlo.constant dense<0x7FF8000000000000> : tensor - %23 = stablehlo.broadcast_in_dim %22, dims = [] : (tensor) -> tensor<36xf64> - %24 = stablehlo.broadcast_in_dim %21, dims = [0] : (tensor<1xi1>) -> tensor<36xi1> - %25 = stablehlo.select %24, %10, %23 : tensor<36xi1>, tensor<36xf64> - return %20, %25 : tensor<36x36xf64>, tensor<36xf64> - } - func.func private @tril(%arg0: tensor<36x36xf64>) -> tensor<36x36xf64> { - %0 = stablehlo.iota dim = 0 : tensor<36x36xi32> - %1 = stablehlo.constant dense<0> : tensor - %2 = stablehlo.broadcast_in_dim %1, dims = [] : (tensor) -> tensor<36x36xi32> - %3 = stablehlo.add %0, %2 : tensor<36x36xi32> - %4 = stablehlo.iota dim = 1 : tensor<36x36xi32> - %5 = stablehlo.compare GE, %3, %4, SIGNED : (tensor<36x36xi32>, tensor<36x36xi32>) -> tensor<36x36xi1> - %6 = stablehlo.constant dense<0.000000e+00> : tensor - %7 = stablehlo.broadcast_in_dim %6, dims = [] : (tensor) -> tensor<36x36xf64> - %8 = stablehlo.select %5, %arg0, %7 : tensor<36x36xi1>, tensor<36x36xf64> - return %8 : tensor<36x36xf64> - } -} -""", - mlir_module_serialized=b"ML\xefR\x03MLIRxxx-trunk\x00\x01-\x05\x01\x05\x01\x03\x05\x03\x1d\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x1f!\x03^\x02\xeb5\x01\x95\x0f\x17\x13\x07\x0f\x0b\x0b\x0b\x0b\x0b\x17\x0b\x0b\x0b\x0b\x13\x0b\x13\x0f\x0b\x0b\x17\x0f\x13\x13\x0b33\x0b\x0f\x0b\x0b\x13\x0f\x0b\x1b\x0f\x0b\x13\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x13\x0b\x0f\x0b\x0f\x0b\x13\x0b\x13\x0bK\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x13\x13\x13\x1b\x13\x13\x03W\x0b\x0b\x0f\x0b\x0bO/\x0b\x13\x13\x0b\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x1f\x0f\x0f\x0b/O/\x0b\x0b\x0b\x0b\x0f\x0f\x17\x1b\x0f\x0f\x0f\x0f\x0f\x0b/O/\x035\x17\x0f\x07\x0f\x07\x13\x07\x07\x17\x07\x17\x13\x1b\x17\x17\x13\x17\x1b\x13\x13\x13\x0f\x17\x13\x13\x13\x02\xa6\x08\x1d\x85\x03\x17\x11R\x04\x01\x03\x03\x13\xbd\x1f\x1d9\x03\x05#\x05%\x05'\x05)\x05+\x17\x11N\x04\x01\x05-\x05/\x051\x053\x03\x03!\xb9\x055\x03\x03\x0b\xbb\x1d?\x03\x057\x059\x17\x11F\x04\x01\x1dm\x15\x03\x03\x0b\xe5\x03\x03\x0f3\x05;\x03\x0b\x17\x95\x19\xa3\x1b\xa5\x0f\xaf\x1d\xb1\x03\x0b\x17\x99\x19\xb5\x1b\x99\x0f\x9b\x1d\xb7\x05=\x1d=\x03\x05?\x05A\x03\x03!\xbf\x1dE\x03\x05C\x03\x05'\x9d)\xc1\x1dK\x03\x05E\x03\x03\x0b\xc3\x1dQ\x03\x05G\x1dU\x03\x05I\x1dY+\x05K\x1d]+\x05M\x03\x03a\xc5\x05O\x1de\x15\x05Q\x1di\x15\x05S\x03\x03\x0b\xc7\x05U\x03\x03q\x9b\x05W\x03\x11u\xc9w\xcby\xcd{\x95}\xcf\x7f\xd1\x81\xd3\x83\xd7\x05Y\x05[\x05]\x05_\x05a\x05c\x05e\x05g\x05i\x03\x03\r\xdb\x03\x03\r\xdd\x03\x03\r\xdf\x03\x03\r\xe1\x03\x05'\x9d)\xe3\x03\x03\x13\xe7\x03\x03\x13\xe9\x03\x01\x1dk\x03\x03\xb3\x1dm\t\x07\x1f%!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f'\x11\x00\x00\x00\x00\x00\x00\x00\x00#\x1b\x03\x05\xa7\xab\r\x03\x97\xa9\x1do\r\x03\x97\xad\x1dq\x1ds\x1du\r\x01#\x1d\x1dw\x13\r\x01\x1f\x07\t\x00\x00\x00\x00\x1f\x1f\x01\x13\r\x05\x07\x05\x1f\x03\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x17!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x03\x11\x00\x00\x00\x00\x00\x00\x00@\x0b\x05\x1dy\x1d{\x05\x01\x03\x03\x9f\x03\x03\xd5\x15\x03\x01\x01\x01\x03\t\x9f\xa1\xd9\xa1\x1f)\x01\x13\x05\x01\x13\x05\x05\x13\x05\t\x13\x05\r\x07\x01\x1f\x03\x11\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f\x17!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f3\x11\x00\x00\x00\x00\x00\x00\x00\x00)\x05\x91\x91\t)\x01\t\x1b)\x01\x05\x0b)\x03\x91\t\x1d\x01)\x05\x91\x91\x05\x13)\x05\x91\x91\x0f)\x03\t\r)\x03\x94\x85\t\t\x11\x01\x05\x01\x0b\x11\x03\x01\x03\x01)\x03\x01\r)\x03\x82(\t/\t\x01\x0b\x07\x19)\x03\t\x13)\x03\x05\x13)\x03\x01\x13)\x01\x0f)\x05\x05\x05\x0f)\x03\x05\x0f)\x03\x91\x0f)\x03\x05\r\x04\xc6\x04\x05\x01\x11\x071\x07\x03\x01\t\r\x11\x075\x05\x035m\t\x03W\x1f\x03!\x15\x06[\x03\x01\x03\x01\x17\x07c_\x03\x01\x03\x03\x0f\x06g\x03\x01\x05\x03\x05\x05\x03\x07k\x03\x03\x03\x07-\x05\x03\x01\x03\t\x19\x06-\x03\x01\x05\x07\x0b\x1b\x07\to\x03\x01\x03\r\x1d\x07\x01s\x03#\x03\x0f\x07\x07\x01\x87\x03\x01\x03\x11\x07\x07\x01\x89\x03\x0b\x03\x11\x07\x07\x01\x8b\x03\x07\x03\x11\x07\x07\x01\x8d\x03\x19\x03\x11\x05\x03\x01#\x03\x07\x03\x07\x01\x05\x03\x07\x03\x1b\x11\x07\x01\x8f\x03+\x05\x17\x1d\x03\x07\x01\x05\x03-\x03\x1f\x05\x03\x01/\x03\x03\x03\x07\x01\x05\x03\x01\x03#\x03\x07\x01\x91\x03\x15\x03!\x0b\x06\x01\x03\x01\x07'\x13%\x03\x07\x01\x05\x03/\x03\x1f\x05\x03\x01/\x03\x03\x03\x07\x01\x05\x03\x0b\x03-\x03\x07\x01\x93\x031\x03+\x0b\x06\x01\x03\x0b\x071\x15/\x13\x04\x07\x05)3\r\x11\t7\x05\x03\x15+\x03\x01\x07\t\x03;\x1f\x03\x11\x05\x03\t#\x03\x07\x03\x07%\x05\x03\x11\x03\x05\x0f\x06%\x03\x11\x05\x03\x07\t\x03CA\x03\x11\x11\x07IG\x03\x15\x05\t\x0b\x05\x03\tM\x03\x03\x03\x07O\x05\x03\x01\x03\x0f\x0b\x06S\x03\x01\x07\r\x01\x11\x13\x04\t\x03\x13\x06\x03\x01\x05\x01\x00.\x1a}\x1f+\x11\x0f\x0b\t\t\x0b!\x7f\x1f/!!)#\x1f\x19\x0f99m\x19\x89\x8dW\xb7K\x9fM\x9f\x96\x04\x1b+\x1b\x1f\x1f\x15\x1d\x15+\x83\x13\r\r\x1f\x11\x15\x1b\x17\x15\x17\x0f\x11\x15\x11+\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00get_tuple_element_v1\x00iota_v1\x00select_v1\x00func_v1\x00add_v1\x00compare_v1\x00return_v1\x00reshape_v1\x00transpose_v1\x00divide_v1\x00call_v1\x00custom_call_v1\x00value\x00index\x00sym_name\x00third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py\x00broadcast_dimensions\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00compare_type\x00comparison_direction\x00jit__lambda_\x00jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False,) name=tril in_positional_semantics=(<_PositionalSemantics.GLOBAL: 1>,) out_positional_semantics=_PositionalSemantics.GLOBAL keep_unused=False inline=False]\x00jit()/jit(main)/jit(tril)/iota[dtype=int32 shape=(36, 36) dimension=0]\x00jit()/jit(main)/jit(tril)/add\x00jit()/jit(main)/jit(tril)/iota[dtype=int32 shape=(36, 36) dimension=1]\x00jit()/jit(main)/jit(tril)/ge\x00jit()/jit(main)/jit(tril)/broadcast_in_dim[shape=(36, 36) broadcast_dimensions=()]\x00jit()/jit(main)/jit(tril)/select_n\x00jit()/jit(main)/iota[dtype=float64 shape=(1296,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(36, 36) dimensions=None]\x00permutation\x00jit()/jit(main)/transpose[permutation=(1, 0)]\x00jit()/jit(main)/add\x00jit()/jit(main)/div\x00callee\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jit()/jit(main)/eigh[lower=True sort_eigenvalues=True]\x00jax.result_info\x00tril\x00[0]\x00[1]\x00main\x00public\x00private\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00$\x00\x00\x00Y\x98\x00\x00\x00cusolver_syevd\x00", - xla_call_module_version=4, - ) # End paste -) - data_2024_09_30 = {} data_2024_09_30["f32"] = dict( diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/cuda_lu_pivots_to_permutation.py b/jax/_src/internal_test_util/export_back_compat_test_data/cuda_lu_pivots_to_permutation.py index 12285a45b77a..8063d9f44722 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_data/cuda_lu_pivots_to_permutation.py +++ b/jax/_src/internal_test_util/export_back_compat_test_data/cuda_lu_pivots_to_permutation.py @@ -16,11 +16,11 @@ from numpy import array, int32 # Pasted from the test output (see export_back_compat_test_util.py module docstring) -data_2024_08_08 = dict( +data_2025_04_01 = dict( testdata_version=1, platform='cuda', custom_call_targets=['cu_lu_pivots_to_permutation'], - serialized_date=datetime.date(2024, 8, 8), + serialized_date=datetime.date(2025, 4, 1), inputs=(), expected_outputs=(array([[[0, 1, 2, 3, 4, 5, 6, 7], [4, 5, 6, 7, 0, 1, 2, 3], @@ -31,25 +31,22 @@ [0, 1, 2, 3, 4, 5, 6, 7]]], dtype=int32),), mlir_module_text=r""" module @jit__lambda_ attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main() -> (tensor<2x3x8xi32> {jax.result_info = "", mhlo.layout_mode = "default"}) { + func.func public @main() -> (tensor<2x3x8xi32> {jax.result_info = "result"}) { %0 = stablehlo.iota dim = 0 : tensor<24xi32> loc(#loc4) %1 = stablehlo.reshape %0 : (tensor<24xi32>) -> tensor<2x3x4xi32> loc(#loc5) - %c = stablehlo.constant dense<2> : tensor loc(#loc6) - %c_0 = stablehlo.constant dense<3> : tensor loc(#loc6) - %c_1 = stablehlo.constant dense<4> : tensor loc(#loc6) - %2 = stablehlo.custom_call @cu_lu_pivots_to_permutation(%1) {mhlo.backend_config = {permutation_size = 8 : i32}, operand_layouts = [dense<[2, 1, 0]> : tensor<3xindex>], result_layouts = [dense<[2, 1, 0]> : tensor<3xindex>]} : (tensor<2x3x4xi32>) -> tensor<2x3x8xi32> loc(#loc6) + %2 = stablehlo.custom_call @cu_lu_pivots_to_permutation(%1) {mhlo.backend_config = {}, mhlo.frontend_attributes = {num_batch_dims = "2"}, operand_layouts = [dense<[2, 1, 0]> : tensor<3xindex>], result_layouts = [dense<[2, 1, 0]> : tensor<3xindex>]} : (tensor<2x3x4xi32>) -> tensor<2x3x8xi32> loc(#loc6) return %2 : tensor<2x3x8xi32> loc(#loc) } loc(#loc) } loc(#loc) #loc = loc(unknown) -#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":347:26) -#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":347:14) -#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":348:11) -#loc4 = loc("jit()/jit(main)/iota[dtype=int32 shape=(24,) dimension=0]"(#loc1)) -#loc5 = loc("jit()/jit(main)/reshape[new_sizes=(2, 3, 4) dimensions=None]"(#loc2)) -#loc6 = loc("jit()/jit(main)/lu_pivots_to_permutation[permutation_size=8]"(#loc3)) +#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":408:26) +#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":408:14) +#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":409:11) +#loc4 = loc("jit()/jit(main)/iota"(#loc1)) +#loc5 = loc("jit()/jit(main)/reshape"(#loc2)) +#loc6 = loc("jit()/jit(main)/lu_pivots_to_permutation"(#loc3)) """, - mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01\x1d\x05\x01\x03\x01\x03\x05\x03\r\x07\t\x0b\r\x0f\x11\x03\xa7}\x17\x01Q\x0f\x07\x0b\x0b\x0f\x0b+\x0b\x0f\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x17\x0f\x0b\x17\x13\x0b\x17\x13\x13S\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x03-\x0b\x0b\x0f\x0b\x0f\x1b\x0b\x0b\x0b\x0b\x0b\x0f///\x0b\x0b\x0b\x13\x0b\x0fo\x01\x05\x0b\x0f\x03\x13\x0f\x07\x1b\x07\x13\x13\x1b\x13\x07\x02Z\x04\x1d57\x1f\x05\x13\x05\x15\x11\x03\x05\x05\x17\x03\t\x0f\x11\x13\t\x15\t\x0b\x17\x05\x19\x11\x01\x00\x05\x1b\x05\x1d\x05\x1f\x03\x0b\x1bQ\x1dW\x1fY\x0bc!e\x05!\x05#\x05%\x05'\x03\x03%g\x05)\x1d)+\x05+\x17\x05n\x055\x1d/1\x05-\x17\x05n\x05\x1d\x03\x03\x07i\x05/\x17\x05r\x05\x17\x03\x03\x07k\x03\x03\x07m\x03\x13?oASCqEQGsIuKUMQOU\x051\x053\x055\x057\x059\x05;\x05=\x05?\x05A\x03\x01\x1dC\x03\x03{#\r\x03\x03[\r\x05]S_a\x1dE\x1dG\x1dI\x1dK\x1dM\x13\x0b\x01\x1f\x05\x11\x02\x00\x00\x00\x00\x00\x00\x00\x1f\x05\x11\x03\x00\x00\x00\x00\x00\x00\x00\x1f\x05\x11\x04\x00\x00\x00\x00\x00\x00\x00\x0b\x03\x1dO\x05\x01\r\x03wy\x1dQ\x13\x07!\x1f\x131\x02\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x01\x0b\x1b)\x07\t\r!\x07\x1d\x11\x01\x03\t)\x03a\x07)\x07\t\r\x11\x07)\x03\r\x15\x13\x04{\x05\x01\x11\x03\r\x07\x03\x01\x05\x05\x11\x03\x19\x07\x03\r\x1d\x07\x03'#\x03\x0f\t\x06-\x03\x11\x03\x01\x03\x03\x013\x03\x05\x03\x03\x019\x03\x05\x03\x03\x01;\x03\x05\x0b\x07\x01=\x03\t\x03\x03\r\x04\x03\x03\x0b\x06\x03\x01\x05\x01\x00f\x0cS#9\x0f\x0b\x11#!\x03\x1f/!)!)#\x1f\x19\x8b\x8b\x85\x1f\x1f\x15\x1d\x15\x1b%)9\x13\ri\x15\x1f\x17\x11\x11\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00func_v1\x00iota_v1\x00reshape_v1\x00custom_call_v1\x00return_v1\x00third_party/py/jax/tests/export_back_compat_test.py\x00value\x00sym_name\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00jit()/jit(main)/iota[dtype=int32 shape=(24,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(2, 3, 4) dimensions=None]\x00jit()/jit(main)/lu_pivots_to_permutation[permutation_size=8]\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00mhlo.backend_config\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00\x00jax.result_info\x00mhlo.layout_mode\x00default\x00main\x00public\x00cu_lu_pivots_to_permutation\x00permutation_size\x00", + mlir_module_serialized=b"ML\xefR\rStableHLO_v1.9.3\x00\x01\x1d\x05\x01\x05\r\x01\x03\x0b\x03\x0b\x0f\x13\x17\x1b\x1f\x03yQ\x15\x01+\x07\x0b\x0f#\x0b\x0f\x0b\x0b\x0b\x0f\x0b\x17\x0f\x0b\x17\x1b\x0b\x0b\x0f\x0b\x17\x03'\x0b\x0f\x0b\x0f\x13\x0b\x0b\x0b\x0b\x0f\x0b\x13\x0b\x0b\x0b\x0b\x0b\x0bo\x01\x05\x0b\x0f\x03\x11\x07\x1b\x13\x13\x07\x1b\x13\x07\x02\x9e\x02\x1f\x05\x11\x11\x03\x05\x03\x07\t\x0b\r\x05\x0f\x05\x05\x13\x11\x01\x00\x05\x15\x05\x17\x05\x19\x1d\x15\x17\x05\x1b\x17\x03b\x065\x1d\x1b\x1d\x05\x1d\x17\x03b\x06\x1d\x03\x05!?#A\x05\x1f\x05!\x1d')\x05#\x17\x03f\x06\x17\x03\x01\x03\x03O#\t\x03\x033\r\x0357\x1d%\x1d'\x1d)\x1d+\x13\r\x01\r\x01\r\x03CE\x1d-\x1d/\x0b\x03\x1d1\x1d3\x05\x01\x1f\x111\x02\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02\x1b)\x07\t\r!\x05\x11\x01\x03\x07)\x03a\x05\x1d)\x07\t\r\x11\x05)\x03\r\x13\x13\x04c\x05\x01Q\x01\x07\x01\x07\x04Q\x03\x01\x05\x03P\x01\x03\x07\x04=\x03\x07\x11\x05B\x13\x05\x03\x0b\x07\x06\x19\x03\x0f\x03\x01\tG%\x1f\x07\x03\x07\x03\x03\x0b\x04\x01\x03\x05\x06\x03\x01\x05\x01\x00J\x0759\x03\x05\x1f\x0f\x0b\x0f!c3)A;\x1b%)9i\x15\x1f\x17\x11\x11\x0f\x0b\x11builtin\x00vhlo\x00module\x00func_v1\x00iota_v1\x00reshape_v1\x00custom_call_v1\x00return_v1\x00third_party/py/jax/tests/export_back_compat_test.py\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00jit()/jit(main)/iota\x00jit()/jit(main)/reshape\x00mhlo.backend_config\x00mhlo.frontend_attributes\x00jit()/jit(main)/lu_pivots_to_permutation\x00jax.result_info\x00result\x00main\x00public\x00num_batch_dims\x002\x00\x00cu_lu_pivots_to_permutation\x00\x08+\t\x05#\x01\x0b+/19;\x03=\x11GIK+M-+-", xla_call_module_version=9, nr_devices=1, ) # End paste diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/cuda_qr_cusolver_geqrf.py b/jax/_src/internal_test_util/export_back_compat_test_data/cuda_qr_cusolver_geqrf.py index be5c6e01f8d8..00ced41a0492 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_data/cuda_qr_cusolver_geqrf.py +++ b/jax/_src/internal_test_util/export_back_compat_test_data/cuda_qr_cusolver_geqrf.py @@ -15,149 +15,10 @@ # ruff: noqa import datetime -from numpy import array, float32, float64, complex64, complex128 +from numpy import array, float32, complex64 -data_2023_03_18 = {} - -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_03_18["unbatched"] = dict( - testdata_version=1, - platform='cuda', - custom_call_targets=['cusolver_geqrf', 'cusolver_orgqr'], - serialized_date=datetime.date(2023, 3, 18), - inputs=(), - expected_outputs=(array([[ 0. , 0.9128705 , 0.40824863], - [-0.44721356, 0.36514878, -0.8164964 ], - [-0.8944271 , -0.18257457, 0.40824813]], dtype=float32), array([[-6.7082043e+00, -8.0498438e+00, -9.3914843e+00], - [ 0.0000000e+00, 1.0954436e+00, 2.1908882e+00], - [ 0.0000000e+00, 0.0000000e+00, 5.6703755e-08]], dtype=float32)), - mlir_module_text=r""" -module @jit__lambda_ { - func.func public @main() -> (tensor<3x3xf32> {jax.result_info = "[0]"}, tensor<3x3xf32> {jax.result_info = "[1]"}) { - %0 = stablehlo.iota dim = 0 : tensor<9xf32> - %1 = stablehlo.reshape %0 : (tensor<9xf32>) -> tensor<3x3xf32> - %2 = stablehlo.custom_call @cusolver_geqrf(%1) {api_version = 2 : i32, backend_config = "\00\00\00\00\01\00\00\00\03\00\00\00\03\00\00\00\00\00\03\00", operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>]} : (tensor<3x3xf32>) -> tuple, tensor<3xf32>, tensor, tensor<196608xf32>> - %3 = stablehlo.get_tuple_element %2[0] : (tuple, tensor<3xf32>, tensor, tensor<196608xf32>>) -> tensor<3x3xf32> - %4 = stablehlo.get_tuple_element %2[1] : (tuple, tensor<3xf32>, tensor, tensor<196608xf32>>) -> tensor<3xf32> - %5 = stablehlo.get_tuple_element %2[2] : (tuple, tensor<3xf32>, tensor, tensor<196608xf32>>) -> tensor - %6 = stablehlo.get_tuple_element %2[3] : (tuple, tensor<3xf32>, tensor, tensor<196608xf32>>) -> tensor<196608xf32> - %7 = stablehlo.constant dense<0> : tensor - %8 = stablehlo.broadcast_in_dim %7, dims = [] : (tensor) -> tensor - %9 = stablehlo.compare EQ, %5, %8, SIGNED : (tensor, tensor) -> tensor - %10 = stablehlo.broadcast_in_dim %9, dims = [] : (tensor) -> tensor<1x1xi1> - %11 = stablehlo.constant dense<0x7FC00000> : tensor - %12 = stablehlo.broadcast_in_dim %11, dims = [] : (tensor) -> tensor<3x3xf32> - %13 = stablehlo.broadcast_in_dim %10, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<3x3xi1> - %14 = stablehlo.select %13, %3, %12 : tensor<3x3xi1>, tensor<3x3xf32> - %15 = stablehlo.broadcast_in_dim %9, dims = [] : (tensor) -> tensor<1xi1> - %16 = stablehlo.constant dense<0x7FC00000> : tensor - %17 = stablehlo.broadcast_in_dim %16, dims = [] : (tensor) -> tensor<3xf32> - %18 = stablehlo.broadcast_in_dim %15, dims = [0] : (tensor<1xi1>) -> tensor<3xi1> - %19 = stablehlo.select %18, %4, %17 : tensor<3xi1>, tensor<3xf32> - %20 = stablehlo.constant dense<0.000000e+00> : tensor - %21 = stablehlo.pad %14, %20, low = [0, 0], high = [0, 0], interior = [0, 0] : (tensor<3x3xf32>, tensor) -> tensor<3x3xf32> - %22 = stablehlo.custom_call @cusolver_orgqr(%21, %19) {api_version = 2 : i32, backend_config = "\00\00\00\00\01\00\00\00\03\00\00\00\03\00\00\00\03\00\00\00 \81\00\00", operand_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>]} : (tensor<3x3xf32>, tensor<3xf32>) -> tuple, tensor, tensor<33056xf32>> - %23 = stablehlo.get_tuple_element %22[0] : (tuple, tensor, tensor<33056xf32>>) -> tensor<3x3xf32> - %24 = stablehlo.get_tuple_element %22[1] : (tuple, tensor, tensor<33056xf32>>) -> tensor - %25 = stablehlo.get_tuple_element %22[2] : (tuple, tensor, tensor<33056xf32>>) -> tensor<33056xf32> - %26 = stablehlo.constant dense<0> : tensor - %27 = stablehlo.broadcast_in_dim %26, dims = [] : (tensor) -> tensor - %28 = stablehlo.compare EQ, %24, %27, SIGNED : (tensor, tensor) -> tensor - %29 = stablehlo.broadcast_in_dim %28, dims = [] : (tensor) -> tensor<1x1xi1> - %30 = stablehlo.constant dense<0x7FC00000> : tensor - %31 = stablehlo.broadcast_in_dim %30, dims = [] : (tensor) -> tensor<3x3xf32> - %32 = stablehlo.broadcast_in_dim %29, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<3x3xi1> - %33 = stablehlo.select %32, %23, %31 : tensor<3x3xi1>, tensor<3x3xf32> - %34 = call @triu(%14) : (tensor<3x3xf32>) -> tensor<3x3xf32> - return %33, %34 : tensor<3x3xf32>, tensor<3x3xf32> - } - func.func private @triu(%arg0: tensor<3x3xf32>) -> tensor<3x3xf32> { - %0 = stablehlo.iota dim = 0 : tensor<3x3xi32> - %1 = stablehlo.constant dense<-1> : tensor - %2 = stablehlo.broadcast_in_dim %1, dims = [] : (tensor) -> tensor<3x3xi32> - %3 = stablehlo.add %0, %2 : tensor<3x3xi32> - %4 = stablehlo.iota dim = 1 : tensor<3x3xi32> - %5 = stablehlo.compare GE, %3, %4, SIGNED : (tensor<3x3xi32>, tensor<3x3xi32>) -> tensor<3x3xi1> - %6 = stablehlo.constant dense<0.000000e+00> : tensor - %7 = stablehlo.broadcast_in_dim %6, dims = [] : (tensor) -> tensor<3x3xf32> - %8 = stablehlo.select %5, %7, %arg0 : tensor<3x3xi1>, tensor<3x3xf32> - return %8 : tensor<3x3xf32> - } -} -""", - mlir_module_serialized=b"ML\xefR\x03MLIRxxx-trunk\x00\x01+\x05\x01\x05\x01\x03\x05\x03\x1b\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x1f\x03~\x02\xf79\x01\x99\x0f\x0f\x17\x13\x0f\x07\x0b\x0b\x0b\x0b\x13\x0b\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x0b\x13\x17\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x13\x13\x13\x1b\x13\x13\x0b33\x0b\x0f\x0b\x13\x0b\x13\x0f\x0b\x1b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0bK\x0b\x13\x13\x0f\x0b#\x0b\x0b\x0b\x0f\x0bK\x0b\x13\x0b\x03_O/\x0b/\x0b\x0f\x0b\x0b\x0b\x0b\x0f\x0f\x0b\x13\x13\x0b\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x1f\x0f\x0f\x0b\x1f\x0b\x0b\x0f\x17\x1b\x0f\x0f\x0f\x0f\x1f\x0b\x1fO/\x0b\x0b\x13\x17\x039\x17\x0f\x0f\x07\x07\x07\x07\x17\x13\x17\x07\x1b\x0f\x17\x13\x1b\x17\x17\x13\x13\x1b\x13\x13\x13\x13\x13\x13\x17\x02\x06\t\x1d{\x05\x1d\x93\x05\x17\x1f\n\x06\x01\x03\x03\x13\xcb\x1dS\x05\x1f\x05!\x05#\x05%\x05'\x03\x03\r\xe9\x05)\x05+\x05-\x05/\x051\x03\x03#\xc7\x053\x1d[\x05\x055\x057\x03\x03\r\xd1\x17\x1f\x06\x06\x01\x059\x05;\x05=\x05?\x05A\x05C\x05E\x05G\x03\x03\x0f\xdd\x03\x03\x0f\xdf\x03\x03\x0f\xe1\x03\x03\r\xe5\x03\x05'\xa7)\xe7\x03\x03\x13\xeb\x03\x03\x11M\x05I\x03\x0b\x17\x9d\x19\xb1\x1b\xb3\x11\xbd\x1d\xbf\x03\x0b\x17\xa3\x19\xc3\x1b\xa3\x11\xa5\x1d\xc5\x05K\x1dW\x05\x05M\x03\x03\r\xc9\x05O\x03\x03#\xcd\x1da\x05\x05Q\x03\x05'\xa7)\xcf\x1dg\x05\x05S\x1dk\x05\x05U\x1do\x05\x05W\x1ds-\x05Y\x1dw-\x05[\x03\x11/\xa91\xd33\xd55\x9d7\xab9\xd7;\xad=\xdb\x05]\x03\x03\x0f\xe3\x03\x03\x13\xed\x1d\x83\x05\x05_\x03\x07\x87\x9f\x89\x9f\x8b\x9f\x05a\x05c\x05e\x1d\x8f\x05\x05g\x03\x11/\xa91\xef3\xf15\x9d7\xab9\xf3;\xad=\xf5\x05i\x03\x03\x97\xa5\x05k\x1f+!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f-\x11\x00\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1f\x1d\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1dm\x03\x03\xc1\x1do\t\x07\x0b\x05\x05\x01\x03\x03\xd9\x1f/\x01#!\x03\x05\xb5\xb9\r\x03\xa1\xb7\x1dq\r\x03\xa1\xbb\x1ds\x1du\x1dw\r\x01##\x1dy\x13\x0b\x01\x1f\x03\t\xff\xff\xff\xff\x1f%\x01\x13\x0b\x05\x07\x05\x1f\x05\t\x00\x00\x00\x00\x1d{\x1d}\x03\x03\x99\x15\x03\x01\x01\x01\x03\t\x99\x9b\xaf\x9b\x13\t\x01\x13\t\x05\x13\t\t\x13\t\r\x1f\x03\t\x00\x00\x00\x00\x07\x01\x1f\x05\t\x00\x00\xc0\x7f\x1f\x1d!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f5\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1d\x7f\x1d\x81\x03\x05\x99\x9b\x03\x07\x99\xaf\x9b)\x05\r\r\x07)\x01\t)\x01\x07\t\x1b\x1d\x01)\x05\r\r\t)\x03\r\x07)\x05\r\r\r\x13)\x03\x04\x000\x07)\x01\r)\x05\x05\x05\r)\x03\t\x0b)\x03\x04\x12\x08\x07\x11\x01\x05\x01\x01\x11\x03\x01\x03\x01)\x03\x01\x0b)\x03%\x07/\t\x01\x11\x03\x17)\x03\t\x15)\x03\x05\x15)\x03\x01\x15)\x03\x05\r)\x03\r\r)\x03\x05\x0b/\x07\x01\x03\x1f\x04\xe6\x05\x05\x01\x11\x0bK\x07\x03\x01\t\x0f\x11\x0bO\x05\x03G\x91\x0b\x03q!\x03'\x17\x06u\x03\x01\x03\x01\x13\x07\x01y\x03)\x03\x03\x07\x07\x01?\x03\x01\x03\x05\x07\x07\x01A\x03\x11\x03\x05\x07\x07\x01C\x03\x03\x03\x05\x07\x07\x01}\x03\x17\x03\x05\x05\x03\x01E\x03\x03\x03\x07\x01\x07\x03\x03\x03\x0f\r\x07\x01G\x03\x19\x05\x0b\x11\x03\x07\x01\x07\x03\x1b\x03\x13\x05\x03\x01\x15\x03\x05\x03\x07\x01\x07\x03\x01\x03\x17\x03\x07\x01I\x03\x13\x03\x15\t\x06\x01\x03\x01\x07\x1b\x07\x19\x03\x07\x01\x07\x031\x03\x13\x05\x03\x01\x15\x03\x05\x03\x07\x01\x07\x03\x11\x03!\x03\x07\x01\x7f\x033\x03\x1f\t\x06\x01\x03\x11\x07%\t#\x05\x03\x81+\x03\x05\x19\x07\x8d\x85\x03\x01\x05\x1d)\x13\x07\x03\x91\x037\x05+'\x07\x07\x03?\x03\x01\x03-\x07\x07\x03A\x03\x03\x03-\x07\x07\x03C\x03\x1f\x03-\x05\x03\x03E\x03\x03\x03\x07\x03\x07\x03\x03\x035\r\x07\x03G\x03\x19\x0517\x03\x07\x03\x07\x03\x1b\x039\x05\x03\x03\x15\x03\x05\x03\x07\x03\x07\x03\x01\x03=\x03\x07\x03I\x03\x13\x03;\t\x06\x03\x03\x01\x07A/?\x1b\x07\t\x95\x03\x01\x03\x1d\x11\x04\x0b\x05CE\x0f\x11\tQ\x05\x03\x15+\x03\x01\x0b\x0b\x03U!\x03\x0f\x05\x03\tY\x03\x03\x03\x07%\x07\x03\x0f\x03\x05\x15\x06%\x03\x0f\x05\x03\x07\x0b\x03_]\x03\x0f\r\x07ec\x03\x13\x05\t\x0b\x05\x03\t+\x03\x05\x03\x07i\x07\x03\x01\x03\x0f\t\x06m\x03\x01\x07\r\x11\x01\x11\x04\t\x03\x13\x06\x03\x01\x05\x01\x00\x86\x19\x83\x1f3\x1f+\x11\x0f\x0b\t\t\x0b!\x0fY\x87##%_=\x85\x87W\xb3K\x9bM\x9b\xd2\x02\x1b\x1f/!!)#\x1f\x19+\x1b\x1f\x83\x1f\x15\x1d\x15+\x13\r\r\x11\x0f\x17\x0f\x1f\x15\x11\x17\x11\x15+\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00get_tuple_element_v1\x00select_v1\x00iota_v1\x00compare_v1\x00func_v1\x00return_v1\x00custom_call_v1\x00add_v1\x00reshape_v1\x00pad_v1\x00call_v1\x00value\x00index\x00sym_name\x00broadcast_dimensions\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py\x00iota_dimension\x00compare_type\x00comparison_direction\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jit__lambda_\x00jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False,) name=triu keep_unused=False inline=False]\x00jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=0]\x00jit()/jit(main)/jit(triu)/add\x00jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=1]\x00jit()/jit(main)/jit(triu)/ge\x00jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(3, 3) broadcast_dimensions=()]\x00jit()/jit(main)/jit(triu)/select_n\x00jit()/jit(main)/iota[dtype=float32 shape=(9,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(3, 3) dimensions=None]\x00jit()/jit(main)/geqrf\x00jit()/jit(main)/qr[full_matrices=True]\x00edge_padding_high\x00edge_padding_low\x00interior_padding\x00jit()/jit(main)/pad[padding_config=((0, 0, 0), (0, 0, 0))]\x00jit()/jit(main)/householder_product\x00callee\x00jax.result_info\x00triu\x00[0]\x00[1]\x00main\x00public\x00private\x00\x00\x00\x00\x00\x01\x00\x00\x00\x03\x00\x00\x00\x03\x00\x00\x00\x00\x00\x03\x00\x00cusolver_geqrf\x00\x00\x00\x00\x00\x01\x00\x00\x00\x03\x00\x00\x00\x03\x00\x00\x00\x03\x00\x00\x00 \x81\x00\x00\x00cusolver_orgqr\x00", - xla_call_module_version=4, -) # End paste - - -# Pasted from the test output (see back_compat_test.py module docstring) -data_2023_03_18["batched"] = dict( - testdata_version=1, - platform='cuda', - custom_call_targets=['cublas_geqrf_batched', 'cusolver_orgqr'], - serialized_date=datetime.date(2023, 3, 18), - inputs=(), - expected_outputs=(array([[[ 0. , 0.91287094, 0.40824836], - [-0.4472136 , 0.36514843, -0.81649655], - [-0.8944272 , -0.18257417, 0.4082483 ]], - - [[-0.42426407, 0.80828977, 0.40824953], - [-0.5656854 , 0.11547142, -0.8164964 ], - [-0.7071068 , -0.5773508 , 0.4082474 ]]], dtype=float32), array([[[-6.7082038e+00, -8.0498447e+00, -9.3914852e+00], - [ 0.0000000e+00, 1.0954450e+00, 2.1908898e+00], - [ 0.0000000e+00, 0.0000000e+00, 4.8374091e-08]], - - [[-2.1213203e+01, -2.2910259e+01, -2.4607319e+01], - [ 0.0000000e+00, 3.4641042e-01, 6.9282258e-01], - [ 0.0000000e+00, 0.0000000e+00, 1.4548683e-06]]], dtype=float32)), - mlir_module_text=r""" -module @jit__lambda_ { - func.func public @main() -> (tensor<2x3x3xf32> {jax.result_info = "[0]"}, tensor<2x3x3xf32> {jax.result_info = "[1]"}) { - %0 = stablehlo.iota dim = 0 : tensor<18xf32> - %1 = stablehlo.reshape %0 : (tensor<18xf32>) -> tensor<2x3x3xf32> - %2 = stablehlo.custom_call @cublas_geqrf_batched(%1) {api_version = 2 : i32, backend_config = "\00\00\00\00\02\00\00\00\03\00\00\00\03\00\00\00", operand_layouts = [dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>]} : (tensor<2x3x3xf32>) -> tuple, tensor<2x3xf32>, tensor<16xi8>, tensor<16xi8>> - %3 = stablehlo.get_tuple_element %2[0] : (tuple, tensor<2x3xf32>, tensor<16xi8>, tensor<16xi8>>) -> tensor<2x3x3xf32> - %4 = stablehlo.get_tuple_element %2[1] : (tuple, tensor<2x3xf32>, tensor<16xi8>, tensor<16xi8>>) -> tensor<2x3xf32> - %5 = stablehlo.get_tuple_element %2[2] : (tuple, tensor<2x3xf32>, tensor<16xi8>, tensor<16xi8>>) -> tensor<16xi8> - %6 = stablehlo.get_tuple_element %2[3] : (tuple, tensor<2x3xf32>, tensor<16xi8>, tensor<16xi8>>) -> tensor<16xi8> - %7 = stablehlo.constant dense<0.000000e+00> : tensor - %8 = stablehlo.pad %3, %7, low = [0, 0, 0], high = [0, 0, 0], interior = [0, 0, 0] : (tensor<2x3x3xf32>, tensor) -> tensor<2x3x3xf32> - %9 = stablehlo.custom_call @cusolver_orgqr(%8, %4) {api_version = 2 : i32, backend_config = "\00\00\00\00\02\00\00\00\03\00\00\00\03\00\00\00\03\00\00\00 \81\00\00", operand_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>]} : (tensor<2x3x3xf32>, tensor<2x3xf32>) -> tuple, tensor<2xi32>, tensor<33056xf32>> - %10 = stablehlo.get_tuple_element %9[0] : (tuple, tensor<2xi32>, tensor<33056xf32>>) -> tensor<2x3x3xf32> - %11 = stablehlo.get_tuple_element %9[1] : (tuple, tensor<2xi32>, tensor<33056xf32>>) -> tensor<2xi32> - %12 = stablehlo.get_tuple_element %9[2] : (tuple, tensor<2xi32>, tensor<33056xf32>>) -> tensor<33056xf32> - %13 = stablehlo.constant dense<0> : tensor - %14 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor<2xi32> - %15 = stablehlo.compare EQ, %11, %14, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> - %16 = stablehlo.broadcast_in_dim %15, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> - %17 = stablehlo.constant dense<0x7FC00000> : tensor - %18 = stablehlo.broadcast_in_dim %17, dims = [] : (tensor) -> tensor<2x3x3xf32> - %19 = stablehlo.broadcast_in_dim %16, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x3x3xi1> - %20 = stablehlo.select %19, %10, %18 : tensor<2x3x3xi1>, tensor<2x3x3xf32> - %21 = call @triu(%3) : (tensor<2x3x3xf32>) -> tensor<2x3x3xf32> - return %20, %21 : tensor<2x3x3xf32>, tensor<2x3x3xf32> - } - func.func private @triu(%arg0: tensor<2x3x3xf32>) -> tensor<2x3x3xf32> { - %0 = stablehlo.iota dim = 0 : tensor<3x3xi32> - %1 = stablehlo.constant dense<-1> : tensor - %2 = stablehlo.broadcast_in_dim %1, dims = [] : (tensor) -> tensor<3x3xi32> - %3 = stablehlo.add %0, %2 : tensor<3x3xi32> - %4 = stablehlo.iota dim = 1 : tensor<3x3xi32> - %5 = stablehlo.compare GE, %3, %4, SIGNED : (tensor<3x3xi32>, tensor<3x3xi32>) -> tensor<3x3xi1> - %6 = stablehlo.broadcast_in_dim %5, dims = [1, 2] : (tensor<3x3xi1>) -> tensor<2x3x3xi1> - %7 = stablehlo.constant dense<0.000000e+00> : tensor - %8 = stablehlo.broadcast_in_dim %7, dims = [] : (tensor) -> tensor<2x3x3xf32> - %9 = stablehlo.select %6, %8, %arg0 : tensor<2x3x3xi1>, tensor<2x3x3xf32> - return %9 : tensor<2x3x3xf32> - } -} -""", - mlir_module_serialized=b"ML\xefR\x03MLIRxxx-trunk\x00\x01+\x05\x01\x05\x01\x03\x05\x03\x1b\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x1f\x03\x96\x02\xff=\x01\x9f\x17\x0f\x0f\x0f\x07\x0b\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x0b\x13\x17\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x13\x13\x13\x0b33\x0b\x0f\x0b\x13\x0b\x13\x0f\x0b\x1b\x0f\x0b\x13\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0bK\x0b\x13\x0f\x0b#\x0b\x0b\x0b\x0f\x0bK\x0b\x13\x1b\x13\x13\x13\x13\x0b\x03ao/\x0b/\x0b\x0f\x0b\x0b\x0b\x0b\x0fO\x0b\x13\x13\x0b\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x1f\x0f\x0f\x0bO\x1f\x0b\x0b\x0f\x17\x1b\x0f\x0f\x0f\x0f\x0b\x0b\x13\x17\x1f\x0b/\x1fo\x03=\x1b\x07\x07\x07\x0f\x17\x0f\x07\x13\x07\x13\x1b\x17\x13\x1b\x17\x17\x13\x17\x13\x13\x1b\x07\x13\x13\x13\x17\x13\x1b\x13\x02\x1a\n\x17\x1d\n\x06\x01\x1d\x8f\x01\x1dK\x01\x1dy\x01\x1f\x05!\x03\x03\x0f\xd1\x05#\x05%\x05'\x05)\x05+\x05-\x05/\x051\x03\x03!\xcd\x053\x1dS\x01\x055\x057\x03\x03\x0b\xd9\x17\x1d\x06\x06\x01\x059\x05;\x05=\x05?\x05A\x05C\x05E\x05G\x03\x03\x11\xe5\x03\x03\x11\xe7\x03\x03\x11\xe9\x03\x03\x13E\x05I\x03\x0b\x15\xa3\x17\xb7\x19\xb9\x13\xc3\x1b\xc5\x03\x0b\x15\xa9\x17\xc9\x19\xa9\x13\xab\x1b\xcb\x05K\x1dO\x01\x05M\x03\x03\x0b\xcf\x05O\x03\x03!\xd3\x1dY\x01\x05Q\x03\x05%\xad'\xd5\x1d_\x01\x05S\x03\x03\x0f\xd7\x1de\x01\x05U\x1di\x01\x05W\x1dm\x01\x05Y\x1dq+\x05[\x1du+\x05]\x03\x11-\xaf/\xdb1\xdd3\xa35\xb17\xdf9\xb3;\xe3\x05_\x03\x03\x11\xeb\x1d\x7f\x01\x05a\x03\x07\x83\xa5\x85\xa5\x87\xa5\x05c\x05e\x05g\x1d\x8b\x01\x05i\x03\x11-\xaf/\xed1\xef3\xa35\xb17\xf19\xb3;\xf3\x05k\x03\x03\x0b\xf5\x03\x05%\xad'\xf7\x03\x03\x0f\xf9\x03\x03\x0b\xfb\x03\x03\x0f\xfd\x03\x03\x9d\xab\x05m\x1f/1\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f3\x11\x00\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1f\x1b\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1do\x03\x03\xc7\x1dq\t\x07\x0b\x05\x05\x01\x03\x03\xe1\x1f1!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00#\x1f\x03\x05\xbb\xbf\r\x03\xa7\xbd\x1ds\r\x03\xa7\xc1\x1du\x1dw\x1dy\r\x01#!\x1d{\x13\x05\x01\x1f\r\t\xff\xff\xff\xff\x1f#\x01\x13\x05\x05\x07\x05\x1f'!\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x1f\t\t\x00\x00\x00\x00\x1d}\x1d\x7f\x03\x03\x9f\x15\x03\x01\x01\x01\x03\t\x9f\xb5\xa1\xa1\x13\x03\x01\x13\x03\x05\x13\x03\t\x13\x03\r\x1d\x81\x1d\x83\x03\x05\x9f\xb5\x03\x07\x9f\xa1\xa1\x1f\r\t\x00\x00\x00\x00\x07\x01\x1f;\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\t\t\x00\x00\xc0\x7f\x1f\x1b1\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00)\x07\t\r\r\x07\x1b\x1d\t)\x01\x07)\x05\r\r\x03)\x01\x03\x01)\x03A-\x13)\x03\t\x03)\x07\t\r\r\x0f)\x05\t\r\x07)\x03\r\x05)\x03\x04\x12\x08\x07\x11\x01\x05\x01\x01\x11\x03\x01\x03\x01)\x03\x01\x05)\x05\r\r\x0f)\x03\t\x05)\x03I\x07/\t\x01\x19\x11\x11\x17)\x03\r\x13)\x03\t\x13)\x03\x05\x13/\x07\x01\x15\x1d)\x03\t\x0f)\x07\t\x05\x05\x0f)\x03\x05\x05\x04r\x04\x05\x01\x11\tC\x07\x03\x01\t\x0b\x11\tG\x05\x03-]\t\x03o\x1f\x03)\x17\x06s\x03\x01\x03\x01\x13\x07\x07w\x03+\x03\x03\x05\x07\x07=\x03\x01\x03\x05\x05\x07\x07?\x03\x19\x03\x05\x05\x07\x07A\x03\x11\x03\x05\x05\x07\x07{\x03\x11\x03\x05\x07\x03})\x03\t\x19\x07\x89\x81\x03\x01\x05\x07\x0f\x13\x07\x03\x8d\x035\x05\x11\t\x05\x07\x03=\x03\x01\x03\x13\x05\x07\x03?\x03\x15\x03\x13\x05\x07\x03A\x03\x1d\x03\x13\x07\x03\x03\x91\x03\r\x03\x07\x03\r\x03\x15\x03\x1b\r\x07\x03\x93\x037\x05\x17\x1d\x03\x07\x03\x95\x039\x03\x1f\x07\x03\x03\x97\x03\t\x03\x07\x03\r\x03\x01\x03#\x03\x07\x03\x99\x03\x17\x03!\x0f\x06\x03\x03\x01\x07'\x15%\x1b\x07\x05\x9b\x03\x01\x03\x07\x11\x04\t\x05)+\x0b\x11\x05I\x05\x03\x17/\x03\x01\t\t\x03M\x1f\x03\x0b\x07\x03\x05Q\x03\r\x03\x07#\r\x03\x0b\x03\x05\x15\x06#\x03\x0b\x05\x03\x07\t\x03WU\x03\x0b\r\x07][\x03%\x05\t\x0b\x03\x07ca\x03\x17\x03\r\x07\x03\x05)\x03\t\x03\x07g\r\x03\x01\x03\x11\x0f\x06k\x03\x01\x07\x0f\x13\x01\x11\x04\x05\x03\x15\x06\x03\x01\x05\x01\x00Z\x1b\x85\x1f3+#\x11\x0f\x0b\t\t\x0b!\x0fY\x9d##%_=\x8b\x89W\xb9\xc1K\x9bM\x9b\xd2\x02\x1b\x1f/!!)#\x1f\x19+\x1b\x1f\x83\x1f\x15\x1d\x15\x13\r+\r\x11\x0f\x17\x0f\x1f\x15\x15\x17\x11\x11\x19+)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00get_tuple_element_v1\x00constant_v1\x00iota_v1\x00func_v1\x00compare_v1\x00select_v1\x00return_v1\x00custom_call_v1\x00add_v1\x00reshape_v1\x00pad_v1\x00call_v1\x00value\x00broadcast_dimensions\x00index\x00sym_name\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py\x00iota_dimension\x00compare_type\x00comparison_direction\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jit__lambda_\x00jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False,) name=triu keep_unused=False inline=False]\x00jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=0]\x00jit()/jit(main)/jit(triu)/add\x00jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=1]\x00jit()/jit(main)/jit(triu)/ge\x00jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(2, 3, 3) broadcast_dimensions=(1, 2)]\x00jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(2, 3, 3) broadcast_dimensions=()]\x00jit()/jit(main)/jit(triu)/select_n\x00jit()/jit(main)/iota[dtype=float32 shape=(18,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(2, 3, 3) dimensions=None]\x00jit()/jit(main)/geqrf\x00jit()/jit(main)/qr[full_matrices=True]\x00edge_padding_high\x00edge_padding_low\x00interior_padding\x00jit()/jit(main)/pad[padding_config=((0, 0, 0), (0, 0, 0), (0, 0, 0))]\x00jit()/jit(main)/householder_product\x00callee\x00jax.result_info\x00triu\x00[0]\x00[1]\x00main\x00public\x00private\x00\x00\x00\x00\x00\x02\x00\x00\x00\x03\x00\x00\x00\x03\x00\x00\x00\x00cublas_geqrf_batched\x00\x00\x00\x00\x00\x02\x00\x00\x00\x03\x00\x00\x00\x03\x00\x00\x00\x03\x00\x00\x00 \x81\x00\x00\x00cusolver_orgqr\x00", - xla_call_module_version=4, -) # End paste data_2024_09_26 = {} - data_2024_09_26["f32"] = dict( testdata_version=1, platform='cuda', diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/cuda_tridiagonal_solve.py b/jax/_src/internal_test_util/export_back_compat_test_data/cuda_tridiagonal_solve.py new file mode 100644 index 000000000000..c81d4d4a139d --- /dev/null +++ b/jax/_src/internal_test_util/export_back_compat_test_data/cuda_tridiagonal_solve.py @@ -0,0 +1,84 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ruff: noqa + +import datetime +from numpy import array, float32 + +data_2025_06_16 = {} + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2025_06_16["f32"] = dict( + testdata_version=1, + platform='cuda', + custom_call_targets=['cusparse_gtsv2_ffi'], + serialized_date=datetime.date(2025, 6, 16), + inputs=(array([0., 2., 3.], dtype=float32), array([1., 1., 1.], dtype=float32), array([1., 2., 0.], dtype=float32), array([[1.], + [1.], + [1.]], dtype=float32)), + expected_outputs=(array([[ 0.57142854], + [ 0.42857146], + [-0.2857143 ]], dtype=float32),), + mlir_module_text=r""" +#loc1 = loc("dl") +#loc2 = loc("d") +#loc3 = loc("du") +#loc4 = loc("b") +module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor<3xf32> loc("dl"), %arg1: tensor<3xf32> loc("d"), %arg2: tensor<3xf32> loc("du"), %arg3: tensor<3x1xf32> loc("b")) -> (tensor<3x1xf32> {jax.result_info = "result"}) { + %0 = stablehlo.custom_call @cusparse_gtsv2_ffi(%arg0, %arg1, %arg2, %arg3) {mhlo.backend_config = {}, mhlo.frontend_attributes = {num_batch_dims = "0"}, operand_layouts = [dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>, dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>]} : (tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3x1xf32>) -> tensor<3x1xf32> loc(#loc6) + return %0 : tensor<3x1xf32> loc(#loc) + } loc(#loc) +} loc(#loc) +#loc = loc(unknown) +#loc5 = loc("third_party/py/jax/tests/export_back_compat_test.py":760:13) +#loc6 = loc("jit(func)/jit(main)/tridiagonal_solve"(#loc5)) +""", + mlir_module_serialized=b"ML\xefR\rStableHLO_v1.10.4\x00\x01\x19\x05\x01\x05\t\x01\x03\x0b\x03\x07\x0f\x13\x17\x03\x83]\x13\x01/\x07\x0f#\x0b\x0f\x0b\x0b\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x1b\x0b\x0b\x0f\x0b\x17\x0b\x03/\x0b/O\x1b\x0b\x0f\x13\x0b\x0b\x0b\x0b\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x1b\x0f\x13\x0f\x01\x05\x0b\x0f\x03\x0f\x13\x17\x07\x07#\x13\x13\x02\xea\x02\x1f\x11\x03\x05\x03\x07\x07\t\x0b\x03\r\x03\x05\r\x11\x01\x00\x05\x0f\x05\x11\x05\x13\x1d\x13\x01\x05\x15\x1d\x17\x01\x05\x17\x1d\x1b\x01\x05\x19\x1d\x1f\x01\x05\x1b\x03\x05#/%E\x05\x1d\x05\x1f\x1d)+\x05!\x17-\xe2\x0b\x1b\x05#\r\x01\x1f\x0f\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x11!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\t////#\r\x03\x03;\r\x03=?\x1d%\x1d'\x1d)\x1d+\r\x03GI\x1d-\x1d/\x0b\x03\x1d1\x1d3\x03\x01\x05\x01\x03\t1113\x03\x03Y\x15\x01\r\x01\x03\x033\x01\t\x01\x02\x02)\x03\r\t)\x05\r\x05\t\t\x13\x11\t\x05\x05\x05\x07\x03\x07)\x03\x05\x0b)\x03\t\x0b\x04c\x05\x01Q\x01\x05\x01\x07\x04Q\x03\x01\x05\x03P\x01\x03\x07\x04=\x03\x0b\x0b\t\x0b\x11\x0b\x15\x0b\x19\x0f\x1d\x00\x05G'!\x05\x03\x07\t\x01\x03\x05\x07\x07\x04\x01\x03\t\x06\x03\x01\x05\x01\x00\xd2\x055'\x03\x05\x1f\x0f\x0b\x0f!iM3)\x05\x07\x05\x07\x13%)9\x15\x1f\x11\x0f\x0b\x11builtin\x00vhlo\x00module\x00func_v1\x00custom_call_v1\x00return_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00dl\x00d\x00du\x00b\x00mhlo.backend_config\x00mhlo.frontend_attributes\x00jit(func)/jit(main)/tridiagonal_solve\x00third_party/py/jax/tests/export_back_compat_test.py\x00jax.result_info\x00result\x00main\x00public\x00num_batch_dims\x000\x00\x00cusparse_gtsv2_ffi\x00\x08'\x07\x05\x1f\x01\x0b579AC\x11KMOQSUW[", + xla_call_module_version=9, + nr_devices=1, +) # End paste + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2025_06_16["f64"] = dict( + testdata_version=1, + platform='cuda', + custom_call_targets=['cusparse_gtsv2_ffi'], + serialized_date=datetime.date(2025, 6, 16), + inputs=(array([0., 2., 3.]), array([1., 1., 1.]), array([1., 2., 0.]), array([[1.], + [1.], + [1.]])), + expected_outputs=(array([[ 0.5714285714285714 ], + [ 0.42857142857142855], + [-0.2857142857142857 ]]),), + mlir_module_text=r""" +#loc1 = loc("dl") +#loc2 = loc("d") +#loc3 = loc("du") +#loc4 = loc("b") +module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor<3xf64> loc("dl"), %arg1: tensor<3xf64> loc("d"), %arg2: tensor<3xf64> loc("du"), %arg3: tensor<3x1xf64> loc("b")) -> (tensor<3x1xf64> {jax.result_info = "result"}) { + %0 = stablehlo.custom_call @cusparse_gtsv2_ffi(%arg0, %arg1, %arg2, %arg3) {mhlo.backend_config = {}, mhlo.frontend_attributes = {num_batch_dims = "0"}, operand_layouts = [dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>, dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>]} : (tensor<3xf64>, tensor<3xf64>, tensor<3xf64>, tensor<3x1xf64>) -> tensor<3x1xf64> loc(#loc6) + return %0 : tensor<3x1xf64> loc(#loc) + } loc(#loc) +} loc(#loc) +#loc = loc(unknown) +#loc5 = loc("third_party/py/jax/tests/export_back_compat_test.py":760:13) +#loc6 = loc("jit(func)/jit(main)/tridiagonal_solve"(#loc5)) +""", + mlir_module_serialized=b"ML\xefR\rStableHLO_v1.10.4\x00\x01\x19\x05\x01\x05\t\x01\x03\x0b\x03\x07\x0f\x13\x17\x03\x83]\x13\x01/\x07\x0f#\x0b\x0f\x0b\x0b\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x1b\x0b\x0b\x0f\x0b\x17\x0b\x03/\x0b/O\x1b\x0b\x0f\x13\x0b\x0b\x0b\x0b\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x1b\x0f\x13\x0f\x01\x05\x0b\x0f\x03\x0f\x13\x17\x07\x07#\x13\x13\x02\xea\x02\x1f\x11\x03\x05\x03\x07\x07\t\x0b\x03\r\x03\x05\r\x11\x01\x00\x05\x0f\x05\x11\x05\x13\x1d\x13\x01\x05\x15\x1d\x17\x01\x05\x17\x1d\x1b\x01\x05\x19\x1d\x1f\x01\x05\x1b\x03\x05#/%E\x05\x1d\x05\x1f\x1d)+\x05!\x17-\xe2\x0b\x1b\x05#\r\x01\x1f\x0f\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x11!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\t////#\r\x03\x03;\r\x03=?\x1d%\x1d'\x1d)\x1d+\r\x03GI\x1d-\x1d/\x0b\x03\x1d1\x1d3\x03\x01\x05\x01\x03\t1113\x03\x03Y\x15\x01\r\x01\x03\x033\x01\t\x01\x02\x02)\x03\r\t)\x05\r\x05\t\x0b\x13\x11\t\x05\x05\x05\x07\x03\x07)\x03\x05\x0b)\x03\t\x0b\x04c\x05\x01Q\x01\x05\x01\x07\x04Q\x03\x01\x05\x03P\x01\x03\x07\x04=\x03\x0b\x0b\t\x0b\x11\x0b\x15\x0b\x19\x0f\x1d\x00\x05G'!\x05\x03\x07\t\x01\x03\x05\x07\x07\x04\x01\x03\t\x06\x03\x01\x05\x01\x00\xd2\x055'\x03\x05\x1f\x0f\x0b\x0f!iM3)\x05\x07\x05\x07\x13%)9\x15\x1f\x11\x0f\x0b\x11builtin\x00vhlo\x00module\x00func_v1\x00custom_call_v1\x00return_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00dl\x00d\x00du\x00b\x00mhlo.backend_config\x00mhlo.frontend_attributes\x00jit(func)/jit(main)/tridiagonal_solve\x00third_party/py/jax/tests/export_back_compat_test.py\x00jax.result_info\x00result\x00main\x00public\x00num_batch_dims\x000\x00\x00cusparse_gtsv2_ffi\x00\x08'\x07\x05\x1f\x01\x0b579AC\x11KMOQSUW[", + xla_call_module_version=9, + nr_devices=1, +) # End paste diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/pallas/__init__.py b/jax/_src/internal_test_util/export_back_compat_test_data/pallas/__init__.py new file mode 100644 index 000000000000..3da0dd1fa3ca --- /dev/null +++ b/jax/_src/internal_test_util/export_back_compat_test_data/pallas/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2025 The JAX Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/pallas/mosaic_gpu_add_one.py b/jax/_src/internal_test_util/export_back_compat_test_data/pallas/mosaic_gpu_add_one.py new file mode 100644 index 000000000000..fd66c35de7c9 --- /dev/null +++ b/jax/_src/internal_test_util/export_back_compat_test_data/pallas/mosaic_gpu_add_one.py @@ -0,0 +1,88 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime + +from numpy import array, float32 + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2025_04_22 = dict( + testdata_version=1, + platform='cuda', + custom_call_targets=['mosaic_gpu'], + serialized_date=datetime.date(2025, 4, 22), + inputs=(array([ 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., + 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., + 22., 23., 24., 25., 26., 27., 28., 29., 30., 31., 32., + 33., 34., 35., 36., 37., 38., 39., 40., 41., 42., 43., + 44., 45., 46., 47., 48., 49., 50., 51., 52., 53., 54., + 55., 56., 57., 58., 59., 60., 61., 62., 63., 64., 65., + 66., 67., 68., 69., 70., 71., 72., 73., 74., 75., 76., + 77., 78., 79., 80., 81., 82., 83., 84., 85., 86., 87., + 88., 89., 90., 91., 92., 93., 94., 95., 96., 97., 98., + 99., 100., 101., 102., 103., 104., 105., 106., 107., 108., 109., + 110., 111., 112., 113., 114., 115., 116., 117., 118., 119., 120., + 121., 122., 123., 124., 125., 126., 127., 128., 129., 130., 131., + 132., 133., 134., 135., 136., 137., 138., 139., 140., 141., 142., + 143., 144., 145., 146., 147., 148., 149., 150., 151., 152., 153., + 154., 155., 156., 157., 158., 159., 160., 161., 162., 163., 164., + 165., 166., 167., 168., 169., 170., 171., 172., 173., 174., 175., + 176., 177., 178., 179., 180., 181., 182., 183., 184., 185., 186., + 187., 188., 189., 190., 191., 192., 193., 194., 195., 196., 197., + 198., 199., 200., 201., 202., 203., 204., 205., 206., 207., 208., + 209., 210., 211., 212., 213., 214., 215., 216., 217., 218., 219., + 220., 221., 222., 223., 224., 225., 226., 227., 228., 229., 230., + 231., 232., 233., 234., 235., 236., 237., 238., 239., 240., 241., + 242., 243., 244., 245., 246., 247., 248., 249., 250., 251., 252., + 253., 254., 255.], dtype=float32),), + expected_outputs=(array([ 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., + 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., + 23., 24., 25., 26., 27., 28., 29., 30., 31., 32., 33., + 34., 35., 36., 37., 38., 39., 40., 41., 42., 43., 44., + 45., 46., 47., 48., 49., 50., 51., 52., 53., 54., 55., + 56., 57., 58., 59., 60., 61., 62., 63., 64., 65., 66., + 67., 68., 69., 70., 71., 72., 73., 74., 75., 76., 77., + 78., 79., 80., 81., 82., 83., 84., 85., 86., 87., 88., + 89., 90., 91., 92., 93., 94., 95., 96., 97., 98., 99., + 100., 101., 102., 103., 104., 105., 106., 107., 108., 109., 110., + 111., 112., 113., 114., 115., 116., 117., 118., 119., 120., 121., + 122., 123., 124., 125., 126., 127., 128., 129., 130., 131., 132., + 133., 134., 135., 136., 137., 138., 139., 140., 141., 142., 143., + 144., 145., 146., 147., 148., 149., 150., 151., 152., 153., 154., + 155., 156., 157., 158., 159., 160., 161., 162., 163., 164., 165., + 166., 167., 168., 169., 170., 171., 172., 173., 174., 175., 176., + 177., 178., 179., 180., 181., 182., 183., 184., 185., 186., 187., + 188., 189., 190., 191., 192., 193., 194., 195., 196., 197., 198., + 199., 200., 201., 202., 203., 204., 205., 206., 207., 208., 209., + 210., 211., 212., 213., 214., 215., 216., 217., 218., 219., 220., + 221., 222., 223., 224., 225., 226., 227., 228., 229., 230., 231., + 232., 233., 234., 235., 236., 237., 238., 239., 240., 241., 242., + 243., 244., 245., 246., 247., 248., 249., 250., 251., 252., 253., + 254., 255., 256.], dtype=float32),), + mlir_module_text=r""" +#loc1 = loc("args[0]") +module @jit_wrapped attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor<256xf32> loc("args[0]")) -> (tensor<256xf32> {jax.result_info = "result"}) { + %0 = stablehlo.custom_call @mosaic_gpu(%arg0) {api_version = 2 : i32, backend_config = "\A9C\FB\81\9A1\C2?\0E\F4\E1\E4\E77\03\B6\97\E5G(]WR\98\EB{\BA\8A\84\01\12'#loc = loc(\22third_party/py/jax/tests/pallas/export_back_compat_pallas_test.py\22:83:4)\0A#loc1 = loc(\22-\22:94:40)\0A#loc2 = loc(\22-\22:94:47)\0A#loc3 = loc(\22-\22:94:54)\0A#loc4 = loc(\22-\22:94:116)\0A#loc5 = loc(\22-\22:94:123)\0A#loc6 = loc(\22-\22:94:130)\0A#loc7 = loc(\22-\22:94:65)\0A#loc8 = loc(\22-\22:94:78)\0A#loc9 = loc(\22-\22:94:91)\0A#loc10 = loc(\22-\22:94:141)\0A#loc11 = loc(\22-\22:94:157)\0A#loc12 = loc(\22-\22:94:174)\0A#loc17 = loc(\22jit(wrapped)/jit(main)/pallas_call\22(#loc))\0A\22builtin.module\22() <{sym_name = \22add_one\22}> ({\0A \22stable_mosaic_gpu.func.func\22() ({\0A }) {function_type = (!llvm.ptr, !llvm.ptr, i64, i64, !llvm.ptr, !llvm.ptr, i64, !llvm.ptr) -> (), sym_name = \22mosaic_gpu_init_tma_desc\22, sym_visibility = \22private\22} : () -> () loc(#loc17)\0A \22stable_mosaic_gpu.llvm.mlir.global\22() ({\0A }) {addr_space = 4 : i32, global_type = !llvm.array<0 x i8>, linkage = #llvm.linkage, sym_name = \22global_scratch\22, unnamed_addr = 0 : i64, visibility_ = 0 : i64} : () -> () loc(#loc17)\0A \22stable_mosaic_gpu.func.func\22() ({\0A ^bb0(%arg0: !llvm.ptr loc(\22jit(wrapped)/jit(main)/pallas_call\22(#loc)), %arg1: !llvm.ptr loc(\22jit(wrapped)/jit(main)/pallas_call\22(#loc))):\0A %0 = \22stable_mosaic_gpu.builtin.unrealized_conversion_cast\22(%arg0) : (!llvm.ptr) -> !gpu.async.token loc(#loc17)\0A %1 = \22stable_mosaic_gpu.llvm.getelementptr\22(%arg1) {elem_type = !llvm.ptr, rawConstantIndices = array} : (!llvm.ptr) -> !llvm.ptr loc(#loc17)\0A %2 = \22stable_mosaic_gpu.llvm.load\22(%1) {ordering = 0 : i64} : (!llvm.ptr) -> !llvm.ptr loc(#loc17)\0A %3 = \22stable_mosaic_gpu.llvm.mlir.undef\22() : () -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> loc(#loc17)\0A %4 = \22stable_mosaic_gpu.llvm.insertvalue\22(%3, %2) {position = array} : (!llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, !llvm.ptr) -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> loc(#loc17)\0A %5 = \22stable_mosaic_gpu.llvm.insertvalue\22(%4, %2) {position = array} : (!llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, !llvm.ptr) -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> loc(#loc17)\0A %6 = \22stable_mosaic_gpu.llvm.mlir.constant\22() {value = 0 : i64} : () -> i64 loc(#loc17)\0A %7 = \22stable_mosaic_gpu.llvm.insertvalue\22(%5, %6) {position = array} : (!llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, i64) -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> loc(#loc17)\0A %8 = \22stable_mosaic_gpu.llvm.mlir.constant\22() {value = 256 : i64} : () -> i64 loc(#loc17)\0A %9 = \22stable_mosaic_gpu.llvm.insertvalue\22(%7, %8) {position = array} : (!llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, i64) -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> loc(#loc17)\0A %10 = \22stable_mosaic_gpu.llvm.mlir.constant\22() {value = 1 : i64} : () -> i64 loc(#loc17)\0A %11 = \22stable_mosaic_gpu.llvm.insertvalue\22(%9, %10) {position = array} : (!llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, i64) -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> loc(#loc17)\0A %12 = \22stable_mosaic_gpu.builtin.unrealized_conversion_cast\22(%11) : (!llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>) -> memref<256xf32> loc(#loc17)\0A %13 = \22stable_mosaic_gpu.llvm.getelementptr\22(%arg1) {elem_type = !llvm.ptr, rawConstantIndices = array} : (!llvm.ptr) -> !llvm.ptr loc(#loc17)\0A %14 = \22stable_mosaic_gpu.llvm.load\22(%13) {ordering = 0 : i64} : (!llvm.ptr) -> !llvm.ptr loc(#loc17)\0A %15 = \22stable_mosaic_gpu.llvm.mlir.undef\22() : () -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> loc(#loc17)\0A %16 = \22stable_mosaic_gpu.llvm.insertvalue\22(%15, %14) {position = array} : (!llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, !llvm.ptr) -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> loc(#loc17)\0A %17 = \22stable_mosaic_gpu.llvm.insertvalue\22(%16, %14) {position = array} : (!llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, !llvm.ptr) -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> loc(#loc17)\0A %18 = \22stable_mosaic_gpu.llvm.mlir.constant\22() {value = 0 : i64} : () -> i64 loc(#loc17)\0A %19 = \22stable_mosaic_gpu.llvm.insertvalue\22(%17, %18) {position = array} : (!llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, i64) -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> loc(#loc17)\0A %20 = \22stable_mosaic_gpu.llvm.mlir.constant\22() {value = 256 : i64} : () -> i64 loc(#loc17)\0A %21 = \22stable_mosaic_gpu.llvm.insertvalue\22(%19, %20) {position = array} : (!llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, i64) -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> loc(#loc17)\0A %22 = \22stable_mosaic_gpu.llvm.mlir.constant\22() {value = 1 : i64} : () -> i64 loc(#loc17)\0A %23 = \22stable_mosaic_gpu.llvm.insertvalue\22(%21, %22) {position = array} : (!llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, i64) -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> loc(#loc17)\0A %24 = \22stable_mosaic_gpu.builtin.unrealized_conversion_cast\22(%23) : (!llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>) -> memref<256xf32> loc(#loc17)\0A %25 = \22stable_mosaic_gpu.arith.constant\22() {value = 1 : i64} : () -> i64 loc(#loc17)\0A %26 = \22stable_mosaic_gpu.llvm.alloca\22(%25) {alignment = 64 : i64, elem_type = !llvm.array<256 x i8>} : (i64) -> !llvm.ptr loc(#loc17)\0A %27 = \22stable_mosaic_gpu.llvm.getelementptr\22(%26) {elem_type = i8, rawConstantIndices = array} : (!llvm.ptr) -> !llvm.ptr loc(#loc17)\0A %28:4 = \22stable_mosaic_gpu.memref.extract_strided_metadata\22(%12) : (memref<256xf32>) -> (memref, index, index, index) loc(#loc17)\0A %29 = \22stable_mosaic_gpu.memref.extract_aligned_pointer_as_index\22(%12) : (memref<256xf32>) -> index loc(#loc17)\0A %30 = \22stable_mosaic_gpu.arith.index_cast\22(%29) : (index) -> i64 loc(#loc17)\0A %31 = \22stable_mosaic_gpu.llvm.inttoptr\22(%30) : (i64) -> !llvm.ptr loc(#loc17)\0A %32 = \22stable_mosaic_gpu.arith.index_cast\22(%28#1) : (index) -> i64 loc(#loc17)\0A %33 = \22stable_mosaic_gpu.llvm.getelementptr\22(%31, %32) {elem_type = f32, rawConstantIndices = array} : (!llvm.ptr, i64) -> !llvm.ptr loc(#loc17)\0A %34 = \22stable_mosaic_gpu.arith.constant\22() {value = 6 : i64} : () -> i64 loc(#loc17)\0A %35 = \22stable_mosaic_gpu.arith.constant\22() {value = 1 : i64} : () -> i64 loc(#loc17)\0A %36 = \22stable_mosaic_gpu.arith.index_cast\22(%28#2) : (index) -> i64 loc(#loc17)\0A %37 = \22stable_mosaic_gpu.arith.constant\22() {value = 1 : i64} : () -> i64 loc(#loc17)\0A %38 = \22stable_mosaic_gpu.llvm.alloca\22(%37) {elem_type = i64} : (i64) -> !llvm.ptr loc(#loc17)\0A %39 = \22stable_mosaic_gpu.llvm.getelementptr\22(%38) {elem_type = i64, rawConstantIndices = array} : (!llvm.ptr) -> !llvm.ptr loc(#loc17)\0A \22stable_mosaic_gpu.llvm.store\22(%36, %39) {ordering = 0 : i64} : (i64, !llvm.ptr) -> () loc(#loc17)\0A %40 = \22stable_mosaic_gpu.arith.index_cast\22(%28#3) : (index) -> i64 loc(#loc17)\0A %41 = \22stable_mosaic_gpu.arith.constant\22() {value = 1 : i64} : () -> i64 loc(#loc17)\0A %42 = \22stable_mosaic_gpu.llvm.alloca\22(%41) {elem_type = i64} : (i64) -> !llvm.ptr loc(#loc17)\0A %43 = \22stable_mosaic_gpu.llvm.getelementptr\22(%42) {elem_type = i64, rawConstantIndices = array} : (!llvm.ptr) -> !llvm.ptr loc(#loc17)\0A \22stable_mosaic_gpu.llvm.store\22(%40, %43) {ordering = 0 : i64} : (i64, !llvm.ptr) -> () loc(#loc17)\0A %44 = \22stable_mosaic_gpu.arith.constant\22() {value = 16 : i64} : () -> i64 loc(#loc17)\0A %45 = \22stable_mosaic_gpu.arith.constant\22() {value = 256 : i64} : () -> i64 loc(#loc17)\0A %46 = \22stable_mosaic_gpu.arith.constant\22() {value = 1 : i64} : () -> i64 loc(#loc17)\0A %47 = \22stable_mosaic_gpu.llvm.alloca\22(%46) {elem_type = i64} : (i64) -> !llvm.ptr loc(#loc17)\0A %48 = \22stable_mosaic_gpu.llvm.getelementptr\22(%47) {elem_type = i64, rawConstantIndices = array} : (!llvm.ptr) -> !llvm.ptr loc(#loc17)\0A \22stable_mosaic_gpu.llvm.store\22(%45, %48) {ordering = 0 : i64} : (i64, !llvm.ptr) -> () loc(#loc17)\0A \22stable_mosaic_gpu.func.call\22(%27, %33, %34, %35, %38, %42, %44, %47) {callee = @mosaic_gpu_init_tma_desc} : (!llvm.ptr, !llvm.ptr, i64, i64, !llvm.ptr, !llvm.ptr, i64, !llvm.ptr) -> () loc(#loc17)\0A %49 = \22stable_mosaic_gpu.llvm.getelementptr\22(%26) {elem_type = i8, rawConstantIndices = array} : (!llvm.ptr) -> !llvm.ptr loc(#loc17)\0A %50:4 = \22stable_mosaic_gpu.memref.extract_strided_metadata\22(%24) : (memref<256xf32>) -> (memref, index, index, index) loc(#loc17)\0A %51 = \22stable_mosaic_gpu.memref.extract_aligned_pointer_as_index\22(%24) : (memref<256xf32>) -> index loc(#loc17)\0A %52 = \22stable_mosaic_gpu.arith.index_cast\22(%51) : (index) -> i64 loc(#loc17)\0A %53 = \22stable_mosaic_gpu.llvm.inttoptr\22(%52) : (i64) -> !llvm.ptr loc(#loc17)\0A %54 = \22stable_mosaic_gpu.arith.index_cast\22(%50#1) : (index) -> i64 loc(#loc17)\0A %55 = \22stable_mosaic_gpu.llvm.getelementptr\22(%53, %54) {elem_type = f32, rawConstantIndices = array} : (!llvm.ptr, i64) -> !llvm.ptr loc(#loc17)\0A %56 = \22stable_mosaic_gpu.arith.constant\22() {value = 6 : i64} : () -> i64 loc(#loc17)\0A %57 = \22stable_mosaic_gpu.arith.constant\22() {value = 1 : i64} : () -> i64 loc(#loc17)\0A %58 = \22stable_mosaic_gpu.arith.index_cast\22(%50#2) : (index) -> i64 loc(#loc17)\0A %59 = \22stable_mosaic_gpu.arith.constant\22() {value = 1 : i64} : () -> i64 loc(#loc17)\0A %60 = \22stable_mosaic_gpu.llvm.alloca\22(%59) {elem_type = i64} : (i64) -> !llvm.ptr loc(#loc17)\0A %61 = \22stable_mosaic_gpu.llvm.getelementptr\22(%60) {elem_type = i64, rawConstantIndices = array} : (!llvm.ptr) -> !llvm.ptr loc(#loc17)\0A \22stable_mosaic_gpu.llvm.store\22(%58, %61) {ordering = 0 : i64} : (i64, !llvm.ptr) -> () loc(#loc17)\0A %62 = \22stable_mosaic_gpu.arith.index_cast\22(%50#3) : (index) -> i64 loc(#loc17)\0A %63 = \22stable_mosaic_gpu.arith.constant\22() {value = 1 : i64} : () -> i64 loc(#loc17)\0A %64 = \22stable_mosaic_gpu.llvm.alloca\22(%63) {elem_type = i64} : (i64) -> !llvm.ptr loc(#loc17)\0A %65 = \22stable_mosaic_gpu.llvm.getelementptr\22(%64) {elem_type = i64, rawConstantIndices = array} : (!llvm.ptr) -> !llvm.ptr loc(#loc17)\0A \22stable_mosaic_gpu.llvm.store\22(%62, %65) {ordering = 0 : i64} : (i64, !llvm.ptr) -> () loc(#loc17)\0A %66 = \22stable_mosaic_gpu.arith.constant\22() {value = 16 : i64} : () -> i64 loc(#loc17)\0A %67 = \22stable_mosaic_gpu.arith.constant\22() {value = 256 : i64} : () -> i64 loc(#loc17)\0A %68 = \22stable_mosaic_gpu.arith.constant\22() {value = 1 : i64} : () -> i64 loc(#loc17)\0A %69 = \22stable_mosaic_gpu.llvm.alloca\22(%68) {elem_type = i64} : (i64) -> !llvm.ptr loc(#loc17)\0A %70 = \22stable_mosaic_gpu.llvm.getelementptr\22(%69) {elem_type = i64, rawConstantIndices = array} : (!llvm.ptr) -> !llvm.ptr loc(#loc17)\0A \22stable_mosaic_gpu.llvm.store\22(%67, %70) {ordering = 0 : i64} : (i64, !llvm.ptr) -> () loc(#loc17)\0A \22stable_mosaic_gpu.func.call\22(%49, %55, %56, %57, %60, %64, %66, %69) {callee = @mosaic_gpu_init_tma_desc} : (!llvm.ptr, !llvm.ptr, i64, i64, !llvm.ptr, !llvm.ptr, i64, !llvm.ptr) -> () loc(#loc17)\0A %71 = \22stable_mosaic_gpu.llvm.load\22(%26) {ordering = 0 : i64} : (!llvm.ptr) -> !llvm.array<256 x i8> loc(#loc17)\0A %72 = \22stable_mosaic_gpu.arith.constant\22() {value = 2 : index} : () -> index loc(#loc17)\0A %73 = \22stable_mosaic_gpu.arith.constant\22() {value = 1 : index} : () -> index loc(#loc17)\0A %74 = \22stable_mosaic_gpu.arith.constant\22() {value = 1 : index} : () -> index loc(#loc17)\0A %75 = \22stable_mosaic_gpu.arith.constant\22() {value = 128 : index} : () -> index loc(#loc17)\0A %76 = \22stable_mosaic_gpu.arith.constant\22() {value = 1 : index} : () -> index loc(#loc17)\0A %77 = \22stable_mosaic_gpu.arith.constant\22() {value = 1 : index} : () -> index loc(#loc17)\0A %78 = \22stable_mosaic_gpu.arith.constant\22() {value = 2056 : i32} : () -> i32 loc(#loc17)\0A %79 = \22stable_mosaic_gpu.gpu.launch\22(%0, %72, %73, %74, %75, %76, %77, %78) ({\0A ^bb0(%arg2: index loc(\22-\22:94:40), %arg3: index loc(\22-\22:94:47), %arg4: index loc(\22-\22:94:54), %arg5: index loc(\22-\22:94:116), %arg6: index loc(\22-\22:94:123), %arg7: index loc(\22-\22:94:130), %arg8: index loc(\22-\22:94:65), %arg9: index loc(\22-\22:94:78), %arg10: index loc(\22-\22:94:91), %arg11: index loc(\22-\22:94:141), %arg12: index loc(\22-\22:94:157), %arg13: index loc(\22-\22:94:174)):\0A %80 = \22stable_mosaic_gpu.gpu.dynamic_shared_memory\22() : () -> memref> loc(#loc17)\0A %81 = \22stable_mosaic_gpu.builtin.unrealized_conversion_cast\22(%71) : (!llvm.array<256 x i8>) -> !llvm.ptr loc(#loc17)\0A %82 = \22stable_mosaic_gpu.llvm.getelementptr\22(%81) {elem_type = i8, rawConstantIndices = array} : (!llvm.ptr) -> !llvm.ptr loc(#loc18)\0A %83 = \22stable_mosaic_gpu.llvm.getelementptr\22(%81) {elem_type = i8, rawConstantIndices = array} : (!llvm.ptr) -> !llvm.ptr loc(#loc19)\0A %84 = \22stable_mosaic_gpu.arith.constant\22() {value = 0 : index} : () -> index loc(#loc17)\0A %85 = \22stable_mosaic_gpu.memref.view\22(%80, %84) : (memref>, index) -> memref<2048xi8, #gpu.address_space> loc(#loc17)\0A %86 = \22stable_mosaic_gpu.builtin.unrealized_conversion_cast\22(%80) : (memref>) -> !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)> loc(#loc17)\0A %87 = \22stable_mosaic_gpu.llvm.extractvalue\22(%86) {position = array} : (!llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>) -> !llvm.ptr<3> loc(#loc17)\0A %88 = \22stable_mosaic_gpu.llvm.extractvalue\22(%86) {position = array} : (!llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>) -> i64 loc(#loc17)\0A %89 = \22stable_mosaic_gpu.arith.constant\22() {value = 1 : i64} : () -> i64 loc(#loc17)\0A %90 = \22stable_mosaic_gpu.llvm.mul\22(%88, %89) : (i64, i64) -> i64 loc(#loc17)\0A %91 = \22stable_mosaic_gpu.llvm.ptrtoint\22(%87) : (!llvm.ptr<3>) -> i64 loc(#loc17)\0A %92 = \22stable_mosaic_gpu.llvm.add\22(%91, %90) : (i64, i64) -> i64 loc(#loc17)\0A %93 = \22stable_mosaic_gpu.llvm.inttoptr\22(%92) : (i64) -> !llvm.ptr<3> loc(#loc17)\0A %94 = \22stable_mosaic_gpu.llvm.getelementptr\22(%93) {elem_type = i8, rawConstantIndices = array} : (!llvm.ptr<3>) -> !llvm.ptr<3> loc(#loc17)\0A %95 = \22stable_mosaic_gpu.memref.alloca\22() {operandSegmentSizes = array} : () -> memref loc(#loc17)\0A %96 = \22stable_mosaic_gpu.arith.constant\22() {value = 0 : i32} : () -> i32 loc(#loc17)\0A \22stable_mosaic_gpu.memref.store\22(%96, %95) : (i32, memref) -> () loc(#loc17)\0A %97 = \22stable_mosaic_gpu.nvvm.elect.sync\22() : () -> i1 loc(#loc17)\0A %98 = \22stable_mosaic_gpu.gpu.thread_id\22() {dimension = #gpu} : () -> index loc(#loc17)\0A %99 = \22stable_mosaic_gpu.arith.index_cast\22(%98) : (index) -> i32 loc(#loc17)\0A %100 = \22stable_mosaic_gpu.gpu.block_dim\22() {dimension = #gpu} : () -> index loc(#loc17)\0A %101 = \22stable_mosaic_gpu.arith.index_cast\22(%100) : (index) -> i32 loc(#loc17)\0A %102 = \22stable_mosaic_gpu.gpu.thread_id\22() {dimension = #gpu} : () -> index loc(#loc17)\0A %103 = \22stable_mosaic_gpu.arith.index_cast\22(%102) : (index) -> i32 loc(#loc17)\0A %104 = \22stable_mosaic_gpu.arith.muli\22(%103, %101) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc17)\0A %105 = \22stable_mosaic_gpu.arith.addi\22(%99, %104) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc17)\0A %106 = \22stable_mosaic_gpu.gpu.block_dim\22() {dimension = #gpu} : () -> index loc(#loc17)\0A %107 = \22stable_mosaic_gpu.arith.index_cast\22(%106) : (index) -> i32 loc(#loc17)\0A %108 = \22stable_mosaic_gpu.arith.muli\22(%101, %107) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc17)\0A %109 = \22stable_mosaic_gpu.gpu.thread_id\22() {dimension = #gpu} : () -> index loc(#loc17)\0A %110 = \22stable_mosaic_gpu.arith.index_cast\22(%109) : (index) -> i32 loc(#loc17)\0A %111 = \22stable_mosaic_gpu.arith.muli\22(%110, %108) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc17)\0A %112 = \22stable_mosaic_gpu.arith.addi\22(%105, %111) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc17)\0A %113 = \22stable_mosaic_gpu.gpu.block_dim\22() {dimension = #gpu} : () -> index loc(#loc17)\0A %114 = \22stable_mosaic_gpu.arith.index_cast\22(%113) : (index) -> i32 loc(#loc17)\0A %115 = \22stable_mosaic_gpu.arith.muli\22(%108, %114) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc17)\0A %116 = \22stable_mosaic_gpu.arith.constant\22() {value = 5 : i32} : () -> i32 loc(#loc17)\0A %117 = \22stable_mosaic_gpu.arith.shrui\22(%112, %116) : (i32, i32) -> i32 loc(#loc17)\0A %118 = \22stable_mosaic_gpu.arith.constant\22() {value = -1 : i32} : () -> i32 loc(#loc17)\0A %119 = \22stable_mosaic_gpu.arith.constant\22() {value = 0 : i32} : () -> i32 loc(#loc17)\0A %120 = \22stable_mosaic_gpu.arith.constant\22() {value = 31 : i32} : () -> i32 loc(#loc17)\0A %121 = \22stable_mosaic_gpu.nvvm.shfl.sync\22(%118, %117, %119, %120) {kind = #nvvm} : (i32, i32, i32, i32) -> i32 loc(#loc17)\0A %122 = \22stable_mosaic_gpu.arith.constant\22() {value = 0 : i32} : () -> i32 loc(#loc17)\0A %123 = \22stable_mosaic_gpu.arith.cmpi\22(%121, %122) {predicate = 0 : i64} : (i32, i32) -> i1 loc(#loc17)\0A %124 = \22stable_mosaic_gpu.arith.andi\22(%123, %97) : (i1, i1) -> i1 loc(#loc17)\0A \22stable_mosaic_gpu.scf.if\22(%124) ({\0A %332 = \22stable_mosaic_gpu.llvm.getelementptr\22(%94) {elem_type = i64, rawConstantIndices = array} : (!llvm.ptr<3>) -> !llvm.ptr<3> loc(#loc17)\0A %333 = \22stable_mosaic_gpu.arith.constant\22() {value = 128 : i32} : () -> i32 loc(#loc17)\0A \22stable_mosaic_gpu.nvvm.mbarrier.init.shared\22(%332, %333) : (!llvm.ptr<3>, i32) -> () loc(#loc17)\0A \22stable_mosaic_gpu.scf.yield\22() : () -> () loc(#loc13)\0A }, {\0A }) : (i1) -> () loc(#loc17)\0A %125 = \22stable_mosaic_gpu.arith.constant\22() {value = 0 : i32} : () -> i32 loc(#loc17)\0A \22stable_mosaic_gpu.nvvm.fence.mbarrier.init\22() : () -> () loc(#loc17)\0A \22stable_mosaic_gpu.gpu.barrier\22() : () -> () loc(#loc17)\0A %126 = \22stable_mosaic_gpu.nvvm.elect.sync\22() : () -> i1 loc(#loc17)\0A %127 = \22stable_mosaic_gpu.gpu.thread_id\22() {dimension = #gpu} : () -> index loc(#loc17)\0A %128 = \22stable_mosaic_gpu.arith.index_cast\22(%127) : (index) -> i32 loc(#loc17)\0A %129 = \22stable_mosaic_gpu.gpu.block_dim\22() {dimension = #gpu} : () -> index loc(#loc17)\0A %130 = \22stable_mosaic_gpu.arith.index_cast\22(%129) : (index) -> i32 loc(#loc17)\0A %131 = \22stable_mosaic_gpu.gpu.thread_id\22() {dimension = #gpu} : () -> index loc(#loc17)\0A %132 = \22stable_mosaic_gpu.arith.index_cast\22(%131) : (index) -> i32 loc(#loc17)\0A %133 = \22stable_mosaic_gpu.arith.muli\22(%132, %130) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc17)\0A %134 = \22stable_mosaic_gpu.arith.addi\22(%128, %133) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc17)\0A %135 = \22stable_mosaic_gpu.gpu.block_dim\22() {dimension = #gpu} : () -> index loc(#loc17)\0A %136 = \22stable_mosaic_gpu.arith.index_cast\22(%135) : (index) -> i32 loc(#loc17)\0A %137 = \22stable_mosaic_gpu.arith.muli\22(%130, %136) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc17)\0A %138 = \22stable_mosaic_gpu.gpu.thread_id\22() {dimension = #gpu} : () -> index loc(#loc17)\0A %139 = \22stable_mosaic_gpu.arith.index_cast\22(%138) : (index) -> i32 loc(#loc17)\0A %140 = \22stable_mosaic_gpu.arith.muli\22(%139, %137) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc17)\0A %141 = \22stable_mosaic_gpu.arith.addi\22(%134, %140) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc17)\0A %142 = \22stable_mosaic_gpu.gpu.block_dim\22() {dimension = #gpu} : () -> index loc(#loc17)\0A %143 = \22stable_mosaic_gpu.arith.index_cast\22(%142) : (index) -> i32 loc(#loc17)\0A %144 = \22stable_mosaic_gpu.arith.muli\22(%137, %143) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc17)\0A %145 = \22stable_mosaic_gpu.arith.constant\22() {value = 5 : i32} : () -> i32 loc(#loc17)\0A %146 = \22stable_mosaic_gpu.arith.shrui\22(%141, %145) : (i32, i32) -> i32 loc(#loc17)\0A %147 = \22stable_mosaic_gpu.arith.constant\22() {value = -1 : i32} : () -> i32 loc(#loc17)\0A %148 = \22stable_mosaic_gpu.arith.constant\22() {value = 0 : i32} : () -> i32 loc(#loc17)\0A %149 = \22stable_mosaic_gpu.arith.constant\22() {value = 31 : i32} : () -> i32 loc(#loc17)\0A %150 = \22stable_mosaic_gpu.nvvm.shfl.sync\22(%147, %146, %148, %149) {kind = #nvvm} : (i32, i32, i32, i32) -> i32 loc(#loc17)\0A %151 = \22stable_mosaic_gpu.arith.constant\22() {value = 4 : i32} : () -> i32 loc(#loc17)\0A %152 = \22stable_mosaic_gpu.arith.remui\22(%150, %151) : (i32, i32) -> i32 loc(#loc17)\0A %153 = \22stable_mosaic_gpu.arith.constant\22() {value = 0 : i32} : () -> i32 loc(#loc17)\0A %154 = \22stable_mosaic_gpu.arith.cmpi\22(%152, %153) {predicate = 0 : i64} : (i32, i32) -> i1 loc(#loc17)\0A %155 = \22stable_mosaic_gpu.arith.andi\22(%154, %126) : (i1, i1) -> i1 loc(#loc17)\0A %156 = \22stable_mosaic_gpu.nvvm.elect.sync\22() : () -> i1 loc(#loc17)\0A %157 = \22stable_mosaic_gpu.gpu.block_id\22() {dimension = #gpu} : () -> index loc(#loc17)\0A %158 = \22stable_mosaic_gpu.arith.index_cast\22(%157) : (index) -> i32 loc(#loc17)\0A %159 = \22stable_mosaic_gpu.gpu.dynamic_shared_memory\22() : () -> memref> loc(#loc20)\0A %160 = \22stable_mosaic_gpu.arith.constant\22() {value = 0 : index} : () -> index loc(#loc20)\0A %161 = \22stable_mosaic_gpu.memref.view\22(%159, %160) : (memref>, index) -> memref<1x256xf32, #gpu.address_space> loc(#loc20)\0A %162 = \22stable_mosaic_gpu.gpu.dynamic_shared_memory\22() : () -> memref> loc(#loc20)\0A %163 = \22stable_mosaic_gpu.arith.constant\22() {value = 1024 : index} : () -> index loc(#loc20)\0A %164 = \22stable_mosaic_gpu.memref.view\22(%162, %163) : (memref>, index) -> memref<1x256xf32, #gpu.address_space> loc(#loc20)\0A %165 = \22stable_mosaic_gpu.arith.constant\22() {value = 0 : index} : () -> index loc(#loc19)\0A %166 = \22stable_mosaic_gpu.memref.subview\22(%161, %165) {operandSegmentSizes = array, static_offsets = array, static_sizes = array, static_strides = array} : (memref<1x256xf32, #gpu.address_space>, index) -> memref<256xf32, strided<[1], offset: ?>, #gpu.address_space> loc(#loc19)\0A %167 = \22stable_mosaic_gpu.arith.constant\22() {value = 0 : index} : () -> index loc(#loc19)\0A %168 = \22stable_mosaic_gpu.arith.index_castui\22(%167) : (index) -> i32 loc(#loc19)\0A %169 = \22stable_mosaic_gpu.arith.addi\22(%125, %168) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc19)\0A %170 = \22stable_mosaic_gpu.arith.constant\22() {value = 8 : i32} : () -> i32 loc(#loc19)\0A %171 = \22stable_mosaic_gpu.llvm.getelementptr\22(%94, %169) {elem_type = i64, rawConstantIndices = array} : (!llvm.ptr<3>, i32) -> !llvm.ptr<3> loc(#loc19)\0A \22stable_mosaic_gpu.nvvm.mbarrier.arrive.expect_tx.shared\22(%171, %170) : (!llvm.ptr<3>, i32) -> () loc(#loc19)\0A %172 = \22stable_mosaic_gpu.arith.constant\22() {value = 0 : index} : () -> index loc(#loc19)\0A %173 = \22stable_mosaic_gpu.arith.index_cast\22(%172) : (index) -> i32 loc(#loc19)\0A %174 = \22stable_mosaic_gpu.builtin.unrealized_conversion_cast\22(%166) : (memref<256xf32, strided<[1], offset: ?>, #gpu.address_space>) -> !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)> loc(#loc19)\0A %175 = \22stable_mosaic_gpu.llvm.extractvalue\22(%174) {position = array} : (!llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>) -> !llvm.ptr<3> loc(#loc19)\0A %176 = \22stable_mosaic_gpu.llvm.extractvalue\22(%174) {position = array} : (!llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>) -> i64 loc(#loc19)\0A %177 = \22stable_mosaic_gpu.arith.constant\22() {value = 4 : i64} : () -> i64 loc(#loc19)\0A %178 = \22stable_mosaic_gpu.llvm.mul\22(%176, %177) : (i64, i64) -> i64 loc(#loc19)\0A %179 = \22stable_mosaic_gpu.llvm.ptrtoint\22(%175) : (!llvm.ptr<3>) -> i64 loc(#loc19)\0A %180 = \22stable_mosaic_gpu.llvm.add\22(%179, %178) : (i64, i64) -> i64 loc(#loc19)\0A %181 = \22stable_mosaic_gpu.llvm.inttoptr\22(%180) : (i64) -> !llvm.ptr<3> loc(#loc19)\0A %182 = \22stable_mosaic_gpu.arith.constant\22() {value = 1024 : i32} : () -> i32 loc(#loc19)\0A %183 = \22stable_mosaic_gpu.llvm.getelementptr\22(%94, %169) {elem_type = i64, rawConstantIndices = array} : (!llvm.ptr<3>, i32) -> !llvm.ptr<3> loc(#loc19)\0A \22stable_mosaic_gpu.nvvm.cp.async.bulk.tensor.shared.cluster.global\22(%181, %83, %173, %183, %155) {operandSegmentSizes = array} : (!llvm.ptr<3>, !llvm.ptr, i32, !llvm.ptr<3>, i1) -> () loc(#loc19)\0A %184 = \22stable_mosaic_gpu.arith.constant\22() {value = 0 : i32} : () -> i32 loc(#loc21)\0A %185 = \22stable_mosaic_gpu.arith.constant\22() {value = 0 : i32} : () -> i32 loc(#loc21)\0A %186 = \22stable_mosaic_gpu.arith.addi\22(%185, %184) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc21)\0A %187 = \22stable_mosaic_gpu.arith.constant\22() {value = 1 : i32} : () -> i32 loc(#loc22)\0A %188 = \22stable_mosaic_gpu.arith.remsi\22(%186, %187) : (i32, i32) -> i32 loc(#loc22)\0A %189 = \22stable_mosaic_gpu.arith.index_cast\22(%188) : (i32) -> index loc(#loc23)\0A %190 = \22stable_mosaic_gpu.arith.index_castui\22(%189) : (index) -> i32 loc(#loc23)\0A %191 = \22stable_mosaic_gpu.arith.addi\22(%125, %190) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc23)\0A %192 = \22stable_mosaic_gpu.memref.load\22(%95) : (memref) -> i32 loc(#loc23)\0A %193 = \22stable_mosaic_gpu.arith.constant\22() {value = 1 : i32} : () -> i32 loc(#loc23)\0A %194 = \22stable_mosaic_gpu.arith.shli\22(%193, %191) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc23)\0A %195 = \22stable_mosaic_gpu.arith.andi\22(%192, %194) : (i32, i32) -> i32 loc(#loc23)\0A %196 = \22stable_mosaic_gpu.arith.constant\22() {value = 0 : i32} : () -> i32 loc(#loc23)\0A %197 = \22stable_mosaic_gpu.arith.cmpi\22(%195, %196) {predicate = 1 : i64} : (i32, i32) -> i1 loc(#loc23)\0A %198 = \22stable_mosaic_gpu.arith.xori\22(%192, %194) : (i32, i32) -> i32 loc(#loc23)\0A \22stable_mosaic_gpu.memref.store\22(%198, %95) : (i32, memref) -> () loc(#loc23)\0A %199 = \22stable_mosaic_gpu.arith.constant\22() {value = 10000000 : i32} : () -> i32 loc(#loc23)\0A %200 = \22stable_mosaic_gpu.arith.extui\22(%197) : (i1) -> i32 loc(#loc23)\0A %201 = \22stable_mosaic_gpu.llvm.getelementptr\22(%94, %191) {elem_type = i64, rawConstantIndices = array} : (!llvm.ptr<3>, i32) -> !llvm.ptr<3> loc(#loc23)\0A \22stable_mosaic_gpu.nvvm.mbarrier.try_wait.parity.shared\22(%201, %200, %199) : (!llvm.ptr<3>, i32, i32) -> () loc(#loc23)\0A %202 = \22stable_mosaic_gpu.arith.index_cast\22(%188) : (i32) -> index loc(#loc24)\0A %203 = \22stable_mosaic_gpu.memref.subview\22(%161, %202) {operandSegmentSizes = array, static_offsets = array, static_sizes = array, static_strides = array} : (memref<1x256xf32, #gpu.address_space>, index) -> memref<256xf32, strided<[1], offset: ?>, #gpu.address_space> loc(#loc24)\0A %204 = \22stable_mosaic_gpu.arith.index_cast\22(%188) : (i32) -> index loc(#loc24)\0A %205 = \22stable_mosaic_gpu.memref.subview\22(%164, %204) {operandSegmentSizes = array, static_offsets = array, static_sizes = array, static_strides = array} : (memref<1x256xf32, #gpu.address_space>, index) -> memref<256xf32, strided<[1], offset: ?>, #gpu.address_space> loc(#loc24)\0A %206 = \22stable_mosaic_gpu.gpu.block_id\22() {dimension = #gpu} : () -> index loc(#loc24)\0A %207 = \22stable_mosaic_gpu.arith.index_cast\22(%206) : (index) -> i32 loc(#loc24)\0A %208 = \22stable_mosaic_gpu.memref.subview\22(%203) {operandSegmentSizes = array, static_offsets = array, static_sizes = array, static_strides = array} : (memref<256xf32, strided<[1], offset: ?>, #gpu.address_space>) -> memref<256xf32, strided<[1], offset: ?>, #gpu.address_space> loc(#loc25)\0A %209 = \22stable_mosaic_gpu.memref.collapse_shape\22(%208) {reassociation = [[0]]} : (memref<256xf32, strided<[1], offset: ?>, #gpu.address_space>) -> memref<256xf32, strided<[1], offset: ?>, #gpu.address_space> loc(#loc25)\0A %210 = \22stable_mosaic_gpu.gpu.thread_id\22() {dimension = #gpu} : () -> index loc(#loc25)\0A %211 = \22stable_mosaic_gpu.arith.constant\22() {value = 128 : index} : () -> index loc(#loc25)\0A %212 = \22stable_mosaic_gpu.arith.remui\22(%210, %211) : (index, index) -> index loc(#loc25)\0A %213 = \22stable_mosaic_gpu.arith.constant\22() {value = 2 : index} : () -> index loc(#loc25)\0A %214 = \22stable_mosaic_gpu.arith.muli\22(%212, %213) {overflowFlags = #arith.overflow} : (index, index) -> index loc(#loc25)\0A %215 = \22stable_mosaic_gpu.arith.constant\22() {value = 0 : index} : () -> index loc(#loc25)\0A %216 = \22stable_mosaic_gpu.arith.addi\22(%214, %215) {overflowFlags = #arith.overflow} : (index, index) -> index loc(#loc25)\0A %217 = \22stable_mosaic_gpu.vector.load\22(%209, %216) : (memref<256xf32, strided<[1], offset: ?>, #gpu.address_space>, index) -> vector<2xf32> loc(#loc25)\0A %218 = \22stable_mosaic_gpu.arith.constant\22() {value = 1.000000e+00 : f32} : () -> f32 loc(#loc26)\0A %219 = \22stable_mosaic_gpu.vector.splat\22(%218) : (f32) -> vector<2xf32> loc(#loc26)\0A %220 = \22stable_mosaic_gpu.arith.addf\22(%217, %219) {fastmath = #arith.fastmath} : (vector<2xf32>, vector<2xf32>) -> vector<2xf32> loc(#loc26)\0A %221 = \22stable_mosaic_gpu.memref.subview\22(%205) {operandSegmentSizes = array, static_offsets = array, static_sizes = array, static_strides = array} : (memref<256xf32, strided<[1], offset: ?>, #gpu.address_space>) -> memref<256xf32, strided<[1], offset: ?>, #gpu.address_space> loc(#loc27)\0A %222 = \22stable_mosaic_gpu.memref.collapse_shape\22(%221) {reassociation = [[0]]} : (memref<256xf32, strided<[1], offset: ?>, #gpu.address_space>) -> memref<256xf32, strided<[1], offset: ?>, #gpu.address_space> loc(#loc27)\0A %223 = \22stable_mosaic_gpu.gpu.thread_id\22() {dimension = #gpu} : () -> index loc(#loc27)\0A %224 = \22stable_mosaic_gpu.arith.constant\22() {value = 128 : index} : () -> index loc(#loc27)\0A %225 = \22stable_mosaic_gpu.arith.remui\22(%223, %224) : (index, index) -> index loc(#loc27)\0A %226 = \22stable_mosaic_gpu.arith.constant\22() {value = 2 : index} : () -> index loc(#loc27)\0A %227 = \22stable_mosaic_gpu.arith.muli\22(%225, %226) {overflowFlags = #arith.overflow} : (index, index) -> index loc(#loc27)\0A %228 = \22stable_mosaic_gpu.arith.constant\22() {value = 0 : index} : () -> index loc(#loc27)\0A %229 = \22stable_mosaic_gpu.arith.addi\22(%227, %228) {overflowFlags = #arith.overflow} : (index, index) -> index loc(#loc27)\0A %230 = \22stable_mosaic_gpu.vector.load\22(%222, %229) : (memref<256xf32, strided<[1], offset: ?>, #gpu.address_space>, index) -> vector<2xf32> loc(#loc27)\0A %231 = \22stable_mosaic_gpu.memref.collapse_shape\22(%221) {reassociation = [[0]]} : (memref<256xf32, strided<[1], offset: ?>, #gpu.address_space>) -> memref<256xf32, strided<[1], offset: ?>, #gpu.address_space> loc(#loc27)\0A %232 = \22stable_mosaic_gpu.gpu.thread_id\22() {dimension = #gpu} : () -> index loc(#loc27)\0A %233 = \22stable_mosaic_gpu.arith.constant\22() {value = 128 : index} : () -> index loc(#loc27)\0A %234 = \22stable_mosaic_gpu.arith.remui\22(%232, %233) : (index, index) -> index loc(#loc27)\0A %235 = \22stable_mosaic_gpu.arith.constant\22() {value = 2 : index} : () -> index loc(#loc27)\0A %236 = \22stable_mosaic_gpu.arith.muli\22(%234, %235) {overflowFlags = #arith.overflow} : (index, index) -> index loc(#loc27)\0A %237 = \22stable_mosaic_gpu.arith.constant\22() {value = 0 : index} : () -> index loc(#loc27)\0A %238 = \22stable_mosaic_gpu.arith.addi\22(%236, %237) {overflowFlags = #arith.overflow} : (index, index) -> index loc(#loc27)\0A \22stable_mosaic_gpu.vector.store\22(%220, %231, %238) : (vector<2xf32>, memref<256xf32, strided<[1], offset: ?>, #gpu.address_space>, index) -> () loc(#loc27)\0A \22stable_mosaic_gpu.nvvm.cp.async.bulk.commit.group\22() : () -> () loc(#loc28)\0A %239 = \22stable_mosaic_gpu.arith.constant\22() {value = 1 : i32} : () -> i32 loc(#loc29)\0A %240 = \22stable_mosaic_gpu.arith.addi\22(%186, %239) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc29)\0A %241 = \22stable_mosaic_gpu.arith.constant\22() {value = 1 : i32} : () -> i32 loc(#loc22)\0A %242 = \22stable_mosaic_gpu.arith.remsi\22(%240, %241) : (i32, i32) -> i32 loc(#loc22)\0A %243 = \22stable_mosaic_gpu.arith.constant\22() {value = 0 : i32} : () -> i32 loc(#loc30)\0A %244 = \22stable_mosaic_gpu.arith.cmpi\22(%186, %243) {predicate = 9 : i64} : (i32, i32) -> i1 loc(#loc30)\0A %245 = \22stable_mosaic_gpu.arith.constant\22() {value = 1 : i32} : () -> i32 loc(#loc31)\0A %246 = \22stable_mosaic_gpu.arith.cmpi\22(%240, %245) {predicate = 6 : i64} : (i32, i32) -> i1 loc(#loc31)\0A %247 = \22stable_mosaic_gpu.arith.andi\22(%244, %246) : (i1, i1) -> i1 loc(#loc32)\0A %248 = \22stable_mosaic_gpu.arith.extui\22(%247) : (i1) -> i32 loc(#loc33)\0A %249 = \22stable_mosaic_gpu.arith.index_cast\22(%248) : (i32) -> index loc(#loc34)\0A \22stable_mosaic_gpu.scf.index_switch\22(%249) ({\0A %313 = \22stable_mosaic_gpu.arith.index_cast\22(%242) : (i32) -> index loc(#loc19)\0A %314 = \22stable_mosaic_gpu.memref.subview\22(%161, %313) {operandSegmentSizes = array, static_offsets = array, static_sizes = array, static_strides = array} : (memref<1x256xf32, #gpu.address_space>, index) -> memref<256xf32, strided<[1], offset: ?>, #gpu.address_space> loc(#loc19)\0A %315 = \22stable_mosaic_gpu.arith.index_cast\22(%242) : (i32) -> index loc(#loc19)\0A %316 = \22stable_mosaic_gpu.arith.index_castui\22(%315) : (index) -> i32 loc(#loc19)\0A %317 = \22stable_mosaic_gpu.arith.addi\22(%125, %316) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc19)\0A %318 = \22stable_mosaic_gpu.arith.constant\22() {value = 8 : i32} : () -> i32 loc(#loc19)\0A %319 = \22stable_mosaic_gpu.llvm.getelementptr\22(%94, %317) {elem_type = i64, rawConstantIndices = array} : (!llvm.ptr<3>, i32) -> !llvm.ptr<3> loc(#loc19)\0A \22stable_mosaic_gpu.nvvm.mbarrier.arrive.expect_tx.shared\22(%319, %318) : (!llvm.ptr<3>, i32) -> () loc(#loc19)\0A %320 = \22stable_mosaic_gpu.arith.constant\22() {value = 0 : index} : () -> index loc(#loc19)\0A %321 = \22stable_mosaic_gpu.arith.index_cast\22(%320) : (index) -> i32 loc(#loc19)\0A %322 = \22stable_mosaic_gpu.builtin.unrealized_conversion_cast\22(%314) : (memref<256xf32, strided<[1], offset: ?>, #gpu.address_space>) -> !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)> loc(#loc19)\0A %323 = \22stable_mosaic_gpu.llvm.extractvalue\22(%322) {position = array} : (!llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>) -> !llvm.ptr<3> loc(#loc19)\0A %324 = \22stable_mosaic_gpu.llvm.extractvalue\22(%322) {position = array} : (!llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>) -> i64 loc(#loc19)\0A %325 = \22stable_mosaic_gpu.arith.constant\22() {value = 4 : i64} : () -> i64 loc(#loc19)\0A %326 = \22stable_mosaic_gpu.llvm.mul\22(%324, %325) : (i64, i64) -> i64 loc(#loc19)\0A %327 = \22stable_mosaic_gpu.llvm.ptrtoint\22(%323) : (!llvm.ptr<3>) -> i64 loc(#loc19)\0A %328 = \22stable_mosaic_gpu.llvm.add\22(%327, %326) : (i64, i64) -> i64 loc(#loc19)\0A %329 = \22stable_mosaic_gpu.llvm.inttoptr\22(%328) : (i64) -> !llvm.ptr<3> loc(#loc19)\0A %330 = \22stable_mosaic_gpu.arith.constant\22() {value = 1024 : i32} : () -> i32 loc(#loc19)\0A %331 = \22stable_mosaic_gpu.llvm.getelementptr\22(%94, %317) {elem_type = i64, rawConstantIndices = array} : (!llvm.ptr<3>, i32) -> !llvm.ptr<3> loc(#loc19)\0A \22stable_mosaic_gpu.nvvm.cp.async.bulk.tensor.shared.cluster.global\22(%329, %83, %321, %331, %155) {operandSegmentSizes = array} : (!llvm.ptr<3>, !llvm.ptr, i32, !llvm.ptr<3>, i1) -> () loc(#loc19)\0A \22stable_mosaic_gpu.scf.yield\22() : () -> () loc(#loc16)\0A }, {\0A \22stable_mosaic_gpu.scf.yield\22() : () -> () loc(#loc34)\0A }) {cases = array} : (index) -> () loc(#loc34)\0A %250 = \22stable_mosaic_gpu.arith.constant\22() {value = 1 : i32} : () -> i32 loc(#loc21)\0A %251 = \22stable_mosaic_gpu.arith.addi\22(%184, %250) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc21)\0A \22stable_mosaic_gpu.nvvm.fence.proxy\22() {kind = #nvvm.proxy_kind, space = #nvvm.shared_space} : () -> () loc(#loc35)\0A %252 = \22stable_mosaic_gpu.gpu.thread_id\22() {dimension = #gpu} : () -> index loc(#loc35)\0A %253 = \22stable_mosaic_gpu.arith.index_cast\22(%252) : (index) -> i32 loc(#loc35)\0A %254 = \22stable_mosaic_gpu.gpu.block_dim\22() {dimension = #gpu} : () -> index loc(#loc35)\0A %255 = \22stable_mosaic_gpu.arith.index_cast\22(%254) : (index) -> i32 loc(#loc35)\0A %256 = \22stable_mosaic_gpu.gpu.thread_id\22() {dimension = #gpu} : () -> index loc(#loc35)\0A %257 = \22stable_mosaic_gpu.arith.index_cast\22(%256) : (index) -> i32 loc(#loc35)\0A %258 = \22stable_mosaic_gpu.arith.muli\22(%257, %255) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc35)\0A %259 = \22stable_mosaic_gpu.arith.addi\22(%253, %258) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc35)\0A %260 = \22stable_mosaic_gpu.gpu.block_dim\22() {dimension = #gpu} : () -> index loc(#loc35)\0A %261 = \22stable_mosaic_gpu.arith.index_cast\22(%260) : (index) -> i32 loc(#loc35)\0A %262 = \22stable_mosaic_gpu.arith.muli\22(%255, %261) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc35)\0A %263 = \22stable_mosaic_gpu.gpu.thread_id\22() {dimension = #gpu} : () -> index loc(#loc35)\0A %264 = \22stable_mosaic_gpu.arith.index_cast\22(%263) : (index) -> i32 loc(#loc35)\0A %265 = \22stable_mosaic_gpu.arith.muli\22(%264, %262) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc35)\0A %266 = \22stable_mosaic_gpu.arith.addi\22(%259, %265) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc35)\0A %267 = \22stable_mosaic_gpu.gpu.block_dim\22() {dimension = #gpu} : () -> index loc(#loc35)\0A %268 = \22stable_mosaic_gpu.arith.index_cast\22(%267) : (index) -> i32 loc(#loc35)\0A %269 = \22stable_mosaic_gpu.arith.muli\22(%262, %268) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc35)\0A %270 = \22stable_mosaic_gpu.arith.constant\22() {value = 7 : i32} : () -> i32 loc(#loc35)\0A %271 = \22stable_mosaic_gpu.arith.shrui\22(%266, %270) : (i32, i32) -> i32 loc(#loc35)\0A %272 = \22stable_mosaic_gpu.arith.constant\22() {value = 1 : i32} : () -> i32 loc(#loc35)\0A %273 = \22stable_mosaic_gpu.arith.addi\22(%271, %272) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc35)\0A %274 = \22stable_mosaic_gpu.llvm.inline_asm\22(%273) {asm_string = \22bar.sync $0, 128;\22, constraints = \22r\22, has_side_effects} : (i32) -> !llvm.void loc(#loc35)\0A %275 = \22stable_mosaic_gpu.arith.constant\22() {value = 0 : i32} : () -> i32 loc(#loc22)\0A %276 = \22stable_mosaic_gpu.arith.constant\22() {value = 1 : i32} : () -> i32 loc(#loc22)\0A %277 = \22stable_mosaic_gpu.arith.remsi\22(%275, %276) : (i32, i32) -> i32 loc(#loc22)\0A %278 = \22stable_mosaic_gpu.arith.index_cast\22(%277) : (i32) -> index loc(#loc18)\0A %279 = \22stable_mosaic_gpu.memref.subview\22(%164, %278) {operandSegmentSizes = array, static_offsets = array, static_sizes = array, static_strides = array} : (memref<1x256xf32, #gpu.address_space>, index) -> memref<256xf32, strided<[1], offset: ?>, #gpu.address_space> loc(#loc18)\0A %280 = \22stable_mosaic_gpu.arith.constant\22() {value = 0 : index} : () -> index loc(#loc18)\0A %281 = \22stable_mosaic_gpu.arith.index_cast\22(%280) : (index) -> i32 loc(#loc18)\0A %282 = \22stable_mosaic_gpu.builtin.unrealized_conversion_cast\22(%279) : (memref<256xf32, strided<[1], offset: ?>, #gpu.address_space>) -> !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)> loc(#loc18)\0A %283 = \22stable_mosaic_gpu.llvm.extractvalue\22(%282) {position = array} : (!llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>) -> !llvm.ptr<3> loc(#loc18)\0A %284 = \22stable_mosaic_gpu.llvm.extractvalue\22(%282) {position = array} : (!llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>) -> i64 loc(#loc18)\0A %285 = \22stable_mosaic_gpu.arith.constant\22() {value = 4 : i64} : () -> i64 loc(#loc18)\0A %286 = \22stable_mosaic_gpu.llvm.mul\22(%284, %285) : (i64, i64) -> i64 loc(#loc18)\0A %287 = \22stable_mosaic_gpu.llvm.ptrtoint\22(%283) : (!llvm.ptr<3>) -> i64 loc(#loc18)\0A %288 = \22stable_mosaic_gpu.llvm.add\22(%287, %286) : (i64, i64) -> i64 loc(#loc18)\0A %289 = \22stable_mosaic_gpu.llvm.inttoptr\22(%288) : (i64) -> !llvm.ptr<3> loc(#loc18)\0A \22stable_mosaic_gpu.nvvm.cp.async.bulk.tensor.global.shared.cta\22(%82, %289, %281, %155) {operandSegmentSizes = array} : (!llvm.ptr, !llvm.ptr<3>, i32, i1) -> () loc(#loc18)\0A \22stable_mosaic_gpu.nvvm.cp.async.bulk.commit.group\22() : () -> () loc(#loc28)\0A \22stable_mosaic_gpu.nvvm.cp.async.bulk.wait_group\22() {group = 0 : i32} : () -> () loc(#loc36)\0A %290 = \22stable_mosaic_gpu.gpu.thread_id\22() {dimension = #gpu} : () -> index loc(#loc36)\0A %291 = \22stable_mosaic_gpu.arith.index_cast\22(%290) : (index) -> i32 loc(#loc36)\0A %292 = \22stable_mosaic_gpu.gpu.block_dim\22() {dimension = #gpu} : () -> index loc(#loc36)\0A %293 = \22stable_mosaic_gpu.arith.index_cast\22(%292) : (index) -> i32 loc(#loc36)\0A %294 = \22stable_mosaic_gpu.gpu.thread_id\22() {dimension = #gpu} : () -> index loc(#loc36)\0A %295 = \22stable_mosaic_gpu.arith.index_cast\22(%294) : (index) -> i32 loc(#loc36)\0A %296 = \22stable_mosaic_gpu.arith.muli\22(%295, %293) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc36)\0A %297 = \22stable_mosaic_gpu.arith.addi\22(%291, %296) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc36)\0A %298 = \22stable_mosaic_gpu.gpu.block_dim\22() {dimension = #gpu} : () -> index loc(#loc36)\0A %299 = \22stable_mosaic_gpu.arith.index_cast\22(%298) : (index) -> i32 loc(#loc36)\0A %300 = \22stable_mosaic_gpu.arith.muli\22(%293, %299) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc36)\0A %301 = \22stable_mosaic_gpu.gpu.thread_id\22() {dimension = #gpu} : () -> index loc(#loc36)\0A %302 = \22stable_mosaic_gpu.arith.index_cast\22(%301) : (index) -> i32 loc(#loc36)\0A %303 = \22stable_mosaic_gpu.arith.muli\22(%302, %300) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc36)\0A %304 = \22stable_mosaic_gpu.arith.addi\22(%297, %303) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc36)\0A %305 = \22stable_mosaic_gpu.gpu.block_dim\22() {dimension = #gpu} : () -> index loc(#loc36)\0A %306 = \22stable_mosaic_gpu.arith.index_cast\22(%305) : (index) -> i32 loc(#loc36)\0A %307 = \22stable_mosaic_gpu.arith.muli\22(%300, %306) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc36)\0A %308 = \22stable_mosaic_gpu.arith.constant\22() {value = 7 : i32} : () -> i32 loc(#loc36)\0A %309 = \22stable_mosaic_gpu.arith.shrui\22(%304, %308) : (i32, i32) -> i32 loc(#loc36)\0A %310 = \22stable_mosaic_gpu.arith.constant\22() {value = 1 : i32} : () -> i32 loc(#loc36)\0A %311 = \22stable_mosaic_gpu.arith.addi\22(%309, %310) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc36)\0A %312 = \22stable_mosaic_gpu.llvm.inline_asm\22(%311) {asm_string = \22bar.sync $0, 128;\22, constraints = \22r\22, has_side_effects} : (i32) -> !llvm.void loc(#loc36)\0A \22stable_mosaic_gpu.gpu.terminator\22() : () -> () loc(#loc17)\0A }) {operandSegmentSizes = array, workgroup_attributions = 0 : i64} : (!gpu.async.token, index, index, index, index, index, index, i32) -> !gpu.async.token loc(#loc17)\0A \22stable_mosaic_gpu.func.return\22() : () -> () loc(#loc17)\0A }) {function_type = (!llvm.ptr, !llvm.ptr) -> (), llvm.emit_c_interface, sym_name = \22mosaic_gpu_body\22} : () -> () loc(#loc17)\0A}) {stable_mosaic_gpu.version = 1 : i64} : () -> () loc(#loc17)\0A#loc13 = loc(\22-\22:141:7)\0A#loc14 = loc(\22third_party/py/jax/tests/pallas/export_back_compat_pallas_test.py\22:78:19)\0A#loc15 = loc(\22third_party/py/jax/tests/pallas/export_back_compat_pallas_test.py\22:78:6)\0A#loc16 = loc(\22-\22:279:7)\0A#loc18 = loc(\22/copy_smem_to_gmem\22(#loc))\0A#loc19 = loc(\22/copy_gmem_to_smem\22(#loc))\0A#loc20 = loc(\22/run_scoped\22(#loc))\0A#loc21 = loc(\22/scan\22(#loc))\0A#loc22 = loc(\22/rem\22(#loc))\0A#loc23 = loc(\22/barrier_wait\22(#loc))\0A#loc24 = loc(\22/jaxpr_call\22(#loc))\0A#loc25 = loc(\22/get\22(#loc14))\0A#loc26 = loc(\22/add\22(#loc14))\0A#loc27 = loc(\22/swap\22(#loc15))\0A#loc28 = loc(\22/commit_group\22(#loc))\0A#loc29 = loc(\22/add\22(#loc))\0A#loc30 = loc(\22/ge\22(#loc))\0A#loc31 = loc(\22/lt\22(#loc))\0A#loc32 = loc(\22/and\22(#loc))\0A#loc33 = loc(\22/convert_element_type\22(#loc))\0A#loc34 = loc(\22/cond\22(#loc))\0A#loc35 = loc(\22/commit_smem\22(#loc))\0A#loc36 = loc(\22/wait_smem_to_gmem\22(#loc))\0A", operand_layouts = [dense<0> : tensor<1xindex>], result_layouts = [dense<0> : tensor<1xindex>]} : (tensor<256xf32>) -> tensor<256xf32> loc(#loc3) + return %0 : tensor<256xf32> loc(#loc) + } loc(#loc) +} loc(#loc) +#loc = loc(unknown) +#loc2 = loc("third_party/py/jax/tests/pallas/export_back_compat_pallas_test.py":83:4) +#loc3 = loc("jit(wrapped)/jit(main)/pallas_call"(#loc2)) +""", + mlir_module_serialized=b'ML\xefR\rStableHLO_v1.9.7\x00\x01\x19\x05\x01\x05\t\x01\x03\x0b\x03\x07\x0f\x13\x17\x03_=\x0f\x01\x1d\x07\x0f#\x0b\x0f\x0b\x0b\x0b\x0f\x0b\x0f\x0b\x13\x0b\x03!\x0b\x0f\x0f\x0b\x0b\x0f\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b/\x01\x05\x0b\x0f\x03\x0b\x17\x17\x07\x13\x07\x02\xd5\x1f\x11\x03\x05\x03\x07\x07\t\x0b\x03\r\x03\x05\r\x11\x01\x00\x05\x0f\x05\x11\x05\x13\x1d\x13\x01\x05\x15\x1d\x17\x19\x05\x17\x17\x1b\xa7\t\x05\x19\x03\x01\x03\x03;\x03\x03#\r\x01#\x07\x03\x03)\r\x03+-\x1d\x1b\x1d\x1d\x1d\x1f\x1d!\x0b\x05\x1d#\x1d%\x05\x01\x1f\x0b\x11\x00\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x03\x02\x08\t\x11\x03\x05\x03\x05\t)\x03\x05\r\x13\x04O\x05\x01Q\x01\x05\x01\x07\x04=\x03\x01\x05\x03P\x01\x03\x07\x04)\x03\x05\x0b\x03\x0b\x11\x00\x05F\x15\x05\x03\x05\x03\x01\x07\x04\x01\x03\x03\x06\x03\x01\x05\x01\x00D\xae\x05\'\x17\xa4\xa4\x05\x0f\x0b\x0f!\x85G\x11\x19%)9\x15\x1f\x11\x0f\x0b\x11builtin\x00vhlo\x00module\x00func_v1\x00custom_call_v1\x00return_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_wrapped\x00args[0]\x00jit(wrapped)/jit(main)/pallas_call\x00third_party/py/jax/tests/pallas/export_back_compat_pallas_test.py\x00jax.result_info\x00result\x00main\x00public\x00\xa9C\xfb\x81\x9a1\xc2?\x0e\xf4\xe1\xe4\xe77\x03\xb6\x97\xe5G(]WR\x98\xeb{\xba\x8a\x84\x01\x12\'#loc = loc("third_party/py/jax/tests/pallas/export_back_compat_pallas_test.py":83:4)\n#loc1 = loc("-":94:40)\n#loc2 = loc("-":94:47)\n#loc3 = loc("-":94:54)\n#loc4 = loc("-":94:116)\n#loc5 = loc("-":94:123)\n#loc6 = loc("-":94:130)\n#loc7 = loc("-":94:65)\n#loc8 = loc("-":94:78)\n#loc9 = loc("-":94:91)\n#loc10 = loc("-":94:141)\n#loc11 = loc("-":94:157)\n#loc12 = loc("-":94:174)\n#loc17 = loc("jit(wrapped)/jit(main)/pallas_call"(#loc))\n"builtin.module"() <{sym_name = "add_one"}> ({\n "stable_mosaic_gpu.func.func"() ({\n }) {function_type = (!llvm.ptr, !llvm.ptr, i64, i64, !llvm.ptr, !llvm.ptr, i64, !llvm.ptr) -> (), sym_name = "mosaic_gpu_init_tma_desc", sym_visibility = "private"} : () -> () loc(#loc17)\n "stable_mosaic_gpu.llvm.mlir.global"() ({\n }) {addr_space = 4 : i32, global_type = !llvm.array<0 x i8>, linkage = #llvm.linkage, sym_name = "global_scratch", unnamed_addr = 0 : i64, visibility_ = 0 : i64} : () -> () loc(#loc17)\n "stable_mosaic_gpu.func.func"() ({\n ^bb0(%arg0: !llvm.ptr loc("jit(wrapped)/jit(main)/pallas_call"(#loc)), %arg1: !llvm.ptr loc("jit(wrapped)/jit(main)/pallas_call"(#loc))):\n %0 = "stable_mosaic_gpu.builtin.unrealized_conversion_cast"(%arg0) : (!llvm.ptr) -> !gpu.async.token loc(#loc17)\n %1 = "stable_mosaic_gpu.llvm.getelementptr"(%arg1) {elem_type = !llvm.ptr, rawConstantIndices = array} : (!llvm.ptr) -> !llvm.ptr loc(#loc17)\n %2 = "stable_mosaic_gpu.llvm.load"(%1) {ordering = 0 : i64} : (!llvm.ptr) -> !llvm.ptr loc(#loc17)\n %3 = "stable_mosaic_gpu.llvm.mlir.undef"() : () -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> loc(#loc17)\n %4 = "stable_mosaic_gpu.llvm.insertvalue"(%3, %2) {position = array} : (!llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, !llvm.ptr) -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> loc(#loc17)\n %5 = "stable_mosaic_gpu.llvm.insertvalue"(%4, %2) {position = array} : (!llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, !llvm.ptr) -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> loc(#loc17)\n %6 = "stable_mosaic_gpu.llvm.mlir.constant"() {value = 0 : i64} : () -> i64 loc(#loc17)\n %7 = "stable_mosaic_gpu.llvm.insertvalue"(%5, %6) {position = array} : (!llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, i64) -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> loc(#loc17)\n %8 = "stable_mosaic_gpu.llvm.mlir.constant"() {value = 256 : i64} : () -> i64 loc(#loc17)\n %9 = "stable_mosaic_gpu.llvm.insertvalue"(%7, %8) {position = array} : (!llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, i64) -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> loc(#loc17)\n %10 = "stable_mosaic_gpu.llvm.mlir.constant"() {value = 1 : i64} : () -> i64 loc(#loc17)\n %11 = "stable_mosaic_gpu.llvm.insertvalue"(%9, %10) {position = array} : (!llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, i64) -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> loc(#loc17)\n %12 = "stable_mosaic_gpu.builtin.unrealized_conversion_cast"(%11) : (!llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>) -> memref<256xf32> loc(#loc17)\n %13 = "stable_mosaic_gpu.llvm.getelementptr"(%arg1) {elem_type = !llvm.ptr, rawConstantIndices = array} : (!llvm.ptr) -> !llvm.ptr loc(#loc17)\n %14 = "stable_mosaic_gpu.llvm.load"(%13) {ordering = 0 : i64} : (!llvm.ptr) -> !llvm.ptr loc(#loc17)\n %15 = "stable_mosaic_gpu.llvm.mlir.undef"() : () -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> loc(#loc17)\n %16 = "stable_mosaic_gpu.llvm.insertvalue"(%15, %14) {position = array} : (!llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, !llvm.ptr) -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> loc(#loc17)\n %17 = "stable_mosaic_gpu.llvm.insertvalue"(%16, %14) {position = array} : (!llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, !llvm.ptr) -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> loc(#loc17)\n %18 = "stable_mosaic_gpu.llvm.mlir.constant"() {value = 0 : i64} : () -> i64 loc(#loc17)\n %19 = "stable_mosaic_gpu.llvm.insertvalue"(%17, %18) {position = array} : (!llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, i64) -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> loc(#loc17)\n %20 = "stable_mosaic_gpu.llvm.mlir.constant"() {value = 256 : i64} : () -> i64 loc(#loc17)\n %21 = "stable_mosaic_gpu.llvm.insertvalue"(%19, %20) {position = array} : (!llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, i64) -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> loc(#loc17)\n %22 = "stable_mosaic_gpu.llvm.mlir.constant"() {value = 1 : i64} : () -> i64 loc(#loc17)\n %23 = "stable_mosaic_gpu.llvm.insertvalue"(%21, %22) {position = array} : (!llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, i64) -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> loc(#loc17)\n %24 = "stable_mosaic_gpu.builtin.unrealized_conversion_cast"(%23) : (!llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>) -> memref<256xf32> loc(#loc17)\n %25 = "stable_mosaic_gpu.arith.constant"() {value = 1 : i64} : () -> i64 loc(#loc17)\n %26 = "stable_mosaic_gpu.llvm.alloca"(%25) {alignment = 64 : i64, elem_type = !llvm.array<256 x i8>} : (i64) -> !llvm.ptr loc(#loc17)\n %27 = "stable_mosaic_gpu.llvm.getelementptr"(%26) {elem_type = i8, rawConstantIndices = array} : (!llvm.ptr) -> !llvm.ptr loc(#loc17)\n %28:4 = "stable_mosaic_gpu.memref.extract_strided_metadata"(%12) : (memref<256xf32>) -> (memref, index, index, index) loc(#loc17)\n %29 = "stable_mosaic_gpu.memref.extract_aligned_pointer_as_index"(%12) : (memref<256xf32>) -> index loc(#loc17)\n %30 = "stable_mosaic_gpu.arith.index_cast"(%29) : (index) -> i64 loc(#loc17)\n %31 = "stable_mosaic_gpu.llvm.inttoptr"(%30) : (i64) -> !llvm.ptr loc(#loc17)\n %32 = "stable_mosaic_gpu.arith.index_cast"(%28#1) : (index) -> i64 loc(#loc17)\n %33 = "stable_mosaic_gpu.llvm.getelementptr"(%31, %32) {elem_type = f32, rawConstantIndices = array} : (!llvm.ptr, i64) -> !llvm.ptr loc(#loc17)\n %34 = "stable_mosaic_gpu.arith.constant"() {value = 6 : i64} : () -> i64 loc(#loc17)\n %35 = "stable_mosaic_gpu.arith.constant"() {value = 1 : i64} : () -> i64 loc(#loc17)\n %36 = "stable_mosaic_gpu.arith.index_cast"(%28#2) : (index) -> i64 loc(#loc17)\n %37 = "stable_mosaic_gpu.arith.constant"() {value = 1 : i64} : () -> i64 loc(#loc17)\n %38 = "stable_mosaic_gpu.llvm.alloca"(%37) {elem_type = i64} : (i64) -> !llvm.ptr loc(#loc17)\n %39 = "stable_mosaic_gpu.llvm.getelementptr"(%38) {elem_type = i64, rawConstantIndices = array} : (!llvm.ptr) -> !llvm.ptr loc(#loc17)\n "stable_mosaic_gpu.llvm.store"(%36, %39) {ordering = 0 : i64} : (i64, !llvm.ptr) -> () loc(#loc17)\n %40 = "stable_mosaic_gpu.arith.index_cast"(%28#3) : (index) -> i64 loc(#loc17)\n %41 = "stable_mosaic_gpu.arith.constant"() {value = 1 : i64} : () -> i64 loc(#loc17)\n %42 = "stable_mosaic_gpu.llvm.alloca"(%41) {elem_type = i64} : (i64) -> !llvm.ptr loc(#loc17)\n %43 = "stable_mosaic_gpu.llvm.getelementptr"(%42) {elem_type = i64, rawConstantIndices = array} : (!llvm.ptr) -> !llvm.ptr loc(#loc17)\n "stable_mosaic_gpu.llvm.store"(%40, %43) {ordering = 0 : i64} : (i64, !llvm.ptr) -> () loc(#loc17)\n %44 = "stable_mosaic_gpu.arith.constant"() {value = 16 : i64} : () -> i64 loc(#loc17)\n %45 = "stable_mosaic_gpu.arith.constant"() {value = 256 : i64} : () -> i64 loc(#loc17)\n %46 = "stable_mosaic_gpu.arith.constant"() {value = 1 : i64} : () -> i64 loc(#loc17)\n %47 = "stable_mosaic_gpu.llvm.alloca"(%46) {elem_type = i64} : (i64) -> !llvm.ptr loc(#loc17)\n %48 = "stable_mosaic_gpu.llvm.getelementptr"(%47) {elem_type = i64, rawConstantIndices = array} : (!llvm.ptr) -> !llvm.ptr loc(#loc17)\n "stable_mosaic_gpu.llvm.store"(%45, %48) {ordering = 0 : i64} : (i64, !llvm.ptr) -> () loc(#loc17)\n "stable_mosaic_gpu.func.call"(%27, %33, %34, %35, %38, %42, %44, %47) {callee = @mosaic_gpu_init_tma_desc} : (!llvm.ptr, !llvm.ptr, i64, i64, !llvm.ptr, !llvm.ptr, i64, !llvm.ptr) -> () loc(#loc17)\n %49 = "stable_mosaic_gpu.llvm.getelementptr"(%26) {elem_type = i8, rawConstantIndices = array} : (!llvm.ptr) -> !llvm.ptr loc(#loc17)\n %50:4 = "stable_mosaic_gpu.memref.extract_strided_metadata"(%24) : (memref<256xf32>) -> (memref, index, index, index) loc(#loc17)\n %51 = "stable_mosaic_gpu.memref.extract_aligned_pointer_as_index"(%24) : (memref<256xf32>) -> index loc(#loc17)\n %52 = "stable_mosaic_gpu.arith.index_cast"(%51) : (index) -> i64 loc(#loc17)\n %53 = "stable_mosaic_gpu.llvm.inttoptr"(%52) : (i64) -> !llvm.ptr loc(#loc17)\n %54 = "stable_mosaic_gpu.arith.index_cast"(%50#1) : (index) -> i64 loc(#loc17)\n %55 = "stable_mosaic_gpu.llvm.getelementptr"(%53, %54) {elem_type = f32, rawConstantIndices = array} : (!llvm.ptr, i64) -> !llvm.ptr loc(#loc17)\n %56 = "stable_mosaic_gpu.arith.constant"() {value = 6 : i64} : () -> i64 loc(#loc17)\n %57 = "stable_mosaic_gpu.arith.constant"() {value = 1 : i64} : () -> i64 loc(#loc17)\n %58 = "stable_mosaic_gpu.arith.index_cast"(%50#2) : (index) -> i64 loc(#loc17)\n %59 = "stable_mosaic_gpu.arith.constant"() {value = 1 : i64} : () -> i64 loc(#loc17)\n %60 = "stable_mosaic_gpu.llvm.alloca"(%59) {elem_type = i64} : (i64) -> !llvm.ptr loc(#loc17)\n %61 = "stable_mosaic_gpu.llvm.getelementptr"(%60) {elem_type = i64, rawConstantIndices = array} : (!llvm.ptr) -> !llvm.ptr loc(#loc17)\n "stable_mosaic_gpu.llvm.store"(%58, %61) {ordering = 0 : i64} : (i64, !llvm.ptr) -> () loc(#loc17)\n %62 = "stable_mosaic_gpu.arith.index_cast"(%50#3) : (index) -> i64 loc(#loc17)\n %63 = "stable_mosaic_gpu.arith.constant"() {value = 1 : i64} : () -> i64 loc(#loc17)\n %64 = "stable_mosaic_gpu.llvm.alloca"(%63) {elem_type = i64} : (i64) -> !llvm.ptr loc(#loc17)\n %65 = "stable_mosaic_gpu.llvm.getelementptr"(%64) {elem_type = i64, rawConstantIndices = array} : (!llvm.ptr) -> !llvm.ptr loc(#loc17)\n "stable_mosaic_gpu.llvm.store"(%62, %65) {ordering = 0 : i64} : (i64, !llvm.ptr) -> () loc(#loc17)\n %66 = "stable_mosaic_gpu.arith.constant"() {value = 16 : i64} : () -> i64 loc(#loc17)\n %67 = "stable_mosaic_gpu.arith.constant"() {value = 256 : i64} : () -> i64 loc(#loc17)\n %68 = "stable_mosaic_gpu.arith.constant"() {value = 1 : i64} : () -> i64 loc(#loc17)\n %69 = "stable_mosaic_gpu.llvm.alloca"(%68) {elem_type = i64} : (i64) -> !llvm.ptr loc(#loc17)\n %70 = "stable_mosaic_gpu.llvm.getelementptr"(%69) {elem_type = i64, rawConstantIndices = array} : (!llvm.ptr) -> !llvm.ptr loc(#loc17)\n "stable_mosaic_gpu.llvm.store"(%67, %70) {ordering = 0 : i64} : (i64, !llvm.ptr) -> () loc(#loc17)\n "stable_mosaic_gpu.func.call"(%49, %55, %56, %57, %60, %64, %66, %69) {callee = @mosaic_gpu_init_tma_desc} : (!llvm.ptr, !llvm.ptr, i64, i64, !llvm.ptr, !llvm.ptr, i64, !llvm.ptr) -> () loc(#loc17)\n %71 = "stable_mosaic_gpu.llvm.load"(%26) {ordering = 0 : i64} : (!llvm.ptr) -> !llvm.array<256 x i8> loc(#loc17)\n %72 = "stable_mosaic_gpu.arith.constant"() {value = 2 : index} : () -> index loc(#loc17)\n %73 = "stable_mosaic_gpu.arith.constant"() {value = 1 : index} : () -> index loc(#loc17)\n %74 = "stable_mosaic_gpu.arith.constant"() {value = 1 : index} : () -> index loc(#loc17)\n %75 = "stable_mosaic_gpu.arith.constant"() {value = 128 : index} : () -> index loc(#loc17)\n %76 = "stable_mosaic_gpu.arith.constant"() {value = 1 : index} : () -> index loc(#loc17)\n %77 = "stable_mosaic_gpu.arith.constant"() {value = 1 : index} : () -> index loc(#loc17)\n %78 = "stable_mosaic_gpu.arith.constant"() {value = 2056 : i32} : () -> i32 loc(#loc17)\n %79 = "stable_mosaic_gpu.gpu.launch"(%0, %72, %73, %74, %75, %76, %77, %78) ({\n ^bb0(%arg2: index loc("-":94:40), %arg3: index loc("-":94:47), %arg4: index loc("-":94:54), %arg5: index loc("-":94:116), %arg6: index loc("-":94:123), %arg7: index loc("-":94:130), %arg8: index loc("-":94:65), %arg9: index loc("-":94:78), %arg10: index loc("-":94:91), %arg11: index loc("-":94:141), %arg12: index loc("-":94:157), %arg13: index loc("-":94:174)):\n %80 = "stable_mosaic_gpu.gpu.dynamic_shared_memory"() : () -> memref> loc(#loc17)\n %81 = "stable_mosaic_gpu.builtin.unrealized_conversion_cast"(%71) : (!llvm.array<256 x i8>) -> !llvm.ptr loc(#loc17)\n %82 = "stable_mosaic_gpu.llvm.getelementptr"(%81) {elem_type = i8, rawConstantIndices = array} : (!llvm.ptr) -> !llvm.ptr loc(#loc18)\n %83 = "stable_mosaic_gpu.llvm.getelementptr"(%81) {elem_type = i8, rawConstantIndices = array} : (!llvm.ptr) -> !llvm.ptr loc(#loc19)\n %84 = "stable_mosaic_gpu.arith.constant"() {value = 0 : index} : () -> index loc(#loc17)\n %85 = "stable_mosaic_gpu.memref.view"(%80, %84) : (memref>, index) -> memref<2048xi8, #gpu.address_space> loc(#loc17)\n %86 = "stable_mosaic_gpu.builtin.unrealized_conversion_cast"(%80) : (memref>) -> !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)> loc(#loc17)\n %87 = "stable_mosaic_gpu.llvm.extractvalue"(%86) {position = array} : (!llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>) -> !llvm.ptr<3> loc(#loc17)\n %88 = "stable_mosaic_gpu.llvm.extractvalue"(%86) {position = array} : (!llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>) -> i64 loc(#loc17)\n %89 = "stable_mosaic_gpu.arith.constant"() {value = 1 : i64} : () -> i64 loc(#loc17)\n %90 = "stable_mosaic_gpu.llvm.mul"(%88, %89) : (i64, i64) -> i64 loc(#loc17)\n %91 = "stable_mosaic_gpu.llvm.ptrtoint"(%87) : (!llvm.ptr<3>) -> i64 loc(#loc17)\n %92 = "stable_mosaic_gpu.llvm.add"(%91, %90) : (i64, i64) -> i64 loc(#loc17)\n %93 = "stable_mosaic_gpu.llvm.inttoptr"(%92) : (i64) -> !llvm.ptr<3> loc(#loc17)\n %94 = "stable_mosaic_gpu.llvm.getelementptr"(%93) {elem_type = i8, rawConstantIndices = array} : (!llvm.ptr<3>) -> !llvm.ptr<3> loc(#loc17)\n %95 = "stable_mosaic_gpu.memref.alloca"() {operandSegmentSizes = array} : () -> memref loc(#loc17)\n %96 = "stable_mosaic_gpu.arith.constant"() {value = 0 : i32} : () -> i32 loc(#loc17)\n "stable_mosaic_gpu.memref.store"(%96, %95) : (i32, memref) -> () loc(#loc17)\n %97 = "stable_mosaic_gpu.nvvm.elect.sync"() : () -> i1 loc(#loc17)\n %98 = "stable_mosaic_gpu.gpu.thread_id"() {dimension = #gpu} : () -> index loc(#loc17)\n %99 = "stable_mosaic_gpu.arith.index_cast"(%98) : (index) -> i32 loc(#loc17)\n %100 = "stable_mosaic_gpu.gpu.block_dim"() {dimension = #gpu} : () -> index loc(#loc17)\n %101 = "stable_mosaic_gpu.arith.index_cast"(%100) : (index) -> i32 loc(#loc17)\n %102 = "stable_mosaic_gpu.gpu.thread_id"() {dimension = #gpu} : () -> index loc(#loc17)\n %103 = "stable_mosaic_gpu.arith.index_cast"(%102) : (index) -> i32 loc(#loc17)\n %104 = "stable_mosaic_gpu.arith.muli"(%103, %101) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc17)\n %105 = "stable_mosaic_gpu.arith.addi"(%99, %104) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc17)\n %106 = "stable_mosaic_gpu.gpu.block_dim"() {dimension = #gpu} : () -> index loc(#loc17)\n %107 = "stable_mosaic_gpu.arith.index_cast"(%106) : (index) -> i32 loc(#loc17)\n %108 = "stable_mosaic_gpu.arith.muli"(%101, %107) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc17)\n %109 = "stable_mosaic_gpu.gpu.thread_id"() {dimension = #gpu} : () -> index loc(#loc17)\n %110 = "stable_mosaic_gpu.arith.index_cast"(%109) : (index) -> i32 loc(#loc17)\n %111 = "stable_mosaic_gpu.arith.muli"(%110, %108) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc17)\n %112 = "stable_mosaic_gpu.arith.addi"(%105, %111) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc17)\n %113 = "stable_mosaic_gpu.gpu.block_dim"() {dimension = #gpu} : () -> index loc(#loc17)\n %114 = "stable_mosaic_gpu.arith.index_cast"(%113) : (index) -> i32 loc(#loc17)\n %115 = "stable_mosaic_gpu.arith.muli"(%108, %114) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc17)\n %116 = "stable_mosaic_gpu.arith.constant"() {value = 5 : i32} : () -> i32 loc(#loc17)\n %117 = "stable_mosaic_gpu.arith.shrui"(%112, %116) : (i32, i32) -> i32 loc(#loc17)\n %118 = "stable_mosaic_gpu.arith.constant"() {value = -1 : i32} : () -> i32 loc(#loc17)\n %119 = "stable_mosaic_gpu.arith.constant"() {value = 0 : i32} : () -> i32 loc(#loc17)\n %120 = "stable_mosaic_gpu.arith.constant"() {value = 31 : i32} : () -> i32 loc(#loc17)\n %121 = "stable_mosaic_gpu.nvvm.shfl.sync"(%118, %117, %119, %120) {kind = #nvvm} : (i32, i32, i32, i32) -> i32 loc(#loc17)\n %122 = "stable_mosaic_gpu.arith.constant"() {value = 0 : i32} : () -> i32 loc(#loc17)\n %123 = "stable_mosaic_gpu.arith.cmpi"(%121, %122) {predicate = 0 : i64} : (i32, i32) -> i1 loc(#loc17)\n %124 = "stable_mosaic_gpu.arith.andi"(%123, %97) : (i1, i1) -> i1 loc(#loc17)\n "stable_mosaic_gpu.scf.if"(%124) ({\n %332 = "stable_mosaic_gpu.llvm.getelementptr"(%94) {elem_type = i64, rawConstantIndices = array} : (!llvm.ptr<3>) -> !llvm.ptr<3> loc(#loc17)\n %333 = "stable_mosaic_gpu.arith.constant"() {value = 128 : i32} : () -> i32 loc(#loc17)\n "stable_mosaic_gpu.nvvm.mbarrier.init.shared"(%332, %333) : (!llvm.ptr<3>, i32) -> () loc(#loc17)\n "stable_mosaic_gpu.scf.yield"() : () -> () loc(#loc13)\n }, {\n }) : (i1) -> () loc(#loc17)\n %125 = "stable_mosaic_gpu.arith.constant"() {value = 0 : i32} : () -> i32 loc(#loc17)\n "stable_mosaic_gpu.nvvm.fence.mbarrier.init"() : () -> () loc(#loc17)\n "stable_mosaic_gpu.gpu.barrier"() : () -> () loc(#loc17)\n %126 = "stable_mosaic_gpu.nvvm.elect.sync"() : () -> i1 loc(#loc17)\n %127 = "stable_mosaic_gpu.gpu.thread_id"() {dimension = #gpu} : () -> index loc(#loc17)\n %128 = "stable_mosaic_gpu.arith.index_cast"(%127) : (index) -> i32 loc(#loc17)\n %129 = "stable_mosaic_gpu.gpu.block_dim"() {dimension = #gpu} : () -> index loc(#loc17)\n %130 = "stable_mosaic_gpu.arith.index_cast"(%129) : (index) -> i32 loc(#loc17)\n %131 = "stable_mosaic_gpu.gpu.thread_id"() {dimension = #gpu} : () -> index loc(#loc17)\n %132 = "stable_mosaic_gpu.arith.index_cast"(%131) : (index) -> i32 loc(#loc17)\n %133 = "stable_mosaic_gpu.arith.muli"(%132, %130) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc17)\n %134 = "stable_mosaic_gpu.arith.addi"(%128, %133) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc17)\n %135 = "stable_mosaic_gpu.gpu.block_dim"() {dimension = #gpu} : () -> index loc(#loc17)\n %136 = "stable_mosaic_gpu.arith.index_cast"(%135) : (index) -> i32 loc(#loc17)\n %137 = "stable_mosaic_gpu.arith.muli"(%130, %136) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc17)\n %138 = "stable_mosaic_gpu.gpu.thread_id"() {dimension = #gpu} : () -> index loc(#loc17)\n %139 = "stable_mosaic_gpu.arith.index_cast"(%138) : (index) -> i32 loc(#loc17)\n %140 = "stable_mosaic_gpu.arith.muli"(%139, %137) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc17)\n %141 = "stable_mosaic_gpu.arith.addi"(%134, %140) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc17)\n %142 = "stable_mosaic_gpu.gpu.block_dim"() {dimension = #gpu} : () -> index loc(#loc17)\n %143 = "stable_mosaic_gpu.arith.index_cast"(%142) : (index) -> i32 loc(#loc17)\n %144 = "stable_mosaic_gpu.arith.muli"(%137, %143) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc17)\n %145 = "stable_mosaic_gpu.arith.constant"() {value = 5 : i32} : () -> i32 loc(#loc17)\n %146 = "stable_mosaic_gpu.arith.shrui"(%141, %145) : (i32, i32) -> i32 loc(#loc17)\n %147 = "stable_mosaic_gpu.arith.constant"() {value = -1 : i32} : () -> i32 loc(#loc17)\n %148 = "stable_mosaic_gpu.arith.constant"() {value = 0 : i32} : () -> i32 loc(#loc17)\n %149 = "stable_mosaic_gpu.arith.constant"() {value = 31 : i32} : () -> i32 loc(#loc17)\n %150 = "stable_mosaic_gpu.nvvm.shfl.sync"(%147, %146, %148, %149) {kind = #nvvm} : (i32, i32, i32, i32) -> i32 loc(#loc17)\n %151 = "stable_mosaic_gpu.arith.constant"() {value = 4 : i32} : () -> i32 loc(#loc17)\n %152 = "stable_mosaic_gpu.arith.remui"(%150, %151) : (i32, i32) -> i32 loc(#loc17)\n %153 = "stable_mosaic_gpu.arith.constant"() {value = 0 : i32} : () -> i32 loc(#loc17)\n %154 = "stable_mosaic_gpu.arith.cmpi"(%152, %153) {predicate = 0 : i64} : (i32, i32) -> i1 loc(#loc17)\n %155 = "stable_mosaic_gpu.arith.andi"(%154, %126) : (i1, i1) -> i1 loc(#loc17)\n %156 = "stable_mosaic_gpu.nvvm.elect.sync"() : () -> i1 loc(#loc17)\n %157 = "stable_mosaic_gpu.gpu.block_id"() {dimension = #gpu} : () -> index loc(#loc17)\n %158 = "stable_mosaic_gpu.arith.index_cast"(%157) : (index) -> i32 loc(#loc17)\n %159 = "stable_mosaic_gpu.gpu.dynamic_shared_memory"() : () -> memref> loc(#loc20)\n %160 = "stable_mosaic_gpu.arith.constant"() {value = 0 : index} : () -> index loc(#loc20)\n %161 = "stable_mosaic_gpu.memref.view"(%159, %160) : (memref>, index) -> memref<1x256xf32, #gpu.address_space> loc(#loc20)\n %162 = "stable_mosaic_gpu.gpu.dynamic_shared_memory"() : () -> memref> loc(#loc20)\n %163 = "stable_mosaic_gpu.arith.constant"() {value = 1024 : index} : () -> index loc(#loc20)\n %164 = "stable_mosaic_gpu.memref.view"(%162, %163) : (memref>, index) -> memref<1x256xf32, #gpu.address_space> loc(#loc20)\n %165 = "stable_mosaic_gpu.arith.constant"() {value = 0 : index} : () -> index loc(#loc19)\n %166 = "stable_mosaic_gpu.memref.subview"(%161, %165) {operandSegmentSizes = array, static_offsets = array, static_sizes = array, static_strides = array} : (memref<1x256xf32, #gpu.address_space>, index) -> memref<256xf32, strided<[1], offset: ?>, #gpu.address_space> loc(#loc19)\n %167 = "stable_mosaic_gpu.arith.constant"() {value = 0 : index} : () -> index loc(#loc19)\n %168 = "stable_mosaic_gpu.arith.index_castui"(%167) : (index) -> i32 loc(#loc19)\n %169 = "stable_mosaic_gpu.arith.addi"(%125, %168) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc19)\n %170 = "stable_mosaic_gpu.arith.constant"() {value = 8 : i32} : () -> i32 loc(#loc19)\n %171 = "stable_mosaic_gpu.llvm.getelementptr"(%94, %169) {elem_type = i64, rawConstantIndices = array} : (!llvm.ptr<3>, i32) -> !llvm.ptr<3> loc(#loc19)\n "stable_mosaic_gpu.nvvm.mbarrier.arrive.expect_tx.shared"(%171, %170) : (!llvm.ptr<3>, i32) -> () loc(#loc19)\n %172 = "stable_mosaic_gpu.arith.constant"() {value = 0 : index} : () -> index loc(#loc19)\n %173 = "stable_mosaic_gpu.arith.index_cast"(%172) : (index) -> i32 loc(#loc19)\n %174 = "stable_mosaic_gpu.builtin.unrealized_conversion_cast"(%166) : (memref<256xf32, strided<[1], offset: ?>, #gpu.address_space>) -> !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)> loc(#loc19)\n %175 = "stable_mosaic_gpu.llvm.extractvalue"(%174) {position = array} : (!llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>) -> !llvm.ptr<3> loc(#loc19)\n %176 = "stable_mosaic_gpu.llvm.extractvalue"(%174) {position = array} : (!llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>) -> i64 loc(#loc19)\n %177 = "stable_mosaic_gpu.arith.constant"() {value = 4 : i64} : () -> i64 loc(#loc19)\n %178 = "stable_mosaic_gpu.llvm.mul"(%176, %177) : (i64, i64) -> i64 loc(#loc19)\n %179 = "stable_mosaic_gpu.llvm.ptrtoint"(%175) : (!llvm.ptr<3>) -> i64 loc(#loc19)\n %180 = "stable_mosaic_gpu.llvm.add"(%179, %178) : (i64, i64) -> i64 loc(#loc19)\n %181 = "stable_mosaic_gpu.llvm.inttoptr"(%180) : (i64) -> !llvm.ptr<3> loc(#loc19)\n %182 = "stable_mosaic_gpu.arith.constant"() {value = 1024 : i32} : () -> i32 loc(#loc19)\n %183 = "stable_mosaic_gpu.llvm.getelementptr"(%94, %169) {elem_type = i64, rawConstantIndices = array} : (!llvm.ptr<3>, i32) -> !llvm.ptr<3> loc(#loc19)\n "stable_mosaic_gpu.nvvm.cp.async.bulk.tensor.shared.cluster.global"(%181, %83, %173, %183, %155) {operandSegmentSizes = array} : (!llvm.ptr<3>, !llvm.ptr, i32, !llvm.ptr<3>, i1) -> () loc(#loc19)\n %184 = "stable_mosaic_gpu.arith.constant"() {value = 0 : i32} : () -> i32 loc(#loc21)\n %185 = "stable_mosaic_gpu.arith.constant"() {value = 0 : i32} : () -> i32 loc(#loc21)\n %186 = "stable_mosaic_gpu.arith.addi"(%185, %184) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc21)\n %187 = "stable_mosaic_gpu.arith.constant"() {value = 1 : i32} : () -> i32 loc(#loc22)\n %188 = "stable_mosaic_gpu.arith.remsi"(%186, %187) : (i32, i32) -> i32 loc(#loc22)\n %189 = "stable_mosaic_gpu.arith.index_cast"(%188) : (i32) -> index loc(#loc23)\n %190 = "stable_mosaic_gpu.arith.index_castui"(%189) : (index) -> i32 loc(#loc23)\n %191 = "stable_mosaic_gpu.arith.addi"(%125, %190) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc23)\n %192 = "stable_mosaic_gpu.memref.load"(%95) : (memref) -> i32 loc(#loc23)\n %193 = "stable_mosaic_gpu.arith.constant"() {value = 1 : i32} : () -> i32 loc(#loc23)\n %194 = "stable_mosaic_gpu.arith.shli"(%193, %191) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc23)\n %195 = "stable_mosaic_gpu.arith.andi"(%192, %194) : (i32, i32) -> i32 loc(#loc23)\n %196 = "stable_mosaic_gpu.arith.constant"() {value = 0 : i32} : () -> i32 loc(#loc23)\n %197 = "stable_mosaic_gpu.arith.cmpi"(%195, %196) {predicate = 1 : i64} : (i32, i32) -> i1 loc(#loc23)\n %198 = "stable_mosaic_gpu.arith.xori"(%192, %194) : (i32, i32) -> i32 loc(#loc23)\n "stable_mosaic_gpu.memref.store"(%198, %95) : (i32, memref) -> () loc(#loc23)\n %199 = "stable_mosaic_gpu.arith.constant"() {value = 10000000 : i32} : () -> i32 loc(#loc23)\n %200 = "stable_mosaic_gpu.arith.extui"(%197) : (i1) -> i32 loc(#loc23)\n %201 = "stable_mosaic_gpu.llvm.getelementptr"(%94, %191) {elem_type = i64, rawConstantIndices = array} : (!llvm.ptr<3>, i32) -> !llvm.ptr<3> loc(#loc23)\n "stable_mosaic_gpu.nvvm.mbarrier.try_wait.parity.shared"(%201, %200, %199) : (!llvm.ptr<3>, i32, i32) -> () loc(#loc23)\n %202 = "stable_mosaic_gpu.arith.index_cast"(%188) : (i32) -> index loc(#loc24)\n %203 = "stable_mosaic_gpu.memref.subview"(%161, %202) {operandSegmentSizes = array, static_offsets = array, static_sizes = array, static_strides = array} : (memref<1x256xf32, #gpu.address_space>, index) -> memref<256xf32, strided<[1], offset: ?>, #gpu.address_space> loc(#loc24)\n %204 = "stable_mosaic_gpu.arith.index_cast"(%188) : (i32) -> index loc(#loc24)\n %205 = "stable_mosaic_gpu.memref.subview"(%164, %204) {operandSegmentSizes = array, static_offsets = array, static_sizes = array, static_strides = array} : (memref<1x256xf32, #gpu.address_space>, index) -> memref<256xf32, strided<[1], offset: ?>, #gpu.address_space> loc(#loc24)\n %206 = "stable_mosaic_gpu.gpu.block_id"() {dimension = #gpu} : () -> index loc(#loc24)\n %207 = "stable_mosaic_gpu.arith.index_cast"(%206) : (index) -> i32 loc(#loc24)\n %208 = "stable_mosaic_gpu.memref.subview"(%203) {operandSegmentSizes = array, static_offsets = array, static_sizes = array, static_strides = array} : (memref<256xf32, strided<[1], offset: ?>, #gpu.address_space>) -> memref<256xf32, strided<[1], offset: ?>, #gpu.address_space> loc(#loc25)\n %209 = "stable_mosaic_gpu.memref.collapse_shape"(%208) {reassociation = [[0]]} : (memref<256xf32, strided<[1], offset: ?>, #gpu.address_space>) -> memref<256xf32, strided<[1], offset: ?>, #gpu.address_space> loc(#loc25)\n %210 = "stable_mosaic_gpu.gpu.thread_id"() {dimension = #gpu} : () -> index loc(#loc25)\n %211 = "stable_mosaic_gpu.arith.constant"() {value = 128 : index} : () -> index loc(#loc25)\n %212 = "stable_mosaic_gpu.arith.remui"(%210, %211) : (index, index) -> index loc(#loc25)\n %213 = "stable_mosaic_gpu.arith.constant"() {value = 2 : index} : () -> index loc(#loc25)\n %214 = "stable_mosaic_gpu.arith.muli"(%212, %213) {overflowFlags = #arith.overflow} : (index, index) -> index loc(#loc25)\n %215 = "stable_mosaic_gpu.arith.constant"() {value = 0 : index} : () -> index loc(#loc25)\n %216 = "stable_mosaic_gpu.arith.addi"(%214, %215) {overflowFlags = #arith.overflow} : (index, index) -> index loc(#loc25)\n %217 = "stable_mosaic_gpu.vector.load"(%209, %216) : (memref<256xf32, strided<[1], offset: ?>, #gpu.address_space>, index) -> vector<2xf32> loc(#loc25)\n %218 = "stable_mosaic_gpu.arith.constant"() {value = 1.000000e+00 : f32} : () -> f32 loc(#loc26)\n %219 = "stable_mosaic_gpu.vector.splat"(%218) : (f32) -> vector<2xf32> loc(#loc26)\n %220 = "stable_mosaic_gpu.arith.addf"(%217, %219) {fastmath = #arith.fastmath} : (vector<2xf32>, vector<2xf32>) -> vector<2xf32> loc(#loc26)\n %221 = "stable_mosaic_gpu.memref.subview"(%205) {operandSegmentSizes = array, static_offsets = array, static_sizes = array, static_strides = array} : (memref<256xf32, strided<[1], offset: ?>, #gpu.address_space>) -> memref<256xf32, strided<[1], offset: ?>, #gpu.address_space> loc(#loc27)\n %222 = "stable_mosaic_gpu.memref.collapse_shape"(%221) {reassociation = [[0]]} : (memref<256xf32, strided<[1], offset: ?>, #gpu.address_space>) -> memref<256xf32, strided<[1], offset: ?>, #gpu.address_space> loc(#loc27)\n %223 = "stable_mosaic_gpu.gpu.thread_id"() {dimension = #gpu} : () -> index loc(#loc27)\n %224 = "stable_mosaic_gpu.arith.constant"() {value = 128 : index} : () -> index loc(#loc27)\n %225 = "stable_mosaic_gpu.arith.remui"(%223, %224) : (index, index) -> index loc(#loc27)\n %226 = "stable_mosaic_gpu.arith.constant"() {value = 2 : index} : () -> index loc(#loc27)\n %227 = "stable_mosaic_gpu.arith.muli"(%225, %226) {overflowFlags = #arith.overflow} : (index, index) -> index loc(#loc27)\n %228 = "stable_mosaic_gpu.arith.constant"() {value = 0 : index} : () -> index loc(#loc27)\n %229 = "stable_mosaic_gpu.arith.addi"(%227, %228) {overflowFlags = #arith.overflow} : (index, index) -> index loc(#loc27)\n %230 = "stable_mosaic_gpu.vector.load"(%222, %229) : (memref<256xf32, strided<[1], offset: ?>, #gpu.address_space>, index) -> vector<2xf32> loc(#loc27)\n %231 = "stable_mosaic_gpu.memref.collapse_shape"(%221) {reassociation = [[0]]} : (memref<256xf32, strided<[1], offset: ?>, #gpu.address_space>) -> memref<256xf32, strided<[1], offset: ?>, #gpu.address_space> loc(#loc27)\n %232 = "stable_mosaic_gpu.gpu.thread_id"() {dimension = #gpu} : () -> index loc(#loc27)\n %233 = "stable_mosaic_gpu.arith.constant"() {value = 128 : index} : () -> index loc(#loc27)\n %234 = "stable_mosaic_gpu.arith.remui"(%232, %233) : (index, index) -> index loc(#loc27)\n %235 = "stable_mosaic_gpu.arith.constant"() {value = 2 : index} : () -> index loc(#loc27)\n %236 = "stable_mosaic_gpu.arith.muli"(%234, %235) {overflowFlags = #arith.overflow} : (index, index) -> index loc(#loc27)\n %237 = "stable_mosaic_gpu.arith.constant"() {value = 0 : index} : () -> index loc(#loc27)\n %238 = "stable_mosaic_gpu.arith.addi"(%236, %237) {overflowFlags = #arith.overflow} : (index, index) -> index loc(#loc27)\n "stable_mosaic_gpu.vector.store"(%220, %231, %238) : (vector<2xf32>, memref<256xf32, strided<[1], offset: ?>, #gpu.address_space>, index) -> () loc(#loc27)\n "stable_mosaic_gpu.nvvm.cp.async.bulk.commit.group"() : () -> () loc(#loc28)\n %239 = "stable_mosaic_gpu.arith.constant"() {value = 1 : i32} : () -> i32 loc(#loc29)\n %240 = "stable_mosaic_gpu.arith.addi"(%186, %239) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc29)\n %241 = "stable_mosaic_gpu.arith.constant"() {value = 1 : i32} : () -> i32 loc(#loc22)\n %242 = "stable_mosaic_gpu.arith.remsi"(%240, %241) : (i32, i32) -> i32 loc(#loc22)\n %243 = "stable_mosaic_gpu.arith.constant"() {value = 0 : i32} : () -> i32 loc(#loc30)\n %244 = "stable_mosaic_gpu.arith.cmpi"(%186, %243) {predicate = 9 : i64} : (i32, i32) -> i1 loc(#loc30)\n %245 = "stable_mosaic_gpu.arith.constant"() {value = 1 : i32} : () -> i32 loc(#loc31)\n %246 = "stable_mosaic_gpu.arith.cmpi"(%240, %245) {predicate = 6 : i64} : (i32, i32) -> i1 loc(#loc31)\n %247 = "stable_mosaic_gpu.arith.andi"(%244, %246) : (i1, i1) -> i1 loc(#loc32)\n %248 = "stable_mosaic_gpu.arith.extui"(%247) : (i1) -> i32 loc(#loc33)\n %249 = "stable_mosaic_gpu.arith.index_cast"(%248) : (i32) -> index loc(#loc34)\n "stable_mosaic_gpu.scf.index_switch"(%249) ({\n %313 = "stable_mosaic_gpu.arith.index_cast"(%242) : (i32) -> index loc(#loc19)\n %314 = "stable_mosaic_gpu.memref.subview"(%161, %313) {operandSegmentSizes = array, static_offsets = array, static_sizes = array, static_strides = array} : (memref<1x256xf32, #gpu.address_space>, index) -> memref<256xf32, strided<[1], offset: ?>, #gpu.address_space> loc(#loc19)\n %315 = "stable_mosaic_gpu.arith.index_cast"(%242) : (i32) -> index loc(#loc19)\n %316 = "stable_mosaic_gpu.arith.index_castui"(%315) : (index) -> i32 loc(#loc19)\n %317 = "stable_mosaic_gpu.arith.addi"(%125, %316) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc19)\n %318 = "stable_mosaic_gpu.arith.constant"() {value = 8 : i32} : () -> i32 loc(#loc19)\n %319 = "stable_mosaic_gpu.llvm.getelementptr"(%94, %317) {elem_type = i64, rawConstantIndices = array} : (!llvm.ptr<3>, i32) -> !llvm.ptr<3> loc(#loc19)\n "stable_mosaic_gpu.nvvm.mbarrier.arrive.expect_tx.shared"(%319, %318) : (!llvm.ptr<3>, i32) -> () loc(#loc19)\n %320 = "stable_mosaic_gpu.arith.constant"() {value = 0 : index} : () -> index loc(#loc19)\n %321 = "stable_mosaic_gpu.arith.index_cast"(%320) : (index) -> i32 loc(#loc19)\n %322 = "stable_mosaic_gpu.builtin.unrealized_conversion_cast"(%314) : (memref<256xf32, strided<[1], offset: ?>, #gpu.address_space>) -> !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)> loc(#loc19)\n %323 = "stable_mosaic_gpu.llvm.extractvalue"(%322) {position = array} : (!llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>) -> !llvm.ptr<3> loc(#loc19)\n %324 = "stable_mosaic_gpu.llvm.extractvalue"(%322) {position = array} : (!llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>) -> i64 loc(#loc19)\n %325 = "stable_mosaic_gpu.arith.constant"() {value = 4 : i64} : () -> i64 loc(#loc19)\n %326 = "stable_mosaic_gpu.llvm.mul"(%324, %325) : (i64, i64) -> i64 loc(#loc19)\n %327 = "stable_mosaic_gpu.llvm.ptrtoint"(%323) : (!llvm.ptr<3>) -> i64 loc(#loc19)\n %328 = "stable_mosaic_gpu.llvm.add"(%327, %326) : (i64, i64) -> i64 loc(#loc19)\n %329 = "stable_mosaic_gpu.llvm.inttoptr"(%328) : (i64) -> !llvm.ptr<3> loc(#loc19)\n %330 = "stable_mosaic_gpu.arith.constant"() {value = 1024 : i32} : () -> i32 loc(#loc19)\n %331 = "stable_mosaic_gpu.llvm.getelementptr"(%94, %317) {elem_type = i64, rawConstantIndices = array} : (!llvm.ptr<3>, i32) -> !llvm.ptr<3> loc(#loc19)\n "stable_mosaic_gpu.nvvm.cp.async.bulk.tensor.shared.cluster.global"(%329, %83, %321, %331, %155) {operandSegmentSizes = array} : (!llvm.ptr<3>, !llvm.ptr, i32, !llvm.ptr<3>, i1) -> () loc(#loc19)\n "stable_mosaic_gpu.scf.yield"() : () -> () loc(#loc16)\n }, {\n "stable_mosaic_gpu.scf.yield"() : () -> () loc(#loc34)\n }) {cases = array} : (index) -> () loc(#loc34)\n %250 = "stable_mosaic_gpu.arith.constant"() {value = 1 : i32} : () -> i32 loc(#loc21)\n %251 = "stable_mosaic_gpu.arith.addi"(%184, %250) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc21)\n "stable_mosaic_gpu.nvvm.fence.proxy"() {kind = #nvvm.proxy_kind, space = #nvvm.shared_space} : () -> () loc(#loc35)\n %252 = "stable_mosaic_gpu.gpu.thread_id"() {dimension = #gpu} : () -> index loc(#loc35)\n %253 = "stable_mosaic_gpu.arith.index_cast"(%252) : (index) -> i32 loc(#loc35)\n %254 = "stable_mosaic_gpu.gpu.block_dim"() {dimension = #gpu} : () -> index loc(#loc35)\n %255 = "stable_mosaic_gpu.arith.index_cast"(%254) : (index) -> i32 loc(#loc35)\n %256 = "stable_mosaic_gpu.gpu.thread_id"() {dimension = #gpu} : () -> index loc(#loc35)\n %257 = "stable_mosaic_gpu.arith.index_cast"(%256) : (index) -> i32 loc(#loc35)\n %258 = "stable_mosaic_gpu.arith.muli"(%257, %255) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc35)\n %259 = "stable_mosaic_gpu.arith.addi"(%253, %258) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc35)\n %260 = "stable_mosaic_gpu.gpu.block_dim"() {dimension = #gpu} : () -> index loc(#loc35)\n %261 = "stable_mosaic_gpu.arith.index_cast"(%260) : (index) -> i32 loc(#loc35)\n %262 = "stable_mosaic_gpu.arith.muli"(%255, %261) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc35)\n %263 = "stable_mosaic_gpu.gpu.thread_id"() {dimension = #gpu} : () -> index loc(#loc35)\n %264 = "stable_mosaic_gpu.arith.index_cast"(%263) : (index) -> i32 loc(#loc35)\n %265 = "stable_mosaic_gpu.arith.muli"(%264, %262) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc35)\n %266 = "stable_mosaic_gpu.arith.addi"(%259, %265) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc35)\n %267 = "stable_mosaic_gpu.gpu.block_dim"() {dimension = #gpu} : () -> index loc(#loc35)\n %268 = "stable_mosaic_gpu.arith.index_cast"(%267) : (index) -> i32 loc(#loc35)\n %269 = "stable_mosaic_gpu.arith.muli"(%262, %268) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc35)\n %270 = "stable_mosaic_gpu.arith.constant"() {value = 7 : i32} : () -> i32 loc(#loc35)\n %271 = "stable_mosaic_gpu.arith.shrui"(%266, %270) : (i32, i32) -> i32 loc(#loc35)\n %272 = "stable_mosaic_gpu.arith.constant"() {value = 1 : i32} : () -> i32 loc(#loc35)\n %273 = "stable_mosaic_gpu.arith.addi"(%271, %272) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc35)\n %274 = "stable_mosaic_gpu.llvm.inline_asm"(%273) {asm_string = "bar.sync $0, 128;", constraints = "r", has_side_effects} : (i32) -> !llvm.void loc(#loc35)\n %275 = "stable_mosaic_gpu.arith.constant"() {value = 0 : i32} : () -> i32 loc(#loc22)\n %276 = "stable_mosaic_gpu.arith.constant"() {value = 1 : i32} : () -> i32 loc(#loc22)\n %277 = "stable_mosaic_gpu.arith.remsi"(%275, %276) : (i32, i32) -> i32 loc(#loc22)\n %278 = "stable_mosaic_gpu.arith.index_cast"(%277) : (i32) -> index loc(#loc18)\n %279 = "stable_mosaic_gpu.memref.subview"(%164, %278) {operandSegmentSizes = array, static_offsets = array, static_sizes = array, static_strides = array} : (memref<1x256xf32, #gpu.address_space>, index) -> memref<256xf32, strided<[1], offset: ?>, #gpu.address_space> loc(#loc18)\n %280 = "stable_mosaic_gpu.arith.constant"() {value = 0 : index} : () -> index loc(#loc18)\n %281 = "stable_mosaic_gpu.arith.index_cast"(%280) : (index) -> i32 loc(#loc18)\n %282 = "stable_mosaic_gpu.builtin.unrealized_conversion_cast"(%279) : (memref<256xf32, strided<[1], offset: ?>, #gpu.address_space>) -> !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)> loc(#loc18)\n %283 = "stable_mosaic_gpu.llvm.extractvalue"(%282) {position = array} : (!llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>) -> !llvm.ptr<3> loc(#loc18)\n %284 = "stable_mosaic_gpu.llvm.extractvalue"(%282) {position = array} : (!llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>) -> i64 loc(#loc18)\n %285 = "stable_mosaic_gpu.arith.constant"() {value = 4 : i64} : () -> i64 loc(#loc18)\n %286 = "stable_mosaic_gpu.llvm.mul"(%284, %285) : (i64, i64) -> i64 loc(#loc18)\n %287 = "stable_mosaic_gpu.llvm.ptrtoint"(%283) : (!llvm.ptr<3>) -> i64 loc(#loc18)\n %288 = "stable_mosaic_gpu.llvm.add"(%287, %286) : (i64, i64) -> i64 loc(#loc18)\n %289 = "stable_mosaic_gpu.llvm.inttoptr"(%288) : (i64) -> !llvm.ptr<3> loc(#loc18)\n "stable_mosaic_gpu.nvvm.cp.async.bulk.tensor.global.shared.cta"(%82, %289, %281, %155) {operandSegmentSizes = array} : (!llvm.ptr, !llvm.ptr<3>, i32, i1) -> () loc(#loc18)\n "stable_mosaic_gpu.nvvm.cp.async.bulk.commit.group"() : () -> () loc(#loc28)\n "stable_mosaic_gpu.nvvm.cp.async.bulk.wait_group"() {group = 0 : i32} : () -> () loc(#loc36)\n %290 = "stable_mosaic_gpu.gpu.thread_id"() {dimension = #gpu} : () -> index loc(#loc36)\n %291 = "stable_mosaic_gpu.arith.index_cast"(%290) : (index) -> i32 loc(#loc36)\n %292 = "stable_mosaic_gpu.gpu.block_dim"() {dimension = #gpu} : () -> index loc(#loc36)\n %293 = "stable_mosaic_gpu.arith.index_cast"(%292) : (index) -> i32 loc(#loc36)\n %294 = "stable_mosaic_gpu.gpu.thread_id"() {dimension = #gpu} : () -> index loc(#loc36)\n %295 = "stable_mosaic_gpu.arith.index_cast"(%294) : (index) -> i32 loc(#loc36)\n %296 = "stable_mosaic_gpu.arith.muli"(%295, %293) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc36)\n %297 = "stable_mosaic_gpu.arith.addi"(%291, %296) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc36)\n %298 = "stable_mosaic_gpu.gpu.block_dim"() {dimension = #gpu} : () -> index loc(#loc36)\n %299 = "stable_mosaic_gpu.arith.index_cast"(%298) : (index) -> i32 loc(#loc36)\n %300 = "stable_mosaic_gpu.arith.muli"(%293, %299) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc36)\n %301 = "stable_mosaic_gpu.gpu.thread_id"() {dimension = #gpu} : () -> index loc(#loc36)\n %302 = "stable_mosaic_gpu.arith.index_cast"(%301) : (index) -> i32 loc(#loc36)\n %303 = "stable_mosaic_gpu.arith.muli"(%302, %300) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc36)\n %304 = "stable_mosaic_gpu.arith.addi"(%297, %303) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc36)\n %305 = "stable_mosaic_gpu.gpu.block_dim"() {dimension = #gpu} : () -> index loc(#loc36)\n %306 = "stable_mosaic_gpu.arith.index_cast"(%305) : (index) -> i32 loc(#loc36)\n %307 = "stable_mosaic_gpu.arith.muli"(%300, %306) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc36)\n %308 = "stable_mosaic_gpu.arith.constant"() {value = 7 : i32} : () -> i32 loc(#loc36)\n %309 = "stable_mosaic_gpu.arith.shrui"(%304, %308) : (i32, i32) -> i32 loc(#loc36)\n %310 = "stable_mosaic_gpu.arith.constant"() {value = 1 : i32} : () -> i32 loc(#loc36)\n %311 = "stable_mosaic_gpu.arith.addi"(%309, %310) {overflowFlags = #arith.overflow} : (i32, i32) -> i32 loc(#loc36)\n %312 = "stable_mosaic_gpu.llvm.inline_asm"(%311) {asm_string = "bar.sync $0, 128;", constraints = "r", has_side_effects} : (i32) -> !llvm.void loc(#loc36)\n "stable_mosaic_gpu.gpu.terminator"() : () -> () loc(#loc17)\n }) {operandSegmentSizes = array, workgroup_attributions = 0 : i64} : (!gpu.async.token, index, index, index, index, index, index, i32) -> !gpu.async.token loc(#loc17)\n "stable_mosaic_gpu.func.return"() : () -> () loc(#loc17)\n }) {function_type = (!llvm.ptr, !llvm.ptr) -> (), llvm.emit_c_interface, sym_name = "mosaic_gpu_body"} : () -> () loc(#loc17)\n}) {stable_mosaic_gpu.version = 1 : i64} : () -> () loc(#loc17)\n#loc13 = loc("-":141:7)\n#loc14 = loc("third_party/py/jax/tests/pallas/export_back_compat_pallas_test.py":78:19)\n#loc15 = loc("third_party/py/jax/tests/pallas/export_back_compat_pallas_test.py":78:6)\n#loc16 = loc("-":279:7)\n#loc18 = loc("/copy_smem_to_gmem"(#loc))\n#loc19 = loc("/copy_gmem_to_smem"(#loc))\n#loc20 = loc("/run_scoped"(#loc))\n#loc21 = loc("/scan"(#loc))\n#loc22 = loc("/rem"(#loc))\n#loc23 = loc("/barrier_wait"(#loc))\n#loc24 = loc("/jaxpr_call"(#loc))\n#loc25 = loc("/get"(#loc14))\n#loc26 = loc("/add"(#loc14))\n#loc27 = loc("/swap"(#loc15))\n#loc28 = loc("/commit_group"(#loc))\n#loc29 = loc("/add"(#loc))\n#loc30 = loc("/ge"(#loc))\n#loc31 = loc("/lt"(#loc))\n#loc32 = loc("/and"(#loc))\n#loc33 = loc("/convert_element_type"(#loc))\n#loc34 = loc("/cond"(#loc))\n#loc35 = loc("/commit_smem"(#loc))\n#loc36 = loc("/wait_smem_to_gmem"(#loc))\n\x00mosaic_gpu\x00\x08\'\x07\x05\x1f\x01\x0b!%\'/1\x11357\x1d9\x1f\x1d\x1f', + xla_call_module_version=9, + nr_devices=1, +) # End paste diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/rocm_qr_hipsolver_geqrf.py b/jax/_src/internal_test_util/export_back_compat_test_data/rocm_qr_hipsolver_geqrf.py deleted file mode 100644 index bd5fa628741e..000000000000 --- a/jax/_src/internal_test_util/export_back_compat_test_data/rocm_qr_hipsolver_geqrf.py +++ /dev/null @@ -1,176 +0,0 @@ -# Copyright 2023 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import datetime -from numpy import array, float32 - -data_2024_08_05 = {} - -# Pasted from the test output (see export_back_compat_test_util.py module docstring) -data_2024_08_05["unbatched"] = dict( - testdata_version=1, - platform='rocm', - custom_call_targets=['hipsolver_geqrf', 'hipsolver_orgqr'], - serialized_date=datetime.date(2024, 8, 5), - inputs=(), - expected_outputs=(array([[ 0. , 0.9128709 , 0.40824834], - [-0.4472136 , 0.3651484 , -0.81649655], - [-0.8944272 , -0.18257423, 0.40824828]], dtype=float32), array([[-6.7082038e+00, -8.0498447e+00, -9.3914852e+00], - [ 0.0000000e+00, 1.0954450e+00, 2.1908898e+00], - [ 0.0000000e+00, 0.0000000e+00, 1.6371473e-09]], dtype=float32)), - mlir_module_text=r""" -#loc2 = loc("/release/jax/tests/export_back_compat_test.py":346:0) -#loc9 = loc("jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=triu keep_unused=False inline=False]"(#loc2)) -module @jit__lambda_ attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main() -> (tensor<3x3xf32> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<3x3xf32> {jax.result_info = "[1]", mhlo.layout_mode = "default"}) { - %0 = stablehlo.iota dim = 0 : tensor<9xf32> loc(#loc3) - %1 = stablehlo.reshape %0 : (tensor<9xf32>) -> tensor<3x3xf32> loc(#loc4) - %2:4 = stablehlo.custom_call @hipsolver_geqrf(%1) {api_version = 2 : i32, backend_config = "\00\00\00\00\01\00\00\00\03\00\00\00\03\00\00\00\00\01\00\00", operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>]} : (tensor<3x3xf32>) -> (tensor<3x3xf32>, tensor<3xf32>, tensor, tensor<256xf32>) loc(#loc5) - %c = stablehlo.constant dense<0> : tensor loc(#loc5) - %3 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor loc(#loc5) - %4 = stablehlo.compare EQ, %2#2, %3, SIGNED : (tensor, tensor) -> tensor loc(#loc5) - %5 = stablehlo.broadcast_in_dim %4, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc5) - %cst = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc5) - %6 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<3x3xf32> loc(#loc5) - %7 = stablehlo.broadcast_in_dim %5, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<3x3xi1> loc(#loc5) - %8 = stablehlo.select %7, %2#0, %6 : tensor<3x3xi1>, tensor<3x3xf32> loc(#loc5) - %9 = stablehlo.broadcast_in_dim %4, dims = [] : (tensor) -> tensor<1xi1> loc(#loc5) - %cst_0 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc5) - %10 = stablehlo.broadcast_in_dim %cst_0, dims = [] : (tensor) -> tensor<3xf32> loc(#loc5) - %11 = stablehlo.broadcast_in_dim %9, dims = [0] : (tensor<1xi1>) -> tensor<3xi1> loc(#loc5) - %12 = stablehlo.select %11, %2#1, %10 : tensor<3xi1>, tensor<3xf32> loc(#loc5) - %cst_1 = stablehlo.constant dense<0.000000e+00> : tensor loc(#loc6) - %13 = stablehlo.pad %8, %cst_1, low = [0, 0], high = [0, 0], interior = [0, 0] : (tensor<3x3xf32>, tensor) -> tensor<3x3xf32> loc(#loc7) - %14:3 = stablehlo.custom_call @hipsolver_orgqr(%13, %12) {api_version = 2 : i32, backend_config = "\00\00\00\00\01\00\00\00\03\00\00\00\03\00\00\00\03\00\00\00\80\00\00\00", operand_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>]} : (tensor<3x3xf32>, tensor<3xf32>) -> (tensor<3x3xf32>, tensor, tensor<128xf32>) loc(#loc8) - %c_2 = stablehlo.constant dense<0> : tensor loc(#loc8) - %15 = stablehlo.broadcast_in_dim %c_2, dims = [] : (tensor) -> tensor loc(#loc8) - %16 = stablehlo.compare EQ, %14#1, %15, SIGNED : (tensor, tensor) -> tensor loc(#loc8) - %17 = stablehlo.broadcast_in_dim %16, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc8) - %cst_3 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc8) - %18 = stablehlo.broadcast_in_dim %cst_3, dims = [] : (tensor) -> tensor<3x3xf32> loc(#loc8) - %19 = stablehlo.broadcast_in_dim %17, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<3x3xi1> loc(#loc8) - %20 = stablehlo.select %19, %14#0, %18 : tensor<3x3xi1>, tensor<3x3xf32> loc(#loc8) - %21 = call @triu(%8) : (tensor<3x3xf32>) -> tensor<3x3xf32> loc(#loc9) - return %20, %21 : tensor<3x3xf32>, tensor<3x3xf32> loc(#loc) - } loc(#loc) - func.func private @triu(%arg0: tensor<3x3xf32> {mhlo.layout_mode = "default"} loc("jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=triu keep_unused=False inline=False]"(#loc2))) -> (tensor<3x3xf32> {mhlo.layout_mode = "default"}) { - %0 = stablehlo.iota dim = 0 : tensor<3x3xi32> loc(#loc10) - %c = stablehlo.constant dense<-1> : tensor loc(#loc9) - %1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<3x3xi32> loc(#loc11) - %2 = stablehlo.add %0, %1 : tensor<3x3xi32> loc(#loc11) - %3 = stablehlo.iota dim = 1 : tensor<3x3xi32> loc(#loc12) - %4 = stablehlo.compare GE, %2, %3, SIGNED : (tensor<3x3xi32>, tensor<3x3xi32>) -> tensor<3x3xi1> loc(#loc13) - %cst = stablehlo.constant dense<0.000000e+00> : tensor loc(#loc9) - %5 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<3x3xf32> loc(#loc14) - %6 = stablehlo.select %4, %5, %arg0 : tensor<3x3xi1>, tensor<3x3xf32> loc(#loc15) - return %6 : tensor<3x3xf32> loc(#loc9) - } loc(#loc9) -} loc(#loc) -#loc = loc(unknown) -#loc1 = loc("/release/jax/tests/export_back_compat_test.py":345:0) -#loc3 = loc("jit()/jit(main)/iota[dtype=float32 shape=(9,) dimension=0]"(#loc1)) -#loc4 = loc("jit()/jit(main)/reshape[new_sizes=(3, 3) dimensions=None]"(#loc1)) -#loc5 = loc("jit()/jit(main)/geqrf"(#loc2)) -#loc6 = loc("jit()/jit(main)/qr[full_matrices=True]"(#loc2)) -#loc7 = loc("jit()/jit(main)/pad[padding_config=((0, 0, 0), (0, 0, 0))]"(#loc2)) -#loc8 = loc("jit()/jit(main)/householder_product"(#loc2)) -#loc10 = loc("jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=0]"(#loc2)) -#loc11 = loc("jit()/jit(main)/jit(triu)/add"(#loc2)) -#loc12 = loc("jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=1]"(#loc2)) -#loc13 = loc("jit()/jit(main)/jit(triu)/ge"(#loc2)) -#loc14 = loc("jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(3, 3) broadcast_dimensions=()]"(#loc2)) -#loc15 = loc("jit()/jit(main)/jit(triu)/select_n"(#loc2)) -""", - mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01)\x05\x01\x03\x01\x03\x05\x03\x19\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x03~\x02\xf39\x01\x99\x0f\x17\x13\x0f\x0f\x0b\x0b\x07\x0b\x13\x0f\x0b\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x0b\x13\x17\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x13+\x0b\x0f\x0b\x0b\x0b33\x0b\x0f\x0b\x13\x0b\x13\x0f\x0b\x1b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0bK\x0b\x13\x0f\x0b#\x0b\x0b\x0b\x0f\x0bK\x0b\x13\x0b\x03[O/\x0b\x0b\x0b/\x0b\x0f\x0b\x0b\x0b\x0b\x0f\x0f\x0b\x13\x1b\x0b\x1b\x0b\x0b\x0b\x13\x0b\x0b\x0f\x1f\x0f\x0f\x0b\x1f\x0b\x0b\x0f\x17\x1b\x1f\x0b\x1fO/\x0b\x0b\x13\x17\x01\x05\x0b\x0f\x035\x17\x0f\x0f\x07\x07\x07\x17\x17\x13\x07\x07\x0f\x17\x13\x17\x17\x13\x13\x17\x13\x13\x13\x13\x13\x13\x17\x02\xde\x08\x1d}\x03\x17\x1fj\x05\x01\x03\x03\x11\xcf\x1d\x93\x03\x1dU\x03\x05\x1f\x05!\x1f\x05#\x03\x03\x0b\xe5\x11\x03\x05\x05%\x05'\x05)\x05+\x05-\x03\x03#\xcb\x05/\x1d]\x03\x051\x053\x03\x03\x0b\xd5\x17\x1ff\x05\x01\x055\x057\x059\x05;\x05=\x05?\x05A\x05C\x03\x03\x0b\xe1\x03\x05'\xab)\xe3\x03\x03\x11\xe7\x03\tGIK\x15M\x15\rO\x05E\x11\x01\x00\x05G\x05I\x05K\x03\x0b\x17\x9d\x19\xb5\x1b\xb7\r\xc1\x1d\xc3\x03\x0b\x17\xa7\x19\xc7\x1b\xa7\r\xa9\x1d\xc9\x05M\x1dY\x03\x05O\x03\x03\x0b\xcd\x05Q\x03\x03#\xd1\x1dc\x03\x05S\x03\x05'\xab)\xd3\x1di\x03\x05U\x1dm\x03\x05W\x1dq\x03\x05Y\x1du-\x05[\x1dy-\x05]\x03\x11/\xad1\xd73\xd95\x9d7\xaf9\xdb;\xb1=\xdf\x05_\x03\x03\x11\xe9\x1d\x83\x03\x05a\x03\x07\x87\xa3\x89\xa3\x8b\xa3\x05c\x05e\x05g\x1d\x8f\x03\x05i\x03\x11/\xad1\xeb3\xed5\x9d7\xaf9\xef;\xb1=\xf1\x05k\x03\x03\x97\xa9\x05m\x1f+!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f-\x11\x00\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1do\x1dq\x1f\x1f\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1ds\x03\x03\xc5\x1du\t\x07\x0b\x05\x05\x01\x03\x03\xdd\x1f/\x01#!\x03\x05\xb9\xbd\r\x05\xa5\xbb\x9f\xa1\x1dw\r\x05\xa5\xbf\x9f\xa1\x1dy\x1d{\x1d}\r\x03\x9f\xa1##\x1d\x7f\x13\r\x01\x1f\x07\t\xff\xff\xff\xff\x1f%\x01\x13\r\x05\x07\x05\x1f\t\t\x00\x00\x00\x00\x1d\x81\x1d\x83\x03\x03\x99\x15\x03\x01\x01\x01\x03\t\x99\x9b\xb3\x9b\x1f\x07\t\x00\x00\x00\x00\x07\x01\x1f\t\t\x00\x00\xc0\x7f\x1f\x1f!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f5\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1d\x85\x1d\x87\x03\x05\x99\x9b\x03\x07\x99\xb3\x9b\x01\t\x01\x02\x02)\x05\r\r\x0b)\x01\x19)\x01\x0b\t\x1d\x01)\x05\r\r\x19)\x05\r\r\x0f)\x03\r\x0b\x13\x1b)\x01\x0f)\x05\x05\x05\x0f)\x03\t\r\x11\x01\x05\x05\x05\x11\x03\x05\x03\x05)\x03\x01\r)\x03%\x0b)\x03\x02\x08\x0b)\x03\t\x17)\x03\x05\x17)\x03\x01\x17)\x03\x05\x0f)\x03\r\x0f)\x03\x05\r)\x03\x02\x04\x0b\x04\x1a\x05\x05\x01\x11\x0fE\x07\x03\x01\t\r\x11\x0fQ\x07\x03Cu\t\x03s!\x03'\x15\x06w\x03\x05\x03\x01\x11\x07\x01{\t\x05\x15\x07)\x03\x03\x05\x03\x01?\x03\x07\x03\x07\x01\x05\x03\x07\x03\r\x0b\x07\x01A\x03\x1b\x05\t\x0f\x03\x07\x01\x05\x03\x1d\x03\x11\x05\x03\x01\x13\x03\t\x03\x07\x01\x05\x03\x05\x03\x15\x03\x07\x01C\x03\x13\x03\x13\x07\x06\x01\x03\x05\x07\x19\x05\x17\x03\x07\x01\x05\x031\x03\x11\x05\x03\x01\x13\x03\t\x03\x07\x01\x05\x03\x15\x03\x1f\x03\x07\x01\x7f\x033\x03\x1d\x07\x06\x01\x03\x15\x07#\x07!\x05\x03\x81+\x03\t\x17\x07\x8d\x85\x03\x05\x05\x1b'\x11\x07\x07\x91\x07\x05\x077\x05)%\x05\x03\x07?\x03\x07\x03\x07\x07\x05\x03\x07\x031\x0b\x07\x07A\x03\x1b\x05-3\x03\x07\x07\x05\x03\x1d\x035\x05\x03\x07\x13\x03\t\x03\x07\x07\x05\x03\x05\x039\x03\x07\x07C\x03\x13\x037\x07\x06\x07\x03\x05\x07=+;\x19\x07\t\x95\x03\x05\x03\x1b\x0f\x04\x0f\x05?A\r\x11\tS\x07\x03\x15+\x03\x05\t\t\x03W!\x03\x11\x05\x03\t[\x03\x07\x03\x07%\x05\x03\x11\x03\x05\x13\x06%\x03\x11\x05\x03\x07\t\x03a_\x03\x11\x0b\x07ge\x03\x13\x05\t\x0b\x05\x03\t+\x03\t\x03\x07k\x05\x03\x05\x03\x0f\x07\x06o\x03\x05\x07\r\x11\x01\x0f\x04\t\x03\x13\x06\x03\x01\x05\x01\x00\xea\x1a\x89!3!+\x11\x0f\x0b\t\t\x0b!\x11#\x0fY\x87##%_=\x85\x87W\xb3K\x9bM\x9bn\x03\x1b%)9\x1f/!!)#\x1f\x19+\x1b\x1f]\x1f\x15\x1d\x15+\x13\r\x11\x0f\x17\x0f\x1f\x15\x11\x17\x11\x15\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00select_v1\x00iota_v1\x00compare_v1\x00func_v1\x00return_v1\x00custom_call_v1\x00add_v1\x00reshape_v1\x00pad_v1\x00call_v1\x00value\x00sym_name\x00broadcast_dimensions\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00/release/jax/tests/export_back_compat_test.py\x00iota_dimension\x00compare_type\x00comparison_direction\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=triu keep_unused=False inline=False]\x00jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=0]\x00jit()/jit(main)/jit(triu)/add\x00jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=1]\x00jit()/jit(main)/jit(triu)/ge\x00jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(3, 3) broadcast_dimensions=()]\x00jit()/jit(main)/jit(triu)/select_n\x00jit()/jit(main)/iota[dtype=float32 shape=(9,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(3, 3) dimensions=None]\x00jit()/jit(main)/geqrf\x00jit()/jit(main)/qr[full_matrices=True]\x00edge_padding_high\x00edge_padding_low\x00interior_padding\x00jit()/jit(main)/pad[padding_config=((0, 0, 0), (0, 0, 0))]\x00jit()/jit(main)/householder_product\x00callee\x00mhlo.layout_mode\x00default\x00jax.result_info\x00triu\x00[0]\x00[1]\x00main\x00public\x00private\x00\x00\x00\x00\x00\x01\x00\x00\x00\x03\x00\x00\x00\x03\x00\x00\x00\x00\x01\x00\x00\x00hipsolver_geqrf\x00\x00\x00\x00\x00\x01\x00\x00\x00\x03\x00\x00\x00\x03\x00\x00\x00\x03\x00\x00\x00\x80\x00\x00\x00\x00hipsolver_orgqr\x00", - xla_call_module_version=9, - nr_devices=1, -) # End paste - - -# Pasted from the test output (see export_back_compat_test_util.py module docstring) -data_2024_08_05["batched"] = dict( - testdata_version=1, - platform='rocm', - custom_call_targets=['hipblas_geqrf_batched', 'hipsolver_orgqr'], - serialized_date=datetime.date(2024, 8, 5), - inputs=(), - expected_outputs=(array([[[ 0. , 0.9128709 , 0.40824834], - [-0.4472136 , 0.3651484 , -0.81649655], - [-0.8944272 , -0.18257423, 0.40824828]], - - [[-0.42426407, 0.8082888 , 0.4082513 ], - [-0.5656854 , 0.11547317, -0.81649613], - [-0.7071068 , -0.5773518 , 0.40824607]]], dtype=float32), array([[[-6.7082038e+00, -8.0498447e+00, -9.3914852e+00], - [ 0.0000000e+00, 1.0954450e+00, 2.1908898e+00], - [ 0.0000000e+00, 0.0000000e+00, 1.6371473e-09]], - - [[-2.1213203e+01, -2.2910259e+01, -2.4607313e+01], - [ 0.0000000e+00, 3.4641036e-01, 6.9281983e-01], - [ 0.0000000e+00, 0.0000000e+00, 8.3555670e-07]]], dtype=float32)), - mlir_module_text=r""" -#loc2 = loc("/release/jax/tests/export_back_compat_test.py":346:0) -#loc9 = loc("jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=triu keep_unused=False inline=False]"(#loc2)) -module @jit__lambda_ attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main() -> (tensor<2x3x3xf32> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<2x3x3xf32> {jax.result_info = "[1]", mhlo.layout_mode = "default"}) { - %0 = stablehlo.iota dim = 0 : tensor<18xf32> loc(#loc3) - %1 = stablehlo.reshape %0 : (tensor<18xf32>) -> tensor<2x3x3xf32> loc(#loc4) - %2:4 = stablehlo.custom_call @hipblas_geqrf_batched(%1) {api_version = 2 : i32, backend_config = "\00\00\00\00\02\00\00\00\03\00\00\00\03\00\00\00", operand_layouts = [dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>]} : (tensor<2x3x3xf32>) -> (tensor<2x3x3xf32>, tensor<2x3xf32>, tensor<16xi8>, tensor<16xi8>) loc(#loc5) - %cst = stablehlo.constant dense<0.000000e+00> : tensor loc(#loc6) - %3 = stablehlo.pad %2#0, %cst, low = [0, 0, 0], high = [0, 0, 0], interior = [0, 0, 0] : (tensor<2x3x3xf32>, tensor) -> tensor<2x3x3xf32> loc(#loc7) - %4:3 = stablehlo.custom_call @hipsolver_orgqr(%3, %2#1) {api_version = 2 : i32, backend_config = "\00\00\00\00\02\00\00\00\03\00\00\00\03\00\00\00\03\00\00\00\80\00\00\00", operand_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>]} : (tensor<2x3x3xf32>, tensor<2x3xf32>) -> (tensor<2x3x3xf32>, tensor<2xi32>, tensor<128xf32>) loc(#loc8) - %c = stablehlo.constant dense<0> : tensor loc(#loc8) - %5 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<2xi32> loc(#loc8) - %6 = stablehlo.compare EQ, %4#1, %5, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc8) - %7 = stablehlo.broadcast_in_dim %6, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc8) - %cst_0 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc8) - %8 = stablehlo.broadcast_in_dim %cst_0, dims = [] : (tensor) -> tensor<2x3x3xf32> loc(#loc8) - %9 = stablehlo.broadcast_in_dim %7, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x3x3xi1> loc(#loc8) - %10 = stablehlo.select %9, %4#0, %8 : tensor<2x3x3xi1>, tensor<2x3x3xf32> loc(#loc8) - %11 = call @triu(%2#0) : (tensor<2x3x3xf32>) -> tensor<2x3x3xf32> loc(#loc9) - return %10, %11 : tensor<2x3x3xf32>, tensor<2x3x3xf32> loc(#loc) - } loc(#loc) - func.func private @triu(%arg0: tensor<2x3x3xf32> {mhlo.layout_mode = "default"} loc("jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=triu keep_unused=False inline=False]"(#loc2))) -> (tensor<2x3x3xf32> {mhlo.layout_mode = "default"}) { - %0 = stablehlo.iota dim = 0 : tensor<3x3xi32> loc(#loc10) - %c = stablehlo.constant dense<-1> : tensor loc(#loc9) - %1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<3x3xi32> loc(#loc11) - %2 = stablehlo.add %0, %1 : tensor<3x3xi32> loc(#loc11) - %3 = stablehlo.iota dim = 1 : tensor<3x3xi32> loc(#loc12) - %4 = stablehlo.compare GE, %2, %3, SIGNED : (tensor<3x3xi32>, tensor<3x3xi32>) -> tensor<3x3xi1> loc(#loc13) - %5 = stablehlo.broadcast_in_dim %4, dims = [1, 2] : (tensor<3x3xi1>) -> tensor<2x3x3xi1> loc(#loc14) - %cst = stablehlo.constant dense<0.000000e+00> : tensor loc(#loc9) - %6 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<2x3x3xf32> loc(#loc15) - %7 = stablehlo.select %5, %6, %arg0 : tensor<2x3x3xi1>, tensor<2x3x3xf32> loc(#loc16) - return %7 : tensor<2x3x3xf32> loc(#loc9) - } loc(#loc9) -} loc(#loc) -#loc = loc(unknown) -#loc1 = loc("/release/jax/tests/export_back_compat_test.py":345:0) -#loc3 = loc("jit()/jit(main)/iota[dtype=float32 shape=(18,) dimension=0]"(#loc1)) -#loc4 = loc("jit()/jit(main)/reshape[new_sizes=(2, 3, 3) dimensions=None]"(#loc1)) -#loc5 = loc("jit()/jit(main)/geqrf"(#loc2)) -#loc6 = loc("jit()/jit(main)/qr[full_matrices=True]"(#loc2)) -#loc7 = loc("jit()/jit(main)/pad[padding_config=((0, 0, 0), (0, 0, 0), (0, 0, 0))]"(#loc2)) -#loc8 = loc("jit()/jit(main)/householder_product"(#loc2)) -#loc10 = loc("jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=0]"(#loc2)) -#loc11 = loc("jit()/jit(main)/jit(triu)/add"(#loc2)) -#loc12 = loc("jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=1]"(#loc2)) -#loc13 = loc("jit()/jit(main)/jit(triu)/ge"(#loc2)) -#loc14 = loc("jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(2, 3, 3) broadcast_dimensions=(1, 2)]"(#loc2)) -#loc15 = loc("jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(2, 3, 3) broadcast_dimensions=()]"(#loc2)) -#loc16 = loc("jit()/jit(main)/jit(triu)/select_n"(#loc2)) -""", - mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01)\x05\x01\x03\x01\x03\x05\x03\x19\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x03\x96\x02\xfb=\x01\x9f\x17\x0f\x0f\x0b\x13\x0b\x0b\x07\x0f\x0b\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x0b\x13\x17\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b+\x0b\x0f\x0b\x0b\x0b33\x0b\x0f\x0b\x13\x0b\x13\x0f\x0b\x1b\x0f\x0b\x13\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0bK\x0f\x0b\x0f\x0b#\x0b\x0b\x0b\x0f\x0bK\x0b\x13\x1b\x13\x13\x13\x13\x0b\x03]o/\x0b\x0b\x0b/\x0b\x0f\x0b\x0b\x0b\x0b\x0fO\x0b\x13\x1b\x0b\x1b\x0b\x0b\x0b\x13\x0b\x0b\x0f\x1f\x0f\x0f\x0bO\x1f\x0b\x0b\x0f\x17\x1b\x0b\x0b\x13\x17\x1f\x0b/\x1fo\x01\x05\x0b\x0f\x039\x1b\x07\x07\x0f\x17\x0f\x07\x07\x07\x1b\x13\x13\x13\x17\x17\x13\x17\x13\x13\x17\x07\x13\x13\x13\x17\x13\x1b\x13\x02\xf6\t\x17\x1bj\x05\x01\x1d\x8f\x01\x1dK\x01\x05\x1f\x03\x03\x0b\xd5\x05!\x05#\x1f\x11\x03\x05\x05%\x05'\x05)\x05+\x05-\x03\x03\x1f\xd1\x05/\x1dS\x01\x051\x053\x03\x03\x07\xdd\x17\x1bf\x05\x01\x055\x057\x059\x05;\x05=\x05?\x05A\x05C\x03\t=?A\x11C\x11\rE\x05E\x11\x01\x00\x05G\x05I\x05K\x03\x0b\x13\xa3\x15\xbb\x17\xbd\r\xc7\x19\xc9\x03\x0b\x13\xad\x15\xcd\x17\xad\r\xaf\x19\xcf\x05M\x1dO\x01\x05O\x03\x03\x07\xd3\x05Q\x03\x03\x1f\xd7\x1dY\x01\x05S\x03\x05#\xb1%\xd9\x1d_\x01\x05U\x03\x03\x0b\xdb\x1de\x01\x05W\x1di\x01\x05Y\x1dm\x01\x05[\x1dq)\x05]\x1du)\x05_\x03\x11+\xb3-\xdf/\xe11\xa33\xb55\xe37\xb79\xe7\x1d{\x01\x05a\x1d\x7f\x01\x05c\x03\x07\x83\xa9\x85\xa9\x87\xa9\x05e\x05g\x05i\x1d\x8b\x01\x05k\x03\x11+\xb3-\xe9/\xeb1\xa33\xb55\xed7\xb79\xef\x05m\x03\x03\x07\xf1\x03\x05#\xb1%\xf3\x03\x03\x0b\xf5\x03\x03\x07\xf7\x03\x03\x0b\xf9\x03\x03\x9d\xaf\x05o\x1f/1\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f3\x11\x00\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1dq\x1ds\x1f\x1b\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1du\x03\x03\xcb\x1dw\t\x07\x0b\x05\x05\x01\x03\x03\xe5\x1f1!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00#\x1f\x03\x05\xbf\xc3\r\x05\xab\xc1\xa5\xa7\x1dy\r\x05\xab\xc5\xa5\xa7\x1d{\x1d}\x1d\x7f\r\x03\xa5\xa7#!\x1d\x81\x13\x07\x01\x1f\x0f\t\xff\xff\xff\xff\x1f#\x01\x13\x07\x05\x07\x05\x1f'!\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x1f\x0b\t\x00\x00\x00\x00\x1d\x83\x1d\x85\x03\x03\x9f\x15\x03\x01\x01\x01\x03\t\x9f\xb9\xa1\xa1\x1d\x87\x1d\x89\x03\x05\x9f\xb9\x03\x07\x9f\xa1\xa1\x1f\x0f\t\x00\x00\x00\x00\x07\x01\x1f;\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x0b\t\x00\x00\xc0\x7f\x1f\x1b1\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x07\t\r\r\t\x1d\t)\x01\t)\x05\r\r\x13)\x01\x13\x01\x1b\x13)\x07\t\r\r\x11)\x03A-)\x03\r\x07)\x03\t\x13\x11\x01\x05\x05\x05\x11\x03\x05\x03\x05)\x03\x01\x07)\x05\r\r\x11)\x03\t\x07)\x03I\t)\x05\t\r\t\x17)\x03\r\x15)\x03\t\x15)\x03\x05\x15)\x03\x02\x04\t)\x03\t\x11)\x07\t\x05\x05\x11)\x03\x05\x07\x04\xa6\x03\x05\x01\x11\x0f;\x07\x03\x01\t\t\x11\x0fG\x07\x03)A\x07\x03o\x1d\x03)\x15\x06s\x03\x05\x03\x01\x11\x07yw\t\x05+\x19\x19\x03\x03\x05\x03}'\x03\x0b\x17\x07\x89\x81\x03\x05\x05\x05\r\x11\x07\x03\x8d\x07\x05\x1d5\x05\x0f\x07\x05\x03\x03\x91\x03\x0f\x03\x07\x03\t\x03\x1d\x03\x17\x0b\x07\x03\x93\x037\x05\x13\x19\x03\x07\x03\x95\x039\x03\x1b\x05\x03\x03\x97\x03\x0b\x03\x07\x03\t\x03\x05\x03\x1f\x03\x07\x03\x99\x03\x17\x03\x1d\r\x06\x03\x03\x05\x07#\x11!\x19\x07\x05\x9b\x03\x05\x03\x05\x0f\x04\x0f\x05%'\t\x11\x05I\x07\x03\x17/\x03\x05\x05\x07\x03M\x1d\x03\r\x05\x03\x05Q\x03\x0f\x03\x07!\t\x03\r\x03\x05\x13\x06!\x03\r\x05\x03\x07\x07\x03WU\x03\r\x0b\x07][\x03%\x05\t\x0b\x03\x07ca\x03\x17\x03\r\x05\x03\x05'\x03\x0b\x03\x07g\t\x03\x05\x03\x11\r\x06k\x03\x05\x07\x0f\x13\x01\x0f\x04\x05\x03\x15\x06\x03\x01\x05\x01\x00\xbe\x1c\x8b!3-#\x11\x0f\x0b\t\t\x0b!\x11#\x0fY\x9d##%_=\x8b\x89W\xb9\xc1K\x9bM\x9bn\x03\x1b%)9\x1f/!!)#\x1f\x19+\x1b\x1f]\x1f\x15\x1d\x15\x13+\r\x11\x0f\x17\x0f\x1f\x15\x15\x17\x11\x11\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00iota_v1\x00func_v1\x00compare_v1\x00select_v1\x00return_v1\x00custom_call_v1\x00add_v1\x00reshape_v1\x00pad_v1\x00call_v1\x00value\x00broadcast_dimensions\x00sym_name\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00/release/jax/tests/export_back_compat_test.py\x00iota_dimension\x00compare_type\x00comparison_direction\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=triu keep_unused=False inline=False]\x00jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=0]\x00jit()/jit(main)/jit(triu)/add\x00jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=1]\x00jit()/jit(main)/jit(triu)/ge\x00jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(2, 3, 3) broadcast_dimensions=(1, 2)]\x00jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(2, 3, 3) broadcast_dimensions=()]\x00jit()/jit(main)/jit(triu)/select_n\x00jit()/jit(main)/iota[dtype=float32 shape=(18,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(2, 3, 3) dimensions=None]\x00jit()/jit(main)/geqrf\x00jit()/jit(main)/qr[full_matrices=True]\x00edge_padding_high\x00edge_padding_low\x00interior_padding\x00jit()/jit(main)/pad[padding_config=((0, 0, 0), (0, 0, 0), (0, 0, 0))]\x00jit()/jit(main)/householder_product\x00callee\x00mhlo.layout_mode\x00default\x00jax.result_info\x00triu\x00[0]\x00[1]\x00main\x00public\x00private\x00\x00\x00\x00\x00\x02\x00\x00\x00\x03\x00\x00\x00\x03\x00\x00\x00\x00hipblas_geqrf_batched\x00\x00\x00\x00\x00\x02\x00\x00\x00\x03\x00\x00\x00\x03\x00\x00\x00\x03\x00\x00\x00\x80\x00\x00\x00\x00hipsolver_orgqr\x00", - xla_call_module_version=9, - nr_devices=1, -) # End paste diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/shardy_sharding_ops_with_different_meshes.py b/jax/_src/internal_test_util/export_back_compat_test_data/shardy_sharding_ops_with_different_meshes.py index b54234d11cca..89ba7b0a8790 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_data/shardy_sharding_ops_with_different_meshes.py +++ b/jax/_src/internal_test_util/export_back_compat_test_data/shardy_sharding_ops_with_different_meshes.py @@ -54,4 +54,83 @@ mlir_module_serialized=b'ML\xefR\rStableHLO_v1.8.8\x00\x01\x1d\x05\x01\x05\r\x01\x03\x0b\x03\x0b\x0f\x13\x17\x1b\x1f\x03\x97q\x13\x019\x0f\x07\x0b\x0b+\x0b\x0f\x13\x0b\x0b\x0b\x0f\x0b\x0f\x0b\x0b\x17\x0f\x0b\x17\x0f\x0b\x1b\x0b\x0f\x0b\x17\x13\x039\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x0b\x0b\x0f\x13\x0b\x0b\x0b\x0b\x0f\x8f\x13\x0b\x0b\x0b\x0b#\x0b\x0b\x0b\x0b\x0b\x01\x05\x0f\x0b\x03\x0f\x17\x17\x07\x07\x17\x17\x17\x02v\x03\x1d\x1f!\x1f\x05\x11\x05\x13\x03\t\x0b\r\x05\x0f\x15\x17\x19\x1b\x05\x15\x11\x03\x00\x03\x03\x11\x13\x05\x17\x05\x19\x05\x1b\x11\x01\t\x05\x1d\x11\x01\x05\x05\x1f\x05!\x17\x07r\x10\x1b\x1d%\'\x05#\x17\x07j\x10\x1f\x1d+\x03\x05%\x03\x05\x05[/_\x05\'\x1d35\x05)\x17\x07n\x10\x15\x03\x03\x05e\x03\x01\x1d+\x1d-\x0b\x03\x05\x01\x1d/\x03\x03G\r\x01#\r\x03\x03M\r\x03O;\x1d1\x1d3\x1d5#\x0f\x13\x0b\x05\x1f\x11A\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\r\x03]=\x1d7\x1d9\x1d;\x1d=\r\x07g=ikm=\x1d?\x1dA\x1dC\x1dE\x1dG\x01\x02\x02\x01\t)\x05\x05\x11\t)\x05\t\x11\t\t\x1d\x11\x03\x07\x03\x07\x11\x03\x05\x03\x05)\x05\t\t\x0b\x04\xb9\x05\x01Q\x03\t\x01\x07\x04\xa7\x03\x01\t\x05P\x03\x03\x07\x04]\x03\x0b\x17\x03\x0f)\x00\x03G1-\x05\x03\x07\x03\x01\x03F\x01\x07\x03\x05\x03\x03\x0bG\x017\t\x03\x05\x03\x05\x03F\x01\x0b\x03\x07\x03\x07\x07\x04\x03\x03\t\x05P\x01\r\x07\x04)\x03\x05\x0b\x03\x0b\x01\x00\tF#\x0f\x03\x05\x03\x01\x07\x04\x01\x03\x03\x06\x03\x01\x05\x01\x00r\x0bI7-3)+7\x13+#\x0f\x0b!Ae\x03Q\x1d\x05;=\x13%)=\x1f9i3\x11-\x15\x11\x1f\x0f\x0b\x11builtin\x00vhlo\x00module\x00custom_call_v1\x00func_v1\x00return_v1\x00collective_permute_v1\x00call_v1\x00mhlo.frontend_attributes\x00third_party/py/jax/tests/export_back_compat_test.py\x00jax.uses_shape_polymorphism\x00xla.sdy.meshes\x00{mesh = #sdy.mesh<[\\"a\\"=2]>}\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00jit(func)/jit(main)/shard_map\x00jit(func)/jit(main)/ppermute\x00x\x00mhlo.sharding\x00jit(func)/jit(main)/sharding_constraint\x00\x00#sdy.sharding_per_value<[<@mesh, [{\\"a\\"}, {}]>]>\x00xla.sdy.manual_computation_body\x00jax.result_info\x00main\x00public\x00xla.sdy.sharding\x00{devices=[2,1]<=[2]}\x00Sharding\x00xla.sdy.GlobalToLocalShape\x00xla.sdy.in_shardings\x00xla.sdy.manual_axes\x00#sdy\x00xla.sdy.out_shardings\x00xla.sdy.LocalToGlobalShape\x00\x08a\x11\x05;\x01\x0bEIKQS\x11?;a9A999\x11?;c9A999\x03C\x11?;o9A999\x0b9U9C;\x05WY', xla_call_module_version=9, nr_devices=2, +) + + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2025_04_14 = dict( + testdata_version=1, + platform='tpu', + custom_call_targets=['Sharding', 'xla.sdy.GlobalToLocalShape', 'xla.sdy.LocalToGlobalShape'], + serialized_date=datetime.date(2025, 4, 14), + inputs=(array([[0., 1., 2., 3.], + [4., 5., 6., 7.]], dtype=float32),), + expected_outputs=(array([[4., 5., 6., 7.], + [0., 1., 2., 3.]], dtype=float32),), + mlir_module_text=r""" +#loc1 = loc("x") +#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":1017:8 to :54) +#loc4 = loc("third_party/py/absl/testing/absltest.py":2872:19 to :56) +#loc5 = loc("third_party/py/absl/testing/absltest.py":2908:35 to 2910:3) +#loc6 = loc("third_party/py/absl/testing/absltest.py":2449:6 to :34) +#loc7 = loc("third_party/py/absl/app.py":404:13 to :23) +#loc8 = loc("third_party/py/absl/app.py":484:6 to :27) +#loc9 = loc("third_party/py/absl/testing/absltest.py":2451:4 to :31) +#loc10 = loc("third_party/py/absl/testing/absltest.py":2333:2 to :38) +#loc11 = loc("third_party/py/jax/tests/export_back_compat_test.py":1021:2 to :47) +#loc12 = loc("third_party/py/jax/tests/export_back_compat_test.py":1008:13 to :30) +#loc15 = loc("ShardyCompatTest.test_shardy_sharding_ops_with_different_meshes"(#loc3)) +#loc16 = loc("_run_and_get_tests_result"(#loc4)) +#loc17 = loc("run_tests"(#loc5)) +#loc18 = loc("_run_in_app..main_function"(#loc6)) +#loc19 = loc("_run_main"(#loc7)) +#loc20 = loc("run"(#loc8)) +#loc21 = loc("_run_in_app"(#loc9)) +#loc22 = loc("main"(#loc10)) +#loc23 = loc(""(#loc11)) +#loc24 = loc("ShardyCompatTest.test_shardy_sharding_ops_with_different_meshes..func"(#loc12)) +#loc26 = loc(callsite(#loc22 at #loc23)) +#loc28 = loc(callsite(#loc21 at #loc26)) +#loc30 = loc(callsite(#loc20 at #loc28)) +#loc32 = loc(callsite(#loc19 at #loc30)) +#loc34 = loc(callsite(#loc18 at #loc32)) +#loc36 = loc(callsite(#loc17 at #loc34)) +#loc38 = loc(callsite(#loc16 at #loc36)) +#loc40 = loc(callsite(#loc15 at #loc38)) +#loc43 = loc(callsite(#loc24 at #loc40)) +#loc46 = loc("jit(func)/jit(main)/shard_map"(#loc43)) +module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.frontend_attributes = {xla.sdy.meshes = "{mesh = #sdy.mesh<[\22a\22=2]>}"}, mhlo.num_partitions = 2 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor<2x4xf32> loc("x")) -> (tensor<2x4xf32> {jax.result_info = "result"}) { + %0 = stablehlo.custom_call @Sharding(%arg0) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\22a\22}, {}]>]>"}, mhlo.sharding = "{devices=[2,1]<=[2]}"} : (tensor<2x4xf32>) -> tensor<2x4xf32> loc(#loc45) + %1 = stablehlo.custom_call @xla.sdy.GlobalToLocalShape(%0) {has_side_effect = true, mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh, [{\22a\22}, {}]>]>", xla.sdy.manual_axes = "#sdy"}} : (tensor<2x4xf32>) -> tensor<1x4xf32> loc(#loc46) + %2 = call @xla.sdy.manual_computation_body(%1) {mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh, [{\22a\22}, {}]>]>", xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh, [{\22a\22}, {}]>]>"}} : (tensor<1x4xf32>) -> tensor<1x4xf32> loc(#loc46) + %3 = stablehlo.custom_call @xla.sdy.LocalToGlobalShape(%2) {has_side_effect = true, mhlo.frontend_attributes = {xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh, [{\22a\22}, {}]>]>"}} : (tensor<1x4xf32>) -> tensor<2x4xf32> loc(#loc46) + return %3 : tensor<2x4xf32> loc(#loc) + } loc(#loc) + func.func @xla.sdy.manual_computation_body(%arg0: tensor<1x4xf32> loc("jit(func)/jit(main)/shard_map"(#loc43))) -> tensor<1x4xf32> { + %0 = "stablehlo.collective_permute"(%arg0) <{channel_handle = #stablehlo.channel_handle, source_target_pairs = dense<[[0, 1], [1, 0]]> : tensor<2x2xi64>}> : (tensor<1x4xf32>) -> tensor<1x4xf32> loc(#loc47) + return %0 : tensor<1x4xf32> loc(#loc46) + } loc(#loc46) +} loc(#loc) +#loc = loc(unknown) +#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":1007:10 to :73) +#loc13 = loc("third_party/py/jax/tests/export_back_compat_test.py":1006:15 to :46) +#loc14 = loc("ShardyCompatTest.test_shardy_sharding_ops_with_different_meshes..func"(#loc2)) +#loc25 = loc("ShardyCompatTest.test_shardy_sharding_ops_with_different_meshes..func..shard_map_func"(#loc13)) +#loc27 = loc(callsite(#loc21 at #loc22)) +#loc29 = loc(callsite(#loc20 at #loc27)) +#loc31 = loc(callsite(#loc19 at #loc29)) +#loc33 = loc(callsite(#loc18 at #loc31)) +#loc35 = loc(callsite(#loc17 at #loc33)) +#loc37 = loc(callsite(#loc16 at #loc35)) +#loc39 = loc(callsite(#loc15 at #loc37)) +#loc41 = loc(callsite(#loc24 at #loc39)) +#loc42 = loc(callsite(#loc14 at #loc40)) +#loc44 = loc(callsite(#loc25 at #loc41)) +#loc45 = loc("jit(func)/jit(main)/sharding_constraint"(#loc42)) +#loc47 = loc("jit(func)/jit(main)/ppermute"(#loc44)) +""", + mlir_module_serialized=b'ML\xefR\rStableHLO_v1.9.5\x00\x01\x1d\x05\x01\x05\r\x01\x03\x0b\x03\x0b\x0f\x13\x17\x1b\x1f\x03\x1a\x02\xe7\x13\x01\xa7\x0f\x0b\x0b\x0b\x07\x0f\x0b\x0f\x0f\x0f\x0f\x0f\x0f\x0b\x0f\x0f\x0f+\x0b\x0f\x13\x0b\x0b\x0b\x0f\x0b\x0f\x0b\x0b\x0f\x1f\x0b\x1f\x0f\x0b\x1f\x0f\x0b\'\x0f\x0b\x1f\x0f\x0b\x1f\x0f\x0b\x1f\x0f\x0b\x1f\x0f\x0b\x1f\x0f\x0b\x1f\x0f\x0b\x0f\x0f\x0b\x1f\x0f\x0f\x0f\x0f\x0f\x0f\x0f\x0f\x0f\x0b\x1b\x0b\x0f\x0b\x0f\x0f\x1f\x13\x13\x13\x03A\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x0b\x0b\x0f\x13\x0b\x0b\x0b\x0b\x0b\x0f\x8f\x13\x0b\x0b\x0b\x0b\x1b\x0b#\x1b\x0b\x01\x05\x0f\x0b\x03\x0f\x17\x17\x07\x07\x17\x17\x17\x02\xce\x06\x1d9;\x05\x11\x05\x13\x05\x15\x1f\x1d\r=\x05\x17\x15\x11C\x1d?A\x1dEG\x1dKM\x1dQS\x1dWY\x05\x19\x1d]_\x1dce\x1dik\x03\t%\'\x03)/135\x05\x1b\x11\x03\x00\x03\x03+-\x05\x1d\x05\x1f\x05!\x11\x01\t\x05#\x11\x01\x05\x05%\x05\'\x15\x0b\x0f-\x05\x07\xc2\x0f\x1b=\x05)-\x05\x07\xe6\x0f\x11m\x15\x13I\x05+-\x07\x07\xe2,\'q\x15\x15O\x05--\x07\tr-Gz-\x07\x15\x17U\x05/-\x07\x07F&\rE\x15\x19[\x051-\x1b\x07R\x06\x1b/\x15\x1da\x053-\x1b\x07\x92\x07\r7\x15\x1fg\x055-\x07\x07N&\t?\x15!m\x057-\x07\x07v$\x05M\x1doq\x059-\x05\x07\xf6\x0f\x05_\x1duw\x05;\x15y\x7f\x1d{}\x05=-\x05\x07\xba\x0f\x1f]\x15\x0b\x81\x15\x11\x83\x15\x13\x85\x15\x15\x87\x15\x17\x89\x15\x19\x8b\x15\x1d\x8d\x15\x1f!\x1d\x91\t\x05?\x03\x05\x03\xd3\x95\xd7\x05A\x1d\x99\x9b\x05C\x15\x9d\x0f\x1d\r\x9f-\x05\x07\xbe\x0f\x15\x93\x03\x03\x03\xdd\x03\x03\x03\xe1\x03\x03\x03\xe3\x03\x01\x1dE\x1dG\x0b\x03\x1dI\x1dK\x1dM\x1dO\x05\x03\x1dQ\x03\x03\xbd\r\x01#\r\x03\x03\xc3\r\x03\xc5\xc7\x1dS\x1dU\x1d7\x1dW#\x0f\x13\x0b\x05\x1f\x11A\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\r\x03\xd5\xa9\x1dY\x1d[\x1d]\x05\x01\r\x05\xb5\xa9\xaf\xb1\x1d_\r\x07\xb5\xa9\xaf\xb1\xb9\xa9\r\x05\xaf\xb1\xb9\xa9\x1da\x01\x02\x02\x01\t)\x05\x05\x11\t)\x05\t\x11\t\t\x1d\x11\x03\x07\x03\x07\x11\x03\x05\x03\x05)\x05\t\t\x0b\x04\xbd\x05\x01Q\t#\x01\x07\x04\xab\x03\x01\t\x05P\t\x03\x07\x04a\x03\x0b\x17\x03\x0f\x8f\x00\x03G\x97\x93\x05\x03\x07\x03\x01\x03G\x01\xa1\x07\x03\x05\x03\x03\x0bG\x01\xa3\t\x03\x05\x03\x05\x03G\x01\xa5\x0b\x03\x07\x03\x07\x07\x04\t\x03\t\x05P\x01\r\x07\x04)\x03\x05\x0b\x03\x0b\x01\x00\tFs\x0f\x03\x05\x03\x01\x07\x04\x01\x03\x03\x06\x03\x01\x05\x01\x00.\x12c77\x13+#\x0f\x0f!-+A/)\x03aQ\x1d\x05\xcd;\x13\x0b\x19\t\x15G\x155\x81=\x13%)9\x1f97\x9dQi3\x11-\x15\x11\x1f\x0f\x0b\x11builtin\x00vhlo\x00module\x00custom_call_v1\x00func_v1\x00return_v1\x00collective_permute_v1\x00call_v1\x00mhlo.frontend_attributes\x00third_party/py/jax/tests/export_back_compat_test.py\x00third_party/py/absl/testing/absltest.py\x00ShardyCompatTest.test_shardy_sharding_ops_with_different_meshes..func\x00third_party/py/absl/app.py\x00jax.uses_shape_polymorphism\x00xla.sdy.meshes\x00{mesh = #sdy.mesh<["a"=2]>}\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00jit(func)/jit(main)/shard_map\x00ShardyCompatTest.test_shardy_sharding_ops_with_different_meshes\x00_run_and_get_tests_result\x00run_tests\x00_run_in_app..main_function\x00_run_main\x00run\x00_run_in_app\x00main\x00\x00jit(func)/jit(main)/ppermute\x00ShardyCompatTest.test_shardy_sharding_ops_with_different_meshes..func..shard_map_func\x00x\x00mhlo.sharding\x00jit(func)/jit(main)/sharding_constraint\x00#sdy.sharding_per_value<[<@mesh, [{"a"}, {}]>]>\x00\x00xla.sdy.manual_axes\x00#sdy\x00xla.sdy.manual_computation_body\x00xla.sdy.in_shardings\x00xla.sdy.out_shardings\x00jax.result_info\x00result\x00public\x00xla.sdy.sharding\x00{devices=[2,1]<=[2]}\x00Sharding\x00xla.sdy.GlobalToLocalShape\x00xla.sdy.LocalToGlobalShape\x00\x08a\x11\x05o\x01\x0b\xbb\xbf\xc1\xc9\xcb\x11\xad\xab\xd9\xa7\xdb\xa7\xa7\xa7\x11\xad\xab\xdf\xa7\xb7\xa7\xa7\xa7\x03\xb3\x11\xad\xab\xe5\xa7\xb7\xa7\xa7\xa7\x0b\xa7\xcd\xa7\xb3\xab\x05\xcf\xd1', + xla_call_module_version=9, + nr_devices=2, ) # End paste diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/tpu_Sharding.py b/jax/_src/internal_test_util/export_back_compat_test_data/tpu_Sharding.py index f2d8be3b958a..1caac00a4680 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_data/tpu_Sharding.py +++ b/jax/_src/internal_test_util/export_back_compat_test_data/tpu_Sharding.py @@ -15,8 +15,10 @@ import datetime from numpy import array, float32 +data_2023_03_16 = {} + # Pasted from the test output (see module docstring) -data_2023_03_16 = dict( +data_2023_03_16['gspmd'] = dict( testdata_version=1, platform='tpu', custom_call_targets=['SPMDFullToShardShape', 'SPMDShardToFullShape', 'Sharding'], @@ -47,3 +49,40 @@ xla_call_module_version=4, nr_devices=2, ) # End paste + + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2023_03_16['shardy'] = dict( + testdata_version=1, + platform='tpu', + custom_call_targets=['xla.sdy.FuncResultSharding', 'xla.sdy.GlobalToLocalShape', 'xla.sdy.LocalToGlobalShape'], + serialized_date=datetime.date(2025, 5, 28), + inputs=(array([[0., 1., 2., 3.], + [4., 5., 6., 7.]], dtype=float32),), + expected_outputs=(array([[4., 5., 6., 7.], + [0., 1., 2., 3.]], dtype=float32),), + mlir_module_text=r""" +#loc1 = loc("x") +#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":783:6) +#loc4 = loc("jit(func)/jit(main)/shard_map"(#loc2)) +module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.frontend_attributes = {xla.sdy.meshes = "{mesh = #sdy.mesh<[\22a\22=2]>}"}, mhlo.num_partitions = 2 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor<2x4xf32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding<@mesh, [{\22a\22}, {}]>"}, mhlo.sharding = "{devices=[2,1]<=[2]}"} loc("x")) -> (tensor<2x4xf32> {jax.result_info = "result", mhlo.sharding = "{devices=[2,1]<=[2]}"}) { + %0 = stablehlo.custom_call @xla.sdy.GlobalToLocalShape(%arg0) {has_side_effect = true, mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh, [{\22a\22}, {}]>]>", xla.sdy.manual_axes = "#sdy"}} : (tensor<2x4xf32>) -> tensor<1x4xf32> loc(#loc4) + %1 = call @xla.sdy.manual_computation_body(%0) : (tensor<1x4xf32>) -> tensor<1x4xf32> loc(#loc4) + %2 = stablehlo.custom_call @xla.sdy.LocalToGlobalShape(%1) {mhlo.frontend_attributes = {xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh, [{\22a\22}, {}]>]>"}} : (tensor<1x4xf32>) -> tensor<2x4xf32> loc(#loc4) + %3 = stablehlo.custom_call @xla.sdy.FuncResultSharding(%2) {has_side_effect = true, mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\22a\22}, {}]>]>"}} : (tensor<2x4xf32>) -> tensor<2x4xf32> loc(#loc4) + return %3 : tensor<2x4xf32> loc(#loc) + } loc(#loc) + func.func @xla.sdy.manual_computation_body(%arg0: tensor<1x4xf32> loc("jit(func)/jit(main)/shard_map"(#loc2))) -> tensor<1x4xf32> { + %0 = "stablehlo.collective_permute"(%arg0) <{channel_handle = #stablehlo.channel_handle, source_target_pairs = dense<[[0, 1], [1, 0]]> : tensor<2x2xi64>}> : (tensor<1x4xf32>) -> tensor<1x4xf32> loc(#loc5) + return %0 : tensor<1x4xf32> loc(#loc4) + } loc(#loc4) +} loc(#loc) +#loc = loc(unknown) +#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":779:13) +#loc5 = loc("jit(func)/jit(main)/ppermute"(#loc3)) +""", + mlir_module_serialized=b'ML\xefR\rStableHLO_v1.10.3\x00\x01\x1d\x05\x01\x05\r\x01\x03\x0b\x03\x0b\x0f\x13\x17\x1b\x1f\x03\x9fy\x13\x013\x0f\x0b\x07\x0b+\x0b\x0f\x13\x0b\x0b\x0b\x0f\x0b\x0f\x0b\x0b\x17\x0f\x0b\x17\x0f\x0b\x13\x13\x13\x03G\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x1b\x0b\x13\x0b\x0b\x0f\x1b\x0b\x0b\x0b\x0b\x0b\x0f\x8f\x1b\x0b\x0b\x1b\x0b\x0b\x0b\x13\x0b\x01\x05\x0f\x0b\x03\x0f\x17\x17\x07\x07\x17\x17\x17\x02\xae\x03\x1d\x1f!\x05\x11\x1f\x05\x13\x03\t\x0b\r\x03\x0f\x15\x17\x19\x1b\x05\x15\x11\x03\x00\x03\x03\x11\x13\x05\x17\x05\x19\x05\x1b\x11\x01\t\x05\x1d\x11\x01\x05\x05\x1f\x05!\x17\x07>\x0c\r\x1d%\'\x05#\x17\x07.\x0c\x1b\x1d+\x05\x05%\x03\x03\x03g\x03\x03\x03m\x03\x03\x03u\x03\x01\x1d\'\x1d)\x0b\x03\x1d+\x1d-\x1d/\x1d1\x1d3\x1d5\x05\x03\x03\x03K\r\x05MO=?\x1d\x11\r\x03;Q\x1d7#\r\x03\x03W\r\x05Y[=?\x1d9\x1d;\x1d=\x1d?#\x0f\x13\x0b\x05\x1f\x11A\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\r\x05i7CE\x1dA\x1dC\r\x05CEo7\x1dE\x1dG\x05\x01\r\x03;7\x1dI\x01\x02\x02\x01\t)\x05\x05\x11\t)\x05\t\x11\t\t\x1d\x11\x03\x07\x03\x07\x11\x03\x05\x03\x05)\x05\t\t\x0b\x04\xbb\x05\x01Q\x05\t\x01\x07\x04\xa9\x03\x01\t\x05P\x05\x03\x07\x04_\x03\x0b\x17\x03\x0f)\x00\x03G\x01-\x05\x03\x05\x03\x01\x0bF\x01\x07\x03\x05\x03\x03\x03G\x01/\t\x03\x07\x03\x05\x03G\x011\x0b\x03\x07\x03\x07\x07\x04\x05\x03\t\x05P\x01\r\x07\x04)\x03\x05\x0b\x03\x0b\x01\x00\tF#\x0f\x03\x05\x03\x01\x07\x04\x01\x03\x03\x06\x03\x01\x05\x01\x00\xaa\x0bK77-7+\x0f\x0b\x0f!E/)A+\x1d#a\x03\x05;=\x13%)9\x1f9i3\x11-\x15\x11\x1f\x0f\x0b\x11builtin\x00vhlo\x00module\x00custom_call_v1\x00func_v1\x00return_v1\x00collective_permute_v1\x00call_v1\x00mhlo.frontend_attributes\x00third_party/py/jax/tests/export_back_compat_test.py\x00jax.uses_shape_polymorphism\x00xla.sdy.meshes\x00{mesh = #sdy.mesh<["a"=2]>}\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00jit(func)/jit(main)/shard_map\x00jit(func)/jit(main)/ppermute\x00x\x00\x00#sdy.sharding_per_value<[<@mesh, [{"a"}, {}]>]>\x00xla.sdy.sharding\x00mhlo.sharding\x00{devices=[2,1]<=[2]}\x00xla.sdy.manual_computation_body\x00xla.sdy.manual_axes\x00#sdy\x00#sdy.sharding<@mesh, [{"a"}, {}]>\x00jax.result_info\x00result\x00main\x00public\x00xla.sdy.in_shardings\x00xla.sdy.GlobalToLocalShape\x00xla.sdy.out_shardings\x00xla.sdy.LocalToGlobalShape\x00xla.sdy.FuncResultSharding\x00\x08a\x11\x05;\x01\x0bISU]_\x1195k3G333\x03A\x1195q3s333\x1195w3G333\x0b3a3A5\x05ce', + xla_call_module_version=9, + nr_devices=2, +) # End paste diff --git a/jax/_src/internal_test_util/export_back_compat_test_util.py b/jax/_src/internal_test_util/export_back_compat_test_util.py index 5d5e95b5cb9a..7b4af36e5dc4 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_util.py +++ b/jax/_src/internal_test_util/export_back_compat_test_util.py @@ -90,6 +90,7 @@ def func(...): ... from jax.experimental import pjit from jax._src import core +from jax._src import stages from jax._src import test_util as jtu from jax._src import xla_bridge as xb @@ -165,7 +166,8 @@ def load_testdata_nested(self, testdata_nest) -> Iterable[CompatTestData]: else: assert False, testdata_nest - def run_one_test(self, func: Callable[..., jax.Array], + def run_one_test(self, + func: Callable[..., jax.Array] | stages.Wrapped, data: CompatTestData, polymorphic_shapes: Sequence[str] | None = None, rtol: float | None = None, @@ -176,7 +178,8 @@ def run_one_test(self, func: Callable[..., jax.Array], """Run one compatibility test. Args: - func: the JAX function to serialize and run + func: the JAX function to serialize and run, either as a Python Callable + or as a `jax.jit(callable)`. data: the test data polymorphic_shapes: when using shape polymorphism, the specification for each argument of `func`. @@ -269,19 +272,22 @@ def run_one_test(self, func: Callable[..., jax.Array], expect_current_custom_calls = data.custom_call_targets self.assertItemsEqual(expect_current_custom_calls, current_custom_call_targets) - def run_current(self, func: Callable, data: CompatTestData): + def run_current(self, + func: Callable | stages.Wrapped, + data: CompatTestData): """Lowers and runs the test function at the current JAX version.""" - return jax.jit(func)(*data.inputs) + jit_func = func if isinstance(func, stages.Wrapped) else jax.jit(func) + return jit_func(*data.inputs) def serialize(self, - func: Callable, data: CompatTestData, *, + func: Callable | stages.Wrapped, data: CompatTestData, *, polymorphic_shapes: Sequence[str] | None = None, allow_unstable_custom_call_targets: Sequence[str] = () ) -> tuple[bytes, str, int, int]: """Serializes the test function. Args: - func: the function to serialize + func: the function to serialize. polymorphic_shapes: the polymorphic_shapes to use for serialization allow_unstable_custom_call_targets: whether to allow additional custom call targets besides those known as stable. @@ -292,8 +298,9 @@ def serialize(self, """ # Use the native exporter, to make sure we get the proper serialization. args_specs = export.symbolic_args_specs(data.inputs, polymorphic_shapes) + jit_func = func if isinstance(func, stages.Wrapped) else jax.jit(func) exported = export.export( - jax.jit(func), + jit_func, platforms=(self.default_jax_backend(),), disabled_checks=tuple( export.DisabledSafetyCheck.custom_call(target) @@ -314,7 +321,7 @@ def ndarray_to_aval(a: np.ndarray) -> core.ShapedArray: in_avals_tree = tree_util.tree_map(ndarray_to_aval, args_specs) # TODO: we ought to ensure that out_avals are polymorphic if need be. We # could either save the in/out_avals (but we need to first implement that - # support in export), or we can just re-use them from the current + # support in export), or we can just reuse them from the current # exported. out_avals_tree = tree_util.tree_map(ndarray_to_aval, data.expected_outputs) # in_tree must be for (args, kwargs) diff --git a/jax/_src/internal_test_util/lax_test_util.py b/jax/_src/internal_test_util/lax_test_util.py index 4e28791e9cee..767b41dc8ba0 100644 --- a/jax/_src/internal_test_util/lax_test_util.py +++ b/jax/_src/internal_test_util/lax_test_util.py @@ -304,7 +304,7 @@ def lax_ops(): float_dtypes, test_util.rand_uniform, { - np.float32: 1e-5, + np.float32: 2e-5, np.float64: 1e-12, }, ), diff --git a/jax/_src/internal_test_util/test_harnesses.py b/jax/_src/internal_test_util/test_harnesses.py index 48c645c4d033..a3e873c43c92 100644 --- a/jax/_src/internal_test_util/test_harnesses.py +++ b/jax/_src/internal_test_util/test_harnesses.py @@ -408,11 +408,11 @@ def parameterized(harnesses: Iterable[Harness], ############################################################################### -def _make_unary_elementwise_harness(*, prim, shape=(20, 20), dtype): +def _make_unary_elementwise_harness(*, prim, shape=(20, 20), dtype, **kwargs): define( str(prim), f"shape={jtu.format_shape_dtype_string(shape, dtype)}", - prim.bind, [RandArg(shape, dtype)], + lambda x: prim.bind(x, **kwargs), [RandArg(shape, dtype)], prim=prim, dtype=dtype, shape=shape) @@ -429,19 +429,19 @@ def _make_unary_elementwise_harness(*, prim, shape=(20, 20), dtype): _make_unary_elementwise_harness(prim=lax.acos_p, dtype=dtype) _make_unary_elementwise_harness(prim=lax.atan_p, dtype=dtype) _make_unary_elementwise_harness(prim=lax.asin_p, dtype=dtype) - _make_unary_elementwise_harness(prim=lax.cos_p, dtype=dtype) + _make_unary_elementwise_harness(prim=lax.cos_p, dtype=dtype, accuracy=None) _make_unary_elementwise_harness(prim=lax.cosh_p, dtype=dtype) - _make_unary_elementwise_harness(prim=lax.exp_p, dtype=dtype) - _make_unary_elementwise_harness(prim=lax.expm1_p, dtype=dtype) - _make_unary_elementwise_harness(prim=lax.log_p, dtype=dtype) - _make_unary_elementwise_harness(prim=lax.log1p_p, dtype=dtype) - _make_unary_elementwise_harness(prim=lax.rsqrt_p, dtype=dtype) - _make_unary_elementwise_harness(prim=lax.sin_p, dtype=dtype) + _make_unary_elementwise_harness(prim=lax.exp_p, dtype=dtype, accuracy=None) + _make_unary_elementwise_harness(prim=lax.expm1_p, dtype=dtype, accuracy=None) + _make_unary_elementwise_harness(prim=lax.log_p, dtype=dtype, accuracy=None) + _make_unary_elementwise_harness(prim=lax.log1p_p, dtype=dtype, accuracy=None) + _make_unary_elementwise_harness(prim=lax.rsqrt_p, dtype=dtype, accuracy=None) + _make_unary_elementwise_harness(prim=lax.sin_p, dtype=dtype, accuracy=None) _make_unary_elementwise_harness(prim=lax.sinh_p, dtype=dtype) - _make_unary_elementwise_harness(prim=lax.sqrt_p, dtype=dtype) - _make_unary_elementwise_harness(prim=lax.tan_p, dtype=dtype) - _make_unary_elementwise_harness(prim=lax.tanh_p, dtype=dtype) - _make_unary_elementwise_harness(prim=lax.logistic_p, dtype=dtype) + _make_unary_elementwise_harness(prim=lax.sqrt_p, dtype=dtype, accuracy=None) + _make_unary_elementwise_harness(prim=lax.tan_p, dtype=dtype, accuracy=None) + _make_unary_elementwise_harness(prim=lax.tanh_p, dtype=dtype, accuracy=None) + _make_unary_elementwise_harness(prim=lax.logistic_p, dtype=dtype, accuracy=None) for dtype in jtu.dtypes.all_floating: _make_unary_elementwise_harness(prim=lax.bessel_i0e_p, dtype=dtype) @@ -654,8 +654,8 @@ def _make_device_put_harness(name, "device_put", f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_{device=}", lambda x: dispatch.device_put_p.bind( - x, devices=[_device_fn()], srcs=[None], - copy_semantics=[dispatch.CopySemantics.ALIAS])[0], + x, devices=(_device_fn(),), srcs=(None,), + copy_semantics=(dispatch.CopySemantics.ALIAS,))[0], [RandArg(shape, dtype)], shape=shape, dtype=dtype, @@ -2744,7 +2744,8 @@ def wrap_and_split(): "random_categorical", f"shape={jtu.format_shape_dtype_string(shape, dtype)}_{axis=}", lambda x, axis: jax.random.categorical( - jax.random.key(42), x, axis), + # TODO(b/416027995): Change this key back to 42. + jax.random.key(1337), x, axis), [RandArg(shape, dtype), StaticArg(axis)], dtype=dtype, @@ -3375,8 +3376,9 @@ def _make_conv_harness(name, define( lax.rng_bit_generator_p, f"{key_dtype=}_shape={jtu.format_shape_dtype_string(shape, dtype)}_{algorithm=}", - lambda key, shape, dtype, algorithm: lax.rng_bit_generator(key, shape, dtype=dtype, - algorithm=algorithm), + lambda key, shape, dtype, algorithm, out_sharding=None: lax.rng_bit_generator( + key, shape, dtype=dtype, algorithm=algorithm, + out_sharding=out_sharding), [RandArg(key_shape, key_dtype), StaticArg(shape), StaticArg(dtype), StaticArg(algorithm)], shape=shape, diff --git a/jax/_src/interpreters/ad.py b/jax/_src/interpreters/ad.py index ddf96af6a010..dc5d64e75423 100644 --- a/jax/_src/interpreters/ad.py +++ b/jax/_src/interpreters/ad.py @@ -21,12 +21,12 @@ from functools import partial from typing import Any +from jax._src import api_util from jax._src import config -from jax._src import dispatch from jax._src import linear_util as lu from jax._src.interpreters import partial_eval as pe -from jax.tree_util import (tree_flatten, tree_unflatten, - register_pytree_node, Partial, PyTreeDef) +from jax._src.tree_util import (tree_flatten, tree_unflatten, + register_pytree_node, Partial, PyTreeDef) from jax._src import mesh as mesh_lib from jax._src import core from jax._src import source_info_util @@ -73,6 +73,7 @@ def jvp(fun: lu.WrappedFun, has_aux=False, instantiate=True, def jvpfun(f: Callable, instantiate, transform_stack, primals, tangents): tag = core.TraceTag() tangents = [Zero.from_primal_value(t) if not isinstance(t, Zero) + and isinstance(core.typeof(t), core.ShapedArray) and dtype(t) == float0 else t for t in tangents] ctx = (source_info_util.transform_name_stack('jvp') if transform_stack else contextlib.nullcontext()) @@ -89,12 +90,14 @@ def linearize_subtrace(_f: Callable, _store: lu.Store, _tag: core.TraceTag, nzs_in: Sequence[bool], debug_info: core.DebugInfo, *primals, **params): + source_info = source_info_util.current() with core.take_current_trace() as parent_trace: tangent_trace = pe.DynamicJaxprTrace(debug_info) tangent_trace.tag = _tag linearize_trace = LinearizeTrace(parent_trace, tangent_trace, tag=_tag) tracers = [LinearizeTracer(linearize_trace, p, - tangent_trace.new_arg(get_aval(p).to_tangent_aval())) + tangent_trace.new_arg(get_aval(p).to_tangent_aval(), + source_info)) if nz else p for p, nz in zip(primals, nzs_in)] with core.set_current_trace(linearize_trace, check_leaks=True): @@ -103,8 +106,8 @@ def linearize_subtrace(_f: Callable, _store: lu.Store, _tag: core.TraceTag, del linearize_trace, ans, tracers nzs_out = tuple(type(t) is not Zero for t in out_tangents) out_tangents = tuple(t for t, nz in zip(out_tangents, nzs_out) if nz) - out_tangents = map(tangent_trace.to_jaxpr_tracer, out_tangents) # type: ignore[assignment] - jaxpr, consts, attrs_tracked = tangent_trace.to_jaxpr(out_tangents, debug_info) + out_tangents = map(partial(tangent_trace.to_jaxpr_tracer, source_info=source_info), out_tangents) # type: ignore[assignment] + jaxpr, consts, attrs_tracked = tangent_trace.to_jaxpr(out_tangents, debug_info, source_info) if attrs_tracked: raise NotImplementedError("TODO: attrs") which_env = [(isinstance(c, pe.DynamicJaxprTracer) and @@ -172,13 +175,14 @@ def _linearize_jaxpr( lin_trace = LinearizeTrace(primal_trace, tangent_trace) tangent_trace.tag = lin_trace.tag - def new_arg(trace, primal_aval, nz): - primal = primal_trace.new_arg(primal_aval) + def new_arg(trace, primal_aval, nz, source_info): + primal = primal_trace.new_arg(primal_aval, source_info) tangent_aval = primal_aval.to_tangent_aval() - tangent = tangent_trace.new_arg(tangent_aval) if nz else Zero(tangent_aval) + tangent = tangent_trace.new_arg(tangent_aval, source_info) if nz else Zero(tangent_aval) return LinearizeTracer(trace, primal, tangent) - tracers = [new_arg(lin_trace, v.aval, nz) + source_info = source_info_util.current() + tracers = [new_arg(lin_trace, v.aval, nz, source_info) for (v, nz) in zip(jaxpr.jaxpr.invars, nonzeros)] with core.set_current_trace(lin_trace, check_leaks=True): @@ -188,15 +192,22 @@ def new_arg(trace, primal_aval, nz): debug_info = jaxpr.jaxpr.debug_info nzs_out = [type(t) is not Zero for t in out_tangents] - out_tangents = tuple(tangent_trace.to_jaxpr_tracer(t) + out_tangents = tuple(tangent_trace.to_jaxpr_tracer(t, source_info) for (nz, t) in zip(nzs_out, out_tangents) if nz) - tangent_jaxpr, tangent_consts, attrs_tracked = tangent_trace.to_jaxpr(out_tangents, debug_info) + tangent_jaxpr, tangent_consts, attrs_tracked = tangent_trace.to_jaxpr( + out_tangents, debug_info, source_info) tangent_trace.invalidate() if attrs_tracked: raise NotImplementedError("TODO: attrs") + tangent_jaxpr, used_consts, _ = pe.dce_jaxpr_consts( + tangent_jaxpr, [True] * len(tangent_jaxpr.outvars), + [False] * len(tangent_jaxpr.constvars) + [True] * len(tangent_jaxpr.invars)) + tangent_consts = [c for c, used in zip(tangent_consts, used_consts) if used] + residuals_and_primals = (*tangent_consts, *out_primals) - residuals_and_primals = map(primal_trace.to_jaxpr_tracer, residuals_and_primals) - primal_jaxpr, primal_consts, attrs_tracked = primal_trace.to_jaxpr(residuals_and_primals, debug_info) + residuals_and_primals = map(partial(primal_trace.to_jaxpr_tracer, source_info=source_info), residuals_and_primals) + primal_jaxpr, primal_consts, attrs_tracked = primal_trace.to_jaxpr( + residuals_and_primals, debug_info, source_info) primal_trace.invalidate() num_residuals = len(tangent_consts) tangent_jaxpr = pe.close_jaxpr(convert_constvars_jaxpr_constvars_at_end(tangent_jaxpr)) @@ -207,8 +218,9 @@ def new_arg(trace, primal_aval, nz): def direct_linearize(traceable: lu.WrappedFun, primals, kwargs, *, has_aux=False, tag=None): with core.take_current_trace() as parent_trace: + source_info = source_info_util.current() tangent_trace = pe.DynamicJaxprTrace(traceable.debug_info) - tangents = [tangent_trace.new_arg(get_aval(p).to_tangent_aval()) for p in primals] + tangents = [tangent_trace.new_arg(get_aval(p).to_tangent_aval(), source_info) for p in primals] tangents = [Zero.from_primal_value(t) if dtype(t) == float0 else t for t in tangents] linearize_trace = LinearizeTrace(parent_trace, tangent_trace, tag=tag) tangent_trace.tag = linearize_trace.tag @@ -229,8 +241,9 @@ def direct_linearize(traceable: lu.WrappedFun, del linearize_trace, ans, tracers out_nzs = [type(t) is not Zero for t in out_tangents] out_nz_tangents = [t for t, nz in zip(out_tangents, out_nzs) if nz] - out_nz_tangents = map(tangent_trace.to_jaxpr_tracer, out_nz_tangents) - jaxpr, consts, attrs_tracked = tangent_trace.to_jaxpr(out_nz_tangents, traceable.debug_info) + out_nz_tangents = map(partial(tangent_trace.to_jaxpr_tracer, source_info=source_info), out_nz_tangents) + jaxpr, consts, attrs_tracked = tangent_trace.to_jaxpr( + out_nz_tangents, traceable.debug_info, source_info) tangent_trace.invalidate() jaxpr, used_consts, _ = pe.dce_jaxpr_consts( jaxpr, [True] * len(jaxpr.outvars), @@ -267,7 +280,7 @@ def linearize(traceable: lu.WrappedFun, raise ValueError( "Linearization failed to produce known values for all output primals. " "This is typically caused by attempting to differentiate a function " - "uses an operation that does not support reverse-mode autodiff.") + "using an operation that does not support reverse-mode autodiff.") out_primals_consts = [pval.get_known() for pval in out_primals_pvals] if not has_aux: return out_primals_consts, out_tangents_pvals, jaxpr, consts @@ -407,6 +420,12 @@ def write_primal(v, val): try: cts_out = get_primitive_transpose(eqn.primitive)( cts_in, *invals, **eqn.params) + except core.ShardingTypeError as e: + extra_msg = ("This is a potential JAX bug. Please file an issue at" + " https://github.com/jax-ml/jax/issues") + if extra_msg in str(e): + raise + raise core.ShardingTypeError(f"{str(e)}\n{extra_msg}") except (FloatingPointError, ZeroDivisionError) as e: msg = "When differentiating the code at the top of the callstack:" if msg not in e.args[0]: @@ -460,6 +479,7 @@ def __init__(self, parent_trace, tag): super().__init__() self.tag = tag self.parent_trace = parent_trace + self.requires_low = False def to_primal_tangent_pair(self, val): if isinstance(val, JVPTracer) and val._trace.tag is self.tag: @@ -484,6 +504,11 @@ def process_primitive(self, primitive, tracers, params): else: return maybe_jvp_tracer(self, primal_out, tangent_out) + def cur_qdd(self, x): + p, _ = self.to_primal_tangent_pair(x) + with core.set_current_trace(self.parent_trace): + return core.cur_qdd(p) + def process_call(self, call_primitive, f, tracers, params): assert call_primitive.multiple_results primals, tangents = unzip2(map(self.to_primal_tangent_pair, tracers)) @@ -547,15 +572,20 @@ def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees, with core.set_current_trace(self.parent_trace): res_and_primals_out = fwd.call_wrapped(*fwd_in) - _, res_tree = out_trees() - res, primals_out = split_list(res_and_primals_out, [res_tree.num_leaves]) + _, res_tree, input_fwds = out_trees() + num_res_out = res_tree.num_leaves - sum(f is not None for f in input_fwds) + res_out, primals_out = split_list(res_and_primals_out, [num_res_out]) + res_out_ = iter(res_out) + res = [next(res_out_) if f is None else primals_in[f] for f in input_fwds] + assert next(res_out_, None) is None + avals_out = [core.get_aval(x).to_tangent_aval() for x in primals_out] - # TODO(frostig,mattjj): avoid instantiating zeros when we don't have to! + in_zeros = [type(t) is Zero for t in tangents_in] + nz_tangents_in = [t for z, t in zip(in_zeros, tangents_in) if not z] with core.set_current_trace(self.parent_trace): - tangents_in = map(instantiate_zeros, tangents_in) tangents_out = custom_lin_p.bind( - *res, *tangents_in, num_res=res_tree.num_leaves, bwd=bwd, - out_avals=avals_out, symbolic_zeros=symbolic_zeros) + *res, *nz_tangents_in, num_res=res_tree.num_leaves, bwd=bwd, + out_avals=avals_out, symbolic_zeros=symbolic_zeros, in_zeros=in_zeros) return map(partial(maybe_jvp_tracer, self), primals_out, tangents_out) def process_custom_transpose(self, prim, call, tracers, **params): @@ -591,7 +621,9 @@ def process_custom_transpose(self, prim, call, tracers, **params): return map(partial(maybe_jvp_tracer, self), ps_out, ts_out) def maybe_jvp_tracer(trace, primal, tangent): - if type(tangent) is Zero or dtype(tangent) == float0: + if (type(tangent) is Zero or + isinstance(core.typeof(tangent), core.ShapedArray) + and dtype(tangent) == float0): return primal else: return JVPTracer(trace, primal, tangent) @@ -610,6 +642,9 @@ def __init__(self, trace, primal, tangent): def aval(self): return get_aval(self.primal) + def cur_qdd(self): + return core.cur_qdd(self.primal) + def full_lower(self): if type(self.tangent) is Zero: return core.full_lower(self.primal) @@ -622,10 +657,14 @@ def to_concrete_value(self): def get_referent(self): return core.get_referent(self.primal) + def type_state(self): + return self.primal.type_state() + def _primal_tangent_shapes_match(primal, tangent): if type(tangent) is not Zero: primal_aval = get_aval(primal).strip_weak_type() tangent_aval = get_aval(tangent).strip_weak_type() + if not isinstance(primal_aval, core.ShapedArray): return # TODO(mattjj,dougalm) assert core.definitely_equal_shape(primal_aval.shape, tangent_aval.shape), (primal_aval.shape, tangent_aval.shape) expected_tangent_dtype = core.primal_dtype_to_tangent_dtype(primal_aval.dtype) assert expected_tangent_dtype == tangent_aval.dtype, (expected_tangent_dtype, tangent_aval.dtype) @@ -703,7 +742,7 @@ def _f_jvp(primals, tangents): def process_custom_vjp_call(self, prim, fun, fwd, bwd: lu.WrappedFun, tracers, - out_trees: Callable[[], Sequence[PyTreeDef]], + out_trees: Callable[[], tuple[PyTreeDef, PyTreeDef, list[int | None]]], symbolic_zeros: bool): primals_in, tangents_in = unzip2(map(self.to_primal_tangent_pair, tracers)) if all(type(t) is Zero for t in tangents_in): @@ -715,15 +754,20 @@ def process_custom_vjp_call(self, prim, fun, fwd, with core.set_current_trace(self.parent_trace): res_and_primals_out = fwd.call_wrapped(*fwd_in_flat) - _, res_tree = out_trees() - res, primals_out = split_list(res_and_primals_out, [res_tree.num_leaves]) + _, res_tree, input_fwds = out_trees() + num_res_out = res_tree.num_leaves - sum(f is not None for f in input_fwds) + res_out, primals_out = split_list(res_and_primals_out, [num_res_out]) + res_out_ = iter(res_out) + res = [next(res_out_) if f is None else primals_in[f] for f in input_fwds] + assert next(res_out_, None) is None avals_out = [core.get_aval(x).to_tangent_aval() for x in primals_out] - tangents_in_zeros = map(instantiate_zeros, tangents_in) + in_zeros = [type(t) is Zero for t in tangents_in] + nz_tangents_in = [t for z, t in zip(in_zeros, tangents_in) if not z] with core.set_current_trace(self.tangent_trace): tangents_out = custom_lin_p.bind( - *res, *tangents_in_zeros, num_res=res_tree.num_leaves, bwd=bwd, - out_avals=avals_out, symbolic_zeros=symbolic_zeros) + *res, *nz_tangents_in, num_res=res_tree.num_leaves, bwd=bwd, + out_avals=avals_out, symbolic_zeros=symbolic_zeros, in_zeros=in_zeros) tangent_nzs_out = [type(t) is not Zero for t in tangents_out] return map(partial(maybe_linearize_tracer, self), primals_out, tangent_nzs_out, tangents_out) @@ -864,7 +908,11 @@ def make_zero(aval): out_nz_tracers = [trace.to_jaxpr_tracer(r) for (r, nz) in zip(out_tangents, out_nzs) if nz] in_tracers = [t for t, nz in zip(tangent_args, nonzeros) if nz] - jaxpr, out_consts, _ = pe.tracers_to_jaxpr(in_tracers, out_nz_tracers, jvp.debug_info) + jaxpr, out_consts, _ = pe.tracers_to_jaxpr(in_tracers, out_nz_tracers, [], jvp.debug_info) + jaxpr, used_consts, _ = pe.dce_jaxpr_consts( + jaxpr, [True] * len(jaxpr.outvars), + [False] * len(jaxpr.constvars) + [True] * len(jaxpr.invars)) + out_consts = [c for used, c in zip(used_consts, out_consts) if used] def linearized(residuals, *tangents): nz_tangents_in = [t for (t, nz) in zip(tangents, nonzeros) if nz] @@ -1100,7 +1148,7 @@ def out_axes_thunk(): try: out_flat = primitive.bind(fun, *all_args, **new_params) - except dispatch.InternalFloatingPointError as e: + except api_util.InternalFloatingPointError as e: print("Invalid nan value encountered in the backward pass of a jax.jit " "function. Calling the de-optimized backward pass.") try: @@ -1110,7 +1158,7 @@ def out_axes_thunk(): else: # If control reaches this line, we got a NaN on the output of `compiled` # but not `fun.call_wrapped` on the same arguments. Let's tell the user. - dispatch._raise_no_nan_in_deoptimized(e) + api_util._raise_no_nan_in_deoptimized(e) arg_cts = tree_unflatten(out_tree(), out_flat) # The freevars are being fanned out (not mapped). During transpose the @@ -1141,8 +1189,9 @@ def _jvp_jaxpr(jaxpr: core.ClosedJaxpr, debug_info=jaxpr.jaxpr.debug_info) f_jvp, out_nonzeros = f_jvp_traceable( jvp(f, instantiate=instantiate, transform_stack=False), nonzeros) - tangent_avals = [aval.to_tangent_aval() for aval, nz in zip(jaxpr.in_avals, nonzeros) if nz] - avals_in = list(it.chain(jaxpr.in_avals, tangent_avals)) + tangent_avals = [aval.to_tangent_aval() + for aval, nz in zip(jaxpr.in_aval_qdds, nonzeros) if nz] + avals_in = list(it.chain(jaxpr.in_aval_qdds, tangent_avals)) jaxpr_out, avals_out, literals_out, () = pe.trace_to_jaxpr_dynamic( f_jvp, avals_in) return core.ClosedJaxpr(jaxpr_out, literals_out), out_nonzeros() @@ -1164,18 +1213,18 @@ def rearrange_binders(jaxpr: core.ClosedJaxpr, primals_in, tangents_in, primals_ new_invars = _perm(primals_in, tangents_in, jaxpr.jaxpr.invars) new_outvars = _perm(primals_out, tangents_out, jaxpr.jaxpr.outvars) new_debug_info = jaxpr.jaxpr.debug_info - new_arg_names = tuple(_perm(primals_in, tangents_in, - jaxpr.jaxpr.debug_info.safe_arg_names(len(jaxpr.jaxpr.invars)))) - new_result_paths = tuple(_perm(primals_out, tangents_out, - jaxpr.jaxpr.debug_info.safe_result_paths(len(jaxpr.jaxpr.outvars)))) + arg_names = jaxpr.jaxpr.debug_info.safe_arg_names(len(jaxpr.in_avals)) + result_paths = jaxpr.jaxpr.debug_info.safe_result_paths(len(jaxpr.out_avals)) + new_arg_names = tuple(_perm(primals_in, tangents_in, arg_names)) + new_result_paths = tuple(_perm(primals_out, tangents_out, result_paths)) new_debug_info = new_debug_info._replace( - arg_names=new_arg_names, - result_paths=new_result_paths, - ) - new_jaxpr = core.Jaxpr(jaxpr.jaxpr.constvars, - new_invars, new_outvars, jaxpr.jaxpr.eqns, - jaxpr.jaxpr.effects, - new_debug_info) + arg_names=new_arg_names, result_paths=new_result_paths) + constvars = jaxpr.jaxpr.constvars + new_effects = pe._renumber_effects( + (*constvars, *new_invars), (*constvars, *jaxpr.jaxpr.invars), + jaxpr.jaxpr.effects) + new_jaxpr = core.Jaxpr(constvars, new_invars, new_outvars, jaxpr.jaxpr.eqns, + new_effects, new_debug_info) return core.ClosedJaxpr(new_jaxpr, jaxpr.consts) def _perm(primal_counts: Sequence[int], tangent_counts: Sequence[int], @@ -1202,7 +1251,7 @@ def raise_custom_vjp_error_on_jvp(*_, **__): def _custom_lin_transpose(cts_out, *invals, num_res, bwd: lu.WrappedFun, out_avals, - symbolic_zeros): + symbolic_zeros, in_zeros): res, _ = split_list(invals, [num_res]) if symbolic_zeros: cts_out = map(replace_internal_symbolic_zeros, cts_out) @@ -1210,9 +1259,18 @@ def _custom_lin_transpose(cts_out, *invals, num_res, cts_out = map(instantiate_zeros, cts_out) cts_in = bwd.call_wrapped(*res, *cts_out) cts_in = map(replace_rule_output_symbolic_zeros, cts_in) - return [None] * num_res + list(cts_in) + nz_cts_in, _ = partition_list(in_zeros, cts_in) + return [None] * num_res + nz_cts_in primitive_transposes[custom_lin_p] = _custom_lin_transpose +def _custom_lin_pp_rule(eqn: core.JaxprEqn, context: core.JaxprPpContext, + settings: core.JaxprPpSettings) -> core.pp.Doc: + params = dict(eqn.params) + params.pop("out_avals") + params["bwd"] = params.pop("bwd").debug_info.func_name + return core._pp_eqn(eqn.replace(params=params), context, settings) +core.pp_eqn_rules[custom_lin_p] = _custom_lin_pp_rule + class CustomJVPException(Exception): def __init__(self): diff --git a/jax/_src/interpreters/batching.py b/jax/_src/interpreters/batching.py index 03c9a95105d7..49c853e33040 100644 --- a/jax/_src/interpreters/batching.py +++ b/jax/_src/interpreters/batching.py @@ -21,7 +21,6 @@ import numpy as np -import jax from jax._src import config from jax._src import core from jax._src import source_info_util @@ -29,9 +28,7 @@ from jax._src.partition_spec import PartitionSpec as P from jax._src.sharding_impls import NamedSharding from jax._src import mesh as mesh_lib -from jax._src.ad_util import (Zero, instantiate, SymbolicZero, - replace_rule_output_symbolic_zeros, - add_jaxvals, add_jaxvals_p) +from jax._src.ad_util import Zero, SymbolicZero, add_jaxvals, add_jaxvals_p from jax._src.core import Trace, Tracer, TraceTag, AxisName from jax._src.interpreters import partial_eval as pe from jax._src.tree_util import (tree_unflatten, tree_flatten, @@ -114,7 +111,7 @@ def _jumble_unflatten(aval, x): register_pytree_node(Jumble, _jumble_flatten, _jumble_unflatten) def _jumble_result(axis_size, stacked_axis, ragged_axes, x): - binder = core.Var('', core.ShapedArray((), np.dtype('int32'))) + binder = core.Var(core.ShapedArray((), np.dtype('int32'))) if stacked_axis != 0: raise NotImplementedError # TODO Transpose x so the stacked axis is axis 0 shape = list(x.shape) @@ -178,7 +175,7 @@ def bdim_as_shape( bdim: int | RaggedAxis, data_shape: core.Shape) -> core.Shape: if isinstance(bdim, RaggedAxis): result = list(data_shape) - binder = core.Var('', core.ShapedArray((), np.dtype('int32'))) + binder = core.Var(core.ShapedArray((), np.dtype('int32'))) for ragged_axis, segment_lens in bdim.ragged_axes: result[ragged_axis] = IndexedAxisSize(binder, segment_lens) return tuple(result) @@ -303,11 +300,14 @@ def _cont(axis_size, elt, axis): from_elt_handlers: dict[type, FromEltHandler] = {} def make_iota(axis_size: AxisSize) -> Array: + # Callers of this utility, via batch() or vtile(), must be in a context + # where lax is importable. + from jax import lax # pytype: disable=import-error handler = make_iota_handlers.get(type(axis_size)) if handler: return handler(axis_size) else: - return jax.lax.iota('int32', int(axis_size)) + return lax.iota('int32', int(axis_size)) make_iota_handlers: dict[type, MakeIotaHandler] = {} def register_vmappable(data_type: type, spec_type: type, axis_size_type: type, @@ -408,6 +408,10 @@ def __init__(self, trace, val, batch_dim: NotMapped | int | RaggedAxis, @property def aval(self): aval = core.get_aval(self.val) + if self._trace.axis_data.spmd_name is not None: + if config._check_vma.value: + aval = aval.update( + vma=aval.vma - frozenset(self._trace.axis_data.spmd_name)) if self.batch_dim is not_mapped: return aval elif type(self.batch_dim) is int: @@ -456,6 +460,8 @@ class AxisData: def get_sharding_for_vmap(axis_data, orig_sharding, axis): val = axis_data.explicit_mesh_axis + # TODO(yashkatariya): Preserve unreduced here using + # `orig_sharding.spec.update` new_spec = P(*tuple_insert(orig_sharding.spec, axis, val)) return NamedSharding(orig_sharding.mesh, new_spec) @@ -498,7 +504,7 @@ def process_primitive(self, p, tracers, params): with core.set_current_trace(self.parent_trace): val_out, dim_out = primitive_batchers[p](vals_in, dims_in, **params) else: - raise NotImplementedError("Batching rule for '{}' not implemented".format(p)) + raise NotImplementedError(f"Batching rule for '{p}' not implemented") src = source_info_util.current() if p.multiple_results: with core.set_current_trace(self.parent_trace): # val_out may be lazy map @@ -565,12 +571,9 @@ def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros): in_vals, in_dims = unzip2(map(self.to_batch_info, tracers)) fun, out_dims1 = batch_subtrace(fun, self.tag, self.axis_data, in_dims) jvp, out_dims2 = batch_custom_jvp_subtrace(jvp, self.tag, self.axis_data, in_dims) - out_vals = prim.bind_with_trace(self.parent_trace, (fun, jvp) + tuple(in_vals), + out_vals = prim.bind_with_trace(self.parent_trace, (fun, jvp, *in_vals), dict(symbolic_zeros=symbolic_zeros)) fst, out_dims = lu.merge_linear_aux(out_dims1, out_dims2) - if not fst: - assert out_dims == out_dims[:len(out_dims) // 2] * 2 - out_dims = out_dims[:len(out_dims) // 2] src = source_info_util.current() return [BatchTracer(self, v, d, src) for v, d in zip(out_vals, out_dims)] @@ -582,14 +585,21 @@ def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, *, out_trees, fun, out_dims1 = batch_subtrace(fun, self.tag, self.axis_data, in_dims) fwd, out_dims2 = batch_subtrace(fwd, self.tag, self.axis_data, fwd_in_dims) - bwd = batch_custom_vjp_bwd(bwd, self.tag, self.axis_data, out_dims2, in_dims) + def bwd_in_dims(): + _, _, input_fwds = out_trees() + pruned_dims = iter(out_dims2()) + full_dims = [next(pruned_dims) if f is None else in_dims[f] for f in input_fwds] + return [*full_dims, *pruned_dims] + + bwd = batch_custom_vjp_bwd(bwd, self.tag, self.axis_data, bwd_in_dims, in_dims) out_vals = prim.bind_with_trace(self.parent_trace, (fun, fwd, bwd) + tuple(in_vals), dict(out_trees=out_trees, symbolic_zeros=symbolic_zeros)) fst, out_dims = lu.merge_linear_aux(out_dims1, out_dims2) if not fst: - _, res_tree = out_trees() - _, out_dims = split_list(out_dims, [res_tree.num_leaves]) + _, res_tree, input_fwds = out_trees() + num_res = res_tree.num_leaves - sum(f is not None for f in input_fwds) + _, out_dims = split_list(out_dims, [num_res]) src = source_info_util.current() return [BatchTracer(self, v, d, src) for v, d in zip(out_vals, out_dims)] @@ -616,17 +626,15 @@ def _batch_inner(f: Callable, axis_data, out_dim_dests, tag, in_dims, *in_vals): trace = BatchTrace(parent_trace, tag, axis_data) idx = memoize(lambda: BatchTracer(trace, make_iota(axis_data.size), 0, source_info_util.current())) - in_tracers = map(partial(to_elt, trace, idx), in_vals, in_dims) + with core.set_current_trace(parent_trace): + in_tracers = map(partial(to_elt, trace, idx), in_vals, in_dims) with (core.set_current_trace(trace), core.extend_axis_env_nd([(axis_data.name, axis_data.size)]), core.add_spmd_axis_names(axis_data.spmd_name)): outs = f(*in_tracers) - - out_dim_dests = out_dim_dests() if callable(out_dim_dests) else out_dim_dests - out_vals = map(partial(from_elt, trace, axis_data.size, - axis_data.explicit_mesh_axis), - range(len(outs)), outs, out_dim_dests) - + out_dim_dests = out_dim_dests() if callable(out_dim_dests) else out_dim_dests + out_vals = map(partial(from_elt, trace, axis_data.size, axis_data.explicit_mesh_axis), + range(len(outs)), outs, out_dim_dests) return out_vals, trace # NOTE: This divides the in_axes by the tile_size and multiplies the out_axes by it. @@ -771,10 +779,17 @@ def _batch_jaxpr2( handle_ragged(closed_jaxpr.in_avals, dim, aval) if isinstance(dim, RaggedAxis) else (dim, aval) for dim, aval in zip(in_axes, closed_jaxpr.in_avals)]) - avals_in2 = [core.unmapped_aval(axis_data.size, b, aval, - axis_data.explicit_mesh_axis) - if b is not not_mapped else aval - for aval, b in unsafe_zip(avals_in, in_axes2)] + avals_in2 = [] + for aval, b in unsafe_zip(avals_in, in_axes2): + if b is not_mapped: + avals_in2.append(aval) + else: + aval = core.unmapped_aval( + axis_data.size, b, aval, axis_data.explicit_mesh_axis) + if axis_data.spmd_name is not None: + if config._check_vma.value: + aval = aval.update(vma=aval.vma | frozenset(axis_data.spmd_name)) # type: ignore + avals_in2.append(aval) jaxpr_out, _, consts, () = pe.trace_to_jaxpr_dynamic(f, avals_in2) return core.ClosedJaxpr(jaxpr_out, consts), out_axes() @@ -888,20 +903,16 @@ def batch_custom_jvp_subtrace(f, store, tag, axis_data, in_dims, *in_vals): if type(val) is SymbolicZero else BatchTracer(trace, val, dim) for val, dim in zip(in_vals, in_dims * 2)] with core.set_current_trace(trace): - outs = f(*in_tracers) - # TODO(mattjj,frostig): instantiating any SymbolicZero output is easy, but can - # be wasteful in the rare case it actually triggers; handle symbolically! - outs = [instantiate(replace_rule_output_symbolic_zeros(x)) for x in outs] - - out_vals, out_dims = unzip2(map(trace.to_batch_info, outs)) + out_tracers: list[BatchTracer | SymbolicZero] = f(*in_tracers) + out_vals, out_dims = unzip2(map(trace.to_batch_info, out_tracers)) out_primals, out_tangents = split_list(out_vals, [len(out_vals) // 2]) out_primal_bds, out_tangent_bds = split_list(out_dims, [len(out_vals) // 2]) out_dims = map(_merge_bdims, out_primal_bds, out_tangent_bds) out_primals = map(partial(matchaxis, trace.axis_data.name, size, mesh_axis), out_primal_bds, out_dims, out_primals) - out_tangents = map(partial(matchaxis, trace.axis_data.name, size, mesh_axis), + out_tangents = map(partial(_matchaxis_symzeros, trace.axis_data.name, size, mesh_axis), out_tangent_bds, out_dims, out_tangents) - store.store(out_dims * 2) + store.store(out_dims) return out_primals + out_tangents def batch_custom_vjp_bwd(bwd: lu.WrappedFun, tag: core.TraceTag, @@ -929,12 +940,11 @@ def _match_axes_and_sum(f, axis_size, axis_name, mesh_axis, out_dims_thunk, out_dim_dests, *in_vals): # this is like _match_axes, but we do reduce-sums as needed out_vals = f(*in_vals) - return map(partial(_matchaxis_symbolic_zeros, axis_name, axis_size, mesh_axis, - axis_name, sum_match=True), + return map(partial(_matchaxis_symzeros, axis_name, axis_size, mesh_axis, + sum_match=True), out_dims_thunk(), out_dim_dests, out_vals) -def _matchaxis_symbolic_zeros(axis_name, sz, mesh_axis, name, src, dst, x, - sum_match=False): +def _matchaxis_symzeros(axis_name, sz, mesh_axis, src, dst, x, sum_match=False): # Just like `matchaxis`, but handles symbolic zeros using ad_util.py # TODO(mattjj): dedup with matchaxis if isinstance(x, (Zero, SymbolicZero)): @@ -942,11 +952,11 @@ def _matchaxis_symbolic_zeros(axis_name, sz, mesh_axis, name, src, dst, x, return x elif type(src) == type(dst) == int: aval = core.mapped_aval(sz, src, x.aval) - return Zero(core.unmapped_aval(sz, dst, aval, mesh_axis)) + return type(x)(core.unmapped_aval(sz, dst, aval, mesh_axis)) elif src is not_mapped and dst is not not_mapped: - return Zero(core.unmapped_aval(sz, dst, x.aval, mesh_axis)) + return type(x)(core.unmapped_aval(sz, dst, x.aval, mesh_axis)) elif dst is not_mapped and sum_match: - return Zero(core.mapped_aval(sz, src, x.aval)) + return type(x)(core.mapped_aval(sz, src, x.aval)) else: raise ValueError((axis_name, x, src, dst)) else: @@ -1018,10 +1028,13 @@ def broadcast_batcher(prim, args, dims, **params): return (out, (0,) * len(out)) if prim.multiple_results else (out, 0) def _handle_scalar_broadcasting(nd, x, d): + # Callers of this utility, via broadcast_batcher() or defbroadcasting(), + # must be in a context where lax is importable. + from jax import lax # pytype: disable=import-error if d is not_mapped or nd == np.ndim(x): return x else: - return jax.lax.expand_dims(x, tuple(range(np.ndim(x), nd))) + return lax.expand_dims(x, tuple(range(np.ndim(x), nd))) def defreducer(prim, ident): primitive_batchers[prim] = partial(reducer_batcher, prim, ident) @@ -1077,17 +1090,20 @@ def mask_ragged_axes(operand: Array, ident, axis_spec: RaggedAxis) -> Array: def _mask_one_ragged_axis( operand: Array, ident, axis_spec: RaggedAxis) -> Array: + # Callers of this utility, via reducer_batcher() or defreducer(), + # must be in a context where lax is importable. + from jax import lax # pytype: disable=import-error assert len(axis_spec.ragged_axes) == 1, "Mask just one ragged axis at a time" ragged_axis, segment_lengths = axis_spec.ragged_axes[0] value = ident(operand.dtype) - positions = jax.lax.broadcasted_iota('int32', operand.shape, ragged_axis) + positions = lax.broadcasted_iota('int32', operand.shape, ragged_axis) # TODO(mattjj, axch) can't get ._data, need to convert it - # lengths = jax.lax.convert_element_type(segment_lengths._data, 'int32') - lengths = jax.lax.convert_element_type(segment_lengths, 'int32') - limits = jax.lax.broadcast_in_dim( + # lengths = lax.convert_element_type(segment_lengths._data, 'int32') + lengths = lax.convert_element_type(segment_lengths, 'int32') + limits = lax.broadcast_in_dim( lengths, operand.shape, [axis_spec.stacked_axis]) mask = positions < limits - return jax.lax.select(mask, operand, jax.lax.broadcast(value, operand.shape)) + return lax.select(mask, operand, lax.broadcast(value, operand.shape)) def move_stacked_axis(operand, bdim, dst): dst = canonicalize_axis(dst, operand.ndim) @@ -1102,23 +1118,34 @@ def move_stacked_axis(operand, bdim, dst): ### general utilities for manipulating axes on jaxpr types (not vmappables) def broadcast(x, sz, axis, mesh_axis=None): + # Callers of this utility must be in a context where lax is importable. + from jax import lax # pytype: disable=import-error shape = list(np.shape(x)) shape.insert(axis, sz) broadcast_dims = tuple(np.delete(np.arange(len(shape)), axis)) x_aval = core.get_aval(x) + if x_aval.sharding.mesh.empty: + mesh_axis = None new_spec = P(*tuple_insert(x_aval.sharding.spec, axis, mesh_axis)) - sharding = x_aval.sharding.with_spec(new_spec) + sharding = x_aval.sharding.update(spec=new_spec) # TODO(dougalm, yashkatariya): Delete this context manager once we figure # out how to ensure jaxpr arguments always have the context mesh. with mesh_lib.use_abstract_mesh(sharding.mesh): - return jax.lax.broadcast_in_dim(x, shape, broadcast_dims, - out_sharding=sharding) + x = lax.broadcast_in_dim(x, shape, broadcast_dims, out_sharding=sharding) + if config._check_vma.value: + # TODO(yashkatariya,parkers): don't do this, fix during fixit week 2026 + spmd_names = core.get_axis_env().spmd_axis_names + if len(spmd_names) > 1: + raise NotImplementedError + if spmd_names: + x = core.pvary(x, tuple(spmd_names)) + return x def matchaxis(axis_name, sz, mesh_axis, src, dst, x, sum_match=False): if dst == jumble_axis: x = bdim_at_front(x, src, sz) elt_ty = x.aval.update(shape=x.shape[1:]) - aval = JumbleTy(core.Var('', core.ShapedArray((), np.dtype('int32'))), + aval = JumbleTy(core.Var(core.ShapedArray((), np.dtype('int32'))), x.shape[0], elt_ty) return Jumble(aval, x) try: @@ -1147,9 +1174,9 @@ def __init__(self, leaf_idx, src, dst): self.src = src self.dst = dst -def bdim_at_front(x, bdim, size): +def bdim_at_front(x, bdim, size, mesh_axis=None): if bdim is not_mapped: - return broadcast(x, size, 0) + return broadcast(x, size, 0, mesh_axis=mesh_axis) else: return moveaxis(x, bdim, 0) @@ -1169,3 +1196,17 @@ def add_batched(batched_args, batch_dims): x = moveaxis(x, bdx, bdy) return add_jaxvals(x, y), bdy primitive_batchers[add_jaxvals_p] = add_batched + +########################### core. ################################## + +def _pvary_batcher(vals_in, dims_in, *, axes, axis_index_groups): + if any(type(axis) is int for axis in axes): + raise NotImplementedError + vals_out = core.pvary_p.bind(*vals_in, axes=axes, + axis_index_groups=axis_index_groups) + return vals_out, dims_in +primitive_batchers[core.pvary_p] = _pvary_batcher + +### mutable arrays + +defvectorized(core.mutable_array_p) diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index 1369f72ac74c..840ec336a330 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -16,11 +16,12 @@ from __future__ import annotations import collections -import contextlib from collections.abc import Callable, Iterable, Iterator, Sequence +import contextlib import dataclasses import functools from functools import partial +import heapq import io import itertools import operator @@ -31,14 +32,14 @@ from typing import Any, NamedTuple, Protocol, Union, cast as type_cast import warnings -import numpy as np - from jax._src import ad_util from jax._src import api_util from jax._src import config from jax._src import core from jax._src import dtypes from jax._src import effects as effects_lib +from jax._src import hashable_array +from jax._src import jaxpr_util from jax._src import linear_util as lu from jax._src import path from jax._src import sharding_impls @@ -48,18 +49,20 @@ from jax._src.interpreters import partial_eval as pe from jax._src.interpreters import xla from jax._src.layout import AutoLayout, DeviceLocalLayout -from jax._src.partition_spec import PartitionSpec -from jax._src.sharding import Sharding as JSharding -from jax._src.sharding_impls import (AUTO, NamedSharding, - modify_sdy_sharding_wrt_axis_types, - SdyArraySharding, SdyArrayShardingList) -from jax._src.util import foreach +from jax._src.lib import _jax from jax._src.lib import xla_client as xc -from jax._src.lib import xla_extension, xla_extension_version from jax._src.lib.mlir import dialects, ir, passmanager -from jax._src.lib.mlir.dialects import func as func_dialect, hlo from jax._src.lib.mlir import register_jax_dialects +from jax._src.lib.mlir.dialects import func as func_dialect, hlo +from jax._src.mesh import AxisType +from jax._src.partition_spec import PartitionSpec +from jax._src.sharding import Sharding as JSharding +from jax._src.sharding_impls import ( AUTO, NamedSharding, + SdyArray, SdyArrayList, + modify_sdy_sharding_wrt_axis_types) from jax._src.state.types import AbstractRef +from jax._src.util import foreach +import numpy as np # mypy: ignore-errors @@ -96,7 +99,6 @@ def _is_not_block_argument(x: IrValues) -> bool: - """Returns true if `x` is not a block argument.""" return not isinstance(x, ir.BlockArgument) @@ -185,24 +187,14 @@ def _is_ir_values(x: IrValues) -> bool: np.dtype(np.float64): ir.F64Type.get, np.dtype(np.complex64): lambda: ir.ComplexType.get(ir.F32Type.get()), np.dtype(np.complex128): lambda: ir.ComplexType.get(ir.F64Type.get()), + np.dtype(dtypes.int2): partial(ir.IntegerType.get_signless, 2), + np.dtype(dtypes.uint2): partial(ir.IntegerType.get_unsigned, 2), + np.dtype(dtypes.float8_e3m4): ir.Float8E3M4Type.get, + np.dtype(dtypes.float8_e4m3): ir.Float8E4M3Type.get, + np.dtype(dtypes.float8_e8m0fnu): ir.Float8E8M0FNUType.get, + np.dtype(dtypes.float4_e2m1fn): ir.Float4E2M1FNType.get, } - -if dtypes.int2 is not None: - assert dtypes.uint2 is not None - _dtype_to_ir_type[np.dtype(dtypes.int2)] = partial(ir.IntegerType.get_signless, 2) - _dtype_to_ir_type[np.dtype(dtypes.uint2)] = partial(ir.IntegerType.get_unsigned, 2) - -if dtypes.float8_e3m4 is not None: - _dtype_to_ir_type[np.dtype(dtypes.float8_e3m4)] = ir.Float8E3M4Type.get -if dtypes.float8_e4m3 is not None: - _dtype_to_ir_type[np.dtype(dtypes.float8_e4m3)] = ir.Float8E4M3Type.get -if dtypes.float8_e8m0fnu is not None: - _dtype_to_ir_type[np.dtype(dtypes.float8_e8m0fnu)] = ir.Float8E8M0FNUType.get - -if dtypes.float4_e2m1fn is not None: - _dtype_to_ir_type[np.dtype(dtypes.float4_e2m1fn)] = ir.Float4E2M1FNType.get - def dtype_to_ir_type(dtype: core.bint | np.dtype | np.generic) -> ir.Type: if isinstance(dtype, core.bint): # TODO Support different-size underlying dtypes to take advantage of the @@ -392,6 +384,8 @@ def _numpy_array_attribute_handler(val: np.ndarray | np.generic) -> ir.Attribute return _numpy_array_attribute(val) register_attribute_handler(np.ndarray, _numpy_array_attribute_handler) +register_attribute_handler(hashable_array.HashableArray, + lambda x: _numpy_array_attribute_handler(x.val)) for _scalar_type in [np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, np.uint64, @@ -593,6 +587,11 @@ def module_to_bytecode(module: ir.Module) -> bytes: # Translation rules +# Create one global thread pool that can be shared between multiple ir.Contexts +# and enabling multi-threading +global_thread_pool = ir.ThreadPool() + + class JaxIrContext(ir.Context): def __init__(self, *args, **kwargs): # Note: we're very intentionally *not* calling the __init__() of our @@ -607,15 +606,8 @@ def make_ir_context() -> ir.Context: context.append_dialect_registry(upstream_dialects) context.load_all_available_dialects() - # If threading is enabled, each MLIR context will keep alive a thread pool. - # Since we cache MLIR modules (and hence contexts), this means we might keep - # several threads alive for each cache entry. This is a terrible idea. However - # we don't do any heavy computation on MLIR modules from Python anyway, so we - # just disable threading. - context.enable_multithreading(False) - # TODO(bartchr): Once JAX is released with SDY, remove the if. - if dialects.sdy: - dialects.sdy.register_dialect(context) + context.set_thread_pool(global_thread_pool) + dialects.sdy.register_dialect(context) dialects.mhlo.register_mhlo_dialect(context) dialects.chlo.register_dialect(context) dialects.hlo.register_dialect(context) @@ -662,7 +654,7 @@ def __init__(self, @dataclasses.dataclass(frozen=True) class LoweringParameters: # A mapping between primitives and user-defined LoweringRules. - # When lowering a primitive, give priorioty to the rule in this map over + # When lowering a primitive, give priority to the rule in this map over # existing Jax rules. override_lowering_rules: tuple[tuple[core.Primitive, LoweringRule]] | None = None @@ -676,7 +668,7 @@ class LoweringParameters: # Signals that we are lowering for exporting. for_export: bool = False - # See usage in https://jax.readthedocs.io/en/latest/export/export.html#ensuring-forward-and-backward-compatibility + # See usage in https://docs.jax.dev/en/latest/export/export.html#ensuring-forward-and-backward-compatibility # We have this here to ensure it is reflected in the cache keys export_ignore_forward_compatibility: bool = False @@ -717,7 +709,7 @@ class ModuleContext: # Cached primitive lowerings. cached_primitive_lowerings: dict[Any, func_dialect.FuncOp] - # Cached traceback infromation. + # Cached traceback information. traceback_caches: TracebackCaches lowering_parameters: LoweringParameters @@ -834,11 +826,18 @@ def is_forward_compat(self) -> bool: """Returns true if the lowering parameters are in forward compatibility mode. """ lowering_parameters = self.module_context.lowering_parameters - return ( - lowering_parameters.for_export - and not lowering_parameters.export_ignore_forward_compatibility + + check_platforms: Sequence[str] = ( + self.platforms or self.module_context.platforms + ) + force_forward_compat = any( + p in xb.FORCE_FORWARD_COMPAT_LOWERING_PLATFORMS for p in check_platforms ) + return ( + lowering_parameters.for_export or force_forward_compat + ) and not lowering_parameters.export_ignore_forward_compatibility + if not MYPY: class LoweringRule(Protocol): @@ -1010,24 +1009,35 @@ class LoweringResult(NamedTuple): def add_manual_axes(axis_ctx: sharding_impls.SPMDAxisContext, sharding, ndim): - mesh = axis_ctx.mesh + mesh = axis_ctx.mesh.abstract_mesh + sharding_mesh = sharding.mesh.abstract_mesh if (isinstance(sharding, sharding_impls.NamedSharding) and - sharding.mesh.shape == mesh.shape): - return sharding_impls.NamedSharding( - sharding.mesh, sharding.spec, memory_kind=sharding.memory_kind, - _manual_axes=axis_ctx.manual_axes) + sharding_mesh.shape == mesh.shape): + out_mesh, spec = sharding_mesh, sharding.spec else: - spec = sharding_impls.parse_flatten_op_sharding( + out_mesh, spec = mesh, sharding_impls.parse_flatten_op_sharding( sharding._to_xla_hlo_sharding(ndim), mesh)[0] - return sharding_impls.NamedSharding( - mesh, spec, memory_kind=sharding.memory_kind, - _manual_axes=axis_ctx.manual_axes) + + out_mesh = out_mesh.update_axis_types( + {a: AxisType.Manual for a in axis_ctx.manual_axes}) + out = sharding_impls.NamedSharding(out_mesh, spec, + memory_kind=sharding.memory_kind) + manual_axes = out.mesh.manual_axes + if any(p in manual_axes for s in out.spec + if s is not None and s is not PartitionSpec.UNCONSTRAINED + for p in (s if isinstance(s, tuple) else (s,))): + raise ValueError( + f'pspec {out.spec} contains a manual axes {manual_axes} of mesh' + f' which is not allowed. If you are using a' + ' with_sharding_constraint under a shard_map, only use the' + ' mesh axis in PartitionSpec which are not manual.') + return out def _to_physical_op_sharding( ctx: ModuleContext, aval: core.AbstractValue, sharding: JSharding | AUTO | None, -) -> xc.OpSharding | SdyArraySharding | None: +) -> xc.OpSharding | SdyArray | None: if sharding is None: return None if all_unconstrained(sharding, aval): @@ -1091,13 +1101,58 @@ class UnconstrainedVariants(NamedTuple): def _get_unconstrained_variants(s, aval) -> UnconstrainedVariants: us = contains_unconstrained(s) - unconstrained_dims = ({i for i, p in enumerate(s.spec) + unconstrained_dims = ({i for i, p in enumerate(s.spec) # pytype: disable=attribute-error if p is PartitionSpec.UNCONSTRAINED} if us else None) return UnconstrainedVariants( contains_unconstrained=us, all_unconstrained=all_unconstrained(s, aval), unconstrained_dims=unconstrained_dims) +def check_jaxpr_constants(closed_jaxpr: core.ClosedJaxpr): + """Check if a JAXPR contains an excessive amount of constants, if so, report where they were captured""" + if (threshold := config.captured_constants_warn_bytes.value) == -1: + return + + # need the unaesthetic getter here as some of the consts in the test suite are arbitrary objects + total_iter, nbytes_iter = itertools.tee( + map(lambda c: getattr(c, "nbytes", 0), closed_jaxpr.consts) + ) + + if (total_bytes := sum(total_iter)) < threshold: + return + + message = ( + "A large amount of constants were captured during lowering" + f" ({util.pprint_bytes(total_bytes)} total). If this is intentional," + " disable this warning by setting JAX_CAPTURED_CONSTANTS_WARN_BYTES=-1. " + ) + + if not (num_frames := config.captured_constants_report_frames.value): + message += ( + "To obtain a report of where these constants were encountered, " + "set JAX_CAPTURED_CONSTANTS_REPORT_FRAMES=-1." + ) + warnings.warn(message) + return + + message += ( + "The subsequent report may be disabled by setting JAX_CAPTURED_CONSTANTS_REPORT_FRAMES=0.\n\n" + f"Largest {min(num_frames, len(closed_jaxpr.consts))} allocation(s):\n" + ) + try: + nbytes_var_const = zip(nbytes_iter, closed_jaxpr.jaxpr.constvars, closed_jaxpr.consts) + for nbytes, var, const in heapq.nlargest(5, nbytes_var_const, key=operator.itemgetter(0)): + message += f" Constant {type(const)}, {var.aval.str_short()}, {util.pprint_bytes(nbytes)} captured at:\n" + + for eqn in jaxpr_util.eqns_using_var(closed_jaxpr.jaxpr, var): + call_frame_source_info = source_info_util.summarize(eqn.source_info, num_frames) + message += " " * 2 + call_frame_source_info.replace("\n", "\n" + " " * 2) + "\n\n" + + warnings.warn(message) + except Exception as exc: + warnings.warn(message + f" Exception raised while generating report: {exc}") + + def lower_jaxpr_to_module( module_name: str, jaxpr: core.ClosedJaxpr, @@ -1171,7 +1226,7 @@ def lower_jaxpr_to_module( donated_args[input_id] = False if any(donated_args): unused_donations = [str(a) for a, d in zip(in_avals, donated_args) if d] - msg = "See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer-donation." + msg = "See an explanation at https://docs.jax.dev/en/latest/faq.html#buffer-donation." if not platforms_with_donation: msg = f"Donation is not implemented for {platforms}.\n{msg}" if unused_donations: @@ -1435,6 +1490,8 @@ def lower_jaxpr_to_fun( MLIR func op """ util.test_event("lower_jaxpr_to_fun", name) + check_jaxpr_constants(jaxpr) + # The first dimension variable may be the platform index num_dim_vars = len(ctx.shape_poly_state.dim_vars) dim_var_avals = [core.ShapedArray((), dtypes.canonicalize_dtype(np.int64))] * num_dim_vars @@ -1785,10 +1842,10 @@ def replicate_trailing_dims(ctx, val: ir.Value, aval) -> ir.Value: assert isinstance(aval, (core.ShapedArray, core.DShapedArray)) if config.use_shardy_partitioner.value: physical_ndim = core.physical_aval(aval).ndim - s = SdyArraySharding( + s = SdyArray( mesh_shape=None, - dimension_shardings=[ - sharding_impls.SdyDimSharding(axes=[], is_closed=i >= aval.ndim) + dim_shardings=[ + sharding_impls.SdyDim(axes=[], is_open=i < aval.ndim) for i in range(physical_ndim) ]) return wrap_with_sharding_op(ctx, val, aval, s) @@ -2026,6 +2083,11 @@ def _platforms_for_eqn_ctx(eqn_ctx: core.JaxprEqnContext | None return ('tpu',) return () +def _platforms_for_eqn(ctx: LoweringRuleContext) -> tuple[str, ...]: + """The lowering platforms for the current eqn""" + return tuple(_platforms_for_eqn_ctx(ctx.jaxpr_eqn_ctx) or + ctx.platforms or ctx.module_context.platforms) + def lower_per_platform(ctx: LoweringRuleContext, description: str, @@ -2068,8 +2130,7 @@ def lower_per_platform(ctx: LoweringRuleContext, rule_args: the args of the lowering rules. rule_kwargs: the kwargs of the lowering rules. """ - platforms: Sequence[str] = (_platforms_for_eqn_ctx(ctx.jaxpr_eqn_ctx) or - ctx.platforms or ctx.module_context.platforms) + platforms: Sequence[str] = _platforms_for_eqn(ctx) # Special case the common case (single-platform lowering) if len(platforms) == 1: rule = platform_rules.get(platforms[0], default_rule) @@ -2235,10 +2296,17 @@ def _lower_jaxpr_to_fun_cached(ctx, fn_name, call_jaxpr, effects, name_stack, try: func_op = ctx.cached_primitive_lowerings[key] except KeyError: + num_callbacks = len(ctx.host_callbacks) func_op = lower_jaxpr_to_fun( ctx, fn_name, call_jaxpr, effects, name_stack, arg_names=arg_names, result_names=result_names) - ctx.cached_primitive_lowerings[key] = func_op + + # If this Jaxpr includes callbacks, we can't cache the lowering because + # on TPU every callback must have a globally unique channel, but the + # channel gets assigned during lowering. + has_callbacks = len(ctx.host_callbacks) > num_callbacks + if not has_callbacks or "tpu" not in ctx.platforms: + ctx.cached_primitive_lowerings[key] = func_op else: func_op = lower_jaxpr_to_fun( ctx, fn_name, call_jaxpr, effects, name_stack, arg_names=arg_names, @@ -2326,21 +2394,24 @@ def core_call_lowering(ctx: LoweringRuleContext, register_lowering(core.closed_call_p, partial(core_call_lowering, name=None)) -def map_compute_type(c_type): - if c_type == 'device_host': - return 'host' - elif c_type == 'device': - return 'dense' - elif c_type == 'tpu_sparsecore': - return 'sparse' - raise ValueError(f'Invalid compute type {c_type}. Current supported values ' - 'are `device_host`, `device` and `tpu_sparsecore') - -def wrap_compute_type_in_place(ctx, op): +def map_compute_type(c_type: str) -> str: + if c_type == "device_host": + return "host" + elif c_type == "device": + return "dense" + elif c_type == "tpu_sparsecore": + return "sparse" + raise ValueError(f"Invalid compute type {c_type}. Current supported values " + "are `device_host`, `device` and `tpu_sparsecore`") + +def wrap_compute_type_in_place(ctx: LoweringRuleContext, op: ir.Operation) -> None: if ctx.jaxpr_eqn_ctx is not None and ctx.jaxpr_eqn_ctx.compute_type is not None: if ctx.jaxpr_eqn_ctx.compute_type.startswith("gpu_stream:"): stream = ctx.jaxpr_eqn_ctx.compute_type.split(":")[1] - dict_attr = {"_xla_stream_annotation": ir.StringAttr.get(stream)} + dict_attr = { + "_xla_stream_annotation": ir.StringAttr.get(stream), + "inlineable": ir.StringAttr.get("false"), + } op.operation.attributes["mhlo.frontend_attributes"] = ir.DictAttr.get(dict_attr) else: dict_attr = {"_xla_compute_type": ir.StringAttr.get( @@ -2348,7 +2419,7 @@ def wrap_compute_type_in_place(ctx, op): op.operation.attributes["mhlo.frontend_attributes"] = ir.DictAttr.get(dict_attr) -def wrap_xla_metadata_in_place(ctx, op): +def wrap_xla_metadata_in_place(ctx: LoweringRuleContext, op: ir.Operation) -> None: ctx_attributes = {} existing_attributes = {} if ctx.jaxpr_eqn_ctx is not None and ctx.jaxpr_eqn_ctx.xla_metadata: @@ -2597,7 +2668,7 @@ def _wrap_with_spmd_op(name: str, ctx: LoweringRuleContext, x: ir.Value, aval_out: core.AbstractValue, - sharding: xc.OpSharding | SdyArraySharding, + sharding: xc.OpSharding | SdyArray, unspecified_dims: set[int] | None = None, has_side_effect: bool = False, allow_shardy_lowering: bool = False): @@ -2662,7 +2733,7 @@ def lower_with_sharding_in_types(ctx, op, aval, sharding_proto=None): return wrap_with_sharding_op(ctx, op, aval, proto, unspecified_dims) -def set_sharding(op, sharding: xc.OpSharding | SdyArraySharding | SdyArrayShardingList): +def set_sharding(op, sharding: xc.OpSharding | SdyArray | SdyArrayList): if config.use_shardy_partitioner.value: op.attributes["sdy.sharding"] = get_sharding_attr(sharding) else: @@ -2670,7 +2741,7 @@ def set_sharding(op, sharding: xc.OpSharding | SdyArraySharding | SdyArrayShardi def get_sharding_attr( - sharding: xc.OpSharding | SdyArraySharding | SdyArrayShardingList + sharding: xc.OpSharding | SdyArray | SdyArrayList ) -> ir.Attribute: if config.use_shardy_partitioner.value: return sharding.build() # type: ignore @@ -2746,11 +2817,6 @@ def cached_lowering(ctx, *args, **params): return cached_lowering -def xla_computation_to_mlir_module(xla_computation: xc.XlaComputation - ) -> ir.Module: - module_str = xc._xla.mlir.xla_computation_to_mlir_module(xla_computation) - return ir.Module.parse(module_str) - def merge_mlir_modules(dst_module: ir.Module, sym_name: str, src_module: ir.Module, @@ -3027,15 +3093,12 @@ def refine_polymorphic_shapes(module: ir.Module) -> ir.Module: Then verifies that there are no more dynamic shapes in the module. """ try: - refine_polymorphic_shapes = partial(xla_extension.mlir.refine_polymorphic_shapes, + refine_polymorphic_shapes = partial(_jax.mlir.refine_polymorphic_shapes, mlir_module=module_to_bytecode(module), enable_shape_assertions=True, validate_static_shapes=True) - if xla_extension_version >= 319: - refined_module_str = refine_polymorphic_shapes( - enable_shardy=config.use_shardy_partitioner.value) - else: - refined_module_str = refine_polymorphic_shapes() + refined_module_str = refine_polymorphic_shapes( + enable_shardy=config.use_shardy_partitioner.value) except Exception as e: raise ValueError( "Error refining shapes. " + @@ -3044,3 +3107,7 @@ def refine_polymorphic_shapes(module: ir.Module) -> ir.Module: context = make_ir_context() with context: return ir.Module.parse(refined_module_str) + +########################### pvary ################################## + +register_lowering(core.pvary_p, lambda ctx, *x, axes, axis_index_groups: x) diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index 07c516fd95c7..b7e02ee0fd18 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -11,19 +11,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +# pytype: skip-file from __future__ import annotations from collections import namedtuple from collections.abc import Callable, Sequence, Hashable -from contextlib import contextmanager +import contextlib +from dataclasses import dataclass from functools import partial import itertools as it import operator as op from typing import Any, NamedTuple, Union from weakref import ref -import numpy as np - from jax._src import ad_util from jax._src import api_util from jax._src import config @@ -41,13 +42,14 @@ JaxprEqn, Primitive, ShapedArray, DShapedArray, mapped_aval, unmapped_aval, DBIdx, InDBIdx, OutDBIdx, InputType, OutputType, get_referent, JaxprEqnContext) +from jax._src.source_info_util import SourceInfo from jax._src.state.types import AbstractRef, ReadEffect -from jax._src.tree_util import (PyTreeDef, treedef_tuple, - tree_flatten, tree_structure) +from jax._src.tree_util import (PyTreeDef, treedef_tuple, tree_flatten, + tree_structure, register_static) from jax._src.util import (unzip2, safe_zip, safe_map, toposort, split_list, merge_lists, partition_list, OrderedSet, as_hashable_function, weakref_lru_cache, subs_list, - HashableFunction, foreach) + HashableFunction, foreach, cache) map, unsafe_map = safe_map, map @@ -58,6 +60,15 @@ def identity(x): return x AvalId = int ConstId = int +AttrKind = Any +PyTree = Any + +# Attrs flavors, see jax/experimental/attrs.py +ReadWrite = type('ReadWrite', (), {})() +Append = type('Append', (), {})() +BoxAttr = type('BoxAttr', (), {})() +ListAttr = type('ListAttr', (), {})() + def _update_annotation_known( f: lu.WrappedFun, orig_type: InputType | None, @@ -137,6 +148,10 @@ def get_aval(self) -> AbstractValue: else: return self[0] +@dataclass(frozen=True) +class EffectHandle: + parents : list[Tracer] + recipe : JaxprEqnRecipe class JaxprTrace(Trace['JaxprTracer']): @@ -145,6 +160,9 @@ def __init__(self, parent_trace:Trace, name_stack: source_info_util.NameStack, t self.name_stack = name_stack self.tag = tag self.parent_trace = parent_trace + self.requires_low = False + self.effect_handles : list[EffectHandle] = [] + self.counter = it.count() def to_jaxpr_tracer(self, x): if isinstance(x, JaxprTracer) and x._trace.tag is self.tag: @@ -191,18 +209,18 @@ def instantiate_const(self, tracer: JaxprTracer) -> JaxprTracer: if const is None: return tracer else: - if type(const) in core.literalable_types and np.shape(const) == (): + if core.is_literalable(const): return self.new_instantiated_literal(const) else: return self.new_instantiated_const(const) - def instantiate_const_abstracted(self, tracer) -> JaxprTracer: - const = tracer.pval.get_known() + def cur_qdd(self, x): + const = self.to_jaxpr_tracer(x).pval.get_known() if const is None: - return tracer + assert False # TODO: track tangent QDDs else: - aval = get_aval(const).update_weak_type(np.isscalar(const)) - return JaxprTracer(self, PartialVal.unknown(aval), ConstVar(const)) + with core.set_current_trace(self.parent_trace): + return core.cur_qdd(const) def process_primitive(self, primitive, tracers, params): with core.set_current_trace(self.parent_trace): @@ -222,20 +240,25 @@ def default_process_primitive(self, primitive, tracers, params): return primitive.bind_with_trace(self.parent_trace, consts, params) tracers = map(self.instantiate_const, tracers) avals = [t.aval for t in tracers] - out_aval, effects = primitive.abstract_eval(*avals, **params) + out_aval, effs = primitive.abstract_eval(*avals, **params) name_stack = self._current_truncated_name_stack() source = source_info_util.current().replace(name_stack=name_stack) if primitive.multiple_results: out_tracers = [JaxprTracer(self, PartialVal.unknown(aval), None) for aval in out_aval] - eqn = new_eqn_recipe(tracers, out_tracers, primitive, params, effects, + eqn = new_eqn_recipe(self, tracers, out_tracers, primitive, params, effs, source) + if effects.partial_eval_kept_effects.filter_in(effs): + self.effect_handles.append(EffectHandle(tracers, eqn)) for t in out_tracers: t.recipe = eqn return out_tracers else: out_tracer = JaxprTracer(self, PartialVal.unknown(out_aval), None) - out_tracer.recipe = new_eqn_recipe(tracers, [out_tracer], primitive, - params, effects, source) + eqn = new_eqn_recipe(self, tracers, [out_tracer], primitive, + params, effs, source) + if effects.partial_eval_kept_effects.filter_in(effs): + self.effect_handles.append(EffectHandle(tracers, eqn)) + out_tracer.recipe = eqn return out_tracer def process_call(self, primitive, f: lu.WrappedFun, tracers, params): @@ -310,7 +333,7 @@ def process_call(self, primitive, f: lu.WrappedFun, tracers, params): for a in out_type] name_stack = self._current_truncated_name_stack() source = source_info_util.current().replace(name_stack=name_stack) - eqn = new_eqn_recipe((*res_tracers, *env_tracers, *unknown_arg_tracers), + eqn = new_eqn_recipe(self, (*res_tracers, *env_tracers, *unknown_arg_tracers), out_tracers, primitive, staged_params, jaxpr.effects, source) for t in out_tracers: t.recipe = eqn @@ -379,7 +402,7 @@ def const_out_axes_thunk(): for a in out_avals] effs = core.filter_named_axis_effects(jaxpr.effects, {params['axis_name']}) src_info = source_info_util.current() - eqn = new_eqn_recipe((*const_tracers, *env_tracers, *unknown_arg_tracers), + eqn = new_eqn_recipe(self, (*const_tracers, *env_tracers, *unknown_arg_tracers), out_tracers, primitive, staged_params, effs, src_info) for t in out_tracers: t.recipe = eqn @@ -395,8 +418,7 @@ def process_custom_jvp_call(self, prim, fun, jvp, tracers, symbolic_zeros): vals = [t.pval[1] for t in tracers] return prim.bind(fun, jvp, *vals, symbolic_zeros=symbolic_zeros) # We assume non-trivial partial evaluation is only performed to build linear - # functions, and hence we don't need to keep the custom JVP rule around - # anymore. + # functions, and hence we don't need to keep the custom JVP rule around. del jvp, symbolic_zeros with core.set_current_trace(self): return fun.call_wrapped(*tracers) @@ -415,7 +437,7 @@ def process_custom_transpose(self, prim, call, tracers, **params): for aval in params['out_types']] in_tracers = map(self.instantiate_const, tracers) new_params = dict(params, call=call) - eqn = new_eqn_recipe(in_tracers, out_tracers, prim, new_params, + eqn = new_eqn_recipe(self, in_tracers, out_tracers, prim, new_params, core.no_effects, source_info_util.current()) for t in out_tracers: t.recipe = eqn return out_tracers @@ -425,49 +447,45 @@ def process_custom_vjp_call(self, prim, f, fwd, bwd, tracers, out_trees, symboli if all(t.is_known() for t in tracers): vals = [t.pval[1] for t in tracers] with core.set_current_trace(self.parent_trace): - return prim.bind(f, fwd, bwd, *vals, out_trees=out_trees, symbolic_zeros=symbolic_zeros) - else: - # TODO(mattjj): remove non-ad users of partial eval, then drop this case. - # We stage out the whole thing, i.e. no nontrivial partial evaluation. - tracers = map(self.instantiate_const_abstracted, tracers) - # Because we instantiate all tracers, in_knowns is all False. - in_knowns, in_avals, () = partition_pvals([t.pval for t in tracers]) - f = trace_to_subjaxpr_nounits(f, self, True, f.debug_info) - f, aux = partial_eval_wrapper_nounits(f, (*in_knowns,), (*in_avals,)) - with core.set_current_trace(self.parent_trace): - out_flat = prim.bind(f, fwd, bwd, out_trees=out_trees, - symbolic_zeros=symbolic_zeros) - out_knowns, out_avals, jaxpr, env = aux() - out_consts, res = split_list(out_flat, [len(out_flat)-len(jaxpr.constvars)]) - res_tracers = map(self.new_instantiated_const, res) - env_tracers = map(self.to_jaxpr_tracer, env) - out_tracers = [JaxprTracer(self, PartialVal.unknown(a), None) - for a in out_avals] - closed_jaxpr = core.ClosedJaxpr(convert_constvars_jaxpr(jaxpr), ()) - - @_memoize - def fwd_jaxpr_thunk(*zeros): - fwd_ = _interleave_fun(fwd, zeros) - fwd_ = trace_to_subjaxpr_nounits(fwd_, self, True, fwd_.debug_info) - fwd_, aux = partial_eval_wrapper_nounits(fwd_, (*in_knowns,), (*in_avals,)) - out_flat = fwd_.call_wrapped() - out_knowns, out_avals, jaxpr, env = aux() - _, res = split_list(out_flat, [len(out_flat)-len(jaxpr.constvars)]) - converted_jaxpr = convert_envvars_to_constvars(jaxpr, len(env)) - return converted_jaxpr, (*res, *env) + return prim.bind(f, fwd, bwd, *vals, out_trees=out_trees, + symbolic_zeros=symbolic_zeros) + + tracers = map(self.instantiate_const, tracers) + in_knowns = (False,) * len(tracers) + in_avals = tuple(t.aval for t in tracers) + f_ = trace_to_subjaxpr_nounits2(f, self.tag, f.debug_info, True) + f_, aux = partial_eval_wrapper_nounits(f_, in_knowns, in_avals) + params = dict(out_trees=out_trees, symbolic_zeros=symbolic_zeros) + res = prim.bind_with_trace(self.parent_trace, (f_, fwd, bwd), params) + out_knowns, out_avals, jaxpr, env = aux() + assert not any(out_knowns) + res_tracers = map(self.instantiate_const, map(self.new_const, res)) + env_tracers = map(self.to_jaxpr_tracer, env) + out_tracers = [JaxprTracer(self, PartialVal.unknown(a), None) + for a in out_avals] + closed_jaxpr = close_jaxpr(convert_constvars_jaxpr(jaxpr)) + + @partial(lu.wrap_init, debug_info=fwd.debug_info) + @_memoize + def fwd_jaxpr_thunk(*zeros): + fwd_ = _interleave_fun(fwd, zeros) + fwd_jaxpr, _, consts, () = trace_to_jaxpr_dynamic(fwd_, in_avals) + return fwd_jaxpr, consts name_stack = self._current_truncated_name_stack() source = source_info_util.current().replace(name_stack=name_stack) - eqn = new_eqn_recipe((*res_tracers, *env_tracers, *tracers), - out_tracers, prim.initial_style, - dict(fun_jaxpr=closed_jaxpr, - fwd_jaxpr_thunk=fwd_jaxpr_thunk, - num_consts=len(res) + len(env), - bwd=bwd, out_trees=out_trees, - symbolic_zeros=symbolic_zeros), - jaxpr.effects, source) + params = dict( + call_jaxpr=closed_jaxpr, + fwd_jaxpr_thunk=fwd_jaxpr_thunk, + num_consts=len(res) + len(env), + bwd=bwd, + out_trees=out_trees, + symbolic_zeros=symbolic_zeros + ) + eqn = new_eqn_recipe(self, (*res_tracers, *env_tracers, *tracers), + out_tracers, prim, params, jaxpr.effects, source) for t in out_tracers: t.recipe = eqn - return merge_lists(out_knowns, out_tracers, out_consts) + return out_tracers def partition_pvals( pvals: list[PartialVal] @@ -494,6 +512,24 @@ def partial_eval_wrapper_nounits( store.store((*maybe_fwds, out_knowns, out_avals, jaxpr, env)) return (*out_consts, *res) +@lu.transformation_with_aux2 +def partial_eval_wrapper_nounits2( + f: Callable, + store: lu.Store, + in_knowns: Sequence[bool], + in_avals: Sequence[AbstractValue], + *in_consts: Any): + in_avals_, in_consts_ = iter(in_avals), iter(in_consts) + in_pvals = [PartialVal.known(next(in_consts_)) if known else + PartialVal.unknown(next(in_avals_)) for known in in_knowns] + sentinel = object() + assert next(in_avals_, sentinel) is next(in_consts_, sentinel) is sentinel + jaxpr, (*maybe_fwds, out_pvals, res, env) = f(in_pvals) + out_knowns, _, out_consts = partition_pvals(out_pvals) + res_avals = [core.typeof(r) for r in res] + store.store((*maybe_fwds, out_knowns, res_avals, jaxpr, env)) + return (*out_consts, *res) + custom_partial_eval_rules: dict[Primitive, Callable] = {} call_partial_eval_rules: dict[Primitive, Callable] = {} call_param_updaters: dict[Primitive, Callable] = {} @@ -633,7 +669,7 @@ def _trace_to_subjaxpr_nounits(f: Callable, trace: JaxprTrace, out_tracers = [trace.instantiate_const(t) if inst else t for inst, t in zip(instantiate, out_tracers)] out_tracers_ = [t for t in out_tracers if not t.is_known()] - jaxpr, out_consts, env = tracers_to_jaxpr(in_tracers, out_tracers_, debug_info) + jaxpr, out_consts, env = tracers_to_jaxpr(in_tracers, out_tracers_, trace.effect_handles, debug_info) return out_tracers, jaxpr, out_consts, env # The below variant implements an optimization where residuals which are also @@ -715,7 +751,8 @@ class JaxprEqnRecipe(NamedTuple): source_info: source_info_util.SourceInfo ctx: JaxprEqnContext -def new_eqn_recipe(in_tracers: Sequence[JaxprTracer], +def new_eqn_recipe(trace: JaxprTrace, + in_tracers: Sequence[JaxprTracer], out_tracers: Sequence[JaxprTracer], primitive: Primitive, params: dict[str, Any], @@ -738,7 +775,7 @@ def new_eqn_recipe(in_tracers: Sequence[JaxprTracer], config.threefry_partitionable.value, xla_metadata_lib.current_xla_metadata(), ) - return JaxprEqnRecipe(object(), tuple(in_tracers), map(ref, out_tracers), + return JaxprEqnRecipe(next(trace.counter), tuple(in_tracers), map(ref, out_tracers), out_avals, primitive, params, effects, source_info, ctx) @@ -756,6 +793,7 @@ def recipe_to_eqn(getvar: Callable[[JaxprTracer], Atom], def tracers_to_jaxpr( in_tracers: Sequence[JaxprTracer], out_tracers: Sequence[JaxprTracer], + effect_handles: Sequence[Any], debug_info: core.DebugInfo, ) -> tuple[Jaxpr, tuple[Any, ...], tuple[Any, ...]]: """Constructs Jaxpr given tracers for inputs and outputs. @@ -797,7 +835,15 @@ def type_substitute(aval: AbstractValue) -> AbstractValue: processed_eqn_ids = set() eqns: list[core.JaxprEqn] = [] - for t in toposort((*in_tracers, *out_tracers)): + + reachable = toposort + tracers = reachable((*in_tracers, *out_tracers, *effect_handles)) + def sort_key(t): + r = t.recipe + return r.eqn_id if isinstance(r, JaxprEqnRecipe) else -1 + tracers = sorted(tracers, key=sort_key) + + for t in tracers: r = t.recipe if isinstance(r, JaxprEqnRecipe): # TODO broadcast_in_dim can create a new tracer, not present in parents @@ -850,10 +896,8 @@ def convert_constvars_jaxpr(jaxpr: Jaxpr) -> Jaxpr: config.enable_checks.value and core.check_jaxpr(jaxpr) dbg = jaxpr.debug_info._replace( arg_names=("",) * len(jaxpr.constvars) + jaxpr.debug_info.arg_names) - lifted_jaxpr = Jaxpr(constvars=(), - invars=jaxpr.constvars + jaxpr.invars, - outvars=jaxpr.outvars, eqns=jaxpr.eqns, - effects=jaxpr.effects, debug_info=dbg) + lifted_jaxpr = jaxpr.replace( + constvars=(), invars=jaxpr.constvars + jaxpr.invars, debug_info=dbg) config.enable_checks.value and core.check_jaxpr(lifted_jaxpr) return lifted_jaxpr @@ -876,9 +920,8 @@ def convert_envvars_to_constvars(jaxpr: Jaxpr, num_env_vars: int) -> Jaxpr: raise NotImplementedError config.enable_checks.value and core.check_jaxpr(jaxpr) env_vars, invars = split_list(jaxpr.invars, [num_env_vars]) - converted_jaxpr = Jaxpr(constvars=jaxpr.constvars + env_vars, - invars=invars, outvars=jaxpr.outvars, eqns=jaxpr.eqns, - effects=jaxpr.effects, debug_info=jaxpr.debug_info) + converted_jaxpr = jaxpr.replace(constvars=jaxpr.constvars + env_vars, + invars=invars) config.enable_checks.value and core.check_jaxpr(converted_jaxpr) return converted_jaxpr @@ -944,75 +987,69 @@ def partial_eval_jaxpr_nounits( passed to jaxpr_unknown (as leading inputs). """ instantiate = tuple(instantiate) if isinstance(instantiate, list) else instantiate - return _partial_eval_jaxpr_nounits(jaxpr, tuple(unknowns), instantiate) + return _partial_eval_jaxpr_nounits(jaxpr, tuple(unknowns), instantiate, False)[:-1] + +def partial_eval_jaxpr_nounits_fwd( + jaxpr: ClosedJaxpr, unknowns: Sequence[bool], + instantiate: bool | Sequence[bool], +) -> tuple[ClosedJaxpr, ClosedJaxpr, list[bool], list[AbstractValue], list[int | None]]: + instantiate = tuple(instantiate) if isinstance(instantiate, list) else instantiate + return _partial_eval_jaxpr_nounits(jaxpr, tuple(unknowns), instantiate, True) @weakref_lru_cache -def _partial_eval_jaxpr_nounits(jaxpr: ClosedJaxpr, - in_unknowns: Sequence[bool], - instantiate: bool | Sequence[bool]): - f = lu.wrap_init(core.jaxpr_as_fun(jaxpr), - debug_info=jaxpr.jaxpr.debug_info) +def _partial_eval_jaxpr_nounits( + jaxpr: ClosedJaxpr, in_unknowns: Sequence[bool], + instantiate: bool | Sequence[bool], fwd: bool): + f = lu.wrap_init(core.jaxpr_as_fun(jaxpr), debug_info=jaxpr.jaxpr.debug_info) cell = [] def fun(*known_vals_in): - known_vals_in = iter(known_vals_in) + known_vals_in_ = iter(known_vals_in) unknown_avals = (a for a, uk in zip(jaxpr.in_avals, in_unknowns) if uk) in_pvals = [PartialVal.unknown(next(unknown_avals)) if uk - else PartialVal.known(next(known_vals_in)) for uk in in_unknowns] - assert next(known_vals_in, None) is next(unknown_avals, None) is None - jaxpr_unknown_, out_pvals, residuals = trace_to_jaxpr_nounits( - f, in_pvals, instantiate=instantiate) + else PartialVal.known(next(known_vals_in_)) for uk in in_unknowns] + assert next(known_vals_in_, None) is next(unknown_avals, None) is None + jaxpr_unknown_, (fwds, out_pvals, residuals, ()) = trace_to_subjaxpr_nounits_fwd( + f, TraceTag(), jaxpr.jaxpr.debug_info, instantiate).call_wrapped(in_pvals) jaxpr_unknown = convert_constvars_jaxpr(jaxpr_unknown_) out_unknowns = [not pval.is_known() for pval in out_pvals] + if not fwd: + residuals_ = iter(residuals) + residuals = [next(residuals_) if f is None else known_vals_in[f] + for f in fwds] + assert next(residuals_, None) is None + fwds = [None] * len(fwds) + else: + fwds, residuals = _include_consts_in_fwds(jaxpr.consts, fwds, residuals) res_avals = [core.get_aval(r) for r in residuals] - cell.append((out_unknowns, jaxpr_unknown, res_avals)) + cell.append((out_unknowns, jaxpr_unknown, res_avals, fwds)) known_vals_out = [pval.get_known() for pval in out_pvals if pval.is_known()] return [*known_vals_out, *residuals] - known_avals = [a for a, uk in zip(jaxpr.in_avals, in_unknowns) if not uk] + known_avals = [a for a, uk in zip(jaxpr.in_aval_qdds, in_unknowns) if not uk] jaxpr_known, _, consts_known, () = trace_to_jaxpr_dynamic( - lu.wrap_init(fun, debug_info=f.debug_info), - known_avals) - (out_unknowns, jaxpr_unknown, res_avals), = cell # pytype: disable=bad-unpacking + lu.wrap_init(fun, debug_info=f.debug_info), known_avals) + (out_unknowns, jaxpr_unknown, res_avals, fwds), = cell # pytype: disable=bad-unpacking - # check jaxpr_known and jaxpr_unknown in isolation - # TODO(mattjj): enable weak type checking here if config.enable_checks.value: core.check_jaxpr(jaxpr_known) core.check_jaxpr(jaxpr_unknown) - def check(first, second): - for f, s in zip(first, second): - if (not isinstance(f, core.ShapedArray) and - not isinstance(s, core.ShapedArray)): - assert f == s - elif f.sharding.mesh.empty or s.sharding.mesh.empty: - assert (f.shape, f.dtype) == (s.shape, s.dtype) - else: - assert f == s, (f, s) - - # check jaxpr_known has input type corresponding to known inputs of jaxpr - assert ([v.aval for v in jaxpr_known.invars] == - [a for a, uk in zip(jaxpr.in_avals, in_unknowns) if not uk]) - # check jaxpr_known has out type corresponding to known outs of jaxpr plus res - # Change this to `assert ... == ...` and remove the check function. - # See https://github.com/jax-ml/jax/issues/26474 - check([v.aval.strip_weak_type() for v in jaxpr_known.outvars], - [a.strip_weak_type() for a, uk in zip(jaxpr.out_avals, out_unknowns) - if not uk] + [a.strip_weak_type() for a in res_avals]) - # check jaxpr_unknown has input type corresponding to res plus unknown inputs - assert ([v.aval.strip_weak_type() for v in jaxpr_unknown.invars] == - [a.strip_weak_type() for a in res_avals] + - [a.strip_weak_type() for a, uk in zip(jaxpr.in_avals, in_unknowns) - if uk]) - # check jaxpr_unknown has output type corresponding to unknown outputs - check([v.aval.strip_weak_type() for v in jaxpr_unknown.outvars], - [a.strip_weak_type() for a, uk in zip(jaxpr.out_avals, out_unknowns) - if uk]) - closed_jaxpr_known = ClosedJaxpr(jaxpr_known, consts_known) closed_jaxpr_unknown = ClosedJaxpr(jaxpr_unknown, ()) - return closed_jaxpr_known, closed_jaxpr_unknown, out_unknowns, res_avals + return closed_jaxpr_known, closed_jaxpr_unknown, out_unknowns, res_avals, fwds + +def _include_consts_in_fwds(consts, fwds, residuals): + if all(f is None for f in fwds): + return fwds, residuals + dummys = [object() for _ in range(max(f for f in fwds if f is not None) + 1)] + residuals_ = iter(residuals) + residuals = [next(residuals_) if f is None else dummys[f] for f in fwds] + assert next(residuals_, None) is None + idxs = {id(x): i for i, x in enumerate((*consts, *dummys))} + fwds = [idxs.get(id(r)) for r in residuals] + residuals = [r for r in residuals if id(r) not in idxs] + return fwds, residuals def partial_eval_jaxpr_custom( @@ -1083,7 +1120,6 @@ def ensure_instantiated(inst: bool, x: Atom) -> Atom: def has_effects(effects) -> bool: return bool({e for e in effects if not isinstance(e, core.NamedAxisEffect)}) - newvar = core.gensym(suffix='_offload') known_eqns, staged_eqns = [], [] foreach(write, in_unknowns, in_inst, jaxpr.invars) foreach(partial(write, False, True), jaxpr.constvars) @@ -1113,13 +1149,15 @@ def has_effects(effects) -> bool: elif isinstance(policy, Offloadable): # TODO(slebedev): This is a legit error which requires a BUILD fix. from jax._src.dispatch import device_put_p, TransferToMemoryKind, CopySemantics # pytype: disable=import-error - resvars = [newvar(v.aval) for v in eqn.outvars] + resvars = [Var(v.aval) for v in eqn.outvars] outvars_copy = list[Atom](eqn.outvars) offload_eqn = core.JaxprEqn( outvars_copy, resvars, device_put_p, - dict(devices=[TransferToMemoryKind(policy.dst) - ] * len(outvars_copy), srcs=[None], - copy_semantics=[CopySemantics.COPY]), + dict( + devices=(TransferToMemoryKind(policy.dst),) * len(outvars_copy), + srcs=(None,), + copy_semantics=(CopySemantics.COPY,), + ), set(), source_info_util.new_source_info(), JaxprEqnContext(None, False)) known_eqns.append(offload_eqn) @@ -1128,9 +1166,11 @@ def has_effects(effects) -> bool: residuals.update(resvars) reload_eqn = core.JaxprEqn( resvars, eqn.outvars, device_put_p, - dict(devices=[TransferToMemoryKind(policy.src) - ] * len(resvars), srcs=[None], - copy_semantics=[CopySemantics.COPY]), + dict( + devices=(TransferToMemoryKind(policy.src),) * len(resvars), + srcs=(None,), + copy_semantics=(CopySemantics.COPY,) + ), set(), source_info_util.new_source_info(), JaxprEqnContext(None, False)) staged_eqns.append(reload_eqn) @@ -1150,6 +1190,7 @@ def has_effects(effects) -> bool: out_unknowns = map(op.or_, out_unknowns, ensure_out_unknowns) out_inst = map(op.or_, out_inst, ensure_out_inst) + ins_known, _ = partition_list(in_unknowns, jaxpr.invars) outs_known, _ = partition_list(out_unknowns, jaxpr.outvars) ref_res_is_input = [r in ins_known for r in residual_refs] @@ -1158,8 +1199,11 @@ def has_effects(effects) -> bool: known_outvars = [*outs_known, *residuals] known_effects = make_jaxpr_effects(jaxpr.constvars, ins_known_and_ref_res, known_outvars, known_eqns) - jaxpr_known = Jaxpr(jaxpr.constvars, ins_known_and_ref_res, known_outvars, - known_eqns, known_effects, jaxpr.debug_info) + + # TODO(mattjj,necula): debug info should be updated here + jaxpr_known = jaxpr.replace( + invars=ins_known_and_ref_res, outvars=known_outvars, + eqns=known_eqns, effects=known_effects) config.enable_checks.value and core.check_jaxpr(jaxpr_known) _, ins_staged = partition_list(in_inst, jaxpr.invars) @@ -1167,9 +1211,10 @@ def has_effects(effects) -> bool: staged_invars = [*residuals, *non_input_res_refs, *ins_staged] staged_effects = make_jaxpr_effects(jaxpr.constvars, staged_invars, outs_staged, staged_eqns) - jaxpr_staged = Jaxpr(jaxpr.constvars, staged_invars, - outs_staged, staged_eqns, staged_effects, - jaxpr.debug_info) + # TODO(mattjj,necula): debug info should be updated here + jaxpr_staged = jaxpr.replace( + invars=staged_invars, outvars=outs_staged, eqns=staged_eqns, + effects=staged_effects) config.enable_checks.value and core.check_jaxpr(jaxpr_staged) return (jaxpr_known, jaxpr_staged, out_unknowns, out_inst, len(residuals), @@ -1229,14 +1274,12 @@ def _default_res_aval_updater( params: dict[str, Any], aval: AbstractValue) -> AbstractValue: return aval -@contextmanager -def trivial_ctx(_): yield def call_partial_eval_custom_rule( jaxpr_param_name: str, params_updater: ParamsUpdater, saveable: Callable[..., RematCases_], unks_in: list[bool], inst_in: list[bool], eqn: JaxprEqn, *, res_aval: ResAvalUpdater = _default_res_aval_updater, - ctx = trivial_ctx, + ctx = contextlib.nullcontext, ) -> tuple[JaxprEqn, JaxprEqn, Sequence[bool], Sequence[bool], list[Var]]: jaxpr = eqn.params[jaxpr_param_name] with ctx(eqn.params): @@ -1246,13 +1289,12 @@ def call_partial_eval_custom_rule( out_binders_known, _ = partition_list(unks_out, eqn.outvars) _, ins_staged = partition_list(inst_in, eqn.invars) _, out_binders_staged = partition_list(inst_out, eqn.outvars) - newvar = core.gensym() params_known = {**eqn.params, jaxpr_param_name: jaxpr_known} params_staged = {**eqn.params, jaxpr_param_name: jaxpr_staged} params_known, params_staged = params_updater( unks_in, inst_in, map(op.not_, unks_out), inst_out, num_res, params_known, params_staged) - residuals = [newvar(res_aval(params_known, var.aval)) + residuals = [Var(res_aval(params_known, var.aval)) for var in jaxpr_staged.invars[:num_res]] eqn_known = new_jaxpr_eqn(ins_known, [*out_binders_known, *residuals], eqn.primitive, params_known, jaxpr_known.effects, @@ -1285,14 +1327,13 @@ def closed_call_partial_eval_custom_rule( ins_known, _ = partition_list(unks_in, eqn.invars) _, ins_staged = partition_list(inst_in, eqn.invars) _, out_binders_staged = partition_list(inst_out, eqn.outvars) - newvar = core.gensym() params_known = {**eqn.params, jaxpr_param_name: jaxpr_known} params_staged = {**eqn.params, jaxpr_param_name: jaxpr_staged} params_known, params_staged = params_updater( unks_in, inst_in, map(op.not_, unks_out), inst_out, sum(f is None for f in out_fwd), num_res, params_known, params_staged) res_val_binders, res_ref_binders = split_list( - [newvar(res_aval(params_known, v)) + [Var(res_aval(params_known, v)) for v in jaxpr_staged.in_avals[:num_res]], [num_res_val]) res_val_binders = [v for v, f in zip(res_val_binders, out_fwd) if f is None] res_val_vars = subs_list(out_fwd, out_binders_known, res_val_binders) @@ -1350,15 +1391,15 @@ def _closed_jaxpr_partial_eval_custom_cached( def _jaxpr_forwarding(jaxpr: Jaxpr) -> list[int | None]: # Compute which inputs are just forwarded to outputs. - fwds: dict[Var, Var] = dict(zip(jaxpr.invars, jaxpr.invars)) + fwds: dict[Var, Atom] = dict(zip(jaxpr.invars, jaxpr.invars)) for eqn in jaxpr.eqns: if eqn.primitive in forwarding_rules: eqn = eqn.replace(invars=[a if type(a) is Literal else fwds.get(a, a) # type: ignore for a in eqn.invars]) - fwd_vars, _ = forwarding_rules[eqn.primitive](eqn) - for v_orig, v_new in zip(eqn.outvars, fwd_vars): - if v_new is not None: - fwds[v_orig] = v_new + fwd_idx, _ = forwarding_rules[eqn.primitive](eqn) + for v_orig, idx in zip(eqn.outvars, fwd_idx): + if idx is not None: + fwds[v_orig] = eqn.invars[idx] idxs: dict[Var, int] = {v: i for i, v in enumerate(jaxpr.invars)} return [None if type(v) is Literal else idxs.get(fwds.get(v)) # type: ignore for v in jaxpr.outvars] @@ -1462,7 +1503,8 @@ def write(x: Atom, b: bool) -> None: jaxpr.debug_info.traced_for, jaxpr.debug_info.func_src_info, jaxpr.debug_info.filter_arg_names(used_inputs), jaxpr.debug_info.filter_result_paths(used_outputs)) - new_jaxpr = Jaxpr(jaxpr.constvars, invars, outvars, eqns, jaxpr_effects, dbg) + new_jaxpr = jaxpr.replace(invars=invars, outvars=outvars, eqns=eqns, + effects=jaxpr_effects, debug_info=dbg) config.enable_checks.value and core.check_jaxpr(new_jaxpr) return new_jaxpr, used_inputs @@ -1527,6 +1569,20 @@ def dce_jaxpr_closed_call_rule(used_outputs: list[bool], eqn: JaxprEqn def close_jaxpr(jaxpr: Jaxpr) -> ClosedJaxpr: return ClosedJaxpr(jaxpr, ()) +def move_invars_right(jaxpr: ClosedJaxpr, to_move: Sequence[bool]): + return _move_invars_right(jaxpr, tuple(to_move)) + +@weakref_lru_cache +def _move_invars_right(jaxpr: ClosedJaxpr, to_move: tuple[bool, ...]): + invars, rest = split_list(jaxpr.jaxpr.invars, [len(to_move)]) + left_invars, right_invars = partition_list(to_move, invars) + new_invars = [*left_invars, *right_invars, *rest] + new_effs = _renumber_effects( + (*jaxpr.jaxpr.constvars, *new_invars), + (*jaxpr.jaxpr.constvars, *jaxpr.jaxpr.invars), + jaxpr.jaxpr.effects) + return jaxpr.replace(jaxpr=jaxpr.jaxpr.replace(invars=new_invars, effects=new_effs)) + def move_binders_to_front(closed_jaxpr: ClosedJaxpr, to_move: Sequence[bool] ) -> ClosedJaxpr: """Reorder `invars` by moving those indicated in `to_move` to the front.""" @@ -1536,14 +1592,21 @@ def move_binders_to_front(closed_jaxpr: ClosedJaxpr, to_move: Sequence[bool] def _move_binders_to_front(closed_jaxpr: ClosedJaxpr, to_move: tuple[bool, ...] ) -> ClosedJaxpr: assert len(closed_jaxpr.in_avals) == len(to_move) - new_invars = _move_to_front(closed_jaxpr.jaxpr.invars, to_move) - new_jaxpr = Jaxpr(closed_jaxpr.jaxpr.constvars, new_invars, - closed_jaxpr.jaxpr.outvars, closed_jaxpr.jaxpr.eqns, - closed_jaxpr.jaxpr.effects, - closed_jaxpr.jaxpr.debug_info) + constvars, invars = closed_jaxpr.jaxpr.constvars, closed_jaxpr.jaxpr.invars + new_invars = _move_to_front(invars, to_move) + new_effs = _renumber_effects( + (*constvars, *new_invars), (*constvars, *invars), closed_jaxpr.jaxpr.effects) + new_jaxpr = closed_jaxpr.jaxpr.replace( + constvars=constvars, invars=new_invars, effects=new_effs) new_closed_jaxpr = core.ClosedJaxpr(new_jaxpr, closed_jaxpr.consts) return new_closed_jaxpr +def _renumber_effects(new_vars, old_vars, effs): + newvar_idxs = {id(v): i for i, v in enumerate(new_vars)} + old_to_new = {i: newvar_idxs[id(v)] for i, v in enumerate(old_vars)} + return {e.replace(input_index=old_to_new[e.input_index]) + if isinstance(e, effects.JaxprInputEffect) else e for e in effs} + def _move_to_front(lst: Sequence, to_move: Sequence[bool]) -> Sequence: return ([elt for elt, move in zip(lst, to_move) if move] + [elt for elt, move in zip(lst, to_move) if not move]) @@ -1553,20 +1616,47 @@ def move_binders_to_back(closed_jaxpr: ClosedJaxpr, to_move: Sequence[bool] """Reorder `invars` by moving those indicated in `to_move` to the back.""" return move_binders_to_front(closed_jaxpr, map(op.not_, to_move)) +def move_outvars_to_back(jaxpr: ClosedJaxpr, to_move: Sequence[bool]) -> ClosedJaxpr: + return _move_outvars_to_back(jaxpr, tuple(to_move)) + +@weakref_lru_cache +def _move_outvars_to_back(jaxpr, to_move): + new_outvars = ([e for e, m in zip(jaxpr.jaxpr.outvars, to_move) if not m] + + [e for e, m in zip(jaxpr.jaxpr.outvars, to_move) if m]) + return jaxpr.replace(jaxpr=jaxpr.jaxpr.replace(outvars=new_outvars)) + + class DynamicJaxprTracer(core.Tracer): - __slots__ = ['aval', '_debug_info'] + __slots__ = ['aval', 'mutable_qdd', '_debug_info'] def __init__(self, trace: DynamicJaxprTrace, - aval: core.AbstractValue, + aval: core.AbstractValue | core.AvalQDD, line_info: source_info_util.SourceInfo | None = None): + if isinstance(aval, core.AvalQDD): + assert aval.qdd is not None + aval, qdd = aval.aval, aval.qdd + else: + assert not aval.has_qdd + qdd = None self._trace = trace self._line_info = line_info self._debug_info = self._trace.frame.debug_info # for UnexpectedTracerError self.aval = aval # type: ignore[misc] + self.mutable_qdd = core.MutableQuasiDynamicData(qdd) + + @property + def aval_mutable_qdd(self): + aval = self.aval + if aval.has_qdd: + return core.AvalMutableQDD(aval, self.mutable_qdd) + else: + return aval def full_lower(self): var = self._trace.frame.tracer_to_var.get(id(self)) if var is None: return self + if isinstance(var, Literal): + return var.val val = self._trace.frame.constvar_to_val.get(var) if val is None: return self return core.full_lower(val) @@ -1610,7 +1700,8 @@ def _origin_msg(self): def get_referent(self): frame = self._trace.frame - val = frame.constvar_to_val.get(frame.tracer_to_var.get(id(self))) + var = frame.tracer_to_var.get(id(self)) + val = frame.constvar_to_val.get(var) if isinstance(var, Var) else None return self if val is None else get_referent(val) core.pytype_aval_mappings[DynamicJaxprTracer] = lambda x: x.aval @@ -1633,16 +1724,19 @@ def make_jaxpr_effects(constvars, invars, outvars, eqns) -> effects.Effects: f"\n Equation: {eqn}\n" "\n Jaxpr: " f"{core.Jaxpr(constvars, invars, outvars, eqns, set())}") - invar = eqn.invars[eff.input_index] - if invar in mut_arrays: + eqn_invar = eqn.invars[eff.input_index] + if eqn_invar in mut_arrays: continue - if (input_index := all_vars.get(invar, sentinel)) is sentinel: + if (input_index := all_vars.get(eqn_invar, sentinel)) is sentinel: + # TODO(mattjj): ask for forgiveness + dbg = type('Fake', (), {'resolve_result_paths': lambda _: None})() raise ValueError( f"`JaxprInputEffect` {eff} does not have " - f"corresponding input: {invar}." + f"corresponding jaxpr input: {eqn_invar=}." f"\n Equation: {eqn}\n" + f"\n Effects: {eqn.effects}\n" "\n Jaxpr: " - f"{core.Jaxpr(constvars, invars, outvars, eqns, set())}") + f"{core.Jaxpr(constvars, invars, outvars, eqns, set(), dbg)}") # type: ignore eff = eff.replace(input_index=input_index) jaxpr_effects.add(eff) return jaxpr_effects @@ -1650,17 +1744,20 @@ def make_jaxpr_effects(constvars, invars, outvars, eqns) -> effects.Effects: class JaxprStackFrame: gensym: Callable[[AbstractValue], Var] - tracer_to_var: dict[TracerId, Var] + tracer_to_var: dict[TracerId, Atom] constid_to_tracer: dict[ConstId, Tracer] constvar_to_val: dict[Var, Any] tracers: list[DynamicJaxprTracer] # hold onto strong refs for all tracers eqns: list[JaxprEqn] invars: list[Var] effects: core.Effects - attrs_tracked: list[tuple[Any, str]] + attrs_tracked: list[tuple[Any, str, AttrKind]] attrs_inits: list attrs_vars: list[Var] debug_info: core.DebugInfo + is_high: bool + mutable_qdds: list[tuple[Var, core.MutableQuasiDynamicData]] + def __init__(self, debug_info: core.DebugInfo): self.gensym = core.gensym() @@ -1675,37 +1772,49 @@ def __init__(self, debug_info: core.DebugInfo): self.attrs_inits = [] self.attrs_vars = [] self.debug_info = debug_info + self.is_high = False + self.mutable_qdds = [] def add_eqn(self, eqn: core.JaxprEqn): self.eqns.append(eqn) - def to_jaxpr(self, trace: DynamicJaxprTrace, - out_tracers: Sequence[Tracer], - debug_info: core.DebugInfo, - ) -> tuple[Jaxpr, list[Any], list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str]]]]: + def reset_states(self, trace): + reset_states(trace, self.attrs_tracked, self.attrs_inits) + + def to_jaxpr( + self, trace: DynamicJaxprTrace, + out_tracers: Sequence[Tracer], + debug_info: core.DebugInfo, + source_info: SourceInfo, + ) -> tuple[Jaxpr, list[Any], list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str, AttrKind]]]]: # It's not necessary, but we keep the tracer-to-var mapping injective: - assert len(self.tracer_to_var) == len(set(self.tracer_to_var.values())) + vars = [v for v in self.tracer_to_var.values() if not isinstance(v, Literal)] + assert len(vars) == len(set(vars)) invars = self.attrs_vars + self.invars state_ans, end_trees = unzip2( tree_flatten(t) for t in get_states(self.attrs_tracked)) - state_outvars = [self.tracer_to_var[id(trace.to_jaxpr_tracer(x))] + state_outvars = [self.tracer_to_var[id(trace.to_jaxpr_tracer(x, source_info))] for xs in state_ans for x in xs] explicit_outvars = [self.tracer_to_var[id(t)] for t in out_tracers] outvars = state_outvars + explicit_outvars constvars, constvals = unzip2(self.constvar_to_val.items()) jaxpr_effects = make_jaxpr_effects(constvars, self.invars, explicit_outvars, self.eqns) + + # TODO(dougalm): handle qdd for consts + for v, qdd in self.mutable_qdds: + v.final_qdd = qdd.cur_val + jaxpr = Jaxpr(constvars, invars, outvars, self.eqns, jaxpr_effects, - debug_info) - jaxpr, constvals = _const_folding_and_forwarding(jaxpr, constvals) - jaxpr, constvals = _inline_literals(jaxpr, constvals) + debug_info, self.is_high) + jaxpr, constvals = _drop_unused_vars(jaxpr, constvals) init_trees = [tree_structure(init_val) for init_val in self.attrs_inits] - set_states(self.attrs_tracked, self.attrs_inits) return jaxpr, list(constvals), zip(init_trees, end_trees, self.attrs_tracked) def to_jaxpr2(self, out_tracers: Sequence[core.Tracer], debug_info: core.DebugInfo): # It's not necessary, but we keep the tracer-to-var mapping injective: - assert len(self.tracer_to_var) == len(set(self.tracer_to_var.values())) + vars = [v for v in self.tracer_to_var.values() if not isinstance(v, Literal)] + assert len(vars) == len(set(vars)) constvars, constvals = unzip2(self.constvar_to_val.items()) expl_outvars = [self.tracer_to_var[id(t)] for t in out_tracers] jaxpr_effects = make_jaxpr_effects(constvars, self.invars, expl_outvars, @@ -1713,8 +1822,7 @@ def to_jaxpr2(self, out_tracers: Sequence[core.Tracer], jaxpr = Jaxpr(constvars, self.invars, expl_outvars, self.eqns, jaxpr_effects, debug_info) # We can't run check_jaxpr until after we normalize. - jaxpr, constvals = _const_folding_and_forwarding(jaxpr, constvals) - jaxpr, constvals = _inline_literals(jaxpr, constvals) + jaxpr, constvals = _drop_unused_vars(jaxpr, constvals) jaxpr, out_type = _add_implicit_outputs(jaxpr) config.enable_checks.value and core.check_jaxpr(jaxpr) return jaxpr, out_type, constvals @@ -1724,12 +1832,16 @@ def newvar(self, aval): # this aval may have tracers in it, so we replace those with variables new_shape = [self.tracer_to_var[id(d)] if isinstance(d, Tracer) else d for d in aval.shape] + new_shape = [d.val if isinstance(d, Literal) else d for d in new_shape] aval = aval.update(shape=tuple(new_shape)) - return self.gensym(aval) + if isinstance(aval, core.AvalQDD): + return self.gensym(aval.aval, initial_qdd=aval.qdd) + else: + return self.gensym(aval) def find_progenitors(self, tracer): var = self.tracer_to_var.get(id(tracer)) - if not var: + if not var or isinstance(var, Literal): return None, None active_vars = {var} for eqn in self.eqns[::-1]: @@ -1739,49 +1851,11 @@ def find_progenitors(self, tracer): active_vars.update({v for v in eqn.invars if type(v) is Var}) invar_positions = [i for i, v in enumerate(self.invars) if v in active_vars] constvars = active_vars & set(self.constvar_to_val) - const_eqns = [eqn for eqn in self.eqns - if {v for v in eqn.invars if type(v) is Var} & constvars] + const_eqns = [eqn for eqn in self.eqns if any( + v in constvars if type(v) is Var else type(v) is Literal + for v in eqn.invars)] return invar_positions, const_eqns -def _const_folding_and_forwarding( - jaxpr: Jaxpr, constvals: Sequence[Any]) -> tuple[Jaxpr, tuple[Any, ...]]: - consts: dict[Var, Any] = dict(zip(jaxpr.constvars, constvals)) - var_subs: dict[Var, Var] = {} # not Dict[Var, Atom] b/c literals not inlined - new_eqns = [] - def apply_var_sub(a: Atom) -> Atom: - return var_subs.get(a, a) if isinstance(a, Var) else a - for eqn in jaxpr.eqns: - # always apply invar substitutions - eqn = eqn.replace(invars=[apply_var_sub(v) for v in eqn.invars]) - # if any inputs are constants and we have a constant-folding rule, apply it - has_input_effect = any(isinstance(eff, effects.JaxprInputEffect) - for eff in eqn.effects) - if (eqn.primitive in const_fold_rules and - any(v in consts for v in eqn.invars if isinstance(v, Var)) and - not has_input_effect): - consts_in = [consts.get(v) if isinstance(v, Var) else None - for v in eqn.invars] - consts_out, new_eqn = const_fold_rules[eqn.primitive](consts_in, eqn) - assert (new_eqn is None) == all(c is not None for c in consts_out) - for v, c in zip(eqn.outvars, consts_out): - if c is not None: consts[v] = c - if new_eqn is None: continue - else: eqn = new_eqn - # if the application trivially maps some inputs to outputs, simplify - if eqn.primitive in forwarding_rules and not has_input_effect: - fwd_vars, new_eqn = forwarding_rules[eqn.primitive](eqn) - for v_orig, v_new in zip(eqn.outvars, fwd_vars): - if v_new is not None: var_subs[v_orig] = v_new - if new_eqn is None: continue - else: eqn = new_eqn - new_eqns.append(eqn) - new_constvars, new_constvals = unzip2(consts.items()) - new_outvars = [apply_var_sub(v) for v in jaxpr.outvars] - jaxpr_effects = make_jaxpr_effects(new_constvars, jaxpr.invars, new_outvars, - new_eqns) - new_jaxpr = Jaxpr(new_constvars, jaxpr.invars, new_outvars, new_eqns, - jaxpr_effects, jaxpr.debug_info) - return new_jaxpr, new_constvals ConstFoldRule = Callable[ [list[Union[Any, None]], JaxprEqn], @@ -1791,119 +1865,126 @@ def apply_var_sub(a: Atom) -> Atom: ForwardingRule = Callable[ [JaxprEqn], - tuple[list[Union[Var, None]], Union[JaxprEqn, None]] + tuple[list[Union[int, None]], Union[JaxprEqn, None]] ] forwarding_rules: dict[Primitive, ForwardingRule] = {} -def _inline_literals( +def _drop_unused_vars( jaxpr: Jaxpr, constvals: Sequence[Any] ) -> tuple[Jaxpr, list[Any]]: - # This function also prunes unused constants and inserts `dropvar` symbols. - input_effects = {eff for eff in jaxpr.effects - if isinstance(eff, effects.JaxprInputEffect)} - # Don't inline any literal with an input effect - has_input_effect = [any(eff.input_index == i for eff in input_effects) - for i in range(len(constvals))] - lits = {v: Literal(c, v.aval) for v, c, e in zip(jaxpr.constvars, constvals, - has_input_effect) - if type(c) in core.literalable_types and not np.shape(c) and not e} - def lit(a: Atom) -> Literal | None: - return (a if isinstance(a, Literal) else lits.get(a) if isinstance(a, Var) - else None) - newname: Callable[[AbstractValue], Var] = core.gensym() - newvars: dict[Var, Var] = {} - newvar = lambda aval: newname(_substitute_vars_in_type(lits, newvars, aval)) - var = lambda v: newvars.get(v) or newvars.setdefault(v, newvar(v.aval)) - lit_or_var = ( - lambda a: a if isinstance(a, Literal) else (lit(a) or var(a)) - ) - dropvar = lambda aval: DropVar(_substitute_vars_in_type(lits, newvars, aval)) - - def vars_in_shape(aval: AbstractValue) -> Sequence[Var]: + def vars(atom: Atom) -> list[Var]: + if isinstance(atom, Literal): + return [] + aval = atom.aval if isinstance(aval, DShapedArray): - return [d for d in aval.shape if isinstance(d, Var)] - return [] - - used = {v for eqn in jaxpr.eqns for atom in eqn.invars - for v in it.chain([atom], vars_in_shape(atom.aval)) - if isinstance(atom, Var)} - used |= {v for outvar in jaxpr.outvars - for v in it.chain([outvar], vars_in_shape(outvar.aval))} - new_constvars = [var(v) for v in jaxpr.constvars if v in used and not lit(v)] - new_constvals = [c for v, c in zip(jaxpr.constvars, constvals) - if v in used and not lit(v)] - new_invars = [var(v) for v in jaxpr.invars] - new_eqns = [] - for eqn in jaxpr.eqns: - invars = [lit_or_var(x) for x in eqn.invars] - outvars = [var(v) if v in used else dropvar(v.aval) for v in eqn.outvars] - new_eqns.append(eqn.replace(invars=invars, outvars=outvars)) - new_outvars = [lit_or_var(v) for v in jaxpr.outvars] - jaxpr_effects = make_jaxpr_effects(new_constvars, new_invars, new_outvars, - new_eqns) - new_jaxpr = Jaxpr(new_constvars, new_invars, new_outvars, new_eqns, - jaxpr_effects, jaxpr.debug_info) - return new_jaxpr, new_constvals + return [atom] + [d for d in aval.shape if isinstance(d, Var)] + return [atom] + used: set[Var] = {v for atom in jaxpr.outvars for v in vars(atom)} + for eqn in jaxpr.eqns[::-1]: + eqn.outvars = [v if v in used else DropVar(v.aval) for v in eqn.outvars] + used.update(v for atom in eqn.invars for v in vars(atom)) + cvars, constvals = unzip2( + (v, val) for v, val in zip(jaxpr.constvars, constvals) if v in used) + jaxpr._constvars = list(cvars) + jaxpr._effects = make_jaxpr_effects(jaxpr.constvars, jaxpr.invars, + jaxpr.outvars, jaxpr.eqns) + return jaxpr, list(constvals) + + +@cache() +def _cached_abstract_eval(primitive: core.Primitive, *aval_qdds, **params): + return primitive.abstract_eval(*aval_qdds, **params) + + +def _verify_params_are_hashable( + primitive: core.Primitive, params: dict[str, Any]) -> None: + for k, v in params.items(): + try: + hash(v) + except TypeError as e: + raise TypeError( + "As of JAX v0.7, parameters to jaxpr equations must have __hash__ and " + f"__eq__ methods. In a call to primitive {primitive}, the value of " + f"parameter {k} was not hashable: {v}") from e class DynamicJaxprTrace(core.Trace): - __slots__ = ("frame", "tag") + __slots__ = ("frame", "tag", "parent_trace") - def __init__(self, debug_info: core.DebugInfo): + def __init__(self, debug_info: core.DebugInfo, parent_trace=None, lower=False): super().__init__() + self.requires_low = lower self.frame = JaxprStackFrame(debug_info) + self.parent_trace = parent_trace def invalidate(self): # avoid cyclic refs self.frame.tracers = [] self.frame.constid_to_tracer = {} + self.frame.constvar_to_val = {} + self.frame.attrs_tracked = [] + self.frame.attrs_inits = [] - def to_jaxpr_tracer(self, x): + def to_jaxpr_tracer(self, x, source_info: SourceInfo): as_local_var = self.frame.tracer_to_var.get(id(x)) if as_local_var is None: if hasattr(x, "dimension_as_value"): # Used for shape_poly._DimExpr with core.set_current_trace(self): x = x.dimension_as_value() - return self.to_jaxpr_tracer(x) + return self.to_jaxpr_tracer(x, source_info) else: - return self.new_const(x) + return self.new_const(x, source_info) else: return x - def new_arg(self, aval): - tracer = DynamicJaxprTracer(self, aval, source_info_util.current()) + def new_arg(self, aval, source_info: SourceInfo): + tracer = DynamicJaxprTracer(self, aval, source_info) self.frame.tracers.append(tracer) self.frame.tracer_to_var[id(tracer)] = var = self.frame.newvar(aval) self.frame.invars.append(var) + self.frame.mutable_qdds.append((var, tracer.mutable_qdd)) return tracer - def new_const(self, c): + def new_const(self, c, source_info: SourceInfo): # TODO(mattjj): for ints, or hashable consts, don't rely on id tracer = self.frame.constid_to_tracer.get(id(c)) if tracer is None: aval = get_aval(c) - if hasattr(aval, "weak_type"): - aval = aval.update_weak_type(dtypes.is_weakly_typed(c)) - aval = self._lift_tracers_in_aval(aval) - tracer = self._new_const(aval, c) + if aval.has_qdd: + with core.set_current_trace(self.parent_trace): + aval = core.AvalQDD(aval, core.cur_qdd(c)) + aval = self._lift_tracers_in_aval(aval, source_info) + tracer = self._new_const(aval, c, source_info) return tracer pure = lift = new_const - def _new_const(self, aval, c) -> DynamicJaxprTracer: - tracer = DynamicJaxprTracer(self, aval, source_info_util.current()) + def _new_const(self, aval, c, source_info: SourceInfo) -> DynamicJaxprTracer: + tracer = DynamicJaxprTracer(self, aval, source_info) self.frame.tracers.append(tracer) - self.frame.tracer_to_var[id(tracer)] = var = self.frame.newvar(aval) - self.frame.constid_to_tracer[id(c)] = tracer - self.frame.constvar_to_val[var] = c + if core.is_literalable(c): + self.frame.tracer_to_var[id(tracer)] = Literal(c, aval) + else: + self.frame.tracer_to_var[id(tracer)] = var = self.frame.newvar(aval) + self.frame.constid_to_tracer[id(c)] = tracer + if isinstance(aval, core.AvalQDD): + self.frame.mutable_qdds.append((var, tracer.mutable_qdd)) + self.frame.constvar_to_val[var] = c return tracer - def _lift_tracers_in_aval(self, aval): + def get_const(self, tracer) -> Any: + var = self.frame.tracer_to_var.get(id(tracer)) + if isinstance(var, Literal): + return var.val + elif var is not None: + return self.frame.constvar_to_val.get(var) + + def _lift_tracers_in_aval(self, aval, source_info: SourceInfo): if (not isinstance(aval, DShapedArray) or not any(isinstance(d, Tracer) for d in aval.shape)): return aval - shape = [self.to_jaxpr_tracer(d) if isinstance(d, Tracer) else d + shape = [self.to_jaxpr_tracer(d, source_info) if isinstance(d, Tracer) else d for d in aval.shape] return aval.update(shape=tuple(shape)) @@ -1920,46 +2001,93 @@ def makevar(self, tracer): var = self.frame.tracer_to_var[id(tracer)] = self.frame.newvar(tracer.aval) return var - def is_const(self, tracer): - return self.frame.tracer_to_var.get(id(tracer)) is None + def cur_qdd(self, x): + source_info = source_info_util.current() + return self.to_jaxpr_tracer(x, source_info=source_info).mutable_qdd.cur_val def process_primitive(self, primitive, tracers, params): - if (config.eager_constant_folding.value and all(map(self.is_const, tracers))): + self.frame.is_high |= primitive.is_high(**params) + if config.eager_constant_folding.value and not any(isinstance(x, Tracer) for x in tracers): return primitive.bind_with_trace(core.eval_trace, tracers, params) - jaxpr_tracers = map(self.to_jaxpr_tracer, tracers) + source_info = source_info_util.current() + to_jaxpr_tracer = partial(self.to_jaxpr_tracer, source_info=source_info) + jaxpr_tracers = map(to_jaxpr_tracer, tracers) if primitive in custom_staging_rules: - return custom_staging_rules[primitive](self, *jaxpr_tracers, **params) - return self.default_process_primitive(primitive, jaxpr_tracers, params) + return custom_staging_rules[primitive](self, source_info, *jaxpr_tracers, + **params) + return self.default_process_primitive( + primitive, jaxpr_tracers, params, source_info) + + def default_process_primitive(self, primitive, tracers, params, + source_info=None): + aval_qdds = [t.aval_mutable_qdd for t in tracers] + # TODO(mattjj): make custom_lin have hashable params. + # TODO(dougalm): add an attribute to primitives to mark primitives with + # effectful abstract_eval rules. + if ( + primitive.name == "custom_lin" + or config.dynamic_shapes.value + or any( + isinstance(aval, core.MutableQuasiDynamicData) for aval in aval_qdds + ) + ): + out_avals, effs = primitive.abstract_eval(*aval_qdds, **params) + else: + try: + out_avals, effs = _cached_abstract_eval(primitive, *aval_qdds, **params) + except Exception as e: + # TODO(phawkins): remove this 3 months after the release of JAX v0.7. + _verify_params_are_hashable(primitive, params) + raise - def default_process_primitive(self, primitive, tracers, params): - avals = [t.aval for t in tracers] - out_avals, effects = primitive.abstract_eval(*avals, **params) if isinstance(out_avals, (tuple, list)) != primitive.multiple_results: raise ValueError(f"{primitive}.abstract_eval() method should return " f"a tuple or a list iff {primitive}.multiple_results.") out_avals = [out_avals] if not primitive.multiple_results else out_avals - source_info = source_info_util.current() + source_info = source_info or source_info_util.current() out_tracers = [DynamicJaxprTracer(self, a, source_info) for a in out_avals] invars = map(self.getvar, tracers) outvars = map(self.makevar, out_tracers) - eqn = new_jaxpr_eqn(invars, outvars, primitive, params, effects, - source_info) - self.frame.add_eqn(eqn) + eqn = new_jaxpr_eqn(invars, outvars, primitive, params, effs, source_info) + no_input_effects = not any(isinstance(e, effects.JaxprInputEffect) + for e in eqn.effects) + + # Constant folding + if no_input_effects and primitive in const_fold_rules: + consts_in = map(self.get_const, tracers) + if any(c is not None for c in consts_in): + consts_out, eqn = const_fold_rules[primitive](consts_in, eqn) + assert (eqn is None) == all(c is not None for c in consts_out) + for i, c in enumerate(consts_out): + if c is not None: + out_tracers[i] = self.new_const(c, source_info) + + # Input-to-output tracer forwarding + if eqn is not None and no_input_effects and primitive in forwarding_rules: + in_fwd, eqn = forwarding_rules[primitive](eqn) + for out_idx, in_idx in enumerate(in_fwd): + if in_idx is not None: + out_tracers[out_idx] = tracers[in_idx] + + if eqn is not None: + self.frame.add_eqn(eqn) return out_tracers if primitive.multiple_results else out_tracers.pop() def process_call(self, call_primitive, f: lu.WrappedFun, explicit_tracers, params): + source_info = source_info_util.current() + to_jaxpr_tracer = partial(self.to_jaxpr_tracer, source_info=source_info) if f.in_type is None: f = lu.annotate(f, tuple((get_aval(t), True) for t in explicit_tracers)) assert f.in_type is not None - implicit_tracers = _extract_implicit_args(self, f.in_type, explicit_tracers) - in_tracers = map(self.to_jaxpr_tracer, [*implicit_tracers, *explicit_tracers]) + implicit_tracers = _extract_implicit_args(self, f.in_type, explicit_tracers, + source_info) + in_tracers = map(to_jaxpr_tracer, [*implicit_tracers, *explicit_tracers]) # TODO(mattjj): check in_tracers are consistent with f.in_type annotation jaxpr, out_type, consts = trace_to_jaxpr_dynamic2(f) if params.get('inline', False): return core.eval_jaxpr(jaxpr, consts, *in_tracers, propagate_source_info=False) - source_info = source_info_util.current() out_tracers: list[Tracer] = [] for aval, _ in out_type: if type(aval) is DShapedArray: @@ -1969,7 +2097,7 @@ def process_call(self, call_primitive, f: lu.WrappedFun, aval = aval.update(shape=tuple(get_referent(d) for d in shape)) out_tracers.append(DynamicJaxprTracer(self, aval, source_info)) invars = map(self.getvar, in_tracers) - constvars = map(self.getvar, map(self.to_jaxpr_tracer, consts)) + constvars = map(self.getvar, map(to_jaxpr_tracer, consts)) outvars = map(self.makevar, out_tracers) new_params = dict(params, call_jaxpr=convert_constvars_jaxpr(jaxpr)) update_params = call_param_updaters.get(call_primitive) @@ -1982,7 +2110,9 @@ def process_call(self, call_primitive, f: lu.WrappedFun, return [t for t, (_, keep) in zip(out_tracers, out_type) if keep] def process_map(self, map_primitive, f: lu.WrappedFun, tracers, params): - tracers = map(self.to_jaxpr_tracer, tracers) + source_info = source_info_util.current() + to_jaxpr_tracer = partial(self.to_jaxpr_tracer, source_info=source_info) + tracers = map(to_jaxpr_tracer, tracers) in_avals = [t.aval for t in tracers] axis_name, axis_size = params['axis_name'], params['axis_size'] reduced_in_avals = [core.mapped_aval(axis_size, in_axis, a) @@ -2001,10 +2131,9 @@ def process_map(self, map_primitive, f: lu.WrappedFun, tracers, params): out_avals = [core.unmapped_aval(axis_size, out_axis, a) if out_axis is not None else a for a, out_axis in zip(reduced_out_avals, out_axes)] - source_info = source_info_util.current() out_tracers = [DynamicJaxprTracer(self, a, source_info) for a in out_avals] invars = map(self.getvar, tracers) - constvars = map(self.getvar, map(self.to_jaxpr_tracer, consts)) + constvars = map(self.getvar, map(to_jaxpr_tracer, consts)) outvars = map(self.makevar, out_tracers) new_in_axes = (None,) * len(consts) + params['in_axes'] new_params = dict(params, in_axes=new_in_axes, out_axes=out_axes, @@ -2022,12 +2151,15 @@ def process_map(self, map_primitive, f: lu.WrappedFun, tracers, params): def process_custom_jvp_call(self, prim, fun: lu.WrappedFun, jvp: lu.WrappedFun, tracers, symbolic_zeros: bool): - tracers = map(self.to_jaxpr_tracer, tracers) + source_info = source_info_util.current() + to_jaxpr_tracer = partial(self.to_jaxpr_tracer, source_info=source_info) + tracers = map(to_jaxpr_tracer, tracers) in_avals = [t.aval for t in tracers] in_tangent_avals = [t.to_tangent_aval() for t in in_avals] fun_jaxpr, out_avals, consts, () = trace_to_jaxpr_dynamic(fun, in_avals) closed_fun_jaxpr = core.ClosedJaxpr(convert_constvars_jaxpr(fun_jaxpr), ()) + @partial(lu.wrap_init, debug_info=jvp.debug_info) @_memoize def jvp_jaxpr_thunk(*in_zeros): for store in jvp.stores: store and store.reset() @@ -2039,29 +2171,32 @@ def jvp_jaxpr_thunk(*in_zeros): out_tracers = [DynamicJaxprTracer(self, a) for a in out_avals] invars = map(self.getvar, tracers) - constvars = map(self.getvar, map(self.to_jaxpr_tracer, consts)) + constvars = map(self.getvar, map(to_jaxpr_tracer, consts)) outvars = map(self.makevar, out_tracers) eqn = new_jaxpr_eqn([*constvars, *invars], outvars, prim, dict(call_jaxpr=closed_fun_jaxpr, - jvp_jaxpr_fun=lu.wrap_init(jvp_jaxpr_thunk, - debug_info=jvp.debug_info), + jvp_jaxpr_fun=jvp_jaxpr_thunk, num_consts=len(consts), symbolic_zeros=symbolic_zeros), fun_jaxpr.effects, - source_info_util.current()) + source_info) self.frame.add_eqn(eqn) return out_tracers def process_custom_vjp_call(self, prim: core.Primitive, fun: lu.WrappedFun, fwd: lu.WrappedFun, bwd: lu.WrappedFun, tracers, - out_trees: Callable[[], Sequence[PyTreeDef]], + out_trees: Callable[[], tuple[PyTreeDef, PyTreeDef, list[int | None]]], symbolic_zeros: bool): - tracers = map(self.to_jaxpr_tracer, tracers) + source_info = source_info_util.current() + to_jaxpr_tracer = partial(self.to_jaxpr_tracer, source_info=source_info) + tracers = map(to_jaxpr_tracer, tracers) in_avals = [t.aval for t in tracers] fun_jaxpr, out_avals, consts, _ = trace_to_jaxpr_dynamic(fun, in_avals) + num_consts = len(consts) closed_fun_jaxpr = core.ClosedJaxpr(convert_constvars_jaxpr(fun_jaxpr), ()) + @partial(lu.wrap_init, debug_info=fwd.debug_info) @_memoize def fwd_jaxpr_from_zeros(*zeros): for store in fwd.stores: store and store.reset() @@ -2070,19 +2205,23 @@ def fwd_jaxpr_from_zeros(*zeros): if attrs: raise NotImplementedError return jaxpr, consts - out_tracers = [DynamicJaxprTracer(self, a) for a in out_avals] + def out_trees_(): + out_tree, res_tree, input_fwds = out_trees() + input_fwds = [f if f is None else f + num_consts for f in input_fwds] + return out_tree, res_tree, input_fwds + + out_tracers = [DynamicJaxprTracer(self, a, source_info) for a in out_avals] invars = map(self.getvar, tracers) - constvars = map(self.getvar, map(self.to_jaxpr_tracer, consts)) + constvars = map(self.getvar, map(to_jaxpr_tracer, consts)) outvars = map(self.makevar, out_tracers) - eqn = new_jaxpr_eqn([*constvars, *invars], outvars, - prim.initial_style, # pytype: disable=attribute-error - dict(fun_jaxpr=closed_fun_jaxpr, + eqn = new_jaxpr_eqn([*constvars, *invars], outvars, prim, + dict(call_jaxpr=closed_fun_jaxpr, fwd_jaxpr_thunk=fwd_jaxpr_from_zeros, - num_consts=len(consts), - bwd=bwd, out_trees=out_trees, + num_consts=num_consts, + bwd=bwd, out_trees=out_trees_, symbolic_zeros=symbolic_zeros), fun_jaxpr.effects, - source_info_util.current()) + source_info) self.frame.add_eqn(eqn) return out_tracers @@ -2092,7 +2231,9 @@ def process_custom_transpose(self, prim: core.Primitive, # type: ignore[overrid out_types, lin_tree: PyTreeDef, res_tree: PyTreeDef, out_tree: PyTreeDef): - tracers = map(self.to_jaxpr_tracer, tracers) + source_info = source_info_util.current() + to_jaxpr_tracer = partial(self.to_jaxpr_tracer, source_info=source_info) + tracers = map(to_jaxpr_tracer, tracers) tracers_res, tracers_lin = split_list(tracers, [res_tree.num_leaves]) in_avals_p = [t.aval for t in tracers] @@ -2112,9 +2253,9 @@ def transpose_jaxpr_thunk(): jaxpr, _, consts, () = trace_to_jaxpr_dynamic(transpose_flat, in_avals_t) return jaxpr, consts - out_tracers = [DynamicJaxprTracer(self, a) for a in out_avals] + out_tracers = [DynamicJaxprTracer(self, a, source_info) for a in out_avals] invars = map(self.getvar, tracers) - constvars = map(self.getvar, map(self.to_jaxpr_tracer, call_consts)) + constvars = map(self.getvar, map(to_jaxpr_tracer, call_consts)) outvars = map(self.makevar, out_tracers) eqn = new_jaxpr_eqn([*constvars, *invars], outvars, prim, dict(call_jaxpr=closed_call_jaxpr, @@ -2122,13 +2263,13 @@ def transpose_jaxpr_thunk(): out_types=out_types, res_tree=res_tree, lin_tree=lin_tree, out_tree=out_tree), closed_call_jaxpr.effects, - source_info_util.current()) + source_info) self.frame.add_eqn(eqn) return out_tracers def to_jaxpr(self, out_tracers: Sequence[Tracer], - debug_info: core.DebugInfo): - return self.frame.to_jaxpr(self, out_tracers, debug_info) + debug_info: core.DebugInfo, source_info: SourceInfo): + return self.frame.to_jaxpr(self, out_tracers, debug_info, source_info) custom_staging_rules: dict[Primitive, Callable] = {} @@ -2171,24 +2312,48 @@ def trace_to_jaxpr_dynamic( in_avals: Sequence[AbstractValue], *, keep_inputs: list[bool] | None = None, + lower: bool = False, ) -> tuple[Jaxpr, list[AbstractValue], list[Any], - list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str]]]]: + list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str, AttrKind]]]]: keep_inputs = [True] * len(in_avals) if keep_inputs is None else keep_inputs - trace = DynamicJaxprTrace(fun.debug_info) + parent_trace = core.trace_ctx.trace + trace = DynamicJaxprTrace(fun.debug_info, parent_trace=parent_trace, lower=lower) with core.ensure_no_leaks(trace), source_info_util.reset_name_stack(): - in_tracers = _input_type_to_tracers(trace.new_arg, in_avals) + source_info = source_info_util.current() + in_tracers = _input_type_to_tracers( + partial(trace.new_arg, source_info=source_info), in_avals) in_tracers = [t for t, keep in zip(in_tracers, keep_inputs) if keep] - with core.set_current_trace(trace): - ans = fun.call_wrapped(*in_tracers) - out_tracers = map(trace.to_jaxpr_tracer, ans) - _check_no_returned_refs(fun.debug_info, out_tracers) - jaxpr, consts, attrs_tracked = trace.to_jaxpr(out_tracers, fun.debug_info) - del trace, fun, in_tracers, out_tracers, ans + try: + with core.set_current_trace(trace): + ans = fun.call_wrapped(*in_tracers) + _check_returned_jaxtypes(fun.debug_info, ans) + out_tracers = map(partial(trace.to_jaxpr_tracer, source_info=source_info), ans) + _check_no_returned_refs(fun.debug_info, out_tracers) + jaxpr, consts, attrs_tracked = trace.frame.to_jaxpr( + trace, out_tracers, fun.debug_info, source_info) + del fun, in_tracers, out_tracers, ans + finally: + trace.frame.reset_states(trace) + del trace config.enable_checks.value and core.check_jaxpr(jaxpr) return jaxpr, [v.aval for v in jaxpr.outvars], consts, attrs_tracked +def _check_returned_jaxtypes(dbg, out_tracers): + for i, x in enumerate(out_tracers): + try: + core.typeof(x) + except TypeError: + if (dbg and len(paths := dbg.result_paths()) > i and + (p := paths[i].removeprefix('result'))): + extra = f' at output component {p}' + else: + extra = '' + raise TypeError( + f"function {dbg.func_src_info} traced for {dbg.traced_for} returned a " + f"value of type {type(x)}{extra}, which is not a valid JAX type") from None + def _check_no_returned_refs( dbg: core.DebugInfo, out_tracers: Sequence[DynamicJaxprTracer] @@ -2223,14 +2388,17 @@ def trace_to_jaxpr_dynamic2( ) -> tuple[Jaxpr, OutputType, list[Any]]: assert fun.in_type is not None, "fun must be annotated with lu.annotate()" - trace = DynamicJaxprTrace(fun.debug_info) + parent_trace = core.trace_ctx.trace + trace = DynamicJaxprTrace(fun.debug_info, parent_trace=parent_trace) with core.ensure_no_leaks(trace), source_info_util.reset_name_stack(): + source_info = source_info_util.current() in_avals, keep_inputs = unzip2(fun.in_type) - in_tracers = _input_type_to_tracers(trace.new_arg, in_avals) + in_tracers = _input_type_to_tracers( + partial(trace.new_arg, source_info=source_info), in_avals) in_tracers = [t for t, keep in zip(in_tracers, keep_inputs) if keep] with core.set_current_trace(trace): ans = fun.call_wrapped(*in_tracers) - out_tracers = map(trace.to_jaxpr_tracer, ans) + out_tracers = map(partial(trace.to_jaxpr_tracer, source_info=source_info), ans) jaxpr = trace.frame.to_jaxpr2(out_tracers, fun.debug_info) del trace, in_tracers, out_tracers, ans @@ -2242,14 +2410,18 @@ def trace_to_jaxpr_dynamic2( tuple[AbstractedAxisName, ...], ] -AttrsTracked = list[tuple[Any, str]] +AttrsTracked = list[tuple[Any, str, AttrKind]] AttrStates = list -def set_states(attrs_tracked: AttrsTracked, vals: AttrStates): - for ((obj, attr), val) in zip(attrs_tracked, vals): - setattr(obj, attr, val) +def reset_states(trace, attrs_tracked: AttrsTracked, init_vals: AttrStates) -> None: + for ((obj, attr, kind), val) in zip(attrs_tracked, init_vals): + setattr(obj, attr, val) if val is not dne_sentinel else delattr(obj, attr) -def get_states(attrs_tracked: AttrsTracked): - return [getattr(obj, attr) for (obj, attr) in attrs_tracked] +def get_states(attrs_tracked: AttrsTracked) -> list[PyTree]: + return [getattr(obj, attr) for (obj, attr, kind) in attrs_tracked] + +@register_static +class DoesNotExist: ... +dne_sentinel = DoesNotExist() def infer_lambda_input_type( @@ -2384,8 +2556,7 @@ def _add_implicit_outputs(jaxpr: Jaxpr) -> tuple[Jaxpr, OutputType]: kept_outs = [False] * len(impl_outvars) + [True] * len(expl_outvars) out_type = tuple(zip(out_avals, kept_outs)) - new_jaxpr = Jaxpr(jaxpr.constvars, jaxpr.invars, outvars, jaxpr.eqns, - jaxpr.effects, jaxpr.debug_info) + new_jaxpr = jaxpr.replace(outvars=outvars) config.enable_checks.value and core.check_jaxpr(jaxpr) return new_jaxpr, out_type @@ -2401,7 +2572,7 @@ def __hash__(self): def _extract_implicit_args( trace: DynamicJaxprTrace, in_type: Sequence[tuple[AbstractValue, bool]], - explicit_tracers: Sequence[DynamicJaxprTracer] + explicit_tracers: Sequence[DynamicJaxprTracer], source_info: SourceInfo, ) -> Sequence[DynamicJaxprTracer]: # First, construct a list to represent the full argument list, leaving the # implicit arguments as Nones for now. @@ -2419,8 +2590,8 @@ def _extract_implicit_args( for d1, d2 in zip(aval.shape, tracer.aval.shape): if isinstance(d1, DBIdx): if tracers[d1.val] is None: - tracers[d1.val] = trace.to_jaxpr_tracer(d2) - assert tracers[d1.val] is trace.to_jaxpr_tracer(d2) + tracers[d1.val] = trace.to_jaxpr_tracer(d2, source_info) + assert tracers[d1.val] is trace.to_jaxpr_tracer(d2, source_info) assert all(t is not None for t in tracers) return [t for t, (_, e) in zip(tracers, in_type) if not e] # type: ignore @@ -2565,33 +2736,59 @@ def instantiate_const_at(trace: JaxprTrace, instantiate: bool, tracer): return tracer def inline_jaxpr_into_trace( - trace: DynamicJaxprTrace, jaxpr: Jaxpr, consts: Sequence[Any], - *arg_tracers: DynamicJaxprTracer) -> list[Any]: + trace: DynamicJaxprTrace, src: SourceInfo, jaxpr: Jaxpr, + consts: Sequence[Any], *arg_tracers: DynamicJaxprTracer) -> list[Any]: # This function is conceptually the same thing as just calling eval_jaxpr, - const_tracers = map(trace.new_const, consts) + const_tracers = map(partial(trace.new_const, source_info=src), consts) constvars = map(trace.getvar, const_tracers) argvars = map(trace.getvar, arg_tracers) - env: dict[Var, Var] = dict(zip([*jaxpr.constvars, *jaxpr.invars], - [*constvars, *argvars])) + const_env: dict[Var, Any] = { + v: c for v, c in zip(constvars, consts) if not isinstance(v, Literal)} + env: dict[Var, Atom] = dict(zip([*jaxpr.constvars, *jaxpr.invars], + [*constvars, *argvars])) - src = source_info_util.current() for eqn in jaxpr.eqns: invars = [x if isinstance(x, Literal) else env[x] for x in eqn.invars] - outvars = [Var('', v.aval) for v in eqn.outvars] + orig_outvars = eqn.outvars + outvars = [Var(v.aval) for v in orig_outvars] src_ = (src if not eqn.source_info.name_stack else src.replace(name_stack=src.name_stack + eqn.source_info.name_stack)) - trace.frame.add_eqn(eqn.replace(invars, outvars, source_info=src_)) - foreach(env.setdefault, eqn.outvars, outvars) - - tracer_env: dict[Var, Any] = dict(zip([*jaxpr.constvars, *jaxpr.invars], - [*consts, *arg_tracers])) - def new_tracer(atom): + eqn = eqn.replace(invars, outvars, source_info=src_) + foreach(env.setdefault, orig_outvars, outvars) + + # We must re-run constant folding when inlining because some jaxpr inputs + # may be consts in the outer scope. + eqn_: JaxprEqn | None = eqn + inp_eff = any(isinstance(e, effects.JaxprInputEffect) for e in eqn.effects) + if eqn.primitive in const_fold_rules and not inp_eff: + consts_in = [v.val if isinstance(v, Literal) else const_env.get(v) + for v in invars] + if any(c is not None for c in consts_in): + consts_out, eqn_ = const_fold_rules[eqn.primitive](consts_in, eqn) + assert (eqn_ is None) == all(c is not None for c in consts_out) + for v, c in zip(orig_outvars, consts_out): + if c is not None: + if core.is_literalable(c): + env[v] = Literal(c, v.aval) + else: + const_env[v] = c + if eqn_ is not None: + trace.frame.add_eqn(eqn_) + + tracer_env: dict[Var, Any] = const_env + tracer_env.update( + {v: t for v, t in zip(argvars, arg_tracers) if not isinstance(v, Literal)} + ) + def maybe_new_tracer(atom): + if isinstance(atom, Literal): + return atom.val + if atom in tracer_env: + return tracer_env[atom] tracer = tracer_env[atom] = DynamicJaxprTracer(trace, atom.aval, src) trace.frame.tracers.append(tracer) - trace.frame.tracer_to_var[id(tracer)] = env[atom] + trace.frame.tracer_to_var[id(tracer)] = atom return tracer - return [x.val if isinstance(x, Literal) else tracer_env[x] if x in tracer_env - else new_tracer(x) for x in jaxpr.outvars] + return [maybe_new_tracer(x if isinstance(x, Literal) else env[x]) for x in jaxpr.outvars] # TODO(mattjj,dougalm): this special handling is to avoid round-tripping the # jaxpr when we do grad-of-pmap. The tag is set by LinearizeTrace.process_call's @@ -2602,3 +2799,37 @@ def _linearize_of_pmap_hack(f: lu.WrappedFun, jaxpr, consts) -> tuple[Jaxpr, lis _, jaxpr = f.f.closure return convert_constvars_jaxpr(jaxpr), [] return jaxpr, consts + + +@weakref_lru_cache +def lower_jaxpr(hi_jaxpr): + lo_avals = [lo_ty for aval in hi_jaxpr.in_aval_qdds for lo_ty in aval.lo_ty()] + f = lu.wrap_init(partial(lower_traceable, hi_jaxpr), + debug_info=hi_jaxpr.jaxpr.debug_info) + lo_jaxpr, _, lo_consts, () = trace_to_jaxpr_dynamic(f, lo_avals, lower=True) + return core.ClosedJaxpr(lo_jaxpr, lo_consts) + +def lower_traceable(jaxpr, *lo_args): + lo_args_ = iter(lo_args) + hi_args = [aval.raise_val(*it.islice(lo_args_, len(aval.lo_ty()))) + if not aval.has_qdd else + aval.new_from_loval(*it.islice(lo_args_, len(aval.lo_ty()))) + for aval in jaxpr.in_aval_qdds] + assert (problem := next(lo_args_, None)) is None + hi_outs = core.jaxpr_as_fun(jaxpr)(*hi_args) + mut_outs = [lo_val for aval, hi_arg in zip(jaxpr.final_aval_qdds, hi_args) if aval.has_qdd + for lo_val in aval.read_loval(hi_arg)] + lo_outs = [lo_val for v, hi_val in zip(jaxpr.jaxpr.outvars, hi_outs) + for lo_val in v.aval.lower_val(hi_val)] + return mut_outs + lo_outs + +def convert_const_himutables(jaxpr): + move = [core.typeof(c).has_qdd for c in jaxpr.consts] + constvals, in_mutables = partition_list(move, jaxpr.consts) + constvars, boxvars = partition_list(move, jaxpr.jaxpr.constvars) + invars = *boxvars, *jaxpr.jaxpr.invars + effects = make_jaxpr_effects(constvars, invars, jaxpr.jaxpr.outvars, + jaxpr.jaxpr.eqns) + new_jaxpr = jaxpr.jaxpr.replace(constvars=constvars, invars=invars, + effects=effects) + return jaxpr.replace(jaxpr=new_jaxpr, consts=constvals), in_mutables diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index c06eda5214ed..c4585663b68d 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -15,7 +15,6 @@ from __future__ import annotations -import enum import collections from collections import namedtuple from collections.abc import Callable, Sequence, Iterable @@ -30,9 +29,8 @@ import numpy as np -import jax - from jax._src import api +from jax._src import array from jax._src import compiler from jax._src import config from jax._src import core @@ -42,11 +40,13 @@ from jax._src import linear_util as lu from jax._src import op_shardings from jax._src import sharding_specs +from jax._src import pjit from jax._src import profiler from jax._src import sharding_impls from jax._src import source_info_util from jax._src import stages from jax._src import tree_util +from jax._src import typing from jax._src import util from jax._src import xla_bridge as xb from jax._src.abstract_arrays import array_types @@ -57,7 +57,7 @@ from jax._src.interpreters import partial_eval as pe from jax._src.interpreters import mlir from jax._src.interpreters import xla -from jax._src.layout import DeviceLocalLayout, AutoLayout, Layout +from jax._src.layout import DeviceLocalLayout, AutoLayout, Format from jax._src.lib import xla_client as xc from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo @@ -68,10 +68,12 @@ from jax._src.sharding_impls import ( ArrayMapping, ArrayMappingOrAutoOrUnspecified, AUTO, UnspecifiedValue, get_array_mapping as _get_array_mapping, array_mapping_to_axis_resources, - SingleDeviceSharding, GSPMDSharding, NamedSharding, PositionalSharding) + SingleDeviceSharding, GSPMDSharding, NamedSharding, + PartitionSpec as P) from jax._src.util import (safe_map, safe_zip, partition_list, wrap_name, tuple_update, tuple_delete, distributed_debug_log, - unzip2, HashableFunction, weakref_lru_cache) + unzip2, HashableFunction, weakref_lru_cache, + tuple_insert) from jax._src.state.types import AbstractRef, RefEffect @@ -83,6 +85,7 @@ class WeakRefList(list): xe = xc._xla unsafe_map, map = map, safe_map # type: ignore +zip, unsafe_zip = safe_zip, zip # type: ignore logger = logging.getLogger(__name__) @@ -152,7 +155,7 @@ def shard_args(shardings: Sequence[JSharding], layouts, copy_semantics, # from each call in the same order as `args`. Since `batches` is grouped by # types, we cannot simply flatten the results and we have to use the original # indices to put each array back to its original position. - results: list[jax.Array | None] = [None] * len(args) + results: list[typing.Array | None] = [None] * len(args) for t, (indices, a, s, l, cs) in batches.items(): outs = shard_arg_handlers[t](a, s, l, cs) for i, out in safe_zip(indices, outs): @@ -206,7 +209,7 @@ def _shard_np_array(xs, shardings, layouts, copy_semantics): x = np.zeros(x.shape, dtype=np.dtype(bool)) aval = core.shaped_abstractify(x) if layout is not None: - results.append(api.device_put(x, Layout(layout, sharding))) + results.append(api.device_put(x, Format(layout, sharding))) else: if sharding.is_fully_replicated: shards = [x] * len(devices) @@ -228,11 +231,9 @@ def _shard_mutable_array(xs, shardings, layouts, copy_semantics): def batched_device_put(aval: core.ShapedArray, sharding: JSharding, xs: Sequence[Any], - devices: Sequence[jax.Device], committed: bool = True): + devices: Sequence[xc.Device], committed: bool = True): util.test_event("batched_device_put_start") try: - from jax._src import array - bufs = [x for x, d in safe_zip(xs, devices) if (isinstance(x, array.ArrayImpl) and dispatch.is_single_device_sharding(x.sharding) and @@ -257,7 +258,7 @@ def _shard_abstract_array(size, axis: int, x): raise ValueError(f"Axis size {size} does not match dimension {axis} of " f"shape {x.shape}") except IndexError: - raise ValueError("Cannot split a {x.dim}D value along axis {axis}") from None + raise ValueError(f"Cannot split a {x.dim}D value along axis {axis}") from None if config.pmap_no_rank_reduction.value: return x.update(shape=tuple_update(x.shape, axis, 1)) else: @@ -338,8 +339,8 @@ def xla_pmap_impl_lazy( donated_invars: Sequence[bool], is_explicit_global_axis_size: bool, ) -> Callable: - if (config.disable_jit.value and config.eager_pmap.value and - not is_explicit_global_axis_size and not any(d for d in donated_invars)): + if (config.disable_jit.value and + not is_explicit_global_axis_size and not any(donated_invars)): def _emap_apply_fn(*args): return _emap_impl(fun, *args, backend=backend, axis_name=axis_name, axis_size=axis_size, global_axis_size=global_axis_size, @@ -383,7 +384,6 @@ def _emap_impl(fun: lu.WrappedFun, *args, donated_invars: Sequence[bool], is_explicit_global_axis_size: bool, ): - from jax._src import array # TODO(sharadmv,mattjj): implement these cases if any(d for d in donated_invars): raise NotImplementedError("Buffer donation not supported in eager pmap.") @@ -408,12 +408,12 @@ def _emap_impl(fun: lu.WrappedFun, *args, donate_argnums = (1,) if platform in {"cuda", "rocm", "tpu"} else () new_outvals = [] for out_axis_src, out_axis, outval in zip(out_axes_src, out_axes, outvals): - with jax.disable_jit(False): + with api.disable_jit(False): donate_argnums_ = donate_argnums if isinstance(outval, array.ArrayImpl): # We don't want to donate if it's already sharded. donate_argnums_ = () - out = jax.pmap( + out = api.pmap( lambda _, x: x, in_axes=(0, out_axis_src.get(axis_name)), out_axes=out_axis, @@ -446,7 +446,7 @@ def _multi_pmap(f: Callable, info: EmapInfo, names: list[core.AxisName], for i, name in reversed(list(enumerate(names))): in_axes = tuple(arg_axis[i] for arg_axis in all_axes) if any(in_axis is not None for in_axis in in_axes): - f = jax.pmap( + f = api.pmap( f, in_axes=in_axes, axis_name=name, @@ -474,11 +474,12 @@ def to_map_tracer(self, val): return MapTracer(self, val, {}) def process_primitive(self, primitive, tracers, params): - if primitive is jax._src.lax.parallel.axis_index_p: - return self.process_axis_index(**params) - if primitive is jax._src.lax.parallel.psum_p: + from jax._src.lax import parallel # pytype: disable=import-error + if primitive is parallel.axis_index_p: + return self.process_axis_index(**params) # pytype: disable=missing-parameter + if primitive is parallel.psum_p: f = HashableFunction( - lambda *xs: jax._src.lax.parallel.psum( + lambda *xs: parallel.psum( xs, axis_name=params['axes'], axis_index_groups=params['axis_index_groups']), (primitive, tuple(params.items()))) else: @@ -490,7 +491,7 @@ def process_primitive(self, primitive, tracers, params): names = core.get_axis_env().axis_names() all_axes = tuple(_map_schedule(map(s.get, names)) for s in shard_axes) # pytype: disable=wrong-arg-types # always-use-return-annotations f_mapped, out_shard_axes = _multi_pmap(f, self.emap_info, names, all_axes) - with core.eval_context(), jax.disable_jit(False): + with core.eval_context(), api.disable_jit(False): outvals = f_mapped(*vals) if primitive.multiple_results: return [MapTracer(self, val, out_shard_axes) for val in outvals] @@ -544,11 +545,12 @@ def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers, return fun.call_wrapped(*tracers) def process_axis_index(self, axis_name): + from jax._src.lax import lax, parallel # pytype: disable=import-error bind = HashableFunction( - lambda _: jax.lax.axis_index(axis_name), - (jax.lax.axis_index, axis_name)) + lambda _: parallel.axis_index(axis_name), + (parallel.axis_index, axis_name)) fake_primitive = FakePrimitive(multiple_results=False, bind=bind) - range = jax.lax.iota(np.int32, core.get_axis_env().axis_size(axis_name)) + range = lax.iota(np.int32, core.get_axis_env().axis_size(axis_name)) dummy_tracer = MapTracer(self, range, {axis_name: 0}) return self.process_primitive(fake_primitive, (dummy_tracer,), {}) @@ -693,14 +695,15 @@ def find_replicas( @lu.transformation2 def _change_argument_ranks(f, in_axes, out_axes_thunk, *args): + from jax._src.lax import lax # pytype: disable=import-error args = tuple( - arg if in_axis is None else jax.lax.squeeze(arg, dimensions=(in_axis,)) + arg if in_axis is None else lax.squeeze(arg, dimensions=(in_axis,)) for in_axis, arg in zip(in_axes, args) ) results = f(*args) out_axes = out_axes_thunk() return tuple( - x if axis is None else jax.lax.expand_dims(x, dimensions=(axis,)) + x if axis is None else lax.expand_dims(x, dimensions=(axis,)) for x, axis in zip(results, out_axes) ) @@ -921,7 +924,7 @@ def _pmap_unmapped_aval(size: core.AxisSize, axis: int | None, raise TypeError(f"no unmapping handler for {aval} of type {type(aval)}") -class PmapComputation(stages.XlaLowering): +class PmapComputation(stages.Lowering): _hlo: ir.Module _executable: PmapExecutable | None @@ -930,7 +933,7 @@ def __init__(self, hlo: ir.Module, **compile_args): self._hlo = hlo self.compile_args = compile_args - # -- stages.XlaLowering overrides + # -- stages.Lowering overrides def stablehlo(self) -> ir.Module: return self._hlo @@ -1097,9 +1100,14 @@ def from_hlo(hlo: ir.Module, with dispatch.log_elapsed_time( "Finished XLA compilation of {fun_name} in {elapsed_time:.9f} sec", fun_name=pci.name, event=dispatch.BACKEND_COMPILE_EVENT): + # `executable_devices` contains devices for output shardings of a pmapped + # function. It contains only local devices for correspondence with + # `PmapSharding`s, which also contain only local devices. + executable_devices = _create_da_object( + tuple(local_device_assignment.flat)) compiled = compiler.compile_or_get_cached( pci.backend, hlo, device_assignment, compile_options, - host_callbacks) + host_callbacks, executable_devices) return UnloadedPmapExecutable( compiled=compiled, @@ -1115,7 +1123,7 @@ def from_hlo(hlo: ir.Module, jaxpr_debug_info=jaxpr_debug_info).load() -class PmapExecutable(stages.XlaExecutable): +class PmapExecutable(stages.Executable): __slots__ = ["xla_executable", "_unsafe_call", "build_unsafe_call", "fingerprint", "in_avals", "_unloaded_executable"] @@ -1135,7 +1143,7 @@ def unsafe_call(self) -> Callable[..., Any]: self._unsafe_call = self.build_unsafe_call() return self._unsafe_call # type: ignore - # -- stages.XlaExecutable overrides + # -- stages.Executable overrides def xla_extension_executable(self): return self.xla_executable @@ -1267,8 +1275,8 @@ def _handle_token_bufs(self, token_bufs, sharded_token): for token in token_buf: assert isinstance(token.sharding, sharding_impls.SingleDeviceSharding) token_devices.append(token.sharding._device_assignment[0]) - s = PositionalSharding(token_devices) - global_token_array = jax.make_array_from_single_device_arrays( + s = NamedSharding(Mesh(token_devices, 'x'), P('x')) + global_token_array = array.make_array_from_single_device_arrays( (0,), s, token_buf ) dispatch.runtime_tokens.set_token_result( @@ -1314,7 +1322,7 @@ def __call__(self, *args): out_ = [] for i, o in zip(self.mut.out_mut, out): if i is not None: - args[i]._buf = o + args[i]._buf._replace_with(o) # type: ignore else: out_.append(o) return out_ @@ -1651,67 +1659,10 @@ def check_if_any_auto( return True return False -class MismatchType(enum.Enum): - ARG_SHARDING = 0 - OUT_SHARDING = 1 - SHARDING_INSIDE_COMPUTATION = 2 - CONTEXT_DEVICES = 3 - IN_SHARDING = 4 - - def __str__(self): - if self.name == 'IN_SHARDING': - return 'explicit input sharding' - elif self.name == 'OUT_SHARDING': - return 'explicit output sharding' - elif self.name == 'CONTEXT_DEVICES': - return 'context mesh' - return f'{self.name}' - - -@dataclasses.dataclass -class DeviceAssignmentMismatch: - da: Sequence[xc.Device] - m_type: MismatchType - source_info: dispatch.SourceInfo | None - - @property - def device_ids(self) -> Sequence[int]: - return [d.id for d in self.da] - - @property - def platform(self) -> str: - return self.da[0].platform.upper() - - def _maybe_api_name(self, api_name) -> str: - return f" {api_name}'s" if self.m_type == MismatchType.CONTEXT_DEVICES else "" - - @property - def source_info_str(self): - return ( - "" if self.source_info is None - else f" at {source_info_util.summarize(self.source_info.source_info)}" - ) - - @property - def _dev_ids_plat_str(self): - return f"device ids {self.device_ids} on platform {self.platform}" - - def m_type_str(self, api_name): - return (f'{self.source_info and self.source_info.eqn_name} inside {api_name}' - if self.m_type == MismatchType.SHARDING_INSIDE_COMPUTATION else self.m_type) - - def _str(self, api_name): - return (f"{self._maybe_api_name(api_name)} {self.m_type_str(api_name)} with " - f"{self._dev_ids_plat_str}{self.source_info_str}") - - -class DeviceAssignmentMismatchError(Exception): - pass - ShardingInfo = tuple[ Union[JSharding, UnspecifiedValue, AUTO], - MismatchType, + stages.MismatchType, Union[Any, None], # Any is dispatch.SourceInfo to avoid circular imports ] @@ -1743,14 +1694,14 @@ def _get_and_check_device_assignment( else sh._device_assignment) if not devices: if first_sharding_info[0] != arr_device_assignment: - raise DeviceAssignmentMismatchError([ - DeviceAssignmentMismatch(*first_sharding_info), - DeviceAssignmentMismatch(arr_device_assignment, s_type, source_info)]) + raise stages.DeviceAssignmentMismatchError([ + stages.DeviceAssignmentMismatch(*first_sharding_info), + stages.DeviceAssignmentMismatch(arr_device_assignment, s_type, source_info)]) else: if devices != arr_device_assignment: - raise DeviceAssignmentMismatchError([ - DeviceAssignmentMismatch(devices, MismatchType.CONTEXT_DEVICES, None), - DeviceAssignmentMismatch(arr_device_assignment, s_type, source_info)]) + raise stages.DeviceAssignmentMismatchError([ + stages.DeviceAssignmentMismatch(devices, stages.MismatchType.CONTEXT_DEVICES, None), + stages.DeviceAssignmentMismatch(arr_device_assignment, s_type, source_info)]) if first_sharding_info is None and devices: final_device_assignment = devices elif first_sharding_info is None: @@ -1803,7 +1754,7 @@ class MutationData(NamedTuple): def _discharge_refs( jaxpr: core.ClosedJaxpr ) -> tuple[core.ClosedJaxpr, Sequence[int | None], MutationData]: - from jax._src.state.discharge import discharge_state + from jax._src.state.discharge import discharge_state # pytype: disable=import-error jaxpr, in_mut = _move_mutable_consts(jaxpr) new_jaxpr = core.ClosedJaxpr(*discharge_state(jaxpr.jaxpr, jaxpr.consts)) count = it.count(len(jaxpr.out_avals)) # new outputs are appended to the end @@ -1824,13 +1775,14 @@ def _move_mutable_consts( constvars, mutvars = partition_list(hoist, jaxpr.constvars) invars = (*jaxpr.invars, *mutvars) effects = pe.make_jaxpr_effects(constvars, invars, jaxpr.outvars, jaxpr.eqns) + # TODO(mattjj): debug_info must be updated... jaxpr = core.Jaxpr(constvars, invars, jaxpr.outvars, jaxpr.eqns, effects, closed_jaxpr.jaxpr.debug_info) return core.ClosedJaxpr(jaxpr, consts), in_mut @weakref_lru_cache def _discharge_internal_refs(jaxpr: core.ClosedJaxpr) -> core.ClosedJaxpr: - from jax._src.state.discharge import discharge_state + from jax._src.state.discharge import discharge_state # pytype: disable=import-error jaxpr_, consts = discharge_state(jaxpr.jaxpr, jaxpr.consts) jaxpr_._debug_info = jaxpr.jaxpr._debug_info return core.ClosedJaxpr(jaxpr_, consts) @@ -1876,14 +1828,14 @@ def _raise_warnings_or_errors_for_jit_of_pmap( "input and output arrays onto a single device. " "Consider removing the outer jit unless you know what you're doing. " "See https://github.com/jax-ml/jax/issues/2926. Or " - "use jax.experimental.shard_map instead of pmap under jit compilation.") + "use jax.shard_map instead of pmap under jit compilation.") if nreps > xb.device_count(backend): raise ValueError( f"compiling computation `{name}` that requires {nreps} replicas, but " f"only {xb.device_count(backend)} XLA devices are available.") - if xb.process_count() > 1 and ( + if xb.process_count(backend) > 1 and ( nreps > 1 or dispatch.jaxpr_has_primitive(jaxpr, "xla_pmap") ): raise NotImplementedError( @@ -2064,8 +2016,6 @@ def _default_rule(prim, num_outvars, *_, **__): @weakref_lru_cache def get_out_layouts_via_propagation(closed_jaxpr: core.ClosedJaxpr ) -> tuple[None | DeviceLocalLayout]: - from jax._src import pjit - env = {} # type: ignore jaxpr = closed_jaxpr.jaxpr @@ -2143,8 +2093,10 @@ def _discharge_refs_jaxpr(closed_jaxpr, in_shardings, in_layouts, donated_invars, out_shardings, out_layouts): if any(isinstance(e, RefEffect) for e in closed_jaxpr.effects): closed_jaxpr, inout_aliases, mut = _discharge_refs(closed_jaxpr) - in_shardings = (*in_shardings, *(c.sharding for c in mut.in_mut)) - in_layouts = (*in_layouts,) + (None,) * len(mut.in_mut) # TODO(mattjj) + in_shardings = (*in_shardings, *( + pjit.finalize_arg_sharding(c.sharding, c.committed) for c in mut.in_mut)) + in_layouts = (*in_layouts, *(c.format.dll if hasattr(c, 'format') else None + for c in mut.in_mut)) donated_invars = (*donated_invars,) + (False,) * len(mut.in_mut) out_layouts_ = iter(zip(out_shardings, out_layouts)) out_shardings, out_layouts = unzip2( @@ -2230,8 +2182,7 @@ def lower_sharding_computation( The caller of this code can pass in a singleton UNSPECIFIED because the number of out_avals might not be known at that time and lower_sharding_computation calculates the number of out_avals so it can apply - the singleton UNSPECIFIED to all out_avals. - """ + the singleton UNSPECIFIED to all out_avals.""" auto_spmd_lowering = check_if_any_auto( it.chain.from_iterable([in_shardings, out_shardings])) @@ -2274,13 +2225,23 @@ def lower_sharding_computation( unique_out_shardings = util.stable_unique(out_shardings) backend, device_assignment = _get_and_check_device_assignment( it.chain( - ((i, MismatchType.ARG_SHARDING, None) for i in unique_in_shardings), - ((o, MismatchType.OUT_SHARDING, None) for o in unique_out_shardings), - ((js, MismatchType.SHARDING_INSIDE_COMPUTATION, source_info) + ((i, stages.MismatchType.ARG_SHARDING, None) for i in unique_in_shardings), + ((o, stages.MismatchType.OUT_SHARDING, None) for o in unique_out_shardings), + ((js, stages.MismatchType.SHARDING_INSIDE_COMPUTATION, source_info) for js, source_info in unique_intermediate_shardings)), devices_from_context) unique_intermediate_shardings = [js for js, _ in unique_intermediate_shardings] + for a in global_out_avals: + if (a is not core.abstract_token and not a.sharding.mesh.empty and + a.sharding.mesh._are_all_axes_explicit and + len(device_assignment) != a.sharding.mesh.size): + raise ValueError( + f"Length of device assignment {len(device_assignment)} is not equal" + f" to the size of the mesh {a.sharding.mesh.size} of aval" + f" {a.str_short(True, True)}. Please enter your `jit` into a mesh" + " context via `jax.sharding.use_mesh`.") + # TODO(parkers): One _raw_platform has been unified with platform, # change this back to just read platform. platforms = lowering_platforms or ( @@ -2419,7 +2380,7 @@ def _to_logical_sharding( raise TypeError(aval) -class MeshComputation(stages.XlaLowering): +class MeshComputation(stages.Lowering): _hlo: ir.Module _executable: MeshExecutable | None @@ -2435,7 +2396,7 @@ def __init__(self, name: str, hlo: ir.Module, self.compile_args = compile_args self._executable = None - # -- stages.XlaLowering overrides + # -- stages.Lowering overrides def stablehlo(self) -> ir.Module: return self._hlo @@ -2463,14 +2424,41 @@ def cost_analysis(self) -> dict[str, float]: return xe.hlo_module_cost_analysis(backend, self.hlo().as_hlo_module()) +def get_op_sharding_from_executable( + executable) -> tuple[Sequence[xc.OpSharding], Sequence[xc.OpSharding]]: + in_op_shardings: list[xc.OpSharding] = [] + parameter_shardings_from_xla = executable.get_parameter_shardings() + if parameter_shardings_from_xla is not None: + in_op_shardings = parameter_shardings_from_xla + + out_op_shardings: list[xc.OpSharding] = [] + output_shardings_from_xla = executable.get_output_shardings() + if output_shardings_from_xla is not None: + out_op_shardings = output_shardings_from_xla + + return in_op_shardings, out_op_shardings + + +def get_pspec_from_executable( + executable, mesh: Mesh +) -> tuple[tuple[PartitionSpec, ...], tuple[PartitionSpec, ...]]: + input_op_s, output_op_s = get_op_sharding_from_executable(executable) + in_pspec: list[PartitionSpec] = [] + for s in input_op_s: + in_pspec.extend(sharding_impls.parse_flatten_op_sharding(s, mesh)) + + out_pspec: list[PartitionSpec] = [] + for s in output_op_s: + out_pspec.extend(sharding_impls.parse_flatten_op_sharding(s, mesh)) + return tuple(in_pspec), tuple(out_pspec) + + def get_out_shardings_from_executable( xla_executable, device_assignment: Sequence[xc.Device], num_out_avals: int, num_ordered_effects: int, ) -> Sequence[sharding_impls.GSPMDSharding] | None: - from jax._src import pjit - try: omk = xla_executable.get_output_memory_kinds()[0] if num_ordered_effects > 0: @@ -2486,7 +2474,7 @@ def get_out_shardings_from_executable( return [sharding_impls.GSPMDSharding.get_replicated(device_assignment, memory_kind=mk) for mk in omk] - _, out_op_shardings = pjit.get_op_sharding_from_executable(xla_executable) + _, out_op_shardings = get_op_sharding_from_executable(xla_executable) if not out_op_shardings: return None @@ -2517,14 +2505,12 @@ def _get_in_shardings_from_xla( num_ordered_effects: int ) -> Sequence[GSPMDSharding] | None: """Returns input shardings from XLA.""" - from jax._src import pjit - # When the device assignment only has 1 device, SPMD partitioner will not run. # Hence the op shardings will not be set on the `hlo_module`. if len(device_assignment) == 1: return [GSPMDSharding.get_replicated(device_assignment)] * num_in_avals - in_op_shardings, _ = pjit.get_op_sharding_from_executable(xla_executable) + in_op_shardings, _ = get_op_sharding_from_executable(xla_executable) if not in_op_shardings: return None @@ -2543,9 +2529,7 @@ def _get_in_shardings_from_xla( def _get_mesh_pspec_shardings_from_executable( xla_executable, mesh: Mesh ) -> tuple[Sequence[NamedSharding], Sequence[NamedSharding]]: - from jax._src import pjit - - in_pspec, out_pspec = pjit.get_pspec_from_executable(xla_executable, mesh) + in_pspec, out_pspec = get_pspec_from_executable(xla_executable, mesh) return ([NamedSharding(mesh, i) for i in in_pspec], [NamedSharding(mesh, o) for o in out_pspec]) @@ -2565,15 +2549,6 @@ def _gspmd_to_named_sharding( return sharding_impls._gspmd_to_named_sharding_via_mesh(out_s, mesh) _orig_out_sharding_handlers[NamedSharding] = _gspmd_to_named_sharding -def _gspmd_to_positional_sharding( - out_s: GSPMDSharding, out_aval, orig_in_s: PositionalSharding - ) -> PositionalSharding: - assert isinstance(out_s, GSPMDSharding) - assert isinstance(orig_in_s, PositionalSharding) - return sharding_impls._op_sharding_to_pos_sharding( - out_s._hlo_sharding, orig_in_s._device_assignment, out_s.memory_kind) -_orig_out_sharding_handlers[PositionalSharding] = _gspmd_to_positional_sharding # type: ignore - def _gspmd_to_single_device_sharding( out_s: GSPMDSharding, out_aval, orig_in_s: SingleDeviceSharding ) -> SingleDeviceSharding: @@ -2716,7 +2691,6 @@ def create_compile_options( num_partitions=num_partitions, device_assignment=xla_device_assignment, use_spmd_partitioning=spmd_lowering, - use_shardy_partitioner=config.use_shardy_partitioner.value, use_auto_spmd_partitioning=auto_spmd_lowering, env_options_overrides=compiler_options, fdo_profile=fdo_profile, @@ -2757,7 +2731,7 @@ def _cached_compilation(computation, name, mesh, spmd_lowering, fun_name=name, event=dispatch.BACKEND_COMPILE_EVENT): xla_executable = compiler.compile_or_get_cached( backend, computation, dev, compile_options, host_callbacks, - pgle_profiler) + da, pgle_profiler) return xla_executable @@ -2970,8 +2944,6 @@ def from_hlo(name: str, allow_prop_to_outputs, tuple(host_callbacks), backend, da, pmap_nreps, compiler_options_kvs, pgle_profiler) - orig_out_shardings = out_shardings - if auto_spmd_lowering: assert mesh is not None in_shardings_xla, out_shardings_xla = _get_mesh_pspec_shardings_from_executable( @@ -2994,7 +2966,7 @@ def from_hlo(name: str, xla_executable.local_devices(), len(in_shardings), len(out_shardings)) # xla_in_layouts are all either None or DeviceLocalLayout. Even default - # layout are concrete layouts and they are used in `compiled.input_layouts` + # layout are concrete layouts and they are used in `compiled.input_formats` # to return concrete layouts to users. # `dispatch_in_layouts` replaces default layouts with `None` to simplify # dispatch logic downstream. @@ -3085,7 +3057,7 @@ def reflatten_outputs_for_dispatch(out_tree, out_flat): return tree_util.dispatch_registry.flatten(out_unflat, None) -class MeshExecutable(stages.XlaExecutable): +class MeshExecutable(stages.Executable): __slots__ = [ "xla_executable", "_unsafe_call", "build_unsafe_call", "in_avals", "out_avals", "_in_shardings", "_out_shardings", "_auto_spmd_lowering", @@ -3121,7 +3093,7 @@ def unsafe_call(self) -> Callable[..., Any]: self._unsafe_call = self.build_unsafe_call() return self._unsafe_call # type: ignore - # -- stages.XlaExecutable overrides + # -- stages.Executable overrides def xla_extension_executable(self): return self.xla_executable @@ -3149,20 +3121,6 @@ def call(self, *args): self._kept_var_idx) return self.unsafe_call(*args) # pylint: disable=not-callable - def input_shardings(self) -> Sequence[JSharding]: - return self._in_shardings - - def output_shardings(self) -> Sequence[JSharding]: - return self._out_shardings - - def input_layouts(self): - return [Layout(l, s) - for l, s in safe_zip(self._xla_in_layouts, self._in_shardings)] - - def output_layouts(self): - return [Layout(l, s) - for l, s in safe_zip(self._xla_out_layouts, self._out_shardings)] - def create_cpp_call(self, no_kwargs, in_tree, out_tree): if not (isinstance(self.unsafe_call, ExecuteReplicated) and not self.unsafe_call.has_unordered_effects and @@ -3261,7 +3219,6 @@ def check_array_xla_sharding_layout_match( in_xla_layouts: Sequence[DeviceLocalLayout], jaxpr_debug_info: core.DebugInfo, kept_var_idx: set[int]) -> None: - from jax._src.array import ArrayImpl # jaxpr_debug_info.arg_names are before DCE, so need to DCE them. arg_names = ( [a for i, a in enumerate(jaxpr_debug_info.arg_names) @@ -3271,7 +3228,7 @@ def check_array_xla_sharding_layout_match( num_errors = 5 for arg, xs, xl, name in safe_zip( args_after_dce, in_xla_shardings, in_xla_layouts, arg_names): - if not isinstance(arg, ArrayImpl): + if not isinstance(arg, array.ArrayImpl): continue if isinstance(xs, (UnspecifiedValue, AUTO)): continue @@ -3287,11 +3244,11 @@ def check_array_xla_sharding_layout_match( 'sharding')) if (not db_xs and arg._committed and - arg.layout.device_local_layout is not None and xl is not None and - arg.layout.device_local_layout != xl): + arg.format.device_local_layout is not None and xl is not None and + arg.format.device_local_layout != xl): errors.append( ("Got input layout(s) that compiled object was called with: " - f"{arg.layout.device_local_layout} and layout(s) the computation was " + f"{arg.format.device_local_layout} and layout(s) the computation was " f"compiled with: {xl} for arg {name} with " f"shape: {arg.aval.str_short()}", 'layout')) @@ -3314,6 +3271,12 @@ def check_array_xla_sharding_layout_match( "compiled with. " f"Here are {num_mismatch_str}:\n{str_errors}") +def batch_spec(spec, dim, val): + too_short = dim - len(spec) + if too_short > 0: + spec += (None,) * too_short + new_partitions = tuple_insert(spec, dim, val) # type: ignore + return PartitionSpec(*new_partitions) def get_array_mapping(pspec: PartitionSpec) -> ArrayMappingOrAutoOrUnspecified: pspec = sharding_impls.prepare_axis_resources(pspec, "pspec to array_mapping") diff --git a/jax/_src/interpreters/xla.py b/jax/_src/interpreters/xla.py index 33a8992a8be4..73a57f935f5d 100644 --- a/jax/_src/interpreters/xla.py +++ b/jax/_src/interpreters/xla.py @@ -16,16 +16,16 @@ from __future__ import annotations -from collections.abc import Callable, Sequence +from collections.abc import Callable from functools import partial from typing import Any, Union import numpy as np from jax._src import core +from jax._src import deprecations from jax._src import dtypes from jax._src.abstract_arrays import numpy_scalar_types -from jax._src.core import ShapedArray from jax._src.util import safe_zip, safe_map from jax._src.typing import Shape @@ -41,11 +41,6 @@ def identity(x): return x _scalar_types = dtypes.python_scalar_dtypes.keys() -def _make_array_shape(aval: ShapedArray) -> Sequence[xc.Shape]: - aval = core.physical_aval(aval) - dtype = np.dtype('bool') if aval.dtype == dtypes.float0 else aval.dtype - return (xc.Shape.array_shape(dtype, aval.shape),) - # Utilities # HLO instructions optionally can be annotated to say how the output should be @@ -90,20 +85,6 @@ def tuple_sharding_proto(elems): ### handlers -# JAX abstract values -> XLA shapes - -def aval_to_xla_shapes(aval: core.AbstractValue) -> Sequence[xc.Shape]: - try: - return _xla_shape_handlers[type(aval)](aval) - except KeyError as err: - raise TypeError(f"No xla_shape_handler for type: {type(aval)}") from err - -_xla_shape_handlers: dict[type[core.AbstractValue], - Callable[[Any], Sequence[xc.Shape]]] = { - ShapedArray: _make_array_shape, -} -_xla_shape_handlers[core.AbstractToken] = lambda _: (xc.Shape.token_shape(),) - # IR constants @@ -120,6 +101,12 @@ def canonicalize_dtype(x): handler = canonicalize_dtype_handlers.get(typ) if handler: return handler(x) if hasattr(x, '__jax_array__'): + deprecations.warn( + 'jax-abstract-dunder-array', + ('Triggering of __jax_array__() during abstractification is deprecated.' + ' To avoid this error, either explicitly convert your object using' + ' jax.numpy.array(), or register your object as a pytree.'), + stacklevel=6) return canonicalize_dtype(x.__jax_array__()) raise InvalidInputException( f"Argument '{x}' of type {type(x)} is not a valid JAX type.") diff --git a/jax/_src/jaxpr_util.py b/jax/_src/jaxpr_util.py index ab72634d3bdf..81ffb2730b1f 100644 --- a/jax/_src/jaxpr_util.py +++ b/jax/_src/jaxpr_util.py @@ -22,7 +22,8 @@ import itertools import json import types -from typing import Any, Iterator, Union +from typing import Any, Union +from collections.abc import Iterator from jax._src import core from jax._src import util @@ -33,11 +34,23 @@ zip, unsafe_zip = util.safe_zip, zip -def all_eqns(jaxpr: core.Jaxpr) -> Iterator[tuple[core.Jaxpr, core.JaxprEqn]]: +def _all_eqns( + jaxpr: core.Jaxpr, visited: set[core.Jaxpr] | None, +) -> Iterator[tuple[core.Jaxpr, core.JaxprEqn]]: for eqn in jaxpr.eqns: yield (jaxpr, eqn) for subjaxpr in core.subjaxprs(jaxpr): - yield from all_eqns(subjaxpr) + if visited is None: + yield from _all_eqns(subjaxpr, visited) + elif subjaxpr not in visited: + visited.add(subjaxpr) + yield from _all_eqns(subjaxpr, visited) + +def all_eqns( + jaxpr: core.Jaxpr, revisit_inner_jaxprs: bool = True +) -> Iterator[tuple[core.Jaxpr, core.JaxprEqn]]: + yield from _all_eqns(jaxpr, None if revisit_inner_jaxprs else set()) + def collect_eqns(jaxpr: core.Jaxpr, key: Callable): d = defaultdict(list) @@ -206,6 +219,38 @@ def pprof_equation_profile(jaxpr: core.Jaxpr) -> bytes: """ d = Counter( (eqn.source_info.traceback, eqn.primitive) - for _, eqn in all_eqns(jaxpr) + for _, eqn in all_eqns(jaxpr, revisit_inner_jaxprs=False) ) return _pprof_profile(d) + +def eqns_using_var_with_invar_index(jaxpr: core.Jaxpr, invar: core.Var) -> Iterator[tuple[core.JaxprEqn, int]]: + """Find all the equations which use invar and the positional index of its binder""" + for eqn in jaxpr.eqns: + for invar_index, eqn_var in enumerate(eqn.invars): + if eqn_var == invar: + yield eqn, invar_index + break # we found the var, no need to keep looking in this eqn + +def jaxpr_and_binder_in_params(params, index: int) -> Iterator[tuple[core.Jaxpr, core.Var]]: + for val in params.values(): + vals = val if isinstance(val, tuple) else (val,) + for v in vals: + if isinstance(v, core.Jaxpr): + if index >= len(v.invars): + raise RuntimeError(f"Failed to find index {index} in jaxpr.invars while building report") + yield v, v.invars[index] + elif isinstance(v, core.ClosedJaxpr): + if index >= len(v.jaxpr.invars): + raise RuntimeError(f"Failed to find index {index} in jaxpr.invars while building report") + yield v.jaxpr, v.jaxpr.invars[index] + +def eqns_using_var(jaxpr: core.Jaxpr, invar: core.Var) -> Iterator[core.JaxprEqn]: + """Find the leaf equations using a variable""" + # The complexity of this call is because the invar might originate from a nested jaxpr + for eqn, invar_index in eqns_using_var_with_invar_index(jaxpr, invar): + if (child_jaxprs_and_vars := tuple(jaxpr_and_binder_in_params(eqn.params, invar_index))): + for (jaxpr, invar) in child_jaxprs_and_vars: + yield from eqns_using_var(jaxpr, invar) + else: + # if the previous condition fails, there is no deeper jaxpr to explore =( + yield eqn diff --git a/jax/_src/lax/ann.py b/jax/_src/lax/ann.py index 0e037ec774b5..61d383ee29c2 100644 --- a/jax/_src/lax/ann.py +++ b/jax/_src/lax/ann.py @@ -82,7 +82,7 @@ def pmap_mips(qy, db, db_offset, db_size, k, recall_target): from jax._src.interpreters import batching from jax._src.interpreters import mlir from jax._src.lax import lax -from jax._src.lib import xla_client as xc +from jax._src.lib import _jax from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import func from jax._src.lib.mlir.dialects import hlo @@ -231,7 +231,7 @@ def _approx_top_k_abstract_eval(operand, *, k, reduction_dimension, if aggregate_to_topk: dims[reduction_dimension] = k elif core.is_constant_shape((reduction_input_size, k)): - dims[reduction_dimension] = xc.ops.ApproxTopKReductionOutputSize( + dims[reduction_dimension] = _jax.approx_top_k_reduction_output_size( reduction_input_size, len(dims), k, recall_target, aggregate_to_topk, reduction_input_size_override)[0] else: @@ -240,8 +240,8 @@ def _approx_top_k_abstract_eval(operand, *, k, reduction_dimension, f"either the `k` ({k}) or the " f" reduction dimension size ({reduction_input_size}) are symbolic") return (operand.update(shape=dims, dtype=operand.dtype, - weak_type=operand.weak_type), - operand.update(shape=dims, dtype=np.dtype(np.int32))) + weak_type=operand.weak_type, vma=operand.vma), + operand.update(shape=dims, dtype=np.dtype(np.int32), vma=operand.vma)) def _get_init_val_literal(op_type, is_max_k): return np.array(-np.inf if is_max_k else np.inf, dtype=op_type) diff --git a/jax/_src/lax/control_flow/__init__.py b/jax/_src/lax/control_flow/__init__.py index f89e4d53a476..44ee94e14ca2 100644 --- a/jax/_src/lax/control_flow/__init__.py +++ b/jax/_src/lax/control_flow/__init__.py @@ -34,6 +34,7 @@ while_p as while_p, ) from jax._src.lax.control_flow.conditionals import ( + BranchesPlatforms as BranchesPlatforms, cond as cond, cond_p as cond_p, switch as switch, diff --git a/jax/_src/lax/control_flow/common.py b/jax/_src/lax/control_flow/common.py index b75cbf6ac708..b90eda4e765c 100644 --- a/jax/_src/lax/control_flow/common.py +++ b/jax/_src/lax/control_flow/common.py @@ -20,17 +20,17 @@ from functools import partial from typing import Any +from jax._src import ad_util from jax._src import api_util from jax._src import core -from jax._src import linear_util as lu -from jax._src.lax import lax from jax._src import effects -from jax._src import ad_util +from jax._src import linear_util as lu from jax._src import state +from jax._src.lax import lax from jax._src.util import weakref_lru_cache, safe_map, partition_list from jax._src.interpreters import partial_eval as pe -from jax.tree_util import tree_map, tree_unflatten, keystr, PyTreeDef -from jax._src.tree_util import equality_errors_pytreedef +from jax._src.tree_util import (equality_errors_pytreedef, tree_map, + tree_unflatten, keystr, PyTreeDef) map, unsafe_map = safe_map, map @@ -184,9 +184,8 @@ def _pad_jaxpr_constvars(jaxpr, i, canonical_ref_avals, canonical_ref_indices, canonical_non_ref_avals, canonical_non_ref_indices): is_ref = [isinstance(v.aval, state.AbstractRef) for v in jaxpr.constvars] nonref_constvars, ref_constvars = partition_list(is_ref, jaxpr.constvars) - newvar = core.gensym(suffix='_') - padded_ref_constvars = map(newvar, canonical_ref_avals) - padded_non_ref_constvars = map(newvar, canonical_non_ref_avals) + padded_ref_constvars = map(core.Var, canonical_ref_avals) + padded_non_ref_constvars = map(core.Var, canonical_non_ref_avals) for canonical_id, ref_var in zip(canonical_ref_indices[i], ref_constvars): padded_ref_constvars[canonical_id] = ref_var for canonical_id, non_ref_var in zip(canonical_non_ref_indices[i], nonref_constvars): diff --git a/jax/_src/lax/control_flow/conditionals.py b/jax/_src/lax/control_flow/conditionals.py index 63896cc2a0bf..c270b54f8713 100644 --- a/jax/_src/lax/control_flow/conditionals.py +++ b/jax/_src/lax/control_flow/conditionals.py @@ -23,7 +23,9 @@ import operator from typing import Any, TypeVar -from jax.tree_util import tree_flatten, tree_unflatten +from jax._src.tree_util import ( + tree_flatten, tree_unflatten, tree_flatten_with_path, keystr, + equality_errors_pytreedef) from jax._src import ad_util from jax._src import api_util from jax._src import config @@ -44,19 +46,15 @@ from jax._src.interpreters import xla from jax._src.lax import lax from jax._src.traceback_util import api_boundary -from jax._src.util import (safe_map, split_list, partition_list) +from jax._src.typing import ArrayLike +from jax._src.util import safe_map, split_list, partition_list, unzip2 from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo import numpy as np from jax._src.lax.control_flow.common import ( - _avals_short, - _check_tree_and_avals, - _initial_style_jaxprs_with_common_consts, - _make_closed_jaxpr, - _prune_zeros, - _typecheck_param, - ) + _avals_short, _typecheck_param, _initial_style_jaxprs_with_common_consts, + _make_closed_jaxpr, _prune_zeros) map, unsafe_map = safe_map, map @@ -130,9 +128,17 @@ def switch(index, branches, *operands): lo = np.array(0, np.int32) hi = np.array(len(branches) - 1, np.int32) index = lax.clamp(lo, index, hi) + return _switch_internal(index, branches, operands, + branches_platforms=None) + +def _switch_internal( + index: ArrayLike, + branches: Sequence[Callable], + operands: Sequence[ArrayLike], *, + branches_platforms: BranchesPlatforms | None): if (config.disable_jit.value and core.is_concrete(index)): - return branches[int(index)](*operands) + return branches[int(index)](*operands) # type: ignore dbgs = [api_util.debug_info("switch", branch, operands, {}) for branch in branches] @@ -147,16 +153,33 @@ def switch(index, branches, *operands): if config.mutable_array_checks.value: api_util._check_no_aliased_closed_over_refs(dbgs[0], (*jaxprs[0].consts, *consts), ops) for i, (out_tree, jaxpr) in enumerate(zip(out_trees[1:], jaxprs[1:])): - _check_tree_and_avals("branch 0 output", - out_trees[0], jaxprs[0].out_avals, - f"branch {i + 1} output", - out_tree, jaxpr.out_avals) + _check_branch_outputs( + "switch", "branch 0", f"branch{i+1}", branches[0], branches[i+1], + out_trees[0], out_tree, jaxprs[0].out_avals, jaxpr.out_avals) + # prune passthrough outputs + fwds = [pe._jaxpr_forwarding(jaxpr.jaxpr) for jaxpr in jaxprs] + in_fwd = [xs[0] if len(set(xs)) == 1 else None for xs in zip(*fwds)] + keep = [f is None for f in in_fwd] + jaxprs = [pe.prune_closed_jaxpr_outputs(jaxpr, keep) for jaxpr in jaxprs] + joined_effects = core.join_effects(*(jaxpr.effects for jaxpr in jaxprs)) disallowed_effects = effects.control_flow_allowed_effects.filter_not_in(joined_effects) if disallowed_effects: raise NotImplementedError( f'Effects not supported in `switch`: {disallowed_effects}') - out = cond_p.bind(index, *consts, *ops, branches=tuple(jaxprs)) + jaxprs = [replace_jaxpr_effects(jaxpr, joined_effects) for jaxpr in jaxprs] + params = dict(branches=tuple(jaxprs)) + if branches_platforms is not None: + params["branches_platforms"] = branches_platforms + out = cond_p.bind(index, *consts, *ops, **params) + out_ = iter(out) + + all_inputs = [*consts, *ops] + out = [ + next(out_) if fwd is None else lax.asarray(all_inputs[fwd]) + for fwd in in_fwd + ] + assert next(out_, None) is None return tree_unflatten(out_trees[0], out) @@ -255,11 +278,11 @@ def cond(pred, true_fun, false_fun, *operands): true_jaxpr.out_avals + false_jaxpr.out_avals): raise ValueError("Cannot return `Ref`s from `cond`.") - _check_tree_and_avals("true_fun output", - out_tree, true_jaxpr.out_avals, - "false_fun output", - false_out_tree, false_jaxpr.out_avals) - # prune passhtrough outputs + _check_branch_outputs( + 'cond', 'true_fun', 'false_fun', true_fun, false_fun, out_tree, + false_out_tree, true_jaxpr.out_avals, false_jaxpr.out_avals) + + # prune passthrough outputs true_fwds = pe._jaxpr_forwarding(true_jaxpr.jaxpr) false_fwds = pe._jaxpr_forwarding(false_jaxpr.jaxpr) in_fwd = [i if i == j else None for i, j in zip(true_fwds, false_fwds)] @@ -278,7 +301,6 @@ def cond(pred, true_fun, false_fun, *operands): true_jaxpr = replace_jaxpr_effects(true_jaxpr, joined_effects) out = cond_p.bind(index, *consts, *ops, branches=(false_jaxpr, true_jaxpr)) - num_consts = len(consts) out_ = iter(out) all_inputs = [*consts, *ops] @@ -289,6 +311,90 @@ def cond(pred, true_fun, false_fun, *operands): assert next(out_, None) is None return tree_unflatten(out_tree, out) +def _check_branch_outputs( + api_name, name1, name2, f1, f2, out_tree1, out_tree2, out_avals1, + out_avals2) -> None: + info1 = api_util.fun_sourceinfo(f1) + info2 = api_util.fun_sourceinfo(f2) + try: + outs1 = tree_unflatten(out_tree1, out_avals1) + except: + paths = [None] * len(out_avals1) + component = lambda _: '' + else: + leaves_and_paths, _ = tree_flatten_with_path(outs1) + paths, _ = unzip2(leaves_and_paths) # type: ignore + component = lambda p: f' at path {keystr(p)}' if p else '' + + if out_tree1 != out_tree2: + diffs = [f'{name1} output{component(p)} is a {thing1} but ' + f'{name2} output{component(p)} is a {thing2}, so {expl}' + for p, thing1, thing2, expl + in equality_errors_pytreedef(out_tree1, out_tree2)] + + if len(diffs) == 0: + return # the trees may have different aux data, but structures are same + elif len(diffs) == 1: + differences = f'{diffs[0]}.\n' + else: + differences = ('\n'.join(f' * {d};\n' for d in diffs[:-1]) + + f' * {diffs[-1]}.\n') + + raise TypeError( + f'{api_name} branch outputs must have the same pytree structure, but ' + 'they differ:\n\n' + f'{name1} is {info1}\n' + f'{name2} is {info2}\n\n' + f'{differences}\n' + f'Revise {name1} and/or {name2} so that they have the same pytree ' + 'structure.') + + if not all(map(core.typematch, out_avals1, out_avals2)): + diffs = [f'the output of {name1}{component(p)} has type {a1.str_short()}' + f' but the corresponding output of {name2} has type ' + f'{a2.str_short()}{core.aval_mismatch_extra(a1, a2)}' + for p, a1, a2 in zip(paths, out_avals1, out_avals2) + if not core.typematch(a1, a2)] + if len(diffs) == 0: + return # seems unreachable but in any case we don't have a good error msg + elif len(diffs) == 1: + differences = f'{_capitalize(diffs[0])}.\n' + else: + differences = ('\n'.join(f' * {d};' for d in diffs[:-1]) + + f'\n * {diffs[-1]}.\n') + + pvary_applications = [ + f'applying `jax.lax.pvary(..., {tuple(a1.vma - a2.vma)})` ' + f'to the output of {n}{component(p)}' + for p, aval1, aval2 in zip(paths, out_avals1, out_avals2) + for n, a1, a2 in [(name1, aval2, aval1), (name2, aval1, aval2)] + if not core.typematch(a1, a2) and + isinstance(a1, core.ShapedArray) and isinstance(a2, core.ShapedArray) + and a1.vma != a2.vma and a2.vma - a1.vma] + + if not pvary_applications: + pvary_msg = '' + elif len(pvary_applications) == 1: + pvary_msg = f'This might be fixed by {pvary_applications[0]}.\n' + else: + pvary_msg = ('This might be fixed by:\n' + + '\n'.join(f' * {d};' for d in pvary_applications[:-1]) + + f'\n * {pvary_applications[-1]}.\n') + if pvary_msg: + pvary_msg += ("See https://docs.jax.dev/en/latest/notebooks/shard_map.html#scan-vma " + "for more information.\n\n") + + raise TypeError( + f'{api_name} branches must have equal output types but they differ.\n\n' + f'{name1} is {info1}\n' + f'{name2} is {info2}\n\n' + f'{differences}\n' + f'{pvary_msg}' + f'Revise {name1} and/or {name2} so that all output types match.') + + +def _capitalize(s): + # s.capitalize() converts s[1:] to lowercase which we don't want. + return s[0].capitalize() + s[1:] + @api_boundary @functools.wraps(_cond) def cond(*args, **kwargs): @@ -347,6 +453,15 @@ def _cond_abstract_eval(*avals: core.AbstractValue, if disallowed_effects: raise NotImplementedError( f'Effects not supported in `cond`: {disallowed_effects}') + b0_vma = [o.vma for o in branches[0].out_avals] + for branch in branches[1:]: + b_vma = [o.vma for o in branch.out_avals] + if b0_vma != b_vma: + raise Exception("The branches of cond produced mismatched varying manual " + f"axes. Got {b0_vma} and {b_vma}. Please open an issue " + "at https://github.com/jax-ml/jax/issues, and as a " + "temporary workaround pass the check_vma=False argument " + "to `jax.shard_map`") return branches[0].out_avals, joined_effects def _bcast_select(pred, on_true, on_false): @@ -361,7 +476,7 @@ def _bcast_select_n(pred, *cases): pred = lax.broadcast_in_dim(pred, np.shape(cases[0]), idx) return lax.select_n(pred, *cases) -def _cond_batching_rule(axis_data, args, dims, branches): +def _cond_batching_rule(axis_data, args, dims, *, branches, **params): index, *ops = args index_dim, *op_dims = dims # TODO(sharadmv): clean this up by adding a specific blocklist @@ -375,6 +490,11 @@ def _cond_batching_rule(axis_data, args, dims, branches): raise NotImplementedError( "IO effect not supported in vmap-of-cond.") + if "branches_platforms" in params and (index_dim is not batching.not_mapped): + # If we end up with a mapped index for a platform_dependent cond, we can + # replace the index with a fresh call to platform_index. See #29329. + index = platform_index_p.bind(platforms=params["branches_platforms"]) + index_dim = batching.not_mapped if index_dim is not batching.not_mapped: # Convert to a lax.select. While we could get away with not broadcasting @@ -415,10 +535,11 @@ def _cond_batching_rule(axis_data, args, dims, branches): for jaxpr in branches) out_dims = [0 if b else batching.not_mapped for b in out_bat] - out = cond_p.bind(index, *ops, branches=branches_batched) + out = cond_p.bind(index, *ops, branches=branches_batched, + **params) return out, out_dims -def _cond_jvp(primals, tangents, branches): +def _cond_jvp(primals, tangents, *, branches, **params): nonzeros = [type(t) is not ad_util.Zero for t in tangents] index_nz, *ops_nz = nonzeros @@ -435,15 +556,16 @@ def _cond_jvp(primals, tangents, branches): _, *ops_dot = tangents ops_dot = _prune_zeros(ops_dot) - out = cond_p.bind(index, *ops, *ops_dot, branches=branches_jvp) + out = cond_p.bind(index, *ops, *ops_dot, branches=branches_jvp, + **params) out_primals, out_tangents = split_list(out, [len(out_nz)]) out_tangents_iter = iter(out_tangents) out_tangents = [next(out_tangents_iter) if nz else ad_util.Zero.from_primal_value(p) for p, nz in zip(out_primals, out_nz)] return out_primals, out_tangents -def _cond_partial_eval(trace, *tracers, branches): - in_unknowns = [t.pval[0] is not None for t in tracers] +def _cond_partial_eval(trace, *tracers, branches, **params): + in_unknowns = [not t.pval.is_known() for t in tracers] index_uk, *ops_uk = in_unknowns if any(isinstance(eff, RefEffect) for branch in branches for eff in branch.jaxpr.effects): @@ -453,7 +575,7 @@ def _cond_partial_eval(trace, *tracers, branches): if index_uk: # When the branch index is unknown, we stage out the whole cond. # TODO(mattjj): remove this path when old remat is removed - params = dict(branches=branches) + params = dict(branches=branches, **params) return trace.default_process_primitive(cond_p, tracers, params) branches_out_uks = [] @@ -483,7 +605,8 @@ def _cond_partial_eval(trace, *tracers, branches): for j in branches_known[1:]) in_consts = [t.pval.get_known() for t in tracers if t.pval.is_known()] - out_consts_res = cond_p.bind(*in_consts, branches=branches_known) + out_consts_res = cond_p.bind(*in_consts, branches=branches_known, + **params) out_consts, res = split_list(out_consts_res, [len(out_consts_res) - num_res]) index_tracer = trace.instantiate_const(tracers[0]) @@ -492,11 +615,11 @@ def _cond_partial_eval(trace, *tracers, branches): res_tracers = map(trace.new_instantiated_const, res) out_tracers = [pe.JaxprTracer(trace, pe.PartialVal.unknown(aval), None) for aval in branches_unknown[0].out_avals] - params = dict(branches=branches_unknown) + params = dict(branches=branches_unknown, **params) name_stack = source_info_util.current_name_stack()[len(trace.name_stack):] source = source_info_util.current().replace(name_stack=name_stack) eqn = pe.new_eqn_recipe( - [index_tracer] + res_tracers + ops_tracers, out_tracers, cond_p, params, + trace, [index_tracer] + res_tracers + ops_tracers, out_tracers, cond_p, params, core.join_effects(*(j.effects for j in branches_unknown)), source) for t in out_tracers: t.recipe = eqn return util.merge_lists(out_uks, out_consts, out_tracers) @@ -505,6 +628,7 @@ def _cond_partial_eval(trace, *tracers, branches): def _cond_partial_eval_custom(saveable, unks_in, inst_in, eqn): index_uk, *ops_uk = unks_in branches = eqn.params['branches'] + eqn_rest_params = dict(k_v for k_v in eqn.params.items() if k_v[0] != 'branches') # Instantiate all inputs (b/c jaxpr_staged will take all inputs). new_inst = [x for x, inst in zip(eqn.invars, inst_in) @@ -555,13 +679,12 @@ def _cond_partial_eval_custom(saveable, unks_in, inst_in, eqn): for j in branches_known[1:]) # Create residual variables. - newvar = core.gensym() - res_binders = map(newvar, all_res_avals) + res_binders = map(core.Var, all_res_avals) # Build the known eqn. ins_known, _ = partition_list(unks_in, eqn.invars) # includes index invar out_binders_known, _ = partition_list(unks_out, eqn.outvars) - params_known = dict(branches=branches_known) + params_known = dict(branches=branches_known, **eqn_rest_params) effects_known = _join_cond_effects(branches_known) eqn_known = pe.new_jaxpr_eqn( ins_known, [*out_binders_known, *res_binders], cond_p, params_known, @@ -569,7 +692,7 @@ def _cond_partial_eval_custom(saveable, unks_in, inst_in, eqn): # Build the staged eqn. _, out_binders_staged = partition_list(inst_out, eqn.outvars) - params_staged = dict(branches=branches_staged) + params_staged = dict(branches=branches_staged, **eqn_rest_params) effects_staged = _join_cond_effects(branches_staged) eqn_staged = pe.new_jaxpr_eqn( [eqn.invars[0], *res_binders, *eqn.invars[1:]], out_binders_staged, @@ -641,8 +764,7 @@ def f_aug(*args): def _join_cond_pe_staged_jaxpr_inputs(jaxprs: Sequence[core.ClosedJaxpr], all_res_avals, res_aval_indices_per_jaxpr): - newvar = core.gensym(suffix='_') - all_res_vars = map(newvar, all_res_avals) + all_res_vars = map(core.Var, all_res_avals) def augment_jaxpr(jaxpr: core.ClosedJaxpr, res_indices) -> core.ClosedJaxpr: num_res = len(res_indices) @@ -715,7 +837,7 @@ def transposed(*args): debug_info=jaxpr.jaxpr.debug_info), res_avals + jaxpr.out_avals) -def _cond_transpose(cts, *args, branches): +def _cond_transpose(cts, *args, branches, **params): index, *ops = args assert type(index) is not ad.UndefinedPrimal linear = [type(x) is ad.UndefinedPrimal for x in ops] @@ -735,7 +857,8 @@ def _cond_transpose(cts, *args, branches): res = ops[:num_res] cts = map(ad.instantiate_zeros, cts) - out = cond_p.bind(index, *res, *cts, branches=branches_trans) + out = cond_p.bind(index, *res, *cts, branches=branches_trans, + **params) assert all(map(core.typecheck, lin_in_avals, out)) out_iter = iter(out) @@ -743,7 +866,8 @@ def _cond_transpose(cts, *args, branches): assert next(out_iter, None) is None return [None] + out -def _cond_typecheck(bind_time, *in_atoms, branches): +def _cond_typecheck(bind_time, *in_atoms, branches, **params): + del params if not bind_time: _, *in_atoms = in_atoms avals = [x.aval for x in in_atoms] @@ -797,6 +921,16 @@ def _cond_typecheck(bind_time, *in_atoms, branches): f'called with operands of type {_avals_short(op_avals)}') return jaxpr0.out_avals, joined_effects + +BranchesPlatforms = tuple[tuple[str, ...] | None, ...] +# cond_p takes an optional branches_platforms param of type `BranchesPlatforms` +# when it is a `platform_dependent` conditional. +# In that case, `branches_platforms` is a tuple as long +# as `branches` and for each branch it specifies the lowering platforms it +# corresponds to. The last element, corresponding to the last branch, +# can be `None` to represent a default match-all-lowering-platforms. +# The index argument of a `platform_dependent` cond is always a +# `platform_index` primitive. cond_p = core.Primitive('cond') cond_p.multiple_results = True cond_p.skip_canonicalization = True @@ -812,7 +946,39 @@ def _cond_typecheck(bind_time, *in_atoms, branches): pe.dce_rules[cond_p] = _cond_dce_rule batching.ragged_prop_rules[cond_p] = batching.ragged_mask_assert_no_op_rule -def _cond_lowering(ctx, index, *args, branches): +def _cond_lowering(ctx, index, *args, branches, + **params): + if (branches_platforms := params.get("branches_platforms", None)) is not None: + branches_kept: list[core.ClosedJaxpr] = [] + index_to_kept_index: dict[int, int] = {} + for p in mlir._platforms_for_eqn(ctx): + # Each `p` must appear in exactly one branches_platforms, or in the + # last default branch. Otherwise, platform_index lowering would have + # failed already. + for b_idx, b_platforms in enumerate(branches_platforms): + if b_platforms is None or p in b_platforms: + if b_idx not in index_to_kept_index: + index_to_kept_index[b_idx] = len(branches_kept) + branches_kept.append(branches[b_idx]) + break + else: + assert False, p + + # Compute the new index into branches_keep + i32_type = ir.RankedTensorType.get([], mlir.dtype_to_ir_type(dtypes.dtype(np.int32))) + kept_index_case_op = hlo.CaseOp([i32_type], + index=index, + num_branches=len(branches)) + for i in range(len(branches)): + branch = kept_index_case_op.regions[i].blocks.append() + with ir.InsertionPoint(branch): + kept_i = np.int32(index_to_kept_index.get(i, 0)) + hlo.return_([mlir.ir_constant(kept_i)]) + + index = kept_index_case_op + branches = branches_kept + assert branches, "platform_index lowering should have failed first" + joined_effects = core.join_effects(*(branch.effects for branch in branches)) ordered_effects = list(effects.ordered_effects.filter_in(joined_effects)) num_tokens = len(ordered_effects) @@ -849,7 +1015,8 @@ def _cond_lowering(ctx, index, *args, branches): mlir.register_lowering(cond_p, _cond_lowering) @register_partial_discharge_rule(cond_p) -def _cond_state_discharge_rule(should_discharge, in_avals, out_avals, index, *args, branches): +def _cond_state_discharge_rule(should_discharge, in_avals, out_avals, index, *args, + branches, **params): assert not should_discharge[0], "Can't discharge the index." discharged_branches = tuple( discharge_state(branch.jaxpr, (), should_discharge=should_discharge[1:])[0] @@ -878,7 +1045,8 @@ def _cond_state_discharge_rule(should_discharge, in_avals, out_avals, index, *ar if fwd is None]), ()) for branch in discharged_branches ) - out_vals_no_fwd = cond_p.bind(index, *args, branches=new_branches) + out_vals_no_fwd = cond_p.bind(index, *args, branches=new_branches, + **params) out_vals, out_ref_vals_no_fwd = util.split_list(out_vals_no_fwd, [len(out_avals)]) # Insert forwarded values into reference outputs ref_val_no_fwd_iter = iter(out_ref_vals_no_fwd) @@ -943,50 +1111,41 @@ def other_platforms_code(*args): ... The value ``per_platform[execution_platform](*args)``. """ # Join identical branches - platform_branches: list[tuple[list[str], Callable]] = [] + branches_platforms_list: list[tuple[list[str], Callable]] = [] for pname, pbranch in per_platform.items(): + if not callable(pbranch): + raise TypeError(f"lax.platform_dependent: the '{pname}' branch must " + "be a callable.") if pname == "gpu": raise ValueError("Use 'cuda' or 'rocm' for lax.platform_dependent.") - for ps, b in platform_branches: + for ps, b in branches_platforms_list: if b == pbranch: ps.append(pname) break else: - platform_branches.append(([pname], pbranch)) - - platforms_lists, branches = util.unzip2(platform_branches) - platform_index = platform_index_p.bind( - platforms=tuple(tuple(ps) for ps in platforms_lists), - has_default=(default is not None)) + branches_platforms_list.append(([pname], pbranch)) + platforms_lists, branches = util.unzip2(branches_platforms_list) + branches_platforms: BranchesPlatforms = tuple(tuple(ps) for ps in platforms_lists) if default is not None: + if not callable(default): + raise TypeError("lax.platform_dependent: the 'default' branch must " + "be a callable.") branches = branches + (default,) - # Use a switch, to get the proper transformation rules for free. Since - # platform index has no dependence on the input data, it won't be vectorized - # under vmap. - # If the switch and the platform_index_p above are in the same compilation - # unit then constant-folding will remove the unnecessary branches. However, - # if we run in eager mode the switch below cannot be constant-folded and - # the compilation may fail if some of the branches contain custom calls not - # recognized on the compilation platform. Detect eager mode and keep only the - # needed branch. - try: - # Note/TODO(mvoz): This actually rarely seems to concretize - we could look into - # core.ensure_compile_time_eval to get better single-branch selection. - platform_index_concrete = core.concrete_or_error(operator.index, platform_index) - except core.ConcretizationTypeError: - return switch(platform_index, branches, *args) - else: - assert 0 <= platform_index_concrete < len(branches) - return branches[platform_index_concrete](*args) + branches_platforms = branches_platforms + (None,) # type: ignore + platform_index = platform_index_p.bind(platforms=branches_platforms) + + if core.is_concrete(platform_index): + return branches[int(platform_index)](*args) + return _switch_internal(platform_index, branches, args, + branches_platforms=branches_platforms) + # A primitive to compute the index of a platform into a list of platforms. # Args: -# platforms: Sequence[Sequence[str]]: a sequence of sequences of platform -# names. If the current lowering platform is in one of the inner sequences -# returns the index of that inner sequence in the outer sequence. -# has_default: if True, and if the lowering platform is not found in -# `platforms` then return `len(platforms)`. Otherwise, raise an error. +# platforms: BranchesPlatforms. If the current lowering +# platform is in one of the inner tuples returns the index of that inner +# tuple in the outer tuple. platform_index_p = core.Primitive("platform_index") platform_index_p.multiple_results = False platform_index_p.def_impl(functools.partial(dispatch.apply_primitive, @@ -998,25 +1157,25 @@ def _platform_index_aval(*_, **__): def _platform_index_lowering(ctx: mlir.LoweringRuleContext, *, - platforms: Sequence[Sequence[str]], - has_default: bool): - def lower_constant( - ctx: mlir.LoweringRuleContext, *, i: int - ) -> Sequence[ir.Value]: + platforms: BranchesPlatforms): + def lower_constant(ctx: mlir.LoweringRuleContext, *, + i: int) -> Sequence[ir.Value]: v = mlir.ir_constant(np.int32(i)) - assert isinstance(v, ir.Value), v return [v] + platform_rules: dict[str, mlir.LoweringRule] = {} + default_rule = None for i, ps in enumerate(platforms): rule = partial(lower_constant, i=i) - for p in ps: - platform_rules[p] = rule + if ps is None: + default_rule = rule + else: + for p in ps: + platform_rules[p] = rule - default_rule = ( - partial(lower_constant, i=len(platforms)) if has_default else None) return mlir.lower_per_platform( ctx, - f"platform_index(platforms={platforms}, has_default={has_default})", + f"platform_index(platforms={platforms})", platform_rules, default_rule, effects.no_effects) mlir.register_lowering(platform_index_p, _platform_index_lowering) diff --git a/jax/_src/lax/control_flow/for_loop.py b/jax/_src/lax/control_flow/for_loop.py index fc7ebde4cbea..773061c59bd4 100644 --- a/jax/_src/lax/control_flow/for_loop.py +++ b/jax/_src/lax/control_flow/for_loop.py @@ -20,14 +20,13 @@ import operator from typing import Any, Generic, TypeVar -from jax import lax from jax._src import api_util from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir from jax._src.interpreters import partial_eval as pe -from jax.tree_util import (tree_flatten, tree_structure, tree_unflatten, - treedef_tuple, tree_map, tree_leaves, PyTreeDef) +from jax._src.tree_util import (tree_flatten, tree_structure, tree_unflatten, + treedef_tuple, tree_map, tree_leaves, PyTreeDef) from jax._src import ad_util from jax._src import core @@ -35,6 +34,7 @@ from jax._src import dtypes from jax._src import linear_util as lu from jax._src import source_info_util +from jax._src.lax import lax from jax._src.state.types import (ReadEffect, AbstractRef, StateEffect) from jax._src.state import discharge as state_discharge from jax._src.state import primitives as state_primitives @@ -272,7 +272,7 @@ def while_body(carry): state = body(i, state) i = i + 1 return i, state - _, state = lax.while_loop(cond, while_body, (i, state)) + _, state = loops.while_loop(cond, while_body, (i, state)) return state mlir.register_lowering(for_p, mlir.lower_fun(_for_impl, multiple_results=True)) @@ -498,7 +498,7 @@ def _for_partial_eval(trace: pe.JaxprTrace, *tracers: pe.JaxprTracer, assert len(unknown_inputs) == len(res_ref_unknown_outputs) assert len(unknown_inputs) == len(jaxpr_unknown.invars) - 1 - eqn = pe.new_eqn_recipe(unknown_inputs, res_ref_unknown_outputs, + eqn = pe.new_eqn_recipe(trace, unknown_inputs, res_ref_unknown_outputs, for_p, dict(jaxpr=jaxpr_unknown, nsteps=nsteps, reverse=reverse, which_linear=which_linear_unknown, diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index 3084fa722977..65162ea15305 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -17,7 +17,7 @@ from collections.abc import Callable, Sequence from functools import partial import inspect -import itertools +import itertools as it import operator from typing import Any, TypeVar import weakref @@ -53,31 +53,22 @@ _initial_style_jaxpr_attrs, _make_closed_jaxpr_attrs, _prune_zeros, _typecheck_param) from jax._src.lax.other import logaddexp +from jax._src.pjit import auto_axes, PartitionSpec as P +from jax._src.mesh import get_abstract_mesh from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo from jax._src.state import discharge as state_discharge from jax._src.traceback_util import api_boundary from jax._src.tree_util import equality_errors from jax._src.typing import Array +from jax._src.attrs import jax_setattr, jax_getattr, jax_extendattr from jax._src.util import ( - merge_lists, - partition_list, - safe_map, - safe_zip, - split_list, - split_list_checked, - unzip2, - weakref_lru_cache, -) + merge_lists, partition_list, safe_map, safe_zip, split_list, + split_list_checked, unzip2, weakref_lru_cache,) from jax._src import xla_bridge as xb -from jax.tree_util import ( - keystr, - tree_flatten, - tree_flatten_with_path, - tree_map, - tree_unflatten, - treedef_is_leaf, -) +from jax._src.tree_util import ( + keystr, tree_flatten, tree_flatten_with_path, tree_map, tree_unflatten, + treedef_is_leaf) import numpy as np _map = safe_map @@ -178,6 +169,11 @@ def scan(f, init, xs, length=None): :py:func:`scan` compiles ``f``, so while it can be combined with :py:func:`jit`, it's usually unnecessary. + .. note:: + :func:`scan` is designed for iterating with a static number of iterations. + For iteration with a dynamic number of iterations, use :func:`fori_loop` + or :func:`while_loop`. + Args: f: a Python function to be scanned of type ``c -> a -> (c, b)``, meaning that ``f`` accepts two arguments where the first is a value of the loop @@ -202,7 +198,7 @@ def scan(f, init, xs, length=None): a single iteration of a loop. If an integer is provided, it determines how many unrolled loop iterations to run within a single rolled iteration of the loop. If a boolean is provided, it will determine if the loop is - competely unrolled (i.e. `unroll=True`) or left completely rolled (i.e. + completely unrolled (i.e. `unroll=True`) or left completely rolled (i.e. `unroll=False`). _split_transpose: experimental optional bool specifying whether to further split the transpose into a scan (computing activation gradients), and a @@ -239,7 +235,9 @@ def scan(f, init, xs, length=None): try: length = int(length) except core.ConcretizationTypeError as err: - msg = 'The `length` argument to `scan` expects a concrete `int` value.' + msg = ('The `length` argument to `scan` expects a concrete `int` value.' + ' For scan-like iteration with a dynamic length, use `while_loop`' + ' or `fori_loop`.') raise core.ConcretizationTypeError(length, msg) from None # type: ignore[arg-type] if not all(length == l for l in lengths): msg = ("scan got `length` argument of {} which disagrees with " @@ -291,8 +289,18 @@ def _create_jaxpr(init): if len(out_tree_children) != 2: msg = "scan body output must be a pair, got {}." raise TypeError(msg.format(tree_unflatten(out_tree, jaxpr.out_avals))) - _, carry_avals_out, _ = split_list( - jaxpr.out_avals, [len(attrs_tracked), out_tree_children[0].num_leaves]) + + if attrs_tracked: + appends_out = [k for _, t, (_, _, k) in attrs_tracked + for k in [k in (pe.Append, pe.ListAttr)] * t.num_leaves] + jaxpr = pe.move_outvars_to_back( + jaxpr, appends_out + [False] * (len(jaxpr.out_avals) - len(appends_out))) + num_attr_carry = sum(init_tree.num_leaves for init_tree, _, (_, _, kind) + in attrs_tracked if kind in (pe.ReadWrite, pe.BoxAttr)) + _, carry_avals_out, _ = split_list( + jaxpr.out_avals, [num_attr_carry, out_tree_children[0].num_leaves]) + else: + carry_avals_out, _ = split_list(jaxpr.out_avals, [out_tree_children[0].num_leaves]) return (init_flat, carry_avals, carry_avals_out, init_tree, in_flat, jaxpr, consts, out_tree, out_tree_children, attrs_tracked) @@ -308,6 +316,9 @@ def _create_jaxpr(init): init_flat, carry_avals, carry_avals_out, init_tree, *rest = _create_jaxpr(init) in_flat, jaxpr, consts, out_tree, out_tree_children, attrs_tracked = rest num_carry = len(init_flat) + num_xs = len(x_avals) + num_ys = len(jaxpr.out_avals) - num_carry + del init_flat _check_carry_type('scan body', f, init, out_tree_children[0], carry_avals_out) disallowed_effects = effects.control_flow_allowed_effects.filter_not_in(jaxpr.effects) @@ -323,39 +334,123 @@ def _create_jaxpr(init): unroll = max(length, 1) if unroll else 1 if unroll < 1: raise ValueError("`unroll` must be a `bool` or a positive `int`.") + if attrs_tracked: in_state = _get_states(attrs_tracked) - in_carry, in_ext = split_list(in_flat, [num_carry]) - in_flat = [*in_state, *in_carry, *in_ext] - num_carry += len(attrs_tracked) + in_flat = [*in_state, *in_flat] + num_carry += len(in_state) + + # If the body forwards an input carry to an output carry, that input is + # read-only and can be moved to be a const. Doing so can lead to efficiency + # wins, e.g. if the scan is inside a cond with a batched predicate. + carry_fwd, ext_fwd = split_list(pe._jaxpr_forwarding(jaxpr.jaxpr), [num_carry]) + move_to_const = [len(consts) + i == f for i, f in enumerate(carry_fwd)] + if any(move_to_const): + jaxpr = pe.prune_closed_jaxpr_outputs( + jaxpr, [not m for m in move_to_const] + [True] * num_ys) + jaxpr = pe.move_binders_to_front( + jaxpr, [False] * len(consts) + move_to_const + [False] * num_xs) + in_flat, new_consts = partition_list(move_to_const + [False] * num_xs, in_flat) + consts = [*new_consts, *consts] + num_carry -= len(new_consts) + + # When an extensive output is forwarded from an extensive input, we can + # avoid copying it by pruning it from the jaxpr and forwarding manually. We + # don't need to update the indexing based on the optimization above since it + # doesn't change the total number of consts and carries combined, and + # `ext_fwd` already only includes the extensive outputs. But, we do remove + # the number of consts from the index since we're going to use it to index + # into `in_flat`, which doesn't include consts. + ext_to_ext_fwd = [ + in_idx - len(consts) if in_idx is not None and + in_idx >= num_carry + len(consts) else None for in_idx in ext_fwd] + jaxpr = pe.prune_closed_jaxpr_outputs( + jaxpr, [True] * num_carry + [i is None for i in ext_to_ext_fwd]) + out = scan_p.bind(*consts, *in_flat, reverse=reverse, length=length, jaxpr=jaxpr, num_consts=len(consts), num_carry=num_carry, linear=(False,) * (len(consts) + len(in_flat)), - unroll=unroll, - _split_transpose=_split_transpose) + unroll=unroll, _split_transpose=_split_transpose) + + # Apply input to output forwarding that was computed above. + carry_out, out = split_list(out, [num_carry]) + out_ = iter(out) + out = [next(out_) if f is None else _maybe_put(in_flat[f]) for f in ext_to_ext_fwd] + assert next(out_, None) is None + out = [*carry_out, *out] + + if any(move_to_const): + out = pe.merge_lists(move_to_const + [False] * num_ys, out, new_consts) + if attrs_tracked: - out_state, out = split_list(out, [len(attrs_tracked)]) - _set_states(attrs_tracked, out_state) + num_ext = (len(out) - len(in_state) + - sum(k is pe.Append for *_, (_, _, k) in attrs_tracked) + - sum(t.num_leaves for _, t, (_, _, k) in attrs_tracked + if k is pe.ListAttr)) + out_state, out, out_append = split_list(out, [len(in_state), num_ext]) + out_attrs = _merge_attrs_out(attrs_tracked, out_state, out_append) + _set_states(attrs_tracked, out_attrs) + return tree_unflatten(out_tree, out) def _set_states(attrs_tracked, vals): - from jax.experimental.attrs import jax_setattr valss = split_list_checked(vals, [td.num_leaves for _, td, _ in attrs_tracked]) - for ((_, treedef, (obj, attr)), leaves) in zip(attrs_tracked, valss): - val = tree_unflatten(treedef, leaves) - jax_setattr(obj, attr, val) + for ((_, treedef, (obj, attr, kind)), leaves) in zip(attrs_tracked, valss): + if kind is pe.ReadWrite: + val = tree_unflatten(treedef, leaves) + jax_setattr(obj, attr, val) + elif kind is pe.Append: + val, = leaves + jax_extendattr(obj, attr, val.reshape(-1, *val.shape[2:])) + elif kind is pe.BoxAttr: + val = tree_unflatten(treedef, leaves) + obj.set(val) + elif kind is pe.ListAttr: + for leaves_ in zip(*leaves): + for item in tree_unflatten(treedef, leaves_): + obj.append(item) + else: + assert False def _get_states(attrs_tracked): - from jax.experimental.attrs import jax_getattr vals = [] - for treedef, _, (obj, attr) in attrs_tracked: - tree = jax_getattr(obj, attr) - leaves, treedef_ = tree_flatten(tree) - assert treedef == treedef_ - vals.extend(leaves) + for treedef, _, (obj, attr, kind) in attrs_tracked: + if kind is pe.ReadWrite: + tree = jax_getattr(obj, attr) + leaves, treedef_ = tree_flatten(tree) + assert treedef == treedef_ + vals.extend(leaves) + elif kind is pe.Append: + pass + elif kind is pe.BoxAttr: + tree = obj.get() + leaves, treedef_ = tree_flatten(tree) + assert treedef == treedef_ + vals.extend(leaves) + elif kind is pe.ListAttr: + pass + else: + assert False return vals +def _merge_attrs_out(attrs_tracked, out_state, out_append): + # merge out_state & out_append back into attrs_tracked order + out_state_, out_append_ = iter(out_state), iter(out_append) + out_attrs = [] + for _, out_tree, (_, _, k) in attrs_tracked: + if k in (pe.ReadWrite, pe.BoxAttr): + out_attrs.extend(it.islice(out_state_, out_tree.num_leaves)) + elif k is pe.Append: + out_attrs.append(next(out_append_)) + elif k is pe.ListAttr: + out_attrs.extend(it.islice(out_append_, out_tree.num_leaves)) + else: + assert False + assert next(out_state_, None) is next(out_append_, None) is None + return out_attrs + + def _capitalize(s): # s.capitalize() converts s[1:] to lowercase which we don't want. return s[0].capitalize() + s[1:] @@ -390,9 +485,8 @@ def _check_carry_type(name, body_fun, in_carry, out_carry_tree, out_avals): for path, thing1, thing2, explanation in equality_errors(in_carry, out_carry)] if len(diffs) == 0: - # The trees may have different aux data but structures are the same. - return - if len(diffs) == 1: + return # the trees may have different aux data, but structures are same + elif len(diffs) == 1: differences = f'{_capitalize(diffs[0])}.\n' else: differences = ('\n'.join(f' * {d};\n' for d in diffs[:-1]) @@ -406,35 +500,45 @@ def _check_carry_type(name, body_fun, in_carry, out_carry_tree, out_avals): if not all(_map(core.typematch, in_avals, out_avals)): diffs = [f'{component(path)} has type {in_aval.str_short()}' ' but the corresponding output carry component has type ' - f'{out_aval.str_short()}{_aval_mismatch_extra(in_aval, out_aval)}' + f'{out_aval.str_short()}{core.aval_mismatch_extra(in_aval, out_aval)}' for path, in_aval, out_aval in zip(paths, in_avals, out_avals) if not core.typematch(in_aval, out_aval)] + if len(diffs) == 0: - # The trees may have different aux data but structures are the same. - return + return # seems unreachable but in any case we don't have a good error msg if len(diffs) == 1: differences = f'{_capitalize(diffs[0])}.\n' else: differences = ('\n'.join(f' * {d};\n' for d in diffs[:-1]) + f' * {diffs[-1]}.\n') + + pvary_applications = [ + f'applying `jax.lax.pvary(..., {tuple(out_aval.vma - in_aval.vma)})` ' + f'to the initial carry value corresponding to {component(path)}' + for path, in_aval, out_aval in zip(paths, in_avals, out_avals) + if not core.typematch(in_aval, out_aval) and + isinstance(in_aval, ShapedArray) and isinstance(out_aval, ShapedArray) + and in_aval.vma != out_aval.vma and out_aval.vma - in_aval.vma] + + if not pvary_applications: + pvary_msg = '' + elif len(pvary_applications) == 1: + pvary_msg = f'This might be fixed by {pvary_applications[0]}.\n' + else: + pvary_msg = ('This might be fixed by:\n' + + '\n'.join(f' * {d};\n' for d in pvary_applications[:-1]) + + f' * {pvary_applications[-1]}.\n') + if pvary_msg: + pvary_msg += ("See https://docs.jax.dev/en/latest/notebooks/shard_map.html#scan-vma " + "for more information.\n\n") + raise TypeError( - f"{name} function carry input and carry output must have equal types " - "(e.g. shapes and dtypes of arrays), " + f"{name} function carry input and carry output must have equal types, " "but they differ:\n\n" f"{differences}\n" - "Revise the function so that all output types (e.g. shapes " - "and dtypes) match the corresponding input types.") - -def _aval_mismatch_extra(a1: core.AbstractValue, a2: core.AbstractValue) -> str: - assert not core.typematch(a1, a2) - if isinstance(a1, core.ShapedArray) and isinstance(a2, core.ShapedArray): - dtype_mismatch = a1.dtype != a2.dtype - shape_mismatch = a1.shape != a2.shape - return (', so ' * (dtype_mismatch or shape_mismatch) + - 'the dtypes do not match' * dtype_mismatch + - ' and also ' * (dtype_mismatch and shape_mismatch) + - 'the shapes do not match' * shape_mismatch) - return '' + f"{pvary_msg}" + "Revise the function so that all output types match the corresponding " + "input types.") # TODO(mattjj): re-land #19819 version? simpler, but caused ~1 perf regression. def _scan_impl(*args, reverse, length, num_consts, num_carry, jaxpr, linear, @@ -512,15 +616,17 @@ def _split_leading(sz, x): def _concat(a, b): return lax.concatenate([a, b], 0) def _empty_array(prefix, length_spec, aval): - sharding = aval.sharding.with_spec((*length_spec, *aval.sharding.spec)) - return lax.broadcast(lax.empty(aval.dtype), (*prefix, *aval.shape), - out_sharding=sharding) + sharding = aval.sharding.update(spec=(*length_spec, *aval.sharding.spec)) + empty = core.pvary(lax.empty(aval.dtype), tuple(aval.vma)) + return lax.broadcast(empty, (*prefix, *aval.shape), out_sharding=sharding) eval_jaxpr_p = core.Primitive('eval_jaxpr') eval_jaxpr_p.multiple_results = True -def _stage_jaxpr(trace: pe.JaxprTrace, *tracers, jaxpr: core.ClosedJaxpr): +def _stage_jaxpr(trace: pe.DynamicJaxprTrace, source_info, *tracers, + jaxpr: core.ClosedJaxpr): params = dict(call_jaxpr=jaxpr) - return trace.default_process_primitive(core.closed_call_p, tracers, params) + return trace.default_process_primitive(core.closed_call_p, tracers, params, + source_info=source_info) pe.custom_staging_rules[eval_jaxpr_p] = _stage_jaxpr @eval_jaxpr_p.def_effectful_abstract_eval # abstract eval only used for jax2tf @@ -532,9 +638,17 @@ def _prepend_dim_to_aval(sz, aval): def _scan_abstract_eval(*args, reverse, length, num_consts, num_carry, jaxpr, linear, unroll, _split_transpose): - carry_avals, y_avals = split_list(jaxpr.out_avals, [num_carry]) + out_carry_avals, y_avals = split_list(jaxpr.out_avals, [num_carry]) + _, in_carry_avals, _ = split_list(args, [num_consts, num_carry]) + if [i.vma for i in in_carry_avals] != [o.vma for o in out_carry_avals]: + raise ValueError( + 'Scan carry input and output got mismatched varying manual axes ' + f'{in_carry_avals} and {out_carry_avals}. Please open an ' + 'issue at https://github.com/jax-ml/jax/issues, and as a ' + 'temporary workaround pass the check_vma=False argument to ' + '`jax.shard_map`') ys_avals = _map(partial(_prepend_dim_to_aval, length), y_avals) - return carry_avals + ys_avals, jaxpr.effects + return out_carry_avals + ys_avals, jaxpr.effects def _scan_jvp(primals, tangents, reverse, length, jaxpr, num_consts, num_carry, linear, unroll, _split_transpose): @@ -593,6 +707,108 @@ def _scan_jvp(primals, tangents, reverse, length, jaxpr, num_consts, num_carry, for p, nz in zip(primals_out, nonzeros_out)] return primals_out, tangents_out +def _scan_linearization(nzs, *primals_in, reverse: bool, length: int, + num_consts: int, num_carry: int, + jaxpr: core.ClosedJaxpr, linear: Sequence[bool], + unroll: int, _split_transpose: bool): + const_nz, init_nz, xs_nz = split_list(nzs, [num_consts, num_carry]) + carry_nz = init_nz + for _ in range(1 + num_carry): + nzs = const_nz + carry_nz + xs_nz + primal_jaxpr, num_res, nzs_out, tangent_jaxpr = ad.linearize_jaxpr(jaxpr, nzs) + carry_nz_out = nzs_out[:num_carry] + if carry_nz_out == carry_nz: + break + else: + carry_nz = _map(operator.or_, carry_nz, carry_nz_out) + else: + assert False, "Fixpoint not reached" + + # The linearize_jaxpr function produces primal_jaxpr with num_res residuals + # output at the front, and tangent_jaxpr with num_res residuals input at the + # back. We could move all the residuals to the back and treat them as + # extensive outputs, but this would be wasteful for residuals that are + # loop invariant, or forwarded extensive inputs. + + # First, for residuals that are forwarded constants, we move those to the + # front in the tangent_jaxpr to treat them as intensive inputs. + in_fwd = pe._jaxpr_forwarding(primal_jaxpr.jaxpr) + primal_jaxpr, tangent_jaxpr, intensive_res, in_fwd = _const_to_intensive_res_forwarding( + primal_jaxpr, tangent_jaxpr, num_res, num_consts, primals_in, in_fwd) + num_intensive_res = len(intensive_res) + num_res -= num_intensive_res + + # After pruning the intensive residuals, the rest get moved to the back and + # handled as extensive outputs from the primal. + num_out = len(nzs_out) + primal_jaxpr = pe.move_outvars_to_back( + primal_jaxpr, [True] * num_res + [False] * num_out) + in_fwd = in_fwd[num_res:] + in_fwd[:num_res] + + # Then, any residuals or other extensive outputs that are forwarded extensive + # inputs, we remove them from the primal jaxpr, and manually forward them. + in_fwd = [in_idx if out_idx >= num_carry and in_idx is not None and + in_idx >= num_consts + num_carry else None + for out_idx, in_idx in enumerate(in_fwd)] + primal_jaxpr = pe.prune_closed_jaxpr_outputs(primal_jaxpr, + [i is None for i in in_fwd]) + + out = scan_p.bind(*primals_in, jaxpr=primal_jaxpr, reverse=reverse, + length=length, num_consts=num_consts, num_carry=num_carry, + linear=linear, unroll=unroll, _split_transpose=_split_transpose) + out_ = iter(out) + all_out = [next(out_) if f is None else _maybe_put(primals_in[f]) for f in in_fwd] + assert next(out_, None) is None + primals_out, extensive_res = split_list(all_out, [len(all_out) - num_res]) + res = [*intensive_res, *extensive_res] + + def tangent_fun(res, *tangents): + intensive_res, extensive_res = split_list(res, [num_intensive_res]) + nz_tangents = [ad.instantiate_zeros(x) for nz, x in zip(nzs, tangents) if nz] + tangent_linear = ( + (False,) * len(intensive_res) + + (True,) * len(nz_tangents) + + (False,) * len(extensive_res) + ) + tangent_num_consts = len(intensive_res) + sum(nzs[:num_consts]) + tangent_num_carry = sum(nzs[num_consts:num_consts + num_carry]) + nz_tangents_out = scan_p.bind(*intensive_res, *nz_tangents, *extensive_res, + jaxpr=tangent_jaxpr, + reverse=reverse, length=length, + num_consts=tangent_num_consts, + num_carry=tangent_num_carry, + linear=tangent_linear, unroll=unroll, + _split_transpose=_split_transpose) + tangent_avals_out = [v.aval.to_tangent_aval() for v in jaxpr.jaxpr.outvars] + nz_tangents_out_ = iter(nz_tangents_out) + tangents_out = [next(nz_tangents_out_) if nz else ad.Zero(aval) + for aval, nz in zip(tangent_avals_out, nzs_out)] + assert next(nz_tangents_out_, None) is None + return tangents_out + + return primals_out, nzs_out, res, tangent_fun + +def _const_to_intensive_res_forwarding( + primal_jaxpr: core.ClosedJaxpr, + tangent_jaxpr: core.ClosedJaxpr, + num_res: int, + num_consts: int, + primals_in: Sequence[Any], + in_fwd: list[int | None] +) -> tuple[core.ClosedJaxpr, core.ClosedJaxpr, list[Any], list[int | None]]: + const_to_res = [in_idx if in_idx is not None and in_idx < num_consts else None + for in_idx in in_fwd[:num_res]] + new_in_fwd = [f for c, f in zip(const_to_res, in_fwd[:num_res]) if c is None] + new_in_fwd += in_fwd[num_res:] + intensive_res = [primals_in[f] for f in const_to_res if f is not None] + num_out = len(primal_jaxpr.out_avals) - num_res + primal_jaxpr = pe.prune_closed_jaxpr_outputs( + primal_jaxpr, [i is None for i in const_to_res] + [True] * num_out) + num_nz = len(tangent_jaxpr.in_avals) - num_res + tangent_jaxpr = pe.move_binders_to_front( + tangent_jaxpr, [False] * num_nz + [i is not None for i in const_to_res]) + return primal_jaxpr, tangent_jaxpr, intensive_res, new_in_fwd + def _scan_partial_eval(trace, *tracers, reverse: bool, length: int, num_consts: int, num_carry: int, jaxpr: core.ClosedJaxpr, linear: Sequence[bool], @@ -642,6 +858,8 @@ def _scan_partial_eval(trace, *tracers, reverse: bool, # want to broadcast the matrix!). So, outside the loop we perform a partial # evaluation with known 'const' inputs (but all other inputs unknown). const_pvals = [pe.PartialVal.known(t.pval.get_known()) + if not isinstance(t.aval, state.AbstractRef) + else pe.PartialVal.unknown(t.aval) for t in tracers[:num_consts] if t.pval.is_known()] other_pvals = [pe.PartialVal.unknown(aval) for aval in jaxpr_known.in_avals[len(const_pvals):]] @@ -655,7 +873,7 @@ def _scan_partial_eval(trace, *tracers, reverse: bool, # The above trace_to_jaxpr_nounits call computed loop-invariant residuals # (known values in invar_pvals_out) and also computed loop-invariant values # needed by the new jaxpr_known (in jaxpr_known_consts, which replace the - # previous consts). We need to collect the computed inteisive residuals, and + # previous consts). We need to collect the computed intensive residuals, and # move corresponding intensive residual binders in jaxpr_unknown to the front. res_pvals = invar_pvals_out[len(invar_pvals_out) - num_res:] intensive_res = [pval.get_known() for pval in res_pvals if pval.is_known()] @@ -686,7 +904,9 @@ def _scan_partial_eval(trace, *tracers, reverse: bool, # We use `fwds_known` below when forming the output of scanning jaxpr_known. # Run the known part of the scan (if it has any outputs or effects). - known_inputs = (list(jaxpr_known_consts) + + known_mutable_consts = [t.pval.get_known() for t in tracers[:num_consts] + if t.pval.is_known() and isinstance(t.aval, state.AbstractRef)] + known_inputs = (list(jaxpr_known_consts) + known_mutable_consts + [t.pval.get_known() for t in tracers[num_consts:] if t.pval.is_known()]) if not jaxpr_known.out_avals and not jaxpr_known.effects: @@ -695,7 +915,8 @@ def _scan_partial_eval(trace, *tracers, reverse: bool, linear_known = [False] * len(known_inputs) # conservative! out_known = scan_p.bind( *known_inputs, reverse=reverse, length=length, jaxpr=jaxpr_known, - num_consts=len(jaxpr_known_consts), num_carry=num_carry - sum(carry_uk), + num_consts=len(jaxpr_known_consts) + len(known_mutable_consts), + num_carry=num_carry - sum(carry_uk), linear=tuple(linear_known), unroll=unroll, _split_transpose=_split_transpose) del linear_known @@ -719,7 +940,7 @@ def _scan_partial_eval(trace, *tracers, reverse: bool, ys_avals = [core.unmapped_aval(length, 0, y_aval) for y_aval in y_avals] out_tracers = [pe.JaxprTracer(trace, pe.PartialVal.unknown(a), None) - for a in itertools.chain(carry_avals, ys_avals)] + for a in it.chain(carry_avals, ys_avals)] del carry_avals, y_avals # Create equation. linear_unknown = tuple([False] * len(intensive_res) + @@ -728,7 +949,7 @@ def _scan_partial_eval(trace, *tracers, reverse: bool, name_stack = source_info_util.current_name_stack()[len(trace.name_stack):] source = source_info_util.current().replace(name_stack=name_stack) assert len(out_tracers) == len(jaxpr_unknown.out_avals) - eqn = pe.new_eqn_recipe([*intensive_res, *unknown_inputs, *extensive_res], + eqn = pe.new_eqn_recipe(trace, [*intensive_res, *unknown_inputs, *extensive_res], out_tracers, scan_p, dict(reverse=reverse, length=length, unroll=unroll, jaxpr=jaxpr_unknown, linear=linear_unknown, @@ -778,16 +999,22 @@ def _scan_transpose(cts, *args, reverse, length, num_consts, ct_consts = _map(ad_util.zeros_like_aval, jaxpr.in_avals[num_ires:num_consts]) # jaxpr :: [ires, T d] -> [T c] -> [T a, eres] -> ([T c], [T b]) - # jaxpr_trans :: [ires] -> [CT d, CT c] -> [CT b, eres] -> ([CT d, CT c], [CT a]) + # jaxpr_trans :: [ires] -> [CT d, CT c] -> [CT b, eres] -> ([CT d, CT c], [CT a, e]) jaxpr_trans, attrs_tracked = _transpose_scan_jaxpr( jaxpr, num_ires, num_consts - num_ires, num_eres, ct_ys_is_zeros) - linear_trans = ([False] * num_ires + [False] * len(attrs_tracked) + + appends_out = [k for _, t, (_, _, k) in attrs_tracked + for k in [k in (pe.Append, pe.ListAttr)] * t.num_leaves] + jaxpr_trans = pe.move_outvars_to_back( + jaxpr_trans, appends_out + [False] * (len(jaxpr_trans.out_avals) - len(appends_out))) + num_attr_carry = sum(init_tree.num_leaves for init_tree, _, (_, _, kind) + in attrs_tracked if kind is pe.ReadWrite) + linear_trans = ([False] * num_ires + [False] * num_attr_carry + [True] * (len(ct_consts) + len(ct_carry) + len(ct_ys)) + [False] * num_eres) in_state = _get_states(attrs_tracked) transpose_inputs = *ires, *in_state, *ct_consts, *ct_carry, *ct_ys, *eres - transpose_num_out_carry = num_consts-num_ires+num_carry+len(attrs_tracked) + transpose_num_out_carry = num_consts-num_ires+num_carry+num_attr_carry if not _split_transpose: outs = scan_p.bind( @@ -882,8 +1109,10 @@ def _scan_transpose(cts, *args, reverse, length, num_consts, for mask in outs_mask ] - out_state, outs = split_list(outs, [len(attrs_tracked)]) - _set_states(attrs_tracked, out_state) + num_outs = len(outs) - num_attr_carry - sum(appends_out) + out_state, outs, out_append = split_list(outs, [num_attr_carry, num_outs]) + out_attrs = _merge_attrs_out(attrs_tracked, out_state, out_append) + _set_states(attrs_tracked, out_attrs) ct_consts, ct_init, ct_xs = split_list(outs, [num_consts - num_ires, num_carry]) return [None] * num_ires + ct_consts + ct_init + ct_xs + [None] * num_eres @@ -928,12 +1157,10 @@ def transposed(*res1_cbar_bbar_res2): return c_bar + a_bar # TODO(necula): fix arg names and results for transposed - transposed_wrapped = lu.wrap_init(transposed, - debug_info=jaxpr.jaxpr.debug_info) - return _make_closed_jaxpr_attrs( - transposed_wrapped, - tuple(res1_avals + c_avals + b_carry_avals + - b_ys_avals_stripped + res2_avals)) + transposed_wrapped = lu.wrap_init(transposed, debug_info=jaxpr.jaxpr.debug_info) + trans_avals = (*res1_avals, *c_avals, *b_carry_avals, *b_ys_avals_stripped, *res2_avals) + trans_jaxpr, attrs_tracked = _make_closed_jaxpr_attrs(transposed_wrapped, trans_avals) + return trans_jaxpr, attrs_tracked def _scan_batching_rule(axis_data, args, @@ -1074,10 +1301,12 @@ def _scan_partial_eval_custom(saveable, unks_in, inst_in, eqn): num_const_known = len(const_uk) - sum(const_uk) num_carry_known = len(carry_uk) - sum(carry_uk) num_xs_known = len( xs_uk) - sum( xs_uk) + const_donthoist = [isinstance(a, state.AbstractRef) + for a in jaxpr_known.in_avals[:num_const_known]] jaxpr_known_hoist, jaxpr_known_loop, loop_dep, consts_known_lp_avals = \ pe.partial_eval_jaxpr_nounits( jaxpr_known, - [False] * num_const_known + [True] * (num_carry_known + num_xs_known), + const_donthoist + [True] * (num_carry_known + num_xs_known), [True] * (len(unks_out) - sum(unks_out)) + [False] * num_res) # jaxpr_known_hoist produces intensive residuals followed by the constants for # jaxpr_known_loop. We adjust jaxpr_staged to accept intensive res as consts. @@ -1110,10 +1339,13 @@ def _scan_partial_eval_custom(saveable, unks_in, inst_in, eqn): linear=tuple(linear_known)) def known(*ins_known): - consts_known_hoist, ins_known_lp = split_list(ins_known, [num_const_known]) + consts_known_maybehoist, ins_known_lp = split_list(ins_known, [num_const_known]) + consts_known_hoist, consts_known_donthoist = \ + partition_list(const_donthoist, consts_known_maybehoist) out_hoist = core.jaxpr_as_fun(jaxpr_known_hoist)(*consts_known_hoist) intensive_res, consts_known_lp = split_list(out_hoist, [num_intensive_res]) - out_loop = scan_p.bind(*consts_known_lp, *ins_known_lp, **params_known) + out_loop = scan_p.bind(*consts_known_lp, *consts_known_donthoist, + *ins_known_lp, **params_known) return [*intensive_res, *out_loop] call_jaxpr_, _, call_jaxpr_consts, () = pe.trace_to_jaxpr_dynamic( lu.wrap_init(known, debug_info=jaxpr_known_hoist.jaxpr.debug_info), @@ -1289,6 +1521,7 @@ def arrange_jaxpr_args_for_wrapped(args): scan_p.def_effectful_abstract_eval(_scan_abstract_eval) ad.primitive_jvps[scan_p] = _scan_jvp ad.primitive_transposes[scan_p] = _scan_transpose +ad.primitive_linearizations[scan_p] = _scan_linearization pe.custom_partial_eval_rules[scan_p] = _scan_partial_eval xla.register_initial_style_primitive(scan_p) mlir.register_lowering(scan_p, @@ -1300,6 +1533,64 @@ def arrange_jaxpr_args_for_wrapped(args): pe.dce_rules[scan_p] = _scan_dce_rule state_discharge.register_partial_discharge_rule(scan_p)(_scan_state_partial_discharge_rule) +def _is_high(jaxpr, **_) -> bool: + return jaxpr.jaxpr.is_high +scan_p.is_high = _is_high # type: ignore + +def _to_lojax(*hi_args, jaxpr, num_carry, num_consts, linear, **params): + + # move box binders and hi_args from consts slots to carry slots + to_move = [t.has_qdd for t in jaxpr.in_aval_qdds[:num_consts]] + jaxpr = pe.move_invars_right(jaxpr, to_move) + hi_args = _move_right(hi_args, to_move) + num_consts -= sum(to_move) + num_carry += sum(to_move) + + # expand num_consts, num_carry, linear according to lo types + const_in_avals, carry_in_avals, _ = split_list(jaxpr.in_aval_qdds, [num_consts, num_carry]) + num_consts = sum(len(aval.lo_ty()) for aval in const_in_avals) + num_carry = sum(len(aval.lo_ty()) for aval in carry_in_avals) + linear = [l for aval, l_ in zip(jaxpr.in_aval_qdds, linear) + for l in (l_,) * len(aval.lo_ty())] + lo_muts_out = sum(len(aval.lo_ty()) for aval in jaxpr.final_aval_qdds if aval.has_qdd) + + # collect lo input values + lo_args = [lo_val for aval, x in zip(jaxpr.in_aval_qdds, hi_args) + for lo_val in (aval.read_loval(x) if aval.has_qdd + else aval.lower_val(x))] + + # lower the jaxpr and bind it using lo input values + lo_jaxpr = pe.lower_jaxpr(jaxpr) + all_outs = scan_p.bind(*lo_args, jaxpr=lo_jaxpr, num_consts=num_consts, + num_carry=num_carry, linear=tuple(linear), **params) + out_mut, lo_outs = split_list(all_outs, [lo_muts_out]) + + # collect and apply mutations + out_mut_ = iter(out_mut) + in_idx = {v: i for i, v in enumerate(jaxpr.jaxpr.invars)} + + for v in jaxpr.jaxpr.invars: + if v.final_qdd is not None: + qdd = v.final_qdd + lo_vals = it.islice(out_mut_, len(v.aval.lo_ty_qdd(qdd))) + v.aval.update_from_loval(qdd, hi_args[in_idx[v]], *lo_vals) + + assert next(out_mut_, None) is None + + # collect output values into hi types + lo_outs_ = iter(lo_outs) + hi_outs = [t.raise_val(*it.islice(lo_outs_, len(t.lo_ty()))) + for t in jaxpr.out_avals] + assert next(lo_outs_, None) is None + + return hi_outs +scan_p.to_lojax = _to_lojax + +def _move_right(lst, to_move): + lst, rest = split_list(lst, [len(to_move)]) + left, right = partition_list(to_move, lst) + return [*left, *right, *rest] + def _propagate_mem_kind_scan(*xm, reverse, length, num_consts, num_carry, jaxpr, linear, unroll, _split_transpose): return pxla.get_out_memory_kinds_via_propagation(jaxpr) @@ -1413,9 +1704,34 @@ def _create_jaxpr(init_val): if disallowed_effects: raise NotImplementedError( f'Effects not supported in `while`: {disallowed_effects}') + + # If the body forwards an input carry to an output carry, *and* it's not used + # by the cond fun, it can be moved to be a body const. Doing so can lead to + # efficiency wins: if e.g. we vmap the loop with a batched predicate, we batch + # the carry too, but not the body consts. + body_fwd = pe._jaxpr_forwarding(body_jaxpr.jaxpr) + carry_nofwd = [len(body_consts) + i != f for i, f in enumerate(body_fwd)] + cond_jaxpr_, keep_cond = pe.dce_jaxpr( + cond_jaxpr.jaxpr, [True], [True] * len(cond_consts) + carry_nofwd) + _, keep_cond_carry = split_list(keep_cond, [len(cond_consts)]) + move_to_const = _map(operator.not_, keep_cond_carry) + + if any(move_to_const): + cond_jaxpr = pe.close_jaxpr(cond_jaxpr_) + body_jaxpr = pe.prune_closed_jaxpr_outputs( + body_jaxpr, [not m for m in move_to_const]) + body_jaxpr = pe.move_binders_to_front( + body_jaxpr, [False] * len(body_consts) + move_to_const) + init_vals, new_body_consts = partition_list(move_to_const, init_vals) + body_consts = [*new_body_consts, *body_consts] + outs = while_p.bind(*cond_consts, *body_consts, *init_vals, cond_nconsts=len(cond_consts), cond_jaxpr=cond_jaxpr, body_nconsts=len(body_consts), body_jaxpr=body_jaxpr) + + if any(move_to_const): + outs = pe.merge_lists(move_to_const, outs, new_body_consts) + return tree_unflatten(body_tree, outs) @@ -1438,7 +1754,29 @@ def _join_while_effects(body_jaxpr, cond_jaxpr, body_nconsts, cond_nconsts def _while_loop_abstract_eval(*avals, cond_jaxpr, body_jaxpr, body_nconsts, cond_nconsts): - del avals + cond_consts_avals, body_consts_avals, in_avals = \ + util.split_list(avals, [cond_nconsts, body_nconsts]) + + if len(cond_jaxpr.in_avals) != len(cond_consts_avals) + len(in_avals): + raise core.JaxprTypeError( + f"while_loop {len(cond_jaxpr.in_avals)=} but {len(cond_consts_avals) + len(in_avals)=}") + if len(body_jaxpr.in_avals) != len(body_consts_avals) + len(in_avals): + raise core.JaxprTypeError( + f"while_loop {len(body_jaxpr.in_avals)=} but {len(body_consts_avals) + len(in_avals)=}") + # TODO(mattjj): check body carry type + # TODO(mattjj): make these typecompat checks work with bints + # if not all(_map(core.typecompat, [*cond_consts_avals, *in_avals], cond_jaxpr.in_avals)): # type: ignore + # cond_avals = [*cond_consts_avals, *in_avals] + # a1, a2 = next((a1, a2) for a1, a2 in zip(cond_avals, cond_jaxpr.in_avals) + # if not core.typecompat(a1, a2)) + # raise core.JaxprTypeError(f"while_loop cond function input type error: {a1} != {a2}") + # if not all(_map(core.typecompat, [*body_consts_avals, *in_avals], body_jaxpr.in_avals)): # type: ignore + # body_avals = [*body_consts_avals, *in_avals] + # a1, a2 = next((a1, a2) for a1, a2 in zip(body_avals, body_jaxpr.in_avals) + # if not core.typecompat(a1, a2)) + # raise core.JaxprTypeError(f"while_loop body function input type error: {a1} != {a2}") + + joined_effects = _join_while_effects(body_jaxpr, cond_jaxpr, body_nconsts, cond_nconsts) disallowed_effects = effects.control_flow_allowed_effects.filter_not_in(joined_effects) @@ -1679,7 +2017,7 @@ def _while_partial_eval_custom(saveable, unks_in, inst_in, eqn): assert False, "Fixpoint not reached" assert not num_res body_jaxpr_known = core.ClosedJaxpr(jaxpr_known_, body_jaxpr.consts) - del jaxpr_known_, carry_uk_out, num_res + del jaxpr_known_, carry_uk_out, num_res, unks_in # Instantiate all inputs (b/c jaxpr_staged will take all inputs). new_inst = [x for x, inst in zip(eqn.invars, inst_in) @@ -1701,6 +2039,7 @@ def _while_partial_eval_custom(saveable, unks_in, inst_in, eqn): del cond_uk # Build the known eqn. + unks_in = [*cond_consts_uk, *body_consts_uk, *carry_uk] # fixpoint carry_uk ins_known, _ = partition_list(unks_in, eqn.invars) out_binders_known, _ = partition_list(carry_uk, eqn.outvars) params_known = dict(cond_jaxpr=cond_jaxpr_known, body_jaxpr=body_jaxpr_known, @@ -1711,6 +2050,11 @@ def _while_partial_eval_custom(saveable, unks_in, inst_in, eqn): eqn_known = pe.new_jaxpr_eqn(ins_known, out_binders_known, while_p, params_known, effects_known, eqn.source_info, eqn.ctx) + # Typecheck known eqn. + _while_loop_abstract_eval( + *[v.aval for v in eqn_known.invars], cond_jaxpr=cond_jaxpr_known, + body_jaxpr=body_jaxpr_known, body_nconsts=params_known['body_nconsts'], + cond_nconsts=params_known['cond_nconsts']) # Staged eqn is same as input eqn. eqn_staged = eqn @@ -1763,18 +2107,19 @@ def cond(args): pred = lax.reduce_or(pred, tuple(range(len(pred_aval.shape)))) return pred def body(args): - return tuple(core.eval_jaxpr(body_jaxpr.jaxpr, body_jaxpr.consts, *args)) + return core.eval_jaxpr(body_jaxpr.jaxpr, body_jaxpr.consts, *args) def new_cond(pred_args): - pred, _ = pred_args + pred, *_ = pred_args return pred def new_body(pred_args): - _, args = pred_args - args = body(args) - pred = cond(args) - return pred, args + _, cond_consts, body_consts, carry = pred_args + carry = body((*body_consts, *carry)) + pred = cond((*cond_consts, *carry)) + return pred, cond_consts, body_consts, carry def fun(*args): - pred = cond(args) - _, out = while_loop(new_cond, new_body, (pred, args)) + cond_consts, body_consts, carry = split_list(args, [cond_nconsts, body_nconsts]) + pred = cond((*cond_consts, *carry)) + *_, out = while_loop(new_cond, new_body, (pred, cond_consts, body_consts, carry)) return out return mlir.lower_fun(fun)(ctx, *args) @@ -1798,8 +2143,7 @@ def fun(*args): cond_block.arguments[i] for i in range(len(flat_loop_carry_types)) ] cond_args = mlir.unflatten_ir_values_like_types(flat_cond_args, loop_carry_types) - # Remove tokens from cond args - cond_args = cond_args[num_tokens:] + cond_args = cond_args[num_tokens:] # Remove tokens from cond args x, _, z = util.split_list(cond_args, [cond_nconsts, body_nconsts]) cond_consts = [ mlir.ir_constant(xla.canonicalize_dtype(x)) for x in cond_jaxpr.consts @@ -1820,7 +2164,8 @@ def fun(*args): name_stack=cond_name_stack, primitive=None, avals_in=[pred_aval], - avals_out=[pred_aval.update(shape=())], + avals_out=[pred_aval.update( + shape=(), sharding=pred_aval.sharding.update(spec=()))], tokens_in=mlir.TokenSet(), tokens_out=None) pred, = lax._unary_reduce_lower( @@ -1861,8 +2206,9 @@ def fun(*args): partial(_pred_bcast_select_hlo, ctx, pred_aval, body_pred), new_z, z, body_jaxpr.out_avals) - hlo.return_([*mlir.flatten_ir_values(out_tokens), *mlir.flatten_ir_values(x), *mlir.flatten_ir_values(y), - *mlir.flatten_ir_values(new_z)]) + hlo.return_([*mlir.flatten_ir_values(out_tokens), + *mlir.flatten_ir_values(x), *mlir.flatten_ir_values(y), + *mlir.flatten_ir_values(new_z)]) outputs = mlir.unflatten_ir_values_like_types(while_op.results, loop_carry_types) tokens, _, _, z = util.split_list(outputs, [num_tokens, cond_nconsts, body_nconsts]) @@ -1975,8 +2321,8 @@ def new_cond(*consts_refs_carry): ad.primitive_transposes[while_p] = _while_transpose_error batching.fancy_primitive_batchers[while_p] = _while_loop_batching_rule pe.partial_eval_jaxpr_custom_rules[while_p] = _while_partial_eval_custom -mlir.register_lowering(while_p, _while_lowering) core.custom_typechecks[while_p] = _while_typecheck +mlir.register_lowering(while_p, _while_lowering) state_discharge.register_partial_discharge_rule(while_p)(_while_partial_discharge_rule) @@ -2071,7 +2417,7 @@ def fori_loop(lower, upper, body_fun, init_val): unroll: An optional integer or boolean that determines how much to unroll the loop. If an integer is provided, it determines how many unrolled loop iterations to run within a single rolled iteration of the loop. If a - boolean is provided, it will determine if the loop is competely unrolled + boolean is provided, it will determine if the loop is completely unrolled (i.e. `unroll=True`) or left completely unrolled (i.e. `unroll=False`). This argument is only applicable if the loop bounds are statically known. @@ -2137,7 +2483,7 @@ def fori_loop(lower, upper, body_fun, init_val): unroll=unroll, ) return result - if unroll is not None: + if unroll is not None and unroll is not False and unroll != 1: raise ValueError("Can only use `unroll` in `fori_loop` if the loop bounds " "are statically known.") @@ -2153,29 +2499,45 @@ def fori_loop(lower, upper, body_fun, init_val): ### map and miscellaneous rules -def _batch_and_remainder(x, batch_size: int): - leaves, treedef = tree_flatten(x) - - scan_leaves = [] - remainder_leaves = [] +def _scan_leaf(leaf, batch_elems, num_batches, batch_size): + def f(l): + return l[:batch_elems].reshape(num_batches, batch_size, *leaf.shape[1:]) - for leaf in leaves: - num_batches, _ = divmod(leaf.shape[0], batch_size) - total_batch_elems = num_batches * batch_size - scan_leaves.append(leaf[:total_batch_elems].reshape(num_batches, batch_size, *leaf.shape[1:])) - remainder_leaves.append(leaf[total_batch_elems:]) + aval = core.typeof(leaf) + if aval.sharding.spec[0] is not None: + raise ValueError( + '0th dimension of leaf passed to `jax.lax.map` should be replicated.' + f' Got {aval.str_short(True, True)}') + if get_abstract_mesh()._are_all_axes_explicit: + out_s = aval.sharding.update(spec=P(None, None, *aval.sharding.spec[1:])) + return auto_axes(f, out_sharding=out_s)(leaf) + return f(leaf) + +def _remainder_leaf(leaf, batch_elems): + def f(l): + return l[batch_elems:] + if get_abstract_mesh()._are_all_axes_explicit: + return auto_axes(f, out_sharding=core.typeof(leaf).sharding)(leaf) + return f(leaf) - scan_tree = treedef.unflatten(scan_leaves) - remainder_tree = treedef.unflatten(remainder_leaves) - return scan_tree, remainder_tree +def _batch_and_remainder(x, batch_size: int): + leaves, treedef = tree_flatten(x) + if not leaves: + return x, None + num_batches, remainder = divmod(leaves[0].shape[0], batch_size) + batch_elems = num_batches * batch_size + if remainder: + scan_leaves, remainder_leaves = unzip2( + [(_scan_leaf(leaf, batch_elems, num_batches, batch_size), + _remainder_leaf(leaf, batch_elems)) for leaf in leaves]) + return treedef.unflatten(scan_leaves), treedef.unflatten(remainder_leaves) + else: + scan_leaves = tuple(_scan_leaf(leaf, batch_elems, num_batches, batch_size) + for leaf in leaves) + return treedef.unflatten(scan_leaves), None @api_boundary -def map( - f, - xs, - *, - batch_size: int | None = None, -): +def map(f, xs, *, batch_size: int | None = None): """Map a function over leading array axes. Like Python's builtin map, except inputs and outputs are in the form of @@ -2227,27 +2589,35 @@ def map(f, xs): scan_xs, remainder_xs = _batch_and_remainder(xs, batch_size) g = lambda _, x: ((), api.vmap(f)(x)) _, scan_ys = scan(g, (), scan_xs) - remainder_ys = api.vmap(f)(remainder_xs) flatten = lambda x: x.reshape(-1, *x.shape[2:]) - ys = tree_map( - lambda x, y: lax.concatenate([flatten(x), y], dimension=0), scan_ys, remainder_ys, - ) + if remainder_xs is not None: + remainder_ys = api.vmap(f)(remainder_xs) + ys = tree_map( + lambda x, y: lax.concatenate([flatten(x), y], dimension=0), scan_ys, + remainder_ys) + else: + ys = tree_map(flatten, scan_ys) else: g = lambda _, x: ((), f(x)) _, ys = scan(g, (), xs) return ys -def _rng_bit_generator_batching_rule(batched_args, batch_dims, *, shape, dtype, algorithm): +def _rng_bit_generator_batching_rule(batched_args, batch_dims, *, shape, dtype, + algorithm, out_sharding): keys, = batched_args bd, = batch_dims if bd is batching.not_mapped: - return lax.rng_bit_generator_p.bind(keys, shape=shape, dtype=dtype, - algorithm=algorithm), (None, None) + return lax.rng_bit_generator_p.bind( + keys, shape=shape, dtype=dtype, algorithm=algorithm, + out_sharding=out_sharding), (None, None) keys = batching.moveaxis(keys, bd, 0) batch_size = keys.shape[0] + out_s = (out_sharding.update(spec=(keys.aval.sharding.spec[0], *out_sharding.spec)) + if out_sharding is not None else None) key = keys[0] - new_key, bits = lax.rng_bit_generator_p.bind(key, shape=(batch_size, *shape), - dtype=dtype, algorithm=algorithm) + new_key, bits = lax.rng_bit_generator_p.bind( + key, shape=(batch_size, *shape), dtype=dtype, algorithm=algorithm, + out_sharding=out_s) new_keys = slicing.dynamic_update_index_in_dim(keys, new_key, 0, axis=0) return (new_keys, bits), (0, 0) @@ -2288,6 +2658,9 @@ def associative_scan(fn: Callable, elems, reverse: bool = False, axis: int = 0): of ``elems`` along ``axis``. For example, given ``elems = [a, b, c, ...]``, the result would be ``[a, fn(a, b), fn(fn(a, b), c), ...]``. + If ``elems = [..., x, y, z]`` and ``reverse`` is true, the result is + ``[..., f(f(z, y), x), f(z, y), z]``. + Example 1: partial sums of an array of numbers: >>> lax.associative_scan(jnp.add, jnp.arange(0, 4)) @@ -2488,7 +2861,8 @@ def _cumred_dtype_rule(name, operand, *args, **kw): def _cumulative_reduction_primitive(name, reduce_fn, reduce_window_fn): reducer_p = lax.standard_primitive( _cumred_shape_rule, partial(_cumred_dtype_rule, name), - name, sharding_rule=_cumred_sharding_rule) + name, sharding_rule=_cumred_sharding_rule, + vma_rule=partial(core.standard_vma_rule, name)) batching.primitive_batchers[reducer_p] = partial(_cumred_batch_rule, reducer_p) diff --git a/jax/_src/lax/control_flow/solves.py b/jax/_src/lax/control_flow/solves.py index acfcfd7ff3d3..dd35044e7047 100644 --- a/jax/_src/lax/control_flow/solves.py +++ b/jax/_src/lax/control_flow/solves.py @@ -15,10 +15,9 @@ import collections from functools import partial import operator -from typing import Any, Callable +from typing import Any +from collections.abc import Callable -from jax.tree_util import (tree_flatten, treedef_children, tree_leaves, - tree_unflatten, treedef_tuple) from jax._src import ad_util from jax._src import api from jax._src import api_util @@ -30,6 +29,8 @@ from jax._src.interpreters import mlir from jax._src.interpreters import xla from jax._src.traceback_util import api_boundary +from jax._src.tree_util import (tree_flatten, treedef_children, tree_leaves, + tree_unflatten, treedef_tuple) from jax._src.util import split_list, safe_map import numpy as np @@ -309,24 +310,24 @@ def f_aux(x): jaxprs = _LinearSolveTuple( matvec_jaxpr, vecmat_jaxpr, solve_jaxpr, tr_solve_jaxpr) - out_flat = linear_solve_p.bind( - *(_flatten(all_consts) + b_flat), - const_lengths=const_lengths, jaxprs=jaxprs) + args = _flatten(all_consts) + b_flat + args = core.standard_insert_pvary(*args) + out_flat = linear_solve_p.bind(*args, const_lengths=const_lengths, jaxprs=jaxprs) return tree_unflatten(out_tree, out_flat) def _linear_solve_abstract_eval(*args, const_lengths, jaxprs): args_to_raise = args[sum(const_lengths):] - # raise aux_args to shaped arrays as well if present # number of aux args is the difference in out_avals # of solve and matvec (since they map to the same vector space) - num_aux = len(jaxprs.solve.out_avals) - len(jaxprs.matvec.out_avals) if num_aux > 0: args_to_raise += tuple(jaxprs.solve.out_avals[-num_aux:]) - return args_to_raise, jaxprs.solve.effects + out_vma = core.standard_vma_rule('linear_solve', *args_to_raise) + return (tuple(a.update(vma=out_vma) for a in args_to_raise), + jaxprs.solve.effects) def _custom_linear_solve_impl(*args, const_lengths, jaxprs): diff --git a/jax/_src/lax/convolution.py b/jax/_src/lax/convolution.py index 290d027cc6bc..53f9c88369a6 100644 --- a/jax/_src/lax/convolution.py +++ b/jax/_src/lax/convolution.py @@ -53,6 +53,8 @@ class ConvDimensionNumbers(NamedTuple): None, ] +# TODO(yashkatariya): conv_general_dilated should take `out_sharding` argument +# similar to `dot_general` def conv_general_dilated( lhs: Array, rhs: Array, window_strides: Sequence[int], padding: str | Sequence[tuple[int, int]], @@ -158,6 +160,7 @@ def conv_general_dilated( preferred_element_type = ( None if preferred_element_type is None else dtypes.canonicalize_dtype(np.dtype(preferred_element_type))) + lhs, rhs = core.standard_insert_pvary(lhs, rhs) return conv_general_dilated_p.bind( lhs, rhs, window_strides=tuple(window_strides), padding=tuple(padding), lhs_dilation=tuple(lhs_dilation), rhs_dilation=tuple(rhs_dilation), @@ -414,6 +417,26 @@ def _conv_general_dilated_shape_rule( return tuple(np.take(out_trans, np.argsort(out_perm))) +def _conv_general_dilated_sharding_rule( + lhs: core.ShapedArray, rhs: core.ShapedArray, *, window_strides, padding, + lhs_dilation, rhs_dilation, dimension_numbers, feature_group_count, + batch_group_count, **unused_kwargs): + # Only allow if rhs is fully replicated and lhs's feature dim is not sharded + if ((rhs.sharding.mesh.empty or rhs.sharding.is_fully_replicated) and + lhs.sharding.spec[dimension_numbers.lhs_spec[1]] is None): + out_shape = _conv_general_dilated_shape_rule( + lhs, rhs, window_strides=window_strides, padding=padding, + lhs_dilation=lhs_dilation, rhs_dilation=rhs_dilation, + dimension_numbers=dimension_numbers, + feature_group_count=feature_group_count, + batch_group_count=batch_group_count) + return lax.slicing._get_sharding_for_varying_out_shape( + out_shape, lhs, "conv_general_dilated") + # TODO(yashkatariya): In this case, just let the user specify the out_sharding + # via `out_sharding` argument to `conv_general_dilated`. + raise core.ShardingTypeError( + "Please file an issue at https://github.com/jax-ml/jax/issues") + def _conv_general_dilated_dtype_rule( lhs, rhs, *, window_strides, padding, lhs_dilation, rhs_dilation, dimension_numbers, preferred_element_type, **unused_kwargs): @@ -633,7 +656,9 @@ def _conv_general_dilated_batch_rule( conv_general_dilated_p = lax.standard_primitive( _conv_general_dilated_shape_rule, _conv_general_dilated_dtype_rule, - 'conv_general_dilated') + 'conv_general_dilated', + sharding_rule=_conv_general_dilated_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'conv_general_dilated')) ad.defbilinear(conv_general_dilated_p, _conv_general_dilated_transpose_lhs, @@ -711,21 +736,18 @@ def _conv_general_dilated_lower( # TODO(https://github.com/openxla/stablehlo/issues/1268) raise NotImplementedError("Convolutions with non-static strides, dilation, feature_group_count, or batch_group_count") if all(core.is_constant_shape(p) for p in padding): - return [ - hlo.convolution( - mlir.aval_to_ir_type(aval_out), - lhs, - rhs, - dimension_numbers=dnums, - feature_group_count=mlir.i64_attr(feature_group_count), - batch_group_count=mlir.i64_attr(batch_group_count), - window_strides=mlir.dense_int_array(window_strides), - padding=mlir.dense_int_elements(padding), - lhs_dilation=mlir.dense_int_array(lhs_dilation), - rhs_dilation=mlir.dense_int_array(rhs_dilation), - window_reversal=window_reversal, - precision_config=lax.precision_attr(precision)) - ] + out = hlo.convolution( + mlir.aval_to_ir_type(aval_out), lhs, rhs, + dimension_numbers=dnums, + feature_group_count=mlir.i64_attr(feature_group_count), + batch_group_count=mlir.i64_attr(batch_group_count), + window_strides=mlir.dense_int_array(window_strides), + padding=mlir.dense_int_elements(padding), + lhs_dilation=mlir.dense_int_array(lhs_dilation), + rhs_dilation=mlir.dense_int_array(rhs_dilation), + window_reversal=window_reversal, + precision_config=lax.precision_attr(precision)) + return [mlir.lower_with_sharding_in_types(ctx, out, aval_out)] else: # d_padding will be an array i32[N, 2] with pad_lo and pad_hi for each # spatial dimension. diff --git a/jax/_src/lax/fft.py b/jax/_src/lax/fft.py index 6ca1a4abd193..08e06287b784 100644 --- a/jax/_src/lax/fft.py +++ b/jax/_src/lax/fft.py @@ -21,8 +21,6 @@ import numpy as np -from jax import lax - from jax._src import dispatch from jax._src import dtypes from jax._src.api import jit, linear_transpose, ShapeDtypeStruct @@ -30,6 +28,7 @@ from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir +from jax._src.lax import lax from jax._src.lib.mlir.dialects import hlo __all__ = [ @@ -124,7 +123,7 @@ def fft_abstract_eval(x, fft_type, fft_lengths): f"be equal to fft_lengths {fft_lengths}") shape = x.shape dtype = x.dtype - return x.update(shape=shape, dtype=dtype) + return x.update(shape=shape, dtype=dtype, vma=x.vma) def _fft_lowering(ctx, x, *, fft_type, fft_lengths): if not is_constant_shape(fft_lengths): diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 86a75ada63ad..fcbe86968e06 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -28,10 +28,6 @@ import numpy as np -from jax import tree_util -from jax.sharding import Sharding -from jax.tree_util import tree_map - from jax._src import ad_util from jax._src import api from jax._src import api_util @@ -46,11 +42,13 @@ from jax._src import pretty_printer as pp from jax._src import source_info_util from jax._src import state +from jax._src import tree_util from jax._src import util from jax._src.abstract_arrays import array_types from jax._src.core import (Primitive, UnshapedArray, ShapedArray, abstract_token, canonicalize_shape) from jax._src.errors import UnexpectedTracerError +from jax._src.hashable_array import HashableArray from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir @@ -66,11 +64,12 @@ from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import chlo from jax._src.lib.mlir.dialects import hlo -from jax._src.lib import xla_extension_version +from jax._src.sharding import Sharding from jax._src.sharding_impls import (PmapSharding, NamedSharding, + ShardingContext, SPMDAxisContext, PartitionSpec as P, canonicalize_sharding) from jax._src.typing import Array, ArrayLike, DimSize, DuckTypedArray, DTypeLike, Shape -from jax._src.util import (NumpyComplexWarning, cache, canonicalize_axis, +from jax._src.util import (cache, canonicalize_axis, safe_map, safe_zip, split_list, weakref_lru_cache, foreach) @@ -238,7 +237,7 @@ def broadcast_shardings(*avals): new_spec = P(*(None,) * (ndim - a.ndim) + a.sharding.spec) new_shape = (1,) * (ndim - a.ndim) + a.shape aval_list.append(a.update(shape=new_shape, - sharding=a.sharding.with_spec(new_spec))) + sharding=a.sharding.update(spec=new_spec))) return broadcasting_sharding_rule('broadcast_shardings', *aval_list) def _identity(x, **_): return x @@ -266,8 +265,8 @@ def _merge_dyn_shape( assert next(dyn_shape_it, None) is None return shape -def _dyn_shape_staging_rule(trace, prim, out_aval, *args, **params): - source_info = source_info_util.current() +def _dyn_shape_staging_rule(trace, source_info, prim, out_aval, *args, + **params): out_tracer = pe.DynamicJaxprTracer(trace, out_aval, source_info) eqn = pe.new_jaxpr_eqn([trace.getvar(x) for x in args], [trace.makevar(out_tracer)], @@ -369,6 +368,7 @@ def nextafter(x1: ArrayLike, x2: ArrayLike) -> Array: For the smallest usable (i.e. normal) float, use ``tiny`` of ``jnp.finfo``. """ + x1, x2 = core.standard_insert_pvary(x1, x2) return nextafter_p.bind(x1, x2) @export @@ -483,14 +483,41 @@ def is_finite(x: ArrayLike) -> Array: """ return is_finite_p.bind(x) +class Tolerance: + """Specify the tolerances used for computing unary functions. + + Maximum two tolerances can be specified: (atol and rtol) or (atol and ulps). + """ + + def __init__(self, atol: float = 0.0, rtol: float = 0.0, ulps: int = 0): + if atol < 0.0 or rtol < 0.0 or ulps < 0.0: + raise ValueError('Tolerances must be non-negative.') + if atol == 0.0 and rtol == 0.0 and ulps == 0: + raise ValueError('At least one of atol, rtol, or ulps must be set.') + + self.atol = atol + self.rtol = rtol + self.ulps = ulps + + +class AccuracyMode(enum.Enum): + HIGHEST = 1 + DEFAULT = 2 + @export -def exp(x: ArrayLike) -> Array: +def exp(x: ArrayLike, accuracy=None) -> Array: r"""Elementwise exponential: :math:`e^x`. This function lowers directly to the `stablehlo.exponential`_ operation. Args: x: input array. Must have floating-point or complex type. + accuracy: Optional `lax.Tolerance` or `lax.AccuracyMode` object that + selects the implementation of the op based on the requested accuracy. If + the implementation cannot satisfy the requested tolerance, the + compiler will return an error. If mode is specified and there are no + multiple implementations available, the default implementation will be + used. Returns: Array of the same shape and dtype as ``x`` containing the element-wise @@ -502,10 +529,10 @@ def exp(x: ArrayLike) -> Array: .. _stablehlo.exponential: https://openxla.org/stablehlo/spec#exponential """ - return exp_p.bind(x) + return exp_p.bind(x, accuracy=accuracy) -@export -def exp2(x: ArrayLike) -> Array: + +def exp2(x: ArrayLike, accuracy=None) -> Array: r"""Elementwise base-2 exponential: :math:`2^x`. This function is implemented in terms of the `stablehlo.exponential`_ @@ -513,6 +540,12 @@ def exp2(x: ArrayLike) -> Array: Args: x: input array. Must have floating-point or complex type. + accuracy: Optional `lax.Tolerance` or `lax.AccuracyMode` object that + selects the implementation of the op based on the requested accuracy. If + the implementation cannot satisfy the requested tolerance, the + compiler will return an error. If mode is specified and there are no + multiple implementations available, the default implementation will be + used. Returns: Array of the same shape and dtype as ``x`` containing the element-wise @@ -525,10 +558,10 @@ def exp2(x: ArrayLike) -> Array: .. _stablehlo.exponential: https://openxla.org/stablehlo/spec#exponential .. _stablehlo.multiply: https://openxla.org/stablehlo/spec#multiply """ - return exp2_p.bind(x) + return exp2_p.bind(x, accuracy=accuracy) @export -def expm1(x: ArrayLike) -> Array: +def expm1(x: ArrayLike, accuracy=None) -> Array: r"""Elementwise :math:`e^{x} - 1`. This function lowers directly to the `stablehlo.exponential_minus_one`_ @@ -537,6 +570,12 @@ def expm1(x: ArrayLike) -> Array: Args: x: input array. Must have floating-point or complex type. + accuracy: Optional `lax.Tolerance` or `lax.AccuracyMode` object that + selects the implementation of the op based on the requested accuracy. If + the implementation cannot satisfy the requested tolerance, the + compiler will return an error. If mode is specified and there are no + multiple implementations available, the default implementation will be + used. Returns: Array of the same shape and dtype as ``x`` containing the element-wise @@ -548,16 +587,22 @@ def expm1(x: ArrayLike) -> Array: .. _stablehlo.exponential_minus_one: https://openxla.org/stablehlo/spec#exponential_minus_one """ - return expm1_p.bind(x) + return expm1_p.bind(x, accuracy=accuracy) @export -def log(x: ArrayLike) -> Array: +def log(x: ArrayLike, accuracy=None) -> Array: r"""Elementwise natural logarithm: :math:`\mathrm{log}(x)`. This function lowers directly to the `stablehlo.log`_ operation. Args: x: input array. Must have floating-point or complex type. + accuracy: Optional `lax.Tolerance` or `lax.AccuracyMode` object that + selects the implementation of the op based on the requested accuracy. If + the implementation cannot satisfy the requested tolerance, the + compiler will return an error. If mode is specified and there are no + multiple implementations available, the default implementation will be + used. Returns: Array of the same shape and dtype as ``x`` containing the element-wise @@ -568,10 +613,10 @@ def log(x: ArrayLike) -> Array: .. _stablehlo.log: https://openxla.org/stablehlo/spec#log """ - return log_p.bind(x) + return log_p.bind(x, accuracy=accuracy) @export -def log1p(x: ArrayLike) -> Array: +def log1p(x: ArrayLike, accuracy=None) -> Array: r"""Elementwise :math:`\mathrm{log}(1 + x)`. This function lowers directly to the `stablehlo.log_plus_one`_ operation. @@ -580,6 +625,12 @@ def log1p(x: ArrayLike) -> Array: Args: x: input array. Must have floating-point or complex type. + accuracy: Optional `lax.Tolerance` or `lax.AccuracyMode` object that + selects the implementation of the op based on the requested accuracy. If + the implementation cannot satisfy the requested tolerance, the + compiler will return an error. If mode is specified and there are no + multiple implementations available, the default implementation will be + used. Returns: Array of the same shape and dtype as ``x`` containing the element-wise @@ -591,16 +642,22 @@ def log1p(x: ArrayLike) -> Array: .. _stablehlo.log_plus_one: https://openxla.org/stablehlo/spec#log_plus_one """ - return log1p_p.bind(x) + return log1p_p.bind(x, accuracy=accuracy) @export -def tanh(x: ArrayLike) -> Array: +def tanh(x: ArrayLike, accuracy=None) -> Array: r"""Elementwise hyperbolic tangent: :math:`\mathrm{tanh}(x)`. This function lowers directly to the `stablehlo.tanh`_ operation. Args: x: input array. Must have floating-point or complex type. + accuracy: Optional `lax.Tolerance` or `lax.AccuracyMode` object that + selects the implementation of the op based on the requested accuracy. If + the implementation cannot satisfy the requested tolerance, the + compiler will return an error. If mode is specified and there are no + multiple implementations available, the default implementation will be + used. Returns: Array of the same shape and dtype as ``x`` containing the element-wise @@ -613,10 +670,11 @@ def tanh(x: ArrayLike) -> Array: .. _stablehlo.tanh: https://openxla.org/stablehlo/spec#tanh """ - return tanh_p.bind(x) + return tanh_p.bind(x, accuracy=accuracy) @export -def logistic(x: ArrayLike) -> Array: + +def logistic(x: ArrayLike, accuracy=None) -> Array: r"""Elementwise logistic (sigmoid) function: :math:`\frac{1}{1 + e^{-x}}`. There is no HLO logistic/sigmoid primitive, so this lowers to a sequence @@ -632,10 +690,10 @@ def logistic(x: ArrayLike) -> Array: See also: - :func:`jax.nn.sigmoid`: an alternative API for this functionality. """ - return logistic_p.bind(x) + return logistic_p.bind(x, accuracy=accuracy) @export -def sin(x: ArrayLike) -> Array: +def sin(x: ArrayLike, accuracy=None) -> Array: r"""Elementwise sine: :math:`\mathrm{sin}(x)`. For floating-point inputs, this function lowers directly to the @@ -644,6 +702,12 @@ def sin(x: ArrayLike) -> Array: Args: x: input array. Must have floating-point or complex type. + accuracy: Optional `lax.Tolerance` or `lax.AccuracyMode` object that + selects the implementation of the op based on the requested accuracy. If + the implementation cannot satisfy the requested tolerance, the + compiler will return an error. If mode is specified and there are no + multiple implementations available, the default implementation will be + used. Returns: Array of the same shape and dtype as ``x`` containing the element-wise @@ -656,10 +720,10 @@ def sin(x: ArrayLike) -> Array: .. _stablehlo.sine: https://openxla.org/stablehlo/spec#sine """ - return sin_p.bind(x) + return sin_p.bind(x, accuracy=accuracy) @export -def cos(x: ArrayLike) -> Array: +def cos(x: ArrayLike, accuracy=None) -> Array: r"""Elementwise cosine: :math:`\mathrm{cos}(x)`. For floating-point inputs, this function lowers directly to the @@ -668,6 +732,12 @@ def cos(x: ArrayLike) -> Array: Args: x: input array. Must have floating-point or complex type. + accuracy: Optional `lax.Tolerance` or `lax.AccuracyMode` object that + selects the implementation of the op based on the requested accuracy. If + the implementation cannot satisfy the requested tolerance, the + compiler will return an error. If mode is specified and there are no + multiple implementations available, the default implementation will be + used. Returns: Array of the same shape and dtype as ``x`` containing the element-wise @@ -680,7 +750,7 @@ def cos(x: ArrayLike) -> Array: .. _stablehlo.cosine: https://openxla.org/stablehlo/spec#cosine """ - return cos_p.bind(x) + return cos_p.bind(x, accuracy=accuracy) @export def atan2(x: ArrayLike, y: ArrayLike) -> Array: @@ -704,6 +774,7 @@ def atan2(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.atan2: https://openxla.org/stablehlo/spec#atan2 """ + x, y = core.standard_insert_pvary(x, y) return atan2_p.bind(x, y) @export @@ -773,6 +844,7 @@ def complex(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.complex: https://openxla.org/stablehlo/spec#complex """ + x, y = core.standard_insert_pvary(x, y) return complex_p.bind(x, y) @export @@ -844,6 +916,7 @@ def pow(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.convert: https://openxla.org/stablehlo/spec#convert .. _stablehlo.pow: https://openxla.org/stablehlo/spec#pow """ + x, y = core.standard_insert_pvary(x, y) return pow_p.bind(x, y) @export @@ -861,20 +934,27 @@ def integer_pow(x: ArrayLike, y: int) -> Array: An array of the same shape and dtype as ``x`` containing the elementwise power. See also: - :func:`jax.lax.pow`: Elementwise pwoer where ``y`` is an array. + :func:`jax.lax.pow`: Elementwise power where ``y`` is an array. .. _stablehlo.multiply: https://openxla.org/stablehlo/spec#multiply """ return integer_pow_p.bind(x, y=y) + @export -def sqrt(x: ArrayLike) -> Array: +def sqrt(x: ArrayLike, accuracy=None) -> Array: r"""Elementwise square root: :math:`\sqrt{x}`. This function lowers directly to the `stablehlo.sqrt`_ operation. Args: x: Input array. Must have floating or complex dtype. + accuracy: Optional `lax.Tolerance` or `lax.AccuracyMode` object that + selects the implementation of the op based on the requested accuracy. If + the implementation cannot satisfy the requested tolerance, the + compiler will return an error. If mode is specified and there are no + multiple implementations available, the default implementation will be + used. Returns: An array of the same shape and dtype as ``x`` containing the square root. @@ -886,16 +966,22 @@ def sqrt(x: ArrayLike) -> Array: .. _stablehlo.sqrt: https://openxla.org/stablehlo/spec#sqrt """ - return sqrt_p.bind(x) + return sqrt_p.bind(x, accuracy=accuracy) @export -def rsqrt(x: ArrayLike) -> Array: +def rsqrt(x: ArrayLike, accuracy=None) -> Array: r"""Elementwise reciprocal square root: :math:`1 \over \sqrt{x}`. This function lowers directly to the `stablehlo.rsqrt`_ operation. Args: x: Input array. Must have floating or complex dtype. + accuracy: Optional `lax.Tolerance` or `lax.AccuracyMode` object that + selects the implementation of the op based on the requested accuracy. If + the implementation cannot satisfy the requested tolerance, the + compiler will return an error. If mode is specified and there are no + multiple implementations available, the default implementation will be + used. Returns: An array of the same shape and dtype as ``x`` containing the @@ -908,16 +994,22 @@ def rsqrt(x: ArrayLike) -> Array: .. _stablehlo.rsqrt: https://openxla.org/stablehlo/spec#rsqrt """ - return rsqrt_p.bind(x) + return rsqrt_p.bind(x, accuracy=accuracy) @export -def cbrt(x: ArrayLike) -> Array: +def cbrt(x: ArrayLike, accuracy=None) -> Array: r"""Elementwise cube root: :math:`\sqrt[3]{x}`. This function lowers directly to the `stablehlo.cbrt`_ operation. Args: x: Input array. Must have floating or complex dtype. + accuracy: Optional `lax.Tolerance` or `lax.AccuracyMode` object that + selects the implementation of the op based on the requested accuracy. If + the implementation cannot satisfy the requested tolerance, the + compiler will return an error. If mode is specified and there are no + multiple implementations available, the default implementation will be + used. Returns: An array of the same shape and dtype as ``x`` containing the cube root. @@ -929,7 +1021,7 @@ def cbrt(x: ArrayLike) -> Array: .. _stablehlo.cbrt: https://openxla.org/stablehlo/spec#cbrt """ - return cbrt_p.bind(x) + return cbrt_p.bind(x, accuracy=accuracy) @export def bitwise_not(x: ArrayLike) -> Array: @@ -979,6 +1071,7 @@ def bitwise_and(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.and: https://openxla.org/stablehlo/spec#and """ + x, y = core.standard_insert_pvary(x, y) return and_p.bind(x, y) @export @@ -1005,6 +1098,7 @@ def bitwise_or(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.or: https://openxla.org/stablehlo/spec#or """ + x, y = core.standard_insert_pvary(x, y) return or_p.bind(x, y) @export @@ -1031,6 +1125,7 @@ def bitwise_xor(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.xor: https://openxla.org/stablehlo/spec#xor """ + x, y = core.standard_insert_pvary(x, y) return xor_p.bind(x, y) @export @@ -1095,6 +1190,7 @@ def add(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.add: https://openxla.org/stablehlo/spec#add """ + x, y = core.standard_insert_pvary(x, y) return add_p.bind(x, y) @export @@ -1118,6 +1214,7 @@ def sub(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.subtract: https://openxla.org/stablehlo/spec#subtract """ + x, y = core.standard_insert_pvary(x, y) return sub_p.bind(x, y) @export @@ -1141,6 +1238,7 @@ def mul(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.multiply: https://openxla.org/stablehlo/spec#multiply """ + x, y = core.standard_insert_pvary(x, y) return mul_p.bind(x, y) @export @@ -1170,6 +1268,7 @@ def div(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.divide: https://openxla.org/stablehlo/spec#divide """ + x, y = core.standard_insert_pvary(x, y) return div_p.bind(x, y) @export @@ -1197,6 +1296,7 @@ def rem(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.remainder: https://openxla.org/stablehlo/spec#remainder """ + x, y = core.standard_insert_pvary(x, y) return rem_p.bind(x, y) @export @@ -1222,6 +1322,7 @@ def max(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.maximum: https://openxla.org/stablehlo/spec#maximum """ + x, y = core.standard_insert_pvary(x, y) return max_p.bind(x, y) @export @@ -1247,6 +1348,7 @@ def min(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.minimum: https://openxla.org/stablehlo/spec#minimum """ + x, y = core.standard_insert_pvary(x, y) return min_p.bind(x, y) @export @@ -1272,6 +1374,7 @@ def shift_left(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.shift_left: https://openxla.org/stablehlo/spec#shift_left """ + x, y = core.standard_insert_pvary(x, y) return shift_left_p.bind(x, y) @export @@ -1298,6 +1401,7 @@ def shift_right_arithmetic(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.shift_right_arithmetic: https://openxla.org/stablehlo/spec#shift_right_arithmetic """ + x, y = core.standard_insert_pvary(x, y) return shift_right_arithmetic_p.bind(x, y) @export @@ -1324,6 +1428,7 @@ def shift_right_logical(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.shift_right_logical: https://openxla.org/stablehlo/spec#shift_right_logical """ + x, y = core.standard_insert_pvary(x, y) return shift_right_logical_p.bind(x, y) @export @@ -1354,6 +1459,7 @@ def eq(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.compare: https://openxla.org/stablehlo/spec#compare """ + x, y = core.standard_insert_pvary(x, y) return eq_p.bind(x, y) @export @@ -1384,6 +1490,7 @@ def ne(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.compare: https://openxla.org/stablehlo/spec#compare """ + x, y = core.standard_insert_pvary(x, y) return ne_p.bind(x, y) @export @@ -1414,6 +1521,7 @@ def ge(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.compare: https://openxla.org/stablehlo/spec#compare """ + x, y = core.standard_insert_pvary(x, y) return ge_p.bind(x, y) @export @@ -1444,6 +1552,7 @@ def gt(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.compare: https://openxla.org/stablehlo/spec#compare """ + x, y = core.standard_insert_pvary(x, y) return gt_p.bind(x, y) @export @@ -1474,6 +1583,7 @@ def le(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.compare: https://openxla.org/stablehlo/spec#compare """ + x, y = core.standard_insert_pvary(x, y) return le_p.bind(x, y) @export @@ -1504,6 +1614,7 @@ def lt(x: ArrayLike, y: ArrayLike) -> Array: .. _stablehlo.compare: https://openxla.org/stablehlo/spec#compare """ + x, y = core.standard_insert_pvary(x, y) return lt_p.bind(x, y) @export @@ -1573,12 +1684,11 @@ def _convert_element_type( "Instead, convert to and from their representation dtypes, e.g.:\n" f"{dtype_to_string(old_dtype)} -> {dtype_to_string(old_rep_dtype)} " f"-> {dtype_to_string(new_rep_dtype)} -> {dtype_to_string(new_dtype)}") + if isinstance(new_dtype, dtypes.ExtendedDType): return to_edtype_p.bind(operand, edtype=new_dtype) return from_edtype_p.bind(operand, dtype=np.dtype(new_dtype)) - new_dtype = type_cast(DTypeLike | None, new_dtype) - old_weak_type = dtypes.is_weakly_typed(operand) if new_dtype is None: new_dtype = old_dtype @@ -1593,7 +1703,7 @@ def _convert_element_type( dtypes.issubdtype(old_dtype, np.complexfloating) and not dtypes.issubdtype(new_dtype, np.complexfloating)): msg = "Casting complex values to real discards the imaginary part" - warnings.warn(msg, NumpyComplexWarning, stacklevel=2) + warnings.warn(msg, np.exceptions.ComplexWarning, stacklevel=2) # Python has big integers, but convert_element_type(2 ** 100, np.float32) need # not be an error since the target dtype fits the value. Handle this case by @@ -1658,6 +1768,7 @@ def clamp(min: ArrayLike, x: ArrayLike, max: ArrayLike) -> Array: x & \text{otherwise} \end{cases}`. """ + min, x, max = core.standard_insert_pvary(min, x, max) return clamp_p.bind(min, x, max) @@ -1764,10 +1875,18 @@ def _decorator(*args, **kwargs): closed_jaxpr, out_tree = _trace_composite_to_jaxpr( partial(decomposition, **kwargs), in_tree, in_avals, name, debug_info ) + attributes = [] + for k, v in kwargs.items(): + leaves, treedef = tree_util.tree_flatten(v) + leaves = tuple( + HashableArray(v) if isinstance(v, np.ndarray) else v for v in leaves + ) + attributes.append((k, leaves, treedef)) + flat_args = core.standard_insert_pvary(*flat_args) out_flat = composite_p.bind( *flat_args, name=name, - attributes=tuple((k, v) for k, v in kwargs.items()), + attributes=tuple(attributes), version=version, jaxpr=closed_jaxpr, ) @@ -1780,7 +1899,7 @@ def _composite_lowering( ctx: mlir.LoweringRuleContext, *args: Any, name: str, - attributes: Sequence[tuple[str, Any]], + attributes: Sequence[tuple[str, tuple[Any, ...], tree_util.PyTreeDef]], version: int, jaxpr: core.ClosedJaxpr, ): @@ -1807,11 +1926,11 @@ def _composite_lowering( ctx.avals_out, ctx.tokens_in, ) - composite_attrs = { - k : mlir.ir_attribute(v) - for k, v in attributes - if v is not None - } + composite_attrs = {} + for k, leaves, treedef in attributes: + v = treedef.unflatten(leaves) + if v is not None: + composite_attrs[k] = mlir.ir_attribute(v) symbol_name = func_op.name.value composite = hlo.CompositeOp( func_op.type.results, @@ -1838,7 +1957,7 @@ def composite_jvp(*args, **_): raise ValueError( "JVP rule for composite not implemented. You can use `jax.custom_jvp` to " "add support. See " - "https://jax.readthedocs.io/en/latest/_autosummary/jax.custom_jvp.html" + "https://docs.jax.dev/en/latest/_autosummary/jax.custom_jvp.html" ) @@ -1847,7 +1966,7 @@ def composite_transpose(*args, **_): raise ValueError( "Transpose rule for composite not implemented. You can use" "`jax.custom_jvp` or `jax.custom_vjp` to add support. See " - "https://jax.readthedocs.io/en/latest/_autosummary/jax.custom_jvp.html" + "https://docs.jax.dev/en/latest/_autosummary/jax.custom_jvp.html" ) @@ -1881,6 +2000,7 @@ def concatenate(operands: Array | Sequence[ArrayLike], dimension: int) -> Array: op, = operands if isinstance(op, Array): return op + operands = core.standard_insert_pvary(*operands) return concatenate_p.bind(*operands, dimension=dimension) @@ -1986,7 +2106,7 @@ class DotAlgorithm(NamedTuple): The `StableHLO spec `_ for the dot operation doesn't require that the precision types be the same as the - storage types for the inputs or outputs, but some plaforms may require that + storage types for the inputs or outputs, but some platforms may require that these types match. Furthermore, the return type of :func:`~jax.lax.dot_general` is always defined by the ``accumulation_type`` parameter of the input algorithm, if specified. @@ -2230,13 +2350,10 @@ def _convert_to_hlo_attr(self, lhs_dtype: DTypeLike, np.dtype(dtypes.float8_e4m3fnuz), np.dtype(dtypes.float8_e5m2), np.dtype(dtypes.float8_e5m2fnuz), + np.dtype(dtypes.float8_e3m4), + np.dtype(dtypes.float8_e4m3), + np.dtype(dtypes.float8_e8m0fnu), ] - if dtypes.float8_e3m4 is not None: - fp8_dtypes += [np.dtype(dtypes.float8_e3m4)] - if dtypes.float8_e4m3 is not None: - fp8_dtypes += [np.dtype(dtypes.float8_e4m3)] - if dtypes.float8_e8m0fnu is not None: - fp8_dtypes += [np.dtype(dtypes.float8_e8m0fnu)] if lhs_dtype not in fp8_dtypes or rhs_dtype not in fp8_dtypes: raise ValueError( f"The dot algorithm '{self}' requires both inputs to have float8 " @@ -2267,11 +2384,6 @@ def _convert_to_hlo_attr(self, lhs_dtype: DTypeLike, case DotAlgorithmPreset.BF16_BF16_F32_X6: return hlo.DotAlgorithm.get(bf16, bf16, f32, 1, 1, 6, False) case DotAlgorithmPreset.BF16_BF16_F32_X9: - if xla_extension_version < 320: - raise ValueError( - "The dot algorithm BF16_BF16_F32_X9 requires XLA extension " - "version >= 320." - ) return hlo.DotAlgorithm.get(bf16, bf16, f32, 1, 1, 9, False) case DotAlgorithmPreset.TF32_TF32_F32: return hlo.DotAlgorithm.get(tf32, tf32, f32, 1, 1, 1, False) @@ -2352,6 +2464,7 @@ def dot(lhs: Array, rhs: Array, precision: PrecisionLike = None, def dot_general(lhs: ArrayLike, rhs: ArrayLike, dimension_numbers: DotDimensionNumbers, precision: PrecisionLike = None, preferred_element_type: DTypeLike | None = None, + *, out_sharding=None) -> Array: """General dot product/contraction operator. @@ -2398,10 +2511,6 @@ def dot_general(lhs: ArrayLike, rhs: ArrayLike, dimension_numbers: DotDimensionN by the ``lhs`` non-contracting/non-batch dimensions, and finally the ``rhs`` non-contracting/non-batch dimensions. """ - if out_sharding is not None and not isinstance(out_sharding, NamedSharding): - raise NotImplementedError( - '`out_sharding` argument of `dot_general` only supports NamedSharding ' - 'instances. Please file a bug if this is not enough for your use case.') out_sharding = canonicalize_sharding(out_sharding, 'dot_general') (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers cdims = (api_util._ensure_index_tuple(lhs_contract), @@ -2411,6 +2520,7 @@ def dot_general(lhs: ArrayLike, rhs: ArrayLike, dimension_numbers: DotDimensionN preferred_element_type = ( None if preferred_element_type is None else dtypes.canonicalize_dtype(np.dtype(preferred_element_type))) + lhs, rhs = core.standard_insert_pvary(lhs, rhs) return dot_general_p.bind(lhs, rhs, dimension_numbers=(cdims, bdims), precision=canonicalize_precision(precision), @@ -2546,6 +2656,7 @@ def ragged_dot_general( extra leading dimension of size `g` in the case where the lhs ragged dimension is a contracting dimension. """ + lhs, rhs, group_sizes = core.standard_insert_pvary(lhs, rhs, group_sizes) return ragged_dot_general_p.bind( lhs, rhs, @@ -2557,7 +2668,7 @@ def ragged_dot_general( ) -def broadcast(operand: ArrayLike, sizes: Sequence[int], out_sharding=None +def broadcast(operand: ArrayLike, sizes: Sequence[int], *, out_sharding=None ) -> Array: """Broadcasts an array, adding new leading dimensions @@ -2579,7 +2690,7 @@ def broadcast(operand: ArrayLike, sizes: Sequence[int], out_sharding=None out_sharding=out_sharding) def broadcast_in_dim(operand: ArrayLike, shape: Shape, - broadcast_dimensions: Sequence[int], out_sharding=None + broadcast_dimensions: Sequence[int], *, out_sharding=None ) -> Array: """Wraps XLA's `BroadcastInDim `_ @@ -2598,6 +2709,7 @@ def broadcast_in_dim(operand: ArrayLike, shape: Shape, See Also: jax.lax.broadcast : simpler interface to add new leading dimensions. """ + # TODO(dfm): Re-write this as a "reshard" when only the sharding changes. out_sharding = canonicalize_sharding(out_sharding, 'broadcast_in_dim') if (np.ndim(operand) == len(shape) and not len(broadcast_dimensions) and isinstance(operand, Array) and out_sharding is None): @@ -2622,7 +2734,7 @@ def broadcast_to_rank(x: ArrayLike, rank: int) -> Array: def reshape(operand: ArrayLike, new_sizes: Shape, dimensions: Sequence[int] | None = None, - out_sharding: NamedSharding | P | None = None) -> Array: + *, out_sharding: NamedSharding | P | None = None) -> Array: """Wraps XLA's `Reshape `_ operator. @@ -2729,6 +2841,7 @@ def pad(operand: ArrayLike, padding_value: ArrayLike, [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], dtype=int32) """ + operand, padding_value = core.standard_insert_pvary(operand, padding_value) return pad_p.bind(operand, padding_value, padding_config=tuple(padding_config)) def rev(operand: ArrayLike, dimensions: Sequence[int]) -> Array: @@ -2761,6 +2874,8 @@ def select(pred: ArrayLike, on_true: ArrayLike, on_false: ArrayLike) -> Array: """ # Caution! The select_n_p primitive has the *opposite* order of arguments to # select(). This is because it implements `select_n`. + pred, on_false, on_true = core.standard_insert_pvary( + pred, on_false, on_true) return select_n_p.bind(pred, on_false, on_true) def select_n(which: ArrayLike, *cases: ArrayLike) -> Array: @@ -2786,6 +2901,7 @@ def select_n(which: ArrayLike, *cases: ArrayLike) -> Array: """ if len(cases) == 0: raise ValueError("select_n() must have at least one case") + which, *cases = core.standard_insert_pvary(which, *cases) return select_n_p.bind(which, *cases) @@ -2799,6 +2915,7 @@ def transpose(operand: ArrayLike, if permutation == tuple(range(np.ndim(operand))) and isinstance(operand, Array): return operand else: + return transpose_p.bind(operand, permutation=permutation) def argmin(operand: ArrayLike, axis: int, @@ -2848,6 +2965,8 @@ def reduce(operands: Any, flat_init_avals = safe_map(core.get_aval, flat_init_values) closed_jaxpr, out_tree = _variadic_reduction_jaxpr( computation, comp_debug, tuple(flat_init_avals), init_value_tree) + flat_operands = core.standard_insert_pvary(*flat_operands) + flat_init_avals = core.standard_insert_pvary(*flat_init_values) out = reduce_p.bind(*flat_operands, *flat_init_values, computation=computation, jaxpr=closed_jaxpr, dimensions=tuple(dimensions)) return tree_util.tree_unflatten(out_tree, out) @@ -3146,6 +3265,7 @@ def sort(operand: Array | Sequence[Array], dimension: int = -1, if not (1 <= num_keys <= len(operand)): raise ValueError(f"{num_keys=} must be between 1 and {len(operand)=}") dimension = canonicalize_axis(dimension, len(operand[0].shape)) + operand = core.standard_insert_pvary(*operand) return tuple(sort_p.bind(*operand, dimension=dimension, is_stable=is_stable, num_keys=num_keys)) @@ -3241,8 +3361,8 @@ def zeros_like_shaped_array(aval: ShapedArray) -> Array: scalar_zero = np.zeros((), dtype=aval.dtype) else: scalar_zero = _convert_element_type(0, aval.dtype, aval.weak_type) - return broadcast(scalar_zero, aval.shape, out_sharding=aval.sharding) - + out = broadcast(scalar_zero, aval.shape, out_sharding=aval.sharding) + return core.pvary(out, tuple(aval.vma)) ad_util.aval_zeros_likers[ShapedArray] = zeros_like_shaped_array def zeros_like_abstract_ref(aval: state.AbstractRef) -> core.MutableArray: @@ -3262,7 +3382,7 @@ def iota(dtype: DTypeLike, size: int) -> Array: return broadcasted_iota(dtype, (size,), 0) def broadcasted_iota(dtype: DTypeLike, shape: Shape, dimension: int, - out_sharding=None) -> Array: + *, out_sharding=None) -> Array: """Convenience wrapper around ``iota``.""" dtype = dtypes.canonicalize_dtype(dtype) shape = canonicalize_shape(shape) @@ -3365,7 +3485,7 @@ def stop(x): return ad_util.stop_gradient_p.bind(x) else: return x - return tree_map(stop, x) + return tree_util.tree_map(stop, x) def reduce_precision(operand: float | ArrayLike, exponent_bits: int, @@ -3378,7 +3498,8 @@ def reduce_precision(operand: float | ArrayLike, operator.index, exponent_bits, "exponent_bits argument of lax.reduce_precision") mantissa_bits = core.concrete_or_error( operator.index, mantissa_bits, "mantissa_bits argument of lax.reduce_precision") - return reduce_precision_p.bind(operand, exponent_bits=exponent_bits, mantissa_bits=mantissa_bits) + return reduce_precision_p.bind(operand, exponent_bits=exponent_bits, + mantissa_bits=mantissa_bits) def squeeze(array: ArrayLike, dimensions: Sequence[int]) -> Array: """Squeeze any number of size 1 dimensions from an array.""" @@ -3445,12 +3566,13 @@ def full_like(x: ArrayLike | DuckTypedArray, # This bypasses the check. and not isinstance(x, core.Tracer) and hasattr(x, 'sharding') + and x.sharding is not None + and x.sharding._is_concrete and getattr(x, '_committed', True) and not weak_type and fill_shape == np.shape(x) # type: ignore[arg-type] ) if use_x_sharding: - # TODO(yashkatariya): Use shard_alike in tracing_mode once it is supported. sharding = x.sharding # type: ignore val = full(fill_shape, _convert_element_type(fill_value, dtype, weak_type), sharding=sharding) @@ -3513,13 +3635,19 @@ def reciprocal(x: ArrayLike) -> Array: return integer_pow(x, -1) @export -def tan(x: ArrayLike) -> Array: +def tan(x: ArrayLike, accuracy=None) -> Array: r"""Elementwise tangent: :math:`\mathrm{tan}(x)`. This function lowers directly to the `stablehlo.tangent`_ operation. Args: x: input array. Must have floating-point or complex type. + accuracy: Optional `lax.Tolerance` or `lax.AccuracyMode` object that + selects the implementation of the op based on the requested accuracy. If + the implementation cannot satisfy the requested tolerance, the + compiler will return an error. If mode is specified and there are no + multiple implementations available, the default implementation will be + used. Returns: Array of the same shape and dtype as ``x`` containing the element-wise @@ -3533,7 +3661,7 @@ def tan(x: ArrayLike) -> Array: .. _stablehlo.tangent: https://openxla.org/stablehlo/spec#tangent """ - return tan_p.bind(x) + return tan_p.bind(x, accuracy=accuracy) @export def asin(x: ArrayLike) -> Array: @@ -3762,7 +3890,8 @@ def unop_dtype_rule(result_dtype, accepted_dtypes, name, aval, **kwargs): def unop(result_dtype, accepted_dtypes, name): dtype_rule = partial(unop_dtype_rule, result_dtype, accepted_dtypes, name) prim = standard_primitive(_attrgetter('shape'), dtype_rule, name, - sharding_rule=_attrgetter('sharding')) + sharding_rule=_attrgetter('sharding'), + vma_rule=_attrgetter('vma')) batching.defvectorized(prim) pe.def_trivial_padding(prim) return prim @@ -3811,7 +3940,7 @@ def broadcasting_sharding_rule(name, *avals): for a in avals: if a.sharding is not None and not a.sharding.mesh.empty: if mesh is not None and mesh != a.sharding.mesh: - raise ValueError( + raise core.ShardingTypeError( f'Mesh for all inputs should be equal. Got one mesh: {mesh} and' f' another mesh: {a.sharding.mesh}') mesh = a.sharding.mesh @@ -3828,7 +3957,7 @@ def broadcasting_sharding_rule(name, *avals): result_specs = [None] * len(shapes[0]) for i, (ss, ds) in enumerate(zip(zip(*specs), zip(*shapes))): - if all(s == ss[0] for s in ss[1:]): + if all(ss[0] == s for s in ss[1:]): # if all dimension shardings are same, the resulting dimension sharding is # the same. result_specs[i] = ss[0] @@ -3845,21 +3974,22 @@ def broadcasting_sharding_rule(name, *avals): result_specs[i] = s elif (result_specs[i] is not None and s is not None and result_specs[i] != s): - raise TypeError( + raise core.ShardingTypeError( f'{name} got incompatible shardings for broadcasting: ' f'{", ".join(map(str, map(tuple, specs)))}.') return NamedSharding(mesh, P(*result_specs)) - def naryop(result_dtype, accepted_dtypes, name, allow_extended_dtype=False, - require_same_dtypes=True): + require_same_dtypes=True, unreduced_rule=None): dtype_rule = partial(naryop_dtype_rule, result_dtype, accepted_dtypes, name, allow_extended_dtype=allow_extended_dtype, require_same=require_same_dtypes) shape_rule = partial(broadcasting_shape_rule, name) sharding_rule = partial(broadcasting_sharding_rule, name) - prim = standard_primitive(shape_rule, dtype_rule, name, - sharding_rule=sharding_rule) + prim = standard_primitive( + shape_rule, dtype_rule, name, sharding_rule=sharding_rule, + vma_rule=partial(core.standard_vma_rule, name), + unreduced_rule=unreduced_rule) batching.defbroadcasting(prim) pe.def_trivial_padding(prim) return prim @@ -3926,8 +4056,9 @@ def multi_sharding_in_dim(ctx, ops, in_avals, out_aval): return out -def _nary_lower_hlo(op: Callable, ctx, - *args: ir.Value, **params) -> Sequence[ir.Value]: +def _nary_lower_hlo( + op: Callable, ctx, *args: ir.Value, accuracy=None, **params +) -> Sequence[ir.Value]: """Lowers an elementwise operator to its MLIR equivalent. """ del params @@ -3936,8 +4067,15 @@ def _nary_lower_hlo(op: Callable, ctx, args = multi_sharding_in_dim(ctx, args, avals_in, aval_out) out = op(*args) + if accuracy: + out = op(*args, result_accuracy=accuracy_attr(accuracy)) return [mlir.lower_with_sharding_in_types(ctx, out, aval_out)] +def _unary_with_accuracy_pp_rule(eqn, context, settings): + params = dict(eqn.params) + if 'accuracy' in params and params['accuracy'] is None: + del params['accuracy'] + return core._pp_eqn(eqn.replace(params=params), context, settings) _float = {np.floating} _complex = {np.complexfloating} @@ -3997,48 +4135,68 @@ def _round_lower(ctx, x, *, rounding_method): mlir.register_lowering(is_finite_p, partial(_nary_lower_hlo, hlo.is_finite)) exp_p = standard_unop(_float | _complex, 'exp') -ad.defjvp2(exp_p, lambda g, ans, x: mul(g, ans)) +ad.defjvp2(exp_p, lambda g, ans, x, **kwargs: mul(g, ans)) mlir.register_lowering(exp_p, partial(_nary_lower_hlo, hlo.exponential)) batching.ragged_prop_rules[exp_p] = batching.ragged_mask_elementwise_rule +core.pp_eqn_rules[exp_p] = _unary_with_accuracy_pp_rule exp2_p = standard_unop(_float | _complex, 'exp2') -ad.defjvp2(exp2_p, lambda g, ans, x: mul(log(_const(x, 2)), mul(g, ans))) -def _exp2_lower(ctx, x): +ad.defjvp2( + exp2_p, lambda g, ans, x, **kwargs: mul(log(_const(x, 2)), mul(g, ans)) +) + +def _exp2_lower(ctx, x, accuracy): x_aval, = ctx.avals_in log2 = mlir.ir_constant(np.array(np.log(2), x_aval.dtype)) log2 = mlir.broadcast_in_dim(ctx, log2, x_aval, broadcast_dimensions=()) - return [hlo.exponential(hlo.multiply(log2, x))] + return [ + hlo.exponential( + hlo.multiply(log2, x), result_accuracy=accuracy_attr(accuracy) + ) + ] + mlir.register_lowering(exp2_p, _exp2_lower) +core.pp_eqn_rules[exp2_p] = _unary_with_accuracy_pp_rule log_p = standard_unop(_float | _complex, 'log') -ad.defjvp(log_p, lambda g, x: div(g, x)) +ad.defjvp(log_p, lambda g, x, **kwargs: div(g, x)) mlir.register_lowering(log_p, partial(_nary_lower_hlo, hlo.log)) +core.pp_eqn_rules[log_p] = _unary_with_accuracy_pp_rule expm1_p = standard_unop(_float | _complex, 'expm1') -ad.defjvp2(expm1_p, lambda g, ans, x: mul(g, add(ans, _one(ans)))) +ad.defjvp2(expm1_p, lambda g, ans, x, **kwargs: mul(g, add(ans, _one(ans)))) mlir.register_lowering(expm1_p, partial(_nary_lower_hlo, hlo.exponential_minus_one)) +core.pp_eqn_rules[expm1_p] = _unary_with_accuracy_pp_rule log1p_p = standard_unop(_float | _complex, 'log1p') -ad.defjvp(log1p_p, lambda g, x: div(g, add(x, _one(x)))) +ad.defjvp(log1p_p, lambda g, x, **kwargs: div(g, add(x, _one(x)))) mlir.register_lowering(log1p_p, partial(_nary_lower_hlo, hlo.log_plus_one)) +core.pp_eqn_rules[log1p_p] = _unary_with_accuracy_pp_rule tanh_p = standard_unop(_float | _complex, 'tanh') -ad.defjvp2(tanh_p, lambda g, ans, x: mul(add(g, mul(g, ans)), - sub(_one(x), ans))) +ad.defjvp2( + tanh_p, + lambda g, ans, x, **kwargs: mul(add(g, mul(g, ans)), sub(_one(x), ans)), +) mlir.register_lowering(tanh_p, partial(_nary_lower_hlo, hlo.tanh)) +core.pp_eqn_rules[tanh_p] = _unary_with_accuracy_pp_rule logistic_p = standard_unop(_float | _complex, 'logistic') -ad.defjvp2(logistic_p, lambda g, ans, x: mul(g, mul(ans, sub(_one(ans), ans)))) +ad.defjvp2( + logistic_p, + lambda g, ans, x, **kwargs: mul(g, mul(ans, sub(_one(ans), ans))), +) # TODO(phawkins): switch to LogisticOp lowering; debug numerical problems. # mlir.register_lowering(logistic_p, partial(_nary_lower_hlo, hlo.logistic)) -def logistic_impl(x): +def logistic_impl(x, accuracy): one = _const(x, 1) return div(one, add(one, exp(neg(x)))) mlir.register_lowering(logistic_p, mlir.lower_fun(logistic_impl, multiple_results=False)) +core.pp_eqn_rules[logistic_p] = _unary_with_accuracy_pp_rule def _sin_complex(x): # use expm1 instead of exp to avoid cancellation when abs(x) is small @@ -4056,21 +4214,28 @@ def _sin_complex(x): # avoid nan value when real(x) is zero and abs(x) is so large that abs(expm1(x)) is inf return select(a_is_zero, complex(_const(a, 0), im), complex(re, im)) -def _sin_lowering(ctx, x): +def _sin_lowering(ctx, x, accuracy): if dtypes.issubdtype(ctx.avals_in[0].dtype, np.complexfloating): sine = mlir.lower_fun(_sin_complex, multiple_results=False) return sine(ctx, x) - return _nary_lower_hlo(hlo.sine, ctx, x) + return _nary_lower_hlo(hlo.sine, ctx, x, accuracy=accuracy) + -def _sin_lin(nzs, x): +def _sin_p_lin(nzs, x, accuracy): nz, = nzs cos_x = cos(x) # TODO: allow this to happen in the linearized computation (need to fix backward_pass) - return (sin_p.bind(x), nz, cos_x, lambda cos_x_, t: mul(t, cos_x_)) + return ( + sin_p.bind(x, accuracy=accuracy), + nz, + cos_x, + lambda cos_x_, t: mul(t, cos_x_), + ) sin_p = standard_unop(_float | _complex, 'sin') -ad.defjvp(sin_p, lambda g, x: mul(g, cos(x))) -ad.primitive_linearizations[sin_p] = _sin_lin +ad.defjvp(sin_p, lambda g, x, accuracy: mul(g, cos(x, accuracy=accuracy))) +ad.primitive_linearizations[sin_p] = _sin_p_lin mlir.register_lowering(sin_p, _sin_lowering) +core.pp_eqn_rules[sin_p] = _unary_with_accuracy_pp_rule batching.ragged_prop_rules[sin_p] = batching.ragged_mask_elementwise_rule def _cos_complex(x): @@ -4085,19 +4250,23 @@ def _cos_complex(x): re, im = mul(cs, csh), mul(neg(sn), snh) return select(a_is_zero, complex(re, _const(a, 0)), complex(re, im)) -def _cos_lowering(ctx, x): +def _cos_lowering(ctx, x, accuracy): if dtypes.issubdtype(ctx.avals_in[0].dtype, np.complexfloating): cosine = mlir.lower_fun(_cos_complex, multiple_results=False) return cosine(ctx, x) - return _nary_lower_hlo(hlo.cosine, ctx, x) + return _nary_lower_hlo(hlo.cosine, ctx, x, accuracy=accuracy) cos_p = standard_unop(_float | _complex, 'cos') -ad.defjvp(cos_p, lambda g, x: neg(mul(g, sin(x)))) +ad.defjvp( + cos_p, lambda g, x, accuracy: neg(mul(g, sin(x, accuracy=accuracy))) +) mlir.register_lowering(cos_p, _cos_lowering) +core.pp_eqn_rules[cos_p] = _unary_with_accuracy_pp_rule tan_p = standard_unop(_float | _complex, 'tan') -ad.defjvp2(tan_p, lambda g, ans, x: mul(g, add(_const(x, 1), square(ans)))) +ad.defjvp2(tan_p, lambda g, ans, x, **kwargs: mul(g, add(_const(x, 1), square(ans)))) mlir.register_lowering(tan_p, partial(_nary_lower_hlo, hlo.tan)) +core.pp_eqn_rules[tan_p] = _unary_with_accuracy_pp_rule asin_p = standard_unop(_float | _complex, 'asin') ad.defjvp(asin_p, lambda g, x: mul(g, rsqrt(sub(_const(x, 1), square(x))))) @@ -4213,19 +4382,27 @@ def _abs_jvp_rule(g, ans, x): _maybe_real = lambda x: real(x) if _iscomplex(x) else x sqrt_p = standard_unop(_float | _complex, 'sqrt') -ad.defjvp2(sqrt_p, lambda g, ans, x: mul(g, div(_const(x, 0.5), ans))) +ad.defjvp2(sqrt_p, lambda g, ans, x, **kwargs: mul(g, div(_const(x, 0.5), ans))) mlir.register_lowering(sqrt_p, partial(_nary_lower_hlo, hlo.sqrt)) +core.pp_eqn_rules[sqrt_p] = _unary_with_accuracy_pp_rule rsqrt_p = standard_unop(_float | _complex, 'rsqrt') -ad.defjvp2(rsqrt_p, - lambda g, ans, x: - mul(g, mul(_const(x, -0.5), div(ans, x)))) +ad.defjvp2( + rsqrt_p, + lambda g, ans, x, **kwargs: mul(g, mul(_const(x, -0.5), div(ans, x))), +) mlir.register_lowering(rsqrt_p, partial(_nary_lower_hlo, hlo.rsqrt)) +core.pp_eqn_rules[rsqrt_p] = _unary_with_accuracy_pp_rule cbrt_p = standard_unop(_float, 'cbrt') -ad.defjvp2(cbrt_p, - lambda g, ans, x: mul(g, mul(_const(x, 1/3), integer_pow(ans, -2)))) +ad.defjvp2( + cbrt_p, + lambda g, ans, x, **kwargs: mul( + g, mul(_const(x, 1 / 3), integer_pow(ans, -2)) + ), +) mlir.register_lowering(cbrt_p, partial(_nary_lower_hlo, hlo.cbrt)) +core.pp_eqn_rules[cbrt_p] = _unary_with_accuracy_pp_rule square_p = standard_unop(_int | _float | _complex, 'square') @@ -4307,7 +4484,7 @@ def _integer_pow_jvp(g, x, *, y): integer_pow_p = standard_primitive( _attrgetter('shape'), _integer_pow_dtype_rule, 'integer_pow', - sharding_rule=_attrgetter('sharding')) + sharding_rule=_attrgetter('sharding'), vma_rule=_attrgetter('vma')) batching.defvectorized(integer_pow_p) ad.defjvp(integer_pow_p, _integer_pow_jvp) pe.def_trivial_padding(integer_pow_p) @@ -4400,8 +4577,30 @@ def _add_transpose(t, x, y): else: return [_unbroadcast(x_aval, t), _unbroadcast(y_aval, t)] -# TODO(slebedev): Why does mypy fail to infer the type here? -add_p: Primitive = standard_naryop([_num, _num], 'add') +def _add_unreduced(out_sharding, x, y): + x_ur, y_ur = x.sharding.spec.unreduced, y.sharding.spec.unreduced + if x_ur and y_ur: + if x_ur != y_ur: + raise core.ShardingTypeError( + 'lhs and rhs to `add` must be unreduced along the same mesh axes. ' + f'Got lhs={x_ur}, rhs={y_ur}') + res_unreduced = x_ur + elif x_ur or y_ur: + if x_ur and not y_ur: + lhs_str, rhs_str = 'lhs', 'rhs' + else: + assert not x_ur and y_ur + lhs_str, rhs_str = 'rhs', 'lhs' + raise core.ShardingTypeError( + f'{lhs_str} is unreduced while {rhs_str} is not. `add` operation does' + ' not allow this because there will be implicit communication. Please' + f' reduce {lhs_str} via `reshard` before calling `add`.') + else: + res_unreduced = frozenset() + return out_sharding.update(spec=out_sharding.spec.update(unreduced=res_unreduced)) + +add_p: Primitive = naryop(_input_dtype, [_num, _num], 'add', + unreduced_rule=_add_unreduced) ad.primitive_jvps[add_p] = _add_jvp ad.primitive_transposes[add_p] = _add_transpose mlir.register_lowering(add_p, partial(_nary_lower_hlo, hlo.add)) @@ -4678,11 +4877,11 @@ def _convert_elt_type_folding_rule(consts, eqn): def _convert_elt_type_fwd_rule(eqn): v, = eqn.invars - if (not dtypes.issubdtype(eqn.params['new_dtype'], dtypes.extended) and + if (v.aval.dtype == eqn.params['new_dtype'] and + v.aval.weak_type == eqn.params['weak_type'] and not dtypes.issubdtype(v.aval.dtype, dtypes.extended) and - v.aval.dtype == eqn.params['new_dtype'] and - v.aval.weak_type == eqn.params['weak_type']): - return [v], None + (eqn.params['sharding'] is None or eqn.params['sharding'] == v.aval.sharding)): + return [0], None else: return [None], eqn @@ -4710,10 +4909,21 @@ def _convert_element_type_bind_with_trace(trace, args, params): partial(standard_abstract_eval, convert_element_type_p, _convert_element_type_shape_rule, _convert_element_type_dtype_rule, _convert_element_type_weak_type_rule, - _convert_element_type_sharding_rule)) + _convert_element_type_sharding_rule, + partial(core.standard_vma_rule, convert_element_type_p.name), + None)) ad.defjvp2(convert_element_type_p, _convert_element_type_jvp_rule) ad.primitive_transposes[convert_element_type_p] = _convert_element_type_transpose_rule -batching.defvectorized(convert_element_type_p) + +def _convert_element_type_batching_rule( + axis_data, batched_args, batch_dims, *, new_dtype, weak_type, sharding): + if sharding is not None: + sharding = batching.get_sharding_for_vmap(axis_data, sharding, 0) + new_params = dict(new_dtype=new_dtype, weak_type=weak_type, sharding=sharding) + return convert_element_type_p.bind(*batched_args, **new_params), batch_dims[0] +batching.fancy_primitive_batchers[convert_element_type_p] = _convert_element_type_batching_rule +batching.skippable_batchers[convert_element_type_p] = lambda _: () + pe.const_fold_rules[convert_element_type_p] = _convert_elt_type_folding_rule pe.forwarding_rules[convert_element_type_p] = _convert_elt_type_fwd_rule pe.def_trivial_padding(convert_element_type_p) @@ -4743,6 +4953,9 @@ def _to_edtype_abstract_eval(x, *, edtype): not isinstance(x.dtype, dtypes.ExtendedDType)) # For backward compatibility, if the edtype rules have a `convert_to` method, # use that rather than looking for an `allow_conversion: bool` attribute. + if not isinstance(x, (ShapedArray, core.DShapedArray)): + raise TypeError("can only convert to an extended dtype on an array type," + f"but got {type(x)}") if convert_to := getattr(edtype._rules, 'convert_to', None): allow_conversion = convert_to(x.dtype, edtype) else: @@ -4752,6 +4965,7 @@ def _to_edtype_abstract_eval(x, *, edtype): f"Cannot convert_element_type from {dtype_to_string(x.dtype)} " f"to {dtype_to_string(edtype)}") rep_aval = core.physical_element_aval(edtype) + assert tuple(rep_aval.sharding.spec) == (None,) * rep_aval.ndim if x.dtype != rep_aval.dtype: raise ValueError( "can only convert to extended dtype from its representation dtype, " @@ -4774,7 +4988,20 @@ def _to_edtype_abstract_eval(x, *, edtype): f" has a representation shape {rep_aval.shape} while the given " f"representation array has shape {x.shape}, so the shape suffix " f"does not match: given {shape_suffix} but required {rep_aval.shape}.") - return x.update(shape=shape_prefix, dtype=edtype) + if isinstance(x, ShapedArray): + spec_prefix, spec_suffix = x.sharding.spec[:n], x.sharding.spec[n:] + if tuple(spec_suffix) != (None,) * len(spec_suffix): + raise ValueError( + "can only convert to extended dtype from an array with trailing " + "axes that are not explicitly sharded, but tried to convert from " + f"{x.str_short(short_dtypes=True)} to an extended dtype with element " + f"shape {rep_aval.shape}") + return x.update(shape=shape_prefix, dtype=edtype, + sharding=x.sharding.update(spec=spec_prefix)) + elif isinstance(x, core.DShapedArray): + return x.update(shape=shape_prefix, dtype=edtype) + else: + assert False # unreachable, see isinstance check above to_edtype_p = Primitive('to_edtype') to_edtype_p.def_impl(partial(dispatch.apply_primitive, to_edtype_p)) @@ -4791,6 +5018,9 @@ def _to_edtype_abstract_eval(x, *, edtype): def _from_edtype_abstract_eval(x, *, dtype): assert (isinstance(x.dtype, dtypes.ExtendedDType) and not isinstance(dtype, dtypes.ExtendedDType)) + if not isinstance(x, (ShapedArray, core.DShapedArray)): + raise TypeError("can only convert from an extended dtype on an array type," + f"but got {type(x)}") if convert_from := getattr(x.dtype._rules, 'convert_from', None): allow_conversion = convert_from(x.dtype, dtype) else: @@ -4800,16 +5030,22 @@ def _from_edtype_abstract_eval(x, *, dtype): f"Cannot convert_element_type from {dtype_to_string(x.dtype)} " f"to {dtype_to_string(dtype)}") rep_aval = core.physical_element_aval(x.dtype) + assert tuple(rep_aval.sharding.spec) == (None,) * rep_aval.ndim if rep_aval.dtype != dtype: raise ValueError( "can only convert from extended dtype to its representation dtype, " f"but tried to convert from {dtype_to_string(x.dtype)} to " f"{dtype_to_string(dtype)} which doesn't match the representation type " f"{dtype_to_string(rep_aval.dtype)}.") - if all(isinstance(d, int) for d in x.shape): - return core.ShapedArray(shape=(*x.shape, *rep_aval.shape), dtype=dtype) + if isinstance(x, ShapedArray): + return x.update(shape=(*x.shape, *rep_aval.shape), dtype=dtype) + elif isinstance(x, core.DShapedArray): + if all(isinstance(d, int) for d in x.shape): + return core.ShapedArray(shape=(*x.shape, *rep_aval.shape), dtype=dtype) + else: + raise NotImplementedError else: - raise NotImplementedError + assert False # unreachable, see isinstance check above from_edtype_p = Primitive('from_edtype') from_edtype_p.def_impl(partial(dispatch.apply_primitive, from_edtype_p)) @@ -4854,9 +5090,9 @@ def _bitcast_convert_type_sharding_rule(operand, *, new_dtype): if old_nbits == new_nbits: return operand.sharding elif old_nbits > new_nbits: - return operand.sharding.with_spec((*operand.sharding.spec, None)) + return operand.sharding.update(spec=(*operand.sharding.spec, None)) else: - return operand.sharding.with_spec(operand.sharding.spec[:-1]) + return operand.sharding.update(spec=operand.sharding.spec[:-1]) def _bitcast_convert_type_dtype_rule(operand, *, new_dtype): old_dtype = dtypes.canonicalize_dtype(operand.dtype) @@ -4875,7 +5111,8 @@ def _bitcast_convert_type_dtype_rule(operand, *, new_dtype): bitcast_convert_type_p = standard_primitive( _bitcast_convert_type_shape_rule, _bitcast_convert_type_dtype_rule, 'bitcast_convert_type', weak_type_rule=_strip_weak_type, - sharding_rule=_bitcast_convert_type_sharding_rule) + sharding_rule=_bitcast_convert_type_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'bitcast_convert_type')) ad.defjvp_zero(bitcast_convert_type_p) batching.defvectorized(bitcast_convert_type_p) @@ -4985,7 +5222,7 @@ def _dot_general_shape_computation(lhs_shape, rhs_shape, dimension_numbers): lhs_tensored_shape = tuple_delete(lhs_shape, lhs_contract_or_batch) rhs_group = () if isinstance(dimension_numbers, RaggedDotDimensionNumbers): - rhs_group = tuple(dimension_numbers.rhs_group_dimensions) + rhs_group = tuple(dimension_numbers.rhs_group_dimensions) # pytype: disable=attribute-error rhs_contract_or_batch_or_group = tuple( sorted(tuple(rhs_contracting) + tuple(rhs_batch) + rhs_group) ) @@ -4996,21 +5233,36 @@ def _dot_general_shape_computation(lhs_shape, rhs_shape, dimension_numbers): def _check_specs_match(lhs_spec, rhs_spec, msg): for l, r in zip(lhs_spec, rhs_spec): if l is not None and r is not None and l != r: - raise TypeError(msg) + raise core.ShardingTypeError(msg) def _dot_general_sharding_rule(lhs, rhs, *, dimension_numbers, precision, preferred_element_type: DTypeLike | None, out_sharding): if lhs.sharding.mesh != rhs.sharding.mesh: - raise ValueError( + raise core.ShardingTypeError( 'Mesh of both lhs and rhs should match. Got lhs:' f' {lhs.sharding.mesh} and rhs: {rhs.sharding.mesh}') + (lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers + lhs_contracting_spec = tuple(lhs.sharding.spec[i] for i in lhs_contracting) + rhs_contracting_spec = tuple(rhs.sharding.spec[i] for i in rhs_contracting) + if out_sharding is not None: assert isinstance(out_sharding, NamedSharding) + if out_sharding.spec.unreduced: + if lhs_contracting_spec != rhs_contracting_spec: + raise core.ShardingTypeError( + 'lhs and rhs contracting dims should be sharded identically when' + ' out_sharding provided to dot_general mentions unreduced_axes.' + f' Got {out_sharding=}, {lhs_contracting_spec=},' + f' {rhs_contracting_spec=}') + if out_sharding.spec.unreduced != frozenset(lhs_contracting_spec): + raise core.ShardingTypeError( + "out_sharding's unreduced axes should be equal to the contracting" + f' specs. Got unreduced axes={out_sharding.spec.unreduced} and' + f' contracting spec={lhs_contracting_spec}') return out_sharding - (lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers lhs_batch_spec = tuple(lhs.sharding.spec[i] for i in lhs_batch) rhs_batch_spec = tuple(rhs.sharding.spec[i] for i in rhs_batch) msg = ("dot_general requires lhs batch dimensions and rhs batch dimensions " @@ -5018,15 +5270,13 @@ def _dot_general_sharding_rule(lhs, rhs, *, dimension_numbers, precision, f"{rhs_batch_spec}.") _check_specs_match(lhs_batch_spec, rhs_batch_spec, msg) - lhs_contracting_spec = tuple(lhs.sharding.spec[i] for i in lhs_contracting) - rhs_contracting_spec = tuple(rhs.sharding.spec[i] for i in rhs_contracting) msg = ("dot_general requires contracting dimensions to have consistent " f"sharding, got {lhs_contracting_spec} and {rhs_contracting_spec}.") _check_specs_match(lhs_contracting_spec, rhs_contracting_spec, msg) for l, r in zip(lhs_contracting_spec, rhs_contracting_spec): if l is not None and r is not None: - raise ValueError( + raise core.ShardingTypeError( 'Contracting dimensions are sharded and it is ambiguous how the' ' output should be sharded. Please specify the output sharding via' ' the `out_sharding` parameter of einsum. Or reshard your input via' @@ -5116,7 +5366,7 @@ def _dot_general_transpose_lhs(g, x, y, *, dimension_numbers, precision, out_axes = np.argsort(unsorted_axes) xs = x.aval.sharding inverse_spec = tuple(xs.spec[o] for o in unsorted_axes) - ds = xs.with_spec(inverse_spec) + ds = xs.update(spec=inverse_spec) dot_general_out = dot_general(g, y, dims, precision=precision, preferred_element_type=preferred_element_type, out_sharding=ds) @@ -5344,6 +5594,7 @@ def _dot_general_ragged_prop_rule(eqn_params, invar_raggedness, outvars): _dot_general_dtype_rule, 'dot_general', sharding_rule=_dot_general_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'dot_general') ) @@ -5371,15 +5622,26 @@ def _dot_general_batch_unpack_dims(batch_dims): core.pp_eqn_rules[dot_general_p] = _dot_general_pp_rule batching.ragged_prop_rules[dot_general_p] = _dot_general_ragged_prop_rule -def precision_attr(precision: Precision) -> ir.ArrayAttr: + +def _full_precision(precision: Precision) -> tuple[Precision, Precision]: if precision is None or isinstance(precision, (DotAlgorithm, DotAlgorithmPreset)): - full_precision = (Precision.DEFAULT, Precision.DEFAULT) + return (Precision.DEFAULT, Precision.DEFAULT) elif not isinstance(precision, tuple): - full_precision = (precision, precision) + return (precision, precision) else: - full_precision = precision + return precision + + +def precision_attr(precision: Precision) -> ir.ArrayAttr: return ir.ArrayAttr.get( - [hlo.PrecisionAttr.get(str(p)) for p in full_precision]) + [hlo.PrecisionAttr.get(str(p)) for p in _full_precision(precision)] + ) + + +def chlo_precision_attr(precision: Precision) -> ir.ArrayAttr: + return ir.ArrayAttr.get( + [chlo.PrecisionAttr.get(str(p)) for p in _full_precision(precision)] + ) def dot_algorithm_attr(precision: CanonicalPrecision, lhs_dtype: DTypeLike, @@ -5417,32 +5679,30 @@ def maybe_convert_dtype(input_dtype, target_dtypes): return lhs_dtype, rhs_dtype, out_type -def _dot_general_lower(ctx, lhs, rhs, *, dimension_numbers, - precision, preferred_element_type: np.dtype | None, - out_sharding, platform: str = "default"): +def accuracy_attr(accuracy) -> hlo.ResultAccuracyAttr: + if isinstance(accuracy, AccuracyMode): + return hlo.ResultAccuracyAttr.get(0.0, 0.0, int(0), str(accuracy.name)) + elif isinstance(accuracy, Tolerance): + return hlo.ResultAccuracyAttr.get( + atol=accuracy.atol, + rtol=accuracy.rtol, + ulps=accuracy.ulps, + mode='TOLERANCE', + ) + +def _handle_dot_precision(ctx, lhs, rhs, precision, platform): def _is_fp8_mixed_precision_matmul(_lhs_dtypes, _rhs_dtypes): fp8_dtypes = (dtypes.float8_e4m3fn, dtypes.float8_e5m2, - dtypes.float8_e5m2fnuz, dtypes.float8_e4m3fnuz) - if dtypes.float8_e3m4 is not None: - fp8_dtypes += (dtypes.float8_e3m4,) - if dtypes.float8_e4m3 is not None: - fp8_dtypes += (dtypes.float8_e4m3,) - if dtypes.float8_e8m0fnu is not None: - fp8_dtypes += (dtypes.float8_e8m0fnu,) + dtypes.float8_e5m2fnuz, dtypes.float8_e4m3fnuz, + dtypes.float8_e3m4, dtypes.float8_e4m3, + dtypes.float8_e8m0fnu) return _lhs_dtypes in fp8_dtypes and _rhs_dtypes in fp8_dtypes - del preferred_element_type # Implied by the output aval - lhs_aval, rhs_aval = ctx.avals_in + + # The *_ lets us reuse this for ragged_dot_general, which has group_sizes. + lhs_aval, rhs_aval, *_ = ctx.avals_in lhs_dtype, rhs_dtype = lhs_aval.dtype, rhs_aval.dtype aval_out, = ctx.avals_out accumulation_aval = aval_out - (lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers - - dot_dnums = hlo.DotDimensionNumbers.get( - lhs_batching_dimensions=list(lhs_batch), - rhs_batching_dimensions=list(rhs_batch), - lhs_contracting_dimensions=list(lhs_contracting), - rhs_contracting_dimensions=list(rhs_contracting)) - algorithm_kwarg = {} if isinstance(precision, (DotAlgorithm, DotAlgorithmPreset)): # The CPU backend silently ignores the algorithm spec, so we check here to @@ -5500,7 +5760,22 @@ def maybe_convert_dtype(operand, operand_aval, target_dtype): core.ShapedArray(lhs_aval.shape, aval_out.dtype)) rhs = mlir.convert_hlo(ctx, rhs, rhs_aval, core.ShapedArray(rhs_aval.shape, aval_out.dtype)) + return lhs, rhs, accumulation_aval, algorithm_kwarg + +def _dot_general_lower(ctx, lhs, rhs, *, dimension_numbers, + precision, preferred_element_type: np.dtype | None, + out_sharding, platform: str = "default"): + del preferred_element_type # Implied by the output aval + lhs, rhs, accumulation_aval, algorithm_kwarg = _handle_dot_precision( + ctx, lhs, rhs, precision, platform + ) + (lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers + dot_dnums = hlo.DotDimensionNumbers.get( + lhs_batching_dimensions=list(lhs_batch), + rhs_batching_dimensions=list(rhs_batch), + lhs_contracting_dimensions=list(lhs_contracting), + rhs_contracting_dimensions=list(rhs_contracting)) result = hlo.dot_general( mlir.aval_to_ir_type(accumulation_aval), lhs, @@ -5509,7 +5784,7 @@ def maybe_convert_dtype(operand, operand_aval, target_dtype): precision_config=precision_attr(precision), **algorithm_kwarg, ) - + aval_out, = ctx.avals_out result = mlir.lower_with_sharding_in_types(ctx, result, aval_out) if accumulation_aval.dtype != aval_out.dtype: result = mlir.convert_hlo(ctx, result, accumulation_aval, aval_out) @@ -5805,7 +6080,7 @@ def grad_x_dims(): unsorted_axes = list(x_batch) + x_kept + x_contract_sorted_by_y case RaggedDotMode.RAGGED_CONTRACTING | RaggedDotMode.RAGGED_BATCH: raise unimplemented('grad_x_dims', mode) - return dims, unsorted_axes + return dims, unsorted_axes # pytype: disable=name-error def grad_y_dims(): match mode: @@ -5824,7 +6099,7 @@ def grad_y_dims(): ) case RaggedDotMode.RAGGED_CONTRACTING | RaggedDotMode.RAGGED_BATCH: raise unimplemented('grad_y_dims', mode) - return dims, unsorted_axes + return dims, unsorted_axes # pytype: disable=name-error def _ragged_dot_grad(lhs, rhs, dims_fn, aval): dims, unsorted_axes = dims_fn() @@ -5918,6 +6193,7 @@ def _ragged_dot_general_batch_rule( _ragged_dot_general_shape_rule, _ragged_dot_general_dtype_rule, 'ragged_dot_general', + vma_rule=partial(core.standard_vma_rule, 'ragged_dot') ) ad.primitive_jvps[ragged_dot_general_p] = _ragged_dot_general_jvp_rule ad.primitive_transposes[ragged_dot_general_p] = _ragged_dot_general_transpose_rule @@ -6025,13 +6301,88 @@ def expand(x, dim, gs, *axes): lhs, rhs, dimension_numbers=ragged_dot_dimension_numbers.dot_dimension_numbers, - ) + ) # pytype: disable=bad-return-type + + +def _ragged_dot_general_lower( + ctx, + lhs, + rhs, + group_sizes, + *, + ragged_dot_dimension_numbers, + precision, + preferred_element_type: np.dtype | None, + group_offset: Array | None = None, + platform: str = 'default', +): + if group_offset is not None: + raise NotImplementedError('Unimplemented group_offset support.') + + # TODO(pravnar): Remove this once we have sharding support. + def use_default_lowering(): + axis_context = ctx.module_context.axis_context + return ( + isinstance(axis_context, SPMDAxisContext) + or isinstance(axis_context, ShardingContext) + and axis_context.num_devices > 1 + ) + if use_default_lowering(): + result = mlir.lower_fun(_ragged_dot_general_impl, multiple_results=False)( + ctx, lhs, rhs, group_sizes, + ragged_dot_dimension_numbers=ragged_dot_dimension_numbers, + precision=precision, + preferred_element_type=preferred_element_type, + group_offset=group_offset + ) + (aval_out,) = ctx.avals_out + return mlir.lower_with_sharding_in_types(ctx, result, aval_out) + + del preferred_element_type # Implied by the output aval + lhs, rhs, accumulation_aval, _ = _handle_dot_precision( + ctx, lhs, rhs, precision, platform + ) + (lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = ( + ragged_dot_dimension_numbers.dot_dimension_numbers + ) + ragged_dot_dnums = chlo.RaggedDotDimensionNumbers.get( + lhs_batching_dimensions=list(lhs_batch), + rhs_batching_dimensions=list(rhs_batch), + lhs_contracting_dimensions=list(lhs_contracting), + rhs_contracting_dimensions=list(rhs_contracting), + lhs_ragged_dimensions=list( + ragged_dot_dimension_numbers.lhs_ragged_dimensions + ), + rhs_group_dimensions=list( + ragged_dot_dimension_numbers.rhs_group_dimensions + ), + ) + result = chlo.ragged_dot( + mlir.aval_to_ir_type(accumulation_aval), + lhs, + rhs, + group_sizes, + ragged_dot_dnums, + precision_config=chlo_precision_attr(precision), + ) + (aval_out,) = ctx.avals_out + result = mlir.lower_with_sharding_in_types(ctx, result, aval_out) + if accumulation_aval.dtype != aval_out.dtype: + result = mlir.convert_hlo(ctx, result, accumulation_aval, aval_out) + return [result] mlir.register_lowering(ragged_dot_general_p, mlir.lower_fun(_ragged_dot_general_impl, multiple_results=False)) +for platform in ['tpu']: + mlir.register_lowering( + ragged_dot_general_p, + partial(_ragged_dot_general_lower, platform=platform), + platform=platform, + ) + def _broadcast_in_dim_shape_rule(operand, *, shape, broadcast_dimensions, sharding): @@ -6077,7 +6428,7 @@ def _broadcast_in_dim_sharding_rule(operand, *, shape, broadcast_dimensions, orig_spec = iter(operand.sharding.spec) new_spec = [next(orig_spec) if i in bds else None for i in range(len(shape))] assert next(orig_spec, None) is None - return operand.sharding.with_spec(new_spec) + return operand.sharding.update(spec=new_spec) def _broadcast_in_dim_typecheck_rule( _, operand, *dyn_shape, shape, broadcast_dimensions, sharding): @@ -6163,20 +6514,23 @@ def _broadcast_in_dim_batch_rule(axis_data, batched_args, batch_dims, shape, def _broadcast_in_dim_fwd_rule(eqn): v, *dyn = eqn.invars - if not dyn and core.definitely_equal_shape(eqn.params['shape'], v.aval.shape): - return [v], None + if (not dyn and core.definitely_equal_shape(eqn.params['shape'], v.aval.shape) + and (eqn.params['sharding'] is None or + eqn.params['sharding'] == v.aval.sharding)): + return [0], None else: return [None], eqn def _broadcast_in_dim_staging_rule( - trace, x, *dyn, shape, broadcast_dimensions, sharding): + trace, source_info, x, *dyn, shape, broadcast_dimensions, sharding): params = dict(shape=shape, broadcast_dimensions=broadcast_dimensions, sharding=sharding) if not dyn: - return trace.default_process_primitive(broadcast_in_dim_p, (x,), params) + return trace.default_process_primitive(broadcast_in_dim_p, (x,), params, + source_info=source_info) aval = core.DShapedArray(_merge_dyn_shape(shape, dyn), x.dtype, x.weak_type) - return _dyn_shape_staging_rule(trace, broadcast_in_dim_p, aval, x, *dyn, - **params) + return _dyn_shape_staging_rule(trace, source_info, broadcast_in_dim_p, aval, + x, *dyn, **params) def _broadcast_in_dim_padding_rule(in_avals, out_avals, x, *dyn_shape, shape, broadcast_dimensions): @@ -6226,7 +6580,7 @@ def _broadcast_in_dim_partial_eval( out_aval = core.DShapedArray(tuple(shape_), operand.dtype, operand.weak_type) out_tracer = pe.JaxprTracer(trace, pe.PartialVal.unknown(out_aval), None) eqn = pe.new_eqn_recipe( - [operand_tracer, *dyn_shape_tracers], [out_tracer], broadcast_in_dim_p, + trace, [operand_tracer, *dyn_shape_tracers], [out_tracer], broadcast_in_dim_p, dict(shape=shape, broadcast_dimensions=broadcast_dimensions, sharding=None), core.no_effects, source_info_util.current()) @@ -6252,7 +6606,9 @@ def _broadcast_in_dim_abstract_eval(x, *dyn_shape, shape, broadcast_dimensions, new_sharding = _broadcast_in_dim_sharding_rule( x, shape=shape, broadcast_dimensions=broadcast_dimensions, sharding=sharding) - return core.ShapedArray(shape, x.dtype, x.weak_type, sharding=new_sharding) + new_vma = core.standard_vma_rule('broadcast_in_dim', x) + return core.ShapedArray(shape, x.dtype, x.weak_type, sharding=new_sharding, + vma=new_vma) # If any BInts in shape, or Tracers in dyn_shape, produce a DShapedArray # (even if x is a ShapedArray) # TODO(mattjj): unify DShapedArray with ShapedArray, and remove this code @@ -6336,7 +6692,8 @@ def _clamp_batch_rule(batched_args, batch_dims, **params): return clamp_p.bind(min, x, max), 0 clamp_p = standard_primitive(_clamp_shape_rule, _clamp_dtype_rule, 'clamp', - sharding_rule=_clamp_sharding_rule) + sharding_rule=_clamp_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'clamp')) ad.defjvp(clamp_p, lambda g, min, operand, max: select(bitwise_and(gt(min, operand), lt(min, max)), @@ -6384,7 +6741,7 @@ def _concatenate_sharding_rule(*operands, **kwargs): return core.get_cur_mesh_sharding() if not all(s == non_empty_s[0] for s in non_empty_s): ss = ", ".join(str(o.sharding) for o in operands) - raise TypeError( + raise core.ShardingTypeError( f"All operands should have the same sharding. Got shardings {ss}") return non_empty_s[0] @@ -6409,7 +6766,7 @@ def _concatenate_batch_rule(batched_args, batch_dims, *, dimension): for op, bdim in zip(batched_args, batch_dims) if bdim is not None) operands = [batching.moveaxis(op, bdim, 0) if bdim is not None else broadcast( - op, (size,), out_sharding=core.get_aval(op).sharding.with_spec( + op, (size,), out_sharding=core.get_aval(op).sharding.update(spec= (spec, *core.get_aval(op).sharding.spec))) for op, bdim in zip(batched_args, batch_dims)] return concatenate(operands, dimension + 1), 0 @@ -6423,7 +6780,8 @@ def _concatenate_pad_rule(in_avals, out_avals, *operands, dimension): concatenate_p = standard_primitive( _concatenate_shape_rule, _concatenate_dtype_rule, 'concatenate', - sharding_rule=_concatenate_sharding_rule) + sharding_rule=_concatenate_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'concatenate')) ad.deflinear2(concatenate_p, _concatenate_transpose_rule) ad.primitive_transposes[concatenate_p] = _concatenate_transpose_rule batching.primitive_batchers[concatenate_p] = _concatenate_batch_rule @@ -6495,11 +6853,17 @@ def _split_sharding_rule(operand, *, sizes, axis): return [slicing._get_sharding_for_varying_out_shape(out_sh, operand, 'split') for out_sh in out_shapes] +def _split_vma_rule(operand, *, sizes, axis): + out_vma = core.standard_vma_rule('split', operand) + out_shapes = _split_shape_rule(operand, sizes=sizes, axis=axis) + return [out_vma] * len(out_shapes) + split_p = core.Primitive('split') split_p.multiple_results = True split_p.def_abstract_eval( partial(standard_multi_result_abstract_eval, split_p, _split_shape_rule, - _split_dtype_rule, _split_weak_type_rule, _split_sharding_rule)) + _split_dtype_rule, _split_weak_type_rule, _split_sharding_rule, + _split_vma_rule)) split_p.def_impl(partial(dispatch.apply_primitive, split_p)) ad.deflinear2(split_p, _split_transpose_rule) batching.primitive_batchers[split_p] = _split_batch_rule @@ -6581,7 +6945,8 @@ def _pad_batch_rule(batched_args, batch_dims, *, padding_config): return select(mask, x, broadcasted_padding), operand_bdim pad_p = standard_primitive(_pad_shape_rule, _pad_dtype_rule, 'pad', - sharding_rule=_pad_sharding_rule) + sharding_rule=_pad_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'pad')) ad.deflinear2(pad_p, _pad_transpose) batching.primitive_batchers[pad_p] = _pad_batch_rule @@ -6615,7 +6980,7 @@ def _squeeze_sharding_rule(operand, *, dimensions): dims_set = set(dimensions) new_spec = tuple(s for i, s in enumerate(operand.sharding.spec) if i not in dims_set) - return operand.sharding.with_spec(new_spec) + return operand.sharding.update(spec=new_spec) def _compute_squeeze_shape(shape, dimensions): dims_set = set(dimensions) @@ -6645,7 +7010,8 @@ def _squeeze_batch_rule(batched_args, batch_dims, *, dimensions): return squeeze(operand, dimensions=dimensions), bdim_out squeeze_p = standard_primitive(_squeeze_shape_rule, _squeeze_dtype_rule, - 'squeeze', sharding_rule=_squeeze_sharding_rule) + 'squeeze', sharding_rule=_squeeze_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'squeeze')) ad.deflinear2(squeeze_p, _squeeze_transpose_rule) batching.primitive_batchers[squeeze_p] = _squeeze_batch_rule pe.def_trivial_padding(squeeze_p) @@ -6703,16 +7069,21 @@ def _split_on_one_axis(op_shape, new_sizes, name): else: count += 1 if count > 1: - raise ValueError( + raise core.ShardingTypeError( f'{name} on more than 1 axis is not supported. Please specify' ' the sharding of the output via the `sharding` argument of' f' jax.lax.reshape. Got operand.shape={op_shape} and {new_sizes=}') temp = [new_sizes[j]] - while math.prod(temp) != op_shape[i]: + next_j = j + 1 + while (math.prod(temp) != op_shape[i] or + (next_j < len(new_sizes) and new_sizes[next_j] == 1)): if math.prod(temp) > op_shape[i]: return False, [] j += 1 + if j >= len(new_sizes): + return False, [] temp.append(new_sizes[j]) + next_j += 1 out.append(temp) i += 1 j += 1 @@ -6729,6 +7100,8 @@ def _merge_on_one_axis(operand, new_sizes): def _reshape_sharding_rule(operand, *, new_sizes, dimensions, sharding): if sharding is not None: return sharding + if operand.sharding.is_fully_replicated: + return operand.sharding non_1s_op_shape = [s for s in operand.shape if s != 1] non_1s_new_shape = [s for s in new_sizes if s != 1] if non_1s_op_shape == non_1s_new_shape: @@ -6744,11 +7117,10 @@ def _reshape_sharding_rule(operand, *, new_sizes, dimensions, sharding): return _merge_an_axis_sharding_rule(operand, operand_merge, new_sizes, dimensions) - raise ValueError( + raise core.ShardingTypeError( 'This reshape is not supported. Please specify the sharding of' ' the output via the `out_sharding` argument of jax.lax.reshape. Got' - f' operand shape: {operand.shape}, new sizes: {new_sizes} and' - f' operand spec: {operand.sharding.spec}') + f' operand shape: {operand}, new sizes: {new_sizes}') def _split_merge_singleton_dim_sharding_rule(operand, new_sizes): filtered_spec = [sp for sh, sp in zip(operand.shape, operand.sharding.spec) @@ -6761,7 +7133,7 @@ def _split_merge_singleton_dim_sharding_rule(operand, new_sizes): else: sp = next(fs) new_spec.append(sp) - return operand.sharding.with_spec(new_spec) + return operand.sharding.update(spec=new_spec) def _get_spec_size(sp, mesh): tup_sp = sp if isinstance(sp, tuple) else (sp,) @@ -6777,15 +7149,14 @@ def _split_an_axis_sharding_rule(operand, out_split, new_sizes, dimensions): elif dimensions is None and out[0] % _get_spec_size(sp, mesh) == 0: new_spec.extend([sp] + [None] * (len(out) - 1)) else: - raise ValueError( + raise core.ShardingTypeError( 'This reshape is not supported. Please specify the sharding of the' ' output via the `sharding` argument of jax.lax.reshape. Got' - f' operand shape: {operand.shape}, new sizes: {new_sizes} and' - f' operand spec: {operand.sharding.spec}') + f' operand shape: {operand}, new sizes: {new_sizes}') else: new_spec.append(sp) assert len(new_spec) == len(new_sizes), (new_spec, new_sizes) - return operand.sharding.with_spec(new_spec) + return operand.sharding.update(spec=new_spec) def _merge_an_axis_sharding_rule(operand, operand_merge, new_sizes, dimensions): @@ -6802,16 +7173,15 @@ def _merge_an_axis_sharding_rule(operand, operand_merge, new_sizes, dimensions): assert new_size % _get_spec_size(sp[0], mesh) == 0 new_spec.append(sp[0]) else: - raise ValueError( + raise core.ShardingTypeError( 'This reshape is not supported. Please specify the sharding of the' ' output via the `sharding` argument of jax.lax.reshape. Got' - f' operand shape: {operand.shape}, new sizes: {new_sizes} and' - f' operand spec: {operand.sharding.spec}') + f' operand shape: {operand}, new sizes: {new_sizes}') else: new_spec.append(next(op_spec)) assert next(op_spec, None) is None assert len(new_spec) == len(new_sizes), (new_spec, new_sizes) - return operand.sharding.with_spec(new_spec) + return operand.sharding.update(spec=new_spec) def _reshape_typecheck_rule(_, operand, *dyn_shape, new_sizes, dimensions, @@ -6838,7 +7208,7 @@ def _reshape_transpose_rule(t, operand, *, new_sizes, dimensions, sharding): if dimensions is None: return [reshape(t, operand.aval.shape, out_sharding=operand.aval.sharding)] else: - t_s = operand.aval.sharding.with_spec( + t_s = operand.aval.sharding.update(spec= tuple(map(lambda s: s if s is None else str(s), np.take(operand.aval.sharding.spec, dimensions)))) return [transpose(reshape(t, np.take(operand.aval.shape, dimensions), @@ -6871,15 +7241,18 @@ def _reshape_lower(ctx, x, *dyn_shape, new_sizes, dimensions, sharding): return [mlir.lower_with_sharding_in_types(ctx, out, aval_out)] def _reshape_staging_rule( - trace, x, *dyn, new_sizes, dimensions, sharding): + trace, source_info, x, *dyn, new_sizes, dimensions, sharding): params = dict(new_sizes=new_sizes, dimensions=dimensions, sharding=sharding) if not dyn: - return trace.default_process_primitive(reshape_p, (x,), params) + return trace.default_process_primitive(reshape_p, (x,), params, + source_info=source_info) av = core.DShapedArray(_merge_dyn_shape(new_sizes, dyn), x.dtype, x.weak_type) - return _dyn_shape_staging_rule(trace, reshape_p, av, x, *dyn, **params) + return _dyn_shape_staging_rule(trace, source_info, reshape_p, av, x, *dyn, + **params) reshape_p = standard_primitive(_reshape_shape_rule, _reshape_dtype_rule, - 'reshape', sharding_rule=_reshape_sharding_rule) + 'reshape', sharding_rule=_reshape_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'reshape')) ad.deflinear2(reshape_p, _reshape_transpose_rule) batching.fancy_primitive_batchers[reshape_p] = _reshape_batch_rule batching.skippable_batchers[reshape_p] = lambda _: () @@ -6911,7 +7284,8 @@ def _rev_batch_rule(batched_args, batch_dims, *, dimensions): return rev(operand, new_dimensions), bdim rev_p = standard_primitive(_rev_shape_rule, _input_dtype, 'rev', - sharding_rule=_rev_sharding_rule) + sharding_rule=_rev_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'rev')) ad.deflinear2(rev_p, lambda t, _, dimensions: [rev(t, dimensions)]) batching.primitive_batchers[rev_p] = _rev_batch_rule @@ -6935,7 +7309,7 @@ def _transpose_shape_rule(operand, *, permutation): def _transpose_sharding_rule(operand, *, permutation): o_spec = operand.sharding.spec new_spec = [o_spec[old_idx] for old_idx in permutation] - return operand.sharding.with_spec(new_spec) + return operand.sharding.update(spec=new_spec) def _transpose_batch_rule(batched_args, batch_dims, *, permutation): operand, = batched_args @@ -6959,7 +7333,8 @@ def _transpose_lower(ctx, x, *, permutation): transpose_p = standard_primitive( _transpose_shape_rule, _input_dtype, 'transpose', - sharding_rule=_transpose_sharding_rule) + sharding_rule=_transpose_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'transpose')) ad.deflinear2(transpose_p, lambda t, _, permutation: [transpose(t, np.argsort(permutation))]) batching.primitive_batchers[transpose_p] = _transpose_batch_rule @@ -6985,10 +7360,11 @@ def _select_sharding_rule(which, *cases): return core.get_cur_mesh_sharding() if any(s != non_empty_s[0] for s in non_empty_s[1:]): msg = "select cases must have the same shardings, got [{}]." - raise TypeError(msg.format(", ".join([str(c.sharding) for c in cases]))) + raise core.ShardingTypeError( + msg.format(", ".join([str(c.sharding) for c in cases]))) if (which.shape and not which.sharding.mesh.empty and which.sharding != non_empty_s[0]): - raise TypeError( + raise core.ShardingTypeError( 'select `which` must be scalar or have the same sharding as cases, got' f' `which` sharding {which.sharding} but case sharding' f' {cases[0].sharding}.') @@ -7025,7 +7401,7 @@ def _select_transpose_rule(t, which, *cases): if ad.is_undefined_primal(case) else None for i, case in enumerate(cases) ] -def _select_batch_rule(batched_args, batch_dims, **unused_kwargs): +def _select_batch_rule(axis_data, batched_args, batch_dims, **unused_kwargs): which, *cases = batched_args which_bdim, *case_bdims = batch_dims size = next(x.shape[i] for x, i in zip(batched_args, batch_dims) @@ -7038,7 +7414,8 @@ def _select_batch_rule(batched_args, batch_dims, **unused_kwargs): else: # vmapped function had a scalar which with nonscalar args assert np.ndim(which) == 1 - which = broadcast_in_dim(which, cases[0].shape, [which_bdim]) + which = broadcast_in_dim(which, cases[0].shape, [which_bdim], + out_sharding=core.typeof(cases[0]).sharding) return select_n(which, *cases), which_bdim elif np.ndim(which) == 0 and all(bdim is not None for bdim in case_bdims): if all(case_bdims[0] == bdim for bdim in case_bdims[1:]): @@ -7049,16 +7426,18 @@ def _select_batch_rule(batched_args, batch_dims, **unused_kwargs): for c, c_bdim in zip(cases[1:], case_bdims[1:])] return select_n(which, cases[0], *other_cases), bdim - which = (batching.bdim_at_front(which, which_bdim, size) if np.shape(which) - else which) + which = (batching.bdim_at_front(which, which_bdim, size, + axis_data.explicit_mesh_axis) + if np.shape(which) else which) if not all(() == np.shape(c) for c in cases): - cases = [batching.bdim_at_front(c, bdim, size) + cases = [batching.bdim_at_front(c, bdim, size, axis_data.explicit_mesh_axis) for c, bdim in zip(cases, case_bdims)] assert all(np.shape(cases[0]) == np.shape(c) for c in cases[1:]) if 0 < np.ndim(which) < np.ndim(cases[0]): # vmapped function had a scalar which with nonscalar args assert np.ndim(which) == 1 - which = broadcast_in_dim(which, cases[0].shape, [0]) + which = broadcast_in_dim(which, cases[0].shape, [0], + out_sharding=core.typeof(cases[0]).sharding) if np.ndim(which) > np.ndim(cases[0]): assert np.ndim(cases[0]) == 0 cases = [broadcast(c, which.shape) for c in cases] @@ -7079,7 +7458,11 @@ def _select_jvp(primals, tangents): def _select_hlo_lowering_opaque(ctx, which, *cases): avals_in = ctx.avals_in aval_out, = ctx.avals_out - assert all(aval_case == aval_out for aval_case in avals_in[1:]) + assert all((aval_case.shape, aval_case.dtype) == (aval_out.shape, aval_out.dtype) + for aval_case in avals_in[1:]) + assert all( + aval_case == aval_out for aval_case in avals_in[1:] + if not aval_case.sharding.mesh.empty and not aval_out.sharding.mesh.empty) select_lower = _select_hlo_lowering physical_aval_out = core.physical_aval(aval_out) @@ -7134,10 +7517,12 @@ def _select(offset, cases): select_n_p = standard_primitive( _select_shape_rule, _select_dtype_rule, 'select_n', - weak_type_rule=_select_weak_type_rule, sharding_rule=_select_sharding_rule) + weak_type_rule=_select_weak_type_rule, sharding_rule=_select_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'select_n')) ad.primitive_jvps[select_n_p] = _select_jvp ad.primitive_transposes[select_n_p] = _select_transpose_rule -batching.primitive_batchers[select_n_p] = _select_batch_rule +batching.fancy_primitive_batchers[select_n_p] = _select_batch_rule +batching.skippable_batchers[select_n_p] = lambda _: () mlir.register_lowering(select_n_p, _select_hlo_lowering) pe.def_trivial_padding(select_n_p) @@ -7151,9 +7536,14 @@ def _reduce_shape_rule(*avals, computation, jaxpr, dimensions): def _reduce_sharding_rule(*avals, computation, jaxpr, dimensions): operand_avals, _ = split_list(avals, [len(avals) // 2]) - return [op.sharding.with_spec(tuple_delete(op.sharding.spec, dimensions)) + return [op.sharding.update(spec=tuple_delete(op.sharding.spec, dimensions)) for op in operand_avals] +def _reduce_vma_rule(*avals, computation, jaxpr, dimensions): + operand_avals, _ = split_list(avals, [len(avals) // 2]) + out_vma = core.standard_vma_rule('reduce', *operand_avals) + return [out_vma] * len(operand_avals) + def _reduce_dtype_rule(*avals, computation, jaxpr, dimensions): operand_avals, init_val_avals = split_list(avals, [len(avals) // 2]) operand_dtypes = [dtypes.canonicalize_dtype(op.dtype) for op in operand_avals] @@ -7240,7 +7630,8 @@ def _reduce_jvp_rule(primals, tangents, *, computation, jaxpr, dimensions): reduce_p.def_impl(partial(dispatch.apply_primitive, reduce_p)) reduce_p.def_abstract_eval( partial(standard_multi_result_abstract_eval, reduce_p, _reduce_shape_rule, - _reduce_dtype_rule, _reduce_weak_type_rule, _reduce_sharding_rule)) + _reduce_dtype_rule, _reduce_weak_type_rule, _reduce_sharding_rule, + _reduce_vma_rule)) batching.primitive_batchers[reduce_p] = _reduce_batch_rule ad.primitive_jvps[reduce_p] = _reduce_jvp_rule @@ -7310,11 +7701,12 @@ def _reduce_op_sharding_rule(operand, *, axes): axes = frozenset(axes) new_spec = P(*tuple(s for i, s in enumerate(operand.sharding.spec) if i not in axes)) - return operand.sharding.with_spec(new_spec) + return operand.sharding.update(spec=new_spec) reduce_sum_p = standard_primitive( _reduce_op_shape_rule, partial(_reduce_number_dtype_rule, 'reduce_sum'), - 'reduce_sum', sharding_rule=_reduce_op_sharding_rule) + 'reduce_sum', sharding_rule=_reduce_op_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'reduce_sum')) ad.deflinear2(reduce_sum_p, _reduce_sum_transpose_rule) batching.defreducer(reduce_sum_p, _get_sum_identity) pe.padding_rules[reduce_sum_p] = partial(_reducer_padding, reduce_sum, @@ -7329,7 +7721,8 @@ def _reduce_prod_jvp_rule(primals, tangents, *, axes): reduce_prod_p = standard_primitive( _reduce_op_shape_rule, partial(_reduce_number_dtype_rule, 'reduce_prod'), - 'reduce_prod', sharding_rule=_reduce_op_sharding_rule) + 'reduce_prod', sharding_rule=_reduce_op_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'reduce_prod')) ad.primitive_jvps[reduce_prod_p] = _reduce_prod_jvp_rule batching.defreducer(reduce_prod_p, _get_prod_identity) pe.padding_rules[reduce_prod_p] = partial(_reducer_padding, reduce_prod, @@ -7349,7 +7742,8 @@ def _reduce_chooser_jvp_rule(g, ans, operand, *, axes): reduce_max_p = standard_primitive( _reduce_op_shape_rule, _input_dtype, 'reduce_max', - sharding_rule=_reduce_op_sharding_rule) + sharding_rule=_reduce_op_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'reduce_max')) ad.defjvp2(reduce_max_p, _reduce_chooser_jvp_rule) batching.defreducer(reduce_max_p, _get_max_identity) pe.padding_rules[reduce_max_p] = partial(_reducer_padding, reduce_max, @@ -7359,7 +7753,8 @@ def _reduce_chooser_jvp_rule(g, ans, operand, *, axes): reduce_min_p = standard_primitive( _reduce_op_shape_rule, _input_dtype, 'reduce_min', - sharding_rule=_reduce_op_sharding_rule) + sharding_rule=_reduce_op_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'reduce_min')) ad.defjvp2(reduce_min_p, _reduce_chooser_jvp_rule) batching.defreducer(reduce_min_p, _get_min_identity) pe.padding_rules[reduce_min_p] = partial(_reducer_padding, reduce_min, @@ -7377,7 +7772,7 @@ def _argminmax_shape_rule(operand, *, axes, index_dtype): def _argminmax_sharding_rule(operand, *, axes, index_dtype): axis, = axes - return operand.sharding.with_spec( + return operand.sharding.update(spec= util.tuple_delete(operand.sharding.spec, axis)) def _argminmax_dtype_rule(operand, *, axes, index_dtype): @@ -7426,13 +7821,15 @@ def _compute_argminmax(value_comparator, get_identity, argmin_p = standard_primitive(_argminmax_shape_rule, _argminmax_dtype_rule, 'argmin', weak_type_rule=_strip_weak_type, - sharding_rule=_argminmax_sharding_rule) + sharding_rule=_argminmax_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'argmin')) batching.defreducer(argmin_p, _get_min_identity) ad.defjvp_zero(argmin_p) argmax_p = standard_primitive(_argminmax_shape_rule, _argminmax_dtype_rule, 'argmax', weak_type_rule=_strip_weak_type, - sharding_rule=_argminmax_sharding_rule) + sharding_rule=_argminmax_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'argmax')) batching.defreducer(argmax_p, _get_max_identity) ad.defjvp_zero(argmax_p) @@ -7451,24 +7848,27 @@ def _reduce_logical_shape_rule(operand, *, axes): return tuple(np.delete(operand.shape, axes)) def _reduce_logical_sharding_rule(operand, *, axes): - return operand.sharding.with_spec(tuple_delete(operand.sharding.spec, axes)) + return operand.sharding.update(spec=tuple_delete(operand.sharding.spec, axes)) reduce_or_p = standard_primitive( _reduce_logical_shape_rule, _input_dtype, 'reduce_or', - weak_type_rule=_strip_weak_type, sharding_rule=_reduce_logical_sharding_rule) + weak_type_rule=_strip_weak_type, sharding_rule=_reduce_logical_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'reduce_or')) batching.defreducer(reduce_or_p, _get_bitwise_or_identity) reduce_and_p = standard_primitive( _reduce_logical_shape_rule, _input_dtype, 'reduce_and', - weak_type_rule=_strip_weak_type, sharding_rule=_reduce_logical_sharding_rule) + weak_type_rule=_strip_weak_type, sharding_rule=_reduce_logical_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'reduce_and')) batching.defreducer(reduce_and_p, _get_bitwise_and_identity) batching.ragged_prop_rules[reduce_and_p] = batching.ragged_mask_elementwise_rule reduce_xor_p = standard_primitive( _reduce_logical_shape_rule, _input_dtype, 'reduce_xor', - weak_type_rule=_strip_weak_type, sharding_rule=_reduce_logical_sharding_rule) + weak_type_rule=_strip_weak_type, sharding_rule=_reduce_logical_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'reduce_xor')) batching.defreducer(reduce_xor_p, _get_bitwise_or_identity) @@ -7515,7 +7915,8 @@ def _reduce_precision_sharding_rule(operand, *, exponent_bits, mantissa_bits): reduce_precision_p = standard_primitive( _reduce_precision_shape_rule, partial(unop_dtype_rule, _identity, _float, 'reduce_precision'), - name='reduce_precision', sharding_rule=_reduce_precision_sharding_rule) + name='reduce_precision', sharding_rule=_reduce_precision_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'reduce_precision')) ad.deflinear(reduce_precision_p, lambda t, **kwargs: [reduce_precision_p.bind(t, **kwargs)]) batching.defvectorized(reduce_precision_p) @@ -7550,7 +7951,7 @@ def _sort_abstract_eval(*args, **kwargs): def _canonicalize_float_for_sort(x): - # In the sort comparator, we are going to use a comparision operator where -0 + # In the sort comparator, we are going to use a comparison operator where -0 # would be before 0, and -NaN and NaN appear at the beginning and end of the # ordering. In this scheme, -0 would be before 0, and -NaN and NaN appear at # the beginning and end of the ordering. This causes issues for stable @@ -7663,7 +8064,7 @@ def _sort_lower(ctx, *operands, dimension, is_stable, num_keys): mlir.flatten_ir_values(operands), dimension=mlir.i64_attr(dimension), is_stable=ir.BoolAttr.get(is_stable)) - scalar_s = lambda a: a.sharding.with_spec(P()) + scalar_s = lambda a: a.sharding.update(spec=P()) scalar_avals = [aval.update(shape=(), sharding=scalar_s(aval)) for aval in ctx.avals_in] scalar_types = safe_map(mlir.aval_to_ir_type, scalar_avals) @@ -7695,6 +8096,15 @@ def _top_k_abstract_eval(operand, *, k): if shape[-1] < k: msg = "k argument to top_k must be no larger than minor dimension; {} vs {}" raise ValueError(msg.format(k, shape)) + int32_max = dtypes.iinfo('int32').max + try: + too_large = (shape[-1] > int32_max + 1) + except core.InconclusiveDimensionOperation: + pass + else: + if too_large: + raise ValueError("top_k returns int32 indices, which will overflow for array dimensions " + f"larger than the maximum int32 ({int32_max}). Got {operand.shape=}") shape[-1] = k return (operand.update(shape=shape, dtype=operand.dtype, weak_type=operand.weak_type), @@ -7791,7 +8201,8 @@ def _create_token_lowering(ctx, *operands): def after_all(*operands): """Merges one or more XLA token values. Experimental. - Wraps the XLA AfterAll operator.""" + Wraps the XLA after all operator.""" + operands = core.standard_insert_pvary(*operands) return after_all_p.bind(*operands) def _after_all_abstract_eval(*operands): @@ -7815,6 +8226,7 @@ class InOutFeedEffect(effects.Effect): infeed_effect = InOutFeedEffect() outfeed_effect = InOutFeedEffect() +effects.custom_derivatives_allowed_effects.add_type(InOutFeedEffect) def infeed(token, shape=None, partitions=None): """Consumes an infeed value of `shape` from the host. Experimental. @@ -7926,6 +8338,7 @@ def rng_uniform(a, b, shape): This API may be removed at any time. """ + a, b = core.standard_insert_pvary(a, b) return rng_uniform_p.bind(a, b, shape=tuple(shape)) def _rng_uniform_abstract_eval(a, b, *, shape): @@ -7952,15 +8365,24 @@ def _rng_uniform_lowering(ctx, a, b, *, shape): mlir.register_lowering(rng_uniform_p, _rng_uniform_lowering) -def _rng_bit_generator_shape_rule(key, *, shape, dtype, algorithm): +def _rng_bit_generator_shape_rule(key, *, shape, dtype, algorithm, out_sharding): del dtype, algorithm return (key.shape, tuple(shape)) -def _rng_bit_generator_dtype_rule(key, *, shape, dtype, algorithm): +def _rng_bit_generator_sharding_rule(key, *, shape, dtype, algorithm, + out_sharding): + return (key.sharding, out_sharding) + +def _rng_bit_generator_vma_rule(key, *, shape, dtype, algorithm, out_sharding): + assert key.vma == frozenset() + return (key.vma, frozenset()) + +def _rng_bit_generator_dtype_rule(key, *, shape, dtype, algorithm, out_sharding): del shape, algorithm return (key.dtype, dtype) -def _rng_bit_generator_weak_type_rule(key, *, shape, dtype, algorithm): +def _rng_bit_generator_weak_type_rule(key, *, shape, dtype, algorithm, + out_sharding): del shape, dtype, algorithm return (key.weak_type, False) @@ -7991,7 +8413,7 @@ def _rng_algorithm(algorithm: RandomAlgorithm): assert False def _rng_bit_generator_lowering( - ctx, key, *, shape, dtype, algorithm): + ctx, key, *, shape, dtype, algorithm, out_sharding): key_type = ir.RankedTensorType(key.type) key_shape, key_etype = key_type.shape, key_type.element_type # While the RngBitGenerator HLO accepts a u64[2] key on all backends, we @@ -8020,7 +8442,7 @@ def _rng_bit_generator_lowering( ir.RankedTensorType.get([2], u64_type), hlo.reshape(ir.RankedTensorType.get([2, 2], u32_type), key)) algorithm_attr = _rng_algorithm(algorithm) - _, out_vals_aval = ctx.avals_out + out_key_aval, out_vals_aval = ctx.avals_out if any(not core.is_constant_shape(a.shape) for a in ctx.avals_out): output_shape = mlir.shape_tensor( mlir.eval_dynamic_shape(ctx, out_vals_aval.shape)) @@ -8044,7 +8466,8 @@ def _rng_bit_generator_lowering( out_vals = hlo.convert( ir.RankedTensorType.get(ir.RankedTensorType(out_vals.type).shape, etype), out_vals) - return [out_key, out_vals] + return [mlir.lower_with_sharding_in_types(ctx, out_key, out_key_aval), + mlir.lower_with_sharding_in_types(ctx, out_vals, out_vals_aval)] rng_bit_generator_p = Primitive("rng_bit_generator") @@ -8054,7 +8477,8 @@ def _rng_bit_generator_lowering( rng_bit_generator_p.def_abstract_eval( partial(standard_multi_result_abstract_eval, rng_bit_generator_p, _rng_bit_generator_shape_rule, _rng_bit_generator_dtype_rule, - _rng_bit_generator_weak_type_rule, None)) + _rng_bit_generator_weak_type_rule, _rng_bit_generator_sharding_rule, + _rng_bit_generator_vma_rule)) mlir.register_lowering(rng_bit_generator_p, _rng_bit_generator_lowering) @@ -8118,7 +8542,8 @@ def _propagate_mem_kind_copy(in_mem_kind): pxla.memory_kind_propagate_rule[copy_p] = _propagate_mem_kind_copy def rng_bit_generator(key, shape, dtype=np.uint32, - algorithm=RandomAlgorithm.RNG_DEFAULT): + algorithm=RandomAlgorithm.RNG_DEFAULT, + *, out_sharding=None): """Stateless PRNG bit generator. Experimental and its use is discouraged. Returns uniformly distributed random bits with the specified shape and dtype @@ -8134,12 +8559,14 @@ def rng_bit_generator(key, shape, dtype=np.uint32, """ shape = core.canonicalize_shape(shape) dtype = dtypes.canonicalize_dtype(dtype) + out_sharding = canonicalize_sharding(out_sharding, 'rng_bit_generator') if np.dtype(dtype) not in {np.dtype('uint8'), np.dtype('uint16'), np.dtype('uint32'), np.dtype('uint64')}: raise TypeError(f'rng_bit_generator: unsupported dtype {dtype}') return tuple( rng_bit_generator_p.bind( - key, shape=shape, dtype=dtype, algorithm=algorithm)) + key, shape=shape, dtype=dtype, algorithm=algorithm, + out_sharding=out_sharding)) def _iota_abstract_eval(*dyn_shape, dtype, shape, dimension, sharding): @@ -8169,13 +8596,16 @@ def _iota_abstract_eval(*dyn_shape, dtype, shape, dimension, sharding): iota_p.def_abstract_eval(_iota_abstract_eval) batching.ragged_prop_rules[iota_p] = batching.ragged_mask_no_op_rule -def _iota_staging_rule(trace, *dyn_shape, dtype, shape, dimension, sharding): +def _iota_staging_rule(trace, source_info, *dyn_shape, dtype, shape, dimension, + sharding): params = dict(dtype=dtype, shape=shape, dimension=dimension, sharding=sharding) if not dyn_shape: - return trace.default_process_primitive(iota_p, (), params) + return trace.default_process_primitive(iota_p, (), params, + source_info=source_info) aval = core.DShapedArray(_merge_dyn_shape(shape, dyn_shape), dtype, False) - return _dyn_shape_staging_rule(trace, iota_p, aval, *dyn_shape, **params) + return _dyn_shape_staging_rule(trace, source_info, iota_p, aval, *dyn_shape, + **params) pe.custom_staging_rules[iota_p] = _iota_staging_rule def _iota_typecheck_rule(_, *dyn_shape, dtype, shape, dimension, sharding): @@ -8405,15 +8835,19 @@ def _const(example, val): def _zero(x): x_aval = core.get_aval(x) - return full_like(x, shape=(), fill_value=0, - sharding=x_aval.sharding.with_spec(P())) + out = full_like(x, shape=(), fill_value=0, + sharding=x_aval.sharding.update(spec=P())) + out = core.pvary(out, tuple(x_aval.vma)) + return out _ones: Callable = partial(full_like, fill_value=1) def _one(x): x_aval = core.get_aval(x) - return full_like(x, shape=(), fill_value=1, - sharding=x_aval.sharding.with_spec(P())) + out = full_like(x, shape=(), fill_value=1, + sharding=x_aval.sharding.update(spec=P())) + out = core.pvary(out, tuple(x_aval.vma)) + return out _twos: Callable = partial(full_like, fill_value=2) _two: Callable = partial(full_like, shape=(), fill_value=2) @@ -8594,11 +9028,13 @@ def optimization_barrier(operand, /): Array(0., dtype=float32, weak_type=True) """ flat_args, treedef = tree_util.tree_flatten(operand) - return tree_util.tree_unflatten( - treedef, optimization_barrier_p.bind(*flat_args)) + flat_args = core.standard_insert_pvary(*flat_args) + out = optimization_barrier_p.bind(*flat_args) + return tree_util.tree_unflatten(treedef, out) def _optimization_barrier_abstract_eval(*args): + core.standard_vma_rule('optimization_barrier', *args) return args def _optimization_barrier_lowering_rule(ctx, *args): diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index c674401fb80d..81d23465ea34 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -23,8 +23,6 @@ import numpy as np -from jax import lax - from jax._src import ad_util from jax._src import api from jax._src import config @@ -39,16 +37,13 @@ from jax._src.interpreters import batching from jax._src.interpreters import mlir from jax._src.lax import control_flow -from jax._src.lax import eigh as lax_eigh -from jax._src.lax import lax as lax_internal -from jax._src.lax import svd as lax_svd +from jax._src.lax import lax from jax._src.lax import utils as lax_utils from jax._src.lax.lax import _float, _complex, _int from jax._src.lib import gpu_linalg from jax._src.lib import gpu_solver from jax._src.lib import gpu_sparse from jax._src.lib import lapack -from jax._src.lib import version as jaxlib_version from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import chlo from jax._src.lib.mlir.dialects import hlo @@ -121,6 +116,7 @@ def cholesky_update(r_matrix: ArrayLike, w_vector: ArrayLike) -> Array: A new upper-triangular matrix :math:`R` defining the Cholesky decomposition of :math:`A + w \, w^T`. """ + r_matrix, w_vector = core.standard_insert_pvary(r_matrix, w_vector) return cholesky_update_p.bind(r_matrix, w_vector) @@ -268,6 +264,7 @@ def householder_product(a: ArrayLike, taus: ArrayLike) -> Array: A batch of orthogonal (unitary) matrices with the same shape as ``a``, containing the products of the elementary Householder reflectors. """ + a, taus = core.standard_insert_pvary(a, taus) return householder_product_p.bind(a, taus) @@ -526,7 +523,7 @@ def symmetric_product( Computes the symmetric product - ..math:: + .. math:: \alpha \, A \, A^T + \beta \, C where :math:`A` is a rectangular matrix and :math:`C` is a symmetric matrix. @@ -545,6 +542,7 @@ def symmetric_product( ``symmetrize_output`` is ``True``, the upper triangle is filled with the transpose of the lower triangle, and the whole matrix is valid. """ + a_matrix, c_matrix = core.standard_insert_pvary(a_matrix, c_matrix) result = symmetric_product_p.bind(a_matrix, c_matrix, alpha=alpha, beta=beta) if symmetrize_output: upper_half = lax.transpose( @@ -602,6 +600,7 @@ def triangular_solve( singleton = np.ndim(b) == np.ndim(a) - 1 if singleton: b = lax.expand_dims(b, (-1 if left_side else -2,)) + a, b = core.standard_insert_pvary(a, b) out = triangular_solve_p.bind( a, b, left_side=left_side, lower=lower, transpose_a=transpose_a, conjugate_a=conjugate_a, unit_diagonal=unit_diagonal) @@ -635,7 +634,7 @@ def tridiagonal( superdiagonal. ``taus`` contains the scalar factors of the elementary Householder reflectors. """ - return tridiagonal_p.bind(lax_internal.asarray(a), lower=lower) + return tridiagonal_p.bind(lax.asarray(a), lower=lower) def tridiagonal_solve(dl: Array, d: Array, du: Array, b: Array) -> Array: @@ -661,6 +660,7 @@ def tridiagonal_solve(dl: Array, d: Array, du: Array, b: Array) -> Array: Returns: Solution ``X`` of tridiagonal system. """ + dl, d, du, b = core.standard_insert_pvary(dl, d, du, b) return tridiagonal_solve_p.bind(dl, d, du, b) @@ -717,34 +717,42 @@ def linalg_sharding_rule( spec = aval.sharding.spec batch_spec, rest_spec = spec[:len(spec) - rank], spec[len(spec) - rank:] if not all(s is None for s in rest_spec): - raise ValueError( + raise core.ShardingTypeError( f"Input {i} to {name} must be unsharded on non-batch dimensions, " f"but got {spec}." ) batch_specs.append(batch_spec) batch_spec = batch_specs[0] if any(b != batch_spec for b in batch_specs[1:]): - raise ValueError( + raise core.ShardingTypeError( f"All inputs to {name} must have the same batch sharding, but got " f"{batch_specs}." ) sharding = avals[0].sharding if multiple_results: return [ - sharding.with_spec( + sharding.update(spec= P(*(tuple(batch_spec) + (None,) * (len(s) - len(batch_spec)))) ) for s in output_shapes ] else: ndim = len(output_shapes) - len(batch_spec) - return sharding.with_spec(P(*(tuple(batch_spec) + (None,) * ndim))) + return sharding.update(spec=P(*(tuple(batch_spec) + (None,) * ndim))) + +def linalg_vma_rule(multiple_results, shape_rule, name, *avals, **kwargs): + output_shapes = shape_rule(*avals, **kwargs) + out_vma = core.standard_vma_rule(name, *avals) + if multiple_results: + return [out_vma] * len(output_shapes) + else: + return out_vma def linalg_primitive(result_dtype, accepted_dtypes, ranks, result_shape, name, multiple_results=False, supports_batching=True, require_same=True): dtype_rule = partial( - lax_internal.naryop_dtype_rule, result_dtype, accepted_dtypes, name, + lax.naryop_dtype_rule, result_dtype, accepted_dtypes, name, require_same=require_same) shape_rule = partial( linalg_shape_rule, multiple_results, supports_batching, ranks, @@ -754,6 +762,7 @@ def linalg_primitive(result_dtype, accepted_dtypes, ranks, result_shape, name, linalg_sharding_rule, multiple_results, shape_rule, ranks, name) else: sharding_rule = None + vma_rule = partial(linalg_vma_rule, multiple_results, shape_rule, name) prim = core.Primitive(name) prim.multiple_results = multiple_results prim.def_impl(partial(dispatch.apply_primitive, prim)) @@ -761,17 +770,19 @@ def linalg_primitive(result_dtype, accepted_dtypes, ranks, result_shape, name, prim.def_abstract_eval( partial(lax_utils.standard_multi_result_abstract_eval, prim, shape_rule, dtype_rule, lax_utils._standard_weak_type_rule, - sharding_rule)) + sharding_rule, vma_rule)) else: prim.def_abstract_eval( partial(lax_utils.standard_abstract_eval, prim, shape_rule, dtype_rule, - lax_utils._standard_weak_type_rule, sharding_rule)) + lax_utils._standard_weak_type_rule, sharding_rule, + partial(core.standard_vma_rule, name), + None)) if supports_batching: batching.primitive_batchers[prim] = partial( batching.expand_dims_batcher, prim) return prim -standard_linalg_primitive = partial(linalg_primitive, lax_internal._input_dtype) +standard_linalg_primitive = partial(linalg_primitive, lax._input_dtype) # Primitive implementations @@ -794,7 +805,7 @@ def _cholesky_jvp_rule(primals, tangents): def phi(X): l = _tril(X) return l / lax.expand_dims( - lax_internal._const(X, 1) + lax_internal._eye(X.dtype, (X.shape[-1], X.shape[-1])), + lax._const(X, 1) + lax._eye(X.dtype, (X.shape[-1], X.shape[-1])), range(l.ndim - 2)) tmp = triangular_solve(L, sigma_dot, left_side=False, transpose_a=True, @@ -857,7 +868,8 @@ def _drotg_nonzero(x, y): np.array(1., dtype=x.dtype), np.array(0., dtype=x.dtype), ) - return lax.cond(y == 0, lambda x, y: one_and_zero, _drotg_nonzero, x, y) + return control_flow.cond( + y == 0, lambda x, y: one_and_zero, _drotg_nonzero, x, y) def _drot( first_vector: Array, second_vector: Array, @@ -1051,7 +1063,7 @@ def _eigh_jacobi_shape_rule(shape, **_): def _eigh_jacobi_dtype_rule(dtype, **_): dtype = dtypes.canonicalize_dtype(dtype) - return lax_internal._complex_basetype(dtype), dtype + return lax._complex_basetype(dtype), dtype def _eigh_jacobi_lowering_rule(ctx, operand, lower, sort_eigenvalues): operand_aval, = ctx.avals_in @@ -1106,7 +1118,7 @@ def _eigh_shape_rule(shape, *, subset_by_index, **_): def _eigh_dtype_rule(dtype, **_): dtype = dtypes.canonicalize_dtype(dtype) - return dtype, lax_internal._complex_basetype(dtype) + return dtype, lax._complex_basetype(dtype) def _eigh_cpu_gpu_lowering( ctx, operand, *, lower, sort_eigenvalues, subset_by_index, @@ -1145,57 +1157,6 @@ def _eigh_cpu_gpu_lowering( return [v, w] -def _eigh_tpu_impl(x, *, lower, sort_eigenvalues, subset_by_index): - *_, m, n = x.shape - assert m == n, (m, n) - - termination_size = 256 - if not is_constant_dim(m): - # TODO: maybe we can relax the check below for shape polymorphism? - raise NotImplementedError( - "Shape polymorphism for native lowering for eigh is implemented " - f"only for the batch dimensions: {x.shape}") - if m <= termination_size and ( - subset_by_index is None or subset_by_index == (0, n) - ): - eig_vals, eig_vecs = eigh_jacobi(x, lower=lower, - sort_eigenvalues=sort_eigenvalues) - return eig_vecs, eig_vals - - def eigh_qdwh(x): - if len(x.shape) > 2: - return control_flow.map(eigh_qdwh, x) - - # We should only look at elements from the lower/upper triangle. Reflects - # that triangle into the other triangle to form a Hermitian matrix. - if lower: - mask = lax_internal._tri(bool, (n, n), 0) - else: - mask = lax.bitwise_not(lax_internal._tri(bool, (n, n), -1)) - if dtypes.issubdtype(x.dtype, np.complexfloating): - re = lax.select(mask, lax.real(x), _T(lax.real(x))) - if lower: - im_mask = lax_internal._tri(bool, (n, n), -1) - else: - im_mask = lax.bitwise_not(lax_internal._tri(bool, (n, n), 0)) - im = lax.imag(x) - im = lax.select(im_mask, im, lax.full_like(im, 0)) - im = lax.select(mask, im, -_T(im)) - x = lax.complex(re, im) - else: - x = lax.select(mask, x, _T(x)) - - return lax_eigh.eigh( - x, - sort_eigenvalues=sort_eigenvalues, - termination_size=termination_size, - subset_by_index=subset_by_index, - ) - - eig_vals, eig_vecs = eigh_qdwh(x) - return eig_vecs, eig_vals - - def _eigh_jvp_rule( primals, tangents, *, lower, sort_eigenvalues, subset_by_index ): @@ -1224,7 +1185,7 @@ def _eigh_jvp_rule( # for complex numbers we need eigenvalues to be full dtype of v, a: w = w_real.astype(a.dtype) - eye_n = lax_internal._eye(a.dtype, (n, n)) + eye_n = lax._eye(a.dtype, (n, n)) # carefully build reciprocal delta-eigenvalue matrix, avoiding NaNs. with config.numpy_rank_promotion("allow"): Fmat = lax.integer_pow(eye_n + w[..., np.newaxis, :] - w[..., np.newaxis], -1) - eye_n @@ -1241,9 +1202,6 @@ def _eigh_jvp_rule( _eigh_dtype_rule, (_float | _complex,), (2,), _eigh_shape_rule, "eigh", multiple_results=True) ad.primitive_jvps[eigh_p] = _eigh_jvp_rule -mlir.register_lowering( - eigh_p, mlir.lower_fun(_eigh_tpu_impl, multiple_results=True), - platform='tpu') register_cpu_gpu_lowering(eigh_p, _eigh_cpu_gpu_lowering) @@ -1374,7 +1332,7 @@ def body(k, state): # a[k+1:, k+1:] -= jnp.outer(a[k+1:, k], a[k, k+1:]) a_outer = a[:, k, None] * a[k, None] a = a - lax.select((m_idx[:, None] > k) & (n_idx[None, :] > k), - a_outer, lax_internal._zeros(a_outer)) + a_outer, lax._zeros(a_outer)) return pivot, perm, a pivot = lax.full((min(m, n),), 0, dtype=np.int32) @@ -1383,7 +1341,7 @@ def body(k, state): # If the array is empty, the loop body never executes but tracing it to a # jaxpr fails because the indexing cannot succeed. return (pivot, perm, a) - return lax.fori_loop(0, min(m, n), body, (pivot, perm, a)) + return control_flow.fori_loop(0, min(m, n), body, (pivot, perm, a)) def _lu_blocked(a, block_size=128): @@ -1447,10 +1405,10 @@ def _lu_jvp_inner(lu, a_dot, permutation): l_padding = [(0, 0, 0)] * 2 l_padding[-1] = (0, m - k, 0) - zero = lax_internal._const(lu, 0) + zero = lax._const(lu, 0) l = lax.pad(_tril(lu[:, :k], -1), zero, l_padding) - l = l + lax_internal._eye(dtype, (m, m)) - u_eye = lax.pad(lax_internal._eye(dtype, (n - k, n - k)), zero, + l = l + lax._eye(dtype, (m, m)) + u_eye = lax.pad(lax._eye(dtype, (n - k, n - k)), zero, ((k, 0, 0), (k, 0, 0))) u_padding = [(0, 0, 0)] * 2 u_padding[-2] = (0, n - k, 0) @@ -1643,8 +1601,9 @@ def _generic_lu_pivots_to_permutation(swaps, permutation_size): if m == 0 or k == 0: return permutation upper = np.array(k, np.int32) if is_constant_dim(k) else k - result, _ = lax.fori_loop(np.array(0, np.int32), upper, _lu_pivots_body_fn, - (permutation, swaps)) + permutation, swaps = core.standard_insert_pvary(permutation, swaps) + result, _ = control_flow.fori_loop(np.array(0, np.int32), upper, + _lu_pivots_body_fn, (permutation, swaps)) return result @@ -1758,6 +1717,7 @@ def geqp3(a: ArrayLike, jpvt: ArrayLike, *, elementary Householder reflectors, and ``jpvt`` is the column-pivot indices such that ``a[:, jpvt] = q @ r``. """ + a, jpvt = core.standard_insert_pvary(a, jpvt) a_out, jpvt_out, taus = geqp3_p.bind(a, jpvt, use_magma=use_magma) return a_out, jpvt_out, taus @@ -1816,7 +1776,7 @@ def qr_jvp_rule(primals, tangents, *, pivoting, full_matrices, use_magma): qt_dx_rinv_lower = _tril(qt_dx_rinv, -1) do = qt_dx_rinv_lower - _H(qt_dx_rinv_lower) # This is skew-symmetric # The following correction is necessary for complex inputs - I = lax.expand_dims(lax_internal._eye(do.dtype, (n, n)), range(qt_dx_rinv.ndim - 2)) + I = lax.expand_dims(lax._eye(do.dtype, (n, n)), range(qt_dx_rinv.ndim - 2)) do = do + I * (qt_dx_rinv - qt_dx_rinv.real.astype(qt_dx_rinv.dtype)) dq = q @ (do - qt_dx_rinv) + dx_rinv dr = (qt_dx_rinv - do) @ r @@ -1829,7 +1789,7 @@ def _qr_lowering(a, *, pivoting, full_matrices, use_magma): *batch_dims, m, n = a.shape if m == 0 or n == 0: k = m if full_matrices else core.min_dim(m, n) - q = lax.broadcast_in_dim(lax_internal._eye(a.dtype, (m, k)), + q = lax.broadcast_in_dim(lax._eye(a.dtype, (m, k)), (*batch_dims, m, k), (len(batch_dims), len(batch_dims) + 1)) r = lax.full((*batch_dims, k, n), 0, dtype=a.dtype) @@ -1849,7 +1809,7 @@ def _qr_lowering(a, *, pivoting, full_matrices, use_magma): q = householder_product(r[..., :m, :m], taus) elif full_matrices: pads = [(0, 0, 0)] * (len(batch_dims) + 1) + [(0, m - n, 0)] - q = lax.pad(r, lax_internal._zero(r), pads) + q = lax.pad(r, lax._zero(r), pads) q = householder_product(q, taus) else: q = householder_product(r, taus) @@ -1949,7 +1909,7 @@ def _svd_shape_rule(shape, *, full_matrices, compute_uv, subset_by_index, **_): def _svd_dtype_rule(dtype, *, compute_uv, **_): dtype = dtypes.canonicalize_dtype(dtype) - real_dtype = lax_internal._complex_basetype(dtype) + real_dtype = lax._complex_basetype(dtype) if compute_uv: return real_dtype, dtype, dtype else: @@ -1981,7 +1941,7 @@ def _svd_jvp_rule( return (s,), (ds,) s_diffs = (s_dim + _T(s_dim)) * (s_dim - _T(s_dim)) - s_diffs_zeros = lax_internal._eye(s.dtype, (s.shape[-1], s.shape[-1])) # jnp.ones((), dtype=A.dtype) * (s_diffs == 0.) # is 1. where s_diffs is 0. and is 0. everywhere else + s_diffs_zeros = lax._eye(s.dtype, (s.shape[-1], s.shape[-1])) # jnp.ones((), dtype=A.dtype) * (s_diffs == 0.) # is 1. where s_diffs is 0. and is 0. everywhere else s_diffs_zeros = lax.expand_dims(s_diffs_zeros, range(s_diffs.ndim - 2)) F = 1 / (s_diffs + s_diffs_zeros) - s_diffs_zeros dSS = s_dim.astype(A.dtype) * dS # dS.dot(jnp.diag(s)) @@ -2007,12 +1967,12 @@ def _svd_jvp_rule( def _empty_svd(a, *, full_matrices, compute_uv): batch_shape = a.shape[:-2] m, n = a.shape[-2:] - s = lax.full(batch_shape + (0,), 0, dtype=lax_internal._complex_basetype(a.dtype)) + s = lax.full(batch_shape + (0,), 0, dtype=lax._complex_basetype(a.dtype)) if not compute_uv: return (s,) if full_matrices: size = max(m, n) - u = lax.broadcast_in_dim(lax_internal._eye(a.dtype, (size, size)), + u = lax.broadcast_in_dim(lax._eye(a.dtype, (size, size)), (*batch_shape, size, size), (len(batch_shape), len(batch_shape) + 1)) else: @@ -2130,7 +2090,7 @@ def _svd_gpu_sub_lowering(ctx, operand, *, full_matrices, compute_uv, # default QR algorithm, but users can (in principle) override this behavior # by passing `use_jacobi=True`. # - # TODO(danfm): Since this was originally implemented, hipSolver appers to + # TODO(danfm): Since this was originally implemented, hipSolver appears to # have added support for the Jacobi algorithm, so we should investigate # removing this condition. if algorithm is None or algorithm == SvdAlgorithm.DEFAULT: @@ -2195,57 +2155,12 @@ def _svd_gpu_sub_lowering(ctx, operand, *, full_matrices, compute_uv, else: return s, u, vt, info -def _svd_tpu(a, *, full_matrices, compute_uv, subset_by_index, algorithm=None): - if algorithm is not None and algorithm != SvdAlgorithm.DEFAULT: - raise NotImplementedError( - "The SVD algorithm parameter is not implemented on TPU.") - - batch_dims = a.shape[:-2] - fn = partial( - lax_svd.svd, - full_matrices=full_matrices, - compute_uv=compute_uv, - subset_by_index=subset_by_index, - ) - for _ in range(len(batch_dims)): - fn = api.vmap(fn) - - if compute_uv: - u, s, vh = fn(a) - return [s, u, vh] - else: - s = fn(a) - return [s] - -def _svd_tpu_lowering_rule( - ctx, operand, *, full_matrices, compute_uv, subset_by_index, algorithm=None -): - del algorithm # unused - operand_aval, = ctx.avals_in - m, n = operand_aval.shape[-2:] - - if m == 0 or n == 0: - return mlir.lower_fun(_empty_svd, multiple_results=True)( - ctx, - operand, - full_matrices=full_matrices, - compute_uv=compute_uv, - ) - - return mlir.lower_fun(_svd_tpu, multiple_results=True)( - ctx, - operand, - full_matrices=full_matrices, - compute_uv=compute_uv, - subset_by_index=subset_by_index, - ) svd_p = linalg_primitive( _svd_dtype_rule, (_float | _complex,), (2,), _svd_shape_rule, "svd", multiple_results=True) ad.primitive_jvps[svd_p] = _svd_jvp_rule register_cpu_gpu_lowering(svd_p, _svd_cpu_gpu_lowering) -mlir.register_lowering(svd_p, _svd_tpu_lowering_rule) # Symmetric product @@ -2321,7 +2236,7 @@ def a_inverse(rhs): transpose_a=transpose_a, conjugate_a=conjugate_a, unit_diagonal=unit_diagonal) - # triangular_solve is about the same cost as matrix multplication (~n^2 FLOPs + # triangular_solve is about the same cost as matrix multiplication (~n^2 FLOPs # for matrix/vector inputs). Order these operations in whichever order is # cheaper. if left_side: @@ -2410,15 +2325,7 @@ def _triangular_solve_cpu_lower( conjugate_a = False if np.dtype(a_aval.dtype) in _cpu_lapack_types: target_name = lapack.prepare_lapack_call("trsm_ffi", a_aval.dtype) - # TODO(b/397715595): Remove forward_compat check no earlier than 2025-03-18. - if ctx.is_forward_compat() or jaxlib_version <= (0, 5, 1): - alpha = mlir.ir_constant(np.array(1, dtype=a_aval.dtype)), - alpha_aval = ShapedArray((), a_aval.dtype), - batch_partitionable = False - else: - alpha = () - alpha_aval = () - batch_partitionable = True + alpha, alpha_aval, batch_partitionable = (), (), True rule = _linalg_ffi_lowering(target_name, [a_aval, b_aval, *alpha_aval], operand_output_aliases={1: 0}, @@ -2464,7 +2371,7 @@ def _tridiagonal_shape_rule(shape, **_): def _tridiagonal_dtype_rule(dtype, **_): dtype = dtypes.canonicalize_dtype(dtype) - real_dtype = lax_internal._complex_basetype(dtype) + real_dtype = lax._complex_basetype(dtype) return dtype, real_dtype, real_dtype, dtype def _tridiagonal_cpu_gpu_lowering(ctx, a, *, lower, target_name_prefix): @@ -2511,16 +2418,10 @@ def _tridiagonal_solve_shape_rule(dl_shape, d_shape, du_shape, b_shape, **_): "equal the dimensions of the diagonal arguments.") return b_shape -def _tridiagonal_solve_gpu_lowering(lowering, ctx, dl, d, du, b): - _, _, _, b_aval = ctx.avals_in - if b_aval.dtype != np.float32 and b_aval.dtype != np.float64: - raise NotImplementedError( - "tridiagonal_solve is only implemented for float32 and float64 on GPU.") - m, n = b_aval.shape[-2:] - b_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, b_aval.shape) - return [lowering( - dl, d, du, b, m=m, n=n, ldb=m, t=b_aval.dtype, - b_shape_vals=b_shape_vals)] +def _tridiagonal_solve_gpu_lowering(ctx, dl, d, du, b, *, target_name_prefix): + target_name = f"{target_name_prefix}sparse_gtsv2_ffi" + rule = _linalg_ffi_lowering(target_name, operand_output_aliases={3: 0}) + return rule(ctx, dl, d, du, b) def _tridiagonal_solve_cpu_lowering(ctx, dl, d, du, b, **kwargs): del kwargs # unused @@ -2598,7 +2499,7 @@ def fwd(carry, args): dp_next = (d - a * dp) / (b - a * cp) return (cp_next, dp_next), (cp, dp) - (_, final), (cp, dp) = lax.scan( + (_, final), (cp, dp) = control_flow.scan( fwd, (du[0] / d[0], b[0] / d[0]), (dl[1:], d[1:], du[1:], b[1:, :]), unroll=32) @@ -2607,7 +2508,7 @@ def bwd(xn, args): x = dp - cp * xn return x, xn - end, ans = lax.scan(bwd, final, (cp, dp), unroll=32, reverse=True) + end, ans = control_flow.scan(bwd, final, (cp, dp), unroll=32, reverse=True) return lax.concatenate((end[None], ans), 0) def _tridiagonal_solve_jax(dl, d, du, b, **_): @@ -2628,11 +2529,11 @@ def _tridiagonal_solve_jax(dl, d, du, b, **_): platform='cpu') mlir.register_lowering( tridiagonal_solve_p, - partial(_tridiagonal_solve_gpu_lowering, gpu_sparse.cuda_gtsv2), + partial(_tridiagonal_solve_gpu_lowering, target_name_prefix='cu'), platform='cuda') mlir.register_lowering( tridiagonal_solve_p, - partial(_tridiagonal_solve_gpu_lowering, gpu_sparse.rocm_gtsv2), + partial(_tridiagonal_solve_gpu_lowering, target_name_prefix='hip'), platform='rocm') mlir.register_lowering(tridiagonal_solve_p, mlir.lower_fun( _tridiagonal_solve_jax, multiple_results=False)) @@ -2672,7 +2573,7 @@ def _solve(a: Array, b: Array) -> Array: # computing sensitivities. This is considerably faster. lu_, _, permutation = lu(lax.stop_gradient(a)) custom_solve = partial( - lax.custom_linear_solve, + control_flow.custom_linear_solve, lambda x: _broadcasted_matvec(a, x), solve=lambda _, x: lu_solve(lu_, permutation, x, trans=0), transpose_solve=lambda _, x: lu_solve(lu_, permutation, x, trans=1)) @@ -2693,12 +2594,12 @@ def symmetrize(x: Array) -> Array: return (x + _H(x)) / 2 def _tril(m: Array, k:int = 0) -> Array: *_, N, M = m.shape - mask = lax_internal._tri(bool, (N, M), k) + mask = lax._tri(bool, (N, M), k) return lax.select(lax.broadcast(mask, m.shape[:-2]), m, lax.zeros_like_array(m)) def _triu(m: Array, k:int = 0) -> Array: *_, N, M = m.shape - mask = lax_internal._tri(bool, (N, M), k - 1) + mask = lax._tri(bool, (N, M), k - 1) return lax.select(lax.broadcast(mask, m.shape[:-2]), lax.zeros_like_array(m), m) def _construct_diagonal(s: Array) -> Array: @@ -2723,7 +2624,7 @@ def _nan_like_hlo(ctx: mlir.LoweringRuleContext, aval) -> ir.Value: def _broadcasting_select_hlo(ctx, which, which_aval, x, x_aval, y, y_aval) -> ir.Value: """Wrapper around XLA `Select` that broadcasts its arguments.""" - out_shapes = list(lax_internal.broadcast_shapes( + out_shapes = list(lax.broadcast_shapes( tuple(which_aval.shape), tuple(x_aval.shape), tuple(y_aval.shape))) which, x, y = mlir.multi_broadcast_in_dim(ctx, (which, x, y), (which_aval, x_aval, y_aval), @@ -2763,9 +2664,9 @@ def _column_major_matrix_layout(dim: int) -> tuple[int, ...]: return (dim - 2, dim - 1) + tuple(range(dim - 3, -1, -1)) def _sdy_rule_for_aval(letters, num_batch_dims, aval): - return " ".join( - ("...", *(next(letters) for _ in range(len(aval.shape) - num_batch_dims))) - ) + d = len(aval.shape) - num_batch_dims + prefix = "... " if num_batch_dims and d >= 0 else "" + return prefix + " ".join(next(letters) for _ in range(d)) def _build_sdy_sharding_rule(num_batch_dims, avals_in, avals_out): letters = iter(string.ascii_letters) diff --git a/jax/_src/lax/other.py b/jax/_src/lax/other.py index 00e15ef6a91d..6da39b0c2405 100644 --- a/jax/_src/lax/other.py +++ b/jax/_src/lax/other.py @@ -287,3 +287,35 @@ def _logaddexp_jvp(primals, tangents): tangent_out = lax.add(lax.mul(t1, lax.exp(lax.sub(_replace_inf(x1), _replace_inf(primal_out)))), lax.mul(t2, lax.exp(lax.sub(_replace_inf(x2), _replace_inf(primal_out))))) return primal_out, tangent_out + + +@custom_jvp +def logaddexp2(x1: ArrayLike, x2: ArrayLike, /) -> Array: + """Compute log2(exp2(x1) + exp2(x2)) avoiding overflow.""" + x1_arr = lax.asarray(x1) + x2_arr = lax.asarray(x2) + assert x1_arr.dtype == x2_arr.dtype + + amax = lax.max(x1_arr, x2_arr) + invln2 = lax._const(amax, 1/np.log(2)) + if dtypes.isdtype(x1_arr.dtype, "real floating"): + delta = lax.sub(x1_arr, x2_arr) + return lax.select(lax._isnan(delta), + lax.add(x1_arr, x2_arr), # NaNs or infinities of the same sign. + lax.add(amax, lax.mul(invln2, lax.log1p(lax.exp2(lax.neg(lax.abs(delta))))))) + elif dtypes.isdtype(x1_arr.dtype, "complex floating"): + delta = lax.sub(lax.add(x1_arr, x2_arr), lax.mul(amax, lax._const(amax, 2))) + out = lax.add(amax, lax.mul(invln2, lax.log1p(lax.exp2(delta)))) + return lax.complex(lax.real(out), _wrap_between(lax.imag(out), np.pi / np.log(2))) + else: + raise ValueError(f"logaddexp2 requires floating-point or complex inputs; got {x1_arr.dtype}") + + +@logaddexp2.defjvp +def _logaddexp2_jvp(primals, tangents): + x1, x2 = primals + t1, t2 = tangents + primal_out = logaddexp2(x1, x2) + tangent_out = lax.add(lax.mul(t1, lax.exp2(lax.sub(_replace_inf(x1), _replace_inf(primal_out)))), + lax.mul(t2, lax.exp2(lax.sub(_replace_inf(x2), _replace_inf(primal_out))))) + return primal_out, tangent_out diff --git a/jax/_src/lax/parallel.py b/jax/_src/lax/parallel.py index 221fe2a9e87a..06d4ec2f4281 100644 --- a/jax/_src/lax/parallel.py +++ b/jax/_src/lax/parallel.py @@ -22,11 +22,11 @@ import itertools import math -import jax -from jax import tree_util from jax._src import core +from jax._src import config from jax._src import dispatch from jax._src import dtypes +from jax._src import tree_util from jax._src.sharding_impls import (SPMDAxisContext, ShardingContext, NamedSharding, PartitionSpec as P) from jax._src.core import AxisName, ShapedArray @@ -34,10 +34,15 @@ from jax._src.interpreters import batching from jax._src.interpreters import mlir from jax._src.interpreters import pxla +from jax._src.mesh import get_abstract_mesh +from jax._src.core import abstract_token, pvary +from jax._src.lax import control_flow from jax._src.lax import lax from jax._src.lax import slicing from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo +from jax._src.lib import xla_client as xc +from jax._src.typing import Array from jax._src.util import (canonicalize_axis, moveaxis, safe_map, safe_zip, unzip2) import numpy as np @@ -115,6 +120,8 @@ def psum(x, axis_name, *, axis_index_groups=None): """ if not isinstance(axis_name, (tuple, list)): axis_name = (axis_name,) + if not axis_name: + return x if any(isinstance(axis, int) for axis in axis_name) and axis_index_groups is not None: raise ValueError("axis_index_groups only supported for sums over just named axes") _validate_reduce_axis_index_groups(axis_index_groups) @@ -139,10 +146,27 @@ def pos_reduce(x): size = math.prod([core.get_axis_env().axis_size(name) for name in named_axes]) out_flat = tuple(lax._const(leaf, size) * pos_reduce(leaf) for leaf in leaves) else: - out_flat = psum_p.bind( - *leaves, axes=tuple(axis_name), axis_index_groups=axis_index_groups) + if config._check_vma.value: + out_flat = bind_psum_invariant( + leaves, axes=tuple(axis_name), axis_index_groups=axis_index_groups) + else: + out_flat = psum_p.bind( + *leaves, axes=tuple(axis_name), axis_index_groups=axis_index_groups) return tree_util.tree_unflatten(treedef, out_flat) +def bind_psum_invariant(leaves, *, axes, axis_index_groups): + if axis_index_groups is not None: + raise NotImplementedError + axes_ = frozenset(axes) + args_ = [] + for x in leaves: + in_vma = core.get_aval(x).vma + args_.append(pvary(x, tuple(pbroadcast_names)) + if (pbroadcast_names := axes_ - in_vma) else x) + return psum_invariant_p.bind(*args_, axes=axes, + axis_index_groups=axis_index_groups) + + def pmean(x, axis_name, *, axis_index_groups=None): """Compute an all-reduce mean on ``x`` over the pmapped axis ``axis_name``. @@ -173,7 +197,7 @@ def pmean(x, axis_name, *, axis_index_groups=None): [0. 0.6666667 1.3333334 2. ] """ x = psum(x, axis_name=axis_name, axis_index_groups=axis_index_groups) - n = psum(1, axis_name=axis_name, axis_index_groups=axis_index_groups) + n = _axis_size(axis_name, axis_index_groups) return tree_util.tree_map(lambda v: v / n, x) def pmax(x, axis_name, *, axis_index_groups=None): @@ -202,6 +226,7 @@ def pmax(x, axis_name, *, axis_index_groups=None): _validate_reduce_axis_index_groups(axis_index_groups) leaves, treedef = tree_util.tree_flatten(x) axis_index_groups = _canonicalize_axis_index_groups(axis_index_groups) + leaves = map(partial(insert_collective_pvary, axis_name), leaves) out_flat = pmax_p.bind(*leaves, axes=axis_name, axis_index_groups=axis_index_groups) return tree_util.tree_unflatten(treedef, out_flat) @@ -232,6 +257,7 @@ def pmin(x, axis_name, *, axis_index_groups=None): _validate_reduce_axis_index_groups(axis_index_groups) leaves, treedef = tree_util.tree_flatten(x) axis_index_groups = _canonicalize_axis_index_groups(axis_index_groups) + leaves = map(partial(insert_collective_pvary, axis_name), leaves) out_flat = pmin_p.bind(*leaves, axes=axis_name, axis_index_groups=axis_index_groups) return tree_util.tree_unflatten(treedef, out_flat) @@ -325,9 +351,84 @@ def ppermute(x, axis_name, perm): """ if not isinstance(axis_name, (list, tuple)): axis_name = (axis_name,) - return tree_util.tree_map( - partial(ppermute_p.bind, axis_name=axis_name, - perm=tuple(map(tuple, perm))), x) + def bind(leaf): + leaf = insert_collective_pvary(axis_name, leaf) + return ppermute_p.bind(leaf, axis_name=axis_name, perm=tuple(map(tuple, perm))) + return tree_util.tree_map(bind, x) + + +def psend(x, axis_name, perm): + """Perform a collective send according to the permutation ``perm``. + + If ``x`` is a pytree then the result is equivalent to mapping this function to + each leaf in the tree. + + This function is an analog of the Send HLO. + + Args: + x: array(s) with a mapped axis named ``axis_name``. + axis_name: hashable Python object used to name a pmapped axis (see the + :func:`jax.pmap` documentation for more details). + perm: list of pairs of ints, representing ``(source_index, + destination_index)`` pairs that encode how the mapped axis named + ``axis_name`` should be shuffled. The integer values are treated as + indices into the mapped axis ``axis_name``. Any two pairs should not have + the same source index or the same destination index. For each index of the + axis ``axis_name`` that does not correspond to a destination index in + ``perm``, the corresponding values in the result are filled with zeros of + the appropriate type. The semantics here are platform-specific, and for + GPU they correspond to NCCL send. + + Returns: + A compiler token that can be used by precv and lax.optimzation_barrier to + enforce ordering of collective ops. + """ + axis_name = tuple(axis_name) if isinstance(axis_name, (list, tuple)) else (axis_name,) + + def bind(leaf): + leaf = insert_collective_pvary(axis_name, leaf) + return psend_p.bind(leaf, axis_name=axis_name, perm=tuple(map(tuple, perm))) + + return tree_util.tree_map(bind, x) + + +def precv(token, out_shape, axis_name, perm): + """Perform a collective recv according to the permutation ``perm``. + + This function is an analog of the Recv HLO. + + Args: + token: a compiler token, either generated by a matching psend or + lax.create_token(). This is used to enforce control dependencies between + collectives. + out_shape: ShapeDtypeStruct(s) containing the dtype and shape + of the result. + axis_name: hashable Python object used to name a pmapped axis (see the + :func:`jax.pmap` documentation for more details). + perm: list of pairs of ints, representing ``(source_index, + destination_index)`` pairs that encode how the mapped axis named + ``axis_name`` should be shuffled. The integer values are treated as + indices into the mapped axis ``axis_name``. Any two pairs should not have + the same source index or the same destination index. For each index of the + axis ``axis_name`` that does not correspond to a destination index in + ``perm``, the corresponding values in the result are filled with zeros of + the appropriate type. The semantics here are platform-specific, and for + GPU they correspond to NCCL recv. + + Returns: + Array(s) with the same shape as ``out_shape``. + """ + axis_name = tuple(axis_name) if isinstance(axis_name, (list, tuple)) else (axis_name,) + + return precv_p.bind( + token, + out_shape=core.ShapedArray( + out_shape.shape, out_shape.dtype + ), + axis_name=axis_name, + perm=tuple(map(tuple, perm)), + ) + def pshuffle(x, axis_name, perm): """Convenience wrapper of jax.lax.ppermute with alternate permutation encoding @@ -421,14 +522,14 @@ def all_to_all(x, axis_name, split_axis, concat_axis, *, axis_index_groups=None, np.insert(np.delete(x.shape, split_axis), concat_axis, axis_size) where ``axis_size`` is the size of the mapped axis named ``axis_name`` in - the input ``x``, i.e. ``axis_size = lax.psum(1, axis_name)``. + the input ``x``. Otherwise array with shape similar to the input shape, except with split_axis divided by axis size and concat_axis multiplied by axis size. """ axis_index_groups = _canonicalize_axis_index_groups(axis_index_groups) def bind(x, split_axis=split_axis, concat_axis=concat_axis): - group_size = psum(1, axis_name, axis_index_groups=axis_index_groups) + group_size = _axis_size(axis_name, axis_index_groups) if tiled: if x.shape[split_axis] % group_size != 0: raise ValueError(f"The size of all_to_all split_axis ({x.shape[split_axis]}) " @@ -447,6 +548,7 @@ def bind(x, split_axis=split_axis, concat_axis=concat_axis): else: # concat_axis < split_axis x = lax.expand_dims(x, (concat_axis,)) # insert the new axis split_axis += 1 # we have a new axis before split_axis now + x = insert_collective_pvary(axis_name, x) result = all_to_all_p.bind(x, split_axis=split_axis, concat_axis=concat_axis, axis_name=axis_name, axis_index_groups=axis_index_groups, @@ -612,7 +714,7 @@ def ragged_all_to_all( axis_index_groups=axis_index_groups) -def axis_index(axis_name): +def axis_index(axis_name: AxisName) -> Array: """Return the index along the mapped axis ``axis_name``. Args: @@ -628,16 +730,16 @@ def axis_index(axis_name): ... def f(_): ... return lax.axis_index('i') ... - >>> f(np.zeros(4)) + >>> f(jnp.zeros(4)) Array([0, 1, 2, 3], dtype=int32) - >>> f(np.zeros(8)) + >>> f(jnp.zeros(8)) Array([0, 1, 2, 3, 4, 5, 6, 7], dtype=int32) >>> @partial(jax.pmap, axis_name='i') ... @partial(jax.pmap, axis_name='j') ... def f(_): ... return lax.axis_index('i'), lax.axis_index('j') ... - >>> x, y = f(np.zeros((4, 2))) + >>> x, y = f(jnp.zeros((4, 2))) >>> print(x) [[0 0] [1 1] @@ -653,12 +755,53 @@ def axis_index(axis_name): return axis_index_p.bind(axis_name=axis_name) else: inner_size = 1 - index = 0 + index = lax.asarray(0) for name in reversed(axis_name): index += axis_index(name) * inner_size - inner_size *= psum(1, name) + inner_size *= axis_size(name) return index + +def axis_size(axis_name: AxisName) -> int: + """Return the size of the mapped axis ``axis_name``. + + Args: + axis_name: hashable Python object used to name the mapped axis. + + Returns: + An integer representing the size. + + For example, with 8 XLA devices available: + + >>> from functools import partial + >>> from jax.sharding import PartitionSpec as P + >>> mesh = jax.make_mesh((8,), 'i') + >>> @partial(jax.shard_map, mesh=mesh, in_specs=P('i'), out_specs=P()) + ... def f(_): + ... return lax.axis_size('i') + ... + >>> f(jnp.zeros(16)) + Array(8, dtype=int32, weak_type=True) + >>> mesh = jax.make_mesh((4, 2), ('i', 'j')) + >>> @partial(jax.shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P()) + ... def f(_): + ... return lax.axis_size(('i', 'j')) + ... + >>> f(jnp.zeros((16, 8))) + Array(8, dtype=int32, weak_type=True) + """ + return _axis_size(axis_name) + + +def _axis_size( + axis_name: AxisName, + axis_index_groups: Sequence[Sequence[int]] | None = None, + /, +) -> int: + axis_index_groups = _canonicalize_axis_index_groups(axis_index_groups) + return psum(1, axis_name, axis_index_groups=axis_index_groups) + + def pgather(src, idx, axes: int | AxisName): """Uses the last positional axis of idx to index into src's axes.""" if not isinstance(axes, (tuple, list)): @@ -666,7 +809,6 @@ def pgather(src, idx, axes: int | AxisName): # TODO: Canonicalize exes! return pgather_p.bind(src, idx, axes=tuple(axes)) - ### parallel primitives def _names_in_param(pname: str, params: core.ParamDict) -> tuple[str]: @@ -800,6 +942,48 @@ def _allreduce_effectful_abstract_eval(*args, axes, axis_index_groups): ] return out_avals, {core.NamedAxisEffect(axis) for axis in named_axes} +def _psum_invariant_abstract_eval(name, *args, axes, axis_index_groups): + if not config._check_vma.value: + return psum_p.abstract_eval( + *args, axes=axes, axis_index_groups=axis_index_groups) + + assert isinstance(axes, tuple) + _check_axis_names(axes) + arg_vma = [a.vma for a in args] + # If intersection between arg_vma and axes is empty, error + if any(not set(axes) & a for a in arg_vma): + raise ValueError( + f"Collective {name} must be applied to a device-varying " + f"type, but got {arg_vma} for collective acting " + f"over axis name {axes}. Please open an issue at " + "https://github.com/jax-ml/jax/issues, and as a temporary " + "workaround pass the check_vma=False argument to `jax.shard_map`") + + named_axes = tuple(axis for axis in axes if not isinstance(axis, int)) + pos_axes = tuple(axis for axis in axes if isinstance(axis, int)) + if axis_index_groups is not None: + if len(pos_axes) != 0: + raise ValueError( + "axis_index_groups can only be used with reductions over " + f"named axes, but got: {axes}") + core.check_avals_context_mesh(args, 'all_reduce') + out_avals = [ + core.ShapedArray( + lax._reduce_op_shape_rule(arg, axes=pos_axes), arg.dtype, + sharding=lax._reduce_op_sharding_rule(arg, axes=pos_axes), + vma=frozenset(a for a in arg.vma if a not in named_axes)) + for arg in args + ] + return out_avals, {core.NamedAxisEffect(axis) for axis in named_axes} + +# TODO(yashkatariya): Replace this with _psum_invariant_abstract_eval +def _pmin_pmax_abstract_eval(name, *args, axes, axis_index_groups): + if not config._check_vma.value: + return _allreduce_effectful_abstract_eval( + *args, axes=axes, axis_index_groups=axis_index_groups) + return _psum_invariant_abstract_eval( + name, *args, axes=axes, axis_index_groups=axis_index_groups) + def _check_axis_names(axes): named_axes = tuple(axis for axis in axes if not isinstance(axis, int)) axis_env = core.get_axis_env() @@ -899,7 +1083,7 @@ def broadcast_positional(ct, arg): pmax_p = core.Primitive('pmax') pmax_p.multiple_results = True pmax_p.def_impl(partial(_allreduce_impl, pmax_p, lax.reduce_max)) -pmax_p.def_effectful_abstract_eval(_allreduce_effectful_abstract_eval) +pmax_p.def_effectful_abstract_eval(partial(_pmin_pmax_abstract_eval, 'pmax')) mlir.register_lowering( pmax_p, partial(_allreduce_lowering, lax.max_p, lax.reduce_max)) batching.fancy_primitive_batchers[pmax_p] = \ @@ -910,7 +1094,7 @@ def broadcast_positional(ct, arg): pmin_p = core.Primitive('pmin') pmin_p.multiple_results = True pmin_p.def_impl(partial(_allreduce_impl, pmin_p, lax.reduce_min)) -pmin_p.def_effectful_abstract_eval(_allreduce_effectful_abstract_eval) +pmin_p.def_effectful_abstract_eval(partial(_pmin_pmax_abstract_eval, 'pmin')) mlir.register_lowering( pmin_p, partial(_allreduce_lowering, lax.min_p, lax.reduce_min)) batching.fancy_primitive_batchers[pmin_p] = \ @@ -918,12 +1102,12 @@ def broadcast_positional(ct, arg): batching.skippable_batchers[pmin_p] = partial(_names_in_param, 'axes') -def _ppermute_lowering(ctx, x, *, axis_name, perm): +def _pcollectives_lowering_common(ctx, *, axis_name, perm, op_name): replica_groups = _replica_groups(ctx.module_context.axis_env, axis_name, None) group_size = len(replica_groups[0]) srcs, dsts = unzip2((src % group_size, dst % group_size) for src, dst in perm) if not (len(srcs) == len(set(srcs)) and len(dsts) == len(set(dsts))): - msg = "ppermute sources and destinations must be unique, got {}." + msg = f"{op_name} sources and destinations must be unique, got {{}}." raise ValueError(msg.format(perm)) full_perm = np.zeros((len(replica_groups), len(perm), 2), np.int64) @@ -945,10 +1129,17 @@ def _ppermute_lowering(ctx, x, *, axis_name, perm): channel_handle=hlo.ChannelHandle.get(channel, mlir.DEVICE_TO_DEVICE_TYPE)) else: other_args = {} + return full_perm, other_args + +def _ppermute_lowering(ctx, x, *, axis_name, perm): + full_perm, other_args = _pcollectives_lowering_common( + ctx, axis_name=axis_name, perm=perm, op_name="ppermute" + ) return hlo.CollectivePermuteOp( x, mlir.dense_int_elements(full_perm), **other_args).results + def _ppermute_transpose_rule(t, x, perm, axis_name): srcs, dsts = unzip2(perm) inverse_perm = list(zip(dsts, srcs)) @@ -975,6 +1166,7 @@ def _ppermute_batcher(axis_data, vals_in, dims_in, axis_name, perm): def _raise_to_shaped_abstract_eval(x, *, axis_name, **params): _check_axis_names(axis_name) + collective_vma_rule('ppermute', axis_name, x) return x ppermute_p = core.Primitive('ppermute') @@ -984,6 +1176,97 @@ def _raise_to_shaped_abstract_eval(x, *, axis_name, **params): batching.fancy_primitive_batchers[ppermute_p] = _ppermute_batcher batching.skippable_batchers[ppermute_p] = partial(_names_in_param, 'axis_name') + +def _psend_lowering_gpu(ctx, x, *, axis_name, perm): + if ("cuda" not in ctx.module_context.platforms): + raise NotImplementedError("psend is currently only implemented on GPUs") + + full_perm, other_args = _pcollectives_lowering_common( + ctx, axis_name=axis_name, perm=perm, op_name="psend" + ) + token = hlo.create_token() + send_op = hlo.SendOp( + [x], + token, + source_target_pairs=mlir.dense_int_elements(full_perm), + **other_args, + ) + axis_ctx = ctx.module_context.axis_context + if not isinstance(axis_ctx, SPMDAxisContext): + raise NotImplementedError("psend currently only supports manual sharding") + + sharding = xc.OpSharding() + sharding.type = xc.OpSharding.Type.MANUAL + mlir.set_sharding(send_op, sharding) + return [send_op.results] + + +mlir.lowerable_effects.add_type(core.SingleSideCollectiveEffect) + + +def _psend_abstract_eval(x, *, axis_name, **params): + _check_axis_names(axis_name) + return abstract_token, { + *map(core.NamedAxisEffect, axis_name), + core.SingleSideCollectiveEffect(), + } + + +psend_p = core.Primitive("psend") +psend_p.def_impl(partial(dispatch.apply_primitive, psend_p)) +psend_p.def_effectful_abstract_eval(_psend_abstract_eval) +mlir.register_lowering(psend_p, _psend_lowering_gpu, platform="gpu") + +def _psend_lowering(ctx, x, *, axis_name, perm): + raise NotImplementedError("psend is currently only implemented on GPU") +mlir.register_lowering(psend_p, _psend_lowering) + +batching.fancy_primitive_batchers[psend_p] = _ppermute_batcher +batching.skippable_batchers[psend_p] = partial(_names_in_param, "axis_name") + + +def _precv_lowering_gpu(ctx, token, *, out_shape, axis_name, perm): + full_perm, other_args = _pcollectives_lowering_common( + ctx, axis_name=axis_name, perm=perm, op_name="precv" + ) + recv_op = hlo.RecvOp( + [mlir.aval_to_ir_type(out_shape), token.type], + token, + source_target_pairs=mlir.dense_int_elements(full_perm), + **other_args, + ) + axis_ctx = ctx.module_context.axis_context + if not isinstance(axis_ctx, SPMDAxisContext): + raise NotImplementedError("precv currently only supports manual sharding") + + sharding = xc.OpSharding() + sharding.type = xc.OpSharding.Type.MANUAL + mlir.set_sharding(recv_op, sharding) + + # recv_op should return an array of [RankedTensorType, StableHlo.token]; we + # only need the tensor. + results = recv_op.results + return [results[0]] + + +def _precv_abstract_eval( + token, *, out_shape, axis_name, **params +): + return out_shape, {*map(core.NamedAxisEffect, axis_name), + core.SingleSideCollectiveEffect()} + +precv_p = core.Primitive("precv") +precv_p.multiple_results = False +precv_p.def_effectful_abstract_eval(_precv_abstract_eval) +mlir.register_lowering(precv_p, _precv_lowering_gpu, platform='gpu') + +def _precv_lowering(ctx, token, *, out_shape, axis_name, perm): + raise NotImplementedError("precv is currently only implemented on GPU") +mlir.register_lowering(precv_p, _precv_lowering) + +batching.fancy_primitive_batchers[precv_p] = _ppermute_batcher +batching.skippable_batchers[precv_p] = partial(_names_in_param, "axis_name") + def _pbroadcast_transpose_rule(t, x, source, axis_name): is_source = axis_index(axis_name) == source tsum = psum(t, axis_name) @@ -1012,14 +1295,27 @@ def _pbroadcast_lowering(ctx, x, *, axis_name, source): def source_to_front(group): return [group[source]] + list(group[:source]) + list(group[source + 1:]) replica_groups = [source_to_front(group) for group in replica_groups] - channel = ctx.module_context.new_channel() + is_spmd = isinstance( + ctx.module_context.axis_context, + (SPMDAxisContext, ShardingContext), + ) + if is_spmd: + # We want to emit the collective-broadcast with global device IDs and a unique + # channel ID, as otherwise it interprets the devices as replicas instead + # of partitions - and XLA is configured with only a single replica. + channel = ctx.module_context.new_channel() + channel_handle = hlo.ChannelHandle.get(channel, mlir.DEVICE_TO_DEVICE_TYPE) + other_args = dict(channel_handle=channel_handle) + else: + other_args = {} return hlo.CollectiveBroadcastOp( - x, replica_groups=_replica_groups_hlo(replica_groups)).results + x, replica_groups=_replica_groups_hlo(replica_groups), **other_args + ).results pbroadcast_p = core.Primitive('pbroadcast') pbroadcast_p.def_abstract_eval(_raise_to_shaped_abstract_eval) ad.deflinear2(pbroadcast_p, _pbroadcast_transpose_rule) -mlir.register_lowering(pbroadcast_p, _pbroadcast_lowering) +mlir.register_lowering(pbroadcast_p, _pbroadcast_lowering, platform='gpu') batching.fancy_primitive_batchers[pbroadcast_p] = _pbroadcast_batcher batching.skippable_batchers[pbroadcast_p] = partial(_names_in_param, 'axis_name') @@ -1109,15 +1405,15 @@ def _all_to_all_batcher(vals_in, dims_in, *, axis_name, split_axis, concat_axis, def _all_to_all_batched_collective(axis_data, vals_in, dims_in, axis_name, split_axis, concat_axis, axis_index_groups, tiled): - axis_size, frame_name = axis_data.size, axis_data.name if axis_index_groups is not None: raise NotImplementedError("Please open a feature request!") + axis_size, frame_name = axis_data.size, axis_data.name if isinstance(axis_name, (list, tuple)): axes_names = axis_name else: axes_names = [axis_name] - if axis_data.name not in axes_names: + if frame_name not in axes_names: return _all_to_all_batcher( vals_in, dims_in, axis_name=axis_name, split_axis=split_axis, concat_axis=concat_axis, axis_index_groups=axis_index_groups, tiled=tiled) @@ -1157,6 +1453,7 @@ def _all_to_all_batched_collective(axis_data, vals_in, dims_in, axis_index_groups=axis_index_groups, tiled=tiled) # Split out the local part into axis new_d (NOTE: d is already in axis 1) + assert d == 1 x = _splitaxis(split_axis, axis_size, x) new_d = split_axis concat_axis += (split_axis <= concat_axis) # Offset the existing axes by the new batch axis @@ -1184,11 +1481,16 @@ def _all_to_all_effectful_abstract_eval( axis_name = (axis_name,) _check_axis_names(axis_name) shape = list(input_aval.shape) - axis_size = psum(1, axis_name) if axis_index_groups is None else len(axis_index_groups[0]) + axis_size = ( + _axis_size(axis_name) + if axis_index_groups is None + else len(axis_index_groups[0]) + ) assert shape[split_axis] % axis_size == 0, (shape[split_axis], axis_size) shape[split_axis] //= axis_size shape[concat_axis] *= axis_size - out_aval = input_aval.update(shape=tuple(shape), weak_type=False) + vma = collective_vma_rule('all_to_all', axis_name, input_aval) + out_aval = input_aval.update(shape=tuple(shape), weak_type=False, vma=vma) effects = {*map(core.NamedAxisEffect, axis_name)} return out_aval, effects @@ -1298,20 +1600,59 @@ def _ragged_all_to_all_transpose( operand_t = ragged_all_to_all_p.bind( t, zero, output_offsets_, recv_sizes, input_offsets_, send_sizes, axis_name=axis_name, axis_index_groups=axis_index_groups) - mask = jax.numpy.cumsum( - jax.numpy.zeros(t.shape[0], dtype='int32').at[output_offsets_].set(1)\ + mask = control_flow.cumsum( + lax.full(t.shape[0], 0, dtype='int32').at[output_offsets_].set(1) .at[output_offsets_ + recv_sizes].add(-1)) - mask = jax.numpy.expand_dims(mask, (*range(1, t.ndim),)) - output_t = jax.numpy.where(mask, 0, t) + mask = lax.expand_dims(mask, (*range(1, t.ndim),)) + output_t = lax.select(mask, lax._zeros(t), t) return [operand_t, output_t] + [None] * 4 +def _ragged_all_to_all_batched_collective(axis_data, vals_in, dims_in, + axis_name, axis_index_groups): + del axis_data + if axis_index_groups: + raise NotImplementedError("Please open a feature request!") + + operand, output, input_offsets, send_sizes, output_offsets, recv_sizes = vals_in + operand_dim, output_dim, input_offsets_dim, send_sizes_dim, output_offsets_dim, recv_sizes_dim = dims_in + if not (operand.shape[operand_dim] == output.shape[output_dim] == input_offsets.shape[input_offsets_dim] == send_sizes.shape[send_sizes_dim] == output_offsets.shape[output_offsets_dim] == recv_sizes.shape[recv_sizes_dim]): + raise ValueError("all operands must have the same batch sizes") + + sliced_results = [] + for i in range(operand.shape[operand_dim]): + sliced_operand = slicing.slice_in_dim(operand, start_index=i, limit_index=i+1, axis=operand_dim).flatten() + sliced_output = slicing.slice_in_dim(output, start_index=i, limit_index=i+1, axis=output_dim) + sliced_output_shape = sliced_output.shape + sliced_output = sliced_output.flatten() + sliced_input_offsets = slicing.slice_in_dim(input_offsets, start_index=i, limit_index=i+1, axis=input_offsets_dim).flatten() + sliced_send_sizes = slicing.slice_in_dim(send_sizes, start_index=i, limit_index=i+1, axis=send_sizes_dim).flatten() + sliced_output_offsets = slicing.slice_in_dim(output_offsets, start_index=i, limit_index=i+1, axis=output_offsets_dim).flatten() + sliced_recv_sizes = slicing.slice_in_dim(recv_sizes, start_index=i, limit_index=i+1, axis=recv_sizes_dim).flatten() + sliced_result = ragged_all_to_all(sliced_operand, sliced_output, sliced_input_offsets, sliced_send_sizes, sliced_output_offsets, sliced_recv_sizes, axis_name=axis_name, axis_index_groups=axis_index_groups) + sliced_result = lax.expand_dims(sliced_result.reshape(sliced_output_shape), dimensions=(output_dim,)) + sliced_results.append(sliced_result) + + concat_result = lax.concatenate(sliced_results, dimension=output_dim) + return concat_result, operand_dim + ragged_all_to_all_p = core.Primitive('ragged_all_to_all') ragged_all_to_all_p.def_effectful_abstract_eval(_ragged_all_to_all_effectful_abstract_eval) ad.primitive_jvps[ragged_all_to_all_p] = _ragged_all_to_all_jvp ad.primitive_transposes[ragged_all_to_all_p] = _ragged_all_to_all_transpose mlir.register_lowering(ragged_all_to_all_p, _ragged_all_to_all_lowering) +batching.fancy_primitive_batchers[ragged_all_to_all_p] = _ragged_all_to_all_batched_collective batching.skippable_batchers[ragged_all_to_all_p] = partial(_names_in_param, 'axis_name') +def insert_collective_pvary(axis_name, x): + if not config._check_vma.value: + return x + + axis_name = (axis_name,) if not isinstance(axis_name, tuple) else axis_name + aval = core.get_aval(x) + names_union = set(axis_name) | aval.vma + x = pvary(x, tuple(n for n in names_union if n not in aval.vma)) + return x + def all_gather(x, axis_name, *, axis_index_groups=None, axis=0, tiled=False): """Gather values of x across all replicas. @@ -1379,14 +1720,15 @@ def all_gather(x, axis_name, *, axis_index_groups=None, axis=0, tiled=False): if not isinstance(axis_name, tuple): axis_name = axis_name, axis_index_groups = _canonicalize_axis_index_groups(axis_index_groups) - axis_size = psum(1, axis_name, axis_index_groups=axis_index_groups) + axis_size = _axis_size(axis_name, axis_index_groups) def bind(leaf): + leaf = insert_collective_pvary(axis_name, leaf) return all_gather_p.bind( leaf, all_gather_dimension=canonicalize_axis( axis, np.ndim(leaf) if tiled else np.ndim(leaf) + 1), axis_name=axis_name, axis_index_groups=axis_index_groups, - axis_size=int(axis_size), tiled=tiled) + axis_size=axis_size, tiled=tiled) return tree_util.tree_map(bind, x) def _all_gather_impl(x, *, all_gather_dimension, axis_name, axis_index_groups, axis_size, tiled): @@ -1433,6 +1775,19 @@ def _all_gather_lowering(ctx, x, *, all_gather_dimension, axis_name, **other_args).results +def collective_vma_rule(prim_name, axis_name, x_aval): + if not config._check_vma.value: + return frozenset() + axis_name = (axis_name,) if not isinstance(axis_name, tuple) else axis_name + if any(a not in x_aval.vma for a in axis_name): + raise ValueError( + f"Collective {prim_name} must be applied to a device-varying " + f" type, but got {x_aval.vma} for collective acting " + f"over axis name {axis_name}. Please open an issue at " + "https://github.com/jax-ml/jax/issues and as a temporary " + "workaround pass the check_vma=False argument to `jax.shard_map`") + return x_aval.vma + def _all_gather_effectful_abstract_eval( x_aval, *, all_gather_dimension, axis_name, axis_index_groups, axis_size, tiled ): @@ -1444,39 +1799,45 @@ def _all_gather_effectful_abstract_eval( new_shape[all_gather_dimension] *= axis_size else: new_shape.insert(all_gather_dimension, axis_size) - return x_aval.update(shape=new_shape), {*map(core.NamedAxisEffect, axis_name)} + out_vma = collective_vma_rule('all_gather', axis_name, x_aval) + return (x_aval.update(shape=new_shape, vma=out_vma), + {*map(core.NamedAxisEffect, axis_name)}) -def _all_gather_transpose_rule(cts, x, *, all_gather_dimension, axis_name, axis_index_groups, axis_size, tiled): +def _all_gather_transpose_rule(cts, x, *, all_gather_dimension, axis_name, + axis_index_groups, axis_size, tiled): return (psum_scatter(cts, axis_name=axis_name, scatter_dimension=all_gather_dimension, axis_index_groups=axis_index_groups, tiled=tiled),) - # TODO(sharadmv,apaszke): re-enable this when we can properly detect replication. - # return (lax.dynamic_index_in_dim(cts, idx, axis=all_gather_dimension, keepdims=False) * axis_size,) -def _all_gather_batcher(vals_in, dims_in, *, all_gather_dimension, axis_name, axis_index_groups, axis_size, tiled): +def _all_gather_batcher(prim, vals_in, dims_in, *, all_gather_dimension, axis_name, + axis_index_groups, axis_size, tiled): (x,), (d,) = vals_in, dims_in if d is not batching.not_mapped: if d <= all_gather_dimension: all_gather_dimension += 1 elif not tiled: # Tiled all-gather doesn't modify the set of dimensions d += 1 - result = all_gather_p.bind( - x, - all_gather_dimension=all_gather_dimension, - axis_name=axis_name, - axis_index_groups=axis_index_groups, - axis_size=axis_size, - tiled=tiled) - return result, d + if prim is all_gather_p: + result = all_gather_p.bind( + x, all_gather_dimension=all_gather_dimension, axis_name=axis_name, + axis_index_groups=axis_index_groups, axis_size=axis_size, + tiled=tiled) + return result, d + else: + assert prim is all_gather_invariant_p + result = all_gather_invariant_p.bind( + x, all_gather_dimension=all_gather_dimension, axis_name=axis_name, + axis_size=axis_size, tiled=tiled) + return result, d -def _all_gather_batched_collective(axis_data, vals_in, dims_in, +def _all_gather_batched_collective(prim, axis_data, vals_in, dims_in, all_gather_dimension, axis_name, axis_index_groups, axis_size, tiled): frame_size, frame_name = axis_data.size, axis_data.name if frame_name not in axis_name: return _all_gather_batcher( - vals_in, dims_in, all_gather_dimension=all_gather_dimension, + prim, vals_in, dims_in, all_gather_dimension=all_gather_dimension, axis_name=axis_name, axis_index_groups=axis_index_groups, axis_size=axis_size, tiled=tiled) if axis_index_groups is not None: @@ -1508,10 +1869,100 @@ def _all_gather_batched_collective(axis_data, vals_in, dims_in, partial(_all_gather_lowering, platform=p), platform=p) ad.deflinear2(all_gather_p, _all_gather_transpose_rule) -batching.fancy_primitive_batchers[all_gather_p] = _all_gather_batched_collective +batching.fancy_primitive_batchers[all_gather_p] = partial( + _all_gather_batched_collective, all_gather_p) batching.skippable_batchers[all_gather_p] = partial(_names_in_param, 'axis_name') +def all_gather_invariant(x, axis_name, *, axis: int = 0, tiled: bool = False): + """Gather values of x across all replicas. + + If ``x`` is a pytree then the result is equivalent to mapping this function to + each leaf in the tree. + + all_gather_invariant differs from all_gather in the following ways: + + * all_gather_invariant is Varying -> Invariant. + For example: `out: f32[8] = all_gather_invariant(inp: f32[4]{V: x}, 'x')` + where the size of mesh axis `x` is 2. + While all_gather is Varying -> Varying. + + * all_gather_invariant transposes to dynamic_slice which is + Invariant -> Varying. While all_gather transposes to reduce_scatter + which is Varying -> Varying. + """ + if not isinstance(axis_name, tuple): + axis_name = axis_name, + axis_size = _axis_size(axis_name, None) + axes_ = frozenset(axis_name) + def bind(leaf): + in_vma = core.typeof(leaf).vma + if vary_names := axes_ - in_vma: + leaf = pvary(leaf, tuple(vary_names)) + return all_gather_invariant_p.bind( + leaf, + all_gather_dimension=canonicalize_axis(axis, np.ndim(leaf) if tiled else + np.ndim(leaf) + 1), + axis_name=axis_name, axis_size=axis_size, tiled=tiled) + return tree_util.tree_map(bind, x) + +all_gather_invariant_p = core.Primitive('all_gather_invariant') + +def _all_gather_invariant_effectful_abstract_eval( + x_aval, *, all_gather_dimension, axis_name, axis_size, tiled +): + _check_axis_names(axis_name) + new_shape = list(x_aval.shape) + if tiled: + new_shape[all_gather_dimension] *= axis_size + else: + new_shape.insert(all_gather_dimension, axis_size) + out_vma = frozenset(v for v in x_aval.vma if v not in axis_name) + return (x_aval.update(shape=new_shape, vma=out_vma), + {*map(core.NamedAxisEffect, axis_name)}) + +all_gather_invariant_p.def_effectful_abstract_eval( + _all_gather_invariant_effectful_abstract_eval) + +def _all_gather_invariant_impl(x, *, all_gather_dimension, axis_name, axis_size, + tiled): + raise NotImplementedError +all_gather_invariant_p.def_impl(_all_gather_invariant_impl) + + +def _all_gather_invariant_lowering( + ctx, x, *, all_gather_dimension, axis_name, axis_size, tiled, platform=None): + return _all_gather_lowering( + ctx, x, all_gather_dimension=all_gather_dimension, axis_name=axis_name, + axis_index_groups=None, axis_size=axis_size, tiled=tiled, + platform=platform) + +mlir.register_lowering(all_gather_invariant_p, _all_gather_invariant_lowering) +for p in ("cuda", "rocm", "tpu"): + mlir.register_lowering(all_gather_invariant_p, + partial(_all_gather_invariant_lowering, platform=p), + platform=p) + +def _all_gather_invariant_transpose_rule( + cts, x, *, all_gather_dimension, axis_name, axis_size, tiled): + slice_size, rem = divmod(cts.shape[all_gather_dimension], axis_size) + assert not rem + idx = axis_index(axis_name) * slice_size + out = slicing.dynamic_slice_in_dim( + cts, idx, slice_size=slice_size, axis=all_gather_dimension) + return (out,) if tiled else (lax.squeeze(out, [all_gather_dimension]),) +ad.deflinear2(all_gather_invariant_p, _all_gather_invariant_transpose_rule) + +def _all_gather_invariant_batched_collective( + axis_data, vals_in, dims_in, all_gather_dimension, axis_name, axis_size, + tiled): + return _all_gather_batched_collective( + all_gather_invariant_p, axis_data, vals_in, dims_in, all_gather_dimension, + axis_name, None, axis_size, tiled) +batching.fancy_primitive_batchers[all_gather_invariant_p] = _all_gather_invariant_batched_collective +batching.skippable_batchers[all_gather_invariant_p] = partial(_names_in_param, 'axis_name') + + def _reduce_scatter_lowering( prim, ctx, x, *, scatter_dimension, axis_name, @@ -1581,7 +2032,9 @@ def _reduce_scatter_effectful_abstract_eval( f"{scatter_dim_input_size} must match shard count " f"{axis_size}") del new_shape[scatter_dimension] - return x_aval.update(shape=new_shape), {*map(core.NamedAxisEffect, axis_name)} + vma = collective_vma_rule('reduce_scatter', axis_name, x_aval) + return (x_aval.update(shape=new_shape, vma=vma), + {*map(core.NamedAxisEffect, axis_name)}) def _reduce_scatter_transpose_rule(cts, x, *, axis_name, scatter_dimension, @@ -1723,19 +2176,19 @@ def psum_scatter(x, axis_name, *, scatter_dimension=0, axis_index_groups=None, """ if not isinstance(axis_name, tuple): axis_name = axis_name, - axis_size = psum(1, axis_name, axis_index_groups=axis_index_groups) + axis_size = _axis_size(axis_name, axis_index_groups) axis_index_groups = _canonicalize_axis_index_groups(axis_index_groups) - bind = partial( - reduce_scatter_p.bind, - axis_name=axis_name, - scatter_dimension=scatter_dimension, - axis_index_groups=axis_index_groups, - axis_size=axis_size, - tiled=tiled) + def bind(leaf): + leaf = insert_collective_pvary(axis_name, leaf) + return reduce_scatter_p.bind( + leaf, axis_name=axis_name, scatter_dimension=scatter_dimension, + axis_index_groups=axis_index_groups, axis_size=axis_size, tiled=tiled) return tree_util.tree_map(bind, x) def _build_axis_index_lowering_hlo(ctx, axis_name, axis_env): + from jax._src.shard_map import shard_map # pytype: disable=import-error + if isinstance(axis_name, tuple): assert axis_name, 'empty axis name' if len(axis_name) > 1: @@ -1753,12 +2206,11 @@ def _build_axis_index_lowering_hlo(ctx, axis_name, axis_env): axis_context.manual_axes != frozenset(axis_context.mesh.axis_names)): if axis_env.sizes[axis_pos] == 1: return hlo.constant(ir.DenseElementsAttr.get(np.asarray(0, dtype=np.int32))) - from jax.experimental.shard_map import shard_map def f(): return axis_index_p.bind(axis_name=axis_name) return mlir.lower_fun( - lambda: [shard_map(f, axis_context.mesh, check_rep=False, - in_specs=(), out_specs=P())()])(ctx)[0] + lambda: [shard_map(f, check_vma=False, in_specs=(), + out_specs=P())()])(ctx)[0] nreplicas = axis_env.nreps // math.prod(axis_env.sizes) div = mlir.ir_constant( @@ -1781,8 +2233,14 @@ def _axis_index_lowering(ctx, *, axis_name): ctx.module_context.axis_env)] def _axis_index_effectful_abstract_eval(*, axis_name): - _check_axis_names([axis_name]) - return ShapedArray((), np.int32), {core.NamedAxisEffect(axis_name)} + effect = {core.NamedAxisEffect(axis_name)} + axis_name = (axis_name,) if not isinstance(axis_name, tuple) else axis_name + _check_axis_names(axis_name) + mesh = get_abstract_mesh() + sharding = NamedSharding(mesh, P()) + vma = ((frozenset(axis_name) if mesh._any_axis_manual else frozenset()) + if config._check_vma.value else frozenset()) + return ShapedArray((), np.int32, sharding=sharding, vma=vma), effect def _axis_index_batcher(axis_data, vals_in, dims_in, *, axis_name): return lax.iota(np.int32, axis_data.size), 0 @@ -1856,3 +2314,38 @@ def _pgather_collective_batcher(axis_size, frame_name, _, vals_in, dims_in, *, a # TODO: Transpose? That requires adding pscatter... batching.fancy_primitive_batchers[pgather_p] = _pgather_collective_batcher batching.skippable_batchers[pgather_p] = partial(_names_in_param, 'axes') + +psum_invariant_p = core.Primitive('psum_invariant') +psum_invariant_p.multiple_results = True +psum_invariant_p.def_impl(psum_p.impl) +psum_invariant_p.def_effectful_abstract_eval( + partial(_psum_invariant_abstract_eval, psum_invariant_p.name)) +mlir.register_lowering(psum_invariant_p, mlir._lowerings[psum_p]) +batching.fancy_primitive_batchers[psum_invariant_p] = partial( + _batched_reduction_collective, psum_invariant_p, + lambda v, axis_size: axis_size * v) +batching.skippable_batchers[psum_invariant_p] = partial(_names_in_param, 'axes') + +def _psum_invariant_transpose_rule(cts, *args, axes, axis_index_groups): + def f(ct, arg): + assert ad.is_undefined_primal(arg) + return ad.Zero(arg.aval) if type(ct) is ad.Zero else ct + cts = map(f, cts, args) + nonzero_out_cts, treedef = tree_util.tree_flatten(cts) + nonzero_in_cts = core.pvary_p.bind(*nonzero_out_cts, axes=axes, + axis_index_groups=axis_index_groups) + return tree_util.tree_unflatten(treedef, nonzero_in_cts) +ad.deflinear2(psum_invariant_p, _psum_invariant_transpose_rule) + +########################### pvary ################################## + +def _pvary_transpose_rule(cts, *args, axes, axis_index_groups): + def f(ct, arg): + assert ad.is_undefined_primal(arg) + return ad.Zero(arg.aval) if type(ct) is ad.Zero else ct + cts = map(f, cts, args) + nonzero_out_cts, treedef = tree_util.tree_flatten(cts) + nonzero_in_cts = psum_invariant_p.bind(*nonzero_out_cts, axes=axes, + axis_index_groups=axis_index_groups) + return tree_util.tree_unflatten(treedef, nonzero_in_cts) +ad.deflinear2(core.pvary_p, _pvary_transpose_rule) diff --git a/jax/_src/lax/slicing.py b/jax/_src/lax/slicing.py index c26de99c7374..d70110c7301a 100644 --- a/jax/_src/lax/slicing.py +++ b/jax/_src/lax/slicing.py @@ -173,6 +173,8 @@ def dynamic_slice( else: dynamic_sizes = [] static_sizes = core.canonicalize_shape(slice_sizes) # type: ignore + operand, *start_indices = core.standard_insert_pvary( + operand, *start_indices) return dynamic_slice_p.bind(operand, *start_indices, *dynamic_sizes, slice_sizes=tuple(static_sizes)) @@ -234,6 +236,8 @@ def dynamic_update_slice( """ start_indices = _dynamic_slice_indices( operand, start_indices, allow_negative_indices) + operand, update, *start_indices = core.standard_insert_pvary( + operand, update, *start_indices) return dynamic_update_slice_p.bind(operand, update, *start_indices) @@ -303,7 +307,7 @@ class GatherScatterMode(enum.Enum): ONE_HOT = enum.auto() @staticmethod - def from_any(s: str | GatherScatterMode | None): + def from_any(s: str | GatherScatterMode | None) -> GatherScatterMode: if isinstance(s, GatherScatterMode): return s if s == "clip": @@ -416,6 +420,7 @@ def gather(operand: ArrayLike, start_indices: ArrayLike, raise ValueError(f"Unsupported dtype for gather fill_value {dtype}") else: fill_value = None + operand, start_indices = core.standard_insert_pvary(operand, start_indices) return gather_p.bind( operand, start_indices, dimension_numbers=dimension_numbers, slice_sizes=core.canonicalize_shape(slice_sizes), @@ -505,6 +510,8 @@ def scatter_add( """ jaxpr, consts = lax._reduction_jaxpr(lax.add, core.get_aval(lax._const(operand, 0))) + operand, scatter_indices, updates = core.standard_insert_pvary( + operand, scatter_indices, updates) return scatter_add_p.bind( operand, scatter_indices, updates, update_jaxpr=jaxpr, update_consts=consts, dimension_numbers=dimension_numbers, @@ -559,6 +566,8 @@ def scatter_sub( jaxpr, consts = lax._reduction_jaxpr( lax.sub, core.get_aval(lax._const(operand, 0)) ) + operand, scatter_indices, updates = core.standard_insert_pvary( + operand, scatter_indices, updates) return scatter_sub_p.bind( operand, scatter_indices, @@ -613,6 +622,8 @@ def scatter_mul( """ jaxpr, consts = lax._reduction_jaxpr(lax.mul, core.get_aval(lax._const(operand, 1))) + operand, scatter_indices, updates = core.standard_insert_pvary( + operand, scatter_indices, updates) return scatter_mul_p.bind( operand, scatter_indices, updates, update_jaxpr=jaxpr, update_consts=consts, dimension_numbers=dimension_numbers, @@ -660,6 +671,8 @@ def scatter_min( """ jaxpr, consts = lax._reduction_jaxpr(lax.min, core.get_aval(lax._const(operand, 0))) + operand, scatter_indices, updates = core.standard_insert_pvary( + operand, scatter_indices, updates) return scatter_min_p.bind( operand, scatter_indices, updates, update_jaxpr=jaxpr, update_consts=consts, dimension_numbers=dimension_numbers, @@ -707,6 +720,8 @@ def scatter_max( """ jaxpr, consts = lax._reduction_jaxpr(lax.max, core.get_aval(lax._const(operand, 0))) + operand, scatter_indices, updates = core.standard_insert_pvary( + operand, scatter_indices, updates) return scatter_max_p.bind( operand, scatter_indices, updates, update_jaxpr=jaxpr, update_consts=consts, dimension_numbers=dimension_numbers, @@ -771,6 +786,8 @@ def scatter_apply( pass jaxpr, consts = lax._reduction_jaxpr(_apply, core.get_aval(lax._zero(operand))) # TODO: implement this via its own primitive so we can define appropriate autodiff rules. + operand, scatter_indices, unused = core.standard_insert_pvary( + operand, scatter_indices, unused) return scatter_p.bind( operand, scatter_indices, unused, update_jaxpr=jaxpr, update_consts=consts, dimension_numbers=dimension_numbers, @@ -854,6 +871,8 @@ def scatter( ... mode=lax.GatherScatterMode.PROMISE_IN_BOUNDS) Array([0., 2., 3., 0., 4.], dtype=float32) """ + operand, scatter_indices, updates = core.standard_insert_pvary( + operand, scatter_indices, updates) return scatter_p.bind( operand, scatter_indices, updates, update_jaxpr=None, update_consts=(), dimension_numbers=dimension_numbers, @@ -1333,10 +1352,11 @@ def _get_sharding_for_varying_out_shape(out_shape, operand, name): operand.shape, out_shape, operand.sharding.spec): if (op_sh != out_sh and op_spec is not None and out_sh % _get_sub_spec_size(mesh, op_spec) != 0): - raise NotImplementedError( - f"{name} on sharded dims where out dim ({out_sh}) is not divisble by" + raise core.ShardingTypeError( + f"{name} on sharded dims where out dim ({out_sh}) is not divisible by" f" mesh axes ({_get_sub_spec_size(mesh, op_spec)}) with spec" - f" ({op_spec}) is not implemented.") + f" ({op_spec}) is not implemented." + ) # TODO(yashkatariya): Returning operand.sharding as is may or may not move # data. So think about how to avoid it which might include creating a new # mesh? For example: @@ -1393,7 +1413,8 @@ def _slice_batching_rule(batched_args, batch_dims, *, start_indices, return out, bdim slice_p = standard_primitive(_slice_shape_rule, _input_dtype, 'slice', - sharding_rule=_slice_sharding_rule) + sharding_rule=_slice_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'slice')) ad.deflinear2(slice_p, _slice_transpose_rule) batching.primitive_batchers[slice_p] = _slice_batching_rule # TODO(mvoz): A better slice rule for ragged prop, enforcing boundaries @@ -1472,11 +1493,12 @@ def _dynamic_slice_jvp(primals, tangents, *, slice_sizes): def _dynamic_slice_transpose_rule(t, operand, *start_indices, slice_sizes): assert ad.is_undefined_primal(operand) assert all(not ad.is_undefined_primal(s) for s in start_indices) - operand_shape, operand_dtype = operand.aval.shape, operand.aval.dtype if type(t) is ad_util.Zero: return [ad_util.Zero(operand.aval)] + [None] * len(start_indices) else: - zeros = lax.full(operand_shape, 0, operand_dtype) + zeros = lax.full(operand.aval.shape, 0, operand.aval.dtype, + sharding=operand.aval.sharding) + zeros = core.pvary(zeros, tuple(operand.aval.vma)) return ([dynamic_update_slice_p.bind(zeros, t, *start_indices)] + [None] * len(start_indices)) @@ -1520,15 +1542,17 @@ def _dynamic_slice_batching_rule(batched_args, batch_dims, *, slice_sizes): slice_sizes=slice_sizes, unique_indices=True, indices_are_sorted=True, mode=GatherScatterMode.PROMISE_IN_BOUNDS, fill_value=None) -def _dynamic_slice_staging_rule(trace, x, *starts_and_dyn_sizes, slice_sizes): +def _dynamic_slice_staging_rule(trace, source_info, x, *starts_and_dyn_sizes, + slice_sizes): start_indices, dyn = util.split_list(starts_and_dyn_sizes, [x.ndim]) if not dyn: - return trace.default_process_primitive(dynamic_slice_p, (x, *start_indices), - dict(slice_sizes=slice_sizes)) + return trace.default_process_primitive( + dynamic_slice_p, (x, *start_indices), dict(slice_sizes=slice_sizes), + source_info=source_info) shape = lax._merge_dyn_shape(slice_sizes, dyn) aval = core.DShapedArray(shape, x.dtype, False) - return lax._dyn_shape_staging_rule(trace, dynamic_slice_p, aval, x, - *starts_and_dyn_sizes, + return lax._dyn_shape_staging_rule(trace, source_info, dynamic_slice_p, aval, + x, *starts_and_dyn_sizes, slice_sizes=slice_sizes) def _dynamic_slice_typecheck_rule(_, x, *starts_and_dyn_sizes, slice_sizes): @@ -1558,7 +1582,8 @@ def _dynamic_slice_padding_rule(in_avals, out_avals, x, *starts_and_dyn, dynamic_slice_p = standard_primitive( _dynamic_slice_shape_rule, _dynamic_slice_dtype_rule, 'dynamic_slice', weak_type_rule=_argnum_weak_type(0), - sharding_rule=_dynamic_slice_sharding_rule) + sharding_rule=_dynamic_slice_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'dynamic_slice')) ad.primitive_jvps[dynamic_slice_p] = _dynamic_slice_jvp ad.primitive_transposes[dynamic_slice_p] = _dynamic_slice_transpose_rule batching.primitive_batchers[dynamic_slice_p] = _dynamic_slice_batching_rule @@ -1606,7 +1631,7 @@ def _dynamic_update_slice_shape_rule(operand, update, *start_indices): def _dynamic_update_slice_sharding_rule(operand, update, *start_indices): if operand.sharding != update.sharding: - raise TypeError( + raise core.ShardingTypeError( "dynamic_update_slice update sharding must be equal to operand" " sharding, got update sharding" f" {update.str_short(mesh_axis_types=True)} for operand sharding" @@ -1678,7 +1703,8 @@ def _dynamic_update_slice_batching_rule(batched_args, batch_dims): dynamic_update_slice_p = standard_primitive( _dynamic_update_slice_shape_rule, _dynamic_update_slice_dtype_rule, - 'dynamic_update_slice', sharding_rule=_dynamic_update_slice_sharding_rule) + 'dynamic_update_slice', sharding_rule=_dynamic_update_slice_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'dynamic_update_slice')) ad.primitive_jvps[dynamic_update_slice_p] = _dynamic_update_slice_jvp ad.primitive_transposes[dynamic_update_slice_p] = \ _dynamic_update_slice_transpose_rule @@ -1921,9 +1947,6 @@ def _gather_shape_computation(indices, dimension_numbers, slice_sizes): else next(indices_shape_gen) for i in range(output_shape_rank)) return ans -class GatherShardingError(Exception): - pass - def _gather_sharding_rule(operand, indices, *, dimension_numbers, slice_sizes, unique_indices, indices_are_sorted, mode, fill_value): @@ -1935,7 +1958,7 @@ def _gather_sharding_rule(operand, indices, *, dimension_numbers, all(s is None for s in operand.sharding.spec) and all(s is None for s in indices.sharding.spec)): return core.get_cur_mesh_sharding() - raise GatherShardingError( + raise core.ShardingTypeError( "Use `.at[...].get(out_sharding=)` to provide output PartitionSpec for" " the gather indexing.") @@ -2119,7 +2142,8 @@ def _gather_pad_rule(in_avals, out_avals, operand, indices, *, gather_p = standard_primitive( _gather_shape_rule, _gather_dtype_rule, 'gather', - weak_type_rule=_argnum_weak_type(0), sharding_rule=_gather_sharding_rule) + weak_type_rule=_argnum_weak_type(0), sharding_rule=_gather_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'gather')) ad.defjvp(gather_p, _gather_jvp_rule, None) ad.primitive_transposes[gather_p] = _gather_transpose_rule batching.primitive_batchers[gather_p] = _gather_batching_rule @@ -2601,7 +2625,8 @@ def _scatter_batching_rule(scatter_op, batched_args, batch_dims, *, scatter_add_p = standard_primitive( _scatter_shape_rule, _scatter_dtype_rule, 'scatter-add', - weak_type_rule=_argnum_weak_type(0)) + weak_type_rule=_argnum_weak_type(0), + vma_rule=partial(core.standard_vma_rule, 'scatter_add')) ad.primitive_jvps[scatter_add_p] = partial(_scatter_addsub_jvp, scatter_add_p) ad.primitive_transposes[scatter_add_p] = partial(_scatter_addsub_transpose_rule, scatter_add_p) batching.primitive_batchers[scatter_add_p] = ( @@ -2612,6 +2637,7 @@ def _scatter_batching_rule(scatter_op, batched_args, batch_dims, *, _scatter_dtype_rule, "scatter-sub", weak_type_rule=_argnum_weak_type(0), + vma_rule=partial(core.standard_vma_rule, 'scatter_sub') ) ad.primitive_jvps[scatter_sub_p] = partial(_scatter_addsub_jvp, scatter_sub_p) ad.primitive_transposes[scatter_sub_p] = partial(_scatter_addsub_transpose_rule, scatter_sub_p) @@ -2621,7 +2647,8 @@ def _scatter_batching_rule(scatter_op, batched_args, batch_dims, *, scatter_mul_p = standard_primitive( _scatter_shape_rule, _scatter_dtype_rule, 'scatter-mul', - weak_type_rule=_argnum_weak_type(0)) + weak_type_rule=_argnum_weak_type(0), + vma_rule=partial(core.standard_vma_rule, 'scatter_mul')) def _scatter_mul_jvp_rhs(g, x, i, y, *, dimension_numbers, indices_are_sorted, unique_indices, mode, **kw): @@ -2750,14 +2777,16 @@ def _scatter_extremal_jvp(scatter_op, primals, tangents, update_jaxpr, scatter_min_p = standard_primitive( _scatter_shape_rule, _scatter_dtype_rule, 'scatter-min', - weak_type_rule=_argnum_weak_type(0)) + weak_type_rule=_argnum_weak_type(0), + vma_rule=partial(core.standard_vma_rule, 'scatter_min')) batching.primitive_batchers[scatter_min_p] = ( partial(_scatter_batching_rule, scatter_min_p)) ad.primitive_jvps[scatter_min_p] = partial(_scatter_extremal_jvp, scatter_min_p) scatter_max_p = standard_primitive( _scatter_shape_rule, _scatter_dtype_rule, 'scatter-max', - weak_type_rule=_argnum_weak_type(0)) + weak_type_rule=_argnum_weak_type(0), + vma_rule=partial(core.standard_vma_rule, 'scatter_max')) batching.primitive_batchers[scatter_max_p] = ( partial(_scatter_batching_rule, scatter_max_p)) ad.primitive_jvps[scatter_max_p] = partial(_scatter_extremal_jvp, scatter_max_p) @@ -2915,7 +2944,8 @@ def _scatter_transpose_rule(t, operand, indices, updates, *, scatter_p = standard_primitive( _scatter_shape_rule, _scatter_dtype_rule, 'scatter', - weak_type_rule=_argnum_weak_type(0)) + weak_type_rule=_argnum_weak_type(0), + vma_rule=partial(core.standard_vma_rule, 'scatter')) ad.primitive_jvps[scatter_p] = _scatter_jvp ad.primitive_transposes[scatter_p] = _scatter_transpose_rule batching.primitive_batchers[scatter_p] = ( diff --git a/jax/_src/lax/special.py b/jax/_src/lax/special.py index b70513bc2d20..023fed34fdc9 100644 --- a/jax/_src/lax/special.py +++ b/jax/_src/lax/special.py @@ -21,6 +21,7 @@ import numpy as np from functools import partial +from jax._src import core from jax._src.lax.lax import (add, bitwise_and, bitwise_not, bitwise_or, broadcast_in_dim, broadcast_shapes, convert_element_type, div, eq, exp, full_like, ge, @@ -29,7 +30,7 @@ standard_naryop, standard_unop, sub, _const, _dtype, _float, _nary_lower_hlo, _ones, _isnan, _reduce) -from jax._src.lax.control_flow import while_loop +from jax._src.lax.control_flow.loops import while_loop from jax._src import dtypes from jax._src.interpreters import ad @@ -37,8 +38,28 @@ from jax._src.lib.mlir.dialects import chlo from jax._src.typing import Array, ArrayLike +# TODO(mattjj): this function sucks, delete it +def _up_and_broadcast(doit): + def up_and_broadcast(*args): + broadcasted_shape = broadcast_shapes(*(a.shape for a in args)) + args = [broadcast_in_dim(a, broadcasted_shape, list(range(a.ndim))) for a in args] + + a_dtype = args[0].dtype + needs_upcast = a_dtype == dtypes.bfloat16 or a_dtype == np.float16 + if needs_upcast: + args = [convert_element_type(a, np.float32) for a in args] + a_x_type = np.float32 + else: + a_x_type = a_dtype + result = doit(*args, dtype=a_x_type) + if needs_upcast: + result = convert_element_type(result, a_dtype) + return result + return up_and_broadcast + def betainc(a: ArrayLike, b: ArrayLike, x: ArrayLike) -> Array: r"""Elementwise regularized incomplete beta integral.""" + a, b, x = core.standard_insert_pvary(a, b, x) return regularized_incomplete_beta_p.bind(a, b, x) def lgamma(x: ArrayLike) -> Array: @@ -51,26 +72,33 @@ def digamma(x: ArrayLike) -> Array: def polygamma(m: ArrayLike, x: ArrayLike) -> Array: r"""Elementwise polygamma: :math:`\psi^{(m)}(x)`.""" + m, x = core.standard_insert_pvary(m, x) return polygamma_p.bind(m, x) def igamma(a: ArrayLike, x: ArrayLike) -> Array: r"""Elementwise regularized incomplete gamma function.""" + a, x = core.standard_insert_pvary(a, x) return igamma_p.bind(a, x) def igammac(a: ArrayLike, x: ArrayLike) -> Array: r"""Elementwise complementary regularized incomplete gamma function.""" + a, x = core.standard_insert_pvary(a, x) return igammac_p.bind(a, x) def igamma_grad_a(a: ArrayLike, x: ArrayLike) -> Array: r"""Elementwise derivative of the regularized incomplete gamma function.""" + a, x = core.standard_insert_pvary(a, x) return igamma_grad_a_p.bind(a, x) -def random_gamma_grad(a: ArrayLike, x: ArrayLike) -> Array: +@_up_and_broadcast +def random_gamma_grad(a: ArrayLike, x: ArrayLike, *, dtype) -> Array: r"""Elementwise derivative of samples from `Gamma(a, 1)`.""" - return random_gamma_grad_p.bind(a, x) + a, x = core.standard_insert_pvary(a, x) + return random_gamma_grad_impl(a, x, dtype=dtype) def zeta(x: ArrayLike, q: ArrayLike) -> Array: r"""Elementwise Hurwitz zeta function: :math:`\zeta(x, q)`""" + x, q = core.standard_insert_pvary(x, q) return zeta_p.bind(x, q) def bessel_i0e(x: ArrayLike) -> Array: @@ -194,12 +222,18 @@ def nth_partial_betainc_numerator(iteration, a, b, x): iteration_is_one = eq(iteration_bcast, full_like(iteration_bcast, 1)) iteration_minus_one = iteration_bcast - full_like(iteration_bcast, 1) m = iteration_minus_one // full_like(iteration_minus_one, 2) + m_is_zero = eq(m, full_like(m, 0)) m = convert_element_type(m, dtype) one = full_like(a, 1) two = full_like(a, 2.0) # Partial numerator terms - even_numerator = -(a + m) * (a + b + m) * x / ( - (a + two * m) * (a + two * m + one)) + + # When a is close to zero and m == 0, using zero_numerator avoids + # inaccuracies when FTZ or DAZ is enabled: + zero_numerator = -(a + b) * x / (a + one) + even_numerator = select(m_is_zero, zero_numerator, + -(a + m) * (a + b + m) * x / ( + (a + two * m) * (a + two * m + one))) odd_numerator = m * (b - m) * x / ((a + two * m - one) * (a + two * m)) one_numerator = full_like(x, 1.0) numerator = select(iteration_is_even, even_numerator, odd_numerator) @@ -210,12 +244,24 @@ def nth_partial_betainc_denominator(iteration, a, b, x): return select(eq(iteration_bcast, full_like(iteration_bcast, 0)), full_like(x, 0), full_like(x, 1)) + a_is_zero = bitwise_or(eq(a, full_like(a, 0)), eq(b, full_like(b, float('inf')))) + b_is_zero = bitwise_or(eq(b, full_like(b, 0)), eq(a, full_like(a, float('inf')))) + x_is_zero = eq(x, full_like(x, 0)) + x_is_one = eq(x, full_like(x, 1)) + x_is_not_zero = bitwise_not(x_is_zero) + x_is_not_one = bitwise_not(x_is_one) + is_nan = bitwise_or(bitwise_or(_isnan(a), _isnan(b)), _isnan(x)) + + result_is_zero = bitwise_or(bitwise_and(b_is_zero, x_is_not_one), bitwise_and(a_is_zero, x_is_zero)) + result_is_one = bitwise_or(bitwise_and(a_is_zero, x_is_not_zero), bitwise_and(b_is_zero, x_is_one)) + result_is_nan = bitwise_or(bitwise_or(bitwise_or( - le(a, full_like(a, 0)), le(b, full_like(b, 0))), + lt(a, full_like(a, 0)), lt(b, full_like(b, 0))), lt(x, full_like(x, 0))), gt(x, full_like(x, 1))) + result_is_nan = bitwise_or(result_is_nan, bitwise_or(bitwise_and(a_is_zero, b_is_zero), is_nan)) - # The continued fraction will converge rapidly when x < (a+1)/(a+b+2) - # as per: http://dlmf.nist.gov/8.17.E23 + # The continued fraction will converge rapidly when x < + # (a+1)/(a+b+2) as per: http://dlmf.nist.gov/8.17.E23. # # Otherwise, we can rewrite using the symmetry relation as per: # http://dlmf.nist.gov/8.17.E4 @@ -234,10 +280,21 @@ def nth_partial_betainc_denominator(iteration, a, b, x): inputs=[a, b, x] ) - lbeta_ab = lgamma(a) + lgamma(b) - lgamma(a + b) - result = continued_fraction * exp(log(x) * a + log1p(-x) * b - lbeta_ab) / a + # For very small a and to avoid division by zero, we'll use + # a * gamma(a) = gamma(a + 1) -> 1 as a -> 0+. + very_small = (dtypes.finfo(dtype).tiny * 2).astype(dtype) + lbeta_ab_small_a = lgamma(b) - lgamma(a + b) + lbeta_ab = lgamma(a) + lbeta_ab_small_a + factor = select(lt(a, full_like(a, very_small)), + exp(log1p(-x) * b - lbeta_ab_small_a), + exp(log(x) * a + log1p(-x) * b - lbeta_ab) / a) + result = continued_fraction * factor + result = select(converges_rapidly, result, sub(full_like(result, 1), result)) + + result = select(result_is_zero, full_like(a, 0), result) + result = select(result_is_one, full_like(a, 1), result) result = select(result_is_nan, full_like(a, float('nan')), result) - return select(converges_rapidly, result, sub(full_like(result, 1), result)) + return result class IgammaMode(Enum): VALUE = 1 @@ -494,24 +551,6 @@ def random_gamma_grad_impl(a, x, *, dtype): full_like(a, float('nan')), output) return output -def _up_and_broadcast(doit): - def up_and_broadcast(*args): - broadcasted_shape = broadcast_shapes(*(a.shape for a in args)) - args = [broadcast_in_dim(a, broadcasted_shape, list(range(a.ndim))) for a in args] - - a_dtype = args[0].dtype - needs_upcast = a_dtype == dtypes.bfloat16 or a_dtype == np.float16 - if needs_upcast: - args = [convert_element_type(a, np.float32) for a in args] - a_x_type = np.float32 - else: - a_x_type = a_dtype - result = doit(*args, dtype=a_x_type) - if needs_upcast: - result = convert_element_type(result, a_dtype) - return result - return up_and_broadcast - def evaluate_chebyshev_polynomial(x, coefficients): b0 = full_like(x,0) @@ -657,11 +696,6 @@ def bessel_i0e_impl(x): ad.defjvp(igammac_p, igammac_grada, igammac_gradx) -random_gamma_grad_p = standard_naryop([_float, _float], 'random_gamma_grad') -mlir.register_lowering(random_gamma_grad_p, - mlir.lower_fun(_up_and_broadcast(random_gamma_grad_impl), - multiple_results=False)) - zeta_p = standard_naryop([_float, _float], 'zeta') mlir.register_lowering(zeta_p, partial(_nary_lower_hlo, chlo.zeta)) diff --git a/jax/_src/lax/utils.py b/jax/_src/lax/utils.py index f39d925ac2ad..97a2687bbb67 100644 --- a/jax/_src/lax/utils.py +++ b/jax/_src/lax/utils.py @@ -22,9 +22,10 @@ from jax._src import dispatch from jax._src import dtypes from jax._src import mesh as mesh_lib -from jax._src.util import safe_zip +from jax._src import state +from jax._src.named_sharding import DuplicateSpecError, NamedSharding from jax._src.partition_spec import PartitionSpec as P -from jax._src.named_sharding import NamedSharding, DuplicateSpecError +from jax._src.util import safe_zip zip, unsafe_zip = safe_zip, zip @@ -37,13 +38,14 @@ def _argnum_weak_type(*argnums): return lambda *args, **_: all(args[i].weak_type for i in argnums) def standard_primitive(shape_rule, dtype_rule, name, - weak_type_rule=None, sharding_rule=None): + weak_type_rule=None, sharding_rule=None, vma_rule=None, + unreduced_rule=None): weak_type_rule = weak_type_rule or _standard_weak_type_rule prim = core.Primitive(name) prim.def_impl(partial(dispatch.apply_primitive, prim)) prim.def_abstract_eval( partial(standard_abstract_eval, prim, shape_rule, dtype_rule, - weak_type_rule, sharding_rule)) + weak_type_rule, sharding_rule, vma_rule, unreduced_rule)) return prim def _get_array_abstraction_level(a): return a.array_abstraction_level @@ -65,7 +67,7 @@ def _get_abstract_mesh_from_avals(in_avals) -> mesh_lib.AbstractMesh: return mesh_lib.empty_abstract_mesh if m is None else m -def call_sharding_rule(prim, rule, num_out, *avals, **kwargs): +def call_sharding_rule(prim, sh_rule, unreduced_rule, num_out, *avals, **kwargs): cur_mesh = mesh_lib.get_abstract_mesh() aval_mesh = _get_abstract_mesh_from_avals(avals) if ((cur_mesh.empty or cur_mesh._are_all_axes_auto_or_manual) and @@ -73,36 +75,56 @@ def call_sharding_rule(prim, rule, num_out, *avals, **kwargs): aval_mesh = cur_mesh if aval_mesh.empty else aval_mesh s = NamedSharding(aval_mesh, P()) return s if num_out is None else [s] * num_out - if rule is None: - raise ValueError( - f'sharding rule for {prim.name} is not implemented. Please file a' - ' bug at https://github.com/jax-ml/jax/issues. You can work around' + if sh_rule is None: + raise core.ShardingTypeError( + f'sharding rule for {prim.name} is not implemented. Please file an' + ' issue at https://github.com/jax-ml/jax/issues. You can work around' ' this error by dropping that operation into full auto sharding' ' mode via: `jax.experimental.shard.auto_axes(fun, out_shardings=...)`') - return rule(*avals, **kwargs) + out_sharding = sh_rule(*avals, **kwargs) + if unreduced_rule is not None: + out_sharding = unreduced_rule(out_sharding, *avals, **kwargs) + else: + if any(a.sharding.spec.unreduced for a in avals): + raise NotImplementedError( + f'unreduced rule for {prim.name} is not implemented. Please file an' + ' issue at https://github.com/jax-ml/jax/issues') + return out_sharding def call_shape_dtype_sharding_rule(prim, shape_rule, dtype_rule, sharding_rule, - multi_out, *avals, **kwargs): + unreduced_rule, multi_out, *avals, **kwargs): out_shapes = shape_rule(*avals, **kwargs) out_dtypes = dtype_rule(*avals, **kwargs) num_out = len(out_shapes) if multi_out else None try: out_shardings = call_sharding_rule( - prim, sharding_rule, num_out, *avals, **kwargs) + prim, sharding_rule, unreduced_rule, num_out, *avals, **kwargs) except DuplicateSpecError as e: if multi_out: raise avals_str = ', '.join(i.str_short(short_dtypes=True) for i in avals) mesh = mesh_lib.empty_abstract_mesh if e.mesh is None else e.mesh out_aval_str = core.str_short_aval(out_shapes, out_dtypes, mesh, e.pspec, - short_dtypes=True) - raise TypeError( + frozenset(), short_dtypes=True) + raise core.ShardingTypeError( f'{prim} operation with inputs: {avals_str} produces an illegally' f' sharded result: {out_aval_str}') from e return out_shapes, out_dtypes, out_shardings def standard_abstract_eval(prim, shape_rule, dtype_rule, weak_type_rule, - sharding_rule, *avals, **kwargs): + sharding_rule, vma_rule, unreduced_rule, + *avals, **kwargs): + for a in avals: + if isinstance(a, state.AbstractRef): + raise ValueError( + f' Attempting to pass a Ref {a} to a primitive:' + f' {prim} - did you forget to unpack ([...]) the ref?' + ) + if not isinstance(a, core.UnshapedArray): + raise ValueError( + f'Attempting to pass an unexpected type {a} to a' + f' primitive: {prim}' + ) assert all(isinstance(aval, core.UnshapedArray) for aval in avals), avals assert not prim.multiple_results weak_type = weak_type_rule(*avals, **kwargs) @@ -110,10 +132,12 @@ def standard_abstract_eval(prim, shape_rule, dtype_rule, weak_type_rule, if least_specialized is core.ShapedArray: core.check_avals_context_mesh(avals, prim.name) out_shape, out_dtype, out_sharding = call_shape_dtype_sharding_rule( - prim, shape_rule, dtype_rule, sharding_rule, False, + prim, shape_rule, dtype_rule, sharding_rule, unreduced_rule, False, *avals, **kwargs) + out_vma = vma_rule(*avals, **kwargs) out_aval = core.ShapedArray( - out_shape, out_dtype, weak_type=weak_type, sharding=out_sharding) + out_shape, out_dtype, weak_type=weak_type, sharding=out_sharding, + vma=out_vma) core.check_avals_context_mesh([out_aval], prim.name) return out_aval elif least_specialized is core.DShapedArray: @@ -127,7 +151,7 @@ def standard_abstract_eval(prim, shape_rule, dtype_rule, weak_type_rule, raise TypeError(avals, least_specialized) def standard_multi_result_abstract_eval( - prim, shape_rule, dtype_rule, weak_type_rule, sharding_rule, + prim, shape_rule, dtype_rule, weak_type_rule, sharding_rule, vma_rule, *avals, **kwargs): assert prim.multiple_results assert all(isinstance(aval, core.UnshapedArray) for aval in avals), avals @@ -136,12 +160,14 @@ def standard_multi_result_abstract_eval( if least_specialized is core.ShapedArray: core.check_avals_context_mesh(avals, prim.name) out_shapes, out_dtypes, out_shardings = call_shape_dtype_sharding_rule( - prim, shape_rule, dtype_rule, sharding_rule, True, *avals, **kwargs) + prim, shape_rule, dtype_rule, sharding_rule, None, True, + *avals, **kwargs) + out_vmas = vma_rule(*avals, **kwargs) if isinstance(weak_types, bool): weak_types = (weak_types,) * len(out_shapes) - out_avals = [core.ShapedArray(s, d, weak_type=weak_type, sharding=sh) - for s, d, weak_type, sh in zip(out_shapes, out_dtypes, - weak_types, out_shardings)] + out_avals = [core.ShapedArray(s, d, weak_type=weak_type, sharding=sh, vma=vma) + for s, d, weak_type, sh, vma in zip( + out_shapes, out_dtypes, weak_types, out_shardings, out_vmas)] core.check_avals_context_mesh(out_avals, prim.name) return out_avals elif least_specialized is core.UnshapedArray: diff --git a/jax/_src/lax/windowed_reductions.py b/jax/_src/lax/windowed_reductions.py index 400646f6238f..e322fc447e7c 100644 --- a/jax/_src/lax/windowed_reductions.py +++ b/jax/_src/lax/windowed_reductions.py @@ -18,13 +18,14 @@ from functools import partial import warnings -from jax import tree_util +from jax._src import ad_util from jax._src import api_util from jax._src import core from jax._src import dispatch from jax._src import dtypes +from jax._src import tree_util from jax._src import util -from jax._src.core import ShapedArray +from jax._src.core import ClosedJaxpr, ShapedArray, jaxpr_as_fun from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir @@ -35,11 +36,8 @@ from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo from jax._src.typing import Array + import numpy as np -from jax._src.core import ClosedJaxpr -from jax._src.core import jaxpr_as_fun -from jax._src.interpreters.ad import jvp_jaxpr -from jax._src import ad_util map = util.safe_map zip = util.safe_zip @@ -79,7 +77,7 @@ def _reduce_window( padding = tuple(lax.padtype_to_pads( flat_operands[0].shape, dilated_window_dims, window_strides, padding)) else: - padding = tuple(padding) + padding = tuple((x, y) for x, y in padding) if base_dilation is None: base_dilation = (1,) * len(window_dimensions) if window_dilation is None: @@ -97,6 +95,7 @@ def _reduce_window( raise ValueError( 'reduce_window output must have the same tree structure as the operands' f' {operand_tree} vs. {out_tree}') + flat_operands = core.standard_insert_pvary(*flat_operands) out_flat = reduce_window_p.bind( *flat_operands, *flat_init_values, @@ -250,6 +249,8 @@ def _select_and_scatter(operand: Array, select: Callable, select, core.get_aval(init_value)) scatter_jaxpr, scatter_consts = lax._reduction_jaxpr( scatter, core.get_aval(init_value)) + operand, source, init_value = core.standard_insert_pvary( + operand, source, init_value) return select_and_scatter_p.bind( operand, source, init_value, select_jaxpr=select_jaxpr, select_consts=select_consts, scatter_jaxpr=scatter_jaxpr, @@ -261,6 +262,7 @@ def _select_and_scatter_add(source: Array, operand: Array, window_dimensions: core.Shape, window_strides: Sequence[int], padding: Sequence[tuple[int, int]]) -> Array: + source, operand = core.standard_insert_pvary(source, operand) return select_and_scatter_add_p.bind( source, operand, select_prim=select_prim, window_dimensions=tuple(window_dimensions), @@ -296,6 +298,7 @@ def _select_and_gather_add(tangents: Array, operand: Array, An array containing the elements in `tangents` corresponding to the output of the reduction of `operand` fin each window. """ + tangents, operand = core.standard_insert_pvary(tangents, operand) return select_and_gather_add_p.bind( tangents, operand, select_prim=select_prim, window_dimensions=tuple(window_dimensions), @@ -332,7 +335,8 @@ def _reduce_window_abstract_eval_rule( out_sharding = reduce_window_sharding_rule( operand_avals[0], window_dimensions, window_strides, padding, base_dilation, window_dilation) - return tuple(ShapedArray(out_shape, op.dtype, sharding=out_sharding) + vma = core.standard_vma_rule('reduce_window', *operand_avals) + return tuple(ShapedArray(out_shape, op.dtype, sharding=out_sharding, vma=vma) for op in operand_avals) @@ -398,7 +402,7 @@ def reduce_window_jvp( init_value_tangent = map(ad_util.instantiate, init_value_tangent) c_reduction_jaxpr = ClosedJaxpr(reduction_jaxpr, consts) - jvp_reduction = jvp_jaxpr(c_reduction_jaxpr, (True,) * len(tangents), [False] * len(init_value_tangent))[0] + jvp_reduction = ad.jvp_jaxpr(c_reduction_jaxpr, (True,) * len(tangents), [False] * len(init_value_tangent))[0] def wrapper(left, right): pl, tl = util.split_list(left, [n]) @@ -514,25 +518,16 @@ def _reduce_window_batch_rule(reduce_window, batched_args, bdims, *, def reduce_window_sharding_rule(operand, window_dimensions, window_strides, padding, base_dilation, window_dilation): - if base_dilation is None: - base_dilation = [1] * operand.ndim - if window_dilation is None: - window_dilation = [1] * operand.ndim - - for spec, wdim, ws, pd, bd, wdil in zip( - operand.sharding.spec, window_dimensions, window_strides, padding, - base_dilation, window_dilation): - if spec is None: - continue - if not (wdim == 1 and ws == 1 and pd == 1 and bd == 1 and wdil == 1): - raise NotImplementedError( - "Only trivial windowing is supported along non-replicated" - f" dimensions. Got {operand.sharding.spec=}") - return operand.sharding + out_shape = reduce_window_shape_tuple( + operand.shape, window_dimensions, window_strides, padding, base_dilation, + window_dilation) + return lax.slicing._get_sharding_for_varying_out_shape( + out_shape, operand, 'reduce_window') reduce_window_sum_p = lax.standard_primitive( _reduce_window_sum_shape_rule, lax._input_dtype, 'reduce_window_sum', - sharding_rule=reduce_window_sharding_rule) + sharding_rule=reduce_window_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'reduce_window_sum')) ad.deflinear2(reduce_window_sum_p, _reduce_window_sum_transpose_rule) batching.primitive_batchers[reduce_window_sum_p] = partial( _reduce_window_batch_rule, _reduce_window_sum) @@ -598,7 +593,8 @@ def reduce_window_shape_tuple(operand_shape, window_dimensions, window_strides, reduce_window_max_p = lax.standard_primitive( _common_reduce_window_shape_rule, lax._input_dtype, 'reduce_window_max', - sharding_rule=reduce_window_sharding_rule) + sharding_rule=reduce_window_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'reduce_window_max')) ad.defjvp(reduce_window_max_p, partial(_reduce_window_chooser_jvp_rule, lax.max_p)) batching.primitive_batchers[reduce_window_max_p] = partial( @@ -606,7 +602,8 @@ def reduce_window_shape_tuple(operand_shape, window_dimensions, window_strides, reduce_window_min_p = lax.standard_primitive( _common_reduce_window_shape_rule, lax._input_dtype, 'reduce_window_min', - sharding_rule=reduce_window_sharding_rule) + sharding_rule=reduce_window_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'reduce_window_min')) ad.defjvp(reduce_window_min_p, partial(_reduce_window_chooser_jvp_rule, lax.min_p)) @@ -630,7 +627,8 @@ def _reduce_window_lower( ): operand_aval, = ctx.avals_in - scalar_aval = operand_aval.update(shape=()) + scalar_aval = operand_aval.update( + shape=(), sharding=operand_aval.sharding.update(spec=())) return mlir.reduce_window( ctx, @@ -670,8 +668,15 @@ def _select_and_scatter_shape_rule( raise TypeError(msg.format(window_strides, window_dimensions)) return operand.shape +def _select_and_scatter_sharding_rule( + operand, source, init_value, *, select_jaxpr, select_consts, scatter_jaxpr, + scatter_consts, window_dimensions, window_strides, padding): + return operand.sharding + select_and_scatter_p = lax.standard_primitive( - _select_and_scatter_shape_rule, lax._input_dtype, 'select_and_scatter') + _select_and_scatter_shape_rule, lax._input_dtype, 'select_and_scatter', + sharding_rule=_select_and_scatter_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'select_and_scatter')) def _select_and_scatter_lower( ctx, operand, source, init_value, *, select_jaxpr, @@ -679,7 +684,8 @@ def _select_and_scatter_lower( window_strides, padding): operand_aval, source_aval, init_value_aval = ctx.avals_in aval_out, = ctx.avals_out - scalar_aval = operand_aval.update(shape=()) + scalar_aval = operand_aval.update( + shape=(), sharding=operand_aval.sharding.update(spec=())) scalar_type = mlir.aval_to_ir_type(scalar_aval) op = hlo.SelectAndScatterOp( mlir.aval_to_ir_type(aval_out), @@ -710,7 +716,8 @@ def _select_and_scatter_lower( *scatter.arguments, dim_var_values=ctx.dim_var_values) hlo.return_(mlir.flatten_ir_values(out_nodes)) - return op.results + return [mlir.lower_with_sharding_in_types(ctx, r, aval) + for r, aval in zip(op.results, ctx.avals_out)] mlir.register_lowering(select_and_scatter_p, _select_and_scatter_lower) @@ -719,6 +726,11 @@ def _select_and_scatter_add_shape_rule( padding): return operand.shape +def _select_and_scatter_add_sharding_rule( + source, operand, *, select_prim, window_dimensions, window_strides, + padding): + return operand.sharding + def _select_and_scatter_add_jvp( primals, tangents, *, select_prim, window_dimensions, window_strides, padding): @@ -766,7 +778,9 @@ def _select_and_scatter_add_batch_rule( select_and_scatter_add_p = lax.standard_primitive( _select_and_scatter_add_shape_rule, lax._input_dtype, - 'select_and_scatter_add') + 'select_and_scatter_add', + sharding_rule=_select_and_scatter_add_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'select_and_scatter_add')) ad.primitive_transposes[select_and_scatter_add_p] = \ _select_and_scatter_add_transpose @@ -826,7 +840,7 @@ def _select_and_gather_add_sharding_rule( tangents, operand, *, select_prim, window_dimensions, window_strides, padding, base_dilation, window_dilation): if tangents.sharding != operand.sharding: - raise TypeError( + raise core.ShardingTypeError( "select_and_gather_add tangents and operand shardings must match, " f"got {tangents.sharding} and {operand.sharding}.") return reduce_window_sharding_rule( @@ -1039,7 +1053,8 @@ def _select_and_gather_add_batching_rule( select_and_gather_add_p = lax.standard_primitive( _select_and_gather_add_shape_rule, lax._input_dtype, - 'select_and_gather_add', sharding_rule=_select_and_gather_add_sharding_rule) + 'select_and_gather_add', sharding_rule=_select_and_gather_add_sharding_rule, + vma_rule=partial(core.standard_vma_rule, 'select_and_gather_add')) ad.primitive_jvps[select_and_gather_add_p] = _select_and_gather_add_jvp ad.primitive_transposes[select_and_gather_add_p] = \ _select_and_gather_add_transpose diff --git a/jax/_src/layout.py b/jax/_src/layout.py index 5309f0b1fd9c..824778df453b 100644 --- a/jax/_src/layout.py +++ b/jax/_src/layout.py @@ -94,7 +94,7 @@ def check_compatible_aval(self, aval_shape: Shape): ShardingOptions = Union[Sharding, None, AutoSharding] -class Layout: +class Format: __slots__ = ['device_local_layout', 'sharding'] def __init__(self, device_local_layout: LayoutOptions = None, @@ -105,7 +105,7 @@ def __init__(self, device_local_layout: LayoutOptions = None, raise ValueError( 'Sharding has to be concrete when layout is of type' f' {type(device_local_layout)}. Please pass a' - ' `jax.sharding.NamedSharding`, `jax.sharding.PositionalSharding` or' + ' `jax.sharding.NamedSharding` or' ' `jax.sharding.SingleDeviceSharding` to the sharding argument. Got' f' sharding {sharding}' ) @@ -127,6 +127,10 @@ def __init__(self, device_local_layout: LayoutOptions = None, self.device_local_layout = device_local_layout self.sharding = sharding + @property + def dll(self): + return self.device_local_layout + def __repr__(self): return (f'Layout(device_local_layout={self.device_local_layout},' f' sharding={self.sharding})') @@ -135,7 +139,7 @@ def __hash__(self): return hash((self.device_local_layout, self.sharding)) def __eq__(self, other): - if not isinstance(other, Layout): + if not isinstance(other, Format): return False return (self.device_local_layout == other.device_local_layout and self.sharding == other.sharding) diff --git a/jax/_src/lib/BUILD b/jax/_src/lib/BUILD index 1fcbd4b6b7ef..e0b5ea607501 100644 --- a/jax/_src/lib/BUILD +++ b/jax/_src/lib/BUILD @@ -40,26 +40,5 @@ py_library_providing_imports_info( "//jax:version", ] + if_building_jaxlib([ "//jaxlib", - "//jaxlib/mosaic/python:gpu_dialect", - "//jaxlib/mosaic/python:tpu_dialect", - "//jaxlib:cpu_feature_guard", - "//jaxlib:utils", - "//jaxlib/triton", - "//jaxlib/mlir/_mlir_libs:register_jax_dialects", - "//jaxlib/mlir:arithmetic_dialect", - "//jaxlib/mlir:builtin_dialect", - "//jaxlib/mlir:chlo_dialect", - "//jaxlib/mlir:func_dialect", - "//jaxlib/mlir:ir", - "//jaxlib/mlir:math_dialect", - "//jaxlib/mlir:memref_dialect", - "//jaxlib/mlir:mhlo_dialect", - "//jaxlib/mlir:pass_manager", - "//jaxlib/mlir:scf_dialect", - "//jaxlib/mlir:sdy_dialect", - "//jaxlib/mlir:sparse_tensor_dialect", - "//jaxlib/mlir:stablehlo_dialect", - "//jaxlib/mlir:vector_dialect", - # xla_client ]), ) diff --git a/jax/_src/lib/__init__.py b/jax/_src/lib/__init__.py index 7933bb769733..fde926094e8b 100644 --- a/jax/_src/lib/__init__.py +++ b/jax/_src/lib/__init__.py @@ -40,7 +40,7 @@ raise ImportError(msg) from err -# Checks the jaxlib version before importing anything else from jaxlib. +# Checks the jaxlib version before importing anything else. # Returns the jaxlib version string. def check_jaxlib_version(jax_version: str, jaxlib_version: str, minimum_jaxlib_version: str) -> tuple[int, ...]: @@ -77,20 +77,42 @@ def _parse_version(v: str) -> tuple[int, ...]: jaxlib_version=jaxlib.version.__version__, minimum_jaxlib_version=jax.version._minimum_jaxlib_version) -# Before importing any C compiled modules from jaxlib, first import the CPU +# Before importing any C compiled modules, first import the CPU # feature guard module to verify that jaxlib was compiled in a way that only # uses instructions that are present on this machine. import jaxlib.cpu_feature_guard as cpu_feature_guard cpu_feature_guard.check_cpu_features() -import jaxlib.utils as utils # noqa: F401 -import jaxlib.xla_client as xla_client import jaxlib.lapack as lapack # noqa: F401 +import jaxlib.utils as utils # noqa: F401 +import jaxlib._jax as _jax # noqa: F401 +from jaxlib._jax import guard_lib as guard_lib # noqa: F401 +from jaxlib._jax import jax_jit as jax_jit # noqa: F401 +from jaxlib._jax import pmap_lib as pmap_lib # noqa: F401 +from jaxlib._jax import pytree as pytree # noqa: F401 +from jaxlib._jax import Device as Device # noqa: F401 +from jaxlib import _profiler as _profiler # noqa: F401 + +import jaxlib.xla_client as xla_client # noqa: F401 + +# Jaxlib code is split between the Jax and the XLA repositories. +# Only for the internal usage of the JAX developers, we expose a version +# number that can be used to perform changes without breaking the main +# branch on the Jax github. +jaxlib_extension_version: int = getattr(xla_client, '_version', 0) +ifrt_version: int = getattr(xla_client, '_ifrt_version', 0) + +from jaxlib._jax import ffi as ffi # noqa: F401 +import jaxlib.cpu_sparse as cpu_sparse # noqa: F401 +has_cpu_sparse = True + +import jaxlib.weakref_lru_cache as weakref_lru_cache # noqa: F401 + +if jaxlib_extension_version >= 350: + import jaxlib._pretty_printer as _pretty_printer # noqa: F401 +else: + _pretty_printer = None -xla_extension = xla_client._xla -pytree = xla_client._xla.pytree -jax_jit = xla_client._xla.jax_jit -pmap_lib = xla_client._xla.pmap_lib # XLA garbage collection: see https://github.com/jax-ml/jax/issues/14882 def _xla_gc_callback(*args): @@ -109,13 +131,6 @@ def _xla_gc_callback(*args): import jaxlib.gpu_sparse as gpu_sparse # pytype: disable=import-error # noqa: F401 import jaxlib.gpu_prng as gpu_prng # pytype: disable=import-error # noqa: F401 import jaxlib.gpu_linalg as gpu_linalg # pytype: disable=import-error # noqa: F401 -import jaxlib.hlo_helpers as hlo_helpers # pytype: disable=import-error # noqa: F401 - -# Jaxlib code is split between the Jax and the Tensorflow repositories. -# Only for the internal usage of the JAX developers, we expose a version -# number that can be used to perform changes without breaking the main -# branch on the Jax github. -xla_extension_version: int = getattr(xla_client, '_version', 0) import jaxlib.gpu_rnn as gpu_rnn # pytype: disable=import-error # noqa: F401 import jaxlib.gpu_triton as gpu_triton # pytype: disable=import-error # noqa: F401 @@ -123,9 +138,6 @@ def _xla_gc_callback(*args): import jaxlib.mosaic.python.mosaic_gpu as mosaic_gpu_dialect # pytype: disable=import-error # noqa: F401 import jaxlib.mosaic.python.tpu as tpu # pytype: disable=import-error # noqa: F401 -# Version number for MLIR:Python APIs, provided by jaxlib. -mlir_api_version = xla_client.mlir_api_version - # TODO(rocm): check if we need the same for rocm. def _cuda_path() -> str | None: @@ -160,14 +172,23 @@ def _try_cuda_nvcc_import() -> str | None: return str(cuda_nvcc_path) + def _try_bazel_runfiles() -> str | None: + """Try to get the path to the cuda installation in bazel runfiles.""" + python_runfiles = os.environ.get('PYTHON_RUNFILES') + if not python_runfiles: + return None + cuda_nvcc_root = os.path.join(python_runfiles, 'cuda_nvcc') + if os.path.exists(cuda_nvcc_root): + return cuda_nvcc_root + return None + if (path := _try_cuda_root_environment_variable()) is not None: return path elif (path := _try_cuda_nvcc_import()) is not None: return path + elif (path := _try_bazel_runfiles()) is not None: + return path return None cuda_path = _cuda_path() - -guard_lib = xla_client._xla.guard_lib -Device = xla_client._xla.Device diff --git a/jax/_src/lib/mlir/dialects/__init__.py b/jax/_src/lib/mlir/dialects/__init__.py index a9bae8821db5..eccd40104dc1 100644 --- a/jax/_src/lib/mlir/dialects/__init__.py +++ b/jax/_src/lib/mlir/dialects/__init__.py @@ -19,6 +19,7 @@ if TYPE_CHECKING: from jaxlib.mlir.dialects import arith as arith from jaxlib.mlir.dialects import builtin as builtin + from jaxlib.mlir.dialects import cf as cf from jaxlib.mlir.dialects import chlo as chlo from jaxlib.mlir.dialects import func as func from jaxlib.mlir.dialects import gpu as gpu @@ -36,6 +37,7 @@ __getattr__, __dir__, __all__ = _lazy.attach("jaxlib.mlir.dialects", [ "arith", "builtin", + "cf", "chlo", "func", "gpu", @@ -51,11 +53,9 @@ ]) del _lazy -# TODO(bartchr): Once JAX is released with SDY, remove the try/except. -try: - from jaxlib.mlir.dialects import sdy as sdy -except ImportError: - sdy: Any = None # type: ignore[no-redef] +from jaxlib.mlir.dialects import sdy # Alias that is set up to abstract away the transition from MHLO to StableHLO. from jaxlib.mlir.dialects import stablehlo as hlo + +from jax._src import lib diff --git a/jax/_src/linear_util.py b/jax/_src/linear_util.py index 1497597ebd62..b6b0b3cce982 100644 --- a/jax/_src/linear_util.py +++ b/jax/_src/linear_util.py @@ -67,15 +67,17 @@ def trans1(static_arg, *dynamic_args, **kwargs): from collections.abc import Callable, Sequence from functools import partial import re -from typing import Any, Hashable, NamedTuple +import time +from typing import Any, NamedTuple +from collections.abc import Hashable import warnings import weakref from jax._src import config from jax._src import core from jax._src import traceback_util -from jax._src.tree_util import keystr, KeyPath, generate_key_paths -from jax._src.util import curry, cache_clearing_funs, HashableFunction +from jax._src.tree_util import KeyPath, generate_key_paths, keystr +from jax._src.util import HashableFunction, cache_clearing_funs, curry, fun_name traceback_util.register_exclusion(__file__) @@ -185,7 +187,7 @@ def __init__(self, f: Callable, @property def __name__(self): - return getattr(self.f, '__name__', '') + return fun_name(self.f, "") def wrap(self, gen, gen_static_args, out_store: Store | EqualStore | None) -> WrappedFun: @@ -265,12 +267,6 @@ def transformation_with_aux2( out_thunk = lambda: out_store.val return fun.wrap(gen, gen_static_args, out_store), out_thunk -def fun_name(f): - try: - return f.__name__ - except: - return str(f) - class DebugInfo(NamedTuple): """Debugging info about a func, its arguments, and results.""" @@ -326,6 +322,18 @@ def replace_func_name(self, name: str) -> DebugInfo: func_src_comps[0] = name return self._replace(func_src_info=" ".join(func_src_comps)) + @property + def func_filename(self) -> str | None: + m = _re_func_src_info.match(self.func_src_info) + if not m: return None + return m.group(3) + + @property + def func_lineno(self) -> int | None: + m = _re_func_src_info.match(self.func_src_info) + if not m or m.group(4) is None: return None + return int(m.group(4)) + def safe_arg_names(self, expected: int) -> tuple[str, ...]: """Get the arg_names with a safety check.""" if len(self.arg_names) == expected: @@ -352,6 +360,7 @@ def filter_result_paths(self, keep: Sequence[bool]) -> tuple[str, ...]: assert self.result_paths is not None and not callable(self.result_paths), self return tuple(v for v, b in zip(self.safe_result_paths(len(keep)), keep) if b) +_re_func_src_info = re.compile(r"([^ ]+)( at (.+):(\d+))?$") def _missing_debug_info(for_what: str) -> DebugInfo: warnings.warn( @@ -433,7 +442,7 @@ def valid_size(d) -> bool: def cache(call: Callable, *, - explain: Callable[[WrappedFun, bool, dict, tuple], None] | None = None): + explain: Callable[[WrappedFun, bool, dict, tuple, float], None] | None = None): """Memoization decorator for functions taking a WrappedFun as first argument. Args: @@ -442,7 +451,8 @@ def cache(call: Callable, *, memoization cache key. explain: a function that is invoked upon cache misses to log an explanation - of the miss. Invoked with `(fun, is_cache_first_use, cache, key)`. + of the miss. + Invoked with `(fun, is_cache_first_use, cache, key, elapsed_sec)`. Returns: A memoized version of ``call``. @@ -457,9 +467,11 @@ def memoized_fun(fun: WrappedFun, *args): ans, stores = result fun.populate_stores(stores) else: + if do_explain := explain and config.explain_cache_misses.value: + start = time.time() ans = call(fun, *args) - if explain and config.explain_cache_misses.value: - explain(fun, cache is new_cache, cache, key) + if do_explain: + explain(fun, cache is new_cache, cache, key, time.time() - start) # type: ignore cache[key] = (ans, fun.stores) return ans diff --git a/jax/_src/mesh.py b/jax/_src/mesh.py index b490febf7b0c..442ca3f18d83 100644 --- a/jax/_src/mesh.py +++ b/jax/_src/mesh.py @@ -111,12 +111,16 @@ class AxisType(enum.Enum): def __repr__(self): return self.name -def _normalize_axis_types(axis_names, axis_types): +def _normalize_axis_types(axis_names, axis_types, name): axis_types = ((AxisType.Auto,) * len(axis_names) if axis_types is None else axis_types) if not isinstance(axis_types, tuple): - assert isinstance(axis_types, AxisType), axis_types axis_types = (axis_types,) + + if not all(isinstance(a, AxisType) for a in axis_types): + raise TypeError( + f"axis_types passed to {name} must be of type `jax.sharding.AxisType`." + f" Got {axis_types} of type {tuple(type(a) for a in axis_types)}") if len(axis_names) != len(axis_types): raise ValueError( "Number of axis names should match the number of axis_types. Got" @@ -174,6 +178,28 @@ def _any_axis_auto(self) -> bool: def _any_axis_explicit(self) -> bool: return any_axis_types_match(self._axis_types, AxisType.Explicit) + @functools.cached_property + def _any_axis_auto_or_manual(self) -> bool: + if not self._axis_types: + return False + return any(t == AxisType.Auto or t == AxisType.Manual + for t in self._axis_types) + + @functools.cached_property + def auto_axes(self): + return tuple(n for n, t in safe_zip(self.axis_names, self._axis_types) + if t == AxisType.Auto) + + @functools.cached_property + def explicit_axes(self): + return tuple(n for n, t in safe_zip(self.axis_names, self._axis_types) + if t == AxisType.Explicit) + + @functools.cached_property + def manual_axes(self): + return tuple(n for n, t in safe_zip(self.axis_names, self._axis_types) + if t == AxisType.Manual) + @functools.cached_property def _axis_types_dict(self): if not self.axis_names: @@ -194,16 +220,9 @@ def _name_to_type(self): class Mesh(_BaseMesh, contextlib.ContextDecorator): """Declare the hardware resources available in the scope of this manager. - In particular, all ``axis_names`` become valid resource names inside the - managed block and can be used e.g. in the ``in_axis_resources`` argument of - :py:func:`jax.experimental.pjit.pjit`. Also see JAX's multi-process programming - model (https://jax.readthedocs.io/en/latest/multi_process.html) - and the Distributed arrays and automatic parallelization tutorial - (https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html) - - If you are compiling in multiple threads, make sure that the - ``with Mesh`` context manager is inside the function that the threads will - execute. + See the Distributed arrays and automatic parallelization tutorial + (https://docs.jax.dev/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html) + and Explicit sharding tutorial (https://docs.jax.dev/en/latest/notebooks/explicit-sharding.html) Args: devices: A NumPy ndarray object containing JAX device objects (as @@ -211,35 +230,24 @@ class Mesh(_BaseMesh, contextlib.ContextDecorator): axis_names: A sequence of resource axis names to be assigned to the dimensions of the ``devices`` argument. Its length should match the rank of ``devices``. + axis_types: and optional tuple of :class:`jax.sharding.AxisType` entries corresponding to + the ``axis_names``. See `Explicit Sharding`_ for more information. Examples: - >>> from jax.experimental.pjit import pjit >>> from jax.sharding import Mesh - >>> from jax.sharding import PartitionSpec as P + >>> from jax.sharding import PartitionSpec as P, NamedSharding >>> import numpy as np ... - >>> inp = np.arange(16).reshape((8, 2)) - >>> devices = np.array(jax.devices()).reshape(4, 2) - ... >>> # Declare a 2D mesh with axes `x` and `y`. - >>> global_mesh = Mesh(devices, ('x', 'y')) - >>> # Use the mesh object directly as a context manager. - >>> with global_mesh: - ... out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp) - - >>> # Initialize the Mesh and use the mesh as the context manager. - >>> with Mesh(devices, ('x', 'y')) as global_mesh: - ... out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp) - - >>> # Also you can use it as `with ... as ...`. - >>> global_mesh = Mesh(devices, ('x', 'y')) - >>> with global_mesh as m: - ... out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp) - - >>> # You can also use it as `with Mesh(...)`. - >>> with Mesh(devices, ('x', 'y')): - ... out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp) + >>> devices = np.array(jax.devices()).reshape(4, 2) + >>> mesh = Mesh(devices, ('x', 'y')) + >>> inp = np.arange(16).reshape(8, 2) + >>> arr = jax.device_put(inp, NamedSharding(mesh, P('x', 'y'))) + >>> out = jax.jit(lambda x: x * 2)(arr) + >>> assert out.sharding == NamedSharding(mesh, P('x', 'y')) + + .. _Explicit Sharding: https://docs.jax.dev/en/latest/notebooks/explicit-sharding.html """ devices: np.ndarray @@ -263,7 +271,7 @@ def __new__(cls, devices: np.ndarray | Sequence[xc.Device], f"devices.ndim == {devices.ndim} and " f"len(axis_names) == {len(axis_names)}.") - axis_types = _normalize_axis_types(axis_names, axis_types) + axis_types = _normalize_axis_types(axis_names, axis_types, 'Mesh') key = (axis_names, devices.shape, tuple(devices.flat), axis_types) val = _mesh_object_dict.get(key, None) @@ -362,10 +370,6 @@ def empty(self): def is_multi_process(self): return self.devices.size != len(self.local_devices) - @functools.cached_property - def _process_indices(self): - return {d.process_index for d in self._flat_devices_tuple} - @property def local_mesh(self): return self._local_mesh(xb.process_index()) @@ -440,6 +444,16 @@ class AbstractMesh(_BaseMesh): your mesh shape and axis names stay the same but the devices change. See the description of https://github.com/jax-ml/jax/pull/23022 for more details. + + Args: + axis_sizes: A tuple of integers specifying the size of each resource axis. + axis_names: A tuple of resource axis names to be assigned to the + dimensions of the ``devices`` argument. Its length should match the + rank of ``devices``. + axis_types: and optional tuple of :class:`jax.sharding.AxisType` entries corresponding to + the ``axis_names``. See `Explicit Sharding`_ for more information. + + .. _Explicit Sharding: https://docs.jax.dev/en/latest/notebooks/explicit-sharding.html """ def __init__(self, axis_sizes: tuple[int, ...], axis_names: tuple[str, ...], @@ -447,7 +461,8 @@ def __init__(self, axis_sizes: tuple[int, ...], axis_names: tuple[str, ...], self.axis_sizes = axis_sizes self.axis_names = axis_names self._size = math.prod(self.axis_sizes) if self.axis_sizes else 0 - self._axis_types = _normalize_axis_types(self.axis_names, axis_types) + self._axis_types = _normalize_axis_types( + self.axis_names, axis_types, 'AbstractMesh') self._hash = hash((self.axis_sizes, self.axis_names, self._axis_types)) def __hash__(self): diff --git a/jax/_src/mesh_utils.py b/jax/_src/mesh_utils.py index ccc75af8c84f..fdb8e10d598e 100644 --- a/jax/_src/mesh_utils.py +++ b/jax/_src/mesh_utils.py @@ -1,4 +1,4 @@ -# Copyright 2021 The TensorFlow Authors. All Rights Reserved. +# Copyright 2021 The JAX Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -34,6 +34,7 @@ _TPU_V5_LITE = "TPU v5 lite" _TPU_V5E = "TPU v5e" _TPU_V5P = "TPU v5p" +_TPU_V6_LITE = "TPU v6 lite" # Maps physical topology -> mesh shape -> transpose to use for jekbradbury's # famous contiguous mesh trick. @@ -190,6 +191,7 @@ def _v5p_create_device_mesh( _TPU_V3: _tpu_v2_v3_create_device_mesh, _TPU_V5_LITE: _v5e_create_device_mesh, _TPU_V5P: _v5p_create_device_mesh, + _TPU_V6_LITE: _v5e_create_device_mesh, } diff --git a/jax/_src/monitoring.py b/jax/_src/monitoring.py index 99e957733ba2..de706ccbaef5 100644 --- a/jax/_src/monitoring.py +++ b/jax/_src/monitoring.py @@ -46,10 +46,18 @@ def __call__( ) -> None: ... +class ScalarListenerWithMetadata(Protocol): + + def __call__( + self, event: str, value: float | int, **kwargs: str | int, + ) -> None: + ... + _event_listeners: list[EventListenerWithMetadata] = [] _event_duration_secs_listeners: list[EventDurationListenerWithMetadata] = [] _event_time_span_listeners: list[EventTimeSpanListenerWithMetadata] = [] +_scalar_listeners: list[ScalarListenerWithMetadata] = [] def record_event(event: str, **kwargs: str | int) -> None: @@ -81,6 +89,14 @@ def record_event_time_span( callback(event, start_time, end_time, **kwargs) +def record_scalar( + event: str, value: float | int, **kwargs: str | int +) -> None: + """Record a scalar summary value.""" + for callback in _scalar_listeners: + callback(event, value, **kwargs) + + def register_event_listener( callback: EventListenerWithMetadata, ) -> None: @@ -100,6 +116,14 @@ def register_event_duration_secs_listener( """Register a callback to be invoked during record_event_duration_secs().""" _event_duration_secs_listeners.append(callback) + +def register_scalar_listener( + callback : ScalarListenerWithMetadata, +) -> None: + """Register a callback to be invoked during record_scalar().""" + _scalar_listeners.append(callback) + + def get_event_duration_listeners() -> list[EventDurationListenerWithMetadata]: """Get event duration listeners.""" return list(_event_duration_secs_listeners) @@ -114,12 +138,20 @@ def get_event_listeners() -> list[EventListenerWithMetadata]: """Get event listeners.""" return list(_event_listeners) + +def get_scalar_listeners() -> list[ScalarListenerWithMetadata]: + """Get scalar event listeners.""" + return list(_scalar_listeners) + + def clear_event_listeners(): """Clear event listeners.""" global _event_listeners, _event_duration_secs_listeners, _event_time_span_listeners _event_listeners = [] _event_duration_secs_listeners = [] _event_time_span_listeners = [] + _scalar_listeners = [] + def _unregister_event_duration_listener_by_callback( callback: EventDurationListenerWithMetadata) -> None: @@ -159,3 +191,14 @@ def _unregister_event_listener_by_callback( """ assert callback in _event_listeners _event_listeners.remove(callback) + + +def _unregister_scalar_listener_by_callback( + callback: ScalarListenerWithMetadata, +) -> None: + """Unregister a scalar event listener by callback. + + This function is supposed to be called for testing only. + """ + assert callback in _scalar_listeners + _scalar_listeners.remove(callback) diff --git a/jax/_src/named_sharding.py b/jax/_src/named_sharding.py index 5accdd880a79..1b0ae46a968b 100644 --- a/jax/_src/named_sharding.py +++ b/jax/_src/named_sharding.py @@ -21,13 +21,13 @@ from typing import Any, Union from jax._src import config -from jax._src.util import use_cpp_class, cache, use_cpp_method, tuple_insert +from jax._src.util import use_cpp_class, cache, use_cpp_method from jax._src.lib import xla_client as xc from jax._src.lib.mlir.dialects import sdy from jax._src import mesh as mesh_lib -from jax._src.partition_spec import PartitionSpec, UnconstrainedSingleton +from jax._src.mesh import AxisType +from jax._src.partition_spec import PartitionSpec from jax._src import sharding as JSharding -from jax._src import xla_bridge as xb import numpy as np Shape = tuple[int, ...] @@ -41,10 +41,11 @@ class AUTO: def __init__(self, mesh: mesh_lib.Mesh): self.mesh = mesh - def _to_sdy_sharding(self, ndim: int) -> SdyArraySharding: - dim_shardings = [SdyDimSharding(axes=[], is_closed=False) + def _to_sdy_sharding(self, ndim: int) -> SdyArray: + dim_shardings = [SdyDim(axes=[], is_open=True) for _ in range(ndim)] - return SdyArraySharding(self.mesh.shape_tuple, dim_shardings) + return SdyArray(mesh_shape=self.mesh.shape_tuple, + dim_shardings=dim_shardings) class UnspecifiedValue: def __repr__(self): @@ -73,6 +74,11 @@ def __repr__(self): ArrayMappingOrAutoOrUnspecified = Union[ArrayMapping, AUTO, UnspecifiedValue] +def _unpickle_named_sharding(mesh, spec, memory_kind, logical_device_ids): + return NamedSharding(mesh, spec, memory_kind=memory_kind, + _logical_device_ids=logical_device_ids) + + @use_cpp_class(xc.NamedSharding) class NamedSharding(JSharding.Sharding): r"""A :class:`NamedSharding` expresses sharding using named axes. @@ -92,7 +98,7 @@ class NamedSharding(JSharding.Sharding): across ``y`` axis of the mesh. The Distributed arrays and automatic parallelization - (https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#namedsharding-gives-a-way-to-express-shardings-with-names) + (https://docs.jax.dev/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#namedsharding-gives-a-way-to-express-shardings-with-names) tutorial has more details and diagrams that explain how :class:`Mesh` and :class:`PartitionSpec` are used. @@ -112,20 +118,17 @@ class NamedSharding(JSharding.Sharding): mesh: mesh_lib.Mesh | mesh_lib.AbstractMesh spec: PartitionSpec _memory_kind: str | None - _manual_axes: frozenset[MeshAxisName] _logical_device_ids: tuple[int, ...] | None @use_cpp_method() def __init__( self, mesh: mesh_lib.Mesh | mesh_lib.AbstractMesh, spec: PartitionSpec, *, - memory_kind: str | None = None, _manual_axes=frozenset(), - _logical_device_ids=None): + memory_kind: str | None = None, _logical_device_ids=None): self.mesh = mesh self.spec = spec self._memory_kind = memory_kind - self._manual_axes = _manual_axes self._logical_device_ids = _logical_device_ids - check_pspec(self.mesh, self.spec, self._manual_axes) + check_pspec(self.mesh, self.spec) def __repr__(self): mem = '' if self.memory_kind is None else f', memory_kind={self.memory_kind}' @@ -135,22 +138,21 @@ def __repr__(self): return f'NamedSharding(mesh={mesh_repr}, spec={self.spec}{mem}{ldi})' def __reduce__(self): - return (type(self), (self.mesh, self.spec), - {'memory_kind': self.memory_kind, - '_manual_axes': self._manual_axes, - '_logical_device_ids': self._logical_device_ids}) + return (_unpickle_named_sharding, + (self.mesh, self.spec, self.memory_kind, self._logical_device_ids)) @property def memory_kind(self) -> str | None: return self._memory_kind + @use_cpp_method() def __hash__(self): if not hasattr(self, '_hash'): self._hash = hash( - (self.mesh, self.memory_kind, self.spec, self._manual_axes, - self._logical_device_ids)) + (self.mesh, self.memory_kind, self.spec, self._logical_device_ids)) return self._hash + @use_cpp_method() def __eq__(self, other): if not isinstance(other, NamedSharding): return False @@ -158,7 +160,6 @@ def __eq__(self, other): return True if (self.spec != other.spec or self.memory_kind != other.memory_kind - or self._manual_axes != other._manual_axes or self._logical_device_ids != other._logical_device_ids): return False return self.mesh is other.mesh or self.mesh == other.mesh @@ -198,10 +199,7 @@ def is_fully_addressable(self) -> bool: # Speed up `is_fully_addressable` since there is a high chance that the # mesh across multiple NamedSharding objects will be the same. if config.enable_empty_arrays.value: - client = self._internal_device_list[0].client - return (len(self.mesh._process_indices) == 1 and - next(iter(self.mesh._process_indices)) == - xb.process_index(client)) + return self._internal_device_list.is_fully_addressable # type: ignore return not self.mesh.is_multi_process @property @@ -231,31 +229,40 @@ def is_fully_replicated(self) -> bool: return num_partitions == 1 def with_memory_kind(self, kind: str) -> NamedSharding: - return NamedSharding(self.mesh, self.spec, memory_kind=kind) + return self.update(memory_kind=kind) - def with_spec(self, spec: PartitionSpec | Sequence[Any]) -> NamedSharding: + def update(self, **kwargs) -> NamedSharding: + spec = kwargs.pop("spec", self.spec) if not isinstance(spec, PartitionSpec): spec = PartitionSpec(*spec) - return NamedSharding(self.mesh, spec, memory_kind=self.memory_kind) + return NamedSharding( + mesh=kwargs.pop("mesh", self.mesh), + spec=spec, + memory_kind=kwargs.pop("memory_kind", self.memory_kind), + _logical_device_ids=kwargs.pop("_logical_device_ids", + self._logical_device_ids)) def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding: return named_sharding_to_xla_hlo_sharding(self, num_dimensions) - def _to_sdy_sharding(self, num_dimensions: int) -> SdyArraySharding: - dim_shardings = [SdyDimSharding(axes=[], is_closed=True) + def _to_sdy_sharding(self, num_dimensions: int) -> SdyArray: + dim_shardings = [SdyDim(axes=[], is_open=False) for _ in range(num_dimensions)] for i, dim_spec in enumerate(self.spec): if dim_spec is PartitionSpec.UNCONSTRAINED: - dim_shardings[i].is_closed = False + dim_shardings[i].is_open = True elif dim_spec is None: # Already empty and closed sharding. pass else: dim_spec = dim_spec if isinstance(dim_spec, tuple) else (dim_spec,) dim_shardings[i].axes = dim_spec - return SdyArraySharding(self.mesh.shape_tuple, dim_shardings, - self._logical_device_ids) + return SdyArray(mesh_shape=self.mesh.shape_tuple, + dim_shardings=dim_shardings, + logical_device_ids=self._logical_device_ids, + unreduced_axes=self.spec.unreduced) +NamedSharding.__module__ = 'jax.sharding' def get_array_mapping( axis_resources: PartitionSpec | AUTO | UnspecifiedValue @@ -272,35 +279,42 @@ def get_array_mapping( return d @dataclasses.dataclass -class SdyDimSharding: +class SdyDim: axes: Sequence[str] - is_closed: bool + is_open: bool priority: int | None = None def build(self) -> sdy.DimensionShardingAttr: return sdy.DimensionShardingAttr.get( [sdy.AxisRefAttr.get(axis) for axis in self.axes], - is_closed=self.is_closed, - priority=self.priority) + is_closed=not self.is_open, priority=self.priority) def __repr__(self): - return f'SdyDimSharding({self._custom_repr()})' + return f'SdyDim({self._custom_repr()})' def _custom_repr(self): axes_repr = ', '.join(f"'{a}'" for a in self.axes) open_repr = '' - if not self.is_closed: + if self.is_open: open_repr = ', ?' if self.axes else '?' priority_repr = '' if self.priority is None else f'p{self.priority}' return f'{{{axes_repr}{open_repr}}}{priority_repr}' +def _get_axes(axes, mesh_shape): + if not axes: + return () + assert mesh_shape is not None + # Sort wrt mesh axis names so order is deterministic and doesn't hang in + # McJAX. + return tuple(n for n, _ in mesh_shape if n in axes) -@dataclasses.dataclass -class SdyArraySharding: +@dataclasses.dataclass(kw_only=True) +class SdyArray: mesh_shape: tuple[tuple[str, int], ...] | None - dimension_shardings: Sequence[SdyDimSharding] + dim_shardings: Sequence[SdyDim] logical_device_ids: tuple[int, ...] | None = None replicated_axes: tuple[str, ...] = () + unreduced_axes: frozenset[str] = frozenset() def build(self) -> sdy.TensorShardingAttr: if self.mesh_shape is None: @@ -311,94 +325,44 @@ def build(self) -> sdy.TensorShardingAttr: mesh_attr = sdy.MeshAttr.get( [sdy.MeshAxisAttr.get(name, size) for name, size in self.mesh_shape], ldi) + + replicated_axes = _get_axes(self.replicated_axes, self.mesh_shape) + unreduced_axes = _get_axes(self.unreduced_axes, self.mesh_shape) return sdy.TensorShardingAttr.get( mesh_attr, - [dim_sharding.build() for dim_sharding in self.dimension_shardings], - replicated_axes=[sdy.AxisRefAttr.get(axis) for axis in self.replicated_axes]) + [dim_sharding.build() for dim_sharding in self.dim_shardings], + replicated_axes=[sdy.AxisRefAttr.get(axis) for axis in replicated_axes], + unreduced_axes=[sdy.AxisRefAttr.get(axis) for axis in unreduced_axes]) def __repr__(self): dim_sharding_repr = ', '.join( - d._custom_repr() for d in self.dimension_shardings) + d._custom_repr() for d in self.dim_shardings) device_id_repr = (f', device_ids={self.logical_device_ids}' if self.logical_device_ids is not None else '') rar = (f', replicated_axes={self.replicated_axes}' if self.replicated_axes else '') - return f"SdyArraySharding([{dim_sharding_repr}]{device_id_repr}{rar})" - -# TODO(yashkatariya): Remove this after jax 0.5.2 release -class ParsedPartitionSpec: - __slots__ = ('_user_spec', 'partitions') - - _user_spec: PartitionSpec | None - partitions: tuple[tuple[MeshAxisName, ...] | UnconstrainedSingleton, ...] + return f"SdyArray([{dim_sharding_repr}]{device_id_repr}{rar})" + + +# TODO(yashkatariya): Upstream this into `_to_sdy_sharding` maybe with an extra +# parameter to it `_to_sdy_sharding(self, ndim, modify_wrt_axis_types=False)` +def modify_sdy_sharding_wrt_axis_types(sdy_sharding: SdyArray, mesh): + if mesh._any_axis_auto: + dim_shardings, used_axes = [], [] # type: ignore + for d in sdy_sharding.dim_shardings: + # TODO(yashkatariya): Maybe if any mesh axis is auto, mark all axes as open? + dim_shardings.append(SdyDim(axes=[], is_open=True) + if not d.axes and not d.is_open else d) + used_axes.extend(d.axes) + remaining_axes = set(mesh.axis_names) - set(used_axes) + replicated_axes = tuple(r for r in remaining_axes + if mesh._name_to_type[r] == mesh_lib.AxisType.Explicit) + return SdyArray(mesh_shape=sdy_sharding.mesh_shape, + dim_shardings=dim_shardings, + logical_device_ids=sdy_sharding.logical_device_ids, + replicated_axes=replicated_axes) + return sdy_sharding - def __init__(self, user_spec, partitions): - self._user_spec = user_spec - assert None not in partitions, partitions - self.partitions = tuple(partitions) - - def get_partition_spec(self) -> PartitionSpec: - if isinstance(self._user_spec, PartitionSpec): - return self._user_spec - else: - return get_single_pspec(self) - - def insert_axis_partitions(self, dim, val): - parts = self.partitions - too_short = dim - len(parts) - if too_short > 0: - parts += ((),) * too_short - new_partitions = tuple_insert(parts, dim, val) - return ParsedPartitionSpec(None, new_partitions) - - @classmethod - def from_user_input( - cls, - entry: PartitionSpec | None, - arg_name: str, - allow_unconstrained_dims: bool = False, - ) -> ParsedPartitionSpec: - if entry is None: - return cls(entry, ()) - if not isinstance(entry, PartitionSpec): - raise TypeError(f"{arg_name} are expected to be " - f"PartitionSpec instances or None, but got {entry}") - axis_specs = [] - for axis_spec in entry: - if axis_spec is None: - axis_spec = () - elif isinstance(axis_spec, (list, tuple)): - axis_spec = tuple(axis_spec) - elif axis_spec is PartitionSpec.UNCONSTRAINED: - if not allow_unconstrained_dims: - raise ValueError(f"Unconstrained dims are not allowed: {entry}") - axis_spec = PartitionSpec.UNCONSTRAINED - else: - axis_spec = (axis_spec,) - axis_specs.append(axis_spec) - new_entry = PartitionSpec( - *[tuple(e) if isinstance(e, (list, tuple)) else e for e in entry]) - return cls(new_entry, axis_specs) - - def __hash__(self): - return hash(self.partitions) - - def __eq__(self, other): - if not isinstance(other, ParsedPartitionSpec): - return False - return self.partitions == other.partitions - - def __len__(self): - return len(self.partitions) - - def __getitem__(self, i): - return self.partitions[i] - - def __iter__(self): - return iter(self.partitions) - - def __repr__(self): - return f"ParsedPartitionSpec(partitions={self.partitions})" @cache(max_size=4096, trace_context_in_key=False) def named_sharding_to_xla_hlo_sharding( @@ -408,9 +372,7 @@ def named_sharding_to_xla_hlo_sharding( mesh_axis_pos = {name: i for i, name in enumerate(self.mesh.axis_names)} special_axes = {} - mesh_manual_axes = {n for n, t in self.mesh._name_to_type.items() - if t == mesh_lib.AxisType.Manual} - manual_axes = self._manual_axes.union(mesh_manual_axes) + manual_axes = frozenset(self.mesh.manual_axes) if manual_axes: axis_names = self.mesh.axis_names for manual_axis in manual_axes: @@ -432,7 +394,7 @@ def named_sharding_to_xla_hlo_sharding( last_tile_dims = [] if replicated_mesh_axes: - axes_by_type = collections.defaultdict(list) + axes_by_type: dict[Any, list[int]] = collections.defaultdict(list) size_by_type = collections.defaultdict(lambda: 1) # type: ignore assert {x[0] for x in replicated_mesh_axes}.issuperset(set(special_axes.keys())) for i, size in replicated_mesh_axes: @@ -491,21 +453,12 @@ def array_mapping_to_axis_resources(array_mapping: ArrayMapping): partitions.append(None) return PartitionSpec(*partitions) -get_single_pspec = lambda p: array_mapping_to_axis_resources(get_array_mapping(p)) # type: ignore - -# TODO(yashkatariya): Remove this after jax 0.5.2 release -def preprocess(mesh, spec, parsed_pspec, _manual_axes=frozenset()): - if parsed_pspec is None: - spec = PartitionSpec() if spec is None else spec - parsed_pspec = ParsedPartitionSpec.from_user_input( - spec, "NamedSharding spec", allow_unconstrained_dims=True) - _check_unique_resources(parsed_pspec, "NamedSharding spec", mesh) - _check_mesh_resource_axis(mesh, parsed_pspec, _manual_axes) - return parsed_pspec +@cache(max_size=128, trace_context_in_key=False) def check_pspec(mesh, spec, _manual_axes=frozenset()): _check_unique_resources(spec, "NamedSharding spec", mesh) - _check_mesh_resource_axis(mesh, spec, _manual_axes) + _check_mesh_resource_axis(mesh, spec) + _check_mesh_unreduced(mesh, spec) class DuplicateSpecError(Exception): def __init__(self, message, mesh, pspec): @@ -517,13 +470,10 @@ def __init__(self, message, mesh, pspec): def __str__(self): return f"{self.message}" -def _check_unique_resources( - pspec: ParsedPartitionSpec | PartitionSpec, arg_name: str, mesh=None, -) -> None: +def _check_unique_resources(pspec: PartitionSpec, arg_name: str, mesh=None + ) -> None: resource_counts: dict[MeshAxisName, int] = {} duplicate = False - pspec = (pspec.get_partition_spec() if isinstance(pspec, ParsedPartitionSpec) - else pspec) for d in pspec: if d is PartitionSpec.UNCONSTRAINED or d is None: continue @@ -542,31 +492,47 @@ def _check_unique_resources( f' for {mesh_lib.show_axes(multiple_uses)}'), mesh=mesh, pspec=pspec) -@cache(max_size=128, trace_context_in_key=False) -def _check_mesh_resource_axis(mesh, pspec, _manual_axes): - pspec = (pspec.get_partition_spec() if isinstance(pspec, ParsedPartitionSpec) - else pspec) +def _check_mesh_resource_axis(mesh, pspec): for p in pspec: if p is PartitionSpec.UNCONSTRAINED or p is None: continue p = p if isinstance(p, tuple) else (p,) for r in p: - if r not in mesh.shape: + if r not in mesh.axis_names: raise ValueError( f"Resource axis: {r} of {pspec} " f"is not found in mesh: {tuple(mesh.shape.keys())}.") - if r in _manual_axes: - raise ValueError( - f"Axis: {r} of {pspec} " - f"is also found in manual_axes: {_manual_axes}.") from None if not all(mesh._name_to_type[p[0]] == mesh._name_to_type[r] for r in p): raise ValueError( 'AxisTypes should be the same in a tuple subset of PartitionSpec:' f' {pspec}. Got subset {p} with axis' f' types: ({", ".join(str(mesh._name_to_type[r]) for r in p)})') - if (mesh_lib.AxisType.Auto not in mesh._axis_types_dict and + if (AxisType.Auto not in mesh._axis_types_dict and PartitionSpec.UNCONSTRAINED in pspec): raise ValueError( f'{pspec} cannot contain' ' `P.UNCONSTRAINED` when no mesh axis_types are `Auto`. Got mesh' f' axis_types: {mesh._axis_types_dict}') + +def _check_mesh_unreduced(mesh, pspec): + for u in pspec.unreduced: + if u not in mesh.axis_names: + raise ValueError( + f'Unreduced axes {u} is not found in {mesh.axis_names=}. ' + f'Got {pspec=}') + if mesh._name_to_type[u] in (AxisType.Auto, AxisType.Manual): + raise ValueError( + 'Unreduced axes can only refer to mesh axes that is of type' + f' `Explicit`. Got unreduced axes: {pspec.unreduced} and' + f' mesh: {mesh}') + + for u in pspec.reduced: + if u not in mesh.axis_names: + raise ValueError( + f'Reduced axes {u} is not found in {mesh.axis_names=}. ' + f'Got {pspec=}') + if mesh._name_to_type[u] in (AxisType.Auto, AxisType.Manual): + raise ValueError( + 'Reduced axes can only refer to mesh axes that is of type' + f' `Explicit`. Got reduced axes: {pspec.reduced} and' + f' mesh: {mesh}') diff --git a/jax/_src/nn/functions.py b/jax/_src/nn/functions.py index 7df0a638e566..db960e842403 100644 --- a/jax/_src/nn/functions.py +++ b/jax/_src/nn/functions.py @@ -21,7 +21,8 @@ import operator import math import numpy as np -from typing import Any, List, Literal +from typing import Any, Literal +import warnings import jax import jax.numpy as jnp @@ -47,13 +48,26 @@ from jax._src.ops.special import logsumexp as _logsumexp -class Unspecified: - def __repr__(self): - return "_UNSPECIFIED" -_UNSPECIFIED = Unspecified() +# activations +@jax.jit +def identity(x: ArrayLike) -> Array: + r"""Identity activation function. + Returns the argument unmodified. -# activations + Args: + x : input array + + Returns: + The argument `x` unmodified. + + Examples: + >>> jax.nn.identity(jax.numpy.array([-2., -1., -0.5, 0, 0.5, 1., 2.])) + Array([-2. , -1. , -0.5, 0. , 0.5, 1. , 2. ], dtype=float32) + + """ + numpy_util.check_arraylike("identity", x) + return jnp.asarray(x) @custom_jvp @jax.jit @@ -505,8 +519,7 @@ def glu(x: ArrayLike, axis: int = -1) -> Array: @partial(jax.jit, static_argnames=("axis",)) def log_softmax(x: ArrayLike, axis: int | tuple[int, ...] | None = -1, - where: ArrayLike | None = None, - initial: Unspecified = _UNSPECIFIED) -> Array: + where: ArrayLike | None = None) -> Array: r"""Log-Softmax function. Computes the logarithm of the :code:`softmax` function, which rescales @@ -532,10 +545,6 @@ def log_softmax(x: ArrayLike, See also: :func:`softmax` """ - # TODO(jakevdp): remove the initial argument after JAX v0.4.40. - if initial is not _UNSPECIFIED: - raise TypeError("The initial argument to jax.nn.log_softmax was removed in JAX v0.4.36.") - del initial numpy_util.check_arraylike("log_softmax", x) x_arr = jnp.asarray(x) x_max = jnp.max(x_arr, axis, where=where, initial=-jnp.inf, keepdims=True) @@ -553,8 +562,7 @@ def log_softmax(x: ArrayLike, # @partial(jax.jit, static_argnames=("axis",)) def softmax(x: ArrayLike, axis: int | tuple[int, ...] | None = -1, - where: ArrayLike | None = None, - initial: Unspecified = _UNSPECIFIED) -> Array: + where: ArrayLike | None = None) -> Array: r"""Softmax function. Computes the function which rescales elements to the range :math:`[0, 1]` @@ -580,10 +588,6 @@ def softmax(x: ArrayLike, See also: :func:`log_softmax` """ - # TODO(jakevdp): remove the initial argument after JAX v0.4.40. - if initial is not _UNSPECIFIED: - raise TypeError("The initial argument to jax.nn.softmax was removed in JAX v0.4.36.") - del initial if config.softmax_custom_jvp.value: # mypy is confused by the `functools.partial` application in the definition # of `_softmax` and incorrectly concludes that `_softmax` returns @@ -629,12 +633,38 @@ def _softmax_deprecated( @partial(jax.jit, static_argnames=("axis",)) def standardize(x: ArrayLike, - axis: int | tuple[int, ...] | None = -1, - mean: ArrayLike | None = None, - variance: ArrayLike | None = None, - epsilon: ArrayLike = 1e-5, - where: ArrayLike | None = None) -> Array: - r"""Normalizes an array by subtracting ``mean`` and dividing by :math:`\sqrt{\mathrm{variance}}`.""" + axis: int | tuple[int, ...] | None = -1, + mean: ArrayLike | None = None, + variance: ArrayLike | None = None, + epsilon: ArrayLike = 1e-5, + where: ArrayLike | None = None) -> Array: + r"""Standardizes input to zero mean and unit variance. + + The standardization is given by: + + .. math:: + + x_{std} = \frac{x - \langle x\rangle}{\sqrt{\langle(x - \langle x\rangle)^2\rangle + \epsilon}} + + where :math:`\langle x\rangle` indicates the mean of :math:`x`, and :math:`\epsilon` is + a small correction factor introduced to avoid division by zero. + + Args: + x: input array to be standardized. + axis: integer or tuple of integers representing the axes along which + to standardize. Defaults to the last axis (``-1``). + mean: optionally specify the mean used for standardization. If not specified, + then ``x.mean(axis, where=where)`` will be used. + variance: optionally specify the variance used for standardization. If not + specified, then ``x.var(axis, where=where)`` will be used. + epsilon: correction factor added to variance to avoid division by zero; defaults + to ``1E-5``. + where: optional boolean mask specifying which elements to use when computing + the mean and variance. + + Returns: + An array of the same shape as ``x`` containing the standardized input. + """ numpy_util.check_arraylike("standardize", x) numpy_util.check_arraylike_or_none("standardize", mean, variance, where) if mean is None: @@ -657,9 +687,9 @@ def _one_hot(x: Array, num_classes: int, *, "The error arose in jax.nn.one_hot argument `num_classes`.") dtype = dtypes.canonicalize_dtype(dtype) try: - output_pos_axis = util.canonicalize_axis(axis, x.ndim + 1) + output_pos_axis = util.canonicalize_axis(axis, x.ndim + 1) # type: ignore[arg-type] except TypeError: - axis_size = lax.psum(1, axis) + axis_size = lax.axis_size(axis) if num_classes != axis_size: raise ValueError(f"Expected num_classes to match the size of axis {axis}, " f"but {num_classes} != {axis_size}") from None @@ -1073,7 +1103,7 @@ def dot_product_attention( token's local window. If set, this specifies the (left_window_size, right_window_size) for each token. E.g., if local_window_size == (3, 2) and the sequence is [0, 1, 2, 3, 4, 5, c, 7, 8, 9], token `c` can attend - to [3, 4, 5, c, 7, 8]. If a single int is given, it will be intepreted as + to [3, 4, 5, c, 7, 8]. If a single int is given, it will be interpreted as a symmetric window (window_size, window_size). implementation: A string to control which implementation backend to use. Supported strings are `xla`, `cudnn` (cuDNN flash attention). It defaults @@ -1197,101 +1227,219 @@ def scaled_matmul( rhs_scales: Array, preferred_element_type: DTypeLike = jnp.float32, ) -> Array: - r""" - Performs scaled matrix multiplication between two 3D arrays, with scaling - factors applied to the matrices. - .. math:: - \mathrm{ScaledMatmul}(lhs, rhs, lhs_scales, rhs_scales)=lhs_scales \cdot rhs_scales \cdot \mathrm{dot}(lhs, rhs) + r"""Scaled matrix multiplication function. + + Performs block-scaled matmul of `a` and `b` using `a_scales` and `b_scales`. + The last dim is the contracting dim, and block size is inferred. + + Mathematically, this operation is equivalent to:: + + a_block_size = a.shape[-1] // a_scales.shape[-1] + b_block_size = b.shape[-1] // b_scales.shape[-1] + a_scaled = a * jnp.repeat(a_scales, a_block_size, axis=-1) + b_scaled = b * jnp.repeat(b_scales, b_block_size, axis=-1) + jnp.einsum('BMK,BNK->BMN', a_scaled, b_scaled) + Args: - lhs (Array): A 3D array of shape (B, M, K). - rhs (Array): A 3D array of shape (B, N, K). - lhs_scales (Array): A 3D array of shape (B, M, K_block). - rhs_scales (Array): A 3D array of shape (B, N, K_block). - preferred_element_type (DTypeLike, optional): The preferred data type - for the computation. Defaults to `jnp.float32`. + lhs (Array): Operand a, shape (B, M, K). + rhs (Array): Operand b, shape (B, N, K). + lhs_scales (Array): Shape (B, M, K_a), where `K % K_a == 0`. + rhs_scales (Array): Shape (B, N, K_b), where `K % K_b == 0`. + preferred_element_type (DTypeLike, optional): Defaults to `jnp.float32`. + Returns: - Array: A 3D array of shape (B, M, N) representing the scaled matrix - multiplication result. - Raises: - AssertionError: If the number of columns in `lhs` (`lhs_K`) does not - match the number of columns in `rhs` (`rhs_K`). + Array of shape (B, M, N). + Notes: - - The function ensures that the `preferred_element_type` is - danonicalized before passing it to the underlying computation. - - Scaling is applied to the matrices based on the `lhs_scales` and - `rhs_scales` arrays, enabling efficient computations in blocks. + - We currently do not support user-defined `precision` for customizing the + compute data type. It is fixed to `jnp.float32`. + - Block size is inferred as `K // K_a` for `a` and `K // K_b` for `b`. + - To use cuDNN with Nvidia Blackwell GPUs, inputs must match:: + + # mxfp8 + a, b: jnp.float8_e4m3fn | jnp.float8_e5m2 + a_scales, b_scales: jnp.float8_e8m0fnu + block_size: 32 + # nvfp4 + a, b: jnp.float4_e2m1fn + a_scales, b_scales: jnp.float8_e4m3fn + block_size: 16 + + Examples: + + Basic case: + + >>> a = jnp.array([1, 2, 3]).reshape((1, 1, 3)) + >>> b = jnp.array([4, 5, 6]).reshape((1, 1, 3)) + >>> a_scales = jnp.array([0.5]).reshape((1, 1, 1)) + >>> b_scales = jnp.array([0.5]).reshape((1, 1, 1)) + >>> scaled_matmul(a, b, a_scales, b_scales) # doctest: +SKIP + Array([[[8.]]], dtype=float32) + + Using fused cuDNN call on Blackwell GPUs: + + >>> dtype = jnp.float8_e4m3fn + >>> a = jax.random.normal(jax.random.PRNGKey(1), (3, 128, 64), dtype=dtype) + >>> b = jax.random.normal(jax.random.PRNGKey(2), (3, 128, 64), dtype=dtype) + >>> a_scales = jnp.ones((3, 128, 4), dtype=jnp.float8_e8m0fnu) + >>> b_scales = jnp.ones((3, 128, 4), dtype=jnp.float8_e8m0fnu) + >>> scaled_matmul(a, b, a_scales, b_scales) # doctest: +SKIP """ - B, M, lhs_K = lhs.shape - _, N, rhs_K = rhs.shape - assert lhs_K == rhs_K - _, _, K_block = lhs_scales.shape + a, b, a_scales, b_scales = lhs, rhs, lhs_scales, rhs_scales + if not all(x.ndim == 3 for x in (a, b, a_scales, b_scales)): + raise ValueError( + "scaled_matmul requires all inputs to be 3-dimensional arrays" + ) + + B_a, M_a, K_a = a.shape + B_b, N_b, K_b = b.shape + if K_a != K_b or B_a != B_b: + raise ValueError( + "scaled_matmul requires inputs a and b to have matching batch (B) " + f"and contract (K) dimensions, but got shapes {a.shape} and " + f"{b.shape}" + ) + + B_as, M_as, K_as = a_scales.shape + B_bs, N_bs, K_bs = b_scales.shape + if K_as != K_bs or B_as != B_bs: + raise ValueError( + "scaled_matmul requires scales to have matching batch (B) and " + f"contract (K) dimensions, but got shapes {a_scales.shape} and " + f"{b_scales.shape}" + ) + + if M_as != M_a or N_bs != N_b: + raise ValueError( + "scaled_matmul requires scales to match non-contract dimensions of " + f"inputs, but got shapes a: {a.shape}, b: {b.shape}, a_scales: " + f"{a_scales.shape}, b_scales: {b_scales.shape}" + ) preferred_element_type = dtypes.canonicalize_dtype( np.dtype(preferred_element_type) ) out = cudnn_scaled_matmul( - lhs, - rhs, - lhs_scales, - rhs_scales, + a, + b, + a_scales, + b_scales, preferred_element_type=preferred_element_type, ) return out +def get_scaled_dot_general_config(mode: Literal['nvfp4', 'mxfp8'], + global_scale: Array | None = None): + r"""Get quantization configs for scaled_dot_general. + + Create quantization configs for the `jax.nn.scaled_dot_general`. + + See Also: + - :func:`jax.nn.scaled_dot_general`: Scaled dot general function. + """ + + if mode == 'nvfp4': + one = jnp.ones((1,), dtype=jnp.float32) + return BlockScaleConfig( + mode='nvfp4', + block_size=16, + data_type=jnp.float4_e2m1fn, + scale_type=jnp.float8_e4m3fn, + global_scale=one if global_scale is None else global_scale, + infer_only=False + ) + elif mode == 'mxfp8': + return BlockScaleConfig( + mode='mxfp8', + block_size=32, + data_type=jnp.float8_e4m3fn, + scale_type=jnp.float8_e8m0fnu, + global_scale=None, + infer_only=False + ) + else: + raise ValueError(f"Unsupported mode: {mode}") + def scaled_dot_general( lhs, rhs, dimension_numbers, preferred_element_type=jnp.float32, - configs: List[BlockScaleConfig] | None = None, + configs: list[BlockScaleConfig] | None = None, implementation: Literal['cudnn'] | None = None, ): r"""Scaled dot general operation. - Computes the scaled dot general on lhs, rhs with quanitzation specified by configs: - .. math:: - \widehat{lhs}, s_a=\mathrm{quantize}(lhs) \\ - \widehat{rhs}, s_b=\mathrm{quantize}(rhs) \\ - \mathrm{ScaledDot}(lhs, rhs)=s_a \cdot s_b \cdot \mathrm{dot}(\widehat{lhs}, \widehat{rhs}) + + Performs a generalized dot product with block-scaled quantization on the + lhs and rhs inputs. This operation extends `lax.dot_general` to support + user-defined scaling configurations. + + Essentially, the operation follows:: + + a, a_scales = quantize(lhs, configs[0]) + b, b_scales = quantize(rhs, configs[1]) + c = jax.nn.scaled_matmul(a, b, a_scales, b_scales) + Args: - lhs: Left-hand side input tensor. - rhs: Right-hand side input tensor. - dimension_numbers: A tuple specifying the contraction and batch dimensions - for the dot general operation. Must follow the format: - `((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims, rhs_batch_dims))`. - preferred_element_type: The preferred output data type. Supported types are - `jnp.float32`, `jnp.bfloat16`, and `jnp.float16`. Defaults to `jnp.float32`. - configs: A list of `BlockScaleConfig` specifying the scaling - configurations for the operation. Defaults to `mxfp8`. - implementation: A string to control which implementation backend to use. - Supported strings are `cudnn` (cuDNN block scaled dot). It defaults - to `None`, which will automatically select the best available backend. + lhs (ArrayLike): Input array. + rhs (ArrayLike): Input array. + dimension_numbers (DotDimensionNumbers): A tuple of two tuples specifying + the contraction and batch dimensions: + `((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims, rhs_batch_dims))`. + preferred_element_type (DTypeLike, optional): Output data type of the dot + product. Defaults to `jnp.float32`. Other valid types include + `jnp.bfloat16` and `jnp.float16`. + configs (list of BlockScaleConfig, optional): Scaling configurations for + lhs, rhs, and gradients. Users can obtain valid configurations via + `jax.nn.get_scaled_dot_general_config`. Currently, `nvfp4` and `mxfp8` + are supported. If `None`, falls back to `lax.dot_general`. + implementation: str + (Deprecated) Backend selector, now ignored. The system chooses the backend + automatically. Scheduled for removal in future releases. + Returns: - The result of the scaled dot general operation. + Array: The resulting tensor, with batch dimensions first, followed by + non-contracting/non-batch dimensions of lhs, and then those of rhs. + + See Also: + - :func:`jax.nn.scaled_matmul`: Scaled matmul function. + - :func:`jax.lax.dot_general`: General dot product operator. + + Notes: + - Unlike `nn.scaled_matmul`, which assumes quantized low-precision + inputs with explicit scaling factors, this operator takes high-precision + inputs, applies quantization internally, and handles the backward pass. + + Examples: + + Creating config for mxfp8: + + >>> configs = [jax.nn.get_scaled_dot_general_config('mxfp8')] * 3 + + Creating config for nvfp4: + + >>> global_scale = jnp.array([0.5], jnp.float32) + >>> configs = [jax.nn.get_scaled_dot_general_config('nvfp4', global_scale)] * 3 + + Using scaled_dot_general with the configs: + + >>> import functools + >>> scaled_dot_general_fn = functools.partial(jax.nn.scaled_dot_general, configs=configs) + >>> lhs = jax.random.normal(jax.random.PRNGKey(1), (3, 128, 64)) + >>> rhs = jax.random.normal(jax.random.PRNGKey(2), (3, 128, 64)) + >>> out = scaled_dot_general_fn(lhs, rhs, (((2,), (2,)), ((0,), (0,)))) # doctest: +SKIP """ - # Create configs if not provided - if configs is None: - if dtypes.float8_e8m0fnu is None: - raise ValueError("Requires >= ml_dtypes 0.5.0 to support float8_e8m0fnu") - mxfp8_config = BlockScaleConfig( - mode='mxfp8', - block_size=32, - data_type=jnp.float8_e4m3fn, - scale_type=jnp.float8_e8m0fnu, - global_scale=None, - infer_only=False - ) - configs = [mxfp8_config for _ in range(3)] + if implementation is not None: + warnings.warn("Backend selector, now ignored. The system chooses the " + "backend automatically.", DeprecationWarning) - if implementation is None: - implementation = 'cudnn' + if configs is None: + return lax.dot_general(lhs, rhs, dimension_numbers, + preferred_element_type=preferred_element_type) - match implementation: - case 'cudnn': - out = cudnn_scaled_dot_general( - lhs, rhs, dimension_numbers, - preferred_element_type=preferred_element_type, - configs=configs - ) - case _: - raise ValueError(f"Unsupported implementation option: {implementation}") + out = cudnn_scaled_dot_general( + lhs, rhs, dimension_numbers, + preferred_element_type=preferred_element_type, + configs=configs + ) return out diff --git a/jax/_src/nn/initializers.py b/jax/_src/nn/initializers.py index 287e8f039e1d..855729fa16ff 100644 --- a/jax/_src/nn/initializers.py +++ b/jax/_src/nn/initializers.py @@ -30,6 +30,7 @@ from jax import random from jax._src import core from jax._src import dtypes +from jax._src.sharding_impls import canonicalize_sharding from jax._src.typing import Array, ArrayLike from jax._src.util import set_module @@ -48,7 +49,8 @@ class Initializer(Protocol): def __call__(self, key: Array, shape: core.Shape, - dtype: DTypeLikeInexact = jnp.float_) -> Array: + dtype: DTypeLikeInexact = jnp.float_, + out_sharding=None) -> Array: raise NotImplementedError @export @@ -100,9 +102,12 @@ def constant(value: ArrayLike, """ def init(key: Array, shape: core.Shape, - dtype: DTypeLikeInexact = dtype) -> Array: + dtype: DTypeLikeInexact = dtype, + out_sharding=None) -> Array: dtype = dtypes.canonicalize_dtype(dtype) - return jnp.full(shape, value, dtype=dtype) + out_sharding = canonicalize_sharding( + out_sharding, 'nn.initializers.constant') + return jnp.full(shape, value, dtype=dtype, device=out_sharding) return init @export @@ -126,9 +131,11 @@ def uniform(scale: RealNumeric = 1e-2, """ def init(key: Array, shape: core.Shape, - dtype: DTypeLikeInexact = dtype) -> Array: + dtype: DTypeLikeInexact = dtype, + out_sharding=None) -> Array: dtype = dtypes.canonicalize_dtype(dtype) - return random.uniform(key, shape, dtype) * jnp.array(scale, dtype) + return random.uniform(key, shape, dtype, + out_sharding=out_sharding) * jnp.array(scale, dtype) return init @export @@ -152,9 +159,11 @@ def normal(stddev: RealNumeric = 1e-2, """ def init(key: Array, shape: core.Shape, - dtype: DTypeLikeInexact = dtype) -> Array: + dtype: DTypeLikeInexact = dtype, + out_sharding=None) -> Array: dtype = dtypes.canonicalize_dtype(dtype) - return random.normal(key, shape, dtype) * jnp.array(stddev, dtype) + return random.normal(key, shape, dtype, + out_sharding=out_sharding) * jnp.array(stddev, dtype) return init @export @@ -189,10 +198,12 @@ def truncated_normal(stddev: RealNumeric = 1e-2, def init(key: Array, shape: core.Shape, - dtype: DTypeLikeInexact = dtype) -> Array: + dtype: DTypeLikeInexact = dtype, + out_sharding=None) -> Array: dtype = dtypes.canonicalize_dtype(dtype) return random.truncated_normal( - key, lower, upper, shape, dtype) * jnp.array(stddev, dtype) + key, lower, upper, shape, dtype, + out_sharding=out_sharding) * jnp.array(stddev, dtype) return init @export @@ -267,7 +278,7 @@ def variance_scaling( Literal["uniform"]), in_axis: int | Sequence[int] = -2, out_axis: int | Sequence[int] = -1, - batch_axis: Sequence[int] = (), + batch_axis: int | Sequence[int] = (), dtype: DTypeLikeInexact = jnp.float_ ) -> Initializer: r""" @@ -315,7 +326,8 @@ def variance_scaling( def init(key: Array, shape: core.Shape, - dtype: DTypeLikeInexact = dtype) -> Array: + dtype: DTypeLikeInexact = dtype, + out_sharding=None) -> Array: shape = core.canonicalize_shape(shape) dtype = dtypes.canonicalize_dtype(dtype) fan_in, fan_out = _compute_fans(shape, in_axis, out_axis, batch_axis) @@ -332,16 +344,19 @@ def init(key: Array, if jnp.issubdtype(dtype, jnp.floating): # constant is stddev of standard normal truncated to (-2, 2) stddev = jnp.sqrt(variance) / jnp.array(.87962566103423978, dtype) - return random.truncated_normal(key, -2, 2, shape, dtype) * stddev + return random.truncated_normal(key, -2, 2, shape, dtype, + out_sharding=out_sharding) * stddev else: # constant is stddev of complex standard normal truncated to 2 stddev = jnp.sqrt(variance) / jnp.array(.95311164380491208, dtype) return _complex_truncated_normal(key, 2, shape, dtype) * stddev elif distribution == "normal": - return random.normal(key, shape, dtype) * jnp.sqrt(variance) + return random.normal(key, shape, dtype, + out_sharding=out_sharding) * jnp.sqrt(variance) elif distribution == "uniform": if jnp.issubdtype(dtype, jnp.floating): - return random.uniform(key, shape, dtype, -1) * jnp.sqrt(3 * variance) + return random.uniform(key, shape, dtype, -1, + out_sharding=out_sharding) * jnp.sqrt(3 * variance) else: return _complex_uniform(key, shape, dtype) * jnp.sqrt(variance) else: @@ -352,7 +367,7 @@ def init(key: Array, @export def glorot_uniform(in_axis: int | Sequence[int] = -2, out_axis: int | Sequence[int] = -1, - batch_axis: Sequence[int] = (), + batch_axis: int | Sequence[int] = (), dtype: DTypeLikeInexact = jnp.float_) -> Initializer: """Builds a Glorot uniform initializer (aka Xavier uniform initializer). @@ -390,7 +405,7 @@ def glorot_uniform(in_axis: int | Sequence[int] = -2, @export def glorot_normal(in_axis: int | Sequence[int] = -2, out_axis: int | Sequence[int] = -1, - batch_axis: Sequence[int] = (), + batch_axis: int | Sequence[int] = (), dtype: DTypeLikeInexact = jnp.float_) -> Initializer: """Builds a Glorot normal initializer (aka Xavier normal initializer). @@ -428,7 +443,7 @@ def glorot_normal(in_axis: int | Sequence[int] = -2, @export def lecun_uniform(in_axis: int | Sequence[int] = -2, out_axis: int | Sequence[int] = -1, - batch_axis: Sequence[int] = (), + batch_axis: int | Sequence[int] = (), dtype: DTypeLikeInexact = jnp.float_) -> Initializer: """Builds a Lecun uniform initializer. @@ -464,7 +479,7 @@ def lecun_uniform(in_axis: int | Sequence[int] = -2, @export def lecun_normal(in_axis: int | Sequence[int] = -2, out_axis: int | Sequence[int] = -1, - batch_axis: Sequence[int] = (), + batch_axis: int | Sequence[int] = (), dtype: DTypeLikeInexact = jnp.float_) -> Initializer: """Builds a Lecun normal initializer. @@ -500,7 +515,7 @@ def lecun_normal(in_axis: int | Sequence[int] = -2, @export def he_uniform(in_axis: int | Sequence[int] = -2, out_axis: int | Sequence[int] = -1, - batch_axis: Sequence[int] = (), + batch_axis: int | Sequence[int] = (), dtype: DTypeLikeInexact = jnp.float_) -> Initializer: """Builds a He uniform initializer (aka Kaiming uniform initializer). @@ -538,7 +553,7 @@ def he_uniform(in_axis: int | Sequence[int] = -2, @export def he_normal(in_axis: int | Sequence[int] = -2, out_axis: int | Sequence[int] = -1, - batch_axis: Sequence[int] = (), + batch_axis: int | Sequence[int] = (), dtype: DTypeLikeInexact = jnp.float_) -> Initializer: """Builds a He normal initializer (aka Kaiming normal initializer). @@ -601,7 +616,10 @@ def orthogonal(scale: RealNumeric = 1.0, """ def init(key: Array, shape: core.Shape, - dtype: DTypeLikeInexact = dtype) -> Array: + dtype: DTypeLikeInexact = dtype, + out_sharding=None) -> Array: + if out_sharding is not None: + raise NotImplementedError dtype = dtypes.canonicalize_dtype(dtype) if len(shape) < 2: raise ValueError("orthogonal initializer requires at least a 2D shape") @@ -651,7 +669,10 @@ def delta_orthogonal( """ def init(key: Array, shape: core.Shape, - dtype: DTypeLikeInexact = dtype) -> Array: + dtype: DTypeLikeInexact = dtype, + out_sharding=None) -> Array: + if out_sharding is not None: + raise NotImplementedError dtype = dtypes.canonicalize_dtype(dtype) if len(shape) not in [3, 4, 5]: raise ValueError("Delta orthogonal initializer requires a 3D, 4D or 5D " diff --git a/jax/_src/numpy/array.py b/jax/_src/numpy/array.py new file mode 100644 index 000000000000..73bbd7d09554 --- /dev/null +++ b/jax/_src/numpy/array.py @@ -0,0 +1,383 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib +from typing import Any + +import numpy as np + +from jax._src import api +from jax._src import config +from jax._src import core +from jax._src import dtypes +from jax._src import tree_util +from jax._src import xla_bridge +from jax._src.lax import lax +from jax._src.lib import xla_client as xc +from jax._src.numpy import util +from jax._src.typing import Array, ArrayLike, DTypeLike +from jax._src.sharding import Sharding + + +export = util.set_module('jax.numpy') + +for pkg_name in ['jax_cuda12_plugin', 'jax.jaxlib.cuda']: + try: + cuda_plugin_extension = importlib.import_module( + f'{pkg_name}.cuda_plugin_extension' + ) + except ImportError: + cuda_plugin_extension = None # type: ignore + else: + break + + +def _supports_buffer_protocol(obj): + try: + view = memoryview(obj) + except TypeError: + return False + else: + return True + + +def _make_string_array( + object: np.ndarray, + dtype: DTypeLike | None = None, + ndmin: int = 0, + device: xc.Device | Sharding | None = None, +) -> Array: + if not isinstance(object, np.ndarray): + raise TypeError( + "Currently, string arrays can only be made from NumPy" + f" arrays. Got: {type(object)}." + ) + if dtype is not None and ( + dtypes.is_string_dtype(object.dtype) != dtypes.is_string_dtype(dtype) + ): + raise TypeError( + f"Cannot make an array with dtype {dtype} from an object with dtype" + f" {object.dtype}." + ) + if ndmin > object.ndim: + raise TypeError( + f"ndmin {ndmin} cannot be greater than object's ndims" + f" {object.ndim} for string arrays." + ) + + # Just do a device_put since XLA does not support string as a data type. + return api.device_put(x=object, device=device) + + +@export +def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True, + order: str | None = "K", ndmin: int = 0, + *, device: xc.Device | Sharding | None = None) -> Array: + """Convert an object to a JAX array. + + JAX implementation of :func:`numpy.array`. + + Args: + object: an object that is convertible to an array. This includes JAX + arrays, NumPy arrays, Python scalars, Python collections like lists + and tuples, objects with an ``__array__`` method, and objects + supporting the Python buffer protocol. + dtype: optionally specify the dtype of the output array. If not + specified it will be inferred from the input. + copy: specify whether to force a copy of the input. Default: True. + order: not implemented in JAX + ndmin: integer specifying the minimum number of dimensions in the + output array. + device: optional :class:`~jax.Device` or :class:`~jax.sharding.Sharding` + to which the created array will be committed. + + Returns: + A JAX array constructed from the input. + + See also: + - :func:`jax.numpy.asarray`: like `array`, but by default only copies + when necessary. + - :func:`jax.numpy.from_dlpack`: construct a JAX array from an object + that implements the dlpack interface. + - :func:`jax.numpy.frombuffer`: construct a JAX array from an object + that implements the buffer interface. + + Examples: + Constructing JAX arrays from Python scalars: + + >>> jnp.array(True) + Array(True, dtype=bool) + >>> jnp.array(42) + Array(42, dtype=int32, weak_type=True) + >>> jnp.array(3.5) + Array(3.5, dtype=float32, weak_type=True) + >>> jnp.array(1 + 1j) + Array(1.+1.j, dtype=complex64, weak_type=True) + + Constructing JAX arrays from Python collections: + + >>> jnp.array([1, 2, 3]) # list of ints -> 1D array + Array([1, 2, 3], dtype=int32) + >>> jnp.array([(1, 2, 3), (4, 5, 6)]) # list of tuples of ints -> 2D array + Array([[1, 2, 3], + [4, 5, 6]], dtype=int32) + >>> jnp.array(range(5)) + Array([0, 1, 2, 3, 4], dtype=int32) + + Constructing JAX arrays from NumPy arrays: + + >>> jnp.array(np.linspace(0, 2, 5)) + Array([0. , 0.5, 1. , 1.5, 2. ], dtype=float32) + + Constructing a JAX array via the Python buffer interface, using Python's + built-in :mod:`array` module. + + >>> from array import array + >>> pybuffer = array('i', [2, 3, 5, 7]) + >>> jnp.array(pybuffer) + Array([2, 3, 5, 7], dtype=int32) + """ + if order is not None and order != "K": + raise NotImplementedError("Only implemented for order='K'") + + # check if the given dtype is compatible with JAX + dtypes.check_user_dtype_supported(dtype, "array") + + # Here we make a judgment call: we only return a weakly-typed array when the + # input object itself is weakly typed. That ensures asarray(x) is a no-op + # whenever x is weak, but avoids introducing weak types with something like + # array([1, 2, 3]) + weak_type = dtype is None and dtypes.is_weakly_typed(object) + if device is None and isinstance(object, core.Tracer): + sharding = object.aval.sharding + sharding = None if sharding.mesh.empty else sharding + else: + sharding = util.canonicalize_device_to_sharding(device) + + # Use device_put to avoid a copy for ndarray inputs. + if (not copy and isinstance(object, np.ndarray) and + (dtype is None or dtype == object.dtype) and (ndmin <= object.ndim) and + device is None): + # Keep the output uncommitted. + return api.device_put(object) + + # String arrays need separate handling because XLA does not support string + # as a data type. + if dtypes.is_string_dtype(dtype) or ( + hasattr(object, "dtype") and dtypes.is_string_dtype(object.dtype) + ): + return _make_string_array( + object=object, dtype=dtype, ndmin=ndmin, device=device + ) + + # For Python scalar literals, call coerce_to_array to catch any overflow + # errors. We don't use dtypes.is_python_scalar because we don't want this + # triggering for traced values. We do this here because it matters whether or + # not dtype is None. We don't assign the result because we want the raw object + # to be used for type inference below. + if isinstance(object, (bool, int, float, complex)): + _ = dtypes.coerce_to_array(object, dtype) + elif not isinstance(object, Array): + # Check if object supports any of the data exchange protocols + # (except dlpack, see data-apis/array-api#301). If it does, + # consume the object as jax array and continue (but not return) so + # that other array() arguments get processed against the input + # object. + # + # Notice that data exchange protocols define dtype in the + # corresponding data structures and it may not be available as + # object.dtype. So, we'll resolve the protocols here before + # evaluating object.dtype. + if hasattr(object, '__jax_array__'): + object = object.__jax_array__() + elif hasattr(object, '__cuda_array_interface__'): + cai = object.__cuda_array_interface__ + backend = xla_bridge.get_backend("cuda") + if cuda_plugin_extension is None: + device_id = None + else: + device_id = cuda_plugin_extension.get_device_ordinal(cai["data"][0]) + object = xc._xla.cuda_array_interface_to_buffer( + cai=cai, gpu_backend=backend, device_id=device_id) + + leaves, treedef = tree_util.tree_flatten(object, is_leaf=lambda x: x is None) + if any(leaf is None for leaf in leaves): + raise ValueError("None is not a valid value for jnp.array") + leaves = [ + leaf + if (leaf_jax_array := getattr(leaf, "__jax_array__", None)) is None + else leaf_jax_array() + for leaf in leaves + ] + if dtype is None: + # Use lattice_result_type rather than result_type to avoid canonicalization. + # Otherwise, weakly-typed inputs would have their dtypes canonicalized. + try: + dtype = dtypes._lattice_result_type(*leaves)[0] if leaves else dtypes.float_ + except TypeError: + # This happens if, e.g. one of the entries is a memoryview object. + # This is rare, so we only handle it if the normal path fails. + leaves = [_convert_to_array_if_dtype_fails(leaf) for leaf in leaves] + dtype = dtypes._lattice_result_type(*leaves)[0] + + if not weak_type: + dtype = dtypes.canonicalize_dtype(dtype, allow_extended_dtype=True) # type: ignore[assignment] + + object = treedef.unflatten(leaves) + out: ArrayLike + if all(not isinstance(leaf, Array) for leaf in leaves): + # TODO(jakevdp): falling back to numpy here fails to overflow for lists + # containing large integers; see discussion in + # https://github.com/jax-ml/jax/pull/6047. More correct would be to call + # coerce_to_array on each leaf, but this may have performance implications. + out = np.asarray(object, dtype=dtype) + elif isinstance(object, Array): + assert object.aval is not None + out = lax._array_copy(object) if copy else object + elif isinstance(object, (list, tuple)): + if object: + arrs = (array(elt, dtype=dtype, copy=False) for elt in object) + arrays_out = [lax.expand_dims(arr, [0]) for arr in arrs] + # lax.concatenate can be slow to compile for wide concatenations, so form a + # tree of concatenations as a workaround especially for op-by-op mode. + # (https://github.com/jax-ml/jax/issues/653). + k = 16 + while len(arrays_out) > k: + arrays_out = [lax.concatenate(arrays_out[i:i+k], 0) + for i in range(0, len(arrays_out), k)] + out = lax.concatenate(arrays_out, 0) + else: + out = np.array([], dtype=dtype) + elif _supports_buffer_protocol(object): + object = memoryview(object) + # TODO(jakevdp): update this once we support NumPy 2.0 semantics for the copy arg. + out = np.array(object) if copy else np.asarray(object) + else: + raise TypeError(f"Unexpected input type for array: {type(object)}") + out_array: Array = lax._convert_element_type( + out, dtype, weak_type=weak_type, sharding=sharding) + if ndmin > np.ndim(out_array): + out_array = lax.expand_dims(out_array, range(ndmin - np.ndim(out_array))) + return out_array + + +def _get_platform( + device_or_sharding: xc.Device | Sharding | None | str) -> str: + """Get device_or_sharding platform or look up config.default_device.value.""" + if isinstance(device_or_sharding, xc.Device): + return device_or_sharding.platform + elif isinstance(device_or_sharding, Sharding): + return list(device_or_sharding.device_set)[0].platform + elif isinstance(device_or_sharding, str): + return device_or_sharding + elif device_or_sharding is None: + if config.default_device.value is None: + return xla_bridge.default_backend() + else: + return _get_platform(config.default_device.value) + else: + raise ValueError(f"`{device_or_sharding = }` was passed to" + "`canonicalize_or_get_default_platform`, only xc.Device," + " Sharding, None or str values are supported.") + + +def _convert_to_array_if_dtype_fails(x: ArrayLike) -> ArrayLike: + try: + dtypes.dtype(x) + except TypeError: + return np.asarray(x) + else: + return x + + +@export +def asarray(a: Any, dtype: DTypeLike | None = None, order: str | None = None, + *, copy: bool | None = None, + device: xc.Device | Sharding | None = None) -> Array: + """Convert an object to a JAX array. + + JAX implementation of :func:`numpy.asarray`. + + Args: + a: an object that is convertible to an array. This includes JAX + arrays, NumPy arrays, Python scalars, Python collections like lists + and tuples, objects with an ``__array__`` method, and objects + supporting the Python buffer protocol. + dtype: optionally specify the dtype of the output array. If not + specified it will be inferred from the input. + order: not implemented in JAX + copy: optional boolean specifying the copy mode. If True, then always + return a copy. If False, then error if a copy is necessary. Default is + None, which will only copy when necessary. + device: optional :class:`~jax.Device` or :class:`~jax.sharding.Sharding` + to which the created array will be committed. + + Returns: + A JAX array constructed from the input. + + See also: + - :func:`jax.numpy.array`: like `asarray`, but defaults to `copy=True`. + - :func:`jax.numpy.from_dlpack`: construct a JAX array from an object + that implements the dlpack interface. + - :func:`jax.numpy.frombuffer`: construct a JAX array from an object + that implements the buffer interface. + + Examples: + Constructing JAX arrays from Python scalars: + + >>> jnp.asarray(True) + Array(True, dtype=bool) + >>> jnp.asarray(42) + Array(42, dtype=int32, weak_type=True) + >>> jnp.asarray(3.5) + Array(3.5, dtype=float32, weak_type=True) + >>> jnp.asarray(1 + 1j) + Array(1.+1.j, dtype=complex64, weak_type=True) + + Constructing JAX arrays from Python collections: + + >>> jnp.asarray([1, 2, 3]) # list of ints -> 1D array + Array([1, 2, 3], dtype=int32) + >>> jnp.asarray([(1, 2, 3), (4, 5, 6)]) # list of tuples of ints -> 2D array + Array([[1, 2, 3], + [4, 5, 6]], dtype=int32) + >>> jnp.asarray(range(5)) + Array([0, 1, 2, 3, 4], dtype=int32) + + Constructing JAX arrays from NumPy arrays: + + >>> jnp.asarray(np.linspace(0, 2, 5)) + Array([0. , 0.5, 1. , 1.5, 2. ], dtype=float32) + + Constructing a JAX array via the Python buffer interface, using Python's + built-in :mod:`array` module. + + >>> from array import array + >>> pybuffer = array('i', [2, 3, 5, 7]) + >>> jnp.asarray(pybuffer) + Array([2, 3, 5, 7], dtype=int32) + """ + # For copy=False, the array API specifies that we raise a ValueError if the input supports + # the buffer protocol but a copy is required. Since array() supports the buffer protocol + # via numpy, this is only the case when the default device is not 'cpu' + if (copy is False and not isinstance(a, Array) + and _get_platform(device) != "cpu" + and _supports_buffer_protocol(a)): + raise ValueError(f"jnp.asarray: cannot convert object of type {type(a)} to JAX Array " + f"on platform={_get_platform(device)} with " + "copy=False. Consider using copy=None or copy=True instead.") + dtypes.check_user_dtype_supported(dtype, "asarray") + if dtype is not None: + dtype = dtypes.canonicalize_dtype(dtype, allow_extended_dtype=True) # type: ignore[assignment] + return array(a, dtype=dtype, copy=bool(copy), order=order, device=device) diff --git a/jax/_src/numpy/array_api_metadata.py b/jax/_src/numpy/array_api_metadata.py index 4a01f579a67e..af4e27cd5f8d 100644 --- a/jax/_src/numpy/array_api_metadata.py +++ b/jax/_src/numpy/array_api_metadata.py @@ -21,13 +21,14 @@ from types import ModuleType -import jax from jax._src.sharding import Sharding from jax._src.lib import xla_client as xc -from jax._src import dtypes as _dtypes, config +from jax._src import config +from jax._src import dtypes as _dtypes +from jax._src import xla_bridge as xb -__array_api_version__ = '2023.12' +__array_api_version__ = '2024.12' def __array_namespace__(self, *, api_version: None | str = None) -> ModuleType: @@ -38,6 +39,7 @@ def __array_namespace__(self, *, api_version: None | str = None) -> ModuleType: if api_version is not None and api_version != __array_api_version__: raise ValueError(f"{api_version=!r} is not available; " f"available versions are: {[__array_api_version__]}") + import jax.numpy # pytype: disable=import-error return jax.numpy @@ -51,8 +53,9 @@ class ArrayNamespaceInfo: .. _Python array API: https://data-apis.org/array-api/ """ _capabilities = { - "boolean indexing": True, - "data-dependent shapes": False, + "boolean indexing": False, # within transformations + "data-dependent shapes": False, # within transformations + "max dimensions": 64, # XLA limitation } def _build_dtype_dict(self): @@ -72,7 +75,10 @@ def default_device(self): return None def devices(self): - return jax.devices() + out = [None] # None indicates "uncommitted" + for backend in xb.backends(): + out.extend(xb.devices(backend)) + return out def capabilities(self): return self._capabilities diff --git a/jax/_src/numpy/array_creation.py b/jax/_src/numpy/array_creation.py index 67418e7322c9..63ef76c01b69 100644 --- a/jax/_src/numpy/array_creation.py +++ b/jax/_src/numpy/array_creation.py @@ -13,19 +13,23 @@ # limitations under the License. import types -from typing import Any +from functools import partial +import operator +from typing import Any, Literal, overload import numpy as np -import jax -from jax import lax +from jax._src.api import device_put, jit from jax._src import core from jax._src import dtypes +from jax._src.lax import lax from jax._src.lib import xla_client as xc +from jax._src.numpy.array import asarray +from jax._src.numpy import ufuncs from jax._src.numpy import util +from jax._src.sharding import Sharding from jax._src.typing import Array, ArrayLike, DuckTypedArray, DTypeLike -from jax._src.util import set_module -from jax.sharding import Sharding +from jax._src.util import canonicalize_axis, set_module export = set_module('jax.numpy') @@ -50,7 +54,8 @@ def zeros(shape: Any, dtype: DTypeLike | None = None, *, Args: shape: int or sequence of ints specifying the shape of the created array. - dtype: optional dtype for the created array; defaults to floating point. + dtype: optional dtype for the created array; defaults to float32 or float64 + depending on the X64 configuration (see :ref:`default-dtypes`). device: (optional) :class:`~jax.Device` or :class:`~jax.sharding.Sharding` to which the created array will be committed. @@ -87,7 +92,8 @@ def ones(shape: Any, dtype: DTypeLike | None = None, *, Args: shape: int or sequence of ints specifying the shape of the created array. - dtype: optional dtype for the created array; defaults to floating point. + dtype: optional dtype for the created array; defaults to float32 or float64 + depending on the X64 configuration (see :ref:`default-dtypes`). device: (optional) :class:`~jax.Device` or :class:`~jax.sharding.Sharding` to which the created array will be committed. @@ -126,7 +132,8 @@ def empty(shape: Any, dtype: DTypeLike | None = None, *, Args: shape: int or sequence of ints specifying the shape of the created array. - dtype: optional dtype for the created array; defaults to floating point. + dtype: optional dtype for the created array; defaults to float32 or float64 + depending on the X64 configuration (see :ref:`default-dtypes`). device: (optional) :class:`~jax.Device` or :class:`~jax.sharding.Sharding` to which the created array will be committed. @@ -204,8 +211,8 @@ def full(shape: Any, fill_value: ArrayLike, shape = canonicalize_shape(shape) return lax.full(shape, fill_value, dtype, sharding=util.normalize_device_to_sharding(device)) else: - return jax.device_put( - util._broadcast_to(jax.numpy.asarray(fill_value, dtype=dtype), shape), device) + return device_put( + util._broadcast_to(asarray(fill_value, dtype=dtype), shape), device) @export @@ -244,6 +251,8 @@ def zeros_like(a: ArrayLike | DuckTypedArray, [0, 0, 0]], dtype=int32) """ if not (hasattr(a, 'dtype') and hasattr(a, 'shape')): # support duck typing + if hasattr(a, '__jax_array__'): + a = a.__jax_array__() util.check_arraylike("zeros_like", a) dtypes.check_user_dtype_supported(dtype, "zeros_like") if shape is not None: @@ -287,6 +296,8 @@ def ones_like(a: ArrayLike | DuckTypedArray, [1, 1, 1]], dtype=int32) """ if not (hasattr(a, 'dtype') and hasattr(a, 'shape')): # support duck typing + if hasattr(a, '__jax_array__'): + a = a.__jax_array__() util.check_arraylike("ones_like", a) dtypes.check_user_dtype_supported(dtype, "ones_like") if shape is not None: @@ -332,9 +343,13 @@ def empty_like(prototype: ArrayLike | DuckTypedArray, [0, 0, 0]], dtype=int32) """ if not (hasattr(prototype, 'dtype') and hasattr(prototype, 'shape')): # support duck typing - util.check_arraylike("empty_like", prototype) - dtypes.check_user_dtype_supported(dtype, "empty_like") - return zeros_like(prototype, dtype=dtype, shape=shape, device=device) + if hasattr(prototype, '__jax_array__'): + prototype = prototype.__jax_array__() + util.check_arraylike("ones_like", prototype) + dtypes.check_user_dtype_supported(dtype, "ones_like") + if shape is not None: + shape = canonicalize_shape(shape) + return lax.full_like(prototype, 0, dtype, shape, sharding=util.normalize_device_to_sharding(device)) @export @@ -382,6 +397,8 @@ def full_like(a: ArrayLike | DuckTypedArray, util.check_arraylike("full_like", 0, fill_value) else: util.check_arraylike("full_like", a, fill_value) + if hasattr(a, '__jax_array__'): + a = a.__jax_array__() dtypes.check_user_dtype_supported(dtype, "full_like") if shape is not None: shape = canonicalize_shape(shape) @@ -390,5 +407,317 @@ def full_like(a: ArrayLike | DuckTypedArray, else: shape = np.shape(a) if shape is None else shape # type: ignore[arg-type] dtype = dtypes.result_type(a) if dtype is None else dtype - return jax.device_put( - util._broadcast_to(jax.numpy.asarray(fill_value, dtype=dtype), shape), device) + return device_put( + util._broadcast_to(asarray(fill_value, dtype=dtype), shape), device) + +@overload +def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50, + endpoint: bool = True, retstep: Literal[False] = False, + dtype: DTypeLike | None = None, + axis: int = 0, + *, device: xc.Device | Sharding | None = None) -> Array: ... +@overload +def linspace(start: ArrayLike, stop: ArrayLike, num: int, + endpoint: bool, retstep: Literal[True], + dtype: DTypeLike | None = None, + axis: int = 0, + *, device: xc.Device | Sharding | None = None) -> tuple[Array, Array]: ... +@overload +def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50, + endpoint: bool = True, *, retstep: Literal[True], + dtype: DTypeLike | None = None, + axis: int = 0, + device: xc.Device | Sharding | None = None) -> tuple[Array, Array]: ... +@overload +def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50, + endpoint: bool = True, retstep: bool = False, + dtype: DTypeLike | None = None, + axis: int = 0, + *, device: xc.Device | Sharding | None = None) -> Array | tuple[Array, Array]: ... +@export +def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50, + endpoint: bool = True, retstep: bool = False, + dtype: DTypeLike | None = None, + axis: int = 0, + *, device: xc.Device | Sharding | None = None) -> Array | tuple[Array, Array]: + """Return evenly-spaced numbers within an interval. + + JAX implementation of :func:`numpy.linspace`. + + Args: + start: scalar or array of starting values. + stop: scalar or array of stop values. + num: number of values to generate. Default: 50. + endpoint: if True (default) then include the ``stop`` value in the result. + If False, then exclude the ``stop`` value. + retstep: If True, then return a ``(result, step)`` tuple, where ``step`` is the + interval between adjacent values in ``result``. + axis: integer axis along which to generate the linspace. Defaults to zero. + device: optional :class:`~jax.Device` or :class:`~jax.sharding.Sharding` + to which the created array will be committed. + + Returns: + An array ``values``, or a tuple ``(values, step)`` if ``retstep`` is True, where: + + - ``values`` is an array of evenly-spaced values from ``start`` to ``stop`` + - ``step`` is the interval between adjacent values. + + See also: + - :func:`jax.numpy.arange`: Generate ``N`` evenly-spaced values given a starting + point and a step + - :func:`jax.numpy.logspace`: Generate logarithmically-spaced values. + - :func:`jax.numpy.geomspace`: Generate geometrically-spaced values. + + Examples: + List of 5 values between 0 and 10: + + >>> jnp.linspace(0, 10, 5) + Array([ 0. , 2.5, 5. , 7.5, 10. ], dtype=float32) + + List of 8 values between 0 and 10, excluding the endpoint: + + >>> jnp.linspace(0, 10, 8, endpoint=False) + Array([0. , 1.25, 2.5 , 3.75, 5. , 6.25, 7.5 , 8.75], dtype=float32) + + List of values and the step size between them + + >>> vals, step = jnp.linspace(0, 10, 9, retstep=True) + >>> vals + Array([ 0. , 1.25, 2.5 , 3.75, 5. , 6.25, 7.5 , 8.75, 10. ], dtype=float32) + >>> step + Array(1.25, dtype=float32) + + Multi-dimensional linspace: + + >>> start = jnp.array([0, 5]) + >>> stop = jnp.array([5, 10]) + >>> jnp.linspace(start, stop, 5) + Array([[ 0. , 5. ], + [ 1.25, 6.25], + [ 2.5 , 7.5 ], + [ 3.75, 8.75], + [ 5. , 10. ]], dtype=float32) + """ + num = core.concrete_dim_or_error(num, "'num' argument of jnp.linspace") + axis = core.concrete_or_error(operator.index, axis, "'axis' argument of jnp.linspace") + return _linspace(start, stop, num, endpoint, retstep, dtype, axis, device=device) + +@partial(jit, static_argnames=('num', 'endpoint', 'retstep', 'dtype', 'axis', 'device')) +def _linspace(start: ArrayLike, stop: ArrayLike, num: int = 50, + endpoint: bool = True, retstep: bool = False, + dtype: DTypeLike | None = None, + axis: int = 0, + *, device: xc.Device | Sharding | None = None) -> Array | tuple[Array, Array]: + """Implementation of linspace differentiable in start and stop args.""" + dtypes.check_user_dtype_supported(dtype, "linspace") + if num < 0: + raise ValueError(f"Number of samples, {num}, must be non-negative.") + start, stop = util.ensure_arraylike("linspace", start, stop) + + if dtype is None: + dtype = dtypes.to_inexact_dtype(dtypes.result_type(start, stop)) + dtype = dtypes.jax_dtype(dtype) + computation_dtype = dtypes.to_inexact_dtype(dtype) + start = start.astype(computation_dtype) + stop = stop.astype(computation_dtype) + + bounds_shape = list(lax.broadcast_shapes(np.shape(start), np.shape(stop))) + broadcast_start = util._broadcast_to(start, bounds_shape) + broadcast_stop = util._broadcast_to(stop, bounds_shape) + axis = len(bounds_shape) + axis + 1 if axis < 0 else axis + bounds_shape.insert(axis, 1) + div = (num - 1) if endpoint else num + if num > 1: + delta: Array = lax.convert_element_type(stop - start, computation_dtype) / asarray(div, dtype=computation_dtype) + iota_shape = [1,] * len(bounds_shape) + iota_shape[axis] = div + # This approach recovers the endpoints with float32 arithmetic, + # but can lead to rounding errors for integer outputs. + real_dtype = dtypes.finfo(computation_dtype).dtype + step = lax.iota(real_dtype, div).reshape(iota_shape) / asarray(div, real_dtype) + step = step.astype(computation_dtype) + out = (broadcast_start.reshape(bounds_shape) * (1 - step) + + broadcast_stop.reshape(bounds_shape) * step) + + if endpoint: + out = lax.concatenate([out, lax.expand_dims(broadcast_stop, (axis,))], + canonicalize_axis(axis, out.ndim)) + + elif num == 1: + delta = asarray(np.nan if endpoint else stop - start, dtype=computation_dtype) + out = broadcast_start.reshape(bounds_shape) + else: # num == 0 degenerate case, match numpy behavior + empty_shape = list(lax.broadcast_shapes(np.shape(start), np.shape(stop))) + empty_shape.insert(axis, 0) + delta = full((), np.nan, computation_dtype) + out = empty(empty_shape, dtype) + + if dtypes.issubdtype(dtype, np.integer) and not dtypes.issubdtype(out.dtype, np.integer): + out = lax.floor(out) + + sharding = util.canonicalize_device_to_sharding(device) + result = lax._convert_element_type(out, dtype, sharding=sharding) + return (result, delta) if retstep else result + + +@export +def logspace(start: ArrayLike, stop: ArrayLike, num: int = 50, + endpoint: bool = True, base: ArrayLike = 10.0, + dtype: DTypeLike | None = None, axis: int = 0) -> Array: + """Generate logarithmically-spaced values. + + JAX implementation of :func:`numpy.logspace`. + + Args: + start: scalar or array. Used to specify the start value. The start value is + ``base ** start``. + stop: scalar or array. Used to specify the stop value. The end value is + ``base ** stop``. + num: int, optional, default=50. Number of values to generate. + endpoint: bool, optional, default=True. If True, then include the ``stop`` value + in the result. If False, then exclude the ``stop`` value. + base: scalar or array, optional, default=10. Specifies the base of the logarithm. + dtype: optional. Specifies the dtype of the output. + axis: int, optional, default=0. Axis along which to generate the logspace. + + Returns: + An array of logarithm. + + See also: + - :func:`jax.numpy.arange`: Generate ``N`` evenly-spaced values given a starting + point and a step value. + - :func:`jax.numpy.linspace`: Generate evenly-spaced values. + - :func:`jax.numpy.geomspace`: Generate geometrically-spaced values. + + Examples: + List 5 logarithmically spaced values between 1 (``10 ** 0``) and 100 + (``10 ** 2``): + + >>> with jnp.printoptions(precision=3, suppress=True): + ... jnp.logspace(0, 2, 5) + Array([ 1. , 3.162, 10. , 31.623, 100. ], dtype=float32) + + List 5 logarithmically-spaced values between 1(``10 ** 0``) and 100 + (``10 ** 2``), excluding endpoint: + + >>> with jnp.printoptions(precision=3, suppress=True): + ... jnp.logspace(0, 2, 5, endpoint=False) + Array([ 1. , 2.512, 6.31 , 15.849, 39.811], dtype=float32) + + List 7 logarithmically-spaced values between 1 (``2 ** 0``) and 4 (``2 ** 2``) + with base 2: + + >>> with jnp.printoptions(precision=3, suppress=True): + ... jnp.logspace(0, 2, 7, base=2) + Array([1. , 1.26 , 1.587, 2. , 2.52 , 3.175, 4. ], dtype=float32) + + Multi-dimensional logspace: + + >>> start = jnp.array([0, 5]) + >>> stop = jnp.array([5, 0]) + >>> base = jnp.array([2, 3]) + >>> with jnp.printoptions(precision=3, suppress=True): + ... jnp.logspace(start, stop, 5, base=base) + Array([[ 1. , 243. ], + [ 2.378, 61.547], + [ 5.657, 15.588], + [ 13.454, 3.948], + [ 32. , 1. ]], dtype=float32) + """ + num = core.concrete_or_error(operator.index, num, "'num' argument of jnp.logspace") + axis = core.concrete_or_error(operator.index, axis, "'axis' argument of jnp.logspace") + return _logspace(start, stop, num, endpoint, base, dtype, axis) + +@partial(jit, static_argnames=('num', 'endpoint', 'dtype', 'axis')) +def _logspace(start: ArrayLike, stop: ArrayLike, num: int = 50, + endpoint: bool = True, base: ArrayLike = 10.0, + dtype: DTypeLike | None = None, axis: int = 0) -> Array: + """Implementation of logspace differentiable in start and stop args.""" + dtypes.check_user_dtype_supported(dtype, "logspace") + if dtype is None: + dtype = dtypes.to_inexact_dtype(dtypes.result_type(start, stop)) + dtype = dtypes.jax_dtype(dtype) + computation_dtype = dtypes.to_inexact_dtype(dtype) + start, stop = util.ensure_arraylike("logspace", start, stop) + start = start.astype(computation_dtype) + stop = stop.astype(computation_dtype) + lin = linspace(start, stop, num, + endpoint=endpoint, retstep=False, dtype=None, axis=axis) + return lax.convert_element_type(ufuncs.power(base, lin), dtype) + + +@export +def geomspace(start: ArrayLike, stop: ArrayLike, num: int = 50, endpoint: bool = True, + dtype: DTypeLike | None = None, axis: int = 0) -> Array: + """Generate geometrically-spaced values. + + JAX implementation of :func:`numpy.geomspace`. + + Args: + start: scalar or array. Specifies the starting values. + stop: scalar or array. Specifies the stop values. + num: int, optional, default=50. Number of values to generate. + endpoint: bool, optional, default=True. If True, then include the ``stop`` value + in the result. If False, then exclude the ``stop`` value. + dtype: optional. Specifies the dtype of the output. + axis: int, optional, default=0. Axis along which to generate the geomspace. + + Returns: + An array containing the geometrically-spaced values. + + See also: + - :func:`jax.numpy.arange`: Generate ``N`` evenly-spaced values given a starting + point and a step value. + - :func:`jax.numpy.linspace`: Generate evenly-spaced values. + - :func:`jax.numpy.logspace`: Generate logarithmically-spaced values. + + Examples: + List 5 geometrically-spaced values between 1 and 16: + + >>> with jnp.printoptions(precision=3, suppress=True): + ... jnp.geomspace(1, 16, 5) + Array([ 1., 2., 4., 8., 16.], dtype=float32) + + List 4 geomtrically-spaced values between 1 and 16, with ``endpoint=False``: + + >>> with jnp.printoptions(precision=3, suppress=True): + ... jnp.geomspace(1, 16, 4, endpoint=False) + Array([1., 2., 4., 8.], dtype=float32) + + Multi-dimensional geomspace: + + >>> start = jnp.array([1, 1000]) + >>> stop = jnp.array([27, 1]) + >>> with jnp.printoptions(precision=3, suppress=True): + ... jnp.geomspace(start, stop, 4) + Array([[ 1., 1000.], + [ 3., 100.], + [ 9., 10.], + [ 27., 1.]], dtype=float32) + """ + num = core.concrete_or_error(operator.index, num, "'num' argument of jnp.geomspace") + axis = core.concrete_or_error(operator.index, axis, "'axis' argument of jnp.geomspace") + return _geomspace(start, stop, num, endpoint, dtype, axis) + +@partial(jit, static_argnames=('num', 'endpoint', 'dtype', 'axis')) +def _geomspace(start: ArrayLike, stop: ArrayLike, num: int = 50, endpoint: bool = True, + dtype: DTypeLike | None = None, axis: int = 0) -> Array: + """Implementation of geomspace differentiable in start and stop args.""" + dtypes.check_user_dtype_supported(dtype, "geomspace") + if dtype is None: + dtype = dtypes.to_inexact_dtype(dtypes.result_type(start, stop)) + dtype = dtypes.jax_dtype(dtype) + computation_dtype = dtypes.to_inexact_dtype(dtype) + start, stop = util.ensure_arraylike("geomspace", start, stop) + start = start.astype(computation_dtype) + stop = stop.astype(computation_dtype) + + sign = ufuncs.sign(start) + res = sign * logspace(ufuncs.log10(start / sign), ufuncs.log10(stop / sign), + num, endpoint=endpoint, base=10.0, + dtype=computation_dtype, axis=0) + axis = canonicalize_axis(axis, res.ndim) + if axis != 0: + # res = moveaxis(res, 0, axis) + res = lax.transpose(res, permutation=(*range(1, axis + 1), 0, *range(axis + 1, res.ndim))) + return lax.convert_element_type(res, dtype) diff --git a/jax/_src/numpy/array_methods.py b/jax/_src/numpy/array_methods.py index e9e097c85aff..27a1ddce7685 100644 --- a/jax/_src/numpy/array_methods.py +++ b/jax/_src/numpy/array_methods.py @@ -26,12 +26,12 @@ import abc from functools import partial, wraps import math -from typing import Any, Sequence +from typing import Any +from collections.abc import Callable, Sequence import numpy as np -import jax + from jax import lax -from jax.sharding import Sharding from jax._src import api from jax._src import core from jax._src import dtypes @@ -44,6 +44,7 @@ from jax._src.numpy import lax_numpy from jax._src.numpy import tensor_contractions from jax._src.pjit import PartitionSpec +from jax._src.sharding import Sharding from jax._src.sharding_impls import canonicalize_sharding, NamedSharding from jax._src.numpy import reductions from jax._src.numpy import ufuncs @@ -197,12 +198,12 @@ def _dot(self: Array, b: ArrayLike, *, precision: lax_internal.PrecisionLike = N """ return tensor_contractions.dot(self, b, precision=precision, preferred_element_type=preferred_element_type) -def _flatten(self: Array, order: str = "C") -> Array: +def _flatten(self: Array, order: str = "C", *, out_sharding=None) -> Array: """Flatten array into a 1-dimensional shape. Refer to :func:`jax.numpy.ravel` for the full documentation. """ - return lax_numpy.ravel(self, order=order) + return lax_numpy.ravel(self, order=order, out_sharding=out_sharding) def _imag_property(self: Array) -> Array: """Return the imaginary part of the array.""" @@ -293,14 +294,18 @@ def _real_property(self: Array) -> Array: return ufuncs.real(self) def _repeat(self: Array, repeats: ArrayLike, axis: int | None = None, *, - total_repeat_length: int | None = None) -> Array: + total_repeat_length: int | None = None, + out_sharding: NamedSharding | PartitionSpec | None = None) -> Array: """Construct an array from repeated elements. Refer to :func:`jax.numpy.repeat` for the full documentation. """ - return lax_numpy.repeat(self, repeats=repeats, axis=axis, total_repeat_length=total_repeat_length) + return lax_numpy.repeat(self, repeats=repeats, axis=axis, + total_repeat_length=total_repeat_length, + out_sharding=out_sharding) -def _reshape(self: Array, *args: Any, order: str = "C") -> Array: +def _reshape(self: Array, *args: Any, order: str = "C", out_sharding=None + ) -> Array: """Returns an array containing the same data with a new shape. Refer to :func:`jax.numpy.reshape` for full documentation. @@ -308,10 +313,10 @@ def _reshape(self: Array, *args: Any, order: str = "C") -> Array: __tracebackhide__ = True newshape = _compute_newshape(self, args[0] if len(args) == 1 else args) if order == "C": - return lax.reshape(self, newshape, None) + return lax.reshape(self, newshape, None, out_sharding=out_sharding) elif order == "F": dims = list(range(self.ndim)[::-1]) - return lax.reshape(self, newshape[::-1], dims).T + return lax.reshape(self, newshape[::-1], dims, out_sharding=out_sharding).T elif order == "A": raise NotImplementedError("np.reshape order=A is not implemented.") else: @@ -588,7 +593,7 @@ def deferring_binary_op(self, other): def _unimplemented_setitem(self, i, x): msg = ("JAX arrays are immutable and do not support in-place item assignment." " Instead of x[idx] = y, use x = x.at[idx].set(y) or another .at[] method:" - " https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html") + " https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html") raise TypeError(msg.format(type(self))) def _operator_round(number: ArrayLike, ndigits: int | None = None) -> Array: @@ -608,12 +613,13 @@ def _deepcopy(self: Array, memo: Any) -> Array: def __array_module__(self, types): if all(issubclass(t, _HANDLED_ARRAY_TYPES) for t in types): + import jax.numpy # pytype: disable=import-error return jax.numpy else: return NotImplemented -@partial(jax.jit, static_argnums=(1,2,3)) +@partial(api.jit, static_argnums=(1,2,3)) def _multi_slice(self: Array, start_indices: tuple[tuple[int, ...]], limit_indices: tuple[tuple[int, ...]], @@ -633,7 +639,7 @@ def _multi_slice(self: Array, # The next two functions are related to iter(array), implemented here to # avoid circular imports. -@jax.jit +@api.jit def _unstack(x: Array) -> list[Array]: dims = (0,) return [lax.squeeze(t, dims) for t in lax.split(x, (1,) * x.shape[0])] @@ -689,10 +695,8 @@ class _IndexUpdateHelper: By default, JAX assumes that all indices are in-bounds. Alternative out-of-bound index semantics can be specified via the ``mode`` parameter (see below). - Arguments - --------- - mode : str - Specify out-of-bound indexing mode. Options are: + Args: + mode: string specifying out-of-bound indexing mode. Options are: - ``"promise_in_bounds"``: (default) The user promises that indices are in bounds. No additional checking will be performed. In practice, this means that @@ -703,50 +707,68 @@ class _IndexUpdateHelper: - ``"fill"``: alias for ``"drop"``. For `get()`, the optional ``fill_value`` argument specifies the value that will be returned. - See :class:`jax.lax.GatherScatterMode` for more details. - - indices_are_sorted : bool - If True, the implementation will assume that the indices passed to ``at[]`` - are sorted in ascending order, which can lead to more efficient execution - on some backends. - unique_indices : bool - If True, the implementation will assume that the indices passed to ``at[]`` - are unique, which can result in more efficient execution on some backends. - fill_value : Any - Only applies to the ``get()`` method: the fill value to return for out-of-bounds - slices when `mode` is ``'fill'``. Ignored otherwise. Defaults to ``NaN`` for - inexact types, the largest negative value for signed types, the largest positive - value for unsigned types, and ``True`` for booleans. - - Examples - -------- - >>> x = jnp.arange(5.0) - >>> x - Array([0., 1., 2., 3., 4.], dtype=float32) - >>> x.at[2].add(10) - Array([ 0., 1., 12., 3., 4.], dtype=float32) - >>> x.at[10].add(10) # out-of-bounds indices are ignored - Array([0., 1., 2., 3., 4.], dtype=float32) - >>> x.at[20].add(10, mode='clip') - Array([ 0., 1., 2., 3., 14.], dtype=float32) - >>> x.at[2].get() - Array(2., dtype=float32) - >>> x.at[20].get() # out-of-bounds indices clipped - Array(4., dtype=float32) - >>> x.at[20].get(mode='fill') # out-of-bounds indices filled with NaN - Array(nan, dtype=float32) - >>> x.at[20].get(mode='fill', fill_value=-1) # custom fill value - Array(-1., dtype=float32) + See :class:`jax.lax.GatherScatterMode` for more details. + wrap_negative_indices: If True (default) then negative indices indicate position + from the end of the array, similar to Python and NumPy indexing. If False, then + negative indices are considered out-of-bounds and behave according to the + ``mode`` parameter. + fill_value: Only applies to the ``get()`` method: the fill value to return for + out-of-bounds slices when ``mode`` is ``'fill'``. Ignored otherwise. Defaults + to ``NaN`` for inexact types, the largest negative value for signed types, the + largest positive value for unsigned types, and ``True`` for booleans. + indices_are_sorted: If True, the implementation will assume that the (normalized) + indices passed to ``at[]`` are sorted in ascending order, which can lead to more + efficient execution on some backends. If True but the indices are not actually + sorted, the output is undefined. + unique_indices: If True, the implementation will assume that the (normalized) indices + passed to ``at[]`` are unique, which can result in more efficient execution on some + backends. If True but the indices are not actually unique, the output is undefined. + + Examples: + >>> x = jnp.arange(5.0) + >>> x + Array([0., 1., 2., 3., 4.], dtype=float32) + >>> x.at[2].get() + Array(2., dtype=float32) + >>> x.at[2].add(10) + Array([ 0., 1., 12., 3., 4.], dtype=float32) + + By default, out-of-bound indices are ignored in updates, but this behavior + can be controlled with the ``mode`` parameter: + + >>> x.at[10].add(10) # dropped + Array([0., 1., 2., 3., 4.], dtype=float32) + >>> x.at[20].add(10, mode='clip') # clipped + Array([ 0., 1., 2., 3., 14.], dtype=float32) + + For ``get()``, out-of-bound indices are clipped by default: + + >>> x.at[20].get() # out-of-bounds indices clipped + Array(4., dtype=float32) + >>> x.at[20].get(mode='fill') # out-of-bounds indices filled with NaN + Array(nan, dtype=float32) + >>> x.at[20].get(mode='fill', fill_value=-1) # custom fill value + Array(-1., dtype=float32) + + Negative indices count from the end of the array, but this behavior can + be disabled by setting ``wrap_negative_indices = False``: + + >>> x.at[-1].set(99) + Array([ 0., 1., 2., 3., 99.], dtype=float32) + >>> x.at[-1].set(99, wrap_negative_indices=False, mode='drop') # dropped! + Array([0., 1., 2., 3., 4.], dtype=float32) """ __slots__ = ("array",) - def __init__(self, array): + array: Array + + def __init__(self, array: Array): self.array = array - def __getitem__(self, index): + def __getitem__(self, index: scatter.Index) -> _IndexUpdateRef: return _IndexUpdateRef(self.array, index) - def __repr__(self): + def __repr__(self) -> str: return f"_IndexUpdateHelper({self.array!r})" @@ -759,15 +781,21 @@ class _IndexUpdateRef: """ __slots__ = ("array", "index") - def __init__(self, array, index): + array: Array + index: scatter.Index + + def __init__(self, array: Array, index: scatter.Index): self.array = array self.index = index def __repr__(self) -> str: return f"_IndexUpdateRef({self.array!r}, {self.index!r})" - def get(self, *, indices_are_sorted=False, unique_indices=False, - mode=None, fill_value=None, out_sharding=None): + def get(self, *, indices_are_sorted: bool = False, unique_indices: bool = False, + mode: str | lax.GatherScatterMode | None = None, + fill_value: ArrayLike | None = None, + out_sharding: Sharding | PartitionSpec | None = None, + wrap_negative_indices: bool = True): """Equivalent to ``x[idx]``. Returns the value of ``x`` that would result from the NumPy-style @@ -775,7 +803,7 @@ def get(self, *, indices_are_sorted=False, unique_indices=False, the usual array indexing syntax in that it allows additional keyword arguments ``indices_are_sorted`` and ``unique_indices`` to be passed. - See :mod:`jax.ops` for details. + See :func:`jax.numpy.ndarray.at` for details. """ if out_sharding is not None: assert isinstance(out_sharding, (NamedSharding, PartitionSpec)) @@ -784,23 +812,32 @@ def get(self, *, indices_are_sorted=False, unique_indices=False, indices_are_sorted=indices_are_sorted, unique_indices=unique_indices, mode=mode, fill_value=fill_value, + normalize_indices=wrap_negative_indices, out_sharding=out_sharding) - def set(self, values, *, indices_are_sorted=False, unique_indices=False, - mode=None): + def set(self, values: ArrayLike, *, indices_are_sorted: bool = False, + unique_indices: bool = False, + mode: str | lax.GatherScatterMode | None = None, + wrap_negative_indices: bool = True) -> None: """Pure equivalent of ``x[idx] = y``. Returns the value of ``x`` that would result from the NumPy-style :mod:`indexed assignment ` ``x[idx] = y``. - See :mod:`jax.ops` for details. + See :func:`jax.numpy.ndarray.at` for details. """ + out_s = core.typeof(self.array).sharding + if out_s.mesh.empty or out_s.mesh._are_all_axes_auto_or_manual: + out_s = None return scatter._scatter_update(self.array, self.index, values, lax.scatter, indices_are_sorted=indices_are_sorted, - unique_indices=unique_indices, mode=mode) + unique_indices=unique_indices, mode=mode, + out_sharding=out_s, normalize_indices=wrap_negative_indices) - def apply(self, func, *, indices_are_sorted=False, unique_indices=False, - mode=None): + def apply(self, func: Callable[[ArrayLike], Array], *, + indices_are_sorted: bool = False, unique_indices: bool = False, + mode: str | lax.GatherScatterMode | None = None, + wrap_negative_indices: bool = True) -> Array: """Pure equivalent of ``func.at(x, idx)`` for a unary ufunc ``func``. Returns the value of ``x`` that would result from applying the unary @@ -812,7 +849,7 @@ def apply(self, func, *, indices_are_sorted=False, unique_indices=False, Note that in the current implementation, ``scatter_apply`` is not compatible with automatic differentiation. - See :mod:`jax.ops` for details. + See :func:`jax.numpy.ndarray.at` for details. """ def _scatter_apply(x, indices, y, dims, **kwargs): return lax.scatter_apply(x, indices, func, dims, update_shape=y.shape, **kwargs) @@ -820,113 +857,134 @@ def _scatter_apply(x, indices, y, dims, **kwargs): lax_internal._zero(self.array), _scatter_apply, indices_are_sorted=indices_are_sorted, - unique_indices=unique_indices, mode=mode) + unique_indices=unique_indices, mode=mode, + normalize_indices=wrap_negative_indices) - def add(self, values, *, indices_are_sorted=False, unique_indices=False, - mode=None): + def add(self, values: ArrayLike, *, + indices_are_sorted: bool = False, unique_indices: bool = False, + mode: str | lax.GatherScatterMode | None = None, + wrap_negative_indices: bool = True) -> Array: """Pure equivalent of ``x[idx] += y``. Returns the value of ``x`` that would result from the NumPy-style :mod:indexed assignment ` ``x[idx] += y``. - See :mod:`jax.ops` for details. + See :func:`jax.numpy.ndarray.at` for details. """ return scatter._scatter_update(self.array, self.index, values, lax.scatter_add, indices_are_sorted=indices_are_sorted, - unique_indices=unique_indices, mode=mode) + unique_indices=unique_indices, mode=mode, + normalize_indices=wrap_negative_indices) - def subtract(self, values, *, indices_are_sorted=False, unique_indices=False, - mode=None): + def subtract(self, values: ArrayLike, *, + indices_are_sorted: bool = False, unique_indices: bool = False, + mode: str | lax.GatherScatterMode | None = None, + wrap_negative_indices: bool = True) -> Array: """Pure equivalent of ``x[idx] -= y``. Returns the value of ``x`` that would result from the NumPy-style :mod:indexed assignment ` ``x[idx] -= y``. - See :mod:`jax.ops` for details. + See :func:`jax.numpy.ndarray.at` for details. """ return scatter._scatter_update(self.array, self.index, values, lax.scatter_sub, indices_are_sorted=indices_are_sorted, - unique_indices=unique_indices, mode=mode) + unique_indices=unique_indices, mode=mode, + normalize_indices=wrap_negative_indices) - def multiply(self, values, *, indices_are_sorted=False, unique_indices=False, - mode=None): + def multiply(self, values: ArrayLike, *, + indices_are_sorted: bool = False, unique_indices: bool = False, + mode: str | lax.GatherScatterMode | None = None, + wrap_negative_indices: bool = True) -> Array: """Pure equivalent of ``x[idx] *= y``. Returns the value of ``x`` that would result from the NumPy-style :mod:indexed assignment ` ``x[idx] *= y``. - See :mod:`jax.ops` for details. + See :func:`jax.numpy.ndarray.at` for details. """ return scatter._scatter_update(self.array, self.index, values, lax.scatter_mul, indices_are_sorted=indices_are_sorted, unique_indices=unique_indices, - mode=mode) + mode=mode, normalize_indices=wrap_negative_indices) mul = multiply - def divide(self, values, *, indices_are_sorted=False, unique_indices=False, - mode=None): + def divide(self, values: ArrayLike, *, + indices_are_sorted: bool = False, unique_indices: bool = False, + mode: str | lax.GatherScatterMode | None = None, + wrap_negative_indices: bool = True) -> Array: """Pure equivalent of ``x[idx] /= y``. Returns the value of ``x`` that would result from the NumPy-style :mod:indexed assignment ` ``x[idx] /= y``. - See :mod:`jax.ops` for details. + See :func:`jax.numpy.ndarray.at` for details. """ return ufuncs.divide( self.array, scatter._scatter_update(lax_numpy.ones_like(self.array), self.index, values, lax.scatter_mul, indices_are_sorted=indices_are_sorted, - unique_indices=unique_indices, mode=mode)) + unique_indices=unique_indices, mode=mode, + normalize_indices=wrap_negative_indices)) - def power(self, values, *, indices_are_sorted=False, unique_indices=False, - mode=None): + def power(self, values: ArrayLike, *, + indices_are_sorted: bool = False, unique_indices: bool = False, + mode: str | lax.GatherScatterMode | None = None, + wrap_negative_indices: bool = True) -> Array: """Pure equivalent of ``x[idx] **= y``. Returns the value of ``x`` that would result from the NumPy-style :mod:indexed assignment ` ``x[idx] **= y``. - See :mod:`jax.ops` for details. + See :func:`jax.numpy.ndarray.at` for details. """ return ufuncs.power( self.array, scatter._scatter_update(lax_numpy.ones_like(self.array), self.index, values, lax.scatter_mul, indices_are_sorted=indices_are_sorted, - unique_indices=unique_indices, mode=mode)) + unique_indices=unique_indices, mode=mode, + normalize_indices=wrap_negative_indices)) - def min(self, values, *, indices_are_sorted=False, unique_indices=False, - mode=None): + def min(self, values: ArrayLike, *, + indices_are_sorted: bool = False, unique_indices: bool = False, + mode: str | lax.GatherScatterMode | None = None, + wrap_negative_indices: bool = True) -> Array: """Pure equivalent of ``x[idx] = minimum(x[idx], y)``. Returns the value of ``x`` that would result from the NumPy-style :mod:indexed assignment ` ``x[idx] = minimum(x[idx], y)``. - See :mod:`jax.ops` for details. + See :func:`jax.numpy.ndarray.at` for details. """ return scatter._scatter_update(self.array, self.index, values, lax.scatter_min, indices_are_sorted=indices_are_sorted, - unique_indices=unique_indices, mode=mode) + unique_indices=unique_indices, mode=mode, + normalize_indices=wrap_negative_indices) - def max(self, values, *, indices_are_sorted=False, unique_indices=False, - mode=None): + def max(self, values: ArrayLike, *, + indices_are_sorted: bool = False, unique_indices: bool = False, + mode: str | lax.GatherScatterMode | None = None, + wrap_negative_indices: bool = True) -> Array: """Pure equivalent of ``x[idx] = maximum(x[idx], y)``. Returns the value of ``x`` that would result from the NumPy-style :mod:indexed assignment ` ``x[idx] = maximum(x[idx], y)``. - See :mod:`jax.ops` for details. + See :func:`jax.numpy.ndarray.at` for details. """ return scatter._scatter_update(self.array, self.index, values, lax.scatter_max, indices_are_sorted=indices_are_sorted, - unique_indices=unique_indices, mode=mode) + unique_indices=unique_indices, mode=mode, + normalize_indices=wrap_negative_indices) _array_operators = { "getitem": _getitem, @@ -948,8 +1006,6 @@ def max(self, values, *, indices_are_sorted=False, unique_indices=False, "rsub": _defer_to_unrecognized_arg("-", ufuncs.subtract, swap=True), "mul": _defer_to_unrecognized_arg("*", ufuncs.multiply), "rmul": _defer_to_unrecognized_arg("*", ufuncs.multiply, swap=True), - "div": _defer_to_unrecognized_arg("/", ufuncs.divide), - "rdiv": _defer_to_unrecognized_arg("/", ufuncs.divide, swap=True), "truediv": _defer_to_unrecognized_arg("/", ufuncs.true_divide), "rtruediv": _defer_to_unrecognized_arg("/", ufuncs.true_divide, swap=True), "floordiv": _defer_to_unrecognized_arg("//", ufuncs.floor_divide), diff --git a/jax/_src/numpy/einsum.py b/jax/_src/numpy/einsum.py index 9d745643b596..372b643fbc02 100644 --- a/jax/_src/numpy/einsum.py +++ b/jax/_src/numpy/einsum.py @@ -13,7 +13,8 @@ # limitations under the License. import collections -from typing import overload, Any, Callable, Sequence +from typing import overload, Any +from collections.abc import Callable, Sequence import numpy as np import opt_einsum @@ -22,10 +23,11 @@ from jax._src import core from jax._src import dtypes from jax._src.api import jit, named_call +from jax._src.export import shape_poly from jax._src.lax import lax from jax._src.lax.lax import PrecisionLike from jax._src.numpy import util -from jax._src.sharding_impls import canonicalize_sharding, NamedSharding, PartitionSpec as P +from jax._src.sharding_impls import canonicalize_sharding, NamedSharding from jax._src.typing import Array, ArrayLike, DTypeLike from jax._src.util import partition_list, set_module, unzip2 @@ -288,6 +290,10 @@ def einsum( spec = operands[0] if isinstance(operands[0], str) else None path_type = 'optimal' if optimize is True else Unoptimized() if optimize is False else optimize + # Extract __jax_array__ before passing to contract_path() + operands = tuple(op.__jax_array__() if hasattr(op, "__jax_array__") else op + for op in operands) + # Allow handling of shape polymorphism non_constant_dim_types = { type(d) for op in operands if not isinstance(op, str) @@ -418,7 +424,8 @@ def _einsum( " instances. Please file a bug if this is not enough for your use case.") dtypes.check_user_dtype_supported(preferred_element_type, "einsum") if preferred_element_type is None: - preferred_element_type, output_weak_type = dtypes.result_type(*operands, return_weak_type_flag=True) + preferred_element_type, output_weak_type = dtypes.result_type( + *operands, return_weak_type_flag=True) else: output_weak_type = False @@ -548,12 +555,12 @@ def filter_singleton_dims(operand, names, other_shape, other_names): dot_general_out_sharding = None elif out_sharding is not None and names != result_names: if len(result_names) > len(out_sharding.spec): - out_sharding = out_sharding.with_spec( + out_sharding = out_sharding.update(spec= out_sharding.spec._normalized_spec_for_aval(len(result_names))) spec = out_sharding.spec inverse_spec = tuple(spec[result_names.index(name)] for name in names) dot_general_out_sharding = NamedSharding( - out_sharding.mesh, P(*inverse_spec)) + out_sharding.mesh, spec.update(partitions=inverse_spec)) else: dot_general_out_sharding = out_sharding # type: ignore dimension_numbers = ((lhs_cont, rhs_cont), (lhs_batch, rhs_batch)) @@ -576,3 +583,5 @@ def filter_singleton_dims(operand, names, other_shape, other_names): return lax._convert_element_type(operands[0], preferred_element_type, output_weak_type) + +_poly_einsum_handlers[shape_poly._DimExpr] = shape_poly._einsum_contract_path diff --git a/jax/_src/numpy/error.py b/jax/_src/numpy/error.py new file mode 100644 index 000000000000..cf69eb10b1a3 --- /dev/null +++ b/jax/_src/numpy/error.py @@ -0,0 +1,206 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib +from typing import Literal +from collections.abc import Sequence + +import numpy as np + +from jax._src import config +from jax._src import dtypes +from jax._src.typing import Array, ArrayLike + +Category = Literal["nan", "divide", "oob"] + + +def _is_category_disabled( + category: Category | None, +) -> bool: + """Check if the error checking behavior for the given category is disabled.""" + if category is None: + return False + if category == "nan": + raise ValueError("nan is deprecated. Use `_set_error_if_nan` instead.") + if category == "divide": + raise ValueError( + "divide is deprecated. Use `_set_error_if_divide_by_zero` instead." + ) + if category == "oob": + return config.error_checking_behavior_oob.value == "ignore" + raise ValueError(f"Invalid category: {category}") + + +def _set_error_if_with_category( + pred: Array, + /, + msg: str, + category: Category | None = None, +) -> None: + """Set the internal error state if any element of `pred` is `True`. + + This function is similar to :func:`set_error_if`, but it also takes a category + argument. The category can be "nan", "divide", or "oob". The error checking + behavior for each category can be configured using + :func:`set_error_checking_behavior`. If not provided, there will be no + category. + + This function is intended for use in JAX internal APIs (e.g., `jax.numpy`) + to perform category-specific runtime checks tied to the operation being + performed. + """ + if _is_category_disabled(category): + return + + # TODO(mattjj): fix the circular import issue. + from jax._src import error_check as error_check_lib + error_check_lib.set_error_if(pred, msg) + + +def _set_error_if_nan(pred: Array, /): + """Set the internal error state if any element of `pred` is `NaN`. + + This function is disabled if the `jax_error_checking_behavior_nan` flag is + set to "ignore". + """ + if config.error_checking_behavior_nan.value == "ignore": + return + + if not dtypes.issubdtype(pred.dtype, np.floating): # only check floats + return + + # TODO(mattjj): fix the circular import issue. + from jax._src import error_check as error_check_lib + import jax.numpy as jnp + + error_check_lib.set_error_if(jnp.isnan(pred), "NaN encountered") + + +def _set_error_if_divide_by_zero(pred: Array, /): + """Set the internal error state if any element of `pred` is zero. + + This function is intended for checking if the denominator of a division is + zero. + + This function is disabled if the `jax_error_checking_behavior_divide` flag is + set to "ignore". + """ + if config.error_checking_behavior_divide.value == "ignore": + return + + # TODO(ayx): fix the circular import issue. + from jax._src import error_check as error_check_lib + import jax.numpy as jnp + zero = jnp.zeros_like(pred, shape=()) + error_check_lib.set_error_if(pred == zero, "Division by zero encountered") + + +def _check_precondition_oob_gather( + shape: tuple[int, ...], gather_indices: ArrayLike +) -> None: + """Check for out of bounds errors before calling `lax.gather`.""" + if config.error_checking_behavior_oob.value == "ignore": + return + + # TODO(mattjj): fix the circular import issue. + from jax._src import error_check as error_check_lib + import jax.numpy as jnp + + shape = jnp.array(shape, dtype=jnp.int32) + error_check_lib.set_error_if( + jnp.logical_or( + jnp.min(gather_indices) < -shape, + jnp.max(gather_indices) >= shape, + ), + "Out of bounds encountered before calling `lax.gather`", + ) + + +def _check_precondition_oob_dynamic_slice( + shape: tuple[int, ...], + start_indices: Sequence[ArrayLike], + slice_sizes: list[int], + allow_negative_indices: list[bool], +) -> None: + """Check for out of bounds errors before calling `lax.dynamic_slice`.""" + if config.error_checking_behavior_oob.value == "ignore": + return + + # TODO(mattjj): fix the circular import issue. + from jax._src import error_check as error_check_lib + import jax.numpy as jnp + + shape = jnp.array(shape, dtype=jnp.int32) + start_indices = jnp.array(start_indices, dtype=jnp.int32) + slice_sizes = jnp.array(slice_sizes, dtype=jnp.int32) + allow_negative_indices = jnp.array(allow_negative_indices, dtype=jnp.bool_) + + lower_bound = jnp.where(allow_negative_indices, -shape, 0) + error_check_lib.set_error_if( + jnp.logical_or( + jnp.minimum(start_indices, start_indices + slice_sizes) < lower_bound, + jnp.maximum(start_indices, start_indices + slice_sizes) >= shape, + ), + "Out of bounds encountered before calling `lax.dynamic_slice`", + ) + + +Behavior = Literal["ignore", "raise"] + + +class error_checking_behavior: + """A context manager to set the error checking behavior. + + If both `all` and a category are provided, the category will override the + `all` setting. + + When the error checking behavior is set to "ignore", all errors will be + ignored. When set to "raise", errors will be detected and recorded, but an + exception will not be raised immediately. Users must call + :func:`raise_if_error` to at the end of the computation to raise the + exception. + """ + + def __init__( + self, + *, + all: Behavior | None = None, + nan: Behavior | None = None, + divide: Behavior | None = None, + oob: Behavior | None = None, + ) -> None: + new_settings = {} + if all is not None: + new_settings["nan"] = new_settings["divide"] = new_settings["oob"] = all + if nan is not None: + new_settings["nan"] = nan + if divide is not None: + new_settings["divide"] = divide + if oob is not None: + new_settings["oob"] = oob + self.new_settings = new_settings + self.stack = contextlib.ExitStack() + + def __enter__(self): + config_flags = { + "nan": config.error_checking_behavior_nan, + "divide": config.error_checking_behavior_divide, + "oob": config.error_checking_behavior_oob, + } + for key, value in self.new_settings.items(): + self.stack.enter_context(config_flags[key](value)) + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.stack.close() diff --git a/jax/_src/numpy/fft.py b/jax/_src/numpy/fft.py index f962438f23bb..970847532e46 100644 --- a/jax/_src/numpy/fft.py +++ b/jax/_src/numpy/fft.py @@ -18,8 +18,8 @@ import operator import numpy as np -from jax import lax from jax._src import dtypes +from jax._src.lax import fft as lax_fft from jax._src.lib import xla_client from jax._src.util import safe_zip from jax._src.numpy.util import ensure_arraylike, promote_dtypes_inexact @@ -45,7 +45,7 @@ def _fft_norm(s: Array, func_name: str, norm: str) -> Array: '"ortho" or "forward".') -def _fft_core(func_name: str, fft_type: lax.FftType, a: ArrayLike, +def _fft_core(func_name: str, fft_type: lax_fft.FftType, a: ArrayLike, s: Shape | None, axes: Sequence[int] | None, norm: str | None) -> Array: full_name = f"jax.numpy.fft.{func_name}" @@ -80,14 +80,14 @@ def _fft_core(func_name: str, fft_type: lax.FftType, a: ArrayLike, in_s = list(arr.shape) for axis, x in safe_zip(axes, s): in_s[axis] = x - if fft_type == lax.FftType.IRFFT: + if fft_type == lax_fft.FftType.IRFFT: in_s[-1] = (in_s[-1] // 2 + 1) # Cropping arr = arr[tuple(map(slice, in_s))] # Padding arr = jnp.pad(arr, [(0, x-y) for x, y in zip(in_s, arr.shape)]) else: - if fft_type == lax.FftType.IRFFT: + if fft_type == lax_fft.FftType.IRFFT: s = [arr.shape[axis] for axis in axes[:-1]] if axes: s += [max(0, 2 * (arr.shape[axes[-1]] - 1))] @@ -103,10 +103,10 @@ def _fft_core(func_name: str, fft_type: lax.FftType, a: ArrayLike, return transformed -def _fft_core_nd(arr: Array, fft_type: lax.FftType, s: Shape) -> Array: +def _fft_core_nd(arr: Array, fft_type: lax_fft.FftType, s: Shape) -> Array: # XLA supports N-D transforms up to N=3 so we use XLA's FFT N-D directly. if len(s) <= 3: - return lax.fft(arr, fft_type, tuple(s)) + return lax_fft.fft(arr, fft_type, tuple(s)) # For larger N, we repeatedly apply N<=3 transforms until we reach the # requested dimension. We special case N=4 to use two 2-D transforms instead @@ -115,16 +115,16 @@ def _fft_core_nd(arr: Array, fft_type: lax.FftType, s: Shape) -> Array: n = 2 if len(s) == 4 else 3 src = tuple(range(arr.ndim - len(s), arr.ndim - n)) dst = tuple(range(arr.ndim - len(s) + n, arr.ndim)) - if fft_type in {lax.FftType.RFFT, lax.FftType.FFT}: - arr = lax.fft(arr, fft_type, tuple(s)[-n:]) + if fft_type in {lax_fft.FftType.RFFT, lax_fft.FftType.FFT}: + arr = lax_fft.fft(arr, fft_type, tuple(s)[-n:]) arr = jnp.moveaxis(arr, src, dst) - arr = _fft_core_nd(arr, lax.FftType.FFT, s[:-n]) + arr = _fft_core_nd(arr, lax_fft.FftType.FFT, s[:-n]) arr = jnp.moveaxis(arr, dst, src) else: arr = jnp.moveaxis(arr, src, dst) - arr = _fft_core_nd(arr, lax.FftType.IFFT, s[:-n]) + arr = _fft_core_nd(arr, lax_fft.FftType.IFFT, s[:-n]) arr = jnp.moveaxis(arr, dst, src) - arr = lax.fft(arr, fft_type, tuple(s)[-n:]) + arr = lax_fft.fft(arr, fft_type, tuple(s)[-n:]) return arr @@ -199,7 +199,7 @@ def fftn(a: ArrayLike, s: Shape | None = None, >>> jnp.allclose(x, jnp.fft.ifftn(x_fftn)) Array(True, dtype=bool) """ - return _fft_core('fftn', lax.FftType.FFT, a, s, axes, norm) + return _fft_core('fftn', lax_fft.FftType.FFT, a, s, axes, norm) def ifftn(a: ArrayLike, s: Shape | None = None, @@ -267,7 +267,7 @@ def ifftn(a: ArrayLike, s: Shape | None = None, [[ 2.5 +0.j 0. -0.58j 0. +0.58j] [ 0.17+0.j -0.83-0.29j -0.83+0.29j]] """ - return _fft_core('ifftn', lax.FftType.IFFT, a, s, axes, norm) + return _fft_core('ifftn', lax_fft.FftType.IFFT, a, s, axes, norm) def rfftn(a: ArrayLike, s: Shape | None = None, @@ -358,7 +358,7 @@ def rfftn(a: ArrayLike, s: Shape | None = None, >>> jnp.fft.rfftn(x1) Array([10.+0.j, -2.+2.j, -2.+0.j], dtype=complex64) """ - return _fft_core('rfftn', lax.FftType.RFFT, a, s, axes, norm) + return _fft_core('rfftn', lax_fft.FftType.RFFT, a, s, axes, norm) def irfftn(a: ArrayLike, s: Shape | None = None, @@ -435,7 +435,7 @@ def irfftn(a: ArrayLike, s: Shape | None = None, [[-2., -2., -2.], [-2., -2., -2.]]], dtype=float32) """ - return _fft_core('irfftn', lax.FftType.IRFFT, a, s, axes, norm) + return _fft_core('irfftn', lax_fft.FftType.IRFFT, a, s, axes, norm) def _axis_check_1d(func_name: str, axis: int | None): @@ -446,7 +446,7 @@ def _axis_check_1d(func_name: str, axis: int | None): "Got axis = %r." % (full_name, full_name, axis) ) -def _fft_core_1d(func_name: str, fft_type: lax.FftType, +def _fft_core_1d(func_name: str, fft_type: lax_fft.FftType, a: ArrayLike, n: int | None, axis: int | None, norm: str | None) -> Array: _axis_check_1d(func_name, axis) @@ -514,7 +514,7 @@ def fft(a: ArrayLike, n: int | None = None, >>> jnp.allclose(x, jnp.fft.ifft(x_fft)) Array(True, dtype=bool) """ - return _fft_core_1d('fft', lax.FftType.FFT, a, n=n, axis=axis, + return _fft_core_1d('fft', lax_fft.FftType.FFT, a, n=n, axis=axis, norm=norm) @@ -570,7 +570,7 @@ def ifft(a: ArrayLike, n: int | None = None, [ 0.67+0.58j -0.5 +1.44j 0.17+2.02j 1.83+0.29j] [ 0.67-0.58j -0.5 -1.44j 0.17-2.02j 1.83-0.29j]] """ - return _fft_core_1d('ifft', lax.FftType.IFFT, a, n=n, axis=axis, + return _fft_core_1d('ifft', lax_fft.FftType.IFFT, a, n=n, axis=axis, norm=norm) @@ -631,7 +631,7 @@ def rfft(a: ArrayLike, n: int | None = None, [ 1.-2.j, 3.-4.j, 5.-6.j], [-1.+0.j, -1.+0.j, -1.+0.j]], dtype=complex64) """ - return _fft_core_1d('rfft', lax.FftType.RFFT, a, n=n, axis=axis, + return _fft_core_1d('rfft', lax_fft.FftType.RFFT, a, n=n, axis=axis, norm=norm) @@ -691,7 +691,7 @@ def irfft(a: ArrayLike, n: int | None = None, [-0.75, -1.25, -1.75], [ 0.25, 0.75, 1.25]], dtype=float32) """ - return _fft_core_1d('irfft', lax.FftType.IRFFT, a, n=n, axis=axis, + return _fft_core_1d('irfft', lax_fft.FftType.IRFFT, a, n=n, axis=axis, norm=norm) @@ -712,7 +712,7 @@ def hfft(a: ArrayLike, n: int | None = None, are supported. Default is "backward". Returns: - A real-valued array containing the one-dimensional discret Fourier transform + A real-valued array containing the one-dimensional discrete Fourier transform of ``a`` by exploiting its inherent Hermitian-symmetry, having a dimension of ``n`` along ``axis``. @@ -781,7 +781,7 @@ def hfft(a: ArrayLike, n: int | None = None, conj_a = ufuncs.conj(a) _axis_check_1d('hfft', axis) nn = (conj_a.shape[axis] - 1) * 2 if n is None else n - return _fft_core_1d('hfft', lax.FftType.IRFFT, conj_a, n=n, axis=axis, + return _fft_core_1d('hfft', lax_fft.FftType.IRFFT, conj_a, n=n, axis=axis, norm=norm) * nn @@ -831,12 +831,12 @@ def ihfft(a: ArrayLike, n: int | None = None, _axis_check_1d('ihfft', axis) arr = jnp.asarray(a) nn = arr.shape[axis] if n is None else n - output = _fft_core_1d('ihfft', lax.FftType.RFFT, arr, n=n, axis=axis, + output = _fft_core_1d('ihfft', lax_fft.FftType.RFFT, arr, n=n, axis=axis, norm=norm) return ufuncs.conj(output) * (1 / nn) -def _fft_core_2d(func_name: str, fft_type: lax.FftType, a: ArrayLike, +def _fft_core_2d(func_name: str, fft_type: lax_fft.FftType, a: ArrayLike, s: Shape | None, axes: Sequence[int], norm: str | None) -> Array: full_name = f"jax.numpy.fft.{func_name}" @@ -923,7 +923,7 @@ def fft2(a: ArrayLike, s: Shape | None = None, axes: Sequence[int] = (-2,-1), >>> jnp.allclose(x, jnp.fft.ifft2(x_fft2)) Array(True, dtype=bool) """ - return _fft_core_2d('fft2', lax.FftType.FFT, a, s=s, axes=axes, + return _fft_core_2d('fft2', lax_fft.FftType.FFT, a, s=s, axes=axes, norm=norm) @@ -995,7 +995,7 @@ def ifft2(a: ArrayLike, s: Shape | None = None, axes: Sequence[int] = (-2,-1), [-0.33-0.58j, -0.33-0.58j], [-0.33+0.58j, -0.33+0.58j]]], dtype=complex64) """ - return _fft_core_2d('ifft2', lax.FftType.IFFT, a, s=s, axes=axes, + return _fft_core_2d('ifft2', lax_fft.FftType.IFFT, a, s=s, axes=axes, norm=norm) @@ -1074,7 +1074,7 @@ def rfft2(a: ArrayLike, s: Shape | None = None, axes: Sequence[int] = (-2,-1), [ 3.47+10.11j, 6.43+11.42j, 9.38+12.74j], [ 3.19 +1.63j, 4.4 +1.38j, 5.61 +1.12j]]], dtype=complex64) """ - return _fft_core_2d('rfft2', lax.FftType.RFFT, a, s=s, axes=axes, + return _fft_core_2d('rfft2', lax_fft.FftType.RFFT, a, s=s, axes=axes, norm=norm) @@ -1149,7 +1149,7 @@ def irfft2(a: ArrayLike, s: Shape | None = None, axes: Sequence[int] = (-2,-1), [ 0. , 0. , 0. ], [ 0. , 0. , 0. ]]], dtype=float32) """ - return _fft_core_2d('irfft2', lax.FftType.IRFFT, a, s=s, axes=axes, + return _fft_core_2d('irfft2', lax_fft.FftType.IRFFT, a, s=s, axes=axes, norm=norm) @@ -1186,21 +1186,8 @@ def fftfreq(n: int, d: ArrayLike = 1.0, *, dtype: DTypeLike | None = None, "The d argument of jax.numpy.fft.fftfreq only takes a single value. " "Got d = %s." % list(d)) - k = jnp.zeros(n, dtype=dtype, device=device) - if n % 2 == 0: - # k[0: n // 2 - 1] = jnp.arange(0, n // 2 - 1) - k = k.at[0: n // 2].set(jnp.arange(0, n // 2, dtype=dtype)) - - # k[n // 2:] = jnp.arange(-n // 2, -1) - k = k.at[n // 2:].set(jnp.arange(-n // 2, 0, dtype=dtype)) - - else: - # k[0: (n - 1) // 2] = jnp.arange(0, (n - 1) // 2) - k = k.at[0: (n - 1) // 2 + 1].set(jnp.arange(0, (n - 1) // 2 + 1, dtype=dtype)) - - # k[(n - 1) // 2 + 1:] = jnp.arange(-(n - 1) // 2, -1) - k = k.at[(n - 1) // 2 + 1:].set(jnp.arange(-(n - 1) // 2, 0, dtype=dtype)) - + i = jnp.arange(n, dtype=dtype, device=device) + k = ((i + n//2) % n - n//2) return k / jnp.array(d * n, dtype=dtype, device=device) diff --git a/jax/_src/numpy/index_tricks.py b/jax/_src/numpy/index_tricks.py index ec67d7489f30..8b70c37192c2 100644 --- a/jax/_src/numpy/index_tricks.py +++ b/jax/_src/numpy/index_tricks.py @@ -17,17 +17,18 @@ from collections.abc import Iterable from typing import Any, Union -import jax +import numpy as np + +from jax._src import config from jax._src import core +from jax._src.numpy.array import array from jax._src.numpy.util import promote_dtypes from jax._src.numpy.lax_numpy import ( - arange, array, concatenate, expand_dims, linspace, meshgrid, stack, transpose + arange, concatenate, expand_dims, linspace, meshgrid, stack, transpose ) from jax._src.typing import Array, ArrayLike from jax._src.util import set_module -import numpy as np - export = set_module('jax.numpy') @@ -83,7 +84,7 @@ def __getitem__(self, key: slice | tuple[slice, ...]) -> Array: if isinstance(key, slice): return _make_1d_grid_from_slice(key, op_name="mgrid") output: Iterable[Array] = (_make_1d_grid_from_slice(k, op_name="mgrid") for k in key) - with jax.numpy_dtype_promotion('standard'): + with config.numpy_dtype_promotion('standard'): output = promote_dtypes(*output) output_arr = meshgrid(*output, indexing='ij', sparse=False) if len(output_arr) == 0: @@ -128,7 +129,7 @@ def __getitem__( if isinstance(key, slice): return _make_1d_grid_from_slice(key, op_name="ogrid") output: Iterable[Array] = (_make_1d_grid_from_slice(k, op_name="ogrid") for k in key) - with jax.numpy_dtype_promotion('standard'): + with config.numpy_dtype_promotion('standard'): output = promote_dtypes(*output) return meshgrid(*output, indexing='ij', sparse=True) diff --git a/jax/_src/numpy/indexing.py b/jax/_src/numpy/indexing.py index 5d59bb53b457..934246dc8cbd 100644 --- a/jax/_src/numpy/indexing.py +++ b/jax/_src/numpy/indexing.py @@ -18,11 +18,11 @@ from functools import partial import operator import string -from typing import Any, NamedTuple, Sequence +from typing import Any, NamedTuple +from collections.abc import Sequence import numpy as np -import jax from jax import lax from jax._src import array from jax._src import config @@ -33,14 +33,14 @@ from jax._src.api import jit from jax._src.lax import lax as lax_internal from jax._src.numpy import einsum -from jax._src import mesh as mesh_lib -from jax._src.pjit import auto_axes +from jax._src.numpy import error as jnp_error from jax._src.numpy import lax_numpy from jax._src.numpy import ufuncs from jax._src.numpy import util +from jax._src.pjit import auto_axes from jax._src.tree_util import tree_flatten from jax._src.typing import Array, ArrayLike, StaticScalar -from jax._src.util import canonicalize_axis, set_module, tuple_replace, safe_zip +from jax._src.util import canonicalize_axis, safe_zip, set_module, tuple_update export = set_module('jax.numpy') @@ -315,8 +315,10 @@ def replace(tup, val): return lax.full(out_shape, 0, a.dtype) if mode == "one_hot": + from jax import nn # pytype: disable=import-error + indices = _normalize_index(indices, axis_size) - hot = jax.nn.one_hot(indices, axis_size, dtype=np.bool_) + hot = nn.one_hot(indices, axis_size, dtype=np.bool_) if a.ndim == 1: return einsum.einsum("...b,b->...", hot, a, preferred_element_type=a.dtype) if axis_int > len(string.ascii_letters) - 2: @@ -397,7 +399,9 @@ def replace(tup, val): def _make_along_axis_idx(shape, indices, axis): - return tuple_replace(lax_numpy.indices(shape, sparse=True), axis, indices) + if axis < 0: + axis += len(shape) + return tuple_update(lax_numpy.indices(shape, sparse=True), axis, indices) @export @@ -520,14 +524,13 @@ def _is_contiguous_slice(idx): (idx.stop is None or _is_integer_index(idx.stop)) and (idx.step is None or (_is_integer_index(idx.step) and idx.step == 1))) -def _attempt_rewriting_take_via_slice(arr: Array, idx: Any, mode: str | None) -> Array | None: +def _attempt_rewriting_take_via_slice(arr: Array, idx: Any, mode: str | None, + out_sharding=None) -> Array | None: # attempt to compute _rewriting_take via lax.slice(); return None if not possible. idx = idx if isinstance(idx, tuple) else (idx,) if not all(isinstance(i, int) for i in arr.shape): return None - if len(idx) > arr.ndim: - return None if any(i is None for i in idx): return None # TODO(jakevdp): handle newaxis case # For symbolic dimensions fallback to gather @@ -535,10 +538,13 @@ def _attempt_rewriting_take_via_slice(arr: Array, idx: Any, mode: str | None) -> for i in idx if isinstance(i, slice) for elt in (i.start, i.stop, i.step)): return None - if any(i is Ellipsis for i in idx): - # Remove ellipses and add trailing `slice(None)`. + # Remove ellipses and pad with trailing `slice(None)` if necessary. + # Do this before checking against rank of `arr` so that `...` can + # count as no dimensions at all (e.g. `my_1d_array[:, ...]` succeeds) idx = _canonicalize_tuple_index(arr.ndim, idx=idx) + if len(idx) > arr.ndim: + return None simple_revs = {i for i, ind in enumerate(idx) if _is_simple_reverse_slice(ind)} int_indices = {i for i, (ind, size) in enumerate(zip(idx, arr.shape)) @@ -570,7 +576,7 @@ def _attempt_rewriting_take_via_slice(arr: Array, idx: Any, mode: str | None) -> idx += (arr.ndim - len(idx)) * (slice(None),) start_indices: Sequence[ArrayLike] = [] - slice_sizes: Sequence[int] = [] + slice_sizes: list[int] = [] allow_negative_indices: list[bool] = [] for ind, size in safe_zip(idx, arr.shape): @@ -587,6 +593,7 @@ def _attempt_rewriting_take_via_slice(arr: Array, idx: Any, mode: str | None) -> slice_sizes.append(1) allow_negative_indices.append( not isinstance(ind, (int, np.integer)) or bool(ind < 0)) + # Try to use static slicing when possible. if all(isinstance(i, (int, np.integer)) and i >= 0 for i in start_indices): int_start_indices = [int(i) for i in start_indices] # type: ignore @@ -598,25 +605,34 @@ def _attempt_rewriting_take_via_slice(arr: Array, idx: Any, mode: str | None) -> # start indices to have matching types. if len(start_indices) > 1: start_indices = util.promote_dtypes(*start_indices) - arr = lax.dynamic_slice( - arr, start_indices=start_indices, slice_sizes=slice_sizes, - allow_negative_indices=allow_negative_indices) + jnp_error._check_precondition_oob_dynamic_slice( + arr.shape, start_indices, slice_sizes, allow_negative_indices + ) + internal_ds = partial(lax.dynamic_slice, slice_sizes=slice_sizes, + allow_negative_indices=allow_negative_indices) + if out_sharding is not None: + arr = auto_axes(internal_ds, out_sharding=out_sharding)(arr, start_indices) + else: + arr = internal_ds(arr, start_indices) if int_indices: arr = lax.squeeze(arr, tuple(int_indices)) return arr def rewriting_take(arr, idx, indices_are_sorted=False, unique_indices=False, - mode=None, fill_value=None, out_sharding=None): + mode=None, fill_value=None, normalize_indices=True, + out_sharding=None): # Computes arr[idx]. # All supported cases of indexing can be implemented as an XLA gather, # followed by an optional reverse and broadcast_in_dim. - # For simplicity of generated primitives, we call lax.dynamic_slice in the - # simplest cases: i.e. non-dynamic arrays indexed with integers and slices. - - if (result := _attempt_rewriting_take_via_slice(arr, idx, mode)) is not None: - return result + # For simplicity of generated primitives, we call lax.slice or lax.dynamic_slice + # in the simplest cases: i.e. non-dynamic arrays indexed with integers and slices. + # TODO(jakevdp): lower to slice even when normalize_indices is False + if normalize_indices: + result = _attempt_rewriting_take_via_slice(arr, idx, mode, out_sharding) + if result is not None: + return result # TODO(mattjj,dougalm): expand dynamic shape indexing support if config.dynamic_shapes.value and arr.ndim > 0: @@ -630,16 +646,24 @@ def rewriting_take(arr, idx, indices_are_sorted=False, unique_indices=False, return lax.dynamic_index_in_dim(arr, idx, keepdims=False) treedef, static_idx, dynamic_idx = split_index_for_jit(idx, arr.shape) - return _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted, - unique_indices, mode, fill_value, out_sharding) + internal_gather = partial( + _gather, treedef=treedef, static_idx=static_idx, + indices_are_sorted=indices_are_sorted, unique_indices=unique_indices, + mode=mode, fill_value=fill_value, normalize_indices=normalize_indices) + if out_sharding is not None: + return auto_axes(internal_gather, out_sharding=out_sharding + )(arr, dynamic_idx) + return internal_gather(arr, dynamic_idx) + # TODO(phawkins): re-enable jit after fixing excessive recompilation for # slice indexes (e.g., slice(0, 5, None), slice(10, 15, None), etc.). # @partial(jit, static_argnums=(1, 2)) -def _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted, - unique_indices, mode, fill_value, out_sharding): +def _gather(arr, dynamic_idx, *, treedef, static_idx, indices_are_sorted, + unique_indices, mode, fill_value, normalize_indices): idx = merge_static_and_dynamic_indices(treedef, static_idx, dynamic_idx) - indexer = index_to_gather(np.shape(arr), idx) # shared with _scatter_update + indexer = index_to_gather(np.shape(arr), idx, normalize_indices=normalize_indices) # shared with _scatter_update + jnp_error._check_precondition_oob_gather(arr.shape, indexer.gather_indices) y = arr if fill_value is not None: @@ -660,26 +684,19 @@ def _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted, # We avoid generating a gather when indexer.gather_indices.size is empty. if not core.is_empty_shape(indexer.gather_indices.shape): - internal_gather = partial( - lax.gather, - dimension_numbers=indexer.dnums, - slice_sizes=indexer.gather_slice_shape, + y = lax.gather( + y, indexer.gather_indices, indexer.dnums, indexer.gather_slice_shape, unique_indices=unique_indices or indexer.unique_indices, indices_are_sorted=indices_are_sorted or indexer.indices_are_sorted, mode=mode, fill_value=fill_value) - if out_sharding is not None: - internal_gather = auto_axes( - internal_gather, axes=mesh_lib.get_abstract_mesh().axis_names, - out_shardings=out_sharding) - y = internal_gather(y, indexer.gather_indices) # Reverses axes with negative strides. if indexer.reversed_y_dims: y = lax.rev(y, indexer.reversed_y_dims) - # This adds np.newaxis/None dimensions. return lax.expand_dims(y, indexer.newaxis_dims) + class _Indexer(NamedTuple): # The expected shape of the slice output. slice_shape: Sequence[int] @@ -1259,16 +1276,16 @@ def put(a: ArrayLike, ind: ArrayLike, v: ArrayLike, [ 0, 0, 20, 0, 0], [ 0, 0, 0, 0, 30]], dtype=int32) """ + if inplace: + raise ValueError( + "jax.numpy.put cannot modify arrays in-place, because JAX arrays are immutable. " + "Pass inplace=False to instead return an updated array.") arr, ind_arr, _ = util.ensure_arraylike("put", a, ind, v) ind_arr = ind_arr.ravel() v_arr = lax_numpy.ravel(v) if not arr.size or not ind_arr.size or not v_arr.size: return arr v_arr = lax_numpy._tile_to_size(v_arr, len(ind_arr)) - if inplace: - raise ValueError( - "jax.numpy.put cannot modify arrays in-place, because JAX arrays are immutable. " - "Pass inplace=False to instead return an updated array.") if mode is None: scatter_mode = "drop" elif mode == "clip": diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 96efc48062e1..abbdf6d0411d 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -28,59 +28,50 @@ import builtins from collections.abc import Callable, Sequence from functools import partial -import importlib import math import operator import os from typing import (Any, IO, Literal, Protocol, TypeVar, Union, overload) import warnings -import jax -from jax import jit from jax import lax +from jax._src.api import jit +from jax._src import api from jax._src import config from jax._src import core from jax._src import deprecations from jax._src import dtypes -from jax._src import xla_bridge from jax._src.api_util import _ensure_index_tuple from jax._src.custom_derivatives import custom_jvp from jax._src.lax import lax as lax_internal from jax._src.lax.lax import (PrecisionLike,_array_copy, _sort_le_comparator, _sort_lt_comparator) from jax._src.lib import xla_client as xc -from jax._src.numpy.array_creation import (empty, empty_like, full, - ones, ones_like, zeros, zeros_like) +from jax._src.numpy.array import array, asarray from jax._src.numpy import indexing from jax._src.numpy import reductions from jax._src.numpy import tensor_contractions from jax._src.numpy import ufuncs from jax._src.numpy import util +from jax._src.numpy.array_creation import (empty, empty_like, full, linspace, + ones, ones_like, zeros, zeros_like) from jax._src.numpy.sorting import argsort, sort from jax._src.numpy.vectorize import vectorize from jax._src.typing import ( - Array, ArrayLike, DType, DTypeLike, DeprecatedArg, DimSize, Shape + Array, ArrayLike, DType, DTypeLike, DeprecatedArg, DimSize, Shape, SupportsShape ) from jax._src.util import ( - NumpyComplexWarning, canonicalize_axis as _canonicalize_axis, + canonicalize_axis as _canonicalize_axis, ceil_of_ratio, safe_zip, set_module, unzip2) from jax.sharding import Sharding -from jax._src.sharding_impls import SingleDeviceSharding -from jax.tree_util import tree_leaves, tree_map +from jax._src.sharding_impls import NamedSharding, PartitionSpec as P +from jax._src.mesh import get_abstract_mesh +from jax._src.pjit import auto_axes +from jax.tree_util import tree_map import numpy as np export = set_module('jax.numpy') -for pkg_name in ['jax_cuda12_plugin', 'jax.jaxlib.cuda']: - try: - cuda_plugin_extension = importlib.import_module( - f'{pkg_name}.cuda_plugin_extension' - ) - except ImportError: - cuda_plugin_extension = None # type: ignore - else: - break - T = TypeVar('T') # Wrappers for NumPy printoptions @@ -169,7 +160,7 @@ def _dtype(x: Any) -> DType: can_cast = dtypes.can_cast promote_types = dtypes.promote_types -ComplexWarning = NumpyComplexWarning +ComplexWarning = np.exceptions.ComplexWarning _lax_const = lax_internal._const @@ -270,7 +261,7 @@ def load(file: IO[bytes] | str | os.PathLike[Any], *args: Any, **kwargs: Any) -> def fmin(x1: ArrayLike, x2: ArrayLike) -> Array: """Return element-wise minimum of the input arrays. - JAX implemtentation of :func:`numpy.fmin`. + JAX implementation of :func:`numpy.fmin`. Args: x1: input array or scalar. @@ -506,7 +497,7 @@ def isscalar(element: Any) -> bool: """ if np.isscalar(element): return True - elif isinstance(element, (np.ndarray, jax.Array)): + elif isinstance(element, (np.ndarray, Array)): return element.ndim == 0 elif hasattr(element, '__jax_array__'): return asarray(element).ndim == 0 @@ -552,7 +543,7 @@ def result_type(*args: Any) -> DType: For details on 64-bit values, refer to `Sharp bits - double precision`_: - .. _Sharp bits - double precision: https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision + .. _Sharp bits - double precision: https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision """ return dtypes.result_type(*args) @@ -911,11 +902,11 @@ def histogram(a: ArrayLike, bins: ArrayLike = 10, Array(True, dtype=bool) """ if weights is None: - util.check_arraylike("histogram", a, bins) + a, _ = util.ensure_arraylike("histogram", a, bins) a, = util.promote_dtypes_inexact(a) weights = ones_like(a) else: - util.check_arraylike("histogram", a, bins, weights) + a, _, weights = util.ensure_arraylike("histogram", a, bins, weights) if np.shape(a) != np.shape(weights): raise ValueError("weights should have the same shape as a.") a, weights = util.promote_dtypes_inexact(a, weights) @@ -1005,7 +996,7 @@ def histogram2d(x: ArrayLike, y: ArrayLike, bins: ArrayLike | list[ArrayLike] = >>> jnp.allclose(normed_sum, 1.0) Array(True, dtype=bool) """ - util.check_arraylike("histogram2d", x, y) + x, y = util.ensure_arraylike("histogram2d", x, y) try: N = len(bins) # type: ignore[arg-type] except TypeError: @@ -1077,10 +1068,10 @@ def histogramdd(sample: ArrayLike, bins: ArrayLike | list[ArrayLike] = 10, Array(True, dtype=bool) """ if weights is None: - util.check_arraylike("histogramdd", sample) + sample = util.ensure_arraylike("histogramdd", sample) sample, = util.promote_dtypes_inexact(sample) else: - util.check_arraylike("histogramdd", sample, weights) + sample, weights = util.ensure_arraylike("histogramdd", sample, weights) if np.shape(weights) != np.shape(sample)[:1]: raise ValueError("should have one weight for each sample.") sample, weights = util.promote_dtypes_inexact(sample, weights) @@ -1203,8 +1194,8 @@ def transpose(a: ArrayLike, axes: Sequence[int] | None = None) -> Array: Array([[1, 3], [2, 4]], dtype=int32) """ - util.check_arraylike("transpose", a) - axes_ = list(range(np.ndim(a))[::-1]) if axes is None else axes + a = util.ensure_arraylike("transpose", a) + axes_ = list(range(a.ndim)[::-1]) if axes is None else axes axes_ = [_canonicalize_axis(i, np.ndim(a)) for i in axes_] return lax.transpose(a, axes_) @@ -1235,7 +1226,7 @@ def permute_dims(a: ArrayLike, /, axes: tuple[int, ...]) -> Array: [2, 5], [3, 6]], dtype=int32) """ - util.check_arraylike("permute_dims", a) + a = util.ensure_arraylike("permute_dims", a) return lax.transpose(a, axes) @@ -1285,8 +1276,8 @@ def matrix_transpose(x: ArrayLike, /) -> Array: [[5, 7], [6, 8]]], dtype=int32) """ - util.check_arraylike("matrix_transpose", x) - ndim = np.ndim(x) + x = util.ensure_arraylike("matrix_transpose", x) + ndim = x.ndim if ndim < 2: raise ValueError(f"x must be at least two-dimensional for matrix_transpose; got {ndim=}") axes = (*range(ndim - 2), ndim - 1, ndim - 2) @@ -1353,7 +1344,7 @@ def rot90(m: ArrayLike, k: int = 1, axes: tuple[int, int] = (0, 1)) -> Array: [11, 8], [12, 9]]], dtype=int32) """ - util.check_arraylike("rot90", m) + m = util.ensure_arraylike("rot90", m) if np.ndim(m) < 2: raise ValueError("rot90 requires its first argument to have ndim at least " f"two, but got first argument of shape {np.shape(m)}, " @@ -1589,6 +1580,7 @@ def angle(z: ArrayLike, deg: bool = False) -> Array: [[ 71.57 -68.2 ] [-36.87 33.69]] """ + z = util.ensure_arraylike('angle', z) re = ufuncs.real(z) im = ufuncs.imag(z) dtype = _dtype(re) @@ -1944,9 +1936,8 @@ def isrealobj(x: Any) -> bool: @export def reshape( - a: ArrayLike, shape: DimSize | Shape | None = None, order: str = "C", *, - newshape: DimSize | Shape | DeprecatedArg = DeprecatedArg(), - copy: bool | None = None) -> Array: + a: ArrayLike, shape: DimSize | Shape, order: str = "C", *, + copy: bool | None = None, out_sharding=None) -> Array: """Return a reshaped copy of an array. JAX implementation of :func:`numpy.reshape`, implemented in terms of @@ -1962,8 +1953,6 @@ def reshape( JAX does not support ``order="A"``. copy: unused by JAX; JAX always returns a copy, though under JIT the compiler may optimize such copies away. - newshape: deprecated alias of the ``shape`` argument. Will result in a - :class:`DeprecationWarning` if used. Returns: reshaped copy of input array with the specified shape. @@ -2021,25 +2010,18 @@ def reshape( __tracebackhide__ = True util.check_arraylike("reshape", a) - # TODO(jakevdp): finalized 2024-12-2; remove argument after JAX v0.4.40. - if not isinstance(newshape, DeprecatedArg): - raise TypeError("The newshape argument to jnp.reshape was removed in JAX v0.4.36." - " Use shape instead.") - if shape is None: - raise TypeError( - "jnp.shape requires passing a `shape` argument, but none was given." - ) try: - # forward to method for ndarrays - return a.reshape(shape, order=order) # type: ignore[call-overload,union-attr] + if out_sharding is None: + # forward to method for ndarrays + return a.reshape(shape, order=order) # type: ignore[call-overload,union-attr] except AttributeError: pass - return asarray(a).reshape(shape, order=order) + return asarray(a).reshape(shape, order=order, out_sharding=out_sharding) @export -@partial(jit, static_argnames=('order',), inline=True) -def ravel(a: ArrayLike, order: str = "C") -> Array: +@partial(jit, static_argnames=('order', 'out_sharding'), inline=True) +def ravel(a: ArrayLike, order: str = "C", *, out_sharding=None) -> Array: """Flatten array into a 1-dimensional shape. JAX implementation of :func:`numpy.ravel`, implemented in terms of @@ -2085,10 +2067,10 @@ def ravel(a: ArrayLike, order: str = "C") -> Array: >>> x.ravel() Array([1, 2, 3, 4, 5, 6], dtype=int32) """ - util.check_arraylike("ravel", a) + a = util.ensure_arraylike("ravel", a) if order == "K": raise NotImplementedError("Ravel not implemented for order='K'.") - return reshape(a, (np.size(a),), order) + return reshape(a, (np.size(a),), order, out_sharding=out_sharding) @export @@ -2150,8 +2132,7 @@ def ravel_multi_index(multi_index: Sequence[ArrayLike], dims: Sequence[int], """ assert len(multi_index) == len(dims), f"len(multi_index)={len(multi_index)} != len(dims)={len(dims)}" dims = tuple(core.concrete_or_error(operator.index, d, "in `dims` argument of ravel_multi_index().") for d in dims) - util.check_arraylike("ravel_multi_index", *multi_index) - multi_index_arr = [asarray(i) for i in multi_index] + multi_index_arr = list(util.ensure_arraylike_tuple("ravel_multi_index", multi_index)) for index in multi_index_arr: if mode == 'raise': core.concrete_or_error(array, index, @@ -2259,7 +2240,7 @@ def resize(a: ArrayLike, new_shape: Shape) -> Array: Returns: A resized array with specified shape. The elements of ``a`` are repeated in - the resized array, if the resized array is larger than the original aray. + the resized array, if the resized array is larger than the original array. See also: - :func:`jax.numpy.reshape`: Returns a reshaped copy of an array. @@ -2435,7 +2416,7 @@ def expand_dims(a: ArrayLike, axis: int | Sequence[int]) -> Array: [2], [3]]]], dtype=int32) """ - util.check_arraylike("expand_dims", a) + a = util.ensure_arraylike("expand_dims", a) axis = _ensure_index_tuple(axis) return lax.expand_dims(a, axis) @@ -2482,7 +2463,7 @@ def swapaxes(a: ArrayLike, axis1: int, axis2: int) -> Array: >>> a.transpose(0, 3, 2, 1).shape (2, 5, 4, 3) """ - util.check_arraylike("swapaxes", a) + a = util.ensure_arraylike("swapaxes", a) perm = np.arange(np.ndim(a)) perm[axis1], perm[axis2] = perm[axis2], perm[axis1] return lax.transpose(a, list(perm)) @@ -2607,41 +2588,23 @@ def isclose(a: ArrayLike, b: ArrayLike, rtol: ArrayLike = 1e-05, atol: ArrayLike a, b = util.promote_args_inexact("isclose", a, b) dtype = _dtype(a) if issubdtype(dtype, np.complexfloating): - dtype = util._complex_elem_type(dtype) + dtype = np.array(0, dtype).real.dtype rtol = lax.convert_element_type(rtol, dtype) atol = lax.convert_element_type(atol, dtype) - out = lax.le( + both_nan = ufuncs.logical_and(ufuncs.isnan(a), ufuncs.isnan(b)) + check_fin = ufuncs.isfinite(b) + in_range = lax.le( lax.abs(lax.sub(a, b)), lax.add(atol, lax.mul(rtol, lax.abs(b)))) - # This corrects the comparisons for infinite and nan values - a_inf = ufuncs.isinf(a) - b_inf = ufuncs.isinf(b) - any_inf = ufuncs.logical_or(a_inf, b_inf) - both_inf = ufuncs.logical_and(a_inf, b_inf) - # Make all elements where either a or b are infinite to False - out = ufuncs.logical_and(out, ufuncs.logical_not(any_inf)) - # Make all elements where both a or b are the same inf to True - same_value = lax.eq(a, b) - same_inf = ufuncs.logical_and(both_inf, same_value) - out = ufuncs.logical_or(out, same_inf) - - # Make all elements where either a or b is NaN to False - a_nan = ufuncs.isnan(a) - b_nan = ufuncs.isnan(b) - any_nan = ufuncs.logical_or(a_nan, b_nan) - out = ufuncs.logical_and(out, ufuncs.logical_not(any_nan)) - if equal_nan: - # Make all elements where both a and b is NaN to True - both_nan = ufuncs.logical_and(a_nan, b_nan) - out = ufuncs.logical_or(out, both_nan) - return out + out = ufuncs.logical_or(lax.eq(a, b), ufuncs.logical_and(check_fin, in_range)) + return ufuncs.logical_or(out, both_nan) if equal_nan else out def _interp(x: ArrayLike, xp: ArrayLike, fp: ArrayLike, left: ArrayLike | str | None = None, right: ArrayLike | str | None = None, period: ArrayLike | None = None) -> Array: - util.check_arraylike("interp", x, xp, fp) + x, xp, fp = util.ensure_arraylike("interp", x, xp, fp) if np.shape(xp) != np.shape(fp) or np.ndim(xp) != 1: raise ValueError("xp and fp must be one-dimensional arrays of equal size") x_arr, xp_arr = util.promote_dtypes_inexact(x, xp) @@ -2825,7 +2788,7 @@ def where(condition, x=None, y=None, /, *, size=None, fill_value=None): (reverse-mode differentiation), a NaN in either ``x`` or ``y`` will propagate into the gradient, regardless of the value of ``condition``. More information on this behavior and workarounds is available in the `JAX FAQ - `_. + `_. Examples: When ``x`` and ``y`` are not provided, ``where`` behaves equivalently to @@ -2918,6 +2881,12 @@ def select( raise ValueError(msg.format(len(condlist), len(choicelist))) if len(condlist) == 0: raise ValueError("condlist must be non-empty") + + util.check_arraylike("select", *condlist, *choicelist, default) + condlist = [asarray(cond) for cond in condlist] + choicelist = [asarray(choice) for choice in choicelist] + default = asarray(default) + # Put the default at front with condition False because # argmax returns zero for an array of False values. choicelist = util.promote_dtypes(default, *choicelist) @@ -2934,7 +2903,7 @@ def bincount(x: ArrayLike, weights: ArrayLike | None = None, JAX implementation of :func:`numpy.bincount`. - For an array of positive integers ``x``, this function returns an array ``counts`` + For an array of non-negative integers ``x``, this function returns an array ``counts`` of size ``x.max() + 1``, such that ``counts[i]`` contains the number of occurrences of the value ``i`` in ``x``. @@ -2947,7 +2916,7 @@ def bincount(x: ArrayLike, weights: ArrayLike | None = None, like :func:`jax.jit`. In this case, items larger than `length + 1` will be dropped. Args: - x : N-dimensional array of positive integers + x : 1-dimensional array of non-negative integers weights: optional array of weights associated with ``x``. If not specified, the weight for each entry will be ``1``. minlength: the minimum length of the output counts array. @@ -2989,7 +2958,7 @@ def bincount(x: ArrayLike, weights: ArrayLike | None = None, >>> jnp.bincount(x, length=5) Array([2, 1, 0, 1, 0], dtype=int32) """ - util.check_arraylike("bincount", x) + x = util.ensure_arraylike("bincount", x) if _dtype(x) == bool: x = lax.convert_element_type(x, 'int32') if not issubdtype(_dtype(x), np.integer): @@ -3097,11 +3066,13 @@ def broadcast_arrays(*args: ArrayLike) -> list[Array]: .. _NumPy broadcasting: https://numpy.org/doc/stable/user/basics.broadcasting.html """ + args = util.ensure_arraylike_tuple("broadcast_arrays", args) return util._broadcast_arrays(*args) @export -def broadcast_to(array: ArrayLike, shape: DimSize | Shape) -> Array: +def broadcast_to(array: ArrayLike, shape: DimSize | Shape, + *, out_sharding: NamedSharding | P | None = None) -> Array: """Broadcast an array to a specified shape. JAX implementation of :func:`numpy.broadcast_to`. JAX uses NumPy-style @@ -3135,7 +3106,7 @@ def broadcast_to(array: ArrayLike, shape: DimSize | Shape) -> Array: .. _NumPy broadcasting: https://numpy.org/doc/stable/user/basics.broadcasting.html """ - return util._broadcast_to(array, shape) + return util._broadcast_to(array, shape, sharding=out_sharding) def _split(op: str, ary: ArrayLike, @@ -3410,6 +3381,7 @@ def clip( Returns: An array containing values from ``arr``, with values smaller than ``min`` set to ``min``, and values larger than ``max`` set to ``max``. + Wherever ``min`` is larger than ``max``, the value of ``max`` is returned. See also: - :func:`jax.numpy.minimum`: Compute the element-wise minimum value of two arrays. @@ -3435,7 +3407,7 @@ def clip( ) util.check_arraylike("clip", arr) - if any(jax.numpy.iscomplexobj(t) for t in (arr, min, max)): + if any(iscomplexobj(t) for t in (arr, min, max)): raise ValueError( "Clip received a complex value either through the input or the min/max " "keywords. Complex values have no ordering and cannot be clipped. " @@ -3444,7 +3416,7 @@ def clip( if min is not None: arr = ufuncs.maximum(min, arr) if max is not None: - arr = ufuncs.minimum(max, arr) + arr = ufuncs.minimum(max, arr) # type: ignore return asarray(arr) @@ -3559,7 +3531,7 @@ def fix(x: ArrayLike, out: None = None) -> Array: [-0., 0., -3.], [-1., 1., 2.]], dtype=float32) """ - util.check_arraylike("fix", x) + x = util.ensure_arraylike("fix", x) if out is not None: raise NotImplementedError("The 'out' argument to jnp.fix is not supported.") zero = _lax_const(x, 0) @@ -3777,7 +3749,7 @@ def nonzero(a: ArrayLike, *, size: int | None = None, return tuple(zeros(calculated_size, int) for dim in arr.shape) flat_indices = reductions.cumsum( bincount(reductions.cumsum(mask), length=calculated_size)) - strides: np.ndarray = (np.cumprod(arr.shape[::-1])[::-1] // arr.shape).astype(dtypes.int_) + strides: np.ndarray = (np.cumprod(arr.shape[::-1])[::-1] // arr.shape).astype(flat_indices.dtype) out = tuple((flat_indices // stride) % size for stride, size in zip(strides, arr.shape)) if fill_value is not None: fill_value_tup = fill_value if isinstance(fill_value, tuple) else arr.ndim * (fill_value,) @@ -4376,7 +4348,7 @@ def pad_func(row: Array, pad_width: tuple[int, int], Array([-10, -10, 2, 3, 4, 10, 10], dtype=int32) """ - util.check_arraylike("pad", array) + array = util.ensure_arraylike("pad", array) pad_width = _broadcast_to_pairs(pad_width, np.ndim(array), "pad_width") if pad_width and not all(core.is_dim(p[0]) and core.is_dim(p[1]) for p in pad_width): @@ -4471,7 +4443,7 @@ def stack(arrays: np.ndarray | Array | Sequence[ArrayLike], axis = _canonicalize_axis(axis, arrays.ndim) return concatenate(expand_dims(arrays, axis + 1), axis=axis, dtype=dtype) else: - util.check_arraylike("stack", *arrays) + arrays = util.ensure_arraylike_tuple("stack", arrays) shape0 = np.shape(arrays[0]) axis = _canonicalize_axis(axis, len(shape0) + 1) new_arrays = [] @@ -4560,7 +4532,7 @@ def tile(A: ArrayLike, reps: DimSize | Sequence[DimSize]) -> Array: [1, 2], [3, 4]], dtype=int32) """ - util.check_arraylike("tile", A) + A = util.ensure_arraylike("tile", A) try: iter(reps) # type: ignore[arg-type] except TypeError: @@ -4633,7 +4605,7 @@ def concatenate(arrays: np.ndarray | Array | Sequence[ArrayLike], """ if isinstance(arrays, (np.ndarray, Array)): return _concatenate_array(arrays, axis, dtype=dtype) - util.check_arraylike("concatenate", *arrays) + arrays = util.ensure_arraylike_tuple("concatenate", arrays) if not len(arrays): raise ValueError("Need at least one array to concatenate.") if axis is None: @@ -4693,7 +4665,7 @@ def concat(arrays: Sequence[ArrayLike], /, *, axis: int | None = 0) -> Array: [1., 1., 1., 0.]], dtype=float32) """ util.check_arraylike("concat", *arrays) - return jax.numpy.concatenate(arrays, axis=axis) + return concatenate(arrays, axis=axis) @export @@ -4749,7 +4721,7 @@ def vstack(tup: np.ndarray | Array | Sequence[ArrayLike], """ arrs: Array | list[Array] if isinstance(tup, (np.ndarray, Array)): - arrs = jax.vmap(atleast_2d)(tup) + arrs = api.vmap(atleast_2d)(tup) else: # TODO(jakevdp): Non-array input deprecated 2023-09-22; change to error. util.check_arraylike("vstack", *tup, emit_warning=True) @@ -4808,7 +4780,7 @@ def hstack(tup: np.ndarray | Array | Sequence[ArrayLike], """ arrs: Array | list[Array] if isinstance(tup, (np.ndarray, Array)): - arrs = jax.vmap(atleast_1d)(tup) + arrs = api.vmap(atleast_1d)(tup) arr0_ndim = arrs.ndim - 1 else: # TODO(jakevdp): Non-array input deprecated 2023-09-22; change to error. @@ -4871,10 +4843,11 @@ def dstack(tup: np.ndarray | Array | Sequence[ArrayLike], """ arrs: Array | list[Array] if isinstance(tup, (np.ndarray, Array)): - arrs = jax.vmap(atleast_3d)(tup) + arrs = api.vmap(atleast_3d)(tup) else: # TODO(jakevdp): Non-array input deprecated 2023-09-22; change to error. util.check_arraylike("dstack", *tup, emit_warning=True) + tup = util.ensure_arraylike_tuple("dstack", tup) arrs = [atleast_3d(m) for m in tup] return concatenate(arrs, axis=2, dtype=dtype) @@ -4932,7 +4905,7 @@ def column_stack(tup: np.ndarray | Array | Sequence[ArrayLike]) -> Array: """ arrs: Array | list[Array] | np.ndarray if isinstance(tup, (np.ndarray, Array)): - arrs = jax.vmap(lambda x: atleast_2d(x).T)(tup) if tup.ndim < 3 else tup + arrs = api.vmap(lambda x: atleast_2d(x).T)(tup) if tup.ndim < 3 else tup else: # TODO(jakevdp): Non-array input deprecated 2023-09-22; change to error. util.check_arraylike("column_stack", *tup, emit_warning=True) @@ -5022,7 +4995,7 @@ def choose(a, choices): """ if out is not None: raise NotImplementedError("The 'out' argument to jnp.choose is not supported.") - util.check_arraylike('choose', a, *choices) + a, *choices = util.ensure_arraylike_tuple('choose', (a, *choices)) if not issubdtype(_dtype(a), np.integer): raise ValueError("`a` array must be integer typed") N = len(choices) @@ -5336,252 +5309,13 @@ def atleast_3d(*arys: ArrayLike) -> Array | list[Array]: return [atleast_3d(arr) for arr in arys] -def _supports_buffer_protocol(obj): - try: - view = memoryview(obj) - except TypeError: - return False - else: - return True - - -def _make_string_array( - object: np.ndarray, - dtype: DTypeLike | None = None, - ndmin: int = 0, - device: xc.Device | Sharding | None = None, -) -> Array: - if not isinstance(object, np.ndarray): - raise TypeError( - "Currently, string arrays can only be made from NumPy" - f" arrays. Got: {type(object)}." - ) - if dtype is not None and ( - dtypes.is_string_dtype(object.dtype) != dtypes.is_string_dtype(dtype) - ): - raise TypeError( - f"Cannot make an array with dtype {dtype} from an object with dtype" - f" {object.dtype}." - ) - if ndmin > object.ndim: - raise TypeError( - f"ndmin {ndmin} cannot be greater than object's ndims" - f" {object.ndim} for string arrays." - ) - - # Just do a device_put since XLA does not support string as a data type. - return jax.device_put(x=object, device=device) - - -@export -def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True, - order: str | None = "K", ndmin: int = 0, - *, device: xc.Device | Sharding | None = None) -> Array: - """Convert an object to a JAX array. - - JAX implementation of :func:`numpy.array`. - - Args: - object: an object that is convertible to an array. This includes JAX - arrays, NumPy arrays, Python scalars, Python collections like lists - and tuples, objects with an ``__array__`` method, and objects - supporting the Python buffer protocol. - dtype: optionally specify the dtype of the output array. If not - specified it will be inferred from the input. - copy: specify whether to force a copy of the input. Default: True. - order: not implemented in JAX - ndmin: integer specifying the minimum number of dimensions in the - output array. - device: optional :class:`~jax.Device` or :class:`~jax.sharding.Sharding` - to which the created array will be committed. - - Returns: - A JAX array constructed from the input. - - See also: - - :func:`jax.numpy.asarray`: like `array`, but by default only copies - when necessary. - - :func:`jax.numpy.from_dlpack`: construct a JAX array from an object - that implements the dlpack interface. - - :func:`jax.numpy.frombuffer`: construct a JAX array from an object - that implements the buffer interface. - - Examples: - Constructing JAX arrays from Python scalars: - - >>> jnp.array(True) - Array(True, dtype=bool) - >>> jnp.array(42) - Array(42, dtype=int32, weak_type=True) - >>> jnp.array(3.5) - Array(3.5, dtype=float32, weak_type=True) - >>> jnp.array(1 + 1j) - Array(1.+1.j, dtype=complex64, weak_type=True) - - Constructing JAX arrays from Python collections: - - >>> jnp.array([1, 2, 3]) # list of ints -> 1D array - Array([1, 2, 3], dtype=int32) - >>> jnp.array([(1, 2, 3), (4, 5, 6)]) # list of tuples of ints -> 2D array - Array([[1, 2, 3], - [4, 5, 6]], dtype=int32) - >>> jnp.array(range(5)) - Array([0, 1, 2, 3, 4], dtype=int32) - - Constructing JAX arrays from NumPy arrays: - - >>> jnp.array(np.linspace(0, 2, 5)) - Array([0. , 0.5, 1. , 1.5, 2. ], dtype=float32) - - Constructing a JAX array via the Python buffer interface, using Python's - built-in :mod:`array` module. - - >>> from array import array - >>> pybuffer = array('i', [2, 3, 5, 7]) - >>> jnp.array(pybuffer) - Array([2, 3, 5, 7], dtype=int32) - """ - if order is not None and order != "K": - raise NotImplementedError("Only implemented for order='K'") - - # check if the given dtype is compatible with JAX - dtypes.check_user_dtype_supported(dtype, "array") - - # Here we make a judgment call: we only return a weakly-typed array when the - # input object itself is weakly typed. That ensures asarray(x) is a no-op - # whenever x is weak, but avoids introducing weak types with something like - # array([1, 2, 3]) - weak_type = dtype is None and dtypes.is_weakly_typed(object) - if device is None and isinstance(object, core.Tracer): - sharding = object.aval.sharding - sharding = None if sharding.mesh.empty else sharding - else: - sharding = canonicalize_device_to_sharding(device) - - # Use device_put to avoid a copy for ndarray inputs. - if (not copy and isinstance(object, np.ndarray) and - (dtype is None or dtype == object.dtype) and (ndmin <= object.ndim) and - device is None): - # Keep the output uncommitted. - return jax.device_put(object) - - # String arrays need separate handling because XLA does not support string - # as a data type. - if dtypes.is_string_dtype(dtype) or ( - hasattr(object, "dtype") and dtypes.is_string_dtype(object.dtype) - ): - return _make_string_array( - object=object, dtype=dtype, ndmin=ndmin, device=device - ) - - # For Python scalar literals, call coerce_to_array to catch any overflow - # errors. We don't use dtypes.is_python_scalar because we don't want this - # triggering for traced values. We do this here because it matters whether or - # not dtype is None. We don't assign the result because we want the raw object - # to be used for type inference below. - if isinstance(object, (bool, int, float, complex)): - _ = dtypes.coerce_to_array(object, dtype) - elif not isinstance(object, Array): - # Check if object supports any of the data exchange protocols - # (except dlpack, see data-apis/array-api#301). If it does, - # consume the object as jax array and continue (but not return) so - # that other array() arguments get processed against the input - # object. - # - # Notice that data exchange protocols define dtype in the - # corresponding data structures and it may not be available as - # object.dtype. So, we'll resolve the protocols here before - # evaluating object.dtype. - if hasattr(object, '__jax_array__'): - object = object.__jax_array__() - elif hasattr(object, '__cuda_array_interface__'): - cai = object.__cuda_array_interface__ - backend = xla_bridge.get_backend("cuda") - if cuda_plugin_extension is None: - device_id = None - else: - device_id = cuda_plugin_extension.get_device_ordinal(cai["data"][0]) - object = xc._xla.cuda_array_interface_to_buffer( - cai=cai, gpu_backend=backend, device_id=device_id) - - object = tree_map(lambda leaf: leaf.__jax_array__() - if hasattr(leaf, "__jax_array__") else leaf, object) - leaves = tree_leaves(object, is_leaf=lambda x: x is None) - if any(leaf is None for leaf in leaves): - # Added Nov 16 2023 - if deprecations.is_accelerated("jax-numpy-array-none"): - raise TypeError("None is not a valid value for jnp.array") - warnings.warn( - "None encountered in jnp.array(); this is currently treated as NaN. " - "In the future this will result in an error.", - FutureWarning, stacklevel=2) - leaves = tree_leaves(object) - if dtype is None: - # Use lattice_result_type rather than result_type to avoid canonicalization. - # Otherwise, weakly-typed inputs would have their dtypes canonicalized. - try: - dtype = dtypes._lattice_result_type(*leaves)[0] if leaves else dtypes.float_ - except TypeError: - # This happens if, e.g. one of the entries is a memoryview object. - # This is rare, so we only handle it if the normal path fails. - leaves = [_convert_to_array_if_dtype_fails(leaf) for leaf in leaves] - dtype = dtypes._lattice_result_type(*leaves)[0] - - if not weak_type: - dtype = dtypes.canonicalize_dtype(dtype, allow_extended_dtype=True) # type: ignore[assignment] - - out: ArrayLike - - if all(not isinstance(leaf, Array) for leaf in leaves): - # TODO(jakevdp): falling back to numpy here fails to overflow for lists - # containing large integers; see discussion in - # https://github.com/jax-ml/jax/pull/6047. More correct would be to call - # coerce_to_array on each leaf, but this may have performance implications. - out = np.asarray(object, dtype=dtype) - elif isinstance(object, Array): - assert object.aval is not None - out = _array_copy(object) if copy else object - elif isinstance(object, (list, tuple)): - if object: - out = stack([asarray(elt, dtype=dtype) for elt in object]) - else: - out = np.array([], dtype=dtype) - elif _supports_buffer_protocol(object): - object = memoryview(object) - # TODO(jakevdp): update this once we support NumPy 2.0 semantics for the copy arg. - out = np.array(object) if copy else np.asarray(object) - else: - raise TypeError(f"Unexpected input type for array: {type(object)}") - out_array: Array = lax_internal._convert_element_type( - out, dtype, weak_type=weak_type, sharding=sharding) - if ndmin > np.ndim(out_array): - out_array = lax.expand_dims(out_array, range(ndmin - np.ndim(out_array))) - return out_array - - -def canonicalize_device_to_sharding(device: xc.Device | Sharding | None - ) -> Sharding | None: - if isinstance(device, xc.Device): - return SingleDeviceSharding(device) - return device - - -def _convert_to_array_if_dtype_fails(x: ArrayLike) -> ArrayLike: - try: - dtypes.dtype(x) - except TypeError: - return np.asarray(x) - else: - return x - - @export def astype(x: ArrayLike, dtype: DTypeLike | None, /, *, copy: bool = False, device: xc.Device | Sharding | None = None) -> Array: """Convert an array to a specified dtype. - JAX imlementation of :func:`numpy.astype`. + JAX implementation of :func:`numpy.astype`. This is implemented via :func:`jax.lax.convert_element_type`, which may have slightly different behavior than :func:`numpy.astype` in some cases. @@ -5639,88 +5373,6 @@ def astype(x: ArrayLike, dtype: DTypeLike | None, return _array_copy(result) if copy else result -@export -def asarray(a: Any, dtype: DTypeLike | None = None, order: str | None = None, - *, copy: bool | None = None, - device: xc.Device | Sharding | None = None) -> Array: - """Convert an object to a JAX array. - - JAX implementation of :func:`numpy.asarray`. - - Args: - a: an object that is convertible to an array. This includes JAX - arrays, NumPy arrays, Python scalars, Python collections like lists - and tuples, objects with an ``__array__`` method, and objects - supporting the Python buffer protocol. - dtype: optionally specify the dtype of the output array. If not - specified it will be inferred from the input. - order: not implemented in JAX - copy: optional boolean specifying the copy mode. If True, then always - return a copy. If False, then error if a copy is necessary. Default is - None, which will only copy when necessary. - device: optional :class:`~jax.Device` or :class:`~jax.sharding.Sharding` - to which the created array will be committed. - - Returns: - A JAX array constructed from the input. - - See also: - - :func:`jax.numpy.array`: like `asarray`, but defaults to `copy=True`. - - :func:`jax.numpy.from_dlpack`: construct a JAX array from an object - that implements the dlpack interface. - - :func:`jax.numpy.frombuffer`: construct a JAX array from an object - that implements the buffer interface. - - Examples: - Constructing JAX arrays from Python scalars: - - >>> jnp.asarray(True) - Array(True, dtype=bool) - >>> jnp.asarray(42) - Array(42, dtype=int32, weak_type=True) - >>> jnp.asarray(3.5) - Array(3.5, dtype=float32, weak_type=True) - >>> jnp.asarray(1 + 1j) - Array(1.+1.j, dtype=complex64, weak_type=True) - - Constructing JAX arrays from Python collections: - - >>> jnp.asarray([1, 2, 3]) # list of ints -> 1D array - Array([1, 2, 3], dtype=int32) - >>> jnp.asarray([(1, 2, 3), (4, 5, 6)]) # list of tuples of ints -> 2D array - Array([[1, 2, 3], - [4, 5, 6]], dtype=int32) - >>> jnp.asarray(range(5)) - Array([0, 1, 2, 3, 4], dtype=int32) - - Constructing JAX arrays from NumPy arrays: - - >>> jnp.asarray(np.linspace(0, 2, 5)) - Array([0. , 0.5, 1. , 1.5, 2. ], dtype=float32) - - Constructing a JAX array via the Python buffer interface, using Python's - built-in :mod:`array` module. - - >>> from array import array - >>> pybuffer = array('i', [2, 3, 5, 7]) - >>> jnp.asarray(pybuffer) - Array([2, 3, 5, 7], dtype=int32) - """ - # For copy=False, the array API specifies that we raise a ValueError if the input supports - # the buffer protocol but a copy is required. Since array() supports the buffer protocol - # via numpy, this is only the case when the default device is not 'cpu' - if (copy is False and not isinstance(a, Array) - and jax.default_backend() != 'cpu' - and _supports_buffer_protocol(a)): - raise ValueError(f"jnp.asarray: cannot convert object of type {type(a)} to JAX Array " - f"on backend={jax.default_backend()!r} with copy=False. " - "Consider using copy=None or copy=True instead.") - dtypes.check_user_dtype_supported(dtype, "asarray") - if dtype is not None: - dtype = dtypes.canonicalize_dtype(dtype, allow_extended_dtype=True) # type: ignore[assignment] - return array(a, dtype=dtype, copy=bool(copy), order=order, device=device) - - @export def copy(a: ArrayLike, order: str | None = None) -> Array: """Return a copy of the array. @@ -5910,14 +5562,14 @@ def fromfile(*args, **kwargs): ``jnp.asarray(np.fromfile(...))`` instead, although care should be taken if ``np.fromfile`` is used within jax transformations because of its potential side-effect of consuming the file object; for more information see `Common Gotchas: Pure Functions - `_. + `_. """ raise NotImplementedError( "jnp.fromfile() is not implemented because it may be non-pure and thus unsafe for use " "with JIT and other JAX transformations. Consider using jnp.asarray(np.fromfile(...)) " "instead, although care should be taken if np.fromfile is used within a jax transformations " "because of its potential side-effect of consuming the file object; for more information see " - "https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#pure-functions") + "https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html#pure-functions") @export @@ -5929,14 +5581,14 @@ def fromiter(*args, **kwargs): ``jnp.asarray(np.fromiter(...))`` instead, although care should be taken if ``np.fromiter`` is used within jax transformations because of its potential side-effect of consuming the iterable object; for more information see `Common Gotchas: Pure Functions - `_. + `_. """ raise NotImplementedError( "jnp.fromiter() is not implemented because it may be non-pure and thus unsafe for use " "with JIT and other JAX transformations. Consider using jnp.asarray(np.fromiter(...)) " "instead, although care should be taken if np.fromiter is used within a jax transformations " "because of its potential side-effect of consuming the iterable object; for more information see " - "https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#pure-functions") + "https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html#pure-functions") @export @@ -5963,7 +5615,7 @@ def from_dlpack(x: Any, /, *, device: xc.Device | Sharding | None = None, if needed for a device transfer. Returns: - A JAX array of the imput buffer. + A JAX array of the input buffer. Note: While JAX arrays are always immutable, dlpack buffers cannot be marked as @@ -6083,7 +5735,7 @@ def fromfunction(function: Callable[..., Array], shape: Any, shape = core.canonicalize_shape(shape, context="shape argument of jnp.fromfunction()") for i in range(len(shape)): in_axes = [0 if i == j else None for j in range(len(shape))] - function = jax.vmap(function, in_axes=tuple(in_axes[::-1])) + function = api.vmap(function, in_axes=tuple(in_axes[::-1])) return function(*(arange(s, dtype=dtype) for s in shape), **kwargs) @@ -6172,7 +5824,7 @@ def eye(N: DimSize, M: DimSize | None = None, # instead of putting it on default device and then on the specific device output = _eye(N, M=M, k=k, dtype=dtype) if device is not None: - return jax.device_put(output, device=device) + return api.device_put(output, device=device) return output @@ -6305,7 +5957,7 @@ def arange(start: ArrayLike | DimSize, stop: ArrayLike | DimSize | None = None, # instead of putting it on default device and then on the specific device output = _arange(start, stop=stop, step=step, dtype=dtype) if device is not None: - return jax.device_put(output, device=device) + return api.device_put(output, device=device) return output @@ -6373,316 +6025,6 @@ def _arange_dynamic( return (array(start, dtype=dtype) + array(step, dtype=dtype) * lax.iota(dtype, size)) -@overload -def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50, - endpoint: bool = True, retstep: Literal[False] = False, - dtype: DTypeLike | None = None, - axis: int = 0, - *, device: xc.Device | Sharding | None = None) -> Array: ... -@overload -def linspace(start: ArrayLike, stop: ArrayLike, num: int, - endpoint: bool, retstep: Literal[True], - dtype: DTypeLike | None = None, - axis: int = 0, - *, device: xc.Device | Sharding | None = None) -> tuple[Array, Array]: ... -@overload -def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50, - endpoint: bool = True, *, retstep: Literal[True], - dtype: DTypeLike | None = None, - axis: int = 0, - device: xc.Device | Sharding | None = None) -> tuple[Array, Array]: ... -@overload -def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50, - endpoint: bool = True, retstep: bool = False, - dtype: DTypeLike | None = None, - axis: int = 0, - *, device: xc.Device | Sharding | None = None) -> Array | tuple[Array, Array]: ... -@export -def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50, - endpoint: bool = True, retstep: bool = False, - dtype: DTypeLike | None = None, - axis: int = 0, - *, device: xc.Device | Sharding | None = None) -> Array | tuple[Array, Array]: - """Return evenly-spaced numbers within an interval. - - JAX implementation of :func:`numpy.linspace`. - - Args: - start: scalar or array of starting values. - stop: scalar or array of stop values. - num: number of values to generate. Default: 50. - endpoint: if True (default) then include the ``stop`` value in the result. - If False, then exclude the ``stop`` value. - retstep: If True, then return a ``(result, step)`` tuple, where ``step`` is the - interval between adjacent values in ``result``. - axis: integer axis along which to generate the linspace. Defaults to zero. - device: optional :class:`~jax.Device` or :class:`~jax.sharding.Sharding` - to which the created array will be committed. - - Returns: - An array ``values``, or a tuple ``(values, step)`` if ``retstep`` is True, where: - - - ``values`` is an array of evenly-spaced values from ``start`` to ``stop`` - - ``step`` is the interval between adjacent values. - - See also: - - :func:`jax.numpy.arange`: Generate ``N`` evenly-spaced values given a starting - point and a step - - :func:`jax.numpy.logspace`: Generate logarithmically-spaced values. - - :func:`jax.numpy.geomspace`: Generate geometrically-spaced values. - - Examples: - List of 5 values between 0 and 10: - - >>> jnp.linspace(0, 10, 5) - Array([ 0. , 2.5, 5. , 7.5, 10. ], dtype=float32) - - List of 8 values between 0 and 10, excluding the endpoint: - - >>> jnp.linspace(0, 10, 8, endpoint=False) - Array([0. , 1.25, 2.5 , 3.75, 5. , 6.25, 7.5 , 8.75], dtype=float32) - - List of values and the step size between them - - >>> vals, step = jnp.linspace(0, 10, 9, retstep=True) - >>> vals - Array([ 0. , 1.25, 2.5 , 3.75, 5. , 6.25, 7.5 , 8.75, 10. ], dtype=float32) - >>> step - Array(1.25, dtype=float32) - - Multi-dimensional linspace: - - >>> start = jnp.array([0, 5]) - >>> stop = jnp.array([5, 10]) - >>> jnp.linspace(start, stop, 5) - Array([[ 0. , 5. ], - [ 1.25, 6.25], - [ 2.5 , 7.5 ], - [ 3.75, 8.75], - [ 5. , 10. ]], dtype=float32) - """ - num = core.concrete_dim_or_error(num, "'num' argument of jnp.linspace") - axis = core.concrete_or_error(operator.index, axis, "'axis' argument of jnp.linspace") - return _linspace(start, stop, num, endpoint, retstep, dtype, axis, device=device) - -@partial(jit, static_argnames=('num', 'endpoint', 'retstep', 'dtype', 'axis', 'device')) -def _linspace(start: ArrayLike, stop: ArrayLike, num: int = 50, - endpoint: bool = True, retstep: bool = False, - dtype: DTypeLike | None = None, - axis: int = 0, - *, device: xc.Device | Sharding | None = None) -> Array | tuple[Array, Array]: - """Implementation of linspace differentiable in start and stop args.""" - dtypes.check_user_dtype_supported(dtype, "linspace") - if num < 0: - raise ValueError(f"Number of samples, {num}, must be non-negative.") - start, stop = util.ensure_arraylike("linspace", start, stop) - - if dtype is None: - dtype = dtypes.to_inexact_dtype(result_type(start, stop)) - dtype = dtypes.jax_dtype(dtype) - computation_dtype = dtypes.to_inexact_dtype(dtype) - start = start.astype(computation_dtype) - stop = stop.astype(computation_dtype) - - bounds_shape = list(lax.broadcast_shapes(np.shape(start), np.shape(stop))) - broadcast_start = broadcast_to(start, bounds_shape) - broadcast_stop = broadcast_to(stop, bounds_shape) - axis = len(bounds_shape) + axis + 1 if axis < 0 else axis - bounds_shape.insert(axis, 1) - div = (num - 1) if endpoint else num - if num > 1: - delta: Array = lax.convert_element_type(stop - start, computation_dtype) / array(div, dtype=computation_dtype) - iota_shape = [1,] * len(bounds_shape) - iota_shape[axis] = div - # This approach recovers the endpoints with float32 arithmetic, - # but can lead to rounding errors for integer outputs. - real_dtype = finfo(computation_dtype).dtype - step = reshape(lax.iota(real_dtype, div), iota_shape) / array(div, real_dtype) - step = step.astype(computation_dtype) - out = (reshape(broadcast_start, bounds_shape) * (1 - step) + - reshape(broadcast_stop, bounds_shape) * step) - - if endpoint: - out = lax.concatenate([out, lax.expand_dims(broadcast_stop, (axis,))], - _canonicalize_axis(axis, out.ndim)) - - elif num == 1: - delta = asarray(np.nan if endpoint else stop - start, dtype=computation_dtype) - out = reshape(broadcast_start, bounds_shape) - else: # num == 0 degenerate case, match numpy behavior - empty_shape = list(lax.broadcast_shapes(np.shape(start), np.shape(stop))) - empty_shape.insert(axis, 0) - delta = asarray(np.nan, dtype=computation_dtype) - out = reshape(array([], dtype=dtype), empty_shape) - - if issubdtype(dtype, np.integer) and not issubdtype(out.dtype, np.integer): - out = lax.floor(out) - - sharding = canonicalize_device_to_sharding(device) - result = lax_internal._convert_element_type(out, dtype, sharding=sharding) - return (result, delta) if retstep else result - - -@export -def logspace(start: ArrayLike, stop: ArrayLike, num: int = 50, - endpoint: bool = True, base: ArrayLike = 10.0, - dtype: DTypeLike | None = None, axis: int = 0) -> Array: - """Generate logarithmically-spaced values. - - JAX implementation of :func:`numpy.logspace`. - - Args: - start: scalar or array. Used to specify the start value. The start value is - ``base ** start``. - stop: scalar or array. Used to specify the stop value. The end value is - ``base ** stop``. - num: int, optional, default=50. Number of values to generate. - endpoint: bool, optional, default=True. If True, then include the ``stop`` value - in the result. If False, then exclude the ``stop`` value. - base: scalar or array, optional, default=10. Specifies the base of the logarithm. - dtype: optional. Specifies the dtype of the output. - axis: int, optional, default=0. Axis along which to generate the logspace. - - Returns: - An array of logarithm. - - See also: - - :func:`jax.numpy.arange`: Generate ``N`` evenly-spaced values given a starting - point and a step value. - - :func:`jax.numpy.linspace`: Generate evenly-spaced values. - - :func:`jax.numpy.geomspace`: Generate geometrically-spaced values. - - Examples: - List 5 logarithmically spaced values between 1 (``10 ** 0``) and 100 - (``10 ** 2``): - - >>> with jnp.printoptions(precision=3, suppress=True): - ... jnp.logspace(0, 2, 5) - Array([ 1. , 3.162, 10. , 31.623, 100. ], dtype=float32) - - List 5 logarithmically-spaced values between 1(``10 ** 0``) and 100 - (``10 ** 2``), excluding endpoint: - - >>> with jnp.printoptions(precision=3, suppress=True): - ... jnp.logspace(0, 2, 5, endpoint=False) - Array([ 1. , 2.512, 6.31 , 15.849, 39.811], dtype=float32) - - List 7 logarithmically-spaced values between 1 (``2 ** 0``) and 4 (``2 ** 2``) - with base 2: - - >>> with jnp.printoptions(precision=3, suppress=True): - ... jnp.logspace(0, 2, 7, base=2) - Array([1. , 1.26 , 1.587, 2. , 2.52 , 3.175, 4. ], dtype=float32) - - Multi-dimensional logspace: - - >>> start = jnp.array([0, 5]) - >>> stop = jnp.array([5, 0]) - >>> base = jnp.array([2, 3]) - >>> with jnp.printoptions(precision=3, suppress=True): - ... jnp.logspace(start, stop, 5, base=base) - Array([[ 1. , 243. ], - [ 2.378, 61.547], - [ 5.657, 15.588], - [ 13.454, 3.948], - [ 32. , 1. ]], dtype=float32) - """ - num = core.concrete_or_error(operator.index, num, "'num' argument of jnp.logspace") - axis = core.concrete_or_error(operator.index, axis, "'axis' argument of jnp.logspace") - return _logspace(start, stop, num, endpoint, base, dtype, axis) - -@partial(jit, static_argnames=('num', 'endpoint', 'dtype', 'axis')) -def _logspace(start: ArrayLike, stop: ArrayLike, num: int = 50, - endpoint: bool = True, base: ArrayLike = 10.0, - dtype: DTypeLike | None = None, axis: int = 0) -> Array: - """Implementation of logspace differentiable in start and stop args.""" - dtypes.check_user_dtype_supported(dtype, "logspace") - if dtype is None: - dtype = dtypes.to_inexact_dtype(result_type(start, stop)) - dtype = dtypes.jax_dtype(dtype) - computation_dtype = dtypes.to_inexact_dtype(dtype) - start, stop = util.ensure_arraylike("logspace", start, stop) - start = start.astype(computation_dtype) - stop = stop.astype(computation_dtype) - lin = linspace(start, stop, num, - endpoint=endpoint, retstep=False, dtype=None, axis=axis) - return lax.convert_element_type(ufuncs.power(base, lin), dtype) - - -@export -def geomspace(start: ArrayLike, stop: ArrayLike, num: int = 50, endpoint: bool = True, - dtype: DTypeLike | None = None, axis: int = 0) -> Array: - """Generate geometrically-spaced values. - - JAX implementation of :func:`numpy.geomspace`. - - Args: - start: scalar or array. Specifies the starting values. - stop: scalar or array. Specifies the stop values. - num: int, optional, default=50. Number of values to generate. - endpoint: bool, optional, default=True. If True, then include the ``stop`` value - in the result. If False, then exclude the ``stop`` value. - dtype: optional. Specifies the dtype of the output. - axis: int, optional, default=0. Axis along which to generate the geomspace. - - Returns: - An array containing the geometrically-spaced values. - - See also: - - :func:`jax.numpy.arange`: Generate ``N`` evenly-spaced values given a starting - point and a step value. - - :func:`jax.numpy.linspace`: Generate evenly-spaced values. - - :func:`jax.numpy.logspace`: Generate logarithmically-spaced values. - - Examples: - List 5 geometrically-spaced values between 1 and 16: - - >>> with jnp.printoptions(precision=3, suppress=True): - ... jnp.geomspace(1, 16, 5) - Array([ 1., 2., 4., 8., 16.], dtype=float32) - - List 4 geomtrically-spaced values between 1 and 16, with ``endpoint=False``: - - >>> with jnp.printoptions(precision=3, suppress=True): - ... jnp.geomspace(1, 16, 4, endpoint=False) - Array([1., 2., 4., 8.], dtype=float32) - - Multi-dimensional geomspace: - - >>> start = jnp.array([1, 1000]) - >>> stop = jnp.array([27, 1]) - >>> with jnp.printoptions(precision=3, suppress=True): - ... jnp.geomspace(start, stop, 4) - Array([[ 1., 1000.], - [ 3., 100.], - [ 9., 10.], - [ 27., 1.]], dtype=float32) - """ - num = core.concrete_or_error(operator.index, num, "'num' argument of jnp.geomspace") - axis = core.concrete_or_error(operator.index, axis, "'axis' argument of jnp.geomspace") - return _geomspace(start, stop, num, endpoint, dtype, axis) - -@partial(jit, static_argnames=('num', 'endpoint', 'dtype', 'axis')) -def _geomspace(start: ArrayLike, stop: ArrayLike, num: int = 50, endpoint: bool = True, - dtype: DTypeLike | None = None, axis: int = 0) -> Array: - """Implementation of geomspace differentiable in start and stop args.""" - dtypes.check_user_dtype_supported(dtype, "geomspace") - if dtype is None: - dtype = dtypes.to_inexact_dtype(result_type(start, stop)) - dtype = dtypes.jax_dtype(dtype) - computation_dtype = dtypes.to_inexact_dtype(dtype) - start, stop = util.ensure_arraylike("geomspace", start, stop) - start = start.astype(computation_dtype) - stop = stop.astype(computation_dtype) - - sign = ufuncs.sign(start) - res = sign * logspace(ufuncs.log10(start / sign), ufuncs.log10(stop / sign), - num, endpoint=endpoint, base=10.0, - dtype=computation_dtype, axis=0) - if axis != 0: - res = moveaxis(res, 0, axis) - return lax.convert_element_type(res, dtype) - @export def meshgrid(*xi: ArrayLike, copy: bool = True, sparse: bool = False, @@ -6812,7 +6154,7 @@ def _i0(x): @_i0.defjvp def _i0_jvp(primals, tangents): - primal_out, tangent_out = jax.jvp(_i0.fun, primals, tangents) + primal_out, tangent_out = api.jvp(_i0.fun, primals, tangents) return primal_out, where(primals[0] == 0, 0.0, tangent_out) @export @@ -6931,7 +6273,8 @@ def indices(dimensions: Sequence[int], dtype: DTypeLike | None = None, @export def repeat(a: ArrayLike, repeats: ArrayLike, axis: int | None = None, *, - total_repeat_length: int | None = None) -> Array: + total_repeat_length: int | None = None, + out_sharding: NamedSharding | P | None = None) -> Array: """Construct an array from repeated elements. JAX implementation of :func:`numpy.repeat`. @@ -6995,8 +6338,44 @@ def repeat(a: ArrayLike, repeats: ArrayLike, axis: int | None = None, *, Array([[1, 1, 2, 2, 2, 2, 2], [3, 3, 4, 4, 4, 4, 4]], dtype=int32) """ - arr = util.ensure_arraylike("repeat", a) - core.is_dim(repeats) or util.check_arraylike("repeat", repeats) + if out_sharding is not None: + return _auto_repeat(_repeat, a, repeats, axis, total_repeat_length, + out_sharding) + ctx_mesh = get_abstract_mesh() + if ctx_mesh._are_all_axes_explicit: + aval = core.typeof(a) + if axis is None or aval.sharding.spec[axis] is not None: + raise ValueError( + "Please pass sharding to `jnp.repeat` via `out_sharding` parameter.") + assert axis is not None and aval.sharding.spec[axis] is None + out_sharding = (NamedSharding(ctx_mesh, P()) + if aval.sharding.mesh.empty else aval.sharding) + return _auto_repeat(_repeat, a, repeats, axis, total_repeat_length, + out_sharding) + try: + return _repeat(a, repeats=repeats, axis=axis, + total_repeat_length=total_repeat_length) + except core.ShardingTypeError as e: + raise ValueError( + "Please pass sharding to `jnp.repeat` via `out_sharding` parameter.") + +def _auto_repeat(fun, a, repeats, axis, total_repeat_length, out_sharding): + if total_repeat_length is None: + return auto_axes(partial(fun, repeats=repeats, axis=axis, + total_repeat_length=total_repeat_length), + out_sharding=out_sharding)(a) + else: + return auto_axes( + partial(fun, axis=axis, total_repeat_length=total_repeat_length), + out_sharding=out_sharding)(a, repeats=repeats) + +def _repeat(a: ArrayLike, *, repeats: ArrayLike, axis: int | None = None, + total_repeat_length: int | None = None) -> Array: + if core.is_dim(repeats): + util.check_arraylike("repeat", a) + else: + util.check_arraylike("repeat", a, repeats) + arr = asarray(a) if axis is None: arr = arr.ravel() @@ -7118,11 +6497,11 @@ def trapezoid(y: ArrayLike, x: ArrayLike | None = None, dx: ArrayLike = 1.0, # TODO(phawkins): remove this annotation after fixing jnp types. dx_array: Array if x is None: - util.check_arraylike('trapezoid', y) + y = util.ensure_arraylike('trapezoid', y) y_arr, = util.promote_dtypes_inexact(y) dx_array = asarray(dx) else: - util.check_arraylike('trapezoid', y, x) + y, x = util.ensure_arraylike('trapezoid', y, x) y_arr, x_arr = util.promote_dtypes_inexact(y, x) if x_arr.ndim == 1: dx_array = diff(x_arr) @@ -7243,7 +6622,7 @@ def tril(m: ArrayLike, k: int = 0) -> Array: [[5, 0], [7, 8]]], dtype=int32) """ - util.check_arraylike("tril", m) + m = util.ensure_arraylike("tril", m) m_shape = np.shape(m) if len(m_shape) < 2: raise ValueError("Argument to jax.numpy.tril must be at least 2D") @@ -7310,7 +6689,7 @@ def triu(m: ArrayLike, k: int = 0) -> Array: [[5, 6], [0, 8]]], dtype=int32) """ - util.check_arraylike("triu", m) + m = util.ensure_arraylike("triu", m) m_shape = np.shape(m) if len(m_shape) < 2: raise ValueError("Argument to jax.numpy.triu must be at least 2D") @@ -7367,7 +6746,7 @@ def trace(a: ArrayLike, offset: int | ArrayLike = 0, axis1: int = 0, axis2: int >>> jnp.trace(x, offset=1, axis1=1, axis2=2) Array([2, 6], dtype=int32) """ - util.check_arraylike("trace", a) + a = util.ensure_arraylike("trace", a) if out is not None: raise NotImplementedError("The 'out' argument to jnp.trace is not supported.") @@ -7564,7 +6943,7 @@ def tril_indices(n: int, k: int = 0, m: int | None = None) -> tuple[Array, Array @export -def triu_indices_from(arr: ArrayLike, k: int = 0) -> tuple[Array, Array]: +def triu_indices_from(arr: ArrayLike | SupportsShape, k: int = 0) -> tuple[Array, Array]: """Return the indices of upper triangle of a given array. JAX implementation of :func:`numpy.triu_indices_from`. @@ -7615,14 +6994,18 @@ def triu_indices_from(arr: ArrayLike, k: int = 0) -> tuple[Array, Array]: >>> jnp.triu_indices_from(arr, k=-1) (Array([0, 0, 0, 1, 1, 1, 2, 2], dtype=int32), Array([0, 1, 2, 0, 1, 2, 1, 2], dtype=int32)) """ - arr_shape = np.shape(arr) + if hasattr(arr, "shape"): + arr_shape = arr.shape + else: + arr = util.ensure_arraylike("triu_indices_from", arr) + arr_shape = arr.shape if len(arr_shape) != 2: raise ValueError("Only 2-D inputs are accepted") return triu_indices(arr_shape[0], k=k, m=arr_shape[1]) @export -def tril_indices_from(arr: ArrayLike, k: int = 0) -> tuple[Array, Array]: +def tril_indices_from(arr: ArrayLike | SupportsShape, k: int = 0) -> tuple[Array, Array]: """Return the indices of lower triangle of a given array. JAX implementation of :func:`numpy.tril_indices_from`. @@ -7673,7 +7056,11 @@ def tril_indices_from(arr: ArrayLike, k: int = 0) -> tuple[Array, Array]: >>> jnp.tril_indices_from(arr, k=-1) (Array([1, 2, 2], dtype=int32), Array([0, 0, 1], dtype=int32)) """ - arr_shape = np.shape(arr) + if hasattr(arr, "shape"): + arr_shape = arr.shape + else: + arr = util.ensure_arraylike("tril_indices_from", arr) + arr_shape = arr.shape if len(arr_shape) != 2: raise ValueError("Only 2-D inputs are accepted") return tril_indices(arr_shape[0], k=k, m=arr_shape[1]) @@ -7827,7 +7214,7 @@ def diag_indices_from(arr: ArrayLike) -> tuple[Array, ...]: Array([0, 1], dtype=int32), Array([0, 1], dtype=int32)) """ - util.check_arraylike("diag_indices_from", arr) + arr = util.ensure_arraylike("diag_indices_from", arr) nd = np.ndim(arr) if not np.ndim(arr) >= 2: raise ValueError("input array must be at least 2-d") @@ -7876,7 +7263,7 @@ def diagonal(a: ArrayLike, offset: int = 0, axis1: int = 0, >>> jnp.diagonal(x, offset=-1) Array([4, 8], dtype=int32) """ - util.check_arraylike("diagonal", a) + a = util.ensure_arraylike("diagonal", a) if np.ndim(a) < 2: raise ValueError("diagonal requires an array of at least two dimensions.") @@ -7962,11 +7349,11 @@ def diag(v: ArrayLike, k: int = 0) -> Array: >>> jnp.diag(x) Array([1, 5, 9], dtype=int32) """ + v = util.ensure_arraylike("diag", v) return _diag(v, operator.index(k)) @partial(jit, static_argnames=('k',)) -def _diag(v, k): - util.check_arraylike("diag", v) +def _diag(v: Array, k: int): v_shape = np.shape(v) if len(v_shape) == 1: zero = lambda x: lax.full_like(x, shape=(), fill_value=0) @@ -8063,7 +7450,7 @@ def trim_zeros(filt: ArrayLike, trim: str ='fb') -> Array: util.check_arraylike("trim_zeros", filt, emit_warning=True) core.concrete_or_error(None, filt, "Error arose in the `filt` argument of trim_zeros()") - filt_arr = jax.numpy.asarray(filt) + filt_arr = asarray(filt) del filt if filt_arr.ndim != 1: # Added on 2024-09-11 @@ -8243,6 +7630,9 @@ def delete( # Case 3: obj is an array # NB: pass both arrays to check for appropriate error message. util.check_arraylike("delete", a, obj) + # Can't use ensure_arraylike here because obj may be static. + if hasattr(obj, "__jax_array__"): + obj = obj.__jax_array__() # Case 3a: unique integer indices; delete in a JIT-compatible way if issubdtype(_dtype(obj), np.integer) and assume_unique_indices: @@ -8441,9 +7831,9 @@ def apply_along_axis( axis = _canonicalize_axis(axis, num_dims) func = lambda arr: func1d(arr, *args, **kwargs) for i in range(1, num_dims - axis): - func = jax.vmap(func, in_axes=i, out_axes=-1) + func = api.vmap(func, in_axes=i, out_axes=-1) for i in range(axis): - func = jax.vmap(func, in_axes=0, out_axes=0) + func = api.vmap(func, in_axes=0, out_axes=0) return func(arr) @@ -8687,7 +8077,7 @@ def vander( [3, 1], [4, 1]], dtype=int32) - Generates the Vandermonde matrix in increaing order of powers, when + Generates the Vandermonde matrix in increasing order of powers, when ``increasing=True``. >>> jnp.vander(x, increasing=True) @@ -8773,6 +8163,7 @@ def argwhere( >>> jnp.argwhere(0) Array([], shape=(0, 0), dtype=int32) """ + a = util.ensure_arraylike("argwhere", a) result = transpose(vstack(nonzero(atleast_1d(a), size=size, fill_value=fill_value))) if np.ndim(a) == 0: return result[:0].reshape(result.shape[0], 0) @@ -8945,12 +8336,12 @@ def nanargmax( """ if out is not None: raise NotImplementedError("The 'out' argument to jnp.nanargmax is not supported.") + a = util.ensure_arraylike("nanargmax", a) return _nanargmax(a, None if axis is None else operator.index(axis), keepdims=bool(keepdims)) @partial(jit, static_argnames=('axis', 'keepdims')) -def _nanargmax(a, axis: int | None = None, keepdims: bool = False): - util.check_arraylike("nanargmax", a) +def _nanargmax(a: Array, axis: int | None = None, keepdims: bool = False): if not issubdtype(_dtype(a), np.inexact): return argmax(a, axis=axis, keepdims=keepdims) nan_mask = ufuncs.isnan(a) @@ -9006,12 +8397,12 @@ def nanargmin( """ if out is not None: raise NotImplementedError("The 'out' argument to jnp.nanargmin is not supported.") + a = util.ensure_arraylike("nanargmin", a) return _nanargmin(a, None if axis is None else operator.index(axis), keepdims=bool(keepdims)) @partial(jit, static_argnames=('axis', 'keepdims')) -def _nanargmin(a, axis: int | None = None, keepdims : bool = False): - util.check_arraylike("nanargmin", a) +def _nanargmin(a: Array, axis: int | None = None, keepdims : bool = False): if not issubdtype(_dtype(a), np.inexact): return argmin(a, axis=axis, keepdims=keepdims) nan_mask = ufuncs.isnan(a) @@ -9154,7 +8545,7 @@ def rollaxis(a: ArrayLike, axis: int, start: int = 0) -> Array: >>> jnp.moveaxis(a, 1, -1).shape (2, 4, 5, 3) """ - util.check_arraylike("rollaxis", a) + a = util.ensure_arraylike("rollaxis", a) start = core.concrete_or_error(operator.index, start, "'start' argument of jnp.rollaxis()") a_ndim = np.ndim(a) axis = _canonicalize_axis(axis, a_ndim) @@ -9231,7 +8622,7 @@ def packbits(a: ArrayLike, axis: int | None = None, bitorder: str = "big") -> Ar raise TypeError('Expected an input array of integer or boolean data type') if bitorder not in ['little', 'big']: raise ValueError("'order' must be either 'little' or 'big'") - arr = lax.gt(arr, _lax_const(a, 0)).astype('uint8') + arr = lax.ne(arr, _lax_const(arr, 0)).astype('uint8') bits = arange(8, dtype='uint8') if bitorder == 'big': bits = bits[::-1] @@ -9391,7 +8782,7 @@ def gcd(x1: ArrayLike, x2: ArrayLike) -> Array: >>> jnp.gcd(x1, x2) Array([ 6, 3, 12], dtype=int32) """ - util.check_arraylike("gcd", x1, x2) + x1, x2 = util.ensure_arraylike("gcd", x1, x2) x1, x2 = util.promote_dtypes(x1, x2) if not issubdtype(_dtype(x1), np.integer): raise ValueError("Arguments to jax.numpy.gcd must be integers.") @@ -9438,7 +8829,7 @@ def lcm(x1: ArrayLike, x2: ArrayLike) -> Array: >>> jnp.lcm(x1, x2) Array([12, 36, 12], dtype=int32) """ - util.check_arraylike("lcm", x1, x2) + x1, x2 = util.ensure_arraylike("lcm", x1, x2) x1, x2 = util.promote_dtypes(x1, x2) x1, x2 = ufuncs.abs(x1), ufuncs.abs(x2) if not issubdtype(_dtype(x1), np.integer): @@ -9890,7 +9281,7 @@ def _rank(x): def _searchsorted_via_compare_all(sorted_arr: Array, query: Array, side: str, dtype: type) -> Array: op = _sort_lt_comparator if side == 'left' else _sort_le_comparator - comparisons = jax.vmap(op, in_axes=(0, None))(sorted_arr, query) + comparisons = api.vmap(op, in_axes=(0, None))(sorted_arr, query) return comparisons.sum(dtype=dtype, axis=0) @@ -9957,9 +9348,9 @@ def searchsorted(a: ArrayLike, v: ArrayLike, side: str = 'left', Array([0, 2, 5, 1, 1], dtype=int32) """ if sorter is None: - util.check_arraylike("searchsorted", a, v) + a, v = util.ensure_arraylike("searchsorted", a, v) else: - util.check_arraylike("searchsorted", a, v, sorter) + a, v, sorter = util.ensure_arraylike("searchsorted", a, v, sorter) if side not in ['left', 'right']: raise ValueError(f"{side!r} is an invalid value for keyword 'side'. " "Expected one of ['left', 'right'].") diff --git a/jax/_src/numpy/linalg.py b/jax/_src/numpy/linalg.py index 23f2a58b09f6..2351b0ccb075 100644 --- a/jax/_src/numpy/linalg.py +++ b/jax/_src/numpy/linalg.py @@ -23,10 +23,11 @@ import operator from typing import Literal, NamedTuple, overload -import jax -from jax import jit, custom_jvp from jax import lax +from jax._src.api import jit +from jax._src import config +from jax._src.custom_derivatives import custom_jvp from jax._src import deprecations from jax._src.lax import lax as lax_internal from jax._src.lax.lax import PrecisionLike @@ -44,24 +45,24 @@ class EighResult(NamedTuple): - eigenvalues: jax.Array - eigenvectors: jax.Array + eigenvalues: Array + eigenvectors: Array class QRResult(NamedTuple): - Q: jax.Array - R: jax.Array + Q: Array + R: Array class SlogdetResult(NamedTuple): - sign: jax.Array - logabsdet: jax.Array + sign: Array + logabsdet: Array class SVDResult(NamedTuple): - U: jax.Array - S: jax.Array - Vh: jax.Array + U: Array + S: Array + Vh: Array def _H(x: ArrayLike) -> Array: @@ -72,8 +73,8 @@ def _symmetrize(x: Array) -> Array: return (x + _H(x)) / 2 @export -@partial(jit, static_argnames=['upper']) -def cholesky(a: ArrayLike, *, upper: bool = False) -> Array: +@partial(jit, static_argnames=['upper', 'symmetrize_input']) +def cholesky(a: ArrayLike, *, upper: bool = False, symmetrize_input: bool = True) -> Array: """Compute the Cholesky decomposition of a matrix. JAX implementation of :func:`numpy.linalg.cholesky`. @@ -98,6 +99,10 @@ def cholesky(a: ArrayLike, *, upper: bool = False) -> Array: Must have shape ``(..., N, N)``. upper: if True, compute the upper Cholesky decomposition `U`. if False (default), compute the lower Cholesky decomposition `L`. + symmetrize_input: if True (default) then input is symmetrized, which leads + to better behavior under automatic differentiation. Note that when this + is set to True, both the upper and lower triangles of the input will + be used in computing the decomposition. Returns: array of shape ``(..., N, N)`` representing the Cholesky decomposition @@ -135,7 +140,7 @@ def cholesky(a: ArrayLike, *, upper: bool = False) -> Array: """ a = ensure_arraylike("jnp.linalg.cholesky", a) a, = promote_dtypes_inexact(a) - L = lax_linalg.cholesky(a) + L = lax_linalg.cholesky(a, symmetrize_input=symmetrize_input) return L.mT.conj() if upper else L @@ -363,8 +368,7 @@ def matrix_power(a: ArrayLike, n: int) -> Array: Array([[ 5.5 , -2.5 ], [-3.75, 1.75]], dtype=float32) """ - a = ensure_arraylike("jnp.linalg.matrix_power", a) - arr, = promote_dtypes_inexact(a) + arr = ensure_arraylike("jnp.linalg.matrix_power", a) if arr.ndim < 2: raise TypeError("{}-dimensional array given. Array must be at least " @@ -821,7 +825,9 @@ def eigh(a: ArrayLike, UPLO: str | None = None, UPLO: specifies whether the calculation is done with the lower triangular part of ``a`` (``'L'``, default) or the upper triangular part (``'U'``). symmetrize_input: if True (default) then input is symmetrized, which leads - to better behavior under automatic differentiation. + to better behavior under automatic differentiation. Note that when this + is set to True, both the upper and lower triangles of the input will + be used in computing the decomposition. Returns: A namedtuple ``(eigenvalues, eigenvectors)`` where @@ -863,8 +869,9 @@ def eigh(a: ArrayLike, UPLO: str | None = None, @export -@partial(jit, static_argnames=('UPLO',)) -def eigvalsh(a: ArrayLike, UPLO: str | None = 'L') -> Array: +@partial(jit, static_argnames=('UPLO', 'symmetrize_input')) +def eigvalsh(a: ArrayLike, UPLO: str | None = 'L', *, + symmetrize_input: bool = True) -> Array: """ Compute the eigenvalues of a Hermitian matrix. @@ -875,6 +882,10 @@ def eigvalsh(a: ArrayLike, UPLO: str | None = 'L') -> Array: or symmetric (if real) matrix. UPLO: specifies whether the calculation is done with the lower triangular part of ``a`` (``'L'``, default) or the upper triangular part (``'U'``). + symmetrize_input: if True (default) then input is symmetrized, which leads + to better behavior under automatic differentiation. Note that when this + is set to True, both the upper and lower triangles of the input will + be used in computing the decomposition. Returns: An array of shape ``(..., M)`` containing the eigenvalues, sorted in @@ -894,7 +905,7 @@ def eigvalsh(a: ArrayLike, UPLO: str | None = 'L') -> Array: """ a = ensure_arraylike("jnp.linalg.eigvalsh", a) a, = promote_dtypes_inexact(a) - w, _ = eigh(a, UPLO) + w, _ = eigh(a, UPLO, symmetrize_input=symmetrize_input) return w @@ -985,7 +996,7 @@ def _pinv(a: ArrayLike, rtol: ArrayLike | None = None, hermitian: bool = False) @_pinv.defjvp -@jax.default_matmul_precision("float32") +@config.default_matmul_precision("float32") def _pinv_jvp(rtol, hermitian, primals, tangents): # The Differentiation of Pseudo-Inverses and Nonlinear Least Squares Problems # Whose Variables Separate. Author(s): G. H. Golub and V. Pereyra. SIAM @@ -1606,8 +1617,8 @@ def matrix_transpose(x: ArrayLike, /) -> Array: x_arr = ensure_arraylike('jnp.linalg.matrix_transpose', x) ndim = x_arr.ndim if ndim < 2: - raise ValueError(f"matrix_transpose requres at least 2 dimensions; got {ndim=}") - return jax.lax.transpose(x_arr, (*range(ndim - 2), ndim - 1, ndim - 2)) + raise ValueError(f"matrix_transpose requires at least 2 dimensions; got {ndim=}") + return lax.transpose(x_arr, (*range(ndim - 2), ndim - 1, ndim - 2)) @export diff --git a/jax/_src/numpy/polynomial.py b/jax/_src/numpy/polynomial.py index 81d320cb7403..2f7a32c3f52d 100644 --- a/jax/_src/numpy/polynomial.py +++ b/jax/_src/numpy/polynomial.py @@ -19,8 +19,8 @@ import numpy as np -from jax import jit from jax import lax +from jax._src.api import jit from jax._src import dtypes from jax._src import core from jax._src.lax import lax as lax_internal @@ -146,7 +146,7 @@ def polyfit(x: ArrayLike, y: ArrayLike, deg: int, rcond: float | None = None, rcond: Relative condition number of the fit. Default value is ``len(x) * eps``. It must be specified statically. full: Switch that controls the return value. Default is ``False`` which - restricts the return value to the array of polynomail coefficients ``p``. + restricts the return value to the array of polynomial coefficients ``p``. If ``True``, the function returns a tuple ``(p, resids, rank, s, rcond)``. It must be specified statically. w: Array of weights of shape ``(M,)``. If None, all data points are considered @@ -154,8 +154,8 @@ def polyfit(x: ArrayLike, y: ArrayLike, deg: int, rcond: float | None = None, unsquared residual of :math:`y_i - \widehat{y}_i` at :math:`x_i`, where :math:`\widehat{y}_i` is the fitted value of :math:`y_i`. Default is None. cov: Boolean or string. If ``True``, returns the covariance matrix scaled - by ``resids/(M-deg-1)`` along with ploynomial coefficients. If - ``cov='unscaled'``, returns the unscaaled version of covariance matrix. + by ``resids/(M-deg-1)`` along with polynomial coefficients. If + ``cov='unscaled'``, returns the unscaled version of covariance matrix. Default is ``False``. ``cov`` is ignored if ``full=True``. It must be specified statically. @@ -224,7 +224,7 @@ def polyfit(x: ArrayLike, y: ArrayLike, deg: int, rcond: float | None = None, >>> p, C = jnp.polyfit(x, y, 2, cov=True) >>> p.shape, C.shape - ((3, 3), (3, 3, 1)) + ((3, 3), (3, 3, 3)) """ if w is None: x_arr, y_arr = ensure_arraylike("polyfit", x, y) @@ -233,7 +233,6 @@ def polyfit(x: ArrayLike, y: ArrayLike, deg: int, rcond: float | None = None, del x, y deg = core.concrete_or_error(int, deg, "deg must be int") order = deg + 1 - # check arguments if deg < 0: raise ValueError("expected deg >= 0") if x_arr.ndim != 1: @@ -245,7 +244,6 @@ def polyfit(x: ArrayLike, y: ArrayLike, deg: int, rcond: float | None = None, if x_arr.shape[0] != y_arr.shape[0]: raise TypeError("expected x and y to have same length") - # set rcond if rcond is None: rcond = len(x_arr) * float(finfo(x_arr.dtype).eps) rcond = core.concrete_or_error(float, rcond, "rcond must be float") @@ -268,9 +266,17 @@ def polyfit(x: ArrayLike, y: ArrayLike, deg: int, rcond: float | None = None, # scale lhs to improve condition number and solve scale = sqrt((lhs*lhs).sum(axis=0)) - lhs /= scale[np.newaxis,:] + lhs /= scale[np.newaxis, :] c, resids, rank, s = linalg.lstsq(lhs, rhs, rcond) - c = (c.T/scale).T # broadcast scale coefficients + + # Broadcasting scale coefficients + if c.ndim > 1: + # For multi-dimensional output, make scale (1, order) to divide + # across the c.T of shape (num_rhs, order) + c = (c.T / scale[np.newaxis, :]).T + else: + # Simple case for 1D output + c = c / scale if full: assert rcond is not None @@ -278,22 +284,25 @@ def polyfit(x: ArrayLike, y: ArrayLike, deg: int, rcond: float | None = None, elif cov: Vbase = linalg.inv(dot(lhs.T, lhs)) Vbase /= outer(scale, scale) + if cov == "unscaled": - fac = 1 + fac = array(1.0) else: if len(x_arr) <= order: - raise ValueError("the number of data points must exceed order " - "to scale the covariance matrix") + raise ValueError("the number of data points must exceed order" + " to scale the covariance matrix") fac = resids / (len(x_arr) - order) - fac = fac[0] #making np.array() of shape (1,) to int + if y_arr.ndim == 1: + fac = atleast_1d(fac)[np.newaxis] + # For 1D output, simple scalar multiplication return c, Vbase * fac else: - return c, Vbase[:, :, np.newaxis] * fac + # For multiple rhs, broadcast fac to match shape + return c, Vbase[:, :, np.newaxis] * atleast_1d(fac)[np.newaxis, np.newaxis, :] else: return c - @export @jit def poly(seq_of_zeros: ArrayLike) -> Array: diff --git a/jax/_src/numpy/reductions.py b/jax/_src/numpy/reductions.py index 985b296bc06f..e1f499ccc530 100644 --- a/jax/_src/numpy/reductions.py +++ b/jax/_src/numpy/reductions.py @@ -23,16 +23,17 @@ import numpy as np -import jax from jax import lax from jax._src import api +from jax._src import config from jax._src import core from jax._src import deprecations from jax._src import dtypes from jax._src.numpy.util import ( - _broadcast_to, check_arraylike, _complex_elem_type, + _broadcast_to, ensure_arraylike, promote_dtypes_inexact, promote_dtypes_numeric, _where) from jax._src.lax import lax as lax_internal +from jax._src.lax import other as lax_other from jax._src.typing import Array, ArrayLike, DType, DTypeLike, DeprecatedArg from jax._src.util import ( canonicalize_axis as _canonicalize_axis, maybe_named_axis, @@ -54,8 +55,7 @@ def _isscalar(element: Any) -> bool: def _moveaxis(a: ArrayLike, source: int, destination: int) -> Array: # simplified version of jnp.moveaxis() for local use. - check_arraylike("moveaxis", a) - a = lax_internal.asarray(a) + a = ensure_arraylike("moveaxis", a) source = _canonicalize_axis(source, np.ndim(a)) destination = _canonicalize_axis(destination, np.ndim(a)) perm = [i for i in range(np.ndim(a)) if i != source] @@ -83,8 +83,7 @@ def _promote_integer_dtype(dtype: DTypeLike) -> DTypeLike: def check_where(name: str, where: ArrayLike | None) -> Array | None: if where is None: return where - check_arraylike(name, where) - where_arr = lax_internal.asarray(where) + where_arr = ensure_arraylike(name, where) if where_arr.dtype != bool: # Deprecation added 2024-12-05 deprecations.warn( @@ -113,7 +112,7 @@ def _reduction(a: ArrayLike, name: str, op: ReductionOp, init_val: ArrayLike, # exists, passing along all its arguments. if out is not None: raise NotImplementedError(f"The 'out' argument to jnp.{name} is not supported.") - check_arraylike(name, a) + a = ensure_arraylike(name, a) where_ = check_where(name, where_) dtypes.check_user_dtype_supported(dtype, name) axis = core.concrete_or_error(None, axis, f"axis argument to jnp.{name}().") @@ -122,7 +121,6 @@ def _reduction(a: ArrayLike, name: str, op: ReductionOp, init_val: ArrayLike, raise ValueError(f"reduction operation {name} does not have an identity, so to use a " f"where mask one has to specify 'initial'") - a = a if isinstance(a, Array) else lax_internal.asarray(a) a = preproc(a) if preproc else a pos_dims, dims = _reduction_dims(a, axis) @@ -401,11 +399,11 @@ def prod(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, @partial(api.jit, static_argnames=('axis', 'keepdims'), inline=True) -def _reduce_max(a: ArrayLike, axis: Axis = None, out: None = None, - keepdims: bool = False, initial: ArrayLike | None = None, - where: ArrayLike | None = None) -> Array: +def _reduce_max(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, + out: None = None, keepdims: bool = False, + initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: return _reduction(a, "max", lax.max, -np.inf, has_identity=False, - axis=axis, out=out, keepdims=keepdims, + axis=axis, dtype=dtype, out=out, keepdims=keepdims, initial=initial, where_=where, parallel_reduce=lax.pmax) @@ -483,12 +481,12 @@ def max(a: ArrayLike, axis: Axis = None, out: None = None, return _reduce_max(a, axis=_ensure_optional_axes(axis), out=out, keepdims=keepdims, initial=initial, where=where) -@partial(api.jit, static_argnames=('axis', 'keepdims'), inline=True) -def _reduce_min(a: ArrayLike, axis: Axis = None, out: None = None, - keepdims: bool = False, initial: ArrayLike | None = None, - where: ArrayLike | None = None) -> Array: +@partial(api.jit, static_argnames=('axis', 'keepdims', 'dtype'), inline=True) +def _reduce_min(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, + out: None = None, keepdims: bool = False, + initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: return _reduction(a, "min", lax.min, np.inf, has_identity=False, - axis=axis, out=out, keepdims=keepdims, + axis=axis, dtype=dtype, out=out, keepdims=keepdims, initial=initial, where_=where, parallel_reduce=lax.pmin) @@ -685,7 +683,7 @@ def _reduce_bitwise_and(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: arr = lax_internal.asarray(a) - init_val = np.array(-1, dtype=dtype or arr.dtype) + init_val = np.array(-1).astype(dtype or arr.dtype) return _reduction(arr, name="reduce_bitwise_and", op=lax.bitwise_and, init_val=init_val, preproc=_require_integer, axis=_ensure_optional_axes(axis), dtype=dtype, out=out, keepdims=keepdims, initial=initial, where_=where) @@ -743,7 +741,7 @@ def _logsumexp(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, if out is not None: raise NotImplementedError("The 'out' argument to jnp.logaddexp.reduce is not supported.") dtypes.check_user_dtype_supported(dtype, "jnp.logaddexp.reduce") - check_arraylike("logsumexp", a) + a = ensure_arraylike("logsumexp", a) where = check_where("logsumexp", where) a_arr, = promote_dtypes_inexact(a) pos_dims, dims = _reduction_dims(a_arr, axis) @@ -753,7 +751,7 @@ def _logsumexp(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, exp_a = lax.exp(lax.sub(a_arr, amax_with_dims.astype(a_arr.dtype))) sumexp = exp_a.sum(axis=dims, keepdims=keepdims, where=where) result = lax.add(lax.log(sumexp), amax.astype(sumexp.dtype)) - return result if initial is None else lax.logaddexp(initial, result) + return result if initial is None else lax_other.logaddexp(initial, result) def _logsumexp2(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, @@ -763,7 +761,7 @@ def _logsumexp2(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, if out is not None: raise NotImplementedError("The 'out' argument to jnp.logaddexp2.reduce is not supported.") dtypes.check_user_dtype_supported(dtype, "jnp.logaddexp2.reduce") - check_arraylike("logsumexp2", a) + a = ensure_arraylike("logsumexp2", a) where = check_where("logsumexp2", where) ln2 = float(np.log(2)) if initial is not None: @@ -771,7 +769,6 @@ def _logsumexp2(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, return _logsumexp(a * ln2, axis=axis, dtype=dtype, keepdims=keepdims, where=where, initial=initial) / ln2 - @export def amin(a: ArrayLike, axis: Axis = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, @@ -796,7 +793,7 @@ def _axis_size(a: ArrayLike, axis: int | Sequence[int]): size = 1 a_shape = np.shape(a) for a in axis_seq: - size *= maybe_named_axis(a, lambda i: a_shape[i], lambda name: lax.psum(1, name)) + size *= maybe_named_axis(a, lambda i: a_shape[i], lax.axis_size) return size @@ -873,7 +870,7 @@ def _mean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, keepdims: bool = False, *, upcast_f16_for_computation: bool = True, where: ArrayLike | None = None) -> Array: - check_arraylike("mean", a) + a = ensure_arraylike("mean", a) where = check_where("mean", where) if out is not None: raise NotImplementedError("The 'out' argument to jnp.mean is not supported.") @@ -972,7 +969,7 @@ def average(a: ArrayLike, axis: Axis = None, weights: ArrayLike | None = None, def _average(a: ArrayLike, axis: Axis = None, weights: ArrayLike | None = None, returned: bool = False, keepdims: bool = False) -> Array | tuple[Array, Array]: if weights is None: # Treat all weights as 1 - check_arraylike("average", a) + a = ensure_arraylike("average", a) a, = promote_dtypes_inexact(a) avg = mean(a, axis=axis, keepdims=keepdims) if axis is None: @@ -982,7 +979,7 @@ def _average(a: ArrayLike, axis: Axis = None, weights: ArrayLike | None = None, else: weights_sum = lax.full_like(avg, core.dimension_as_value(a.shape[axis])) # type: ignore[index] else: - check_arraylike("average", a, weights) + a, weights = ensure_arraylike("average", a, weights) a, weights = promote_dtypes_inexact(a, weights) a_shape = np.shape(a) @@ -991,7 +988,7 @@ def _average(a: ArrayLike, axis: Axis = None, weights: ArrayLike | None = None, if axis is None: pass - elif isinstance(axis, tuple): + elif isinstance(axis, Sequence): axis = tuple(_canonicalize_axis(d, a_ndim) for d in axis) else: axis = _canonicalize_axis(axis, a_ndim) @@ -1104,14 +1101,14 @@ def var(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, correction = ddof elif not isinstance(ddof, int) or ddof != 0: raise ValueError("ddof and correction can't be provided simultaneously.") + a = ensure_arraylike("var", a) return _var(a, _ensure_optional_axes(axis), dtype, out, correction, keepdims, where=where) @partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims')) -def _var(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, +def _var(a: Array, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, correction: int | float = 0, keepdims: bool = False, *, where: ArrayLike | None = None) -> Array: - check_arraylike("var", a) where = check_where("var", where) dtypes.check_user_dtype_supported(dtype, "var") if out is not None: @@ -1139,7 +1136,7 @@ def _var(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, normalizer = lax.sub(normalizer, lax.convert_element_type(correction, computation_dtype)) result = sum(centered, axis, dtype=computation_dtype, keepdims=keepdims, where=where) result = lax.div(result, normalizer).astype(dtype) - with jax.debug_nans(False): + with config.debug_nans(False): result = _where(normalizer > 0, result, np.nan) return result @@ -1160,7 +1157,7 @@ def _var_promote_types(a_dtype: DTypeLike, dtype: DTypeLike | None) -> tuple[DTy dtype = dtypes.to_inexact_dtype(a_dtype) computation_dtype = dtype else: - dtype = _complex_elem_type(a_dtype) + dtype = np.array(0, a_dtype).real.dtype computation_dtype = a_dtype return _upcast_f16(computation_dtype), np.dtype(dtype) @@ -1242,14 +1239,14 @@ def std(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, correction = ddof elif not isinstance(ddof, int) or ddof != 0: raise ValueError("ddof and correction can't be provided simultaneously.") + a = ensure_arraylike("std", a) return _std(a, _ensure_optional_axes(axis), dtype, out, correction, keepdims, where=where) @partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims')) -def _std(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, +def _std(a: Array, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, correction: int | float = 0, keepdims: bool = False, *, where: ArrayLike | None = None) -> Array: - check_arraylike("std", a) where = check_where("std", where) dtypes.check_user_dtype_supported(dtype, "std") if dtype is not None and not dtypes.issubdtype(dtype, np.inexact): @@ -1298,12 +1295,12 @@ def ptp(a: ArrayLike, axis: Axis = None, out: None = None, [7], [6]], dtype=int32) """ + a = ensure_arraylike("ptp", a) return _ptp(a, _ensure_optional_axes(axis), out, keepdims) @partial(api.jit, static_argnames=('axis', 'keepdims')) -def _ptp(a: ArrayLike, axis: Axis = None, out: None = None, +def _ptp(a: Array, axis: Axis = None, out: None = None, keepdims: bool = False) -> Array: - check_arraylike("ptp", a) if out is not None: raise NotImplementedError("The 'out' argument to jnp.ptp is not supported.") x = amax(a, axis=axis, keepdims=keepdims) @@ -1350,7 +1347,7 @@ def count_nonzero(a: ArrayLike, axis: Axis = None, [1], [3]], dtype=int32) """ - check_arraylike("count_nonzero", a) + a = ensure_arraylike("count_nonzero", a) return sum(lax.ne(a, _lax_const(a, 0)), axis=axis, dtype=dtypes.canonicalize_dtype(int), keepdims=keepdims) @@ -1359,7 +1356,7 @@ def _nan_reduction(a: ArrayLike, name: str, jnp_reduction: Callable[..., Array], init_val: ArrayLike, nan_if_all_nan: bool, axis: Axis = None, keepdims: bool = False, where: ArrayLike | None = None, **kwargs) -> Array: - check_arraylike(name, a) + a = ensure_arraylike(name, a) where = check_where(name, where) if not dtypes.issubdtype(dtypes.dtype(a), np.inexact): return jnp_reduction(a, axis=axis, keepdims=keepdims, where=where, **kwargs) @@ -1783,7 +1780,7 @@ def nanmean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out >>> jnp.nanmean(x, axis=0, keepdims=True, where=where) Array([[nan, nan, nan, nan]], dtype=float32) """ - check_arraylike("nanmean", a) + a = ensure_arraylike("nanmean", a) where = check_where("nanmean", where) if out is not None: raise NotImplementedError("The 'out' argument to jnp.nanmean is not supported.") @@ -1877,7 +1874,7 @@ def nanvar(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: [0. ], [4. ]], dtype=float32) """ - check_arraylike("nanvar", a) + a = ensure_arraylike("nanvar", a) where = check_where("nanvar", where) dtypes.check_user_dtype_supported(dtype, "nanvar") if out is not None: @@ -1973,7 +1970,7 @@ def nanstd(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: >>> jnp.nanstd(x, axis=0, keepdims=True, where=where) Array([[0.5, 0.5, 0. , 0. ]], dtype=float32) """ - check_arraylike("nanstd", a) + a = ensure_arraylike("nanstd", a) where = check_where("nanstd", where) dtypes.check_user_dtype_supported(dtype, "nanstd") if out is not None: @@ -1992,7 +1989,7 @@ def _cumulative_reduction( fill_nan: bool = False, fill_value: ArrayLike = 0, promote_integers: bool = False) -> Array: """Helper function for implementing cumulative reductions.""" - check_arraylike(name, a) + a = ensure_arraylike(name, a) if out is not None: raise NotImplementedError(f"The 'out' argument to jnp.{name} is not supported") dtypes.check_user_dtype_supported(dtype, name) @@ -2242,8 +2239,7 @@ def cumulative_sum( Array([[ 0, 1, 3, 6], [ 0, 4, 9, 15]], dtype=int32) """ - check_arraylike("cumulative_sum", x) - x = lax_internal.asarray(x) + x = ensure_arraylike("cumulative_sum", x) if x.ndim == 0: raise ValueError( "The input must be non-scalar to take a cumulative sum, however a " @@ -2304,8 +2300,7 @@ def cumulative_prod( Array([[ 1, 1, 2, 6], [ 1, 4, 20, 120]], dtype=int32) """ - check_arraylike("cumulative_prod", x) - x = lax_internal.asarray(x) + x = ensure_arraylike("cumulative_prod", x) if x.ndim == 0: raise ValueError( "The input must be non-scalar to take a cumulative product, however a " @@ -2377,7 +2372,7 @@ def quantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = No >>> jnp.quantile(x, q, method='nearest') Array([2., 4., 7.], dtype=float32) """ - check_arraylike("quantile", a, q) + a, q = ensure_arraylike("quantile", a, q) if overwrite_input or out is not None: raise ValueError("jax.numpy.quantile does not support overwrite_input=True " "or out != None") @@ -2435,7 +2430,7 @@ def nanquantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = >>> jnp.nanquantile(x, q) Array([1.5, 3. , 4.5], dtype=float32) """ - check_arraylike("nanquantile", a, q) + a, q = ensure_arraylike("nanquantile", a, q) if overwrite_input or out is not None: msg = ("jax.numpy.nanquantile does not support overwrite_input=True or " "out != None") @@ -2518,7 +2513,7 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None, index[axis] = high high_value = a[tuple(index)] else: - with jax.debug_nans(False): + with config.debug_nans(False): a = _where(any(lax_internal._isnan(a), axis=axis, keepdims=True), np.nan, a) a = lax.sort(a, dimension=axis) n = lax.convert_element_type(a_shape[axis], lax_internal._dtype(q)) @@ -2618,7 +2613,7 @@ def percentile(a: ArrayLike, q: ArrayLike, >>> jnp.percentile(x, q, method='nearest') Array([1., 3., 4.], dtype=float32) """ - check_arraylike("percentile", a, q) + a, q = ensure_arraylike("percentile", a, q) q, = promote_dtypes_inexact(q) if not isinstance(interpolation, DeprecatedArg): deprecations.warn( @@ -2678,7 +2673,7 @@ def nanpercentile(a: ArrayLike, q: ArrayLike, >>> jnp.nanpercentile(x, q) Array([1.5, 3. , 4.5], dtype=float32) """ - check_arraylike("nanpercentile", a, q) + a, q = ensure_arraylike("nanpercentile", a, q) q, = promote_dtypes_inexact(q) q = q / 100 if not isinstance(interpolation, DeprecatedArg): @@ -2738,7 +2733,7 @@ def median(a: ArrayLike, axis: int | tuple[int, ...] | None = None, [4. ], [4.5]], dtype=float32) """ - check_arraylike("median", a) + a = ensure_arraylike("median", a) return quantile(a, 0.5, axis=axis, out=out, overwrite_input=overwrite_input, keepdims=keepdims, method='midpoint') @@ -2795,7 +2790,7 @@ def nanmedian(a: ArrayLike, axis: int | tuple[int, ...] | None = None, [5. ], [3. ]], dtype=float32) """ - check_arraylike("nanmedian", a) + a = ensure_arraylike("nanmedian", a) return nanquantile(a, 0.5, axis=axis, out=out, overwrite_input=overwrite_input, keepdims=keepdims, method='midpoint') diff --git a/jax/_src/numpy/scalar_types.py b/jax/_src/numpy/scalar_types.py index 2f9954488b41..442df38a9641 100644 --- a/jax/_src/numpy/scalar_types.py +++ b/jax/_src/numpy/scalar_types.py @@ -22,11 +22,12 @@ from typing import Any -import jax +import numpy as np + from jax._src.typing import Array from jax._src import core from jax._src import dtypes -import numpy as np +from jax._src.numpy.array import asarray # Some objects below rewrite their __module__ attribute to this name. @@ -46,7 +47,7 @@ def __ne__(self, other: Any) -> bool: return not (self == other) def __call__(self, x: Any) -> Array: - return jax.numpy.asarray(x, dtype=self.dtype) + return asarray(x, dtype=self.dtype) def __instancecheck__(self, instance: Any) -> bool: return isinstance(instance, self.dtype.type) @@ -68,33 +69,27 @@ def _make_scalar_type(np_scalar_type: type) -> _ScalarMeta: return meta bool_ = _make_scalar_type(np.bool_) -if dtypes.uint2 is not None: - uint2 = _make_scalar_type(dtypes.uint2) +uint2 = _make_scalar_type(dtypes.uint2) uint4 = _make_scalar_type(dtypes.uint4) uint8 = _make_scalar_type(np.uint8) uint16 = _make_scalar_type(np.uint16) uint32 = _make_scalar_type(np.uint32) uint64 = _make_scalar_type(np.uint64) -if dtypes.int2 is not None: - int2 = _make_scalar_type(dtypes.int2) +int2 = _make_scalar_type(dtypes.int2) int4 = _make_scalar_type(dtypes.int4) int8 = _make_scalar_type(np.int8) int16 = _make_scalar_type(np.int16) int32 = _make_scalar_type(np.int32) int64 = _make_scalar_type(np.int64) -if dtypes.float8_e3m4 is not None: - float8_e3m4 = _make_scalar_type(dtypes.float8_e3m4) -if dtypes.float8_e4m3 is not None: - float8_e4m3 = _make_scalar_type(dtypes.float8_e4m3) -if dtypes.float8_e8m0fnu is not None: - float8_e8m0fnu = _make_scalar_type(dtypes.float8_e8m0fnu) +float4_e2m1fn = _make_scalar_type(dtypes.float4_e2m1fn) +float8_e3m4 = _make_scalar_type(dtypes.float8_e3m4) +float8_e4m3 = _make_scalar_type(dtypes.float8_e4m3) +float8_e8m0fnu = _make_scalar_type(dtypes.float8_e8m0fnu) float8_e4m3fn = _make_scalar_type(dtypes.float8_e4m3fn) float8_e4m3fnuz = _make_scalar_type(dtypes.float8_e4m3fnuz) float8_e5m2 = _make_scalar_type(dtypes.float8_e5m2) float8_e5m2fnuz = _make_scalar_type(dtypes.float8_e5m2fnuz) float8_e4m3b11fnuz = _make_scalar_type(dtypes.float8_e4m3b11fnuz) -if dtypes.float4_e2m1fn is not None: - float4_e2m1fn = _make_scalar_type(dtypes.float4_e2m1fn) bfloat16 = _make_scalar_type(dtypes.bfloat16) float16 = _make_scalar_type(np.float16) float32 = single = _make_scalar_type(np.float32) diff --git a/jax/_src/numpy/setops.py b/jax/_src/numpy/setops.py index d4a8e41dd317..ef1d44ae01b1 100644 --- a/jax/_src/numpy/setops.py +++ b/jax/_src/numpy/setops.py @@ -21,10 +21,9 @@ import numpy as np -import jax -from jax import jit from jax import lax +from jax._src.api import jit from jax._src import core from jax._src import dtypes from jax._src.lax import lax as lax_internal @@ -59,8 +58,10 @@ def _in1d(ar1: ArrayLike, ar2: ArrayLike, invert: bool, else: return (arr1[:, None] == arr2[None, :]).any(-1) elif method == 'binary_search': + from jax._src.numpy.lax_numpy import searchsorted + arr2 = lax.sort(arr2) - ind = jax.numpy.searchsorted(arr2, arr1) + ind = searchsorted(arr2, arr1) if invert: return arr1 != arr2[ind] else: diff --git a/jax/_src/numpy/sorting.py b/jax/_src/numpy/sorting.py index a0f368e2ef07..d8d1f7751d67 100644 --- a/jax/_src/numpy/sorting.py +++ b/jax/_src/numpy/sorting.py @@ -13,18 +13,18 @@ # limitations under the License. from functools import partial -from typing import Sequence +from collections.abc import Sequence import numpy as np -import jax +from jax import lax + from jax._src import api from jax._src import core from jax._src import dtypes from jax._src.numpy import util from jax._src.util import canonicalize_axis, set_module from jax._src.typing import Array, ArrayLike -from jax import lax export = set_module('jax.numpy') @@ -226,7 +226,7 @@ def partition(a: ArrayLike, kth: int, axis: int = -1) -> Array: axis = canonicalize_axis(axis, arr.ndim) kth = canonicalize_axis(kth, arr.shape[axis]) - arr = jax.numpy.swapaxes(arr, axis, -1) + arr = arr.swapaxes(axis, -1) if dtypes.isdtype(arr.dtype, "unsigned integer"): # Here, we apply a trick to handle correctly 0 values for unsigned integers bottom = -lax.top_k(-(arr + 1), kth + 1)[0] - 1 @@ -234,7 +234,7 @@ def partition(a: ArrayLike, kth: int, axis: int = -1) -> Array: bottom = -lax.top_k(-arr, kth + 1)[0] top = lax.top_k(arr, arr.shape[-1] - kth - 1)[0] out = lax.concatenate([bottom, top], dimension=arr.ndim - 1) - return jax.numpy.swapaxes(out, -1, axis) + return out.swapaxes(-1, axis) @export @@ -297,7 +297,7 @@ def argpartition(a: ArrayLike, kth: int, axis: int = -1) -> Array: axis = canonicalize_axis(axis, arr.ndim) kth = canonicalize_axis(kth, arr.shape[axis]) - arr = jax.numpy.swapaxes(arr, axis, -1) + arr = arr.swapaxes(axis, -1) if dtypes.isdtype(arr.dtype, "unsigned integer"): # Here, we apply a trick to handle correctly 0 values for unsigned integers bottom_ind = lax.top_k(-(arr + 1), kth + 1)[1] @@ -307,11 +307,11 @@ def argpartition(a: ArrayLike, kth: int, axis: int = -1) -> Array: # To avoid issues with duplicate values, we compute the top indices via a proxy set_to_zero = lambda a, i: a.at[i].set(0) for _ in range(arr.ndim - 1): - set_to_zero = jax.vmap(set_to_zero) - proxy = set_to_zero(jax.numpy.ones(arr.shape), bottom_ind) + set_to_zero = api.vmap(set_to_zero) + proxy = set_to_zero(lax.full(arr.shape, 1.0), bottom_ind) top_ind = lax.top_k(proxy, arr.shape[-1] - kth - 1)[1] out = lax.concatenate([bottom_ind, top_ind], dimension=arr.ndim - 1) - return jax.numpy.swapaxes(out, -1, axis) + return out.swapaxes(-1, axis) @export @@ -421,7 +421,7 @@ def lexsort(keys: Array | np.ndarray | Sequence[ArrayLike], axis: int = -1) -> A if len({np.shape(key) for key in key_arrays}) > 1: raise ValueError("all keys need to be the same shape") if np.ndim(key_arrays[0]) == 0: - return jax.numpy.array(0, dtype=dtypes.canonicalize_dtype(dtypes.int_)) + return lax.full((), 0, dtypes.canonicalize_dtype(dtypes.int_)) axis = canonicalize_axis(axis, np.ndim(key_arrays[0])) use_64bit_index = key_arrays[0].shape[axis] >= (1 << 31) iota = lax.broadcasted_iota(np.dtype('int64') if use_64bit_index else dtypes.int_, diff --git a/jax/_src/numpy/tensor_contractions.py b/jax/_src/numpy/tensor_contractions.py index 850eb90cf1d2..255da08e1816 100644 --- a/jax/_src/numpy/tensor_contractions.py +++ b/jax/_src/numpy/tensor_contractions.py @@ -20,7 +20,6 @@ import numpy as np -import jax from jax import lax from jax._src import core from jax._src import dtypes @@ -36,10 +35,12 @@ export = set_module('jax.numpy') @export -@partial(jit, static_argnames=('precision', 'preferred_element_type'), inline=True) +@partial(jit, static_argnames=('precision', 'preferred_element_type', 'out_sharding'), + inline=True) def dot(a: ArrayLike, b: ArrayLike, *, precision: PrecisionLike = None, - preferred_element_type: DTypeLike | None = None) -> Array: + preferred_element_type: DTypeLike | None = None, + out_sharding=None) -> Array: """Compute the dot product of two arrays. JAX implementation of :func:`numpy.dot`. @@ -119,7 +120,8 @@ def dot(a: ArrayLike, b: ArrayLike, *, contract_dims = ((a_ndim - 1,), (b_ndim - 2,)) result = lax.dot_general(a, b, dimension_numbers=(contract_dims, batch_dims), precision=precision, - preferred_element_type=preferred_element_type) + preferred_element_type=preferred_element_type, + out_sharding=out_sharding) return lax_internal._convert_element_type(result, preferred_element_type, output_weak_type) @@ -284,7 +286,7 @@ def matvec(x1: ArrayLike, x2: ArrayLike, /) -> Array: Array([[ 50, 122], [ 38, 92]], dtype=int32) """ - util.check_arraylike("matvec", x1, x2) + x1, x2 = util.ensure_arraylike("matvec", x1, x2) return vectorize(matmul, signature="(n,m),(m)->(n)")(x1, x2) @@ -326,7 +328,7 @@ def vecmat(x1: ArrayLike, x2: ArrayLike, /) -> Array: Array([[ 40, 46], [ 94, 109]], dtype=int32) """ - util.check_arraylike("matvec", x1, x2) + x1, x2 = util.ensure_arraylike("matvec", x1, x2) return vectorize(matmul, signature="(n),(n,m)->(m)")(ufuncs.conj(x1), x2) @@ -372,10 +374,10 @@ def vdot( >>> jnp.dot(x, y) Array(0.+14.j, dtype=complex64) """ - util.check_arraylike("vdot", a, b) + a, b = util.ensure_arraylike("vdot", a, b) if dtypes.issubdtype(dtypes.dtype(a, canonicalize=True), np.complexfloating): a = ufuncs.conj(a) - return dot(jax.numpy.ravel(a), jax.numpy.ravel(b), precision=precision, + return dot(a.ravel(), b.ravel(), precision=precision, preferred_element_type=preferred_element_type) @@ -426,11 +428,13 @@ def vecdot(x1: ArrayLike, x2: ArrayLike, /, *, axis: int = -1, >>> jnp.linalg.vecdot(a, b, axis=-1) Array([20, 47], dtype=int32) """ + from jax._src.numpy.lax_numpy import moveaxis + x1_arr, x2_arr = util.ensure_arraylike("jnp.vecdot", x1, x2) if x1_arr.shape[axis] != x2_arr.shape[axis]: raise ValueError(f"axes must match; got shapes {x1_arr.shape} and {x2_arr.shape} with {axis=}") - x1_arr = jax.numpy.moveaxis(x1_arr, axis, -1) - x2_arr = jax.numpy.moveaxis(x2_arr, axis, -1) + x1_arr = moveaxis(x1_arr, axis, -1) + x2_arr = moveaxis(x2_arr, axis, -1) return vectorize(partial(vdot, precision=precision, preferred_element_type=preferred_element_type), signature="(n),(n)->()")(x1_arr, x2_arr) @@ -601,8 +605,9 @@ def inner( """ a, b = util.ensure_arraylike("inner", a, b) if np.ndim(a) == 0 or np.ndim(b) == 0: - a = jax.numpy.asarray(a, dtype=preferred_element_type) - b = jax.numpy.asarray(b, dtype=preferred_element_type) + if preferred_element_type is not None: + a = a.astype(preferred_element_type) + b = b.astype(preferred_element_type) return a * b return tensordot(a, b, (-1, -1), precision=precision, preferred_element_type=preferred_element_type) @@ -638,6 +643,6 @@ def outer(a: ArrayLike, b: ArrayLike, out: None = None) -> Array: """ if out is not None: raise NotImplementedError("The 'out' argument to jnp.outer is not supported.") - util.check_arraylike("outer", a, b) + a, b = util.ensure_arraylike("outer", a, b) a, b = util.promote_dtypes(a, b) - return jax.numpy.ravel(a)[:, None] * jax.numpy.ravel(b)[None, :] + return a.ravel()[:, None] * b.ravel()[None, :] diff --git a/jax/_src/numpy/ufunc_api.py b/jax/_src/numpy/ufunc_api.py index c488855b70fa..c85621d6cdba 100644 --- a/jax/_src/numpy/ufunc_api.py +++ b/jax/_src/numpy/ufunc_api.py @@ -22,9 +22,11 @@ import operator from typing import Any -import jax +from jax._src import api from jax._src.typing import Array, ArrayLike, DTypeLike -from jax._src.lax import lax as lax_internal +from jax._src.lax import control_flow +from jax._src.lax import slicing +from jax._src.lax import lax from jax._src.numpy import indexing import jax._src.numpy.lax_numpy as jnp from jax._src.numpy.reductions import _moveaxis @@ -90,7 +92,7 @@ class ufunc: [ 5, 6, 7, 8, 9], [ 6, 7, 8, 9, 10]], dtype=int32) - The :meth:`ufunc.reduce` method perfoms a reduction over the array. + The :meth:`ufunc.reduce` method performs a reduction over the array. For example, :meth:`jnp.add.reduce` is equivalent to ``jnp.sum``: >>> jnp.add.reduce(x) @@ -110,7 +112,7 @@ class ufunc: Array([101, 2, 3, 4, 5], dtype=int32) And the :meth:`ufunc.reduceat` method performs a number of ``reduce`` - operations bewteen specified indices of an array; for ``jnp.add`` the + operations between specified indices of an array; for ``jnp.add`` the operation is similar to :func:`jax.ops.segment_sum`: >>> jnp.add.reduceat(x, jnp.array([0, 2])) @@ -179,12 +181,12 @@ def __call__(self, *args: ArrayLike, out: None = None, where: None = None) -> An call = self.__static_props['call'] or self._call_vectorized return call(*args) - @partial(jax.jit, static_argnames=['self']) + @partial(api.jit, static_argnames=['self']) def _call_vectorized(self, *args): return vectorize(self._func)(*args) - @partial(jax.jit, static_argnames=['self', 'axis', 'dtype', 'out', 'keepdims']) - def reduce(self, a: ArrayLike, axis: int = 0, + @partial(api.jit, static_argnames=['self', 'axis', 'dtype', 'out', 'keepdims']) + def reduce(self, a: ArrayLike, axis: int | None = 0, dtype: DTypeLike | None = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: @@ -249,8 +251,8 @@ def reduce(self, a: ArrayLike, axis: int = 0, if self.identity is None and initial is None: raise ValueError(f"reduction operation {self.__name__!r} does not have an identity, " "so to use a where mask one has to specify 'initial'.") - if lax_internal._dtype(where) != bool: - raise ValueError(f"where argument must have dtype=bool; got dtype={lax_internal._dtype(where)}") + if lax._dtype(where) != bool: + raise ValueError(f"where argument must have dtype=bool; got dtype={lax._dtype(where)}") reduce = self.__static_props['reduce'] or self._reduce_via_scan return reduce(a, axis=axis, dtype=dtype, keepdims=keepdims, initial=initial, where=where) @@ -258,11 +260,11 @@ def _reduce_via_scan(self, arr: ArrayLike, axis: int | None = 0, dtype: DTypeLik keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: assert self.nin == 2 and self.nout == 1 - arr = lax_internal.asarray(arr) + arr = lax.asarray(arr) if initial is None: initial = self.identity if dtype is None: - dtype = jax.eval_shape(self._func, lax_internal._one(arr), lax_internal._one(arr)).dtype + dtype = api.eval_shape(self._func, lax._one(arr), lax._one(arr)).dtype if where is not None: where = _broadcast_to(where, arr.shape) if isinstance(axis, tuple): @@ -306,15 +308,15 @@ def body_fun(i, val): else: start_index = 0 start_value = initial - start_value = _broadcast_to(lax_internal.asarray(start_value).astype(dtype), arr.shape[1:]) + start_value = _broadcast_to(lax.asarray(start_value).astype(dtype), arr.shape[1:]) - result = jax.lax.fori_loop(start_index, arr.shape[0], body_fun, start_value) + result = control_flow.fori_loop(start_index, arr.shape[0], body_fun, start_value) if keepdims: result = result.reshape(final_shape) return result - @partial(jax.jit, static_argnames=['self', 'axis', 'dtype']) + @partial(api.jit, static_argnames=['self', 'axis', 'dtype']) def accumulate(self, a: ArrayLike, axis: int = 0, dtype: DTypeLike | None = None, out: None = None) -> Array: """Accumulate operation derived from binary ufunc. @@ -376,10 +378,10 @@ def _accumulate_via_scan(self, arr: ArrayLike, axis: int = 0, dtype: DTypeLike | None = None) -> Array: assert self.nin == 2 and self.nout == 1 check_arraylike(f"{self.__name__}.accumulate", arr) - arr = lax_internal.asarray(arr) + arr = lax.asarray(arr) if dtype is None: - dtype = jax.eval_shape(self._func, lax_internal._one(arr), lax_internal._one(arr)).dtype + dtype = api.eval_shape(self._func, lax._one(arr), lax._one(arr)).dtype if axis is None or isinstance(axis, tuple): raise ValueError("accumulate does not allow multiple axes") @@ -390,10 +392,10 @@ def scan_fun(carry, _): i, x = carry y = _where(i == 0, arr[0].astype(dtype), self(x.astype(dtype), arr[i].astype(dtype))) return (i + 1, y), y - _, result = jax.lax.scan(scan_fun, (0, arr[0].astype(dtype)), None, length=arr.shape[0]) + _, result = control_flow.scan(scan_fun, (0, arr[0].astype(dtype)), None, length=arr.shape[0]) return _moveaxis(result, 0, axis) - @partial(jax.jit, static_argnums=[0], static_argnames=['inplace']) + @partial(api.jit, static_argnums=[0], static_argnames=['inplace']) def at(self, a: ArrayLike, indices: Any, b: ArrayLike | None = None, /, *, inplace: bool = True) -> Array: """Update elements of an array via the specified unary or binary ufunc. @@ -440,15 +442,15 @@ def at(self, a: ArrayLike, indices: Any, b: ArrayLike | None = None, /, *, def _at_via_scan(self, a: ArrayLike, indices: Any, *args: Any) -> Array: assert len(args) in {0, 1} check_arraylike(f"{self.__name__}.at", a, *args) - dtype = jax.eval_shape(self._func, lax_internal._one(a), *(lax_internal._one(arg) for arg in args)).dtype - a = lax_internal.asarray(a).astype(dtype) - args = tuple(lax_internal.asarray(arg).astype(dtype) for arg in args) + dtype = api.eval_shape(self._func, lax._one(a), *(lax._one(arg) for arg in args)).dtype + a = lax.asarray(a).astype(dtype) + args = tuple(lax.asarray(arg).astype(dtype) for arg in args) indices = indexing.eliminate_deprecated_list_indexing(indices) if not indices: return a shapes = [np.shape(i) for i in indices if not isinstance(i, slice)] - shape = shapes and jax.lax.broadcast_shapes(*shapes) + shape = shapes and lax.broadcast_shapes(*shapes) if not shape: return a.at[indices].set(self(a.at[indices].get(), *args)) @@ -462,10 +464,10 @@ def scan_fun(carry, x): idx = tuple(ind if isinstance(ind, slice) else ind[i] for ind in indices) a = a.at[idx].set(self(a.at[idx].get(), *(arg[i] for arg in args))) return (i + 1, a), x - carry, _ = jax.lax.scan(scan_fun, (0, a), None, len(indices[0])) # type: ignore[arg-type] + carry, _ = control_flow.scan(scan_fun, (0, a), None, len(indices[0])) # type: ignore[arg-type] return carry[1] - @partial(jax.jit, static_argnames=['self', 'axis', 'dtype']) + @partial(api.jit, static_argnames=['self', 'axis', 'dtype']) def reduceat(self, a: ArrayLike, indices: Any, axis: int = 0, dtype: DTypeLike | None = None, out: None = None) -> Array: """Reduce an array between specified indices via a binary ufunc. @@ -517,7 +519,7 @@ def reduceat(self, a: ArrayLike, indices: Any, axis: int = 0, def _reduceat_via_scan(self, a: ArrayLike, indices: Any, axis: int = 0, dtype: DTypeLike | None = None) -> Array: check_arraylike(f"{self.__name__}.reduceat", a, indices) - a = lax_internal.asarray(a) + a = lax.asarray(a) idx_tuple = indexing.eliminate_deprecated_list_indexing(indices) assert len(idx_tuple) == 1 indices = idx_tuple[0] @@ -531,17 +533,17 @@ def _reduceat_via_scan(self, a: ArrayLike, indices: Any, axis: int = 0, raise ValueError("reduceat requires a single integer axis.") axis = canonicalize_axis(axis, a.ndim) out = indexing.take(a, indices, axis=axis) - ind = jax.lax.expand_dims(jnp.append(indices, a.shape[axis]), - list(np.delete(np.arange(out.ndim), axis))) - ind_start = jax.lax.slice_in_dim(ind, 0, ind.shape[axis] - 1, axis=axis) - ind_end = jax.lax.slice_in_dim(ind, 1, ind.shape[axis], axis=axis) + ind = lax.expand_dims(jnp.append(indices, a.shape[axis]), + list(np.delete(np.arange(out.ndim), axis))) + ind_start = slicing.slice_in_dim(ind, 0, ind.shape[axis] - 1, axis=axis) + ind_end = slicing.slice_in_dim(ind, 1, ind.shape[axis], axis=axis) def loop_body(i, out): return _where((i > ind_start) & (i < ind_end), - self(out, indexing.take(a, jax.lax.expand_dims(i, (0,)), axis=axis)), + self(out, indexing.take(a, lax.expand_dims(i, (0,)), axis=axis)), out) - return jax.lax.fori_loop(0, a.shape[axis], loop_body, out) + return control_flow.fori_loop(0, a.shape[axis], loop_body, out) - @partial(jax.jit, static_argnums=[0]) + @partial(api.jit, static_argnums=[0]) def outer(self, A: ArrayLike, B: ArrayLike, /) -> Array: """Apply the function to all pairs of values in ``A`` and ``B``. @@ -572,7 +574,7 @@ def outer(self, A: ArrayLike, B: ArrayLike, /) -> Array: [ 10 20 30 40 50 60 70 80 90 100]] For input arrays with ``N`` and ``M`` dimensions respectively, the output - will have dimesion ``N + M``: + will have dimension ``N + M``: >>> x = jnp.ones((1, 3, 5)) >>> y = jnp.ones((2, 4)) @@ -584,8 +586,8 @@ def outer(self, A: ArrayLike, B: ArrayLike, /) -> Array: if self.nout != 1: raise ValueError("outer only supported for functions returning a single value") check_arraylike(f"{self.__name__}.outer", A, B) - _ravel = lambda A: jax.lax.reshape(A, (np.size(A),)) - result = jax.vmap(jax.vmap(self, (None, 0)), (0, None))(_ravel(A), _ravel(B)) + _ravel = lambda A: lax.reshape(A, (np.size(A),)) + result = api.vmap(api.vmap(self, (None, 0)), (0, None))(_ravel(A), _ravel(B)) return result.reshape(*np.shape(A), *np.shape(B)) diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py index 91191d24a12e..b0ff3cb9747a 100644 --- a/jax/_src/numpy/ufuncs.py +++ b/jax/_src/numpy/ufuncs.py @@ -32,12 +32,13 @@ from jax._src.lax import lax from jax._src.lax import other as lax_other from jax._src.typing import Array, ArrayLike +from jax._src.numpy import error as jnp_error +from jax._src.numpy import reductions +from jax._src.numpy.ufunc_api import ufunc from jax._src.numpy.util import ( - check_arraylike, promote_args, promote_args_inexact, + check_arraylike, ensure_arraylike, promote_args, promote_args_inexact, promote_args_numeric, promote_dtypes_inexact, promote_dtypes_numeric, promote_shapes, _where, check_no_float0s) -from jax._src.numpy.ufunc_api import ufunc -from jax._src.numpy import reductions from jax._src.util import set_module @@ -118,7 +119,7 @@ def fabs(x: ArrayLike, /) -> Array: >>> jnp.fabs(x2) Array([1., 0.], dtype=float32) """ - check_arraylike('fabs', x) + x = ensure_arraylike('fabs', x) if dtypes.issubdtype(dtypes.dtype(x), np.complexfloating): raise TypeError("ufunc 'fabs' does not support complex dtypes") return lax.abs(*promote_args_inexact('fabs', x)) @@ -296,7 +297,7 @@ def sign(x: ArrayLike, /) -> Array: -1, & x < 0 \end{cases} - For complex valued input, ``jnp.sign`` returns a unit vector repesenting the + For complex valued input, ``jnp.sign`` returns a unit vector representing the phase. For generalized case, the sign of ``x`` is given by: .. math:: @@ -346,8 +347,8 @@ def floor(x: ArrayLike, /) -> Array: the nearest integer that is less than or equal to the value itself. See also: - - :func:`jax.numpy.fix`: Rounds the input to the nearest interger towards zero. - - :func:`jax.numpy.trunc`: Rounds the input to the nearest interger towards + - :func:`jax.numpy.fix`: Rounds the input to the nearest integer towards zero. + - :func:`jax.numpy.trunc`: Rounds the input to the nearest integer towards zero. - :func:`jax.numpy.ceil`: Rounds the input up to the nearest integer. @@ -364,9 +365,9 @@ def floor(x: ArrayLike, /) -> Array: [ 0., -1., 0.], [-5., 2., 1.]], dtype=float32) """ - check_arraylike('floor', x) + x = ensure_arraylike('floor', x) if dtypes.isdtype(dtypes.dtype(x), ('integral', 'bool')): - return lax.asarray(x) + return x return lax.floor(*promote_args_inexact('floor', x)) @@ -385,8 +386,8 @@ def ceil(x: ArrayLike, /) -> Array: the nearest integer that is greater than or equal to the value itself. See also: - - :func:`jax.numpy.fix`: Rounds the input to the nearest interger towards zero. - - :func:`jax.numpy.trunc`: Rounds the input to the nearest interger towards + - :func:`jax.numpy.fix`: Rounds the input to the nearest integer towards zero. + - :func:`jax.numpy.trunc`: Rounds the input to the nearest integer towards zero. - :func:`jax.numpy.floor`: Rounds the input down to the nearest integer. @@ -403,7 +404,7 @@ def ceil(x: ArrayLike, /) -> Array: [-0., 4., 1.], [ 5., 4., -1.]], dtype=float32) """ - check_arraylike('ceil', x) + x = ensure_arraylike('ceil', x) if dtypes.isdtype(dtypes.dtype(x), ('integral', 'bool')): return lax.asarray(x) return lax.ceil(*promote_args_inexact('ceil', x)) @@ -486,7 +487,9 @@ def log(x: ArrayLike, /) -> Array: >>> jnp.allclose(jnp.log(x1*x2), jnp.log(x1)+jnp.log(x2)) Array(True, dtype=bool) """ - return lax.log(*promote_args_inexact('log', x)) + out = lax.log(*promote_args_inexact('log', x)) + jnp_error._set_error_if_nan(out) + return out @export @@ -572,7 +575,9 @@ def log1p(x: ArrayLike, /) -> Array: >>> jnp.expm1(jnp.log(x1+1)) # doctest: +SKIP Array([1.000166e-04, 9.536743e-07, 0.000000e+00], dtype=float32) """ - return lax.log1p(*promote_args_inexact('log1p', x)) + out = lax.log1p(*promote_args_inexact('log1p', x)) + jnp_error._set_error_if_nan(out) + return out @export @@ -604,7 +609,9 @@ def sin(x: ArrayLike, /) -> Array: ... print(jnp.sin(x)) [ 0.707 1. 0.707 -0. ] """ - return lax.sin(*promote_args_inexact('sin', x)) + out = lax.sin(*promote_args_inexact('sin', x)) + jnp_error._set_error_if_nan(out) + return out @export @@ -635,7 +642,9 @@ def cos(x: ArrayLike, /) -> Array: ... print(jnp.cos(x)) [ 0.707 -0. -0.707 -0.866] """ - return lax.cos(*promote_args_inexact('cos', x)) + out = lax.cos(*promote_args_inexact('cos', x)) + jnp_error._set_error_if_nan(out) + return out @export @@ -666,7 +675,9 @@ def tan(x: ArrayLike, /) -> Array: ... print(jnp.tan(x)) [ 0. 0.577 1. -1. -0.577] """ - return lax.tan(*promote_args_inexact('tan', x)) + out = lax.tan(*promote_args_inexact('tan', x)) + jnp_error._set_error_if_nan(out) + return out @export @@ -708,7 +719,9 @@ def arcsin(x: ArrayLike, /) -> Array: ... jnp.arcsin(3+4j) Array(0.634+2.306j, dtype=complex64, weak_type=True) """ - return lax.asin(*promote_args_inexact('arcsin', x)) + out = lax.asin(*promote_args_inexact('arcsin', x)) + jnp_error._set_error_if_nan(out) + return out @export @@ -751,7 +764,9 @@ def arccos(x: ArrayLike, /) -> Array: ... jnp.arccos(4-1j) Array(0.252+2.097j, dtype=complex64, weak_type=True) """ - return lax.acos(*promote_args_inexact('arccos', x)) + out = lax.acos(*promote_args_inexact('arccos', x)) + jnp_error._set_error_if_nan(out) + return out @export @@ -1005,6 +1020,7 @@ def arccosh(x: ArrayLike, /) -> Array: # Note: arccosh is multi-valued for complex input, and lax.acosh # uses a different convention than np.arccosh. result = lax.acosh(*promote_args_inexact("arccosh", x)) + jnp_error._set_error_if_nan(result) if dtypes.issubdtype(result.dtype, np.complexfloating): result = _where(real(result) < 0, lax.neg(result), result) return result @@ -1110,7 +1126,9 @@ def arctanh(x: ArrayLike, /) -> Array: ... jnp.arctanh(x1) Array([-0.549+1.571j, 0.347+1.571j, 0.239-1.509j], dtype=complex64) """ - return lax.atanh(*promote_args_inexact('arctanh', x)) + out = lax.atanh(*promote_args_inexact('arctanh', x)) + jnp_error._set_error_if_nan(out) + return out @export @@ -1143,7 +1161,9 @@ def sqrt(x: ArrayLike, /) -> Array: >>> jnp.sqrt(-1) Array(nan, dtype=float32, weak_type=True) """ - return lax.sqrt(*promote_args_inexact('sqrt', x)) + out = lax.sqrt(*promote_args_inexact('sqrt', x)) + jnp_error._set_error_if_nan(out) + return out @export @@ -1212,7 +1232,11 @@ def add(x: ArrayLike, y: ArrayLike, /) -> Array: Array([10, 11, 12, 13], dtype=int32) """ x, y = promote_args("add", x, y) - return lax.add(x, y) if x.dtype != bool else lax.bitwise_or(x, y) + if x.dtype == bool: + return lax.bitwise_or(x, y) + out = lax.add(x, y) + jnp_error._set_error_if_nan(out) + return out def _multiply_at(a: Array, indices: Any, b: ArrayLike) -> Array: @@ -1541,7 +1565,9 @@ def subtract(x: ArrayLike, y: ArrayLike, /) -> Array: >>> x - 10 Array([-10, -9, -8, -7], dtype=int32) """ - return lax.sub(*promote_args("subtract", x, y)) + out = lax.sub(*promote_args("subtract", x, y)) + jnp_error._set_error_if_nan(out) + return out @export @@ -1595,13 +1621,12 @@ def arctan2(x1: ArrayLike, x2: ArrayLike, /) -> Array: The results match the input ``theta``, except at the endpoints where :math:`+\pi` and :math:`-\pi` represent indistinguishable points on the unit circle. By convention, - :func:`arctan2` alwasy returns values between :math:`-\pi` and :math:`+\pi` inclusive. + :func:`arctan2` always returns values between :math:`-\pi` and :math:`+\pi` inclusive. """ return lax.atan2(*promote_args_inexact("arctan2", x1, x2)) -@export -@partial(jit, inline=True) +@binary_ufunc(identity=None, reduce=reductions._reduce_min) def minimum(x: ArrayLike, y: ArrayLike, /) -> Array: """Return element-wise minimum of the input arrays. @@ -1661,8 +1686,7 @@ def minimum(x: ArrayLike, y: ArrayLike, /) -> Array: return lax.min(*promote_args("minimum", x, y)) -@export -@partial(jit, inline=True) +@binary_ufunc(identity=None, reduce=reductions._reduce_max) def maximum(x: ArrayLike, y: ArrayLike, /) -> Array: """Return element-wise maximum of the input arrays. @@ -1686,7 +1710,7 @@ def maximum(x: ArrayLike, y: ArrayLike, /) -> Array: arrays. - :func:`jax.numpy.fmax`: Returns element-wise maximum of the input arrays, ignoring NaNs. - - :func:`jax.numpy.amax`: Retruns the maximum of array elements along a given + - :func:`jax.numpy.amax`: Returns the maximum of array elements along a given axis. - :func:`jax.numpy.nanmax`: Returns the maximum of the array elements along a given axis, ignoring NaNs. @@ -1750,7 +1774,7 @@ def float_power(x: ArrayLike, y: ArrayLike, /) -> Array: >>> jnp.float_power(x, y) Array([ 9. , 1. , -0.2], dtype=float32) - Inputs with broacast compatibility: + Inputs with broadcast compatibility: >>> x1 = jnp.array([[2, -4, 1], ... [-1, 2, 3]]) @@ -1765,7 +1789,9 @@ def float_power(x: ArrayLike, y: ArrayLike, /) -> Array: >>> jnp.float_power(-3, 1.7) Array(nan, dtype=float32, weak_type=True) """ - return lax.pow(*promote_args_inexact("float_power", x, y)) + out = lax.pow(*promote_args_inexact("float_power", x, y)) + jnp_error._set_error_if_nan(out) + return out @export @@ -2315,7 +2341,7 @@ def absolute(x: ArrayLike, /) -> Array: >>> jnp.absolute(x3) Array([17., 5., 5.], dtype=float32) """ - check_arraylike('absolute', x) + x = ensure_arraylike('absolute', x) dt = dtypes.dtype(x) return lax.asarray(x) if dt == np.bool_ or dtypes.issubdtype(dt, np.unsignedinteger) else lax.abs(x) @@ -2358,7 +2384,7 @@ def rint(x: ArrayLike, /) -> Array: >>> jnp.rint(x3) Array([-2.+4.j, 4.-0.j], dtype=complex64) """ - check_arraylike('rint', x) + x = ensure_arraylike('rint', x) dtype = dtypes.dtype(x) if dtype == bool or dtypes.issubdtype(dtype, np.integer): return lax.convert_element_type(x, dtypes.float_) @@ -2443,7 +2469,10 @@ def true_divide(x1: ArrayLike, x2: ArrayLike, /) -> Array: :func:`jax.numpy.floor_divide` for integer division """ x1, x2 = promote_args_inexact("true_divide", x1, x2) - return lax.div(x1, x2) + jnp_error._set_error_if_divide_by_zero(x2) + out = lax.div(x1, x2) + jnp_error._set_error_if_nan(out) + return out @export @@ -2493,6 +2522,7 @@ def floor_divide(x1: ArrayLike, x2: ArrayLike, /) -> Array: Array([3., 2., 2.], dtype=float32) """ x1, x2 = promote_args_numeric("floor_divide", x1, x2) + jnp_error._set_error_if_divide_by_zero(x2) dtype = dtypes.dtype(x1) if dtypes.issubdtype(dtype, np.unsignedinteger): return lax.div(x1, x2) @@ -2547,6 +2577,7 @@ def divmod(x1: ArrayLike, x2: ArrayLike, /) -> tuple[Array, Array]: if dtypes.issubdtype(dtypes.dtype(x1), np.integer): return floor_divide(x1, x2), remainder(x1, x2) else: + jnp_error._set_error_if_divide_by_zero(x2) return _float_divmod(x1, x2) @@ -2582,8 +2613,9 @@ def power(x1: ArrayLike, x2: ArrayLike, /) -> Array: :func:`jax.lax.integer_pow`. - When ``x2`` is a traced scalar or an array, ``jnp.power`` lowers to :func:`jax.lax.pow`. - - ``jnp.power`` raises a ``TypeError`` for integer type raised to negative - integer power. + - ``jnp.power`` raises a ``TypeError`` for integer type raised to a concrete + negative integer power. For a non-concrete power, the operation is invalid + and the returned value is implementation-defined. - ``jnp.power`` returns ``nan`` for negative value raised to the power of non-integer values. @@ -2619,6 +2651,11 @@ def power(x1: ArrayLike, x2: ArrayLike, /) -> Array: [nan, 27., 1.]], dtype=float32) """ check_arraylike("power", x1, x2) + + # Must do __jax_array__ conversion prior to dtype check. + x1 = x1.__jax_array__() if hasattr(x1, "__jax_array__") else x1 + x2 = x2.__jax_array__() if hasattr(x2, "__jax_array__") else x2 + check_no_float0s("power", x1, x2) # We apply special cases, both for algorithmic and autodiff reasons: @@ -2645,7 +2682,9 @@ def power(x1: ArrayLike, x2: ArrayLike, /) -> Array: return lax.integer_pow(x1, x2) # Handle cases #2 and #3 under a jit: - return _power(x1, x2) + out = _power(x1, x2) + jnp_error._set_error_if_nan(out) + return out @export def pow(x1: ArrayLike, x2: ArrayLike, /) -> Array: @@ -2741,8 +2780,7 @@ def logaddexp2(x1: ArrayLike, x2: ArrayLike, /) -> Array: Array(True, dtype=bool) """ x1, x2 = promote_args_inexact("logaddexp2", x1, x2) - ln2 = float(np.log(2)) - return logaddexp(x1 * ln2, x2 * ln2) / ln2 + return lax_other.logaddexp2(x1, x2) @export @@ -2771,7 +2809,9 @@ def log2(x: ArrayLike, /) -> Array: im = lax.imag(r) ln2 = lax.log(_constant_like(re, 2)) return lax.complex(lax.div(re, ln2), lax.div(im, ln2)) - return lax.div(lax.log(x), lax.log(_constant_like(x, 2))) + out = lax.div(lax.log(x), lax.log(_constant_like(x, 2))) + jnp_error._set_error_if_nan(out) + return out @export @@ -2801,7 +2841,9 @@ def log10(x: ArrayLike, /) -> Array: im = lax.imag(r) ln10 = lax.log(_constant_like(re, 10)) return lax.complex(lax.div(re, ln10), lax.div(im, ln10)) - return lax.div(lax.log(x), lax.log(_constant_like(x, 10))) + out = lax.div(lax.log(x), lax.log(_constant_like(x, 10))) + jnp_error._set_error_if_nan(out) + return out @export @@ -2950,7 +2992,7 @@ def ldexp(x1: ArrayLike, x2: ArrayLike, /) -> Array: >>> jnp.ldexp(m, e) Array([ 2., 3., 5., 11.], dtype=float32) """ - check_arraylike("ldexp", x1, x2) + x1, x2 = ensure_arraylike("ldexp", x1, x2) x1_dtype = dtypes.dtype(x1) x2_dtype = dtypes.dtype(x2) if (dtypes.issubdtype(x1_dtype, np.complexfloating) @@ -2958,7 +3000,16 @@ def ldexp(x1: ArrayLike, x2: ArrayLike, /) -> Array: raise ValueError(f"ldexp not supported for input types {(x1_dtype, x2_dtype)}") x1, = promote_args_inexact("ldexp", x1) x2 = lax.convert_element_type(x2, dtypes.dtype(x1)) - x = x1 * (2 ** x2) + + # Split off the exponent to avoid overflow for small x1 and large x2. + m, e = frexp(x1) + e = (e.astype(x2.dtype) + x2).astype(x1.dtype) + + # exponent may overflow by 1 and still have a finite result. + m = _where(e > 0, m * 2, m) + e = _where(e > 0, e - 1, e) + + x = m * (2 ** e.astype(m.dtype)) return _where(isinf(x1) | (x1 == 0), x1, x) @@ -2995,11 +3046,14 @@ def frexp(x: ArrayLike, /) -> tuple[Array, Array]: >>> m * 2 ** e Array([1., 2., 3., 4., 5.], dtype=float32) """ - check_arraylike("frexp", x) + x = ensure_arraylike("frexp", x) x, = promote_dtypes_inexact(x) if dtypes.issubdtype(x.dtype, np.complexfloating): raise TypeError("frexp does not support complex-valued inputs") + return _frexp(x) +@custom_jvp +def _frexp(x): dtype = dtypes.dtype(x) info = dtypes.finfo(dtype) mask = (1 << info.nexp) - 1 @@ -3016,6 +3070,16 @@ def frexp(x: ArrayLike, /) -> tuple[Array, Array]: return _where(cond, x, x1), lax.convert_element_type(x2, np.int32) +@_frexp.defjvp +def _frexp_jvp(primals, tangents): + x, = primals + t, = tangents + m, e = frexp(x) + mdot = t * exp2(-e.astype(t.dtype)) + edot = np.empty(e.shape, dtypes.float0) + return (m, e), (mdot, edot) + + @export @jit def remainder(x1: ArrayLike, x2: ArrayLike, /) -> Array: @@ -3054,6 +3118,7 @@ def remainder(x1: ArrayLike, x2: ArrayLike, /) -> Array: [ 0., 2., -2.]], dtype=float32) """ x1, x2 = promote_args_numeric("remainder", x1, x2) + jnp_error._set_error_if_divide_by_zero(x2) zero = _constant_like(x1, 0) if dtypes.issubdtype(x2.dtype, np.integer): x2 = _where(x2 == 0, lax._ones(x2), x2) @@ -3061,7 +3126,9 @@ def remainder(x1: ArrayLike, x2: ArrayLike, /) -> Array: trunc_mod_not_zero = lax.ne(trunc_mod, zero) do_plus = lax.bitwise_and( lax.ne(lax.lt(trunc_mod, zero), lax.lt(x2, zero)), trunc_mod_not_zero) - return lax.select(do_plus, lax.add(trunc_mod, x2), trunc_mod) + out = lax.select(do_plus, lax.add(trunc_mod, x2), trunc_mod) + jnp_error._set_error_if_nan(out) + return out @export @@ -3106,10 +3173,12 @@ def fmod(x1: ArrayLike, x2: ArrayLike, /) -> Array: Array([[ 1., -1., 4.], [ 0., 2., -2.]], dtype=float32) """ - check_arraylike("fmod", x1, x2) + x1, x2 = ensure_arraylike("fmod", x1, x2) if dtypes.issubdtype(dtypes.result_type(x1, x2), np.integer): x2 = _where(x2 == 0, lax._ones(x2), x2) - return lax.rem(*promote_args_numeric("fmod", x1, x2)) + out = lax.rem(*promote_args_numeric("fmod", x1, x2)) + jnp_error._set_error_if_nan(out) + return out @export @@ -3157,7 +3226,7 @@ def square(x: ArrayLike, /) -> Array: >>> jnp.square(x2) Array([-8.-6.j, -1.+0.j, 4.+0.j], dtype=complex64) """ - check_arraylike("square", x) + x = ensure_arraylike("square", x) x, = promote_dtypes_numeric(x) return lax.square(x) @@ -3271,7 +3340,7 @@ def conjugate(x: ArrayLike, /) -> Array: >>> jnp.conjugate(x) Array([2.+1.j, 3.-5.j, 7.-0.j], dtype=complex64) """ - check_arraylike("conjugate", x) + x = ensure_arraylike("conjugate", x) return lax.conj(x) if np.iscomplexobj(x) else lax.asarray(x) @@ -3309,7 +3378,7 @@ def imag(val: ArrayLike, /) -> Array: >>> jnp.imag(x) Array([ 3., -1., 0.], dtype=float32) """ - check_arraylike("imag", val) + val = ensure_arraylike("imag", val) return lax.imag(val) if np.iscomplexobj(val) else lax.full_like(val, 0) @@ -3341,7 +3410,7 @@ def real(val: ArrayLike, /) -> Array: >>> jnp.real(x) Array([ 3., 4., -0.], dtype=float32) """ - check_arraylike("real", val) + val = ensure_arraylike("real", val) return lax.real(val) if np.iscomplexobj(val) else lax.asarray(val) @@ -3371,7 +3440,7 @@ def modf(x: ArrayLike, /, out=None) -> tuple[Array, Array]: >>> jnp.modf(x) (Array([-0.4000001 , -0.6999998 , 0.6 , 0.5 , 0.29999995], dtype=float32), Array([-3., -5., 0., 1., 2.], dtype=float32)) """ - check_arraylike("modf", x) + x = ensure_arraylike("modf", x) x, = promote_dtypes_inexact(x) if out is not None: raise NotImplementedError("The 'out' argument to jnp.modf is not supported.") @@ -3410,7 +3479,7 @@ def isfinite(x: ArrayLike, /) -> Array: >>> jnp.isfinite(3-4j) Array(True, dtype=bool, weak_type=True) """ - check_arraylike("isfinite", x) + x = ensure_arraylike("isfinite", x) dtype = dtypes.dtype(x) if dtypes.issubdtype(dtype, np.floating): return lax.is_finite(x) @@ -3451,7 +3520,7 @@ def isinf(x: ArrayLike, /) -> Array: >>> jnp.isinf(x) Array([False, True, False, True, False], dtype=bool) """ - check_arraylike("isinf", x) + x = ensure_arraylike("isinf", x) dtype = dtypes.dtype(x) if dtypes.issubdtype(dtype, np.floating): return lax.eq(lax.abs(x), _constant_like(x, np.inf)) @@ -3464,7 +3533,7 @@ def isinf(x: ArrayLike, /) -> Array: return lax.full_like(x, False, dtype=np.bool_) -def _isposneginf(infinity: float, x: ArrayLike, out) -> Array: +def _isposneginf(infinity: float, x: Array, out) -> Array: if out is not None: raise NotImplementedError("The 'out' argument to isneginf/isposinf is not supported.") dtype = dtypes.dtype(x) @@ -3507,6 +3576,7 @@ def isposinf(x, /, out=None): >>> jnp.isposinf(x) Array([False, False, True, False, False], dtype=bool) """ + x = ensure_arraylike("isposinf", x) return _isposneginf(np.inf, x, out) @@ -3541,6 +3611,7 @@ def isneginf(x, /, out=None): >>> jnp.isneginf(x) Array([ True, False, False, False, False], dtype=bool) """ + x = ensure_arraylike("isneginf", x) return _isposneginf(-np.inf, x, out) @@ -3575,7 +3646,7 @@ def isnan(x: ArrayLike, /) -> Array: >>> jnp.isnan(x) Array([False, False, False, True], dtype=bool) """ - check_arraylike("isnan", x) + x = ensure_arraylike("isnan", x) return lax.ne(x, x) @@ -3591,9 +3662,9 @@ def heaviside(x1: ArrayLike, x2: ArrayLike, /) -> Array: .. math:: \mathrm{heaviside}(x1, x2) = \begin{cases} - 0., & x < 0\\ - x2, & x = 0\\ - 1., & x > 0. + 0, & x1 < 0\\ + x2, & x1 = 0\\ + 1, & x1 > 0. \end{cases} Args: @@ -3622,7 +3693,7 @@ def heaviside(x1: ArrayLike, x2: ArrayLike, /) -> Array: >>> jnp.heaviside(-3, x2) Array([0., 0., 0.], dtype=float32) """ - check_arraylike("heaviside", x1, x2) + x1, x2 = ensure_arraylike("heaviside", x1, x2) x1, x2 = promote_dtypes_inexact(x1, x2) zero = _lax_const(x1, 0) return _where(lax.lt(x1, zero), zero, @@ -3707,7 +3778,7 @@ def reciprocal(x: ArrayLike, /) -> Array: >>> jnp.reciprocal(x) Array([1. , 0.2 , 0.25], dtype=float32) """ - check_arraylike("reciprocal", x) + x = ensure_arraylike("reciprocal", x) x, = promote_dtypes_inexact(x) return lax.integer_pow(x, -1) @@ -3760,7 +3831,7 @@ def sinc(x: ArrayLike, /) -> Array: (d/dx)^4 f(0.0) = 19.48 (d/dx)^5 f(0.0) = 0.00 """ - check_arraylike("sinc", x) + x = ensure_arraylike("sinc", x) x, = promote_dtypes_inexact(x) eq_zero = lax.eq(x, _lax_const(x, 0)) pi_x = lax.mul(_lax_const(x, np.pi), x) diff --git a/jax/_src/numpy/util.py b/jax/_src/numpy/util.py index e281c63ae654..6302e1a9b54c 100644 --- a/jax/_src/numpy/util.py +++ b/jax/_src/numpy/util.py @@ -27,7 +27,8 @@ from jax._src.lib import xla_client as xc from jax._src.sharding_impls import SingleDeviceSharding from jax._src.util import safe_zip, safe_map, set_module -from jax._src.typing import Array, ArrayLike, DimSize, DType, DTypeLike, Shape +from jax._src.typing import ( + Array, ArrayLike, DimSize, Shape, SupportsNdim, SupportsShape, SupportsSize) from jax.sharding import Sharding import numpy as np @@ -69,13 +70,13 @@ def _rank_promotion_warning_or_error(fun_name: str, shapes: Sequence[Shape]): msg = ("Following NumPy automatic rank promotion for {} on shapes {}. " "Set the jax_numpy_rank_promotion config option to 'allow' to " "disable this warning; for more information, see " - "https://jax.readthedocs.io/en/latest/rank_promotion_warning.html.") + "https://docs.jax.dev/en/latest/rank_promotion_warning.html.") warnings.warn(msg.format(fun_name, ' '.join(map(str, shapes)))) elif config.numpy_rank_promotion.value == "raise": msg = ("Operands could not be broadcast together for {} on shapes {} " "and with the config option jax_numpy_rank_promotion='raise'. " "For more information, see " - "https://jax.readthedocs.io/en/latest/rank_promotion_warning.html.") + "https://docs.jax.dev/en/latest/rank_promotion_warning.html.") raise ValueError(msg.format(fun_name, ' '.join(map(str, shapes)))) @@ -123,11 +124,6 @@ def promote_dtypes_complex(*args: ArrayLike) -> list[Array]: for x in args] -def _complex_elem_type(dtype: DTypeLike) -> DType: - """Returns the float type of the real/imaginary parts of a complex dtype.""" - return np.abs(np.zeros((), dtype)).dtype - - def _arraylike(x: ArrayLike) -> bool: return (isinstance(x, np.ndarray) or isinstance(x, Array) or hasattr(x, '__jax_array__') or np.isscalar(x)) @@ -140,6 +136,10 @@ def _arraylike_asarray(x: Any) -> Array: return lax.asarray(x) +def _check_jax_array_protocol(x: Any) -> Any: + return x.__jax_array__() if hasattr(x, '__jax_array__') else x + + @overload def ensure_arraylike(fun_name: str, /) -> tuple[()]: ... @overload @@ -158,7 +158,7 @@ def ensure_arraylike(fun_name: str, /, *args: Any) -> Array | tuple[Array, ...]: return tuple(_arraylike_asarray(arg) for arg in args) # pytype: disable=bad-return-type -def ensure_arraylike_tuple(fun_name: str, tup: tuple[Any, ...]) -> tuple[Array, ...]: +def ensure_arraylike_tuple(fun_name: str, tup: Sequence[Any]) -> tuple[Array, ...]: """Check that argument elements are arraylike and convert to a tuple of arrays. This is useful because ensure_arraylike with a single argument returns a single array. @@ -222,6 +222,7 @@ def check_for_prngkeys(fun_name: str, *args: Any): def promote_args(fun_name: str, *args: ArrayLike) -> list[Array]: """Convenience function to apply Numpy argument shape and dtype promotion.""" check_arraylike(fun_name, *args) + args = tuple(_check_jax_array_protocol(arg) for arg in args) _check_no_float0s(fun_name, *args) check_for_prngkeys(fun_name, *args) return promote_shapes(fun_name, *promote_dtypes(*args)) @@ -229,6 +230,7 @@ def promote_args(fun_name: str, *args: ArrayLike) -> list[Array]: def promote_args_numeric(fun_name: str, *args: ArrayLike) -> list[Array]: check_arraylike(fun_name, *args) + args = tuple(_check_jax_array_protocol(arg) for arg in args) _check_no_float0s(fun_name, *args) check_for_prngkeys(fun_name, *args) return promote_shapes(fun_name, *promote_dtypes_numeric(*args)) @@ -239,11 +241,19 @@ def promote_args_inexact(fun_name: str, *args: ArrayLike) -> list[Array]: Promotes non-inexact types to an inexact type.""" check_arraylike(fun_name, *args) + args = tuple(_check_jax_array_protocol(arg) for arg in args) _check_no_float0s(fun_name, *args) check_for_prngkeys(fun_name, *args) return promote_shapes(fun_name, *promote_dtypes_inexact(*args)) +def canonicalize_device_to_sharding(device: xc.Device | Sharding | None + ) -> Sharding | None: + if isinstance(device, xc.Device): + return SingleDeviceSharding(device) + return device + + @partial(api.jit, inline=True) def _broadcast_arrays(*args: ArrayLike) -> list[Array]: """Like Numpy's broadcast_arrays but doesn't return views.""" @@ -258,7 +268,7 @@ def _broadcast_arrays(*args: ArrayLike) -> list[Array]: def _broadcast_to(arr: ArrayLike, shape: DimSize | Shape, sharding=None ) -> Array: - check_arraylike("broadcast_to", arr) + arr = ensure_arraylike("broadcast_to", arr) arr = arr if isinstance(arr, Array) else lax.asarray(arr) if not isinstance(shape, tuple) and np.ndim(shape) == 0: shape = (shape,) @@ -286,6 +296,7 @@ def _broadcast_to(arr: ArrayLike, shape: DimSize | Shape, sharding=None # materialize the broadcast forms of scalar arguments. @api.jit def _where(condition: ArrayLike, x: ArrayLike, y: ArrayLike) -> Array: + condition, x, y = ensure_arraylike("where", condition, x, y) if x is None or y is None: raise ValueError("Either both or neither of the x and y arguments should " "be provided to jax.numpy.where, got {} and {}." @@ -313,7 +324,7 @@ def normalize_device_to_sharding(device: xc.Device | Sharding | None) -> Shardin @export -def ndim(a: ArrayLike) -> int: +def ndim(a: ArrayLike | SupportsNdim) -> int: """Return the number of dimensions of an array. JAX implementation of :func:`numpy.ndim`. Unlike ``np.ndim``, this function @@ -321,7 +332,7 @@ def ndim(a: ArrayLike) -> int: tuple. Args: - a: array-like object. + a: array-like object, or any object with an ``ndim`` attribute. Returns: An integer specifying the number of dimensions of ``a``. @@ -346,13 +357,18 @@ def ndim(a: ArrayLike) -> int: >>> x.ndim 1 """ + if hasattr(a, "ndim"): + return a.ndim # Deprecation warning added 2025-2-20. check_arraylike("ndim", a, emit_warning=True) - return np.ndim(a) # NumPy dispatches to a.ndim if available. + if hasattr(a, "__jax_array__"): + a = a.__jax_array__() + # NumPy dispatches to a.ndim if available. + return np.ndim(a) # type: ignore[arg-type] @export -def shape(a: ArrayLike) -> tuple[int, ...]: +def shape(a: ArrayLike | SupportsShape) -> tuple[int, ...]: """Return the shape an array. JAX implementation of :func:`numpy.shape`. Unlike ``np.shape``, this function @@ -360,7 +376,7 @@ def shape(a: ArrayLike) -> tuple[int, ...]: tuple. Args: - a: array-like object. + a: array-like object, or any object with a ``shape`` attribute. Returns: An tuple of integers representing the shape of ``a``. @@ -385,13 +401,18 @@ def shape(a: ArrayLike) -> tuple[int, ...]: >>> x.shape (10,) """ + if hasattr(a, "shape"): + return a.shape # Deprecation warning added 2025-2-20. check_arraylike("shape", a, emit_warning=True) - return np.shape(a) # NumPy dispatches to a.shape if available. + if hasattr(a, "__jax_array__"): + a = a.__jax_array__() + # NumPy dispatches to a.shape if available. + return np.shape(a) # type: ignore[arg-type] @export -def size(a: ArrayLike, axis: int | None = None) -> int: +def size(a: ArrayLike | SupportsSize | SupportsShape, axis: int | None = None) -> int: """Return number of elements along a given axis. JAX implementation of :func:`numpy.size`. Unlike ``np.size``, this function @@ -399,7 +420,8 @@ def size(a: ArrayLike, axis: int | None = None) -> int: tuple. Args: - a: array-like object + a: array-like object, or any object with a ``size`` attribute when ``axis`` is not + specified, or with a ``shape`` attribute when ``axis`` is specified. axis: optional integer along which to count elements. By default, return the total number of elements. @@ -428,6 +450,12 @@ def size(a: ArrayLike, axis: int | None = None) -> int: >>> y.size 6 """ + if (axis is None and hasattr(a, "size")) or (axis is not None and hasattr(a, "shape")): + # NumPy dispatches to a.size/a.shape if available. + return np.size(a, axis=axis) # type: ignore[arg-type] # Deprecation warning added 2025-2-20. check_arraylike("size", a, emit_warning=True) - return np.size(a, axis=axis) # NumPy dispatches to a.size if available. + if hasattr(a, "__jax_array__"): + a = a.__jax_array__() + # NumPy dispatches to a.size/a.shape if available. + return np.size(a, axis=axis) # type: ignore[arg-type] diff --git a/jax/_src/numpy/vectorize.py b/jax/_src/numpy/vectorize.py index e6ad1386a52e..f166a96a4693 100644 --- a/jax/_src/numpy/vectorize.py +++ b/jax/_src/numpy/vectorize.py @@ -23,7 +23,7 @@ from jax._src import api from jax._src import config -from jax import lax +from jax._src.lax import lax from jax._src.numpy import lax_numpy as jnp from jax._src.util import set_module, safe_map as map, safe_zip as zip @@ -144,18 +144,15 @@ def wrapped(*args): out = func(*args) out_shapes = map(np.shape, out if isinstance(out, tuple) else [out]) - if expected_output_core_dims is None: - output_core_dims = [()] * len(out_shapes) - else: - output_core_dims = expected_output_core_dims - if len(output_core_dims) > 1 and not isinstance(out, tuple): - raise TypeError( - "output must be a tuple when multiple outputs are expected, " - "got: {!r}\n{}".format(out, error_context)) - if len(out_shapes) != len(output_core_dims): - raise TypeError( - 'wrong number of output arguments: expected %r, got %r %s' - % (len(output_core_dims), len(out_shapes), error_context)) + output_core_dims = expected_output_core_dims + if len(output_core_dims) > 1 and not isinstance(out, tuple): + raise TypeError( + "output must be a tuple when multiple outputs are expected, " + "got: {!r}\n{}".format(out, error_context)) + if len(out_shapes) != len(output_core_dims): + raise TypeError( + 'wrong number of output arguments: expected %r, got %r %s' + % (len(output_core_dims), len(out_shapes), error_context)) sizes = dict(dim_sizes) for shape, core_dims in zip(out_shapes, output_core_dims): @@ -215,7 +212,8 @@ def vectorize(pyfunc, *, excluded=frozenset(), signature=None): ``(m,n),(n)->(m)`` for vectorized matrix-vector multiplication. If provided, ``pyfunc`` will be called with (and expected to return) arrays with shapes given by the size of corresponding core dimensions. By - default, pyfunc is assumed to take scalars arrays as input and output. + default, pyfunc is assumed to take scalar arrays as input, and if + ``signature`` is ``None``, ``pyfunc`` can produce outputs of any shape. Returns: Vectorized version of the given function. @@ -294,8 +292,11 @@ def wrapped(*args, **kwargs): broadcast_shape, dim_sizes = _parse_input_dimensions( args, input_core_dims, error_context) - checked_func = _check_output_dims( - excluded_func, dim_sizes, output_core_dims, error_context) + if output_core_dims is None: + checked_func = excluded_func + else: + checked_func = _check_output_dims( + excluded_func, dim_sizes, output_core_dims, error_context) # Detect implicit rank promotion: if config.numpy_rank_promotion.value != "allow": @@ -307,7 +308,7 @@ def wrapped(*args, **kwargs): f" promotion for jnp.vectorize function with signature {signature}." " Set the jax_numpy_rank_promotion config option to 'allow' to" " disable this message; for more information, see" - " https://jax.readthedocs.io/en/latest/rank_promotion_warning.html.") + " https://docs.jax.dev/en/latest/rank_promotion_warning.html.") if config.numpy_rank_promotion.value == "warn": warnings.warn(msg) elif config.numpy_rank_promotion.value == "raise": diff --git a/jax/_src/numpy/window_functions.py b/jax/_src/numpy/window_functions.py index 96a15db777a8..6d1bfb245272 100644 --- a/jax/_src/numpy/window_functions.py +++ b/jax/_src/numpy/window_functions.py @@ -16,11 +16,11 @@ from jax._src import core from jax._src import dtypes +from jax._src.lax import lax from jax._src.numpy import lax_numpy from jax._src.numpy import ufuncs from jax._src.typing import Array, ArrayLike from jax._src.util import set_module -from jax import lax export = set_module('jax.numpy') diff --git a/jax/_src/ops/scatter.py b/jax/_src/ops/scatter.py index e19be6622168..4db79557c3cc 100644 --- a/jax/_src/ops/scatter.py +++ b/jax/_src/ops/scatter.py @@ -17,8 +17,9 @@ from __future__ import annotations from collections.abc import Callable, Sequence -from typing import Union +from typing import Any, Union import warnings +from functools import partial import numpy as np @@ -27,9 +28,12 @@ from jax._src import config from jax._src import core from jax._src import dtypes +from jax._src import sharding +from jax._src import tree_util from jax._src import util from jax._src.lax import lax as lax_internal from jax._src.numpy import indexing +from jax._src.pjit import auto_axes from jax._src.numpy import lax_numpy as jnp from jax._src.numpy import reductions from jax._src.numpy.util import check_arraylike, promote_dtypes @@ -42,8 +46,10 @@ Scalar = Union[complex, float, int, np.number] -def _scatter_update(x, idx, y, scatter_op, indices_are_sorted, - unique_indices, mode=None, normalize_indices=True): +def _scatter_update(x: ArrayLike, idx: Index, y: ArrayLike, scatter_op: Callable[..., Array], + indices_are_sorted: bool, unique_indices: bool, + mode: lax.GatherScatterMode | str | None = None, normalize_indices: bool = True, + out_sharding: sharding.Sharding | None = None): """Helper for indexed updates. Computes the value of x that would result from computing:: @@ -74,17 +80,26 @@ def _scatter_update(x, idx, y, scatter_op, indices_are_sorted, # XLA gathers and scatters are very similar in structure; the scatter logic # is more or less a transpose of the gather equivalent. treedef, static_idx, dynamic_idx = indexing.split_index_for_jit(idx, x.shape) - return _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx, - indices_are_sorted, unique_indices, mode, - normalize_indices) + + internal_scatter = partial( + _scatter_impl, scatter_op=scatter_op, treedef=treedef, + static_idx=static_idx, indices_are_sorted=indices_are_sorted, + unique_indices=unique_indices, mode=mode, + normalize_indices=normalize_indices) + if out_sharding is not None: + return auto_axes(internal_scatter, out_sharding=out_sharding + )(x, y, dynamic_idx) + return internal_scatter(x, y, dynamic_idx) # TODO(phawkins): re-enable jit after fixing excessive recompilation for # slice indexes (e.g., slice(0, 5, None), slice(10, 15, None), etc.). # @partial(jit, static_argnums=(2, 3, 4)) -def _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx, - indices_are_sorted, unique_indices, mode, - normalize_indices): +def _scatter_impl(x: ArrayLike, y: ArrayLike, dynamic_idx: tuple[Any, ...], *, + scatter_op: Callable[..., Array], + treedef: tree_util.PyTreeDef, static_idx: tuple[Any, ...], + indices_are_sorted: bool, unique_indices: bool, + mode: lax.GatherScatterMode | str | None, normalize_indices: bool): dtype = lax.dtype(x) weak_type = dtypes.is_weakly_typed(x) @@ -168,7 +183,7 @@ def _segment_update(name: str, unique_indices: bool = False, bucket_size: int | None = None, reducer: Callable | None = None, - mode: lax.GatherScatterMode | None = None) -> Array: + mode: lax.GatherScatterMode | str | None = None) -> Array: check_arraylike(name, data, segment_ids) mode = lax.GatherScatterMode.FILL_OR_DROP if mode is None else mode data = jnp.asarray(data) @@ -207,7 +222,7 @@ def segment_sum(data: ArrayLike, indices_are_sorted: bool = False, unique_indices: bool = False, bucket_size: int | None = None, - mode: lax.GatherScatterMode | None = None) -> Array: + mode: lax.GatherScatterMode | str | None = None) -> Array: """Computes the sum within segments of an array. Similar to TensorFlow's `segment_sum @@ -262,7 +277,7 @@ def segment_prod(data: ArrayLike, indices_are_sorted: bool = False, unique_indices: bool = False, bucket_size: int | None = None, - mode: lax.GatherScatterMode | None = None) -> Array: + mode: lax.GatherScatterMode | str | None = None) -> Array: """Computes the product within segments of an array. Similar to TensorFlow's `segment_prod @@ -272,8 +287,7 @@ def segment_prod(data: ArrayLike, data: an array with the values to be reduced. segment_ids: an array with integer dtype that indicates the segments of `data` (along its leading axis) to be reduced. Values can be repeated and - need not be sorted. Values outside of the range [0, num_segments) are - dropped and do not contribute to the result. + need not be sorted. num_segments: optional, an int with nonnegative value indicating the number of segments. The default is set to be the minimum number of segments that would support all indices in ``segment_ids``, calculated as @@ -283,11 +297,11 @@ def segment_prod(data: ArrayLike, indices_are_sorted: whether ``segment_ids`` is known to be sorted. unique_indices: whether `segment_ids` is known to be free of duplicates. bucket_size: size of bucket to group indices into. ``segment_prod`` is - performed on each bucket separately to improve numerical stability of - addition. Default ``None`` means no bucketing. + performed on each bucket separately to improve numerical stability. + Default ``None`` means no bucketing. mode: a :class:`jax.lax.GatherScatterMode` value describing how out-of-bounds indices should be handled. By default, values outside of the - range [0, num_segments) are dropped and do not contribute to the sum. + range [0, num_segments) are dropped and do not contribute to the result. Returns: An array with shape :code:`(num_segments,) + data.shape[1:]` representing the @@ -318,7 +332,7 @@ def segment_max(data: ArrayLike, indices_are_sorted: bool = False, unique_indices: bool = False, bucket_size: int | None = None, - mode: lax.GatherScatterMode | None = None) -> Array: + mode: lax.GatherScatterMode | str | None = None) -> Array: """Computes the maximum within segments of an array. Similar to TensorFlow's `segment_max @@ -328,8 +342,7 @@ def segment_max(data: ArrayLike, data: an array with the values to be reduced. segment_ids: an array with integer dtype that indicates the segments of `data` (along its leading axis) to be reduced. Values can be repeated and - need not be sorted. Values outside of the range [0, num_segments) are - dropped and do not contribute to the result. + need not be sorted. num_segments: optional, an int with nonnegative value indicating the number of segments. The default is set to be the minimum number of segments that would support all indices in ``segment_ids``, calculated as @@ -342,7 +355,7 @@ def segment_max(data: ArrayLike, performed on each bucket separately. Default ``None`` means no bucketing. mode: a :class:`jax.lax.GatherScatterMode` value describing how out-of-bounds indices should be handled. By default, values outside of the - range [0, num_segments) are dropped and do not contribute to the sum. + range [0, num_segments) are dropped and do not contribute to the result. Returns: An array with shape :code:`(num_segments,) + data.shape[1:]` representing the @@ -373,7 +386,7 @@ def segment_min(data: ArrayLike, indices_are_sorted: bool = False, unique_indices: bool = False, bucket_size: int | None = None, - mode: lax.GatherScatterMode | None = None) -> Array: + mode: lax.GatherScatterMode | str | None = None) -> Array: """Computes the minimum within segments of an array. Similar to TensorFlow's `segment_min @@ -383,8 +396,7 @@ def segment_min(data: ArrayLike, data: an array with the values to be reduced. segment_ids: an array with integer dtype that indicates the segments of `data` (along its leading axis) to be reduced. Values can be repeated and - need not be sorted. Values outside of the range [0, num_segments) are - dropped and do not contribute to the result. + need not be sorted. num_segments: optional, an int with nonnegative value indicating the number of segments. The default is set to be the minimum number of segments that would support all indices in ``segment_ids``, calculated as @@ -397,7 +409,7 @@ def segment_min(data: ArrayLike, performed on each bucket separately. Default ``None`` means no bucketing. mode: a :class:`jax.lax.GatherScatterMode` value describing how out-of-bounds indices should be handled. By default, values outside of the - range [0, num_segments) are dropped and do not contribute to the sum. + range [0, num_segments) are dropped and do not contribute to the result. Returns: An array with shape :code:`(num_segments,) + data.shape[1:]` representing the diff --git a/jax/_src/ops/special.py b/jax/_src/ops/special.py index fe4c46394832..6ed8804bc8e2 100644 --- a/jax/_src/ops/special.py +++ b/jax/_src/ops/special.py @@ -47,16 +47,15 @@ def logsumexp(a: ArrayLike, axis: Axis = None, b: ArrayLike | None = None, JAX implementation of :func:`scipy.special.logsumexp`. .. math:: - \mathrm{logsumexp}(a) = \mathrm{log} \sum_j b \cdot \mathrm{exp}(a_{ij}) + \operatorname{logsumexp} a = \log \sum_i b_i \exp a_i - where the :math:`j` indices range over one or more dimensions to be reduced. + where the :math:`i` indices range over one or more dimensions to be reduced. Args: a: the input array axis: int or sequence of ints, default=None. Axis along which the sum to be computed. If None, the sum is computed along all the axes. - b: scaling factors for :math:`\mathrm{exp}(a)`. Must be broadcastable to the - shape of `a`. + b: scaling factors for the exponentials. Must be broadcastable to the shape of `a`. keepdims: If ``True``, the axes that are reduced are left in the output as dimensions of size 1. return_sign: If ``True``, the output will be a ``(result, sign)`` pair, diff --git a/jax/_src/pallas/BUILD b/jax/_src/pallas/BUILD index 91987167512c..e080c601836f 100644 --- a/jax/_src/pallas/BUILD +++ b/jax/_src/pallas/BUILD @@ -45,6 +45,7 @@ py_library( "//jax:core", "//jax:dtypes", "//jax:effects", + "//jax:frozen_dict", "//jax:mlir", "//jax:partial_eval", "//jax:pretty_printer", diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index 206c2a73fbed..1aae4452d32f 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -24,7 +24,8 @@ import functools import itertools import threading -from typing import Any, ClassVar, Hashable, Protocol, Union, runtime_checkable +from typing import Any, ClassVar, Literal, Protocol, TypeAlias, Union, runtime_checkable +from collections.abc import Hashable import jax from jax._src import api_util @@ -39,10 +40,12 @@ from jax._src.interpreters import mlir from jax._src.interpreters import partial_eval as pe from jax._src.state import discharge as state_discharge +from jax._src.state import indexing from jax._src.state import types as state_types from jax._src.state.types import TransformedRef import jax.numpy as jnp + class DynamicGridDim: def __repr__(self): return "DynamicGridDim" @@ -67,11 +70,62 @@ def __repr__(self): SEMAPHORE_INTERPRET_DTYPE = jnp.int16 SEMAPHORE_MAX_VALUE = jnp.iinfo(SEMAPHORE_INTERPRET_DTYPE).max +class AbstractSemaphoreTyRules: + @staticmethod + def pallas_interpret_element_aval(_) -> jax_core.ShapedArray: + return jax_core.ShapedArray((), SEMAPHORE_INTERPRET_DTYPE) + + @staticmethod + def physical_element_aval(_) -> jax_core.ShapedArray: + return jax_core.ShapedArray((), jnp.int32) + +# TODO(sharadmv): implement dtype rules for AbstractSemaphoreTy +class AbstractSemaphoreTy(dtypes.ExtendedDType): + name: str + _rules = AbstractSemaphoreTyRules + + def __repr__(self) -> str: + return self.name + + def __eq__(self, other): + return self.__class__ == other.__class__ + + def __hash__(self) -> int: + return hash(self.__class__) + +class semaphore_dtype(dtypes.extended): + """Common dtype for all kinds of semaphore dtypes. + + This is an abstract class that should never be instantiated, but rather + exists for the sake of `jnp.issubdtype`. + """ + +class semaphore(semaphore_dtype): + """Regular semaphore dtype. + + Like its superclass, this class should never be instantiated. + """ + +class Semaphore(AbstractSemaphoreTy): + name = "semaphore" + type = semaphore + +class barrier_semaphore(semaphore_dtype): + """Barrier semaphore dtype. + + Like its superclass, this class should never be instantiated. + """ + +class BarrierSemaphore(AbstractSemaphoreTy): + name = "barrier_semaphore" + type = barrier_semaphore + +Backend = Literal["mosaic_tpu", "triton", "mosaic_gpu"] @runtime_checkable class CompilerParams(Protocol): """Base class for compiler parameters.""" - PLATFORM: ClassVar[str] + BACKEND: ClassVar[Backend] # Subclasses must be dataclasses. __dataclass_fields__: ClassVar[dict[str, dataclasses.Field[Any]]] @@ -90,34 +144,27 @@ class ShapedArrayWithMemorySpace(jax_core.ShapedArray): __slots__ = ["memory_space"] def __init__(self, shape, dtype, weak_type=False, sharding=None, - memory_space=None): - super().__init__(shape, dtype, weak_type=weak_type, sharding=sharding) + vma=frozenset(), memory_space=None): + super().__init__(shape, dtype, weak_type=weak_type, sharding=sharding, + vma=vma) self.memory_space = memory_space def __eq__(self, other): return super().__eq__(other) and self.memory_space == other.memory_space def __hash__(self): - return hash(( - self.shape, - self.dtype, - self.weak_type, - getattr(self, "sharding", None), - self.memory_space, - )) + return hash((self.shape, self.dtype, self.weak_type, self.sharding, + self.vma, self.memory_space)) def str_short(self, short_dtypes=False): - dt_str = \ - dtypes.short_dtype_name(self.dtype) if short_dtypes else self.dtype.name + dt_str = (dtypes.short_dtype_name(self.dtype) if short_dtypes else + self.dtype.name) dt_str = dt_str.replace("void", "float0") shapestr = ",".join(map(str, self.shape)) - if hasattr(self, "sharding"): - sharding_str = f"{dt_str}[{shapestr}]({self.sharding})" - else: - sharding_str = "" - memoryspace_str = ( - "" if self.memory_space is None else f"<{self.memory_space}>" - ) + sharding_str = (f"{dt_str}[{shapestr}]({self.sharding})" + if self.sharding else "") + memoryspace_str = ("" if self.memory_space is None + else f"<{self.memory_space}>") return f"{dt_str}{memoryspace_str}[{shapestr}]{sharding_str}" def update( @@ -126,6 +173,7 @@ def update( dtype=None, weak_type=None, sharding=None, + vma=None, memory_space=None, ): if shape is None: @@ -135,11 +183,14 @@ def update( if weak_type is None: weak_type = self.weak_type if sharding is None: - sharding = getattr(self, "sharding", None) + sharding = self.sharding + if vma is None: + vma = self.vma if memory_space is None: memory_space = self.memory_space return ShapedArrayWithMemorySpace( - shape, dtype, weak_type, sharding=sharding, memory_space=memory_space + shape, dtype, weak_type, sharding=sharding, vma=vma, + memory_space=memory_space ) mlir.ir_type_handlers[ShapedArrayWithMemorySpace] = mlir._array_ir_types @@ -148,7 +199,7 @@ def update( class MemoryRef: """Like jax.ShapeDtypeStruct but with memory spaces.""" shape: tuple[int, ...] - dtype: jnp.dtype + dtype: jnp.dtype | dtypes.ExtendedDType # TODO(b/368122763): Unify memory space types across backends memory_space: Any @@ -186,8 +237,10 @@ def __repr__(self) -> str: return f'MemRef<{self.memory_space}>{{{self.inner_aval.str_short()}}}' def update_weak_type(self, weak_type): - return AbstractMemoryRef( - self.inner_aval.update_weak_type(weak_type), self.memory_space) + return self.update(inner_aval=self.inner_aval.update_weak_type(weak_type)) + + def update_vma(self, vma): + return self.update(inner_aval=self.inner_aval.update_vma(vma)) def update(self, inner_aval=None, memory_space=None): inner_aval = self.inner_aval if inner_aval is None else inner_aval @@ -195,8 +248,7 @@ def update(self, inner_aval=None, memory_space=None): return AbstractMemoryRef(inner_aval, memory_space) def to_tangent_aval(self): - return AbstractMemoryRef( - self.inner_aval.to_tangent_aval(), self.memory_space) + return self.update(inner_aval=self.inner_aval.to_tangent_aval()) # TODO(dougalm, sharadmv): figure out how to avoid needing this def normalize(self): @@ -219,6 +271,7 @@ class MemorySpace(enum.Enum): ANY = "any" # Unrestricted memory space (usually HBM) ERROR = "error" # Memory space for checkify errors. INDEX = "index" # Memory space for scalar prefetch arguments. + KEY = "key" # Memory space for PRNG keys. def __str__(self) -> str: return self.value @@ -283,49 +336,158 @@ def current_grid_env() -> GridEnv | None: return _pallas_tracing_env.grid_env_stack[-1] -class Mapped: - """Used as a block shape dimension to denote a mapped dimension. - A mapped dimension behaves like `1` except it is squeezed from the block. - See :ref:`pallas_blockspec` for more details. - """ - def __repr__(self): - return "Mapped" -mapped = Mapped() +@dataclasses.dataclass(frozen=True) +class Element: + """Use to index an array using an elementwise start index.""" + block_size: int + padding: tuple[int, int] = (0, 0) + def __str__(self): + if self.padding == (0, 0): + return f"Element({self.block_size})" + return f"Element({self.block_size}, padding={self.padding})" @dataclasses.dataclass(frozen=True) -class Unblocked: - padding: tuple[tuple[int, int], ...] | None = None - - def __repr__(self): - return f"Unblocked(padding={self.padding})" -unblocked = Unblocked() +class Squeezed: + """Represents a one-sized block dimension that is squeezed out in the kernel.""" +squeezed = Squeezed() +@dataclasses.dataclass(frozen=True) class Blocked: + """The default BlockShape type.""" + block_size: int + + def __str__(self): + return f"Blocked({self.block_size})" + +@dataclasses.dataclass(frozen=True) +class BoundedSlice: + """Allows to specify a bounded slice of a dimension. + + Specifically, the index_map need to return a `pl.Slice/pl.ds` for this + dimension. The start and size may be dynamic, as long as the size <= + block_size. + """ + block_size: int + def __repr__(self): - return "Blocked" -blocked = Blocked() + return f"BoundedSlice({self.block_size})" +BlockDim: TypeAlias = Element | Squeezed | Blocked | BoundedSlice -IndexingMode = Union[Blocked, Unblocked] def default_index_map(ndim: int) -> Callable: return lambda *args: (0,) * ndim + +def _canonicalize_block_dim(dim: BlockDim | int | None) -> BlockDim: + match dim: + case None: + return squeezed + case int(): + return Blocked(int(dim)) + case Squeezed() | Blocked() | Element() | BoundedSlice(): + return dim + case _: + # Handle case where the dim is a symbolic dimension so we assume it is + # Blocked. + if jax_core.is_symbolic_dim(dim): + return Blocked(dim) + try: + return Blocked(int(dim)) + except Exception as e: + raise ValueError( + f"Unsupported block dimension type: {type(dim)}. Allowed types:" + " `pl.Squeezed`, `pl.Blocked`, `pl.Element`, `int`, `None`." + ) from e + +def _canonicalize_block_shape(block_shape: Sequence[BlockDim | int | None] + ) -> tuple[BlockDim, ...]: + return tuple(_canonicalize_block_dim(dim) for dim in block_shape) + + +def _get_block_dim_size(dim: BlockDim) -> int: + match dim: + case Squeezed(): + return 1 + case Blocked(block_size): + return block_size + case Element(): + return dim.block_size + case BoundedSlice(block_size): + return block_size + case _: + raise ValueError(f"Unsupported block shape type: {type(dim)}") + + +def _get_block_shape(block_shape: tuple[BlockDim, ...]) -> tuple[int, ...]: + return tuple(_get_block_dim_size(dim) for dim in block_shape) + +def _get_ref_block_shape(block_shape: tuple[BlockDim, ...]) -> tuple[int, ...]: + # Special handling for squeezed here (don't include Squeezed dims in the Ref + # shape). + return tuple( + _get_block_dim_size(dim) + for dim in block_shape + if not isinstance(dim, Squeezed) + ) + + +class _IndexMapFunc: + """Helper class that checks for index_map equality.""" + + def __init__(self, index_map): + self.index_map = index_map + functools.update_wrapper(self, self.index_map) + + def __eq__(self, other: object): + if not isinstance(other, _IndexMapFunc): + return NotImplemented + return self.index_map == other.index_map + + def __call__(self, *args, **kwargs): + out_indices = self.index_map(*args, **kwargs) + if isinstance(out_indices, list): + out_indices = tuple(out_indices) + if not isinstance(out_indices, tuple): + out_indices = (out_indices,) + return out_indices + + @dataclasses.dataclass class BlockSpec: """Specifies how an array should be sliced for each invocation of a kernel. - See :ref:`pallas_blockspec` for more details. + The `block_shape` is a sequence of `int | None`s, or `BlockDim` types (e.g. + `pl.Element`, `pl.Squeezed`, `pl.Blocked`, `pl.BoundedSlice`). Each of these + types specify the size of the block dimension. `None` is used to specify a + dimension that is squeezed out of the kernel. The `BlockDim` types allow for + more fine-grained control over the indexing of the dimension. The `index_map` + needs to return a tuple of the same length as `block_shape`, which each entry + depending on the type of `BlockDim`. + + See :ref:`pallas_blockspec` and the individual `BlockDim` type docstrings for + more details. """ # An internal canonicalized version is in BlockMapping. - block_shape: Sequence[int | None] | None = None + block_shape: Sequence[BlockDim | int | None] | None = None index_map: Callable[..., Any] | None = None memory_space: Any | None = dataclasses.field(kw_only=True, default=None) - indexing_mode: IndexingMode = dataclasses.field(kw_only=True, default=blocked) + indexing_mode: Any | None = None pipeline_mode: Buffered | None = None + def __post_init__(self): + # TODO(sharadmv): Remove this check. + if self.indexing_mode is not None: + raise ValueError( + "indexing_mode has been removed. Please pass in `pl.Element` for each" + " block dimension in `block_shape` instead to enable 'Unblocked'" + " indexing." + ) + if self.index_map is not None: + self.index_map = _IndexMapFunc(self.index_map) + def to_block_mapping( self, origin: OriginStr, @@ -336,6 +498,7 @@ def to_block_mapping( index_map_tree: tree_util.PyTreeDef, grid: GridMappingGrid, mapped_dims: tuple[int, ...], + debug: bool = False, ) -> BlockMapping: if self.index_map is None: index_map_func = default_index_map(len(array_aval.shape)) @@ -343,9 +506,9 @@ def to_block_mapping( else: index_map_func = self.index_map if self.block_shape is None: - block_shape = array_aval.shape + block_shape = _canonicalize_block_shape(array_aval.shape) else: - block_shape = self.block_shape # type: ignore + block_shape = _canonicalize_block_shape(self.block_shape) if len(array_aval.shape) != len(block_shape): raise ValueError( f"Block shape for {origin} (= {block_shape}) " @@ -353,15 +516,21 @@ def to_block_mapping( f"array shape {array_aval.shape}." ) - unmapped_block_shape = tuple(s for s in block_shape if s is not None) - block_array_aval = array_aval.update(shape=unmapped_block_shape) + ref_block_shape = _get_ref_block_shape(block_shape) if isinstance(array_aval, jax_core.DShapedArray): # Get the "max" shape for the ragged array. + block_array_aval = array_aval.update(shape=ref_block_shape) block_array_aval = jax_core.ShapedArray( block_array_aval.shape, block_array_aval.dtype, block_array_aval.weak_type, ) + elif isinstance(array_aval, ShapedArrayWithMemorySpace): + block_array_aval = jax_core.ShapedArray( + ref_block_shape, array_aval.dtype, array_aval.weak_type + ) + else: + block_array_aval = array_aval.update(shape=ref_block_shape) block_aval = AbstractMemoryRef(block_array_aval, self.memory_space) if ( @@ -376,50 +545,76 @@ def to_block_mapping( fake_index_map_args, fake_index_map_kwargs = \ index_map_tree.unflatten([False] * index_map_tree.num_leaves) - debug = api_util.debug_info("pallas_call index_map", - index_map_func, fake_index_map_args, - fake_index_map_kwargs) + debug_info = api_util.debug_info( + "pallas_call index_map", + index_map_func, + fake_index_map_args, + fake_index_map_kwargs, + ) flat_index_map_fun, index_map_out_tree_thunk = api_util.flatten_fun( - lu.wrap_init(index_map_func, debug_info=debug), index_map_tree) + lu.wrap_init(index_map_func, debug_info=debug_info), index_map_tree + ) with tracing_grid_env(grid, mapped_dims): jaxpr, out_avals, consts, () = pe.trace_to_jaxpr_dynamic( flat_index_map_fun, index_map_avals ) + index_map_out_tree = index_map_out_tree_thunk() + unflat_avals = tree_util.tree_unflatten(index_map_out_tree, out_avals) - mapped_block_shape = tuple(mapped if s is None else s for s in block_shape) - if len(out_avals) != len(block_shape): + if len(unflat_avals) != len(block_shape): raise ValueError( - f"Index map function {debug.func_src_info} for " + f"Index map function {debug_info.func_src_info} for " f"{origin} must return " f"{len(block_shape)} values to match {block_shape=}. " - f"Currently returning {len(out_avals)} values." + f"Currently returning {len(unflat_avals)} values:" ) + # Verify types match + for i, (idx_aval, bd) in enumerate(zip(unflat_avals, block_shape)): + match bd: + case BoundedSlice(): + if not isinstance(idx_aval, indexing.Slice): + raise ValueError( + "index_map returned a value of type" + f" {type(idx_aval)} at position {i} with block dimension" + f" {bd} when it should be pl.Slice" + ) + case Blocked() | Element() | Squeezed() | int(): + if ( + not isinstance(idx_aval, jax_core.ShapedArray) + and not idx_aval.shape + ): + raise ValueError( + "index_map returned a value of type" + f" {type(idx_aval)} at position {i} with block dimension" + f" {bd} when it should be a scalar" + ) for i, ov in enumerate(out_avals): if ov.shape or ov.dtype not in [jnp.int32, jnp.int64]: raise ValueError( - f"Index map function {debug.func_src_info} for " + f"Index map function {debug_info.func_src_info} for " f"{origin} must return integer scalars. Output[{i}] has type " f"{ov}." ) if consts: raise ValueError( - f"Index map function {debug.func_src_info} for " + f"Index map function {debug_info.func_src_info} for " f"{origin} must not capture constants: {consts}" ) array_aval_shape = _max_shape_from_aval(array_aval) mapping = BlockMapping( - block_shape=mapped_block_shape, + block_shape=block_shape, transformed_block_aval=block_aval, # There are no transforms by default index_map_jaxpr=jax_core.ClosedJaxpr(jaxpr, consts), - indexing_mode=self.indexing_mode, + index_map_out_tree=index_map_out_tree, array_shape_dtype=jax.ShapeDtypeStruct( array_aval_shape, array_aval.dtype ), origin=origin, pipeline_mode=self.pipeline_mode, + debug=debug, ) mapping.check_invariants() return mapping @@ -453,30 +648,27 @@ class BlockMapping: """ # TODO(apaszke,sharadmv): Replace mapped dims in block_shape with a transform. # After all, it's just indexing out singleton dimensions. - block_shape: tuple[Mapped | int, ...] + block_shape: tuple[BlockDim, ...] transformed_block_aval: AbstractMemoryRef index_map_jaxpr: jax_core.ClosedJaxpr - indexing_mode: IndexingMode + index_map_out_tree: tree_util.PyTreeDef array_shape_dtype: jax.ShapeDtypeStruct # The whole array origin: OriginStr transforms: Sequence[MemoryRefTransform] = () pipeline_mode: Buffered | None = None + debug: bool = False def check_invariants(self) -> None: if not config.enable_checks.value: return - unmapped_block_shape = tuple(s for s in self.block_shape if s is not mapped) - assert unmapped_block_shape == self.ref_aval.shape, ( + ref_block_shape = _get_ref_block_shape(self.block_shape) + assert ref_block_shape == self.ref_aval.shape, ( self.block_shape, self.ref_aval.shape) assert len(self.block_shape) == len(self.array_shape_dtype.shape), ( self.block_shape, self.array_shape_dtype ) assert not self.index_map_jaxpr.consts - assert len(self.block_shape) == len(self.index_map_jaxpr.out_avals), ( - self.block_shape, - self.index_map_jaxpr.out_avals, - ) assert all(ov.shape == () and (ov.dtype == jnp.int32 or ov.dtype == jnp.int64) for ov in self.index_map_jaxpr.out_avals), ( @@ -514,24 +706,46 @@ def compute_start_indices_interpret(self, loop_idx, *args): # updated values since we only care about the return values. block_indices, _ = split_list(block_indices_and_rest, [len(self.block_shape)]) - if isinstance(self.indexing_mode, Blocked): - return tuple(i if b is mapped else b * i - for b, i in zip(self.block_shape, block_indices)) - elif isinstance(self.indexing_mode, Unblocked): - return block_indices - else: - raise RuntimeError(f"Unknown indexing mode: {self.indexing_mode}") + def _get_start_index(i, b): + match b: + case Squeezed() | Element(): + return i + case Blocked(block_size): + return block_size * i + case _: + raise ValueError(f"Unsupported block dim type: {type(b)}") + return tuple( + _get_start_index(i, b) for i, b in zip(block_indices, self.block_shape) + ) def has_trivial_window(self): """If block shape is same as the array shape and index_map returns 0s.""" for b, s in zip(self.block_shape, self.array_shape_dtype.shape): - if b != s and not (b is mapped and s == 1): + if _get_block_dim_size(b) != s: return False for atom in self.index_map_jaxpr.jaxpr.outvars: if not (isinstance(atom, jax_core.Literal) and atom.val == 0): return False return True + def __repr__(self): + if self.debug: + return ( + f"BlockMapping(block_shape={self.block_shape}, " + f"transformed_block_aval={self.transformed_block_aval}, " + f"index_map_jaxpr={self.index_map_jaxpr}, " + f"index_map_out_tree={self.index_map_out_tree}, " + f"array_shape_dtype={self.array_shape_dtype}, " + f"origin={self.origin}, " + f"transforms={self.transforms}, " + f"pipeline_mode={self.pipeline_mode}, " + f"debug={self.debug})" + ) + return f"BlockMapping(block_shape={self.block_shape})" + + def __str__(self): + return self.__repr__() + @contextlib.contextmanager def tracing_grid_env(grid: GridMappingGrid, mapped_dims: tuple[int, ...]): @@ -596,6 +810,8 @@ class GridMapping: num_scratch_operands: int get_grid_indices: Callable | None = None local_grid_env: Callable | None = None + # Primarily dictates how much debugging information is printed. + debug: bool = False def check_invariants(self) -> None: if not config.enable_checks.value: return @@ -719,6 +935,29 @@ def out_shapes(self) -> Iterable[jax.ShapeDtypeStruct]: return tuple( bm.array_shape_dtype for bm in self.block_mappings_output) + def __repr__(self): + if self.debug: + return ( + f"GridMapping(grid={self.grid}, grid_names={self.grid_names}, " + f"block_mappings={self.block_mappings}, " + f"index_map_tree={self.index_map_tree}, " + f"index_map_avals={self.index_map_avals}, " + f"vmapped_dims={self.vmapped_dims}, " + f"num_index_operands={self.num_index_operands}, " + f"num_inputs={self.num_inputs}, " + f"num_outputs={self.num_outputs}, " + f"num_scratch_operands={self.num_scratch_operands}, " + f"get_grid_indices={self.get_grid_indices}, " + f"local_grid_env={self.local_grid_env}, " + f"debug={self.debug})" + ) + return ( + f"GridMapping(grid={self.grid}, block_mappings={self.block_mappings})" + ) + + def __str__(self): + return self.__repr__() + def _is_valid_grid_dim(dim: int | jax.Array) -> bool: if isinstance(dim, jax.Array): @@ -754,6 +993,7 @@ def _convert_block_spec_to_block_mapping( index_map_tree: tree_util.PyTreeDef, grid: GridMappingGrid, mapped_dims: tuple[int, ...], + debug: bool = False, ) -> BlockMapping: if block_spec is no_block_spec: block_spec = BlockSpec(None, None) @@ -764,15 +1004,17 @@ def _convert_block_spec_to_block_mapping( index_map_tree=index_map_tree, grid=grid, mapped_dims=mapped_dims, + debug=debug, ) + index_map_grid_aval = jax_core.ShapedArray((), jnp.int32) class ScratchShape(Protocol): def get_array_aval(self) -> jax_core.AbstractValue: ... - def get_ref_aval(self) -> state.AbstractRef: + def get_ref_aval(self) -> state.AbstractRef | TransformedRef: ... @@ -839,8 +1081,8 @@ def get_grid_mapping( out_avals: Sequence[jax_core.AbstractValue], out_tree: tree_util.PyTreeDef, out_origins: Sequence[OriginStr], -) -> tuple[tuple[jax_core.AbstractValue, ...], - GridMapping]: + debug: bool = False, +) -> tuple[tuple[jax_core.AbstractValue, ...], GridMapping]: if dynamic_shapes_export_enabled(): dim_check : Any = jax_core.is_dim else: @@ -895,7 +1137,7 @@ def get_grid_mapping( if in_specs_tree != in_tree: raise ValueError( pytreedef_mismatch_err_msg("`in_specs`", in_specs_tree, - "inputs", in_tree)) + "`inputs`", in_tree)) else: flat_in_specs = [no_block_spec] * len(in_avals) @@ -906,6 +1148,7 @@ def get_grid_mapping( index_map_tree=index_map_tree, grid=grid_mapping_grid, # type: ignore[arg-type] mapped_dims=(), + debug=debug, ), flat_in_specs, in_origins[num_flat_scalar_prefetch:], @@ -928,6 +1171,7 @@ def get_grid_mapping( index_map_tree=index_map_tree, grid=grid_mapping_grid, # type: ignore[arg-type] mapped_dims=(), + debug=debug, ), flat_out_specs, out_origins, @@ -944,6 +1188,7 @@ def get_grid_mapping( num_inputs=len(flat_in_specs), num_outputs=len(flat_out_specs), num_scratch_operands=num_flat_scratch_operands, + debug=debug, ) grid_mapping.check_invariants() in_ref_avals = [bm.ref_aval for bm in in_block_mappings] @@ -1040,7 +1285,10 @@ def wrapped(f): debug_info=api_util.debug_info("pallas_core_map", f, (), {})), in_tree) - with jax_core.extend_axis_env_nd(mesh.shape.items()): + with ( + tracing_grid_env(tuple(mesh.shape.values()), mapped_dims=()), + jax_core.extend_axis_env_nd(mesh.shape.items()), + ): jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, flat_args) out = core_map_p.bind(*consts, jaxpr=jaxpr, mesh=mesh, compiler_params=compiler_params, @@ -1054,11 +1302,19 @@ def wrapped(f): @core_map_p.def_effectful_abstract_eval -def _core_map_abstract_eval(*args, jaxpr, mesh, **_): +def _core_map_abstract_eval(*args, jaxpr, mesh, **kwargs): del args if jaxpr.outvars: raise ValueError("core_map must not return any outputs.") + interpret = kwargs.get('interpret', False) effs = set() + if interpret: + try: + from jax._src.pallas.mosaic import interpret as mosaic_tpu_interpret # Avoid circular dependency. + if isinstance(interpret, mosaic_tpu_interpret.InterpretParams): + effs = mosaic_tpu_interpret.get_interpret_effects() + except ImportError: + pass for eff in jaxpr.effects: if mesh.discharges_effect(eff): continue @@ -1095,6 +1351,7 @@ def default_mesh_discharge_rule( interpret, cost_estimate, name, + memory_space=MemorySpace.ANY, ): """Discharges a ``core_map`` over a mesh to a ``pallas_call``.""" del out_avals # Unused. @@ -1111,13 +1368,9 @@ def body(*args): for eff in jaxpr.effects if isinstance(eff, state_types.WriteEffect) ) - any_spec = BlockSpec(memory_space=MemorySpace.ANY) - grid_spec = GridSpec( - grid=tuple(mesh.shape.items()), - in_specs=[any_spec] * len(in_avals), - out_specs=[any_spec] * len(modified_idxs), - ) + spec = BlockSpec(memory_space=memory_space) from jax._src.pallas import pallas_call # Avoid circular dependency. + outs = pallas_call._pallas_call( body, name=name, @@ -1125,7 +1378,11 @@ def body(*args): input_output_aliases={ in_idx: out_idx for out_idx, in_idx in enumerate(modified_idxs) }, - grid_spec=grid_spec, + grid_spec=GridSpec( + grid=tuple(mesh.shape.items()), + in_specs=[spec] * len(in_avals), + out_specs=[spec] * len(modified_idxs), + ), mesh=mesh, compiler_params=compiler_params, interpret=interpret, @@ -1149,10 +1406,18 @@ def _core_map_discharge_rule(in_avals, out_avals, *args_flat, jaxpr, mesh, **kwa def _core_map_typecheck_rule(_, *in_atoms, jaxpr, mesh, **kwargs): - del in_atoms, kwargs + del in_atoms with jax_core.extend_axis_env_nd(tuple(mesh.shape.items())): jax_core.check_jaxpr(jaxpr) + interpret = kwargs.get('interpret', False) effs = set() + if interpret: + try: + from jax._src.pallas.mosaic import interpret as mosaic_tpu_interpret # Avoid circular dependency. + if isinstance(interpret, mosaic_tpu_interpret.InterpretParams): + effs = mosaic_tpu_interpret.get_interpret_effects() + except ImportError: + pass for eff in jaxpr.effects: if mesh.discharges_effect(eff): continue diff --git a/jax/_src/pallas/cost_estimate.py b/jax/_src/pallas/cost_estimate.py index 73db4a2e2d4a..ad238bdf475d 100644 --- a/jax/_src/pallas/cost_estimate.py +++ b/jax/_src/pallas/cost_estimate.py @@ -15,7 +15,8 @@ import dataclasses import functools import math -from typing import Any, Sequence +from typing import Any +from collections.abc import Sequence import jax from jax._src import api_util @@ -64,12 +65,11 @@ def cost_estimate_jaxpr( total_cost = CostEstimate(flops=0, transcendentals=0, bytes_accessed=0) for eqn in jaxpr.eqns: - _, bind_params = eqn.primitive.get_bind_params(eqn.params) rule = _cost_rules.get(eqn.primitive, None) if rule is not None: context = Context(avals_in=[v.aval for v in eqn.invars], avals_out=[v.aval for v in eqn.outvars]) - op_cost = rule(context, **bind_params) + op_cost = rule(context, **eqn.params) total_cost = total_cost + op_cost return pallas_core.CostEstimate( flops=total_cost.flops, @@ -239,15 +239,15 @@ def _pjit_cost_rule(ctx, *, jaxpr: jax_core.ClosedJaxpr, **_): ) register_cost_rule(pjit.pjit_p, _pjit_cost_rule) -def _custom_vjp_rule(ctx, *, fun_jaxpr: jax_core.ClosedJaxpr, **_): +def _custom_vjp_rule(ctx, *, call_jaxpr: jax_core.ClosedJaxpr, **_): del ctx - inner_cost = cost_estimate_jaxpr(fun_jaxpr) + inner_cost = cost_estimate_jaxpr(call_jaxpr) return CostEstimate( flops=inner_cost.flops, transcendentals=inner_cost.transcendentals, bytes_accessed=inner_cost.bytes_accessed, ) -register_cost_rule(custom_derivatives.custom_vjp_call_jaxpr_p, _custom_vjp_rule) +register_cost_rule(custom_derivatives.custom_vjp_call_p, _custom_vjp_rule) def _run_state_rule(*_, jaxpr: jax_core.Jaxpr, **_2): inner_cost = cost_estimate_jaxpr(pe.close_jaxpr(jaxpr)) diff --git a/jax/_src/pallas/fuser/BUILD b/jax/_src/pallas/fuser/BUILD index 66bbac33aabb..951c08d8f4fa 100644 --- a/jax/_src/pallas/fuser/BUILD +++ b/jax/_src/pallas/fuser/BUILD @@ -33,7 +33,7 @@ pytype_strict_library( deps = [ ":block_spec", ":custom_evaluate", - ":fusable", + ":fusible", ":fusion", ":jaxpr_fusion", ], @@ -48,8 +48,10 @@ pytype_strict_library( ":fuser_utils", "//jax", "//jax:ad_util", + "//jax:api", "//jax:api_util", "//jax:core", + "//jax:custom_derivatives", "//jax:partial_eval", "//jax:tree_util", "//jax:util", @@ -58,9 +60,9 @@ pytype_strict_library( ) pytype_strict_library( - name = "fusable", + name = "fusible", srcs = [ - "fusable.py", + "fusible.py", ], deps = [ ":fusion", @@ -91,28 +93,30 @@ pytype_strict_library( "jaxpr_fusion.py", ], deps = [ - ":fusable", - ":fusable_dtype", + ":fusible", + ":fusible_dtype", ":fusion", "//jax", "//jax:api_util", "//jax:core", "//jax:partial_eval", "//jax:tree_util", + "//jax:util", ], ) pytype_strict_library( - name = "fusable_dtype", + name = "fusible_dtype", srcs = [ - "fusable_dtype.py", + "fusible_dtype.py", ], deps = [ ":block_spec", - ":fusable", + ":fusible", "//jax", "//jax:api_util", "//jax:core", + "//jax:custom_derivatives", "//jax:dtypes", "//jax:partial_eval", "//jax:source_info_util", diff --git a/jax/_src/pallas/fuser/__init__.py b/jax/_src/pallas/fuser/__init__.py index 3295c8f1061a..39720100eb1d 100644 --- a/jax/_src/pallas/fuser/__init__.py +++ b/jax/_src/pallas/fuser/__init__.py @@ -17,6 +17,6 @@ from jax._src.pallas.fuser.block_spec import pull_block_spec as pull_block_spec from jax._src.pallas.fuser.block_spec import push_block_spec as push_block_spec from jax._src.pallas.fuser.custom_evaluate import evaluate as evaluate -from jax._src.pallas.fuser.fusable import fusable as fusable +from jax._src.pallas.fuser.fusible import fusible as fusible from jax._src.pallas.fuser.fusion import Fusion as Fusion from jax._src.pallas.fuser.jaxpr_fusion import fuse as fuse diff --git a/jax/_src/pallas/fuser/block_spec.py b/jax/_src/pallas/fuser/block_spec.py index de0cdd204f3c..e6ca4dddc61b 100644 --- a/jax/_src/pallas/fuser/block_spec.py +++ b/jax/_src/pallas/fuser/block_spec.py @@ -21,7 +21,8 @@ import enum import functools import threading -from typing import Any, Callable, Protocol, Sequence +from typing import Any, Protocol +from collections.abc import Callable, Sequence import jax from jax import lax @@ -29,11 +30,15 @@ from jax._src import core from jax._src import custom_derivatives from jax._src import pjit +from jax._src import prng +from jax._src import state from jax._src import tree_util from jax._src import util from jax._src.interpreters import partial_eval as pe from jax._src.pallas import core as pallas_core from jax._src.pallas.fuser import fuser_utils +from jax._src.state import indexing +from jax._src.state import primitives as state_primitives import jax.numpy as jnp import numpy as np @@ -94,6 +99,20 @@ def wrapped(*args): return wrapped +def _block_size(dim: pallas_core.Element | int | None) -> int | None: + match dim: + case ( + pallas_core.Element() + | pallas_core.BoundedSlice() + | pallas_core.Blocked() + ): + return dim.block_size + case pallas_core.Squeezed() | None: + return None + case _: + return dim # pytype: disable=bad-return-type + + @dataclasses.dataclass class UsageRuleContext: avals_in: tuple[core.AbstractValue, ...] @@ -170,8 +189,14 @@ def get_out_block_indices(self): _illegal = object() -_sp_env = threading.local() -_sp_env.scalar_prefetch = None + +class _SpEnv(threading.local): + + def __init__(self): + self.scalar_prefetch = None + + +_sp_env = _SpEnv() @contextlib.contextmanager @@ -192,7 +217,7 @@ def _wrap_block_spec_scalar_prefetch( block_spec: pallas_core.BlockSpec, num_grid_args: int, ) -> pallas_core.BlockSpec: - if block_spec is pallas_core.no_block_spec: + if block_spec is pallas_core.no_block_spec or block_spec.index_map is None: return block_spec def new_index_map(*args_and_scalar_prefetch): @@ -236,9 +261,7 @@ def wrapped(*args, **kwargs): jaxpr, consts, in_tree, out_tree_ = fuser_utils.make_jaxpr( f, *args, **kwargs ) - # TODO(sharadmv): handle these consts better, they should correspond to - # scalar prefetch. - del consts, out_tree_ + del out_tree_ jaxpr_out_usages = [{Usage.REGULAR}] * len(jaxpr.outvars) block_specs_ = jax.tree.map( _unwrap_block_spec_scalar_prefetch, out_block_specs @@ -251,15 +274,17 @@ def wrapped(*args, **kwargs): ) assert all(used_invars) assert all(used_consts) + read_usage_env = compute_usage(jaxpr, jaxpr_out_usages) in_block_specs, env, read_usage_env = _pull_block_spec( jaxpr, tuple(flat_block_specs), - jaxpr_out_usages, scalar_prefetch_handler=scalar_prefetch_handler, + read_usage_env=read_usage_env, grid=grid, ) kernel_fn = make_kernel_function( jaxpr, + consts, in_tree, out_tree, read_usage_env, @@ -285,8 +310,8 @@ def wrapped(*args, **kwargs): def _pull_block_spec( jaxpr: core.Jaxpr, out_block_specs: tuple[pallas_core.BlockSpec, ...], - out_usages, *, + read_usage_env: Callable[[core.Var], set[Usage]], scalar_prefetch_handler: Any | None = None, grid: tuple[int | jax.Array, ...], ) -> tuple[ @@ -294,7 +319,6 @@ def _pull_block_spec( tuple[dict[core.Var, pallas_core.BlockSpec], dict[int, Any]], Any, ]: - read_usage_env = compute_usage(jaxpr, out_usages) jaxpr_invar_usages = util.safe_map(read_usage_env, jaxpr.invars) env: dict[core.Var, pallas_core.BlockSpec] = {} scalar_prefetch_fn_env = {} @@ -306,7 +330,7 @@ def _pull_block_spec( def _read_block_spec(atom: core.Atom) -> pallas_core.BlockSpec | Any: if isinstance(atom, core.Literal): return pallas_core.no_block_spec - return env[atom] + return env.get(atom, pallas_core.no_block_spec) def _write_block_spec(atom: core.Atom, block_spec: pallas_core.BlockSpec): if isinstance(atom, core.Literal): @@ -315,9 +339,11 @@ def _write_block_spec(atom: core.Atom, block_spec: pallas_core.BlockSpec): for i, eqn in reversed(list(enumerate(jaxpr.eqns))): eqn_out_block_specs = tuple(util.safe_map(_read_block_spec, eqn.outvars)) + if all(bs is pallas_core.no_block_spec for bs in eqn_out_block_specs): + continue rule = pull_block_spec_rules.get(eqn.primitive, None) if not rule: - raise NotImplementedError(eqn.primitive) + raise NotImplementedError(eqn.primitive, eqn_out_block_specs) ctx = PullRuleContext( avals_in=tuple(v.aval for v in eqn.invars), avals_out=tuple(v.aval for v in eqn.outvars), @@ -405,6 +431,7 @@ def _get_in_block_spec(v, usage): def make_kernel_function( jaxpr: core.Jaxpr, + consts, in_tree, out_tree, read_usage_env, @@ -417,15 +444,22 @@ def make_kernel_function( invar_usages = util.safe_map(read_usage_env, jaxpr.invars) bs_env, scalar_prefetch_fn_env = block_spec_env - def _remove_nones(shape: tuple[int | None, ...] | None) -> tuple[int, ...]: + def _remove_nones( + shape: tuple[pallas_core.BlockDim | int | None, ...] | None, + ) -> tuple[int, ...]: assert shape is not None - return tuple(s for s in shape if s is not None) + new_shape = tuple(_block_size(s) for s in shape) + return tuple(s for s in new_shape if s is not None) _no_aval = object() def _get_block_aval(bs, aval): + if isinstance(aval, state.AbstractRef): + return aval if bs is pallas_core.no_block_spec or bs is None: return _no_aval + if bs.block_shape is None: + return aval return aval.update(shape=_remove_nones(bs.block_shape)) # pytype: disable=attribute-error in_block_avals = [ @@ -451,7 +485,7 @@ def sds_like(x): def _read_block_spec(atom: core.Atom) -> pallas_core.BlockSpec | Any: if isinstance(atom, core.Literal): return pallas_core.no_block_spec - return bs_env[atom] + return bs_env.get(atom, pallas_core.no_block_spec) def kernel_fn(program_ids, scalar_prefetch, *args, **kwargs): def _check_args(prefix, path, x, y, usage): @@ -502,6 +536,8 @@ def read_env(atom): def write_env(var, val): env[var] = val + for const, constvar in zip(consts, jaxpr.constvars): + env[constvar] = const for invar, arg, usage in zip(jaxpr.invars, flat_args, invar_usages): if Usage.REGULAR in usage: env[invar] = arg @@ -714,7 +750,14 @@ def new_index_map(i, *args): idx = util.tuple_update(idx, i, 0) return idx - new_block_shape = util.tuple_update(block_spec.block_shape, i, 1) + # TODO(wdvi): This is a hack needed since lowering rules require block shape + # to contain either all pl.Element or none + bcast_dim_block_shape = 1 + if isinstance(block_spec.block_shape[i], pallas_core.Element): + bcast_dim_block_shape = pallas_core.Element(1) + new_block_shape = util.tuple_update( # pytype: disable=wrong-arg-types + block_spec.block_shape, i, bcast_dim_block_shape + ) return pallas_core.BlockSpec( new_block_shape, functools.partial(new_index_map, i) ) @@ -768,12 +811,22 @@ def _eval_function(_, x, y): return [l_block_spec, r_block_spec] +def register_default_eval_rule(prim: core.Primitive): + def default_rule(ctx, *args, **params): + assert all(bs is pallas_core.no_block_spec for bs in ctx.out_block_specs) + return prim.bind(*args, **params) + + register_eval_rule(prim)(default_rule) + + def register_binop_rule(prim: core.Primitive): register_pull_block_spec_rule(prim)(functools.partial(_binop_pull_rule, prim)) register_usage_rule(prim)(functools.partial(_binop_usage_rule, prim)) register_eval_rule(prim)(functools.partial(_binop_eval_rule, prim)) +register_default_eval_rule(state_primitives.get_p) + register_binop_rule(lax.mul_p) register_binop_rule(lax.add_p) register_binop_rule(lax.sub_p) @@ -784,7 +837,10 @@ def register_binop_rule(prim: core.Primitive): register_binop_rule(lax.eq_p) register_binop_rule(lax.gt_p) register_binop_rule(lax.ge_p) +register_binop_rule(lax.or_p) +register_binop_rule(lax.xor_p) register_binop_rule(lax.and_p) +register_binop_rule(lax.shift_right_logical_p) register_binop_rule(ad_util.add_any_p) @@ -839,10 +895,74 @@ def new_index_map(*args): def _slice_eval_rule(ctx, x, **params): del params out_block_shape = ctx.out_block_specs[0].block_shape - assert len(x.shape) == sum(1 for bs in out_block_shape if bs is not None) + assert len(x.shape) == sum( + 1 + for bs in out_block_shape + if not (bs is None or isinstance(bs, pallas_core.Squeezed)) + ) return x +def _offset_indexer( + bs: pallas_core.BlockDim | int | None, + indexer, + slice_start, + slice_size, +): + # Short-circuit if the slice start is just at zero. + if isinstance(slice_start, int) and slice_start == 0: + return indexer + match bs: + case None | pallas_core.Squeezed(): + return indexer + slice_start + case pallas_core.Element(block_size): + _maybe_static_check( + slice_start % block_size == 0, + f'slice_start is not a multiple of block_size {block_size}', + ) + _maybe_static_check( + slice_size % block_size == 0, + f'slice_size is not a multiple of block_size {block_size}', + ) + return indexer + slice_start + case int() | pallas_core.Blocked(): + block_size = _block_size(bs) + _maybe_static_check( + slice_start % block_size == 0, + f'slice_start is not a multiple of block_size {block_size}', + ) + _maybe_static_check( + slice_size % block_size == 0, + f'slice_size is not a multiple of block_size {block_size}', + ) + # indexer is a block index so we need to offset it by the block offset. + return indexer + slice_start // block_size + case pallas_core.BoundedSlice(block_size): + assert isinstance(indexer, indexing.Slice) + _maybe_static_check( + indexer.start % block_size == 0, + f'slice_start is not a multiple of block_size {block_size}', + ) + _maybe_static_check( + indexer.size % block_size == 0, + f'slice_size is not a multiple of block_size {block_size}', + ) + return indexing.ds(indexer.start + slice_start, indexer.size) + case _: + raise ValueError(f'Unsupported block size {bs}') + + +def _maybe_static_check(pred: bool, msg: str): + # Tries to emit a static error if possible, otherwise falls back to runtime. + from jax.experimental import checkify + + if isinstance(pred, jax.Array): + checkify.check(pred, msg, debug=True) + else: + if not pred: + raise ValueError(msg) + + @register_pull_block_spec_rule(lax.slice_p) def _slice_rule( ctx: PullRuleContext, @@ -858,25 +978,42 @@ def _slice_rule( slice_sizes = tuple( int(end - start) for start, end in zip(start_indices, limit_indices) ) + # Do some basic checks for bs, slice_start, slice_size in zip( block_spec.block_shape, start_indices, slice_sizes ): - if bs is None: - continue - assert slice_start % bs == 0, (start_indices, block_spec.block_shape) - assert slice_size % bs == 0, (slice_sizes, block_spec.block_shape) - offsets = tuple( - slice_start // bs if bs is not None else slice_start - for slice_start, bs in zip(start_indices, block_spec.block_shape) - ) - - def _offset(x, i): - return x + i if i != 0 else x + match bs: + case None | pallas_core.Squeezed(): + continue + case pallas_core.BoundedSlice() | pallas_core.Element(): + block_size = _block_size(bs) + # Require that block_size no bigger than the slice. + if block_size > slice_size: + raise ValueError( + f'Block size {block_size} is larger than the slice size' + f' {slice_size}' + ) + case _: + block_size = _block_size(bs) + assert slice_start % block_size == 0, ( + start_indices, + block_spec.block_shape, + ) + assert slice_size % block_size == 0, ( + slice_sizes, + block_spec.block_shape, + ) def new_index_map(*args): idx = block_spec.index_map(*args) assert len(idx) == len(block_spec.block_shape) - return tuple(_offset(i, o) for i, o in zip(idx, offsets)) + idx = tuple( + _offset_indexer(bs, i, start, size) + for bs, i, start, size in zip( + block_spec.block_shape, idx, start_indices, slice_sizes, strict=True + ) + ) + return idx return [pallas_core.BlockSpec(block_spec.block_shape, new_index_map)] @@ -893,20 +1030,6 @@ def _dynamic_slice_usage_rule(ctx, used_out: set[Usage], **params): return [set()] * len(ctx.avals_in) -def _offset(x, i, s): - from jax.experimental import checkify - - if s is not None: - pred = i % s == 0 - if isinstance(pred, jax.Array): - checkify.check(i % s == 0, 'Invalid index', debug=True) - else: - if not pred: - raise ValueError('Invalid index') - offset = jax.lax.div(i, s) if s is not None else i - return x + offset - - @register_eval_rule(lax.dynamic_slice_p) def _dynamic_slice_eval_rule(ctx, x, *args, **params): del ctx, params @@ -920,7 +1043,6 @@ def _dynamic_slice_rule( *, slice_sizes: tuple[int, ...], ): - del slice_sizes def new_index_map(*args): slice_starts = ctx.scalar_prefetch_fn() @@ -942,11 +1064,11 @@ def new_index_map(*args): # multiples of the block sizes. The indices of the block that correspond to # the slice are then given by (i // b_l, j // b_m, k // b_n). # We then add these block indices to block indices produced by the index - # map. + # map block_indices = tuple( - _offset(i, o, s) - for i, o, s in zip( - idx, slice_starts, block_spec.block_shape, strict=True + _offset_indexer(s, i, start, size) + for i, s, start, size in zip( + idx, block_spec.block_shape, slice_starts, slice_sizes, strict=True ) ) return block_indices @@ -957,12 +1079,175 @@ def new_index_map(*args): ) +@register_pull_block_spec_rule(state_primitives.swap_p) +def _swap_pull_rule( + ctx: PullRuleContext, + block_spec: pallas_core.BlockSpec, + **kwargs, +): + del ctx, kwargs + # The output and val block spec are the same. + return [block_spec, block_spec] + + +@register_eval_rule(state_primitives.swap_p) +def _swap_eval_rule(ctx: KernelEvalContext, ref, val, *idx, tree): + indexers = tree_util.tree_unflatten(tree, idx) + ref_aval, _ = ctx.avals_in[:2] + indexers_avals = tree_util.tree_unflatten(tree, ctx.avals_in[2:]) + assert hasattr(ref_aval, 'shape') + if len(indexers) > 1: + raise NotImplementedError('swap not supported yet') + indexer_aval = indexers_avals[0] + for idx_aval, size in zip(indexer_aval.indices, ref_aval.shape, strict=True): + if not isinstance(idx_aval, indexing.Slice): + raise NotImplementedError('swap not supported yet') + if not isinstance(idx_aval.start, int): + raise NotImplementedError('swap not supported yet') + if not isinstance(idx_aval.size, int): + raise NotImplementedError('swap not supported yet') + if idx_aval.stride != 1: + raise NotImplementedError('swap not supported yet') + if idx_aval.start != 0: + raise NotImplementedError('swap not supported yet') + if idx_aval.size != size: + raise NotImplementedError('swap not supported yet') + # We have a pure slice so now we can just re-index the ref according to the + # block indices. + block_spec = ctx.out_block_specs[0] + block_idx = ctx.get_out_block_indices()[0] + + def _slice(i, b): + if not isinstance(b, int): + raise NotImplementedError('swap not supported yet') + return i if b is None else indexing.ds(i * b, b) + + indexer = tuple( + _slice(i, b) + for i, b in zip(block_idx, block_spec.block_shape, strict=True) + ) + return ref.swap(val, idx=indexer) + + +@register_pull_block_spec_rule(state_primitives.get_p) +def _get_pull_rule( + ctx: PullRuleContext, block_spec: pallas_core.BlockSpec, *, tree +): + ref_aval = ctx.avals_in[0] + assert hasattr(ref_aval, 'shape') + indexers_avals = tree_util.tree_unflatten(tree, ctx.avals_in[1:]) + if len(indexers_avals) > 1: + raise NotImplementedError('get not supported yet') + indexer_aval = indexers_avals[0] + block_shape_iter = iter(block_spec.block_shape) + block_shape = [] + if not all( + isinstance(bd, (int, pallas_core.Blocked, pallas_core.Squeezed, None)) + for bd in block_spec.block_shape + ): + raise NotImplementedError('get not supported yet') + for idx_aval, size in zip(indexer_aval.indices, ref_aval.shape, strict=True): + if not isinstance(idx_aval, indexing.Slice): + assert hasattr(idx_aval, 'shape') and not idx_aval.shape + block_shape.append(pallas_core.Squeezed()) + continue + if not isinstance(idx_aval.start, int): + raise NotImplementedError('get not supported yet') + if not isinstance(idx_aval.size, int): + raise NotImplementedError('get not supported yet') + if idx_aval.stride != 1: + raise NotImplementedError('get not supported yet') + if idx_aval.start != 0: + raise NotImplementedError('get not supported yet') + if idx_aval.size != size: + raise NotImplementedError('get not supported yet') + bd = next(block_shape_iter) + block_shape.append(_block_size(bd)) + assert next(block_shape_iter, None) is None + + def new_index_map(*args): + idx = block_spec.index_map(*args) + idx_iter = iter(idx) + indices = tuple( + 0 + if (bd is None or isinstance(bd, pallas_core.Squeezed)) + else next(idx_iter) + for bd in range(len(block_shape)) + ) + assert next(idx_iter, None) is None + return indices + + block_spec = pallas_core.BlockSpec(block_shape, new_index_map) + return [block_spec] + [pallas_core.no_block_spec] * (len(ctx.avals_in) - 1) + + +@register_eval_rule(state_primitives.get_p) +def _get_eval_rule(ctx: KernelEvalContext, ref, *idx, tree): + indexers = tree_util.tree_unflatten(tree, idx) + ref_aval = ctx.avals_in[0] + indexers_avals = tree_util.tree_unflatten(tree, ctx.avals_in[1:]) + ref_block_spec = ctx.in_block_specs[0] + assert hasattr(ref_aval, 'shape') + if len(indexers) > 1: + raise NotImplementedError('get not supported yet') + indexer = indexers[0] + indexer_aval = indexers_avals[0] + block_indexer = [] + + def _slice(i, b): + match b: + case int(): + return indexing.ds(i * b, b) + case pallas_core.Blocked(bs): + return indexing.ds(i * bs, bs) + case pallas_core.Squeezed() | None: + return i + case _: + raise NotImplementedError('get not supported yet') + + if ref_block_spec is pallas_core.no_block_spec: + # Short-circuit if the ref is not blocked. + return state_primitives.get_p.bind(ref, *idx, tree=tree) + block_idx_iter = iter(ctx.get_out_block_indices()[0]) + for idx_aval, size, idx, bd in zip( + indexer_aval.indices, + ref_aval.shape, + indexer.indices, + ref_block_spec.block_shape, + strict=True, + ): + if not isinstance(idx_aval, indexing.Slice): + assert hasattr(idx_aval, 'shape') and not idx_aval.shape, idx_aval + assert bd is None or isinstance(bd, pallas_core.Squeezed) + block_indexer.append(idx) + continue + if not isinstance(idx_aval.start, int): + raise NotImplementedError('get not supported yet') + if not isinstance(idx_aval.size, int): + raise NotImplementedError('get not supported yet') + if idx_aval.stride != 1: + raise NotImplementedError('get not supported yet') + if idx_aval.start != 0: + raise NotImplementedError('get not supported yet') + if idx_aval.size != size: + raise NotImplementedError('get not supported yet') + bidx = next(block_idx_iter) + block_indexer.append(_slice(bidx, bd)) + assert next(block_idx_iter, None) is None + return ref.get(idx=tuple(block_indexer)) + + @register_eval_rule(lax.concatenate_p) def _concatenate_eval_rule(ctx: KernelEvalContext, *args, dimension): # We now handle the case where each of the concatenated array dimensions # divides the block size. block_spec = ctx.out_block_specs[0] block_shape = block_spec.block_shape + is_element_block = [isinstance(bd, pallas_core.Element) for bd in block_shape] + if any(is_element_block): + raise NotImplementedError( + 'Concatenation with Element indexing is not yet supported.' + ) block_dim = block_shape[dimension] if block_dim is None: block_dim = 1 @@ -1006,15 +1291,20 @@ def _concatenate_rule( dimension: int, ): block_shape = block_spec.block_shape + is_element_block = [isinstance(bd, pallas_core.Element) for bd in block_shape] + if any(is_element_block): + raise NotImplementedError( + 'Concatenation with Element indexing is not yet supported.' + ) num_blocks = [] block_dim = block_shape[dimension] - if block_dim is None: + if block_dim is None or isinstance(block_dim, pallas_core.Squeezed): block_dim = 1 if block_dim == sum(aval.shape[dimension] for aval in ctx.avals_in): # pytype: disable=attribute-error # Handle special case if the block contains all of the concatenated # array. new_shapes = [ - util.tuple_update( + util.tuple_update( # pytype: disable=wrong-arg-types block_spec.block_shape, dimension, aval.shape[dimension] # pytype: disable=attribute-error ) for aval in ctx.avals_in @@ -1077,12 +1367,15 @@ def _broadcast_in_dim_usage_rule(ctx, used_out: set[Usage], **params): def _broadcast_in_dim_eval_rule( eval_ctx: KernelEvalContext, x, broadcast_dimensions, **params ): - if not eval_ctx.avals_in[0].shape: # pytype: disable=attribute-error - # Scalar -> Array broadcast - block_spec = eval_ctx.out_block_specs[0] - shape = tuple(s for s in block_spec.block_shape if s is not None) - return jax.lax.broadcast_in_dim(x, broadcast_dimensions=(), shape=shape) - return x + del params # Unused. + shape = tuple(map(_block_size, eval_ctx.out_block_specs[0].block_shape)) + dims = tuple( + d - sum(s is None for s in shape[:d]) + for d in broadcast_dimensions + if shape[d] is not None + ) + shape = tuple(s for s in shape if s is not None) + return jax.lax.broadcast_in_dim(x, broadcast_dimensions=dims, shape=shape) @register_pull_block_spec_rule(lax.broadcast_in_dim_p) @@ -1096,15 +1389,20 @@ def _broadcast_in_dim_pull_rule( ): del shape, sharding - if not ctx.avals_in[0].shape: # pytype: disable=attribute-error + shape = ctx.avals_in[0].shape # pytype: disable=attribute-error + if not shape: return [pallas_core.no_block_spec] def new_index_map(*args): idx = block_spec.index_map(*args) - return tuple(idx[i] for i in broadcast_dimensions) + return tuple( + 0 if (d == 1) else idx[i] + for i, d in zip(broadcast_dimensions, shape, strict=True) + ) new_block_shape = tuple( - block_spec.block_shape[i] for i in broadcast_dimensions + b if ((b := block_spec.block_shape[i]) is None) or (d != 1) else 1 + for i, d in zip(broadcast_dimensions, shape, strict=True) ) return [pallas_core.BlockSpec(new_block_shape, new_index_map)] @@ -1115,10 +1413,17 @@ def _transpose_eval_rule( ): block_spec = eval_ctx.out_block_specs[0] block_shape = block_spec.block_shape - block_shape_no_nones = tuple(bs for bs in block_shape if bs is not None) + block_shape_no_nones = tuple( + bs + for bs in block_shape + if not (bs is None or isinstance(bs, pallas_core.Squeezed)) + ) block_dims_iter = iter(range(len(block_shape_no_nones))) expanded_block_dims = [ - None if bs is None else next(block_dims_iter) for bs in block_shape + None + if (bs is None or isinstance(bs, pallas_core.Squeezed)) + else next(block_dims_iter) + for bs in block_shape ] assert next(block_dims_iter, None) is None permuted_block_dims = [expanded_block_dims[p] for p in permutation] @@ -1171,6 +1476,67 @@ def _convert_element_type_pull_rule( return [block_spec] +@register_eval_rule(lax.bitcast_convert_type_p) +def _bitcast_convert_type_eval_rule(eval_ctx: KernelEvalContext, x, new_dtype): + return jax.lax.bitcast_convert_type(x, new_dtype) + + +@register_pull_block_spec_rule(lax.bitcast_convert_type_p) +def _bitcast_convert_type_pull_rule( + ctx: PullRuleContext, + block_spec: pallas_core.BlockSpec, + *, + new_dtype: jnp.dtype, +): + old_dtype = ctx.avals_in[0].dtype # pytype: disable=attribute-error + if old_dtype.itemsize != new_dtype.itemsize: + raise NotImplementedError( + 'bitcast_convert_type with different bitwidths not supported yet:' + f' {old_dtype=}, {new_dtype=}' + ) + return [block_spec] + + +@register_eval_rule(prng.random_bits_p) +def _random_bits_eval_rule(eval_ctx: KernelEvalContext, key, bit_width, shape): + del shape + block_spec = eval_ctx.out_block_specs[0] + indices = eval_ctx.get_out_block_indices()[0] + block_shape = block_spec.block_shape + # This is the important part here: we fold in block indices into the key so + # each block gets different random numbers. + for idx in indices: + key = jax.random.fold_in(key, idx) + return prng.random_bits(key, bit_width=bit_width, shape=block_shape) + + +@register_pull_block_spec_rule(prng.random_bits_p) +def _random_bits_pull_rule( + ctx: PullRuleContext, + block_spec: pallas_core.BlockSpec, + **_, +): + del ctx, block_spec + key_block_spec = pallas_core.BlockSpec( + block_shape=None, memory_space=pallas_core.MemorySpace.KEY + ) + return [key_block_spec] + + +@register_eval_rule(prng.random_wrap_p) +def _random_wrap_eval_rule(eval_ctx: KernelEvalContext, arr, *, impl): + del eval_ctx + return jax.random.wrap_key_data(arr, impl=impl) + + +@register_pull_block_spec_rule(prng.random_wrap_p) +def _random_wrap_pull_rule( + ctx: PullRuleContext, block_spec: pallas_core.BlockSpec, *, impl +): + del ctx, block_spec, impl + return [pallas_core.BlockSpec(block_shape=None)] + + @register_eval_rule(lax.iota_p) def _iota_eval_rule( eval_ctx: KernelEvalContext, *, dimension, shape, dtype, sharding @@ -1203,6 +1569,138 @@ def _iota_pull_rule( return [] +def _pattern_match_sublanes_to_lanes_reshape( + aval_in: core.ShapedArray, + aval_out: core.ShapedArray, +) -> bool: + # Pattern matches a reshape of the form (..., n/l, l) -> (..., n * l) + # where l is a multiple of 128. + + *leading_in, second_to_last_dim, last_dim = aval_in.shape + *leading_out, last_dim_out = aval_out.shape + if leading_in != leading_out: + return False + if second_to_last_dim * last_dim != last_dim_out: + return False + if last_dim % 128 != 0: + return False + return True + + +def _pattern_match_lanes_to_sublanes_reshape( + aval_in: core.ShapedArray, + aval_out: core.ShapedArray, +) -> bool: + # Pattern matches a reshape of the form (..., n * l) -> (..., n, l) + # where l is a multiple of 128. + + *leading_out, last_dim_in = aval_in.shape + *leading_in, second_to_last_dim_out, last_dim = aval_out.shape + if leading_in != leading_out: + return False + if second_to_last_dim_out * last_dim != last_dim_in: + return False + if last_dim % 128 != 0: + return False + return True + + +@register_pull_block_spec_rule(lax.reshape_p) +def _reshape_pull_rule( + ctx: PullRuleContext, + block_spec: pallas_core.BlockSpec, + *, + dimensions: tuple[int, ...] | None, + new_sizes: tuple[int, ...], + sharding: jax.sharding.Sharding, +): + del sharding, new_sizes + if dimensions is not None: + raise NotImplementedError('reshape with None dimensions not supported yet') + aval_in = ctx.avals_in[0] + assert isinstance(aval_in, core.ShapedArray) + aval_out = ctx.avals_out[0] + assert isinstance(aval_out, core.ShapedArray) + + # Handle the case where we reshape from (..., n/l, l) -> (..., n * l) + if _pattern_match_sublanes_to_lanes_reshape(aval_in, aval_out): + block_shape = tuple(block_spec.block_shape) + if not isinstance(block_shape[-1], (int, pallas_core.Blocked)): + raise NotImplementedError( + f'reshape must use Blocked block size on lanes: {block_shape}' + ) + last_dim = _block_size(block_shape[-1]) + if last_dim % 128 != 0: + raise NotImplementedError( + 'reshape with non-128 aligned block size on lanes not supported yet' + ) + # We can now reshape last dim from d -> (d/128, 128) + new_block_shape = block_shape[:1] + (last_dim // 128, 128) + + def new_index_map(*args): + idx = block_spec.index_map(*args) + return *idx, 0 + + return [pallas_core.BlockSpec(new_block_shape, new_index_map)] + + # Handle the case where we reshape from (..., n * l) -> (..., n, l) + if _pattern_match_lanes_to_sublanes_reshape(aval_in, aval_out): + block_shape = tuple(block_spec.block_shape) + if not isinstance(block_shape[-1], (int, pallas_core.Blocked)): + raise NotImplementedError( + f'reshape must use Blocked block size on lanes: {block_shape}' + ) + if not isinstance(block_shape[-2], (int, pallas_core.Blocked)): + raise NotImplementedError( + f'reshape must use Blocked block size on sublanes: {block_shape}' + ) + last_dim = aval_out.shape[-1] + block_sublane_dim, block_lane_dim = ( + _block_size(block_shape[-2]), + _block_size(block_shape[-1]), + ) + total_block_size = block_sublane_dim * block_lane_dim + if total_block_size % 128 != 0: + raise NotImplementedError( + 'reshape with non-128 aligned block size on lanes not supported yet' + ) + if block_lane_dim != last_dim: + raise NotImplementedError( + 'reshape with non-matching block size on lanes not supported yet:' + f' {block_shape}' + ) + new_block_shape = block_shape[:-2] + (total_block_size,) + def new_index_map(*args): # pylint: disable=function-redefined + *idx, second_to_last, last = block_spec.index_map(*args) + # last should always be 0 + if not isinstance(last, int) and last != 0: + raise NotImplementedError( + 'Must select entire block on last dimension for reshape' + ) + return *idx, second_to_last + return [pallas_core.BlockSpec(new_block_shape, new_index_map)] + + raise NotImplementedError(f'reshape not supported yet: {aval_in}, {aval_out}') + + +@register_eval_rule(lax.reshape_p) +def _reshape_eval_rule( + eval_ctx: KernelEvalContext, x, *, dimensions, new_sizes, sharding +): + del sharding, dimensions, new_sizes + out_shape_nones = tuple( + _block_size(s) for s in eval_ctx.out_block_specs[0].block_shape + ) + out_shape = tuple(s for s in out_shape_nones if s is not None) + # Because we have restricted the pull block spec rule, we can just apply a + # basic reshape here. + x = x.reshape(out_shape) + return x + + +# Higher order primitives + + @register_usage_rule(pjit.pjit_p) def _jit_usage_rule( ctx, used_out: list[set[Usage]], *, jaxpr: core.ClosedJaxpr, **_ @@ -1219,16 +1717,20 @@ def _jit_eval_rule(ctx: KernelEvalContext, *args, jaxpr, **kwargs): raise NotImplementedError('pjit with consts not supported yet') out_tree = tree_util.tree_structure(tuple(jaxpr.outvars)) in_tree = tree_util.tree_structure((tuple(jaxpr.invars), {})) - read_usage_env = compute_usage(jaxpr, ctx.out_usages) + + def read_usage_env(_: core.Var): + return {Usage.REGULAR} + _, env, _ = _pull_block_spec( jaxpr, ctx.out_block_specs, - ctx.out_usages, scalar_prefetch_handler=ctx.scalar_prefetch_handler, + read_usage_env=read_usage_env, grid=ctx.grid, ) kernel_fn = make_kernel_function( jaxpr, + (), in_tree, out_tree, read_usage_env, @@ -1247,11 +1749,15 @@ def _jit_pull_block_spec_rule( jaxpr, consts = jaxpr.jaxpr, jaxpr.consts if consts: raise NotImplementedError('pjit with consts not supported yet') + + def read_usage_env(_: core.Var): + return {Usage.REGULAR} + in_block_specs, _, _ = _pull_block_spec( jaxpr, out_block_specs, - ctx.out_usages, scalar_prefetch_handler=ctx.scalar_prefetch_handler, + read_usage_env=read_usage_env, grid=ctx.grid, ) return in_block_specs @@ -1276,16 +1782,20 @@ def _custom_jvp_call_eval_rule( raise NotImplementedError('custom_jvp_call with consts not supported yet') out_tree = tree_util.tree_structure(tuple(jaxpr.outvars)) in_tree = tree_util.tree_structure((tuple(jaxpr.invars), {})) - read_usage_env = compute_usage(jaxpr, ctx.out_usages) + + def read_usage_env(_: core.Var): + return {Usage.REGULAR} + _, env, _ = _pull_block_spec( jaxpr, ctx.out_block_specs, - ctx.out_usages, scalar_prefetch_handler=ctx.scalar_prefetch_handler, grid=ctx.grid, + read_usage_env=read_usage_env, ) kernel_fn = make_kernel_function( jaxpr, + (), in_tree, out_tree, read_usage_env, @@ -1304,12 +1814,16 @@ def _custom_jvp_call_pull_block_spec_rule( jaxpr, consts = call_jaxpr.jaxpr, call_jaxpr.consts if consts: raise NotImplementedError('custom_jvp_call with consts not supported yet') + + def read_usage_env(_: core.Var): + return {Usage.REGULAR} + in_block_specs, _, _ = _pull_block_spec( jaxpr, out_block_specs, - ctx.out_usages, scalar_prefetch_handler=ctx.scalar_prefetch_handler, grid=ctx.grid, + read_usage_env=read_usage_env, ) return in_block_specs @@ -1510,9 +2024,14 @@ def _select_n_push_rule( ): del ctx block_specs = [b for b in args if b is not pallas_core.no_block_spec] + assert len(block_specs) > 0 + block_spec = block_specs[0] if len(block_specs) > 1: - raise NotImplementedError('select_n with multiple inputs not supported yet') - return block_specs[0] + if any(b is not block_spec for b in block_specs): + raise NotImplementedError( + 'select_n with multiple differing inputs not supported yet' + ) + return block_spec @register_push_block_spec_rule(custom_derivatives.custom_jvp_call_p) @@ -1548,3 +2067,45 @@ def register_eltwise_rule(prim: core.Primitive): register_eltwise_rule(lax.rsqrt_p) register_eltwise_rule(lax.log_p) register_eltwise_rule(lax.integer_pow_p) + +@register_push_block_spec_rule(lax.reshape_p) +def _reshape_push_rule( + ctx: PullRuleContext, + block_spec: pallas_core.BlockSpec, + *, + dimensions: tuple[int, ...] | None, + new_sizes: tuple[int, ...], + sharding: jax.sharding.Sharding, +): + del sharding, new_sizes + if dimensions is not None: + raise NotImplementedError('reshape with None dimensions not supported yet') + aval_in = ctx.avals_in[0] + assert isinstance(aval_in, core.ShapedArray) + aval_out = ctx.avals_out[0] + assert isinstance(aval_out, core.ShapedArray) + if _pattern_match_lanes_to_sublanes_reshape(aval_in, aval_out): + block_shape = tuple(block_spec.block_shape) + if not isinstance(block_shape[-1], (int, pallas_core.Blocked)): + raise NotImplementedError( + f'reshape must use Blocked block size on lanes: {block_shape}' + ) + last_dim = aval_out.shape[-1] + last_block_dim = _block_size(block_shape[-1]) + if last_block_dim % 128 != 0: + raise NotImplementedError( + 'reshape with non-128 aligned block size on lanes not supported yet' + ) + if last_block_dim % last_dim != 0: + raise NotImplementedError( + 'reshape with non-divisible block size on lanes not supported yet' + ) + num_last_dim_blocks = last_block_dim // last_dim + new_block_shape = block_shape[:1] + (num_last_dim_blocks, last_dim) + + def new_index_map(*args): + *idx, last = block_spec.index_map(*args) + return *idx, last, 0 + + return pallas_core.BlockSpec(new_block_shape, new_index_map) + raise NotImplementedError(f'reshape not supported yet: {aval_in}, {aval_out}') diff --git a/jax/_src/pallas/fuser/fusable.py b/jax/_src/pallas/fuser/fusable.py deleted file mode 100644 index b075c6d136c9..000000000000 --- a/jax/_src/pallas/fuser/fusable.py +++ /dev/null @@ -1,83 +0,0 @@ -# Copyright 2025 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Fusable primitive.""" - -import jax -from jax._src import api_util -from jax._src import core as jax_core -from jax._src import linear_util as lu -from jax._src import tree_util -from jax._src import util -from jax._src.interpreters import mlir -from jax._src.interpreters import partial_eval as pe -from jax._src.pallas.fuser import fusion as fusion_lib - -fusable_p = jax_core.Primitive('fusable') -fusable_p.multiple_results = True - - -def _get_aval(x): - return jax_core.raise_to_shaped(jax_core.get_aval(x)) - - -def _make_trivial_fusion(x: jax.Array) -> fusion_lib.Fusion: - return fusion_lib.Fusion( - func=lambda: x, - in_type=((), {}), - out_type=jax.ShapeDtypeStruct(x.shape, x.dtype), - ) - - -def fusable(f): - def wrapper(*args): - def wrapped(*args): - in_fusions = tree_util.tree_map(_make_trivial_fusion, args) - return f(*in_fusions, None) - - flat_args, in_tree = tree_util.tree_flatten(args) - debug_info = api_util.debug_info('fusable', wrapped, args, {}) - flat_fun, out_tree_thunk = api_util.flatten_fun_nokwargs( - lu.wrap_init(wrapped, debug_info=debug_info), in_tree - ) - flat_avals = [_get_aval(x) for x in flat_args] - jaxpr, _, consts, _ = pe.trace_to_jaxpr_dynamic(flat_fun, flat_avals) - out_tree = out_tree_thunk() - out = fusable_p.bind( - *consts, - *flat_args, - jaxpr=jaxpr, - num_consts=len(consts), - in_tree=in_tree, - out_tree=out_tree, - func=f, - ) - return tree_util.tree_unflatten(out_tree, out) - - return wrapper - - -@fusable_p.def_impl -def _(*consts_and_args, jaxpr, num_consts, **_): - consts, args = util.split_list(consts_and_args, [num_consts]) - return jax_core.eval_jaxpr(jaxpr, consts, *args) - - -mlir.register_lowering(fusable_p, mlir.lower_fun(fusable_p.impl)) - - -@fusable_p.def_abstract_eval -def _(*args, jaxpr, **kwargs): - del args, kwargs - return [v.aval for v in jaxpr.outvars] diff --git a/jax/_src/pallas/fuser/fusible.py b/jax/_src/pallas/fuser/fusible.py new file mode 100644 index 000000000000..f0d03cb18d94 --- /dev/null +++ b/jax/_src/pallas/fuser/fusible.py @@ -0,0 +1,86 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Fusible primitive.""" +from typing import Any + +import jax +from jax._src import api_util +from jax._src import core as jax_core +from jax._src import linear_util as lu +from jax._src import tree_util +from jax._src import util +from jax._src.interpreters import mlir +from jax._src.interpreters import partial_eval as pe +from jax._src.pallas.fuser import fusion as fusion_lib + +fusible_p = jax_core.Primitive('fusible') +fusible_p.multiple_results = True + + +def _make_trivial_fusion(x: jax.Array) -> fusion_lib.Fusion: + return fusion_lib.Fusion( + func=lambda: x, + in_type=((), {}), + out_type=jax.ShapeDtypeStruct(x.shape, x.dtype), + ) + + +def fusible(f=None, *, output_fusion_prefix: Any = True): + def decorator(f): + def wrapper(*args): + def wrapped(*args): + in_fusions = tree_util.tree_map(_make_trivial_fusion, args) + return f(*in_fusions, None) + + flat_args, in_tree = tree_util.tree_flatten(args) + debug_info = api_util.debug_info('fusible', wrapped, args, {}) + flat_fun, out_tree_thunk = api_util.flatten_fun_nokwargs( + lu.wrap_init(wrapped, debug_info=debug_info), in_tree + ) + flat_avals = [jax_core.get_aval(x) for x in flat_args] + jaxpr, _, consts, _ = pe.trace_to_jaxpr_dynamic(flat_fun, flat_avals) + out_tree = out_tree_thunk() + out = fusible_p.bind( + *consts, + *flat_args, + jaxpr=jaxpr, + num_consts=len(consts), + in_tree=in_tree, + out_tree=out_tree, + func=f, + output_fusion_prefix=output_fusion_prefix, + ) + return tree_util.tree_unflatten(out_tree, out) + + return wrapper + + if f is not None: + return decorator(f) + return decorator + + +@fusible_p.def_impl +def _(*consts_and_args, jaxpr, num_consts, **_): + consts, args = util.split_list(consts_and_args, [num_consts]) + return jax_core.eval_jaxpr(jaxpr, consts, *args) + + +mlir.register_lowering(fusible_p, mlir.lower_fun(fusible_p.impl)) + + +@fusible_p.def_effectful_abstract_eval +def _(*args, jaxpr, **kwargs): + del args, kwargs + return [v.aval for v in jaxpr.outvars], jaxpr.effects diff --git a/jax/_src/pallas/fuser/fusable_dtype.py b/jax/_src/pallas/fuser/fusible_dtype.py similarity index 87% rename from jax/_src/pallas/fuser/fusable_dtype.py rename to jax/_src/pallas/fuser/fusible_dtype.py index e5bc9ab683ab..2d2c8aac2967 100644 --- a/jax/_src/pallas/fuser/fusable_dtype.py +++ b/jax/_src/pallas/fuser/fusible_dtype.py @@ -12,16 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Custom fusable dtypes.""" +"""Custom fusible dtypes.""" import abc import dataclasses import functools -from typing import Any, Sequence, TypeVar +import itertools as it +from typing import Any, TypeVar +from collections.abc import Sequence import jax from jax._src import api_util from jax._src import core +from jax._src import custom_derivatives from jax._src import dtypes from jax._src import linear_util as lu from jax._src import source_info_util @@ -34,7 +37,7 @@ from jax._src.pallas import pallas_call from jax._src.pallas import primitives as pallas_primitives from jax._src.pallas.fuser import block_spec -from jax._src.pallas.fuser.fusable import fusable_p +from jax._src.pallas.fuser.fusible import fusible_p from jax._src.state import discharge as state_discharge from jax._src.state import primitives as state_primitives from jax._src.util import foreach @@ -54,7 +57,7 @@ @pack_dtype_p.def_abstract_eval def pack_dtype_abstract_eval(*xs, dtype): - if dtypes.issubdtype(dtype, FusableElementDType): + if dtypes.issubdtype(dtype, FusibleElementDType): return dtype.abstract_pack(*xs) raise ValueError("Attempted to pack non-fusion dtype: {dtype}") @@ -69,7 +72,7 @@ def pack(*xs, dtype): @unpack_dtype_p.def_abstract_eval def unpack_dtype_abstract_eval(x): - if dtypes.issubdtype(x.dtype, FusableElementDType): + if dtypes.issubdtype(x.dtype, FusibleElementDType): return x.dtype.abstract_unpack(x) elif isinstance(x.dtype, pallas_core.AbstractMemoryRef): raise NotImplementedError() @@ -80,22 +83,20 @@ def unpack(x): return unpack_dtype_p.bind(x) -class FusableElementDType(dtypes.extended): - """Scalar dtype for fusable dtypes.""" +class FusibleElementDType(dtypes.extended): + """Scalar dtype for fusible dtypes.""" - pass - -class FusableTyRules: +class FusibleTyRules: allow_conversion: bool = False class FusionDType(dtypes.ExtendedDType, metaclass=abc.ABCMeta): - """Base class for fusable extended dtypes.""" + """Base class for fusible extended dtypes.""" _op_registry = {} - _rules = FusableTyRules - type = FusableElementDType + _rules = FusibleTyRules + type = FusibleElementDType @abc.abstractmethod def abstract_unpack(self, x) -> Sequence[Any]: @@ -126,7 +127,7 @@ def pull_block_spec_one_step(self, *args, **kwargs): def physicalize(f): - """Runs a function that contains fusable extended dtypes.""" + """Runs a function that contains fusible extended dtypes.""" def wrapper(*args, **kwargs): if kwargs: @@ -205,7 +206,7 @@ class Context: def physicalize_interp( jaxpr: core.Jaxpr, consts: Sequence[core.Value], *args: core.Value ): - """Physicalizes a jaxpr by replacing fusable dtypes with physical types.""" + """Physicalizes a jaxpr by replacing fusible dtypes with physical types.""" # TODO: Merge into JAX core. env: dict[core.Var, Any] = {} @@ -302,9 +303,9 @@ def _pallas_call_physicalize_rule( def _cond_physicalize_rule(ctx: Context, *args, branches, **kwargs): _assert_no_fusion_types(ctx.avals_out) - physicalized_branches = [ + physicalized_branches = tuple( physicalize_closed_jaxpr(branch) for branch in branches - ] + ) flat_args = jax.tree.leaves(args) return conditionals.cond_p.bind( *flat_args, branches=physicalized_branches, **kwargs @@ -314,6 +315,41 @@ def _cond_physicalize_rule(ctx: Context, *args, branches, **kwargs): _physicalize_rules[conditionals.cond_p] = _cond_physicalize_rule +@lu.transformation2 +def _physicalize_transform(f, *args): + vals, zeros = args[::2], args[1::2] + assert len(vals) == len(zeros) + wrapper = lambda *inner_vals: f( + *it.chain.from_iterable(zip(inner_vals, zeros)) + ) + return physicalize(wrapper)(*vals) + + +@lu.transformation2 +def _physicalize_transform_bwd(f, const_avals, *args): + return [custom_derivatives.Zero(a) for a in const_avals] + list( + physicalize(f)(*args) + ) + + +def _custom_vjp_call_physicalize_rule( + ctx: Context, *args, call_jaxpr, num_consts, fwd_jaxpr_thunk, bwd, **kwargs +): + _assert_no_fusion_types(ctx.avals_out) + new_jaxpr = physicalize_closed_jaxpr(call_jaxpr) + fun = lu.wrap_init(core.jaxpr_as_fun(new_jaxpr), + debug_info=call_jaxpr.jaxpr.debug_info) + fwd = custom_derivatives.lift_fwd(num_consts, fwd_jaxpr_thunk) + fwd_physicalized = _physicalize_transform(fwd) + const_avals, _ = util.split_list(new_jaxpr.in_avals, [num_consts]) + bwd_physicalized = _physicalize_transform_bwd(bwd, const_avals) + return custom_derivatives.custom_vjp_call_p.bind( + fun, fwd_physicalized, bwd_physicalized, *args, **kwargs + ) + +_physicalize_rules[custom_derivatives.custom_vjp_call_p] = _custom_vjp_call_physicalize_rule + + def _run_state_rule(ctx: Context, *args, jaxpr, which_linear, is_initialized): _assert_no_fusion_types(ctx.avals_in) _assert_no_fusion_types(ctx.avals_out) @@ -448,12 +484,12 @@ def _pack_dtype_pull_rule( return dtype.pull_block_spec_one_step(block_spec) # pytype: disable=attribute-error -def _fusable_physicalize_rule( +def _fusible_physicalize_rule( _, *consts_and_args, jaxpr, num_consts, in_tree, out_tree, func ): consts, _ = util.split_list(consts_and_args, [num_consts]) new_jaxpr = physicalize_closed_jaxpr(core.ClosedJaxpr(jaxpr, consts)) - return fusable_p.bind( + return fusible_p.bind( *consts_and_args, jaxpr=new_jaxpr.jaxpr, num_consts=num_consts, @@ -463,4 +499,4 @@ def _fusable_physicalize_rule( ) -_physicalize_rules[fusable_p] = _fusable_physicalize_rule +_physicalize_rules[fusible_p] = _fusible_physicalize_rule diff --git a/jax/_src/pallas/fuser/fusion.py b/jax/_src/pallas/fuser/fusion.py index eff8c36ddb08..6319722a9823 100644 --- a/jax/_src/pallas/fuser/fusion.py +++ b/jax/_src/pallas/fuser/fusion.py @@ -17,7 +17,8 @@ from __future__ import annotations import dataclasses -from typing import Any, Callable, Generic, ParamSpec, TypeVar +from typing import Any, Generic, ParamSpec, TypeVar +from collections.abc import Callable import jax from jax._src import util diff --git a/jax/_src/pallas/fuser/jaxpr_fusion.py b/jax/_src/pallas/fuser/jaxpr_fusion.py index 3d36b8f3e2fd..8e12b5db483d 100644 --- a/jax/_src/pallas/fuser/jaxpr_fusion.py +++ b/jax/_src/pallas/fuser/jaxpr_fusion.py @@ -14,35 +14,32 @@ """Fuses a function.""" +from collections.abc import Sequence +import functools from typing import Any - import jax from jax._src import api_util from jax._src import core as jax_core from jax._src import linear_util as lu from jax._src import tree_util from jax._src.interpreters import partial_eval as pe - -from jax._src.pallas.fuser import fusable_dtype +from jax._src.pallas.fuser import fusible_dtype from jax._src.pallas.fuser import fusion as fusion_lib -from jax._src.pallas.fuser.fusable import fusable_p - +from jax._src.pallas.fuser.fusible import fusible_p -def _get_aval(x): - return jax_core.raise_to_shaped(jax_core.get_aval(x)) - -def fuse(f=None, *, physicalize: bool = False, debug: bool = False): - """Fuses a function into a single fusable. +def fuse(f=None, *, resolve_fusion_dtypes: bool = True, debug: bool = False): + """Fuses a function into a single fusible. Args: f: The function to fuse. - physicalize: (experimental) whether to physicalize the function. + resolve_fusion_dtypes: (experimental) whether or not to resolve fusion + dtypes (which don't correspond to physical dtypes) debug: Whether to print debug information. - There should be a single call to a `fusable` inside the body of `f`. `fuse` + There should be a single call to a `fusible` inside the body of `f`. `fuse` returns a transformed function that will fuse the surrounding computation into - the fusable and invoke it. + the fusible and invoke it. """ def decorator(f): @@ -52,7 +49,7 @@ def wrapper(*args, **kwargs): flat_fun, out_tree_thunk = api_util.flatten_fun( lu.wrap_init(f, debug_info=debug_info), in_tree ) - flat_avals = [_get_aval(x) for x in flat_args] + flat_avals = [jax_core.get_aval(x) for x in flat_args] jaxpr, _, consts, _ = pe.trace_to_jaxpr_dynamic(flat_fun, flat_avals) if debug: print("Jaxpr before fusion:") @@ -61,8 +58,8 @@ def wrapper(*args, **kwargs): out_flat = fuse_jaxpr(jaxpr, out_tree, consts, *flat_args) return tree_util.tree_unflatten(out_tree, out_flat) - if physicalize: - wrapper = fusable_dtype.physicalize(wrapper) + if resolve_fusion_dtypes: + wrapper = fusible_dtype.physicalize(wrapper) return wrapper if f is not None: @@ -70,12 +67,12 @@ def wrapper(*args, **kwargs): return decorator -_fusable: dict[jax_core.Primitive, Any] = {} +_fusible: dict[jax_core.Primitive, Any] = {} -def construct_fusion( +def _construct_fusion_jaxpr( candidate_values, jaxpr: jax_core.Jaxpr, outvars, *invars, **kwargs -) -> fusion_lib.Fusion: +): flat_outvars, out_tree = tree_util.tree_flatten(outvars) flat_invars, in_tree = tree_util.tree_flatten((invars, kwargs)) new_jaxpr_no_dce = jaxpr.replace( @@ -94,12 +91,6 @@ def construct_fusion( c for used, c in zip(used_consts, candidate_values, strict=True) if used ) kernel_in_tree = tree_util.tree_structure((invars, kwargs)) - - def _fn(*args, **kwargs): - flat_args, _ = tree_util.tree_flatten((args, kwargs)) - out_flat = jax_core.eval_jaxpr(new_jaxpr, new_values, *flat_args) - return tree_util.tree_unflatten(out_tree, out_flat) - flat_in_type = [ jax.ShapeDtypeStruct(x.aval.shape, x.aval.dtype) for x in flat_invars ] @@ -108,9 +99,158 @@ def _fn(*args, **kwargs): out_tree, [jax.ShapeDtypeStruct(x.aval.shape, x.aval.dtype) for x in flat_outvars], ) + return new_jaxpr, new_values, in_type, out_type, out_tree + + +def construct_fusion( + candidate_values, jaxpr: jax_core.Jaxpr, outvars, *invars, **kwargs +) -> fusion_lib.Fusion: + new_jaxpr, new_values, in_type, out_type, out_tree = _construct_fusion_jaxpr( + candidate_values, jaxpr, outvars, *invars, **kwargs + ) + + def _fn(*args, **kwargs): + flat_args, _ = tree_util.tree_flatten((args, kwargs)) + out_flat = jax_core.eval_jaxpr(new_jaxpr, new_values, *flat_args) + return tree_util.tree_unflatten(out_tree, out_flat) + return fusion_lib.Fusion(_fn, in_type, out_type) +def _find_downstream( + jaxpr: jax_core.Jaxpr, in_used: Sequence[bool] +) -> tuple[bool, ...]: + # TODO(sharadmv): We use partial_eval to query downstream dependencies which + # is not an officially sanctioned way to do so, since PE is really used for + # AD. In the future, we should have a special Jaxpr API that queries this. + _, _, out_used, *_ = pe.partial_eval_jaxpr_custom( + jaxpr, + in_unknowns=in_used, + in_inst=in_used, + ensure_out_unknowns=False, + ensure_out_inst=False, + saveable=lambda *_, **__: False, + ) + return tuple(out_used) + + +def _construct_output_permutation( + used: list[tuple[bool, ...]], +) -> list[int]: + order = [] + for u in used: + true_vals = [i for i in range(len(u)) if u[i]] + order.extend(true_vals) + return [order.index(i) for i in range(len(order))] + + +def _construct_output_fusions( + candidate_values, + jaxpr, + out_tree, + fusion_eqn_index, + fusion_eqn_outvars, # Flat list of vars output by the fusible eqn + fusion_eqn_out_tree, # Tree structure of the fusible eqn outputs + output_fusion_prefix, # Pytree defining output groups +): + # 1. Create jaxpr_out: represents computation *after* the fusible + # Inputs: fusion_eqn_outvars + # Outputs: jaxpr.outvars + jaxpr_out, all_values, _, _, _ = _construct_fusion_jaxpr( + candidate_values, + jaxpr.replace( + eqns=jaxpr.eqns[:fusion_eqn_index] + + jaxpr.eqns[fusion_eqn_index + 1 :] + ), + tree_util.tree_unflatten(out_tree, jaxpr.outvars), # Original outputs + tree_util.tree_unflatten( + fusion_eqn_out_tree, fusion_eqn_outvars + ), # Fusible outputs as inputs + ) + + # 2. Group fusible outputs based on the mask + unflat_fusible_outvars = jax.tree.unflatten( + fusion_eqn_out_tree, fusion_eqn_outvars + ) + partial_flat = jax.tree.structure(output_fusion_prefix).flatten_up_to( + unflat_fusible_outvars + ) + + # 3. Calculate dependencies and check disjointedness + downstream_outputs_used_masks = [] # List of bool tuples, one per group + already_used_final_outputs = set() # Indices of final outputs already claimed + for outvars_group in partial_flat: + # Identify vars in this group + used_fusible_outvars = set(jax.tree.leaves(outvars_group)) + # Create mask for jaxpr_out inputs corresponding to this group + in_used_mask = [ + True if v in used_fusible_outvars else False for v in jaxpr_out.invars + ] + # Trace dependencies through jaxpr_out to find which final outputs are affected + downstream_used_mask = _find_downstream( + jaxpr_out, in_used_mask + ) # Mask for jaxpr_out.outvars (== jaxpr.outvars) + + # Check for overlap in final output usage across groups + for i, used in enumerate(downstream_used_mask): + if used: + if i in already_used_final_outputs: + raise ValueError( + "Outputs must be disjoint in order to use separate output fusions" + ) + already_used_final_outputs.add(i) + downstream_outputs_used_masks.append(downstream_used_mask) + + # 4. Construct output permutation needed to restore original output order + output_permutation = _construct_output_permutation( + downstream_outputs_used_masks + ) + + # Construct fusions for each group by DCEing the jaxpr_out + output_fusions = [] + for i, outvars_group in enumerate(partial_flat): + flat_group_vars, _ = tree_util.tree_flatten(outvars_group) + downstream_used_mask = downstream_outputs_used_masks[i] + + used_jaxpr_invars = [False] * len(all_values) + [ + v in flat_group_vars for v in jaxpr_out.invars + ] + jaxpr_out_for_group, used_consts, _ = pe.dce_jaxpr_consts( + jaxpr_out, downstream_used_mask, instantiate=used_jaxpr_invars + ) + values_for_jaxpr = tuple( + c for used, c in zip(used_consts, all_values, strict=True) if used + ) + + def _fn(jaxpr, vals, *args, **kwargs): + flat_args, _ = tree_util.tree_flatten((args, kwargs)) + out_flat = jax_core.eval_jaxpr(jaxpr, vals, *flat_args) + return tuple(out_flat) + + fn = functools.partial(_fn, jaxpr_out_for_group, values_for_jaxpr) + in_type = jax.tree.map( + lambda v: jax.ShapeDtypeStruct(v.aval.shape, v.aval.dtype), # pytype: disable=attribute-error + outvars_group, + ) + out_type = tuple( + jax.ShapeDtypeStruct(v.aval.shape, v.aval.dtype) # pytype: disable=attribute-error + for v in jaxpr_out_for_group.outvars + ) + fusion = fusion_lib.Fusion( + fn, + (in_type, {}), + out_type, + ) + output_fusions.append(fusion) + + return ( + tree_util.tree_unflatten( + tree_util.tree_structure(output_fusion_prefix), output_fusions + ), + output_permutation, + ) + + def fuse_jaxpr( jaxpr: jax_core.Jaxpr, out_tree: tree_util.PyTreeDef, consts, *args ): @@ -118,16 +258,25 @@ def fuse_jaxpr( # Collect input fusions for i, eqn in enumerate(jaxpr.eqns): - if eqn.primitive is fusable_p: + if eqn.primitive is fusible_p: fusion_eqn_index = i break if fusion_eqn_index is None: - raise ValueError("No fusable eqn found") + raise ValueError("No fusible eqn found") fusion_eqn = jaxpr.eqns[fusion_eqn_index] + # Now let's check if we need to do any fusion at all, e.g. do the outputs of + # the jaxpr have any dependence on the fusion at all? We can DCE the jaxpr + # with all the inputs and outputs to check if there is a dependence. + dced_jaxpr, _ = pe.dce_jaxpr(jaxpr, [True] * len(jaxpr.outvars), + instantiate=True) + if not any(eqn.primitive is fusible_p for eqn in dced_jaxpr.eqns): + # Short circuit if there is nothing to fuse. + return jax_core.eval_jaxpr(dced_jaxpr, consts, *args) + candidate_values = [*consts, *args] - # Construct fusions for non-constant inputs to the fusable. + # Construct fusions for non-constant inputs to the fusible. in_fusions_flat = [ construct_fusion( candidate_values, @@ -141,21 +290,20 @@ def fuse_jaxpr( in_fusions = tree_util.tree_unflatten( fusion_eqn.params["in_tree"], in_fusions_flat ) - out_fusion = construct_fusion( + output_fusions, output_permutation = _construct_output_fusions( candidate_values, - jaxpr.replace( - eqns=jaxpr.eqns[:fusion_eqn_index] - + jaxpr.eqns[fusion_eqn_index + 1 :] - ), - tree_util.tree_unflatten(out_tree, jaxpr.outvars), - tree_util.tree_unflatten( - fusion_eqn.params["out_tree"], fusion_eqn.outvars - ), + jaxpr, + out_tree, + fusion_eqn_index, + fusion_eqn.outvars, + fusion_eqn.params["out_tree"], + fusion_eqn.params["output_fusion_prefix"], ) - # Run the fusable. - out = fusion_eqn.params["func"](*in_fusions, out_fusion) - - # Now return the flattened output (the fuse_jaxpr caller should unflatten). - out_flat = tree_util.tree_leaves(out) - assert len(out_flat) == len(jaxpr.outvars) - return out_flat + out = fusion_eqn.params["func"](*in_fusions, output_fusions) + flat_out = jax.tree.leaves(out) + permuted_out = [flat_out[i] for i in output_permutation] + assert len(permuted_out) == len(jaxpr.outvars), ( + len(permuted_out), + len(jaxpr.outvars), + ) + return permuted_out diff --git a/jax/_src/pallas/helpers.py b/jax/_src/pallas/helpers.py index 1b2649d4e987..71004cd405a3 100644 --- a/jax/_src/pallas/helpers.py +++ b/jax/_src/pallas/helpers.py @@ -13,44 +13,54 @@ # limitations under the License. """Pallas helper functions.""" -from typing import Any, Protocol +from collections.abc import Callable import jax -import jax.numpy as jnp -from jax._src.pallas import pallas_call +from jax._src import checkify +from jax._src import config from jax._src.pallas import core as pl_core +from jax._src.pallas import pallas_call @jax.named_call def empty( - shape: tuple[int, ...], dtype: jnp.dtype, *, memory_space: Any = None + shape: tuple[int, ...], + dtype: jax.typing.DTypeLike, + *, + memory_space: object | None = None, + interpret: bool = False, + backend: pl_core.Backend | None = None, ): - def _empty_kernel(_): - # No-op to leave the out_ref uninitialized - pass + return empty_like( + jax.ShapeDtypeStruct(shape, dtype), + memory_space=memory_space, + interpret=interpret, + backend=backend, + ) + +@jax.named_call +def empty_like( + x: object, + *, + memory_space: object | None = None, + interpret: bool = False, + backend: pl_core.Backend | None = None, +): if memory_space is None: - kernel_memory_space = pl_core.MemorySpace.ANY - memory_space = jax.ShapeDtypeStruct - else: - kernel_memory_space = memory_space + memory_space = pl_core.MemorySpace.ANY return pallas_call.pallas_call( - _empty_kernel, - in_specs=[], - out_specs=pl_core.BlockSpec(memory_space=kernel_memory_space), - out_shape=memory_space(shape, dtype), + # No-op to leave the out_ref uninitialized + lambda *_: None, + out_specs=jax.tree.map( + lambda _: pl_core.BlockSpec(memory_space=memory_space), x + ), + out_shape=x, + interpret=interpret, + backend=backend, )() -class ArrayLike(Protocol): - shape: tuple[int, ...] - dtype: jnp.dtype - - -def empty_like(x: ArrayLike, *, memory_space: Any = None): - return empty(x.shape, x.dtype, memory_space=memory_space) - - def when(condition): def _wrapped(f): if isinstance(condition, bool): @@ -59,3 +69,43 @@ def _wrapped(f): else: jax.lax.cond(condition, f, lambda: None) return _wrapped + + +def loop( + lower: jax.typing.ArrayLike, + upper: jax.typing.ArrayLike, + *, + unroll: int | bool | None = None, +) -> Callable[[Callable[[jax.Array], None]], None]: + def decorator(body): + jax.lax.fori_loop( + lower, upper, lambda idx, _: body(idx), init_val=None, unroll=unroll + ) + + return decorator + + +_ENABLE_DEBUG_CHECKS = config.bool_state( + "jax_pallas_enable_debug_checks", + default=False, + help=( + "If set, ``pl.debug_check`` calls are checked at runtime. Otherwise," + " they are a noop." + ), +) + + +enable_debug_checks = _ENABLE_DEBUG_CHECKS + + +def debug_checks_enabled() -> bool: + """Returns runtime checks are enabled.""" + return _ENABLE_DEBUG_CHECKS.value + + +def debug_check(condition, message): + """Check the condition if + :func:`~jax.experimental.pallas.enable_debug_checks` is set, otherwise + do nothing. + """ + return checkify.debug_check(condition, message) diff --git a/jax/_src/pallas/hlo_interpreter.py b/jax/_src/pallas/hlo_interpreter.py index 6fbe5e914bfe..038d93d3f9e2 100644 --- a/jax/_src/pallas/hlo_interpreter.py +++ b/jax/_src/pallas/hlo_interpreter.py @@ -27,7 +27,8 @@ from collections.abc import Iterable, Sequence from functools import reduce, partial import itertools -from typing import Any, Callable +from typing import Any +from collections.abc import Callable import jax from jax import lax @@ -83,18 +84,19 @@ def _logical_aval_to_interpret_mode_aval(aval): return aval -def _dynamic_slice(start_idx, block_shape, value, is_indexing): +def _dynamic_slice( + start_idx, block_shape: tuple[int, ...], value, is_squeeze, +): start_idx = tuple(jnp.asarray(s, dtype=jnp.int32) for s in start_idx) output = lax.dynamic_slice(value, start_idx, slice_sizes=block_shape) - squeeze_dims = tuple(np.arange(len(is_indexing))[np.array(is_indexing, - dtype=np.bool_)]) - return lax.squeeze(output, squeeze_dims) + squeeze_dims = tuple(np.arange(len(is_squeeze))[np.array(is_squeeze, + dtype=np.bool_)]) + return lax.squeeze(output, squeeze_dims) # type: ignore[arg-type] -def _dynamic_update_slice(start_idx, block_shape, value, update, - is_indexing): +def _dynamic_update_slice(start_idx, block_shape, value, update, is_squeeze): start_idx = tuple(jnp.asarray(s, dtype=jnp.int32) for s in start_idx) - broadcast_dims = tuple(i for i, b in enumerate(is_indexing) + broadcast_dims = tuple(i for i, b in enumerate(is_squeeze) if not b) update = lax.broadcast_in_dim(update, block_shape, broadcast_dims) assert update.shape == block_shape @@ -112,8 +114,7 @@ def _get_next_indices(grid, indices): return tuple(reversed(next_indices)) -def _pad_to_block_dimension(value, - block_shape): +def _pad_to_block_dimension(value, block_shape: tuple[int, ...]): """Pads values so the shape evenly divides into block dimensions. For example, if values has a shape of (33, 2, 5) with a block_shape of @@ -121,8 +122,7 @@ def _pad_to_block_dimension(value, Args: value: Array to be padded. - block_shape: Block shapes to use for padding. If None, no padding will - be performed. + block_shape: Block shapes to use for padding. Returns: A padded array. @@ -190,7 +190,7 @@ def eval_jaxpr_recursive( consts: Consts that ``jaxpr`` closes over. *args: Input arguments to the ``jaxpr``. recurse_hop_rule: A Jaxpr interpreter to call on sub-jaxprs of - higher-order primtives. + higher-order primitives. propagate_source_info: Whether to propagate source info. """ def read(v: jax_core.Atom) -> Any: @@ -236,8 +236,7 @@ def pad_jaxpr_constvars(jaxpr: jax_core.Jaxpr, to pad each Jaxpr with all consts from all branches so the signatures match, but only use the consts for this branch. """ - newvar = jax_core.gensym(suffix='_') - unused_const_vars = [tuple(map(newvar, const_avals)) + unused_const_vars = [tuple(map(jax_core.Var, const_avals)) for const_avals in all_const_avals] const_prefix = util.concatenate(unused_const_vars[:i]) const_suffix = util.concatenate(unused_const_vars[i + 1:]) @@ -313,9 +312,15 @@ def rule(interpreter, *args, **params): lax.while_p, 'body_jaxpr', 'cond_jaxpr') _eval_jaxpr_hop_rules[lax.cond_p] = make_hop_rule(lax.cond_p, 'branches') def _run_scoped_physicalize_rule( - interpreter, *consts, jaxpr: jax_core.Jaxpr): + interpreter, *consts, jaxpr: jax_core.Jaxpr, collective_axes): + if collective_axes: + raise NotImplementedError( + "run_scoped interpret rule does not support collective axes" + ) physical_jaxpr, physical_consts = interpreter(jaxpr, consts) - return primitives.run_scoped_p.bind(*physical_consts, jaxpr=physical_jaxpr) + return primitives.run_scoped_p.bind( + *physical_consts, jaxpr=physical_jaxpr, collective_axes=collective_axes + ) _eval_jaxpr_hop_rules[primitives.run_scoped_p] = _run_scoped_physicalize_rule @@ -377,23 +382,21 @@ def pallas_call_hlo_interpret( carry = [] for x, bm in zip(itertools.chain(block_args, out), grid_mapping.block_mappings): - if isinstance(bm.indexing_mode, pallas_core.Unblocked): - padding = bm.indexing_mode.padding - if padding is not None and any(p != (0, 0) for p in padding): - if input_output_aliases: - raise NotImplementedError("Padding with aliasing not supported.") - pad_value = primitives.uninitialized_value(shape=(), dtype=x.dtype) - x = lax.pad(x, pad_value, [(*p, 0) for p in padding]) + padding = [bd.padding if isinstance(bd, pallas_core.Element) else (0, 0) + for bd in bm.block_shape] + if padding is not None and any(p != (0, 0) for p in padding): + if input_output_aliases: + raise NotImplementedError("Padding with aliasing not supported.") + pad_value = primitives.uninitialized_value(shape=(), dtype=x.dtype) + x = lax.pad(x, pad_value, [(*p, 0) for p in padding]) carry.append(x) - is_indexing_dim = [ - tuple(b is pallas_core.mapped for b in bm.block_shape) + block_shapes = [pallas_core._get_block_shape(bm.block_shape) + for bm in grid_mapping.block_mappings] + is_squeeze_dim = [ + tuple(isinstance(bd, pallas_core.Squeezed) for bd in bm.block_shape) for bm in grid_mapping.block_mappings ] - block_shapes = [ - tuple(1 if i else b for i, b in zip(iid, bm.block_shape)) - for iid, bm in zip(is_indexing_dim, grid_mapping.block_mappings) - ] # Pad values to evenly divide into block dimensions. This matches the # behavior of the non-interpret mode. We pad with NaN, to make it easier @@ -416,7 +419,7 @@ def pallas_call_hlo_interpret( num_iterations = 1 # The scan carry: (i, loop_idx, *consts, *ins, *outs, *scratch) - # i:int32 is the interation index + # i:int32 is the iteration index # loop_idx: tuple[int32] are the program ids for each grid axis def cond(carry): i, *_ = carry @@ -444,7 +447,7 @@ def body(carry): for bm in grid_mapping.block_mappings ] blocks = map(_dynamic_slice, start_indices, block_shapes, - carry_consts_ins, is_indexing_dim) + carry_consts_ins, is_squeeze_dim) with pallas_core.grid_env(local_grid_env): assert len(discharged_jaxpr.invars) == len(scalars) + len(blocks) + len( scratch_values @@ -462,7 +465,7 @@ def body(carry): _, out_inout, out_scratch = split_list( blocks, [grid_mapping.num_index_operands, num_inout_blocks]) out_carry = map(_dynamic_update_slice, start_indices, block_shapes, - carry_consts_ins, out_inout, is_indexing_dim) + carry_consts_ins, out_inout, is_squeeze_dim) return (i + 1, _get_next_indices(grid, loop_idx), *out_carry, *out_scratch) @@ -473,14 +476,14 @@ def body(carry): out_out = carry[len(block_args):len(block_args) + len(out)] out_nopad = [] for o, bm in zip(out_out, grid_mapping.block_mappings_output): - if isinstance(bm.indexing_mode, pallas_core.Unblocked): - padding = bm.indexing_mode.padding - if padding is not None and any(p != (0, 0) for p in padding): - if input_output_aliases: - raise NotImplementedError("Padding with aliasing not supported.") - pad_low, pad_high = zip(*padding) - limit_indices = [s - p for s, p in zip(o.shape, pad_high)] - o = lax.slice(o, pad_low, limit_indices) + padding = [bd.padding if isinstance(bd, pallas_core.Element) else (0, 0) + for bd in bm.block_shape] + if padding is not None and any(p != (0, 0) for p in padding): + if input_output_aliases: + raise NotImplementedError("Padding with aliasing not supported.") + pad_low, pad_high = zip(*padding) + limit_indices = [s - p for s, p in zip(o.shape, pad_high)] + o = lax.slice(o, pad_low, limit_indices) if o.shape != bm.array_shape_dtype.shape: o = lax.slice(o, (0,) * o.ndim, bm.array_shape_dtype.shape) out_nopad.append(o) diff --git a/jax/_src/pallas/mosaic/BUILD b/jax/_src/pallas/mosaic/BUILD index 24e8341046b0..83525f11d3cf 100644 --- a/jax/_src/pallas/mosaic/BUILD +++ b/jax/_src/pallas/mosaic/BUILD @@ -103,6 +103,7 @@ py_library( "//jax", "//jax:ad_util", "//jax:core", + "//jax:custom_derivatives", "//jax:dtypes", "//jax:mesh", "//jax:mlir", @@ -158,6 +159,7 @@ py_library( deps = [ ":core", ":primitives", + ":verification", "//jax", "//jax:core", "//jax:source_info_util", diff --git a/jax/_src/pallas/mosaic/core.py b/jax/_src/pallas/mosaic/core.py index f582248ee7c3..a63df1ca8b42 100644 --- a/jax/_src/pallas/mosaic/core.py +++ b/jax/_src/pallas/mosaic/core.py @@ -21,12 +21,12 @@ import enum import functools from typing import Any, ClassVar, Literal +from collections.abc import Mapping import jax -from jax._src import config from jax._src import core as jax_core -from jax._src import dtypes from jax._src import util +from jax._src.frozen_dict import FrozenDict from jax._src.pallas import core as pallas_core import jax.numpy as jnp import numpy as np @@ -49,61 +49,105 @@ _out_shape_to_aval_mapping = pallas_core._out_shape_to_aval_mapping split_list = util.split_list -_ENABLE_RUNTIME_ASSERT = config.bool_state( - "jax_pallas_enable_runtime_assert", - default=False, - help=( - "If set, enables runtime assertions in the kernel via checkify.check." - " Otherwise, runtime asserts will be ignored unless functionalized" - " using checkify.checkify." - ), -) + +class KernelType(enum.Enum): + TC = 0 + SC_SCALAR_SUBCORE = 1 + SC_VECTOR_SUBCORE = 2 + + +class GridDimensionSemantics(enum.Enum): + PARALLEL = "parallel" + ARBITRARY = "arbitrary" + +PARALLEL = GridDimensionSemantics.PARALLEL +ARBITRARY = GridDimensionSemantics.ARBITRARY + + +DimensionSemantics = Literal["parallel", "arbitrary"] | GridDimensionSemantics @dataclasses.dataclass(frozen=True) -class TPUCompilerParams(pallas_core.CompilerParams): +class CompilerParams(pallas_core.CompilerParams): """Mosaic TPU compiler parameters. Attributes: - dimension_semantics: A list of dimension semantics for each grid - dimension of the kernel. Either "parallel" for dimensions that can - execute in any order, or "arbitrary" for dimensions that must be - executed sequentially. + dimension_semantics: A list of dimension semantics for each grid dimension + of the kernel. Either "parallel" for dimensions that can execute in any + order, or "arbitrary" for dimensions that must be executed sequentially. allow_input_fusion: A list of booleans indicating whether input fusion is allowed for each argument. - vmem_limit_bytes: Overrides the default VMEM limit for a kernel. Note - that this must be used in conjunction with the + vmem_limit_bytes: Overrides the default VMEM limit for a kernel. Note that + this must be used in conjunction with the --xla_tpu_scoped_vmem_limit_kib=N flag with N*1kib > vmem_limit_bytes. - collective_id: Indicates which barrier semaphore to use for the kernel. - Note that using the same collective_id does not guarantee that - the same barrier semaphore will be allocated between kernels. + collective_id: Indicates which barrier semaphore to use for the kernel. Note + that using the same collective_id does not guarantee that the same barrier + semaphore will be allocated between kernels. internal_scratch_in_bytes: The size of the internal scratch space used by Mosaic. flags: A dictionary of command line flags for the kernel. serialization_format: The serialization format for the kernel body. - device_type: The device type to compile for. + disable_bounds_checks: Disable bounds checks in the kernel. """ - PLATFORM: ClassVar[str] = "mosaic" - dimension_semantics: ( - Sequence[Literal["parallel", "arbitrary"] | GridDimensionSemantics] | None - ) = None - allow_input_fusion: Sequence[bool] | None = None + BACKEND: ClassVar[pallas_core.Backend] = "mosaic_tpu" + dimension_semantics: tuple[DimensionSemantics, ...] | None = None + allow_input_fusion: tuple[bool, ...] | None = None vmem_limit_bytes: int | None = None collective_id: int | None = None has_side_effects: bool = False flags: dict[str, Any] | None = None internal_scratch_in_bytes: int | None = None serialization_format: int = 1 - device_type: str | None = None + kernel_type: KernelType = KernelType.TC + disable_bounds_checks: bool = False + def __init__( + self, + dimension_semantics: Sequence[DimensionSemantics] | None = None, + allow_input_fusion: Sequence[bool] | None = None, + vmem_limit_bytes: int | None = None, + collective_id: int | None = None, + has_side_effects: bool = False, + flags: Mapping[str, Any] | None = None, + internal_scratch_in_bytes: int | None = None, + serialization_format: int = 1, + kernel_type: KernelType = KernelType.TC, + disable_bounds_checks: bool = False, + ): + object.__setattr__( + self, + "dimension_semantics", + None if dimension_semantics is None else tuple(dimension_semantics), + ) + object.__setattr__( + self, + "allow_input_fusion", + None if allow_input_fusion is None else tuple(allow_input_fusion), + ) + object.__setattr__(self, "vmem_limit_bytes", vmem_limit_bytes) + object.__setattr__(self, "collective_id", collective_id) + object.__setattr__(self, "has_side_effects", has_side_effects) + object.__setattr__( + self, "flags", None if flags is None else FrozenDict(flags) + ) + object.__setattr__( + self, "internal_scratch_in_bytes", internal_scratch_in_bytes + ) + object.__setattr__(self, "serialization_format", serialization_format) + object.__setattr__(self, "kernel_type", kernel_type) + object.__setattr__(self, "disable_bounds_checks", disable_bounds_checks) + + # Replace is a method, not a field. replace = dataclasses.replace -class TPUMemorySpace(enum.Enum): + +class MemorySpace(enum.Enum): ANY = "any" # TODO(b/368401328): Remove this and just use pl.ANY. VMEM = "vmem" SMEM = "smem" CMEM = "cmem" SEMAPHORE = "semaphore_mem" + HBM = "hbm" def __str__(self) -> str: return self.value @@ -112,47 +156,12 @@ def __call__(self, shape: tuple[int, ...], dtype: jnp.dtype): # A convenience function for constructing MemoryRef types. return pallas_core.MemoryRef(shape, dtype, self) -class semaphore_dtype(dtypes.extended): pass -class semaphore(semaphore_dtype): pass -class dma_semaphore(semaphore_dtype): pass -class barrier_semaphore(semaphore_dtype): pass - -class AbstractSemaphoreTyRules: - @staticmethod - def pallas_interpret_element_aval(_) -> jax_core.ShapedArray: - return jax_core.ShapedArray((), pallas_core.SEMAPHORE_INTERPRET_DTYPE) +class dma_semaphore(pallas_core.semaphore_dtype): pass - @staticmethod - def physical_element_aval(_) -> jax_core.ShapedArray: - return jax_core.ShapedArray((), jnp.int32) - -class AbstractSemaphoreTy(dtypes.ExtendedDType): - name: str - _rules = AbstractSemaphoreTyRules - - def __repr__(self) -> str: - return self.name - - def __eq__(self, other): - return self.__class__ == other.__class__ - - def __hash__(self) -> int: - return hash(self.__class__) - -# TODO(sharadmv): implement dtype rules for AbstractSemaphoreTy - -class SemaphoreTy(AbstractSemaphoreTy): - type = semaphore - name = "sem" - -class DmaSemaphoreTy(AbstractSemaphoreTy): +class DMASemaphore(pallas_core.AbstractSemaphoreTy): type = dma_semaphore name = "dma_sem" -class BarrierSemaphoreTy(AbstractSemaphoreTy): - type = barrier_semaphore - name = "barrier_sem" - class SemaphoreType(enum.Enum): REGULAR = "regular" DMA = "dma" @@ -161,12 +170,12 @@ class SemaphoreType(enum.Enum): def __call__(self, shape: tuple[int, ...]): dtype: Any if self == SemaphoreType.DMA: - dtype = DmaSemaphoreTy() + dtype = DMASemaphore() elif self == SemaphoreType.BARRIER: - dtype = BarrierSemaphoreTy() + dtype = pallas_core.BarrierSemaphore() else: - dtype = SemaphoreTy() - return pallas_core.MemoryRef(shape, dtype, TPUMemorySpace.SEMAPHORE) + dtype = pallas_core.Semaphore() + return pallas_core.MemoryRef(shape, dtype, MemorySpace.SEMAPHORE) def get_array_aval(self) -> pallas_core.ShapedArrayWithMemorySpace: return self(()).get_array_aval() @@ -197,7 +206,7 @@ def __init__( def _make_scalar_ref_aval(self, aval): return AbstractMemoryRef(jax_core.ShapedArray(aval.shape, aval.dtype), - TPUMemorySpace.SMEM) + MemorySpace.SMEM) @dataclasses.dataclass(frozen=True) @@ -211,6 +220,17 @@ class TensorCoreMesh: devices: np.ndarray axis_names: Sequence[str] + def __init__(self, devices: np.ndarray, axis_names: Sequence[str]): + devices = np.copy(devices) + devices.setflags(write=False) + object.__setattr__(self, "devices", devices) + object.__setattr__(self, "axis_names", tuple(axis_names)) + + def __hash__(self) -> int: + return hash( + (self.devices.shape, tuple(np.ravel(self.devices)), self.axis_names) + ) + @property def backend(self) -> str: return "mosaic_tpu" @@ -225,23 +245,22 @@ def discharges_effect(self, effect: jax_core.Effect): def create_tensorcore_mesh( - axis_name: str, devices: Sequence[jax.Device] | None = None + axis_name: str, + devices: Sequence[jax.Device] | None = None, + num_cores: int | None = None, ) -> TensorCoreMesh: - # TODO(b/355036384): emit a better error if we don't have tensorcores. - if devices is None: - devices = jax.devices() - num_cores = devices[0].num_cores + if devices is not None and num_cores is not None: + raise ValueError('cannot specify both devices and num_cores') + if num_cores is None: + if devices is None: + devices = jax.devices() + num_cores = devices[0].num_cores return TensorCoreMesh( np.array([TensorCore(i) for i in range(num_cores)]), [axis_name], ) -def runtime_assert_enabled() -> bool: - """Returns whether runtime asserts are enabled.""" - return _ENABLE_RUNTIME_ASSERT.value - - def _tensorcore_mesh_discharge_rule( in_avals, out_avals, @@ -249,18 +268,18 @@ def _tensorcore_mesh_discharge_rule( mesh, jaxpr, compiler_params: Any | None, - interpret: bool, + interpret: Any, debug: bool, cost_estimate: pallas_core.CostEstimate | None, name: str, ): assert isinstance(mesh, TensorCoreMesh) - if compiler_params and not isinstance(compiler_params, TPUCompilerParams): + if compiler_params and not isinstance(compiler_params, CompilerParams): raise ValueError( - "compiler_params must be a pltpu.TPUCompilerParams" + "compiler_params must be a pltpu.CompilerParams" ) if not compiler_params: - compiler_params = TPUCompilerParams() + compiler_params = CompilerParams() if len(mesh.shape) > 1: raise NotImplementedError("Mesh must be 1D") if compiler_params.dimension_semantics is not None: @@ -296,10 +315,3 @@ def _convert_semaphore_type_to_aval( pallas_core._out_shape_to_aval_mapping[SemaphoreType] = ( _convert_semaphore_type_to_aval ) - - -class GridDimensionSemantics(enum.Enum): - PARALLEL = "parallel" - ARBITRARY = "arbitrary" -PARALLEL = GridDimensionSemantics.PARALLEL -ARBITRARY = GridDimensionSemantics.ARBITRARY diff --git a/jax/_src/pallas/mosaic/helpers.py b/jax/_src/pallas/mosaic/helpers.py index 76421cec3340..80bb4ef4abed 100644 --- a/jax/_src/pallas/mosaic/helpers.py +++ b/jax/_src/pallas/mosaic/helpers.py @@ -60,7 +60,7 @@ def _copy_start_or_wait(action, src_ref, dst_ref): def run_on_first_core(core_axis_name: str): """Runs a function on the first core in a given axis.""" - num_cores = jax.lax.psum(1, core_axis_name) + num_cores = jax.lax.axis_size(core_axis_name) if num_cores == 1: return lambda f: f() @@ -77,7 +77,7 @@ def _(): def core_barrier(sem, *, core_axis_name: str): """Synchronizes all cores in a given axis.""" - num_cores = jax.lax.psum(1, core_axis_name) + num_cores = jax.lax.axis_size(core_axis_name) core_id = jax.lax.axis_index(core_axis_name) @pl_helpers.when(num_cores > 1) @@ -88,8 +88,8 @@ def signal_core(i): # Don't signal ourself @pl_helpers.when(core_id != i) def _(): - plm_primitives.semaphore_signal(sem, 1, core_index=i) + pl_primitives.semaphore_signal(sem, 1, core_index=i) for i in range(num_cores): signal_core(i) - plm_primitives.semaphore_wait(sem, num_cores - 1) + pl_primitives.semaphore_wait(sem, num_cores - 1) diff --git a/jax/_src/pallas/mosaic/interpret.py b/jax/_src/pallas/mosaic/interpret.py index 1ad7be8154cd..de25d739c554 100644 --- a/jax/_src/pallas/mosaic/interpret.py +++ b/jax/_src/pallas/mosaic/interpret.py @@ -13,14 +13,15 @@ # limitations under the License. import collections -from collections.abc import Iterable, Sequence import dataclasses import enum import functools +import gc import itertools import math import threading -from typing import Any, Literal +from typing import Any, Literal, cast +from collections.abc import Callable import jax from jax import lax @@ -29,14 +30,16 @@ from jax._src.lax.control_flow import for_loop from jax._src import linear_util as lu from jax._src import source_info_util -from jax._src.pallas.mosaic import primitives as mosaic_primitives from jax._src.pallas.mosaic import core as mosaic_core +from jax._src.pallas.mosaic import primitives as mosaic_primitives +from jax._src.pallas.mosaic import verification from jax._src.pallas import core as pallas_core from jax._src.pallas import primitives from jax._src import pjit from jax._src.state import discharge as state_discharge from jax._src.state import indexing from jax._src.state import primitives as state_primitives +from jax._src.typing import Array from jax._src.util import ( safe_map, safe_zip, @@ -64,7 +67,7 @@ @dataclasses.dataclass(frozen=True) -class TPUInterpretParams: +class InterpretParams: """Parameters for Mosaic TPU interpret mode. Attributes: @@ -73,25 +76,46 @@ class TPUInterpretParams: is waiting on a DMA semaphore that will be signaled when the read or write is complete. Default: "on_wait". - detect_races: If True, a dynamic, happens-before race detector will be - used to detect data races during kernel interpretation. If any races are - detected, a message will be printed and `races.races_found` will be set - to True. + detect_races: If True, a dynamic, happens-before race detector will be used + to detect data races during kernel interpretation. If any races are + detected, a message will be printed and `races.races_found` will be set to + True. Default: False. skip_floating_point_ops: If True, operations that produce only floating point values will not be interpreted; instead, their results will be - replaced with arrays all of `jnp.inf`. Additionaly any floating point + replaced with arrays all of `jnp.inf`. Additionally any floating point operands to any operation will be replaced with (arrays of) `jnp.inf`. Default: False. - uninitialized_memory: If "nan", allocated buffers are initialized to - to contain all NaNs (or to their maximum possible value for integers). - If "zero", allocated buffers are initialized to all zeros. + uninitialized_memory: If "nan", allocated buffers are initialized to contain + all NaNs (or to their maximum possible value for integers). If "zero", + allocated buffers are initialized to all zeros. Default: "nan". + random_seed: Seed for random number generator used during interpretation. + Currently random numbers are used to randomize the grid coordinates along + dimensions with 'parallel' semantics. + Default: None. + grid_point_recorder: Callback that is invoked by the interpreter for each + grid point in the order in which the grid points are traversed. The + callback is invoked with two arguments: + - A tuple of grid coordinates. + - The local core ID of the core that is processing the grid point. + This callback is intended for inspecting + - the randomization of coordinates along grid dimensions with 'parallel' + semantics and + - the mapping of grid points to local (i.e. per-device) cores. + Default: None. + num_cores_per_device: The number of cores per device. + Default: 1. """ dma_execution_mode: Literal["eager", "on_wait"] = "on_wait" detect_races: bool = False skip_floating_point_ops: bool = False uninitialized_memory: Literal["nan", "zero"] = "nan" + random_seed: int | None = None + grid_point_recorder: ( + Callable[[tuple[np.int32, ...], np.int32], None] | None + ) = None + num_cores_per_device: int = 1 VectorClock = np.ndarray @@ -101,11 +125,12 @@ class TPUInterpretParams: # of DMAs. # # Instead, we use approximate vector clocks of fixed size. We assign each DMA -# a virtual device ID in the range [num_devices + 1, NUM_VIRTUAL_DEVICES] -- +# a virtual core ID in the range +# [num_devices*num_cores_per_device + 1, NUM_VIRTUAL_CORES], # and each operation of a DMA increments the corresponding coordinate in its -# vector clock. (So the "virtual" part of a vector clock is effectively -# counting, for each virtual device, the number of DMAs that happened-before -# the vector clock and were assigned to that virtual device.) +# vector clock. (So the "virtual" part of a vector clock is effectively +# counting, for each virtual core, the number of DMAs that happened-before +# the vector clock and were assigned to that virtual core.) # # If two approximate clocks are unordered, then their corresponding events are # not ordered by the happens-before relation. So this approximation will not @@ -114,11 +139,11 @@ class TPUInterpretParams: # clocks are ordered, and we will treat the corresponding events as ordered # by the happens-before relation, but the corresponding events are not # actually ordered. -NUM_VIRTUAL_DEVICES = 32 +NUM_VIRTUAL_CORES = 32 -def make_vector_clock(num_devices: int) -> VectorClock: - del num_devices - return np.zeros(NUM_VIRTUAL_DEVICES, dtype=np.int32) +def make_vector_clock(_: int) -> VectorClock: + del _ + return np.zeros(NUM_VIRTUAL_CORES, dtype=np.int32) def copy_vector_clock(x: VectorClock) -> VectorClock: if x is None: @@ -126,7 +151,7 @@ def copy_vector_clock(x: VectorClock) -> VectorClock: return x.copy() def update_vector_clock(x: VectorClock, y: VectorClock): - x[:] = np.maximum(x, y) + x[:] = np.maximum(x[:], y[:]) def lt(x: VectorClock, y: VectorClock) -> bool: return bool((x <= y).all() & (x < y).any()) @@ -134,11 +159,17 @@ def lt(x: VectorClock, y: VectorClock) -> bool: def ordered(x: VectorClock, y: VectorClock) -> bool: return lt(x, y) | lt(y, x) -def inc_vector_clock(x: VectorClock, device_id: int): - if device_id >= len(x): - raise ValueError(f'device_id={device_id} is out of range for x={x}') - assert device_id < len(x) - x[device_id] += 1 +def inc_vector_clock(x: VectorClock, global_core_id: int): + if global_core_id >= len(x): + raise ValueError(f'device_id={global_core_id} is out of range for x={x}') + assert global_core_id < len(x) + x[global_core_id] += 1 + +def _get_global_core_id(device_id, local_core_id): + """Computes the global core ID from the given device and local core ID.""" + device_id = int(device_id) + local_core_id = int(local_core_id) + return device_id * _get_shared_memory().num_cores_per_device + local_core_id class Semaphore: @@ -151,45 +182,45 @@ def __init__(self, semaphore_id=None): # easier to do when we're using single integer device IDs.) self.cv = threading.Condition() - self.counts = np.zeros(shared_memory.num_devices, dtype=np.int32) + self.counts = np.zeros(shared_memory.num_cores, dtype=np.int32) self.interpret_params = shared_memory.interpret_params if self.interpret_params.detect_races: # We associate a vector clock with each count in self.counts. Whenever # self.counts[i] is signaled, self.clocks[i] is updated with the vector - # clock of the signaling device. Whenever device i successfully waits on - # self.counts[i], the vector clock of device i is updated with + # clock of the signaling core. Whenever core i successfully waits on + # self.counts[i], the vector clock of core i is updated with # self.clocks[i]. # # TODO(jburnim): Model happens-before more precisely for the case where # semaphores are over-signaled. - self.clocks = [None] * shared_memory.num_devices + self.clocks = [None] * shared_memory.num_cores - def signal(self, inc, device_id, clock): - """Signal the semaphore on `device_id` by `inc`. + def signal(self, inc, global_core_id, clock): + """Signal the semaphore on `(device_id, core_id)` by `inc`. Args: inc: A positive integer. The amount by which to increment the semaphore on the target device. - device_id: The ID of the target device. + global_core_id: The ID of the target core. clock: The vector clock of the signaling device at the time of the signal. """ - device_id = int(device_id) + global_core_id = int(global_core_id) with self.cv: - self.counts[device_id] += inc + self.counts[global_core_id] += inc if self.interpret_params.detect_races: - if self.clocks[device_id] is None: - self.clocks[device_id] = copy_vector_clock(clock) + if self.clocks[global_core_id] is None: + self.clocks[global_core_id] = copy_vector_clock(clock) else: - update_vector_clock(self.clocks[device_id], clock) + update_vector_clock(self.clocks[global_core_id], clock) self.cv.notify_all() - def read(self, device_id): + def read(self, global_core_id): with self.cv: - return self.counts[device_id] + return self.counts[global_core_id] - def wait(self, value, device_id, *, is_dma=False): - device_id = int(device_id) + def wait(self, value, global_core_id, *, is_dma=False): + global_core_id = int(global_core_id) shared_memory = _get_shared_memory() # TODO(jburnim): @@ -200,14 +231,14 @@ def wait(self, value, device_id, *, is_dma=False): # Simple implementation for non-DMA semaphores. if not is_dma or (self.interpret_params.dma_execution_mode == "eager"): with self.cv: - while self.counts[device_id] < value: + while self.counts[global_core_id] < value: self.cv.wait() - self.counts[device_id] -= value + self.counts[global_core_id] -= value if self.interpret_params.detect_races: - clock = copy_vector_clock(self.clocks[device_id]) + clock = copy_vector_clock(self.clocks[global_core_id]) if self.interpret_params.detect_races: with shared_memory.lock: - update_vector_clock(shared_memory.clocks[device_id], clock) + update_vector_clock(shared_memory.clocks[global_core_id], clock) return # For DMA semaphores (when dma_execution_mode=='on_wait'), while our count @@ -221,15 +252,15 @@ def wait(self, value, device_id, *, is_dma=False): while True: clock = None with self.cv: - if self.counts[device_id] >= value: - self.counts[device_id] -= value + if self.counts[global_core_id] >= value: + self.counts[global_core_id] -= value if self.interpret_params.detect_races: - clock = copy_vector_clock(self.clocks[device_id]) + clock = copy_vector_clock(self.clocks[global_core_id]) else: return if clock is not None: with shared_memory.lock: - update_vector_clock(shared_memory.clocks[device_id], clock) + update_vector_clock(shared_memory.clocks[global_core_id], clock) return with shared_memory.lock: @@ -244,25 +275,32 @@ def wait(self, value, device_id, *, is_dma=False): with dma.lock: if dma.virtual_device_id is None: dma.virtual_device_id = np.random.randint( - shared_memory.num_devices, NUM_VIRTUAL_DEVICES) + shared_memory.num_devices, NUM_VIRTUAL_CORES) if dma.state == DmaState.STARTED: # Do the read. if self.interpret_params.detect_races: inc_vector_clock(dma.clock, dma.virtual_device_id) dma.data = get(dma.src_device_id, + dma.src_local_core_id, dma.src_memory_space, dma.src_buffer_id, dma.src_transforms, clock=copy_vector_clock(dma.clock), src_device_id=dma.id, + src_local_core_id=0, source_info=dma.source_info) if self.interpret_params.detect_races: inc_vector_clock(dma.clock, dma.virtual_device_id) if dma.src_sem is not None: data_size = dma.data.itemsize * dma.data.size dma.src_sem.signal( - data_size, device_id=dma.src_device_id, clock=dma.clock) + data_size, + global_core_id=_get_global_core_id( + dma.src_device_id, dma.src_local_core_id + ), + clock=dma.clock, + ) dma.state = DmaState.READ if dma.src_sem is self: @@ -276,18 +314,25 @@ def wait(self, value, device_id, *, is_dma=False): if self.interpret_params.detect_races: inc_vector_clock(dma.clock, dma.virtual_device_id) store(dma.dst_device_id, + dma.dst_local_core_id, dma.dst_memory_space, dma.dst_buffer_id, dma.dst_transforms, dma.data, clock=copy_vector_clock(dma.clock), src_device_id=dma.id, + src_local_core_id=0, source_info=dma.source_info) if self.interpret_params.detect_races: inc_vector_clock(dma.clock, dma.virtual_device_id) data_size = dma.data.itemsize * dma.data.size dma.dst_sem.signal( - data_size, device_id=dma.dst_device_id, clock=dma.clock) + data_size, + global_core_id=_get_global_core_id( + dma.dst_device_id, dma.dst_local_core_id + ), + clock=dma.clock, + ) dma.data = None dma.state = DmaState.COMPLETED @@ -303,10 +348,12 @@ class DMA: id: int src_device_id: int + src_local_core_id: int src_memory_space: int src_buffer_id: int src_transforms: tuple[Any, ...] dst_device_id: int + dst_local_core_id: int dst_memory_space: int dst_buffer_id: int dst_transforms: tuple[Any, ...] @@ -325,13 +372,14 @@ class DMA: @dataclasses.dataclass class RaceDetectionState: - num_devices: int + num_cores: int - # (memory_space, buffer_id, device_id) -> [(device_id, VectorClock, range)] + + # (memory_space, buffer_id, device_id, local_core_id) -> [(device_id, local_core_id, VectorClock, range)] reads: dict = dataclasses.field( default_factory=lambda: collections.defaultdict(list)) - # (memory_space, buffer_id, device_id) -> [(device_id, VectorClock, range)] + # (memory_space, buffer_id, device_id, local_core_id) -> [(device_id, local_core_id, VectorClock, range)] writes: dict = dataclasses.field( default_factory=lambda: collections.defaultdict(list)) @@ -373,7 +421,10 @@ def ranges_overlap(range1: tuple[slice | int, ...], return all(slices_overlap(r1, r2) for r1, r2 in itertools.zip_longest(range1, range2, fillvalue=slice(None))) -def check_read(device_id, clock, buffer_key, rnge, source_info=None): + +def check_read( + device_id, local_core_id, clock, buffer_key, rnge, source_info=None +): if source_info is not None: user_frame = source_info_util.summarize(source_info) else: @@ -382,24 +433,36 @@ def check_read(device_id, clock, buffer_key, rnge, source_info=None): with races.lock: writes = races.writes[buffer_key] num_writes = len(writes) - races.reads[buffer_key].append((device_id, clock, rnge, user_frame)) + races.reads[buffer_key].append( + (device_id, local_core_id, clock, rnge, user_frame) + ) for i in range(num_writes): - write_device_id, write_clock, write_range, write_frame = writes[i] + ( + write_device_id, + write_local_core_id, + write_clock, + write_range, + write_frame, + ) = writes[i] if ordered(write_clock, clock): continue if not ranges_overlap(rnge, write_range): continue # TODO(jburnim): When printing device IDs for reads/writes, distinguish # between real device IDs vs. DMA IDs. - print('RACE DETECTED\n' - f' read of {buffer_key}[{rnge}] from {device_id}, {user_frame}\n' - f' write of {buffer_key}[{write_range}] from {write_device_id}, {write_frame}') + print( + f'RACE DETECTED\n read of {buffer_key}[{rnge}] from {device_id},' + f' {local_core_id}, {user_frame}\n write of' + f' {buffer_key}[{write_range}] from {write_device_id},' + f' {write_local_core_id} {write_frame}' + ) with races.lock: races.races_found = True return -def check_write(device_id, clock, buffer_key, rnge, source_info=None): + +def check_write(device_id, local_core_id, clock, buffer_key, rnge, source_info=None): if source_info is not None: user_frame = source_info_util.summarize(source_info) else: @@ -410,37 +473,50 @@ def check_write(device_id, clock, buffer_key, rnge, source_info=None): reads = races.reads[buffer_key] num_writes = len(writes) num_reads = len(reads) - races.writes[buffer_key].append((device_id, clock, rnge, user_frame)) + races.writes[buffer_key].append((device_id, local_core_id, clock, rnge, user_frame)) # TODO(jburnim): For performance, we should also probably remove any # conflicting reads and writes that happened-before the current write. for i in range(num_writes): - write_device_id, write_clock, write_range, write_frame = writes[i] + ( + write_device_id, + write_local_core_id, + write_clock, + write_range, + write_frame, + ) = writes[i] if ordered(write_clock, clock): continue if not ranges_overlap(rnge, write_range): continue # TODO(jburnim): When printing device IDs for reads/writes, distinguish # between real device IDs vs. DMA IDs. - print('RACE DETECTED\n' - f' write of {buffer_key}[{rnge}] from {device_id}, {user_frame}\n' - f' write of {buffer_key}[{write_range}] from {write_device_id}, {write_frame}') + print( + f'RACE DETECTED\n write of {buffer_key}[{rnge}] from {device_id},' + f' {local_core_id}, {user_frame}\n write of' + f' {buffer_key}[{write_range}] from {write_device_id},' + f' {write_local_core_id}, {write_frame}' + ) with races.lock: races.races_found = True break for i in range(num_reads): - read_device_id, read_clock, read_range, read_frame = reads[i] + read_device_id, read_local_core_id, read_clock, read_range, read_frame = ( + reads[i] + ) if ordered(read_clock, clock): continue if not ranges_overlap(rnge, read_range): continue # TODO(jburnim): When printing device IDs for reads/writes, distinguish # between real device IDs vs. DMA IDs. - print('RACE DETECTED\n' - f' write of {buffer_key}[{rnge}] from {device_id}, {user_frame}\n' - f' read of {buffer_key}[{read_range}] from {read_device_id}, {read_frame}') + print( + f'RACE DETECTED\n write of {buffer_key}[{rnge}] from {device_id},' + f' {local_core_id}, {user_frame}\n read of {buffer_key}[{read_range}]' + f' from {read_device_id}, {read_local_core_id}, {read_frame}' + ) with races.lock: races.races_found = True return @@ -448,14 +524,15 @@ def check_write(device_id, clock, buffer_key, rnge, source_info=None): @dataclasses.dataclass class SharedMemory: - interpret_params: TPUInterpretParams + interpret_params: InterpretParams num_devices: int + num_cores_per_device: int clocks: list[VectorClock] barrier: threading.Barrier + clean_up_barrier: threading.Barrier - # (memory_space, buffer_id, device_id) -> NumPy array - # TODO(jburnim): Handle Megacore. - mem: dict[tuple[int, int, int], np.ndarray] = dataclasses.field( + # (memory_space, buffer_id, device_id, local_core_id) -> NumPy array + mem: dict[tuple[str, int, int, int], np.ndarray] = dataclasses.field( default_factory=dict) # semaphore_id -> Semaphore @@ -463,20 +540,29 @@ class SharedMemory: # (semaphore_id, device_id) # -> list of DMAs that will signal the semaphore on the given device + # TODO(jburnim): Fix uses of `dmas_by_sem` to align with the two lines of + # documentation above, i.e. index `dmas_by_sem` with + # `(semaphore_id, device_id)` (currently indexed with `semaphore_id only). dmas_by_sem: dict[tuple[int, int], list[DMA]] = dataclasses.field( default_factory=lambda: collections.defaultdict(list)) lock: threading.Lock = dataclasses.field(default_factory=threading.Lock) - # device_id -> next buffer ID - next_buffer_id: dict[int, int] = dataclasses.field( + # (device_id, local_core_id) -> next buffer ID + next_buffer_id: dict[tuple[int, int], int] = dataclasses.field( default_factory=lambda: collections.defaultdict(lambda: 100)) - # device_id -> next semaphore ID + # global_core_id -> next semaphore ID next_semaphore_id: dict[int, int] = dataclasses.field( default_factory=lambda: collections.defaultdict(lambda: 2000)) next_dma_id: int = 100 + deallocated_bytes: int = 0 + + @property + def num_cores(self) -> int: + return self.num_devices * self.num_cores_per_device + # TODO(jburnim): Do we want to support multiple instances of SharedMemory? # Maybe for running multiple distinct interpreted computations in parallel? @@ -493,100 +579,257 @@ def _clear_shared_memory(): with _shared_memory_init_lock: _shared_memory = None -def _initialize_shared_memory(device_id, num_devices, *, interpret_params): + +def _initialize_shared_memory( + device_id, num_devices, num_cores_per_device, *, interpret_params +): global _shared_memory del device_id num_devices = int(num_devices) + num_cores_per_device = int(num_cores_per_device) + num_cores = num_devices * num_cores_per_device with _shared_memory_init_lock: if _shared_memory is None: _shared_memory = SharedMemory( interpret_params=interpret_params, num_devices=num_devices, - clocks=[make_vector_clock(num_devices) for _ in range(num_devices)], - barrier=threading.Barrier(num_devices)) - assert _shared_memory.num_devices == num_devices + num_cores_per_device=num_cores_per_device, + clocks=[make_vector_clock(num_cores) for _ in range(num_cores)], + barrier=threading.Barrier( + num_devices, action=_update_clocks_for_global_barrier), + clean_up_barrier=threading.Barrier( + num_devices, action=_clear_shared_memory)) + assert _shared_memory.num_cores == num_cores global races - races = RaceDetectionState(num_devices=num_devices) + races = RaceDetectionState(num_cores=num_cores) + +def _update_clocks(low_global_core_id, high_global_core_id): + """Synchronizes the vector clocks for the cores with ids in the range between the two arguments.""" + shared_memory = _get_shared_memory() + # Despite only updating the vector clocks for some cores, we still need to + # hold the global lock to ensure that no other devices are concurrently + # accessing the same vector clocks. + with shared_memory.lock: + for c in shared_memory.clocks[low_global_core_id + 1 : high_global_core_id]: + update_vector_clock(shared_memory.clocks[low_global_core_id], c) + for c in shared_memory.clocks[low_global_core_id + 1 : high_global_core_id]: + update_vector_clock(c, shared_memory.clocks[low_global_core_id]) + +def _update_clocks_for_device_barrier(device_id): + """Synchronizes the vector clocks for the cores on the given device.""" + shared_memory = _get_shared_memory() + low_core_id = device_id * shared_memory.num_cores_per_device + high_core_id = (device_id + 1) * shared_memory.num_cores_per_device + _update_clocks(low_core_id, high_core_id) + +def _update_clocks_for_global_barrier(): + """Synchronizes all vector clocks.""" + shared_memory = _get_shared_memory() + _update_clocks(0, shared_memory.num_cores) + +def _barrier(device_id): + device_id = int(device_id) + shared_memory = _get_shared_memory() + if shared_memory.num_devices > 1: + shared_memory.barrier.wait() def _clean_up_shared_memory(device_id): device_id = int(device_id) shared_memory = _get_shared_memory() - shared_memory.barrier.wait() - if device_id == 0: - _clear_shared_memory() + shared_memory.clean_up_barrier.wait() def _validate(device_id): device_id = int(device_id) shared_memory = _get_shared_memory() + local_core_ids = tuple(range(shared_memory.num_cores_per_device)) with shared_memory.lock: for sem in shared_memory.sem.values(): with sem.cv: - if sem.counts[device_id] != 0: - # TODO(jburnim): Make this raise an error, but in a way that doesn't - # cause other devices to hang later in `_clean_up_shared_memory`. - print( - f'Semaphore {sem.id} has non-zero count for {device_id} at ' - f'kernel exit: {sem.counts[device_id]}') - -def _allocate_buffer(device_id, memory_space, val): + for lci in local_core_ids: + global_core_id = _get_global_core_id(device_id, lci) + if sem.counts[global_core_id] != 0: + # TODO(jburnim): Make this raise an error, but in a way that doesn't + # cause other devices to hang later in `_clean_up_shared_memory`. + print( + f'Semaphore {sem.id} has non-zero count for {device_id} ' + f' (core {lci}) at kernel exit: {sem.counts[global_core_id]}') + +def _allocate_buffer( + device_id: Array, + local_core_id: Array | None, + memory_space: Array, + val: Array, +): + """Allocates a memory buffer on the device with id `device_id` and core with id `local_core_id`. + + Args: + device_id: Singleton array holding the device id where the buffer will be + allocated. + local_core_id: None or singleton array holding the core id where the buffer + will be allocated. If None, a buffer will be allocated on each cores on + the device. + memory_space: Singleton array indicating the memory space to allocate the + buffer in. If the corresponding memory space is "any" (i.e. HBM), at most + one buffer will be allocated and it will belong to (local) core id 0. + val: Array of values to initialize the allocated buffer with. + + Returns: + Integer id for the allocated buffer. + """ device_id = int(device_id) - memory_space = TPU_MEMORY_SPACE_NAMES[int(memory_space)] + memory_space_str = TPU_MEMORY_SPACE_NAMES[int(memory_space)] + del memory_space val = np.array(val) shared_memory = _get_shared_memory() + + if local_core_id is None: + local_core_id_int = 0 + local_core_ids = tuple(range(shared_memory.num_cores_per_device)) + else: + local_core_id_int = int(local_core_id) + local_core_ids = (local_core_id_int,) + del local_core_id + + local_core_id_to_buffer_id : dict[int, int] = {} with shared_memory.lock: - buffer_id = shared_memory.next_buffer_id[device_id] - shared_memory.next_buffer_id[device_id] = buffer_id + 1 - # TODO(jburnim): Add options for initializing memory (e.g., with NaNs, - # with zeros, or with the buffer ID). - shared_memory.mem[(memory_space, buffer_id, device_id)] = val + for lci in local_core_ids: + buffer_id = shared_memory.next_buffer_id[(device_id, lci)] + shared_memory.next_buffer_id[(device_id, lci)] = buffer_id + 1 + # If allocating in HBM, only actually allocate a buffer for core 0. + if lci == 0 or memory_space_str != 'any': + # If we are allocating more than one buffer, we must make additional + # copies of `val` so that each buffer is a distinct ndarray. + if len(local_core_id_to_buffer_id) > 0: + val = val.copy() + shared_memory.mem[(memory_space_str, buffer_id, device_id, lci)] = val + + local_core_id_to_buffer_id[lci] = buffer_id + + # The buffer ids should always be kept in sync across all cores. + assert all( + buffer_id == local_core_id_to_buffer_id[local_core_id_int] + for buffer_id in local_core_id_to_buffer_id.values() + ) # TODO(jburnim): Raise an error if buffer_id is too big for int16. - return np.int16(buffer_id) + return np.int16(local_core_id_to_buffer_id[local_core_id_int]) -def _deallocate_buffer(device_id, memory_space, buffer_id): +def _deallocate_buffer(device_id, local_core_id, memory_space, buffer_id): device_id = int(device_id) + local_core_id = int(local_core_id) memory_space = TPU_MEMORY_SPACE_NAMES[int(memory_space)] buffer_id = int(buffer_id) + if memory_space == 'any': + local_core_id = 0 + shared_memory = _get_shared_memory() with shared_memory.lock: - # TODO(jburnim): Error if buffer doesn't exist? - shared_memory.mem.pop((memory_space, buffer_id, device_id), None) + buff = shared_memory.mem.pop( + (memory_space, buffer_id, device_id, local_core_id) + ) + shared_memory.deallocated_bytes += buff.size * buff.itemsize + del buff + + should_collect = shared_memory.deallocated_bytes > 100_000_000 + if should_collect: + shared_memory.deallocated_bytes = 0 + + if should_collect: + # Periodic garbage collection here prevents OOMs -- although it's not clear + # why arrays are not getting freed without this. + gc.collect() + + +def _allocate_semaphores( + device_id: Array, local_core_id: Array | None, shape: Array +): + """Allocates semaphores on the device with id `device_id` and core with id `local_core_id`. + + The number of semaphores allocated is given by the product of the entries in + `shape`. + + Since for each semaphore id there is really only one global `Semaphore` + object, 'allocation' of semaphores per device and core here means that the + internal counter of semaphore ids that is held by `SharedMemory` is + incremented for each the device and core (or for all cores on the dive if + argument `local_core_id` is None, see below). + + Args: + device_id: Singleton array holding the id for the device where the + semaphores will be allocated. + local_core_id: None or singleton array holding the id for the core where the + semaphores will be allocated. If None, semaphores will be allocated on all + cores on the device. + shape: Shape of the semaphore array to allocate. -def _allocate_semaphores(device_id, shape): + Returns: + Array of semaphore ids. + """ device_id = int(device_id) shape = tuple(map(int, shape)) num_semaphores = math.prod(shape) shared_memory = _get_shared_memory() + + if local_core_id is None: + local_core_id_int = 0 + global_core_ids = tuple( + _get_global_core_id(device_id, core_id) + for core_id in range(shared_memory.num_cores_per_device) + ) + else: + local_core_id_int = int(local_core_id) + global_core_ids = (_get_global_core_id(device_id, local_core_id_int),) + del local_core_id + + global_core_id_to_semaphore_id = {} with shared_memory.lock: - semaphore_id = shared_memory.next_semaphore_id[device_id] - shared_memory.next_semaphore_id[device_id] = semaphore_id + num_semaphores - for i in range(semaphore_id, semaphore_id + num_semaphores): - if i not in shared_memory.sem: - shared_memory.sem[i] = Semaphore(i) + for gci in global_core_ids: + semaphore_id = shared_memory.next_semaphore_id[gci] + shared_memory.next_semaphore_id[gci] = ( + semaphore_id + num_semaphores + ) + + # Ensure that only one global `Semaphore` object is allocated for each + # `semaphore_id`. + for i in range(semaphore_id, semaphore_id + num_semaphores): + if i not in shared_memory.sem: + shared_memory.sem[i] = Semaphore(i) + + global_core_id_to_semaphore_id[gci] = semaphore_id + + global_core_id = _get_global_core_id(device_id, local_core_id_int) + # The semaphore ids should always be kept in sync across all cores. + assert all( + semaphore_id == global_core_id_to_semaphore_id[global_core_id] + for semaphore_id in global_core_id_to_semaphore_id.values() + ) # NOTE: For now, we use a relatively uncommon datatype (int16) for # semaphore (and buffer) IDs, so these values are more easily identifiable # in kernels. # # TODO(jburnim): Raise an error if any IDs are too big for int16. - return np.int16( - range(semaphore_id, semaphore_id + num_semaphores) + semaphore_id = global_core_id_to_semaphore_id[global_core_id] + return np.arange( + semaphore_id, semaphore_id + num_semaphores, dtype=np.int16 ).reshape(shape) -TPU_MEMORY_SPACE_IDXS : dict[mosaic_core.TPUMemorySpace | None, int] = { - v: i for i, v in enumerate(mosaic_core.TPUMemorySpace)} +TPU_MEMORY_SPACE_IDXS : dict[mosaic_core.MemorySpace | pallas_core.MemorySpace | None, int] = { + v: i for i, v in enumerate(mosaic_core.MemorySpace)} +TPU_MEMORY_SPACE_IDXS[pallas_core.MemorySpace.ANY] = ( + TPU_MEMORY_SPACE_IDXS[mosaic_core.MemorySpace.ANY]) TPU_MEMORY_SPACE_NAMES = { - i: v.value for i, v in enumerate(mosaic_core.TPUMemorySpace)} + i: v.value for i, v in enumerate(mosaic_core.MemorySpace)} # Default to VMEM when no memory space is specified. TPU_MEMORY_SPACE_IDXS[None] = ( - TPU_MEMORY_SPACE_IDXS[mosaic_core.TPUMemorySpace.VMEM]) + TPU_MEMORY_SPACE_IDXS[mosaic_core.MemorySpace.VMEM]) def get_barrier_semaphore(device_id, collective_id): del device_id @@ -647,24 +890,48 @@ def _to_range(transforms) -> tuple[slice | int, ...]: ret, tuple(_transform_slice_or_index(i) for i in transform.indices)) return ret -def get(device_id, memory_space, buffer_id, transforms, *, - src_device_id=None, clock=None, source_info=None): +def _to_int(x : int | Array | None) -> int | None: + """Converts a value to an integer, or returns None if the value is None.""" + if x is None: + return None + return int(x) + +def get( + device_id, + local_core_id, + memory_space, + buffer_id, + transforms, + *, + src_device_id=None, + src_local_core_id=None, + clock=None, + source_info=None, +): device_id = int(device_id) + local_core_id = int(local_core_id) memory_space = TPU_MEMORY_SPACE_NAMES[int(memory_space)] buffer_id = int(buffer_id) try: transforms = jax.tree.map(int, transforms) except: raise ValueError('Advanced indexers are not supported on TPU') + src_device_id = _to_int(src_device_id) + src_local_core_id = _to_int(src_local_core_id) + + local_core_id_for_buffer = 0 if memory_space == 'any' else local_core_id + global_core_id = _get_global_core_id(device_id, local_core_id) shared_memory = _get_shared_memory() with shared_memory.lock: read_range = _to_range(transforms) if shared_memory.interpret_params.detect_races: - inc_vector_clock(shared_memory.clocks[device_id], device_id) + inc_vector_clock(shared_memory.clocks[global_core_id], global_core_id) if clock is None: - clock = copy_vector_clock(shared_memory.clocks[device_id]) - buffer = shared_memory.mem[(memory_space, buffer_id, device_id)] + clock = copy_vector_clock(shared_memory.clocks[global_core_id]) + buffer = shared_memory.mem[ + (memory_space, buffer_id, device_id, local_core_id_for_buffer) + ] ret = buffer[read_range].copy() if transforms: # TODO(jburnim): Instead of using NDIndexer, do the computation ourselves @@ -672,20 +939,43 @@ def get(device_id, memory_space, buffer_id, transforms, *, expected_shape = transforms[-1].get_indexer_shape() if expected_shape != ret.shape[:len(expected_shape)]: raise ValueError( - f'Out-of-bounds read of ({device_id} {memory_space} {buffer_id}): ' - f'reading [{read_range}] but bufer has shape {buffer.shape} .') + 'Out-of-bounds read of' + f' ({device_id} {local_core_id} {memory_space} {buffer_id}):' + f' reading [{read_range}] but buffer has shape {buffer.shape} .' + ) if shared_memory.interpret_params.detect_races: if src_device_id is None: src_device_id = device_id - check_read(src_device_id, clock, (memory_space, buffer_id, device_id), - read_range, source_info=source_info) + if src_local_core_id is None: + src_local_core_id = local_core_id + check_read( + src_device_id, + src_local_core_id, + clock, + (memory_space, buffer_id, device_id, local_core_id_for_buffer), + read_range, + source_info=source_info, + ) return ret -def store(device_id, memory_space, buffer_id, transforms, val, *, - src_device_id=None, clock=None, source_info=None): + +def store( + device_id, + local_core_id, + memory_space, + buffer_id, + transforms, + val, + *, + src_device_id=None, + src_local_core_id=None, + clock=None, + source_info=None, +): device_id = int(device_id) + local_core_id = int(local_core_id) memory_space = TPU_MEMORY_SPACE_NAMES[int(memory_space)] buffer_id = int(buffer_id) try: @@ -693,34 +983,62 @@ def store(device_id, memory_space, buffer_id, transforms, val, *, except: raise ValueError('Advanced indexers are not supported on TPU') val = np.array(val) + src_device_id = _to_int(src_device_id) + src_local_core_id = _to_int(src_local_core_id) + + local_core_id_for_buffer = 0 if memory_space == 'any' else local_core_id + global_core_id = _get_global_core_id(device_id, local_core_id) shared_memory = _get_shared_memory() with shared_memory.lock: if shared_memory.interpret_params.detect_races: - inc_vector_clock(shared_memory.clocks[device_id], device_id) + inc_vector_clock(shared_memory.clocks[global_core_id], global_core_id) if clock is None: - clock = copy_vector_clock(shared_memory.clocks[device_id]) + clock = copy_vector_clock(shared_memory.clocks[global_core_id]) - buff = shared_memory.mem[(memory_space, buffer_id, device_id)] + buff = shared_memory.mem[ + (memory_space, buffer_id, device_id, local_core_id_for_buffer) + ] assert buff.dtype == val.dtype # TODO(jburnim): Catch this statically. write_range = _to_range(transforms) # TODO(jburnim): Better error message if this raises? in_bounds_shape = buff[write_range].shape if in_bounds_shape != val.shape: raise ValueError( - f'Out-of-bounds write of ({device_id} {memory_space} {buffer_id}): ' - f'writing [{write_range}] but buffer has shape {buff.shape} .') + 'Out-of-bounds write of' + f' ({device_id} {local_core_id} {memory_space} {buffer_id}): writing' + f' [{write_range}] but buffer has shape {buff.shape} .' + ) buff[write_range] = val if shared_memory.interpret_params.detect_races: if src_device_id is None: src_device_id = device_id - check_write(src_device_id, clock, (memory_space, buffer_id, device_id), - write_range, source_info=source_info) + if src_local_core_id is None: + src_local_core_id = local_core_id + check_write( + src_device_id, + src_local_core_id, + clock, + (memory_space, buffer_id, device_id, local_core_id_for_buffer), + write_range, + source_info=source_info, + ) -def swap(device_id, memory_space, buffer_id, transforms, val, mask, *, - source_info=None): + +def swap( + device_id, + local_core_id, + memory_space, + buffer_id, + transforms, + val, + mask, + *, + source_info=None, +): device_id = int(device_id) + local_core_id = int(local_core_id) memory_space = TPU_MEMORY_SPACE_NAMES[int(memory_space)] buffer_id = int(buffer_id) try: @@ -732,12 +1050,17 @@ def swap(device_id, memory_space, buffer_id, transforms, val, mask, *, if mask is not None: assert mask.shape == val.shape + local_core_id_for_buffer = 0 if memory_space == 'any' else local_core_id + global_core_id = _get_global_core_id(device_id, local_core_id) + shared_memory = _get_shared_memory() with shared_memory.lock: if shared_memory.interpret_params.detect_races: - inc_vector_clock(shared_memory.clocks[device_id], device_id) - clock = copy_vector_clock(shared_memory.clocks[device_id]) - buff = shared_memory.mem[(memory_space, buffer_id, device_id)] + inc_vector_clock(shared_memory.clocks[global_core_id], global_core_id) + clock = copy_vector_clock(shared_memory.clocks[global_core_id]) + buff = shared_memory.mem[ + (memory_space, buffer_id, device_id, local_core_id_for_buffer) + ] assert buff.dtype == val.dtype # TODO(jburnim): Catch this statically. read_write_range = _to_range(transforms) # TODO(jburnim): Better error message if this raises? @@ -746,8 +1069,11 @@ def swap(device_id, memory_space, buffer_id, transforms, val, mask, *, if mask is None: if in_bounds_shape != val.shape: raise ValueError( - f'Out-of-bounds swap of ({device_id} {memory_space} {buffer_id}): ' - f'swapping [{read_write_range}] but buffer has shape {buff.shape} .') + 'Out-of-bounds swap of' + f' ({device_id} {local_core_id} {memory_space} {buffer_id}):' + f' swapping [{read_write_range}] but buffer has shape' + f' {buff.shape} .' + ) buff[read_write_range] = val return raw_result.copy() @@ -758,8 +1084,10 @@ def swap(device_id, memory_space, buffer_id, transforms, val, mask, *, # TODO(jburnim): Include indices of out-of-bounds locations where mask # is True. raise ValueError( - f'Out-of-bounds masked swap of ({device_id} {memory_space} {buffer_id}): ' - f'swapping [{read_write_range}] but buffer has shape {buff.shape} . ') + 'Out-of-bounds masked swap of' + f' ({device_id} {local_core_id} {memory_space} {buffer_id}): swapping' + f' [{read_write_range}] but buffer has shape {buff.shape} . ' + ) in_bounds_idx = tuple(slice(i) for i in in_bounds_shape) result = val.copy() @@ -769,8 +1097,14 @@ def swap(device_id, memory_space, buffer_id, transforms, val, mask, *, mask[in_bounds_idx], val[in_bounds_idx], raw_result) if shared_memory.interpret_params.detect_races: - check_write(device_id, clock, (memory_space, buffer_id, device_id), - read_write_range, source_info=source_info) + check_write( + device_id, + local_core_id, + clock, + (memory_space, buffer_id, device_id, local_core_id_for_buffer), + read_write_range, + source_info=source_info, + ) return result def execute_dma(dma): @@ -782,17 +1116,19 @@ def execute_dma(dma): if dma.virtual_device_id is None: # See comment in Semaphore.wait . dma.virtual_device_id = np.random.randint( - shared_memory.num_devices, NUM_VIRTUAL_DEVICES) + shared_memory.num_cores, NUM_VIRTUAL_CORES) # Do the read. if shared_memory.interpret_params.detect_races: inc_vector_clock(dma.clock, dma.virtual_device_id) dma.data = get(dma.src_device_id, + dma.src_local_core_id, dma.src_memory_space, dma.src_buffer_id, dma.src_transforms, clock=copy_vector_clock(dma.clock), src_device_id=dma.id, + src_local_core_id=0, source_info=dma.source_info) data_size = dma.data.itemsize * dma.data.size @@ -801,19 +1137,26 @@ def execute_dma(dma): inc_vector_clock(dma.clock, dma.virtual_device_id) if dma.src_sem is not None: dma.src_sem.signal( - data_size, device_id=dma.src_device_id, clock=dma.clock) + data_size, + global_core_id=_get_global_core_id( + dma.src_device_id, dma.src_local_core_id + ), + clock=dma.clock, + ) dma.state = DmaState.READ # Do the write. if shared_memory.interpret_params.detect_races: inc_vector_clock(dma.clock, dma.virtual_device_id) store(dma.dst_device_id, + dma.dst_local_core_id, dma.dst_memory_space, dma.dst_buffer_id, dma.dst_transforms, dma.data, clock=copy_vector_clock(dma.clock), src_device_id=dma.id, + src_local_core_id=0, source_info=dma.source_info) # Signal the receive semaphore. @@ -821,7 +1164,12 @@ def execute_dma(dma): inc_vector_clock(dma.clock, dma.virtual_device_id) if dma.dst_sem is not None: dma.dst_sem.signal( - data_size, device_id=dma.dst_device_id, clock=dma.clock) + data_size, + global_core_id=_get_global_core_id( + dma.dst_device_id, dma.dst_local_core_id + ), + clock=dma.clock, + ) dma.data = None dma.state = DmaState.COMPLETED @@ -833,11 +1181,24 @@ def print_memory(device_id): with shared_memory.lock: print(shared_memory.mem) -def dma_start(device_id, src_memory_space, src_id, src_transforms, - dst_memory_space, dst_id, dst_transforms, - dst_sem_id, src_sem_id, dst_device_id, - source_info=None): + +def dma_start( + device_id, + src_local_core_id, + src_memory_space, + src_id, + src_transforms, + dst_memory_space, + dst_id, + dst_transforms, + dst_sem_id, + src_sem_id, + dst_device_id, + source_info=None, +): device_id = int(device_id) + src_local_core_id = int(src_local_core_id) + src_global_core_id = _get_global_core_id(device_id, src_local_core_id) src_memory_space, src_id = int(src_memory_space), int(src_id) src_transforms = jax.tree.map(int, src_transforms) dst_memory_space, dst_id = int(dst_memory_space), int(dst_id) @@ -856,15 +1217,25 @@ def dma_start(device_id, src_memory_space, src_id, src_transforms, clock = None if shared_memory.interpret_params.detect_races: - inc_vector_clock(shared_memory.clocks[device_id], device_id) - clock = copy_vector_clock(shared_memory.clocks[device_id]) + inc_vector_clock( + shared_memory.clocks[src_global_core_id], src_global_core_id + ) + clock = copy_vector_clock(shared_memory.clocks[src_global_core_id]) dma_id = shared_memory.next_dma_id shared_memory.next_dma_id += 1 dma = DMA( dma_id, - device_id, src_memory_space, src_id, src_transforms, - dst_device_id, dst_memory_space, dst_id, dst_transforms, + device_id, + src_local_core_id, + src_memory_space, + src_id, + src_transforms, + dst_device_id, + src_local_core_id, # Same core on destination device as on source. + dst_memory_space, + dst_id, + dst_transforms, src_sem, dst_sem, clock=clock, @@ -880,52 +1251,61 @@ def dma_start(device_id, src_memory_space, src_id, src_transforms, assert shared_memory.interpret_params.dma_execution_mode == 'eager' execute_dma(dma) -def dma_wait(device_id, sem_id, size): +def dma_wait(device_id, local_core_id, sem_id, size): device_id = int(device_id) + local_core_id = int(local_core_id) sem_id = int(sem_id) size = int(size) + global_core_id = _get_global_core_id(device_id, local_core_id) shared_memory = _get_shared_memory() with shared_memory.lock: if shared_memory.interpret_params.detect_races: - inc_vector_clock(shared_memory.clocks[device_id], device_id) + inc_vector_clock(shared_memory.clocks[global_core_id], global_core_id) sem = shared_memory.sem[sem_id] - sem.wait(size, device_id, is_dma=True) + sem.wait(size, global_core_id, is_dma=True) -def semaphore_signal(device_id, sem_id, inc, target_device_id, - target_core_index): +def semaphore_signal(device_id, local_core_id, sem_id, inc, target_device_id, + target_local_core_id): device_id = int(device_id) + local_core_id = int(local_core_id) sem_id = int(sem_id) inc = int(inc) + src_global_core_id = _get_global_core_id(device_id, local_core_id) if target_device_id is None: target_device_id = device_id else: target_device_id = int(target_device_id) - if target_core_index is not None: - if int(target_core_index) != 0: - raise NotImplementedError('semaphore_signal with target_core_index != 0') + if target_local_core_id is None: + target_local_core_id = 0 shared_memory = _get_shared_memory() with shared_memory.lock: clock = None if shared_memory.interpret_params.detect_races: - inc_vector_clock(shared_memory.clocks[device_id], device_id) - clock = copy_vector_clock(shared_memory.clocks[device_id]) + inc_vector_clock( + shared_memory.clocks[src_global_core_id], src_global_core_id + ) + clock = copy_vector_clock(shared_memory.clocks[src_global_core_id]) sem = shared_memory.sem[sem_id] - sem.signal(inc, target_device_id, clock) + sem.signal( + inc, _get_global_core_id(target_device_id, target_local_core_id), clock + ) -def semaphore_wait(device_id, sem_id, value): +def semaphore_wait(device_id, local_core_id, sem_id, value): device_id = int(device_id) + local_core_id = int(local_core_id) sem_id = int(sem_id) value = int(value) + global_core_id = _get_global_core_id(device_id, local_core_id) shared_memory = _get_shared_memory() with shared_memory.lock: if shared_memory.interpret_params.detect_races: - inc_vector_clock(shared_memory.clocks[device_id], device_id) + inc_vector_clock(shared_memory.clocks[global_core_id], global_core_id) sem = shared_memory.sem[sem_id] - sem.wait(value, device_id) + sem.wait(value, global_core_id) def _compute_transformed_shape_and_dtype(shape, dtype, transforms): for transform in transforms: @@ -948,9 +1328,9 @@ def _device_coords_to_logical_id(device_coords, axis_sizes): def _device_id_to_logical(device_id, device_id_type, axis_sizes): if device_id is None: return None - if device_id_type == mosaic_primitives.DeviceIdType.MESH: + if device_id_type == primitives.DeviceIdType.MESH: return _device_coords_to_logical_id(device_id, axis_sizes) - elif device_id_type == mosaic_primitives.DeviceIdType.LOGICAL: + elif device_id_type == primitives.DeviceIdType.LOGICAL: return device_id else: raise ValueError(f'Unsupported device ID type: {device_id_type}') @@ -962,7 +1342,7 @@ def _to_jaxpr(flat_fun, in_avals): return new_jaxpr def _is_any(memory_space): - return ((memory_space == mosaic_core.TPUMemorySpace.ANY) or + return ((memory_space == mosaic_core.MemorySpace.ANY) or (memory_space == pallas_core.MemorySpace.ANY)) def _is_float(dtype): @@ -976,7 +1356,18 @@ class Placeholder: shape: tuple[int, ...] dtype: jnp.dtype -def _interpret_jaxpr(jaxpr, *args, compiler_params, interpret_params): + +def _interpret_jaxpr( + jaxpr, + *args, + axis_sizes, + mesh, + axis_indices, + device_id, + local_core_id, + compiler_params, + interpret_params +): env = {} def read(var): @@ -993,23 +1384,21 @@ def write(var, value): value = Placeholder(value.shape, value.dtype) env[var] = value - jax.util.safe_map(write, jaxpr.constvars + jaxpr.invars, args) - - # Get the device ID. - axis_sizes = jax_core.get_axis_env().axis_sizes - device_id = _device_coords_to_logical_id( - tuple(lax.axis_index(s) for s in axis_sizes.keys()), - axis_sizes) - # TODO(jburnim): Pass the device ID around, instead of re-fetching/computing - # it for each sub-jaxpr. + jax._src.util.safe_map(write, jaxpr.constvars + jaxpr.invars, args) # TODO(jburnim): Clean up and finish this evaluation loop. For example: # - Replace the big if-statement with a dictionary of rules. # - Handle other higher-order primitives? - # - Megacore. _interpret = functools.partial( - _interpret_jaxpr, compiler_params=compiler_params, - interpret_params=interpret_params) + _interpret_jaxpr, + axis_sizes=axis_sizes, + mesh=mesh, + axis_indices=axis_indices, + device_id=device_id, + local_core_id=local_core_id, + compiler_params=compiler_params, + interpret_params=interpret_params, + ) for eqn in jaxpr.eqns: with source_info_util.user_context( eqn.source_info.traceback, name_stack=eqn.source_info.name_stack): @@ -1019,7 +1408,9 @@ def write(var, value): # not need to do any reads if `interpret_params.skip_floating_point_ops` # is True. If this is the case, we want to avoid materializing the read # array into the jaxpr when this function is traced. - deferred_invals = functools.partial(jax.util.safe_map, read, eqn.invars) + deferred_invals = functools.partial( + jax._src.util.safe_map, read, eqn.invars + ) if prim is primitives.load_p: (ref, transforms, mask, _) = jax.tree.unflatten( @@ -1030,6 +1421,7 @@ def write(var, value): functools.partial(get, source_info=eqn.source_info), eqn.outvars[0].aval, device_id, + local_core_id, TPU_MEMORY_SPACE_IDXS[eqn.invars[0].aval.memory_space], ref, transforms, @@ -1042,6 +1434,7 @@ def write(var, value): functools.partial(swap, source_info=eqn.source_info), eqn.outvars[0].aval, device_id, + local_core_id, TPU_MEMORY_SPACE_IDXS[eqn.invars[0].aval.memory_space], ref, transforms, @@ -1050,8 +1443,36 @@ def write(var, value): ordered=True) elif prim is mosaic_primitives.delay_p: + # TODO(jburnim): Implement this properly? out = [] + elif prim is mosaic_primitives.prng_seed_p: + # TODO(jburnim): Implement this properly? + out = [] + + elif prim is mosaic_primitives.prng_random_bits_p: + # TODO(jburnim): Implement this properly? + out = jnp.zeros(eqn.params['shape'], jnp.int32) + + elif prim is verification.assume_p: + out = read(eqn.invars[0]) + + elif prim is verification.pretend_p: + out = [] + + elif ((prim is lax.axis_index_p) + and (mesh is not None) and (eqn.params['axis_name'] in mesh.shape)): + # We are interpreting a core_map, and this lax.axis_index call is + # querying our index along the core axis, so return our core ID. + out = local_core_id + + elif ((prim is lax.axis_index_p) + and (eqn.params['axis_name'] in axis_indices)): + # We replace lax.axis_index calls in the kernel body, so that the + # kernel body jaxpr can be run on other threads (via an io_callback) + # without having to recreate the axis environment in those threads. + out = axis_indices[eqn.params['axis_name']] + elif prim is lax.cond_p: def _make_branch(jaxpr): return lambda *args: _interpret(jaxpr, *args) @@ -1102,15 +1523,21 @@ def f(*args, jaxpr): out = pjit.pjit_p.bind(*invals, **(eqn.params | {'jaxpr': new_jaxpr})) elif prim is primitives.run_scoped_p: + if eqn.params['collective_axes']: + raise NotImplementedError( + 'run_scoped_p with collective axes is not supported' + ) # Allocate a buffer or semaphore for each element of - # eqn.params['jaxpr'].invars . + # eqn.params['jaxpr'].invars. It is assumed that each core + # runs the same sequence of `run_scoped`s. allocs = [] for v in eqn.params['jaxpr'].invars: - if v.aval.memory_space == mosaic_core.TPUMemorySpace.SEMAPHORE: + if v.aval.memory_space == mosaic_core.MemorySpace.SEMAPHORE: allocs.append(callback.io_callback( _allocate_semaphores, jax.ShapeDtypeStruct(v.aval.shape, jnp.int16), device_id, + local_core_id, v.aval.shape, ordered=True)) else: @@ -1118,6 +1545,7 @@ def f(*args, jaxpr): _allocate_buffer, jax.ShapeDtypeStruct((), jnp.int16), device_id, + local_core_id, TPU_MEMORY_SPACE_IDXS[v.aval.memory_space], _uninitialized_value( v.aval.shape, v.aval.dtype, interpret_params), @@ -1125,16 +1553,8 @@ def f(*args, jaxpr): out = _interpret(eqn.params['jaxpr'], *deferred_invals(), *allocs) - for a in allocs: - if isinstance(a, tuple): - callback.io_callback( - _deallocate_buffer, - None, - device_id, - TPU_MEMORY_SPACE_IDXS[v.aval.memory_space], - a, - ordered=True) - else: + for a, v in zip(allocs, eqn.params['jaxpr'].invars): + if v.aval.memory_space == mosaic_core.MemorySpace.SEMAPHORE: # TODO(jburnim): De-allocate semaphores. # callback.io_callback( # _deallocate_semaphores, @@ -1143,6 +1563,15 @@ def f(*args, jaxpr): # a, # ordered=True) pass + else: + callback.io_callback( + _deallocate_buffer, + None, + device_id, + local_core_id, + TPU_MEMORY_SPACE_IDXS[v.aval.memory_space], + a, + ordered=True) elif prim is state_primitives.get_p: invals = deferred_invals() @@ -1150,6 +1579,7 @@ def f(*args, jaxpr): functools.partial(get, source_info=eqn.source_info), eqn.outvars[0].aval, device_id, + local_core_id, TPU_MEMORY_SPACE_IDXS[eqn.invars[0].aval.memory_space], invals[0], jax.tree.unflatten(eqn.params['tree'], invals[1:]), @@ -1161,6 +1591,7 @@ def f(*args, jaxpr): functools.partial(swap, source_info=eqn.source_info), eqn.outvars[0].aval, device_id, + local_core_id, TPU_MEMORY_SPACE_IDXS[eqn.invars[0].aval.memory_space], invals[0], jax.tree.unflatten(eqn.params['tree'], invals[2:]), @@ -1188,9 +1619,10 @@ def f(*args, jaxpr): functools.partial(dma_start, source_info=eqn.source_info), (), device_id, - TPU_MEMORY_SPACE_IDXS[getattr(orig_src_ref.aval, 'memory_space', mosaic_core.TPUMemorySpace.ANY)], + local_core_id, + TPU_MEMORY_SPACE_IDXS[getattr(orig_src_ref.aval, 'memory_space', mosaic_core.MemorySpace.ANY)], src, src_transforms, - TPU_MEMORY_SPACE_IDXS[getattr(orig_dst_ref.aval, 'memory_space', mosaic_core.TPUMemorySpace.ANY)], + TPU_MEMORY_SPACE_IDXS[getattr(orig_dst_ref.aval, 'memory_space', mosaic_core.MemorySpace.ANY)], dst, dst_transforms, state_discharge.transform_array(dst_sem, dst_sem_transforms), state_discharge.transform_array(src_sem, src_sem_transforms), @@ -1216,6 +1648,7 @@ def f(*args, jaxpr): dma_wait, (), device_id, + local_core_id, state_discharge.transform_array(dst_sem, dst_sem_transforms), math.prod(read_shape) * read_dtype.itemsize, ordered=True) @@ -1226,10 +1659,10 @@ def f(*args, jaxpr): get_barrier_semaphore, jax.ShapeDtypeStruct((), jnp.int16), device_id, - compiler_params['mosaic']['collective_id'], + _get_mosaic_params(compiler_params).collective_id, ordered=True) - elif prim is mosaic_primitives.semaphore_signal_p: + elif prim is primitives.semaphore_signal_p: sem, sem_transforms, inc, target_device_id, core_index = ( jax.tree.unflatten(eqn.params['args_tree'], deferred_invals())) target_device_id = _device_id_to_logical( @@ -1238,6 +1671,7 @@ def f(*args, jaxpr): semaphore_signal, (), device_id, + local_core_id, state_discharge.transform_array(sem, sem_transforms), inc, target_device_id, @@ -1245,13 +1679,14 @@ def f(*args, jaxpr): ordered=True) out = [] - elif prim is mosaic_primitives.semaphore_wait_p: + elif prim is primitives.semaphore_wait_p: sem, sem_transforms, value = ( jax.tree.unflatten(eqn.params['args_tree'], deferred_invals())) callback.io_callback( semaphore_wait, (), device_id, + local_core_id, state_discharge.transform_array(sem, sem_transforms), value, ordered=True) @@ -1279,38 +1714,46 @@ def f(*args, jaxpr): out = prim.bind(*subfuns, *deferred_invals(), **bind_params) out = out if prim.multiple_results else [out] - jax.util.safe_map(write, eqn.outvars, out) - - return jax.util.safe_map(read, jaxpr.outvars) - -def _initialize_output_vals( - block_mappings_output: Iterable[BlockMapping], - input_args, input_output_aliases, - interpret_params: TPUInterpretParams, -) -> Sequence[jax.Array]: - oi_map = {v: k for k, v in input_output_aliases} - output_vals = [] - for i, bm in enumerate(block_mappings_output): - if i in oi_map: - output_vals.append(input_args[oi_map[i]]) - else: - output_vals.append(_uninitialized_value( - bm.array_shape_dtype.shape, - bm.array_shape_dtype.dtype, - interpret_params)) - return output_vals - -def _compute_start_indices(block_mapping, loop_idx, *args): - block_indices = ( - jax_core.jaxpr_as_fun(block_mapping.index_map_jaxpr)(*loop_idx, *args)) - if isinstance(block_mapping.indexing_mode, pallas_core.Blocked): - ret = tuple(i if b is pallas_core.mapped else b * i - for b, i in zip(block_mapping.block_shape, block_indices)) - elif isinstance(block_mapping.indexing_mode, pallas_core.Unblocked): - ret = block_indices - else: - raise RuntimeError(f"Unknown indexing mode: {block_mapping.indexing_mode}") - return ret + jax._src.util.safe_map(write, eqn.outvars, out) + + return jax._src.util.safe_map(read, jaxpr.outvars) + +def _compute_start_indices( + block_mapping, loop_idx, *args, + axis_sizes, mesh, axis_indices, device_id, local_core_id, + compiler_params, interpret_params): + jaxpr = block_mapping.index_map_jaxpr + block_indices = _interpret_jaxpr( + jaxpr.jaxpr, + *jaxpr.consts, + *loop_idx, + *args, + axis_sizes=axis_sizes, + mesh=mesh, + axis_indices=axis_indices, + device_id=device_id, + local_core_id=local_core_id, + compiler_params=compiler_params, + interpret_params=interpret_params, + ) + def _get_start_index(i, b): + match b: + case pallas_core.Squeezed(): + return i + case pallas_core.Element(): + return i + case pallas_core.Blocked(): + return i * b.block_size + case _: + raise ValueError(f"Unsupported block dim type: {type(b)}") + ret = jnp.array( + tuple( + _get_start_index(i, b) + for i, b in zip(block_indices, block_mapping.block_shape) + ), + dtype=jnp.int32, + ) + return ret def _get_next_indices(grid, indices): next_indices = [] @@ -1321,12 +1764,142 @@ def _get_next_indices(grid, indices): next_indices.append(jnp.where(carry, 0, i)) return tuple(reversed(next_indices)) -def _maybe_dynamic_slice(start_idx, block_shape, value, is_indexing): - start_idx = tuple(jnp.array(s, dtype=jnp.int32) for s in start_idx) - output = lax.dynamic_slice(value, start_idx, slice_sizes=block_shape) - squeeze_dims = tuple(np.arange(len(is_indexing))[np.array(is_indexing, - dtype=np.bool_)]) - return lax.squeeze(output, squeeze_dims) +def _get_indices(grid, loop_index): + indices = [] + for dim_size in reversed(grid): + i = loop_index % dim_size + loop_index = loop_index // dim_size + indices.append(i) + return tuple(reversed(indices)) + +def _get_mosaic_params(compiler_params: dict[str, pallas_core.CompilerParams]) -> mosaic_core.CompilerParams: + try: + return cast(mosaic_core.CompilerParams, compiler_params['mosaic_tpu']) + except KeyError: + return mosaic_core.CompilerParams() + + +def _get_parallel_dim_semantics( + compiler_params: dict[str, Any], num_dimensions_in_grid: int, +) -> tuple[bool, ...]: + """Returns a tuple indicating which grid dimensions have parallel semantics. + + Args: + compiler_params: Representation of a `mosaic_core.CompilerParams` object + as a dictionary. + num_dimensions_in_grid: The number of dimensions in the grid. + + Returns: + A tuple of booleans where the entry at index `i` is `True` precisely if the + `i`-th dimension in the grid has parallel semantics. + + Raises: + ValueError: If the dimensions with parallel semantics do not form a prefix + of the grid. + """ + mosaic_params = _get_mosaic_params(compiler_params) + if mosaic_params.dimension_semantics is None: + return (False,) * num_dimensions_in_grid + result = tuple(ds in ('parallel', mosaic_core.PARALLEL) + for ds in mosaic_params.dimension_semantics) + for ds0, ds1 in zip(result[:-1], result[1:]): + if ds1 and not ds0: + raise ValueError( + 'Dimensions with parallel semantics must form a prefix of the grid.' + ) + return result + + +def _get_parallel_subgrid_size( + parallel_semantics_per_dim: tuple[bool, ...], grid: tuple[int, ...] +) -> int: + """Returns the size of the subgrid along the parallel dimensions.""" + return functools.reduce( + lambda x, y: x * y, + ( + dim_size if parallel_dim else 1 + for dim_size, parallel_dim in zip(grid, parallel_semantics_per_dim) + ), + 1, + ) + +_GridPointCoordinatesPerDim = tuple[Array, ...] + +def _get_randomized_grid_coordinates( + grid: tuple[int, ...], + compiler_params: dict[str, Any], + random_seed: int | None, +) -> _GridPointCoordinatesPerDim: + """Returns a tuple of randomized coordinates for each 'parallel' dimension in `grid`. + + For a dimension with 'parallel' semantics at position `d` in the grid, the + returned tuple contains a random permutation of the sequence `[0,..., + grid[d] - 1]` at index `d`. For each dimension with 'arbitrary' semantics, + the resulting tuple contains an empty array. (Inserting an empty array for an + 'arbitrary' dimension at position `d` in the grid, instead of the sequence + `[0,..., grid[d] - 1]`, allows `grid[d]` to be a dynamic value, i.e. a value + not known at Jax trace time.) + + Args: + grid: Tuple of sizes of the dimensions in the grid. + compiler_params: Representation of a `mosaic_core.CompilerParams` object + as a dictionary. + parallel_semantics_per_dim: A tuple of booleans indicating whether the + corresponding dimension in the grid has parallel semantics. + random_seed: The seed to use for randomizing coordinates in parallel + dimensions. + """ + parallel_semantics_per_dim = _get_parallel_dim_semantics( + compiler_params, len(grid) + ) + + key = jax.random.key(random_seed or 0) + grid_point_coordinates = [] + for dim_size, parallel_dim in zip(grid, parallel_semantics_per_dim): + if parallel_dim: + # The size of a dimension with `parallel` semantics must be known at Jax + # trace time. This ensures that the arguments to `jnp.arange` and + # `jax.random.permutation` below are valid. + dim_size = jax_core.concrete_or_error(None, dim_size) + + coordindates_along_dim = jnp.arange(dim_size, dtype=jnp.int32) + key, subkey = jax.random.split(key) + coordindates_along_dim = jax.random.permutation( + subkey, coordindates_along_dim + ) + grid_point_coordinates.append(coordindates_along_dim) + else: + grid_point_coordinates.append(jnp.array((), dtype=jnp.int32)) + + return tuple(grid_point_coordinates) + + +def _get_grid_point( + loop_indices: tuple[Array, ...], + grid_point_coordinates: _GridPointCoordinatesPerDim, +) -> Array: + """Indexes each entry in `grid_point_coordinates` with the corresponding entry in `loop_indices`. + + If an entry in `grid_point_coordinates` is an empty array, the corresponding + entry in the returned array is the corresponding entry in `loop_indices`. + Otherwise, the returned array contains the entry in `grid_point_coordinates` + indexed with the corresponding entry in `loop_indices`. + + Args: + loop_indices: A tuple of loop indices. + grid_point_coordinates: A tuple of coordinate arrays for each dimension in + the grid. Dimensions with 'arbitrary' semantics are represented by empty + arrays. Dimensions with 'parallel' semantics are represented by arrays of + randomized coordinates. + + Returns: + A 1-dimensional array containing the coordinates for the grid point + corresponding to the specified `loop_indices`. + """ + grid_point = [] + for li, coords in zip(loop_indices, grid_point_coordinates): + grid_point.append(li if jnp.size(coords) == 0 else coords[li]) + return jnp.array(grid_point, dtype=np.int32) def _uninitialized_value(shape, dtype, interpret_params): if interpret_params.uninitialized_memory == 'nan': @@ -1367,6 +1940,53 @@ def _pad_to_block_dimension(value, block_shape, interpret_params): def get_interpret_effects(): return {callback._OrderedIOEffect} +def _thread_map(f, num_threads): + if num_threads == 1: + f(jnp.int32(0)) + return + + def _f(core_index): + f(core_index) + return () + jaxpr = jax.make_jaxpr(_f)(jnp.int32(0)) + + _call_threadmap_callback(jaxpr.jaxpr, num_threads, *jaxpr.consts) + +def _run_jaxpr(jaxpr, consts, *args): + def _run(jaxpr, consts, *args): + jax_core.eval_jaxpr(jaxpr, consts, *args) + traced = jax.jit(_run, static_argnums=(0,)).trace(jaxpr, consts, *args) + traced.lower().compile()(consts, *args) + return + +import concurrent.futures + +def _thread_map_callback(jaxpr, num_threads, consts): + num_threads = int(num_threads) + threads = [] + with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor: + for i in range(num_threads): + threads.append( + executor.submit(_run_jaxpr, jaxpr, consts, jnp.int32(i))) + for i in range(num_threads): + threads[i].result() + +def _call_threadmap_callback(jaxpr, num_threads, *consts): + # NOTE: At runtime, _thread_map_callback will lower and compile the + # given jaxpr. (JAX's caches should ensure the jaxpr is only lowered and + # compiled once.) + # + # TODO(jburnim): Would it be worth trying to lower/compile the jaxpr at + # lowering/compilation time? E.g., by using a custom primitive here, could + # we lower/compile jaxpr at lowering time, and then pass the compiled + # function to the callback? + return callback.io_callback( + functools.partial(_thread_map_callback, jaxpr), + (), + num_threads, + consts, + ordered=True) + def interpret_pallas_call( *args, jaxpr: jax_core.Jaxpr, @@ -1374,12 +1994,20 @@ def interpret_pallas_call( input_output_aliases: tuple[tuple[int, int], ...], grid_mapping: GridMapping, mesh: pallas_core.Mesh | None, - compiler_params: Any, + compiler_params: dict[str, Any], cost_estimate: CostEstimate, out_avals: tuple[jax_core.AbstractValue, ...], - interpret_params: TPUInterpretParams, + interpret_params: InterpretParams, ): - del debug, mesh, cost_estimate, out_avals + del debug, cost_estimate, out_avals + + if isinstance(mesh, mosaic_core.TensorCoreMesh): + # As a convenience for users, if we are interpreting a pl.core_map over a + # TensorCoreMesh, we automatically set the number of cores per device so + # that users don't have to specify it in the InterpretParams. + assert len(mesh.shape) == 1 + interpret_params = dataclasses.replace( + interpret_params, num_cores_per_device=mesh.devices.shape[0]) # args contains: *dynamic_grid_sizes, *index, *inputs. (No consts?) dynamic_grid_args, scalars, input_args = split_list( @@ -1397,25 +2025,26 @@ def interpret_pallas_call( axis_sizes = jax_core.get_axis_env().axis_sizes num_devices = functools.reduce( jnp.multiply, axis_sizes.values(), jnp.int32(1)) + axis_indices = {k: lax.axis_index(k) for k in axis_sizes.keys()} device_id = _device_coords_to_logical_id( - tuple(lax.axis_index(s) for s in axis_sizes.keys()), - axis_sizes) + tuple(axis_indices.values()), axis_sizes) callback.io_callback( functools.partial( _initialize_shared_memory, interpret_params=interpret_params), (), device_id, num_devices, + interpret_params.num_cores_per_device, ordered=True) # Pad input arguments. - is_indexing_dim = [ - tuple(b is pallas_core.mapped for b in bm.block_shape) + is_squeeze_dim = [ + tuple(isinstance(b, pallas_core.Squeezed) for b in bm.block_shape) for bm in grid_mapping.block_mappings ] block_shapes = [ - tuple(1 if i else b for i, b in zip(iid, bm.block_shape)) - for iid, bm in zip(is_indexing_dim, grid_mapping.block_mappings) + pallas_core._get_block_shape(bm.block_shape) + for bm in grid_mapping.block_mappings ] num_inputs = grid_mapping.num_inputs input_args = [ @@ -1423,75 +2052,121 @@ def interpret_pallas_call( for a, bs in zip(input_args, block_shapes[:num_inputs]) ] - # Allocate buffers in HBM for outputs. - output_buffer_ids = [] - output_buffer_shapes = [] - output_vals = _initialize_output_vals( - grid_mapping.block_mappings_output, - scalars + input_args, - input_output_aliases, - interpret_params) - num_outputs = grid_mapping.num_outputs - output_block_shapes = block_shapes[num_inputs : num_inputs + num_outputs] - for out_val, bs in zip(output_vals, output_block_shapes): - padded_val = _pad_to_block_dimension(out_val, bs, interpret_params) - output_buffer_shapes.append(padded_val.shape) - output_buffer_ids.append(callback.io_callback( + # Allocate HBM buffers for pallas_call inputs. + # + # TODO(jburnim): As an optimization, skip allocating buffers for inputs that + # are neither aliased nor passed to the kernel in HBM? + input_buffer_ids = [] + for i, var in enumerate( + jaxpr.invars[grid_mapping.num_index_operands:][:grid_mapping.num_inputs]): + assert var.aval.dtype == input_args[i].dtype + input_buffer_ids.append(callback.io_callback( _allocate_buffer, jax.ShapeDtypeStruct((), jnp.int16), device_id, - TPU_MEMORY_SPACE_IDXS[mosaic_core.TPUMemorySpace.ANY], - padded_val, + None, # local_core_id + TPU_MEMORY_SPACE_IDXS[mosaic_core.MemorySpace.ANY], + input_args[i], ordered=True)) - # Allocate buffers for all kernel arguments (e.g., scalars, inputs, - # outputs, scratch). - io_alias_map = dict(input_output_aliases) + + # Allocate buffers in HBM for pallas_call outputs. oi_alias_map = {v: k for k, v in input_output_aliases} - kernel_buffer_ids = [] - for _, val in zip(jaxpr.invars[grid_mapping.slice_index_ops], scalars): - kernel_buffer_ids.append(callback.io_callback( - _allocate_buffer, - jax.ShapeDtypeStruct((), jnp.int16), - device_id, - TPU_MEMORY_SPACE_IDXS[mosaic_core.TPUMemorySpace.SMEM], - val, - ordered=True)) + output_buffer_ids = [] + output_buffer_shapes = [] + output_vals = [] + num_outputs = grid_mapping.num_outputs + output_block_shapes = block_shapes[num_inputs : num_inputs + num_outputs] + for i, bm in enumerate(grid_mapping.block_mappings_output): + if i in oi_alias_map: + # Reuse the HBM buffer for the aliased pallas_call input. + output_buffer_ids.append(input_buffer_ids[oi_alias_map[i]]) + output_buffer_shapes.append(input_args[oi_alias_map[i]].shape) + output_vals.append(input_args[oi_alias_map[i]]) + else: + out_val = _uninitialized_value(bm.array_shape_dtype.shape, + bm.array_shape_dtype.dtype, + interpret_params) + padded_val = _pad_to_block_dimension( + out_val, output_block_shapes[i], interpret_params + ) + output_buffer_ids.append( + callback.io_callback( + _allocate_buffer, + jax.ShapeDtypeStruct((), jnp.int16), + device_id, + None, # local_core_id + TPU_MEMORY_SPACE_IDXS[mosaic_core.MemorySpace.ANY], + padded_val, + ordered=True, + ) + ) + output_buffer_shapes.append(padded_val.shape) + output_vals.append(out_val) + + # Allocate buffers for non-HBM kernel arguments (e.g., scalars, inputs, + # outputs, scratch). + scalar_buffer_ids = [] + for var, val in zip(jaxpr.invars[grid_mapping.slice_index_ops], scalars): + assert var.aval.shape == val.shape + assert var.aval.dtype == val.dtype + scalar_buffer_ids.append( + callback.io_callback( + _allocate_buffer, + jax.ShapeDtypeStruct((), jnp.int16), + device_id, + None, # local_core_id, + TPU_MEMORY_SPACE_IDXS[mosaic_core.MemorySpace.SMEM], + val, + ordered=True, + ) + ) + + kernel_buffer_ids = scalar_buffer_ids.copy() for i, var in enumerate(jaxpr.invars[grid_mapping.num_index_operands:]): output_idx = i - grid_mapping.num_inputs is_input = i < grid_mapping.num_inputs is_output = (output_idx >= 0) and (output_idx < grid_mapping.num_outputs) - if var.aval.memory_space == mosaic_core.TPUMemorySpace.SEMAPHORE: - kernel_buffer_ids.append(callback.io_callback( - _allocate_semaphores, - jax.ShapeDtypeStruct(var.aval.shape, jnp.int16), - device_id, - var.aval.shape, - ordered=True)) - elif is_output and _is_any(var.aval.memory_space): - # Use the already-allocated HBM output buffer. + if var.aval.memory_space == mosaic_core.MemorySpace.SEMAPHORE: + kernel_buffer_ids.append( + callback.io_callback( + _allocate_semaphores, + jax.ShapeDtypeStruct(var.aval.shape, jnp.int16), + device_id, + None, # local_core_id + var.aval.shape, + ordered=True, + ) + ) + elif _is_any(var.aval.memory_space): + # Use the already-allocated HBM input or output buffer. # - # TODO(jburnim): For kernel args in HBM, check that block shape is the - # same as for the corresponding pallas_call input, and that the index_map + # TODO(jburnim): For kernel args in HBM, check that block shape eqals the + # shape of the corresponding pallas_call input, and that the index_map # is trivial. - kernel_buffer_ids.append(output_buffer_ids[output_idx]) - elif is_output and (output_idx in oi_alias_map): - # Use the already-allocated (non-HBM) input buffer. - kernel_buffer_ids.append(kernel_buffer_ids[oi_alias_map[output_idx]]) - elif is_input and (i in io_alias_map) and _is_any(var.aval.memory_space): - # Use the already-allocated HBM output buffer. - kernel_buffer_ids.append(output_buffer_ids[io_alias_map[i]]) + assert is_input ^ is_output + if is_input: + kernel_buffer_ids.append(input_buffer_ids[i]) + if is_output: + kernel_buffer_ids.append(output_buffer_ids[output_idx]) else: - # TODO(jburnim): For kernel args in HBM, check that block shape is the - # same as for the corresponding pallas_call input, and that the index_map - # is trivial. - kernel_buffer_ids.append(callback.io_callback( - _allocate_buffer, - jax.ShapeDtypeStruct((), jnp.int16), - device_id, - TPU_MEMORY_SPACE_IDXS[var.aval.memory_space], - _uninitialized_value( - var.aval.shape, var.aval.dtype, interpret_params), - ordered=True)) + kernel_buffer_ids.append( + callback.io_callback( + _allocate_buffer, + jax.ShapeDtypeStruct((), jnp.int16), + device_id, + None, # local_core_id, + TPU_MEMORY_SPACE_IDXS[var.aval.memory_space], + _uninitialized_value( + var.aval.shape, var.aval.dtype, interpret_params + ), + ordered=True, + ) + ) + + if _get_mosaic_params(compiler_params).collective_id is None: + # The kernel doesn't specify its own barrier semaphore, so we do a global + # barrier before running the first iteration of the kernel. + callback.io_callback(_barrier, (), device_id, ordered=True) _, input_ids, kernel_output_ids, _ = split_list( kernel_buffer_ids, @@ -1499,119 +2174,313 @@ def interpret_pallas_call( input_vars, output_vars = split_list( jaxpr.invars[grid_mapping.slice_block_ops], [num_inputs]) - # For kernel inputs that are in HBM, we populate the buffer once before - # any kernel invocations. - for buffer_id, var, val in zip(input_ids, input_vars, input_args): - if not _is_any(var.aval.memory_space): - continue - if (val.shape != var.aval.shape) or (val.dtype != var.aval.dtype): - # TODO(jburnim): Also check that the index_map is trivial. - raise ValueError() - callback.io_callback( - store, - (), - device_id, - TPU_MEMORY_SPACE_IDXS[mosaic_core.TPUMemorySpace.ANY], - buffer_id, - (), - val, - ordered=True) - if grid: num_iterations = functools.reduce(jnp.multiply, grid) # type: ignore[arg-type] else: # Base case is always one iteration when grid is () num_iterations = 1 - def body(carry): - # The loop carry: (i, loop_idx) -- - # - i:int32 is the interation index - # - loop_idx: tuple[int32] are the program ids for each grid axis - i, loop_idx = carry + if isinstance(mesh, mosaic_core.TensorCoreMesh): + # We are interpreting a pl.core_map over a TensorCoreMesh, so we use a + # fixed division of the grid between cores, instead of a random division. + randomized_grid_coordinates = (jnp.array((), dtype=jnp.int32),) * len(grid) + else: + randomized_grid_coordinates = _get_randomized_grid_coordinates( + grid, compiler_params, interpret_params.random_seed # type: ignore[arg-type] + ) + parallel_dim_semantics = _get_parallel_dim_semantics( + compiler_params, len(grid) + ) + parallel_subgrid_size = _get_parallel_subgrid_size( + parallel_dim_semantics, grid # type: ignore[arg-type] + ) + num_points_in_parallel_subgrid_per_core = ( + parallel_subgrid_size + interpret_params.num_cores_per_device - 1 + ) // interpret_params.num_cores_per_device # We round up here. + num_iterations_per_point_in_parallel_subgrid = ( + # This is evenly divisible. + num_iterations // parallel_subgrid_size # type: ignore[operator] + ) + num_iterations_per_core = ( + num_points_in_parallel_subgrid_per_core + * num_iterations_per_point_in_parallel_subgrid + ) + def _get_local_grid_env(grid_point): if grid_mapping.local_grid_env is not None: - local_grid_env = grid_mapping.local_grid_env(loop_idx, grid) + return grid_mapping.local_grid_env(grid_point, grid) else: - local_grid_env = tuple( + return tuple( pallas_core.GridAxis(idx, b) - for dim, (idx, b) in enumerate(zip(loop_idx, grid)) + for dim, (idx, b) in enumerate(zip(grid_point, grid)) if dim not in grid_mapping.vmapped_dims ) - with pallas_core.grid_env(local_grid_env): - # Copy slices of the input to the kernel buffers. - # - # TODO(jburnim): Only copy slices when the index mapping has changed? - start_indices = [_compute_start_indices(bm, loop_idx, *scalars) - for bm in grid_mapping.block_mappings] - for j, var in enumerate(input_vars): - if _is_any(var.aval.memory_space): - continue - sliced_val = _maybe_dynamic_slice(start_indices[j], block_shapes[j], - input_args[j], is_indexing_dim[j]) - assert(sliced_val.shape == var.aval.shape) + def _execute_grid_for_core(core_index): + # NOTE: We assume here that all parallel dimensions appear before all + # arbitrary dimensions in the grid. (We will have raised an error earlier + # if this is not the case.) + # + # TODO(jburnim): Are we overusing nested local functions here? + initial_iteration_idx = core_index * num_iterations_per_core + loop_bound = jnp.minimum( + (core_index + 1) * num_iterations_per_core, num_iterations) + + def _body( + carry: tuple[ + jnp.int32, + tuple[jnp.int32, ...], + jnp.ndarray, + list[jnp.ndarray], + list[jnp.ndarray], + ], + ) -> tuple[ + jnp.int32, + tuple[jnp.int32, ...], + jnp.ndarray, + list[jnp.ndarray], + list[jnp.ndarray], + ]: + """Performs one execution of the kernel body. + + Execution of `jaxpr` is preceded by reading kernel input buffers and + followed by writing kernel output buffers. + + Args: + carry: (iteration_idx, loop_idx, grid_point, prev_start_indices, + cur_start_indices). + - iteration_idx: the iteration index. + - loop_idx: internal indices for looping over the grid. + - grid_point: the current positions along all axes of the grid. + - prev_start_indices: a rank-1 array that contains the start indices + for the slices of inputs and outputs processed in the previous loop + iteration. + - cur_start_indices: a rank-1 array that contains the start indices + for the slices of inputs and outputs processed in the current loop + iteration. + + Note that by carrying the previous *and* current start indices between + loop iterations, it suffices to compute only one list of start indices, + i.e. `next_start_indices` (see below), per iteration. + + Returns: + The carry for the next iteration. + """ + ( + iteration_idx, + loop_idx, + grid_point, + prev_start_indices, + cur_start_indices, + ) = carry + if interpret_params.grid_point_recorder is not None: callback.io_callback( - # TODO(jburnim): Pass source_info from the pallas_call, in case this - # store is involved in a data race. - store, + interpret_params.grid_point_recorder, (), - device_id, - TPU_MEMORY_SPACE_IDXS[var.aval.memory_space], - input_ids[j], - (), - sliced_val, - ordered=True) + grid_point, + core_index, + ) - # Invoke the kernel. - _interpret_jaxpr(jaxpr, *kernel_buffer_ids, - compiler_params=compiler_params, - interpret_params=interpret_params) + with pallas_core.grid_env(_get_local_grid_env(grid_point)): + next_loop_idx = _get_next_indices(grid, loop_idx) + next_grid_point = _get_grid_point( + next_loop_idx, randomized_grid_coordinates + ) + next_start_indices = [ + _compute_start_indices( + bm, + next_grid_point, + *scalar_buffer_ids, + axis_sizes=axis_sizes, + mesh=mesh, + axis_indices=axis_indices, + device_id=device_id, + local_core_id=core_index, + compiler_params=compiler_params, + interpret_params=interpret_params, + ) + for bm in grid_mapping.block_mappings + ] + + # Copy slices of the input to the kernel buffers. + def _store_slice_to_kernel_input(index, input_var): + # Copy from the HBM buffer for the pallas_call input to the kernel + # input buffer. + # TODO(jburnim): Just use input_args[j] when the input is not aliased? + transform = indexing.NDIndexer( + indices=tuple( + indexing.ds(st, sz) if not iid else st + for st, sz, iid in zip( + cur_start_indices[index], + block_shapes[index], + is_squeeze_dim[index], + ) + ), + shape=input_args[index].shape, + int_indexer_shape=(), + ) + sliced_val = callback.io_callback( + # TODO(jburnim): Pass source_info from the pallas_call, in case this + # read is involved in a data race. + get, + jax.ShapeDtypeStruct(input_var.aval.shape, input_var.aval.dtype), + device_id, + core_index, + TPU_MEMORY_SPACE_IDXS[mosaic_core.MemorySpace.ANY], + input_buffer_ids[index], + (transform,), + ordered=True, + ) + callback.io_callback( + # TODO(jburnim): Pass source_info from the pallas_call, in case this + # store is involved in a data race. + store, + (), + device_id, + core_index, + TPU_MEMORY_SPACE_IDXS[input_var.aval.memory_space], + input_ids[index], + (), + sliced_val, + ordered=True, + ) + + for j, var in enumerate(input_vars): + if _is_any(var.aval.memory_space): + continue + assert len(cur_start_indices[j].shape) == 1 + assert len(prev_start_indices[j].shape) == 1 + jax.lax.cond( + (iteration_idx == initial_iteration_idx) + | jax.lax.reduce_or( + cur_start_indices[j] != prev_start_indices[j], axes=(0,) + ), + functools.partial(_store_slice_to_kernel_input, j, var), + lambda: None, + ) + + # Invoke the kernel. + _interpret_jaxpr( + jaxpr, + *kernel_buffer_ids, + axis_sizes=axis_sizes, + mesh=mesh, + axis_indices=axis_indices, + device_id=device_id, + local_core_id=core_index, + compiler_params=compiler_params, + interpret_params=interpret_params, + ) - # Copy from the kernel buffers to slices of the output in HBM. - # - # TODO(jburnim): Only copy if the index mapping will change in the - # next iteration (or if this is the last iteration)? - for j, var in enumerate(output_vars): - if _is_any(var.aval.memory_space): - continue - kernel_output_val = callback.io_callback( - # TODO(jburnim): Pass source_info from the pallas_call, in case this - # get is involved in a data race. - get, - var.aval, - device_id, - TPU_MEMORY_SPACE_IDXS[var.aval.memory_space], - kernel_output_ids[j], - (), - ordered=True) - transform = indexing.NDIndexer( - indices=tuple(indexing.ds(st, sz) if not iid else st - for st, sz, iid in zip(start_indices[num_inputs + j], - block_shapes[num_inputs + j], - is_indexing_dim[num_inputs + j])), - shape=output_vals[j].shape, - int_indexer_shape=()) - callback.io_callback( - # TODO(jburnim): Pass source_info from the pallas_call, in case this - # store is involved in a data race. - store, - (), - device_id, - TPU_MEMORY_SPACE_IDXS[mosaic_core.TPUMemorySpace.ANY], - output_buffer_ids[j], - (transform,), - kernel_output_val, - ordered=True) + # Copy from the kernel buffers to slices of the output in HBM. + def _store_to_output_buffer(index, output_var): + kernel_output_val = callback.io_callback( + # TODO(jburnim): Pass source_info from the pallas_call, in case this + # get is involved in a data race. + get, + output_var.aval, + device_id, + core_index, + TPU_MEMORY_SPACE_IDXS[output_var.aval.memory_space], + kernel_output_ids[j], + (), + ordered=True, + ) + transform = indexing.NDIndexer( + indices=tuple( + indexing.ds(st, sz) if not iid else st + for st, sz, iid in zip( + cur_start_indices[num_inputs + index], + block_shapes[num_inputs + index], + is_squeeze_dim[num_inputs + index], + ) + ), + shape=output_vals[index].shape, + int_indexer_shape=(index), + ) + callback.io_callback( + # TODO(jburnim): Pass source_info from the pallas_call, in case this + # store is involved in a data race. + store, + (), + device_id, + core_index, + TPU_MEMORY_SPACE_IDXS[mosaic_core.MemorySpace.ANY], + output_buffer_ids[index], + (transform,), + kernel_output_val, + ordered=True, + ) + + for j, var in enumerate(output_vars): + if _is_any(var.aval.memory_space): + continue + assert len(cur_start_indices[num_inputs + j].shape) == 1 + assert len(next_start_indices[num_inputs + j].shape) == 1 + jax.lax.cond( + (iteration_idx + 1 == loop_bound) + | jax.lax.reduce_or( + cur_start_indices[num_inputs + j] + != next_start_indices[num_inputs + j], + axes=(0,), + ), + functools.partial(_store_to_output_buffer, j, var), + lambda: None, + ) + + return ( + iteration_idx + 1, + next_loop_idx, + next_grid_point, + cur_start_indices, + next_start_indices, + ) - return i + 1, _get_next_indices(grid, loop_idx) + initial_loop_idx = _get_indices(grid, initial_iteration_idx) + initial_grid_point = _get_grid_point( + initial_loop_idx, randomized_grid_coordinates) + with pallas_core.grid_env(_get_local_grid_env(initial_grid_point)): + initial_start_indices = [ + _compute_start_indices( + bm, + initial_grid_point, + *scalar_buffer_ids, + axis_sizes=axis_sizes, + mesh=mesh, + axis_indices=axis_indices, + device_id=device_id, + local_core_id=core_index, + compiler_params=compiler_params, + interpret_params=interpret_params, + ) + for bm in grid_mapping.block_mappings + ] + + _ = lax.while_loop( + lambda carry: carry[0] < loop_bound, + _body, + ( + initial_iteration_idx, + initial_loop_idx, + initial_grid_point, + initial_start_indices, # Previous start indices are ignored on the first iteration. + initial_start_indices, + ), + ) - # TODO(jburnim): Handle parallel grid dimensions + megacore. - _ = lax.while_loop( - lambda carry: carry[0] < num_iterations, - body, - (jnp.int32(0), (jnp.int32(0),) * len(grid)) + # TODO(jburnim): Should we only create happens-before here from core 0 to + # the other cores? + callback.io_callback( + _update_clocks_for_device_barrier, (), device_id, ordered=True ) + _thread_map(_execute_grid_for_core, interpret_params.num_cores_per_device) + + # TODO(jburnim): Should we only create happens-before here from the other + # # cores to core 0? + callback.io_callback( + _update_clocks_for_device_barrier, (), device_id, ordered=True) + # Read the output from the allocated output buffers. ret = [ callback.io_callback( @@ -1620,7 +2489,8 @@ def body(carry): get, val, device_id, - TPU_MEMORY_SPACE_IDXS[mosaic_core.TPUMemorySpace.ANY], + 0, # local_core_id + TPU_MEMORY_SPACE_IDXS[mosaic_core.MemorySpace.ANY], output_buffer_id, (indexing.NDIndexer.from_indices_shape( tuple(indexing.ds(0, s) for s in val.shape), diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 10b9de7487eb..02e3e8930651 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -15,12 +15,13 @@ """Module for lowering JAX to Mosaic-compatible MLIR dialects.""" from __future__ import annotations -from collections.abc import Callable, Sequence +from collections.abc import Callable, Collection, Hashable, Sequence import contextlib import dataclasses import functools +import operator import string -from typing import Any, Hashable +from typing import Any, TypeVar import jax from jax import api_util @@ -39,20 +40,26 @@ from jax._src import source_info_util from jax._src import state from jax._src import traceback_util +from jax._src import xla_bridge from jax._src.cloud_tpu_init import is_cloud_tpu_older_than +from jax._src.export import shape_poly from jax._src.export._export import export from jax._src.interpreters import mlir from jax._src.interpreters import partial_eval as pe +from jax._src.lax import control_flow from jax._src.lax import lax as lax_internal -from jax._src.lax.control_flow import for_loop +from jax._src.lax.control_flow import BranchesPlatforms, for_loop +from jax._src.lib import version as jaxlib_version from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import arith +from jax._src.lib.mlir.dialects import cf from jax._src.lib.mlir.dialects import func from jax._src.lib.mlir.dialects import math from jax._src.lib.mlir.dialects import memref from jax._src.lib.mlir.dialects import scf from jax._src.lib.mlir.dialects import vector from jax._src.pallas import core as pallas_core +from jax._src.pallas import helpers as pallas_helpers from jax._src.pallas import pallas_call from jax._src.pallas import primitives from jax._src.pallas import utils as pallas_utils @@ -80,10 +87,10 @@ # mypy: ignore-errors NDIndexer = indexing.NDIndexer -TPUMemorySpace = tpu_core.TPUMemorySpace -MemorySpace = pallas_core.MemorySpace | TPUMemorySpace -VMEM = tpu_core.TPUMemorySpace.VMEM -SMEM = tpu_core.TPUMemorySpace.SMEM +TPUMemorySpace = tpu_core.MemorySpace +AnyMemorySpace = pallas_core.MemorySpace | TPUMemorySpace +VMEM = TPUMemorySpace.VMEM +SMEM = TPUMemorySpace.SMEM # Booleans are stored as the following type in memrefs. BOOL_MEMREF_TYPE = np.dtype('int32') @@ -156,15 +163,14 @@ def to_placeholder(self, dim_expr: Any) -> ir.Value: @dataclasses.dataclass class LoweringContext: - ir_context: ir.Context grid_sizes: tuple[int, ...] # Includes both user and vmap axes. grid_names: tuple[Hashable, ...] | None mapped_dims: tuple[int, ...] # Indices of vmapped grid dimensions. user_grid_indices: Sequence[ir.Value] | None - block_shapes: list[tuple[int | pallas_core.Mapped, ...]] + block_shapes: list[tuple[int | pallas_core.Squeezed, ...]] name_stack: source_info_util.NameStack mesh_context: MeshContext | None - replace = dataclasses.replace + kernel_type: tpu_core.KernelType traceback_caches: mlir.TracebackCaches for_verification: bool forward_compatible: bool @@ -172,6 +178,8 @@ class LoweringContext: [tuple[jax.DimSize, ...]], tuple[int, ...] ] + replace = dataclasses.replace + @property def grid_rank(self): return len(self.grid_sizes) @@ -196,7 +204,8 @@ class LoweringRuleContext: lowering_context: LoweringContext avals_in: Sequence[jax_core.AbstractValue] avals_out: Sequence[jax_core.AbstractValue] - block_shapes: Sequence[tuple[int | pallas_core.Mapped, ...] | None] + block_shapes: Sequence[tuple[int | pallas_core.Squeezed, ...] | None] + replace = dataclasses.replace @property @@ -204,7 +213,7 @@ def forward_compatible(self): return self.lowering_context.forward_compatible -def _memory_space_to_tpu_memory_space(memory_space: MemorySpace | None +def _memory_space_to_tpu_memory_space(memory_space: AnyMemorySpace | None ) -> TPUMemorySpace: match memory_space: case None: @@ -214,7 +223,11 @@ def _memory_space_to_tpu_memory_space(memory_space: MemorySpace | None case pallas_core.MemorySpace.ANY: # Map the general ANY memory space to TPU ANY memory space return TPUMemorySpace.ANY - case pallas_core.MemorySpace.ERROR | pallas_core.MemorySpace.INDEX: + case ( + pallas_core.MemorySpace.ERROR + | pallas_core.MemorySpace.INDEX + | pallas_core.MemorySpace.KEY + ): return TPUMemorySpace.SMEM case TPUMemorySpace(): # Leave the memory space unchanged @@ -223,27 +236,27 @@ def _memory_space_to_tpu_memory_space(memory_space: MemorySpace | None raise ValueError(f"Invalid memory space: {memory_space}") -def _memory_space_to_mosaic_attribute(memory_space: MemorySpace | None +def _memory_space_to_mosaic_attribute(memory_space: AnyMemorySpace | None ) -> ir.Attribute: tpu_memory_space = _memory_space_to_tpu_memory_space(memory_space) return ir.Attribute.parse(f"#tpu.memory_space<{tpu_memory_space}>") -def _dtype_to_ir_type(dtype: jnp.dtype, +def _dtype_to_ir_type(dtype: jax.typing.DTypeLike, is_kernel_boundary: bool = False) -> ir.Type: - if jnp.issubdtype(dtype, tpu_core.semaphore_dtype): + if jnp.issubdtype(dtype, pallas_core.semaphore_dtype): if jnp.issubdtype(dtype, tpu_core.dma_semaphore): return ir.Type.parse("!tpu.dma_semaphore") - elif jnp.issubdtype(dtype, tpu_core.semaphore): + elif jnp.issubdtype(dtype, pallas_core.semaphore): return ir.Type.parse("!tpu.semaphore") - elif jnp.issubdtype(dtype, tpu_core.barrier_semaphore): + elif jnp.issubdtype(dtype, pallas_core.barrier_semaphore): return ir.Type.parse("!tpu.semaphore") else: raise NotImplementedError - if is_kernel_boundary and jnp.issubdtype(dtype, jnp.dtype('bool')): + if is_kernel_boundary and jnp.issubdtype(dtype, jnp.bool): dtype = BOOL_MEMREF_TYPE # TODO(justinfu): Remove after mosaic supports unsigned types. # This conversion makes mosaic interpret all unsigned types as signed types. - type = mlir.dtype_to_ir_type(dtype) + type = mlir.dtype_to_ir_type(jnp.dtype(dtype)) if isinstance(type, ir.IntegerType): return ir.IntegerType.get_signless(type.width) else: @@ -254,7 +267,7 @@ def aval_to_ir_type( dynamic_shape_replacement_fn, aval, shape=None, - memory_space: MemorySpace | None = None, + memory_space: AnyMemorySpace | None = None, is_kernel_boundary: bool = False, ): if isinstance(aval, tpu_core.AbstractSemaphore): @@ -309,20 +322,38 @@ def ir_constant(x, mlir_type=None): x = np.array(x, np.float32) if not mlir_type: mlir_type = _dtype_to_ir_type(x.dtype) - if isinstance(x, int) or np.issubdtype(x.dtype, np.integer): + if isinstance(x, int) or jnp.issubdtype(x.dtype, np.integer): return arith.constant(mlir_type, ir.IntegerAttr.get(mlir_type, int(x))) - elif isinstance(x, float) or x.dtype == np.float32: - return arith.constant(mlir_type, ir.FloatAttr.get(mlir_type, float(x))) - elif x.dtype == jnp.bfloat16: + elif isinstance(x, float) or jnp.issubdtype(x.dtype, jnp.floating): return arith.constant(mlir_type, ir.FloatAttr.get(mlir_type, float(x))) elif x.dtype == jnp.bool_: return arith.constant(mlir_type, ir.BoolAttr.get(bool(x))) raise NotImplementedError(x.dtype) -lowering_rules = {} +lowering_rules = {kernel_type: {} for kernel_type in tpu_core.KernelType} skip_mlir_conversions = set() + +T = TypeVar("T") + + +def register_lowering_rule( + prim: jax_core.Primitive, + *, + kernel_types: Collection[tpu_core.KernelType] = (tpu_core.KernelType.TC,), + ensure_mlir_values: bool = True, +) -> Callable[[T], T]: + def decorator(rule: T) -> T: + for kernel_type in kernel_types: + lowering_rules[kernel_type][prim] = rule + if not ensure_mlir_values: + skip_mlir_conversions.add(prim) + return rule + + return decorator + + def _get_aval_physical_dtype_shape(aval): dtype_physical_shape = jax_core.physical_aval(aval).shape[ len(aval.shape) : @@ -339,7 +370,7 @@ def _get_arg_type( ): memory_space = None if isinstance(aval, pallas_core.AbstractMemoryRef): - memory_space = aval.memory_space + memory_space = _memory_space_to_tpu_memory_space(aval.memory_space) # We assume unannotated memory refs are in VMEM if memory_space is None: memory_space = TPUMemorySpace.VMEM @@ -353,7 +384,13 @@ def _get_arg_type( ), aval.shape, ) - shape = tuple(1 if b is pallas_core.mapped else b for b in block_mapping.block_shape) + shape = pallas_core._get_block_shape(block_mapping.block_shape) + # Keep around squeezed as a sentinel for the lowering rules + block_shape = tuple( + pallas_core.squeezed if isinstance(b, pallas_core.Squeezed) + else pallas_core._get_block_dim_size(b) + for b in block_mapping.block_shape + ) return ( aval_to_ir_type( dynamic_shape_replacement_fn, @@ -361,7 +398,7 @@ def _get_arg_type( shape=shape, memory_space=memory_space, ), - block_mapping.block_shape, + block_shape, ) @@ -394,11 +431,12 @@ def __init__( self, jaxpr: jax_core.Jaxpr, grid_mapping: pallas_core.GridMapping, - dimension_semantics: tuple[str | tpu_core.GridDimensionSemantics, ...] | None, + dimension_semantics: Sequence[tpu_core.DimensionSemantics] | None, mesh: mesh_lib.Mesh | None, dynamic_shape_replacement_fn: Callable[ [tuple[jax.DimSize, ...]], tuple[int, ...] ], + arg_type_fn: Callable[..., ir.Type], ): self.grid = grid_mapping.grid self.grid_names = grid_mapping.grid_names @@ -438,17 +476,17 @@ def __init__( operand_avals = in_avals[grid_mapping.slice_block_ops] scratch_avals = in_avals[grid_mapping.slice_scratch_ops] self.scalar_prefetch_types, _ = unzip2([ - _get_arg_type(dynamic_shape_replacement_fn, aval, None) + arg_type_fn(dynamic_shape_replacement_fn, aval, None) for aval in scalar_prefetch_avals ]) self.scalar_prefetch_block_shapes = tuple( aval.shape for aval in scalar_prefetch_avals) self.operand_types, self.operand_block_shapes = unzip2([ - _get_arg_type(dynamic_shape_replacement_fn, aval, block_mapping) + arg_type_fn(dynamic_shape_replacement_fn, aval, block_mapping) for aval, block_mapping in zip(operand_avals, self.block_mappings) ]) self.scratch_types, _ = unzip2([ - _get_arg_type(dynamic_shape_replacement_fn, aval, None) + arg_type_fn(dynamic_shape_replacement_fn, aval, None) for aval in scratch_avals ]) self.scratch_block_shapes = tuple( @@ -456,7 +494,7 @@ def __init__( for aval in scratch_avals ) self.grid_types, _ = unzip2([ - _get_arg_type( + arg_type_fn( dynamic_shape_replacement_fn, pallas_core.index_map_grid_aval, None, @@ -562,10 +600,10 @@ def _check_block_mappings( rank = len(bm.block_shape) # TODO(necula): add tests for SMEM blocks with trivial windowing # We support scalars too - if (bm.block_aval.memory_space == tpu_core.TPUMemorySpace.SMEM and - bm.has_trivial_window()): + memory_space = _memory_space_to_tpu_memory_space(bm.block_aval.memory_space) + if memory_space == tpu_core.MemorySpace.SMEM and bm.has_trivial_window(): continue - if bm.block_aval.memory_space == tpu_core.TPUMemorySpace.SEMAPHORE: + if memory_space == tpu_core.MemorySpace.SEMAPHORE: continue def err_details(): @@ -575,21 +613,22 @@ def err_details(): # TODO(necula): add index_map source location info f"and index_map {bm.index_map_jaxpr.jaxpr}, in " f"memory space {bm.block_aval.memory_space}." - "\nSee details at https://jax.readthedocs.io/en/latest/pallas/grid_blockspec.html#pallas-blockspec") + "\nSee details at https://docs.jax.dev/en/latest/pallas/grid_blockspec.html#pallas-blockspec") if rank < 1: raise ValueError( "The Pallas TPU lowering currently supports only blocks of " "rank >= 1. " + err_details()) - if (bm.block_aval.memory_space == tpu_core.TPUMemorySpace.ANY and - not bm.has_trivial_window()): + if ( + memory_space == tpu_core.MemorySpace.ANY + and not bm.has_trivial_window() + ): raise ValueError( "The Pallas TPU lowering currently supports in memory space ANY " "only blocks having the same block shape as the array shape " "and a trivial index_map (returning all 0s)." + err_details()) - unmapped_bs = [ - 1 if bs is pallas_core.mapped else bs for bs in bm.block_shape] + unmapped_bs = pallas_core._get_block_shape(bm.block_shape) bs0, as0 = unmapped_bs[-1], bm.array_shape_dtype.shape[-1] if rank >= 2: bs1, as1 = unmapped_bs[-2], bm.array_shape_dtype.shape[-2] @@ -643,21 +682,21 @@ def err_details(): def lower_jaxpr_to_module( lowering_context: mlir.LoweringRuleContext, - ctx: ir.Context, grid_mapping: pallas_core.GridMapping, jaxpr: jax_core.Jaxpr, *, - dimension_semantics: ( - tuple[str | tpu_core.GridDimensionSemantics, None, ...] | None - ), + dimension_semantics: Sequence[tpu_core.DimensionSemantics] | None, + kernel_type: tpu_core.KernelType, mesh: mesh_lib.Mesh | None = None, for_verification: bool = False, dynamic_shape_replacement_enabled: bool = False, ) -> tuple[Module, tuple[Any, ...]]: # NOTE: We should bump this periodically if is_cloud_tpu_older_than(2025, 1, 10): + platform_version = xla_bridge.get_backend().platform_version raise RuntimeError( - "Pallas TPU requires a libTPU version that's at most a month old" + "Pallas TPU requires a libtpu version that's at most a month old. Found" + f" version string:\n{platform_version}" ) debug_info = jaxpr.debug_info _mosaic_lowering_dynamic_shape_env = None @@ -686,6 +725,7 @@ def dynamic_shape_replacement_fn( dimension_semantics, mesh, dynamic_shape_replacement_fn, + arg_type_fn=_get_arg_type, ) mosaic_grid_mapping.maybe_compress_grid() m = ir.Module.create() @@ -695,10 +735,10 @@ def dynamic_shape_replacement_fn( sym_tab = ir.SymbolTable(m.operation) func_op = lower_jaxpr_to_func( - ctx, jaxpr, mosaic_grid_mapping=mosaic_grid_mapping, name="main", + kernel_type=kernel_type, for_verification=for_verification, forward_compatible=lowering_context.is_forward_compat(), dynamic_shape_replacement_fn=dynamic_shape_replacement_fn, @@ -709,6 +749,12 @@ def dynamic_shape_replacement_fn( window_params = [] static_grid = None grid = mosaic_grid_mapping.grid + if not grid and any( + not bm.has_trivial_window() for bm in grid_mapping.block_mappings + ): + raise NotImplementedError( + "Non-trivial windowing is not supported for grid-free pallas_call." + ) if grid: for i, bm in enumerate(grid_mapping.block_mappings): func_name = f"transform_{i}" @@ -716,27 +762,34 @@ def dynamic_shape_replacement_fn( tpu_memory_space = _memory_space_to_tpu_memory_space( bm.block_aval.memory_space) if ( - tpu_memory_space == tpu_core.TPUMemorySpace.ANY - or tpu_memory_space == tpu_core.TPUMemorySpace.SEMAPHORE + tpu_memory_space == tpu_core.MemorySpace.ANY + or tpu_memory_space == tpu_core.MemorySpace.SEMAPHORE ): # We checked above that the block does not require windowing. window_params.append(ir.DictAttr.get()) continue mlir_func = lower_jaxpr_to_transform_func( - ctx, bm.index_map_jaxpr.jaxpr, bm.block_aval, name=func_name, mosaic_grid_mapping=mosaic_grid_mapping, + kernel_type=kernel_type, for_verification=for_verification, forward_compatible=lowering_context.is_forward_compat(), dynamic_shape_replacement_fn=dynamic_shape_replacement_fn, ) assert mlir_func.verify(), mlir_func - block_shape = [ - 1 if b is pallas_core.mapped else b for b in bm.block_shape - ] + block_shape = list(pallas_core._get_block_shape(bm.block_shape)) + + # Force single-buffering pipelining for trivial windowing in VMEM. + pipeline_mode = bm.pipeline_mode + if ( + tpu_memory_space == tpu_core.MemorySpace.VMEM + and bm.has_trivial_window() + ): + pipeline_mode = pallas_core.Buffered(1) + # If we have an extended dtype, we need to add the block shape for the # remaining physical dtype. block_shape += list(_get_aval_physical_dtype_shape(bm.block_aval.inner_aval)) @@ -746,28 +799,48 @@ def dynamic_shape_replacement_fn( window_bounds=window_shape, transform_indices=ir.FlatSymbolRefAttr.get(func_name), ) - if isinstance(bm.indexing_mode, pallas_core.Unblocked): - if bm.indexing_mode.padding is None: - pad_low = pad_high = [0] * len(bm.block_shape) - else: - pad_low, pad_high = map(list, zip(*bm.indexing_mode.padding)) + for bd in bm.block_shape: + if not isinstance( + bd, (pallas_core.Element, pallas_core.Squeezed, pallas_core.Blocked) + ): + raise NotImplementedError( + "Unsupported block dimension type: " + f"{type(bd)} for block shape: {bm.block_shape}" + ) + is_element_block = [isinstance(bd, pallas_core.Element) + for bd in bm.block_shape] + if any(is_element_block): + is_element_or_squeezed_block = [ + isinstance(bd, (pallas_core.Element, pallas_core.Squeezed)) + for bd in bm.block_shape + ] + if not all(is_element_or_squeezed_block): + raise NotImplementedError( + "All block dimensions must be Elements or none of them can be" + " Elements." + ) + padding = [ + bd.padding if isinstance(bd, pallas_core.Element) else (0, 0) + for bd in bm.block_shape + ] + pad_low, pad_high = map(list, zip(*padding)) block_params["window_kind"] = ir.Attribute.parse( f"#tpu.element_window<{pad_low},{pad_high}>" ) - if bm.pipeline_mode is not None: - if not isinstance(bm.pipeline_mode, pallas_core.Buffered): + if pipeline_mode is not None: + if not isinstance(pipeline_mode, pallas_core.Buffered): raise LoweringException( - f"Unsupported pipeline mode: {bm.pipeline_mode}." + f"Unsupported pipeline mode: {pipeline_mode}." ) - buffer_count = bm.pipeline_mode.buffer_count + buffer_count = pipeline_mode.buffer_count if buffer_count < 1 or buffer_count > 2: raise LoweringException( "Only single (1) and double (2) buffering are supported. Got" f" {buffer_count}." ) - pipeline_mode = "synchronous" if buffer_count == 1 else "double_buffered" + pipeline_mode_str = "synchronous" if buffer_count == 1 else "double_buffered" block_params["pipeline_mode"] = ir.Attribute.parse( - f"#tpu.pipeline_mode<{pipeline_mode}>" + f"#tpu.pipeline_mode<{pipeline_mode_str}>" ) window_params.append(ir.DictAttr.get(block_params)) m.body.append(mlir_func) @@ -844,14 +917,14 @@ def dynamic_shape_replacement_fn( def lower_jaxpr_to_transform_func( - ctx: ir.Context, jaxpr: jax_core.Jaxpr, aval: jax_core.AbstractValue, *, name: str, mosaic_grid_mapping: MosaicGridMapping, + kernel_type: tpu_core.KernelType, for_verification: bool, - forward_compatible: bool, + forward_compatible: bool, dynamic_shape_replacement_fn: ( Callable[[tuple[jax.DimSize, ...]], tuple[int, ...]] | None ) = None, @@ -879,7 +952,6 @@ def body_func(*args): else: mesh_context = None lowering_context = LoweringContext( - ctx, mosaic_grid_mapping.grid, mosaic_grid_mapping.grid_names, mosaic_grid_mapping.mapped_dims, @@ -887,6 +959,7 @@ def body_func(*args): arg_block_shapes, source_info_util.NameStack(), mesh_context=mesh_context, + kernel_type=kernel_type, traceback_caches=mlir.TracebackCaches(), for_verification=for_verification, forward_compatible=forward_compatible, @@ -911,12 +984,19 @@ def body_func(*args): return body.func_op +lower_jaxpr_to_func_fns = {} + + +def register_jaxpr_to_func(kernel_type: tpu_core.KernelType): + lower_jaxpr_to_func_fns[kernel_type] = lower_jaxpr_to_func + + def lower_jaxpr_to_func( - ctx: ir.Context, jaxpr: jax_core.Jaxpr, *, mosaic_grid_mapping: MosaicGridMapping, name: str, + kernel_type: tpu_core.KernelType, for_verification: bool, forward_compatible: bool, dynamic_shape_replacement_fn: ( @@ -951,7 +1031,6 @@ def body_func(*args): else: mesh_context = None lowering_context = LoweringContext( - ctx, mosaic_grid_mapping.grid, mosaic_grid_mapping.grid_names, mosaic_grid_mapping.mapped_dims, @@ -959,6 +1038,7 @@ def body_func(*args): arg_block_shapes, source_info_util.NameStack(), mesh_context=mesh_context, + kernel_type=kernel_type, traceback_caches=mlir.TracebackCaches(), for_verification=for_verification, forward_compatible=forward_compatible, @@ -1066,7 +1146,7 @@ def write_env(var: jax_core.Var, val): loc = mlir._source_info_to_location(ctx, eqn.primitive, source_info) with (source_info_util.user_context(eqn.source_info.traceback), loc, eqn.ctx.manager): - if eqn.primitive in lowering_rules: + if eqn.primitive in lowering_rules[ctx.kernel_type]: if eqn.primitive not in skip_mlir_conversions: invals = [_ensure_mlir_value(x, v.aval) for x, v in zip(invals, eqn.invars)] @@ -1084,12 +1164,12 @@ def write_env(var: jax_core.Var, val): current_name_stack, name_stack) current_name_stack = name_stack for _ in popped: - tpu.TraceStopOp() + tpu.trace_stop() for name in pushed: - tpu.TraceStartOp(message=name, level=10) + tpu.trace_start(message=name, level=10) try: - ans = lowering_rules[eqn.primitive]( + ans = lowering_rules[ctx.kernel_type][eqn.primitive]( rule_context, *invals, **eqn.params ) except LoweringException: @@ -1109,9 +1189,10 @@ def write_env(var: jax_core.Var, val): raise new_error from e else: raise NotImplementedError( - "Unimplemented primitive in Pallas TPU lowering: " - f"{eqn.primitive.name}. " - "Please file an issue on https://github.com/jax-ml/jax/issues.") + "Unimplemented primitive in Pallas TPU lowering for" + f" {ctx.kernel_type}: {eqn.primitive.name}. Please file an issue on" + " https://github.com/jax-ml/jax/issues." + ) if eqn.primitive.multiple_results: foreach(write_env, eqn.outvars, ans) else: @@ -1121,7 +1202,7 @@ def write_env(var: jax_core.Var, val): popped, pushed = _compute_name_stack_updates( current_name_stack, initial_name_stack) for _ in popped: - tpu.TraceStopOp() + tpu.trace_stop() assert len(pushed) == 0 outvals = map(read_env, jaxpr.outvars) @@ -1145,6 +1226,7 @@ def _ensure_mlir_value(val, aval): ) +@register_lowering_rule(state_primitives.get_p, ensure_mlir_values=False) def _get_lowering_rule( ctx: LoweringRuleContext, ref, *idx, tree, ): @@ -1161,10 +1243,7 @@ def _get_lowering_rule( return _load_lowering_rule(ctx, *args_flat, args_tree=args_tree) -lowering_rules[state_primitives.get_p] = _get_lowering_rule -skip_mlir_conversions.add(state_primitives.get_p) - - +@register_lowering_rule(state_primitives.swap_p, ensure_mlir_values=False) def _swap_lowering_rule( ctx: LoweringRuleContext, ref, @@ -1186,9 +1265,6 @@ def _swap_lowering_rule( ) return _masked_swap_lowering_rule(ctx, *args_flat, args_tree=args_tree) -lowering_rules[state_primitives.swap_p] = _swap_lowering_rule -skip_mlir_conversions.add(state_primitives.swap_p) - def _make_index(s): if isinstance(s, (int, np.ndarray)): @@ -1230,7 +1306,7 @@ def _index_to_start_size_stride( def _indexer_to_start_size_stride( indexer: NDIndexer, - ref_block_shape: tuple[int | pallas_core.Mapped, ...], + ref_block_shape: tuple[int | pallas_core.Squeezed, ...], *, cast_to_index: bool, ) -> tuple[ @@ -1238,21 +1314,21 @@ def _indexer_to_start_size_stride( tuple[int | ir.Value, ...], tuple[int, ...], tuple[bool, ...], - tuple[int | pallas_core.Mapped, ...], + tuple[int | pallas_core.Squeezed, ...], ]: indices_iter = iter(indexer.indices) starts, sizes, strides, squeeze_dims = [], [], [], [] for s in ref_block_shape: - start, size, stride, squeeze_dim = ( - ( - _maybe_cast_to_index(cast_to_index, 0), - 1, - 1, - True, + match s: + case pallas_core.Squeezed(): + start = _maybe_cast_to_index(cast_to_index, 0) + size = 1 + stride = 1 + squeeze_dim = True + case _: + start, size, stride, squeeze_dim = _index_to_start_size_stride( + next(indices_iter), cast_to_index ) - if s is pallas_core.mapped - else _index_to_start_size_stride(next(indices_iter), cast_to_index) - ) starts.append(start) sizes.append(size) strides.append(stride) @@ -1274,10 +1350,9 @@ def _slice_memref( ref: ir.Value, indexer: NDIndexer, ref_dtype: DTypeLike, - ref_block_shape: tuple[int | pallas_core.Mapped, ...], -) -> tuple[ir.Value, tuple[int | pallas_core.Mapped, ...]]: + ref_block_shape: tuple[int | pallas_core.Squeezed, ...], +) -> tuple[ir.Value, tuple[int | pallas_core.Squeezed, ...]]: assert ref_block_shape is not None - target_shape = indexer.get_indexer_shape() starts, sizes, strides, squeeze_dims, ref_block_shape = ( _indexer_to_start_size_stride( indexer, @@ -1287,26 +1362,68 @@ def _slice_memref( ) if not all((s is None or s == 1) for s in strides): raise NotImplementedError("Strided slices of references are unsupported.") - dynamic_sizes = tuple(s for s in sizes if isinstance(s, ir.Value)) + ir_dynamic_size = ir.ShapedType.get_dynamic_size() - static_sizes = tuple(s if not isinstance(s, ir.Value) - else ir_dynamic_size for s in sizes) - target_ref_ty = ir.MemRefType.get( - static_sizes, - _dtype_to_ir_type(ref_dtype), - memory_space=ref.type.memory_space, + static_starts = [] + for s in starts: + if not isinstance(s, ir.Value): + static_starts.append(s) + elif (v := _fold_and_get_constant_value(s)) is not None: + static_starts.append(v) + else: + static_starts.append(ir_dynamic_size) + + static_sizes = [] + dynamic_sizes = [] + for s in sizes: + if not isinstance(s, ir.Value): + static_sizes.append(s) + elif (v := _fold_and_get_constant_value(s)) is not None: + static_sizes.append(v) + else: + static_sizes.append(ir_dynamic_size) + dynamic_sizes.append(s) + + ref_ty = ir.MemRefType(ref.type) + ref_strides, ref_offset = ref_ty.get_strides_and_offset() + if ref_offset == ir_dynamic_size or ir_dynamic_size in static_starts: + target_offset = ir_dynamic_size + else: + target_offset = sum( + map(operator.mul, static_starts, ref_strides), ref_offset + ) + out_layout = ( + ir.StridedLayoutAttr.get(target_offset, ref_strides) + if not is_cloud_tpu_older_than(2025, 6, 20) + else None ) - out = tpu.memref_slice(target_ref_ty, ref, starts, dynamic_sizes) + out_ty = ir.MemRefType.get( + static_sizes, ref_ty.element_type, out_layout, ref_ty.memory_space + ) + out = tpu.memref_slice(out_ty, ref, starts, dynamic_sizes) if any(squeeze_dims): - # We need to squeeze out some dimensions - static_sizes = tuple(s if not isinstance(s, ir.Value) - else ir_dynamic_size for s in target_shape) - squeezed_ref_ty = ir.MemRefType.get( - static_sizes, - _dtype_to_ir_type(ref_dtype), - memory_space=ref.type.memory_space, + # We need to squeeze out some dimensions. + ref_ty = out_ty + del out_ty + ref_strides, ref_offset = ref_ty.get_strides_and_offset() + target_strides = [] + target_sizes = [] + for i, dim in enumerate(ref_ty.shape): + if not squeeze_dims[i]: + target_sizes.append(dim) + target_strides.append(ref_strides[i]) + out_layout = ( + ir.StridedLayoutAttr.get(ref_offset, target_strides) + if not is_cloud_tpu_older_than(2025, 6, 20) + else None + ) + out_ty = ir.MemRefType.get( + target_sizes, + ref_ty.element_type, + out_layout, + ref_ty.memory_space, ) - out = tpu.memref_squeeze(squeezed_ref_ty, out) + out = tpu.memref_squeeze(out_ty, out) return out, ref_block_shape @@ -1314,8 +1431,8 @@ def _bitcast_memref( ref: ir.Value, bitcaster: RefBitcaster, ref_dtype: DTypeLike, - ref_block_shape: tuple[int | pallas_core.Mapped, ...], -) -> tuple[ir.Value, DTypeLike, tuple[int | pallas_core.Mapped, ...]]: + ref_block_shape: tuple[int | pallas_core.Squeezed, ...], +) -> tuple[ir.Value, DTypeLike, tuple[int | pallas_core.Squeezed, ...]]: src_bitwidth = dtype_bitwidth(ref_dtype) dst_bitwidth = dtype_bitwidth(bitcaster.dtype) if src_bitwidth != dst_bitwidth: @@ -1323,7 +1440,7 @@ def _bitcast_memref( raise NotImplementedError( "Bitcast 1D ref with bitwidth change is not supported." ) - if ref_block_shape[-2] is pallas_core.mapped: + if ref_block_shape[-2] is pallas_core.squeezed: raise NotImplementedError( "Bitcast a ref whose 2nd minormost dimension is squeezed when" " bitwidth changes." @@ -1337,7 +1454,7 @@ def _bitcast_memref( new_ref_block_shape = list(ref_block_shape) if ( len(new_ref_block_shape) >= 2 - and new_ref_block_shape[-2] is not pallas_core.mapped + and new_ref_block_shape[-2] is not pallas_core.squeezed ): new_ref_block_shape[-2] = ( new_ref_block_shape[-2] * src_bitwidth // dst_bitwidth @@ -1353,8 +1470,8 @@ def _reshape_memref( ref: ir.Value, reshaper: RefReshaper, ref_dtype: DTypeLike, - ref_block_shape: tuple[int | pallas_core.Mapped, ...], -) -> tuple[ir.Value, DTypeLike, tuple[int | pallas_core.Mapped, ...]]: + ref_block_shape: tuple[int | pallas_core.Squeezed, ...], +) -> tuple[ir.Value, DTypeLike, tuple[int | pallas_core.Squeezed, ...]]: if ref_dtype != reshaper.dtype: raise ValueError( f"Reshape a ref with dtype change: {reshaper.dtype} vs {ref_dtype}" @@ -1362,8 +1479,8 @@ def _reshape_memref( if len(ref_block_shape) < 2: raise NotImplementedError("Reshape 1D ref is not supported.") if ( - ref_block_shape[-2] is pallas_core.mapped - or ref_block_shape[-1] is pallas_core.mapped + ref_block_shape[-2] is pallas_core.squeezed + or ref_block_shape[-1] is pallas_core.squeezed ): raise NotImplementedError( "Reshape a ref with squeezed dimension on last two dimensions." @@ -1421,6 +1538,8 @@ class KeyScalarBundle: key_shape: tuple[int, ...] scalars: list[ir.OpResult] + +@register_lowering_rule(primitives.load_p, ensure_mlir_values=False) def _load_lowering_rule(ctx: LoweringRuleContext, *args_flat, args_tree, **_): ref, transforms, mask, _ = args_tree.unflatten(args_flat) ref_aval, transforms_avals, _, _ = args_tree.unflatten(ctx.avals_in) @@ -1443,7 +1562,7 @@ def _load_lowering_rule(ctx: LoweringRuleContext, *args_flat, args_tree, **_): ): if not is_smem_load: raise ValueError("PRNG keys must be loaded from SMEM. Did you set " - "the memory space to TPUMemorySpace.SMEM in the " + "the memory space to MemorySpace.SMEM in the " "BlockSpec for the PRNG key input?") return _prng_key_load_lowering_rule(ctx, *args_flat, args_tree=args_tree) if not is_smem_load and not ref_block_shape: @@ -1494,10 +1613,13 @@ def _load_lowering_rule(ctx: LoweringRuleContext, *args_flat, args_tree, **_): starts, ) if load_aval != aval_out: - vec_type = ir.VectorType.get(aval_out.shape, - _dtype_to_ir_type(aval_out.dtype, - is_kernel_boundary=True)) - load_val = vector.shape_cast(vec_type, load_val) + if aval_out.shape: + vec_type = ir.VectorType.get(aval_out.shape, + _dtype_to_ir_type(aval_out.dtype, + is_kernel_boundary=True)) + load_val = vector.shape_cast(vec_type, load_val) + else: + load_val = vector.extract(load_val, [], [0] * len(load_aval.shape)) return _maybe_cast_load_to_bool(ctx, aval_out, load_val) def _prng_key_load_lowering_rule(ctx: LoweringRuleContext, *args_flat, args_tree) -> KeyScalarBundle: @@ -1514,13 +1636,12 @@ def _prng_key_load_lowering_rule(ctx: LoweringRuleContext, *args_flat, args_tree ref_block_shape = aval_out.dtype._impl.key_shape if len(ref_block_shape) != 2: - raise NotImplementedError("Seed key_data must be 2D.") - if tuple(ref_block_shape) != (1, 1): - raise NotImplementedError( - f"Seed key_data of shape != (1, 1) not supported. Got: {ref_block_shape}") + raise NotImplementedError("Seed key_data must be 1D.") + if ref_block_shape[0] != 1: + raise NotImplementedError("Leading dimension of seed key_data must be 1.") load_ops = [] - for i in range(ref_block_shape[0]): + for i in range(ref_block_shape[1]): idx = NDIndexer(indices=(0, i), shape=ref_block_shape, int_indexer_shape=tuple()) starts, _, _, _, _ = _indexer_to_start_size_stride( @@ -1532,10 +1653,6 @@ def _prng_key_load_lowering_rule(ctx: LoweringRuleContext, *args_flat, args_tree return KeyScalarBundle(scalars=load_ops, key_shape=tuple(ref_block_shape)) -lowering_rules[primitives.load_p] = _load_lowering_rule -skip_mlir_conversions.add(primitives.load_p) - - def _maybe_cast_load_to_bool( ctx, out_aval, val: ir.Value ) -> tuple[ir.Value, jnp.dtype]: @@ -1564,13 +1681,13 @@ def _maybe_cast_load_to_bool( out_aval, is_kernel_boundary=True, ) - vector_zeros = arith.ConstantOp( + vector_zeros = arith.constant( load_vector_type, ir.DenseElementsAttr.get_splat(load_vector_type, const_zero) ) return arith.cmpi(predicate, val, vector_zeros) else: # Scalar case. - const_zero = arith.ConstantOp(load_scalar_type, const_zero) + const_zero = arith.constant(load_scalar_type, const_zero) return arith.cmpi(predicate, val, const_zero) @@ -1588,6 +1705,7 @@ def _maybe_cast_store_to_memref_type( return arith.extui(int_out_type, val) +@register_lowering_rule(primitives.swap_p, ensure_mlir_values=False) def _masked_swap_lowering_rule( ctx: LoweringRuleContext, *args_flat, args_tree, **_ ): @@ -1643,7 +1761,7 @@ def _masked_swap_lowering_rule( result = memref.load(ref, starts) result = _maybe_cast_load_to_bool(ctx, val_aval, result) val = _maybe_cast_store_to_memref_type(ctx, val_aval, val) - memref.StoreOp(val, ref, starts) + memref.store(val, ref, starts) return result if not is_vmem_store: @@ -1664,7 +1782,7 @@ def _masked_swap_lowering_rule( mem_slice_shape.insert(i, 1) mem_slice_shape_iter = iter(mem_slice_shape) mem_slice_shape = [ - 1 if b is pallas_core.mapped else next(mem_slice_shape_iter) + 1 if b is pallas_core.squeezed else next(mem_slice_shape_iter) for b in ref_block_shape ] mem_aval = aval_out.update( @@ -1682,6 +1800,8 @@ def _masked_swap_lowering_rule( result = vector.load(mem_aval_vec_type, ref, starts) val = _maybe_cast_store_to_memref_type(ctx, val_aval, val) if mem_aval != aval_out: + if not aval_out.shape: + raise ValueError("Cannot swap scalars to VMEM.") # We are slicing a scalar so provided dummy 1 indices result_vec_type = ir.VectorType.get(aval_out.shape, _dtype_to_ir_type(aval_out.dtype, is_kernel_boundary=True)) @@ -1694,16 +1814,13 @@ def _masked_swap_lowering_rule( if need_stride: if mask is not None: raise NotImplementedError("masked swap with strided store") - tpu.StridedStoreOp(val, ref, starts, strides) + tpu.strided_store(val, ref, starts, strides) else: - tpu.VectorStoreOp(val, ref, starts, [], mask=mask) + tpu.vector_store(val, ref, starts, [], mask=mask) return result -lowering_rules[primitives.swap_p] = _masked_swap_lowering_rule -skip_mlir_conversions.add(primitives.swap_p) - - +@register_lowering_rule(primitives.multiple_of_p) def _multiple_of_lowering_rule(ctx: LoweringRuleContext, val, *, values): del ctx for multiple in values: @@ -1711,9 +1828,6 @@ def _multiple_of_lowering_rule(ctx: LoweringRuleContext, val, *, values): return val -lowering_rules[primitives.multiple_of_p] = _multiple_of_lowering_rule - - def reduce_lowering_rule(reduce_fn, type_to_kind, type_to_identity): def _lowering_rule(ctx: LoweringRuleContext, x, *, axes): (x_aval,) = ctx.avals_in @@ -1759,7 +1873,7 @@ def _proxy_fun(val, *, axes): ctx.lowering_context.dynamic_shape_replacement_fn, ctx.avals_out[0] ) identity = ir.DenseElementsAttr.get_splat(out_type, val) - acc = arith.ConstantOp(out_type, identity) + acc = arith.constant(out_type, identity) return vector.multi_reduction(kind, x, acc, axes) return _lowering_rule @@ -1775,7 +1889,7 @@ def _proxy_fun(val, *, axes): } _reduce_max_lowering_rule = reduce_lowering_rule( jnp.max, REDUCE_MAX_KINDS, REDUCE_MAX_IDENTITY) -lowering_rules[lax.reduce_max_p] = _reduce_max_lowering_rule +register_lowering_rule(lax.reduce_max_p)(_reduce_max_lowering_rule) REDUCE_MIN_KINDS = { @@ -1789,7 +1903,7 @@ def _proxy_fun(val, *, axes): } _reduce_min_lowering_rule = reduce_lowering_rule( jnp.min, REDUCE_MIN_KINDS, REDUCE_MIN_IDENTITY) -lowering_rules[lax.reduce_min_p] = _reduce_min_lowering_rule +register_lowering_rule(lax.reduce_min_p)(_reduce_min_lowering_rule) REDUCE_SUM_KINDS = { @@ -1803,9 +1917,10 @@ def _proxy_fun(val, *, axes): } _reduce_sum_lowering_rule = reduce_lowering_rule( jnp.sum, REDUCE_SUM_KINDS, REDUCE_SUM_IDENTITY) -lowering_rules[lax.reduce_sum_p] = _reduce_sum_lowering_rule +register_lowering_rule(lax.reduce_sum_p)(_reduce_sum_lowering_rule) +@register_lowering_rule(lax.reduce_and_p) def _reduce_and_lowering_rule(ctx: LoweringRuleContext, x, *, axes): def _proxy_reduce(arg, *, axes): # Mosaic currently only supports float reductions, so we cast the boolean @@ -1818,9 +1933,8 @@ def _proxy_reduce(arg, *, axes): _proxy_reduce, multiple_results=False) return proxy_lowering(ctx, x, axes=axes) -lowering_rules[lax.reduce_and_p] = _reduce_and_lowering_rule - +@register_lowering_rule(lax.reduce_or_p) def _reduce_or_lowering_rule(ctx: LoweringRuleContext, x, *, axes): def _proxy_reduce(arg, *, axes): # Mosaic currently only supports float reductions, so we cast the boolean @@ -1833,9 +1947,8 @@ def _proxy_reduce(arg, *, axes): _proxy_reduce, multiple_results=False) return proxy_lowering(ctx, x, axes=axes) -lowering_rules[lax.reduce_or_p] = _reduce_or_lowering_rule - +@register_lowering_rule(state_primitives.broadcast_to_p) def _broadcast_to_lowering_rule( ctx: LoweringRuleContext, x, shape: Sequence[int] ): @@ -1845,29 +1958,33 @@ def _broadcast_to_lowering_rule( ) -lowering_rules[state_primitives.broadcast_to_p] = _broadcast_to_lowering_rule - - +@register_lowering_rule( + lax.broadcast_in_dim_p, kernel_types=[*tpu_core.KernelType] +) def _broadcast_in_dim_lowering_rule( ctx: LoweringRuleContext, val, *, shape, broadcast_dimensions, sharding ): del sharding (aval_in,) = ctx.avals_in (aval_out,) = ctx.avals_out + if aval_in.shape == shape: + return val - if jnp.issubdtype(aval_in.dtype, jnp.bool_): + if jnp.issubdtype(aval_in.dtype, jnp.bool_) and ( + ctx.forward_compatible or is_cloud_tpu_older_than(2025, 6, 3) + ): # Direct broadcasts for bools are not supported in Mosaic due to booleans # living in mask registers and broadcast operating on vregs. Broadcast as an # integer instead and cast back to a bool. - # TODO(b/351019164): Implement this logic in Mosaic BroadcastOp instead. def _proxy_fun(val, *, shape, broadcast_dimensions): int_val = jnp.where(val, 1, 0) bcast_val = jax.lax.broadcast_in_dim(int_val, shape, broadcast_dimensions) return bcast_val == 1 - proxy_lowering = lower_fun( - _proxy_fun, multiple_results=False) + + proxy_lowering = lower_fun(_proxy_fun, multiple_results=False) return proxy_lowering( - ctx, val, shape=shape, broadcast_dimensions=broadcast_dimensions) + ctx, val, shape=shape, broadcast_dimensions=broadcast_dimensions + ) if broadcast_dimensions: out_shape_list = [1] * len(shape) @@ -1886,9 +2003,6 @@ def _proxy_fun(val, *, shape, broadcast_dimensions): return vector.broadcast(out_type, val) -lowering_rules[lax.broadcast_in_dim_p] = _broadcast_in_dim_lowering_rule - - def jax_dot_dims_to_tpu_dot_dot_dims(dimension_numbers, lhs_shape, rhs_shape): """Converts a jax dot dimension numbers to a tpu dot dimension numbers. @@ -1954,6 +2068,7 @@ def format_dims(dims): return ir.Attribute.parse(tpu_dim_numbers_str) +@register_lowering_rule(lax.dot_general_p) def _dot_general_lowering_rule( ctx: LoweringRuleContext, x, @@ -2026,12 +2141,12 @@ def _dot_general_lowering_rule( else: raise NotImplementedError(f"Unsupported {preferred_element_type=}") - acc = arith.ConstantOp( + acc = arith.constant( red_type, ir.DenseElementsAttr.get_splat(red_type, val) ) - red = vector.MultiDimReductionOp( + red = vector.multi_reduction( ir.Attribute.parse("#vector.kind"), - arith.MulFOp(x, y), + arith.mulf(x, y), acc, [1] ) @@ -2053,7 +2168,7 @@ def _dot_general_lowering_rule( ) else: raise NotImplementedError(f"Unsupported dot precision: {precision}") - out_tile = arith.ConstantOp( + out_tile = arith.constant( out_type, ir.DenseElementsAttr.get_splat(out_type, val) ) return tpu.matmul( @@ -2066,8 +2181,6 @@ def _dot_general_lowering_rule( ) -lowering_rules[lax.dot_general_p] = _dot_general_lowering_rule - def _convert_helper(x, *, to_dtype): # Helper function for dtype conversion from_dtype = x.dtype @@ -2096,18 +2209,12 @@ def _convert_helper(x, *, to_dtype): # unsigned -> float is unsupported. We fall through and raise at the bottom. if not jnp.issubdtype(to_dtype, jnp.floating): return x.astype(to_dtype) - if jnp.issubdtype(from_dtype, jnp.floating) and jnp.issubdtype( - to_dtype, jnp.signedinteger - ): - if from_dtype.itemsize < 4: - x = x.astype(jnp.float32) - if to_dtype.itemsize < 4: - # Need to clip values to match XLA - minval, maxval = jnp.iinfo(to_dtype).min, jnp.iinfo(to_dtype).max - x = jnp.clip(x, minval, maxval) - return x.astype(jnp.int32).astype(to_dtype) raise NotImplementedError(f"Unsupported cast: {from_dtype} -> {to_dtype}") + +@register_lowering_rule( + lax.convert_element_type_p, kernel_types=[*tpu_core.KernelType] +) def _convert_element_type_lowering_rule( ctx: LoweringRuleContext, x, *, new_dtype, weak_type, sharding ): @@ -2147,23 +2254,21 @@ def _convert_element_type_lowering_rule( elif jnp.iinfo(old_dtype).bits == jnp.iinfo(new_dtype).bits: # This case triggers when casting signed to unsigned or vice versa. return x - # TODO(apaszke): Remove both_32bit constraints using the Mosaic canonicalizer. elif _from(floating) and _to(signed): - # TODO(apaszke): Remove once a month has passed, along with the - # _convert_helper float -> signed conversion above. - if not ctx.forward_compatible or both_32bit: - return arith.fptosi(out_type, x) - elif _from(signed) and _to(floating) and both_32bit: - return arith.sitofp(out_type, x) + return arith.fptosi(out_type, x) + elif _from(signed) and _to(floating): + if ( + not (ctx.forward_compatible or is_cloud_tpu_older_than(2025, 5, 12)) + or both_32bit + ): + return arith.sitofp(out_type, x) elif old_dtype == jnp.bool_ and _to(integer) and new_dtype.itemsize == 4: return arith.extui(out_type, x) return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype), multiple_results=False)(ctx, x) -lowering_rules[lax.convert_element_type_p] = _convert_element_type_lowering_rule - - +@register_lowering_rule(lax.reshape_p) def _reshape_lowering_rule(ctx: LoweringRuleContext, x, new_sizes, dimensions, sharding): if dimensions is not None: @@ -2177,6 +2282,8 @@ def _reshape_lowering_rule(ctx: LoweringRuleContext, x, new_sizes, dimensions, ), x, ) + if not ctx.avals_out[0].shape: + return vector.extract(x, [], [0] * len(ctx.avals_in[0].shape)) return vector.shape_cast( aval_to_ir_type( ctx.lowering_context.dynamic_shape_replacement_fn, ctx.avals_out[0] @@ -2185,9 +2292,7 @@ def _reshape_lowering_rule(ctx: LoweringRuleContext, x, new_sizes, dimensions, ) -lowering_rules[lax.reshape_p] = _reshape_lowering_rule - - +@register_lowering_rule(lax.squeeze_p, kernel_types=[*tpu_core.KernelType]) def _squeeze_lowering_rule(ctx: LoweringRuleContext, x, dimensions): del dimensions # Unused. (aval_in,) = ctx.avals_in @@ -2208,9 +2313,7 @@ def _squeeze_lowering_rule(ctx: LoweringRuleContext, x, dimensions): ) -lowering_rules[lax.squeeze_p] = _squeeze_lowering_rule - - +@register_lowering_rule(lax.concatenate_p) def _concatenate_lowering_rule(ctx: LoweringRuleContext, *xs, dimension): out_type = aval_to_ir_type( ctx.lowering_context.dynamic_shape_replacement_fn, ctx.avals_out[0] @@ -2218,9 +2321,7 @@ def _concatenate_lowering_rule(ctx: LoweringRuleContext, *xs, dimension): return tpu.concatenate(out_type, xs, dimension=dimension) -lowering_rules[lax.concatenate_p] = _concatenate_lowering_rule - - +@register_lowering_rule(lax.split_p) def _split_lowering_rule( ctx: LoweringRuleContext, x, *, sizes, axis ): @@ -2245,20 +2346,27 @@ def _split_lowering_rule( starts[axis] += size return outs -lowering_rules[lax.split_p] = _split_lowering_rule - +@register_lowering_rule(lax.iota_p) def _iota_lowering_rule(ctx: LoweringRuleContext, dtype, shape, dimension, sharding): + if len(shape) == 1: + if dimension != 0: + raise ValueError("Dimension must be 0 for 1D iota.") + def _1d_iota_helper(): + iota_2d = lax.iota_p.bind(dtype=dtype, + shape=(1,) + shape, + dimension=1, + sharding=sharding) + return iota_2d[0] + return lower_fun(_1d_iota_helper, multiple_results=False)(ctx) out_type = aval_to_ir_type( ctx.lowering_context.dynamic_shape_replacement_fn, ctx.avals_out[0] ) return tpu.iota(out_type, dimension=dimension) -lowering_rules[lax.iota_p] = _iota_lowering_rule - - +@register_lowering_rule(lax.gather_p) def _gather_lowering_rule( ctx: LoweringRuleContext, x, @@ -2312,7 +2420,11 @@ def _gather_lowering_rule( operand_batching_dims=(1,), start_indices_batching_dims=(1,), ): - return tpu.dynamic_gather(out_type, x, recovered_indices, 0) + if jaxlib_version < (0, 6, 3): + # TODO: b/423649694 - Remove on 2025-07-18 + return tpu.dynamic_gather(x, recovered_indices, 0) + else: + return tpu.dynamic_gather(x, recovered_indices, [0]) if dimension_numbers == lax.GatherDimensionNumbers( offset_dims=(), collapsed_slice_dims=(1,), @@ -2320,23 +2432,27 @@ def _gather_lowering_rule( operand_batching_dims=(0,), start_indices_batching_dims=(0,), ): - return tpu.dynamic_gather(out_type, x, recovered_indices, 1) + if jaxlib_version < (0, 6, 3): + # TODO: b/423649694 - Remove on 2025-07-18 + return tpu.dynamic_gather(x, recovered_indices, 1) + else: + return tpu.dynamic_gather(x, recovered_indices, [1]) raise NotImplementedError("Unsupported gather") -lowering_rules[lax.gather_p] = _gather_lowering_rule - - +@register_lowering_rule(lax.transpose_p) def _transpose_lowering_rule(ctx: LoweringRuleContext, x, *, permutation): - if permutation != (1, 0): + minormost_transpose = (1, 0) + untiled_tiled_swap = (1, 0, 2) + if permutation not in (minormost_transpose, untiled_tiled_swap): raise NotImplementedError out_type = aval_to_ir_type( ctx.lowering_context.dynamic_shape_replacement_fn, ctx.avals_out[0] ) - return vector.transpose(out_type, x, permutation) - - -lowering_rules[lax.transpose_p] = _transpose_lowering_rule + if ctx.forward_compatible or is_cloud_tpu_older_than(2025, 5, 8): + return vector.transpose(out_type, x, permutation) + else: + return tpu.transpose(out_type, x, permutation) def _bcast(x, y, x_aval, y_aval, out_aval): @@ -2346,13 +2462,13 @@ def _bcast(x, y, x_aval, y_aval, out_aval): y_dtype = x_aval.dtype elif x_aval.weak_type: x_dtype = y_aval.dtype - if isinstance(x, (np.ndarray, np.number, int, float)): + if not isinstance(x, ir.Value): if getattr(y, "type", None) == ir.IndexType.get(): mlir_type = y.type else: mlir_type = _dtype_to_ir_type(x_dtype) x = ir_constant(x, mlir_type) - if isinstance(y, (np.ndarray, np.number, int, float)): + if not isinstance(y, ir.Value): if getattr(x, "type", None) == ir.IndexType.get(): mlir_type = x.type else: @@ -2368,6 +2484,10 @@ def _bcast(x, y, x_aval, y_aval, out_aval): return x, y +@register_lowering_rule( + lax.add_p, kernel_types=[*tpu_core.KernelType], ensure_mlir_values=False +) +@register_lowering_rule(ad_util.add_any_p, ensure_mlir_values=False) def _add_lowering_rule(ctx: LoweringRuleContext, x, y): x, y = _bcast(x, y, ctx.avals_in[0], ctx.avals_in[1], ctx.avals_out[0]) (aval_out,) = ctx.avals_out @@ -2378,12 +2498,6 @@ def _add_lowering_rule(ctx: LoweringRuleContext, x, y): raise NotImplementedError(aval_out.dtype) -lowering_rules[lax.add_p] = _add_lowering_rule -skip_mlir_conversions.add(lax.add_p) -lowering_rules[ad_util.add_any_p] = _add_lowering_rule -skip_mlir_conversions.add(ad_util.add_any_p) - - class FoldingError(Exception): pass @@ -2391,7 +2505,7 @@ class FoldingError(Exception): def _fold_and_get_constant_value(x): def _fold(x, fuel): if fuel <= 0: - raise FoldingError("Folding depth exceeded") + raise FoldingError() op_name = getattr(x.owner, "name", None) binop_folds = { "arith.maxsi": max, @@ -2406,7 +2520,7 @@ def _fold(x, fuel): raise ValueError(f"Unsupported constant type: {x.type}") if op_name in binop_folds: return binop_folds[op_name](_fold(v, fuel - 1) for v in x.owner.operands) - raise FoldingError(f"Folding not supported for {x.owner}") + raise FoldingError() try: return _fold(x, 10) @@ -2414,6 +2528,7 @@ def _fold(x, fuel): return None +@register_lowering_rule(lax.max_p, ensure_mlir_values=False) def _max_lowering_rule(ctx: LoweringRuleContext, x, y): x, y = _bcast(x, y, ctx.avals_in[0], ctx.avals_in[1], ctx.avals_out[0]) (aval_out,) = ctx.avals_out @@ -2426,10 +2541,7 @@ def _max_lowering_rule(ctx: LoweringRuleContext, x, y): raise NotImplementedError(aval_out.dtype) -lowering_rules[lax.max_p] = _max_lowering_rule -skip_mlir_conversions.add(lax.max_p) - - +@register_lowering_rule(lax.min_p, ensure_mlir_values=False) def _min_lowering_rule(ctx: LoweringRuleContext, x, y): x, y = _bcast(x, y, ctx.avals_in[0], ctx.avals_in[1], ctx.avals_out[0]) (aval_out,) = ctx.avals_out @@ -2442,10 +2554,9 @@ def _min_lowering_rule(ctx: LoweringRuleContext, x, y): raise NotImplementedError(aval_out.dtype) -lowering_rules[lax.min_p] = _min_lowering_rule -skip_mlir_conversions.add(lax.min_p) - - +@register_lowering_rule( + lax.sub_p, kernel_types=[*tpu_core.KernelType], ensure_mlir_values=False +) def _sub_lowering_rule(ctx: LoweringRuleContext, x, y): x, y = _bcast(x, y, ctx.avals_in[0], ctx.avals_in[1], ctx.avals_out[0]) (aval_out,) = ctx.avals_out @@ -2456,10 +2567,9 @@ def _sub_lowering_rule(ctx: LoweringRuleContext, x, y): raise NotImplementedError(aval_out.dtype) -lowering_rules[lax.sub_p] = _sub_lowering_rule -skip_mlir_conversions.add(lax.sub_p) - - +@register_lowering_rule( + lax.mul_p, kernel_types=[*tpu_core.KernelType], ensure_mlir_values=False +) def _mul_lowering_rule(ctx: LoweringRuleContext, x, y): x, y = _bcast(x, y, ctx.avals_in[0], ctx.avals_in[1], ctx.avals_out[0]) (aval_out,) = ctx.avals_out @@ -2470,10 +2580,9 @@ def _mul_lowering_rule(ctx: LoweringRuleContext, x, y): raise NotImplementedError(aval_out.dtype) -lowering_rules[lax.mul_p] = _mul_lowering_rule -skip_mlir_conversions.add(lax.mul_p) - - +@register_lowering_rule( + lax.div_p, kernel_types=[*tpu_core.KernelType], ensure_mlir_values=False +) def _div_lowering_rule(ctx: LoweringRuleContext, x, y): x, y = _bcast(x, y, ctx.avals_in[0], ctx.avals_in[1], ctx.avals_out[0]) (aval_out,) = ctx.avals_out @@ -2486,10 +2595,9 @@ def _div_lowering_rule(ctx: LoweringRuleContext, x, y): raise NotImplementedError(aval_out.dtype) -lowering_rules[lax.div_p] = _div_lowering_rule -skip_mlir_conversions.add(lax.div_p) - - +@register_lowering_rule( + lax.rem_p, kernel_types=[*tpu_core.KernelType], ensure_mlir_values=False +) def _rem_lowering_rule(ctx: LoweringRuleContext, x, y): x, y = _bcast(x, y, ctx.avals_in[0], ctx.avals_in[1], ctx.avals_out[0]) (aval_out,) = ctx.avals_out @@ -2502,10 +2610,7 @@ def _rem_lowering_rule(ctx: LoweringRuleContext, x, y): raise NotImplementedError(aval_out.dtype) -lowering_rules[lax.rem_p] = _rem_lowering_rule -skip_mlir_conversions.add(lax.rem_p) - - +@register_lowering_rule(lax.abs_p) def _abs_lowering_rule(ctx: LoweringRuleContext, x): (aval_out,) = ctx.avals_out if jnp.issubdtype(aval_out.dtype, jnp.integer): @@ -2515,9 +2620,7 @@ def _abs_lowering_rule(ctx: LoweringRuleContext, x): raise NotImplementedError(aval_out.dtype) -lowering_rules[lax.abs_p] = _abs_lowering_rule - - +@register_lowering_rule(lax.neg_p, ensure_mlir_values=False) def _neg_lowering_rule(ctx: LoweringRuleContext, x): (x_aval,) = ctx.avals_in new_ctx = ctx.replace( @@ -2527,58 +2630,49 @@ def _neg_lowering_rule(ctx: LoweringRuleContext, x): return _sub_lowering_rule(new_ctx, np.array(0, dtype=x_aval.dtype), x) -lowering_rules[lax.neg_p] = _neg_lowering_rule -skip_mlir_conversions.add(lax.neg_p) - - +@register_lowering_rule(lax.sign_p, kernel_types=[*tpu_core.KernelType]) def _sign_lowering_rule(ctx: LoweringRuleContext, x): return lower_fun( pallas_utils.sign_lowering_helper, multiple_results=False, )(ctx, x) -lowering_rules[lax.sign_p] = _sign_lowering_rule - - +@register_lowering_rule(lax.nextafter_p) def _nextafter_lowering_rule(ctx: LoweringRuleContext, x, y): return lower_fun( pallas_utils.nextafter_lowering_helper, multiple_results=False, )(ctx, x, y) -lowering_rules[lax.nextafter_p] = _nextafter_lowering_rule - - -def _rsqrt_lowering_rule(ctx: LoweringRuleContext, x): +@register_lowering_rule(lax.rsqrt_p) +def _rsqrt_lowering_rule(ctx: LoweringRuleContext, x, accuracy): + if accuracy is not None: + raise NotImplementedError("Not implemented: accuracy") return math.rsqrt(x) -lowering_rules[lax.rsqrt_p] = _rsqrt_lowering_rule - - -def _sqrt_lowering_rule(ctx: LoweringRuleContext, x): +@register_lowering_rule(lax.sqrt_p) +def _sqrt_lowering_rule(ctx: LoweringRuleContext, x, accuracy): + if accuracy is not None: + raise NotImplementedError("Not implemented: accuracy") return math.sqrt(x) -lowering_rules[lax.sqrt_p] = _sqrt_lowering_rule - - +@register_lowering_rule(lax.square_p) def _square_lowering_rule(ctx: LoweringRuleContext, x): if jnp.issubdtype(ctx.avals_in[0].dtype, jnp.integer): return arith.muli(x, x) return arith.mulf(x, x) -lowering_rules[lax.square_p] = _square_lowering_rule - - -def _exp_lowering_rule(ctx: LoweringRuleContext, x): +@register_lowering_rule(lax.exp_p) +def _exp_lowering_rule(ctx: LoweringRuleContext, x, accuracy): + if accuracy is not None: + raise NotImplementedError("Not implemented: accuracy") return math.exp(x) -lowering_rules[lax.exp_p] = _exp_lowering_rule - - +@register_lowering_rule(lax.pow_p, ensure_mlir_values=False) def _pow_lowering_rule(ctx: LoweringRuleContext, x, y): # jax accepts float base (x) and integer/float exponent (y), and integer # exponent is casted to float. @@ -2593,32 +2687,28 @@ def _pow_lowering_rule(ctx: LoweringRuleContext, x, y): return math.powf(x, y) -lowering_rules[lax.pow_p] = _pow_lowering_rule -skip_mlir_conversions.add(lax.pow_p) - - +@register_lowering_rule(lax.integer_pow_p) def _integer_pow_lowering_rule(ctx: LoweringRuleContext, x, *, y): return lower_fun(lax_internal._integer_pow, multiple_results=False)( ctx, x, y=y) -lowering_rules[lax.integer_pow_p] = _integer_pow_lowering_rule - - -def _exp2_lowering_rule(ctx: LoweringRuleContext, x): +@register_lowering_rule(lax.exp2_p, ensure_mlir_values=False) +def _exp2_lowering_rule(ctx: LoweringRuleContext, x, accuracy): # exp2 in JAX lowers to exp(ln2 * x), not to pow2. We match that behavior # here. + if accuracy is not None: + raise NotImplementedError("Not implemented: accuracy") return lower_fun( lambda x: jnp.exp(jnp.astype(np.log(2), x.dtype) * x), multiple_results=False, )(ctx, x) -lowering_rules[lax.exp2_p] = _exp2_lowering_rule -skip_mlir_conversions.add(lax.exp2_p) - - -def _logistic_lowering_rule(ctx: LoweringRuleContext, x): +@register_lowering_rule(lax.logistic_p) +def _logistic_lowering_rule(ctx: LoweringRuleContext, x, accuracy): + if accuracy is not None: + raise NotImplementedError("Not implemented: accuracy") neg_x = arith.negf(x) exp_neg_x = math.exp(neg_x) aval_out = ctx.avals_out[0] @@ -2633,51 +2723,49 @@ def _logistic_lowering_rule(ctx: LoweringRuleContext, x): return arith.divf(one, denom) -lowering_rules[lax.logistic_p] = _logistic_lowering_rule - - -def _sin_lowering_rule(ctx: LoweringRuleContext, x): +@register_lowering_rule(lax.sin_p) +def _sin_lowering_rule(ctx: LoweringRuleContext, x, accuracy): + if accuracy is not None: + raise NotImplementedError("Not implemented: accuracy") return math.sin(x) -lowering_rules[lax.sin_p] = _sin_lowering_rule - - -def _cos_lowering_rule(ctx: LoweringRuleContext, x): +@register_lowering_rule(lax.cos_p) +def _cos_lowering_rule(ctx: LoweringRuleContext, x, accuracy): + if accuracy is not None: + raise NotImplementedError("Not implemented: accuracy") return math.cos(x) -lowering_rules[lax.cos_p] = _cos_lowering_rule - - -def _tan_lowering_rule(ctx: LoweringRuleContext, x): +@register_lowering_rule(lax.tan_p) +def _tan_lowering_rule(ctx: LoweringRuleContext, x, accuracy): + if accuracy is not None: + raise NotImplementedError("Not implemented: accuracy") return math.tan(x) -lowering_rules[lax.tan_p] = _tan_lowering_rule - - -def _tanh_lowering_rule(ctx: LoweringRuleContext, x): +@register_lowering_rule(lax.tanh_p) +def _tanh_lowering_rule(ctx: LoweringRuleContext, x, accuracy): + if accuracy is not None: + raise NotImplementedError("Not implemented: accuracy") return math.tanh(x) -lowering_rules[lax.tanh_p] = _tanh_lowering_rule - - -def _log_lowering_rule(ctx: LoweringRuleContext, x): +@register_lowering_rule(lax.log_p) +def _log_lowering_rule(ctx: LoweringRuleContext, x, accuracy): + if accuracy is not None: + raise NotImplementedError("Not implemented: accuracy") return math.log(x) -lowering_rules[lax.log_p] = _log_lowering_rule - - -def _log1p_lowering_rule(ctx: LoweringRuleContext, x): +@register_lowering_rule(lax.log1p_p) +def _log1p_lowering_rule(ctx: LoweringRuleContext, x, accuracy): + if accuracy is not None: + raise NotImplementedError("Not implemented: accuracy") return math.log1p(x) -lowering_rules[lax.log1p_p] = _log1p_lowering_rule - - +@register_lowering_rule(lax.round_p) def _round_lowering_rule(ctx: LoweringRuleContext, x, *, rounding_method): if rounding_method == 0: return math.round(x) @@ -2687,37 +2775,28 @@ def _round_lowering_rule(ctx: LoweringRuleContext, x, *, rounding_method): raise NotImplementedError(f"Unsupported rounding method: {rounding_method}") -lowering_rules[lax.round_p] = _round_lowering_rule - - +@register_lowering_rule(lax.ceil_p) def _ceil_lowering_rule(ctx: LoweringRuleContext, x): return math.ceil(x) -lowering_rules[lax.ceil_p] = _ceil_lowering_rule - - +@register_lowering_rule(lax.floor_p) def _floor_lowering_rule(ctx: LoweringRuleContext, x): return math.floor(x) -lowering_rules[lax.floor_p] = _floor_lowering_rule - - +@register_lowering_rule(lax.clz_p) def _clz_lowering_rule(ctx: LoweringRuleContext, x): return math.ctlz(x) -lowering_rules[lax.clz_p] = _clz_lowering_rule - +@register_lowering_rule(lax.population_count_p) def _population_count_lowering_rule(ctx: LoweringRuleContext, x): aval_out = ctx.avals_out[0] if aval_out.shape == (): raise ValueError("Population count is not supported on scalars") return math.ctpop(x) -lowering_rules[lax.population_count_p] = _population_count_lowering_rule - # Mapping for signed integer comparisons. _cmpsi_lowering_types = { @@ -2823,23 +2902,21 @@ def _cmp_lowering_rule(primitive, ctx: LoweringRuleContext, x, y): raise NotImplementedError(f"Unsupported dtype in cmp: {dtype}") -lowering_rules[lax.eq_p] = functools.partial(_cmp_lowering_rule, lax.eq_p) -lowering_rules[lax.ne_p] = functools.partial(_cmp_lowering_rule, lax.ne_p) -lowering_rules[lax.lt_p] = functools.partial(_cmp_lowering_rule, lax.lt_p) -lowering_rules[lax.le_p] = functools.partial(_cmp_lowering_rule, lax.le_p) -lowering_rules[lax.gt_p] = functools.partial(_cmp_lowering_rule, lax.gt_p) -lowering_rules[lax.ge_p] = functools.partial(_cmp_lowering_rule, lax.ge_p) +for prim in [lax.eq_p, lax.ne_p, lax.lt_p, lax.le_p, lax.gt_p, lax.ge_p]: + register_lowering_rule(prim, kernel_types=[*tpu_core.KernelType])( + functools.partial(_cmp_lowering_rule, prim) + ) +@register_lowering_rule( + lax.and_p, kernel_types=[*tpu_core.KernelType], ensure_mlir_values=False +) def _and_lowering_rule(ctx: LoweringRuleContext, x, y): x, y = _bcast(x, y, *ctx.avals_in, *ctx.avals_out) return arith.andi(x, y) -lowering_rules[lax.and_p] = _and_lowering_rule -skip_mlir_conversions.add(lax.and_p) - - +@register_lowering_rule(lax.is_finite_p) def _is_finite_lowering_rule(ctx: LoweringRuleContext, x): out_aval, = ctx.avals_out out_type = aval_to_ir_type( @@ -2848,18 +2925,15 @@ def _is_finite_lowering_rule(ctx: LoweringRuleContext, x): return _not_lowering_rule(ctx, tpu.weird(out_type, x)) -lowering_rules[lax.is_finite_p] = _is_finite_lowering_rule - - +@register_lowering_rule( + lax.or_p, kernel_types=[*tpu_core.KernelType], ensure_mlir_values=False +) def _or_lowering_rule(ctx: LoweringRuleContext, x, y): x, y = _bcast(x, y, *ctx.avals_in, *ctx.avals_out) return arith.ori(x, y) -lowering_rules[lax.or_p] = _or_lowering_rule -skip_mlir_conversions.add(lax.or_p) - - +@register_lowering_rule(lax.not_p) def _not_lowering_rule(ctx: LoweringRuleContext, x): # The primitive not_p is lowered to # https://github.com/openxla/stablehlo/blob/main/docs/spec.md#not @@ -2878,14 +2952,13 @@ def _not_lowering_rule(ctx: LoweringRuleContext, x): ctx.lowering_context.dynamic_shape_replacement_fn, out_aval ) scalar_minus_one = ir.IntegerAttr.get(out_scalar_type, -1) - minus_one = arith.ConstantOp( + minus_one = arith.constant( out_type, ir.DenseElementsAttr.get_splat(out_type, scalar_minus_one) ) return arith.xori(x, minus_one) -lowering_rules[lax.not_p] = _not_lowering_rule - +@register_lowering_rule(lax.select_n_p, kernel_types=[*tpu_core.KernelType]) def _select_n_lowering_rule(ctx: LoweringRuleContext, pred, x, *args): if len(args) > 1: raise NotImplementedError("select_n only supported with <= 2 arguments") @@ -2905,22 +2978,18 @@ def _select_n_lowering_rule(ctx: LoweringRuleContext, pred, x, *args): return arith.select(pred, y, x) -lowering_rules[lax.select_n_p] = _select_n_lowering_rule - - def _clamp(min, operand, max): res = jnp.maximum(operand, min) return jnp.minimum(res, max) +@register_lowering_rule(lax.clamp_p) def _clamp_lowering_rule(ctx: LoweringRuleContext, min, operand, max): """Compute minimum_p(maximum_p(min, operand), max).""" return lower_fun(_clamp, multiple_results=False)(ctx, min, operand, max) -lowering_rules[lax.clamp_p] = _clamp_lowering_rule - - +@register_lowering_rule(for_loop.for_p) def _for_lowering_rule( ctx: LoweringRuleContext, *args, @@ -2952,9 +3021,6 @@ def _for_lowering_rule( return args -lowering_rules[for_loop.for_p] = _for_lowering_rule - - def _lower_jaxpr_to_for_loop(ctx: LoweringRuleContext, jaxpr: jax_core.Jaxpr, start: int | ir.Value, num_steps: int | ir.Value, consts, *args, @@ -2997,10 +3063,13 @@ def _run_body(i, args): iv = for_op.induction_variable inner_args = for_op.inner_iter_args inner_out = _run_body(iv, inner_args) - scf.YieldOp(inner_out) + scf.yield_(inner_out) return for_op.results +@register_lowering_rule( + lax.scan_p, kernel_types=[*tpu_core.KernelType], ensure_mlir_values=False +) def _scan_lowering_rule( ctx: LoweringRuleContext, *args, @@ -3045,8 +3114,6 @@ def _scan_lowering_rule( mlir_type=_dtype_to_ir_type(jnp.dtype('int32'))), *out] return out -lowering_rules[lax.scan_p] = _scan_lowering_rule -skip_mlir_conversions.add(lax.scan_p) def _lower_while_via_fori( @@ -3076,6 +3143,7 @@ def _lower_while_via_fori( return [ub, ub, *for_out] +@register_lowering_rule(lax.while_p) def _while_lowering_rule( ctx: LoweringRuleContext, *args, @@ -3136,9 +3204,8 @@ def _while_lowering_rule( return list(while_op.results) -lowering_rules[lax.while_p] = _while_lowering_rule - -def _cond_lowering_rule(ctx: LoweringRuleContext, *args, branches): +@register_lowering_rule(lax.cond_p) +def _cond_lowering_rule(ctx: LoweringRuleContext, *args, branches, **params): index, *args = args constant_index = _fold_and_get_constant_value(index) @@ -3169,29 +3236,25 @@ def _cond_lowering_rule(ctx: LoweringRuleContext, *args, branches): ) else: out = jaxpr_subcomp(lowering_context, branches[1].jaxpr, *args) - scf.YieldOp(out) + scf.yield_(out) with ir.InsertionPoint(if_op.else_block): out = jaxpr_subcomp(lowering_context, branches[0].jaxpr, *args) - scf.YieldOp(out) + scf.yield_(out) return if_op.results -lowering_rules[lax.cond_p] = _cond_lowering_rule - - +@register_lowering_rule(pjit.pjit_p, kernel_types=[*tpu_core.KernelType]) def _pjit_lowering_rule(ctx: LoweringRuleContext, *args, jaxpr, **_): lowering_context = ctx.lowering_context.replace(block_shapes=ctx.block_shapes) return jaxpr_subcomp(lowering_context, jaxpr.jaxpr, *args) -lowering_rules[pjit.pjit_p] = _pjit_lowering_rule - - +@register_lowering_rule(pjit.mesh_cast_p) def _mesh_cast_lowering_rule(ctx, x, dst_sharding): return x -lowering_rules[pjit.mesh_cast_p] = _mesh_cast_lowering_rule +@register_lowering_rule(custom_derivatives.custom_jvp_call_p) def _custom_jvp_call_lowering_rule( ctx: LoweringRuleContext, *args, @@ -3208,34 +3271,49 @@ def _custom_jvp_call_lowering_rule( return jaxpr_subcomp(lowering_context, call_jaxpr.jaxpr, *args) -lowering_rules[custom_derivatives.custom_jvp_call_p] = ( - _custom_jvp_call_lowering_rule) +@register_lowering_rule(custom_derivatives.custom_vjp_call_p) +def _custom_vjp_call_lowering_rule( + ctx: LoweringRuleContext, + *args, + call_jaxpr, + fwd_jaxpr_thunk, + out_trees, + symbolic_zeros, + bwd, + num_consts, +): + if num_consts: raise NotImplementedError + lowering_context = ctx.lowering_context.replace(block_shapes=ctx.block_shapes) + return jaxpr_subcomp(lowering_context, call_jaxpr.jaxpr, *args) +@register_lowering_rule(debugging.debug_callback_p) def _debug_callback_lowering_rule(ctx: LoweringRuleContext, *args, **kwargs): del ctx, args, kwargs # No-op debug callbacks in Mosaic for now return [] -lowering_rules[debugging.debug_callback_p] = _debug_callback_lowering_rule - - +@register_lowering_rule( + primitives.program_id_p, kernel_types=[*tpu_core.KernelType] +) def _program_id_lowering_rule(ctx: LoweringRuleContext, *, axis: int): - if ctx.lowering_context.user_grid_indices is None: raise ValueError( f"program id: {axis} was passed, but user did not provide a grid." ) length = len(ctx.lowering_context.user_grid_indices) - if not (0 <= axis < length): + if axis not in range(length): raise ValueError( f"user passed in program id with axis: {axis}, but grid only has" f" length: {length}" ) return ctx.lowering_context.user_grid_indices[axis] -lowering_rules[primitives.program_id_p] = _program_id_lowering_rule + +@register_lowering_rule( + primitives.num_programs_p, kernel_types=[*tpu_core.KernelType] +) def _num_programs_lowering_rule(ctx: LoweringRuleContext, *, axis: int): mapped_axes = set(ctx.lowering_context.mapped_dims) seen_user_axes = 0 @@ -3249,9 +3327,9 @@ def _num_programs_lowering_rule(ctx: LoweringRuleContext, *, axis: int): f" length: {len(ctx.lowering_context.grid_rank)}" ) return tpu.iteration_bound(i) -lowering_rules[primitives.num_programs_p] = _num_programs_lowering_rule +@register_lowering_rule(tpu_primitives.repeat_p) def _repeat_lowering_rule(ctx: LoweringRuleContext, x, *, repeats, axis): (out_aval,) = ctx.avals_out return tpu.repeat( @@ -3264,9 +3342,7 @@ def _repeat_lowering_rule(ctx: LoweringRuleContext, x, *, repeats, axis): ) -lowering_rules[tpu_primitives.repeat_p] = _repeat_lowering_rule - - +@register_lowering_rule(tpu_primitives.roll_p) def _roll_lowering_rule( ctx: LoweringRuleContext, x, shift, *, axis, stride, stride_axis ): @@ -3283,9 +3359,7 @@ def _roll_lowering_rule( ) -lowering_rules[tpu_primitives.roll_p] = _roll_lowering_rule - - +@register_lowering_rule(lax.slice_p, kernel_types=[*tpu_core.KernelType]) def _slice_lowering_rule( ctx: LoweringRuleContext, x, limit_indices, start_indices, strides ): @@ -3302,62 +3376,55 @@ def _slice_lowering_rule( ) -lowering_rules[lax.slice_p] = _slice_lowering_rule - - +@register_lowering_rule( + lax.xor_p, kernel_types=[*tpu_core.KernelType], ensure_mlir_values=False +) def _xor_lowering_rule(ctx: LoweringRuleContext, x, y): x, y = _bcast(x, y, *ctx.avals_in, *ctx.avals_out) return arith.xori(x, y) -lowering_rules[lax.xor_p] = _xor_lowering_rule -skip_mlir_conversions.add(lax.xor_p) - - +@register_lowering_rule( + lax.shift_left_p, + kernel_types=[*tpu_core.KernelType], + ensure_mlir_values=False, +) def _shift_left_lowering_rule(ctx: LoweringRuleContext, x, d): x, d = _bcast(x, d, *ctx.avals_in, *ctx.avals_out) return arith.shli(x, d) -lowering_rules[lax.shift_left_p] = _shift_left_lowering_rule -skip_mlir_conversions.add(lax.shift_left_p) - - +@register_lowering_rule(lax.shift_right_arithmetic_p, ensure_mlir_values=False) def _shift_right_arithmetic_lowering_rule(ctx: LoweringRuleContext, x, d): x, d = _bcast(x, d, *ctx.avals_in, *ctx.avals_out) return arith.shrsi(x, d) -lowering_rules[lax.shift_right_arithmetic_p] = _shift_right_arithmetic_lowering_rule -skip_mlir_conversions.add(lax.shift_right_arithmetic_p) - - -def _shift_right_logical_lowering_rules(ctx: LoweringRuleContext, x, d): +@register_lowering_rule( + lax.shift_right_logical_p, + kernel_types=[*tpu_core.KernelType], + ensure_mlir_values=False, +) +def _shift_right_logical_lowering_rule(ctx: LoweringRuleContext, x, d): x, d = _bcast(x, d, *ctx.avals_in, *ctx.avals_out) return arith.shrui(x, d) -lowering_rules[lax.shift_right_logical_p] = _shift_right_logical_lowering_rules -skip_mlir_conversions.add(lax.shift_right_logical_p) - - +@register_lowering_rule(lax.erf_inv_p) def _erf_inv_lowering_rule(ctx: LoweringRuleContext, x): return lower_fun( pallas_utils.erf_inv_lowering_helper, multiple_results=False, )(ctx, x) -lowering_rules[lax.erf_inv_p] = _erf_inv_lowering_rule - - +@register_lowering_rule(primitives.reciprocal_p) def _reciprocal_lowering_rule(ctx: LoweringRuleContext, x, *, approx): if not isinstance(x.type.element_type, ir.F32Type): raise ValueError("Only float32 is supported.") return tpu.reciprocal(x, approx=approx) -lowering_rules[primitives.reciprocal_p] = _reciprocal_lowering_rule - +@register_lowering_rule(tpu_primitives.bitcast_p) def _bitcast_lowering_rule(ctx: LoweringRuleContext, x, *, ty): del ty (out_aval,) = ctx.avals_out @@ -3368,8 +3435,8 @@ def _bitcast_lowering_rule(ctx: LoweringRuleContext, x, *, ty): x, ) -lowering_rules[tpu_primitives.bitcast_p] = _bitcast_lowering_rule +@register_lowering_rule(lax.bitcast_convert_type_p) def _bitcast_convert_type_lowering_rule( ctx: LoweringRuleContext, x, *, new_dtype): (in_aval, ) = ctx.avals_in @@ -3384,7 +3451,6 @@ def _bitcast_convert_type_lowering_rule( ), x, ) -lowering_rules[lax.bitcast_convert_type_p] = _bitcast_convert_type_lowering_rule def _alloc_value( @@ -3392,7 +3458,7 @@ def _alloc_value( ) -> ir.Value: if isinstance(aval, pallas_core.AbstractMemoryRef): memspace = _memory_space_to_mosaic_attribute(aval.memory_space) - if jnp.issubdtype(aval.dtype, tpu_core.semaphore_dtype): + if jnp.issubdtype(aval.dtype, pallas_core.semaphore_dtype): assert aval.memory_space == TPUMemorySpace.SEMAPHORE memref_type = aval_to_ir_type( ctx.lowering_context.dynamic_shape_replacement_fn, @@ -3416,7 +3482,10 @@ def _alloc_value( raise NotImplementedError(f"Cannot allocate {type(aval)}.") -def _run_scoped_lowering_rule(ctx: LoweringRuleContext, *consts, jaxpr): +@register_lowering_rule(primitives.run_scoped_p) +def _run_scoped_lowering_rule(ctx: LoweringRuleContext, *consts, jaxpr, collective_axes): + if collective_axes: + raise NotImplementedError("run_scoped lowering does not support collective axes") out_type = [ aval_to_ir_type(ctx.lowering_context.dynamic_shape_replacement_fn, aval) for aval in ctx.avals_out @@ -3434,16 +3503,14 @@ def _run_scoped_lowering_rule(ctx: LoweringRuleContext, *consts, jaxpr): block_shapes=(*ctx.block_shapes, *block_shapes) ) out = jaxpr_subcomp(ctx, jaxpr, *consts, *args) - tpu.YieldOp(out) + tpu.yield_(out) return region.results -lowering_rules[primitives.run_scoped_p] = _run_scoped_lowering_rule - def _device_id_to_logical( ctx: LoweringRuleContext, device_id, - device_id_type: tpu_primitives.DeviceIdType): - if device_id_type is tpu_primitives.DeviceIdType.MESH: + device_id_type: primitives.DeviceIdType): + if device_id_type is primitives.DeviceIdType.MESH: # Mesh means we are passed the mesh coordinates for the device device_ids = tree_util.tree_leaves(device_id) mesh_strides = ctx.lowering_context.mesh_context.mesh_strides @@ -3458,29 +3525,40 @@ def _device_id_to_logical( for a, b in zip(device_ids, mesh_strides) ), ) - elif device_id_type is tpu_primitives.DeviceIdType.LOGICAL: + elif device_id_type is primitives.DeviceIdType.LOGICAL: return device_id raise NotImplementedError(f"Unsupported device id type: {device_id_type}") +@register_lowering_rule(primitives.semaphore_read_p) def _semaphore_read_lowering_rule( ctx: LoweringRuleContext, *args, args_tree, ): - sem_aval, _ = tree_util.tree_unflatten(args_tree, ctx.avals_in) + sem_aval, sem_transforms_avals = tree_util.tree_unflatten(args_tree, ctx.avals_in) + primitives.check_sem_avals( + sem_aval, + sem_transforms_avals, + "read", + allowed_semaphore_types={ + tpu_core.dma_semaphore, + pallas_core.semaphore, + pallas_core.barrier_semaphore, + pallas_core.SEMAPHORE_INTERPRET_DTYPE, + }, + ) sem, transforms = tree_util.tree_unflatten(args_tree, args) sem, _ = _transform_ref(sem, sem_aval.dtype, sem_aval.shape, transforms) return tpu.sem_read(sem) -lowering_rules[tpu_primitives.semaphore_read_p] = _semaphore_read_lowering_rule - +@register_lowering_rule(primitives.semaphore_signal_p) def _semaphore_signal_lowering_rule( ctx: LoweringRuleContext, *args, args_tree, - device_id_type: tpu_primitives.DeviceIdType, + device_id_type: primitives.DeviceIdType, ): sem_aval, _, _, _, _ = tree_util.tree_unflatten(args_tree, ctx.avals_in) sem, transforms, value, device_id, core_index = tree_util.tree_unflatten( @@ -3493,20 +3571,23 @@ def _semaphore_signal_lowering_rule( return [] -lowering_rules[tpu_primitives.semaphore_signal_p] = ( - _semaphore_signal_lowering_rule) - - +@register_lowering_rule(primitives.semaphore_wait_p) def _semaphore_wait_lowering_rule(ctx: LoweringRuleContext, *args, args_tree): sem_aval, _, _ = tree_util.tree_unflatten(args_tree, ctx.avals_in) sem, transforms, value = tree_util.tree_unflatten(args_tree, args) sem, _ = _transform_ref(sem, sem_aval.dtype, sem_aval.shape, transforms) tpu.sem_wait(sem, value) return [] -lowering_rules[tpu_primitives.semaphore_wait_p] = _semaphore_wait_lowering_rule -def _dma_start_lowering_rule(ctx: LoweringRuleContext, *args, tree, - device_id_type: tpu_primitives.DeviceIdType): + +@register_lowering_rule(tpu_primitives.dma_start_p) +def _dma_start_lowering_rule( + ctx: LoweringRuleContext, + *args, + tree, + device_id_type: primitives.DeviceIdType, + priority: int, +): ( src_ref, src_transforms, @@ -3538,15 +3619,23 @@ def _dma_start_lowering_rule(ctx: LoweringRuleContext, *args, tree, sem, _ = _transform_ref(sem, sem_aval.dtype, sem_aval.shape, sem_transforms) if device_id is not None: device_id = _device_id_to_logical(ctx, device_id, device_id_type) - tpu.enqueue_dma(src_ref, dst_ref, sem, source_semaphore=src_sem, - device_id=device_id) - + priority_kwarg = {"priority": priority} + if jaxlib_version < (0, 5, 4): + priority_kwarg = {} + tpu.enqueue_dma( + src_ref, + dst_ref, + sem, + source_semaphore=src_sem, + device_id=device_id, + **priority_kwarg, + ) return [] -lowering_rules[tpu_primitives.dma_start_p] = _dma_start_lowering_rule +@register_lowering_rule(tpu_primitives.dma_wait_p) def _dma_wait_lowering_rule(ctx: LoweringRuleContext, *args, tree, - device_id_type: tpu_primitives.DeviceIdType): + device_id_type: primitives.DeviceIdType): del device_id_type (src, src_transforms, dst, transforms, sem, sem_transforms, _, _, _) = ( tree_util.tree_unflatten(tree, args) @@ -3574,12 +3663,8 @@ def _dma_wait_lowering_rule(ctx: LoweringRuleContext, *args, tree, tpu.wait_dma2(sem, src, dst) return [] -lowering_rules[tpu_primitives.dma_wait_p] = _dma_wait_lowering_rule - -def _device_id_lowering_rule(ctx: LoweringRuleContext): - return tpu.device_id() -lowering_rules[tpu_primitives.device_id_p] = _device_id_lowering_rule +@register_lowering_rule(lax.axis_index_p, kernel_types=[*tpu_core.KernelType]) def _axis_index_rule(ctx: LoweringRuleContext, *, axis_name: Hashable): grid_names = ctx.lowering_context.grid_names if grid_names and axis_name in grid_names: @@ -3598,24 +3683,23 @@ def _axis_index_rule(ctx: LoweringRuleContext, *, axis_name: Hashable): np.prod(mesh_shape[axis_index + 1 :], dtype=np.int32) ) return arith.remsi(arith.divsi(device_id, minor_divisor), axis_size) -lowering_rules[lax.axis_index_p] = _axis_index_rule + +@register_lowering_rule(tpu_primitives.get_barrier_semaphore_p) def _get_barrier_semaphore_rule(ctx: LoweringRuleContext): memref_type = aval_to_ir_type( ctx.lowering_context.dynamic_shape_replacement_fn, ctx.avals_out[0] ) return tpu.sem_barrier(memref_type) -lowering_rules[tpu_primitives.get_barrier_semaphore_p] = _get_barrier_semaphore_rule +@register_lowering_rule(tpu_primitives.delay_p) def _delay_rule(ctx: LoweringRuleContext, nanos: int): tpu.delay(nanos) return [] -lowering_rules[tpu_primitives.delay_p] = _delay_rule - - +@register_lowering_rule(primitives.debug_print_p) def _debug_print_rule( ctx: LoweringRuleContext, *args, fmt: str, has_placeholders: bool ): @@ -3630,8 +3714,8 @@ def _debug_print_rule( # Scalar case. if is_all_scalars: - primitives.check_debug_print_format(fmt, *args) if has_placeholders: + primitives.check_debug_print_format(fmt, *args) if not all( isinstance(arg.type, ir.IntegerType) and arg.type.width == 32 for arg in args @@ -3642,7 +3726,7 @@ def _debug_print_rule( " remove placeholders from the format string." ) - # TPU expects $0, $1 etc as placeholders. + # TPU expects $0, $1 etc as placeholders. fmt = "".join( f"{text}${idx}" for idx, (text, _, _, _) in enumerate(string.Formatter().parse(fmt)) @@ -3687,9 +3771,7 @@ def _debug_print_rule( return () -lowering_rules[primitives.debug_print_p] = _debug_print_rule - - +@register_lowering_rule(tpu_primitives.prng_seed_p) def _prng_seed_lowering_rule(ctx: LoweringRuleContext, *seeds): del ctx # In the KeyScalarBundle case we unpack the bundle and set the seed with @@ -3705,9 +3787,9 @@ def _prng_seed_lowering_rule(ctx: LoweringRuleContext, *seeds): raise ValueError(f"All seed data must be scalar integers. Got {seed_types}") tpu.prng_set_seed_32(seeds) return [] -lowering_rules[tpu_primitives.prng_seed_p] = _prng_seed_lowering_rule +@register_lowering_rule(tpu_primitives.prng_random_bits_p) def _prng_random_bits_lowering_rule(ctx: LoweringRuleContext, *, shape): if len(shape) <= 1: # TODO(b/342054464): Support implicit dims for PRNGRandomBitsOp. @@ -3717,15 +3799,15 @@ def _prng_random_bits_lowering_rule(ctx: LoweringRuleContext, *, shape): ctx.lowering_context.dynamic_shape_replacement_fn, out_aval ) return tpu.prng_random_bits(out_type) -lowering_rules[tpu_primitives.prng_random_bits_p] = _prng_random_bits_lowering_rule +@register_lowering_rule(prng.random_seed_p) def random_seed_lowering(ctx, seeds, *, impl): seed_lowering = lower_fun(impl.seed, multiple_results=False) return seed_lowering(ctx, seeds) -lowering_rules[prng.random_seed_p] = random_seed_lowering +@register_lowering_rule(prng.random_bits_p) def random_bits_lowering(ctx, keys, *, bit_width, shape): assert bit_width == 32, "Only 32-bit PRNG supported." aval, = ctx.avals_in @@ -3738,99 +3820,90 @@ def new_lowering(key, bit_width, shape): _proxy_fn = new_lowering bits_lowering = lower_fun(_proxy_fn, multiple_results=False) return bits_lowering(ctx, keys, bit_width=bit_width, shape=shape) -lowering_rules[prng.random_bits_p] = random_bits_lowering +@register_lowering_rule(prng.random_fold_in_p) def random_fold_in_lowering(ctx, keys, msgs): - keys_aval, _ = ctx.avals_in + keys_aval, msgs_aval = ctx.avals_in impl = keys_aval.dtype._impl fold_in_lowering = lower_fun(impl.fold_in, multiple_results=False) - return fold_in_lowering(ctx, keys, msgs) -lowering_rules[prng.random_fold_in_p] = random_fold_in_lowering + if pl_random.is_pallas_impl(impl): + return fold_in_lowering(ctx, keys, msgs) + else: + ctx = dataclasses.replace(ctx, + avals_in=[jax_core.physical_aval(keys_aval), msgs_aval], + avals_out=map(jax_core.physical_aval, ctx.avals_out)) + return fold_in_lowering(ctx, keys, msgs) +@register_lowering_rule(prng.random_unwrap_p) def random_unwrap_lowering(ctx, key): keys_aval = ctx.avals_in[0] impl = keys_aval.dtype._impl if not pl_random.is_pallas_impl(impl): return key - assert isinstance(key, KeyScalarBundle) - # Convert to a vector. - if tuple(key.key_shape) != (1, 1): - raise NotImplementedError( - "Seed key_data of shape != (1, 1) not supported. " - f"Got: {key.key_shape}") - scalar = key.scalars[0] - out_type = ir.VectorType.get( - key.key_shape, _dtype_to_ir_type(jnp.dtype('int32')) + raise ValueError( + "key_data not support for Pallas PRNG keys. Use" + " split_pallas_seed instead." ) - val = vector.broadcast(out_type, scalar) - return val -lowering_rules[prng.random_unwrap_p] = random_unwrap_lowering +@register_lowering_rule(prng.random_wrap_p) def random_wrap_lowering(ctx, key_data, *, impl): del ctx if not pl_random.is_pallas_impl(impl): return key_data - if isinstance(key_data.type, ir.VectorType): - # If the key data lives in vregs, need to unpack it to sregs. - key_data_list = [] - key_data_shape = key_data.type.shape - if len(key_data_shape) != 2: - raise NotImplementedError("Seed key_data must be 2D.") - if tuple(key_data_shape) != (1, 1): - raise NotImplementedError( - "Seed key_data of shape != (1, 1) not supported. " - f"Got: {key_data_shape}") - for i in range(key_data_shape[1]): - key_data_list.append(vector.ExtractOp(key_data, [], [0, i])) - return KeyScalarBundle( - scalars=key_data_list, key_shape=tuple(key_data_shape)) - if isinstance(key_data, KeyScalarBundle): - return key_data - else: - raise NotImplementedError(f"key_data wrap {type(key_data)}") + raise ValueError( + "wrap_key_data not support for Pallas PRNG keys. Use" + " wrap_pallas_seed instead." + ) -lowering_rules[prng.random_wrap_p] = random_wrap_lowering -def _checkify_lowering_rule( - ctx: LoweringRuleContext, *err_args, err_tree, debug): - if not tpu_core.runtime_assert_enabled(): - if debug: - return [] - else: - raise LoweringException("Non-debug check must be functionalized. " - "Enable runtime asserts with " - "--jax_pallas_enable_runtime_assert " - "or functionalize with checkify.check.") - - assert ctx.lowering_context.ir_context.allow_unregistered_dialects, ( - "allow_unregistered_dialects must be set to True for " - "runtime assert check.") +@register_lowering_rule(tpu_primitives.split_key_p) +def _split_key_lowering_rule( + ctx: LoweringRuleContext, key_data: KeyScalarBundle +): + return key_data.scalars + + +@register_lowering_rule(tpu_primitives.join_key_p) +def _join_key_lowering_rule(ctx: LoweringRuleContext, *scalars, impl): + if not pl_random.is_pallas_impl(impl): + return ValueError(f"Can only join Pallas keys. Got impl={impl}") + return KeyScalarBundle(scalars=scalars, key_shape=impl.key_shape) + + +@register_lowering_rule(checkify.check_p) +def _check_lowering_rule( + ctx: LoweringRuleContext, *err_args, err_tree, debug +): + del ctx # Unused. + + if not debug: + raise NotImplementedError( + "Non-debug checks are not supported by the Mosaic backend." + " Functionalize them via `jax.experimental.checkify`." + ) + if not pallas_helpers.debug_checks_enabled(): + return [] + error = jax.tree.unflatten(err_tree, err_args) - assert len(error._pred) == 1 - assert len(error._metadata) == 1 - assert len(error._payload) == 1 - pred = list(error._pred.items())[0][1] - metadata = list(error._metadata.items())[0] - payload = list(error._payload.items())[0][1] - exception_tree = metadata[1] + [pred] = error._pred.values() + [exception_tree] = error._metadata.values() + [payload] = error._payload.values() exception = jax.tree.unflatten(exception_tree, payload) assert isinstance(exception, checkify.FailedCheckError) + assert isinstance(exception, checkify.FailedCheckError) - # check_p has an inverted predicate compared to assert, - # so we need to compute not(pred) here. - out_scalar_type = _dtype_to_ir_type(jnp.dtype('bool')) - minus_one = ir_constant(-1, out_scalar_type) + # check_p has an inverted predicate compared to assert, so we need to compute + # ``not pred`` here. + minus_one = ir_constant(-1, _dtype_to_ir_type(jnp.bool)) not_pred = arith.xori(pred, minus_one) - attrs = {"msg": ir.StringAttr.get(exception.fmt_string)} - ir.Operation.create("cf.assert", - operands=(not_pred,), - attributes=attrs) + cf.assert_(not_pred, exception.fmt_string) return [] -lowering_rules[checkify.check_p] = _checkify_lowering_rule + +@register_lowering_rule(prng.threefry2x32_p) def _threefry2x32_lowering(ctx, k1, k2, m1, m2): def _lower_fun(k1, k2, m1, m2): with jax.named_scope("threefry2x32"): @@ -3841,9 +3914,7 @@ def _lower_fun(k1, k2, m1, m2): return threefry_lowering(ctx, k1, k2, m1, m2) -lowering_rules[prng.threefry2x32_p] = _threefry2x32_lowering - - +@register_lowering_rule(prng.iota_2x32_shape_p) def _iota_2x32_shape_lowering(ctx, *, shape): total_elements = np.prod(shape) if total_elements > np.iinfo(jnp.int32).max: @@ -3865,9 +3936,7 @@ def _lower_fun(shape): return iota_lowering(ctx, shape=shape) -lowering_rules[prng.iota_2x32_shape_p] = _iota_2x32_shape_lowering - - +@register_lowering_rule(lax.pad_p) def _pad_lowering_rule(ctx: LoweringRuleContext, *args, **kwargs): operand, padding_value = args padding_config = kwargs["padding_config"] @@ -3894,7 +3963,7 @@ def _pad(val): pad = vector.broadcast(pad_vec_type, padding_value) else: scalar_attr = ir.FloatAttr.get(operand.type.element_type, padding_value) - pad = arith.ConstantOp( + pad = arith.constant( pad_vec_type, ir.DenseElementsAttr.get_splat( pad_vec_type, @@ -3929,38 +3998,25 @@ def _pad(val): return operand -lowering_rules[lax.pad_p] = _pad_lowering_rule - - +@register_lowering_rule(control_flow.platform_index_p) def _platform_index_lowering( ctx: mlir.LoweringRuleContext, *, - platforms: Sequence[Sequence[str]], - has_default: bool, + platforms: BranchesPlatforms, ): for i, ps in enumerate(platforms): # note - slightly odd structure here, as platforms is a seq[seq[str]] - if "mosaic" in ps: + if "mosaic" in ps or ps is None: return ir_constant(i) - if has_default: - return ir_constant(len(platforms)) - raise NotImplementedError( "No mosaic or default platform indexing rule found." ) -lowering_rules[jax._src.lax.control_flow.platform_index_p] = _platform_index_lowering - - -def _dim_as_value_lowering(ctx: mlir.LoweringRuleContext, *, dim): +@register_lowering_rule(shape_poly.dim_as_value_p) +def _dim_as_value_lowering(ctx: LoweringRuleContext, *, dim): placeholder = ctx.lowering_context.dynamic_shape_replacement_fn((dim,))[0] return ir_constant( placeholder, mlir_type=_dtype_to_ir_type(jnp.dtype("int32")) ) - - -import jax._src.export.shape_poly as shape_poly - -lowering_rules[shape_poly.dim_as_value_p] = _dim_as_value_lowering diff --git a/jax/_src/pallas/mosaic/pallas_call_registration.py b/jax/_src/pallas/mosaic/pallas_call_registration.py index 896af0c464c5..8944e06443e9 100644 --- a/jax/_src/pallas/mosaic/pallas_call_registration.py +++ b/jax/_src/pallas/mosaic/pallas_call_registration.py @@ -16,9 +16,10 @@ from __future__ import annotations +from collections.abc import Sequence import os import tempfile -from typing import Any +from typing import cast import jax from jax import dtypes @@ -28,7 +29,6 @@ from jax._src import tpu_custom_call from jax._src.interpreters import mlir from jax._src.lib.mlir import ir -from jax._src.pallas import core from jax._src.pallas import core as pallas_core from jax._src.pallas.mosaic import core as tpu_core from jax._src.pallas.mosaic import lowering @@ -72,7 +72,7 @@ def _get_memory_space_from_aval( ) -> tpu_custom_call.MemorySpace | None: if not isinstance(out_aval, jax_core.ShapedArray): raise ValueError('Memory spaces not defined for non-ShapedArrays') - if not isinstance(out_aval, core.ShapedArrayWithMemorySpace): + if not isinstance(out_aval, pallas_core.ShapedArrayWithMemorySpace): # If we are passed a regular old ShapedArray, we don't constrain the # memory space return None @@ -81,39 +81,41 @@ def _get_memory_space_from_aval( match out_aval.memory_space: case None: return None - case tpu_core.TPUMemorySpace.ANY: + case tpu_core.MemorySpace.ANY: return None - case tpu_core.TPUMemorySpace.VMEM: + case tpu_core.MemorySpace.HBM: + return tpu_custom_call.MemorySpace.HBM + case tpu_core.MemorySpace.VMEM: return tpu_custom_call.MemorySpace.VMEM - case tpu_core.TPUMemorySpace.SMEM: + case tpu_core.MemorySpace.SMEM: return tpu_custom_call.MemorySpace.SMEM - case tpu_core.TPUMemorySpace.SEMAPHORE: + case tpu_core.MemorySpace.SEMAPHORE: return tpu_custom_call.MemorySpace.SEMAPHORE_MEM return None def _get_memory_spaces_from_avals( - out_avals: tuple[jax_core.AbstractValue, ...], + avals: Sequence[jax_core.AbstractValue], ) -> tuple[tpu_custom_call.MemorySpace | None, ...] | None: - output_memory_spaces = None + memory_spaces = None if any( - isinstance(out_aval, core.ShapedArrayWithMemorySpace) - for out_aval in out_avals + isinstance(aval, pallas_core.ShapedArrayWithMemorySpace) for aval in avals ): - output_memory_spaces = tuple(map(_get_memory_space_from_aval, out_avals)) - return output_memory_spaces + memory_spaces = tuple(map(_get_memory_space_from_aval, avals)) + return memory_spaces + def pallas_call_tpu_lowering_rule( ctx: mlir.LoweringRuleContext, *in_nodes, jaxpr: jax_core.Jaxpr, - grid_mapping: core.GridMapping, + grid_mapping: pallas_core.GridMapping, mesh: pallas_core.Mesh | None, input_output_aliases: tuple[tuple[int, int], ...], debug: bool, interpret: bool, - compiler_params: dict[str, Any], - cost_estimate: core.CostEstimate | None, + compiler_params: dict[str, pallas_core.CompilerParams], + cost_estimate: pallas_core.CostEstimate | None, out_avals: tuple[jax_core.AbstractValue, ...], ): """Lowers a pallas_call to a Mosaic TPU custom call.""" @@ -123,10 +125,13 @@ def pallas_call_tpu_lowering_rule( if debug: print(f"\nThe kernel jaxpr for pallas_call {debug_info.func_src_info}:") print(jaxpr) - if "mosaic" in compiler_params: - mosaic_params = compiler_params["mosaic"] + + if "mosaic_tpu" in compiler_params: + mosaic_params = cast( + tpu_core.CompilerParams, compiler_params["mosaic_tpu"] + ) else: - mosaic_params = {} + mosaic_params = tpu_core.CompilerParams() jax_mesh = None axis_context = ctx.module_context.axis_context @@ -139,16 +144,15 @@ def pallas_call_tpu_lowering_rule( tpu.register_dialect(mlir_ctx) def lower_module(for_verification: bool): - if for_verification or tpu_core.runtime_assert_enabled(): + if for_verification: mlir_ctx.allow_unregistered_dialects = True with mlir_ctx, ir.Location.unknown(mlir_ctx): - dimension_semantics = mosaic_params.get("dimension_semantics", None) return lowering.lower_jaxpr_to_module( ctx, - mlir_ctx, grid_mapping, jaxpr, - dimension_semantics=dimension_semantics, + dimension_semantics=mosaic_params.dimension_semantics, + kernel_type=mosaic_params.kernel_type, mesh=jax_mesh, for_verification=for_verification, dynamic_shape_replacement_enabled=pallas_core.dynamic_shapes_export_enabled(), @@ -215,6 +219,18 @@ def _maybe_cast_inputs(*args): dynamic_grid_args, args = in_nodes[:num_dyn_bounds], in_nodes[num_dyn_bounds:] kernel_ctx = ctx.replace(avals_in=kernel_in_avals, avals_out=kernel_out_avals) output_memory_spaces = _get_memory_spaces_from_avals(out_avals) + input_memory_spaces = None + if any( + isinstance(aval, pallas_core.ShapedArrayWithMemorySpace) + for aval in ctx.avals_in + ): + # TODO(sharadmv): Support dynamic grid bounds and extra args. + if num_dyn_bounds != 0 or len(extra_args) > 0: + raise NotImplementedError( + "Dynamic grid bounds and extra args are not supported when" + " specifying memory spaces for inputs." + ) + input_memory_spaces = _get_memory_spaces_from_avals(ctx.avals_in) if cost_estimate is not None: mosaic_cost_estimate = tpu_custom_call.CostEstimate( flops=cost_estimate.flops, @@ -223,6 +239,15 @@ def _maybe_cast_inputs(*args): ) else: mosaic_cost_estimate = None + if input_memory_spaces is None and output_memory_spaces is not None: + input_memory_spaces_list: list[tpu_custom_call.MemorySpace | None] = [ + None, + ] * len(ctx.avals_in) + for input_output_alias in input_output_aliases: + input_memory_spaces_list[input_output_alias[0]] = output_memory_spaces[ + input_output_alias[1] + ] + input_memory_spaces = tuple(input_memory_spaces_list) out_nodes = mosaic.lower_module_to_custom_call( kernel_ctx, *dynamic_grid_args, @@ -233,16 +258,17 @@ def _maybe_cast_inputs(*args): backend="tpu", kernel_name=mlir.sanitize_name(debug_info.func_name), cost_estimate=mosaic_cost_estimate, - vmem_limit_bytes=mosaic_params.get("vmem_limit_bytes"), - flags=mosaic_params.get("flags"), - allow_input_fusion=mosaic_params.get("allow_input_fusion"), + vmem_limit_bytes=mosaic_params.vmem_limit_bytes, + flags=mosaic_params.flags, + allow_input_fusion=mosaic_params.allow_input_fusion, input_output_aliases=input_output_aliases, - serialization_format=mosaic_params.get("serialization_format", 1), - device_type=mosaic_params.get("device_type"), - internal_scratch_in_bytes=mosaic_params.get("internal_scratch_in_bytes"), - collective_id=mosaic_params.get("collective_id", None), - has_side_effects=mosaic_params.get("has_side_effects", False), + serialization_format=mosaic_params.serialization_format, + internal_scratch_in_bytes=mosaic_params.internal_scratch_in_bytes, + collective_id=mosaic_params.collective_id, + has_side_effects=mosaic_params.has_side_effects, output_memory_spaces=output_memory_spaces, + disable_bounds_checks=mosaic_params.disable_bounds_checks, + input_memory_spaces=input_memory_spaces, ) _maybe_cast_to_bool = lambda x, aval: x.astype( jax.numpy.bool_) if aval.dtype == jax.numpy.bool_ else x diff --git a/jax/_src/pallas/mosaic/pipeline.py b/jax/_src/pallas/mosaic/pipeline.py index 184b1497adf9..c766d6ec16b5 100644 --- a/jax/_src/pallas/mosaic/pipeline.py +++ b/jax/_src/pallas/mosaic/pipeline.py @@ -20,8 +20,6 @@ import dataclasses import enum import functools -import itertools -import operator from typing import Any, Union import jax @@ -31,6 +29,7 @@ from jax._src.pallas import core as pallas_core from jax._src.pallas import primitives as primitives from jax._src.pallas.mosaic import core as tpu_core +from jax._src.pallas.mosaic import helpers as tpu_helpers from jax._src.pallas.mosaic import primitives as tpu_primitives from jax.experimental import pallas as pl from jax.extend.backend import get_default_device @@ -38,9 +37,8 @@ import numpy as np -SMEM = tpu_core.TPUMemorySpace.SMEM -VMEM = tpu_core.TPUMemorySpace.VMEM -DMA = tpu_core.SemaphoreType.DMA +SMEM = tpu_core.MemorySpace.SMEM +VMEM = tpu_core.MemorySpace.VMEM REF = pallas_core.MemoryRef GridDimensionSemantics = tpu_core.GridDimensionSemantics PARALLEL = tpu_core.PARALLEL @@ -82,8 +80,11 @@ def _get_tpu_generation() -> int: kind = get_default_device().device_kind if kind.endswith(' lite'): kind = kind[:-len(' lite')] - assert kind[:5] == "TPU v", kind - return int(kind[5]) + if kind.startswith("TPU v"): + return int(kind[5]) + else: + assert "TPU7x" in kind + return 7 def _make_tiling(shape: tuple[int, ...], dtype: np.dtype) -> tuple[int, ...]: # For a n-dimensional shape, returns (8, 128) for the last 2 dimensions @@ -103,14 +104,16 @@ def _make_tiling(shape: tuple[int, ...], dtype: np.dtype) -> tuple[int, ...]: return (*(1,) * len(leading_dims), second_minor_tiling, _TILING[1]) -def _round_up_to_nearest_multiple(s: int, multiple: int) -> int: - if s % multiple == 0: +def _round_up_to_nearest_multiple( + s: int | jax.Array, multiple: int +) -> int | jax.Array: + if isinstance(s, int) and s % multiple == 0: return s # Subtract off the remainder, then add multiple return s - s % multiple + multiple -def _make_ds( +def _make_block_ds( idx: jax.Array | int, size: jax.Array | int ) -> pl.Slice: """Make a DMA slice with mosaic size hints.""" @@ -118,33 +121,87 @@ def _make_ds( assert isinstance(out, pl.Slice) return out - -def _make_block_slice( - block_index: jax.Array, block_size: int, size: int, tiling: int -) -> pl.Slice | slice: - # Computes a slice given a block index and block size. In the default case, - # we return slice(block_index * block_size, (block_index + 1) * block_size). - # However, if the total size of the ref does not divide block size and we are - # selecting the last block, we need to pick the lowest tiling size multiple - # that contains the block. - if size % block_size == 0: - return _make_ds(block_index, block_size) +def _create_blocked_slice(block_index: jax.Array | int, + block_size: int, + dim_size: int, + tiling: int): + block_start = block_size * block_index + if (dim_rem := dim_size % block_size) == 0: + return pl.ds(block_start, block_size) if block_size % tiling != 0: raise ValueError(f"Block size must divide tiling: {block_size=}, {tiling=}") - num_blocks = pl.cdiv(size, block_size) + num_blocks = pl.cdiv(dim_size, block_size) is_last = block_index == num_blocks - 1 rounded_size = jnp.where( is_last, - _round_up_to_nearest_multiple(size % block_size, tiling), + _round_up_to_nearest_multiple(dim_rem % block_size, tiling), block_size, ) rounded_size = pl.multiple_of(rounded_size, tiling) return pl.ds(block_index * block_size, rounded_size) +def _create_bounded_slice(slice_start: jax.Array | int, + slice_size: jax.Array | int, + block_size: int, + dim_size: int, + tiling: int): + if block_size % tiling != 0: + raise ValueError(f"Block size must divide tiling: {block_size=}, {tiling=}") + # We assume by construction that slice_size <= block_size. We also assume + # that the slice_start is already aligned to the tiling. + + # If we are out of bound, we need to round the slice size down to the nearest + # multiple of the tiling. + is_oob = slice_start + slice_size > dim_size + remaining = dim_size - slice_start + rounded_size = jnp.where( + is_oob, + _round_up_to_nearest_multiple(remaining, tiling), + slice_size, + ) + rounded_size = pl.multiple_of(rounded_size, tiling) + return pl.ds(slice_start, rounded_size) + +def _make_block_slice( + block_index: jax.Array, block_size: pl.BlockDim | int | None, size: int, + tiling: int +) -> pl.Slice | slice | int | jax.Array: + # Computes a slice given a block index and block size. In the default case, + # we return slice(block_index * block_size, (block_index + 1) * block_size). + # However, if the total size of the ref does not divide block size and we are + # selecting the last block, we need to pick the lowest tiling size multiple + # that contains the block. + match block_size: + case pl.Blocked(): + return _create_blocked_slice(block_index, block_size.block_size, size, tiling) + case int(): + return _create_blocked_slice(block_index, block_size, size, tiling) + case pl.Element(): + block_start = block_index + block_size = block_size.block_size + return _create_bounded_slice( + block_start, block_size, block_size, size, tiling + ) + case pl.BoundedSlice(block_size): + if not isinstance(block_index, pl.Slice): + raise ValueError( + "Must return a pl.ds from the index_map for a BoundedSlice" + " dimension." + ) + slice_start = block_index.start + slice_size = block_index.size + return _create_bounded_slice( + slice_start, slice_size, block_size, size, tiling + ) + case None | pl.Squeezed(): + return block_index + case _: + raise ValueError(f"Unsupported block dimension type: {block_size}") + def _tuples_differ(xs, ys): """Dynamic index-tuple comparison calculation.""" - differences = jax.tree.map(lambda x, y: x != y, xs, ys) + differences = jax.tree.leaves(jax.tree.map(lambda x, y: x != y, xs, ys)) return functools.reduce(lambda x, y: x | y, differences, False) @@ -156,20 +213,6 @@ def _grid_size(grid): return size -def _get_indices(step, grid, offsets): - """Get indices for a given step and grid.""" - # TODO(enriqueps): Implement using bitwise ops, avoid div/rem since they are - # expensive. - extended_grid = grid + (1,) - strides = tuple( - itertools.accumulate(extended_grid[::-1], func=operator.mul))[::-1] - indices = tuple( - lax.div(lax.rem(step, a), b) - for a, b in zip(strides[:-1], strides[1:]) - ) - return tuple(a + b for a, b in zip(indices, offsets, strict=True)) - - class BufferType(enum.Enum): """Buffer type for the arguments to an emitted pipeline.""" INPUT = 1 @@ -179,10 +222,163 @@ class BufferType(enum.Enum): MANUAL = 5 +def _get_block_shape(spec: pl.BlockSpec) -> tuple[int, ...]: + """Get the block shape for a given block spec.""" + def _get_dim_size(bd): + match bd: + case pl.Blocked(block_size): + return block_size + case pl.Element(block_size): + return block_size + case pl.BoundedSlice(block_size): + return block_size + case int(): + return bd + case None | pl.Squeezed(): + return None + case _: + raise ValueError(f"Unsupported block dimension type: {bd}") + if spec.block_shape is None: + raise ValueError("Block shape must be specified.") + block_shape_nones = tuple(_get_dim_size(x) for x in spec.block_shape) + return tuple(x for x in block_shape_nones if x is not None) + + +class BufferedRefBase: + """Abstract interface for BufferedRefs.""" + + @property + def spec(self) -> pl.BlockSpec: + raise NotImplementedError() + + @property + def buffer_type(self) -> BufferType: + raise NotImplementedError() + + @property + def is_input(self): + return self.buffer_type in [ + BufferType.INPUT, + BufferType.ACCUMULATOR, + BufferType.INPUT_OUTPUT, + ] + + @property + def is_output(self): + return self.buffer_type in [ + BufferType.OUTPUT, + BufferType.ACCUMULATOR, + BufferType.INPUT_OUTPUT, + ] + + @property + def is_accumulator(self): + return self.buffer_type == BufferType.ACCUMULATOR + + @property + def is_input_output(self): + return self.buffer_type == BufferType.INPUT_OUTPUT + + @property + def is_manual(self): + return self.buffer_type == BufferType.MANUAL + + def init_slots(self): + """Initialize slot indices.""" + raise NotImplementedError() + + def swap_slots(self, predicate: bool = True) -> BufferedRefBase: + """Switch to the next slot.""" + raise NotImplementedError() + + def load_slots(self) -> BufferedRefBase: + """Load slot information into registers.""" + raise NotImplementedError() + + def save_slots(self): + """Save slot information from registers.""" + raise NotImplementedError() + + @property + def block_shape(self) -> Sequence[pl.BlockDim | int | None] | None: + return self.spec.block_shape + + @property + def compute_index(self): + return self.spec.index_map + + def get_dma_slice(self, src_shape, src_dtype, grid_indices): + # We need to handle blocks that might go OOB in the src array. An in bounds + # block looks like this (for array shape (600, 600) and block shape + # (256, 256)): + # + # +--------------+------------------| + # | Block (0,0) | | + # | (256, 256) | | + # +--------------+ | + # | A (600, 600) | + # | | + # +---------------------------------+ + # + # For in-bounds blocks, we don't need to do anything special. + # An out-of-bounds block looks like this: + # + # +--------------+------------------| + # | | + # | | + # + | + # | A (600, 600) | + # +--------------+ | + # | Block (2,0) | | + # + --------------------------------| + # | XXXXXXXXXX | + # +--------------+ + # where the X's indicate where the block is out of bounds. + # + # When we have an out of bounds block like this, we need to truncate it to + # a tile boundary (tiles are (8, 128) along the two minormost dimensions). + # In this case, we'll have a block that is indexing the + # 512:768 elements of A along the first dimension. We need to convert 768 + # into 600 (600 % 8 == 0), so our indexing will look like this: + + # +--------------+------------------| + # | | + # | | + # + | + # | A (600, 600) | + # +--------------+ | + # | Block (2,0) | | + # + --------------------------------| + # where it is now a (88, 256) sized block. + # + # Suppose A is now (601, 600), instead of picking a (88, 256)-sized block + # for the last iteration on that dimension, we will pick the next highest + # tile multiple, i.e. (96, 256). + if len(src_shape) < 2: + raise NotImplementedError("Must use >1D values.") + + tiling = _make_tiling(src_shape, src_dtype) + block_indices = self.compute_index(*grid_indices) + return tuple( + _make_block_slice(bi, bs, ss, t) + for bi, bs, ss, t in zip( + block_indices, self.block_shape, src_shape, tiling, strict=True + ) + ) + + def bind_existing_ref(self, window_ref, indices): + """For handling VMEM references, the pipeline aliases the existing ref.""" + del window_ref, indices + return self + + def with_spec(self, spec: pl.BlockSpec) -> BufferedRefBase: + """Returns a new BufferedRefBase with the given block spec.""" + raise NotImplementedError() + @tree_util.register_pytree_node_class @dataclasses.dataclass(frozen=True) -class BufferedRef: +class BufferedRef(BufferedRefBase): """A helper class to automate VMEM double buffering in pallas pipelines. Attributes: @@ -195,7 +391,6 @@ class BufferedRef: reference, this simply points to the existing ref. accum_ref: accumulating buffer used by accumulator BufferedRefs. current_slot: current slot index to the working buffer. - next_slot: slot that will point to the working buffer in the next iteration. sem_recvs: Double buffered semaphores for input DMAs. sem_sends: Double buffered semaphores for output DMAs. block_shape: passthrough property for the BlockSpec's block_shape. @@ -210,33 +405,39 @@ class BufferedRef: swap: Tracks whether the BufferedRef slots need to be swapped before next copy. """ - spec: pl.BlockSpec # static metadata + _spec: pl.BlockSpec # static metadata dtype: Any # static metadata - buffer_type: BufferType # static metadata - window_ref: REF | None - accum_ref: REF | None + _buffer_type: BufferType # static metadata + _current_slot_reg: int | jax.Array | None + window_ref: ArrayRef | None + accum_ref: ArrayRef | None current_slot: ArrayRef | None - # TODO(ramiroleal): Unused by class. Remove argument from - # BufferedRef instantiations. - next_slot: ArrayRef | None sem_recvs: SemaphoreTuple | None sem_sends: SemaphoreTuple | None # TODO(ramiroleal): Improve prefetch/postyeet interface to avoid # using this ref. swap: ArrayRef | None + @property + def spec(self): + return self._spec + + @property + def buffer_type(self): + return self._buffer_type + def tree_flatten(self): return ( ( + self._current_slot_reg, self.window_ref, self.accum_ref, self.current_slot, - self.next_slot, self.sem_recvs, self.sem_sends, self.swap, ), - (self.spec, self.dtype, self.buffer_type), + (self._spec, self.dtype, self._buffer_type), ) @classmethod @@ -248,7 +449,8 @@ def buffer_types() -> type[BufferType]: return BufferType @classmethod - def create(cls, spec, dtype, buffer_type, needs_swap_ref=True) -> BufferedRef: + def create(cls, spec: pl.BlockSpec, dtype, buffer_type, needs_swap_ref=True + ) -> BufferedRef: """Create a BufferedRef. Args: @@ -261,7 +463,7 @@ def create(cls, spec, dtype, buffer_type, needs_swap_ref=True) -> BufferedRef: Returns: Initialized BufferedRef """ - block_shape = tuple(1 if x is None else x for x in spec.block_shape) + block_shape = _get_block_shape(spec) if buffer_type is BufferType.ACCUMULATOR: accum_ref = VMEM(block_shape, dtype) else: @@ -271,13 +473,13 @@ def create(cls, spec, dtype, buffer_type, needs_swap_ref=True) -> BufferedRef: # reference is already in VMEM, we just need allocate the accumulation # buffer and we will refer to the original reference slices directly. return cls( - spec=spec, + _spec=spec, dtype=dtype, - buffer_type=buffer_type, + _buffer_type=buffer_type, + _current_slot_reg=None, window_ref=None, # to be bound to existing ref by the pipeline routine accum_ref=accum_ref, current_slot=None, - next_slot=None, sem_recvs=None, sem_sends=None, swap=None, @@ -285,13 +487,13 @@ def create(cls, spec, dtype, buffer_type, needs_swap_ref=True) -> BufferedRef: else: memory_space = SMEM if spec.memory_space == SMEM else VMEM return cls( - spec=spec, + _spec=spec, dtype=dtype, - buffer_type=buffer_type, + _buffer_type=buffer_type, + _current_slot_reg=None, window_ref=memory_space((2,) + block_shape, dtype), accum_ref=accum_ref, current_slot=SMEM((1,), jnp.int32), - next_slot=None, sem_recvs=( None if buffer_type is BufferType.OUTPUT @@ -333,45 +535,39 @@ def compute_index(self): def memory_space(self): return self.spec.memory_space + def with_spec(self, spec: pl.BlockSpec) -> BufferedRef: + """Returns a new BufferedRef with the given block spec.""" + return dataclasses.replace(self, _spec=spec) + + def with_slot_index( + self, slot_index: int | jax.Array | None + ) -> BufferedRef: + """Returns a new BufferedRef with the given slot index.""" + return dataclasses.replace(self, _current_slot_reg=slot_index) + @property def current_ref(self): buffer_slice = tuple( - 0 if x is None else slice(None) for x in self.block_shape) + slice(None) + for x in self.block_shape + if not (x is None or isinstance(x, pl.Squeezed)) + ) + assert not (self.window_ref is None or isinstance(self.window_ref, REF)) if self.memory_space == VMEM: return self.window_ref.at[buffer_slice] else: return self.window_ref.at[(self.current_slot_index, *buffer_slice)] - @property - def is_input(self): - return self.buffer_type in [ - BufferType.INPUT, - BufferType.ACCUMULATOR, - BufferType.INPUT_OUTPUT, - ] - - @property - def is_output(self): - return self.buffer_type in [ - BufferType.OUTPUT, - BufferType.ACCUMULATOR, - BufferType.INPUT_OUTPUT, - ] - - @property - def is_accumulator(self): - return self.buffer_type == BufferType.ACCUMULATOR - - @property - def is_input_output(self): - return self.buffer_type == BufferType.INPUT_OUTPUT - @property def current_slot_index(self): + """Index in double buffer corresponding to the current slot.""" + if self._current_slot_reg is not None: + return self._current_slot_reg return self.current_slot[0] @property def next_slot_index(self): + """Index in double buffer corresponding to the next slot.""" return lax.rem(self.current_slot_index + 1, 2) def bind_existing_ref(self, window_ref, indices): @@ -384,9 +580,29 @@ def bind_existing_ref(self, window_ref, indices): def compute_slice(self, grid_indices): """Compute DMA slice from grid indices.""" - block_shape = tuple(1 if x is None else x for x in self.block_shape) indices = self.compute_index(*grid_indices) - return jax.tree.map(_make_ds, indices, block_shape) + assert len(self.block_shape) == len(indices) + indexer = [] + for bd, idx in zip(self.block_shape, indices, strict=True): + match bd: + case None | pl.Squeezed(): + # Dimension is squeezed out so we don't do anything. + indexer.append(idx) + case pl.Element(): + raise ValueError( + "Element block dimensions are not supported." + ) + case pl.BoundedSlice(): + raise ValueError( + "BoundedSlice block dimensions are not supported." + ) + case pl.Blocked(block_size): + indexer.append(_make_block_ds(idx, block_size)) + case int(): + indexer.append(_make_block_ds(idx, bd)) + case _: + raise ValueError(f"Unsupported block dimension type: {type(bd)}") + return tuple(indexer) def init_slots(self): """Initialize slot indices.""" @@ -395,82 +611,55 @@ def init_slots(self): if self.swap is not None: self.swap[0] = False - def swap_slots(self): - """Switch to the next slot.""" - if self.memory_space == VMEM: return - self.current_slot[0] = self.next_slot_index + def swap_slots(self, predicate: bool | jax.Array = True) -> BufferedRef: + if self.memory_space == VMEM: + return self if self.swap is not None: + assert isinstance(self.swap, jax.Array) + predicate = self.swap[0] self.swap[0] = False + new_current_slot = lax.select( + predicate, self.next_slot_index, self.current_slot_index + ) + if self._current_slot_reg is not None: + return self.with_slot_index(new_current_slot) + assert isinstance(self.current_slot, jax.Array) + self.current_slot[0] = new_current_slot + return self - def get_dma_slice(self, src_shape, src_dtype, grid_indices): - # We need to handle blocks that might go OOB in the src array. An in bounds - # block looks like this (for array shape (600, 600) and block shape - # (256, 256)): - # - # +--------------+------------------| - # | Block (0,0) | | - # | (256, 256) | | - # +--------------+ | - # | A (600, 600) | - # | | - # +---------------------------------+ - # - # For in-bounds blocks, we don't need to do anything special. - # An out-of-bounds block looks like this: - # - # +--------------+------------------| - # | | - # | | - # + | - # | A (600, 600) | - # +--------------+ | - # | Block (2,0) | | - # + --------------------------------| - # | XXXXXXXXXX | - # +--------------+ - # where the X's indicate where the block is out of bounds. - # - # When we have an out of bounds block like this, we need to truncate it to - # a tile boundary (tiles are (8, 128) along the two minormost dimensions). - # In this case, we'll have a block that is indexing the - # 512:768 elements of A along the first dimension. We need to convert 768 - # into 600 (600 % 8 == 0), so our indexing will look like this: - - # +--------------+------------------| - # | | - # | | - # + | - # | A (600, 600) | - # +--------------+ | - # | Block (2,0) | | - # + --------------------------------| - # where it is now a (88, 256) sized block. - # - # Suppose A is now (601, 600), instead of picking a (88, 256)-sized block - # for the last iteration on that dimension, we will pick the next highest - # tile multiple, i.e. (96, 256). - if len(src_shape) < 2: - raise NotImplementedError("Must use >1D values.") + def load_slots(self) -> BufferedRef: + """Load slot information into registers.""" + if self.memory_space == VMEM: + return self + assert isinstance(self.current_slot, jax.Array) + return self.with_slot_index(self.current_slot[0]) - tiling = _make_tiling(src_shape, src_dtype) - block_shape = tuple(1 if b is None else b for b in self.block_shape) - block_indices = self.compute_index(*grid_indices) - return jax.tree.map( - _make_block_slice, block_indices, block_shape, src_shape, tiling - ) + def save_slots(self): + """Save slot information from registers.""" + if self.memory_space == VMEM: + return + assert isinstance(self.current_slot, jax.Array) + assert self._current_slot_reg is not None + self.current_slot[0] = self._current_slot_reg def copy_in(self, src_ref, grid_indices): """Starts copy of HBM dma slice into the current slot.""" assert self.is_input if self.memory_space == VMEM: return + assert not (self.window_ref is None or isinstance(self.window_ref, REF)) + assert self.sem_recvs is not None if self.swap is not None: self.swap[0] = True next_slot = self.next_slot_index src_slice = self.get_dma_slice(src_ref.shape, src_ref.dtype, grid_indices) - dst_slice = tuple(pl.ds(0, s.size) for s in src_slice) + dst_slice = tuple( + pl.ds(0, s.size) + for s, bd in zip(src_slice, self.block_shape) + if not (bd is None or isinstance(bd, pl.Squeezed)) + ) tpu_primitives.make_async_copy( src_ref.at[src_slice], - self.window_ref.at[next_slot].at[dst_slice], + self.window_ref.at[(next_slot, *dst_slice)], self.sem_recvs.at[next_slot], ).start() @@ -478,13 +667,19 @@ def copy_out(self, dst_ref, grid_indices): """Starts copy of HBM dma slice from the current slot.""" assert self.is_output if self.memory_space == VMEM: return + assert not (self.window_ref is None or isinstance(self.window_ref, REF)) + assert self.sem_sends is not None if self.swap is not None: self.swap[0] = True slot = self.current_slot_index dst_slice = self.get_dma_slice(dst_ref.shape, dst_ref.dtype, grid_indices) - src_slice = tuple(pl.ds(0, s.size) for s in dst_slice) + src_slice = tuple( + pl.ds(0, s.size) + for s, bd in zip(dst_slice, self.block_shape) + if not (bd is None or isinstance(bd, pl.Squeezed)) + ) tpu_primitives.make_async_copy( - self.window_ref.at[slot].at[src_slice], + self.window_ref.at[(slot, *src_slice)], dst_ref.at[dst_slice], self.sem_sends.at[slot], ).start() @@ -493,13 +688,19 @@ def wait_in(self, src_ref, grid_indices): """Waits for input copy to finish.""" assert self.is_input if self.memory_space == VMEM: return + assert not (self.window_ref is None or isinstance(self.window_ref, REF)) + assert self.sem_recvs is not None src_slice = self.get_dma_slice(src_ref.shape, src_ref.dtype, grid_indices) - dst_slice = tuple(pl.ds(0, s.size) for s in src_slice) + dst_slice = tuple( + pl.ds(0, s.size) + for s, bd in zip(src_slice, self.block_shape) + if not (bd is None or isinstance(bd, pl.Squeezed)) + ) current_slot = self.current_slot_index tpu_primitives.make_async_copy( src_ref.at[src_slice], # nb: doesn't matter - self.window_ref.at[current_slot].at[ - dst_slice + self.window_ref.at[ + (current_slot, *dst_slice) ], # only dst shape is important self.sem_recvs.at[current_slot], ).wait() @@ -508,12 +709,18 @@ def wait_out(self, dst_ref, grid_indices): """Waits for output copy to finish.""" assert self.is_output if self.memory_space == VMEM: return + assert not (self.window_ref is None or isinstance(self.window_ref, REF)) + assert self.sem_sends is not None # In a double buffer, previous slot is the same as next slot. prev_slot = self.next_slot_index dst_slice = self.get_dma_slice(dst_ref.shape, dst_ref.dtype, grid_indices) - src_slice = tuple(pl.ds(0, s.size) for s in dst_slice) + src_slice = tuple( + pl.ds(0, s.size) + for s, bd in zip(dst_slice, self.block_shape) + if not (bd is None or isinstance(bd, pl.Squeezed)) + ) tpu_primitives.make_async_copy( - self.window_ref.at[prev_slot].at[src_slice], # nb: doesn't matter + self.window_ref.at[(prev_slot, *src_slice)], # nb: doesn't matter dst_ref.at[dst_slice], # only dst shape is important self.sem_sends.at[prev_slot], ).wait() @@ -533,16 +740,18 @@ def set_accumulator(self, init=False): """Set accumulator or zero it out to initialize.""" assert self.is_accumulator if self.accum_ref is not None: + accum_dtype = self.accum_ref.dtype def _init(): self.accum_ref[...] = jnp.zeros_like(self.accum_ref[...]) def _set(): - self.accum_ref[...] = self.current_ref[...].astype(self.accum_ref.dtype) + self.accum_ref[...] = self.current_ref[...].astype(accum_dtype) lax.cond(init, _init, _set) def accumulate(self): """Add into the current slot.""" assert self.is_accumulator if self.accum_ref is not None: + assert self.window_ref is not None accum_dtype = jnp.float32 if self.window_ref.dtype == jnp.int32: accum_dtype = jnp.int32 @@ -557,7 +766,24 @@ def accumulate(self): # Helper to tree map over BufferedRefs as leaves. map_brefs = functools.partial( jax.tree.map, - is_leaf=lambda x: isinstance(x, BufferedRef)) + is_leaf=lambda x: isinstance(x, BufferedRefBase) +) + +def map_inputs(f, *args): + """Maps over all input BufferedRefs.""" + def fmap(bref, *f_args): + if bref.is_input: + return f(bref, *f_args) + return bref + return map_brefs(fmap, *args) + +def map_outputs(f, *args): + """Maps over all output BufferedRefs.""" + def fmap(bref, *f_args): + if bref.is_output: + return f(bref, *f_args) + return bref + return map_brefs(fmap, *args) def _filter_indices( @@ -606,6 +832,7 @@ def __init__( last_cycle=None, init_accumulators=None, trace_scopes=True, + use_sreg_for_state: bool = False, ): """Initializes scheduler. @@ -619,6 +846,8 @@ def __init__( init_accumulators: do we zero-initialize accumulator state for this invocation of the pipeline. trace_scopes: whether to use named_scope to trace blocks in the pipeline. + use_sreg_for_state: optional bool, indicates whether to use sregs for + current_slot state. """ self.step = step self.grid = grid @@ -626,6 +855,7 @@ def __init__( self.last_cycle = last_cycle self.init_accumulators = init_accumulators self.trace_scopes = trace_scopes + self.use_sreg_for_state = use_sreg_for_state # Total number of linear steps. self.num_steps = _grid_size(grid) @@ -685,18 +915,21 @@ def alias_local_refs(self, buffered_ref, ref): def initialize(self, buffered_ref, src_ref, schedule=None): if schedule is None: schedule = _default_schedule - pred = schedule["prologue_copy_in"](self, buffered_ref, src_ref) + do_copy = schedule["prologue_copy_in"](self, buffered_ref, src_ref) with self._named_scope("ep_initialize"): @pl.when(self.first_step_ever) def _init_slots(): buffered_ref.init_slots() - @pl.when(pred) - def _start(): - if buffered_ref.is_input: - buffered_ref.copy_in(src_ref, self.indices) - buffered_ref.swap_slots() + if self.use_sreg_for_state: + buffered_ref = buffered_ref.load_slots() + + @pl.when(do_copy & buffered_ref.is_input) + def _copy_in(): + buffered_ref.copy_in(src_ref, self.indices) + + return buffered_ref.swap_slots(do_copy & buffered_ref.is_input) def wait_in(self, buffered_ref, src_ref, schedule=None): if schedule is None: @@ -804,29 +1037,24 @@ def _end(): if buffered_ref.is_output: buffered_ref.wait_out(dst_ref, self.indices) - def swap_slots(self, buffered_ref, hbm_ref, schedule=None): - if buffered_ref.swap is not None: - swap = buffered_ref.swap[0] - else: - # If we are not using an SMEM `swap` tensor to keep track of - # swaps needed, then all the copies into and out of BufferedRefs - # are done by direct calls to the `copy_in` and `copy_out` - # methods in the pipeline loop. To determine if the BufferedRef - # needs a swap of slots, we recalculate the copy-in/copy-out - # conditions. - if schedule is None: - schedule = _default_schedule - pred_in = schedule["copy_in"](self, buffered_ref, hbm_ref) - pred_out = schedule["copy_out"](self, buffered_ref, hbm_ref) - - copied_in = pred_in & buffered_ref.is_input & ~self.last_step - copied_out = pred_out & buffered_ref.is_output - swap = copied_in | copied_out - - @pl.when(swap) - @self._named_scope("ep_swap") - def _swap(): - buffered_ref.swap_slots() + if self.use_sreg_for_state: + buffered_ref.save_slots() + + def swap_slots( + self, buffered_ref, hbm_ref, schedule=None + ) -> BufferedRefBase: + # All the copies into and out of BufferedRefs are done by direct + # calls to the `copy_in` and `copy_out` methods in the pipeline + # loop. To determine if the BufferedRef needs a swap of slots, we + # recalculate the copy-in/copy-out conditions. + if schedule is None: + schedule = _default_schedule + pred_in = schedule["copy_in"](self, buffered_ref, hbm_ref) + pred_out = schedule["copy_out"](self, buffered_ref, hbm_ref) + + copied_in = pred_in & buffered_ref.is_input & ~self.last_step + copied_out = pred_out & buffered_ref.is_output + return buffered_ref.swap_slots(copied_in | copied_out) # END SCHEDULE -------------------------------------------------------------- @@ -888,7 +1116,7 @@ def skip_input_copies_when_init_accumulators(schedule) -> Any: def new_pred(original_pred_fn, *a): pred = original_pred_fn(*a) if a[1].is_accumulator or a[1].is_input_output: - pred &= ~a[0].init_accumulators + pred &= jnp.logical_not(a[0].init_accumulators) return pred new_schedule[k] = functools.partial( @@ -975,7 +1203,7 @@ def _partition_grid( num_cores = pl.num_programs(core_axis) core_id = pl.program_id(core_axis) else: - num_cores = jax.lax.psum(1, core_axis) + num_cores = jax.lax.axis_size(core_axis) core_id = jax.lax.axis_index(core_axis) # Check that num_cores is statically known if not isinstance(num_cores, int): @@ -1054,7 +1282,37 @@ def _partition_grid( offsets = jax_util.tuple_update( (0,) * len(grid), partition_dimension, grid_offset ) - return new_grid, offsets + return new_grid, offsets # type: ignore[return-value] + + +def sync_copy(src: REF | BufferedRef, dst: REF | BufferedRef, indices): + """Perform a synchronous copy from src to dst.""" + bref: BufferedRef + hbm_ref: REF + if isinstance(src, BufferedRef): + bref = src + if isinstance(dst, BufferedRef): + raise ValueError("Only one of src or dst can be a BufferedRef.") + hbm_ref = dst + copy_in = False + else: + if not isinstance(dst, BufferedRef): + raise ValueError("One of src or dst must be a BufferedRef.") + bref = dst + hbm_ref = src + copy_in = True + hbm_slice = bref.get_dma_slice(hbm_ref.shape, hbm_ref.dtype, indices) + bref_slice = tuple( + pl.ds(0, s.size) + for s, bd in zip(hbm_slice, bref.block_shape) + if not (bd is None or isinstance(bd, pl.Squeezed)) + ) + if copy_in: + tpu_helpers.sync_copy(hbm_ref.at[hbm_slice], + bref.current_ref.at[bref_slice]) # type: ignore[union-attr] + else: + tpu_helpers.sync_copy(bref.current_ref.at[bref_slice], # type: ignore[union-attr] + hbm_ref.at[hbm_slice]) def emit_pipeline( @@ -1068,6 +1326,8 @@ def emit_pipeline( core_axis_name: str | None = None, dimension_semantics: tuple[GridDimensionSemantics, ...] | None = None, trace_scopes: bool = True, + no_pipelining: bool = False, + use_sreg_for_state: bool = False, ): """Creates a function to emit a manual pallas pipeline. @@ -1094,6 +1354,11 @@ def emit_pipeline( or ARBITRARY). trace_scopes: optional bool, indicates whether to annotate each region in the pipeline using named_scope. + no_pipelining: If True, turns off pipelining and all copies will be + made synchronous. This is useful for debugging multiple-buffering + related bugs. + use_sreg_for_state: optional bool, indicates whether to use sregs for + current_slot state. """ if any(not isinstance(d, (int, jax.Array)) for d in grid): grid_types = tuple(type(d) for d in grid) @@ -1206,14 +1471,16 @@ def make_scheduler(step, indices): last_cycle=last_cycle, init_accumulators=init_accumulators, trace_scopes=trace_scopes, + use_sreg_for_state=use_sreg_for_state, ) - def loop_body(step, indices): + def loop_body(step, carry): + unaliased_brefs, indices = carry scheduler = make_scheduler(step, indices) with scheduler.grid_env(): # prepare any local VMEM aliases - brefs = map_brefs(scheduler.alias_local_refs, allocations, refs) + brefs = map_brefs(scheduler.alias_local_refs, unaliased_brefs, refs) # loop input handling phase map_brefs(scheduler.copy_in, brefs, refs, schedule) @@ -1243,25 +1510,61 @@ def loop_body(step, indices): lambda: postyeet(*brefs, scheduler), lambda: None) - map_brefs(scheduler.swap_slots, brefs, refs, schedule) - return _next_index(indices, grid) + next_brefs = map_brefs( + scheduler.swap_slots, unaliased_brefs, refs, schedule + ) + return next_brefs, _next_index(indices, grid) - @pl.when(num_steps > 0) - def _(): - # pipeline prologue + + if no_pipelining: + # Debugging mode where all copies are synchronous. initial_indices = (0,) * len(grid) scheduler = make_scheduler(0, initial_indices) brefs = map_brefs(scheduler.alias_local_refs, allocations, refs) - map_brefs(scheduler.initialize, brefs, refs, schedule) - - # pipeline loop - next_indices = lax.fori_loop(0, num_steps, loop_body, initial_indices) - - # pipeline epilogue - final_indices = _prev_index(next_indices, grid) - scheduler = make_scheduler(num_steps - 1, final_indices) - brefs = map_brefs(scheduler.alias_local_refs, allocations, refs) - map_brefs(scheduler.finalize, brefs, refs, schedule) + map_brefs(lambda bref: bref.init_slots(), brefs) + if postyeet is not None or prefetch is not None: + raise NotImplementedError("Prefetch/Postyeet not supported") + if any(bref.is_accumulator for bref in brefs): + raise NotImplementedError("Accumulators not supported") + @functools.partial(jax.lax.fori_loop, 0, num_steps, + init_val=initial_indices) + def _loop_body(step, indices): + scheduler = make_scheduler(step, indices) + with scheduler.grid_env(): + # prepare any local VMEM aliases + brefs = map_brefs(scheduler.alias_local_refs, allocations, refs) + # loop input handling phase + copy_in = lambda bref, ref: sync_copy(ref, bref, indices) + map_inputs(copy_in, brefs, refs) + # run the kernel! + if body_prologue is not None: + body_prologue() + current_refs = map_brefs(lambda x: x.current_ref, brefs) + with scheduler._named_scope("ep_run_kernel"): + body(*current_refs, *scratches) + # loop output handling phase + copy_out = lambda bref, ref: sync_copy(bref, ref, indices) + map_outputs(copy_out, brefs, refs) + return _next_index(indices, grid) + else: + @pl.when(num_steps > 0) + def _(): + # pipeline prologue + initial_indices = (0,) * len(grid) + scheduler = make_scheduler(0, initial_indices) + with scheduler.grid_env(): + brefs = map_brefs(scheduler.initialize, allocations, refs, schedule) + + # pipeline loop + brefs, next_indices = lax.fori_loop( + 0, num_steps, loop_body, (brefs, initial_indices) + ) + + # pipeline epilogue + final_indices = _prev_index(next_indices, grid) + scheduler = make_scheduler(num_steps - 1, final_indices) + with scheduler.grid_env(): + map_brefs(scheduler.finalize, brefs, refs, schedule) return pipeline diff --git a/jax/_src/pallas/mosaic/primitives.py b/jax/_src/pallas/mosaic/primitives.py index fb0e0c2c55e3..aac5af84bd19 100644 --- a/jax/_src/pallas/mosaic/primitives.py +++ b/jax/_src/pallas/mosaic/primitives.py @@ -16,18 +16,20 @@ from __future__ import annotations import dataclasses -import enum from typing import Any import jax from jax._src import core as jax_core from jax._src import dtypes from jax._src import pretty_printer as pp +from jax._src import prng as jax_prng +from jax._src import random as jax_random from jax._src import state from jax._src import tree_util from jax._src import util from jax._src.interpreters import mlir from jax._src.pallas import core as pl_core +from jax._src.pallas import primitives from jax._src.pallas import utils as pallas_utils from jax._src.pallas.mosaic import core as tpu_core from jax._src.state import discharge as state_discharge @@ -160,255 +162,6 @@ def _roll(x, shift): mlir.register_lowering(roll_p, _roll_lowering_rule) -class DeviceIdType(enum.Enum): - MESH = "mesh" - LOGICAL = "logical" - - -def check_sem_avals( - sem_aval, sem_transforms_avals, name, allowed_semaphore_types=None -): - if allowed_semaphore_types is None: - allowed_semaphore_types = { - tpu_core.semaphore, - tpu_core.barrier_semaphore, - # For interpret mode. - pl_core.SEMAPHORE_INTERPRET_DTYPE, - } - if not isinstance(sem_aval, state.AbstractRef): - raise ValueError(f"Cannot {name} on a non-semaphore Ref: {sem_aval}") - sem_shape = sem_aval.shape - if sem_transforms_avals: - sem_shape = sem_transforms_avals[-1].get_indexer_shape() - if sem_shape: - raise ValueError(f"Cannot {name} on a non-()-shaped semaphore: {sem_shape}") - sem_dtype = sem_aval.dtype - if not any( - jnp.issubdtype(sem_dtype, sem_type) - for sem_type in allowed_semaphore_types - ): - raise ValueError( - f"Must {name} semaphores of the following types:" - f" {allowed_semaphore_types}. Got {sem_dtype}." - ) - - -def _transform_semaphore(ref_value, transforms, ref_aval): - """Helper function for indexing into a semaphore during state_discharge.""" - if ref_value.shape == ref_aval.shape: - return state_discharge.transform_array(ref_value, transforms) - elif len(ref_value.shape) == 0: - return ref_value - else: - raise ValueError( - f"Semaphore value shape {ref_value.shape} does not match aval shape" - f" {ref_aval.shape}" - ) - - -semaphore_read_p = jax_core.Primitive("semaphore_read") -semaphore_read_p.multiple_results = False - - -def semaphore_read(sem_or_view): - ref, transforms = _get_ref_and_transforms(sem_or_view) - args = [ref, transforms] - flat_args, args_tree = tree_util.tree_flatten(args) - return semaphore_read_p.bind(*flat_args, args_tree=args_tree) - -@semaphore_read_p.def_abstract_eval -def _semaphore_read_abstract_eval( - *avals, - args_tree, -): - sem_aval, sem_transforms_avals = tree_util.tree_unflatten(args_tree, avals) - check_sem_avals( - sem_aval, - sem_transforms_avals, - "read", - allowed_semaphore_types={ - tpu_core.dma_semaphore, - tpu_core.semaphore, - tpu_core.barrier_semaphore, - pl_core.SEMAPHORE_INTERPRET_DTYPE, - }, - ) - return jax_core.ShapedArray((), jnp.dtype("int32")) - -def _semaphore_read_discharge_rule(in_avals, - out_avals, - *flat_args, - args_tree): - del out_avals - [ref, transforms] = args_tree.unflatten(flat_args) - sem_value = _transform_semaphore(ref, transforms, in_avals[0]) - sem_value = sem_value.astype(jnp.int32) - return (None,) * len(in_avals), sem_value -state_discharge.register_discharge_rule(semaphore_read_p)( - _semaphore_read_discharge_rule -) - - -semaphore_signal_p = jax_core.Primitive('semaphore_signal') -semaphore_signal_p.multiple_results = True - - -def semaphore_signal( - sem_or_view, - inc: int | jax.Array = 1, - *, - device_id: int | jax.Array | None | tuple[int | jax.Array, ...] = None, - device_id_type: DeviceIdType = DeviceIdType.MESH, - core_index: int | jax.Array | None = None, -): - ref, transforms = _get_ref_and_transforms(sem_or_view) - inc = jnp.asarray(inc, dtype=jnp.int32) - args = [ref, transforms, inc, device_id, core_index] - flat_args, args_tree = tree_util.tree_flatten(args) - semaphore_signal_p.bind( - *flat_args, - args_tree=args_tree, - device_id_type=device_id_type, - ) - - -@semaphore_signal_p.def_abstract_eval -def _semaphore_signal_abstract_eval( - *avals, - args_tree, - device_id_type: DeviceIdType, -): - del device_id_type - ( - sem_aval, - sem_transforms_avals, - value_aval, - device_id_avals, - core_index_aval, - ) = tree_util.tree_unflatten(args_tree, avals) - check_sem_avals(sem_aval, sem_transforms_avals, "signal") - if value_aval.dtype != jnp.dtype("int32"): - raise ValueError("Must signal an int32 value.") - if device_id_avals is not None: - device_id_flat_avals = tree_util.tree_leaves(device_id_avals) - for aval in device_id_flat_avals: - if aval.dtype != jnp.dtype("int32"): - raise ValueError("`device_id`s must be an int32 value.") - return [] - - -def _semaphore_signal_pp_eqn(eqn: jax_core.JaxprEqn, - context: jax_core.JaxprPpContext, - settings: jax_core.JaxprPpSettings): - del settings - invars = eqn.invars - tree = eqn.params["args_tree"] - ( - sem, - sem_transforms, - value, - device_ids, - _, - ) = tree_util.tree_unflatten(tree, invars) - out = pp.concat([ - pp.text("semaphore_signal"), - pp.text(" "), - sp.pp_ref_transforms(context, sem, sem_transforms), - pp.text(" "), - pp.text(jax_core.pp_var(value, context)), - ]) - if device_ids is not None: - flat_device_ids = tree_util.tree_leaves(device_ids) - if not flat_device_ids: - return out - device_ids_pp = [pp.text(jax_core.pp_var(flat_device_ids[0], context))] - for device_id in flat_device_ids[1:]: - device_ids_pp.append(pp.text(" ")) - device_ids_pp.append(pp.text(jax_core.pp_var(device_id, context))) - out = pp.concat([out, pp.concat(device_ids_pp)]) - return out -jax_core.pp_eqn_rules[semaphore_signal_p] = _semaphore_signal_pp_eqn - - -def _semaphore_signal_discharge_rule(in_avals, - out_avals, - *flat_args, - args_tree, - device_id_type): - del out_avals, device_id_type - [ref, transforms, inc, device_id, core_index] = args_tree.unflatten(flat_args) - if device_id is not None: - raise NotImplementedError("Remote signal not implemented.") - if core_index is not None: - raise NotImplementedError("Multiple core support not implemented.") - sem_value = _transform_semaphore(ref, transforms, in_avals[0]) - inc = inc.astype(pl_core.SEMAPHORE_INTERPRET_DTYPE) - _, new_sem_value = state_discharge.transform_swap_array( - ref, transforms, sem_value + inc - ) - return (new_sem_value,) + (None,) * (len(in_avals) - 1), () -state_discharge.register_discharge_rule(semaphore_signal_p)( - _semaphore_signal_discharge_rule -) - - -semaphore_wait_p = jax_core.Primitive('semaphore_wait') -semaphore_wait_p.multiple_results = True - -def semaphore_wait(sem_or_view, dec: int | jax.Array = 1): - ref, transforms = _get_ref_and_transforms(sem_or_view) - dec = jnp.asarray(dec, dtype=jnp.int32) - args = [ref, transforms, dec] - flat_args, args_tree = tree_util.tree_flatten(args) - semaphore_wait_p.bind(*flat_args, args_tree=args_tree) - -@semaphore_wait_p.def_abstract_eval -def _semaphore_wait_abstract_eval(*avals, args_tree): - sem_aval, sem_transforms_avals, value_aval = tree_util.tree_unflatten( - args_tree, avals - ) - check_sem_avals(sem_aval, sem_transforms_avals, "wait") - if value_aval.dtype != jnp.dtype("int32"): - raise ValueError("Must wait an int32 value.") - return [] - -def _semaphore_wait_pp_eqn(eqn: jax_core.JaxprEqn, - context: jax_core.JaxprPpContext, - settings: jax_core.JaxprPpSettings): - del settings - invars = eqn.invars - tree = eqn.params["args_tree"] - ( - sem, - sem_transforms, - value, - ) = tree_util.tree_unflatten(tree, invars) - return pp.concat([ - pp.text("semaphore_wait"), - pp.text(" "), - sp.pp_ref_transforms(context, sem, sem_transforms), - pp.text(" "), - pp.text(jax_core.pp_var(value, context)), - ]) -jax_core.pp_eqn_rules[semaphore_wait_p] = _semaphore_wait_pp_eqn - -def _semaphore_wait_discharge_rule(in_avals, - out_avals, - *flat_args, - args_tree): - del out_avals - [ref, transforms, dec] = args_tree.unflatten(flat_args) - sem_value = _transform_semaphore(ref, transforms, in_avals[0]) - dec = dec.astype(pl_core.SEMAPHORE_INTERPRET_DTYPE) - _, new_sem_value = state_discharge.transform_swap_array( - ref, transforms, sem_value - dec - ) - return (new_sem_value,) + (None,) * (len(in_avals) - 1), () -state_discharge.register_discharge_rule(semaphore_wait_p)( - _semaphore_wait_discharge_rule -) - - @dataclasses.dataclass class AsyncCopyDescriptor: src_ref: Any @@ -420,7 +173,7 @@ class AsyncCopyDescriptor: src_sem: int | jax.Array | None src_sem_transforms: tuple[Transform, ...] | None device_id: int | jax.Array | None - device_id_type: DeviceIdType = DeviceIdType.MESH + device_id_type: primitives.DeviceIdType = primitives.DeviceIdType.MESH def __post_init__(self): if (self.src_sem is None) ^ (self.device_id is None): @@ -457,9 +210,14 @@ def _get_args_and_tree(self, swap_src_and_dst: bool = False): self.device_id, )) - def start(self): + def start(self, priority: int = 0): flat_args, tree = self._get_args_and_tree() - dma_start_p.bind(*flat_args, tree=tree, device_id_type=self.device_id_type) + dma_start_p.bind( + *flat_args, + tree=tree, + device_id_type=self.device_id_type, + priority=priority, + ) def wait(self): if self.is_remote: @@ -488,7 +246,9 @@ def wait_send(self): dma_start_p.multiple_results = True @dma_start_p.def_effectful_abstract_eval -def _dma_start_abstract_eval(*args, tree, device_id_type): +def _dma_start_abstract_eval(*args, tree, device_id_type, priority): + if priority < 0: + raise ValueError(f"DMA start priority must be non-negative: {priority}") ( src_ref_aval, src_transforms_avals, @@ -523,6 +283,7 @@ def _dma_start_pp_eqn(eqn: jax_core.JaxprEqn, settings: jax_core.JaxprPpSettings): invars = eqn.invars tree = eqn.params["tree"] + priority = eqn.params["priority"] ( src_ref, src_transforms, @@ -539,7 +300,7 @@ def _dma_start_pp_eqn(eqn: jax_core.JaxprEqn, if src_sem or device_id: return jax_core._pp_eqn(eqn, context, settings) return pp.concat([ - pp.text("dma_start"), + pp.text(f"dma_start(p{priority})"), pp.text(" "), sp.pp_ref_transforms(context, src_ref, src_transforms), pp.text(" -> "), @@ -550,8 +311,12 @@ def _dma_start_pp_eqn(eqn: jax_core.JaxprEqn, jax_core.pp_eqn_rules[dma_start_p] = _dma_start_pp_eqn -def dma_start_partial_discharge_rule(should_discharge, in_avals, out_avals, - *args, tree, device_id_type): + +def dma_start_partial_discharge_rule( + should_discharge, in_avals, out_avals, *args, tree, device_id_type, priority +): + # Note: we ignore the DMA priority in discharge rules. + del priority ( src_ref, src_transforms, @@ -610,14 +375,14 @@ def dma_start_partial_discharge_rule(should_discharge, in_avals, out_avals, # TODO(justinfu): Verify that code only works in SPMD mode. axis_env = jax_core.get_axis_env() nonempty_axes = [name for name in axis_env.axis_sizes if name is not None] - if device_id_type == DeviceIdType.LOGICAL: + if device_id_type == primitives.DeviceIdType.LOGICAL: if len(nonempty_axes) > 1: raise NotImplementedError("Sharding with more than one named axis not " "implemented in dma_start_p for LOGICAL " "device_id_type.") shard_axis = nonempty_axes[0] my_axis = jax.lax.axis_index(shard_axis) - elif device_id_type == DeviceIdType.MESH: + elif device_id_type == primitives.DeviceIdType.MESH: device_id_len = 1 if isinstance(device_id, jax.Array): device_id_len = device_id.size @@ -667,7 +432,7 @@ def do_discharge_dst(dst_ref=dst_ref): def do_discharge_dst_sem(dst_sem=dst_sem): recv_size = jnp.minimum(updates.size, pl_core.SEMAPHORE_MAX_VALUE) recv_size = jnp.array(recv_size, dtype=pl_core.SEMAPHORE_INTERPRET_DTYPE) - dst_sem_value = _transform_semaphore( + dst_sem_value = primitives._transform_semaphore( dst_sem, dst_sem_transforms, dst_sem_aval ) _, ret = state_discharge.transform_swap_array( @@ -678,7 +443,7 @@ def do_discharge_dst_sem(dst_sem=dst_sem): def do_discharge_src_sem(src_sem=src_sem): send_size = jnp.minimum(local_src.size, pl_core.SEMAPHORE_MAX_VALUE) send_size = jnp.array(send_size, dtype=pl_core.SEMAPHORE_INTERPRET_DTYPE) - src_sem_value = _transform_semaphore( + src_sem_value = primitives._transform_semaphore( src_sem, src_sem_transforms, src_sem_aval ) _, ret = state_discharge.transform_swap_array( @@ -710,6 +475,7 @@ def do_discharge_src_sem(src_sem=src_sem): return new_vals, [] + state_discharge.register_partial_discharge_rule(dma_start_p)(dma_start_partial_discharge_rule) @@ -778,7 +544,7 @@ def dma_wait_partial_discharge_rule(should_discharge, updates = state_discharge.transform_array(dst_ref, dst_ref_transforms) copy_size = jnp.minimum(updates.size, pl_core.SEMAPHORE_MAX_VALUE) copy_size = jnp.array(copy_size, dtype=pl_core.SEMAPHORE_INTERPRET_DTYPE) - sem_value = _transform_semaphore(dst_sem, dst_sem_transforms, dst_sem_aval) + sem_value = primitives._transform_semaphore(dst_sem, dst_sem_transforms, dst_sem_aval) _, new_sem = state_discharge.transform_swap_array( dst_sem, dst_sem_transforms, sem_value - copy_size ) @@ -799,6 +565,7 @@ def _get_ref_and_transforms(ref): return ref.ref, ref.transforms return ref, () + def make_async_copy(src_ref, dst_ref, sem): """Issues a DMA copying from src_ref to dst_ref.""" src_ref, src_transforms = _get_ref_and_transforms(src_ref) @@ -814,17 +581,19 @@ def make_async_copy(src_ref, dst_ref, sem): None, None, None, - DeviceIdType.MESH, + primitives.DeviceIdType.MESH, ) -def async_copy(src_ref, dst_ref, sem): + +def async_copy(src_ref, dst_ref, sem, *, priority: int = 0): """Issues a DMA copying from src_ref to dst_ref.""" copy_descriptor = make_async_copy(src_ref, dst_ref, sem) - copy_descriptor.start() + copy_descriptor.start(priority=priority) return copy_descriptor + def make_async_remote_copy(src_ref, dst_ref, send_sem, recv_sem, device_id, - device_id_type: DeviceIdType = DeviceIdType.MESH): + device_id_type: primitives.DeviceIdType = primitives.DeviceIdType.MESH): """Creates a description of a remote copy operation. Copies data from src_ref on the current device to dst_ref on the device @@ -861,27 +630,19 @@ def make_async_remote_copy(src_ref, dst_ref, send_sem, recv_sem, device_id, ) def async_remote_copy(src_ref, dst_ref, send_sem, recv_sem, device_id, - device_id_type: DeviceIdType = DeviceIdType.MESH): + device_id_type: primitives.DeviceIdType = primitives.DeviceIdType.MESH): copy_descriptor = make_async_remote_copy(src_ref, dst_ref, send_sem, recv_sem, device_id, device_id_type) copy_descriptor.start() return copy_descriptor -device_id_p = jax_core.Primitive('device_id') - -@device_id_p.def_abstract_eval -def _device_id_abstract_eval(): - return jax_core.ShapedArray((), jnp.dtype("int32")) - -device_id = device_id_p.bind - get_barrier_semaphore_p = jax_core.Primitive('get_barrier_semaphore') @get_barrier_semaphore_p.def_abstract_eval def _get_barrier_semaphore_abstract_eval(): return pl_core.AbstractMemoryRef( - jax_core.ShapedArray((), tpu_core.BarrierSemaphoreTy()), - tpu_core.TPUMemorySpace.SEMAPHORE, + jax_core.ShapedArray((), pl_core.BarrierSemaphore()), + tpu_core.MemorySpace.SEMAPHORE, ) def get_barrier_semaphore(): @@ -902,7 +663,7 @@ def get_barrier_semaphore(): to share a collective_id. However, if in doubt, prefer not sharing collective_ids, as doing so incorrectly can lead to silent data corruption or crashes. - Note that re-using the same collective_id doesn't guarantee that the same + Note that reusing the same collective_id doesn't guarantee that the same semaphore is provided by XLA. """ return get_barrier_semaphore_p.bind() @@ -926,8 +687,9 @@ def delay(nanos): prng_seed_p = jax_core.Primitive("prng_seed") prng_seed_p.multiple_results = True + @prng_seed_p.def_abstract_eval -def _(*_): +def _prng_seed_abstract_eval(*_): return [] @@ -944,9 +706,113 @@ def prng_seed(*seeds: int | jax.Array) -> None: prng_random_bits_p = jax_core.Primitive( 'prng_random_bits') + @prng_random_bits_p.def_abstract_eval -def _(*, shape): +def _prng_random_bits_abstract_eval(*, shape): return jax_core.ShapedArray(shape, jnp.dtype("int32")) + def prng_random_bits(shape): return prng_random_bits_p.bind(shape=shape) + +# PRNG wrap/unwrap ops. +# We cannot use JAX's key_data and wrap_key_data because they return +# vectors, and Pallas keys are represented as lists of scalars. + +split_key_p = jax_core.Primitive("prng_split") +split_key_p.multiple_results = True + + +@split_key_p.def_abstract_eval +def _split_key_scalar_abstract_eval(seed): + key_shape = seed.dtype._impl.key_shape + if len(key_shape) != 2 or key_shape[0] != 1: + raise ValueError(f"Key shape must be (1, N), got {key_shape}") + return [jax_core.ShapedArray((), jnp.dtype("uint32"))] * key_shape[1] + + +def unwrap_pallas_seed(seed): + """Splits a PRNG key into it's scalar components.""" + return split_key_p.bind(seed) + + +join_key_p = jax_core.Primitive("prng_join") + + +@join_key_p.def_abstract_eval +def _join_key_scalar_abstract_eval(*seeds, impl): + if len(impl.key_shape) != 2 or impl.key_shape[0] != 1: + raise ValueError(f"Key shape must be (1, N), got {impl.key_shape}") + if len(seeds) != impl.key_shape[1]: + raise ValueError( + f"Number of seeds must match key shape, got {len(seeds)}" + f" != {impl.key_shape[1]}." + ) + return jax_core.ShapedArray((), dtype=jax_prng.KeyTy(impl)) + + +def wrap_pallas_seed(*seeds, impl): + """Joins scalar into a single PRNG key.""" + impl = jax_random.resolve_prng_impl(impl) + return join_key_p.bind(*seeds, impl=impl) + + +with_memory_space_constraint_p = jax_core.Primitive( + 'with_memory_space_constraint') + +@with_memory_space_constraint_p.def_impl +def with_memory_space_constraint_impl(x, *, memory_space): + del x, memory_space + raise ValueError("Cannot eagerly run with_memory_space_constraint.") + + +@with_memory_space_constraint_p.def_abstract_eval +def with_memory_space_constraint_abstract_eval(x, *, memory_space): + if not isinstance(x, jax_core.ShapedArray): + raise NotImplementedError("with_memory_space_constraint only supports " + "arrays.") + return pl_core.ShapedArrayWithMemorySpace( + x.shape, x.dtype, memory_space=memory_space + ) + +def with_memory_space_constraint_lowering_rule(ctx, x, *, memory_space): + del ctx, memory_space + return [x] +mlir.register_lowering( + with_memory_space_constraint_p, with_memory_space_constraint_lowering_rule +) + +def with_memory_space_constraint( + x: jax.Array, memory_space: Any +) -> jax.Array: + """Constrains the memory space of an array. + + This primitive does not change the value of `x`, but it constrains the + memory space where it should be allocated. This is useful to force + Pallas to allocate an array in a specific memory space. + + As of now, this only operates on the inputs pallas_calls, as in you can + apply this to the arguments of a pallas_call and it will constrain them, but + other operations will not respect this constraint. + + Args: + x: The array to constrain. + memory_space: The memory space to constrain to. + + Returns: + The array `x` with the memory space constraint. + """ + if memory_space in {tpu_core.MemorySpace.ANY, pl_core.MemorySpace.ANY}: + return x + if memory_space not in {tpu_core.MemorySpace.HBM, tpu_core.MemorySpace.VMEM}: + raise NotImplementedError( + "with_memory_space_constraint only supports HBM and VMEM." + ) + return with_memory_space_constraint_p.bind(x, memory_space=memory_space) + +def get_memory_space(x: jax.Array) -> Any: + """Queries the memory space of an array.""" + aval = jax_core.get_aval(x) + if isinstance(aval, pl_core.ShapedArrayWithMemorySpace): + return aval.memory_space + return None diff --git a/jax/_src/pallas/mosaic/random.py b/jax/_src/pallas/mosaic/random.py index fd8dcc720f07..8d29f857afb2 100644 --- a/jax/_src/pallas/mosaic/random.py +++ b/jax/_src/pallas/mosaic/random.py @@ -13,18 +13,18 @@ # limitations under the License. from collections.abc import Callable - import functools import jax from jax import numpy as jnp from jax import random as jax_api_random from jax._src import blocked_sampler from jax._src import dtypes +from jax._src import prng as jax_prng from jax._src import typing -from jax._src.pallas.mosaic.primitives import prng_seed -from jax._src.pallas.mosaic.primitives import prng_random_bits from jax._src.pallas import primitives -from jax._src import prng as jax_prng +from jax._src.pallas.mosaic import primitives as tpu_primitives +from jax._src.pallas.mosaic.primitives import prng_random_bits +from jax._src.pallas.mosaic.primitives import prng_seed Shape = jax_prng.Shape @@ -32,8 +32,8 @@ KeylessSampleFnType = Callable[..., jax.Array] set_seed = prng_seed - -FOLD_IN_ROUNDS = 128 +unwrap_pallas_seed = tpu_primitives.unwrap_pallas_seed +wrap_pallas_seed = tpu_primitives.wrap_pallas_seed def to_pallas_key(key: jax.Array) -> jax.Array: @@ -63,7 +63,7 @@ def is_pallas_impl(impl: jax_prng.PRNGImpl) -> bool: def _seed_func(seed: jnp.int32): seed_data = jnp.zeros(tpu_key_impl.key_shape, dtype=jnp.int32) - return (seed_data + seed).astype(jnp.uint32) + return (seed_data + seed).astype(jnp.uint32) # Broadcast the seed. def _random_bits(key: typing.Array, bit_width: int, shape: Shape): if bit_width != 32: @@ -72,42 +72,26 @@ def _random_bits(key: typing.Array, bit_width: int, shape: Shape): return prng_random_bits(shape) def _fold_in(key: jax_prng.PRNGKeyArray, data: typing.Array): - # Roughly, we compute the new key as follows: - # new_key = random_bits(data)[..., 127] ^ random_bits(old_key)[..., 127] - # Because the TPU generates random numbers in (8, 128) blocks at once, we - # can generate that many values without additional cost which will reduce - # correlation between the old and new keys. - - # TODO(justinfu): The underlying TPU hardware PRNG doesn't produce robust - # random bits when applied in rounds such as below (measured via crush). - # We should consider a different strategy for generating keys. - key_shape = tpu_key_impl.key_shape - - prng_seed(data) - data_bits = prng_random_bits( - key_shape + (FOLD_IN_ROUNDS,)).astype(jnp.uint32) - prng_seed(key) - key_bits = prng_random_bits( - key_shape + (FOLD_IN_ROUNDS,)).astype(jnp.uint32) - - mixed = key_bits[..., FOLD_IN_ROUNDS-1] ^ data_bits[..., FOLD_IN_ROUNDS-1] - assert mixed.shape == key_shape - return jax.random.wrap_key_data(mixed, impl="pallas_tpu") + key0, key1 = unwrap_pallas_seed(key) + # Perform a cheap mixing of data into the key. + key1 = key1 + data + [key0, key1] = jax_prng.apply_round([key0, key1], 13) + return wrap_pallas_seed(key0, key1, impl="pallas_tpu") def _split(key: typing.Array, shape: Shape): del key, shape - raise NotImplementedError() + raise NotImplementedError( + "Cannot split a Pallas key. Use fold_in instead to generate new keys." + ) tpu_key_impl = jax_prng.PRNGImpl( - # Pallas currently only supports 2D+ windows, so set the key_shape - # to be 2D to have better compatibility with setting BlockSpecs. - key_shape=(1, 1), - seed=_seed_func, - split=_split, - random_bits=_random_bits, - fold_in=_fold_in, - name="pallas_tpu", - tag="pl" + key_shape=(1, 2), + seed=_seed_func, + split=_split, + random_bits=_random_bits, + fold_in=_fold_in, + name="pallas_tpu", + tag="pl", ) jax_prng.register_prng(tpu_key_impl) @@ -193,7 +177,7 @@ def sample_block(sampler_fn: SampleFnType, `tile_size` should be chosen such that it is a divisor to all block sizes one needs to be invariant to. The larger the `tile_size`, the more - efficient the sampling process wil be and therefore the best choice is + efficient the sampling process will be and therefore the best choice is typically the greatest common divisor between all possible block sizes. Args: diff --git a/jax/_src/pallas/mosaic/verification.py b/jax/_src/pallas/mosaic/verification.py index 08ff58770804..d5266826f909 100644 --- a/jax/_src/pallas/mosaic/verification.py +++ b/jax/_src/pallas/mosaic/verification.py @@ -18,7 +18,8 @@ import itertools import math import textwrap -from typing import Any, Sequence +from typing import Any +from collections.abc import Sequence from jax import lax from jax._src import core as jax_core from jax._src import tree_util @@ -596,11 +597,10 @@ def _assume_abstract_eval(x, y): assert jax_core.typematch(x, y) return x +@lowering.register_lowering_rule(assume_p) def _assume_lowering(ctx: lowering.LoweringRuleContext, x, y): return y if ctx.lowering_context.for_verification else x -lowering.lowering_rules[assume_p] = _assume_lowering # type: ignore - def assume(normally, *, when_verifying): return assume_p.bind(normally, when_verifying) @@ -613,6 +613,7 @@ def _pretend_abstract_eval(*_, **params): del params # Unused. return () +@lowering.register_lowering_rule(pretend_p) def _pretend_lowering(ctx: lowering.LoweringRuleContext, *flat_args, tree): if ctx.lowering_context.for_verification: (base_read_refs, transforms) = tree_util.tree_unflatten(tree, flat_args) @@ -631,8 +632,6 @@ def _pretend_lowering(ctx: lowering.LoweringRuleContext, *flat_args, tree): ir.Operation.create("verification.pretend", operands=read_refs) return () -lowering.lowering_rules[pretend_p] = _pretend_lowering # type: ignore - def pretend(read_refs): refs, transforms = unzip2( primitives._get_ref_and_transforms(r) for r in read_refs diff --git a/jax/_src/pallas/mosaic_gpu/BUILD b/jax/_src/pallas/mosaic_gpu/BUILD index e5b491aef330..07b6887ae5c1 100644 --- a/jax/_src/pallas/mosaic_gpu/BUILD +++ b/jax/_src/pallas/mosaic_gpu/BUILD @@ -42,13 +42,16 @@ pytype_strict_library( name = "pallas_call_registration", srcs = ["pallas_call_registration.py"], deps = [ + ":core", ":lowering", "//jax", + "//jax:config", "//jax:core", "//jax:mlir", "//jax:mosaic_gpu", + "//jax:sharding_impls", "//jax/_src/pallas", - ], + ] + py_deps("numpy"), ) pytype_strict_library( @@ -57,12 +60,18 @@ pytype_strict_library( deps = [ ":core", "//jax", + "//jax:api", "//jax:core", + "//jax:dtypes", + "//jax:lax", + "//jax:mesh", "//jax:mlir", "//jax:mosaic_gpu", "//jax:pallas", "//jax:partial_eval", "//jax:source_info_util", + "//jax:state_types", + "//jax:tree_util", "//jax:util", "//jax/_src/lib", "//jax/_src/pallas", @@ -77,9 +86,12 @@ pytype_strict_library( "//jax:core", "//jax:dtypes", "//jax:effects", + "//jax:lax", "//jax:mosaic_gpu", + "//jax:pretty_printer", "//jax:state_types", "//jax:tree_util", + "//jax/_src/lib", "//jax/_src/pallas", "//jaxlib/mlir:ir", ] + py_deps("numpy"), @@ -93,8 +105,11 @@ pytype_strict_library( ":lowering", "//jax", "//jax:core", - "//jax:mlir", + "//jax:frozen_dict", + "//jax:lax", "//jax:mosaic_gpu", + "//jax:pretty_printer", + "//jax:state_types", "//jax:tree_util", "//jax:util", "//jax/_src/lib", @@ -117,3 +132,9 @@ pytype_strict_library( "//jax/_src/pallas", ], ) + +pytype_strict_library( + name = "helpers", + srcs = ["helpers.py"], + deps = ["//jax"], +) diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index 630c1b8f4bed..839d987a484a 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -18,30 +18,43 @@ import abc import collections -from collections.abc import Iterable, Sequence +from collections.abc import Callable, Iterable, Sequence import dataclasses import enum import itertools as it -from typing import Any, ClassVar, Literal +import math +from typing import Any, ClassVar, Literal, Union import jax from jax._src import core as jax_core from jax._src import dtypes from jax._src import effects +from jax._src import pretty_printer as pp from jax._src import tree_util +from jax._src.lib.mlir.dialects import arith as arith_dialect from jax._src.pallas import core as pallas_core +from jax._src.pallas import helpers as pallas_helpers +from jax._src.pallas import primitives as pallas_primitives +import jax._src.pallas.utils as pallas_utils +from jax._src.state import discharge as state_discharge from jax._src.state import indexing from jax._src.state import types as state_types -from jax._src.state import discharge as state_discharge import jax.experimental.mosaic.gpu as mgpu +from jax.experimental.mosaic.gpu import utils as mgpu_utils import jax.numpy as jnp from jaxlib.mlir import ir +_Ref = pallas_core.AbstractMemoryRef | state_types.TransformedRef AbstractMemoryRef = pallas_core.AbstractMemoryRef DimensionSemantics = Literal["parallel", "sequential"] +# We align all our SMEM allocations to 1024 bytes. TMA and WGMMA are very +# sensitive to alignment and while this is quite conservative, it gets the job +# done. We should make this more refined in the future. +SMEM_ALIGNMENT = 1024 + def is_trivial_index(idx, shape) -> bool: """Checks if the index selects the entire shape.""" @@ -58,7 +71,7 @@ def _slices(d): @dataclasses.dataclass(frozen=True, kw_only=True) -class GPUCompilerParams(pallas_core.CompilerParams): +class CompilerParams(pallas_core.CompilerParams): """Mosaic GPU compiler parameters. Attributes: @@ -74,32 +87,49 @@ class GPUCompilerParams(pallas_core.CompilerParams): references. Defaults to 0, and must be strictly smaller than max_concurrent_steps. Generally, you'll want to set it to 1 if you don't await the WGMMA in the body. + unsafe_no_auto_barriers: If True, Pallas will never automatically insert + barrier instructions that ensure synchronous semantics of loads and stores. + At the moment, the insertion is done conservatively and might regress + performance. There are (at least) two conditions that must be satisfied + for the use of this flag to be safe. First, no memory region is ever read + *and* written to by the same thread (async copies are performed by + background threads and do not count towards this rule). Secondly, no + thread ever calls commit_smem(), reads from the committed SMEM and then + issues an async copy overwriting that region (this is a very artificial + and highly unlikely scenario). profile_space: The number of profiler events that can be collected in a single invocation. It is undefined behavior if a thread collects more events than this. profile_dir: The directory to which profiling traces will be written to. """ - PLATFORM: ClassVar[str] = "mosaic_gpu" + BACKEND: ClassVar[pallas_core.Backend] = "mosaic_gpu" approx_math: bool = False dimension_semantics: Sequence[DimensionSemantics] | None = None max_concurrent_steps: int = 1 delay_release: int = 0 + unsafe_no_auto_barriers: bool = False profile_space: int = 0 profile_dir: str = "" - thread_semantics: mgpu.core.ThreadSemantics = mgpu.core.ThreadSemantics.Lane + lowering_semantics: mgpu.core.LoweringSemantics = mgpu.core.LoweringSemantics.Lane def __post_init__(self): + if self.dimension_semantics is not None: + object.__setattr__( + self, "dimension_semantics", tuple(self.dimension_semantics) + ) if bool(self.profile_space) ^ bool(self.profile_dir): raise ValueError( "Either both profile_space and profile_dir must be set, or neither." ) -class GPUMemorySpace(enum.Enum): +class MemorySpace(enum.Enum): #: Global memory. GMEM = "gmem" #: Shared memory. SMEM = "smem" + #: Tensor memory. + TMEM = "tmem" #: Registers. REGS = "regs" @@ -110,26 +140,77 @@ def __call__( self, shape: tuple[int, ...], dtype: jnp.dtype, + *, transforms: Sequence[MemoryRefTransform] = (), - + packed: bool | None = None, + collective: bool | None = None ) -> pallas_core.MemoryRef: # A convenience function for constructing MemoryRef types. - return GPUMemoryRef(shape, dtype, memory_space=self, transforms=transforms) + return GPUMemoryRef(shape, dtype, memory_space=self, transforms=transforms, + packed=packed, collective=collective) + + +class SemaphoreType(enum.Enum): + REGULAR = "regular" + BARRIER = "barrier" + + def __call__(self, shape: tuple[int, ...]): + dtype: Any + if self == SemaphoreType.BARRIER: + dtype = pallas_core.BarrierSemaphore() + else: + dtype = pallas_core.Semaphore() + return pallas_core.MemoryRef(shape, dtype, MemorySpace.GMEM) + + def get_array_aval(self) -> jax_core.ShapedArray: + return self(()).get_array_aval() + def get_ref_aval(self) -> _Ref: + return self(()).get_ref_aval() -def kernel(body, out_shape, compiler_params=None, **mesh_kwargs): + +class PrimitiveSemantics(enum.Enum): + """Thread semantics for a primitives at the Pallas user-level.""" + + Warp = enum.auto() + Warpgroup = enum.auto() + + +# Convenience constants for (lowering, primitive) thread semantics pairs. +LANExWG_SEMANTICS = ( + mgpu.LoweringSemantics.Lane, PrimitiveSemantics.Warpgroup) +LANExWARP_SEMANTICS = ( + mgpu.LoweringSemantics.Lane, PrimitiveSemantics.Warp) +WGxWG_SEMANTICS = ( + mgpu.LoweringSemantics.Warpgroup, PrimitiveSemantics.Warpgroup) + + +def kernel( + body: Callable[..., None], + out_shape: object, + *, + scratch_shapes: pallas_core.ScratchShapeTree = (), + compiler_params: pallas_core.CompilerParams | None = None, + **mesh_kwargs: object, +): if unwrap_out := not isinstance(out_shape, (tuple, list)): out_shape = (out_shape,) def wrapper(*operands): def stateful(operand_and_out_refs): operand_refs, out_refs = operand_and_out_refs + mesh = Mesh(**mesh_kwargs) + thread_name = mesh.thread_name if mesh.thread_name is not None else () def cmap_body(): - body(*operand_refs, *out_refs) + pallas_primitives.run_scoped( + lambda *scratch_refs: body(*operand_refs, *out_refs, *scratch_refs), + *scratch_shapes, + collective_axes=thread_name, + ) pallas_core.core_map( - GPUMesh(**mesh_kwargs), compiler_params=compiler_params + mesh, compiler_params=compiler_params )(cmap_body) _, outs = state_discharge.run_state(stateful)( - (operands, jax.tree.map(jnp.zeros_like, out_shape)) + (operands, pallas_helpers.empty_like(out_shape, backend="mosaic_gpu")) ) return outs[0] if unwrap_out else outs return wrapper @@ -139,13 +220,32 @@ def cmap_body(): class GPUMemoryRef(pallas_core.MemoryRef): transforms: Sequence[MemoryRefTransform] = () - def get_ref_aval(self) -> pallas_core.TransformedRef | AbstractMemoryRef: + # Whether to allow TMEM packing for sub 4-byte dtypes. + packed: bool | None = dataclasses.field(default=None, kw_only=True) + collective: bool | None = dataclasses.field(default=None, kw_only=True) + + def __post_init__(self): + if self.memory_space != MemorySpace.TMEM: + if self.packed is not None: + raise ValueError("Packed option is only supported for TMEM.") + if self.collective is not None: + raise ValueError("Collective option is only supported for TMEM.") + + def get_ref_aval(self) -> _Ref: aval = jax_core.ShapedArray(self.shape, self.dtype) for t in self.transforms: aval = t(aval) - ref = pallas_core.TransformedRef( - AbstractMemoryRef(aval, memory_space=self.memory_space), () - ) + if self.memory_space == MemorySpace.TMEM: + ref = pallas_core.TransformedRef( + AbstractTMEMRef(aval, + memory_space=self.memory_space, + packed=self.packed, + collective=self.collective), () + ) + else: + ref = pallas_core.TransformedRef( + AbstractMemoryRef(aval, memory_space=self.memory_space), () + ) for t in reversed(self.transforms): ref = t.undo(ref) if not ref.transforms: @@ -153,6 +253,132 @@ def get_ref_aval(self) -> pallas_core.TransformedRef | AbstractMemoryRef: return ref +def align_to(x: int, alignment: int): + if rem := x % alignment: + return x + alignment - rem + return x + + +# A tree of `GPUMemoryRef`s. +_GPUMemoryRefTree = Any + + +def _ref_group_size(refs: _GPUMemoryRefTree) -> int: + if isinstance(refs, GPUMemoryRef): + refs = (refs,) + size = 0 + for ref in jax.tree.leaves(refs): + # Make sure that the start of each ref is aligned with `SMEM_ALIGNMENT`. + size = align_to(size, SMEM_ALIGNMENT) + if jnp.issubdtype(ref.dtype, jnp.integer): + nbits = jnp.iinfo(ref.dtype).bits + elif jnp.issubdtype(ref.dtype, jnp.floating): + nbits = jnp.finfo(ref.dtype).bits + else: + raise NotImplementedError(f"Unsupported dtype: {ref.dtype}") + ref_bits = math.prod(ref.shape) * nbits + if ref_bits % 8: + raise ValueError("Only byte-aligned shapes are supported.") + size += ref_bits // 8 + return size + + +def flatten_ref_union(ref_union: AbstractRefUnion) -> tuple[_Ref, ...]: + """Flattens a union of trees of references into a tuple of references. + + This is the moral equivalent of `jax.tree.leaves` for aliased references. + """ + flat_refs = [] + union_bytes = 0 + for ref_group in ref_union.refs: + byte_offset = 0 + for ref in jax.tree.leaves(ref_group): + byte_offset = align_to(byte_offset, SMEM_ALIGNMENT) + assert isinstance(ref, pallas_core.AbstractMemoryRef) or isinstance( + ref, pallas_core.TransformedRef + ) + if not isinstance(ref, pallas_core.TransformedRef): + ref = pallas_core.TransformedRef(ref, transforms=()) + transform = ExtractAliasedRef.from_transformed_ref(ref, byte_offset) + flat_refs.append( + pallas_core.TransformedRef( + ref_union, transforms=(transform, *ref.transforms) + ) + ) + if jnp.issubdtype(ref.dtype, jnp.integer): + nbits = jnp.iinfo(ref.dtype).bits + elif jnp.issubdtype(ref.dtype, jnp.floating): + nbits = jnp.finfo(ref.dtype).bits + else: + raise NotImplementedError(f"Unsupported dtype: {ref.dtype}") + ref_bits = math.prod(ref.shape) * nbits + if ref_bits % 8: + raise ValueError("Only byte-aligned shapes are supported.") + byte_offset += ref_bits // 8 + union_bytes = max(union_bytes, byte_offset) + assert union_bytes == ref_union.shape[0] + return tuple(flat_refs) + + +class AbstractRefUnion(pallas_core.AbstractMemoryRef): + refs: Sequence[_GPUMemoryRefTree] + + def __init__( + self, + aval, + refs: Sequence[_GPUMemoryRefTree], + memory_space, + ): + self.refs = refs + super().__init__(aval, memory_space=memory_space) + + def _iter(self, tracer): + return iter(flatten_ref_union(tracer)) + + def _getitem(self, tracer, index): + return list(iter(tracer))[index] + + def _setitem(self, tracer, index, value): + del tracer, index, value # Unused. + raise ValueError("Ref unions can't be assigned to.") + + def update(self, inner_aval=None, memory_space=None): + ref = super().update(inner_aval, memory_space) + return AbstractRefUnion(ref.inner_aval, self.refs, self.memory_space) + + +@dataclasses.dataclass(init=False, frozen=True) +class RefUnion(GPUMemoryRef): + """A sequence of trees of refs that are allowed to reuse the same memory. + + One should not make assumptions as to how each ref will map to the underlying + memory region, since arbitrary padding may be applied in between different + refs. + + As such, ref unions are only safe to use when the groups of refs that we + intend to alias have disjoint lifetimes (i.e. one should never attempt to read + data using a different ref than the one that was used to write the data). + """ + refs: Sequence[_GPUMemoryRefTree] = () + + def __init__(self, *refs: _GPUMemoryRefTree): + if any(ref.memory_space != SMEM for ref in jax.tree.leaves(refs)): + raise NotImplementedError("Only SMEM refs can be aliased.") + object.__setattr__(self, "refs", refs) + num_bytes = max(map(_ref_group_size, self.refs)) + super().__init__( + shape=(num_bytes,), + dtype=jnp.int8, + memory_space=SMEM, + transforms=(), + ) + + def get_ref_aval(self) -> AbstractRefUnion: + inner_aval = jax.core.ShapedArray(self.shape, self.dtype) + refs_aval = jax.tree.map(lambda ref: ref.get_ref_aval(), self.refs) + return AbstractRefUnion(inner_aval, refs_aval, memory_space=SMEM) + + class MemoryRefTransform(pallas_core.MemoryRefTransform, abc.ABC): @abc.abstractmethod def to_gpu_transform(self) -> mgpu.MemRefTransform: @@ -171,7 +397,7 @@ def __call__(self, aval: jax_core.ShapedArray) -> jax_core.ShapedArray: shape=self.to_gpu_transform().transform_shape(aval.shape) ) -Index = slice | int | ir.Value +Index = Union[mgpu.DynamicSlice, slice, int, ir.Value] @dataclasses.dataclass(frozen=True) class TilingTransform(MemoryRefTransform): @@ -213,26 +439,74 @@ def transform_shape(self, shape): def transform_dtype(self, dtype): return dtype + def untransform_transpose( + self, perm: tuple[int, ...] + ) -> tuple[tuple[int, ...], state_types.Transform]: + # The transpose in question is applied to the utiled ref so we + # need to translate it by duplicating and offsetting the last part. + off = len(perm) + new_suffix = [i + off for i in perm[-len(self.tiling) :]] + if set(new_suffix) != set(range(off, off + len(self.tiling))): + raise ValueError( + "Transpose cannot be moved before a tiling transform when it changes" + f" the set of tiled dimensions. (permutation: {perm}, tiling:" + f" {self.tiling})" + ) + + new_tiling = tuple(self.tiling[i - off] for i in new_suffix) + return (*perm, *new_suffix), dataclasses.replace(self, tiling=new_tiling) + + def untransform_reshape( + self, dtype: jnp.dtype, shape: tuple[int, ...] + ) -> tuple[tuple[int, ...], state_types.Transform]: + del dtype + raise NotImplementedError("Reshapes don't commute with transposes.") + def untransform_index( - self, idxs: tuple[Index, ...] + self, dtype: jnp.dtype | ir.Type, idxs: tuple[Index, ...] ) -> tuple[tuple[Index, ...], state_types.Transform]: + del dtype untiled_idxs = idxs[: -len(self.tiling)] tiled_idxs = idxs[-len(self.tiling) :] - idxs_after_tiling = [] + idxs_after_tiling: list[Index] = [] for idx, tile in zip(tiled_idxs, self.tiling): - if not isinstance(idx, slice): - raise NotImplementedError("Non-slice indices are not supported") - assert isinstance(idx, slice) - if idx.step is not None and idx.step != 1: - raise NotImplementedError("Strided slices unsupported") - if (idx.start is not None and idx.start % tile) or (idx.stop is not None and idx.stop % tile): - raise ValueError("Non-empty slices must be tile aligned") - idxs_after_tiling.append(slice(idx.start // tile, idx.stop // tile)) + if isinstance(idx, slice): + if idx.step is not None and idx.step != 1: + raise NotImplementedError("Strided slices unsupported") + if (idx.start is not None and idx.start % tile) or (idx.stop is not None and idx.stop % tile): + raise ValueError("Non-empty slices must be tile aligned") + idxs_after_tiling.append(slice(idx.start // tile, idx.stop // tile)) + elif isinstance(idx, mgpu.DynamicSlice): + if idx.length % tile: + raise ValueError( + f"Dynamic slice length ({idx.length}) is not divisible by the" + f" tiling ({tile})" + ) + if isinstance(idx.base, ir.Value): + if not mgpu_utils.is_known_divisible(idx.base, tile): + raise ValueError( + "Dynamic slice base index (which is a dynamic value) cannot be" + f" statically proven to be divisible by the tiling ({tile})" + ) + new_base = arith_dialect.divui(idx.base, mgpu.c(tile, idx.base.type)) + else: + if idx.base % tile: + raise ValueError( + f"Dynamic slice base ({idx.base}) is not divisible by the" + f" tiling ({tile})" + ) + new_base = idx.base // tile + idxs_after_tiling.append(mgpu.DynamicSlice(new_base, idx.length // tile)) + else: + raise TypeError(f"Unsupported index type: {type(idx)}") return (*untiled_idxs, *idxs_after_tiling, *(slice(None) for _ in self.tiling)), self def undo_to_gpu_transform(self) -> mgpu.MemRefTransform: return mgpu.TileTransform(self.tiling) + def pretty_print(self, context: jax_core.JaxprPpContext) -> pp.Doc: + return pp.text(f"{{untile({list(self.tiling)})}}") + def _perm_inverse(permutation: tuple[int, ...]) -> tuple[int, ...]: inverse = [-1] * len(permutation) @@ -271,7 +545,7 @@ def to_gpu_transform(self) -> mgpu.MemRefTransform: @tree_util.register_dataclass @dataclasses.dataclass(frozen=True) class TransposeRef(state_types.Transform): - permutation: tuple[int, ...] + permutation: tuple[int, ...] = dataclasses.field(metadata=dict(static=True)) def transform_shape(self, shape): if shape is None: @@ -281,11 +555,25 @@ def transform_shape(self, shape): def transform_dtype(self, dtype): return dtype + def untransform_transpose( + self, perm + ) -> tuple[tuple[int, ...], state_types.Transform]: + raise NotImplementedError( + "Commuting of transpose over transpose is not supported." + ) + + def untransform_reshape( + self, dtype: jnp.dtype | ir.Type, shape: tuple[int, ...] + ) -> tuple[tuple[int, ...], state_types.Transform]: + del shape, dtype + raise NotImplementedError("Can't reshape a transposed memref.") + def untransform_index( - self, idxs: tuple[Index, ...] + self, dtype: jnp.dtype | ir.Type, idxs: tuple[Index, ...] ) -> tuple[tuple[Index, ...], state_types.Transform]: + del dtype removed_dims = [ - i for i, idx in enumerate(idxs) if not isinstance(idx, slice) + i for i, idx in enumerate(idxs) if not isinstance(idx, (slice, mgpu.ds)) ] new_perm = tuple( p - sum(d < p for d in removed_dims) @@ -298,20 +586,109 @@ def untransform_index( def undo_to_gpu_transform(self) -> mgpu.MemRefTransform: return mgpu.TransposeTransform(_perm_inverse(self.permutation)) + def pretty_print(self, context: jax_core.JaxprPpContext) -> pp.Doc: + return pp.text(f"{{transpose({list(self.permutation)})}}") -def transpose_ref( - ref: pallas_core.TransformedRef | Any, - permutation: tuple[int, ...], + +@tree_util.register_pytree_node_class +@dataclasses.dataclass +class PeerMemRef(state_types.Transform): + device_id: Any + device_id_type: pallas_primitives.DeviceIdType + + def transform_shape(self, shape): + return shape + + def transform_dtype(self, dtype): + return dtype + + def untransform_index( + self, idxs: tuple[Index, ...] + ) -> tuple[tuple[Index, ...], state_types.Transform]: + return idxs, self + + def tree_flatten(self): + return (self.device_id,), (self.device_id_type,) + + @classmethod + def tree_unflatten(cls, metadata, arrays): + return cls(arrays[0], metadata[0]) + + +def remote_ref( + ref: _Ref, + device_id: jax.typing.ArrayLike, + device_id_type: pallas_primitives.DeviceIdType = pallas_primitives.DeviceIdType.MESH, ) -> pallas_core.TransformedRef: + """Translate memref to a symmetric memref on a peer device.""" if not isinstance(ref, pallas_core.TransformedRef): - if not isinstance(jax_core.get_aval(ref), pallas_core.AbstractMemoryRef): + if not isinstance(jax_core.get_aval(ref), state_types.AbstractRef): raise TypeError("ref must be a reference") ref = pallas_core.TransformedRef(ref, transforms=()) return pallas_core.TransformedRef( - ref.ref, (*ref.transforms, TransposeRef(permutation)), + ref.ref, (*ref.transforms, PeerMemRef(device_id, device_id_type)), ) +def transform_ref( + ref: pallas_core.TransformedRef, + transform: state_types.Transform +) -> pallas_core.TransformedRef: + if not isinstance(ref, pallas_core.TransformedRef): + if not isinstance(jax_core.get_aval(ref), state_types.AbstractRef): + raise TypeError("ref must be a reference") + ref = pallas_core.TransformedRef(ref, transforms=()) + return pallas_core.TransformedRef( + ref.ref, (*ref.transforms, transform), + ) + +def transpose_ref( + ref: pallas_core.TransformedRef | Any, + permutation: tuple[int, ...], +) -> pallas_core.TransformedRef: + return transform_ref(ref, TransposeRef(permutation)) + +def untile_ref(ref, tiling: tuple[int, ...]) -> pallas_core.TransformedRef: + return transform_ref(ref, UntileRef(tiling)) + +def unswizzle_ref(ref, swizzle: int) -> pallas_core.TransformedRef: + return transform_ref(ref, UnswizzleRef(swizzle)) + + +@tree_util.register_pytree_node_class +@dataclasses.dataclass(frozen=True) +class ExtractAliasedRef(state_types.Transform): + """Bitcasts the underlying ref at the given offset to the given shape and dtype.""" + dtype: dtypes.DType + shape: tuple[int, ...] + offset: int + + @classmethod + def from_transformed_ref( + cls, ref: pallas_core.TransformedRef, byte_offset: int + ): + return cls( + dtypes.dtype(ref.dtype), ref.ref.shape, byte_offset + ) + + def transform_shape(self, shape): + if shape is None: + return None + return self.shape + + def transform_dtype(self, dtype): + del dtype # Unused. + return self.dtype + + def tree_flatten(self): + return (), (self.dtype, self.shape, self.offset) + + @classmethod + def tree_unflatten(cls, metadata, arrays): + assert not arrays + return cls(*metadata) + + @dataclasses.dataclass(frozen=True) class SwizzleTransform(MemoryRefTransform): swizzle: int @@ -339,7 +716,7 @@ def undo_to_gpu_transform(self) -> mgpu.MemRefTransform: raise NotImplementedError def __call__(self, aval: jax_core.ShapedArray) -> jax_core.ShapedArray: - swizzle_elems = self.swizzle // aval.dtype.itemsize + swizzle_elems = (self.swizzle * 8) // pallas_utils.dtype_bitwidth(aval.dtype) if swizzle_elems != aval.shape[-1]: raise ValueError( f"Swizzle {self.swizzle} requires the trailing dimension to be of" @@ -353,28 +730,57 @@ def __call__(self, aval: jax_core.ShapedArray) -> jax_core.ShapedArray: class UnswizzleRef(state_types.Transform): swizzle: int = dataclasses.field(metadata=dict(static=True)) + def swizzle_elems(self, dtype: jnp.dtype | ir.Type) -> int: + if not isinstance(dtype, ir.Type): + dtype = mgpu_utils.dtype_to_ir_type(dtype) + return (self.swizzle * 8) // mgpu.bitwidth(dtype) + + def untransform_transpose(self, perm) -> tuple[tuple[int, ...], state_types.Transform]: + if perm[-1] != len(perm) - 1: + raise ValueError("Can't transpose the swizzled dimension.") + + return perm, self + + def untransform_reshape( + self, dtype: jnp.dtype | ir.Type, shape: tuple[int, ...] + ) -> tuple[tuple[int, ...], state_types.Transform]: + if shape[-1] != self.swizzle_elems(dtype): + raise ValueError( + f"Reshape shape {shape} is not divisible by swizzle elements" + f" {self.swizzle_elems(dtype)}" + ) + return shape, self + def untransform_index( - self, idxs: tuple[Index, ...] + self, dtype: jnp.dtype | ir.Type, idxs: tuple[Index, ...] ) -> tuple[tuple[Index, ...], state_types.Transform]: + swizzle_elems = self.swizzle_elems(dtype) if not idxs: return idxs, self - if not all(isinstance(idx, slice) for idx in idxs[-2:]): + if not all(isinstance(idx, (slice, mgpu.ds)) for idx in idxs[-2:]): raise NotImplementedError( "Non-slice indices are not supported in 2 minormost dims" ) last_idx = idxs[-1] - assert isinstance(last_idx, slice) - if last_idx.step is not None and last_idx.step != 1: - raise NotImplementedError("Swizzled dims cannot be sliced") - if (last_idx.start is not None and last_idx.start != 0) or ( - last_idx.stop is not None and last_idx.stop != self.swizzle - ): - raise ValueError("Swizzled dims cannot be sliced") + if isinstance(last_idx, mgpu.DynamicSlice): + if last_idx.base != 0 or last_idx.length != swizzle_elems: + raise ValueError("Swizzled dims cannot be sliced") + else: + assert isinstance(last_idx, slice) + if ( + (last_idx.step is not None and last_idx.step != 1) + or (last_idx.start is not None and last_idx.start != 0) + or (last_idx.stop is not None and last_idx.stop != swizzle_elems) + ): + raise ValueError("Swizzled dims cannot be sliced") return idxs, self + def pretty_print(self, context: jax_core.JaxprPpContext) -> pp.Doc: + return pp.text(f"{{unswizzle({self.swizzle})}}") + @dataclasses.dataclass -class GPUBlockSpec(pallas_core.BlockSpec): +class BlockSpec(pallas_core.BlockSpec): transforms: Sequence[MemoryRefTransform] = () def to_block_mapping( @@ -386,6 +792,7 @@ def to_block_mapping( index_map_tree: tree_util.PyTreeDef, grid: pallas_core.GridMappingGrid, mapped_dims: tuple[int, ...], + debug: bool = False, ) -> pallas_core.BlockMapping: bm = super().to_block_mapping( origin, @@ -394,6 +801,7 @@ def to_block_mapping( index_map_tree=index_map_tree, grid=grid, mapped_dims=mapped_dims, + debug=debug, ) block_inner_aval = bm.block_aval.inner_aval for t in self.transforms: @@ -406,9 +814,10 @@ def to_block_mapping( ) -GMEM = GPUMemorySpace.GMEM -SMEM = GPUMemorySpace.SMEM -REGS = GPUMemorySpace.REGS +GMEM = MemorySpace.GMEM +SMEM = MemorySpace.SMEM +TMEM = MemorySpace.TMEM +REGS = MemorySpace.REGS class barrier_dtype(dtypes.extended): @@ -421,19 +830,60 @@ class BarrierType(dtypes.ExtendedDType): name: ClassVar[str] = "barrier" num_arrivals: int + for_tensor_core: bool def __str__(self): return self.name @dataclasses.dataclass(frozen=True) +class ClusterBarrierType(dtypes.ExtendedDType): + type: ClassVar[Any] = barrier_dtype + name: ClassVar[str] = "cluster_barrier" + + collective_axes: tuple[str | tuple[str, ...], ...] + + def __str__(self): + return self.name + + +@dataclasses.dataclass(frozen=True, kw_only=True) class Barrier: - num_arrivals: int + """Describes a barrier Ref. + + Attributes: + num_arrivals: The number of arrivals that will be recorded by this barrier. + num_barriers: The number of barriers that will be created. Individual + barriers can be accessed by indexing into the barrier Ref. + for_tensor_core: Whether this barrier is used for synchronizing with + the tensor core. This should be set to True when waiting on Blackwell + (TC Gen 5) asynchronous matmul instructions. + """ + num_arrivals: int = 1 + num_barriers: int = 1 + for_tensor_core: bool = False + + def get_ref_aval(self) -> AbstractMemoryRef: + aval = jax_core.ShapedArray( + [self.num_barriers], BarrierType(self.num_arrivals, + for_tensor_core=self.for_tensor_core) + ) + return AbstractMemoryRef(aval, SMEM) + + def __post_init__(self): + if self.num_arrivals < 1: + raise ValueError( + f"Num arrivals must be at least 1, but got {self.num_arrivals}" + ) + +@dataclasses.dataclass(frozen=True, kw_only=True) +class ClusterBarrier: + collective_axes: tuple[str | tuple[str, ...], ...] num_barriers: int = 1 def get_ref_aval(self) -> AbstractMemoryRef: aval = jax_core.ShapedArray( - [self.num_barriers], BarrierType(self.num_arrivals) + [self.num_barriers], ClusterBarrierType(self.collective_axes) ) return AbstractMemoryRef(aval, SMEM) @@ -450,7 +900,7 @@ def get_ref_aval(self) -> AbstractMemoryRef: "Preinitialized WGMMAAccumulatorRef only supported in pl.run_state." ) return WGMMAAbstractAccumulatorRef( - jax_core.ShapedArray(shape=self.shape, dtype=self.dtype), GPUMemorySpace.REGS + jax_core.ShapedArray(shape=self.shape, dtype=self.dtype), MemorySpace.REGS ) @staticmethod @@ -460,7 +910,7 @@ def init(array): def _wgmma_ref_type_mapping(ref: WGMMAAccumulatorRef): aval = WGMMAAbstractAccumulatorRef( - jax_core.ShapedArray(shape=ref.shape, dtype=ref.dtype), GPUMemorySpace.REGS + jax_core.ShapedArray(shape=ref.shape, dtype=ref.dtype), MemorySpace.REGS ) return aval, ref._init state_types._ref_type_aval_mappings[WGMMAAccumulatorRef] = _wgmma_ref_type_mapping @@ -472,11 +922,12 @@ class WGMMAAbstractAccumulatorRef(AbstractMemoryRef): def __repr__(self) -> str: return f'Accumulator{{{self.inner_aval.str_short()}}}' - def update_weak_type(self, weak_type): - return _as_accum(super().update_weak_type(weak_type)) - def update(self, inner_aval=None, memory_space=None): - return _as_accum(super().update(inner_aval=None, memory_space=None)) + ref = super().update(inner_aval, memory_space) + return WGMMAAbstractAccumulatorRef( + inner_aval=ref.inner_aval, + memory_space=ref.memory_space, + ) def _getitem(self, tracer, idx): from jax._src.pallas.mosaic_gpu.primitives import wgmma_accumulator_deref # pytype: disable=import-error @@ -488,35 +939,61 @@ def _getitem(self, tracer, idx): return arr -def _as_accum(ref) -> WGMMAAbstractAccumulatorRef: - return WGMMAAbstractAccumulatorRef( - inner_aval=ref.inner_aval, - memory_space=ref.memory_space, # pytype: disable=attribute-error - ) +class AbstractTMEMRef(AbstractMemoryRef): + __slots__ = ["inner_aval", "memory_space", "packed", "collective"] + + def __init__(self, inner_aval, memory_space, packed, collective): + super().__init__(inner_aval, memory_space) + self.packed = packed + self.collective = collective + + def __repr__(self) -> str: + return f'TMEM({self.inner_aval.str_short()},packed={self.packed})' + + def update(self, inner_aval=None, memory_space=None): + ref = super().update(inner_aval, memory_space) + return AbstractTMEMRef( + ref.inner_aval, ref.memory_space, self.packed, self.collective + ) _WARPGROUP_AXIS_NAME = object() @dataclasses.dataclass(frozen=True, kw_only=True) -class GPUMesh: - grid: tuple[int, ...] = () - cluster: tuple[int, ...] = () +class Mesh: + grid: Sequence[int] = () + grid_names: Sequence[str] = () + cluster: Sequence[int] = () + cluster_names: Sequence[str] = () # Those are NOT CUDA threads. On Hopper they correspond to warpgroups. num_threads: int | None = None - axis_names: tuple[str, ...] = () + thread_name: str | None = None def __post_init__(self): - if len(self.axis_names) != len(self.grid) + (self.num_threads is not None): - raise ValueError("Need as many axis names as grid dimensions + warp groups") + if len(self.cluster) > 3: + raise ValueError(f"cluster= must be at most 3D, got {self}.") + if len(self.grid_names) != len(self.grid): + raise ValueError( + f"grid_names must have the same length as grid, got {self}." + ) + if len(self.cluster_names) != len(self.cluster): + raise ValueError( + f"cluster_names must have the same length as cluster, got {self}." + ) + if (self.thread_name is None) != (self.num_threads is None): + raise ValueError( + "num_threads and thread_name must be either both set or both None," + f" got {self}" + ) if self.num_threads is not None and self.num_threads > 2048 // 128: raise ValueError( "Requested too many CUDA threads per block. Each Mosaic thread" " corresponds to 128 CUDA threads." ) - if self.cluster: - raise NotImplementedError( - "Pallas/MosaicGPU does not support clusters yet." - ) + object.__setattr__(self, "grid", tuple(self.grid)) + object.__setattr__(self, "grid_names", tuple(self.grid_names)) + object.__setattr__(self, "cluster", tuple(self.cluster)) + object.__setattr__(self, "cluster_names", tuple(self.cluster_names)) @property def backend(self) -> str: @@ -527,20 +1004,40 @@ def shape(self) -> collections.OrderedDict[object, int]: pairs: Iterable[tuple[object, int]] if self.num_threads is not None: pairs = zip( - self.axis_names, (*self.grid, *self.cluster, self.num_threads) + (*self.grid_names, *self.cluster_names, self.thread_name), + (*self.grid, *self.cluster, self.num_threads), ) else: - pairs = tuple( - zip( - (*self.axis_names, _WARPGROUP_AXIS_NAME), - (*self.grid, *self.cluster, 1), - ) + pairs = zip( + (*self.grid_names, *self.cluster_names), + (*self.grid, *self.cluster), ) return collections.OrderedDict(pairs) def discharges_effect(self, effect: jax_core.Effect): return effect is _wgmma_pipeline_effect or effect is _memory_effect +@dataclasses.dataclass(frozen=True, kw_only=True) +class WarpMesh: + """Represents a mesh over individual warps within a warpgroup. + + When used in conjunction with `core_map`, the warp ID will be visible + within the body of the wrapped scope by querying `lax.axis_index` with + the specified axis name. + """ + + _NUM_WARPS_PER_WARPGROUP: ClassVar[int] = 4 + axis_name: str + + @property + def shape(self): + return collections.OrderedDict([ + (self.axis_name, self._NUM_WARPS_PER_WARPGROUP), + ]) + + def discharges_effect(self, effect: jax_core.Effect): + del effect + return False def _gpu_mesh_discharge_rule( in_avals, @@ -554,17 +1051,15 @@ def _gpu_mesh_discharge_rule( cost_estimate, name, ): - if not isinstance(mesh, GPUMesh): - raise TypeError(f"Mesh must be a GPUMesh, got {type(mesh)}") - if mesh.cluster: - raise NotImplementedError - if compiler_params and not isinstance(compiler_params, GPUCompilerParams): + if not isinstance(mesh, Mesh): + raise TypeError(f"Mesh must be a `plgpu.Mesh`, got {type(mesh)}") + if compiler_params and not isinstance(compiler_params, CompilerParams): raise TypeError( - "Compiler params must be a GPUCompilerParams, got" + "Compiler params must be a `plgpu.CompilerParams`, got" f" {type(compiler_params)}" ) if not compiler_params: - compiler_params = GPUCompilerParams() + compiler_params = CompilerParams() return pallas_core.default_mesh_discharge_rule( in_avals, out_avals, @@ -576,10 +1071,11 @@ def _gpu_mesh_discharge_rule( interpret=interpret, cost_estimate=cost_estimate, name=name, + memory_space=GMEM, ) -pallas_core._core_map_mesh_rules[GPUMesh] = _gpu_mesh_discharge_rule +pallas_core._core_map_mesh_rules[Mesh] = _gpu_mesh_discharge_rule class MemoryEffect(jax_core.Effect): diff --git a/jax/_src/pallas/mosaic_gpu/helpers.py b/jax/_src/pallas/mosaic_gpu/helpers.py new file mode 100644 index 000000000000..939f3d0382e7 --- /dev/null +++ b/jax/_src/pallas/mosaic_gpu/helpers.py @@ -0,0 +1,87 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Helpers for Pallas Mosaic GPU kernels.""" + +from collections.abc import Callable, Hashable, Sequence +import math +from typing import TypeVar + +import jax +from jax import lax + +_T = TypeVar("_T") + + +def nd_loop( + grid: Sequence[int], + *, + collective_axes: Sequence[Hashable] | Hashable, +) -> Callable[[Callable[[Sequence[jax.Array]], None]], None]: + """A loop over a multi-dimensional grid partitioned along the given axes. + + For example, if ``collective_axes`` is ``"x"`` with :func:`lax.axis_size` + equal to 4 and the grid is (2, 3), the implementation would produce the + following iteration order + + loop step index axis index + + 0 (0, 0) 0 + 1 (0, 1) 1 + 2 (0, 2) 2 + 3 (1, 0) 3 + 4 (1, 1) 0 + 5 (1, 2) 1 + + which comes from partitioning the flat iteration space into chunks in an + interleaved fashion wrt the ``"x"`` axis index. + + Note that in the example the total number of loop steps is not divisible + by the axis size of ``"x"``, and thus for some ``"x"`` axis indices the + loop will do one iteration less. + + axis index indices + + 0 (0, 0), (1, 1) + 1 (0, 1), (1, 2) + 2 (0, 2) + 3 (1, 0) + + See also: + - :func:`jax.experimental.pallas.loop`: A loop over a single dimension. + """ + axis_index = lax.axis_index(collective_axes) + axis_size = lax.axis_size(collective_axes) + grid_size = math.prod(grid) + + def decorator(body): + def wrapper(step, _): + step = step * axis_size + axis_index + # The loop below is conceptually ``jnp.unravel_index``, but it uses + # ``lax`` APIs instead of ``jax.numpy`` to minimize the number of + # primitives used. + index = [] + for grid_dim in reversed(grid): + grid_dim = lax.convert_element_type(grid_dim, step.dtype) + index.append(lax.rem(step, grid_dim)) + step = lax.div(step, grid_dim) + index.reverse() + return body(tuple(index)) + + upper = lax.div(grid_size, axis_size) + lax.convert_element_type( + axis_index < grid_size % axis_size, axis_index.dtype + ) + return lax.fori_loop(0, upper, wrapper, None) + + return decorator diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 6b06e6b7dfc2..d17d8aff6801 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -17,26 +17,34 @@ from __future__ import annotations import collections -from collections.abc import Callable, Hashable, MutableMapping, MutableSequence, Sequence +from collections.abc import Callable, Hashable, Iterable, MutableMapping, MutableSequence, Sequence import contextlib import dataclasses import functools +import itertools import math -from typing import Any, Protocol, cast +import operator +from typing import Any, Protocol, cast, TypeVar, Union import jax from jax import api_util from jax import lax +from jax._src import checkify from jax._src import core as jax_core +from jax._src import dtypes from jax._src import linear_util as lu +from jax._src import mesh as mesh_lib from jax._src import pjit from jax._src import source_info_util +from jax._src import tree_util from jax._src import util from jax._src.interpreters import mlir from jax._src.interpreters import partial_eval as pe from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import arith as arith_dialect +from jax._src.lib.mlir.dialects import cf as cf_dialect from jax._src.lib.mlir.dialects import gpu as gpu_dialect +from jax._src.lib.mlir.dialects import llvm as llvm_dialect from jax._src.lib.mlir.dialects import math as math_dialect from jax._src.lib.mlir.dialects import memref as memref_dialect from jax._src.lib.mlir.dialects import nvvm as nvvm_dialect @@ -46,6 +54,7 @@ from jax._src.pallas import pallas_call from jax._src.pallas import primitives from jax._src.pallas import utils as pallas_utils +from jax._src.pallas import helpers as pallas_helpers from jax._src.pallas.mosaic_gpu import core as gpu_core from jax._src.state import discharge from jax._src.state import indexing @@ -56,6 +65,7 @@ import jax.experimental.mosaic.gpu as mgpu from jax.experimental.mosaic.gpu import core as mgpu_core from jax.experimental.mosaic.gpu import profiler as mgpu_profiler +from jax.experimental.mosaic.gpu import tcgen05 from jax.experimental.mosaic.gpu import utils as mgpu_utils import jax.numpy as jnp import numpy as np @@ -70,47 +80,58 @@ partial = functools.partial SMEM = gpu_core.SMEM -# We align all our SMEM allocations to 1024 bytes. TMA and WGMMA are very -# sensitive to alignment and while this is quite conservative, it gets the job -# done. We should make this more refined in the future. -_SMEM_ALIGNMENT = 1024 WARPGROUP_SIZE = 128 +RefOrTmemType = TypeVar("RefOrTmemType", bound=Union[ir.Value, tcgen05.TMEMRef]) -def _align_to(x: int, alignment: int): - if (rem := x % alignment): - return x + alignment - rem - return x - -@dataclasses.dataclass(frozen=True) +@dataclasses.dataclass(frozen=True, kw_only=True) class ResourceEstimatorContext: - thread_semantics: mgpu.ThreadSemantics + axis_names: _AxisNames + lowering_semantics: mgpu.LoweringSemantics @property def arrival_multiplier(self) -> int: return ( WARPGROUP_SIZE - if self.thread_semantics == mgpu.ThreadSemantics.Lane + if self.lowering_semantics == mgpu.LoweringSemantics.Lane else 1 ) +AnyBarrier = mgpu.Barrier | mgpu.ClusterBarrier + + @dataclasses.dataclass(kw_only=True, frozen=True) class Resources: smem_scratch_bytes: int = 0 - barrier_counts: collections.Counter[mgpu.Barrier] = dataclasses.field( + tmem_scratch_cols: int = 0 + tmem_collective_scratch_cols: int = 0 + barrier_counts: collections.Counter[AnyBarrier] = dataclasses.field( default_factory=collections.Counter ) + gmem_semaphores: int = 0 def __post_init__(self): object.__setattr__( self, "smem_scratch_bytes", - _align_to(self.smem_scratch_bytes, _SMEM_ALIGNMENT), + gpu_core.align_to(self.smem_scratch_bytes, gpu_core.SMEM_ALIGNMENT), + ) + + # TMEM must be allocated in 128x8 chunks. + object.__setattr__( + self, + "tmem_scratch_cols", + gpu_core.align_to(self.tmem_scratch_cols, 8), + ) + object.__setattr__( + self, + "tmem_collective_scratch_cols", + gpu_core.align_to(self.tmem_collective_scratch_cols, 8), ) @property - def barriers(self) -> Sequence[mgpu.Barrier]: + def barriers(self) -> Sequence[AnyBarrier]: return list(self.barrier_counts.elements()) def __add__(self, other: Resources) -> Resources: @@ -120,7 +141,11 @@ def __add__(self, other: Resources) -> Resources: # we will allocate two barriers, even though one would be enough. return Resources( smem_scratch_bytes=self.smem_scratch_bytes + other.smem_scratch_bytes, + tmem_scratch_cols=self.tmem_scratch_cols + other.tmem_scratch_cols, + tmem_collective_scratch_cols=self.tmem_collective_scratch_cols + + other.tmem_collective_scratch_cols, barrier_counts=self.barrier_counts + other.barrier_counts, + gmem_semaphores=self.gmem_semaphores + other.gmem_semaphores, ) def __or__(self, other: Resources) -> Resources: @@ -128,7 +153,13 @@ def __or__(self, other: Resources) -> Resources: smem_scratch_bytes=max( self.smem_scratch_bytes, other.smem_scratch_bytes ), + tmem_scratch_cols=max(self.tmem_scratch_cols, other.tmem_scratch_cols), + tmem_collective_scratch_cols=max( + self.tmem_collective_scratch_cols, + other.tmem_collective_scratch_cols, + ), barrier_counts=self.barrier_counts | other.barrier_counts, + gmem_semaphores=max(self.gmem_semaphores, other.gmem_semaphores), ) @@ -158,11 +189,18 @@ def _estimate_resources( rs = Resources(smem_scratch_bytes=0) for eqn in jaxpr.eqns: # TODO(slebedev): Add support for other primitives, notably control flow. - rule = _resource_estimators.get(eqn.primitive) - if rule is None: - # Assume that unsupported primitives are neutral wrt resource usage. + if rule := _resource_estimators.get(eqn.primitive): + rs |= rule(ctx, *(invar.aval for invar in eqn.invars), **eqn.params) continue - rs |= rule(ctx, *(invar.aval for invar in eqn.invars), **eqn.params) + # Assume that unsupported primitives are neutral wrt resource usage, + # unless they have a jaxpr in their params. + if any( + isinstance(v, (jax_core.Jaxpr, jax_core.ClosedJaxpr)) + for v in eqn.params.values() + ): + raise NotImplementedError( + f"Resource estimation does not support {eqn.primitive}" + ) return rs @@ -170,7 +208,7 @@ def _estimate_resources( @_register_resource_estimator(lax.cond_p) def _cond_resource_estimator( ctx: ResourceEstimatorContext, *args, branches -) -> int: +) -> Resources: del args # Unused. return functools.reduce( lambda a, b: a | b, @@ -181,7 +219,7 @@ def _cond_resource_estimator( @_register_resource_estimator(lax.scan_p) def _scan_resource_estimator( ctx: ResourceEstimatorContext, *args, jaxpr: jax_core.ClosedJaxpr, **params -) -> int: +) -> Resources: del args, params # Unused. return _estimate_resources(ctx, jaxpr) @@ -193,64 +231,184 @@ def _while_resource_estimator( cond_jaxpr: jax_core.ClosedJaxpr, body_jaxpr: jax_core.ClosedJaxpr, **params, -) -> int: +) -> Resources: del args, params # Unused. return _estimate_resources(ctx, cond_jaxpr) | _estimate_resources( ctx, body_jaxpr ) +@_register_resource_estimator(pjit.pjit_p) +def _pjit_resource_estimator( + ctx: ResourceEstimatorContext, + *args, + jaxpr: jax_core.ClosedJaxpr, + **params, +) -> Resources: + del args, params # Unused. + return _estimate_resources(ctx, jaxpr) + + +@_register_resource_estimator(pallas_core.core_map_p) +def _core_map_resource_estimator( + ctx: ResourceEstimatorContext, + *args, + jaxpr: jax_core.ClosedJaxpr, + **params, +) -> Resources: + del args, params # Unused. + return _estimate_resources(ctx, jaxpr) + + +@_register_resource_estimator(discharge.run_state_p) +def _run_state_resource_estimator( + ctx: ResourceEstimatorContext, *args, jaxpr: jax_core.Jaxpr, **params +) -> Resources: + del args, params # Unused. + return _estimate_resources(ctx, jaxpr) + + @_register_resource_estimator(primitives.run_scoped_p) def _run_scoped_resource_estimator( - ctx: ResourceEstimatorContext, *consts, jaxpr: jax_core.Jaxpr -) -> int: + ctx: ResourceEstimatorContext, + *consts, + jaxpr: jax_core.Jaxpr, + collective_axes, +) -> Resources: + del collective_axes # Unused. + + # NOTE: This rule assumes that the allocation happens collectively, although + # it can't be checked here due to limited context. We check this in the actual + # lowering rule. del consts # Unused. rs = Resources() for v in jaxpr.invars: aval = v.aval if isinstance(aval.dtype, gpu_core.BarrierType): + multiplier = 1 if aval.dtype.for_tensor_core else ctx.arrival_multiplier rs += Resources( barrier_counts=collections.Counter([ mgpu.Barrier( - aval.dtype.num_arrivals * ctx.arrival_multiplier, *aval.shape + aval.dtype.num_arrivals * multiplier, *aval.shape ) ]) ) - else: + elif isinstance(aval.dtype, gpu_core.ClusterBarrierType): + collective_dims = jax.tree.map( + lambda axis: _resolve_cluster_axis(ctx.axis_names, axis), + aval.dtype.collective_axes, + ) + rs += Resources( + barrier_counts=collections.Counter( + [mgpu.ClusterBarrier(collective_dims, *aval.shape)] + ) + ) + elif aval.memory_space == gpu_core.TMEM: + if len(aval.shape) != 2: + raise ValueError(f"TMEM allocations must be 2D. Got {aval.shape}") + if aval.shape[0] % tcgen05.TMEM_ROWS != 0: + raise ValueError( + f"TMEM shape[0] must be a multiple of 128. Got {aval.shape[0]}.") + if aval.packed: + packing = 4 // aval.dtype.itemsize + else: + packing = 1 + layout = tcgen05._infer_tmem_layout(aval.shape, packing=packing) + cols_used = layout.cols_in_shape(aval.shape) + cols_used = tcgen05._alloc_ncols(cols_used, exact=False) + if aval.collective: + rs += Resources(tmem_collective_scratch_cols=cols_used) + else: + rs += Resources(tmem_scratch_cols=cols_used) + elif aval.memory_space == gpu_core.SMEM: rs += Resources( smem_scratch_bytes=math.prod(aval.shape) * aval.dtype.itemsize ) + elif aval.memory_space == gpu_core.REGS: + # Don't need to allocate anything. + pass + elif aval.memory_space == gpu_core.GMEM and jnp.issubdtype(aval.dtype, pallas_core.semaphore): + rs += Resources(gmem_semaphores=math.prod(aval.shape)) + else: + raise NotImplementedError( + f"Unsupported memory space: {aval.memory_space}") return rs + _estimate_resources(ctx, jaxpr) @_register_resource_estimator(lax.reduce_sum_p) def _reduce_sum_resource_estimator( ctx: ResourceEstimatorContext, x_aval: jax_core.ShapedArray, *, axes -) -> int: +) -> Resources: del ctx, axes # Unused. # We don't need shmem for some reductons, but it depends on the layout, so we # conservatively request some scratch space. return Resources(smem_scratch_bytes=4 * x_aval.dtype.itemsize) +@dataclasses.dataclass(frozen=True) +class _AxisNames: + grid: Sequence[Hashable] + cluster: Sequence[Hashable] = () + wg: Hashable | None = None + + def __iter__(self) -> Iterable[Hashable]: + return itertools.chain( + self.grid, self.cluster, [self.wg] if self.wg is not None else [] + ) + + +AnyBarrierRef = ( + mgpu.BarrierRef | mgpu.DialectBarrierRef | mgpu.CollectiveBarrierRef +) + + @dataclasses.dataclass class ModuleContext: name: str - grid_names: Sequence[Hashable] | None + axis_names: _AxisNames program_ids: Sequence[ir.Value] | None approx_math: bool - single_wg_lane_predicate: ir.Value + single_wg_lane_predicate: ir.Value | None + single_warp_lane_predicate: ir.Value | None smem_requested_bytes: int smem_used_bytes: int - runtime_barriers: MutableMapping[ - mgpu.Barrier, MutableSequence[mgpu.BarrierRef] - ] + tmem_requested_cols: int + tmem_used_cols: int + tmem_base_ptr: ir.Value + tmem_collective_requested_cols: int + tmem_collective_used_cols: int + tmem_collective_base_ptr: ir.Value + gmem_used_semaphores: int + gmem_semaphore_base_ptr: ir.Value | None + runtime_barriers: MutableMapping[AnyBarrier, MutableSequence[AnyBarrierRef]] name_stack: source_info_util.NameStack traceback_caches: mlir.TracebackCaches squashed_dims: tuple[int, ...] - thread_semantics: mgpu.ThreadSemantics + lowering_semantics: mgpu.LoweringSemantics + primitive_semantics: gpu_core.PrimitiveSemantics + mesh: mesh_lib.Mesh | None + # See the documentation of unsafe_no_auto_barriers in CompilerParams. + auto_barriers: bool + warp_axis_name: str | None = None - def reserve_barrier(self, barrier: mgpu.Barrier) -> mgpu.BarrierRef: + @property + def single_lane_predicate(self) -> ir.Value: + """Returns a predicate that is True for a single lane within the current + thread semantics. + """ + assert self.lowering_semantics == mgpu.LoweringSemantics.Lane + match self.primitive_semantics: + case gpu_core.PrimitiveSemantics.Warpgroup: + return self.single_wg_lane_predicate + case gpu_core.PrimitiveSemantics.Warp: + return self.single_warp_lane_predicate + case _: + raise ValueError(f"Unknown semantics: {self.primitive_semantics}") + + @contextlib.contextmanager + def reserve_barrier( + self, barrier: mgpu.Barrier + ) -> mgpu.BarrierRef | mgpu.DialectBarrierRef | mgpu.CollectiveBarrierRef: """Reserves a barrier. Raises: @@ -259,7 +417,65 @@ def reserve_barrier(self, barrier: mgpu.Barrier) -> mgpu.BarrierRef: available = self.runtime_barriers.get(barrier, []) if not available: raise RuntimeError(f"Barrier {barrier} is already reserved") - return available.pop() + barrier = available.pop() + yield barrier + available.append(barrier) + + @contextlib.contextmanager + def reserve_semaphores( + self, shape: tuple[int, ...] + ): + allocated_sems = math.prod(shape) + ref = mgpu.memref_slice( + self.gmem_semaphore_base_ptr, + mgpu.ds(self.gmem_used_semaphores, allocated_sems), + ) + ref = mgpu.memref_reshape(ref, shape) + self.gmem_used_semaphores += allocated_sems + yield ref + # TODO: In debug mode verify the values of all semaphores are again 0 + self.gmem_used_semaphores -= allocated_sems + + @contextlib.contextmanager + def alloc_tmem( + self, + struct: jax.ShapeDtypeStruct, + *, + layout: tcgen05.TMEMLayout | None = None, + collective: bool = False, + packed: bool = False, + exact_cols: bool = False + ) -> ir.Value: + if packed: + packing = 4 // struct.dtype.itemsize + else: + packing = 1 + if layout is None: + layout = tcgen05._infer_tmem_layout(struct.shape, packing=packing) + unpadded_cols_used = layout.cols_in_shape(struct.shape) + cols_used = tcgen05._alloc_ncols(unpadded_cols_used, exact_cols) + if collective: + off = arith_dialect.addi( + self.tmem_collective_base_ptr, + _i32_constant(self.tmem_collective_used_cols), + ) + else: + off = arith_dialect.addi( + self.tmem_base_ptr, _i32_constant(self.tmem_used_cols) + ) + tmem_ref = tcgen05.TMEMRef( + address=off, + shape=struct.shape, + dtype=mgpu_utils.dtype_to_ir_type(struct.dtype), + layout=layout) + if collective: + self.tmem_collective_used_cols += cols_used + yield tmem_ref + self.tmem_collective_used_cols -= cols_used + else: + self.tmem_used_cols += cols_used + yield tmem_ref + self.tmem_used_cols -= cols_used # TODO(cperivol): Only return the shapes and figure out the sizes when freeing. @contextlib.contextmanager @@ -286,13 +502,13 @@ def scratch_view( smem = ir.Attribute.parse("#gpu.address_space") i8 = ir.IntegerType.get_signless(8) i32 = ir.IntegerType.get_signless(32) - if self.thread_semantics == mgpu.ThreadSemantics.Lane: + if self.lowering_semantics == mgpu.LoweringSemantics.Lane: smem_base = gpu_dialect.dynamic_shared_memory( ir.MemRefType.get((mgpu_utils.DYNAMIC,), i8, memory_space=smem) ) views = [] off = initial_used_bytes = self.smem_used_bytes - assert off % _SMEM_ALIGNMENT == 0 + assert off % gpu_core.SMEM_ALIGNMENT == 0 for s in structs: scratch_ty = ir.MemRefType.get( s.shape, @@ -302,7 +518,7 @@ def scratch_view( # The below code emission relies on the assumption that the first scratch # operand provided by Mosaic GPU always begins at the beginning of # dynamic SMEM. Mosaic GPU is expected to uphold that invariant. - if self.thread_semantics == mgpu.ThreadSemantics.Lane: + if self.lowering_semantics == mgpu.LoweringSemantics.Lane: view = memref_dialect.view( scratch_ty, smem_base, _as_index(off), [] ) @@ -310,11 +526,12 @@ def scratch_view( view = mgpu.dialect.slice_smem(scratch_ty, mgpu_utils.c(off, i32)) views.append(view) - off += _align_to( - math.prod(s.shape) * jnp.dtype(s.dtype).itemsize, _SMEM_ALIGNMENT + off += gpu_core.align_to( + math.prod(s.shape) * jnp.dtype(s.dtype).itemsize, + gpu_core.SMEM_ALIGNMENT, ) assert off <= self.smem_requested_bytes, "Ran out of scoped SMEM" - assert off % _SMEM_ALIGNMENT == 0 + assert off % gpu_core.SMEM_ALIGNMENT == 0 self.smem_used_bytes = off yield views @@ -333,7 +550,10 @@ class LoweringRuleContext: @property def estimator_ctx(self) -> ResourceEstimatorContext: - return ResourceEstimatorContext(thread_semantics=self.module_ctx.thread_semantics) + return ResourceEstimatorContext( + axis_names=self.module_ctx.axis_names, + lowering_semantics=self.module_ctx.lowering_semantics, + ) @dataclasses.dataclass(frozen=True) @@ -341,8 +561,9 @@ class LoweringResult: module: ir.Module grid: tuple[int, ...] block: tuple[int, ...] - out_structs: tuple[jax.ShapeDtypeStruct, ...] + new_out_shapes: tuple[jax.ShapeDtypeStruct, ...] # Does not include gmem scratch! profiler_context: ProfilerContext | None + gmem_scratch_shapes: tuple[jax.ShapeDtypeStruct, ...] @dataclasses.dataclass(frozen=True) @@ -366,11 +587,13 @@ def _eval_index_map( ) result = [] for i, b in zip(block_indices, block_mapping.block_shape): - if b is pallas_core.mapped: - result.append(i) - else: - # TODO(slebedev): Use a type-agnostic multiplication wrapper. - result.append(arith_dialect.muli(_as_index(i), _as_index(b))) + match b: + case pallas_core.Squeezed() | pallas_core.Element(): + result.append(i) + case pallas_core.Blocked(): + result.append(arith_dialect.muli(_as_index(i), _as_index(b))) + case _: + raise ValueError(f"Unsupported block dim type: {b}") return tuple(result) @@ -387,7 +610,7 @@ def err_details(bm: pallas_core.BlockMapping) -> str: f" and index_map {bm.index_map_jaxpr.jaxpr} in" f" memory space {bm.transformed_block_aval.memory_space}." " See details at" - " https://jax.readthedocs.io/en/latest/pallas/grid_blockspec.html#pallas-blockspec." + " https://docs.jax.dev/en/latest/pallas/grid_blockspec.html#pallas-blockspec." ) for bm in block_mappings: @@ -402,7 +625,7 @@ def err_details(bm: pallas_core.BlockMapping) -> str: + err_details(bm) ) - if not isinstance(bm.indexing_mode, pallas_core.Blocked): + if any(isinstance(b, pallas_core.Element) for b in bm.block_shape): raise NotImplementedError( "Only Blocked indexing mode is supported in Mosaic GPU lowering.\n\n" + err_details(bm) @@ -439,20 +662,20 @@ def index_map(*indices): ) return eval_index_map(*new_indices) - return gpu_core.GPUBlockSpec( + return gpu_core.BlockSpec( bm.block_shape, index_map, memory_space=bm.transformed_block_aval.memory_space, - indexing_mode=bm.indexing_mode, transforms=bm.transforms, ) def lower_pipelined_jaxpr_to_module( grid_mapping: pallas_core.GridMapping, - mesh: pallas_core.Mesh | None, + gpu_mesh: pallas_core.Mesh | None, + jax_mesh: mesh_lib.Mesh | None, jaxpr: jax_core.Jaxpr, - compiler_params: dict[str, Any], + params: gpu_core.CompilerParams, cost_estimate: pallas_core.CostEstimate | None, ) -> LoweringResult: del cost_estimate # Unused. @@ -474,24 +697,23 @@ def lower_pipelined_jaxpr_to_module( block_mappings, [grid_mapping.num_inputs] ) - if mesh is not None: - assert isinstance(mesh, gpu_core.GPUMesh) - if mesh and mesh.num_threads is not None: - # Last dim corresponds to the warpgroup count. - block = (128 * grid_mapping.grid[-1], 1, 1) - grid = grid_mapping.grid[:-1] + if gpu_mesh: + assert isinstance(gpu_mesh, gpu_core.Mesh) + block = (128 * (gpu_mesh.num_threads or 1), 1, 1) + grid = gpu_mesh.grid + thread_axis = ( + gpu_mesh.thread_name if gpu_mesh.thread_name is not None else () + ) else: block = (128, 1, 1) grid = grid_mapping.grid + thread_axis = () - params = compiler_params.get("mosaic_gpu", {}) - dimension_semantics = params.get("dimension_semantics", None) - if dimension_semantics is None: + if params.dimension_semantics is None: which_parallel = [True] * len(grid) else: - assert len(dimension_semantics) == len(grid) - which_parallel = [ds == "parallel" for ds in dimension_semantics] - del dimension_semantics + assert len(params.dimension_semantics) == len(grid) + which_parallel = [ds == "parallel" for ds in params.dimension_semantics] sequential_grid = tuple( d for axis, d in enumerate(grid) if not which_parallel[axis] @@ -506,28 +728,25 @@ def lower_pipelined_jaxpr_to_module( def ref_for_aval(aval: jax_core.AbstractValue): if isinstance(aval, gpu_core.WGMMAAbstractAccumulatorRef): return gpu_core.WGMMAAccumulatorRef(aval.shape, aval.dtype) + elif isinstance(aval, gpu_core.AbstractTMEMRef): + return gpu_core.TMEM(aval.shape, aval.dtype, packed=aval.packed) elif isinstance(aval, pallas_core.AbstractMemoryRef): return pallas_core.MemoryRef(aval.shape, aval.dtype, aval.memory_space) else: return gpu_core.SMEM(aval.shape, aval.dtype) def pipeline_fn(*refs): - return primitives.run_scoped( + primitives.run_scoped( functools.partial(scoped_pipeline_fn, *refs), - scratch_refs=[ - ref_for_aval(v.aval) - for v in jaxpr.invars[grid_mapping.slice_scratch_ops] - ], + scratch_refs=[ref_for_aval(v.aval) for v in jaxpr.invars[grid_mapping.slice_scratch_ops]], + collective_axes=thread_axis, # scratch_refs are shared across threads ) + return () # ``wrap_init`` does not support functions returning None. def scoped_pipeline_fn(*refs, scratch_refs): - def body_fn(*refs): - grid_env = pallas_core.current_grid_env() - assert grid_env is not None # Set by ``emit_pipeline``. + def body_fn(indices, *refs): program_ids_template = util.merge_lists( - which_parallel, - [grid_axis.index for grid_axis in grid_env], - [None] * sum(which_parallel), + which_parallel, indices, [None] * sum(which_parallel) ) assert len(refs) + len(scratch_refs) == len(jaxpr.invars) return gpu_primitives.jaxpr_call( @@ -545,17 +764,13 @@ def body_fn(*refs): _block_spec_from_block_mapping(bm, which_parallel) for bm in out_block_mappings ], - max_concurrent_steps=params.pop("max_concurrent_steps", 1), - delay_release=params.pop("delay_release", 0), + max_concurrent_steps=params.max_concurrent_steps, + delay_release=params.delay_release, )(*refs) with grid_mapping.trace_env(): new_jaxpr, _, new_consts, () = pe.trace_to_jaxpr_dynamic( - lu.wrap_init( - # ``wrap_init`` does not support functions returning None. - lambda *args: pipeline_fn(*args) or (), - debug_info=jaxpr.debug_info, - ), + lu.wrap_init(pipeline_fn, debug_info=jaxpr.debug_info), [ gpu_core.GMEM( bm.array_shape_dtype.shape, bm.array_shape_dtype.dtype @@ -565,37 +780,46 @@ def body_fn(*refs): ) assert not new_consts + axis_names = ( + _AxisNames(gpu_mesh.grid_names, gpu_mesh.cluster_names, gpu_mesh.thread_name) + if gpu_mesh is not None + else _AxisNames(grid_mapping.grid_names or ()) + ) with grid_mapping.trace_env(): return lower_jaxpr_to_module( + jax_mesh, + axis_names, parallel_grid, - grid_mapping.grid_names, block, - mesh.cluster if mesh is not None else (), + gpu_mesh.cluster if gpu_mesh is not None else (), [bm.array_shape_dtype for bm in in_block_mappings], [bm.array_shape_dtype for bm in out_block_mappings], new_jaxpr, - compiler_params, + params, new_consts, ) def lower_jaxpr_to_module( + jax_mesh: mesh_lib.Mesh | None, + axis_names: _AxisNames, grid: Sequence[int], - grid_names: Sequence[str], block: Sequence[int], cluster: Sequence[int], in_shapes: Sequence[jax.ShapeDtypeStruct], out_shapes: Sequence[jax.ShapeDtypeStruct], jaxpr: jax_core.Jaxpr, - compiler_params: dict[str, Any], + params: gpu_core.CompilerParams, consts=(), ) -> LoweringResult: debug_info = jaxpr.debug_info - params = compiler_params.get("mosaic_gpu", {}) - approx_math = params.get("approx_math", False) - thread_semantics = params.get( - "thread_semantics", mgpu_core.ThreadSemantics.Lane - ) + approx_math = params.approx_math + lowering_semantics = params.lowering_semantics + + if len(cluster) < 3: + cluster = cluster + (1,) * (3 - len(cluster)) + else: + assert len(cluster) == 3 if len(grid) <= 3: squashed_dims = () @@ -606,85 +830,186 @@ def lower_jaxpr_to_module( squashed_dims = grid[:-2] parallel_grid = (math.prod(grid[:-2]), *grid[-2:]) + rs = _estimate_resources( + ResourceEstimatorContext( + axis_names=axis_names, lowering_semantics=lowering_semantics + ), + jaxpr, + ) + def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value): - *buffers_gmem, (runtime_smem, runtime_barriers) = buffers + *buffers_gmem, ( + runtime_smem, + runtime_barriers, + runtime_tmem, + runtime_tmem_collective, + ) = buffers + gmem_semaphores = None + if rs.gmem_semaphores: + # Extract the semaphores local to the current block. + index = ir.IndexType.get() + block_idx = arith_dialect.index_castui(index, mgpu_utils.block_idx()) + gmem_semaphores = mgpu.memref_slice( + buffers_gmem[-1], + mgpu.ds( + arith_dialect.muli( + block_idx, arith_dialect.constant(index, rs.gmem_semaphores) + ), + rs.gmem_semaphores, + ), + ) + # The semaphore buffer is an aliased input/output, so we need to skip it twice. + buffers_gmem = buffers_gmem[:len(in_shapes)] + buffers_gmem[-len(out_shapes) - 1:-1] grouped_barriers = collections.defaultdict(list) for barrier, barrier_ref in zip(rs.barriers, runtime_barriers): grouped_barriers[barrier].append(barrier_ref) + if runtime_tmem is not None: + tmem_cols = math.prod(runtime_tmem.shape) // tcgen05.TMEM_ROWS + else: + tmem_cols = 0 + if runtime_tmem_collective is not None: + tmem_collective_cols = ( + math.prod(runtime_tmem_collective.shape) // tcgen05.TMEM_ROWS + ) + else: + tmem_collective_cols = 0 + + if lowering_semantics == mgpu.LoweringSemantics.Lane: + single_wg_lane_predicate = mgpu.single_thread_predicate( + scope=mgpu.ThreadSubset.WARPGROUP) + single_warp_lane_predicate = mgpu.single_thread_predicate( + scope=mgpu.ThreadSubset.WARP) + else: # Warpgroup semantics do not have a single lane predicate. + single_wg_lane_predicate = None + single_warp_lane_predicate = None + module_ctx = ModuleContext( mlir.sanitize_name(debug_info.func_name), - grid_names, + axis_names, [_program_id(axis, squashed_dims) for axis in range(len(grid))], approx_math, - mgpu.single_thread_predicate(per_block=False), + single_wg_lane_predicate, + single_warp_lane_predicate, smem_requested_bytes=math.prod(ir.MemRefType(runtime_smem.type).shape), smem_used_bytes=0, + tmem_requested_cols=tmem_cols, + tmem_used_cols=0, + tmem_base_ptr=runtime_tmem.address if runtime_tmem else None, + tmem_collective_requested_cols=tmem_collective_cols, + tmem_collective_used_cols=0, + tmem_collective_base_ptr=runtime_tmem_collective.address + if runtime_tmem_collective + else None, + gmem_used_semaphores=0, + gmem_semaphore_base_ptr=gmem_semaphores, runtime_barriers=grouped_barriers, name_stack=source_info_util.NameStack(), traceback_caches=mlir.TracebackCaches(), squashed_dims=squashed_dims, - thread_semantics=thread_semantics, + lowering_semantics=lowering_semantics, + primitive_semantics=gpu_core.PrimitiveSemantics.Warpgroup, + mesh=jax_mesh, + auto_barriers=not params.unsafe_no_auto_barriers, ) del runtime_smem, grouped_barriers, runtime_barriers - _ = lower_jaxpr_to_mosaic_gpu( module_ctx, launch_ctx, jaxpr, buffers_gmem, consts ) - rs = _estimate_resources(ResourceEstimatorContext(thread_semantics), jaxpr) - smem_scratch_bytes = params.get("smem_scratch_bytes") - if smem_scratch_bytes is None: - smem_scratch_bytes = rs.smem_scratch_bytes + scratch_buffers = [ + jax.ShapeDtypeStruct(shape=[rs.smem_scratch_bytes], dtype=np.int8), + rs.barriers, + ] + if rs.tmem_scratch_cols > 0: + scratch_buffers.append( + mgpu.TMEM( + shape=[tcgen05.TMEM_ROWS, rs.tmem_scratch_cols], + dtype=np.int32, + collective=False, + ), + ) + else: + scratch_buffers.append(None) + if rs.tmem_collective_scratch_cols > 0: + scratch_buffers.append( + mgpu.TMEM( + shape=[tcgen05.TMEM_ROWS, rs.tmem_collective_scratch_cols], + dtype=np.int32, + collective=True, + ), + ) + else: + scratch_buffers.append(None) prof_ctx = prof_spec = None - if prof_space := params.get("profile_space", 0): + if params.profile_space: # Each range is 2 events, each event is 4 bytes. - prof_spec = mgpu_profiler.ProfilerSpec(prof_space * 2 * 4) - prof_ctx = ProfilerContext(params["profile_dir"], prof_spec) - module, out_structs_gmem, _, launch_ctx, scratch_arr = ( + prof_spec = mgpu_profiler.ProfilerSpec(params.profile_space * 2 * 4) + prof_ctx = ProfilerContext(params.profile_dir, prof_spec) + mgpu_grid = tuple(map(operator.mul, parallel_grid, cluster)) + semaphores_shape = () + if rs.gmem_semaphores: + semaphores_shape = ( + jax.ShapeDtypeStruct( + shape=(math.prod(mgpu_grid) * rs.gmem_semaphores,), dtype=np.int32 + ), + ) + # NOTE: new_out_shapes has out_shapes, then semaphores_shape and + # optionally the profiler buffer. + module, new_out_shapes, _, launch_ctx = ( mgpu_core._lower_as_gpu_kernel( body, - grid=parallel_grid, + grid=mgpu_grid, cluster=cluster, block=block, - in_shapes=in_shapes, - out_shape=out_shapes, - smem_scratch_shape=( - jax.ShapeDtypeStruct(shape=[smem_scratch_bytes], dtype=np.int8), - rs.barriers, - ), + in_shapes=(*in_shapes, *semaphores_shape), + out_shape=(*out_shapes, *semaphores_shape), + inout_shape=(), + smem_scratch_shape=scratch_buffers, + lowering_semantics=lowering_semantics, module_name=mlir.sanitize_name(debug_info.func_name), prof_spec=prof_spec, ) ) - if thread_semantics == mgpu.ThreadSemantics.Warpgroup: + if lowering_semantics == mgpu.LoweringSemantics.Warpgroup: + # We need to run a pass that removes dead-code for which layout inference + # does not work. + pm = mlir.passmanager.PassManager.parse("builtin.module(canonicalize)", module.context) + pm.run(module.operation) + # Run Python lowering passes. The remaining passes will be run in C++ in # jax/jaxlib/mosaic/gpu/custom_call.cc mgpu.infer_layout(module) # pytype: disable=attribute-error + mgpu.infer_transforms(module) # pytype: disable=attribute-error mgpu.lower_mgpu_dialect(module, launch_ctx) # pytype: disable=attribute-error - mgpu_core._initialize_scratch(launch_ctx, scratch_arr) + launch_ctx.scratch.finalize_size() return LoweringResult( - module, parallel_grid, block, out_structs_gmem, prof_ctx + module, parallel_grid, block, new_out_shapes, prof_ctx, semaphores_shape ) mosaic_lowering_rules = { # Lowering rules when using Mosaic GPU lane semantics. - mgpu.ThreadSemantics.Lane: {} , + (mgpu.LoweringSemantics.Lane, gpu_core.PrimitiveSemantics.Warpgroup): {} , + gpu_core.LANExWARP_SEMANTICS: {} , # Lowering rules when using Mosaic GPU warpgroup semantics. - mgpu.ThreadSemantics.Warpgroup: {}, + (mgpu.LoweringSemantics.Warpgroup, + gpu_core.PrimitiveSemantics.Warpgroup): {}, } def register_lowering_rule( - primitive: jax_core.Primitive, thread_semantics: mgpu.ThreadSemantics + primitive: jax_core.Primitive, + lowering_semantics: mgpu.LoweringSemantics, + primitive_semantics: gpu_core.PrimitiveSemantics = gpu_core.PrimitiveSemantics.Warpgroup, ): def deco(fn): - mosaic_lowering_rules[thread_semantics][primitive] = fn + mosaic_lowering_rules[ + (lowering_semantics, primitive_semantics)][primitive] = fn return fn return deco @@ -720,7 +1045,7 @@ def write_env(var: jax_core.Var, val, require_value: bool = True): # TODO(apaszke): Handle other avals (refs, etc.). if isinstance(aval := var.aval, jax_core.ShapedArray): # TODO(apaszke): Clarify the type invariants for lane semantics? - if module_ctx.thread_semantics == mgpu.ThreadSemantics.Warpgroup: + if module_ctx.lowering_semantics == mgpu.LoweringSemantics.Warpgroup: # Shaped arrays must be vectors if and only if their shape is non-empty. # Those with empty shapes should be represented by their scalar type. mlir_dtype = mgpu_utils.dtype_to_ir_type(aval.dtype) @@ -745,8 +1070,11 @@ def write_env(var: jax_core.Var, val, require_value: bool = True): if val.type != mlir_dtype: raise AssertionError(f"Scalar type must match ShapedArray dtype, got: {val.type} != {mlir_dtype}") - foreach(write_env, jaxpr.constvars, consts) - foreach(lambda v, a: write_env(v, a, require_value=False), jaxpr.invars, args) + foreach( + functools.partial(write_env, require_value=False), jaxpr.constvars, consts + ) + foreach(functools.partial(write_env, require_value=False), jaxpr.invars, args) + # TODO(justinfu): Handle transform scopes. last_local_name_stack: list[str] = [] named_regions = [] @@ -757,10 +1085,13 @@ def write_env(var: jax_core.Var, val, require_value: bool = True): ) loc = mlir._source_info_to_location(module_ctx, eqn.primitive, source_info) with source_info_util.user_context(eqn.source_info.traceback), loc: - if eqn.primitive not in mosaic_lowering_rules[module_ctx.thread_semantics]: + if eqn.primitive not in mosaic_lowering_rules[ + (module_ctx.lowering_semantics, module_ctx.primitive_semantics)]: raise NotImplementedError( "Unimplemented primitive in Pallas Mosaic GPU lowering: " - f"{eqn.primitive.name}. " + f"{eqn.primitive.name} for lowering semantics " + f"{module_ctx.lowering_semantics} and user thread semantics " + f"{module_ctx.primitive_semantics}. " "Please file an issue on https://github.com/jax-ml/jax/issues." ) new_local_name_stack = [scope.name for scope in eqn.source_info.name_stack.stack] @@ -772,7 +1103,9 @@ def write_env(var: jax_core.Var, val, require_value: bool = True): wrapper_stack = contextlib.ExitStack() wrapper_stack.enter_context(launch_ctx.named_region(name)) named_regions.append(wrapper_stack) - rule = mosaic_lowering_rules[module_ctx.thread_semantics][eqn.primitive] + rule = mosaic_lowering_rules[ + (module_ctx.lowering_semantics, module_ctx.primitive_semantics) + ][eqn.primitive] rule_ctx = LoweringRuleContext( module_ctx, launch_ctx, @@ -801,8 +1134,9 @@ def write_env(var: jax_core.Var, val, require_value: bool = True): return map(read_env, jaxpr.outvars) -@register_lowering_rule(primitives.program_id_p, mgpu.ThreadSemantics.Lane) -@register_lowering_rule(primitives.program_id_p, mgpu.ThreadSemantics.Warpgroup) +@register_lowering_rule(primitives.program_id_p, mgpu.LoweringSemantics.Lane) +@register_lowering_rule( + primitives.program_id_p, mgpu.LoweringSemantics.Warpgroup) def _program_id_lowering_rule(ctx: LoweringRuleContext, axis): if ctx.module_ctx.program_ids is None: raise NotImplementedError("pl.program_id() is not supported in this context") @@ -869,8 +1203,9 @@ def lowering_rule(ctx: LoweringRuleContext, *args, **params): return lowering_rule -@register_lowering_rule(primitives.num_programs_p, mgpu.ThreadSemantics.Lane) -@register_lowering_rule(primitives.num_programs_p, mgpu.ThreadSemantics.Warpgroup) +@register_lowering_rule(primitives.num_programs_p, mgpu.LoweringSemantics.Lane) +@register_lowering_rule( + primitives.num_programs_p, mgpu.LoweringSemantics.Warpgroup) def _num_programs_lowering_rule(ctx: LoweringRuleContext, axis): del ctx # Unused. return arith_dialect.index_cast( @@ -879,60 +1214,176 @@ def _num_programs_lowering_rule(ctx: LoweringRuleContext, axis): ) -def _handle_reshaping( - ref: ir.Value, transforms: Sequence[gpu_core.Transform] -) -> tuple[ir.Value, Sequence[gpu_core.Transform]]: - is_trivial_indexer = lambda t: isinstance( - t, indexing.NDIndexer - ) and gpu_core.is_trivial_index(t.indices, t.shape) +def _handle_dtype_bitcast( + ref: ir.Value, src_dtype: ir.Type, dst_dtype: ir.Type +) -> ir.Value: + """Allows bitcasting a SMEM ref from one element type to another. - last_reshaper_idx = next( - reversed([i for i, t in enumerate(transforms) if isinstance(t, RefReshaper)]), - None, - ) - if last_reshaper_idx is None: - return ref, transforms - # Check that before the reshape are only trivial indexes and or - # other reshapes. - # TODO(cperivol): Reshapes should bubble up rather than being - # expected to effectively be the first ref transform. - if not all(isinstance(t, RefReshaper) or is_trivial_indexer(t) for t in transforms[:last_reshaper_idx]): + Args: + ref: the reference to bitcast. + src_dtype: the source element type. + dst_dtype: the destination element type. + + Returns: + A bitcasted version of `ref` with element type `dst_dtype`. + + Raises: + ValueError: if the source ref is not in SMEM. + """ + if src_dtype == dst_dtype: + return ref + if src_dtype != ir.IntegerType.get_signless(8): + raise NotImplementedError( + "Data type bitcast is only supported from i8 to other types." + ) + ref_ty = ir.MemRefType(ref.type) + if ref_ty.memory_space != ir.Attribute.parse("#gpu.address_space"): + raise ValueError(f"Only workgroup memory is supported but got {ref}.") + if len(ref_ty.shape) != 1: raise NotImplementedError( - "Reshapes do not compose with other transforms and indexers must be" - f" trivial (transforms: {transforms})" + "Data type bitcast is only supported for 1D arrays." + ) + [stride], _ = ref_ty.get_strides_and_offset() + if stride != 1: + raise ValueError( + "Data type bitcast is only supported for contiguous 1D arrays, but got " + f"stride={stride}." + ) + [shape_bytes] = ref_ty.shape + shape_bitwidth = shape_bytes * 8 + target_bitwidth = mgpu_utils.bitwidth(dst_dtype) + + if shape_bitwidth % target_bitwidth: + raise ValueError( + f"Can not bitcast memory region of size {shape_bitwidth} bits to dtype " + f"with {target_bitwidth} bits." ) - reshaper = cast(RefReshaper, transforms[last_reshaper_idx]) - # Skip all the reshapes and trivial indexes. - return mgpu.memref_reshape(ref, reshaper.shape), transforms[last_reshaper_idx + 1:] + result_type = ir.MemRefType.get( + shape=(shape_bitwidth // target_bitwidth,), + element_type=dst_dtype, + memory_space=ref_ty.memory_space, + ) -def _handle_indexing( + # Do a memref_ptr/ptr_as_memref roundtrip instead of using `memref.view`, + # which refuses to take in our source ref. This is because `memref.view` only + # works on a super restricted set of `memref`s. E.g., it does not work if an + # offset is specified, which can be the case for our SMEM refs. + smem = mgpu_utils.WORKGROUP_NVPTX_ADDRESS_SPACE + ref = mgpu_utils.memref_ptr(ref, memory_space=smem) + return mgpu_utils.ptr_as_memref(ref, result_type, ptr_memory_space=smem) + + +def _extract_aliased_ref( ref: ir.Value, transforms: Sequence[gpu_core.Transform] ) -> tuple[ir.Value, Sequence[gpu_core.Transform]]: - if not transforms: - pass - indexer_idxs = [ - i for i, t in enumerate(transforms) if isinstance(t, indexing.NDIndexer) - ] - if not indexer_idxs: - return ref, transforms - sliced_ref = ref + match transforms: + case ( + gpu_core.ExtractAliasedRef(dtype, transformed_shape, offset), + *other_transforms, + ): + mlir_dtype = mgpu_utils.dtype_to_ir_type(dtype) + ref_bits = math.prod(transformed_shape) * mgpu_utils.bitwidth(mlir_dtype) + if ref_bits % 8: + raise NotImplementedError("Only byte-aligned bitcasts are supported.") + assert offset % gpu_core.SMEM_ALIGNMENT == 0 + ref_bytes = ref_bits // 8 + ref = mgpu.memref_slice(ref, slice(offset, offset + ref_bytes)) + ref = _handle_dtype_bitcast( + ref, + ir.MemRefType(ref.type).element_type, + mgpu_utils.dtype_to_ir_type(dtype), + ) + ref = mgpu.memref_reshape(ref, transformed_shape) + return ref, tuple(other_transforms) + case _: + return ref, transforms + + +def _transform_dtype( + dtype: dtypes.DType, + transforms: Sequence[state_types.Transform], +) -> dtypes.DType: + """Applies `t.transform_dtype` for `t` in `transforms` sequentially on `dtype`.""" + for transform in transforms: + dtype = transform.transform_dtype(dtype) + return dtype + + +def _handle_transforms( + ctx: LoweringRuleContext, + ref: RefOrTmemType, + transforms: Sequence[state_types.Transform], + *, + handle_transposes=True, + handle_reshapes=True, + allow_peer_refs=False, +) -> tuple[RefOrTmemType, Sequence[state_types.Transform]]: + if isinstance(ref, tcgen05.TMEMRef): + mlir_dtype = ref.dtype + else: + # Before we handle other transforms, we resolve any possible leading + # aliasing transform. + ref, transforms = _extract_aliased_ref(ref, transforms) + mlir_dtype = ir.MemRefType(ref.type).element_type + transformed_ref = ref new_transforms = [] - for t in transforms: - if not isinstance(t, indexing.NDIndexer): - new_transforms.append(t) - continue - indexer = cast(indexing.NDIndexer, t) - if indexer.int_indexer_shape: - raise NotImplementedError("int_indexer_shape non-empty") - indices = _ndindexer_indices(indexer) + def _bubble_up(untransform_fn, data): + nonlocal new_transforms new_transforms_rev = [] for t in reversed(new_transforms): - indices, new_t = t.untransform_index(indices) + data, new_t = untransform_fn(t, data) new_transforms_rev.append(new_t) - sliced_ref = mgpu.memref_slice(sliced_ref, indices) + new_transforms = list(reversed(new_transforms_rev)) - return sliced_ref, new_transforms + return data + + peer_device_id = None + for t in transforms: + match t: + case indexing.NDIndexer(): + indexer = cast(indexing.NDIndexer, t) + if indexer.int_indexer_shape: + raise NotImplementedError("int_indexer_shape non-empty") + indices = _ndindexer_indices(indexer) + indices = _bubble_up( + lambda t, idxs: t.untransform_index(mlir_dtype, idxs), indices + ) + if isinstance(transformed_ref, tcgen05.TMEMRef): + transformed_ref = transformed_ref.slice(*indices) + else: + transformed_ref = mgpu.memref_slice(transformed_ref, indices) + case gpu_core.TransposeRef(perm) if handle_transposes: + perm = _bubble_up(lambda t, p: t.untransform_transpose(p), + perm) + if isinstance(transformed_ref, tcgen05.TMEMRef): + raise ValueError("TMEM transpose not allowed.") + transformed_ref = mgpu.memref_transpose(transformed_ref, perm) + case RefReshaper(dtype=dtype, shape=shape) if handle_reshapes: + shape = _bubble_up( + lambda t, p: t.untransform_reshape(dtype, p), # pylint: disable=cell-var-from-loop + shape) + if isinstance(transformed_ref, tcgen05.TMEMRef): + raise ValueError("TMEM reshape not allowed.") + transformed_ref = mgpu.memref_reshape(transformed_ref, shape) + case gpu_core.PeerMemRef(device_id, device_id_type): + if device_id_type != primitives.DeviceIdType.LOGICAL: + raise NotImplementedError( + "Only logical device IDs are supported for peer memrefs." + ) + peer_device_id = device_id + case _: + new_transforms.append(t) + if peer_device_id is not None: + if not allow_peer_refs: + raise NotImplementedError( + "Peer device references are not allowed in the lowering of this" + " primitive." + ) + transformed_ref = ctx.launch_ctx.to_remote( + transformed_ref, _ensure_ir_value(peer_device_id, jnp.int32) + ) + return transformed_ref, new_transforms def _ndindexer_indices(indexer: indexing.NDIndexer) -> tuple[gpu_core.Index, ...]: @@ -954,48 +1405,65 @@ def _ndindexer_indices(indexer: indexing.NDIndexer) -> tuple[gpu_core.Index, ... return tuple(indices) -@register_lowering_rule(sp.get_p, mgpu.ThreadSemantics.Lane) -def _get_lowering_rule(ctx: LoweringRuleContext, x_smem, *leaves, tree): - if not isinstance(x_smem, ir.Value) and ir.MemRefType.isinstance(x_smem): - raise TypeError(f"Can only load from references (got {x_smem}).") +@register_lowering_rule(sp.get_p, mgpu.LoweringSemantics.Lane) +def _get_lowering_rule(ctx: LoweringRuleContext, x_ref, *leaves, tree): + if isinstance(x_ref, tcgen05.TMEMRef): + transforms = jax.tree.unflatten(tree, leaves) + x_tmem, transforms = _handle_transforms( + ctx, x_ref, transforms, handle_transposes=False, handle_reshapes=False, + ) + if transforms: + raise NotImplementedError( + f"Unimplemented transforms for TMEM refs. {transforms=}" + ) + return x_tmem.load() - x_aval = ctx.avals_in[0] + if not isinstance(x_ref, ir.Value) and ir.MemRefType.isinstance(x_ref): + raise TypeError(f"Can only load from references (got {x_ref}).") + dtype = ctx.avals_out[0].dtype transforms = jax.tree.unflatten(tree, leaves) - x_smem, transforms = _handle_reshaping(x_smem, transforms) - x_smem, transforms = _handle_indexing(x_smem, transforms) + x_smem, transforms = _handle_transforms( + ctx, x_ref, transforms, allow_peer_refs=True + ) match transforms: case (gpu_core.UnswizzleRef(swizzle), gpu_core.UntileRef(tiling)): - if tiling != (64, swizzle // x_aval.dtype.itemsize): - raise NotImplementedError("Tiling does not fit swizzle") + if len(tiling) != 2: + raise NotImplementedError(f"Only 2D tiling is supported, got: {tiling}") + expected_minor_tiling = swizzle * 8 // pallas_utils.dtype_bitwidth(dtype) + if tiling[-1] != expected_minor_tiling: + raise NotImplementedError( + "Minor tiling dimension does not fit swizzle: " + f" expected {expected_minor_tiling}, got {tiling[-1]}" + ) return mgpu.FragmentedArray.load_tiled( - x_smem, is_signed=mgpu_utils.is_signed(x_aval.dtype), swizzle=swizzle + x_smem, is_signed=mgpu_utils.is_signed(dtype), swizzle=swizzle ) case (): # Handle scalar indexing. if not ctx.avals_out[0].shape: - is_signed = mgpu_utils.is_signed(x_aval.dtype) + is_signed = mgpu_utils.is_signed(dtype) val = memref_dialect.load(x_smem, []) return mgpu.FragmentedArray.splat(val, shape=(), is_signed=is_signed) return mgpu.FragmentedArray.load_strided( - x_smem, is_signed=mgpu_utils.is_signed(x_aval.dtype) + x_smem, is_signed=mgpu_utils.is_signed(dtype) ) case _: raise NotImplementedError(f"Unsupported transforms: {transforms}") -@register_lowering_rule(sp.get_p, mgpu.ThreadSemantics.Warpgroup) +@register_lowering_rule(sp.get_p, mgpu.LoweringSemantics.Warpgroup) def _get_lowering_rule_wg(ctx: LoweringRuleContext, x_smem, *leaves, tree): if not isinstance(x_smem, ir.Value) and ir.MemRefType.isinstance(x_smem): raise TypeError(f"Can only load from references (got {x_smem}).") - x_aval = ctx.avals_in[0] - transforms = jax.tree.unflatten(tree, leaves) - x_smem, transforms = _handle_reshaping(x_smem, transforms) - x_smem, transforms = _handle_indexing(x_smem, transforms) + x_smem, transforms = _handle_transforms( + ctx, x_smem, transforms, allow_peer_refs=True + ) + mlir_dtype = ir.MemRefType(x_smem.type).element_type if transforms: raise NotImplementedError( @@ -1003,7 +1471,7 @@ def _get_lowering_rule_wg(ctx: LoweringRuleContext, x_smem, *leaves, tree): ) shape = ctx.avals_out[0].shape - ty = ir.VectorType.get(shape, mgpu_utils.dtype_to_ir_type(x_aval.dtype)) + ty = ir.VectorType.get(shape, mlir_dtype) if shape: zero_index = arith_dialect.constant(ir.IndexType.get(), 0) indices = [zero_index for _ in range(len(shape))] @@ -1012,59 +1480,114 @@ def _get_lowering_rule_wg(ctx: LoweringRuleContext, x_smem, *leaves, tree): return memref_dialect.load(x_smem, []) -@register_lowering_rule(sp.swap_p, mgpu.ThreadSemantics.Lane) +@register_lowering_rule(sp.swap_p, mgpu.LoweringSemantics.Lane) def _swap_lowering_rule( - ctx: LoweringRuleContext, x_smem, value, *leaves, tree + ctx: LoweringRuleContext, x_ref, value, *leaves, tree ): if not isinstance(value, mgpu.FragmentedArray): raise TypeError(f"Can only store arrays (got {value}).") - if not isinstance(x_smem, ir.Value) and ir.MemRefType.isinstance(x_smem): - raise TypeError(f"Can only store to references (got {x_smem}).") - x_aval = ctx.avals_in[0] + + if isinstance(x_ref, tcgen05.TMEMRef): + transforms = jax.tree.unflatten(tree, leaves) + x_tmem, transforms = _handle_transforms( + ctx, x_ref, transforms, handle_transposes=False, handle_reshapes=False, + ) + if transforms: + raise NotImplementedError( + f"Unimplemented transforms for TMEM refs. {transforms=}" + ) + old_value = x_tmem.load(layout=value.layout) + x_tmem.store(value) + return old_value + + if not isinstance(x_ref, ir.Value) and ir.MemRefType.isinstance(x_ref): + raise TypeError(f"Can only store to references (got {x_ref}).") + v_aval = ctx.avals_in[1] transforms = jax.tree.unflatten(tree, leaves) - x_smem, transforms = _handle_reshaping(x_smem, transforms) - x_smem, transforms = _handle_indexing(x_smem, transforms) + transposed_value = value.layout == mgpu.WGMMA_TRANSPOSED_LAYOUT + x_smem, transforms = _handle_transforms( + ctx, x_ref, transforms, handle_transposes=not transposed_value, + allow_peer_refs=True + ) + if ctx.module_ctx.auto_barriers: + mgpu.warpgroup_barrier() # Make sure reads have completed before we write. match transforms: - case (gpu_core.UnswizzleRef(swizzle), gpu_core.UntileRef(tiling)): - if tiling != (64, swizzle // x_aval.dtype.itemsize): - raise NotImplementedError("Tiling does not fit swizzle") + case ( + gpu_core.UnswizzleRef(swizzle), + gpu_core.UntileRef(tiling), + *maybe_transpose, + ): + if len(tiling) != 2: + raise NotImplementedError(f"Only 2D tiling is supported, got: {tiling}") + bw = pallas_utils.dtype_bitwidth(v_aval.dtype) + expected_minor_tiling = swizzle * 8 // bw + if tiling[-1] != expected_minor_tiling: + raise NotImplementedError( + "Minor tiling dimension does not fit swizzle: " + f" expected {expected_minor_tiling}, got {tiling[-1]}" + ) + + if transposed_value != bool(maybe_transpose): + raise ValueError( + "Either both the ref and the value are transposed or neither is." + ) + + if maybe_transpose: + if maybe_transpose != [gpu_core.TransposeRef((1, 0))]: + raise NotImplementedError( + f"Unsupported transforms: {transforms} ({maybe_transpose})" + ) + + x_smem = mgpu.memref_transpose(x_smem, (1, 0, 3, 2)) + old_value = mgpu.FragmentedArray.load_tiled( - x_smem, is_signed=mgpu_utils.is_signed(x_aval.dtype), swizzle=swizzle + x_smem, + is_signed=mgpu_utils.is_signed(v_aval.dtype), + swizzle=swizzle, + layout=value.layout, ) value.store_tiled(x_smem, swizzle=swizzle) - return old_value case (): - old_value = mgpu.FragmentedArray.load_strided( - x_smem, is_signed=mgpu_utils.is_signed(x_aval.dtype) - ) - value.store_untiled(x_smem) - return old_value + match value.layout: + case mgpu.TiledLayout(): + old_value = mgpu.FragmentedArray.load_untiled( + x_smem, + layout=value.layout, + is_signed=mgpu_utils.is_signed(v_aval.dtype), + optimized=False, + ) + value.store_untiled(x_smem, optimized=False) + case _: + old_value = mgpu.FragmentedArray.load_strided( + x_smem, is_signed=mgpu_utils.is_signed(v_aval.dtype) + ) + value.store_untiled(x_smem) case _: raise NotImplementedError(f"Unsupported transforms: {transforms}") + if ctx.module_ctx.auto_barriers: + mgpu.warpgroup_barrier() # Make sure the writes have completed. + return old_value -@register_lowering_rule(sp.swap_p, mgpu.ThreadSemantics.Warpgroup) +@register_lowering_rule(sp.swap_p, mgpu.LoweringSemantics.Warpgroup) def _swap_lowering_rule_wg( ctx: LoweringRuleContext, x_smem, value, *leaves, tree ): - if not ir.VectorType.isinstance(value.type): - raise TypeError(f"Can only store vectors (got {value}).") + shape = ctx.avals_out[0].shape + if shape and not ir.VectorType.isinstance(value.type): + raise TypeError(f"Can only store scalars or vectors (got {value}).") if not ir.MemRefType.isinstance(x_smem.type): raise TypeError(f"Can only store to references (got {x_smem}).") - x_aval = ctx.avals_in[0] - transforms = jax.tree.unflatten(tree, leaves) - x_smem, transforms = _handle_reshaping(x_smem, transforms) - x_smem, transforms = _handle_indexing(x_smem, transforms) - + x_smem, transforms = _handle_transforms( + ctx, x_smem, transforms, allow_peer_refs=True) if transforms: raise NotImplementedError( "Transforms are not yet implemented for warpgroup semantics" ) - - shape = ctx.avals_out[0].shape - ty = ir.VectorType.get(shape, mgpu_utils.dtype_to_ir_type(x_aval.dtype)) + x_mlir_dtype = ir.MemRefType(x_smem.type).element_type + ty = ir.VectorType.get(shape, x_mlir_dtype) if shape: zero_index = arith_dialect.constant(ir.IndexType.get(), 0) indices = [zero_index for _ in range(len(shape))] @@ -1076,8 +1599,8 @@ def _swap_lowering_rule_wg( return old_value -@register_lowering_rule(pjit.pjit_p, mgpu.ThreadSemantics.Lane) -@register_lowering_rule(pjit.pjit_p, mgpu.ThreadSemantics.Warpgroup) +@register_lowering_rule(pjit.pjit_p, mgpu.LoweringSemantics.Lane) +@register_lowering_rule(pjit.pjit_p, mgpu.LoweringSemantics.Warpgroup) def _pjit_lowering_rule(ctx: LoweringRuleContext, *args, jaxpr, **kwargs): if jaxpr.consts: raise NotImplementedError @@ -1085,11 +1608,8 @@ def _pjit_lowering_rule(ctx: LoweringRuleContext, *args, jaxpr, **kwargs): ctx.module_ctx, ctx.launch_ctx, jaxpr.jaxpr, args, ) -@register_lowering_rule(pjit.mesh_cast_p, mgpu.ThreadSemantics.Lane) -def _mesh_cast_lowering_rule(ctx, x, dst_sharding): - return x -@register_lowering_rule(lax.slice_p, mgpu.ThreadSemantics.Lane) +@register_lowering_rule(lax.slice_p, mgpu.LoweringSemantics.Lane) def _slice_lowering_rule( ctx: LoweringRuleContext, x, limit_indices, start_indices, strides ): @@ -1099,8 +1619,10 @@ def _slice_lowering_rule( return x[tuple(slice(b, e) for b, e in zip(start_indices, limit_indices))] -@register_lowering_rule(lax.select_n_p, mgpu.ThreadSemantics.Lane) -@register_lowering_rule(lax.select_n_p, mgpu.ThreadSemantics.Warpgroup) +@register_lowering_rule(lax.select_n_p, mgpu.LoweringSemantics.Lane) +@register_lowering_rule(lax.select_n_p, mgpu.LoweringSemantics.Lane, + gpu_core.PrimitiveSemantics.Warp) +@register_lowering_rule(lax.select_n_p, mgpu.LoweringSemantics.Warpgroup) def _select_n_lowering_rule(ctx: LoweringRuleContext, pred, *cases): if len(cases) != 2: raise NotImplementedError( @@ -1108,8 +1630,12 @@ def _select_n_lowering_rule(ctx: LoweringRuleContext, pred, *cases): f" {len(cases)}" ) pred_aval, *cases_avals = ctx.avals_in + if ctx.module_ctx.primitive_semantics == gpu_core.PrimitiveSemantics.Warp: + if not all(aval.shape == () for aval in ctx.avals_in): + raise NotImplementedError( + "Can only select on scalars in warp-level lowering.") [out_aval] = ctx.avals_out - if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane: + if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Lane: pred = _ensure_fa(pred, pred_aval.dtype) cases = _bcast(*cases, *cases_avals, out_aval) # ``select`` expects the first case to be the true branch, but ``select_n`` @@ -1127,7 +1653,7 @@ def _select_n_lowering_rule(ctx: LoweringRuleContext, pred, *cases): return arith_dialect.select(pred, *reversed(cases)) -@register_lowering_rule(lax.broadcast_in_dim_p, mgpu.ThreadSemantics.Lane) +@register_lowering_rule(lax.broadcast_in_dim_p, mgpu.LoweringSemantics.Lane) def _broadcast_in_dim_lowering_rule( ctx: LoweringRuleContext, x: mgpu.FragmentedArray, @@ -1143,46 +1669,70 @@ def _broadcast_in_dim_lowering_rule( if ( broadcast_dimensions == tuple(range(x_aval.ndim)) and y_aval.ndim == x_aval.ndim + 1 - and x.layout == mgpu.WGMMA_ROW_LAYOUT + and x.layout in (mgpu.WGMMA_ROW_LAYOUT, mgpu.TCGEN05_ROW_LAYOUT) ): return x.broadcast_minor(y_aval.shape[-1]) + if ( + broadcast_dimensions == (1,) + and y_aval.ndim == x_aval.ndim + 1 + and x.layout in (mgpu.WGMMA_COL_LAYOUT, mgpu.TCGEN05_COL_LAYOUT) + ): + return x.broadcast_major(y_aval.shape[-2]) if broadcast_dimensions: - raise NotImplementedError + raise NotImplementedError( + f"Unsupport broadcast {broadcast_dimensions} for layout: {x.layout}" + ) return x.broadcast(shape) -@register_lowering_rule(lax.broadcast_in_dim_p, mgpu.ThreadSemantics.Warpgroup) +@register_lowering_rule( + lax.broadcast_in_dim_p, mgpu.LoweringSemantics.Warpgroup) def _broadcast_in_dim_lowering_rule_wg( ctx: LoweringRuleContext, - x: ir.Value, + x, *, broadcast_dimensions, shape, sharding, ): del sharding - if broadcast_dimensions: - raise NotImplementedError + [x_aval] = ctx.avals_in - x = _ensure_ir_value(x, x_aval.dtype) - return vector_dialect.splat( - ir.VectorType.get(shape, mgpu_utils.dtype_to_ir_type(x_aval.dtype)), - x, - ) + if not broadcast_dimensions: + # Even though we could implement this case by passing a 0D vector as input + # to mgpu.dialect.BroadcastInDimOp we don't want that. 0D vectors are + # generally problematic and so we avoid them by specializing that case + # directly here. + x = _ensure_ir_value(x, x_aval.dtype) + return vector_dialect.splat( + ir.VectorType.get(shape, mgpu_utils.dtype_to_ir_type(x_aval.dtype)), + x, + ) + mlir_type = mgpu_utils.dtype_to_ir_type(x_aval.dtype) + result_ty = ir.VectorType.get(shape, mlir_type) + return mgpu.dialect.broadcast_in_dim(result_ty, x, broadcast_dimensions) -@register_lowering_rule(lax.convert_element_type_p, mgpu.ThreadSemantics.Lane) + +@register_lowering_rule(lax.convert_element_type_p, mgpu.LoweringSemantics.Lane) +@register_lowering_rule(lax.convert_element_type_p, + mgpu.LoweringSemantics.Lane, gpu_core.PrimitiveSemantics.Warp) def _convert_element_type_lowering_rule( ctx: LoweringRuleContext, x, *, new_dtype, weak_type, sharding ): del weak_type, sharding [x_aval] = ctx.avals_in + if ctx.module_ctx.primitive_semantics == gpu_core.PrimitiveSemantics.Warp: + if x_aval.shape != (): + raise NotImplementedError( + "Non-scalar arithmetic is not supported in warp-level lowering.") return _ensure_fa(x, x_aval.dtype).astype( mgpu_utils.dtype_to_ir_type(new_dtype), is_signed=mgpu_utils.is_signed(new_dtype) ) -@register_lowering_rule(lax.convert_element_type_p, mgpu.ThreadSemantics.Warpgroup) +@register_lowering_rule( + lax.convert_element_type_p, mgpu.LoweringSemantics.Warpgroup) def _convert_element_type_lowering_rule_wg( ctx: LoweringRuleContext, x, *, new_dtype, weak_type, sharding ): @@ -1274,25 +1824,42 @@ def convert(ty, x): return convert(ty, x) -mosaic_lowering_rules[mgpu.ThreadSemantics.Lane].update({ +mosaic_lowering_rules[gpu_core.LANExWG_SEMANTICS].update({ lax.neg_p: lambda ctx, x: -x, lax.not_p: lambda ctx, x: ~x, }) -mosaic_lowering_rules[mgpu.ThreadSemantics.Warpgroup].update({ +def _unary_warp_lowering_rule(impl): + def _lowering_rule(ctx: LoweringRuleContext, x): + if not all(aval_in.shape == () for aval_in in ctx.avals_in): + raise NotImplementedError( + "Non-scalar arithmetic is not supported in warp-level lowering.") + return impl(x) + return _lowering_rule + +mosaic_lowering_rules[gpu_core.LANExWARP_SEMANTICS].update({ + lax.neg_p: _unary_warp_lowering_rule(lambda x: -x), + lax.not_p: _unary_warp_lowering_rule(lambda x: ~x) +}) + +mosaic_lowering_rules[gpu_core.WGxWG_SEMANTICS].update({ lax.neg_p: _lower_fun(lambda x: jnp.subtract(0, x), multiple_results=False), lax.not_p: _lower_fun( - lambda x: jnp.bitwise_xor(x, -1), multiple_results=False + lambda x: jnp.astype(jnp.bitwise_xor(jnp.astype(x, int), -1), jnp.dtype(x)), multiple_results=False, ), }) def _binary_op_lowering_rule(ctx: LoweringRuleContext, x, y, *, impl): + if ctx.module_ctx.primitive_semantics == gpu_core.PrimitiveSemantics.Warp: + if not all(aval_in.shape == () for aval_in in ctx.avals_in): + raise NotImplementedError( + "Non-scalar arithmetic is not supported in warp-level lowering.") x, y = _bcast(x, y, *ctx.avals_in, *ctx.avals_out) return impl(x, y) - -mosaic_lowering_rules[mgpu.ThreadSemantics.Lane].update({ +for semantics in [gpu_core.LANExWG_SEMANTICS, gpu_core.LANExWARP_SEMANTICS]: + mosaic_lowering_rules[semantics].update({ lax.add_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x + y), lax.sub_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x - y), lax.mul_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x * y), @@ -1308,8 +1875,7 @@ def _binary_op_lowering_rule(ctx: LoweringRuleContext, x, y, *, impl): lax.ne_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x != y), lax.max_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x.max(y)), lax.min_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x.min(y)), -}) - + }) def _binary_op_lowering_rule_wg( ctx: LoweringRuleContext, x, y, *, ui_impl, si_impl, f_impl=None @@ -1353,7 +1919,7 @@ def _binary_op_lowering_rule_wg( arith_dialect.minimumf, ), ]: - mosaic_lowering_rules[mgpu.ThreadSemantics.Warpgroup][op] = partial( + mosaic_lowering_rules[gpu_core.WGxWG_SEMANTICS][op] = partial( _binary_op_lowering_rule_wg, si_impl=si_impl, ui_impl=ui_impl, @@ -1372,7 +1938,7 @@ def _binary_boolean_op_lowering_rule_wg( (lax.or_p, arith_dialect.ori), (lax.xor_p, arith_dialect.xori), ]: - mosaic_lowering_rules[mgpu.ThreadSemantics.Warpgroup][op] = partial( + mosaic_lowering_rules[gpu_core.WGxWG_SEMANTICS][op] = partial( _binary_boolean_op_lowering_rule_wg, impl=impl, ) @@ -1387,7 +1953,7 @@ def _comparison_lowering_rule_wg( x, y = _bcast_wg(x, y, *ctx.avals_in, *ctx.avals_out) if jnp.issubdtype(x_aval, jnp.signedinteger): return arith_dialect.cmpi(si_pred, x, y) - elif jnp.issubdtype(x_aval, jnp.integer) or jnp.issubdtype(x_aval, jnp.bool): + elif jnp.issubdtype(x_aval, jnp.unsignedinteger) or jnp.issubdtype(x_aval, jnp.bool): return arith_dialect.cmpi(ui_pred, x, y) elif jnp.issubdtype(x_aval, jnp.floating): return arith_dialect.cmpf(f_pred, x, y) @@ -1405,7 +1971,7 @@ def _comparison_lowering_rule_wg( (lax.gt_p, CmpIPred.sgt, CmpIPred.ugt, CmpFPred.OGT), (lax.ge_p, CmpIPred.sge, CmpIPred.uge, CmpFPred.OGE), ]: - mosaic_lowering_rules[mgpu.ThreadSemantics.Warpgroup][op] = partial( + mosaic_lowering_rules[gpu_core.WGxWG_SEMANTICS][op] = partial( _comparison_lowering_rule_wg, si_pred=si_pred, ui_pred=ui_pred, @@ -1413,7 +1979,7 @@ def _comparison_lowering_rule_wg( ) -@register_lowering_rule(lax.div_p, mgpu.ThreadSemantics.Lane) +@register_lowering_rule(lax.div_p, mgpu.LoweringSemantics.Lane) def _div_lowering_rule(ctx: LoweringRuleContext, x, y): x, y = _bcast(x, y, *ctx.avals_in, *ctx.avals_out) if ir.FloatType.isinstance(x.mlir_dtype): @@ -1421,19 +1987,19 @@ def _div_lowering_rule(ctx: LoweringRuleContext, x, y): return x // y -@register_lowering_rule(lax.integer_pow_p, mgpu.ThreadSemantics.Lane) -@register_lowering_rule(lax.integer_pow_p, mgpu.ThreadSemantics.Warpgroup) +@register_lowering_rule(lax.integer_pow_p, mgpu.LoweringSemantics.Lane) +@register_lowering_rule(lax.integer_pow_p, mgpu.LoweringSemantics.Warpgroup) def _integer_pow_lowering_rule(ctx: LoweringRuleContext, x, y): if y != 2: raise NotImplementedError return _square_lowering_rule(ctx, x) -@register_lowering_rule(lax.square_p, mgpu.ThreadSemantics.Lane) -@register_lowering_rule(lax.square_p, mgpu.ThreadSemantics.Warpgroup) +@register_lowering_rule(lax.square_p, mgpu.LoweringSemantics.Lane) +@register_lowering_rule(lax.square_p, mgpu.LoweringSemantics.Warpgroup) def _square_lowering_rule(ctx: LoweringRuleContext, x): [x_aval] = ctx.avals_in - if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane: + if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Lane: x = _ensure_fa(x, x_aval.dtype) return x * x if jnp.issubdtype(x_aval.dtype, jnp.integer): @@ -1443,11 +2009,13 @@ def _square_lowering_rule(ctx: LoweringRuleContext, x): raise NotImplementedError(f"Unsupported dtype {x_aval.dtype}") -@register_lowering_rule(lax.rsqrt_p, mgpu.ThreadSemantics.Lane) -@register_lowering_rule(lax.rsqrt_p, mgpu.ThreadSemantics.Warpgroup) -def _rsqrt_lowering_rule(ctx: LoweringRuleContext, x): +@register_lowering_rule(lax.rsqrt_p, mgpu.LoweringSemantics.Lane) +@register_lowering_rule(lax.rsqrt_p, mgpu.LoweringSemantics.Warpgroup) +def _rsqrt_lowering_rule(ctx: LoweringRuleContext, x, accuracy): + if accuracy is not None: + raise NotImplementedError("Not implemented: accuracy") [x_aval] = ctx.avals_in - if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane: + if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Lane: return _ensure_fa(x, x_aval.dtype).rsqrt(approx=ctx.module_ctx.approx_math) fastmath = ( arith_dialect.FastMathFlags.afn if ctx.module_ctx.approx_math else None @@ -1457,11 +2025,13 @@ def _rsqrt_lowering_rule(ctx: LoweringRuleContext, x): ) -@register_lowering_rule(lax.tanh_p, mgpu.ThreadSemantics.Lane) -@register_lowering_rule(lax.tanh_p, mgpu.ThreadSemantics.Warpgroup) -def _tanh_lowering_rule(ctx: LoweringRuleContext, x): +@register_lowering_rule(lax.tanh_p, mgpu.LoweringSemantics.Lane) +@register_lowering_rule(lax.tanh_p, mgpu.LoweringSemantics.Warpgroup) +def _tanh_lowering_rule(ctx: LoweringRuleContext, x, accuracy): + if accuracy is not None: + raise NotImplementedError("Not implemented: accuracy") [x_aval] = ctx.avals_in - if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane: + if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Lane: return _ensure_fa(x, x_aval.dtype).tanh(approx=ctx.module_ctx.approx_math) fastmath = ( arith_dialect.FastMathFlags.afn if ctx.module_ctx.approx_math else None @@ -1469,23 +2039,27 @@ def _tanh_lowering_rule(ctx: LoweringRuleContext, x): return math_dialect.tanh(_ensure_ir_value(x, x_aval.dtype), fastmath=fastmath) -def _logistic(x): +def _logistic(x, accuracy): + if accuracy is not None: + raise NotImplementedError("Not implemented: accuracy") return 1.0 / (1 + lax.exp(-x)) -mosaic_lowering_rules[mgpu.ThreadSemantics.Lane][lax.logistic_p] = _lower_fun( +mosaic_lowering_rules[gpu_core.LANExWG_SEMANTICS][lax.logistic_p] = _lower_fun( _logistic, multiple_results=False ) -mosaic_lowering_rules[mgpu.ThreadSemantics.Warpgroup][lax.logistic_p] = ( +mosaic_lowering_rules[gpu_core.WGxWG_SEMANTICS][lax.logistic_p] = ( _lower_fun(_logistic, multiple_results=False) ) -@register_lowering_rule(lax.exp_p, mgpu.ThreadSemantics.Lane) -@register_lowering_rule(lax.exp_p, mgpu.ThreadSemantics.Warpgroup) -def _exp_lowering_rule(ctx: LoweringRuleContext, x): +@register_lowering_rule(lax.exp_p, mgpu.LoweringSemantics.Lane) +@register_lowering_rule(lax.exp_p, mgpu.LoweringSemantics.Warpgroup) +def _exp_lowering_rule(ctx: LoweringRuleContext, x, accuracy): + if accuracy is not None: + raise NotImplementedError("Not implemented: accuracy") [x_aval] = ctx.avals_in - if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane: + if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Lane: return _ensure_fa(x, x_aval.dtype).exp(approx=ctx.module_ctx.approx_math) fastmath = ( arith_dialect.FastMathFlags.afn if ctx.module_ctx.approx_math else None @@ -1493,10 +2067,13 @@ def _exp_lowering_rule(ctx: LoweringRuleContext, x): return math_dialect.exp(_ensure_ir_value(x, x_aval.dtype), fastmath=fastmath) -@register_lowering_rule(lax.exp2_p, mgpu.ThreadSemantics.Lane) -def _exp2_lowering_rule(ctx: LoweringRuleContext, x): +@register_lowering_rule(lax.exp2_p, mgpu.LoweringSemantics.Lane) +@register_lowering_rule(lax.exp2_p, mgpu.LoweringSemantics.Warpgroup) +def _exp2_lowering_rule(ctx: LoweringRuleContext, x, accuracy): + if accuracy is not None: + raise NotImplementedError("Not implemented: accuracy") [x_aval] = ctx.avals_in - if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane: + if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Lane: return _ensure_fa(x, x_aval.dtype).exp2(approx=ctx.module_ctx.approx_math) fastmath = ( arith_dialect.FastMathFlags.afn if ctx.module_ctx.approx_math else None @@ -1504,11 +2081,13 @@ def _exp2_lowering_rule(ctx: LoweringRuleContext, x): return math_dialect.exp2(_ensure_ir_value(x, x_aval.dtype), fastmath=fastmath) -@register_lowering_rule(lax.log_p, mgpu.ThreadSemantics.Lane) -@register_lowering_rule(lax.log_p, mgpu.ThreadSemantics.Warpgroup) -def _log_lowering_rule(ctx: LoweringRuleContext, x): +@register_lowering_rule(lax.log_p, mgpu.LoweringSemantics.Lane) +@register_lowering_rule(lax.log_p, mgpu.LoweringSemantics.Warpgroup) +def _log_lowering_rule(ctx: LoweringRuleContext, x, accuracy): + if accuracy is not None: + raise NotImplementedError("Not implemented: accuracy") [x_aval] = ctx.avals_in - if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane: + if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Lane: return _ensure_fa(x, x_aval.dtype).log(approx=ctx.module_ctx.approx_math) fastmath = ( arith_dialect.FastMathFlags.afn if ctx.module_ctx.approx_math else None @@ -1516,17 +2095,25 @@ def _log_lowering_rule(ctx: LoweringRuleContext, x): return math_dialect.log(_ensure_ir_value(x, x_aval.dtype), fastmath=fastmath) -@register_lowering_rule(lax.reduce_sum_p, mgpu.ThreadSemantics.Lane) +@register_lowering_rule(lax.reduce_sum_p, mgpu.LoweringSemantics.Lane) def _reduce_sum_lowering_rule(ctx: LoweringRuleContext, x, *, axes): [x_aval] = ctx.avals_in match x.layout: case mgpu.WGStridedFragLayout(): if set(axes) != set(range(x_aval.ndim)): raise NotImplementedError("No support for axes yet") + # To relax the restriction below, you need to ensure sufficient + # synchronization with other places that use `scratch_view` (which at the + # time of writing is only `run_scoped`). + if ctx.module_ctx.axis_names.wg is not None: + raise NotImplementedError( + "No support for reduce_sum over all axes and multiple Pallas" + " threads" + ) scratch_ty = jax.ShapeDtypeStruct(shape=(4,), dtype=x_aval.dtype) with ctx.module_ctx.scratch_view([scratch_ty]) as [scratch]: - return x.reduce_sum(scratch) - case mgpu.WGMMA_LAYOUT: + return x.reduce("add", axes, scratch) + case mgpu.TiledLayout(): if axes != (x_aval.ndim - 1,): raise NotImplementedError if not jnp.issubdtype(x_aval.dtype, jnp.floating): @@ -1536,11 +2123,11 @@ def _reduce_sum_lowering_rule(ctx: LoweringRuleContext, x, *, axes): raise NotImplementedError(f"Unsupported layout {x.layout}") -@register_lowering_rule(lax.reduce_max_p, mgpu.ThreadSemantics.Lane) +@register_lowering_rule(lax.reduce_max_p, mgpu.LoweringSemantics.Lane) def _reduce_max_lowering_rule(ctx: LoweringRuleContext, x, *, axes): [x_aval] = ctx.avals_in match x.layout: - case mgpu.WGMMA_LAYOUT: + case mgpu.TiledLayout(): if axes != (x_aval.ndim - 1,): raise NotImplementedError if not jnp.issubdtype(x_aval.dtype, jnp.floating): @@ -1577,7 +2164,7 @@ def _reduce_lowering_rule_wg( return vector_dialect.MultiDimReductionOp(kind, x, acc, axes) -@register_lowering_rule(lax.reduce_sum_p, mgpu.ThreadSemantics.Warpgroup) +@register_lowering_rule(lax.reduce_sum_p, mgpu.LoweringSemantics.Warpgroup) def _reduce_sum_lowering_rule_wg(ctx: LoweringRuleContext, x, *, axes): op = _reduce_lowering_rule_wg( vector_dialect.CombiningKind.ADD, 0, ctx, x, axes=axes @@ -1588,7 +2175,7 @@ def _reduce_sum_lowering_rule_wg(ctx: LoweringRuleContext, x, *, axes): return op.result -@register_lowering_rule(lax.reduce_max_p, mgpu.ThreadSemantics.Warpgroup) +@register_lowering_rule(lax.reduce_max_p, mgpu.LoweringSemantics.Warpgroup) def _reduce_max_lowering_rule_wg(ctx: LoweringRuleContext, x, *, axes): [x_aval] = ctx.avals_in if jnp.issubdtype(x_aval.dtype, jnp.floating): @@ -1605,52 +2192,123 @@ def _reduce_max_lowering_rule_wg(ctx: LoweringRuleContext, x, *, axes): return _reduce_lowering_rule_wg(kind, acc, ctx, x, axes=axes).result -@register_lowering_rule(lax.axis_index_p, mgpu.ThreadSemantics.Lane) +def _block_id(ctx: LoweringRuleContext, dim: gpu_dialect.Dimension) -> ir.Value: + result = gpu_dialect.block_id(dim) + cluster_size = ctx.launch_ctx.cluster_size + if math.prod(cluster_size) == 1 or cluster_size[dim.value] == 1: + return result + # We scale the grid in the presence of clusters, so we need to scale the + # block ID back here. + return arith_dialect.divui(result, _as_index(cluster_size[dim.value])) + + +def _resolve_cluster_axis(axis_names: _AxisNames | None, axis_name: str): + if not axis_names: + raise LookupError( + "No axis names are available. Make sure you are using `pl.core_map`" + " with a `plgpu.Mesh`." + ) + if not axis_names or axis_name not in axis_names.cluster: + raise LookupError( + f"Unknown cluster axis {axis_name}, available axes:" + f" {[*axis_names.cluster]}" + ) + return gpu_dialect.Dimension(axis_names.cluster.index(axis_name)) + + +@register_lowering_rule(lax.axis_index_p, mgpu.LoweringSemantics.Lane) +@register_lowering_rule(lax.axis_index_p, mgpu.LoweringSemantics.Warpgroup) def _axis_index_rule(ctx: LoweringRuleContext, *, axis_name: Hashable): - i32 = ir.IntegerType.get_signless(32) - grid_names = ctx.module_ctx.grid_names + gpu_axis_names = ctx.module_ctx.axis_names + jax_axis_names = getattr(ctx.module_ctx.mesh, "axis_names", ()) + if gpu_axis_names is None and not jax_axis_names: + raise LookupError( + "No axis names are available. Make sure you are using `pl.core_map`" + " with a `plgpu.Mesh` or an appropriate JAX device mesh." + ) + if axis_name not in itertools.chain((gpu_axis_names or ()), jax_axis_names): + raise LookupError( + f"Axis {axis_name} does not refer to a GPU mesh axis (available axes:" + f" {[*gpu_axis_names]}) or a JAX mesh axis (available axes:" + f" {[*jax_axis_names]})" + ) + if axis_name in jax_axis_names: + jax_mesh = ctx.module_ctx.mesh + assert jax_mesh is not None + device_id = ctx.launch_ctx.device_id() + jax_mesh_shape = jax_mesh.axis_sizes + axis_index = jax_axis_names.index(axis_name) + i32 = ir.IntegerType.get_signless(32) + axis_size = _ir_constant(jax_mesh_shape[axis_index], i32) + minor_divisor = _ir_constant( + np.prod(jax_mesh_shape[axis_index + 1 :], dtype=np.int32), i32 + ) + return arith_dialect.remsi(arith_dialect.divsi(device_id, minor_divisor), axis_size) + + # We already checked that the axis is in scope and it wasn't a JAX mesh axis. + assert gpu_axis_names is not None + + # We only deal with GPU axes from now on. + axis_names = gpu_axis_names + if axis_names.wg is not None and axis_name == axis_names.wg: + return mgpu.warpgroup_idx(sync=True) + + if axis_name in axis_names.cluster: + return arith_dialect.index_cast( + ir.IntegerType.get_signless(32), + gpu_dialect.cluster_block_id( + gpu_dialect.Dimension(axis_names.cluster.index(axis_name)) + ), + ) + squashed_dims = ctx.module_ctx.squashed_dims if squashed_dims: - unsquashed_names = grid_names[-3:] - squashed_names = grid_names[:-3] + unsquashed_names = axis_names.grid[-2:] + squashed_names = axis_names.grid[:-2] else: # These are unused but initialized for type checkers. - unsquashed_names = () - squashed_names = () - if grid_names and axis_name in grid_names: - if axis_name == grid_names[-1]: - return mgpu.warpgroup_idx(sync=True) + unsquashed_names = squashed_names = () + + if squashed_dims: + if axis_name in unsquashed_names: + # We add 1 to the index because the first dimension is the + # squashed dimension. + # e.g. for the grid (a, b, c, d, wg) + # squashed = (a, b) Mapped to Dimension.x (0) + # unsquashed = (c, d) Mapped to Dimension.y (1) and Dimension.z (2) + idx = unsquashed_names.index(axis_name) + 1 + return arith_dialect.index_cast( + ir.IntegerType.get_signless(32), + _block_id(ctx, gpu_dialect.Dimension(idx)), + ) else: - if squashed_dims: - if axis_name in unsquashed_names: - # We add 1 to the index because the first dimension is the - # squashed dimension. - # e.g. for the grid (a, b, c, d, wg) - # squashed = (a, b) Mapped to Dimension.x (0) - # unsquashed = (c, d) Mapped to Dimension.y (1) and Dimension.z (2) - idx = unsquashed_names.index(axis_name) + 1 - return arith_dialect.index_cast( - i32, - gpu_dialect.block_id(gpu_dialect.Dimension(idx)), - ) - elif axis_name in squashed_names: - # All squashed dimensions are mapped to Dimension.x. - block_id = gpu_dialect.block_id(gpu_dialect.Dimension.x) - axis = squashed_names.index(axis_name) - return _unravel_program_id(block_id, axis, squashed_dims) - else: - if axis_name in grid_names: - idx = grid_names.index(axis_name) - return arith_dialect.index_cast( - i32, - gpu_dialect.block_id(gpu_dialect.Dimension(idx)), - ) + assert axis_name in squashed_names + # All squashed dimensions are mapped to Dimension.x. + axis = squashed_names.index(axis_name) + return _unravel_program_id( + _block_id(ctx, gpu_dialect.Dimension.x), axis, squashed_dims + ) + else: + assert axis_name in axis_names.grid + idx = axis_names.grid.index(axis_name) + return arith_dialect.index_cast( + ir.IntegerType.get_signless(32), + _block_id(ctx, gpu_dialect.Dimension(idx)), + ) + +@register_lowering_rule(lax.axis_index_p, + mgpu.LoweringSemantics.Lane, gpu_core.PrimitiveSemantics.Warp) +def _axis_index_warp_rule(ctx: LoweringRuleContext, *, axis_name: Hashable): + if axis_name == ctx.module_ctx.warp_axis_name: + return mgpu.warp_idx(sync=True) raise ValueError( - "Named axes can only refer to GPUMesh axes in Mosaic GPU kernels" + "Named axes can only refer to the warp axis name inside of core_map." ) -@register_lowering_rule(primitives.debug_print_p, mgpu.ThreadSemantics.Lane) +@register_lowering_rule(primitives.debug_print_p, mgpu.LoweringSemantics.Lane) +@register_lowering_rule(primitives.debug_print_p, mgpu.LoweringSemantics.Lane, + gpu_core.PrimitiveSemantics.Warp) def _debug_print_lowering_rule( ctx: LoweringRuleContext, *args, @@ -1659,6 +2317,9 @@ def _debug_print_lowering_rule( ): del has_placeholders # Unused. primitives.check_debug_print_format(fmt, *args) + scope = mgpu.ThreadSubset.WARPGROUP + if ctx.module_ctx.primitive_semantics == gpu_core.PrimitiveSemantics.Warp: + scope = mgpu.ThreadSubset.WARP if not any(aval.shape for aval in ctx.avals_in): mgpu.debug_print( fmt, @@ -1666,6 +2327,7 @@ def _debug_print_lowering_rule( _ensure_ir_value(arg, aval.dtype) for arg, aval in zip(args, ctx.avals_in) ), + scope=scope ) elif len(ctx.avals_in) == 1: [arg] = args @@ -1678,8 +2340,8 @@ def _debug_print_lowering_rule( return () -@register_lowering_rule(primitives.debug_print_p, mgpu.ThreadSemantics.Warpgroup) -def _debug_print_lowering_rule( +@register_lowering_rule(primitives.debug_print_p, mgpu.LoweringSemantics.Warpgroup) +def _debug_print_lowering_rule_wg( ctx: LoweringRuleContext, *args, fmt, @@ -1692,79 +2354,137 @@ def _debug_print_lowering_rule( return () -@register_lowering_rule(primitives.run_scoped_p, mgpu.ThreadSemantics.Lane) -@register_lowering_rule(primitives.run_scoped_p, mgpu.ThreadSemantics.Warpgroup) +@register_lowering_rule(primitives.run_scoped_p, mgpu.LoweringSemantics.Lane) +@register_lowering_rule(primitives.run_scoped_p, mgpu.LoweringSemantics.Warpgroup) def _run_scoped_lowering_rule( - ctx: LoweringRuleContext, *consts, jaxpr: jax_core.Jaxpr + ctx: LoweringRuleContext, *consts, jaxpr: jax_core.Jaxpr, collective_axes ): input_refs = [] should_discharge = [] - alloc_stack = contextlib.ExitStack() - for v in jaxpr.invars: - aval = v.aval - if isinstance(aval, gpu_core.WGMMAAbstractAccumulatorRef): - if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Warpgroup: - # TODO(bchetioui): Fix this and remove the NotImplementedError. + wg_axis = ctx.module_ctx.axis_names.wg + is_multithreaded = wg_axis is not None + is_thread_collective = is_multithreaded and collective_axes == (wg_axis,) + # Make sure everyone has exited previous scoped allocations. Note that we + # don't synchronize when we exit the allocation, but only when we might want + # to reuse its memory again. + if is_multithreaded and is_thread_collective: + gpu_dialect.barrier() + with contextlib.ExitStack() as alloc_stack: + for v in jaxpr.invars: + aval = v.aval + if isinstance(aval, gpu_core.WGMMAAbstractAccumulatorRef): + if collective_axes: + raise ValueError( + "WGMMA accumulators can only be allocated non-collectively. Hint:" + " remove collective_axes from run_scoped. If other allocations" + " are performed as well, split the run_scoped into two." + ) + dtype = mlir.dtype_to_ir_type(aval.dtype) + if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Lane: + input_refs.append(mgpu.WGMMAAccumulator.zero(*aval.shape, dtype)) + else: + zero = arith_dialect.constant(dtype, ir.FloatAttr.get(dtype, 0.0)) + acc = vector_dialect.splat(ir.VectorType.get(aval.shape, dtype), zero) + acc = mgpu.dialect.optimization_barrier([acc]) + nvvm_dialect.wgmma_fence_aligned() + input_refs.append(acc) + should_discharge.append(True) + continue + # All other allocations must be made collectively across all threads. + if is_multithreaded and not is_thread_collective: raise NotImplementedError( - "WGMMA accumulators are not supported with Warpgroup semantics." + "Only thread-collective allocations are supported in multithreaded" + " kernels. Hint: add" + f" collective_axes={ctx.module_ctx.axis_names.wg} to your" + " run_scoped if you intend all threads to share the same" + f" allocation (currently collective_axes={collective_axes})." ) - mlir_dtype = mlir.dtype_to_ir_type(aval.dtype) - input_refs.append(mgpu.WGMMAAccumulator.zero(*aval.shape, mlir_dtype)) - should_discharge.append(True) - elif isinstance(aval.dtype, gpu_core.BarrierType): - input_refs.append( - ctx.module_ctx.reserve_barrier( - mgpu.Barrier( - aval.dtype.num_arrivals - * ctx.estimator_ctx.arrival_multiplier, - *aval.shape, - ) - ) - ) - should_discharge.append(False) - elif aval.memory_space == gpu_core.SMEM: - [input_ref] = alloc_stack.enter_context( - ctx.module_ctx.scratch_view( - [jax.ShapeDtypeStruct(shape=aval.shape, dtype=aval.dtype)] - ) + if isinstance(aval.dtype, gpu_core.BarrierType): + multiplier = (1 if aval.dtype.for_tensor_core else + ctx.estimator_ctx.arrival_multiplier) + barrier_ref = alloc_stack.enter_context( + ctx.module_ctx.reserve_barrier( + mgpu.Barrier( + aval.dtype.num_arrivals * multiplier, + *aval.shape, + ) + ) + ) + input_refs.append(barrier_ref) + should_discharge.append(False) + elif isinstance(aval.dtype, gpu_core.ClusterBarrierType): + collective_dims = jax.tree.map( + lambda axis: _resolve_cluster_axis(ctx.module_ctx.axis_names, axis), + aval.dtype.collective_axes, + ) + barrier_ref = alloc_stack.enter_context( + ctx.module_ctx.reserve_barrier( + mgpu.ClusterBarrier(collective_dims, *aval.shape) + ) + ) + input_refs.append(barrier_ref) + should_discharge.append(False) + elif aval.memory_space == gpu_core.SMEM: + [input_ref] = alloc_stack.enter_context( + ctx.module_ctx.scratch_view( + [jax.ShapeDtypeStruct(shape=aval.shape, dtype=aval.dtype)] + ) + ) + input_refs.append(input_ref) + should_discharge.append(False) + elif aval.memory_space == gpu_core.TMEM: + input_ref = alloc_stack.enter_context( + ctx.module_ctx.alloc_tmem( + jax.ShapeDtypeStruct(shape=aval.shape, dtype=aval.dtype), + packed=aval.packed, + exact_cols=False, + collective=aval.collective, + ) + ) + input_refs.append(input_ref) + should_discharge.append(False) + elif aval.memory_space == gpu_core.GMEM and jnp.issubdtype(aval.dtype, pallas_core.semaphore): + input_ref = alloc_stack.enter_context( + ctx.module_ctx.reserve_semaphores(aval.shape) + ) + input_refs.append(input_ref) + should_discharge.append(False) + else: + raise ValueError(f"Can't convert to ref: {aval}") + + if any(should_discharge): + # We convert consts to args, because we only have ir.Values and + # not JAX values during lowering. discharge_state() produces JAX + # valiues for the arguments but expects them to be provided for the + # consts. We also don't want to wrap the values in refs. + no_const_jaxpr = pe.convert_constvars_jaxpr(jaxpr) + should_discharge = [False] * len(consts) + should_discharge + discharged_jaxpr, _ = discharge.discharge_state(no_const_jaxpr, (), should_discharge=should_discharge) + new_input_vals = consts + tuple(input_refs) + outs = lower_jaxpr_to_mosaic_gpu( + ctx.module_ctx, + ctx.launch_ctx, + discharged_jaxpr, + new_input_vals, + (), ) - input_refs.append(input_ref) - should_discharge.append(False) + # Discharge appends to the output the refs that got discharged. + outs = outs[:-sum(should_discharge)] else: - raise ValueError(f"Can't convert to ref: {aval}") - - if any(should_discharge): - # We convert consts to args, because we only have ir.Values and - # not JAX values during lowering. discharge_state() produces JAX - # valiues for the aguments but expects them to be provided for the - # consts. We also don't want to wrap the values in refs. - no_const_jaxpr = pe.convert_constvars_jaxpr(jaxpr) - should_discharge = [False] * len(consts) + should_discharge - discharged_jaxpr, _ = discharge.discharge_state(no_const_jaxpr, (), should_discharge=should_discharge) - new_input_vals = consts + tuple(input_refs) - outs = lower_jaxpr_to_mosaic_gpu( - ctx.module_ctx, - ctx.launch_ctx, - discharged_jaxpr, - new_input_vals, - (), - ) - # Discharge appends to the output the refs that got discharged. - outs = outs[:-sum(should_discharge)] - else: - outs = lower_jaxpr_to_mosaic_gpu( - ctx.module_ctx, - ctx.launch_ctx, - jaxpr, - input_refs, - consts, - ) + outs = lower_jaxpr_to_mosaic_gpu( + ctx.module_ctx, + ctx.launch_ctx, + jaxpr, + input_refs, + consts, + ) assert len(outs) == len(jaxpr.outvars), (jaxpr, outs) return outs -@register_lowering_rule(discharge.run_state_p, mgpu.ThreadSemantics.Lane) +@register_lowering_rule(discharge.run_state_p, mgpu.LoweringSemantics.Lane) +@register_lowering_rule(discharge.run_state_p, mgpu.LoweringSemantics.Warpgroup) def _run_state_lowering_rule( ctx: LoweringRuleContext, *args, @@ -1782,7 +2502,12 @@ def _run_state_lowering_rule( for arg, v, out_aval in zip(args, jaxpr.invars, ctx.avals_out): aval = v.aval if isinstance(aval, gpu_core.WGMMAAbstractAccumulatorRef): - new_input_vals.append(mgpu.WGMMAAccumulator.from_registers(arg)) + if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Warpgroup: + arg = mgpu.dialect.optimization_barrier([arg]) + nvvm_dialect.wgmma_fence_aligned() + new_input_vals.append(arg) + else: + new_input_vals.append(mgpu.WGMMAAccumulator.from_registers(arg)) should_discharge.append(True) assert isinstance(out_aval, jax_core.ShapedArray) else: @@ -1817,12 +2542,12 @@ def _lower_jaxpr_to_for_loop( ctx: LoweringRuleContext, jaxpr: jax_core.Jaxpr, start: ir.Value, - length: ir.Value, + length: int | ir.Value, consts, *args, has_loop_index: bool, + unroll: int | None = None, ): - _consts_avals, arg_avals = util.split_list(ctx.avals_in, [len(consts)]) arg_avals = arg_avals[has_loop_index:] out_avals = [] @@ -1836,28 +2561,58 @@ def as_values(vals, avals): _ensure = ( _ensure_fa - if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane + if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Lane else _ensure_ir_value ) return [v if a else _ensure(v, av) for a, v, av in zip(is_acc, vals, avals)] - @mgpu.fori(length, as_values(args, arg_avals)) def loop(loop_index, body_args): - if has_loop_index: - loop_index = arith_dialect.addi(loop_index, start) - jaxpr_args = [*consts, loop_index, *body_args] - else: - jaxpr_args = [*consts, *body_args] - outs = lower_jaxpr_to_mosaic_gpu( - ctx.module_ctx, ctx.launch_ctx, jaxpr, jaxpr_args - ) + outs = body_args + if unroll is not None: + loop_index = arith_dialect.muli( + loop_index, _ir_constant(unroll, start.type) + ) + loop_index = arith_dialect.addi(loop_index, start) + for step in range(unroll or 1): + if has_loop_index: + loop_index = arith_dialect.addi( + loop_index, _ir_constant(step, start.type) + ) + jaxpr_args = [*consts, loop_index, *outs] + else: + jaxpr_args = [*consts, *outs] + outs = lower_jaxpr_to_mosaic_gpu( + ctx.module_ctx, ctx.launch_ctx, jaxpr, jaxpr_args + ) return as_values(outs, out_avals) - return loop.results + if unroll is not None: + if not isinstance(length, int): + raise NotImplementedError( + "``length`` must be an integer when ``unroll` is specified, got" + f" {length}" + ) + if length % unroll: + # TODO(slebedev): Emit an epilogue taking care of the remaining steps. + raise NotImplementedError( + f"``unroll`` must divide ``length``, got {unroll=} and {length=}" + ) + if unroll == length: + # Special-case: the loop is fully unrolled. + return loop(_ir_constant(0, start.type), as_values(args, arg_avals)) + return mgpu.fori( + _ir_constant(length // unroll, start.type), as_values(args, arg_avals) + )(loop).results + else: + if not isinstance(length, ir.Value): + length = _ir_constant(length, start.type) + return mgpu.fori(length, as_values(args, arg_avals))(loop).results -@register_lowering_rule(lax.scan_p, mgpu.ThreadSemantics.Lane) -@register_lowering_rule(lax.scan_p, mgpu.ThreadSemantics.Warpgroup) +@register_lowering_rule(lax.scan_p, mgpu.LoweringSemantics.Lane) +@register_lowering_rule(lax.scan_p, mgpu.LoweringSemantics.Warpgroup) +@register_lowering_rule(lax.scan_p, mgpu.LoweringSemantics.Lane, + gpu_core.PrimitiveSemantics.Warp) def _scan_lowering_rule( ctx: LoweringRuleContext, *args, @@ -1871,13 +2626,9 @@ def _scan_lowering_rule( _split_transpose: bool, ): # Can only handle fori_loop-like scans. - if ( - (num_extensive := len(args) - num_consts - num_carry) - or reverse - or unroll != 1 - ): + if (num_extensive := len(args) - num_consts - num_carry) or reverse: raise NotImplementedError - del linear, num_extensive, reverse, unroll + del linear, num_extensive, reverse jaxpr, jaxpr_consts = jaxpr.jaxpr, jaxpr.consts if jaxpr_consts: @@ -1893,17 +2644,24 @@ def _scan_lowering_rule( start, *args = args index_aval, *_ = arg_avals start: ir.Value = _ensure_ir_value(start, index_aval.dtype) - length = _ir_constant(length, start.type) else: start = _i32_constant(0) - length = _i32_constant(length) + for_out = _lower_jaxpr_to_for_loop( - ctx, jaxpr, start, length, consts, *args, has_loop_index=has_loop_index + ctx, + jaxpr, + start, + length, + consts, + *args, + has_loop_index=has_loop_index, + unroll=unroll, ) if has_loop_index: # Need to return the final loop index value if the outer scan expects # it as an output. - return [length, *for_out] + loop_index = arith_dialect.addi(start, _ir_constant(length, start.type)) + return [loop_index, *for_out] return for_out @@ -1945,8 +2703,8 @@ def _lower_while_via_fori( return ub, ub, *for_out -@register_lowering_rule(lax.while_p, mgpu.ThreadSemantics.Lane) -@register_lowering_rule(lax.while_p, mgpu.ThreadSemantics.Warpgroup) +@register_lowering_rule(lax.while_p, mgpu.LoweringSemantics.Lane) +@register_lowering_rule(lax.while_p, mgpu.LoweringSemantics.Warpgroup) def _while_lowering_rule( ctx: LoweringRuleContext, *args, @@ -1970,7 +2728,7 @@ def _while_lowering_rule( _is_acc = lambda x: isinstance(x, mgpu.WGMMAAccumulator) _ensure = _ensure_ir_value - if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane: + if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Lane: _ensure = lambda v, aval: v if _is_acc(v) else _ensure_fa(v, aval.dtype) # If we fail conversion to fori, fallback to an ordinary while loop. @@ -2004,48 +2762,54 @@ def _while_lowering_rule( ctx.module_ctx, ctx.launch_ctx, body_jaxpr.jaxpr, body_args ) loop_out = [*map(_ensure, loop_out, carry_avals)] - for idx, (carry_fa, out_fa) in enumerate(zip(carry, loop_out)): - if _is_acc(carry_fa) != _is_acc(out_fa): - raise ValueError( - f"The loop body output has unexpected accumulator type: output[{idx}]" - f" is {out_fa}, when it should be {carry_fa}." - ) + if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Lane: + for idx, (carry_fa, out_fa) in enumerate(zip(carry, loop_out)): + if _is_acc(carry_fa) != _is_acc(out_fa): + raise ValueError( + f"The loop body output has unexpected accumulator type:" + f" output[{idx}] is {out_fa}, when it should be {carry_fa}." + ) - if not _is_acc(out_fa) and carry_fa.layout != out_fa.layout: - raise ValueError( - f"The loop body output has unexpected layout: output[{idx}] has" - f" layout {out_fa.layout}, when it should be {carry_fa.layout}." - ) + if not _is_acc(out_fa) and carry_fa.layout != out_fa.layout: + raise ValueError( + f"The loop body output has unexpected layout: output[{idx}] has" + f" layout {out_fa.layout}, when it should be {carry_fa.layout}." + ) scf_dialect.yield_( carry_treedef.flatten_up_to(loop_out) if loop_out else [] ) return carry_treedef.unflatten(list(while_op.results)) -@register_lowering_rule(lax.cond_p, mgpu.ThreadSemantics.Lane) -@register_lowering_rule(lax.cond_p, mgpu.ThreadSemantics.Warpgroup) -def _cond_lowering_rule(ctx: LoweringRuleContext, index, *args, branches): +@register_lowering_rule(lax.cond_p, mgpu.LoweringSemantics.Lane) +@register_lowering_rule(lax.cond_p, + mgpu.LoweringSemantics.Lane, gpu_core.PrimitiveSemantics.Warp) +@register_lowering_rule(lax.cond_p, mgpu.LoweringSemantics.Warpgroup) +def _cond_lowering_rule(ctx: LoweringRuleContext, index, *args, branches, + **params): + if params: + raise NotImplementedError("platform_dependent cond") index_aval, *_arg_avals = ctx.avals_in def _yielded_values(outs, avals): ret = [] for out, aval in zip(outs, avals): - if isinstance(out, mgpu.FragmentedArray): + if isinstance(out, (mgpu.WGMMAAccumulator, mgpu.FragmentedArray)): ret.append(out) else: ret.append(_ensure_ir_value(out, aval.dtype)) return ret - # We need the branch return mlir types in order to construct the - # switch operation. To avoid leaking information about what kind of - # mlir types are internal to FragmentedArrays and other mgpu types, - # we run one of the branches in a dummy module that we throw away to - # extract the return types + # We need to know the result types ahead of time to construct the switch + # operation. Below we lower the first branch in a throw-away module to + # extract them. with ir.InsertionPoint(ir.Module.create().body): outs = lower_jaxpr_to_mosaic_gpu( ctx.module_ctx, ctx.launch_ctx, branches[0].jaxpr, args ) - yielded_types = [v.type for v in jax.tree.leaves(_yielded_values(outs, ctx.avals_out))] + yielded_types = [ + v.type for v in jax.tree.leaves(_yielded_values(outs, ctx.avals_out)) + ] del outs switch_op = scf_dialect.IndexSwitchOp( @@ -2080,9 +2844,9 @@ def _yielded_values(outs, avals): return treedef.unflatten(list(switch_op.results)) -@register_lowering_rule(lax.bitcast_convert_type_p, mgpu.ThreadSemantics.Lane) +@register_lowering_rule(lax.bitcast_convert_type_p, mgpu.LoweringSemantics.Lane) @register_lowering_rule( - lax.bitcast_convert_type_p, mgpu.ThreadSemantics.Warpgroup + lax.bitcast_convert_type_p, mgpu.LoweringSemantics.Warpgroup ) def _bitcast_convert_type_lowering_rule( ctx: LoweringRuleContext, x, *, new_dtype @@ -2098,7 +2862,7 @@ def _bitcast_convert_type_lowering_rule( " have different widths" ) - if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Warpgroup: + if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Warpgroup: x = _ensure_ir_value(x, x_aval.dtype) return arith_dialect.bitcast( ir.VectorType.get(x_aval.shape, dst_elem_type), x @@ -2114,10 +2878,68 @@ def _bitcast_convert_type_lowering_rule( ) -@register_lowering_rule(lax.optimization_barrier_p, mgpu.ThreadSemantics.Lane) +@register_lowering_rule(lax.optimization_barrier_p, mgpu.LoweringSemantics.Lane) def _optimization_barrier_lowering(ctx: LoweringRuleContext, *args): - args = (_ensure_fa(arg, aval.dtype) for arg, aval in zip(args, ctx.avals_in)) - return mgpu.optimization_barrier(*args) + result = mgpu.optimization_barrier( + *(_ensure_fa(arg, aval.dtype) for arg, aval in zip(args, ctx.avals_in)) + ) + return (result,) if len(ctx.avals_in) == 1 else result + + +@register_lowering_rule( + lax.optimization_barrier_p, mgpu.LoweringSemantics.Warpgroup +) +def _optimization_barrier_lowering_wg(ctx: LoweringRuleContext, *args): + result = mgpu.dialect.optimization_barrier([ + _ensure_ir_value(arg, aval.dtype) for arg, aval in zip(args, ctx.avals_in) + ]) + return (result,) if len(ctx.avals_in) == 1 else result + + +@register_lowering_rule(pallas_core.core_map_p, mgpu.LoweringSemantics.Lane) +def _core_map_lowering_rule( + ctx: LoweringRuleContext, + *args, + jaxpr, + mesh, + **_, +): + if isinstance(mesh, gpu_core.WarpMesh): + # A core_map over a WarpMesh represents a fork/join over individual + # warps in a warpgroup. + if (ctx.module_ctx.warp_axis_name or + ctx.module_ctx.primitive_semantics == gpu_core.PrimitiveSemantics.Warp): + raise LoweringError( + "Cannot nest core_maps. Already under core_map with warp_axis_name " + f"{ctx.module_ctx.warp_axis_name}.") + module_ctx = dataclasses.replace( + ctx.module_ctx, + warp_axis_name=mesh.axis_name, + primitive_semantics=gpu_core.PrimitiveSemantics.Warp, + ) + for aval_in in ctx.avals_in: + if isinstance(aval_in, jax_core.ShapedArray) and aval_in.shape: + raise LoweringError( + "Can only close over scalars and Refs when using core_map with " + f"WarpMesh. Found array of shape {aval_in}." + ) + # We allow the warps to schedule async copies without synchronizing with + # other warps, so we need to add a barrier here to make sure all reads and + # writes have completed. + if ctx.module_ctx.auto_barriers: + mgpu.warpgroup_barrier() + _ = lower_jaxpr_to_mosaic_gpu( + module_ctx, + ctx.launch_ctx, + jaxpr, + args=(), + consts=args, + ) + if ctx.module_ctx.auto_barriers: + # TODO(apaszke,justinfu): Do we really need this barrier? + mgpu.warpgroup_barrier() + return [] + raise ValueError(f"Unsupported mesh: {mesh}") def _bcast( @@ -2313,3 +3135,99 @@ def _ensure_idx_fa(x): shape=root_shape, int_indexer_shape=(), ) + + +@register_lowering_rule(primitives.semaphore_read_p, mgpu.LoweringSemantics.Lane) +def _semaphore_read_lowering_rule(ctx: LoweringRuleContext, *args, args_tree): + sem, transforms = tree_util.tree_unflatten(args_tree, args) + sem, transforms = _handle_transforms(ctx, sem, transforms) + if transforms: + raise NotImplementedError(f"Unhandled transforms for semaphore_read: {transforms}") + sem_ptr = mgpu.utils.memref_ptr(sem) + i32_ty = ir.IntegerType.get_signless(32) + return llvm_dialect.inline_asm( + i32_ty, [sem_ptr], "ld.acquire.sys.u32 $0,[$1];", "=r,l", has_side_effects=True, + ) + + +@register_lowering_rule(primitives.semaphore_signal_p, mgpu.LoweringSemantics.Lane) +def _semaphore_signal_lowering_rule( + ctx: LoweringRuleContext, + *args, + args_tree, + device_id_type, +): + i32 = ir.IntegerType.get_signless(32) + sem, transforms, value, device_id, core_index = tree_util.tree_unflatten( + args_tree, args + ) + if core_index is not None: + raise NotImplementedError( + "Mosaic GPU backend does not support the concept of cores, but" + " core_index is specified" + ) + sem, transforms = _handle_transforms(ctx, sem, transforms) + if transforms: + raise NotImplementedError(f"Unhandled transforms for semaphore_signal: {transforms}") + sem_ptr = mgpu.utils.memref_ptr(sem) + if device_id is not None: + if device_id_type != primitives.DeviceIdType.LOGICAL: + raise NotImplementedError( + f"Unsupported device id type: {device_id_type}" + ) + sem_ptr = ctx.launch_ctx.to_remote( + sem_ptr, _ensure_ir_value(device_id, jnp.int32) + ) + # TODO(apaszke): Narrow the scope from .sys to .gpu when the semaphore is local. + val = _ir_constant(value, i32) + # We only signal the semaphore from a single lane, which does not guarantee + # anything about the state of the other three warps in the warpgroup (they + # might still be e.g. reading memory that someone will overwrite once they + # receive a signal). + if ctx.module_ctx.auto_barriers: + mgpu.utils.warpgroup_barrier() + mgpu_utils.SemaphoreRef(sem_ptr).signal( + val, predicate=ctx.module_ctx.single_wg_lane_predicate + ) + return () + + +@register_lowering_rule(primitives.semaphore_wait_p, mgpu.LoweringSemantics.Lane) +def _semaphore_wait_lowering_rule(ctx: LoweringRuleContext, *args, args_tree): + sem, transforms, value = tree_util.tree_unflatten(args_tree, args) + sem, transforms = _handle_transforms(ctx, sem, transforms) + if transforms: + raise NotImplementedError( + f"Unhandled transforms for semaphore_wait: {transforms}" + ) + i32 = ir.IntegerType.get_signless(32) + val = _ir_constant(value, i32) + mgpu_utils.SemaphoreRef(mgpu.utils.memref_ptr(sem)).wait(val) + return () + + +@register_lowering_rule(checkify.check_p, mgpu.LoweringSemantics.Lane) +def _check_lowering_rule(ctx: LoweringRuleContext, *err_args, err_tree, debug): + del ctx # Unused. + + if not debug: + raise NotImplementedError( + "Non-debug checks are not supported by the Mosaic GPU backend." + " Functionalize them via `jax.experimental.checkify`." + ) + if not pallas_helpers.debug_checks_enabled(): + return [] + + error = jax.tree.unflatten(err_tree, err_args) + [pred] = error._pred.values() + [exception_tree] = error._metadata.values() + [payload] = error._payload.values() + exception = jax.tree.unflatten(exception_tree, payload) + assert isinstance(exception, checkify.FailedCheckError) + + # check_p has an inverted predicate compared to assert, so we need to compute + # ``not pred`` here. + minus_one = _ir_constant(-1, mgpu_utils.dtype_to_ir_type(jnp.bool)) + not_pred = arith_dialect.xori(pred.registers.item(), minus_one) + cf_dialect.assert_(not_pred, exception.fmt_string) + return [] diff --git a/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py b/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py index d506349fe101..1d55a6e862a0 100644 --- a/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py +++ b/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py @@ -19,15 +19,20 @@ import os import time -from typing import Any +from typing import cast import warnings import jax +from jax import lax +from jax._src import config from jax._src import core as jax_core +from jax._src import sharding_impls from jax._src.interpreters import mlir from jax._src.pallas import core as pallas_core +from jax._src.pallas.mosaic_gpu import core as gpu_core from jax._src.pallas.mosaic_gpu import lowering -import jax.experimental.mosaic.gpu.core as mosaic_core +from jax.experimental.mosaic import gpu as mgpu +import numpy as np def pallas_call_lowering( @@ -39,7 +44,7 @@ def pallas_call_lowering( input_output_aliases: tuple[tuple[int, int], ...], grid_mapping: pallas_core.GridMapping, mesh: pallas_core.Mesh | None, - compiler_params: dict[str, Any], + compiler_params: dict[str, pallas_core.CompilerParams], cost_estimate: pallas_core.CostEstimate | None, out_avals: tuple[jax_core.AbstractValue, ...], ): @@ -56,33 +61,57 @@ def pallas_call_lowering( print(f"The grid mapping for pallas_call {debug_info.func_src_info}:") print(grid_mapping) - thread_semantics = compiler_params.get("mosaic_gpu", {}).get( - "thread_semantics", mosaic_core.ThreadSemantics.Lane - ) - if thread_semantics == mosaic_core.ThreadSemantics.Warpgroup: - mosaic_core.dialect.register_dialect(ctx.module_context.context) # pytype: disable=attribute-error + mgpu.dialect.register_dialect(ctx.module_context.context) # pytype: disable=attribute-error - lowering_result = lowering.lower_pipelined_jaxpr_to_module( - grid_mapping, - mesh, - jaxpr, - compiler_params, - cost_estimate, - ) + if "mosaic_gpu" in compiler_params: + params = cast(gpu_core.CompilerParams, compiler_params["mosaic_gpu"]) + else: + params = gpu_core.CompilerParams() + + jax_mesh = None + axis_context = ctx.module_context.axis_context + if axis_context is not None: + if isinstance(axis_context, sharding_impls.SPMDAxisContext): + jax_mesh = axis_context.mesh + + # TODO(slebedev): Remove this once the ensure-debug-info-scope-on-llvm-func + # pass correctly handles full tracebacks. + with config.include_full_tracebacks_in_locations(False): + lowering_result = lowering.lower_pipelined_jaxpr_to_module( + grid_mapping, mesh, jax_mesh, jaxpr, params, cost_estimate + ) if debug: print(f"\nThe Mosaic GPU module for pallas_call {debug_info.func_src_info}:") print(lowering_result.module.operation) module = lowering_result.module - new_avals_out = [ - jax_core.ShapedArray(t.shape, t.dtype) for t in lowering_result.out_structs - ] - outs = mosaic_core._mosaic_gpu_lowering_rule( - ctx.replace(avals_out=new_avals_out), - *args, + new_avals_in = list(ctx.avals_in) + new_avals_out = list(map(_as_shaped_array, lowering_result.new_out_shapes)) + scratch_args = () + if lowering_result.gmem_scratch_shapes: + # The new_out_shapes contain the original outputs first, followed by the + # GMEM scratch shapes, and optionally the profiler buffer. + input_output_aliases += tuple( + (len(ctx.avals_in) + i, len(ctx.avals_out) + i) + for i in range(len(lowering_result.gmem_scratch_shapes)) + ) + # The GMEM scratch is an aliased kernel input/output. + new_avals_in.extend(map(_as_shaped_array, lowering_result.gmem_scratch_shapes)) + # We guarantee zero-initialization of the GMEM scratch at the moment, which + # is important for semaphores. + def zero_init_gmem_scratch(): + return [lax.zeros_like_array(s) for s in lowering_result.gmem_scratch_shapes] + scratch_args = mlir.lower_fun( + zero_init_gmem_scratch, multiple_results=True + )(ctx.replace(avals_in=())) + outs = mgpu.core._mosaic_gpu_lowering_rule( + ctx.replace(avals_in=new_avals_in, avals_out=new_avals_out), + *args, *scratch_args, module=module, - out_types=lowering_result.out_structs, + out_types=lowering_result.new_out_shapes, + inout_types=(), input_output_aliases=input_output_aliases, + use_custom_barrier=False, # False until we add get_barrier_semaphore() feature ) if (prof_ctx := lowering_result.profiler_context) is not None: *outs, prof_buffer = outs @@ -111,4 +140,10 @@ def do_callback(prof_buffer): mlir.lower_fun(do_callback, multiple_results=True)( ctx.replace(avals_in=(new_avals_out[-1],)), prof_buffer ) + if lowering_result.gmem_scratch_shapes: # Drop the GMEM scratch. + outs = outs[:-len(lowering_result.gmem_scratch_shapes)] return outs + + +def _as_shaped_array(t: jax.ShapeDtypeStruct) -> jax_core.ShapedArray: + return jax_core.ShapedArray(t.shape, np.dtype(t.dtype)) diff --git a/jax/_src/pallas/mosaic_gpu/pipeline.py b/jax/_src/pallas/mosaic_gpu/pipeline.py index a48fec61b7af..3ed17a085b8e 100644 --- a/jax/_src/pallas/mosaic_gpu/pipeline.py +++ b/jax/_src/pallas/mosaic_gpu/pipeline.py @@ -21,7 +21,7 @@ import functools import itertools as it import math -from typing import Any +from typing import Any, Protocol, TypeVar, cast import jax from jax import api_util @@ -33,14 +33,28 @@ from jax._src.pallas import core as pallas_core from jax._src.pallas.mosaic_gpu import core as gpu_core from jax._src.pallas.mosaic_gpu import primitives as gpu_primitives -from jax._src.util import foreach from jax.experimental import pallas as pl import jax.numpy as jnp map = util.safe_map zip = util.safe_zip +T = TypeVar('T') +def _get_block_size( + bd: pl.Blocked | pl.Element | pl.Squeezed | pl.BoundedSlice | int | None, +) -> int: + match bd: + case int(): + return bd + case pl.Blocked(block_size): + return block_size + case _: + raise NotImplementedError(f"Unsupported block size type: {type(bd)}") + +def _get_block_shape(spec: pallas_core.BlockSpec): + assert spec.block_shape is not None + return tuple(_get_block_size(bd) for bd in spec.block_shape) @jax.tree_util.register_dataclass @dataclasses.dataclass(frozen=True) @@ -65,10 +79,11 @@ def compute_gmem_slice(self, grid_indices) -> tuple[pl.Slice, ...]: # We don't allow Python scalars here, because they are interpreted # differently depending on the x32/x64 mode. assert all(i.dtype == jnp.dtype(jnp.int32) for i in grid_indices) + sizes = _get_block_shape(self.spec) return tuple( pl.Slice(idx * size, size) # type: ignore[arg-type] for idx, size in zip( - index_map(*grid_indices), self.spec.block_shape # type: ignore[arg-type] + index_map(*grid_indices), sizes # type: ignore[arg-type] ) ) @@ -115,7 +130,7 @@ def _uses_arguments( def _is_index_invariant( - spec: pallas_core.BlockSpec, grid: pallas_core.StaticGrid + spec: pallas_core.BlockSpec, grid: pallas_core.TupleGrid ) -> bool: if (index_map := spec.index_map) is None: return True @@ -123,7 +138,7 @@ def _is_index_invariant( def _inc_grid_by_1( - indices: tuple[jax.Array, ...], grid: Sequence[int] + indices: tuple[jax.Array, ...], grid: pallas_core.TupleGrid ) -> tuple[jax.Array, ...]: next_indices = [] carry: bool | jax.Array = True @@ -160,48 +175,65 @@ def __eq__(self, other: _Slice) -> jax.Array: # type: ignore def emit_pipeline( - body: Callable[..., None], + body: Callable[..., T], *, - grid: pallas_core.StaticGrid, + grid: pallas_core.TupleGrid, in_specs: Sequence[pallas_core.BlockSpec] = (), out_specs: Sequence[pallas_core.BlockSpec] = (), max_concurrent_steps: int = 1, delay_release: int = 0, + init_carry: T | None = None, ): - """Creates a function to emit a manual pipeline within a Pallas kernel. + r"""Creates a function to emit a manual pipeline within a Pallas kernel. Args: - body: The pipeline body. - grid: The grid to use for the pipeline. - in_specs: The block specs for the inputs. - out_specs: The block specs for the outputs. - max_concurrent_steps: The maximum number of sequential stages that are - active concurrently. Defaults to 1. - delay_release: The number of steps to wait before reusing the input/output - references. Defaults to 0, and must be strictly smaller than - ``max_concurrent_steps``. Generally, you'll want to set it to 1 if you - don't await the WGMMA in the body. + body: The pipeline body function, which is called with + + - ``indices``: Tuple of current loop indices. + - ``*input_refs``: SMEM refs for inputs. + - ``*output_refs``: SMEM refs for outputs. + + If ``init_carry`` is provided, ``body`` receives an additional argument + ``carry`` -- the carry from the previous iteration. It must then return + the next carry value. + grid: The grid dimensions for the pipeline. + in_specs: A sequence of :class:`~jax.experimental.pallas.BlockSpec`\s + for inputs. + out_specs: A sequence of :class:`~jax.experimental.pallas.BlockSpec`\s + for outputs. + max_concurrent_steps: Maximum concurrently active pipeline stages. + delay_release: Number of steps to delay before reusing input/output + references. Must be ``< max_concurrent_steps``. Useful for hiding WGMMA + latency (typically set to 1). + init_carry: Optional initial carry. If provided, ``body`` handles + carry-over state between iterations, and the pipeline returns the + final carry. + + Returns: + A function that, when called with GMEM input and output refs, executes the + pipeline and returns the final carry value (if ``init_carry`` was used), + otherwise it returns None. """ - num_steps = math.prod(grid) - if max_concurrent_steps <= delay_release: raise ValueError( "max_concurrent_steps must be greater than delay_release, but" f" {max_concurrent_steps=}, {delay_release=}" ) + num_steps = math.prod(grid) + has_dynamic_grid = not isinstance(num_steps, int) + # Shrink ``max_concurrent_steps`` if the total number of steps is lower to # reduce the size of the refs allocated in SMEM. - if max_concurrent_steps > num_steps: - max_concurrent_steps = num_steps - delay_release = 0 # No need to delay anything. + if not has_dynamic_grid and max_concurrent_steps > num_steps: + max_concurrent_steps = cast(int, num_steps) def pipeline(*gmem_refs: pallas_core.AbstractMemoryRef): in_gmem_refs, out_gmem_refs = util.split_list(gmem_refs, [len(in_specs)]) in_smem_refs, out_smem_refs = util.split_list( [ gpu_core.SMEM( - (max_concurrent_steps, *spec.block_shape), # type: ignore + (max_concurrent_steps, *_get_block_shape(spec)), # type: ignore ref.dtype, transforms=tuple( t.batch(1) for t in getattr(spec, "transforms", ()) @@ -213,6 +245,7 @@ def pipeline(*gmem_refs: pallas_core.AbstractMemoryRef): ], [len(in_specs)], ) + num_arrivals = sum(map(_in_smem, in_specs)) return pl.run_scoped( functools.partial( scoped_pipeline, @@ -221,9 +254,11 @@ def pipeline(*gmem_refs: pallas_core.AbstractMemoryRef): ), in_smem_refs=in_smem_refs, out_smem_refs=out_smem_refs, - barrier_ref=gpu_core.Barrier( + barrier_ref=None + if num_arrivals == 0 + else gpu_core.Barrier( # TODO(slebedev): Change this to arrive only once. - sum(map(_in_smem, in_specs)), + num_arrivals=num_arrivals, num_barriers=max_concurrent_steps, ), ) @@ -244,21 +279,24 @@ def scoped_pipeline( ) ] - for step, indices in enumerate( - it.islice(it.product(*map(range, grid)), max_concurrent_steps) - ): - indices = tuple(map(lambda i: jnp.asarray(i, dtype=jnp.int32), indices)) - foreach(lambda bref: bref.copy_in(step, indices, barrier_ref), in_brefs) + # Initialize the pipeline. + indices = (jnp.asarray(0, dtype=jnp.int32),) * len(grid) + fetch_indices = indices + for step in range(max_concurrent_steps): + for bref in in_brefs: + bref.copy_in(step, fetch_indices, barrier_ref) + fetch_indices = _inc_grid_by_1(fetch_indices, grid) + del fetch_indices # This is true if any of the outputs need to be transferred inside the loop. copies_out_in_loop = not all(bref.is_index_invariant for bref in out_brefs) def loop_body(step, carry): slot = lax.rem(step, max_concurrent_steps) - indices, fetch_indices, last_store_slices = carry + indices, fetch_indices, last_store_slices, prev_body_carry = carry - if in_specs: - # Wait for the current GMEM->SMEM copy to complete. + if barrier_ref is not None: + # Wait for the current GMEM->SMEM copy to complete, if any. gpu_primitives.barrier_wait(barrier_ref.at[slot]) # Wait for the previous output SMEM->GMEM copy to complete. if copies_out_in_loop: @@ -266,11 +304,14 @@ def loop_body(step, carry): max_concurrent_steps - (1 + delay_release), wait_read_only=True ) - with pallas_core.grid_env(map(pallas_core.GridAxis, indices, grid)): - body(*( - bref.get_ref_for_slot(slot) - for bref in it.chain(in_brefs, out_brefs) - )) + next_body_carry = body( + indices, + *( + bref.get_ref_for_slot(slot) + for bref in it.chain(in_brefs, out_brefs) + ), + *(prev_body_carry,) if init_carry is not None else (), + ) if copies_out_in_loop: gpu_primitives.commit_smem() @@ -301,7 +342,8 @@ def loop_body(step, carry): predicate=lax.bitwise_or(slices_changed, is_last_step), ) - gpu_primitives.commit_smem_to_gmem_group() + if copies_out_in_loop: + gpu_primitives.commit_smem_to_gmem_group() fetch_step = step + (max_concurrent_steps - delay_release) fetch_slot = lax.rem(fetch_step, max_concurrent_steps) @@ -320,11 +362,11 @@ def do_fetch(): _inc_grid_by_1(indices, grid), _inc_grid_by_1(fetch_indices, grid), new_store_slices, + next_body_carry if init_carry is not None else None, ) # Invariant: ``indices`` and ``fetch_indices`` are always # ``max_concurrent_steps-delay_release`` apart. - indices = (jnp.asarray(0, dtype=jnp.int32),) * len(grid) fetch_indices = indices for _ in range(max_concurrent_steps-delay_release): fetch_indices = _inc_grid_by_1(fetch_indices, grid) @@ -335,14 +377,18 @@ def do_fetch(): else (_Slice(-1, -1),) * len(bref.spec.block_shape) for bref in out_brefs ] - last_indices, _, _ = lax.fori_loop( - 0, num_steps, loop_body, (indices, fetch_indices, last_store_slices) + last_indices, _, _, final_carry = lax.fori_loop( + 0, + num_steps, + loop_body, + (indices, fetch_indices, last_store_slices, init_carry), ) # Outputs invariant to the sequential axis are never written from inside the # loop. This is the only place where we store them. if not copies_out_in_loop: gpu_primitives.commit_smem() + last_slot = lax.rem(num_steps - 1, max_concurrent_steps) for bref in out_brefs: if bref.is_index_invariant: @@ -352,21 +398,51 @@ def do_fetch(): # Finalize the pipeline. gpu_primitives.wait_smem_to_gmem(0) + return final_carry if init_carry is not None else None return pipeline + +class ComputeContext(Protocol): + """Protocol for a compute context for the warp specialized pipeline. + + The ComputeContext is run exclusively in the compute thread and allows + the user to set up a prologue to initialize a pipeline carry and an epilogue + to consume the final carry. + + All values allocated in the ComputeContext will only be allocated in the + compute thread and not the memory thread. This can potentially reduce + register pressure if certain values are only consumed by the compute threads. + + Usage will usually follow this structure: + + ``` + def compute_context(pipeline): + # Perform prologue work and compute the initial carry. + initial_carry = ... + # Run the pipeline. + final_carry = pipeline(*initial_carry) + # Perform epilogue work using the final carry. + do_work(final_carry) + ``` + + """ + def __call__(self, pipeline: Callable[[T], T]) -> None: + ... + + def emit_pipeline_warp_specialized( body: Callable[..., None], *, - grid: pallas_core.StaticGrid, + grid: pallas_core.TupleGrid, memory_registers: int, - in_specs: Sequence[gpu_core.GPUBlockSpec] = (), - out_specs: Sequence[gpu_core.GPUBlockSpec] = (), + in_specs: Sequence[pl.BlockSpec] = (), + out_specs: Sequence[pl.BlockSpec] = (), max_concurrent_steps: int = 2, wg_axis: str, num_compute_wgs: int, manual_consumed_barriers: bool = False, - carry_coroutine: Any | None = None, + compute_context: ComputeContext | None = None, memory_thread_idx: int | None = None, ): """Creates a function to emit a warp-specialized pipeline. @@ -376,14 +452,16 @@ def emit_pipeline_warp_specialized( ``manual_consumed_barriers`` argument is True. ``` - def body(*input_refs, *output_refs, [consumed_barriers]) -> None: + def body(indices, *input_refs, *output_refs, [consumed_barriers]) -> None: ``` - or with a carries enabled (enabled via the ``carry_coroutine`` argument), + or with a carries enabled (enabled via the ``compute_context`` argument), where the body returns the next carry: ``` - def body(*input_refs, *output_refs, [consumed_barriers], carry) -> Carry: + def body( + indices, *input_refs, *output_refs, [consumed_barriers], carry + ) -> Carry: ``` Args: @@ -400,11 +478,15 @@ def body(*input_refs, *output_refs, [consumed_barriers], carry) -> Carry: manual_consumed_barriers: If True, consumed barriers will be passed into the body function after the output refs. There will be one barrier per input and will be passed in the same order. - carry_coroutine: If specified, enables carries in the pipeline. - The signature of the body function will be modified such that the last - argument will be the current carry and it must return the next carry. - The coroutine itself should yield the initial carry, and the - yield statement will return the final value of the carry. + compute_context: If specified, enables carries in the pipeline and allows + a user-specified prologue/epilogue that is only executed in the compute + thread. The signature of the pipeline body function will be modified + such that the last argument will be the current carry and it must + return the next carry. + The compute_context itself should follow the signature of `ComputeContext` + and take a pipeline function as its sole argument. Calling the + pipeline with the initial carry will run the pipeline and return the + final carry. memory_thread_idx: The index of the memory thread. If not specified, defaults to the last thread. """ @@ -418,17 +500,18 @@ def body(*input_refs, *output_refs, [consumed_barriers], carry) -> Carry: # thread is the last thread. raise NotImplementedError("Memory thread must be the last thread.") - has_carry = carry_coroutine is not None + has_carry = compute_context is not None # Trace the index maps to determine if they depend on the grid. # Grid-independent values will not be multiple-buffered. in_spec_has_seq_axis = [ - ~_is_index_invariant(spec, grid) for spec in in_specs] + not _is_index_invariant(spec, grid) for spec in in_specs] out_spec_has_seq_axis = [ - ~_is_index_invariant(spec, grid) for spec in out_specs] + not _is_index_invariant(spec, grid) for spec in out_specs] spec_has_seq_axis = [*in_spec_has_seq_axis, *out_spec_has_seq_axis] - num_pipeline_steps = math.prod(grid) + num_steps = math.prod(grid) + has_dynamic_grid = not isinstance(num_steps, int) def _get_slot(step, has_seq_dim): """Returns the buffer slot given the pipeline step.""" @@ -439,8 +522,8 @@ def _get_slot(step, has_seq_dim): # Shrink ``max_concurrent_steps`` if the total number of steps is lower to # reduce the size of the refs allocated in SMEM. - if max_concurrent_steps > num_pipeline_steps: - max_concurrent_steps = num_pipeline_steps + if not has_dynamic_grid and max_concurrent_steps > num_steps: + max_concurrent_steps = cast(int, num_steps) def pipeline(*gmem_refs: pallas_core.AbstractMemoryRef): in_gmem_refs, out_gmem_refs = util.split_list(gmem_refs, [len(in_specs)]) @@ -458,7 +541,7 @@ def pipeline(*gmem_refs: pallas_core.AbstractMemoryRef): gpu_core.SMEM( (slots, *spec.block_shape), # type: ignore gmem_ref.dtype, - transforms=spec.transforms, + transforms=getattr(spec, "transforms", ()), ) ) in_smem_refs, out_smem_refs = util.split_list( @@ -498,6 +581,7 @@ def pipeline(*gmem_refs: pallas_core.AbstractMemoryRef): out_smem_refs=out_smem_refs, in_smem_barrier_refs=in_smem_barriers, consumed_barrier_refs=consumed_barriers, + collective_axes=wg_axis, ) def scoped_pipeline( @@ -510,13 +594,13 @@ def scoped_pipeline( consumed_barrier_refs, ): in_brefs: Sequence[BufferedRef] = [ - BufferedRef(spec, ~has_seq_axis, gmem_ref, smem_ref) + BufferedRef(spec, not has_seq_axis, gmem_ref, smem_ref) for spec, has_seq_axis, gmem_ref, smem_ref in zip( in_specs, in_spec_has_seq_axis, in_gmem_refs, in_smem_refs ) ] out_brefs: Sequence[BufferedRef] = [ - BufferedRef(spec, ~has_seq_axis, gmem_ref, smem_ref) + BufferedRef(spec, not has_seq_axis, gmem_ref, smem_ref) for spec, has_seq_axis, gmem_ref, smem_ref in zip( out_specs, out_spec_has_seq_axis, out_gmem_refs, out_smem_refs ) @@ -545,18 +629,17 @@ def compute_loop_body(step, carry): if copies_out_in_loop: gpu_primitives.wait_smem_to_gmem(max_concurrent_steps - 1) - with pallas_core.grid_env(map(pallas_core.GridAxis, indices, grid)): - body_refs = [] - for bref in it.chain(in_brefs, out_brefs): - buf_slot = _get_slot(slot, ~bref.is_index_invariant) - body_refs.append(bref.get_ref_for_slot(buf_slot)) + body_refs = [] + for bref in it.chain(in_brefs, out_brefs): + buf_slot = _get_slot(slot, not bref.is_index_invariant) + body_refs.append(bref.get_ref_for_slot(buf_slot)) - body_args = body_refs - if manual_consumed_barriers: - body_args += [consumed_barrier_ref.at[slot] for consumed_barrier_ref in consumed_barrier_refs] - if has_carry: - body_args += [prev_body_carry] - next_body_carry = body(*body_args) + body_args = body_refs + if manual_consumed_barriers: + body_args += [consumed_barrier_ref.at[slot] for consumed_barrier_ref in consumed_barrier_refs] + if has_carry: + body_args += [prev_body_carry] + next_body_carry = body(indices, *body_args) if not manual_consumed_barriers: [consumed_barrier_ref] = consumed_barrier_refs @@ -581,7 +664,7 @@ def compute_loop_body(step, carry): new_store_slices[idx], ) slices_changed = ~functools.reduce(lax.bitwise_and, are_same_slices) - bref.copy_out(_get_slot(slot, ~bref.is_index_invariant), + bref.copy_out(_get_slot(slot, not bref.is_index_invariant), indices, predicate=slices_changed) gpu_primitives.commit_smem_to_gmem_group() @@ -597,34 +680,37 @@ def compute_loop_body(step, carry): ] if has_carry: - _carry = carry_coroutine() - try: - carry_init = next(_carry) - except StopIteration: - raise ValueError("carry_coroutine must yield the initial carry.") # pylint: disable=raise-missing-from + last_indices = None + def pipeline_callback(user_init_carry): + nonlocal last_indices + if last_indices is not None: + raise ValueError( + "Cannot call pipeline more than once in `compute_context`") + init_loop_carry = (init_indices, last_store_slices, user_init_carry) + last_indices, _, final_body_carry = lax.fori_loop(0, + num_steps, + compute_loop_body, + init_loop_carry) + return final_body_carry + compute_context(pipeline_callback) + if last_indices is None: + raise ValueError("Pipeline was not called in `compute_context`") else: - _carry = None - carry_init = None - init_loop_carry = (init_indices, last_store_slices, carry_init) - last_indices, _, final_body_carry = lax.fori_loop(0, - num_pipeline_steps, - compute_loop_body, - init_loop_carry) - if has_carry: - try: - _carry.send(final_body_carry) # pytype: disable=attribute-error - raise ValueError("carry_coroutine must only yield once.") - except StopIteration: - pass + assert compute_context is None + last_indices, _, _ = lax.fori_loop( + 0, num_steps, compute_loop_body, + (init_indices, last_store_slices, None) + ) # Handle index_invariant outputs after the loop. They are not # written in the main pipeline loop. if not copies_out_in_loop: gpu_primitives.commit_smem() - last_slot = lax.rem(num_pipeline_steps - 1, max_concurrent_steps) + last_slot = lax.rem(num_steps - 1, max_concurrent_steps) for bref in out_brefs: if bref.is_index_invariant: - bref.copy_out(last_slot, last_indices, predicate=None) + bref.copy_out(_get_slot(last_slot, has_seq_dim=False), + last_indices, predicate=None) gpu_primitives.commit_smem_to_gmem_group() @@ -635,13 +721,22 @@ def compute_loop_body(step, carry): def memory_block(): gpu_primitives.set_max_registers(memory_registers, action="decrease") indices = (jnp.asarray(0, dtype=jnp.int32),) * len(grid) + if has_dynamic_grid: + prologue_steps = lax.min(max_concurrent_steps, num_steps) + else: + assert max_concurrent_steps <= num_steps + prologue_steps = max_concurrent_steps # Begin initial copies. - for step in range(max_concurrent_steps): + def _init_step(step, indices): for bref, barrier in zip(in_brefs, in_smem_barrier_refs): - buf_slot = _get_slot(step, ~bref.is_index_invariant) + buf_slot = _get_slot(step, not bref.is_index_invariant) bref.copy_in(buf_slot, indices, barrier) - indices = _inc_grid_by_1(indices, grid) + return _inc_grid_by_1(indices, grid) + + indices = jax.lax.fori_loop( + 0, prologue_steps, _init_step, indices, unroll=not has_dynamic_grid + ) def memory_loop_body(step, carry): indices, = carry @@ -662,11 +757,17 @@ def memory_loop_body(step, carry): if manual_consumed_barriers: gpu_primitives.barrier_wait(consumed_barrier.at[slot]) # pytype: disable=attribute-error bref.copy_in( - _get_slot(fetch_slot, ~bref.is_index_invariant), indices, barrier) + _get_slot(fetch_slot, not bref.is_index_invariant), indices, barrier) next_indices = _inc_grid_by_1(indices, grid) return (next_indices,) - lax.fori_loop(0, num_pipeline_steps - max_concurrent_steps, + lax.fori_loop(0, num_steps - max_concurrent_steps, memory_loop_body, (indices,)) + # Await all the arrivals to not leave barriers in a bad state. + # We only need to account for the prologue steps. + @pl.loop(0, prologue_steps, unroll=not has_dynamic_grid) + def _epi_step(step): + for barrier in consumed_barrier_refs: + gpu_primitives.barrier_wait(barrier.at[step]) wg_idx = lax.axis_index(wg_axis) lax.cond( @@ -680,8 +781,16 @@ def _compute_registers( memory_registers: int, num_compute_wgs: int, ) -> int: - """Returns the number of registers to use for the compute thread.""" - # TODO(justinfu): Configure this per-platform. - n_registers = (512 - memory_registers) / num_compute_wgs + """Returns the max number of registers to use in compute threads. + + We start with the theoretical max registers per thread if one wargroup + (128 threads) used the entire SM's 64k register file (64k / 128 = 512). + Then reserve `memory_registers` for the producer warpgroup and distribute + the remaining registers evenly among the compute warpgroups. + + Note: The maximum number of registers per thread is 255, so we clamp + the value. + """ + n_registers = min(256, (512 - memory_registers) / num_compute_wgs) # Round down to the nearest multiple of 8. return int((n_registers // 8) * 8) diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index 7f26f5d2b6a3..dbe24bb299fb 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -16,23 +16,29 @@ from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Sequence, Callable import dataclasses import enum +import functools import itertools import math from typing import Any, Literal import jax from jax._src import core as jax_core +from jax._src import frozen_dict +from jax._src import pretty_printer as pp from jax._src import state from jax._src import tree_util from jax._src import util from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import arith as arith_dialect from jax._src.lib.mlir.dialects import llvm as llvm_dialect +from jax._src.lib.mlir.dialects import memref as memref_dialect +from jax._src.lib.mlir.dialects import gpu as gpu_dialect from jax._src.lib.mlir.dialects import nvvm as nvvm_dialect from jax._src.pallas import core as pallas_core +from jax._src.pallas import primitives as pallas_primitives from jax._src.pallas.mosaic_gpu import core as gpu_core from jax._src.pallas.mosaic_gpu import lowering from jax._src.pallas.mosaic_gpu.core import state_types @@ -41,9 +47,11 @@ from jax._src.state import primitives as state_primitives from jax.experimental.mosaic import gpu as mgpu from jax.experimental.mosaic.gpu import utils as mgpu_utils +from jax.experimental.mosaic.gpu import tcgen05 import jax.numpy as jnp +WARP_SIZE = 32 WARPGROUP_SIZE = 128 @@ -51,7 +59,7 @@ def _check_ref( - aval: object, name: str, memory_space: gpu_core.GPUMemorySpace + aval: object, name: str, memory_space: gpu_core.MemorySpace ) -> None: if not isinstance(aval, state_types.AbstractRef): raise TypeError(f"{name} must be a reference, got {aval}") @@ -62,6 +70,117 @@ def _check_ref( ) +load_p = jax_core.Primitive("load") + +@load_p.def_effectful_abstract_eval +def _load_abstract_eval(src, *avals_flat, args_tree, layout, optimized): + del layout, optimized # Unused. + transforms = args_tree.unflatten(avals_flat) + dtype = lowering._transform_dtype(src.dtype, transforms) + return ( + jax_core.ShapedArray(transforms[-1].get_indexer_shape(), dtype), + {state.ReadEffect(0)}, + ) + +@lowering.register_lowering_rule(load_p, mgpu.LoweringSemantics.Lane) +def _load_p_lowering_rule( + ctx: lowering.LoweringRuleContext, x_ref, *leaves, args_tree, layout, optimized +): + if not isinstance(x_ref, ir.Value) or not ir.MemRefType.isinstance(x_ref.type): + raise TypeError(f"Can only load from references (got {x_ref}).") + + out_aval = ctx.avals_out[0] + + transforms = jax.tree.unflatten(args_tree, leaves) + x_ref, transforms = lowering._handle_transforms(ctx, x_ref, transforms) + + if layout is not None: + layout = layout.to_mgpu() + + is_signed = mgpu_utils.is_signed(out_aval.dtype) + match transforms: + case (gpu_core.UnswizzleRef(swizzle), gpu_core.UntileRef(tiling)): + if tiling != (8, swizzle // out_aval.dtype.itemsize): + raise NotImplementedError("Tiling does not fit swizzle") + return mgpu.FragmentedArray.load_tiled( + x_ref, + is_signed=is_signed, + swizzle=swizzle, + layout=layout, + ) + case (): + # Handle scalar indexing. + if not out_aval.shape: + is_signed = mgpu_utils.is_signed(out_aval.dtype) + val = memref_dialect.load(x_ref, []) + return mgpu.FragmentedArray.splat( + val, shape=(), layout=layout, is_signed=is_signed + ) + match layout: + case ( + mgpu.WGMMA_ROW_LAYOUT + | mgpu.WGMMA_COL_LAYOUT + | mgpu.TCGEN05_ROW_LAYOUT + | mgpu.TCGEN05_COL_LAYOUT + ): + return mgpu.FragmentedArray.load_untiled( + x_ref, + is_signed=is_signed, + layout=layout, + swizzle=16, + optimized=optimized, + ) + case mgpu.WGStridedFragLayout(shape=shape, vec_size=vec_size): + ref_ty = ir.MemRefType(x_ref.type) + if shape != tuple(ref_ty.shape): + raise ValueError( + f"Unsupported shape {shape}, (expected {tuple(ref_ty.shape)})" + ) + return mgpu.FragmentedArray.load_strided( + x_ref, is_signed=is_signed, vec_size=vec_size, + ) + case None: + return mgpu.FragmentedArray.load_strided(x_ref, is_signed=is_signed) + case _: + raise NotImplementedError(f"Unsupported layout: {layout}") + case _: + raise NotImplementedError(f"Unsupported transforms: {transforms}") + + +def load( + src: _Ref, + idx, + *, + layout: Layout | ParameterizedLayout | None = None, + optimized: bool = True, +) -> jax.Array: + """Loads from a reference into an array with the specified layout. + + Args: + src: The reference to load from. Can be either in SMEM or GMEM. + idx: The index to load from. + layout: The optional layout to use for the resulting array. + optimized: If True, a compilation error will be raised if no optimized + implementation for the load is available. + + Returns: + The loaded array. + """ + src, src_transforms = state_primitives.get_ref_and_transforms( + src, idx, "load", force_trailing_indexer=True, + ) + flat_src_transforms, src_transforms_treedef = tree_util.tree_flatten( + src_transforms + ) + return load_p.bind( + src, + *flat_src_transforms, + args_tree=src_transforms_treedef, + layout=layout, + optimized=optimized, + ) + + copy_smem_to_gmem_p = jax_core.Primitive("copy_smem_to_gmem") copy_smem_to_gmem_p.multiple_results = True @@ -74,9 +193,48 @@ def _copy_smem_to_gmem_abstract_eval(src, dst, *args, **params): return (), {state.ReadEffect(0), state.WriteEffect(1)} -@lowering.register_lowering_rule(copy_smem_to_gmem_p, mgpu.ThreadSemantics.Lane) +def _copy_smem_to_gmem_pp_eqn( + eqn: jax_core.JaxprEqn, + context: jax_core.JaxprPpContext, + settings: jax_core.JaxprPpSettings, +): + src, dst, *flat_args = eqn.invars + src_transforms_treedef = eqn.params["src_transforms_treedef"] + dst_transforms_treedef = eqn.params["dst_transforms_treedef"] + pp_params = {} + if not (commit_group := eqn.params["commit_group"]): + pp_params["commit_group"] = commit_group + if eqn.params["has_user_predicate"]: + flat_args, user_predicate = flat_args[:-1], flat_args[-1] + pp_params["user_predicate"] = jax_core.pp_var(user_predicate, context) + if reduction_op := eqn.params["reduction_op"]: + pp_params["reduction_op"] = reduction_op + flat_src_transforms, flat_dst_transforms = util.split_list( + flat_args, + [src_transforms_treedef.num_leaves], + ) + src_transforms = src_transforms_treedef.unflatten(flat_src_transforms) + dst_transforms = dst_transforms_treedef.unflatten(flat_dst_transforms) + return pp.concat([ + pp.text("copy_smem_to_gmem"), + jax_core.pp_kv_pairs(pp_params.items(), context, settings), + pp.text(" "), + state_primitives.pp_ref_transforms(context, src, src_transforms), + pp.text(" -> "), + state_primitives.pp_ref_transforms(context, dst, dst_transforms), + ]) + + +jax_core.pp_eqn_rules[copy_smem_to_gmem_p] = _copy_smem_to_gmem_pp_eqn + + @lowering.register_lowering_rule( - copy_smem_to_gmem_p, mgpu.ThreadSemantics.Warpgroup + copy_smem_to_gmem_p, mgpu.LoweringSemantics.Lane) +@lowering.register_lowering_rule( + copy_smem_to_gmem_p, mgpu.LoweringSemantics.Lane, + primitive_semantics=gpu_core.PrimitiveSemantics.Warp) +@lowering.register_lowering_rule( + copy_smem_to_gmem_p, mgpu.LoweringSemantics.Warpgroup ) def _copy_smem_to_gmem_lowering( ctx: lowering.LoweringRuleContext, @@ -87,27 +245,40 @@ def _copy_smem_to_gmem_lowering( dst_transforms_treedef, has_user_predicate, commit_group, + reduction_op, ): - predicate = ctx.module_ctx.single_wg_lane_predicate if has_user_predicate: flat_args, user_predicate = flat_args[:-1], flat_args[-1] - predicate = arith_dialect.andi( - predicate, lowering._ensure_ir_value(user_predicate, jnp.bool) - ) + predicate = lowering._ensure_ir_value(user_predicate, jnp.bool) + else: + predicate = None + + if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Lane: + if predicate is not None: + assert ctx.module_ctx.single_lane_predicate is not None + predicate = arith_dialect.andi( + predicate, ctx.module_ctx.single_lane_predicate + ) + else: + predicate = ctx.module_ctx.single_lane_predicate + flat_src_transforms, flat_dst_transforms = util.split_list( flat_args, [src_transforms_treedef.num_leaves], ) src_transforms = src_transforms_treedef.unflatten(flat_src_transforms) dst_transforms = dst_transforms_treedef.unflatten(flat_dst_transforms) - src, src_transforms = lowering._handle_indexing(src, src_transforms) + src, src_transforms = lowering._handle_transforms( + ctx, src, src_transforms, handle_transposes=False + ) copy_params = _extract_gmem_copy_params(dst_transforms) | _extract_smem_copy_params(src_transforms) - if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane: + if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Lane: ctx.launch_ctx.async_copy( src_ref=src, dst_ref=dst, predicate=predicate, arrive=commit_group, + reduction_op=reduction_op, **copy_params, ) return () @@ -119,6 +290,10 @@ def _copy_smem_to_gmem_lowering( else: indices, slice_lengths = _split_gmem_slice(copy_params["gmem_slice"]) assert copy_params.get("swizzle") is None + if copy_params.get("gmem_peer_id", None) is not None: + raise NotImplementedError( + "GMEM refs with peer ids are not supported in warpgroup lowering." + ) assert not copy_params.get("gmem_transform") mgpu.dialect.async_store( src, @@ -154,13 +329,25 @@ def _split_gmem_slice(gmem_slice): def _extract_gmem_copy_params(transforms): if not transforms: return {} + peer_id = None + indexers = [] for transform in transforms: - if not isinstance(transform, indexing.NDIndexer): + if isinstance(transform, gpu_core.PeerMemRef): + if transform.device_id_type != pallas_primitives.DeviceIdType.LOGICAL: + raise NotImplementedError( + "Only logical device ids are supported for GMEM refs." + ) + peer_id = lowering._ensure_ir_value(transform.device_id, jnp.int32) + continue + elif isinstance(transform, indexing.NDIndexer): + indexers.append(transform) + else: raise NotImplementedError( "Non-indexing transforms on GMEM refs are not implemented.") - indexer = lowering.merge_indexers(transforms) + indexer = lowering.merge_indexers(indexers) return dict( gmem_slice=lowering._ndindexer_indices(indexer), + gmem_peer_id=peer_id, ) @@ -186,6 +373,7 @@ def copy_smem_to_gmem( predicate: jax.Array | None = None, *, commit_group: bool = True, + reduction_op: mgpu.ReductionOp | None = None, ) -> None: """Asynchronously copies a SMEM reference to a GMEM reference. @@ -194,9 +382,12 @@ def copy_smem_to_gmem( dst: The GMEM reference to copy to. predicate: A boolean indicating whether the copy should be performed. If ``None``, the copy is always performed. - commit_group: If ``True``, this and any previously uncommitted copies - are committed to a group and can be awaited jointly via + commit_group: If ``True``, this and any previously uncommitted copies are + committed to a group and can be awaited jointly via :func:`jax.experimental.mosaic.gpu.wait_smem_to_gmem`. + reduction_op: If set, perform the specified reduction operation when storing + to GMEM. For example, using ``"add"`` is conceptually equivalent to + doing ``src += dst``. See also: :func:`jax.experimental.mosaic.gpu.wait_smem_to_gmem` @@ -224,6 +415,7 @@ def copy_smem_to_gmem( dst_transforms_treedef=dst_transforms_treedef, has_user_predicate=predicate is not None, commit_group=commit_group, + reduction_op=reduction_op, ) return None @@ -241,9 +433,51 @@ def _copy_gmem_to_smem_abstract_eval(src, dst, barrier, *args, **params): return (), {state.ReadEffect(0), state.WriteEffect(1)} -@lowering.register_lowering_rule(copy_gmem_to_smem_p, mgpu.ThreadSemantics.Lane) +def _copy_gmem_to_smem_pp_eqn( + eqn: jax_core.JaxprEqn, + context: jax_core.JaxprPpContext, + settings: jax_core.JaxprPpSettings, +): + src, dst, barrier, *flat_args = eqn.invars + src_transforms_treedef = eqn.params["src_transforms_treedef"] + dst_transforms_treedef = eqn.params["dst_transforms_treedef"] + barrier_transforms_treedef = eqn.params["barrier_transforms_treedef"] + pp_params = {} + if collective_axes := eqn.params["collective_axes"]: + pp_params["collective_axes"] = collective_axes + flat_src_transforms, flat_dst_transforms, flat_barrier_transforms = ( + util.split_list( + flat_args, + [ + src_transforms_treedef.num_leaves, + dst_transforms_treedef.num_leaves, + ], + ) + ) + src_transforms = src_transforms_treedef.unflatten(flat_src_transforms) + dst_transforms = dst_transforms_treedef.unflatten(flat_dst_transforms) + barrier_transforms = barrier_transforms_treedef.unflatten( + flat_barrier_transforms + ) + return pp.concat([ + pp.text("copy_gmem_to_smem"), + jax_core.pp_kv_pairs(pp_params.items(), context, settings), + pp.text(" "), + state_primitives.pp_ref_transforms(context, src, src_transforms), + pp.text(" -> "), + state_primitives.pp_ref_transforms(context, dst, dst_transforms), + pp.text(" using "), + state_primitives.pp_ref_transforms(context, barrier, barrier_transforms), + ]) + + +jax_core.pp_eqn_rules[copy_gmem_to_smem_p] = _copy_gmem_to_smem_pp_eqn + + +@lowering.register_lowering_rule( + copy_gmem_to_smem_p, mgpu.LoweringSemantics.Lane) @lowering.register_lowering_rule( - copy_gmem_to_smem_p, mgpu.ThreadSemantics.Warpgroup + copy_gmem_to_smem_p, mgpu.LoweringSemantics.Warpgroup ) def _copy_gmem_to_smem_lowering( ctx: lowering.LoweringRuleContext, @@ -254,6 +488,9 @@ def _copy_gmem_to_smem_lowering( src_transforms_treedef, dst_transforms_treedef, barrier_transforms_treedef, + collective_axes, + partitioned_axis, + for_warpgroup: bool = True, ): flat_src_transforms, flat_dst_transforms, flat_barrier_transforms = ( util.split_list( @@ -266,7 +503,9 @@ def _copy_gmem_to_smem_lowering( ) src_transforms = src_transforms_treedef.unflatten(flat_src_transforms) dst_transforms = dst_transforms_treedef.unflatten(flat_dst_transforms) - dst, dst_transforms = lowering._handle_indexing(dst, dst_transforms) + dst, dst_transforms = lowering._handle_transforms( + ctx, dst, dst_transforms, handle_transposes=False + ) copy_params = _extract_smem_copy_params(dst_transforms) | _extract_gmem_copy_params(src_transforms) barrier_indexer = _extract_barrier_indexer( barrier_transforms_treedef.unflatten(flat_barrier_transforms) @@ -275,24 +514,76 @@ def _copy_gmem_to_smem_lowering( barrier = barrier.__getitem__( *map(lowering._as_index, barrier_indexer.indices) ) + collective = None + if collective_axes is not None: + collective = tuple( + lowering._resolve_cluster_axis(ctx.module_ctx.axis_names, axis) + for axis in collective_axes + ) dst_ty = ir.MemRefType(dst.type) - bytes = math.prod(dst_ty.shape) * mgpu.bytewidth(dst_ty.element_type) - if ctx.module_ctx.thread_semantics == mgpu.ThreadSemantics.Lane: + bits = math.prod(dst_ty.shape) * mgpu.bitwidth(dst_ty.element_type) + if bits % 8: + raise ValueError( + f"Can only transfer integer bytes (shape={dst_ty.shape}," + f" dtype={dst_ty.element_type})" + ) + bytes = bits // 8 + if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Lane: if bytes % WARPGROUP_SIZE: raise NotImplementedError("Only aligned copies are supported") - # We arrive uniformly from each thread in the WG, so we need to divide the - # number of bytes by the number of threads in the WG. - # TODO: apaszke - Relax this. We can just select the WG leader and have it - # arrive with the whole transfer size, while everyone else arrives with 0. - # But we should continue using this scheme as it's likely to be faster. - bytes //= WARPGROUP_SIZE - barrier.arrive_expect_tx(bytes) + if for_warpgroup: + # We arrive uniformly from each thread in the WG, so we need to divide the + # number of bytes by the number of threads in the WG. + # TODO: apaszke - Relax this. We can just select the WG leader and have it + # arrive with the whole transfer size, while everyone else arrives with 0. + # But we should continue using this scheme as it's likely to be faster. + bytes //= WARPGROUP_SIZE + if collective and partitioned_axis is not None: + raise NotImplementedError( + "Collective partitioned copies not implemented." + ) + if ctx.module_ctx.auto_barriers: + mgpu.warpgroup_barrier() # Make sure all reads have completed. + barrier.arrive_expect_tx(bytes) + else: + # In Warp-level lowering, we arrive on each CUDA thread in a warp, but + # the barrier still expects a full 128 arrivals so we arrive 4 times + # on each CUDA thread instead. + # TODO(justinfu): The arrival counts are wrong if called outside of a + # single warp. Figure out how to guard against this in user code. + bytes = bytes // WARP_SIZE + if collective and partitioned_axis is not None: + if len(collective) != 1: + raise ValueError( + f"Expected exactly one collective axis, got {collective_axes=}" + ) + if math.prod(ctx.launch_ctx.cluster_size) != 2: + raise NotImplementedError( + "Partitioned loads only supported for clusters of size 2" + ) + # Bytes is the destination size, which is only half of the total + # size of the partitioned transfer so we need to double it. + bytes *= 2 + first_block = arith_dialect.cmpi( + arith_dialect.CmpIPredicate.eq, + ctx.launch_ctx.cluster_idx(collective[0]), + mgpu.c(0, ir.IndexType.get()), + ) + with mgpu.when(first_block): + barrier.arrive(arrival_count=3, can_complete=False) + barrier.arrive_expect_tx(bytes) + else: + barrier.arrive(arrival_count=3, can_complete=False) + barrier.arrive_expect_tx(bytes) + ctx.launch_ctx.async_copy( src_ref=src, dst_ref=dst, barrier=barrier, arrive=False, - predicate=ctx.module_ctx.single_wg_lane_predicate, + predicate=ctx.module_ctx.single_lane_predicate, + collective=collective, + partitioned=partitioned_axis, **copy_params, ) return () @@ -305,7 +596,11 @@ def _copy_gmem_to_smem_lowering( indices, slice_lengths = _split_gmem_slice(copy_params["gmem_slice"]) assert copy_params.get("swizzle") is None assert not copy_params.get("gmem_transform") - barrier_ref = barrier.as_dialect_barrier_memref() + if copy_params.get("gmem_peer_id", None) is not None: + raise NotImplementedError( + "GMEM refs with peer ids are not supported in warpgroup lowering." + ) + barrier_ref = barrier.as_barrier_memref() mgpu.dialect.arrive_expect_tx(barrier_ref, bytes) mgpu.dialect.async_load( src, @@ -318,9 +613,46 @@ def _copy_gmem_to_smem_lowering( return () -def copy_gmem_to_smem(src: _Ref, dst: _Ref, barrier: _Ref) -> None: +lowering.register_lowering_rule( + copy_gmem_to_smem_p, + mgpu.LoweringSemantics.Lane, + primitive_semantics=gpu_core.PrimitiveSemantics.Warp, +)(functools.partial(_copy_gmem_to_smem_lowering, for_warpgroup=False)) + + +def copy_gmem_to_smem( + src: _Ref, + dst: _Ref, + barrier: _Ref, + *, + collective_axes: str | tuple[str, ...] | None = None, + partitioned_axis: int | None = None, +) -> None: """Asynchronously copies a GMEM reference to a SMEM reference. + If collective_axes is specified, this performs a multicast copy where + all CUDA blocks that share the same index along the collective axis + receive a copy of the same block of data loaded from `dst` to `src`. + + If both collective_axes and partitioned_axis are specified, this will perform + a partitioned collective copy where each block in the cluster will receive + a tile of `transfer_size // cluster_size` data from the `src` Ref. + For example, if `src` has a shape of (256, 256) and a partitioned + copy is performed along axis 0 with cluster size 2, then the first block will + receive `src[0:128, :]` and the second will receive `src[128:256, :]`. + NOTE: Only the first block in the cluster will arrive on the barrier, + and an additional cluster barrier is necessary to ensure that all blocks in + the cluster have finished the copy. + + Args: + src: The source Ref. Must be in GMEM. + dst: The destination Ref. Must be in SMEM. + barrier: The barrier to use for tracking completion of the copy. + collective_axes: The collective axes to use for the copy. + partitioned_axis: Indicates which array axis along the src/dst Refs to + partition across during a partitioned collective copy. Requires + collective_axes to also be specified. + See also: :func:`jax.experimental.mosaic.gpu.barrier_arrive` :func:`jax.experimental.mosaic.gpu.barrier_wait` @@ -343,6 +675,8 @@ def copy_gmem_to_smem(src: _Ref, dst: _Ref, barrier: _Ref) -> None: flat_barrier_transforms, barrier_transforms_treedef = tree_util.tree_flatten( barrier_transforms ) + if isinstance(collective_axes, str): + collective_axes = (collective_axes,) copy_gmem_to_smem_p.bind( src, dst, @@ -353,6 +687,8 @@ def copy_gmem_to_smem(src: _Ref, dst: _Ref, barrier: _Ref) -> None: src_transforms_treedef=src_transforms_treedef, dst_transforms_treedef=dst_transforms_treedef, barrier_transforms_treedef=barrier_transforms_treedef, + collective_axes=collective_axes, + partitioned_axis=partitioned_axis, ) return None @@ -376,7 +712,7 @@ def _extract_barrier_indexer(transforms) -> indexing.NDIndexer | None: case []: return None case _: - raise ValueError("Barrier does not support arbirary transforms") + raise ValueError("Barrier does not support arbitrary transforms") barrier_arrive_p = jax_core.Primitive("barrier_arrive") @@ -387,10 +723,32 @@ def _extract_barrier_indexer(transforms) -> indexing.NDIndexer | None: def _barrier_arrive_abstract_eval(barrier, *args, **params): del args, params # Unused. _check_ref(barrier, "barrier", gpu_core.SMEM) + if getattr(barrier.inner_aval.dtype, "for_tensor_core", False): + raise ValueError("Cannot arrive on a tensor core barrier.") return (), {gpu_core._memory_effect} -@lowering.register_lowering_rule(barrier_arrive_p, mgpu.ThreadSemantics.Lane) +def _barrier_arrive_pp_eqn( + eqn: jax_core.JaxprEqn, + context: jax_core.JaxprPpContext, + settings: jax_core.JaxprPpSettings, +): + del settings + barrier, *flat_transforms = eqn.invars + transforms_treedef = eqn.params["transforms_treedef"] + transforms = transforms_treedef.unflatten(flat_transforms) + return pp.concat([ + pp.text("barrier_arrive"), + pp.text(" "), + state_primitives.pp_ref_transforms(context, barrier, transforms), + ]) + + +jax_core.pp_eqn_rules[barrier_arrive_p] = _barrier_arrive_pp_eqn + + +@lowering.register_lowering_rule(barrier_arrive_p, mgpu.LoweringSemantics.Lane) +@lowering.register_lowering_rule(barrier_arrive_p, mgpu.LoweringSemantics.Warpgroup) def _barrier_arrive_lowering( ctx: lowering.LoweringRuleContext, barrier, @@ -428,20 +786,48 @@ def _barrier_wait_abstract_eval(barrier, *args, **params): return (), {gpu_core._memory_effect} -@lowering.register_lowering_rule(barrier_wait_p, mgpu.ThreadSemantics.Lane) -@lowering.register_lowering_rule(barrier_wait_p, mgpu.ThreadSemantics.Warpgroup) +def _barrier_wait_pp_eqn( + eqn: jax_core.JaxprEqn, + context: jax_core.JaxprPpContext, + settings: jax_core.JaxprPpSettings, +): + del settings + barrier, *flat_transforms = eqn.invars + transforms_treedef = eqn.params["transforms_treedef"] + transforms = transforms_treedef.unflatten(flat_transforms) + return pp.concat([ + pp.text("barrier_wait"), + pp.text(" "), + state_primitives.pp_ref_transforms(context, barrier, transforms), + ]) + + +jax_core.pp_eqn_rules[barrier_wait_p] = _barrier_wait_pp_eqn + + +@lowering.register_lowering_rule(barrier_wait_p, mgpu.LoweringSemantics.Lane) +@lowering.register_lowering_rule( + barrier_wait_p, + mgpu.LoweringSemantics.Lane, + gpu_core.PrimitiveSemantics.Warp, +) +@lowering.register_lowering_rule( + barrier_wait_p, mgpu.LoweringSemantics.Warpgroup +) def _barrier_wait_lowering( ctx: lowering.LoweringRuleContext, barrier, *flat_transforms, transforms_treedef, ): - del ctx # Unused. + barrier_aval = ctx.avals_in[0] transforms = transforms_treedef.unflatten(flat_transforms) indexer = _extract_barrier_indexer(transforms) + for_tensor_core = getattr( + barrier_aval.inner_aval.dtype, "for_tensor_core", False) if indexer is not None: barrier = barrier.__getitem__(*map(lowering._as_index, indexer.indices)) - barrier.wait() + barrier.wait(for_tensor_core=for_tensor_core) return () @@ -452,7 +838,7 @@ def barrier_wait(barrier: pallas_core.AbstractMemoryRef) -> None: ) flat_transforms, transforms_treedef = tree_util.tree_flatten(transforms) barrier_wait_p.bind( - barrier, *flat_transforms, transforms_treedef=transforms_treedef + barrier, *flat_transforms, transforms_treedef=transforms_treedef, ) @@ -466,9 +852,10 @@ def _wait_smem_to_gmem_abstract_eval(n, *, wait_read_only): return (), {gpu_core._memory_effect} -@lowering.register_lowering_rule(wait_smem_to_gmem_p, mgpu.ThreadSemantics.Lane) @lowering.register_lowering_rule( - wait_smem_to_gmem_p, mgpu.ThreadSemantics.Warpgroup + wait_smem_to_gmem_p, mgpu.LoweringSemantics.Lane) +@lowering.register_lowering_rule( + wait_smem_to_gmem_p, mgpu.LoweringSemantics.Warpgroup ) def _wait_smem_to_gmem_lowering( ctx: lowering.LoweringRuleContext, n, *, wait_read_only @@ -499,8 +886,9 @@ def _commit_group_abstract_eval(): return (), {gpu_core._memory_effect} -@lowering.register_lowering_rule(commit_group_p, mgpu.ThreadSemantics.Lane) -@lowering.register_lowering_rule(commit_group_p, mgpu.ThreadSemantics.Warpgroup) +@lowering.register_lowering_rule(commit_group_p, mgpu.LoweringSemantics.Lane) +@lowering.register_lowering_rule( + commit_group_p, mgpu.LoweringSemantics.Warpgroup) def _commit_group_lowering(ctx: lowering.LoweringRuleContext): del ctx # Unused. nvvm_dialect.cp_async_bulk_commit_group() @@ -508,7 +896,7 @@ def _commit_group_lowering(ctx: lowering.LoweringRuleContext): def commit_smem_to_gmem_group() -> None: - """Commits all issued but uncommited SMEM->GMEM copies to a group.""" + """Commits all issued but uncommitted SMEM->GMEM copies to a group.""" commit_group_p.bind() @@ -517,11 +905,7 @@ def commit_smem_to_gmem_group() -> None: wgmma_ref_p.multiple_results = True -def wgmma( - acc: gpu_core.WGMMAAbstractAccumulatorRef, - a, - b: pallas_core.TransformedRef, -) -> None: +def wgmma(acc: gpu_core.WGMMAAbstractAccumulatorRef, a, b) -> None: """Performs an asynchronous warp group matmul-accumulate on the given references. Conceptually, this is equivalent to doing ``acc[...] += a[...] @ b[...]``, @@ -555,12 +939,17 @@ def wgmma( a = a.ref else: a_transforms_leaves, a_transforms_tree = [], None - b_transforms_leaves, b_transforms_tree = jax.tree.flatten(b.transforms) + + if isinstance(b, pallas_core.TransformedRef): + b_transforms_leaves, b_transforms_tree = jax.tree.flatten(b.transforms) + b = b.ref + else: + b_transforms_leaves, b_transforms_tree = [], None wgmma_ref_p.bind( acc, a, - b.ref, + b, *a_transforms_leaves, *b_transforms_leaves, a_transforms_tree=a_transforms_tree, @@ -582,6 +971,40 @@ def _wgmma_ref_effectful_abstract_eval(acc_aval, a_aval, b_aval, *_, **params): } +def _wgmma_ref_pp_eqn( + eqn: jax_core.JaxprEqn, + context: jax_core.JaxprPpContext, + settings: jax_core.JaxprPpSettings, +): + del settings + acc, a, b, *leaves = eqn.invars + a_transforms_treedef = eqn.params["a_transforms_tree"] + b_transforms_treedef = eqn.params["b_transforms_tree"] + split = getattr(a_transforms_treedef, "num_leaves", 0) + a_transforms = ( + a_transforms_treedef.unflatten(leaves[:split]) + if a_transforms_treedef is not None + else [] + ) + b_transforms = ( + b_transforms_treedef.unflatten(leaves[split:]) + if b_transforms_treedef is not None + else [] + ) + return pp.concat([ + pp.text("wgmma_ref"), + pp.text(" "), + pp.text(jax_core.pp_var(acc, context)), + pp.text(" <- "), + state_primitives.pp_ref_transforms(context, a, a_transforms), + pp.text(" @ "), + state_primitives.pp_ref_transforms(context, b, b_transforms), + ]) + + +jax_core.pp_eqn_rules[wgmma_ref_p] = _wgmma_ref_pp_eqn + + @discharge.register_discharge_rule(wgmma_ref_p) def _wgmma_ref_discharge(in_avals, out_avals, *args, **kwargs): del in_avals, out_avals @@ -592,7 +1015,7 @@ def _wgmma_ref_discharge(in_avals, out_avals, *args, **kwargs): wgmma_p = jax_core.Primitive("wgmma") -@lowering.register_lowering_rule(wgmma_p, mgpu.ThreadSemantics.Lane) +@lowering.register_lowering_rule(wgmma_p, mgpu.LoweringSemantics.Lane) def _wgmma_lowering( ctx: lowering.LoweringRuleContext, acc, @@ -602,22 +1025,32 @@ def _wgmma_lowering( a_transforms_tree, b_transforms_tree, ): - _, a_aval, *_ = ctx.avals_in lhs_swizzle: int | None = None if a_transforms_tree is not None: a_transforms_leaves, b_transforms_leaves = util.split_list( transforms_leaves, [a_transforms_tree.num_leaves] ) a_transforms = a_transforms_tree.unflatten(a_transforms_leaves) - a, a_transforms = lowering._handle_indexing(a, a_transforms) + a, a_transforms = lowering._handle_transforms( + ctx, a, a_transforms, handle_transposes=False, handle_reshapes=False + ) match a_transforms: case (gpu_core.UnswizzleRef(lhs_swizzle), gpu_core.UntileRef(tiling)): - swizzle_elems = lhs_swizzle // a_aval.dtype.itemsize - if tiling != (64, swizzle_elems): - raise NotImplementedError("WGMMA lhs tiling does not fit swizzle") + lhs_transpose = False + case ( + gpu_core.UnswizzleRef(lhs_swizzle), + gpu_core.UntileRef(tiling), + gpu_core.TransposeRef((1, 0)), + ): + lhs_transpose = True case _: raise ValueError(f"WGMMA lhs has unsupported transforms: {a_transforms}.") + a_mlir_dtype = ir.MemRefType(a.type).element_type + swizzle_elems = lhs_swizzle // mgpu_utils.bytewidth(a_mlir_dtype) + if tiling != (8, swizzle_elems): + raise NotImplementedError("WGMMA lhs tiling does not fit swizzle") else: + lhs_transpose = False b_transforms_leaves = transforms_leaves # type: ignore if not isinstance(a, mgpu.FragmentedArray): raise ValueError( @@ -626,16 +1059,17 @@ def _wgmma_lowering( ) b_transforms = b_transforms_tree.unflatten(b_transforms_leaves) - b, b_transforms = lowering._handle_indexing(b, b_transforms) + b, b_transforms = lowering._handle_transforms( + ctx, b, b_transforms, handle_transposes=False, handle_reshapes=False + ) match b_transforms: case (gpu_core.UnswizzleRef(rhs_swizzle), gpu_core.UntileRef(rhs_tiling)): rhs_transpose = False case ( gpu_core.UnswizzleRef(rhs_swizzle), - gpu_core.TransposeRef((1, 0, 2, 3)), # Only transpose between tiles gpu_core.UntileRef(rhs_tiling), - gpu_core.TransposeRef((1, 0)), # Transpose the two logical dims + gpu_core.TransposeRef((1, 0)), ): rhs_transpose = True case ( @@ -661,19 +1095,68 @@ def _wgmma_lowering( raise ValueError(f"WGMMA rhs has unsupported transforms: {b_transforms}.") if lhs_swizzle is not None: - swizzle_elems = rhs_swizzle // a_aval.dtype.itemsize + b_mlir_dtype = ir.MemRefType(b.type).element_type + swizzle_elems = rhs_swizzle // mgpu_utils.bytewidth(b_mlir_dtype) if rhs_swizzle != lhs_swizzle: raise NotImplementedError("WGMMA rhs swizzle must match lhs swizzle") - if rhs_tiling != (swizzle_elems, swizzle_elems): + if rhs_tiling != (8, swizzle_elems): raise NotImplementedError("WGMMA rhs tiling does not fit swizzle") + if lhs_transpose: + a = mgpu.memref_transpose(a, (1, 0, 3, 2)) if rhs_transpose: - b = mgpu.memref_transpose(b, (0, 1, 3, 2)) + b = mgpu.memref_transpose(b, (1, 0, 3, 2)) new_acc = mgpu.wgmma(acc, a, b, swizzle=rhs_swizzle) nvvm_dialect.wgmma_commit_group_sync_aligned() return new_acc +@lowering.register_lowering_rule(wgmma_p, mgpu.LoweringSemantics.Warpgroup) +def _wgmma_warpgroup_lowering( + ctx: lowering.LoweringRuleContext, + acc, + a, + b, + *transforms_leaves, + a_transforms_tree, + b_transforms_tree, +): + if a_transforms_tree is not None: + a_transforms_leaves, b_transforms_leaves = util.split_list( + transforms_leaves, [a_transforms_tree.num_leaves] + ) + a_transforms = a_transforms_tree.unflatten(a_transforms_leaves) + a, a_transforms = lowering._handle_transforms(ctx, a, a_transforms) + match a_transforms: + case (gpu_core.TransposeRef((1, 0)),): + a = mgpu.memref_transpose(a, (1, 0)) + case (): + pass + case _: + raise ValueError( + f"WGMMA lhs has unsupported transforms: {a_transforms}." + ) + else: + b_transforms_leaves = transforms_leaves # type: ignore + + if b_transforms_tree is not None: + b_transforms = b_transforms_tree.unflatten(b_transforms_leaves) + b, b_transforms = lowering._handle_transforms(ctx, b, b_transforms) + match b_transforms: + case (gpu_core.TransposeRef((1, 0)),): + b = mgpu.memref_transpose(b, (1, 0)) + case (): + pass + case _: + raise ValueError( + f"WGMMA rhs has unsupported transforms: {b_transforms}." + ) + + new_acc = mgpu.dialect.wgmma(acc, a, b) + nvvm_dialect.wgmma_commit_group_sync_aligned() + return new_acc + + @wgmma_p.def_effectful_abstract_eval def _wgmma_effectful_abstract_eval(acc, lhs_ref, *args, **kwargs): del args, kwargs @@ -697,7 +1180,8 @@ def wgmma_wait_effectful_abstract_eval(_): return [], {gpu_core._wgmma_pipeline_effect} -@lowering.register_lowering_rule(wgmma_wait_p, mgpu.ThreadSemantics.Lane) +@lowering.register_lowering_rule(wgmma_wait_p, mgpu.LoweringSemantics.Lane) +@lowering.register_lowering_rule(wgmma_wait_p, mgpu.LoweringSemantics.Warpgroup) def _wgmma_wait_lowering(ctx: lowering.LoweringRuleContext, allow_groups): del ctx nvvm_dialect.wgmma_wait_group_sync_aligned(allow_groups) @@ -728,11 +1212,313 @@ def _wgmma_accumulator_deref_discharge(in_avals, out_avals, acc): return (None,), wgmma_accumulator_deref_p.bind(acc) -@lowering.register_lowering_rule(wgmma_accumulator_deref_p, mgpu.ThreadSemantics.Lane) +@lowering.register_lowering_rule( + wgmma_accumulator_deref_p, mgpu.LoweringSemantics.Lane +) +@lowering.register_lowering_rule( + wgmma_accumulator_deref_p, mgpu.LoweringSemantics.Warpgroup +) def _wgmma_accumulator_deref_lowering(ctx: lowering.LoweringRuleContext, acc): - del ctx nvvm_dialect.wgmma_wait_group_sync_aligned(0) - return acc.value + return ( + acc.value + if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Lane + else acc + ) + + +# MMA for TensorCore gen 5. +tcgen05_mma_p = jax_core.Primitive("tcgen05_mma") +tcgen05_mma_p.multiple_results = True + +def tcgen05_mma(acc: _Ref, + a: _Ref, + b: _Ref, + barrier: _Ref, + accumulate: bool | jax.Array = True, + collective_axis: str | None = None): + """Asynchronous matrix-multiply accumulate for TensorCore gen 5 (Blackwell). + + If run in collective mode, `acc`, `a` (LHS), and `b` (RHS) should correspond + to half of the total inputs to the MMA, where `acc` and `a` (LHS) are split + in half along the rows and `b` (RHS) is split along the columns like so: + + ----------- ----------- ----------- + | ACC1 | | LHS1 | | | | + ----------- += ----------- @ |RHS1|RHS2| + | ACC2 | | LHS2 | | | | + ----------- ----------- ----------- + + Args: + acc: The accumulator. Must be a TMEM Ref. + a: The left-hand side. Must be a TMEM/SMEM Ref. + b: The right-hand side. Must be an SMEM Ref. + barrier: Barrier Ref for synchronizing with the tensor core. Should have + for_tensor_core set to True. + accumulate: Whether to accumulate into acc or overwrite it. + collective_axis: The name of the cluster axis along which to perform + a collective MMA. The cluster axis should have a size of exactly 2, + and must be on the minormost cluster axis. + """ + acc_m, acc_n = acc.shape + lhs_m, lhs_k = a.shape + rhs_k, rhs_n = b.shape + if collective_axis is not None: + acc_n /= 2 + if acc_m != lhs_m: + raise ValueError( + f"Accumulator and LHS have incompatible shapes. Accumulator: {acc.shape}. LHS: {a.shape}.") + if acc_n != rhs_n: + raise ValueError( + f"Accumulator and RHS have incompatible shapes. Accumulator: {acc.shape}. RHS: {b.shape}.") + if lhs_k != rhs_k: + raise ValueError( + f"LHS and RHS have incompatible shapes. LHS: {a.shape}. RHS: {b.shape}.") + + if isinstance(acc, pallas_core.TransformedRef): + acc_transforms_leaves, acc_transforms_tree = jax.tree.flatten(acc.transforms) + acc = acc.ref + else: + acc_transforms_leaves, acc_transforms_tree = [], None + + if isinstance(a, pallas_core.TransformedRef): + a_transforms_leaves, a_transforms_tree = jax.tree.flatten(a.transforms) + a = a.ref + else: + a_transforms_leaves, a_transforms_tree = [], None + + if isinstance(b, pallas_core.TransformedRef): + b_transforms_leaves, b_transforms_tree = jax.tree.flatten(b.transforms) + b = b.ref + else: + b_transforms_leaves, b_transforms_tree = [], None + + if isinstance(barrier, pallas_core.TransformedRef): + barrier_transforms_leaves, barrier_transforms_tree = jax.tree.flatten( + barrier.transforms + ) + barrier = barrier.ref + else: + barrier_transforms_leaves, barrier_transforms_tree = [], None + + tcgen05_mma_p.bind(acc, a, b, barrier, accumulate, + *acc_transforms_leaves, *a_transforms_leaves, + *b_transforms_leaves, + *barrier_transforms_leaves, + acc_transforms_tree=acc_transforms_tree, + a_transforms_tree=a_transforms_tree, + b_transforms_tree=b_transforms_tree, + barrier_transforms_tree=barrier_transforms_tree, + collective_axis=collective_axis) + + +@tcgen05_mma_p.def_abstract_eval +def _tcgen05_mma_abstract_eval(acc, a, b, barrier, accumulate, + *transforms_leaves, + acc_transforms_tree, a_transforms_tree, + b_transforms_tree, + barrier_transforms_tree, + collective_axis): + del (accumulate, transforms_leaves, acc_transforms_tree, + a_transforms_tree, b_transforms_tree, barrier_transforms_tree) + + if acc.memory_space != gpu_core.TMEM: + raise ValueError("Accumulator must be a TMEM Ref.") + if a.memory_space not in (gpu_core.SMEM, gpu_core.TMEM): + raise ValueError("LHS must be a TMEM/SMEM Ref.") + if b.memory_space != gpu_core.SMEM: + raise ValueError("RHS must be an SMEM Ref.") + + if collective_axis is not None: + # TODO(justinfu): If under a core_map, the avals for acc/a + # become normal MemRefs so we cannot check if they are collective. + # Figure out a way to fix this. + if isinstance(acc, gpu_core.AbstractTMEMRef) and not acc.collective: + raise ValueError( + "Accumulator Ref must be collective if collective_axis is set.") + if isinstance(a, gpu_core.AbstractTMEMRef) and not a.collective: + raise ValueError( + "LHS Ref must be collective if collective_axis is set.") + + for_tensor_core = getattr( + barrier.inner_aval.dtype, "for_tensor_core", False) + if not for_tensor_core: + raise ValueError("MMA barrier must have for_tensor_core set to True.") + + return [] + + +@lowering.register_lowering_rule(tcgen05_mma_p, *gpu_core.LANExWG_SEMANTICS) +@lowering.register_lowering_rule(tcgen05_mma_p, *gpu_core.LANExWARP_SEMANTICS) +def _tcgen05_mma_lowering( + ctx: lowering.LoweringRuleContext, + acc: tcgen05.TMEMRef, + a_ref, + b_ref, + barrier_ref: mgpu.BarrierRef, + accumulate: bool | ir.Value, + *transforms_leaves, + acc_transforms_tree, + a_transforms_tree, + b_transforms_tree, + barrier_transforms_tree, + collective_axis, +): + _, a_aval, b_aval, *_ = ctx.avals_in + lhs_swizzle: int | None = None + lhs_transpose: bool = False + + transforms_trees = ( + acc_transforms_tree, + a_transforms_tree, + b_transforms_tree, + barrier_transforms_tree, + ) + (acc_transforms_leaves, a_transforms_leaves, b_transforms_leaves, barrier_transforms_leaves, _) = ( + util.split_list( + transforms_leaves, + [getattr(tree, "num_leaves", 0) for tree in transforms_trees], + ) + ) + + if acc_transforms_tree is not None: + acc_transforms = acc_transforms_tree.unflatten(acc_transforms_leaves) + acc, acc_transforms = lowering._handle_transforms(ctx, acc, acc_transforms) + if acc_transforms: + raise NotImplementedError( + f"Unsupported transforms: {acc_transforms}." + ) + + if a_transforms_tree is not None: + a_transforms = a_transforms_tree.unflatten(a_transforms_leaves) + a_ref, a_transforms = lowering._handle_transforms( + ctx, a_ref, a_transforms, handle_transposes=False, handle_reshapes=True + ) + match a_transforms: + case (gpu_core.UnswizzleRef(lhs_swizzle), gpu_core.UntileRef(lhs_tiling)): + lhs_transpose = False + case ( + gpu_core.UnswizzleRef(lhs_swizzle), + gpu_core.UntileRef(lhs_tiling), + gpu_core.TransposeRef((1, 0)), + ): + lhs_transpose = True + case _: + raise NotImplementedError( + f"Unsupported transforms: {a_transforms}." + ) + swizzle_elems = lhs_swizzle // a_aval.dtype.itemsize + if lhs_tiling != (8, swizzle_elems): + raise ValueError("MMA lhs tiling does not fit swizzle. " + f"{lhs_tiling=} expected={(8, swizzle_elems)}") + + assert b_transforms_tree is not None + b_transforms = b_transforms_tree.unflatten(b_transforms_leaves) + b_ref, b_transforms = lowering._handle_transforms( + ctx, b_ref, b_transforms, handle_transposes=False, handle_reshapes=True + ) + match b_transforms: + case (gpu_core.UnswizzleRef(rhs_swizzle), gpu_core.UntileRef(rhs_tiling)): + rhs_transpose = False + case ( + gpu_core.UnswizzleRef(rhs_swizzle), + gpu_core.UntileRef(rhs_tiling), + gpu_core.TransposeRef((1, 0)), + ): + rhs_transpose = True + case _: + raise NotImplementedError( + f"Unsupported transforms: {b_transforms}." + ) + swizzle_elems = rhs_swizzle // b_aval.dtype.itemsize + if rhs_tiling != (8, swizzle_elems): + raise ValueError( + "MMA rhs tiling does not fit swizzle" + f" {rhs_tiling=} expected={(8, swizzle_elems)}" + ) + + if barrier_transforms_tree is not None: + barrier_transforms = barrier_transforms_tree.unflatten( + barrier_transforms_leaves + ) + indexer = _extract_barrier_indexer(barrier_transforms) + if indexer is not None: + barrier_ref = barrier_ref.__getitem__( + *map(lowering._as_index, indexer.indices) + ) + + if lhs_swizzle is None: + lhs_swizzle = rhs_swizzle + elif rhs_swizzle != lhs_swizzle: + raise ValueError("MMA rhs swizzle must match lhs swizzle." + f" {lhs_swizzle=} {rhs_swizzle=}") + if lhs_transpose: + if isinstance(a_ref, tcgen05.TMEMRef): + raise ValueError("TMEM transpose not allowed.") + a_ref = mgpu.memref_transpose(a_ref, (1, 0, 3, 2)) + if rhs_transpose: + b_ref = mgpu.memref_transpose(b_ref, (1, 0, 3, 2)) + if isinstance(accumulate, bool): + accumulate = mgpu.c(accumulate, ir.IntegerType.get_signless(1)) + elif isinstance(accumulate, mgpu.FragmentedArray): + accumulate = accumulate.registers.item() + assert isinstance(accumulate, ir.Value) + + predicate = ctx.module_ctx.single_lane_predicate + collective = False + if collective_axis is not None: + cluster_axis = lowering._resolve_cluster_axis( + ctx.module_ctx.axis_names, collective_axis) + if cluster_axis != gpu_dialect.Dimension(0): + # Note: resolve_cluster_axis checks if axis_names exists. + assert ctx.module_ctx.axis_names is not None + if len(ctx.module_ctx.axis_names.cluster) <= 1: + raise ValueError("No cluster axes found.") + minormost_cluster_axis = ctx.module_ctx.axis_names.cluster[0] + raise ValueError( + "Can only perform collective MMA along minormost cluster axis. " + f"Got {collective_axis}, expected {minormost_cluster_axis}.") + index = ir.IndexType.get() + is_leader_block = arith_dialect.cmpi( + arith_dialect.CmpIPredicate.eq, + ctx.launch_ctx.cluster_idx(cluster_axis), mgpu.c(0, index)) + predicate = arith_dialect.andi(predicate, is_leader_block) + collective = True + + with mgpu.when(predicate): + tcgen05.mma( + acc, + a_ref, + b_ref, + a_swizzle=int(lhs_swizzle), + b_swizzle=int(rhs_swizzle), + accumulate=accumulate, + collective=collective, + ) + tcgen05.commit_arrive(barrier_ref, + collective=collective, + ctx=ctx.launch_ctx) + return [] + + +commit_tmem_p = jax_core.Primitive("commit_tmem") +commit_tmem_p.multiple_results = True + + +@commit_tmem_p.def_effectful_abstract_eval +def _commit_tmem_abstract_eval(): + return (), {gpu_core._memory_effect} + + +@lowering.register_lowering_rule(commit_tmem_p, mgpu.LoweringSemantics.Lane) +def _commit_tmem_lowering(_): + tcgen05.commit_tmem() + return () + + +def commit_tmem(): + """Commits all writes to TMEM, making them visible to loads and MMA.""" + commit_tmem_p.bind() class Layout(enum.Enum): @@ -740,10 +1526,17 @@ class Layout(enum.Enum): WGMMA = enum.auto() #: [m] matrix, where m % 64 == 0. WGMMA_ROW = enum.auto() + #: [n] matrix, where n % 8 == 0. + WGMMA_COL = enum.auto() + WGMMA_TRANSPOSED = enum.auto() WG_SPLAT = enum.auto() WG_STRIDED = enum.auto() + TCGEN05 = enum.auto() + TCGEN05_ROW = enum.auto() + TCGEN05_COL = enum.auto() + def __call__(self, *args, **kwargs) -> ParameterizedLayout: return ParameterizedLayout(self, args, kwargs) @@ -753,16 +1546,31 @@ def check_no_args(): raise ValueError(f"Can't instantiate {self} with arguments.") match self: + case Layout.WGMMA_TRANSPOSED: + check_no_args() + return mgpu.WGMMA_TRANSPOSED_LAYOUT case Layout.WGMMA: check_no_args() return mgpu.WGMMA_LAYOUT case Layout.WGMMA_ROW: check_no_args() return mgpu.WGMMA_ROW_LAYOUT + case Layout.WGMMA_COL: + check_no_args() + return mgpu.WGMMA_COL_LAYOUT case Layout.WG_SPLAT: return mgpu.WGSplatFragLayout(*args, **kwargs) # pytype: disable=missing-parameter case Layout.WG_STRIDED: - return mgpu.WGStridedFragLayout(*args, **kwargs) + return mgpu.WGStridedFragLayout(*args, **kwargs) # pytype: disable=missing-parameter + case Layout.TCGEN05: + check_no_args() + return mgpu.TCGEN05_LAYOUT + case Layout.TCGEN05_ROW: + check_no_args() + return mgpu.TCGEN05_ROW_LAYOUT + case Layout.TCGEN05_COL: + check_no_args() + return mgpu.TCGEN05_COL_LAYOUT @dataclasses.dataclass(frozen=True) class ParameterizedLayout: @@ -770,6 +1578,10 @@ class ParameterizedLayout: args: Sequence[Any] kwargs: Any + def __post_init__(self): + object.__setattr__(self, "args", tuple(self.args)) + object.__setattr__(self, "kwargs", frozen_dict.FrozenDict(self.kwargs)) + def to_mgpu(self) -> mgpu.FragmentedLayout: return self.layout_cls.to_mgpu(*self.args, **self.kwargs) @@ -783,12 +1595,20 @@ def _layout_cast_abstract_eval(x, new_layout): return x -@lowering.register_lowering_rule(layout_cast_p, mgpu.ThreadSemantics.Lane) +@lowering.register_lowering_rule(layout_cast_p, mgpu.LoweringSemantics.Lane) def _layout_cast_lowering(ctx: lowering.LoweringRuleContext, x, *, new_layout): del ctx # Unused. return x.to_layout(new_layout.to_mgpu()) +@lowering.register_lowering_rule(layout_cast_p, mgpu.LoweringSemantics.Warpgroup) +def _layout_cast_lowering_wg( + ctx: lowering.LoweringRuleContext, x, *, new_layout +): + del ctx # Unused. + return mgpu.dialect.layout_cast(x, mgpu.to_layout_attr(new_layout.to_mgpu())) + + def layout_cast(x: Any, new_layout: Layout | ParameterizedLayout): """Casts the layout of the given array.""" return layout_cast_p.bind(x, new_layout=new_layout) @@ -804,7 +1624,10 @@ def _set_max_registers_abstract_eval(n, *, action): return (), {gpu_core._memory_effect} -@lowering.register_lowering_rule(set_max_registers_p, mgpu.ThreadSemantics.Lane) +@lowering.register_lowering_rule( + set_max_registers_p, mgpu.LoweringSemantics.Lane) +@lowering.register_lowering_rule( + set_max_registers_p, mgpu.LoweringSemantics.Warpgroup) def _set_max_registers_lowering( ctx: lowering.LoweringRuleContext, n, *, action ): @@ -832,9 +1655,11 @@ def _commit_smem_abstract_eval(): return (), {gpu_core._memory_effect} -@lowering.register_lowering_rule(commit_smem_p, mgpu.ThreadSemantics.Lane) -@lowering.register_lowering_rule(commit_smem_p, mgpu.ThreadSemantics.Warpgroup) +@lowering.register_lowering_rule(commit_smem_p, mgpu.LoweringSemantics.Lane) +@lowering.register_lowering_rule( + commit_smem_p, mgpu.LoweringSemantics.Warpgroup) def _commit_smem_lowering(ctx: lowering.LoweringRuleContext): + # TODO(bchetioui): add primitive for commit smem to mosaic_gpu dialect. mgpu.commit_shared() return () @@ -852,7 +1677,8 @@ def _broadcasted_iota_abstract_eval(dtype, shape, dimension, layout): return jax_core.ShapedArray(shape, dtype) -@lowering.register_lowering_rule(broadcasted_iota_p, mgpu.ThreadSemantics.Lane) +@lowering.register_lowering_rule( + broadcasted_iota_p, mgpu.LoweringSemantics.Lane) def _broadcasted_iota_lowering( ctx: lowering.LoweringRuleContext, dtype, shape, dimension, layout ): @@ -900,8 +1726,42 @@ def _jaxpr_call_abstract_eval(*args, jaxpr: jax_core.Jaxpr, **params): return [v.aval for v in jaxpr.outvars] -@lowering.register_lowering_rule(jaxpr_call_p, mgpu.ThreadSemantics.Lane) -@lowering.register_lowering_rule(jaxpr_call_p, mgpu.ThreadSemantics.Warpgroup) +def _jaxpr_call_pp_eqn( + eqn: jax_core.JaxprEqn, + context: jax_core.JaxprPpContext, + settings: jax_core.JaxprPpSettings, +): + flat_args = eqn.invars + ref_treedefs = eqn.params["ref_treedefs"] + flat_refs, _ = util.split_list( + flat_args, [sum(treedef.num_leaves for treedef in ref_treedefs)] + ) + flat_refs = util.split_list( + flat_refs, + [treedef.num_leaves for treedef in ref_treedefs[: len(ref_treedefs) - 1]], + ) + trailer = [] + for treedef, flat_ref in zip(ref_treedefs, flat_refs): + ref = treedef.unflatten(flat_ref) + transforms = [] + if isinstance(ref, tuple): + ref, transforms = ref + trailer.append(pp.text(" ")) + trailer.append(state_primitives.pp_ref_transforms(context, ref, transforms)) + return pp.concat([ + pp.text("jaxpr_call"), + pp.text("["), + jax_core.pp_kv_pair("jaxpr", eqn.params["jaxpr"], context, settings), + pp.text("]"), + pp.concat(trailer), + ]) + + +jax_core.pp_eqn_rules[jaxpr_call_p] = _jaxpr_call_pp_eqn + + +@lowering.register_lowering_rule(jaxpr_call_p, mgpu.LoweringSemantics.Lane) +@lowering.register_lowering_rule(jaxpr_call_p, mgpu.LoweringSemantics.Warpgroup) def _jaxpr_call_lowering_rule( ctx: lowering.LoweringRuleContext, *flat_args, @@ -920,9 +1780,12 @@ def _jaxpr_call_lowering_rule( for treedef, flat_ref in zip(ref_treedefs, flat_refs): ref = treedef.unflatten(flat_ref) if isinstance(ref, tuple): + ref, transforms = ref # We ignore other transforms here, because they are already embedded # in the jaxpr. - ref, _ = lowering._handle_indexing(*ref) + ref, _ = lowering._handle_transforms( + ctx, ref, transforms, handle_reshapes=False, handle_transposes=False + ) args.append(ref) program_ids = program_ids_treedef.unflatten(flat_program_ids) for axis, pid in enumerate(program_ids): @@ -969,7 +1832,7 @@ def _jaxpr_call_discharge( outs = jaxpr_call_p.bind( *flat_args, jaxpr=discharged_jaxpr, - ref_treedefs=ref_treedefs, + ref_treedefs=tuple(ref_treedefs), program_ids_treedef=program_ids_treedef, ) discharged_outs_it = iter(outs[len(jaxpr.outvars) :]) @@ -1022,6 +1885,230 @@ def jaxpr_call( *flat_refs, *flat_program_ids, jaxpr=jaxpr, - ref_treedefs=ref_treedefs, + ref_treedefs=tuple(ref_treedefs), program_ids_treedef=program_ids_treedef, ) + + +@dataclasses.dataclass(frozen=True) +class ShapeDtypeStruct: + shape: tuple[int, ...] + dtype: jnp.dtype + layout: ParameterizedLayout | Layout + + +inline_mgpu_p = jax_core.Primitive("inline_mgpu_p") +inline_mgpu_p.multiple_results = True + + +@dataclasses.dataclass(frozen=True) +class RefType: + transforms: tuple[gpu_core.MemoryRefTransform, ...] = () + + +def _undo_transforms( + raw_ref: pallas_core.AbstractMemoryRef, + memory_transforms: Sequence[gpu_core.MemoryRefTransform], +): + """Extract the `Transform`s that reverse the `MemoryRefTransform`s""" + tmp_ref = state_types.TransformedRef(raw_ref, transforms=()) + tmp_ref = functools.reduce(lambda r, t: t.undo(r), reversed(memory_transforms), tmp_ref) + return tmp_ref.transforms + + +def inline_mgpu(*, arg_types=(), return_type=None): + r"""Returns a decorator that inlines Mosaic GPU code. + + This allows using lower-level Mosaic GPU abstractions and operations, which + are otherwise not directly exposed in Pallas. + + Example:: + + layout = plgpu.Layout.WG_STRIDED(x_ref.shape, vec_size=4) + + @plgpu.inline_mgpu( + arg_types=(plgpu.RefType(),), + return_type=plgpu.ShapeDtypeStruct( + (128, 128), dtype, layout=layout + ), + ) + def add_one(ctx, smem_ref): + x = mgpu.FragmentedArray.load_tiled(smem_ref) + y = mgpu.FragmentedArray.splat( + mgpu.c(1, x.mlir_dtype), shape=x.shape, layout=x.layout + ) + return x + y + + Args: + arg_types: A sequence of pytrees where the leaves are + {class}`~jax.experimental.pallas.mosaic_gpu.RefType`\s or + {class}`~jax.experimental.pallas.mosaic_gpu.Layout`\s for reference or + array arguments respectively. + return_type: A pytree where the leaves are + {class}`~jax.experimental.pallas.mosaic_gpu.ShapeDtypeStruct`\s + representing the arrays returned by the decorated function. + """ + flat_arg_types, treedef_ty = jax.tree.flatten(tuple(arg_types)) + flat_ret_ty, pytree_ret_ty = jax.tree.flatten(return_type) + if return_type and not all(isinstance(r, ShapeDtypeStruct) for r in flat_ret_ty): + raise ValueError( + "inline_mgpu_p only supports plgpu.ShapeDtypeStruct return types." + ) + if not all(isinstance(r, (Layout, ParameterizedLayout, RefType)) for r in flat_arg_types): + raise ValueError( + "inline_mgpu_p only supports only Layout, ParameterizedLayout and" + " RefType arg types." + ) + + def inner(f): + def wrapper(*args): + flat_args, treedef = jax.tree.flatten(tuple(args)) + if treedef != treedef_ty: + raise ValueError(f"Mismatched type shape: {treedef} != {treedef_ty}") + + # Strip the transforms from the refs since they will be recorded in + # the types. + ref_transforms = [] + raw_flat_args = [] + for a, t in zip(flat_args, flat_arg_types): + if isinstance(a, state_types.TransformedRef) and isinstance(t, RefType): + raw_flat_args.append(a.ref) + ref_transforms.append(a.transforms) + elif isinstance(aval := jax_core.get_aval(a), jax_core.ShapedArray) and isinstance(t, (ParameterizedLayout, Layout)): + raw_flat_args.append(a) + ref_transforms.append(None) + elif isinstance(aval, state.AbstractRef) and isinstance(t, RefType): + raw_flat_args.append(a) + ref_transforms.append(()) + else: + raise ValueError(f"Mismatched type: {a, t}") + + flat_ref_transforms, pytree_ref_transforms = jax.tree.flatten(ref_transforms) + flat_ret = inline_mgpu_p.bind( + *raw_flat_args, + *flat_ref_transforms, + flat_arg_types=tuple(flat_arg_types), + flat_ret_ty=tuple(flat_ret_ty), + pytree_ret_ty=pytree_ret_ty, + pytree_args=treedef, + pytree_ref_transforms=pytree_ref_transforms, + mgpu_fn=f, + ) + return jax.tree.unflatten(pytree_ret_ty, flat_ret) + return wrapper + + return inner + + +@inline_mgpu_p.def_effectful_abstract_eval +def _inline_mgpu_abstract_eval( + *flat_args_and_transforms, + flat_arg_types, + flat_ret_ty, + pytree_args, + pytree_ref_transforms, + pytree_ret_ty, + mgpu_fn, +): + del flat_arg_types, pytree_ret_ty, pytree_ref_transforms, mgpu_fn # Unused. + aval_return = tuple( + jax_core.ShapedArray(x.shape, x.dtype) for x in flat_ret_ty + ) + # TODO(cperivol): Let the user set the effects. + flat_args = flat_args_and_transforms[:pytree_args.num_leaves] + return aval_return, { + gpu_core._wgmma_pipeline_effect, + gpu_core._memory_effect, + *itertools.chain.from_iterable( + (state.ReadEffect(i), state.WriteEffect(i)) + for i, r in enumerate(flat_args) + if isinstance(r, pallas_core.AbstractMemoryRef) + ), + } + + +@discharge.register_partial_discharge_rule(inline_mgpu_p) +def _inline_mgpu_discharge(*args, **kwargs): + del args, kwargs + raise NotImplementedError("inline_mgpu_p does not support discharge.") + + +def _type_check_mgpu(v, ty): + match (ty, v): + case (RefType(), ir.Value()) if ir.MemRefType.isinstance(v.type): + pass + case (ShapeDtypeStruct(), mgpu.FragmentedArray()): + mlir_dtype = mgpu_utils.dtype_to_ir_type(ty.dtype) + if v.mlir_dtype != mlir_dtype: + raise ValueError( + f"Array dtype mismatch: expected {v.mlir_dtype} got {mlir_dtype}." + ) + if ty.shape != v.shape: + raise ValueError( + f"Array shape mismatch: expected {ty.shape} got {v.shape}." + ) + if v.layout != ty.layout.to_mgpu(): + raise ValueError( + f"Array layout mismatch: expected {v.layout} got {ty.layout.to_mgpu()}." + ) + case (Layout() , mgpu.FragmentedArray()) | (ParameterizedLayout(), mgpu.FragmentedArray()): + if ty.to_mgpu() != v.layout: + raise ValueError(f"Unexpected layout for {v} (expected: {ty})") + case _: + raise ValueError(f"Unexpected type {ty} for value {v}") + + +@lowering.register_lowering_rule(inline_mgpu_p, mgpu.LoweringSemantics.Lane) +def _inline_mgpu_lowering_rule( + ctx: lowering.LoweringRuleContext, + *flat_args_and_transforms, + mgpu_fn: Callable[..., Any], + flat_arg_types, + flat_ret_ty, + pytree_args, + pytree_ref_transforms, + pytree_ret_ty, +): + flat_args = flat_args_and_transforms[:pytree_args.num_leaves] + flat_arg_avals = ctx.avals_in[:pytree_args.num_leaves] + ref_transforms = pytree_ref_transforms.unflatten(flat_args_and_transforms[pytree_args.num_leaves:]) + for a, t in zip(flat_args, flat_arg_types): + _type_check_mgpu(a, t) + + flat_transformed = [] + for a, aval, t, transforms in zip( + flat_args, flat_arg_avals, flat_arg_types, ref_transforms, strict=True + ): + if not isinstance(t, RefType): + flat_transformed.append(a) + assert transforms is None + continue + assert isinstance(aval, pallas_core.AbstractMemoryRef) + a, user_transforms = lowering._handle_transforms( + ctx, a, transforms, handle_transposes=False + ) + # Transforms that do not originate from a MemoryRefTransform are + # applied implicitly (eg by emit-pipeline) and therefore we do not + # expect the user to pass them to the type. The transforms not + # passed by the user here will be discharged. + ty_transforms = _undo_transforms(aval, t.transforms) + if ty_transforms != tuple(user_transforms): + raise ValueError(f"Transform mismatch: got {user_transforms}, expected {ty_transforms}") + flat_transformed.append(a) + + args = jax.tree.unflatten(pytree_args, flat_transformed) + ret = mgpu_fn(ctx.launch_ctx, *args) + ret_leaves, ret_tree = jax.tree.flatten( + ret, is_leaf=lambda x: isinstance(x, mgpu.FragmentedArray) + ) + + if ret_tree != pytree_ret_ty: + return_type = jax.tree.unflatten(pytree_ret_ty, flat_ret_ty) + raise ValueError( + f"inline_mgpu_p return type tree mismatch: {ret} != {return_type}" + ) + + for ty, r in zip(flat_ret_ty, ret_leaves): + _type_check_mgpu(r, ty) + + return ret_leaves diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index d0b74b2e5148..295e60bb0b4d 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -15,12 +15,11 @@ """Module for calling pallas functions from JAX.""" from __future__ import annotations -from collections.abc import Callable, Sequence -import dataclasses +from collections.abc import Callable, Mapping, Sequence import enum from functools import partial, reduce import types -from typing import Any, Literal, cast +from typing import Any, cast import jax from jax import lax @@ -33,6 +32,7 @@ from jax._src import linear_util as lu from jax._src import state from jax._src import tree_util +from jax._src.frozen_dict import FrozenDict from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir @@ -68,6 +68,8 @@ no_block_spec = pallas_core.no_block_spec ScratchShapeTree = pallas_core.ScratchShapeTree CostEstimate = pallas_core.CostEstimate +Backend = pallas_core.Backend +CompilerParams = pallas_core.CompilerParams # See the docstring for GridMapping for the calling convention pallas_call_p = jax_core.Primitive('pallas_call') @@ -79,7 +81,9 @@ def _pallas_call_impl(*args, **params): @partial(jax.jit, inline=True) def _jit_run(*args): return pallas_call_p.bind(*args, **params) - return _jit_run(*args) + + with config.disable_jit(False): + return _jit_run(*args) pallas_call_p.def_impl(_pallas_call_impl) @@ -93,7 +97,7 @@ def _pallas_call_abstract_eval( ): del avals - if isinstance(interpret, mosaic_tpu_interpret.TPUInterpretParams): + if isinstance(interpret, mosaic_tpu_interpret.InterpretParams): # Report effects that will be introduced when running/lowering # mosaic_tpu_interpret.mosaic_tpu_interpret.interpret_pallas_call . effs = mosaic_tpu_interpret.get_interpret_effects() @@ -121,11 +125,11 @@ def _pallas_call_jvp_rule( grid_mapping: GridMapping, mesh: pallas_core.Mesh | None, debug: bool, - interpret: bool, + interpret: Any, compiler_params: Any, cost_estimate: CostEstimate | None, out_avals: tuple[jax_core.AbstractValue, ...], - backend: _Backend | None, + backend: Backend | None, ): debug_info = jaxpr.debug_info if grid_mapping.num_dynamic_grid_bounds: @@ -220,14 +224,19 @@ def _block_map_function(new_idx, *args): block_mapping.index_map_jaxpr.consts, *drop_last_args, ) + unflat_indices = tree_util.tree_unflatten( + block_mapping.index_map_out_tree, indices) + if not isinstance(unflat_indices, tuple): + unflat_indices = (unflat_indices,) + unflat_indices = list(unflat_indices) if dim is not batching.not_mapped: if isinstance(dim, batching.RaggedAxis): assert for_ragged, "Ragged axis not supported for non-ragged batching." stacked_axis = dim.stacked_axis - indices.insert(stacked_axis, new_idx) + unflat_indices.insert(stacked_axis, new_idx) else: - indices.insert(dim, new_idx) - return tuple(indices) + unflat_indices.insert(dim, new_idx) + return tuple(unflat_indices) idx_avals = [pallas_core.index_map_grid_aval, *block_mapping.index_map_jaxpr.in_avals] if for_ragged: @@ -242,11 +251,15 @@ def _block_map_function(new_idx, *args): ) idx_avals = [*idx_avals, i32_aval_memref] + block_mapping_flat_fn, out_tree_thunk = api_util.flatten_fun_nokwargs( + lu.wrap_init(_block_map_function, + debug_info=block_mapping.index_map_jaxpr.jaxpr.debug_info), + tree_util.tree_structure(idx_avals)) with grid_mapping.trace_env(): block_mapping_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic( - lu.wrap_init(_block_map_function, - debug_info=block_mapping.index_map_jaxpr.jaxpr.debug_info), + block_mapping_flat_fn, idx_avals) + new_index_map_out_tree = out_tree_thunk() shape = block_mapping.block_shape if dim is batching.not_mapped: new_block_shape = shape @@ -257,10 +270,10 @@ def _block_map_function(new_idx, *args): new_block_shape = shape stacked_axis = dim.stacked_axis new_block_shape = tuple_insert( - new_block_shape, stacked_axis, pallas_core.mapped + new_block_shape, stacked_axis, pallas_core.squeezed ) else: - new_block_shape = tuple_insert(shape, dim, pallas_core.mapped) + new_block_shape = tuple_insert(shape, dim, pallas_core.squeezed) array_shape = block_mapping.array_shape_dtype.shape if isinstance(dim, batching.RaggedAxis): @@ -277,7 +290,8 @@ def _block_map_function(new_idx, *args): jaxpr = jax_core.ClosedJaxpr(block_mapping_jaxpr, consts) return block_mapping.replace(block_shape=new_block_shape, array_shape_dtype=new_array_shape_dtype, - index_map_jaxpr=jaxpr) + index_map_jaxpr=jaxpr, + index_map_out_tree=new_index_map_out_tree) def _broadcast_input_output_aliases( @@ -291,7 +305,7 @@ def _broadcast_input_output_aliases( When we have input/output aliasing, since the output will be mapped, we need to make sure to broadcast the input across that dimension if it is not - mapped. If the input is mapped, but on a different axis, we tranpose the input + mapped. If the input is mapped, but on a different axis, we transpose the input to match the output. """ @@ -324,11 +338,11 @@ def _batch_with_explicit_loop( mesh: pallas_core.Mesh | None, input_output_aliases: tuple[tuple[int, int], ...], debug: bool, - interpret: bool, + interpret: Any, compiler_params: Any, cost_estimate: CostEstimate | None, out_avals: tuple[jax_core.AbstractValue, ...], - backend: _Backend | None, + backend: Backend | None, ): """Batch the pallas_call by calling it in loop over the batch size. @@ -357,7 +371,7 @@ def _batch_with_explicit_loop( axis_size=axis_size, ) - # The output arrays are completelly overwritten, so we can just initialize + # The output arrays are completely overwritten, so we can just initialize # empty arrays. initial_state = [ jnp.empty(tuple_insert(bm.array_shape_dtype.shape, 0, axis_size), @@ -422,11 +436,11 @@ def _pallas_call_batching_rule( mesh: pallas_core.Mesh | None, input_output_aliases: tuple[tuple[int, int], ...], debug: bool, - interpret: bool, + interpret: Any, compiler_params: Any, cost_estimate: CostEstimate | None, out_avals: tuple[jax_core.AbstractValue, ...], - backend: _Backend | None, + backend: Backend | None, ): if mesh is not None: raise NotImplementedError( @@ -663,7 +677,7 @@ def get_size(i, x, d): for block_mapping in batched_grid_mapping.block_mappings: mapped_dim_idxs = [] for i, d in enumerate(block_mapping.block_shape): - if d is pallas_core.mapped: + if isinstance(d, pallas_core.Squeezed): mapped_dim_idxs.append(i) else: mapped_dim_idxs.append(None) # type: ignore[arg-type] @@ -754,7 +768,7 @@ def when_wrapped_kernel(lengths_ref, *args, **kwargs): continue arg_i_idx = ( primitives.program_id(ragged_axis_dim) - * block_shapes[i][ragged_axis_dim] + * pallas_core._get_block_dim_size(block_shapes[i][ragged_axis_dim]) ) run_kernel = jnp.logical_and(run_kernel, arg_i_idx < b_len) @@ -788,7 +802,7 @@ def index_rewrite_kernel(*indexer_args): ragged_axis_dim = per_input_ragged_axis_dim[arg_pos] # the problem here seems to be that we are rnning this for all inputs, per input, because they each have an indexer - which means - # that the indexer for output isnt getting written - before, it always was + # that the indexer for output isn't getting written - before, it always was lengths_ref = indexer_args[-1] rest_indexer_args = indexer_args[:-1] @@ -800,7 +814,8 @@ def index_rewrite_kernel(*indexer_args): nargs = list(rest_indexer_args) if ragged_axis_dim is not None: - val_at_ragged_dim = batched_block_mapping.block_shape[ragged_axis_dim] + val_at_ragged_dim = pallas_core._get_block_dim_size( + batched_block_mapping.block_shape[ragged_axis_dim]) # The current index into the ragged dimension. # Invariant: There is only one ragged dimension, enforced above. @@ -882,7 +897,7 @@ def index_rewrite_kernel(*indexer_args): raise NotImplementedError("consts not supported in pallas_call") # We need to rewrite the input_output_aliases here, the initial call - # to broadcast is done, and we have inseted a new input (lengths), so + # to broadcast is done, and we have inserted a new input (lengths), so # there's an off-by-one here now. new_input_output_aliases = [] for k, v in input_output_aliases: @@ -895,7 +910,7 @@ def index_rewrite_kernel(*indexer_args): batched_out_avals = [] for aval in out_avals: - sharding = aval.sharding.with_spec(tuple_insert(aval.sharding.spec, 0, None)) + sharding = aval.sharding.update(spec=tuple_insert(aval.sharding.spec, 0, None)) shape = tuple_insert(aval.shape, 0, axis_size) batched_out_avals.append(aval.update(shape=shape, sharding=sharding)) batched_out_avals = tuple(batched_out_avals) @@ -965,16 +980,15 @@ def pallas_call_checkify_oob_grid(error: checkify.Error, num_iterations = 1 is_indexing_dim = [ - tuple(b is pallas_core.mapped for b in bm.block_shape) + tuple(isinstance(b, pallas_core.Squeezed) for b in bm.block_shape) for bm in grid_mapping.block_mappings ] block_shapes = [ - None if iid is None - else tuple(1 if i else b for i, b in zip(iid, bm.block_shape)) - for iid, bm in zip(is_indexing_dim, grid_mapping.block_mappings) + pallas_core._get_block_shape(bm.block_shape) + for bm in grid_mapping.block_mappings ] # The scan carry: (i, loop_idx, *consts, *ins, *outs, *scratch) - # i:int32 is the interation index + # i:int32 is the iteration index # loop_idx: tuple[int32] are the program ids for each grid axis def cond(carry): i, *_ = carry @@ -1021,7 +1035,7 @@ def pallas_call_checkify_rule(error: checkify.Error, enabled_errors, *args: jax_core.Value, jaxpr: jax_core.Jaxpr, - interpret: bool, + interpret: Any, input_output_aliases: tuple[tuple[int, int], ...], grid_mapping: GridMapping, out_avals: tuple[jax_core.AbstractValue, ...], @@ -1118,10 +1132,10 @@ def _ensure_2d_error_shape(arg): retrace_in_avals = [*shaped_scalar_avals, *error_memref_aval, *input_aval, *error_memref_aval, *output_aval, *scratch_aval] jaxpr_flat_avals, jaxpr_in_tree = tree_util.tree_flatten(retrace_in_avals) - debug = api_util.debug_info("checkify_pallas", checked_kernel_fn, + debug_info = api_util.debug_info("checkify_pallas", checked_kernel_fn, retrace_in_avals, {}) wrapped_kernel_with_err, out_tree_thunk = api_util.flatten_fun_nokwargs( - lu.wrap_init(checked_kernel_fn, debug_info=debug), jaxpr_in_tree) + lu.wrap_init(checked_kernel_fn, debug_info=debug_info), jaxpr_in_tree) with pallas_core.tracing_grid_env(grid_mapping.grid, ()): final_jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic( @@ -1131,15 +1145,20 @@ def _ensure_2d_error_shape(arg): # for the new error inputs and outputs. error_block_specs = [pallas_core.BlockSpec(None, None)] * len(shaped_err_avals) error_paths, _ = unzip2(tree_util.tree_flatten_with_path(error_block_specs)[0]) - error_origins = tuple(f"errrors[{tree_util.keystr(p)}" for p in error_paths) + error_origins = tuple(f"errors[{tree_util.keystr(p)}" for p in error_paths) error_block_mappings = map( - partial( - pallas_core._convert_block_spec_to_block_mapping, - index_map_avals=grid_mapping.index_map_avals, - index_map_tree=grid_mapping.index_map_tree, - grid=grid_mapping.grid, - mapped_dims=grid_mapping.vmapped_dims), - error_block_specs, error_origins, shaped_err_avals) + partial( + pallas_core._convert_block_spec_to_block_mapping, + index_map_avals=grid_mapping.index_map_avals, + index_map_tree=grid_mapping.index_map_tree, + grid=grid_mapping.grid, + mapped_dims=grid_mapping.vmapped_dims, + debug=True, + ), + error_block_specs, + error_origins, + shaped_err_avals, + ) input_block_mappings, output_block_mappings = split_list( grid_mapping.block_mappings, [num_kernel_inputs,]) grid_mapping_with_error = grid_mapping.replace( @@ -1187,7 +1206,7 @@ def _trace_kernel_to_jaxpr( wrapped_kernel_fun = primitives.wrap_with_transforms( wrapped_kernel_fun, kernel_in_transforms ) - with grid_mapping.trace_env(): + with grid_mapping.trace_env(), config._check_vma(False): jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_kernel_fun, kernel_avals) if consts: @@ -1206,7 +1225,7 @@ def _trace_kernel_to_jaxpr( return jaxpr, tuple(consts) -_PALLAS_USE_MOSAIC_GPU = config.bool_flag( +_PALLAS_USE_MOSAIC_GPU = config.bool_state( "jax_pallas_use_mosaic_gpu", default=config.bool_env("JAX_PALLAS_USE_MOSAIC_GPU", False), help=( @@ -1218,7 +1237,7 @@ def _trace_kernel_to_jaxpr( _PALLAS_VERBOSE_ERRORS = config.bool_flag( "jax_pallas_verbose_errors", - default=config.bool_env("JAX_PALLAS_VERBOSE_ERRORS", True), + default=config.bool_env("JAX_PALLAS_VERBOSE_ERRORS", False), help=( "If True, print verbose error messages for Pallas kernels." ), @@ -1234,23 +1253,21 @@ def _unsupported_lowering_error(platform: str) -> Exception: f"Cannot lower pallas_call on platform: {platform}. To use Pallas on GPU," " install jaxlib GPU 0.4.24 or newer. To use Pallas on TPU, install" " jaxlib TPU and libtpu. See" - " https://jax.readthedocs.io/en/latest/installation.html." + " https://docs.jax.dev/en/latest/installation.html." ) -_Backend = Literal["mosaic_tpu", "triton", "mosaic_gpu"] - def _pallas_call_lowering( ctx: mlir.LoweringRuleContext, *in_nodes, - interpret: bool, - backend: _Backend | None, + interpret: Any, + backend: Backend | None, **params, ): if params['jaxpr'].constvars: raise ValueError('Cannot lower a pallas_call with constants.') if interpret: - if isinstance(interpret, mosaic_tpu_interpret.TPUInterpretParams): + if isinstance(interpret, mosaic_tpu_interpret.InterpretParams): impl = partial(mosaic_tpu_interpret.interpret_pallas_call, interpret_params=interpret, **params) @@ -1334,6 +1351,16 @@ def _pallas_call_typecheck_rule(*in_avals, grid_mapping, **params): def _convert_out_shape_to_aval(out_shape: Any) -> jax_core.AbstractValue: match out_shape: case jax.ShapeDtypeStruct(): + if config._check_vma.value: + if out_shape.vma is None: + raise ValueError( + "When `check_vma=True` on `jax.shard_map`, `vma` on" + " `jax.ShapeDtypeStruct` must not be `None`. Please specify how the" + " output should be varying across mesh axes using the `vma`" + " argument of `jax.ShapeDtypeStruct` or set `check_vma=False` on" + " `jax.shard_map`.") + return jax_core.ShapedArray( + shape=out_shape.shape, dtype=out_shape.dtype, vma=out_shape.vma) return jax_core.ShapedArray(shape=out_shape.shape, dtype=out_shape.dtype) case pallas_core.MemoryRef(): return out_shape.get_array_aval() @@ -1357,11 +1384,11 @@ def _pallas_call_state_discharge_rule( grid_mapping: GridMapping, mesh: pallas_core.Mesh | None, debug: bool, - interpret: bool, + interpret: Any, compiler_params: Any, cost_estimate: CostEstimate | None, out_avals: tuple[jax_core.AbstractValue, ...], - backend: _Backend | None = None, + backend: Backend | None = None, ): del avals_out assert all(isinstance(v.aval, state.AbstractRef) for v in jaxpr.constvars) @@ -1385,7 +1412,9 @@ def _pallas_call_state_discharge_rule( index_map_tree=grid_mapping.index_map_tree, grid=grid_mapping.grid, mapped_dims=grid_mapping.mapped_dims, - ) for ref_aval, block_spec in zip(ref_avals, ref_block_specs) + debug=debug, + ) + for ref_aval, block_spec in zip(ref_avals, ref_block_specs) ] in_block_mappings, out_block_mappings = split_list( grid_mapping.block_mappings, [grid_mapping.num_inputs] @@ -1455,7 +1484,7 @@ def _rewritten_body(*args): *ref_args, *rest_args, jaxpr=new_jaxpr, - input_output_aliases=new_input_output_aliases, + input_output_aliases=tuple(new_input_output_aliases), grid_mapping=new_grid_mapping, mesh=mesh, debug=debug, @@ -1479,17 +1508,19 @@ def pallas_call( in_specs: BlockSpecTree = no_block_spec, out_specs: BlockSpecTree = no_block_spec, scratch_shapes: ScratchShapeTree = (), - input_output_aliases: dict[int, int] = {}, + input_output_aliases: Mapping[int, int] = {}, debug: bool = False, - interpret: bool = False, + interpret: Any = False, name: str | None = None, - compiler_params: dict[str, Any] | pallas_core.CompilerParams | None = None, + compiler_params: ( + Mapping[Backend, pallas_core.CompilerParams] | pallas_core.CompilerParams | None + ) = None, cost_estimate: CostEstimate | None = None, - backend: _Backend | None = None, + backend: Backend | None = None, ) -> Callable[..., Any]: """Invokes a Pallas kernel on some inputs. - See `Pallas Quickstart `_. + See `Pallas Quickstart `_. Args: kernel: the kernel function, that receives a Ref for each input and output. @@ -1527,22 +1558,22 @@ def pallas_call( This is useful for debugging. name: if present, specifies the name to use for this kernel call in debugging and error messages. To this name we append the file and line - where the kernel function is defined, .e.g: - `{name} for kernel function {kernel_name} at {file}:{line}`. - If missing, then we use `{kernel_name} at {file}:{line}`. - compiler_params: Optional compiler parameters. If a dict is provided, it - should be of the form {platform: {param_name: param_value}}, where - platform is either 'mosaic' or 'triton'. It is also possible - to pass in `jax.experimental.pallas.tpu.TPUCompilerParams` for TPUs and - `jax.experimental.pallas.gpu.TritonCompilerParams` for Triton/GPUs. - backend: Optional string literal one of "mosaic_tpu", "triton" or "mosaic_gpu" - determining the backend to be used. None means let pallas decide. - + where the kernel function is defined, .e.g: `{name} for kernel function + {kernel_name} at {file}:{line}`. If missing, then we use `{kernel_name} at + {file}:{line}`. + compiler_params: Optional compiler parameters. The value should either be a + backend-specific dataclass + (:class:`jax.experimental.pallas.tpu.CompilerParams`, + :class:`jax.experimental.pallas.triton.CompilerParams`, + :class:`jax.experimental.pallas.mosaic_gpu.CompilerParams`) or a dict + mapping backend name to the corresponding platform-specific dataclass. + backend: Optional string literal one of ``"mosaic_tpu"``, ``"triton"`` or + ``"mosaic_gpu"`` determining the backend to be used. None means let Pallas + decide. Returns: A function that can be called on a number of positional array arguments to invoke the Pallas kernel. - """ if grid_spec is None: grid_spec = GridSpec(grid, in_specs, out_specs, scratch_shapes) @@ -1578,30 +1609,49 @@ def pallas_call( ) + +def _normalize_compiler_params( + compiler_params: Mapping[Backend, pallas_core.CompilerParams] | pallas_core.CompilerParams | None, +) -> Mapping[Backend, pallas_core.CompilerParams]: + if compiler_params is None: + return FrozenDict({}) + if isinstance(compiler_params, CompilerParams): + compiler_params = {compiler_params.BACKEND: compiler_params} + assert isinstance(compiler_params, Mapping) + for backend, params in compiler_params.items(): + if backend not in ["mosaic_tpu", "mosaic_gpu", "triton"]: + raise ValueError(f"Unknown backend in compiler_params: {backend}") + if not isinstance(params, CompilerParams): + raise ValueError( + f"Unexpected compiler_params for backend {backend}: {params}" + ) + if params.BACKEND != backend: + raise ValueError( + f"Inconsistent backend in compiler_params: {params.BACKEND} !=" + f" {backend}" + ) + if not isinstance(compiler_params, FrozenDict): + compiler_params = FrozenDict(compiler_params) + return compiler_params + + def _pallas_call( kernel: Callable[..., None], out_shape: Any, *, grid_spec: GridSpec, mesh: pallas_core.Mesh | None = None, - input_output_aliases: dict[int, int] = {}, + input_output_aliases: Mapping[int, int] = {}, debug: bool = False, - interpret: bool = False, + interpret: Any = False, name: str | None = None, - compiler_params: dict[str, Any] | pallas_core.CompilerParams | None = None, + compiler_params: ( + Mapping[Backend, CompilerParams] | CompilerParams | None + ) = None, cost_estimate: CostEstimate | None = None, - backend: _Backend | None = None, + backend: Backend | None = None, ): - if compiler_params is None: - compiler_params = {} - if isinstance(compiler_params, pallas_core.CompilerParams): - if compiler_params.PLATFORM not in ["mosaic", "mosaic_gpu", "triton"]: - raise ValueError( - f"Unknown platform in compiler params: {compiler_params.PLATFORM}" - ) - compiler_params = { - compiler_params.PLATFORM: dataclasses.asdict(compiler_params) - } + compiler_params = _normalize_compiler_params(compiler_params) if mesh is not None: if tuple(mesh.shape.values()) != grid_spec.grid: @@ -1611,7 +1661,7 @@ def _pallas_call( ) if backend is not None: raise ValueError("If `mesh` is specified, then `backend` must be `None`.") - backend = cast(_Backend, mesh.backend) + backend = cast(Backend, mesh.backend) grid_spec, dynamic_grid_bounds = pallas_core.unzip_dynamic_grid_bounds(grid_spec) # TODO(necula): this canonicalization may be convenient for some usage @@ -1636,13 +1686,22 @@ def wrapped(*args): # TODO(necula): check that input_output_aliases is well-formed: no duplicates, etc. kernel_args, grid_mapping = pallas_core.get_grid_mapping( grid_spec, - flat_in_avals, in_tree, in_origins, - flat_out_avals, out_tree, out_origins) + flat_in_avals, + in_tree, + in_origins, + flat_out_avals, + out_tree, + out_origins, + debug, + ) flat_kernel_args, kernel_in_tree = tree_util.tree_flatten(kernel_args) flat_kernel_avals = tuple( x.ref if isinstance(x, state_types.TransformedRef) else x for x in flat_kernel_args ) + if config._check_vma.value: + flat_kernel_avals = tuple(a.update_vma(frozenset()) + for a in flat_kernel_avals) # Note that only a subset of all transforms can be found here, and they are # never expected to contain any arrays. kernel_arg_transforms = tuple( @@ -1654,7 +1713,7 @@ def wrapped(*args): if name is not None: kernel_dbg = kernel_dbg.replace_func_name(mlir.sanitize_name(name)) jaxpr, consts = _trace_kernel_to_jaxpr( - kernel, kernel_dbg, grid_mapping, tuple(flat_kernel_avals), + kernel, kernel_dbg, grid_mapping, flat_kernel_avals, kernel_in_tree, kernel_arg_transforms) for i_idx, o_idx in input_output_aliases.items(): if i_idx not in range(len(flat_in_avals)): @@ -1733,7 +1792,7 @@ def in_path_to_input_origin( # We import the TPU backend at the top level because it defines flags. Note that -# we can only do that at the bottom of this file, beacuse it also depends on +# we can only do that at the bottom of this file, because it also depends on # this module already being initialized. try: @@ -1745,5 +1804,5 @@ def in_path_to_input_origin( from jax._src.pallas.mosaic import interpret as mosaic_tpu_interpret except ImportError: mosaic_tpu_interpret = types.SimpleNamespace( # type: ignore - TPUInterpretParams=types.new_class('_NoInstances', (enum.Enum,)), + InterpretParams=types.new_class('_NoInstances', (enum.Enum,)), ) diff --git a/jax/_src/pallas/primitives.py b/jax/_src/pallas/primitives.py index 3306649f24f3..a8fdf04d9d9f 100644 --- a/jax/_src/pallas/primitives.py +++ b/jax/_src/pallas/primitives.py @@ -19,7 +19,9 @@ import enum import functools import string -from typing import Any, Callable +from collections.abc import Hashable +from typing import Any +from collections.abc import Callable, Sequence import jax from jax import lax @@ -345,9 +347,9 @@ def _atomic_cas_discharge_rule(in_avals, out_avals, ref, cmp, val): mlir.register_lowering(max_contiguous_p, lambda _, x, **__: [x]) def max_contiguous(x, values): - if not isinstance(values, list): - values = [values] - return max_contiguous_p.bind(x, values=values) + if not isinstance(values, (list, tuple)): + values = (values,) + return max_contiguous_p.bind(x, values=tuple(values)) @max_contiguous_p.def_abstract_eval def _max_contiguous_abstract_eval(aval, **_): @@ -358,9 +360,8 @@ def _max_contiguous_abstract_eval(aval, **_): multiple_of_p.def_impl(lambda x, **_: x) mlir.register_lowering(multiple_of_p, lambda _, x, **__: [x]) -def multiple_of(x: jax.Array, values: list[int] | int) -> jax.Array: - if not isinstance(values, list): - values = [values] +def multiple_of(x: jax.Array, values: Sequence[int] | int) -> jax.Array: + values = (values,) if isinstance(values, int) else tuple(values) return multiple_of_p.bind(x, values=values) @multiple_of_p.def_abstract_eval @@ -489,7 +490,7 @@ def _load_discharge_rule(in_avals, out_avals, *args_flat, args_tree, **_): scalar_dims = [not isinstance(s, Slice) and not s.shape for s in indices] slice_starts = [s.start if isinstance(s, Slice) else s for s in indices] slice_sizes = tuple(s.size if isinstance(s, Slice) else 1 for s in indices) - # fixes an inconstency with lax.dynamic_slice where if the slice goes out + # fixes an inconsistency with lax.dynamic_slice where if the slice goes out # of bounds, it will instead move the start_index backwards so the slice # will fit in memory. ref = _pad_values_to_avoid_dynamic_slice_oob_shift(ref, slice_sizes) @@ -878,13 +879,25 @@ def wrap_with_transforms(f, transforms, *args): run_scoped_p.multiple_results = True -def run_scoped(f: Callable[..., Any], *types: Any, **kw_types: Any) -> Any: +def run_scoped( + f: Callable[..., Any], + *types: Any, + collective_axes: Hashable | tuple[Hashable, ...] = (), + **kw_types: Any, +) -> Any: """Calls the function with allocated references and returns the result. The positional and keyword arguments describe which reference types to allocate for each argument. Each backend has its own set of reference types in addition to :class:`jax.experimental.pallas.MemoryRef`. + + When `collective_axes` is specified, the same allocation will be returned for + all programs that only differ in their program ids along the collective axes. + It is an error not to call the same `run_scoped` in all programs along that + axis. """ + if not isinstance(collective_axes, tuple): + collective_axes = (collective_axes,) flat_types, in_tree = tree_util.tree_flatten((types, kw_types)) flat_fun, out_tree_thunk = api_util.flatten_fun( lu.wrap_init(f, @@ -908,13 +921,13 @@ def run_scoped(f: Callable[..., Any], *types: Any, **kw_types: Any) -> Any: # are not in the invars of an operation so we just put them all # there. jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, avals) - out = run_scoped_p.bind(*consts, jaxpr=jaxpr) + out = run_scoped_p.bind(*consts, jaxpr=jaxpr, collective_axes=collective_axes) return tree_util.tree_unflatten(out_tree_thunk(), out) @run_scoped_p.def_effectful_abstract_eval -def _run_scoped_abstract_eval(*args, jaxpr): - del args +def _run_scoped_abstract_eval(*args, jaxpr, collective_axes): + del args, collective_axes # jaxpr will have effects for its inputs (Refs that are allocated) and for # constvars (closed over Refs). The effects for the allocated Refs are local # to the jaxpr and shouldn't propagate out. @@ -935,8 +948,12 @@ def _run_scoped_discharge_rule( out_avals, *args_flat, jaxpr, - **_): + collective_axes): del out_avals + if collective_axes: + raise NotImplementedError( + "run_scoped discharge does not support collective_axes yet." + ) num_consts = len(args_flat) # discharge_state only discharges invars, not consts, so in order to # discharge the requested refs we need to move them to the invar set. @@ -956,7 +973,9 @@ def _run_scoped_discharge_rule( # Run_scoped discharged the external variables but the scoped ones # are not discharged. - out = run_scoped_p.bind(*args_flat, jaxpr=discharged_body) + out = run_scoped_p.bind( + *args_flat, jaxpr=discharged_body, collective_axes=collective_axes + ) # Order of outputs: # (1) return values, (2) closed refs, (3) scoped refs. return_values = out[:num_return_values] @@ -975,7 +994,12 @@ def _run_scoped_discharge_rule( @functools.partial(mlir.register_lowering, run_scoped_p) -def _run_scoped_lowering_rule(ctx, *args, jaxpr): +def _run_scoped_lowering_rule(ctx, *args, jaxpr, collective_axes): + if collective_axes: + raise ValueError( + "run_scoped lowering outside of Pallas does not support" + " collective_axes." + ) jaxpr_noconst = pe.convert_constvars_jaxpr(jaxpr) num_return_values = len(jaxpr_noconst.outvars) discharged_body, new_consts = state_discharge.discharge_state( @@ -993,3 +1017,247 @@ def _lower_fun(*lower_fun_args): return out[:num_return_values] return mlir.lower_fun(_lower_fun, multiple_results=True)(ctx, *args) + + +def _get_ref_and_transforms(ref): + if isinstance(ref, state.TransformedRef): + return ref.ref, ref.transforms + return ref, () + + +class DeviceIdType(enum.Enum): + MESH = "mesh" + LOGICAL = "logical" + + +def check_sem_avals( + sem_aval, sem_transforms_avals, name, allowed_semaphore_types=None +): + if allowed_semaphore_types is None: + allowed_semaphore_types = { + pallas_core.semaphore, + pallas_core.barrier_semaphore, + # For interpret mode. + pallas_core.SEMAPHORE_INTERPRET_DTYPE, + } + if not isinstance(sem_aval, state.AbstractRef): + raise ValueError(f"Cannot {name} on a non-semaphore Ref: {sem_aval}") + sem_shape = sem_aval.shape + if sem_transforms_avals: + sem_shape = sem_transforms_avals[-1].get_indexer_shape() + if sem_shape: + raise ValueError(f"Cannot {name} on a non-()-shaped semaphore: {sem_shape}") + sem_dtype = sem_aval.dtype + if not any( + jnp.issubdtype(sem_dtype, sem_type) + for sem_type in allowed_semaphore_types + ): + raise ValueError( + f"Must {name} semaphores of the following types:" + f" {allowed_semaphore_types}." + ) + + +def _transform_semaphore(ref_value, transforms, ref_aval): + """Helper function for indexing into a semaphore during state_discharge.""" + if ref_value.shape == ref_aval.shape: + return state_discharge.transform_array(ref_value, transforms) + elif len(ref_value.shape) == 0: + return ref_value + else: + raise ValueError( + f"Semaphore value shape {ref_value.shape} does not match aval shape" + f" {ref_aval.shape}" + ) + + +semaphore_read_p = jax_core.Primitive("semaphore_read") +semaphore_read_p.multiple_results = False + + +def semaphore_read(sem_or_view): + ref, transforms = _get_ref_and_transforms(sem_or_view) + args = [ref, transforms] + flat_args, args_tree = tree_util.tree_flatten(args) + return semaphore_read_p.bind(*flat_args, args_tree=args_tree) + +@semaphore_read_p.def_abstract_eval +def _semaphore_read_abstract_eval( + *avals, + args_tree, +): + del avals, args_tree + return jax_core.ShapedArray((), jnp.dtype("int32")) + +def _semaphore_read_discharge_rule(in_avals, + out_avals, + *flat_args, + args_tree): + del out_avals + [ref, transforms] = args_tree.unflatten(flat_args) + sem_value = _transform_semaphore(ref, transforms, in_avals[0]) + sem_value = sem_value.astype(jnp.int32) + return (None,) * len(in_avals), sem_value +state_discharge.register_discharge_rule(semaphore_read_p)( + _semaphore_read_discharge_rule +) + + +semaphore_signal_p = jax_core.Primitive('semaphore_signal') +semaphore_signal_p.multiple_results = True + + +def semaphore_signal( + sem_or_view, + inc: int | jax.Array = 1, + *, + device_id: int | jax.Array | None | tuple[int | jax.Array, ...] = None, + device_id_type: DeviceIdType = DeviceIdType.MESH, + core_index: int | jax.Array | None = None, +): + ref, transforms = _get_ref_and_transforms(sem_or_view) + inc = jnp.asarray(inc, dtype=jnp.int32) + args = [ref, transforms, inc, device_id, core_index] + flat_args, args_tree = tree_util.tree_flatten(args) + semaphore_signal_p.bind( + *flat_args, + args_tree=args_tree, + device_id_type=device_id_type, + ) + + +@semaphore_signal_p.def_abstract_eval +def _semaphore_signal_abstract_eval( + *avals, + args_tree, + device_id_type: DeviceIdType, +): + del device_id_type + ( + sem_aval, + sem_transforms_avals, + value_aval, + device_id_avals, + core_index_aval, + ) = tree_util.tree_unflatten(args_tree, avals) + check_sem_avals(sem_aval, sem_transforms_avals, "signal") + if value_aval.dtype != jnp.dtype("int32"): + raise ValueError("Must signal an int32 value.") + if device_id_avals is not None: + device_id_flat_avals = tree_util.tree_leaves(device_id_avals) + for aval in device_id_flat_avals: + if aval.dtype != jnp.dtype("int32"): + raise ValueError("`device_id`s must be an int32 value.") + return [] + + +def _semaphore_signal_pp_eqn(eqn: jax_core.JaxprEqn, + context: jax_core.JaxprPpContext, + settings: jax_core.JaxprPpSettings): + del settings + invars = eqn.invars + tree = eqn.params["args_tree"] + ( + sem, + sem_transforms, + value, + device_ids, + _, + ) = tree_util.tree_unflatten(tree, invars) + out = pp.concat([ + pp.text("semaphore_signal"), + pp.text(" "), + sp.pp_ref_transforms(context, sem, sem_transforms), + pp.text(" "), + pp.text(jax_core.pp_var(value, context)), + ]) + if device_ids is not None: + flat_device_ids = tree_util.tree_leaves(device_ids) + if not flat_device_ids: + return out + device_ids_pp = [pp.text(jax_core.pp_var(flat_device_ids[0], context))] + for device_id in flat_device_ids[1:]: + device_ids_pp.append(pp.text(" ")) + device_ids_pp.append(pp.text(jax_core.pp_var(device_id, context))) + out = pp.concat([out, pp.concat(device_ids_pp)]) + return out +jax_core.pp_eqn_rules[semaphore_signal_p] = _semaphore_signal_pp_eqn + + +def _semaphore_signal_discharge_rule(in_avals, + out_avals, + *flat_args, + args_tree, + device_id_type): + del out_avals, device_id_type + [ref, transforms, inc, device_id, core_index] = args_tree.unflatten(flat_args) + if device_id is not None: + raise NotImplementedError("Remote signal not implemented.") + if core_index is not None: + raise NotImplementedError("Multiple core support not implemented.") + sem_value = _transform_semaphore(ref, transforms, in_avals[0]) + inc = inc.astype(pallas_core.SEMAPHORE_INTERPRET_DTYPE) + _, new_sem_value = state_discharge.transform_swap_array( + ref, transforms, sem_value + inc + ) + return (new_sem_value,) + (None,) * (len(in_avals) - 1), () +state_discharge.register_discharge_rule(semaphore_signal_p)( + _semaphore_signal_discharge_rule +) + + +semaphore_wait_p = jax_core.Primitive('semaphore_wait') +semaphore_wait_p.multiple_results = True + +def semaphore_wait(sem_or_view, dec: int | jax.Array = 1): + ref, transforms = _get_ref_and_transforms(sem_or_view) + dec = jnp.asarray(dec, dtype=jnp.int32) + args = [ref, transforms, dec] + flat_args, args_tree = tree_util.tree_flatten(args) + semaphore_wait_p.bind(*flat_args, args_tree=args_tree) + +@semaphore_wait_p.def_abstract_eval +def _semaphore_wait_abstract_eval(*avals, args_tree): + sem_aval, sem_transforms_avals, value_aval = tree_util.tree_unflatten( + args_tree, avals + ) + check_sem_avals(sem_aval, sem_transforms_avals, "wait") + if value_aval.dtype != jnp.dtype("int32"): + raise ValueError("Must wait an int32 value.") + return [] + +def _semaphore_wait_pp_eqn(eqn: jax_core.JaxprEqn, + context: jax_core.JaxprPpContext, + settings: jax_core.JaxprPpSettings): + del settings + invars = eqn.invars + tree = eqn.params["args_tree"] + ( + sem, + sem_transforms, + value, + ) = tree_util.tree_unflatten(tree, invars) + return pp.concat([ + pp.text("semaphore_wait"), + pp.text(" "), + sp.pp_ref_transforms(context, sem, sem_transforms), + pp.text(" "), + pp.text(jax_core.pp_var(value, context)), + ]) +jax_core.pp_eqn_rules[semaphore_wait_p] = _semaphore_wait_pp_eqn + +def _semaphore_wait_discharge_rule(in_avals, + out_avals, + *flat_args, + args_tree): + del out_avals + [ref, transforms, dec] = args_tree.unflatten(flat_args) + sem_value = _transform_semaphore(ref, transforms, in_avals[0]) + dec = dec.astype(pallas_core.SEMAPHORE_INTERPRET_DTYPE) + _, new_sem_value = state_discharge.transform_swap_array( + ref, transforms, sem_value - dec + ) + return (new_sem_value,) + (None,) * (len(in_avals) - 1), () +state_discharge.register_discharge_rule(semaphore_wait_p)( + _semaphore_wait_discharge_rule +) diff --git a/jax/_src/pallas/triton/BUILD b/jax/_src/pallas/triton/BUILD index cde2aadd6013..f7c4a05205d3 100644 --- a/jax/_src/pallas/triton/BUILD +++ b/jax/_src/pallas/triton/BUILD @@ -60,12 +60,16 @@ pytype_strict_library( deps = [ "//jax", "//jax:ad_util", + "//jax:api", "//jax:api_util", "//jax:config", "//jax:core", + "//jax:custom_derivatives", + "//jax:lax", "//jax:mlir", "//jax:partial_eval", "//jax:source_info_util", + "//jax:state_types", "//jax:util", "//jax/_src/lib", "//jax/_src/pallas", @@ -76,12 +80,11 @@ pytype_strict_library( name = "pallas_call_registration", srcs = ["pallas_call_registration.py"], deps = [ + ":core", ":lowering", "//jax", - "//jax:config", "//jax:core", "//jax:mlir", - "//jax:util", "//jax/_src/lib", "//jax/_src/pallas", ], diff --git a/jax/_src/pallas/triton/core.py b/jax/_src/pallas/triton/core.py index 097f8497e8f7..7b6e69dc8dd8 100644 --- a/jax/_src/pallas/triton/core.py +++ b/jax/_src/pallas/triton/core.py @@ -21,7 +21,7 @@ from jax._src.pallas import core as pallas_core @dataclasses.dataclass(frozen=True) -class TritonCompilerParams(pallas_core.CompilerParams): +class CompilerParams(pallas_core.CompilerParams): """Compiler parameters for Triton. Attributes: @@ -32,7 +32,7 @@ class TritonCompilerParams(pallas_core.CompilerParams): serialized_metadata: Additional compiler metadata. This field is unstable and may be removed in the future. """ - PLATFORM: ClassVar[str] = "triton" + BACKEND: ClassVar[pallas_core.Backend] = "triton" num_warps: int | None = None num_stages: int | None = None serialized_metadata: bytes | None = None diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index f3a8dd175ec1..e2fb6705de4c 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -21,7 +21,8 @@ import functools import math import operator -from typing import Any, Hashable, TypeVar +from typing import Any, TypeVar +from collections.abc import Hashable import jax from jax import lax @@ -90,7 +91,7 @@ class BlockInfo: full_shape_dtype: jax.ShapeDtypeStruct start_indices: Sequence[Any] start_indices_alignment: Sequence[int] - block_shape: tuple[int | pallas_core.Mapped, ...] + block_shape: tuple[int | pallas_core.Squeezed, ...] @dataclasses.dataclass @@ -121,32 +122,45 @@ def _eval_index_map( block_indices = lower_jaxpr_to_triton_ir( ctx, block_mapping.index_map_jaxpr.jaxpr, None, *idx ) - block_indices = ( + block_indices = tuple( _ensure_ir_value(i, jax_core.ShapedArray((), jnp.int32)) for i in block_indices ) - if isinstance(block_mapping.indexing_mode, pallas_core.Unblocked): - if block_mapping.indexing_mode.padding is not None: - raise NotImplementedError( - "Unblocked indexing with padding is not supported in Triton lowering." - ) - if block_mapping.pipeline_mode is not None: - raise NotImplementedError( - "Pipeline mode is not supported in Triton lowering." - ) - return tuple(block_indices) + block_indices = tree_util.tree_unflatten( + block_mapping.index_map_out_tree, block_indices) + if block_mapping.pipeline_mode is not None: + raise NotImplementedError( + "Pipeline mode is not supported in Triton lowering." + ) + if any( + isinstance(b, pallas_core.Element) and b.padding != (0, 0) + for b in block_mapping.block_shape + ): + raise NotImplementedError( + "Unblocked indexing with padding is not supported in Triton lowering." + ) + def _get_start_index(i, b): + match b: + case pallas_core.Squeezed() | pallas_core.Element(): + return i + case pallas_core.Blocked(): + return _mul(i, _ir_constant(b.block_size, i.type)) + case _: + raise ValueError(f"Unsupported block dim type: {type(b)}") return tuple( - i if b is pallas_core.mapped else _mul(i, _ir_constant(b, i.type)) - for i, b in zip(block_indices, block_mapping.block_shape) + _get_start_index(i, b) for i, b in + zip(block_indices, block_mapping.block_shape) ) def _get_index_alignment(block_mapping: BlockMapping) -> tuple[int, ...]: - if isinstance(block_mapping.indexing_mode, pallas_core.Unblocked): - return (1,) * len(block_mapping.block_shape) - return tuple( - 1 if b is pallas_core.mapped else b for b in block_mapping.block_shape - ) + def _get_bdim_alignment(b: pallas_core.BlockDim): + match b: + case pallas_core.Squeezed() | pallas_core.Element(): + return 1 + case pallas_core.Blocked(): + return b.block_size + return tuple(_get_bdim_alignment(b) for b in block_mapping.block_shape) def _bcast_to(a: ir.Value, shape: tuple[int, ...]) -> ir.Value: @@ -274,8 +288,9 @@ def _new_ir_context() -> ir.Context: # this). This check is only needed to obtain a nicer error message; the # Triton lowering will fail anyway but it will crash with a C++ exception. # We currently apply this check only to load/store operations. -def _check_tensor_size(shape: tuple[int | pallas_core.Mapped, ...]): - size = math.prod(1 if d is pallas_core.mapped else d for d in shape) +def _check_tensor_size(shape: tuple[int | pallas_core.Squeezed, ...]): + size = math.prod(1 if isinstance(d, pallas_core.Squeezed) else d + for d in shape) power_of_2 = (size & (size - 1)) == 0 if not power_of_2: raise ValueError( @@ -347,7 +362,9 @@ def lower_jaxpr_to_triton_module( block_mapping.array_shape_dtype, _eval_index_map(ctx, program_ids, block_mapping), _get_index_alignment(block_mapping), - block_mapping.block_shape, + tuple(pallas_core.squeezed if isinstance(b, pallas_core.Squeezed) + else pallas_core._get_block_dim_size(b) + for b in block_mapping.block_shape), ) for block_mapping in grid_mapping.block_mappings ] @@ -654,7 +671,9 @@ def _make_dispatch_table( name: str, **tables: Sequence[_Extern | _Fallback] ) -> Callable[..., ir.Value]: - def inner(ctx: LoweringRuleContext, *args: ir.Value) -> ir.Value: + def inner( + ctx: LoweringRuleContext, *args: ir.Value, **_ + ) -> ir.Value: table = tables[ctx.context.platform] h = next((e for e in table if e.matches(ctx.avals_in)), None) if h is None: @@ -1120,7 +1139,7 @@ def inner(ctx: LoweringRuleContext, *args: ir.Value) -> ir.Value: def _minus(x: ir.Value) -> ir.Value: if tt_dialect.PointerType.isinstance(_element_type(x.type)): raise NotImplementedError(f"unsupported type: {x.type}") - return _sub(_full(x.type, 0), x) + return _sub(_zeros_like(x), x) def _add(x: ir.Value, y: ir.Value): @@ -1260,6 +1279,10 @@ def _cmp( ) +def _is_nan(x: ir.Value) -> ir.Value: + return arith_dialect.cmpf(arith_dialect.CmpFPredicate.UNO, x, x) + + _JAX_TO_TRITON_BINARY = { lax.add_p: _add, lax.sub_p: _sub, @@ -1373,7 +1396,7 @@ def _broadcast_to_rule(ctx: LoweringRuleContext, x, shape: Sequence[int]): @register_lowering(lax.integer_pow_p) def _integer_pow_rule(ctx: LoweringRuleContext, x, *, y: int): if y == 0: - return _full(x.type, 1) + return _ones_like(x) is_reciprocal = y < 0 if is_reciprocal: @@ -1393,14 +1416,14 @@ def _integer_pow_rule(ctx: LoweringRuleContext, x, *, y: int): acc = _cast(acc, x_aval.dtype, out_aval.dtype) if is_reciprocal: signed = jnp.issubdtype(out_aval.dtype, jnp.signedinteger) - return _truediv(_full(acc.type, 1), acc, signed=signed) + return _truediv(_ones_like(acc), acc, signed=signed) else: return acc _JAX_FN_MAPPING = { lax.clamp_p: lambda min, a, max: jnp.minimum(jnp.maximum(min, a), max), - lax.logistic_p: lambda a: 1 / (1 + jnp.exp(-a)), + lax.logistic_p: lambda a, accuracy: 1 / (1 + jnp.exp(-a)), } for prim, fn in _JAX_FN_MAPPING.items(): @@ -1514,6 +1537,22 @@ def _full(t: ir.Type, v: object) -> ir.Type: return result +def _zeros(t: ir.Type) -> ir.Value: + return _full(t, 0) + + +def _zeros_like(x: ir.Value) -> ir.Value: + return _full(x.type, 0) + + +def _ones(t: ir.Type) -> ir.Value: + return _full(t, 1) + + +def _ones_like(x: ir.Value) -> ir.Value: + return _full(x.type, 1) + + def _splat(x: ir.value, shape: Sequence[int]) -> ir.Value: if ir.RankedTensorType.isinstance(x.type): raise TypeError("cannot splat a tensor") @@ -1534,11 +1573,10 @@ def _float_float_cast(src: ir.Value, dst_type: ir.Type) -> ir.Value: src_element_type = ir.FloatType(_element_type(src.type)) dst_element_type = ir.FloatType(_element_type(dst_type)) if src_element_type.width == 8 or dst_element_type.width == 8: - return tt_dialect.fp_to_fp( - dst_type, - src, - rounding=tt_dialect.RoundingMode.RTNE, + rounding = ( + tt_dialect.RoundingMode.RTNE if src_element_type.width > 8 else None ) + return tt_dialect.fp_to_fp(dst_type, src, rounding=rounding) if src_element_type.width > dst_element_type.width: return arith_dialect.truncf(dst_type, src) elif src_element_type.width < dst_element_type.width: @@ -1552,7 +1590,7 @@ def _int_int_cast(src: ir.Value, dst_type: ir.Type, signed: bool) -> ir.Value: dst_element_type = ir.IntegerType(_element_type(dst_type)) assert src_element_type != dst_element_type if dst_element_type.width == 1: - return _not_equal(src, _full(src.type, 0), signed=signed) + return _not_equal(src, _zeros_like(src), signed=signed) if src_element_type.width == dst_element_type.width: return arith_dialect.bitcast(dst_type, src) @@ -1572,7 +1610,7 @@ def _float_int_cast( raise NotImplementedError(f"cannot cast {src} tp {dst_type}") dst_element_type = ir.IntegerType(_element_type(dst_type)) if dst_element_type.width == 1: - return _not_equal(src, _full(src.type, 0), signed=signed) + return _not_equal(src, _zeros_like(src), signed=signed) else: # We clamp the float value to the min/max integer destination value # in order to match JAX/XLA casting behavior. Note that this differs @@ -1675,7 +1713,7 @@ def _ir_cast(src: ir.Value, dst_type: ir.Type, *, return tt_dialect.ptr_to_int(dst_type, src) elif dst_element_type.width == 1: x = _ir_cast(src, ir.IntegerType.get_signless(64), signed=signed) - zero = _full(x.type, 0) + zero = _zeros_like(x) return _ir_cast(_not_equal(x, zero, signed=signed), dst_type, signed=signed) if isinstance( src_element_type, ir.IntegerType @@ -1759,6 +1797,12 @@ def _reshape(a: ir.Value, shape: Sequence[int]) -> ir.Value: ) +def get_join_type(old_type: ir.RankedTensorType): + shape = old_type.shape + shape.append(2) + return ir.RankedTensorType.get(shape, old_type.element_type, old_type.encoding) + + @register_lowering(lax.concatenate_p) def _concatenate_lowering_rule(ctx: LoweringRuleContext, *args, dimension): if len(args) != 2: @@ -1773,16 +1817,40 @@ def _concatenate_lowering_rule(ctx: LoweringRuleContext, *args, dimension): raise NotImplementedError( "Only arguments with shape [..., 1] are supported." ) - return tt_dialect.join( - _reshape(x, x_aval.shape[:-1]), _reshape(y, y_aval.shape[:-1]) - ) + lhs = _reshape(x, x_aval.shape[:-1]) + rhs = _reshape(y, y_aval.shape[:-1]) + ret_type = get_join_type(ir.RankedTensorType(rhs.type)) + return tt_dialect.join(ret_type, lhs, rhs) + + +@register_lowering(lax.split_p) +def _split_lowering_rule(ctx: LoweringRuleContext, x, *, sizes, axis): + pass + # TODO(cjfj): Add support for larger powers of 2. + num_parts = len(sizes) + if num_parts != pallas_utils.next_power_of_2(num_parts): + raise NotImplementedError("Only power-of-2 num parts supported.") + if any(size != sizes[0] for size in sizes): + raise NotImplementedError("Only equal-sized splits are supported.") + + def split_into_2(x): + shape = ir.RankedTensorType(x.type).shape + x = _reshape(x, shape[:axis] + [2, shape[axis] // 2] + shape[axis + 1 :]) + permutation = tuple(d for d in range(len(shape) + 1) if d != axis) + (axis,) + return tuple(tt_dialect.split(tt_dialect.trans(x, permutation))) + + x_parts = (x,) + while len(x_parts) < num_parts: + x_parts = sum(map(split_into_2, x_parts), ()) + return x_parts def _compute_offsets_from_indices( block_info: BlockInfo, nd_indexer: NDIndexer ) -> ir.Value: full_shape = block_info.full_shape_dtype.shape - num_mapped_dims = sum(b is pallas_core.mapped for b in block_info.block_shape) + num_squeezed_dims = sum(isinstance(b, pallas_core.Squeezed) + for b in block_info.block_shape) strides = pallas_utils.strides_from_shape(full_shape) indexer_shape = nd_indexer.get_indexer_shape() int_indexer_shape = nd_indexer.int_indexer_shape @@ -1790,7 +1858,7 @@ def _compute_offsets_from_indices( indices = nd_indexer.indices other_shape = indexer_shape[len(int_indexer_shape) :] other_shape_idx = 0 - assert len(indices) + num_mapped_dims == len(full_shape) + assert len(indices) + num_squeezed_dims == len(full_shape) assert len(block_info.start_indices) == len(full_shape) array_dtype = jnp.dtype(block_info.full_shape_dtype.dtype) @@ -1798,7 +1866,7 @@ def _compute_offsets_from_indices( # Use 64-bit indexing when offset might be >= 2**32 bytes. offset_eltype = ir.IntegerType.get_signless(64 if full_size > 2**32 else 32) if indexer_shape: - offsets = _full(ir.RankedTensorType.get(indexer_shape, offset_eltype), 0) + offsets = _zeros(ir.RankedTensorType.get(indexer_shape, offset_eltype)) else: offsets = _ir_constant(0, offset_eltype) @@ -1806,10 +1874,11 @@ def _compute_offsets_from_indices( for dim_stride, dim_block_size, start_offset in zip( strides, block_info.block_shape, block_info.start_indices ): - if dim_block_size is pallas_core.mapped: - index = _ir_constant(0, offset_eltype) - else: - index = next(indexer_iter) + match dim_block_size: + case pallas_core.Squeezed(): + index = _ir_constant(0, offset_eltype) + case int(): + index = next(indexer_iter) if isinstance(index, slice): index = primitives.Slice.from_slice(index, dim_block_size) @@ -2060,17 +2129,18 @@ def _masked_load_lowering_rule( # most significant. Before jaxlib 0.5.2, the order was reversed. if is_contiguous_int4: msb_values = arith_dialect.shrui(values, _full(values.type, 4)) + join_type = get_join_type(ir.RankedTensorType(values.type)) if jaxlib_version < (0, 5, 2): - values = tt_dialect.join(msb_values, values) + values = tt_dialect.join(join_type, msb_values, values) else: - values = tt_dialect.join(values, msb_values) + values = tt_dialect.join(join_type, values, msb_values) shape = ir.RankedTensorType(values.type).shape values = _reshape(values, (*shape[:-2], shape[-2] * shape[-1])) else: offsets = _ir_cast(offsets, ir.IntegerType.get_signless(32), signed=False) in_msb = _mod(offsets, _full(offsets.type, 2), signed=False) if jaxlib_version < (0, 5, 2): - in_msb = arith_dialect.xori(in_msb, _full(in_msb.type, 1)) + in_msb = arith_dialect.xori(in_msb, _ones_like(in_msb)) shift = _mul(in_msb, _full(in_msb.type, 4)) shift = _ir_cast(shift, values.type, signed=False) values = arith_dialect.shrui(values, shift) @@ -2198,6 +2268,14 @@ def _transpose_lowering(ctx: LoweringRuleContext, x, *, permutation): _TF32_PRECISIONS = (lax.Precision.HIGH, lax.Precision.DEFAULT) +def _as_bf16(x): + return _ir_cast(x, _dtype_to_ir_type(jnp.bfloat16), signed=False) + + +def _as_f32(x): + return _ir_cast(x, _dtype_to_ir_type(jnp.float32), signed=False) + + @register_lowering(lax.dot_general_p) def _dot_general_lowering( ctx: LoweringRuleContext, @@ -2237,6 +2315,9 @@ def _dot_general_lowering( | lax.DotAlgorithmPreset.F16_F16_F32 | lax.DotAlgorithmPreset.BF16_BF16_BF16 | lax.DotAlgorithmPreset.BF16_BF16_F32 + | lax.DotAlgorithmPreset.BF16_BF16_F32_X3 + | lax.DotAlgorithmPreset.BF16_BF16_F32_X6 + | lax.DotAlgorithmPreset.BF16_BF16_F32_X9 ): input_precision = None case _: @@ -2275,7 +2356,40 @@ def _dot_general_lowering( m, _ = a_type.shape _, n = b_type.shape - acc = _full(ir.RankedTensorType.get([m, n], _dtype_to_ir_type(acc_dtype)), 0) + acc = _zeros(ir.RankedTensorType.get([m, n], _dtype_to_ir_type(acc_dtype))) + + if precision in ( + lax.DotAlgorithmPreset.BF16_BF16_F32_X3, + lax.DotAlgorithmPreset.BF16_BF16_F32_X6, + lax.DotAlgorithmPreset.BF16_BF16_F32_X9, + ): + a_bf16 = _as_bf16(a) + b_bf16 = _as_bf16(b) + a_err0 = _sub(a, _as_f32(a_bf16)) + b_err0 = _sub(b, _as_f32(b_bf16)) + a_err0_bf16 = _as_bf16(a_err0) + b_err0_bf16 = _as_bf16(b_err0) + a_err1_bf16 = _as_bf16(_sub(a_err0, _as_f32(a_err0_bf16))) + b_err1_bf16 = _as_bf16(_sub(b_err0, _as_f32(b_err0_bf16))) + # Accumulate the smallest values first to reduce the numeric error. + if precision == lax.DotAlgorithmPreset.BF16_BF16_F32_X9: + acc = tt_dialect.dot(a_err1_bf16, b_err0_bf16, acc) + acc = tt_dialect.dot(a_err1_bf16, b_err1_bf16, acc) + acc = tt_dialect.dot(a_err0_bf16, b_err1_bf16, acc) + if precision in ( + lax.DotAlgorithmPreset.BF16_BF16_F32_X6, + lax.DotAlgorithmPreset.BF16_BF16_F32_X9, + ): + acc = tt_dialect.dot(a_err1_bf16, b_bf16, acc) + acc = tt_dialect.dot(a_bf16, b_err1_bf16, acc) + acc = tt_dialect.dot(a_err0_bf16, b_err0_bf16, acc) + acc = tt_dialect.dot(a_err0_bf16, b_bf16, acc) + acc = tt_dialect.dot(a_bf16, b_err0_bf16, acc) + # If `a` rounding error is zero and `b` is `inf` then `acc` may contain + # `NaN`s (as `0 * inf = NaN`), and vice versa. + acc = arith_dialect.select(_is_nan(acc), _zeros_like(acc), acc) + a, b = a_bf16, b_bf16 + acc = tt_dialect.dot(a, b, acc, input_precision=input_precision) return _cast(acc, acc_dtype, out_aval.dtype) diff --git a/jax/_src/pallas/triton/pallas_call_registration.py b/jax/_src/pallas/triton/pallas_call_registration.py index 4e8775e514f0..9bb5c8f21628 100644 --- a/jax/_src/pallas/triton/pallas_call_registration.py +++ b/jax/_src/pallas/triton/pallas_call_registration.py @@ -17,16 +17,17 @@ from __future__ import annotations import io -from typing import Any +from typing import cast import zlib import jax import jax._src.core as jax_core from jax._src.interpreters import mlir -from jax._src.lib import triton from jax._src.lib import gpu_triton as triton_kernel_call_lib +from jax._src.lib import triton from jax._src.lib.mlir import ir from jax._src.pallas import core as pallas_core +from jax._src.pallas.triton import core as triton_core from jax._src.pallas.triton import lowering @@ -39,7 +40,7 @@ def normalize_grid(grid: pallas_core.StaticGrid) -> tuple[int, int, int]: def avals_to_layouts(avals): - return [list(reversed(range(aval.ndim))) for aval in avals] + return [list(reversed(range(aval.ndim))) for aval in avals] # pytype: disable=attribute-error def pallas_call_lowering( @@ -51,7 +52,7 @@ def pallas_call_lowering( input_output_aliases: tuple[tuple[int, int], ...], grid_mapping: pallas_core.GridMapping, mesh: pallas_core.Mesh | None, - compiler_params: dict[str, Any], + compiler_params: dict[str, pallas_core.CompilerParams], cost_estimate: pallas_core.CostEstimate | None, out_avals: tuple[jax_core.AbstractValue, ...], ): @@ -67,16 +68,17 @@ def pallas_call_lowering( ) if mesh is not None: raise NotImplementedError("mesh is not supported in the Triton backend") - triton_params = compiler_params.get("triton", compiler_params) - num_warps = triton_params.get("num_warps", 4) - num_warps = 4 if num_warps is None else num_warps + [lowering_platform] = ctx.platforms or ctx.module_context.platforms - if lowering_platform == "rocm": - num_stages = triton_params.get("num_stages", 1) - num_stages = 1 if num_stages is None else num_stages + + if "triton" in compiler_params: + params = cast(triton_core.CompilerParams, compiler_params["triton"]) else: - num_stages = triton_params.get("num_stages", 3) - num_stages = 3 if num_stages is None else num_stages + params = triton_core.CompilerParams() + num_warps = 4 if params.num_warps is None else params.num_warps + num_stages = params.num_stages + if num_stages is None: + num_stages = 1 if lowering_platform == "rocm" else 3 if debug: print(f"\nThe kernel jaxpr for pallas_call {debug_info.func_src_info}:") @@ -117,12 +119,11 @@ def pallas_call_lowering( grid_z=mlir.i32_attr(grid_z), debug=ir.BoolAttr.get(debug), ) - if "serialized_metadata" in (triton_params or {}): + if params.serialized_metadata is not None: # This field is unstable and may be removed in the future. - if triton_params["serialized_metadata"] is not None: - backend_config["serialized_metadata"] = ir.StringAttr.get( - triton_params["serialized_metadata"] - ) + backend_config["serialized_metadata"] = ir.StringAttr.get( + params.serialized_metadata + ) return mlir.custom_call( call_target_name="__gpu$xla.gpu.triton", result_types=out_types, @@ -178,10 +179,10 @@ def pallas_call_lowering( call_target_name="triton_kernel_call", result_types=[*map(mlir.aval_to_ir_type, ctx.avals_out)], operands=in_nodes, - backend_config=zlib.compress( + backend_config=zlib.compress( kernel_call.to_proto( debug_info.func_name, - triton_params.get("serialized_metadata") or b"", + params.serialized_metadata or b"", ) ), operand_layouts=avals_to_layouts(ctx.avals_in), diff --git a/jax/_src/pallas/triton/primitives.py b/jax/_src/pallas/triton/primitives.py index b845a4079ff4..2a15b3dbd47d 100644 --- a/jax/_src/pallas/triton/primitives.py +++ b/jax/_src/pallas/triton/primitives.py @@ -83,7 +83,7 @@ def elementwise_inline_asm( asm=asm, constraints=constraints, pack=pack, - result_shape_dtypes=result_shape_dtypes, + result_shape_dtypes=tuple(result_shape_dtypes), ) diff --git a/jax/_src/pallas/utils.py b/jax/_src/pallas/utils.py index a78c5487a4d6..15844da927e3 100644 --- a/jax/_src/pallas/utils.py +++ b/jax/_src/pallas/utils.py @@ -44,7 +44,7 @@ def cdiv(a: jax.Array, b: jax.Array) -> jax.Array: def cdiv(a: int | jax.Array, b: int | jax.Array) -> int | jax.Array: if isinstance(a, int) and isinstance(b, int): return (a + b - 1) // b - return lax.div(a + b - 1, b) + return lax.div(a + (b - 1), b) def strides_from_shape(shape: tuple[int, ...]) -> tuple[int, ...]: diff --git a/jax/_src/partition_spec.py b/jax/_src/partition_spec.py index bf6a90060bc8..5f743c9c141b 100644 --- a/jax/_src/partition_spec.py +++ b/jax/_src/partition_spec.py @@ -13,37 +13,31 @@ # limitations under the License. from __future__ import annotations +from typing import Any -class UnconstrainedSingleton: - - def __repr__(self): - return "UNCONSTRAINED" - - def __reduce__(self): - return (_get_default_unconstrained, ()) +from jax._src.lib import _jax +from jax._src.util import use_cpp_class, use_cpp_method +_UNCONSTRAINED_PARTITION = _jax.UNCONSTRAINED_PARTITION +_canonicalize_partition = _jax.canonicalize_partition -# Unconstrained sentinel value for PartitionSpec, representing a dimension for -# which the user wants XLA to assign the best partitioning. -# TODO(yashkatariya): May rename to AUTO. -_UNCONSTRAINED_PARTITION = UnconstrainedSingleton() -def _get_default_unconstrained(): - return _UNCONSTRAINED_PARTITION +def unpickle_pspec(partitions, unreduced, reduced): + return PartitionSpec(*partitions, unreduced=unreduced, reduced=reduced) -def _canonicalize_partition(partition): - if not partition: - return None - if partition is _UNCONSTRAINED_PARTITION: - return _UNCONSTRAINED_PARTITION - if isinstance(partition, (tuple, list)): - if len(partition) == 1: - return partition[0] - return tuple(partition) - return partition +def _get_ur_str(unreduced, reduced): + if unreduced and reduced: + return f"unreduced={set(unreduced)!r}, reduced={set(reduced)!r}" + elif unreduced and not reduced: + return f"unreduced={set(unreduced)!r}" + elif not unreduced and reduced: + return f"reduced={set(reduced)!r}" + assert False # unreachable +AxisName = Any -class PartitionSpec(tuple): +@use_cpp_class(_jax.PartitionSpec) +class PartitionSpec: """Tuple describing how to partition an array across a mesh of devices. Each element is either ``None``, a string, or a tuple of strings. @@ -52,38 +46,120 @@ class PartitionSpec(tuple): This class exists so JAX's pytree utilities can distinguish a partition specifications from tuples that should be treated as pytrees. """ + __match_args__ = ("_partitions",) # A sentinel value representing a dim is unconstrained. UNCONSTRAINED = _UNCONSTRAINED_PARTITION - def __init__(self, *partitions): - pass - - def __new__(cls, *partitions): - partitions = tuple(_canonicalize_partition(p) for p in partitions) - return tuple.__new__(PartitionSpec, partitions) + @use_cpp_method() + def __init__(self, *partitions, unreduced=frozenset(), reduced=frozenset()): + self._partitions = tuple(_canonicalize_partition(p) for p in partitions) + if not isinstance(unreduced, (set, frozenset)): + raise TypeError( + "`unreduced` argument of PartitionSpec should be of type" + f" `frozenset` or `set`. Got type {type(unreduced)}") + if not isinstance(reduced, (set, frozenset)): + raise TypeError( + "`reduced` argument of PartitionSpec should be of type" + f" `frozenset` or `set`. Got type {type(reduced)}") + self.unreduced = frozenset(unreduced) + self.reduced = frozenset(reduced) + # `__init__` is implemented in C++ so this check happens in C++ + # _check(self._partitions, self.unreduced, self.reduced) def __repr__(self): - return f"PartitionSpec{tuple.__repr__(self)}" + pr = repr(self._partitions)[1:-1] + if not self.unreduced and not self.reduced: + return f"PartitionSpec({pr})" + ur_str = _get_ur_str(self.unreduced, self.reduced) + pr = '' if not pr else f"{pr} " if pr.endswith(',') else f"{pr}, " + return (f"PartitionSpec({pr}{ur_str})") def __reduce__(self): - return (PartitionSpec, tuple(self)) + return (unpickle_pspec, (self._partitions, self.unreduced, self.reduced)) + + def __getitem__(self, i): + return self._partitions[i] + + def __iter__(self): + return iter(self._partitions) + def __len__(self): + return len(self._partitions) + + @use_cpp_method() def __eq__(self, other): - if not isinstance(other, tuple): + if isinstance(other, PartitionSpec): + return (self._partitions == other._partitions and + self.unreduced == other.unreduced and + self.reduced == other.reduced) + elif isinstance(other, tuple): + if self.unreduced: + raise TypeError( + f"other {other} cannot be of instance `tuple` when self {self} has" + " unreduced in `__eq__` of PartitionSpec.") + if self.reduced: + raise TypeError( + f"other {other} cannot be of instance `tuple` when self {self} has" + " reduced in `__eq__` of PartitionSpec.") + other_p = tuple(_canonicalize_partition(o) for o in other) + return self._partitions == other_p + else: return False - other = tuple(_canonicalize_partition(o) for o in other) - return super().__eq__(other) + @use_cpp_method() def __hash__(self): - return super().__hash__() + return hash((self._partitions, self.unreduced, self.reduced)) + + def __add__(self, other): + if isinstance(other, PartitionSpec): + return PartitionSpec( + *self, *other, + unreduced={*self.unreduced, *other.unreduced}, + reduced={*self.reduced, *other.reduced}) + elif isinstance(other, tuple): + if self.unreduced: + raise TypeError( + f"other {other} cannot be of instance `tuple` when self {self} has" + " unreduced in `__add__` of PartitionSpec.") + if self.reduced: + raise TypeError( + f"other {other} cannot be of instance `tuple` when self {self} has" + " reduced in `__add__` of PartitionSpec.") + return PartitionSpec(*self, *other) + else: + raise NotImplementedError + + def __radd__(self, other): + if not isinstance(other, tuple): + raise NotImplementedError + # other will always be a tuple. + if self.unreduced: + raise TypeError( + f"other {other} cannot be of instance `tuple` when self {self} has" + " unreduced in `__radd__` of PartitionSpec.") + if self.reduced: + raise TypeError( + f"other {other} cannot be of instance `tuple` when self {self} has" + " reduced in `__radd__` of PartitionSpec.") + return PartitionSpec(*other, *self) def index(self, value): - value = _canonicalize_partition(value) - return super().index(value) + return self._partitions.index(_canonicalize_partition(value)) + + def count(self, value): + return self._partitions.count(_canonicalize_partition(value)) + + def update(self, **kwargs): + return PartitionSpec(*kwargs.pop("partitions", self._partitions), + unreduced=kwargs.pop("unreduced", self.unreduced), + reduced=kwargs.pop("reduced", self.reduced)) def _normalized_spec_for_aval(self, ndim: int) -> PartitionSpec: - out = [None if p is _UNCONSTRAINED_PARTITION else p for p in self] + out = [None if p is _UNCONSTRAINED_PARTITION else p + for p in self._partitions] if len(out) < ndim: out.extend([None] * (ndim - len(out))) - return PartitionSpec(*out) + return self.update(partitions=out) + +PartitionSpec.__module__ = 'jax.sharding' diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index f7a4361ffee2..1aa6a4d73daf 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -20,8 +20,8 @@ import dataclasses from functools import partial import inspect +import itertools as it import logging -import operator as op import weakref from typing import NamedTuple, Any, Union, cast import warnings @@ -68,18 +68,20 @@ NamedSharding, GSPMDSharding, SingleDeviceSharding, PmapSharding, AUTO, UNSPECIFIED, UnspecifiedValue, prepare_axis_resources, parse_flatten_op_sharding, canonicalize_sharding, - flatten_spec) -from jax._src.layout import Layout, DeviceLocalLayout, AutoLayout -from jax._src.state import discharge as state_discharge, RefEffect, AbstractRef + flatten_spec, _internal_use_concrete_mesh) +from jax._src.layout import Format, DeviceLocalLayout, AutoLayout +from jax._src.state.types import RefEffect from jax._src.traceback_util import api_boundary from jax._src.tree_util import ( tree_flatten, tree_unflatten, treedef_is_leaf, tree_structure, tree_leaves, treedef_children, broadcast_prefix, all_leaves, prefix_errors, keystr, - PyTreeDef, none_leaf_registry as none_lr, tree_map) + PyTreeDef, none_leaf_registry as none_lr, tree_map, tree_flatten_with_path) from jax._src.util import ( - HashableFunction, safe_map, safe_zip, wraps, tuple_insert, - distributed_debug_log, split_list, weakref_lru_cache, + HashableFunction, safe_map, safe_zip, wraps, + distributed_debug_log, split_list, split_list_checked, weakref_lru_cache, merge_lists, subs_list, fun_name, fun_qual_name) +from jax._src.attrs import (Box, List, dne_sentinel, jax_setattr, jax_getattr, + jax_extendattr) map, unsafe_map = safe_map, map zip, unsafe_zip = safe_zip, zip @@ -94,48 +96,6 @@ logger = logging.getLogger(__name__) -def _find_arg_mismatch(arg_list, fails, fun_name): - mismatched_args_msg = [] - def mismatch(err): - for name, inp_da, aval in arg_list: - if err.m_type == pxla.MismatchType.ARG_SHARDING and err.da == inp_da: - mismatched_args_msg.append( - f"argument {name} of {fun_name} with shape {aval.str_short()} and " - f"{err._dev_ids_plat_str}") - break - first_err, second_err = fails - mismatch(first_err) - mismatch(second_err) - return mismatched_args_msg - - -def _device_assignment_mismatch_error(fun_name, fails, args_flat, api_name, - arg_names): - arg_list = [] - if arg_names is None: - arg_names = [''] * len(args_flat) - for a, n in zip(args_flat, arg_names): - da = (a.sharding._device_assignment - if getattr(a, 'sharding', None) is not None else None) - arg_list.append((n, da, core.shaped_abstractify(a))) - - mismatched_args_msg = _find_arg_mismatch(arg_list, fails, fun_name) - - if len(mismatched_args_msg) == 2: - first, second = mismatched_args_msg # pytype: disable=bad-unpacking - extra_msg = f" Got {first} and {second}" - elif len(mismatched_args_msg) == 1: - first, second = fails - # Choose the failure left which is not already covered by ARG_SHARDING. - left = second if first.m_type == pxla.MismatchType.ARG_SHARDING else first - extra_msg = f" Got {mismatched_args_msg[0]} and{left._str(api_name)}" - else: - first, second = fails - extra_msg = f" Got{first._str(api_name)} and{second._str(api_name)}" - msg = (f"Received incompatible devices for {api_name}ted computation.{extra_msg}") - return msg - - class PjitInfo(NamedTuple): """Things that we know about a jit instance before it is called. @@ -186,9 +146,8 @@ def _python_pjit_helper(fun: Callable, jit_info: PjitInfo, *args, **kwargs): args_flat = [*init_states, *args_flat] try: - if (core.trace_state_clean() and - not config.debug_key_reuse.value and - not config.data_dependent_tracing_fallback.value): + if (core.trace_state_clean() and not config.debug_key_reuse.value + and not p.params['jaxpr'].jaxpr.is_high): args_flat = map(core.full_lower, args_flat) core.check_eval_args(args_flat) out_flat, compiled, profiler = _pjit_call_impl_python(*args_flat, **p.params) @@ -196,10 +155,10 @@ def _python_pjit_helper(fun: Callable, jit_info: PjitInfo, *args, **kwargs): out_flat = pjit_p.bind(*args_flat, **p.params) compiled = None profiler = None - except pxla.DeviceAssignmentMismatchError as e: + except stages.DeviceAssignmentMismatchError as e: fails, = e.args fun_name = getattr(fun, '__qualname__', getattr(fun, '__name__', str(fun))) - msg = _device_assignment_mismatch_error( + msg = stages._device_assignment_mismatch_error( fun_name, fails, args_flat, 'jit', p.arg_names) raise ValueError(msg) from None except xla.InvalidInputException as e: @@ -217,44 +176,43 @@ def _python_pjit_helper(fun: Callable, jit_info: PjitInfo, *args, **kwargs): f"Argument '{name}' of shape {aval.str_short()} of type" f' {type(arg)} is not a valid JAX type.') from e raise AssertionError("Unreachable") from e - except dispatch.InternalFloatingPointError as e: + except api_util.InternalFloatingPointError as e: if getattr(fun, '_apply_primitive', False): raise FloatingPointError(f"invalid value ({e.ty}) encountered in {fun.__qualname__}") from None - dispatch.maybe_recursive_nan_check(e, fun, args, kwargs) + api_util.maybe_recursive_nan_check(e, fun, args, kwargs) + + if p.box_data: + box_treedef, out_tree = p.out_tree.children() + box_flat, out_flat = split_list_checked(out_flat, [box_treedef.num_leaves, out_tree.num_leaves]) + box_out = tree_unflatten(box_treedef, box_flat) + leaves = tree_leaves((args, kwargs)) + for (i, kind), b in zip(p.box_data, box_out): + if kind is pe.BoxAttr: + leaves[i].set(tree_unflatten(b.treedef, b.leaves)) + elif kind is pe.ListAttr: + for item in tree_unflatten(b.treedef, b.leaves): + leaves[i].append(item) + else: + assert False + else: + out_tree = p.out_tree if p.attrs_tracked: num_states_out = sum(end_tree.num_leaves for _, end_tree, _ in p.attrs_tracked) final_states, out_flat = split_list(out_flat, [num_states_out]) _set_states(p.attrs_tracked, final_states) - outs = tree_unflatten(p.out_tree, out_flat) - return (outs, out_flat, p.out_tree, args_flat, p.params['jaxpr'], - p.attrs_tracked, compiled, profiler) + outs = tree_unflatten(out_tree, out_flat) + return (outs, out_flat, out_tree, args_flat, p.params['jaxpr'], + p.attrs_tracked, p.box_data, compiled, profiler) -def _set_states(attrs_tracked, vals): - from jax.experimental.attrs import jax_setattr - valss = split_list(vals, [td.num_leaves for _, td, _ in attrs_tracked[:-1]]) - for ((_, treedef, (obj, attr)), leaves) in zip(attrs_tracked, valss): - val = tree_unflatten(treedef, leaves) - jax_setattr(obj, attr, val) - -def _get_states(attrs_tracked): - from jax.experimental.attrs import jax_getattr - vals = [] - for treedef, _, (obj, attr) in attrs_tracked: - tree = jax_getattr(obj, attr) - leaves, treedef_ = tree_flatten(tree) - assert treedef == treedef_ - vals.extend(leaves) - return vals - def _need_to_rebuild_with_fdo(pgle_profiler): return (pgle_profiler is not None and pgle_profiler.is_enabled() and not pgle_profiler.is_fdo_consumed()) def _get_fastpath_data( - executable, out_tree, args_flat, out_flat, attrs_tracked, effects, + executable, out_tree, args_flat, out_flat, attrs_tracked, box_data, effects, consts, abstracted_axes, pgle_profiler ) -> pxla.MeshExecutableFastpathData | None: out_reflattened, out_tree = pxla.reflatten_outputs_for_dispatch(out_tree, out_flat) @@ -271,6 +229,7 @@ def _get_fastpath_data( and abstracted_axes is None # no attr state effects and not attrs_tracked + and not box_data # no ref state effects and not any(isinstance(e, RefEffect) for e in effects) # no prng reuse checking @@ -300,12 +259,6 @@ def _get_fastpath_data( return fastpath_data -def _cpp_pjit_evict_fn(self): - self._clear_cache() - _create_pjit_jaxpr.evict_function(self._fun) # pytype: disable=attribute-error - _infer_params_cached.cache_clear() - - # The entries are doubled here from the default 4096 because _pjit_call_impl # also has a cpp dispatch path and that would double the number of entries in # the global shared cache. @@ -335,13 +288,12 @@ def cache_miss(*args, **kwargs): raise RuntimeError(f"re-tracing function {jit_info.fun_sourceinfo} for " "`jit`, but 'no_tracing' is set") - (outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked, executable, - pgle_profiler) = _python_pjit_helper(fun, jit_info, *args, **kwargs) + (outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked, box_data, + executable, pgle_profiler) = _python_pjit_helper(fun, jit_info, *args, **kwargs) maybe_fastpath_data = _get_fastpath_data( - executable, out_tree, args_flat, out_flat, attrs_tracked, jaxpr.effects, - jaxpr.consts, jit_info.abstracted_axes, - pgle_profiler) + executable, out_tree, args_flat, out_flat, attrs_tracked, box_data, + jaxpr.effects, jaxpr.consts, jit_info.abstracted_axes, pgle_profiler) return outs, maybe_fastpath_data, _need_to_rebuild_with_fdo(pgle_profiler) @@ -366,22 +318,69 @@ def cache_miss(*args, **kwargs): cpp_pjitted_f = wraps(fun)(cpp_pjit_f) cpp_pjitted_f._fun = fun - type(cpp_pjitted_f).clear_cache = _cpp_pjit_evict_fn + cpp_pjitted_f._jit_info = jit_info + # TODO(necula): move these to top-level; we don't need to do this for + # every jit + cpp_jitted_f_class = type(cpp_pjitted_f) + # TODO(necula): make clear_cache private, no need to have it part of the API + cpp_jitted_f_class.clear_cache = jit_evict_fn + cpp_jitted_f_class.lower = jit_lower + cpp_jitted_f_class.trace = jit_trace + cpp_jitted_f_class.eval_shape = jit_eval_shape + # We return directly the function produced by _xla.pjit, because we do not + # want to have Python in the dispatch path. return cpp_pjitted_f +@api_boundary +def jit_trace(jit_func, *args, **kwargs) -> stages.Traced: + p, args_flat = _infer_params(jit_func._fun, jit_func._jit_info, args, kwargs) + donate_argnums = tuple(i for i, d in enumerate(p.donated_invars) if d) + args_info = stages.make_args_info(p.in_tree, p.in_avals, donate_argnums) + lower_callable = partial(_resolve_and_lower, args_flat, **p.params, + pgle_profiler=None) + return stages.Traced( + p.params['jaxpr'], args_info, p.params["name"], p.out_tree, + lower_callable, args_flat, p.arg_names, p.num_consts) + + +@api_boundary +def jit_lower(jit_func, *args, **kwargs): + return jit_trace(jit_func, *args, **kwargs).lower() + +@api_boundary +def jit_eval_shape(jit_func, *args, **kwargs): + p, _ = _infer_params(jit_func._fun, jit_func._jit_info, args, kwargs) + out_shardings = [None if isinstance(s, UnspecifiedValue) else s + for s in p.params['out_shardings']] + out = [] + for a, out_s in zip(p.params['jaxpr'].out_avals, out_shardings): + if out_s is None: + s = a.sharding if a.sharding.mesh._are_all_axes_explicit else out_s + else: + s = out_s + # TODO(yashkatariya): Add `Layout` to SDS. + out.append(api.ShapeDtypeStruct(a.shape, a.dtype, sharding=s, + weak_type=a.weak_type)) + return tree_unflatten(p.out_tree, out) + +def jit_evict_fn(self): + self._clear_cache() + _create_pjit_jaxpr.evict_function(self._fun) # pytype: disable=attribute-error + _infer_params_cached.cache_clear() + def _split_layout_and_sharding(entries): entries_flat, treedef = tree_flatten(entries, is_leaf=lambda x: x is None) layouts, shardings = [], [] for e in entries_flat: - if isinstance(e, Layout): + if isinstance(e, Format): layouts.append(e.device_local_layout) shardings.append(e.sharding) elif isinstance(e, (DeviceLocalLayout, AutoLayout)): raise ValueError( '`jax.jit` does not accept device-local layouts directly. Create ' - 'a `Layout` instance wrapping this device-local layout and pass ' + 'a `Format` instance wrapping this device-local layout and pass ' f'that to `jit` instead. Got {e}') else: layouts.append(None) @@ -391,14 +390,16 @@ def _split_layout_and_sharding(entries): return tree_unflatten(treedef, layouts), tree_unflatten(treedef, shardings) -def _parse_jit_arguments(fun: Callable, in_shardings: Any, out_shardings: Any, - donate_argnums: int | Sequence[int] | None, - donate_argnames: str | Iterable[str] | None, +def _parse_jit_arguments(fun: Callable, *, in_shardings: Any, + out_shardings: Any, static_argnums: int | Sequence[int] | None, static_argnames: str | Iterable[str] | None, - device: xc.Device | None, backend: str | None, - abstracted_axes: Any | None, keep_unused: bool, - inline: bool, compiler_options: dict[str, Any] | None, + donate_argnums: int | Sequence[int] | None, + donate_argnames: str | Iterable[str] | None, + keep_unused: bool, device: xc.Device | None, + backend: str | None, inline: bool, + abstracted_axes: Any | None, + compiler_options: dict[str, Any] | None, use_resource_env: bool) -> PjitInfo: """Parses the arguments to jit/pjit. @@ -413,8 +414,10 @@ def _parse_jit_arguments(fun: Callable, in_shardings: Any, out_shardings: Any, if backend is not None or device is not None: warnings.warn( 'backend and device argument on jit is deprecated. You can use' - ' `jax.device_put(..., jax.local_devices("cpu")[0])` on the inputs to' - ' the jitted function to get the same behavior.', DeprecationWarning) + ' `jax.device_put(..., jax.local_devices(backend="cpu")[0])` on the' + ' inputs to the jitted function to get the same behavior.', + DeprecationWarning, + ) if device is not None and backend is not None: raise ValueError("can't specify both a device and a backend for jit, " f"got {device=} and {backend=}") @@ -437,7 +440,8 @@ def _parse_jit_arguments(fun: Callable, in_shardings: Any, out_shardings: Any, out_layouts, out_shardings = _split_layout_and_sharding(out_shardings) in_shardings = prepare_axis_resources(in_shardings, 'in_shardings') - out_shardings = prepare_axis_resources(out_shardings, 'out_shardings') + out_shardings = prepare_axis_resources(out_shardings, 'out_shardings', + allow_unconstrained_dims=True) user_specified_in_shardings = (in_shardings is not None and not isinstance(in_shardings, UnspecifiedValue)) @@ -476,56 +480,30 @@ def _parse_jit_arguments(fun: Callable, in_shardings: Any, out_shardings: Any, use_resource_env=use_resource_env, compiler_options_kvs=compiler_options_kvs) - -def _make_jit_wrapper(fun: Callable, jit_info: PjitInfo): - - @api_boundary - def lower(*args, **kwargs): - return trace(*args, **kwargs).lower() - - @api_boundary - def eval_shape(*args, **kwargs): - p, _ = _infer_params(fun, jit_info, args, kwargs) - out_s = [None if isinstance(s, UnspecifiedValue) else s for s in p.params['out_shardings']] - # TODO(yashkatariya): Add `Layout` to SDS. - out = [api.ShapeDtypeStruct(x.shape, x.dtype, sharding=s, - weak_type=x.weak_type) - for x, s in zip(p.params['jaxpr'].out_avals, out_s)] - return tree_unflatten(p.out_tree, out) - - @api_boundary - def trace(*args, **kwargs) -> stages.Traced: - p, args_flat = _infer_params(fun, jit_info, args, kwargs) - donate_argnums = tuple(i for i, d in enumerate(p.donated_invars) if d) - args_info = stages.make_args_info(p.in_tree, p.in_avals, donate_argnums) - lower_callable = partial(_resolve_and_lower, args_flat, **p.params, - pgle_profiler=None) - return stages.Traced( - p.params['jaxpr'], args_info, p.params["name"], p.out_tree, - lower_callable, args_flat, p.arg_names, p.num_consts) - - wrapped = _cpp_pjit(fun, jit_info) - wrapped.lower = lower - wrapped.eval_shape = eval_shape - wrapped.trace = trace - return wrapped - - -def make_jit(fun: Callable, in_shardings: Any, out_shardings: Any, - donate_argnums: int | Sequence[int] | None, - donate_argnames: str | Iterable[str] | None, +def make_jit(fun: Callable, + *, + in_shardings: Any, + out_shardings: Any, static_argnums: int | Sequence[int] | None, static_argnames: str | Iterable[str] | None, - device: xc.Device | None, backend: str | None, - abstracted_axes: Any | None, keep_unused: bool, - inline: bool, compiler_options: dict[str, Any] | None, + donate_argnums: int | Sequence[int] | None, + donate_argnames: str | Iterable[str] | None, + keep_unused: bool, + device: xc.Device | None, + backend: str | None, + inline: bool, + abstracted_axes: Any | None, + compiler_options: dict[str, Any] | None, use_resource_env: bool) -> Any: """jit() and pjit() are thin wrappers around this function.""" jit_info = _parse_jit_arguments( - fun, in_shardings, out_shardings, donate_argnums, donate_argnames, - static_argnums, static_argnames, device, backend, abstracted_axes, - keep_unused, inline, compiler_options, use_resource_env) - return _make_jit_wrapper(fun, jit_info) + fun, in_shardings=in_shardings, out_shardings=out_shardings, + static_argnums=static_argnums, static_argnames=static_argnames, + donate_argnums=donate_argnums, donate_argnames=donate_argnames, + keep_unused=keep_unused, device=device, backend=backend, inline=inline, + abstracted_axes=abstracted_axes, compiler_options=compiler_options, + use_resource_env=use_resource_env) + return _cpp_pjit(fun, jit_info) class PjitParams(NamedTuple): @@ -537,7 +515,8 @@ class PjitParams(NamedTuple): donated_invars: tuple[bool, ...] arg_names: tuple[str, ...] num_consts: int - attrs_tracked: list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str]]] + attrs_tracked: list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str, Any]]] + box_data: list def _infer_params_impl( @@ -566,8 +545,13 @@ def _infer_params_impl( f = lu.wrap_init(fun, debug_info=dbg) f, dyn_args = argnums_partial_except(f, ji.static_argnums, args, allow_invalid=True) del args - f, dyn_kwargs = argnames_partial_except(f, ji.static_argnames, kwargs) + del kwargs + + dyn_args, dyn_kwargs, box_data = _flatten_boxes(dbg, dyn_args, dyn_kwargs) + if box_data: + f = _handle_boxes(f, dbg) + explicit_args, in_tree = tree_flatten((dyn_args, dyn_kwargs)) flat_fun, out_tree = flatten_fun(f, in_tree) flat_fun, explicit_args = hoist_obj_attrs(flat_fun, explicit_args) @@ -604,8 +588,12 @@ def _infer_params_impl( assert in_avals is None in_type = pe.infer_lambda_input_type(axes_specs, explicit_args) in_avals = tuple(a for a, e in in_type if e) + elif box_data: + in_type = in_avals = tuple(core.shaped_abstractify(x) for x in explicit_args) # type: ignore else: in_type = in_avals # type: ignore + in_type = tuple(core.AvalQDD(a, core.cur_qdd(x)) if a.has_qdd # type: ignore + else a for a, x in zip(in_type, explicit_args)) assert in_avals is not None in_shardings_flat, in_layouts_flat = _process_in_axis_resources( @@ -613,14 +601,14 @@ def _infer_params_impl( ji.in_layouts_treedef, ji.in_layouts_leaves, in_avals, in_tree, flat_fun.debug_info, device_or_backend_set, have_kwargs) - attr_token = _attr_token(flat_fun, in_type) + attr_token = _attr_cache_index(flat_fun, in_type) jaxpr, consts, out_avals, attrs_tracked = _create_pjit_jaxpr( flat_fun, in_type, attr_token, IgnoreKey(ji.inline)) if config.mutable_array_checks.value: _check_no_aliased_closed_over_refs(dbg, (*jaxpr.consts, *consts), explicit_args) - _attr_update(flat_fun, in_type, attr_token, attrs_tracked) + _attr_cachedata_update(flat_fun, in_type, attr_token, attrs_tracked) out_shardings_flat, out_layouts_flat = _check_and_canonicalize_out_shardings( out_shardings_treedef, out_shardings_leaves, ji.out_layouts_treedef, @@ -636,13 +624,14 @@ def _infer_params_impl( implicit_args = [] args_flat = [*implicit_args, *explicit_args] - num_states_in = sum(init_tree.num_leaves for init_tree, _, _ in attrs_tracked) - num_extra_args = len(implicit_args) + num_states_in + len(consts) + num_attrs_in = sum(init_tree.num_leaves for init_tree, _, (_, _, kind) + in attrs_tracked if kind in (pe.ReadWrite, pe.BoxAttr)) + num_extra_args = len(implicit_args) + num_attrs_in + len(consts) in_shardings_flat = (UNSPECIFIED,) * num_extra_args + in_shardings_flat in_layouts_flat = (None,) * num_extra_args + in_layouts_flat donated_invars = (False,) * num_extra_args + donated_invars assert (len(in_shardings_flat) == len(in_layouts_flat) == - len(donated_invars) == num_states_in + len(consts) + len(args_flat)) + len(donated_invars) == num_attrs_in + len(consts) + len(args_flat)) params = dict( jaxpr=jaxpr, @@ -659,7 +648,7 @@ def _infer_params_impl( ) return PjitParams(consts, params, in_avals, in_tree, out_tree(), donated_invars, dbg.arg_names, len(consts), - attrs_tracked), args_flat + attrs_tracked, box_data), args_flat class InferParamsCacheEntry: @@ -672,7 +661,7 @@ def __init__(self): # We use an outer cache that is keyed on the signature of the arguments, but # when populating a cache entry using _infer_params_impl, we need to provide -# actual arguments. In principle we could refactor _infer_params_impl to look +# actual arguments. In principle, we could refactor _infer_params_impl to look # only at an argument signature instead of args/kwargs in those cases that we # cache, but this was a more minimal change. @util.weakref_lru_cache @@ -689,8 +678,10 @@ def _infer_params_cached( def _infer_params( fun: Callable, ji: PjitInfo, args: tuple[Any, ...], kwargs: dict[str, Any] ) -> tuple[PjitParams, list[Any]]: - if ji.use_resource_env: - with sharding_impls.use_mesh(mesh_lib.thread_resources.env.physical_mesh): + if ji.use_resource_env: # pjit + phys_mesh = mesh_lib.thread_resources.env.physical_mesh + with (_internal_use_concrete_mesh(phys_mesh), + mesh_lib.use_abstract_mesh(phys_mesh.abstract_mesh)): return _infer_params_internal(fun, ji, args, kwargs) return _infer_params_internal(fun, ji, args, kwargs) @@ -703,7 +694,8 @@ def _infer_params_internal( static_argnames=ji.static_argnames, sourceinfo=ji.fun_sourceinfo, signature=ji.fun_signature) - if config.dynamic_shapes.value: # if dynamic shapes, don't use the cache + any_boxes = any(isinstance(x, (Box, List)) for x in tree_leaves((args, kwargs))) + if config.dynamic_shapes.value or any_boxes: # don't use the cache p, args_flat = _infer_params_impl(fun, ji, ctx_mesh, dbg, args, kwargs, in_avals=None) return p, p.consts + args_flat @@ -717,7 +709,7 @@ def _infer_params_internal( if entry.pjit_params is None: p, args_flat = _infer_params_impl( fun, ji, ctx_mesh, dbg, args, kwargs, in_avals=avals) - if p.attrs_tracked: # if attrs, don't popoulate the cache + if p.attrs_tracked or p.box_data or p.params['jaxpr'].jaxpr.is_high: return p, p.consts + args_flat entry.pjit_params = p return entry.pjit_params, entry.pjit_params.consts + dynargs @@ -729,16 +721,16 @@ def _infer_input_type(fun: Callable, dbg: core.DebugInfo, for i, x in enumerate(explicit_args): avals.append(core.shaped_abstractify(x)) except OverflowError: - arg_path = f"argument path is {dbg.arg_names[i]}" + arg_path = f"argument path is {dbg.arg_names[i]}" # pytype: disable=name-error raise OverflowError( "An overflow was encountered while parsing an argument to a jitted " f"computation, whose {arg_path}." ) from None except TypeError: - arg_description = f"path {dbg.arg_names[i]}" + arg_description = f"path {dbg.arg_names[i]}" # pytype: disable=name-error raise TypeError( f"Error interpreting argument to {fun} as an abstract array." - f" The problematic value is of type {type(x)} and was passed to" + f" The problematic value is of type {type(x)} and was passed to" # pytype: disable=name-error f" the function at {arg_description}.\n" "This typically means that a jit-wrapped function was called with a non-array" " argument, and this argument was not marked as static using the" @@ -942,7 +934,7 @@ def pjit( be donated. For more details on buffer donation see the - `FAQ `_. + `FAQ `_. donate_argnames: An optional string or collection of strings specifying which named arguments are donated to the computation. See the comment on ``donate_argnums`` for details. If not @@ -984,9 +976,12 @@ def pjit( [ 0.5 2. 4. 6. 8. 10. 12. 10. ] """ return make_jit( - fun, in_shardings, out_shardings, donate_argnums, donate_argnames, - static_argnums, static_argnames, device, backend, abstracted_axes, - keep_unused, inline, compiler_options, use_resource_env=True) + fun, in_shardings=in_shardings, out_shardings=out_shardings, + static_argnums=static_argnums, static_argnames=static_argnames, + donate_argnums=donate_argnums, donate_argnames=donate_argnames, + keep_unused=keep_unused, device=device, backend=backend, inline=inline, + abstracted_axes=abstracted_axes, compiler_options=compiler_options, + use_resource_env=True) def hashable_pytree(pytree): @@ -1131,139 +1126,269 @@ def _process_in_axis_resources(in_shardings_treedef, in_shardings_leaves, debug_info.safe_arg_names(len(in_avals)), "jit arguments") # type: ignore[arg-type] return in_shardings_flat, in_layouts_flat -callsites: set[str] = set() +callsites_with_tracing_cache_miss: set[str] = set() + +def diff_tracing_cache_keys( + k: tuple, oldk: tuple, debug_info: lu.DebugInfo) -> tuple[Sequence[str], int]: + """Explanations of differences between the cache keys, along with diff sizes. + + Result: a pair of a list of explanations for differences, and the total size + of the differences. The sizes are used to pick the old key with the smallest + different size for the explanation that is shown to the user. + """ + (fun_transforms_k, fun_params_k, fun_in_type_k, + (arg_in_type_k, arg_attr_data_k, arg_inline_k), ctx_k) = k + (fun_transforms_ok, fun_params_ok, fun_in_type_ok, + (arg_in_type_ok, arg_attr_data_ok, arg_inline_ok), ctx_ok) = oldk + + diffs: list[tuple[str, int]] = [] # each difference with its size + def unavailable(key_field: str, what_k, what_ok): + diffs.append( + (f"different {key_field}:\n now: {what_k}\n != before: {what_ok}.\n" + "explanation unavailable! " + "please open an issue at https://github.com/jax-ml/jax.", + 10)) + + def list_diff_size(s1: Sequence, s2: Sequence) -> int: + min_len = min(len(s1), len(s2)) + diff_size = max(len(s1), len(s2)) - min_len + diff_size += sum(e1 != e2 for e1, e2 in zip(s1[:min_len], + s2[:min_len])) + return diff_size + + different_leaf_count = False + + def explain_transform_argnums_partial(param_k: tuple, param_ok: tuple): + dyn_argnums_k, static_args_k = param_k + dyn_argnums_ok, static_args_ok = param_ok + if dyn_argnums_k != dyn_argnums_ok: + diffs.append( + ("different static_argnums:\n" + f" dynamic argnums now {dyn_argnums_k} and before {dyn_argnums_ok}", + 1)) + if static_args_k != static_args_ok: + diffs.append( + ("different value of static args:\n" + f" now {', '.join(repr(a.val) for a in static_args_k)}" + f" and before {', '.join(repr(a.val) for a in static_args_ok)}", + list_diff_size(static_args_k, static_args_ok))) + + def explain_transform_argnames_partial(param_k: tuple, param_ok: tuple): + static_kwargs_k, = param_k + static_kwargs_ok, = param_ok + static_kwargs_k = [(k, v.val) for k, v in + sorted(static_kwargs_k.val.items())] + static_kwargs_ok = [(k, v.val) for k, v in + sorted(static_kwargs_ok.val.items())] + if static_kwargs_k != static_kwargs_ok: + diffs.append( + ("different value of static kwargs:\n" + f" now {{{', '.join(f'{k}: {repr(v)}' for k, v in static_kwargs_k)}}}" + f" and before {{{', '.join(f'{k}: {repr(v)}' for k, v in static_kwargs_ok)}}}", + list_diff_size(static_kwargs_k, static_kwargs_ok))) + + def explain_in_tree_diff(in_tree_k: PyTreeDef, in_tree_ok: PyTreeDef): + nonlocal different_leaf_count + different_leaf_count = (in_tree_k.num_leaves != in_tree_ok.num_leaves) + if not different_leaf_count: + # Look for the special case of passing positional args as kwargs or + # vice-versa; the common prefix of positional args match. + args_tree_k, kwargs_tree_k = treedef_children(in_tree_k) + nr_args_k = len(treedef_children(args_tree_k)) + args_tree_ok, kwargs_tree_ok = treedef_children(in_tree_ok) + nr_args_ok = len(treedef_children(args_tree_k)) + if (treedef_children(args_tree_k)[:min(nr_args_k, nr_args_ok)] == + treedef_children(args_tree_ok)[:min(nr_args_k, nr_args_ok)]): + keys_k = kwargs_tree_k.node_data()[1] # type: ignore[index] + keys_ok = kwargs_tree_ok.node_data()[1] # type: ignore[index] + diffs.append( + (("different number of args and kwargs, but same total number.\n" + f" now {nr_args_k} args and kwargs " + f"with keys {keys_k}\n" + f" before {nr_args_ok} args and kwargs " + f"with keys {keys_ok}"), + abs(nr_args_ok - nr_args_k))) + return + + in_tree_k_str = str(in_tree_k) + in_tree_k_str = (in_tree_k_str if len(in_tree_k_str) < 73 + else in_tree_k_str[:73] + "...") + in_tree_ok_str = str(in_tree_ok) + in_tree_ok_str = (in_tree_ok_str if len(in_tree_ok_str) < 73 + else in_tree_ok_str[:73] + "...") + diff = [f"different input pytree:\n now: {in_tree_k_str}\n" + f" before: {in_tree_ok_str}"] + + errs = list(tree_util.equality_errors_pytreedef(in_tree_k, in_tree_ok)) + for path, thing1, thing2, explanation in errs: + fst, *path = path # type: ignore + base = ["args", "kwargs"][fst.idx] + diff.append( + f" * at {base}{keystr(tuple(path))}, now {thing1} and before {thing2}," + f" so {explanation}") + diffs.append(("\n".join(diff), len(errs))) + + def explain_args_type_diff(args_k: tuple[core.AbstractValue], + args_ok: tuple[core.AbstractValue]): + diff_size = 0 + arg_names = debug_info.safe_arg_names(len(args_k)) + def arg_type_to_str(at): + if hasattr(at, "str_short"): + return at.str_short(short_dtypes=True) + else: + return str(at) + args_k_str = ", ".join(f"{an}: {arg_type_to_str(at)}" + for an, at in zip(arg_names, args_k)) + args_k_str = args_k_str if len(args_k_str) < 73 else args_k_str[:73] + "..." + diff = [f"different input types:\n types now: {args_k_str}"] + add_weak_type_hint = False + + for name, arg_t_k, arg_t_ok in zip(arg_names, args_k, args_ok): + if arg_t_k == arg_t_ok: continue + this_arg_diff_size = 0 + if type(arg_t_k) == type(arg_t_ok) == core.ShapedArray: + s1, s2 = arg_type_to_str(arg_t_k), arg_type_to_str(arg_t_ok) + this_arg_diff_size += list_diff_size(arg_t_k.shape, arg_t_ok.shape) # type: ignore + + if arg_t_k.weak_type != arg_t_ok.weak_type: # type: ignore + s1 += f"{{weak_type={arg_t_k.weak_type}}}" # type: ignore + s2 += f"{{weak_type={arg_t_ok.weak_type}}}" # type: ignore + add_weak_type_hint = True + this_arg_diff_size += 1 + elif arg_t_k.sharding != arg_t_ok.sharding: # type: ignore + s1 = arg_t_k.str_short(short_dtypes=True, mesh_axis_types=True) # type: ignore + s2 = arg_t_ok.str_short(short_dtypes=True, mesh_axis_types=True) # type: ignore + this_arg_diff_size += 1 + else: + s1, s2 = str(arg_t_k), str(arg_t_ok) + diff_size += max(1, this_arg_diff_size) + diff.append(f" * at {name}, now {s1} and before {s2}") + + if add_weak_type_hint: + diff.append( + "where weak_type=True often means a Python builtin numeric value, and \n" + "weak_type=False means a jax.Array.\n" + "See https://docs.jax.dev/en/latest/type_promotion.html#weak-types.") + diffs.append(("\n".join(diff), diff_size)) + + if fun_transforms_k != fun_transforms_ok: + if len(fun_transforms_k) != len(fun_transforms_ok): + different_leaf_count = True # Skip other more precise checks + unavailable("fun_transforms length", + fun_transforms_k, fun_transforms_ok) + else: + for i, (t, ot) in enumerate(zip(fun_transforms_k, fun_transforms_ok)): + t_name = t[0].__name__ + if t == ot: continue + + # TODO(mattjj): explain box cache misses + if t_name == '_handle_boxes': continue + + if t[0] != ot[0]: + unavailable(f"fun_transforms[{i}] transform", t, ot) + continue + if t_name == "flatten_fun": + explain_in_tree_diff(t[1][0], ot[1][0]) + continue + if t_name == "_argnums_partial": + explain_transform_argnums_partial(t[1], ot[1]) + continue + if t_name == "_argnames_partial": + explain_transform_argnames_partial(t[1], ot[1]) + continue + unavailable(f"fun_transforms.{t_name} params", t[1:], ot[1:]) + continue + + # If we had different leaf counts, we can discard the _argnums_partial + # difference. That transform sometimes occurs before the flatten_fun + if different_leaf_count: + diffs = [d for d in diffs if "fun_transforms._argnums_partial" not in d[0]] + if fun_params_k != fun_params_ok: + unavailable("fun_params", fun_params_k, fun_params_ok) + if fun_in_type_k != fun_in_type_ok: + unavailable("fun_in_type", fun_params_k, fun_params_ok) + if arg_in_type_k != arg_in_type_ok and not different_leaf_count: + explain_args_type_diff(arg_in_type_k, arg_in_type_ok) + if arg_attr_data_k != arg_attr_data_ok: + unavailable("arg_attr_data", arg_attr_data_k, arg_attr_data_ok) + if arg_inline_k != arg_inline_ok: + unavailable("arg_inline", arg_inline_k, arg_inline_ok) + if ctx_k != ctx_ok: + assert len(ctx_k) == len(ctx_ok) + idxs = [f" [{i}]: now {c_k} and before {c_ok}" + for i, (c_k, c_ok) in enumerate(zip(ctx_k, ctx_ok)) if c_k != c_ok] + diffs.append( + ("different tracing context, e.g. due to config or context manager.\n" + "found differences at positions\n" + + ", and\n".join(idxs) + + "\ncompare to tuple returned by " + "config.trace_context() in jax/_src/config.py.", + len(idxs))) + if not diffs: # Should never happen, but let's not crash + unavailable("something (unexpected empty diffs)", k, oldk) + diffs_and_sizes = util.unzip2(sorted(diffs, key=lambda d: d[1])) + return (diffs_and_sizes[0], sum(diffs_and_sizes[1])) + def explain_tracing_cache_miss( - fun: lu.WrappedFun, unseen_f: bool, cache: dict, key: tuple): + fun: lu.WrappedFun, unseen_f: bool, cache: dict, + key: tuple, elapsed_sec: float): if config.check_tracer_leaks.value: return - - def unpack(key): - transforms, (), _, (in_type, _, inline), *_, ctx = key - # TODO(dougalm,mattjj): enable cache miss explanation with attrs - _, (_, (in_tree,)), *_ = transforms - return in_tree, in_type, inline.val, ctx - in_tree, in_type, inline, ctx = unpack(key) - if inline: return + if key[3][2].val: return # No explanations for "inline" functions debug_info = fun.debug_info + func_filename = debug_info.func_filename + if func_filename and not source_info_util.is_user_filename(func_filename): + return + msg: list[str] = [] p = msg.append - done = lambda: logger.log(logging.WARNING, '\n'.join(msg)) + done = lambda: logger.log(logging.WARNING, "\n".join(msg)) callsite = source_info_util.summarize(source_info_util.current()) - p(f"TRACING CACHE MISS at {callsite} because:") + p(f"TRACING CACHE MISS at {callsite} costing {elapsed_sec * 1e3:.3f} ms because:") # have we seen this function before at all? - fun_name = getattr(fun.f, '__qualname__', fun.f) - if debug_info.func_src_info: - # TODO(necula): clean up the extraction of the source info - _, *rest = debug_info.func_src_info.split(' at ') - src_info = " defined at " + ' '.join(rest) - else: - src_info = '' - if unseen_f: - p(f" never seen function:\n {fun_name} id={id(fun.f)}{src_info}") - if callsite in callsites: + src_info = "" + if func_filename: + src_info += f" defined at {func_filename}" + if func_lineno := debug_info.func_lineno: + src_info += f":{func_lineno}" + func_name = debug_info.func_name + if unseen_f or not cache: + p(f" never seen function:\n {func_name} id={id(fun.f)}{src_info}") + if callsite in callsites_with_tracing_cache_miss: p(" but seen another function defined on the same line; maybe the function is\n" " being re-defined repeatedly, preventing caching?") - callsites.add(callsite) - return done() - else: - p(f" for {fun_name}{src_info}") - - seen_keys = map(unpack, cache.keys()) - - # have we maybe switched some args to be kwargs or visa-versa? - args_tree, kwargs_tree = treedef_children(in_tree) - args_kwargs_trees = [treedef_children(k) for k, *_ in seen_keys] - args_kwargs_match = [t for t in args_kwargs_trees - if t == [args_tree, kwargs_tree]] - if not args_kwargs_match: - num_args = len(treedef_children(args_tree)) - _, kwarg_keys = kwargs_tree.node_data() # type: ignore - p(f" never seen passing {num_args} positional args and {len(kwarg_keys)} " - "keyword args with keys:\n" - f" {', '.join(map(repr, kwarg_keys))}") - dont_match = [set(t[1].node_data()[1]) for t in args_kwargs_trees # type: ignore - if t != [args_tree, kwargs_tree]] - close_kwargs = min( - dont_match, key=set(kwarg_keys).symmetric_difference, default=None - ) - if not close_kwargs: - p(" closest seen is passing no keyword args") else: - p(f" closest seen passes {len(close_kwargs)} keyword args with keys:\n" - f" {', '.join(map(repr, close_kwargs))}") + callsites_with_tracing_cache_miss.add(callsite) return done() - # have we never seen this tracing context before? - ctxs_match = [c for *_, c in seen_keys if c == ctx] - if not ctxs_match: - p(" tracing context doesn't match, e.g. due to config or context manager") - dont_match = [c for *_, c in seen_keys if c != ctx] - closest_ctx = min(dont_match, key=lambda c: sum(map(op.ne, c, ctx))) - idxs = [i for i, (c1, c2) in enumerate(zip(ctx, closest_ctx)) if c1 != c2] - p(" closest seen context tuple differs at positions:\n" - f" {', '.join(map(str, idxs))}\n" - " compare to tuple returned by config._trace_context() in jax/_src/config.py.") - return done() + p(f" for {func_name}{src_info}") + + diffs = [diff_tracing_cache_keys(key, ok, debug_info) + for ok in cache.keys() if key != ok] + assert diffs, "we must find some diffs if key differs from all cache keys" + min_diff = min(diffs, key=lambda v: v[1]) + smallest_diffs: Sequence[Sequence[str]] # the diffs for the closest keys + smallest_diffs = [d[0] for d in diffs if d[1] == min_diff[1]] + def indent_subsequent_lines(indent: int, msg: str) -> str: + return msg.replace("\n", "\n" + " " * indent) + def p_one_diff(diff: Sequence[str]): + for d in diff: + p(" * key with " + indent_subsequent_lines(4, d)) + + if len(smallest_diffs) == 1: + p(" all previously seen cache keys are different. Closest previous key:") + p_one_diff(smallest_diffs[0]) + else: + p(" all previously seen cache keys are different. " + "Several previous keys are closest:") + for d in smallest_diffs: + p_one_diff(d) - # have we never seen this input pytree before? - trees_match = [k for k in seen_keys if k[0] == in_tree] - if not trees_match: - in_tree_str = f':\n {in_tree}' if len(str(in_tree)) < 76 else '' - p(f" never seen input pytree{in_tree_str}") - dont_match = [t for t, *_ in seen_keys if t != in_tree] - closest_tree = min(dont_match, key=lambda t: abs(t.num_leaves - in_tree.num_leaves)) - errs = list(tree_util.equality_errors_pytreedef(in_tree, closest_tree)) # type: ignore[arg-type] - p(f" closest seen input pytree has {len(errs)} mismatches, including:") - for path, thing1, thing2, explanation in errs: - fst, *path = path # type: ignore - base = ['args', 'kwargs'][fst.idx] - p(f" * at {base}{keystr(tuple(path))}, seen {thing2} but now given {thing1}," - f" so {explanation}") - return done() + done() - # have we never seen these input types (eg shapes, dtypes) before? - types_match = [k for k in trees_match if k[1] == in_type] - if not types_match: - if len(in_type) < 5: - in_type_str = ':\n {}'.format(', '.join( - f'{n}: {ty.str_short(short_dtypes=True)}' - for n, ty in zip(debug_info.arg_names, in_type))) - else: - in_type_str = '' - p(f" never seen input type signature{in_type_str}") - dont_match = [t for _, t, *_ in trees_match if t != in_type] - closest_ty = min(dont_match, key=lambda t: sum(map(op.ne, t, in_type))) - num_mismatch = sum(map(op.ne, closest_ty, in_type)) - p(f" closest seen input type signature has {num_mismatch} mismatches, including:") - add_weak_type_hint = False - arg_names = debug_info.safe_arg_names(len(in_type)) - - for name, ty1, ty2 in zip(arg_names, closest_ty, in_type): - if ty1 != ty2: - if type(ty1) == type(ty2) == core.ShapedArray: - s1, s2 = ty1.str_short(True), ty2.str_short(True) - if ty1.weak_type != ty2.weak_type: - s1 += f'{{weak_type={ty1.weak_type}}}' - s2 += f'{{weak_type={ty2.weak_type}}}' - add_weak_type_hint = True - elif ty1.sharding != ty2.sharding: - s1 = ty1.str_short(short_dtypes=True, mesh_axis_types=True) - s2 = ty2.str_short(short_dtypes=True, mesh_axis_types=True) - else: - s1, s2 = str(ty1), str(ty2) - p(f" * at {name}, seen {s1}, but now given {s2}") - if add_weak_type_hint: - p('where weak_type=True often means a Python builtin numeric value, and ') - p('weak_type=False means a jax.Array.') - p('See https://jax.readthedocs.io/en/latest/type_promotion.html#weak-types') - return done() - - # we think this is unreachable... - p("explanation unavailable! please open an issue at https://github.com/jax-ml/jax") - return done() @partial(lu.cache, explain=explain_tracing_cache_miss) def _create_pjit_jaxpr( @@ -1272,7 +1397,7 @@ def _create_pjit_jaxpr( attr_data: int, ignored_inline: IgnoreKey ) -> tuple[core.ClosedJaxpr, list[Any], list[core.AbstractValue], - list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str]]]]: + list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str, Any]]]]: util.test_event("create_pjit_jaxpr") del ignored_inline # just for explain_cache_miss if config.no_tracing.value: @@ -1286,16 +1411,16 @@ def _create_pjit_jaxpr( lu.annotate(fun, cast(core.InputType, in_type))) attrs_tracked = [] else: - jaxpr, global_out_avals, consts, attrs_tracked = pe.trace_to_jaxpr_dynamic( - fun, in_type) - # assert attr_data is sentinel or attr_data matches attrs_tracked + jaxpr, global_out_avals, consts, attrs_tracked = pe.trace_to_jaxpr_dynamic(fun, in_type) if config.debug_key_reuse.value: # Import here to avoid circular imports - from jax.experimental.key_reuse._core import check_key_reuse_jaxpr + from jax.experimental.key_reuse._core import check_key_reuse_jaxpr # pytype: disable=import-error check_key_reuse_jaxpr(jaxpr) - if any(isinstance(c, core.Tracer) for c in consts): + # TODO(mattjj,yashkatariya): if we take the 'true' path then we *must* fall + # off the C++ dispatch fast path for correctness. Ensure that happens. + if any(isinstance(c, core.Tracer) or core.typeof(c).has_qdd for c in consts): closed_jaxpr = pe.close_jaxpr(pe.convert_constvars_jaxpr(jaxpr)) final_consts = consts else: @@ -1348,32 +1473,29 @@ def seen_attrs_get( assert fun.in_type is None or fun.in_type == in_type return cache[(fun.transforms, fun.params, in_type)] -def _attr_token( +def _attr_cache_index( fun: lu.WrappedFun, in_type: core.InputType | tuple[core.AbstractValue, ...] ) -> int: - from jax.experimental.attrs import jax_getattr cases = seen_attrs_get(fun, in_type) for i, records in enumerate(cases): - for obj, attr, treedef, avals in records: - val = jax_getattr(obj, attr) - vals, treedef_ = tree_flatten(val) - avals_ = map(core.shaped_abstractify, vals) - if treedef != treedef_ or avals != avals_: break + for obj, attr, kind, treedef, avals in records: + if kind in (pe.ReadWrite, pe.BoxAttr): + val = getattr(obj, attr, dne_sentinel) + vals, treedef_ = tree_flatten(val) + avals_ = map(core.shaped_abstractify, vals) + if treedef != treedef_ or avals != avals_: break else: return i return len(cases) -def _attr_update(fun, in_type, i, attrs_tracked): - from jax.experimental.attrs import jax_getattr - leaves = lambda obj, attr: tree_leaves(jax_getattr(obj, attr)) - records = [(obj, attr, init_tree, map(core.shaped_abstractify, leaves(obj, attr))) - for init_tree, _, (obj, attr) in attrs_tracked] +def _attr_cachedata_update(fun, in_type, i, attrs_tracked): + leaves = lambda obj, attr: tree_leaves(getattr(obj, attr, dne_sentinel)) + records = [(obj, attr, kind, init_tree, map(core.typeof, leaves(obj, attr))) + for init_tree, _, (obj, attr, kind) in attrs_tracked] cases = seen_attrs_get(fun, in_type) if i == len(cases): cases.append(records) - else: - assert i < len(cases) and cases[i] == records @dataclasses.dataclass(frozen=True) @@ -1425,9 +1547,8 @@ def check_aval_layout_compatibility( if l is None or isinstance(l, AutoLayout): continue name_str = f' with pytree key path {name}' if name else '' - shape = aval.shape try: - l.check_compatible_aval(shape) + l.check_compatible_aval(aval.shape) except ValueError as e: raise ValueError( f'One of {what_aval}{name_str} is incompatible with its layout ' @@ -1440,6 +1561,84 @@ def check_aval_layout_compatibility( pjit_p.multiple_results = True pjit_p.skip_canonicalization = True +def _is_high(jaxpr, **_) -> bool: + return jaxpr.jaxpr.is_high +pjit_p.is_high = _is_high # type: ignore + +def _to_lojax(*hi_args, jaxpr, **params): + # convert closed-over boxes to explicit args + jaxpr, closed_over_himutables = pe.convert_const_himutables(jaxpr) + hi_args = [*closed_over_himutables, *hi_args] + params = _converted_mutables_add_params(len(closed_over_himutables), **params) + + + # expand pjit params that must match number of lo inputs/outputs + lo_nums_in = [len(aval.lo_ty()) for aval in jaxpr.in_aval_qdds] + lo_nums_out = [len(t.lo_ty()) for t in jaxpr.out_avals] + lo_muts_out = sum(len(aval.lo_ty()) for aval in jaxpr.final_aval_qdds if aval.has_qdd) + params = _lojax_expand_params(lo_nums_in, lo_nums_out, lo_muts_out, **params) + + # collect lo input values + lo_args = [lo_val for aval, x in zip(jaxpr.in_aval_qdds, hi_args) + for lo_val in (aval.read_loval(x) if aval.has_qdd + else aval.lower_val(x))] + + # lower the jaxpr and bind it using lo input values + lo_jaxpr = pe.lower_jaxpr(jaxpr) + all_outs = pjit_p.bind(*lo_args, jaxpr=lo_jaxpr, **params) + out_mut, lo_outs = split_list(all_outs, [lo_muts_out]) + + # collect and apply mutations + out_mut_ = iter(out_mut) + in_idx = {v: i for i, v in enumerate(jaxpr.jaxpr.invars)} + for v in jaxpr.jaxpr.invars: + if v.final_qdd is not None: + qdd = v.final_qdd + lo_vals = it.islice(out_mut_, len(v.aval.lo_ty_qdd(qdd))) + v.aval.update_from_loval(qdd, hi_args[in_idx[v]], *lo_vals) + assert next(out_mut_, None) is None + + # collect output values into hi types + lo_outs_ = iter(lo_outs) + hi_outs = [t.raise_val(*it.islice(lo_outs_, len(t.lo_ty()))) + for t in jaxpr.out_avals] + assert next(lo_outs_, None) is None + + return hi_outs +pjit_p.to_lojax = _to_lojax + +def _converted_mutables_add_params( + n, *, donated_invars, in_shardings, in_layouts, **params): + donated_invars = (False,) * n + donated_invars + in_shardings = (UNSPECIFIED,) * n + in_shardings + in_layouts = (None,) * n + in_layouts + return dict(params, donated_invars=donated_invars, in_shardings=in_shardings, + in_layouts=in_layouts) + + +def _lojax_expand_params( + nums_in, nums_out, muts_out, *, donated_invars, in_shardings, in_layouts, + out_shardings, out_layouts, **params): + # some pjit params match the length of hi_jaxpr.invars/outvars, so when + # lowering we must expand them to match their number of lojax types + def expand(ns, xs): + return tuple(y for n, x in zip(ns, xs) for y in (x,) * n) + donated_invars = expand(nums_in , donated_invars) + in_shardings = expand(nums_in , in_shardings ) + in_layouts = expand(nums_in , in_layouts ) + out_shardings = expand(nums_out, out_shardings ) + out_layouts = expand(nums_out, out_layouts ) + + # also, the lo_jaxpr has pure outputs corresponding to mutable hi_jaxpr types + out_shardings = (UNSPECIFIED,) * muts_out + out_shardings + out_layouts = (None,) * muts_out + out_layouts + + new_params = dict(params, donated_invars=donated_invars, + in_shardings=in_shardings, in_layouts=in_layouts, + out_shardings=out_shardings, out_layouts=out_layouts) + return new_params + + def _resolve_in_layouts(args, jit_in_layouts, resolved_in_shardings, in_avals): # If device or backend is set, return the default layout. This is because you # can pass arrays on cpu (with untiled layouts) to jit with backend='tpu' @@ -1456,8 +1655,8 @@ def _resolve_in_layouts(args, jit_in_layouts, resolved_in_shardings, in_avals): # below. We cannot replace default layout with None to raise nicer errors. # `dispatch_arg_layout` replaces default layouts with `None` to simplify # dispatch and lowering logic downstream. - if hasattr(arg, 'layout'): - arg_layout = arg.layout.device_local_layout + if hasattr(arg, 'format'): + arg_layout = arg.format.device_local_layout dispatch_arg_layout = (None if pxla.is_default_layout(arg_layout, rs, aval) else arg_layout) else: @@ -1475,8 +1674,8 @@ def _resolve_in_layouts(args, jit_in_layouts, resolved_in_shardings, in_avals): resolved_in_layouts.append(None) else: # arg_layout can be None because some backends don't implement the - # required layout methods. Hence `arr.layout` can return - # `Layout(None, sharding)` + # required layout methods. Hence `arr.format` can return + # `Format(None, sharding)` if (committed and not is_pmap_sharding and arg_layout is not None @@ -1513,6 +1712,20 @@ def _resolve_out_layouts(out_layouts, out_shardings, out_avals): new_out_layouts.append(out_l) return tuple(new_out_layouts) +def finalize_arg_sharding(arg_s, committed): + if isinstance(arg_s, UnspecifiedValue): + return arg_s + else: + if committed: + # If the arg has a PmapSharding, then reshard it unconditionally. + return UNSPECIFIED if isinstance(arg_s, PmapSharding) else arg_s + else: + assert isinstance(arg_s, Sharding) + if dispatch.is_single_device_sharding(arg_s): + return UNSPECIFIED + raise NotImplementedError('Having uncommitted Array sharded on ' + 'multiple devices is not supported.') + def _resolve_in_shardings(args, pjit_in_shardings: Sequence[PjitSharding] ) -> Sequence[PjitSharding]: # If True, means that device or backend is set by the user on pjit and it @@ -1535,7 +1748,7 @@ def _resolve_in_shardings(args, pjit_in_shardings: Sequence[PjitSharding] if isinstance(arg_s, PmapSharding): continue if getattr(a, '_committed', True): - committed_arg_shardings.append((arg_s, pxla.MismatchType.ARG_SHARDING, None)) + committed_arg_shardings.append((arg_s, stages.MismatchType.ARG_SHARDING, None)) resolved_in_shardings: list[PjitSharding] = [] for arg, pjit_in_s in zip(args, pjit_in_shardings): @@ -1547,22 +1760,7 @@ def _resolve_in_shardings(args, pjit_in_shardings: Sequence[PjitSharding] if isinstance(arg_s, NamedSharding) and arg_s.mesh.empty: arg_s, committed = UNSPECIFIED, False if isinstance(pjit_in_s, UnspecifiedValue): - if isinstance(arg_s, UnspecifiedValue): - resolved_in_shardings.append(arg_s) - else: - if committed: - # If the arg has a PmapSharding, then reshard it unconditionally. - if isinstance(arg_s, PmapSharding): - resolved_in_shardings.append(UNSPECIFIED) - else: - resolved_in_shardings.append(arg_s) - else: - assert isinstance(arg_s, Sharding) - if dispatch.is_single_device_sharding(arg_s): - resolved_in_shardings.append(UNSPECIFIED) - else: - raise NotImplementedError('Having uncommitted Array sharded on ' - 'multiple devices is not supported.') + resolved_in_shardings.append(finalize_arg_sharding(arg_s, committed)) else: if (isinstance(arg, np.ndarray) and not pjit_in_s.is_fully_replicated and # type: ignore[union-attr] @@ -1571,14 +1769,12 @@ def _resolve_in_shardings(args, pjit_in_shardings: Sequence[PjitSharding] 'Passing non-trivial shardings for numpy ' 'inputs is not allowed. To fix this error, either specify a ' 'replicated sharding explicitly or use ' - '`jax.experimental.multihost_utils.host_local_array_to_global_array(...)` ' + '`jax.make_array_from_process_local_data(...)` ' 'to convert your host local numpy inputs to a jax.Array which you ' - 'can pass to pjit. ' + 'can pass to jit. ' 'If the numpy input is the same on each process, then you can use ' '`jax.make_array_from_callback(...) to create a `jax.Array` which ' - 'you can pass to pjit. ' - 'Please see the jax.Array migration guide for more information ' - 'https://jax.readthedocs.io/en/latest/jax_array_migration.html#handling-of-host-local-inputs-to-pjit-like-batch-etc. ' + 'you can pass to jit. ' f'Got arg shape: {arg.shape}, arg value: {arg}') if not isinstance(arg_s, UnspecifiedValue) and arg_s._is_concrete: # jax.jit does not allow resharding across different memory kinds even @@ -1706,8 +1902,8 @@ def call_impl_cache_miss(*args_, **kwargs_): ctx_mesh=ctx_mesh, name=name, keep_unused=keep_unused, inline=inline, compiler_options_kvs=compiler_options_kvs) fastpath_data = _get_fastpath_data( - compiled, tree_structure(out_flat), args, out_flat, [], jaxpr.effects, - jaxpr.consts, None, pgle_profiler) + compiled, tree_structure(out_flat), args, out_flat, [], [], + jaxpr.effects, jaxpr.consts, None, pgle_profiler) return out_flat, fastpath_data, _need_to_rebuild_with_fdo(pgle_profiler) f = _get_jaxpr_as_fun( @@ -1730,7 +1926,13 @@ def call_impl_cache_miss(*args_, **kwargs_): pjit_p.def_impl(_pjit_call_impl) -def _pjit_lower( +def _pjit_lower(*args, **kwargs): + util.test_event("pjit_lower") + return _pjit_lower_cached(*args, **kwargs) + +# This cache is important for python dispatch performance. +@weakref_lru_cache +def _pjit_lower_cached( jaxpr: core.ClosedJaxpr, in_shardings, out_shardings, @@ -1746,7 +1948,6 @@ def _pjit_lower( lowering_platforms: tuple[str, ...] | None, lowering_parameters: mlir.LoweringParameters, pgle_profiler: profiler.PGLEProfiler | None): - util.test_event("pjit_lower") return pxla.lower_sharding_computation( jaxpr, 'jit', name, in_shardings, out_shardings, in_layouts, out_layouts, tuple(donated_invars), @@ -1757,7 +1958,7 @@ def _pjit_lower( pgle_profiler=pgle_profiler) -def pjit_staging_rule(trace, *args, **params): +def pjit_staging_rule(trace, source_info, *args, **params): # If we're inlining, no need to compute forwarding information; the inlined # computation will in effect forward things. if (params["inline"] and @@ -1771,18 +1972,19 @@ def pjit_staging_rule(trace, *args, **params): # shapes are enabled, use eval_jaxpr, which uses the tracing machinery, # but redundantly performs abstract evaluation again. with core.set_current_trace(trace): - return core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *args, - propagate_source_info=False) + out = core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *args, + propagate_source_info=False) else: - return pe.inline_jaxpr_into_trace( - trace, jaxpr.jaxpr, jaxpr.consts, *args) + out = pe.inline_jaxpr_into_trace( + trace, source_info, jaxpr.jaxpr, jaxpr.consts, *args) + return [trace.to_jaxpr_tracer(x, source_info) for x in out] - jaxpr, in_fwd, out_shardings, out_layouts = _pjit_forwarding( - params['jaxpr'], params['out_shardings'], params['out_layouts']) - params = dict(params, jaxpr=jaxpr, out_shardings=out_shardings, - out_layouts=out_layouts) + jaxpr = params['jaxpr'] if config.dynamic_shapes.value: - source_info = source_info_util.current() + jaxpr, in_fwd, out_shardings, out_layouts = _pjit_forwarding( + jaxpr, params['out_shardings'], params['out_layouts']) + params = dict(params, jaxpr=jaxpr, out_shardings=out_shardings, + out_layouts=out_layouts) out_tracers = [] for aval in _out_type(jaxpr): if type(aval) is core.DShapedArray: @@ -1795,46 +1997,49 @@ def pjit_staging_rule(trace, *args, **params): map(trace.getvar, args), map(trace.makevar, out_tracers), pjit_p, params, jaxpr.effects, source_info) trace.frame.add_eqn(eqn) + out_tracers_ = iter(out_tracers) + out_tracers = [args[f] if type(f) is int else next(out_tracers_) + for f in in_fwd] + assert next(out_tracers_, None) is None elif any(isinstance(c, core.MutableArray) for c in jaxpr.consts): jaxpr, consts = pxla._move_mutable_consts(jaxpr) - consts = map(trace.new_const, consts) + consts = [trace.new_const(c, source_info) for c in consts] in_shardings = (*params['in_shardings'],) + (UNSPECIFIED,) * len(consts) in_layouts = (*params['in_layouts'],) + (None,) * len(consts) donated_invars = (*params['donated_invars'],) + (False,) * len(consts) new_params = dict(params, jaxpr=jaxpr, in_shardings=in_shardings, in_layouts=in_layouts, donated_invars=donated_invars) out_tracers = trace.default_process_primitive( - pjit_p, (*args, *consts), new_params) + pjit_p, (*args, *consts), new_params, source_info=source_info) else: - out_tracers = trace.default_process_primitive(pjit_p, args, params) + out_tracers = trace.default_process_primitive( + pjit_p, args, params, source_info=source_info) - out_tracers_ = iter(out_tracers) - out_tracers = [args[f] if type(f) is int else next(out_tracers_) - for f in in_fwd] - assert next(out_tracers_, None) is None return out_tracers pe.custom_staging_rules[pjit_p] = pjit_staging_rule def _pjit_forwarding(jaxpr, out_shardings, out_layouts): in_fwd: list[int | None] = pe._jaxpr_forwarding(jaxpr.jaxpr) - in_fwd = [fwd if isinstance(os, UnspecifiedValue) and ol is None else None for fwd, os, ol - in zip(in_fwd, out_shardings, out_layouts)] + in_fwd = [fwd if isinstance(os, UnspecifiedValue) and ol is None else None + for fwd, os, ol in zip(in_fwd, out_shardings, out_layouts)] keep = [f is None for f in in_fwd] jaxpr = pe.prune_closed_jaxpr_outputs(jaxpr, keep) - out_shardings = [o for o, k in zip(out_shardings, keep) if k] - out_layouts = [o for o, k in zip(out_layouts , keep) if k] + out_shardings = tuple(o for o, k in zip(out_shardings, keep) if k) + out_layouts = tuple(o for o, k in zip(out_layouts , keep) if k) return jaxpr, in_fwd, out_shardings, out_layouts def pjit_forwarding_rule(eqn): + if not config.dynamic_shapes.value: + return [None] * len(eqn.outvars), eqn jaxpr, in_fwd, out_shardings, out_layouts = _pjit_forwarding( eqn.params['jaxpr'], eqn.params['out_shardings'], eqn.params['out_layouts']) new_outvars = [v for v, f in zip(eqn.outvars, in_fwd) if f is None] - new_params = dict(eqn.params, jaxpr=jaxpr, out_shardings=(*out_shardings,), - out_layouts=(*out_layouts,)) + new_params = dict(eqn.params, jaxpr=jaxpr, out_shardings=out_shardings, + out_layouts=out_layouts) new_eqn = eqn.replace(params=new_params, outvars=new_outvars) - fwd_vars = [eqn.invars[f] if f is not None else None for f in in_fwd] - return fwd_vars, new_eqn + return in_fwd, new_eqn +# TODO(mattjj): Remove pjit_forwarding_rule and also in staging rule. pe.forwarding_rules[pjit_p] = pjit_forwarding_rule @@ -1880,8 +2085,8 @@ def _pjit_cached_lower_jaxpr_to_fun(ctx: mlir.LoweringRuleContext, elif isinstance(axis_ctx, sharding_impls.SPMDAxisContext): num_devices = axis_ctx.mesh.size key = (pjit_p, name, jaxpr, effects, num_devices, - pxla.SemanticallyEqualShardings(in_shardings, jaxpr.in_avals), - pxla.SemanticallyEqualShardings(out_shardings, jaxpr.out_avals), + pxla.SemanticallyEqualShardings(in_shardings, jaxpr.in_avals), # pytype: disable=wrong-arg-types + pxla.SemanticallyEqualShardings(out_shardings, jaxpr.out_avals), # pytype: disable=wrong-arg-types in_layouts, out_layouts, api_name) func = mod_ctx.cached_primitive_lowerings.get(key, None) @@ -1891,12 +2096,19 @@ def _pjit_cached_lower_jaxpr_to_fun(ctx: mlir.LoweringRuleContext, # TODO(b/228598865): inlined calls cannot have shardings set directly on the # inputs or outputs because they are lost during MLIR->HLO conversion. # using_sharding_annotation=False means we add an identity operation instead. + num_callbacks = len(mod_ctx.host_callbacks) func = mlir.lower_jaxpr_to_fun( mod_ctx, name, jaxpr, effects, ctx.name_stack, arg_shardings=arg_shardings, result_shardings=result_shardings, use_sharding_annotations=False, api_name=api_name, arg_layouts=in_layouts, result_layouts=out_layouts) - mod_ctx.cached_primitive_lowerings[key] = func + + # If this Jaxpr includes callbacks, we can't cache the lowering because + # on TPU every callback must have a globally unique channel, but the + # channel gets assigned during lowering. + has_callbacks = len(mod_ctx.host_callbacks) > num_callbacks + if not has_callbacks or "tpu" not in mod_ctx.platforms: + mod_ctx.cached_primitive_lowerings[key] = func return func @@ -1976,12 +2188,6 @@ def _pjit_batcher(axis_data, vals_in, batching.fancy_primitive_batchers[pjit_p] = _pjit_batcher batching.ragged_prop_rules[pjit_p] = batching.ragged_mask_no_op_rule -def _insert_axis_partitions(spec, dim, val): - too_short = dim - len(spec) - if too_short > 0: - spec += (None,) * too_short - new_partitions = tuple_insert(spec, dim, val) - return PartitionSpec(*new_partitions) def _pjit_batcher_for_sharding( s: Sharding | UnspecifiedValue, @@ -1995,7 +2201,7 @@ def _pjit_batcher_for_sharding( return s if isinstance(s, NamedSharding) and isinstance(s.mesh, AbstractMesh): return NamedSharding( - s.mesh, _insert_axis_partitions(s.spec, dim, PartitionSpec.UNCONSTRAINED)) + s.mesh, pxla.batch_spec(s.spec, dim, PartitionSpec.UNCONSTRAINED)) new_op = hlo_s.to_proto().clone() tad = list(new_op.tile_assignment_dimensions) tad.insert(dim, 1) # type: ignore @@ -2007,7 +2213,7 @@ def _pjit_batcher_for_sharding( else: if isinstance(s, NamedSharding) and isinstance(s.mesh, AbstractMesh): return NamedSharding( - s.mesh, _insert_axis_partitions(s.spec, dim, spmd_axis_name)) + s.mesh, pxla.batch_spec(s.spec, dim, spmd_axis_name)) if isinstance(s, NamedSharding): mesh = s.mesh if mesh is None or mesh.empty: @@ -2020,7 +2226,7 @@ def _pjit_batcher_for_sharding( f' manager scope{s!r}') spec = parse_flatten_op_sharding(hlo_s, mesh)[0] return NamedSharding( - mesh, _insert_axis_partitions(spec, dim, spmd_axis_name)) + mesh, pxla.batch_spec(spec, dim, spmd_axis_name)) def _pjit_jvp(primals_in, tangents_in, @@ -2062,14 +2268,47 @@ def _pjit_linearization(nzs, *primals_in, jaxpr, donated_invars, ctx_mesh, name, keep_unused, inline, compiler_options_kvs): primal_jaxpr, num_residuals, nzs_out, tangent_jaxpr = ad.linearize_jaxpr(jaxpr, nzs) - # constvars will become residuals. Move them to the end of the ordinary args. res_shardings = (UNSPECIFIED,) * num_residuals res_layouts = (None,) * num_residuals res_donated = (False,) * num_residuals + primal_out_shardings = res_shardings + tuple(out_shardings) + primal_out_layouts = res_layouts + tuple(out_layouts) + + def keep_where(l, should_keep): + return tuple(x for x, keep in zip(l, should_keep) if keep) + + # Input-to-output forwarding. + in_fwd = pe._jaxpr_forwarding(primal_jaxpr.jaxpr) + in_fwd_res, in_fwd_primal = split_list(in_fwd, [num_residuals]) + in_fwd = in_fwd_res + [ + fwd if isinstance(os, UnspecifiedValue) and ol is None else None + for os, ol, fwd in zip(out_shardings, out_layouts, in_fwd_primal) + ] + del in_fwd_res, in_fwd_primal + keep = [f is None for f in in_fwd] + primal_jaxpr = pe.prune_closed_jaxpr_outputs(primal_jaxpr, keep) + primal_out_shardings = keep_where(primal_out_shardings, keep) + primal_out_layouts = keep_where(primal_out_layouts, keep) + kept_res, _ = split_list(keep, [num_residuals]) + num_kept_residuals = sum(kept_res) + del keep, kept_res + + # Output-to-output forwarding. + num_out_primals = len(primal_jaxpr.jaxpr.outvars) - num_kept_residuals + res_vars, out_vars = split_list(primal_jaxpr.jaxpr.outvars, [num_kept_residuals]) + idx_map = {id(v): i for i, v in enumerate(out_vars)} + offset = sum(id(v) not in idx_map for v in res_vars) + idx_map = {k: v + offset for k, v in idx_map.items()} + out_fwd = [idx_map.get(id(v)) for v in res_vars] + [None] * num_out_primals + keep = [f is None for f in out_fwd] + primal_jaxpr = pe.prune_closed_jaxpr_outputs(primal_jaxpr, keep) + primal_out_shardings = keep_where(primal_out_shardings, keep) + primal_out_layouts = keep_where(primal_out_layouts, keep) + del keep + def tangent_fun(consts_, *tangents): tangents_nz = _filter_zeros(nzs, tangents) - assert len(consts_) == num_residuals - nz_tangents_out = pjit_p.bind(*(*tangents_nz, *consts_), + nz_tangents_out = pjit_p.bind(*tangents_nz, *consts_, jaxpr=tangent_jaxpr, in_shardings=_filter_zeros(nzs, in_shardings) + res_shardings, out_shardings=_filter_zeros(nzs_out, out_shardings), @@ -2092,15 +2331,17 @@ def _filter_zeros(is_nz_l, l): ans = pjit_p.bind(*primals_in, jaxpr=primal_jaxpr, in_shardings=in_shardings, - out_shardings=(*res_shardings, *out_shardings), + out_shardings=primal_out_shardings, in_layouts=in_layouts, - out_layouts=(*res_layouts, *out_layouts), + out_layouts=primal_out_layouts, donated_invars=donated_invars, ctx_mesh=ctx_mesh, name=name, keep_unused=keep_unused, inline=inline, compiler_options_kvs=compiler_options_kvs) + ans = subs_list(out_fwd, ans, ans) + ans = subs_list(in_fwd, primals_in, ans) residuals_ans, primal_ans = split_list(ans, [num_residuals]) return primal_ans, nzs_out, residuals_ans, tangent_fun @@ -2117,42 +2358,32 @@ def _pjit_partial_eval(trace: pe.JaxprTrace, known_ins = tuple(pv.is_known() for pv in in_pvals) unknown_ins = tuple(not k for k in known_ins) - if any(isinstance(e, (RefEffect, core.InternalMutableArrayEffect)) - for e in jaxpr.effects): - known_jaxpr_, unknown_jaxpr_, unknown_outs, _, num_res_val, num_res_ref = \ - pe.partial_eval_jaxpr_stateful(jaxpr.jaxpr, unknown_ins, unknown_ins, - False, False, None) - if num_res_ref: raise NotImplementedError - known_jaxpr = pe.ClosedJaxpr(known_jaxpr_, jaxpr.consts) - unknown_jaxpr = pe.ClosedJaxpr(unknown_jaxpr_, jaxpr.consts) - res_avals = unknown_jaxpr.in_avals[:num_res_val] - else: - known_jaxpr, unknown_jaxpr, unknown_outs, res_avals = \ - pe.partial_eval_jaxpr_nounits(jaxpr, unknown_ins, instantiate=False) + known_jaxpr, unknown_jaxpr, unknown_outs, res_out_avals, in_fwd_res = \ + pe.partial_eval_jaxpr_nounits_fwd(jaxpr, unknown_ins, instantiate=False) unknown_outs = tuple(unknown_outs) # type: ignore[assignment] known_outs = tuple(not uk for uk in unknown_outs) - num_residuals = len(res_avals) - res_shardings = (UNSPECIFIED,) * num_residuals - res_layouts = (None,) * num_residuals + # out_shardings and out_layouts for residual values output by known_jaxpr def keep_where(l, should_keep): return tuple(x for x, keep in zip(l, should_keep) if keep) - known_out_shardings = keep_where(out_shardings, known_outs) + res_shardings - known_out_layouts = keep_where(out_layouts, known_outs) + res_layouts + known_out_shardings = (keep_where(out_shardings, known_outs) + + (UNSPECIFIED,) * len(res_out_avals)) + known_out_layouts = (keep_where(out_layouts, known_outs) + + (None,) * len(res_out_avals)) # Input-to-output forwarding: compute which outputs are just forwarded inputs. - num_out_primals = len(known_jaxpr.out_avals) - num_residuals + num_out_primals = len(known_jaxpr.out_avals) - len(res_out_avals) in_fwd: list[int | None] = pe._jaxpr_forwarding(known_jaxpr.jaxpr) - # Only forward primal outputs when corresponding out_sharding is UNSPECIFIED. - in_fwd_primal, in_fwd_res = split_list(in_fwd, [num_out_primals]) + in_fwd_primal, in_fwd_res_ = split_list(in_fwd, [num_out_primals]) + assert all(f is None for f in in_fwd_res_) in_fwd = [ fwd if isinstance(os, UnspecifiedValue) and ol is None else None for os, ol, fwd in zip( keep_where(out_shardings, known_outs), keep_where(out_layouts, known_outs), in_fwd_primal) - ] + in_fwd_res - del in_fwd_primal, in_fwd_res + ] + in_fwd_res_ + del in_fwd_primal, in_fwd_res_ # Prune jaxpr outputs and out_shardings by removing the input-forwards. keep = [f is None for f in in_fwd] known_jaxpr = pe.prune_closed_jaxpr_outputs(known_jaxpr, keep) @@ -2195,7 +2426,11 @@ def keep_where(l, should_keep): all_known_outs = subs_list(in_fwd, known_inputs, all_known_outs) known_out_vals, residual_vals = \ - split_list(all_known_outs, [len(all_known_outs) - num_residuals]) + split_list(all_known_outs, [len(all_known_outs) - len(res_out_avals)]) + residual_vals_ = iter(residual_vals) + residual_vals = [next(residual_vals_) if f is None + else [*jaxpr.consts, *known_inputs][f] for f in in_fwd_res] + assert next(residual_vals_, None) is None residual_tracers = map(trace.new_instantiated_const, residual_vals) # The convention of partial_eval_jaxpr_nounits is to place residual binders at @@ -2203,16 +2438,22 @@ def keep_where(l, should_keep): # jaxpr equation built below and the pjit transpose rule assume a # residual-inputs-last convention. unknown_jaxpr = pe.move_binders_to_back( - unknown_jaxpr, [True] * num_residuals + [False] * sum(unknown_ins)) - # Prepare unknown tracers + unknown_jaxpr, [True] * len(residual_vals) + [False] * sum(unknown_ins)) + + # Set up staged-out 'unknown' eqn + unknown_in_shardings = (keep_where(in_shardings, unknown_ins) + + (UNSPECIFIED,) * len(residual_tracers)) + unknown_in_layouts = (keep_where(in_layouts, unknown_ins) + + (None,) * len(residual_tracers)) + unknown_donated_invars = (keep_where(donated_invars, unknown_ins) + + (False,) * len(residual_tracers)) unknown_params = dict( jaxpr=unknown_jaxpr, - in_shardings=(keep_where(in_shardings, unknown_ins) + res_shardings), + in_shardings=unknown_in_shardings, + in_layouts=unknown_in_layouts, out_shardings=keep_where(out_shardings, unknown_outs), - in_layouts=(keep_where(in_layouts, unknown_ins) + res_layouts), out_layouts=keep_where(out_layouts, unknown_outs), - donated_invars=(keep_where(donated_invars, unknown_ins) + - (False,) * num_residuals), + donated_invars=unknown_donated_invars, ctx_mesh=ctx_mesh, name=name, keep_unused=keep_unused, @@ -2224,7 +2465,7 @@ def keep_where(l, should_keep): pe.JaxprTracer(trace, pe.PartialVal.unknown(aval), None) for aval in unknown_out_avals ] - eqn = pe.new_eqn_recipe((*unknown_tracers_in, *residual_tracers), + eqn = pe.new_eqn_recipe(trace, (*unknown_tracers_in, *residual_tracers), unknown_tracers_out, pjit_p, unknown_params, @@ -2304,8 +2545,7 @@ def _pjit_transpose(cts_in, *primals_in, def prune_type(ty, xs, maybe_zeros): return tuple(x for x, mz in zip(xs, maybe_zeros) if type(mz) is not ty) - body = lu.wrap_init(ad.closed_backward_pass, - debug_info=jaxpr.jaxpr._debug_info) + body = lu.wrap_init(ad.closed_backward_pass, debug_info=jaxpr.jaxpr._debug_info) body = lu.hashable_partial(body, jaxpr, False) primals_and_nz_cts_in, in_treedef = tree_flatten((primals_in, cts_in)) body, cts_out_treedef_thunk = flatten_fun_nokwargs(body, in_treedef) @@ -2334,11 +2574,12 @@ def prune_type(ty, xs, maybe_zeros): if attrs_tracked: init_states = _get_states(attrs_tracked) + num_attr_outs = sum(final_tree.num_leaves for _, final_tree, _ in attrs_tracked) primals_and_nz_cts_in = [*init_states, *primals_and_nz_cts_in] - transpose_in_shardings = (UNSPECIFIED,) * len(attrs_tracked) + transpose_in_shardings - transpose_out_shardings = (UNSPECIFIED,) * len(attrs_tracked) + transpose_out_shardings - transpose_in_layouts = (None,) * len(attrs_tracked) + transpose_in_layouts - transpose_out_layouts = (None,) * len(attrs_tracked) + transpose_out_layouts + transpose_in_shardings = (UNSPECIFIED,) * len(init_states) + transpose_in_shardings + transpose_out_shardings = (UNSPECIFIED,) * num_attr_outs + transpose_out_shardings + transpose_in_layouts = (None,) * len(init_states) + transpose_in_layouts + transpose_out_layouts = (None,) * num_attr_outs + transpose_out_layouts try: nz_cts_out = pjit_p.bind( @@ -2354,7 +2595,7 @@ def prune_type(ty, xs, maybe_zeros): keep_unused=keep_unused, inline=inline, compiler_options_kvs=compiler_options_kvs) - except dispatch.InternalFloatingPointError as e: + except api_util.InternalFloatingPointError as e: print("Invalid nan value encountered in the backward pass of a jax.jit " "function. Calling the de-optimized backward pass.") try: @@ -2364,10 +2605,10 @@ def prune_type(ty, xs, maybe_zeros): else: # If control reaches this line, we got a NaN on the output of `compiled` # but not `fun.call_wrapped` on the same arguments. Let's tell the user. - dispatch._raise_no_nan_in_deoptimized(e) + api_util._raise_no_nan_in_deoptimized(e) if attrs_tracked: - final_states, nz_cts_out = split_list(nz_cts_out, [len(init_states)]) + final_states, nz_cts_out = split_list(nz_cts_out, [num_attr_outs]) _set_states(attrs_tracked, final_states) return tree_unflatten(cts_out_treedef, nz_cts_out) @@ -2449,38 +2690,6 @@ def _pjit_pp_rule(eqn: core.JaxprEqn, core.pp_eqn_rules[pjit_p] = _pjit_pp_rule -def _pjit_state_discharge_rule( - in_avals, out_avals, *args, jaxpr, in_shardings, out_shardings, - in_layouts, out_layouts, **params): - if not all(isinstance(s, UnspecifiedValue) for s in (*in_shardings, *out_shardings)): - raise NotImplementedError - - if not (all(l is None for l in in_layouts) and - all(l is None for l in out_layouts)): - raise NotImplementedError - - jaxpr, consts = jaxpr.jaxpr, jaxpr.consts - num_outs = len(jaxpr.outvars) - discharged_jaxpr, discharged_consts = state_discharge.discharge_state(jaxpr, consts) - discharged_closed_jaxpr = core.ClosedJaxpr(discharged_jaxpr, discharged_consts) - new_in_shardings = (UnspecifiedValue(),) * len(discharged_jaxpr.invars) - new_out_shardings = (UnspecifiedValue(),) * len(discharged_jaxpr.outvars) - new_in_layouts = (None,) * len(discharged_jaxpr.invars) - new_out_layouts = (None,) * len(discharged_jaxpr.outvars) - out_and_ref_vals = pjit_p.bind( - *args, jaxpr=discharged_closed_jaxpr, in_shardings=new_in_shardings, - out_shardings=new_out_shardings, in_layouts=new_in_layouts, - out_layouts=new_out_layouts, **params) - out_vals, ref_vals = split_list(out_and_ref_vals, [num_outs]) - ref_vals_iter = iter(ref_vals) - new_invals = tuple(next(ref_vals_iter) if isinstance(aval, AbstractRef) - else None for aval in in_avals) - sentinel = object() - assert next(ref_vals_iter, sentinel) is sentinel - return new_invals, out_vals -state_discharge.register_discharge_rule(pjit_p)(_pjit_state_discharge_rule) - - # -------------------- with_sharding_constraint -------------------- def check_shardings_are_auto(shardings_flat): @@ -2495,7 +2704,13 @@ def check_shardings_are_auto(shardings_flat): raise ValueError( 'The spec of NamedSharding passed to with_sharding_constraint can' f' only refer to Auto axes of the mesh. Got spec={s.spec} and' - f' mesh={mesh}') + f' mesh={mesh}. You probably meant to use `reshard` API?') + + cur_mesh = mesh_lib.get_abstract_mesh() + if cur_mesh._are_all_axes_explicit: + raise ValueError( + 'with_sharding_constraint cannot be used when all axes of the mesh are' + ' of type `Explicit`. Please use the `reshard` API.') def with_sharding_constraint(x, shardings): @@ -2516,10 +2731,10 @@ def with_sharding_constraint(x, shardings): Returns: x_with_shardings: PyTree of jax.Arrays with specified sharding constraints. - .. _Distributed arrays and automatic parallelization: https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html + .. _Distributed arrays and automatic parallelization: https://docs.jax.dev/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html """ x_flat, tree = tree_flatten(x) - + x_avals_flat = [core.shaped_abstractify(x) for x in x_flat] layouts, shardings = _split_layout_and_sharding(shardings) user_shardings = prepare_axis_resources( @@ -2551,17 +2766,15 @@ def with_sharding_constraint(x, shardings): # TODO(bartchr): remove `unconstrained_dims` after migrating to Shardy. It's # already part of the shardings. unconstrained_dims = [get_unconstrained_dims(s) - if isinstance(s, NamedSharding) else {} + if isinstance(s, NamedSharding) else frozenset() for s in shardings_flat] pjit_check_aval_sharding( - shardings_flat, x_flat, ("",) * len(shardings_flat), + shardings_flat, x_avals_flat, ("",) * len(shardings_flat), "with_sharding_constraint arguments", allow_uneven_sharding=True) - check_shardings_are_auto(shardings_flat) - - check_aval_layout_compatibility(user_layouts_flat, x_flat, + check_aval_layout_compatibility(user_layouts_flat, x_avals_flat, ("",) * len(user_layouts_flat), "with_sharding_constraint arguments") @@ -2608,10 +2821,10 @@ def _sharding_constraint_impl(x, sharding, layout, context_mesh, # Run a jit here to raise good errors when device assignment don't match. return api.jit(_identity_fn, out_shardings=sharding)(x) else: - if (hasattr(x, 'layout') and x.layout.device_local_layout == layout and + if (hasattr(x, 'format') and x.format.device_local_layout == layout and x.sharding.is_equivalent_to(sharding, x.ndim)): return x - return api.jit(_identity_fn, out_shardings=Layout(layout, sharding))(x) + return api.jit(_identity_fn, out_shardings=Format(layout, sharding))(x) sharding_constraint_p = core.Primitive("sharding_constraint") @@ -2622,20 +2835,22 @@ def _sharding_constraint_impl(x, sharding, layout, context_mesh, def _sharding_constraint_hlo_lowering(ctx, x_node, *, sharding, layout, context_mesh, unconstrained_dims): - aval, = ctx.avals_in + in_aval, = ctx.avals_in out_aval, = ctx.avals_out axis_ctx = ctx.module_context.axis_context + if dtypes.issubdtype(in_aval.dtype, dtypes.extended): + in_aval = core.physical_aval(in_aval) if (isinstance(axis_ctx, sharding_impls.SPMDAxisContext) and axis_ctx.manual_axes): - sharding = mlir.add_manual_axes(axis_ctx, sharding, aval.ndim) + sharding = mlir.add_manual_axes(axis_ctx, sharding, in_aval.ndim) if config.use_shardy_partitioner.value: - sharding = sharding._to_sdy_sharding(aval.ndim) + sharding = sharding._to_sdy_sharding(in_aval.ndim) else: - sharding = sharding._to_xla_hlo_sharding(aval.ndim).to_proto() + sharding = sharding._to_xla_hlo_sharding(in_aval.ndim).to_proto() out = mlir.wrap_with_sharding_op( ctx, x_node, out_aval, sharding, unspecified_dims=unconstrained_dims) if layout is not None: - out = mlir.wrap_with_layout_op(ctx, out, out_aval, layout, aval) + out = mlir.wrap_with_layout_op(ctx, out, out_aval, layout, in_aval) return [out] mlir.register_lowering(sharding_constraint_p, _sharding_constraint_hlo_lowering) @@ -2677,7 +2892,7 @@ def _sharding_constraint_batcher( sharding=vmapped_sharding, layout=layout, context_mesh=context_mesh, - unconstrained_dims=unconstrained_dims) + unconstrained_dims=frozenset(unconstrained_dims)) return y, d batching.fancy_primitive_batchers[sharding_constraint_p] = _sharding_constraint_batcher batching.skippable_batchers[sharding_constraint_p] = lambda _: () @@ -2773,7 +2988,12 @@ def reshard(xs, out_shardings): out_flat = [] for x, x_aval, s in safe_zip(x_flat, x_avals_flat, shardings_flat): ds = canonicalize_sharding(s, 'reshard') - ds = ds.with_spec(ds.spec._normalized_spec_for_aval(x_aval.ndim)) # pytype: disable=attribute-error + if ds is None: + raise ValueError( + 'Reshard should only be used with out_shardings which are non-None ' + f'and have a nonempty mesh. Got sharding {s}.' + ) + ds = ds.update(spec=ds.spec._normalized_spec_for_aval(x_aval.ndim)) # pytype: disable=attribute-error out_flat.append(reshard_p.bind(x, dst_sharding=ds)) return tree_unflatten(treedef, out_flat) @@ -2795,7 +3015,7 @@ def _reshard_impl(x, dst_sharding): reshard_p.def_impl(_reshard_impl) def _reshard_transpose_rule(ct, x, dst_sharding): - return [reshard_p.bind(ct, dst_sharding=x.aval.sharding)] + return [reshard_p.bind(ct, dst_sharding=x.aval.to_cotangent_aval().sharding)] ad.deflinear2(reshard_p, _reshard_transpose_rule) def _reshard_hlo_lowering(ctx, x_node, *, dst_sharding): @@ -2821,48 +3041,70 @@ def _reshard_batcher(axis_data, vals_in, dims_in, dst_sharding): # -------------------- auto and user mode ------------------------- def _get_new_mesh(axes: str | tuple[str, ...] | None, - axis_type: mesh_lib.AxisType, name: str, - error_on_manual_to_auto_explict=False): + axis_type: mesh_lib.AxisType, name: str, shardings=None, + error_on_manual_to_auto_explicit=False): cur_mesh = mesh_lib.get_abstract_mesh() - # TODO(yashkatariya): Maybe allow fetching mesh from the args to enable - # computation follows data? - if cur_mesh.empty: + flat_shardings, _ = tree_flatten(shardings) + sharding_mesh = mesh_lib.empty_abstract_mesh + for i in flat_shardings: + if isinstance(i, NamedSharding): + if not sharding_mesh.empty and sharding_mesh != i.mesh: + raise ValueError( + f'Shardings passed to {name} should have the same mesh. Got one' + f' mesh {sharding_mesh} and another {i.mesh}') + sharding_mesh = i.mesh.abstract_mesh + + if sharding_mesh.empty and cur_mesh.empty: raise ValueError( f'Context mesh {cur_mesh} cannot be empty. Please use' ' `jax.sharding.use_mesh` API to enter into a mesh context when using' f' `{name}` API.') + if not sharding_mesh.empty and not cur_mesh.empty: + if sharding_mesh != cur_mesh: + raise ValueError( + f'Context mesh {cur_mesh} must match the mesh passed to shardings' + f' {sharding_mesh}. Recommended approach is to use' + ' `jax.sharding.use_mesh` context manager.') + mesh_to_use = cur_mesh + elif sharding_mesh.empty and not cur_mesh.empty: + mesh_to_use = cur_mesh + else: + assert not sharding_mesh.empty and cur_mesh.empty + mesh_to_use = sharding_mesh + if axes is None: - axes = cur_mesh.axis_names + axes = mesh_to_use.axis_names if not isinstance(axes, tuple): axes = (axes,) for a in axes: - if (error_on_manual_to_auto_explict and - cur_mesh._name_to_type[a] == mesh_lib.AxisType.Manual and + if (error_on_manual_to_auto_explicit and + mesh_to_use._name_to_type[a] == mesh_lib.AxisType.Manual and axis_type in {mesh_lib.AxisType.Auto, mesh_lib.AxisType.Explicit}): raise NotImplementedError( 'Going from `Manual` AxisType to `Auto` or `Explicit` AxisType is not' ' allowed. Please file a bug at https://github.com/jax-ml/jax/issues' ' with your use case') - return cur_mesh.update_axis_types({a: axis_type for a in axes}) + return mesh_to_use.update_axis_types({a: axis_type for a in axes}) def auto_axes(fun, *, axes: str | tuple[str, ...] | None = None, - out_shardings=None): + out_sharding=None): def decorator(*args, **kwargs): - if out_shardings is None: - if "out_shardings" in kwargs: - _out_shardings = kwargs.pop("out_shardings") + if out_sharding is None: + if "out_sharding" in kwargs: + _out_sharding = kwargs.pop("out_sharding") else: - raise TypeError("Missing required keyword argument: 'out_shardings'") + raise TypeError("Missing required keyword argument: 'out_sharding'") else: - _out_shardings = out_shardings - new_mesh = _get_new_mesh(axes, mesh_lib.AxisType.Auto, 'auto_axes', - error_on_manual_to_auto_explict=True) + _out_sharding = out_sharding + new_mesh = _get_new_mesh( + axes, mesh_lib.AxisType.Auto, 'auto_axes', shardings=_out_sharding, + error_on_manual_to_auto_explicit=True) with mesh_lib.use_abstract_mesh(new_mesh): in_specs = tree_map(lambda a: core.modify_spec_for_auto_manual( core.get_aval(a).sharding.spec, new_mesh), args) args = mesh_cast(args, in_specs) out = fun(*args, **kwargs) - return mesh_cast(out, _out_shardings) + return mesh_cast(out, _out_sharding) return decorator @contextlib.contextmanager @@ -2873,19 +3115,19 @@ def use_auto_axes(*axes): def explicit_axes(fun, *, axes: str | tuple[str, ...] | None = None, - in_shardings=None): + in_sharding=None): def decorator(*args, **kwargs): - if in_shardings is None: - if "in_shardings" in kwargs: - _in_shardings = kwargs.pop("in_shardings") + if in_sharding is None: + if "in_sharding" in kwargs: + _in_sharding = kwargs.pop("in_sharding") else: - raise TypeError("Missing required keyword argument: 'in_shardings'") + raise TypeError("Missing required keyword argument: 'in_sharding'") else: - _in_shardings = in_shardings + _in_sharding = in_sharding new_mesh = _get_new_mesh(axes, mesh_lib.AxisType.Explicit, 'explicit_axes', - error_on_manual_to_auto_explict=True) + error_on_manual_to_auto_explicit=True) with mesh_lib.use_abstract_mesh(new_mesh): - args = mesh_cast(args, _in_shardings) + args = mesh_cast(args, _in_sharding) out = fun(*args, **kwargs) out_specs = tree_map(lambda o: core.modify_spec_for_auto_manual( core.get_aval(o).sharding.spec, mesh_lib.get_abstract_mesh()), out) @@ -2899,47 +3141,187 @@ def use_explicit_axes(*axes): with mesh_lib.use_abstract_mesh(new_mesh): yield -# -------------------- helpers -------------------- +# -------------------- with_layout_constraint -------------------- -def get_unconstrained_dims(sharding: NamedSharding): - assert sharding.spec is not None - return {i for i, axes in enumerate(sharding.spec) - if axes is PartitionSpec.UNCONSTRAINED} +def with_layout_constraint(x, layouts): + x_flat, tree = tree_flatten(x) + x_avals_flat = [core.shaped_abstractify(x) for x in x_flat] + layouts_flat = tuple(flatten_axes("with_layout_constraint layouts", tree, + layouts)) + if any(not isinstance(l, DeviceLocalLayout) for l in layouts_flat): + raise ValueError( + 'layouts passed to `with_layout_constraint` must be of type' + f' `DeviceLocalLayout`. Got {[type(l) for l in layouts_flat]}') + check_aval_layout_compatibility( + layouts_flat, x_avals_flat, ("",) * len(layouts_flat), + "with_layout_constraint arguments") + outs = [layout_constraint_p.bind(xf, layout=l) + for xf, l in zip(x_flat, layouts_flat)] + return tree_unflatten(tree, outs) +layout_constraint_p = core.Primitive('layout_constraint') +layout_constraint_p.def_abstract_eval(lambda x, **_: x) +ad.deflinear2(layout_constraint_p, + lambda ct, _, **params: (layout_constraint_p.bind(ct, **params),)) + +def _layout_constraint_impl(x, *, layout): + if not isinstance(x, xc.ArrayImpl): + raise ValueError( + 'with_layout_constraint in eager mode can only be applied to' + f' jax.Arrays. Got {type(x)}') + if x.format.device_local_layout == layout: # type: ignore + return x + return api.jit(_identity_fn, out_shardings=Format(layout, x.sharding))(x) +layout_constraint_p.def_impl(_layout_constraint_impl) -def get_op_sharding_from_executable( - executable) -> tuple[Sequence[xc.OpSharding], Sequence[xc.OpSharding]]: - in_op_shardings: list[xc.OpSharding] = [] - parameter_shardings_from_xla = executable.get_parameter_shardings() - if parameter_shardings_from_xla is not None: - in_op_shardings = parameter_shardings_from_xla +def _layout_constraint_hlo_lowering(ctx, x_node, *, layout): + aval, = ctx.avals_in + out_aval, = ctx.avals_out + return [mlir.wrap_with_layout_op(ctx, x_node, out_aval, layout, aval)] +mlir.register_lowering(layout_constraint_p, + _layout_constraint_hlo_lowering) - out_op_shardings: list[xc.OpSharding] = [] - output_shardings_from_xla = executable.get_output_shardings() - if output_shardings_from_xla is not None: - out_op_shardings = output_shardings_from_xla +def _layout_constraint_batcher(axis_data, vals_in, dims_in, layout): + raise NotImplementedError +batching.fancy_primitive_batchers[layout_constraint_p] = _layout_constraint_batcher +batching.skippable_batchers[layout_constraint_p] = lambda _: () - return in_op_shardings, out_op_shardings +# -------------------- helpers -------------------- +def get_unconstrained_dims(sharding: NamedSharding): + assert sharding.spec is not None + return frozenset(i for i, axes in enumerate(sharding.spec) + if axes is PartitionSpec.UNCONSTRAINED) -def _get_ppspec_from_executable( - executable, mesh - ) -> tuple[Sequence[PartitionSpec], Sequence[PartitionSpec]]: - input_op_shardings, output_op_sharding = get_op_sharding_from_executable( - executable - ) - in_pspec: list[PartitionSpec] = [] - for s in input_op_shardings: - in_pspec.extend(parse_flatten_op_sharding(s, mesh)) +# -------------------- attrs etc -------------------- - out_pspec: list[PartitionSpec] = [] - for s in output_op_sharding: - out_pspec.extend(parse_flatten_op_sharding(s, mesh)) - return in_pspec, out_pspec +def _set_states(attrs_tracked, vals): + valss = split_list(vals, [td.num_leaves for _, td, _ in attrs_tracked[:-1]]) + for ((_, treedef, (obj, attr, kind)), leaves) in zip(attrs_tracked, valss): + if kind is pe.ReadWrite: + val = tree_unflatten(treedef, leaves) + jax_setattr(obj, attr, val) + elif kind is pe.Append: + del treedef + val, = leaves + jax_extendattr(obj, attr, val) + elif kind is pe.BoxAttr: + val = tree_unflatten(treedef, leaves) + obj.set(val) + elif kind is pe.ListAttr: + for item in tree_unflatten(treedef, leaves): + obj.append(item) + else: + assert False +def _get_states(attrs_tracked): + vals = [] + for treedef, _, (obj, attr, kind) in attrs_tracked: + if kind is pe.ReadWrite: + tree = jax_getattr(obj, attr) if hasattr(obj, attr) else dne_sentinel + leaves, treedef_ = tree_flatten(tree) + assert treedef == treedef_ + vals.extend(leaves) + elif kind is pe.Append: + pass + elif kind is pe.BoxAttr: + tree = obj.get() # not getattr! + leaves, treedef_ = tree_flatten(tree) + assert treedef == treedef_ + vals.extend(leaves) + elif kind is pe.ListAttr: + pass + else: + assert False + return vals -def get_pspec_from_executable( - executable, mesh: pxla.Mesh -) -> tuple[tuple[PartitionSpec, ...], tuple[PartitionSpec, ...]]: - in_pspec, out_pspec = _get_ppspec_from_executable(executable, mesh) - return tuple(in_pspec), tuple(out_pspec) +def static(): + return dataclasses.field(metadata=dict(static=True)) + +@tree_util.register_dataclass +@dataclasses.dataclass +class BoxTree: + leaves: list + treedef: PyTreeDef = static() + +@tree_util.register_dataclass +@dataclasses.dataclass +class ListTree: + leaves: list + treedef: PyTreeDef | None = static() + +def _flatten_boxes(dbg, args, kwargs): + # TODO(mattjj,dougalm): refine this implementation of box-handling... + if all(not isinstance(x, (Box, List)) for x in tree_leaves((args, kwargs))): + return args, kwargs, [] + box_data = [] + id_first_occurrences = {} + idxs = it.count() + def visit(x): + i = next(idxs) + if (isinstance(x, (Box, List)) and + (dup_idx := id_first_occurrences.setdefault(id(x), i)) != i): + type_name = type(x).__name__ + raise ValueError( + f"a {type_name} instance can't be passed as an argument more than " + f"once, but when tracing {dbg.func_src_info} for {dbg.traced_for}, " + f"the object {x} appeared at both arguments " + f"{dbg.arg_names[dup_idx]} and {dbg.arg_names[i]}" + if dbg else + f"at both flat index {dup_idx} and flat index {i}") + if type(x) is Box: + leaves, treedef = tree_flatten(x._val) + ty = tuple(core.shaped_abstractify(l) for l in leaves) + box_data.append((i, pe.BoxAttr)) + return BoxTree(leaves, treedef) + elif type(x) is List: + box_data.append((i, pe.ListAttr)) + return ListTree([], None) + else: + return x + args, kwargs = tree_map(visit, (args, kwargs)) + return args, kwargs, box_data + +# TODO(mattjj): because _handle_boxes's caller passes arguments splatted, the +# names of its first two parameters must not collide with user-suppliedkwargs. +# Using obscure names is a temporary workaround; revise! +@lu.transformation2 +def _handle_boxes(__f, __dbg, *args, **kwargs): + f, dbg = __f, __dbg + new_args = [] + arg_mutables = [] + def visit(x): + if type(x) is BoxTree: + box = Box(tree_unflatten(x.treedef, x.leaves)) + arg_mutables.append(box) + return box + elif type(x) is ListTree: + lst = List() + lst._is_arg = True + arg_mutables.append(lst) + return lst + else: + return x + args, kwargs = tree_map(visit, (args, kwargs), + is_leaf=lambda x: isinstance(x, (BoxTree, ListTree))) + out = f(*args, **kwargs) + for path, leaf in tree_flatten_with_path(out)[0]: + if isinstance(leaf, (Box, List)): + type_name = type(leaf).__name__ + raise ValueError( + f"a {type_name} instance can't be returned from a transformed " + f"function, but when tracing {dbg.func_src_info} for {dbg.traced_for} " + f"the object {leaf} appeared at result{keystr(path)}") + if not arg_mutables: + return out + extra_outs = [] + for mutable in arg_mutables: + if type(mutable) is Box: + leaves, treedef = tree_flatten(mutable._val) + extra_outs.append(BoxTree(leaves, treedef)) + elif type(mutable) is List: + leaves, treedef = tree_flatten(mutable._val) + extra_outs.append(ListTree(leaves, treedef)) + else: + assert False + return extra_outs, out diff --git a/jax/_src/pretty_printer.py b/jax/_src/pretty_printer.py index e8fdff497445..d2850c814bb6 100644 --- a/jax/_src/pretty_printer.py +++ b/jax/_src/pretty_printer.py @@ -31,15 +31,11 @@ import enum from functools import partial import sys -from typing import Any, NamedTuple +from typing import Any, NamedTuple, TYPE_CHECKING from jax._src import config from jax._src import util - -try: - import colorama # pytype: disable=import-error -except ImportError: - colorama = None +from jax._src.lib import _pretty_printer as _pretty_printer _PPRINT_USE_COLOR = config.bool_state( @@ -66,409 +62,464 @@ def _can_use_color() -> bool: CAN_USE_COLOR = _can_use_color() -class Doc(util.StrictABC): - __slots__ = () +# TODO(phawkins): remove this condition after the jaxlib 0.6.3 release. +if TYPE_CHECKING or _pretty_printer is None: + try: + import colorama # pytype: disable=import-error + except ImportError: + colorama = None + + class Doc(util.StrictABC): + __slots__ = () + + def format( + self, width: int = 80, *, use_color: bool | None = None, + annotation_prefix: str = " # ", + source_map: list[list[tuple[int, int, Any]]] | None = None + ) -> str: + """ + Formats a pretty-printer document as a string. + + Args: + source_map: for each line in the output, contains a list of + (start column, end column, source) tuples. Each tuple associates a + region of output text with a source. + """ + if use_color is None: + use_color = CAN_USE_COLOR and _PPRINT_USE_COLOR.value + return _format(self, width, use_color=use_color, + annotation_prefix=annotation_prefix, source_map=source_map) + + def __str__(self): + return self.format() + + def __add__(self, other: Doc) -> Doc: + return concat([self, other]) + + def num_annotations(self) -> int: + raise NotImplementedError() + + class _NilDoc(Doc): + def __repr__(self): return "nil" + + def num_annotations(self) -> int: + return 0 + + _nil = _NilDoc() + + class _TextDoc(Doc): + __slots__ = ("text", "annotation") + text: str + annotation: str | None + + def __init__(self, text: str, annotation: str | None = None): + assert isinstance(text, str), text + assert annotation is None or isinstance(annotation, str), annotation + self.text = text + self.annotation = annotation + + def __repr__(self): + if self.annotation is not None: + return f"text(\"{self.text}\", annotation=\"{self.annotation}\")" + else: + return f"text(\"{self.text}\")" - def format( - self, width: int = 80, *, use_color: bool | None = None, - annotation_prefix: str = " # ", - source_map: list[list[tuple[int, int, Any]]] | None = None - ) -> str: - """ - Formats a pretty-printer document as a string. + def num_annotations(self) -> int: + return 1 if self.annotation is not None else 0 - Args: - source_map: for each line in the output, contains a list of - (start column, end column, source) tuples. Each tuple associates a - region of output text with a source. - """ - if use_color is None: - use_color = CAN_USE_COLOR and _PPRINT_USE_COLOR.value - return _format(self, width, use_color=use_color, - annotation_prefix=annotation_prefix, source_map=source_map) + class _ConcatDoc(Doc): + __slots__ = ("children", "_num_annotations") + children: list[Doc] + _num_annotations: int - def __str__(self): - return self.format() + def __init__(self, children: Sequence[Doc]): + self.children = list(children) + self._num_annotations = sum(child.num_annotations() for child in children) - def __add__(self, other: Doc) -> Doc: - return concat([self, other]) + def __repr__(self): return f"concat({self.children})" -class _NilDoc(Doc): - def __repr__(self): return "nil" + def num_annotations(self) -> int: + return self._num_annotations -_nil = _NilDoc() + class _BreakDoc(Doc): + __slots__ = ("text",) + text: str -class _TextDoc(Doc): - __slots__ = ("text", "annotation") - text: str - annotation: str | None + def __init__(self, text: str): + assert isinstance(text, str), text + self.text = text - def __init__(self, text: str, annotation: str | None = None): - assert isinstance(text, str), text - assert annotation is None or isinstance(annotation, str), annotation - self.text = text - self.annotation = annotation + def __repr__(self): return f"break({self.text})" - def __repr__(self): - if self.annotation is not None: - return f"text(\"{self.text}\", annotation=\"{self.annotation}\")" - else: - return f"text(\"{self.text}\")" + def num_annotations(self) -> int: + return 0 -class _ConcatDoc(Doc): - __slots__ = ("children",) - children: list[Doc] + class _GroupDoc(Doc): + __slots__ = ("child",) + child: Doc - def __init__(self, children: Sequence[Doc]): - self.children = list(children) - assert all(isinstance(doc, Doc) for doc in self.children), self.children - - def __repr__(self): return f"concat({self.children})" - -class _BreakDoc(Doc): - __slots__ = ("text",) - text: str + def __init__(self, child: Doc): + assert isinstance(child, Doc), child + self.child = child - def __init__(self, text: str): - assert isinstance(text, str), text - self.text = text - - def __repr__(self): return f"break({self.text})" - -class _GroupDoc(Doc): - __slots__ = ("child",) - child: Doc - - def __init__(self, child: Doc): - assert isinstance(child, Doc), child - self.child = child - - def __repr__(self): return f"group({self.child})" - -class _NestDoc(Doc): - __slots__ = ("n", "child",) - n: int - child: Doc - - def __init__(self, n: int, child: Doc): - assert isinstance(child, Doc), child - self.n = n - self.child = child - - def __repr__(self): return f"nest({self.n, self.child})" - - -_NO_SOURCE = object() - -class _SourceMapDoc(Doc): - __slots__ = ("child", "source") - child: Doc - source: Any - - def __init__(self, child: Doc, source: Any): - assert isinstance(child, Doc), child - self.child = child - self.source = source - - def __repr__(self): return f"source({self.child}, {self.source})" - - -Color = enum.Enum("Color", ["BLACK", "RED", "GREEN", "YELLOW", "BLUE", - "MAGENTA", "CYAN", "WHITE", "RESET"]) -Intensity = enum.Enum("Intensity", ["DIM", "NORMAL", "BRIGHT"]) - -class _ColorDoc(Doc): - __slots__ = ("foreground", "background", "intensity", "child") - foreground: Color | None - background: Color | None - intensity: Intensity | None - child: Doc - - def __init__(self, child: Doc, *, foreground: Color | None = None, - background: Color | None = None, - intensity: Intensity | None = None): - assert isinstance(child, Doc), child - self.child = child - self.foreground = foreground - self.background = background - self.intensity = intensity - - -_BreakMode = enum.Enum("_BreakMode", ["FLAT", "BREAK"]) - - -# In Lindig's paper fits() and format() are defined recursively. This is a -# non-recursive formulation using an explicit stack, necessary because Python -# doesn't have a tail recursion optimization. - -def _fits(doc: Doc, width: int, agenda: list[tuple[int, _BreakMode, Doc]] - ) -> bool: - while width >= 0 and len(agenda) > 0: - i, m, doc = agenda.pop() - if isinstance(doc, _NilDoc): - pass - elif isinstance(doc, _TextDoc): - width -= len(doc.text) - elif isinstance(doc, _ConcatDoc): - agenda.extend((i, m, d) for d in reversed(doc.children)) - elif isinstance(doc, _BreakDoc): - if m == _BreakMode.BREAK: - return True - width -= len(doc.text) - elif isinstance(doc, _NestDoc): - agenda.append((i + doc.n, m, doc.child)) - elif isinstance(doc, _GroupDoc): - agenda.append((i, _BreakMode.FLAT, doc.child)) - elif isinstance(doc, _ColorDoc) or isinstance(doc, _SourceMapDoc): - agenda.append((i, m, doc.child)) - else: - raise ValueError("Invalid document ", doc) - - return width >= 0 - - -# Annotation layout: A flat group is sparse if there are no breaks between -# annotations. -def _sparse(doc: Doc) -> bool: - agenda = [doc] - num_annotations = 0 - seen_break = False - while len(agenda) > 0: - doc = agenda.pop() - if isinstance(doc, _NilDoc): - pass - elif isinstance(doc, _TextDoc): - if doc.annotation is not None: - if num_annotations >= 1 and seen_break: - return False - num_annotations += 1 - elif isinstance(doc, _ConcatDoc): - agenda.extend(reversed(doc.children)) - elif isinstance(doc, _BreakDoc): - seen_break = True - elif isinstance(doc, _NestDoc): - agenda.append(doc.child) - elif isinstance(doc, _GroupDoc): - agenda.append(doc.child) - elif isinstance(doc, _ColorDoc) or isinstance(doc, _SourceMapDoc): - agenda.append(doc.child) - else: - raise ValueError("Invalid document ", doc) - - return True - -class _ColorState(NamedTuple): - foreground: Color - background: Color - intensity: Intensity - -class _State(NamedTuple): - indent: int - mode: _BreakMode - doc: Doc - color: _ColorState - source_map: Any - -class _Line(NamedTuple): - text: str - width: int - annotations: str | None | list[str] - - -def _update_color(use_color: bool, state: _ColorState, update: _ColorState - ) -> tuple[_ColorState, str]: - if not use_color or colorama is None: - return update, "" - color_str = "" - if state.foreground != update.foreground: - color_str += getattr(colorama.Fore, str(update.foreground.name)) - if state.background != update.background: - color_str += getattr(colorama.Back, str(update.background.name)) - if state.intensity != update.intensity: - color_str += colorama.Style.NORMAL # pytype: disable=unsupported-operands - color_str += getattr(colorama.Style, str(update.intensity.name)) - return update, color_str - - -def _align_annotations(lines): - # TODO: Hafiz also implements a local alignment mode, where groups of lines - # with annotations are aligned together. - maxlen = max(l.width for l in lines) - out = [] - for l in lines: - if len(l.annotations) == 0: - out.append(l._replace(annotations=None)) - elif len(l.annotations) == 1: - out.append(l._replace(text=l.text + " " * (maxlen - l.width), - annotations=l.annotations[0])) - else: - out.append(l._replace(text=l.text + " " * (maxlen - l.width), - annotations=l.annotations[0])) - for a in l.annotations[1:]: - out.append(_Line(text=" " * maxlen, width=l.width, annotations=a)) - return out - - - -def _format( - doc: Doc, width: int, *, use_color: bool, annotation_prefix: str, - source_map: list[list[tuple[int, int, Any]]] | None -) -> str: - lines = [] - default_colors = _ColorState(Color.RESET, Color.RESET, Intensity.NORMAL) - annotation_colors = _ColorState(Color.RESET, Color.RESET, Intensity.DIM) - color_state = default_colors - source_start = 0 # The column at which the current source region starts. - source = _NO_SOURCE # The currently active source region. - line_source_map = [] # Source maps for the current line of text. - agenda = [_State(0, _BreakMode.BREAK, doc, default_colors, source)] - k = 0 - line_text = "" - line_annotations = [] - while len(agenda) > 0: - i, m, doc, color, agenda_source = agenda.pop() - if source_map is not None and agenda_source != source: - pos = len(line_text) - if source_start != pos and source is not _NO_SOURCE: - line_source_map.append((source_start, pos, source)) - source = agenda_source - source_start = pos - if isinstance(doc, _NilDoc): - pass - elif isinstance(doc, _TextDoc): - color_state, color_str = _update_color(use_color, color_state, color) - line_text += color_str - line_text += doc.text - if doc.annotation is not None: - line_annotations.append(doc.annotation) - k += len(doc.text) - elif isinstance(doc, _ConcatDoc): - agenda.extend(_State(i, m, d, color, source) - for d in reversed(doc.children)) - elif isinstance(doc, _BreakDoc): - if m == _BreakMode.BREAK: - if len(line_annotations) > 0: - color_state, color_str = _update_color(use_color, color_state, - annotation_colors) - line_text += color_str - lines.append(_Line(line_text, k, line_annotations)) - if source_map is not None: - pos = len(line_text) - if source_start != pos and source is not _NO_SOURCE: - line_source_map.append((source_start, pos, source)) - source_map.append(line_source_map) - line_source_map = [] - source_start = i - line_text = " " * i - line_annotations = [] - k = i + def __repr__(self): return f"group({self.child})" + + def num_annotations(self) -> int: + return self.child.num_annotations() + + class _NestDoc(Doc): + __slots__ = ("n", "child",) + n: int + child: Doc + + def __init__(self, n: int, child: Doc): + assert isinstance(child, Doc), child + self.n = n + self.child = child + + def __repr__(self): return f"nest({self.n, self.child})" + + def num_annotations(self) -> int: + return self.child.num_annotations() + + _NO_SOURCE = object() + + class _SourceMapDoc(Doc): + __slots__ = ("child", "source") + child: Doc + source: Any + + def __init__(self, child: Doc, source: Any): + assert isinstance(child, Doc), child + self.child = child + self.source = source + + def __repr__(self): return f"source({self.child}, {self.source})" + + def num_annotations(self) -> int: + return self.child.num_annotations() + + Color = enum.Enum("Color", ["BLACK", "RED", "GREEN", "YELLOW", "BLUE", + "MAGENTA", "CYAN", "WHITE", "RESET"]) + Intensity = enum.Enum("Intensity", ["DIM", "NORMAL", "BRIGHT"]) + + class _ColorDoc(Doc): + __slots__ = ("foreground", "background", "intensity", "child") + foreground: Color | None + background: Color | None + intensity: Intensity | None + child: Doc + + def __init__(self, child: Doc, *, foreground: Color | None = None, + background: Color | None = None, + intensity: Intensity | None = None): + assert isinstance(child, Doc), child + self.child = child + self.foreground = foreground + self.background = background + self.intensity = intensity + + def num_annotations(self) -> int: + return self.child.num_annotations() + + _BreakMode = enum.Enum("_BreakMode", ["FLAT", "BREAK"]) + + + # In Lindig's paper fits() and format() are defined recursively. This is a + # non-recursive formulation using an explicit stack, necessary because Python + # doesn't have a tail recursion optimization. + + def _fits(doc: Doc, width: int) -> bool: + agenda = [doc] + while width >= 0 and len(agenda) > 0: + doc = agenda.pop() + if isinstance(doc, _NilDoc): + pass + elif isinstance(doc, _TextDoc): + width -= len(doc.text) + elif isinstance(doc, _ConcatDoc): + agenda.extend(reversed(doc.children)) + elif isinstance(doc, _BreakDoc): + width -= len(doc.text) + elif isinstance(doc, (_NestDoc, _GroupDoc, _ColorDoc, _SourceMapDoc)): + agenda.append(doc.child) + else: + raise ValueError("Invalid document ", doc) + + return width >= 0 + + + # Annotation layout: A flat group is sparse if there are no breaks between + # annotations. + def _sparse(doc: Doc) -> bool: + agenda = [doc] + if doc.num_annotations() == 0: + return True + num_annotations = 0 + seen_break = False + while len(agenda) > 0: + doc = agenda.pop() + if isinstance(doc, _NilDoc): + pass + elif isinstance(doc, _TextDoc): + if doc.annotation is not None: + if num_annotations >= 1 and seen_break: + return False + num_annotations += 1 + elif isinstance(doc, _ConcatDoc): + agenda.extend(reversed(doc.children)) + elif isinstance(doc, _BreakDoc): + seen_break = True + elif isinstance(doc, _NestDoc): + agenda.append(doc.child) + elif isinstance(doc, _GroupDoc): + agenda.append(doc.child) + elif isinstance(doc, _ColorDoc) or isinstance(doc, _SourceMapDoc): + agenda.append(doc.child) else: + raise ValueError("Invalid document ", doc) + + return True + + class _ColorState(NamedTuple): + foreground: Color + background: Color + intensity: Intensity + + class _State(NamedTuple): + indent: int + mode: _BreakMode + doc: Doc + color: _ColorState + source_map: Any + + class _Line(NamedTuple): + text: str + width: int + annotations: list[str] + + + def _update_color(use_color: bool, state: _ColorState, update: _ColorState + ) -> tuple[_ColorState, str]: + if not use_color or colorama is None: + return update, "" + color_str = "" + if state.foreground != update.foreground: + color_str += getattr(colorama.Fore, str(update.foreground.name)) + if state.background != update.background: + color_str += getattr(colorama.Back, str(update.background.name)) + if state.intensity != update.intensity: + color_str += colorama.Style.NORMAL # pytype: disable=unsupported-operands + color_str += getattr(colorama.Style, str(update.intensity.name)) + return update, color_str + + + def _align_annotations(lines: list[_Line], annotation_prefix: str) -> list[str]: + # TODO: Hafiz also implements a local alignment mode, where groups of lines + # with annotations are aligned together. + maxlen = max(l.width for l in lines) + out = [] + for l in lines: + if len(l.annotations) == 0: + out.append(l.text) + else: + out.append(f"{l.text}{' ' * (maxlen - l.width)}" + f"{annotation_prefix}{l.annotations[0]}") + for a in l.annotations[1:]: + out.append(f"{' ' * maxlen}{annotation_prefix}{a}") + return out + + + + def _format( + doc: Doc, width: int, *, use_color: bool, annotation_prefix: str, + source_map: list[list[tuple[int, int, Any]]] | None + ) -> str: + lines = [] + default_colors = _ColorState(Color.RESET, Color.RESET, Intensity.NORMAL) + annotation_colors = _ColorState(Color.RESET, Color.RESET, Intensity.DIM) + color_state = default_colors + source_start = 0 # The column at which the current source region starts. + source = _NO_SOURCE # The currently active source region. + line_source_map = [] # Source maps for the current line of text. + agenda = [_State(0, _BreakMode.BREAK, doc, default_colors, source)] + k = 0 + line_text = "" + line_annotations = [] + while len(agenda) > 0: + i, m, doc, color, agenda_source = agenda.pop() + if source_map is not None and agenda_source != source: + pos = len(line_text) + if source_start != pos and source is not _NO_SOURCE: + line_source_map.append((source_start, pos, source)) + source = agenda_source + source_start = pos + if isinstance(doc, _NilDoc): + pass + elif isinstance(doc, _TextDoc): color_state, color_str = _update_color(use_color, color_state, color) line_text += color_str line_text += doc.text + if doc.annotation is not None: + line_annotations.append(doc.annotation) k += len(doc.text) - elif isinstance(doc, _NestDoc): - agenda.append(_State(i + doc.n, m, doc.child, color, source)) - elif isinstance(doc, _GroupDoc): - # In Lindig's paper, _fits is passed the remainder of the document. - # I'm pretty sure that's a bug and we care only if the current group fits! - if (_sparse(doc) - and _fits(doc, width - k, [(i, _BreakMode.FLAT, doc.child)])): - agenda.append(_State(i, _BreakMode.FLAT, doc.child, color, source)) + elif isinstance(doc, _ConcatDoc): + agenda.extend(_State(i, m, d, color, source) + for d in reversed(doc.children)) + elif isinstance(doc, _BreakDoc): + if m == _BreakMode.BREAK: + if len(line_annotations) > 0: + color_state, color_str = _update_color(use_color, color_state, + annotation_colors) + line_text += color_str + lines.append(_Line(line_text, k, line_annotations)) + if source_map is not None: + pos = len(line_text) + if source_start != pos and source is not _NO_SOURCE: + line_source_map.append((source_start, pos, source)) + source_map.append(line_source_map) + line_source_map = [] + source_start = i + line_text = " " * i + line_annotations = [] + k = i + else: + color_state, color_str = _update_color(use_color, color_state, color) + line_text += color_str + line_text += doc.text + k += len(doc.text) + elif isinstance(doc, _NestDoc): + agenda.append(_State(i + doc.n, m, doc.child, color, source)) + elif isinstance(doc, _GroupDoc): + # In Lindig's paper, _fits is passed the remainder of the document. + # I'm pretty sure that's a bug and we care only if the current group fits! + if (_fits(doc, width - k) and _sparse(doc)): + agenda.append(_State(i, _BreakMode.FLAT, doc.child, color, source)) + else: + agenda.append(_State(i, _BreakMode.BREAK, doc.child, color, source)) + elif isinstance(doc, _ColorDoc): + color = _ColorState(doc.foreground or color.foreground, + doc.background or color.background, + doc.intensity or color.intensity) + agenda.append(_State(i, m, doc.child, color, source)) + elif isinstance(doc, _SourceMapDoc): + agenda.append(_State(i, m, doc.child, color, doc.source)) else: - agenda.append(_State(i, _BreakMode.BREAK, doc.child, color, source)) - elif isinstance(doc, _ColorDoc): - color = _ColorState(doc.foreground or color.foreground, - doc.background or color.background, - doc.intensity or color.intensity) - agenda.append(_State(i, m, doc.child, color, source)) - elif isinstance(doc, _SourceMapDoc): - agenda.append(_State(i, m, doc.child, color, doc.source)) - else: - raise ValueError("Invalid document ", doc) - - if len(line_annotations) > 0: - color_state, color_str = _update_color(use_color, color_state, - annotation_colors) - line_text += color_str - if source_map is not None: - pos = len(line_text) - if source_start != pos and source is not _NO_SOURCE: - line_source_map.append((source_start, pos, source)) - source_map.append(line_source_map) - lines.append(_Line(line_text, k, line_annotations)) - lines = _align_annotations(lines) - out = "\n".join( - l.text if l.annotations is None - else f"{l.text}{annotation_prefix}{l.annotations}" for l in lines) - color_state, color_str = _update_color(use_color, color_state, - default_colors) - return out + color_str - - - - -# Public API. - -def nil() -> Doc: - """An empty document.""" - return _nil - -def text(s: str, annotation: str | None = None) -> Doc: - """Literal text.""" - return _TextDoc(s, annotation) - -def concat(docs: Sequence[Doc]) -> Doc: - """Concatenation of documents.""" - docs = list(docs) - if len(docs) == 1: - return docs[0] - return _ConcatDoc(docs) + raise ValueError("Invalid document ", doc) + + if len(line_annotations) > 0: + color_state, color_str = _update_color(use_color, color_state, + annotation_colors) + line_text += color_str + if source_map is not None: + pos = len(line_text) + if source_start != pos and source is not _NO_SOURCE: + line_source_map.append((source_start, pos, source)) + source_map.append(line_source_map) + lines.append(_Line(line_text, k, line_annotations)) + out = "\n".join(_align_annotations(lines, annotation_prefix)) + _, color_str = _update_color(use_color, color_state, + default_colors) + return out + color_str -def brk(text: str = " ") -> Doc: - """A break. - Prints either as a newline or as `text`, depending on the enclosing group. - """ - return _BreakDoc(text) -def group(doc: Doc) -> Doc: - """Layout alternative groups. - Prints the group with its breaks as their text (typically spaces) if the - entire group would fit on the line when printed that way. Otherwise, breaks - inside the group as printed as newlines. - """ - return _GroupDoc(doc) + # Public API. -def nest(n: int, doc: Doc) -> Doc: - """Increases the indentation level by `n`.""" - return _NestDoc(n, doc) + def nil() -> Doc: + """An empty document.""" + return _nil + def text(s: str, annotation: str | None = None) -> Doc: + """Literal text.""" + return _TextDoc(s, annotation) -def color(doc: Doc, *, foreground: Color | None = None, - background: Color | None = None, - intensity: Intensity | None = None): - """ANSI colors. + def concat(docs: Sequence[Doc]) -> Doc: + """Concatenation of documents.""" + docs = list(docs) + if len(docs) == 1: + return docs[0] + return _ConcatDoc(docs) - Overrides the foreground/background/intensity of the text for the child doc. - Requires use_colors=True to be set when printing and the `colorama` package - to be installed; otherwise does nothing. - """ - return _ColorDoc(doc, foreground=foreground, background=background, - intensity=intensity) + def brk(text: str = " ") -> Doc: + """A break. + Prints either as a newline or as `text`, depending on the enclosing group. + """ + return _BreakDoc(text) + + def group(doc: Doc) -> Doc: + """Layout alternative groups. -def source_map(doc: Doc, source: Any): - """Source mapping. + Prints the group with its breaks as their text (typically spaces) if the + entire group would fit on the line when printed that way. Otherwise, breaks + inside the group as printed as newlines. + """ + return _GroupDoc(doc) + + def nest(n: int, doc: Doc) -> Doc: + """Increases the indentation level by `n`.""" + return _NestDoc(n, doc) + + + def color(doc: Doc, *, foreground: Color | None = None, + background: Color | None = None, + intensity: Intensity | None = None): + """ANSI colors. + + Overrides the foreground/background/intensity of the text for the child doc. + Requires use_colors=True to be set when printing and the `colorama` package + to be installed; otherwise does nothing. + """ + return _ColorDoc(doc, foreground=foreground, background=background, + intensity=intensity) + + + def source_map(doc: Doc, source: Any): + """Source mapping. + + A source map associates a region of the pretty-printer's text output with a + source location that produced it. For the purposes of the pretty printer a + ``source`` may be any object: we require only that we can compare sources for + equality. A text region to source object mapping can be populated as a side + output of the ``format`` method. + """ + return _SourceMapDoc(doc, source) + +else: + Color = _pretty_printer.Color + Intensity = _pretty_printer.Intensity + Doc = _pretty_printer.Doc + def _format( + self, width: int = 80, *, use_color: bool | None = None, + annotation_prefix: str = " # ", + source_map: list[list[tuple[int, int, Any]]] | None = None + ) -> str: + """ + Formats a pretty-printer document as a string. + + Args: + source_map: for each line in the output, contains a list of + (start column, end column, source) tuples. Each tuple associates a + region of output text with a source. + """ + if use_color is None: + use_color = CAN_USE_COLOR and _PPRINT_USE_COLOR.value + return self._format( + width, use_color=use_color, annotation_prefix=annotation_prefix, + source_map=source_map) + Doc.format = _format + Doc.__str__ = lambda self: self.format() + nil = _pretty_printer.nil + text = _pretty_printer.text + concat = _pretty_printer.concat + brk = _pretty_printer.brk + group = _pretty_printer.group + nest = _pretty_printer.nest + color = _pretty_printer.color + source_map = _pretty_printer.source_map - A source map associates a region of the pretty-printer's text output with a - source location that produced it. For the purposes of the pretty printer a - ``source`` may be any object: we require only that we can compare sources for - equality. A text region to source object mapping can be populated as a side - output of the ``format`` method. - """ - return _SourceMapDoc(doc, source) type_annotation = partial(color, intensity=Intensity.NORMAL, foreground=Color.MAGENTA) @@ -480,6 +531,8 @@ def join(sep: Doc, docs: Sequence[Doc]) -> Doc: docs = list(docs) if len(docs) == 0: return nil() + if len(docs) == 1: + return docs[0] xs = [docs[0]] for doc in docs[1:]: xs.append(sep) diff --git a/jax/_src/prng.py b/jax/_src/prng.py index 2fa9b2b37aa4..47ce6f6a7ed9 100644 --- a/jax/_src/prng.py +++ b/jax/_src/prng.py @@ -31,6 +31,7 @@ from jax._src import core from jax._src import dispatch from jax._src import dtypes +from jax._src import ffi from jax._src import pretty_printer as pp from jax._src import source_info_util from jax._src import tree_util as tree_util_internal @@ -50,7 +51,8 @@ from jax._src.numpy.array_methods import ( _array_operators, _set_array_base_attributes, _IndexUpdateHelper) from jax._src.sharding_impls import ( - NamedSharding, PmapSharding, physical_sharding, logical_sharding) + NamedSharding, PmapSharding, SingleDeviceSharding, physical_sharding, + logical_sharding) from jax._src.typing import Array from jax._src.util import safe_map, safe_zip @@ -64,6 +66,13 @@ UINT_DTYPES = { 8: jnp.uint8, 16: jnp.uint16, 32: jnp.uint32, 64: jnp.uint64} +if hasattr(gpu_prng, "registrations"): + for platform, targets in gpu_prng.registrations().items(): + for name, value, api_version in targets: + ffi.register_ffi_target( + name, value, platform=platform, api_version=api_version + ) + # -- PRNG implementation interface class PRNGImpl(NamedTuple): @@ -105,7 +114,7 @@ def pprint(self): ])))) -prngs = {} +prngs: dict[str, PRNGImpl] = {} def register_prng(impl: PRNGImpl): if impl.name in prngs: @@ -148,7 +157,7 @@ class behave like an array whose base elements are keys, hiding the # device_buffer, device_buffers, __cuda_interface__() _impl: PRNGImpl - _base_array: typing.Array + _base_array: jax.Array _consumed: bool | np.ndarray # Used in jax.experimental.key_reuse. _source_info: None | source_info_util.SourceInfo = None @@ -156,8 +165,16 @@ def __init__(self, impl, key_data: Any): assert not isinstance(key_data, core.Tracer) _check_prng_key_data(impl, key_data) self._impl = impl - self._base_array = key_data self._consumed = False # TODO(jakevdp): default to True here? + if isinstance(key_data, np.ndarray): + aval = core.get_aval(key_data) + device = pxla.get_default_device() + key_data = pxla.batched_device_put(aval, SingleDeviceSharding(device), + [key_data], [device], committed=False) + self._base_array = key_data + + def _replace_with(self, value: PRNGKeyArray): + self._base_array._replace_with(value._base_array) def block_until_ready(self): _ = self._base_array.block_until_ready() @@ -168,9 +185,8 @@ def copy_to_host_async(self): @property def aval(self): - logical_sharding = (self.sharding if hasattr(self._base_array, 'sharding') - else None) - return keys_shaped_array(self._impl, self.shape, logical_sharding) + vma = self._base_array.aval.vma + return keys_shaped_array(self._impl, self.shape, self.sharding, vma) @property def shape(self): @@ -188,6 +204,10 @@ def ndim(self): def dtype(self): return KeyTy(self._impl) + @property + def nbytes(self): + return self.itemsize * self.size + @property def itemsize(self): return self.dtype.itemsize @@ -321,8 +341,8 @@ def seed_with_impl(impl: PRNGImpl, seed: int | typing.ArrayLike) -> PRNGKeyArray return random_seed(seed, impl=impl) -def keys_shaped_array(impl, shape, sharding): - aval = core.ShapedArray(shape, KeyTy(impl)) +def keys_shaped_array(impl, shape, sharding, vma): + aval = core.ShapedArray(shape, KeyTy(impl), vma=vma) return core.update_aval_with_sharding(aval, sharding) def base_arr_shape_to_keys_shape(impl, base_arr_shape): @@ -415,7 +435,6 @@ def device_put_sharded(vals, aval, sharding, devices): @staticmethod def device_put_replicated(val, aval, sharding, devices): physical_aval = core.physical_aval(aval) - assert len(xla.aval_to_xla_shapes(physical_aval)) == 1 physical_buf = random_unwrap(val) phys_sharding = physical_sharding(aval, sharding) physical_result = pxla.batched_device_put( @@ -542,7 +561,8 @@ def random_seed(seeds: int | typing.ArrayLike, impl: PRNGImpl) -> PRNGKeyArray: @random_seed_p.def_abstract_eval def random_seed_abstract_eval(seeds_aval, *, impl): - return keys_shaped_array(impl, seeds_aval.shape, seeds_aval.sharding) + return keys_shaped_array(impl, seeds_aval.shape, seeds_aval.sharding, + seeds_aval.vma) @random_seed_p.def_impl def random_seed_impl(seeds, *, impl): @@ -577,7 +597,7 @@ def random_split_abstract_eval(keys_aval, *, shape): # don't choose None here? new_spec = (*keys_aval.sharding.spec, *[None] * len(shape)) return keys_shaped_array(keys_aval.dtype._impl, (*keys_aval.shape, *shape), - keys_aval.sharding.with_spec(new_spec)) + keys_aval.sharding.update(spec=new_spec), keys_aval.vma) @random_split_p.def_impl def random_split_impl(keys, *, shape): @@ -603,7 +623,9 @@ def random_split_lowering(ctx, keys, *, shape): def random_fold_in(keys, msgs): - return random_fold_in_p.bind(keys, jnp.asarray(msgs)) + msgs = jnp.asarray(msgs) + keys, msgs = core.standard_insert_pvary(keys, msgs) + return random_fold_in_p.bind(keys, msgs) random_fold_in_p = core.Primitive('random_fold_in') ad.defjvp_zero(random_fold_in_p) @@ -613,7 +635,10 @@ def random_fold_in(keys, msgs): def random_fold_in_abstract_eval(keys_aval, msgs_aval): shape = lax_internal.broadcasting_shape_rule( 'random_fold_in', keys_aval, msgs_aval) - return core.ShapedArray(shape, keys_aval.dtype) + sharding = lax_internal.broadcasting_sharding_rule( + 'random_fold_in', keys_aval, msgs_aval) + vma = core.standard_vma_rule('random_fold_in', keys_aval, msgs_aval) + return core.ShapedArray(shape, keys_aval.dtype, sharding=sharding, vma=vma) @random_fold_in_p.def_impl def random_fold_in_impl(keys, msgs): @@ -651,7 +676,7 @@ def random_bits(keys, bit_width, shape): def random_bits_abstract_eval(keys_aval, *, bit_width, shape): out_shape = (*keys_aval.shape, *shape) out_dtype = dtypes.dtype(f'uint{bit_width}') - return core.ShapedArray(out_shape, out_dtype) + return core.ShapedArray(out_shape, out_dtype, vma=keys_aval.vma) @random_bits_p.def_impl def random_bits_impl(keys, *, bit_width, shape): @@ -708,7 +733,7 @@ def random_wrap(base_arr, *, impl): def random_wrap_abstract_eval(base_arr_aval, *, impl): shape = base_arr_shape_to_keys_shape(impl, base_arr_aval.shape) sharding = logical_sharding(shape, KeyTy(impl), base_arr_aval.sharding) - return keys_shaped_array(impl, shape, sharding) + return keys_shaped_array(impl, shape, sharding, base_arr_aval.vma) @random_wrap_p.def_impl def random_wrap_impl(base_arr, *, impl): @@ -902,7 +927,7 @@ def _threefry2x32_lowering(key1, key2, x1, x2, use_rolled_loops=True): multiple_results=True) -def _threefry2x32_gpu_lowering_rule(lowering_func, ctx, k1, k2, x1, x2): +def _threefry2x32_gpu_lowering_rule(ctx, k1, k2, x1, x2, *, target_name_prefix): if not config.threefry_gpu_kernel_lowering.value: # back to default lowering return _threefry2x32_lowering_rule(ctx, k1, k2, x1, x2) @@ -917,23 +942,11 @@ def _broadcast(x, aval): return mlir.broadcast_in_dim(ctx, x, aval_out, broadcast_dimensions=range(rank - len(aval.shape), rank)) - out_len = reduce(op.mul, aval_out.shape, 1) - if not core.is_constant_dim(out_len): - length = mlir.eval_dynamic_shape_as_tensor(ctx, [out_len]) - length = mlir.hlo.convert( - ir.RankedTensorType.get((1,), ir.IntegerType.get_signless(64)), - length) - output_shape = mlir.eval_dynamic_shape_as_tensor(ctx, aval_out.shape) - else: - length = int(out_len) # will be passed statically - output_shape = None - - return lowering_func( - (_broadcast(k1, k1_aval), _broadcast(k2, k2_aval)), - (_broadcast(x1, x1_aval), _broadcast(x2, x2_aval)), length, - output_shape, - False, # forward_compatibility_mode - ) + sub_ctx = ctx.replace(avals_in=(aval_out,) * 4) + rule = ffi.ffi_lowering( + f"{target_name_prefix}_threefry2x32_ffi") + return rule(sub_ctx, _broadcast(k1, k1_aval), _broadcast(k2, k2_aval), + _broadcast(x1, x1_aval), _broadcast(x2, x2_aval)) threefry2x32_p = core.Primitive("threefry2x32") @@ -947,11 +960,11 @@ def _broadcast(x, aval): threefry2x32_p, _threefry2x32_cpu_lowering_rule, platform='cpu') mlir.register_lowering( threefry2x32_p, - partial(_threefry2x32_gpu_lowering_rule, gpu_prng.cuda_threefry2x32), + partial(_threefry2x32_gpu_lowering_rule, target_name_prefix='cu'), platform='cuda') mlir.register_lowering( threefry2x32_p, - partial(_threefry2x32_gpu_lowering_rule, gpu_prng.rocm_threefry2x32), + partial(_threefry2x32_gpu_lowering_rule, target_name_prefix='hip'), platform='rocm') @@ -1294,3 +1307,20 @@ def _unsafe_rbg_fold_in(key: typing.Array, data: typing.Array) -> typing.Array: tag='urbg') register_prng(unsafe_rbg_prng_impl) + + +# Register export serialization for PRNG key types. +try: + from jax._src.export import serialization # pytype: disable=import-error + from jax._src.export import serialization_generated as ser_flatbuf # pytype: disable=import-error +except ImportError: + # This can happen if flatbuffers is not installed, in which case export + # serialization is not supported and it is safe to skip the registration. + pass +else: + serialization.register_dtype_kind( + KeyTy(prngs["threefry2x32"]), ser_flatbuf.DType.key_fry) + serialization.register_dtype_kind( + KeyTy(prngs["rbg"]), ser_flatbuf.DType.key_rbg) + serialization.register_dtype_kind( + KeyTy(prngs["unsafe_rbg"]), ser_flatbuf.DType.key_unsafe_rbg) diff --git a/jax/_src/profiler.py b/jax/_src/profiler.py index f06933f57e22..efdd7bd1e2a1 100644 --- a/jax/_src/profiler.py +++ b/jax/_src/profiler.py @@ -32,14 +32,18 @@ traceback_util.register_exclusion(__file__) from jax._src import xla_bridge -from jax._src.lib import xla_client +from jax._src.lib import _profiler -_profiler_server: xla_client.profiler.ProfilerServer | None = None +_profiler_server: _profiler.ProfilerServer | None = None logger = logging.getLogger(__name__) -def start_server(port: int) -> xla_client.profiler.ProfilerServer: +class ProfileOptions(_profiler.ProfileOptions): + """Profiler Options to configure the collectors for the profiler.""" + + +def start_server(port: int) -> _profiler.ProfilerServer: """Starts the profiler server on port `port`. Using the "TensorFlow profiler" feature in `TensorBoard @@ -59,7 +63,7 @@ def start_server(port: int) -> xla_client.profiler.ProfilerServer: # is for start_trace), but I'm putting it here to be safe. xla_bridge.get_backend() - _profiler_server = xla_client.profiler.start_server(port) + _profiler_server = _profiler.start_server(port) return _profiler_server @@ -89,12 +93,17 @@ def reset(self): _profile_state = _ProfileState() -def start_trace(log_dir: os.PathLike | str, create_perfetto_link: bool = False, - create_perfetto_trace: bool = False) -> None: +def start_trace( + log_dir: os.PathLike | str, + create_perfetto_link: bool = False, + create_perfetto_trace: bool = False, + profiler_options: ProfileOptions | None = None, +) -> None: """Starts a profiler trace. The trace will capture CPU, GPU, and/or TPU activity, including Python - functions and JAX on-device operations. Use :func:`stop_trace` to end the trace + functions and JAX on-device operations. Use :func:`stop_trace` to end the + trace and save the results to ``log_dir``. The resulting trace can be viewed with TensorBoard. Note that TensorBoard @@ -113,8 +122,8 @@ def start_trace(log_dir: os.PathLike | str, create_perfetto_link: bool = False, ``perfetto_trace.json.gz`` file that is compatible for upload with the Perfetto trace viewer UI (https://ui.perfetto.dev). The file will also be generated if ``create_perfetto_link`` is true. This could be useful if you - want to generate a Perfetto-compatible trace without blocking the - process. + want to generate a Perfetto-compatible trace without blocking the process. + profiler_options: Profiler options to configure the profiler for collection. """ with _profile_state.lock: if _profile_state.profile_session is not None: @@ -126,7 +135,12 @@ def start_trace(log_dir: os.PathLike | str, create_perfetto_link: bool = False, # fail and no TPU operations will be included in the profile. xla_bridge.get_backend() - _profile_state.profile_session = xla_client.profiler.ProfilerSession() + if profiler_options is None: + _profile_state.profile_session = _profiler.ProfilerSession() + else: + _profile_state.profile_session = _profiler.ProfilerSession( + profiler_options + ) _profile_state.create_perfetto_link = create_perfetto_link _profile_state.create_perfetto_trace = ( create_perfetto_trace or create_perfetto_link) @@ -201,7 +215,7 @@ def stop_trace(): if _profile_state.profile_session is None: raise RuntimeError("No profile started") sess = _profile_state.profile_session - sess.export(sess.stop(), str(_profile_state.log_dir)) + sess.stop_and_export(str(_profile_state.log_dir)) # type: ignore if _profile_state.create_perfetto_trace: abs_filename = _write_perfetto_trace_file(_profile_state.log_dir) if _profile_state.create_perfetto_link: @@ -219,13 +233,18 @@ def stop_and_get_fdo_profile() -> bytes | str: if _profile_state.profile_session is None: raise RuntimeError("No profile started") xspace = _profile_state.profile_session.stop() - fdo_profile = xla_client.profiler.get_fdo_profile(xspace) + fdo_profile = _profiler.get_fdo_profile(xspace) _profile_state.reset() return fdo_profile @contextmanager -def trace(log_dir: os.PathLike | str, create_perfetto_link=False, create_perfetto_trace=False): +def trace( + log_dir: os.PathLike | str, + create_perfetto_link=False, + create_perfetto_trace=False, + profiler_options: ProfileOptions | None = None, +): """Context manager to take a profiler trace. The trace will capture CPU, GPU, and/or TPU activity, including Python @@ -247,17 +266,19 @@ def trace(log_dir: os.PathLike | str, create_perfetto_link=False, create_perfett ``perfetto_trace.json.gz`` file that is compatible for upload with the Perfetto trace viewer UI (https://ui.perfetto.dev). The file will also be generated if ``create_perfetto_link`` is true. This could be useful if you - want to generate a Perfetto-compatible trace without blocking the - process. + want to generate a Perfetto-compatible trace without blocking the process. + profiler_options: Profiler options to configure the profiler for collection. """ - start_trace(log_dir, create_perfetto_link, create_perfetto_trace) + start_trace( + log_dir, create_perfetto_link, create_perfetto_trace, profiler_options + ) try: yield finally: stop_trace() -class TraceAnnotation(xla_client.profiler.TraceMe): +class TraceAnnotation(_profiler.TraceMe): """Context manager that generates a trace event in the profiler. The trace event spans the duration of the code enclosed by the context. @@ -271,7 +292,6 @@ class TraceAnnotation(xla_client.profiler.TraceMe): This will cause a "my_label" event to show up on the trace timeline if the event occurs while the process is being traced. """ - pass class StepTraceAnnotation(TraceAnnotation): @@ -332,7 +352,6 @@ def annotate_function(func: Callable, name: str | None = None, def wrapper(*args, **kwargs): with TraceAnnotation(name, **decorator_kwargs): return func(*args, **kwargs) - return wrapper return wrapper @@ -361,7 +380,8 @@ def device_memory_profile(backend: str | None = None) -> bytes: Returns: A byte string containing a binary `pprof`-format protocol buffer. """ - return xla_client.heap_profile(xla_bridge.get_backend(backend)) + client = xla_bridge.get_backend(backend) + return gzip.compress(client.heap_profile()) def save_device_memory_profile(filename, backend: str | None = None) -> None: @@ -382,7 +402,7 @@ def save_device_memory_profile(filename, backend: str | None = None) -> None: # Allows to run model with profiler given amount of times. After required amount -# of retries achived client can collect FDO data. +# of retries achieved client can collect FDO data. class PGLEProfiler: def __init__(self, retries: int, percentile: int): @@ -391,7 +411,7 @@ def __init__(self, retries: int, percentile: int): self.collected_fdo: str | None = None self.called_times: int = 0 self.fdo_profiles: list[Any] = [] - self.current_session: xla_client.profiler.ProfilerSession | None = None + self.current_session: _profiler.ProfilerSession | None = None def consume_fdo_profile(self) -> str | None: if self.collected_fdo is not None: @@ -400,7 +420,7 @@ def consume_fdo_profile(self) -> str | None: if not self.is_enabled() or self.called_times != self.retries: return None - self.collected_fdo = xla_client.profiler.aggregate_profiled_instructions( + self.collected_fdo = _profiler.aggregate_profiled_instructions( self.fdo_profiles, self.percentile ) return self.collected_fdo @@ -424,16 +444,17 @@ def trace(cls, runner: PGLEProfiler | None): or not runner.is_enabled() or runner.is_fdo_consumed()): yield else: - options = xla_client.profiler.ProfileOptions() + options = _profiler.ProfileOptions() options.enable_hlo_proto = True - runner.current_session = xla_client.profiler.ProfilerSession(options) + options.raise_error_on_start_failure = True + runner.current_session = _profiler.ProfilerSession(options) try: yield finally: xspace = runner.current_session.stop() runner.fdo_profiles.append( - xla_client.profiler.get_fdo_profile(xspace) + _profiler.get_fdo_profile(xspace) ) runner.current_session = None diff --git a/jax/_src/public_test_util.py b/jax/_src/public_test_util.py index 455a3b98cce2..3b1e24bc9c50 100644 --- a/jax/_src/public_test_util.py +++ b/jax/_src/public_test_util.py @@ -14,6 +14,7 @@ from functools import partial import operator +from typing import Any, TypeAlias from jax._src import api from jax._src import config @@ -32,7 +33,7 @@ EPS = 1e-4 -def _dtype(x): +def _dtype(x: Any) -> np.dtype: if hasattr(x, 'dtype'): return x.dtype elif type(x) in _dtypes.python_scalar_dtypes: @@ -40,20 +41,27 @@ def _dtype(x): else: return np.asarray(x).dtype +ToleranceDict: TypeAlias = dict[np.dtype, int | float] -_default_tolerance = { +_default_tolerance: ToleranceDict = { _dtypes.float0: 0, np.dtype(np.bool_): 0, + np.dtype(_dtypes.int2): 0, np.dtype(_dtypes.int4): 0, np.dtype(np.int8): 0, np.dtype(np.int16): 0, np.dtype(np.int32): 0, np.dtype(np.int64): 0, + np.dtype(_dtypes.uint2): 0, np.dtype(_dtypes.uint4): 0, np.dtype(np.uint8): 0, np.dtype(np.uint16): 0, np.dtype(np.uint32): 0, np.dtype(np.uint64): 0, + np.dtype(_dtypes.float4_e2m1fn): 1e0, + np.dtype(_dtypes.float8_e3m4): 1e-1, + np.dtype(_dtypes.float8_e4m3): 1e-1, + np.dtype(_dtypes.float8_e8m0fnu): 1e0, np.dtype(_dtypes.float8_e4m3b11fnuz): 1e-1, np.dtype(_dtypes.float8_e4m3fn): 1e-1, np.dtype(_dtypes.float8_e4m3fnuz): 1e-1, @@ -67,16 +75,15 @@ def _dtype(x): np.dtype(np.complex128): 1e-15, } -if _dtypes.int2 is not None: - assert _dtypes.uint2 is not None - _default_tolerance[np.dtype(_dtypes.int2)] = 0 - _default_tolerance[np.dtype(_dtypes.uint2)] = 0 - def default_tolerance(): return _default_tolerance -default_gradient_tolerance = { +default_gradient_tolerance: ToleranceDict = { + np.dtype(_dtypes.float4_e2m1fn): 1e0, + np.dtype(_dtypes.float8_e3m4): 1e-1, + np.dtype(_dtypes.float8_e4m3): 1e-1, + np.dtype(_dtypes.float8_e8m0fnu): 1e0, np.dtype(_dtypes.float8_e4m3b11fnuz): 1e-1, np.dtype(_dtypes.float8_e4m3fn): 1e-1, np.dtype(_dtypes.float8_e4m3fnuz): 1e-1, @@ -90,21 +97,8 @@ def default_tolerance(): np.dtype(np.complex128): 1e-5, } -# TODO: make this unconditional when ml_dtypes>=0.5.0 is required -if _dtypes.float8_e3m4 is not None: - _default_tolerance[np.dtype(_dtypes.float8_e3m4)] = 1e-1 - default_gradient_tolerance[np.dtype(_dtypes.float8_e3m4)] = 1e-1 -if _dtypes.float8_e4m3 is not None: - _default_tolerance[np.dtype(_dtypes.float8_e4m3)] = 1e-1 - default_gradient_tolerance[np.dtype(_dtypes.float8_e4m3)] = 1e-1 -if _dtypes.float8_e8m0fnu is not None: - _default_tolerance[np.dtype(_dtypes.float8_e8m0fnu)] = 1e0 - default_gradient_tolerance[np.dtype(_dtypes.float8_e8m0fnu)] = 1e0 -if _dtypes.float4_e2m1fn is not None: - _default_tolerance[np.dtype(_dtypes.float4_e2m1fn)] = 1e0 - default_gradient_tolerance[np.dtype(_dtypes.float4_e2m1fn)] = 1e0 - -def is_python_scalar(val): + +def is_python_scalar(val: Any) -> bool: return not isinstance(val, np.generic) and isinstance(val, (bool, int, float, complex)) def _assert_numpy_allclose(a, b, atol=None, rtol=None, err_msg=''): @@ -113,6 +107,10 @@ def _assert_numpy_allclose(a, b, atol=None, rtol=None, err_msg=''): return custom_float_dtypes = [ + _dtypes.float4_e2m1fn, + _dtypes.float8_e8m0fnu, + _dtypes.float8_e3m4, + _dtypes.float8_e4m3, _dtypes.float8_e4m3b11fnuz, _dtypes.float8_e4m3fn, _dtypes.float8_e4m3fnuz, @@ -121,15 +119,6 @@ def _assert_numpy_allclose(a, b, atol=None, rtol=None, err_msg=''): _dtypes.bfloat16, ] - if _dtypes.float8_e4m3 is not None: - custom_float_dtypes.insert(0, _dtypes.float8_e4m3) - if _dtypes.float8_e3m4 is not None: - custom_float_dtypes.insert(0, _dtypes.float8_e3m4) - if _dtypes.float8_e8m0fnu is not None: - custom_float_dtypes.insert(0, _dtypes.float8_e8m0fnu) - if _dtypes.float4_e2m1fn is not None: - custom_float_dtypes.insert(0, _dtypes.float4_e2m1fn) - def maybe_upcast(x): if x.dtype in custom_float_dtypes: return x.astype(np.float32) @@ -151,7 +140,8 @@ def maybe_upcast(x): # value errors. It should not do that. np.testing.assert_allclose(a, b, **kw, err_msg=err_msg) -def tolerance(dtype, tol=None): + +def tolerance(dtype: np.dtype, tol: int | float | ToleranceDict | None = None) -> int | float: tol = {} if tol is None else tol if not isinstance(tol, dict): return tol diff --git a/jax/_src/random.py b/jax/_src/random.py index 094268c65825..d44361ebb3a6 100644 --- a/jax/_src/random.py +++ b/jax/_src/random.py @@ -39,8 +39,11 @@ from jax._src.interpreters import batching from jax._src.interpreters import mlir from jax._src.lax import lax as lax_internal +from jax._src.nn.functions import softmax from jax._src.numpy.lax_numpy import _convert_and_clip_integer from jax._src.numpy.util import _arraylike, check_arraylike, promote_dtypes_inexact +from jax._src.pjit import auto_axes +from jax._src.sharding_impls import canonicalize_sharding from jax._src.typing import Array, ArrayLike, DTypeLike from jax._src.util import canonicalize_axis @@ -223,7 +226,7 @@ def PRNGKey(seed: int | ArrayLike, *, This function produces old-style legacy PRNG keys, which are arrays of dtype ``uint32``. For more, see the note in the `PRNG keys - `_ + `_ section. When possible, :func:`jax.random.key` is recommended for use instead. @@ -346,9 +349,19 @@ def _check_shape(name: str, shape: Shape, *param_shapes) -> None: raise ValueError(msg.format(name, shape_, shape)) +def maybe_auto_axes(f, out_sharding, **hoist_kwargs): + f_ = partial(f, **hoist_kwargs) + if out_sharding is None: + return f_ + else: + return auto_axes(f_, out_sharding=out_sharding) + + def bits(key: ArrayLike, shape: Shape = (), - dtype: DTypeLikeUInt | None = None) -> Array: + dtype: DTypeLikeUInt | None = None, + *, + out_sharding=None) -> Array: """Sample uniform bits in the form of unsigned integers. Args: @@ -371,15 +384,19 @@ def bits(key: ArrayLike, f"got {dtype}") dtype = dtypes.canonicalize_dtype(dtype) shape = core.canonicalize_shape(shape) + out_sharding = canonicalize_sharding(out_sharding, "bits") bit_width = dtype.itemsize * 8 - return _random_bits(key, bit_width, shape) + return maybe_auto_axes(_random_bits, out_sharding, + bit_width=bit_width, shape=shape)(key) def uniform(key: ArrayLike, shape: Shape = (), dtype: DTypeLikeFloat = float, minval: RealArray = 0., - maxval: RealArray = 1.) -> Array: + maxval: RealArray = 1., + *, + out_sharding=None) -> Array: """Sample uniform random values in [minval, maxval) with given shape/dtype. Args: @@ -397,15 +414,17 @@ def uniform(key: ArrayLike, key, _ = _check_prng_key("uniform", key) dtypes.check_user_dtype_supported(dtype) shape = core.canonicalize_shape(shape) + out_sharding = canonicalize_sharding(out_sharding, "uniform") if not dtypes.issubdtype(dtype, np.floating): raise ValueError(f"dtype argument to `uniform` must be a float dtype, " f"got {dtype}") dtype = dtypes.canonicalize_dtype(dtype) - return _uniform(key, shape, dtype, minval, maxval) + return maybe_auto_axes(_uniform, out_sharding, + shape=shape,dtype=dtype)(key, minval, maxval) -@partial(jit, static_argnums=(1, 2)) -def _uniform(key, shape, dtype, minval, maxval) -> Array: +@partial(jit, static_argnums=(3, 4)) +def _uniform(key, minval, maxval, shape, dtype) -> Array: _check_shape("uniform", shape) if not jnp.issubdtype(dtype, np.floating): raise TypeError("uniform only accepts floating point dtypes.") @@ -449,7 +468,9 @@ def randint(key: ArrayLike, shape: Shape, minval: IntegerArray, maxval: IntegerArray, - dtype: DTypeLikeInt = int) -> Array: + dtype: DTypeLikeInt = int, + *, + out_sharding=None) -> Array: """Sample uniform random values in [minval, maxval) with given shape/dtype. Args: @@ -469,10 +490,12 @@ def randint(key: ArrayLike, dtypes.check_user_dtype_supported(dtype) dtype = dtypes.canonicalize_dtype(dtype) shape = core.canonicalize_shape(shape) - return _randint(key, shape, minval, maxval, dtype) + out_sharding = canonicalize_sharding(out_sharding, "randint") + return maybe_auto_axes(_randint, out_sharding, shape=shape, dtype=dtype)( + key, minval, maxval) -@partial(jit, static_argnums=(1, 4)) -def _randint(key, shape, minval, maxval, dtype) -> Array: +@partial(jit, static_argnums=(3, 4)) +def _randint(key, minval, maxval, shape, dtype) -> Array: _check_shape("randint", shape, np.shape(minval), np.shape(maxval)) if not jnp.issubdtype(dtype, np.integer): raise TypeError(f"randint only accepts integer dtypes, got {dtype}") @@ -537,7 +560,9 @@ def _randint(key, shape, minval, maxval, dtype) -> Array: def permutation(key: ArrayLike, x: int | ArrayLike, axis: int = 0, - independent: bool = False) -> Array: + independent: bool = False, + *, + out_sharding=None) -> Array: """Returns a randomly permuted array or range. Args: @@ -554,11 +579,17 @@ def permutation(key: ArrayLike, key, _ = _check_prng_key("permutation", key) check_arraylike("permutation", x) axis = canonicalize_axis(axis, np.ndim(x) or 1) + out_sharding = canonicalize_sharding(out_sharding, "permutation") if not np.ndim(x): if not np.issubdtype(lax.dtype(x), np.integer): raise TypeError("x must be an integer or at least 1-dimensional") - r = core.concrete_or_error(int, x, 'argument x of jax.random.permutation()') - return _shuffle(key, jnp.arange(r), axis) + r = core.concrete_or_error(int, x, "argument x of jax.random.permutation()") + return maybe_auto_axes(lambda key: _shuffle(key, jnp.arange(r), axis), + out_sharding)(key) + return maybe_auto_axes( + _permutation, out_sharding, axis=axis, independent=independent)(key, x) + +def _permutation(key, x, axis, independent): if independent or np.ndim(x) == 1: return _shuffle(key, x, axis) ind = _shuffle(key, jnp.arange(x.shape[axis]), 0) # type: ignore[union-attr] @@ -602,7 +633,8 @@ def choice(key: ArrayLike, shape: Shape = (), replace: bool = True, p: RealArray | None = None, - axis: int = 0) -> Array: + axis: int = 0, + mode: str | None = None) -> Array: """Generates a random sample from a given array. .. warning:: @@ -625,6 +657,12 @@ def choice(key: ArrayLike, entries in a. axis: int, optional. The axis along which the selection is performed. The default, 0, selects by row. + mode: optional, "high" or "low" for how many bits to use in the gumbel sampler + when `p is None` and `replace = False`. The default is determined by the + ``use_high_dynamic_range_gumbel`` config, which defaults to "low". With mode="low", + in float32 sampling will be biased for choices with probability less than about + 1E-7; with mode="high" this limit is pushed down to about 1E-14. mode="high" + approximately doubles the cost of sampling. Returns: An array of shape `shape` containing samples from `a`. @@ -670,7 +708,7 @@ def choice(key: ArrayLike, ind = jnp.searchsorted(p_cuml, r).astype(int) else: # Gumbel top-k trick: https://timvieira.github.io/blog/post/2019/09/16/algorithms-for-sampling-without-replacement/ - g = gumbel(key, (n_inputs,), dtype=p_arr.dtype) + jnp.log(p_arr) + g = gumbel(key, (n_inputs,), dtype=p_arr.dtype, mode=mode) + jnp.log(p_arr) ind = lax.top_k(g, k=n_draws)[1].astype(int) result = ind if arr.ndim == 0 else jnp.take(arr, ind, axis) @@ -680,7 +718,9 @@ def choice(key: ArrayLike, def normal(key: ArrayLike, shape: Shape = (), - dtype: DTypeLikeFloat = float) -> Array: + dtype: DTypeLikeFloat = float, + *, + out_sharding=None) -> Array: r"""Sample standard normal random values with given shape and float dtype. The values are returned according to the probability density function: @@ -702,12 +742,13 @@ def normal(key: ArrayLike, """ key, _ = _check_prng_key("normal", key) shape = core.canonicalize_shape(shape) + out_sharding = canonicalize_sharding(out_sharding, "normal") dtypes.check_user_dtype_supported(dtype) if not dtypes.issubdtype(dtype, np.inexact): raise ValueError(f"dtype argument to `normal` must be a float or complex dtype, " f"got {dtype}") dtype = dtypes.canonicalize_dtype(dtype) - return _normal(key, shape, dtype) + return maybe_auto_axes(_normal, out_sharding, shape=shape, dtype=dtype)(key) @partial(jit, static_argnums=(1, 2)) def _normal(key, shape, dtype) -> Array: @@ -818,7 +859,8 @@ def truncated_normal(key: ArrayLike, lower: RealArray, upper: RealArray, shape: Shape | None = None, - dtype: DTypeLikeFloat = float) -> Array: + dtype: DTypeLikeFloat = float, + *, out_sharding=None) -> Array: r"""Sample truncated standard normal random values with given shape and dtype. The values are returned according to the probability density function: @@ -849,12 +891,14 @@ def truncated_normal(key: ArrayLike, if shape is not None: shape = core.canonicalize_shape(shape) key, _ = _check_prng_key("truncated_normal", key) + out_sharding = canonicalize_sharding(out_sharding, "truncated_normal") dtypes.check_user_dtype_supported(dtype) if not dtypes.issubdtype(dtype, np.floating): raise ValueError(f"dtype argument to `truncated_normal` must be a float " f"dtype, got {dtype}") dtype = dtypes.canonicalize_dtype(dtype) - return _truncated_normal(key, lower, upper, shape, dtype) + return maybe_auto_axes(_truncated_normal, out_sharding, + shape=shape, dtype=dtype)(key, lower, upper) @partial(jit, static_argnums=(3, 4)) def _truncated_normal(key, lower, upper, shape, dtype) -> Array: @@ -882,7 +926,8 @@ def _truncated_normal(key, lower, upper, shape, dtype) -> Array: def bernoulli(key: ArrayLike, p: RealArray = np.float32(0.5), - shape: Shape | None = None) -> Array: + shape: Shape | None = None, + mode: str = 'low') -> Array: r"""Sample Bernoulli random values with given shape and mean. The values are distributed according to the probability mass function: @@ -899,6 +944,11 @@ def bernoulli(key: ArrayLike, shape: optional, a tuple of nonnegative integers representing the result shape. Must be broadcast-compatible with ``p.shape``. The default (None) produces a result shape equal to ``p.shape``. + mode: optional, "high" or "low" for how many bits to use when sampling. + default='low'. Set to "high" for correct sampling at small values of + `p`. When sampling in float32, bernoulli samples with mode='low' produce + incorrect results for p < ~1E-7. mode="high" approximately doubles the + cost of sampling. Returns: A random array with boolean dtype and shape given by ``shape`` if ``shape`` @@ -906,23 +956,33 @@ def bernoulli(key: ArrayLike, """ if shape is not None: shape = core.canonicalize_shape(shape) + if mode not in ['high', 'low']: + raise ValueError(f"got {mode=}, expected 'high' or 'low'") key, _ = _check_prng_key("bernoulli", key) dtype = dtypes.canonicalize_dtype(lax.dtype(p)) if not jnp.issubdtype(dtype, np.floating): msg = "bernoulli probability `p` must have a floating dtype, got {}." raise TypeError(msg.format(dtype)) p = lax.convert_element_type(p, dtype) - return _bernoulli(key, p, shape) + return _bernoulli(key, p, shape, mode=mode) -@partial(jit, static_argnums=(2,)) -def _bernoulli(key, p, shape) -> Array: + +@partial(jit, static_argnames=['shape', 'mode']) +def _bernoulli(key: Array, p: Array, shape: Shape | None, mode: str) -> Array: if shape is None: # TODO: Use the named part of `p` as well shape = np.shape(p) else: _check_shape("bernoulli", shape, np.shape(p)) + dtype = lax.dtype(p) - return uniform(key, shape, lax.dtype(p)) < p + if mode == 'high': + u1, u2 = uniform(key, (2, *shape), dtype) + # resolution of uniform samples is 2 ** -n_mantissa + u2 *= 2 ** -dtypes.finfo(dtype).nmant + return u2 < p - u1 + else: + return uniform(key, shape, lax.dtype(p)) < p def beta(key: ArrayLike, @@ -1085,16 +1145,7 @@ def _dirichlet(key, alpha, shape, dtype) -> Array: # Compute gamma in log space, otherwise small alpha can lead to poor behavior. log_gamma_samples = loggamma(key, alpha, shape + np.shape(alpha)[-1:], dtype) - return _softmax(log_gamma_samples, -1) - - -def _softmax(x, axis) -> Array: - """Utility to compute the softmax of x along a given axis.""" - if not dtypes.issubdtype(x.dtype, np.floating): - raise TypeError(f"_softmax only accepts floating dtypes, got {x.dtype}") - x_max = jnp.max(x, axis, keepdims=True) - unnormalized = jnp.exp(x - lax.stop_gradient(x_max)) - return unnormalized / unnormalized.sum(axis, keepdims=True) + return softmax(log_gamma_samples, -1) def exponential(key: ArrayLike, @@ -1501,7 +1552,7 @@ def poisson(key: ArrayLike, def gumbel(key: ArrayLike, shape: Shape = (), dtype: DTypeLikeFloat = float, - mode: str | None =None) -> Array: + mode: str | None = None) -> Array: """Sample Gumbel random values with given shape and float dtype. The values are distributed according to the probability density function: @@ -1516,6 +1567,11 @@ def gumbel(key: ArrayLike, dtype: optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32). mode: optional, "high" or "low" for how many bits to use when sampling. + The default is determined by the ``use_high_dynamic_range_gumbel`` config, + which defaults to "low". When drawing float32 samples, with mode="low" the + uniform resolution is such that the largest possible gumbel logit is ~16; + with mode="high" this is increased to ~32, at approximately double the + computational cost. Returns: A random array with the specified shape and dtype. @@ -1537,7 +1593,8 @@ def gumbel(key: ArrayLike, def _gumbel(key, shape, dtype, mode) -> Array: _check_shape("gumbel", shape) if mode == "high": - high, low = _uniform(key, (2,) + shape, dtype, minval=0., maxval=1.) + high, low = _uniform(key, minval=0., maxval=1., + shape=(2,) + shape, dtype=dtype) # TODO(parkers): The condition is to protect against rounding up but # we should be able to add safely with the right addition operation. x = jnp.where(high >= 0.5, high, @@ -1545,7 +1602,8 @@ def _gumbel(key, shape, dtype, mode) -> Array: return -jnp.log(-jnp.log1p(-x)) else: return -jnp.log(-jnp.log( - _uniform(key, shape, dtype, minval=jnp.finfo(dtype).tiny, maxval=1.))) + _uniform(key, minval=jnp.finfo(dtype).tiny, maxval=1., + shape=shape, dtype=dtype))) def categorical( @@ -1554,6 +1612,7 @@ def categorical( axis: int = -1, shape: Shape | None = None, replace: bool = True, + mode: str | None = None, ) -> Array: """Sample random values from categorical distributions. @@ -1568,8 +1627,14 @@ def categorical( shape: Optional, a tuple of nonnegative integers representing the result shape. Must be broadcast-compatible with ``np.delete(logits.shape, axis)``. The default (None) produces a result shape equal to ``np.delete(logits.shape, axis)``. - replace: If True, perform sampling without replacement. Default (False) is to - perform sampling with replacement. + replace: If True (default), perform sampling with replacement. If False, perform + sampling without replacement. + mode: optional, "high" or "low" for how many bits to use in the gumbel sampler. + The default is determined by the ``use_high_dynamic_range_gumbel`` config, + which defaults to "low". With mode="low", in float32 sampling will be biased + for events with probability less than about 1E-7; with mode="high" this limit + is pushed down to about 1E-14. mode="high" approximately doubles the cost of + sampling. Returns: A random array with int dtype and shape given by ``shape`` if ``shape`` @@ -1599,11 +1664,11 @@ def categorical( logits_shape = list(shape[len(shape) - len(batch_shape):]) logits_shape.insert(axis % len(logits_arr.shape), logits_arr.shape[axis]) return jnp.argmax( - gumbel(key, (*shape_prefix, *logits_shape), logits_arr.dtype) + + gumbel(key, (*shape_prefix, *logits_shape), logits_arr.dtype, mode=mode) + lax.expand_dims(logits_arr, tuple(range(len(shape_prefix)))), axis=axis) else: - logits_arr += gumbel(key, logits_arr.shape, logits_arr.dtype) + logits_arr += gumbel(key, logits_arr.shape, logits_arr.dtype, mode=mode) k = math.prod(shape_prefix) if k > logits_arr.shape[axis]: raise ValueError( @@ -2110,7 +2175,7 @@ def orthogonal( m: an integer indicating the number of columns. Defaults to `n`. Returns: - A random array of shape `(*shape, n, n)` and specified dtype. + A random array of shape `(*shape, n, m)` and specified dtype. References: .. [1] Mezzadri, Francesco. (2007). "How to generate random matrices from diff --git a/jax/_src/scipy/linalg.py b/jax/_src/scipy/linalg.py index 9917cbaa0b12..5a1f6d988740 100644 --- a/jax/_src/scipy/linalg.py +++ b/jax/_src/scipy/linalg.py @@ -26,10 +26,10 @@ from jax import lax from jax._src import dtypes from jax._src.lax import linalg as lax_linalg -from jax._src.lax import qdwh from jax._src.numpy.util import ( check_arraylike, promote_dtypes, promote_dtypes_inexact, promote_dtypes_complex) +from jax._src.tpu.linalg import qdwh from jax._src.typing import Array, ArrayLike @@ -486,6 +486,8 @@ def _schur(a: Array, output: str) -> tuple[Array, Array]: def schur(a: ArrayLike, output: str = 'real') -> tuple[Array, Array]: """Compute the Schur decomposition + Only implemented on CPU. + JAX implementation of :func:`scipy.linalg.schur`. The Schur form `T` of a matrix `A` satisfies: @@ -1832,6 +1834,9 @@ def _sqrtm(A: ArrayLike) -> Array: def sqrtm(A: ArrayLike, blocksize: int = 1) -> Array: """Compute the matrix square root + This function is implemented using :func:`scipy.linalg.schur`, which is only + supported on CPU. + JAX implementation of :func:`scipy.linalg.sqrtm`. Args: @@ -2182,3 +2187,64 @@ def hilbert(n: int) -> Array: """ a = lax.broadcasted_iota(jnp.float64, (n, 1), 0) return 1/(a + a.T + 1) + +@partial(jit, static_argnames=("n", "kind",)) +def pascal(n: int, kind: str | None = None) -> Array: + r"""Create a Pascal matrix approximation of order n. + + JAX implementation of :func:`scipy.linalg.pascal`. + + The elements of the Pascal matrix approximate the binomial coefficients. This + implementation is not exact as JAX does not support exact factorials. + + Args: + n: the size of the matrix to create. + kind: (optional) must be one of ``lower``, ``upper``, or ``symmetric`` (default). + + Returns: + A Pascal matrix of shape ``(n, n)`` + + Examples: + >>> with jnp.printoptions(precision=3): + ... print(jax.scipy.linalg.pascal(3, kind="lower")) + ... print(jax.scipy.linalg.pascal(4, kind="upper")) + ... print(jax.scipy.linalg.pascal(5)) + [[1. 0. 0.] + [1. 1. 0.] + [1. 2. 1.]] + [[1. 1. 1. 1.] + [0. 1. 2. 3.] + [0. 0. 1. 3.] + [0. 0. 0. 1.]] + [[ 1. 1. 1. 1. 1.] + [ 1. 2. 3. 4. 5.] + [ 1. 3. 6. 10. 15.] + [ 1. 4. 10. 20. 35.] + [ 1. 5. 15. 35. 70.]] + """ + if kind is None: + kind = "symmetric" + + valid_kind = ["symmetric", "lower", "upper"] + + if kind not in valid_kind: + raise ValueError(f"Expected kind to be on of: {valid_kind}; got {kind}") + + a = jnp.arange(n, dtype=jnp.float32) + + L_n = _binom(a[:, None], a[None, :]) + + if kind == "lower": + return L_n + + if kind == "upper": + return L_n.T + + return jnp.dot(L_n, L_n.T) + +@jit +def _binom(n, k): + a = lax.lgamma(n + 1.0) + b = lax.lgamma(n - k + 1.0) + c = lax.lgamma(k + 1.0) + return lax.exp(a - b - c) diff --git a/jax/_src/scipy/signal.py b/jax/_src/scipy/signal.py index d950cd2ea395..d4ca7c2c6147 100644 --- a/jax/_src/scipy/signal.py +++ b/jax/_src/scipy/signal.py @@ -148,7 +148,7 @@ def _fftconvolve_unbatched(in1: Array, in2: Array, mode: str) -> Array: return lax.dynamic_slice(conv, start_indices, out_shape) -# Note: we do not re-use the code from jax.numpy.convolve here, because the handling +# Note: we do not reuse the code from jax.numpy.convolve here, because the handling # of padding differs slightly between the two implementations (particularly for # mode='same'). def _convolve_nd(in1: Array, in2: Array, mode: str, *, precision: PrecisionLike) -> Array: @@ -566,13 +566,9 @@ def _fft_helper(x: Array, win: Array, detrend_func: Callable[[Array], Array], result = x[..., np.newaxis] else: step = nperseg - noverlap - batch_shape = list(batch_shape) - x = x.reshape((math.prod(batch_shape), signal_length, 1)) - result = jax.lax.conv_general_dilated_patches( - x, (nperseg,), (step,), - 'VALID', - dimension_numbers=('NTC', 'OIT', 'NTC')) - result = result.reshape(*batch_shape, *result.shape[-2:]) + starts = jnp.arange(signal_length - nperseg + 1, step=step) + slice_func = partial(jax.lax.dynamic_slice_in_dim, operand=x, slice_size=nperseg, axis=-1) + result = jax.vmap(slice_func, out_axes=-2)(start_index=starts) # Detrend each data segment individually result = detrend_func(result) @@ -1034,16 +1030,16 @@ def _overlap_and_add(x: Array, step_size: int) -> Array: x = x.reshape((flat_batchsize, nframes, nstep_per_segment, step_size)) # For obtaining shifted signals, this routine reinterprets flattened array - # with a shrinked axis. With appropriate truncation/ padding, this operation + # with a shrunken axis. With appropriate truncation/ padding, this operation # pushes the last padded elements of the previous row to the head of the # current row. # See implementation of `overlap_and_add` in Tensorflow for details. x = x.transpose((0, 2, 1, 3)) # x: (B, S, N, T) x = jnp.pad(x, ((0, 0), (0, 0), (0, nframes), (0, 0))) # x: (B, S, N*2, T) - shrinked = x.shape[2] - 1 + shrunken = x.shape[2] - 1 x = x.reshape((flat_batchsize, -1)) - x = x[:, :(nstep_per_segment * shrinked * step_size)] - x = x.reshape((flat_batchsize, nstep_per_segment, shrinked * step_size)) + x = x[:, :(nstep_per_segment * shrunken * step_size)] + x = x.reshape((flat_batchsize, nstep_per_segment, shrunken * step_size)) # Finally, sum shifted segments, and truncate results to the output_size. x = x.sum(axis=1)[:, :output_size] @@ -1071,7 +1067,7 @@ def istft(Zxx: Array, fs: ArrayLike = 1.0, window: str = 'hann', noverlap: Number of points to overlap between segments (default: ``nperseg // 2``). nfft: Number of FFT points used in the STFT. If ``None`` (default), the value is determined from the size of ``Zxx``. - input_onesided: If Tru` (default), interpret the input as a one-sided STFT + input_onesided: If True (default), interpret the input as a one-sided STFT (positive frequencies only). If False, interpret the input as a two-sided STFT. boundary: If True (default), it is assumed that the input signal was extended at its boundaries by ``stft``. If `False`, the input signal is assumed to have been truncated at the boundaries by `stft`. @@ -1108,7 +1104,7 @@ def istft(Zxx: Array, fs: ArrayLike = 1.0, window: str = 'hann', raise ValueError('Must specify differing time and frequency axes!') Zxx = jnp.asarray(Zxx, dtype=jax.dtypes.canonicalize_dtype( - np.result_type(Zxx, np.complex64))) + dtypes.to_complex_dtype(Zxx.dtype))) n_default = (2 * (Zxx.shape[freq_axis] - 1) if input_onesided else Zxx.shape[freq_axis]) @@ -1147,7 +1143,7 @@ def istft(Zxx: Array, fs: ArrayLike = 1.0, window: str = 'hann', xsubs = ifunc(Zxx, axis=-2, n=nfft)[..., :nperseg_int, :] # Get window as array - if window == 'hann': + if isinstance(window, str) and window == 'hann': # Implement the default case without scipy win = jnp.array([1.0]) if nperseg_int == 1 else jnp.sin(jnp.linspace(0, jnp.pi, nperseg_int, endpoint=False)) ** 2 win = win.astype(xsubs.dtype) diff --git a/jax/_src/scipy/special.py b/jax/_src/scipy/special.py index a24736ccfec0..ba0ea75b21d9 100644 --- a/jax/_src/scipy/special.py +++ b/jax/_src/scipy/special.py @@ -26,9 +26,12 @@ from jax import vmap from jax import lax +from jax._src import api_util +from jax._src import config from jax._src import core from jax._src import custom_derivatives from jax._src import deprecations +from jax._src import dispatch from jax._src import dtypes from jax._src.lax.lax import _const as _lax_const from jax._src.numpy.util import promote_args_inexact, promote_dtypes_inexact @@ -1048,10 +1051,17 @@ def _create_polynomial(var, coeffs): jnp.where(z >= dtype(8.0), x_for_small_p, x_otherwise)) x = jnp.where(p > dtype(1. - np.exp(-2.)), x, -x) - infinity = jnp.full(shape, dtype(np.inf)) - x_fix_boundaries = jnp.where( - p == dtype(0.0), -infinity, jnp.where(p == dtype(1.0), infinity, x)) - return x_fix_boundaries + with config.debug_infs(False): + infinity = jnp.full(shape, dtype(np.inf)) + x = jnp.where( + p == dtype(0.0), -infinity, jnp.where(p == dtype(1.0), infinity, x)) + if not isinstance(x, core.Tracer): + try: + dispatch.check_special("ndtri", [x]) + except api_util.InternalFloatingPointError as e: + raise FloatingPointError( + f"invalid value ({e.ty}) encountered in ndtri.") from None + return x @partial(custom_derivatives.custom_jvp, nondiff_argnums=(1,)) @@ -2637,6 +2647,260 @@ def hyp1f1(a: ArrayLike, b: ArrayLike, x: ArrayLike) -> Array: ) +def _hyp2f1_terminal(a, b, c, x): + """ + The Taylor series representation of the 2F1 hypergeometric function + terminates when either a or b is a non-positive integer. See Eq. 4.1 and + Taylor Series Method (a) from PEARSON, OLVER & PORTER 2014 + https://doi.org/10.48550/arXiv.1407.7786 + """ + # Ensure that between a and b, the negative integer parameter with the greater + # absolute value - that still has a magnitude less than the absolute value of + # c if c is non-positive - is used for the upper limit in the loop. + eps = jnp.finfo(x.dtype).eps * 50 + ib = jnp.round(b) + mask = jnp.logical_and( + b < a, + jnp.logical_and( + jnp.abs(b - ib) < eps, + jnp.logical_not( + jnp.logical_and( + c % 1 == 0, + jnp.logical_and( + c <= 0, + c > b + ) + ) + ) + ) + ) + orig_a = a + a = jnp.where(mask, b, a) + b = jnp.where(mask, orig_a, b) + + a = jnp.abs(a) + + def body(i, state): + serie, term = state + + term *= -(a - i + 1) / (c + i - 1) * (b + i - 1) / i * x + serie += term + + return serie, term + + init = (jnp.array(1, dtype=x.dtype), jnp.array(1, dtype=x.dtype)) + + return lax.fori_loop(jnp.array(1, dtype=a.dtype), + a + 1, + body, + init)[0] + + +def _hyp2f1_serie(a, b, c, x): + """ + Compute the 2F1 hypergeometric function using the Taylor expansion. + See Eq. 4.1 from PEARSON, OLVER & PORTER 2014 + https://doi.org/10.48550/arXiv.1407.7786 + """ + rtol = jnp.finfo(x.dtype).eps + + def body(state): + serie, k, term = state + + serie += term + term *= (a + k - 1) * (b + k - 1) / (c + k - 1) / k * x + k += 1 + + return serie, k, term + + def cond(state): + serie, k, term = state + + return (k < 250) & (lax.abs(term) > rtol * lax.abs(serie)) + + init = (jnp.array(0, dtype=x.dtype), + jnp.array(1, dtype=x.dtype), + jnp.array(1, dtype=x.dtype)) + + return lax.while_loop(cond, body, init)[0] + + +def _hyp2f1_terminal_or_serie(a, b, c, x): + """ + Check for recurrence relations along with whether or not the series + terminates. True recursion is not possible; however, the recurrence + relation may still be approximated. + See 4.6.1. Recurrence Relations from PEARSON, OLVER & PORTER 2014 + https://doi.org/10.48550/arXiv.1407.7786 + """ + eps = jnp.finfo(x.dtype).eps * 50 + + d = c - a - b + + ia = jnp.round(a) + ib = jnp.round(b) + id = jnp.round(d) + + neg_int_a = jnp.logical_and(a <= 0, jnp.abs(a - ia) < eps) + neg_int_b = jnp.logical_and(b <= 0, jnp.abs(b - ib) < eps) + neg_int_a_or_b = jnp.logical_or(neg_int_a, neg_int_b) + not_neg_int_a_or_b = jnp.logical_not(neg_int_a_or_b) + + index = jnp.where(jnp.logical_and(x > 0.9, not_neg_int_a_or_b), + jnp.where(jnp.abs(d - id) >= eps, 0, 1), + jnp.where(neg_int_a_or_b, 2, 0)) + + return lax.select_n(index, + _hyp2f1_serie(a, b, c, x), + _hyp2f1_digamma_transform(a, b, c, x), + _hyp2f1_terminal(a, b, c, x)) + + +def _hyp2f1_digamma_transform(a, b, c, x): + """ + Digamma transformation of the 2F1 hypergeometric function. + See AMS55 #15.3.10, #15.3.11, #15.3.12 + """ + rtol = jnp.finfo(x.dtype).eps + + d = c - a - b + s = 1 - x + rd = jnp.round(d) + + e = jnp.where(rd >= 0, d, -d) + d1 = jnp.where(rd >= 0, d, jnp.array(0, dtype=d.dtype)) + d2 = jnp.where(rd >= 0, jnp.array(0, dtype=d.dtype), d) + ard = jnp.where(rd >= 0, rd, -rd).astype('int32') + + ax = jnp.log(s) + + y = digamma(1.0) + digamma(1.0 + e) - digamma(a + d1) - digamma(b + d1) - ax + y /= gamma(e + 1.0) + + p = (a + d1) * (b + d1) * s / gamma(e + 2.0) + + def cond(state): + _, _, _, _, _, _, q, _, _, t, y = state + + return jnp.logical_and( + t < 250, + jnp.abs(q) >= rtol * jnp.abs(y) + ) + + def body(state): + a, ax, b, d1, e, p, q, r, s, t, y = state + + r = digamma(1.0 + t) + digamma(1.0 + t + e) - digamma(a + t + d1) \ + - digamma(b + t + d1) - ax + q = p * r + y += q + p *= s * (a + t + d1) / (t + 1.0) + p *= (b + t + d1) / (t + 1.0 + e) + t += 1.0 + + return a, ax, b, d1, e, p, q, r, s, t, y + + init = (a, ax, b, d1, e, p, y, jnp.array(0, dtype=x.dtype), s, + jnp.array(1, dtype=x.dtype), y) + _, _, _, _, _, _, q, r, _, _, y = lax.while_loop(cond, body, init) + + def compute_sum(y): + y1 = jnp.array(1, dtype=x.dtype) + t = jnp.array(0, dtype=x.dtype) + p = jnp.array(1, dtype=x.dtype) + + def for_body(i, state): + a, b, d2, e, p, s, t, y1 = state + + r = 1.0 - e + t + p *= s * (a + t + d2) * (b + t + d2) / r + t += 1.0 + p /= t + y1 += p + + return a, b, d2, e, p, s, t, y1 + + init_val = a, b, d2, e, p, s, t, y1 + y1 = lax.fori_loop(1, ard, for_body, init_val)[-1] + + p = gamma(c) + y1 *= gamma(e) * p / (gamma(a + d1) * gamma(b + d1)) + y *= p / (gamma(a + d2) * gamma(b + d2)) + + y = jnp.where((ard & 1) != 0, -y, y) + q = s ** rd + + return jnp.where(rd > 0, y * q + y1, y + y1 * q) + + return jnp.where( + rd == 0, + y * gamma(c) / (gamma(a) * gamma(b)), + compute_sum(y) + ) + + +@jit +@jnp.vectorize +def hyp2f1(a: ArrayLike, b: ArrayLike, c: ArrayLike, x: ArrayLike) -> Array: + r"""The 2F1 hypergeometric function. + + JAX implementation of :obj:`scipy.special.hyp2f1`. + + .. math:: + + \mathrm{hyp2f1}(a, b, c, x) = {}_2F_1(a; b; c; x) = \sum_{k=0}^\infty \frac{(a)_k(b)_k}{(c)_k}\frac{x^k}{k!} + + where :math:`(\cdot)_k` is the Pochammer symbol. + + The JAX version only accepts positive and real inputs. Values of + ``a``, ``b``, ``c``, and ``x`` leading to high values of 2F1 may + lead to erroneous results; consider enabling double precision in this case. + + Args: + a: arraylike, real-valued + b: arraylike, real-valued + c: arraylike, real-valued + x: arraylike, real-valued + + Returns: + array of 2F1 values. + """ + # This is backed by https://doi.org/10.48550/arXiv.1407.7786 + a, b, c, x = promote_args_inexact('hyp2f1', a, b, c, x) + eps = jnp.finfo(x.dtype).eps * 50 + + d = c - a - b + s = 1 - x + ca = c - a + cb = c - b + + id = jnp.round(d) + ica = jnp.round(ca) + icb = jnp.round(cb) + + neg_int_ca = jnp.logical_and(ca <= 0, jnp.abs(ca - ica) < eps) + neg_int_cb = jnp.logical_and(cb <= 0, jnp.abs(cb - icb) < eps) + neg_int_ca_or_cb = jnp.logical_or(neg_int_ca, neg_int_cb) + + index = jnp.where(jnp.logical_or(x == 0, jnp.logical_and(jnp.logical_or(a == 0, b == 0), c != 0)), 0, + jnp.where(jnp.logical_or(c == 0, jnp.logical_and(c < 0, c % 1 == 0)), 1, + jnp.where(jnp.logical_and(d <= -1, jnp.logical_not(jnp.logical_and(jnp.abs(d - id) >= eps, s < 0))), 2, + jnp.where(jnp.logical_and(d <= 0, x == 1), 1, + jnp.where(jnp.logical_and(x < 1, b == c), 3, + jnp.where(jnp.logical_and(x < 1, a == c), 4, + jnp.where(x > 1, 1, + jnp.where(x == 1, 5, 6)))))))) + + return lax.select_n(index, + jnp.array(1, dtype=x.dtype), + jnp.array(jnp.inf, dtype=x.dtype), + s ** d * _hyp2f1_terminal_or_serie(ca, cb, c, x), + s ** (-a), + s ** (-b), + gamma(c) * gamma(d) / (gamma(ca) * gamma(cb)), + _hyp2f1_terminal_or_serie(a, b, c, x)) + + def softmax(x: ArrayLike, /, *, diff --git a/jax/_src/scipy/stats/_core.py b/jax/_src/scipy/stats/_core.py index 65c457f79cc8..ae93dd793844 100644 --- a/jax/_src/scipy/stats/_core.py +++ b/jax/_src/scipy/stats/_core.py @@ -285,7 +285,7 @@ def sem(a: ArrayLike, axis: int | None = 0, ddof: int = 1, nan_policy: str = "pr Array([1.73, nan, 1.53, nan, nan, nan], dtype=float32) If ``nan_policy='omit```, ``sem`` omits the ``nan`` values and computes the error - for the remainging values along the specified axis. + for the remaining values along the specified axis. >>> with jnp.printoptions(precision=2, suppress=True): ... jax.scipy.stats.sem(x2, nan_policy='omit') diff --git a/jax/_src/shard_map.py b/jax/_src/shard_map.py new file mode 100644 index 000000000000..69c1b7d264f6 --- /dev/null +++ b/jax/_src/shard_map.py @@ -0,0 +1,1843 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from collections.abc import Callable, Hashable, Sequence, Set +import enum +from functools import partial +import inspect +from math import prod +import operator as op +from typing import Any, TypeVar, Union + +import numpy as np + +import jax +import jax.numpy as jnp +from jax.sharding import NamedSharding, PartitionSpec +from jax._src import ad_util +from jax._src import api_util +from jax._src import config +from jax._src import core +from jax._src import debugging +from jax._src import dispatch +from jax._src import dtypes +from jax._src import linear_util as lu +from jax._src import sharding_impls +from jax._src import source_info_util +from jax._src import traceback_util +from jax._src import util +from jax._src.core import pvary +from jax._src.core import Tracer, typeof +from jax._src.mesh import (AbstractMesh, Mesh, AxisType, use_abstract_mesh, + get_abstract_mesh, get_concrete_mesh) +from jax._src.api import _shared_code_pmap, _prepare_pmap +from jax._src.lib.mlir import ir +from jax._src.lib.mlir.dialects import hlo, sdy +from jax._src.util import (HashableFunction, HashablePartial, unzip2, + as_hashable_function, memoize, partition_list, + merge_lists, split_list, subs_list2, + fun_name as util_fun_name) +from jax._src.interpreters import batching +from jax._src.interpreters import mlir +from jax._src.interpreters import partial_eval as pe +from jax._src.interpreters import pxla +from jax._src.interpreters import ad +from jax.tree_util import (tree_map, tree_flatten, tree_unflatten, + tree_structure, tree_leaves, keystr) +from jax._src.tree_util import (broadcast_prefix, prefix_errors, PyTreeDef, + generate_key_paths, KeyPath) +from jax.experimental.multihost_utils import (host_local_array_to_global_array, + global_array_to_host_local_array) + +P = PartitionSpec + +map, unsafe_map = util.safe_map, map +zip, unsafe_zip = util.safe_zip, zip +traceback_util.register_exclusion(__file__) + +# API + +Specs = Any # PyTree[PartitionSpec] +AxisName = Hashable + + +def shard_map(f=None, /, *, out_specs: Specs, axis_names: Set[AxisName] = set(), + in_specs: Specs | None = None, + mesh: Mesh | AbstractMesh | None = None, check_vma: bool = True): + """Map a function over shards of data using a mesh of devices. + + See the docs at https://docs.jax.dev/en/latest/notebooks/shard_map.html. + + Args: + f: callable to be mapped. Each application of ``f``, or "instance" of ``f``, + takes as input a shard of the mapped-over arguments and produces a shard + of the output. + mesh: (optional, default None) a ``jax.sharding.Mesh`` representing the + array of devices over which to shard the data and on which to execute + instances of ``f``. The names of the ``Mesh`` can be used in collective + communication operations in ``f``. If mesh is None, it will be inferred + from the context which can be set via `jax.sharding.use_mesh` context + manager. + in_specs: (optional, default None) a pytree with + ``jax.sharding.PartitionSpec`` instances as leaves, with a tree structure + that is a tree prefix of the args tuple to be mapped over. Similar to + ``jax.sharding.NamedSharding``, each ``PartitionSpec`` represents how the + corresponding argument (or subtree of arguments) should be sharded along + the named axes of ``mesh``. In each ``PartitionSpec``, mentioning a + ``mesh`` axis name at a position expresses sharding the corresponding + argument array axis along that positional axis; not mentioning an axis + name expresses replication. If ``None``, all mesh axes must be of type + `Explicit`, in which case the in_specs are inferred from the argument types. + out_specs: a pytree with ``PartitionSpec`` instances as leaves, with a tree + structure that is a tree prefix of the output of ``f``. Each + ``PartitionSpec`` represents how the corresponding output shards should be + concatenated. In each ``PartitionSpec``, mentioning a ``mesh`` axis name + at a position expresses concatenation of that mesh axis's shards along the + corresponding positional axis; not mentioning a ``mesh`` axis name + expresses a promise that the output values are equal along that mesh axis, + and that rather than concatenating only a single value should be produced. + axis_names: (optional, default set()) set of axis names from ``mesh`` over + which the function ``f`` is manual. If empty, ``f``, is manual + over all mesh axes. + check_vma: (optional) boolean (default True) representing whether to enable + additional validity checks and automatic differentiation optimizations. + The validity checks concern whether any mesh axis names not mentioned in + ``out_specs`` are consistent with how the outputs of ``f`` are replicated. + + Returns: + A callable representing a mapped version of ``f``, which accepts positional + arguments corresponding to those of ``f`` and produces output corresponding + to that of ``f``. + """ + kwargs = dict(mesh=mesh, in_specs=in_specs, out_specs=out_specs, + axis_names=axis_names, check_vma=check_vma) + if f is None: + return lambda g: _shard_map(g, **kwargs) + return _shard_map(f, **kwargs) + +def _axes_to_pspec(axis_name, axis): + if axis is None: + return P() + return P(*[None] * axis + [axis_name]) + +class InferFromArgs: + + def __repr__(self): + return "jax.sharding.Infer" + + def __reduce__(self): + return (_get_default_infer, ()) + +Infer = InferFromArgs() + +def _get_default_infer(): + return Infer + +# TODO(yashkatariya): We need a singleton which users can provide to `in_axes` +# to tell smap to infer in_specs from args when mesh is fully explicit. +def smap(f, /, *, in_axes=Infer, out_axes, axis_name: AxisName): + if isinstance(axis_name, (list, tuple)): + raise TypeError( + f"smap axis_name should be a `str` or a `Hashable`, but got {axis_name}") + if (in_axes is not None and in_axes is not Infer and + not isinstance(in_axes, (int, tuple))): + raise TypeError( + "smap in_axes must be an int, None, jax.sharding.Infer, or a tuple of" + " entries corresponding to the positional arguments passed to the" + f" function, but got {in_axes}.") + if (in_axes is not Infer and + not all(isinstance(l, int) for l in tree_leaves(in_axes))): + raise TypeError( + "smap in_axes must be an int, None, jax.sharding.Infer, or (nested)" + f" container with those types as leaves, but got {in_axes}.") + if not all(isinstance(l, int) for l in tree_leaves(out_axes)): + raise TypeError("smap out_axes must be an int, None, or (nested) container " + f"with those types as leaves, but got {out_axes}.") + + in_specs = (None if in_axes is Infer else + tree_map(partial(_axes_to_pspec, axis_name), in_axes, + is_leaf=lambda x: x is None)) + out_specs = tree_map(partial(_axes_to_pspec, axis_name), out_axes, + is_leaf=lambda x: x is None) + return _shard_map(f, mesh=None, in_specs=in_specs, out_specs=out_specs, + axis_names={axis_name}, check_vma=True, _smap=True) + + +def _shard_map(f: Callable, *, mesh: Mesh | AbstractMesh | None, + in_specs: Specs, out_specs: Specs | Callable[[], Specs], + axis_names: Set[AxisName], check_vma: bool, + _skip_mesh_check: bool = False, _smap: bool = False) -> Callable: + if not callable(f): + raise TypeError("shard_map requires a callable for its first argument, " + f"but got {f} of type {type(f)}.") + + @util.wraps(f) + @traceback_util.api_boundary + def wrapped(*args): + nonlocal mesh, axis_names + mesh, axis_names = _shmap_checks(mesh, axis_names, in_specs, out_specs, + _skip_mesh_check, _smap) + fun = lu.wrap_init( + f, debug_info=api_util.debug_info("shard_map", f, args, {})) + args_flat, in_tree = tree_flatten(args) + fun, out_tree = api_util.flatten_fun_nokwargs(fun, in_tree) + + try: + in_specs_flat = broadcast_prefix( + in_specs, args, is_leaf=lambda x: x is None) + except ValueError: + e, *_ = prefix_errors(in_specs, args) + raise e('shard_map in_specs') from None + + if (in_specs is None and + all(mesh._name_to_type[a] == AxisType.Explicit for a in axis_names)): + arg_s = [typeof(a).sharding for a in args_flat] + assert all(i is None for i in in_specs_flat), in_specs_flat + in_specs_flat = [_manual_spec(axis_names, s.spec) for s in arg_s] + + dyn_argnums, in_specs_flat = unzip2((i, s) for i, s in enumerate(in_specs_flat) + if s is not None) + fun, args_flat = api_util.argnums_partial(fun, dyn_argnums, args_flat, False) + _check_specs_vs_args(f, mesh, in_tree, in_specs, dyn_argnums, in_specs_flat, + args_flat) + + @memoize + def out_specs_thunk(): + if callable(out_specs): + out_specs_ = out_specs() + _check_specs(SpecErrorType.out, out_specs_, axis_names) + else: + out_specs_ = out_specs + dummy = tree_unflatten(out_tree(), [object()] * out_tree().num_leaves) + try: + out_specs_flat = broadcast_prefix(out_specs_, dummy) + except ValueError: + e, *_ = prefix_errors(out_specs_, dummy) + raise e('shard_map out_specs') from None + return tuple(out_specs_flat) + + if check_vma: + fun = _implicit_pvary_on_output(fun, out_specs_thunk) + + try: + out_flat = shard_map_p.bind( + fun, *args_flat, mesh=mesh, in_specs=in_specs_flat, + out_specs_thunk=out_specs_thunk, check_vma=check_vma, + manual_axes=axis_names) + except _SpecError as e: + fails, = e.args + if not callable(out_specs): + msg = _spec_rank_error(SpecErrorType.out, f, out_tree(), out_specs, fails) + if any(fail is not no_fail and not fail.shape for fail in fails): + msg += (" In particular, for rank 0 outputs which are not constant " + "over the mesh, add at least one (singleton) axis to them so " + "that they can be concatenated using out_specs.") + raise ValueError(msg) from None + except _RepError as e: + fails, = e.args + if not callable(out_specs): + msg = _inout_vma_error(f, mesh, out_tree(), out_specs, fails) + raise ValueError(msg) from None + return tree_unflatten(out_tree(), out_flat) + return wrapped + + +def _shmap_checks(mesh, axis_names, in_specs, out_specs, _skip_mesh_check, + _smap): + if mesh is None: + mesh = get_abstract_mesh() + if mesh.empty: + raise ValueError( + "The context mesh cannot be empty. Use" + " `jax.sharding.use_mesh(mesh)` to enter into a mesh context") + else: + ctx_mesh = get_abstract_mesh() + if (not _skip_mesh_check and not ctx_mesh.empty and + mesh.abstract_mesh != ctx_mesh): + raise ValueError( + f"The context mesh {ctx_mesh} should match the mesh passed to" + f" shard_map {mesh}") + + if not isinstance(mesh, (Mesh, AbstractMesh)): + raise TypeError("shard_map requires a `jax.sharding.Mesh` or a " + "`jax.sharding.AbstractMesh` instance for its " + f"second argument, but got {mesh} of type {type(mesh)}.") + + if not isinstance(axis_names, (frozenset, set)): + raise TypeError( + "`axis_names` argument of shard_map should be of type `frozenset` or" + f" `set`. Got type: {type(axis_names)}") + if isinstance(axis_names, set): + axis_names = frozenset(axis_names) + if not axis_names: + axis_names = frozenset(mesh.axis_names) + if not axis_names.issubset(mesh.axis_names): + raise ValueError( + f"jax.shard_map requires axis_names={axis_names} to be a subset of " + f"mesh.axis_names={mesh.axis_names}") + + if (in_specs is None and + not all(mesh._name_to_type[a] == AxisType.Explicit for a in axis_names)): + axis_types = ', '.join(str(mesh._name_to_type[a]) for a in axis_names) + if _smap: + msg = (f"in_axes was not specified when axis_name={axis_names} was of" + f" type {axis_types}") + else: + msg = ("shard_map in_specs argument must be a pytree of" + " `jax.sharding.PartitionSpec` instances, but it was `None` when" + f" {axis_names=} are of type {axis_types}") + raise TypeError(msg) + + if in_specs is not None: + _check_specs(SpecErrorType.input, in_specs, axis_names) + if not callable(out_specs): + _check_specs(SpecErrorType.out, out_specs, axis_names) + return mesh, axis_names + +def _manual_spec(manual_axes, spec: P) -> P: + out = [] # type: ignore + for s in spec: + if s is None: + out.append(s) + elif isinstance(s, tuple): + temp = [p if p in manual_axes else None for p in s] + while temp and temp[-1] is None: + temp.pop() + if None in temp: + raise ValueError(f"Invalid spec: {spec}") + out.append(None if len(temp) == 0 else tuple(temp)) + else: + out.append(s if s in manual_axes else None) + return P(*out) + + +# Error checking and messages + +SpecErrorType = enum.Enum('SpecErrorType', ['input', 'out']) + +def _check_specs(error_type: SpecErrorType, specs: Any, manual_axes) -> None: + if error_type == SpecErrorType.input and specs is None: + raise TypeError( + "shard_map in_specs argument must be a pytree of " + "`jax.sharding.PartitionSpec` instances, but it was None.\n" + "Instead of `in_specs=None`, did you mean `in_specs=P()`, " + "where `P = jax.sharding.PartitionSpec`?") + + def check_spec(p): + if not isinstance(p, PartitionSpec): + return False + for names in p: + names = (names,) if not isinstance(names, tuple) else names + for name in names: + if name is not None and name not in manual_axes: + return False + return True + + if all(check_spec(p) for p in tree_leaves(specs)): + return + prefix = 'in' if error_type == SpecErrorType.input else 'out' + msgs = [f" {prefix}_specs{keystr(key)} is {x} of type {type(x).__name__}, " + for key, x in generate_key_paths(specs) if not isinstance(x, P)] + if not msgs: + for key, p in generate_key_paths(specs): + for names in p: + names = (names,) if not isinstance(names, tuple) else names + for name in names: + if name is not None and name not in manual_axes: + msgs.append(f" {prefix}_specs{keystr(key)} refers to {repr(name)}") + raise ValueError( + f"shard_map {prefix}_specs argument must refer to an axis " + f"marked as manual ({manual_axes}), but:\n\n" + + '\n\n'.join(msgs) + '\n\n' + f"Check the {prefix}_specs values passed to shard_map.") + raise TypeError( + f"shard_map {prefix}_specs argument must be a pytree of " + f"`jax.sharding.PartitionSpec` instances, but:\n\n" + + '\n\n'.join(msgs) + '\n\n' + f"Check the {prefix}_specs values passed to shard_map.") + +class NoFail: pass +no_fail = NoFail() + +def _check_specs_vs_args( + f: Callable, mesh: Mesh | AbstractMesh, in_tree: PyTreeDef, in_specs: Specs, + dyn_argnums: Sequence[int], in_specs_flat: Sequence[P], + xs: Sequence) -> None: + in_avals = map(core.shaped_abstractify, xs) + fail = [a if not len(p) <= a.ndim else no_fail + for p, a in zip(in_specs_flat, in_avals)] + if any(f is not no_fail for f in fail): + fail = _expand_fail(in_tree, dyn_argnums, fail) + msg = _spec_rank_error(SpecErrorType.input, f, in_tree, in_specs, fail) + raise ValueError(msg) + in_names_flat = tuple(map(_spec_to_names, in_specs_flat)) + fail = [a if any(a.shape[d] % prod(mesh.shape[n] for n in ns) + for d, ns in names.items()) else no_fail + for a, names in zip(in_avals, in_names_flat)] + if any(f is not no_fail for f in fail): + fail = _expand_fail(in_tree, dyn_argnums, fail) + msg = _spec_divisibility_error(f, mesh, in_tree, in_specs, fail) + raise ValueError(msg) + +def _expand_fail(in_tree: PyTreeDef, dyn_argnums: Sequence[int], + fail: Sequence[core.ShapedArray | NoFail] + ) -> list[core.ShapedArray | NoFail]: + fail_: list[core.ShapedArray | NoFail] = [no_fail] * in_tree.num_leaves + for i, f in zip(dyn_argnums, fail): + fail_[i] = f + return fail_ + +def _spec_rank_error( + error_type: SpecErrorType, f: Callable, tree: PyTreeDef, specs: Specs, + fails: list[core.ShapedArray | NoFail]) -> str: + fun_name = util_fun_name(f) + if error_type == SpecErrorType.input: + prefix, base = 'in', 'args' + ba = _try_infer_args(f, tree) + else: + prefix, base = 'out', f'{fun_name}(*args)' + msgs = [] + for (spec_key, spec), (fail_key, aval) in _iter_paths(tree, specs, fails): + extra = "" + if error_type == SpecErrorType.input and ba is not None: + arg_key, *_ = fail_key + if arg_key.idx < len(ba.arguments): + param_name = list(ba.arguments.keys())[arg_key.idx] + extra = (f", where {base}{arg_key} is bound to {fun_name}'s " + f"parameter '{param_name}',") + else: + param = list(ba.signature.parameters.values())[-1] + assert param.kind == inspect.Parameter.VAR_POSITIONAL + extra = (f", where {base}{arg_key} is the index " + f"{arg_key.idx - len(ba.signature.parameters) + 1} component " + f"of {fun_name}'s varargs parameter '{param.name}',") + msgs.append( + f"* {prefix}_specs{keystr(spec_key)} is {spec} which has length " + f"{len(spec)}, but " + f"{base}{keystr(fail_key)}{extra} has shape {aval.str_short()}, " + f"which has rank {aval.ndim} (and {aval.ndim} < {len(spec)})") + assert msgs + if len(msgs) == 1: msgs = [msgs[0][2:]] # remove the bullet point + msg = (f"shard_map applied to the function '{fun_name}' was given an " + f"{prefix}_specs entry which is too long to be compatible with the " + f"corresponding {prefix}put value from the function:\n\n" + + '\n\n'.join(msgs) + '\n\n' + + f"Entries in {prefix}_specs must be of length no greater than the " + f"number of axes in the corresponding {prefix}put value.\n\n" + f"Either revise the spec to be shorter, or modify '{fun_name}' so " + f"that its {prefix}puts have sufficient rank.") + if any(not aval.ndim for _, (_, aval) in _iter_paths(tree, specs, fails)): + msg += (f"\n\nFor scalar values (rank 0), consider using an {prefix}_specs " + "entry of `P()`, where `P = jax.sharding.PartitionSpec`.") + return msg + +def _spec_divisibility_error( + f: Callable, mesh: Mesh | AbstractMesh, tree: PyTreeDef, specs: Specs, + fails: list[core.ShapedArray | NoFail]) -> str: + ba = _try_infer_args(f, tree) + fun_name = getattr(f, '__name__', str(f)) + msgs = [] + for (spec_key, spec), (fail_key, aval) in _iter_paths(tree, specs, fails): + extra = "" + if ba is not None: + arg_key, *_ = fail_key + if arg_key.idx < len(ba.arguments): + param_name = list(ba.arguments.keys())[arg_key.idx] + extra = (f", where args{arg_key} is bound to {fun_name}'s " + f"parameter '{param_name}',") + else: + param = list(ba.signature.parameters.values())[-1] + assert param.kind == inspect.Parameter.VAR_POSITIONAL + extra = (f", where args{arg_key} is the index " + f"{arg_key.idx - len(ba.signature.parameters) + 1} component " + f"of {fun_name}'s varargs parameter '{param.name}',") + names = _spec_to_names(spec) + for d, ns in names.items(): + if aval.shape[d] % prod(mesh.shape[n] for n in ns): + axis = f"axes {ns}" if len(ns) > 1 else f"axis '{ns[0]}'" + total = 'total ' if len(ns) > 1 else '' + sz = prod(mesh.shape[n] for n in ns) + msgs.append( + f"* args{keystr(fail_key)} of shape {aval.str_short()}{extra} " + f"corresponds to in_specs{keystr(spec_key)} of value {spec}, " + f"which maps array axis {d} (of size {aval.shape[d]}) to mesh " + f"{axis} (of {total}size {sz}), but {sz} does not evenly divide " + f"{aval.shape[d]}") + assert msgs + if len(msgs) == 1: msgs = [msgs[0][2:]] # remove the bullet point + msg = (f"shard_map applied to the function '{fun_name}' was given argument " + f"arrays with axis sizes that are not evenly divisible by the " + f"corresponding mesh axis sizes:\n\n" + f"The mesh given has shape {tuple(mesh.shape.values())} with " + f"corresponding axis names {mesh.axis_names}.\n\n" + + '\n\n'.join(msgs) + '\n\n' + + f"Array arguments' axis sizes must be evenly divisible by the mesh " + f"axis or axes indicated by the corresponding elements of the " + f"argument's in_specs entry. Consider checking that in_specs are " + f"correct, and if so consider changing the mesh axis sizes or else " + f"padding the input and adapting '{fun_name}' appropriately.") + return msg + +def _inout_vma_error(f: Callable, mesh: Mesh | AbstractMesh, tree: PyTreeDef, + specs: Specs, fails: list[set | NoFail]) -> str: + fun_name = getattr(f, '__name__', str(f)) + msgs = [] + for (spec_key, spec), (fail_key, vma) in _iter_paths(tree, specs, fails): + unmentioned = _unmentioned(mesh, spec) + if len(unmentioned) > 1: + need_vma = ','.join(map(str, order_wrt_mesh(mesh, _spec_to_vma(spec)))) + got_vma = ','.join(map(str, order_wrt_mesh(mesh, vma))) + diff = ','.join(map(str, order_wrt_mesh( + mesh, [n for n in unmentioned if n in vma]))) + msgs.append( + f"* out_specs{keystr(spec_key)} is {spec} which implies that the " + f"corresponding output value is only varying across mesh axes " + f"{{{need_vma}}} and not {{{diff}}}, but it was inferred to be " + f"possibly varying over {{{got_vma}}}") + else: + need_rep_, = unmentioned + msgs.append( + f"* out_specs{keystr(spec_key)} is {spec} which implies that the " + f"corresponding output value is replicated across mesh axis " + f"'{need_rep_}', but could not infer replication over any axes") + assert msgs + if len(msgs) == 1: msgs = [msgs[0][2:]] # remove the bullet point + msg = (f"shard_map applied to the function '{fun_name}' was given " + f"out_specs which require replication which can't be statically " + f"inferred given the mesh:\n\n" + f"The mesh given has shape {tuple(mesh.shape.values())} with " + f"corresponding axis names {mesh.axis_names}.\n\n" + + '\n\n'.join(msgs) + '\n\n' + + "Check if these output values are meant to be replicated over those " + "mesh axes. If not, consider revising the corresponding out_specs " + "entries. If so, consider disabling the check by passing the " + "check_vma=False argument to `jax.shard_map`.") + return msg + +def _unmentioned(mesh: Mesh | AbstractMesh, spec) -> list[AxisName]: + vma_set = _spec_to_vma(spec) + return [n for n in mesh.axis_names if n not in vma_set] + + +def _try_infer_args(f, tree): + dummy_args = tree_unflatten(tree, [False] * tree.num_leaves) + try: + return inspect.signature(f).bind(*dummy_args) + except (TypeError, ValueError): + return None + +T = TypeVar('T') +def _iter_paths(tree: PyTreeDef, specs: Specs, fails: list[T | NoFail] + ) -> list[tuple[tuple[KeyPath, P], tuple[KeyPath, T]]]: + failures = tree_unflatten(tree, fails) + failures_aug = generate_key_paths(failures) + specs_ = tree_unflatten(tree_structure(specs), generate_key_paths(specs)) + leaf = lambda x: x is None or type(x) is tuple and len(x) == 2 and type(x[1]) is P + specs_aug = broadcast_prefix(specs_, failures, is_leaf=leaf) + return [(s, (fail_key, fail_data)) for s, (fail_key, fail_data) + in zip(specs_aug, failures_aug) + if s is not None and fail_data is not no_fail] + +# Primitive + +@lu.transformation2 +def _implicit_pvary_on_output(f, out_specs_thunk, *args, **kwargs): + out_flat = f(*args, **kwargs) + return [pvary(o, tuple(_spec_to_vma(sp) - typeof(o).vma)) + for o, sp in zip(out_flat, out_specs_thunk())] + +JaxType = Any +MaybeTracer = Union[JaxType, Tracer] + +class ShardMapPrimitive(core.Primitive): + multiple_results = True + + def bind(self, *args, **params): + return self._true_bind(*args, **params) + + def bind_with_trace(self, trace, fun_and_args, params): + fun: lu.WrappedFun + fun, *args = fun_and_args + return trace.process_shard_map(shard_map_p, fun, args, **params) + + def get_bind_params(self, params): + new_params = dict(params) + jaxpr: core.Jaxpr = new_params.pop('jaxpr') + subfun = lu.hashable_partial(lu.wrap_init(core.eval_jaxpr, + debug_info=jaxpr.debug_info), + jaxpr, ()) + axes = new_params.pop('out_specs') + new_params['out_specs_thunk'] = HashableFunction(lambda: axes, closure=axes) + return [subfun], new_params + +shard_map_p = ShardMapPrimitive('shard_map') + +# Staging + +@util.cache(max_size=256, trace_context_in_key=True) +def _as_manual_mesh(mesh, manual_axes: frozenset): + not_manual = set(mesh.axis_names) - manual_axes + cur_mesh = get_abstract_mesh() + if cur_mesh.empty: + cur_mesh = mesh + explicit_axes, auto_axes = set(), set() # type: ignore + for a in not_manual: + if cur_mesh._name_to_type[a] == AxisType.Auto: + auto_axes.add(a) + else: + assert cur_mesh._name_to_type[a] == AxisType.Explicit, ( + a, cur_mesh._name_to_type[a]) + explicit_axes.add(a) + + new_axis_types = [] + for n in mesh.axis_names: + if n in manual_axes: + new_axis_types.append(AxisType.Manual) + elif n in auto_axes: + new_axis_types.append(AxisType.Auto) + else: + assert n in explicit_axes + new_axis_types.append(AxisType.Explicit) + return AbstractMesh(mesh.axis_sizes, mesh.axis_names, + axis_types=tuple(new_axis_types)) + + +def _extend_axis_env(mesh, manual_axes): + return core.extend_axis_env_nd([(k, v) for k, v in mesh.shape.items() + if k in manual_axes]) + +def _shard_map_staging( + trace: pe.DynamicJaxprTrace, prim: core.Primitive, f: lu.WrappedFun, + in_tracers: Sequence[Any], *, mesh: Mesh, + in_specs, out_specs_thunk, check_vma: bool, manual_axes: frozenset, + ) -> Sequence[pe.DynamicJaxprTracer]: + source_info = source_info_util.current() + to_jaxpr_tracer = partial(trace.to_jaxpr_tracer, source_info=source_info) + in_tracers = map(to_jaxpr_tracer, in_tracers) + inner_mesh = _as_manual_mesh(mesh, manual_axes | set(mesh.manual_axes)) + in_avals = [t.aval for t in in_tracers] + in_avals_ = map(partial(_shard_aval, mesh, manual_axes, check_vma), in_specs, + in_avals) + with (_extend_axis_env(mesh, manual_axes), use_abstract_mesh(inner_mesh), + config._check_vma(check_vma)): + jaxpr, out_avals_, consts, () = pe.trace_to_jaxpr_dynamic(f, in_avals_) + _check_names(out_specs_thunk(), out_avals_) + if check_vma: + out_vma = [v.aval.vma for v in jaxpr.outvars] + _check_vmas(mesh, out_specs_thunk(), out_vma) + out_avals = map(_check_shapedarray, out_avals_) + out_avals = [_check_shapedarray(_unshard_aval(mesh, check_vma, spec, aval)) + for spec, aval in zip(out_specs_thunk(), out_avals)] + out_tracers = [pe.DynamicJaxprTracer(trace, a, source_info) for a in out_avals] + invars = map(trace.getvar, in_tracers) + constvars = map(trace.getvar, map(to_jaxpr_tracer, consts)) + outvars = map(trace.makevar, out_tracers) + in_specs_staged = (P(),) * len(consts) + tuple(in_specs) # type: ignore + with (_extend_axis_env(mesh, manual_axes), use_abstract_mesh(inner_mesh), + config._check_vma(check_vma)): + jaxpr = pe.convert_constvars_jaxpr(jaxpr) + params = dict(mesh=mesh, in_specs=in_specs_staged, + out_specs=tuple(out_specs_thunk()), jaxpr=jaxpr, + check_vma=check_vma, manual_axes=manual_axes) + effs = core.filter_named_axis_effects(jaxpr.effects, mesh.axis_names) + eqn = pe.new_jaxpr_eqn([*constvars, *invars], outvars, prim, params, + effs, source_info) + trace.frame.add_eqn(eqn) + return out_tracers +pe.DynamicJaxprTrace.process_shard_map = _shard_map_staging + +# TODO add underscore version, for direct-linearize to consume + +def _spec_to_names(spec: PartitionSpec): + return {i: names if isinstance(names, tuple) else (names,) + for i, names in enumerate(spec) if names is not None} + +def _check_shapedarray(aval: core.AbstractValue) -> core.ShapedArray: + assert isinstance(aval, core.ShapedArray) + return aval + +def _shard_aval(mesh: Mesh, manual_axes, check_vma, spec, + aval: core.AbstractValue) -> core.AbstractValue: + if type(aval) in core.shard_aval_handlers: + return core.shard_aval_handlers[type(aval)](mesh, manual_axes, check_vma, + spec, aval) + raise NotImplementedError(f"Unsupported aval type: {type(aval)}") + +def _unshard_aval(mesh: Mesh, check_vma, spec, + aval: core.AbstractValue) -> core.AbstractValue: + if type(aval) in core.unshard_aval_handlers: + return core.unshard_aval_handlers[type(aval)](mesh, check_vma, spec, aval) + else: + raise NotImplementedError(f"Unsupported aval type: {type(aval)}") + +def _shard_shaped_array(mesh: Mesh, manual_axes: frozenset, check_vma, + spec, aval: core.AbstractValue) -> core.AbstractValue: + assert isinstance(aval, core.ShapedArray) + names = _spec_to_names(spec) + new_shape = tuple(sz // prod(mesh.shape[n] for n in names.get(i, ())) + for i, sz in enumerate(aval.shape)) + manual_mesh = _as_manual_mesh(mesh, manual_axes | set(mesh.manual_axes)) + new_sharding = NamedSharding(manual_mesh, aval.sharding.spec) + vma = _spec_to_vma(spec) if check_vma else frozenset() + vma = vma | aval.vma + return aval.update(shape=new_shape, sharding=new_sharding, vma=vma) +core.shard_aval_handlers[core.ShapedArray] = _shard_shaped_array + +def _unshard_shaped_array(mesh: Mesh, check_vma, spec, aval: core.AbstractValue + ) -> core.AbstractValue: + assert isinstance(aval, core.ShapedArray) + names = _spec_to_names(spec) + new_shape = tuple(sz * prod(mesh.shape[n] for n in names.get(i, ())) + for i, sz in enumerate(aval.shape)) + names_spec = spec._normalized_spec_for_aval(aval.ndim) + if aval.ndim == 0: + out_spec = P() + else: + out_spec = [] # type: ignore + for name_s, aval_s in zip(names_spec, aval.sharding.spec): + if name_s and not aval_s: + out_spec.append(name_s) + elif aval_s and not name_s: + out_spec.append(aval_s) + elif not name_s and not aval_s: + out_spec.append(None) + else: + assert name_s and aval_s + name_s = name_s if isinstance(name_s, tuple) else (name_s,) + aval_s = aval_s if isinstance(aval_s, tuple) else (aval_s,) + out_spec.append(name_s + aval_s) + out_spec = PartitionSpec(*out_spec) + new_mesh = (mesh.abstract_mesh if get_abstract_mesh().empty else + get_abstract_mesh()) + new_sharding = NamedSharding(new_mesh, out_spec) + manual_axes = set(new_mesh.manual_axes) + vma = (frozenset(v for v in aval.vma if v in manual_axes) + if check_vma else frozenset()) + return aval.update(shape=new_shape, sharding=new_sharding, vma=vma) +core.unshard_aval_handlers[core.ShapedArray] = _unshard_shaped_array + +# Type-checking + +def _shard_map_typecheck(_, *in_atoms, jaxpr, mesh, in_specs, out_specs, + check_vma, manual_axes): + # TODO(mattjj,parkers): check auto + for v, x, in_spec in zip(jaxpr.invars, in_atoms, in_specs): + if not core.typecompat(v.aval, _shard_aval( + mesh, manual_axes, check_vma, in_spec, x.aval)): + raise core.JaxprTypeError("shard_map argument avals not compatible with " + "jaxpr binder avals and in_specs") + with _extend_axis_env(mesh, manual_axes), config._check_vma(check_vma): + core.check_jaxpr(jaxpr) + if check_vma: + out_vma = [v.aval.vma for v in jaxpr.outvars] + for vma, out_spec in zip(out_vma, out_specs): + if not _valid_repeats(mesh, vma, out_spec): + raise core.JaxprTypeError( + "shard_map can't prove output is sufficiently replicated") + out_avals_sharded = [x.aval for x in jaxpr.outvars] + out_avals = map(partial(_unshard_aval, mesh, check_vma), out_specs, + out_avals_sharded) + effs = core.filter_named_axis_effects(jaxpr.effects, mesh.axis_names) + return out_avals, effs +core.custom_typechecks[shard_map_p] = _shard_map_typecheck + + +def _valid_repeats(mesh: Mesh, vma: Set[AxisName], spec) -> bool: + um = set(_unmentioned(mesh, spec)) - set(mesh.manual_axes) + if any(u in vma for u in um): + return False + return True + +# Lowering + +def _shardy_shard_map_sharding( + ctx: mlir.LoweringRuleContext, mesh, manual_axes, spec, aval_in +) -> sharding_impls.SdyArray: + ns = _make_scoped_manual_sharding(ctx, mesh, spec) + if dtypes.issubdtype(aval_in.dtype, dtypes.extended): + ns = sharding_impls.physical_sharding(aval_in, ns) + aval_in = core.physical_aval(aval_in) + sdy_sharding = ns._to_sdy_sharding(aval_in.ndim) + if len(manual_axes) < len(mesh.axis_names): + for dim_sharding in sdy_sharding.dim_shardings: + dim_sharding.is_open = True + return sdy_sharding + + +def _shardy_shard_map_token_sharding( + ctx: mlir.LoweringRuleContext, mesh + ) -> ir.Attribute: + ns = _make_scoped_manual_sharding(ctx, mesh, P()) + return ns._to_sdy_sharding(0) + + +def _get_spmdaxis_ctx_mesh(mesh): + if isinstance(mesh, AbstractMesh): + concrete_mesh = get_concrete_mesh() + return concrete_mesh if concrete_mesh is not None else mesh + return mesh + + +def _shard_map_lowering_shardy( + ctx, in_nodes, jaxpr, mesh, in_specs, out_specs, manual_axes, check_vma): + axis_ctx = ctx.module_context.axis_context + in_avals_ = [v.aval for v in jaxpr.invars] + if isinstance(axis_ctx, sharding_impls.SPMDAxisContext): + # Nested `ManualComputationOp`s cannot refer to axes that are already + # manual. So figure out what axes are free thus far. + shardy_manual_axes = frozenset(mesh.axis_names) - axis_ctx.manual_axes + else: + shardy_manual_axes = manual_axes + new_axis_context = sharding_impls.SPMDAxisContext( + _get_spmdaxis_ctx_mesh(mesh), manual_axes) + sub_ctx = ctx.module_context.replace(axis_context=new_axis_context) + + tokens = [ctx.tokens_in.get(eff) for eff in ctx.tokens_in.effects()] + num_tokens = len(tokens) + manual_axes = order_wrt_mesh(mesh, shardy_manual_axes) + if np.prod([mesh.shape[a] for a in manual_axes]) == 1: + # No need for a `ManualComputationOp` if all manual axes are size 1. + with _extend_axis_env(mesh, manual_axes), config._check_vma(check_vma): + out_nodes, tokens_out = mlir.jaxpr_subcomp( + sub_ctx, jaxpr, ctx.name_stack, + mlir.TokenSet(zip(ctx.tokens_in.effects(), tokens)), + (), *in_nodes, + dim_var_values=ctx.dim_var_values) + ctx.set_tokens_out(tokens_out) + return out_nodes + + in_shardings = list( + map(partial(_shardy_shard_map_sharding, ctx, mesh, manual_axes), + in_specs, ctx.avals_in)) + num_dim_vars = len(ctx.dim_var_values) + in_shardings = ([_shardy_shard_map_token_sharding(ctx, mesh)] + * (num_tokens + num_dim_vars) + in_shardings) + in_shardings = sharding_impls.SdyArrayList(in_shardings).build() + + out_shardings = list( + map(partial(_shardy_shard_map_sharding, ctx, mesh, manual_axes), + out_specs, ctx.avals_out)) + out_shardings = [ + _shardy_shard_map_token_sharding(ctx, mesh)] * num_tokens + out_shardings + out_shardings = sharding_impls.SdyArrayList(out_shardings).build() + + output_types = ([hlo.TokenType.get()] * num_tokens + + list(map(mlir.aval_to_ir_type, ctx.avals_out))) + + args = (*ctx.dim_var_values, *tokens, *in_nodes) + manual_computation_op = sdy.ManualComputationOp( + output_types, + mlir.flatten_ir_values(args), + in_shardings, out_shardings, + sdy.ManualAxesAttr.get( + ir.ArrayAttr.get([ir.StringAttr.get(i) for i in manual_axes]))) + block = ir.Block.create_at_start( + manual_computation_op.body, + (*(i if isinstance(i, ir.Type) else i.type for i in ctx.dim_var_values), + *([hlo.TokenType.get()] * num_tokens), + *map(mlir.aval_to_ir_type, in_avals_))) + with (ir.InsertionPoint(block), _extend_axis_env(mesh, manual_axes), + config._check_vma(check_vma)): + out_nodes_, tokens_out = mlir.jaxpr_subcomp( + sub_ctx, jaxpr, ctx.name_stack, + mlir.TokenSet(zip( + ctx.tokens_in.effects(), block.arguments[:num_tokens])), + (), *block.arguments[num_tokens+num_dim_vars:], + dim_var_values=ctx.dim_var_values) + sdy.ReturnOp([ir.Value(x) for x in (*[v for _, v in tokens_out.items()], + *out_nodes_)]) + num_tokens = len(tokens_out.effects()) + tokens_out = tokens_out.update_tokens(mlir.TokenSet(zip( + ctx.tokens_in.effects(), manual_computation_op.results[:num_tokens]))) + ctx.set_tokens_out(tokens_out) + + return manual_computation_op.results[num_tokens:] + + +def _shard_map_lowering(ctx, *in_nodes, jaxpr, mesh, in_specs, out_specs, + check_vma, manual_axes): + if config.use_shardy_partitioner.value: + return _shard_map_lowering_shardy( + ctx, in_nodes, jaxpr, mesh, in_specs, out_specs, manual_axes, check_vma) + + in_avals_ = [v.aval for v in jaxpr.invars] + out_avals_ = [x.aval for x in jaxpr.outvars] + in_nodes_ = map(partial(_xla_shard, ctx, mesh, manual_axes), in_specs, + ctx.avals_in, in_avals_, in_nodes) + new_axis_context = sharding_impls.SPMDAxisContext( + _get_spmdaxis_ctx_mesh(mesh), manual_axes) + sub_ctx = ctx.module_context.replace(axis_context=new_axis_context) + with _extend_axis_env(mesh, manual_axes), config._check_vma(check_vma): + out_nodes_, tokens_out = mlir.call_lowering( + "shmap_body", ctx.name_stack, jaxpr, None, sub_ctx, in_avals_, + out_avals_, ctx.tokens_in, *in_nodes_, + dim_var_values=ctx.dim_var_values, + arg_names=map(_pspec_mhlo_attrs, in_specs, in_avals_), + result_names=map(_pspec_mhlo_attrs, out_specs, out_avals_)) + ctx.set_tokens_out(tokens_out) + return map(partial(_xla_unshard, ctx, mesh, manual_axes), out_specs, + out_avals_, ctx.avals_out, out_nodes_) +mlir.register_lowering(shard_map_p, _shard_map_lowering) + +def _make_scoped_manual_sharding(ctx, mesh, spec): + axis_ctx = ctx.module_context.axis_context + mesh = mesh.abstract_mesh + if isinstance(axis_ctx, sharding_impls.SPMDAxisContext): + mesh = mesh.update_axis_types( + {a: AxisType.Manual for a in axis_ctx.manual_axes}) + return NamedSharding(mesh, spec) + +def _xla_shard(ctx: mlir.LoweringRuleContext, mesh, manual_axes, spec, + aval_in, aval_out, x): + if prod([size for n, size in mesh.shape.items() if n in manual_axes]) == 1: + return x + ns = _make_scoped_manual_sharding(ctx, mesh, spec) + if dtypes.issubdtype(aval_in.dtype, dtypes.extended): + ns = sharding_impls.physical_sharding(aval_in, ns) + aval_in = core.physical_aval(aval_in) + shard_proto = ns._to_xla_hlo_sharding(aval_in.ndim).to_proto() + unspecified = (set(range(aval_in.ndim)) + if len(manual_axes) < len(mesh.axis_names) else set()) + sx = mlir.wrap_with_sharding_op(ctx, x, aval_in, shard_proto, + unspecified_dims=unspecified) + manual_proto = pxla.manual_proto( + aval_in, manual_axes | set(mesh.manual_axes), mesh) + return mlir.wrap_with_full_to_shard_op(ctx, sx, aval_out, manual_proto, + unspecified) + +def _xla_unshard(ctx: mlir.LoweringRuleContext, mesh, manual_axes, spec, + aval_in, aval_out, x): + if prod([size for n, size in mesh.shape.items() if n in manual_axes]) == 1: + return x + ns = _make_scoped_manual_sharding(ctx, mesh, spec) + if dtypes.issubdtype(aval_out.dtype, dtypes.extended): + ns = sharding_impls.physical_sharding(aval_out, ns) + aval_out = core.physical_aval(aval_out) + unspecified = (set(range(aval_in.ndim)) + if len(manual_axes) < len(mesh.axis_names) else set()) + if dtypes.issubdtype(aval_in.dtype, dtypes.extended): + aval_in = core.physical_aval(aval_in) + manual_proto = pxla.manual_proto( + aval_in, manual_axes | set(mesh.manual_axes), mesh) + sx = mlir.wrap_with_sharding_op(ctx, x, aval_in, manual_proto, + unspecified_dims=unspecified) + shard_proto = ns._to_xla_hlo_sharding(aval_out.ndim).to_proto() + return mlir.wrap_with_shard_to_full_op(ctx, sx, aval_out, shard_proto, + unspecified) + +def _pspec_mhlo_attrs(spec, aval: core.AbstractValue) -> str: + if isinstance(aval, core.ShapedArray): + names = _spec_to_names(spec) + return str(map(names.get, range(aval.ndim))) + return '' + +# Eager evaluation + +def get_mesh_from_args(args_flat, mesh): + for a in args_flat: + if hasattr(a, 'sharding') and isinstance(a.sharding, NamedSharding): + if a.sharding.mesh.shape_tuple != mesh.shape_tuple: + aval = core.shaped_abstractify(a) + raise ValueError( + f"Mesh shape of the input {a.sharding.mesh.shape_tuple} does not" + " match the mesh shape passed to shard_map " + f" {mesh.shape_tuple} for shape {aval.str_short()}") + mesh = a.sharding.mesh + if isinstance(mesh, AbstractMesh): + raise ValueError( + "Please pass `jax.Array`s with a `NamedSharding` as input to" + " `shard_map` when passing `AbstractMesh` to the mesh argument.") + assert isinstance(mesh, Mesh) + return mesh + +def _vma_to_spec(mesh, vma): + return P(order_wrt_mesh(mesh, vma)) + +def _spec_to_vma(spec): + return frozenset(p for s in spec if s is not None + for p in (s if isinstance(s, tuple) else (s,))) + +def order_wrt_mesh(mesh, x): + return tuple(a for a in mesh.axis_names if a in x) + +def _shard_map_impl(trace, prim, fun, args, *, mesh, in_specs, out_specs_thunk, + check_vma, manual_axes): + if len(manual_axes) < len(mesh.axis_names): + raise NotImplementedError + del prim + if isinstance(mesh, AbstractMesh): + concrete_mesh = get_concrete_mesh() + mesh = concrete_mesh if concrete_mesh is not None else mesh + mesh = get_mesh_from_args(args, mesh) + cur_mesh = get_abstract_mesh() + args = map(partial(_unmatch_spec, mesh, check_vma, context_mesh=cur_mesh), + in_specs, args) + in_vma = map(_spec_to_vma, in_specs) + outs, out_vma = _run_shmap(fun, mesh, manual_axes, args, in_vma, check_vma, + cur_mesh) + out_avals = [core.mapped_aval(x.shape[0], 0, core.get_aval(x)) for x in outs] + _check_names(out_specs_thunk(), out_avals) # pytype: disable=wrong-arg-types + if check_vma: + _check_vmas(mesh, out_specs_thunk(), out_vma) + src_pspecs = tuple(_vma_to_spec(mesh, r) for r in out_vma) + else: + src_pspecs = tuple(P(mesh.axis_names) for _ in out_vma) + dst_pspecs = out_specs_thunk() + return map(partial(_match_spec, mesh, check_vma), src_pspecs, dst_pspecs, + outs) +core.EvalTrace.process_shard_map = _shard_map_impl + +def _run_shmap(f, mesh, manual_axes, args, vmas, check_vma, context_mesh): + trace = ShardMapTrace(mesh, manual_axes, check_vma, context_mesh) + in_tracers = map(partial(ShardMapTracer, trace), vmas, args) + inner_mesh = _as_manual_mesh(mesh, manual_axes | set(mesh.manual_axes)) + with (core.set_current_trace(trace), _extend_axis_env(mesh, manual_axes), + use_abstract_mesh(inner_mesh), config._check_vma(check_vma)): + ans = f.call_wrapped(*in_tracers) + outs, out_vma = unzip2(map(trace.to_val_vma_pair, ans)) + return outs, out_vma + + +def _unmatch_spec(mesh: Mesh, check_vma, in_spec, x: JaxType, context_mesh + ) -> JaxType: + with (core.eval_context(), jax.disable_jit(False), + use_abstract_mesh(context_mesh)): + return jax.jit(HashablePartial(_unmatch, mesh, check_vma, in_spec))(x) + +def _unmatch(mesh, check_vma, in_spec, x): + if check_vma: + used_axes = _spec_to_vma(in_spec) + dst = P(order_wrt_mesh(mesh, used_axes)) + else: + dst = P(mesh.axis_names) + check_vma = False + return shard_map(_add_singleton, mesh=mesh, in_specs=(in_spec,), + out_specs=dst, check_vma=check_vma)(x) + +def _check_names(specs, avals: Sequence[core.ShapedArray]) -> None: + fail = [a if sp and len(sp) > a.ndim else no_fail + for sp, a in zip(specs, avals)] + if any(f is not no_fail for f in fail): + raise _SpecError(fail) + +class _SpecError(Exception): + pass + +def _check_vmas(mesh, specs, vmas): + fail = [vma if not _valid_repeats(mesh, vma, sp) else no_fail + for sp, vma in zip(specs, vmas)] + if any(f is not no_fail for f in fail): + raise _RepError(fail) + +class _RepError(Exception): + pass + +def _match_spec(mesh: Mesh, check_vma, src_pspec: PartitionSpec, + dst_pspec: PartitionSpec, x: JaxType) -> JaxType: + fn = HashablePartial(_match, mesh, check_vma, src_pspec, dst_pspec) + with core.eval_context(), jax.disable_jit(False): + return jax.jit(fn, out_shardings=NamedSharding(mesh, dst_pspec))(x) + +def _match(mesh, check_vma, src_pspec, dst_pspec, x): + return shard_map(_rem_singleton, mesh=mesh, in_specs=src_pspec, + out_specs=dst_pspec, check_vma=check_vma)(x) + +def _rem_singleton(x): return jnp.squeeze(x, axis=0) +def _add_singleton(x): return jnp.expand_dims(x, axis=0) + +def _maybe_check_special(outs): + if not config.debug_nans.value and not config.debug_infs.value: return + bufs = [s.data for leaf in tree_leaves(outs) + for s in getattr(leaf, 'addressable_shards', [])] + try: + dispatch.check_special('shard_map', bufs) + except api_util.InternalFloatingPointError as e: + raise FloatingPointError(f'Invalid value ({e.ty}) encountered in sharded computation.') from None + +class ShardMapTrace(core.Trace): + __slots__ = ("mesh", "manual_axes", "check", "context_mesh") + + mesh: Mesh + manual_axes: frozenset[AxisName] + check: bool + context_mesh: AbstractMesh + + def __init__(self, mesh, manual_axes, check, context_mesh): + super().__init__() + self.mesh = mesh + self.manual_axes = manual_axes + self.check = check + self.context_mesh = context_mesh + + def to_val_vma_pair(self, val): + if isinstance(val, ShardMapTracer): + return val.val, val.vma + elif isinstance(val, Tracer): + raise Exception(f"Shouldn't have any non-shard_map tracers: {val}") + else: + val_ = _unmatch_spec(self.mesh, self.check, P(), val, self.context_mesh) + return val_, frozenset() + + def process_primitive(self, prim, tracers, params): + in_vals, in_vma = unzip2(map(self.to_val_vma_pair, tracers)) + if self.check: + out_avals, _ = prim.abstract_eval(*(typeof(t) for t in tracers), **params) + out_avals = tuple(out_avals) if type(out_avals) is list else out_avals + out_vma = tree_map(lambda a: a.vma, out_avals) + in_specs = tuple(map(partial(_vma_to_spec, self.mesh), in_vma)) + out_specs = tree_map(partial(_vma_to_spec, self.mesh), out_vma) + else: + out_vma = frozenset() + in_specs = out_specs = P(self.mesh.axis_names) + + eager_rule = eager_rules.get(prim) + if eager_rule: + out_vals = eager_rule(self.mesh, *in_vals, **params) + else: + f = HashablePartial( + _prim_applier, prim, self.check, tuple(params.items()), self.mesh, + in_specs, out_specs) + with (core.eval_context(), jax.disable_jit(False), jax.debug_nans(False), + jax.debug_infs(False), use_abstract_mesh(self.context_mesh)): + out_vals = jax.jit(f)(*in_vals) + _maybe_check_special(out_vals) + if prim.multiple_results: + out_vma = (out_vma if isinstance(out_vma, (list, tuple)) + else [out_vma] * len(out_vals)) + return map(partial(ShardMapTracer, self), out_vma, out_vals) + return ShardMapTracer(self, out_vma, out_vals) + + def process_call(self, call_primitive, fun, tracers, params): + raise NotImplementedError( + f"Eager evaluation of `{call_primitive}` inside a `shard_map` isn't " + "yet supported. Put a `jax.jit` around the `shard_map`-decorated " + "function, and open a feature request at " + "https://github.com/jax-ml/jax/issues !") + + def process_map(self, map_primitive, fun, tracers, params): + raise NotImplementedError( + "Eager evaluation of `pmap` inside a `shard_map` isn't yet supported." + "Put a `jax.jit` around the `shard_map`-decorated function, and open " + "a feature request at https://github.com/jax-ml/jax/issues !") + + def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros): + # Since ShardMapTrace is only used as a base main, we can drop the jvp. + del prim, jvp, symbolic_zeros + in_vals, in_vma = unzip2(map(self.to_val_vma_pair, tracers)) + out_vals, out_vma = _run_shmap(fun, self.mesh, self.manual_axes, in_vals, + in_vma, self.check, self.context_mesh) + return map(partial(ShardMapTracer, self), out_vma, out_vals) + + def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees, + symbolic_zeros): + if symbolic_zeros: + msg = ("custom_vjp symbolic_zeros support with shard_map is not " + "implemented; please open an issue at " + "https://github.com/jax-ml/jax/issues") + raise NotImplementedError(msg) + del prim, fwd, bwd, out_trees, symbolic_zeros + in_vals, in_vma = unzip2(map(self.to_val_vma_pair, tracers)) + out_vals, out_vma = _run_shmap(fun, self.mesh, self.manual_axes, in_vals, + in_vma, self.check, self.context_mesh) + return map(partial(ShardMapTracer, self), out_vma, out_vals) + + +class ShardMapTracer(core.Tracer): + vma: frozenset[AxisName] + val: JaxType + + def __init__(self, trace, vma, val): + self._trace = trace + if isinstance(vma, set): + vma = frozenset(vma) + assert isinstance(vma, frozenset) + self.vma = vma + self.val = val + + @property + def aval(self): + aval = core.get_aval(self.val) + out = core.mapped_aval(self._trace.mesh.size, 0, aval) + new_sharding = NamedSharding( + _as_manual_mesh(self._trace.mesh, self._trace.manual_axes), + out.sharding.spec) # pytype: disable=attribute-error + vma = self.vma if config._check_vma.value else frozenset() + return out.update(sharding=new_sharding, vma=vma) + + def to_concrete_value(self): + if self.vma == frozenset(): + with core.eval_context(), use_abstract_mesh(self._trace.context_mesh): + return core.to_concrete_value(self.val[0]) + else: + return None + + def __str__(self) -> str: + pb_names = set(self._trace.mesh.axis_names) - self.vma + self = pvary(self, tuple(pb_names)) + with core.eval_context(), use_abstract_mesh(self._trace.context_mesh): + blocks = list(self.val) + mesh = self._trace.mesh + axis_names = f"({', '.join(map(str, mesh.axis_names))},)" + return '\n'.join( + f"On {device} at mesh coordinates {axis_names} = {idx}:\n{block}\n" + for (idx, device), block in zip(np.ndenumerate(mesh.devices), blocks)) + + __repr__ = __str__ # for debuggers, like `p x` + +def _prim_applier(prim, check_vma, params_tup, mesh, in_specs, out_specs, *args): + def apply(*args): + outs = prim.bind(*map(_rem_singleton, args), **dict(params_tup)) + return tree_map(_add_singleton, outs) + out_specs = list(out_specs) if type(out_specs) is tuple else out_specs + return shard_map(apply, mesh=mesh, in_specs=in_specs, out_specs=out_specs, + check_vma=check_vma)(*args) + +eager_rules: dict[core.Primitive, Callable] = {} + + +# TODO(mattjj): working around an apparent XLA or PjRt bug, remove eventually +def _debug_callback_eager_rule( + mesh, + *args, + callback: Callable[..., Any], + effect: debugging.DebugEffect, + partitioned: bool, +): + del effect + with core.eval_context(): + all_blocks = zip(*map(list, args)) + for (idx, device), blocks in zip(np.ndenumerate(mesh.devices), all_blocks): + callback(*blocks) + return [] + + +eager_rules[debugging.debug_callback_p] = _debug_callback_eager_rule + +def _device_put_eager_rule(mesh, *xs, srcs, devices, copy_semantics): + del mesh, srcs, copy_semantics + for device in devices: + if device is not None: + raise ValueError("device_put with explicit device not allowed within " + f"shard_map-decorated functions, but got device {device}") + return xs +eager_rules[dispatch.device_put_p] = _device_put_eager_rule + + +# Batching + +def _modify_specs_axis_data(trace, name, mesh, in_specs, in_dims): + new_in_specs = [sp if d is batching.not_mapped else pxla.batch_spec(sp, d, name) + for sp, d in zip(in_specs, in_dims)] + new_size = trace.axis_data.size // prod(mesh.shape[n] for n in name) + new_axis_data = batching.AxisData( + trace.axis_data.name, new_size, trace.axis_data.spmd_name, + trace.axis_data.explicit_mesh_axis) + return new_in_specs, new_axis_data + +def _shard_map_batch( + trace: batching.BatchTrace, prim: core.Primitive, fun: lu.WrappedFun, + in_tracers: Sequence[batching.BatchTracer], mesh: Mesh, + in_specs, out_specs_thunk, check_vma: bool, manual_axes: frozenset + ) -> Sequence[batching.BatchTracer]: + in_vals, in_dims = unzip2(map(trace.to_batch_info, in_tracers)) + if any(isinstance(d, batching.RaggedAxis) for d in in_dims): + raise NotImplementedError + spmd_axis_name = trace.axis_data.spmd_name + explicit_mesh_axis = trace.axis_data.explicit_mesh_axis + if spmd_axis_name is not None: + used = {n for spec in in_specs for n in _spec_to_vma(spec)} + if not config.disable_vmap_shmap_error.value and set(spmd_axis_name) & used: + raise ValueError("vmap spmd_axis_name cannot appear in shard_map in_specs") + new_in_specs, new_axis_data = _modify_specs_axis_data( + trace, spmd_axis_name, mesh, in_specs, in_dims) + elif explicit_mesh_axis is not None: + used = {n for spec in in_specs for n in _spec_to_vma(spec)} + if set(explicit_mesh_axis) & used: + raise ValueError("vmapped away explicit mesh axis cannot appear in " + "shard_map in_specs") + new_in_specs, new_axis_data = _modify_specs_axis_data( + trace, explicit_mesh_axis, mesh, in_specs, in_dims) + else: + new_in_specs = [sp if d is batching.not_mapped else pxla.batch_spec(sp, d, None) + for sp, d in zip(in_specs, in_dims)] + new_axis_data = trace.axis_data + fun, out_dims = batching.batch_subtrace(fun, trace.tag, new_axis_data, tuple(in_dims)) + + @as_hashable_function(closure=out_specs_thunk) + def new_out_specs_thunk(): + return _batch_out_specs(spmd_axis_name, explicit_mesh_axis, out_dims(), + out_specs_thunk()) + + new_params = dict(mesh=mesh, in_specs=new_in_specs, + out_specs_thunk=new_out_specs_thunk, check_vma=check_vma, + manual_axes=manual_axes) + with core.set_current_trace(trace.parent_trace): + out_vals = prim.bind(fun, *in_vals, **new_params) + make_tracer = partial(batching.BatchTracer, trace, + source_info=source_info_util.current()) + return map(make_tracer, out_vals, out_dims()) +batching.BatchTrace.process_shard_map = _shard_map_batch + +def _batch_out_specs(spmd_name, explicit_mesh_axis, dims, out_specs): + if spmd_name is not None: + used = {n for spec in out_specs for n in _spec_to_vma(spec)} + if not config.disable_vmap_shmap_error.value and set(spmd_name) & used: + raise ValueError("vmap spmd_axis_name cannot appear in shard_map out_specs") + return [sp if d is batching.not_mapped else pxla.batch_spec(sp, d, spmd_name) + for sp, d in zip(out_specs, dims)] + elif explicit_mesh_axis is not None: + used = {n for spec in out_specs for n in _spec_to_vma(spec)} + if set(explicit_mesh_axis) & used: + raise ValueError("vmapped away explicit mesh axis cannot appear in " + "shard_map out_specs") + return [sp if d is batching.not_mapped else + pxla.batch_spec(sp, d, explicit_mesh_axis) + for sp, d in zip(out_specs, dims)] + else: + return [sp if d is batching.not_mapped else pxla.batch_spec(sp, d, None) + for sp, d in zip(out_specs, dims)] + + +# Autodiff + +def _shard_map_jvp(trace, shard_map_p, f, tracers, mesh, in_specs, + out_specs_thunk, check_vma, manual_axes): + primals, tangents = unzip2(map(trace.to_primal_tangent_pair, tracers)) + which_nz = [ type(t) is not ad.Zero for t in tangents] + tangents = [t if type(t) is not ad.Zero else None for t in tangents] + args, in_tree = tree_flatten((primals, tangents)) + f_jvp = ad.jvp_subtrace(f, trace.tag) + f_jvp, which_nz_out = ad.nonzero_tangent_outputs(f_jvp) + tangent_in_specs = [sp for sp, nz in zip(in_specs, which_nz) if nz] + + @as_hashable_function(closure=out_specs_thunk) + def new_out_specs_thunk(): + out_ax = out_specs_thunk() + return (*out_ax, *(ax for ax, nz in zip(out_ax, which_nz_out()) if nz)) + params = dict(mesh=mesh, in_specs=(*in_specs, *tangent_in_specs), + out_specs_thunk=new_out_specs_thunk, check_vma=check_vma, + manual_axes=manual_axes) + f_jvp, out_tree = ad.traceable(f_jvp, in_tree) + result = shard_map_p.bind_with_trace(trace.parent_trace, (f_jvp,) + tuple(args), params) + primal_out, tangent_out = tree_unflatten(out_tree(), result) + tangent_out = [ad.Zero(core.get_aval(p).to_tangent_aval()) if t is None else t + for p, t in zip(primal_out, tangent_out)] + return [ad.JVPTracer(trace, p, t) for p, t in zip(primal_out, tangent_out)] +ad.JVPTrace.process_shard_map = _shard_map_jvp + +def _shard_map_partial_eval(trace: pe.JaxprTrace, shard_map_p, + f: lu.WrappedFun, tracers, mesh, in_specs, + out_specs_thunk, check_vma, manual_axes): + tracers = map(trace.to_jaxpr_tracer, tracers) + in_pvals = [t.pval for t in tracers] + in_knowns, in_avals, in_consts = pe.partition_pvals(in_pvals) + unk_in_specs, known_in_specs = pe.partition_list(in_knowns, in_specs) + in_avals_sharded = map(partial(_shard_aval, mesh, manual_axes, check_vma), + unk_in_specs, in_avals) + f = pe.trace_to_subjaxpr_nounits_fwd2(f, trace.tag, f.debug_info, False) + f = _promote_scalar_residuals(f) + f_known, aux = pe.partial_eval_wrapper_nounits2( + f, (*in_knowns,), (*in_avals_sharded,)) + all_names = _all_newly_manual_mesh_names(mesh, manual_axes) + + @as_hashable_function(closure=out_specs_thunk) + def known_out_specs(): + _, _, out_knowns, res_avals, _, _ = aux() + _, out_known_specs = pe.partition_list(out_knowns, out_specs_thunk()) + if check_vma: + res_specs = [P(order_wrt_mesh(mesh, a.vma)) for a in res_avals] + else: + res_specs = [P(all_names)] * len(res_avals) + return (*out_known_specs, *res_specs) + + known_params = dict(mesh=mesh, in_specs=(*known_in_specs,), + out_specs_thunk=known_out_specs, check_vma=check_vma, + manual_axes=manual_axes) + out = shard_map_p.bind_with_trace(trace.parent_trace, (f_known, *in_consts), + known_params) + in_fwd, out_fwd, out_knowns, res_avals, jaxpr, env = aux() + num_res = sum(f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd)) + out_consts, non_fwd_res = split_list(out, [len(out) - num_res]) + assert not jaxpr.constvars + unk_out_specs, _ = pe.partition_list(out_knowns, out_specs_thunk()) + known_out_specs_ = known_out_specs() + res = subs_list2(in_fwd, out_fwd, in_consts, out_consts, non_fwd_res) + # TODO make res_avals be the full set, not just the non-fwd ones + res_avals_iter = iter(res_avals) + res_specs = [] + for f1, f2 in zip(in_fwd, out_fwd): + if f1 is not None: + res_specs.append(known_in_specs[f1]) + elif f2 is not None: + res_specs.append(known_out_specs_[f2]) + else: + if check_vma: + res_vma = next(res_avals_iter).vma + res_specs.append(P(order_wrt_mesh(mesh, res_vma))) + else: + res_specs.append(P(all_names)) + unk_in_specs = (*res_specs,) + (P(),) * len(env) + (*unk_in_specs,) # type: ignore[assignment] + const_tracers = map(trace.new_instantiated_const, res) + env_tracers = map(trace.to_jaxpr_tracer, env) + unk_arg_tracers = [t for t in tracers if not t.is_known()] + out_avals_sharded = [v.aval for v in jaxpr.outvars] + unk_params = dict(mesh=mesh, in_specs=unk_in_specs, + out_specs=unk_out_specs, jaxpr=jaxpr, + check_vma=check_vma, manual_axes=manual_axes) + out_avals = map(partial(_unshard_aval, mesh, check_vma), unk_out_specs, + out_avals_sharded) + out_tracers = [pe.JaxprTracer(trace, pe.PartialVal.unknown(a), None) + for a in out_avals] + effs = core.filter_named_axis_effects(jaxpr.effects, mesh.axis_names) + eqn = pe.new_eqn_recipe(trace, (*const_tracers, *env_tracers, *unk_arg_tracers), + out_tracers, shard_map_p, unk_params, + effs, source_info_util.current()) + for t in out_tracers: t.recipe = eqn + return merge_lists(out_knowns, out_tracers, out_consts) +pe.JaxprTrace.process_shard_map = _shard_map_partial_eval + +def _shard_map_linearize(trace, shard_map_p, f: lu.WrappedFun, + tracers, mesh, in_specs, out_specs_thunk, check_vma, + manual_axes): + primals, tangents = unzip2(map(trace.to_primal_tangent_pair, tracers)) + nzs_in = tuple(type(t) is not ad.Zero for t in tangents) + f_primal, linearize_outs_thunk = ad.linearize_subtrace(f, trace.tag, nzs_in, f.debug_info) + f_primal = _promote_scalar_residuals_lin(f_primal, linearize_outs_thunk) + all_names = _all_newly_manual_mesh_names(mesh, manual_axes) + + @as_hashable_function(closure=linearize_outs_thunk) + def fwd_out_specs_thunk(): + res_avals, _, _, _, in_fwd, out_fwd = linearize_outs_thunk() + res_avals = [r for r, f1, f2 in zip(res_avals, in_fwd, out_fwd) + if f1 is None and f2 is None] + out_specs = out_specs_thunk() + if check_vma: + res_specs = [P(order_wrt_mesh(mesh, a.vma)) for a in res_avals] + else: + res_specs = [P(all_names)] * len(res_avals) + return (*res_specs, *out_specs) + fwd_params = dict( + mesh=mesh, in_specs=in_specs, + out_specs_thunk=fwd_out_specs_thunk, check_vma=check_vma, + manual_axes=manual_axes) + all_fwd_results = shard_map_p.bind_with_trace( + trace.parent_trace, (f_primal, *primals), fwd_params) + res_avals, nzs_out, lin_jaxpr, env, in_fwd, out_fwd = linearize_outs_thunk() + num_res_out = sum(f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd)) + non_fwd_res = all_fwd_results[:num_res_out] + primals_out = all_fwd_results[num_res_out:] + residuals = subs_list2(in_fwd, out_fwd, primals, primals_out, non_fwd_res) + args_to_promote = [getattr(aval, 'shape', ()) == () and f1 is None and f2 is None + for aval, f1, f2 in zip(res_avals, in_fwd, out_fwd)] + with (_extend_axis_env(mesh, manual_axes), + use_abstract_mesh(_as_manual_mesh(mesh, manual_axes | set(mesh.manual_axes))), + config._check_vma(check_vma)): + lin_jaxpr = _promote_scalar_residuals_jaxpr(lin_jaxpr, args_to_promote) + out_specs = out_specs_thunk() + res_avals2 = [r for r, f1, f2 in zip(res_avals, in_fwd, out_fwd) + if f1 is None and f2 is None] + res_avals_iter = iter(res_avals2) + res_specs = [] + for f1, f2 in zip(in_fwd, out_fwd): + if f1 is not None: + res_specs.append(in_specs[f1]) + elif f2 is not None: + res_specs.append(out_specs[f2]) + else: + if check_vma: + res_vma = next(res_avals_iter).vma + res_specs.append(P(order_wrt_mesh(mesh, res_vma))) + else: + res_specs.append(P(all_names)) + new_in_specs = (*res_specs, *(P(),) * len(env), + *(ax for ax, nz in zip(in_specs, nzs_in) if nz)) + tangent_out_specs = tuple(ax for ax, nz in zip(out_specs_thunk(), nzs_out) + if nz) + @as_hashable_function(closure=tangent_out_specs) + def tangent_out_specs_thunk(): + return tangent_out_specs + tangent_params = dict( + mesh=mesh, in_specs=new_in_specs, out_specs_thunk=tangent_out_specs_thunk, + check_vma=check_vma, manual_axes=manual_axes) + + # TODO(mattjj): avoid round-tripping the jaxpr through eval_jaxpr here + def f_tangent(*args): + return core.eval_jaxpr(lin_jaxpr, (), *args) + + nz_tangents_in = [t for (t, nz) in zip(tangents, nzs_in) if nz] + nz_tangents_out = shard_map_p.bind_with_trace( + trace.tangent_trace, + (lu.wrap_init(f_tangent, debug_info=lin_jaxpr.debug_info), + *residuals, *env, *nz_tangents_in), tangent_params) + nz_tangents_out_iter = iter(nz_tangents_out) + tangents_out = [next(nz_tangents_out_iter) if nz else ad.Zero.from_primal_value(primal) + for nz, primal in zip(nzs_out, primals_out)] + return map(partial(ad.maybe_linearize_tracer, trace), primals_out, nzs_out, tangents_out) +ad.LinearizeTrace.process_shard_map = _shard_map_linearize + +@lu.transformation2 +def _promote_scalar_residuals_lin(f, linearize_outs_thunk, *args, **kwargs): + ans = f(*args, **kwargs) + _, _, _, _, in_fwd, out_fwd = linearize_outs_thunk() + num_res_out = sum(f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd)) + residuals = ans[:num_res_out] + primals = ans[num_res_out:] + residuals = [jax.lax.broadcast(x, (1,)) if not getattr(x, 'shape', ()) else x + for x in residuals] + return *residuals, *primals + +@lu.transformation2 +def _promote_scalar_residuals(f: Callable, *args, **kwargs): + jaxpr, (in_fwds, out_fwds, out_pvals, out_consts, env) = f(*args, **kwargs) + which = [f1 is None and f2 is None and not v.aval.shape + for f1, f2, v in zip(in_fwds, out_fwds, jaxpr.constvars)] + jaxpr = _promote_scalar_residuals_jaxpr(jaxpr, which) + out_consts = [jax.lax.broadcast(x, (1,)) if not getattr(x, 'shape', ()) else x + for x in out_consts] + return jaxpr, (in_fwds, out_fwds, out_pvals, out_consts, env) + +def _promote_scalar_residuals_jaxpr(jaxpr: core.Jaxpr, which: Sequence[bool]): + def fun(*res_and_args): + res, args = split_list(res_and_args, [len(jaxpr.constvars)]) + res = [_rem_singleton(x) if w else x for x, w in zip(res, which)] + return core.eval_jaxpr(jaxpr, res, *args) + res_avals = [core.unmapped_aval(1, 0, v.aval) if w else v.aval + for v, w in zip(jaxpr.constvars, which)] + in_avals = [*res_avals, *[v.aval for v in jaxpr.invars]] + jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic( + lu.wrap_init(fun, debug_info=jaxpr.debug_info), in_avals) + return jaxpr + + +def _unmentioned2(mesh: Mesh, spec, manual_axes: frozenset[AxisName] + ) -> list[AxisName]: + # We use a filtered-down version of unmentioned to avoid defensive-psum over + # more chips than required in the transpose-no-check-vma case. + name_set = _spec_to_vma(spec) + return [n for n in _all_mesh_names_except_spmd(mesh, manual_axes) + if n not in name_set] + + +def _shard_map_transpose(out_cts, *args, + jaxpr: core.Jaxpr, mesh, in_specs, out_specs, + check_vma, manual_axes): + mb_div = lambda x, y: x / y if y != 1 else x + out_cts = [ + ad.Zero(_shard_aval(mesh, manual_axes, check_vma, sp, x.aval)) + if type(x) is ad.Zero else x if check_vma or dtypes.dtype(x) == dtypes.float0 + else mb_div(x, prod(map(mesh.shape.get, _unmentioned2(mesh, sp, manual_axes)))) + for sp, x in zip(out_specs, out_cts) + ] + args = tuple(x if type(x) is not ad.UndefinedPrimal else + ad.UndefinedPrimal( + _shard_aval(mesh, manual_axes, check_vma, sp, x.aval)) + for sp, x in zip(in_specs, args)) + all_args, in_tree = tree_flatten((out_cts, args)) + + def fun_trans_callable(out_cts, args): + # TODO(mattjj): when #26811 lands, delete this and just run backward_pass + in_undef = map(ad.is_undefined_primal, args) + res, undefs = partition_list(in_undef, args) + jaxpr_known, jaxpr_unknown, _, _ = pe.partial_eval_jaxpr_nounits( + pe.close_jaxpr(jaxpr), in_undef, False) + res_reshaped = core.jaxpr_as_fun(jaxpr_known)(*res) + in_cts = ad.backward_pass( + jaxpr_unknown.jaxpr, False, (), (*res_reshaped, *undefs), out_cts + )[len(res_reshaped):] + _, in_ct_specs = partition_list(in_undef, in_specs) + in_cts = [ad.Zero(_unshard_aval(mesh, check_vma, sp, x.aval)) + if type(x) is ad.Zero else x if check_vma + else jax.lax.psum(x, tuple(_unmentioned2(mesh, sp, manual_axes))) + for sp, x in zip(in_ct_specs, in_cts)] + res_zeros = [ad_util.zero_from_primal(r) for r in res] + return merge_lists(in_undef, res_zeros, in_cts) + + fun_trans = lu.wrap_init(fun_trans_callable, debug_info=jaxpr.debug_info) + fun_trans, nz_arg_cts = ad.nonzero_outputs(fun_trans) + fun_trans_flat, out_tree = api_util.flatten_fun_nokwargs(fun_trans, in_tree) + + new_in_specs = ( + [n for n, x in zip(out_specs, out_cts) if type(x) is not ad.Zero] + + [n for n, x in zip(in_specs, args) if type(x) is not ad.UndefinedPrimal]) + + def new_out_specs_thunk(): + return tuple(sp for sp, nz in zip(in_specs, nz_arg_cts()) if nz) + + try: + out_flat = shard_map_p.bind( + fun_trans_flat, *all_args, mesh=mesh, in_specs=tuple(new_in_specs), + out_specs_thunk=new_out_specs_thunk, check_vma=check_vma, + manual_axes=manual_axes) + except (FloatingPointError, ZeroDivisionError) as e: + print("Invalid nan value encountered in the backward pass of a shard_map " + "function. Calling the de-optimized backward pass.") + try: + # TODO(mattjj): Remove this and do `fun_trans.call_wrapped(out_cts, args)` + # in eager mode so that output of shmap are not manual. + with jax.disable_jit(True): + _ = shard_map_p.bind( + fun_trans_flat, *all_args, mesh=mesh, in_specs=tuple(new_in_specs), + out_specs_thunk=new_out_specs_thunk, check_vma=check_vma, + manual_axes=manual_axes) + except (FloatingPointError, ZeroDivisionError) as e2: + raise e2 from None + else: + api_util._raise_no_nan_in_deoptimized(e) + return tree_unflatten(out_tree(), out_flat) +ad.primitive_transposes[shard_map_p] = _shard_map_transpose + +# Remat + +def _partial_eval_jaxpr_custom_rule( + saveable: Callable[..., pe.RematCases_], unks_in: Sequence[bool], + inst_in: Sequence[bool], eqn: core.JaxprEqn +) -> tuple[core.JaxprEqn, core.JaxprEqn, Sequence[bool], Sequence[bool], + list[core.Var]]: + jaxpr, mesh = eqn.params['jaxpr'], eqn.params['mesh'] + check_vma, manual_axes = eqn.params['check_vma'], eqn.params['manual_axes'] + with (_extend_axis_env(mesh, manual_axes), config._check_vma(check_vma), + use_abstract_mesh(_as_manual_mesh(mesh, manual_axes | set(mesh.manual_axes)))): + jaxpr_known, jaxpr_staged, unks_out, inst_out, num_res = \ + pe.partial_eval_jaxpr_custom(jaxpr, unks_in, inst_in, False, False, saveable) + num_out_primals = len(jaxpr_known.outvars) - num_res + in_fwd = pe._jaxpr_forwarding(jaxpr_known)[num_out_primals:] + out_vars, res_vars = split_list(jaxpr_known.outvars, [num_out_primals]) + idx_map = {id(v): i for i, v in enumerate(out_vars)} + out_fwd = [idx_map.get(id(v)) for v in res_vars] + which = [f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd)] + mesh = eqn.params['mesh'] + with (_extend_axis_env(mesh, manual_axes), + use_abstract_mesh(_as_manual_mesh(mesh, manual_axes | set(mesh.manual_axes))), + config._check_vma(check_vma)): + jaxpr_known = pe.prune_jaxpr_outputs(jaxpr_known, [True] * num_out_primals + which) + jaxpr_known, jaxpr_staged = _add_reshapes(which, jaxpr_known, jaxpr_staged) + jaxpr_known = core.remove_named_axis_effects(jaxpr_known, mesh.axis_names) + jaxpr_staged = core.remove_named_axis_effects(jaxpr_staged, mesh.axis_names) + ins_known, _ = partition_list(unks_in, eqn.invars) + out_binders_known, _ = partition_list(unks_out, eqn.outvars) + _, ins_staged = partition_list(inst_in, eqn.invars) + _, out_binders_staged = partition_list(inst_out, eqn.outvars) + newvar = core.gensym() + residuals, staged_in_res_specs = [], [] + for var, w in zip(jaxpr_staged.invars[:num_res], which): + if w: + rn = (P(order_wrt_mesh(mesh, var.aval.vma)) # type: ignore + if check_vma else P(_all_newly_manual_mesh_names(mesh, manual_axes))) + residuals.append(newvar(_unshard_aval(mesh, check_vma, rn, var.aval))) + staged_in_res_specs.append(rn) + if check_vma: + out_res_specs_known = [P(order_wrt_mesh(mesh, var.aval.vma)) # type: ignore + for var, o in zip(res_vars, out_fwd) if o is None] + else: + out_res_specs_known = [ + P(_all_newly_manual_mesh_names(mesh, manual_axes))] * sum(which) + params_known, params_staged = _pe_custom_params( + unks_in, inst_in, map(op.not_, unks_out), inst_out, in_fwd, out_fwd, + out_res_specs_known, staged_in_res_specs, + dict(eqn.params, jaxpr=jaxpr_known), dict(eqn.params, jaxpr=jaxpr_staged)) + eqn_known = pe.new_jaxpr_eqn(ins_known, [*out_binders_known, *residuals], + eqn.primitive, params_known, jaxpr_known.effects, + eqn.source_info, eqn.ctx) + full_res = subs_list2(in_fwd, out_fwd, ins_known, out_binders_known, residuals) + eqn_staged = pe.new_jaxpr_eqn([*full_res, *ins_staged], out_binders_staged, + eqn.primitive, params_staged, + jaxpr_staged.effects, eqn.source_info, eqn.ctx) + assert len(eqn_staged.invars) == len(jaxpr_staged.invars) + new_inst = [x for x, inst in zip(eqn.invars, inst_in) + if type(x) is core.Var and not inst] + new_inst += [out_binders_known[f] for f in {i for i in out_fwd if i is not None}] + return eqn_known, eqn_staged, unks_out, inst_out, new_inst + residuals +pe.partial_eval_jaxpr_custom_rules[shard_map_p] = \ + _partial_eval_jaxpr_custom_rule + +def _add_reshapes(which: Sequence[bool], + jaxpr_known: core.Jaxpr, + jaxpr_staged: core.Jaxpr) -> tuple[core.Jaxpr, core.Jaxpr]: + # add singleton axes to residuals which are from jaxpr_known and are scalars + which_ = [w and not v.aval.shape # pytype: disable=attribute-error + for w, v in zip(which, jaxpr_staged.invars[:len(which)])] + if not any(which_): return jaxpr_known, jaxpr_staged + assert not jaxpr_known.constvars and not jaxpr_staged.constvars + + def known(*args): + out = core.eval_jaxpr(jaxpr_known, (), *args) + out_known, res = split_list(out, [len(out) - sum(which)]) + res = [_add_singleton(x) if not x.shape else x for x in res] + return [*out_known, *res] + avals_in = [v.aval for v in jaxpr_known.invars] + jaxpr_known, _, (), () = pe.trace_to_jaxpr_dynamic( + lu.wrap_init(known, debug_info=jaxpr_known.debug_info), avals_in) + + def staged(*args): + res_, ins = split_list(args, [len(which)]) + res = [_rem_singleton(x) if w else x for x, w in zip(res_, which_)] + return core.eval_jaxpr(jaxpr_staged, (), *res, *ins) + res_avals = [core.unmapped_aval(1, 0, v.aval) if w else v.aval + for w, v in zip(which_, jaxpr_staged.invars[:len(which)])] + avals_in = [*res_avals, *[v.aval for v in jaxpr_staged.invars[len(which):]]] + jaxpr_staged, _, (), () = pe.trace_to_jaxpr_dynamic( + lu.wrap_init(staged, debug_info=jaxpr_staged.debug_info), avals_in) + + return jaxpr_known, jaxpr_staged + +def _pe_custom_params(unks_in, inst_in, kept_outs_known, kept_outs_staged, + in_fwd, out_fwd, out_res_specs_known, staged_in_res_specs, + params_known, params_staged): + # prune inputs to jaxpr_known according to unks_in + in_specs_known, _ = partition_list(unks_in, params_known['in_specs']) + _, out_specs_known = partition_list(kept_outs_known, params_known['out_specs']) + out_specs_known = out_specs_known + out_res_specs_known + assert len(out_specs_known) == len(params_known['jaxpr'].outvars) + new_params_known = dict(params_known, in_specs=tuple(in_specs_known), + out_specs=tuple(out_specs_known)) + + # added num_res new inputs to jaxpr_staged, pruning according to inst_in + _, in_specs_staged = partition_list(inst_in, params_staged['in_specs']) + iter_staged = iter(staged_in_res_specs) + res_specs = [in_specs_known[f1] if f1 is not None else + out_specs_known[f2] if f2 is not None else + next(iter_staged) for f1, f2 in zip(in_fwd, out_fwd)] + + in_specs_staged = res_specs + in_specs_staged + _, out_specs_staged = partition_list(kept_outs_staged, params_staged['out_specs']) + new_params_staged = dict(params_staged, in_specs=tuple(in_specs_staged), + out_specs=tuple(out_specs_staged)) + return new_params_known, new_params_staged + +# TODO(mattjj): remove this mechanism when we revise mesh scopes +def _all_mesh_names_except_spmd( + mesh: Mesh, manual_axes: frozenset[AxisName]) -> tuple[AxisName, ...]: + axis_env = core.get_axis_env() + spmd_names = axis_env.spmd_axis_names + return tuple(name for name in mesh.axis_names + if name not in spmd_names and name in manual_axes) + +def _all_newly_manual_mesh_names( + mesh: Mesh, manual_axes: frozenset[AxisName]) -> tuple[AxisName, ...]: + axis_env = core.get_axis_env() + vmap_spmd_names = set(axis_env.spmd_axis_names) + if not (ctx_mesh := get_abstract_mesh()).empty: + mesh = ctx_mesh + already_manual_names = set(ctx_mesh.manual_axes) + else: + # TODO(mattjj): remove this mechanism when we revise mesh scopes + already_manual_names = set(axis_env.axis_sizes) # may include vmap axis_names + return tuple(name for name in mesh.axis_names + if (name not in vmap_spmd_names | already_manual_names and + name in manual_axes)) + + +# DCE + +# TODO(mattjj): de-duplicate with pe.dce_jaxpr_call_rule, and/or _pmap_dce_rule? +def _shard_map_dce(used_outputs: list[bool], eqn: core.JaxprEqn + ) -> tuple[list[bool], core.JaxprEqn | None]: + if not any(used_outputs) and not pe.has_effects(eqn): + return [False] * len(eqn.invars), None + mesh = eqn.params["mesh"] + manual_axes = eqn.params["manual_axes"] + check_vma = eqn.params["check_vma"] + with (_extend_axis_env(mesh, manual_axes), config._check_vma(check_vma), + use_abstract_mesh(_as_manual_mesh(mesh, manual_axes | set(mesh.manual_axes)))): + jaxpr, used_inputs = pe.dce_jaxpr(eqn.params['jaxpr'], used_outputs) + if not any(used_inputs) and not any(used_outputs) and not jaxpr.effects: + return used_inputs, None + else: + _, in_specs = partition_list(used_inputs, eqn.params['in_specs']) + _, out_specs = partition_list(used_outputs, eqn.params['out_specs']) + new_params = dict(eqn.params, jaxpr=jaxpr, in_specs=tuple(in_specs), + out_specs=tuple(out_specs)) + effs = core.filter_named_axis_effects(jaxpr.effects, mesh.axis_names) + new_eqn = pe.new_jaxpr_eqn( + [v for v, used in zip(eqn.invars, used_inputs) if used], + [x for x, used in zip(eqn.outvars, used_outputs) if used], + eqn.primitive, new_params, effs, eqn.source_info, eqn.ctx) + return used_inputs, new_eqn +pe.dce_rules[shard_map_p] = _shard_map_dce + +# Implementing pmap in terms of shard_map + +def pmap(f, axis_name=None, *, in_axes=0, out_axes=0, + static_broadcasted_argnums=(), devices=None, backend=None, + axis_size=None, donate_argnums=(), global_arg_shapes=None): + devices = tuple(devices) if devices is not None else devices + axis_name, static_broadcasted_tuple, donate_tuple = _shared_code_pmap( + f, axis_name, static_broadcasted_argnums, donate_argnums, in_axes, out_axes) + + def infer_params(*args, **kwargs): + p = _prepare_pmap(f, in_axes, out_axes, static_broadcasted_tuple, + donate_tuple, devices, backend, axis_size, args, kwargs) + for arg in p.flat_args: + dispatch.check_arg(arg) + mesh = Mesh(_get_devices(p, backend), (axis_name,)) + _pmapped, in_specs, out_specs = _cached_shard_map( + p.flat_fun, mesh, p.in_axes_flat, p.out_axes_thunk, axis_name) + flat_global_args = host_local_array_to_global_array( + p.flat_args, mesh, list(in_specs)) + jitted_f = jax.jit( + _pmapped, + donate_argnums=[i for i, val in enumerate(p.donated_invars) if val]) + return jitted_f, flat_global_args, p.out_tree, mesh, out_specs + + def wrapped(*args, **kwargs): + (jitted_f, flat_global_args, out_tree, mesh, + out_specs) = infer_params(*args, **kwargs) + outs = jitted_f(*flat_global_args) + outs = global_array_to_host_local_array(outs, mesh, out_specs()) + return tree_unflatten(out_tree(), outs) + + def lower(*args, **kwargs): + jitted_f, _, _, _, _ = infer_params(*args, **kwargs) + return jitted_f.lower(*args, **kwargs) + wrapped.lower = lower + + return wrapped + + +@lu.cache +def _cached_shard_map(flat_fun, mesh, in_axes_flat, out_axes_thunk, axis_name): + in_specs = tuple(map(partial(_axis_to_spec, axis_name), in_axes_flat)) + out_specs = lambda: map(partial(_axis_to_spec, axis_name), out_axes_thunk()) + fun = _handle_reshapes(flat_fun, in_axes_flat, out_axes_thunk) + return (_shard_map(fun.call_wrapped, mesh=mesh, in_specs=in_specs, + out_specs=out_specs, check_vma=False, + axis_names=set(mesh.axis_names)), + in_specs, out_specs) + +@lu.transformation2 +def _handle_reshapes(f, in_axes, out_axes_thunk, *args, **kwargs): + args = tree_map(lambda x, ax: x if ax is None else jnp.squeeze(x, axis=ax), + list(args), list(in_axes)) + out = f(*args) + return tree_map(lambda x, ax: x if ax is None else jnp.expand_dims(x, axis=ax), + list(out), list(out_axes_thunk())) + +def _axis_to_spec(axis_name, ax): + if isinstance(ax, int): + specs = [None] * ax + [axis_name] + return P(*specs) + elif ax is None: + return P() + else: + raise TypeError(ax) + +def _get_devices(p, backend): + if backend is not None and p.devices is None: + devs = jax.devices(backend=backend) + else: + devs = jax.devices() if p.devices is None else p.devices + if jax.process_count() > 1: + return devs[:p.global_axis_size] + return devs[:p.local_axis_size] diff --git a/jax/_src/sharding.py b/jax/_src/sharding.py index a9bf62b46473..f4b342deafcc 100644 --- a/jax/_src/sharding.py +++ b/jax/_src/sharding.py @@ -36,11 +36,8 @@ def _addressable_devices_indices_map( global_map = sharding.devices_indices_map(global_shape) if sharding.is_fully_addressable: return global_map - if hasattr(sharding, '_internal_device_list'): - return {d: global_map[d] - for d in sharding._internal_device_list.addressable_device_list} - return {d: ind for d, ind in global_map.items() - if d.process_index == d.client.process_index()} + return {d: global_map[d] + for d in sharding._internal_device_list.addressable_device_list} # type: ignore @cache(max_size=4096, trace_context_in_key=False) def common_devices_indices_map( @@ -174,10 +171,7 @@ def devices_indices_map(self, global_shape: Shape) -> Mapping[Device, Index]: def _addressable_device_assignment(self) -> XLADeviceAssignment: if self.is_fully_addressable: return self._device_assignment - if hasattr(self, '_internal_device_list'): - return tuple(self._internal_device_list.addressable_device_list) - return tuple(d for d in self._device_assignment - if d.process_index == d.client.process_index()) + return tuple(self._internal_device_list.addressable_device_list) # type: ignore def shard_shape(self, global_shape: Shape) -> Shape: """Returns the shape of the data on each device. @@ -192,10 +186,6 @@ def is_equivalent_to(self: Sharding, other: Sharding, ndim: int) -> bool: Two shardings are equivalent if they place the same logical array shards on the same devices. - - For example, a :class:`NamedSharding` may be equivalent - to a :class:`PositionalSharding` if both place the same shards of the array - on the same devices. """ try: return (are_op_shardings_equal(self._to_xla_hlo_sharding(ndim), diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index 2bbf913783e3..1d77874c9420 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -33,20 +33,20 @@ from jax._src import xla_bridge as xb from jax._src import mesh_utils from jax._src.lib import xla_client as xc -from jax._src.lib import xla_extension_version from jax._src.lib.mlir.dialects import sdy from jax._src.named_sharding import ( # noqa: F401 - SdyArraySharding, SdyDimSharding, UnspecifiedValue, AUTO, - ParsedPartitionSpec, _check_unique_resources, NamedSharding, UNSPECIFIED, + SdyArray, SdyDim, UnspecifiedValue, AUTO, + _check_unique_resources, NamedSharding, UNSPECIFIED, ArrayMapping, ArrayMappingOrAutoOrUnspecified, get_array_mapping, - array_mapping_to_axis_resources, get_single_pspec, preprocess, - named_sharding_to_xla_hlo_sharding) + array_mapping_to_axis_resources, named_sharding_to_xla_hlo_sharding, + modify_sdy_sharding_wrt_axis_types) from jax._src.op_shardings import ( are_op_shardings_equal, get_num_ways_dim_sharded, is_op_sharding_replicated) from jax._src.partition_spec import PartitionSpec -from jax._src.util import safe_map, safe_zip, use_cpp_class, use_cpp_method +from jax._src.util import safe_zip, use_cpp_class, use_cpp_method import numpy as np +config_ext = xc._xla.config Shape = tuple[int, ...] Device = xc.Device @@ -88,33 +88,19 @@ def device_replica_id_map(sharding, global_shape: Shape) -> Mapping[Device, int] @dataclasses.dataclass -class SdyArrayShardingList: - shardings: Sequence[SdyArraySharding] +class SdyArrayList: + shardings: Sequence[SdyArray] def build(self) -> sdy.TensorShardingPerValueAttr: return sdy.TensorShardingPerValueAttr.get( [sharding.build() for sharding in self.shardings]) -# TODO(yashkatariya): Upstream this into `_to_sdy_sharding` maybe with an extra -# parameter to it `_to_sdy_sharding(self, ndim, modify_wrt_axis_types=False)` -def modify_sdy_sharding_wrt_axis_types(sdy_sharding: SdyArraySharding, mesh): - if mesh._any_axis_auto: - dim_shardings, used_axes = [], [] # type: ignore - for d in sdy_sharding.dimension_shardings: - # TODO(yashkatariya): Maybe if any mesh axis is auto, mark all axes as open? - dim_shardings.append(SdyDimSharding(axes=[], is_closed=False) - if not d.axes and d.is_closed else d) - used_axes.extend(d.axes) - remaining_axes = set(mesh.axis_names) - set(used_axes) - replicated_axes = tuple(r for r in remaining_axes - if mesh._name_to_type[r] == mesh_lib.AxisType.Explicit) - return SdyArraySharding(sdy_sharding.mesh_shape, dim_shardings, - sdy_sharding.logical_device_ids, replicated_axes) - return sdy_sharding +replicated_hlo_sharding = xc.HloSharding.replicate() -replicated_hlo_sharding = xc.HloSharding.replicate() +def _unpickle_single_device_sharding(device, memory_kind): + return SingleDeviceSharding(device, memory_kind=memory_kind) @use_cpp_class(xc.SingleDeviceSharding) @@ -139,7 +125,7 @@ def __init__(self, device: Device, *, memory_kind: str | None = None): self._memory_kind = memory_kind def __reduce__(self): - return type(self), (self._device,), {'memory_kind': self._memory_kind} + return (_unpickle_single_device_sharding, (self._device, self._memory_kind)) def __repr__(self): mem = '' if self._memory_kind is None else f', memory_kind={self._memory_kind}' @@ -183,10 +169,10 @@ def _device_assignment(self) -> XLADeviceAssignment: def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding: return replicated_hlo_sharding - def _to_sdy_sharding(self, num_dimensions: int) -> SdyArraySharding: - sdy_dim_sharding = [SdyDimSharding(axes=[], is_closed=True) + def _to_sdy_sharding(self, num_dimensions: int) -> SdyArray: + sdy_dim_sharding = [SdyDim(axes=[], is_open=False) for _ in range(num_dimensions)] - return SdyArraySharding(None, sdy_dim_sharding) + return SdyArray(mesh_shape=None, dim_shardings=sdy_dim_sharding) @property def is_fully_replicated(self) -> bool: @@ -198,6 +184,7 @@ def is_fully_addressable(self) -> bool: return xb.process_index(self._device.client) == self._device.process_index return True +SingleDeviceSharding.__module__ = 'jax.sharding' @util.cache(max_size=4096, trace_context_in_key=False) def pmap_sharding_devices_indices_map( @@ -222,8 +209,7 @@ def __init__(self, devices: Sequence[Device] | np.ndarray, self.sharding_spec = sharding_spec def __reduce__(self): - return (type(self), (self.devices, self.sharding_spec), - {'memory_kind': self.memory_kind}) + return (type(self), (self.devices, self.sharding_spec)) def __eq__(self, other): if not isinstance(other, PmapSharding): @@ -327,8 +313,8 @@ def with_memory_kind(self, kind: str): def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding: raise NotImplementedError("pmap doesn't use OpSharding.") - def _to_sdy_sharding(self, num_dimensions: int) -> SdyArraySharding: - raise NotImplementedError("pmap doesn't use SdyArraySharding.") + def _to_sdy_sharding(self, num_dimensions: int) -> SdyArray: + raise NotImplementedError("pmap doesn't use SdyArray.") @functools.cached_property def is_fully_replicated(self) -> bool: @@ -366,213 +352,11 @@ def shard_shape(self, global_shape: Shape) -> Shape: f'the number of devices={len(self._device_assignment)}') return sharded_shape +PmapSharding.__module__ = 'jax.sharding' -def _op_sharding_to_pos_sharding( - op_sharding: xc.OpSharding | xc.HloSharding, - device_assignment: Sequence[xc.Device], - memory_kind: str | None = None) -> PositionalSharding: - if isinstance(op_sharding, xc.OpSharding): - op_sharding = xc.HloSharding.from_proto(op_sharding) - - if op_sharding.is_replicated(): - return PositionalSharding( - device_assignment, memory_kind=memory_kind).replicate() - - if len(op_sharding.subgroup_types()) > 1: - raise NotImplementedError( - 'Unhandled HloSharding type. Please open a bug report!' - ) - - name = device_assignment[0].platform.upper() - ids = np.array( - [DeviceIdSet(name, i) for i in op_sharding.tile_assignment_devices()] - ) - p = PositionalSharding._remake(tuple(device_assignment), ids, - memory_kind=memory_kind) - p = p.reshape(op_sharding.tile_assignment_dimensions()) - if op_sharding.replicate_on_last_tile_dim(): - p = p.replicate(-1, keepdims=False) - return p - - -@util.cache(max_size=4096, trace_context_in_key=False) -def _positional_sharding_to_xla_hlo_sharding( - self, num_dimensions: int) -> xc.HloSharding: - if self.shape == (1,) * self.ndim: - return replicated_hlo_sharding - - pbuf = xc.OpSharding() - shape = self.shape[self.ndim - num_dimensions:] # 'rank promotion' of val - set_size, = {len(device_set) for device_set in self._ids.flat} - pbuf.type = xc.OpSharding.Type.OTHER - if set_size > 1: - pbuf.last_tile_dims = [xc.OpSharding.Type.REPLICATED] - pbuf.tile_assignment_dimensions = (*shape, set_size) - else: - pbuf.tile_assignment_dimensions = shape - pbuf.tile_assignment_devices = [i for ids in self._ids.flat for i in ids] - product_of_dims = math.prod(pbuf.tile_assignment_dimensions) - num_devices = len(pbuf.tile_assignment_devices) - assert product_of_dims == num_devices, (product_of_dims, num_devices) - return xc.HloSharding.from_proto(pbuf) - - -class PositionalSharding(jsharding.Sharding): - _devices: tuple[xc.Device, ...] - _memory_kind: str | None - _ids: np.ndarray # dtype DeviceIdSet - - def __init__(self, devices: Sequence[xc.Device] | np.ndarray, - *, memory_kind: str | None = None): - super().__init__() - if not isinstance(devices, np.ndarray): - devices = np.array(devices, dtype='object') - if not devices.size: - raise ValueError(f"{self.__class__.__name__}.__init__ requires at least " - f"one device, got {devices}") - self._devices = tuple(devices.flat) - self._memory_kind = memory_kind - name = self._devices[0].platform.upper() - self._ids = np.array([DeviceIdSet(name, i) for i in range(devices.size)], - dtype='object').reshape(devices.shape) - self._internal_device_list = xc.DeviceList(self._devices) - self._memory_kind = xc.check_and_canonicalize_memory_kind( - self._memory_kind, self._internal_device_list) - - @property - def shape(self): - return self._ids.shape - - @property - def ndim(self): - return self._ids.ndim - - def __repr__(self) -> str: - cls_name = self.__class__.__name__ - ids = self._ids.copy() - platform_name = self._devices[0].platform.upper() - for idx, x in np.ndenumerate(ids): - ids[idx] = DeviceIdSet(platform_name, *(self._devices[i].id for i in x)) - body = np.array2string(ids, prefix=cls_name + '(', suffix=')', - max_line_width=100) - mem = '' if self._memory_kind is None else f', memory_kind={self._memory_kind}' - return f'{cls_name}({body}{mem}, shape={self.shape})' - - def reshape(self, *shape) -> PositionalSharding: - return self._remake(self._devices, self._ids.reshape(*shape), - memory_kind=self.memory_kind) - - def transpose(self, *axes) -> PositionalSharding: - return self._remake(self._devices, self._ids.transpose(*axes), - memory_kind=self.memory_kind) - T = property(transpose) - - def replicate(self, axis=None, keepdims=True) -> PositionalSharding: - new_ids = self._ids.sum(axis=axis, keepdims=keepdims) # union - return self._remake(self._devices, new_ids, - memory_kind=self.memory_kind) - - def check_compatible_aval(self, aval_shape: Shape) -> None: - if len(aval_shape) != len(self.shape) and not self.is_fully_replicated: - raise ValueError( - f"Sharding {self} is only valid for values of rank " - f"{len(self.shape)}, but was applied to a value of rank " - f"{len(aval_shape)}") - - @classmethod - def _remake( - cls, devices: tuple[xc.Device, ...], ids: np.ndarray, - *, memory_kind: str | None = None) -> PositionalSharding: - sharding = cls(devices, memory_kind=memory_kind) - sharding._ids = ids - return sharding - - # Hashable - - def __hash__(self) -> int: - if not hasattr(self, '_hash'): - self._hash = hash((self._internal_device_list, self.memory_kind)) - return self._hash - - def __eq__(self, other) -> bool: - if not isinstance(other, PositionalSharding): - return False - if self is other: - return True - all_ids_equal = np.array_equal(self._ids,other._ids) - mem_kind_equal = self.memory_kind == other.memory_kind - if self._devices is other._devices and mem_kind_equal and all_ids_equal: - return True - return (mem_kind_equal and all_ids_equal and - self._internal_device_list == other._internal_device_list) - - # Sharding interface - - @property - def num_devices(self) -> int: - return len(self.device_set) - - @functools.cached_property - def device_set(self) -> set[xc.Device]: - return set(self._devices) - - @property - def memory_kind(self) -> str | None: - return self._memory_kind - - def with_memory_kind(self, kind: str) -> PositionalSharding: - return PositionalSharding(self._devices, memory_kind=kind) - - @functools.cached_property - def is_fully_replicated(self) -> bool: - return self.shape == (1,) * self.ndim - - # jsharding.Sharding interface - - @property - def _device_assignment(self) -> XLADeviceAssignment: - return self._devices - - def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding: - return _positional_sharding_to_xla_hlo_sharding(self, num_dimensions) - - def _to_sdy_sharding(self, num_dimensions: int) -> SdyArraySharding: - raise NotImplementedError( - "PositionalSharding can't be converted to an SdyArraySharding.") - - @functools.cached_property - def is_fully_addressable(self) -> bool: - return self._internal_device_list.is_fully_addressable - - -class DeviceIdSet: - _name: str - _ids: frozenset[int] - def __init__(self, name, *ids): - self._name = name - self._ids = frozenset(ids) - - def __iter__(self): - return iter(sorted(self._ids)) - - def __add__(self, other) -> DeviceIdSet: - assert isinstance(other, DeviceIdSet) - return DeviceIdSet(self._name, *(self._ids | other._ids)) - - def __len__(self) -> int: - return len(self._ids) - - def __repr__(self) -> str: - ids = ', '.join(safe_map(str, sorted(self._ids))) - return f'{{{self._name} {ids}}}' - - def __hash__(self) -> int: - return hash((self._name, self._ids)) - - def __eq__(self, other) -> bool: - return (isinstance(other, DeviceIdSet) and self._name == other._name and - self._ids == other._ids) +def _unpickle_gspmd_sharding(devices, op_sharding, memory_kind): + return GSPMDSharding(devices, op_sharding, memory_kind=memory_kind) @use_cpp_class(xc.GSPMDSharding) class GSPMDSharding(jsharding.Sharding): @@ -595,8 +379,8 @@ def __init__(self, devices: Sequence[Device], self._memory_kind = memory_kind def __reduce__(self): - return (type(self), (self._devices, self._hlo_sharding.to_proto()), - {'memory_kind': self._memory_kind}) + return (_unpickle_gspmd_sharding, + (self._devices, self._hlo_sharding.to_proto(), self._memory_kind)) @functools.cached_property def _hlo_sharding_hash(self): @@ -653,9 +437,26 @@ def _device_assignment(self) -> XLADeviceAssignment: def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding: return self._hlo_sharding - def _to_sdy_sharding(self, num_dimensions: int) -> SdyArraySharding: - raise NotImplementedError( - "GSPMDSharding can't be converted to SdyArraySharding.") + def _to_sdy_sharding(self, num_dimensions: int) -> SdyArray: + if self._hlo_sharding.tuple_elements(): + raise TypeError( + f'Cannot convert GSPMDSharding {self._hlo_sharding} into SdyArray.') + elif self._hlo_sharding.is_replicated(): + empty_mesh = mesh_lib.AbstractMesh((), ()) + return NamedSharding(empty_mesh, PartitionSpec())._to_sdy_sharding( + num_dimensions) + elif self._hlo_sharding.is_tiled(): + if not self._hlo_sharding.is_tile_assignment_iota(): + raise TypeError( + f'Cannot convert GSPMDSharding {self._hlo_sharding} into SdyArray.') + axis_sizes = tuple(self._hlo_sharding.get_axis_sizes()) + axis_names = tuple(f'_axis_{i}' for i in range(len(axis_sizes))) + mesh = mesh_lib.AbstractMesh(axis_sizes, axis_names) + return _gspmd_to_named_sharding_via_mesh(self, mesh)._to_sdy_sharding( + num_dimensions) + else: + raise TypeError( + f'Cannot convert GSPMDSharding {self._hlo_sharding} into SdyArray.') @functools.cached_property def is_fully_replicated(self) -> bool: @@ -689,14 +490,20 @@ def prepare_axis_resources(axis_resources, arg_name, if isinstance(entry, PmapSharding): raise ValueError(f'One of {what} got sharding {entry} which is not ' 'allowed.') + if (not allow_unconstrained_dims and isinstance(entry, NamedSharding) and + PartitionSpec.UNCONSTRAINED in entry.spec): + raise ValueError( + f'Unconstrained dims are not allowed when passed to {arg_name}:' + f' {entry}') new_entries.append(entry) else: if not isinstance(entry, PartitionSpec): raise TypeError(f"{what} are expected to be " f"PartitionSpec instances or None, but got {entry}") - for e in entry: - if e is PartitionSpec.UNCONSTRAINED and not allow_unconstrained_dims: - raise ValueError(f"Unconstrained dims are not allowed: {entry}") + if not allow_unconstrained_dims and PartitionSpec.UNCONSTRAINED in entry: + raise ValueError( + f'Unconstrained dims are not allowed when passed to {arg_name}:' + f' {entry}') _check_unique_resources(entry, arg_name) new_entries.append(entry) @@ -882,8 +689,7 @@ def parse_flatten_op_sharding( return out elif hlo_sharding.is_replicated(): return [PartitionSpec()] - elif (xla_extension_version >= 319 and hlo_sharding.is_maximal() - and mesh.size == 1): + elif hlo_sharding.is_maximal() and mesh.size == 1: return [PartitionSpec()] elif hlo_sharding.is_tiled(): mesh_shape = mesh.shape @@ -898,7 +704,11 @@ def parse_flatten_op_sharding( while dim_size > 1: axis = next(mesh_axis) axis_size = mesh_shape[axis] - assert dim_size % axis_size == 0 + if dim_size % axis_size != 0: + raise ValueError( + f'{shape=} is incompatible with {mesh_shape=}: ' + f'{dim_size=} is not divisible by {axis_size=}.' + ) dim_size //= axis_size dim_partitions.append(axis) partitions.append(tuple(dim_partitions)) @@ -1169,7 +979,7 @@ def make_key_array_phys_sharding(aval, sharding): elif isinstance(sharding, NamedSharding): elt_aval = core.physical_element_aval(aval.dtype) trailing_spec = [None] * elt_aval.ndim - return sharding.with_spec(PartitionSpec(*sharding.spec, *trailing_spec)) + return sharding.update(spec=PartitionSpec(*sharding.spec, *trailing_spec)) else: hlos = sharding._to_xla_hlo_sharding(aval.ndim) return GSPMDSharding( @@ -1231,15 +1041,16 @@ def logical_sharding(logical_shape, dtype, phys_sharding) -> jsharding.Sharding: phys_spec = (*phys_sharding.spec, *[None] * (len(phys_shape) - len(phys_sharding.spec))) else: - phys_spec = phys_sharding.spec - return phys_sharding.with_spec(phys_spec[:-elt_aval.ndim]) + phys_spec = phys_sharding.spec # type: ignore + return phys_sharding.update(spec=phys_spec[:-elt_aval.ndim]) else: return get_logical_gspmd_sharding(logical_shape, dtype, phys_sharding) @util.cache() def create_mesh_pspec_sharding( - mesh: mesh_lib.Mesh, pspec: PartitionSpec | None, + mesh: mesh_lib.Mesh | mesh_lib.AbstractMesh, + pspec: PartitionSpec | None, memory_kind: str | None = None) -> NamedSharding: if pspec is None: pspec = PartitionSpec() @@ -1247,11 +1058,13 @@ def create_mesh_pspec_sharding( def _gspmd_to_named_sharding_via_mesh( - out_s: GSPMDSharding, mesh: mesh_lib.Mesh) -> NamedSharding: + out_s: GSPMDSharding, mesh: mesh_lib.Mesh | mesh_lib.AbstractMesh +) -> NamedSharding: spec = parse_flatten_op_sharding(out_s._hlo_sharding, mesh)[0] return create_mesh_pspec_sharding( mesh, spec, memory_kind=out_s.memory_kind) + def flatten_spec(spec): out = [] for s in spec: @@ -1347,9 +1160,14 @@ def make_mesh(axis_shapes: Sequence[int], axis_names: Sequence[str], axis_names: Names of the mesh axes. For example, axis_names=('x', 'y') devices: Optional keyword only argument, that allows you to specify the devices you want to create a mesh with. + axis_types: and optional tuple of :class:`jax.sharding.AxisType` entries + corresponding to the ``axis_names``. See `Explicit Sharding`_ for more + information. Returns: - A `jax.sharding.Mesh` object. + A :class:`jax.sharding.Mesh` object. + + .. _Explicit Sharding: https://docs.jax.dev/en/latest/notebooks/explicit-sharding.html """ if devices is None: devices = xb.devices() @@ -1382,10 +1200,8 @@ def use_mesh(mesh: mesh_lib.Mesh): if not isinstance(mesh, mesh_lib.Mesh): raise ValueError( f"Expected mesh of type `jax.sharding.Mesh`. Got {type(mesh)}") - - # TODO(yashkatariya): Enable this. - # if not core.trace_state_clean(): - # raise ValueError('`use_mesh` can only be used outside of `jax.jit`') + if not core.trace_state_clean(): + raise ValueError('`use_mesh` can only be used outside of `jax.jit`') with mesh_lib.use_abstract_mesh(mesh.abstract_mesh), use_concrete_mesh(mesh): yield @@ -1396,27 +1212,30 @@ def set_mesh(mesh: mesh_lib.Mesh | None) -> mesh_lib.Mesh | None: if mesh is not None and not isinstance(mesh, mesh_lib.Mesh): raise ValueError( f"Expected mesh of type `jax.sharding.Mesh`. Got {type(mesh)}") + assert mesh is None or isinstance(mesh, mesh_lib.Mesh) if not core.trace_state_clean(): raise ValueError('`set_mesh` can only be used outside of `jax.jit`.') if mesh is None: - config.abstract_mesh_context_manager.set_global(mesh_lib.empty_abstract_mesh) # type: ignore + config.abstract_mesh_context_manager.set_local(mesh_lib.empty_abstract_mesh) # type: ignore else: - config.abstract_mesh_context_manager.set_global(mesh.abstract_mesh) # type: ignore + config.abstract_mesh_context_manager.set_local(mesh.abstract_mesh) # type: ignore - prev_mesh = config.device_context.get_global() - config.device_context.set_global(mesh) - return prev_mesh + prev_mesh = config.device_context.swap_local(mesh) + return None if prev_mesh is config_ext.unset else prev_mesh @contextlib.contextmanager def use_concrete_mesh(mesh: mesh_lib.Mesh | None): + if not core.trace_state_clean(): + raise ValueError('`use_concrete_mesh` can only be used outside of `jax.jit`.') + with _internal_use_concrete_mesh(mesh): + yield + +@contextlib.contextmanager +def _internal_use_concrete_mesh(mesh: mesh_lib.Mesh | None): if mesh is not None and not isinstance(mesh, mesh_lib.Mesh): raise ValueError( f"Expected mesh of type `jax.sharding.Mesh`. Got {type(mesh)}") - # TODO(yashkatariya): Enable this. - # if not core.trace_state_clean(): - # raise ValueError('`use_concrete_mesh` can only be used outside of `jax.jit`.') - prev_val = config.device_context.swap_local(mesh) try: yield diff --git a/jax/_src/source_info_util.py b/jax/_src/source_info_util.py index b1901f44f022..396ea541c75f 100644 --- a/jax/_src/source_info_util.py +++ b/jax/_src/source_info_util.py @@ -21,7 +21,6 @@ import itertools import os.path import re -import sys import sysconfig import threading import types @@ -159,21 +158,15 @@ def is_user_filename(filename: str) -> bool: return (_include_path_regex().search(filename) is not None or _exclude_path_regex().search(filename) is None) -if sys.version_info >= (3, 11): - def raw_frame_to_frame(code: types.CodeType, lasti: int) -> Frame: - loc = xla_client.Traceback.code_addr2location(code, lasti) - start_line, start_column, end_line, end_column = loc - return Frame(file_name=code.co_filename, - function_name=code.co_qualname, - start_line=start_line, start_column=start_column, - end_line=end_line, end_column=end_column) -else: - def raw_frame_to_frame(code: types.CodeType, lasti: int) -> Frame: - # pre-3.11 co_qualname does not exist, use co_name - return Frame(file_name=code.co_filename, - function_name=code.co_name, - start_line=xla_client.Traceback.code_addr2line(code, lasti), - start_column=0, end_line=0, end_column=0) + +def raw_frame_to_frame(code: types.CodeType, lasti: int) -> Frame: + loc = xla_client.Traceback.code_addr2location(code, lasti) + start_line, start_column, end_line, end_column = loc + return Frame(file_name=code.co_filename, + function_name=code.co_qualname, + start_line=start_line, start_column=start_column, + end_line=end_line, end_column=end_column) + def user_frames(source_info: SourceInfo) -> Iterator[Frame]: """Iterator over the user's frames, filtering jax-internal frames.""" diff --git a/jax/_src/stages.py b/jax/_src/stages.py index 19cd0822aa58..fcf3d5f6176d 100644 --- a/jax/_src/stages.py +++ b/jax/_src/stages.py @@ -30,23 +30,26 @@ """ from __future__ import annotations +import dataclasses +import enum import functools from collections.abc import Sequence from dataclasses import dataclass from typing import Any, NamedTuple, Protocol, Union, runtime_checkable -import jax - from jax._src import core from jax._src import config +from jax._src import sharding as sharding_lib from jax._src import source_info_util from jax._src import traceback_util from jax._src import tree_util +from jax._src import typing from jax._src import util from jax._src.sharding_impls import UnspecifiedValue, AUTO -from jax._src.layout import Layout +from jax._src.layout import Format, DeviceLocalLayout from jax._src.interpreters import mlir from jax._src.lib.mlir import ir +from jax._src.lib import _jax from jax._src.lib import xla_client as xc @@ -54,44 +57,55 @@ traceback_util.register_exclusion(__file__) -xla_extension = xc._xla map, unsafe_map = util.safe_map, map zip, unsafe_zip = util.safe_zip, zip CompilerOptions = dict[str, Union[str, bool]] -# -- Internal protocols +# -- Internal types + -class Executable(Protocol): - """Protocol for executables, which a user-facing ``Compiled`` encapsulates.""" +class Executable: + + def xla_extension_executable(self) -> xc.LoadedExecutable: + raise NotImplementedError( + "compiled executable carries no loaded XLA executable. It may be " + f"that {type(self)} defines an incomplete implementation.") def call(self, *args_flat) -> Sequence[Any]: """Execute on the flat list of arguments, returning flat outputs.""" - # TODO(frostig): improve annotation (sequences of arrays/buffers) - raise NotImplementedError + raise NotImplementedError("compiled executable does not support invocation") - def input_shardings(self) -> Sequence[jax.sharding.Sharding]: + def create_cpp_call(self, no_kwargs, in_tree, out_tree) -> Any: + """Optionally constructs a fast c++ dispatcher.""" + return None + + def input_shardings(self) -> Sequence[sharding_lib.Sharding]: """Flat sequence of input shardings. May raise ``NotImplementedError`` if unavailable, e.g. based on backend, compiler, or runtime. """ - raise NotImplementedError + raise NotImplementedError( + "compiled executable carries no input sharding information") - def output_shardings(self) -> Sequence[jax.sharding.Sharding]: + def output_shardings(self) -> Sequence[sharding_lib.Sharding]: """Flat sequence of output shardings. May raise ``NotImplementedError`` if unavailable, e.g. based on backend, compiler, or runtime. """ - raise NotImplementedError + raise NotImplementedError( + "compiled executable carries no output sharding information") - def input_layouts(self): - raise NotImplementedError + def input_formats(self): + raise NotImplementedError( + "compiled executable carries no input layout information") - def output_layouts(self): - raise NotImplementedError + def output_formats(self): + raise NotImplementedError( + "compiled executable carries no output layout information") def as_text(self) -> str: """A human-readable text representation of this executable. @@ -102,89 +116,19 @@ def as_text(self) -> str: May raise ``NotImplementedError`` if unavailable, e.g. based on backend, compiler, or runtime. """ - raise NotImplementedError - - def cost_analysis(self) -> Any: - """A summary of execution cost estimates. - - Intended for visualization and debugging purposes. The object output by - this is some simple data structure that can easily be printed or serialized - (e.g. nested dicts, lists, and tuples with numeric leaves). However, its - structure can be arbitrary: it need not be consistent across versions of JAX - and jaxlib, or even across invocations. It is relayed directly to external - callers. - - May raise ``NotImplementedError`` if unavailable, e.g. based on backend, - compiler, or runtime. - """ - # TODO(frostig): improve annotation (arbitrary pytree) - raise NotImplementedError - - def memory_analysis(self) -> Any: - """A summary of estimated memory requirements. - - Intended for visualization and debugging purposes. The object output by - this is some simple data structure that can easily be printed or serialized - (e.g. nested dicts, lists, and tuples with numeric leaves). However, its - structure can be arbitrary: it need not be consistent across versions of JAX - and jaxlib, or even across invocations. It is relayed directly to external - callers. - - May raise ``NotImplementedError`` if unavailable, e.g. based on backend, - compiler, or runtime. - """ - # TODO(frostig): improve annotation (arbitrary pytree) - raise NotImplementedError - - def runtime_executable(self) -> Any: - """An arbitrary object representation of this executable. - - Intended for debugging purposes. This need not be a valid nor reliable - serialization. It is relayed directly to external callers, with no - guarantee on type, structure, or consistency across invocations. - - May raise ``NotImplementedError`` if unavailable, e.g. based on backend or - compiler. - """ - raise NotImplementedError - - def create_cpp_call(self, no_kwargs, in_tree, out_tree) -> Any: - """Optionally constructs a fast c++ dispatcher.""" - return None - - -class Lowering(Protocol): - """Protocol for lowerings, which a user-facing ``Lowered`` encapsulates.""" - - def compile( - self, compiler_options: CompilerOptions | None = None) -> Executable: - """Compile and return a corresponding ``Executable``.""" - raise NotImplementedError - - def as_text(self, dialect: str | None = None, *, - debug_info: bool = False) -> str: - """A human-readable text representation of this lowering. - - Intended for visualization and debugging purposes. This need not be a valid - nor reliable serialization. It is relayed directly to external callers. - """ - raise NotImplementedError - - def compiler_ir(self, dialect: str | None = None) -> Any: - """An arbitrary object representation of this lowering. - - Intended for debugging purposes. This need not be a valid nor reliable - serialization. It is relayed directly to external callers, with no - guarantee on type, structure, or consistency across invocations. - - May raise ``NotImplementedError`` if unavailable, e.g. based on backend or - compiler. - - Args: - dialect: Optional string specifying a representation dialect - (e.g. "stablehlo") - """ - raise NotImplementedError + xla_ext_exe = self.xla_extension_executable() + err_msg = ("text view unsupported on current XLA backend: " + f"{type(xla_ext_exe)}") + if not hasattr(xla_ext_exe, "hlo_modules"): + raise NotImplementedError(err_msg) + try: + return "\n\n".join([m.to_string() for m in xla_ext_exe.hlo_modules()]) + except _jax.XlaRuntimeError as e: + msg, *_ = e.args + if type(msg) is str and msg.startswith("UNIMPLEMENTED"): + raise NotImplementedError(err_msg) from e + else: + raise def cost_analysis(self) -> Any: """A summary of execution cost estimates. @@ -196,66 +140,15 @@ def cost_analysis(self) -> Any: and jaxlib, or even across invocations. It is relayed directly to external callers. - This function estimates execution cost in the absence of compiler - optimizations, which may drastically affect the cost. For execution cost - estimates after optimizations, compile this lowering and see - ``Compiled.cost_analysis``. - May raise ``NotImplementedError`` if unavailable, e.g. based on backend, compiler, or runtime. """ - # TODO(frostig): improve annotation (arbitrary pytree) - raise NotImplementedError - - -# -- Internal adapters from XLA-related objects to the above protocols - -class XlaExecutable(Executable): - - def xla_extension_executable(self) -> xc.LoadedExecutable: - raise NotImplementedError("must override") - - def call(self, *args_flat) -> Sequence[Any]: - raise NotImplementedError("must override") - - def input_shardings(self) -> Sequence[jax.sharding.Sharding]: - raise NotImplementedError( - "compiled executable carries no input sharding information") - - def output_shardings(self) -> Sequence[jax.sharding.Sharding]: - raise NotImplementedError( - "compiled executable carries no output sharding information") - - def input_layouts(self): - raise NotImplementedError( - "compiled executable carries no input layout information") - - def output_layouts(self): - raise NotImplementedError( - "compiled executable carries no input layout information") - - def as_text(self) -> str: - xla_ext_exe = self.xla_extension_executable() - err_msg = ("text view unsupported on current XLA backend: " - f"{type(xla_ext_exe)}") - if not hasattr(xla_ext_exe, "hlo_modules"): - raise NotImplementedError(err_msg) - try: - return "\n\n".join([m.to_string() for m in xla_ext_exe.hlo_modules()]) - except xla_extension.XlaRuntimeError as e: - msg, *_ = e.args - if type(msg) is str and msg.startswith("UNIMPLEMENTED"): - raise NotImplementedError(err_msg) from e - else: - raise - - def cost_analysis(self) -> dict[str, float]: xla_ext_exe = self.xla_extension_executable() if hasattr(xla_ext_exe, "cost_analysis"): try: return xla_ext_exe.cost_analysis() - except xla_extension.XlaRuntimeError as e: + except _jax.XlaRuntimeError as e: msg, *_ = e.args if not (type(msg) is str and msg.startswith("UNIMPLEMENTED")): raise @@ -273,6 +166,18 @@ def cost_analysis(self) -> dict[str, float]: ) def memory_analysis(self) -> Any: + """A summary of estimated memory requirements. + + Intended for visualization and debugging purposes. The object output by + this is some simple data structure that can easily be printed or serialized + (e.g. nested dicts, lists, and tuples with numeric leaves). However, its + structure can be arbitrary: it need not be consistent across versions of JAX + and jaxlib, or even across invocations. It is relayed directly to external + callers. + + May raise ``NotImplementedError`` if unavailable, e.g. based on backend, + compiler, or runtime. + """ xla_ext_exe = self.xla_extension_executable() err_msg = ("memory analysis unsupported on current XLA backend: " f"{type(xla_ext_exe)}") @@ -280,7 +185,7 @@ def memory_analysis(self) -> Any: raise NotImplementedError(err_msg) try: return xla_ext_exe.get_compiled_memory_stats() - except xla_extension.XlaRuntimeError as e: + except _jax.XlaRuntimeError as e: msg, *_ = e.args if type(msg) is str and msg.startswith("UNIMPLEMENTED"): raise NotImplementedError(err_msg) from e @@ -288,11 +193,19 @@ def memory_analysis(self) -> Any: raise def runtime_executable(self) -> Any: + """An arbitrary object representation of this executable. + + Intended for debugging purposes. This need not be a valid nor reliable + serialization. It is relayed directly to external callers, with no + guarantee on type, structure, or consistency across invocations. + + May raise ``NotImplementedError`` if unavailable, e.g. based on backend or + compiler. + """ return self.xla_extension_executable() -class XlaLowering(Lowering): - """Adapts our various internal XLA-backed computations into a ``Lowering``.""" +class Lowering: compile_args: dict[str, Any] @@ -301,20 +214,28 @@ def hlo(self) -> xc.XlaComputation: hlo = self.stablehlo() m: str | bytes m = mlir.module_to_bytecode(hlo) - return xla_extension.mlir.mlir_module_to_xla_computation( + return _jax.mlir.mlir_module_to_xla_computation( m, use_tuple_args=self.compile_args["tuple_args"]) def stablehlo(self) -> ir.Module: """Return a StableHLO representation of this computation.""" - raise NotImplementedError("must override") + raise NotImplementedError( + f"cost analysis unsupported on XLA computation: {type(self)}") def compile( self, compiler_options: CompilerOptions | None = None) -> Executable: - raise NotImplementedError("must override") + """Compile and return a corresponding ``Executable``.""" + raise NotImplementedError( + f"cost analysis unsupported on XLA computation: {type(self)}") def as_text(self, dialect: str | None = None, *, debug_info: bool = False) -> str: + """A human-readable text representation of this lowering. + + Intended for visualization and debugging purposes. This need not be a valid + nor reliable serialization. It is relayed directly to external callers. + """ if dialect is None: dialect = "stablehlo" if dialect == "stablehlo": @@ -328,6 +249,19 @@ def as_text(self, dialect: str | None = None, raise ValueError(f"unknown dialect: {dialect}") def compiler_ir(self, dialect: str | None = None) -> Any: + """An arbitrary object representation of this lowering. + + Intended for debugging purposes. This need not be a valid nor reliable + serialization. It is relayed directly to external callers, with no + guarantee on type, structure, or consistency across invocations. + + May raise ``NotImplementedError`` if unavailable, e.g. based on backend or + compiler. + + Args: + dialect: Optional string specifying a representation dialect + (e.g. "stablehlo") + """ if dialect is None: dialect = "stablehlo" if dialect == "stablehlo": @@ -337,8 +271,26 @@ def compiler_ir(self, dialect: str | None = None) -> Any: else: raise ValueError(f"unknown dialect: {dialect}") - def cost_analysis(self) -> dict[str, float]: - raise NotImplementedError("must override") + def cost_analysis(self) -> Any: + """A summary of execution cost estimates. + + Intended for visualization and debugging purposes. The object output by + this is some simple data structure that can easily be printed or serialized + (e.g. nested dicts, lists, and tuples with numeric leaves). However, its + structure can be arbitrary: it need not be consistent across versions of JAX + and jaxlib, or even across invocations. It is relayed directly to external + callers. + + This function estimates execution cost in the absence of compiler + optimizations, which may drastically affect the cost. For execution cost + estimates after optimizations, compile this lowering and see + ``Compiled.cost_analysis``. + + May raise ``NotImplementedError`` if unavailable, e.g. based on backend, + compiler, or runtime. + """ + raise NotImplementedError( + f"cost analysis unsupported on XLA computation: {type(self)}") # -- Public-facing API, plus helpers @@ -360,8 +312,8 @@ def dtype(self): @dataclass(frozen=True) class OutInfo: shape: tuple[int, ...] - dtype: jax.typing.DTypeLike - sharding: jax.sharding.Sharding | None = None + dtype: typing.DTypeLike + sharding: sharding_lib.Sharding | None = None class Stage: @@ -486,37 +438,48 @@ def runtime_executable(self) -> Any | None: """ return self._executable.runtime_executable() - @property - def input_shardings(self): # PyTree[sharding.Sharding] - shardings_flat = self._executable.input_shardings() + def _input_shardings_flat(self): + shardings_flat = self._executable._in_shardings # Some input shardings got DCE'd if self.in_tree.num_leaves > len(shardings_flat): iter_shardings_flat = iter(shardings_flat) shardings_flat = [next(iter_shardings_flat) if i in self._executable._kept_var_idx else None for i in range(self.in_tree.num_leaves)] + return shardings_flat + + @property + def input_shardings(self): # -> PyTree[sharding.Sharding] + shardings_flat = self._input_shardings_flat() return tree_util.tree_unflatten(self.in_tree, shardings_flat) # pytype: disable=attribute-error @property - def output_shardings(self): # PyTree[sharding.Sharding] - shardings_flat = self._executable.output_shardings() + def output_shardings(self): # -> PyTree[sharding.Sharding] + shardings_flat = self._executable._out_shardings return tree_util.tree_unflatten(self.out_tree, shardings_flat) # pytype: disable=attribute-error - @property - def input_layouts(self): - layouts_flat = self._executable.input_layouts() - assert all(isinstance(l, Layout) for l in layouts_flat) + def _input_layouts_flat(self): + layouts_flat = self._executable._xla_in_layouts # Some input layouts got DCE'd if self.in_tree.num_leaves > len(layouts_flat): iter_layouts_flat = iter(layouts_flat) layouts_flat = [next(iter_layouts_flat) if i in self._executable._kept_var_idx - else Layout() for i in range(self.in_tree.num_leaves)] - return tree_util.tree_unflatten(self.in_tree, layouts_flat) # pytype: disable=attribute-error + else None for i in range(self.in_tree.num_leaves)] + return layouts_flat + + @property + def input_formats(self): + layouts_flat = self._input_layouts_flat() + shardings_flat = self._input_shardings_flat() + formats_flat = [Format(l, s) for l, s in zip(layouts_flat, shardings_flat)] + return tree_util.tree_unflatten(self.in_tree, formats_flat) # pytype: disable=attribute-error @property - def output_layouts(self): - layouts_flat = self._executable.output_layouts() - assert all(isinstance(l, Layout) for l in layouts_flat) - return tree_util.tree_unflatten(self.out_tree, layouts_flat) # pytype: disable=attribute-error + def output_formats(self): + layouts_flat = self._executable._xla_out_layouts + shardings_flat = self._executable._out_shardings + assert all(isinstance(l, DeviceLocalLayout) for l in layouts_flat) + formats_flat = [Format(l, s) for l, s in zip(layouts_flat, shardings_flat)] + return tree_util.tree_unflatten(self.out_tree, formats_flat) # pytype: disable=attribute-error @staticmethod def call(*args, **kwargs): @@ -593,14 +556,14 @@ class Lowered(Stage): lowering paths (:func:`~jax.jit`, :func:`~jax.pmap`, etc.). """ __slots__ = ["_lowering", "args_info", "out_tree", "_no_kwargs"] - _lowering: XlaLowering + _lowering: Lowering args_info: Any # PyTree of ArgInfo out_tree: tree_util.PyTreeDef _no_kwargs: bool def __init__( self, - lowering: XlaLowering, + lowering: Lowering, args_info, # PyTree of ArgInfo out_tree: tree_util.PyTreeDef, no_kwargs: bool = False): @@ -612,7 +575,7 @@ def __init__( @classmethod def from_flat_info(cls, - lowering: XlaLowering, + lowering: Lowering, in_tree: tree_util.PyTreeDef, in_avals, donate_argnums: tuple[int, ...], @@ -737,9 +700,6 @@ def out_info(self): def lower(self, *, lowering_platforms: tuple[str, ...] | None = None, _private_parameters: mlir.LoweringParameters | None = None): """Lower to compiler input, returning a ``Lowered`` instance.""" - from jax._src.interpreters import pxla - from jax._src import pjit - if _private_parameters is None: _private_parameters = mlir.LoweringParameters() new_callable = functools.partial( @@ -747,9 +707,9 @@ def lower(self, *, lowering_platforms: tuple[str, ...] | None = None, lowering_parameters=_private_parameters) try: lowering = new_callable() - except pxla.DeviceAssignmentMismatchError as e: + except DeviceAssignmentMismatchError as e: fails, = e.args - msg = pjit._device_assignment_mismatch_error( + msg = _device_assignment_mismatch_error( self.fun_name, fails, self._args_flat, 'jit', self._arg_names) raise ValueError(msg) from None return Lowered(lowering, self.args_info, self._out_tree) @@ -793,3 +753,108 @@ def lower(self, *args, **kwargs) -> Lowered: A ``Lowered`` instance representing the lowering. """ raise NotImplementedError + + +class MismatchType(enum.Enum): + ARG_SHARDING = 0 + OUT_SHARDING = 1 + SHARDING_INSIDE_COMPUTATION = 2 + CONTEXT_DEVICES = 3 + IN_SHARDING = 4 + + def __str__(self): + if self.name == 'IN_SHARDING': + return 'explicit input sharding' + elif self.name == 'OUT_SHARDING': + return 'explicit output sharding' + elif self.name == 'CONTEXT_DEVICES': + return 'context mesh' + return f'{self.name}' + + +class SourceInfo(NamedTuple): + source_info: source_info_util.SourceInfo + eqn_name: str + + +@dataclasses.dataclass +class DeviceAssignmentMismatch: + da: Sequence[xc.Device] + m_type: MismatchType + source_info: SourceInfo | None + + @property + def device_ids(self) -> Sequence[int]: + return [d.id for d in self.da] + + @property + def platform(self) -> str: + return self.da[0].platform.upper() + + def _maybe_api_name(self, api_name) -> str: + return f" {api_name}'s" if self.m_type == MismatchType.CONTEXT_DEVICES else "" + + @property + def source_info_str(self): + return ( + "" if self.source_info is None + else f" at {source_info_util.summarize(self.source_info.source_info)}" + ) + + @property + def _dev_ids_plat_str(self): + return f"device ids {self.device_ids} on platform {self.platform}" + + def m_type_str(self, api_name): + return (f'{self.source_info and self.source_info.eqn_name} inside {api_name}' + if self.m_type == MismatchType.SHARDING_INSIDE_COMPUTATION else self.m_type) + + def _str(self, api_name): + return (f"{self._maybe_api_name(api_name)} {self.m_type_str(api_name)} with " + f"{self._dev_ids_plat_str}{self.source_info_str}") + + +class DeviceAssignmentMismatchError(Exception): + pass + + +def _find_arg_mismatch(arg_list, fails, fun_name): + mismatched_args_msg = [] + def mismatch(err): + for name, inp_da, aval in arg_list: + if err.m_type == MismatchType.ARG_SHARDING and err.da == inp_da: + mismatched_args_msg.append( + f"argument {name} of {fun_name} with shape {aval.str_short()} and " + f"{err._dev_ids_plat_str}") + break + first_err, second_err = fails + mismatch(first_err) + mismatch(second_err) + return mismatched_args_msg + + +def _device_assignment_mismatch_error(fun_name, fails, args_flat, api_name, + arg_names): + arg_list = [] + if arg_names is None: + arg_names = [''] * len(args_flat) + for a, n in zip(args_flat, arg_names): + da = (a.sharding._device_assignment + if getattr(a, 'sharding', None) is not None else None) + arg_list.append((n, da, core.shaped_abstractify(a))) + + mismatched_args_msg = _find_arg_mismatch(arg_list, fails, fun_name) + + if len(mismatched_args_msg) == 2: + first, second = mismatched_args_msg # pytype: disable=bad-unpacking + extra_msg = f" Got {first} and {second}" + elif len(mismatched_args_msg) == 1: + first, second = fails + # Choose the failure left which is not already covered by ARG_SHARDING. + left = second if first.m_type == MismatchType.ARG_SHARDING else first + extra_msg = f" Got {mismatched_args_msg[0]} and{left._str(api_name)}" + else: + first, second = fails + extra_msg = f" Got{first._str(api_name)} and{second._str(api_name)}" + msg = (f"Received incompatible devices for {api_name}ted computation.{extra_msg}") + return msg diff --git a/jax/_src/state/discharge.py b/jax/_src/state/discharge.py index 7ab77d5b1c37..9dce3297b947 100644 --- a/jax/_src/state/discharge.py +++ b/jax/_src/state/discharge.py @@ -25,6 +25,8 @@ from jax._src import api_util from jax._src import core from jax._src import linear_util as lu +from jax._src import pjit +from jax._src import sharding_impls from jax._src import source_info_util from jax._src import tree_util from jax._src.interpreters import ad @@ -208,14 +210,13 @@ def _eval_jaxpr_discharge_state( return out_vals + ref_vals def _is_trivial_indexer(indexer: indexing.NDIndexer): + """Returns whether the indexer selects the entire shape.""" for s, idx in zip(indexer.shape, indexer.indices): if not isinstance(idx, indexing.Slice): return False - if not isinstance(idx.start, int): + if idx.is_dynamic_start or idx.is_dynamic_size: return False - if idx.start: - return False - if idx.size != s: + if idx.start != 0 or idx.size != s: return False return True @@ -275,33 +276,97 @@ def _maybe_convert_to_dynamic_slice( return starts, sizes, squeeze_dims -def _convert_to_array_indexer(indexer: indexing.NDIndexer - ) -> tuple[int | Array, ...]: - # This is the general gather case. We need to create the gather arrays. - is_integer_indexer, _, integer_indexer = ( - indexing.unpack_ndindexer(indexer) +# In this code, indexing is handled in three ways: `slice`, `dynamic_slice`, and +# gather. For the gather case, the goal is to create a gather array, which means +# that we need to convert all other types of indexers into integer array +# indexers. This is done by looping over all indexers and checking if they are +# not integer array indexers, and if not, performing the conversion. However, +# during this process, the indexing semantics may change. Specifically, +# according to the indexing rules of NumPy, when there are integer array +# indexers separated by other indexers, the axes corresponding to the integer +# array indexers need to be moved to the front. After we convert all other +# indexers to integer array indexers, the distinction between integer array +# indexers and other types of indexers is lost. As a result, it becomes +# impossible to determine which axes should be moved to the front. In this case, +# we need to transpose the target array before the gather operation. We also +# need to transpose the target array back after the gather operation, if it is +# used in subsequent computations. +def _maybe_transpose_before_gather( + indexer: indexing.NDIndexer +) -> tuple[int, ...] | None: + is_int_indexing, _, _ = indexing.unpack_ndindexer(indexer) + + int_indexers_contiguous = bool( + np.all(np.diff(np.where(is_int_indexing)[0]) == 1) ) - total_shape = indexer.get_indexer_shape() - int_indexer_shape = indexer.int_indexer_shape - slice_shape = total_shape[len(int_indexer_shape):] - slice_dims = tuple( - i + len(int_indexer_shape) for i in range(len(slice_shape)) + if int_indexers_contiguous: + return None # no transpose needed + + int_indexer_idxs: list[int] = [] + non_int_indexer_idxs: list[int] = [] + for i, is_int_index in enumerate(is_int_indexing): + (int_indexer_idxs if is_int_index else non_int_indexer_idxs).append(i) + transpose_order = (*int_indexer_idxs, *non_int_indexer_idxs) + return transpose_order + + +def _perform_transpose_before_gather( + target_arr: Array, + indexer: indexing.NDIndexer, + transpose_order: tuple[int, ...], +) -> tuple[Array, indexing.NDIndexer]: + new_target_arr = target_arr.transpose(transpose_order) + reordered_indices = tuple(indexer.indices[i] for i in transpose_order) + new_indexer = indexing.NDIndexer( + indices=reordered_indices, + shape=indexer.shape, + int_indexer_shape=indexer.int_indexer_shape, ) - slice_dim_iter = iter(slice_dims) - slice_indexer: list[Array] = [] - for idx, is_int_index in zip(indexer.indices, is_integer_indexer): - if not is_int_index: - assert isinstance(idx, indexing.Slice) - slice_indices = lax.broadcasted_iota( - np.dtype("int32"), total_shape, next(slice_dim_iter) - ) * idx.stride + idx.start - slice_indexer.append(slice_indices) - integer_indexer = tuple( - lax.expand_dims(idx, (-1,)) for idx in integer_indexer + return new_target_arr, new_indexer + + +def _convert_to_gather_arrays(indexer: indexing.NDIndexer) -> tuple[Array, ...]: + # This is the general gather case. We need to create the gather arrays. + total_shape = indexer.get_indexer_shape() + is_int_indexing, _, _ = indexing.unpack_ndindexer(indexer) + + if any(is_int_indexing): + n_idxers = len(indexer.indices) + int_indexer_shape = indexer.int_indexer_shape + n_int_indexers = sum(1 for p in is_int_indexing if p) + last_int_index_idx = n_idxers - 1 - is_int_indexing[::-1].index(True) + n_slice_index_dims_after_int = n_idxers - last_int_index_idx - 1 + + def get_idx_in_shape_after_indexing(i): + if not any(is_int_indexing): + return i + + if i < n_idxers - n_slice_index_dims_after_int - n_int_indexers: + return i + if i < n_idxers - n_slice_index_dims_after_int: + raise ValueError + return i - n_int_indexers + len(int_indexer_shape) + + arrs = [] + for i, idxer in enumerate(indexer.indices): + if isinstance(idxer, indexing.Slice): + idx_in_shape_after_indexing = get_idx_in_shape_after_indexing(i) + arr = ( + lax.iota(np.int32, total_shape[idx_in_shape_after_indexing]) + * idxer.stride + + idxer.start ) - continue - assert next(slice_dim_iter, None) is None - return tuple(merge_lists(is_integer_indexer, slice_indexer, integer_indexer)) + diff = len(total_shape) - idx_in_shape_after_indexing - 1 + arr = arr.reshape(arr.shape + (1,) * diff) + arrs.append(arr) + elif isinstance(idxer, (np.ndarray, Array)): + diff = n_idxers - 1 - last_int_index_idx + arr = idxer.reshape(idxer.shape + (1,) * diff) + arrs.append(arr) + else: + raise ValueError(f"Invalid type of idxer: {type(idxer).__name__}") + + return tuple(arrs) @register_discharge_rule(get_p) @@ -313,20 +378,8 @@ def _get_discharge_rule( y = _get_discharge(x, idx, tree) return (None,) * (len(idx) + 1), y -def _prepend_gather(x, indexer): - # NumPy advanced int indexing won't prepend w/ only one dim, so add dummy. - return x[None][(np.array(0, 'int32'), *indexer)] - -def _prepend_scatter(x, indexer, val, *, add=False): - # NumPy advanced int indexing won't prepend w/ only one dim, so add dummy. - # However, since this is scatter, we need to remove the 1-sized dimension - # we added at the front. - if add: - return x[None].at[(0, *indexer)].add(val)[0] - return x[None].at[(0, *indexer)].set(val)[0] - -def _index_array(x, indexer): +def _index_array(x, indexer: indexing.NDIndexer): if _is_trivial_indexer(indexer): return x # Try the three APIs in the following order: `lax.slice`, @@ -336,13 +389,16 @@ def _index_array(x, indexer): # If everything in the indexer is a slice or ()-shaped, we can also # use `lax.dynamic_slice` with 1-sized slices for ()-shaped indices. # We need to squeeze out the 1-sized slices at the end. - elif maybe_slice := _maybe_convert_to_dynamic_slice(indexer): - starts, sizes, squeeze_dims = maybe_slice + elif maybe_dynamic_slice := _maybe_convert_to_dynamic_slice(indexer): + starts, sizes, squeeze_dims = maybe_dynamic_slice y = lax_slicing.dynamic_slice(x, starts, sizes) x = lax.squeeze(y, squeeze_dims) else: - indexer = _convert_to_array_indexer(indexer) - x = x[None][(np.array(0, "int32"), *indexer)] + transpose_order = _maybe_transpose_before_gather(indexer) + if transpose_order is not None: + x, indexer = _perform_transpose_before_gather(x, indexer, transpose_order) + arrays = _convert_to_gather_arrays(indexer) + x = x[arrays] return x @@ -367,53 +423,79 @@ def transform_array(x, transforms): def transform_swap_array(x, transforms, val): if transforms is None: transforms = [] - result = x - result_val = val - # Compute updated "val" (result). - _results = [x] + + # Will hold the value read from `x` before the swap, and will have the same + # shape as `val`. + new_val = x + # List of intermediate results by transforming `x`. + intermediates = [x] + + # Read phase (forward loop) for transform in transforms: match transform: case indexing.NDIndexer(): indexer = transform if _is_trivial_indexer(indexer): - _results.append(_results[-1]) + intermediates.append(intermediates[-1]) continue # If everything in the indexer is a slice or ()-shaped, we can also # use `lax.dynamic_slice` with 1-sized slices for ()-shaped indices. # We need to squeeze out the 1-sized slices at the end. if maybe_slice := _maybe_convert_to_dynamic_slice(indexer): starts, sizes, squeeze_dims = maybe_slice - result_old = lax_slicing.dynamic_slice(result, starts, sizes) - result = lax.squeeze(result_old, squeeze_dims) + new_val = lax.squeeze( + lax_slicing.dynamic_slice(new_val, starts, sizes), squeeze_dims + ) else: - indexer = _convert_to_array_indexer(indexer) - result = _prepend_gather(result, indexer) - _results.append(result) + transpose_order = _maybe_transpose_before_gather(indexer) + if transpose_order is not None: + new_val, indexer = _perform_transpose_before_gather( + new_val, indexer, transpose_order + ) + arrays = _convert_to_gather_arrays(indexer) + new_val = new_val[arrays] + # Here, we don't need to transpose `new_val` back because it now holds + # the result of the indexing, and is no longer the original array that + # was indexed into. + intermediates.append(new_val) case RefBitcaster(): - _results.append(bitcast(result, transform.dtype)) + intermediates.append(bitcast(new_val, transform.dtype)) case RefReshaper(): - _results.append(result.reshape(transform.shape)) + intermediates.append(new_val.reshape(transform.shape)) case _: raise NotImplementedError(f"Unsupported transform: {transform}") - # Compute updated "x" (result_val) - for i, transform in reversed(list(enumerate(transforms))): + # Will hold the final state of the `x` after `val` has been written to the + # transformed location, and will have the same shape as `x`. + new_x = val + + # Write phase (reversed loop) + for intermediate, transform in reversed(zip(intermediates[:-1], transforms)): if isinstance(transform, indexing.NDIndexer): indexer = transform if _is_trivial_indexer(indexer): continue if maybe_slice := _maybe_convert_to_dynamic_slice(indexer): starts, _, squeeze_dims = maybe_slice - result_val = lax.expand_dims(result_val, squeeze_dims) - result_val = lax_slicing.dynamic_update_slice( - _results[i], result_val, starts + new_x = lax_slicing.dynamic_update_slice( + intermediate, lax.expand_dims(new_x, squeeze_dims), starts ) else: - indexer = _convert_to_array_indexer(indexer) - result_val = _prepend_scatter(_results[i], indexer, result_val) + transpose_order = _maybe_transpose_before_gather(indexer) + if transpose_order is not None: + intermediate, indexer = _perform_transpose_before_gather( + intermediate, indexer, transpose_order + ) + arrays = _convert_to_gather_arrays(indexer) + new_x = intermediate.at[arrays].set(new_x) # pytype: disable=attribute-error + if transpose_order is not None: + transpose_order_inversed = np.argsort(transpose_order) + new_x = new_x.transpose(transpose_order_inversed) else: raise NotImplementedError(f"Unsupported transform: {transform}") - return result, result_val + + return new_val, new_x + def _get_discharge(x, idx, tree): transforms = tree_util.tree_unflatten(tree, idx) @@ -446,8 +528,10 @@ def _addupdate_discharge(x, val, idx, tree): if len(transforms) > 1: raise NotImplementedError("Only single indexer is supported.") indexer = transforms[0] + if _is_trivial_indexer(indexer): return x + val + # If everything in the indexer is a slice or ()-shaped, we can also # use `lax.dynamic_slice` with 1-sized slices for ()-shaped indices. # We need to squeeze out the 1-sized slices at the end. @@ -457,8 +541,17 @@ def _addupdate_discharge(x, val, idx, tree): val = lax.expand_dims(val, squeeze_dims) y = lax_slicing.dynamic_update_slice(x, x_old + val, starts) return y - indexer = _convert_to_array_indexer(indexer) - return _prepend_scatter(x, indexer, val, add=True) + + transpose_order = _maybe_transpose_before_gather(indexer) + if transpose_order is not None: + x, indexer = _perform_transpose_before_gather(x, indexer, transpose_order) + arrays = _convert_to_gather_arrays(indexer) + x = x.at[arrays].add(val) + if transpose_order is not None: + transpose_order_inversed = np.argsort(transpose_order) + x = x.transpose(transpose_order_inversed) + return x + @weakref_lru_cache def _cached_closed_jaxpr_discharge(closed_jaxpr: core.ClosedJaxpr): @@ -737,7 +830,7 @@ def _run_state_partial_eval(trace: pe.JaxprTrace, *tracers: pe.JaxprTracer, is_initialized=(True,) * len(jaxpr_unknown.invars)) _, eqn_effects = run_state_p.abstract_eval(*[v.aval for v in unknown_inputs], **uk_params) - eqn = pe.new_eqn_recipe(unknown_inputs, res_ref_unknown_outputs, + eqn = pe.new_eqn_recipe(trace, unknown_inputs, res_ref_unknown_outputs, run_state_p, uk_params, eqn_effects, source) for t in res_ref_unknown_outputs: t.recipe = eqn @@ -1054,3 +1147,35 @@ def wrapped(args): _, out_flat = split_list(out_const_flat, [len(consts)]) return in_tree.unflatten(out_flat) return wrapped + + +@register_discharge_rule(pjit.pjit_p) +def _pjit_state_discharge_rule( + in_avals, out_avals, *args, jaxpr, in_shardings, out_shardings, + in_layouts, out_layouts, **params): + if not all(isinstance(s, sharding_impls.UnspecifiedValue) for s in (*in_shardings, *out_shardings)): + raise NotImplementedError + + if not (all(l is None for l in in_layouts) and + all(l is None for l in out_layouts)): + raise NotImplementedError + + jaxpr, consts = jaxpr.jaxpr, jaxpr.consts + num_outs = len(jaxpr.outvars) + discharged_jaxpr, discharged_consts = discharge_state(jaxpr, consts) + discharged_closed_jaxpr = core.ClosedJaxpr(discharged_jaxpr, discharged_consts) + new_in_shardings = (sharding_impls.UNSPECIFIED,) * len(discharged_jaxpr.invars) + new_out_shardings = (sharding_impls.UNSPECIFIED,) * len(discharged_jaxpr.outvars) + new_in_layouts = (None,) * len(discharged_jaxpr.invars) + new_out_layouts = (None,) * len(discharged_jaxpr.outvars) + out_and_ref_vals = pjit.pjit_p.bind( + *args, jaxpr=discharged_closed_jaxpr, in_shardings=new_in_shardings, + out_shardings=new_out_shardings, in_layouts=new_in_layouts, + out_layouts=new_out_layouts, **params) + out_vals, ref_vals = split_list(out_and_ref_vals, [num_outs]) + ref_vals_iter = iter(ref_vals) + new_invals = tuple(next(ref_vals_iter) if isinstance(aval, AbstractRef) + else None for aval in in_avals) + sentinel = object() + assert next(ref_vals_iter, sentinel) is sentinel + return new_invals, out_vals diff --git a/jax/_src/state/indexing.py b/jax/_src/state/indexing.py index 4b627c1cd581..a0d1d85d09b4 100644 --- a/jax/_src/state/indexing.py +++ b/jax/_src/state/indexing.py @@ -17,9 +17,11 @@ from __future__ import annotations import dataclasses -from typing import Any, Sequence, Union +from typing import Any, Union +from collections.abc import Sequence from jax._src import core +from jax._src import pretty_printer as pp from jax._src import tree_util from jax._src.typing import Array from jax._src.util import merge_lists @@ -78,6 +80,30 @@ def from_slice(cls, slc: slice, size: int) -> Slice: return cls(start, size, step) +def _pp_slice(context: core.JaxprPpContext, dim, slc: Slice) -> str: + start, size = slc.start, slc.size + if isinstance(start, core.Var): + start_str = core.pp_var(start, context) + size_str = ( + core.pp_var(size, context) if isinstance(size, core.Var) else str(size) + ) + return f"{start_str}:{start_str}+{size_str}" + else: + start_str = str(start) + if start == 0: + start_str = "" + if isinstance(size, core.Var): + size_str = core.pp_var(size, context) + if start_str: + return f"{start_str}:{start_str}+{size_str}" + else: + return f":{size_str}" + else: + end = start + size + end_str = "" if end == dim else str(end) + return f"{start_str}:{end_str}" + + def dslice( start: int | Array | None, size: int | Array | None = None, @@ -247,11 +273,21 @@ def from_indices_shape(cls, indices, shape) -> NDIndexer: return cls(indices, shape, int_indexer_shape, validate=True) def get_indexer_shape(self) -> tuple[int | Array, ...]: - _, slice_indexers, _ = unpack_ndindexer(self) - slice_shape = [s.size for s in slice_indexers] - # In NDIndexers, the int_indexer_shape is *always* at the front of the - # result. - return (*self.int_indexer_shape, *slice_shape) + is_int_indexing, slice_indexers, _ = unpack_ndindexer(self) + + slice_shape = tuple(s.size for s in slice_indexers) + int_indexers_contiguous = bool( + np.all(np.diff(np.where(is_int_indexing)[0]) == 1) + ) + if not int_indexers_contiguous: + return self.int_indexer_shape + slice_shape + + has_int_indexers = any(is_int_indexing) + if has_int_indexers: + pos = is_int_indexing.index(True) + return slice_shape[:pos] + self.int_indexer_shape + slice_shape[pos:] + + return slice_shape def transform_shape(self, shape: None | tuple[int | Array, ...]) -> None | tuple[int | Array, ...]: del shape # Unused @@ -282,3 +318,12 @@ def transform_sharding(self, sharding): f"along unsharded axes, but ref of shape {self.shape} " f"was sliced on axis {i}, which is sharded like {s}") return sharding + + def pretty_print(self, context: core.JaxprPpContext) -> pp.Doc: + indices = [] + for idx, dim in zip(self.indices, self.shape): + if isinstance(idx, Slice): + indices.append(_pp_slice(context, dim, idx)) + else: + indices.append(core.pp_var(idx, context, print_literal_dtype=False)) # type: ignore + return pp.concat([pp.text("["), pp.text(",".join(indices)), pp.text("]")]) diff --git a/jax/_src/state/primitives.py b/jax/_src/state/primitives.py index 6f7570a5f3cd..5b83b6a3cb64 100644 --- a/jax/_src/state/primitives.py +++ b/jax/_src/state/primitives.py @@ -18,9 +18,12 @@ import types from typing import Any, Union +import numpy as np + from jax._src import ad_util from jax._src import core from jax._src import dispatch +from jax._src import dtypes from jax._src import pretty_printer as pp from jax._src import traceback_util from jax._src import tree_util @@ -34,15 +37,12 @@ AbstractRef, AccumEffect, ReadEffect, - RefBitcaster, - RefReshaper, Transform, TransformedRef, WriteEffect, ) from jax._src.typing import Array from jax._src.util import safe_map, safe_zip -import numpy as np ## General utilities @@ -144,10 +144,25 @@ def ref_swap( _function_name: str = "ref_swap", ) -> Array: """Sets a `Ref`'s value and returns the original value.""" + if hasattr(ref_or_view, 'dtype'): + value = _maybe_implicit_cast(ref_or_view.dtype, value) ref, transforms = get_ref_and_transforms(ref_or_view, idx, _function_name) flat_transforms, tree = tree_util.tree_flatten(transforms) return swap_p.bind(ref, value, *flat_transforms, tree=tree) +# TODO(slebedev,mattjj): replace with special handling of Python numeric types: +# if (isinstance(value, (int, float, complex)) and +# value == np.array(value, dtype).item()): return cast +def _maybe_implicit_cast(dtype, value): + aval = core.typeof(value) + if (aval.weak_type and + (dtypes.issubdtype(dtype, np.floating) and + dtypes.issubdtype(aval.dtype, np.floating)) or + (dtypes.issubdtype(dtype, np.integer) and + dtypes.issubdtype(aval.dtype, np.integer))): + return lax.convert_element_type(value, dtype) + return value + def ref_set( ref_or_view: AbstractRef | TransformedRef, @@ -248,7 +263,7 @@ def _swap_abstract_eval(ref_aval: AbstractRef, f"Expected shape: {expected_out_shape}. " f"Value shape: {val_aval.shape}. " f"Transforms: {transforms}. ") - if expected_out_dtype != val_aval.dtype and not val_aval.weak_type: + if expected_out_dtype != val_aval.dtype: raise ValueError( "Invalid dtype for `swap`. " f"Ref dtype: {expected_out_dtype}. " @@ -297,70 +312,6 @@ def _addupdate_abstract_eval(ref_aval: AbstractRef, pp_ref_var = partial(pp.color, intensity=pp.Intensity.NORMAL, foreground=pp.Color.GREEN) -def _pp_slice(context: core.JaxprPpContext, dim, slc: indexing.Slice - ) -> str: - start, size = slc.start, slc.size - if isinstance(start, core.Var): - start_str = core.pp_var(start, context) - size_str = ( - core.pp_var(size, context) - if isinstance(size, core.Var) - else str(size) - ) - return f'{start_str}:{start_str}+{size_str}' - else: - start_str = str(start) - if start == 0: - start_str = '' - if isinstance(size, core.Var): - size_str = core.pp_var(size, context) - if start_str: - return f'{start_str}:{start_str}+{size_str}' - else: - return f':{size_str}' - else: - end = start + size - end_str = '' if end == dim else str(end) - return f'{start_str}:{end_str}' - -def pp_indexer(context: core.JaxprPpContext,indexer: indexing.NDIndexer - ) -> pp.Doc: - indices = [] - for idx, dim in zip(indexer.indices, indexer.shape): - if isinstance(idx, indexing.Slice): - indices.append(_pp_slice(context, dim, idx)) - else: - indices.append(core.pp_var(idx, context)) # type: ignore - return pp.concat([pp.text("["), pp.text(','.join(indices)), pp.text("]")]) - - -def pp_bitcaster( - context: core.JaxprPpContext, bitcaster: RefBitcaster -) -> pp.Doc: - del context - return pp.text( - f"[bitcast({bitcaster.dtype}[{','.join(str(d) for d in bitcaster.shape)}])]" - ) - - -def pp_reshaper(context: core.JaxprPpContext, reshaper: RefReshaper) -> pp.Doc: - del context - return pp.text( - f"[reshape({reshaper.dtype}[{','.join(str(d) for d in reshaper.shape)}])]" - ) - - -def pp_transform(context: core.JaxprPpContext, transform: Transform) -> pp.Doc: - match transform: - case indexing.NDIndexer(): - return pp_indexer(context, transform) - case RefBitcaster(): - return pp_bitcaster(context, transform) - case RefReshaper(): - return pp_reshaper(context, transform) - case _: - return pp.text(f"[{transform}]") - def _pp_transforms( context: core.JaxprPpContext, @@ -369,7 +320,7 @@ def _pp_transforms( if not transforms: return pp.text("[...]") return pp.concat( - [pp_transform(context, transform) for transform in transforms] + [transform.pretty_print(context) for transform in transforms] ) @@ -503,11 +454,52 @@ def _state_partial_eval_custom(prim, saveable, unks_in, inst_in, eqn): ## get/swap/addupdate batching rules -def _batch_indexer(indexer: indexing.NDIndexer, dims, - axis_size: int, - ref_shape: tuple[int, ...], - ref_dim: int | batching.NotMapped, - idx_is_batched: bool) -> indexing.NDIndexer: +def _batch_indexer( + indexer: indexing.NDIndexer, + dims, + axis_size: int, + ref_shape: tuple[int, ...], + ref_dim: int | batching.NotMapped, + idx_is_batched: bool, +) -> indexing.NDIndexer: + """Converts a batched indexer into an unbatched one. + + This function handles the complexity of `vmap`-style batching where either the + `ref` being indexed, the indexer, or both may have batched dimensions. The + goal is to produce a new indexer that acts as if applied in a batched context, + but without actual batching, enabling downstream code to process it as usual. + + If any index in `indexer` is batched, all array indexers are normalized. If + the array indexer contains a batched dimension, the dimension is moved to the + front (axis 0). If the array indexer not batched, it is broadcasted to include + a batch dimension at the front. This is to guarantee that all array indexers + are still of the same shape. + + Slices are passed through unchanged unless they contain dynamic elements and + are themselves batched, which is currently unsupported. + + If `ref` is batched (`ref_dim` is not `NotMapped`), we simulate per-example + indexing by inserting a new iota array at the position corresponding to + `ref_dim` in the indexer. + + It is worth noting that if the array indexers in the original indexer are + contiguous, but become non-contiguous in the new indexer due to the insertion + of the iota, the dimensions corresponding to the array indexers will be moved + to the front in the indexing result. The batched dimension will be at axis 0, + while the dimensions corresponding to the array indexers in the original + indexer will start from axis 1. This behavior would cause a mismatch between + the original indexer and the new indexer. Callers must take this behavior into + account and properly transpose the arrays involved to avoid this mismatch. + + Args: + indexer: An `NDIndexer` that indexes into `ref`. + dims: A pytree with the same structure as `indexer`, indicating which + dimension (if any) is batched for each array indexer. + axis_size: Size of the batch dimension. + ref_shape: Shape of `ref`. + ref_dim: The dimension of `ref` that is batched (if any). + idx_is_batched: Whether any index in the `indexer` is batched. + """ indices = indexer.indices indices_dims = dims.indices new_indices: list[Array | indexing.Slice | int] = [] @@ -545,7 +537,7 @@ def _batch_indexer(indexer: indexing.NDIndexer, dims, idx = lax.broadcast_in_dim(idx, new_integer_indexer_shape, bcast_dims) else: - idx = batching.moveaxis(idx, dim, 0) + idx = batching.moveaxis(idx, dim, 0) # type: ignore[arg-type] new_indices.append(idx) else: if ref_dim is not batching.not_mapped: @@ -559,9 +551,9 @@ def _batch_indexer(indexer: indexing.NDIndexer, dims, if ref_dim is not batching.not_mapped: iota = lax.broadcasted_iota(np.dtype('int32'), new_integer_indexer_shape, 0) new_indices.insert(ref_dim, iota) - return indexing.NDIndexer(tuple(new_indices), ref_shape, - new_integer_indexer_shape, - validate=True) + return indexing.NDIndexer( + tuple(new_indices), ref_shape, new_integer_indexer_shape, validate=True + ) def _get_vmap(batched_args, batched_dims, *, tree): axis_size, = {x.shape[d] for x, d in zip(batched_args, batched_dims) @@ -576,11 +568,42 @@ def _get_vmap(batched_args, batched_dims, *, tree): if len(indexers) > 1: raise NotImplementedError("Batching with multiple indexers not supported.") # TODO(sharadmv): handle vmap of multiple indexers - indexers = tuple(_batch_indexer(indexer, dims, axis_size, + new_indexers = tuple(_batch_indexer(indexer, dims, axis_size, ref.shape, ref_dim, idx_is_batched) for indexer, dims in zip(indexers, indexers_dims)) - flat_indexers, tree = tree_util.tree_flatten(indexers) - return get_p.bind(ref, *flat_indexers, tree=tree), 0 + flat_indexers, tree = tree_util.tree_flatten(new_indexers) + + is_int_indexing, _, _ = indexing.unpack_ndindexer(indexers[0]) + int_indexers_contiguous = bool( + np.all(np.diff(np.where(is_int_indexing)[0]) == 1) + ) + is_new_int_indexing, _, _ = indexing.unpack_ndindexer(new_indexers[0]) + new_int_indexers_contiguous = bool( + np.all(np.diff(np.where(is_new_int_indexing)[0]) == 1) + ) + + out = get_p.bind(ref, *flat_indexers, tree=tree) + if not int_indexers_contiguous: # will always be moved to the front + out_bdim = 0 + else: # originally not going to be moved to the front + if new_int_indexers_contiguous: # now not going to be moved to the front + out_bdim = is_new_int_indexing.index(True) + else: # now going to be moved to the front + original_pos = is_int_indexing.index(True) + array_indexer_shape = new_indexers[0].int_indexer_shape + array_indexer_len = len(array_indexer_shape) + + transpose_order = list(range(len(out.shape))) + transpose_order = ( + transpose_order[0], + *transpose_order[array_indexer_len:array_indexer_len+original_pos], + *transpose_order[1:array_indexer_len], + *transpose_order[array_indexer_len+original_pos:], + ) + + out = lax.transpose(out, transpose_order) + out_bdim = 0 + return out, out_bdim batching.primitive_batchers[get_p] = _get_vmap def _swap_vmap(batched_args, batched_dims, *, tree): @@ -595,18 +618,69 @@ def _swap_vmap(batched_args, batched_dims, *, tree): val_is_batched = val_dim is not batching.not_mapped idx_is_batched = any(i_dim is not batching.not_mapped for i_dim in flat_idx_dims) + + if not ref_is_batched: + raise Exception("performing a set/swap operation with vmapped value on " + "an unbatched mutable array reference " + f"of type {core.typeof(ref)}. Move the mutable array to be " + "an argument to the vmapped function?") + if len(indexers) > 1: raise NotImplementedError("Batching with multiple indexers not supported.") # TODO(sharadmv): handle vmap of multiple indexers - indexers = tuple(_batch_indexer(indexer, dims, axis_size, + new_indexers = tuple(_batch_indexer(indexer, dims, axis_size, ref.shape, ref_dim, idx_is_batched) for indexer, dims in zip(indexers, indexers_dims)) - flat_indexers, tree = tree_util.tree_flatten(indexers) - if (ref_is_batched or idx_is_batched) and not val_is_batched: - val = batching.broadcast(val, axis_size, 0) - if val_is_batched: - val = batching.moveaxis(val, val_dim, 0) - return swap_p.bind(ref, val, *flat_indexers, tree=tree), 0 + flat_indexers, tree = tree_util.tree_flatten(new_indexers) + + is_int_indexing, _, _ = indexing.unpack_ndindexer(indexers[0]) + int_indexers_contiguous = bool( + np.all(np.diff(np.where(is_int_indexing)[0]) == 1) + ) + is_new_int_indexing, _, _ = indexing.unpack_ndindexer(new_indexers[0]) + new_int_indexers_contiguous = bool( + np.all(np.diff(np.where(is_new_int_indexing)[0]) == 1) + ) + + if not new_int_indexers_contiguous: # will be moved to the front + batched_dim_in_result = 0 + else: + batched_dim_in_result = is_new_int_indexing.index(True) + 0 + + if not val_is_batched: + if ref_is_batched or idx_is_batched: + val = batching.broadcast(val, axis_size, batched_dim_in_result) + else: + val = batching.moveaxis(val, val_dim, batched_dim_in_result) + + transpose_order_inversed = None + + # Originally not going to be moved to the front, but now going to be moved to + # the front. + if int_indexers_contiguous and not new_int_indexers_contiguous: + original_pos = is_int_indexing.index(True) + array_indexer_shape = new_indexers[0].int_indexer_shape + array_indexer_len = len(array_indexer_shape) + + transpose_order = list(range(len(val.shape))) + transpose_order = ( + transpose_order[0], + *transpose_order[1+original_pos:(1+original_pos)+(array_indexer_len-1)], + *transpose_order[1:1+original_pos], + *transpose_order[(1+original_pos)+(array_indexer_len-1):], + ) + val = val.transpose(transpose_order) + transpose_order_inversed = np.argsort(transpose_order) + + out = swap_p.bind(ref, val, *flat_indexers, tree=tree) + + # `val` should not be transposed, but we needed to transpose it to match + # `swap_p`. As a result, the output of `swap_p` is also transposed. Now we + # need to transpose it back. + if transpose_order_inversed is not None: + out = out.transpose(transpose_order_inversed) + + return out, batched_dim_in_result batching.primitive_batchers[swap_p] = _swap_vmap def _addupdate_vmap(batched_args, batched_dims, *, tree): @@ -624,14 +698,47 @@ def _addupdate_vmap(batched_args, batched_dims, *, tree): if len(indexers) > 1: raise NotImplementedError("Batching with multiple indexers not supported.") # TODO(sharadmv): handle vmap of multiple indexers - indexers = tuple(_batch_indexer(indexer, dims, axis_size, + new_indexers = tuple(_batch_indexer(indexer, dims, axis_size, ref.shape, ref_dim, idx_is_batched) for indexer, dims in zip(indexers, indexers_dims)) - flat_indexers, tree = tree_util.tree_flatten(indexers) - if (ref_is_batched or idx_is_batched) and not val_is_batched: - val = batching.broadcast(val, axis_size, 0) - if val_is_batched: - val = batching.moveaxis(val, val_dim, 0) + flat_indexers, tree = tree_util.tree_flatten(new_indexers) + + is_int_indexing, _, _ = indexing.unpack_ndindexer(indexers[0]) + int_indexers_contiguous = bool( + np.all(np.diff(np.where(is_int_indexing)[0]) == 1) + ) + is_new_int_indexing, _, _ = indexing.unpack_ndindexer(new_indexers[0]) + new_int_indexers_contiguous = bool( + np.all(np.diff(np.where(is_new_int_indexing)[0]) == 1) + ) + + if not new_int_indexers_contiguous: # will be moved to the front + batched_dim_in_result = 0 + else: + batched_dim_in_result = is_new_int_indexing.index(True) + + if not val_is_batched: + if ref_is_batched or idx_is_batched: + val = batching.broadcast(val, axis_size, batched_dim_in_result) + else: + val = batching.moveaxis(val, val_dim, batched_dim_in_result) + + # Originally not going to be moved to the front, but now going to be moved to + # the front. + if int_indexers_contiguous and not new_int_indexers_contiguous: + original_pos = is_int_indexing.index(True) + array_indexer_shape = new_indexers[0].int_indexer_shape + array_indexer_len = len(array_indexer_shape) + + transpose_order = list(range(len(val.shape))) + transpose_order = ( + transpose_order[0], + *transpose_order[1+original_pos:(1+original_pos)+(array_indexer_len-1)], + *transpose_order[1:1+original_pos], + *transpose_order[(1+original_pos)+(array_indexer_len-1):], + ) + val = val.transpose(transpose_order) + return addupdate_p.bind(ref, val, *flat_indexers, tree=tree), [] batching.primitive_batchers[addupdate_p] = _addupdate_vmap @@ -644,7 +751,7 @@ def _addupdate_vmap(batched_args, batched_dims, *, tree): broadcast_to_p = core.Primitive('broadcast_to') def broadcast_to(a: Array, shape: tuple[int, ...]) -> Array: - import jax.numpy as jnp + import jax.numpy as jnp # pytype: disable=import-error a = jnp.asarray(a) if a.shape == shape: return a @@ -652,7 +759,7 @@ def broadcast_to(a: Array, shape: tuple[int, ...]) -> Array: @broadcast_to_p.def_impl def _broadcast_to_impl(a, *, shape): - import jax.numpy as jnp + import jax.numpy as jnp # pytype: disable=import-error return jnp.broadcast_to(a, shape) @broadcast_to_p.def_abstract_eval diff --git a/jax/_src/state/types.py b/jax/_src/state/types.py index 057242f4c1ac..95e298b532e8 100644 --- a/jax/_src/state/types.py +++ b/jax/_src/state/types.py @@ -18,7 +18,8 @@ from collections.abc import Sequence import dataclasses import math -from typing import Any, Callable, Protocol, Union +from typing import Any, Protocol, Union +from collections.abc import Callable from jax._src import core from jax._src import dtypes @@ -75,6 +76,7 @@ class AccumEffect(RefEffect): name: str = "Accum" effects.control_flow_allowed_effects.add_type(RefEffect) +effects.partial_eval_kept_effects.add_type(RefEffect) StateEffect = Union[ReadEffect, WriteEffect, AccumEffect] @@ -125,6 +127,10 @@ def transform_sharding(self, sharding): return sharding raise NotImplementedError + def pretty_print(self, context: core.JaxprPpContext) -> pp.Doc: + del context # Unused. + return pp.text(f"{{bitcast({self.dtype}{list(self.shape)}])}}") + @tree_util.register_pytree_node_class @dataclasses.dataclass(frozen=True) @@ -178,6 +184,10 @@ def transform_sharding(self, sharding): return sharding raise NotImplementedError + def pretty_print(self, context: core.JaxprPpContext) -> pp.Doc: + del context # Unused. + return pp.text(f"{{reshape({self.dtype}{list(self.shape)})}}") + class Transform(Protocol): @@ -205,6 +215,9 @@ def transform_sharding(self, sharding): if all(p is None for p in sharding.spec): return sharding # no explicit axes raise NotImplementedError + def pretty_print(self, context: core.JaxprPpContext) -> pp.Doc: + return pp.text(f"{{{self}}}") + @dataclasses.dataclass class RefIndexer: @@ -243,7 +256,7 @@ def shape(self) -> tuple[int | Array, ...]: if not unprocessed: return shape # If there are any unprocessed transforms left, we apply them to the shape - # we've found previuously. + # we've found previously. for t in self.transforms[-unprocessed:]: shape = t.transform_shape(shape) assert shape is not None @@ -266,6 +279,9 @@ def dtype(self): assert dtype is not None return dtype + ndim = property(lambda self: len(self.shape)) + size = property(lambda self: math.prod(self.shape)) + @property def at(self) -> RefIndexer: return RefIndexer(self) @@ -330,6 +346,12 @@ def update(self, inner_aval=None): ndim = property(lambda self: len(self.shape)) size = property(lambda self: math.prod(self.shape)) + def _len(self, ignored_tracer) -> int: + try: + return self.shape[0] + except IndexError as err: + raise TypeError("len() of unsized object") from err # same as numpy error + @property def shape(self): try: @@ -357,6 +379,15 @@ def sharding(self): f"`Ref{{{self.inner_aval.str_short()}}} has no `sharding`." ) from None + @property + def vma(self): + try: + return self.inner_aval.vma # pytype: disable=attribute-error + except AttributeError: + raise AttributeError( + f"`Ref{{{self.inner_aval.str_short()}}} has no `vma`." + ) from None + @core.aval_property def at(self): return RefIndexer(self) @@ -427,7 +458,7 @@ def shaped_array_ref( shape: tuple[int, ...], dtype, weak_type: bool = False) -> AbstractRef: return AbstractRef(core.ShapedArray(shape, dtype, weak_type=weak_type)) -def _shard_ref(mesh, auto, names, ref_aval: AbstractRef): +def _shard_ref(mesh, auto, check_rep, names, ref_aval: AbstractRef): del mesh if names: # Can't actually shard a ref, can only close over it. @@ -435,7 +466,7 @@ def _shard_ref(mesh, auto, names, ref_aval: AbstractRef): return ref_aval core.shard_aval_handlers[AbstractRef] = _shard_ref -def _unshard_ref(mesh, names, ref_aval: AbstractRef): +def _unshard_ref(mesh, check_rep, names, ref_aval: AbstractRef): del mesh if names: # Can't actually shard a ref, can only close over it. diff --git a/jax/_src/state/utils.py b/jax/_src/state/utils.py index 2dd57dcde0ca..a07ebd626dee 100644 --- a/jax/_src/state/utils.py +++ b/jax/_src/state/utils.py @@ -14,15 +14,16 @@ """Utilities for tracing stateful functions.""" from functools import partial -from typing import Callable +from collections.abc import Callable -import jax +from jax._src import api from jax._src import core from jax._src import dtypes from jax._src import linear_util as lu from jax._src.interpreters import partial_eval as pe -from jax._src.state import AbstractRef +from jax._src.lax import lax from jax._src.state.primitives import ref_get +from jax._src.state.types import AbstractRef from jax._src.typing import DTypeLike from jax._src.util import safe_map, safe_zip, split_list @@ -112,7 +113,7 @@ def bitcast(x, dtype: DTypeLike): x = x.reshape(*x.shape[:-2], x.shape[-2] // ratio, ratio, -1).swapaxes( -1, -2 ) - y = jax.lax.bitcast_convert_type(x, dtype) + y = lax.bitcast_convert_type(x, dtype) if x_bitwidth > y_bitwidth: y = y.swapaxes(-1, -2).reshape(shape) return y @@ -120,4 +121,4 @@ def bitcast(x, dtype: DTypeLike): def eval_bitcast_shape(x, dtype: DTypeLike): f = partial(bitcast, dtype=dtype) - return jax.eval_shape(f, jax.ShapeDtypeStruct(x.shape, x.dtype)).shape + return api.eval_shape(f, api.ShapeDtypeStruct(x.shape, x.dtype)).shape diff --git a/jax/_src/test_loader.py b/jax/_src/test_loader.py new file mode 100644 index 000000000000..8f97cea1e7bc --- /dev/null +++ b/jax/_src/test_loader.py @@ -0,0 +1,222 @@ +# Copyright 2018 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Contains a custom unittest loader and test suite. + +Implements: +- A test filter based on the JAX_TEST_TARGETS and JAX_EXCLUDE_TEST_TARGETS + environment variables. +- A test suite that runs tests in parallel using threads if JAX_TEST_NUM_THREADS + is >= 1. +- Test decorators that mark a test case or test class as thread-hostile. +""" + +from __future__ import annotations + +from collections.abc import Callable +from concurrent.futures import ThreadPoolExecutor +from contextlib import contextmanager +import logging +import os +import re +import threading +import time +import unittest + +from absl.testing import absltest +from jax._src import config +from jax._src import test_warning_util +from jax._src import util + +logger = logging.getLogger(__name__) + + +_TEST_TARGETS = config.string_flag( + 'test_targets', os.getenv('JAX_TEST_TARGETS', ''), + 'Regular expression specifying which tests to run, called via re.search on ' + 'the test name. If empty or unspecified, run all tests.' +) + +_EXCLUDE_TEST_TARGETS = config.string_flag( + 'exclude_test_targets', os.getenv('JAX_EXCLUDE_TEST_TARGETS', ''), + 'Regular expression specifying which tests NOT to run, called via re.search ' + 'on the test name. If empty or unspecified, run all tests.' +) + +TEST_NUM_THREADS = config.int_flag( + 'jax_test_num_threads', int(os.getenv('JAX_TEST_NUM_THREADS', '0')), + help='Number of threads to use for running tests. 0 means run everything ' + 'in the main thread. Using > 1 thread is experimental.' +) + +# We use a reader-writer lock to protect test execution. Tests that may run in +# parallel acquire a read lock; tests that are not thread-safe acquire a write +# lock. +_test_rwlock = util.Mutex() + +def _run_one_test(test: unittest.TestCase, result: ThreadSafeTestResult): + if getattr(test.__class__, "thread_hostile", False): + _test_rwlock.writer_lock() + try: + test(result) # type: ignore + finally: + _test_rwlock.writer_unlock() + else: + _test_rwlock.reader_lock() + try: + test(result) # type: ignore + finally: + _test_rwlock.reader_unlock() + + +@contextmanager +def thread_unsafe_test(): + """Decorator for tests that are not thread-safe. + + Note: this decorator (naturally) only applies to what it wraps, not to, say, + code in separate setUp() or tearDown() methods. + """ + if TEST_NUM_THREADS.value <= 0: + yield + return + + _test_rwlock.assert_reader_held() + _test_rwlock.reader_unlock() + _test_rwlock.writer_lock() + try: + yield + finally: + _test_rwlock.writer_unlock() + _test_rwlock.reader_lock() + + +def thread_unsafe_test_class(): + """Decorator that marks a TestCase class as thread-hostile.""" + def f(klass): + assert issubclass(klass, unittest.TestCase), type(klass) + klass.thread_hostile = True + return klass + return f + + +class ThreadSafeTestResult: + """ + Wraps a TestResult to make it thread safe. + + We do this by accumulating API calls and applying them in a batch under a + lock at the conclusion of each test case. + + We duck type instead of inheriting from TestResult because we aren't actually + a perfect implementation of TestResult, and would rather get a loud error + for things we haven't implemented. + """ + def __init__(self, lock: threading.Lock, result: unittest.TestResult): + self.lock = lock + self.test_result = result + self.actions: list[Callable[[], None]] = [] + + def startTest(self, test: unittest.TestCase): + logger.info("Test start: %s", test.id()) + self.start_time = time.time() + + def stopTest(self, test: unittest.TestCase): + logger.info("Test stop: %s", test.id()) + stop_time = time.time() + with self.lock: + # If test_result is an ABSL _TextAndXMLTestResult we override how it gets + # the time. This affects the timing that shows up in the XML output + # consumed by CI. + time_getter = getattr(self.test_result, "time_getter", None) + try: + self.test_result.time_getter = lambda: self.start_time + self.test_result.startTest(test) + for callback in self.actions: + callback() + self.test_result.time_getter = lambda: stop_time + self.test_result.stopTest(test) + finally: + if time_getter is not None: + self.test_result.time_getter = time_getter + + def addSuccess(self, test: unittest.TestCase): + self.actions.append(lambda: self.test_result.addSuccess(test)) + + def addSkip(self, test: unittest.TestCase, reason: str): + self.actions.append(lambda: self.test_result.addSkip(test, reason)) + + def addError(self, test: unittest.TestCase, err): + self.actions.append(lambda: self.test_result.addError(test, err)) + + def addFailure(self, test: unittest.TestCase, err): + self.actions.append(lambda: self.test_result.addFailure(test, err)) + + def addExpectedFailure(self, test: unittest.TestCase, err): + self.actions.append(lambda: self.test_result.addExpectedFailure(test, err)) + + def addDuration(self, test: unittest.TestCase, elapsed): + self.actions.append(lambda: self.test_result.addDuration(test, elapsed)) + + +class JaxTestSuite(unittest.TestSuite): + """Runs tests in parallel using threads if TEST_NUM_THREADS is > 1. + + Caution: this test suite does not run setUpClass or setUpModule methods if + thread parallelism is enabled. + """ + + def __init__(self, suite: unittest.TestSuite): + super().__init__(list(suite)) + + def run(self, result: unittest.TestResult, debug: bool = False) -> unittest.TestResult: + if TEST_NUM_THREADS.value <= 0: + return super().run(result) + + test_warning_util.install_threadsafe_warning_handlers() + + executor = ThreadPoolExecutor(TEST_NUM_THREADS.value) + lock = threading.Lock() + futures = [] + + def run_test(test): + """Recursively runs tests in a test suite or test case.""" + if isinstance(test, unittest.TestSuite): + for subtest in test: + run_test(subtest) + else: + test_result = ThreadSafeTestResult(lock, result) + futures.append(executor.submit(_run_one_test, test, test_result)) + + with executor: + run_test(self) + for future in futures: + future.result() + + return result + + +class JaxTestLoader(absltest.TestLoader): + suiteClass = JaxTestSuite + + def getTestCaseNames(self, testCaseClass): + names = super().getTestCaseNames(testCaseClass) + if _TEST_TARGETS.value: + pattern = re.compile(_TEST_TARGETS.value) + names = [name for name in names + if pattern.search(f"{testCaseClass.__name__}.{name}")] + if _EXCLUDE_TEST_TARGETS.value: + pattern = re.compile(_EXCLUDE_TEST_TARGETS.value) + names = [name for name in names + if not pattern.search(f"{testCaseClass.__name__}.{name}")] + return names diff --git a/jax/_src/test_multiprocess.py b/jax/_src/test_multiprocess.py new file mode 100644 index 000000000000..8a5b6ee1df00 --- /dev/null +++ b/jax/_src/test_multiprocess.py @@ -0,0 +1,254 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Helper for running multi-process tests.""" + +import os +import pathlib +import re +import signal +import subprocess +import time + +from absl import app +from absl import flags +import jax +from jax import config +from jax._src import distributed +try: + import portpicker +except ImportError: + portpicker = None + +from absl.testing import absltest +from jax._src import test_util as jtu + + +_NUM_PROCESSES = flags.DEFINE_integer( + "num_processes", None, "Number of processes to use." +) + +_GPUS_PER_PROCESS = flags.DEFINE_integer( + "gpus_per_process", + 0, + "Number of GPUs per worker process.", +) + +_MULTIPROCESS_TEST_WORKER_ID = flags.DEFINE_integer( + "multiprocess_test_worker_id", + -1, + "Worker id. Set by main test process; should not be set by users.", +) + +_MULTIPROCESS_TEST_CONTROLLER_ADDRESS = flags.DEFINE_string( + "multiprocess_test_controller_address", + "", + "Address of the JAX controller. Set by the main test process; should not be" + " set by users.", +) + + +expect_failures_with_regex = None + + +def main(): + config.config_with_absl() + app.run(_main) + + +class GracefulKiller: + """Add a signal handler that sets a flag if SIGINT or SIGTERM are caught.""" + + # From https://stackoverflow.com/a/31464349 + kill_now = False + + def __init__(self): + signal.signal(signal.SIGINT, self.exit_gracefully) + signal.signal(signal.SIGTERM, self.exit_gracefully) + + def exit_gracefully(self, sig_num, unused_stack_frame): + print(f"Caught signal: {signal.Signals(sig_num).name} ({sig_num})") + self.kill_now = True + + +def _main(argv): + if _MULTIPROCESS_TEST_WORKER_ID.value >= 0: + jax.distributed.initialize( + _MULTIPROCESS_TEST_CONTROLLER_ADDRESS.value, + num_processes=_NUM_PROCESSES.value, + process_id=_MULTIPROCESS_TEST_WORKER_ID.value, + initialization_timeout=10, + ) + absltest.main(testLoader=jtu.JaxTestLoader()) + + if not argv[0].endswith(".py"): # Skip the interpreter path if present. + argv = argv[1:] + + num_processes = _NUM_PROCESSES.value + if num_processes is None: + raise ValueError("num_processes must be set") + gpus_per_process = _GPUS_PER_PROCESS.value + if portpicker is None: + jax_port = 9876 + else: + jax_port = portpicker.pick_unused_port() + subprocesses = [] + output_filenames = [] + output_files = [] + for i in range(num_processes): + env = os.environ.copy() + + args = [ + "/proc/self/exe", + *argv, + f"--num_processes={num_processes}", + f"--multiprocess_test_worker_id={i}", + f"--multiprocess_test_controller_address=localhost:{jax_port}", + "--logtostderr", + ] + + if gpus_per_process > 0: + gpus = range(i * gpus_per_process, (i + 1) * gpus_per_process) + env["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, gpus)) + + undeclared_outputs = os.environ.get("TEST_UNDECLARED_OUTPUTS_DIR", "/tmp") + stdout_name = f"{undeclared_outputs}/jax_{i}_stdout.log" + stderr_name = f"{undeclared_outputs}/jax_{i}_stderr.log" + stdout = open(stdout_name, "wb") + stderr = open(stderr_name, "wb") + print(f"Launching process {i}:") + print(f" stdout: {stdout_name}") + print(f" stderr: {stderr_name}") + proc = subprocess.Popen(args, env=env, stdout=stdout, stderr=stderr) + subprocesses.append(proc) + output_filenames.append((stdout_name, stderr_name)) + output_files.append((stdout, stderr)) + + print(" All launched, running ".center(80, "="), flush=True) + + # Wait for all the children to finish or for a SIGTERM from bazel. If we get + # SIGTERM, we still want to collect their logs, so kill them and continue. + killer = GracefulKiller() + running_procs = dict(enumerate(subprocesses)) + while not killer.kill_now and running_procs: + time.sleep(0.1) + for i, proc in list(running_procs.items()): + if proc.poll() is not None: + print(f"Process {i} finished.", flush=True) + running_procs.pop(i) + if killer.kill_now and running_procs: + print("Caught termination, terminating remaining children.", flush=True) + + # Send a SIGTERM to each child process, to let it know it should terminate. + for i, proc in running_procs.items(): + proc.terminate() + print(f"Process {i} terminated.", flush=True) + + # We give the child process(es) a few seconds for their own cleanup, and + # keep the rest (up to 15s) for copying the children logs into our own. + time.sleep(5) + + # Send a SIGKILL (a "hard" kill) to each child process. This is CRITICAL: + # without it, this process may end up waiting a long time on the proc.wait() + # below, and never get to saving the children logs, making test timeouts + # very hard to debug. + for i, proc in running_procs.items(): + proc.kill() + print(f"Process {i} killed.") + print("Killed all child processes.", flush=True) + + retvals = [] + stdouts = [] + stderrs = [] + for proc, fds, (stdout, stderr) in zip( + subprocesses, output_files, output_filenames + ): + retvals.append(proc.wait()) + for fd in fds: + fd.close() + stdouts.append(pathlib.Path(stdout).read_text(errors="replace")) + stderrs.append(pathlib.Path(stderr).read_text(errors="replace")) + + print(" All finished ".center(80, "="), flush=True) + + print(" Summary ".center(80, "=")) + for i, (retval, stdout, stderr) in enumerate(zip(retvals, stdouts, stderrs)): + m = re.search(r"Ran \d+ tests? in [\d.]+s\n\n.*", stderr, re.MULTILINE) + result = m.group().replace("\n\n", "; ") if m else "Test crashed?" + print( + f"Process {i}, ret: {retval}, len(stdout): {len(stdout)}, " + f"len(stderr): {len(stderr)}; {result}" + ) + + print(" Detailed logs ".center(80, "=")) + for i, (retval, stdout, stderr) in enumerate(zip(retvals, stdouts, stderrs)): + print(f" Process {i}: return code: {retval} ".center(80, "=")) + if stdout: + print(f" Process {i} stdout ".center(80, "-")) + print(stdout) + if stderr: + print(f" Process {i} stderr ".center(80, "-")) + print(stderr) + + print(" Done detailed logs ".center(80, "="), flush=True) + for i, (retval, stderr) in enumerate(zip(retvals, stderrs)): + if retval != 0: + if expect_failures_with_regex is not None: + assert re.search( + expect_failures_with_regex, stderr + ), f"process {i} failed, expected regex: {expect_failures_with_regex}" + else: + assert retval == 0, f"process {i} failed, return value: {retval}" + + +class MultiProcessTest(absltest.TestCase): + + def setUp(self): + """Start tests together.""" + super().setUp() + assert jax.process_count() == _NUM_PROCESSES.value, ( + jax.process_count(), + _NUM_PROCESSES.value, + ) + # Make sure all processes are at the same test case. + client = distributed.global_state.client + try: + client.wait_at_barrier(self._testMethodName + "_start", 10000) + except jax.errors.JaxRuntimeError as e: + msg, *_ = e.args + if msg.startswith("DEADLINE_EXCEEDED"): + raise RuntimeError( + f"Init or some test executed earlier than {self._testMethodName} " + "failed. Check logs from earlier tests to debug further. We " + "recommend debugging that specific failed test with " + "`--test_filter` before running the full test suite again." + ) from e + + def tearDown(self): + """End tests together.""" + client = distributed.global_state.client + # Ensure a shared fate for tests where a subset of processes run different + # test assertions (i.e. some processes may pass and some processes fail - + # but the overall test should fail). + try: + client.wait_at_barrier(self._testMethodName + "_end", 10000) + except jax.errors.JaxRuntimeError as e: + msg, *_ = e.args + if msg.startswith("DEADLINE_EXCEEDED"): + raise RuntimeError( + f"Test {self._testMethodName} failed in another process. We " + "recommend debugging that specific failed test with " + "`--test_filter` before running the full test suite again." + ) from e + super().tearDown() diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index c55dc2a560e0..f6810b533b31 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -17,7 +17,6 @@ import collections from collections.abc import Callable, Generator, Iterable, Sequence -from concurrent.futures import ThreadPoolExecutor from contextlib import ExitStack, contextmanager import datetime import functools @@ -32,12 +31,10 @@ import tempfile import textwrap import threading -import time from typing import Any, TextIO import unittest import zlib -from absl.testing import absltest from absl.testing import parameterized import jax from jax import lax @@ -51,21 +48,28 @@ from jax._src import lib as _jaxlib from jax._src import monitoring from jax._src import test_warning_util +from jax._src.typing import ArrayLike, DTypeLike from jax._src import xla_bridge from jax._src import util from jax._src import mesh as mesh_lib from jax._src.cloud_tpu_init import running_in_cloud_tpu_vm from jax._src.interpreters import mlir +from jax._src.lib.mlir.dialects import hlo from jax._src.numpy.util import promote_dtypes, promote_dtypes_inexact from jax._src.public_test_util import ( # noqa: F401 _assert_numpy_allclose, _check_dtypes_match, _default_tolerance, _dtype, check_close, check_grads, - check_jvp, check_vjp, default_gradient_tolerance, default_tolerance, rand_like, tolerance) + check_jvp, check_vjp, default_gradient_tolerance, default_tolerance, rand_like, tolerance, ToleranceDict) +from jax._src.test_loader import thread_unsafe_test as thread_unsafe_test +from jax._src.test_loader import thread_unsafe_test_class as thread_unsafe_test_class +from jax._src.test_loader import JaxTestLoader as JaxTestLoader +from jax._src.test_loader import TEST_NUM_THREADS as TEST_NUM_THREADS from jax._src.util import unzip2 from jax.tree_util import tree_all, tree_flatten, tree_map, tree_unflatten import numpy as np import numpy.random as npr + # This submodule includes private test utilities that are not exported to # jax.test_util. Functionality appearing here is for internal use only, and # may be changed or removed at any time and without any deprecation cycle. @@ -89,22 +93,12 @@ 'sampling process is terminated.' ) -_SKIP_SLOW_TESTS = config.bool_flag( +SKIP_SLOW_TESTS = config.bool_flag( 'jax_skip_slow_tests', config.bool_env('JAX_SKIP_SLOW_TESTS', False), help='Skip tests marked as slow (> 5 sec).' ) -_TEST_TARGETS = config.string_flag( - 'test_targets', os.getenv('JAX_TEST_TARGETS', ''), - 'Regular expression specifying which tests to run, called via re.search on ' - 'the test name. If empty or unspecified, run all tests.' -) -_EXCLUDE_TEST_TARGETS = config.string_flag( - 'exclude_test_targets', os.getenv('JAX_EXCLUDE_TEST_TARGETS', ''), - 'Regular expression specifying which tests NOT to run, called via re.search ' - 'on the test name. If empty or unspecified, run all tests.' -) TEST_WITH_PERSISTENT_COMPILATION_CACHE = config.bool_flag( 'jax_test_with_persistent_compilation_cache', config.bool_env('JAX_TEST_WITH_PERSISTENT_COMPILATION_CACHE', False), @@ -118,11 +112,6 @@ 'deterministic, interactive'), ) -TEST_NUM_THREADS = config.int_flag( - 'jax_test_num_threads', int(os.getenv('JAX_TEST_NUM_THREADS', '0')), - help='Number of threads to use for running tests. 0 means run everything ' - 'in the main thread. Using > 1 thread is experimental.' -) # We sanitize test names to ensure they work with "unitttest -k" and # "pytest -k" test filtering. pytest accepts '[' and ']' but unittest -k @@ -131,10 +120,10 @@ def sanitize_test_name(s: str) -> str: return kSanitizeNameRE.sub("_", s) -def num_float_bits(dtype): +def num_float_bits(dtype: DTypeLike) -> int: return _dtypes.finfo(_dtypes.canonicalize_dtype(dtype)).bits -def to_default_dtype(arr): +def to_default_dtype(arr: ArrayLike) -> np.ndarray: """Convert a value to an array with JAX's default dtype. This is generally used for type conversions of values returned by numpy functions, @@ -145,7 +134,7 @@ def to_default_dtype(arr): dtype = _dtypes._default_types.get(arr.dtype.kind) return arr.astype(_dtypes.canonicalize_dtype(dtype)) if dtype else arr -def with_jax_dtype_defaults(func, use_defaults=True): +def with_jax_dtype_defaults(func: Callable[..., Any], use_defaults: bool = True): """Return a version of a function with outputs that match JAX's default dtypes. This is generally used to wrap numpy functions within tests, in order to make @@ -168,7 +157,7 @@ def wrapped(*args, **kwargs): return tree_map(f, result, use_defaults) return wrapped -def is_sequence(x): +def is_sequence(x: Any) -> bool: try: iter(x) except TypeError: @@ -176,14 +165,16 @@ def is_sequence(x): else: return True -def _normalize_tolerance(tol): +def _normalize_tolerance(tol: int | float | ToleranceDict | None) -> ToleranceDict: tol = tol or 0 if isinstance(tol, dict): return {np.dtype(k): v for k, v in tol.items()} else: return dict.fromkeys(_default_tolerance, tol) -def join_tolerance(tol1, tol2): +def join_tolerance( + tol1: int | float | ToleranceDict | None, + tol2: int | float | ToleranceDict | None) -> ToleranceDict: tol1 = _normalize_tolerance(tol1) tol2 = _normalize_tolerance(tol2) out = tol1 @@ -192,7 +183,7 @@ def join_tolerance(tol1, tol2): return out -def check_eq(xs, ys, err_msg=''): +def check_eq(xs: Any, ys: Any, err_msg: str = '') -> None: assert_close = partial(_assert_numpy_allclose, err_msg=err_msg) tree_all(tree_map(assert_close, xs, ys)) @@ -363,6 +354,18 @@ def assert_num_jit_and_pmap_compilations(times): raise AssertionError(f"Expected exactly {times} XLA compilations, " f"but executed {count()}") +@contextmanager +def count_internal_device_puts(): + before = jax._src.lib._jax.get_internal_device_put_info() + counts = {} + try: + yield lambda: counts + finally: + after = jax._src.lib._jax.get_internal_device_put_info() + for k, v in after.items(): + diff = v - before.get(k, 0) + if diff != 0: + counts[k] = diff def jaxlib_version() -> tuple[int, ...]: return _jaxlib.version @@ -373,8 +376,9 @@ def device_under_test(): def supported_dtypes(): if device_under_test() == "tpu": - types = {np.bool_, np.int8, np.int16, np.int32, np.uint8, np.uint16, - np.uint32, _dtypes.bfloat16, np.float16, np.float32, np.complex64, + types = {np.bool_, _dtypes.int4, np.int8, np.int16, np.int32, + _dtypes.uint4, np.uint8, np.uint16, np.uint32, + _dtypes.bfloat16, np.float16, np.float32, np.complex64, _dtypes.float8_e4m3fn, _dtypes.float8_e4m3b11fnuz, _dtypes.float8_e5m2} elif device_under_test() == "gpu": @@ -386,8 +390,8 @@ def supported_dtypes(): elif device_under_test() == "METAL": types = {np.int32, np.uint32, np.float32} else: - types = {np.bool_, np.int8, np.int16, np.int32, np.int64, - np.uint8, np.uint16, np.uint32, np.uint64, + types = {np.bool_, _dtypes.int4, np.int8, np.int16, np.int32, np.int64, + _dtypes.uint4, np.uint8, np.uint16, np.uint32, np.uint64, _dtypes.bfloat16, np.float16, np.float32, np.float64, np.complex64, np.complex128} if not config.enable_x64.value: @@ -428,14 +432,22 @@ def pjrt_c_api_version_at_least(major_version: int, minor_version: int): return True return pjrt_c_api_versions >= (major_version, minor_version) +def stablehlo_version_at_least(required_version: str): + plugin_version = xla_bridge.backend_stablehlo_version() + if plugin_version is None: + return True + return hlo.get_smaller_version( + ".".join(map(str, plugin_version)), required_version + ) == plugin_version + def get_tpu_version() -> int: if device_under_test() != "tpu": raise ValueError("Device is not TPU") kind = jax.devices()[0].device_kind - if kind.endswith(' lite'): - kind = kind[:-len(' lite')] - assert kind[:-1] == "TPU v", kind - return int(kind[-1]) + match = re.match(r"TPU[^\d]*(\d+)", kind) + if match is None: + raise ValueError(f"Device kind {kind} is not supported") + return int(match.group(1)) def is_device_tpu_at_least(version: int) -> bool: if device_under_test() != "tpu": @@ -1044,165 +1056,6 @@ def sample_product(*args, **kw): """ return parameterized.parameters(*sample_product_testcases(*args, **kw)) -# We use a reader-writer lock to protect test execution. Tests that may run in -# parallel acquire a read lock; tests that are not thread-safe acquire a write -# lock. -_test_rwlock = util.Mutex() - -def _run_one_test(test: unittest.TestCase, result: ThreadSafeTestResult): - if getattr(test.__class__, "thread_hostile", False): - _test_rwlock.writer_lock() - try: - test(result) # type: ignore - finally: - _test_rwlock.writer_unlock() - else: - _test_rwlock.reader_lock() - try: - test(result) # type: ignore - finally: - _test_rwlock.reader_unlock() - - -@contextmanager -def thread_unsafe_test(): - """Decorator for tests that are not thread-safe. - - Note: this decorator (naturally) only applies to what it wraps, not to, say, - code in separate setUp() or tearDown() methods. - """ - if TEST_NUM_THREADS.value <= 0: - yield - return - - _test_rwlock.assert_reader_held() - _test_rwlock.reader_unlock() - _test_rwlock.writer_lock() - try: - yield - finally: - _test_rwlock.writer_unlock() - _test_rwlock.reader_lock() - - -def thread_unsafe_test_class(): - "Decorator that marks a TestCase class as thread-hostile." - def f(klass): - assert issubclass(klass, unittest.TestCase), type(klass) - klass.thread_hostile = True - return klass - return f - - -class ThreadSafeTestResult: - """ - Wraps a TestResult to make it thread safe. - - We do this by accumulating API calls and applying them in a batch under a - lock at the conclusion of each test case. - - We duck type instead of inheriting from TestResult because we aren't actually - a perfect implementation of TestResult, and would rather get a loud error - for things we haven't implemented. - """ - def __init__(self, lock: threading.Lock, result: unittest.TestResult): - self.lock = lock - self.test_result = result - self.actions: list[Callable] = [] - - def startTest(self, test: unittest.TestCase): - del test - self.start_time = time.time() - - def stopTest(self, test: unittest.TestCase): - stop_time = time.time() - with self.lock: - # If test_result is an ABSL _TextAndXMLTestResult we override how it gets - # the time. This affects the timing that shows up in the XML output - # consumed by CI. - time_getter = getattr(self.test_result, "time_getter", None) - try: - self.test_result.time_getter = lambda: self.start_time - self.test_result.startTest(test) - for callback in self.actions: - callback() - self.test_result.time_getter = lambda: stop_time - self.test_result.stopTest(test) - finally: - if time_getter is not None: - self.test_result.time_getter = time_getter - - def addSuccess(self, test: unittest.TestCase): - self.actions.append(lambda: self.test_result.addSuccess(test)) - - def addSkip(self, test: unittest.TestCase, reason: str): - self.actions.append(lambda: self.test_result.addSkip(test, reason)) - - def addError(self, test: unittest.TestCase, err): - self.actions.append(lambda: self.test_result.addError(test, err)) - - def addFailure(self, test: unittest.TestCase, err): - self.actions.append(lambda: self.test_result.addFailure(test, err)) - - def addExpectedFailure(self, test: unittest.TestCase, err): - self.actions.append(lambda: self.test_result.addExpectedFailure(test, err)) - - def addDuration(self, test: unittest.TestCase, elapsed): - self.actions.append(lambda: self.test_result.addDuration(test, elapsed)) - - -class JaxTestSuite(unittest.TestSuite): - """Runs tests in parallel using threads if TEST_NUM_THREADS is > 1. - - Caution: this test suite does not run setUpClass or setUpModule methods if - thread parallelism is enabled. - """ - - def __init__(self, suite: unittest.TestSuite): - super().__init__(list(suite)) - - def run(self, result: unittest.TestResult, debug: bool = False) -> unittest.TestResult: - if TEST_NUM_THREADS.value <= 0: - return super().run(result) - - test_warning_util.install_threadsafe_warning_handlers() - - executor = ThreadPoolExecutor(TEST_NUM_THREADS.value) - lock = threading.Lock() - futures = [] - - def run_test(test): - "Recursively runs tests in a test suite or test case." - if isinstance(test, unittest.TestSuite): - for subtest in test: - run_test(subtest) - else: - test_result = ThreadSafeTestResult(lock, result) - futures.append(executor.submit(_run_one_test, test, test_result)) - - with executor: - run_test(self) - for future in futures: - future.result() - - return result - - -class JaxTestLoader(absltest.TestLoader): - suiteClass = JaxTestSuite - - def getTestCaseNames(self, testCaseClass): - names = super().getTestCaseNames(testCaseClass) - if _TEST_TARGETS.value: - pattern = re.compile(_TEST_TARGETS.value) - names = [name for name in names - if pattern.search(f"{testCaseClass.__name__}.{name}")] - if _EXCLUDE_TEST_TARGETS.value: - pattern = re.compile(_EXCLUDE_TEST_TARGETS.value) - names = [name for name in names - if not pattern.search(f"{testCaseClass.__name__}.{name}")] - return names - def with_config(**kwds): """Test case decorator for subclasses of JaxTestCase""" @@ -1348,15 +1201,15 @@ def assertDeprecationWarnsOrRaises(self, deprecation_id: str, message: str): else: return self.assertWarnsRegex(DeprecationWarning, message) - def assertArraysEqual(self, x, y, *, check_dtypes=True, err_msg='', + def assertArraysEqual(self, actual, desired, *, check_dtypes=True, err_msg='', allow_object_dtype=False, verbose=True): """Assert that x and y arrays are exactly equal.""" if check_dtypes: - self.assertDtypesMatch(x, y) - x = np.asarray(x) - y = np.asarray(y) + self.assertDtypesMatch(actual, desired) + actual = np.asarray(actual) + desired = np.asarray(desired) - if (not allow_object_dtype) and (x.dtype == object or y.dtype == object): + if (not allow_object_dtype) and (actual.dtype == object or desired.dtype == object): # See https://github.com/jax-ml/jax/issues/17867 raise TypeError( "assertArraysEqual may be poorly behaved when np.asarray casts to dtype=object. " @@ -1366,57 +1219,57 @@ def assertArraysEqual(self, x, y, *, check_dtypes=True, err_msg='', # Work around https://github.com/numpy/numpy/issues/18992 with np.errstate(over='ignore'): - np.testing.assert_array_equal(x, y, err_msg=err_msg, + np.testing.assert_array_equal(actual, desired, err_msg=err_msg, verbose=verbose) - def assertArraysAllClose(self, x, y, *, check_dtypes=True, atol=None, + def assertArraysAllClose(self, actual, desired, *, check_dtypes=True, atol=None, rtol=None, err_msg=''): - """Assert that x and y are close (up to numerical tolerances).""" - self.assertEqual(x.shape, y.shape) - atol = max(tolerance(_dtype(x), atol), tolerance(_dtype(y), atol)) - rtol = max(tolerance(_dtype(x), rtol), tolerance(_dtype(y), rtol)) + """Assert that actual and desired are close (up to numerical tolerances).""" + self.assertEqual(actual.shape, desired.shape) + atol = max(tolerance(_dtype(actual), atol), tolerance(_dtype(desired), atol)) + rtol = max(tolerance(_dtype(actual), rtol), tolerance(_dtype(desired), rtol)) - _assert_numpy_allclose(x, y, atol=atol, rtol=rtol, err_msg=err_msg) + _assert_numpy_allclose(actual, desired, atol=atol, rtol=rtol, err_msg=err_msg) if check_dtypes: - self.assertDtypesMatch(x, y) + self.assertDtypesMatch(actual, desired) - def assertDtypesMatch(self, x, y, *, canonicalize_dtypes=True): + def assertDtypesMatch(self, actual, desired, *, canonicalize_dtypes=True): if not config.enable_x64.value and canonicalize_dtypes: - self.assertEqual(_dtypes.canonicalize_dtype(_dtype(x), allow_extended_dtype=True), - _dtypes.canonicalize_dtype(_dtype(y), allow_extended_dtype=True)) + self.assertEqual(_dtypes.canonicalize_dtype(_dtype(actual), allow_extended_dtype=True), + _dtypes.canonicalize_dtype(_dtype(desired), allow_extended_dtype=True)) else: - self.assertEqual(_dtype(x), _dtype(y)) + self.assertEqual(_dtype(actual), _dtype(desired)) - def assertAllClose(self, x, y, *, check_dtypes=True, atol=None, rtol=None, + def assertAllClose(self, actual, desired, *, check_dtypes=True, atol=None, rtol=None, canonicalize_dtypes=True, err_msg=''): - """Assert that x and y, either arrays or nested tuples/lists, are close.""" - if isinstance(x, dict): - self.assertIsInstance(y, dict) - self.assertEqual(set(x.keys()), set(y.keys())) - for k in x.keys(): - self.assertAllClose(x[k], y[k], check_dtypes=check_dtypes, atol=atol, + """Assert that actual and desired, either arrays or nested tuples/lists, are close.""" + if isinstance(actual, dict): + self.assertIsInstance(desired, dict) + self.assertEqual(set(actual.keys()), set(desired.keys())) + for k in actual.keys(): + self.assertAllClose(actual[k], desired[k], check_dtypes=check_dtypes, atol=atol, rtol=rtol, canonicalize_dtypes=canonicalize_dtypes, err_msg=err_msg) - elif is_sequence(x) and not hasattr(x, '__array__'): - self.assertTrue(is_sequence(y) and not hasattr(y, '__array__')) - self.assertEqual(len(x), len(y)) - for x_elt, y_elt in zip(x, y): - self.assertAllClose(x_elt, y_elt, check_dtypes=check_dtypes, atol=atol, + elif is_sequence(actual) and not hasattr(actual, '__array__'): + self.assertTrue(is_sequence(desired) and not hasattr(desired, '__array__')) + self.assertEqual(len(actual), len(desired)) + for actual_elt, desired_elt in zip(actual, desired): + self.assertAllClose(actual_elt, desired_elt, check_dtypes=check_dtypes, atol=atol, rtol=rtol, canonicalize_dtypes=canonicalize_dtypes, err_msg=err_msg) - elif hasattr(x, '__array__') or np.isscalar(x): - self.assertTrue(hasattr(y, '__array__') or np.isscalar(y)) + elif hasattr(actual, '__array__') or np.isscalar(actual): + self.assertTrue(hasattr(desired, '__array__') or np.isscalar(desired)) if check_dtypes: - self.assertDtypesMatch(x, y, canonicalize_dtypes=canonicalize_dtypes) - x = np.asarray(x) - y = np.asarray(y) - self.assertArraysAllClose(x, y, check_dtypes=False, atol=atol, rtol=rtol, + self.assertDtypesMatch(actual, desired, canonicalize_dtypes=canonicalize_dtypes) + actual = np.asarray(actual) + desired = np.asarray(desired) + self.assertArraysAllClose(actual, desired, check_dtypes=False, atol=atol, rtol=rtol, err_msg=err_msg) - elif x == y: + elif actual == desired: return else: - raise TypeError((type(x), type(y))) + raise TypeError((type(actual), type(desired))) def assertMultiLineStrippedEqual(self, expected, what): """Asserts two strings are equal, after dedenting and stripping each line.""" @@ -1431,7 +1284,6 @@ def assertMultiLineStrippedEqual(self, expected, what): self.assertMultiLineEqual(expected_clean, what_clean, msg=f"Found\n{what}\nExpecting\n{expected}") - @contextmanager def assertNoWarnings(self): with test_warning_util.raise_on_warnings(): @@ -1501,9 +1353,9 @@ def wrapped_fun(*args): python_should_be_executing = False compiled_ans = cfun(*args) - self.assertAllClose(python_ans, monitored_ans, check_dtypes=check_dtypes, + self.assertAllClose(monitored_ans, python_ans, check_dtypes=check_dtypes, atol=atol or tol, rtol=rtol or tol) - self.assertAllClose(python_ans, compiled_ans, check_dtypes=check_dtypes, + self.assertAllClose(compiled_ans, python_ans, check_dtypes=check_dtypes, atol=atol or tol, rtol=rtol or tol) args = args_maker() @@ -1514,7 +1366,7 @@ def wrapped_fun(*args): python_should_be_executing = False compiled_ans = cfun(*args) - self.assertAllClose(python_ans, compiled_ans, check_dtypes=check_dtypes, + self.assertAllClose(compiled_ans, python_ans, check_dtypes=check_dtypes, atol=atol or tol, rtol=rtol or tol) def _CheckAgainstNumpy(self, numpy_reference_op, lax_op, args_maker, @@ -1523,7 +1375,7 @@ def _CheckAgainstNumpy(self, numpy_reference_op, lax_op, args_maker, args = args_maker() lax_ans = lax_op(*args) numpy_ans = numpy_reference_op(*args) - self.assertAllClose(numpy_ans, lax_ans, check_dtypes=check_dtypes, + self.assertAllClose(lax_ans, numpy_ans, check_dtypes=check_dtypes, atol=atol or tol, rtol=rtol or tol, canonicalize_dtypes=canonicalize_dtypes) @@ -1575,12 +1427,12 @@ def with_and_without_mesh(f): ('Mesh', (('x', 2),), (('i', 'x'),)) ))(with_mesh_from_kwargs(f)) -def with_user_mesh(sizes, names, axis_types=None): +def with_explicit_mesh(sizes, names, axis_types=None, iota_order=False): axis_types = ((mesh_lib.AxisType.Explicit,) * len(names) if axis_types is None else axis_types) def decorator(fn): def mesh_fn(*args, **kwargs): - mesh = create_mesh(sizes, names, axis_types=axis_types) + mesh = create_mesh(sizes, names, iota_order, axis_types=axis_types) with jax.sharding.use_mesh(mesh): return fn(*args, **kwargs, mesh=mesh) return mesh_fn @@ -1630,15 +1482,11 @@ def custom_floats(self): _dtypes.float8_e4m3fnuz, _dtypes.float8_e5m2, _dtypes.float8_e5m2fnuz, + _dtypes.float8_e3m4, + _dtypes.float8_e4m3, + _dtypes.float8_e8m0fnu, + _dtypes.float4_e2m1fn, ] - if _dtypes.float8_e3m4 is not None: - float_dtypes += [_dtypes.float8_e3m4] - if _dtypes.float8_e4m3 is not None: - float_dtypes += [_dtypes.float8_e4m3] - if _dtypes.float8_e8m0fnu is not None: - float_dtypes += [_dtypes.float8_e8m0fnu] - if _dtypes.float4_e2m1fn is not None: - float_dtypes += [_dtypes.float4_e2m1fn] return self.supported(float_dtypes) @_cached_property diff --git a/jax/_src/third_party/scipy/signal_helper.py b/jax/_src/third_party/scipy/signal_helper.py index 4a021675804d..ad7bdfbef62a 100644 --- a/jax/_src/third_party/scipy/signal_helper.py +++ b/jax/_src/third_party/scipy/signal_helper.py @@ -57,7 +57,7 @@ def _triage_segments(window: ArrayLike | str | tuple[Any, ...], nperseg: int | N win = get_window(window, nperseg_int) win = jnp.array(win, dtype=dtype) else: - win = jnp.asarray(window) + win = jnp.asarray(window, dtype=dtype) nperseg_int = win.size if nperseg is None else int(nperseg) if win.ndim != 1: raise ValueError('window must be 1-D') diff --git a/jax/_src/tpu/__init__.py b/jax/_src/tpu/__init__.py new file mode 100644 index 000000000000..1337256a5074 --- /dev/null +++ b/jax/_src/tpu/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/jax/experimental/pallas/gpu.py b/jax/_src/tpu/linalg/__init__.py similarity index 63% rename from jax/experimental/pallas/gpu.py rename to jax/_src/tpu/linalg/__init__.py index 0ee84c8453ec..8c09b25d1e08 100644 --- a/jax/experimental/pallas/gpu.py +++ b/jax/_src/tpu/linalg/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2023 The JAX Authors. +# Copyright 2025 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,13 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from jax._src import deprecations +import os -deprecations.warn( - "pallas-gpu-triton", - "The ``jax.experimental.pallas.gpu`` submodule is deprecated. " - " Use ``jax.experimental.pallas.triton`` instead.", - stacklevel=1, +from jax._src.tpu.linalg import ( + eigh as eigh, + qdwh as qdwh, + svd as svd, ) -from jax.experimental.pallas.triton import * # noqa: F403 +from jax._src import traceback_util +traceback_util.register_exclusion(os.path.dirname(__file__)) diff --git a/jax/_src/lax/eigh.py b/jax/_src/tpu/linalg/eigh.py similarity index 90% rename from jax/_src/lax/eigh.py rename to jax/_src/tpu/linalg/eigh.py index 99711dc6bf0e..dda254579459 100644 --- a/jax/_src/lax/eigh.py +++ b/jax/_src/tpu/linalg/eigh.py @@ -33,15 +33,20 @@ import numpy as np import jax +from jax._src import core +from jax._src import dtypes import jax._src.numpy.lax_numpy as jnp import jax._src.numpy.linalg as jnp_linalg +from jax._src.interpreters import mlir from jax._src.numpy import tensor_contractions from jax._src.numpy import reductions from jax._src.numpy import ufuncs from jax import lax -from jax._src.lax import qdwh +from jax._src.lax import control_flow +from jax._src.lax import lax as lax_internal from jax._src.lax import linalg as lax_linalg -from jax._src.lax.stack import Stack +from jax._src.tpu.linalg import qdwh +from jax._src.tpu.linalg.stack import Stack # QDWH-eigh is a recursive algorithm where the structure of the recursion @@ -573,3 +578,63 @@ def eigh( eig_vecs = eig_vecs[:, sort_idxs] return eig_vals, eig_vecs + + +def _T(x: jax.Array) -> jax.Array: + return lax.transpose(x, (*range(x.ndim - 2), x.ndim - 1, x.ndim - 2)) + + +def _eigh_tpu_impl(x, *, lower, sort_eigenvalues, subset_by_index): + *_, m, n = x.shape + assert m == n, (m, n) + + termination_size = 256 + if not core.is_constant_dim(m): + # TODO: maybe we can relax the check below for shape polymorphism? + raise NotImplementedError( + "Shape polymorphism for native lowering for eigh is implemented " + f"only for the batch dimensions: {x.shape}") + if m <= termination_size and ( + subset_by_index is None or subset_by_index == (0, n) + ): + eig_vals, eig_vecs = lax_linalg.eigh_jacobi(x, lower=lower, + sort_eigenvalues=sort_eigenvalues) + return eig_vecs, eig_vals + + def eigh_qdwh(x): + if len(x.shape) > 2: + return control_flow.map(eigh_qdwh, x) + + # We should only look at elements from the lower/upper triangle. Reflects + # that triangle into the other triangle to form a Hermitian matrix. + if lower: + mask = lax_internal._tri(bool, (n, n), 0) + else: + mask = lax.bitwise_not(lax_internal._tri(bool, (n, n), -1)) + if dtypes.issubdtype(x.dtype, np.complexfloating): + re = lax.select(mask, lax.real(x), _T(lax.real(x))) + if lower: + im_mask = lax_internal._tri(bool, (n, n), -1) + else: + im_mask = lax.bitwise_not(lax_internal._tri(bool, (n, n), 0)) + im = lax.imag(x) + im = lax.select(im_mask, im, lax.full_like(im, 0)) + im = lax.select(mask, im, -_T(im)) + x = lax.complex(re, im) + else: + x = lax.select(mask, x, _T(x)) + + return eigh( + x, + sort_eigenvalues=sort_eigenvalues, + termination_size=termination_size, + subset_by_index=subset_by_index, + ) + + eig_vals, eig_vecs = eigh_qdwh(x) + return eig_vecs, eig_vals + + +mlir.register_lowering( + lax_linalg.eigh_p, mlir.lower_fun(_eigh_tpu_impl, multiple_results=True), + platform='tpu') diff --git a/jax/_src/lax/qdwh.py b/jax/_src/tpu/linalg/qdwh.py similarity index 100% rename from jax/_src/lax/qdwh.py rename to jax/_src/tpu/linalg/qdwh.py diff --git a/jax/_src/lax/stack.py b/jax/_src/tpu/linalg/stack.py similarity index 87% rename from jax/_src/lax/stack.py rename to jax/_src/tpu/linalg/stack.py index 882195f17d51..0225e66f43d8 100644 --- a/jax/_src/lax/stack.py +++ b/jax/_src/tpu/linalg/stack.py @@ -22,10 +22,12 @@ from typing import Any -import jax from jax import lax import jax.numpy as jnp +from jax._src import tree_util + + class Stack: """A bounded functional stack implementation. Elements may be pytrees.""" def __init__(self, size, data): @@ -45,7 +47,7 @@ def create(capacity: int, prototype: Any) -> Stack: """ return Stack( jnp.array(0, jnp.int32), - jax.tree_util.tree_map( + tree_util.tree_map( lambda x: jnp.zeros((capacity,) + tuple(x.shape), x.dtype), prototype)) def empty(self) -> Any: @@ -56,23 +58,23 @@ def push(self, elem: Any) -> Stack: """Pushes `elem` onto the stack, returning the updated stack.""" return Stack( self._size + 1, - jax.tree_util.tree_map( + tree_util.tree_map( lambda x, y: lax.dynamic_update_index_in_dim(x, y, self._size, 0), self._data, elem)) def pop(self) -> tuple[Any, Stack]: """Pops from the stack, returning an (elem, updated stack) pair.""" - elem = jax.tree_util.tree_map( + elem = tree_util.tree_map( lambda x: lax.dynamic_index_in_dim(x, self._size - 1, 0, keepdims=False), self._data) return elem, Stack(self._size - 1, self._data) def flatten(self): - leaves, treedef = jax.tree_util.tree_flatten(self._data) + leaves, treedef = tree_util.tree_flatten(self._data) return ([self._size] + leaves), treedef @staticmethod def unflatten(treedef, leaves): - return Stack(leaves[0], jax.tree_util.tree_unflatten(treedef, leaves[1:])) + return Stack(leaves[0], tree_util.tree_unflatten(treedef, leaves[1:])) -jax.tree_util.register_pytree_node(Stack, Stack.flatten, Stack.unflatten) +tree_util.register_pytree_node(Stack, Stack.flatten, Stack.unflatten) diff --git a/jax/_src/lax/svd.py b/jax/_src/tpu/linalg/svd.py similarity index 86% rename from jax/_src/lax/svd.py rename to jax/_src/tpu/linalg/svd.py index 9f22f130cbb2..298d6650b618 100644 --- a/jax/_src/lax/svd.py +++ b/jax/_src/tpu/linalg/svd.py @@ -43,6 +43,8 @@ import jax from jax import lax from jax._src import core +from jax._src.interpreters import mlir +from jax._src.lax import linalg as lax_linalg import jax.numpy as jnp @@ -110,6 +112,7 @@ def correct_rank_deficiency(u_out): u_out, _ = lax.while_loop(cond_f, body_f, (u_out, do_correction)) return (u_out, s_out, v_out) + @functools.partial(jax.jit, static_argnums=(1, 2, 3, 4, 5)) def svd( a: Any, @@ -241,3 +244,52 @@ def svd( return (v_out, s_out, u_out.T.conj()) return (u_out, s_out, v_out.T.conj()) + + +def _svd_tpu(a, *, full_matrices, compute_uv, subset_by_index, algorithm=None): + if algorithm is not None and algorithm != lax_linalg.SvdAlgorithm.DEFAULT: + raise NotImplementedError( + "The SVD algorithm parameter is not implemented on TPU.") + + batch_dims = a.shape[:-2] + fn = functools.partial( + svd, + full_matrices=full_matrices, + compute_uv=compute_uv, + subset_by_index=subset_by_index, + ) + for _ in range(len(batch_dims)): + fn = jax.vmap(fn) + + if compute_uv: + u, s, vh = fn(a) + return [s, u, vh] + else: + s = fn(a) + return [s] + + +def _svd_tpu_lowering_rule( + ctx, operand, *, full_matrices, compute_uv, subset_by_index, algorithm=None +): + del algorithm # unused + operand_aval, = ctx.avals_in + m, n = operand_aval.shape[-2:] + + if m == 0 or n == 0: + return mlir.lower_fun(lax_linalg._empty_svd, multiple_results=True)( + ctx, + operand, + full_matrices=full_matrices, + compute_uv=compute_uv, + ) + + return mlir.lower_fun(_svd_tpu, multiple_results=True)( + ctx, + operand, + full_matrices=full_matrices, + compute_uv=compute_uv, + subset_by_index=subset_by_index, + ) + +mlir.register_lowering(lax_linalg.svd_p, _svd_tpu_lowering_rule) diff --git a/jax/_src/tpu_custom_call.py b/jax/_src/tpu_custom_call.py index 4089e047f8b0..d0070f5a73ae 100644 --- a/jax/_src/tpu_custom_call.py +++ b/jax/_src/tpu_custom_call.py @@ -24,20 +24,18 @@ import enum import functools import io -import os -import time from typing import Any import jax from jax._src import config from jax._src import core from jax._src import sharding_impls +from jax._src.cloud_tpu_init import is_cloud_tpu_older_than from jax._src.interpreters import mlir from jax._src.lib import tpu from jax._src.lib import xla_client from jax.interpreters import xla from jaxlib.mlir import ir -from jaxlib.mlir.dialects import stablehlo from jaxlib.mlir.passmanager import PassManager try: @@ -46,16 +44,6 @@ except ImportError: FLAGS = {} -_MOSAIC_USE_PYTHON_PIPELINE = config.bool_state( - name="mosaic_use_python_pipeline", - default=False, - help=( - "Run the initial Mosaic MLIR passes from Python, when as_tpu_kernel" - " is called (for Pallas, this happens at JAX lowering time), instead of" - " later within XLA." - ), -) - _MOSAIC_ALLOW_HLO = config.bool_state( name="jax_mosaic_allow_hlo", default=False, @@ -63,8 +51,22 @@ ) -# This tracks the latest Mosaic IR version with a monthly delay. -FWD_COMPAT_IR_VERSION = 3 +# Controls the IR serialization version. Upon incrementing the +# default version in jaxlib/mosaic/dialect/tpu/transforms/serde.cc we must +# continue to use the old serialization version when in forward compatibility +# mode: for 1 month when exporting, or when using old cloud TPU. +# +# This can be achieved by adding: +# if ctx.is_forward_compat() or is_cloud_tpu_older_than(): +# return +# return None +# +# We should also add a TODO to remove the conditional one month later. +def get_ir_version(ctx: mlir.LoweringRuleContext) -> int | None: + # TODO: b/423649694 - remove the forward compatibility check after 2025-07-18 + if ctx.is_forward_compat() or is_cloud_tpu_older_than(2025, 6, 19): + return 4 + return None tpu_custom_call_p = core.Primitive("tpu_custom_call") @@ -73,12 +75,6 @@ tpu_custom_call_p.multiple_results = True -def get_target_shape(hardware_generation: int) -> tuple[int, int]: - """Returns the target shape for the given hardware generation.""" - del hardware_generation - return (8, 128) - - class MemorySpace(enum.Enum): HBM = enum.auto() VMEM = enum.auto() @@ -124,10 +120,18 @@ class CustomCallBackendConfig: needs_layout_passes: bool vmem_limit_bytes: int | None flags: dict[str, bool | int | float] | None - allow_input_fusion: list[bool] | None + allow_input_fusion: Sequence[bool] | None serialization_format: int | None internal_scratch_in_bytes: int | None output_memory_spaces: tuple[MemorySpace | None, ...] | None + disable_bounds_checks: bool + active_core_count: int | None + input_memory_spaces: tuple[MemorySpace | None, ...] | None + + def __post_init__(self): + if self.allow_input_fusion is not None: + object.__setattr__(self, "allow_input_fusion", + tuple(self.allow_input_fusion)) # We omit the body while printing, because primitive params get embedded # in HLO metadata, and the body blows up its size. @@ -171,13 +175,56 @@ def to_json(self) -> bytes: config.write(b', "internal_scratch_in_bytes": ') config.write(str(self.internal_scratch_in_bytes).encode("ascii")) if self.output_memory_spaces is not None: - config.write(b', "output_memory_colors": [') - for i, memory_space in enumerate(self.output_memory_spaces): - if i: + if len(self.output_memory_spaces) == 1: + output_memory_space = self.output_memory_spaces[0] + if output_memory_space is not None: + config.write(b', "output_memory_space_colors": [') + config.write( + f'{{"color":{output_memory_space.color}}}'.encode("ascii") + ) + config.write(b"]") + else: + comma = False + for i, output_memory_space in enumerate(self.output_memory_spaces): + if output_memory_space is None: + continue + if comma: + config.write(b",") + else: + config.write(b', "output_memory_space_colors": [') + config.write( + f'{{"shape_index":[{i}],"color":{output_memory_space.color}}}' + .encode("ascii") + ) + comma = True + if comma: + config.write(b"]") + if self.input_memory_spaces is not None: + comma = False + for i, input_memory_space in enumerate(self.input_memory_spaces): + if input_memory_space is None: + continue + if input_memory_space not in ( + MemorySpace.HBM, + MemorySpace.VMEM, + ): + raise NotImplementedError( + "input_memory_space_colors only supports HBM and VMEM" + ) + if comma: config.write(b",") - color = memory_space.color if memory_space is not None else -1 - config.write(str(color).encode("ascii")) - config.write(b"]") + else: + config.write(b', "input_memory_space_colors": [') + config.write( + f'{{"operand_index":{i},"color":{input_memory_space.color}}}' + .encode("ascii") + ) + comma = True + if comma: + config.write(b"]") + if self.disable_bounds_checks: + config.write(b', "disable_bounds_checks": ') + config.write(str(self.disable_bounds_checks).lower().encode("ascii")) config.write(b"}") # End of custom_call_config. if self.device_type is not None: config.write(b', "device_type": ') @@ -212,6 +259,8 @@ def to_json(self) -> bytes: if i + 1 != len(self.flags): config.write(b",") config.write(b"]") + if self.device_type == "sparsecore" and self.active_core_count == 1: + config.write(b', "megachip_parallelism_config": {"cores": ["0"]}') config.write(b"}") return config.getvalue() @@ -222,7 +271,7 @@ def _tpu_custom_call_abstract_eval(*_, out_avals, **__): def _avals_to_layouts(avals) -> Sequence[Sequence[int]]: - return [tuple(range(a.ndim - 1, -1, -1)) for a in avals] + return [tuple(range(a.ndim - 1, -1, -1)) for a in avals] # pytype: disable=attribute-error def _tpu_custom_call_lowering( @@ -233,7 +282,7 @@ def _tpu_custom_call_lowering( kernel_name: str | None, out_avals: Any, input_output_aliases: tuple[tuple[int, int], ...], -) -> ...: +) -> ir.OpResultList: result_types = [mlir.aval_to_ir_type(aval) for aval in out_avals] axis_context = ctx.module_context.axis_context if isinstance(axis_context, sharding_impls.SPMDAxisContext): @@ -286,166 +335,6 @@ def _tpu_custom_call_lowering( platform="tpu") -def _lower_tpu_kernel( - module: ir.Module, - hardware_generation: int, - target_shape: tuple[int, int], - kernel_name: str | None = None, -) -> ir.Module: - """Runs MLIR passes lowering the given module to an MLIR module. - - Uses Python versions of canonicalize-mosaic,infer-memref-layout and - apply-vector-layout. - - Args: - module: The MLIR module to lower. - hardware_generation: The TPU hardware generation to target. - target_shape: The target shape of (sublane_count, lane_count). - - Returns: - An MLIR module implementing the kernel. - """ - try: - module.operation.verify() - except ir.MLIRError as e: - raise ValueError("The compiled module fails MLIR verification") from e - - timestamp = time.time_ns() - dump_cnt = [0] - - def get_dump_file_prefix() -> str: - s = f"{timestamp}-{dump_cnt[0]:04}" - dump_cnt[0] += 1 - return s - - with module.context as ctx, module.operation.location as _: - ctx.append_dialect_registry(mlir.upstream_dialects) - ctx.load_all_available_dialects() - tpu.register_dialect(ctx) - stablehlo.register_dialect(ctx) - dump_mlir(module, "original", get_dump_file_prefix(), kernel_name) - - if _MOSAIC_ALLOW_HLO.value: - # Run dialect conversion: StableHLO -> linalg -> vector. - pipeline = [ - "func.func(stablehlo-legalize-to-linalg)", - "func.func(linalg-vectorization)", - ] - pipeline = PassManager.parse(f"builtin.module({','.join(pipeline)})") - pipeline.run(module.operation) - dump_mlir(module, "post-hlo-conversion", get_dump_file_prefix(), kernel_name) - - sl_cnt, l_cnt = target_shape - # Note: we don't pass the TpuTilingFlags here, since we don't know the - # tiling decisions made by the compiler / what flags are enabled at this - # point, so we assume everything can be tiled up to default tiling. - pipeline = [ - "func.func(tpu-infer-memref-layout{" - f" hardware-generation={hardware_generation}" - f" sublane-count={sl_cnt}" - f" lane-count={l_cnt}" - "})" - ] - pipeline = PassManager.parse(f"builtin.module({','.join(pipeline)})") - pipeline.run(module.operation) - dump_mlir(module, "post-infer-memref-layout", get_dump_file_prefix(), kernel_name) - - pipeline = [ - "canonicalize", - "cse", - ] - pipeline = PassManager.parse(f"builtin.module({','.join(pipeline)})") - pipeline.run(module.operation) - dump_mlir( - module, - "post-infer-memref-layout-simplify", - get_dump_file_prefix(), - kernel_name, - ) - - try: - on_device_checks = FLAGS["xla_mosaic_on_device_checks"].value - except KeyError: - on_device_checks = False - - if checks := on_device_checks: - checks = set(checks.split(",")) - if checks == {"bounds"}: # We only support one kind of checks now. - pipeline = PassManager.parse( - "builtin.module(func.func(debug-assert-insertion))" - ) - pipeline.run(module.operation) - dump_mlir(module, "post-assert-insertion", get_dump_file_prefix(), kernel_name) - elif checks: - checks.discard("bounds") - raise ValueError( - f"Unrecognized on-device check categories: {', '.join(checks)}" - ) - - # Legacy pipeline always runs in compatibility mode. - compatibility_mode = True - pipeline = [ - ( - f"func.func(tpu-canonicalize-mosaic{{hardware-generation={hardware_generation} compatibility-mode={compatibility_mode}}})" - ), - ] - pipeline = PassManager.parse(f"builtin.module({','.join(pipeline)})") - pipeline.run(module.operation) - dump_mlir(module, "post-canonicalize-mosaic", get_dump_file_prefix(), kernel_name) - - pipeline = [ - ( - "func.func(tpu-infer-vector-layout{" - f" hardware-generation={hardware_generation}" - f" sublane-count={sl_cnt} lane-count={l_cnt}" - "})" - ), - ] - pipeline = PassManager.parse(f"builtin.module({','.join(pipeline)})") - pipeline.run(module.operation) - dump_mlir(module, "post-infer-vector-layout", get_dump_file_prefix(), kernel_name) - - pipeline = [ - ( - "func.func(tpu-relayout-insertion{" - f" sublane-count={sl_cnt} lane-count={l_cnt}" - f" hardware-generation={hardware_generation}" - "})" - ), - ] - pipeline = PassManager.parse(f"builtin.module({','.join(pipeline)})") - pipeline.run(module.operation) - dump_mlir(module, "post-relayout-insertion", get_dump_file_prefix(), kernel_name) - - mxu_size = 128 if hardware_generation < 6 else 256 - pipeline = [ - "func.func(tpu-apply-vector-layout{" - f" sublane-count={sl_cnt} lane-count={l_cnt}" - f" hardware-generation={hardware_generation}" - f" mxu-contracting-size={mxu_size} mxu-noncontracting-size={mxu_size}" - f" max-sublanes-in-scratch={sl_cnt * (sl_cnt + 1)}" - "})" - ] - pipeline = PassManager.parse(f"builtin.module({','.join(pipeline)})") - pipeline.run(module.operation) - dump_mlir(module, "post-apply-vector-layout", get_dump_file_prefix(), kernel_name) - - pipeline = [ - "canonicalize", - "cse", - ] - pipeline = PassManager.parse(f"builtin.module({','.join(pipeline)})") - pipeline.run(module.operation) - dump_mlir( - module, - "post-apply-vector-layout-simplify", - get_dump_file_prefix(), - kernel_name, - ) - - return module - - def _lower_mosaic_module_to_asm( module: ir.Module, *, @@ -461,36 +350,12 @@ def _lower_mosaic_module_to_asm( needs_layout_passes = not device_type # We'll mutate the module, so clone it with module.context as ctx, module.operation.location as _: - if needs_layout_passes and _MOSAIC_USE_PYTHON_PIPELINE.value: - module = ir.Module.parse( - module.operation.get_asm(binary=True, enable_debug_info=True) - ) - module_op = module.operation - some_tpu = jax.devices(backend)[0] - device_kind = some_tpu.device_kind - if not device_kind.startswith("TPU v"): - raise ValueError( - f"Unrecognized TPU device kind: {device_kind}. " - "tpu_custom_call cannot be lowered on a machine without TPUs " - "when mosaic_use_python_pipeline=True.") - hardware_generation = int(device_kind[len("TPU v")]) - target_shape = get_target_shape(hardware_generation) - module = _lower_tpu_kernel( - module, hardware_generation, target_shape=target_shape, kernel_name=kernel_name, - ) - needs_hlo_passes = False - needs_layout_passes = False - else: - module_op = module.operation.clone() + module_op = module.operation.clone() prev_allow_unregistered_dialects = ctx.allow_unregistered_dialects ctx.allow_unregistered_dialects = True - # TODO(apaszke): Remove once the minimum jaxlib version is at least 0.4.37. - if jax.version._version_as_tuple(jax.lib.__version__) < (0, 4, 37): - target_version = "" - else: - target_version = ( - f"target-version={ir_version}" if ir_version is not None else "" - ) + target_version = ( + f"target-version={ir_version}" if ir_version is not None else "" + ) try: pipeline = PassManager.parse( "builtin.module(mosaic-serde{serialize=true " + target_version + "})" @@ -539,30 +404,107 @@ def assign_device_type_based_on_core_type(op: ir.Operation) -> ir.WalkResult: ) if tensorcore_func_found and sparsecore_func_found: raise ValueError( - "A single Mosaic kernel cannot contain both " - "TensorCore and SparseCore functions." + "A single Mosaic kernel cannot contain both TensorCore and SparseCore" + " functions." ) if sparsecore_func_found: return "sparsecore" return None +def _get_active_core_count(module: ir.Module) -> int | None: + + def get_core_parallel_dim_size( + dim_semantics: ir.ArrayAttr, + iter_bounds: ir.DenseI64ArrayAttr, + other_subkernel_core_dim_size: int | None = None) -> int | None: + + if len(iter_bounds) != len(dim_semantics): + raise ValueError( + "The iteration bounds and dimension semantics attributes must have" + " the same number of elements." + ) + + subkernel_core_dim_size = None + + for dim_idx, (dim_size, dim_sem) in enumerate( + zip(iter_bounds, dim_semantics) + ): + if str(dim_sem) != "#tpu.dimension_semantics": + continue + + if ir.ShapedType.is_dynamic_size(dim_size): + raise ValueError( + "The iteration bound corresponding to the core-parallel dimension " + f"{dim_idx} must be statically known." + ) + if subkernel_core_dim_size is not None: + raise ValueError( + "A single Mosaic subkernel cannot contain multiple core sharding " + "dimensions." + ) + if ( + other_subkernel_core_dim_size is not None + and other_subkernel_core_dim_size != dim_size + ): + raise ValueError( + "The iteration bound corresponding to the core-parallel dimension " + "be the same across all subkernels." + ) + subkernel_core_dim_size = dim_size + + return subkernel_core_dim_size + + core_parallel_dim_size = None + + for op in module.body.operations: + if op.operation.name != "func.func": + continue + + if ( + "iteration_bounds" not in op.attributes + or "dimension_semantics" not in op.attributes + ): + continue + + try: + iter_bounds = ir.DenseI64ArrayAttr(op.attributes["iteration_bounds"]) + except ValueError as e: + e.add_note("The iteration bounds attribute must be an array.") + raise + try: + dim_semantics = ir.ArrayAttr(op.attributes["dimension_semantics"]) + except ValueError as e: + e.add_note("The dimension semantics attribute must be an array.") + raise + + core_parallel_dim_size = get_core_parallel_dim_size( + dim_semantics=dim_semantics, + iter_bounds=iter_bounds, + other_subkernel_core_dim_size=core_parallel_dim_size, + ) + + return core_parallel_dim_size + + def _lower_to_custom_call_config( module: ir.Module, *, backend: str, - device_type: str | None, vmem_limit_bytes: int | None, cost_estimate: CostEstimate | None, flags: dict[str, bool | int | float] | None, - allow_input_fusion: list[bool] | None, + allow_input_fusion: Sequence[bool] | None, internal_scratch_in_bytes: int | None, collective_id: int | None, serialization_format: int | None, output_memory_spaces: tuple[MemorySpace | None, ...] | None = None, kernel_name: str | None = None, ir_version: int | None = None, + disable_bounds_checks: bool = False, + input_memory_spaces: tuple[MemorySpace | None, ...] | None = None, ) -> CustomCallBackendConfig: + device_type = _get_device_type(module) lowered_module_asm, ( has_communication, has_custom_barrier, @@ -575,6 +517,7 @@ def _lower_to_custom_call_config( kernel_name=kernel_name, ir_version=ir_version, ) + active_core_count = _get_active_core_count(module) return _lowered_to_custom_call_config( lowered_module_asm, vmem_limit_bytes=vmem_limit_bytes, @@ -590,6 +533,9 @@ def _lower_to_custom_call_config( needs_hlo_passes=needs_hlo_passes, needs_layout_passes=needs_layout_passes, output_memory_spaces=output_memory_spaces, + disable_bounds_checks=disable_bounds_checks, + active_core_count=active_core_count, + input_memory_spaces=input_memory_spaces, ) @@ -599,7 +545,7 @@ def _lowered_to_custom_call_config( vmem_limit_bytes: int | None, cost_estimate: CostEstimate | None, flags: dict[str, bool | int | float] | None, - allow_input_fusion: list[bool] | None, + allow_input_fusion: Sequence[bool] | None, internal_scratch_in_bytes: int | None, collective_id: int | None, serialization_format: int | None, @@ -609,6 +555,9 @@ def _lowered_to_custom_call_config( needs_layout_passes: bool, device_type: str | None, output_memory_spaces: tuple[MemorySpace | None, ...] | None = None, + disable_bounds_checks: bool = False, + active_core_count: int | None = None, + input_memory_spaces: tuple[MemorySpace | None, ...] | None = None, ): if has_custom_barrier: if collective_id is None: @@ -639,6 +588,9 @@ def _lowered_to_custom_call_config( serialization_format, internal_scratch_in_bytes, output_memory_spaces, + disable_bounds_checks, + active_core_count=active_core_count, + input_memory_spaces=input_memory_spaces, ) return config @@ -653,14 +605,15 @@ def lower_module_to_custom_call( cost_estimate: CostEstimate | None, vmem_limit_bytes: int | None, flags: dict[str, bool | int | float] | None, - allow_input_fusion: list[bool] | None, + allow_input_fusion: Sequence[bool] | None, input_output_aliases: tuple[tuple[int, int], ...], internal_scratch_in_bytes: int | None, collective_id: int | None, has_side_effects: bool, serialization_format: int | None, output_memory_spaces: tuple[MemorySpace | None, ...] | None, - device_type: str | None, + disable_bounds_checks: bool = False, + input_memory_spaces: tuple[MemorySpace | None, ...] | None, ) -> Sequence[ir.Value]: config = _lower_to_custom_call_config( module, @@ -671,11 +624,12 @@ def lower_module_to_custom_call( allow_input_fusion=allow_input_fusion, internal_scratch_in_bytes=internal_scratch_in_bytes, collective_id=collective_id, - device_type=device_type, serialization_format=serialization_format, output_memory_spaces=output_memory_spaces, kernel_name=kernel_name, - ir_version=FWD_COMPAT_IR_VERSION if ctx.is_forward_compat() else None, + ir_version=get_ir_version(ctx), + disable_bounds_checks=disable_bounds_checks, + input_memory_spaces=input_memory_spaces, ) return _tpu_custom_call_lowering( ctx, @@ -697,20 +651,20 @@ def as_tpu_kernel( kernel_name: str | None = None, vmem_limit_bytes: int | None = None, flags: dict[str, bool | int | float] | None = None, - allow_input_fusion: list[bool] | None = None, + allow_input_fusion: Sequence[bool] | None = None, input_output_aliases: tuple[tuple[int, int], ...] = (), internal_scratch_in_bytes: int | None = None, collective_id: int | None = None, has_side_effects: bool = False, serialization_format: int | None = 1, output_memory_spaces: tuple[MemorySpace | None, ...] | None = None, + disable_bounds_checks: bool = False, + input_memory_spaces: tuple[MemorySpace | None, ...] | None = None, ) -> Callable[..., Any]: """Turns an MLIR Mosaic kernel into a JAX-compatible function.""" - device_type = _get_device_type(module) config = _lower_to_custom_call_config( module, backend=backend, - device_type=device_type, vmem_limit_bytes=vmem_limit_bytes, cost_estimate=cost_estimate, flags=flags, @@ -720,6 +674,8 @@ def as_tpu_kernel( serialization_format=serialization_format, output_memory_spaces=output_memory_spaces, kernel_name=kernel_name, + disable_bounds_checks=disable_bounds_checks, + input_memory_spaces=input_memory_spaces, ) return _as_jax_callable( config, @@ -738,18 +694,19 @@ def lowered_as_tpu_kernel( cost_estimate: CostEstimate | None = None, needs_hlo_passes: bool = False, needs_layout_passes: bool = False, - device_type: str | None = None, has_communication: bool = False, has_side_effects: bool = False, has_custom_barrier: bool = False, kernel_name: str | None = None, vmem_limit_bytes: int | None = None, flags: dict[str, bool | int | float] | None = None, - allow_input_fusion: list[bool] | None = None, + allow_input_fusion: Sequence[bool] | None = None, input_output_aliases: tuple[tuple[int, int], ...] = (), serialization_format: int | None = None, internal_scratch_in_bytes: int | None = None, + disable_bounds_checks: bool = False, ) -> Callable[..., Any]: + device_type = _get_device_type(lowered_module) lowered_module_asm = lowered_module.operation.get_asm( binary=True, enable_debug_info=True ) @@ -767,6 +724,7 @@ def lowered_as_tpu_kernel( has_communication=has_communication, needs_hlo_passes=needs_hlo_passes, needs_layout_passes=needs_layout_passes, + disable_bounds_checks=disable_bounds_checks, ) return _as_jax_callable( config, @@ -804,21 +762,3 @@ def apply_kernel(*args): return result[0] if unpack else result return jax.jit(apply_kernel) - - -def dump_mlir( - module: ir.Module, name: str, prefix: str, kernel_name: str | None = None -): - """A helper function to dump mosaic mlir module""" - try: - should_dump = FLAGS["xla_mosaic_dump_to"].value - except KeyError: - return - if should_dump == "sponge": - outdir = os.environ.get("TEST_UNDECLARED_OUTPUTS_DIR", None) - if outdir: - if kernel_name: - name = f"{kernel_name}-{name}" - path = os.path.join(outdir, f"{prefix}-mosaic-dump-{name}-py.txt") - with open(path, "w") as f: - f.write(str(module)) diff --git a/jax/_src/traceback_util.py b/jax/_src/traceback_util.py index d66cbb912a99..f1cdf86d9929 100644 --- a/jax/_src/traceback_util.py +++ b/jax/_src/traceback_util.py @@ -17,14 +17,12 @@ from collections.abc import Callable import functools import os -import sys import traceback import types from typing import Any, TypeVar, cast from jax._src import config from jax._src import util -from jax._src.lib import xla_extension C = TypeVar("C", bound=Callable[..., Any]) @@ -56,8 +54,10 @@ def _path_starts_with(path: str, path_prefix: str) -> bool: return False def include_frame(f: types.FrameType) -> bool: - return not any(_path_starts_with(f.f_code.co_filename, path) - for path in _exclude_paths) + return include_filename(f.f_code.co_filename) + +def include_filename(filename: str) -> bool: + return not any(_path_starts_with(filename, path) for path in _exclude_paths) # When scanning stack traces, we might encounter frames from cpython that are # removed from printed stack traces, such as frames from parts of importlib. We @@ -191,25 +191,10 @@ def reraise_with_filtered_traceback(*args, **kwargs): tb = e.__traceback__ filtered_tb = filter_traceback(tb) e.with_traceback(filtered_tb) - # In Python < 3.11, there seems to be no way to alter the currently - # raised exception traceback, except via the C API. The interpreter - # keeps a copy of the traceback (exc_traceback) that is separate to the - # __traceback__ of exc_value. Python 3.11 removes exc_traceback and - # just setting __traceback__ is enough. Since it is no longer needed, - # the XLA extension no longer defines a traceback-replacing method at - # Python 3.11 and onward. - if hasattr(xla_extension, "replace_thread_exc_traceback"): - # TODO(kidger): remove this line once Python 3.11 is the minimum supported - # version. - xla_extension.replace_thread_exc_traceback(filtered_tb) - if sys.version_info >= (3, 11) and mode == "quiet_remove_frames": + if mode == "quiet_remove_frames": e.add_note("--------------------\n" + _simplified_tb_msg) else: - if mode == "quiet_remove_frames": - # TODO(kidger): remove `SimplifiedTraceback` once Python 3.11 is the minimum - # supported version. - jax_error = SimplifiedTraceback() - elif mode == "remove_frames": + if mode == "remove_frames": msg = format_exception_only(e) msg = f'{msg}\n\n{_jax_message_append}' jax_error = UnfilteredStackTrace(msg) diff --git a/jax/_src/tree.py b/jax/_src/tree.py index 70d75a126804..d1d3be41b917 100644 --- a/jax/_src/tree.py +++ b/jax/_src/tree.py @@ -287,7 +287,8 @@ def unflatten(treedef: tree_util.PyTreeDef, def flatten_with_path( - tree: Any, is_leaf: Callable[[Any], bool] | None = None + tree: Any, is_leaf: Callable[..., bool] | None = None, + is_leaf_takes_path: bool = False, ) -> tuple[list[tuple[tree_util.KeyPath, Any]], tree_util.PyTreeDef]: """Flattens a pytree like ``tree_flatten``, but also returns each leaf's key path. @@ -313,11 +314,12 @@ def flatten_with_path( - :func:`jax.tree.map_with_path` - :func:`jax.tree_util.register_pytree_with_keys` """ - return tree_util.tree_flatten_with_path(tree, is_leaf) + return tree_util.tree_flatten_with_path(tree, is_leaf, is_leaf_takes_path) def leaves_with_path( - tree: Any, is_leaf: Callable[[Any], bool] | None = None + tree: Any, is_leaf: Callable[..., bool] | None = None, + is_leaf_takes_path: bool = False, ) -> list[tuple[tree_util.KeyPath, Any]]: """Gets the leaves of a pytree like ``tree_leaves`` and returns each leaf's key path. @@ -338,14 +340,15 @@ def leaves_with_path( - :func:`jax.tree.flatten_with_path` - :func:`jax.tree_util.register_pytree_with_keys` """ - return tree_util.tree_leaves_with_path(tree, is_leaf) + return tree_util.tree_leaves_with_path(tree, is_leaf, is_leaf_takes_path) def map_with_path( f: Callable[..., Any], tree: Any, *rest: Any, - is_leaf: Callable[[Any], bool] | None = None, + is_leaf: Callable[..., bool] | None = None, + is_leaf_takes_path: bool = False, ) -> Any: """Maps a multi-input function over pytree key path and args to produce a new pytree. @@ -377,4 +380,37 @@ def map_with_path( - :func:`jax.tree.leaves_with_path` - :func:`jax.tree_util.register_pytree_with_keys` """ - return tree_util.tree_map_with_path(f, tree, *rest, is_leaf=is_leaf) + return tree_util.tree_map_with_path( + f, tree, *rest, is_leaf=is_leaf, is_leaf_takes_path=is_leaf_takes_path + ) + + +def broadcast(prefix_tree: Any, full_tree: Any, + is_leaf: Callable[[Any], bool] | None = None + ) -> Any: + """Broadcasts a tree prefix into the full structure of a given tree. + + Args: + prefix_tree: a pytree that is a tree prefix of full_tree. + full_tree: a pytree with the structure to broadcast the prefix leaves into. + is_leaf: an optionally specified function that will be called at each + flattening step. It should return a boolean, with true stopping the + traversal and the whole subtree being treated as a leaf, and false + indicating the flattening should traverse the current object. + + Returns: + A pytree matching the structure of full_tree where the leaves of prefix_tree have been + broadcasted into the leaves of each corresponding subtree. + + Examples: + >>> import jax + >>> prefix = (1, 2, 3) + >>> full = (0, {'a': 0, 'b': 0}, (0, 0)) + >>> jax.tree.broadcast(prefix, full) + (1, {'a': 2, 'b': 2}, (3, 3)) + + See Also: + - :func:`jax.tree.leaves` + - :func:`jax.tree.structure` + """ + return tree_util.tree_broadcast(prefix_tree, full_tree, is_leaf=is_leaf) diff --git a/jax/_src/tree_util.py b/jax/_src/tree_util.py index 6c7e15a042e5..c57ab2109c56 100644 --- a/jax/_src/tree_util.py +++ b/jax/_src/tree_util.py @@ -21,7 +21,7 @@ from functools import partial import operator as op import textwrap -from typing import Any, NamedTuple, TypeVar, overload +from typing import Any, TypeVar, overload from jax._src import traceback_util from jax._src.lib import pytree @@ -202,8 +202,11 @@ def all_leaves(iterable: Iterable[Any], if is_leaf is None: return pytree.all_leaves(default_registry, iterable) else: - lst = list(iterable) - return lst == tree_leaves(lst, is_leaf) + items = list(iterable) + leaves = tree_leaves(items, is_leaf) + return len(leaves) == len(items) and all( + item is leaf for item, leaf in zip(items, leaves, strict=True) + ) _Children = TypeVar("_Children", bound=Iterable[Any]) @@ -362,6 +365,8 @@ def tree_map(f: Callable[..., Any], def build_tree(treedef: PyTreeDef, xs: Any) -> Any: """Build a treedef from a nested iterable structure + DEPRECATED: Use :func:`jax.tree.unflatten` instead. + Args: treedef: the PyTreeDef structure to build. xs: nested iterables matching the arity as the treedef @@ -376,13 +381,6 @@ def build_tree(treedef: PyTreeDef, xs: Any) -> Any: >>> import jax >>> tree = [(1, 2), {'a': 3, 'b': 4}] >>> treedef = jax.tree.structure(tree) - - Both ``build_tree`` and :func:`jax.tree_util.tree_unflatten` can reconstruct - the tree from new values, but ``build_tree`` takes these values in terms of - a nested rather than flat structure: - - >>> jax.tree_util.build_tree(treedef, [[10, 11], [12, 13]]) - [(10, 11), {'a': 12, 'b': 13}] >>> jax.tree_util.tree_unflatten(treedef, [10, 11, 12, 13]) [(10, 11), {'a': 12, 'b': 13}] """ @@ -534,7 +532,7 @@ class Partial(functools.partial): >>> print_zero() 0 >>> call_func(print_zero) # doctest:+ELLIPSIS - Tracedwith + Traced<~int32[]>with """ def __new__(klass, func, *args, **kw): @@ -562,17 +560,42 @@ def __new__(klass, func, *args, **kw): ) -# broadcast_prefix is not exported. +@export +def tree_broadcast(prefix_tree: Any, full_tree: Any, + is_leaf: Callable[[Any], bool] | None = None + ) -> Any: + """Alias of :func:`jax.tree.broadcast`.""" + broadcast_leaves = broadcast_prefix(prefix_tree, full_tree, is_leaf=is_leaf) + return tree_structure(full_tree).unflatten(broadcast_leaves) + + +# broadcast_prefix is not exported def broadcast_prefix(prefix_tree: Any, full_tree: Any, is_leaf: Callable[[Any], bool] | None = None ) -> list[Any]: - # If prefix_tree is not a tree prefix of full_tree, this code can raise a - # ValueError; use prefix_errors to find disagreements and raise more precise - # error messages. + """Broadcasts tree prefix leaves into the full set of leaves for a given full tree. + + Args: + prefix_tree: a pytree that is a tree prefix of full_tree. + full_tree: a pytree with the structure to broadcast the prefix leaves into. + is_leaf: an optionally specified function that will be called at each + flattening step. It should return a boolean, with true stopping the + traversal and the whole subtree being treated as a leaf, and false + indicating the flattening should traverse the current object. + + Returns: + A list of leaves matching the expected count for the full tree, + with the leaf of each prefix tree being duplicated to match the count of + its corresponding subtree. + """ result = [] num_leaves = lambda t: tree_structure(t).num_leaves add_leaves = lambda x, subtree: result.extend([x] * num_leaves(subtree)) - tree_map(add_leaves, prefix_tree, full_tree, is_leaf=is_leaf) + try: + tree_map(add_leaves, prefix_tree, full_tree, is_leaf=is_leaf) + except ValueError: + e, *_ = prefix_errors(prefix_tree, full_tree) + raise e('broadcast_prefix prefix_tree') from None return result @@ -767,42 +790,6 @@ def _simple_entrystr(key: KeyEntry) -> str: return str(key) -# TODO(ivyzheng): remove this after another jaxlib release. -class _RegistryWithKeypathsEntry(NamedTuple): - flatten_with_keys: Callable[..., Any] - unflatten_func: Callable[..., Any] - - -def _register_keypaths( - ty: type[T], handler: Callable[[T], tuple[KeyEntry, ...]] -) -> None: - def flatten_with_keys(xs): - children, treedef = _registry[ty].to_iter(xs) - return list(zip(handler(xs), children)), treedef - if ty in _registry: - _registry_with_keypaths[ty] = _RegistryWithKeypathsEntry( - flatten_with_keys, _registry[ty].from_iter - ) - -_registry_with_keypaths: dict[type[Any], _RegistryWithKeypathsEntry] = {} - -_register_keypaths( - tuple, lambda xs: tuple(SequenceKey(i) for i in range(len(xs))) -) -_register_keypaths( - list, lambda xs: tuple(SequenceKey(i) for i in range(len(xs))) -) -_register_keypaths(dict, lambda xs: tuple(DictKey(k) for k in sorted(xs))) - -_register_keypaths( - collections.defaultdict, lambda x: tuple(DictKey(k) for k in x.keys()) -) - -_register_keypaths( - collections.OrderedDict, lambda x: tuple(DictKey(k) for k in x.keys()) -) - - @export def register_pytree_with_keys( nodetype: type[T], @@ -872,9 +859,6 @@ def flatten_func_impl(tree): register_pytree_node( nodetype, flatten_func, unflatten_func, flatten_with_keys ) - _registry_with_keypaths[nodetype] = _RegistryWithKeypathsEntry( - flatten_with_keys, unflatten_func - ) @export @@ -1067,11 +1051,6 @@ def register_dataclass( msg += f" Unexpected fields: {unexpected}." raise ValueError(msg) - def flatten_with_keys(x): - meta = tuple(getattr(x, name) for name in meta_fields) - data = tuple((GetAttrKey(name), getattr(x, name)) for name in data_fields) - return data, meta - def unflatten_func(meta, data): meta_args = tuple(zip(meta_fields, meta)) data_args = tuple(zip(data_fields, data)) @@ -1087,9 +1066,6 @@ def flatten_func(x): none_leaf_registry.register_dataclass_node(nodetype, list(data_fields), list(meta_fields)) dispatch_registry.register_dataclass_node(nodetype, list(data_fields), list(meta_fields)) _registry[nodetype] = _RegistryEntry(flatten_func, unflatten_func) - _registry_with_keypaths[nodetype] = _RegistryWithKeypathsEntry( - flatten_with_keys, unflatten_func - ) return nodetype @@ -1150,34 +1126,38 @@ def register_static(cls: type[H]) -> type[H]: @export def tree_flatten_with_path( - tree: Any, is_leaf: Callable[[Any], bool] | None = None + tree: Any, is_leaf: Callable[..., bool] | None = None, + is_leaf_takes_path: bool = False, ) -> tuple[list[tuple[KeyPath, Any]], PyTreeDef]: """Alias of :func:`jax.tree.flatten_with_path`.""" - return default_registry.flatten_with_path(tree, is_leaf) + is_leaf_with_kp: Callable[[Any, Any], bool] | None = is_leaf + if not is_leaf_takes_path and is_leaf is not None: + is_leaf_with_kp = lambda _, x: is_leaf(x) + return default_registry.flatten_with_path(tree, is_leaf_with_kp) @export def tree_leaves_with_path( - tree: Any, is_leaf: Callable[[Any], bool] | None = None + tree: Any, is_leaf: Callable[..., bool] | None = None, + is_leaf_takes_path: bool = False, ) -> list[tuple[KeyPath, Any]]: """Alias of :func:`jax.tree.leaves_with_path`.""" - return tree_flatten_with_path(tree, is_leaf)[0] - - -# generate_key_paths is not exported. -def generate_key_paths( - tree: Any, is_leaf: Callable[[Any], bool] | None = None -) -> list[tuple[KeyPath, Any]]: - return tree_leaves_with_path(tree, is_leaf) -_generate_key_paths = generate_key_paths # alias for backward compat + return tree_flatten_with_path(tree, is_leaf, is_leaf_takes_path)[0] +generate_key_paths = tree_leaves_with_path @export -def tree_map_with_path(f: Callable[..., Any], - tree: Any, *rest: Any, - is_leaf: Callable[[Any], bool] | None = None) -> Any: +def tree_map_with_path( + f: Callable[..., Any], + tree: Any, + *rest: Any, + is_leaf: Callable[..., bool] | None = None, + is_leaf_takes_path: bool = False, +) -> Any: """Alias of :func:`jax.tree.map_with_path`.""" - keypath_leaves, treedef = tree_flatten_with_path(tree, is_leaf) + keypath_leaves, treedef = tree_flatten_with_path( + tree, is_leaf, is_leaf_takes_path + ) keypath_leaves = list(zip(*keypath_leaves)) all_keypath_leaves = keypath_leaves + [treedef.flatten_up_to(r) for r in rest] return treedef.unflatten(f(*xs) for xs in zip(*all_keypath_leaves)) diff --git a/jax/_src/typing.py b/jax/_src/typing.py index 010841b45dd2..ee2422dd2d73 100644 --- a/jax/_src/typing.py +++ b/jax/_src/typing.py @@ -47,7 +47,19 @@ @typing.runtime_checkable class SupportsDType(Protocol): @property - def dtype(self) -> DType: ... + def dtype(self, /) -> DType: ... + +class SupportsShape(Protocol): + @property + def shape(self, /) -> tuple[int, ...]: ... + +class SupportsSize(Protocol): + @property + def size(self, /) -> int: ... + +class SupportsNdim(Protocol): + @property + def ndim(self, /) -> int: ... # DTypeLike is meant to annotate inputs to np.dtype that return # a valid JAX dtype. It's different than numpy.typing.DTypeLike diff --git a/jax/_src/util.py b/jax/_src/util.py index 0e28aea04b5a..1b9102af2d95 100644 --- a/jax/_src/util.py +++ b/jax/_src/util.py @@ -20,20 +20,24 @@ from functools import partial import itertools as it import logging +import math import operator -from typing import (Any, Generic, TypeVar, overload, TYPE_CHECKING, cast) +from typing import (Any, Generic, SupportsIndex, TypeVar, overload, TYPE_CHECKING, cast) import weakref import numpy as np from jax._src import config -from jax._src.lib import xla_client as xc +from jax._src.lib import weakref_lru_cache as _weakref_lru_cache from jax._src.lib import utils as jaxlib_utils logger = logging.getLogger(__name__) Seq = Sequence +# TODO(jakevdp): fix import cycles and import Array. +Array = Any + T = TypeVar("T") T1 = TypeVar("T1") T2 = TypeVar("T2") @@ -54,12 +58,19 @@ def safe_zip(__arg1: Iterable[T1], __arg2: Iterable[T2], __arg3: Iterable[T3]) - def safe_zip(__arg1: Iterable[Any], __arg2: Iterable[Any], __arg3: Iterable[Any], __arg4: Iterable[Any], *args) -> list[tuple[Any, ...]]: ... def safe_zip(*args): - args = list(map(list, args)) - n = len(args[0]) - for arg in args[1:]: - assert len(arg) == n, f'length mismatch: {list(map(len, args))}' - return list(zip(*args)) - + """ + Like builtin :func:`zip`, but with additional safety checks. + + The differences from :func:`zip` are: + + - :func:`safe_zip` checks that at least one argument is provided. + - :func:`safe_zip` checks that all arguments have the same length. + - :func:`safe_zip` returns an eagerly-evaluated list instead of a + lazily-evaluated iterator. + """ + if not args: + raise TypeError("safe_zip requires at least 1 argument.") + return list(zip(*args, strict=True)) else: safe_zip = jaxlib_utils.safe_zip @@ -108,11 +119,7 @@ def foreach(f, *args): return None else: - # TODO(phawkins): remove after jaxlib 0.5.2 is the minimum. - if hasattr(jaxlib_utils, 'foreach'): - foreach = jaxlib_utils.foreach - else: - foreach = safe_map + foreach = jaxlib_utils.foreach def unzip2(xys: Iterable[tuple[T1, T2]] @@ -141,13 +148,15 @@ def unzip3(xyzs: Iterable[tuple[T1, T2, T3]] zs.append(z) return tuple(xs), tuple(ys), tuple(zs) -def subvals(lst, replace): +def subvals(lst: Sequence[T], replace: Iterable[tuple[int, T]]) -> tuple[T, ...]: + """Substitute values within a list.""" lst = list(lst) for i, v in replace: lst[i] = v return tuple(lst) def split_list(args: Sequence[T], ns: Sequence[int]) -> list[list[T]]: + """Split list into sublists of the specified sizes.""" args = list(args) lists = [] for n in ns: @@ -157,8 +166,9 @@ def split_list(args: Sequence[T], ns: Sequence[int]) -> list[list[T]]: return lists def split_list_checked(args: Sequence[T], ns: Sequence[int]) -> list[list[T]]: + """Split list into sublists of the specified sizes.""" args = list(args) - assert sum(ns) == len(args) + assert sum(ns) == len(args) and all(n >= 0 for n in ns) lists = [] for n in ns: lists.append(args[:n]) @@ -166,8 +176,9 @@ def split_list_checked(args: Sequence[T], ns: Sequence[int]) -> list[list[T]]: return lists def partition_list(bs: Sequence[bool], l: Sequence[T]) -> tuple[list[T], list[T]]: + """Partition a list into two based on a mask.""" assert len(bs) == len(l) - lists = [], [] # type: ignore + lists: tuple[list[T], list[T]] = ([], []) for b, x in zip(bs, l): lists[b].append(x) return lists @@ -176,6 +187,7 @@ def merge_lists(bs: Sequence[bool], l0: Sequence[T1], l1: Sequence[T2] ) -> list[T1 | T2]: + """Merge the elements of two lists based on a mask.""" assert sum(bs) == len(l1) and len(bs) - sum(bs) == len(l0) i0, i1 = iter(l0), iter(l1) out: list[T1 | T2] = [next(i1) if b else next(i0) for b in bs] @@ -204,7 +216,7 @@ def subs_list2( assert next(base_, sentinel) is sentinel return out -def split_dict(dct, names): +def split_dict(dct: dict[T1, T2], names: Sequence[T1]) -> list[T2]: dct = dict(dct) lst = [dct.pop(name) for name in names] assert not dct @@ -244,64 +256,14 @@ def curry(f): """ return wraps(f)(partial(partial, f)) -# TODO(phawkins): make this unconditional after jaxlib 0.5.3 is the minimum. toposort: Callable[[Iterable[Any]], list[Any]] -if hasattr(jaxlib_utils, "topological_sort"): - toposort = partial(jaxlib_utils.topological_sort, "parents") -else: - - def toposort(end_nodes): - if not end_nodes: - return [] - end_nodes = _remove_duplicates(end_nodes) - - child_counts = {} - stack = list(end_nodes) - while stack: - node = stack.pop() - if id(node) in child_counts: - child_counts[id(node)] += 1 - else: - child_counts[id(node)] = 1 - stack.extend(node.parents) - for node in end_nodes: - child_counts[id(node)] -= 1 - - sorted_nodes = [] - childless_nodes = [ - node for node in end_nodes if child_counts[id(node)] == 0 - ] - assert childless_nodes - while childless_nodes: - node = childless_nodes.pop() - sorted_nodes.append(node) - for parent in node.parents: - if child_counts[id(parent)] == 1: - childless_nodes.append(parent) - else: - child_counts[id(parent)] -= 1 - sorted_nodes = sorted_nodes[::-1] - - check_toposort(sorted_nodes) - return sorted_nodes - - def check_toposort(nodes): - visited = set() - for node in nodes: - assert all(id(parent) in visited for parent in node.parents) - visited.add(id(node)) - - def _remove_duplicates(node_list): - seen = set() - out = [] - for n in node_list: - if id(n) not in seen: - seen.add(id(n)) - out.append(n) - return out +toposort = partial(jaxlib_utils.topological_sort, "parents") -def split_merge(predicate, xs): +def split_merge( + predicate: Callable[[T], bool], + xs: Sequence[T] +) -> tuple[list[T], list[T], Callable[[Sequence[T], Sequence[T]], list[T]]]: sides = list(map(predicate, xs)) lhs = [x for x, s in zip(xs, sides) if s] rhs = [x for x, s in zip(xs, sides) if not s] @@ -320,7 +282,6 @@ def merge(new_lhs, new_rhs): return lhs, rhs, merge - def _ignore(): return None @@ -362,8 +323,9 @@ def weakref_lru_cache(call: Callable, maxsize=2048, behave similar to `functools.lru_cache`. """ global _weakref_lru_caches - cached_call = xc.weakref_lru_cache( - config.trace_context if trace_context_in_key else _ignore, call, maxsize) + cached_call = _weakref_lru_cache.weakref_lru_cache( + config.trace_context if trace_context_in_key else _ignore, call, maxsize + ) _weakref_lru_caches.add(cached_call) return cached_call @@ -406,19 +368,21 @@ def __hash__(self): def __eq__(self, other): return self.val == other.val -def wrap_name(name, transform_name): +def wrap_name(name: str, transform_name: str) -> str: return transform_name + '(' + name + ')' -def fun_name(fun: Callable): + +def fun_name(fun: Callable, default_name: str = "") -> str: name = getattr(fun, "__name__", None) if name is not None: return name if isinstance(fun, partial): return fun_name(fun.func) else: - return "" + return default_name + -def fun_qual_name(fun: Callable): +def fun_qual_name(fun: Callable) -> str: qual_name = getattr(fun, "__qualname__", None) if qual_name is not None: return qual_name @@ -426,7 +390,7 @@ def fun_qual_name(fun: Callable): return fun_qual_name(fun.func) return fun_name(fun) -def canonicalize_axis(axis, num_dims) -> int: +def canonicalize_axis(axis: SupportsIndex, num_dims: int) -> int: """Canonicalize an axis in [-num_dims, num_dims) to [0, num_dims).""" axis = operator.index(axis) if not -num_dims <= axis < num_dims: @@ -435,7 +399,7 @@ def canonicalize_axis(axis, num_dims) -> int: axis = axis + num_dims return axis -def moveaxis(x, src, dst): +def moveaxis(x: Array, src: int | Sequence[int], dst: int | Sequence[int]) -> Array: if src == dst: return x if isinstance(src, int): @@ -449,7 +413,7 @@ def moveaxis(x, src, dst): perm.insert(d, s) return x.transpose(perm) -def ceil_of_ratio(x, y): +def ceil_of_ratio(x: int, y: int) -> int: return -(-x // y) @@ -475,8 +439,9 @@ def wrapper(fun: T) -> T: else docstr.format(fun=name, doc=doc, **kwargs)) fun.__qualname__ = getattr(wrapped, "__qualname__", fun.__name__) fun.__wrapped__ = wrapped - finally: - return fun + except Exception: + pass + return fun return wrapper @@ -485,22 +450,18 @@ def wrapper(fun: T) -> T: def assert_unreachable(x): raise AssertionError(f"Unhandled case: {type(x).__name__}") -def tuple_insert(t, idx, val): +def tuple_insert(t: tuple[T, ...], idx: int, val: T) -> tuple[T, ...]: assert 0 <= idx <= len(t), (idx, len(t)) return t[:idx] + (val,) + t[idx:] -def tuple_delete(t, idx): +def tuple_delete(t: tuple[T, ...], idx: int) -> tuple[T, ...]: assert 0 <= idx < len(t), (idx, len(t)) return t[:idx] + t[idx + 1:] -def tuple_update(t, idx, val): +def tuple_update(t: tuple[T, ...], idx: int, val: T) -> tuple[T, ...]: assert 0 <= idx < len(t), (idx, len(t)) return t[:idx] + (val,) + t[idx+1:] -def tuple_replace(tupl, index, item): - # unlike tuple_update, works with negative indices as well - return tupl[:index] + (item,) + tupl[index:][1:] - class HashableFunction: """Decouples function equality and hash from its identity. @@ -554,13 +515,8 @@ def __eq__(self, other): self.args == other.args and self.kwargs == other.kwargs) def __hash__(self): - return hash( - ( - self.f.__code__, - self.args, - tuple(sorted(self.kwargs.items(), key=lambda kv: kv[0])), - ), - ) + kwargs = tuple(sorted(self.kwargs.items(), key=lambda kv: kv[0])) + return hash((self.f.__code__, self.args, kwargs)) def __call__(self, *args, **kwargs): return self.f(*self.args, *args, **self.kwargs, **kwargs) @@ -643,7 +599,7 @@ def __eq__(self, other): return self.x == other.x if self.hash is not None else self.x is other.x -def _original_func(f): +def _original_func(f: Callable) -> Callable: if isinstance(f, property): return cast(property, f).fget elif isinstance(f, functools.cached_property): @@ -690,14 +646,6 @@ def decorator(f): return decorator -try: - # numpy 1.25.0 or newer - NumpyComplexWarning: type[Warning] = np.exceptions.ComplexWarning -except AttributeError: - # legacy numpy - NumpyComplexWarning = np.ComplexWarning - - class StrictABCMeta(abc.ABCMeta): """A variant of `abc.ABCMeta` which does not allow virtual subclasses. @@ -726,3 +674,12 @@ def test_event(name: str, *args) -> None: if hasattr(jaxlib_utils, "Mutex"): Mutex = jaxlib_utils.Mutex + + +def pprint_bytes(num_bytes: int | float) -> str: + prefixes = ("", "K", "M", "G", "T") + if num_bytes <= 0: + return "0.00B" + exponent = min(math.floor(math.log(num_bytes, 1000)), len(prefixes) - 1) + scaled_value = num_bytes / (1000**exponent) + return f"{scaled_value:.2f}{prefixes[exponent]}B" diff --git a/jax/_src/xla_bridge.py b/jax/_src/xla_bridge.py index be96deab81d8..bb77b4f1ff0f 100644 --- a/jax/_src/xla_bridge.py +++ b/jax/_src/xla_bridge.py @@ -31,8 +31,8 @@ import pkgutil import platform as py_platform import threading -import traceback from typing import Any, Union +from collections.abc import Sequence import warnings from jax._src import config @@ -43,7 +43,8 @@ from jax._src.cloud_tpu_init import get_tpu_library_path from jax._src.lib import cuda_versions from jax._src.lib import xla_client -from jax._src.lib import xla_extension +from jax._src.lib import _jax +from jax._src.lib import _profiler logger = logging.getLogger(__name__) @@ -60,9 +61,10 @@ XlaBackend = xla_client.Client -MIN_COMPUTE_CAPABILITY = 52 +# The platforms in this set will force forward compatibility for lowering. +FORCE_FORWARD_COMPAT_LOWERING_PLATFORMS: set[str] = set() -_DEFAULT_CPU_COLLECTIVES_IMPL = 'gloo' +MIN_COMPUTE_CAPABILITY = 52 # TODO(phawkins): Remove jax_xla_backend. _XLA_BACKEND = config.string_flag( @@ -86,13 +88,13 @@ 'Restricts the set of ROCM devices that JAX will use. Either "all", or a ' 'comma-separate list of integer device IDs.') -_MOCK_NUM_GPU_PROCESSES = config.int_flag( +MOCK_NUM_GPU_PROCESSES = config.int_flag( name="mock_num_gpu_processes", default=0, help="Mock number of JAX processes in GPU client. Value zero turns " "off mocking.", ) -_MOCK_GPU_TOPOLOGY = config.string_flag( +MOCK_GPU_TOPOLOGY = config.string_flag( name="jax_mock_gpu_topology", default="", help='Mock multi-host GPU topology in GPU client. The value should ' @@ -125,13 +127,31 @@ def _at_fork(): # Backends +_NameValueMapping = Mapping[str, Union[str, int, list[int], float, bool]] + +def make_tpu_client( + library_path: str | None = None, options: _NameValueMapping | None = None +): + """Returns a TPU client. Defaults to allowing 32 in-flight computations.""" + if not _jax.pjrt_plugin_loaded('tpu'): + c_api = xla_client.load_pjrt_plugin_dynamically( + "tpu", library_path or "libtpu.so" + ) + _profiler.register_plugin_profiler(c_api) + assert _jax.pjrt_plugin_loaded('tpu') + if not _jax.pjrt_plugin_initialized('tpu'): + _jax.initialize_pjrt_plugin('tpu') + if options is None: + options = {} + return _jax.get_c_api_client('tpu', options) + def tpu_client_timer_callback(timer_secs: float) -> xla_client.Client | None: def _log_warning(): warnings.warn( f'TPU backend initialization is taking more than {timer_secs} seconds. ' 'Did you run your code on all TPU hosts? ' - 'See https://jax.readthedocs.io/en/latest/multi_process.html ' + 'See https://docs.jax.dev/en/latest/multi_process.html ' 'for more information.') # Will log a warning after `timer_secs`. @@ -139,7 +159,7 @@ def _log_warning(): t.start() try: - client = xla_client.make_tpu_client( + client = make_tpu_client( get_tpu_library_path(), _options_from_jax_configs("tpu")) finally: @@ -248,8 +268,6 @@ def make_cpu_client( '"jax_cpu_collectives_implementation", "gloo")` instead.', DeprecationWarning, ) - if collectives_impl is None: - collectives_impl = _DEFAULT_CPU_COLLECTIVES_IMPL if collectives_impl == 'gloo': collectives = xla_client._xla.make_gloo_tcp_collectives( @@ -287,149 +305,14 @@ def _check_cuda_compute_capability(devices_to_check): f"Device {idx} has CUDA compute capability {compute_cap/10} which is " "lower than the minimum supported compute capability " f"{MIN_COMPUTE_CAPABILITY/10}. See " - "https://jax.readthedocs.io/en/latest/installation.html#nvidia-gpu for " + "https://docs.jax.dev/en/latest/installation.html#nvidia-gpu for " "more details", RuntimeWarning ) -def _check_cuda_versions(raise_on_first_error: bool = False, - debug: bool = False): - assert cuda_versions is not None - results: list[dict[str, Any]] = [] - - def _make_msg(name: str, - runtime_version: int, - build_version: int, - min_supported: int, - debug_msg: bool = False): - if debug_msg: - return (f"Package: {name}\n" - f"Version JAX was built against: {build_version}\n" - f"Minimum supported: {min_supported}\n" - f"Installed version: {runtime_version}") - if min_supported: - req_str = (f"The local installation version must be no lower than " - f"{min_supported}.") - else: - req_str = ("The local installation must be the same version as " - "the version against which JAX was built.") - msg = (f"Outdated {name} installation found.\n" - f"Version JAX was built against: {build_version}\n" - f"Minimum supported: {min_supported}\n" - f"Installed version: {runtime_version}\n" - f"{req_str}") - return msg - - def _version_check(name: str, - get_version, - get_build_version, - scale_for_comparison: int = 1, - min_supported_version: int = 0): - """Checks the runtime CUDA component version against the JAX one. - - Args: - name: Of the CUDA component. - get_version: A function to get the local runtime version of the component. - get_build_version: A function to get the build version of the component. - scale_for_comparison: For rounding down a version to ignore patch/minor. - min_supported_version: An absolute minimum version required. Must be - passed without rounding down. - - Raises: - RuntimeError: If the component is not found, or is of unsupported version, - and if raising the error is not deferred till later. - """ - - build_version = get_build_version() - try: - version = get_version() - except Exception as e: - err_msg = f"Unable to load {name}. Is it installed?" - if raise_on_first_error: - raise RuntimeError(err_msg) from e - err_msg += f"\n{traceback.format_exc()}" - results.append({"name": name, "installed": False, "msg": err_msg}) - return - - if not min_supported_version: - min_supported_version = build_version // scale_for_comparison - passed = min_supported_version <= version - - if not passed or debug: - msg = _make_msg(name=name, - runtime_version=version, - build_version=build_version, - min_supported=min_supported_version, - debug_msg=passed) - if not passed and raise_on_first_error: - raise RuntimeError(msg) - else: - record = {"name": name, - "installed": True, - "msg": msg, - "passed": passed, - "build_version": build_version, - "version": version, - "minimum_supported": min_supported_version} - results.append(record) - - _version_check("CUDA", cuda_versions.cuda_runtime_get_version, - cuda_versions.cuda_runtime_build_version, - scale_for_comparison=10, - min_supported_version=12010) - _version_check( - "cuDNN", - cuda_versions.cudnn_get_version, - cuda_versions.cudnn_build_version, - # NVIDIA promise both backwards and forwards compatibility for cuDNN patch - # versions: - # https://docs.nvidia.com/deeplearning/cudnn/developer-guide/index.html#api-compat - scale_for_comparison=100, - min_supported_version=9100 - ) - _version_check("cuFFT", cuda_versions.cufft_get_version, - cuda_versions.cufft_build_version, - # Ignore patch versions. - scale_for_comparison=100) - _version_check("cuSOLVER", cuda_versions.cusolver_get_version, - cuda_versions.cusolver_build_version, - # Ignore patch versions. - scale_for_comparison=100, - min_supported_version=11400) - _version_check("cuPTI", cuda_versions.cupti_get_version, - cuda_versions.cupti_build_version, - min_supported_version=18) - _version_check("cuBLAS", cuda_versions.cublas_get_version, - cuda_versions.cublas_build_version, - # Ignore patch versions. - scale_for_comparison=100, - min_supported_version=120100) - _version_check("cuSPARSE", cuda_versions.cusparse_get_version, - cuda_versions.cusparse_build_version, - # Ignore patch versions. - scale_for_comparison=100, - min_supported_version=12100) - - errors = [] - debug_results = [] - for result in results: - message: str = result['msg'] - if not result['installed'] or not result['passed']: - errors.append(message) - else: - debug_results.append(message) - - join_str = f'\n{"-" * 50}\n' - if debug_results: - print(f'CUDA components status (debug):\n' - f'{join_str.join(debug_results)}') - if errors: - raise RuntimeError(f'Unable to use CUDA because of the ' - f'following issues with CUDA components:\n' - f'{join_str.join(errors)}') - -def _get_num_nodes_from_gpu_topology(topology: str) -> int: + +def get_num_nodes_from_gpu_topology(topology: str) -> int: try: slices_str, hosts_per_slice_str, _ = topology.split("x", 2) return int(slices_str) * int(hosts_per_slice_str) @@ -438,75 +321,11 @@ def _get_num_nodes_from_gpu_topology(topology: str) -> int: '" x x ' '".') -def make_gpu_client( - *, platform_name: str, visible_devices_flag: config.Flag[str] -) -> xla_client.Client: - visible_devices = visible_devices_flag.value - allowed_devices = None - if visible_devices != "all": - allowed_devices = {int(x) for x in visible_devices.split(",")} - - mock_gpu_topology = _MOCK_GPU_TOPOLOGY.value or None - mock_num_gpu_processes = (_get_num_nodes_from_gpu_topology(mock_gpu_topology) if - mock_gpu_topology else _MOCK_NUM_GPU_PROCESSES.value) - - use_mock_gpu_client = mock_num_gpu_processes > 0 - num_nodes = (mock_num_gpu_processes if use_mock_gpu_client - else distributed.global_state.num_processes) - - if platform_name == "cuda": - if not os.getenv("JAX_SKIP_CUDA_CONSTRAINTS_CHECK"): - _check_cuda_versions() - else: - print('Skipped CUDA versions constraints check due to the ' - 'JAX_SKIP_CUDA_CONSTRAINTS_CHECK env var being set.') - - devices_to_check = ( - allowed_devices - if allowed_devices - else range(cuda_versions.cuda_device_count()) - ) - _check_cuda_compute_capability(devices_to_check) - - return xla_client.make_gpu_client( - distributed_client=distributed.global_state.client, - node_id=distributed.global_state.process_id, - num_nodes=num_nodes, - platform_name=platform_name, - allowed_devices=allowed_devices, - mock=use_mock_gpu_client, - ) - - -if hasattr(xla_client, "make_gpu_client"): - register_backend_factory( - "cuda", - partial( - make_gpu_client, - platform_name="cuda", - visible_devices_flag=CUDA_VISIBLE_DEVICES, - ), - priority=200, - fail_quietly=True, - ) - register_backend_factory( - "rocm", - partial( - make_gpu_client, - platform_name="rocm", - visible_devices_flag=_ROCM_VISIBLE_DEVICES, - ), - priority=200, - fail_quietly=True, - ) - - -if hasattr(xla_client, "make_tpu_client"): - # TODO(phawkins,skyewm): switch TPU plugin to use the PJRT plugin mechanism, - # and then fail loudly on initialization failure. - register_backend_factory( - 'tpu', partial(tpu_client_timer_callback, timer_secs=60.0), priority=300, - fail_quietly=True) +# TODO(phawkins,skyewm): switch TPU plugin to use the PJRT plugin mechanism, +# and then fail loudly on initialization failure. +register_backend_factory( + 'tpu', partial(tpu_client_timer_callback, timer_secs=60.0), priority=300, + fail_quietly=True) def _get_pjrt_plugin_names_and_library_paths( @@ -563,7 +382,7 @@ def discover_pjrt_plugins() -> None: """Discovers plugins in the namespace package `jax_plugins` and import them. There are two methods used to discover plugin modules. They are intended - to be used together by implementors in order to cover all packaging and + to be used together by implementers in order to cover all packaging and development cases: 1. Define a globally unique module under the `jax_plugins` namespace @@ -631,27 +450,30 @@ def _options_from_jax_configs(plugin_name): options = {} pjrt_client_options = config.jax_pjrt_client_create_options.value - pjrt_client_option_list = [] - if pjrt_client_options: - pjrt_client_option_list = pjrt_client_options.split(";") - - for option in pjrt_client_option_list: - option_list = option.split(":") - if (len(option_list) != 2): - raise RuntimeError( - "Multiple ':' separators for option in " - f"jax_pjrt_client_create_options: '{option}'. " - "Should be in format 'key:value'") - options[option_list[0]] = option_list[1] + if isinstance(pjrt_client_options, str): + pjrt_client_option_list = [] + if pjrt_client_options: + pjrt_client_option_list = pjrt_client_options.split(";") + + for option in pjrt_client_option_list: + option_list = option.split(":") + if (len(option_list) != 2): + raise RuntimeError( + "Multiple ':' separators for option in " + f"jax_pjrt_client_create_options: '{option}'. " + "Should be in format 'key:value'") + options[option_list[0]] = option_list[1] + elif isinstance(pjrt_client_options, dict): + options.update(pjrt_client_options) if plugin_name in ("cuda", "rocm"): visible_devices = (CUDA_VISIBLE_DEVICES.value if plugin_name == "cuda" else _ROCM_VISIBLE_DEVICES.value) if visible_devices != 'all': options['visible_devices'] = [int(x) for x in visible_devices.split(',')] - mock_gpu_topology = _MOCK_GPU_TOPOLOGY.value or None - mock_num_processes = (_get_num_nodes_from_gpu_topology(mock_gpu_topology) if - mock_gpu_topology else _MOCK_NUM_GPU_PROCESSES.value) + mock_gpu_topology = MOCK_GPU_TOPOLOGY.value or None + mock_num_processes = (get_num_nodes_from_gpu_topology(mock_gpu_topology) if + mock_gpu_topology else MOCK_NUM_GPU_PROCESSES.value) options['enable_mock_nccl'] = mock_num_processes > 0 if mock_num_processes > 0: options['num_nodes'] = mock_num_processes @@ -660,6 +482,8 @@ def _options_from_jax_configs(plugin_name): return options +OptionsDict = Mapping[str, str | int | list[int] | float | bool] + # TODO(b/261345120): decide on a public name and expose a public method which is # an alias of this method. @@ -668,7 +492,7 @@ def register_plugin( *, priority: int = 400, library_path: str | None = None, - options: Mapping[str, str | int | list[int] | float | bool] | None = None, + options: OptionsDict | Callable[[], OptionsDict] | None = None, c_api: Any | None = None, ) -> Any: """Registers a backend factory for the PJRT plugin. @@ -679,7 +503,9 @@ def register_plugin( Default to be 400. library_path: Optional. The full path to the .so file of the plugin. The plugin needs to provide either the library_path or the c_api. - options: Optional. It is used when creating a PJRT plugin client. + options: Optional. It is used when creating a PJRT plugin client. Can be a + callable, in which case it will be invoked upon plugin initialization + time, and will be expected to return an option dictionary. c_api: Optional. The plugin can provide a PJRT C API to be registered. """ def factory(): @@ -687,7 +513,7 @@ def factory(): xla_client.initialize_pjrt_plugin(plugin_name) updated_options = {} if options is not None: - updated_options.update(options) + updated_options.update(options() if callable(options) else options) updated_options.update(_options_from_jax_configs(plugin_name)) if distributed.global_state.client is None: return xla_client.make_c_api_client(plugin_name, updated_options, None) @@ -696,6 +522,8 @@ def factory(): 'node_id': distributed.global_state.process_id, 'num_nodes': distributed.global_state.num_processes, } + if (slice_index := distributed.global_state.slice_index) is not None: + distribute_options['slice_index'] = slice_index if options is not None: distribute_options.update(updated_options) return xla_client.make_c_api_client( @@ -722,7 +550,7 @@ def factory(): ) if library_path is not None: c_api = xla_client.load_pjrt_plugin_dynamically(plugin_name, library_path) - xla_client.profiler.register_plugin_profiler(c_api) + _profiler.register_plugin_profiler(c_api) else: assert c_api is not None xla_client.load_pjrt_plugin_with_c_api(plugin_name, c_api) @@ -950,14 +778,14 @@ def _suggest_missing_backends(): assert _default_backend is not None default_platform = _default_backend.platform if "cuda" not in _backends and hardware_utils.has_visible_nvidia_gpu(): - if hasattr(xla_extension, "GpuAllocatorConfig") and "cuda" in _backend_errors: + if hasattr(_jax, "GpuAllocatorConfig") and "cuda" in _backend_errors: err = _backend_errors["cuda"] warning_msg = f"CUDA backend failed to initialize: {err}." if "no supported devices found for platform CUDA." in err: warning_msg += ( "This may be due to JAX pre-allocating too much device " "memory, leaving too little for CUDA library initialization. See " - "https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html " + "https://docs.jax.dev/en/latest/gpu_memory_allocation.html " "for more details and potential workarounds." ) warning_msg += "(Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)" @@ -1092,7 +920,7 @@ def devices( ) -> list[xla_client.Device]: """Returns a list of all devices for a given backend. - .. currentmodule:: jaxlib.xla_extension + .. currentmodule:: jaxlib._jax Each device is represented by a subclass of :class:`Device` (e.g. :class:`CpuDevice`, :class:`GpuDevice`). The length of the returned list is @@ -1137,13 +965,23 @@ def backend_xla_version(platform=None) -> int | None: """Returns the XLA version of the backend. Returns None if the backend does not use PJRT C API or does not have - xla_version in the plugin attributes. This methon can be used to skip features + xla_version in the plugin attributes. This method can be used to skip features that are not available before certain xla_version if the backend is a plugin and uses xla_version. """ backend = get_backend(platform) return getattr(backend, "xla_version", None) +def backend_stablehlo_version(platform=None) -> Sequence[int] | None: + """Returns the StableHLO version of the backend. + + Returns None if the backend does not use PJRT C API or does not have + stablehlo_current_version in the plugin attributes. This method can be used to + skip features that are not available before certain stablehlo_current_version + if the backend is a plugin and uses stablehlo_current_version. + """ + backend = get_backend(platform) + return getattr(backend, "stablehlo_current_version", None) @lru_cache def local_devices(process_index: int | None = None, @@ -1266,7 +1104,7 @@ def make_pjrt_tpu_topology(topology_name='', **kwargs): "JAX TPU support not installed; cannot generate TPU topology. See" " https://github.com/jax-ml/jax#installation") c_api = xla_client.load_pjrt_plugin_dynamically("tpu", library_path) - xla_client.profiler.register_plugin_profiler(c_api) + _profiler.register_plugin_profiler(c_api) assert xla_client.pjrt_plugin_loaded("tpu") if not xla_client.pjrt_plugin_initialized("tpu"): xla_client.initialize_pjrt_plugin("tpu") diff --git a/jax/_src/xla_metadata.py b/jax/_src/xla_metadata.py index 91895b4e7851..77c0e2ff9910 100644 --- a/jax/_src/xla_metadata.py +++ b/jax/_src/xla_metadata.py @@ -24,6 +24,8 @@ class XlaMetadata: __slots__ = ['val', 'hash'] + val: dict[str, Any] + def __init__(self, val): self.val = val self.hash = hash(tuple(sorted(self.val.items()))) @@ -35,14 +37,19 @@ def __eq__(self, other): return other is not None and self.val == other.val +def filter_nones(d: dict) -> dict: + return {k: v for k, v in d.items() if v is not None} + + def update_metadata(a, b: dict[str, Any]): if not b: return a if a is None or a is config_ext.unset: - return XlaMetadata(b) - val = a.val.copy() + val = {} + else: + val = a.val.copy() val.update(b) - return XlaMetadata(val) + return XlaMetadata(filter_nones(val)) def current_xla_metadata(): diff --git a/jax/collect_profile.py b/jax/collect_profile.py index d1309e0c5bca..2c725ce8e9e2 100644 --- a/jax/collect_profile.py +++ b/jax/collect_profile.py @@ -23,15 +23,11 @@ # pytype: disable=import-error from jax._src import profiler as jax_profiler try: - from tensorflow.python.profiler import profiler_v2 as profiler - from tensorflow.python.profiler import profiler_client -except ImportError: - raise ImportError("This script requires `tensorflow` to be installed.") -try: - from tensorboard_plugin_profile.convert import raw_to_tool_data as convert + from xprof.convert import _pywrap_profiler_plugin + from xprof.convert import raw_to_tool_data as convert except ImportError: raise ImportError( - "This script requires `tensorboard_plugin_profile` to be installed.") + "This script requires `xprof` to be installed.") # pytype: enable=import-error @@ -69,13 +65,13 @@ def collect_profile(port: int, duration_in_ms: int, host: str, log_dir: os.PathLike | str | None, host_tracer_level: int, device_tracer_level: int, python_tracer_level: int, no_perfetto_link: bool): - options = profiler.ProfilerOptions( - host_tracer_level=host_tracer_level, - device_tracer_level=device_tracer_level, - python_tracer_level=python_tracer_level, - ) + options = { + "host_tracer_level": host_tracer_level, + "device_tracer_level": device_tracer_level, + "python_tracer_level": python_tracer_level, + } log_dir_ = pathlib.Path(log_dir if log_dir is not None else tempfile.mkdtemp()) - profiler_client.trace( + _pywrap_profiler_plugin.trace( f"{host}:{port}", str(log_dir_), duration_in_ms, @@ -91,7 +87,7 @@ def collect_profile(port: int, duration_in_ms: int, host: str, in root_trace_folder.iterdir()] latest_folder = max(trace_folders, key=os.path.getmtime) xplane = next(latest_folder.glob("*.xplane.pb")) - result, _ = convert.xspace_to_tool_data([xplane], "trace_viewer^", {}) + result, _ = convert.xspace_to_tool_data([xplane], "trace_viewer", {}) with gzip.open(str(latest_folder / "remote.trace.json.gz"), "wb") as fp: fp.write(result.encode("utf-8")) diff --git a/jax/core.py b/jax/core.py index 3fd7af440d4a..50b50d935024 100644 --- a/jax/core.py +++ b/jax/core.py @@ -81,154 +81,79 @@ from jax._src import core as _src_core _deprecations = { - # Added 2024-12-16 - "ClosedJaxpr": ("jax.core.ClosedJaxpr is deprecated. Use jax.extend.core.ClosedJaxpr instead, " - "and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.", - _src_core.ClosedJaxpr), - "Jaxpr": ("jax.core.Jaxpr is deprecated. Use jax.extend.core.Jaxpr instead, " - "and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.", - _src_core.Jaxpr), - "JaxprEqn": ("jax.core.JaxprEqn is deprecated. Use jax.extend.core.JaxprEqn instead, " - "and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.", - _src_core.JaxprEqn), - "Literal": ("jax.core.Literal is deprecated. Use jax.extend.core.Literal instead, " - "and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.", - _src_core.Literal), - "Primitive": ("jax.core.Primitive is deprecated. Use jax.extend.core.Primitive instead, " - "and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.", - _src_core.Primitive), - "Token": ("jax.core.Token is deprecated. Use jax.extend.core.Token instead, " - "and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.", - _src_core.Token), - "Var": ("jax.core.Var is deprecated. Use jax.extend.core.Var instead, " - "and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.", - _src_core.Var), # Added 2024-12-11 "axis_frame": ("jax.core.axis_frame is deprecated.", _src_core.axis_frame), "AxisName": ("jax.core.AxisName is deprecated.", _src_core.AxisName), - "AxisSize": ("jax.core.AxisSize is deprecated.", _src_core.AxisSize), "ConcretizationTypeError": ("jax.core.ConcretizationTypeError is deprecated; " "use jax.errors.ConcretizationTypeError.", _src_core.ConcretizationTypeError), - "EvalTrace": ("jax.core.EvalTrace is deprecated.", _src_core.EvalTrace), - "InDBIdx": ("jax.core.InDBIdx is deprecated.", _src_core.InDBIdx), - "InputType": ("jax.core.InputType is deprecated.", _src_core.InputType), - "MapPrimitive": ("jax.core.MapPrimitive is deprecated.", _src_core.MapPrimitive), - "OpaqueTraceState": ("jax.core.OpaqueTraceState is deprecated.", _src_core.OpaqueTraceState), - "OutDBIdx": ("jax.core.OutDBIdx is deprecated.", _src_core.OutDBIdx), - "TRACER_LEAK_DEBUGGER_WARNING": ("jax.core.TRACER_LEAK_DEBUGGER_WARNING is deprecated.", - _src_core.TRACER_LEAK_DEBUGGER_WARNING), "call_p": ("jax.core.call_p is deprecated. Use jax.extend.core.primitives.call_p", _src_core.call_p), "closed_call_p": ("jax.core.closed_call_p is deprecated. Use jax.extend.core.primitives.closed_call_p", _src_core.closed_call_p), - "concrete_aval": ("jax.core.concrete_aval is deprecated.", _src_core.abstractify), - "dedup_referents": ("jax.core.dedup_referents is deprecated.", _src_core.dedup_referents), - "escaped_tracer_error": ("jax.core.escaped_tracer_error is deprecated.", - _src_core.escaped_tracer_error), - "extend_axis_env_nd": ("jax.core.extend_axis_env_nd is deprecated.", - _src_core.extend_axis_env_nd), "get_type": ("jax.core.get_type is deprecated.", _src_core.get_aval), - "get_referent": ("jax.core.get_referent is deprecated.", _src_core.get_referent), - "join_effects": ("jax.core.join_effects is deprecated.", _src_core.join_effects), - "leaked_tracer_error": ("jax.core.leaked_tracer_error is deprecated.", - _src_core.leaked_tracer_error), - "maybe_find_leaked_tracers": ("jax.core.maybe_find_leaked_tracers is deprecated.", - _src_core.maybe_find_leaked_tracers), - "raise_to_shaped_mappings": ("jax.core.raise_to_shaped_mappings is deprecated." - " It is unused as of jax v0.4.36.", - _src_core.raise_to_shaped_mappings), - "reset_trace_state": ("jax.core.reset_trace_state is deprecated.", - _src_core.reset_trace_state), - "str_eqn_compact": ("jax.core.str_eqn_compact is deprecated.", _src_core.str_eqn_compact), - "substitute_vars_in_output_ty": ("jax.core.substitute_vars_in_output_ty is deprecated.", - _src_core.substitute_vars_in_output_ty), "trace_state_clean": ("jax.core.trace_state_clean is deprecated.", _src_core.trace_state_clean), "typecheck": ("jax.core.typecheck is deprecated.", _src_core.typecheck), - "typecompat": ("jax.core.typecompat is deprecated.", _src_core.typecompat), "typematch": ("jax.core.typematch is deprecated.", _src_core.typematch), - "used_axis_names_jaxpr": ("jax.core.used_axis_names_jaxpr is deprecated.", - _src_core.used_axis_names_jaxpr), # Added 2024-12-10 - "full_lower": ("jax.core.full_lower is deprecated. It is a no-op as of JAX v0.4.36.", - _src_core.full_lower), - "jaxpr_as_fun": ("jax.core.jaxpr_as_fun is deprecated. Use jax.extend.core.jaxpr_as_fun instead, " - "and see https://jax.readthedocs.io/en/latest/jax.extend.html for details.", - _src_core.jaxpr_as_fun), - "lattice_join": ("jax.core.lattice_join is deprecated. It is a no-op as of JAX v0.4.36.", - _src_core.lattice_join), - "raise_to_shaped": ("jax.core.raise_to_shaped is deprecated. It is a no-op as of JAX v0.4.36.", - _src_core.raise_to_shaped), - # Finalized 2024-12-11; remove after 2025-3-11 - "check_eqn": ("jax.core.check_eqn was removed in JAX v0.4.38.", None), - "check_type": ("jax.core.check_type was removed in JAX v0.4.38.", None), - "check_valid_jaxtype": ( - ("jax.core.check_valid_jaxtype was removed in JAX v0.4.38. Instead, you can manually" - " raise an error if core.valid_jaxtype() returns False."), - None), - "non_negative_dim": ( - "jax.core.non_negative_dim was removed in JAX v0.4.38. Use max_dim(..., 0).", None, - ), - # Finalized 2024-09-25; remove after 2024-12-25 - "pp_aval": ("jax.core.pp_aval was removed in JAX v0.4.34.", None), - "pp_eqn": ("jax.core.pp_eqn was removed in JAX v0.4.34.", None), - "pp_eqn_rules": ("jax.core.pp_eqn_rules was removed in JAX v0.4.34.", None), - "pp_eqns": ("jax.core.pp_eqns was removed in JAX v0.4.34.", None), - "pp_jaxpr": ("jax.core.pp_jaxpr was removed in JAX v0.4.34.", None), - "pp_jaxpr_eqn_range": ("jax.core.pp_jaxpr_eqn_range was removed in JAX v0.4.34.", None), - "pp_jaxpr_skeleton": ("jax.core.pp_jaxpr_skeleton was removed in JAX v0.4.34.", None), - "pp_jaxprs": ("jax.core.pp_jaxprs was removed in JAX v0.4.34.", None), - "pp_kv_pair": ("jax.core.pp_kv_pair was removed in JAX v0.4.34.", None), - "pp_kv_pairs": ("jax.core.pp_kv_pairs was removed in JAX v0.4.34.", None), - "pp_var": ("jax.core.pp_var was removed in JAX v0.4.34.", None), - "pp_vars": ("jax.core.pp_vars was removed in JAX v0.4.34.", None), + "full_lower": ("jax.core.full_lower is deprecated. It is a no-op as of JAX v0.4.36.", None), + "jaxpr_as_fun": ("jax.core.jaxpr_as_fun was removed in JAX v0.6.0. Use jax.extend.core.jaxpr_as_fun instead, " + "and see https://docs.jax.dev/en/latest/jax.extend.html for details.", + None), + "lattice_join": ("jax.core.lattice_join is deprecated. It is a no-op as of JAX v0.4.36.", None), + # Finalized 2025-03-25 for JAX v0.6.0; remove after 2025-06-25 + "AxisSize": ("jax.core.AxisSize was removed in JAX v0.6.0.", None), + "ClosedJaxpr": ("jax.core.ClosedJaxpr was removed in JAX v0.6.0. Use jax.extend.core.ClosedJaxpr instead, " + "and see https://docs.jax.dev/en/latest/jax.extend.html for details.", None), + "EvalTrace": ("jax.core.EvalTrace was removed in JAX v0.6.0.", None), + "InDBIdx": ("jax.core.InDBIdx was removed in JAX v0.6.0.", None), + "InputType": ("jax.core.InputType was removed in JAX v0.6.0.", None), + "Jaxpr": ("jax.core.Jaxpr was removed in JAX v0.6.0. Use jax.extend.core.Jaxpr instead, " + "and see https://docs.jax.dev/en/latest/jax.extend.html for details.", None), + "JaxprEqn": ("jax.core.JaxprEqn was removed in JAX v0.6.0. Use jax.extend.core.JaxprEqn instead, " + "and see https://docs.jax.dev/en/latest/jax.extend.html for details.", None), + "Literal": ("jax.core.Literal was removed in JAX v0.6.0. Use jax.extend.core.Literal instead, " + "and see https://docs.jax.dev/en/latest/jax.extend.html for details.", None), + "MapPrimitive": ("jax.core.MapPrimitive was removed in JAX v0.6.0.", None), + "OpaqueTraceState": ("jax.core.OpaqueTraceState was removed in JAX v0.6.0.", None), + "OutDBIdx": ("jax.core.OutDBIdx was removed in JAX v0.6.0.", None), + "Primitive": ("jax.core.Primitive was removed in JAX v0.6.0. Use jax.extend.core.Primitive instead, " + "and see https://docs.jax.dev/en/latest/jax.extend.html for details.", None), + "Token": ("jax.core.Token was removed in JAX v0.6.0. Use jax.extend.core.Token instead, " + "and see https://docs.jax.dev/en/latest/jax.extend.html for details.", None), + "TRACER_LEAK_DEBUGGER_WARNING": ("jax.core.TRACER_LEAK_DEBUGGER_WARNING was removed in JAX v0.6.0.", None), + "Var": ("jax.core.Var was removed in JAX v0.6.0. Use jax.extend.core.Var instead, " + "and see https://docs.jax.dev/en/latest/jax.extend.html for details.", None), + "concrete_aval": ("jax.core.concrete_aval was removed in JAX v0.6.0.", None), + "dedup_referents": ("jax.core.dedup_referents was removed in JAX v0.6.0.", None), + "escaped_tracer_error": ("jax.core.escaped_tracer_error was removed in JAX v0.6.0.", None), + "extend_axis_env_nd": ("jax.core.extend_axis_env_nd was removed in JAX v0.6.0.", None), + "get_referent": ("jax.core.get_referent was removed in JAX v0.6.0.", None), + "join_effects": ("jax.core.join_effects was removed in JAX v0.6.0.", None), + "leaked_tracer_error": ("jax.core.leaked_tracer_error was removed in JAX v0.6.0.", None), + "maybe_find_leaked_tracers": ("jax.core.maybe_find_leaked_tracers was removed in JAX v0.6.0.", None), + "raise_to_shaped": ("jax.core.raise_to_shaped was removed in JAX v0.6.0. It is a no-op as of JAX v0.4.36.", None), + "raise_to_shaped_mappings": ("jax.core.raise_to_shaped_mappings was removed in JAX v0.6.0." + " It is unused as of jax v0.4.36.", None), + "reset_trace_state": ("jax.core.reset_trace_state was removed in JAX v0.6.0.", None), + "str_eqn_compact": ("jax.core.str_eqn_compact was removed in JAX v0.6.0.", None), + "substitute_vars_in_output_ty": ("jax.core.substitute_vars_in_output_ty was removed in JAX v0.6.0.", None), + "typecompat": ("jax.core.typecompat was removed in JAX v0.6.0.", None), + "used_axis_names_jaxpr": ("jax.core.used_axis_names_jaxpr was removed in JAX v0.6.0.", None), } import typing if typing.TYPE_CHECKING: AxisName = _src_core.AxisName - AxisSize = _src_core.AxisSize - ClosedJaxpr = _src_core.ClosedJaxpr ConcretizationTypeError = _src_core.ConcretizationTypeError - EvalTrace = _src_core.EvalTrace - InDBIdx = _src_core.InDBIdx - InputType = _src_core.InputType - Jaxpr = _src_core.Jaxpr - JaxprEqn = _src_core.JaxprEqn - Literal = _src_core.Literal - MapPrimitive = _src_core.MapPrimitive - OpaqueTraceState = _src_core.OpaqueTraceState - OutDBIdx = _src_core.OutDBIdx - Primitive = _src_core.Primitive - Token = _src_core.Token - TRACER_LEAK_DEBUGGER_WARNING = _src_core.TRACER_LEAK_DEBUGGER_WARNING - Var = _src_core.Var axis_frame = _src_core.axis_frame call_p = _src_core.call_p closed_call_p = _src_core.closed_call_p - concrete_aval = _src_core.abstractify - dedup_referents = _src_core.dedup_referents - escaped_tracer_error = _src_core.escaped_tracer_error - extend_axis_env_nd = _src_core.extend_axis_env_nd - full_lower = _src_core.full_lower get_type = _src_core.get_aval - get_referent = _src_core.get_referent - jaxpr_as_fun = _src_core.jaxpr_as_fun - join_effects = _src_core.join_effects - lattice_join = _src_core.lattice_join - leaked_tracer_error = _src_core.leaked_tracer_error - maybe_find_leaked_tracers = _src_core.maybe_find_leaked_tracers - raise_to_shaped = _src_core.raise_to_shaped - raise_to_shaped_mappings = _src_core.raise_to_shaped_mappings - reset_trace_state = _src_core.reset_trace_state - str_eqn_compact = _src_core.str_eqn_compact - substitute_vars_in_output_ty = _src_core.substitute_vars_in_output_ty trace_state_clean = _src_core.trace_state_clean typecheck = _src_core.typecheck - typecompat = _src_core.typecompat typematch = _src_core.typematch - used_axis_names_jaxpr = _src_core.used_axis_names_jaxpr else: from jax._src.deprecations import deprecation_getattr as _deprecation_getattr __getattr__ = _deprecation_getattr(__name__, _deprecations) diff --git a/jax/custom_derivatives.py b/jax/custom_derivatives.py index 3628ae4aaa6e..6674046dd8e8 100644 --- a/jax/custom_derivatives.py +++ b/jax/custom_derivatives.py @@ -23,10 +23,9 @@ custom_gradient as custom_gradient, custom_jvp as custom_jvp, custom_jvp_call_p as custom_jvp_call_p, - custom_jvp_call_jaxpr_p as custom_jvp_call_jaxpr_p, + custom_jvp_call_jaxpr_p as _custom_jvp_call_jaxpr_p, custom_vjp as custom_vjp, custom_vjp_call_p as custom_vjp_call_p, - custom_vjp_call_jaxpr_p as custom_vjp_call_jaxpr_p, custom_vjp_primal_tree_values as custom_vjp_primal_tree_values, CustomVJPPrimal as CustomVJPPrimal, linear_call as linear_call, @@ -37,3 +36,22 @@ SymbolicZero as SymbolicZero, zero_from_primal as zero_from_primal ) + +_deprecations = { + # Added May 12, 2025 + "custom_jvp_call_jaxpr_p": ( + ("jax.custom_derivatives.custom_jvp_call_jaxpr_p is deprecated, use " + "jax.extend.core.primitives.custom_jvp_call_p instead."), + _custom_jvp_call_jaxpr_p, + ), +} + +import typing +if typing.TYPE_CHECKING: + custom_jvp_call_jaxpr_p = _custom_jvp_call_jaxpr_p +else: + from jax._src.deprecations import deprecation_getattr as _deprecation_getattr + __getattr__ = _deprecation_getattr(__name__, _deprecations) + del _deprecation_getattr +del typing +del _custom_jvp_call_jaxpr_p diff --git a/jax/dlpack.py b/jax/dlpack.py index a65496ec0cbf..c4b993195030 100644 --- a/jax/dlpack.py +++ b/jax/dlpack.py @@ -12,8 +12,35 @@ # See the License for the specific language governing permissions and # limitations under the License. + +import jax._src.dlpack +import jax._src.deprecations + from jax._src.dlpack import ( - to_dlpack as to_dlpack, from_dlpack as from_dlpack, SUPPORTED_DTYPES as SUPPORTED_DTYPES, ) + +_deprecations = { + "to_dlpack": ( + ( + "jax.dlpack.to_dlpack was deprecated in JAX v0.6.0 and" + " removed in JAX v0.7.0. Please use the newer DLPack API based on" + " __dlpack__ and __dlpack_device__ instead. Typically, you can pass" + " a JAX array directly to the `from_dlpack` function of another" + " framework without using `to_dlpack`." + ), + None, + ), +} + + +import typing as _typing + +if _typing.TYPE_CHECKING: + to_dlpack = jax._src.dlpack.to_dlpack +else: + __getattr__ = jax._src.deprecations.deprecation_getattr( + __name__, _deprecations + ) +del _typing diff --git a/jax/errors.py b/jax/errors.py index 6da7b717cb5f..0dcf34bd4763 100644 --- a/jax/errors.py +++ b/jax/errors.py @@ -31,4 +31,19 @@ JaxRuntimeError = _xc.XlaRuntimeError del _xc -from jax._src.traceback_util import SimplifiedTraceback as SimplifiedTraceback +import jax._src.traceback_util +_deprecations = { + "SimplifiedTraceback": ( + "jax.errors.SimplifiedTraceback is deprecated and will be removed in JAX v0.8.", + jax._src.traceback_util.SimplifiedTraceback + ), +} + +import typing +if typing.TYPE_CHECKING: + SimplifiedTraceback = jax._src.traceback_util.SimplifiedTraceback +else: + from jax._src.deprecations import deprecation_getattr as _deprecation_getattr + __getattr__ = _deprecation_getattr(__name__, _deprecations) + del _deprecation_getattr +del typing diff --git a/jax/experimental/__init__.py b/jax/experimental/__init__.py index 375d058d0edc..1b4f7efedbe7 100644 --- a/jax/experimental/__init__.py +++ b/jax/experimental/__init__.py @@ -19,6 +19,10 @@ enable_x64 as enable_x64, disable_x64 as disable_x64, ) +from jax._src.api import ( + saved_input_vjp as saved_input_vjp, + si_vjp as si_vjp +) from jax._src.callback import ( io_callback as io_callback ) @@ -28,3 +32,7 @@ from jax._src.earray import ( EArray as EArray ) +from jax._src.core import ( + mutable_array as mutable_array, + MutableArray as MutableArray, +) diff --git a/jax/experimental/_private_mm/examples/example_overlap.py b/jax/experimental/_private_mm/examples/example_overlap.py index 022eb3293dcc..f3c3726ec347 100644 --- a/jax/experimental/_private_mm/examples/example_overlap.py +++ b/jax/experimental/_private_mm/examples/example_overlap.py @@ -14,7 +14,8 @@ """An example showcasing overlap on a (forward-only) PP-like workload.""" from dataclasses import dataclass -from typing import Any, Callable +from typing import Any +from collections.abc import Callable import time import numpy as np diff --git a/jax/experimental/_private_mm/examples/example_pp.py b/jax/experimental/_private_mm/examples/example_pp.py index b43d1c743c28..846d96cb34a9 100644 --- a/jax/experimental/_private_mm/examples/example_pp.py +++ b/jax/experimental/_private_mm/examples/example_pp.py @@ -15,7 +15,8 @@ from dataclasses import dataclass from functools import cached_property, partial -from typing import Any, Callable +from typing import Any +from collections.abc import Callable import numpy as np diff --git a/jax/experimental/_private_mm/mini_dime.py b/jax/experimental/_private_mm/mini_dime.py index 971d5a016817..f12084b3a1ce 100644 --- a/jax/experimental/_private_mm/mini_dime.py +++ b/jax/experimental/_private_mm/mini_dime.py @@ -49,8 +49,8 @@ import jax import jax.numpy as jnp -import jaxlib.xla_extension as xe from jax._src import array +from jax._src.lib import _jax from jax._src.op_shardings import are_op_shardings_equal @@ -66,10 +66,10 @@ def _get_nccl_dtype_and_count(arr, count=None): return nccl_dtype, count -def get_distributed_client() -> xe.DistributedRuntimeClient: +def get_distributed_client() -> _jax.DistributedRuntimeClient: from jax._src.distributed import global_state - assert isinstance(global_state.client, xe.DistributedRuntimeClient) + assert isinstance(global_state.client, _jax.DistributedRuntimeClient) return global_state.client diff --git a/jax/experimental/_private_mm/mm.py b/jax/experimental/_private_mm/mm.py index f47724ce6ec4..b108fb3e2e35 100644 --- a/jax/experimental/_private_mm/mm.py +++ b/jax/experimental/_private_mm/mm.py @@ -16,7 +16,7 @@ from dataclasses import dataclass from functools import cached_property, lru_cache, partial, wraps -from typing import Callable +from collections.abc import Callable import jax import jax.numpy as jnp diff --git a/jax/experimental/array_serialization/BUILD b/jax/experimental/array_serialization/BUILD index ab1ee3fd393e..559d8eb16269 100644 --- a/jax/experimental/array_serialization/BUILD +++ b/jax/experimental/array_serialization/BUILD @@ -35,9 +35,37 @@ pytype_library( "serialization.py", ], visibility = ["//visibility:public"], - deps = ["//jax"] + py_deps([ + deps = [ + "//jax", + "//jax/experimental/array_serialization:tensorstore_impl", + ] + py_deps([ + "absl/logging", + "numpy", + ]), +) + +pytype_library( + name = "pytree_serialization", + srcs = ["pytree_serialization.py"], + visibility = ["//visibility:public"], + deps = [ + "//jax", + "//jax/experimental/array_serialization:pytree_serialization_utils", + "//jax/experimental/array_serialization:tensorstore_impl", + ] + py_deps([ + "absl/logging", "numpy", + ]), +) + +pytype_library( + name = "pytree_serialization_utils", + srcs = ["pytree_serialization_utils.py"], + deps = [ + "//jax", + ] + py_deps([ "absl/logging", + "numpy", ]), ) @@ -45,10 +73,19 @@ jax_multiplatform_test( name = "serialization_test", srcs = ["serialization_test.py"], enable_configs = [ - "tpu_v3_2x2", + "tpu_v3_x4", ], deps = [ - "//jax:experimental", + "//jax/experimental/array_serialization:pytree_serialization", "//jax/experimental/array_serialization:serialization", ], ) + +pytype_library( + name = "tensorstore_impl", + srcs = ["tensorstore_impl.py"], + visibility = ["//visibility:public"], + deps = ["//jax"] + py_deps([ + "numpy", + ]), +) diff --git a/jax/experimental/array_serialization/pytree_serialization.py b/jax/experimental/array_serialization/pytree_serialization.py new file mode 100644 index 000000000000..639d36a7c806 --- /dev/null +++ b/jax/experimental/array_serialization/pytree_serialization.py @@ -0,0 +1,506 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Serializations routines for pytrees including array and non-array serialization. +""" + +from __future__ import annotations + +from os import PathLike +import os +import re +from typing import Any +from uuid import uuid4, UUID +import json +import asyncio +import threading +from concurrent.futures import ThreadPoolExecutor +import shutil +import logging + +import jax +from jax._src import distributed +from jax._src.api_util import flatten_axes + +from jax.experimental import multihost_utils +from jax.experimental.array_serialization import tensorstore_impl as ts_impl +import jax.experimental.array_serialization.pytree_serialization_utils as utils +from jax._src import path as pathlib +import numpy as np + +logger = logging.getLogger(__name__) + +_THREADING_SAVE_LOCK = threading.Lock() + +_REMOTE_URL_PREFIXES = ['gs://', 's3://'] +_PYTREEDEF_FILE = "pytreedef.json" +_ARCHIVE_NAME = "archive.zip" +_USE_OCDBT = True # a lot of the code relies on this being True +_MAX_PATH_LENGTH = 4096 +_ARRAY_STORE_DIRNAME = "array_store" +_ARRAY_TYPE_FORMAT = "Array({dtype}[{shape}])" +_ARRAY_TYPE_REGEX = r"Array\(([a-zA-Z0-9_]+)\[([0-9, ]*)\]\)" +_MAX_CONCURRENCY = 32 +_TIMEOUT_SEC = 30 + +PyTreeT = Any + +__all__ = ["save", "load", "load_pytreedef", + "nonblocking_load", "nonblocking_save"] + + +def _get_unique_sync_key() -> str | None: + """Generate a thread-local key for ensuring all host finish (de)serializing""" + if jax.process_count() == 1: + return None + # broadcast a thread-local unique barrier name + sync_key_unique = multihost_utils.broadcast_one_to_all( + np.frombuffer(uuid4().bytes, dtype=np.int32)) + sync_key_id = UUID(bytes=np.array(sync_key_unique).tobytes()) + return f"jax_sync_key_{str(sync_key_id)}" + + +def _is_str_same_on_all_hosts(path: str | PathLike[str]) -> bool: + """All-gather the location of the checkpoint and check if it's the same.""" + if jax.process_count() <= 1: + return False + path_b = str(path).encode("utf-8") + if len(path_b) > _MAX_PATH_LENGTH: + raise ValueError(f"Path exceeds maximum length of {_MAX_PATH_LENGTH} in" + " multiprocess case.") + path_array = np.concatenate([ + np.frombuffer(path_b, dtype=np.uint8), np.zeros( + _MAX_PATH_LENGTH - len(path_b), dtype=np.uint8)]) + path_array = multihost_utils.process_allgather(path_array) + return bool(np.all(path_array[0] == path_array[1:])) + + +def _sync_on_key(key: str | None, extra_tag: str = "") -> None: + if key is None: + return + full_key = f"{key}-{extra_tag}" if extra_tag else key + if (client := distributed.global_state.client) is not None: + client.wait_at_barrier(full_key, timeout_in_ms=_TIMEOUT_SEC * 1000) + + +def _is_array_like(x): + return isinstance(x, (jax.Array, np.ndarray)) + + +def _leaf_to_desc(leaf) -> str: + if leaf is None: + return "null" + elif _is_array_like(leaf): + return _ARRAY_TYPE_FORMAT.format( + dtype=leaf.dtype.name, shape=", ".join(map(str, leaf.shape))) + else: + return type(leaf).__name__ + + +def _desc_to_leaf(leaf_desc: str | None) -> str | None | jax.ShapeDtypeStruct: + if leaf_desc is None: + return None + if not re.match(_ARRAY_TYPE_REGEX, leaf_desc): + return leaf_desc + shape_dtype_match = re.match(_ARRAY_TYPE_REGEX, leaf_desc) + assert shape_dtype_match is not None + dtype_str, shape_str = shape_dtype_match.groups() + shape = [int(x.strip()) for x in shape_str.strip("]").strip().split(",") + if len(x.strip()) > 0] + return jax.ShapeDtypeStruct(shape, jax.numpy.dtype(dtype_str)) + + +def _is_remote_path(path: str | PathLike[str]): + """Check whether a path is remote by examining the prefix.""" + # we need to truncate e.g., gs:// to gs:/ because pathlib.Path collapses // + return any(str(path).startswith(prefix[:-1]) + for prefix in _REMOTE_URL_PREFIXES) + + +def _norm_path(path: str | PathLike[str]) -> Any: + if _is_remote_path(path): + return pathlib.Path(path) + return pathlib.Path(path).expanduser().resolve() + + +def _rm_dir(root: Any) -> None: + if _is_remote_path(root): + root.rmtree() # pytype: disable=attribute-error + else: + shutil.rmtree(root) + + +def _set_up_destination(root: str | PathLike[str], overwrite: bool, + pytree_repr: dict[str, Any], distinct_locations: bool, + sync_key: str | None) -> dict[str, Any]: + """Inspect the destination, set it up for writing, potentially read existing data.""" + root = _norm_path(root) + if overwrite: + if root.exists() and len(list(root.iterdir())) > 0: + # check that we're only deleting things that come from JAX + # refuse to rm directories containing additional entries + extra_member_paths = [ + path for path in list(root.iterdir()) if path.name not in + (_PYTREEDEF_FILE, _ARCHIVE_NAME, _ARRAY_STORE_DIRNAME)] + + if len(extra_member_paths) != 0: + raise RuntimeError( + "Refusing to work on a directory that is not a previous checkpoint." + f" Unrecognized paths: {extra_member_paths}. Remove them manually" + f" if you're sure you want to use {root} as the checkpoint" + " directory.") + + if (jax.process_index() == 0 or distinct_locations) and root.exists(): + _rm_dir(root) + _sync_on_key(sync_key, "overwrite") + return pytree_repr + else: + if (root.exists() and len(list(root.iterdir())) > 0): # not empty + raise ValueError(f"Files already exist at path: `{root}`, but you" + f" specified `{overwrite=}`") + return pytree_repr + + +def _prepare_directory(root: str | PathLike[str], overwrite: bool, + pytreedef_repr: dict[str, Any], distinct_locations: bool, + sync_key: str | None): + """Prepare the directory: check destination, potentially read existing data + and overwrite. + + Raises: + RuntimeError: If the destination directory cannot be created. + """ + root = _norm_path(root) + # prepare the destination directory, overwrite destination directory or error + pytreedef_repr = _set_up_destination( + root, overwrite, pytreedef_repr, distinct_locations, sync_key) + + if not _is_remote_path(root) and (distinct_locations + or jax.process_index() == 0): + root.mkdir(exist_ok=True) # do not make parents, that's too much + if not root.exists() or not root.is_dir(): + raise RuntimeError(f"Could not create destination directory at {root}") + _sync_on_key(sync_key, "mkdir") + return pytreedef_repr + + +def _write_arrays(array_store_path: Any, arrs: list[Any], + arr_leaf_ids: list[int], ts_specs: list[Any | None], + distinct_locations: bool): + paths = [array_store_path / str(leaf_id) for leaf_id in arr_leaf_ids] + process_idx = None + if not distinct_locations and jax.process_count() > 1: + process_idx = jax.process_index() + default_ts_specs = [ts_impl.get_tensorstore_spec(path, ocdbt=_USE_OCDBT, + process_idx=process_idx, + arr=arr) + for (path, arr) in zip(paths, arrs)] + ts_specs = [ts_impl.merge_nested_ts_specs(default_ts_spec, ts_spec) + for (default_ts_spec, ts_spec) in zip(default_ts_specs, ts_specs)] + + # sanity check the ts specs + if len(ts_specs) > 0: # verify the base path is shared for all arrays + expected_path = ts_specs[0]["kvstore"]["base"]["path"] # shared base path + for ts_spec, arr in zip(ts_specs, arrs): + ts_impl.verify_tensorstore_spec(ts_spec, arr, expected_path, + ocdbt=_USE_OCDBT, check_metadata=True) + + async def _serialize_arrays(): + await asyncio.gather(*[ + ts_impl.async_serialize(arr, ts_spec, primary_host=None) + for (arr, ts_spec) in zip(arrs, ts_specs)]) + + asyncio.run(_serialize_arrays()) + + +def _finalize_array_store(kvstore_path, distinct_locations: bool): + """When multiple processes are writing, they must write to a per-process + location followed by combining them via no-copy links to the final location. + """ + # only in multiprocess case and only process 0 + if distinct_locations or jax.process_count() == 1 or jax.process_index() != 0: + return + dummy_key_path = os.path.join(kvstore_path, "dummy_key") + combined_kvstore = ts_impl.get_tensorstore_spec( + dummy_key_path, ocdbt=True, process_idx=None)["kvstore"] + children_kvstores = [ts_impl.get_tensorstore_spec( + dummy_key_path, ocdbt=True, process_idx=i)["kvstore"] + for i in range(jax.process_count())] + _ = combined_kvstore.pop("path") + _ = [kvstore.pop("path") for kvstore in children_kvstores] + asyncio.run(ts_impl.combine_kvstores(combined_kvstore, children_kvstores)) + + +def _write_pytreedef(directory: Any, pytree_repr: dict[str, Any], + distinct_locations: bool): + """Write the pytreedef to the destination directory and aux data to the archive.""" + if not (jax.process_index() == 0 or distinct_locations): + return + root = _norm_path(directory) + (root / _PYTREEDEF_FILE).write_text(json.dumps(pytree_repr, indent=2)) + + +def _tree_broadcast(a, b, is_leaf=lambda x: x is None): + """Broadcast the prefix tree `a` to the full tree `b` + + Uses `flatten_axes` for better error messages on mismatched arity but allowing + for custom is_leaf in the `a` and `b` trees. + """ + a_leaves, a_struct = jax.tree.flatten(a, is_leaf=is_leaf) + a_idx2leaf_map = dict(enumerate(a_leaves)) + a_idx = jax.tree.unflatten(a_struct, a_idx2leaf_map.keys()) + a_idx_broadcast = flatten_axes("tree_broadcast", + jax.tree.structure(b, is_leaf=is_leaf), a_idx) + return jax.tree.map(lambda i: a_idx2leaf_map[i], a_idx_broadcast) + + +_serialization_executor = ThreadPoolExecutor(max_workers=_MAX_CONCURRENCY) + + +def save(data: PyTreeT, directory: str | PathLike[str], *, + overwrite: bool = True, ts_specs: PyTreeT | None = None) -> None: + """Saves the given data structure to the provided directory path. + + This function provides functionality to serialize and save a data structure + comprising JAX arrays, along with its structure to a given directory. It + leverages `PyTree` for flattening and reconstructing the data structure. + + This is a simple experimental array serialization API, for anything more + complex and for all checkpointing prefer: https://github.com/google/orbax + + Args: + data: The data structure to be saved. Arbitrary composition of JAX arrays, + including nested structures. + directory: The directory path where the data will be saved. A local path or + a remote URL (e.g., gs://, s3://). For remote URLs, `etils` is required. + overwrite: If True, any existing directory with the same name will be + overwritten. + ts_specs: Optional tensorstore specs to use for serialization. If None, + defaults to using the default tensorstore specs. + + Example: + >>> data = {"a": jnp.array([1, 2]), "b": None} + >>> save(data, directory) + """ + with _THREADING_SAVE_LOCK: + return _save(data, directory, overwrite=overwrite, ts_specs=ts_specs) + + +def _save(data: PyTreeT, directory: str | PathLike[str], *, + overwrite: bool = True, ts_specs: PyTreeT | None = None) -> None: + sync_key = _get_unique_sync_key() # get a synchronization key for multi-host + + if _is_remote_path(directory) and not pathlib.epath_installed: + raise RuntimeError("For saving to remote URLs (e.g., gs, s3) you need the" + " `etils` module installed. You can install it using" + " `pip install etils`.") + ts_specs = _tree_broadcast(ts_specs, data, + is_leaf=ts_impl.is_tensorstore_spec_leaf) + data_flat, pytreedef = jax.tree.flatten(data, is_leaf=lambda x: x is None) + if not all(x is None or _is_array_like(x) for x in data_flat): + raise ValueError("For serialization, all leaves must be either None or" + " jax.Array-like objects.") + distinct_locations = not _is_str_same_on_all_hosts(directory) + if jax.process_count() > 1 and distinct_locations: + raise ValueError( + "Saving to different locations on different hosts is not supported," + " because it is extremely fragile. Consider using a single location.") + root = _norm_path(directory) + + # 1. serialize the pytree ################################# + pytreedef_repr = utils.serialize_pytreedef(pytreedef) + pytreedef_repr[utils._LEAF_IDS_KEY] = jax.tree.map(_leaf_to_desc, data_flat) + + pytreedef_repr = _prepare_directory( + root, overwrite, pytreedef_repr, distinct_locations, sync_key) + futures = [] + futures.append(_serialization_executor.submit( + _write_pytreedef, root, pytreedef_repr, distinct_locations)) + + # 2. serialize arrays ##################################### + array_store_path = root / _ARRAY_STORE_DIRNAME + arrs = [data for data in data_flat if _is_array_like(data)] + arr_leaf_ids = [i for i, data in enumerate(data_flat) if _is_array_like(data)] + ts_specs_flat = jax.tree.leaves(ts_specs, + is_leaf=ts_impl.is_tensorstore_spec_leaf) + ts_specs_flat = [ts_specs_flat[i] for i in arr_leaf_ids] + futures.append(_serialization_executor.submit( + _write_arrays, array_store_path, arrs, arr_leaf_ids, ts_specs_flat, + distinct_locations)) + + # 3. wait for all futures to complete ##################### + _ = [fut.result() for fut in futures] + _sync_on_key(sync_key, "array_serialization") + + # 4. finalize the array writing ########################### + if len(arr_leaf_ids) > 0 and _USE_OCDBT: + _finalize_array_store(array_store_path, distinct_locations) + # we are done with all async ops here, we can block #### + _sync_on_key(sync_key, "end") + + +def _read_arrays(array_store_path: str | PathLike[str], arr_leaf_ids: list[int], + ts_specs: list[Any], shardings: list[Any]): + # array_store_path = root / _LEAF_DATA_DIR / _ARRAY_STORE_DIRNAME + arr_store_path = _norm_path(array_store_path) + arr_paths = [arr_store_path / str(leaf_id) for leaf_id in arr_leaf_ids] + + # byte limiter to limit number of parallel reads, resizes to largest read + byte_limiter = ts_impl._LimitInFlightBytes(10 * 1024 ** 3) # 10 GB + + default_ts_specs = [ts_impl.get_tensorstore_spec(path, ocdbt=_USE_OCDBT, + process_idx=None) + for path in arr_paths] + ts_specs = [ts_impl.merge_nested_ts_specs(default_ts_spec, ts_spec) + for (default_ts_spec, ts_spec) in zip(default_ts_specs, ts_specs)] + + if len(ts_specs) > 0: # verify the base path is shared for all arrays + expected_path = ts_specs[0]["kvstore"]["base"]["path"] # shared base path + for ts_spec in ts_specs: + ts_impl.verify_tensorstore_spec(ts_spec, arr=None, path=expected_path, + ocdbt=_USE_OCDBT, check_metadata=False) + + async def _deserialize_arrays(): + return await asyncio.gather(*[ + ts_impl.async_deserialize(sharding, ts_spec, byte_limiter=byte_limiter) + for (sharding, ts_spec) in zip(shardings, ts_specs)]) + + return dict(zip(arr_leaf_ids, asyncio.run(_deserialize_arrays()))) + + +def load_pytreedef(directory: str | PathLike[str]) -> PyTreeT: + """Loads a pytree from the given directory. + + This is a simple experimental array serialization API, for anything more + complex and for all checkpointing prefer: https://github.com/google/orbax + + Args: + directory: Directory path to load from. + Returns: + The loaded pytree with arrays represented as jax.ShapeDtypeStruct's. + """ + assert not _is_remote_path(directory) or pathlib.epath_installed, ( + "For checkpointing using remote URLs (e.g., gs, s3) you need `etils`" + " module installed. You can install it using `pip install etils`.") + json_content = (_norm_path(directory) / _PYTREEDEF_FILE).read_text() + raw_tree = json.loads(json_content) + leaves = map(_desc_to_leaf, raw_tree[utils._LEAF_IDS_KEY]) + return jax.tree.unflatten(utils.deserialize_pytreedef(raw_tree), leaves) + + +def load(directory: str | PathLike[str], shardings: PyTreeT, *, + mask: PyTreeT | None = None, ts_specs: PyTreeT | None = None + ) -> PyTreeT: + """Loads and reconstructs a data structure from a directory. + + This is a simple experimental array serialization API, for anything more + complex and for all checkpointing prefer: https://github.com/google/orbax + + Args: + directory: Directory path where the data is stored. + shardings: Sharding strategy for array objects. If None, defaults to + single device sharding on the default device. + mask: boolean prefix tree for partial loading, will return None for False + leaves. + ts_specs: Optional tensorstore specs to use for deserialization. If None, + defaults to using the default tensorstore specs. + + Returns: + Reconstructed data. + + Example: + >>> save(data, directory) + >>> restored_data = load(directory, SingleDeviceSharding(jax.devices()[0])) + """ + assert not _is_remote_path(directory) or pathlib.epath_installed, ( + "For checkpointing using remote URLs (e.g., gs, s3) you need `etils`" + " module installed. You can install it using `pip install etils`.") + + root = _norm_path(directory) + assert root.is_dir(), f"Checkpoint directory {root} does not exist" + is_leaf = lambda x: x is None + + # deserialize PyTreeDef + pytree = load_pytreedef(directory) + # broadcast the (prefix) shardings and tensorstore specs to the full pytree + shardings = _tree_broadcast(shardings, pytree) + ts_specs = _tree_broadcast(ts_specs, pytree, + is_leaf=ts_impl.is_tensorstore_spec_leaf) + if mask is not None: + _prefix_mask = lambda m, x: jax.tree.map(lambda _: None, x) if not m else x + pytree = jax.tree.map(_prefix_mask, mask, pytree) + pytreedef = jax.tree.structure(pytree, is_leaf=is_leaf) + leaf_ids_flat = jax.tree.leaves(pytree, is_leaf=is_leaf) + shardings_flat = jax.tree.leaves(shardings, is_leaf=is_leaf) + ts_specs_flat = jax.tree.leaves(ts_specs, + is_leaf=ts_impl.is_tensorstore_spec_leaf) + + # deserialize array objects + arr_leaf_ids = [i for i, leaf_id in enumerate(leaf_ids_flat) + if leaf_id is not None] + shardings_flat = [shardings_flat[i] for i in arr_leaf_ids] + ts_specs_flat = [ts_specs_flat[i] for i in arr_leaf_ids] + + arrs_fut = _serialization_executor.submit( + _read_arrays, root / _ARRAY_STORE_DIRNAME, arr_leaf_ids, ts_specs_flat, + shardings_flat) + + arrs = arrs_fut.result() + filled_values = [arrs.get(i, None) for i, _ in enumerate(leaf_ids_flat)] + return jax.tree.unflatten(pytreedef, filled_values) + + +def nonblocking_save(data: PyTreeT, directory: str | PathLike[str], *, + overwrite: bool = True, ts_specs: PyTreeT | None = None + ) -> utils.PyTreeFuture: + """Nonblocking alias of save, return an awaitable future with a pytree stub. + + This is a simple experimental array serialization API, for anything more + complex and for all checkpointing prefer: https://github.com/google/orbax + + Examples: + >>> fut = nonblocking_save(data, directory) + >>> print(fut.pytree) # a pytree of jax.ShapeDtypeStruct's + >>> print(fut.result()) # None, blocking until the serialization is done + """ + # start serialization immediately + fut = utils.PyTreeFuture(_serialization_executor.submit( + save, data, directory, overwrite=overwrite, ts_specs=ts_specs)) + # construct a nice looking pytree representing the nodes being read + fut.pytree = jax.tree.map(lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype) + if _is_array_like(x) else x, data) + return fut + + +def nonblocking_load(directory: str | PathLike[str], shardings: PyTreeT, *, + mask: PyTreeT | None = None, + ts_specs: PyTreeT | None = None) -> utils.PyTreeFuture: + """Nonblocking alias of load, return an awaitable future with a pytree stub. + + This is a simple experimental array serialization API, for anything more + complex and for all checkpointing prefer: https://github.com/google/orbax + + Examples: + >>> fut = nonblocking_load(directory) + >>> print(fut.pytree) # a pytree of jax.ShapeDtypeStruct + >>> print(fut.result()) # the fully populated pytree + """ + # TODO(rdyro): the awaitable future output is a workaround + # it should return the fully populated pytree instead of just + # jax.ShapeDtypeStruct for arrays by constructing them asynchronously + fut = utils.PyTreeFuture(_serialization_executor.submit( + load, directory, shardings, mask=mask, ts_specs=ts_specs)) + fut.pytree = load_pytreedef(directory) + return fut diff --git a/jax/experimental/array_serialization/pytree_serialization_utils.py b/jax/experimental/array_serialization/pytree_serialization_utils.py new file mode 100644 index 000000000000..a7d37eeab5f8 --- /dev/null +++ b/jax/experimental/array_serialization/pytree_serialization_utils.py @@ -0,0 +1,85 @@ +# Copyright 2021 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# + +# # Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Utilities for representing pytreedefs in a serializable format. +""" + +import base64 +import logging +from types import ModuleType +from concurrent.futures import Future +from typing import Any, TypeVar + +import jax +from jax._src.export.serialization import (flatbuffers, _serialize_pytreedef, + _deserialize_pytreedef_to_pytree, + ser_flatbuf) +from jax.export import register_pytree_node_serialization # pylint: disable=unused-import + +T = TypeVar("T") +PickleModule = ModuleType +logger = logging.getLogger(__name__) + +_READABLE_PYTREE_SERIALIZATION = True +_TREE_REPR_KEY = "__jax_pytreedef_repr" +_LEAF_IDS_KEY = "__jax_leaf_ids" + +_NOT_REGISTERED_MESSAGE = ( + " * If you want to register a custom leaf, register it via" + " `register_pytree_leaf_serialization` first.\n" + " * If you want to register a custom node, register is via" + " `register_pytree_node_serialization`") + +__all__ = ["serialize_pytreedef", "deserialize_pytreedef", + "register_pytree_node_serialization"] + +class PyTreeFuture(Future[Any]): + """A wrapper around a Future that makes it look like an async function.""" + def __init__(self, future: Future[Any]): + self._future, self.pytree = future, None + + def done(self): + return self._future.done() + + def result(self, *args, **kw): + return self._future.result(*args, **kw) + + def __await__(self): + while not self.done(): + yield + return self.result() + + +def _cls2typerepr(cls): + return f"{cls.__module__}.{cls.__name__}" + + +def serialize_pytreedef(node) -> dict[str, Any]: + builder = flatbuffers.Builder(65536) + exported = _serialize_pytreedef(builder, node) + builder.Finish(exported) + root_repr = base64.b64encode(builder.Output()).decode("utf-8") + leaf_count = node.num_leaves + pytree_repr = {_TREE_REPR_KEY: root_repr, + _LEAF_IDS_KEY: list(range(leaf_count))} + return pytree_repr + + +def deserialize_pytreedef(pytreedef_repr: dict[str, Any]): + buf = base64.b64decode(pytreedef_repr[_TREE_REPR_KEY]) + exp = ser_flatbuf.PyTreeDef.GetRootAs(buf) + treestruct = jax.tree.structure(_deserialize_pytreedef_to_pytree(exp)) + return treestruct diff --git a/jax/experimental/array_serialization/serialization.py b/jax/experimental/array_serialization/serialization.py index 8a082b6e912d..44b2eb9ccd03 100644 --- a/jax/experimental/array_serialization/serialization.py +++ b/jax/experimental/array_serialization/serialization.py @@ -17,34 +17,43 @@ import abc import asyncio -from collections.abc import Awaitable, Callable, Sequence -from functools import partial +from collections.abc import Callable, Sequence +import functools import itertools import logging -import os import re import threading import time -from typing import Any, Optional +from typing import Any import jax from jax._src import array from jax._src import distributed from jax._src import sharding -from jax._src.layout import Layout from jax._src import typing from jax._src import util -from jax._src.lib import xla_extension as xe -import jax.numpy as jnp -import numpy as np -import tensorstore as ts +from jax._src.layout import Format +from jax._src.lib import _jax +from jax.experimental.array_serialization import tensorstore_impl as ts_impl +# ruff: noqa: F401 +# pylint: disable=unused-import +# import tensorstore-backed methods for backward compatibility. +from jax.experimental.array_serialization.tensorstore_impl import ( + _run_deserialization as run_deserialization, + _run_serialization as run_serialization, + async_serialize, async_deserialize, _TS_CONTEXT as TS_CONTEXT, + _DEFAULT_BASE_DRIVER as _DEFAULT_DRIVER, _LimitInFlightBytes) + +# for compatibility with older zarr format +_get_metadata = functools.partial(ts_impl._get_tensorstore_metadata, + driver='zarr') +get_tensorstore_spec = functools.partial(ts_impl.get_tensorstore_spec, + driver='zarr', ocdbt=False) +# pylint: enable=unused-import -TS_CONTEXT = ts.Context({'file_io_concurrency': {'limit': 128}}) -_REMOVED_VALUE = 'Value removed' _CHECKPOINT_SUCCESS = 'checkpoint_write_success' _module_unique_count = itertools.count() -_DEFAULT_DRIVER = 'file' _DISTRIBUTED_SYSTEM_MSG = ( 'Please initialize the distributed system via ' '`jax.distributed.initialize()` at the start of your program.') @@ -54,7 +63,7 @@ {'driver': 's3', 'path_regex': None}, ] -class BarrierTimeoutException(Exception): +class BarrierTimeoutError(Exception): pass _BARRIER_TIMED_OUT_MSG = ( @@ -66,68 +75,6 @@ class BarrierTimeoutException(Exception): logger = logging.getLogger(__name__) -async def create_async_array_from_callback( - global_shape: array.Shape, - inp_sharding: jax.sharding.Sharding, - data_callback: Callable[[array.Index, jax.Device], Awaitable[jax.Array]], -): - device_to_index_map = inp_sharding.devices_indices_map(global_shape) - addressable_da = inp_sharding._addressable_device_assignment - future_arrays = [data_callback(device_to_index_map[d], d) - for d in addressable_da] - dbs = await asyncio.gather(*future_arrays) - return array.make_array_from_single_device_arrays( - global_shape, inp_sharding, dbs) - - -def _get_metadata(arr): - local_shape = arr.addressable_data(0).shape - return { - 'compressor': {'id': 'zstd'}, - 'shape': arr.shape, - 'chunks': np.array(np.maximum(1, local_shape)), - } - - -def _spec_has_metadata(tree): - if not isinstance(tree, dict): - return False - return 'metadata' in tree or any( - _spec_has_metadata(subtree) for _, subtree in tree.items()) - -def _get_kvstore_for_gcs(ckpt_path: str): - m = re.fullmatch('^gs://([^/]*)/(.*)$', ckpt_path, re.DOTALL) - if m is None: - raise ValueError('The ckpt_path should contain the bucket name and the ' - f'file path inside the bucket. Got: {ckpt_path}') - gcs_bucket = m.group(1) - path_without_bucket = m.group(2) - return {'driver': 'gcs', 'bucket': gcs_bucket, 'path': path_without_bucket} - -def get_tensorstore_spec(ckpt_path: str, ocdbt: bool = False): - # Normalize path to exclude trailing '/'. In GCS path case, we will need to - # fix the path prefix to add back the stripped '/'. - ckpt_path = os.path.normpath(ckpt_path).replace('gs:/', 'gs://') - is_gcs_path = ckpt_path.startswith('gs://') - spec = {'driver': 'zarr', 'kvstore': {}} - if ocdbt: - if not is_gcs_path and not os.path.isabs(ckpt_path): - raise ValueError(f'Checkpoint path should be absolute. Got {ckpt_path}') - base_path = os.path.dirname(ckpt_path) - spec['kvstore'] = { - 'driver': 'ocdbt', - 'base': base_path if is_gcs_path else f'{_DEFAULT_DRIVER}://{base_path}', - 'path': os.path.basename(ckpt_path), - } - else: - if is_gcs_path: - spec['kvstore'] = _get_kvstore_for_gcs(ckpt_path) - else: - spec['kvstore'] = {'driver': _DEFAULT_DRIVER, 'path': ckpt_path} - - return spec - - def is_remote_storage(tspec: dict[str, Any] | str) -> bool: """Detect if user is using cloud storages. @@ -157,278 +104,6 @@ def is_remote_storage(tspec: dict[str, Any] | str) -> bool: return False - -# Lifted from T5X. -class _LimitInFlightBytes: - """Limits in-flight bytes when reading/writing checkpoints per process.""" - - def __init__(self, num_bytes): - self._max_bytes = num_bytes - self._available_bytes = num_bytes - self._cv = asyncio.Condition(lock=asyncio.Lock()) - - async def wait_for_bytes(self, requested_bytes): - if requested_bytes > self._max_bytes: - raise ValueError('Requested more bytes than we reserved space for: ' - f'{requested_bytes} > {self._max_bytes}') - async with self._cv: - await self._cv.wait_for(lambda: self._available_bytes > requested_bytes) - self._available_bytes -= requested_bytes - assert self._available_bytes >= 0 - - async def release_bytes(self, requested_bytes): - async with self._cv: - self._available_bytes += requested_bytes - assert self._available_bytes <= self._max_bytes - self._cv.notify_all() - - -async def transfer_shard_to_host(shard: array.Shard) -> np.ndarray: - data = shard.data - has_pinned_host = any( - m.kind == "pinned_host" for m in shard.device.addressable_memories()) - if has_pinned_host: - # If available, transfer to pinned host memory - sharding = jax.sharding.SingleDeviceSharding(shard.device, - memory_kind="pinned_host") - data = jax.device_put(data, sharding) - else: - data.copy_to_host_async() - # Allow other transfers to be scheduled simultaneously - await asyncio.sleep(0) - # Ensure that jax.Array's internal numpy array can be zero-copied. Tensorstore - # implicitly converts the written data to a numpy array, and would otherwise - # silently copy host-to-host. - return np.array(data, copy=False) - - -async def async_serialize( - arr_inp, - tensorstore_spec, - commit_future=None, - context=TS_CONTEXT, - primary_host: int | None = 0, - replica_id: int = 0, - transaction: Optional[ts.Transaction] = None, -): - """Serialize an array using TensorStore. - - Args: - arr_inp: The array to serialize. - tensorstore_spec: The tensorstore spec to use. - commit_future: A list of futures that will be appended to. The futures can - be awaited asynchronously. If None, the futures will be awaited - synchronously by this method. - context: ts.Context instance. - primary_host: Primary host, which indicates the host that will be treated as - the "leader". If None, all hosts are treated as the primary. DO NOT USE - unless you are sure you know what you are doing. - replica_id: Allows overriding the shard replica id that will be saved. DO - NOT USE unless you are sure you know what you are doing. - transaction: TensorStore transaction to use for opening and writing the - array. If not specified, a non-transactional write will be used. - """ - if (isinstance(arr_inp, array.ArrayImpl) and jax.process_count() > 1 and - arr_inp.is_fully_addressable): - raise ValueError( - f'Passing fully addressable arrays to a multiprocess ' - f'serialization is not allowed, as this may lead to a race condition ' - f'between processes. Serialization have failed for the array with ' - f'the path "{tensorstore_spec["kvstore"]["path"]}".') - - # 'metadata' may not be present at the top level (for example, if we are using - # a 'cast' driver). - if not _spec_has_metadata(tensorstore_spec): - tensorstore_spec['metadata'] = _get_metadata(arr_inp) - - # Set dtype if it's not in spec - if 'dtype' not in tensorstore_spec: - tensorstore_spec['dtype'] = jnp.dtype(arr_inp.dtype).name - - # If primary_host is None, all hosts will checkpoint. This is used - # for checkpointing to local filesystem. - if primary_host is None or jax.process_index() == primary_host: - open_future = ts.open( - ts.Spec(tensorstore_spec), - create=True, - open=True, - context=context, - transaction=transaction, - ) - # Asynchronous case. - if commit_future is not None: - assert isinstance(commit_future, list) - commit_future.append(open_future) - else: - await open_future - - # `ts.open` runs twice for process `primary_host` because for the first time, - # we just get the future to be awaited upon in the background thread. The - # second one runs with `assume_metadata=True` which does no I/O operation and - # returns the tensorstore object. - # For every process other than `primary_host`, we open with - # `assume_metadata=True`. - t = await ts.open( - ts.Spec(tensorstore_spec), - open=True, - assume_metadata=True, - context=context, - transaction=transaction, - ) - - async def _write_array(shard): - if shard.replica_id == replica_id: - data = await transfer_shard_to_host(shard) - write_future = t[shard.index].write( - data, - # Avoid additional copy of input array into the TensorStore chunk - # cache. If `arr_inp` is a jax.Array, the result of converting - # it to a NumPy array, as is done internally by TensorStore, is - # guaranteed to be immutable and therefore it is safe to retain a - # reference indefinitely. - can_reference_source_data_indefinitely=isinstance( - arr_inp, array.ArrayImpl - ), - ) - if commit_future is not None: - assert isinstance(commit_future, list) - commit_future.append(write_future.commit) - await write_future.copy - else: - await write_future.commit - - local_shards = arr_inp.addressable_shards - future_write_state = jax.tree_util.tree_map(_write_array, local_shards) - return await asyncio.gather(*future_write_state) - - -def run_serialization(arrays, tensorstore_specs): - async def _run_serializer(): - future_writer = jax.tree_util.tree_map(async_serialize, arrays, tensorstore_specs) - return await asyncio.gather(*future_writer) - asyncio.run(_run_serializer()) - - -def estimate_read_memory_footprint(t: ts.TensorStore, - domain: ts.IndexDomain) -> int: - rank = t.rank - num_bytes = t.dtype.numpy_dtype.itemsize - chunk_template = t.chunk_layout.read_chunk_template - if domain is None: - domain = t.domain - origin = domain.origin - shape = domain.shape - chunk_origin = chunk_template.origin - chunk_shape = chunk_template.shape - - # Some TensorStore drivers are not chunked, e.g. the inline 'array' driver. - # For those, instead of returning a near-infinite memory footprint, estimate - # the footprint as the entire shape. - for i in range(rank): - if not chunk_template[i].finite: - return domain.size * num_bytes - - # Otherwise, if we have a chunked driver, estimate based on chunk size. - for i in range(rank): - origin_value = origin[i] - chunk_origin_value = chunk_origin[i] - chunk_size = chunk_shape[i] - lower = origin_value - chunk_origin_value - upper = origin_value + shape[i] - chunk_origin_value - lower_aligned = lower // chunk_size * chunk_size - upper_aligned = -(-upper // chunk_size) * chunk_size - num_bytes *= (upper_aligned - lower_aligned) - - return num_bytes - - -async def async_deserialize( - user_in_sharding: jax.sharding.Sharding | Layout, - tensorstore_spec: ts.Spec | dict[str, Any], - global_shape: Sequence[int] | None = None, - dtype=None, - byte_limiter: _LimitInFlightBytes | None = None, - context=TS_CONTEXT, - assume_metadata: bool = False, -): - in_sharding = (user_in_sharding.sharding - if isinstance(user_in_sharding, Layout) else user_in_sharding) - if not isinstance(in_sharding, jax.sharding.Sharding): - raise ValueError( - 'sharding passed to deserialization should be specified, concrete and' - f' an instance of `jax.sharding.Sharding`. Got {in_sharding}') - dll = (user_in_sharding.device_local_layout - if isinstance(user_in_sharding, Layout) else None) - t = await ts.open( - tensorstore_spec, - open=True, - assume_metadata=assume_metadata, - context=context, - ) - shape = t.shape if global_shape is None else global_shape - new_shard_shape = in_sharding.shard_shape(tuple(shape)) - - async def cb(index: array.Index, device: jax.Device): - requested_domain = ts.IndexTransform(input_shape=shape)[index].domain - restricted_domain = t.domain.intersect(requested_domain) - requested_bytes = estimate_read_memory_footprint(t, restricted_domain) - # Limit the bytes read for every shard. - if byte_limiter is not None: - await byte_limiter.wait_for_bytes(requested_bytes) - # This maybe needed because the shape the array was saved with is smaller - # than the requested shape of the array in which it will be reloaded. So - # the extra values will be filled with 0s. - out = np.zeros(new_shard_shape, dtype=t.dtype.numpy_dtype) - await ts.array(out)[ts.d[:].translate_to[requested_domain.origin]][ - restricted_domain].write(t[restricted_domain]) - if dtype is not None: - # Cast while reloading on process to avoid 2 copies on device if the - # casting is done on device. - out = out.astype(dtype) - # Convert to jnp array so that layouts are initialized properly for - # sub-byte dtypes. - # TODO(yashkatariya): This is a band-aid fix. Figure out a better way to - # make this work. - if out.dtype == jnp.int4: - out = jnp.asarray(out) # type: ignore - result = jax.device_put( - out, Layout(dll, jax.sharding.SingleDeviceSharding(device))) - if byte_limiter is not None: - # NB: `out` actually might not be ready for garbage collection by the - # time we call release_bytes . Thus peak memory usage still might grow - # beyond what byte_limiter limit suggests it should. The simplest option - # would be to call `result.block_until_ready()`` here. However it - # also comes with ~15-20% perf penalty as we would be waiting for CPU->GPU - # transfer instead of loading data. In the future, if memory pressure - # becomes a problem, we can instead instrument bytelimiter to - # keep track of all in-flight tensors and only block_until_ready, if byte - # limiter hits the limit to get reduced memory usage, without losing - # performance in common use cases. - await byte_limiter.release_bytes(requested_bytes) - return result - - return await create_async_array_from_callback(tuple(shape), in_sharding, cb) - - -def run_deserialization(shardings: Sequence[sharding.Sharding | Layout], - tensorstore_specs: Sequence[dict[str, Any]], - global_shapes: Sequence[array.Shape] | None = None, - dtypes: Sequence[typing.DTypeLike] | None = None, - concurrent_gb: int = 32): - concurrent_bytes = concurrent_gb * 10**9 - - async def _run_deserializer(): - # Object should be created once per process. - byte_limiter = _LimitInFlightBytes(concurrent_bytes) - future_arrays = jax.tree_util.tree_map( - partial(async_deserialize, byte_limiter=byte_limiter), - list(shardings), list(tensorstore_specs), - [None] * len(tensorstore_specs) if global_shapes is None else global_shapes, - [None] * len(tensorstore_specs) if dtypes is None else dtypes) - return await asyncio.gather(*future_arrays) - return asyncio.run(_run_deserializer()) - - def _get_key(key: int): return f'tensorstore_checkpoint_{key}' @@ -510,8 +185,7 @@ def __init__(self, timeout_secs=300): if jax.process_count() > 1 and distributed.global_state.client is None: raise ValueError(_DISTRIBUTED_SYSTEM_MSG) - if jax.process_count() > 1: - self._client = distributed.global_state.client + self._client = distributed.global_state.client self._count = None def __del__(self): @@ -533,7 +207,9 @@ def _thread_func(self): logger.info('Finished committing to storage layer by process: %s', current_process) + key_for_barrier = None if process_count > 1: + assert self._client is not None # All processes will wait at the barrier. When all processes are at the # barrier, the barrier will be satisfied. If not, then it will timeout. key_for_barrier = _get_key(self._count) @@ -544,9 +220,11 @@ def _thread_func(self): current_process) if current_process == 0: - self._on_commit_callback() - logger.info('on_commit_callback successfully ran!') + if self._on_commit_callback is not None: + self._on_commit_callback() + logger.info('on_commit_callback successfully ran!') if process_count > 1: + assert self._client is not None self._client.key_value_set(key_for_barrier, _CHECKPOINT_SUCCESS) logger.info('Process 0 successfully set key %s in the kv store', key_for_barrier) @@ -555,7 +233,7 @@ def _thread_func(self): '/jax/checkpoint/write/async/thread_duration_sec', time.time() - thread_start_time) - except Exception as e: + except Exception as e: # pylint: disable=broad-except self._exception = e def _start_async_commit(self, on_commit_callback): @@ -570,9 +248,9 @@ def check_for_errors(self): # Clears self._exception so it is only raised once. exception = self._exception self._exception = None - if (isinstance(exception, xe.XlaRuntimeError) and + if (isinstance(exception, _jax.XlaRuntimeError) and 'DEADLINE_EXCEEDED: Barrier timed out' in str(exception)): - raise BarrierTimeoutException( + raise BarrierTimeoutError( '\n'.join([str(exception), _BARRIER_TIMED_OUT_MSG])) raise exception # pylint: disable=raising-bad-type @@ -586,6 +264,7 @@ def wait_until_finished(self): logger.info('Error check finished successfully') if jax.process_count() > 1 and self._count is not None: + assert self._client is not None # Block until process 0 writes success value to the key value store. # If it fails to write it, then `blocking_key_value_get` will time out. get_key = _get_key(self._count) @@ -605,8 +284,8 @@ def serialize( arrays, tensorstore_specs, *, - on_commit_callback, - transaction: Optional[ts.Transaction] = None, + on_commit_callback: Callable[[], None] | None = None, + transaction: ts_impl.Transaction | None = None, ): """Serializes Arrays or Arrays via TensorStore asynchronously. @@ -635,11 +314,11 @@ def serialize( logger.info('Waiting for previous serialization to finish.') self.wait_until_finished() - commit_futures: list[ts.Future] = [] + commit_futures: list[ts_impl.Future] = [] async def _run_serializer(): future_writer = jax.tree_util.tree_map( - lambda arr_inp, tensorstore_spec: async_serialize( + lambda arr_inp, tensorstore_spec: ts_impl.async_serialize( arr_inp, tensorstore_spec, commit_future=commit_futures, @@ -649,7 +328,6 @@ async def _run_serializer(): tensorstore_specs, ) return await asyncio.gather(*future_writer) - asyncio.run(_run_serializer()) self._add_futures(commit_futures) @@ -663,25 +341,25 @@ def serialize_with_paths( arrays: Sequence[jax.Array], paths: Sequence[str], *, - on_commit_callback, - transaction: Optional[ts.Transaction] = None, + on_commit_callback: Callable[[], None] | None = None, + transaction: ts_impl.Transaction | None = None, ): tspecs = jax.tree.map(get_tensorstore_spec, paths) - self.serialize( + return self.serialize( arrays, tspecs, on_commit_callback=on_commit_callback, transaction=transaction, ) - def deserialize(self, shardings: Sequence[sharding.Sharding | Layout], + def deserialize(self, shardings: Sequence[sharding.Sharding | Format], tensorstore_specs: Sequence[dict[str, Any]], global_shapes: Sequence[array.Shape] | None = None, dtypes: Sequence[typing.DTypeLike] | None = None, concurrent_gb: int = 32): self.wait_until_finished() - return run_deserialization(shardings, tensorstore_specs, - global_shapes, dtypes, concurrent_gb) + return ts_impl._run_deserialization( + shardings, tensorstore_specs, global_shapes, dtypes, concurrent_gb) def deserialize_with_paths( self, shardings: Sequence[sharding.Sharding], diff --git a/jax/experimental/array_serialization/serialization_test.py b/jax/experimental/array_serialization/serialization_test.py index 9f4539fc63c8..eab23443f545 100644 --- a/jax/experimental/array_serialization/serialization_test.py +++ b/jax/experimental/array_serialization/serialization_test.py @@ -12,31 +12,69 @@ # See the License for the specific language governing permissions and # limitations under the License. +# pylint: disable=g-importing-member import asyncio -import math +from dataclasses import dataclass from functools import partial +import json +import logging +import math import os import pathlib -import tracemalloc as tm +import pickle +import tempfile +import threading +import time +import tracemalloc as tm +from typing import Any from absl.testing import absltest from absl.testing import parameterized import jax -import jax.numpy as jnp +from jax._src import array from jax._src import config from jax._src import test_util as jtu -from jax._src import array -from jax.sharding import NamedSharding, GSPMDSharding, SingleDeviceSharding -from jax.sharding import PartitionSpec as P +from jax._src.export._export import ( + deserialization_registry as node_deserialization_registry) +from jax._src.export._export import ( + serialization_registry as node_serialization_registry) +from jax._src.layout import DeviceLocalLayout as DLL +from jax._src.layout import Format +from jax.experimental.array_serialization import pytree_serialization from jax.experimental.array_serialization import serialization -from jax.experimental.layout import Layout, DeviceLocalLayout as DLL +from jax.experimental.array_serialization import tensorstore_impl as ts_impl + +from jax.experimental.array_serialization.pytree_serialization_utils import ( + register_pytree_node_serialization) + +import jax.numpy as jnp + +from jax.sharding import NamedSharding +from jax.sharding import PartitionSpec as P +from jax.sharding import SingleDeviceSharding import numpy as np import tensorstore as ts +# pylint: enable=g-importing-member jax.config.parse_flags_with_absl() jtu.request_cpu_devices(8) +_default_sharding = None + + +def tree_load(*args, **kw): + return pytree_serialization.load(*args, shardings=_default_sharding, **kw) + +tree_save = pytree_serialization.save +tree_load_pytreedef = pytree_serialization.load_pytreedef + + +def _get_replicated_sharding(devices): + return NamedSharding( + jax.make_mesh(np.shape(devices), P('x'), devices=devices), P()) + + class CheckpointTest(jtu.JaxTestCase): def _on_commit_callback(self, temp_ckpt_dir, final_ckpt_dir): @@ -87,19 +125,21 @@ def test_memory_consumption(self): inp = array.make_array_from_callback( inp_shape, sharding, lambda idx: src[idx]) - ckpt_dir = pathlib.Path(self.create_tempdir('memprof').full_path) + ckpt_dir = pathlib.Path(self.create_tempdir( + 'memprof-deserialize').full_path) tspec = serialization.get_tensorstore_spec(str(ckpt_dir)) manager = serialization.GlobalAsyncCheckpointManager() manager.serialize( [inp], [tspec], - on_commit_callback=partial(self._on_commit_callback, ckpt_dir, ckpt_dir)) + on_commit_callback=partial( + self._on_commit_callback, ckpt_dir, ckpt_dir)) manager.wait_until_finished() async def deserialize_with_byte_limit(): r = await serialization.async_deserialize( - sharding, tspec, inp_shape, - byte_limiter=serialization._LimitInFlightBytes(4_200_000)) + sharding, tspec, inp_shape, + byte_limiter=serialization._LimitInFlightBytes(4_200_000)) r.block_until_ready() tm.start() @@ -122,6 +162,7 @@ async def deserialize_with_byte_limit(): self.assertGreater(peak, 30_000_000) tm.stop() + @jtu.thread_unsafe_test() def test_memory_consumption_for_save(self): global_mesh = jtu.create_mesh((1, 1), ('x', 'y')) inp_shape = (16 * 1024, 16 * 1024) @@ -132,25 +173,24 @@ def test_memory_consumption_for_save(self): inp = array.make_array_from_callback( inp_shape, sharding, lambda idx: src[idx] ) - ckpt_dir = pathlib.Path(self.create_tempdir('memprofsave').full_path) - tspec = serialization.get_tensorstore_spec(str(ckpt_dir)) + ckpt_dir = pathlib.Path(self.create_tempdir( + 'memprofsave-serialize').full_path) + tspec = ts_impl.get_tensorstore_spec(str(ckpt_dir), ocdbt=False, + driver='zarr3') tspec['metadata'] = { 'shape': inp.shape, - 'compressor': None, - 'chunks': inp.shape, + 'data_type': jnp.dtype(inp.dtype).name, + 'chunk_grid': { + 'name': 'regular', + 'configuration': {'chunk_shape': np.array(np.maximum(1, inp.shape))} + } } - is_cpu = jtu.test_device_matches(['cpu']) tm.start() try: manager = serialization.GlobalAsyncCheckpointManager() - manager.serialize( - [inp], - [tspec], - on_commit_callback=partial( - self._on_commit_callback, ckpt_dir, ckpt_dir - ), - ) + manager.serialize([inp], [tspec], on_commit_callback=partial( + self._on_commit_callback, ckpt_dir, ckpt_dir)) manager.wait_until_finished() unused_current, peak = tm.get_traced_memory() self.assertLess(peak, src.nbytes * (1 * (not is_cpu) + 0.5)) @@ -176,7 +216,8 @@ def test_checkpointing_with_path_variant(self): manager = serialization.GlobalAsyncCheckpointManager() manager.serialize_with_paths( [a1], ckpt_paths, - on_commit_callback=partial(self._on_commit_callback, ckpt_dir, ckpt_dir)) + on_commit_callback=partial( + self._on_commit_callback, ckpt_dir, ckpt_dir)) manager.wait_until_finished() m1, = manager.deserialize_with_paths( @@ -201,7 +242,8 @@ def test_checkpointing_jax_array(self): inp_shape, NamedSharding(global_mesh, pspec), lambda idx: global_input_data1[idx]) ckpt_dir = pathlib.Path(self.create_tempdir('ckpt').full_path) - ckpt_path1 = pathlib.Path(self.create_tempdir(f'{ckpt_dir}/first').full_path) + ckpt_path1 = pathlib.Path( + self.create_tempdir(f'{ckpt_dir}/first').full_path) # Second Array global_input_data2 = np.arange( @@ -209,7 +251,8 @@ def test_checkpointing_jax_array(self): a2 = array.make_array_from_callback( inp_shape, NamedSharding(global_mesh, pspec), lambda idx: global_input_data2[idx]) - ckpt_path2 = pathlib.Path(self.create_tempdir(f'{ckpt_dir}/second').full_path) + ckpt_path2 = pathlib.Path( + self.create_tempdir(f'{ckpt_dir}/second').full_path) # Third Array def cb3(_): @@ -217,15 +260,17 @@ def cb3(_): global_mesh1d = jtu.create_mesh((8,), ('x',)) a3 = array.make_array_from_callback( (0,), NamedSharding(global_mesh1d, P(None)), cb3) - ckpt_path3 = pathlib.Path(self.create_tempdir(f'{ckpt_dir}/third').full_path) + ckpt_path3 = pathlib.Path( + self.create_tempdir(f'{ckpt_dir}/third').full_path) ckpt_paths = [str(ckpt_path1), str(ckpt_path2), str(ckpt_path3)] - tspecs = jax.tree_util.tree_map(serialization.get_tensorstore_spec, ckpt_paths) + tspecs = jax.tree.map(serialization.get_tensorstore_spec, ckpt_paths) manager = serialization.GlobalAsyncCheckpointManager() manager.serialize( [a1, a2, a3], tspecs, - on_commit_callback=partial(self._on_commit_callback, ckpt_dir, ckpt_dir)) + on_commit_callback=partial( + self._on_commit_callback, ckpt_dir, ckpt_dir)) manager.wait_until_finished() m1, m2, m3 = serialization.run_deserialization( @@ -295,9 +340,8 @@ def cb3(_): ckpt_path3 = ckpt_dir / 'third' ckpt_paths = [str(ckpt_path1), str(ckpt_path2), str(ckpt_path3)] - tspecs = jax.tree_util.tree_map( - lambda p: serialization.get_tensorstore_spec(p, ocdbt=True), ckpt_paths - ) + tspecs = jax.tree.map(partial(ts_impl.get_tensorstore_spec, ocdbt=True), + ckpt_paths) manager = serialization.GlobalAsyncCheckpointManager() with ts.Transaction(atomic=True) as transaction: @@ -312,13 +356,8 @@ def cb3(_): manager.wait_until_finished() m1, m2, m3 = serialization.run_deserialization( - [ - NamedSharding(global_mesh, pspec), - NamedSharding(global_mesh, P('x')), - NamedSharding(global_mesh1d, P(None)), - ], - tspecs, - ) + [NamedSharding(global_mesh, pspec), NamedSharding(global_mesh, P('x')), + NamedSharding(global_mesh1d, P(None))], tspecs) self.assertIsInstance(m1, array.ArrayImpl) self.assertArraysEqual( @@ -367,12 +406,13 @@ def cb1(index): ckpt_dir = pathlib.Path(self.create_tempdir('first').full_path) ckpt_paths = [str(ckpt_dir)] - tspecs = jax.tree_util.tree_map(serialization.get_tensorstore_spec, ckpt_paths) + tspecs = jax.tree.map(serialization.get_tensorstore_spec, ckpt_paths) manager = serialization.GlobalAsyncCheckpointManager() manager.serialize( [arr], tspecs, - on_commit_callback=partial(self._on_commit_callback, ckpt_dir, ckpt_dir)) + on_commit_callback=partial( + self._on_commit_callback, ckpt_dir, ckpt_dir)) manager.wait_until_finished() ds = NamedSharding(jtu.create_mesh((4, 2), ('x', 'y'), iota_order=True), @@ -395,15 +435,16 @@ def cb1(index): for l in m1.addressable_shards: self.assertArraysEqual(np.asarray(l.data), expected_data[l.device.id]) - new_ds = GSPMDSharding.get_replicated(list(global_mesh.devices.flat)) - m2, = serialization.run_deserialization([new_ds], tspecs, [(8, 2)], [np.float32]) + new_ds = _get_replicated_sharding(list(global_mesh.devices.flat)) + m2, = serialization.run_deserialization([new_ds], tspecs, [(8, 2)], + [np.float32]) for l in m2.addressable_shards: self.assertArraysEqual(l.data, global_input_data1.astype('float32')) @parameterized.product(input_dtype=[jnp.int4, jnp.int8]) def test_checkpointing_with_int4(self, input_dtype): if config.use_shardy_partitioner.value: - self.skipTest("TODO(b/376077396): Fix XlaRuntimeError: INVALID_ARGUMENT") + self.skipTest('TODO(b/376077396): Fix XlaRuntimeError: INVALID_ARGUMENT') global_mesh = jtu.create_mesh((2, 2), ('x', 'y'), iota_order=True) global_input_shape = (8, 2) num = math.prod(global_input_shape) @@ -418,12 +459,13 @@ def cb(index): ckpt_dir = pathlib.Path(self.create_tempdir('first').full_path) ckpt_paths = [str(ckpt_dir)] - tspecs = jax.tree_util.tree_map(serialization.get_tensorstore_spec, ckpt_paths) + tspecs = jax.tree.map(serialization.get_tensorstore_spec, ckpt_paths) manager = serialization.GlobalAsyncCheckpointManager() manager.serialize( [arr], tspecs, - on_commit_callback=partial(self._on_commit_callback, ckpt_dir, ckpt_dir)) + on_commit_callback=partial( + self._on_commit_callback, ckpt_dir, ckpt_dir)) manager.wait_until_finished() ds = NamedSharding(jtu.create_mesh((4, 2), ('x', 'y'), iota_order=True), @@ -448,8 +490,9 @@ def cb(index): for l in m1.addressable_shards: self.assertArraysEqual(np.asarray(l.data), expected_data[l.device.id]) - new_ds = GSPMDSharding.get_replicated(list(global_mesh.devices.flat)) - m2, = serialization.run_deserialization([new_ds], tspecs, [(8, 2)], [target_dtype]) + new_ds = _get_replicated_sharding(list(global_mesh.devices.flat)) + m2, = serialization.run_deserialization([new_ds], tspecs, [(8, 2)], + [target_dtype]) for l in m2.addressable_shards: self.assertArraysEqual(l.data, global_input_data.astype(target_dtype)) @@ -463,22 +506,17 @@ def test_checkpointing_scalar_jax_array(self): ckpt_dir = pathlib.Path(self.create_tempdir('first').full_path) ckpt_paths = [str(ckpt_dir)] - tspecs = jax.tree_util.tree_map(serialization.get_tensorstore_spec, ckpt_paths) - + tspecs = jax.tree.map(serialization.get_tensorstore_spec, ckpt_paths) manager = serialization.GlobalAsyncCheckpointManager() manager.serialize( [array1], tspecs, - on_commit_callback=partial(self._on_commit_callback, ckpt_dir, ckpt_dir)) + on_commit_callback=partial( + self._on_commit_callback, ckpt_dir, ckpt_dir)) manager.wait_until_finished() ds = NamedSharding(jtu.create_mesh((2,), ('x')), P(None)) - m1, = serialization.run_deserialization( - [ds], - tspecs, - [()], - [np.float32] - ) + m1, = serialization.run_deserialization([ds], tspecs, [()], [np.float32]) for l in m1.addressable_shards: self.assertArraysEqual(np.asarray(l.data), data.astype(np.float32)) @@ -488,9 +526,7 @@ def test_deserialize_tensorstore_array_jax_array(self): data = np.arange(1024) tspec = ts.array(data).spec() m1, = serialization.run_deserialization( - [NamedSharding(global_mesh, P(None))], - [tspec] - ) + [NamedSharding(global_mesh, P(None))], [tspec]) for l in m1.addressable_shards: self.assertArraysEqual(np.asarray(l.data), data) @@ -507,9 +543,9 @@ def test_spec_has_metadata(self): }, 'f': 4 } - self.assertTrue(serialization._spec_has_metadata(spec)) + self.assertTrue(ts_impl._spec_has_metadata(spec)) self.assertTrue( - serialization._spec_has_metadata({ + ts_impl._spec_has_metadata({ 'driver': 'zarr', 'kvstore': 'gfile', 'metadata': { @@ -531,39 +567,40 @@ def test_spec_has_no_metadata(self): }, 'f': 4 } - self.assertFalse(serialization._spec_has_metadata(spec)) + self.assertFalse(ts_impl._spec_has_metadata(spec)) def test_empty_spec_has_no_metadata(self): spec = {} - self.assertFalse(serialization._spec_has_metadata(spec)) + self.assertFalse(ts_impl._spec_has_metadata(spec)) @parameterized.named_parameters( ('gcs', 'gs://my/ckpt/dir/path'), ('file', '/my/ckpt/dir/path') ) def test_get_tensorstore_spec_ocdbt(self, path): - spec = serialization.get_tensorstore_spec(path, ocdbt=True) + spec = ts_impl.get_tensorstore_spec(path, ocdbt=True) is_gcs_path = path.startswith('gs://') + # for OCDBT the last part of the path is the key in the kvstore + expected_path = os.path.split(path)[0] if is_gcs_path: - self.assertEqual(spec['kvstore']['base'], os.path.dirname(path)) + self.assertEqual(spec['kvstore']['base']['driver'], 'gcs') + self.assertTrue(expected_path.endswith(spec['kvstore']['base']['path'])) else: - self.assertEqual(spec['kvstore']['base'], - f'{serialization._DEFAULT_DRIVER}://{os.path.dirname(path)}') - self.assertEqual(spec['kvstore']['path'], 'path') + self.assertEqual(spec['kvstore']['base']['path'], expected_path) def test_get_tensorstore_spec_not_absolute_path(self): path = 'my/ckpt/path' with self.assertRaisesRegex(ValueError, - "Checkpoint path should be absolute"): - serialization.get_tensorstore_spec(path, ocdbt=True) + 'Checkpoint path should be absolute'): + ts_impl.get_tensorstore_spec(path, ocdbt=True) def test_maybe_cloud_storage(self): - gs_path = 'gs://some-buck/path' - gs_spec = serialization.get_tensorstore_spec(gs_path, ocdbt=True) + gs_path = 'gs://some-buck/path/array_name' + gs_spec = ts_impl.get_tensorstore_spec(gs_path, ocdbt=True) self.assertTrue(serialization.is_remote_storage(gs_spec)) - local_path = '/tmp/checkpoint' - local_spec = serialization.get_tensorstore_spec(local_path, ocdbt=True) + local_path = '/tmp/checkpoint/array_name' + local_spec = ts_impl.get_tensorstore_spec(local_path, ocdbt=True) self.assertFalse(serialization.is_remote_storage(local_spec)) nested_tspec = { @@ -571,7 +608,8 @@ def test_maybe_cloud_storage(self): 'dtype': 'int32', 'base': { 'driver': 'zarr', - 'kvstore': {'driver': 'ocdbt', 'base': 's3://some-bucket/path'}, + 'kvstore': {'driver': 'ocdbt', + 'base': 's3://some-bucket/path/array_name'}, }, } self.assertTrue(serialization.is_remote_storage(nested_tspec)) @@ -585,24 +623,25 @@ def test_load_with_layout(self): s = NamedSharding(mesh, P('x', 'y')) arr = jax.device_put(np_inp, s) - out_layout = jax.jit(lambda x: x.T, out_shardings=Layout(DLL.AUTO)).lower( - arr).compile().output_layouts - self.assertEqual(arr.layout.device_local_layout.major_to_minor, - out_layout.device_local_layout.major_to_minor[::-1]) + out_format = jax.jit(lambda x: x.T, out_shardings=Format(DLL.AUTO)).lower( + arr).compile().output_formats + self.assertEqual(arr.format.device_local_layout.major_to_minor, + out_format.device_local_layout.major_to_minor[::-1]) ckpt_dir = pathlib.Path(self.create_tempdir('ckpt').full_path) ckpt_path = pathlib.Path(self.create_tempdir(f'{ckpt_dir}/first').full_path) - tspecs = jax.tree_util.tree_map(serialization.get_tensorstore_spec, [ckpt_path]) + tspecs = jax.tree.map(ts_impl.get_tensorstore_spec, [ckpt_path]) manager = serialization.GlobalAsyncCheckpointManager() manager.serialize( [arr], tspecs, - on_commit_callback=partial(self._on_commit_callback, ckpt_dir, ckpt_dir)) + on_commit_callback=partial( + self._on_commit_callback, ckpt_dir, ckpt_dir)) manager.wait_until_finished() - out, = serialization.run_deserialization([out_layout], tspecs) + out, = serialization.run_deserialization([out_format], tspecs) - self.assertEqual(out.layout, out_layout) + self.assertEqual(out.format, out_format) self.assertIsInstance(out, array.ArrayImpl) self.assertArraysEqual(out, np_inp) for s in out.addressable_shards: @@ -610,7 +649,7 @@ def test_load_with_layout(self): def test_deserialization_with_int4(self): if config.use_shardy_partitioner.value: - self.skipTest("TODO(b/376077396): Fix XlaRuntimeError: INVALID_ARGUMENT") + self.skipTest('TODO(b/376077396): Fix XlaRuntimeError: INVALID_ARGUMENT') if jtu.test_device_matches(['gpu']): self.skipTest("Fails on GPU. Enable after it's fixed") dtype = jnp.int4 @@ -620,10 +659,8 @@ def test_deserialization_with_int4(self): ckpt_dir = pathlib.Path(self.create_tempdir('test_ckpt').full_path) # Run serialization. - sharding = jax.sharding.GSPMDSharding.get_replicated(jax.devices()) - tspecs = jax.tree_util.tree_map( - serialization.get_tensorstore_spec, [ckpt_dir] - ) + sharding = _get_replicated_sharding(list(jax.devices())) + tspecs = jax.tree.map(serialization.get_tensorstore_spec, [ckpt_dir]) manager = serialization.GlobalAsyncCheckpointManager() manager.serialize( [arr], @@ -634,11 +671,8 @@ def test_deserialization_with_int4(self): # Run deserialization. deserialized_arr, = serialization.run_deserialization( - shardings=[sharding], - tensorstore_specs=tspecs, - global_shapes=[shape], - dtypes=[dtype], - ) + shardings=[sharding], tensorstore_specs=tspecs, global_shapes=[shape], + dtypes=[dtype]) out = deserialized_arr.astype(jnp.int8) # doesn't crash self.assertEqual(out.dtype, jnp.int8) @@ -650,13 +684,397 @@ class TransferShardTest(jtu.JaxTestCase): @jtu.skip_on_devices('cpu') def test_transfer_shard_to_host(self): np_inp = np.arange(16).reshape((4, 4)) - sharding = SingleDeviceSharding(jax.devices()[0], memory_kind="device") + sharding = SingleDeviceSharding(jax.devices()[0], memory_kind='device') arr = jax.device_put(np_inp, sharding) shard = arr.addressable_shards[0] - np_out = asyncio.run(serialization.transfer_shard_to_host(shard)) + np_out = asyncio.run(ts_impl._transfer_shard_to_host(shard)) self.assertArraysEqual(np_out, np_inp) + +def _remove_from_serialization_registry(t: Any): + if t in node_serialization_registry: + serialized_name = node_serialization_registry[t][0] + del node_serialization_registry[t] + del node_deserialization_registry[serialized_name] + + +class UserAPITestCase(jtu.JaxTestCase): + name: str | None + path: pathlib.Path | None + + def setUp(self): + super().setUp() + tmpdir = tempfile.TemporaryDirectory() + self.enter_context(tmpdir) + self.name = tmpdir.name + self.path = pathlib.Path(self.name) + + def tearDown(self): + self.path = None + self.name = None + super().tearDown() + + def generate_random_fp32(self, shape, dtype=jnp.float32): + seed = round(time.time() * 1e6) % (2 ** 31) + key = jax.random.key(seed) + return jax.random.normal(key, shape=shape).astype(dtype) + + def generate_clean_tree(self, dtype=jnp.float32): + r1 = self.generate_random_fp32((), dtype=dtype) + r2 = self.generate_random_fp32((4,), dtype=dtype) + r3 = self.generate_random_fp32((2, 3), dtype=dtype) + return (r1, {'a': r2, 'rs': [r1, r2, r3], 'c': {'d': {'e': (r2,)}}}) + + def _is_equal(self, el1, el2): + if not isinstance(el1, type(el2)) or not isinstance(el2, type(el1)): + return False + if isinstance(el1, (np.ndarray, jax.Array)): + return (el1.dtype == el2.dtype and el1.shape == el2.shape + and jnp.allclose(el1, el2)) + else: + return el1 == el2 + + def assertPyTreeEqual(self, p1, p2, is_leaf=None): + leaves1, struct1 = jax.tree.flatten(p1, is_leaf=is_leaf) + leaves2, struct2 = jax.tree.flatten(p2, is_leaf=is_leaf) + self.assertEqual(struct1, struct2) + self.assertTrue(all(self._is_equal(el1, el2) + for (el1, el2) in zip(leaves1, leaves2))) + +_DTYPES_LIST = [ + jnp.uint8, + jnp.uint16, + jnp.uint32, + jnp.int8, + jnp.int16, + jnp.int32, + jnp.float8_e4m3fn, + jnp.float8_e4m3fnuz, + jnp.float8_e5m2, + jnp.float8_e5m2fnuz, + jnp.float8_e4m3b11fnuz, + jnp.bfloat16, + jnp.float16, + jnp.float32, + jnp.complex64, +] + +_X64_DTYPES_LIST = [ + jnp.uint64, + jnp.int64, + jnp.float64, + jnp.complex128, +] + +if jax.config.x64_enabled: + _DTYPES_LIST.extend(_X64_DTYPES_LIST) + + +@jax.tree_util.register_pytree_node_class +class CustomNode: + def __init__(self, a): + self.a = a + + def tree_flatten(self): + return (self.a,), None + + @classmethod + def tree_unflatten(cls, aux_data, children): + del aux_data + return cls(*children) + + +@partial(jax.tree_util.register_dataclass, data_fields=['a', 'd'], + meta_fields=['c']) +@dataclass +class CustomDataclass: + a: int + c: str + d: int + + +@jax.tree_util.register_static +class CustomStatic: + def __init__(self, a): + self.a = a + +# we're testing custom type registration which modifies the global registry +# so need to ensure we're not running multiple custom types tests in parallel +custom_types_threading_lock = threading.Lock() + + +class UserPytreeAPITest(UserAPITestCase): + def setUp(self): + super().setUp() + global _default_sharding + _default_sharding = SingleDeviceSharding(jax.devices()[0]) + self.tempdirs = [] + + def tearDown(self): + for tempdir in self.tempdirs: + tempdir.cleanup() + super().tearDown() + + def create_tempdir(self): + tempdir = tempfile.TemporaryDirectory() + self.tempdirs.append(tempdir) + return pathlib.Path(tempdir.name).resolve() + + @parameterized.product(tree=[{'a': 1}, [1, 2, 3], (1, 2, 3), 1, 2, 3]) + def test_save_then_load(self, tree): # pylint: disable=redefined-outer-name + path = self.create_tempdir() + tree = jax.tree.map(jnp.array, tree) + tree_save(tree, path) + tree2 = tree_load(path) + self.assertPyTreeEqual(tree, tree2) + + @parameterized.product(dtype=_DTYPES_LIST) + def test_saving_dtype(self, dtype): + if dtype in _X64_DTYPES_LIST and jtu.test_device_matches(['tpu']): + self.skipTest('Don\'t test x64 dtypes on TPUs') + path = self.create_tempdir() + test_tree = self.generate_clean_tree(dtype=dtype) + tree_save(test_tree, path) + new_tree = tree_load(path) + self.assertPyTreeEqual(test_tree, new_tree) + + def test_do_not_overwrite_noncheckpoint_directories(self): + path = self.create_tempdir() + path.mkdir(exist_ok=True) + (path / 'hello.txt').write_text('Hello World') + with self.assertRaisesRegex(RuntimeError, 'Refusing to work on a directory' + ' that is not a previous checkpoint.'): + tree_save({'a': jnp.ones(1)}, path) + + def test_checkpoint_exists(self): + path = self.create_tempdir() + tree_save({'a': jnp.ones(1)}, path) + with self.assertRaises(ValueError): + tree_save({'a': jnp.ones(1)}, path, overwrite=False) + + @parameterized.product(test_load_fail=[True, False]) + def test_custom_types(self, test_load_fail): + path = self.create_tempdir() + with custom_types_threading_lock: + magic_value = jnp.ones(()) * 37 + n = CustomNode(magic_value) + d = CustomDataclass(magic_value, 'hello', magic_value + 1) + s = CustomStatic(magic_value - 1) + tree_to_save = [n, (d, s)] + + register_pytree_node_serialization(CustomNode, + serialized_name='CustomNode', + serialize_auxdata=pickle.dumps, + deserialize_auxdata=pickle.loads) + register_pytree_node_serialization(CustomStatic, + serialized_name='CustomStatic', + serialize_auxdata=pickle.dumps, + deserialize_auxdata=pickle.loads) + register_pytree_node_serialization(CustomDataclass, + serialized_name='CustomDataclass', + serialize_auxdata=pickle.dumps, + deserialize_auxdata=pickle.loads) + tree_save(tree_to_save, path) + if test_load_fail: + _ = [_remove_from_serialization_registry(cls) + for cls in [CustomStatic, CustomNode, CustomDataclass]] + with self.assertRaises(ValueError): + _ = tree_load(path) + else: + tree2 = tree_load(path) + self.assertEqual(tree2[0].a, magic_value) + self.assertEqual(tree2[1][0].a, magic_value) + self.assertEqual(tree2[1][0].c, 'hello') + self.assertEqual(tree2[1][0].d, magic_value + 1) + self.assertEqual(tree2[1][1].a, magic_value - 1) + _ = [_remove_from_serialization_registry(cls) + for cls in [CustomStatic, CustomNode, CustomDataclass]] + + def test_flax_frozen_dict(self): + path = self.create_tempdir() + try: + # pylint: disable=g-import-not-at-top + # pylint: disable=g-importing-member + from flax.core.frozen_dict import FrozenDict + # pylint: enable=g-importing-member + # pylint: enable=g-import-not-at-top + except ImportError: + logging.warning('Skipping Flax FrozenDict tests as flax is not installed') + return + + try: + register_pytree_node_serialization(FrozenDict, + serialized_name='FrozenDict', + serialize_auxdata=pickle.dumps, + deserialize_auxdata=pickle.loads) + tree_save(FrozenDict(a=1, b=self.generate_clean_tree()), path) + tree_load(path) + finally: + _remove_from_serialization_registry(FrozenDict) + + def test_register_as_decorator(self): + @partial(register_pytree_node_serialization, + serialized_name='CustomDNode', + serialize_auxdata=json.dumps, + deserialize_auxdata=json.loads) + @partial(jax.tree_util.register_dataclass, data_fields=['a', 'b'], + meta_fields=[]) + @dataclass + class CustomDNode: + a: int + b: int + + # test whether the object can be created (is visible in this scope) + _ = CustomDNode(1, 2) + + def test_custom_node_registration(self): + path = self.create_tempdir() + + @jax.tree_util.register_static + @dataclass + class P: + a: int = 2 + + @partial(jax.tree_util.register_dataclass, data_fields=['a', 'b'], + meta_fields=['op']) + @dataclass + class D: + a: Any + b: Any + op: str + + def serialize_D(data): + return json.dumps(jax.tree.map(lambda x: np.array(x).tolist(), data) + ).encode('utf-8') + + def deserialize_D(data): + return jnp.array(json.loads(data)) + + data = [jnp.ones(1), {'world': [jnp.zeros(3), (jnp.ones(1), jnp.ones(2))]}, + 7 * jnp.ones(()), P()] + + serialize_fn = lambda p: json.dumps(int(p.a)).encode('utf-8') + deserialize_fn = lambda data: P(json.loads(data)) + + with self.assertRaises(ValueError): + tree_save(data, path) + + register_pytree_node_serialization(P, + serialized_name='P', + serialize_auxdata=serialize_fn, + deserialize_auxdata=deserialize_fn) + magic_value = -171 + data[-1].a = jnp.array(magic_value) + tree_save(data, path) + ret = tree_load(path) + self.assertLen(ret, len(data)) + self.assertEqual(ret[-1].a, magic_value) + + magic_val = 17 * jnp.ones(2) + data.append(D(jnp.ones(1), jax.numpy.zeros(2), magic_val)) + with self.assertRaises(ValueError): + tree_save(data, path) + + register_pytree_node_serialization(D, + serialized_name='D', + serialize_auxdata=serialize_D, + deserialize_auxdata=deserialize_D) + tree_save(data, path) + ret = tree_load(path) + self.assertLen(ret, len(data)) + self.assertLess(jnp.linalg.norm(ret[-1].op - magic_val), 1e-5) + + jax.tree.flatten(data) + + def test_masked_reading(self): + path = self.create_tempdir() + data = [jnp.ones(1), {'world': [jnp.zeros(3), (jnp.ones(1), jnp.ones(2))]}, + 7 * jnp.ones(())] + tree_save(data, path) + for mask in [False, True]: + ret = tree_load(path, mask=mask) + expected = jax.tree.map(lambda x: None if not mask else x, data) + self.assertPyTreeEqual(ret, expected, is_leaf=lambda x: x is None) + + mask = [True, False, False] + expected = data[:1] + jax.tree.map(lambda x: None, data[1:]) + ret = tree_load(path, mask=mask) + self.assertPyTreeEqual(ret, expected, is_leaf=lambda x: x is None) + + mask = [True, True, False] + expected = data[:2] + jax.tree.map(lambda x: None, data[2:]) + ret = tree_load(path, mask=mask) + self.assertPyTreeEqual(ret, expected, is_leaf=lambda x: x is None) + + mask = [True, {'world': [True, (False, True)]}, False] + data[1]['world'][1] = (None, data[1]['world'][1][1]) + ret = tree_load(path, mask=mask) + self.assertPyTreeEqual(ret, expected, is_leaf=lambda x: x is None) + + # TODO(rdyro): Remove when serialization supports non-arrays + @parameterized.product(obj=[b'hello', 'hello', 1, 1.0, 1j]) + def test_serialization_works_for_arrays_only(self, obj): + path = self.create_tempdir() + data = [{'world': [jnp.zeros(3), (jnp.ones(1), jnp.ones(2))]}, obj] + msg = ('For serialization, all leaves must be either None or' + ' jax.Array-like objects.') + with self.assertRaisesRegex(ValueError, msg): + tree_save(data, path) + + def test_load_pytreedef(self): + path = self.create_tempdir() + data = [jnp.ones(1), {'world': [jnp.zeros(3), (jnp.ones(1), jnp.ones(2))]}, + 7 * jnp.ones(())] + tree_save(data, path) + pytreedef = tree_load_pytreedef(path) + expected_pytreedef = jax.tree.map( + lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype), data) + self.assertPyTreeEqual(pytreedef, expected_pytreedef) + + @parameterized.product(data=[ + None, [None], [None, np.ones(())], + [None, {'world': [None, (np.ones(1), np.ones(2))]}, np.ones(())], + [None, {'world': [np.zeros(3), (None, np.ones(2))]}, None]]) + def test_save_and_load_null_leaves(self, data): + path = self.create_tempdir() + # TPUs might not have X64 enabled, so we need to convert to float32 + data = jax.tree.map(lambda x: jnp.array(x, dtype=jnp.float32), data) + tree_save(data, path) + pytreedef = tree_load_pytreedef(path) + is_leaf = lambda x: x is None + expected_pytreedef = jax.tree.map(lambda x: jax.ShapeDtypeStruct( + x.shape, x.dtype) if x is not None else x, data, is_leaf=is_leaf) + self.assertPyTreeEqual(pytreedef, expected_pytreedef) + load_data = tree_load(path) + load_leaves, load_struct = jax.tree.flatten(load_data, is_leaf=is_leaf) + expected_leaves, expected_struct = jax.tree.flatten(data, is_leaf=is_leaf) + self.assertEqual(load_struct, expected_struct) + self.assertLen(load_leaves, len(expected_leaves)) + for (l1, l2) in zip(load_leaves, expected_leaves): + if l1 is None: + self.assertIsNone(l2) + else: + self.assertArraysEqual(l1, l2) + + @parameterized.product(manually_broadcast_ts_specs=[True, False]) + def test_custom_ts_specs(self, manually_broadcast_ts_specs): + if ts_impl._TS_ARRAY_DRIVER == 'zarr': + self.skipTest('Skipping since this test assumes zarr is NOT the default') + path = self.create_tempdir() + data = [jnp.ones(()), (jnp.zeros(()), jnp.ones(())), None] + ts_spec = {'driver': 'zarr', 'metadata': {'shape': ()}} + if manually_broadcast_ts_specs: + ts_specs = [ts_spec, (ts_spec, None), None] # None ts_spec allowed + else: + ts_specs = ts_spec + tree_save(data, path, ts_specs=ts_specs) + load_data = tree_load(path, ts_specs=ts_specs) + self.assertPyTreeEqual(data, load_data) + with self.assertRaisesRegex(ValueError, + 'NOT_FOUND: Error opening "zarr3" driver:'): + _ = tree_load(path) # default attempts to open with zarr3 and fails + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/jax/experimental/array_serialization/tensorstore_impl.py b/jax/experimental/array_serialization/tensorstore_impl.py new file mode 100644 index 000000000000..99f7a137f6f9 --- /dev/null +++ b/jax/experimental/array_serialization/tensorstore_impl.py @@ -0,0 +1,587 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +from functools import partial +import functools +import os +from os import PathLike +import re +from typing import Any +from collections.abc import Awaitable, Callable, Sequence +import math +import logging + +import jax +from jax import numpy as jnp +from jax._src import array +from jax._src.layout import Format +from jax._src import typing +import numpy as np +import tensorstore as ts + +_TS_ARRAY_DRIVER = "zarr3" + +_TS_CONTEXT = ts.Context({ + 'file_io_concurrency': {'limit': 128}, + 'cache_pool': {'total_bytes_limit': 10_000_000_000}, # 10 GB RAM limit + 'cache_pool#remote': {'total_bytes_limit': 10_000_000_000}, + 'data_copy_concurrency': {'limit': 128} +}) +_TS_CHUNK_LAYOUT = ts.ChunkLayout({ + "chunk": {"elements": 100_000_000}, # 100M (800MB for float64) file size +}) + +_DEFAULT_BASE_DRIVER = 'file' +_PROCESS_DIR_FORMAT = "process_{}" +_FILE_SIZE_TARGET = 2 * 1024 ** 3 # 2 GB + +Future, Transaction = ts.Future, ts.Transaction + +logger = logging.getLogger(__name__) + +# Lifted from T5X. +class _LimitInFlightBytes: + """Limits host scratch memory usage when reading/writing checkpoints per process.""" + + def __init__(self, host_memory_bytes_limit: int): + self._max_bytes = host_memory_bytes_limit + self._available_bytes = host_memory_bytes_limit + self._cv = asyncio.Condition(lock=asyncio.Lock()) + + async def wait_for_bytes(self, requested_bytes): + if requested_bytes > self._max_bytes: + logger.debug("A single array item requests more bytes than we reserved" + " space for in the parallel pool: %d > %d. Increasing the" + " limit to %d.", requested_bytes, self._max_bytes, + requested_bytes) + self._max_bytes = requested_bytes + async with self._cv: + await self._cv.wait_for(lambda: self._available_bytes >= requested_bytes) + self._available_bytes -= requested_bytes + assert self._available_bytes >= 0 + + async def release_bytes(self, requested_bytes): + async with self._cv: + self._available_bytes += requested_bytes + assert self._available_bytes <= self._max_bytes + self._cv.notify_all() + +def is_tensorstore_spec_leaf(leaf: Any): + # TODO(rdyro): think of a better way to detect which leaf is a ts config + return leaf is None or (isinstance(leaf, dict) + and ("driver" in leaf or "kvstore" in leaf)) + +def _prime_factors(x: int) -> list[int]: + # find prime factors of axis sizes to help efficiently find divisor chunks + factors = [] + while x % 2 == 0: + factors.append(2) + x //= 2 + for i in range(3, int(math.sqrt(x)) + 1, 2): + while x % i == 0: + factors.append(i) + x //= i + if x > 1: + factors.append(x) + return sorted(factors) + +@functools.lru_cache(maxsize=1024) +def _compute_chunk_shape( + local_shape: Sequence[int], dtype: str | jnp.dtype, + file_size_target: int = _FILE_SIZE_TARGET) -> list[int]: + """Compute a chunk such that it divides the local shape and is less than + target file size. This helps the tensorstore kvstore driver limit the largest + file size on disk to below the ``file_size_target``. We compute a chunk with a + byte size at most 110% of the ``file_size_target``. + """ + local_shape = list(local_shape) + if len(local_shape) == 0 or math.prod(local_shape) == 0: + # a zero size array needs a non-zero chunk passed to tensorstore for compat. + return [max(z, 1) for z in local_shape] + total_size = math.prod(local_shape) * jnp.dtype(dtype).itemsize + axis_prime_factors = [_prime_factors(z) for z in local_shape] + chunk_shape, chunk_size = list(local_shape), total_size + # while chunk_size exceeds target size, reduce chunk_shape + while chunk_size > 1.1 * file_size_target: # 10% buffer + # 1. find the smallest axis divisor across all axes + chosen_axis_idx, chosen_divisor = None, 1 + for axis_idx in range(len(chunk_shape)): + if len(axis_prime_factors[axis_idx]) == 1: # ignore axes sizes == 1 + continue + if (chosen_axis_idx is None + or chosen_divisor > axis_prime_factors[axis_idx][0]): + chosen_axis_idx = axis_idx + chosen_divisor = axis_prime_factors[axis_idx][0] + # 2. if no divisor found, give up, return current chunk shape + if chosen_axis_idx is None: + return chunk_shape + # 3. remove the applied divisor from prime factors + prime_factors = axis_prime_factors[chosen_axis_idx] + prime_factors.pop(0) + # 4. apply the found divisor to reduce the chunk size + chunk_shape[chosen_axis_idx] //= chosen_divisor + chunk_size //= chosen_divisor + return chunk_shape + +def _get_tensorstore_metadata(arr, is_remote: bool = False, + file_size_target: int = _FILE_SIZE_TARGET, + driver: str = _TS_ARRAY_DRIVER) -> dict[str, Any]: + global_shape, dtype = arr.shape, arr.dtype + if hasattr(arr, 'addressable_data'): # jax.Array + local_shape = arr.addressable_data(0).shape + else: # np.ndarray + local_shape = global_shape + return _get_tensorstore_metadata_cached(global_shape, dtype, local_shape, + is_remote, file_size_target, driver) + +@functools.lru_cache(maxsize=1024) +def _get_tensorstore_metadata_cached( + global_shape: Sequence[int], dtype: jnp.dtype, local_shape: Sequence[int], + is_remote: bool = False, file_size_target: int = _FILE_SIZE_TARGET, + driver: str = _TS_ARRAY_DRIVER) -> dict[str, Any]: + if driver == "zarr3": + codecs = ([{"name": "zstd"}] if is_remote else []) + return { + 'codecs': codecs, + 'shape': global_shape, + 'data_type': jnp.dtype(dtype).name, + 'chunk_grid': { + 'name': 'regular', + 'configuration': {'chunk_shape': _compute_chunk_shape( + local_shape, dtype, file_size_target=file_size_target)} + } + } + elif driver == "zarr": # in zarr dtype goes in the base spec + return {'compressor': {'id': 'zstd'}, 'shape': global_shape, + 'chunks': np.array(np.maximum(1, local_shape)).tolist()} + else: + raise ValueError(f"Unsupported driver: {driver}") + +_divides = lambda x, y: np.all((np.array(x) % np.array(y)) == 0) + +def merge_nested_ts_specs(dict1: dict[Any, Any], dict2: dict[Any, Any] | None): + """Merge two ts specs, dict2 takes precedence.""" + if dict2 is None: # nothing to do + return dict1 + # TODO(rdyro): this is an opinionated merge, we should get user feedback + # merge kvstore explicitly + kvstore = dict1.get("kvstore", {}) | dict2.get("kvstore", {}) + return dict1 | dict(dict2, kvstore=kvstore) # merge with dict2 preferred + +def verify_tensorstore_spec(spec: dict[str, Any], arr: jax.Array | None, + path: str | os.PathLike[str], ocdbt: bool, + check_metadata: bool = True) -> None: + """Verify the minimum requirements for a tensorstore spec.""" + if ocdbt: + if spec.get("kvstore", {}).get("driver", "") != "ocdbt": + raise ValueError(f"Expected ocdbt driver, got {spec=}") + if check_metadata: + if arr is None: + raise ValueError("Array is required for metadata verification.") + metadata = spec['metadata'] + if spec.get("driver", "") == "zarr3": + if metadata['data_type'] != jnp.dtype(arr.dtype).name: + raise ValueError(f"Provided dtype ({metadata['data_type']=}) doesn't" + f" match ({arr.dtype=})") + if 'shape' in metadata: + if metadata['shape'] != arr.shape: + raise ValueError(f"Provided shape ({metadata['shape']=}) doesn't match" + f" ({arr.shape=})") + if hasattr(arr, 'addressable_data'): + local_shape = arr.addressable_data(0).shape + else: # np.ndarray + local_shape = arr.shape + if spec.get("driver", "") == "zarr3": + chunk_shape = metadata['chunk_grid']['configuration']['chunk_shape'] + if not _divides(local_shape, chunk_shape): + raise ValueError(f"Provided chunk shape {chunk_shape} does not divide" + f" the local shape of the array {local_shape}") + # check path is still the same one we expect + if ocdbt: + found_path = spec["kvstore"]['base']['path'] + else: + found_path = spec["kvstore"]['path'] + if str(found_path) != str(path): + raise ValueError(f"Provided {path=} does not match the spec path:" + f" {spec['kvstore']}") + +def _spec_has_metadata(tree): + if not isinstance(tree, dict): + return False + return 'metadata' in tree or any( + _spec_has_metadata(subtree) for _, subtree in tree.items()) + +def _get_kvstore_for_gcs(ckpt_path: str): + m = re.fullmatch('^gs://([^/]*)/(.*)$', ckpt_path) + if m is None: + raise ValueError('The ckpt_path should contain the bucket name and the ' + f'file path inside the bucket. Got: {ckpt_path}') + bucket = m.group(1) + path_without_bucket = m.group(2) + return {'driver': 'gcs', 'bucket': bucket, 'path': path_without_bucket} + +def _get_kvstore_for_s3(ckpt_path: str): + m = re.fullmatch('^s3://([^/]*)/(.*)$', ckpt_path, re.DOTALL) + if m is None: + raise ValueError('The ckpt_path should contain the bucket name and the ' + f'file path inside the bucket. Got: {ckpt_path}') + bucket = m.group(1) + path_without_bucket = m.group(2) + return {'driver': 's3', 'bucket': bucket, 'path': path_without_bucket} + +def get_tensorstore_spec( + ckpt_path: str | PathLike[str], ocdbt: bool = True, + process_idx: int | None = None, arr: jax.Array | None = None, + driver: str = _TS_ARRAY_DRIVER) -> dict[str, Any]: + + # Normalize path to exclude trailing '/'. In GCS path case, normpath will + # replace a the double '//' with a single '/' and we need to restore the + # filesystem type:// prefix for GCS (gs://) and S3 paths (s3://) + ckpt_path = os.path.normpath(str(ckpt_path)) + ckpt_path = re.sub(r"^([a-z]+):/", r"\1://", ckpt_path) + + # in cases of multi-process writes, we need to write to a different location + # for each process and finally created a combined symlink to the final + # location, tensorstore can do this via ts.KvStore.experimental_copy_range_to + if process_idx is not None: + _parent, _name = os.path.split(ckpt_path) + ckpt_path = os.path.join(_parent, _PROCESS_DIR_FORMAT.format(process_idx), + _name) + + is_gcs_path = ckpt_path.startswith('gs://') + is_s3_path = ckpt_path.startswith('s3://') + spec = {'driver': driver, 'kvstore': {}} + + # use a combined OCDBT store, the actual path is the parent path + # the name (filename/last part of the path) is the key in the ocdbt kvstore + entry_key = None + if ocdbt: + (ckpt_path, entry_key), org_ckpt_path = os.path.split(ckpt_path), ckpt_path + if is_gcs_path: + m = re.fullmatch('^gs://([^/]*)/(.*)$', ckpt_path) + elif is_s3_path: + m = re.fullmatch('^s3://([^/]*)/(.*)$', ckpt_path) + else: + m = re.match("a", "a") # make it True + if m is None: + raise ValueError('Using OCDBT requires the bucket name, the directory' + ' name and the array name, your path is: ' + f'{org_ckpt_path}') + + if is_gcs_path: + base_kvstore = _get_kvstore_for_gcs(ckpt_path) + elif is_s3_path: + base_kvstore = _get_kvstore_for_s3(ckpt_path) + else: + base_kvstore = {'driver': _DEFAULT_BASE_DRIVER, 'path': ckpt_path} + + if ocdbt: + if not is_gcs_path and not is_s3_path and not os.path.isabs(ckpt_path): + raise ValueError(f'Checkpoint path should be absolute. Got {ckpt_path}') + spec['kvstore'] = {'driver': 'ocdbt', 'base': base_kvstore, + 'path': entry_key} + else: + spec['kvstore'] = base_kvstore + # done writing tensorstore spec based on destination path + # optionally, if array is provided, we can add metadata to the spec + if arr is not None: + spec["metadata"] = _get_tensorstore_metadata( + arr, driver=str(spec["driver"])) + return spec + +async def _create_async_array_from_callback( + global_shape: array.Shape, + inp_sharding: jax.sharding.Sharding, + data_callback: Callable[[array.Index, jax.Device], Awaitable[jax.Array]], +): + device_to_index_map = inp_sharding.devices_indices_map(global_shape) + addressable_da = inp_sharding._addressable_device_assignment + future_arrays = [data_callback(device_to_index_map[d], d) + for d in addressable_da] + dbs = await asyncio.gather(*future_arrays) + return array.make_array_from_single_device_arrays( + global_shape, inp_sharding, dbs) + +async def _transfer_shard_to_host(shard: array.Shard) -> np.ndarray: + data = shard.data + has_pinned_host = any( + m.kind == "pinned_host" for m in shard.device.addressable_memories()) + if has_pinned_host: + # If available, transfer to pinned host memory + sharding = jax.sharding.SingleDeviceSharding(shard.device, + memory_kind="pinned_host") + data = jax.device_put(data, sharding) + else: + data.copy_to_host_async() + # Allow other transfers to be scheduled simultaneously + await asyncio.sleep(0) + # Ensure that jax.Array's internal numpy array can be zero-copied. Tensorstore + # implicitly converts the written data to a numpy array, and would otherwise + # silently copy host-to-host. + return np.array(data, copy=False) + +async def combine_kvstores(combined_kvstore: dict[str, Any], + kvstores: list[dict[str, Any]], + context: ts.Context | dict[str, Any] = _TS_CONTEXT + ) -> None: + """Merge a list of kvstores into a single kvstore. NOT multi-process safe.""" + combined_fut = ts.KvStore.open(combined_kvstore, context=context) + kvstores_futs = [ts.KvStore.open(kvstore, context=context) + for kvstore in kvstores] + combined, kvstores = await asyncio.gather(combined_fut, + asyncio.gather(*kvstores_futs)) + tx = ts.Transaction() + await asyncio.gather(*[kvstore.experimental_copy_range_to( + combined.with_transaction(tx)) for kvstore in kvstores]) + await tx.commit_async() + +async def async_serialize( + arr_inp, + tensorstore_spec, + commit_future=None, + context=_TS_CONTEXT, + chunk_layout=_TS_CHUNK_LAYOUT, + primary_host: int | None = None, + replica_id: int = 0, + transaction: ts.Transaction | None = None, +): + """Serialize an array using TensorStore. + + Args: + arr_inp: The array to serialize. + tensorstore_spec: The tensorstore spec to use. + commit_future: A list of futures that will be appended to. The futures can + be awaited asynchronously. If None, the futures will be awaited + synchronously by this method. + context: ts.Context instance. + primary_host: Primary host, which indicates the host that will be treated as + the "leader". If None, all hosts are treated as the primary. DO NOT USE + unless you are sure you know what you are doing. + replica_id: Allows overriding the shard replica id that will be saved. DO + NOT USE unless you are sure you know what you are doing. + transaction: TensorStore transaction to use for opening and writing the + array. If not specified, a non-transactional write will be used. + """ + if (isinstance(arr_inp, array.ArrayImpl) and jax.process_count() > 1 and + arr_inp.is_fully_addressable): + raise ValueError( + f'Passing fully addressable arrays to a multiprocess ' + f'serialization is not allowed, as this may lead to a race condition ' + f'between processes. Serialization have failed for the array with ' + f'the path from kvstore: "{tensorstore_spec["kvstore"]}".') + + # 'metadata' may not be present at the top level (for example, if we are using + # a 'cast' driver). + if not _spec_has_metadata(tensorstore_spec): + tensorstore_spec['metadata'] = _get_tensorstore_metadata( + arr_inp, driver=tensorstore_spec['driver']) + ## zarr driver requires specifying the dtype in the spec base + if tensorstore_spec['driver'] == 'zarr' and 'dtype' not in tensorstore_spec: + tensorstore_spec['dtype'] = jnp.dtype(arr_inp.dtype).name + + # If primary_host is None, all hosts will checkpoint. This is used + # for checkpointing to local filesystem. + if primary_host is None or jax.process_index() == primary_host: + open_future = ts.open( + ts.Spec(tensorstore_spec), + create=True, + open=True, + context=context, + chunk_layout=chunk_layout, + transaction=transaction, + ) + # Asynchronous case. + if commit_future is not None: + assert isinstance(commit_future, list) + commit_future.append(open_future) + else: + await open_future + + # `ts.open` runs twice for process `primary_host` because for the first time, + # we just get the future to be awaited upon in the background thread. The + # second one runs with `assume_metadata=True` which does no I/O operation and + # returns the tensorstore object. + # For every process other than `primary_host`, we open with + # `assume_metadata=True`. + t = await ts.open( + ts.Spec(tensorstore_spec), + open=True, + assume_metadata=True, + context=context, + chunk_layout=chunk_layout, + transaction=transaction, + ) + + async def _write_array(shard): + if shard.replica_id == replica_id: + data = await _transfer_shard_to_host(shard) + write_future = t[shard.index].write( + data, + # Avoid additional copy of input array into the TensorStore chunk + # cache. If `arr_inp` is a jax.Array, the result of converting + # it to a NumPy array, as is done internally by TensorStore, is + # guaranteed to be immutable and therefore it is safe to retain a + # reference indefinitely. + can_reference_source_data_indefinitely=isinstance( + arr_inp, array.ArrayImpl + ), + ) + if commit_future is not None: + assert isinstance(commit_future, list) + commit_future.append(write_future.commit) + await write_future.copy + else: + await write_future.commit + + local_shards = arr_inp.addressable_shards + future_write_state = jax.tree_util.tree_map(_write_array, local_shards) + return await asyncio.gather(*future_write_state) + + +# TODO(rdyro): Remove this function. +def _run_serialization(arrays, tensorstore_specs): + """Legacy serialization of a list of arrays.""" + async def _run_serializer(): + future_writer = jax.tree_util.tree_map(async_serialize, arrays, tensorstore_specs) + return await asyncio.gather(*future_writer) + asyncio.run(_run_serializer()) + + +def estimate_read_memory_footprint(t: ts.TensorStore, + domain: ts.IndexDomain) -> int: + rank = t.rank + num_bytes = t.dtype.numpy_dtype.itemsize + chunk_template = t.chunk_layout.read_chunk_template + if domain is None: + domain = t.domain + origin = domain.origin + shape = domain.shape + chunk_origin = chunk_template.origin + chunk_shape = chunk_template.shape + + # Some TensorStore drivers are not chunked, e.g. the inline 'array' driver. + # For those, instead of returning a near-infinite memory footprint, estimate + # the footprint as the entire shape. + for i in range(rank): + if not chunk_template[i].finite: + return domain.size * num_bytes + + # Otherwise, if we have a chunked driver, estimate based on chunk size. + for i in range(rank): + origin_value = origin[i] + chunk_origin_value = chunk_origin[i] + chunk_size = chunk_shape[i] + lower = origin_value - chunk_origin_value + upper = origin_value + shape[i] - chunk_origin_value + lower_aligned = lower // chunk_size * chunk_size + upper_aligned = -(-upper // chunk_size) * chunk_size + num_bytes *= (upper_aligned - lower_aligned) + + return num_bytes + + +async def async_deserialize( + user_in_sharding: jax.sharding.Sharding | Format, + tensorstore_spec: ts.Spec | dict[str, Any], + global_shape: Sequence[int] | None = None, + dtype=None, + byte_limiter: _LimitInFlightBytes | None = None, + context=_TS_CONTEXT, + chunk_layout=_TS_CHUNK_LAYOUT, + assume_metadata: bool = False, +): + """Main performant deserialization routine for arrays using tensorstore.""" + in_sharding = (user_in_sharding.sharding + if isinstance(user_in_sharding, Format) else user_in_sharding) + if not isinstance(in_sharding, jax.sharding.Sharding): + raise ValueError( + 'sharding passed to deserialization should be specified, concrete and' + f' an instance of `jax.sharding.Sharding`. Got {in_sharding}') + dll = (user_in_sharding.device_local_layout + if isinstance(user_in_sharding, Format) else None) + t = await ts.open( + tensorstore_spec, + open=True, + assume_metadata=assume_metadata, + context=context, + chunk_layout=chunk_layout, + ) + shape = t.shape if global_shape is None else global_shape + new_shard_shape = in_sharding.shard_shape(tuple(shape)) + + async def cb(index: array.Index, device: jax.Device): + requested_domain = ts.IndexTransform(input_shape=shape)[index].domain + restricted_domain = t.domain.intersect(requested_domain) + requested_bytes = estimate_read_memory_footprint(t, restricted_domain) + # Limit the bytes read for every shard. + if byte_limiter is not None: + await byte_limiter.wait_for_bytes(requested_bytes) + # This maybe needed because the shape the array was saved with is smaller + # than the requested shape of the array in which it will be reloaded. So + # the extra values will be filled with 0s. + out = np.zeros(new_shard_shape, dtype=t.dtype.numpy_dtype) + await ts.array(out)[ts.d[:].translate_to[requested_domain.origin]][ + restricted_domain].write(t[restricted_domain]) + if dtype is not None: + # Cast while reloading on process to avoid 2 copies on device if the + # casting is done on device. + out = out.astype(dtype) + # Convert to jnp array so that layouts are initialized properly for + # sub-byte dtypes. + # TODO(yashkatariya): This is a band-aid fix. Figure out a better way to + # make this work. + if out.dtype == jnp.int4: + out = jnp.asarray(out) # type: ignore + result = jax.device_put( + out, Format(dll, jax.sharding.SingleDeviceSharding(device))) + if byte_limiter is not None: + # NB: `out` actually might not be ready for garbage collection by the + # time we call release_bytes . Thus peak memory usage still might grow + # beyond what byte_limiter limit suggests it should. The simplest option + # would be to call `result.block_until_ready()`` here. However it + # also comes with ~15-20% perf penalty as we would be waiting for CPU->GPU + # transfer instead of loading data. In the future, if memory pressure + # becomes a problem, we can instead instrument bytelimiter to + # keep track of all in-flight tensors and only block_until_ready, if byte + # limiter hits the limit to get reduced memory usage, without losing + # performance in common use cases. + await byte_limiter.release_bytes(requested_bytes) + return result + + return await _create_async_array_from_callback(tuple(shape), in_sharding, cb) + + +# TODO(rdyro): Remove this function. +def _run_deserialization(shardings: Sequence[jax.sharding.Sharding | Format], + tensorstore_specs: Sequence[dict[str, Any]], + global_shapes: Sequence[array.Shape] | None = None, + dtypes: Sequence[typing.DTypeLike] | None = None, + concurrent_gb: int = 32): + """Legacy deserialization of a list of arrays. Optionally pass global_shapes + and dtypes for type-checking. + """ + concurrent_bytes = concurrent_gb * 10**9 + + async def _run_deserializer(): + # Object should be created once per process. + byte_limiter = _LimitInFlightBytes(concurrent_bytes) + + future_arrays = jax.tree_util.tree_map( + partial(async_deserialize, byte_limiter=byte_limiter), + list(shardings), list(tensorstore_specs), + [None] * len(tensorstore_specs) if global_shapes is None else global_shapes, + [None] * len(tensorstore_specs) if dtypes is None else dtypes) + return await asyncio.gather(*future_arrays) + return asyncio.run(_run_deserializer()) diff --git a/jax/experimental/attrs.py b/jax/experimental/attrs.py index 4e1dc4b8f493..8984b2159d82 100644 --- a/jax/experimental/attrs.py +++ b/jax/experimental/attrs.py @@ -12,225 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations - -from typing import Any, Callable - -from jax._src import core -from jax._src import source_info_util -from jax._src import api_util -from jax._src import linear_util as lu -from jax._src.ad_util import (Zero) -from jax._src.api_util import flatten_fun_nokwargs -from jax._src.interpreters import ad -from jax._src.interpreters import partial_eval as pe -from jax._src.tree_util import (tree_flatten, tree_unflatten, tree_structure, - treedef_tuple) -from jax._src.util import unzip2, safe_map, safe_zip, split_list -from jax._src.dtypes import dtype, float0 - -map, unsafe_map = safe_map, map -zip, unsafe_zip = safe_zip, zip - -JaxVal = Any -Pytree = Any - -register = api_util.register_class_with_attrs - -def jax_getattr(obj: Any, attr: str): - with core.take_current_trace() as t: - return t.process_getattr(obj, attr) - -def jax_setattr(obj: Any, attr: str, val: Pytree): - with core.take_current_trace() as t: - return t.process_setattr(obj, attr, val) - -def _getattr_impl(_, obj, attr): - return getattr(obj, attr) -core.EvalTrace.process_getattr = _getattr_impl - -def _setattr_impl(_, obj, attr, val): - setattr(obj, attr, val) -core.EvalTrace.process_setattr = _setattr_impl - -def _ensure_tracked(trace: pe.DynamicJaxprTrace, obj: Any, attr: str): - frame = trace.frame - - def new_tracer(x): - aval = core.get_aval(x) - tracer = pe.DynamicJaxprTracer(trace, aval, pe.source_info_util.current()) - var = frame.tracer_to_var[id(tracer)] = frame.newvar(aval) - frame.attrs_vars.append(var) - frame.tracers.append(tracer) - return tracer - - if (obj, attr) not in frame.attrs_tracked: - init_val = getattr(obj, attr) - frame.attrs_inits.append(init_val) - init_vals, init_tree = tree_flatten(init_val) - tracers = map(new_tracer, init_vals) - setattr(obj, attr, tree_unflatten(init_tree, tracers)) - frame.attrs_tracked.append((obj, attr)) -pe.DynamicJaxprTrace._ensure_tracked = _ensure_tracked - -def _getattr_staging(trace, obj, attr): - trace._ensure_tracked(obj, attr) - return getattr(obj, attr) -pe.DynamicJaxprTrace.process_getattr = _getattr_staging - -def _setattr_staging(trace, obj, attr, val): - trace._ensure_tracked(obj, attr) - setattr(obj, attr, val) -pe.DynamicJaxprTrace.process_setattr = _setattr_staging - - -def jvp(f, primals, tangents, attr_tangents): - attrs, attr_tangents = unzip2(((o, a), t) for o, a, t in attr_tangents) - attr_primals = tuple(jax_getattr(o, a) for o, a in attrs) - primals_flat, in_tree = tree_flatten((attr_primals, *primals)) - tangents_flat, in_tree_ = tree_flatten((attr_tangents, *tangents)) - if in_tree != in_tree_: raise Exception - dbg = api_util.debug_info("attrs_jvp", f, primals, {}) - f_, out_tree = flatten_fun_nokwargs( - _set_attrs(lu.wrap_init(f, debug_info=dbg), attrs), in_tree) - out_primals_flat, out_tangents_flat, tangent_attrs_out = _jvp(f_).call_wrapped( - primals_flat, tangents_flat) - out_primals = tree_unflatten(out_tree(), out_primals_flat) - out_tangents = tree_unflatten(out_tree(), out_tangents_flat) - return out_primals, out_tangents, tangent_attrs_out - -@lu.transformation2 -def _set_attrs(f, attrs, attr_vals, *args): - for (o, a), x in zip(attrs, attr_vals): - jax_setattr(o, a, x) - return f(*args) - -def _jvp(fun: lu.WrappedFun): - return jvpfun2(jvp_subtrace2(fun)) - -@lu.transformation2 -def jvpfun2(f, primals, tangents): - tag = core.TraceTag() - tangents = [Zero.from_primal_value(t) if not isinstance(t, Zero) - and dtype(t) == float0 else t for t in tangents] - ctx = source_info_util.transform_name_stack('jvp') - with ctx: - out_primals, out_tangents, tangent_attrs_out = f(tag, primals, tangents) - return out_primals, out_tangents, tangent_attrs_out - -@lu.transformation2 -def jvp_subtrace2(f, tag, primals, tangents): - with core.take_current_trace() as parent_trace: - trace = ad.JVPTrace(parent_trace, tag) - tag.attrs_tracked = [] # attrs written to - in_tracers = [ad.JVPTracer(trace, x, t) if type(t) is not ad.Zero else x - for x, t in zip(primals, tangents)] - with core.set_current_trace(trace): - ans = f(*in_tracers) - out_primals, out_tangents = unzip2(map(trace.to_primal_tangent_pair, ans)) - tangent_attrs_out = [] - for (obj, name) in tag.attrs_tracked: - primal, tangent = trace.to_primal_tangent_pair(jax_getattr(obj, name)) - jax_setattr(obj, name, primal) - if type(tangent) is not ad.Zero: - tangent_attrs_out.append((obj, name, tangent)) - del tag.attrs_tracked - return out_primals, out_tangents, tangent_attrs_out - -def _setattr_jvp(trace, obj, attr, maybe_tracer): - primal, tangent = trace.to_primal_tangent_pair(maybe_tracer) - if isinstance(tangent, ad.Zero): - return setattr(obj, attr, primal) - if (obj, attr) not in trace.tag.attrs_tracked: - trace.tag.attrs_tracked.append((obj, attr)) - return setattr(obj, attr, ad.JVPTracer(trace, primal, tangent)) -ad.JVPTrace.process_setattr = _setattr_jvp - -def _getattr_jvp(trace, obj, attr): - return getattr(obj, attr) -ad.JVPTrace.process_getattr = _getattr_jvp - -ad.LinearizeTrace.process_setattr = _setattr_jvp -ad.LinearizeTrace.process_getattr = _getattr_jvp - -def linearize(f: Callable, *primals, attrs: list[tuple[Any, str]] = []): - attr_primals = [jax_getattr(o, a) for o, a in attrs] - attr_avals = [core.get_aval(p) for p in attr_primals] - primals_flat, in_tree = tree_flatten(primals) - tree = treedef_tuple((tree_structure(attr_primals), *in_tree.children())) - dbg = api_util.debug_info("attrs linearize", f, primals, {}) - f_, out_tree = flatten_fun_nokwargs( - _set_attrs(lu.wrap_init(f, debug_info=dbg), attrs), tree) - primal_out, out_pvals, jaxpr, consts, attrs_out = _linearize( - f_, *attr_primals, *primals_flat) - f_lin = _lin_wrap(jaxpr, consts, out_pvals, attr_avals, (in_tree, out_tree()), - attrs, attrs_out) - return tree_unflatten(out_tree(), primal_out), f_lin - -def _linearize(traceable: lu.WrappedFun, *primals): - jvpfun, attrs = _split_attrs(_jvp(traceable)) - in_pvals = (tuple(pe.PartialVal.known(p) for p in primals) - + tuple(pe.PartialVal.unknown(core.get_aval(p).to_tangent_aval()) - for p in primals)) - _, in_tree = tree_flatten((primals, primals)) - jvpfun_flat, out_tree = flatten_fun_nokwargs(jvpfun, in_tree) - jaxpr, out_pvals, consts = pe.trace_to_jaxpr_nounits(jvpfun_flat, in_pvals) - out_primals_pvals, out_tangents_pvals, out_tangent_attr_pvals = \ - tree_unflatten(out_tree(), out_pvals) - out_primals_consts = [pval.get_known() for pval in out_primals_pvals] - return (out_primals_consts, [*out_tangents_pvals, *out_tangent_attr_pvals], - jaxpr, consts, attrs()) - -@lu.transformation_with_aux2 -def _split_attrs(f, store, *args, **kwargs): - primals, tangents, tangent_attrs = f(*args, **kwargs) - attrs, tangent_attr_vals = unzip2(((o, a), t) for o, a, t in tangent_attrs) - store.store(attrs) - return primals, tangents, tangent_attr_vals - -def _lin_wrap(jaxpr, consts, out_pvals, attr_avals, io_tree, in_attrs, out_attrs): - in_tree, out_tree = io_tree - def f_lin(*tangents, attr_tangents): - if set(attr_tangents) - set(in_attrs): raise Exception - tangents_, in_tree_ = tree_flatten(tangents) - assert in_tree == in_tree_ - attr_tangents_ = [attr_tangents.get(a, ad.Zero(aval)) - for a, aval in zip(in_attrs, attr_avals)] - out = core.eval_jaxpr(jaxpr, consts, *attr_tangents_, *tangents_) - out_ = iter(out) - out = [p.get_known() if p.is_known() else next(out_) for p in out_pvals] - assert next(out_, None) is None - tangents_out, attr_tangents_out = split_list(out, [len(out)-len(out_attrs)]) - out_ct = tree_unflatten(out_tree, tangents_out) - return out_ct, dict(zip(out_attrs, attr_tangents_out)) - return f_lin - - -def vjp(f, *primals, attrs: list[tuple[Any, str]] = []): - attr_primals = [jax_getattr(o, a) for o, a in attrs] - primals_flat, in_tree = tree_flatten(primals) - tree = treedef_tuple((tree_structure(attr_primals), *in_tree.children())) - dbg = api_util.debug_info("attrs vjp", f, primals, {}) - f_, out_tree = flatten_fun_nokwargs( - _set_attrs(lu.wrap_init(f, debug_info=dbg), attrs), tree) - primal_out, out_pvals, jaxpr, consts, attrs_out = _linearize( - f_, *attr_primals, *primals_flat) - attr_avals = [core.get_aval(jax_getattr(o, a)).to_tangent_aval() - for o, a in attrs_out] - f_vjp = _vjp_wrap(jaxpr, consts, out_pvals, attr_avals, (in_tree, out_tree()), - attrs, attrs_out) - return tree_unflatten(out_tree(), primal_out), f_vjp - -def _vjp_wrap(jaxpr, consts, out_pvals, attr_avals, io_tree, in_attrs, out_attrs): - in_tree, out_tree = io_tree - dummies = [ad.UndefinedPrimal(v.aval) for v in jaxpr.invars] - def f_vjp(out_ct, *, attr_cotangents: dict[tuple[Any, str], JaxVal] = {}): - out_cts, out_tree_ = tree_flatten(out_ct) - assert out_tree == out_tree_ - attr_cts = [attr_cotangents.get(a, ad.Zero(aval)) - for a, aval in zip(out_attrs, attr_avals)] - out = ad.backward_pass(jaxpr, (), consts, dummies, (*out_cts, *attr_cts)) - in_attr_bars, arg_cts = split_list(out, [len(in_attrs)]) - args_ct = tree_unflatten(in_tree, map(ad.instantiate_zeros, arg_cts)) - return args_ct, dict(zip(in_attrs, in_attr_bars)) - return f_vjp +from jax._src.attrs import ( + jax_setattr as jax_setattr, + jax_getattr as jax_getattr, + jax_appendattr as jax_appendattr, + Box as Box, + List as List, +) diff --git a/jax/experimental/buffer_callback.py b/jax/experimental/buffer_callback.py new file mode 100644 index 000000000000..f919cfa10208 --- /dev/null +++ b/jax/experimental/buffer_callback.py @@ -0,0 +1,20 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from jax._src.buffer_callback import ( + Buffer as Buffer, + ExecutionContext as ExecutionContext, + ExecutionStage as ExecutionStage, + buffer_callback as buffer_callback, +) diff --git a/jax/experimental/colocated_python/api.py b/jax/experimental/colocated_python/api.py index b855bba48abb..363a2987cb9c 100644 --- a/jax/experimental/colocated_python/api.py +++ b/jax/experimental/colocated_python/api.py @@ -16,18 +16,50 @@ from __future__ import annotations import collections -from typing import Any, Callable, Sequence, Type +from typing import Any, overload +from collections.abc import Callable, Sequence import jax from jax._src import api_util +from jax._src import util from jax.experimental.colocated_python.func import make_callable from jax.experimental.colocated_python.obj import wrap_class +import numpy as np +@overload def colocated_cpu_devices( - devices: Sequence[jax.Device], + devices_or_mesh: Sequence[jax.Device], +) -> Sequence[jax.Device]: + ... + + +@overload +def colocated_cpu_devices( + devices_or_mesh: jax.sharding.Mesh, +) -> jax.sharding.Mesh: + ... + + +def colocated_cpu_devices(devices_or_mesh): + """Finds devices or a mesh that has CPU devices colocated with the given devices or mesh.""" + if isinstance(devices_or_mesh, jax.sharding.Mesh): + return _colocated_cpu_mesh_cached(devices_or_mesh) + + if not isinstance(devices_or_mesh, tuple): + devices_or_mesh = tuple(devices_or_mesh) + try: + return _colocated_cpu_devices_cached(devices_or_mesh) + except (ValueError, AttributeError): + return _colocated_cpu_devices_cached_fallback_to_cpu_backend( + devices_or_mesh + ) + + +@util.cache(max_size=1024, trace_context_in_key=False) +def _colocated_cpu_devices_cached( + devices: tuple[jax.Device, ...], ) -> Sequence[jax.Device]: - """Finds CPU devices colocated with the given devices.""" cpu_devices_by_colocation_id = collections.defaultdict(list) for device in devices[0].client._get_all_devices(): # pylint: disable=protected-access if device.device_kind == "cpu": @@ -49,13 +81,43 @@ def colocated_cpu_devices( return colocated_cpu_devices -def colocated_python(fun: Callable[..., Any]) -> Callable[..., Any]: +@util.cache(max_size=1024, trace_context_in_key=False) +def _colocated_cpu_devices_cached_fallback_to_cpu_backend( + devices: tuple[jax.Device, ...], +) -> Sequence[jax.Device]: + # PjRt-IFRT currently defines CPU devices by using a CPU backend. + # TODO(hyeontaek): Remove this fallback path once a PjRt-IFRT backend defines + # CPU devices by its own instead of using a separate CPU backend. + cpu_backend_devices = jax.local_devices(backend="cpu") + device_index_map = {device.id: i for i, device in enumerate(jax.devices())} + + available_devices = devices[: min(len(cpu_backend_devices), len(devices))] + return [ + cpu_backend_devices[device_index_map[d.id]] for d in available_devices + ] + + +@util.cache(max_size=1024, trace_context_in_key=False) +def _colocated_cpu_mesh_cached(mesh: jax.sharding.Mesh) -> jax.sharding.Mesh: + """Returns a CPU mesh that is similar to the given mesh but has colocated CPU devices.""" + # Finding colocated CPU devices reuses the cache of `colocated_cpu_devices` + # called with devices. `_colocated_cpu_mesh` itself is also cached to avoid + # creating a new `Mesh` object repeatedly. + flat_cpu_devices = colocated_cpu_devices(tuple(mesh.devices.flat)) + return jax.sharding.Mesh( + np.array(flat_cpu_devices).reshape(mesh.axis_sizes), + mesh.axis_names, + axis_types=mesh.axis_types, + ) + + +def colocated_python(fun: Callable[..., Any]): """Executes the given Python function on the same devices as the arguments.""" return make_callable( fun, api_util.fun_sourceinfo(fun), api_util.fun_signature(fun) ) -def colocated_python_class(cls: Type[object]) -> Type[object]: +def colocated_python_class(cls: type[object]) -> type[object]: """Executes the given Python class methods on the same devices as the arguments.""" return wrap_class(cls, api_util.fun_sourceinfo(cls)) diff --git a/jax/experimental/colocated_python/func.py b/jax/experimental/colocated_python/func.py index effca1fe77b7..9ad84c7e06ad 100644 --- a/jax/experimental/colocated_python/func.py +++ b/jax/experimental/colocated_python/func.py @@ -19,7 +19,8 @@ import inspect import random import threading -from typing import Any, Callable, Sequence +from typing import Any +from collections.abc import Callable, Sequence import jax from jax._src import api @@ -65,7 +66,7 @@ def update( out_specs_treedef: tree_util.PyTreeDef | None = None, out_specs_leaves: tuple[api.ShapeDtypeStruct, ...] | None = None, devices: Sequence[jax.Device] | xc.DeviceList | None = None, - ) -> Any: + ): """Creates a new specialization with overrides.""" if in_specs_treedef is None: in_specs_treedef = self.in_specs_treedef @@ -169,7 +170,7 @@ def _compile_to_executable( program, compile_options ) out_handlers = pxla.global_avals_to_results_handler( - out_sdss, out_shardings, committed=True + out_sdss, out_shardings, committed=True # type: ignore ).handlers def call(*args, **kwargs): @@ -234,7 +235,7 @@ def _make_pop_result_fun( out_specs_treedef = specialization.out_specs_treedef - def lowered_fun() -> Any: + def lowered_fun(): result_leaves = func_backend.SINGLETON_RESULT_STORE.pop(uid) return tree_util.tree_unflatten(out_specs_treedef, result_leaves) @@ -279,7 +280,7 @@ def _make_async_execution_fun( ) -@jax.util.cache(max_size=None) +@jax._src.util.cache(max_size=None) def _get_specialized_func( info: FunctionInfo, specialization: Specialization ) -> Callable[..., Any]: @@ -294,7 +295,7 @@ def _get_specialized_func( # Asynchronous execution function that has known output_specs. async_execution_func = None - def specialized_func(*args, **kwargs) -> Any: + def specialized_func(*args, **kwargs): """Specialized function to be executed with given args and kwargs.""" nonlocal specialization, async_execution_func with mutex: @@ -356,24 +357,21 @@ def make_callable( fun: Callable[..., Any], fun_sourceinfo: str | None, fun_signature: inspect.Signature | None, -) -> Callable[..., Any]: +): """Makes a colocated Python callable.""" return _make_callable( FunctionInfo(fun, fun_sourceinfo, fun_signature), Specialization() ) -def _make_callable( - info: FunctionInfo, - specialization: Specialization, -) -> Callable[..., Any]: +def _make_callable(info: FunctionInfo, specialization: Specialization): """Internal implementation of make_callable.""" def specialize( in_specs: ShapeDtypeStructTree | None = None, out_specs_fn: Callable[..., ShapeDtypeStructTree] | None = None, devices: Sequence[jax.Device] | None = None, - ) -> Callable[..., Any]: + ): """Returns a colocated Python callable with extra specialization. Args: @@ -410,7 +408,7 @@ def specialize( ) @api_boundary - def __call__(*args, **kwargs) -> Any: + def __call__(*args, **kwargs): """Executes the function. If the output specs are not known, the very first execution will be diff --git a/jax/experimental/colocated_python/func_backend.py b/jax/experimental/colocated_python/func_backend.py index aa514015004d..4f1443da4b17 100644 --- a/jax/experimental/colocated_python/func_backend.py +++ b/jax/experimental/colocated_python/func_backend.py @@ -16,7 +16,7 @@ from __future__ import annotations import threading -from typing import Sequence +from collections.abc import Sequence import jax diff --git a/jax/experimental/colocated_python/obj.py b/jax/experimental/colocated_python/obj.py index b1e7a0b1eade..b962b82525fd 100644 --- a/jax/experimental/colocated_python/obj.py +++ b/jax/experimental/colocated_python/obj.py @@ -18,7 +18,8 @@ import inspect import random import threading -from typing import Any, Callable, Type +from typing import Any +from collections.abc import Callable import jax from jax._src import api_util @@ -58,7 +59,7 @@ def pop_instance(self, uid: int) -> set[jax.Device]: SINGLETON_INSTANCE_REGISTRY = _InstanceRegistry() -@jax.util.cache(max_size=4096) +@jax._src.util.cache(max_size=4096) def _update_instance_devices( uid: int, shardings: tuple[jax.sharding.Sharding, ...] ) -> None: @@ -70,7 +71,7 @@ def _update_instance_devices( def _make_method( - cls: Type[object], + cls: type[object], cls_sourceinfo: str | None, uid: int, init_args: tuple[Any, ...], @@ -114,9 +115,9 @@ def method_wrapper(*args, **kwargs): def wrap_class( - cls: Type[object], + cls: type[object], cls_sourceinfo: str | None, -) -> Type[object]: +) -> type[object]: class WrappedClass: @wraps(cls.__init__) diff --git a/jax/experimental/colocated_python/obj_backend.py b/jax/experimental/colocated_python/obj_backend.py index ffa04a007818..eb3b2c4049d9 100644 --- a/jax/experimental/colocated_python/obj_backend.py +++ b/jax/experimental/colocated_python/obj_backend.py @@ -17,7 +17,8 @@ import dataclasses import threading -from typing import Any, Callable +from typing import Any +from collections.abc import Callable @dataclasses.dataclass(frozen=True) diff --git a/jax/experimental/colocated_python/serialization.py b/jax/experimental/colocated_python/serialization.py index 1ca29ab12660..83a12277a872 100644 --- a/jax/experimental/colocated_python/serialization.py +++ b/jax/experimental/colocated_python/serialization.py @@ -19,7 +19,8 @@ import collections import functools import io -from typing import Any, Callable, Sequence +from typing import Any +from collections.abc import Callable, Sequence try: import cloudpickle # type: ignore[import-not-found] @@ -35,7 +36,7 @@ DeviceList = xc.DeviceList -@jax.util.cache(max_size=None) +@jax._src.util.cache(max_size=None) def _get_cpu_device_map() -> dict[int, jax.Device]: """Returns a map from a device id to a matching device.""" cpu_device_map: dict[int, jax.Device] = {} @@ -99,6 +100,20 @@ def make_mesh( return make_mesh, (mesh_device_ids, mesh.axis_names) +def _reduce_named_sharding( + sharding: jax.sharding.NamedSharding, +) -> tuple[Callable[..., jax.sharding.NamedSharding], Any]: + # TODO(hyeontaek): Use `legacy_memory_space_behavior=false` for the + # CPU backend's `xla::CpuClientOptions`, and preserve the memory + # kind across serialization. + # Colocated Python implicitly relies on the default memory kind + # being reset to the default memory space when deserializing. + def _make_named_sharding(mesh, spec): + return jax.sharding.NamedSharding(mesh, spec) + + return _make_named_sharding, (sharding.mesh, sharding.spec) + + def _reduce_device_list( device_list: DeviceList, ) -> tuple[Callable[..., DeviceList], Any]: @@ -149,6 +164,7 @@ def _serialize(obj: Any) -> bytes: class _CustomPickler(cloudpickle.Pickler): dispatch_table = collections.ChainMap( {jax.sharding.Mesh: _reduce_mesh}, + {jax.sharding.NamedSharding: _reduce_named_sharding}, {DeviceList: _reduce_device_list}, {jax.sharding.SingleDeviceSharding: _reduce_single_device_sharding}, cloudpickle.CloudPickler.dispatch_table, # pylint: disable=attribute-error @@ -201,7 +217,7 @@ def _serialize_specs( if not hasattr(np.dtypes, "StringDType"): raise TypeError( "Serializing Colocated Python requires StringDType. Please use" - " numpy to 2.0.0 or later, or explicityly provide an output spec" + " numpy to 2.0.0 or later, or explicitly provide an output spec" " function." ) diff --git a/jax/experimental/host_callback.py b/jax/experimental/host_callback.py index 7d60f62e230f..4dcd9f66a961 100644 --- a/jax/experimental/host_callback.py +++ b/jax/experimental/host_callback.py @@ -16,7 +16,7 @@ .. warning:: The host_callback APIs are deprecated as of March 20, 2024. The functionality is subsumed by the - `new JAX external callbacks `_ + `new JAX external callbacks `_ See https://github.com/jax-ml/jax/issues/20385. """ diff --git a/jax/experimental/jax2tf/README.md b/jax/experimental/jax2tf/README.md index 0d827fbcc7a5..ac9829d69006 100644 --- a/jax/experimental/jax2tf/README.md +++ b/jax/experimental/jax2tf/README.md @@ -138,7 +138,7 @@ f_tf_graph = tf.function(f_tf, autograph=False) ``` Note that when using the default native serialization, the target JAX function -must be jittable (see [JAX - The Sharp Bits](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html)). +must be jittable (see [JAX - The Sharp Bits](https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html)). In the native serialization mode, under TensorFlow eager the whole JAX function executes as one op. @@ -461,7 +461,7 @@ presence of shape polymorphism, some dimensions may be dimension variables. The `polymorphic_shapes` parameter must be either `None`, or a pytree of shape specifiers corresponding to the pytree of arguments. (A value `None` for `polymorphic_shapes` is equivalent to a list of `None`. -See [how optional parameters are matched to arguments](https://jax.readthedocs.io/en/latest/pytrees.html#applying-optional-parameters-to-pytrees).) +See [how optional parameters are matched to arguments](https://docs.jax.dev/en/latest/pytrees.html#applying-optional-parameters-to-pytrees).) A shape specifier is combined with a `TensorSpec` as follows: * A shape specifier of `None` means that the shape is given @@ -568,6 +568,7 @@ because the shape abstraction that JAX tracing uses is given by the actual arguments are more specific and would actually work. Also, + ```python jax2tf.convert(lambda x: jnp.matmul(x, x), polymorphic_shapes=["(v, 4)"])(np.ones((4, 4))) @@ -808,6 +809,7 @@ TypeError: add got incompatible shapes for broadcasting: (a,), (floordiv(b, 2),) ``` You can fix this by adding a constraint: + ```python jax2tf.convert(lambda x, y: x + y[:y.shape[0] // 2], polymorphic_shapes=("a", "b"), @@ -826,19 +828,19 @@ For example, the following code will fail because `a1` and `a2` use different scopes (created by `export.symbolic_shape`): -````python +```python a1, = export.symbolic_shape("a,") a2, = export.symbolic_shape("a,", constraints=("a >= 8",)) a1 + a2 -```` +``` The symbolic expressions that originate from a single call to `export.symbolic_shape` share a scope and can be mixed up in arithmetic operations. The result would also share the same scope. -You can re-use scopes: +You can reuse scopes: ```python a, = export.symbolic_shape("a,", constraints=("a >= 8",)) @@ -1005,6 +1007,8 @@ We list here a history of the serialization version numbers: available in JAX since October 20th, 2023 (JAX 0.4.20), and the default since February 1st, 2024 (JAX 0.4.24). This is the only supported version as of 27th of March, 2024. + * Version 10 propagate the `jax.config.use_shardy_partitioner` value to + XlaCallModule. ## Known issues @@ -1024,7 +1028,7 @@ always behaves like the JAX function. JAX interprets the type of Python scalars differently based on `JAX_ENABLE_X64` flag. (See -[JAX - The Sharp Bits: Double (64bit) precision](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision).) +[JAX - The Sharp Bits: Double (64bit) precision](https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision).) In the default configuration, the flag is unset, and JAX interprets Python constants as 32-bit, e.g., the type of `3.14` is `float32`. This is also what @@ -1086,7 +1090,7 @@ Applies to both native and non-native serialization. `jax2tf` can lower functions with arguments and results that are nested collections (tuples, lists, dictionaries) of numeric values or JAX arrays -([pytrees](https://jax.readthedocs.io/en/latest/pytrees.html)). The +([pytrees](https://docs.jax.dev/en/latest/pytrees.html)). The resulting TensorFlow function will take the same kind of arguments except the leaves can be numeric values or TensorFlow tensors (`tf.Tensor`, `tf.TensorSpec`, `tf.Variable`). @@ -1285,7 +1289,7 @@ per PRNG operation. The "unsafe" part is that it doesn't guarantee determinism across JAX/XLA versions, and the quality of random streams it generates from different keys is less well understood. Nevertheless, this should be fine for most inference/serving cases. -See more details in the [JAX PRNG documentation](https://jax.readthedocs.io/en/latest/jax.random.html?highlight=unsafe_rbg#advanced-rng-configuration). +See more details in the [JAX PRNG documentation](https://docs.jax.dev/en/latest/jax.random.html?highlight=unsafe_rbg#advanced-rng-configuration). ### SavedModel supports only first-order gradients diff --git a/jax/experimental/jax2tf/call_tf.py b/jax/experimental/jax2tf/call_tf.py index 98c1c20cd6e5..2aadc3a9d512 100644 --- a/jax/experimental/jax2tf/call_tf.py +++ b/jax/experimental/jax2tf/call_tf.py @@ -40,12 +40,13 @@ from jax._src import core from jax._src import effects from jax._src import util -from jax._src.lib import xla_client +from jax._src.lib import _jax from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import func as func_dialect from jax._src.lib.mlir.dialects import hlo from jax.experimental.jax2tf import jax2tf as jax2tf_internal from jax._src.interpreters import mlir +import ml_dtypes import numpy as np import tensorflow as tf @@ -345,8 +346,7 @@ def _arg_jax_to_tf(arg_jax): if (isinstance(arg_jax, jax.Array) and list(arg_jax.devices())[0].platform in _DLPACK_PLATFORMS and arg_jax.dtype.type in dlpack.SUPPORTED_DTYPES): - arg_dlpack = jax.dlpack.to_dlpack(arg_jax) - return tf.experimental.dlpack.from_dlpack(arg_dlpack) + return tf.experimental.dlpack.from_dlpack(arg_jax.__dlpack__()) # The following avoids copies to the host on CPU, always for Array # and even for ndarray if they are sufficiently aligned. # TODO(necula): on TPU this copies to the host! @@ -468,6 +468,47 @@ def is_fully_known_shape(s): call_tf_p.def_effectful_abstract_eval(_call_tf_abstract_eval) +def _mlir_type_to_numpy_dtype(type: ir.Type) -> np.dtype: + """Converts an MLIR scalar type to a NumPy dtype.""" + + if ir.IntegerType.isinstance(type): + type = ir.IntegerType(type) + width = type.width + if width == 1: + return np.dtype(np.bool_) + elif width == 8: + return np.dtype(np.uint8 if type.is_unsigned else np.int8) + elif width == 16: + return np.dtype(np.uint16 if type.is_unsigned else np.int16) + elif width == 32: + return np.dtype(np.uint32 if type.is_unsigned else np.int32) + elif width == 64: + return np.dtype(np.uint64 if type.is_unsigned else np.int64) + else: + raise ValueError(f"Unsupported integer width: {width}") + + elif ir.F16Type.isinstance(type): + return np.dtype(np.float16) + elif ir.F32Type.isinstance(type): + return np.dtype(np.float32) + elif ir.F64Type.isinstance(type): + return np.dtype(np.float64) + elif ir.BF16Type.isinstance(type): + return np.dtype(ml_dtypes.bfloat16) + + elif ir.ComplexType.isinstance(type): + element_type = ir.ComplexType(type).element_type + if ir.F32Type.isinstance(element_type): + return np.dtype(np.complex64) + elif ir.F64Type.isinstance(element_type): + return np.dtype(np.complex128) + else: + raise ValueError(f"Unsupported complex element type: {element_type}") + + else: + raise TypeError(f"Unsupported MLIR type for NumPy conversion: {type}") + + def _call_tf_lowering( ctx: mlir.LoweringRuleContext, *args_op, @@ -555,33 +596,8 @@ def convert_to_spec(x): "\n\nCaught TensorFlow exception: " + str(e)) raise ValueError(msg) from e - xla_comp = xla_client.XlaComputation(func_tf_hlo) - - # Canonicalize the results; e.g., makes them x32 if JAX is in 32-bit mode - def canonical_res_aval(res_shape: xla_client.Shape) -> core.ShapedArray: - if not res_shape.is_static(): - msg = ("Compiled TensorFlow function has dynamic output shape " + - f"{res_shape}. call_tf can used " + - "in a staged context (under jax.jit, lax.scan, etc.) only with " + - "compilable functions with static output shapes. " + - "See https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call_tf for a discussion.") - raise ValueError(msg) - - res_dtype = res_shape.numpy_dtype() - jax_res_dtype = dtypes.canonicalize_dtype(res_dtype) - return core.ShapedArray(res_shape.dimensions(), jax_res_dtype) - - result_shape = xla_comp.program_shape().result_shape() - if not result_shape.is_tuple(): - # TF does not wrap singletons as tuples, but JAX expects tuples because - # call_tf is a multiple_results primitive. - result_shapes = (result_shape,) - else: - result_shapes = result_shape.tuple_shapes() # type: ignore - - result_avals = tuple(map(canonical_res_aval, result_shapes)) - - submodule = mlir.xla_computation_to_mlir_module(xla_comp) + stablehlo = _jax.mlir.hlo_to_stablehlo(func_tf_hlo) + submodule = ir.Module.parse(stablehlo) symtab = ir.SymbolTable(submodule.operation) callee_result_types = symtab["main"].type.results fn = mlir.merge_mlir_modules(ctx.module_context.module, @@ -600,10 +616,26 @@ def canonical_res_aval(res_shape: xla_client.Shape) -> core.ShapedArray: ) outputs = [] - for op, res_aval, res_shape in zip(flat_results, result_avals, - result_shapes): - if res_aval.dtype != res_shape.numpy_dtype(): - op = hlo.ConvertOp(mlir.aval_to_ir_type(res_aval), op).result + for op, res_type in zip(flat_results, callee_result_types): + if not res_type.has_static_shape: + msg = ( + "Compiled TensorFlow function has dynamic output shape " + + f"{res_type}. call_tf can used in a staged context (under jax.jit," + " lax.scan, etc.) only with compilable functions with static" + " output shapes. See" + " https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call_tf" + " for a discussion." + ) + raise ValueError(msg) + + res_dtype = _mlir_type_to_numpy_dtype(res_type.element_type) + # Canonicalize the results; e.g., makes them x32 if JAX is in 32-bit mode + jax_res_dtype = dtypes.canonicalize_dtype(res_dtype) + if res_dtype != jax_res_dtype: + op = hlo.ConvertOp( + mlir.aval_to_ir_type(core.ShapedArray(res_type.shape, jax_res_dtype)), + op, + ).result outputs.append(op) return outputs diff --git a/jax/experimental/jax2tf/g3doc/no_xla_limitations.md b/jax/experimental/jax2tf/g3doc/no_xla_limitations.md index 24a1d62ee67e..af092a218805 100644 --- a/jax/experimental/jax2tf/g3doc/no_xla_limitations.md +++ b/jax/experimental/jax2tf/g3doc/no_xla_limitations.md @@ -24,7 +24,7 @@ partial support. For a detailed description of these XLA ops, please see the [XLA Operation Semantics documentation](https://www.tensorflow.org/xla/operation_semantics). -| XLA ops ([documentation](https://www.tensorflow.org/xla/operation_semantics)) | JAX primitive(s) ([documentation](https://jax.readthedocs.io/en/latest/jax.lax.html)) | Supported | +| XLA ops ([documentation](https://www.tensorflow.org/xla/operation_semantics)) | JAX primitive(s) ([documentation](https://docs.jax.dev/en/latest/jax.lax.html)) | Supported | | ------- | ---------------- | ------- | | XlaDot | `lax.dot_general` | Full | | XlaDynamicSlice | `lax.dynamic_slice` | Full | @@ -47,7 +47,7 @@ support and which not. ### XlaConv JAX convolutions are done using -[`lax.conv_general_dilated`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv_general_dilated.html). +[`lax.conv_general_dilated`](https://docs.jax.dev/en/latest/_autosummary/jax.lax.conv_general_dilated.html). ``` lax.conv_general_dilated( @@ -88,7 +88,7 @@ instance, parallelization primitives `vmap` and `pmap` use gather to specify a batch dimension, and it is used for slices or multidimensional indexing as well, e.g. `x[0, 1]`, `x[:, :1]`, or `x[[0], [1]]`. -The signature of [`lax.gather`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.gather.html#jax.lax.gather) +The signature of [`lax.gather`](https://docs.jax.dev/en/latest/_autosummary/jax.lax.gather.html#jax.lax.gather) is as follows: ``` @@ -128,7 +128,7 @@ All other cases of `lax.gather` are currently not supported. ### XlaReduceWindow -The signature of [`lax.reduce_window`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.reduce_window.html) +The signature of [`lax.reduce_window`](https://docs.jax.dev/en/latest/_autosummary/jax.lax.reduce_window.html) is as follows: ``` diff --git a/jax/experimental/jax2tf/impl_no_xla.py b/jax/experimental/jax2tf/impl_no_xla.py index 644c3324b4e2..70a6dccf8915 100644 --- a/jax/experimental/jax2tf/impl_no_xla.py +++ b/jax/experimental/jax2tf/impl_no_xla.py @@ -659,7 +659,7 @@ def tf_pool(inputs, pooling_type): raise NotImplementedError( f"TODO: use tf.nn.pool with dynamic shapes¨{window_dimensions=} " f" {window_strides=} {dilations=}") - # tf.nn.pool() currently does not suport tf.int32 and so we cast back and + # tf.nn.pool() currently does not support tf.int32 and so we cast back and # forth in order to be able to convert. if (inputs.dtype in [tf.int16, tf.int32]) and computation_name == "add": original_dtype = inputs.dtype diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index 7f98ce433815..088878bfbd04 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -272,7 +272,7 @@ def convert(fun_jax: Callable, should be `None` (monomorphic argument), or a Python object with the same pytree structure as the argument. See [how optional parameters are matched to - arguments](https://jax.readthedocs.io/en/latest/pytrees.html#applying-optional-parameters-to-pytrees). + arguments](https://docs.jax.dev/en/latest/pytrees.html#applying-optional-parameters-to-pytrees). A shape specification for an array argument should be an object `PolyShape(dim0, dim1, ..., dimn)` @@ -944,6 +944,11 @@ def _convert_value(val, aval): if DisabledSafetyCheck.platform() in exported.disabled_safety_checks: call_module_attrs["platforms"] = () # No platform checking + if version >= 10: + call_module_attrs["use_shardy_partitioner"] = ( + config.use_shardy_partitioner.value + ) + if logging.vlog_is_on(3): # We already logged the MLIR module when we exported it. logging.vlog(3, "XlaCallModule %s", str(call_module_attrs)) @@ -1521,6 +1526,8 @@ def _unexpected_primitive(p: core.Primitive, *args, **kwargs): "pmax_p", "pmin", "ppermute", + "psend", + "precv", "psum", "psum2", "pbroadcast", @@ -1529,6 +1536,7 @@ def _unexpected_primitive(p: core.Primitive, *args, **kwargs): "reduce_scatter", "axis_index", "all_gather", + "all_gather_invariant", "lu_pivots_to_permutation", "xla_pmap", "geqrf", @@ -1551,6 +1559,7 @@ def _unexpected_primitive(p: core.Primitive, *args, **kwargs): "bitcast", "repeat", "roll", + "with_memory_space_constraint", # temporary pending cudnn fix, see https://github.com/jax-ml/jax/pull/23740 "bias_fwd", "bias_bwd", @@ -1666,17 +1675,18 @@ def _integer_pow(x, *, y: int, _in_avals: Sequence[core.ShapedArray], tf_impl_with_avals[lax.integer_pow_p] = _integer_pow -tf_impl[lax.exp_p] = tf.math.exp -tf_impl[lax_internal.exp2_p] = lambda x: \ - tf.math.exp(tf.math.multiply(tf.math.log(tf.constant(2, x.dtype)), x)) -tf_impl[lax.expm1_p] = tf.math.expm1 -tf_impl[lax.log_p] = tf.math.log -tf_impl[lax.log1p_p] = tf.math.log1p -tf_impl[lax.tan_p] = tf.math.tan -tf_impl[lax.tanh_p] = tf.math.tanh -tf_impl[lax.sin_p] = tf.math.sin +tf_impl[lax.exp_p] = lambda x, accuracy: tf.math.exp(x) +tf_impl[lax_internal.exp2_p] = lambda x, accuracy: tf.math.exp( + tf.math.multiply(tf.math.log(tf.constant(2, x.dtype)), x) +) +tf_impl[lax.expm1_p] = lambda x, accuracy: tf.math.expm1(x) +tf_impl[lax.log_p] = lambda x, accuracy: tf.math.log(x) +tf_impl[lax.log1p_p] = lambda x, accuracy: tf.math.log1p(x) +tf_impl[lax.tan_p] = lambda x, accuracy: tf.math.tan(x) +tf_impl[lax.tanh_p] = lambda x, accuracy: tf.math.tanh(x) +tf_impl[lax.sin_p] = lambda x, accuracy: tf.math.sin(x) tf_impl[lax.sinh_p] = tf.math.sinh -tf_impl[lax.cos_p] = tf.math.cos +tf_impl[lax.cos_p] = lambda x, accuracy: tf.math.cos(x) tf_impl[lax.cosh_p] = tf.math.cosh tf_impl_with_avals[lax.atan_p] = _convert_jax_impl( lax_internal.atan_impl, multiple_results=False) @@ -1706,11 +1716,11 @@ def _atan2(y, x, **kwargs): tf_impl[lax.asin_p] = tf.math.asin tf_impl[lax.acos_p] = tf.math.acos -tf_impl[lax.sqrt_p] = tf.math.sqrt +tf_impl[lax.sqrt_p] = lambda x, accuracy: tf.math.sqrt(x) tf_impl[lax.square_p] = tf.math.square -tf_impl[lax.rsqrt_p] = tf.math.rsqrt +tf_impl[lax.rsqrt_p] = lambda x, accuracy: tf.math.rsqrt(x) -def _cbrt(x): +def _cbrt(x, accuracy): return tf.math.sign(x) * tf.math.pow(tf.math.abs(x), 1/3) tf_impl[lax.cbrt_p] = _cbrt @@ -2822,7 +2832,8 @@ def _threefry2x32_jax_impl(*args: TfVal, _in_avals, _out_aval): multiple_results=False, extra_name_stack="random_gamma") -def _rng_bit_generator(key: TfVal, *, shape, dtype, algorithm) -> Sequence[TfVal]: +def _rng_bit_generator(key: TfVal, *, shape, dtype, algorithm, + out_sharding) -> Sequence[TfVal]: is_uint32_key = key.dtype == _to_tf_dtype(jnp.uint32) if is_uint32_key: key = tf.reshape(key, (2, 2)) @@ -3060,8 +3071,11 @@ def update_computation(arg1: TfVal, arg2: TfVal) -> TfVal: def _cond( - index: TfVal, *operands: TfVal, branches: Sequence[core.ClosedJaxpr] + index: TfVal, *operands: TfVal, branches: Sequence[core.ClosedJaxpr], + **params ) -> Sequence[TfVal]: + if params: + raise NotImplementedError("jax2tf conversion for platform_dependent") # tf.cond needs lambdas with no arguments. branches_tf = [ partial(_interpret_jaxpr, jaxpr, *operands, @@ -3171,12 +3185,11 @@ def select_one_carry(new_c: TfVal, c: TfVal, c_aval: core.ShapedArray) -> TfVal: lax_control_flow._scan_impl, extra_name_stack="scan") -tf_impl_with_avals[ad_checkpoint.remat_p] = \ - _convert_jax_impl(partial(ad_checkpoint.remat_expansion, - # TODO: jax2tf cannot discriminate by platform - is_gpu_platform=False), - multiple_results=True, - extra_name_stack="checkpoint") +tf_impl_with_avals[ad_checkpoint.remat_p] = _convert_jax_impl( + ad_checkpoint.remat_expansion, + multiple_results=True, + extra_name_stack="checkpoint", +) tf_impl[ad_checkpoint.name_p] = lambda x, *, name: x @@ -3457,14 +3470,14 @@ def _custom_jvp_call(*args: TfVal, call_jaxpr: core.ClosedJaxpr, tf_impl[custom_derivatives.custom_jvp_call_p] = _custom_jvp_call -def _custom_vjp_call_jaxpr(*args: TfVal, fun_jaxpr: core.ClosedJaxpr, - **_) -> Sequence[TfVal]: +def _custom_vjp_call(*args: TfVal, call_jaxpr: core.ClosedJaxpr, + **_) -> Sequence[TfVal]: # TODO(necula): ensure that there is no AD transformation in scope - return _interpret_jaxpr(fun_jaxpr, *args, extra_name_stack="custom_vjp", + return _interpret_jaxpr(call_jaxpr, *args, extra_name_stack="custom_vjp", fresh_constant_cache=False) -tf_impl[custom_derivatives.custom_vjp_call_jaxpr_p] = _custom_vjp_call_jaxpr +tf_impl[custom_derivatives.custom_vjp_call_p] = _custom_vjp_call def _custom_lin(*args: TfVal, **_) -> Sequence[TfVal]: diff --git a/jax/experimental/jax2tf/tests/back_compat_tf_test.py b/jax/experimental/jax2tf/tests/back_compat_tf_test.py index 2cf363b0cfb2..0b75c679f5e6 100644 --- a/jax/experimental/jax2tf/tests/back_compat_tf_test.py +++ b/jax/experimental/jax2tf/tests/back_compat_tf_test.py @@ -30,7 +30,7 @@ import jax from jax._src import test_util as jtu from jax._src.internal_test_util import export_back_compat_test_util as bctu -from jax._src.lib import xla_extension +from jax._src.lib import _jax from jax.experimental import jax2tf from jax.experimental.jax2tf.tests.back_compat_testdata import tf_call_tf_function import jax.numpy as jnp @@ -96,7 +96,7 @@ def serialize( for op in tf_graph.get_operations(): if op.type == "XlaCallModule": serialized_module = op.get_attr("module") - module_str = xla_extension.mlir.deserialize_portable_artifact( + module_str = _jax.mlir.deserialize_portable_artifact( serialized_module ) module_version = op.get_attr("version") diff --git a/jax/experimental/jax2tf/tests/jax2tf_test.py b/jax/experimental/jax2tf/tests/jax2tf_test.py index bea2b76cb7cf..bde148cb514e 100644 --- a/jax/experimental/jax2tf/tests/jax2tf_test.py +++ b/jax/experimental/jax2tf/tests/jax2tf_test.py @@ -38,7 +38,7 @@ from jax._src import xla_bridge as xb from jax.experimental import jax2tf from jax.experimental.jax2tf.tests import tf_test_util -from jax.experimental.shard_map import shard_map +from jax._src.shard_map import shard_map from jax.experimental import pjit from jax.sharding import PartitionSpec as P @@ -52,6 +52,13 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase): def setUp(self): super().setUp() + versions = tf.version.VERSION.split(".") + if versions < ["2", "19", "1"]: + # StableHLO changed on March 18th, 2025 ,to version 1.10.0, and this + # introduces ops like vhlo_sine_v2. These ops require a TF version + # released after this date. + self.skipTest("Need version of TensorFlow at least 2.19.1") + # One TF device of each device_type self.tf_devices = [] for tf_device in (tf.config.list_logical_devices("TPU") + @@ -832,11 +839,7 @@ def f(x1): arg = np.array(3.) f_tf = jax2tf.convert(jax.grad(remat_f)) f_tf_hlo = self.TfToHlo(f_tf, arg) - if config.remat_opt_barrier.value: - self.assertRegex(f_tf_hlo, r"opt-barrier") - else: - self.assertRegex(f_tf_hlo, - r'transpose/jax2tf_f_/jvp/checkpoint/cond/branch_1_fun/Sin') + self.assertRegex(f_tf_hlo, r"opt-barrier") def test_remat_free_var(self): def f(x): @@ -1476,7 +1479,7 @@ def apply_transform(func, transform: str): in_shardings=(sharding.NamedSharding(mesh, P("a")),), out_shardings=sharding.NamedSharding(mesh, P("a"))), shard_map=( - shard_map(func, mesh, in_specs=(P("a", None),), + shard_map(func, mesh=mesh, in_specs=(P("a", None),), out_specs=P("a", None))), pmap=jax.pmap(func, in_axes=0, out_axes=0), )[transform] @@ -1698,6 +1701,54 @@ def f_jax(x): "Unsupported precision in dot_general"): jax2tf.convert(f_jax, native_serialization=False)(x) + def test_jvp_through_loop(self): + # Context: b/388929258 + + num_actions = 512 + + def tf_preprocessor(features): + features["num_c_actions"] = tf.constant(256, tf.int32) + return features + + def postprocessor(prob, features): + actions = jnp.arange(num_actions, dtype=jnp.int32) + r = actions // features["num_c_actions"] + c = actions - r * features["num_c_actions"] + rr = jnp.array([0.12, 0.3])[r] * prob + rc = (jnp.arange(256) * 0.7)[c] * prob + return rr, rc + + def loop_step(features, params): + features = jax2tf.call_tf(tf_preprocessor)(features) + odds = features["f1"] @ params["w1"] + features["f2"] @ params["w2"] + prob = jax.nn.sigmoid(odds) + rr, rc = postprocessor(prob, features) + new_f1 = jnp.mean(rr, keepdims=True) + new_f2 = jnp.mean(rc, keepdims=True) + return new_f1, new_f2 + + def loop(init_features, params): + def body(carry, unused_x): + f1, f2 = carry + return loop_step({"f1": f1, "f2": f2}, params), None + + (rr, rc), _ = jax.lax.scan( + body, (init_features["f1"], init_features["f2"]), length=10 + ) + return rr, rc + + def loss(features, params): + rr, rc = loop(features, params) + return jnp.mean((rr - rc) ** 2) + + jax.grad(loss, argnums=(1,))( + {"f1": jnp.array([0.5]), "f2": jnp.array([0.7])}, + { + "w1": jnp.ones((1, num_actions)) * 0.01, + "w2": jnp.ones((1, num_actions)) * 0.01, + }, + ) + @jtu.with_config(jax_enable_custom_prng=True) class Jax2tfWithCustomPRNGTest(tf_test_util.JaxToTfTestCase): @@ -1738,10 +1789,17 @@ def func(): jax_result = func() self.assertEqual(tf_result, jax_result) + class Jax2TfVersioningTest(tf_test_util.JaxToTfTestCase): # Use a separate test case with the default jax_serialization_version def setUp(self): self.use_max_serialization_version = False + versions = tf.version.VERSION.split(".") + if versions < ["2", "19", "1"]: + # StableHLO changed on March 18th, 2025 ,to version 1.10.0, and this + # introduces ops like vhlo_sine_v2. These ops require a TF version + # released after this date. + self.skipTest("Need version of TensorFlow at least 2.19.1") super().setUp() @jtu.ignore_warning( diff --git a/jax/experimental/jax2tf/tests/primitives_test.py b/jax/experimental/jax2tf/tests/primitives_test.py index 1ccd009f157c..f6ce4435e6a2 100644 --- a/jax/experimental/jax2tf/tests/primitives_test.py +++ b/jax/experimental/jax2tf/tests/primitives_test.py @@ -172,8 +172,14 @@ def test_primitive_coverage(self): continue if p.name == "composite": continue + if p.name == "pvary": + continue + if p.name == "psum_invariant": + continue if p.name == "sharding_constraint": continue + if p.name == "layout_constraint": + continue if p.name == "mesh_cast": continue if p.name == "reshard": diff --git a/jax/experimental/jax2tf/tests/shape_poly_test.py b/jax/experimental/jax2tf/tests/shape_poly_test.py index 09da97e8420a..17d03bc8c778 100644 --- a/jax/experimental/jax2tf/tests/shape_poly_test.py +++ b/jax/experimental/jax2tf/tests/shape_poly_test.py @@ -595,7 +595,7 @@ def conv_and_run(*, arg_shape: core.Shape, "Using the following polymorphic shapes specifications: args[0].shape = (2*b + a, a, c + b + a). " "Obtained dimension variables: 'a' = 2 from specification 'a' for dimension args[0].shape[1] (= 2), " "'b' = 0 from specification '2*b + a' for dimension args[0].shape[0] (= 2), . " - "Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#shape-assertion-errors for more details." + "Please see https://docs.jax.dev/en/latest/export/shape_poly.html#shape-assertion-errors for more details." )), dict(shape=(3, 2, 6), # a = 2, b = 0.5, c = 4 - b is not integer poly_spec="(a + 2*b, a, a + b + c)", @@ -604,7 +604,7 @@ def conv_and_run(*, arg_shape: core.Shape, "Division had remainder 1 when computing the value of 'b'. " "Using the following polymorphic shapes specifications: args[0].shape = (2*b + a, a, c + b + a). " "Obtained dimension variables: 'a' = 2 from specification 'a' for dimension args[0].shape[1] (= 2), . " - "Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#shape-assertion-errors for more details." + "Please see https://docs.jax.dev/en/latest/export/shape_poly.html#shape-assertion-errors for more details." )), dict(shape=(8, 2, 6), # a = 2, b = 3 - inconsistency poly_spec="(a + 2*b, a, a + b)", @@ -614,7 +614,7 @@ def conv_and_run(*, arg_shape: core.Shape, "Using the following polymorphic shapes specifications: args[0].shape = (2*b + a, a, b + a). " "Obtained dimension variables: 'a' = 2 from specification 'a' for dimension args[0].shape[1] (= 2), " "'b' = 4 from specification 'b + a' for dimension args[0].shape[2] (= 6), . " - "Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#shape-assertion-errors for more details." + "Please see https://docs.jax.dev/en/latest/export/shape_poly.html#shape-assertion-errors for more details." )), dict(shape=(7, 2, 36), # a = 2, b = 3, c = 6 - cannot solve c poly_spec="(2 * a + b, a, c * c)", @@ -623,7 +623,7 @@ def conv_and_run(*, arg_shape: core.Shape, "We can only solve linear uni-variate constraints. " "Using the following polymorphic shapes specifications: args[0].shape = (b + 2*a, a, c^2). " "Unprocessed specifications: 'c^2' for dimension size args[0].shape[2]. " - "Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#dimension-variables-must-be-solvable-from-the-input-shapes for more details." + "Please see https://docs.jax.dev/en/latest/export/shape_poly.html#dimension-variables-must-be-solvable-from-the-input-shapes for more details." )), ]) def test_shape_constraints_errors(self, *, diff --git a/jax/experimental/jax2tf/tests/sharding_test.py b/jax/experimental/jax2tf/tests/sharding_test.py index 653ddce7dca4..20193a931b63 100644 --- a/jax/experimental/jax2tf/tests/sharding_test.py +++ b/jax/experimental/jax2tf/tests/sharding_test.py @@ -33,10 +33,11 @@ from jax._src import config from jax._src import test_util as jtu from jax._src import xla_bridge +from jax._src.lib import xla_client as xc from jax import lax from jax.experimental import jax2tf from jax.experimental import pjit -from jax.experimental.shard_map import shard_map +from jax._src.shard_map import shard_map from jax.sharding import NamedSharding from jax.sharding import Mesh from jax.sharding import PartitionSpec as P @@ -109,8 +110,9 @@ def log_jax_hlo(self, f_jax, args: Sequence[Any], *, device_assignment=device_assignment, use_spmd_partitioning=use_spmd_partitioning, ) - jax_optimized_hlo = backend.compile( - jax_hlo, compile_options).hlo_modules()[0].to_string() + executable = backend.compile_and_load( + jax_hlo, xc.DeviceList(tuple(self.devices.flat)), compile_options) # type: ignore + jax_optimized_hlo = executable.hlo_modules()[0].to_string() logging.info("[%s] got JAX optimized HLO for platform %s %s", self._testMethodName, backend.platform, jax_optimized_hlo) @@ -231,10 +233,10 @@ def f_tf(x): jax2tf.convert(f_jax), [x], checks=[ # The argument - (r"f32\[10,20\].*custom_call_target.*Sharding.*sharding.*devices=\[1,2\]", + (r"f32\[10,20\].*custom_call_target.*\"Sharding.*sharding.*devices=\[1,2\]", count_in_P), # The result - (r"f32\[20,10\].*custom_call_target.*Sharding.*sharding.*devices=\[2,1\]", + (r"f32\[20,10\].*custom_call_target.*\"Sharding.*sharding.*devices=\[2,1\]", count_out_P), ]) # TODO(b/326476605): Change the condition below if required. @@ -242,11 +244,11 @@ def f_tf(x): self.check_sharding( jax2tf.convert(f_jax), [x], checks=[ - (r"f32\[10,20\].*custom_call_target.*Sharding.*sharding.*replicated", + (r"f32\[10,20\].*custom_call_target.*\"Sharding.*sharding.*replicated", count_in_replicated), - (r"f32\[20,10\].*custom_call_target.*Sharding.*sharding.*replicated", + (r"f32\[20,10\].*custom_call_target.*\"Sharding.*sharding.*replicated", count_out_replicated), - (r"custom_call_target.*Sharding", + (r"custom_call_target.*\"Sharding", count_in_P + count_in_replicated + count_out_P + count_out_replicated), ]) @@ -276,13 +278,13 @@ def f_jax(x, y): # f32[10,20] , f32[20,30] -> f32[10,30] f_tf, [y], checks=[ # The variable argument - (r"f32\[10,20\].*custom_call_target.*Sharding.*sharding.*devices=\[1,2\]", 1), + (r"f32\[10,20\].*custom_call_target.*\"Sharding.*sharding.*devices=\[1,2\]", 1), # The y argument - (r"f32\[20,30\].*custom_call_target.*Sharding.*sharding.*devices=\[2,1\]", 1), + (r"f32\[20,30\].*custom_call_target.*\"Sharding.*sharding.*devices=\[2,1\]", 1), # The output sharding - (r"f32\[10,30\].*custom_call_target.*Sharding.*sharding.*replicated", 1), + (r"f32\[10,30\].*custom_call_target.*\"Sharding.*sharding.*replicated", 1), # No other annotations - (r"custom_call_target.*Sharding", 3) + (r"custom_call_target.*\"Sharding", 3) ]) @jtu.with_mesh([("x", 2)]) @@ -310,10 +312,10 @@ def f_tf(x): jax2tf.convert(f_jax), [x], checks=[ # x - (r"f32\[10,20\].*custom_call_target.*Sharding.*sharding.*devices=\[2,1\]", + (r"f32\[10,20\].*custom_call_target.*\"Sharding.*sharding.*devices=\[2,1\]", 1), # The result - (r"f32\[20,10\].*custom_call_target.*Sharding.*sharding.*replicated", + (r"f32\[20,10\].*custom_call_target.*\"Sharding.*sharding.*replicated", self.GEQ(1)), ]) @@ -357,16 +359,16 @@ def f_jax(x): # x: f32[10, 20], optionally some axes as polymorphic f_tf, [x], checks=[ # The input argument - (r"f32\[10,20\].*custom_call_target.*Sharding.*sharding.*replicated", 1), + (r"f32\[10,20\].*custom_call_target.*\"Sharding.*sharding.*replicated", 1), # The y argument - (r"f32\[10,40\].*custom_call_target.*Sharding.*sharding.*devices=\[2,1\]", + (r"f32\[10,40\].*custom_call_target.*\"Sharding.*sharding.*devices=\[2,1\]", count_inner_sharding), - (r"f32\[10,40\].*custom_call_target.*Sharding.*sharding.*replicated", + (r"f32\[10,40\].*custom_call_target.*\"Sharding.*sharding.*replicated", count_inner_replicated), # The output sharding - (r"f32\[10,80\].*custom_call_target.*Sharding.*sharding.*replicated", 1), + (r"f32\[10,80\].*custom_call_target.*\"Sharding.*sharding.*replicated", 1), # No other annotations - (r"custom_call_target.*Sharding", 2 + count_inner_sharding + count_inner_replicated) + (r"custom_call_target.*\"Sharding", 2 + count_inner_sharding + count_inner_replicated) ]) @jtu.parameterized_filterable( @@ -427,17 +429,17 @@ def f_grad_tf(x_v, res_ct): self.check_sharding(f_grad_tf, [x, x.T], checks=[ # The input primal argument, and the output grad - (r"f32\[10,20\].*custom_call_target.*Sharding.*sharding.*devices=\[1,2\]", count_in_P), + (r"f32\[10,20\].*custom_call_target.*\"Sharding.*sharding.*devices=\[1,2\]", count_in_P), # The primal result, and the input cotangent - (r"f32\[20,10\].*custom_call_target.*Sharding.*sharding.*devices=\[2,1\]", count_out_P), + (r"f32\[20,10\].*custom_call_target.*\"Sharding.*sharding.*devices=\[2,1\]", count_out_P), ]) # TODO(b/326476605): Change the condition below if required. if out_shardings not in [None, "missing"] and in_shardings not in [None, "missing"]: self.check_sharding(f_grad_tf, [x, x.T], checks=[ - (r"f32\[10,20\].*custom_call_target.*Sharding.*sharding.*replicated", count_in_replicated), + (r"f32\[10,20\].*custom_call_target.*\"Sharding.*sharding.*replicated", count_in_replicated), # The primal result, and the input cotangent - (r"f32\[20,10\].*custom_call_target.*Sharding.*sharding.*devices=\[2,1\]", count_out_P), + (r"f32\[20,10\].*custom_call_target.*\"Sharding.*sharding.*devices=\[2,1\]", count_out_P), ]) def test_grad_sharding_different_mesh(self): @@ -576,7 +578,7 @@ def test_repro_xla_bug_shmap_collective_permute(self): @partial(shard_map, mesh=mesh, in_specs=(P('x', None),), out_specs=P('x', None)) def f_jax(b): # b: f32[2, 4] - axis_size = lax.psum(1, 'x') + axis_size = lax.axis_size('x') perm = [(j, (j + 1) % axis_size) for j in range(axis_size)] return lax.ppermute(b, 'x', perm=perm) @@ -612,7 +614,7 @@ def test_shmap_collective_permute(self, poly=None): @partial(shard_map, mesh=mesh, in_specs=(P('x', None),), out_specs=P('x', None)) def f_jax(b): # b: f32[2, 4] - axis_size = lax.psum(1, 'x') + axis_size = lax.axis_size('x') perm = [(j, (j + 1) % axis_size) for j in range(axis_size)] return lax.ppermute(b, 'x', perm=perm) diff --git a/jax/experimental/jax2tf/tests/tf_test_util.py b/jax/experimental/jax2tf/tests/tf_test_util.py index 32f89e533daf..faecf9f0f09e 100644 --- a/jax/experimental/jax2tf/tests/tf_test_util.py +++ b/jax/experimental/jax2tf/tests/tf_test_util.py @@ -34,6 +34,7 @@ from jax import export from jax._src import config from jax._src import xla_bridge +from jax._src.lib import xla_client as xc import numpy as np import tensorflow as tf from tensorflow.compiler.xla import xla_data_pb2 @@ -344,7 +345,9 @@ def log_message(extra): tf_hlo) backend = xla_bridge.get_backend() - modules = backend.compile(str(jax_lowered.compiler_ir())).hlo_modules() + device_list = xc.DeviceList(tuple(backend.local_devices())) + modules = backend.compile_and_load( + str(jax_lowered.compiler_ir()), device_list).hlo_modules() jax_opt_hlo = modules[0].to_string() logging.info("[%s] JAX OPT HLO\n%s", self._testMethodName, jax_opt_hlo) diff --git a/jax/experimental/jet.py b/jax/experimental/jet.py index 15273f0fd02a..acf8885b0f98 100644 --- a/jax/experimental/jet.py +++ b/jax/experimental/jet.py @@ -76,7 +76,7 @@ from jax._src.util import unzip2, weakref_lru_cache, safe_zip -def jet(fun, primals, series): +def jet(fun, primals, series, **_): r"""Taylor-mode higher-order automatic differentiation. Args: @@ -405,11 +405,11 @@ def deriv_prop(prim, deriv, primals_in, series_in): lax.exp(lax.neg(lax.square(x))))) -def def_comp(prim, comp): +def def_comp(prim, comp, **kwargs): """ Define the jet rule for a primitive in terms of a composition of simpler primitives. """ - jet_rules[prim] = partial(jet, comp) + jet_rules[prim] = partial(jet, comp, **kwargs) def_comp(lax.expm1_p, lambda x: lax.exp(x) - 1) @@ -478,7 +478,7 @@ def _scale(k, j): def _scale2(k, j): return 1. / (fact(k - j) * fact(j)) -def _exp_taylor(primals_in, series_in): +def _exp_taylor(primals_in, series_in, **_): x, = primals_in series, = series_in u = [x] + series @@ -522,7 +522,7 @@ def _integer_pow_taylor(primals_in, series_in, *, y): jet_rules[lax.integer_pow_p] = _integer_pow_taylor -def _logistic_taylor(primals_in, series_in): +def _logistic_taylor(primals_in, series_in, **_): x, = primals_in series, = series_in u = [x] + series @@ -538,7 +538,7 @@ def _logistic_taylor(primals_in, series_in): jet_rules[lax.logistic_p] = _logistic_taylor -def _tanh_taylor(primals_in, series_in): +def _tanh_taylor(primals_in, series_in, **_): x, = primals_in series, = series_in u = [2*x] + [2 * series_ for series_ in series] @@ -548,7 +548,7 @@ def _tanh_taylor(primals_in, series_in): return 2 * primal_out - 1, series_out jet_rules[lax.tanh_p] = _tanh_taylor -def _log_taylor(primals_in, series_in): +def _log_taylor(primals_in, series_in, **_): x, = primals_in series, = series_in u = [x] + series @@ -590,7 +590,7 @@ def scale(k, j): return 1. / (fact(k - j) * fact(j)) return primal_out, series_out jet_rules[lax.div_p] = _div_taylor_rule -def _sinusoidal_rule(sign, prims, primals_in, series_in): +def _sinusoidal_rule(sign, prims, primals_in, series_in, **_): x, = primals_in series, = series_in u = [x] + series @@ -603,7 +603,7 @@ def _sinusoidal_rule(sign, prims, primals_in, series_in): return (s[0], s[1:]), (c[0], c[1:]) def _get_ind(f, ind): - return lambda *args: f(*args)[ind] + return lambda *args, **kwargs: f(*args, **kwargs)[ind] jet_rules[lax.sin_p] = _get_ind(partial(_sinusoidal_rule, -1, (lax.sin, lax.cos)), 0) jet_rules[lax.cos_p] = _get_ind(partial(_sinusoidal_rule, -1, (lax.sin, lax.cos)), 1) diff --git a/jax/experimental/key_reuse/_core.py b/jax/experimental/key_reuse/_core.py index 7275046f556d..7c7ffd17a56c 100644 --- a/jax/experimental/key_reuse/_core.py +++ b/jax/experimental/key_reuse/_core.py @@ -35,10 +35,11 @@ from jax._src import util from jax._src.ad_checkpoint import remat_p from jax._src.debugging import debug_callback_p +from jax._src.hashable_array import HashableArray from jax._src.interpreters import partial_eval as pe from jax._src.util import weakref_lru_cache -from jax.experimental.shard_map import shard_map_p +from jax._src.shard_map import shard_map_p import numpy as np @@ -212,7 +213,7 @@ def key_reuse_signature_from_eqn(eqn: core.JaxprEqn) -> KeyReuseSignature: return sig.signature(eqn) else: raise TypeError( - f"Unrecognized key reuse sigature of type {type(sig)}: {sig}") + f"Unrecognized key reuse signature of type {type(sig)}: {sig}") else: return unknown_signature(eqn) @@ -231,7 +232,7 @@ def key_reuse_signature_from_primitive(prim, *args, **params): return jaxpr_type_signature(jaxpr) else: raise TypeError( - f"Unrecognized key reuse sigature of type {type(sig)}: {sig}") + f"Unrecognized key reuse signature of type {type(sig)}: {sig}") consume_p = core.Primitive("consume") @@ -257,16 +258,16 @@ def consume(key): def assert_unconsumed(key): """Assert that a key is unconsumed""" - assert_consumed_value_p.bind(key, value=False) + assert_consumed_value_p.bind(key, value=HashableArray(False)) def assert_consumed(key, value=True): """Assert that a key is consumed""" - assert_consumed_value_p.bind(key, value=value) + assert_consumed_value_p.bind(key, value=HashableArray(value)) def _check_consumed_value(eqn, consumed): """Extra check for use with assert_consumed_value_p""" - expected = eqn.params['value'] + expected = eqn.params['value'].val if not np.all(consumed == expected): if np.all(expected): raise AssertionError(f"Expected key to be consumed in {eqn}") @@ -415,7 +416,7 @@ def check_key_reuse(fun: Callable[..., Any], /, *args: Any) -> None: function_type_signature(fun, *args) -#---------------------------------------------------------------------------------- +# ---------------------------------------------------------------------------------- # key reuse rules for particular primitives: @dynamic_key_reuse_signature diff --git a/jax/experimental/layout.py b/jax/experimental/layout.py index ed9f8931938e..daffedcd1739 100644 --- a/jax/experimental/layout.py +++ b/jax/experimental/layout.py @@ -14,5 +14,8 @@ from jax._src.layout import ( DeviceLocalLayout as DeviceLocalLayout, - Layout as Layout + Format as Format, +) +from jax._src.pjit import ( + with_layout_constraint as with_layout_constraint, ) diff --git a/jax/experimental/mesh_utils.py b/jax/experimental/mesh_utils.py index 075e4e6eed48..58d20c331d5f 100644 --- a/jax/experimental/mesh_utils.py +++ b/jax/experimental/mesh_utils.py @@ -1,4 +1,4 @@ -# Copyright 2021 The TensorFlow Authors. All Rights Reserved. +# Copyright 2021 The JAX Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/jax/experimental/mosaic/gpu/__init__.py b/jax/experimental/mosaic/gpu/__init__.py index d004c7deb3df..7094fc7352d3 100644 --- a/jax/experimental/mosaic/gpu/__init__.py +++ b/jax/experimental/mosaic/gpu/__init__.py @@ -23,15 +23,17 @@ Barrier as Barrier, ClusterBarrier as ClusterBarrier, TMABarrier as TMABarrier, - ThreadSemantics as ThreadSemantics, + LoweringSemantics as LoweringSemantics, TMEM as TMEM, Union as Union, as_gpu_kernel as as_gpu_kernel, + supports_cross_device_collectives as supports_cross_device_collectives, ) from .launch_context import ( LaunchContext as LaunchContext, MemRefTransform as MemRefTransform, + ReductionOp as ReductionOp, Rounding as Rounding, TileTransform as TileTransform, TransposeTransform as TransposeTransform, @@ -45,6 +47,10 @@ infer_layout as infer_layout, ) +from .layouts import ( + to_layout_attr as to_layout_attr, +) + from .transform_inference import ( infer_transforms as infer_transforms, ) @@ -52,19 +58,27 @@ from .fragmented_array import ( FragmentedArray as FragmentedArray, FragmentedLayout as FragmentedLayout, + TCGEN05_LAYOUT as TCGEN05_LAYOUT, + TCGEN05_ROW_LAYOUT as TCGEN05_ROW_LAYOUT, + TCGEN05_COL_LAYOUT as TCGEN05_COL_LAYOUT, + TiledLayout as TiledLayout, WGMMA_LAYOUT as WGMMA_LAYOUT, WGMMA_ROW_LAYOUT as WGMMA_ROW_LAYOUT, - WGMMARowFragLayout as WGMMARowFragLayout, + WGMMA_COL_LAYOUT as WGMMA_COL_LAYOUT, + WGMMA_TRANSPOSED_LAYOUT as WGMMA_TRANSPOSED_LAYOUT, WGSplatFragLayout as WGSplatFragLayout, WGStridedFragLayout as WGStridedFragLayout, optimization_barrier as optimization_barrier, ) from .utils import ( BarrierRef as BarrierRef, + DialectBarrierRef as DialectBarrierRef, CollectiveBarrierRef as CollectiveBarrierRef, DynamicSlice as DynamicSlice, Partition as Partition, Partition1D as Partition1D, + SemaphoreRef as SemaphoreRef, + ThreadSubset as ThreadSubset, bitwidth as bitwidth, bytewidth as bytewidth, c as c, @@ -72,6 +86,7 @@ debug_print as debug_print, ds as ds, fori as fori, + is_known_divisible as is_known_divisible, memref_fold as memref_fold, memref_slice as memref_slice, memref_reshape as memref_reshape, diff --git a/jax/experimental/mosaic/gpu/core.py b/jax/experimental/mosaic/gpu/core.py index b255893e2e2e..38d030fa765c 100644 --- a/jax/experimental/mosaic/gpu/core.py +++ b/jax/experimental/mosaic/gpu/core.py @@ -18,16 +18,19 @@ import ctypes import dataclasses import enum -import functools import hashlib import math import os import pathlib import time -from typing import Any, Callable, Generic, TypeVar +from typing import Any, Generic, TypeVar +from collections.abc import Callable import weakref +import itertools import jax +from jax._src import lib +from jax._src import sharding_impls from jax._src.interpreters import mlir from jax._src.lib import mosaic_gpu_dialect as dialect from jaxlib.mlir import ir @@ -41,6 +44,7 @@ from jaxlib.mlir.dialects import nvvm import numpy as np + # mypy: ignore-errors from . import dialect_lowering @@ -52,15 +56,9 @@ from . import utils # MLIR can't find libdevice unless we point it to the CUDA path -# TODO(apaszke): Unify with jax._src.lib.cuda_path -CUDA_ROOT = "/usr/local/cuda" -if os.environ.get("CUDA_ROOT") is None: - os.environ["CUDA_ROOT"] = CUDA_ROOT -else: - CUDA_ROOT = os.environ["CUDA_ROOT"] - -PTXAS_PATH = os.path.join(CUDA_ROOT, "bin/ptxas") -NVDISASM_PATH = os.path.join(CUDA_ROOT, "bin/nvdisasm") +cuda_root = lib.cuda_path or "/usr/local/cuda" +os.environ["CUDA_ROOT"] = cuda_root +PYTHON_RUNFILES = os.environ.get("PYTHON_RUNFILES") # This tracks the latest Mosaic GPU IR version with a monthly delay. FWD_COMPAT_IR_VERSION = 1 @@ -84,14 +82,74 @@ os.environ["MOSAIC_GPU_RUNTIME_LIB_PATH"] = str(RUNTIME_PATH) +try: + from nvidia import nvshmem +except ImportError: + # Try to find the nvshmem library in Bazel runfiles. + if PYTHON_RUNFILES: + libdevice_path = os.path.join( + PYTHON_RUNFILES, "nvidia_nvshmem", "lib", "libnvshmem_device.bc" + ) + if os.path.exists(libdevice_path): + os.environ["MOSAIC_GPU_NVSHMEM_BC_PATH"] = libdevice_path + for root, _, files in os.walk(os.path.join(os.getcwd(), "_solib_local")): + if "libnvshmem_host.so.3" in files: + os.environ["MOSAIC_GPU_NVSHMEM_SO_PATH"] = os.path.join( + root, "libnvshmem_host.so.3" + ) + break + else: + pass +else: + if os.environ.get("MOSAIC_GPU_NVSHMEM_BC_PATH") is None: + os.environ["MOSAIC_GPU_NVSHMEM_BC_PATH"] = os.path.join( + nvshmem.__path__[0], "lib/libnvshmem_device.bc" + ) + if os.environ.get("MOSAIC_GPU_NVSHMEM_SO_PATH") is None: + os.environ["MOSAIC_GPU_NVSHMEM_SO_PATH"] = os.path.join( + nvshmem.__path__[0], "lib/libnvshmem_host.so.3" + ) + + +def supports_cross_device_collectives(): + try: + nvshmem_bc_path = os.environ["MOSAIC_GPU_NVSHMEM_BC_PATH"] + except KeyError: + return False + if nvshmem_so_path := os.environ.get("MOSAIC_GPU_NVSHMEM_SO_PATH", ""): + try: + # This both ensures that the file exists, and it populates the dlopen + # cache, helping XLA find the library even if the RPATH is not right... + ctypes.CDLL(nvshmem_so_path) + except OSError: + return False + xla_flags = os.environ.get("XLA_FLAGS", "") + return ( + os.path.exists(nvshmem_bc_path) + and "--xla_gpu_experimental_enable_nvshmem" in xla_flags + ) + + mosaic_gpu_p = jax._src.core.Primitive("mosaic_gpu_p") mosaic_gpu_p.multiple_results = True @mosaic_gpu_p.def_abstract_eval -def _mosaic_gpu_abstract_eval(*_, module, out_types): +def _mosaic_gpu_abstract_eval(*_, module, out_types, inout_types): del module # Unused. - return [jax._src.core.ShapedArray(t.shape, t.dtype) for t in out_types] + return [ + jax._src.core.ShapedArray(t.shape, t.dtype) + for t in itertools.chain(out_types, inout_types) + ] + + +def _has_communication(module, **_): + empty_str_attr = ir.StringAttr.get("") + for op in module.body: + if "nvshmem" in getattr(op, "sym_name", empty_str_attr).value: + return True + return False + # TODO(apaszke): Implement a proper system for managing kernel lifetimes KNOWN_KERNELS = {} @@ -102,9 +160,44 @@ def _mosaic_gpu_lowering_rule( *args, module, out_types, + inout_types, input_output_aliases: tuple[tuple[int, int], ...] = (), + use_custom_barrier: bool = False, ): - assert len(out_types) == len(ctx.avals_out) + axis_context = ctx.module_context.axis_context + if _has_communication(module): + # Those checks are trying to ensure that the logical device ids are + # consistent with the NVSHMEM PE ids that Mosaic will be using for + # communication. Any divergence here would require us to implement a logical + # to physical translation, which is currently not implemented. + if isinstance(axis_context, sharding_impls.SPMDAxisContext): + mesh = axis_context.mesh + if not np.array_equal(mesh.device_ids.ravel(), np.arange(mesh.size)): + raise NotImplementedError( + "Mosaic GPU only supports meshes with device ordering that follows" + " row-major device ids." + ) + elif isinstance(axis_context, sharding_impls.ShardingContext): + if axis_context.num_devices != 1: + raise NotImplementedError( + "Mosaic GPU only supports single-device meshes in ShardingContext." + ) + else: + raise NotImplementedError(f"Unsupported sharding context: {axis_context}") + + if inout_types: + if input_output_aliases: + raise ValueError( + "input_output_aliases and inout_types are mutually exclusive" + ) + num_inputs = len(ctx.avals_in) + num_outputs = len(ctx.avals_out) + input_output_aliases = tuple( + (num_inputs - 1 - i, num_outputs - 1 - i) + for i in range(len(inout_types)) + ) + assert len(ctx.avals_in) == len(args) + assert len(ctx.avals_out) == len(out_types) + len(inout_types) module = _run_serde_pass( module, serialize=True, @@ -120,15 +213,35 @@ def _mosaic_gpu_lowering_rule( raise RuntimeError("Hash collision!") else: KNOWN_KERNELS[kernel_id] = module_asm - op = mlir.custom_call( - "mosaic_gpu", - result_types=[mlir.aval_to_ir_type(aval) for aval in ctx.avals_out], - operands=args, - operand_layouts=[list(reversed(range(a.ndim))) for a in ctx.avals_in], - result_layouts=[list(reversed(range(a.ndim))) for a in ctx.avals_out], - backend_config=kernel_id + module_asm, - operand_output_aliases=dict(input_output_aliases), - ) + + if ctx.is_forward_compat(): + if use_custom_barrier: + raise ValueError("Barrier semaphore is not supported in forward compatibility mode. " + "Please, use 'export_ignore_forward_compatibility=True'.") + op = mlir.custom_call( + "mosaic_gpu", + result_types=[mlir.aval_to_ir_type(aval) for aval in ctx.avals_out], + operands=args, + operand_layouts=[list(reversed(range(a.ndim))) for a in ctx.avals_in], + result_layouts=[list(reversed(range(a.ndim))) for a in ctx.avals_out], + backend_config=kernel_id + module_asm, + operand_output_aliases=dict(input_output_aliases), + ) + else: + op = mlir.custom_call( + "mosaic_gpu_v2", + result_types=[mlir.aval_to_ir_type(aval) for aval in ctx.avals_out], + operands=args, + operand_layouts=[list(reversed(range(a.ndim))) for a in ctx.avals_in], + result_layouts=[list(reversed(range(a.ndim))) for a in ctx.avals_out], + backend_config=dict( + kernel_hash=ir.StringAttr.get(kernel_id), + module=ir.StringAttr.get(module_asm), + use_custom_barrier=ir.BoolAttr.get(use_custom_barrier), + ), + operand_output_aliases=dict(input_output_aliases), + api_version=4, + ) return op.results @@ -157,6 +270,12 @@ class Barrier: arrival_count: int num_barriers: int = 1 + def __post_init__(self): + if self.arrival_count < 1: + raise ValueError( + f"Arrival count must be at least 1, but got {self.arrival_count}" + ) + @dataclasses.dataclass(frozen=True) class ClusterBarrier: collective_dims: Sequence[gpu.Dimension] @@ -166,55 +285,95 @@ class ClusterBarrier: class TMEM: shape: tuple[int, int] dtype: Any + _: dataclasses.KW_ONLY layout: tcgen05.TMEMLayout | None = None collective: bool = False + packing: int | None = None def __post_init__(self): if self.layout is not None: - self.layout.check_shape(self.shape) + self.layout.check_type(self.shape, utils.dtype_to_ir_type(self.dtype)) + if self.packing is not None: + raise ValueError("Cannot specify both layout and packing") def _count_buffer_bytes(shape_dtype: jax.ShapeDtypeStruct) -> int: return math.prod(shape_dtype.shape) * np.dtype(shape_dtype.dtype).itemsize -class ThreadSemantics(enum.Enum): +class LoweringSemantics(enum.Enum): """Semantics for the kernel's instruction stream.""" Lane = enum.auto() Warpgroup = enum.auto() +@dataclasses.dataclass(frozen=True) +class _TMEMAlloc: + addr_ref: ir.Value + num_cols: int + collective: bool + + def alloc(self): + tcgen05.tmem_alloc( + self.addr_ref, self.num_cols, collective=self.collective, exact=False + ) + + def dealloc(self): + addr = memref.load(self.addr_ref, []) + tcgen05.tmem_dealloc( + addr, self.num_cols, collective=self.collective, exact=False + ) + + +def _slice_smem( + result: ir.Type, + smem_base: ir.Value, + offset: ir.Value, # This should be an ir.IndexType. + lowering_semantics: LoweringSemantics, +) -> ir.Value: + if lowering_semantics == LoweringSemantics.Warpgroup: + offset = arith.index_cast(ir.IntegerType.get_signless(32), offset) + return dialect.slice_smem(result, offset) + else: + return memref.view(result, smem_base, offset, []) + + def _construct_smem_reftree( cluster_shape: tuple[int, int, int], dynamic_smem: ir.Value, smem_buffers: ShapeTree, - delayed_warp_init: list[Callable[[], None]], # Mutated by this function! + tmem_allocs: list[_TMEMAlloc], # Mutated by this function! + lowering_semantics: LoweringSemantics, dynamic_smem_offset: int = 0, ) -> Callable[[], RefTree]: index = ir.IndexType.get() - i8 = ir.IntegerType.get_signless(8) i32 = ir.IntegerType.get_signless(32) + i64 = ir.IntegerType.get_signless(64) smem = ir.Attribute.parse("#gpu.address_space") flat_ref_tys, smem_buffer_tree = jax.tree.flatten( smem_buffers, is_leaf=lambda x: isinstance(x, Union) ) smem_refs = [] + for ref_ty in flat_ref_tys: - def get_barrier_ptr(num_barriers: int) -> ir.Value: + def barrier_memref(num_barriers: int) -> ir.Value: nonlocal dynamic_smem_offset - workgroup_nvptx_address_space = ( - utils.gpu_address_space_to_nvptx(gpu.AddressSpace.Workgroup) - ) - smem_base_ptr = utils.memref_ptr( - dynamic_smem, memory_space=workgroup_nvptx_address_space - ) - smem_ptr_ty = ir.Type.parse(f"!llvm.ptr<{workgroup_nvptx_address_space}>") - barrier_base_ptr = llvm.getelementptr( - smem_ptr_ty, smem_base_ptr, [], [dynamic_smem_offset], i8 + barrier_ty = ir.MemRefType.get( + (num_barriers,), + ir.Type.parse("!mosaic_gpu.barrier") + if lowering_semantics == LoweringSemantics.Warpgroup + else i64, + memory_space=smem, ) + barrier_memref = _slice_smem( + barrier_ty, + dynamic_smem, + c(dynamic_smem_offset, index), + lowering_semantics, + ) dynamic_smem_offset += num_barriers * utils.MBARRIER_BYTES - return barrier_base_ptr + return barrier_memref match ref_ty: case Union(members): member_thunks = [ @@ -222,7 +381,8 @@ def get_barrier_ptr(num_barriers: int) -> ir.Value: cluster_shape, dynamic_smem, m, - delayed_warp_init, + tmem_allocs, + lowering_semantics, dynamic_smem_offset, ) for m in members @@ -234,36 +394,34 @@ def ref(member_thunks=member_thunks): return Union([t() for t in member_thunks]) case TMABarrier(num_barriers): - ref = utils.BarrierRef.initialize( - get_barrier_ptr(num_barriers), num_barriers, arrival_count=1 - ) + init_fn = utils.DialectBarrierRef.initialize if ( + lowering_semantics == LoweringSemantics.Warpgroup + ) else utils.BarrierRef.initialize + ref = init_fn(barrier_memref(num_barriers), arrival_count=1) case Barrier(arrival_count, num_barriers): - ref = utils.BarrierRef.initialize( - get_barrier_ptr(num_barriers), - num_barriers, - arrival_count=arrival_count, - ) + init_fn = utils.DialectBarrierRef.initialize if ( + lowering_semantics == LoweringSemantics.Warpgroup + ) else utils.BarrierRef.initialize + ref = init_fn(barrier_memref(num_barriers), arrival_count=arrival_count) case ClusterBarrier(collective_dims, num_barriers): ref = utils.CollectiveBarrierRef.initialize( - get_barrier_ptr(num_barriers), - num_barriers, + barrier_memref(num_barriers), collective_dims, cluster_shape, ) - case TMEM(shape, dtype, layout, collective): - addr_ref = memref.view( + case TMEM(shape, dtype, layout=layout, collective=collective, packing=packing): + addr_ref = _slice_smem( ir.MemRefType.get([], i32, memory_space=smem), - dynamic_smem, c(dynamic_smem_offset, index), [], + dynamic_smem, + c(dynamic_smem_offset, index), + lowering_semantics, ) if layout is None: - layout = tcgen05._infer_tmem_layout(shape, collective) + layout = tcgen05._infer_tmem_layout( + shape, 1 if packing is None else packing + ) num_cols = layout.cols_in_shape(shape) - delayed_warp_init.append( - functools.partial( - tcgen05.tmem_alloc, - addr_ref, num_cols, collective=collective, exact=False, - ) - ) + tmem_allocs.append(_TMEMAlloc(addr_ref, num_cols, collective)) def ref(addr_ref=addr_ref, shape=shape, dtype=dtype, layout=layout): addr = memref.load(addr_ref, []) return tcgen05.TMEMRef( @@ -272,9 +430,11 @@ def ref(addr_ref=addr_ref, shape=shape, dtype=dtype, layout=layout): dynamic_smem_offset += 4 # i32 takes up 4 bytes case _: mlir_dtype = utils.dtype_to_ir_type(ref_ty.dtype) - tile_smem = memref.view( + tile_smem = _slice_smem( ir.MemRefType.get(ref_ty.shape, mlir_dtype, memory_space=smem), - dynamic_smem, c(dynamic_smem_offset, index), [], + dynamic_smem, + c(dynamic_smem_offset, index), + lowering_semantics, ) dynamic_smem_offset += _count_buffer_bytes(ref_ty) ref = tile_smem @@ -307,6 +467,8 @@ def _smem_tree_size(smem_buffers: ShapeTree) -> int: raise NotImplementedError("Misaligned barrier allocation") size += num_barriers * utils.MBARRIER_BYTES case TMEM(_): + # TODO(justinfu): This can trigger misaligned barrier allocations + # if TMEM is requested before barriers b/c it's not divisible by 8. size += 4 # i32 takes up 4 bytes case _: size += _count_buffer_bytes(l) @@ -320,8 +482,9 @@ def _launch( grid: tuple[int, int, int], cluster: tuple[int, int, int], block: tuple[int, int, int], - scratch_arr, smem_buffers: ShapeTree | Union[ShapeTree], + lowering_semantics: LoweringSemantics, + module: ir.Module, profiler_spec: profiler.ProfilerSpec | None = None, maybe_prof_buffer: ir.Value | None = None, ): @@ -337,7 +500,13 @@ def _launch( smem_bytes = user_smem_bytes if profiler_spec is not None: - smem_bytes += profiler_spec.smem_bytes(block=block) + # Profiler array stores values in 64 bit chunks (vectors of size 2 + # of 32-bit elements), and so the starting address needs to be 64 + # bit = 8 byte aligned. + # https://docs.nvidia.com/cuda/parallel-thread-execution/#addresses-as-operands:~:text=The%20address%20must%20be%20naturally%20aligned%20to%20a%20multiple%20of%20the%20access%20size. + align = 8 + profiler_start = (smem_bytes + align - 1) & ~(align - 1) + smem_bytes = profiler_start + profiler_spec.smem_bytes(block=block) # TODO(cperivol): Query the shared memory size programmatically. if smem_bytes > 228 * 1024: @@ -363,18 +532,19 @@ def _launch( smem = ir.Attribute.parse("#gpu.address_space") with ir.InsertionPoint(launch_op.body.blocks[0]): dynamic_smem = gpu.dynamic_shared_memory( - ir.MemRefType.get( - (ir.ShapedType.get_dynamic_size(),), i8, memory_space=smem - ) + ir.MemRefType.get((utils.DYNAMIC,), i8, memory_space=smem) ) if profiler_spec: - prof_smem = memref.view( + prof_smem = _slice_smem( ir.MemRefType.get( (profiler_spec.smem_i32_elements(block=block),), - i32, memory_space=smem, + i32, + memory_space=smem, ), - dynamic_smem, c(user_smem_bytes, index), [], + dynamic_smem, + c(profiler_start, index), + lowering_semantics, ) prof = profiler.OnDeviceProfiler( profiler_spec, prof_smem, maybe_prof_buffer @@ -382,13 +552,13 @@ def _launch( else: prof = None - ptr_ty = ir.Type.parse("!llvm.ptr") - scratch_ptr = builtin.unrealized_conversion_cast([ptr_ty], [scratch_arr]) - ctx = launch_context.LaunchContext(launch_op, scratch_ptr, cluster, prof) + ctx = launch_context.LaunchContext( + module, launch_context.Scratch(launch_op), cluster, prof + ) with ctx.named_region("Init"): - delayed_warp_init = [] + tmem_allocs: list[_TMEMAlloc] = [] smem_ref_tree_thunk = _construct_smem_reftree( - cluster, dynamic_smem, smem_buffers, delayed_warp_init + cluster, dynamic_smem, smem_buffers, tmem_allocs, lowering_semantics ) # TODO(apaszke): Skip fences if no barriers or TMEM is initialized. # TODO(apaszke): Only initialize cluster barriers before the cluster wait. @@ -396,17 +566,29 @@ def _launch( if math.prod(cluster) != 1: nvvm.cluster_arrive_relaxed(aligned=ir.UnitAttr.get()) nvvm.cluster_wait(aligned=ir.UnitAttr.get()) - if delayed_warp_init: + if tmem_allocs: eq = arith.CmpIPredicate.eq is_init_warp = arith.cmpi(eq, utils.warp_idx(sync=False), c(0, i32)) with utils.when(is_init_warp): - for init in delayed_warp_init: - init() - tcgen05.tmem_relinquish_alloc_permit() + for alloc in tmem_allocs: + alloc.alloc() + if any(alloc.collective for alloc in tmem_allocs): + tcgen05.tmem_relinquish_alloc_permit(collective=True) + if any(not alloc.collective for alloc in tmem_allocs): + tcgen05.tmem_relinquish_alloc_permit(collective=False) gpu.barrier() # Make sure the init is visible to all threads. smem_ref_tree = smem_ref_tree_thunk() yield ctx, smem_ref_tree + + if tmem_allocs: + gpu.barrier() # Make sure everyone is done before we release TMEM. + if any(alloc.collective for alloc in tmem_allocs): + nvvm.cluster_arrive_relaxed(aligned=ir.UnitAttr.get()) + nvvm.cluster_wait(aligned=ir.UnitAttr.get()) + with utils.when(is_init_warp): + for alloc in tmem_allocs: + alloc.dealloc() if prof is not None: prof.finalize(grid=grid, block=block) gpu.terminator() @@ -419,7 +601,9 @@ def _lower_as_gpu_kernel( block: tuple[int, int, int], in_shapes: tuple[Any, ...], out_shape, + inout_shape, smem_scratch_shape: ShapeTree | Union[ShapeTree], + lowering_semantics: LoweringSemantics, module_name: str, kernel_name: str | None = None, prof_spec: profiler.ProfilerSpec | None = None, @@ -427,32 +611,33 @@ def _lower_as_gpu_kernel( ptr_ty = ir.Type.parse("!llvm.ptr") token_ty = ir.Type.parse("!gpu.async.token") i32 = ir.IntegerType.get_signless(32) - i64 = ir.IntegerType.get_signless(64) def _shape_to_ref_ty(shape: jax.ShapeDtypeStruct) -> ir.MemRefType: return ir.MemRefType.get(shape.shape, utils.dtype_to_ir_type(shape.dtype)) in_ref_tys = [_shape_to_ref_ty(t) for t in in_shapes] + inout_ref_tys = [_shape_to_ref_ty(t) for t in inout_shape] unwrap_output_tuple = False if isinstance(out_shape, list): out_shape = tuple(out_shape) elif not isinstance(out_shape, tuple): out_shape = (out_shape,) - unwrap_output_tuple = True + unwrap_output_tuple = not inout_shape out_ref_tys = [_shape_to_ref_ty(t) for t in out_shape] if prof_spec is not None: out_shape = (*out_shape, prof_spec.jax_buffer_type(grid, block)) out_ref_tys.append(prof_spec.mlir_buffer_type(grid, block)) module = ir.Module.create() + dialect.register_dialect(module.context) attrs = module.operation.attributes attrs["sym_name"] = ir.StringAttr.get(module_name) if kernel_name is None: kernel_name = getattr(body, "__name__", "anonymous") # These are needed as nonlocal below. - launch_ctx, scratch_arr = None, None + launch_ctx = None with ir.InsertionPoint(module.body): _declare_runtime_functions() global_scratch = llvm.GlobalOp( @@ -463,35 +648,28 @@ def _shape_to_ref_ty(shape: jax.ShapeDtypeStruct) -> ir.MemRefType: ) @func.FuncOp.from_py_func(ptr_ty, ptr_ty, name=f"mosaic_gpu_{kernel_name}") def main(token_ptr, buffers): - nonlocal launch_ctx, scratch_arr + nonlocal launch_ctx token = builtin.unrealized_conversion_cast([token_ty], [token_ptr]) arg_refs = [] - for i, ref_ty in enumerate([*in_ref_tys, *out_ref_tys]): - ptr = llvm.LoadOp(ptr_ty, llvm.GEPOp(ptr_ty, buffers, [], [i], ptr_ty)) + # XLA will pass in inout refs again as outputs, but we ignore them. + for i, ref_ty in enumerate([*in_ref_tys, *inout_ref_tys, *out_ref_tys]): + ptr = llvm.LoadOp(ptr_ty, llvm.GEPOp(ptr_ty, buffers, [], [i], ptr_ty, llvm.GEPNoWrapFlags.none)) arg_refs.append(utils.ptr_as_memref(ptr, ir.MemRefType(ref_ty))) - in_refs = arg_refs[:len(in_ref_tys)] - out_refs = arg_refs[len(in_ref_tys):] - prof_buffer = out_refs.pop() if prof_spec is not None else None - empty_arr_ty = ir.Type.parse("!llvm.array<0 x i8>") - scratch_alloc = llvm.AllocaOp( - ptr_ty, c(1, i64), empty_arr_ty, - alignment=launch_context.TMA_DESCRIPTOR_ALIGNMENT - ) - scratch_arr = llvm.load(empty_arr_ty, scratch_alloc.result) + prof_buffer = arg_refs.pop() if prof_spec is not None else None with _launch( - token, grid, cluster, block, scratch_arr, smem_scratch_shape, - prof_spec, prof_buffer + token, grid, cluster, block, smem_scratch_shape, + lowering_semantics, module, prof_spec, prof_buffer ) as (_launch_ctx, smem_refs): nonlocal launch_ctx launch_ctx = _launch_ctx - body(launch_ctx, *in_refs, *out_refs, smem_refs) + body(launch_ctx, *arg_refs, smem_refs) main.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() sym_tab = ir.SymbolTable(module.operation) sym_tab.insert(main.func_op) sym_tab.insert(global_scratch) module.operation.verify() - return module, out_shape, unwrap_output_tuple, launch_ctx, scratch_arr + return module, out_shape, unwrap_output_tuple, launch_ctx def _run_serde_pass( @@ -518,27 +696,6 @@ def _run_serde_pass( return module -def _initialize_scratch( - launch_ctx : launch_context.LaunchContext, - scratch_arr: ir.Value, - ): - """ - Allocates and initializes the host buffer right before the launch. This needs - to be done after all TMA descriptors have been recorded by the launch context. - Only then we know what the scratch contains. - - When using the Mosaic GPU dialect, the necessary information is known only - after the lowering passes have run. - """ - with ir.InsertionPoint(scratch_arr.owner): - gmem_scratch_bytes = launch_ctx.next_scratch_offset - scratch_alloc_op = scratch_arr.owner.opview.addr.owner.opview - scratch_arr_ty = ir.Type.parse(f"!llvm.array<{gmem_scratch_bytes} x i8>") - scratch_alloc_op.elem_type = ir.TypeAttr.get(scratch_arr_ty) - scratch_arr.set_type(scratch_arr_ty) - for init_callback in launch_ctx.host_scratch_init: - init_callback(scratch_alloc_op.result) - def _declare_runtime_functions(): """Declares the runtime functions that can be used by the generated code.""" ptr_ty = ir.Type.parse("!llvm.ptr") @@ -562,31 +719,49 @@ def as_gpu_kernel( module_name: str = "unknown", kernel_name: str | None = None, ir_version: int | None = None, - thread_semantics: ThreadSemantics = ThreadSemantics.Lane, + thread_semantics: LoweringSemantics = LoweringSemantics.Lane, + inout_shape = (), ): if isinstance(in_shape, list): in_shape = tuple(in_shape) elif not isinstance(in_shape, tuple): in_shape = (in_shape,) - - module, out_shape, unwrap_output_tuple, launch_ctx, scratch_arr = ( + if isinstance(inout_shape, list): + inout_shape = tuple(inout_shape) + elif not isinstance(inout_shape, tuple): + inout_shape = (inout_shape,) + + inout_shape = jax.tree.map(lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype), + inout_shape) + out_shape = jax.tree.map(lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype), + out_shape) + module, out_shape, unwrap_output_tuple, launch_ctx = ( _lower_as_gpu_kernel( - body, grid, cluster, block, in_shape, out_shape, smem_scratch_shape, - module_name, kernel_name, prof_spec + body, grid, cluster, block, in_shape, out_shape, inout_shape, + smem_scratch_shape, thread_semantics, module_name, kernel_name, + prof_spec ) ) - if thread_semantics == ThreadSemantics.Warpgroup and dialect is not None: + if thread_semantics == LoweringSemantics.Warpgroup and dialect is not None: + # We need to run a pass that removes dead-code for which layout inference + # does not work. + pm = mlir.passmanager.PassManager.parse("builtin.module(canonicalize)", module.context) + pm.run(module.operation) + # Run Python lowering passes. The remaining passes will be run in C++ in # jax/jaxlib/mosaic/gpu/custom_call.cc layout_inference.infer_layout(module) # pytype: disable=attribute-error transform_inference.infer_transforms(module) # pytype: disable=attribute-error dialect_lowering.lower_mgpu_dialect(module, launch_ctx) # pytype: disable=attribute-error - _initialize_scratch(launch_ctx, scratch_arr) + launch_ctx.scratch.finalize_size() module.operation.verify() - expected_arg_treedef = jax.tree.structure(in_shape) + if launch_ctx.is_device_collective and not supports_cross_device_collectives(): + raise RuntimeError("Kernel is a cross-device collective but no support is available.") + + expected_arg_tys, expected_arg_treedef = jax.tree.flatten((*in_shape, *inout_shape)) def _check_args(*args): arg_treedef = jax.tree.structure(args) if arg_treedef != expected_arg_treedef: @@ -594,9 +769,23 @@ def _check_args(*args): f"Invalid argument structure: expected {expected_arg_treedef}, got" f" {arg_treedef}, ({args=})" ) + for arg, expected_ty in zip(args, expected_arg_tys): + if arg.shape != expected_ty.shape: + raise ValueError( + f"Argument shape mismatch: expected {expected_ty.shape}, got" + f" {arg.shape}" + ) + if arg.dtype != expected_ty.dtype: + hint = "" + if not arg.shape: + hint = f". Hint: cast the scalar to {expected_ty.dtype} explicitly." + raise ValueError( + f"Argument dtype mismatch: expected {expected_ty.dtype}, got" + f" {arg.dtype}{hint}" + ) def bind(*args) -> Any: - return mosaic_gpu_p.bind(*args, module=module, out_types=out_shape) + return mosaic_gpu_p.bind(*args, module=module, out_types=out_shape, inout_types=inout_shape) if prof_spec is not None: @jax.jit @@ -605,7 +794,7 @@ def prof_kernel(*args): *results, prof_buffer = bind(*args) def dump_profile(prof_buffer): out_file = os.path.join( - os.getenv("TEST_UNDECLARED_OUTPUTS_DIR"), + os.getenv("TEST_UNDECLARED_OUTPUTS_DIR", "/tmp"), f"{time.time_ns()}-trace.json", ) try: @@ -636,7 +825,7 @@ def as_torch_gpu_kernel( cluster: tuple[int, int, int] = (1, 1, 1), module_name: str = "unknown", kernel_name: str | None = None, - thread_semantics: ThreadSemantics = ThreadSemantics.Lane, + lowering_semantics: LoweringSemantics = LoweringSemantics.Lane, ): try: import torch @@ -652,23 +841,31 @@ def as_torch_gpu_kernel( flat_out_types, out_treedef = jax.tree.flatten(out_shape) expected_arg_treedef = jax.tree.structure(in_shape) - module, out_shape, unwrap_output_tuple, launch_ctx, scratch_arr = ( + module, out_shape, unwrap_output_tuple, launch_ctx = ( _lower_as_gpu_kernel( body, grid, cluster, block, in_shape, out_shape, smem_scratch_shape, - module_name, kernel_name, prof_spec + lowering_semantics, module_name, kernel_name, prof_spec ) ) - if thread_semantics == ThreadSemantics.Warpgroup and dialect is not None: + if lowering_semantics == LoweringSemantics.Warpgroup and dialect is not None: + # We need to run a pass that removes dead-code for which layout inference + # does not work. + pm = mlir.passmanager.PassManager.parse("builtin.module(canonicalize)", module.context) + pm.run(module.operation) + # Run Python lowering passes. The remaining passes will be run in C++ in # jax/jaxlib/mosaic/gpu/custom_call.cc layout_inference.infer_layout(module) # pytype: disable=attribute-error transform_inference.infer_transforms(module) # pytype: disable=attribute-error dialect_lowering.lower_mgpu_dialect(module, launch_ctx) # pytype: disable=attribute-error - _initialize_scratch(launch_ctx, scratch_arr) + launch_ctx.scratch.finalize_size() module.operation.verify() + if launch_ctx.is_device_collective: + raise RuntimeError("Kernel is a cross-device collective but no support is available.") + # Get our hands on the compilation and unload functions try: import jax_plugins.xla_cuda12 as cuda_plugin diff --git a/jax/experimental/mosaic/gpu/dialect_lowering.py b/jax/experimental/mosaic/gpu/dialect_lowering.py index fedde5a00887..932ccc1a7980 100644 --- a/jax/experimental/mosaic/gpu/dialect_lowering.py +++ b/jax/experimental/mosaic/gpu/dialect_lowering.py @@ -14,13 +14,16 @@ """Lowering rules and pass for the MLIR Mosaic GPU dialect.""" -from collections.abc import Callable +from collections.abc import Callable, Iterable import dataclasses import functools import itertools +import math import operator -from typing import Any, Sequence, Type, cast +from typing import Any, cast +from collections.abc import Sequence +from jax._src import lib as jaxlib from jax._src.interpreters import mlir as mlir_interpreter from jax._src.lib import mosaic_gpu_dialect as mgpu from jax._src.lib.mlir import ir @@ -34,6 +37,9 @@ from jax._src.lib.mlir.dialects import nvvm from jax._src.lib.mlir.dialects import scf from jax._src.lib.mlir.dialects import vector +from jax._src.util import safe_zip +from jax.experimental.mosaic.gpu import layouts as layouts_lib +from jax.experimental.mosaic.gpu import utils as mgpu_utils import numpy as np from . import fragmented_array as fa @@ -151,8 +157,49 @@ def _fragmented_array_from_ir( ).to_layout(layouts.from_layout_attr(layout)) +def wrap_transformed_memref( + transformed_memref: ir.Value, + logical_type: ir.Type, + transforms: ir.ArrayAttr, +) -> ir.Value: + """Wraps a transformed memref to an unrealized cast with transforms. + + The return type of the cast is the untransformed logical type. + """ + conversion_cast = builtin.UnrealizedConversionCastOp( + [logical_type], [transformed_memref] + ) + conversion_cast.attributes["transforms"] = transforms + return conversion_cast.result + + +def unwrap_transformed_memref( + ref: ir.Value, expected_transforms: ir.ArrayAttr +) -> ir.Value: + """Uwraps a memref from an unrealized cast and verifies its transforms.""" + + conversion_cast = cast( + builtin.UnrealizedConversionCastOp, ref.owner.opview # pytype: disable=attribute-error + ) + + if not isinstance(conversion_cast, builtin.UnrealizedConversionCastOp): + raise ValueError(f"{conversion_cast} is not a conversion_cast") + + # Check that the actual transforms match the expected ones. + if expected_transforms != conversion_cast.attributes["transforms"]: + raise ValueError( + f"Expected transforms {expected_transforms} do not match actual" + f" transforms {conversion_cast.attributes['transforms']}" + ) + + result = builtin.unrealized_conversion_cast( + [conversion_cast.operands[0].type], [conversion_cast] + ) + return result + + def _register_lowering( - op: str | Type[ir.OpView] | None + op: str | type[ir.OpView] | None ) -> Callable[[MlirLoweringRule], MlirLoweringRule]: def wrapper(f): if op is not None: @@ -185,22 +232,63 @@ def _initialize_barrier_op_lowering_rule( for i in range(num_barriers): nvvm.mbarrier_init_shared( - llvm.getelementptr(ptr_ty, initialize_barrier_op.base_pointer, [], [i], - lowered_barrier_type), - utils.c(initialize_barrier_op.arrival_count.value, i32), - predicate=ctx.single_thread_per_block_predicate + llvm.getelementptr( + ptr_ty, + initialize_barrier_op.base_pointer, + [], + [i], + lowered_barrier_type, + llvm.GEPNoWrapFlags.none, + ), + utils.c( + initialize_barrier_op.arrival_count.value * utils.WARPGROUP_SIZE, + i32, + ), + predicate=ctx.single_thread_per_block_predicate, ) gpu.barrier() barrier_base_ptr = llvm.getelementptr( ir.Type.parse("!llvm.ptr"), - initialize_barrier_op.base_pointer, [], [0], lowered_barrier_type) + initialize_barrier_op.base_pointer, [], [0], lowered_barrier_type, llvm.GEPNoWrapFlags.none) return utils.ptr_as_memref( barrier_base_ptr, initialize_barrier_op.barriers_ref.type), +# TODO(bchetioui): remove once minimum jaxlib >= 0.5.3. +OptimizationBarrierOp = getattr(mgpu, "OptimizationBarrierOp", None) + + +@_register_lowering(OptimizationBarrierOp) +def _optimization_barrier_op_lowering_rule( + _: LoweringContext, + op: OptimizationBarrierOp, +) -> Sequence[ir.Value]: + if not all(ir.VectorType.isinstance(operand.type) for operand in op.operands): + raise NotImplementedError( + f"Optimization barrier op {op} has non-vector operands." + ) + + fragmented_arrays = [] + for operand, layout in safe_zip(op.operands, inference_utils.in_layouts(op)): + ty = ir.VectorType(operand.type) + is_signed = False if ir.IntegerType.isinstance(ty.element_type) else None + fragmented_arrays.append( + _fragmented_array_from_ir(operand, layout, is_signed=is_signed) + ) + + lowered_fragmented_arrays = fa.optimization_barrier(*fragmented_arrays) + if isinstance(lowered_fragmented_arrays, fa.FragmentedArray): + lowered_fragmented_arrays = [lowered_fragmented_arrays] + + return [ + _fragmented_array_to_ir(arr, result.type) + for arr, result in safe_zip(lowered_fragmented_arrays, op.results) + ] + + @_register_lowering(arith.ConstantOp) def _arith_constant_op_lowering_rule( _: LoweringContext, op: arith.ConstantOp @@ -315,12 +403,13 @@ def _vector_load_op_lowering_rule( vec_size=strided_layout.vec_size, ) elif layouts.from_layout_attr(out_layout_attr) == fa.WGMMA_LAYOUT: + transforms_attr = inference_utils.in_transforms(vector_load_op)[0] swizzle, transforms = swizzle_and_transforms_from_transforms_attr( - inference_utils.in_transforms(vector_load_op)[0] + transforms_attr ) ref_ty = ir.MemRefType(vector_load_op.base.type) _check_transforms_and_swizzle_are_supported(ref_ty, transforms, swizzle) - transformed_ref = transform_memref(vector_load_op.base, transforms) + transformed_ref = unwrap_transformed_memref(vector_load_op.base, transforms_attr) fragmented_array = fa.FragmentedArray.load_tiled( transformed_ref, swizzle=swizzle, @@ -355,22 +444,33 @@ def _vector_store_op_lowering_rule( vector_store_op.valueToStore, to_store_layout ) - if fragmented_array.layout == fa.WGMMA_LAYOUT: + mgpu_utils.warpgroup_barrier() # Make sure the reads have completed. + + unwrapped_ref = vector_store_op.base + swizzle = None + if inference_utils.should_have_transforms(vector_store_op): + # Not all vector loads have transforms. E.g. if the store is directly to + # gmem, it won't have any transforms. + transforms_attr = inference_utils.in_transforms(vector_store_op)[0] swizzle, transforms = swizzle_and_transforms_from_transforms_attr( - inference_utils.in_transforms(vector_store_op)[0] + transforms_attr ) ref_ty = ir.MemRefType(vector_store_op.base.type) _check_transforms_and_swizzle_are_supported(ref_ty, transforms, swizzle) - fragmented_array.store_tiled( - transform_memref(vector_store_op.base, transforms), swizzle - ) - elif (isinstance(fragmented_array.layout, fa.WGStridedFragLayout) or + unwrapped_ref = unwrap_transformed_memref(vector_store_op.base, transforms_attr) + + if fragmented_array.layout == fa.WGMMA_LAYOUT: + fragmented_array.store_tiled(unwrapped_ref, swizzle) + elif (fragmented_array.layout == fa.WGMMA_ROW_LAYOUT or + fragmented_array.layout == fa.WGMMA_COL_LAYOUT or + isinstance(fragmented_array.layout, fa.WGStridedFragLayout) or isinstance(fragmented_array.layout, fa.WGSplatFragLayout)): - fragmented_array.store_untiled(vector_store_op.base) + fragmented_array.store_untiled(unwrapped_ref) else: raise ValueError( f"{vector_store_op} has an unsupported layout: {to_store_layout}" ) + mgpu_utils.warpgroup_barrier() # Make sure the writes have completed. return [] @@ -423,7 +523,7 @@ def _vector_reduction_op_lowering_rule( ir.MemRefType.get([4], element_type, memory_space=smem), arith.constant(None, op.attributes["offset"]), ) - result = a.reduce_sum(scratch) + result = a.reduce("add", range(len(a.shape)), scratch) case ( "#vector.kind" | "#vector.kind" | "#vector.kind" ): @@ -433,6 +533,86 @@ def _vector_reduction_op_lowering_rule( raise NotImplementedError(f"Unsupported reduction kind: {op.kind}") return [_fragmented_array_to_ir(result, op.result.type)] +@_register_lowering(vector.MultiDimReductionOp) +def _vector_multi_dim_reduction_op_lowering_rule( + ctx: LoweringContext, op: vector.MultiDimReductionOp +) -> Sequence[ir.Value]: + del ctx + + [in_layout, acc_layout] = inference_utils.in_layouts(op) + [out_layout] = inference_utils.out_layouts(op) + if layouts.from_layout_attr(in_layout) != fa.WGMMA_LAYOUT: + raise NotImplementedError(f"Unsupported input layout: {in_layout}") + if layouts.from_layout_attr(out_layout) not in { + fa.WGMMA_ROW_LAYOUT, + fa.WGMMA_COL_LAYOUT, + }: + raise NotImplementedError(f"Unsupported output layout: {out_layout}") + if out_layout != acc_layout: + raise ValueError( + f"Output layout {out_layout} must match the accumulator layout" + f" {acc_layout}" + ) + + element_type = ir.VectorType(op.source.type).element_type + + is_signed = False if ir.IntegerType.isinstance(element_type) else None + source_fa = _fragmented_array_from_ir(op.source, in_layout, is_signed) + acc_fa = _fragmented_array_from_ir(op.acc, acc_layout, is_signed) + match vector.CombiningKind[ + str(op.kind).removeprefix("#vector.kind<").removesuffix(">").upper() + ]: + case vector.CombiningKind.ADD: + result = source_fa.reduce("add", op.reduction_dims[0]) + result += acc_fa + case ( + vector.CombiningKind.MAXIMUMF + | vector.CombiningKind.MAXSI + | vector.CombiningKind.MAXUI + ): + result = source_fa.reduce("max", op.reduction_dims[0]) + result = result.max(acc_fa) + case _: + raise NotImplementedError(f"Unsupported reduction kind: {op.kind}") + return [_fragmented_array_to_ir(result, op.result.type)] + + +@_register_lowering(mgpu.LayoutCastOp) +def _mgpu_layout_cast_op_lowering_rule( + _: LoweringContext, layout_cast_op: mgpu.LayoutCastOp +) -> Sequence[ir.Value]: + return [layout_cast_op.x] + + +# TODO(dasenov): Remove this after the minimal jaxlib version is 0.6.1. +if hasattr(mgpu, "BroadcastInDimOp"): + @_register_lowering(mgpu.BroadcastInDimOp) + def _mgpu_broadcast_in_dim_op_lowering_rule( + _: LoweringContext, op: mgpu.BroadcastInDimOp + ) -> Sequence[ir.Value]: + in_ty = ir.VectorType(op.operand.type) + out_ty = ir.VectorType(op.result.type) + if len(in_ty.shape) != 1 or len(out_ty.shape) != 2: + raise NotImplementedError( + "Broadcast in dim with non-trivial broadcast dimensions is not" + f" supported: {op}" + ) + + broadcast_dims = list(op.broadcast_dimensions) + in_layout = inference_utils.in_layouts(op)[0] + operand_fa = _fragmented_array_from_ir(op.operand, in_layout) + + if (operand_fa.layout == fa.WGMMA_ROW_LAYOUT and broadcast_dims == [0]): + out = operand_fa.broadcast_minor(out_ty.shape[1]) + elif (operand_fa.layout == fa.WGMMA_COL_LAYOUT and broadcast_dims == [1]): + out = operand_fa.broadcast_major(out_ty.shape[0]) + else: + raise NotImplementedError( + "Broadcast in dim with non-trivial broadcast dimensions is not" + f" supported: {op}" + ) + return [_fragmented_array_to_ir(out, out_ty)] + def swizzle_and_transforms_from_transforms_attr( transforms: ir.ArrayAttr, @@ -475,32 +655,92 @@ def swizzle_and_transforms_from_transforms_attr( return swizzle or mgpu.SwizzlingMode.kNoSwizzle, tuple(gmem_transforms) -def transform_memref( - mem_ref: ir.Value, transforms: tuple[launch_context.MemRefTransform, ...] -) -> ir.Value: - """Reinterprets the memref to one where the shape is transformed as given.""" - if not transforms: - return mem_ref +def _is_memref_transposed(mem_ref_type: ir.MemRefType) -> bool: + strides, _ = mem_ref_type.get_strides_and_offset() + prev_stride = math.inf + for stride in strides: + if stride > prev_stride: + return True + prev_stride = stride + return False - mem_ref_type = ir.MemRefType(mem_ref.type) - if mem_ref_type.memory_space != ir.Attribute.parse( - "#gpu.address_space" - ): - raise ValueError(f"Only workgroup memory is supported but got {mem_ref}.") - shape = mem_ref_type.shape +def _transformed_smem_ref_type( + ref_ty: ir.MemRefType, + transforms: tuple[launch_context.MemRefTransform, ...], +) -> ir.MemRefType: + """Returns the transformed ref type for the given logical ref and transforms. + """ + transposed = _is_memref_transposed(ref_ty) + if not transforms and not transposed: + return ref_ty + + if ref_ty.memory_space != ir.Attribute.parse("#gpu.address_space"): + raise ValueError(f"Only workgroup memory is supported but got {ref_ty}.") + + shape = ref_ty.shape + strides, offset = ref_ty.get_strides_and_offset() + if transposed: + if len(shape) != 2: + raise NotImplementedError( + f"Only 2D shapes can be transposed, but got {shape}" + ) + if strides[0] != 1 or strides[1] != shape[0]: + raise NotImplementedError( + f"Only contiguous 2D memrefs can be transposed, but got {ref_ty}" + ) + for t in transforms: - shape = t.transform_shape(shape) + shape = list(t.transform_shape(shape)) + + if transposed: + # The expected output is a transposed ref and `shape` is already transposed. + # We need to compute the correct strides to match the shape. + if len(shape) == 2: + minor_to_major_stride_order = (1, 0) + elif len(shape) == 4: + minor_to_major_stride_order = (2, 3, 0, 1) + else: + raise NotImplementedError( + f"Expected a 2D or 4D shape after transforms, but got {shape}" + ) + else: + minor_to_major_stride_order = tuple(reversed(range(len(shape)))) - memref_new_type = ir.MemRefType.get( + new_strides = [1] * len(shape) + for i in range(1, len(shape)): + dim = minor_to_major_stride_order[i] + prev_dim = minor_to_major_stride_order[i-1] + new_strides[dim] = new_strides[prev_dim] * shape[prev_dim] + + new_ref_ty = ir.MemRefType.get( shape, - mem_ref_type.element_type, - memory_space=mem_ref_type.memory_space, + ref_ty.element_type, + memory_space=ref_ty.memory_space, + layout=ir.StridedLayoutAttr.get(offset, new_strides), ) + return new_ref_ty + + +def reinterpret_smem_ref( + ref: ir.Value, + transforms: tuple[launch_context.MemRefTransform, ...], +) -> ir.Value: + """Applies transforms on the ref, and makes sure that their effect is + propagated appropriately on the strides. + This function is used any time we lower from a dialect SMEM ref (2D for wgmma) + with given transforms to a "physical" SMEM ref (4D for wgmma) that is fully + transformed and transposed as needed. + """ + ref_ty = ir.MemRefType(ref.type) + new_ref_ty = _transformed_smem_ref_type(ref_ty, transforms) + if ref_ty == new_ref_ty: + return ref ms = utils.WORKGROUP_NVPTX_ADDRESS_SPACE - ptr = utils.memref_ptr(mem_ref, memory_space=ms) - return utils.ptr_as_memref(ptr, memref_new_type, ptr_memory_space=ms) + ptr = utils.memref_ptr(ref, memory_space=ms) + new_ref = utils.ptr_as_memref(ptr, new_ref_ty, ptr_memory_space=ms) + return new_ref @_register_lowering(mgpu.AsyncLoadOp) @@ -508,16 +748,15 @@ def _mgpu_async_load_op_lowering_rule( ctx: LoweringContext, load_op: mgpu.AsyncLoadOp ) -> Sequence[ir.Value]: assert ctx.launch_context is not None - barrier = utils.BarrierRef.from_dialect_barrier_memref(load_op.barrier) + barrier = utils.DialectBarrierRef.from_barrier_memref(load_op.barrier) - if inference_utils.has_in_transforms_set(load_op): - [transforms] = inference_utils.in_transforms(load_op) - swizzle, transforms = swizzle_and_transforms_from_transforms_attr( - transforms - ) - else: - swizzle = mgpu.SwizzlingMode.kNoSwizzle - transforms = () + [transforms_attr] = inference_utils.in_transforms(load_op) + swizzle, transforms = swizzle_and_transforms_from_transforms_attr( + transforms_attr + ) + unwrapped_destination = unwrap_transformed_memref( + load_op.destination, transforms_attr + ) gmem_slice = [] for idx_i32, size in zip(load_op.indices, load_op.slice_lengths): @@ -525,14 +764,19 @@ def _mgpu_async_load_op_lowering_rule( v = idx if size < 0 else utils.DynamicSlice(idx, size) gmem_slice.append(v) + # TODO(dasenov): async_copy requires all GMEM strides except the last one + # to be a multiple of 16 bytes. This restriction could be loosned with + # strided layouts when they are contiguous in GMEM. In that case, we could do: + # flatten -> async_copy -> unflatted here, as long as flattened size is a + # multiple of 16. + # TODO(dasenov): Add support for the remaining op properties. ctx.launch_context.async_copy( src_ref=load_op.source, - dst_ref=transform_memref(load_op.destination, transforms), + dst_ref=unwrapped_destination, gmem_slice=tuple(gmem_slice), - barrier=barrier, + barrier=barrier.barrier_ref, arrive=False, - uniform=True, swizzle=swizzle, gmem_transform=transforms, predicate=ctx.single_thread_per_warpgroup_predicate, @@ -546,14 +790,11 @@ def _mgpu_async_store_op_lowering_rule( ) -> Sequence[ir.Value]: assert ctx.launch_context is not None - if inference_utils.has_in_transforms_set(store_op): - [transforms] = inference_utils.in_transforms(store_op) - swizzle, transforms = swizzle_and_transforms_from_transforms_attr( - transforms - ) - else: - swizzle = mgpu.SwizzlingMode.kNoSwizzle - transforms = () + [transforms_attr] = inference_utils.in_transforms(store_op) + swizzle, transforms = swizzle_and_transforms_from_transforms_attr( + transforms_attr + ) + unwrapped_source = unwrap_transformed_memref(store_op.source, transforms_attr) gmem_slice = [] for idx_i32, size in zip(store_op.indices, store_op.slice_lengths): @@ -561,14 +802,19 @@ def _mgpu_async_store_op_lowering_rule( v = idx if size < 0 else utils.DynamicSlice(idx, size) gmem_slice.append(v) + # TODO(dasenov): async_copy requires all GMEM strides except the last one + # to be a multiple of 16 bytes. This restriction could be loosned with + # strided layouts when they are contiguous in GMEM. In that case, we could do: + # flatten -> async_copy -> unflatted here, as long as flattened size is a + # multiple of 16. + # TODO(dasenov): Add support for the remaining op properties. ctx.launch_context.async_copy( - src_ref=transform_memref(store_op.source, transforms), + src_ref=unwrapped_source, dst_ref=store_op.destination, gmem_slice=tuple(gmem_slice), swizzle=swizzle, gmem_transform=transforms, - uniform=True, predicate=ctx.single_thread_per_warpgroup_predicate, arrive=store_op.commit_group, ) @@ -761,9 +1007,6 @@ def _bitcast_op_lowering_rule( def _mgpu_wgmma_op_lowering_rule( _: LoweringContext, wgmma_op: mgpu.WGMMAOp ) -> Sequence[ir.Value]: - if wgmma_op.transpose_a or wgmma_op.transpose_b: - raise ValueError("Transpose arguments are to be deleted.") - fa_layouts = ( *inference_utils.in_layouts(wgmma_op), *inference_utils.out_layouts(wgmma_op), @@ -775,7 +1018,7 @@ def _mgpu_wgmma_op_lowering_rule( raise ValueError("Layout mismatch") wgmma_layout = fa_layouts[0] - # TODO(dasenov): Move the value -> accumulator conversion outisde of wgmma. + # TODO(dasenov): Move the value -> accumulator conversion outside of wgmma. # The associated fence could be a little expensive and is not needed if the # result a wgmma feeds into another wgmma (even in another loop step). acc_in = _fragmented_array_from_ir(wgmma_op.accumulator, wgmma_layout) @@ -785,18 +1028,20 @@ def _mgpu_wgmma_op_lowering_rule( if ir.VectorType.isinstance(wgmma_op.a.type): a_transforms = None b_transforms = inference_utils.in_transforms(wgmma_op)[0] + unwrapped_a_ref = None + unwrapped_b_ref = unwrap_transformed_memref(wgmma_op.b, b_transforms) else: a_transforms, b_transforms = inference_utils.in_transforms(wgmma_op) + unwrapped_a_ref = unwrap_transformed_memref(wgmma_op.a, a_transforms) + unwrapped_b_ref = unwrap_transformed_memref(wgmma_op.b, b_transforms) b_swizzle, b_transforms = swizzle_and_transforms_from_transforms_attr( b_transforms ) minimum_swizzle = mgpu.SwizzlingMode.k32ByteSwizzle - ref_ty = ir.MemRefType(wgmma_op.b.type) _check_transforms_and_swizzle_are_supported( - ref_ty, b_transforms, b_swizzle, minimum_swizzle + ir.MemRefType(wgmma_op.b.type), b_transforms, b_swizzle, minimum_swizzle ) - b_operand = transform_memref(wgmma_op.b, b_transforms) if ir.VectorType.isinstance(wgmma_op.a.type): a_operand = _fragmented_array_from_ir(wgmma_op.a, wgmma_layout) @@ -804,18 +1049,17 @@ def _mgpu_wgmma_op_lowering_rule( a_swizzle, a_transforms = swizzle_and_transforms_from_transforms_attr( a_transforms ) - ref_ty = ir.MemRefType(wgmma_op.a.type) _check_transforms_and_swizzle_are_supported( - ref_ty, a_transforms, a_swizzle, minimum_swizzle + ir.MemRefType(wgmma_op.a.type), a_transforms, a_swizzle, minimum_swizzle ) if a_swizzle != b_swizzle: raise ValueError( f"Non-matching swizzles of operands a and b in WGMMA: {a_swizzle} !=" f" {b_swizzle}" ) - a_operand = transform_memref(wgmma_op.a, a_transforms) + a_operand = unwrapped_a_ref - new_acc = wgmma.wgmma(acc, a_operand, b_operand, swizzle=b_swizzle) + new_acc = wgmma.wgmma(acc, a_operand, unwrapped_b_ref, swizzle=b_swizzle) return [ _fragmented_array_to_ir( @@ -827,14 +1071,25 @@ def _mgpu_wgmma_op_lowering_rule( @_register_lowering(mgpu.ArriveExpectTxOp) def _mgpu_arrive_expect_tx_op_lowering_rule( - ctx: LoweringContext, arrive_expect_tx_op: mgpu.ArriveExpectTxOp + _: LoweringContext, arrive_expect_tx_op: mgpu.ArriveExpectTxOp ) -> Sequence[ir.Value]: - - barrier = utils.BarrierRef.from_dialect_barrier_memref(arrive_expect_tx_op.barrier) - barrier.arrive_expect_tx( - arrive_expect_tx_op.expect_tx.value, - ctx.single_thread_per_warpgroup_predicate, + bytes = arrive_expect_tx_op.expect_tx.value + if bytes % utils.WARPGROUP_SIZE: + raise NotImplementedError( + "Only copies of a multiple of 128 bytes are supported" + ) + # We arrive uniformly from each thread in the WG, so we need to divide the + # number of bytes by the number of threads in the WG. + # TODO: dasenov - Relax this. We can just select the WG leader and have it + # arrive with the whole transfer size, while everyone else arrives with 0. + # But we should continue using this scheme as it's likely to be faster. + bytes //= utils.WARPGROUP_SIZE + bytes = utils.c(bytes, ir.IntegerType.get_signless(32)) + + barrier = utils.DialectBarrierRef.from_barrier_memref( + arrive_expect_tx_op.barrier ) + nvvm.mbarrier_arrive_expect_tx_shared(barrier.get_ptr(), bytes) return [] @@ -844,22 +1099,30 @@ def _mgpu_wait_op_lowering_rule( _: LoweringContext, wait_op: mgpu.WaitOp ) -> Sequence[ir.Value]: - barrier = utils.BarrierRef.from_dialect_barrier_memref(wait_op.barrier) + barrier = utils.DialectBarrierRef.from_barrier_memref(wait_op.barrier) barrier.wait_parity(wait_op.parity) return [] -# TODO(bchetioui): remove this once jaxlib minimum version >= 0.5.2. -SliceSMEMOp = getattr(mgpu, "SliceSMEMOp", None) - - -@_register_lowering(SliceSMEMOp) +@_register_lowering(mgpu.SliceSMEMOp) def _mgpu_slice_smem_op_lowering_rule( - ctx: LoweringContext, op: SliceSMEMOp + ctx: LoweringContext, op: mgpu.SliceSMEMOp ) -> Sequence[ir.Value]: del ctx - return [_slice_smem(op.result.type, op.offset)] + sliced_ref = _slice_smem(op.result.type, op.offset) + + memref_ty = ir.MemRefType(sliced_ref.type) + if memref_ty.element_type == ir.Type.parse("!mosaic_gpu.barrier"): + # Barrier memrefs are not transformed and must not be wrapped. + assert not inference_utils.has_out_transforms_set(op) + return [sliced_ref] + + out_transforms = inference_utils.out_transforms(op)[0] + _, transforms = swizzle_and_transforms_from_transforms_attr(out_transforms) + transformed_ref = reinterpret_smem_ref(sliced_ref, transforms) + wrapped_ref = wrap_transformed_memref(transformed_ref, op.result.type, out_transforms) + return [wrapped_ref] def _slice_smem(result: ir.Type, offset: ir.Value): @@ -869,8 +1132,414 @@ def _slice_smem(result: ir.Type, offset: ir.Value): ir.MemRefType.get((utils.DYNAMIC,), i8, memory_space=smem) ) offset = arith.index_cast(ir.IndexType.get(), offset) - return memref.view(result, smem_base, offset, []) + lowered_result_type = result + if ir.MemRefType.isinstance(result): + memref_ty = ir.MemRefType(result) + if memref_ty.element_type == ir.Type.parse("!mosaic_gpu.barrier"): + lowered_result_type = ir.MemRefType.get( + memref_ty.shape, _lowered_barrier_type(), memory_space=smem + ) + view = memref.view(lowered_result_type, smem_base, offset, []) + if result == lowered_result_type: + return view + return builtin.unrealized_conversion_cast([result], [view]) + + +# TODO(dasenov): Remove this after the minimal jaxlib version is 0.6.2. +if jaxlib.version >= (0, 6, 2): + @_register_lowering(mgpu.WithTransformsOp) + def _mgpu_with_transforms_op_lowering_rule( + ctx: LoweringContext, op: mgpu.WithTransformsOp + ) -> Sequence[ir.Value]: + """Lowering rule for mgpu.WithTransformsOp. + This is a noop that simply returns its input. + """ + del ctx + + [in_transforms] = inference_utils.in_transforms(op) + unwrapped_source_ref = unwrap_transformed_memref(op.ref, in_transforms) + out_transforms = inference_utils.out_transforms(op)[0] + wrapped_ref = wrap_transformed_memref( + unwrapped_source_ref, op.result.type, out_transforms + ) + return [wrapped_ref] + + +def _tile_transform_offsets( + tiling: Sequence[int], + static_offsets: Sequence[int], + dynamic_offsets: Sequence[ir.Value], +) -> tuple[Sequence[int], Sequence[ir.Value]]: + """Computes the static and dynamic offsets after the given tiling is applied. + + Conceptually, this function is analogous to + tile.transform_shape(static_offsets), except that it also handles dynamic offsets. + """ + dynamic_offset_index = 0 + new_static_offsets = [] + new_dynamic_offsets = [] + + # Preserve all offsets in non-tiled dimensions. + for offset in static_offsets[: -len(tiling)]: + new_static_offsets.append(offset) + if offset == ir.ShapedType.get_dynamic_stride_or_offset(): + new_dynamic_offsets.append(dynamic_offsets[dynamic_offset_index]) + dynamic_offset_index += 1 + + # Compute static and dynamic offsets of tiled dimensions. + for tile_size, offset in zip( + tiling, static_offsets[-len(tiling) :], strict=True + ): + if offset == ir.ShapedType.get_dynamic_stride_or_offset(): + # Here we assume that the offset is divisble by the tile size, but we + # don't check it. This has been established at the time the tiling was + # inferred. + dyn_offset = arith.divui( + dynamic_offsets[dynamic_offset_index], + utils.c(tile_size, ir.IndexType.get()), + ) + new_dynamic_offsets.append(dyn_offset) + new_static_offsets.append(ir.ShapedType.get_dynamic_stride_or_offset()) + dynamic_offset_index += 1 + else: + assert offset % tile_size == 0 + new_static_offsets.append(offset // tile_size) + + # Add 0 offsets for the newly created dimension of the tile. + new_static_offsets += [0] * len(tiling) + + return new_static_offsets, new_dynamic_offsets + + +@_register_lowering(memref.SubViewOp) +def _memref_subview_op_lowering_rule( + ctx: LoweringContext, op: memref.SubViewOp +) -> Sequence[ir.Value]: + del ctx + + in_transforms = inference_utils.in_transforms(op)[0] + out_transforms = inference_utils.out_transforms(op)[0] + + if in_transforms != out_transforms: + raise NotImplementedError( + "SubViewOp transforms for the input and output refs must be identical." + ) + + if any(s != 1 for s in op.static_strides): + raise NotImplementedError( + "SubViewOp only supports static strides of 1." + ) + + if _is_memref_transposed(op.source.type): + raise NotImplementedError( + "SubViewOp does not support transposed memrefs." + ) + + unwrapped_source_ref = unwrap_transformed_memref(op.source, in_transforms) + swizzle, transforms = swizzle_and_transforms_from_transforms_attr(out_transforms) + if swizzle != mgpu.SwizzlingMode.kNoSwizzle: + source_ty = ir.MemRefType(op.source.type) + source_strides, _ = source_ty.get_strides_and_offset() + for stride, slice, size in zip(source_strides, op.static_sizes, source_ty.shape, strict=True): + if stride != 1: + continue + # A dimension with stride 1 is a minor dimension and is swizzled. + if slice != size: + raise NotImplementedError("Slicing a swizzled dimension is unsupported.") + + match transforms: + case (): + new_subview_op = memref.SubViewOp( + op.result.type, + unwrapped_source_ref, + op.offsets, + None, + None, + static_offsets=op.static_offsets, + static_sizes=op.static_sizes, + static_strides=op.static_strides, + ) + case (tile_transform, ) if isinstance(tile_transform, launch_context.TileTransform): + in_transformed_ty = ir.MemRefType(unwrapped_source_ref.type) + tiling = tile_transform.tiling + if any( + ir.ShapedType.is_dynamic_size(s) + for s in list(op.static_sizes)[-len(tiling) :] + ): + raise NotImplementedError( + "SubViewOp only supports static sizes for the tiled dimensions." + ) + new_sizes = tile_transform.transform_shape(list(op.static_sizes)) + new_static_offsets, new_dynamic_offsets = _tile_transform_offsets( + tiling, list(op.static_offsets), list(op.offsets) + ) + + new_subview_op = memref.SubViewOp( + _transformed_smem_ref_type(op.result.type, transforms), + unwrapped_source_ref, + new_dynamic_offsets, + None, + None, + static_offsets=new_static_offsets, + static_sizes=new_sizes, + static_strides=[1] * len(in_transformed_ty.shape), + ) + case _: + raise NotImplementedError( + "SubViewOp only supports a single tile transform." + ) + + wrapped_ref = wrap_transformed_memref( + new_subview_op.result, op.result.type, out_transforms + ) + return [wrapped_ref] + + +@_register_lowering(memref.CastOp) +def _memref_cast_op_lowering_rule( + ctx: LoweringContext, op: memref.CastOp +) -> Sequence[ir.Value]: + """Lowering rule for memref.CastOp. + Only casts that add a dynamic offset are supported. + """ + del ctx + + in_transforms = inference_utils.in_transforms(op)[0] + out_transforms = inference_utils.out_transforms(op)[0] + if in_transforms != out_transforms: + raise NotImplementedError( + "CastOp transforms for the input and output refs must be identical." + ) + + in_ty = ir.MemRefType(op.source.type) + out_ty = ir.MemRefType(op.result.type) + if in_ty.element_type != out_ty.element_type: + raise NotImplementedError( + "CastOp only supports casts between memrefs with the same element type." + ) + if in_ty.shape != out_ty.shape: + raise NotImplementedError( + "CastOp only supports casts between memrefs with the same shape." + ) + in_strides, _ = in_ty.get_strides_and_offset() + out_strides, out_offset = out_ty.get_strides_and_offset() + if in_strides != out_strides: + raise NotImplementedError( + "CastOp only supports casts between memrefs with the same strides." + ) + + unwrapped_source_ref = unwrap_transformed_memref(op.source, in_transforms) + in_transformed_ty = ir.MemRefType(unwrapped_source_ref.type) + transformed_strides, _ = in_transformed_ty.get_strides_and_offset() + out_layout = ir.StridedLayoutAttr.get(out_offset, transformed_strides) + out_transformed_ty = ir.MemRefType.get( + in_transformed_ty.shape, + in_transformed_ty.element_type, + memory_space=in_transformed_ty.memory_space, + layout=out_layout, + ) + new_cast_op = memref.CastOp(out_transformed_ty, unwrapped_source_ref) + wrapped_ref = wrap_transformed_memref( + new_cast_op.result, op.result.type, out_transforms + ) + return [wrapped_ref] + + +def _permutation_to_affine_map_attr( + permutation: Sequence[int], +) -> ir.AffineMapAttr: + return ir.AffineMapAttr.get(ir.AffineMap.get_permutation(permutation)) + + +@_register_lowering(memref.TransposeOp) +def _memref_transpose_op_lowering_rule( + ctx: LoweringContext, op: memref.TransposeOp +) -> Sequence[ir.Value]: + del ctx + + in_transforms = inference_utils.in_transforms(op)[0] + unwrapped_in_ref = unwrap_transformed_memref(op.in_, in_transforms) + in_transformed_ty = ir.MemRefType(unwrapped_in_ref.type) + if len(in_transformed_ty.shape) == 2: + new_permutation = op.permutation + elif len(in_transformed_ty.shape) == 4: + if op.permutation == _permutation_to_affine_map_attr([0, 1]): + new_permutation = _permutation_to_affine_map_attr([0, 1, 2, 3]) + elif op.permutation == _permutation_to_affine_map_attr([1, 0]): + new_permutation = _permutation_to_affine_map_attr([1, 0, 3, 2]) + else: + raise NotImplementedError("Unsupported permutation.") + else: + raise NotImplementedError( + "TransposeOp only supports transposing 2D and 4D memrefs." + ) + + out_transforms = inference_utils.out_transforms(op)[0] + _, transforms = swizzle_and_transforms_from_transforms_attr(out_transforms) + new_transpose_op = memref.TransposeOp( + _transformed_smem_ref_type(op.result.type, transforms), + unwrapped_in_ref, + new_permutation, + ) + + wrapped_ref = wrap_transformed_memref( + new_transpose_op.result, op.result.type, out_transforms + ) + return [wrapped_ref] + + +@_register_lowering(memref.LoadOp) +def _memref_load_op_lowering_rule( + ctx: LoweringContext, op: memref.LoadOp +) -> Sequence[ir.Value]: + """Lowering rule for memref.LoadOp. + + Loads are never transformed so this rule is mostly just a pass-through. + """ + del ctx + + in_transforms = inference_utils.in_transforms(op)[0] + if in_transforms: + raise NotImplementedError(f"memref.LoadOp does not support transforms: {op}") + + new_load_op = memref.LoadOp( + memref=unwrap_transformed_memref(op.memref, in_transforms), + indices=op.indices, + nontemporal=op.nontemporal, + ) + return [new_load_op.result] + + +@_register_lowering(memref.StoreOp) +def _memref_store_op_lowering_rule( + ctx: LoweringContext, op: memref.StoreOp +) -> Sequence[ir.Value]: + """Lowering rule for memref.StoreOp. + + Stores are never transformed so this rule is mostly just a pass-through. + """ + del ctx + + in_transforms = inference_utils.in_transforms(op)[0] + if in_transforms: + raise NotImplementedError(f"memref.StoreOp does not support transforms: {op}") + + memref.StoreOp( + value=op.value, + memref=unwrap_transformed_memref(op.memref, in_transforms), + indices=op.indices, + nontemporal=op.nontemporal, + ) + return [] + + +# The metadata needed to recostruct a vector from its flattened representation. +_VectorTemplate = tuple[Sequence[int], fa.FragmentedLayout, ir.VectorType] + +def _flatten_ir_values( + values: Sequence[ir.Value], fa_layouts: Iterable[ir.Attribute] +) -> tuple[Sequence[ir.Value], Sequence[_VectorTemplate | None]]: + """Flattens a sequence of values. + + Non-vector values are preserved as is. Vectors are mapped to fragmented + arrays and then flattened into per-register values. + + Args: + values: The sequence of values to flatten. + fa_layouts: The layouts of vectors in ``values``. + + Returns: + A tuple of (flattened values, templates). The templates are used to + reconstruct the vectors from the per-register values. + """ + fa_layouts_it = iter(fa_layouts) + result = [] + templates = [] + for v in values: + if ir.VectorType.isinstance(v.type): + fa = _fragmented_array_from_ir(v, next(fa_layouts_it)) + result.extend(fa.registers.flat) + templates.append((fa.registers.shape, fa.layout, ir.VectorType(v.type))) + else: + result.append(v) + templates.append(None) + return result, templates + + +def _unflatten_ir_values( + flat_values: Sequence[ir.Value], templates: Sequence[_VectorTemplate | None] +) -> Sequence[ir.Value]: + """The inverse of ``_flatten_ir_values``.""" + result = [] + flat_values_it = iter(flat_values) + for template in templates: + if template is None: + result.append(next(flat_values_it)) + continue + registers_shape, layout, vec_type = template + value_registers = np.asarray( + [next(flat_values_it) for _ in range(math.prod(registers_shape))], + dtype=object, + ) + value = fa.FragmentedArray( + _registers=value_registers.reshape(registers_shape), + _layout=layout, + _is_signed=False + if ir.IntegerType.isinstance(vec_type.element_type) + else None, + ) + result.append(_fragmented_array_to_ir(value, vec_type)) + return result + + +def _move_scf_block_to_block_with_flattened_arguments( + ctx: LoweringContext, + old_block: ir.Block, + new_block: ir.Block, + last_op_type: type[ir.OpView], + args_template: Sequence[_VectorTemplate | None], + *new_leading_args: Sequence[ir.Value], +) -> Sequence[_VectorTemplate | None]: + """Moves the operations from `old_block` to `new_block`. + + The input arguments to the block, if any, are flattened using the provided + `args_template`, except for any new_leading_args which are simply prepended + to the flattened arguments and must be part of the template. + + The last operation of the old block must be of type `last_op_type` which + is expected to be either a `scf.YieldOp` or a `scf.ConditionOp`. This + operation is recreated with flattened output arguments. + """ + out_template = None + with ir.InsertionPoint(new_block): + new_carry = _unflatten_ir_values(new_block.arguments[len(new_leading_args):], args_template) + new_args = new_leading_args + tuple(new_carry) + for old_arg, new_arg in zip(old_block.arguments, new_args, strict=True): + old_arg.replace_all_uses_with(new_arg) + for op in [*old_block]: + if not isinstance(op, last_op_type): + mgpu.private_operation_remove_from_parent(op) + mgpu.private_block_append_owned_operation(new_block, op) + ctx.lower_op(op) + else: + assert out_template is None + layouts = ( + inference_utils.in_layouts(op) + if inference_utils.has_in_layouts_set(op) + else [] + ) + if isinstance(op, scf.YieldOp): + flat_operands, out_template = _flatten_ir_values(op.operands, layouts) + scf.yield_(flat_operands) + elif isinstance(op, scf.ConditionOp): + flat_carry, out_template = _flatten_ir_values(op.args, layouts) + scf.condition(op.condition, flat_carry) + else: + raise NotImplementedError(f"Unsupported op type: {op}") + op.erase() + assert out_template is not None + return out_template @_register_lowering(scf.ForOp) def _for_op_lowering_rule( @@ -884,84 +1553,145 @@ def _for_op_lowering_rule( yield_layouts = inference_utils.in_layouts(yield_op) if in_layouts != out_layouts or in_layouts != yield_layouts: raise ValueError("Layout mismatch") - fa_layouts = in_layouts - - fa_layouts_it = iter(fa_layouts) - arg_template = [ - (_fragmented_array_from_ir(arg, next(fa_layouts_it)), arg.type) - if ir.VectorType.isinstance(arg.type) - else (arg, arg.type) - for arg in for_op.initArgs - ] - def lower_carry(carry): - fa_layouts_it = iter(fa_layouts) - carry_with_fas = [ - _fragmented_array_from_ir(arg, next(fa_layouts_it)) - if ir.VectorType.isinstance(arg.type) - else arg - for arg in carry - ] - lowered_carry = [] - for c in carry_with_fas: - if isinstance(c, fa.FragmentedArray): - lowered_carry.extend(c.registers.flat) - else: - lowered_carry.append(c) - return lowered_carry - - def recreate_carry(lowered_carry): - recreated_carry = [] - arg_it = iter(lowered_carry) - for arg_value, arg_type in arg_template: - if isinstance(arg_value, fa.FragmentedArray): - carry_registers = np.asarray( - [next(arg_it) for _ in arg_value.registers.flat], dtype=object - ) - carry_registers = carry_registers.reshape(arg_value.registers.shape) - carry = fa.FragmentedArray( - _registers=carry_registers, - _layout=arg_value.layout, - _is_signed=arg_value.is_signed, - ) - recreated_carry.append(_fragmented_array_to_ir(carry, arg_type)) - else: - recreated_carry.append(next(arg_it)) - return recreated_carry + flat_init_args, args_template = _flatten_ir_values( + for_op.initArgs, in_layouts + ) new_for_op = scf.ForOp( for_op.lowerBound, for_op.upperBound, for_op.step, - lower_carry(for_op.initArgs), + flat_init_args, ) - with ir.InsertionPoint(new_for_op.body): - recreated_carry = recreate_carry(new_for_op.body.arguments[1:]) - ops_to_lower = [] - for op in for_op.body: - if op == yield_op: - continue - mgpu.private_operation_remove_from_parent(op) - mgpu.private_block_append_owned_operation(new_for_op.body, op) - ops_to_lower.append(op) - new_args = (new_for_op.induction_variable, *recreated_carry) - for old_carry, new_carry in zip(for_op.body.arguments, new_args, strict=True): - old_carry.replace_all_uses_with(new_carry) - - for op in ops_to_lower: - with ir.InsertionPoint(op): - ctx.lower_op(op) - with ir.InsertionPoint(new_for_op.body): - new_yield_operands = lower_carry(yield_op.operands) - yield_op.erase() - scf.yield_(new_yield_operands) - return recreate_carry(new_for_op.results) + _move_scf_block_to_block_with_flattened_arguments( + ctx, + for_op.body, + new_for_op.body, + scf.YieldOp, + args_template, + new_for_op.induction_variable, + ) + + return _unflatten_ir_values(new_for_op.results, args_template) + + +@_register_lowering(scf.WhileOp) +def _while_op_lowering_rule( + ctx: LoweringContext, while_op: scf.WhileOp +) -> MlirLoweringRuleResult: + if not inference_utils.should_have_layout(while_op): + return _traverse_op_lowering_rule(ctx, while_op) + + before_block = while_op.before.blocks[0] + after_block = while_op.after.blocks[0] + condition_op = before_block.operations[len(before_block.operations) - 1] + yield_op = after_block.operations[len(after_block.operations) - 1] + + in_layouts = inference_utils.in_layouts(while_op) + out_layouts = inference_utils.out_layouts(while_op) + + if in_layouts: + yield_layouts = inference_utils.in_layouts(yield_op) + if in_layouts != yield_layouts: + raise ValueError( + f"Input layouts {in_layouts} do not match yield layouts" + f" {yield_layouts}" + ) + + if out_layouts: + condition_layouts = inference_utils.in_layouts(condition_op) + if out_layouts != condition_layouts: + raise ValueError( + f"Output layouts {out_layouts} do not match condition layouts" + f" {condition_layouts}" + ) + + flat_inits, inits_template = _flatten_ir_values(while_op.inits, in_layouts) + result_types = _infer_flat_result_types(while_op, out_layouts) + new_while_op = scf.WhileOp(result_types, flat_inits) + + # Before block + init_types = [v.type for v in flat_inits] + new_before_block = new_while_op.before.blocks.append(*init_types) + results_template = _move_scf_block_to_block_with_flattened_arguments( + ctx, + before_block, + new_before_block, + scf.ConditionOp, + inits_template, + ) + + # After block + new_after_block = new_while_op.after.blocks.append(*result_types) + _move_scf_block_to_block_with_flattened_arguments( + ctx, + after_block, + new_after_block, + scf.YieldOp, + results_template, + ) + + return _unflatten_ir_values(new_while_op.results, results_template) + + +def _infer_flat_result_types( + op: ir.OpView, out_layouts: Sequence[ir.Attribute] +) -> Sequence[ir.Type]: + result_types: list[ir.Type] = [] + out_layouts_it = iter(out_layouts) + for r in op.results: + if not ir.VectorType.isinstance(r.type): + result_types.append(r.type) + continue + vec_type = ir.VectorType(r.type) + layout = layouts_lib.from_layout_attr(next(out_layouts_it)) + result_types.extend( + [layout.registers_element_type(vec_type.element_type)] + * math.prod(layout.registers_shape(tuple(vec_type.shape))) + ) + return result_types + + +@_register_lowering(scf.IfOp) +def _if_op_lowering_rule( + ctx: LoweringContext, if_op: scf.IfOp +) -> MlirLoweringRuleResult: + if not inference_utils.should_have_layout(if_op): + return _traverse_op_lowering_rule(ctx, if_op) + + raise NotImplementedError + + +@_register_lowering(scf.IndexSwitchOp) +def _index_switch_op_lowering_rule( + ctx: LoweringContext, switch_op: scf.IndexSwitchOp +) -> MlirLoweringRuleResult: + if not inference_utils.should_have_layout(switch_op): + return _traverse_op_lowering_rule(ctx, switch_op) + + out_layouts = inference_utils.out_layouts(switch_op) + new_switch_op = scf.IndexSwitchOp( + _infer_flat_result_types(switch_op, out_layouts), + switch_op.arg, + switch_op.cases, + len(switch_op.regions) - 1, + ) + + results_template: Sequence[_VectorTemplate | None] = [] + for region, new_region in zip( + switch_op.regions, new_switch_op.regions, strict=True + ): + [block] = region.blocks + new_block = new_region.blocks.append() + results_template = _move_scf_block_to_block_with_flattened_arguments( + ctx, block, new_block, scf.YieldOp, [] + ) + return _unflatten_ir_values(new_switch_op.results, results_template) @_register_lowering(func.FuncOp) @_register_lowering(gpu.LaunchOp) -@_register_lowering(scf.IfOp) # TODO(apaszke,bchetioui): Add a proper rule. -@_register_lowering(scf.IndexSwitchOp) # TODO(apaszke,bchetioui): Add a proper rule. def _traverse_op_lowering_rule( ctx: LoweringContext, op: ir.OpView ) -> MlirLoweringRuleResult: @@ -989,9 +1719,11 @@ def single_thread_predicates(module: ir.Module) -> tuple[ir.Value, ir.Value]: sub_op.operation.regions[0].blocks[0] ): assert block_predicate is None - block_predicate = utils.single_thread_predicate(per_block=True) + block_predicate = utils.single_thread_predicate( + scope=utils.ThreadSubset.BLOCK + ) warpgroup_predicate = utils.single_thread_predicate( - per_block=False + scope=utils.ThreadSubset.WARPGROUP ) if block_predicate is None: @@ -1008,6 +1740,7 @@ def _should_lower(op: ir.OpView) -> bool: return ( op.OPERATION_NAME.startswith("mosaic_gpu.") # pytype: disable=attribute-error or inference_utils.should_have_layout(op) + or inference_utils.should_have_transforms(op) or any(bool(b) for r in op.regions for b in r) # Does it have subblocks? ) diff --git a/jax/experimental/mosaic/gpu/examples/BUILD b/jax/experimental/mosaic/gpu/examples/BUILD index fe1a7e9180ac..b24c38b34235 100644 --- a/jax/experimental/mosaic/gpu/examples/BUILD +++ b/jax/experimental/mosaic/gpu/examples/BUILD @@ -39,6 +39,15 @@ py_library( ], ) +py_library( + name = "matmul_blackwell", + srcs = ["matmul_blackwell.py"], + deps = [ + "//jax", + "//jax:mosaic_gpu", + ], +) + py_library( name = "flash_attention", srcs = ["flash_attention.py"], diff --git a/jax/experimental/mosaic/gpu/examples/__init__.py b/jax/experimental/mosaic/gpu/examples/__init__.py new file mode 100644 index 000000000000..3da0dd1fa3ca --- /dev/null +++ b/jax/experimental/mosaic/gpu/examples/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2025 The JAX Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== diff --git a/jax/experimental/mosaic/gpu/examples/flash_attention.py b/jax/experimental/mosaic/gpu/examples/flash_attention.py index dc59dda3a6e5..280efd513187 100644 --- a/jax/experimental/mosaic/gpu/examples/flash_attention.py +++ b/jax/experimental/mosaic/gpu/examples/flash_attention.py @@ -17,7 +17,6 @@ import dataclasses import enum import itertools -import warnings import jax from jax import random @@ -244,8 +243,8 @@ def kv_loop(kv_step, carry): perform_schedule_barrier() - # This is quite suprising, but it seems like warp shuffles cannot - # run simutaneously with the WGMMA. For that reason we include it as + # This is quite surprising, but it seems like warp shuffles cannot + # run simultaneously with the WGMMA. For that reason we include it as # part of the TensorCore critical section and not the ALU section. with ctx.named_region("Softmax reduction"): l_i += p.reduce(arith.addf, axis=1) @@ -299,7 +298,7 @@ def kv_loop(kv_step, carry): scf.yield_([]) with ir.InsertionPoint(if_compute.else_block): nvvm.setmaxregister(40, nvvm.SetMaxRegisterAction.decrease) - with single_thread(per_block=False): + with single_thread(scope=ThreadSubset.WARPGROUP): k_tr = (TileTransform(tiling), TransposeTransform((1, 0, 2, 3))) v_tr = TileTransform(tiling) kv_head_idx = arith.divui(q_head_idx, c(q_heads_per_kv_head)) @@ -310,7 +309,7 @@ def start_kv_copy(slot, kv_seq_base, smem, gmem, barrier, transform): gmem_slice=(kv_head_idx, ds(kv_seq_base, blocks.kv)), gmem_transform=transform, barrier=barrier, - uniform=False, + predicate=None, swizzle=128, ) def start_k_copy(slot, kv_seq_base): @@ -391,7 +390,7 @@ def only_wg(idx): kv_head_idx = arith.divui(q_head_idx, c(q_heads_per_kv_head)) def kv_copy_init(slot, kv_seq_base): - with single_thread(per_block=False): + with single_thread(ThreadSubset.WARPGROUP): txcount = 2 * blocks.kv * head_dim * bytewidth(f16) barriers[slot].arrive_expect_tx(txcount) k_tr = (TileTransform(tiling), TransposeTransform((1, 0, 2, 3))) @@ -404,7 +403,7 @@ def kv_copy_init(slot, kv_seq_base): gmem_transform=t, barrier=barriers[slot], arrive=False, - uniform=False, + predicate=None, swizzle=128, ) @@ -601,7 +600,7 @@ def ref(q, k, v): if __name__ == "__main__": if (not jtu.test_device_matches(["cuda"]) or not jtu.is_cuda_compute_capability_equal("9.0")): - warnings.warn( + print( "Mosaic GPU Flash Attention requires compute capability 9.0a to run, " "skipping.") exit(0) diff --git a/jax/experimental/mosaic/gpu/examples/matmul.py b/jax/experimental/mosaic/gpu/examples/matmul.py index a5dd29e0dc4d..5c8363fa8b27 100644 --- a/jax/experimental/mosaic/gpu/examples/matmul.py +++ b/jax/experimental/mosaic/gpu/examples/matmul.py @@ -206,7 +206,7 @@ def fetch(slot, ki): rhs_tma_tile_bytes = int(np.prod(block_tiling.kn) * rhs_elem_bytes) txcount = lhs_tma_tile_bytes + rhs_tma_tile_bytes common_copy_args = dict( - swizzle=swizzle, barrier=barrier, arrive=False, uniform=False, + swizzle=swizzle, barrier=barrier, arrive=False, predicate=None, ) with single_thread(): barrier.arrive_expect_tx(txcount) diff --git a/jax/experimental/mosaic/gpu/examples/matmul_blackwell.py b/jax/experimental/mosaic/gpu/examples/matmul_blackwell.py index 6af394d00138..ac5a8985ebff 100644 --- a/jax/experimental/mosaic/gpu/examples/matmul_blackwell.py +++ b/jax/experimental/mosaic/gpu/examples/matmul_blackwell.py @@ -15,6 +15,7 @@ """Matmul kernel for Blackwell.""" import itertools +import math import jax from jax._src.interpreters import mlir @@ -41,7 +42,8 @@ def bytecount(shape, dtype): def build_kernel( - m, n, k, + m, k, n, + dtype: jnp.dtype, tile_m: int = 128, tile_n: int = 128, grid_tile_m: int = 1, @@ -51,12 +53,15 @@ def build_kernel( i1 = ir.IntegerType.get_signless(1) i32 = ir.IntegerType.get_signless(32) index = ir.IndexType.get() + if jnp.dtype(dtype).itemsize != 2: + raise NotImplementedError(f"Only tested with 16-bit dtypes, but got {dtype}") + if tile_m != 128: + raise NotImplementedError(f"Only tile_m=128 supported, but got {tile_m}") swizzle = 128 - swizzle_elems = tile_k = swizzle // 2 + swizzle_elems = tile_k = 8 * swizzle // jnp.finfo(dtype).bits tiling = (8, swizzle_elems) - in_dtype = jnp.float16 k_loop_iter = k // tile_k max_concurrent_steps = min(max_concurrent_steps, k_loop_iter) @@ -74,132 +79,187 @@ def build_kernel( raise ValueError(f"{n=} must be divisible by {tile_n=}") if k % tile_k != 0: raise ValueError(f"{k=} must be divisible by {tile_k=}") - if (m // tile_m) % grid_tile_m: + if (m // block_tile_m) % grid_tile_m: raise ValueError(f"{m=} // {tile_m=} must be divisible by {grid_tile_m=}") + # We intend this to be iterated in column-major order. + logical_grid = (grid_tile_m, n // tile_n, m // (block_tile_m * grid_tile_m)) + def kernel(ctx, a, b, d, smem): - ((a_smem, b_smem), d_smem), barriers, mma_done_barrier, acc = smem + ((a_smem, b_smem), d_smem), barriers, mma_done_barrier, tmem_done_barrier, acc = smem (ab_full_barriers, ab_empty_barriers) = barriers warp_idx = mgpu.warp_idx(sync=True) is_warp_leader = nvvm.elect_sync(i1) - is_leader_of = lambda i: arith.andi(arith.cmpi(arith.CmpIPredicate.eq, warp_idx, c(i, i32)), is_warp_leader) - is_leader_block = arith.cmpi(arith.CmpIPredicate.eq, ctx.cluster_idx(gpu.Dimension.x), c(0, index)) - - m_idx = arith.addi( - gpu.block_id(gpu.Dimension.x), - arith.muli(gpu.block_id(gpu.Dimension.z), c(grid_tile_m, index)), + is_leader_of = lambda i: arith.andi( + arith.cmpi(arith.CmpIPredicate.eq, warp_idx, c(i, i32)), is_warp_leader + ) + is_leader_block = arith.cmpi( + arith.CmpIPredicate.eq, ctx.cluster_idx(gpu.Dimension.x), c(0, index) + ) + is_store_warpgroup = arith.cmpi( + arith.CmpIPredicate.eq, mgpu.warpgroup_idx(sync=True), c(1, i32) ) - n_idx = gpu.block_id(gpu.Dimension.y) - block_m_start = arith.muli(m_idx, c(block_tile_m, index)) - # All blocks in the cluster share the same m_start -- align it! - m_start = arith.muli(arith.divui(block_m_start, c(tile_m, index)), c(tile_m, index)) - n_start = arith.muli(n_idx, c(tile_n,index)) + def compute_output(block_m_start, n_start, call_counter): + """Compute and store a single output tile. - with mgpu.when(is_leader_of(TMA_WARP)): - @mgpu.fori(c(k_loop_iter, index), None) - def _tma_body(ki, _): - slot = arith.remui(ki, c(max_concurrent_steps, index)) - # TODO(apaszke): Use a predicate instead of a conditional. - with mgpu.when(arith.cmpi(arith.CmpIPredicate.uge, ki, c(max_concurrent_steps, index))): - ab_empty_barriers[slot].wait() - full_barrier = ab_full_barriers[slot] - with mgpu.when(is_leader_block): - full_barrier.arrive_expect_tx( - bytecount((tile_m, tile_k), in_dtype) + bytecount((tile_n, tile_k), in_dtype) + call_counter should be 0 the first time this function is called and + incremented by 1 before each subsequent call. + """ + acc_slot = arith.remui(call_counter, c(2, index)) + acc_slice = acc.slice(slice(None), mgpu.ds(arith.muli(acc_slot, c(tile_n, index)), tile_n)) + # All blocks in the cluster share the same m_start -- align it! + m_start = arith.muli(arith.divui(block_m_start, c(tile_m, index)), c(tile_m, index)) + with mgpu.when(is_leader_of(TMA_WARP)): + @mgpu.fori(c(k_loop_iter, index), None) + def _tma_body(ki, _): + slot = arith.remui(ki, c(max_concurrent_steps, index)) + isnt_warmup = arith.cmpi( + arith.CmpIPredicate.uge, ki, c(max_concurrent_steps, index) + ) + isnt_first_call = arith.cmpi( + arith.CmpIPredicate.ne, call_counter, c(0, index) + ) + with mgpu.when(arith.ori(isnt_first_call, isnt_warmup)): + ab_empty_barriers[slot].wait() + full_barrier = ab_full_barriers[slot] + with mgpu.when(is_leader_block): + full_barrier.arrive_expect_tx( + bytecount((tile_m, tile_k), dtype) + bytecount((tile_n, tile_k), dtype) + ) + k_start = arith.muli(ki, c(tile_k, index)) + common_args = dict( + swizzle=swizzle, + barrier=full_barrier, + arrive=False, + predicate=None, + collective=gpu.Dimension.x, + partitioned=0, # Non-contracting dim is always 0. + ) + ctx.async_copy( + src_ref=a, + dst_ref=mgpu.memref_slice(a_smem, slot), + gmem_slice=(ds(m_start, tile_m), ds(k_start, tile_k)), + gmem_transform=mgpu.TileTransform(tiling), + **common_args, + ) + ctx.async_copy( + src_ref=b, + dst_ref=mgpu.memref_slice(b_smem, slot), + gmem_slice=(ds(n_start, tile_n), ds(k_start, tile_k)), + gmem_transform=mgpu.TileTransform(tiling), + **common_args, ) - k_start = arith.muli(ki, c(tile_k, index)) - common_args = dict( - swizzle=swizzle, - barrier=full_barrier, - arrive=False, - uniform=False, - collective=gpu.Dimension.x, - partitioned=0, # Non-contracting dim is always 0. - ) - ctx.async_copy( - src_ref=a, - dst_ref=mgpu.memref_slice(a_smem, slot), - gmem_slice=(ds(m_start, tile_m), ds(k_start, tile_k)), - gmem_transform=mgpu.TileTransform(tiling), - **common_args, - ) - ctx.async_copy( - src_ref=b, - dst_ref=mgpu.memref_slice(b_smem, slot), - gmem_slice=(ds(n_start, tile_n), ds(k_start, tile_k)), - gmem_transform=mgpu.TileTransform(tiling), - **common_args, - ) - with mgpu.when(arith.andi(is_leader_of(MMA_WARP), is_leader_block)): - @mgpu.fori(c(k_loop_iter, index), arith.constant(i1, 0)) - def _mma_body(ki, accumulate): - slot = arith.remui(ki, c(max_concurrent_steps, index)) - ab_full_barriers[slot].wait() - tcgen05.mma( - acc, - mgpu.memref_slice(a_smem, slot), - mgpu.memref_transpose(mgpu.memref_slice(b_smem, slot), (1, 0, 3, 2)), - a_swizzle=swizzle, - b_swizzle=swizzle, - accumulate=accumulate, - collective=collective, - ) - accumulate = arith.constant(i1, 1) - is_last_iter = arith.cmpi( - arith.CmpIPredicate.eq, ki, c(k_loop_iter - 1, index) - ) - barrier_ptr = arith.select( - is_last_iter, - mma_done_barrier.get_ptr(), - ab_empty_barriers[slot].get_ptr(), - ) - tcgen05.commit_arrive(barrier_ptr, collective=collective, ctx=ctx) - return accumulate + # We wait in all blocks in the cluster to avoid double arrival errors. + reuses_tmem = arith.cmpi(arith.CmpIPredicate.uge, call_counter, c(2, index)) + with mgpu.when(arith.andi(is_leader_of(MMA_WARP), reuses_tmem)): + tmem_done_barrier[acc_slot].wait(for_tensor_core=True) + with mgpu.when(arith.andi(is_leader_of(MMA_WARP), is_leader_block)): + @mgpu.fori(c(k_loop_iter, index), arith.constant(i1, 0)) + def _mma_body(ki, accumulate): + slot = arith.remui(ki, c(max_concurrent_steps, index)) + ab_full_barriers[slot].wait() + tcgen05.mma( + acc_slice, + mgpu.memref_slice(a_smem, slot), + mgpu.memref_transpose(mgpu.memref_slice(b_smem, slot), (1, 0, 3, 2)), + a_swizzle=swizzle, + b_swizzle=swizzle, + accumulate=accumulate, + collective=collective, + ) + accumulate = arith.constant(i1, 1) + tcgen05.commit_arrive(ab_empty_barriers[slot], collective=collective, ctx=ctx) + is_last_iter = arith.cmpi( + arith.CmpIPredicate.eq, ki, c(k_loop_iter - 1, index) + ) + with mgpu.when(is_last_iter): + tcgen05.commit_arrive(mma_done_barrier[acc_slot], collective=collective, ctx=ctx) + return accumulate - gpu.barrier() - mma_done_barrier.wait(for_tensor_core=True) + with mgpu.when(is_store_warpgroup): + mma_done_barrier[acc_slot].wait(for_tensor_core=True) + final_acc = acc_slice.load().astype(mlir.dtype_to_ir_type(jnp.dtype(dtype))) + assert tile_n % epilogue_tile_n == 0 + for ni in range(tile_n // epilogue_tile_n): + n_slice = ds(ni * epilogue_tile_n, epilogue_tile_n) + final_acc[:, n_slice].store_tiled(d_smem, swizzle=128) + # We store the first tile before arriving to reduce register pressure. + mgpu.commit_shared() + store_n_start = arith.addi(n_start, c(ni * epilogue_tile_n, index)) + ctx.async_copy( + src_ref=d_smem, + dst_ref=d, + gmem_slice=( + ds(block_m_start, block_tile_m), + ds(store_n_start, epilogue_tile_n), + ), + gmem_transform=mgpu.TileTransform((128, swizzle_elems)), + swizzle=128, + ) + ctx.await_async_copy(0, await_read_only=True) + tmem_done_barrier[acc_slot].arrive(for_tensor_core=True) - acc[:].astype(ir.F16Type.get()).store_tiled(d_smem, swizzle=128) - mgpu.commit_shared() - ctx.async_copy( - src_ref=d_smem, - dst_ref=d, - gmem_slice=(ds(block_m_start, block_tile_m), ds(n_start, tile_n)), - gmem_transform=mgpu.TileTransform((128, swizzle_elems)), - swizzle=swizzle, + # We statically assign the tiles to SMs. + logical_grid_size = math.prod(logical_grid) + sm_id = gpu.block_id(gpu.Dimension.x) + extra_step = arith.cmpi( + arith.CmpIPredicate.slt, sm_id, c(logical_grid_size % num_sms, index) + ) # Some SMs do an extra step when grid size isn't divisible by SM count. + mn_steps = arith.addi( + mgpu.c(logical_grid_size // num_sms, index), + arith.index_castui(index, extra_step), ) - ctx.await_async_copy(0) + + @mgpu.fori(mn_steps, None) + def _mn_loop(local_mn_step, _): + global_mn_step = arith.addi( + sm_id, arith.muli(local_mn_step, mgpu.c(num_sms, index)) + ) + logical_idxs = [] + for dim_size in logical_grid: + logical_idxs.append(arith.remui(global_mn_step, mgpu.c(dim_size, index))) + global_mn_step = arith.divui(global_mn_step, mgpu.c(dim_size, index)) + lx, ly, lz = logical_idxs + m_idx = arith.addi(lx, arith.muli(lz, c(grid_tile_m, index))) + n_idx = ly + + block_m_start = arith.muli(m_idx, c(block_tile_m, index)) + n_start = arith.muli(n_idx, c(tile_n,index)) + compute_output(block_m_start, n_start, local_mn_step) compute_buffers = ( jax.ShapeDtypeStruct( mgpu.tile_shape((max_concurrent_steps, block_tile_m, tile_k), tiling), - jnp.float16), + dtype), jax.ShapeDtypeStruct( - mgpu.tile_shape((max_concurrent_steps, block_tile_n, tile_k), tiling), - jnp.float16), + mgpu.tile_shape((max_concurrent_steps, block_tile_n, tile_k), tiling), + dtype), ) + epilogue_tile_n = 64 epilogue_buffer = jax.ShapeDtypeStruct( - mgpu.tile_shape((block_tile_m, tile_n), (128, swizzle_elems)), - jnp.float16) - smem_buffers = mgpu.Union([compute_buffers, epilogue_buffer]) + mgpu.tile_shape((block_tile_m, epilogue_tile_n), (128, swizzle_elems)), + dtype) + smem_buffers = [compute_buffers, epilogue_buffer] smem = ( smem_buffers, [mgpu.Barrier(arrival_count=1, num_barriers=max_concurrent_steps)] * 2, - mgpu.Barrier(arrival_count=1), - mgpu.TMEM((128, tile_n), jnp.float32, collective=collective), + mgpu.Barrier(arrival_count=1, num_barriers=2), + mgpu.ClusterBarrier(collective_dims=(gpu.Dimension.x,), num_barriers=2), + mgpu.TMEM((128, 2 * tile_n), jnp.float32, collective=collective), ) + num_sms = 148 return mgpu.as_gpu_kernel( kernel, - (grid_tile_m, n // tile_n, m // (block_tile_m * grid_tile_m)), - (128, 1, 1), + (num_sms, 1, 1), # This is a persistent kernel. + (2 * 128, 1, 1), ( - jax.ShapeDtypeStruct((m, k), jnp.float16), - jax.ShapeDtypeStruct((n, k), jnp.float16), + jax.ShapeDtypeStruct((m, k), dtype), + jax.ShapeDtypeStruct((n, k), dtype), ), - jax.ShapeDtypeStruct((m, n), jnp.float16), + jax.ShapeDtypeStruct((m, n), dtype), smem, cluster=(2 if collective else 1, 1, 1), ) @@ -213,7 +273,7 @@ def main(unused_argv): b = jr.normal(key=kb, shape=(n, k), dtype=jnp.float16) tile_m = (128,) - tile_n = (128, 256, 512) + tile_n = (128, 256) max_concurrent_steps = (2, 4, 5, 6) grid_tile_m = (1, 2, 4, 8, 16) collective = (False, True) @@ -230,13 +290,13 @@ def main(unused_argv): tile_n *= 2 if m < tile_m or n < tile_n: continue - if tile_n > 512: + if 2 * tile_n > 512: continue if (m // tile_m) % kwargs["grid_tile_m"]: continue try: with mlir.make_ir_context(), ir.Location.unknown(): - f = build_kernel(m, n, k, **kwargs) + f = build_kernel(m, k, n, jnp.float16, **kwargs) _, runtime = profiler.measure(f)(a, b) except ValueError as e: if "Mosaic GPU kernel exceeds available shared memory" not in str(e): @@ -251,7 +311,7 @@ def main(unused_argv): raise ValueError("No valid configuration found") with mlir.make_ir_context(), ir.Location.unknown(): - d, runtime = profiler.measure(build_kernel(m, n, k, **best_kwargs))(a, b) + d, runtime = profiler.measure(build_kernel(m, k, n, jnp.float16, **best_kwargs))(a, b) d_ref, ref_runtime = profiler.measure(jax.jit(lambda a, b: a @ b.T))(a, b) tflops = float(2 * k * m * n) / (runtime / 1e3) / 1e12 diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index 5daed8416589..1c61807c83ad 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -16,25 +16,25 @@ from __future__ import annotations +from collections.abc import Callable import dataclasses import functools +import itertools import math -from collections.abc import Callable -from typing import Iterable, Protocol, Sequence, TypeVar +from typing import Protocol, TypeVar +from collections.abc import Generator, Iterable, Sequence -import itertools import jax +import jax.experimental.mosaic.gpu as mgpu from jaxlib.mlir import ir from jaxlib.mlir.dialects import arith from jaxlib.mlir.dialects import gpu from jaxlib.mlir.dialects import llvm from jaxlib.mlir.dialects import math as mlir_math from jaxlib.mlir.dialects import memref -from jaxlib.mlir.dialects import nvvm from jaxlib.mlir.dialects import vector import numpy as np -import jax.experimental.mosaic.gpu as mgpu from . import utils # mypy: ignore-errors @@ -68,15 +68,15 @@ class Tiling: def __post_init__(self): if not self.tiles: return - tiled_rank = len(self.tiles[0]) + last_tile_rank = len(self.tiles[0]) for tile in self.tiles: - if len(tile) > tiled_rank: - raise ValueError("Only the first tile can refer to value dimensions") + if len(tile) > last_tile_rank: + raise ValueError("Tiles must have a decreasing rank") if not tile: raise ValueError("Tiles must not be empty") if any(d <= 0 for d in tile): raise ValueError(f"Tile shape must only have positive sizes, got: {self.tiles}") - tiled_rank += len(tile) + last_tile_rank = len(tile) def __str__(self): return f"Tiling({''.join(map(str, self.tiles))})" @@ -111,6 +111,43 @@ def fail(): shape = (*untiled_dims, *(d * t for d, t in zip(tiled_dims, tile))) return shape + def canonicalize(self) -> Tiling: + """Returns a canonicalized version of the tiling. + + We define a tiling to be canonical if, at each step (except the first one, + which defines the base tile shape): + + 1. The tiling partitions at least one dimension in more than 1 tile. For + example, the tiling `(8, 8)(8, 8)` is not canonical, as applying it + yields a shape `(1, 1, 8, 8)`. We canonicalize it to `(8, 8)`, which + allows getting rid of the unnecessary `1` dimensions. + 2. The leading dimensions of each tile are not `1`. If canonicalizing a + tile in this way leads to an empty tile, then the tile is given shape + `(1,)`---which is still a meaningful (final) tile. For example, the + tiling `(8, 8)(1, 4)` is not canonical, as applying it yields a shape + `(8, 2, 1, 4)`. We canonicalize it to `(8, 8)(4,)`, which allows + getting rid of the unnecessary `1` dimension, and yields a shape + `(8, 2, 4)`. + """ + if len(self.tiles) <= 1: + return self + + shape = self.tiles[0] + new_tiling = [self.tiles[0]] + for tile in self.tiles[1:]: + for i, d in enumerate(tile): + if d != 1: + canonical_tile = tile[i:] + break + else: + canonical_tile = (1,) + tiled_dims = shape[-len(canonical_tile):] + if tiled_dims == canonical_tile: + continue + shape = canonical_tile + new_tiling.append(canonical_tile) + return Tiling(tuple(new_tiling)) + def tile_strides(self, strides: tuple[int, ...]) -> tuple[int, ...]: """Computes the strides of an array after tiling.""" for tile in self.tiles: @@ -118,6 +155,33 @@ def tile_strides(self, strides: tuple[int, ...]) -> tuple[int, ...]: strides = (*untiled, *(s * t for s, t in zip(tiled, tile)), *tiled) return strides + def tile_dimension(self, dim: int) -> tuple[bool, ...]: + """Result is True whenever the tiled dim originated from the given input dim.""" + tiling_rank = len(self.tiles[0]) + if dim < 0 or dim >= tiling_rank: + raise ValueError(f"Invalid dimension {dim} for tiling {self}") + strides = [1] * tiling_rank + strides[dim] = 0 + return tuple(s == 0 for s in self.tile_strides(tuple(strides))) + + def remove_dimension(self, dim: int) -> Tiling: + """Returns a tiling with the given dimension removed.""" + tiling_rank = len(self.tiles[0]) + if dim < 0 or dim >= tiling_rank: + raise ValueError(f"Invalid dimension {dim} for tiling {self}") + dim_in_tile = dim + tiles = [] + last_tile_rank = len(self.tiles[0]) + for t in self.tiles: + assert last_tile_rank >= len(t) + dim_in_tile -= last_tile_rank - len(t) + if dim_in_tile >= 0: + t = t[:dim_in_tile] + t[dim_in_tile + 1:] + if not t: # If this tile is empty, all other tiles will be empty too. + break + tiles.append(t) + return Tiling(tuple(tiles)) + def tile_nested_shape_strides( self, shape: tuple[tuple[int, ...], ...], @@ -202,6 +266,11 @@ def enumerate_negative(elems: Sequence[T]) -> Iterable[tuple[int, T]]: yield i - offset, e +@dataclasses.dataclass(frozen=True) +class Replicated: + times: int + + @dataclasses.dataclass(frozen=True) class TiledLayout: """A FragmentedArray layout derived from a tiling expression. @@ -247,27 +316,51 @@ class TiledLayout: by a single (logical) register. """ tiling: Tiling - warp_dim: int - lane_dims: tuple[int, ...] # major-to-minor + warp_dim: int | Replicated + lane_dims: tuple[int | Replicated, ...] # major-to-minor vector_dim: int + # Whether to enforce that the layout is canonical. Users of `TiledLayout` + # should not set this to `False`, but it is helpful to be able to construct + # non-canonical layouts as an intermediate state when implementing layout + # transformations. + _check_canonical: dataclasses.InitVar[bool] = True - def __post_init__(self): + def __post_init__(self, _check_canonical: bool): if not self.tiling.tiles: raise ValueError("Tiling must have at least one tile") min_shape = self.tiling.tiles[0] min_tiled_shape = self.tiling.tile_shape(min_shape) - dims_set = {self.warp_dim, *self.lane_dims, self.vector_dim} - if len(dims_set) != len(self.lane_dims) + 2: + dims_set = {*self.partitioned_lane_dims, self.vector_dim} + if partitions_warp_dim := not isinstance(self.warp_dim, Replicated): + dims_set.add(self.warp_dim) + if len(dims_set) != len(self.partitioned_lane_dims) + 1 + partitions_warp_dim: raise ValueError for d in dims_set: if d >= 0: raise ValueError("All dimensions must be negative") if d < -(len(min_tiled_shape) - len(min_shape)): raise ValueError("Dimension out of range") - if min_tiled_shape[self.warp_dim] != WARPS_IN_WARPGROUP: - raise ValueError - if math.prod(min_tiled_shape[d] for d in self.lane_dims) != WARP_SIZE: + if isinstance(self.warp_dim, Replicated): + if self.warp_dim.times != WARPS_IN_WARPGROUP: + raise ValueError + elif min_tiled_shape[self.warp_dim] != WARPS_IN_WARPGROUP: raise ValueError + lane_dims_prod = math.prod( + d.times if isinstance(d, Replicated) else min_tiled_shape[d] + for d in self.lane_dims + ) + if lane_dims_prod != WARP_SIZE: + raise ValueError("The product of lane dims does not equal the warp size") + if _check_canonical: + canonical_layout = self.canonicalize() + if self != canonical_layout: + raise ValueError(f"{self} is not canonical.") + + @functools.cached_property + def partitioned_lane_dims(self) -> tuple[int, ...]: + return tuple( + d for d in self.lane_dims if not isinstance(d, Replicated) + ) def thread_idxs(self, shape: tuple[int, ...]) -> Iterable[tuple[ir.Value, ...]]: # We first find the linear index and then divide by the shape to @@ -319,11 +412,15 @@ def tiled_tiling_rank(self) -> int: def vector_length(self) -> int: return self.tiled_tiling_shape[self.vector_dim] + def registers_element_type(self, t: ir.Type) -> ir.Type: + return ir.VectorType.get((self.vector_length,), t) + def registers_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: """Returns the shape of the register array needed to represent an array of the given logical shape.""" tiled_shape = list(self.tiling.tile_shape(shape)) - tiled_shape[self.warp_dim] = 1 - for d in self.lane_dims: + if not isinstance(self.warp_dim, Replicated): + tiled_shape[self.warp_dim] = 1 + for d in self.partitioned_lane_dims: tiled_shape[d] = 1 tiled_shape[self.vector_dim] = 1 return tuple(tiled_shape) @@ -335,16 +432,20 @@ def shape_from_registers_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: """ tiled_tiling = self.tiled_tiling_shape shape = list(shape) - shape[self.warp_dim] = WARPS_IN_WARPGROUP - for d in self.lane_dims: + if not isinstance(self.warp_dim, Replicated): + shape[self.warp_dim] = WARPS_IN_WARPGROUP + for d in self.partitioned_lane_dims: shape[d] = tiled_tiling[d] shape[self.vector_dim] = tiled_tiling[self.vector_dim] return self.tiling.untile_shape(tuple(shape)) - def lane_indices(self) -> tuple[ir.Value, ...]: + def _full_lane_indices(self) -> tuple[ir.Value, ...]: i32 = ir.IntegerType.get_signless(32) tiled_shape = self.tiled_tiling_shape - lanes_shape = tuple(tiled_shape[d] for d in self.lane_dims) + lanes_shape = tuple( + d.times if isinstance(d, Replicated) else tiled_shape[d] + for d in self.lane_dims + ) assert math.prod(lanes_shape) == WARP_SIZE lane_strides = utils.get_contiguous_strides(lanes_shape) lane_idx = arith.remui(utils.thread_idx(), c(WARP_SIZE, i32)) @@ -352,22 +453,121 @@ def lane_indices(self) -> tuple[ir.Value, ...]: arith.remui(arith.divui(lane_idx, c(stride, i32)), c(size, i32)) for stride, size in zip(lane_strides, lanes_shape) ) + return lane_indices + + def lane_indices(self) -> tuple[ir.Value, ...]: + i32 = ir.IntegerType.get_signless(32) + tiled_shape = self.tiled_tiling_shape + lane_indices = self._full_lane_indices() full_indices = [arith.constant(i32, 0)] * len(tiled_shape) for d, i in zip(self.lane_dims, lane_indices): + if isinstance(d, Replicated): + continue full_indices[d] = i return tuple(full_indices) def warp_indices(self) -> tuple[ir.Value, ...]: i32 = ir.IntegerType.get_signless(32) tiled_shape_rank = len(self.tiled_tiling_shape) - warp_idx = arith.remui( - arith.divui(utils.thread_idx(), c(WARP_SIZE, i32)), - c(WARPS_IN_WARPGROUP, i32), - ) indices = [arith.constant(i32, 0)] * tiled_shape_rank - indices[self.warp_dim] = warp_idx + if not isinstance(self.warp_dim, Replicated): + warp_idx = arith.remui( + arith.divui(utils.thread_idx(), c(WARP_SIZE, i32)), + c(WARPS_IN_WARPGROUP, i32), + ) + indices[self.warp_dim] = warp_idx return tuple(indices) + def remove_dimension(self, dim: int) -> TiledLayout: + if dim < 0 or dim >= len(self.tiling.tiles[0]): + raise ValueError(f"Dimension {dim} is out of range for {self.tiling}") + new_tiling = self.tiling.remove_dimension(dim) + tiled_shape = self.tiled_tiling_shape + removed_dim = self.tiling.tile_dimension(dim) + dim_offsets = np.cumsum(removed_dim[::-1])[::-1].tolist() + if removed_dim[self.vector_dim]: + new_tiling = Tiling((*new_tiling.tiles, (1,))) + new_vector_dim = -1 + dim_offsets = [o - 1 for o in dim_offsets] # We inserted an extra dim. + else: + new_vector_dim = self.vector_dim + dim_offsets[self.vector_dim] + def replace_tiled_dim(d: int | Replicated, size: int): + if isinstance(d, Replicated): + return d + elif removed_dim[d]: + return Replicated(size) + else: + return d + dim_offsets[d] + return TiledLayout( + new_tiling, + replace_tiled_dim(self.warp_dim, WARPS_IN_WARPGROUP), + tuple( + d if isinstance(d, Replicated) else replace_tiled_dim(d, tiled_shape[d]) + for d in self.lane_dims + ), + new_vector_dim, + _check_canonical=False, + ).canonicalize() + + def reduce(self, axes: Sequence[int]) -> TiledLayout: + reduced_layout = self + for a in sorted(axes, reverse=True): + reduced_layout = reduced_layout.remove_dimension(a) + return reduced_layout + + def canonicalize(self) -> TiledLayout: + """Returns a version of this layout where tiling is canonical.""" + canonical_tiling = self.tiling.canonicalize() + if canonical_tiling == self.tiling: + return self + + s = self.base_tile_shape + canonical_tiled_tiling_shape = canonical_tiling.tile_shape(s)[len(s):] + offset = len(canonical_tiled_tiling_shape) - 1 + + rev_removed_dims = [] + # Iterate starting from the end in order to eliminate leading dimensions, + # whenever possible. For instance, say we have + # + # shape=(4, 32, 1, 1, 1, 1, 1) + # warp_dim=-7, + # lane_dims=(-6,) + # vector_dim=-1 + # + # and we want to canonicalize this to + # + # shape=(4, 32, 1) + # warp_dim=-3, + # lane_dims=(-2,) + # vector_dim=-1. + # + # After the loop below, we end up with + # + # rev_removed_dims=[False, True, True, True, True, False, False] + # + # which will yield offsets `4` for `warp_dim`, `4` for `lane_dims[0]`, and + # `0` for `vector_dim`. + for d in reversed(self.tiled_tiling_shape): + if offset >= 0 and d == canonical_tiled_tiling_shape[offset]: + rev_removed_dims.append(False) + offset -= 1 + else: + rev_removed_dims.append(True) + assert offset == -1 + + dim_offsets = np.cumsum(rev_removed_dims)[::-1].tolist() + + def replace_tiled_dim(d: int | Replicated): + return d if isinstance(d, Replicated) else d + dim_offsets[d] + + return TiledLayout( + canonical_tiling, + replace_tiled_dim(self.warp_dim), + tuple(replace_tiled_dim(d) for d in self.lane_dims), + replace_tiled_dim(self.vector_dim), + _check_canonical=True + ) + def _tiled_wgmma_layout(shape: tuple[int, ...]): """Returns the tiled layout relevant for WGMMA operations. @@ -382,28 +582,6 @@ def _tiled_wgmma_layout(shape: tuple[int, ...]): return WGMMA_LAYOUT -@dataclasses.dataclass(frozen=True) -class WGMMARowFragLayout: - """[m] matrix, where m % 64 == 0.""" - - def thread_idxs(self, shape): - index = ir.IndexType.get() - assert len(shape) == 1 - assert shape[0] % 64 == 0 - tid = arith.index_cast(ir.IndexType.get(), mgpu.thread_idx()) - tid_wg = arith.remui(tid, c(WARPGROUP_SIZE, index)) - warp_idx = arith.divui(tid_wg, c(32, index)) - lane_id = arith.remui(tid_wg, c(32, index)) - row_base = arith.addi( - arith.divui(lane_id, c(4, index)), arith.muli(warp_idx, c(16, index)) - ) - - for row_group in range(0, shape[0], 64): - for row_subgroup in (0, 8): - row = arith.addi(row_base, c(row_group + row_subgroup, index)) - yield (row,) - - @dataclasses.dataclass(frozen=True) class WGSplatFragLayout: """A fragmented array where all the values are equal represented as a register per thread. @@ -435,6 +613,14 @@ def can_broadcast_to(self, shape) -> bool: """ return all(dim1 == dim2 or dim1 == 1 for dim1, dim2 in zip(self.shape[::-1], shape[::-1])) + def registers_element_type(self, t: ir.Type) -> ir.Type: + return t + + def registers_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: + """Returns the shape of the register array needed to represent an array of the given logical shape.""" + del shape # Unused. + return () + def thread_idxs(self, shape): assert shape == self.shape raise NotImplementedError @@ -469,6 +655,15 @@ def from_shaped_type(cls, shaped_ty: ir.Type): shape=tuple(shaped_ty.shape), vec_size=min(8 // bw, max_vec_size) ) + def registers_element_type(self, t: ir.Type) -> ir.Type: + return ir.VectorType.get((self.vec_size,), t) + + def registers_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: + """Returns the shape of the register array needed to represent an array of the given logical shape.""" + if shape != self.shape: + raise ValueError(f"Shape {shape} is not compatible with {self}") + return (math.prod(self.shape) // (WARPGROUP_SIZE * self.vec_size),) + def thread_idxs(self, shape): assert shape == self.shape index = ir.IndexType.get() @@ -497,14 +692,25 @@ def linear_thread_idxs(self): yield arith.addi(off, c(i * WARPGROUP_SIZE * self.vec_size, tidx.type)) -FragmentedLayout = WGSplatFragLayout | WGStridedFragLayout | WGMMARowFragLayout | TiledLayout +FragmentedLayout = WGSplatFragLayout | WGStridedFragLayout | TiledLayout -WGMMA_ROW_LAYOUT = WGMMARowFragLayout() +WGMMA_COL_LAYOUT = TiledLayout( + Tiling(((8,), (2,))), + warp_dim=Replicated(4), + lane_dims=(Replicated(8), -2), + vector_dim=-1, +) +WGMMA_ROW_LAYOUT = TiledLayout( + Tiling(((64,), (16,), (8,), (1,))), + warp_dim=-4, + lane_dims=(-2, Replicated(4)), + vector_dim=-1, +) # The tiled layout is equivalent to one described here in PTX documentation: # https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n16-d -# In this layout, we partition the 64x8 tiles over 4 warpgroups into 16x8 tiles. +# In this layout, we partition the 64x8 tiles over 4 warps into 16x8 tiles. # Then, we further split the 16x8 tiles into 8x8 submatrices which are the unit # of data that is split across a warp. Since 8*8 = 64, but a warp has only 32 # threads, we vectorize pairs of elements along columns. @@ -516,9 +722,9 @@ def linear_thread_idxs(self): # 12 12 13 13 14 14 15 15 # ... WGMMA_LAYOUT = TiledLayout( - Tiling(((64, 8), (16, 8), (8, 8), (1, 2))), - warp_dim=-8, - lane_dims=(-4, -3), + Tiling(((64, 8), (16, 8), (8, 8), (2,))), + warp_dim=-7, + lane_dims=(-3, -2), vector_dim=-1, ) # This tiled layout is similar to the WGMMA layout, only the unit at which we @@ -570,7 +776,7 @@ def linear_thread_idxs(self): # ... # # You can see that we have taken 2x2 submatrices from the above layout and -# transposed them. The assigment of lanes to elements is such that in both +# transposed them. The assignment of lanes to elements is such that in both # layouts the same two lanes map to a single 2x2 submatrix, making the transpose # very cheap (one shuffle and permute suffices to change between those layouts). WGMMA_TRANSPOSED_LAYOUT = TiledLayout( @@ -580,6 +786,30 @@ def linear_thread_idxs(self): vector_dim=-2, ) +# Like WGMMA_LAYOUT, only each warp holds a 32xN strip instead of 16xN. +TCGEN05_LAYOUT = TiledLayout( + Tiling(((128, 8), (32, 8), (8, 8), (2,))), + warp_dim=-7, + lane_dims=(-3, -2), + vector_dim=-1, +) +# TCGEN05_ROW_LAYOUT is to TCGEN05_LAYOUT as WGMMA_ROW_LAYOUT is to +# WGMMA_LAYOUT. +TCGEN05_ROW_LAYOUT = TiledLayout( + Tiling(tiles=((128,), (32,), (8,), (1,))), + warp_dim=-4, + lane_dims=(-2, Replicated(times=4)), + vector_dim=-1, +) +# TCGEN05_COL_LAYOUT is to TCGEN05_LAYOUT as WGMMA_COL_LAYOUT is to +# WGMMA_LAYOUT. +TCGEN05_COL_LAYOUT = TiledLayout( + Tiling(tiles=((8,), (2,))), + warp_dim=Replicated(times=4), + lane_dims=(Replicated(times=8), -2), + vector_dim=-1, +) + @jax.tree_util.register_pytree_node_class @dataclasses.dataclass(init=False, eq=False, frozen=True, slots=True) class FragmentedArray: @@ -612,12 +842,6 @@ def __init__( ) match self.layout: - # Registers are [m_tiles, 2 rows] in WGMMA_ROW layout - # Each element is a dtype scalar - case WGMMARowFragLayout(): - if _registers.ndim != 2 or _registers.shape[-1] != 2: - raise ValueError(f"Invalid register array shape: {_registers.shape}") - # Registers are flat case WGStridedFragLayout(shape): [reg_size] = ir.VectorType(_registers.flat[0].type).shape @@ -626,8 +850,8 @@ def __init__( != math.prod(_registers.shape) * WARPGROUP_SIZE * reg_size ): raise ValueError( - "Invalid register array shape: math.prod({_registers.shape}) *" - " {WARPGROUP_SIZE} * {reg_size}, want: math.prod({shape})" + f"Invalid register array shape: math.prod({_registers.shape}) *" + f" {WARPGROUP_SIZE} * {reg_size}, want: math.prod({shape})" ) # Just a single register @@ -674,59 +898,19 @@ def load_strided( vecs = [vector.load(vec_ty, ref, vec_idx) for vec_idx in layout.thread_idxs(shape)] return cls(_registers=np.array(vecs), _layout=layout, _is_signed=is_signed) - @classmethod - def load_wgmma_row( - cls, - ref: ir.Value, - *, - is_signed: bool | None = None, - ): - if not ir.MemRefType.isinstance(ref.type): - raise TypeError(ref.type) - - ref_ty = ir.MemRefType(ref.type) - shape = tuple(ref_ty.shape) - if len(shape) != 1: - raise ValueError("WGMMARowFragLayout requires a 1D shape") - if shape[0] % 64: - raise ValueError( - "WGMMARowFragLayout requires shape[0] to be a multiple of 64" - ) - - layout = WGMMARowFragLayout() - registers = [memref.load(ref, [idx]) for (idx,) in layout.thread_idxs(shape)] - registers = np.array(registers).reshape(-1, 2) - return cls(_registers=registers, _layout=layout, _is_signed=is_signed) - - @classmethod def splat(cls, value, shape, layout=None, *, is_signed: bool | None = None): layout = layout or WGSplatFragLayout(shape) match layout: - case WGMMARowFragLayout(): - if len(shape) != 1: - raise ValueError("WGMMARowFragLayout requires a 1D shape") - if shape[0] % 64: - raise ValueError( - "WGMMARowFragLayout requires shape[0] to be a multiple of 64" - ) - reg_shape = (shape[0] // 64, 2) - case WGStridedFragLayout(vec_size=vec_size): - assert shape == layout.shape - elems = np.prod(shape) - reg_shape = (elems // (WARPGROUP_SIZE * vec_size),) - value = vector.splat(ir.VectorType.get((vec_size,), value.type), value) case WGSplatFragLayout(): - assert shape == layout.shape - reg_shape = () - case TiledLayout(): - value = vector.splat(ir.VectorType.get((layout.vector_length,), value.type), value) - reg_shape = layout.registers_shape(shape) + pass + case WGStridedFragLayout() | TiledLayout(): + value = vector.splat(layout.registers_element_type(value.type), value) case _: raise NotImplementedError(layout) return cls( - _registers=np.full(reg_shape, value, dtype=object), + _registers=np.full(layout.registers_shape(shape), value, dtype=object), _layout=layout, _is_signed=is_signed, ) @@ -734,9 +918,6 @@ def splat(cls, value, shape, layout=None, *, is_signed: bool | None = None): @property def shape(self): match self.layout: - case WGMMARowFragLayout(): - row_tiles = self.registers.shape[0] - return (row_tiles * 64,) case WGStridedFragLayout(shape): return shape case WGSplatFragLayout(shape=shape): @@ -752,7 +933,7 @@ def mlir_dtype(self): match self.layout: case WGStridedFragLayout() | TiledLayout(): return ir.VectorType(reg_ty).element_type - case WGMMARowFragLayout() | WGSplatFragLayout(): + case WGSplatFragLayout(): return reg_ty case _: raise NotImplementedError @@ -1173,7 +1354,7 @@ def _compare(self, other, *, f_pred, si_pred, ui_pred): if ir.FloatType.isinstance(self.mlir_dtype): pred = functools.partial(arith.cmpf, f_pred) elif ir.IntegerType.isinstance(self.mlir_dtype): - if ir.IntegerType(self.mlir_dtype).is_signed: + if self.is_signed: pred = functools.partial(arith.cmpi, si_pred) else: pred = functools.partial(arith.cmpi, ui_pred) @@ -1326,37 +1507,40 @@ def bitcast(self, elt: ir.Type, *, output_is_signed: bool | None = None): ) def __getitem__(self, idx): - if self.layout != WGMMA_LAYOUT: - raise NotImplementedError("Only WGMMA layouts support slicing") + if not isinstance(self.layout, TiledLayout): + raise NotImplementedError("Only arrays with tiled layouts can be sliced") base_idx, slice_shape, is_squeezed = utils.parse_indices(idx, self.shape) + if any(isinstance(idx, ir.Value) for idx in base_idx): + raise ValueError("Only static slicing allowed") if any(is_squeezed): raise NotImplementedError("Only slicing implemented") - if ( - base_idx[0] % 64 - or slice_shape[0] % 64 - or base_idx[1] % 8 - or slice_shape[1] % 8 + base_tile_shape = self.layout.base_tile_shape + if len(base_tile_shape) != len(self.shape): + raise NotImplementedError("Tiling has different rank than array") + if any( + b % t or l % t + for b, l, t in zip(base_idx, slice_shape, base_tile_shape, strict=True) ): raise NotImplementedError("Only tile aligned slicing supported") - base_idx[0] //= 64 - slice_shape[0] //= 64 - base_idx[1] //= 8 - slice_shape[1] //= 8 - new_regs = self.registers[ - base_idx[0] : base_idx[0] + slice_shape[0], - base_idx[1] : base_idx[1] + slice_shape[1], - ] + register_slices = tuple( + slice(b // t, (b + l) // t) + for b, l, t in zip(base_idx, slice_shape, base_tile_shape, strict=True) + ) + new_regs = self.registers[register_slices] return FragmentedArray( _registers=new_regs, _layout=self.layout, _is_signed=self.is_signed ) # TODO(apaszke): Support JAX dtypes here as well? def astype(self, new_dtype: ir.Type, *, is_signed: bool | None = None): + index = ir.IndexType.get() i4 = ir.IntegerType.get_signless(4) i8 = ir.IntegerType.get_signless(8) i16 = ir.IntegerType.get_signless(16) i32 = ir.IntegerType.get_signless(32) bf16 = ir.BF16Type.get() + f32 = ir.F32Type.get() + f8e4m3fn = ir.Float8E4M3FNType.get() cur_dtype = self.mlir_dtype if cur_dtype == new_dtype: @@ -1374,6 +1558,98 @@ def astype(self, new_dtype: ir.Type, *, is_signed: bool | None = None): "Register bitwidth in target type must be divisible by 8, got" f" {new_reg_bitwidth}" ) + if cur_dtype == i4 and new_dtype == f8e4m3fn: + # The algorithm here is taken from CUTLASS's `NumericArrayConverter` + # specialization for int4 -> f8e4m3, available at + # https://github.com/NVIDIA/cutlass/blob/5c6bca04414e06ce74458ab0a2018e2b8272701c/include/cutlass/numeric_conversion.h#L4982. + # Each call to the function below will upcast 4 contiguous nibbles of + # the input 32-bit register, and whether to select the 4 low nibbles or + # the 4 high nibbles is determined by the `part` argument. + def upcast_to_f8e4m3fn(reg: ir.Value, part: int): + lut = [ + 0x44403800, # [0, 1, 2, 3] encoded as f8e4m3fn + 0x4E4C4A48, # [4, 5, 6, 7] encoded as f8e4m3fn + 0xCACCCED0, # [-8, -7, -6, -5] encoded as f8e4m3fn + 0xB8C0C4C8, # [-4, -3, -2, -1] encoded as f8e4m3fn + ] + + sign = arith.shrui(arith.andi(reg, c(0x88888888, i32)), c(1, i32)) + # Ignore the sign when indexing into the LUT. + lut_idx = arith.andi(reg, c(0x77777777, i32)) + + assert 0 <= part < 2 + if part == 1: + lut_idx = arith.shrui(lut_idx, c(16, i32)) + sign = arith.shrui(sign, c(16, i32)) + + prmt_sign_pattern = arith.ori(sign, c(0x32103210, i32)) + return llvm.inline_asm( + i32, + [lut_idx, prmt_sign_pattern], + f""" + {{ + .reg .b32 pos_f8s, neg_f8s; + prmt.b32 pos_f8s, {lut[0]}, {lut[1]}, $1; + prmt.b32 neg_f8s, {lut[2]}, {lut[3]}, $1; + prmt.b32 $0, pos_f8s, neg_f8s, $2; + }} + """, + "=r,r,r", + ) + new_registers = np.empty_like(self.registers) + + def packed_registers() -> Generator[tuple[list[index], ir.Value]]: + """Tries to pack registers into groups of 16 bits if vector_len < 4.""" + generator = np.ndenumerate(self.registers) + indices = [] + regs = [] + while True: + try: + for _ in range(max(4 // vector_len, 1)): + idx, reg = next(generator) + indices.append(idx) + regs.append(reg) + yield indices, utils.vector_concat(regs) + regs.clear() + indices.clear() + except StopIteration: + break + if regs: + yield indices, utils.vector_concat(regs) + + for indices, reg in packed_registers(): + group_size = ir.VectorType(reg.type).shape[0] + assert group_size % vector_len == 0 + int_ty = ir.IntegerType.get_signless(group_size * 4) + reg_as_i32 = utils.bitcast(reg, int_ty) + if int_ty != i32: + reg_as_i32 = arith.extsi(i32, reg_as_i32) + out_i32_regs = [ + upcast_to_f8e4m3fn(reg_as_i32, part=part) + for part in range(max(group_size // 4, 1)) + ] + out_vec_int = utils.vector_concat([ + vector.splat(ir.VectorType.get((1,), i32), out_i32_reg) + for out_i32_reg in out_i32_regs + ]) + out_vector_len = len(out_i32_regs) * 4 + # Bitcast to i8 first to allow slicing as necessary, since LLVM chokes + # on f8 types. + out_vec = utils.bitcast( + out_vec_int, ir.VectorType.get((out_vector_len,), i8) + ) + offset = 0 + for idx in indices: + sliced_out_vec = utils.vector_slice( + out_vec, slice(offset, offset + vector_len) + ) + new_registers[idx] = utils.bitcast( + sliced_out_vec, ir.VectorType.get((vector_len,), f8e4m3fn) + ) + offset += vector_len + return FragmentedArray( + _registers=new_registers, _layout=self.layout, _is_signed=None + ) if cur_dtype == i4 and self.is_signed and new_dtype == bf16: new_registers = np.empty_like(self.registers) out_vec_ty = ir.VectorType.get((vector_len,), new_dtype) @@ -1490,6 +1766,31 @@ def upcast_to_bf16(reg, high): return FragmentedArray( _registers=new_registers, _layout=self.layout, _is_signed=is_signed ) + # TODO(bchetioui): handle conversions to/from other float8 types. + if cur_dtype in {bf16, f32} and new_dtype == f8e4m3fn: + if vector_len != 2: + raise NotImplementedError(vector_len) + new_registers = np.empty_like(self.registers) + empty_vec_16 = llvm.mlir_undef(ir.VectorType.get((1,), i16)) + for idx, reg in np.ndenumerate(self.registers): + e0 = vector.extractelement(reg, position=c(0, index)) + e1 = vector.extractelement(reg, position=c(1, index)) + # TODO(bchetioui): can we do faster than this? + if cur_dtype == bf16: + e0 = arith.extf(f32, e0) + e1 = arith.extf(f32, e1) + new_reg_16 = llvm.inline_asm( + i16, + [e1, e0], + "cvt.rn.satfinite.e4m3x2.f32 $0, $1, $2;", + "=h,f,f", + ) + new_registers[idx] = vector.bitcast( + ir.VectorType.get((2,), f8e4m3fn), + llvm.insertelement(empty_vec_16, new_reg_16, c(0, i32))) + return FragmentedArray( + _registers=new_registers, _layout=self.layout, _is_signed=is_signed + ) # Generic path. from_float = ir.FloatType.isinstance(cur_dtype) to_float = ir.FloatType.isinstance(new_dtype) @@ -1514,7 +1815,7 @@ def upcast_to_bf16(reg, high): case WGStridedFragLayout() | TiledLayout(): shape = ir.VectorType(self.registers.flat[0].type).shape upcast_ty = ir.VectorType.get(shape, larger_ty) - case WGMMARowFragLayout() | WGSplatFragLayout(): + case WGSplatFragLayout(): upcast_ty = larger_ty case _: raise NotImplementedError(f"Unsupported layout {self.layout}") @@ -1539,7 +1840,7 @@ def upcast_to_bf16(reg, high): case WGStridedFragLayout() | TiledLayout(): shape = ir.VectorType(self.registers.flat[0].type).shape new_reg_ty = ir.VectorType.get(shape, new_dtype) - case WGMMARowFragLayout() | WGSplatFragLayout(): + case WGSplatFragLayout(): new_reg_ty = new_dtype case _: raise NotImplementedError(f"Unsupported layout {self.layout}") @@ -1549,74 +1850,26 @@ def upcast_to_bf16(reg, high): _registers=new_registers, _layout=self.layout, _is_signed=is_signed ) - # NOTE: scratch can be reused immediately once this function returns. - def reduce_sum(self, scratch: ir.Value | None = None): - if isinstance(self.layout, WGSplatFragLayout): - [reg] = self.registers.flat - if ir.FloatType.isinstance(self.mlir_dtype): - op = mulf - elif ir.IntegerType.isinstance(self.mlir_dtype): - op = arith.muli - else: - raise NotImplementedError(self.mlir_dtype) - return FragmentedArray.splat( - op(reg, utils.c(math.prod(self.shape), self.mlir_dtype)), - (), - is_signed=self.is_signed, - ) - - if not isinstance(self.layout, WGStridedFragLayout): - raise NotImplementedError(f"Unsupported layout {self.layout}") - - if scratch is None: - raise ValueError("scratch must be provided") - - if ir.FloatType.isinstance(self.mlir_dtype): - op = addf - elif ir.IntegerType.isinstance(self.mlir_dtype): - op = arith.addi - else: - raise NotImplementedError(self.mlir_dtype) - - result = c(0, self.mlir_dtype) - for reg in self.registers: - result = op( - result, - vector.reduction(self.mlir_dtype, vector.CombiningKind.ADD, reg), - ) - scratch_ty = ir.MemRefType(scratch.type) - if scratch_ty.element_type != self.mlir_dtype or scratch_ty.shape != [4]: - raise ValueError(f"Expected shape={(4,)}, {self.mlir_dtype} (got {scratch_ty})") - - index = ir.IndexType.get() - warp_result = utils.warp_tree_reduce(result, op, 32) - warp_id = arith.divui(gpu.thread_id(gpu.Dimension.x), c(32, index)) - memref.store(warp_result, scratch, [warp_id]) - utils.warpgroup_barrier() - zero_index = c(0, index) - with mgpu.single_thread(per_block=False): - scratch_vec = vector.load( - ir.VectorType.get((4,), self.mlir_dtype), - scratch, - [zero_index], - ) - scratch_sum = vector.reduction( - self.mlir_dtype, vector.CombiningKind.ADD, scratch_vec - ) - memref.store(scratch_sum, scratch, [zero_index]) - utils.warpgroup_barrier() - result = memref.load(scratch, [zero_index]) - utils.warpgroup_barrier() # Make sure everyone is done using scratch. - return FragmentedArray.splat(result, (), is_signed=self.is_signed) - - def reduce(self, op: str | Callable[[ir.Value, ir.Value], ir.Value], axis): + def reduce( + self, + op: str | Callable[[ir.Value, ir.Value], ir.Value], + axis: int | Sequence[int, ...], + scratch: ir.Value | None = None, + ): + i32 = ir.IntegerType.get_signless(32) + if isinstance(axis, int): + axis = (axis,) + splat_op = None if isinstance(op, str): match op: case "add": + reduced_elems = math.prod(self.shape[a] for a in axis) if ir.FloatType.isinstance(self.mlir_dtype): op = addf + splat_op = lambda x: arith.mulf(x, c(reduced_elems, x.type)) elif ir.IntegerType.isinstance(self.mlir_dtype): op = arith.addi + splat_op = lambda x: arith.muli(x, c(reduced_elems, x.type)) else: raise NotImplementedError(self.mlir_dtype) case "max": @@ -1628,52 +1881,174 @@ def reduce(self, op: str | Callable[[ir.Value, ir.Value], ir.Value], axis): op = arith.maxsi if self.is_signed else arith.maxui else: raise NotImplementedError(self.mlir_dtype) + splat_op = lambda x: x case _: raise ValueError(f"Unrecognized reduction operator: {op}") - if self.layout != WGMMA_LAYOUT: - raise NotImplementedError(self.layout) - if axis != 1: + match self.layout: + case WGStridedFragLayout(shape=_, vec_size=vec_size): + if set(axis) != set(range(len(self.shape))): + raise NotImplementedError( + "Warpgroup strided layout only support reductions along all axes" + ) + # We reinterpret the data as a tiled layout. We're reducing it all anyway. + layout = TiledLayout( + tiling=Tiling(((128 * vec_size,), (32 * vec_size,), (vec_size,))), + warp_dim=-3, + lane_dims=(-2,), + vector_dim=-1, + ) + return FragmentedArray( + _registers=self.registers.reshape( + layout.registers_shape((math.prod(self.shape),)) + ), + _layout=layout, + _is_signed=self.is_signed, + ).reduce(op, 0, scratch) + case WGSplatFragLayout(): + if splat_op is None: + raise NotImplementedError( + "Splat reductions only supported when the operator is a string" + ) + assert not self.registers.shape + return FragmentedArray( + _registers=np.asarray( + splat_op(self.registers.item()), dtype=object + ), + _layout=WGSplatFragLayout( + tuple(d for a, d in enumerate(self.shape) if a not in axis) + ), + _is_signed=self.is_signed, + ) + case TiledLayout(): + pass + case _: + raise NotImplementedError(self.layout) + if len(self.layout.base_tile_shape) != len(self.shape): raise NotImplementedError + if isinstance(axis, int): + axis = (axis,) + layout = self.layout + tiled_tiling_shape = layout.tiled_tiling_shape + reduced_dims = layout.tiling.tile_dimension(axis[0]) + for a in axis[1:]: + reduced_dims = [ + r or d for r, d in zip(reduced_dims, layout.tiling.tile_dimension(a), strict=True) + ] + regs_shape = self.registers.shape + reduced_shape = tuple( + d if r else 1 for r, d in zip(reduced_dims, regs_shape, strict=True) + ) + remaining_shape = tuple( + 1 if r else d for r, d in zip(reduced_dims, regs_shape) + ) + out_regs = np.empty(remaining_shape, dtype=object) index = ir.IndexType.get() - i32 = ir.IntegerType.get_signless(32) - row_tile_dim = self.registers.shape[0] - row_subtile_dim = self.registers.shape[4] - new_regs = np.empty((row_tile_dim, row_subtile_dim), dtype=object) - assert self.registers.shape[-1] == 1 - for row_tile, row_subtile in np.ndindex(new_regs.shape): - # Reduce the registers owned by the current thread over n tiles - reg_index = [0] * self.registers.ndim - reg_index[0] = row_tile - reg_index[4] = row_subtile - thread_result_vec = self.registers[tuple(reg_index)] - for n_tile in range(1, self.registers.shape[1]): - reg_index[1] = n_tile - thread_result_vec = op( - thread_result_vec, self.registers[tuple(reg_index)] + for out_idx in np.ndindex(remaining_shape): + out_reg = None + for red_idx in np.ndindex(reduced_shape): + src_idx = tuple(o + r for o, r in zip(out_idx, red_idx)) + if out_reg is None: + out_reg = self.registers[src_idx] + else: + out_reg = op(out_reg, self.registers[src_idx]) + # Reduce within the vector dimension, if necessary. + if reduced_dims[layout.vector_dim]: + [vec_len] = ir.VectorType(out_reg.type).shape + scalar_out_reg = None + for i in range(vec_len): + scalar = vector.extractelement(out_reg, position=c(i, index)) + scalar_out_reg = ( + scalar if scalar_out_reg is None else op(scalar_out_reg, scalar) + ) + out_reg = vector.splat( + ir.VectorType.get((1,), out_reg.type.element_type), scalar_out_reg ) - - thread_result = vector.extractelement(thread_result_vec, position=c(0, index)) - for i in range(1, self.layout.vector_length): - thread_result = op( - thread_result, - vector.extractelement(thread_result_vec, position=c(i, index)), + # Reduce across warp lanes, if necessary (using warp shuffles). + if any(reduced_dims[d] for d in layout.partitioned_lane_dims): + if utils.bitwidth(out_reg.type) > 32: + raise NotImplementedError # Need to implement wide shfl_bfly. + lane_stride = 1 + for d in layout.lane_dims[::-1]: # Iterate minor-to-major + if isinstance(d, Replicated): + lane_stride *= d.times + elif not reduced_dims[d]: + lane_stride *= tiled_tiling_shape[d] + else: + assert lane_stride.bit_count() == 1 + reduction_size = tiled_tiling_shape[d] + while reduction_size > 1: + other_out_reg = utils.shfl_bfly(out_reg, lane_stride) + out_reg = op(out_reg, other_out_reg) + lane_stride *= 2 + reduction_size //= 2 + assert lane_stride == WARP_SIZE, lane_stride + # Reduce across warps in the warpgroup, if necessary. + if ( + not isinstance(layout.warp_dim, Replicated) + and reduced_dims[layout.warp_dim] + ): + if scratch is None: + raise ValueError( + "scratch must be provided when cross-warp reduction is required" + ) + [vec_len] = ir.VectorType(out_reg.type).shape + scratch_ty = ir.MemRefType(scratch.type) + if scratch_ty.rank != 1: + raise ValueError(f"Expected rank 1 for scratch, got {scratch_ty.rank}") + if scratch_ty.element_type != self.mlir_dtype: + raise ValueError( + f"Expected element type {self.mlir_dtype} for scratch, got" + f" {scratch_ty.element_type}" + ) + # TODO(apaszke): All lanes that replicate data can share the same scratch. + # For now we treat the complete reduction as a special case. + reduces_all_dims = set(axis) == set(range(len(self.shape))) + unique_lanes = 1 if reduces_all_dims else 32 + if scratch_ty.shape[0] < WARPS_IN_WARPGROUP * unique_lanes * vec_len: + raise ValueError("Insufficient scratch space for cross-warp reduction") + if scratch_ty.get_strides_and_offset()[0] != [1]: + raise ValueError("Expected scratch to be contiguous") + thread_idx = utils.thread_idx() + if reduces_all_dims: + lane_idx = c(0, i32) + else: + lane_idx = arith.remui(thread_idx, c(WARP_SIZE, i32)) + warp_idx = arith.divui( + arith.remui(thread_idx, c(WARPGROUP_SIZE, i32)), c(WARP_SIZE, i32) ) - - # Do a shuffle to reduce in groups of 4 consecutive threads. - result = thread_result - for i in (1, 2): - other_result = nvvm.shfl_sync( - result.type, - c(0xFFFFFFFF, i32), - result, - c(i, i32), - c(0x1F, i32), - nvvm.ShflKind.bfly, + spill_base = arith.muli(lane_idx, c(WARPS_IN_WARPGROUP, i32)) + store_idx = arith.index_cast(index, arith.addi(spill_base, warp_idx)) + vector.store( + out_reg, scratch, [arith.muli(store_idx, c(vec_len, index))] + ) + utils.warpgroup_barrier() + scratch_vec = vector.load( + ir.VectorType.get((WARPS_IN_WARPGROUP * vec_len,), self.mlir_dtype), + scratch, + [arith.muli(arith.index_cast(index, spill_base), c(vec_len, index))], ) - result = op(result, other_result) - new_regs[row_tile, row_subtile] = result + out_reg = None + for w in range(WARPS_IN_WARPGROUP): + part = utils.vector_slice(scratch_vec, slice(w * vec_len, (w + 1) * vec_len)) + out_reg = part if out_reg is None else op(out_reg, part) + utils.warpgroup_barrier() # Make sure everyone is done using scratch. + out_regs[out_idx] = out_reg + # Infer the output layout and reshape the registers accordingly. + reduced_logical_shape = list(self.shape) + for a in sorted(axis, reverse=True): + del reduced_logical_shape[a] + if not reduced_logical_shape: # Complete reduction results in a splat. + reduced_layout = WGSplatFragLayout(()) + assert out_regs.size == 1 + out_reg = out_regs.flat[0] + assert ir.VectorType(out_reg.type).shape == [1] + out_reg = vector.extractelement(out_reg, position=c(0, index)) + out_regs = np.asarray(out_reg, dtype=object) + else: + reduced_layout = layout.reduce(axis) + out_regs = out_regs.reshape(reduced_layout.registers_shape(reduced_logical_shape)) return FragmentedArray( - _registers=new_regs, _layout=WGMMA_ROW_LAYOUT, _is_signed=self.is_signed + _registers=out_regs, _layout=reduced_layout, _is_signed=self.is_signed ) def broadcast(self, shape): @@ -1709,22 +2084,47 @@ def reshape(self, shape): ) def broadcast_minor(self, n): - if self.layout != WGMMA_ROW_LAYOUT: - raise NotImplementedError + if self.layout == WGMMA_ROW_LAYOUT: + output_layout = WGMMA_LAYOUT + elif self.layout == TCGEN05_ROW_LAYOUT: + output_layout = TCGEN05_LAYOUT + else: + raise NotImplementedError(self.layout) if n % 8: raise ValueError("Number of columns must be divisible by 8") - reg_shape = WGMMA_LAYOUT.registers_shape((self.shape[0], n)) + reg_shape = output_layout.registers_shape((self.shape[0], n)) new_regs = np.empty(reg_shape, dtype=object) dtype = self.mlir_dtype - for (row_tile, row_subtile), reg in np.ndenumerate(self.registers): + i0 = arith.constant(ir.IndexType.get(), 0) + for (row_tile, _, row_subtile, *__), reg in np.ndenumerate(self.registers): tile = [slice(None)] * len(new_regs.shape) tile[0] = row_tile tile[4] = row_subtile new_regs[tuple(tile)] = vector.splat( - ir.VectorType.get((WGMMA_LAYOUT.vector_length,), dtype), reg + ir.VectorType.get((output_layout.vector_length,), dtype), + vector.extractelement(reg, position=i0), ) return FragmentedArray( - _registers=new_regs, _layout=WGMMA_LAYOUT, _is_signed=self.is_signed + _registers=new_regs, _layout=output_layout, _is_signed=self.is_signed + ) + + def broadcast_major(self, m): + if self.layout == WGMMA_COL_LAYOUT: + output_layout = WGMMA_LAYOUT + elif self.layout == TCGEN05_COL_LAYOUT: + output_layout = TCGEN05_LAYOUT + else: + raise NotImplementedError(self.layout) + if m % 64: + raise ValueError("Number of rows must be divisible by 64") + reg_shape = output_layout.registers_shape((m, self.shape[0])) + new_regs = np.empty(reg_shape, dtype=object) + for (col_tile, *_), reg in np.ndenumerate(self.registers): + tile = [slice(None)] * len(new_regs.shape) + tile[1] = col_tile + new_regs[tuple(tile)] = reg + return FragmentedArray( + _registers=new_regs, _layout=output_layout, _is_signed=self.is_signed ) def select(self, on_true, on_false): @@ -1739,6 +2139,21 @@ def select(self, on_true, on_false): lambda t, p, f: arith.select(p, t, f), self, on_false, ) + @classmethod + def build( + cls, + shape: tuple[int, ...], + layout: FragmentedLayout, + fn: Callable[..., ir.Value], # ir.Value varargs, one for each dim + *, + is_signed: bool | None = None, + ): + undef = llvm.mlir_undef(ir.IntegerType.get_signless(32)) + dummy = cls.splat(undef, shape, layout, is_signed=False) + return dummy.foreach( + lambda _, idx: fn(*idx), create_array=True, is_signed=is_signed + ) + def foreach( self, fn: Callable[[ir.Value, tuple[ir.Value, ...]], ir.Value | None], @@ -1749,17 +2164,33 @@ def foreach( """Call a function for each value and index.""" index = ir.IndexType.get() new_regs = None - if create_array: - new_regs = np.full_like(self.registers, llvm.mlir_undef(self.registers.flat[0].type)) + orig_fn = fn + def fn(*args): + nonlocal new_regs + result = orig_fn(*args) + old_reg_type = self.registers.flat[0].type + # Lazily create new_regs once we know the desired output type. + if create_array and new_regs is None: + if ir.VectorType.isinstance(old_reg_type): + new_reg_type = ir.VectorType.get(old_reg_type.shape, result.type) + else: + new_reg_type = result.type + new_regs = np.full_like(self.registers, llvm.mlir_undef(new_reg_type)) + return result for mlir_idx, reg_idx in zip(self.layout.thread_idxs(self.shape), np.ndindex(self.registers.shape), strict=True): reg = self.registers[reg_idx] assert len(mlir_idx) == len(self.shape), (mlir_idx, self.shape) - [elems] = ir.VectorType(reg.type).shape - for i in range(elems): - i = c(i, index) - val = fn(vector.extractelement(reg, position=i), (*mlir_idx[:-1], arith.addi(mlir_idx[-1], i))) + if ir.VectorType.isinstance(reg.type): + [elems] = ir.VectorType(reg.type).shape + for i in range(elems): + i = c(i, index) + val = fn(vector.extractelement(reg, position=i), (*mlir_idx[:-1], arith.addi(mlir_idx[-1], i))) + if create_array: + new_regs[reg_idx] = vector.insertelement(val, new_regs[reg_idx], position=i) + else: + val = fn(reg, mlir_idx) if create_array: - new_regs[reg_idx] = vector.insertelement(val, new_regs[reg_idx], position=i) + new_regs[reg_idx] = val if create_array: return FragmentedArray(_registers=new_regs, _layout=self.layout, _is_signed=is_signed) @@ -1771,37 +2202,59 @@ def _(val, idx): fmt_str = fmt.format(f"[{idx_fmt}]: {{}}") utils.debug_print(fmt_str, *idx, val, uniform=False) - def store_untiled(self, ref: ir.Value, *, vector_store: bool = True): + def store_untiled( + self, ref: ir.Value, *, swizzle: int = 16, optimized: bool = True + ): if not ir.MemRefType.isinstance(ref.type): raise ValueError(ref) - - def vs_unsupported(): - if not vector_store: - raise NotImplementedError( - f"Can't use non-vector stores with layout {self.layout}" - ) - match self.layout: - case WGMMARowFragLayout(): - self._store_untiled_wgmma_row(ref) case WGSplatFragLayout(): - vs_unsupported() + # All values are the same so swizzle does not affect anything here. self._store_untiled_splat(ref) case WGStridedFragLayout(): - vs_unsupported() + if swizzle != 16: + raise NotImplementedError self._store_untiled_wg_strided(ref) case TiledLayout(): - self._store_untiled_tiled(ref, vector_store=vector_store) + ref_shape = ir.MemRefType(ref.type).shape + ref = utils.memref_reshape(ref, (*(1 for _ in ref_shape), *ref_shape)) + self.store_tiled(ref, swizzle=swizzle, optimized=optimized) case _: raise NotImplementedError(self.layout) + @classmethod + def load_untiled( + cls, + ref: ir.Value, + *, + layout: TiledLayout, + swizzle: int = 16, + is_signed: bool | None = None, + optimized: bool = True, + ): + ref_shape = ir.MemRefType(ref.type).shape + ref = utils.memref_reshape(ref, (*(1 for _ in ref_shape), *ref_shape)) + return cls.load_tiled( + ref, swizzle=swizzle, is_signed=is_signed, layout=layout, optimized=optimized + ) + def _store_untiled_splat(self, ref: ir.Value): + if math.prod(self.shape) == 1: + c0 = c(0, ir.IndexType.get()) + memref.store( + self.registers.flat[0], ref, [c0] * len(ir.MemRefType(ref.type).shape) + ) + return + vec_size = 64 // mgpu.bitwidth(self.mlir_dtype) if np.prod(self.shape) < vec_size * WARPGROUP_SIZE: vec_size = 1 if np.prod(self.shape) % WARPGROUP_SIZE * vec_size: - raise ValueError(self.shape, WARPGROUP_SIZE, vec_size) + raise NotImplementedError( + "Arrays with the splat layout can only be stored when they have a" + f" single element or a multiple of {WARPGROUP_SIZE} elements" + ) fa = FragmentedArray.splat( self.registers.flat[0], @@ -1823,82 +2276,22 @@ def _store_untiled_wg_strided(self, ref: ir.Value): idxs = ([i] for i in self.layout.linear_thread_idxs()) except NotImplementedError: ref_ = ref - idxs = self.layout.thread_idxs() + idxs = self.layout.thread_idxs(self.shape) ref_shape = tuple(ref_ty.shape) if ref_shape != self.shape: raise ValueError((ref_shape, self.shape)) for idx, reg in zip(idxs, self.registers.flat): vector.store(reg, ref_, idx) - def _store_untiled_wgmma_row(self, ref: ir.Value): - """Stores an array with a WGMMA row layout.""" - assert self.layout == WGMMA_ROW_LAYOUT - index = ir.IndexType.get() - tid = arith.index_cast(ir.IndexType.get(), mgpu.thread_idx()) - - is_first = arith.cmpi( - arith.CmpIPredicate.eq, arith.remui(tid, c(4, index)), c(0, index) - ) - # Consecutive groups of 4 threads hold the same value in this layout, - # therefore we only need to transfer data from one of them. - with utils.when(is_first): - for (idx,), value in zip( - self.layout.thread_idxs(self.shape), self.registers.flatten() - ): - memref.store(value, ref, [idx]) - - def _store_untiled_tiled(self, ref: ir.Value, *, vector_store: bool = True): - """Stores an array with a tiled layout. Not optimized at the moment.""" - if utils.bitwidth(self.mlir_dtype) < 8: - raise NotImplementedError(f"Can't store sub-byte types ({self.mlir_dtype=})") - i32 = ir.IntegerType.get_signless(32) - layout = self.layout - assert isinstance(layout, TiledLayout) - ref_strides, _ = ir.MemRefType(ref.type).get_strides_and_offset() - if vector_store and ref_strides[layout.vector_dim] != 1: - raise NotImplementedError( - "Can't use vector stores with non-unit minormost stride" - ) - strides = layout.tiling.tile_strides(ref_strides) - smem_space = ir.Attribute.parse("#gpu.address_space") - ref_space = ir.MemRefType(ref.type).memory_space - memory_space = None - if str(ref_space) == str(smem_space): - memory_space = 3 - elif ref_space: - raise NotImplementedError(f"Unexpected ref space {ref_space}") - ptr = utils.memref_ptr(ref, memory_space=memory_space) - # Fold warp and lane offsets into the pointer once, since they are dynamic. - dyn_strides = [ - arith.constant(i32, s) for s in strides[-layout.tiled_tiling_rank :] - ] - warp_offset = utils.dyn_dot(layout.warp_indices(), dyn_strides) - lane_offset = utils.dyn_dot(layout.lane_indices(), dyn_strides) - dyn_offset = arith.addi(warp_offset, lane_offset) - ptr = utils.getelementptr(ptr, [dyn_offset], self.mlir_dtype) - # All warp tile offsets are static and can be fused into the store. - for tile_idx, reg in np.ndenumerate(self.registers): - if vector_store: - elems = [reg] - else: - index = ir.IndexType.get() - elems = [ - vector.extractelement(reg, position=c(i, index)) - for i in range(ir.VectorType(reg.type).shape[0]) - ] - for i, e in enumerate(elems): - tile_idx_local = list(tile_idx) - tile_idx_local[layout.vector_dim] += i - tile_idx_local = list(tile_idx_local) - lin_idx = sum(i * s for i, s in zip(tile_idx_local, strides, strict=True)) - reg_ptr = utils.getelementptr(ptr, [lin_idx], self.mlir_dtype) - llvm.store(e, reg_ptr) - - def store_tiled(self, ref, swizzle: int | None): + def store_tiled(self, ref, swizzle: int | None, optimized: bool = True): if not isinstance(self.layout, TiledLayout): raise NotImplementedError(self.layout) layout, shape = self.layout, self.shape - for get, _, ptr in self.transfer_tiled2(ref, swizzle, layout, shape): + # Note that the loop below will "race" for layouts that replicate data. + # However, in that case all of the racing writes store the same data, which + # is ok in the CUDA memory model. + stores = self.transfer_tiled2(ref, swizzle, layout, shape, optimized) + for get, _, ptr in stores: llvm.store(get(self.registers), ptr) @classmethod @@ -1909,6 +2302,7 @@ def load_tiled( *, is_signed: bool | None = None, layout: FragmentedLayout = WGMMA_LAYOUT, + optimized: bool = True, ): ref_ty = ir.MemRefType(ref.type) dtype = ref_ty.element_type @@ -1926,9 +2320,20 @@ def load_tiled( ), ) registers = np.full(layout.registers_shape(shape), zero, dtype=object) - reg_ty = ir.VectorType.get((layout.vector_length,), ref_ty.element_type) - for _, update, ptr in cls.transfer_tiled2(ref, swizzle, layout, shape): - update(registers, llvm.load(reg_ty, ptr)) + is_f8 = ir.FloatType.isinstance(dtype) and utils.bitwidth(dtype) == 8 + i8 = ir.IntegerType.get_signless(8) + reg_ty = ir.VectorType.get((layout.vector_length,), dtype) + # f8 data types are not handled by the LLVM dialect, so we need to + # transfer them as i8 and bitcast them back to f8. + transfer_ty = ir.VectorType.get( + (layout.vector_length,), i8 if is_f8 else dtype + ) + loads = cls.transfer_tiled2(ref, swizzle, layout, shape, optimized) + for _, update, ptr in loads: + loaded_reg = llvm.load(transfer_ty, ptr) + if is_f8: + loaded_reg = vector.bitcast(reg_ty, loaded_reg) + update(registers, loaded_reg) case _: raise NotImplementedError(layout) return cls(_registers=registers, _layout=layout, _is_signed=is_signed) @@ -2023,6 +2428,7 @@ def transfer_tiled2( swizzle: int | None, layout: TiledLayout, shape: tuple[int, ...], + optimized: bool = True, ): """Generate a transfer schedule for a tiled layout. @@ -2053,10 +2459,12 @@ def transfer_tiled2( raise ValueError() nested_ref_shape = tuple( (ref_ty.shape[i], ref_ty.shape[i + ref_logical_rank]) + if ref_ty.shape[i + ref_logical_rank] != 1 else (ref_ty.shape[i],) for i in range(ref_logical_rank) ) nested_ref_strides = tuple( (ref_strides[i], ref_strides[i + ref_logical_rank]) + if ref_ty.shape[i + ref_logical_rank] != 1 else (ref_strides[i],) for i in range(ref_logical_rank) ) tiled_nested_shape, tiled_nested_strides = tiling.tile_nested_shape_strides( @@ -2074,12 +2482,18 @@ def transfer_tiled2( raise NotImplementedError("Memory and register tiling incompatible") tiled_shape = list(itertools.chain.from_iterable(tiled_nested_shape)) elem_tiled_strides = list(itertools.chain.from_iterable(tiled_nested_strides)) - elem_lane_strides = [elem_tiled_strides[d] for d in layout.lane_dims] - lane_shape = [tiled_shape[d] for d in layout.lane_dims] + lane_shape = [ + d.times if isinstance(d, Replicated) else tiled_shape[d] for d in layout.lane_dims + ] + lane_strides = [ + 0 if isinstance(d, Replicated) else elem_tiled_strides[d] for d in layout.lane_dims + ] if elem_tiled_strides[layout.vector_dim] != 1: raise ValueError("Stride of the vectorized dimension should be 1") - for d in (layout.warp_dim, *layout.lane_dims, layout.vector_dim): + for d in (*layout.partitioned_lane_dims, layout.vector_dim): tiled_shape[d] = 1 + if not isinstance(layout.warp_dim, Replicated): + tiled_shape[layout.warp_dim] = 1 element_bits = mgpu.bitwidth(dtype) if (layout.vector_length * element_bits) % 8 != 0: @@ -2112,12 +2526,29 @@ def transfer_tiled2( # Technically we should keep the vector_dim set to 1, but its shape is 1 # so it does not matter. transfer_tiled_strides = [s // layout.vector_length for s in elem_tiled_strides] - transfer_dtype = ir.VectorType.get((layout.vector_length,), dtype) + is_f8 = ir.FloatType.isinstance(dtype) and element_bits == 8 + i8 = ir.IntegerType.get_signless(8) + if is_f8: + transfer_dtype = ir.VectorType.get((layout.vector_length,), i8) + else: + transfer_dtype = ir.VectorType.get((layout.vector_length,), dtype) - plan = plan_tiled_transfer( - tiled_shape, elem_tiled_strides, lane_shape, elem_lane_strides, layout, - element_bits, swizzle - ) + if ref_ty.memory_space is None: + llvm_memory_space = None + elif ref_ty.memory_space == ir.Attribute.parse("#gpu.address_space"): + llvm_memory_space = 3 + else: + raise ValueError(f"Unsupported memory space: {ref_ty.memory_space}") + + if optimized: + if llvm_memory_space != 3: + raise NotImplementedError("Only optimized transfers to SMEM supported") + plan = plan_tiled_transfer( + tiled_shape, elem_tiled_strides, lane_shape, lane_strides, + layout, element_bits, swizzle + ) + else: + plan = TrivialTransferPlan() # All offsets are in units of transfer_dtype. dyn_tiled_strides = [ @@ -2126,9 +2557,7 @@ def transfer_tiled2( lane_offset = utils.dyn_dot(layout.lane_indices(), dyn_tiled_strides) warp_offset = utils.dyn_dot(layout.warp_indices(), dyn_tiled_strides) dyn_offset = arith.addi(lane_offset, warp_offset) - if ref_ty.memory_space != ir.Attribute.parse("#gpu.address_space"): - raise ValueError("Tiled stores can be performed into SMEM") - ptr = utils.memref_ptr(ref, memory_space=3) + ptr = utils.memref_ptr(ref, memory_space=llvm_memory_space) _as_consts = lambda consts: [c(const) for const in consts.tolist()] # This has bits set only for the offset bits that influence swizzling. swizzle_mask = swizzle_block_transfers - swizzle_tile_transfers @@ -2170,7 +2599,13 @@ def mem_idx_to_reg_idx(idx): return (*reg_tiled_idx, *idx[base_idx:]) reg_idxs = [mem_idx_to_reg_idx(idx) for idx in indices.tolist()] def get_register(regs, reg_idxs=reg_idxs): - return plan.select([regs[reg_idx] for reg_idx in reg_idxs]) + def cast_if_f8(x): + if is_f8: + return vector.bitcast(transfer_dtype, x) + return x + # f8 data types are not handled by the LLVM dialect, so we need to + # transfer them as i8 and bitcast them back to f8. + return plan.select([cast_if_f8(regs[reg_idx]) for reg_idx in reg_idxs]) def update_registers(regs, new, reg_idxs=reg_idxs): # TODO(apaszke): If the staggering forms a permutation with a small # cycle length, then instead of blending at each step we could construct @@ -2289,7 +2724,7 @@ def plan_tiled_transfer( raise ValueError( "Failed to prove that vector transfers don't cross swizzle tile" " boundaries. This check is incomplete, and does not guarantee that" - " this is a user error, but it might be." + str(transfer_alignment) + f" this is a user error, but it might be. {transfer_alignment=}" ) # 2. The transfer pattern does not cause bank conflicts. @@ -2307,9 +2742,14 @@ def plan_tiled_transfer( num_wavefronts = max(transfer_bytes // smem_bank_bytes, 1) wavefront_lanes = WARP_SIZE // num_wavefronts + lane_mask = np.full(lane_shape, False) + lane_mask[tuple(slice(0, 1) if s == 0 else slice(None) for s in lane_strides)] = True + wavefront_mask = lane_mask.reshape(num_wavefronts, wavefront_lanes) + lane_offsets_in_tile = np.dot(list(np.ndindex(*lane_shape)), lane_strides) def has_bank_conflicts(tile_idx_transform): - tile_idxs = np.unravel_index(np.arange(math.prod(tiled_shape)), tiled_shape) + num_tiles = math.prod(tiled_shape) + tile_idxs = np.unravel_index(np.arange(num_tiles), tiled_shape) tile_idxs = np.expand_dims(np.stack(tile_idxs, 1), 1) # [#tiles, 1, #dims] lane_tile_idx = tile_idx_transform(tile_idxs) # [#tiles, #lanes/1, #dims] assert lane_tile_idx.shape[1] in {1, WARP_SIZE} @@ -2320,10 +2760,17 @@ def has_bank_conflicts(tile_idx_transform): swizzle_bits = swizzle_groups * swizzle_tile_elems lane_banks = ((offsets ^ swizzle_bits) // elems_per_bank) % num_banks wavefront_banks = lane_banks.reshape(-1, num_wavefronts, wavefront_lanes) - # Order of threads within the wavefront is unimportant. - wavefront_banks = np.sort(wavefront_banks, axis=-1) - # There are no conflicts if each wavefront only contains unique banks. - return np.any(wavefront_banks[..., 1:] == wavefront_banks[..., :-1]) + # We step over wavefronts since they might have a different number of lanes. + wavefront_banks = wavefront_banks.swapaxes(0, 1) + for banks, mask in zip(wavefront_banks, wavefront_mask): + banks = banks[:, mask] + # Order of threads within the wavefront is unimportant. + banks = np.sort(banks, axis=-1) + # There are no conflicts if each wavefront only contains unique banks. + repeats = np.any(banks[..., 1:] == banks[..., :-1]) + if repeats: + return True + return False # We don't need any special treatment if there are no conflicts when each lane # transfers the same tile at a time. @@ -2386,16 +2833,29 @@ def optimization_barrier(*arrays: mgpu.FragmentedArray): index = ir.IndexType.get() i32 = ir.IntegerType.get_signless(32) + def _repack(regs_it, reg_ty): + if not ir.VectorType.isinstance(reg_ty): + result_reg = next(regs_it) + assert result_reg.type == reg_ty + return result_reg + + num_i32_regs = utils.bitwidth(reg_ty) // 32 + i32_reg_ty = ir.VectorType.get((num_i32_regs,), i32) + reg = llvm.mlir_undef(i32_reg_ty) + for i_elem in range(num_i32_regs): + val = llvm.bitcast(i32, next(regs_it)) + reg = llvm.insertelement(reg, val, arith.constant(i32, i_elem)) + return vector.bitcast(reg_ty, reg) + regs = [] reg_dtypes = [] reg_constraints = [] - repack_fns = [] # We unpack each array into a flat list of registers, and prepare the # functions that invert the transform in repack_fns. for array in arrays: reg_ty = array.registers.flat[0].type dtype = array.mlir_dtype - if ir.F32Type.isinstance(dtype): + if ir.F32Type.isinstance(dtype) or dtype == i32: if ir.VectorType.isinstance(reg_ty): [vec_len] = ir.VectorType(reg_ty).shape array_regs = [ # pylint: disable=g-complex-comprehension @@ -2403,36 +2863,25 @@ def optimization_barrier(*arrays: mgpu.FragmentedArray): for reg in array.registers.flat for pos in range(vec_len) ] - def _repack(regs, reg_ty=reg_ty): - reg = llvm.mlir_undef(reg_ty) - [vec_len] = ir.VectorType(reg_ty).shape - for i_elem in range(vec_len): - reg = llvm.insertelement( - reg, next(regs), arith.constant(i32, i_elem) - ) - return reg - repack_fns.append(_repack) else: array_regs = list(array.registers.flat) - repack_fns.append(lambda regs: next(regs)) - reg_constraint = "f" + reg_constraint = "r" if dtype == i32 else "f" elif ir.BF16Type.isinstance(dtype) or ir.F16Type.isinstance(dtype): if not ir.VectorType.isinstance(reg_ty): raise NotImplementedError(array.mlir_dtype) [vec_len] = ir.VectorType(reg_ty).shape - if vec_len != 2: + if vec_len % 2: raise NotImplementedError(vec_len) - i32_reg_ty = ir.VectorType.get((1,), i32) + num_i32_regs = vec_len // 2 + i32_reg_ty = ir.VectorType.get((num_i32_regs,), i32) array_regs = [ vector.extractelement( - vector.bitcast(i32_reg_ty, reg), position=c(0, index) + vector.bitcast(i32_reg_ty, reg), position=c(i, index) ) + for i in range(num_i32_regs) for reg in array.registers.flat ] reg_constraint = "r" - def _repack(regs, reg_ty=reg_ty, i32_reg_ty=i32_reg_ty): - return vector.bitcast(reg_ty, vector.splat(i32_reg_ty, next(regs))) - repack_fns.append(_repack) else: raise NotImplementedError(array.mlir_dtype) regs += array_regs @@ -2446,28 +2895,39 @@ def _repack(regs, reg_ty=reg_ty, i32_reg_ty=i32_reg_ty): all_reg_constraints = ",".join( [*("=" + c for c in reg_constraints), *reg_constraints] ) - struct_ty = ir.Type.parse( - f"!llvm.struct<({','.join(map(str, reg_dtypes))})>" - ) - result_struct = llvm.inline_asm( - struct_ty, regs, ptx, all_reg_constraints, - asm_dialect=0, has_side_effects=True, - ) - regs = [ - llvm.extractvalue(dtype, result_struct, [i]) - for i, dtype in enumerate(reg_dtypes) - ] + + if len(reg_dtypes) == 1: + # The InlineAsm::verify() function doesn't allow a struct output when there + # is only one element (even though that seems to work for the case below). + result_elem = llvm.inline_asm( + reg_dtypes[0], regs, ptx, all_reg_constraints, + asm_dialect=0, has_side_effects=True, + ) + regs = [result_elem] + else: + struct_ty = ir.Type.parse( + f"!llvm.struct<({','.join(map(str, reg_dtypes))})>" + ) + result_struct = llvm.inline_asm( + struct_ty, regs, ptx, all_reg_constraints, + asm_dialect=0, has_side_effects=True, + ) + regs = [ + llvm.extractvalue(dtype, result_struct, [i]) + for i, dtype in enumerate(reg_dtypes) + ] + i32 = ir.IntegerType.get_signless(32) results = [] regs_it = iter(regs) - for array, repack_fn in zip(arrays, repack_fns, strict=True): + for array in arrays: num_regs = array.registers.size reg_ty = array.registers.flat[0].type if ir.VectorType.isinstance(reg_ty): reg_ty = ir.VectorType(reg_ty) new_registers = np.empty((num_regs,), dtype=object) for i_vreg in range(num_regs): - reg = repack_fn(regs_it) + reg = _repack(regs_it, reg_ty) assert reg.type == reg_ty, (reg.type, reg_ty) new_registers[i_vreg] = reg results.append( diff --git a/jax/experimental/mosaic/gpu/inference_utils.py b/jax/experimental/mosaic/gpu/inference_utils.py index 6362626404c5..2641b76da8fc 100644 --- a/jax/experimental/mosaic/gpu/inference_utils.py +++ b/jax/experimental/mosaic/gpu/inference_utils.py @@ -18,11 +18,11 @@ import enum from functools import partial import itertools -from typing import cast +from typing import cast, Union from jax._src.lib.mlir import ir -MlirOperation = ir.Operation | ir.OpView +MlirOperation = Union[ir.Operation, ir.OpView] def in_layouts(op: MlirOperation) -> Sequence[ir.Attribute]: """Returns the in_layouts attribute of the given operation. @@ -95,6 +95,22 @@ def has_out_transforms_set(op: MlirOperation) -> bool: return "out_transforms" in op.attributes +def attr_element( + attr_name: str, op: MlirOperation, index: int +) -> ir.Attribute | None: + """Returns `op.attributes[attr_name][index]` if it exists, otherwise None. + + If `op.attributes[attr_name]` exists, then `index` must be a valid index into + the attribute array. + """ + if attr_name not in op.attributes: + return None + attr = op.attributes[attr_name] + if not attr: + return None + return op.attributes[attr_name][index] # type: ignore + + def _in_attr_for_operand( op: MlirOperation, operand: ir.Value, @@ -109,9 +125,7 @@ def _in_attr_for_operand( operand_number = [o for o in op.operands if predicate(o)].index(operand) - if attr_name not in op.attributes: - return None - return op.attributes[attr_name][operand_number] # type: ignore + return attr_element(attr_name, op, operand_number) in_layout_for_operand = partial( @@ -121,6 +135,15 @@ def _in_attr_for_operand( _in_attr_for_operand, attr_name="in_transforms" ) +def should_have_transforms(op: ir.OpView) -> bool: + """Returns 'True' if the operation should be assigned in/out transforms.""" + return any( + map( + is_transformable_smem_memref, + itertools.chain(op.operands, op.results), + ) + ) + def is_transformable_smem_memref(v: ir.Value) -> bool: """Whether the value is a memref in SMEM on which transforms should be applied.""" barrier_ty = ir.Type.parse("!mosaic_gpu.barrier") diff --git a/jax/experimental/mosaic/gpu/launch_context.py b/jax/experimental/mosaic/gpu/launch_context.py index ce432f26dac2..852ac90c0d73 100644 --- a/jax/experimental/mosaic/gpu/launch_context.py +++ b/jax/experimental/mosaic/gpu/launch_context.py @@ -19,11 +19,13 @@ import enum import functools import math -from typing import Any +from typing import Any, Literal from jax._src.lib import mosaic_gpu_dialect as mgpu_dialect +from jax._src import lib as jaxlib from jaxlib.mlir import ir from jaxlib.mlir.dialects import arith +from jaxlib.mlir.dialects import builtin from jaxlib.mlir.dialects import func from jaxlib.mlir.dialects import gpu from jaxlib.mlir.dialects import llvm @@ -158,7 +160,7 @@ class TransposeTransform(MemRefTransform): def __post_init__(self): if len(self.permutation) != len(set(self.permutation)): - raise ValueError("Permutation must be a permutation") + raise ValueError("All elements of `permutation` must be unique") def apply(self, ref: ir.Value) -> ir.Value: return utils.memref_transpose(ref, self.permutation) @@ -228,21 +230,112 @@ def batch(self, leading_rank: int) -> MemRefTransform: OnDeviceProfiler = profiler.OnDeviceProfiler +ReductionOp = Literal["add", "min", "max", "inc", "dec", "and", "or", "xor"] + +class Scratch: + """Manages ops handling the GMEM scratch that contains the TMA descriptors. + + TMA descriptors are created on the host and then copied to GMEM. So there + needs to be some code on the host to allocate and initialize the TMA + descriptors. However, we only know what descriptors we need after we have + lowered the entire kernel. This class helps manage everything needed to + correctly allocate and initialize the scratch. + + To help reconcile the needs of kernels that use the dialect lowering with + those that use MGPU APIs directly, this class only creates the relevant ops + lazily. Eager creation would make them appear dead before dialect lowering + and MLIR's DCE would remove them. + + During the lowering, we collect information about how many bytes are needed + and also how each descriptor should be initialized on the host. At the end + of the lowering, the finalize_size() method should be called to add the + necessary code on the host to allocate and initialize all descriptors. + """ + def __init__(self, gpu_launch_op: gpu.LaunchOp): + self.next_offset: int = 0 + self.host_init: list[Callable[[ir.Value], None]] = [] + self._alloc_op = None + self._load_op = None + self._scratch_ptr = None + + # Ideally, we would store the gpu.launch op directly. However, it gets + # invalidated by passes like "canonicalize". Thus we store the module and + # find the gpu.launch op from there when needed. + op = gpu_launch_op + while op.name != "builtin.module": + op = op.parent.opview + assert op is not None + self._module_op = op + + def _find_gpu_launch_op(self, block: ir.Block) -> ir.OpView | None: + for op in block: + if op.name == "gpu.launch": + return op + for region in op.regions: + for block in region: + child_op = self._find_gpu_launch_op(block) + if child_op is not None: + return child_op + return None + + def _create_ops_if_none(self): + if self._alloc_op is not None: + return + + gpu_launch_op = self._find_gpu_launch_op(self._module_op.body) + assert gpu_launch_op is not None + ptr_ty = ir.Type.parse("!llvm.ptr") + with ir.InsertionPoint(gpu_launch_op): + empty_arr_ty = ir.Type.parse("!llvm.array<0 x i8>") + i64 = ir.IntegerType.get_signless(64) + self._alloc_op = llvm.AllocaOp( + ptr_ty, c(1, i64), empty_arr_ty, + alignment=TMA_DESCRIPTOR_ALIGNMENT + ) + self._load_op = llvm.LoadOp(empty_arr_ty, self._alloc_op) + + with ir.InsertionPoint.at_block_begin(gpu_launch_op.body.blocks[0]): + self._scratch_ptr = builtin.unrealized_conversion_cast( + [ptr_ty], [self._load_op] + ) + + def device_ptr(self) -> ir.Value: + self._create_ops_if_none() + return self._scratch_ptr + + def finalize_size(self): + """ + Allocates and initializes the host buffer. This needs to be done after + lowering, i.e. after all TMA descriptors have been recorded. Only then we + know what the scratch contains. + """ + if self.next_offset == 0: + return + assert self._alloc_op is not None + with ir.InsertionPoint(self._load_op): + gmem_scratch_bytes = self.next_offset + scratch_arr_ty = ir.Type.parse(f"!llvm.array<{gmem_scratch_bytes} x i8>") + self._alloc_op.elem_type = ir.TypeAttr.get(scratch_arr_ty) + self._load_op.result.set_type(scratch_arr_ty) + for init_callback in self.host_init: + init_callback(self._alloc_op.result) + + +class _DefaultPredicate: + pass + @dataclasses.dataclass() class LaunchContext: - launch_op: gpu.LaunchOp - gmem_scratch_ptr: ir.Value + module: ir.Module + scratch: Scratch cluster_size: tuple[int, int, int] profiler: OnDeviceProfiler | None = None - next_scratch_offset: int = 0 - host_scratch_init: list[Callable[[ir.Value], None]] = dataclasses.field( - default_factory=list, init=False - ) tma_descriptors: dict[ tuple[ir.Value, tuple[int, ...], int | None, tuple[MemRefTransform, ...]], ir.Value, ] = dataclasses.field(default_factory=dict, init=False) + is_device_collective: bool = False @contextlib.contextmanager def named_region(self, *args, **kwargs): @@ -286,32 +379,40 @@ def _alloc_scratch( ptr_ty = ir.Type.parse("!llvm.ptr") if alignment is None: alignment = size - if self.next_scratch_offset % alignment: + if self.scratch.next_offset % alignment: raise NotImplementedError # TODO(apaszke): Pad to match alignment - alloc_base = self.next_scratch_offset - self.next_scratch_offset += size + alloc_base = self.scratch.next_offset + self.scratch.next_offset += size def host_init_wrapped(host_ptr): host_init( - llvm.getelementptr(ptr_ty, host_ptr, [], [alloc_base], i8) + llvm.getelementptr(ptr_ty, host_ptr, [], [alloc_base], i8, llvm.GEPNoWrapFlags.none) ) - self.host_scratch_init.append(host_init_wrapped) + self.scratch.host_init.append(host_init_wrapped) # with ir.InsertionPoint(self.gmem_scratch_ptr.owner): # There is no way to create an insertion point after an operation... gep = llvm.GEPOp( - ptr_ty, self.gmem_scratch_ptr, [], [alloc_base], i8 + ptr_ty, self.scratch.device_ptr(), [], [alloc_base], i8, llvm.GEPNoWrapFlags.none ) - gep.move_after(self.gmem_scratch_ptr.owner) + gep.move_after(self.scratch.device_ptr().owner) return device_init(gep.result) def _get_tma_desc( self, gmem_ref, gmem_transform: tuple[MemRefTransform, ...], + gmem_peer_id: int | ir.Value | None, transformed_slice_shape: tuple[int, ...], swizzle: int | None, + reduction_op: Literal[ + "add","min","max","inc","dec","and","or","xor" + ] | None, ): - tma_desc_key = (gmem_ref, transformed_slice_shape, swizzle, gmem_transform) + # Using ir.Values in cache keys is a little sketchy, but I think it should + # be fine. Having it in the key will keep it alive, and if comparison and + # hashing is by identity then it should work out. + tma_desc_key = (gmem_ref, transformed_slice_shape, swizzle, gmem_transform, gmem_peer_id) if (tma_desc := self.tma_descriptors.get(tma_desc_key, None)) is None: + i32 = ir.IntegerType.get_signless(32) i64 = ir.IntegerType.get_signless(64) ptr_ty = ir.Type.parse("!llvm.ptr") def init_tma_desc(host_ptr): @@ -320,14 +421,41 @@ def init_tma_desc(host_ptr): ref = t.apply(ref) ref_ty = ir.MemRefType(ref.type) # TODO(apaszke): Use utils.memref_ptr to compute base_ptr + strides, _ = ref_ty.get_strides_and_offset() + if strides[-1] != 1: + raise ValueError( + "TMA requires the stride of the last dimension after" + " transforming the GMEM reference to be 1, but it is" + f" {strides[-1]}." + ) + _, offset, *sizes_and_strides = memref.extract_strided_metadata(ref) aligned_ptr_idx = memref.extract_aligned_pointer_as_index(ref) as_i64 = lambda i: arith.index_cast(i64, i) alloc_ptr = llvm.inttoptr(ptr_ty, as_i64(aligned_ptr_idx)) llvm_dyn = -2147483648 # TODO(apaszke): Improve the MLIR bindings... base_ptr = llvm.getelementptr( - ptr_ty, alloc_ptr, [as_i64(offset)], [llvm_dyn], ref_ty.element_type, + ptr_ty, alloc_ptr, [as_i64(offset)], [llvm_dyn], ref_ty.element_type, llvm.GEPNoWrapFlags.none, ) + if gmem_peer_id is not None: + if not isinstance(gmem_peer_id, ir.Value): + peer_id = c(gmem_peer_id, i32) + else: + try: + # We try to reproduce the gmem_peer_id computation on the host. + peer_id = _recompute_peer_id(gmem_peer_id) + except ReplicationError as e: + raise ValueError( + "Failed to recompute the async_copy peer id on the host" + ) from e + self._ensure_nvshmem_decls() + base_ptr = llvm.call( + base_ptr.type, + [base_ptr, peer_id], + [], + [], + callee="nvshmem_ptr", + ) rank = ref_ty.rank assert rank * 2 == len(sizes_and_strides) swizzle_arg = ( @@ -337,10 +465,45 @@ def init_tma_desc(host_ptr): ) # TODO(apaszke): Better verification (e.g. slice is non-zero) # TODO(apaszke): We always know strides statically. + if jaxlib.version < (0, 5, 4): + dtype_or_bitwidth = c(utils.bitwidth(ref_ty.element_type), i64) + else: + if isinstance(ref_ty.element_type, ir.IntegerType): + if reduction_op is not None: + raise ValueError( + f"TMA with reduction_op={reduction_op} is not supported with Integers" + ) + bitwidth = utils.bitwidth_impl(ref_ty.element_type) + if bitwidth == 4: + tma_dtype = 0 + elif bitwidth == 8: + tma_dtype = 1 + elif bitwidth == 16: + tma_dtype = 2 + elif bitwidth == 32: + tma_dtype = 3 + elif bitwidth == 64: + tma_dtype = 4 + else: + raise ValueError(f"Unsupported integer bitwidth: {bitwidth}") + elif ir.F16Type.isinstance(ref_ty.element_type): + tma_dtype = 5 + elif ir.F32Type.isinstance(ref_ty.element_type): + tma_dtype = 6 + elif ir.BF16Type.isinstance(ref_ty.element_type): + tma_dtype = 7 + # We treat 8 bit floats as 8 bit integers + elif ir.Float8E5M2Type.isinstance(ref_ty.element_type): + tma_dtype = 1 + elif ir.Float8E4M3FNType.isinstance(ref_ty.element_type): + tma_dtype = 1 + else: + raise ValueError(f"unsupported TMA dtype {ref_ty.element_type}") + dtype_or_bitwidth = c(tma_dtype, i64) args = [ host_ptr, base_ptr, - c(utils.bitwidth(ref_ty.element_type), i64), + dtype_or_bitwidth, c(rank, i64), utils.pack_array([as_i64(i) for i in sizes_and_strides[:rank]]), utils.pack_array([as_i64(i) for i in sizes_and_strides[rank:]]), @@ -368,13 +531,15 @@ def async_copy( dst_ref, gmem_slice: Any = (), gmem_transform: MemRefTransform | tuple[MemRefTransform, ...] = (), + gmem_peer_id: int | ir.Value | None = None, barrier: utils.BarrierRef | None = None, swizzle: int | None = None, arrive: bool | None = None, - uniform: bool = True, collective: Sequence[gpu.Dimension] | gpu.Dimension | None = None, partitioned: int | None = None, - predicate: ir.Value | None = None, # Should select 0 or 1 threads from the WG. + # Should select 0 or 1 threads from the WG. + predicate: ir.Value | None | _DefaultPredicate = _DefaultPredicate(), + reduction_op: ReductionOp | None = None, ): """Initiates an async copy between GMEM and SMEM. @@ -415,8 +580,8 @@ def async_copy( f"Expected same element type, got {element_type} and" f" {dst_ref_ty.element_type}" ) - if predicate is not None and not uniform: - raise ValueError("Predicate can only be defined when uniform is True") + if isinstance(predicate, _DefaultPredicate): + predicate = utils.single_thread_predicate(utils.ThreadSubset.WARPGROUP) if not isinstance(gmem_transform, tuple): gmem_transform = (gmem_transform,) @@ -453,6 +618,19 @@ def async_copy( " multiple of 16 bytes" ) + if reduction_op is not None: + if not any( + t.isinstance(gmem_ref_ty.element_type) + for t in (ir.F32Type, ir.BF16Type, ir.F16Type) + ): + raise ValueError( + "TMA with reduction is only supported with f32, f16 and bf16" + ) + if reduction_op != "add": + raise ValueError( + "TMA with reduction is only supported with add operation" + ) + # NOTE: TMA supports OOB indices, so we skip the check. base_indices, slice_shape, is_squeezed = utils.parse_indices( gmem_slice, ir.MemRefType(gmem_ref.type).shape, check_oob=False @@ -597,20 +775,15 @@ def partition_dim(dim: int, idx: ir.Value, num_chunks: int): multicast_mask = None tma_desc = self._get_tma_desc( - gmem_ref, gmem_transform, tuple(slice_shape), swizzle, + gmem_ref, gmem_transform, gmem_peer_id, + tuple(slice_shape), swizzle, reduction_op, ) - # We constuct TMA descriptors in column-major order. + # We construct TMA descriptors in column-major order. rev_dyn_base_indices = [ arith.index_cast(i32, idx) for idx in reversed(dyn_base_indices) ] - uniform_ctx = ( - functools.partial(utils.single_thread, per_block=False) - if uniform and predicate is None - else contextlib.nullcontext - ) - if max(slice_shape) > 256: raise ValueError( "Async copies only support copying <=256 elements along each" @@ -618,8 +791,8 @@ def partition_dim(dim: int, idx: ir.Value, num_chunks: int): ) if (zeroth_bw := slice_shape[-1] * element_bitwidth) % 128 != 0: raise ValueError( - "Async copies require the number of bytes copied along the last" - f" dimension to be divisible by 16, but got {zeroth_bw}" + "Async copies require the number of bits copied along the last" + f" dimension to be divisible by 128, but got {zeroth_bw}" ) if ( swizzle is not None @@ -640,46 +813,60 @@ def partition_dim(dim: int, idx: ir.Value, num_chunks: int): np.prod(slice_shape) * element_bitwidth * collective_size // 8, i32 ) barrier_ptr = barrier.get_ptr() - with uniform_ctx(): - if collective_size > 1 and partitioned is not None: - if predicate is None: - predicate = c(1, ir.IntegerType.get_signless(1)) - if arrive: - first_block = arith.cmpi( - arith.CmpIPredicate.eq, self.cluster_idx(collective), c(0, index), - ) - arrive_predicate = arith.andi(predicate, first_block) - nvvm.mbarrier_arrive_expect_tx_shared( - barrier_ptr, transfer_bytes, predicate=arrive_predicate - ) - rank = len(slice_shape) - idx_operands = ",".join(f"${i}" for i in range(4, 4 + rank)) - llvm.inline_asm( - ir.Type.parse("!llvm.void"), - [predicate, smem_ptr, tma_desc, barrier_ptr, *rev_dyn_base_indices], - f""" - {{ - .reg .b32 mapped_addr; - @$0 mapa.shared::cluster.u32 mapped_addr, $3, 0; - @$0 cp.async.bulk.tensor.{rank}d.shared::cta.global.tile.mbarrier::complete_tx::bytes.cta_group::2 - [$1], [$2, {{{idx_operands}}}], [mapped_addr]; - }} - """, - "b,r,l,r" + ",r" * rank, - has_side_effects=True, + assert reduction_op is None + if collective_size > 1 and partitioned is not None: + if predicate is None: + predicate = c(1, ir.IntegerType.get_signless(1)) + if arrive: + first_block = arith.cmpi( + arith.CmpIPredicate.eq, self.cluster_idx(collective), c(0, index), ) - else: - if arrive: - nvvm.mbarrier_arrive_expect_tx_shared( - barrier_ptr, transfer_bytes, predicate=predicate - ) - nvvm.cp_async_bulk_tensor_shared_cluster_global( - smem_ptr, tma_desc, rev_dyn_base_indices, barrier_ptr, [], - multicast_mask=multicast_mask, predicate=predicate + arrive_predicate = arith.andi(predicate, first_block) + nvvm.mbarrier_arrive_expect_tx_shared( + barrier_ptr, transfer_bytes, predicate=arrive_predicate ) + rank = len(slice_shape) + idx_operands = ",".join(f"${i}" for i in range(4, 4 + rank)) + llvm.inline_asm( + ir.Type.parse("!llvm.void"), + [predicate, smem_ptr, tma_desc, barrier_ptr, *rev_dyn_base_indices], + f""" + {{ + .reg .b32 mapped_addr; + @$0 mapa.shared::cluster.u32 mapped_addr, $3, 0; + @$0 cp.async.bulk.tensor.{rank}d.shared::cta.global.tile.mbarrier::complete_tx::bytes.cta_group::2 + [$1], [$2, {{{idx_operands}}}], [mapped_addr]; + }} + """, + "b,r,l,r" + ",r" * rank, + has_side_effects=True, + ) + else: + if arrive: + nvvm.mbarrier_arrive_expect_tx_shared( + barrier_ptr, transfer_bytes, predicate=predicate + ) + nvvm.cp_async_bulk_tensor_shared_cluster_global( + smem_ptr, tma_desc, rev_dyn_base_indices, barrier_ptr, [], + multicast_mask=multicast_mask, predicate=predicate + ) else: assert multicast_mask is None - with uniform_ctx(): + if reduction_op is not None: + if predicate is None: + predicate = c(1, ir.IntegerType.get_signless(1)) + rank = len(slice_shape) + idx_operands = ",".join(f"${i}" for i in range(3, 3 + rank)) + llvm.inline_asm( + ir.Type.parse("!llvm.void"), + [predicate,smem_ptr,tma_desc,*rev_dyn_base_indices], + f"@$0 cp.reduce.async.bulk.tensor.{rank}d.global.shared::cta.{reduction_op}.tile.bulk_group [$2,{{{idx_operands}}}], [$1];", + "b,r,l" + ",r" * rank, + has_side_effects=True, + ) + if arrive: + nvvm.cp_async_bulk_commit_group() + else: nvvm.cp_async_bulk_tensor_global_shared_cta( tma_desc, smem_ptr, rev_dyn_base_indices, predicate=predicate ) @@ -691,3 +878,74 @@ def await_async_copy( ): nvvm.cp_async_bulk_wait_group(allow_groups, read=await_read_only) utils.warpgroup_barrier() + + def _ensure_nvshmem_decls(self): + if self.is_device_collective: + return + self.is_device_collective = True + with ir.InsertionPoint(self.module.body): + nvshmem_my_pe_type = ir.TypeAttr.get(ir.Type.parse("!llvm.func")) + llvm.LLVMFuncOp( + "nvshmem_my_pe", nvshmem_my_pe_type, sym_visibility="private" + ) + nvshmem_ptr_type = ir.TypeAttr.get( + ir.Type.parse("!llvm.func") + ) + llvm.LLVMFuncOp("nvshmem_ptr", nvshmem_ptr_type, sym_visibility="private") + + def to_remote(self, ref: ir.Value, peer: ir.Value): + self._ensure_nvshmem_decls() + if ir.MemRefType.isinstance(ref.type): + # We replace the offset in the ref type by 0, because memref_ptr always + # folds the offset into the pointer. + ref_ty = ir.MemRefType(ref.type) + strides, _ = ref_ty.get_strides_and_offset() + result_type = ir.MemRefType.get( + ref_ty.shape, + ref_ty.element_type, + ir.StridedLayoutAttr.get(0, strides), + ref_ty.memory_space, + ) + return utils.ptr_as_memref( + self.to_remote(utils.memref_ptr(ref), peer), result_type + ) + if ref.type != ir.Type.parse("!llvm.ptr"): + raise ValueError(f"Unsupported type for to_remote: {ref.type}") + if peer.type != ir.IntegerType.get_signless(32): + raise ValueError(f"peer index must be an i32, got {peer.type}") + return llvm.call(ref.type, [ref, peer], [], [], callee="nvshmem_ptr") + + def device_id(self) -> ir.Value: + self._ensure_nvshmem_decls() + i32 = ir.IntegerType.get_signless(32) + return llvm.call(i32, [], [], [], callee="nvshmem_my_pe") + + +class ReplicationError(Exception): + pass + +def _recompute_peer_id(peer_id: ir.Value, fuel=8) -> ir.Value: + if fuel == 0: + raise ReplicationError( + "gmem_peer_id computation is too complicated to recompute on the host" + ) + if isinstance(peer_id, ir.BlockArgument): + raise ReplicationError("Can't recompute a value that's a block argument") + op = peer_id.owner.opview + # We accept all arith ops + if op.OPERATION_NAME.startswith("arith."): + new_operands = [_recompute_peer_id(x, fuel - 1) for x in op.operands] + result_types = [r.type for r in op.results] + new_attributes = {na.name: na.attr for na in op.attributes} + new_op = ir.Operation.create( + op.OPERATION_NAME, result_types, new_operands, new_attributes + ) + return new_op.results if len(new_op.results) > 1 else new_op.result + # nvshmem_my_pe queries the device id of the current process and works on both + # the host and the device. + if isinstance(op, llvm.CallOp) and op.callee.value == "nvshmem_my_pe": + i32 = ir.IntegerType.get_signless(32) + return llvm.call(i32, [], [], [], callee="nvshmem_my_pe") + raise ReplicationError( + f"Unrecognized op can't be recomputed on the host: {op}" + ) diff --git a/jax/experimental/mosaic/gpu/layout_inference.py b/jax/experimental/mosaic/gpu/layout_inference.py index 0d2811bb5610..29a92e2d5d43 100644 --- a/jax/experimental/mosaic/gpu/layout_inference.py +++ b/jax/experimental/mosaic/gpu/layout_inference.py @@ -25,7 +25,6 @@ from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import arith from jax._src.lib.mlir.dialects import math as mlir_math -from jax._src.lib.mlir.dialects import memref from jax._src.lib.mlir.dialects import scf from jax._src.lib.mlir.dialects import vector import numpy as np @@ -44,7 +43,9 @@ def _add_layout_inference_rule(op: type[ir.OpView], rule: LayoutInferenceRule): - _layout_inference_rules[op.OPERATION_NAME] = rule # pytype: disable=attribute-error + if op is not None: + _layout_inference_rules[op.OPERATION_NAME] = rule # pytype: disable=attribute-error + return rule def _set_layout_attributes( @@ -52,8 +53,8 @@ def _set_layout_attributes( in_layouts: list[ir.Attribute], out_layouts: list[ir.Attribute], ): - op.attributes["in_layouts"] = ir.ArrayAttr.get(in_layouts) - op.attributes["out_layouts"] = ir.ArrayAttr.get(out_layouts) + op.attributes["in_layouts"] = ir.ArrayAttr.get(in_layouts) + op.attributes["out_layouts"] = ir.ArrayAttr.get(out_layouts) def _choose_representative_layout( @@ -135,7 +136,6 @@ def _choose_representative_layout( def _infer_pointwise_op_layouts(op: ir.OpView) -> OptionalLayouts: - def is_array(v: ir.Value) -> bool: return ir.VectorType.isinstance(v.type) @@ -192,7 +192,7 @@ def is_array(v: ir.Value) -> bool: # This is left for a future change, and currently we only do "down # propagation". layout = _choose_representative_layout(layouts) - # It is unsafe to t conclude that this op produces a splat if not all inputs + # It is unsafe to conclude that this op produces a splat if not all inputs # have been inferred: some of them might turn out not to be splats! if layouts_lib.is_splat_fragmented_layout(layout) and not all_inputs_have_layout: return None @@ -247,6 +247,51 @@ def is_array(v: ir.Value) -> bool: _add_layout_inference_rule(op, _infer_pointwise_op_layouts) +# TODO(bchetioui): remove once minimum jaxlib >= 0.5.3. +OptimizationBarrierOp = getattr(mgpu, "OptimizationBarrierOp", None) + + +@partial(_add_layout_inference_rule, OptimizationBarrierOp) +def _infer_optimization_barrier_op_layout( + op: OptimizationBarrierOp, +) -> OptionalLayouts: + def is_array(v: ir.Value) -> bool: + return ir.VectorType.isinstance(v.type) + + if inference_utils.has_in_layouts_set(op): + op_in_layouts = list(inference_utils.in_layouts(op)) + return op_in_layouts, op_in_layouts + + if inference_utils.has_out_layouts_set(op): + op_out_layouts = list(inference_utils.out_layouts(op)) + return op_out_layouts, op_out_layouts + + layouts = [None] * len(op.operands) + for i, operand in enumerate(filter(is_array, op.operands)): + layouts[i] = inference_utils.value_layout(operand) + + for i, result in enumerate(filter(is_array, op.results)): + possible_layouts = set() + for op_operand_use in cast(ir.OpResult, result).uses: + consumer = op_operand_use.owner + op_user = consumer.operands[op_operand_use.operand_number] + layout = inference_utils.in_layout_for_operand(consumer, op_user) + if layout is not None: + possible_layouts.add(layout) + if possible_layouts and layouts[i] is None: + # TODO(bchetioui): we could actually just pick any user layout here, + # and optimize later. This is fine for now. + layouts[i] = _choose_representative_layout(possible_layouts) + + # TODO(bchetioui): handle annotating layout for only certain operands. + # Otherwise, layouts may not get propagated through optimization barriers, if + # a single branch does not carry any forcing layout, which is pretty bad. + if any(layout is None for layout in layouts): + return None + + return layouts, layouts + + @partial(_add_layout_inference_rule, arith.ConstantOp) def _infer_constant_op_layout(constant_op: arith.ConstantOp) -> OptionalLayouts: if not ir.VectorType.isinstance(constant_op.result.type): @@ -289,40 +334,107 @@ def _infer_constant_op_layout(constant_op: arith.ConstantOp) -> OptionalLayouts: return [], [layout] -@partial(_add_layout_inference_rule, scf.YieldOp) -def _infer_yield_op_layout(op: scf.YieldOp) -> OptionalLayouts: +def _layouts_from_values(values: Sequence[ir.Value]) -> list[ir.Attribute] | None: layouts = [] - for result in op.results_: - if not ir.VectorType.isinstance(result.type): + for value in values: + if not ir.VectorType.isinstance(value.type): continue - if (layout := inference_utils.value_layout(result)) is not None: + if (layout := inference_utils.value_layout(value)) is not None: if layouts_lib.is_splat_fragmented_layout(layout): return None layouts.append(layout) else: # Not all layouts could be inferred for vector ops. Return for now. return None + return layouts +@partial(_add_layout_inference_rule, scf.YieldOp) +def _infer_yield_op_layout(op: scf.YieldOp) -> OptionalLayouts: + layouts = _layouts_from_values(op.results_) + if layouts is None: + return None return (layouts, []) +@partial(_add_layout_inference_rule, scf.ConditionOp) +def _infer_condition_op_layout(op: scf.ConditionOp) -> OptionalLayouts: + layouts = _layouts_from_values(op.args) + if layouts is None: + return None + return (layouts, []) + + +def _last_op(region: ir.Region, expected_op_type: type[ir.OpView]): + [block] = region.blocks + last_op = block.operations[len(block.operations) - 1] + assert isinstance(last_op, expected_op_type) + return last_op + + +def _infer_from_op(op: ir.OpView) -> list[ir.Attribute] | None: + if not inference_utils.has_in_layouts_set(op): + return None + in_layouts = list(inference_utils.in_layouts(op)) + if any( + layouts_lib.is_splat_fragmented_layout(layout) + for layout in in_layouts + ): + return None + return in_layouts + + +def _infer_from_yield_ops(op: ir.Operation) -> list[ir.Attribute] | None: + candidates = [] + for region in op.regions: + yield_layouts = _infer_from_op(_last_op(region, scf.YieldOp)) + if yield_layouts is not None: + candidates.append(yield_layouts) + if not candidates: + return None + return [_choose_representative_layout(set(c)) for c in zip(*candidates)] + + @partial(_add_layout_inference_rule, scf.ForOp) def _infer_for_op_layout(op: scf.ForOp) -> OptionalLayouts: - yield_op = op.body.operations[len(op.body.operations) - 1] - assert isinstance(yield_op, scf.YieldOp) - - if inference_utils.has_in_layouts_set(yield_op): - yield_layouts = list(inference_utils.in_layouts(yield_op)) - if any( - layouts_lib.is_splat_fragmented_layout(layout) - for layout in yield_layouts - ): - return None - return (yield_layouts, yield_layouts) - # TODO(bchetioui): we don't attempt to propagate from outside for the moment. # For the existing kernels, propagating from the YieldOp should be enough. + if layouts := _infer_from_yield_ops(op): + return layouts, layouts + return None + + +@partial(_add_layout_inference_rule, scf.WhileOp) +def _infer_while_op_layout(op: scf.WhileOp) -> OptionalLayouts: + # TODO(dasenov): we don't attempt to propagate from outside for the moment. + # Note that the inputs or results do not necessarily contain vector types. If + # there is no vector type, the corresponding layouts (in_layouts or + # out_layouts) should be an empty list. + + yield_op = _last_op(op.after, scf.YieldOp) + needs_in_layouts = inference_utils.should_have_layout(yield_op) + in_layouts = _infer_from_op(yield_op) if needs_in_layouts else [] + + condition_op = _last_op(op.before, scf.ConditionOp) + needs_out_layouts = inference_utils.should_have_layout(condition_op) + out_layouts = _infer_from_op(condition_op) if needs_out_layouts else [] + + if in_layouts is None or out_layouts is None: + return None + return in_layouts, out_layouts + + +@partial(_add_layout_inference_rule, scf.IfOp) +def _infer_if_op_layout(op: scf.IfOp) -> OptionalLayouts: + if layouts := _infer_from_yield_ops(op): + return [], layouts + return None + + +@partial(_add_layout_inference_rule, scf.IndexSwitchOp) +def _infer_index_switch_op_layout(op: scf.IndexSwitchOp) -> OptionalLayouts: + if layouts := _infer_from_yield_ops(op): + return [], layouts return None @@ -333,7 +445,6 @@ def _infer_splat_op_layout(splat_op: vector.SplatOp) -> OptionalLayouts: shape=cast(ir.ShapedType, splat_op.result.type).shape ) ) - return [], [layout] @@ -374,6 +485,161 @@ def _infer_reduction_op_layout(op: vector.ReductionOp) -> OptionalLayouts: return None +@partial(_add_layout_inference_rule, vector.MultiDimReductionOp) +def _infer_multi_dim_reduction_op_layout( + op: vector.MultiDimReductionOp, +) -> OptionalLayouts: + if inference_utils.has_any_layout_set(op): + # At the moment we either have all layouts or none. So if we found some + # layouts, set just return the same ones. + op_in_layouts = list(inference_utils.in_layouts(op)) + op_out_layouts = list(inference_utils.out_layouts(op)) + return op_in_layouts, op_out_layouts + + in_ty = ir.VectorType(op.source.type) + out_ty = ir.VectorType(op.result.type) + if len(in_ty.shape) != 2 or len(out_ty.shape) != 1: + raise NotImplementedError( + f"Only 2D -> 1D reductions are supported: {op}" + ) + + wgmma_layout = layouts_lib.to_layout_attr(fa.WGMMA_LAYOUT) + wgmma_row_layout = layouts_lib.to_layout_attr(fa.WGMMA_ROW_LAYOUT) + wgmma_col_layout = layouts_lib.to_layout_attr(fa.WGMMA_COL_LAYOUT) + reduction_dims = list(op.reduction_dims) + + # Find out the layout of the source. + in_layout = inference_utils.value_layout(op.source) + if in_layout is not None and in_layout == wgmma_layout: + if reduction_dims == [0]: + out_layout = wgmma_col_layout + elif reduction_dims == [1]: + out_layout = wgmma_row_layout + else: + raise NotImplementedError( + f"Invalid reduction dimensions: {reduction_dims}" + ) + return [in_layout, out_layout], [out_layout] + + # The source either has no layout or its layout is not WGMMA so we don't know + # yet how to handle it. Find out the layout of the result and see if that is + # WGMMA_ROW or WGMMA_COL which would imply the input is WGMMA. We can look at + # either the consumers or the acc input (they should have the same layout). + out_layouts = set() + + # Get acc layout. + acc_layout = inference_utils.value_layout(op.acc) + if acc_layout is not None: + out_layouts.add(acc_layout) + + # Get user layouts. + for use in cast(ir.OpResult, op.result).uses: + consumer = use.owner + operand = consumer.operands[use.operand_number] + layout = inference_utils.in_layout_for_operand(consumer, operand) + if layout: + out_layouts.add(layout) + + if not out_layouts: + # We couldn't find any definitive layouts, so we can't infer anything. + return None + + out_layout = _choose_representative_layout(out_layouts) + if out_layout is None: + raise NotImplementedError( + f"Could not choose a best layout from {out_layouts}" + ) + if out_layout != wgmma_row_layout and out_layout != wgmma_col_layout: + # We don't have a layout we can handle in the output, so we can't infer + # anything. + return None + + if (out_layout == wgmma_row_layout and reduction_dims == [1]) or ( + out_layout == wgmma_col_layout and reduction_dims == [0] + ): + in_layout = wgmma_layout + else: + raise NotImplementedError( + f"Unsupported output layout: {out_layout} for reduction dimensions" + f" {reduction_dims}" + ) + + return [in_layout, out_layout], [out_layout] + + +@partial(_add_layout_inference_rule, mgpu.LayoutCastOp) +def _infer_layout_cast_op_layout( + layout_cast_op: mgpu.LayoutCastOp, +) -> OptionalLayouts: + return [layout_cast_op.new_layout], [layout_cast_op.new_layout] + + +# TODO(dasenov): Remove this after the minimal jaxlib version is 0.6.1. +if hasattr(mgpu, "BroadcastInDimOp"): + @partial(_add_layout_inference_rule, mgpu.BroadcastInDimOp) + def _infer_broadcast_in_dim_op_layout( + op: mgpu.BroadcastInDimOp, + ) -> OptionalLayouts: + if inference_utils.has_any_layout_set(op): + op_in_layouts = list(inference_utils.in_layouts(op)) + op_out_layouts = list(inference_utils.out_layouts(op)) + return op_in_layouts, op_out_layouts + + in_ty = ir.VectorType(op.operand.type) + out_ty = ir.VectorType(op.result.type) + if len(in_ty.shape) != 1 or len(out_ty.shape) != 2: + raise NotImplementedError( + "Broadcast in dim with non-trivial broadcast dimensions is not" + f" supported: {op}" + ) + + # Find out the layout of the output from the consumers. + user_layouts = set() + for use in cast(ir.OpResult, op.result).uses: + consumer = use.owner + operand = consumer.operands[use.operand_number] + layout = inference_utils.in_layout_for_operand(consumer, operand) + if layout is not None: + user_layouts.add(layout) + if user_layouts: + out_layout = _choose_representative_layout(user_layouts) + + if out_layout is None: + raise ValueError(f"Could not choose a best layout from {user_layouts}") + + if out_layout != layouts_lib.to_layout_attr(fa.WGMMA_LAYOUT): + raise NotImplementedError(f"Unsupported layout: {out_layout}") + + broadcast_dims = list(op.broadcast_dimensions) + if broadcast_dims == [0]: + in_layout = layouts_lib.to_layout_attr(fa.WGMMA_ROW_LAYOUT) + elif broadcast_dims == [1]: + in_layout = layouts_lib.to_layout_attr(fa.WGMMA_COL_LAYOUT) + else: + raise ValueError(f"Invalid broadcast dimensions: {broadcast_dims}") + + return [in_layout], [out_layout] + + # The consumers did not have any layouts set. Find out the layout of the + # input and infer the output layout from it. + in_layout = inference_utils.value_layout(op.operand) + if in_layout is None: + return None + + broadcast_dims = list(op.broadcast_dimensions) + if ( + broadcast_dims == [0] + and in_layout == layouts_lib.to_layout_attr(fa.WGMMA_ROW_LAYOUT) + ) or ( + broadcast_dims == [1] + and in_layout == layouts_lib.to_layout_attr(fa.WGMMA_COL_LAYOUT) + ): + out_layout = layouts_lib.to_layout_attr(fa.WGMMA_LAYOUT) + return [in_layout], [out_layout] + + return None + + @partial(_add_layout_inference_rule, mgpu.WGMMAOp) def _infer_wgmma_op_layout(wgmma_op: mgpu.WGMMAOp) -> OptionalLayouts: layout = layouts_lib.to_layout_attr(fa.WGMMA_LAYOUT) @@ -394,21 +660,6 @@ def _earliest_use(regions: list[ir.Region], uses: Sequence[ir.OpOperand]) -> ir. raise ValueError("None of uses are in the given block") -def _insert_memref_layout_cast(layout: ir.Attribute, view_op: memref.ViewOp): - mem_ref_type = ir.MemRefType(view_op.result.type) - memref_new_type = ir.MemRefType.get( - mem_ref_type.shape, - mem_ref_type.element_type, - layout, - mem_ref_type.memory_space, - ) - uses = list(view_op.result.uses) - with ir.InsertionPoint(_earliest_use(view_op.parent.regions, uses)): - cast_op = memref.cast(memref_new_type, view_op.result) - for use in uses: - use.owner.operands[use.operand_number] = cast_op - - class TraversalOrder(enum.Enum): """Traversal orders with respect to the data flow for IR.""" @@ -454,20 +705,20 @@ def inference_step(op: ir.Operation): # # We run two passes over the module, in order to make sure that layouts # defined in the middle of the computation are propagated wherever they need - # to be propagated. We start with a backwards (root-to-parameters) pass to - # propagate the information as far up as possible, and then a forward pass - # (parameters-to-root). + # to be propagated. We start with a forward (parameters-to-root) pass to + # preserve replicated layouts as far down as possible, and then do a + # backwards (root-to-parameters) pass. # - # Backwards pass + # Forward pass for op in module.body: inference_utils.traverse_op( - op, inference_step, inference_utils.TraversalOrder.BACKWARDS + op, inference_step, inference_utils.TraversalOrder.FORWARD ) - # Forward pass + # Backwards pass for op in module.body: inference_utils.traverse_op( - op, inference_step, inference_utils.TraversalOrder.FORWARD + op, inference_step, inference_utils.TraversalOrder.BACKWARDS ) # At this point, layouts have been propagated as far as they could be @@ -479,23 +730,33 @@ def inference_step(op: ir.Operation): # make sure to derive a single vector size in order to avoid relayouts at # lowering time. default_vector_size = math.inf - - def update_default_vector_size(op: ir.OpView): + def update_default_vector_size_from_vector(v: ir.Value): nonlocal default_vector_size - for v in list(op.operands) + list(op.results): - if ir.VectorType.isinstance(v.type): - max_vec_size_for_v = ( - np.prod(cast(ir.ShapedType, v.type).shape) // fa.WARPGROUP_SIZE - ) - desired_vec_size = 8 // utils.bytewidth(v.type.element_type) - default_vector_size = min( - default_vector_size, max_vec_size_for_v, desired_vec_size - ) + max_vec_size_for_v = ( + np.prod(cast(ir.ShapedType, v.type).shape) // fa.WARPGROUP_SIZE + ) + desired_vec_size = 64 // utils.bitwidth(v.type.element_type) # pytype: disable=attribute-error + default_vector_size = min( + default_vector_size, max_vec_size_for_v, desired_vec_size + ) + + def update_default_vector_size_from_op(op: ir.OpView): + for i, v in enumerate( + filter(lambda v: ir.VectorType.isinstance(v.type), op.operands) + ): + if inference_utils.attr_element("in_layouts", op, i) is None: + update_default_vector_size_from_vector(v) + + for i, v in enumerate( + filter(lambda v: ir.VectorType.isinstance(v.type), op.results) + ): + if inference_utils.attr_element("out_layouts", op, i) is None: + update_default_vector_size_from_vector(v) for op in module.body: - traverse_op(op, update_default_vector_size) + traverse_op(op, update_default_vector_size_from_op) - if default_vector_size is None: # Nothing to annotate. + if default_vector_size == math.inf: # Nothing to annotate. return def to_default_layout(ty: ir.Type) -> ir.Attribute | None: diff --git a/jax/experimental/mosaic/gpu/layouts.py b/jax/experimental/mosaic/gpu/layouts.py index 5c3b23119779..0a4f3ed09116 100644 --- a/jax/experimental/mosaic/gpu/layouts.py +++ b/jax/experimental/mosaic/gpu/layouts.py @@ -96,7 +96,7 @@ def is_strided_fragmented_layout(attr: ir.Attribute) -> bool: _tiled_layout_attr_pattern = re.compile( r"^#mosaic_gpu.TiledLayout<\[(?P.*)\]," - r" warp_dim\s*=\s*(?P[-\d]+)," + r" warp_dim\s*=\s*(?P.+)," r" lane_dims\s*=\s*\[(?P.*)\]," r" vector_dim\s*=\s*(?P[-\d]+)>$" ) @@ -107,15 +107,29 @@ def to_tiled_layout_attr( ) -> ir.Attribute: """Constructs a #mosaic_gpu.TiledLayout attribute from a TiledLayout.""" + def _int_or_replicated(d: int | fa.Replicated) -> str: + if isinstance(d, fa.Replicated): + return f"#mosaic_gpu.Replicated" + return str(d) + tile_str = lambda tile: "[" + ", ".join(str(d) for d in tile) + "]" tiling = "[" + ", ".join(tile_str(tile) for tile in layout.tiling.tiles) + "]" + lane_dims = ( + "[" + ",".join(_int_or_replicated(d) for d in layout.lane_dims) + "]" + ) + return ir.Attribute.parse( - f"#mosaic_gpu.TiledLayout<{tiling}, warp_dim={layout.warp_dim}," - f" lane_dims={list(layout.lane_dims)}, vector_dim={layout.vector_dim}>" + f"#mosaic_gpu.TiledLayout<{tiling}," + f" warp_dim={_int_or_replicated(layout.warp_dim)}," + f" lane_dims={lane_dims}, vector_dim={layout.vector_dim}>" ) _list_of_lists_delimiter = re.compile(r"\]\s*,\s*\[") +_int_pattern = re.compile(r"^(?P[-\d]+)(\s*:\s*\w+)?$") +_replicated_pattern = re.compile( + r"^#mosaic_gpu.Replicated<\s*times\s*=\s*(?P\d+)\s*>\s*$" +) def from_tiled_layout_attr( @@ -133,6 +147,15 @@ def from_tiled_layout_attr( f"Expected a #mosaic_gpu.TiledLayout attribute, got {attr}" ) + def _int_or_replicated(replicated_dim: str) -> int | fa.Replicated: + match = _replicated_pattern.fullmatch(replicated_dim) + if match: + return fa.Replicated(int(match.group("times"))) + match = _int_pattern.fullmatch(replicated_dim) + if match: + return int(match.group("num")) + raise ValueError(f"Unexpected format for replicated dim {replicated_dim}") + tiling_str = match.group("tiling") tile_strings = [] if len(tiling_str) > 2: @@ -140,9 +163,12 @@ def from_tiled_layout_attr( tiles = tuple(tuple(map(int, ts.split(","))) for ts in tile_strings) return fa.TiledLayout( tiling=fa.Tiling(tiles), - warp_dim=int(match.group("warp_dim")), - lane_dims=tuple(int(s) for s in match.group("lane_dims").split(",")), - vector_dim=int(match.group("vector_dim")) + warp_dim=_int_or_replicated(match.group("warp_dim")), + lane_dims=tuple( + _int_or_replicated(s.strip()) + for s in match.group("lane_dims").split(",") + ), + vector_dim=int(match.group("vector_dim")), ) @@ -155,7 +181,6 @@ def to_layout_attr( fa.WGSplatFragLayout | fa.WGStridedFragLayout | fa.TiledLayout - | fa.WGMMARowFragLayout ), ) -> ir.Attribute: """Constructs an MLIR attribute that corresponds to the given layout.""" @@ -166,30 +191,18 @@ def to_layout_attr( return to_strided_fragmented_layout_attr(layout) case fa.TiledLayout(): return to_tiled_layout_attr(layout) - case fa.WGMMARowFragLayout(): - return ir.Attribute.parse("#mosaic_gpu.WGMMARowFragLayout") case _: raise NotImplementedError( f"Unsupported layout for conversion to MLIR attribute: {layout}" ) -_wgmma_row_fragmented_layout_attr_pattern = re.compile( - r"^#mosaic_gpu.WGMMARowFragLayout$" -) - - -def is_wgmma_row_fragmented_layout(attr: ir.Attribute) -> bool: - return bool(_wgmma_row_fragmented_layout_attr_pattern.search(str(attr))) - - def from_layout_attr( attr: ir.Attribute, ) -> ( fa.WGSplatFragLayout | fa.WGStridedFragLayout | fa.TiledLayout - | fa.WGMMARowFragLayout ): """Constructs a layout from an MLIR attribute.""" if is_splat_fragmented_layout(attr): @@ -198,8 +211,6 @@ def from_layout_attr( return from_strided_fragmented_layout_attr(attr) elif is_tiled_layout(attr): return from_tiled_layout_attr(attr) - elif is_wgmma_row_fragmented_layout(attr): - return fa.WGMMARowFragLayout() else: raise NotImplementedError( f"Unsupported layout for conversion from MLIR attribute: {attr}" diff --git a/jax/experimental/mosaic/gpu/profiler.py b/jax/experimental/mosaic/gpu/profiler.py index 0c128f88d169..b4d06aba1671 100644 --- a/jax/experimental/mosaic/gpu/profiler.py +++ b/jax/experimental/mosaic/gpu/profiler.py @@ -17,10 +17,12 @@ import itertools import json import math -from typing import Callable, ParamSpec, TypeVar +from typing import ParamSpec, TypeAlias, TypeVar +from collections.abc import Callable import warnings import jax +from jax._src import stages from jax._src.lib import xla_client import jax.numpy as jnp from jaxlib.mlir import ir @@ -97,29 +99,44 @@ def run(*args, **kwargs): return outs, float(elapsed) -def _measure_cupti(f, aggregate): - def run(*args, **kwargs): - mosaic_gpu_lib._mosaic_gpu_ext._cupti_init() - try: - results = jax.block_until_ready(jax.jit(f)(*args, **kwargs)) - finally: - timings = mosaic_gpu_lib._mosaic_gpu_ext._cupti_get_timings() - return results, timings - - def wrapper(*args, **kwargs): - run(*args, **kwargs) # Warmup. - results, timings = run(*args, **kwargs) - if not timings: - return results, None - elif aggregate: - return results, sum(item[1] for item in timings) - else: - return results, timings - return wrapper - - -def measure(f: Callable, *, mode: str = "events", aggregate: bool = True -) -> Callable: +Timings: TypeAlias = list[tuple[str, float]] | float | None + + +@dataclasses.dataclass(frozen=True, kw_only=True) +class Cupti: + """CUPTI-based profiler.""" + + # If `True`, detach CUPTI from the process after measurement. + finalize: bool = True + + def measure( + self, f: Callable[P, T], *, aggregate: bool = True + ) -> Callable[P, tuple[T, Timings]]: + if not isinstance(f, (stages.Wrapped, stages.Compiled)): + f = jax.jit(f) + + def wrapper(*args: P.args, **kwargs: P.kwargs): + jax.block_until_ready(f(*args, **kwargs)) # Warmup. + ext = mosaic_gpu_lib._mosaic_gpu_ext + ext._cupti_init() + try: + results = jax.block_until_ready(f(*args, **kwargs)) + finally: + timings = ext._cupti_get_timings(self.finalize) + + if not timings: + return results, None + elif aggregate: + return results, sum(item[1] for item in timings) + else: + return results, timings + + return wrapper + + +def measure( + f: Callable[P, T], *, mode: str = "events", aggregate: bool = True +) -> Callable[P, tuple[T, Timings]]: """Sets up a function ``f`` for profiling on GPU. ``measure`` is a higher-order function that augments the argument ``f`` to @@ -173,10 +190,10 @@ def measure(f: Callable, *, mode: str = "events", aggregate: bool = True In an attempt to minimize the second effect, internally the events-based implementation may execute ``f`` more than once to "warm up" and exclude compilation time from the measurement. - """ + """ # fmt: skip match mode: case "cupti": - return _measure_cupti(f, aggregate) + return Cupti().measure(f, aggregate=aggregate) case "events": if not aggregate: raise ValueError(f"{aggregate=} is not supported with {mode=}") @@ -247,10 +264,20 @@ def dump(self, buffer, f, grid: tuple[int, ...], block: tuple[int, ...]): if np.any(entries_used > self.entries_per_warpgroup - 2): raise RuntimeError("Insufficient space to capture a full trace") traces = entries[..., 3:] + + # Estimate the overhead of profiling. + time_events = traces[:, :, 1::2] + valid_times_mask = np.arange(traces.shape[-1])[1::2] < (entries_used[..., None] - 3) + # 12 cycles is a ballpark estimate for H100 + profiling_overhead = (time_events[:, :, 1:] - time_events[:, :, :-1]).min( + where=valid_times_mask[:, :, 1:], initial=12 + ) + profiling_overhead = max(0, profiling_overhead - 1) + unintern = {v: k for k, v in self.interned_names.items()} events = [] for block_idx, wg_idx in np.ndindex(num_blocks, warpgroups_per_block): - valid_entries = entries_used[block_idx, wg_idx] - 3 + valid_entries = (entries_used[block_idx, wg_idx] - 3) local_clock_offset = None assert valid_entries % 2 == 0, valid_entries start_time = start_times[block_idx, wg_idx] @@ -262,7 +289,7 @@ def dump(self, buffer, f, grid: tuple[int, ...], block: tuple[int, ...]): if local_clock_offset is None: local_clock_offset = time time -= local_clock_offset - time -= i * 6 # Account for the overhead of profiling. + time -= (i // 2) * profiling_overhead # Account for the overhead of profiling. if time < 0: break # Detect a timer wraparound name_id = tag diff --git a/jax/experimental/mosaic/gpu/tcgen05.py b/jax/experimental/mosaic/gpu/tcgen05.py index 3330500cd6dc..904aec493b3b 100644 --- a/jax/experimental/mosaic/gpu/tcgen05.py +++ b/jax/experimental/mosaic/gpu/tcgen05.py @@ -35,6 +35,25 @@ TMEM_ROWS = 128 TCGEN05_SMEM_DESCRIPTOR_BIT = 1 << 46 +LAYOUT = fa.TCGEN05_LAYOUT +ROW_LAYOUT = fa.TCGEN05_ROW_LAYOUT +COL_LAYOUT = fa.TCGEN05_COL_LAYOUT + +# A layout resembling the logical organization of TMEM. The 128 rows in a tile +# are assigned to 128 lanes in the warpgroup. Useful when the result needs to be +# processed in registers and then stored back into TMEM. Should not be used if +# the result is to be written back to SMEM, as there is no good way to store it +# without bank conflicts. +# +# We use a vector_dim of 2, to be able to make sure that the vectors are always +# a multiple of 32-bits, even when the data is 16-bits. +TMEM_NATIVE_LAYOUT = fa.TiledLayout( + fa.Tiling(((128, 2), (32, 2))), + warp_dim=-4, + lane_dims=(-2,), + vector_dim=-1, +) + def create_instr_descriptor( m: int, @@ -44,20 +63,40 @@ def create_instr_descriptor( transpose_a: bool = False, transpose_b: bool = False, ): - f32 = ir.F32Type.get() - bf16 = ir.BF16Type.get() f16 = ir.F16Type.get() - if input_dtype not in {f16, bf16}: - raise NotImplementedError("Only float16 and bfloat16 inputs supported") - if acc_dtype not in {f32, f16}: - raise NotImplementedError("Only float32 and float16 accumulators supported") + f32 = ir.F32Type.get() + i32 = ir.IntegerType.get_signless(32) desc = 0 - # We ignore sparsity in bits 0-3 - desc |= (acc_dtype == f32) << 4 # D dtype, bits 4-5 + if acc_dtype == f16: + d_type_val = 0 + elif acc_dtype == f32: + d_type_val = 1 + elif acc_dtype == i32: + d_type_val = 2 + else: + raise NotImplementedError(f"Unsupported accumulator dtype: {acc_dtype}") + desc |= (d_type_val << 4) # D type, bits 4-5 # Bit 6 is reserved - desc |= (input_dtype == bf16) << 7 # A dtype, bits 7-9 - desc |= (input_dtype == bf16) << 10 # B dtype, bits 10-12 + if input_dtype == f16: + assert acc_dtype in {f16, f32} + ab_type_val = 0 + elif input_dtype == ir.BF16Type.get(): + assert acc_dtype == f32 + ab_type_val = 1 + elif input_dtype == ir.Float8E4M3FNType.get(): + assert acc_dtype in {f16, f32} + ab_type_val = 0 + elif input_dtype == ir.Float8E5M2Type.get(): + assert acc_dtype in {f16, f32} + ab_type_val = 1 + elif input_dtype == ir.IntegerType.get_signless(8): # Only s8 for now. + assert acc_dtype == i32 + ab_type_val = 1 + else: + raise NotImplementedError(f"Unsupported input dtype: {input_dtype}") + desc |= (ab_type_val << 7) # A dtype, bits 7-9 + desc |= (ab_type_val << 10) # B dtype, bits 10-12 # We ignore negate bits 13-14 desc |= transpose_a << 15 # Transpose A desc |= transpose_b << 16 # Transpose B @@ -75,7 +114,7 @@ def create_instr_descriptor( def mma( d: TMEMRef, - a: ir.Value, + a: ir.Value | TMEMRef, b: ir.Value, *, a_swizzle: int = 128, @@ -95,12 +134,22 @@ def mma( num_cta = 2 if collective else 1 # Step 1. Establish the shape and element type of the operation. - if not ir.MemRefType.isinstance(a.type): - raise ValueError(f"A must be a memref, got {a.type}") if not ir.MemRefType.isinstance(b.type): raise ValueError(f"B must be a memref, got: {b.type}") (k, n), element_type = mma_utils.tiled_memref_shape(b) - (m, k2), element_type2 = mma_utils.tiled_memref_shape(a) + if isinstance(a, TMEMRef): + m, k2 = a.shape + element_type2 = a.dtype + if collective and n * num_cta == 512: + raise NotImplementedError("Collective MMA with N=512 is not supported") + if a.layout != (expected_layout := _infer_tmem_layout(a.shape, packing=2)): + raise ValueError( + f"A layout mismatch: expected {expected_layout}, got {a.layout}" + ) + else: + if not ir.MemRefType.isinstance(a.type): + raise ValueError(f"A must be a memref, got {a.type}") + (m, k2), element_type2 = mma_utils.tiled_memref_shape(a) if k != k2: raise ValueError( "MMA requires A and B to have the same contraction dimension (K)," @@ -115,23 +164,41 @@ def mma( raise ValueError( f"Accumulator shape mismatch: expected {(m, n * num_cta)}, got {d.shape}" ) - if d.layout != (expected_layout := _infer_tmem_layout(d.shape, collective)): + expected_d_layout = ( + TMEM_COLLECTIVE_N512_LAYOUT + if collective and n * num_cta == 512 + else TMEM_DEFAULT_LAYOUT + ) + if d.layout != expected_d_layout: raise ValueError( - f"Accumulator layout mismatch: expected {expected_layout}, got {d.layout}" + f"Accumulator layout mismatch: expected {expected_d_layout}, got {d.layout}" ) f32 = ir.F32Type.get() + f16 = ir.F16Type.get() + s32 = ir.IntegerType.get_signless(32) if element_type == f32 or element_type == ir.BF16Type.get(): if d.dtype != f32: raise ValueError( f"MMA with element type {element_type} only supports accumulators" f" of type f32, but got: {d.dtype}" ) - elif element_type == ir.F16Type.get(): - if d.dtype != element_type and d.dtype != f32: + elif any( + t.isinstance(element_type) + for t in {ir.F16Type, ir.Float8E5M2Type, ir.Float8E4M3FNType} + ): + if d.dtype != f16 and d.dtype != f32: raise ValueError( - "MMA with element type f16 only supports accumulators of type f32" - f" or f16, but got: {d.dtype}" + f"MMA with element type {element_type} only supports accumulators of" + f" type f32 or f16, but got: {d.dtype}" ) + elif element_type == ir.IntegerType.get_signless(8): + if d.dtype != s32: + raise ValueError( + "MMA with element type s8 only supports s32 accumulators, but got:" + f" {d.dtype}" + ) + else: + raise NotImplementedError(f"Unsupported element type: {element_type}") # Step 2. Decide on the instruction shapes we'll use. Note that with swizzles, # instructions must be issued in groups of the same width as the swizzle. @@ -153,22 +220,27 @@ def mma( m_groups = m // m_group_elems k_groups = k // k_group_elems n_groups = n // n_group_elems - # TODO(apaszke): Require users to bitcast input refs to tf32 before WGMMA. - wgmma_element_type = ( + # TODO(apaszke): Require users to bitcast input refs to tf32 before MMA. + mma_element_type = ( ir.FloatTF32Type.get() if element_type == ir.F32Type.get() else element_type ) # Step 3. Compute the operand descriptors. - ( - (a_desc_base, a_k_instr_stride), - (a_m_group_stride, a_k_group_stride), - a_fastest, - ) = mma_utils.create_descriptor( - a, - swizzle=swizzle, - group_size=(m_group_elems, k_group_elems), - logical_k_major=False, - ) + if not isinstance(a, TMEMRef): + ( + (a_desc_base, a_k_instr_stride), + (a_m_group_stride, a_k_group_stride), + a_fastest, + ) = mma_utils.create_descriptor( + a, + swizzle=swizzle, + group_size=(m_group_elems, k_group_elems), + logical_k_major=False, + ) + else: + a_fastest = mma_utils.Dim.K + a_k_instr_stride = None + a_m_group_stride = a_k_group_stride = a_desc_base = None ( (b_desc_base, b_k_instr_stride), (b_n_group_stride, b_k_group_stride), @@ -184,8 +256,11 @@ def mma( true = arith.constant(ir.IntegerType.get_signless(1), 1) n_collective_group_elems = n_group_elems * num_cta for mi, ni, ki in np.ndindex(m_groups, n_groups, k_groups): - a_offset = mi * a_m_group_stride + ki * a_k_group_stride - a_mk = arith.addi(a_desc_base, utils.c(mma_utils.encode_addr(a_offset), i64)) + if isinstance(a, TMEMRef): + a_mk = a.slice(slice(None), utils.ds(ki * k_group_elems, k_group_elems)).address + else: + a_offset = mi * a_m_group_stride + ki * a_k_group_stride + a_mk = arith.addi(a_desc_base, utils.c(mma_utils.encode_addr(a_offset), i64)) b_offset = ni * b_n_group_stride + ki * b_k_group_stride b_nk = arith.addi(b_desc_base, utils.c(mma_utils.encode_addr(b_offset), i64)) if m_groups != 1: @@ -197,7 +272,7 @@ def mma( ), a_mk, b_nk, - d_type=ir.F32Type.get(), + d_type=d.dtype, m=m_group_elems, n=n_group_elems, collective=collective, @@ -207,17 +282,17 @@ def mma( b_k_stride=b_k_instr_stride, accumulate=acc, swizzle=swizzle, - element_type=wgmma_element_type, + element_type=mma_element_type, ) def _do_mma( d_addr: ir.Value, - a_desc: ir.Value, + a_desc_or_addr: ir.Value, # TMEM address if a_k_stride is None b_desc: ir.Value, a_transpose: bool, b_transpose: bool, - a_k_stride: int, + a_k_stride: int | None, b_k_stride: int, m: int, n: int, @@ -228,14 +303,23 @@ def _do_mma( collective: bool, ): i1 = ir.IntegerType.get_signless(1) + i32 = ir.IntegerType.get_signless(32) i64 = ir.IntegerType.get_signless(64) - kn_tiling = swizzle // utils.bytewidth(element_type) - instr_k = 32 // utils.bytewidth(element_type) - if a_k_stride % 16 or b_k_stride % 16: + elem_bytewidth = utils.bytewidth(element_type) + kn_tiling = swizzle // elem_bytewidth + instr_k = 32 // elem_bytewidth + packing = 4 // elem_bytewidth + if (a_k_stride is not None and a_k_stride % 16) or b_k_stride % 16: raise ValueError if ir.F16Type.isinstance(element_type) or ir.BF16Type.isinstance(element_type): kind = "f16" + elif ir.Float8E5M2Type.isinstance(element_type): + kind = "f8f6f4" + elif ir.Float8E4M3FNType.isinstance(element_type): + kind = "f8f6f4" + elif ir.IntegerType.get_signless(8).isinstance(element_type): + kind = "i8" else: raise NotImplementedError(f"Unsupported input element type: {element_type}") @@ -243,16 +327,27 @@ def _do_mma( i_desc = create_instr_descriptor( m * num_cta, n * num_cta, d_type, element_type, a_transpose, b_transpose ) + a_in_tmem = a_k_stride is None + a_ptx = "[$1]" if a_in_tmem else "$1" + a_ptx_constraint = "r" if a_in_tmem else "l" + assert a_desc_or_addr.type == ir.IntegerType.get_signless(32 if a_in_tmem else 64) for _ in range(kn_tiling // instr_k): llvm.inline_asm( ir.Type.parse("!llvm.void"), - [d_addr, a_desc, b_desc, i_desc, accumulate], - f"tcgen05.mma.cta_group::{num_cta}.kind::{kind} [$0], $1, $2, $3, $4;", - "r,l,l,r,b", + [d_addr, a_desc_or_addr, b_desc, i_desc, accumulate], + f"tcgen05.mma.cta_group::{num_cta}.kind::{kind} [$0], {a_ptx}, $2, $3, $4;", + f"r,{a_ptx_constraint},l,r,b", has_side_effects=True, ) accumulate = arith.constant(i1, 1) - a_desc = arith.addi(a_desc, arith.constant(i64, a_k_stride >> 4)) + if not a_in_tmem: + a_desc_or_addr = arith.addi( + a_desc_or_addr, arith.constant(i64, a_k_stride >> 4) + ) + else: + a_desc_or_addr = arith.addi( + a_desc_or_addr, arith.constant(i32, instr_k // packing) + ) b_desc = arith.addi(b_desc, arith.constant(i64, b_k_stride >> 4)) @@ -288,6 +383,19 @@ def commit_arrive( ) +def _alloc_ncols(ncols: int, exact: bool): + if exact: + if ncols.bit_count() != 1 or not 32 <= ncols <= 512: + raise ValueError(f"ncols must be a power of 2 and within [32, 512], got: {ncols}") + else: + ncols = max(32, 1 << (ncols - 1).bit_length()) + if ncols > 512: + raise ValueError( + f"After rounding up, got {ncols} columns, exceeding the limit of 512" + ) + return ncols + + def tmem_alloc(tmem_addr: ir.Value, ncols: int, collective: bool = False, exact: bool = True): if ir.MemRefType.isinstance(tmem_addr.type): ref_ty = ir.MemRefType(tmem_addr.type) @@ -300,15 +408,7 @@ def tmem_alloc(tmem_addr: ir.Value, ncols: int, collective: bool = False, exact: tmem_addr = utils.memref_ptr(tmem_addr, memory_space=3) elif tmem_addr.type != ir.Type.parse("!llvm.ptr<3>"): raise ValueError(f"tmem_addr must be an SMEM pointer or a memref, got: {tmem_addr.type}") - if exact: - if ncols.bit_count() != 1 or not 32 <= ncols <= 512: - raise ValueError(f"ncols must be a power of 2 and within [32, 512], got: {ncols}") - else: - ncols = max(32, 1 << (ncols - 1).bit_length()) - if ncols > 512: - raise ValueError( - f"After rounding up, got {ncols} columns, exceeding the limit of 512" - ) + ncols = _alloc_ncols(ncols, exact) num_cta = 2 if collective else 1 return llvm.inline_asm( ir.Type.parse("!llvm.void"), @@ -318,45 +418,93 @@ def tmem_alloc(tmem_addr: ir.Value, ncols: int, collective: bool = False, exact: has_side_effects=True, ) -def tmem_relinquish_alloc_permit(): + +def tmem_dealloc(tmem_addr: ir.Value, ncols: int, collective: bool = False, exact: bool = True): + if tmem_addr.type != ir.IntegerType.get_signless(32): + raise ValueError(f"tmem_addr must be an i32, got: {tmem_addr.type}") + ncols = _alloc_ncols(ncols, exact) + num_cta = 2 if collective else 1 + return llvm.inline_asm( + ir.Type.parse("!llvm.void"), + [tmem_addr], + f"tcgen05.dealloc.cta_group::{num_cta}.sync.aligned.b32 $0, {ncols};", + "r", + has_side_effects=True, + ) + + +def tmem_relinquish_alloc_permit(collective: bool): + num_cta = 2 if collective else 1 return llvm.inline_asm( ir.Type.parse("!llvm.void"), [], - "tcgen05.relinquish_alloc_permit.cta_group::1.sync.aligned;", + f"tcgen05.relinquish_alloc_permit.cta_group::{num_cta}.sync.aligned;", "", has_side_effects=True, ) -def tmem_load(tmem_addr, shape, num): +def _tmem_access_helper(shape, num): if num.bit_count() != 1 or num > 128: raise ValueError(f"num must be a power of 2 and <= 128, got: {num}") match shape: + case "32x32b": + num_regs = 1 case "16x128b": - num_out_regs = 2 + num_regs = 2 case "16x256b": - num_out_regs = 4 + num_regs = 4 case _: raise NotImplementedError(f"{shape=} is unsupported") - if num * num_out_regs >= 256: + num_regs *= num + if num_regs > 255: raise ValueError( - f"Loading too much TMEM at once: {num=} and each load requires" - f" {num_out_regs} registers, which exceeds the limit of 256" + f"TMEM translation too big : {shape=} and {num=} involve" + f" {num_regs} registers per-thread, which exceeds the limit of 255" ) - num_out_regs *= num + regs_vector = ",".join(f"${i}" for i in range(num_regs)) + regs_vector = "{" + regs_vector + "}" + return num_regs, regs_vector + + +def tmem_load(tmem_addr, shape, num, pack: bool): i32 = ir.IntegerType.get_signless(32) - out_regs = ",".join("$" + str(i) for i in range(num_out_regs)) + num_out_regs, regs_vector = _tmem_access_helper(shape, num) + pack_mod = ".pack::16b" if pack else "" regs = llvm.inline_asm( ir.Type.parse( "!llvm.struct<(" + ",".join("i32" for _ in range(num_out_regs)) + ")>" ), [tmem_addr], - f"tcgen05.ld.sync.aligned.{shape}.x{num}.b32 {{{out_regs}}}, [${num_out_regs}];", + f"tcgen05.ld.sync.aligned.{shape}.x{num}{pack_mod}.b32 {regs_vector}, [${num_out_regs}];", "=r," * num_out_regs + "r", has_side_effects=True, ) return [llvm.extractvalue(i32, regs, [i]) for i in range(num_out_regs)] +def wait_tmem_load(): + llvm.inline_asm( + ir.Type.parse("!llvm.void"), + [], + "tcgen05.wait::ld.sync.aligned;", + "", + has_side_effects=True, + ) + utils.warpgroup_barrier() + + +def tmem_store(tmem_addr, shape, num, regs, unpack: bool): + num_out_regs, regs_vector = _tmem_access_helper(shape, num) + pack_mod = ".unpack::16b" if unpack else "" + llvm.inline_asm( + ir.Type.parse("!llvm.void"), + [*regs, tmem_addr], + f"tcgen05.st.sync.aligned.{shape}.x{num}{pack_mod}.b32 [${num_out_regs}], {regs_vector};", + "r," * num_out_regs + "r", + has_side_effects=True, + ) + + @dataclasses.dataclass(frozen=True) class TMEMLayout: """Represents the way a shape is laid out in TMEM. @@ -390,6 +538,7 @@ class TMEMLayout: """ elements_in_tile: tuple[int, int] column_tile_stride: int = 1 + packing: int = 1 def __post_init__(self): row_tiling = self.elements_in_tile[0] @@ -399,24 +548,34 @@ def __post_init__(self): ) if row_tiling.bit_count() != 1: raise ValueError(f"Row tiling must be a power of 2, got: {row_tiling}") + if self.elements_in_tile[1] % self.packing: + raise ValueError( + f"Column tiling must be a multiple of packing={self.packing}, got:" + f" {self.elements_in_tile[1]}" + ) - def check_shape(self, shape: tuple[int, ...]): + def check_type(self, shape: tuple[int, ...], dtype: ir.Type): if len(shape) != 2: raise ValueError(f"TMEM can only represent 2D shapes, got {shape}") if any(s % t for s, t in zip(shape, self.elements_in_tile)): raise ValueError( f"{shape} is divisible into tiles of shape {self.elements_in_tile}" ) + if self.packing not in {1, fully_packed := 32 // utils.bitwidth(dtype)}: + raise ValueError( + f"For {utils.bitwidth(dtype)}-bit types, only packing=1 and" + f" packing={fully_packed} are supported, but got: {self.packing}" + ) def cols_in_shape(self, shape: tuple[int, int]): - cols_in_tile = self.elements_in_tile[1] + cols_in_tile = self.elements_in_tile[1] // self.packing tiles_in_row = TMEM_ROWS // self.elements_in_tile[0] num_tiles = math.prod(utils.tile_shape(shape, self.elements_in_tile)[:-2]) assert num_tiles % tiles_in_row == 0 return num_tiles // tiles_in_row * cols_in_tile -def _infer_tmem_layout(shape: tuple[int, int], collective: bool) -> TMEMLayout: +def _infer_tmem_layout(shape: tuple[int, int], packing: int = 1) -> TMEMLayout: if shape[0] > TMEM_ROWS: raise ValueError( "Can only infer TMEM layout for shapes with at most 128 rows, got:" @@ -437,11 +596,13 @@ def _infer_tmem_layout(shape: tuple[int, int], collective: bool) -> TMEMLayout: "Can only infer TMEM layout for shapes with column count that's a" f" multiple of 8, got: {shape[1]}" ) - if collective and shape[1] == 512: - return TMEMLayout(elements_in_tile=(shape[0], 128), column_tile_stride=2) - else: - return TMEMLayout(elements_in_tile=(shape[0], 8)) + return TMEMLayout(elements_in_tile=(shape[0], 8), packing=packing) + +TMEM_DEFAULT_LAYOUT = TMEMLayout(elements_in_tile=(TMEM_ROWS, 8), packing=1) +TMEM_COLLECTIVE_N512_LAYOUT = TMEMLayout( + elements_in_tile=(TMEM_ROWS, 128), column_tile_stride=2, packing=1 +) @dataclasses.dataclass(frozen=True) class TMEMRef: @@ -481,30 +642,41 @@ def from_alloc( ) layout = _infer_tmem_layout(shape, collective) else: - layout.check_shape(shape) + layout.check_type(shape, dtype) # TODO: Do we have to do this?? # warp_idx = utils.warp_idx(sync=False) # tmem_addr = arith.ori(tmem_addr, arith.shli(warp_idx, utils.c(21, i32))) return cls(tmem_addr, shape, dtype, layout) def slice(self, *idxs): + i32 = ir.IntegerType.get_signless(32) base_idx, slice_shape, is_squeezed = utils.parse_indices(idxs, self.shape) if any(is_squeezed): raise ValueError("TMEM can only be sliced, not indexed") - if self.layout != TMEMLayout(elements_in_tile=(TMEM_ROWS, 8)): - raise NotImplementedError( - "Slicing only implemented for refs with standard layout, got:" - f" {self.layout}" - ) + match self.layout: + case TMEMLayout(elements_in_tile=(r, 8), packing=packing) if ( + r == TMEM_ROWS + ): + pass + case _: + raise NotImplementedError( + "Slicing only implemented for refs with standard layout, got:" + f" {self.layout}" + ) if base_idx[0] != 0 or slice_shape[0] != TMEM_ROWS: raise NotImplementedError("TMEM cannot be sliced along rows") if slice_shape[1] % 8: raise NotImplementedError( - "TMEM column slice length must be a multiple of 8" + "TMEM column slice length must be a multiple of 8. " + f"Got {slice_shape[1]}." ) col_idx = base_idx[1] if not isinstance(col_idx, ir.Value): - col_idx = arith.constant(ir.IntegerType.get_signless(32), col_idx) + col_idx = arith.constant(i32, col_idx) + if col_idx.type == ir.IndexType.get(): + col_idx = arith.index_cast(i32, col_idx) + if packing != 1: + col_idx = arith.divui(col_idx, arith.constant(i32, packing)) return TMEMRef( address=arith.addi(self.address, col_idx), shape=tuple(slice_shape), @@ -512,91 +684,360 @@ def slice(self, *idxs): dtype=self.dtype, ) - def __getitem__(self, *idxs): + def load(self, layout: fa.TiledLayout = LAYOUT, is_signed: bool | None = None): i32 = ir.IntegerType.get_signless(32) - base_idxs, slice_shape, is_squeezed = utils.parse_indices(idxs, self.shape) - if any(is_squeezed): - raise ValueError("TMEM loads only support slicing") - if any(idx != 0 for idx in base_idxs) or tuple(slice_shape) != self.shape: - raise NotImplementedError("Slicing of TMEM not impelmented yet") if self.shape[1] % 8: raise NotImplementedError - if self.dtype != ir.F32Type.get(): - raise NotImplementedError(self.dtype) - layout = _m128_256bit_32bit_layout(self.shape) - regs_shape = layout.registers_shape(self.shape) - if self.layout == TMEMLayout(elements_in_tile=(TMEM_ROWS, 8)): - # load_32xcols returns a 4xN array, but the FA tiling we use here tiles - # columns before rows, and so it is Nx4 (after ignoring all 1 dims). - registers = _load_32xcols( - self.address, self.shape[1], self.dtype - ).T.reshape(regs_shape) - elif self.layout == TMEMLayout(elements_in_tile=(TMEM_ROWS, 128), column_tile_stride=2): - if self.shape[1] % 128 != 0: - raise ValueError( - f"TMEM layout {self.layout} is not compatible with shape {self.shape}" - ) - num_column_tiles = self.shape[1] // 128 - column_tile_stride = self.layout.column_tile_stride - num_strided_col_groups = utils.ceil_div(num_column_tiles, column_tile_stride) - tiles = [] - for col_tile_base in range(num_strided_col_groups): - for col_tile in range(col_tile_base, num_column_tiles, column_tile_stride): - tiles.append( - _load_32xcols( - arith.addi(self.address, arith.constant(i32, col_tile * 128)), - cols=128, - dtype=self.dtype, + if utils.bitwidth(self.dtype) not in {16, 32}: + raise NotImplementedError(f"Unsupported dtype: {self.dtype}") + if layout == LAYOUT: + regs_shape = layout.registers_shape(self.shape) + match self.layout: + case TMEMLayout(elements_in_tile=(r, 8), packing=packing) if ( + r == TMEM_ROWS + ): + # load_32xcols returns a 4xN array, but the FA tiling we use here tiles + # columns before rows, and so it is Nx4 (after ignoring all 1 dims). + registers = _load_32xcols( + self.address, self.shape[1], self.dtype, packing + ).T.reshape(regs_shape) + case TMEMLayout(elements_in_tile=(r, 128), column_tile_stride=2) if r == TMEM_ROWS: + if self.shape[1] % 128 != 0: + raise ValueError( + f"TMEM layout {self.layout} is not compatible with shape {self.shape}" + ) + num_column_tiles = self.shape[1] // 128 + column_tile_stride = self.layout.column_tile_stride + num_strided_col_groups = utils.ceil_div(num_column_tiles, column_tile_stride) + tiles = [] + for col_tile_base in range(num_strided_col_groups): + for col_tile in range(col_tile_base, num_column_tiles, column_tile_stride): + tiles.append( + _load_32xcols( + arith.addi(self.address, arith.constant(i32, col_tile * 128)), + cols=128, + dtype=self.dtype, + tmem_packing=1, + ) ) + registers = np.concatenate(tiles, axis=1).T.reshape(regs_shape) + case _: + raise NotImplementedError( + f"Loads only implemented for refs with standard layout, got: {self.layout}" + ) + elif layout == TMEM_NATIVE_LAYOUT: + regs_shape = layout.registers_shape(self.shape) + match self.layout: + case TMEMLayout(elements_in_tile=(r, c), packing=packing) if ( + r == TMEM_ROWS and c % 2 == 0 + ): + registers = _load_32xcols_native( + self.address, self.shape[1], self.dtype, packing + ).reshape(regs_shape) + case _: + raise NotImplementedError( + "Loads only implemented for refs with standard layout, got:" + f" {self.layout}" ) - registers = np.concatenate(tiles, axis=1).T.reshape(regs_shape) else: - raise NotImplementedError( - f"Loads only implemented for refs with standard layout, got: {self.layout}" + raise ValueError( + "TMEM loads can only produce results in the tcgen05 layouts" + f" ({LAYOUT} and {TMEM_NATIVE_LAYOUT}), but got: {layout}" + ) + return fa.FragmentedArray( + _registers=registers, _layout=layout, _is_signed=is_signed + ) + + def store(self, value): + if self.shape[1] % 8: + raise NotImplementedError + if utils.bitwidth(self.dtype) not in {16, 32}: + raise NotImplementedError(f"Unsupported dtype: {self.dtype}") + if not isinstance(value, fa.FragmentedArray): + raise ValueError(f"TMEM stores expect a FragmentedArray, got: {value}") + if value.shape != self.shape: + raise ValueError( + f"Stored array has shape {value.shape}, but TMEM has shape" + f" {self.shape}" + ) + if value.mlir_dtype != self.dtype: + raise ValueError( + f"Stored array has dtype {value.mlir_dtype}, but TMEM has dtype" + f" {self.dtype}" + ) + if value.layout == LAYOUT: + # TODO(apaszke): Collective MMA layout + match self.layout: + case TMEMLayout(elements_in_tile=(r, 8), packing=packing) if ( + r == TMEM_ROWS + ): + # store_32xcols needs a 4xN array, but the FA tiling we use here tiles + # columns before rows, and so it is Nx4 (after ignoring all 1 dims). + _store_32xcols( + self.address, value.registers.T.reshape((4, -1)), packing + ) + case _: + raise NotImplementedError( + f"Stores only implemented for refs with standard layout, got: {self.layout}" + ) + elif value.layout == TMEM_NATIVE_LAYOUT: + # TODO(apaszke): Collective MMA layout + match self.layout: + case TMEMLayout(elements_in_tile=(r, c), packing=packing) if ( + r == TMEM_ROWS and c % 2 == 0 + ): + _store_32xcols_native( + self.address, value.registers.reshape(-1), packing + ) + case _: + raise NotImplementedError( + f"Stores only implemented for refs with standard layout, got: {self.layout}" + ) + else: + raise ValueError( + f"Stored array has layout {value.layout}, but only tcgen05.LAYOUT and" + " tcgen05.TMEM_NATIVE_LAYOUT are supported" ) - return fa.FragmentedArray(_registers=registers, _layout=layout, _is_signed=None) -def _load_32xcols(base_addr, cols, dtype): - # See https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-matrix-fragments-shape-16256b + def _debug_print(self): + i32 = ir.IntegerType.get_signless(32) + num_cols = self.layout.cols_in_shape(self.shape) + lane = arith.remui(utils.thread_idx(), arith.constant(i32, utils.WARPGROUP_SIZE)) + for c in range(num_cols): + val = llvm.inline_asm( + i32, + [arith.addi(self.address, arith.constant(i32, c))], + "tcgen05.ld.sync.aligned.32x32b.x1.b32 {$0}, [$1];", + "=r,r", + ) + dtype_bitwidth = utils.bitwidth(self.dtype) + full_packing = 32 // dtype_bitwidth + if self.layout.packing == 1: + if dtype_bitwidth < 32: + val = arith.trunci(ir.IntegerType.get_signless(dtype_bitwidth), val) + val = utils.bitcast(val, self.dtype) + elif self.layout.packing == full_packing: + val = utils.bitcast(val, ir.VectorType.get((full_packing,), self.dtype)) + else: + raise NotImplementedError(f"Unsupported packing: {self.layout.packing}") + # TODO(apaszke): Make this print logical, not physical location. + utils.debug_print(f"[{{}}, {c}]: {{}}", lane, val, uniform=False) + + +def _transfer_32xcols( + base_addr: ir.Value, + cols: int, + atom_shape: tuple[int, int], + tmem_packing: int, + reg_packing: int, +): + """Generates a sequence of parameters for a given TMEM read or write. + + Arguments: + base_addr: The base address of the TMEM region. + cols: The number of logical columns to transfer. + atom_shape: The logical shape of the tile written by the warp in a single + TMEM transfer. + tmem_packing: Packing degree in TMEM. When packing is 1, but the data is + 16-bit, we expect that each transfer actually involves double the number + of physical columns. + reg_packing: The number of elements that fit in a single 32-bit register. + """ i32 = ir.IntegerType.get_signless(32) - assert cols % 8 == 0 - cols_per_num_tile = 8 - load_shape = "16x256b" - num = cols // 8 - if num <= 32: - num_tiling = num - elif num == 64: - num_tiling = 32 - else: - raise NotImplementedError(num) - vector_regs = np.ndarray((4, num), dtype=object) - # We load 16 lanes at a time, but need 32 in total. - for row_group in range(2): - addr_row = arith.addi(base_addr, arith.constant(i32, (row_group * 16) << 16)) - regs = [] - for num_group in range(num // num_tiling): + atom_rows, atom_cols = atom_shape + assert cols % atom_cols == 0 + total_num = cols // atom_cols + assert total_num.bit_count() == 1 + regs_per_instr = atom_shape[0] * atom_shape[1] // (utils.WARP_SIZE * reg_packing) + # We artificially lower the instr_num compared to its limits, because higher + # values can lead to register spills.. + instr_num = min(total_num, 32 // regs_per_instr) + assert 32 % atom_rows == 0 + num_row_steps = 32 // atom_rows + for lane_step in range(num_row_steps): + addr_row = arith.addi(base_addr, utils.c((lane_step * atom_rows) << 16, i32)) + cols_per_instr = instr_num * atom_cols + for num_step in range(total_num // instr_num): + num_slice = slice(num_step * instr_num, (num_step + 1) * instr_num) addr_row_col = arith.addi( - addr_row, - arith.constant(i32, num_tiling * num_group * cols_per_num_tile), + addr_row, utils.c(num_step * cols_per_instr // tmem_packing, i32) ) - regs += tmem_load(addr_row_col, load_shape, num_tiling) - regs = [llvm.bitcast(dtype, r) for r in regs] - undef = llvm.mlir_undef(ir.VectorType.get((2,), dtype)) - for r_low, r_high, idx in zip(regs[::2], regs[1::2], np.ndindex(num, 2)): - high_undef = llvm.insertelement(undef, r_low, utils.c(0, i32)) - vreg = llvm.insertelement(high_undef, r_high, utils.c(1, i32)) - vector_regs[idx[1] + 2 * row_group, idx[0]] = vreg + yield addr_row_col, instr_num, lane_step, num_slice + + +def _store_32xcols(base_addr, vector_regs, tmem_packing): + i32 = ir.IntegerType.get_signless(32) + assert vector_regs.ndim == 2 and vector_regs.shape[0] == 4 + cols = vector_regs.shape[1] * 8 + + reg_packing = 64 // utils.bitwidth(vector_regs.flat[0].type) + if reg_packing == 1: + store_shape = "16x256b" # 4 threads * 64 bits per vreg = 256 bits + regs = np.empty((4, vector_regs.shape[1], 2), dtype=object) + c0 = arith.constant(i32, 0) + c1 = arith.constant(i32, 1) + for idx, vreg in np.ndenumerate(vector_regs): + regs[(*idx, 0)] = llvm.extractelement(vreg, c0) + regs[(*idx, 1)] = llvm.extractelement(vreg, c1) + regs = regs.reshape(2, 2, vector_regs.shape[1], 2).swapaxes(1, 2) + # From a single lane perspective a num tile consists of a 2x2, with the + # minor dim traversing columns and major being 8 rows apart. + # See https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-matrix-fragments-shape-16256b + assert regs.shape[-2:] == (2, 2) + assert tmem_packing == 1 + unpack = False + elif reg_packing == 2: + store_shape = "16x128b" # 4 threads * 32 bits per vreg = 128 bits + # From a single lane perspective a num tile has 2 registers, 8 rows apart. + # See https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-matrix-fragments-shape-16128b + regs = vector_regs.reshape(2, 2, vector_regs.shape[1]).swapaxes(1, 2) + assert 1 <= tmem_packing <= 2 + unpack = tmem_packing == 1 + else: + raise NotImplementedError(reg_packing) + + it = _transfer_32xcols(base_addr, cols, (16, 8), tmem_packing, reg_packing) + for addr_row_col, instr_num, lane_step, num_slice in it: + regs_slice = regs[lane_step, num_slice].flat + tmem_store(addr_row_col, store_shape, instr_num, regs_slice, unpack) + + +def _store_32xcols_native(base_addr, vector_regs, tmem_packing): + i32 = ir.IntegerType.get_signless(32) + assert vector_regs.ndim == 1 + cols = len(vector_regs) * TMEM_NATIVE_LAYOUT.vector_length + + reg_packing = 64 // utils.bitwidth(vector_regs.flat[0].type) + store_shape = "32x32b" + if reg_packing == 1: + store_atom_shape = (32, 1) + regs = [None] * (len(vector_regs) * 2) + c0 = arith.constant(i32, 0) + c1 = arith.constant(i32, 1) + for idx, vreg in enumerate(vector_regs): + regs[2 * idx] = llvm.extractelement(vreg, c0) + regs[2 * idx + 1] = llvm.extractelement(vreg, c1) + assert tmem_packing == 1 + unpack = False + elif reg_packing == 2: + store_atom_shape = (32, 2) + regs = vector_regs + assert 1 <= tmem_packing <= 2 + unpack = tmem_packing == 1 + else: + raise NotImplementedError(reg_packing) + + it = _transfer_32xcols(base_addr, cols, store_atom_shape, tmem_packing, reg_packing) + for addr_row_col, instr_num, lane_step, num_slice in it: + assert lane_step == 0 + regs_slice = regs[num_slice] + tmem_store(addr_row_col, store_shape, instr_num, regs_slice, unpack) + + +def _load_32xcols(base_addr, cols, dtype, tmem_packing): + i32 = ir.IntegerType.get_signless(32) + vec_ty = ir.VectorType.get((2,), dtype) + reg_packing = 32 // utils.bitwidth(dtype) + if reg_packing == 1: + load_shape = "16x256b" # 4 threads * 64 bits per vreg = 256 bits + assert tmem_packing == 1 + pack = False + elif reg_packing == 2: + load_shape = "16x128b" # 4 threads * 32 bits per vreg = 128 bits + assert 1 <= tmem_packing <= 2 + pack = tmem_packing == 1 + else: + raise NotImplementedError(reg_packing) + + vector_regs = np.ndarray((4, cols // 8), dtype=object) + + it = _transfer_32xcols(base_addr, cols, (16, 8), tmem_packing, reg_packing) + c0 = arith.constant(i32, 0) + c1 = arith.constant(i32, 1) + for addr_row_col, instr_num, lane_step, num_slice in it: + regs = tmem_load(addr_row_col, load_shape, instr_num, pack) + row_slice = slice(lane_step * 2, (lane_step + 1) * 2) + # This aliases the original array, so updates will be reflected there. + vector_regs_update = vector_regs[row_slice, num_slice] + assert vector_regs_update.shape == (2, instr_num), (vector_regs_update.shape, instr_num) + if reg_packing == 1: + regs = [llvm.bitcast(dtype, r) for r in regs] + # From a single lane perspective a num tile consists of a 2x2, with the + # minor dim traversing columns and major being 8 rows apart. + # See https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-matrix-fragments-shape-16256b + regs = np.asarray(regs, dtype=object).reshape(instr_num, 2, 2).swapaxes(0, 1) + undef = llvm.mlir_undef(vec_ty) + assert regs.shape == (*vector_regs_update.shape, 2) + for idx in np.ndindex(vector_regs_update.shape): + high_undef = llvm.insertelement(undef, regs[(*idx, 0)], c0) + vreg = llvm.insertelement(high_undef, regs[(*idx, 1)], c1) + vector_regs_update[idx] = vreg + else: + assert reg_packing == 2 + regs = [llvm.bitcast(vec_ty, r) for r in regs] + # From a single lane perspective a num tile has 2 registers, 8 rows apart. + # See https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-matrix-fragments-shape-16128b + regs = np.asarray(regs, dtype=object).reshape(instr_num, 2).swapaxes(0, 1) + vector_regs_update[...] = regs + + return vector_regs + + +def _load_32xcols_native(base_addr, cols, dtype, tmem_packing): + i32 = ir.IntegerType.get_signless(32) + vec_ty = ir.VectorType.get((2,), dtype) + reg_packing = 32 // utils.bitwidth(dtype) + load_shape = "32x32b" + if reg_packing == 1: + load_atom_shape = (32, 1) + assert tmem_packing == 1 + pack = False + elif reg_packing == 2: + load_atom_shape = (32, 2) + assert 1 <= tmem_packing <= 2 + pack = tmem_packing == 1 + else: + raise NotImplementedError(reg_packing) + + it = _transfer_32xcols(base_addr, cols, load_atom_shape, tmem_packing, reg_packing) + c0 = arith.constant(i32, 0) + c1 = arith.constant(i32, 1) + regs = [None] * (cols // reg_packing) + for addr_row_col, instr_num, lane_step, num_slice in it: + assert lane_step == 0, lane_step + instr_regs = tmem_load(addr_row_col, load_shape, instr_num, pack) + if reg_packing == 1: + regs[num_slice] = [llvm.bitcast(dtype, r) for r in instr_regs] + else: + assert reg_packing == 2 + regs[num_slice] = [llvm.bitcast(vec_ty, r) for r in instr_regs] + + if reg_packing == 1: + vector_regs = np.ndarray((cols // 2,), dtype=object) + undef = llvm.mlir_undef(vec_ty) + for idx in range(vector_regs.size): + high_undef = llvm.insertelement(undef, regs[2 * idx], c0) + vreg = llvm.insertelement(high_undef, regs[2 * idx + 1], c1) + vector_regs[idx] = vreg + else: + assert reg_packing == 2 + vector_regs = np.asarray(regs, dtype=object) + + assert vector_regs.shape == (cols // TMEM_NATIVE_LAYOUT.vector_length,) return vector_regs -def _m128_256bit_32bit_layout(shape: tuple[int, ...]): +def _m128_layout(shape: tuple[int, ...]): if len(shape) != 2: raise ValueError(f"Shape {shape} is not 2D") if shape[0] % 128 != 0 or shape[1] % 8 != 0: raise ValueError(f"Shape {shape} is not a multiple of 64x8") - return fa.TiledLayout( - fa.Tiling(((128, 8), (32, 8), (8, 8), (1, 2))), - warp_dim=-8, - lane_dims=(-4, -3), - vector_dim=-1, + return LAYOUT + + +def commit_tmem(): + void = ir.Type.parse("!llvm.void") + llvm.inline_asm( + void, [], "tcgen05.wait::st.sync.aligned;", "", has_side_effects=True, ) + utils.warpgroup_barrier() diff --git a/jax/experimental/mosaic/gpu/transform_inference.py b/jax/experimental/mosaic/gpu/transform_inference.py index ef2d3661674c..f08027506334 100644 --- a/jax/experimental/mosaic/gpu/transform_inference.py +++ b/jax/experimental/mosaic/gpu/transform_inference.py @@ -20,13 +20,13 @@ from collections.abc import Callable from functools import partial -import itertools +import math from typing import cast +from jax._src import lib as jaxlib from jax._src.lib import mosaic_gpu_dialect as mgpu from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import arith -from jax._src.lib.mlir.dialects import builtin from jax._src.lib.mlir.dialects import gpu from jax._src.lib.mlir.dialects import memref from jax._src.lib.mlir.dialects import vector @@ -60,17 +60,76 @@ def _set_transform_attributes( op.attributes["out_transforms"] = ir.ArrayAttr.get(out_transforms) -def infer_transforms_for_wgmma_ref(ref_ty: ir.MemRefType) -> ir.ArrayAttr: +def _resolve_transforms( + transforms: ir.ArrayAttr | None, + other_transforms: ir.ArrayAttr | None, +) -> ir.ArrayAttr | None: + """Resolves two sets of competing transforms to a single compatible set. + + Args: + transforms: one optional set of transforms. + other_transforms: another optional set of transforms. + + Returns: + A single set of transforms that is compatible with both `transforms` and + `other_transforms`, or `None` if both transforms are `None`. + Raises: + NotImplementedError: if the two sets of transforms can't be resolved to a + single set. + """ + if transforms is None: + return other_transforms + + if other_transforms is None: + return transforms + + if len(transforms) != len(other_transforms): + raise NotImplementedError( + f"Conflicting transforms {transforms} != {other_transforms}." + ) + + new_transforms = [] + for a, b in zip(transforms, other_transforms, strict=True): + if a == b: + new_transforms.append(a) + elif mgpu.TileTransformAttr.isinstance(a) and mgpu.TileTransformAttr.isinstance(b): + a = mgpu.TileTransformAttr(a) + b = mgpu.TileTransformAttr(b) + if len(a.tiling) != len(b.tiling): + raise ValueError(f"Conflicting tile transforms {a} != {b}.") + new_tiling = [] + for tile_a, tile_b in zip(a.tiling, b.tiling): + new_tiling.append(math.gcd(tile_a, tile_b)) + new_transforms.append(mgpu.TileTransformAttr.get(new_tiling)) + else: + raise NotImplementedError(f"Unsupported transforms {a} and {b}") + + return ir.ArrayAttr.get(new_transforms) + + +def _transforms_from_uses(op: ir.OpView) -> ir.Attribute | None: + transforms = None + + for result_use in cast(ir.OpResult, op.result).uses: + consumer = result_use.owner + op_user = consumer.operands[result_use.operand_number] + user_transforms = inference_utils.in_transforms_for_operand( + consumer, op_user + ) + transforms = _resolve_transforms(transforms, user_transforms) + return transforms + + +def _infer_transforms_for_wgmma_ref( + ref_ty: ir.MemRefType, max_swizzle: mgpu.SwizzlingMode +) -> tuple[ir.ArrayAttr, mgpu.SwizzlingMode]: if len(ref_ty.shape) != 2: raise ValueError(f"Expected a 2D memref, got {ref_ty}") element_bytewidth = utils.bytewidth(ref_ty.element_type) strides, _ = ref_ty.get_strides_and_offset() - - if strides[0] < strides[1]: - raise NotImplementedError("Transpositions aren't handled yet.") - - minor_dim = ref_ty.shape[1] + transposed = strides[0] < strides[1] + minor_dim = ref_ty.shape[0 if transposed else 1] major_tiling = 8 # Try tiling with all swizzling modes starting from the largest one. @@ -80,29 +139,45 @@ def infer_transforms_for_wgmma_ref(ref_ty: ir.MemRefType) -> ir.ArrayAttr: mgpu.SwizzlingMode.k32ByteSwizzle, mgpu.SwizzlingMode.kNoSwizzle, ]: + if swizzle > max_swizzle: + continue swizzle_elems = swizzle // element_bytewidth if minor_dim % swizzle_elems == 0: minor_tiling = swizzle_elems + inferred_swizzle = swizzle break else: # No valid tile transform can be inferred. - raise ValueError( - f"{ref_ty.shape} is not a valid WGMMA shape" - ) + raise ValueError(f"{ref_ty.shape} is not a valid WGMMA shape") - return ir.ArrayAttr.get([ - mgpu.TileTransformAttr.get((major_tiling, minor_tiling)), - mgpu.SwizzleTransformAttr.get(minor_tiling * element_bytewidth), - ]) + if transposed: + tiling = (minor_tiling, major_tiling) + else: + tiling = (major_tiling, minor_tiling) + return ( + ir.ArrayAttr.get([ + mgpu.TileTransformAttr.get(tiling), + mgpu.SwizzleTransformAttr.get(minor_tiling * element_bytewidth), + ]), + inferred_swizzle, + ) @partial(_add_transform_inference_rule, mgpu.WGMMAOp) def infer_wgmma_transforms(op: mgpu.WGMMAOp) -> OptionalTransforms: - b_transforms = infer_transforms_for_wgmma_ref(ir.MemRefType(op.b.type)) + b_transforms, b_swizzle = _infer_transforms_for_wgmma_ref( + ir.MemRefType(op.b.type), max_swizzle=mgpu.SwizzlingMode.k128ByteSwizzle + ) if ir.MemRefType.isinstance(op.a.type): - a_transforms = infer_transforms_for_wgmma_ref( - cast(ir.MemRefType, op.a.type) + a_transforms, a_swizzle = _infer_transforms_for_wgmma_ref( + cast(ir.MemRefType, op.a.type), max_swizzle=b_swizzle ) + if a_swizzle != b_swizzle: + # The swizzle for a and b has to match. + b_transforms, b_swizzle = _infer_transforms_for_wgmma_ref( + ir.MemRefType(op.b.type), max_swizzle=a_swizzle + ) + assert a_swizzle == b_swizzle return [a_transforms, b_transforms], [] return [b_transforms], [] @@ -145,70 +220,42 @@ def _infer_vector_load_store_transforms( transforms = inference_utils.value_transforms(op.base) if layout == fa.WGMMA_LAYOUT: - layout_transforms = infer_transforms_for_wgmma_ref( - ir.MemRefType(op.base.type) + layout_transforms, _ = _infer_transforms_for_wgmma_ref( + ir.MemRefType(op.base.type), max_swizzle=mgpu.SwizzlingMode.k128ByteSwizzle ) - elif (isinstance(layout, fa.WGStridedFragLayout) or - isinstance(layout, fa.WGSplatFragLayout)): + elif ( + layout == fa.WGMMA_ROW_LAYOUT + or layout == fa.WGMMA_COL_LAYOUT + or isinstance(layout, fa.WGStridedFragLayout) + or isinstance(layout, fa.WGSplatFragLayout) + ): layout_transforms = None else: raise NotImplementedError( f"Got layout {layout} which is not yet supported" ) - if transforms is not None and layout_transforms is not None: - if transforms != layout_transforms: - raise NotImplementedError( - f"Conflicting transforms for {op.base} in {op}: " - f"{transforms} != {layout_transforms}." - ) - return [transforms], [] + transforms = _resolve_transforms(transforms, layout_transforms) + return None if transforms is None else ([transforms], []) - if transforms is not None: - return [transforms], [] - if layout_transforms is not None: - return [layout_transforms], [] +@partial(_add_transform_inference_rule, memref.StoreOp) +def _infer_memref_store_transforms(op: memref.StoreOp) -> OptionalTransforms: + # memref.store is only used for scalar operations, so there are no transforms. + ref_shape = ir.MemRefType(op.memref.type).shape + if ref_shape != [] and ref_shape != [1]: + raise NotImplementedError( + f"Only scalar memrefs are supported, got {ref_shape}" + ) return None -# TODO(bchetioui): remove this once jaxlib minimum version >= 0.5.2. -SliceSMEMOp = getattr(mgpu, "SliceSMEMOp", None) - -@partial(_add_transform_inference_rule, SliceSMEMOp) -def _infer_slice_smem_transforms(op: SliceSMEMOp) -> OptionalTransforms: - transforms = None - uses = cast(ir.OpResult, op.result).uses - - for op_operand_use in uses: - consumer = op_operand_use.owner - op_user = consumer.operands[op_operand_use.operand_number] - out_transforms = inference_utils.in_transforms_for_operand( - consumer, op_user - ) - if transforms is not None and out_transforms is not None: - if transforms != out_transforms: - raise NotImplementedError( - f"Conflicting transforms for {op_user} in {op}: " - f"{transforms} != {out_transforms}." - ) - elif out_transforms is not None: - transforms = out_transforms - +@partial(_add_transform_inference_rule, mgpu.SliceSMEMOp) +def _infer_slice_smem_transforms(op: mgpu.SliceSMEMOp) -> OptionalTransforms: + transforms = _transforms_from_uses(op) return None if transforms is None else ([], [transforms]) -# TODO(bchetioui,apaszke): this empty rule is necessary while Mosaic doesn't use -# the dialect in all cases. -# The rule is necessary in order to handle the lowering of `utils.memref_ptr` -# which is used in `_construct_smem_reftree`. -@partial(_add_transform_inference_rule, builtin.UnrealizedConversionCastOp) -def _infer_unrealized_conversion_cast_transforms( - _: builtin.UnrealizedConversionCastOp, -) -> OptionalTransforms: - return None - - @partial(_add_transform_inference_rule, memref.ViewOp) def _infer_memref_view_transforms(op: memref.ViewOp) -> OptionalTransforms: if not isinstance(op.source.owner.opview, gpu.DynamicSharedMemoryOp): @@ -221,45 +268,157 @@ def _infer_memref_view_transforms(op: memref.ViewOp) -> OptionalTransforms: raise NotImplementedError( "memref view with in_transforms aren't yet supported" ) - uses = cast(ir.OpResult, op.result).uses - - for op_operand_use in uses: - consumer = op_operand_use.owner - op_user = consumer.operands[op_operand_use.operand_number] - out_transforms = inference_utils.in_transforms_for_operand( - consumer, op_user - ) - if transforms is not None and out_transforms is not None: - if transforms != out_transforms: - raise ValueError( - f"Conflicting transforms for {op_user} in {op}: " - f"{transforms} != {out_transforms}." - ) - elif out_transforms is not None: - transforms = out_transforms + transforms = _transforms_from_uses(op) # TODO(bchetioui): do we actually need to assign a transform to the input of # the view op? Presumably, it'll only be used to access scratch memory. return None if transforms is None else ([], [transforms]) -# TODO(bchetioui,apaszke): this empty rule is necessary while Mosaic doesn't use -# the dialect in all cases. -@partial(_add_transform_inference_rule, gpu.DynamicSharedMemoryOp) -def _infer_dynamic_smem_transforms( - _: gpu.DynamicSharedMemoryOp, +def _get_tile_and_swizzle_transforms( + transforms: ir.ArrayAttr | None, +) -> tuple[ir.Attribute, ir.Attribute]: + if transforms is None: + return + + if len(transforms) == 2: + tile_transform, swizzle_transform = transforms + if not ( + mgpu.TileTransformAttr.isinstance(tile_transform) + and mgpu.SwizzleTransformAttr.isinstance(swizzle_transform) + ): + raise NotImplementedError(f"Unsupported transforms {transforms}.") + return tile_transform, swizzle_transform + else: + raise NotImplementedError(f"Unsupported transforms {transforms}.") + + +# This is used by Pallas' "_handle_indexing" memory transform. +@partial(_add_transform_inference_rule, memref.SubViewOp) +def _infer_memref_subview_transforms( + op: memref.SubViewOp, ) -> OptionalTransforms: - return None + transforms = _transforms_from_uses(op) + in_transforms = inference_utils.value_transforms(op.source) + transforms = _resolve_transforms(transforms, in_transforms) + + if transforms is None: + return None + + # Here, we have some transforms to propagate one way or the other. For now, + # we implement only the following basic propagation rules: + # - A tile transform can be propagated bidirectionally if the axes being + # tiled are not sliced, and are the logical minor axes of the source. + # - A swizzle transform can be propagated towards the input of a subview if + # the physical minormost dimension is unchanged. + # - We only propagate transforms if they consist of a single tile transform + # and a single swizzle transform. + # TODO(bchetioui): implement more complex propagation rules. + tile_transform, swizzle_transform = _get_tile_and_swizzle_transforms(transforms) + + # Check swizzle transform propagation. + strides, _ = ir.MemRefType.get_strides_and_offset(op.source.type) + minor_dim = strides.index(min(strides)) + if op.source.type.shape[minor_dim] != op.static_sizes[minor_dim]: + raise NotImplementedError( + "Swizzle transforms can only propagated if the minor dimension is " + "unchanged." + ) + # Check tile transform propagation. + old_tiling = mgpu.TileTransformAttr(tile_transform).tiling + num_tiled_axes = len(old_tiling) + last_n_dims = op.source.type.shape[-num_tiled_axes:] + last_n_sizes = list(op.static_sizes)[-num_tiled_axes:] + last_n_offsets = list(op.static_offsets)[-num_tiled_axes:] -def _should_have_transforms(op: ir.OpView) -> bool: - """Returns 'True' if the operation should be assigned in/out transforms.""" - return any( - map( - inference_utils.is_transformable_smem_memref, - itertools.chain(op.operands, op.results), + if any(ir.ShapedType.is_dynamic_size(x) for x in last_n_sizes): + raise NotImplementedError( + "Subview transforms with dynamic sizes are not supported." + ) + + dynamic_index = 0 + for i in range(len(last_n_offsets)): + if ir.ShapedType.is_dynamic_size(last_n_offsets[i]): + if utils.is_known_divisible( + op.offsets[dynamic_index], last_n_sizes[i] + ): + last_n_offsets[i] = last_n_sizes[i] + else: + # This will force a tiling of 1 along this axis. This is a safe choice + # (since we couldn't infer a better one) but might not be optimal. + last_n_offsets[i] = 1 + dynamic_index += 1 + + new_tiling = [ + math.gcd(*xs) + for xs in zip( + last_n_sizes, last_n_dims, last_n_offsets, old_tiling, strict=True ) + ] + + new_transforms = ir.ArrayAttr.get( + [mgpu.TileTransformAttr.get(new_tiling), swizzle_transform] ) + return [new_transforms], [new_transforms] + + +@partial(_add_transform_inference_rule, memref.TransposeOp) +def _infer_memref_transpose_transforms( + op: memref.TransposeOp, +) -> OptionalTransforms: + in_ty = ir.MemRefType(op.in_.type) + if len(in_ty.shape) != 2: + raise NotImplementedError(f"Only 2D memrefs are supported, got {in_ty}") + in_strides, _ = in_ty.get_strides_and_offset() + out_strides, _ = ir.MemRefType(op.result.type).get_strides_and_offset() + transpose = in_strides != out_strides + + out_transforms = _transforms_from_uses(op) + in_transforms = [] + if not transpose: + in_transforms = out_transforms + else: + tile_transform, swizzle_transform = _get_tile_and_swizzle_transforms( + out_transforms + ) + transposed_tiling = mgpu.TileTransformAttr(tile_transform).tiling[::-1] + in_transforms.append(mgpu.TileTransformAttr.get(transposed_tiling)) + in_transforms.append(swizzle_transform) + + return [ir.ArrayAttr.get(in_transforms)], [out_transforms] + + +# `memref.load` is used to load barrier phases---the rule needn't do anything +# interesting, but we need to have it in order to avoid crashing on it. +@partial(_add_transform_inference_rule, memref.LoadOp) +def _infer_memref_load_transforms(op: memref.LoadOp) -> OptionalTransforms: + if not ir.MemRefType(op.memref.type).shape: + # memref.load returns a scalar, so there is nothing interesting to do here. + return None + raise NotImplementedError("Non-scalar memref.load transforms") + + +@partial(_add_transform_inference_rule, memref.CastOp) +def _infer_memref_cast_transforms( + op: memref.CastOp, +) -> OptionalTransforms: + transforms = _transforms_from_uses(op) + in_transforms = inference_utils.value_transforms(op.source) + transforms = _resolve_transforms(transforms, in_transforms) + if transforms is None: + return None + return [transforms], [transforms] + + +# TODO(dasenov): Remove this after the minimal jaxlib version is 0.6.2. +if jaxlib.version >= (0, 6, 2): + @partial(_add_transform_inference_rule, mgpu.WithTransformsOp) + def _infer_mgpu_with_transforms_transforms( + op: mgpu.WithTransformsOp, + ) -> OptionalTransforms: + # Do not change the manually provided transforms. + return [op.transforms], [op.transforms] def infer_transforms(module: ir.Module): @@ -275,7 +434,7 @@ def infer_transforms(module: ir.Module): annotate the same memref. """ def inference_step(op: ir.Operation): - if not _should_have_transforms(op): + if not inference_utils.should_have_transforms(op): return elif inference_rule := _transform_inference_rules.get(op.OPERATION_NAME, None): # pytype: disable=attribute-error pass @@ -288,14 +447,43 @@ def inference_step(op: ir.Operation): _set_transform_attributes(op, *maybe_transforms) - # It's enough to do a single backwards propagation (starting from vector - # users), and then a single forward propagation (to feed into the async loads - # and stores). - for op in module.body: - inference_utils.traverse_op( - op, inference_step, inference_utils.TraversalOrder.BACKWARDS - ) + # We alternate a few backwards propagation (starting from vector users), and + # forward propagation (to feed into the async loads and stores) passes in + # order to enable more complex inference situations. + # + # TODO(bchetioui): Replace this with a more generic inference. + inference_passes = [ + inference_utils.TraversalOrder.BACKWARDS, + inference_utils.TraversalOrder.FORWARD, + inference_utils.TraversalOrder.BACKWARDS, + inference_utils.TraversalOrder.FORWARD, + ] + for traversal_order in inference_passes: + for op in module.body: + inference_utils.traverse_op(op, inference_step, traversal_order) + + # All ops that should have transforms but have no transforms inferred so far + # are assigned an empty sets of transforms. E.g., this happens in kernels with + # only pointwise operations. + def set_empty_transforms(op: ir.Operation): + if ( + inference_utils.should_have_transforms(op) + and not inference_utils.has_in_transforms_set(op) + and not inference_utils.has_out_transforms_set(op) + ): + ins = [ + ir.ArrayAttr.get([]) + for o in op.operands + if inference_utils.is_transformable_smem_memref(o) + ] + outs = [ + ir.ArrayAttr.get([]) + for r in op.results + if inference_utils.is_transformable_smem_memref(r) + ] + _set_transform_attributes(op, ins, outs) + for op in module.body: inference_utils.traverse_op( - op, inference_step, inference_utils.TraversalOrder.FORWARD + op, set_empty_transforms, inference_utils.TraversalOrder.FORWARD ) diff --git a/jax/experimental/mosaic/gpu/utils.py b/jax/experimental/mosaic/gpu/utils.py index 28534cf4025b..678aad0c91c9 100644 --- a/jax/experimental/mosaic/gpu/utils.py +++ b/jax/experimental/mosaic/gpu/utils.py @@ -40,6 +40,7 @@ # mypy: ignore-errors +WARP_SIZE: int = 32 WARPGROUP_SIZE: int = 128 DYNAMIC = -9223372036854775808 DYNAMIC32 = -2147483648 @@ -64,6 +65,9 @@ def gpu_address_space_to_nvptx(address_space: gpu.AddressSpace) -> int: def ptr_as_memref(ptr, memref_ty: ir.MemRefType, ptr_memory_space: int | None = None): + strides, offset = memref_ty.get_strides_and_offset() + if offset != 0: + raise ValueError("Non-zero offset is not supported for ptr_as_memref") i64 = ir.IntegerType.get_signless(64) rank = len(memref_ty.shape) ptr_ty = "ptr" if ptr_memory_space is None else f"ptr<{ptr_memory_space}>" @@ -84,7 +88,7 @@ def ptr_as_memref(ptr, memref_ty: ir.MemRefType, ptr_memory_space: int | None = desc = llvm.InsertValueOp( desc, llvm.ConstantOp(i64, ir.IntegerAttr.get(i64, s)), [3, i] ) - for i, s in enumerate(get_contiguous_strides(memref_ty.shape)): + for i, s in enumerate(strides): desc = llvm.InsertValueOp( desc, llvm.ConstantOp(i64, ir.IntegerAttr.get(i64, s)), [4, i] ) @@ -99,7 +103,7 @@ def pack_array(values): ptr_ty = ir.Type.parse("!llvm.ptr") arr_ptr = llvm.alloca(ptr_ty, c(len(values), i64), elem_ty) for i, v in enumerate(values): - elem_ptr = llvm.getelementptr(ptr_ty, arr_ptr, [], [i], elem_ty) + elem_ptr = llvm.getelementptr(ptr_ty, arr_ptr, [], [i], elem_ty, llvm.GEPNoWrapFlags.none) llvm.store(v, elem_ptr) return arr_ptr @@ -135,12 +139,16 @@ def _debug_scalar_ty_format(arg): return "%llu", arg if ir.F32Type.isinstance(arg.type): return "%f", arg - if ir.F16Type.isinstance(arg.type): + if ir.BF16Type.isinstance(arg.type) or ir.F16Type.isinstance(arg.type): arg = arith.extf(ir.F32Type.get(), arg) return "%f", arg raise NotImplementedError(f"Can't print the type {arg.type}") -def debug_print(fmt, *args, uniform=True): +def debug_print(fmt, *args, uniform=True, scope=None): + if not uniform and scope is not None: + raise ValueError("Cannot specify scope to a non-uniform debug_print.") + if scope is None: + scope = ThreadSubset.WARPGROUP type_formats = [] new_args = [] for arg in args: @@ -164,7 +172,7 @@ def debug_print(fmt, *args, uniform=True): raise NotImplementedError(arg.type) type_formats.append(ty_format) ctx = ( - functools.partial(single_thread, per_block=False) + functools.partial(single_thread, scope=scope) if uniform else contextlib.nullcontext ) @@ -222,15 +230,19 @@ def when(cond): scf.yield_([]) -def thread_idx(): +def _3d_to_1d_idx(dim_idx_fn, dim_size_fn): i32 = ir.IntegerType.get_signless(32) as_i32 = lambda x: arith.index_cast(i32, x) - tidx = as_i32(gpu.thread_id(gpu.Dimension.x)) - stride = as_i32(gpu.block_dim(gpu.Dimension.x)) + idx = as_i32(dim_idx_fn(gpu.Dimension.x)) + stride = as_i32(dim_size_fn(gpu.Dimension.x)) for dim in (gpu.Dimension.y, gpu.Dimension.z): - tidx = arith.addi(tidx, arith.muli(as_i32(gpu.thread_id(dim)), stride)) - stride = arith.muli(stride, as_i32(gpu.block_dim(dim))) - return tidx + idx = arith.addi(idx, arith.muli(as_i32(dim_idx_fn(dim)), stride)) + stride = arith.muli(stride, as_i32(dim_size_fn(dim))) + return idx + + +thread_idx = functools.partial(_3d_to_1d_idx, gpu.thread_id, gpu.block_dim) +block_idx = functools.partial(_3d_to_1d_idx, gpu.block_id, gpu.grid_dim) def _warp_bcast(val, lane_idx=0): @@ -258,33 +270,43 @@ def warpgroup_idx(sync=True): class ThreadSubset(enum.IntEnum): + WARP = enum.auto() WARPGROUP = enum.auto() BLOCK = enum.auto() -# True withon `once()` contexts. +# True within `once()` contexts. _ONCE_PER: ThreadSubset | None = None -def single_thread_predicate(per_block=True): +def single_thread_predicate(scope: ThreadSubset = ThreadSubset.BLOCK): + """Returns a predicate that selects a single thread. + + Args: + scope: What level of the thread hierarchy to select a thread from. + For example, if the scope is BLOCK, only one thread per block will be + selected. + """ + elected = nvvm.elect_sync(ir.IntegerType.get_signless(1)) + if scope == ThreadSubset.WARP: + return elected warp = warp_idx() - if not per_block: + if scope is not ThreadSubset.BLOCK: warp = arith.remui(warp, c(4, warp.type)) first_warp = arith.cmpi(arith.CmpIPredicate.eq, warp, c(0, warp.type)) - elected = nvvm.elect_sync(ir.IntegerType.get_signless(1)) return arith.andi(first_warp, elected) @contextlib.contextmanager -def single_thread(per_block=True): +def single_thread(scope: ThreadSubset = ThreadSubset.BLOCK): """Runs the context only from a single thread. Args: - per_block: If True, only one thread per block will run the context. - Otherwise, only one thread per warp group will run the context. + scope: What level of the thread hierarchy to select a thread from. + For example, if the scope is BLOCK, only one thread per block will be + selected. """ global _ONCE_PER - scope = ThreadSubset.BLOCK if per_block else ThreadSubset.WARPGROUP # If we're already in a single-thread context, we don't have to do anything. if _ONCE_PER is not None and _ONCE_PER >= scope: yield @@ -293,7 +315,7 @@ def single_thread(per_block=True): prev_scope = _ONCE_PER _ONCE_PER = scope try: - if_op = scf.IfOp(single_thread_predicate(per_block)) + if_op = scf.IfOp(single_thread_predicate(scope)) with ir.InsertionPoint(if_op.then_block): yield scf.YieldOp([]) @@ -446,7 +468,7 @@ def fold_until(shape, off , target) -> tuple[int, int]: # TODO(cperivol): Implement dependent fold-unfolds for subsections # of the shape eg (..., 4,5,5, ...) -> (..., 10,10, ...) could be # supported without touching any other dimensions. - raise NotImplementedError(f"Can't reshape {sh0} to {sh1} bu composing independent folds/unfolds.") + raise NotImplementedError(f"Can't reshape {sh0} to {sh1} by composing independent folds/unfolds.") raise AssertionError(f"Unreachable: number of elements don't match in each shape ({sh0} ans {sh1})") @@ -497,12 +519,42 @@ def memref_reshape(ref: ir.Value, shape: tuple[int, ...]) -> ir.Value: f" allowed) {shape}" ) - return _reshape(ref, list(ref_ty.shape), list(shape)) + src_shape = list(ref_ty.shape) + dst_shape = list(shape) + if src_shape == dst_shape: + return ref + if not src_shape: + _, offset = ref_ty.get_strides_and_offset() + identity = ir.AffineMapAttr.get(ir.AffineMap.get_identity(0)) + if ref_ty.layout == identity: + new_layout = ir.AffineMapAttr.get(ir.AffineMap.get_identity(len(dst_shape))) + else: + new_layout = ir.StridedLayoutAttr.get(offset, [1] * len(dst_shape)) + result_ty = ir.MemRefType.get(dst_shape, ref_ty.element_type, new_layout, ref_ty.memory_space) + return memref.expand_shape(result_ty, ref, [], [], dst_shape) + if not dst_shape: + _, offset = ref_ty.get_strides_and_offset() + identity = ir.AffineMapAttr.get(ir.AffineMap.get_identity(ref_ty.rank)) + contig_strided_1d = ir.Attribute.parse("strided<[1]>") + if ref_ty.layout == identity or ref_ty.layout == contig_strided_1d: + new_layout = ir.AffineMapAttr.get(ir.AffineMap.get_identity(0)) + else: + new_layout = ir.StridedLayoutAttr.get(offset, []) + result_ty = ir.MemRefType.get((), ref_ty.element_type, new_layout, ref_ty.memory_space) + return memref.collapse_shape(result_ty, ref, []) + return _reshape(ref, src_shape, dst_shape) def memref_fold(ref: ir.Value, dim, fold_rank) -> ir.Value: ref_ty = ir.MemRefType(ref.type) new_shape = list(ref_ty.shape) + if dim < 0: + raise ValueError(f"Dimension {dim} is negative") + if dim + fold_rank > len(new_shape): + raise ValueError( + f"Folding {fold_rank} dimensions starting from {dim} is out of bounds" + f" for shape {new_shape}" + ) new_shape[dim : dim + fold_rank] = [np.prod(new_shape[dim : dim + fold_rank])] identity = ir.AffineMapAttr.get(ir.AffineMap.get_identity(ref_ty.rank)) contig_strided_1d = ir.Attribute.parse("strided<[1]>") @@ -545,7 +597,8 @@ def memref_unfold(ref: ir.Value, dim, factors) -> ir.Value: ) new_shape[dim : dim + 1] = factors identity = ir.AffineMapAttr.get(ir.AffineMap.get_identity(ref_ty.rank)) - if ref_ty.layout == identity: + contig_strided_1d = ir.Attribute.parse("strided<[1]>") + if ref_ty.layout == identity or ref_ty.layout == contig_strided_1d: new_layout = ir.AffineMapAttr.get( ir.AffineMap.get_identity(ref_ty.rank + len(factors) - 1) ) @@ -691,6 +744,9 @@ def warpgroup_barrier(): has_side_effects=True, ) +def warp_barrier(): + nvvm.bar_warp_sync(c(0xffffffff, ir.IntegerType.get_signless(32))) + @dataclasses.dataclass(frozen=True) class BarrierRef: @@ -700,18 +756,23 @@ class BarrierRef: num_barriers: int @staticmethod - def initialize(address: ir.Value, num_barriers: int, arrival_count: int = 1) -> "BarrierRef": + def initialize(barrier_memref: ir.Value, arrival_count: int = 1) -> "BarrierRef": + barrier_ty = ir.MemRefType(barrier_memref.type) + [num_barriers] = barrier_ty.shape if num_barriers > 32: raise NotImplementedError("Only up to 32 barriers per group supported") i32 = ir.IntegerType.get_signless(32) i64 = ir.IntegerType.get_signless(64) ptr = ir.Type.parse(f"!llvm.ptr<{WORKGROUP_NVPTX_ADDRESS_SPACE}>") + address = memref_ptr( + barrier_memref, memory_space=WORKGROUP_NVPTX_ADDRESS_SPACE + ) phases = memref.alloca(ir.MemRefType.get((), i32), [], []) memref.store(c(0, i32), phases, []) - with single_thread(per_block=True): + with single_thread(scope=ThreadSubset.BLOCK): for i in range(num_barriers): nvvm.mbarrier_init_shared( - llvm.getelementptr(ptr, address, [], [i], i64), + llvm.getelementptr(ptr, address, [], [i], i64, llvm.GEPNoWrapFlags.none), c(arrival_count, i32), ) return BarrierRef(address, c(0, i32), phases, num_barriers) @@ -764,9 +825,27 @@ def update_parities(self, parities: ir.Value) -> tuple[ir.Value, ir.Value]: ) return parity, arith.xori(parities, bitmask) - def arrive(self): + def arrive( + self, + arrival_count: int = 1, + can_complete: bool = True, + for_tensor_core: bool = False, + ): i64 = ir.IntegerType.get_signless(64) - nvvm.mbarrier_arrive_shared(i64, self.get_ptr()) + if for_tensor_core: + llvm.inline_asm( + ir.Type.parse("!llvm.void"), + [], "tcgen05.fence::before_thread_sync;", "", + has_side_effects=True, + ) + if can_complete: + if arrival_count > 1: + count = c(arrival_count - 1, ir.IntegerType.get_signless(32)) + nvvm.mbarrier_arrive_nocomplete_shared(i64, self.get_ptr(), count) + nvvm.mbarrier_arrive_shared(i64, self.get_ptr()) + else: + count = c(arrival_count, ir.IntegerType.get_signless(32)) + nvvm.mbarrier_arrive_nocomplete_shared(i64, self.get_ptr(), count) def arrive_expect_tx( self, bytes: int | ir.Value, predicate: ir.Value | None = None @@ -783,11 +862,71 @@ def get_ptr(self): i64 = ir.IntegerType.get_signless(64) DYNAMIC32 = -2147483648 return llvm.getelementptr( - ptr, self.base_address, [self.offset], [DYNAMIC32], i64 + ptr, self.base_address, [self.offset], [DYNAMIC32], i64, llvm.GEPNoWrapFlags.none ) - def as_dialect_barrier_memref(self) -> ir.Value: - shape = () if self.num_barriers == 1 else (self.num_barriers,) + +@dataclasses.dataclass(frozen=True) +class DialectBarrierRef: + barrier_ref: BarrierRef + + @staticmethod + def initialize( + barrier_memref: ir.Value, + arrival_count: int = 1, + ) -> "DialectBarrierRef": + barrier_ty = ir.MemRefType(barrier_memref.type) + [num_barriers] = barrier_ty.shape + if num_barriers > 32: + raise NotImplementedError("Only up to 32 barriers per group supported") + + address = memref_ptr( + barrier_memref, memory_space=WORKGROUP_NVPTX_ADDRESS_SPACE + ) + dialect.InitializeBarrierOp( + barrier_ty, base_pointer=address, arrival_count=arrival_count + ) + + i32 = ir.IntegerType.get_signless(32) + phases = memref.alloca(ir.MemRefType.get((), i32), [], []) + memref.store(c(0, i32), phases, []) + return DialectBarrierRef( + barrier_ref=BarrierRef(address, c(0, i32), phases, num_barriers) + ) + + def __iter__(self) -> Iterator["DialectBarrierRef"]: + if self.barrier_ref.num_barriers == 1: + yield self + else: + for offset in range(self.barrier_ref.num_barriers): + yield self[offset] + + def __getitem__(self, offset: ir.Value | int) -> "DialectBarrierRef": + return DialectBarrierRef(self.barrier_ref[offset]) + + def wait_parity(self, parity, for_tensor_core=False): + self.barrier_ref.wait_parity(parity, for_tensor_core) + + def wait(self, for_tensor_core: bool = False): + assert self.barrier_ref.phases is not None + self.barrier_ref.wait(for_tensor_core) + + def update_parities(self, parities: ir.Value) -> tuple[ir.Value, ir.Value]: + return self.barrier_ref.update_parities(parities) + + def arrive(self): + self.barrier_ref.arrive() + + def arrive_expect_tx(self, bytes: int | ir.Value): + dialect.ArriveExpectTxOp( + barrier=self.as_barrier_memref(), expect_tx=bytes) + + def get_ptr(self): + return self.barrier_ref.get_ptr() + + def as_barrier_memref(self) -> ir.Value: + num_barriers = self.barrier_ref.num_barriers + shape = () if num_barriers == 1 else (num_barriers,) return ptr_as_memref( self.get_ptr(), ir.MemRefType.get(shape, ir.Type.parse("!mosaic_gpu.barrier")), @@ -795,8 +934,8 @@ def as_dialect_barrier_memref(self) -> ir.Value: ) @classmethod - def from_dialect_barrier_memref(cls, barrier: ir.Value): - """Creates a BarrierRef from a memref of a dialect barrier.""" + def from_barrier_memref(cls, barrier: ir.Value): + """Creates a DialectBarrierRef from a memref of a dialect barrier.""" memref_type = ir.MemRefType(barrier.type) if memref_type.rank > 1 or memref_type.element_type != ir.Type.parse( "!mosaic_gpu.barrier" @@ -807,15 +946,16 @@ def from_dialect_barrier_memref(cls, barrier: ir.Value): ) return cls( - base_address=memref_ptr( - barrier, memory_space=WORKGROUP_NVPTX_ADDRESS_SPACE - ), - offset=c(0, ir.IntegerType.get_signless(64)), - phases=None, - num_barriers=(1 if memref_type.rank == 0 else memref_type.shape[0]), + barrier_ref=BarrierRef( + base_address=memref_ptr( + barrier, memory_space=WORKGROUP_NVPTX_ADDRESS_SPACE + ), + offset=c(0, ir.IntegerType.get_signless(64)), + phases=None, + num_barriers=(1 if memref_type.rank == 0 else memref_type.shape[0]), + ) ) - @dataclasses.dataclass(frozen=True) class CollectiveBarrierRef: barrier: BarrierRef @@ -823,8 +963,7 @@ class CollectiveBarrierRef: @staticmethod def initialize( - address: ir.Value, - num_barriers: int, + barrier_memref: ir.Value, dims: Sequence[gpu.Dimension | Sequence[gpu.Dimension]], cluster_shape: tuple[int, int, int], ) -> "CollectiveBarrierRef": @@ -852,7 +991,7 @@ def initialize( cluster_mask = arith.ori( cluster_mask, cluster_collective_mask(cluster_shape, d) ) - barrier = BarrierRef.initialize(address, num_barriers, arrival_count=arrival_count) + barrier = BarrierRef.initialize(barrier_memref, arrival_count=arrival_count) return CollectiveBarrierRef(barrier, cluster_mask) def __iter__(self): @@ -862,15 +1001,21 @@ def __iter__(self): def __getitem__(self, offset): return CollectiveBarrierRef(self.barrier[offset], self.cluster_mask) - def arrive(self): + def arrive(self, for_tensor_core: bool = False): """Arrives on a barrier in all blocks that share at least one of the coordinates along the collective dimensions. Note that unlike in arrive, each warpgroup arrives once. """ + if for_tensor_core: + llvm.inline_asm( + ir.Type.parse("!llvm.void"), + [], "tcgen05.fence::before_thread_sync;", "", + has_side_effects=True, + ) if self.barrier.num_barriers != 1: raise ValueError("Can only arrive on a single barrier") if self.cluster_mask is None: - with single_thread(per_block=False): + with single_thread(scope=ThreadSubset.WARPGROUP): self.barrier.arrive() return i32 = ir.IntegerType.get_signless(32) @@ -909,6 +1054,67 @@ def wait_parity(self, *args, **kwargs): self.barrier.wait_parity(*args, **kwargs) +@dataclasses.dataclass(frozen=True) +class SemaphoreRef: + ptr: ir.Value + + def signal(self, value: ir.Value | int, predicate: ir.Value | None = None): + i32 = ir.IntegerType.get_signless(32) + if not isinstance(value, ir.Value): + value = c(value, i32) + elif value.type != i32: + raise ValueError(f"Expected a i32 value, got {value.type}") + if predicate is None: + predicate = single_thread_predicate(ThreadSubset.WARPGROUP) + llvm.inline_asm( + i32, + [self.ptr, value, predicate], + "@$3 atom.add.release.sys.global.u32 $0, [$1], $2;", + "=r,l,r,b", + has_side_effects=True, + ) + + def wait( + self, + value: ir.Value | int = 1, + scope: ThreadSubset = ThreadSubset.WARPGROUP, + ): + i32 = ir.IntegerType.get_signless(32) + if not isinstance(value, ir.Value): + value = c(value, i32) + elif value.type != i32: + raise ValueError(f"Expected a i32 value, got {value.type}") + + ne_pred = arith.CmpIPredicate.ne + + with single_thread(scope=scope): + # Create the while loop for busy waiting + while_op = scf.WhileOp([i32], [value]) + before_block = while_op.before.blocks.append(i32) + with ir.InsertionPoint.at_block_begin(before_block): + [expected_in_memory] = before_block.arguments + new_val = arith.subi(expected_in_memory, value) + in_memory = llvm.inline_asm( + i32, + [self.ptr, expected_in_memory, new_val], + "atom.acquire.sys.global.cas.b32 $0, [$1], $2, $3;", + "=r,l,r,r", + has_side_effects=True, + ) + comparison = arith.cmpi(ne_pred, in_memory, expected_in_memory) + new_expected_in_memory = arith.maxui(in_memory, value) + scf.condition(comparison, [new_expected_in_memory]) + after_block = while_op.after.blocks.append(i32) + with ir.InsertionPoint.at_block_begin(after_block): + scf.yield_(after_block.arguments) + if scope == ThreadSubset.WARPGROUP: + warpgroup_barrier() + elif scope == ThreadSubset.WARP: + warp_barrier() + else: + raise ValueError(f"Unsupported scope: {scope}") + + class Partition: source_bounds: tuple[int, ...] target_bounds: tuple[int, ...] @@ -1171,7 +1377,7 @@ def getelementptr( ) -> ir.Value: static_indices = [i if isinstance(i, int) else DYNAMIC32 for i in indices] dyn_indices = [i for i in indices if not isinstance(i, int)] - return llvm.getelementptr(ptr.type, ptr, dyn_indices, static_indices, dtype) + return llvm.getelementptr(ptr.type, ptr, dyn_indices, static_indices, dtype, llvm.GEPNoWrapFlags.none) def dyn_dot(x, y): @@ -1181,13 +1387,24 @@ def dyn_dot(x, y): def shfl_bfly(x: ir.Value, distance: int | ir.Value): i32 = ir.IntegerType.get_signless(32) + index = ir.IndexType.get() if isinstance(distance, int): distance = c(distance, i32) if (result_type := x.type) != i32: + if (x_bitwidth := bitwidth(x.type)) < 32: # Pad to 32-bits if necessary. + x = bitcast(x, ir.IntegerType.get_signless(x_bitwidth)) + empty32 = llvm.mlir_undef(ir.VectorType.get((32 // x_bitwidth,), x.type)) + x = vector.insertelement(x, empty32, position=c(0, index)) + elif x_bitwidth != 32: + raise ValueError(f"Unsupported bitwidth {x_bitwidth}") x = bitcast(x, i32) y = nvvm.shfl_sync( i32, c(0xFFFFFFFF, i32), x, distance, c(0x1F, i32), nvvm.ShflKind.bfly, ) + if (x_bitwidth := bitwidth(result_type)) < 32: + bits_ty = ir.IntegerType.get_signless(x_bitwidth) + y_vec = bitcast(y, ir.VectorType.get((32 // x_bitwidth,), bits_ty)) + y = vector.extractelement(y_vec, position=c(0, index)) return bitcast(y, result_type) @@ -1210,6 +1427,11 @@ def prmt(high: ir.Value, low: ir.Value, permutation: ir.Value): def bitcast(x: ir.Value, new_type: ir.Type): if x.type == new_type: return x + if (x_bw := bitwidth(x.type)) != (new_bw := bitwidth(new_type)): + raise ValueError( + f"Can't bitcast {x.type} (of bitwidth {x_bw}) to {new_type} (of" + f" bitwidth {new_bw})" + ) if ir.VectorType.isinstance(x.type) and ir.IntegerType.isinstance(new_type): new_type = ir.IntegerType(new_type) x_ty = ir.VectorType(x.type) @@ -1229,6 +1451,12 @@ def bitcast(x: ir.Value, new_type: ir.Type): if bitwidth(x_ty) != bitwidth(new_ty): raise ValueError(f"Can't bitcast {x.type} to {new_type}") return vector.bitcast(new_type, x) + if ir.IntegerType.isinstance(x.type) and ir.FloatType.isinstance(new_type): + return arith.bitcast(new_type, x) + if ir.FloatType.isinstance(x.type) and ir.IntegerType.isinstance(new_type): + return arith.bitcast(new_type, x) + if ir.FloatType.isinstance(x.type) and ir.FloatType.isinstance(new_type): + return arith.bitcast(new_type, x) raise ValueError(f"Can't bitcast {x.type} to {new_type}") @@ -1270,3 +1498,43 @@ def vector_concat(vectors: Sequence[ir.Value]) -> ir.Value: result = vector.insertelement(elem, result, position=c(offset + i, index)) offset += vty.shape[0] return result + + +def is_known_divisible(value, divisor, max_depth=10) -> bool: + """Returns True if the value is statically known to be divisible by the divisor.""" + if divisor == 1: + return True + if max_depth < 0 or not isinstance(value.owner, ir.Operation): + return False + + new_depth = max_depth - 1 + def_op = value.owner.opview + + match def_op: + case arith.IndexCastOp(): + return is_known_divisible(value.owner.operands[0], divisor, max_depth - 1) + case arith.ConstantOp(): + return ir.IntegerAttr(def_op.value).value % divisor == 0 + case arith.MulIOp(): + # Only cover the case where one operand is divisible. It's still possible + # that the final product is divisible, but we don't check that here. + return (is_known_divisible(value.owner.operands[0], divisor, new_depth) or + is_known_divisible(value.owner.operands[1], divisor, new_depth)) + case arith.SelectOp(): + return (is_known_divisible(value.owner.operands[1], divisor, new_depth) and + is_known_divisible(value.owner.operands[2], divisor, new_depth)) + case arith.MaxSIOp() | arith.MinSIOp() | arith.MaxUIOp() | arith.MinUIOp(): + return (is_known_divisible(value.owner.operands[0], divisor, new_depth) and + is_known_divisible(value.owner.operands[1], divisor, new_depth)) + case arith.AddIOp() | arith.SubIOp(): + # Only cover the common case where both operads are divisible. + return (is_known_divisible(value.owner.operands[0], divisor, new_depth) and + is_known_divisible(value.owner.operands[1], divisor, new_depth)) + case arith.AndIOp(): + # Only cover the specific case where the divisor is a power of two. + return divisor.bit_count() == 1 and ( + is_known_divisible(value.owner.operands[0], divisor, new_depth) + or is_known_divisible(value.owner.operands[1], divisor, new_depth) + ) + + return False diff --git a/jax/experimental/mosaic/gpu/wgmma.py b/jax/experimental/mosaic/gpu/wgmma.py index 8baa16d8a7e9..23d5174bb24b 100644 --- a/jax/experimental/mosaic/gpu/wgmma.py +++ b/jax/experimental/mosaic/gpu/wgmma.py @@ -63,7 +63,10 @@ def zero(cls, m, n, dtype=None, *, is_signed: bool | None = None): f32 = ir.F32Type.get() if dtype is None: dtype = f32 - zero = arith.constant(dtype, ir.FloatAttr.get(dtype, 0.0)) + if ir.IntegerType.isinstance(dtype): + zero = arith.constant(dtype, ir.IntegerAttr.get(dtype, 0)) + else: + zero = arith.constant(dtype, ir.FloatAttr.get(dtype, 0.0)) return cls( _value=fa.FragmentedArray.splat( zero, (m, n), fa.WGMMA_LAYOUT, is_signed=is_signed @@ -85,10 +88,13 @@ def tree_unflatten(cls, aux, value): def _supported_wgmma_types(dtype, abtype) -> bool: input_types_are = lambda ty: ty.isinstance(abtype) + f16_acc_types = (ir.F16Type, ir.Float8E5M2Type, ir.Float8E4M3FNType) if ir.F32Type.isinstance(dtype): - return any(input_types_are(ty) for ty in (ir.FloatTF32Type, ir.BF16Type, ir.F16Type)) + return any(input_types_are(ty) for ty in (ir.FloatTF32Type, ir.BF16Type, *f16_acc_types)) elif ir.F16Type.isinstance(dtype): - return input_types_are(ir.F16Type) + return any(input_types_are(ty) for ty in f16_acc_types) + elif ir.IntegerType.get_signless(32).isinstance(dtype): + return input_types_are(ir.IntegerType.get_signless(8)) else: return False @@ -107,7 +113,7 @@ def wgmma_m64( ): out_ty = ir.VectorType(acc.flat[0].type).element_type if not _supported_wgmma_types(out_ty, element_type): - raise ValueError(f"Usupported wgmma types {(out_ty, element_type)=}") + raise ValueError(f"Unsupported wgmma types {(out_ty, element_type)=}") if n % 8: raise ValueError @@ -134,7 +140,7 @@ def wgmma_m64( if a_transpose is None: raise ValueError - if ir.F32Type.isinstance(out_ty): + if ir.F32Type.isinstance(out_ty) or out_ty == i32: num_acc_regs = n // 2 out_ty_field = out_ty acc_regs = [ # pylint: disable=g-complex-comprehension @@ -142,8 +148,9 @@ def wgmma_m64( for reg in acc.flat for pos in range(2) ] - to_acc_vec_regs = functools.partial(_as_fragmented_reg_ndarray, dtype=out_ty, shape=acc.shape) - acc_constraint = "f" + to_acc_vec_regs = functools.partial( + _as_fragmented_reg_ndarray, dtype=out_ty, shape=acc.shape) + acc_constraint = "r" if ir.IntegerType.isinstance(out_ty) else "f" elif ir.F16Type.isinstance(out_ty): num_acc_regs = n // 4 out_ty_field = i32 @@ -152,9 +159,15 @@ def wgmma_m64( to_acc_vec_regs = lambda regs : np.array([_unpack_i32(vec_ty, reg) for reg in regs]).reshape(acc.shape) acc_constraint = "r" else: - raise ValueError(f"WGMMA instruciton only supports f32 and f16 out (got {out_ty})") + raise ValueError( + f"WGMMA instruction only supports f32, f16 and s32 out (got {out_ty})") - num_imm_regs = 4 if supports_transpose else 2 + if supports_transpose: + num_imm_regs = 4 + elif out_ty == i32: + num_imm_regs = 0 + else: + num_imm_regs = 2 if a_in_regs: a_reg_constraints = ["r"] * 4 # 4x f16x2 registers @@ -171,7 +184,6 @@ def wgmma_m64( + ["n"] * (1 + num_imm_regs) # literal constants ) reg_constraints = ",".join(reg_constraints_list) - reg_count = itertools.count() def take_regs(n): @@ -185,13 +197,28 @@ def take_regs(n): else: a_regs, = take_regs(1) b_desc_reg, use_out_reg = take_regs(2) - imm_regs = ", ".join(take_regs(num_imm_regs)) # Immediate regs (scale, ...). + # Immediate regs (scale, ...). + imm_regs = "".join(f", {r}" for r in take_regs(num_imm_regs)) assert next(reg_count) == len(reg_constraints_list) - el_ty = element_type k_instr = 32 // bytewidth(element_type) + el_ty = str(element_type) + if ir.Float8E5M2Type.isinstance(element_type): + el_ty = "e5m2" + elif ir.Float8E4M3FNType.isinstance(element_type): + el_ty = "e4m3" + elif ir.IntegerType.get_signless(8).isinstance(element_type): + # TODO(bchetioui): add u8 support in the future. Currently we always assume + # that 8-bit integers are s8, and we would need to change the signature of + # `wgmma` to indicate whether the input should be treated as signed or not. + el_ty = "s8" + + out_ty_str = str(out_ty) + if out_ty == i32: + out_ty_str = "s32" + wgmma_instr = ( - f"wgmma.mma_async.sync.aligned.m64n{n}k{k_instr}.{out_ty}.{el_ty}.{el_ty} " - f"{acc_reg_vector}, {a_regs}, {b_desc_reg}, p, {imm_regs};" + f"wgmma.mma_async.sync.aligned.m64n{n}k{k_instr}.{out_ty_str}.{el_ty}.{el_ty} " + f"{acc_reg_vector}, {a_regs}, {b_desc_reg}, p{imm_regs};" ) ptx = f"{{ .reg .pred p; setp.ne.b32 p, {use_out_reg}, 0; {wgmma_instr} }}\n" @@ -199,12 +226,19 @@ def lc(x): return llvm.ConstantOp(i32, ir.IntegerAttr.get(i32, x)).result use_out = scale_a = scale_b = lc(1) - imms = [use_out, scale_a, scale_b] + if out_ty == i32: + imms = [use_out] + else: + imms = [use_out, scale_a, scale_b] + if supports_transpose and a_transpose is not None: imms += [lc(int(a_transpose)), lc(int(b_transpose))] elif supports_transpose: imms += [lc(int(b_transpose))] - if acc.ndim != 10 or acc.shape[0] != 1 or math.prod(acc.shape[2:]) != 2: + + assert len(imms) == num_imm_regs + 1 # +1 for the use_out_reg in setp.ne.b32 + + if acc.ndim != 9 or acc.shape[0] != 1 or math.prod(acc.shape[2:]) != 2: raise ValueError(acc.shape) acc_struct_type = ir.Type.parse( f"!llvm.struct<({','.join(str(out_ty_field) for _ in acc_regs)})>" @@ -291,18 +325,34 @@ def wgmma( f"Accumulator shape mismatch: expected {(m, n)}, got {acc.value.shape}" ) f32 = ir.F32Type.get() + f16 = ir.F16Type.get() + i32 = ir.IntegerType.get_signless(32) + i8 = ir.IntegerType.get_signless(8) if element_type == f32 or element_type == ir.BF16Type.get(): if acc.value.mlir_dtype != f32: raise ValueError( f"WGMMA with element type {element_type} only supports accumulators" f" of type f32, but got: {acc.value.mlir_dtype}" ) - elif element_type == ir.F16Type.get(): - if acc.value.mlir_dtype != element_type and acc.value.mlir_dtype != f32: + elif any( + t.isinstance(element_type) + for t in {ir.F16Type, ir.Float8E5M2Type, ir.Float8E4M3FNType} + ): + if acc.value.mlir_dtype != f16 and acc.value.mlir_dtype != f32: raise ValueError( - "WGMMA with element type f16 only supports accumulators of type f32" - f" or f16, but got: {acc.value.mlir_dtype}" + f"WGMMA with element type {element_type} only supports accumulators " + f"of type f32 or f16, but got: {acc.value.mlir_dtype}" ) + elif element_type == i8: + if a_in_regs and not a.is_signed: + raise NotImplementedError("WGMMA with lhs of type u8") + if acc.value.mlir_dtype != i32 or not acc.value.is_signed: + raise ValueError( + f"WGMMA with element type {element_type} only supports accumulators " + f"of type s32, but got: {acc.value.mlir_dtype}" + ) + else: + raise NotImplementedError(f"Unsupported element type: {element_type}") # Step 2. Decide on the instruction shapes we'll use. Note that with swizzles, # instructions must be issued in groups of the same width as the swizzle. diff --git a/jax/experimental/multihost_utils.py b/jax/experimental/multihost_utils.py index 2bde1fbeadc4..ee7c4a8f9592 100644 --- a/jax/experimental/multihost_utils.py +++ b/jax/experimental/multihost_utils.py @@ -39,8 +39,8 @@ import numpy as np -def _psum(x: Any) -> Any: - return jax.tree.map(partial(jnp.sum, axis=0), x) +def _psum(xs: Any) -> Any: + return jax.tree.map(lambda x: jnp.sum(x, dtype=x.dtype, axis=0), xs) def broadcast_one_to_all(in_tree: Any, is_source: bool | None = None) -> Any: @@ -99,8 +99,11 @@ def _identity_fn(x): def _handle_array_process_allgather(inp, tiled): if isinstance(inp, array.ArrayImpl) and not inp.is_fully_addressable: - reps = sharding_impls.GSPMDSharding.get_replicated( - inp.sharding._device_assignment) + if isinstance(inp.sharding, sharding_impls.NamedSharding): + reps = inp.sharding.update(spec=P()) + else: + reps = sharding_impls.GSPMDSharding.get_replicated( + inp.sharding._device_assignment, memory_kind=inp.sharding.memory_kind) out = jax.jit(_identity_fn, out_shardings=reps)(inp) else: # All inputs here will be fully addressable. @@ -200,7 +203,7 @@ def should_save(step_id: int) -> bool: after some hosts are preempted. Raises: - RuntimeError: if preemption sync manager has not been inititialized. + RuntimeError: if preemption sync manager has not been initialized. """ if distributed.global_state.client is None: return False @@ -325,7 +328,7 @@ def host_local_array_to_global_array( >>> >>> host_local_output = multihost_utils.global_array_to_host_local_array(global_out, mesh, out_pspecs) # doctest: +SKIP - Please note ths function requires global mesh to be a continuous mesh, meaning + Please note this function requires global mesh to be a continuous mesh, meaning that devices that belong to each host should form a subcube in this mesh. To move local data to global array with non-continuous mesh use jax.make_array_from_callback or jax.make_array_from_single_device_arrays diff --git a/jax/experimental/ode.py b/jax/experimental/ode.py index db7865124687..bdbd52000b15 100644 --- a/jax/experimental/ode.py +++ b/jax/experimental/ode.py @@ -28,7 +28,7 @@ from functools import partial import operator as op -from typing import Callable +from collections.abc import Callable import jax from jax import api_util @@ -214,7 +214,7 @@ def body_fun(state): _, *carry = lax.while_loop(cond_fun, body_fun, [0] + carry) _, _, t, _, last_t, interp_coeff = carry relative_output_time = (target_t - last_t) / (t - last_t) - y_target = jnp.polyval(interp_coeff, relative_output_time.astype(interp_coeff.dtype)) + y_target = jnp.polyval(interp_coeff, relative_output_time.astype(interp_coeff.dtype)) # pytype: disable=attribute-error return carry, y_target f0 = func_(y0, ts[0]) diff --git a/jax/experimental/pallas/__init__.py b/jax/experimental/pallas/__init__.py index 1e0abacfc25f..5c0ef332454c 100644 --- a/jax/experimental/pallas/__init__.py +++ b/jax/experimental/pallas/__init__.py @@ -15,26 +15,33 @@ """Module for Pallas, a JAX extension for custom kernels. See the Pallas documentation at -https://jax.readthedocs.io/en/latest/pallas.html. +https://docs.jax.dev/en/latest/pallas.html. """ +from jax._src.pallas.core import BlockDim as BlockDim from jax._src.pallas.core import Blocked as Blocked from jax._src.pallas.core import BlockSpec as BlockSpec +from jax._src.pallas.core import BoundedSlice as BoundedSlice +from jax._src.pallas.core import Buffered as Buffered from jax._src.pallas.core import CompilerParams as CompilerParams from jax._src.pallas.core import core_map as core_map from jax._src.pallas.core import CostEstimate as CostEstimate +from jax._src.pallas.core import Element as Element from jax._src.pallas.core import GridSpec as GridSpec -from jax._src.pallas.core import IndexingMode as IndexingMode from jax._src.pallas.core import lower_as_mlir as lower_as_mlir from jax._src.pallas.core import MemoryRef as MemoryRef from jax._src.pallas.core import MemorySpace as MemorySpace -from jax._src.pallas.core import Buffered as Buffered from jax._src.pallas.core import no_block_spec as no_block_spec -from jax._src.pallas.core import Unblocked as Unblocked -from jax._src.pallas.core import unblocked as unblocked +from jax._src.pallas.core import semaphore as semaphore +from jax._src.pallas.core import Squeezed as Squeezed +from jax._src.pallas.core import squeezed as squeezed from jax._src.pallas.cost_estimate import estimate_cost as estimate_cost +from jax._src.pallas.helpers import debug_check as debug_check +from jax._src.pallas.helpers import debug_checks_enabled as debug_checks_enabled from jax._src.pallas.helpers import empty as empty from jax._src.pallas.helpers import empty_like as empty_like +from jax._src.pallas.helpers import enable_debug_checks as enable_debug_checks +from jax._src.pallas.helpers import loop as loop from jax._src.pallas.helpers import when as when from jax._src.pallas.pallas_call import pallas_call as pallas_call from jax._src.pallas.pallas_call import pallas_call_p as pallas_call_p @@ -47,6 +54,7 @@ from jax._src.pallas.primitives import atomic_xchg as atomic_xchg from jax._src.pallas.primitives import atomic_xor as atomic_xor from jax._src.pallas.primitives import debug_print as debug_print +from jax._src.pallas.primitives import DeviceIdType as DeviceIdType from jax._src.pallas.primitives import dot as dot from jax._src.pallas.primitives import load as load from jax._src.pallas.primitives import max_contiguous as max_contiguous @@ -55,6 +63,9 @@ from jax._src.pallas.primitives import program_id as program_id from jax._src.pallas.primitives import reciprocal as reciprocal from jax._src.pallas.primitives import run_scoped as run_scoped +from jax._src.pallas.primitives import semaphore_read as semaphore_read +from jax._src.pallas.primitives import semaphore_signal as semaphore_signal +from jax._src.pallas.primitives import semaphore_wait as semaphore_wait from jax._src.pallas.primitives import store as store from jax._src.pallas.primitives import swap as swap from jax._src.pallas.utils import cdiv as cdiv diff --git a/jax/experimental/pallas/fuser.py b/jax/experimental/pallas/fuser.py index 729a447b7408..d4ec7e89cc7d 100644 --- a/jax/experimental/pallas/fuser.py +++ b/jax/experimental/pallas/fuser.py @@ -19,6 +19,6 @@ from jax._src.pallas.fuser.block_spec import pull_block_spec as pull_block_spec from jax._src.pallas.fuser.block_spec import push_block_spec as push_block_spec from jax._src.pallas.fuser.custom_evaluate import evaluate as evaluate -from jax._src.pallas.fuser.fusable import fusable as fusable +from jax._src.pallas.fuser.fusible import fusible as fusible from jax._src.pallas.fuser.fusion import Fusion as Fusion from jax._src.pallas.fuser.jaxpr_fusion import fuse as fuse diff --git a/jax/experimental/pallas/g3doc/debugging.md b/jax/experimental/pallas/g3doc/debugging.md index 6dfa95eb16fa..f1f22999d3af 100644 --- a/jax/experimental/pallas/g3doc/debugging.md +++ b/jax/experimental/pallas/g3doc/debugging.md @@ -3,7 +3,7 @@ [TOC] @@ -16,10 +16,39 @@ a ticket on https://github.com/jax-ml/jax/issues. ### Interpret (HLO) Mode -Passing in `interpret=True` into `pl.pallas_call` will run the kernel in HLO instead of lowering to Mosaic/Triton. This is useful for checking correctness of your program and prototyping on smaller block sizes (as TPUs kernels require block sizes of at least 8x128). HLO is also more feature-complete so sometimes kernels will run in interpret mode but fail otherwise - this will make sure the bug is not in your kernel but in Pallas. +Passing in `interpret=True` into `pl.pallas_call` or `pl.core_map` will run the kernel in HLO instead of lowering to Mosaic/Triton. This is useful for checking correctness of your program and prototyping on smaller block sizes (as TPUs kernels require block sizes of at least 8x128). HLO is also more feature-complete so sometimes kernels will run in interpret mode but fail otherwise - this will make sure the bug is not in your kernel but in Pallas. Note that interpret mode will not be able to fully replicate the behavior or programs that use communication (DMAs) between devices. This is because low-level communication APIs are more general than the interface that XLA provides via SPMD collective operations. +### TPU Interpret Mode + +TPU interpret mode is similar to [interpret (HLO) mode](#interpret-hlo-mode), +but TPU interpret mode explicitly simulates accesses to TPU memory (HBM, VMEM, +SMEM, etc.), communication via remote DMAs, TPU synchronization operations +(e.g., barriers and semaphores), and parallel execution of kernels distributed +across +[multiple TPUs](https://docs.jax.dev/en/latest/pallas/tpu/distributed.html) and +[Megacore cores](https://docs.jax.dev/en/latest/pallas/tpu/distributed.html#megacore). + +TPU interpret mode is slower than interpret (HLO) mode, but it can be useful for +developing and debugging distributed TPU kernels with explicit communication and +synchronization. With this mode, kernels can be run on CPU -- enabling local +development (with no TPU), using a debugger and inspecting the state of +simulated TPU buffers and semaphores, etc. + +To use TPU interpret mode, pass `interpret=pltpu.InterpretParams()` into +`pl.pallas_call` or `pl.core_map`. For examples, see +`test_matmul_example` in +[tpu_pallas_interpret_test.py](https://github.com/jax-ml/jax/blob/main/tests/pallas/tpu_pallas_interpret_test.py#:~:text=test_matmul_example) +and +`test_right_permute_example` and the other tests in +[tpu_pallas_interpret_distributed_test.py](https://github.com/jax-ml/jax/blob/main/tests/pallas/tpu_pallas_interpret_distributed_test.py#:~:text=test_right_permute_example). + +The behavior of TPU interpret mode can be configured via arguments to +[`pltpu.InterpretParams`](https://github.com/jax-ml/jax/blob/main/jax/_src/pallas/mosaic/interpret.py#:~:text=class%20InterpretParams). For example, use `num_cores_per_device=2` +to simulate Megacore or `uninitialized_memory='zero'` to initialize simuluated +TPU buffers with zeros instead of NaNs. + ### debug_print The `pl.debug_print` function can be used to print runtime values inside of a kernel. @@ -45,16 +74,14 @@ as a Python error after the kernel has successfully executed. #### Hard assertion -Hard assertions can be inserted with `checkify.check` -and running your program with the `--jax_pallas_enable_runtime_assert` flag. +Hard assertions can be inserted with `pl.debug_check` +and running your program with the `--jax_pallas_enable_debug_checks` flag. Your code will look like the following: ```python -from jax.experimental import checkify - def kernel(...): - checkify.check(x > y, "Check x > y failed") # Will halt if x <= y + pl.debug_check(x > y, "Check x > y failed") # Will halt if x <= y ``` This will print a relatively lengthy dump which resembles the following: @@ -76,11 +103,10 @@ Functionalized asserts can be performed by checkify-ing the `pl.pallas_call` op from jax.experimental import checkify def kernel(...): - checkify.check(x > y, "Check x > y failed") # Will throw an error if x <= y + pl.debug_check(x > y, "Check x > y failed") # Will throw an error if x <= y kernel = pl.pallas_call(...) -checkified_kernel = checkify.checkify(kernel, - errors=checkify.all_checks) +checkified_kernel = checkify.checkify(kernel, errors=checkify.all_checks) error, result = checkified_kernel(x) error.throw() ``` @@ -163,11 +189,39 @@ spin -a dump.pml && gcc -o pan -O3 pan.c -Wno-format-overflow && time ./pan +### Dynamic Race Detection + +[TPU Interpret Mode](#tpu-interpret-mode) includes a dynamic race detector. +While running a kernel, it can detect and log data races -- pairs of accesses +to shared memory (HBM, VMEM, SMEM, etc.) that are not properly synchronized. + +To enable the dynamic race detector, use the option `detect_races=True` in the +`pltpu.InterpretParams` passed to `pl.pallas_call`: + +```python +pl.pallas_call( + kernel, + ..., + intepret=pltpu.InterpretParams(..., detect_races=True), +) +``` + +If any data races are detected while running the kernel, a message will be +printed -- for example: + +``` +RACE DETECTED + write ... from ...jax/tests/pallas/tpu_pallas_interpret_distributed_test.py:1038:10 (InterpretDistributedTest.test_race_detection..kernel.._) + write ... from .../jax/tests/pallas/tpu_pallas_interpret_distributed_test.py:1038:10 (InterpretDistributedTest.test_race_detection..kernel.._) +``` + + + ## Useful Command line flags * OOB Checks: `--xla_mosaic_on_device_checks=bounds` * Poison VMEM allocations: `--xla_jf_poison_vmem_allocations=true` - + * Dump Mosaic: `--xla_mosaic_dump_to=` * Enable trace markers in XProf: `--xla_enable_transpose_trace` @@ -203,5 +257,3 @@ In most cases the error message should hint at what is wrong. For specific errors: * `Mixed dtype operands in cmp` when using `jnp.mod`: Use lax.rem instead of jnp.mod - - diff --git a/jax/experimental/pallas/mosaic_gpu.py b/jax/experimental/pallas/mosaic_gpu.py index 631b4f720984..1c47d391aa65 100644 --- a/jax/experimental/pallas/mosaic_gpu.py +++ b/jax/experimental/pallas/mosaic_gpu.py @@ -18,17 +18,27 @@ """ from jax._src.pallas.mosaic_gpu.core import Barrier as Barrier -from jax._src.pallas.mosaic_gpu.core import GPUBlockSpec as GPUBlockSpec -from jax._src.pallas.mosaic_gpu.core import GPUCompilerParams as GPUCompilerParams -from jax._src.pallas.mosaic_gpu.core import GPUMemorySpace as GPUMemorySpace -from jax._src.pallas.mosaic_gpu.core import GPUMesh as GPUMesh +from jax._src.pallas.mosaic_gpu.core import ClusterBarrier as ClusterBarrier +from jax._src.pallas.mosaic_gpu.core import BlockSpec as BlockSpec +from jax._src.pallas.mosaic_gpu.core import CompilerParams as CompilerParams +from jax._src.pallas.mosaic_gpu.core import Mesh as Mesh +from jax._src.pallas.mosaic_gpu.core import MemorySpace as MemorySpace from jax._src.pallas.mosaic_gpu.core import kernel as kernel +from jax._src.pallas.mosaic_gpu.core import PeerMemRef as PeerMemRef +from jax._src.pallas.mosaic_gpu.core import RefUnion as RefUnion +from jax._src.pallas.mosaic_gpu.core import remote_ref as remote_ref +from jax._src.pallas.mosaic_gpu.core import SemaphoreType as SemaphoreType from jax._src.pallas.mosaic_gpu.core import SwizzleTransform as SwizzleTransform from jax._src.pallas.mosaic_gpu.core import TilingTransform as TilingTransform +from jax._src.pallas.mosaic_gpu.core import transform_ref as transform_ref from jax._src.pallas.mosaic_gpu.core import transpose_ref as transpose_ref +from jax._src.pallas.mosaic_gpu.core import untile_ref as untile_ref +from jax._src.pallas.mosaic_gpu.core import unswizzle_ref as unswizzle_ref from jax._src.pallas.mosaic_gpu.core import TransposeTransform as TransposeTransform +from jax._src.pallas.mosaic_gpu.core import WarpMesh as WarpMesh from jax._src.pallas.mosaic_gpu.core import WGMMAAccumulatorRef as ACC # noqa: F401 from jax._src.pallas.mosaic_gpu.core import WGMMAAccumulatorRef as WGMMAAccumulatorRef +from jax._src.pallas.mosaic_gpu.helpers import nd_loop as nd_loop from jax._src.pallas.mosaic_gpu.pipeline import emit_pipeline as emit_pipeline from jax._src.pallas.mosaic_gpu.pipeline import emit_pipeline_warp_specialized as emit_pipeline_warp_specialized from jax._src.pallas.mosaic_gpu.primitives import barrier_arrive as barrier_arrive @@ -36,18 +46,26 @@ from jax._src.pallas.mosaic_gpu.primitives import broadcasted_iota as broadcasted_iota from jax._src.pallas.mosaic_gpu.primitives import commit_smem as commit_smem from jax._src.pallas.mosaic_gpu.primitives import commit_smem_to_gmem_group as commit_smem_to_gmem_group +from jax._src.pallas.mosaic_gpu.primitives import ShapeDtypeStruct as ShapeDtypeStruct from jax._src.pallas.mosaic_gpu.primitives import copy_gmem_to_smem as copy_gmem_to_smem from jax._src.pallas.mosaic_gpu.primitives import copy_smem_to_gmem as copy_smem_to_gmem +from jax._src.pallas.mosaic_gpu.primitives import inline_mgpu as inline_mgpu from jax._src.pallas.mosaic_gpu.primitives import Layout as Layout from jax._src.pallas.mosaic_gpu.primitives import layout_cast as layout_cast +from jax._src.pallas.mosaic_gpu.primitives import load as load +from jax._src.pallas.mosaic_gpu.primitives import RefType as RefType from jax._src.pallas.mosaic_gpu.primitives import set_max_registers as set_max_registers from jax._src.pallas.mosaic_gpu.primitives import wait_smem_to_gmem as wait_smem_to_gmem from jax._src.pallas.mosaic_gpu.primitives import wgmma as wgmma from jax._src.pallas.mosaic_gpu.primitives import wgmma_wait as wgmma_wait -from jax.experimental.mosaic.gpu.core import ThreadSemantics as ThreadSemantics +from jax._src.pallas.mosaic_gpu.primitives import tcgen05_mma as tcgen05_mma +from jax._src.pallas.mosaic_gpu.primitives import commit_tmem as commit_tmem +from jax.experimental.mosaic.gpu.core import LoweringSemantics as LoweringSemantics -#: Alias of :data:`jax.experimental.pallas.mosaic_gpu.GPUMemorySpace.GMEM`. -GMEM = GPUMemorySpace.GMEM -#: Alias of :data:`jax.experimental.pallas.mosaic_gpu.GPUMemorySpace.SMEM`. -SMEM = GPUMemorySpace.SMEM +#: Alias of :data:`jax.experimental.pallas.mosaic_gpu.MemorySpace.GMEM`. +GMEM = MemorySpace.GMEM +#: Alias of :data:`jax.experimental.pallas.mosaic_gpu.MemorySpace.SMEM`. +SMEM = MemorySpace.SMEM +#: Alias of :data:`jax.experimental.pallas.mosaic_gpu.MemorySpace.TMEM`. +TMEM = MemorySpace.TMEM diff --git a/jax/experimental/pallas/ops/gpu/attention.py b/jax/experimental/pallas/ops/gpu/attention.py index 8b83d24ea199..4782fc31226e 100644 --- a/jax/experimental/pallas/ops/gpu/attention.py +++ b/jax/experimental/pallas/ops/gpu/attention.py @@ -57,10 +57,10 @@ def get_default(cls): return BlockSizes( block_q=128, block_k=128, - block_q_dkv=128, - block_kv_dkv=128, - block_q_dq=128, - block_kv_dq=128, + block_q_dkv=32, + block_kv_dkv=32, + block_q_dq=32, + block_kv_dq=32, ) @property @@ -86,28 +86,29 @@ def mha_forward_kernel( segment_ids_ref: jax.Array | None, # segment_id arrays o_ref: Any, # Output *residual_refs: Any, # Residual outputs - num_heads: int, sm_scale: float, causal: bool, block_q: int, - block_d: int, block_k: int, + head_dim: int, ): seq_len = k_ref.shape[0] start_q = pl.program_id(0) + head_dim_padded = q_ref.shape[-1] # o is the buffer where we accumulate the output on sram. # m_i and l_i (see FlashAttention paper) are updated during the k,v loop. m_i = jnp.zeros(block_q, dtype=jnp.float32) - float('inf') l_i = jnp.zeros(block_q, dtype=jnp.float32) # acc is the buffer where we accumulate the output on sram. - o = jnp.zeros((block_q, block_d), dtype=jnp.float32) + o = jnp.zeros((block_q, head_dim_padded), dtype=jnp.float32) # Load q: it will stay in L1 throughout. Indices form a matrix because we # read, compute, and write all in 2d chunks. 1 element ~= 1 CUDA thread index. - # q tile has shape [block_q, block_d], block_d == head_dim. + # q tile has shape [block_q, head_dim_padded], head_dim_padded >= head_dim. curr_q_slice = pl.dslice(start_q * block_q, block_q) - q = q_ref[...] + head_mask = (jnp.arange(head_dim_padded) < head_dim)[None, :] + q = pl.load(q_ref, (slice(None), slice(None)), mask=head_mask, other=0.0) q_segment_ids = ( None if segment_ids_ref is None @@ -121,7 +122,7 @@ def body(start_k, carry): o_prev, m_prev, l_prev = carry curr_k_slice = pl.dslice(start_k * block_k, block_k) - k = pl.load(k_ref, (curr_k_slice, slice(None))) + k = pl.load(k_ref, (curr_k_slice, slice(None)), mask=head_mask, other=0.0) qk = pl.dot(q, k.T) # [block_q, block_k] # Scale logits to convert from base-2 to the natural log domain. @@ -151,7 +152,7 @@ def body(start_k, carry): # Apply mask to qk. qk = jnp.where(mask, qk, DEFAULT_MASK_VALUE) - m_curr = qk.max(axis=-1) + m_curr = jnp.max(qk, axis=-1) m_next = jnp.maximum(m_prev, m_curr) correction = jnp.exp2(m_prev - m_next) l_prev_corr = correction * l_prev @@ -161,7 +162,7 @@ def body(start_k, carry): l_curr = s_curr.sum(axis=-1) l_next = l_prev_corr + l_curr o_prev_corr = correction[:, None] * o_prev - v = pl.load(v_ref, (curr_k_slice, pl.dslice(block_d))) + v = pl.load(v_ref, (curr_k_slice, slice(None)), mask=head_mask) o_curr = pl.dot(s_curr.astype(v.dtype), v) o_next = o_prev_corr + o_curr @@ -182,7 +183,8 @@ def body(start_k, carry): lse_ref = residual_refs[0] lse_ref[...] = m_i + jnp.log2(l_i) # Write output to dram. - o_ref[...] = o.astype(o_ref.dtype) + pl.store(o_ref, (slice(None), slice(o.shape[-1])), o.astype(o_ref.dtype), + mask=head_mask) def segment_mask( q_segment_ids: jax.Array, @@ -199,7 +201,7 @@ def segment_mask( @functools.partial( - jax.custom_vjp, nondiff_argnums=[4, 5, 6, 7, 8, 9, 10, 11, 12] + jax.custom_vjp, nondiff_argnums=[4, 5, 6, 7, 8, 9, 10, 11, 12, 13] ) @functools.partial( jax.jit, @@ -213,6 +215,7 @@ def segment_mask( "grid", "interpret", "debug", + "return_residuals", ], ) def mha( @@ -229,12 +232,24 @@ def mha( grid: tuple[int, ...] | None = None, interpret: bool = False, debug: bool = False, + return_residuals: bool = False, ): del backward_pass_impl batch_size, q_seq_len, num_heads, head_dim = q.shape kv_seq_len = k.shape[1] block_q = min(block_sizes.block_q, q_seq_len) block_k = min(block_sizes.block_k, kv_seq_len) + head_dim_padded = pl.next_power_of_2(head_dim) + if (q.shape[-1] != k.shape[-1]) or (q.shape[-1] != v.shape[-1]): + raise ValueError( + f"This kernel expects q, k, and v to have the same head dimension, but" + f" found {q.shape=}, {k.shape=}, {v.shape=}." + ) + if q_seq_len % block_q != 0: + raise ValueError(f"{q_seq_len=} must be a multiple of {block_q=}") + if kv_seq_len % block_k != 0: + raise ValueError(f"{kv_seq_len=} must be a multiple of {block_k=}") + # Heuristics. grid_ = grid if grid_ is None: @@ -243,42 +258,44 @@ def mha( num_warps_ = num_warps if num_warps_ is None: num_warps_ = 4 if head_dim <= 64 else 8 - kernel = functools.partial(mha_forward_kernel, num_heads=num_heads, - sm_scale=sm_scale, block_q=block_q, - block_k=block_k, block_d=head_dim, - causal=causal) + kernel = functools.partial(mha_forward_kernel, sm_scale=sm_scale, + block_q=block_q, block_k=block_k, + head_dim=head_dim, causal=causal) in_specs = [ - pl.BlockSpec( - (None, block_q, None, head_dim), lambda i, j, k: (j, i, k, 0) - ), - pl.BlockSpec( - (None, kv_seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0) - ), - pl.BlockSpec( - (None, kv_seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0) - ), + pl.BlockSpec((None, block_q, None, head_dim_padded), + lambda i, j, k: (j, i, k, 0)), + pl.BlockSpec((None, kv_seq_len, None, head_dim_padded), + lambda _, j, k: (j, 0, k, 0)), + pl.BlockSpec((None, kv_seq_len, None, head_dim_padded), + lambda _, j, k: (j, 0, k, 0)), ] in_specs.append( None # type: ignore[arg-type] if segment_ids is None else pl.BlockSpec((None, kv_seq_len), lambda _, j, k: (j, 0)) ) - out_shape = jax.ShapeDtypeStruct(shape=q.shape, dtype=q.dtype) - return pl.pallas_call( + out_shape = [q] + out_specs = [pl.BlockSpec((None, block_q, None, head_dim_padded), + lambda i, j, k: (j, i, k, 0))] + if return_residuals: + out_shape.append(jax.ShapeDtypeStruct( + shape=(batch_size, num_heads, q_seq_len), dtype=jnp.float32)) # lse + out_specs.append( + pl.BlockSpec((None, None, block_q), lambda i, j, k: (j, k, i))) # lse + out = pl.pallas_call( kernel, grid=grid_, in_specs=in_specs, - out_specs=pl.BlockSpec( - (None, block_q, None, head_dim), lambda i, j, k: (j, i, k, 0) - ), - compiler_params=plgpu.TritonCompilerParams( + out_specs=out_specs, + compiler_params=plgpu.CompilerParams( num_warps=num_warps_, num_stages=num_stages), out_shape=out_shape, debug=debug, interpret=interpret, name="mha_forward", )(q, k, v, segment_ids) + return out if return_residuals else out[0] def _mha_forward( @@ -295,70 +312,24 @@ def _mha_forward( grid: Any, interpret: bool, debug: bool, + return_residuals: bool, ): - del backward_pass_impl - batch_size, q_seq_len, num_heads, head_dim = q.shape - kv_seq_len = k.shape[1] - block_q = min(block_sizes.block_q, q_seq_len) - block_k = min(block_sizes.block_k, kv_seq_len) - # Heuristics. - grid_ = grid - if grid_ is None: - grid_ = (pl.cdiv(q_seq_len, block_q), batch_size, num_heads) - - num_warps_ = num_warps - if num_warps_ is None: - num_warps_ = 4 if head_dim <= 64 else 8 - kernel = functools.partial(mha_forward_kernel, num_heads=num_heads, - sm_scale=sm_scale, causal=causal, block_q=block_q, - block_k=block_k, block_d=head_dim) - out_shape = [ - jax.ShapeDtypeStruct(shape=q.shape, dtype=q.dtype), # out - jax.ShapeDtypeStruct( - shape=(batch_size, num_heads, q_seq_len), dtype=jnp.float32 # lse - ), - ] - in_specs = [ - pl.BlockSpec( - (None, block_q, None, head_dim), lambda i, j, k: (j, i, k, 0) - ), - pl.BlockSpec( - (None, kv_seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0) - ), - pl.BlockSpec( - (None, kv_seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0) - ), - ] - in_specs.append( - None # type: ignore[arg-type] - if segment_ids is None - else pl.BlockSpec((None, kv_seq_len), lambda _, j, k: (j, 0)) - ) - out, lse = pl.pallas_call( - kernel, - grid=grid_, - in_specs=in_specs, - out_specs=[ - pl.BlockSpec( - (None, block_q, None, head_dim), lambda i, j, k: (j, i, k, 0) - ), - pl.BlockSpec((None, None, block_q), lambda i, j, k: (j, k, i)), - ], - compiler_params=plgpu.TritonCompilerParams( - num_warps=num_warps_, num_stages=num_stages - ), - out_shape=out_shape, - debug=debug, - interpret=interpret, - name="mha_forward", - )(q, k, v, segment_ids) - return out, (q, k, v, segment_ids, out, lse) - - -def _preprocess_backward_kernel(out_ref, dout_ref, delta_ref): + out, lse = mha(q, k, v, segment_ids=segment_ids, sm_scale=sm_scale, + causal=causal, block_sizes=block_sizes, + backward_pass_impl=backward_pass_impl, + num_warps=num_warps, num_stages=num_stages, + grid=grid, interpret=interpret, debug=debug, + return_residuals=True) + residuals = (q, k, v, segment_ids, out, lse) + ret = (out, lse) if return_residuals else out + return ret, residuals + + +def _preprocess_backward_kernel(out_ref, dout_ref, delta_ref, head_dim: int): # load - o = out_ref[...].astype(jnp.float32) - do = dout_ref[...].astype(jnp.float32) + head_mask = (jnp.arange(out_ref.shape[-1]) < head_dim)[None, :] + o = pl.load(out_ref, (slice(None), slice(None)), mask=head_mask, other=0.0) + do = pl.load(dout_ref, (slice(None), slice(None)), mask=head_mask, other=0.0) # compute delta = jnp.sum(o * do, axis=1) # write-back @@ -368,20 +339,19 @@ def _preprocess_backward_kernel(out_ref, dout_ref, delta_ref): def _preprocess_backward(out, do, lse, block_q: int, debug: bool, interpret: bool): batch_size, seq_len, num_heads, head_dim = out.shape + head_dim_padded = pl.next_power_of_2(head_dim) out_shape = jax.ShapeDtypeStruct(lse.shape, lse.dtype) delta = pl.pallas_call( - _preprocess_backward_kernel, + functools.partial(_preprocess_backward_kernel, head_dim=head_dim), grid=(pl.cdiv(seq_len, block_q), batch_size, num_heads), in_specs=[ - pl.BlockSpec( - (None, block_q, None, head_dim), lambda i, j, k: (j, i, k, 0) - ), - pl.BlockSpec( - (None, block_q, None, head_dim), lambda i, j, k: (j, i, k, 0) - ), + pl.BlockSpec((None, block_q, None, head_dim_padded), + lambda i, j, k: (j, i, k, 0)), + pl.BlockSpec((None, block_q, None, head_dim_padded), + lambda i, j, k: (j, i, k, 0)), ], out_specs=pl.BlockSpec((None, None, block_q), lambda i, j, k: (j, k, i)), - compiler_params=plgpu.TritonCompilerParams(num_warps=4, num_stages=3), + compiler_params=plgpu.CompilerParams(num_warps=4, num_stages=3), out_shape=out_shape, debug=debug, interpret=interpret, @@ -414,7 +384,7 @@ def mha_backward_kernel( block_kv_dkv: int, block_q_dq: int, block_kv_dq: int, - block_d: int, + head_dim: int, ): del out_ref # Not needed q_seq_len = q_ref.shape[0] @@ -427,11 +397,13 @@ def mha_backward_kernel( start_k = pl.program_id(2) curr_k_slice = pl.dslice(start_k * block_kv_dkv, block_kv_dkv) - dv = jnp.zeros([block_kv_dkv, block_d], dtype=jnp.float32) - dk = jnp.zeros([block_kv_dkv, block_d], dtype=jnp.float32) + head_dim_padded = q_ref.shape[-1] + dv = jnp.zeros([block_kv_dkv, head_dim_padded], dtype=jnp.float32) + dk = jnp.zeros([block_kv_dkv, head_dim_padded], dtype=jnp.float32) - v = pl.load(v_ref, (curr_k_slice, slice(None))) - k = pl.load(k_ref, (curr_k_slice, slice(None))) + head_mask = (jnp.arange(head_dim_padded) < head_dim)[None, :] + v = pl.load(v_ref, (curr_k_slice, slice(None)), mask=head_mask, other=0.0) + k = pl.load(k_ref, (curr_k_slice, slice(None)), mask=head_mask, other=0.0) span_k = start_k * block_kv_dkv + jnp.arange(block_kv_dkv) kv_segment_ids = ( None @@ -443,7 +415,7 @@ def inner_loop_dkdv(start_q, carry): dv, dk = carry curr_q_slice = pl.dslice(start_q * block_q_dkv, block_q_dkv) - q = pl.load(q_ref, (curr_q_slice, slice(None))) + q = pl.load(q_ref, (curr_q_slice, slice(None)), mask=head_mask, other=0.0) qk = pl.dot(q, k.T) qk_scale = math.log2(math.e) if sm_scale != 1.: @@ -466,7 +438,8 @@ def inner_loop_dkdv(start_q, carry): lse = pl.load(lse_ref, (curr_q_slice,)) di = pl.load(delta_ref, (curr_q_slice,)) - do = pl.load(do_scaled_ref, (curr_q_slice, slice(None))) + do = pl.load(do_scaled_ref, (curr_q_slice, slice(None)), mask=head_mask, + other=0.0) p = jnp.exp2(qk - lse[:, None]) dv = dv + pl.dot(p.astype(do.dtype).T, do) @@ -483,8 +456,10 @@ def inner_loop_dkdv(start_q, carry): dv, dk = lax.fori_loop( lower_bound, pl.cdiv(q_seq_len, block_q_dkv), inner_loop_dkdv, (dv, dk) ) - dv_ref[...] = dv.astype(dv_ref.dtype) - dk_ref[...] = dk.astype(dk_ref.dtype) + pl.store(dv_ref, (slice(None), slice(dv.shape[-1])), dv.astype(dv_ref.dtype), + mask=head_mask) + pl.store(dk_ref, (slice(None), slice(dk.shape[-1])), dk.astype(dk_ref.dtype), + mask=head_mask) # Scan #2: dQ # 1. Load a block of Q of size (block_q_dq, head_dim) in SMEM. @@ -493,22 +468,23 @@ def inner_loop_dkdv(start_q, carry): start_q = pl.program_id(2) curr_q_slice = pl.ds(start_q * block_q_dq, block_q_dq) span_q = start_q * block_q_dq + jnp.arange(block_q_dq) - dq = jnp.zeros([block_q_dq, block_d], dtype=jnp.float32) + dq = jnp.zeros([block_q_dq, head_dim_padded], dtype=jnp.float32) - q = pl.load(q_ref, (curr_q_slice, slice(None))) + q = pl.load(q_ref, (curr_q_slice, slice(None)), mask=head_mask, other=0.0) q_segment_ids = ( None if segment_ids_ref is None else pl.load(segment_ids_ref, (curr_q_slice,)) ) lse = pl.load(lse_ref, (curr_q_slice,)) - do = pl.load(do_scaled_ref, (curr_q_slice, slice(None))) + do = pl.load(do_scaled_ref, (curr_q_slice, slice(None)), mask=head_mask, + other=0.0) di = pl.load(delta_ref, (curr_q_slice,)) def inner_loop_dq(start_k, dq): curr_k_slice = pl.dslice(start_k * block_kv_dq, block_kv_dq) - k = pl.load(k_ref, (curr_k_slice, slice(None))) - v = pl.load(v_ref, (curr_k_slice, slice(None))) + k = pl.load(k_ref, (curr_k_slice, slice(None)), mask=head_mask, other=0.0) + v = pl.load(v_ref, (curr_k_slice, slice(None)), mask=head_mask, other=0.0) qk = pl.dot(q, k.T) qk_scale = math.log2(math.e) @@ -547,15 +523,19 @@ def inner_loop_dq(start_k, dq): upper_bound = pl.cdiv(kv_seq_len, block_kv_dq) dq = lax.fori_loop(0, upper_bound, inner_loop_dq, (dq)) - dq_ref[...] = dq.astype(dq_ref.dtype) + pl.store(dq_ref, (slice(None), slice(dq.shape[-1])), dq.astype(dq_ref.dtype), + mask=head_mask) def _mha_backward(sm_scale: float, causal: bool, block_sizes: BlockSizes, backward_pass_impl: str, num_warps: int | None, num_stages: int, grid: Any, interpret: bool, - debug: bool, res, do): - del num_stages, grid + debug: bool, return_residuals: bool, res, do): + if return_residuals: + raise ValueError( + "Kernel differentiation is not supported if return_residuals is True.") q, k, v, segment_ids, out, lse = res + del num_stages, grid, return_residuals if backward_pass_impl == "xla": return jax.vjp( @@ -576,6 +556,7 @@ def _mha_backward(sm_scale: float, causal: bool, block_sizes: BlockSizes, block_kv_dkv = min(block_sizes.block_kv_dkv, kv_seq_len) block_q_dq = min(block_sizes.block_q_dq, q_seq_len) block_kv_dq = min(block_sizes.block_kv_dq, kv_seq_len) + head_dim_padded = pl.next_power_of_2(head_dim) if q_seq_len // block_q_dq != kv_seq_len // block_kv_dkv: raise ValueError( @@ -591,28 +572,24 @@ def _mha_backward(sm_scale: float, causal: bool, block_sizes: BlockSizes, ] in_specs = [ - pl.BlockSpec( - (None, q_seq_len, None, head_dim), lambda i, j, _: (i, 0, j, 0) - ), - pl.BlockSpec( - (None, kv_seq_len, None, head_dim), lambda i, j, _: (i, 0, j, 0) - ), - pl.BlockSpec( - (None, kv_seq_len, None, head_dim), lambda i, j, _: (i, 0, j, 0) - ), - pl.BlockSpec( - (None, q_seq_len, None, head_dim), lambda i, j, _: (i, 0, j, 0) - ), - pl.BlockSpec( - (None, q_seq_len, None, head_dim), lambda i, j, _: (i, 0, j, 0) - ), + pl.BlockSpec((None, q_seq_len, None, head_dim_padded), + lambda i, j, _: (i, 0, j, 0)), + pl.BlockSpec((None, kv_seq_len, None, head_dim_padded), + lambda i, j, _: (i, 0, j, 0)), + pl.BlockSpec((None, kv_seq_len, None, head_dim_padded), + lambda i, j, _: (i, 0, j, 0)), + pl.BlockSpec((None, q_seq_len, None, head_dim_padded), + lambda i, j, _: (i, 0, j, 0)), + pl.BlockSpec((None, q_seq_len, None, head_dim_padded), + lambda i, j, _: (i, 0, j, 0)), pl.BlockSpec((None, None, q_seq_len), lambda i, j, _: (i, j, 0)), pl.BlockSpec((None, None, q_seq_len), lambda i, j, _: (i, j, 0)), ] if segment_ids is None: in_specs.insert(3, None) # type: ignore[arg-type] else: - in_specs.insert(3, pl.BlockSpec((None, kv_seq_len), lambda i, j, _: (i, 0))) + in_specs.insert(3, pl.BlockSpec((None, kv_seq_len), + lambda i, j, _: (i, 0))) grid = (batch_size, num_heads, pl.cdiv(kv_seq_len, block_kv_dkv)) num_warps_ = num_warps @@ -635,29 +612,29 @@ def _mha_backward(sm_scale: float, causal: bool, block_sizes: BlockSizes, block_kv_dkv=block_kv_dkv, block_q_dq=block_q_dq, block_kv_dq=block_kv_dq, - block_d=head_dim, + head_dim=head_dim, ), out_shape=out_shapes, in_specs=in_specs, grid=grid, out_specs=[ pl.BlockSpec( - (None, block_q_dq, None, head_dim), + (None, block_q_dq, None, head_dim_padded), lambda i, j, k: (i, k, j, 0), # dq ), pl.BlockSpec( - (None, block_kv_dkv, None, head_dim), + (None, block_kv_dkv, None, head_dim_padded), lambda i, j, k: (i, k, j, 0), # dk ), pl.BlockSpec( - (None, block_kv_dkv, None, head_dim), + (None, block_kv_dkv, None, head_dim_padded), lambda i, j, k: (i, k, j, 0), # dv ), ], name="mha_backward", debug=debug, interpret=interpret, - compiler_params=plgpu.TritonCompilerParams( + compiler_params=plgpu.CompilerParams( num_warps=num_warps_, num_stages=2 ), )(q, k, v, segment_ids, out, do, lse, delta) diff --git a/jax/experimental/pallas/ops/gpu/attention_mgpu.py b/jax/experimental/pallas/ops/gpu/attention_mgpu.py index 8883878f5f0e..90b8eb702db4 100644 --- a/jax/experimental/pallas/ops/gpu/attention_mgpu.py +++ b/jax/experimental/pallas/ops/gpu/attention_mgpu.py @@ -20,12 +20,13 @@ import jax from jax import lax from jax._src import test_util as jtu # noqa: F401 +from jax._src.lib import cuda_versions # noqa: F401 from jax.experimental.mosaic.gpu import profiler import jax.experimental.pallas as pl import jax.experimental.pallas.mosaic_gpu as plgpu import jax.numpy as jnp import numpy as np - +from functools import partial @dataclasses.dataclass(frozen=True) class TuningConfig: @@ -33,6 +34,13 @@ class TuningConfig: block_kv: int max_concurrent_steps: int use_schedule_barrier: bool = True + causal: bool = False + compute_wgs_bwd: int = 1 + + block_q_dkv: int | None = None + block_kv_dkv: int | None = None + block_q_dq: int | None = None + block_kv_dq: int | None = None def __post_init__(self): if self.block_q % 64: @@ -42,9 +50,26 @@ def __post_init__(self): if self.max_concurrent_steps < 2: raise ValueError(f"{self.max_concurrent_steps=} must be at least 2") + backward_blocks = [self.block_q_dkv, self.block_kv_dkv, self.block_q_dq, self.block_kv_dq] + block_is_set = [blk is not None for blk in backward_blocks] + if any(block_is_set) and not all(block_is_set): + raise ValueError( + "Backward block sizes (block_q_dkv, block_kv_dkv, block_q_dq, " + "block_kv_dq) must either all be specified or all be None." + ) -@functools.partial(jax.jit, static_argnames=["config"]) -def attention(q, k, v, config: TuningConfig): + @property + def has_backward_blocks(self) -> bool: + return self.block_q_dkv is not None + +def _attention_forward(q, k, v, config: TuningConfig, save_residuals: bool = False): + cuda_runtime_version = cuda_versions.cuda_runtime_get_version() + # TODO(pobudzey): Undo when we upgrade to cuda 12.9.1. + if config.causal and cuda_runtime_version >= 12080 and cuda_runtime_version < 12091: + raise ValueError( + "Causal masking not supported with cuda versions between 12.8.0 and" + " 12.9.1 due to a ptxas miscompilation." + ) if q.ndim != 4 or k.ndim != 4 or v.ndim != 4: raise ValueError(f"q, k, and v should all be 4D, got: {q.ndim=}, {k.ndim=}, {v.ndim=}") batch_size, q_seq_len, num_q_heads, head_dim = q.shape @@ -68,25 +93,39 @@ def attention(q, k, v, config: TuningConfig): config.max_concurrent_steps, kv_seq_len // config.block_kv ) block_q, block_kv = config.block_q, config.block_kv + if kv_seq_len % block_kv: + raise ValueError(f"{kv_seq_len=} must be a multiple of {block_kv=}") - def kernel(q_ref, k_ref, v_ref, out_ref, scoped): + def kernel(q_ref, k_ref, v_ref, out_ref, lse_ref, scoped): batch = lax.axis_index("batch") q_head = lax.axis_index("heads") smem_buffers, buffer_barriers, consumed_barriers, schedule_barrier = scoped wg_idx = lax.axis_index("wg") - qo_smem2, k_smem, v_smem = smem_buffers + qo_smem2, k_smem, v_smem, lse_smem2 = smem_buffers k_barriers, v_barriers, q_barriers = buffer_barriers k_consumed_barriers, v_consumed_barriers = consumed_barriers def perform_schedule_barrier(): plgpu.barrier_arrive(schedule_barrier) plgpu.barrier_wait(schedule_barrier) + if config.causal: + block_q_end = (lax.axis_index("q_seq") + 1) * (2 * block_q) + block_max_kv_steps = pl.cdiv(block_q_end, jnp.array(block_kv, jnp.int32)) + else: + block_max_kv_steps = kv_seq_len // block_kv + @pl.when(wg_idx < 2) def _compute_wg(): plgpu.set_max_registers(232, action="increase") qo_smem = qo_smem2.at[wg_idx] + lse_smem = lse_smem2.at[wg_idx] if lse_smem2 is not None else None q_seq_base = lax.axis_index("q_seq") * (2 * block_q) + wg_idx * block_q + if config.causal: + kv_steps = pl.cdiv(q_seq_base + block_q, jnp.array(block_kv, jnp.int32)) + else: + kv_steps = block_max_kv_steps + plgpu.copy_gmem_to_smem( q_ref.at[batch, pl.ds(q_seq_base, block_q), q_head], qo_smem, @@ -104,12 +143,14 @@ def _compute_wg(): jnp.full((block_q, head_dim), 0, dtype=jnp.float32), plgpu.Layout.WGMMA, ) - plgpu.barrier_wait(k_barriers.at[0]) + @pl.when(kv_steps > 0) + def _(): + plgpu.barrier_wait(k_barriers.at[0]) pl.when(wg_idx == 1)(perform_schedule_barrier) - def kv_loop(kv_step, carry): + def kv_loop(kv_step, carry, causal: bool = False): acc, m_i, l_i = carry - slot = lax.rem(kv_step, max_concurrent_steps) + slot = lax.rem(kv_step, jnp.array(max_concurrent_steps, kv_step.dtype)) # QK def compute_qk(acc_ref): @@ -119,6 +160,12 @@ def compute_qk(acc_ref): qk = pl.run_scoped(compute_qk, plgpu.ACC((block_q, block_kv), jnp.float32)) plgpu.barrier_arrive(k_consumed_barriers.at[slot]) + if causal: + q_ids = plgpu.broadcasted_iota(jnp.int32, (block_q, block_kv), 0, layout=plgpu.Layout.WGMMA) + kv_ids = plgpu.broadcasted_iota(jnp.int32, (block_q, block_kv), 1, layout=plgpu.Layout.WGMMA) + mask = (q_ids + q_seq_base) >= (kv_ids + kv_step * block_kv) + qk = jnp.where(mask, qk, -jnp.inf) + # Softmax # We keep m scaled by log2e to use FMA instructions when computing p. log2e = math.log2(math.e) @@ -149,28 +196,53 @@ def compute_pv(acc_ref): plgpu.wgmma(acc_ref, p16, v_smem.at[slot]) wait_step = kv_step + 1 - wait_slot = lax.rem(wait_step, max_concurrent_steps) - @pl.when(wait_step < kv_seq_len // block_kv) + wait_slot = lax.rem(wait_step, jnp.array(max_concurrent_steps, kv_step.dtype)) + @pl.when(wait_step < kv_steps) def _wait(): plgpu.barrier_wait(k_barriers.at[wait_slot]) acc = pl.run_state(compute_pv)(plgpu.ACC.init(acc)) plgpu.barrier_arrive(v_consumed_barriers.at[slot]) return acc, m_i, l_i - if kv_seq_len % block_kv: - raise ValueError(f"{kv_seq_len=} must be a multiple of {block_kv=}") - acc, m_i, l_i = lax.fori_loop( - 0, kv_seq_len // block_kv, kv_loop, (acc, m_i, l_i) - ) + + if not config.causal: + acc, m_i, l_i = lax.fori_loop(0, block_max_kv_steps, kv_loop, (acc, m_i, l_i)) + else: + def epilogue_kv_loop(kv_step, _): + # This loop makes sure that all the pipelined KV data is processed, even + # if one compute wg finishes early like with causal masking. + slot = lax.rem(kv_step, jnp.array(max_concurrent_steps, kv_step.dtype)) + plgpu.barrier_arrive(k_consumed_barriers.at[slot]) + plgpu.barrier_arrive(v_consumed_barriers.at[slot]) + perform_schedule_barrier() + perform_schedule_barrier() + + causal_kv_loop = functools.partial(kv_loop, causal=True) + full_kv_steps = lax.div(q_seq_base, jnp.array(block_kv, jnp.int32)) + # With causal masking, the KV loop unrolling is split in 3 sections: + # 1. A fast path where no causal mask is needed. + acc, m_i, l_i = lax.fori_loop(0, full_kv_steps, kv_loop, (acc, m_i, l_i)) + # 2. Causal masking. + acc, m_i, l_i = lax.fori_loop(full_kv_steps, kv_steps, causal_kv_loop, (acc, m_i, l_i)) + # 3. Epilogue to flush the data pipeline. + lax.fori_loop(kv_steps, block_max_kv_steps, epilogue_kv_loop, None) pl.when(wg_idx == 0)(perform_schedule_barrier) - del m_i # Not needed anymore # TODO(apaszke): Invert and multiply to avoid expensive divisions. acc /= lax.broadcast_in_dim(l_i, (block_q, head_dim), [0]) qo_smem[...] = acc.astype(dtype) + if lse_smem is not None: + RCP_LN2 = 1.4426950408889634 + log2 = lambda x: jnp.log(x) * RCP_LN2 + lse_smem[...] = m_i + log2(l_i) plgpu.commit_smem() plgpu.copy_smem_to_gmem( qo_smem, out_ref.at[batch, pl.ds(q_seq_base, block_q), q_head], ) + if lse_smem is not None: + plgpu.copy_smem_to_gmem( + lse_smem, + lse_ref.at[batch, q_head, pl.ds(q_seq_base, block_q)], + ) plgpu.wait_smem_to_gmem(0) @pl.when(wg_idx == 2) def _memory_wg(): @@ -181,19 +253,19 @@ def _memory_wg(): plgpu.copy_gmem_to_smem(k_ref.at[s], k_smem.at[i], k_barriers.at[i]) plgpu.copy_gmem_to_smem(v_ref.at[s], v_smem.at[i], v_barriers.at[i]) - def kv_loop(kv_step, _): + @pl.loop(0, block_max_kv_steps - max_concurrent_steps) + def _kv_loop(kv_step): tma_step = kv_step + max_concurrent_steps - tma_slot = lax.rem(kv_step, max_concurrent_steps) + tma_slot = lax.rem(kv_step, jnp.array(max_concurrent_steps, kv_step.dtype)) s = (batch, pl.ds(tma_step * block_kv, block_kv), kv_head) plgpu.barrier_wait(k_consumed_barriers.at[tma_slot]) plgpu.copy_gmem_to_smem(k_ref.at[s], k_smem.at[tma_slot], k_barriers.at[tma_slot]) plgpu.barrier_wait(v_consumed_barriers.at[tma_slot]) plgpu.copy_gmem_to_smem(v_ref.at[s], v_smem.at[tma_slot], v_barriers.at[tma_slot]) - lax.fori_loop(0, kv_seq_len // block_kv - max_concurrent_steps, kv_loop, None) - def entry(q_ref, k_ref, v_ref, out_ref): + def entry(q_ref, k_ref, v_ref, out_ref, lse_ref): compute_wgs = 2 - tiling = plgpu.TilingTransform((64, 64)) + tiling = plgpu.TilingTransform((8, 64)) swizzle = plgpu.SwizzleTransform(128) qo_scratch = plgpu.SMEM( (compute_wgs, block_q, head_dim), jnp.float16, @@ -201,39 +273,371 @@ def entry(q_ref, k_ref, v_ref, out_ref): ) k_scratch = plgpu.SMEM( (max_concurrent_steps, block_kv, head_dim), jnp.float16, - transforms=(tiling, plgpu.TransposeTransform((0, 2, 1, 3, 4)), swizzle), + transforms=(tiling, swizzle), ) v_scratch = plgpu.SMEM( (max_concurrent_steps, block_kv, head_dim), jnp.float16, transforms=(tiling, swizzle), ) + scratch = [qo_scratch, k_scratch, v_scratch, None] + if save_residuals: + scratch[3] = plgpu.SMEM((compute_wgs, block_q), jnp.float32) pl.run_scoped( - lambda *args: kernel(q_ref, k_ref, v_ref, out_ref, args), - (qo_scratch, k_scratch, v_scratch), + lambda *args: kernel(q_ref, k_ref, v_ref, out_ref, lse_ref, args), + scratch, ( - plgpu.Barrier(1, num_barriers=max_concurrent_steps), - plgpu.Barrier(1, num_barriers=max_concurrent_steps), - plgpu.Barrier(1, num_barriers=compute_wgs), + plgpu.Barrier(num_barriers=max_concurrent_steps), + plgpu.Barrier(num_barriers=max_concurrent_steps), + plgpu.Barrier(num_barriers=compute_wgs), ), (plgpu.Barrier(num_arrivals=compute_wgs, num_barriers=max_concurrent_steps),) * 2, plgpu.Barrier(num_arrivals=compute_wgs), + collective_axes="wg", ) num_q_tiles, rem = divmod(q_seq_len, block_q * 2) if rem: raise NotImplementedError(f"{q_seq_len=} must be a multiple of {block_q * 2=}") - return plgpu.kernel( + out_shape = [q, None] + if save_residuals: + # Note that we keep seq_len in the minor-most dimension so that we can do + # 1D TMAs on chunks of `block_q`. + out_shape[1] = jax.ShapeDtypeStruct( + (batch_size, num_q_heads, q_seq_len), jnp.float32 + ) + + out, lse = plgpu.kernel( entry, - out_shape=q, + out_shape=out_shape, grid=(batch_size, num_q_tiles, num_q_heads), + grid_names=("batch", "q_seq", "heads"), num_threads=3, - axis_names=("batch", "q_seq", "heads", "wg"), - compiler_params=plgpu.GPUCompilerParams(approx_math=True), + thread_name="wg", + compiler_params=plgpu.CompilerParams(approx_math=True), )(q, k, v) -@functools.partial(jax.jit, static_argnames=["config"]) -def attention_with_pipeline_emitter(q, k, v, config: TuningConfig): + if save_residuals: + assert lse is not None + return out, (lse,) + + return out + +@partial(jax.custom_vjp, nondiff_argnums=(3, 4)) +@partial(jax.jit, static_argnames=["config", "save_residuals"]) +def attention(q, k, v, config: TuningConfig, save_residuals: bool = False): + return _attention_forward(q, k, v, config, save_residuals) + +def _attention_fwd(q, k, v, config: TuningConfig, save_residuals: bool): + del save_residuals + + out, (lse,) = _attention_forward(q, k, v, config, save_residuals=True) + return out, (q, k, v, out, lse) + +def _attention_bwd(config: TuningConfig, save_residuals: bool, res, do): + del save_residuals + q, k, v, out, lse = res + + if config.causal: + raise NotImplementedError("Causal attention not supported in the backwards pass yet.") + + if not config.has_backward_blocks: + raise ValueError("Need to specify backward blocks.") + + assert config.block_q_dq is not None + assert config.block_kv_dq is not None + assert config.block_q_dkv is not None + assert config.block_kv_dkv is not None + + batch_size, q_seq_len, num_q_heads, head_dim = q.shape + _, kv_seq_len, num_kv_heads, _ = k.shape + q_heads_per_kv_head = num_q_heads // num_kv_heads + dtype = q.dtype + compute_wgs = config.compute_wgs_bwd + + num_q_tiles, rem = divmod(q_seq_len, config.block_q_dq * compute_wgs) + if rem: + raise NotImplementedError( + f"{q_seq_len=} must be a multiple of {config.block_q_dq=} * {compute_wgs=}") + + num_kv_tiles, rem = divmod(kv_seq_len, config.block_kv_dkv * compute_wgs) + if rem: + raise NotImplementedError( + f"{kv_seq_len=} must be a multiple of {config.block_kv_dkv=} * {compute_wgs=}") + + num_q_tiles_in_dkv, rem = divmod(q_seq_len, config.block_q_dkv) + if rem: + raise NotImplementedError(f"{q_seq_len=} must be a multiple of {config.block_q_dkv=}") + + num_kv_tiles_in_dq, rem = divmod(kv_seq_len, config.block_kv_dq) + if rem: + raise NotImplementedError(f"{kv_seq_len=} must be a multiple of {config.block_kv_dq=}") + + tiling = plgpu.TilingTransform((8, 64)) + swizzle = plgpu.SwizzleTransform(128) + + delta = jnp.einsum('bqhd,bqhd->bhq', out.astype(jnp.float32), do.astype(jnp.float32)) + del out # Not needed anymore. + + def kernel_dq(q_ref, k_ref, v_ref, do_ref, lse_ref, delta_ref, dq_ref, + smem_buffers, buffer_barriers, block_q, block_kv): + batch = lax.axis_index("batch") + q_head = lax.axis_index("heads") + wg_idx = lax.axis_index("wg") + kv_head = lax.div(q_head, jnp.array(q_heads_per_kv_head, q_head.dtype)) + q_smem2, do_smem2, lse_smem2, delta_smem2 = smem_buffers + q_barriers, do_barriers, lse_barriers, delta_barriers = buffer_barriers + def _compute_thread(pipeline_callback): + q_smem, do_smem, lse_smem, delta_smem = q_smem2.at[wg_idx], do_smem2.at[wg_idx], lse_smem2.at[wg_idx], delta_smem2.at[wg_idx] + q_seq_base = lax.axis_index("q_seq") * (compute_wgs * block_q) + wg_idx * block_q + q_slice = (batch, pl.ds(q_seq_base, block_q), q_head) + plgpu.copy_gmem_to_smem(q_ref.at[q_slice], q_smem, q_barriers.at[wg_idx]) + plgpu.copy_gmem_to_smem(do_ref.at[q_slice], do_smem, do_barriers.at[wg_idx]) + plgpu.copy_gmem_to_smem( + delta_ref.at[batch, q_head, pl.ds(q_seq_base, block_q)], + delta_smem, + delta_barriers.at[wg_idx], + ) + plgpu.copy_gmem_to_smem( + lse_ref.at[batch, q_head, pl.ds(q_seq_base, block_q)], + lse_smem, + lse_barriers.at[wg_idx], + ) + for buffer in buffer_barriers: + plgpu.barrier_wait(buffer.at[wg_idx]) + + delta = plgpu.load(delta_smem, (), layout=plgpu.Layout.WGMMA_ROW) + lse = plgpu.load(lse_smem, (), layout=plgpu.Layout.WGMMA_ROW) + dq_acc = plgpu.layout_cast( + jnp.full((block_q, head_dim), 0, dtype=jnp.float32), plgpu.Layout.WGMMA, + ) + dq, _, _ = pipeline_callback((dq_acc, lse, delta)) + q_smem[...] = dq.astype(dtype) + plgpu.commit_smem() + plgpu.copy_smem_to_gmem(q_smem, dq_ref.at[q_slice]) + plgpu.wait_smem_to_gmem(0) + + def kv_pipeline(_, k_smem, v_smem, k_consumed_barrier, v_consumed_barrier, carry): + q_smem, do_smem = q_smem2.at[wg_idx], do_smem2.at[wg_idx] + (dq_acc, lse, delta) = carry + + def compute_s(acc_ref): + plgpu.wgmma(acc_ref, q_smem, plgpu.transpose_ref(k_smem, (1, 0))) + return acc_ref[...] + + s = pl.run_scoped(compute_s, plgpu.ACC((block_q, block_kv), jnp.float32)) + s *= math.log2(math.e) + p = jnp.exp2(s - lax.broadcast_in_dim(lse, (block_q, block_kv), [0])) + + # dP + def compute_dp(acc_ref): + plgpu.wgmma(acc_ref, do_smem, plgpu.transpose_ref(v_smem, (1, 0))) + return acc_ref[...] + + dp = pl.run_scoped(compute_dp, plgpu.ACC((block_q, block_kv), jnp.float32)) + plgpu.barrier_arrive(v_consumed_barrier) + + # dS + ds = p * (dp - lax.broadcast_in_dim(delta, (block_q, block_kv), [0])) + + # dQ + def compute_dq(acc_ref): + plgpu.wgmma(acc_ref, ds.astype(k_ref.dtype), k_smem) + + dq_acc = pl.run_state(compute_dq)(plgpu.ACC.init(dq_acc)) + plgpu.barrier_arrive(k_consumed_barrier) + + return (dq_acc, lse, delta) + + pipeline = plgpu.emit_pipeline_warp_specialized( + kv_pipeline, + grid=(num_kv_tiles_in_dq,), + max_concurrent_steps=min([config.max_concurrent_steps, num_q_tiles]), + num_compute_wgs=compute_wgs, + memory_registers=40, + wg_axis="wg", + manual_consumed_barriers=True, + compute_context=_compute_thread, + in_specs=[ + plgpu.BlockSpec( # k + block_shape=(block_kv, head_dim), + index_map=lambda i: (i, 0), + transforms=[tiling, swizzle]), + plgpu.BlockSpec( # v + block_shape=(block_kv, head_dim), + index_map=lambda i: (i, 0), + transforms=[tiling, swizzle]), + ]) + k_ref = k_ref.at[batch, :, kv_head, :] + v_ref = v_ref.at[batch, :, kv_head, :] + pipeline(k_ref, v_ref) + + def kernel_dkv(q_ref, k_ref, v_ref, do_ref, lse_ref, delta_ref, + dk_ref, dv_ref, smem_buffers, buffer_barriers, block_q: int, block_kv: int): + batch = lax.axis_index("batch") + q_head = lax.axis_index("heads") + wg_idx = lax.axis_index("wg") + (k_smem2, v_smem2) = smem_buffers + (k_barriers, v_barriers) = buffer_barriers + + def _compute_thread(pipeline_callback): + k_smem, v_smem = k_smem2.at[wg_idx], v_smem2.at[wg_idx] + kv_seq_base = lax.axis_index("kv_seq") * (compute_wgs * block_kv) + wg_idx * block_kv + kv_head = lax.div(q_head, jnp.array(q_heads_per_kv_head, q_head.dtype)) + plgpu.copy_gmem_to_smem( + k_ref.at[(batch, pl.ds(kv_seq_base, block_kv), kv_head)], + k_smem, + k_barriers.at[wg_idx]) + plgpu.copy_gmem_to_smem( + v_ref.at[(batch, pl.ds(kv_seq_base, block_kv), kv_head)], + v_smem, + v_barriers.at[wg_idx]) + plgpu.barrier_wait(k_barriers.at[wg_idx]) + plgpu.barrier_wait(v_barriers.at[wg_idx]) + dk_acc = plgpu.layout_cast( + jnp.full((block_kv, head_dim), 0, dtype=jnp.float32), plgpu.Layout.WGMMA, + ) + dv_acc = plgpu.layout_cast( + jnp.full((block_kv, head_dim), 0, dtype=jnp.float32), plgpu.Layout.WGMMA, + ) + (dk, dv) = pipeline_callback((dv_acc, dk_acc)) + k_smem[...] = dk.astype(dtype) + v_smem[...] = dv.astype(dtype) + + plgpu.commit_smem() + plgpu.copy_smem_to_gmem( + k_smem, + dk_ref.at[(batch, pl.ds(kv_seq_base, block_kv), q_head)], + commit_group=False) + plgpu.copy_smem_to_gmem( + v_smem, + dv_ref.at[(batch, pl.ds(kv_seq_base, block_kv), q_head)], + commit_group=False) + plgpu.commit_smem_to_gmem_group() + plgpu.wait_smem_to_gmem(0) + + def q_pipeline(_, q_smem, do_smem, lse_smem, delta_smem, q_consumed_barrier, do_consumed_barrier, lse_consumed_barrier, delta_consumed_barrier, carry): + k_smem, v_smem = k_smem2.at[wg_idx], v_smem2.at[wg_idx] + dk_acc, dv_acc = carry + + def _compute_sT(acc_ref): + plgpu.wgmma(acc_ref, k_smem, plgpu.transpose_ref(q_smem, (1, 0))) + return acc_ref[...] + sT = pl.run_scoped(_compute_sT, plgpu.ACC((block_kv, block_q), jnp.float32)) + sT *= math.log2(math.e) + + lse = plgpu.load(lse_smem, (), layout=plgpu.Layout.WGMMA_COL) + plgpu.barrier_arrive(lse_consumed_barrier) + pT = jnp.exp2(sT - lax.broadcast_in_dim(lse, (block_kv, block_q), [1])) + + def _compute(refs): + # Combining two WGMMA calls in one block to avoid the unnecessary + # synchronization from two `wgmma.wait_group` calls. + dv_acc_ref, dpT_acc_ref = refs + plgpu.wgmma(dv_acc_ref, pT.astype(dtype), do_smem) # dV + plgpu.wgmma(dpT_acc_ref, v_smem, plgpu.transpose_ref(do_smem, (1, 0))) # dpT + + zeros = plgpu.layout_cast( + jnp.full((block_kv, block_q), 0, dtype=jnp.float32), plgpu.Layout.WGMMA, + ) + dv_acc, dpT = pl.run_state(_compute)((plgpu.ACC.init(dv_acc), plgpu.ACC.init(zeros))) + plgpu.barrier_arrive(do_consumed_barrier) + + delta = plgpu.load(delta_smem, (), layout=plgpu.Layout.WGMMA_COL) + plgpu.barrier_arrive(delta_consumed_barrier) + + dsT = pT * (dpT - lax.broadcast_in_dim(delta, (block_kv, block_q), [1])) + + def compute_dk(acc_ref): + plgpu.wgmma(acc_ref, dsT.astype(dtype), q_smem) + + dk_acc = pl.run_state(compute_dk)(plgpu.ACC.init(dk_acc)) + plgpu.barrier_arrive(q_consumed_barrier) + + return (dk_acc, dv_acc) + + pipeline = plgpu.emit_pipeline_warp_specialized( + q_pipeline, + grid=(num_q_tiles_in_dkv,), + max_concurrent_steps=min([config.max_concurrent_steps, num_kv_tiles]), + num_compute_wgs=compute_wgs, + memory_registers=40, + wg_axis="wg", + manual_consumed_barriers=True, + compute_context=_compute_thread, + in_specs=[ + plgpu.BlockSpec( # q + block_shape=(block_q, head_dim), + index_map=lambda i: (i, 0), + transforms=[tiling, swizzle]), + plgpu.BlockSpec( # do + block_shape=(block_q, head_dim), + index_map=lambda i: (i, 0), + transforms=[tiling, swizzle]), + plgpu.BlockSpec(block_shape=(block_q,), index_map=lambda i: (i,)), + plgpu.BlockSpec(block_shape=(block_q,), index_map=lambda i: (i,)) + ]) + q_ref = q_ref.at[batch, :, q_head, :] + do_ref = do_ref.at[batch, :, q_head, :] + lse_ref = lse_ref.at[batch, q_head, :] + delta_ref = delta_ref.at[batch, q_head, :] + pipeline(q_ref, do_ref, lse_ref, delta_ref) + + q_scratch = plgpu.SMEM( + (compute_wgs, config.block_q_dq, head_dim), jnp.float16, + transforms=(tiling, swizzle), + ) + do_scratch = q_scratch + lse_scratch = plgpu.SMEM((compute_wgs, config.block_q_dq), jnp.float32) + delta_scratch = plgpu.SMEM((compute_wgs, config.block_q_dq), jnp.float32) + dq = plgpu.kernel( + partial(kernel_dq, block_q=config.block_q_dq, block_kv=config.block_kv_dq), + out_shape=q, + scratch_shapes=[ + (q_scratch, do_scratch, lse_scratch, delta_scratch), # type: ignore + (plgpu.Barrier(num_barriers=compute_wgs),) * 4 # type: ignore + ], + compiler_params=plgpu.CompilerParams(approx_math=True), + grid=(batch_size, num_q_tiles, num_q_heads), + grid_names=("batch", "q_seq", "heads"), + num_threads=compute_wgs + 1, + thread_name="wg", + )(q, k, v, do, lse, delta) + + k_scratch = plgpu.SMEM( + (compute_wgs, config.block_kv_dkv, head_dim), jnp.float16, + transforms=(tiling, swizzle), + ) + v_scratch = k_scratch + out_shape_kv = jax.ShapeDtypeStruct( + (batch_size, kv_seq_len, num_q_heads, head_dim), dtype=jnp.float16) + dk, dv = plgpu.kernel( + partial(kernel_dkv, block_q=config.block_q_dkv, block_kv=config.block_kv_dkv), + out_shape=[out_shape_kv, out_shape_kv], + scratch_shapes=[ + (k_scratch, v_scratch), # type: ignore + (plgpu.Barrier(num_barriers=compute_wgs),) * 2 # type: ignore + ], + compiler_params=plgpu.CompilerParams(approx_math=True), + grid=(batch_size, num_kv_tiles, num_q_heads), + grid_names=("batch", "kv_seq", "heads"), + num_threads=compute_wgs + 1, + thread_name="wg" + )(q, k, v, do, lse, delta) + + if q_heads_per_kv_head > 1: + sum_shape = (*k.shape[:-1], q_heads_per_kv_head, head_dim) + dk = dk.reshape(sum_shape).astype(jnp.float32).sum(axis=-2).astype(dk.dtype) + dv = dv.reshape(sum_shape).astype(jnp.float32).sum(axis=-2).astype(dv.dtype) + + return dq, dk, dv + +attention.defvjp(_attention_fwd, _attention_bwd) + +@functools.partial(jax.jit, static_argnames=["config", "save_residuals"]) +def attention_with_pipeline_emitter(q, k, v, config: TuningConfig, save_residuals=False): + if config.causal: + raise NotImplementedError("Causal attention is not supported with the pipeline emitter yet.") if q.ndim != 4 or k.ndim != 4 or v.ndim != 4: raise ValueError(f"q, k, and v should all be 4D, got: {q.ndim=}, {k.ndim=}, {v.ndim=}") batch_size, q_seq_len, num_q_heads, head_dim = q.shape @@ -262,14 +666,10 @@ def attention_with_pipeline_emitter(q, k, v, config: TuningConfig): if rem: raise NotImplementedError(f"{q_seq_len=} must be a multiple of {block_q * 2=}") - tiling = plgpu.TilingTransform((64, 64)) - swizzle = plgpu.SwizzleTransform(128) - transpose = plgpu.TransposeTransform((0, 2, 1, 3, 4)) - - def fa3_kernel(q_ref, k_ref, v_ref, out_ref, scoped): + def fa3_kernel(q_ref, k_ref, v_ref, out_ref, lse_ref, smem_buffers, q_barriers, schedule_barrier): batch = lax.axis_index("batch") wg_idx = lax.axis_index("wg") - qo_smem2, q_barriers, schedule_barrier = scoped + qo_smem2, lse_smem2 = smem_buffers q_seq_base = lax.axis_index("q_seq") * (2 * block_q) + wg_idx * block_q q_head = lax.axis_index("heads") kv_head = lax.div(q_head, jnp.array(q_heads_per_kv_head, q_head.dtype)) @@ -279,17 +679,12 @@ def perform_schedule_barrier(): plgpu.barrier_arrive(schedule_barrier) plgpu.barrier_wait(schedule_barrier) - def _compute_thread(): + def _compute_thread(pipeline_callback): qo_smem = qo_smem2.at[wg_idx] - m_i = plgpu.layout_cast( - jnp.full((block_q,), -jnp.inf, dtype=jnp.float32), plgpu.Layout.WGMMA_ROW, - ) - l_i = plgpu.layout_cast( - jnp.full((block_q,), 0, dtype=jnp.float32), plgpu.Layout.WGMMA_ROW, - ) - acc = plgpu.layout_cast( - jnp.full((block_q, head_dim), 0, dtype=jnp.float32), plgpu.Layout.WGMMA, - ) + lse_smem = lse_smem2.at[wg_idx] if lse_smem2 is not None else None + m_i = jnp.full((block_q,), -jnp.inf, dtype=jnp.float32) + l_i = jnp.full((block_q,), 0, dtype=jnp.float32) + acc = jnp.full((block_q, head_dim), 0, dtype=jnp.float32) # Q is not pipelined, so we load in with a manual DMA. plgpu.copy_gmem_to_smem( q_ref.at[batch, pl.ds(q_seq_base, block_q), q_head], @@ -298,19 +693,27 @@ def _compute_thread(): ) plgpu.barrier_wait(q_barriers.at[wg_idx]) pl.when(wg_idx == 1)(perform_schedule_barrier) - final_carry = (yield (acc, m_i, l_i)) - del m_i # Unused + final_carry = pipeline_callback((acc, m_i, l_i)) pl.when(wg_idx == 0)(perform_schedule_barrier) - acc, _, l_i = final_carry + acc, m_i, l_i = final_carry acc /= lax.broadcast_in_dim(l_i, (block_q, head_dim), [0]) qo_smem[...] = acc.astype(dtype) + if lse_smem is not None: + RCP_LN2 = 1.4426950408889634 + log2 = lambda x: jnp.log(x) * RCP_LN2 + lse_smem[...] = m_i + log2(l_i) plgpu.commit_smem() plgpu.copy_smem_to_gmem( qo_smem, out_ref.at[batch, pl.ds(q_seq_base, block_q), q_head], ) + if lse_smem is not None: + plgpu.copy_smem_to_gmem( + lse_smem, + lse_ref.at[batch, q_head, pl.ds(q_seq_base, block_q)], + ) plgpu.wait_smem_to_gmem(0) - def kv_pipeline(k_smem, v_smem, + def kv_pipeline(_, k_smem, v_smem, k_consumed_barrier, v_consumed_barrier, carry): acc, m_i, l_i = carry @@ -348,66 +751,82 @@ def compute_pv(acc_ref): memory_registers=40, wg_axis="wg", manual_consumed_barriers=True, - carry_coroutine=_compute_thread, + compute_context=_compute_thread, in_specs=[ - plgpu.GPUBlockSpec( # k + plgpu.BlockSpec( # k block_shape=(block_kv, head_dim), - index_map=lambda i: (i, 0), - transforms=[tiling, transpose, swizzle]), - plgpu.GPUBlockSpec( # v + index_map=lambda i: (i, 0)), + plgpu.BlockSpec( # v block_shape=(block_kv, head_dim), - index_map=lambda i: (i, 0), - transforms=[tiling, swizzle]), + index_map=lambda i: (i, 0)), ], out_specs=[], ) k_ref = k_ref.at[batch, :, kv_head, :] v_ref = v_ref.at[batch, :, kv_head, :] pipeline(k_ref, v_ref) - mesh = plgpu.GPUMesh( + + out_shape = [q, None] + if save_residuals: + out_shape[1] = jax.ShapeDtypeStruct((batch_size, num_q_heads, q_seq_len), jnp.float32) + + qo_scratch = plgpu.SMEM((compute_wgs, block_q, head_dim), jnp.float16) + smem_scratch = [qo_scratch, None] + if save_residuals: + smem_scratch[1] = plgpu.SMEM((compute_wgs, block_q), jnp.float32) + + out, lse = plgpu.kernel( + fa3_kernel, grid=(batch_size, num_q_tiles, num_q_heads), + grid_names=("batch", "q_seq", "heads"), num_threads=3, - axis_names=("batch", "q_seq", "heads", "wg"), - ) - def run(refs): - q_ref, k_ref, v_ref, out_ref = refs - @pl.core_map(mesh, - compiler_params=plgpu.GPUCompilerParams(approx_math=True), - ) - def _kernel_entry(): - qo_scratch = plgpu.SMEM( - (compute_wgs, block_q, head_dim), jnp.float16, - transforms=(tiling, swizzle), - ) - pl.run_scoped( - lambda *args: fa3_kernel(q_ref, k_ref, v_ref, out_ref, args), - qo_scratch, - plgpu.Barrier(1, num_barriers=compute_wgs), - plgpu.Barrier(num_arrivals=compute_wgs), - ) - @jax.jit - def run_function(q, k, v, o): - _, _, _, out = pl.run_state(run)((q, k, v, o)) - return out - out = run_function(q, k, v, jnp.full_like(q, jnp.inf)) + thread_name="wg", + out_shape=out_shape, + scratch_shapes=( + tuple(smem_scratch), # type: ignore + plgpu.Barrier(num_barriers=compute_wgs), # type: ignore + plgpu.Barrier(num_arrivals=compute_wgs),), # type: ignore + compiler_params=plgpu.CompilerParams( + approx_math=True, lowering_semantics=plgpu.LoweringSemantics.Warpgroup, + ), + )(q, k, v) + + if save_residuals: + assert lse is not None + return out, (lse,) + return out -@jax.jit -def attention_reference(q, k, v): +@functools.partial(jax.jit, static_argnames=["causal", "save_residuals"]) +def attention_reference(q, k, v, causal=False, save_residuals=False): batch_size, q_seq_len, num_q_heads, head_dim = q.shape - num_kv_heads = k.shape[2] + kv_seq_len, num_kv_heads = k.shape[1], k.shape[2] q, k, v = map(lambda x: x.astype(jnp.float32), (q, k, v)) q_reshaped = q.reshape( batch_size, q_seq_len, num_kv_heads, num_q_heads // num_kv_heads, head_dim ) logits = jnp.einsum("bqHhc,bkHc->bqHhk", q_reshaped, k) + + if causal: + mask = jnp.arange(q_seq_len)[:, None] >= jnp.arange(kv_seq_len)[None, :] + mask = jnp.broadcast_to(mask[:, None, None, :], logits.shape) + logits = jnp.where(mask, logits, -jnp.inf) + m = logits.max(axis=-1, keepdims=True) unnormalized = jnp.exp(logits - m) l = unnormalized.sum(axis=-1, keepdims=True) weights = unnormalized / l - return jnp.einsum("bqHhk,bkHc->bqHhc", weights, v).reshape(*q.shape) - + out = jnp.einsum("bqHhk,bkHc->bqHhc", weights, v).reshape(*q.shape) + + if save_residuals: + log2e = math.log2(math.e) + l = l.reshape(*q.shape[:-1]) + m = m.reshape(*q.shape[:-1]) + lse = m * log2e + jnp.log2(l) + return out, (lse.swapaxes(-1, -2),) + else: + return out def main(unused_argv): num_q_heads = 16 @@ -421,11 +840,18 @@ def main(unused_argv): schedule_barrier_opts = (True,) problem_it = itertools.product( - (1,), (4096, 32768,), (64, 128, 256,), schedule_barrier_opts) - for batch_size, seq_len, head_dim, use_schedule_barrier in problem_it: + (1,), (4096, 32768,), (64, 128, 256,), schedule_barrier_opts, (False, True)) + for batch_size, seq_len, head_dim, use_schedule_barrier, causal in problem_it: + cuda_runtime_version = cuda_versions.cuda_runtime_get_version() + # TODO(pobudzey): Undo when we upgrade to cuda 12.9.1. + if causal and cuda_runtime_version >= 12080 and cuda_runtime_version < 12091: + continue + + if causal and use_pipeline_emitter: + continue q_seq_len = kv_seq_len = seq_len print(f"==== {batch_size=:<6} {kv_seq_len=:<6} {q_seq_len=:<6}" - f"{num_q_heads=:<4} {head_dim=:<6} {use_schedule_barrier=:} ====") + f"{num_q_heads=:<4} {head_dim=:<6} {use_schedule_barrier=:} {causal=:} ====") k1, k2, k3 = jax.random.split(jax.random.key(42), 3) q = jax.random.normal(k1, (batch_size, q_seq_len, num_q_heads, head_dim), jnp.float16) k = jax.random.normal(k2, (batch_size, kv_seq_len, num_kv_heads, head_dim), jnp.float16) @@ -433,11 +859,11 @@ def main(unused_argv): block_q = 64 best = None for block_kv in (256, 128, 64): - config = TuningConfig(block_q=block_q, block_kv=block_kv, max_concurrent_steps=2, use_schedule_barrier=use_schedule_barrier) + config = TuningConfig(block_q=block_q, block_kv=block_kv, max_concurrent_steps=2, use_schedule_barrier=use_schedule_barrier, causal=causal) try: out, runtime_ms = profiler.measure(functools.partial(attention_impl, config=config))(q, k, v) if seq_len < 32768: - out_ref = attention_reference(q, k, v) + out_ref = attention_reference(q, k, v, causal=causal) np.testing.assert_allclose(out, out_ref, atol=2e-3, rtol=1e-3) except ValueError as e: if "exceeds available shared memory" in e.args[0]: @@ -447,6 +873,8 @@ def main(unused_argv): matmul_flops = ( 4 * q_seq_len * kv_seq_len * head_dim * num_q_heads * batch_size ) + if causal: + matmul_flops //= 2 peak_flops = 1e15 # f16 TensorCore peak = 1000TFLOPS optimal_time = matmul_flops / peak_flops * 1e6 # us achieved_tc_util = optimal_time / runtime_us * 100 diff --git a/jax/experimental/pallas/ops/gpu/blackwell_matmul_mgpu.py b/jax/experimental/pallas/ops/gpu/blackwell_matmul_mgpu.py new file mode 100644 index 000000000000..ad210066c5e0 --- /dev/null +++ b/jax/experimental/pallas/ops/gpu/blackwell_matmul_mgpu.py @@ -0,0 +1,247 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Matrix Multiplication kernel for Blackwell GPUs.""" +import dataclasses +import functools +import itertools +import jax +from jax import lax +from jax._src import test_util as jtu # noqa: F401 +from jax.experimental.mosaic.gpu import profiler +import jax.experimental.pallas as pl +import jax.experimental.pallas.mosaic_gpu as plgpu +import jax.numpy as jnp +import numpy as np + + +@dataclasses.dataclass(frozen=True) +class TuningConfig: + tile_m: int + tile_n: int + tile_k: int + max_concurrent_steps: int + collective: bool + + +def _find_swizzle(dim_size_bits: int): + """Finds the largest swizzle that fits the dimension size.""" + for swizzle_bytes in (128, 64, 32, 16): + if dim_size_bits % (swizzle_bytes * 8) == 0: + return swizzle_bytes + raise ValueError( + f"Dimension size has {dim_size_bits} bits, which is not a multiple of 128" + ) + + +def matmul_kernel(a, b, config: TuningConfig): + dtype = a.dtype + if a.dtype != b.dtype: + raise ValueError( + f"Matmul LHS and RHS have incompatible dtypes {a.dtype} vs {b.dtype}" + ) + m, k = a.shape + k2, n = b.shape + if k != k2: + raise ValueError( + f"Matmul LHS and RHS have incompatible shapes {a.shape} vs {b.shape}" + ) + collective = config.collective + tile_m, tile_n, tile_k = (config.tile_m, config.tile_n, config.tile_k) + block_tile_m = tile_m + block_tile_n = tile_n + if collective: + tile_m *= 2 + tile_n *= 2 + swizzle = _find_swizzle(tile_k * jnp.dtype(dtype).itemsize * 8) + swizzle_elems = swizzle // jnp.dtype(dtype).itemsize + transforms = ( + plgpu.TilingTransform((8, swizzle_elems)), + plgpu.SwizzleTransform(swizzle), + ) + block_lhs = (block_tile_m, tile_k) + block_rhs = (tile_k, block_tile_n) + block_out = (block_tile_m, tile_n) + if m % tile_m != 0: + raise ValueError(f"{m=} must be divisible by {tile_m=}") + if n % tile_n != 0: + raise ValueError(f"{n=} must be divisible by {tile_n=}") + if k % tile_k != 0: + raise ValueError(f"{k=} must be divisible by {tile_k=}") + m_iters = m // tile_m + n_iters = n // tile_n + k_iters = k // tile_k + max_concurrent_steps = config.max_concurrent_steps + + TMA_WARP = 0 + MMA_WARP = 1 + + def kernel(a_gmem, b_gmem, out_gmem, + a_smem, b_smem, acc_tmem, acc_smem, + a_tma_barrier, b_tma_barrier, consumed_barrier): + m_index = lax.axis_index("m") + n_index = lax.axis_index("n") + if collective: + cluster_idx = lax.axis_index("x") + block_m_index = m_index * 2 + cluster_idx + is_lead_block = cluster_idx == 0 + else: + block_m_index = m_index + is_lead_block = True + block_slice_m = pl.ds(block_m_index * block_tile_m, block_tile_m) + slice_m = pl.ds(m_index * tile_m, tile_m) + slice_n = pl.ds(n_index * tile_n, tile_n) + + @pl.core_map(plgpu.WarpMesh(axis_name="warp")) + def _per_warp(): + warp_id = lax.axis_index("warp") + @pl.when(warp_id == TMA_WARP) + def _memory(): + def _loop_body(ki, _): + slice_k = pl.ds(ki * tile_k, tile_k) + slot = lax.rem(ki, max_concurrent_steps) + @pl.when(ki >= max_concurrent_steps) + def _(): + plgpu.barrier_wait(consumed_barrier.at[slot]) + plgpu.copy_gmem_to_smem( + a_gmem.at[slice_m, slice_k], + a_smem.at[slot], + a_tma_barrier.at[slot], + partitioned_axis=0 if collective else None, + collective_axes="x" if collective else None, + ) + plgpu.copy_gmem_to_smem( + b_gmem.at[slice_k, slice_n], + b_smem.at[slot], + b_tma_barrier.at[slot], + partitioned_axis=1 if collective else None, + collective_axes="x" if collective else None, + ) + + lax.fori_loop(0, k_iters, _loop_body, None) + + @pl.when(jnp.logical_and(warp_id == MMA_WARP, is_lead_block)) + def _compute(): + def _loop_body(ki, _): + slot = lax.rem(ki, max_concurrent_steps) + plgpu.barrier_wait(a_tma_barrier.at[slot]) + plgpu.barrier_wait(b_tma_barrier.at[slot]) + + is_last_iter = ki >= k_iters - 1 + barrier_slot = lax.select_n(is_last_iter, + slot, max_concurrent_steps) + plgpu.tcgen05_mma( + acc_tmem, + a_smem.at[slot], + b_smem.at[slot], + consumed_barrier.at[barrier_slot], + accumulate=(ki > 0), + collective_axis="x" if collective else None, + ) + + lax.fori_loop(0, k_iters, _loop_body, None) + + plgpu.barrier_wait(consumed_barrier.at[max_concurrent_steps]) + acc_smem[...] = acc_tmem[...].astype(dtype) + plgpu.commit_smem() + plgpu.copy_smem_to_gmem(acc_smem, out_gmem.at[block_slice_m, slice_n]) + plgpu.wait_smem_to_gmem(0) + + f = plgpu.kernel( + kernel, + out_shape=jax.ShapeDtypeStruct((m, n), dtype), + # n, m generally works better for most shapes. + grid=(n_iters, m_iters), + grid_names=("n", "m"), + cluster_names=("x",) if collective else (), + cluster=(2,) if collective else (), + scratch_shapes=( # type: ignore + plgpu.SMEM( + (max_concurrent_steps, *block_lhs), dtype, transforms=transforms + ), + plgpu.SMEM( + (max_concurrent_steps, *block_rhs), dtype, transforms=transforms + ), + plgpu.TMEM(block_out, jnp.float32, collective=collective), + plgpu.SMEM(block_out, dtype, transforms=transforms), + plgpu.Barrier(num_arrivals=1, num_barriers=max_concurrent_steps), + plgpu.Barrier(num_arrivals=1, num_barriers=max_concurrent_steps), + plgpu.Barrier( + num_arrivals=1, + num_barriers=max_concurrent_steps + 1, + for_tensor_core=True, + ), + ), + ) + return f(a, b) + + +def main(_) -> None: + problem_it = itertools.product( + (1024, 4096, 8192), (1024, 4096, 8192), (1024, 8192) + ) + for M, N, K in problem_it: + print(f"==== {M=} {N=} {K=} ====") + matmul_flops = 2 * M * N * K + peak_flops = 2.25e15 # f16 TensorCore peak = 2250 TFLOPS + a = jax.random.uniform(jax.random.key(0), (M, K), jnp.float16) + b = jax.random.uniform(jax.random.key(1), (K, N), jnp.float16) + tuning_it = itertools.product( + (128,), # tile_m + (128, 256), # tile_n + (64, 128), # tile_k + (2, 3, 4, 6), # max_concurrent_steps + (False, True), # collective + ) + best_util = -float("inf") + for (tile_m, tile_n, tile_k, + max_concurrent_steps, collective) in tuning_it: + config = TuningConfig( + tile_m=tile_m, + tile_n=tile_n, + tile_k=tile_k, + max_concurrent_steps=max_concurrent_steps, + collective=collective, + ) + try: + out, runtime_ms = profiler.measure( + functools.partial(matmul_kernel, config=config) + )(a, b) + except ValueError as e: + if ("exceeds available shared memory" in e.args[0] or + "Accumulator layout mismatch:" in e.args[0]): + # Accumulator layout mismatch triggers for tile_n=256 on some configs. + continue + raise + if M * N * K <= 1024 * 1024 * 1024: + expected = a @ b + np.testing.assert_allclose(out, expected) + runtime_us = runtime_ms * 1e3 # type: ignore + optimal_time = matmul_flops / peak_flops * 1e6 # us + achieved_tc_util = optimal_time / runtime_us * 100 + if achieved_tc_util > best_util: + best_util = achieved_tc_util + print( + f"{tile_m=} {tile_n=} {tile_k=} {max_concurrent_steps=} " + f"{collective=} : " + f"{runtime_us:<7.1f}us" + f" = {achieved_tc_util:4.1f}% TC utilization" + ) + print(f"\tBest utilization: {best_util:4.1f}%") + + +if __name__ == "__main__": + from absl import app + + jax.config.config_with_absl() + app.run(main) diff --git a/jax/experimental/pallas/ops/gpu/collective_matmul_mgpu.py b/jax/experimental/pallas/ops/gpu/collective_matmul_mgpu.py new file mode 100644 index 000000000000..5e4dda4494ba --- /dev/null +++ b/jax/experimental/pallas/ops/gpu/collective_matmul_mgpu.py @@ -0,0 +1,188 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A collective matmul kernel implemented using Mosaic GPU.""" + +import functools +import jax +from jax import lax +from jax.experimental import pallas as pl +from jax.experimental.pallas import mosaic_gpu as plgpu +import jax.numpy as jnp + + +def _find_swizzle(dim_size_bits: int, what: str): + for swizzle_bytes in (128, 64, 32, 16): + if dim_size_bits % (swizzle_bytes * 8) == 0: + return swizzle_bytes + raise ValueError( + f"No valid out swizzle for {what}: its minor dimension has" + f" {dim_size_bits} bits, which is not a multiple of 128" + ) + + +# TODO(apaszke): Add grid tiling +def all_gather_lhs_matmul( + lhs: jax.Array, + rhs: jax.Array, + axis_name, + *, + block_m: int, + block_n: int, + block_k: int, + max_concurrent_steps: int, + dtype: jnp.dtype = jnp.float16, +) -> jax.Array: + if (num_devices := jax.device_count()) != jax.process_count(): + raise ValueError("The kernel only supports one device per process") + if (axis_size := lax.axis_size(axis_name)) != num_devices: + raise ValueError("The kernel can only work over all devices in a Mesh.") + if max_concurrent_steps < 2: + raise ValueError("max_concurrent_steps must be >= 2") + if jnp.dtype(dtype) not in map(jnp.dtype, [jnp.float16, jnp.bfloat16]): + raise NotImplementedError(f"Only f16 and bf16 are supported, got dtype: {dtype}") + + num_sms = 132 # There are 132 SMs on a H100 SXM GPU. + + m_shard, k = lhs.shape + k2, n_shard = rhs.shape + if k != k2: + raise ValueError( + f"lhs and rhs must have the same contraction size, got {k} and {k2}." + ) + if (element_type := lhs.dtype) != rhs.dtype: + raise ValueError( + f"lhs and rhs must have the same element type, got {element_type} and" + f" {rhs.dtype}." + ) + if k % block_k != 0: + raise NotImplementedError(f"k={k} must be a multiple of block_k={block_k}") + if m_shard % block_m != 0: + raise NotImplementedError(f"m_shard={m_shard} must be a multiple of block_m={block_m}") + if n_shard % block_n != 0: + raise NotImplementedError(f"n_shard={n_shard} must be a multiple of block_n={block_n}") + if n_shard != block_n: + raise NotImplementedError( + f"n_shard={n_shard} must be equal to block_n={block_n}" + ) + + swizzle = min( + _find_swizzle(block_k * jnp.finfo(element_type).bits, "lhs"), + _find_swizzle(block_n * jnp.finfo(element_type).bits, "rhs"), + ) + transforms = ( + plgpu.TilingTransform((8, swizzle // jnp.dtype(element_type).itemsize)), + plgpu.SwizzleTransform(swizzle), + ) + + def kernel_body(lhs_ref, rhs_ref, out_ref, scratch_ref, capacity_sem, received_sem): + sm_id = lax.axis_index('sm') + scratch_ref = scratch_ref.at[sm_id] + + dev_id = lax.axis_index(axis_name) + send_dev_id = lax.rem(dev_id + axis_size - 1, axis_size) + recv_dev_id = lax.rem(dev_id + 1, axis_size) + # NOTE: Technically we should signal the recv_dev_id (and our signal would + # be received from send_dev_id), but if everyone signals in a ring after a + # barrier then it's equivalent to a local signal. + pl.semaphore_signal(capacity_sem) + send_scratch_ref = plgpu.remote_ref( + scratch_ref, send_dev_id, device_id_type=pl.DeviceIdType.LOGICAL + ) + + def m_loop(mi, _): + mi = mi * lax.axis_size('sm') + sm_id + m_tile_slice = pl.ds(mi * block_m, block_m) + + # For some reason ptxas spills if we unroll the loop over k + copy_block = 32 + @pl.loop(0, k // copy_block) + def _k_copy_loop(ki): + k_slice = pl.ds(ki * copy_block, copy_block) + scratch_ref[0, :, k_slice] = lhs_ref[m_tile_slice, k_slice] + + @pl.loop(0, num_devices) + def _device_loop(device_offset): + # Loop invariant: scratch_ref.at[scratch_slot] is ready to be used + # We're double buffering the scratch space. At each step, we read from + # scratch_ref.at[scratch_slot] and write to scratch_ref.at[next_scratch_slot] + # located on the send_dev_id. We swap the slots after completing a step, + # which lets us overlap the copy with compute. + scratch_slot = lax.rem(device_offset, 2) + next_scratch_slot = 1 - scratch_slot + + @functools.partial( + pl.run_scoped, + acc_ref=plgpu.ACC((block_m, block_n)), + out_smem=plgpu.SMEM((block_m, block_n), dtype, transforms=transforms), + ) + def _(acc_ref, out_smem): + pl.semaphore_wait(capacity_sem) + @functools.partial( + plgpu.emit_pipeline, + grid=(k // block_k,), + in_specs=[ + plgpu.BlockSpec((block_m, block_k), lambda k: (0, k), transforms=transforms), + plgpu.BlockSpec((block_k, block_n), lambda k: (k, 0), transforms=transforms), + ], + max_concurrent_steps=max_concurrent_steps, + delay_release=1, + ) + def k_loop(idxs, lhs_smem, rhs_smem): + (ki,) = idxs + plgpu.wgmma(acc_ref, lhs_smem, rhs_smem) + k_slice = pl.ds(ki * block_k, block_k) + # TODO(apaszke): No need to send on the last step + plgpu.copy_smem_to_gmem( + lhs_smem, send_scratch_ref.at[next_scratch_slot, :, k_slice] + ) + # We only delay release by 1 step, so we need to wait for the + # previous copies. + plgpu.wait_smem_to_gmem(1, wait_read_only=True) + k_loop(scratch_ref.at[scratch_slot], rhs_ref) + # Make sure the copy is fully done. + plgpu.wait_smem_to_gmem(0, wait_read_only=False) + # TODO(apaszke): Both of those semaphores perform a .sys release. + # This is very expensive and we should only do a single .sys fence. + pl.semaphore_signal(capacity_sem, device_id=recv_dev_id, device_id_type=pl.DeviceIdType.LOGICAL) + pl.semaphore_signal(received_sem, device_id=send_dev_id, device_id_type=pl.DeviceIdType.LOGICAL) + # Make sure all TMAs have read SMEM before we overwrite it. + plgpu.wait_smem_to_gmem(0, wait_read_only=True) + out_smem[...] = acc_ref[...].astype(out_smem.dtype) + plgpu.commit_smem() + device_m_slice = pl.ds( + lax.rem(device_offset + dev_id, num_devices) * m_shard, block_m + ) + plgpu.copy_smem_to_gmem( + out_smem, out_ref.at[device_m_slice].at[m_tile_slice] + ) + # Wait for the next scratch to arrive --- see the loop invariant. + pl.semaphore_wait(received_sem) + + grid_size = m_shard // block_m + m_steps = grid_size // num_sms + jnp.int32(sm_id < grid_size % num_sms) + # TODO(apaszke): Use the ND-loop helper. + jax.lax.fori_loop(0, m_steps, m_loop, None) + + result, _ = plgpu.kernel( + kernel_body, + out_shape=[jax.ShapeDtypeStruct((axis_size * m_shard, n_shard), dtype), + jax.ShapeDtypeStruct((num_sms, 2, block_m, k), dtype)], + scratch_shapes=[ + plgpu.SemaphoreType.REGULAR, plgpu.SemaphoreType.REGULAR, + ], + grid=(num_sms,), + grid_names=('sm',), + )(lhs, rhs) + return result diff --git a/jax/experimental/pallas/ops/gpu/decode_attention.py b/jax/experimental/pallas/ops/gpu/decode_attention.py index e2c19b3eaf2d..ee8c22d1b3a4 100644 --- a/jax/experimental/pallas/ops/gpu/decode_attention.py +++ b/jax/experimental/pallas/ops/gpu/decode_attention.py @@ -193,7 +193,7 @@ def decode_attn_unbatched( pl.BlockSpec((None, block_h), lambda i, j: (j, i)), # l pl.BlockSpec((None, block_h), lambda i, j: (j, i)), # m ], - compiler_params=plgpu.TritonCompilerParams( + compiler_params=plgpu.CompilerParams( num_warps=num_warps_, num_stages=num_stages ), out_shape=[ diff --git a/jax/experimental/pallas/ops/gpu/layer_norm.py b/jax/experimental/pallas/ops/gpu/layer_norm.py index d37afaf4d9e0..b838885a9136 100644 --- a/jax/experimental/pallas/ops/gpu/layer_norm.py +++ b/jax/experimental/pallas/ops/gpu/layer_norm.py @@ -94,7 +94,7 @@ def layer_norm_forward( ] method = pl.pallas_call( kernel, - compiler_params=plgpu.TritonCompilerParams(num_warps=num_warps), + compiler_params=plgpu.CompilerParams(num_warps=num_warps), grid=(), out_shape=out_shape, debug=False, @@ -215,7 +215,7 @@ def layer_norm_backward( out_shape_dx = jax.ShapeDtypeStruct(shape=(n,), dtype=x.dtype) method = pl.pallas_call( kernel, - compiler_params=plgpu.TritonCompilerParams(num_warps=num_warps), + compiler_params=plgpu.CompilerParams(num_warps=num_warps), grid=(), out_shape=out_shape_dx, debug=False, @@ -247,7 +247,7 @@ def layer_norm_backward( grid_ = (pl.cdiv(reshaped_x.shape[1], block_n),) method = pl.pallas_call( kernel, - compiler_params=dict(triton=dict(num_warps=num_warps)), + compiler_params=plgpu.CompilerParams(num_warps=num_warps), grid=grid_, out_shape=out_shape_dwbias, debug=False, @@ -283,7 +283,7 @@ def layer_norm( out_shape = jax.ShapeDtypeStruct(shape=(n,), dtype=x.dtype) method = pl.pallas_call( kernel, - compiler_params=plgpu.TritonCompilerParams( + compiler_params=plgpu.CompilerParams( num_warps=num_warps, num_stages=num_stages), grid=(), out_shape=out_shape, diff --git a/jax/experimental/pallas/ops/gpu/paged_attention.py b/jax/experimental/pallas/ops/gpu/paged_attention.py index b30ef554fe12..ca21761cf3ed 100644 --- a/jax/experimental/pallas/ops/gpu/paged_attention.py +++ b/jax/experimental/pallas/ops/gpu/paged_attention.py @@ -33,7 +33,9 @@ def paged_attention_kernel( # inputs q_ref, # [block_h, head_dim] k_pages_ref, # [total_num_pages, page_size, head_dim] + k_scales_pages_ref, # [total_num_pages, page_size] v_pages_ref, # [total_num_pages, page_size, head_dim] + v_scales_pages_ref, # [total_num_pages, page_size] block_tables_ref, # [pages_per_partition] lengths_ref, # [1] # outputs @@ -65,7 +67,16 @@ def body(start_k, carry): block_tables = pl.load(block_tables_ref, block_tables_slice) k = k_pages_ref[block_tables].reshape(block_k, head_dim) v = v_pages_ref[block_tables].reshape(block_k, head_dim) + if k_scales_pages_ref is not None: + # dynamic lhs quantized dot is not currently implemented + # so we cast rhs to the lhs dtype + k = k.astype(q.dtype) uncapped_logits = pl.dot(q, k.T) # [block_h, block_k] + if k_scales_pages_ref is not None: + # k_scales_pages_ref are one per head + # they're laid out across the output dimension, so scale output + k_scale = k_scales_pages_ref[block_tables].reshape((1, block_k)) + uncapped_logits *= k_scale.astype(uncapped_logits.dtype) if attn_logits_soft_cap is not None: logits = jnp.tanh(uncapped_logits / attn_logits_soft_cap) logits = logits * attn_logits_soft_cap @@ -92,6 +103,14 @@ def body(start_k, carry): l_curr = s_curr.sum(axis=-1) l_next = l_prev_corr + l_curr o_prev_corr = correction[:, None] * o_prev + if v_scales_pages_ref is not None: + # v_scales are 1 per head + # they're laid out across the reduction dimension, so scale lhs + v_scale = v_scales_pages_ref[block_tables].reshape((1, block_k)) + s_curr *= v_scale.astype(s_curr.dtype) + # dynamic lhs quantized dot is not currently implemented + # so we cast rhs to the lhs dtype + v = v.astype(s_curr.dtype) o_curr = pl.dot(s_curr.astype(v.dtype), v) o_next = o_prev_corr + o_curr @@ -134,6 +153,8 @@ def paged_attention_unbatched( v_pages: jax.Array, # [num_kv_heads, total_num_pages, page_size, head_dim] block_tables: jax.Array, # [pages_per_sequence] lengths: jax.Array | None, # [1] + k_scales_pages: jax.Array | None = None, # [num_kv_heads, total_num_pages, page_size] + v_scales_pages: jax.Array | None = None, # [num_kv_heads, total_num_pages, page_size] *, block_h: int, pages_per_compute_block: int, @@ -179,6 +200,19 @@ def paged_attention_unbatched( mask_value=mask_value, attn_logits_soft_cap=attn_logits_soft_cap, ) + # set up quantization scales + if k_scales_pages is not None: + assert k_scales_pages.shape == (num_kv_heads, total_num_pages, page_size) + k_scales_spec = pl.BlockSpec((None, total_num_pages, page_size), + lambda h, i, k: (h, 0, 0)) + else: + k_scales_spec = None + if v_scales_pages is not None: + assert v_scales_pages.shape == (num_kv_heads, total_num_pages, page_size) + v_scales_spec = pl.BlockSpec((None, total_num_pages, page_size), + lambda h, i, k: (h, 0, 0)) + else: + v_scales_spec = None o, l, m = pl.pallas_call( kernel, @@ -191,10 +225,12 @@ def paged_attention_unbatched( (None, total_num_pages, page_size, head_dim), lambda h, i, k: (h, 0, 0, 0), ), # k_pages + k_scales_spec, # k_pages_scale pl.BlockSpec( (None, total_num_pages, page_size, head_dim), lambda h, i, k: (h, 0, 0, 0), ), # v_pages + v_scales_spec, # v_pages_scale pl.BlockSpec( (None, pages_per_partition), lambda h, i, k: (k, 0) ), # block_tables @@ -222,11 +258,11 @@ def paged_attention_unbatched( ], debug=debug, interpret=interpret, - compiler_params=plgpu.TritonCompilerParams( + compiler_params=plgpu.CompilerParams( num_warps=num_warps, num_stages=num_stages ), name=f"paged_attention_{block_h=}_{pages_per_compute_block=}", - )(q_reshaped, k_pages, v_pages, block_tables, lengths) + )(q_reshaped, k_pages, k_scales_pages, v_pages, v_scales_pages, block_tables, lengths) if q_heads_per_kv_head % block_h: o = o[..., :q_heads_per_kv_head, :] @@ -265,6 +301,8 @@ def paged_attention( v_pages: jax.Array, block_tables: jax.Array, lengths: jax.Array | None, + k_scales_pages: jax.Array | None = None, + v_scales_pages: jax.Array | None = None, *, block_h: int = 16, pages_per_compute_block: int = 8, @@ -286,6 +324,8 @@ def paged_attention( should be in the range of [0, total_num_pages), indicating where to locate the page in `k_pages` or `v_pages`. lengths: A i32[batch_size] jax.Array the length of each example. + k_scales_pages: A [num_kv_heads, total_num_pages, page_size] jax.Array. + v_scales_pages: A [num_kv_heads, total_num_pages, page_size] jax.Array. block_h: int The block size that partitions the number of head groups. pages_per_compute_block: int The maximum number of blocks per compute block. k_splits: int Number of partitions used to parallelize key-value sequence @@ -342,12 +382,14 @@ def paged_attention( attn_logits_soft_cap=attn_logits_soft_cap, ) - o = jax.vmap(impl, (0, None, None, 0, 0), 0)( + o = jax.vmap(impl, (0, None, None, 0, 0, None, None), 0)( q, k_pages, v_pages, block_tables, lengths[..., None] if lengths is not None else None, + k_scales_pages, + v_scales_pages, ) return o diff --git a/jax/experimental/pallas/ops/gpu/ragged_dot_mgpu.py b/jax/experimental/pallas/ops/gpu/ragged_dot_mgpu.py new file mode 100644 index 000000000000..ed23f5eb764d --- /dev/null +++ b/jax/experimental/pallas/ops/gpu/ragged_dot_mgpu.py @@ -0,0 +1,305 @@ +# Copyright 2025 The JAX Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Ragged dot Pallas-Mosaic-GPU implementation.""" + +import dataclasses +import functools +import itertools +import math +import jax +from jax import lax +from jax import numpy as jnp +from jax import random +from jax._src import test_util as jtu # noqa: F401 +from jax.experimental import pallas as pl +from jax.experimental.mosaic.gpu import profiler +from jax.experimental.pallas import mosaic_gpu as plgpu +import numpy as np + + +@dataclasses.dataclass(frozen=True) +class GroupInfo: + """Information regarding the group being processed in a block.""" + + group_id: jax.Array + block: jax.Array + block_start: jax.Array + actual_start: jax.Array + actual_end: jax.Array + start_within_block: jax.Array + actual_size: jax.Array + + @classmethod + def create(cls, group_lengths, tile, tid): + """Get the group info for the current block.""" + + tile = jnp.int32(tile) + group_boundaries = [group_lengths[i] for i in range(group_lengths.shape[0])] + + # We usually only have very few groups, so we unroll the loop processing + # them. Normally we'd break out of the loop early, once we'd have found our + # boundary, but we can't do that when unrolling, so we rely on many selects + # to mask out the epilogue of the loop. + group_end = group_start = block = group = end = jnp.array( + 0, dtype=jnp.int32 + ) + + for i, b in enumerate(group_boundaries): + # Start/end are inclusive + start = end + end = start + b + final = end - 1 + start_block = lax.div(start, tile) + final_block = lax.div(final, tile) + block_end = final_block + 1 + tid_begin = start_block + i + tid_end = block_end + i + # How many blocks after is our block? + this_is_group = (tid_begin <= tid) & (tid < tid_end) + block = lax.select(this_is_group, tid - tid_begin + start_block, block) + group = lax.select(this_is_group, jnp.int32(i), group) + group_start = lax.select(this_is_group, start, group_start) + group_end = lax.select(this_is_group, end, group_end) + + block_start = block * tile + actual_start = jnp.maximum(group_start, block_start) + actual_end = jnp.minimum(group_end, block_start + tile) + start_within_block = actual_start - block_start + actual_size = actual_end - actual_start + return cls( + group_id=group, + block=block, + block_start=block_start, + actual_start=actual_start, + actual_end=actual_end, + start_within_block=start_within_block, + actual_size=actual_size, + ) + + +def _find_swizzle(dim_size_bits: int, what: str): + for swizzle_bytes in (128, 64, 32, 16): + if dim_size_bits % (swizzle_bytes * 8) == 0: + return swizzle_bytes + raise ValueError( + f"No valid out swizzle for {what}: its minor dimension has" + f" {dim_size_bits} bits, which is not a multiple of 128" + ) + + +def ragged_dot( + lhs, # (M, K) + rhs, # (G, K, N) + *, + group_sizes, # (G,) + block_m: int, + block_n: int, + block_k: int, + max_concurrent_steps: int, + grid_block_n: int, +) -> jax.Array: + if lhs.dtype != rhs.dtype: + raise NotImplementedError( + f"lhs and rhs must have the same dtype, got {lhs.dtype} and {rhs.dtype}" + ) + m, k = lhs.shape + g, k2, n = rhs.shape + + if group_sizes.shape[0] != g: + raise ValueError( + f"Expected group_sizes to have shape {g} but got {group_sizes.shape}" + ) + + if k != k2: + raise ValueError(f"lhs.shape={k} must match rhs.shape={k2}") + + if k % block_k != 0: + raise ValueError(f"k={k} must be a multiple of block_k={block_k}") + + def body(rows_per_expert_gmem, lhs_gmem, rhs_gmem, o_gmem): + grid = ( + grid_block_n, + pl.cdiv(m, block_m) + g - 1, + pl.cdiv(n, grid_block_n * block_n), + ) + + @plgpu.nd_loop(grid, collective_axes="sm") + def mn_loop(idx): # pylint: disable=unused-variable + block_ni, mi, remainder_ni = idx + ni = block_ni * pl.cdiv(n, block_n * grid_block_n) + remainder_ni + group_info = GroupInfo.create(rows_per_expert_gmem, block_m, mi) + + def acc_scope(acc_ref): + plgpu.emit_pipeline( + lambda _, lhs_smem, rhs_smem: plgpu.wgmma(acc_ref, lhs_smem, rhs_smem), + grid=(k // block_k,), + in_specs=[ + plgpu.BlockSpec((block_m, block_k), lambda k: (group_info.block, k)), + plgpu.BlockSpec((block_k, block_n), lambda k: (k, ni)), + ], + max_concurrent_steps=max_concurrent_steps, + delay_release=1, + )(lhs_gmem, rhs_gmem.at[group_info.group_id]) + return acc_ref[...] + + acc = pl.run_scoped(acc_scope, plgpu.ACC((block_m, block_n))) + + @functools.partial( + pl.run_scoped, + o_smem=plgpu.SMEM((block_m, block_n), dtype=o_gmem.dtype) + ) + def store_scope(o_smem): # pylint: disable=unused-variable + o_smem[...] = acc.astype(o_smem.dtype) + plgpu.commit_smem() + + smem_start = group_info.start_within_block + remaining_rows = min(block_m, m) + # TMA descriptors need to be generated with static tile sizes along each + # axis, but we do not know at compile time how many rows we will need to + # store. We only know that the number of rows to store is bounded by + # min(block_m, m). + # + # In order to work around that, we construct a logarithmic ladder of + # TMA descriptors, where each descriptor can store 2**i rows for some + # i between 0 and log2(min(block_m, m)). This allows storing any + # number of rows we will need to store, so long as this number of rows + # is between `1` and `min(block_m, m)`. + # + # E.g., imagine we have block_m = 8, m = 16. The loop below will be + # unrolled into 4 iterations, where the first one will generate a TMA + # descriptor that can store 8 rows, the second one will generate a TMA + # descriptor that can store 4 rows, etc. all the way to 1 row. + # + # At run time, we finally know the actual number of rows we need to + # store as we go through the unrolled loop iterations. Let's imagine + # that we need to store 5 rows. + # + # The first unrolled iteration will check whether we can store 8 rows. + # Since we only need to store 5 rows, we won't store anything then. + # + # The second unrolled iteration will check whether we can store 4 rows. + # We're able to store 4 rows, and are left with a single remaining row. + # + # The fourth unrolled iteration will store the single remaining row, and + # we end up with a storing scheme as follows for our 5 rows: + # + # ----------------------------------------------------------- + # 0 | | + # 1 | | + # 2 | Store 4 rows | + # 3 | | + # ----------------------------------------------------------- + # 4 | Store 1 row | + # ----------------------------------------------------------- + while remaining_rows > 0: + const_rows_len = 1 << int(math.log2(remaining_rows)) + remaining_rows //= 2 + + @pl.when(group_info.actual_size & const_rows_len != 0) + def _(): + o_smem_slice = o_smem.at[pl.ds(smem_start, const_rows_len)] + o_gref_slice = o_gmem.at[ + pl.ds(group_info.block_start + smem_start, const_rows_len), + pl.ds(ni * block_n, block_n), + ] + plgpu.copy_smem_to_gmem(o_smem_slice, o_gref_slice) + + smem_start += group_info.actual_size & const_rows_len + plgpu.wait_smem_to_gmem(0, wait_read_only=True) + + # There are 132 SMs on a H100 SXM GPU. + num_sms = 132 + kernel = plgpu.kernel( + body, + out_shape=jax.ShapeDtypeStruct((m, n), lhs.dtype), + grid=(num_sms,), + grid_names=("sm",), + compiler_params=plgpu.CompilerParams( + lowering_semantics=plgpu.LoweringSemantics.Warpgroup, + ), + ) + return kernel(group_sizes, lhs, rhs) + + +def main(unused_argv): + m, k, n, num_groups = 16 * 1024, 2048, 16 * 1024, 16 + kx, ky, kz = random.split(random.key(1234), num=3) + + lhs = jax.random.normal(kx, (m, k), jnp.float16) + rhs = jax.random.normal(ky, (num_groups, k, n), jnp.float16) + group_boundaries = jax.lax.sort( + jax.random.randint(kz, (num_groups - 1,), 0, m, jnp.int32) + ) + group_starts = lax.concatenate( + [jnp.array([0], dtype=jnp.int32), group_boundaries], 0 + ) + group_ends = lax.concatenate( + [group_boundaries, jnp.array([m], dtype=jnp.int32)], 0 + ) + group_sizes = group_ends - group_starts + assert group_sizes.shape == (num_groups,) + + block_m = block_n = (64, 128, 192) + block_k = (64,) + max_concurrent_steps = (2, 4, 5, 6) + grid_block_n = (1, 2, 4, 8, 16) + configs = itertools.product( + block_m, block_n, block_k, max_concurrent_steps, grid_block_n + ) + names = ( + "block_m", "block_n", "block_k", "max_concurrent_steps", "grid_block_n" + ) + best_runtime = float("inf") + best_kwargs = {} + for config in configs: + kwargs = dict(zip(names, config)) + if n % (kwargs["grid_block_n"] * kwargs["block_n"]): + continue + try: + f = functools.partial(ragged_dot, group_sizes=group_sizes, **kwargs) + _, runtime = profiler.measure(f, mode="cupti")(lhs, rhs) + except ValueError as e: + if "Mosaic GPU kernel exceeds available shared memory" not in str(e): + raise + runtime = float("inf") + # Enable this to get more detailed information. + else: + print(" ".join(f"{k}={v}" for k, v in kwargs.items()), int(runtime * 1000)) + if runtime < best_runtime: # pytype: disable=unsupported-operands + best_runtime = runtime + best_kwargs = kwargs + if not best_kwargs: + raise ValueError("No valid configuration found") + + ref, ref_runtime = profiler.measure(jax.lax.ragged_dot)( + lhs, rhs, group_sizes=group_sizes + ) + result = ragged_dot(lhs, rhs, group_sizes=group_sizes, **best_kwargs) + np.testing.assert_allclose(result, ref, atol=1e-3, rtol=1e-3) + + tflops = float(2 * k * m * n) / (best_runtime / 1e3) / 1e12 + ref_tflops = float(2 * k * m * n) / (ref_runtime / 1e3) / 1e12 + print( + "Best parameters: ", " ".join(f"{k}={v}" for k, v in best_kwargs.items()) + ) + print(f"Kernel: {best_runtime * 1000:.1f} us = {tflops:.1f} TFLOPS") + print(f"Reference: {ref_runtime * 1000:.1f} us = {ref_tflops:.1f} TFLOPS") + + +if __name__ == "__main__": + from absl import app + + jax.config.config_with_absl() + app.run(main) diff --git a/jax/experimental/pallas/ops/gpu/rms_norm.py b/jax/experimental/pallas/ops/gpu/rms_norm.py index ff224c6dfde7..a1b2b582f7bb 100644 --- a/jax/experimental/pallas/ops/gpu/rms_norm.py +++ b/jax/experimental/pallas/ops/gpu/rms_norm.py @@ -82,7 +82,7 @@ def rms_norm_forward( ] method = pl.pallas_call( kernel, - compiler_params=plgpu.TritonCompilerParams(num_warps=num_warps), + compiler_params=plgpu.CompilerParams(num_warps=num_warps), grid=(), out_shape=out_shape, debug=False, @@ -196,7 +196,7 @@ def rms_norm_backward( out_shape_dx = jax.ShapeDtypeStruct(shape=(n,), dtype=x.dtype) method = pl.pallas_call( kernel, - compiler_params=plgpu.TritonCompilerParams(num_warps=num_warps), + compiler_params=plgpu.CompilerParams(num_warps=num_warps), grid=(), out_shape=out_shape_dx, debug=False, @@ -228,7 +228,7 @@ def rms_norm_backward( grid_ = (pl.cdiv(reshaped_x.shape[1], block_n),) method = pl.pallas_call( kernel, - compiler_params=dict(triton=dict(num_warps=num_warps)), + compiler_params=plgpu.CompilerParams(num_warps=num_warps), grid=grid_, out_shape=out_shape_dwbias, debug=False, @@ -264,8 +264,8 @@ def rms_norm( out_shape = jax.ShapeDtypeStruct(shape=(n,), dtype=x.dtype) method = pl.pallas_call( kernel, - compiler_params=dict( - triton=dict(num_warps=num_warps, num_stages=num_stages) + compiler_params=plgpu.CompilerParams( + num_warps=num_warps, num_stages=num_stages ), grid=(), out_shape=out_shape, diff --git a/jax/experimental/pallas/ops/gpu/softmax.py b/jax/experimental/pallas/ops/gpu/softmax.py index 7fc6a0f50cb4..68960081288e 100644 --- a/jax/experimental/pallas/ops/gpu/softmax.py +++ b/jax/experimental/pallas/ops/gpu/softmax.py @@ -80,7 +80,7 @@ def softmax( kernel = functools.partial(_vmappable_softmax_kernel, block_row=block_row) f = pl.pallas_call( kernel, - compiler_params=plgpu.TritonCompilerParams( + compiler_params=plgpu.CompilerParams( num_warps=num_warps, num_stages=1), grid=(), out_shape=out_shape, diff --git a/jax/experimental/pallas/ops/tpu/all_gather.py b/jax/experimental/pallas/ops/tpu/all_gather.py index 8fb975504e26..ce80a443547e 100644 --- a/jax/experimental/pallas/ops/tpu/all_gather.py +++ b/jax/experimental/pallas/ops/tpu/all_gather.py @@ -30,7 +30,7 @@ import jax from jax import lax from jax.experimental import pallas as pl -from jax.experimental import shard_map +from jax._src import shard_map from jax.experimental.pallas import tpu as pltpu import jax.numpy as jnp @@ -48,7 +48,7 @@ def get_neighbor( idx if i == which_axis else lax.axis_index(a) for i, a in enumerate(axis_names) ] - axis_size = lax.psum(1, axis_name) + axis_size = lax.axis_size(axis_name) if direction == "right": next_idx = lax.rem(idx + 1, axis_size) else: @@ -67,7 +67,7 @@ def ag_kernel(x_ref, o_ref, send_sem, recv_sem, *, axis_name: str, pltpu.async_copy(x_ref, o_ref.at[my_id], recv_sem[0]).wait() with jax.named_scope("neighbour_lookup"): - axis_size = lax.psum(1, axis_name) + axis_size = lax.axis_size(axis_name) left_neighbor = get_neighbor(my_id, mesh, axis_name, direction="left") right_neighbor = get_neighbor(my_id, mesh, axis_name, direction="right") @@ -120,7 +120,7 @@ def ag_kernel(x_ref, o_ref, send_sem, recv_sem, *, axis_name: str, jax.jit, static_argnames=["mesh", "axis_name", "memory_space"] ) def all_gather(x, *, mesh: jax.sharding.Mesh, axis_name: str | Sequence[str], - memory_space: pltpu.TPUMemorySpace = pltpu.VMEM): + memory_space: pltpu.MemorySpace = pltpu.VMEM): if isinstance(axis_name, str): axis_name = (axis_name,) # TODO(sharadmv): enable all gather over multiple axes @@ -131,12 +131,12 @@ def all_gather(x, *, mesh: jax.sharding.Mesh, axis_name: str | Sequence[str], # We can short-circuit here if our axis size is 1 return x def ag_local(x_shard): - axis_size = lax.psum(1, axis_name) + axis_size = lax.axis_size(axis_name) out_shape = jax.ShapeDtypeStruct((axis_size, *x_shard.shape), x_shard.dtype) out = pl.pallas_call( functools.partial(ag_kernel, axis_name=axis_name, mesh=mesh), out_shape=out_shape, - compiler_params=pltpu.TPUCompilerParams(collective_id=0), + compiler_params=pltpu.CompilerParams(collective_id=0), grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, scratch_shapes=( @@ -151,5 +151,5 @@ def ag_local(x_shard): return shard_map.shard_map( ag_local, mesh=mesh, in_specs=P(axis_name), out_specs=P(None), - check_rep=False + check_vma=False )(x) diff --git a/jax/experimental/pallas/ops/tpu/flash_attention.py b/jax/experimental/pallas/ops/tpu/flash_attention.py index 0cb3d798d09e..27f66d34e354 100644 --- a/jax/experimental/pallas/ops/tpu/flash_attention.py +++ b/jax/experimental/pallas/ops/tpu/flash_attention.py @@ -383,17 +383,15 @@ def start_new_sequence(): @pl.when(should_run) def run(): - @functools.partial( - lax.fori_loop, 0, block_k_major // block_k, init_val=None, unroll=True - ) - def body(i, _): + @pl.loop(0, block_k_major // block_k, unroll=True) + def _body(i): m_prev = m_scratch_ref[batch_idx] l_prev = l_scratch_ref[batch_idx] q = q_tile_ref[batch_idx] # [block_q, head_dim] start_k = i * block_k - k = pl.load( - k_tile_ref, (*batch_idx, pl.dslice(start_k, block_k), slice(None)) - ) # [block_k, head_dim] + k = k_tile_ref[ + (*batch_idx, pl.dslice(start_k, block_k), slice(None)) + ] # [block_k, head_dim] s = jax.lax.dot_general( q, k, TRANS_B_DIM_NUMBERS, preferred_element_type=jnp.float32 @@ -403,10 +401,9 @@ def body(i, _): # TODO(tanburn) Should the attention bias be added before or after # multiplication by sm_scale? if ab_tile_ref is not None: - ab = pl.load( - ab_tile_ref, + ab = ab_tile_ref[ (*batch_idx, pl.dslice(None), pl.dslice(start_k, block_k)) - ).astype(jnp.float32) + ].astype(jnp.float32) s += ab if sm_scale != 1.0: @@ -422,10 +419,9 @@ def body(i, _): q_segment_ids = pltpu.repeat( q_segment_ids_tile_ref[batch_idx[0]], repeats, axis=1 ) # [block_q, block_k]. - kv_segment_ids = pl.load( - kv_segment_ids_tile_ref, - (batch_idx[0], pl.dslice(1), pl.dslice(start_k, block_k)), - ) # [1, block_k]. + kv_segment_ids = kv_segment_ids_tile_ref[ + batch_idx[0], :1, pl.dslice(start_k, block_k) + ] # [1, block_k]. mask = jnp.equal(q_segment_ids, kv_segment_ids).astype(jnp.bool_) if causal: @@ -471,9 +467,7 @@ def body(i, _): l_next_inv_safe = jnp.where(l_next == 0.0, 1.0, 1.0 / l_next) acc_scratch_ref[batch_idx] *= l_broadcast(l_corr * l_next_inv_safe) - v = pl.load( - v_tile_ref, (*batch_idx, pl.dslice(start_k, block_k), slice(None)) - ) + v = v_tile_ref[(*batch_idx, pl.dslice(start_k, block_k), slice(None))] o_curr = jax.lax.dot( p.astype(v.dtype), v, preferred_element_type=jnp.float32 ) @@ -529,15 +523,13 @@ def _flash_attention_kernel_single_batch_single_step( raise NotImplementedError( f"kv block size must be a multiple of {NUM_LANES}" ) - q_segment_ids = pl.load( - q_segment_ids_tile_ref, (batch_idx[0],) - ) # [block_q, NUM_LANES]. + q_segment_ids = q_segment_ids_tile_ref[ + batch_idx[0] + ] # [block_q, NUM_LANES]. q_segment_ids = pltpu.repeat( q_segment_ids, repeats, axis=1 ) # [block_q, block_k]. - kv_segment_ids = pl.load( - kv_segment_ids_tile_ref, (batch_idx[0], pl.dslice(1)) - ) # [1, block_k]. + kv_segment_ids = kv_segment_ids_tile_ref[batch_idx[0], :1] # [1, block_k]. mask = jnp.equal(q_segment_ids, kv_segment_ids).astype(jnp.bool_) if causal: @@ -775,7 +767,7 @@ def kv_segment_ids_index_map( ), out_shape=out_shape, debug=debug, - compiler_params=pltpu.TPUCompilerParams( + compiler_params=pltpu.CompilerParams( dimension_semantics=( "parallel", "parallel", @@ -840,33 +832,27 @@ def q_body(j, _): start_q = j * block_q def k_body(i, _): start_k = i * block_k - k = pl.load(k_tile_ref, (0, 0, pl.ds(start_k, block_k), slice(None))) - v = pl.load(v_tile_ref, (0, 0, pl.ds(start_k, block_k), slice(None))) - q = pl.load(q_tile_ref, (0, 0, pl.ds(start_q, block_q), slice(None)) - ) # [block_q, head_dim] - l = pl.load(l_tile_ref, (0, 0, pl.ds(start_q, block_q), slice(None)) - ) # [block_q, 128] - m = pl.load(m_tile_ref, (0, 0, pl.ds(start_q, block_q), slice(None)) - ) # [block_q, 128] - do = pl.load(do_tile_ref, (0, 0, pl.ds(start_q, block_q), slice(None)) - ) # [block_q, 128] - di = pl.load(di_tile_ref, (0, 0, pl.ds(start_q, block_q), slice(None)) - ).astype(jnp.float32) # [block_q, 128] + k = k_tile_ref[0, 0, pl.ds(start_k, block_k), :] + v = v_tile_ref[0, 0, pl.ds(start_k, block_k), :] + q = q_tile_ref[0, 0, pl.ds(start_q, block_q), :] # [block_q, head_dim] + l = l_tile_ref[0, 0, pl.ds(start_q, block_q), :] # [block_q, 128] + m = m_tile_ref[0, 0, pl.ds(start_q, block_q), :] # [block_q, 128] + do = do_tile_ref[0, 0, pl.ds(start_q, block_q), :] # [block_q, 128] + di = di_tile_ref[0, 0, pl.ds(start_q, block_q), :].astype( + jnp.float32 + ) # [block_q, 128] capped_logits = lax.dot_general( q, k, TRANS_B_DIM_NUMBERS, preferred_element_type=jnp.float32 ) # [block_q_major, block_k] if ab_tile_ref is not None: - ab = pl.load( - ab_tile_ref, - ( - 0, - 0, - pl.dslice(j * block_q, block_q), - pl.dslice(i * block_k, block_k), - ), - ).astype(jnp.float32) + ab = ab_tile_ref[ + 0, + 0, + pl.dslice(j * block_q, block_q), + pl.dslice(i * block_k, block_k), + ].astype(jnp.float32) capped_logits += ab if sm_scale != 1.0: @@ -878,15 +864,15 @@ def k_body(i, _): if rem: raise NotImplementedError( ) - q_segment_ids = pl.load( - q_segment_ids_tile_ref, (0, pl.ds(start_q, block_q), slice(None)) - ) # [block_q, NUM_LANES]. + q_segment_ids = q_segment_ids_tile_ref[ + 0, pl.ds(start_q, block_q), : + ] # [block_q, NUM_LANES]. q_segment_ids = pltpu.repeat( q_segment_ids, repeats, axis=1 ) # [block_q, block_k]. - kv_segment_ids = pl.load( - kv_segment_ids_tile_ref, (slice(None), 0, pl.ds(start_k, block_k)) - ) # [1, block_k]. + kv_segment_ids = kv_segment_ids_tile_ref[ + :, 0, pl.ds(start_k, block_k) + ] # [1, block_k]. mask = jnp.equal(q_segment_ids, kv_segment_ids).astype(jnp.bool_) if causal: @@ -913,9 +899,9 @@ def k_body(i, _): 1 / l, block_k // MIN_BLOCK_SIZE, axis=1 ) # [block_q_major, block_k_major] dv = lax.dot(p.T.astype(do.dtype), do, preferred_element_type=jnp.float32) - pl.store(dv_scratch_ref, (pl.ds(start_k, block_k), slice(None)), - pl.load(dv_scratch_ref, (pl.ds(start_k, block_k), slice(None))) - + dv.astype(dv_scratch_ref.dtype)) + dv_scratch_ref[pl.ds(start_k, block_k), :] += dv.astype( + dv_scratch_ref.dtype + ) # di: [block_q, 128] # do: [block_q, head_dim] @@ -931,9 +917,9 @@ def k_body(i, _): # ds: [block_q_major, block_k_major] # q: [block_q_major, head_dim] dk = lax.dot(ds.T.astype(do.dtype), q, preferred_element_type=jnp.float32) - pl.store(dk_scratch_ref, (pl.ds(start_k, block_k), slice(None)), - pl.load(dk_scratch_ref, (pl.ds(start_k, block_k), slice(None))) - + dk.astype(dk_scratch_ref.dtype)) + dk_scratch_ref[pl.ds(start_k, block_k), :] += dk.astype( + dk_scratch_ref.dtype + ) lax.fori_loop(0, block_k_major // block_k, k_body, None, unroll=True) if causal: @@ -1144,7 +1130,7 @@ def dkv_index_map(batch_index, head_index, kv_seq_index, _): ), out_shape=out_shapes, debug=debug, - compiler_params=pltpu.TPUCompilerParams( + compiler_params=pltpu.CompilerParams( dimension_semantics=( "parallel", "parallel", @@ -1192,12 +1178,8 @@ def start_new_sequence(): def body(i, _): k_slice = pl.ds(i * block_k, block_k) q = q_tile_ref[0, 0, :, :] - k = pl.load( - k_tile_ref, (0, 0, k_slice, slice(None)), - ) # [block_k, head_dim] - v = pl.load( - v_tile_ref, (0, 0, k_slice, slice(None)), - ) # [block_k, head_dim] + k = k_tile_ref[0, 0, k_slice, :] # [block_k, head_dim] + v = v_tile_ref[0, 0, k_slice, :] # [block_k, head_dim] l = l_tile_ref[0, 0, :, :] # [block_q_major, 128] m = m_tile_ref[0, 0, :, :] # [block_q_major, 128] do = do_tile_ref[0, 0, :, :] # [block_q_major, head_dim] @@ -1208,9 +1190,9 @@ def body(i, _): ) if ab_tile_ref is not None: - ab = pl.load( - ab_tile_ref, (0, 0, pl.dslice(None), pl.dslice(i * block_k, block_k)) - ).astype(jnp.float32) + ab = ab_tile_ref[0, 0, :, pl.dslice(i * block_k, block_k)].astype( + jnp.float32 + ) capped_logits += ab if sm_scale != 1.0: @@ -1226,9 +1208,7 @@ def body(i, _): q_segment_ids = pltpu.repeat( q_segment_ids_tile_ref[0], repeats, axis=1 ) # [block_q, block_k]. - kv_segment_ids = pl.load( - kv_segment_ids_tile_ref, (slice(None), 0, k_slice) - ) # [1, block_k]. + kv_segment_ids = kv_segment_ids_tile_ref[:, 0, k_slice] # [1, block_k]. mask = jnp.equal(q_segment_ids, kv_segment_ids).astype(jnp.bool_) if causal: @@ -1269,10 +1249,8 @@ def body(i, _): ds = ds * sm_scale if ds_tile_ref is not None: - pl.store( - ds_tile_ref, - (0, 0, pl.dslice(None), pl.dslice(i * block_k, block_k)), - ds.astype(ds_tile_ref.dtype), + ds_tile_ref[0, 0, :, pl.dslice(i * block_k, block_k)] = ds.astype( + ds_tile_ref.dtype ) # dp: [block_q_major, block_k] @@ -1487,7 +1465,7 @@ def kv_segment_ids_index_map( ), out_shape=out_shapes, debug=debug, - compiler_params=pltpu.TPUCompilerParams( + compiler_params=pltpu.CompilerParams( dimension_semantics=( "parallel", "parallel", diff --git a/jax/experimental/pallas/ops/tpu/matmul.py b/jax/experimental/pallas/ops/tpu/matmul.py index 4ff82acbb5dd..341aa93fa258 100644 --- a/jax/experimental/pallas/ops/tpu/matmul.py +++ b/jax/experimental/pallas/ops/tpu/matmul.py @@ -14,7 +14,7 @@ """Example matmul TPU kernel. -See discussion in https://jax.readthedocs.io/en/latest/pallas/tpu/matmul.html. +See discussion in https://docs.jax.dev/en/latest/pallas/tpu/matmul.html. """ import functools @@ -78,7 +78,7 @@ def matmul( grid=(x.shape[0] // l, y.shape[1] // r, x.shape[1] // block_k), scratch_shapes=[pltpu.VMEM((l, r), acc_dtype)], ), - compiler_params=pltpu.TPUCompilerParams( + compiler_params=pltpu.CompilerParams( dimension_semantics=("parallel", "parallel", "arbitrary")), debug=debug, )(x, y) diff --git a/jax/experimental/pallas/ops/tpu/megablox/gmm.py b/jax/experimental/pallas/ops/tpu/megablox/gmm.py index 5c2f938597e7..cb185fc45f1d 100644 --- a/jax/experimental/pallas/ops/tpu/megablox/gmm.py +++ b/jax/experimental/pallas/ops/tpu/megablox/gmm.py @@ -538,7 +538,7 @@ def out_transform_indices(n_i, grid_id, k_i, group_metadata, group_offset): scratch_shapes=[pltpu.VMEM((tm, tn), jnp.float32)], ), input_output_aliases=input_output_aliases, - compiler_params=pltpu.TPUCompilerParams( + compiler_params=pltpu.CompilerParams( dimension_semantics=("parallel", "arbitrary", "arbitrary")), interpret=interpret, cost_estimate=cost_estimate, @@ -777,7 +777,7 @@ def out_transform_indices(n_i, k_i, grid_id, group_metadata, group_offset): scratch_shapes=[pltpu.VMEM((tk, tn), jnp.float32)], ), input_output_aliases=input_output_aliases, - compiler_params=pltpu.TPUCompilerParams( + compiler_params=pltpu.CompilerParams( dimension_semantics=("parallel", "arbitrary", "arbitrary")), interpret=interpret, cost_estimate=cost_estimate, diff --git a/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py b/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py index eb1e11df17da..309858368896 100644 --- a/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py +++ b/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py @@ -114,7 +114,7 @@ def paged_flash_attention_kernel( lengths_ref, page_indices_ref, buffer_index_ref, - step_ref, + init_flag_ref, q_ref, k_pages_hbm_ref, k_scales_pages_hbm_ref, @@ -127,7 +127,8 @@ def paged_flash_attention_kernel( k_scales_vmem_buffer, v_vmem_buffer, v_scales_vmem_buffer, - sem, + k_sems, + v_sems, *, batch_size: int, pages_per_compute_block: int, @@ -176,7 +177,9 @@ def advance_to_next_non_zero_length(): return ( lax.cond( - jnp.logical_and(next_b < batch_size, lengths_ref[next_b] == 0), + jnp.logical_and( + next_b < batch_size, + lengths_ref[lax.clamp(0, next_b, batch_size - 1)] == 0), advance_to_next_non_zero_length, lambda: next_b, ), @@ -200,7 +203,7 @@ def create_kv_async_copy_descriptors(b, h, i, buffer_index): k_scales_vmem_buffer.at[buffer_index] if k_scales_vmem_buffer is not None else None, - sem, + k_sems.at[buffer_index], page_indices_ref, page_offset, pages_to_load, @@ -213,7 +216,7 @@ def create_kv_async_copy_descriptors(b, h, i, buffer_index): v_scales_vmem_buffer.at[buffer_index] if v_scales_vmem_buffer is not None else None, - sem, + v_sems.at[buffer_index], page_indices_ref, page_offset, pages_to_load, @@ -223,16 +226,12 @@ def create_kv_async_copy_descriptors(b, h, i, buffer_index): @pl.when(i * bk < length) def flash_attention(): # pylint: disable=unused-variable - step = step_ref[0] + init_flag = init_flag_ref[0] + init_flag_ref[0] = 0 buffer_index = buffer_index_ref[0] + next_b, next_h, next_i = compute_block_indices(b, h, i + 1) - @pl.when(i == 0) - def init(): # pylint: disable=unused-variable - m_ref[...] = jnp.full_like(m_ref, -jnp.inf) - l_ref[...] = jnp.zeros_like(l_ref) - o_ref[...] = jnp.zeros_like(o_ref) - - @pl.when(step == 0) + @pl.when(init_flag) def prefetch_first_block(): # pylint: disable=unused-variable async_copy_k, async_copy_v = create_kv_async_copy_descriptors( b, h, i, buffer_index @@ -240,7 +239,11 @@ def prefetch_first_block(): # pylint: disable=unused-variable async_copy_k.start() async_copy_v.start() - next_b, next_h, next_i = compute_block_indices(b, h, i + 1) + @pl.when(i == 0) + def init(): # pylint: disable=unused-variable + m_ref[...] = jnp.full_like(m_ref, -jnp.inf) + l_ref[...] = jnp.zeros_like(l_ref) + o_ref[...] = jnp.zeros_like(o_ref) @pl.when(next_b < batch_size) def prefetch_next_block(): # pylint: disable=unused-variable @@ -257,7 +260,7 @@ def prefetch_next_block(): # pylint: disable=unused-variable ) q = q_ref[...].astype(jnp.float32) k = async_copy_k.wait_and_get_loaded() - qk = jnp.einsum('hd,td->ht', q, k, preferred_element_type=jnp.float32) + qk = jnp.einsum("gd,td->gt", q, k, preferred_element_type=jnp.float32) if attn_logits_soft_cap is not None: capped_qk = jnp.tanh(qk / attn_logits_soft_cap) qk = capped_qk * attn_logits_soft_cap @@ -274,24 +277,21 @@ def prefetch_next_block(): # pylint: disable=unused-variable alpha = jnp.exp(m_prev - m_next) beta = jnp.exp(m_curr - m_next) l_next = alpha * l_prev + beta * l_curr - l_next_safe = jnp.where(l_next == 0.0, 1.0, l_next) + m_ref[...], l_ref[...] = m_next, l_next v = async_copy_v.wait_and_get_loaded() - o_curr_times_l_curr = jnp.dot(s_curr, v) + o_curr = jnp.einsum("gt,td->gd", s_curr, v) - m_ref[...], l_ref[...] = m_next, l_next_safe o_ref[...] = ( - (l_prev * alpha * o_ref[...] + beta * o_curr_times_l_curr) / l_next_safe + (l_prev * alpha * o_ref[...] + beta * o_curr) / l_next ).astype(o_ref.dtype) - step_ref[0] = step + 1 - def paged_flash_attention_kernel_inline_seq_dim( lengths_ref, page_indices_ref, buffer_index_ref, - step_ref, + init_flag_ref, q_ref, k_pages_hbm_ref, k_scales_pages_hbm_ref, @@ -304,7 +304,8 @@ def paged_flash_attention_kernel_inline_seq_dim( k_scales_vmem_buffer, v_vmem_buffer, v_scales_vmem_buffer, - sem, + k_sems, + v_sems, *, batch_size: int, pages_per_compute_block: int, @@ -326,7 +327,7 @@ def body(i, _): lengths_ref, page_indices_ref, buffer_index_ref, - step_ref, + init_flag_ref, q_ref, k_pages_hbm_ref, k_scales_pages_hbm_ref, @@ -339,7 +340,8 @@ def body(i, _): k_scales_vmem_buffer, v_vmem_buffer, v_scales_vmem_buffer, - sem, + k_sems, + v_sems, batch_size=batch_size, pages_per_compute_block=pages_per_compute_block, pages_per_sequence=pages_per_sequence, @@ -387,7 +389,7 @@ def paged_attention( """Paged grouped query attention. Args: - q: A [batch_size, num_heads, head_dim] jax.Array. + q: A [batch_size, num_q_heads, head_dim] jax.Array. k_pages: A [num_kv_heads, total_num_pages, page_size, head_dim] jax.Array. v_pages: A [num_kv_heads, total_num_pages, page_size, head_dim] jax.Array. lengths: A i32[batch_size] jax.Array the length of each example. @@ -412,7 +414,7 @@ def paged_attention( one kernel. Returns: - The output of attention([batch_size, num_heads, head_dim]). + The output of attention([batch_size, num_q_heads, head_dim]). """ if isinstance(k_pages, quantization_utils.QuantizedTensor): k_pages, k_scales_pages = k_pages.weight, k_pages.scales @@ -431,7 +433,7 @@ def paged_attention( else: v_scales_pages = None - batch_size, num_heads, head_dim = q.shape + batch_size, num_q_heads, head_dim = q.shape num_kv_heads, _, page_size, head_dim_k = k_pages.shape batch_size_paged_indices, pages_per_sequence = page_indices.shape @@ -440,10 +442,10 @@ def paged_attention( f"k_pages and v_pages must have the same shape. Got {k_pages.shape} and" f" {v_pages.shape}" # pytype: disable=attribute-error ) - if num_heads % num_kv_heads != 0: + if num_q_heads % num_kv_heads != 0: raise ValueError( "Number of Q heads must be divisible by number of KV heads. Got" - f" {num_heads} and {num_kv_heads}." + f" {num_q_heads} and {num_kv_heads}." ) if head_dim_k != head_dim: raise ValueError( @@ -480,40 +482,41 @@ def paged_attention( else: raise ValueError("megacore_mode must be one of ['kv_head', 'batch', None]") - if (num_heads // num_kv_heads) % 8 != 0: + num_groups = num_q_heads // num_kv_heads + if (num_groups) % 8 != 0: # Reshape q to hint XLA to pick a <1x128> layout otherwise it will pick a # <8x128> layout for a <1x128> memref inside the kernel and error out. - q = q.reshape(batch_size, num_heads, 1, head_dim) + q = q.reshape(batch_size, num_q_heads, 1, head_dim) if megacore_mode == "kv_head": q_block_spec = pl.BlockSpec( - (None, num_heads // num_kv_heads, None, head_dim), + (None, num_groups, None, head_dim), lambda core_index, b, h, *_: (b, h * num_cores + core_index, 0, 0), ) elif megacore_mode == "batch": q_block_spec = pl.BlockSpec( - (None, num_heads // num_kv_heads, None, head_dim), + (None, num_groups, None, head_dim), lambda core_index, b, h, *_: (b * num_cores + core_index, h, 0, 0), ) else: q_block_spec = pl.BlockSpec( - (None, num_heads // num_kv_heads, None, head_dim), + (None, num_groups, None, head_dim), lambda core_index, b, h, *_: (b, h, 0, 0), ) q_dtype_for_kernel_launch = jnp.float32 else: if megacore_mode == "kv_head": q_block_spec = pl.BlockSpec( - (None, num_heads // num_kv_heads, head_dim), + (None, num_groups, head_dim), lambda core_index, b, h, *_: (b, h * num_cores + core_index, 0), ) elif megacore_mode == "batch": q_block_spec = pl.BlockSpec( - (None, num_heads // num_kv_heads, head_dim), + (None, num_groups, head_dim), lambda core_index, b, h, *_: (b * num_cores + core_index, h, 0), ) else: q_block_spec = pl.BlockSpec( - (None, num_heads // num_kv_heads, head_dim), + (None, num_groups, head_dim), lambda core_index, b, h, *_: (b, h, 0), ) q_dtype_for_kernel_launch = q.dtype @@ -544,10 +547,10 @@ def paged_attention( if k_scales_pages is not None and v_scales_pages is not None: in_specs = [ q_block_spec, - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pltpu.ANY), ] scratch_shapes = ( pltpu.VMEM( @@ -586,14 +589,15 @@ def paged_attention( ), v_scales_pages.dtype, # pytype: disable=attribute-error ), # v_scales_pages buffer - pltpu.SemaphoreType.DMA, + pltpu.SemaphoreType.DMA((2,)), + pltpu.SemaphoreType.DMA((2,)), ) else: in_specs = [ q_block_spec, - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + pl.BlockSpec(memory_space=pltpu.ANY), None, # type: ignore[list-item] - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + pl.BlockSpec(memory_space=pltpu.ANY), None, # type: ignore[list-item] ] scratch_shapes = ( @@ -617,7 +621,8 @@ def paged_attention( v_pages.dtype, ), # v_pages buffer None, - pltpu.SemaphoreType.DMA, + pltpu.SemaphoreType.DMA((2,)), + pltpu.SemaphoreType.DMA((2,)), ) out, _, _ = pl.pallas_call( @@ -632,7 +637,7 @@ def paged_attention( ), grid_spec=pltpu.PrefetchScalarGridSpec( # There are 4 scalars prefetched per kernel call: `lengths_ref`, - # `page_indices_ref`, `buffer_index_ref`, `step_ref` + # `page_indices_ref`, `buffer_index_ref`, `init_flag_ref` num_scalar_prefetch=4, in_specs=in_specs, out_specs=[ @@ -643,8 +648,9 @@ def paged_attention( grid=grid, scratch_shapes=scratch_shapes, ), - compiler_params=pltpu.TPUCompilerParams( - dimension_semantics=dimension_semantics), + compiler_params=pltpu.CompilerParams( + dimension_semantics=dimension_semantics + ), out_shape=[ jax.ShapeDtypeStruct(q.shape, q_dtype_for_kernel_launch), jax.ShapeDtypeStruct((*q.shape[:-1], 1), jnp.float32), @@ -654,11 +660,11 @@ def paged_attention( lengths, page_indices.reshape(-1), jnp.zeros((1,), jnp.int32), # buffer index - jnp.zeros((1,), jnp.int32), # step + jnp.ones((1,), jnp.int32), # init flag q.astype(q_dtype_for_kernel_launch), k_pages, k_scales_pages, v_pages, v_scales_pages, ) - return out.reshape(batch_size, num_heads, head_dim).astype(q.dtype) + return out.reshape(batch_size, num_q_heads, head_dim).astype(q.dtype) diff --git a/jax/experimental/pallas/ops/tpu/paged_attention/util.py b/jax/experimental/pallas/ops/tpu/paged_attention/util.py new file mode 100644 index 000000000000..92aa3a7a1b2c --- /dev/null +++ b/jax/experimental/pallas/ops/tpu/paged_attention/util.py @@ -0,0 +1,82 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""JAX reference implementation of grouped query attention.""" + +import jax +from jax.experimental.pallas.ops.tpu.paged_attention import quantization_utils +import jax.numpy as jnp + +MASK_VALUE = -0.7 * float(jnp.finfo(jnp.dtype("float32")).max) + + +def grouped_query_attention_reference( + queries: jax.Array, # [batch_size, num_q_heads, head_dim] + k_pages: jax.Array, # [batch_size, num_kv_heads, max_seq_len, head_dim] + v_pages: jax.Array, # [batch_size, num_kv_heads, max_seq_len, head_dim] + seq_lens: jax.Array, # i32[batch_size] + soft_cap: float | None = None, + debug: bool = False, +) -> jax.Array: # [batch_size, num_q_heads, head_dim] + """Grouped query attention with a single query per request.""" + # Check input shapes + assert k_pages.shape == v_pages.shape + batch_size, num_q_heads, head_dim = queries.shape + batch_size2, num_kv_heads, max_seq_len, head_dim2 = k_pages.shape + assert batch_size2 == batch_size + assert head_dim2 == head_dim + + # Unquantize kv pages if necessary + if isinstance(k_pages, quantization_utils.QuantizedTensor): + k_pages = quantization_utils.unquantize_from_int8( + k_pages, dtype=jnp.float32 + ) + if isinstance(v_pages, quantization_utils.QuantizedTensor): + v_pages = quantization_utils.unquantize_from_int8( + v_pages, dtype=jnp.float32 + ) + + # Reshape for num_groups queries per k head + assert num_q_heads % num_kv_heads == 0 + num_groups = num_q_heads // num_kv_heads + queries = queries.reshape(batch_size, num_kv_heads, num_groups, head_dim) + + # Compute the dot product q*k and apply soft cap if necessary + qk = jnp.einsum( + "bhgd,bhtd->bhgt", + queries.astype(jnp.float32), + k_pages.astype(jnp.float32), + ) + if soft_cap is not None and soft_cap != 0.0: + qk = jnp.tanh(qk / soft_cap) * soft_cap + assert qk.shape == (batch_size, num_kv_heads, num_groups, max_seq_len) + if debug: + jax.debug.print("qk: {qk}", qk=qk) + + # Enforce causal mask (adding dimensions when necessary) + mask = jnp.arange(max_seq_len)[None] < seq_lens[:, None] + qk += jnp.where(mask, 0.0, MASK_VALUE)[:, None, None, :] + if debug: + jax.debug.print("masked: {qk}", qk=qk) + + # Generate probability distribution using softmax + probs = jax.nn.softmax(qk, axis=-1).astype(v_pages.dtype) + assert probs.shape == (batch_size, num_kv_heads, num_groups, max_seq_len) + if debug: + jax.debug.print("softmax: {probs}", probs=probs) + + # Attention is probability-weighted sum of v heads + attention = jnp.einsum("bhgt,bhtd->bhgd", probs, v_pages) + assert attention.shape == (batch_size, num_kv_heads, num_groups, head_dim) + return attention.reshape(batch_size, num_q_heads, head_dim) diff --git a/jax/experimental/pallas/ops/tpu/ragged_paged_attention/__init__.py b/jax/experimental/pallas/ops/tpu/ragged_paged_attention/__init__.py new file mode 100644 index 000000000000..3830adfa7fd6 --- /dev/null +++ b/jax/experimental/pallas/ops/tpu/ragged_paged_attention/__init__.py @@ -0,0 +1,23 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from jax.experimental.pallas.ops.tpu.ragged_paged_attention import kernel +from jax.experimental.pallas.ops.tpu.ragged_paged_attention import tuned_block_sizes + +cdiv = kernel.cdiv +dynamic_validate_inputs = kernel.dynamic_validate_inputs +ragged_paged_attention = kernel.ragged_paged_attention +ref_ragged_paged_attention = kernel.ref_ragged_paged_attention +static_validate_inputs = kernel.static_validate_inputs +get_tuned_block_sizes = tuned_block_sizes.get_tuned_block_sizes diff --git a/jax/experimental/pallas/ops/tpu/ragged_paged_attention.py b/jax/experimental/pallas/ops/tpu/ragged_paged_attention/kernel.py similarity index 55% rename from jax/experimental/pallas/ops/tpu/ragged_paged_attention.py rename to jax/experimental/pallas/ops/tpu/ragged_paged_attention/kernel.py index 6600d765024c..3f12448f2a9c 100644 --- a/jax/experimental/pallas/ops/tpu/ragged_paged_attention.py +++ b/jax/experimental/pallas/ops/tpu/ragged_paged_attention/kernel.py @@ -19,14 +19,16 @@ specifications. It supports mixed prefill and decoding, enhancing throughput during inference. """ - import functools import jax from jax import lax +from jax._src import dtypes from jax.experimental import pallas as pl from jax.experimental.pallas import tpu as pltpu +from jax.experimental.pallas.ops.tpu.ragged_paged_attention.tuned_block_sizes import get_tuned_block_sizes import jax.numpy as jnp + DEFAULT_MASK_VALUE = -0.7 * float(jnp.finfo(jnp.dtype("float32")).max) @@ -35,23 +37,20 @@ class MultiPageAsyncCopyDescriptor: def __init__( self, - pages_hbm_ref, # [total_num_pages, page_size, num_kv_heads_per_blk, head_dim] - vmem_buf, # [num_kv_pages_per_blk, page_size, num_kv_heads_per_blk, head_dim] + pages_hbm_ref, # [total_num_pages, page_size, num_combined_kv_heads_per_blk, head_dim] + vmem_buf, # [num_kv_pages_per_blk, page_size, num_combined_kv_heads_per_blk, head_dim] sem, page_indices_ref, # i32[max_num_seqs, pages_per_seq] - offset, # [seq_idx, kv_pages_start] + metadata, # [seq_idx, start_page_idx, end_page_idx] ): self._vmem_buf = vmem_buf - seq_id, kv_pages_start = offset - pages_per_seq = page_indices_ref.shape[1] + seq_id, start_page_idx, end_page_idx = metadata self._async_copies = [] # TODO(jevinjiang): Only fetch dynamic shape in need! This will insert # a bunch of if-ops. Check the performance when we have benchmarking setup. for i in range(vmem_buf.shape[0]): - page_idx = kv_pages_start + i - page_idx = jax.lax.select( - page_idx < pages_per_seq, page_idx, pages_per_seq - 1 - ) + page_idx = start_page_idx + i + page_idx = jax.lax.select(page_idx < end_page_idx, page_idx, 0) self._async_copies.append( pltpu.make_async_copy( pages_hbm_ref.at[page_indices_ref[seq_id, page_idx]], @@ -73,17 +72,38 @@ def wait(self): def ref_ragged_paged_attention( queries: jax.Array, # [max_num_batched_tokens, num_q_heads, head_dim] - k_pages: jax.Array, # [total_num_pages, page_size, num_kv_heads, head_dim] - v_pages: jax.Array, # [total_num_pages, page_size, num_kv_heads, head_dim] + kv_pages: jax.Array, # [total_num_pages, page_size, num_combined_kv_heads, head_dim] kv_lens: jax.Array, # i32[max_num_seqs] page_indices: jax.Array, # i32[max_num_seqs, pages_per_seq] cu_q_lens: jax.Array, # i32[max_num_seqs + 1] num_seqs: jax.Array, # i32[1], *, sm_scale: float = 1.0, - mask_value: float = DEFAULT_MASK_VALUE, + sliding_window: int | None = None, + soft_cap: float | None = None, + mask_value: float | None = DEFAULT_MASK_VALUE, + k_scale: float | None = None, + v_scale: float | None = None, ): - _, _, num_kv_heads, head_dim = k_pages.shape + static_validate_inputs( + queries, + kv_pages, + kv_lens, + page_indices, + cu_q_lens, + num_seqs, + sm_scale=sm_scale, + k_scale=k_scale, + v_scale=v_scale, + sliding_window=sliding_window, + soft_cap=soft_cap, + mask_value=mask_value, + ) + if mask_value is None: + mask_value = DEFAULT_MASK_VALUE + _, _, num_combined_kv_heads, head_dim = kv_pages.shape + assert num_combined_kv_heads % 2 == 0 + num_kv_heads = num_combined_kv_heads // 2 num_q_heads = queries.shape[1] assert num_q_heads % num_kv_heads == 0 num_query_per_kv = num_q_heads // num_kv_heads @@ -95,8 +115,18 @@ def ref_ragged_paged_attention( kv_len = kv_lens[i] indices = page_indices[i] q = queries[q_start:q_end] - k = k_pages[indices, :, :, :].reshape(-1, num_kv_heads, head_dim)[:kv_len] - v = v_pages[indices, :, :, :].reshape(-1, num_kv_heads, head_dim)[:kv_len] + k = kv_pages[indices, :, 0::2, :].reshape(-1, num_kv_heads, head_dim)[ + :kv_len + ] + v = kv_pages[indices, :, 1::2, :].reshape(-1, num_kv_heads, head_dim)[ + :kv_len + ] + if k_scale is not None: + k = k.astype(jnp.float32) * k_scale + k = k.astype(q.dtype) + if v_scale is not None: + v = v.astype(jnp.float32) * v_scale + v = v.astype(q.dtype) k = jnp.repeat(k, num_query_per_kv, axis=1) v = jnp.repeat(v, num_query_per_kv, axis=1) attn = jnp.einsum("qhd,khd->hqk", q, k, preferred_element_type=jnp.float32) @@ -105,7 +135,12 @@ def ref_ragged_paged_attention( jnp.int32, attn.shape, 1 ) kv_span = jax.lax.broadcasted_iota(jnp.int32, attn.shape, 2) - attn += jnp.where(q_span < kv_span, mask_value, 0.0) + mask = q_span < kv_span + if sliding_window is not None: + mask = jnp.logical_or(mask, q_span - sliding_window >= kv_span) + if soft_cap is not None: + attn = soft_cap * jnp.tanh(attn / soft_cap) + attn += jnp.where(mask, mask_value, 0.0) attn = jax.nn.softmax(attn, axis=-1).astype(v.dtype) out = jnp.einsum("hqk,khd->qhd", attn, v).astype(queries.dtype) outputs.append(out) @@ -113,26 +148,51 @@ def ref_ragged_paged_attention( return jnp.concatenate(outputs, axis=0) -# Expect to run these checkes during runtime. -def validate_inputs_on_runtime( +# Expect to run these checks during runtime. +def dynamic_validate_inputs( q: jax.Array, # [max_num_batched_tokens, num_q_heads, head_dim] - k_pages: jax.Array, # [total_num_pages, page_size, num_kv_heads, head_dim] - v_pages: jax.Array, # [total_num_pages, page_size, num_kv_heads, head_dim] + kv_pages: jax.Array, # [total_num_pages, page_size, num_combined_kv_heads, head_dim] kv_lens: jax.Array, # i32[max_num_seqs] page_indices: jax.Array, # i32[max_num_seqs, pages_per_seq] cu_q_lens: jax.Array, # i32[max_num_seqs + 1] - num_seqs, # i32[1] + num_seqs: jax.Array, # i32[1] + *, + # These inputs are optional. If not specified, we will not validate them. + sm_scale: float | None = None, + sliding_window: int | None = None, + soft_cap: float | None = None, + mask_value: float | None = None, + k_scale: float | None = None, + v_scale: float | None = None, + # Kernel tuning params. + num_kv_pages_per_block: int | None = None, + num_queries_per_block: int | None = None, + vmem_limit_bytes: int | None = None, ): - check_inputs_shapes( - q, k_pages, v_pages, kv_lens, page_indices, cu_q_lens, num_seqs + static_validate_inputs( + q, + kv_pages, + kv_lens, + page_indices, + cu_q_lens, + num_seqs, + sm_scale=sm_scale, + sliding_window=sliding_window, + soft_cap=soft_cap, + mask_value=mask_value, + k_scale=k_scale, + v_scale=v_scale, + num_kv_pages_per_block=num_kv_pages_per_block, + num_queries_per_block=num_queries_per_block, + vmem_limit_bytes=vmem_limit_bytes, ) max_num_batched_tokens = q.shape[0] - page_size = k_pages.shape[1] + page_size = kv_pages.shape[1] max_num_seqs, pages_per_seq = page_indices.shape if num_seqs[0] > max_num_seqs: raise ValueError(f"{num_seqs[0]=} must be less or equal to {max_num_seqs=}") max_kv_len = jnp.max(kv_lens) - min_pages_per_seq = ceil_div(max_kv_len, page_size) + min_pages_per_seq = cdiv(max_kv_len, page_size) if pages_per_seq < min_pages_per_seq: raise ValueError( f"{pages_per_seq=} must be greater or equal to" @@ -153,24 +213,35 @@ def validate_inputs_on_runtime( # Expect to run these checks during compile time. -def check_inputs_shapes( +def static_validate_inputs( q: jax.Array, # [max_num_batched_tokens, num_q_heads, head_dim] - k_pages: jax.Array, # [total_num_pages, page_size, num_kv_heads, head_dim] - v_pages: jax.Array, # [total_num_pages, page_size, num_kv_heads, head_dim] + kv_pages: jax.Array, # [total_num_pages, page_size, num_combined_kv_heads, head_dim] kv_lens: jax.Array, # i32[max_num_seqs] page_indices: jax.Array, # i32[max_num_seqs, pages_per_seq] cu_q_lens: jax.Array, # i32[max_num_seqs + 1] - num_seqs, # i32[1] + num_seqs: jax.Array, # i32[1] + *, + # These inputs are optional. If not specified, we will not validate them. + sm_scale: float | None = None, + sliding_window: int | None = None, + soft_cap: float | None = None, + mask_value: float | None = None, + k_scale: float | None = None, + v_scale: float | None = None, + # Kernel tuning params. + num_kv_pages_per_block: int | None = None, + num_queries_per_block: int | None = None, + vmem_limit_bytes: int | None = None, ): _, num_q_heads, head_dim = q.shape - _, _, num_kv_heads, head_dim_k = k_pages.shape - max_num_seqs, _ = page_indices.shape + _, _, num_combined_kv_heads, head_dim_k = kv_pages.shape + assert num_combined_kv_heads % 2 == 0 + assert isinstance(k_scale, float) or k_scale is None + assert isinstance(v_scale, float) or v_scale is None + num_kv_heads = num_combined_kv_heads // 2 + max_num_seqs, pages_per_seq = page_indices.shape if num_seqs.shape != (1,): raise ValueError(f"{num_seqs.shape=} must be (1,)") - if k_pages.shape != v_pages.shape: - raise ValueError( - f"{k_pages.shape=} and {v_pages.shape=} must have the same shape." - ) if head_dim_k != head_dim: raise ValueError( f"Q head_dim {head_dim} must be the same as that of K/V {head_dim_k}." @@ -197,6 +268,23 @@ def check_inputs_shapes( ) if num_q_heads % num_kv_heads != 0: raise ValueError(f"{num_q_heads=} must be divisible by {num_kv_heads=}") + if sliding_window is not None and sliding_window <= 0: + raise ValueError(f"{sliding_window=} must be positive.") + if soft_cap is not None and soft_cap == 0.0: + raise ValueError(f"{soft_cap=} must not be 0.0.") + if ( + num_kv_pages_per_block is not None + and not 0 < num_kv_pages_per_block <= pages_per_seq + ): + raise ValueError( + f"{num_kv_pages_per_block=} must be in range (0, {pages_per_seq}]." + ) + if num_queries_per_block is not None and num_queries_per_block <= 0: + raise ValueError(f"{num_queries_per_block=} must be positive.") + if vmem_limit_bytes is not None and vmem_limit_bytes <= 0: + raise ValueError(f"{vmem_limit_bytes=} must be positive.") + del sm_scale # No constraints on sm_scale. + del mask_value # No consstraints on mask_value. def ragged_paged_attention_kernel( @@ -209,23 +297,32 @@ def ragged_paged_attention_kernel( num_seqs_ref, # Input q_ref, # [num_q_per_blk, num_q_heads_per_blk, head_dim] - k_pages_hbm_ref, # [total_num_pages, page_size, num_kv_heads, head_dim] - v_pages_hbm_ref, # [total_num_pages, page_size, num_kv_heads, head_dim] + kv_pages_hbm_ref, # [total_num_pages, page_size, num_combined_kv_heads, head_dim] # Output o_ref, # [num_q_per_blk, num_q_heads_per_blk, head_dim] # Scratch - k_bufs, # [2, num_kv_pages_per_blk, page_size, num_kv_heads_per_blk, head_dim] - v_bufs, # [2, num_kv_pages_per_blk, page_size, num_kv_heads_per_blk, head_dim] + kv_bufs, # [2, num_kv_pages_per_blk, page_size, num_combined_kv_heads_per_blk, head_dim] sems, # [2, 2] l_ref, # [num_kv_heads_per_blk, num_q_per_blk * num_q_heads_per_kv_head, 128] m_ref, # [num_kv_heads_per_blk, num_q_per_blk * num_q_heads_per_kv_head, 128] + acc_ref, # [num_q_per_blk, num_q_heads_per_blk, head_dim] *, sm_scale: float, - mask_value: float, + sliding_window: int | None = None, + soft_cap: float | None = None, + mask_value: float | None = DEFAULT_MASK_VALUE, + k_scale: float | None = None, + v_scale: float | None = None, ): + if mask_value is None: + mask_value = DEFAULT_MASK_VALUE num_q_per_blk, num_q_heads_per_blk, head_dim = q_ref.shape + pages_per_seq = page_indices_ref.shape[-1] num_seqs = num_seqs_ref[0] - _, num_kv_pages_per_blk, page_size, num_kv_heads_per_blk, _ = k_bufs.shape + _, num_kv_pages_per_blk, page_size, num_combined_kv_heads_per_blk, _ = ( + kv_bufs.shape + ) + num_kv_heads_per_blk = num_combined_kv_heads_per_blk // 2 num_kv_per_blk = num_kv_pages_per_blk * page_size num_q_heads_per_kv_head = num_q_heads_per_blk // num_kv_heads_per_blk heads_blk_idx, q_blk_idx = ( @@ -241,42 +338,59 @@ def ragged_paged_attention_kernel( def create_kv_async_copy_descriptors( heads_blk_idx, seq_idx, kv_blk_idx, buf_idx ): - offset = (seq_idx, kv_blk_idx * num_kv_pages_per_blk) - heads_start = heads_blk_idx * num_kv_heads_per_blk - async_copy_k = MultiPageAsyncCopyDescriptor( - k_pages_hbm_ref.at[:, :, pl.ds(heads_start, num_kv_heads_per_blk), :], - k_bufs.at[buf_idx], - sems.at[buf_idx, 0], - page_indices_ref, - offset, + start_kv_page_idx = kv_blk_idx * num_kv_pages_per_blk + end_kv_page_idx = jnp.minimum( + pages_per_seq, cdiv(kv_lens_ref[seq_idx], page_size) ) - async_copy_v = MultiPageAsyncCopyDescriptor( - v_pages_hbm_ref.at[:, :, pl.ds(heads_start, num_kv_heads_per_blk), :], - v_bufs.at[buf_idx], - sems.at[buf_idx, 1], + metadata = (seq_idx, start_kv_page_idx, end_kv_page_idx) + heads_start = heads_blk_idx * num_combined_kv_heads_per_blk + async_copy_kv = MultiPageAsyncCopyDescriptor( + kv_pages_hbm_ref.at[ + :, :, pl.ds(heads_start, num_combined_kv_heads_per_blk), : + ], + kv_bufs.at[buf_idx], + sems.at[buf_idx], page_indices_ref, - offset, + metadata, ) - return async_copy_k, async_copy_v + return async_copy_kv # TODO(jevinjiang): Add these to Mosaic: - # 1. Support arbitrary strided load/store for any dtype. + # 1. Support arbitrary strided load/store for int4 and int8 dtype. # 2. Support arbitrary strided load/store for any last dimension. def strided_load_kv(ref, start, step): - if ref.dtype == jnp.float32: - return ref[start::step, :] packing = get_dtype_packing(ref.dtype) - assert ref.dtype == jnp.bfloat16 + if packing == 1: + return [ref[start::step, :]], [ref[start + 1 :: step, :]] + assert packing in (2, 4, 8) assert step % packing == 0 + k_list, v_list = [], [] b_start = start // packing - b_offset = start % packing b_step = step // packing - b_ref = ref.bitcast(jnp.int32) + b_ref = ref.bitcast(jnp.uint32) b = b_ref[b_start::b_step, :] - bw = 32 // packing - b = jnp.right_shift(b, bw * b_offset) - b = jnp.left_shift(b, bw * (packing - 1)) - return pltpu.bitcast(b, jnp.float32).astype(jnp.bfloat16) + + # TODO(chengjiyao): use the general strided loading logic for bf16 after + # fixing the issue in mosaic's infer vector layout pass + if ref.dtype == jnp.bfloat16: + bk = b << 16 + bv = b & jnp.uint32(0xFFFF0000) + k = pltpu.bitcast(bk, jnp.float32).astype(jnp.bfloat16) + v = pltpu.bitcast(bv, jnp.float32).astype(jnp.bfloat16) + k_list.append(k) + v_list.append(v) + else: + bitwidth = 32 // packing + bitcast_dst_dtype = jnp.dtype(f"uint{bitwidth}") + for i in range(0, packing, 2): + bk = b >> (i * bitwidth) + k = pltpu.bitcast(bk.astype(bitcast_dst_dtype), ref.dtype) + k_list.append(k) + bv = b >> ((i + 1) * bitwidth) + v = pltpu.bitcast(bv.astype(bitcast_dst_dtype), ref.dtype) + v_list.append(v) + + return k_list, v_list def fold_on_2nd_minor(vec): assert vec.dtype == jnp.bfloat16 or vec.dtype == jnp.float32 @@ -289,15 +403,16 @@ def fold_on_2nd_minor(vec): @pl.when(heads_blk_idx + q_blk_idx == 0) def prefetch_first_kv_blk(): - async_copy_k, async_copy_v = create_kv_async_copy_descriptors( + async_copy_kv = create_kv_async_copy_descriptors( heads_blk_idx, init_seq_idx, 0, init_buf_idx ) - async_copy_k.start() - async_copy_v.start() + async_copy_kv.start() def is_cur_q_blk_needed(q_states): done, cur_seq_idx, _ = q_states - return jnp.logical_and(done == 0, cur_seq_idx < num_seqs) + should_run = jnp.logical_and(q_len_start < cu_q_lens_ref[num_seqs], + cur_seq_idx < num_seqs) + return jnp.logical_and(done == 0, should_run) def compute_with_cur_q_blk(q_states): done, cur_seq_idx, cur_buf_idx = q_states @@ -342,7 +457,7 @@ def flash_attention( v, # [num_kv_per_blk, head_dim] head_l_ref, # [num_q_per_blk * num_q_heads_per_kv_head, 128] head_m_ref, # [num_q_per_blk * num_q_heads_per_kv_head, 128] - head_o_ref, # [num_q_per_blk, num_q_heads_per_kv_head, head_dim] + head_acc_ref, # [num_q_per_blk, num_q_heads_per_kv_head, head_dim] *, kv_blk_idx, ): @@ -350,20 +465,24 @@ def flash_attention( num_q_per_blk * num_q_heads_per_kv_head, head_dim, ) - assert k.shape == ( - num_kv_per_blk, - head_dim, - ), f"{k.shape=}, {(num_kv_per_blk, head_dim)=} {k.dtype=}" - assert v.shape == (num_kv_per_blk, head_dim) - assert head_m_ref.shape == ( - num_q_per_blk * num_q_heads_per_kv_head, - 128, + assert ( + k.shape + == v.shape + == ( + num_kv_per_blk, + head_dim, + ) ) - assert head_l_ref.shape == ( - num_q_per_blk * num_q_heads_per_kv_head, - 128, + assert k.dtype == v.dtype + assert ( + head_m_ref.shape + == head_l_ref.shape + == ( + num_q_per_blk * num_q_heads_per_kv_head, + 128, + ) ) - assert head_o_ref.shape == ( + assert head_acc_ref.shape == ( num_q_per_blk, num_q_heads_per_kv_head, head_dim, @@ -373,7 +492,19 @@ def flash_attention( def masked_store(ref, val, start, end, group=1): iota = lax.broadcasted_iota(jnp.int32, ref.shape, 0) // group mask = jnp.logical_and(iota >= start, iota < end) - pl.store(ref, tuple(slice(None) for _ in ref.shape), val, mask=mask) + pl.store(ref, idx=tuple(slice(None) for _ in ref.shape), val=val, mask=mask) + + def load_with_init(ref, init_val): + return jnp.where( + kv_blk_idx == 0, jnp.full_like(ref, init_val), ref[...] + ) + + # kv lens will be contracting dim, we should mask out the NaNs. + kv_mask = ( + lax.broadcasted_iota(jnp.int32, k.shape, 0) < kv_len - kv_len_start + ) + k = jnp.where(kv_mask, k.astype(jnp.float32), 0).astype(k.dtype) + v = jnp.where(kv_mask, v.astype(jnp.float32), 0).astype(v.dtype) qk = ( jnp.einsum("nd,md->nm", q, k, preferred_element_type=jnp.float32) @@ -382,29 +513,6 @@ def masked_store(ref, val, start, end, group=1): store_start = jnp.maximum(q_start - q_len_start, 0) store_end = jnp.minimum(q_end - q_len_start, num_q_per_blk) - @pl.when(kv_blk_idx == 0) - def init_scratch_ref(): - masked_store( - head_m_ref, - jnp.full_like(head_m_ref, -jnp.inf), - store_start, - store_end, - num_q_heads_per_kv_head, - ) - masked_store( - head_l_ref, - jnp.zeros_like(head_l_ref), - store_start, - store_end, - num_q_heads_per_kv_head, - ) - masked_store( - head_o_ref, - jnp.zeros_like(head_o_ref), - store_start, - store_end, - ) - row_ids = ( (kv_len - q_len) + q_len_start @@ -422,6 +530,11 @@ def init_scratch_ref(): 1, ) causal_mask = row_ids < col_ids + if sliding_window is not None: + causal_mask = jnp.logical_or(causal_mask, + row_ids - sliding_window >= col_ids) + if soft_cap is not None: + qk = soft_cap * jnp.tanh(qk / soft_cap) qk += jnp.where(causal_mask, mask_value, 0.0) m_curr = jnp.max(qk, axis=1, keepdims=True) s_curr = jnp.exp(qk - m_curr) @@ -431,8 +544,8 @@ def init_scratch_ref(): l_curr = jnp.broadcast_to( s_curr.sum(axis=1, keepdims=True), lm_store_shape ) - m_prev = head_m_ref[...] - l_prev = head_l_ref[...] + m_prev = load_with_init(head_m_ref, -jnp.inf) + l_prev = load_with_init(head_l_ref, 0.0) m_next = jnp.maximum(m_prev, m_curr) masked_store( head_m_ref, m_next, store_start, store_end, num_q_heads_per_kv_head @@ -461,17 +574,17 @@ def broadcast_to_shape(arr, shape): [arr for _ in range(shape[1] // arr.shape[1])], axis=1 ) - o_curr = head_o_ref[...].reshape(-1, head_dim) + o_curr = load_with_init(head_acc_ref, 0.0).reshape(-1, head_dim) l_alpha = broadcast_to_shape(l_alpha, qkv.shape) beta = broadcast_to_shape(beta, qkv.shape) l_next_safe = broadcast_to_shape(l_next_safe, qkv.shape) out = lax.div( l_alpha * o_curr + beta * qkv, l_next_safe, - ).astype(head_o_ref.dtype) + ) masked_store( - head_o_ref, - out.reshape(head_o_ref.shape), + head_acc_ref, + out.reshape(head_acc_ref.shape), store_start, store_end, ) @@ -493,39 +606,54 @@ def prefetch_next_kv_blk(): # TODO(jevinjiang): reuse the same buffer if it is already prefetched! # TODO(jevinjiang): only fetch effective dynamic size to hold kv_len and # DMA to fixed size buffer! - next_async_copy_k, next_async_copy_v = create_kv_async_copy_descriptors( + next_async_copy_kv = create_kv_async_copy_descriptors( next_heads_blk_idx, next_seq_idx, next_kv_blk_idx, next_buf_idx ) - next_async_copy_k.start() - next_async_copy_v.start() + next_async_copy_kv.start() - cur_async_copy_k, cur_async_copy_v = create_kv_async_copy_descriptors( + cur_async_copy_kv = create_kv_async_copy_descriptors( heads_blk_idx, cur_seq_idx, kv_blk_idx, cur_buf_idx ) - kv_to_load_shape = ( - num_kv_pages_per_blk * page_size * num_kv_heads_per_blk, + kv_ref = cur_async_copy_kv.wait().reshape( + num_kv_pages_per_blk * page_size * num_combined_kv_heads_per_blk, head_dim, ) - k_ref = cur_async_copy_k.wait().reshape(kv_to_load_shape) - v_ref = cur_async_copy_v.wait().reshape(kv_to_load_shape) - for kv_head_idx in range(num_kv_heads_per_blk): - q_head_idx = kv_head_idx * num_q_heads_per_kv_head - # TODO(jevinjiang): extra handlig for packed type that can start at - # unaligned position! - q = fold_on_2nd_minor( - q_ref[:, q_head_idx : q_head_idx + num_q_heads_per_kv_head, :] - ) - k = strided_load_kv(k_ref, kv_head_idx, num_kv_heads_per_blk) - v = strided_load_kv(v_ref, kv_head_idx, num_kv_heads_per_blk) - flash_attention( - q, - k, - v, - l_ref.at[kv_head_idx], - m_ref.at[kv_head_idx], - o_ref.at[:, q_head_idx : q_head_idx + num_q_heads_per_kv_head, :], - kv_blk_idx=kv_blk_idx, + kv_packing = get_dtype_packing(kv_ref.dtype) + # NOTE: kv_packing is divided by 2 because k and v are packed together. + kv_load_step = max(1, kv_packing // 2) + for kv_head_chunk_idx in range(0, num_kv_heads_per_blk, kv_load_step): + k_list, v_list = strided_load_kv( + kv_ref, kv_head_chunk_idx * 2, num_combined_kv_heads_per_blk ) + for step_idx in range(kv_load_step): + k = k_list[step_idx] + v = v_list[step_idx] + if k_scale is not None: + # NOTE: Conversion between arbitrary data types is not supported. + # That's why it is converted to float32 first. + k = k.astype(jnp.float32) * k_scale + k = k.astype(q_ref.dtype) + if v_scale is not None: + v = v.astype(jnp.float32) * v_scale + v = v.astype(q_ref.dtype) + kv_head_idx = kv_head_chunk_idx + step_idx + q_head_idx = kv_head_idx * num_q_heads_per_kv_head + # TODO(jevinjiang): extra handling for packed type that can start at + # unaligned position! + q = fold_on_2nd_minor( + q_ref[:, q_head_idx : q_head_idx + num_q_heads_per_kv_head, :] + ) + flash_attention( + q, + k, + v, + l_ref.at[kv_head_idx], + m_ref.at[kv_head_idx], + acc_ref.at[ + :, q_head_idx : q_head_idx + num_q_heads_per_kv_head, : + ], + kv_blk_idx=kv_blk_idx, + ) return kv_blk_idx + 1, next_buf_idx _, next_buf_idx = lax.while_loop( @@ -545,26 +673,22 @@ def prefetch_next_kv_blk(): # Reset seq_idx for next kv_heads_blk if run out of seqs! seq_buf_idx_ref[0] = lax.select(seq_idx < num_seqs, seq_idx, 0) seq_buf_idx_ref[1] = buf_idx + o_ref[...] = acc_ref[...].astype(q_ref.dtype) -def ceil_div(a, b): +def cdiv(a, b): assert b != 0 return (a + b - 1) // b def get_dtype_packing(dtype): - if dtype == jnp.float32: - return 1 - if dtype == jnp.bfloat16: - return 2 - if dtype == jnp.int8: - return 4 - if dtype == jnp.int4: - return 8 - raise ValueError(f"Not implemented: unsupported {dtype=}") - - -def get_min_heads_per_blk(num_q_heads, num_kv_heads, q_dtype, kv_dtype): + bits = dtypes.bit_width(dtype) + return 32 // bits + + +def get_min_heads_per_blk( + num_q_heads, num_combined_kv_heads, q_dtype, kv_dtype +): q_packing = get_dtype_packing(q_dtype) kv_packing = get_dtype_packing(kv_dtype) @@ -575,22 +699,26 @@ def can_be_xla_fully_tiled(x, packing): return x in (1, 2, 4, 8) or x % 8 == 0 # TODO(jevinjiang): support unaligned number of heads! - if not can_be_xla_fully_tiled(num_kv_heads, kv_packing): + if not can_be_xla_fully_tiled(num_combined_kv_heads, kv_packing): raise ValueError( - f"Not implemented: {num_kv_heads=} can not be XLA fully tiled." + f"Not implemented: {num_combined_kv_heads=} can not be XLA fully tiled." ) + assert num_combined_kv_heads % 2 == 0 + num_kv_heads = num_combined_kv_heads // 2 assert num_q_heads % num_kv_heads == 0 ratio = num_q_heads // num_kv_heads # TODO(jevinjiang): we can choose smaller tiling for packed type if large # second minor tiling is not on. - max_kv_tiling = 8 * kv_packing - min_kv_heads = ( - max_kv_tiling if num_kv_heads % max_kv_tiling == 0 else num_kv_heads + max_combined_kv_tiling = 8 * kv_packing + min_combined_kv_heads = ( + max_combined_kv_tiling + if num_combined_kv_heads % max_combined_kv_tiling == 0 + else num_combined_kv_heads ) - min_q_heads = min_kv_heads * ratio + min_q_heads = min_combined_kv_heads // 2 * ratio if can_be_xla_fully_tiled(min_q_heads, q_packing): - return min_q_heads, min_kv_heads - return num_q_heads, num_kv_heads + return min_q_heads, min_combined_kv_heads + return num_q_heads, num_combined_kv_heads @functools.partial( @@ -601,30 +729,36 @@ def can_be_xla_fully_tiled(x, packing): "num_kv_pages_per_block", "num_queries_per_block", "vmem_limit_bytes", + "sliding_window", + "soft_cap", + "k_scale", + "v_scale", ], ) def ragged_paged_attention( q: jax.Array, # [max_num_batched_tokens, num_q_heads, head_dim] # TODO(jevinjiang): create a write_to_kv_cache kernel! - k_pages: jax.Array, # [total_num_pages, page_size, num_kv_heads, head_dim] - v_pages: jax.Array, # [total_num_pages, page_size, num_kv_heads, head_dim] + kv_pages: jax.Array, # [total_num_pages, page_size, num_combined_kv_heads, head_dim] kv_lens: jax.Array, # i32[max_num_seqs] page_indices: jax.Array, # i32[max_num_seqs, pages_per_seq] cu_q_lens: jax.Array, # i32[max_num_seqs + 1] num_seqs: jax.Array, # i32[1] *, sm_scale: float = 1.0, - mask_value: float = DEFAULT_MASK_VALUE, - num_kv_pages_per_block: int = 16, - num_queries_per_block: int = 128, + sliding_window: int | None = None, + soft_cap: float | None = None, + mask_value: float | None = DEFAULT_MASK_VALUE, + k_scale: float | None = None, + v_scale: float | None = None, + num_kv_pages_per_block: int | None = None, + num_queries_per_block: int | None = None, vmem_limit_bytes: int | None = None, ): """Ragged paged attention that supports mixed prefill and decode. Args: q: concatenated all sequences' queries. - k_pages: paged K cache. Normally in HBM. - v_pages: paged V cache. Normally in HBM. + kv_pages: paged KV cache. Normally in HBM. kv_lens: padded kv lengths. Only the first num_seqs values are valid. page_indices: the first index indicates which page to use in the kv cache for each sequence. Only the first num_seqs values are valid. @@ -632,7 +766,11 @@ def ragged_paged_attention( kv_lens, only the first num_seqs+1 values are valid. num_seqs: the dynamic number of sequences. sm_scale: the softmax scale which will be applied to the Q@K^T. + sliding_window: the sliding window size for the attention. + soft_cap: the logit soft cap for the attention. mask_value: mask value for causal mask. + k_scale: the scale for the key cache. + v_scale: the scale for the value cache. num_kv_pages_per_block: number of kv pages to be processed in one flash attention block in the pallas kernel. num_queries_per_block: number of kv pages to be processed in one flash @@ -642,18 +780,50 @@ def ragged_paged_attention( Returns: The output of the attention. """ - check_inputs_shapes( - q, k_pages, v_pages, kv_lens, page_indices, cu_q_lens, num_seqs + static_validate_inputs( + q, + kv_pages, + kv_lens, + page_indices, + cu_q_lens, + num_seqs, + sm_scale=sm_scale, + sliding_window=sliding_window, + soft_cap=soft_cap, + mask_value=mask_value, + k_scale=k_scale, + v_scale=v_scale, + num_kv_pages_per_block=num_kv_pages_per_block, + num_queries_per_block=num_queries_per_block, + vmem_limit_bytes=vmem_limit_bytes, + ) + if mask_value is None: + mask_value = DEFAULT_MASK_VALUE + num_q_tokens, num_q_heads, head_dim = q.shape + _, page_size, num_combined_kv_heads, _ = kv_pages.shape + assert num_combined_kv_heads % 2 == 0 + num_kv_heads = num_combined_kv_heads // 2 + _, pages_per_seq = page_indices.shape + num_q_heads_per_blk, num_combined_kv_heads_per_blk = get_min_heads_per_blk( + num_q_heads, num_combined_kv_heads, q.dtype, kv_pages.dtype ) - _, num_q_heads, head_dim = q.shape - _, page_size, num_kv_heads, _ = k_pages.shape num_q_per_blk = num_queries_per_block num_kv_pages_per_blk = num_kv_pages_per_block + if num_q_per_blk is None or num_kv_pages_per_blk is None: + num_kv_pages_per_blk, num_q_per_blk = get_tuned_block_sizes( + q.dtype, + kv_pages.dtype, + num_q_heads_per_blk, + num_combined_kv_heads_per_blk // 2, + head_dim, + page_size, + num_q_tokens, + pages_per_seq, + ) num_q_heads_per_kv_head = num_q_heads // num_kv_heads - num_q_blks = ceil_div(cu_q_lens[num_seqs[0]], num_q_per_blk) - num_q_heads_per_blk, num_kv_heads_per_blk = get_min_heads_per_blk( - num_q_heads, num_kv_heads, q.dtype, k_pages.dtype - ) + num_q_blks = cdiv(num_q_tokens, num_q_per_blk) + assert num_combined_kv_heads_per_blk % 2 == 0 + num_kv_heads_per_blk = num_combined_kv_heads_per_blk // 2 assert num_q_heads_per_blk % num_q_heads_per_kv_head == 0 num_heads_blks = num_q_heads // num_q_heads_per_blk grid = (num_heads_blks, num_q_blks) @@ -667,8 +837,7 @@ def q_index_map(heads_blk_idx, q_blk_idx, *_): ) in_specs = [ q_block_spec, - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + pl.BlockSpec(memory_space=pltpu.ANY), ] out_specs = q_block_spec lm_scratch = pltpu.VMEM( @@ -677,22 +846,26 @@ def q_index_map(heads_blk_idx, q_blk_idx, *_): (num_kv_heads_per_blk, num_q_per_blk * num_q_heads_per_kv_head, 128), jnp.float32, ) + acc_scratch = pltpu.VMEM( + (num_q_per_blk, num_q_heads_per_blk, head_dim), + jnp.float32, + ) double_buf_scratch = pltpu.VMEM( ( 2, # For double buffering during DMA copies. num_kv_pages_per_blk, page_size, - num_kv_heads_per_blk, + num_combined_kv_heads_per_blk, head_dim, ), - k_pages.dtype, + kv_pages.dtype, ) scratch_shapes = [ - double_buf_scratch, # k_bufs - double_buf_scratch, # v_bufs - pltpu.SemaphoreType.DMA((2, 2)), # [double_buffers, k_sem/v_sem] + double_buf_scratch, # kv_bufs + pltpu.SemaphoreType.DMA((2,)), # Semaphores for double buffers. lm_scratch, # l_ref lm_scratch, # m_ref + acc_scratch, ] scalar_prefetches = ( kv_lens, @@ -705,7 +878,11 @@ def q_index_map(heads_blk_idx, q_blk_idx, *_): functools.partial( ragged_paged_attention_kernel, sm_scale=sm_scale, + sliding_window=sliding_window, + soft_cap=soft_cap, mask_value=mask_value, + k_scale=k_scale, + v_scale=v_scale, ), grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=len(scalar_prefetches), @@ -714,16 +891,15 @@ def q_index_map(heads_blk_idx, q_blk_idx, *_): grid=grid, scratch_shapes=scratch_shapes, ), - compiler_params=pltpu.TPUCompilerParams( + compiler_params=pltpu.CompilerParams( dimension_semantics=( "arbitrary", "arbitrary", ), vmem_limit_bytes=vmem_limit_bytes, ), - out_shape=jax.ShapeDtypeStruct(shape=q.shape, dtype=jnp.float32), + out_shape=jax.ShapeDtypeStruct(shape=q.shape, dtype=q.dtype), name="ragged_paged_attention_kernel", ) - # TODO(jevinjiang): Use f32 acc scratch for output! So we only need - # to transfer output with desired dtype back to HBM. - return kernel(*scalar_prefetches, q, k_pages, v_pages).astype(q.dtype) + + return kernel(*scalar_prefetches, q, kv_pages) diff --git a/jax/experimental/pallas/ops/tpu/ragged_paged_attention/tuned_block_sizes.py b/jax/experimental/pallas/ops/tpu/ragged_paged_attention/tuned_block_sizes.py new file mode 100644 index 000000000000..df2f1c4ea83f --- /dev/null +++ b/jax/experimental/pallas/ops/tpu/ragged_paged_attention/tuned_block_sizes.py @@ -0,0 +1,457 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Auto-tuned block sizes for ragged paged attention.""" + +import jax +import jax.numpy as jnp + +# The page size is too small. We only have 32 SREGs in TC. If the pages +# per seq is too large, SREGs will spill. +MAX_PAGES_PER_SEQ = 16 + +# key: +# - q_dtype_name +# - kv_dtype_name +# - num_q_heads_per_blk +# - num_kv_heads_per_blk +# - head_dim +# - page_size +# - max_num_batched_tokens +# - max_model_len = page_size * pages_per_seq +# value: +# - num_kv_pages_per_block +# - num_queries_per_block +TUNED_BLOCK_SIZES = { + 'TPU v6': { + # go/keep-sorted start + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 1024, 1024): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 1024, 2048): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 1024, 4096): (32, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 1024, 512): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 2048, 1024): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 2048, 2048): (16, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 2048, 4096): (32, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 2048, 512): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 4096, 1024): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 4096, 2048): (16, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 4096, 4096): (32, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 4096, 512): (4, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 512, 1024): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 512, 2048): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 512, 4096): (32, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 512, 512): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 1024, 1024): (64, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 1024, 128): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 1024, 2048): (128, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 1024, 256): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 1024, 512): (32, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 1024, 64): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 2048, 1024): (64, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 2048, 128): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 2048, 2048): (128, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 2048, 256): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 2048, 512): (32, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 2048, 64): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 4096, 1024): (64, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 4096, 128): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 4096, 2048): (128, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 4096, 256): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 4096, 512): (32, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 4096, 64): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 512, 1024): (64, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 512, 128): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 512, 2048): (128, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 512, 256): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 512, 512): (32, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 512, 64): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 1024, 1024): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 1024, 2048): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 1024, 4096): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 2048, 1024): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 2048, 2048): (8, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 2048, 4096): (16, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 4096, 1024): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 4096, 2048): (8, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 4096, 4096): (16, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 512, 1024): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 512, 2048): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 512, 4096): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 1024, 1024): (32, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 1024, 128): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 1024, 2048): (64, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 1024, 256): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 1024, 4096): (128, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 1024, 512): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 2048, 1024): (32, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 2048, 128): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 2048, 2048): (64, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 2048, 256): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 2048, 4096): (128, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 2048, 512): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 4096, 1024): (32, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 4096, 128): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 4096, 2048): (64, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 4096, 256): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 4096, 4096): (128, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 4096, 512): (16, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 512, 1024): (32, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 512, 128): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 512, 2048): (64, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 512, 256): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 512, 4096): (128, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 512, 512): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 1024, 1024): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 1024, 2048): (32, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 1024, 256): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 1024, 4096): (64, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 1024, 512): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 2048, 1024): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 2048, 2048): (32, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 2048, 256): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 2048, 4096): (64, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 2048, 512): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 4096, 1024): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 4096, 2048): (32, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 4096, 256): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 4096, 4096): (64, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 4096, 512): (8, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 512, 1024): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 512, 2048): (32, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 512, 256): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 512, 4096): (64, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 512, 512): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 1024, 1024): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 1024, 2048): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 1024, 4096): (32, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 1024, 512): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 2048, 1024): (8, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 2048, 2048): (16, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 2048, 4096): (32, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 2048, 512): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 4096, 1024): (8, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 4096, 2048): (16, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 4096, 4096): (32, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 4096, 512): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 512, 1024): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 512, 2048): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 512, 4096): (32, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 512, 512): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 1024, 1024): (64, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 1024, 128): (8, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 1024, 2048): (128, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 1024, 256): (16, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 1024, 512): (32, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 1024, 64): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 2048, 1024): (64, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 2048, 128): (8, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 2048, 2048): (128, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 2048, 256): (16, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 2048, 512): (32, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 2048, 64): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 4096, 1024): (64, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 4096, 128): (8, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 4096, 2048): (128, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 4096, 256): (16, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 4096, 512): (32, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 4096, 64): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 512, 1024): (64, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 512, 128): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 512, 2048): (128, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 512, 256): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 512, 512): (32, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 512, 64): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 1024, 1024): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 1024, 2048): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 1024, 4096): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 2048, 1024): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 2048, 2048): (8, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 2048, 4096): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 4096, 1024): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 4096, 2048): (8, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 4096, 4096): (16, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 512, 1024): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 512, 2048): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 512, 4096): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 1024, 1024): (32, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 1024, 128): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 1024, 2048): (64, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 1024, 256): (8, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 1024, 4096): (128, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 1024, 512): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 2048, 1024): (32, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 2048, 128): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 2048, 2048): (64, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 2048, 256): (8, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 2048, 4096): (64, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 2048, 512): (16, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 4096, 1024): (32, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 4096, 128): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 4096, 2048): (64, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 4096, 256): (8, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 4096, 4096): (64, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 4096, 512): (16, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 512, 1024): (32, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 512, 128): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 512, 2048): (64, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 512, 256): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 512, 4096): (128, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 512, 512): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 1024, 1024): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 1024, 2048): (32, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 1024, 256): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 1024, 4096): (64, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 1024, 512): (8, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 2048, 1024): (16, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 2048, 2048): (32, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 2048, 256): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 2048, 4096): (64, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 2048, 512): (8, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 4096, 1024): (16, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 4096, 2048): (32, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 4096, 256): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 4096, 4096): (64, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 4096, 512): (8, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 512, 1024): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 512, 2048): (32, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 512, 256): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 512, 4096): (64, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 512, 512): (8, 32), + # go/keep-sorted end + }, + 'TPU v5': { + # go/keep-sorted start + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 1024, 1024): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 1024, 2048): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 1024, 512): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 2048, 1024): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 2048, 2048): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 2048, 512): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 4096, 1024): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 4096, 2048): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 4096, 512): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 512, 1024): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 512, 2048): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 512, 512): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 1024, 128): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 1024, 256): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 1024, 64): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 2048, 128): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 2048, 256): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 2048, 64): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 4096, 128): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 4096, 256): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 4096, 64): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 512, 128): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 512, 256): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 512, 64): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 1024, 1024): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 1024, 2048): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 1024, 4096): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 2048, 1024): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 2048, 2048): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 2048, 4096): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 4096, 1024): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 4096, 2048): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 4096, 4096): (16, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 512, 1024): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 512, 2048): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 512, 4096): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 1024, 128): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 1024, 256): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 1024, 512): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 2048, 128): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 2048, 256): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 2048, 512): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 4096, 128): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 4096, 256): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 4096, 512): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 512, 128): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 512, 256): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 512, 512): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 1024, 1024): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 1024, 256): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 1024, 512): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 2048, 1024): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 2048, 256): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 2048, 512): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 4096, 1024): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 4096, 256): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 4096, 512): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 512, 1024): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 512, 256): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 512, 512): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 1024, 1024): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 1024, 2048): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 1024, 512): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 2048, 1024): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 2048, 2048): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 2048, 512): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 4096, 1024): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 4096, 2048): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 4096, 512): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 512, 1024): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 512, 2048): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 512, 512): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 1024, 128): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 1024, 256): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 1024, 64): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 2048, 128): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 2048, 256): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 2048, 64): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 4096, 128): (8, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 4096, 256): (16, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 4096, 64): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 512, 128): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 512, 256): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 512, 64): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 1024, 1024): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 1024, 2048): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 1024, 4096): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 2048, 1024): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 2048, 2048): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 2048, 4096): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 4096, 1024): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 4096, 2048): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 4096, 4096): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 512, 1024): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 512, 2048): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 512, 4096): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 1024, 128): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 1024, 256): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 1024, 512): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 2048, 128): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 2048, 256): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 2048, 512): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 4096, 128): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 4096, 256): (8, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 4096, 512): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 512, 128): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 512, 256): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 512, 512): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 1024, 1024): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 1024, 256): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 1024, 512): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 2048, 1024): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 2048, 256): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 2048, 512): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 4096, 1024): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 4096, 256): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 4096, 512): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 512, 1024): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 512, 256): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 512, 512): (8, 32), + # go/keep-sorted end + }, +} + +def next_power_of_2(x: int): + """Finds the smallest power of 2 >= x using bit manipulation. + + Args: + x: The input number (should be an integer). + + Returns: + The smallest integer power of 2 that is >= x. + """ + assert x > 0 + if x == 1: + return 1 + return 1 << (x - 1).bit_length() + + +def simplify_key(key): + """Simplify the key to reduce the number of combinations.""" + ( + q_dtype, + kv_dtype, + num_q_heads_per_blk, + num_kv_heads_per_blk, + head_dim, + page_size, + max_num_batched_tokens, + pages_per_seq, + ) = key + return ( + jnp.dtype(q_dtype).name, + jnp.dtype(kv_dtype).name, + next_power_of_2(num_q_heads_per_blk), + next_power_of_2(num_kv_heads_per_blk), + (head_dim + 127) // 128 * 128, + next_power_of_2(page_size), + next_power_of_2(max_num_batched_tokens), + next_power_of_2(page_size * pages_per_seq), + ) + + +def get_tpu_version() -> int: + """Returns the numeric version of the TPU, or -1 if not on TPU.""" + kind = jax.devices()[0].device_kind + if 'TPU' not in kind: + return -1 + if kind.endswith(' lite'): + kind = kind[: -len(' lite')] + assert kind[:-1] == 'TPU v', kind + return int(kind[-1]) + + +def get_device_name(num_devices:int | None = None): + name = ' '.join(jax.devices()[0].device_kind.split()[:2]) + if num_devices is not None: + name += f'-{num_devices}' + return name + + +def get_tuned_block_sizes( + q_dtype, + kv_dtype, + num_q_heads_per_blk, + num_kv_heads_per_blk, + head_dim, + page_size, + max_num_batched_tokens, + pages_per_seq, +) -> tuple[int, int]: + """Look up for the best (num_kv_pages_per_blk, num_queries_per_blk) from auto-tuned table.""" + tpu_version = get_tpu_version() + if tpu_version < 4: + raise NotImplementedError('TPU version must be 4 or higher.') + key = ( + q_dtype, + kv_dtype, + num_q_heads_per_blk, + num_kv_heads_per_blk, + head_dim, + page_size, + max_num_batched_tokens, + pages_per_seq, + ) + key = simplify_key(key) + device_name = get_device_name() + + # Default block sizes. + bkv, bq = (128, 32) + if tpu_version == 4: + # This default block size is not tuned, only make sure there's no + # OOM in vmem + bkv, bq = (32, 32) + elif device_name in TUNED_BLOCK_SIZES: + if key in TUNED_BLOCK_SIZES[device_name]: + bkv, bq = TUNED_BLOCK_SIZES[device_name][key] + return (min(pages_per_seq, bkv), min(max_num_batched_tokens, bq)) + + +def get_min_page_size(max_model_len, min_page_size=16): + """Recommended min page size for high-performance kernel.""" + return max(next_power_of_2(max_model_len) // MAX_PAGES_PER_SEQ, min_page_size) diff --git a/jax/experimental/pallas/ops/tpu/random/__init__.py b/jax/experimental/pallas/ops/tpu/random/__init__.py new file mode 100644 index 000000000000..3da0dd1fa3ca --- /dev/null +++ b/jax/experimental/pallas/ops/tpu/random/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2025 The JAX Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== diff --git a/jax/experimental/pallas/ops/tpu/random/philox.py b/jax/experimental/pallas/ops/tpu/random/philox.py index 28e627cfb298..9c1c3a829510 100644 --- a/jax/experimental/pallas/ops/tpu/random/philox.py +++ b/jax/experimental/pallas/ops/tpu/random/philox.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Implementation of the Philox PRNG as a Pallas kernel.""" -from typing import Sequence +from collections.abc import Sequence import jax from jax import typing from jax._src import prng @@ -140,8 +140,8 @@ def kernel(offset_ref, key_ref, out_ref): return pl.pallas_call( kernel, in_specs=[ - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM), - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM), + pl.BlockSpec(memory_space=pltpu.SMEM), + pl.BlockSpec(memory_space=pltpu.SMEM), ], out_specs=out_spec, grid=grid_dims, diff --git a/jax/experimental/pallas/ops/tpu/random/prng_utils.py b/jax/experimental/pallas/ops/tpu/random/prng_utils.py index e5a3ac155eea..3014c7748f22 100644 --- a/jax/experimental/pallas/ops/tpu/random/prng_utils.py +++ b/jax/experimental/pallas/ops/tpu/random/prng_utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Helper functions for PRNG kernels.""" -from typing import Sequence +from collections.abc import Sequence from jax import lax import jax.numpy as jnp diff --git a/jax/experimental/pallas/ops/tpu/random/threefry.py b/jax/experimental/pallas/ops/tpu/random/threefry.py index 5c460d491f48..71a314e09b2d 100644 --- a/jax/experimental/pallas/ops/tpu/random/threefry.py +++ b/jax/experimental/pallas/ops/tpu/random/threefry.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Implementation of the Threefry PRNG as a Pallas kernel.""" -from typing import Sequence +from collections.abc import Sequence import jax from jax._src import prng from jax.experimental import pallas as pl @@ -79,7 +79,7 @@ def kernel(key_ref, out_ref): block_shape = (1,) * (len(shape)-2) + block_size result = pl.pallas_call( kernel, - in_specs=[pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM)], + in_specs=[pl.BlockSpec(memory_space=pltpu.SMEM)], out_specs=pl.BlockSpec(block_shape, lambda *idxs: idxs), grid=grid_dims, out_shape=out, diff --git a/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py b/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py index 4b6e4a41c43b..34d8847e6193 100644 --- a/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py +++ b/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py @@ -599,9 +599,9 @@ def _apply_mask_and_soft_cap( masks = [] if mask_ref is not None: if k_in_lanes: - mask = pl.load(mask_ref, (slice(None), k_slice)) + mask = mask_ref[:, k_slice] else: - mask = pl.load(mask_ref, (k_slice, slice(None))) + mask = mask_ref[k_slice, :] masks.append( jnp.bitwise_or(mask, jnp.broadcast_to(should_not_mask, mask.shape)) @@ -630,7 +630,7 @@ def _apply_mask_and_soft_cap( k_sequence = k_offset + jax.lax.broadcasted_iota( jnp.int32, (k_slice.size, bq), 0 ) - q_sequence = pl.load(q_sequence_ref, (pl.ds(1), slice(None))) # [1, bq] + q_sequence = q_sequence_ref[:1, :] # [1, bq] q_sequence = jnp.broadcast_to(q_sequence, (k_slice.size, bq)) assert q_sequence.shape == k_sequence.shape @@ -644,7 +644,7 @@ def _apply_mask_and_soft_cap( if q_segment_ids_ref is not None: if k_in_lanes: - kv_ids = pl.load(kv_segment_ids_ref, (pl.ds(1), k_slice)) # [1, k_slice] + kv_ids = kv_segment_ids_ref[:1, k_slice] # [1, k_slice] repeats, rem = divmod(kv_ids.shape[1], NUM_LANES) if rem: raise NotImplementedError(f"block_kv must be a multiple of {NUM_LANES}") @@ -655,9 +655,9 @@ def _apply_mask_and_soft_cap( if rem: raise NotImplementedError(f"block_q must be a multiple of {NUM_LANES}") kv_ids = pltpu.repeat( - pl.load(kv_segment_ids_ref, (k_slice, slice(None))), repeats, axis=1 + kv_segment_ids_ref[k_slice, :], repeats, axis=1 ) # [k_slice, bq] - q_ids = pl.load(q_segment_ids_ref, (pl.ds(1), slice(None))) # [1, bq] + q_ids = q_segment_ids_ref[:1, :] # [1, bq] masks.append(q_ids == kv_ids) def cap_logits(logits): @@ -743,9 +743,9 @@ def body(kv_compute_index, _): q = q_ref[...] if q_layout == HEAD_DIM_MINOR else q_ref[...].T qk_dims = NT_DIM_NUMBERS if k_layout == HEAD_DIM_MINOR else NN_DIM_NUMBERS if k_layout == HEAD_DIM_MINOR: - k = pl.load(k_ref, (slice_k, slice(None))) + k = k_ref[slice_k, :] else: - k = pl.load(k_ref, (slice(None), slice_k)) + k = k_ref[:, slice_k] qk = lax.dot_general(q, k, qk_dims, preferred_element_type=float32) assert qk.shape == (bq, bkv_compute) @@ -794,9 +794,9 @@ def body(kv_compute_index, _): sv_dims = NN_DIM_NUMBERS if v_layout == HEAD_DIM_MINOR else NT_DIM_NUMBERS if v_layout == HEAD_DIM_MINOR: - v = pl.load(v_ref, (slice_k, slice(None))) + v = v_ref[slice_k, :] else: - v = pl.load(v_ref, (slice(None), slice_k)) + v = v_ref[:, slice_k] v = v.astype(float32) o_curr = lax.dot_general(s_curr, v, sv_dims) @@ -1118,7 +1118,7 @@ def logsumexp_index_map(h, i, *_): out_specs=out_specs, grid=grid, ), - compiler_params=pltpu.TPUCompilerParams( + compiler_params=pltpu.CompilerParams( dimension_semantics=("parallel", "arbitrary", "arbitrary"), ), out_shape=out_shapes, @@ -1577,7 +1577,7 @@ def logsumexp_index_map(h, i, *_): grid=grid, ), out_shape=out_shapes, - compiler_params=pltpu.TPUCompilerParams( + compiler_params=pltpu.CompilerParams( dimension_semantics=("arbitrary", "arbitrary", "arbitrary"), ), name=kernel_name, @@ -1688,13 +1688,13 @@ def body(i, _): q = q_ref[...] # We keep q potentially transposed, since it's always RHS def _load_kv(ref, layout): if layout == HEAD_DIM_MINOR: - return pl.load(ref, (slice_k, slice(None))) - return pl.load(ref, (slice(None), slice_k)).T + return ref[slice_k, :] + return ref[:, slice_k].T k = _load_kv(k_ref, k_layout) v = _load_kv(v_ref, v_layout) - logsumexp = pl.load(logsumexp_ref, (pl.ds(1), slice(None))) + logsumexp = logsumexp_ref[:1, :] do = do_ref[...] - di = pl.load(di_ref, (pl.ds(1), slice(None))) + di = di_ref[:1, :] qk_dims = NT_DIM_NUMBERS if q_layout == HEAD_DIM_MINOR else NN_DIM_NUMBERS qk_uncapped = lax.dot_general( @@ -1718,10 +1718,8 @@ def _load_kv(ref, layout): ) p = jnp.exp(qk - logsumexp) dv = lax.dot(p.astype(do.dtype), do, preferred_element_type=jnp.float32) - dv = dv.astype(dv_scratch_ref.dtype) + pl.load( - dv_scratch_ref, (slice_k, slice(None)) - ) - pl.store(dv_scratch_ref, (slice_k, slice(None)), dv) + dv = dv.astype(dv_scratch_ref.dtype) + dv_scratch_ref[slice_k, :] + dv_scratch_ref[slice_k, :] = dv dp = lax.dot_general( v, do, NT_DIM_NUMBERS, @@ -1737,10 +1735,8 @@ def _load_kv(ref, layout): dk = lax.dot_general( ds.astype(do.dtype), q, dk_dims, preferred_element_type=jnp.float32 ) - dk = dk.astype(dk_scratch_ref.dtype) + pl.load( - dk_scratch_ref, (slice_k, slice(None)) - ) - pl.store(dk_scratch_ref, (slice_k, slice(None)), dk) + dk = dk.astype(dk_scratch_ref.dtype) + dk_scratch_ref[slice_k, :] + dk_scratch_ref[slice_k, :] = dk if dq_scratch_ref is not None or dq_ref is not None: dq = lax.dot_general( ds.T.astype(k.dtype), k, NN_DIM_NUMBERS, @@ -2130,7 +2126,7 @@ def logsumexp_index_map( # megacore # 2) for heads, we are reducing over heads # 3) for q_seq_len, we are reducing over it to compute dkv - compiler_params=pltpu.TPUCompilerParams( + compiler_params=pltpu.CompilerParams( dimension_semantics=("arbitrary", "arbitrary", "arbitrary"), ), name=kernel_name, @@ -2293,6 +2289,26 @@ def _splash_attention( mask_function: MaskFunctionType | None, interpret: bool, ) -> SplashCustomReturnType: + """ + For dynamic masks, `partial_mask_blocks` has shape (head_count, q_blocks, kv_blocks, block_q, block_kv). + This shape allows sharding across both head count and query sequence dimensions. + + Note: The leading dimensions (head_count, q_blocks, kv_blocks) must be + collapsed into a single dimension before being passed to the kernel. + """ + def _collapse_partial_mask_blocks(mask_info: mask_info_lib.MaskInfo | None): + if mask_info is None or mask_info.partial_mask_blocks is None: + return mask_info + + return mask_info._replace( + partial_mask_blocks=mask_info.partial_mask_blocks.reshape( + -1, *mask_info.partial_mask_blocks.shape[-2:] + ) + ) + + fwd_mask_info = _collapse_partial_mask_blocks(fwd_mask_info) + dq_mask_info = _collapse_partial_mask_blocks(dq_mask_info) + dkv_mask_info = _collapse_partial_mask_blocks(dkv_mask_info) return _splash_attention_custom( fwd_mask_info, dq_mask_info, @@ -2352,13 +2368,16 @@ def manual_sharding_spec(self, sharding: jax.sharding.NamedSharding): spec = sharding.spec assert len(spec) == 2 replicated = jax.sharding.PartitionSpec() + partial_mask_blocks_spec = ( + spec if self.fwd_mask_info.is_dynamic_mask else replicated + ) # Shard q_sequence over the sequence dimension only. q_sequence_spec = jax.sharding.PartitionSpec(spec[1]) mask_info_specs = mask_info_lib.MaskInfo( # pytype: disable=wrong-arg-types data_next=spec if self.fwd_mask_info.data_next is not None else None, mask_next=spec if self.fwd_mask_info.mask_next is not None else None, block_mask=spec if self.fwd_mask_info.block_mask is not None else None, - partial_mask_blocks=replicated + partial_mask_blocks=partial_mask_blocks_spec if self.fwd_mask_info.partial_mask_blocks is not None else None, q_sequence=q_sequence_spec diff --git a/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_mask.py b/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_mask.py index eab2a695dc02..354fdb24f9df 100644 --- a/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_mask.py +++ b/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_mask.py @@ -92,6 +92,35 @@ def make_local_attention_mask( return mask.astype(np.bool_) +def make_chunk_attention_mask( + shape: tuple[int, int], chunk_size: int +) -> np.ndarray: + """Makes a chunked causal attention mask. + + Args: + shape: The desired shape of the mask (q_seq_len, kv_seq_len). + chunk_size: The size of the attention chunks. + + Returns: + A boolean mask of shape `mask_shape` where True indicates attention is + allowed according to chunked causal rules, and False otherwise. + + Raises: + ValueError: If chunk_window_size is None or not positive. + """ + if chunk_size <= 0: + raise ValueError('chunk_size must be positive') + + q_seq_len, kv_seq_len = shape + q_idx = np.arange(q_seq_len, dtype=np.int32) + kv_idx = np.arange(kv_seq_len, dtype=np.int32) + + # chunk mask calculation + same_chunk = (q_idx[:, None] // chunk_size) == (kv_idx[None, :] // chunk_size) + mask = same_chunk & (q_idx[:, None] >= kv_idx[None, :]) + return mask + + def make_random_mask( shape: tuple[int, int], sparsity: float, seed: int ) -> np.ndarray: @@ -196,15 +225,20 @@ def __hash__(self): class _ComputableMask(Mask): """Superclass for all masks that can be computed inside the kernel using a callable object. + This subclass is designed to be used with Splash Attention. + It allows the mask logic to be computed on-the-fly or fused into the attention + kernel, avoiding the memory cost of materializing the full + (sequence_length, sequence_length) boolean mask array, which can be excessive + for long sequences. + Attributes: _shape: Shape of the 2-dim mask: (q_seq_len, kv_seq_len). offset: Offset of q start wrt kv. A positive offset shifts the bottom triangle upward, a negative one shifts it downward. A negative offset makes the first 'offset' rows of the attention matrix all 0s which leads to undefined softmax. - q_sequence: Indices of Q sequence. - q_sequence is reused across __getitem__ calls which is important for - compile-time performance. + q_sequence: Indices of Q sequence. q_sequence is reused across __getitem__ + calls which is important for compile-time performance. mask_function: Function used by the SplashAttention kernel to compute the mask rather than loading it. """ @@ -314,26 +348,80 @@ def __hash__(self): )) -class LocalMask(Mask): +class ChunkedCausalMask(_ComputableMask): + """Lazy chunked causal mask. + + Attention is causal within each chunk (0, K), (K, 2K), (2K, 3K), ... tokens + attend to each other but not across chunks. + Llama4 models use interleaved chunk attention along with global attention. + + + Attributes: + chunk_size: The size of each attention chunk. + """ + + chunk_size: int + + def __init__( + self, + shape: tuple[int, int], + chunk_size: int, + shard_count: int = 1, + ): + if chunk_size <= 0: + raise ValueError('chunk_size must be positive') + self.chunk_size = chunk_size + + # Define the mask function for chunk attention + def chunked_causal_mask_function(q_ids, kv_ids): + """Computes the mask logic for the given slice indices.""" + # Condition 1: Same chunk + same_chunk = (q_ids // self.chunk_size) == (kv_ids // self.chunk_size) + + # Condition 2: Causal + causal = q_ids >= kv_ids + + return same_chunk & causal + + super().__init__( + shape=shape, + mask_function=chunked_causal_mask_function, + shard_count=shard_count, + ) + + def __eq__(self, other: object): + if not isinstance(other, type(self)): + return NotImplemented + + return ( + self.shape == other.shape + and self.chunk_size == other.chunk_size + and np.array_equal(self.q_sequence, other.q_sequence) + ) + + def __hash__(self): + return hash(( + type(self), + self.shape, + self.chunk_size, + self.q_sequence.tobytes() if self.q_sequence is not None else None, + )) + + +class LocalMask(_ComputableMask): """Lazy local mask, prevents model from attending to tokens outside window. Attributes: - _shape: Shape of the 2-dim mask: (q_seq_len, kv_seq_len). - window_size: Size of the two sides of the local window (None identifes no + window_size: Size of the two sides of the local window (None identifies no limit for the given side). offset: Offset of q start wrt kv. A positive offset shifts the bottom triangle upward, a negative one shifts it downward. A negative offset makes the first 'offset' rows of the attention matrix all 0s which leads to undefined softmax. - _q_sequence: Important for performance. """ - # TODO(amagni): Transform LocalMask into a _ComputableMask. - - _shape: tuple[int, int] window_size: tuple[int | None, int | None] offset: int - _q_sequence: np.ndarray | None = None def __init__( self, @@ -342,68 +430,50 @@ def __init__( offset: int, shard_count: int = 1, ): - self._shape = shape self.window_size = window_size self.offset = offset - if self.shape[0] % (shard_count * shard_count) != 0: - raise ValueError( - f'Shard count squared ({shard_count * shard_count}) must' - f' divide Q seq_len ({self.shape[0]}) evenly.' - ) - - @property - def shape(self) -> tuple[int, int]: - return self._shape - - def __getitem__(self, idx) -> np.ndarray: - if len(idx) != 2: - raise NotImplementedError(f'Unsupported slice: {idx}') - q_slice, kv_slice = idx - if not isinstance(q_slice, slice) or not isinstance(kv_slice, slice): - raise NotImplementedError(f'Unsupported slice: {idx}') - - q_slice = _fill_slice(q_slice, self.shape[0]) - kv_slice = _fill_slice(kv_slice, self.shape[1]) + def local_mask_function(q_ids, kv_ids): + """Computes the local attention mask for the given slice indices.""" + left_size, right_size = self.window_size - if self._q_sequence is None: - rows = np.arange(q_slice.start, q_slice.stop) - else: - rows = self._q_sequence[q_slice] - - cols = np.arange(kv_slice.start, kv_slice.stop) + assert q_ids.ndim == 2 + assert kv_ids.ndim == 2 - left_size, right_size = self.window_size + if left_size is None and right_size is None: + return np.ones((q_ids.shape[0], kv_ids.shape[1]), dtype=np.bool_) - if left_size is None and right_size is None: - return np.ones((rows.shape[0], cols.shape[0]), dtype=np.bool_) - else: - expanded_cols = cols[None, :] - if self.offset != 0: - expanded_rows = rows[:, None] + self.offset + # Avoid the addition when possible to avoid instantiating an actual array. + if offset != 0: + shifted_q_ids = q_ids + self.offset else: - expanded_rows = rows[:, None] - if left_size is not None and right_size is not None: - return (expanded_rows <= expanded_cols + left_size) & ( - expanded_cols - right_size <= expanded_rows - ) + shifted_q_ids = q_ids + + mask = None + if left_size is not None: + mask = shifted_q_ids - left_size <= kv_ids + if right_size is not None: + if mask is None: + mask = shifted_q_ids + right_size >= kv_ids + else: + mask &= shifted_q_ids + right_size >= kv_ids + return mask - elif left_size is not None and right_size is None: - return expanded_rows <= expanded_cols + left_size - else: - assert left_size is None and right_size is not None - return expanded_cols - right_size <= expanded_rows + super().__init__( + shape=shape, + mask_function=local_mask_function, + shard_count=shard_count, + ) def __eq__(self, other: object): if not isinstance(other, type(self)): - return NotImplemented + return False return ( self.shape == other.shape and self.window_size == other.window_size and self.offset == other.offset - and (True if self._q_sequence is None else - np.array_equal(self._q_sequence, other._q_sequence)) + and np.array_equal(self.q_sequence, other.q_sequence) ) def __hash__(self): @@ -412,7 +482,7 @@ def __hash__(self): self.shape, self.window_size, self.offset, - self._q_sequence.tobytes() if self._q_sequence is not None else None, + self.q_sequence.tobytes() if self.q_sequence is not None else None, )) diff --git a/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_mask_info.py b/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_mask_info.py index 65081e79c0cf..37ef92c2d33d 100644 --- a/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_mask_info.py +++ b/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_mask_info.py @@ -67,6 +67,10 @@ class MaskInfo(NamedTuple): q_sequence: A i32[q_sequence_length] NumPy array. When using causal masking, this contains the list of indices that correspond to q tokens. For plain causal this is just np.arange(q_sequence_length). + is_dynamic_mask: A bool indicating whether the mask is dynamic or static. + When True, the leading dimensions of `partial_mask_blocks` (num_heads, + q_blocks, kv_blocks) are not collapsed, allowing us to shard it along + those dimensions. """ data_next: np.ndarray | jax.Array | None @@ -74,6 +78,7 @@ class MaskInfo(NamedTuple): block_mask: np.ndarray | jax.Array | None partial_mask_blocks: np.ndarray | jax.Array | None q_sequence: np.ndarray | None + is_dynamic_mask: bool = None def _downcast_to_small_type(array: np.ndarray) -> np.ndarray: @@ -168,7 +173,7 @@ def __eq__(self, other: object) -> bool: def _get_mask_info_for_shard( output_shape: tuple[int, int, int], has_mask_next: bool, - mask: mask_lib.MultiHeadMask, + mask: mask_lib.MultiHeadMask | jax.Array, block_shape: tuple[int, int], coords_to_partial_mask_block_index: dict[tuple[int, int, int], int], masks_per_head_shard: int, @@ -338,7 +343,8 @@ def _process_dynamic_mask( launched. q_seq_shards: Number of Q sequence shards of the mesh in which the kernel is launched. - shrink_grid: Whether or not we should apply the grid shrinking optimization. This is currently ignored. + shrink_grid: Whether or not we should apply the grid shrinking optimization. + This is currently ignored. Returns: `MaskInfo`, a sparse representation of the dense mask. @@ -349,11 +355,6 @@ def _process_dynamic_mask( """ del shrink_grid - - # TODO(pobudzey): Properly support sharding. - if head_shards != 1 or q_seq_shards != 1: - raise ValueError('Dynamic mask processing does not support sharding.') - if len(mask.shape) != 3: raise ValueError(f'Expected a 3-dim mask, instead got: {mask.shape}.') @@ -370,6 +371,18 @@ def _process_dynamic_mask( if kv_mod != 0: raise ValueError(f'{kv_block_size=} should divide {kv_seq_len=}.') + q_seq_len_per_shard, mod = divmod(q_seq_len, q_seq_shards) + if mod != 0: + raise ValueError(f'{q_seq_shards=} should divide {q_seq_len=}.') + + q_blocks_per_shard, mod = divmod(q_seq_len_per_shard, q_block_size) + if mod != 0: + raise ValueError(f'{q_block_size=} should divide {q_seq_len_per_shard=}.') + + heads_per_shard, mod = divmod(head_count, head_shards) + if mod != 0: + raise ValueError(f'{head_shards=} should divide {head_count=}.') + block_mask_shape = ( head_count, q_blocks_count, @@ -398,26 +411,66 @@ def _process_dynamic_mask( block_mask = jnp.where(is_full_mask, 2, block_mask) block_mask = jnp.where(is_empty_mask, 0, block_mask) - # TODO(pobudzey): Return the next valid mask index instead of 0 for a more efficient pipeline. - mask_next = jnp.where( - jnp.logical_or(is_empty_mask, is_full_mask), - 0, - jnp.arange(math.prod(block_mask_shape), dtype=np.int32).reshape( - block_mask_shape - ), - ) + q_sequence_axis = 1 + head_axis = 0 - # data_next stores the index of the next non-empty data block in the sequence. - # The indices of empty blocks are set to 0 to avoid copying extra data when - # pipeling. - if is_dkv: - data_next = jnp.arange(q_blocks_count, dtype=np.int32)[None, :, None] - else: - data_next = jnp.arange(kv_blocks_count, dtype=np.int32)[None, None, :] - data_next = jnp.broadcast_to(data_next, block_mask_shape) - data_next = jnp.where(is_empty_mask, 0, data_next) + # Each iteration of the loop processes a slice of the mask info + # tensors of this shape: + mask_info_slice_shape = (heads_per_shard, q_blocks_per_shard, kv_blocks_count) + + # Collect mask_info shards along the head dimension, concatenate (or + # broadcast) them after the loop. + data_next_per_head_list, mask_next_per_head_list = [], [] + for head_shard in range(head_shards): + head_start = head_shard * heads_per_shard + mask_head_slice = slice(head_start, head_start + heads_per_shard) + + # Collect mask_info shards along the q_sequence dimension, concatenate them + # after the loop. + data_next_sequence_slices, mask_next_sequence_slices = [], [] + for q_seq_len_shard in range(q_seq_shards): + q_seq_len_start = q_seq_len_shard * q_blocks_per_shard + blocked_q_seq_len_slice = slice( + q_seq_len_start, q_seq_len_start + q_blocks_per_shard + ) + local_block_mask = block_mask[mask_head_slice, blocked_q_seq_len_slice] + + mask_next_slice = jnp.arange( + math.prod(mask_info_slice_shape), dtype=np.int32 + ).reshape(mask_info_slice_shape) + mask_next_slice = jnp.where(local_block_mask == 1, mask_next_slice, 0) + + # data_next stores the index of the next non-empty data block in the sequence. + # The indices of empty blocks are set to 0 to avoid copying extra data when + # pipeling. + if is_dkv: + data_next_slice = jnp.arange(q_blocks_per_shard, dtype=np.int32)[ + None, :, None + ] + else: + data_next_slice = jnp.arange(kv_blocks_count, dtype=np.int32)[ + None, None, : + ] + data_next_slice = jnp.broadcast_to(data_next_slice, mask_info_slice_shape) + data_next_slice = jnp.where(local_block_mask == 0, 0, data_next_slice) + + data_next_sequence_slices.append(data_next_slice) + mask_next_sequence_slices.append(mask_next_slice) + + # Concatenate the sequence shards. + data_next_per_head = jnp.concatenate( + data_next_sequence_slices, axis=q_sequence_axis + ) + data_next_per_head_list.append(data_next_per_head) + mask_next_per_head = jnp.concatenate( + mask_next_sequence_slices, axis=q_sequence_axis + ) + mask_next_per_head_list.append(mask_next_per_head) + + # Concatenate (or broadcast) the head shards. + data_next = jnp.concatenate(data_next_per_head_list, axis=head_axis) + mask_next = jnp.concatenate(mask_next_per_head_list, axis=head_axis) - partial_mask_blocks = partial_mask_blocks.reshape(-1, *block_shape) if is_dkv: partial_mask_blocks = partial_mask_blocks.swapaxes(-1, -2) @@ -438,9 +491,11 @@ def _downcast(array: jax.Array, max_value: int) -> jax.Array: if downcast_smem_data: block_mask = block_mask.astype(np.int8) # values are in the range [0, 1, 2] data_next = _downcast( - data_next, q_blocks_count if is_dkv else kv_blocks_count + data_next, q_blocks_per_shard if is_dkv else kv_blocks_count + ) + mask_next = _downcast( + mask_next, heads_per_shard * q_blocks_per_shard * kv_blocks_count ) - mask_next = _downcast(mask_next, math.prod(block_mask_shape)) return ( MaskInfo( @@ -449,6 +504,7 @@ def _downcast(array: jax.Array, max_value: int) -> jax.Array: block_mask=block_mask, partial_mask_blocks=partial_mask_blocks, q_sequence=None, + is_dynamic_mask=True, ), None, ) @@ -577,7 +633,7 @@ def assign_unique_ids(objects): ] # TODO(amagni): checking the validity of the masks is slow for large masks. - # Disable it for now, reevalute in the future. + # Disable it for now, reevaluate in the future. partial_mask_block_ids: dict[_HashableNDArray, int] = collections.defaultdict( lambda: len(partial_mask_block_ids) @@ -691,7 +747,7 @@ def set_block_mask(mask_id: int, q_index: int, kv_index: int, value: int): q_sequence_axis = 1 head_axis = 0 - # Collect mask_info shards along the head dimension, concatentate (or + # Collect mask_info shards along the head dimension, concatenate (or # broadcast) them after the loop. data_next_per_head_list, mask_next_per_head_list = [], [] for head_shard in range(shards_to_process): diff --git a/jax/experimental/pallas/tpu.py b/jax/experimental/pallas/tpu.py index ecc9d0d15120..c96bc8291c4d 100644 --- a/jax/experimental/pallas/tpu.py +++ b/jax/experimental/pallas/tpu.py @@ -19,19 +19,19 @@ from jax._src.pallas.mosaic.core import create_tensorcore_mesh as create_tensorcore_mesh from jax._src.pallas.mosaic.core import dma_semaphore as dma_semaphore from jax._src.pallas.mosaic.core import GridDimensionSemantics as GridDimensionSemantics +from jax._src.pallas.mosaic.core import KernelType as KernelType from jax._src.pallas.mosaic.core import PARALLEL as PARALLEL from jax._src.pallas.mosaic.core import PrefetchScalarGridSpec as PrefetchScalarGridSpec -from jax._src.pallas.mosaic.core import semaphore as semaphore from jax._src.pallas.mosaic.core import SemaphoreType as SemaphoreType -from jax._src.pallas.mosaic.core import TPUMemorySpace as TPUMemorySpace -from jax._src.pallas.mosaic.core import TPUCompilerParams as TPUCompilerParams -from jax._src.pallas.mosaic.core import runtime_assert_enabled as runtime_assert_enabled -from jax._src.pallas.mosaic.core import _ENABLE_RUNTIME_ASSERT as enable_runtime_assert # noqa: F401 +from jax._src.pallas.mosaic.core import MemorySpace as MemorySpace +from jax._src.pallas.mosaic.core import CompilerParams as CompilerParams from jax._src.pallas.mosaic.helpers import sync_copy as sync_copy from jax._src.pallas.mosaic.helpers import core_barrier as core_barrier from jax._src.pallas.mosaic.helpers import run_on_first_core as run_on_first_core +from jax._src.pallas.mosaic.interpret import InterpretParams as InterpretParams from jax._src.pallas.mosaic.lowering import LoweringException as LoweringException from jax._src.pallas.mosaic.pipeline import BufferedRef as BufferedRef +from jax._src.pallas.mosaic.pipeline import BufferedRefBase as BufferedRefBase from jax._src.pallas.mosaic.pipeline import emit_pipeline as emit_pipeline from jax._src.pallas.mosaic.pipeline import emit_pipeline_with_allocations as emit_pipeline_with_allocations from jax._src.pallas.mosaic.pipeline import get_pipeline_schedule as get_pipeline_schedule @@ -40,21 +40,26 @@ from jax._src.pallas.mosaic.primitives import async_remote_copy as async_remote_copy from jax._src.pallas.mosaic.primitives import bitcast as bitcast from jax._src.pallas.mosaic.primitives import delay as delay -from jax._src.pallas.mosaic.primitives import device_id as device_id -from jax._src.pallas.mosaic.primitives import DeviceIdType as DeviceIdType from jax._src.pallas.mosaic.primitives import get_barrier_semaphore as get_barrier_semaphore +from jax._src.pallas.mosaic.primitives import get_memory_space as get_memory_space from jax._src.pallas.mosaic.primitives import make_async_copy as make_async_copy from jax._src.pallas.mosaic.primitives import make_async_remote_copy as make_async_remote_copy from jax._src.pallas.mosaic.primitives import prng_random_bits as prng_random_bits from jax._src.pallas.mosaic.primitives import prng_seed as prng_seed from jax._src.pallas.mosaic.primitives import repeat as repeat from jax._src.pallas.mosaic.primitives import roll as roll -from jax._src.pallas.mosaic.primitives import semaphore_read as semaphore_read -from jax._src.pallas.mosaic.primitives import semaphore_signal as semaphore_signal -from jax._src.pallas.mosaic.primitives import semaphore_wait as semaphore_wait +from jax._src.pallas.mosaic.primitives import with_memory_space_constraint as with_memory_space_constraint from jax._src.pallas.mosaic.random import sample_block as sample_block from jax._src.pallas.mosaic.random import to_pallas_key as to_pallas_key +# Those primitives got moved to Pallas core. Keeping the updated imports +# here for backward compatibility. +from jax._src.pallas.core import semaphore as semaphore +from jax._src.pallas.primitives import DeviceIdType as DeviceIdType +from jax._src.pallas.primitives import semaphore_read as semaphore_read +from jax._src.pallas.primitives import semaphore_signal as semaphore_signal +from jax._src.pallas.primitives import semaphore_wait as semaphore_wait + import types from jax._src.pallas.mosaic.verification import assume from jax._src.pallas.mosaic.verification import pretend @@ -65,8 +70,30 @@ ) del types, assume, pretend, skip, define_model # Clean up. -ANY = TPUMemorySpace.ANY -CMEM = TPUMemorySpace.CMEM -SMEM = TPUMemorySpace.SMEM -VMEM = TPUMemorySpace.VMEM -SEMAPHORE = TPUMemorySpace.SEMAPHORE +ANY = MemorySpace.ANY +CMEM = MemorySpace.CMEM +SMEM = MemorySpace.SMEM +VMEM = MemorySpace.VMEM +HBM = MemorySpace.HBM +SEMAPHORE = MemorySpace.SEMAPHORE + +import typing as _typing # pylint: disable=g-import-not-at-top +if _typing.TYPE_CHECKING: + TPUCompilerParams = CompilerParams + TPUMemorySpace = MemorySpace +else: + from jax._src.deprecations import deprecation_getattr as _deprecation_getattr + _deprecations = { + # Deprecated on May 30th 2025. + "TPUCompilerParams": ( + "TPUCompilerParams is deprecated, use CompilerParams instead.", + CompilerParams, + ), + "TPUMemorySpace": ( + "TPUMemorySpace is deprecated, use MemorySpace instead.", + MemorySpace, + ), + } + __getattr__ = _deprecation_getattr(__name__, _deprecations) + del _deprecation_getattr +del _typing diff --git a/jax/experimental/pallas/triton.py b/jax/experimental/pallas/triton.py index 06adb9e6da7e..1c512540adf2 100644 --- a/jax/experimental/pallas/triton.py +++ b/jax/experimental/pallas/triton.py @@ -14,7 +14,23 @@ """Triton-specific Pallas APIs.""" -from jax._src.pallas.triton.core import TritonCompilerParams as TritonCompilerParams +from jax._src.pallas.triton.core import CompilerParams as CompilerParams from jax._src.pallas.triton.primitives import approx_tanh as approx_tanh from jax._src.pallas.triton.primitives import debug_barrier as debug_barrier from jax._src.pallas.triton.primitives import elementwise_inline_asm as elementwise_inline_asm + +import typing as _typing # pylint: disable=g-import-not-at-top +if _typing.TYPE_CHECKING: + TritonCompilerParams = CompilerParams +else: + from jax._src.deprecations import deprecation_getattr as _deprecation_getattr + _deprecations = { + # Deprecated on May 27th 2025. + "TritonCompilerParams": ( + "TritonCompilerParams is deprecated, use CompilerParams instead.", + CompilerParams, + ), + } + __getattr__ = _deprecation_getattr(__name__, _deprecations) + del _deprecation_getattr +del _typing diff --git a/jax/experimental/profiler.py b/jax/experimental/profiler.py index 766d20472155..f22fba50092b 100644 --- a/jax/experimental/profiler.py +++ b/jax/experimental/profiler.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from jax._src.lib import xla_client +from jax._src.lib import _profiler def get_profiled_instructions_proto(tensorboard_dir: str) -> bytes: @@ -30,4 +30,4 @@ def get_profiled_instructions_proto(tensorboard_dir: str) -> bytes: Serialized [ProfiledInstructionsProto](https://github.com/openxla/xla/blob/main/third_party/tsl/tsl/profiler/protobuf/profiled_instructions.proto). """ - return xla_client.profiler.get_profiled_instructions_proto(tensorboard_dir) + return _profiler.get_profiled_instructions_proto(tensorboard_dir) diff --git a/jax/experimental/rnn.py b/jax/experimental/rnn.py index 55cf2b3bae70..be06aba2db13 100644 --- a/jax/experimental/rnn.py +++ b/jax/experimental/rnn.py @@ -463,7 +463,7 @@ def _gpu_lowering_strip_tf32(fn, *args, cudnn_allow_tf32, **kw): rnn_fwd_p.def_abstract_eval(rnn_abstract_eval) if gpu_rnn: mlir.register_lowering(rnn_fwd_p, gpu_rnn.cudnn_rnn_lowering, platform='cuda') - if hasattr(gpu_rnn, "miopen_rnn_fwd_lowering"): + if hasattr(gpu_rnn, "miopen_rnn_lowering"): mlir.register_lowering(rnn_fwd_p, gpu_rnn.miopen_rnn_lowering, platform='rocm') diff --git a/jax/experimental/roofline/roofline.py b/jax/experimental/roofline/roofline.py index 6a7f2916b503..b711fd62e069 100644 --- a/jax/experimental/roofline/roofline.py +++ b/jax/experimental/roofline/roofline.py @@ -14,7 +14,8 @@ from __future__ import annotations from dataclasses import dataclass, field -from typing import Any, Callable, Protocol, Sequence +from typing import Any, Protocol +from collections.abc import Callable, Sequence import numpy as np import jax.numpy as jnp @@ -29,11 +30,11 @@ from jax._src.mesh import AbstractMesh, Mesh from jax._src.tree_util import broadcast_prefix, tree_flatten, tree_unflatten, tree_map from jax._src.util import foreach -from jax.experimental import shard_map +from jax._src.shard_map import shard_map, shard_map_p ShapeDtypeStructTree = Any - +Specs = Any map = util.safe_map @@ -56,7 +57,7 @@ class RooflineShape: dtype: np.dtype @classmethod - def from_aval(cls, aval: core.AbstractValue) -> "RooflineShape": + def from_aval(cls, aval: core.AbstractValue) -> RooflineShape: if not isinstance(aval, core.ShapedArray): raise TypeError(f"Expected ShapedArray, got {type(aval)}.") if not isinstance(aval.dtype, np.dtype): @@ -87,10 +88,10 @@ class RooflineResult: unfused_hbm_bytes: int = 0 @classmethod - def zeros(cls) -> "RooflineResult": + def zeros(cls) -> RooflineResult: return cls() - def __add__(self, other: "RooflineResult") -> "RooflineResult": + def __add__(self, other: RooflineResult) -> RooflineResult: def merge_ici_dicts(d1: dict[str, int], d2: dict[str, int]) -> dict[str, int]: return {k: d1.get(k, 0) + d2.get(k, 0) for k in set(d1) | set(d2)} @@ -104,7 +105,7 @@ def merge_ici_dicts(d1: dict[str, int], d2: dict[str, int]) -> dict[str, int]: unfused_hbm_bytes=self.unfused_hbm_bytes + other.unfused_hbm_bytes, ) - def __mul__(self, constant: int | float) -> "RooflineResult": + def __mul__(self, constant: int | float) -> RooflineResult: return RooflineResult( flops=int(self.flops * constant), unfused_flops=int(self.unfused_flops * constant), @@ -115,7 +116,7 @@ def __mul__(self, constant: int | float) -> "RooflineResult": unfused_hbm_bytes=int(self.unfused_hbm_bytes * constant), ) - def __rmul__(self, constant: int | float) -> "RooflineResult": + def __rmul__(self, constant: int | float) -> RooflineResult: return self.__mul__(constant) @@ -188,6 +189,16 @@ def calculate_peak_hbm_bytes() -> int: pin_lhs_in_vmem=pin_lhs_in_vmem, pin_rhs_in_vmem=pin_rhs_in_vmem, ) + elif "call_jaxpr" in eqn.params: + # Used for custom_jvp_call_p. Recursively calculates roofline result for + # all primitives in the custom function. + result += _roofline_interpreter( + util.wrap_name(f_name, eqn.primitive.name), + eqn.params['call_jaxpr'], + mesh, + pin_lhs_in_vmem=pin_lhs_in_vmem, + pin_rhs_in_vmem=pin_rhs_in_vmem, + ) else: if eqn.primitive not in _rooflines: msg = f"No roofline rule for {eqn.primitive}." @@ -230,8 +241,8 @@ def wrapped(*args): def roofline( f: Callable, mesh: Mesh | AbstractMesh | None = None, - in_specs: shard_map.Specs | None = None, - out_specs: shard_map.Specs | None = None, + in_specs: Specs | None = None, + out_specs: Specs | None = None, *, pin_lhs_in_vmem: bool = False, pin_rhs_in_vmem: bool = False, @@ -243,14 +254,15 @@ def roofline( def wrapped(*args): wrapped_f = f if in_specs is not None and out_specs is not None and mesh is not None: - wrapped_f = shard_map.shard_map(wrapped_f, mesh, in_specs, out_specs) + wrapped_f = shard_map(wrapped_f, mesh=mesh, in_specs=in_specs, + out_specs=out_specs) if vjp: wrapped_f = _f_with_vjp(wrapped_f) jaxpr, out_shapes = make_jaxpr(wrapped_f, return_shape=True)(*args) def make_sharded_shape_dtype_struct( - shape: api.ShapeDtypeStruct, out_spec: shard_map.Specs + shape: api.ShapeDtypeStruct, out_spec: Specs ) -> api.ShapeDtypeStruct: return api.ShapeDtypeStruct( shape.shape, shape.dtype, sharding=NamedSharding(mesh, out_spec) # type: ignore @@ -267,7 +279,7 @@ def make_sharded_shape_dtype_struct( used_outputs = (True,) * len(jaxpr.jaxpr.outvars) jaxpr, _ = dce_jaxpr(jaxpr.jaxpr, used_outputs) shard_map_eqns = [ - e for e in jaxpr.eqns if e.primitive == shard_map.shard_map_p + e for e in jaxpr.eqns if e.primitive == shard_map_p ] if shard_map_eqns: try: @@ -307,8 +319,8 @@ def standard_rule(ctx: RooflineRuleContext, *args, **kwargs): def roofline_and_grad( f: Callable, mesh: Mesh | AbstractMesh, - in_specs: shard_map.Specs, - out_specs: shard_map.Specs, + in_specs: Specs, + out_specs: Specs, *, pin_lhs_in_vmem: bool = False, pin_rhs_in_vmem: bool = False, diff --git a/jax/experimental/roofline/rooflines.py b/jax/experimental/roofline/rooflines.py index 1edd1e0649b1..4941f95e8e1c 100644 --- a/jax/experimental/roofline/rooflines.py +++ b/jax/experimental/roofline/rooflines.py @@ -14,6 +14,7 @@ from collections import defaultdict from dataclasses import replace import itertools as it +from collections.abc import Sequence import numpy as np from jax._src import ad_util @@ -21,8 +22,10 @@ from jax._src import ops from jax._src import prng from jax._src import random +from jax._src import shard_map from jax._src.lax import ( ann, + control_flow, convolution, fft, lax, @@ -33,12 +36,14 @@ windowed_reductions, ) from jax.experimental import roofline -from jax.experimental import shard_map +# One FMA (Fused Multiply Add) takes 2 flops to compute. +_FMA_FLOPS_FACTOR = 2 for prim in it.chain( ad_util.__dict__.values(), ann.__dict__.values(), + control_flow.__dict__.values(), convolution.__dict__.values(), fft.__dict__.values(), lax.__dict__.values(), @@ -106,6 +111,8 @@ def _unary_p_roofline( roofline.register_roofline(special.erfc_p)(_unary_p_roofline) roofline.register_roofline(special.lgamma_p)(_unary_p_roofline) +roofline.register_standard_roofline(core.pvary_p) + def _binary_p_roofline( ctx: roofline.RooflineRuleContext, *args, @@ -143,6 +150,50 @@ def _binary_p_roofline( roofline.register_roofline(lax.min_p)(_binary_p_roofline) roofline.register_roofline(lax.max_p)(_binary_p_roofline) +def _cumulative_p_roofline( + ctx: roofline.RooflineRuleContext, + *args, + axis: int, + **kw, +) -> roofline.RooflineResult: + (x,) = (roofline.RooflineShape.from_aval(aval) for aval in ctx.avals_in) + out = roofline.RooflineShape.from_aval(ctx.avals_out[0]) + return roofline.RooflineResult( + # `cum{max, min, prod, sum}` only calculate values for one axis. + unfused_flops=x.shape[axis], + unfused_hbm_bytes=( + x.dtype.itemsize * x.size + out.dtype.itemsize * out.size + ), + ) + +roofline.register_roofline(control_flow.cummax_p)(_cumulative_p_roofline) +roofline.register_roofline(control_flow.cummin_p)(_cumulative_p_roofline) +roofline.register_roofline(control_flow.cumprod_p)(_cumulative_p_roofline) +roofline.register_roofline(control_flow.cumsum_p)(_cumulative_p_roofline) + +@roofline.register_roofline(control_flow.cumlogsumexp_p) +def _cumlogsumexp_p_roofline( + ctx: roofline.RooflineRuleContext, + *args, + axis: int, + **kw, +) -> roofline.RooflineResult: + (x,) = (roofline.RooflineShape.from_aval(aval) for aval in ctx.avals_in) + out = roofline.RooflineShape.from_aval(ctx.avals_out[0]) + return roofline.RooflineResult( + # Similar to `cum{max, min, prod, sum}`, `cumlogsumexp` only calculates + # values for one axis. But for `x.shape[axis] = S`, it computes (for a + # naive implementation): + # S `exp` ops. + # S-1 `add` ops. + # 1 log op. + # Thus, the total number of flops is 2 * S. + unfused_flops=x.shape[axis] * 2, + unfused_hbm_bytes=( + x.dtype.itemsize * x.size + out.dtype.itemsize * out.size + ), + ) + @roofline.register_roofline(lax.dot_general_p) def _dot_general_roofline( @@ -156,7 +207,7 @@ def _dot_general_roofline( (lhs_contract, _), (lhs_batch, _) = dimension_numbers flops = ( - 2 + _FMA_FLOPS_FACTOR * lhs.size * rhs.size / np.prod([lhs.shape[i] for i in lhs_contract]) @@ -177,16 +228,208 @@ def _dot_general_roofline( unfused_hbm_bytes=hbm_bytes, ) + +def _get_spatial_valid_position_count_for_one_dim( + window_dim_stride: int, + base_dilation: int, + window_dilation: int, + kernel_limit: int, + input_limit: int, + output_limit: int, + padding: tuple[int, int], +) -> int: + """Gets the valid position count for conv for a single spatial dimension. + + Args: + window_dim_stride: The stride of the window along this dimension. + base_dilation: The base dilation factor along this dimension. + window_dilation: The window dilation factor along this dimension. + kernel_limit: The size of the kernel along this dimension. + input_limit: The size of the input along this dimension. + output_limit: The size of the output along this dimension. + padding: The padding applied to the input along this dimension. + """ + padding_low = padding[0] + padding_high = padding[1] + + # These two conditions will create an N^2 iteration pattern with only N + # valid elements. This is a performance optimization and produces the same + # result as the whole loop. + if ( + input_limit == output_limit + and kernel_limit == output_limit + and input_limit == base_dilation + and window_dilation == 1 + and max(1, input_limit - 1) == window_dim_stride + and padding_low == 0 + and padding_high == 0 + ): + return input_limit + + if ( + input_limit == 1 + and kernel_limit == output_limit + and window_dilation == 1 + and base_dilation == 1 + and window_dim_stride == 1 + and padding_low == output_limit - 1 + and padding_high == output_limit - 1 + ): + return output_limit + + valid_position_count = 0 + # Loop over each point in the kernel + for kernel_idx in range(kernel_limit): + + # Skip loop for trivial stride and base_dilation + if window_dim_stride == 1 and base_dilation == 1: + undilated_index_base = padding_low - kernel_idx * window_dilation + upper_limit = min( + input_limit + undilated_index_base, + output_limit, + ) + lower_limit = max(0, undilated_index_base) + + valid_position_count += max(upper_limit - lower_limit, 0) + continue + + # Loop over each point in the output + for output_idx in range(output_limit): + # Calculate lhs (input) index without taking base dilation into account + undilated_index = ( + output_idx * window_dim_stride + - padding_low + + kernel_idx * window_dilation + ) + # Calculate the actual lhs (input) index after dilation + lhs_spatial_index = int(undilated_index / base_dilation) + + # Skip if the lhs (input) index is to be dilated. + if undilated_index != lhs_spatial_index * base_dilation: + continue + # Skip if input index is not in bound. + if lhs_spatial_index < 0 or lhs_spatial_index >= input_limit: + continue + + valid_position_count += 1 + return valid_position_count + + +def _get_spatial_valid_position_count( + dnums: convolution.ConvDimensionNumbers, + lhs: roofline.RooflineShape, + rhs: roofline.RooflineShape, + out: roofline.RooflineShape, + window_strides: Sequence[int], + padding: Sequence[tuple[int, int]], + lhs_dilation: Sequence[int], + rhs_dilation: Sequence[int], +) -> int: + """Gets the number of valid spatial positions for conv_general_dilated. + + Args: + dnums: The dimension numbers for the convolution. + lhs: The shape of the left-hand side of the convolution. + rhs: The shape of the right-hand side of the convolution. + out: The shape of the output of the convolution. + window_strides: The stride of the window along each spatial dimension. + padding: The padding applied to the input along each spatial dimension. + lhs_dilation: The dilation factor for the left-hand side along each spatial + dimension. + rhs_dilation: The dilation factor for the right-hand side along each spatial + dimension. + """ + input_spatial_dims, kernel_spatial_dims, out_spatial_dims = ( + dnums.lhs_spec[2:], + dnums.rhs_spec[2:], + dnums.out_spec[2:], + ) + + valid_position_counts = 1 + # Loop over each spatial dimension and determine how many valid positions + # there are for each dimension. + for d in range(len(input_spatial_dims)): + valid_position_counts *= _get_spatial_valid_position_count_for_one_dim( + window_dim_stride=window_strides[d], + base_dilation=lhs_dilation[d], + window_dilation=rhs_dilation[d], + kernel_limit=rhs.shape[kernel_spatial_dims[d]], + input_limit=lhs.shape[input_spatial_dims[d]], + output_limit=out.shape[out_spatial_dims[d]], + padding=padding[d], + ) + + return valid_position_counts + + +def _calculate_conv_flops( + lhs: roofline.RooflineShape, + rhs: roofline.RooflineShape, + out: roofline.RooflineShape, + window_strides: Sequence[int], + padding: Sequence[tuple[int, int]], + lhs_dilation: Sequence[int], + rhs_dilation: Sequence[int], + dimension_numbers: convolution.ConvGeneralDilatedDimensionNumbers, + batch_group_count: int, +) -> int: + """Calculates roofline unfused flops for Jax's conv_general_dilated primitive. + + See `jax.lax.conv_general_dilated` for details on the arguments. + """ + dnums = convolution.conv_dimension_numbers( + lhs.shape, rhs.shape, dimension_numbers + ) + + spatial_valid_position_counts = _get_spatial_valid_position_count( + dnums, lhs, rhs, out, window_strides, padding, lhs_dilation, rhs_dilation + ) + + batch = lhs.shape[dnums.lhs_spec[0]] + num_output_features = out.shape[dnums.out_spec[1]] + num_input_features = rhs.shape[dnums.rhs_spec[1]] + num_output_batch = batch / batch_group_count + + non_spatial_dims_factor = ( + num_input_features * num_output_features * num_output_batch + ) + + fma_count = non_spatial_dims_factor * spatial_valid_position_counts + flops = fma_count * _FMA_FLOPS_FACTOR + return int(flops) + + @roofline.register_roofline(convolution.conv_general_dilated_p) def _conv_general_dilated_roofline( - ctx: roofline.RooflineRuleContext, - *args, - **kw, + ctx: roofline.RooflineRuleContext, + *args, + window_strides: Sequence[int], + padding: Sequence[tuple[int, int]], + lhs_dilation: Sequence[int], + rhs_dilation: Sequence[int], + dimension_numbers: convolution.ConvGeneralDilatedDimensionNumbers, + batch_group_count: int, + **kw, ) -> roofline.RooflineResult: + """Roofline for Jax's conv_general_dilated primitive. + + See `jax.lax.conv_general_dilated` for details on the arguments. + """ lhs, rhs = (roofline.RooflineShape.from_aval(aval) for aval in ctx.avals_in) out = roofline.RooflineShape.from_aval(ctx.avals_out[0]) - # TODO(b/394648206): support computing unfused_flops for conv. + return roofline.RooflineResult( + unfused_flops=_calculate_conv_flops( + lhs, + rhs, + out, + window_strides, + padding, + lhs_dilation, + rhs_dilation, + dimension_numbers, + batch_group_count, + ), unfused_hbm_bytes=( lhs.dtype.itemsize * lhs.size + rhs.dtype.itemsize * rhs.size @@ -257,11 +500,33 @@ def _ring_collective_roofline( ) +@roofline.register_roofline(slicing.gather_p) +def _gather_roofline( + ctx: roofline.RooflineRuleContext, + *args, + **kw, +) -> roofline.RooflineResult: + _, indices = (roofline.RooflineShape.from_aval(aval) for aval in ctx.avals_in) + out = roofline.RooflineShape.from_aval(ctx.avals_out[0]) + + # Gather doesn't read the whole input buffer, it's equivalent to a copy the + # size of the output shape and a read of the gather indices. + bytes = ( + out.dtype.itemsize * out.size * 2 + indices.dtype.itemsize * indices.size + ) + + return roofline.RooflineResult( + # Gather does not issue any flops. + unfused_flops=0, + unfused_hbm_bytes=bytes, + ) + + def _scalar_collective_roofline( - ctx: roofline.RooflineRuleContext, - *args, - axes: tuple[str, ...], - **kw, + ctx: roofline.RooflineRuleContext, + *args, + axes: tuple[str, ...], + **kw, ) -> roofline.RooflineResult: shapes = [roofline.RooflineShape.from_aval(aval) for aval in ctx.avals_in] ctx = replace(ctx, avals_in=[core.ShapedArray((1,), shape.dtype) for shape in shapes]) @@ -272,7 +537,7 @@ def _scalar_collective_roofline( roofline.register_roofline(lax_parallel.pmax_p)(_scalar_collective_roofline) -@roofline.register_roofline(shard_map.psum2_p) +@roofline.register_roofline(lax_parallel.psum_invariant_p) def _psum2_roofline( ctx: roofline.RooflineRuleContext, *args, diff --git a/jax/experimental/serialize_executable.py b/jax/experimental/serialize_executable.py index 2d65141a22ea..e1d068ec789f 100644 --- a/jax/experimental/serialize_executable.py +++ b/jax/experimental/serialize_executable.py @@ -20,6 +20,7 @@ import jax from jax._src.lib import xla_client as xc +from collections.abc import Sequence def serialize(compiled: jax.stages.Compiled): @@ -43,14 +44,27 @@ def serialize(compiled: jax.stages.Compiled): def deserialize_and_load(serialized, in_tree, out_tree, - backend: str | xc.Client | None = None): + backend: str | xc.Client | None = None, + execution_devices: Sequence[xc.Device] | None = None): """Constructs a jax.stages.Compiled from a serialized executable.""" if backend is None or isinstance(backend, str): backend = jax.devices(backend)[0].client + if execution_devices is None: + execution_devices = backend.devices() + else: + device_backend = execution_devices[0].client + if device_backend != backend: + raise ValueError( + 'Execution devices belong to a client other than `backend`. Got ' + f'backend client: {(backend.platform, backend.platform_version)} and ' + 'execution devices client: ' + f'{(device_backend.platform, device_backend.platform_version)}') + (unloaded_executable, args_info_flat, - no_kwargs) = _JaxPjrtUnpickler(io.BytesIO(serialized), backend).load() + no_kwargs) = _JaxPjrtUnpickler( + io.BytesIO(serialized), backend, execution_devices).load() args_info = in_tree.unflatten(args_info_flat) @@ -77,14 +91,26 @@ def persistent_id(self, obj): class _JaxPjrtUnpickler(pickle.Unpickler): - def __init__(self, file, backend): + def __init__(self, file, backend, execution_devices=None): super().__init__(file) self.backend = backend - self.devices_by_id = {d.id: d for d in backend.devices()} + if execution_devices is None: + execution_devices = backend.devices() + else: + device_backend = execution_devices[0].client + if device_backend != backend: + raise ValueError( + 'Execution devices belong to a client other than `backend`. Got ' + f'backend client: {(backend.platform, backend.platform_version)} ' + 'and execution devices client: ' + f'{(device_backend.platform, device_backend.platform_version)}') + self.devices_by_id = {d.id: d for d in execution_devices} + self.execution_devices = xc.DeviceList(tuple(execution_devices)) def persistent_load(self, pid): if pid[0] == 'exec': - return self.backend.deserialize_executable(pid[1]) + return self.backend.deserialize_executable( + pid[1], executable_devices=self.execution_devices) if pid[0] == 'device': return self.devices_by_id[pid[1]] if pid[0] == 'client': diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index 66b70c6c2d34..027cbd36feea 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -11,82 +11,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations -from collections.abc import Callable, Hashable, Sequence -import enum -from functools import partial -import inspect -import itertools as it -from math import prod -import operator as op -from typing import Any, TypeVar, Union - -import numpy as np - -import jax -import jax.numpy as jnp -from jax.sharding import NamedSharding, PartitionSpec -from jax._src import ad_checkpoint -from jax._src import ad_util -from jax._src import api_util -from jax._src import callback -from jax._src import config -from jax._src import core -from jax._src import custom_derivatives -from jax._src import debugging -from jax._src import dispatch -from jax._src import dtypes -from jax._src import linear_util as lu -from jax._src import ops -from jax._src import pjit -from jax._src import prng -from jax._src import random -from jax._src import sharding_impls -from jax._src import source_info_util +from collections.abc import Callable, Hashable +from typing import Any from jax._src import traceback_util -from jax._src import util -from jax._src.core import Tracer -from jax._src.mesh import (AbstractMesh, Mesh, AxisType, use_abstract_mesh, - get_abstract_mesh) -from jax._src.api import _shared_code_pmap, _prepare_pmap -from jax._src.lax import (lax, parallel as lax_parallel, slicing, - windowed_reductions, convolution, fft, linalg, - special, control_flow, ann) -from jax._src import ffi -from jax._src.lib.mlir import ir -from jax._src.lib.mlir.dialects import sdy -from jax._src.util import (HashableFunction, HashablePartial, unzip2, - as_hashable_function, memoize, partition_list, - merge_lists, split_list, subs_list2, foreach) -from jax._src.interpreters import batching -from jax._src.interpreters import mlir -from jax._src.interpreters import partial_eval as pe -from jax._src.interpreters import pxla -from jax._src.interpreters import ad -from jax.tree_util import (tree_map, tree_flatten, tree_unflatten, - tree_structure, tree_leaves, keystr) -from jax._src.tree_util import (broadcast_prefix, prefix_errors, PyTreeDef, - generate_key_paths, KeyPath) -from jax.experimental.multihost_utils import (host_local_array_to_global_array, - global_array_to_host_local_array) - -P = PartitionSpec - -map, unsafe_map = util.safe_map, map -zip, unsafe_zip = util.safe_zip, zip -traceback_util.register_exclusion(__file__) +from jax.sharding import Mesh, AbstractMesh +from jax._src import shard_map as jshmap -# API - -Specs = Any # PyTree[PartitionSpec] +Specs = Any AxisName = Hashable - @traceback_util.api_boundary -def shard_map(f: Callable, mesh: Mesh | AbstractMesh, in_specs: Specs, - out_specs: Specs, check_rep: bool = True, - auto: frozenset[AxisName] = frozenset()): +def shard_map( + f: Callable, mesh: Mesh | AbstractMesh, in_specs: Specs, out_specs: Specs, + check_rep: bool = True, auto: frozenset[AxisName] = frozenset()): """Map a function over shards of data. Note: @@ -115,7 +53,7 @@ def shard_map(f: Callable, mesh: Mesh | AbstractMesh, in_specs: Specs, out_specs: a pytree with :class:`~jax.sharding.PartitionSpec` instances as leaves, with a tree structure that is a tree prefix of the output of ``f``. Each ``PartitionSpec`` represents how the corresponding output shards should be - concatenated. In each ``PartitionSpec``, metioning a ``mesh`` axis name at + concatenated. In each ``PartitionSpec``, mentioning a ``mesh`` axis name at a position expresses concatenation of that mesh axis's shards along the corresponding positional axis. Not mentioning a ``mesh`` axis name expresses a promise that the output values are equal along that mesh axis, @@ -136,2074 +74,9 @@ def shard_map(f: Callable, mesh: Mesh | AbstractMesh, in_specs: Specs, Examples: For examples, refer to :ref:`sharded-computation` or `SPMD multi-device parallelism with shard_map`_. - .. _SPMD multi-device parallelism with shard_map: https://jax.readthedocs.io/en/latest/notebooks/shard_map.html + .. _SPMD multi-device parallelism with shard_map: https://docs.jax.dev/en/latest/notebooks/shard_map.html """ - return _shard_map(f, mesh, in_specs, out_specs, check_rep, auto) - -def _shard_map(f: Callable, mesh: Mesh | AbstractMesh, in_specs: Specs, - out_specs: Specs | Callable[[], Specs], - check_rep: bool, auto: frozenset[AxisName]): - if not callable(f): - raise TypeError("shard_map requires a callable for its first argument, " - f"but got {f} of type {type(f)}.") - if not isinstance(mesh, (Mesh, AbstractMesh)): - raise TypeError("shard_map requires a `jax.sharding.Mesh` or a " - "`jax.sharding.AbstractMesh` instance for its " - f"second argument, but got {mesh} of type {type(mesh)}.") - if not auto.issubset(mesh.axis_names): - raise ValueError(f"shard_map requires auto={auto} to be a subset of " - f"mesh.axis_names={mesh.axis_names}") - _check_specs(SpecErrorType.input, in_specs, auto) - if not callable(out_specs): - _check_specs(SpecErrorType.out, out_specs, auto) - - @util.wraps(f) - @traceback_util.api_boundary - def wrapped(*args): - fun = lu.wrap_init(f, - debug_info=api_util.debug_info("shard_map", f, args, {})) - args_flat, in_tree = tree_flatten(args) - fun, out_tree = api_util.flatten_fun_nokwargs(fun, in_tree) - try: in_specs_flat = broadcast_prefix(in_specs, args, - is_leaf=lambda x: x is None) - except ValueError: - e, *_ = prefix_errors(in_specs, args) - raise e('shard_map in_specs') from None - dyn_argnums, in_specs_flat = unzip2((i, s) for i, s in enumerate(in_specs_flat) - if s is not None) - fun, args_flat = api_util.argnums_partial(fun, dyn_argnums, args_flat, False) - _check_specs_vs_args(f, mesh, in_tree, in_specs, dyn_argnums, in_specs_flat, args_flat) - in_names_flat = tuple(map(_canonicalize_spec, in_specs_flat)) - - @memoize - def out_names_thunk(): - if callable(out_specs): - out_specs_ = out_specs() - _check_specs(SpecErrorType.out, out_specs_, auto) - else: - out_specs_ = out_specs - dummy = tree_unflatten(out_tree(), [object()] * out_tree().num_leaves) - try: out_specs_flat = broadcast_prefix(out_specs_, dummy) - except ValueError: - e, *_ = prefix_errors(out_specs_, dummy) - raise e('shard_map out_specs') from None - return tuple(map(_canonicalize_spec, out_specs_flat)) - - if rewrite := check_rep: - fun = _efficient_transpose_rewrite(fun, mesh, in_names_flat, out_names_thunk) - - try: - out_flat = shard_map_p.bind( - fun, *args_flat, mesh=mesh, in_names=in_names_flat, - out_names_thunk=out_names_thunk, check_rep=check_rep, rewrite=rewrite, - auto=auto) - except _SpecError as e: - fails, = e.args - if not callable(out_specs): - msg = _spec_rank_error(SpecErrorType.out, f, out_tree(), out_specs, fails) - if any(fail is not no_fail and not fail.shape for fail in fails): - msg += (" In particular, for rank 0 outputs which are not constant " - "over the mesh, add at least one (singleton) axis to them so " - "that they can be concatenated using out_specs.") - raise ValueError(msg) from None - except _RepError as e: - fails, = e.args - if not callable(out_specs): - msg = _inout_rep_error(f, mesh, out_tree(), out_specs, fails) - raise ValueError(msg) from None - return tree_unflatten(out_tree(), out_flat) - return wrapped - -# Internally use AxisNames = dict[int, tuple[AxisName, ...]], not PartitionSpecs -AxisNames = dict[int, tuple[AxisName, ...]] # TODO(mattjj): make it hashable -def _canonicalize_spec(spec: PartitionSpec) -> AxisNames: - if isinstance(spec, PartitionSpec): - return {i: names if isinstance(names, tuple) else (names,) - for i, names in enumerate(spec) if names is not None} - else: - return spec - -# Error checking and messages - -SpecErrorType = enum.Enum('SpecErrorType', ['input', 'out']) - -def _check_specs(error_type: SpecErrorType, specs: Any, auto) -> None: - if error_type == SpecErrorType.input and specs is None: - raise TypeError( - "shard_map in_specs argument must be a pytree of " - "`jax.sharding.PartitionSpec` instances, but it was None.\n" - "Instead of `in_specs=None`, did you mean `in_specs=P()`, " - "where `P = jax.sharding.PartitionSpec`?") - def check_spec(p): - if not isinstance(p, PartitionSpec): - return False - for names in p: - if not isinstance(names, tuple): - names = (names,) - for name in names: - if name in auto: - return False - return True - if all(check_spec(p) for p in tree_leaves(specs)): return - prefix = 'in' if error_type == SpecErrorType.input else 'out' - msgs = [f" {prefix}_specs{keystr(key)} is {x} of type {type(x).__name__}, " - for key, x in generate_key_paths(specs) if not isinstance(x, P)] - if not msgs: - for key, p in generate_key_paths(specs): - for names in p: - if not isinstance(names, tuple): - names = (names,) - for name in names: - if name in auto: - msgs.append(f" {prefix}_specs{keystr(key)} refers to {repr(name)}") - raise ValueError( - f"shard_map {prefix}_specs argument cannot refer to an axis " - f"marked auto ({auto}), but:\n\n" - + '\n\n'.join(msgs) + '\n\n' - f"Check the {prefix}_specs values passed to shard_map.") - raise TypeError( - f"shard_map {prefix}_specs argument must be a pytree of " - f"`jax.sharding.PartitionSpec` instances, but:\n\n" - + '\n\n'.join(msgs) + '\n\n' - f"Check the {prefix}_specs values passed to shard_map.") - -class NoFail: pass -no_fail = NoFail() - -def _check_specs_vs_args( - f: Callable, mesh: Mesh, in_tree: PyTreeDef, in_specs: Specs, - dyn_argnums: Sequence[int], in_specs_flat: Sequence[P], - xs: Sequence) -> None: - in_avals = map(core.shaped_abstractify, xs) - fail = [a if not len(p) <= a.ndim else no_fail - for p, a in zip(in_specs_flat, in_avals)] - if any(f is not no_fail for f in fail): - fail = _expand_fail(in_tree, dyn_argnums, fail) - msg = _spec_rank_error(SpecErrorType.input, f, in_tree, in_specs, fail) - raise ValueError(msg) - in_names_flat = tuple(map(_canonicalize_spec, in_specs_flat)) - fail = [a if any(a.shape[d] % prod(mesh.shape[n] for n in ns) - for d, ns in names.items()) else no_fail - for a, names in zip(in_avals, in_names_flat)] - if any(f is not no_fail for f in fail): - fail = _expand_fail(in_tree, dyn_argnums, fail) - msg = _spec_divisibility_error(f, mesh, in_tree, in_specs, fail) - raise ValueError(msg) - -def _expand_fail(in_tree: PyTreeDef, dyn_argnums: Sequence[int], - fail: Sequence[core.ShapedArray | NoFail] - ) -> list[core.ShapedArray | NoFail]: - fail_: list[core.ShapedArray | NoFail] = [no_fail] * in_tree.num_leaves - for i, f in zip(dyn_argnums, fail): - fail_[i] = f - return fail_ - -def _spec_rank_error( - error_type: SpecErrorType, f: Callable, tree: PyTreeDef, specs: Specs, - fails: list[core.ShapedArray | NoFail]) -> str: - fun_name = getattr(f, '__name__', str(f)) - if error_type == SpecErrorType.input: - prefix, base = 'in', 'args' - ba = _try_infer_args(f, tree) - else: - prefix, base = 'out', f'{fun_name}(*args)' - msgs = [] - for (spec_key, spec), (fail_key, aval) in _iter_paths(tree, specs, fails): - extra = "" - if error_type == SpecErrorType.input and ba is not None: - arg_key, *_ = fail_key - if arg_key.idx < len(ba.arguments): - param_name = list(ba.arguments.keys())[arg_key.idx] - extra = (f", where {base}{arg_key} is bound to {fun_name}'s " - f"parameter '{param_name}',") - else: - param = list(ba.signature.parameters.values())[-1] - assert param.kind == inspect.Parameter.VAR_POSITIONAL - extra = (f", where {base}{arg_key} is the index " - f"{arg_key.idx - len(ba.signature.parameters) + 1} component " - f"of {fun_name}'s varargs parameter '{param.name}',") - msgs.append( - f"* {prefix}_specs{keystr(spec_key)} is {spec} which has length " - f"{len(spec)}, but " - f"{base}{keystr(fail_key)}{extra} has shape {aval.str_short()}, " - f"which has rank {aval.ndim} (and {aval.ndim} < {len(spec)})") - assert msgs - if len(msgs) == 1: msgs = [msgs[0][2:]] # remove the bullet point - msg = (f"shard_map applied to the function '{fun_name}' was given an " - f"{prefix}_specs entry which is too long to be compatible with the " - f"corresponding {prefix}put value from the function:\n\n" - + '\n\n'.join(msgs) + '\n\n' + - f"Entries in {prefix}_specs must be of length no greater than the " - f"number of axes in the corresponding {prefix}put value.\n\n" - f"Either revise the spec to be shorter, or modify '{fun_name}' so " - f"that its {prefix}puts have sufficient rank.") - if any(not aval.ndim for _, (_, aval) in _iter_paths(tree, specs, fails)): - msg += (f"\n\nFor scalar values (rank 0), consider using an {prefix}_specs " - "entry of `P()`, where `P = jax.sharding.PartitionSpec`.") - return msg - -def _spec_divisibility_error( - f: Callable, mesh: Mesh, tree: PyTreeDef, specs: Specs, - fails: list[core.ShapedArray | NoFail]) -> str: - ba = _try_infer_args(f, tree) - fun_name = getattr(f, '__name__', str(f)) - msgs = [] - for (spec_key, spec), (fail_key, aval) in _iter_paths(tree, specs, fails): - extra = "" - if ba is not None: - arg_key, *_ = fail_key - if arg_key.idx < len(ba.arguments): - param_name = list(ba.arguments.keys())[arg_key.idx] - extra = (f", where args{arg_key} is bound to {fun_name}'s " - f"parameter '{param_name}',") - else: - param = list(ba.signature.parameters.values())[-1] - assert param.kind == inspect.Parameter.VAR_POSITIONAL - extra = (f", where args{arg_key} is the index " - f"{arg_key.idx - len(ba.signature.parameters) + 1} component " - f"of {fun_name}'s varargs parameter '{param.name}',") - names = _canonicalize_spec(spec) - for d, ns in names.items(): - if aval.shape[d] % prod(mesh.shape[n] for n in ns): - axis = f"axes {ns}" if len(ns) > 1 else f"axis '{ns[0]}'" - total = 'total ' if len(ns) > 1 else '' - sz = prod(mesh.shape[n] for n in ns) - msgs.append( - f"* args{keystr(fail_key)} of shape {aval.str_short()}{extra} " - f"corresponds to in_specs{keystr(spec_key)} of value {spec}, " - f"which maps array axis {d} (of size {aval.shape[d]}) to mesh " - f"{axis} (of {total}size {sz}), but {sz} does not evenly divide " - f"{aval.shape[d]}") - assert msgs - if len(msgs) == 1: msgs = [msgs[0][2:]] # remove the bullet point - msg = (f"shard_map applied to the function '{fun_name}' was given argument " - f"arrays with axis sizes that are not evenly divisible by the " - f"corresponding mesh axis sizes:\n\n" - f"The mesh given has shape {tuple(mesh.shape.values())} with " - f"corresponding axis names {mesh.axis_names}.\n\n" - + '\n\n'.join(msgs) + '\n\n' + - f"Array arguments' axis sizes must be evenly divisible by the mesh " - f"axis or axes indicated by the corresponding elements of the " - f"argument's in_specs entry. Consider checking that in_specs are " - f"correct, and if so consider changing the mesh axis sizes or else " - f"padding the input and adapting '{fun_name}' appropriately.") - return msg - -def _inout_rep_error(f: Callable, mesh: Mesh, tree: PyTreeDef, specs: Specs, - fails: list[set | NoFail]) -> str: - fun_name = getattr(f, '__name__', str(f)) - msgs = [] - for (spec_key, spec), (fail_key, rep) in _iter_paths(tree, specs, fails): - dst = _canonicalize_spec(spec) - unmentioned = _unmentioned(mesh, dst) - if len(unmentioned) > 1: - need_rep = ','.join(map(str, unmentioned)) - got_rep = ','.join(map(str, rep)) - diff = ','.join(map(str, [n for n in unmentioned if n not in rep])) - msgs.append( - f"* out_specs{keystr(spec_key)} is {spec} which implies that the " - f"corresponding output value is replicated across mesh axes " - f"{{{need_rep}}}, but could only infer replication over {{{got_rep}}}, " - f"which is missing the required axes {diff}") - else: - need_rep_, = unmentioned - msgs.append( - f"* out_specs{keystr(spec_key)} is {spec} which implies that the " - f"corresponding output value is replicated across mesh axis " - f"'{need_rep_}', but could not infer replication over any axes") - assert msgs - if len(msgs) == 1: msgs = [msgs[0][2:]] # remove the bullet point - msg = (f"shard_map applied to the function '{fun_name}' was given " - f"out_specs which require replication which can't be statically " - f"inferred given the mesh:\n\n" - f"The mesh given has shape {tuple(mesh.shape.values())} with " - f"corresponding axis names {mesh.axis_names}.\n\n" - + '\n\n'.join(msgs) + '\n\n' + - "Check if these output values are meant to be replicated over those " - "mesh axes. If not, consider revising the corresponding out_specs " - "entries. If so, consider disabling the check by passing the " - "check_rep=False argument to shard_map.") - return msg - -def _unmentioned(mesh: Mesh, names: AxisNames) -> list[AxisName]: - name_set = {n for ns in names.values() for n in ns} - return [n for n in mesh.axis_names if n not in name_set] - - -def _try_infer_args(f, tree): - dummy_args = tree_unflatten(tree, [False] * tree.num_leaves) - try: - return inspect.signature(f).bind(*dummy_args) - except (TypeError, ValueError): - return None - -T = TypeVar('T') -def _iter_paths(tree: PyTreeDef, specs: Specs, fails: list[T | NoFail] - ) -> list[tuple[tuple[KeyPath, P], tuple[KeyPath, T]]]: - failures = tree_unflatten(tree, fails) - failures_aug = generate_key_paths(failures) - specs_ = tree_unflatten(tree_structure(specs), generate_key_paths(specs)) - leaf = lambda x: x is None or type(x) is tuple and len(x) == 2 and type(x[1]) is P - specs_aug = broadcast_prefix(specs_, failures, is_leaf=leaf) - return [(s, (fail_key, fail_data)) for s, (fail_key, fail_data) - in zip(specs_aug, failures_aug) - if s is not None and fail_data is not no_fail] - -# Primitive - -JaxType = Any -MaybeTracer = Union[JaxType, Tracer] - -class ShardMapPrimitive(core.Primitive): - multiple_results = True - - def bind(self, *args, **params): - return self._true_bind(*args, **params) - - def bind_with_trace(self, trace, fun_and_args, params): - fun: lu.WrappedFun - fun, *args = fun_and_args - return trace.process_shard_map(shard_map_p, fun, args, **params) - - def get_bind_params(self, params): - new_params = dict(params) - jaxpr: core.Jaxpr = new_params.pop('jaxpr') - subfun = lu.hashable_partial(lu.wrap_init(core.eval_jaxpr, - debug_info=jaxpr.debug_info), - jaxpr, ()) - axes = new_params.pop('out_names') - new_params['out_names_thunk'] = HashableFunction(lambda: axes, closure=axes) - return [subfun], new_params - -shard_map_p = ShardMapPrimitive('shard_map') - -# Staging - -@util.cache(max_size=256, trace_context_in_key=True) -def _as_manual_mesh(mesh, auto: frozenset): - manual_axes = tuple(set(mesh.axis_names) - auto) - cur_mesh = get_abstract_mesh() - if cur_mesh.empty: - cur_mesh = mesh - explicit_axes, auto_axes = set(), set() # type: ignore - for a in auto: - if cur_mesh._name_to_type[a] == AxisType.Auto: - auto_axes.add(a) - else: - assert cur_mesh._name_to_type[a] == AxisType.Explicit - explicit_axes.add(a) - - new_axis_types = [] - for n in mesh.axis_names: - if n in manual_axes: - new_axis_types.append(AxisType.Manual) - elif n in auto_axes: - new_axis_types.append(AxisType.Auto) - else: - assert n in explicit_axes - new_axis_types.append(AxisType.Explicit) - return AbstractMesh(mesh.axis_sizes, mesh.axis_names, - axis_types=tuple(new_axis_types)) - - -def _extend_axis_env(mesh, auto): - return core.extend_axis_env_nd([(k, v) for k, v in mesh.shape.items() - if k not in auto]) - -def _shard_map_staging( - trace: pe.DynamicJaxprTrace, prim: core.Primitive, f: lu.WrappedFun, - in_tracers: Sequence[Any], *, mesh: Mesh, - in_names: tuple[AxisNames, ...], - out_names_thunk: Callable[[], tuple[AxisNames, ...]], - check_rep: bool, - rewrite: bool, - auto: frozenset, - ) -> Sequence[pe.DynamicJaxprTracer]: - in_tracers = map(trace.to_jaxpr_tracer, in_tracers) - in_avals = [t.aval for t in in_tracers] - in_avals_ = map(partial(_shard_aval, mesh, auto), in_names, in_avals) - manual_mesh = _as_manual_mesh(mesh, auto) - with _extend_axis_env(mesh, auto), use_abstract_mesh(manual_mesh): - jaxpr, out_avals_, consts, () = pe.trace_to_jaxpr_dynamic(f, in_avals_) - _check_names(out_names_thunk(), out_avals_) - if check_rep: - in_rep = map(partial(_in_names_to_rep, mesh), in_names) - out_rep = _check_rep(mesh, jaxpr, in_rep) - _check_reps(mesh, out_names_thunk(), out_rep) - out_avals = map(_check_shapedarray, out_avals_) - out_avals = [_check_shapedarray(_unshard_aval(mesh, names, aval)) - for names, aval in zip(out_names_thunk(), out_avals)] - source_info = source_info_util.current() - out_tracers = [pe.DynamicJaxprTracer(trace, a, source_info) for a in out_avals] - invars = map(trace.getvar, in_tracers) - constvars = map(trace.getvar, map(trace.to_jaxpr_tracer, consts)) - outvars = map(trace.makevar, out_tracers) - in_names_staged = ({},) * len(consts) + tuple(in_names) # type: ignore - with _extend_axis_env(mesh, auto), use_abstract_mesh(manual_mesh): - jaxpr = pe.convert_constvars_jaxpr(jaxpr) - params = dict(mesh=mesh, in_names=in_names_staged, - out_names=tuple(out_names_thunk()), jaxpr=jaxpr, - check_rep=check_rep, rewrite=rewrite, auto=auto) - effs = core.filter_named_axis_effects(jaxpr.effects, mesh.axis_names) - eqn = pe.new_jaxpr_eqn([*constvars, *invars], outvars, prim, params, - effs, source_info) - trace.frame.add_eqn(eqn) - return out_tracers -pe.DynamicJaxprTrace.process_shard_map = _shard_map_staging - -# TODO add underscore version, for direct-linearize to consume - -def _check_shapedarray(aval: core.AbstractValue) -> core.ShapedArray: - assert isinstance(aval, core.ShapedArray) - return aval - -def _shard_aval(mesh: Mesh, auto, names: AxisNames, aval: core.AbstractValue - ) -> core.AbstractValue: - if type(aval) in core.shard_aval_handlers: - return core.shard_aval_handlers[type(aval)](mesh, auto, names, aval) - raise NotImplementedError(f"Unsupported aval type: {type(aval)}") - -def _unshard_aval(mesh: Mesh, names: AxisNames, aval: core.AbstractValue - ) -> core.AbstractValue: - if type(aval) in core.unshard_aval_handlers: - return core.unshard_aval_handlers[type(aval)](mesh, names, aval) - else: - raise NotImplementedError(f"Unsupported aval type: {type(aval)}") - -def _shard_shaped_array(mesh: Mesh, auto: frozenset, names: AxisNames, - aval: core.AbstractValue) -> core.AbstractValue: - assert isinstance(aval, core.ShapedArray) - new_shape = tuple(sz // prod(mesh.shape[n] for n in names.get(i, ())) - for i, sz in enumerate(aval.shape)) - manual_mesh = _as_manual_mesh(mesh, auto) - new_sharding = NamedSharding(manual_mesh, aval.sharding.spec) - return aval.update(shape=new_shape, sharding=new_sharding) -core.shard_aval_handlers[core.ShapedArray] = _shard_shaped_array - -def _unshard_shaped_array(mesh: Mesh, names: AxisNames, - aval: core.AbstractValue,) -> core.AbstractValue: - assert isinstance(aval, core.ShapedArray) - new_shape = tuple(sz * prod(mesh.shape[n] for n in names.get(i, ())) - for i, sz in enumerate(aval.shape)) - names_spec = _names_to_pspec(names)._normalized_spec_for_aval(aval.ndim) - if aval.ndim == 0: - out_spec = names_spec - else: - out_spec = [] # type: ignore - for name_s, aval_s in zip(names_spec, aval.sharding.spec): - if name_s and not aval_s: - out_spec.append(name_s) - elif aval_s and not name_s: - out_spec.append(aval_s) - elif not name_s and not aval_s: - out_spec.append(None) - else: - assert name_s and aval_s - name_s = name_s if isinstance(name_s, tuple) else (name_s,) - aval_s = aval_s if isinstance(aval_s, tuple) else (aval_s,) - out_spec.append(name_s + aval_s) - out_spec = PartitionSpec(*out_spec) - new_mesh = (mesh.abstract_mesh if get_abstract_mesh().empty else - get_abstract_mesh()) - new_sharding = NamedSharding(new_mesh, out_spec) - return aval.update(shape=new_shape, sharding=new_sharding) -core.unshard_aval_handlers[core.ShapedArray] = _unshard_shaped_array - -# Type-checking - -RepType = Union[set[AxisName], None] - -def _shard_map_typecheck(_, *in_atoms, jaxpr, mesh, in_names, out_names, - check_rep, rewrite, auto): - # TODO(mattjj,parkers): check auto - for v, x, in_name in zip(jaxpr.invars, in_atoms, in_names): - if not core.typecompat(v.aval, _shard_aval(mesh, auto, in_name, x.aval)): - raise core.JaxprTypeError("shard_map argument avals not compatible with " - "jaxpr binder avals and in_names") - with _extend_axis_env(mesh, auto): - core.check_jaxpr(jaxpr) - if check_rep: - in_rep = map(partial(_in_names_to_rep, mesh), in_names) - out_rep = _check_rep(mesh, jaxpr, in_rep) - for rep, dst in zip(out_rep, out_names): - if not _valid_repeats(mesh, rep, dst): - raise core.JaxprTypeError("shard_map can't prove output is " - "sufficiently replicated") - out_avals_sharded = [x.aval for x in jaxpr.outvars] - out_avals = map(partial(_unshard_aval, mesh), out_names, out_avals_sharded) - effs = core.filter_named_axis_effects(jaxpr.effects, mesh.axis_names) - return out_avals, effs -core.custom_typechecks[shard_map_p] = _shard_map_typecheck - -def _in_names_to_rep(mesh: Mesh, names: AxisNames) -> set[AxisName]: - return set(mesh.axis_names) - {n for ns in names.values() for n in ns} - -def _check_rep(mesh: Mesh, jaxpr: core.Jaxpr, in_rep: Sequence[RepType] - ) -> Sequence[RepType]: - env: dict[core.Var, RepType] = {} - - def read(x: core.Atom) -> RepType: - return env[x] if type(x) is core.Var else None - - def write(v: core.Var, val: RepType) -> None: - env[v] = val - - foreach(write, jaxpr.constvars, [set(mesh.axis_names)] * len(jaxpr.constvars)) - foreach(write, jaxpr.invars, in_rep) - last_used = core.last_used(jaxpr) - for e in jaxpr.eqns: - rule = _check_rules.get(e.primitive, partial(_rule_missing, e.primitive)) - out_rep = rule(mesh, *map(read, e.invars), **e.params) - if e.primitive.multiple_results: - out_rep = [out_rep] * len(e.outvars) if type(out_rep) is set else out_rep - foreach(write, e.outvars, out_rep) - else: - write(e.outvars[0], out_rep) - core.clean_up_dead_vars(e, env, last_used) - return map(read, jaxpr.outvars) - -def _valid_repeats(mesh: Mesh, rep: RepType, dst: AxisNames) -> bool: - return rep is None or set(_unmentioned(mesh, dst)).issubset(rep) - -def _rule_missing(prim: core.Primitive, *_, **__): - raise NotImplementedError( - f"No replication rule for {prim}. As a workaround, pass the " - "`check_rep=False` argument to `shard_map`. To get this fixed, open an " - "issue at https://github.com/jax-ml/jax/issues") - -# Lowering - - -def _shardy_shard_map_sharding( - ctx: mlir.LoweringRuleContext, mesh, auto, names, aval_in -) -> sharding_impls.SdyArraySharding: - axes = {name: i for i, ns in names.items() for name in ns} - ns = _make_scoped_manual_sharding(ctx, mesh, axes) - if dtypes.issubdtype(aval_in.dtype, dtypes.extended): - ns = sharding_impls.physical_sharding(aval_in, ns) - aval_in = core.physical_aval(aval_in) - sdy_sharding = ns._to_sdy_sharding(aval_in.ndim) - if auto: - for dim_sharding in sdy_sharding.dimension_shardings: - # Only allow dimensions which have no sharding to be auto-sharded. - if not dim_sharding.axes: - dim_sharding.is_closed = False - return sdy_sharding - - -def _shard_map_lowering_shardy( - ctx, in_nodes, jaxpr, mesh, in_names, out_names, auto): - in_avals_ = [v.aval for v in jaxpr.invars] - if isinstance(ctx.module_context.axis_context, sharding_impls.SPMDAxisContext): - # Nested `ManualComputationOp`s cannot refer to axes that are already - # manual. So figure out what axes are free thus far. - free_axes = frozenset(mesh.axis_names) - ctx.module_context.axis_context.manual_axes - shardy_manual_axes = free_axes - auto - else: - shardy_manual_axes = frozenset(mesh.axis_names) - auto - new_axis_context = sharding_impls.SPMDAxisContext( - mesh, frozenset(mesh.axis_names) - auto) - sub_ctx = ctx.module_context.replace(axis_context=new_axis_context) - - # The order of manual axes should match the order of mesh.axis_names to avoid - # non-determinism issues. - manual_axes = [a for a in mesh.axis_names - if a in shardy_manual_axes] - if np.prod([mesh.shape[a] for a in manual_axes]) == 1: - # No need for a `ManualComputationOp` if all manual axes are size 1. - with _extend_axis_env(mesh, auto): - out_nodes, _ = mlir.jaxpr_subcomp( - sub_ctx, jaxpr, ctx.name_stack, mlir.TokenSet(), (), *in_nodes, - dim_var_values=ctx.dim_var_values) - return out_nodes - - in_shardings = sharding_impls.SdyArrayShardingList(map( - partial(_shardy_shard_map_sharding, ctx, mesh, auto), - in_names, ctx.avals_in)).build() - out_shardings = sharding_impls.SdyArrayShardingList(map( - partial(_shardy_shard_map_sharding, ctx, mesh, auto), - out_names, ctx.avals_out)).build() - output_types = map(mlir.aval_to_ir_type, ctx.avals_out) - manual_computation_op = sdy.ManualComputationOp( - output_types, in_nodes, in_shardings, out_shardings, - sdy.ManualAxesAttr.get( - ir.ArrayAttr.get([ir.StringAttr.get(i) for i in manual_axes]))) - block = ir.Block.create_at_start( - manual_computation_op.body, map(mlir.aval_to_ir_type, in_avals_)) - with ir.InsertionPoint(block), _extend_axis_env(mesh, auto): - out_nodes_, _ = mlir.jaxpr_subcomp( - sub_ctx, jaxpr, ctx.name_stack, mlir.TokenSet(), (), *block.arguments, - dim_var_values=ctx.dim_var_values) - sdy.ReturnOp([ir.Value(x) for x in out_nodes_]) - - return manual_computation_op.results - - -def _shard_map_lowering(ctx, *in_nodes, jaxpr, mesh, in_names, out_names, - check_rep, rewrite, auto): - del check_rep, rewrite - - if config.use_shardy_partitioner.value: - return _shard_map_lowering_shardy( - ctx, in_nodes, jaxpr, mesh, in_names, out_names, auto) - - in_avals_ = [v.aval for v in jaxpr.invars] - out_avals_ = [x.aval for x in jaxpr.outvars] - in_nodes_ = map(partial(_xla_shard, ctx, mesh, auto), in_names, ctx.avals_in, - in_avals_, in_nodes) - manual_axes = frozenset(mesh.axis_names) - auto - new_axis_context = sharding_impls.SPMDAxisContext(mesh, manual_axes) - sub_ctx = ctx.module_context.replace(axis_context=new_axis_context) - with _extend_axis_env(mesh, auto): - out_nodes_, tokens_out = mlir.call_lowering( - "shmap_body", ctx.name_stack, jaxpr, None, sub_ctx, in_avals_, - out_avals_, ctx.tokens_in, *in_nodes_, dim_var_values=ctx.dim_var_values, - arg_names=map(_pspec_mhlo_attrs, in_names, in_avals_), - result_names=map(_pspec_mhlo_attrs, out_names, out_avals_)) - ctx.set_tokens_out(tokens_out) - return map(partial(_xla_unshard, ctx, mesh, auto), out_names, out_avals_, - ctx.avals_out, out_nodes_) -mlir.register_lowering(shard_map_p, _shard_map_lowering) - -def _make_scoped_manual_sharding(ctx, mesh, axes): - axis_ctx = ctx.module_context.axis_context - if isinstance(axis_ctx, sharding_impls.SPMDAxisContext): - manual_axes = axis_ctx.manual_axes - else: - manual_axes = frozenset({}) - return NamedSharding( - mesh, sharding_impls.array_mapping_to_axis_resources(axes), # pytype: disable=wrong-arg-types - _manual_axes=manual_axes) - -def _xla_shard(ctx: mlir.LoweringRuleContext, mesh, auto, names, - aval_in, aval_out, x): - if prod([size for n, size in mesh.shape.items() if n not in auto]) == 1: - return x - axes = {name: i for i, ns in names.items() for name in ns} - ns = _make_scoped_manual_sharding(ctx, mesh, axes) - if dtypes.issubdtype(aval_in.dtype, dtypes.extended): - ns = sharding_impls.physical_sharding(aval_in, ns) - aval_in = core.physical_aval(aval_in) - shard_proto = ns._to_xla_hlo_sharding(aval_in.ndim).to_proto() - unspecified = set(range(aval_in.ndim)) if auto else set() - sx = mlir.wrap_with_sharding_op(ctx, x, aval_in, shard_proto, - unspecified_dims=unspecified) - manual_proto = pxla.manual_proto(aval_in, frozenset(mesh.axis_names) - auto, mesh) - return mlir.wrap_with_full_to_shard_op(ctx, sx, aval_out, manual_proto, unspecified) - -def _xla_unshard(ctx: mlir.LoweringRuleContext, mesh, auto, names, - aval_in, aval_out, x): - if prod([size for n, size in mesh.shape.items() if n not in auto]) == 1: - return x - axes = {name: i for i, ns in names.items() for name in ns} - ns = _make_scoped_manual_sharding(ctx, mesh, axes) - if dtypes.issubdtype(aval_out.dtype, dtypes.extended): - ns = sharding_impls.physical_sharding(aval_out, ns) - aval_out = core.physical_aval(aval_out) - unspecified = set(range(aval_out.ndim)) if auto else set() - if dtypes.issubdtype(aval_in.dtype, dtypes.extended): - aval_in = core.physical_aval(aval_in) - manual_proto = pxla.manual_proto(aval_in, frozenset(mesh.axis_names) - auto, mesh) - sx = mlir.wrap_with_sharding_op(ctx, x, aval_in, manual_proto, unspecified_dims=unspecified) - shard_proto = ns._to_xla_hlo_sharding(aval_out.ndim).to_proto() - return mlir.wrap_with_shard_to_full_op(ctx, sx, aval_out, shard_proto, - unspecified) - -def _pspec_mhlo_attrs(names: AxisNames, aval: core.AbstractValue) -> str: - if isinstance(aval, core.ShapedArray): - return str(map(names.get, range(aval.ndim))) - return '' - -# Eager evaluation - -def get_mesh_from_args(args_flat, mesh): - for a in args_flat: - if hasattr(a, 'sharding') and isinstance(a.sharding, NamedSharding): - if a.sharding.mesh.shape_tuple != mesh.shape_tuple: - aval = core.shaped_abstractify(a) - raise ValueError( - f"Mesh shape of the input {a.sharding.mesh.shape_tuple} does not" - " match the mesh shape passed to shard_map " - f" {mesh.shape_tuple} for shape {aval.str_short()}") - mesh = a.sharding.mesh - if isinstance(mesh, AbstractMesh): - raise ValueError( - "Please pass `jax.Array`s with a `NamedSharding` as input to" - " `shard_map` when passing `AbstractMesh` to the mesh argument.") - assert isinstance(mesh, Mesh) - return mesh - -def _shard_map_impl(trace, prim, fun, args, *, mesh, in_names, out_names_thunk, - check_rep, rewrite, auto): - if auto: raise NotImplementedError - del prim - if isinstance(mesh, AbstractMesh): - mesh = get_mesh_from_args(args, mesh) - args = map(partial(_unmatch_spec, mesh, context_mesh=get_abstract_mesh()), - in_names, args) - in_rep = map(partial(_in_names_to_rep, mesh), in_names) - outs, out_rep = _run_shmap(fun, mesh, auto, args, in_rep, check_rep, - get_abstract_mesh()) - out_avals = [core.mapped_aval(x.shape[0], 0, core.get_aval(x)) for x in outs] - _check_names(out_names_thunk(), out_avals) # pytype: disable=wrong-arg-types - if check_rep: - _check_reps(mesh, out_names_thunk(), out_rep) - pspecs = map(_names_to_pspec, out_names_thunk()) - return map(partial(_match_spec, mesh, check_rep), pspecs, outs) -core.EvalTrace.process_shard_map = _shard_map_impl - -def _run_shmap(f, mesh, auto, args, reps, check_rep, context_mesh): - trace = ShardMapTrace(mesh, auto, check_rep, context_mesh) - in_tracers = map(partial(ShardMapTracer, trace), reps, args) - manual_mesh = _as_manual_mesh(mesh, auto) - with (core.set_current_trace(trace), _extend_axis_env(mesh, auto), - use_abstract_mesh(manual_mesh)): - ans = f.call_wrapped(*in_tracers) - outs, out_rep = unzip2(map(trace.to_val_rep_pair, ans)) - return outs, out_rep - -def _names_to_pspec(names: AxisNames) -> PartitionSpec: - ndmin = max(names) + 1 if names else 0 - unpack = lambda t: t[0] if t is not None and len(t) == 1 else t - return PartitionSpec(*(unpack(names.get(i)) for i in range(ndmin))) - -def _unmatch_spec(mesh: Mesh, src: AxisNames, x: JaxType, context_mesh) -> JaxType: - with (core.eval_context(), jax.disable_jit(False), - use_abstract_mesh(context_mesh)): - return jax.jit(HashablePartial(_unmatch, mesh, tuple(src.items())))(x) - -def _unmatch(mesh, src_tup, x): - src = _names_to_pspec(dict(src_tup)) - dst = P(mesh.axis_names) - return shard_map(_add_singleton, mesh, (src,), dst, check_rep=False)(x) - -def _check_names(names: Sequence[AxisNames], avals: Sequence[core.ShapedArray] - ) -> None: - fail = [a if n and not max(n) < a.ndim else no_fail - for n, a in zip(names, avals)] - if any(f is not no_fail for f in fail): raise _SpecError(fail) -class _SpecError(Exception): pass - -def _check_reps(mesh, names, reps): - fail = [r if not _valid_repeats(mesh, r, n) else no_fail - for n, r in zip(names, reps)] - if any(f is not no_fail for f in fail): raise _RepError(fail) -class _RepError(Exception): pass - -def _check_reps2(mesh, reps_dest, reps): - fail = [src if not dst.issubset(src) else no_fail - for dst, src in zip(reps_dest, reps)] - if any(f is not no_fail for f in fail): raise _RepError(fail) - -def _match_spec(mesh: Mesh, check_rep: bool, - pspec: PartitionSpec, x: JaxType) -> JaxType: - fn = HashablePartial(_match, mesh, check_rep, pspec) - with core.eval_context(), jax.disable_jit(False): - return jax.jit(fn, out_shardings=NamedSharding(mesh, pspec))(x) - -def _match(mesh, check_rep, pspec, x): - src = P(mesh.axis_names) - return shard_map(_rem_singleton, mesh, (src,), pspec, check_rep=False)(x) - -def _rem_singleton(x): return jnp.squeeze(x, axis=0) -def _add_singleton(x): return jnp.expand_dims(x, axis=0) - -def _maybe_check_special(outs): - if not config.debug_nans.value and not config.debug_infs.value: return - bufs = [s.data for leaf in tree_leaves(outs) - for s in getattr(leaf, 'addressable_shards', [])] - try: - dispatch.check_special('shard_map', bufs) - except dispatch.InternalFloatingPointError as e: - raise FloatingPointError(f'Invalid value ({e.ty}) encountered in sharded computation.') from None - -class ShardMapTrace(core.Trace): - __slots__ = ("mesh", "auto", "check", "context_mesh") - - mesh: Mesh - auto: frozenset[AxisName] - check: bool - context_mesh: AbstractMesh - - def __init__(self, mesh, auto, check, context_mesh): - super().__init__() - self.mesh = mesh - self.auto = auto - self.check = check - self.context_mesh = context_mesh - - def to_val_rep_pair(self, val): - if isinstance(val, ShardMapTracer): - return val.val, val.rep - elif isinstance(val, Tracer): - raise Exception(f"Shouldn't have any non-shard_map tracers: {val}") - else: - val_ = _unmatch_spec(self.mesh, {}, val, self.context_mesh) - return val_, None - - def process_primitive(self, prim, tracers, params): - in_vals, in_rep = unzip2(map(self.to_val_rep_pair, tracers)) - eager_rule = eager_rules.get(prim) - if eager_rule: - out_vals = eager_rule(self.mesh, *in_vals, **params) - else: - f = HashablePartial(_prim_applier, prim, tuple(params.items()), self.mesh) - with (core.eval_context(), jax.disable_jit(False), jax.debug_nans(False), - jax.debug_infs(False), use_abstract_mesh(self.context_mesh)): - out_vals = jax.jit(f)(*in_vals) - _maybe_check_special(out_vals) - rep_rule = _check_rules.get(prim, partial(_rule_missing, prim)) - out_rep = rep_rule(self.mesh, *in_rep, **params) if self.check else set() - if prim.multiple_results: - out_rep = [out_rep] * len(out_vals) if type(out_rep) is set else out_rep - return map(partial(ShardMapTracer, self), out_rep, out_vals) - return ShardMapTracer(self, out_rep, out_vals) - - def process_call(self, call_primitive, fun, tracers, params): - raise NotImplementedError( - f"Eager evaluation of `{call_primitive}` inside a `shard_map` isn't " - "yet supported. Put a `jax.jit` around the `shard_map`-decorated " - "function, and open a feature request at " - "https://github.com/jax-ml/jax/issues !") - - def process_map(self, map_primitive, fun, tracers, params): - raise NotImplementedError( - "Eager evaluation of `pmap` inside a `shard_map` isn't yet supported." - "Put a `jax.jit` around the `shard_map`-decorated function, and open " - "a feature request at https://github.com/jax-ml/jax/issues !") - - def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros): - # Since ShardMapTrace is only used as a base main, we can drop the jvp. - if symbolic_zeros: - msg = ("custom_jvp symbolic_zeros support with shard_map is not " - "implemented; please open an issue at " - "https://github.com/jax-ml/jax/issues") - raise NotImplementedError(msg) - del prim, jvp, symbolic_zeros - in_vals, in_rep = unzip2(map(self.to_val_rep_pair, tracers)) - out_vals, out_rep = _run_shmap(fun, self.mesh, self.auto, in_vals, in_rep, self.check, - self.context_mesh) - return map(partial(ShardMapTracer, self), out_rep, out_vals) - - def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees, - symbolic_zeros): - if symbolic_zeros: - msg = ("custom_vjp symbolic_zeros support with shard_map is not " - "implemented; please open an issue at " - "https://github.com/jax-ml/jax/issues") - raise NotImplementedError(msg) - del prim, fwd, bwd, out_trees, symbolic_zeros - in_vals, in_rep = unzip2(map(self.to_val_rep_pair, tracers)) - out_vals, out_rep = _run_shmap(fun, self.mesh, self.auto, in_vals, in_rep, self.check, - self.context_mesh) - return map(partial(ShardMapTracer, self), out_rep, out_vals) - - -class ShardMapTracer(core.Tracer): - rep: RepType - val: JaxType - - def __init__(self, trace, rep, val): - self._trace = trace - self.rep = rep - self.val = val - - @property - def aval(self): - aval = core.get_aval(self.val) - out = core.mapped_aval(self._trace.mesh.size, 0, aval) - new_sharding = NamedSharding( - _as_manual_mesh(self._trace.mesh, self._trace.auto), - out.sharding.spec) # pytype: disable=attribute-error - return out.update(sharding=new_sharding) - - def to_concrete_value(self): - if self.rep == set(self._trace.mesh.axis_names): - with core.eval_context(), use_abstract_mesh(self._trace.context_mesh): - return core.to_concrete_value(self.val[0]) - else: - return None - - def __str__(self) -> str: - with core.eval_context(), use_abstract_mesh(self._trace.context_mesh): - blocks = list(self.val) - mesh = self._trace.mesh - axis_names = f"({', '.join(map(str, mesh.axis_names))},)" - return '\n'.join( - f"On {device} at mesh coordinates {axis_names} = {idx}:\n{block}\n" - for (idx, device), block in zip(np.ndenumerate(mesh.devices), blocks)) - __repr__ = __str__ # for debuggers, like `p x` - -def _prim_applier(prim, params_tup, mesh, *args): - def apply(*args): - outs = prim.bind(*map(_rem_singleton, args), **dict(params_tup)) - return tree_map(_add_singleton, outs) - spec = P(mesh.axis_names) - return shard_map(apply, mesh, spec, spec, False)(*args) - -eager_rules: dict[core.Primitive, Callable] = {} - -# TODO(mattjj): working around an apparent XLA or PjRt bug, remove eventually -def _debug_callback_eager_rule(mesh, *args, callback: Callable[..., Any], - effect: debugging.DebugEffect): - del effect - with core.eval_context(): - all_blocks = zip(*map(list, args)) - for (idx, device), blocks in zip(np.ndenumerate(mesh.devices), all_blocks): - callback(*blocks) - return [] -eager_rules[debugging.debug_callback_p] = _debug_callback_eager_rule - -def _device_put_eager_rule(mesh, *xs, srcs, devices, copy_semantics): - del mesh, srcs, copy_semantics - for device in devices: - if device is not None: - raise ValueError("device_put with explicit device not allowed within " - f"shard_map-decorated functions, but got device {device}") - return xs -eager_rules[dispatch.device_put_p] = _device_put_eager_rule - -# New primitives for efficient transposition - -# psum2_p is like psum_p except has a different transpose, so mostly copied: -psum2_p = core.Primitive('psum2') -psum2_p.multiple_results = True -psum2_p.def_impl(lax_parallel.psum_p.impl) -psum2_p.def_effectful_abstract_eval(lax_parallel.psum_p.abstract_eval) -mlir.register_lowering(psum2_p, mlir._lowerings[lax_parallel.psum_p]) -batching.fancy_primitive_batchers[psum2_p] = \ - partial(lax_parallel._batched_reduction_collective, psum2_p, - lambda v, axis_size: axis_size * v) -batching.skippable_batchers[psum2_p] = partial(lax_parallel._names_in_param, 'axes') - -def _psum2_transpose_rule(cts, *args, axes, axis_index_groups): - del args - return pbroadcast_p.bind(*cts, axes=axes, axis_index_groups=axis_index_groups) -ad.deflinear2(psum2_p, _psum2_transpose_rule) - -# pbroadcast_p is exactly the transpose of psum2_p -def pbroadcast(x, axis_name): - axes = (axis_name,) if not isinstance(axis_name, tuple) else axis_name - if not axis_name: return x - xs, treedef = tree_flatten(x) - ys = pbroadcast_p.bind(*xs, axes=axes, axis_index_groups=None) - return tree_unflatten(treedef, ys) -pbroadcast_p = core.Primitive('pbroadcast') -pbroadcast_p.multiple_results = True -pbroadcast_p.def_impl(lambda *args, axes, axis_index_groups: args) -pbroadcast_p.def_abstract_eval(lambda *args, axes, axis_index_groups: args) -mlir.register_lowering(pbroadcast_p, lambda ctx, *x, axes, axis_index_groups: x) -def _pbroadcast_batcher(vals_in, dims_in, *, axes, axis_index_groups): - if any(type(axis) is int for axis in axes): raise NotImplementedError - vals_out = pbroadcast_p.bind(*vals_in, axes=axes, - axis_index_groups=axis_index_groups) - return vals_out, dims_in -batching.primitive_batchers[pbroadcast_p] = _pbroadcast_batcher -ad.deflinear2(pbroadcast_p, - lambda cts, *_, axes, axis_index_groups: - psum2_p.bind(*cts, axes=axes, axis_index_groups=axis_index_groups)) - -# Rewrite rules and static replication checking for efficient transposition - -_rewrite_rules: dict[core.Primitive, Callable] = {} -register_rewrite = lambda prim: lambda r: _rewrite_rules.setdefault(prim, r) -register_standard_rewrite = lambda prim: \ - _rewrite_rules.setdefault(prim, partial(_standard_rewrite_rule, prim)) -register_norewrite = lambda p: \ - _rewrite_rules.setdefault(p, partial(_no_rewrite, p, _check_rules[p])) - -_check_rules: dict[core.Primitive, Callable] = {} -register_check = lambda prim: lambda rule: _check_rules.setdefault(prim, rule) -register_standard_check = \ - lambda prim: _check_rules.setdefault(prim, partial(_standard_check, prim)) - -def _no_rewrite(prim, rule, mesh, in_rep, *args, **params): - out_vals = prim.bind(*args,**params) - out_rep = rule(mesh, *in_rep, **params) - if prim.multiple_results: - out_rep_ = out_rep if type(out_rep) is list else [out_rep] * len(out_vals) - else: - out_vals, out_rep_ = [out_vals], [out_rep] - return out_vals, out_rep_ - -def _standard_rewrite_rule(prim, mesh, in_rep, *args, **params): - # The standard rewrite inserts pbroadcasts but doesn't change the primitive. - out_rep_ = set.intersection(*in_rep) if in_rep else set(mesh.axis_names) - args_ = [pbroadcast(x, tuple(n for n in src if n not in out_rep_)) - if src - out_rep_ else x for x, src in zip(args, in_rep)] - out_vals_ = prim.bind(*args_, **params) - out_rep = [out_rep_] * len(out_vals_) if prim.multiple_results else [out_rep_] - out_vals = [out_vals_] if not prim.multiple_results else out_vals_ - return out_vals, out_rep - -def _standard_check(prim, mesh, *in_rep, **__): - # The standard check require args' and outputs' replications to be the same, - # except for Nones which correspond to constants. - in_rep_ = [r for r in in_rep if r is not None] - if in_rep_ and not in_rep_[:-1] == in_rep_[1:]: - raise Exception(f"Primitive {prim} requires argument replication types " - f"to match, but got {in_rep}. Please open an issue at " - "https://github.com/jax-ml/jax/issues and as a temporary " - "workaround pass the check_rep=False argument to shard_map") - return in_rep_[0] if in_rep_ else None - -def register_standard_collective(prim): - register_check(prim)(partial(_standard_collective_check, prim)) - register_rewrite(prim)(partial(_standard_collective_rewrite, prim)) - -def register_reduction_collective(prim): - register_check(prim)(partial(_reduction_collective_check, prim)) - register_rewrite(prim)(partial(_reduction_collective_rewrite, prim)) - -def _standard_collective_check(prim, mesh, x_rep, *, axis_name, **params): - # The standard collective check is varying -> varying over axis_name. - del mesh, params - if x_rep is None or axis_name in x_rep: - raise Exception(f"Collective {prim} must be applied to a device-varying " - f"replication type, but got {x_rep} for collective acting " - f"over axis name {axis_name}. Please open an issue at " - "https://github.com/jax-ml/jax/issues and as a temporary " - "workaround pass the check_rep=False argument to shard_map") - return x_rep - -def _standard_collective_rewrite(prim, mesh, in_rep, x, axis_name, **params): - # The standard collective rewrite may insert a pbroadcast on the input. - axis_name = (axis_name,) if not isinstance(axis_name, tuple) else axis_name - x_rep, = in_rep - axis_name_set = set(axis_name) - if pbroadcast_axis_name := axis_name_set & x_rep: - x = pbroadcast(x, tuple(pbroadcast_axis_name)) - out_val = prim.bind(x, axis_name=axis_name, **params) - return [out_val], [x_rep - axis_name_set] - -def _reduction_collective_check(prim, mesh, x_rep, *, axes, **params): - # The reduction collective check is varying -> replicated over axes. - del mesh, params - axes = (axes,) if not isinstance(axes, tuple) else axes - if x_rep is None or any(a in x_rep for a in axes): - raise Exception(f"Collective {prim} must be applied to a device-varying " - f"replication type, but got {x_rep} for collective acting " - f"over axis name {axes}. Please open an issue at " - "https://github.com/jax-ml/jax/issues and as a temporary " - "workaround pass the check_rep=False argument to shard_map") - return x_rep | set(axes) - -def _reduction_collective_rewrite(prim, mesh, in_rep, x, axes, **params): - # The standard collective rewrite may insert a pbroadcast on the input. - axes = (axes,) if not isinstance(axes, tuple) else axes - x_rep, = in_rep - axes_set = set(axes) - if pbroadcast_axes := axes_set & x_rep: - x = pbroadcast(x, tuple(pbroadcast_axes)) - out_val, = prim.bind(x, axes=axes, **params) - return [out_val], [x_rep | axes_set] - - -for o in it.chain(lax.__dict__.values(), slicing.__dict__.values(), - windowed_reductions.__dict__.values(), - special.__dict__.values(), convolution.__dict__.values(), - fft.__dict__.values(), linalg.__dict__.values(), - ops.__dict__.values(), ad_util.__dict__.values(), - prng.__dict__.values(), ann.__dict__.values(), - random.__dict__.values()): - if isinstance(o, core.Primitive): - register_standard_check(o) - register_standard_rewrite(o) - -for p in [control_flow.loops.cumsum_p, control_flow.loops.cumlogsumexp_p, - control_flow.loops.cumprod_p, control_flow.loops.cummax_p, - control_flow.loops.cummin_p, pjit.sharding_constraint_p, - pjit.mesh_cast_p]: - register_standard_check(p) - register_standard_rewrite(p) - - -@register_check(lax_parallel.psum_p) -def _psum_check(_, *in_rep, axes, axis_index_groups): - assert False # should be rewritten away - -@register_rewrite(lax_parallel.psum_p) -def _psum_rewrite(mesh, in_rep, *args, axes, axis_index_groups): - # Replace the psum with psum2, insert pbroadcasts on input, replicated output. - if axis_index_groups is not None: raise NotImplementedError - axes = (axes,) if not isinstance(axes, tuple) else axes - axes_ = set(axes) - out_rep = [r | axes_ for r in in_rep] # TODO determinism (and elsewhere) - args_ = [pbroadcast(x, tuple(n for n in mesh.axis_names if n in axes_ & src)) - for x, src in zip(args, in_rep)] - out_val = psum2_p.bind(*args_, axes=axes, axis_index_groups=axis_index_groups) - return out_val, out_rep - - -@register_check(psum2_p) -def _psum2_check(mesh, *in_rep, axes, axis_index_groups): - assert type(axes) is tuple - if any(set(axes) & r for r in in_rep if r is not None): - raise Exception("Collective psum must be applied to a device-varying " - f"replication type, but got {in_rep} for collective acting " - f"over axis name {axes}. Please open an issue at " - "https://github.com/jax-ml/jax/issues, and as a temporary " - "workaround pass the check_rep=False argument to shard_map") - in_rep = tuple(set(mesh.axis_names) if r is None else r for r in in_rep) - return [r | set(axes) for r in in_rep] -register_norewrite(psum2_p) - - -@register_check(pbroadcast_p) -def _pbroadcast_check(mesh, *in_rep, axes, axis_index_groups): - assert type(axes) is tuple - if not all(r is None or set(axes) & r for r in in_rep): - raise Exception("Collective pbroadcast must be applied to a " - "non-device-varying " - f"replication type, but got {in_rep} for collective acting " - f"over axis name {axes}. Please open an issue at " - "https://github.com/jax-ml/jax/issues, and as a temporary " - "workaround pass the check_rep=False argument to shard_map") - in_rep = tuple(set(mesh.axis_names) if r is None else r for r in in_rep) - return [r - set(axes) for r in in_rep] -register_norewrite(pbroadcast_p) - - -register_standard_collective(lax_parallel.all_gather_p) -register_standard_collective(lax_parallel.all_to_all_p) -register_standard_collective(lax_parallel.ppermute_p) -register_standard_collective(lax_parallel.reduce_scatter_p) -register_reduction_collective(lax_parallel.pmin_p) -register_reduction_collective(lax_parallel.pmax_p) - - -@register_check(lax_parallel.axis_index_p) -def _axis_index_check(mesh, *, axis_name): - axis_name = (axis_name,) if not type(axis_name) is tuple else axis_name - return set(mesh.shape) - set(axis_name) -register_norewrite(lax_parallel.axis_index_p) - - -@register_rewrite(pjit.pjit_p) -def _pjit_rewrite(mesh, in_rep, *args, jaxpr, **kwargs): - jaxpr_, out_rep = _replication_rewrite_nomatch(mesh, jaxpr, in_rep) - out_vals = pjit.pjit_p.bind(*args, jaxpr=jaxpr_, **kwargs) - return out_vals, out_rep - -@register_check(pjit.pjit_p) -def _pjit_check(mesh, *in_rep, jaxpr, **kwargs): - return _check_rep(mesh, jaxpr.jaxpr, in_rep) - - -@register_rewrite(ad_checkpoint.remat_p) -def _remat_rewrite(mesh, in_rep, *args, jaxpr, **kwargs): - jaxpr_ = pe.close_jaxpr(jaxpr) - jaxpr_, out_rep = _replication_rewrite_nomatch(mesh, jaxpr_, in_rep) - jaxpr, () = jaxpr_.jaxpr, jaxpr_.consts - out_vals = ad_checkpoint.remat_p.bind(*args, jaxpr=jaxpr, **kwargs) - return out_vals, out_rep - -@register_check(ad_checkpoint.remat_p) -def _remat_check(mesh, *in_rep, jaxpr, **kwargs): - return _check_rep(mesh, jaxpr, in_rep) - - -@register_check(core.call_p) -def _core_call_check(mesh, *in_rep, call_jaxpr, **kwargs): - return _check_rep(mesh, call_jaxpr, in_rep) - - -@register_check(debugging.debug_callback_p) -def _debug_callback_rule(mesh, *in_rep, **_): - return [] -register_norewrite(debugging.debug_callback_p) - - -@register_check(callback.pure_callback_p) -def _pure_callback_rule(mesh, *_, result_avals, **__): - return [set()] * len(result_avals) -register_norewrite(callback.pure_callback_p) - - -@register_check(callback.io_callback_p) -def _io_callback_rule(mesh, *_, result_avals, **__): - return [set()] * len(result_avals) -register_norewrite(callback.io_callback_p) - - -@register_check(dispatch.device_put_p) -def _device_put_rule(mesh, *xs, **_): - return list(xs) -register_norewrite(dispatch.device_put_p) - - -@register_check(ad.custom_lin_p) -def _custom_lin_rule(mesh, *_, out_avals, **__): - return [set()] * len(out_avals) -register_norewrite(ad.custom_lin_p) - - -@register_check(control_flow.loops.scan_p) -def _scan_check(mesh, *in_rep, jaxpr, num_consts, num_carry, **_): - _, carry_rep_in, _ = split_list(in_rep, [num_consts, num_carry]) - out_rep = _check_rep(mesh, jaxpr.jaxpr, in_rep) - carry_rep_out, _ = split_list(out_rep, [num_carry]) - if not carry_rep_in == carry_rep_out: - raise Exception("Scan carry input and output got mismatched replication " - f"types {carry_rep_in} and {carry_rep_out}. Please open an " - "issue at https://github.com/jax-ml/jax/issues, and as a " - "temporary workaround pass the check_rep=False argument to " - "shard_map") - return out_rep - -@register_rewrite(control_flow.loops.scan_p) -def _scan_rewrite(mesh, in_rep, *args, jaxpr, num_consts, num_carry, **params): - const_rep, carry_rep_in, xs_rep = split_list(in_rep, [num_consts, num_carry]) - for _ in range(1 + num_carry): - in_rep_ = [*const_rep, *carry_rep_in, *xs_rep] - _, out_rep = _replication_rewrite_nomatch(mesh, jaxpr, in_rep_) - carry_rep_out, ys_rep = split_list(out_rep, [num_carry]) - carry_rep_out = map(op.and_, carry_rep_in, carry_rep_out) - if carry_rep_in == carry_rep_out: - break - else: - carry_rep_in = carry_rep_out - else: - assert False, 'Fixpoint not reached' - - args = [pbroadcast(x, tuple(n for n in src if n not in dst)) - if src - dst else x for x, src, dst in zip(args, in_rep, in_rep_)] - out_rep = [*carry_rep_out, *ys_rep] - jaxpr_ = _replication_rewrite_match(mesh, jaxpr, in_rep_, out_rep) - - out_vals = control_flow.loops.scan_p.bind( - *args, jaxpr=jaxpr_, num_consts=num_consts, num_carry=num_carry, **params) - return out_vals, out_rep - -@register_check(control_flow.conditionals.cond_p) -def _cond_rule(mesh, *in_rep, branches): - _, *args_rep = in_rep - out_rep = _check_rep(mesh, branches[0].jaxpr, args_rep) - for branch in branches[1:]: - out_rep_ = _check_rep(mesh, branch.jaxpr, args_rep) - if not out_rep_ == out_rep: - raise Exception("The branches of cond produced mismatched replication " - "types. Please open an issue at " - "https://github.com/jax-ml/jax/issues, and as a " - "temporary workaround pass the check_rep=False argument " - "to shard_map") - return out_rep - -@register_rewrite(control_flow.conditionals.cond_p) -def _cond_rewrite(mesh, in_rep, *args, branches): - pred_rep, *args_rep = in_rep - _, out_rep = _replication_rewrite_nomatch(mesh, branches[0], args_rep) - for branch in branches[1:]: - _, out_rep_ = _replication_rewrite_nomatch(mesh, branch, args_rep) - if out_rep: - out_rep = map(op.and_, out_rep, out_rep_) - else: - out_rep = out_rep_ - out_rep = map(partial(op.and_, pred_rep), out_rep) - branches_ = tuple(_replication_rewrite_match(mesh, branch, args_rep, out_rep) - for branch in branches) - out_vals = control_flow.conditionals.cond_p.bind(*args, branches=branches_) - return out_vals, out_rep - -@register_check(control_flow.conditionals.platform_index_p) -def _platform_index_rule(mesh, *_, **__): - return set(mesh.axis_names) -register_norewrite(control_flow.conditionals.platform_index_p) - -@register_rewrite(core.closed_call_p) -def _closed_call_rewrite(mesh, in_rep, *args, call_jaxpr, **kwargs): - new_jaxpr, out_rep = _replication_rewrite_nomatch(mesh, call_jaxpr, in_rep) - out_vals = core.closed_call_p.bind(*args, jaxpr=new_jaxpr, **kwargs) - return out_vals, out_rep - -@register_check(core.closed_call_p) -def _closed_call_check(mesh, *in_rep, call_jaxpr, **kwargs): - return _check_rep(mesh, call_jaxpr.jaxpr, in_rep) - - -@register_check(custom_derivatives.custom_jvp_call_p) -def _custom_jvp_call_check(mesh, *in_rep, call_jaxpr, jvp_jaxpr_fun, - num_consts, symbolic_zeros): - return _check_rep(mesh, call_jaxpr.jaxpr, in_rep) - -@register_rewrite(custom_derivatives.custom_vjp_call_jaxpr_p) -def _custom_vjp_call_jaxpr_rewrite( - mesh, in_rep, *args, fun_jaxpr, fwd_jaxpr_thunk, bwd, num_consts, out_trees, - symbolic_zeros): - if symbolic_zeros: - msg = ("Please open an issue at https://github.com/jax-ml/jax/issues and as" - " a temporary workaround pass the check_rep=False argument to " - "shard_map") - raise NotImplementedError(msg) - - fun_jaxpr_, out_rep = _replication_rewrite_nomatch(mesh, fun_jaxpr, in_rep) - _, in_rep_ = split_list(in_rep, [num_consts]) - out_rep2 = [] - - @pe._memoize - def fwd_jaxpr_thunk_(*zeros): - fwd_jaxpr = core.ClosedJaxpr(*fwd_jaxpr_thunk(*zeros)) - fwd_jaxpr_, out_rep = _replication_rewrite_nomatch(mesh, fwd_jaxpr, in_rep_) - out_rep2.append(out_rep) - return fwd_jaxpr_.jaxpr, fwd_jaxpr_.consts - - bwd_ = _rewrite_bwd(bwd, mesh, lambda: out_rep2[0], in_rep_) - - outs = custom_derivatives.custom_vjp_call_jaxpr_p.bind( - *args, fun_jaxpr=fun_jaxpr_, fwd_jaxpr_thunk=fwd_jaxpr_thunk_, bwd=bwd_, - num_consts=num_consts, out_trees=out_trees, symbolic_zeros=symbolic_zeros) - out_rep = out_rep2[0] if out_rep2 else out_rep - return outs, out_rep - -@register_check(custom_derivatives.custom_vjp_call_jaxpr_p) -def _custom_vjp_call_jaxpr_check(mesh, *in_rep, fun_jaxpr, **_): - return _check_rep(mesh, fun_jaxpr.jaxpr, in_rep) - -@register_check(control_flow.solves.linear_solve_p) -def _linear_solve_check(mesh, *in_rep, jaxprs, **_): - out_rep = _standard_check(control_flow.solves.linear_solve_p, mesh, *in_rep) - return [out_rep] * len(jaxprs.solve.out_avals) -register_standard_rewrite(control_flow.solves.linear_solve_p) - -@register_check(ffi.ffi_call_p) -def _ffi_call_check(mesh, *in_rep, result_avals, **_): - out_rep = _standard_check(ffi.ffi_call_p, mesh, *in_rep) - return [out_rep] * len(result_avals) -register_standard_rewrite(ffi.ffi_call_p) - -del _check_rules[lax.tie_p] - -@register_check(lax.tie_p) -def _tie_check(mesh, x_rep, y_rep): - return x_rep -register_norewrite(lax.tie_p) - - -# Batching - -def _shard_map_batch( - trace: batching.BatchTrace, prim: core.Primitive, fun: lu.WrappedFun, - in_tracers: Sequence[batching.BatchTracer], mesh: Mesh, - in_names: tuple[AxisNames, ...], - out_names_thunk: Callable[[], tuple[AxisNames, ...]], - check_rep: bool, - rewrite: bool, - auto: frozenset) -> Sequence[batching.BatchTracer]: - in_vals, in_dims = unzip2(map(trace.to_batch_info, in_tracers)) - if any(isinstance(d, batching.RaggedAxis) for d in in_dims): - raise NotImplementedError - new_in_names = [{ax + (d is not batching.not_mapped and d <= ax): names[ax] - for ax in names} for names, d in zip(in_names, in_dims)] - spmd_axis_name = trace.axis_data.spmd_name - if spmd_axis_name is not None: - used = {n for names in in_names for ns in names.values() for n in ns} - if not config.disable_vmap_shmap_error.value and set(spmd_axis_name) & used: - raise ValueError("vmap spmd_axis_name cannot appear in shard_map in_specs") - new_in_names = [{**ns, d:spmd_axis_name} if d is not batching.not_mapped - else ns for ns, d in zip(new_in_names, in_dims)] - new_size = trace.axis_data.size // prod(mesh.shape[n] for n in spmd_axis_name) - new_axis_data = batching.AxisData(trace.axis_data.name, new_size, - trace.axis_data.spmd_name, None) - else: - new_axis_data = trace.axis_data - fun, out_dims = batching.batch_subtrace(fun, trace.tag, new_axis_data, tuple(in_dims)) - @as_hashable_function(closure=out_names_thunk) - def new_out_names_thunk(): - return _batch_out_names(spmd_axis_name, out_dims(), out_names_thunk()) - - new_params = dict(mesh=mesh, in_names=new_in_names, - out_names_thunk=new_out_names_thunk, check_rep=check_rep, - rewrite=rewrite, auto=auto) - with core.set_current_trace(trace.parent_trace): - out_vals = prim.bind(fun, *in_vals, **new_params) - make_tracer = partial(batching.BatchTracer, trace, - source_info=source_info_util.current()) - return map(make_tracer, out_vals, out_dims()) -batching.BatchTrace.process_shard_map = _shard_map_batch - -def _batch_out_names(spmd_axis_name, dims, out_names): - out_names_ = [{ax + (d is not batching.not_mapped and d <= ax): names[ax] - for ax in names} for names, d in zip(out_names, dims)] - if spmd_axis_name is not None: - used = {n for names in out_names for ns in names.values() for n in ns} - if not config.disable_vmap_shmap_error.value and set(spmd_axis_name) & used: - raise ValueError("vmap spmd_axis_name cannot appear in shard_map out_specs") - out_names_ = [{**ns, d:spmd_axis_name} if d is not batching.not_mapped - else ns for ns, d in zip(out_names_, dims)] - return out_names_ - - -# Autodiff - -def _shard_map_jvp(trace, shard_map_p, f, tracers, mesh, in_names, - out_names_thunk, check_rep, rewrite, auto): - primals, tangents = unzip2(map(trace.to_primal_tangent_pair, tracers)) - which_nz = [ type(t) is not ad.Zero for t in tangents] - tangents = [t if type(t) is not ad.Zero else None for t in tangents] - args, in_tree = tree_flatten((primals, tangents)) - f_jvp = ad.jvp_subtrace(f, trace.tag) - f_jvp, which_nz_out = ad.nonzero_tangent_outputs(f_jvp) - tangent_in_names = [ax for ax, nz in zip(in_names, which_nz) if nz] - - @as_hashable_function(closure=out_names_thunk) - def new_out_names_thunk(): - out_ax = out_names_thunk() - return (*out_ax, *(ax for ax, nz in zip(out_ax, which_nz_out()) if nz)) - params = dict(mesh=mesh, in_names=(*in_names, *tangent_in_names), - out_names_thunk=new_out_names_thunk, check_rep=check_rep, - rewrite=rewrite, auto=auto) - f_jvp, out_tree = ad.traceable(f_jvp, in_tree) - result = shard_map_p.bind_with_trace(trace.parent_trace, (f_jvp,) + tuple(args), params) - primal_out, tangent_out = tree_unflatten(out_tree(), result) - tangent_out = [ad.Zero(core.get_aval(p).to_tangent_aval()) if t is None else t - for p, t in zip(primal_out, tangent_out)] - return [ad.JVPTracer(trace, p, t) for p, t in zip(primal_out, tangent_out)] -ad.JVPTrace.process_shard_map = _shard_map_jvp - -def _shard_map_partial_eval(trace: pe.JaxprTrace, shard_map_p, - f: lu.WrappedFun, tracers, mesh, in_names, - out_names_thunk, check_rep, rewrite, auto): - tracers = map(trace.to_jaxpr_tracer, tracers) - in_pvals = [t.pval for t in tracers] - in_knowns, in_avals, in_consts = pe.partition_pvals(in_pvals) - unk_in_names, known_in_names = pe.partition_list(in_knowns, in_names) - all_names = _all_newly_manual_mesh_names(mesh, auto, trace) - in_avals_sharded = map(partial(_shard_aval, mesh, auto), unk_in_names, in_avals) - f = pe.trace_to_subjaxpr_nounits_fwd2(f, trace.tag, f.debug_info, False) - f = _promote_scalar_residuals(f) - f_known, aux = pe.partial_eval_wrapper_nounits( - f, (*in_knowns,), (*in_avals_sharded,)) - - @as_hashable_function(closure=out_names_thunk) - def known_out_names(): - in_fwd, out_fwd, out_knowns, _, jaxpr, _ = aux() - _, out_known_names = pe.partition_list(out_knowns, out_names_thunk()) - num_res = sum(f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd)) - return (*out_known_names, *({0: all_names},) * num_res) - - known_params = dict(mesh=mesh, in_names=(*known_in_names,), - out_names_thunk=known_out_names, check_rep=check_rep, - rewrite=rewrite, auto=auto) - out = shard_map_p.bind_with_trace(trace.parent_trace, (f_known, *in_consts), known_params) - in_fwd, out_fwd, out_knowns, out_avals_sharded, jaxpr, env = aux() - num_res = sum(f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd)) - out_consts, non_fwd_res = split_list(out, [len(out) - num_res]) - assert not jaxpr.constvars - unk_out_names, _ = pe.partition_list(out_knowns, out_names_thunk()) - known_out_names_ = known_out_names() - res = subs_list2(in_fwd, out_fwd, in_consts, out_consts, non_fwd_res) - res_names = [known_in_names[f1] if f1 is not None else - known_out_names_[f2] if f2 is not None else - {0: all_names} for f1, f2 in zip(in_fwd, out_fwd)] - unk_in_names = (*res_names,) + ({},) * len(env) + (*unk_in_names,) # type: ignore[assignment] - const_tracers = map(trace.new_instantiated_const, res) - env_tracers = map(trace.to_jaxpr_tracer, env) - unk_arg_tracers = [t for t in tracers if not t.is_known()] - unk_params = dict(mesh=mesh, in_names=unk_in_names, - out_names=unk_out_names, jaxpr=jaxpr, check_rep=False, - rewrite=rewrite, auto=auto) - out_avals = map(partial(_unshard_aval, mesh), unk_out_names, out_avals_sharded) - out_tracers = [pe.JaxprTracer(trace, pe.PartialVal.unknown(a), None) - for a in out_avals] - effs = core.filter_named_axis_effects(jaxpr.effects, mesh.axis_names) - eqn = pe.new_eqn_recipe((*const_tracers, *env_tracers, *unk_arg_tracers), - out_tracers, shard_map_p, unk_params, - effs, source_info_util.current()) - for t in out_tracers: t.recipe = eqn - return merge_lists(out_knowns, out_tracers, out_consts) -pe.JaxprTrace.process_shard_map = _shard_map_partial_eval - -def _shard_map_linearize(trace, shard_map_p, f: lu.WrappedFun, - tracers, mesh, in_names, - out_names_thunk, check_rep, rewrite, auto): - primals, tangents = unzip2(map(trace.to_primal_tangent_pair, tracers)) - nzs_in = tuple(type(t) is not ad.Zero for t in tangents) - f_primal, linearize_outs_thunk = ad.linearize_subtrace(f, trace.tag, nzs_in, f.debug_info) - f_primal = _promote_scalar_residuals_lin(f_primal, linearize_outs_thunk) - tangent_in_names = [ax for ax, nz in zip(in_names, nzs_in) if nz] - res_names = _all_newly_manual_mesh_names(mesh, auto, trace) - - @as_hashable_function(closure=linearize_outs_thunk) - def fwd_out_names_thunk(): - _, _, _, _, in_fwd, out_fwd = linearize_outs_thunk() - out_names = out_names_thunk() - num_res_out = sum(f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd)) - # This is incorrect so we set `check_rep=False` in the tangent (as in JVP). - return (*({0: res_names} for _ in range(num_res_out)), *out_names) - fwd_params = dict( - mesh=mesh, in_names=in_names, - out_names_thunk=fwd_out_names_thunk, check_rep=check_rep, - rewrite=rewrite, auto=auto) - all_fwd_results = shard_map_p.bind_with_trace( - trace.parent_trace, (f_primal, *primals), fwd_params) - residual_avals, nzs_out, lin_jaxpr, env, in_fwd, out_fwd = linearize_outs_thunk() - num_res_out = sum(f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd)) - non_fwd_res = all_fwd_results[:num_res_out] - primals_out = all_fwd_results[num_res_out:] - residuals = subs_list2(in_fwd, out_fwd, primals, primals_out, non_fwd_res) - args_to_promote = [getattr(aval, 'shape', ()) == () and f1 is None and f2 is None - for aval, f1, f2 in zip(residual_avals, in_fwd, out_fwd)] - with _extend_axis_env(mesh, auto), use_abstract_mesh(_as_manual_mesh(mesh, auto)): - lin_jaxpr = _promote_scalar_residuals_jaxpr(lin_jaxpr, args_to_promote) - out_names = out_names_thunk() - residual_names = [in_names[f1] if f1 is not None else - out_names[f2] if f2 is not None else - {0: res_names} for f1, f2 in zip(in_fwd, out_fwd)] - new_in_names = (*residual_names, *({} for _ in range(len(env))), - *(ax for ax, nz in zip(in_names, nzs_in) if nz)) - tangent_out_names = tuple(ax for ax, nz in zip(out_names_thunk(), nzs_out) if nz) - @as_hashable_function(closure=tangent_out_names) - def tangent_out_names_thunk(): - return tangent_out_names - tangent_params = dict( - mesh=mesh, in_names=new_in_names, out_names_thunk=tangent_out_names_thunk, - check_rep=False, rewrite=rewrite, auto=auto) - - # TODO(mattjj): avoid round-tripping the jaxpr through eval_jaxpr here - def f_tangent(*args): - return core.eval_jaxpr(lin_jaxpr, (), *args) - - nz_tangents_in = [t for (t, nz) in zip(tangents, nzs_in) if nz] - nz_tangents_out = shard_map_p.bind_with_trace( - trace.tangent_trace, - (lu.wrap_init(f_tangent, debug_info=lin_jaxpr.debug_info), - *residuals, *env, *nz_tangents_in), tangent_params) - nz_tangents_out_iter = iter(nz_tangents_out) - tangents_out = [next(nz_tangents_out_iter) if nz else ad.Zero.from_primal_value(primal) - for nz, primal in zip(nzs_out, primals_out)] - return map(partial(ad.maybe_linearize_tracer, trace), primals_out, nzs_out, tangents_out) -ad.LinearizeTrace.process_shard_map = _shard_map_linearize - -@lu.transformation2 -def _promote_scalar_residuals_lin(f, linearize_outs_thunk, *args, **kwargs): - ans = f(*args, **kwargs) - _, _, _, _, in_fwd, out_fwd = linearize_outs_thunk() - num_res_out = sum(f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd)) - residuals = ans[:num_res_out] - primals = ans[num_res_out:] - residuals = [jax.lax.broadcast(x, (1,)) if not getattr(x, 'shape', ()) else x - for x in residuals] - return *residuals, *primals - -@lu.transformation2 -def _promote_scalar_residuals(f: Callable, *args, **kwargs): - jaxpr, (in_fwds, out_fwds, out_pvals, out_consts, env) = f(*args, **kwargs) - which = [f1 is None and f2 is None and not v.aval.shape - for f1, f2, v in zip(in_fwds, out_fwds, jaxpr.constvars)] - jaxpr = _promote_scalar_residuals_jaxpr(jaxpr, which) - out_consts = [jax.lax.broadcast(x, (1,)) if not getattr(x, 'shape', ()) else x - for x in out_consts] - return jaxpr, (in_fwds, out_fwds, out_pvals, out_consts, env) - -def _promote_scalar_residuals_jaxpr(jaxpr: core.Jaxpr, - which: Sequence[bool]): - def fun(*res_and_args): - res, args = split_list(res_and_args, [len(jaxpr.constvars)]) - res = [_rem_singleton(x) if w else x for x, w in zip(res, which)] - return core.eval_jaxpr(jaxpr, res, *args) - res_avals = [core.unmapped_aval(1, 0, v.aval) if w else v.aval - for v, w in zip(jaxpr.constvars, which)] - in_avals = [*res_avals, *[v.aval for v in jaxpr.invars]] - jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic( - lu.wrap_init(fun, debug_info=jaxpr.debug_info), in_avals) - return jaxpr - - -def _unmentioned2(mesh: Mesh, names: AxisNames, - auto: frozenset[AxisName]) -> list[AxisName]: - # We use a filtered-down version of unmentioned to avoid defensive-psum over - # more chips than required in the transpose-no-check-rep case. - name_set = {n for ns in names.values() for n in ns} | auto - return [n for n in _all_mesh_names_except_spmd(mesh, auto) - if n not in name_set] - - -def _shard_map_transpose(out_cts, *args, - jaxpr: core.Jaxpr, mesh, in_names, out_names, - check_rep, rewrite, auto): - mb_div = lambda x, y: x / y if y != 1 else x - out_cts = [ - ad.Zero(_shard_aval(mesh, auto, ns, x.aval)) if type(x) is ad.Zero - else x if rewrite or dtypes.dtype(x) == dtypes.float0 - else mb_div(x, prod(map(mesh.shape.get, _unmentioned2(mesh, ns, auto)))) - for ns, x in zip(out_names, out_cts) - ] - args = tuple(x if type(x) is not ad.UndefinedPrimal else - ad.UndefinedPrimal(_shard_aval(mesh, auto, ns, x.aval)) - for ns, x in zip(in_names, args)) - all_args, in_tree = tree_flatten((out_cts, args)) - - def fun_trans_callable(out_cts, args): - # TODO(mattjj): when #26811 lands, delete this and just run backward_pass - in_undef = map(ad.is_undefined_primal, args) - res, undefs = partition_list(in_undef, args) - jaxpr_known, jaxpr_unknown, _, _ = pe.partial_eval_jaxpr_nounits( - pe.close_jaxpr(jaxpr), in_undef, False) - res_reshaped = core.jaxpr_as_fun(jaxpr_known)(*res) - in_cts = ad.backward_pass( - jaxpr_unknown.jaxpr, False, (), (*res_reshaped, *undefs), out_cts - )[len(res_reshaped):] - _, in_ct_names = partition_list(in_undef, in_names) - in_cts = [ad.Zero(_unshard_aval(mesh, ns, x.aval)) if type(x) is ad.Zero - else x if rewrite - else jax.lax.psum(x, tuple(_unmentioned2(mesh, ns, auto))) - for ns, x in zip(in_ct_names, in_cts)] - res_zeros = [ad_util.zero_from_primal(r) for r in res] - return merge_lists(in_undef, res_zeros, in_cts) - - fun_trans = lu.wrap_init(fun_trans_callable, debug_info=jaxpr.debug_info) - fun_trans, nz_arg_cts = ad.nonzero_outputs(fun_trans) - fun_trans_flat, out_tree = api_util.flatten_fun_nokwargs(fun_trans, in_tree) - - new_in_names = \ - [n for n, x in zip(out_names, out_cts) if type(x) is not ad.Zero] + \ - [n for n, x in zip(in_names, args) if type(x) is not ad.UndefinedPrimal] - - def new_out_names_thunk(): - return tuple(names for names, nz in zip(in_names, nz_arg_cts()) if nz) - - try: - out_flat = shard_map_p.bind( - fun_trans_flat, *all_args, mesh=mesh, in_names=tuple(new_in_names), - out_names_thunk=new_out_names_thunk, check_rep=check_rep, rewrite=rewrite, - auto=auto) - except (FloatingPointError, ZeroDivisionError) as e: - print("Invalid nan value encountered in the backward pass of a shard_map " - "function. Calling the de-optimized backward pass.") - try: - # TODO(mattjj): Remove this and do `fun_trans.call_wrapped(out_cts, args)` - # in eager mode so that output of shmap are not manual. - with jax.disable_jit(True): - _ = shard_map_p.bind( - fun_trans_flat, *all_args, mesh=mesh, in_names=tuple(new_in_names), - out_names_thunk=new_out_names_thunk, check_rep=check_rep, - rewrite=rewrite, auto=auto) - except (FloatingPointError, ZeroDivisionError) as e2: - raise e2 from None - else: - dispatch._raise_no_nan_in_deoptimized(e) - return tree_unflatten(out_tree(), out_flat) -ad.primitive_transposes[shard_map_p] = _shard_map_transpose - -# Remat - -def _partial_eval_jaxpr_custom_rule( - saveable: Callable[..., pe.RematCases_], unks_in: Sequence[bool], - inst_in: Sequence[bool], eqn: core.JaxprEqn -) -> tuple[core.JaxprEqn, core.JaxprEqn, Sequence[bool], Sequence[bool], - list[core.Var]]: - jaxpr, mesh = eqn.params['jaxpr'], eqn.params['mesh'] - auto = eqn.params['auto'] - with _extend_axis_env(mesh, auto): - jaxpr_known, jaxpr_staged, unks_out, inst_out, num_res = \ - pe.partial_eval_jaxpr_custom(jaxpr, unks_in, inst_in, False, False, saveable) - num_out_primals = len(jaxpr_known.outvars) - num_res - in_fwd = pe._jaxpr_forwarding(jaxpr_known)[num_out_primals:] - out_vars, res_vars = split_list(jaxpr_known.outvars, [num_out_primals]) - idx_map = {id(v): i for i, v in enumerate(out_vars)} - out_fwd = [idx_map.get(id(v)) for v in res_vars] - which = [f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd)] - mesh = eqn.params['mesh'] - with (_extend_axis_env(mesh, auto), - use_abstract_mesh(_as_manual_mesh(mesh, auto))): - jaxpr_known = pe.prune_jaxpr_outputs(jaxpr_known, [True] * num_out_primals + which) - jaxpr_known, jaxpr_staged = _add_reshapes(which, jaxpr_known, jaxpr_staged) - jaxpr_known = core.remove_named_axis_effects(jaxpr_known, mesh.axis_names) - jaxpr_staged = core.remove_named_axis_effects(jaxpr_staged, mesh.axis_names) - ins_known, _ = partition_list(unks_in, eqn.invars) - out_binders_known, _ = partition_list(unks_out, eqn.outvars) - _, ins_staged = partition_list(inst_in, eqn.invars) - _, out_binders_staged = partition_list(inst_out, eqn.outvars) - newvar = core.gensym() - params_known, params_staged, res_names = _pe_custom_params( - unks_in, inst_in, map(op.not_, unks_out), inst_out, in_fwd, out_fwd, which, - dict(eqn.params, jaxpr=jaxpr_known), dict(eqn.params, jaxpr=jaxpr_staged)) - residuals = [newvar(_unshard_aval(mesh, {0: res_names}, var.aval)) - for var, w in zip(jaxpr_staged.invars[:num_res], which) if w] - eqn_known = pe.new_jaxpr_eqn(ins_known, [*out_binders_known, *residuals], - eqn.primitive, params_known, jaxpr_known.effects, - eqn.source_info, eqn.ctx) - full_res = subs_list2(in_fwd, out_fwd, ins_known, out_binders_known, residuals) - eqn_staged = pe.new_jaxpr_eqn([*full_res, *ins_staged], out_binders_staged, - eqn.primitive, params_staged, - jaxpr_staged.effects, eqn.source_info, eqn.ctx) - assert len(eqn_staged.invars) == len(jaxpr_staged.invars) - new_inst = [x for x, inst in zip(eqn.invars, inst_in) - if type(x) is core.Var and not inst] - new_inst += [out_binders_known[f] for f in {i for i in out_fwd if i is not None}] - return eqn_known, eqn_staged, unks_out, inst_out, new_inst + residuals -pe.partial_eval_jaxpr_custom_rules[shard_map_p] = \ - _partial_eval_jaxpr_custom_rule - -def _add_reshapes(which: Sequence[bool], - jaxpr_known: core.Jaxpr, - jaxpr_staged: core.Jaxpr) -> tuple[core.Jaxpr, core.Jaxpr]: - # add singleton axes to residuals which are from jaxpr_known and are scalars - which_ = [w and not v.aval.shape # pytype: disable=attribute-error - for w, v in zip(which, jaxpr_staged.invars[:len(which)])] - if not any(which_): return jaxpr_known, jaxpr_staged - assert not jaxpr_known.constvars and not jaxpr_staged.constvars - - def known(*args): - out = core.eval_jaxpr(jaxpr_known, (), *args) - out_known, res = split_list(out, [len(out) - sum(which)]) - res = [_add_singleton(x) if not x.shape else x for x in res] - return [*out_known, *res] - avals_in = [v.aval for v in jaxpr_known.invars] - jaxpr_known, _, (), () = pe.trace_to_jaxpr_dynamic( - lu.wrap_init(known, debug_info=jaxpr_known.debug_info), avals_in) - - def staged(*args): - res_, ins = split_list(args, [len(which)]) - res = [_rem_singleton(x) if w else x for x, w in zip(res_, which_)] - return core.eval_jaxpr(jaxpr_staged, (), *res, *ins) - res_avals = [core.unmapped_aval(1, 0, v.aval) if w else v.aval - for w, v in zip(which_, jaxpr_staged.invars[:len(which)])] - avals_in = [*res_avals, *[v.aval for v in jaxpr_staged.invars[len(which):]]] - jaxpr_staged, _, (), () = pe.trace_to_jaxpr_dynamic( - lu.wrap_init(staged, debug_info=jaxpr_staged.debug_info), avals_in) - - return jaxpr_known, jaxpr_staged - -def _pe_custom_params(unks_in, inst_in, kept_outs_known, kept_outs_staged, - in_fwd, out_fwd, which, params_known, params_staged): - # prune inputs to jaxpr_known according to unks_in - mesh = params_known['mesh'] - auto = params_known['auto'] - res_names_ = _all_newly_manual_mesh_names(mesh, auto) - in_names_known, _ = partition_list(unks_in, params_known['in_names']) - _, out_names_known = partition_list(kept_outs_known, params_known['out_names']) - out_names_known = out_names_known + [{0: res_names_}] * sum(which) - new_params_known = dict(params_known, in_names=tuple(in_names_known), - out_names=tuple(out_names_known)) - - # added num_res new inputs to jaxpr_staged, pruning according to inst_in - _, in_names_staged = partition_list(inst_in, params_staged['in_names']) - res_names = [in_names_known[f1] if f1 is not None else - out_names_known[f2] if f2 is not None else - {0: res_names_} for f1, f2 in zip(in_fwd, out_fwd)] - in_names_staged = res_names + in_names_staged - _, out_names_staged = partition_list(kept_outs_staged, params_staged['out_names']) - new_params_staged = dict(params_staged, in_names=tuple(in_names_staged), - out_names=tuple(out_names_staged), check_rep=False) - return new_params_known, new_params_staged, res_names_ - -# TODO(mattjj): remove this mechanism when we revise mesh scopes -def _all_mesh_names_except_spmd( - mesh: Mesh, auto: frozenset[AxisName], trace=None -) -> tuple[AxisName, ...]: - axis_env = core.get_axis_env() - spmd_names = axis_env.spmd_axis_names - return tuple(name for name in mesh.axis_names if name not in spmd_names and - name not in auto) - -def _all_newly_manual_mesh_names( - mesh: Mesh, auto: frozenset[AxisName], trace=None -) -> tuple[AxisName, ...]: - axis_env = core.get_axis_env() - vmap_spmd_names = set(axis_env.spmd_axis_names) - if not (ctx_mesh := get_abstract_mesh()).empty: - mesh = ctx_mesh - already_manual_names = set(ctx_mesh._axis_types_dict.get(AxisType.Manual, ())) - else: - # TODO(mattjj): remove this mechanism when we revise mesh scopes - already_manual_names = set(axis_env.axis_sizes) # may include vmap axis_names - return tuple(name for name in mesh.axis_names - if name not in auto | vmap_spmd_names | already_manual_names) - - -# DCE - -# TODO(mattjj): de-duplicate with pe.dce_jaxpr_call_rule, and/or _pmap_dce_rule? -def _shard_map_dce(used_outputs: list[bool], eqn: core.JaxprEqn - ) -> tuple[list[bool], core.JaxprEqn | None]: - if not any(used_outputs) and not pe.has_effects(eqn): - return [False] * len(eqn.invars), None - mesh = eqn.params["mesh"] - auto = eqn.params["auto"] - with _extend_axis_env(mesh, auto): - jaxpr, used_inputs = pe.dce_jaxpr(eqn.params['jaxpr'], used_outputs) - if not any(used_inputs) and not any(used_outputs) and not jaxpr.effects: - return used_inputs, None - else: - _, in_names = partition_list(used_inputs, eqn.params['in_names']) - _, out_names = partition_list(used_outputs, eqn.params['out_names']) - new_params = dict(eqn.params, jaxpr=jaxpr, in_names=tuple(in_names), - out_names=tuple(out_names)) - effs = core.filter_named_axis_effects(jaxpr.effects, mesh.axis_names) - new_eqn = pe.new_jaxpr_eqn( - [v for v, used in zip(eqn.invars, used_inputs) if used], - [x for x, used in zip(eqn.outvars, used_outputs) if used], - eqn.primitive, new_params, effs, eqn.source_info, eqn.ctx) - return used_inputs, new_eqn -pe.dce_rules[shard_map_p] = _shard_map_dce - -# Implementing pmap in terms of shard_map - -def pmap(f, axis_name=None, *, in_axes=0, out_axes=0, - static_broadcasted_argnums=(), devices=None, backend=None, - axis_size=None, donate_argnums=(), global_arg_shapes=None): - devices = tuple(devices) if devices is not None else devices - axis_name, static_broadcasted_tuple, donate_tuple = _shared_code_pmap( - f, axis_name, static_broadcasted_argnums, donate_argnums, in_axes, out_axes) - - def infer_params(*args, **kwargs): - p = _prepare_pmap(f, in_axes, out_axes, static_broadcasted_tuple, - donate_tuple, devices, backend, axis_size, args, kwargs) - for arg in p.flat_args: - dispatch.check_arg(arg) - mesh = Mesh(_get_devices(p, backend), (axis_name,)) - _pmapped, in_specs, out_specs = _cached_shard_map( - p.flat_fun, mesh, p.in_axes_flat, p.out_axes_thunk, axis_name) - flat_global_args = host_local_array_to_global_array( - p.flat_args, mesh, list(in_specs)) - jitted_f = jax.jit( - _pmapped, - donate_argnums=(i for i, val in enumerate(p.donated_invars) if val)) - return jitted_f, flat_global_args, p.out_tree, mesh, out_specs - - def wrapped(*args, **kwargs): - (jitted_f, flat_global_args, out_tree, mesh, - out_specs) = infer_params(*args, **kwargs) - outs = jitted_f(*flat_global_args) - outs = global_array_to_host_local_array(outs, mesh, out_specs()) - return tree_unflatten(out_tree(), outs) - - def lower(*args, **kwargs): - jitted_f, _, _, _, _ = infer_params(*args, **kwargs) - return jitted_f.lower(*args, **kwargs) - wrapped.lower = lower - - return wrapped - - -@lu.cache -def _cached_shard_map(flat_fun, mesh, in_axes_flat, out_axes_thunk, axis_name): - in_specs = tuple(map(partial(_axis_to_spec, axis_name), in_axes_flat)) - out_specs = lambda: map(partial(_axis_to_spec, axis_name), out_axes_thunk()) - fun = _handle_reshapes(flat_fun, in_axes_flat, out_axes_thunk) - return (_shard_map(fun.call_wrapped, mesh, in_specs, out_specs, - check_rep=False, auto=frozenset()), - in_specs, out_specs) - -@lu.transformation2 -def _handle_reshapes(f, in_axes, out_axes_thunk, *args, **kwargs): - args = tree_map(lambda x, ax: x if ax is None else jnp.squeeze(x, axis=ax), - list(args), list(in_axes)) - out = f(*args) - return tree_map(lambda x, ax: x if ax is None else jnp.expand_dims(x, axis=ax), - list(out), list(out_axes_thunk())) - -def _axis_to_spec(axis_name, ax): - if isinstance(ax, int): - specs = [None] * ax + [axis_name] - return P(*specs) - elif ax is None: - return P() - else: - raise TypeError(ax) - -def _get_devices(p, backend): - if backend is not None and p.devices is None: - devs = jax.devices(backend=backend) - else: - devs = jax.devices() if p.devices is None else p.devices - if jax.process_count() > 1: - return devs[:p.global_axis_size] - return devs[:p.local_axis_size] - - -### Rewrite! - -Val = Any - -class RewriteTracer(core.Tracer): - rep: set[AxisName] - val: Val - - def __init__(self, trace, rep, val): - self._trace = trace - self.rep = rep - self.val = val - - @property - def aval(self) -> core.AbstractValue: - return core.get_aval(self.val) - - def to_concrete_value(self): - return core.to_concrete_value(self.val) - - def __str__(self) -> str: - return str(self.val) # TODO(mattjj): could show replication info here - __repr__ = __str__ # for debuggers, like `p x` - -class RewriteTrace(core.Trace): - __slots__ = ("parent_trace", "tag", "mesh") - - parent_trace : core.Trace - tag : core.TraceTag - mesh: Mesh - - def __init__(self, parent_trace, tag, mesh): - super().__init__() - self.parent_trace = parent_trace - self.tag = tag - self.mesh = mesh - - def to_val_rep_pair(self, val): - # TODO: add a tag to tell if self - if isinstance(val, RewriteTracer) and val._trace.tag is self.tag: - return val.val, val.rep - else: - return val, set(self.mesh.axis_names) - - def process_primitive(self, prim, in_tracers, params): - rule = _rewrite_rules.get(prim, partial(_rule_missing, prim)) - in_vals, in_reps = unzip2(map(self.to_val_rep_pair, in_tracers)) - with core.set_current_trace(self.parent_trace): - out_vals, out_reps = rule(self.mesh, in_reps, *in_vals, **params) - out_tracers = map(partial(RewriteTracer, self), out_reps, out_vals) - return out_tracers if prim.multiple_results else out_tracers[0] - - def process_call(self, call_primitive, f, in_tracers, params): - in_vals, in_reps = unzip2(map(self.to_val_rep_pair, in_tracers)) - f, out_reps = _rewrite_subtrace(f, self.tag, self.mesh, tuple(in_reps)) - with core.set_current_trace(self.parent_trace): - out_vals = call_primitive.bind(f, *in_vals, **params) - return map(partial(RewriteTracer, self), out_reps(), out_vals) - - def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros): - if symbolic_zeros: - msg = ("Please open an issue at https://github.com/jax-ml/jax/issues and " - "as a temporary workaround pass the check_rep=False argument to " - "shard_map") - raise NotImplementedError(msg) - in_vals, in_reps = unzip2(map(self.to_val_rep_pair, tracers)) - fun, out_reps1 = _rewrite_subtrace(fun, self.tag, self.mesh, in_reps) - jvp, out_reps2 = _rewrite_subtrace(jvp, self.tag, self.mesh, in_reps * 2) - with core.set_current_trace(self.parent_trace): - out_vals = prim.bind(fun, jvp, *in_vals, symbolic_zeros=symbolic_zeros) - fst, out_reps = lu.merge_linear_aux(out_reps1, out_reps2) - if not fst: - assert out_reps == out_reps[:len(out_reps) // 2] * 2 - out_reps = out_reps[:len(out_reps) // 2] - return map(partial(RewriteTracer, self), out_reps, out_vals) - - def process_custom_vjp_call(self, prim: core.Primitive, fun: lu.WrappedFun, - fwd: lu.WrappedFun, bwd: lu.WrappedFun, - tracers, - out_trees: Callable[[], Sequence[PyTreeDef]], - symbolic_zeros: bool): - if symbolic_zeros: - msg = ("Please open an issue at https://github.com/jax-ml/jax/issues and " - "as a temporary workaround pass the check_rep=False argument to " - "shard_map") - raise NotImplementedError(msg) - in_vals, in_reps = unzip2(map(self.to_val_rep_pair, tracers)) - fun, out_reps1 = _rewrite_subtrace(fun, self.tag, self.mesh, in_reps) - fwd_in_reps = [r_ for r in in_reps for r_ in [r, set(self.mesh.axis_names)]] - fwd, out_reps2 = _rewrite_subtrace(fwd, self.tag, self.mesh, fwd_in_reps) - bwd = _rewrite_bwd(bwd, self.mesh, out_reps2, in_reps) - with core.set_current_trace(self.parent_trace): - out_vals = prim.bind(fun, fwd, bwd, *in_vals, out_trees=out_trees, - symbolic_zeros=symbolic_zeros) - fst, out_reps = lu.merge_linear_aux(out_reps1, out_reps2) - if not fst: - _, res_tree = out_trees() - _, out_reps = split_list(out_reps, [res_tree.num_leaves]) - return map(partial(RewriteTracer, self), out_reps, out_vals) - -def _efficient_transpose_rewrite(fun, mesh, in_names, out_names_thunk): - in_reps = map(partial(_in_names_to_rep, mesh), in_names) - out_reps_dst = lambda: [set(_unmentioned(mesh, n)) for n in out_names_thunk()] - fun, out_reps_src = _efficient_transpose_rewrite_nomatch(fun, mesh, in_reps) - return _match_rep(fun, mesh, out_reps_src, out_reps_dst) - -@lu.transformation_with_aux2 -def _efficient_transpose_rewrite_nomatch(f, store, mesh, in_reps, *args): - with core.take_current_trace() as parent: - tag = core.TraceTag() - t = RewriteTrace(parent_trace = parent, tag = tag, mesh=mesh) - in_tracers = map(partial(RewriteTracer, t), in_reps, args) - with core.set_current_trace(t): - ans = f(*in_tracers) - out_vals, out_reps = unzip2(map(t.to_val_rep_pair, ans)) - del t, in_tracers, ans - store.store(out_reps) - return out_vals - -@lu.transformation2 -def _match_rep(f, mesh, out_reps_src_, out_reps_dst_, *args): - outs = f(*args) - out_reps_src = out_reps_src_() if callable(out_reps_src_) else out_reps_src_ - out_reps_dst = out_reps_dst_() if callable(out_reps_dst_) else out_reps_dst_ - _check_reps2(mesh, out_reps_dst, out_reps_src) - outs = [pbroadcast(x, tuple(n for n in src if n not in dst)) if src - dst - else x for x, src, dst in zip(outs, out_reps_src, out_reps_dst)] - return outs - -# TODO(mattjj): caching -def _replication_rewrite_match( - mesh: Mesh, - jaxpr: core.ClosedJaxpr, - in_rep: Sequence[set[AxisName]], - out_rep_dst: Sequence[set[AxisName]], -) -> core.ClosedJaxpr: - f = lu.wrap_init(partial(core.eval_jaxpr, jaxpr.jaxpr, jaxpr.consts), - debug_info=jaxpr.jaxpr.debug_info) - f, out_rep = _efficient_transpose_rewrite_nomatch(f, mesh, in_rep) - f = _match_rep(f, mesh, out_rep, out_rep_dst) - jaxpr_, _, consts, () = pe.trace_to_jaxpr_dynamic(f, jaxpr.in_avals) - return core.ClosedJaxpr(jaxpr_, consts) - -# TODO(mattjj): caching -def _replication_rewrite_nomatch( - mesh: Mesh, - jaxpr: core.ClosedJaxpr, - in_rep: Sequence[set[AxisName]], -) -> tuple[core.ClosedJaxpr, list[set[AxisName]]]: - f = lu.wrap_init(partial(core.eval_jaxpr, jaxpr.jaxpr, jaxpr.consts), - debug_info=jaxpr.jaxpr.debug_info) - f, out_rep = _efficient_transpose_rewrite_nomatch(f, mesh, in_rep) - jaxpr_, _, consts, () = pe.trace_to_jaxpr_dynamic(f, jaxpr.in_avals) - return core.ClosedJaxpr(jaxpr_, consts), out_rep() - -@lu.transformation_with_aux2 -def _rewrite_subtrace(f: Callable, store: lu.Store, - tag: core.TraceTag, mesh: Mesh, in_reps, *in_vals): - with core.take_current_trace() as parent_trace: - assert len(in_reps) == len(in_vals), (len(in_reps), len(in_vals)) - t = RewriteTrace(parent_trace, tag, mesh) - in_tracers = map(partial(RewriteTracer, t), in_reps, in_vals) - with core.set_current_trace(t): - outs = f(*in_tracers) - out_vals, out_reps = unzip2(map(t.to_val_rep_pair, outs)) - store.store(out_reps) - return out_vals - -def _rewrite_bwd(bwd: lu.WrappedFun, - mesh: Mesh, in_reps, reps_dst) -> lu.WrappedFun: - def new_bwd(*args): - tag = core.TraceTag() - bwd_, reps_thunk = _rewrite_subtrace(bwd, tag, mesh, in_reps()) - out = bwd_.call_wrapped(*args) - return map(_match_replication, reps_thunk(), reps_dst, out) - return lu.wrap_init(new_bwd, debug_info=bwd.debug_info) - -def _match_replication(src, dst, x): - if dst - src: - x, = psum2_p.bind(x, axes=tuple(n for n in dst if n not in src), - axis_index_groups=None) - if src - dst: - x = pbroadcast(x, tuple(n for n in src if n not in dst)) - return x - -# TODO(parkers,mattjj): change implementation when we have sharding-in-types. -def get_replication(x: jax.Array) -> set[AxisName]: - """For a jax.Array, return what axes it is known to be replicated along.""" - - if isinstance(x, RewriteTracer): - return x.rep - if isinstance(x, batching.BatchTracer): - return get_replication(x.val) - raise ValueError("get_replication not defined on %s" % repr(type(x))) + axis_names = frozenset(mesh.axis_names) - auto + return jshmap._shard_map( + f, mesh=mesh, in_specs=in_specs, out_specs=out_specs, + check_vma=check_rep, axis_names=axis_names, _skip_mesh_check=True) diff --git a/jax/experimental/source_mapper/common.py b/jax/experimental/source_mapper/common.py index f7d10bc88f10..471fc0a7a877 100644 --- a/jax/experimental/source_mapper/common.py +++ b/jax/experimental/source_mapper/common.py @@ -15,7 +15,8 @@ import contextlib import dataclasses import re -from typing import Any, Protocol, Sequence +from typing import Any, Protocol +from collections.abc import Sequence from absl import flags import jax diff --git a/jax/experimental/source_mapper/generate_map.py b/jax/experimental/source_mapper/generate_map.py index 76fd0f744463..0066e35285fb 100644 --- a/jax/experimental/source_mapper/generate_map.py +++ b/jax/experimental/source_mapper/generate_map.py @@ -14,7 +14,8 @@ """Generates source maps for JAX functions.""" import os import tempfile -from typing import Sequence, Protocol +from typing import Protocol +from collections.abc import Sequence from jax.experimental.source_mapper import common diff --git a/jax/experimental/sparse/_base.py b/jax/experimental/sparse/_base.py index 7739af0291f1..36d84cb0db62 100644 --- a/jax/experimental/sparse/_base.py +++ b/jax/experimental/sparse/_base.py @@ -19,18 +19,8 @@ import jax from jax._src import core -from jax._src import ffi from jax._src import util from jax._src.typing import Array -from jax._src.lib import gpu_sparse - - -if hasattr(gpu_sparse, "registrations"): - for platform, targets in gpu_sparse.registrations().items(): - for name, value, api_version in targets: - ffi.register_ffi_target( - name, value, platform=platform, api_version=api_version - ) class JAXSparse(util.StrictABC): diff --git a/jax/experimental/sparse/_lowerings.py b/jax/experimental/sparse/_lowerings.py index 6962ef78bcff..c2c25db2c561 100644 --- a/jax/experimental/sparse/_lowerings.py +++ b/jax/experimental/sparse/_lowerings.py @@ -18,13 +18,40 @@ """ from functools import partial +from typing import Any from jax._src import core from jax._src import dispatch +from jax._src import ffi from jax._src.interpreters import mlir from jax._src.lib import gpu_sparse +from jax._src.lib import has_cpu_sparse import numpy as np +if hasattr(gpu_sparse, "registrations"): + for platform, targets in gpu_sparse.registrations().items(): + for name, value, api_version in targets: + ffi.register_ffi_target( + name, value, platform=platform, api_version=api_version + ) + +if has_cpu_sparse: + from jax._src.lib import cpu_sparse + + if hasattr(cpu_sparse, "registrations"): + for platform, targets in cpu_sparse.registrations().items(): + for name, value, api_version in targets: + ffi.register_ffi_target( + name, value, platform=platform, api_version=api_version + ) + +def _get_module(target_name_prefix: str) -> Any: + if target_name_prefix == "cu": + return gpu_sparse._cusparse + elif target_name_prefix == "hip": + return gpu_sparse._hipsparse + else: + raise ValueError(f"Unsupported target_name_prefix: {target_name_prefix}") SUPPORTED_DATA_DTYPES = [np.float32, np.float64, np.complex64, np.complex128] SUPPORTED_INDEX_DTYPES = [np.int32] @@ -54,27 +81,30 @@ def _coo_spmv_abstract_eval(data, row, col, x, *, transpose, shape): shape=shape[1:] if transpose else shape[:1], dtype=x.dtype) -def _coo_spmv_gpu_lowering(coo_spmv_hlo, ctx, data, row, col, x, *, transpose, shape): +def _coo_spmv_gpu_lowering(ctx, data, row, col, x, *, transpose, shape, + target_name_prefix): + rows, cols = shape data_aval, row_aval, _, x_aval = ctx.avals_in - return [coo_spmv_hlo( - data, row, col, x, - shape=shape, - transpose=transpose, - data_dtype=data_aval.dtype, - index_dtype=row_aval.dtype, - x_dtype=x_aval.dtype)] + nnz, = data_aval.shape + buffer_size, opaque = _get_module(target_name_prefix).build_coo_matvec_descriptor( + data_aval.dtype, x_aval.dtype, data_aval.dtype, row_aval.dtype, + rows, cols, nnz, transpose) + buffer_aval = core.ShapedArray(shape=(buffer_size,), dtype=np.int8) + sub_ctx = ctx.replace(avals_out=[ctx.avals_out[0], buffer_aval]) + rule = ffi.ffi_lowering(f"{target_name_prefix}sparse_coo_matvec_ffi") + return rule(sub_ctx, data, row, col, x, opaque=opaque)[:1] coo_spmv_p.def_abstract_eval(_coo_spmv_abstract_eval) dispatch.simple_impl(coo_spmv_p) if gpu_sparse.cuda_is_supported: mlir.register_lowering( coo_spmv_p, - partial(_coo_spmv_gpu_lowering, gpu_sparse.cuda_coo_matvec), + partial(_coo_spmv_gpu_lowering, target_name_prefix='cu'), platform='cuda') if gpu_sparse.rocm_is_supported: mlir.register_lowering( coo_spmv_p, - partial(_coo_spmv_gpu_lowering, gpu_sparse.rocm_coo_matvec), + partial(_coo_spmv_gpu_lowering, target_name_prefix='hip'), platform='rocm') @@ -103,27 +133,51 @@ def _coo_spmm_abstract_eval(data, row, col, x, *, transpose, shape): shape=(shape[1] if transpose else shape[0], x.shape[1]), dtype=x.dtype) -def _coo_spmm_gpu_lowering(coo_spmm_hlo, ctx, data, row, col, x, *, transpose, shape): +def _coo_spmm_gpu_lowering(ctx, data, row, col, x, *, transpose, shape, + target_name_prefix): data_aval, row_aval, _, x_aval = ctx.avals_in - return [coo_spmm_hlo( - data, row, col, x, - shape=shape, - transpose=transpose, - data_dtype=data_aval.dtype, - index_dtype=row_aval.dtype, - x_dtype=x_aval.dtype)] + nnz, = data_aval.shape + _, Ccols = x_aval.shape + + batch_count = 1 + if len(shape) == 2: + rows, cols = shape + elif len(shape) == 3: + batch_count, rows, cols = shape + nnz = nnz // batch_count + else: + raise NotImplementedError(f"Unsupported shape: {shape}") + + # TODO(tianjianlu): use batch stride to trigger different mode of batch + # computation. Currently batch_stride = 0 is not allowed because of the issue + # in cusparse https://github.com/NVIDIA/CUDALibrarySamples/issues/81#issuecomment-1205562643 + # Set batch stride to be the matrix size for now. + lhs_batch_stride = nnz + B_rows = rows if transpose else cols + rhs_batch_stride = B_rows * Ccols + + buffer_size, opaque = _get_module(target_name_prefix).build_coo_matmat_descriptor( + data_aval.dtype, x_aval.dtype, data_aval.dtype, row_aval.dtype, + rows, cols, Ccols, nnz, transpose, batch_count, lhs_batch_stride, + rhs_batch_stride) + + buffer_aval = core.ShapedArray(shape=(buffer_size,), dtype=np.int8) + sub_ctx = ctx.replace(avals_out=[ctx.avals_out[0], buffer_aval]) + rule = ffi.ffi_lowering(f"{target_name_prefix}sparse_coo_matmat_ffi") + return rule(sub_ctx, data, row, col, x, opaque=opaque)[:1] + coo_spmm_p.def_abstract_eval(_coo_spmm_abstract_eval) dispatch.simple_impl(coo_spmm_p) if gpu_sparse.cuda_is_supported: mlir.register_lowering( coo_spmm_p, - partial(_coo_spmm_gpu_lowering, gpu_sparse.cuda_coo_matmat), + partial(_coo_spmm_gpu_lowering, target_name_prefix='cu'), platform='cuda') if gpu_sparse.rocm_is_supported: mlir.register_lowering( coo_spmm_p, - partial(_coo_spmm_gpu_lowering, gpu_sparse.rocm_coo_matmat), + partial(_coo_spmm_gpu_lowering, target_name_prefix='hip'), platform='rocm') # csr_spmv_p @@ -151,30 +205,33 @@ def _csr_spmv_abstract_eval(data, indices, indptr, x, *, transpose, shape): shape=shape[1:] if transpose else shape[:1], dtype=x.dtype) -def _csr_spmv_gpu_lowering(csr_spmv_hlo, ctx, data, indices, indptr, x, *, transpose, shape): +def _csr_spmv_gpu_lowering(ctx, data, indices, indptr, x, *, transpose, shape, + target_name_prefix): + rows, cols = shape data_aval, indices_aval, _, x_aval = ctx.avals_in - return [csr_spmv_hlo( - data, indices, indptr, x, - shape=shape, - transpose=transpose, - data_dtype=data_aval.dtype, - index_dtype=indices_aval.dtype, - x_dtype=x_aval.dtype)] + nnz, = data_aval.shape + buffer_size, opaque = _get_module(target_name_prefix).build_csr_matvec_descriptor( + data_aval.dtype, x_aval.dtype, data_aval.dtype, indices_aval.dtype, + rows, cols, nnz, transpose) + buffer_aval = core.ShapedArray(shape=(buffer_size,), dtype=np.int8) + sub_ctx = ctx.replace(avals_out=[ctx.avals_out[0], buffer_aval]) + rule = ffi.ffi_lowering(f"{target_name_prefix}sparse_csr_matvec_ffi") + return rule(sub_ctx, data, indices, indptr, x, opaque=opaque)[:1] csr_spmv_p.def_abstract_eval(_csr_spmv_abstract_eval) dispatch.simple_impl(csr_spmv_p) if gpu_sparse.cuda_is_supported: mlir.register_lowering( csr_spmv_p, - partial(_csr_spmv_gpu_lowering, gpu_sparse.cuda_csr_matvec), + partial(_csr_spmv_gpu_lowering, target_name_prefix='cu'), platform='cuda') if gpu_sparse.rocm_is_supported: mlir.register_lowering( csr_spmv_p, - partial(_csr_spmv_gpu_lowering, gpu_sparse.rocm_csr_matvec), + partial(_csr_spmv_gpu_lowering, target_name_prefix='hip'), platform='rocm') - # csr_spmm_p +# csr_spmm_p # This is an internal-only primitive that calls into cusparse CSR SpMM. # This is a raw lowering that does no validation of inputs; the indices are # assumed to be lexicographically sorted, deduplicated, and in-bounds. @@ -199,25 +256,91 @@ def _csr_spmm_abstract_eval(data, indices, indptr, x, *, transpose, shape): shape=(shape[1] if transpose else shape[0], x.shape[1]), dtype=x.dtype) -def _csr_spmm_gpu_lowering(csr_spmm_hlo, ctx, data, indices, indptr, x, *, transpose, shape): +def _csr_spmm_gpu_lowering(ctx, data, indices, indptr, x, *, transpose, shape, + target_name_prefix): + rows, cols = shape data_aval, indices_aval, _, x_aval = ctx.avals_in - return [csr_spmm_hlo( - data, indices, indptr, x, - shape=shape, - transpose=transpose, - data_dtype=data_aval.dtype, - index_dtype=indices_aval.dtype, - B_dtype=x_aval.dtype)] + nnz, = data_aval.shape + _, Ccols = x_aval.shape + buffer_size, opaque = _get_module(target_name_prefix).build_csr_matmat_descriptor( + data_aval.dtype, x_aval.dtype, data_aval.dtype, indices_aval.dtype, + rows, cols, Ccols, nnz, transpose) + buffer_aval = core.ShapedArray(shape=(buffer_size,), dtype=np.int8) + sub_ctx = ctx.replace(avals_out=[ctx.avals_out[0], buffer_aval]) + rule = ffi.ffi_lowering(f"{target_name_prefix}sparse_csr_matmat_ffi") + return rule(sub_ctx, data, indices, indptr, x, opaque=opaque)[:1] csr_spmm_p.def_abstract_eval(_csr_spmm_abstract_eval) dispatch.simple_impl(csr_spmm_p) if gpu_sparse.cuda_is_supported: mlir.register_lowering( csr_spmm_p, - partial(_csr_spmm_gpu_lowering, gpu_sparse.cuda_csr_matmat), + partial(_csr_spmm_gpu_lowering, target_name_prefix='cu'), platform='cuda') if gpu_sparse.rocm_is_supported: mlir.register_lowering( csr_spmm_p, - partial(_csr_spmm_gpu_lowering, gpu_sparse.rocm_csr_matmat), + partial(_csr_spmm_gpu_lowering, target_name_prefix='hip'), platform='rocm') + + +if has_cpu_sparse: + def _csr_spmm_cpu_lowering(ctx, data, outer_indices, inner_indices, rhs): + rule = ffi.ffi_lowering("cpu_csr_sparse_dense_ffi") + return rule(ctx, data, outer_indices, inner_indices, rhs) + + + # _csr_spmm_cpu_lowering can handle both matrix-matrix and matrix-vector + # multiplication. + mlir.register_lowering( + csr_spmv_p, + _csr_spmm_cpu_lowering, + platform="cpu", + ) + mlir.register_lowering( + csr_spmm_p, + _csr_spmm_cpu_lowering, + platform="cpu", + ) + +def coo_todense_gpu_lowering(ctx, data, row, col, *, shape, target_name_prefix): + data_aval, row_aval, _ = ctx.avals_in + nnz, = data_aval.shape + rows, cols = shape + buffer_size, opaque = _get_module(target_name_prefix).build_coo_todense_descriptor( + data_aval.dtype, row_aval.dtype, rows, cols, nnz) + buffer_aval = core.ShapedArray(shape=(buffer_size,), dtype=np.int8) + sub_ctx = ctx.replace(avals_out=[ctx.avals_out[0], buffer_aval]) + rule = ffi.ffi_lowering(f"{target_name_prefix}sparse_coo_todense_ffi") + return rule(sub_ctx, data, row, col, opaque=opaque)[0] + +def coo_fromdense_gpu_lowering(ctx, mat, *, nnz, index_dtype, target_name_prefix): + mat_aval, = ctx.avals_in + rows, cols = mat_aval.shape + buffer_size, opaque = _get_module(target_name_prefix).build_coo_fromdense_descriptor( + mat_aval.dtype, np.dtype(index_dtype), rows, cols, nnz) + buffer_aval = core.ShapedArray(shape=(buffer_size,), dtype=np.int8) + sub_ctx = ctx.replace(avals_out=[*ctx.avals_out, buffer_aval]) + rule = ffi.ffi_lowering(f"{target_name_prefix}sparse_coo_fromdense_ffi") + return rule(sub_ctx, mat, opaque=opaque)[:3] + +def csr_todense_gpu_lowering(ctx, data, indices, indptr, *, shape, target_name_prefix): + data_aval, indices_aval, _, = ctx.avals_in + nnz, = data_aval.shape + rows, cols = shape + buffer_size, opaque = _get_module(target_name_prefix).build_csr_todense_descriptor( + data_aval.dtype, indices_aval.dtype, rows, cols, nnz) + buffer_aval = core.ShapedArray(shape=(buffer_size,), dtype=np.int8) + sub_ctx = ctx.replace(avals_out=[ctx.avals_out[0], buffer_aval]) + rule = ffi.ffi_lowering(f"{target_name_prefix}sparse_csr_todense_ffi") + return rule(sub_ctx, data, indices, indptr, opaque=opaque)[0] + +def csr_fromdense_gpu_lowering(ctx, mat, *, nnz, index_dtype, target_name_prefix): + mat_aval, = ctx.avals_in + rows, cols = mat_aval.shape + buffer_size, opaque = _get_module(target_name_prefix).build_csr_fromdense_descriptor( + mat_aval.dtype, np.dtype(index_dtype), rows, cols, nnz) + buffer_aval = core.ShapedArray(shape=(buffer_size,), dtype=np.int8) + sub_ctx = ctx.replace(avals_out=[*ctx.avals_out, buffer_aval]) + rule = ffi.ffi_lowering(f"{target_name_prefix}sparse_csr_fromdense_ffi") + return rule(sub_ctx, mat, opaque=opaque)[:3] diff --git a/jax/experimental/sparse/ad.py b/jax/experimental/sparse/ad.py index 018047e3d5e1..861ef5289cdd 100644 --- a/jax/experimental/sparse/ad.py +++ b/jax/experimental/sparse/ad.py @@ -22,7 +22,7 @@ from jax._src import core from jax import tree_util from jax._src.api_util import _ensure_index, _ensure_index_tuple -from jax.util import safe_zip +from jax._src.util import safe_zip from jax._src.util import split_list, wraps from jax._src.traceback_util import api_boundary from jax.experimental.sparse._base import JAXSparse diff --git a/jax/experimental/sparse/bcoo.py b/jax/experimental/sparse/bcoo.py index 42820fe73651..0365f93d551a 100644 --- a/jax/experimental/sparse/bcoo.py +++ b/jax/experimental/sparse/bcoo.py @@ -38,7 +38,7 @@ from jax.experimental.sparse._lowerings import coo_spmv_p, coo_spmm_p from jax._src.interpreters import mlir import jax.numpy as jnp -from jax.util import safe_zip, unzip2, split_list +from jax._src.util import safe_zip, unzip2, split_list from jax._src import api_util from jax._src import config from jax._src import core diff --git a/jax/experimental/sparse/bcsr.py b/jax/experimental/sparse/bcsr.py index 7fefd1572f45..c7b056c5adfa 100644 --- a/jax/experimental/sparse/bcsr.py +++ b/jax/experimental/sparse/bcsr.py @@ -27,12 +27,13 @@ import jax.numpy as jnp from jax import lax from jax import tree_util +from jax.experimental.sparse import _lowerings from jax.experimental.sparse._base import JAXSparse from jax.experimental.sparse import bcoo from jax.experimental.sparse.util import ( - nfold_vmap, _count_stored_elements, - _csr_to_coo, CuSparseEfficiencyWarning, SparseInfo, Shape) -from jax.util import split_list, safe_zip + nfold_vmap, _count_stored_elements, _csr_to_coo, + SparseEfficiencyWarning, CuSparseEfficiencyWarning, SparseInfo, Shape) +from jax._src.util import split_list, safe_zip from jax._src import api_util from jax._src import config @@ -144,7 +145,7 @@ def _bcsr_to_bcoo(indices: jax.Array, indptr: jax.Array, *, def _bcoo_to_bcsr(indices: Array, *, shape: Sequence[int], - index_dtype: DTypeLike = jnp.int32) -> tuple[Array, Array]: + index_dtype: DTypeLike) -> tuple[Array, Array]: """Given BCOO (indices), return BCSR (indices, indptr). Note: this assumes that ``indices`` are lexicographically sorted within each batch. @@ -237,7 +238,9 @@ def _bcsr_fromdense_impl(mat, *, nse, n_batch, n_dense, index_dtype): raise ValueError("bcsr_fromdense: must have 2 sparse dimensions.") bcoo_mat = bcoo.bcoo_fromdense(mat, nse=nse, index_dtype=index_dtype, n_dense=n_dense, n_batch=n_batch) - indices, indptr = _bcoo_to_bcsr(bcoo_mat.indices, shape=mat.shape) + indices, indptr = _bcoo_to_bcsr( + bcoo_mat.indices, shape=mat.shape, index_dtype=index_dtype + ) return bcoo_mat.data, indices, indptr @@ -620,9 +623,9 @@ def _bcsr_correct_out_of_bound_indices(data, indices, indptr, rhs, *, shape): _bcsr_correct_out_of_bound_indices, multiple_results=True) def _bcsr_dot_general_gpu_lowering( - csr_matvec_lowering, csr_matmat_lowering, + # csr_matvec_lowering, csr_matmat_lowering, ctx, lhs_data, lhs_indices, lhs_indptr, rhs, *, dimension_numbers, - preferred_element_type, lhs_spinfo: SparseInfo): + preferred_element_type, lhs_spinfo: SparseInfo, target_name_prefix): if not config.bcoo_cusparse_lowering.value: return _bcsr_dot_general_default_lowering( @@ -674,22 +677,112 @@ def _bcsr_dot_general_gpu_lowering( lhs_data, lhs_indices = _bcsr_correct_out_of_bound_indices_lowered( ctx, lhs_data, lhs_indices, lhs_indptr, rhs, shape=lhs_spinfo.shape) + sub_ctx = ctx if rhs_aval.ndim == 1: - dot_general_fn = csr_matvec_lowering - x_dtype = 'x_dtype' + dot_general_fn = _lowerings._csr_spmv_gpu_lowering elif rhs_aval.ndim == 2: - dot_general_fn = csr_matmat_lowering - x_dtype = 'B_dtype' + dot_general_fn = _lowerings._csr_spmm_gpu_lowering if rhs_contract[0] == 1: rhs = hlo.transpose(rhs, permutation=mlir.dense_int_array([1, 0])) + *avals_in, rhs_aval = sub_ctx.avals_in + rhs_aval = core.ShapedArray( + shape=(rhs_aval.shape[1], rhs_aval.shape[0]), dtype=rhs_aval.dtype) + sub_ctx = sub_ctx.replace(avals_in=[*avals_in, rhs_aval]) else: raise ValueError(f"rhs has to be 1d or 2d; get {rhs_aval.ndim}d.") - return [dot_general_fn(lhs_data, lhs_indices, lhs_indptr, rhs, - shape=lhs_spinfo.shape, transpose=False, - data_dtype=lhs_data_aval.dtype, - index_dtype=lhs_indices_aval.dtype, - **{x_dtype: rhs_aval.dtype})] + return dot_general_fn(sub_ctx, lhs_data, lhs_indices, lhs_indptr, rhs, + shape=lhs_spinfo.shape, transpose=False, + target_name_prefix=target_name_prefix) + + +def _bcsr_dot_general_cpu_lowering( + # csr_matvec_lowering, csr_matmat_lowering, + ctx, + lhs_data, + lhs_indices, + lhs_indptr, + rhs, + *, + dimension_numbers, + preferred_element_type, + lhs_spinfo: SparseInfo, +): + + (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers + lhs_data_aval, lhs_indices_aval, lhs_indptr_aval, rhs_aval = ctx.avals_in + props = _validate_bcsr( + lhs_data_aval, lhs_indices_aval, lhs_indptr_aval, lhs_spinfo.shape + ) + + use_default_lowering = False + dtype = lhs_data_aval.dtype + if lhs_batch or rhs_batch: + # TODO(willfroom): Add support for batched matrices. + use_default_lowering = True + elif lhs_data_aval.dtype != rhs_aval.dtype: + use_default_lowering = True + elif ( + preferred_element_type is not None + and preferred_element_type != lhs_data_aval.dtype + ): + use_default_lowering = True + elif len(lhs_spinfo.shape) != 2 or rhs_aval.ndim not in [1, 2]: + # only matmat / matvec supported + use_default_lowering = True + elif props.n_batch or props.n_dense: + # batch and dense dimensions in BCSR not supported + use_default_lowering = True + elif list(lhs_contract) != [1] or list(rhs_contract) != [0]: + # TODO(willfroom): Add support for non-canonical dots. + use_default_lowering = True + elif lhs_indices_aval.dtype != lhs_indptr_aval.dtype: + warnings.warn( + "bcsr_dot_general cpu lowering not available, " + f" {lhs_indices_aval.dtype=} and {lhs_indptr_aval.dtype=} do not match." + " Falling back to default implementation.", + SparseEfficiencyWarning, + ) + use_default_lowering = True + elif lhs_indices_aval.dtype not in [np.int32, np.int64]: + use_default_lowering = True + warnings.warn( + "bcsr_dot_general cpu lowering not available for" + f" {lhs_indices_aval.dtype=}. Falling back to default implementation.", + SparseEfficiencyWarning, + ) + elif dtype not in [ + np.int32, + np.int64, + np.float32, + np.float64, + np.complex64, + np.complex128, + ]: + # This would be supported if not for the dtype. + warnings.warn( + "bcsr_dot_general cpu lowering not available " + f"for {dtype=}. Falling back to default implementation.", + SparseEfficiencyWarning, + ) + use_default_lowering = True + + if use_default_lowering: + return _bcsr_dot_general_default_lowering( + ctx, + lhs_data, + lhs_indices, + lhs_indptr, + rhs, + dimension_numbers=dimension_numbers, + preferred_element_type=preferred_element_type, + lhs_spinfo=lhs_spinfo, + ) + + return _lowerings._csr_spmm_cpu_lowering( + ctx, lhs_data, lhs_indptr, lhs_indices, rhs + ) + _bcsr_dot_general_default_lowering = mlir.lower_fun( _bcsr_dot_general_impl, multiple_results=False) @@ -700,17 +793,20 @@ def _bcsr_dot_general_gpu_lowering( if gpu_sparse.cuda_is_supported: mlir.register_lowering(bcsr_dot_general_p, partial(_bcsr_dot_general_gpu_lowering, - gpu_sparse.cuda_csr_matvec, - gpu_sparse.cuda_csr_matmat), + target_name_prefix='cu'), platform='cuda') if gpu_sparse.rocm_is_supported: mlir.register_lowering(bcsr_dot_general_p, partial(_bcsr_dot_general_gpu_lowering, - gpu_sparse.rocm_csr_matvec, - gpu_sparse.rocm_csr_matmat), + target_name_prefix='hip'), platform='rocm') +if _lowerings.has_cpu_sparse: + mlir.register_lowering( + bcsr_dot_general_p, _bcsr_dot_general_cpu_lowering, platform="cpu" + ) + #---------------------------------------------------------------------- # BCOO functions that maybe should be primitives? @@ -867,7 +963,9 @@ def from_bcoo(cls, arr: bcoo.BCOO) -> BCSR: raise NotImplementedError(f"BSCR.from_bcoo requires n_sparse=2; got {arr.n_sparse=}") if not arr.indices_sorted: arr = arr.sort_indices() - indices, indptr = _bcoo_to_bcsr(arr.indices, shape=arr.shape) + indices, indptr = _bcoo_to_bcsr( + arr.indices, shape=arr.shape, index_dtype=arr.indices.dtype + ) return cls((arr.data, indices, indptr), shape=arr.shape) @classmethod diff --git a/jax/experimental/sparse/coo.py b/jax/experimental/sparse/coo.py index c65bc87235d6..014fe9128c1b 100644 --- a/jax/experimental/sparse/coo.py +++ b/jax/experimental/sparse/coo.py @@ -26,6 +26,7 @@ import jax from jax import lax from jax.interpreters import mlir +from jax.experimental.sparse import _lowerings from jax.experimental.sparse._base import JAXSparse from jax.experimental.sparse.util import _coo_extract, CuSparseEfficiencyWarning from jax import tree_util @@ -205,7 +206,7 @@ def _coo_todense_abstract_eval(data, row, col, *, spinfo): _coo_todense_lowering = mlir.lower_fun( _coo_todense_impl, multiple_results=False) -def _coo_todense_gpu_lowering(coo_todense_hlo, ctx, data, row, col, *, spinfo): +def _coo_todense_gpu_lowering(ctx, data, row, col, *, spinfo, target_name_prefix): data_aval, row_aval, _ = ctx.avals_in dtype = data_aval.dtype if not (np.issubdtype(dtype, np.floating) or np.issubdtype(dtype, np.complexfloating)): @@ -226,8 +227,13 @@ def _coo_todense_gpu_lowering(coo_todense_hlo, ctx, data, row, col, *, spinfo): "back to the default implementation.", CuSparseEfficiencyWarning) return _coo_todense_lowering(ctx, data, row, col, spinfo=spinfo) - result = coo_todense_hlo( - data, row, col, shape=shape, data_dtype=dtype, index_dtype=row_aval.dtype) + sub_ctx = ctx + if transpose: + out_aval, = ctx.avals_out + out_aval = core.ShapedArray(shape=out_aval.shape[::-1], dtype=out_aval.dtype) + sub_ctx = sub_ctx.replace(avals_out=[out_aval]) + result = _lowerings.coo_todense_gpu_lowering( + sub_ctx, data, row, col, shape=shape, target_name_prefix=target_name_prefix) return ( [hlo.transpose(result, mlir.dense_int_array([1, 0]))] if transpose else [result]) @@ -255,12 +261,12 @@ def _coo_todense_transpose(ct, data, row, col, *, spinfo): if gpu_sparse.cuda_is_supported: mlir.register_lowering( coo_todense_p, - partial(_coo_todense_gpu_lowering, gpu_sparse.cuda_coo_todense), + partial(_coo_todense_gpu_lowering, target_name_prefix='cu'), platform='cuda') if gpu_sparse.rocm_is_supported: mlir.register_lowering( coo_todense_p, - partial(_coo_todense_gpu_lowering, gpu_sparse.rocm_coo_todense), + partial(_coo_todense_gpu_lowering, target_name_prefix='hip'), platform='rocm') #-------------------------------------------------------------------- @@ -325,20 +331,15 @@ def _coo_fromdense_abstract_eval(mat, *, nse, index_dtype): _coo_fromdense_lowering = mlir.lower_fun( _coo_fromdense_impl, multiple_results=True) -def _coo_fromdense_gpu_lowering(coo_fromdense_hlo, ctx, mat, *, nse, - index_dtype): +def _coo_fromdense_gpu_lowering(ctx, mat, *, nse, index_dtype, target_name_prefix): dtype = ctx.avals_in[0].dtype if not (np.issubdtype(dtype, np.floating) or np.issubdtype(dtype, np.complexfloating)): warnings.warn(f"coo_fromdense cusparse/hipsparse lowering not available for {dtype=}. " "Falling back to default implementation.", CuSparseEfficiencyWarning) return _coo_fromdense_lowering(ctx, mat, nse=nse, index_dtype=index_dtype) - data, row, col = coo_fromdense_hlo( - mat, nnz=nse, - data_dtype=dtype, - index_dtype=np.dtype(index_dtype), - index_type=mlir.dtype_to_ir_type(np.dtype(index_dtype))) - return [data, row, col] - + return _lowerings.coo_fromdense_gpu_lowering( + ctx, mat, nnz=nse, index_dtype=index_dtype, + target_name_prefix=target_name_prefix) def _coo_fromdense_jvp(primals, tangents, *, nse, index_dtype): M, = primals @@ -373,12 +374,12 @@ def _coo_fromdense_transpose(ct, M, *, nse, index_dtype): if gpu_sparse.cuda_is_supported: mlir.register_lowering( coo_fromdense_p, - partial(_coo_fromdense_gpu_lowering, gpu_sparse.cuda_coo_fromdense), + partial(_coo_fromdense_gpu_lowering, target_name_prefix='cu'), platform='cuda') if gpu_sparse.rocm_is_supported: mlir.register_lowering( coo_fromdense_p, - partial(_coo_fromdense_gpu_lowering, gpu_sparse.rocm_coo_fromdense), + partial(_coo_fromdense_gpu_lowering, target_name_prefix='hip'), platform='rocm') #-------------------------------------------------------------------- @@ -444,8 +445,8 @@ def _coo_matvec_abstract_eval(data, row, col, v, *, spinfo, transpose): _coo_matvec_lowering = mlir.lower_fun( _coo_matvec_impl, multiple_results=False) -def _coo_matvec_gpu_lowering(coo_matvec_hlo, ctx, data, row, col, v, *, spinfo, - transpose): +def _coo_matvec_gpu_lowering(ctx, data, row, col, v, *, spinfo, transpose, + target_name_prefix): data_aval, row_aval, _, x_aval = ctx.avals_in dtype = data_aval.dtype if dtype not in [np.float32, np.float64, np.complex64, np.complex128]: @@ -466,9 +467,9 @@ def _coo_matvec_gpu_lowering(coo_matvec_hlo, ctx, data, row, col, v, *, spinfo, return _coo_matvec_lowering(ctx, data, row, col, v, spinfo=spinfo, transpose=transpose) - return [coo_matvec_hlo( - data, row, col, v, shape=shape, transpose=transpose, - index_dtype=row_aval.dtype, data_dtype=dtype, x_dtype=x_aval.dtype)] + return _lowerings._coo_spmv_gpu_lowering( + ctx, data, row, col, v, transpose=transpose, shape=shape, + target_name_prefix=target_name_prefix) def _coo_matvec_jvp_mat(data_dot, data, row, col, v, *, spinfo, transpose): @@ -497,12 +498,12 @@ def _coo_matvec_transpose(ct, data, row, col, v, *, spinfo, transpose): if gpu_sparse.cuda_is_supported: mlir.register_lowering( coo_matvec_p, - partial(_coo_matvec_gpu_lowering, gpu_sparse.cuda_coo_matvec), + partial(_coo_matvec_gpu_lowering, target_name_prefix='cu'), platform='cuda') if gpu_sparse.rocm_is_supported: mlir.register_lowering( coo_matvec_p, - partial(_coo_matvec_gpu_lowering, gpu_sparse.rocm_coo_matvec), + partial(_coo_matvec_gpu_lowering, target_name_prefix='hip'), platform='rocm') @@ -567,8 +568,8 @@ def _coo_matmat_abstract_eval(data, row, col, B, *, spinfo, transpose): _coo_matmat_lowering = mlir.lower_fun(_coo_matmat_impl, multiple_results=False) -def _coo_matmat_gpu_lowering(coo_matmat_hlo, ctx, data, row, col, B, *, spinfo, - transpose): +def _coo_matmat_gpu_lowering(ctx, data, row, col, B, *, spinfo, transpose, + target_name_prefix): data_aval, row_aval, _, B_aval = ctx.avals_in dtype = data_aval.dtype if dtype not in [np.float32, np.float64, np.complex64, np.complex128]: @@ -589,10 +590,9 @@ def _coo_matmat_gpu_lowering(coo_matmat_hlo, ctx, data, row, col, B, *, spinfo, return _coo_matmat_lowering(ctx, data, row, col, B, spinfo=spinfo, transpose=transpose) - return [coo_matmat_hlo(data, row, col, B, shape=shape, - transpose=transpose, x_dtype=B_aval.dtype, - data_dtype=data_aval.dtype, - index_dtype=row_aval.dtype)] + return _lowerings._coo_spmm_gpu_lowering( + ctx, data, row, col, B, transpose=transpose, shape=shape, + target_name_prefix=target_name_prefix) def _coo_matmat_jvp_left(data_dot, data, row, col, B, *, spinfo, transpose): @@ -618,10 +618,10 @@ def _coo_matmat_transpose(ct, data, row, col, B, *, spinfo, transpose): if gpu_sparse.cuda_is_supported: mlir.register_lowering( coo_matmat_p, - partial(_coo_matmat_gpu_lowering, gpu_sparse.cuda_coo_matmat), + partial(_coo_matmat_gpu_lowering, target_name_prefix='cu'), platform='cuda') if gpu_sparse.rocm_is_supported: mlir.register_lowering( coo_matmat_p, - partial(_coo_matmat_gpu_lowering, gpu_sparse.rocm_coo_matmat), + partial(_coo_matmat_gpu_lowering, target_name_prefix='hip'), platform='rocm') diff --git a/jax/experimental/sparse/csr.py b/jax/experimental/sparse/csr.py index 84171855b85e..cbc5bad1100b 100644 --- a/jax/experimental/sparse/csr.py +++ b/jax/experimental/sparse/csr.py @@ -23,6 +23,7 @@ import jax from jax.interpreters import mlir +from jax.experimental.sparse import _lowerings from jax.experimental.sparse._base import JAXSparse from jax.experimental.sparse.coo import _coo_matmat, _coo_matvec, _coo_todense, COOInfo from jax.experimental.sparse.util import _csr_to_coo, _csr_extract, CuSparseEfficiencyWarning @@ -249,17 +250,16 @@ def _csr_todense_abstract_eval(data, indices, indptr, *, shape): _csr_todense_lowering = mlir.lower_fun( _csr_todense_impl, multiple_results=False) -def _csr_todense_gpu_lowering(csr_todense_hlo, ctx, data, indices, indptr, *, - shape): +def _csr_todense_gpu_lowering(ctx, data, indices, indptr, *, shape, target_name_prefix): data_aval, indices_aval, _ = ctx.avals_in dtype = data_aval.dtype if not (np.issubdtype(dtype, np.floating) or np.issubdtype(dtype, np.complexfloating)): warnings.warn(f"csr_todense cusparse/hipsparse lowering not available for {dtype=}. " "Falling back to default implementation.", CuSparseEfficiencyWarning) return _csr_todense_lowering(ctx, data, indices, indptr, shape=shape) - return [csr_todense_hlo( - data, indices, indptr, shape=shape, data_dtype=dtype, - index_dtype=indices_aval.dtype)] + return [_lowerings.csr_todense_gpu_lowering( + ctx, data, indices, indptr, shape=shape, + target_name_prefix=target_name_prefix)] def _csr_todense_jvp(data_dot, data, indices, indptr, *, shape): @@ -284,12 +284,12 @@ def _csr_todense_transpose(ct, data, indices, indptr, *, shape): if gpu_sparse.cuda_is_supported: mlir.register_lowering( csr_todense_p, - partial(_csr_todense_gpu_lowering, gpu_sparse.cuda_csr_todense), + partial(_csr_todense_gpu_lowering, target_name_prefix='cu'), platform='cuda') if gpu_sparse.rocm_is_supported: mlir.register_lowering( csr_todense_p, - partial(_csr_todense_gpu_lowering, gpu_sparse.rocm_csr_todense), + partial(_csr_todense_gpu_lowering, target_name_prefix='hip'), platform='rocm') @@ -359,16 +359,16 @@ def _csr_fromdense_abstract_eval(mat, *, nse, index_dtype): _csr_fromdense_lowering = mlir.lower_fun(_csr_fromdense_impl, multiple_results=True) -def _csr_fromdense_gpu_lowering(csr_fromdense_hlo, ctx, mat, *, nse, index_dtype): +def _csr_fromdense_gpu_lowering(ctx, mat, *, nse, index_dtype, + target_name_prefix): dtype = ctx.avals_in[0].dtype if not (np.issubdtype(dtype, np.floating) or np.issubdtype(dtype, np.complexfloating)): warnings.warn(f"csr_fromdense cusparse/hipsparse lowering not available for {dtype=}. " "Falling back to default implementation.", CuSparseEfficiencyWarning) return _csr_fromdense_lowering(ctx, mat, nse=nse, index_dtype=index_dtype) - data, indices, indptr = csr_fromdense_hlo( - mat, nnz=nse, index_dtype=np.dtype(index_dtype), - data_dtype=dtype, index_type=mlir.dtype_to_ir_type(np.dtype(index_dtype))) - return [data, indices, indptr] + return _lowerings.csr_fromdense_gpu_lowering( + ctx, mat, nnz=nse, index_dtype=index_dtype, + target_name_prefix=target_name_prefix) def _csr_fromdense_jvp(primals, tangents, *, nse, index_dtype): @@ -404,12 +404,12 @@ def _csr_fromdense_transpose(ct, M, *, nse, index_dtype): if gpu_sparse.cuda_is_supported: mlir.register_lowering( csr_fromdense_p, - partial(_csr_fromdense_gpu_lowering, gpu_sparse.cuda_csr_fromdense), + partial(_csr_fromdense_gpu_lowering, target_name_prefix='cu'), platform='cuda') if gpu_sparse.rocm_is_supported: mlir.register_lowering( csr_fromdense_p, - partial(_csr_fromdense_gpu_lowering, gpu_sparse.rocm_csr_fromdense), + partial(_csr_fromdense_gpu_lowering, target_name_prefix='hip'), platform='rocm') #-------------------------------------------------------------------- @@ -470,8 +470,8 @@ def _csr_matvec_abstract_eval(data, indices, indptr, v, *, shape, transpose): _csr_matvec_lowering = mlir.lower_fun(_csr_matvec_impl, multiple_results=False) -def _csr_matvec_gpu_lowering(csr_matvec_hlo, ctx, data, indices, indptr, v, *, - shape, transpose): +def _csr_matvec_gpu_lowering(ctx, data, indices, indptr, v, *, shape, transpose, + target_name_prefix): data_aval, indices_aval, _, v_aval = ctx.avals_in dtype = data_aval.dtype if dtype not in [np.float32, np.float64, np.complex64, np.complex128]: @@ -479,10 +479,9 @@ def _csr_matvec_gpu_lowering(csr_matvec_hlo, ctx, data, indices, indptr, v, *, "Falling back to default implementation.", CuSparseEfficiencyWarning) return _csr_matvec_lowering(ctx, data, indices, indptr, v, shape=shape, transpose=transpose) - return [csr_matvec_hlo( - data, indices, indptr, v, shape=shape, transpose=transpose, - data_dtype=dtype, index_dtype=indices_aval.dtype, x_dtype=v_aval.dtype)] - + return _lowerings._csr_spmv_gpu_lowering( + ctx, data, indices, indptr, v, shape=shape, transpose=transpose, + target_name_prefix=target_name_prefix) def _csr_matvec_jvp_mat(data_dot, data, indices, indptr, v, *, shape, transpose): return _csr_matvec(data_dot, indices, indptr, v, shape=shape, transpose=transpose) @@ -511,12 +510,12 @@ def _csr_matvec_transpose(ct, data, indices, indptr, v, *, shape, transpose): if gpu_sparse.cuda_is_supported: mlir.register_lowering( csr_matvec_p, - partial(_csr_matvec_gpu_lowering, gpu_sparse.cuda_csr_matvec), + partial(_csr_matvec_gpu_lowering, target_name_prefix='cu'), platform='cuda') if gpu_sparse.rocm_is_supported: mlir.register_lowering( csr_matvec_p, - partial(_csr_matvec_gpu_lowering, gpu_sparse.rocm_csr_matvec), + partial(_csr_matvec_gpu_lowering, target_name_prefix='hip'), platform='rocm') @@ -580,8 +579,8 @@ def _csr_matmat_abstract_eval(data, indices, indptr, B, *, shape, transpose): _csr_matmat_lowering = mlir.lower_fun(_csr_matmat_impl, multiple_results=False) -def _csr_matmat_gpu_lowering(csr_matmat_hlo, ctx, data, indices, indptr, B, *, - shape, transpose): +def _csr_matmat_gpu_lowering(ctx, data, indices, indptr, B, *, shape, transpose, + target_name_prefix): data_aval, indices_aval, _, B_aval = ctx.avals_in dtype = data_aval.dtype if dtype not in [np.float32, np.float64, np.complex64, np.complex128]: @@ -589,11 +588,9 @@ def _csr_matmat_gpu_lowering(csr_matmat_hlo, ctx, data, indices, indptr, B, *, "Falling back to default implementation.", CuSparseEfficiencyWarning) return _csr_matmat_lowering(ctx, data, indices, indptr, B, shape=shape, transpose=transpose) - return [csr_matmat_hlo( - data, indices, indptr, B, shape=shape, transpose=transpose, - index_dtype=indices_aval.dtype, data_dtype=data_aval.dtype, - B_dtype=B_aval.dtype)] - + return _lowerings._csr_spmm_gpu_lowering( + ctx, data, indices, indptr, B, shape=shape, transpose=transpose, + target_name_prefix=target_name_prefix) def _csr_matmat_jvp_left(data_dot, data, indices, indptr, B, *, shape, transpose): return _csr_matmat(data_dot, indices, indptr, B, shape=shape, transpose=transpose) @@ -621,10 +618,10 @@ def _csr_matmat_transpose(ct, data, indices, indptr, B, *, shape, transpose): if gpu_sparse.cuda_is_supported: mlir.register_lowering( csr_matmat_p, - partial(_csr_matmat_gpu_lowering, gpu_sparse.cuda_csr_matmat), + partial(_csr_matmat_gpu_lowering, target_name_prefix='cu'), platform='cuda') if gpu_sparse.rocm_is_supported: mlir.register_lowering( csr_matmat_p, - partial(_csr_matmat_gpu_lowering, gpu_sparse.rocm_csr_matmat), + partial(_csr_matmat_gpu_lowering, target_name_prefix='hip'), platform='rocm') diff --git a/jax/experimental/sparse/linalg.py b/jax/experimental/sparse/linalg.py index a931b0a30dcf..b2e57caba9a6 100644 --- a/jax/experimental/sparse/linalg.py +++ b/jax/experimental/sparse/linalg.py @@ -29,7 +29,6 @@ from jax._src import core from jax._src import ffi from jax._src.interpreters import ad -from jax._src.lib import gpu_solver import numpy as np from scipy.sparse import csr_matrix, linalg @@ -534,11 +533,6 @@ def _spsolve_abstract_eval(data, indices, indptr, b, *, tol, reorder): def _spsolve_gpu_lowering(ctx, data, indices, indptr, b, *, tol, reorder): - # TODO(danfm): remove after JAX 0.5.1 release. - if hasattr(gpu_solver, "cuda_csrlsvqr"): - data_aval, _, _, _, = ctx.avals_in - return gpu_solver.cuda_csrlsvqr(data_aval.dtype, data, indices, - indptr, b, tol, reorder) return ffi.ffi_lowering("cusolver_csrlsvqr_ffi")( ctx, data, indices, indptr, b, tol=np.float64(tol), reorder=np.int32(reorder)) diff --git a/jax/experimental/sparse/random.py b/jax/experimental/sparse/random.py index f90c2572d282..a9146b7746e0 100644 --- a/jax/experimental/sparse/random.py +++ b/jax/experimental/sparse/random.py @@ -18,7 +18,7 @@ from jax import dtypes from jax import vmap from jax import random -from jax.util import split_list +from jax._src.util import split_list import jax.numpy as jnp from jax.experimental import sparse diff --git a/jax/experimental/sparse/test_util.py b/jax/experimental/sparse/test_util.py index 77c97513041c..63e035d2d1ac 100644 --- a/jax/experimental/sparse/test_util.py +++ b/jax/experimental/sparse/test_util.py @@ -29,7 +29,7 @@ from jax._src.typing import DTypeLike from jax.experimental import sparse import jax.numpy as jnp -from jax.util import safe_zip, split_list +from jax._src.util import safe_zip, split_list import numpy as np MATMUL_TOL = { diff --git a/jax/experimental/sparse/transform.py b/jax/experimental/sparse/transform.py index ce1d3f4af9d0..a16756d42c45 100644 --- a/jax/experimental/sparse/transform.py +++ b/jax/experimental/sparse/transform.py @@ -68,7 +68,7 @@ from jax._src.lib import pytree from jax._src.interpreters import partial_eval as pe from jax.tree_util import tree_flatten, tree_map, tree_unflatten -from jax.util import safe_map, safe_zip, split_list +from jax._src.util import safe_map, safe_zip, split_list from jax._src.lax.control_flow import _check_tree_and_avals from jax._src.numpy import indexing as jnp_indexing from jax.experimental import sparse diff --git a/jax/experimental/sparse/util.py b/jax/experimental/sparse/util.py index 36e9a9c51664..7c6bfb1ec345 100644 --- a/jax/experimental/sparse/util.py +++ b/jax/experimental/sparse/util.py @@ -25,7 +25,7 @@ from jax._src import core from jax._src.api_util import flatten_axes import jax.numpy as jnp -from jax.util import safe_zip +from jax._src.util import safe_zip from jax._src.lax.lax import _dot_general_shape_rule, DotDimensionNumbers from jax._src.typing import Array diff --git a/jax/experimental/topologies.py b/jax/experimental/topologies.py index 06be2b74853f..94b63769f101 100644 --- a/jax/experimental/topologies.py +++ b/jax/experimental/topologies.py @@ -19,7 +19,7 @@ import jax from jax.experimental import mesh_utils from jax._src.lib import xla_client as xc -from jax._src.lib import xla_extension +from jax._src.lib import _jax from jax._src import xla_bridge as xb Device = xc.Device @@ -46,7 +46,7 @@ def get_topology_desc( try: topology = xb.make_pjrt_topology(platform, topology_name, **kwargs) return TopologyDescription(topology._make_compile_only_devices()) - except xla_extension.XlaRuntimeError as e: + except _jax.XlaRuntimeError as e: msg, *_ = e.args if msg.startswith("UNIMPLEMENTED"): raise NotImplementedError(msg) from e diff --git a/jax/experimental/x64_context.py b/jax/experimental/x64_context.py index 1772d466b006..3ef5289df4f1 100644 --- a/jax/experimental/x64_context.py +++ b/jax/experimental/x64_context.py @@ -30,6 +30,13 @@ def enable_x64(new_val: bool = True): """Experimental context manager to temporarily enable X64 mode. + .. warning:: + + This context manager remains experimental because it is fundamentally broken + and can result in unexpected behavior, particularly when used in conjunction + with JAX transformations like :func:`jax.jit`, :func:`jax.vmap`, :func:`jax.grad`, + and others. See https://github.com/jax-ml/jax/issues/5982 for details. + Usage:: >>> x = np.arange(5, dtype='float64') @@ -40,7 +47,7 @@ def enable_x64(new_val: bool = True): See Also -------- - jax.experimental.enable_x64 : temporarily enable X64 mode. + jax.experimental.disable_x64 : temporarily disable X64 mode. """ with config.enable_x64(new_val): yield @@ -49,6 +56,13 @@ def enable_x64(new_val: bool = True): def disable_x64(): """Experimental context manager to temporarily disable X64 mode. + .. warning:: + + This context manager remains experimental because it is fundamentally broken + and can result in unexpected behavior, particularly when used in conjunction + with JAX transformations like :func:`jax.jit`, :func:`jax.vmap`, :func:`jax.grad`, + and others. See https://github.com/jax-ml/jax/issues/5982 for details. + Usage:: >>> x = np.arange(5, dtype='float64') diff --git a/jax/extend/BUILD b/jax/extend/BUILD index 59958c1da389..e6414305a51b 100644 --- a/jax/extend/BUILD +++ b/jax/extend/BUILD @@ -43,8 +43,12 @@ py_library_providing_imports_info( deps = [ "//jax", "//jax:abstract_arrays", + "//jax:ad", "//jax:ad_util", + "//jax:api", "//jax:core", + "//jax:custom_derivatives", + "//jax:lax", ], ) @@ -58,7 +62,8 @@ pytype_strict_library( name = "backend", srcs = ["backend.py"], deps = [ - "//jax", + "//jax:api", + "//jax:util", "//jax:xla_bridge", ], ) @@ -66,7 +71,16 @@ pytype_strict_library( pytype_strict_library( name = "random", srcs = ["random.py"], - deps = ["//jax"], + deps = [ + "//jax", + "//jax:extend_src", + ], +) + +pytype_strict_library( + name = "sharding", + srcs = ["sharding.py"], + deps = ["//jax:sharding_impls"], ) pytype_strict_library( @@ -78,7 +92,10 @@ pytype_strict_library( pytype_strict_library( name = "ffi", srcs = ["ffi.py"], - deps = ["//jax"], + deps = [ + "//jax", + "//jax:ffi", + ], ) pytype_strict_library( diff --git a/jax/extend/__init__.py b/jax/extend/__init__.py index bbb5925ab41a..c875abb9c598 100644 --- a/jax/extend/__init__.py +++ b/jax/extend/__init__.py @@ -16,24 +16,24 @@ The :mod:`jax.extend` module provides modules for access to JAX internal machinery. See -`JEP #15856 `_. +`JEP #15856 `_. This module is not the only means by which JAX aims to be extensible. For example, the main JAX API offers mechanisms for `customizing derivatives -`_, +`_, `registering custom pytree definitions -`_, +`_, and more. API policy ---------- Unlike the -`public API `_, +`public API `_, this module offers **no compatibility guarantee** across releases. Breaking changes will be announced via the -`JAX project changelog `_. +`JAX project changelog `_. """ from jax.extend import ( diff --git a/jax/extend/backend.py b/jax/extend/backend.py index 8d5488baba16..12c84ecd1f20 100644 --- a/jax/extend/backend.py +++ b/jax/extend/backend.py @@ -27,3 +27,8 @@ from jax._src.interpreters.pxla import ( get_default_device as get_default_device ) +from jax._src import ( + util as _util +) +add_clear_backends_callback = _util.cache_clearing_funs.add # type: ignore +del _util diff --git a/jax/extend/core/primitives.py b/jax/extend/core/primitives.py index d8a10154cf4a..5b790656271c 100644 --- a/jax/extend/core/primitives.py +++ b/jax/extend/core/primitives.py @@ -24,9 +24,7 @@ from jax._src.custom_derivatives import ( custom_jvp_call_p as custom_jvp_call_p, - custom_jvp_call_jaxpr_p as custom_jvp_call_jaxpr_p, custom_vjp_call_p as custom_vjp_call_p, - custom_vjp_call_jaxpr_p as custom_vjp_call_jaxpr_p, ) from jax._src.dispatch import device_put_p as device_put_p @@ -149,7 +147,6 @@ igamma_p as igamma_p, lgamma_p as lgamma_p, polygamma_p as polygamma_p, - random_gamma_grad_p as random_gamma_grad_p, regularized_incomplete_beta_p as regularized_incomplete_beta_p, zeta_p as zeta_p, ) @@ -226,7 +223,10 @@ schur_p as schur_p, ) -from jax._src.pjit import sharding_constraint_p as sharding_constraint_p +from jax._src.pjit import ( + pjit_p as pjit_p, + sharding_constraint_p as sharding_constraint_p, +) from jax._src.prng import ( random_bits_p as random_bits_p, diff --git a/jax/extend/ifrt_programs.py b/jax/extend/ifrt_programs.py index 715dfd43592c..13ba9088bc55 100644 --- a/jax/extend/ifrt_programs.py +++ b/jax/extend/ifrt_programs.py @@ -15,8 +15,8 @@ # Note: import as is required for names to be exported. # See PEP 484 & https://github.com/jax-ml/jax/issues/7570 -from jax._src.lib import xla_extension as _xe +from jax._src.lib import _jax -ifrt_programs = _xe.ifrt_programs +ifrt_programs = _jax.ifrt_programs -del _xe +del _jax diff --git a/jax/extend/linear_util.py b/jax/extend/linear_util.py index 0cf9a013a9e4..ad67f6ac8f73 100644 --- a/jax/extend/linear_util.py +++ b/jax/extend/linear_util.py @@ -15,7 +15,7 @@ # Note: import as is required for names to be exported. # See PEP 484 & https://github.com/jax-ml/jax/issues/7570 -from typing import Callable +from collections.abc import Callable from jax._src.linear_util import ( StoreException as StoreException, diff --git a/jax/extend/mlir/dialects/sdy.py b/jax/extend/mlir/dialects/sdy.py index 48586cc26760..d83fd90ecdf4 100644 --- a/jax/extend/mlir/dialects/sdy.py +++ b/jax/extend/mlir/dialects/sdy.py @@ -14,8 +14,4 @@ # ruff: noqa: F403 -# TODO(bartchr): Once JAX is released with SDY, remove the try/except. -try: - from jaxlib.mlir.dialects.sdy import * -except ImportError: - pass +from jaxlib.mlir.dialects.sdy import * diff --git a/jax/extend/sharding.py b/jax/extend/sharding.py new file mode 100644 index 000000000000..8af2bf397249 --- /dev/null +++ b/jax/extend/sharding.py @@ -0,0 +1,17 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# TODO(yashkatariya): Remove this after NamedSharding supports more complicated +# shardings like sub-axes, strided shardings, etc. +from jax._src.sharding_impls import GSPMDSharding as GSPMDSharding diff --git a/jax/interpreters/mlir.py b/jax/interpreters/mlir.py index 0f32799f7ea9..10c8d1e9e671 100644 --- a/jax/interpreters/mlir.py +++ b/jax/interpreters/mlir.py @@ -33,7 +33,7 @@ aval_to_ir_type as aval_to_ir_type, aval_to_ir_types as aval_to_ir_types, core_call_lowering as core_call_lowering, - custom_call as custom_call, + custom_call as _custom_call, dense_bool_elements as dense_bool_elements, dense_bool_array as dense_bool_array, dense_int_array as dense_int_array, @@ -43,8 +43,6 @@ flatten_ir_values as flatten_lowering_ir_args, # TODO(phawkins): remove me # noqa: F401 flatten_ir_values as flatten_ir_values, unflatten_ir_values_like_types as unflatten_ir_values_like_types, - func_dialect as func_dialect, - hlo as hlo, i32_attr as i32_attr, i64_attr as i64_attr, ir as ir, @@ -63,7 +61,6 @@ register_lowering as register_lowering, shape_tensor as shape_tensor, token_type as token_type, - xla_computation_to_mlir_module as xla_computation_to_mlir_module, ) from jax._src.mesh import Mesh as Mesh @@ -80,3 +77,23 @@ from jax._src.callback import ( emit_python_callback as emit_python_callback, ) + +_deprecations = { + # Added Apr 7 2025 + "custom_call": ( + "mlir.custom_call is deprecated; use the APIs provided by jax.ffi instead.", + _custom_call, + ) +} + +import typing as _typing + +if _typing.TYPE_CHECKING: + custom_call = _custom_call +else: + from jax._src.deprecations import deprecation_getattr as _deprecation_getattr + + __getattr__ = _deprecation_getattr(__name__, _deprecations) + del _deprecation_getattr +del _typing +del _custom_call diff --git a/jax/interpreters/partial_eval.py b/jax/interpreters/partial_eval.py index b546d774a2e9..a2d988f6bea3 100644 --- a/jax/interpreters/partial_eval.py +++ b/jax/interpreters/partial_eval.py @@ -81,7 +81,6 @@ trace_to_subjaxpr_nounits as trace_to_subjaxpr_nounits, trace_to_subjaxpr_nounits_fwd as trace_to_subjaxpr_nounits_fwd, tracers_to_jaxpr as tracers_to_jaxpr, - trivial_ctx as trivial_ctx, ) diff --git a/jax/interpreters/xla.py b/jax/interpreters/xla.py index bd3b83e37d24..2f8417ade1f8 100644 --- a/jax/interpreters/xla.py +++ b/jax/interpreters/xla.py @@ -38,19 +38,6 @@ "jax.interpreters.xla.pytype_aval_mappings is deprecated.", _src_core.pytype_aval_mappings ), - # Finalized 2024-10-24; remove after 2025-01-24 - "xb": ( - ("jax.interpreters.xla.xb was removed in JAX v0.4.36. " - "Use jax.lib.xla_bridge instead."), None - ), - "xc": ( - ("jax.interpreters.xla.xc was removed in JAX v0.4.36. " - "Use jax.lib.xla_client instead."), None - ), - "xe": ( - ("jax.interpreters.xla.xe was removed in JAX v0.4.36. " - "Use jax.lib.xla_extension instead."), None - ), } import typing as _typing diff --git a/jax/lax/__init__.py b/jax/lax/__init__.py index 4e376fb666d1..953259c6c1f1 100644 --- a/jax/lax/__init__.py +++ b/jax/lax/__init__.py @@ -198,6 +198,7 @@ select as select, select_n as select_n, select_n_p as select_n_p, + shape_as_value as shape_as_value, shift_left as shift_left, shift_left_p as shift_left_p, shift_right_arithmetic as shift_right_arithmetic, @@ -260,7 +261,6 @@ polygamma as polygamma, polygamma_p as polygamma_p, random_gamma_grad as random_gamma_grad, - random_gamma_grad_p as random_gamma_grad_p, regularized_incomplete_beta_p as regularized_incomplete_beta_p, zeta as zeta, zeta_p as zeta_p, @@ -356,11 +356,13 @@ ) from jax._src.lax.parallel import ( all_gather as all_gather, + all_gather_invariant as all_gather_invariant, all_gather_p as all_gather_p, all_to_all as all_to_all, all_to_all_p as all_to_all_p, axis_index as axis_index, axis_index_p as axis_index_p, + axis_size as axis_size, pbroadcast as pbroadcast, pmax as pmax, pmax_p as pmax_p, @@ -369,6 +371,8 @@ pmin_p as pmin_p, ppermute as ppermute, ppermute_p as ppermute_p, + psend as psend, + precv as precv, pshuffle as pshuffle, psum as psum, psum_p as psum_p, @@ -377,6 +381,9 @@ ragged_all_to_all as ragged_all_to_all, ragged_all_to_all_p as ragged_all_to_all_p, ) +from jax._src.core import ( + pvary as pvary, +) from jax._src.lax.other import ( conv_general_dilated_local as conv_general_dilated_local, conv_general_dilated_patches as conv_general_dilated_patches @@ -392,3 +399,50 @@ from jax._src.pjit import with_sharding_constraint as with_sharding_constraint from jax._src.pjit import sharding_constraint_p as sharding_constraint_p from jax._src.dispatch import device_put_p as device_put_p + +import jax._src.lax.lax + +_deprecations = { + "infeed": ( + ( + "jax.lax.infeed was deprecated in JAX v0.6.0 and will be removed in" + " JAX v0.7.0." + ), + jax._src.lax.lax.infeed, + ), + "infeed_p": ( + ( + "jax.lax.infeed_p was deprecated in JAX v0.6.0 and will be removed" + " in JAX v0.7.0." + ), + jax._src.lax.lax.infeed_p, + ), + "outfeed": ( + ( + "jax.lax.outfeed was deprecated in JAX v0.6.0 and will be removed" + " in JAX v0.7.0." + ), + jax._src.lax.lax.outfeed, + ), + "outfeed_p": ( + ( + "jax.lax.outfeed_p was deprecated in JAX v0.6.0 and will be removed" + " in JAX v0.7.0." + ), + jax._src.lax.lax.outfeed_p, + ), +} + +import typing as _typing + +if _typing.TYPE_CHECKING: + infeed = jax._src.lax.lax.infeed + infeed_p = jax._src.lax.lax.infeed_p + outfeed = jax._src.lax.lax.outfeed + outfeed_p = jax._src.lax.lax.outfeed_p +else: + from jax._src.deprecations import deprecation_getattr as _deprecation_getattr + + __getattr__ = _deprecation_getattr(__name__, _deprecations) + del _deprecation_getattr +del _typing diff --git a/jax/lax/linalg.py b/jax/lax/linalg.py index 343073ca56d0..984592534656 100644 --- a/jax/lax/linalg.py +++ b/jax/lax/linalg.py @@ -46,6 +46,6 @@ tridiagonal_solve_p as tridiagonal_solve_p, ) -from jax._src.lax.qdwh import ( +from jax._src.tpu.linalg.qdwh import ( qdwh as qdwh ) diff --git a/jax/lib/xla_bridge.py b/jax/lib/xla_bridge.py index b158d9b1ff51..95598c447262 100644 --- a/jax/lib/xla_bridge.py +++ b/jax/lib/xla_bridge.py @@ -27,15 +27,6 @@ "jax.lib.xla_bridge.get_backend is deprecated; use jax.extend.backend.get_backend.", _deprecated_get_backend ), - # Finalized 2024-12-11; remove after 2025-3-11 - "xla_client": ( - "jax.lib.xla_bridge.xla_client was removed in JAX v0.4.38; use jax.lib.xla_client directly.", - None - ), - "default_backend": ( - "jax.lib.xla_bridge.default_backend was removed in JAX v0.4.38; use jax.default_backend.", - None - ), } import typing as _typing diff --git a/jax/lib/xla_client.py b/jax/lib/xla_client.py index 86e7307c804b..15cd62d6e245 100644 --- a/jax/lib/xla_client.py +++ b/jax/lib/xla_client.py @@ -12,116 +12,164 @@ # See the License for the specific language governing permissions and # limitations under the License. -from jax._src.lax.fft import FftType as _FftType +import gzip as _gzip from jax._src.lib import xla_client as _xc -get_topology_for_devices = _xc.get_topology_for_devices -heap_profile = _xc.heap_profile -mlir_api_version = _xc.mlir_api_version -Client = _xc.Client -CompileOptions = _xc.CompileOptions -DeviceAssignment = _xc.DeviceAssignment -Frame = _xc.Frame -HloSharding = _xc.HloSharding -OpSharding = _xc.OpSharding -Traceback = _xc.Traceback +def _heap_profile(client): + return _gzip.compress(client.heap_profile()) _deprecations = { - # Finalized 2024-12-11; remove after 2025-3-11 - "_xla": ( - "jax.lib.xla_client._xla was removed in JAX v0.4.38; use jax.lib.xla_extension.", - None, - ), - "bfloat16": ( - "jax.lib.xla_client.bfloat16 was removed in JAX v0.4.38; use ml_dtypes.bfloat16.", - None, - ), - # Finalized 2024-12-23; remove after 2024-03-23 - "Device": ( - "jax.lib.xla_client.Device is deprecated; use jax.Device instead.", - None, - ), - "XlaRuntimeError": ( + # Finalized 2025-03-25; remove after 2025-06-25 + "FftType": ( ( - "jax.lib.xla_client.XlaRuntimeError is deprecated; use" - " jax.errors.JaxRuntimeError." + "jax.lib.xla_client.FftType was removed in JAX v0.6.0; use" + " jax.lax.FftType." ), None, ), - # Added Oct 10 2024 - "FftType": ( - "jax.lib.xla_client.FftType is deprecated; use jax.lax.FftType.", - _FftType, - ), "PaddingType": ( ( - "jax.lib.xla_client.PaddingType is deprecated; this type is unused" - " by JAX so there is no replacement." + "jax.lib.xla_client.PaddingType was removed in JAX v0.6.0;" + " this type is unused by JAX so there is no replacement." ), - _xc.PaddingType, + None, ), - # Added Oct 11 2024 "dtype_to_etype": ( - "dtype_to_etype is deprecated; use StableHLO instead.", - _xc.dtype_to_etype, + "dtype_to_etype was removed in JAX v0.6.0; use StableHLO instead.", + None, ), + "shape_from_pyval": ( + "shape_from_pyval was removed in JAX v0.6.0; use StableHLO instead.", + None, + ), + # Added Oct 11 2024, finalized 2025-04-09 "ops": ( - "ops is deprecated; use StableHLO instead.", - _xc.ops, + "ops has been removed in JAX v0.6.0; use StableHLO instead.", + None, ), "register_custom_call_target": ( - "register_custom_call_target is deprecated; use the JAX FFI instead " - "(https://jax.readthedocs.io/en/latest/ffi.html)", - _xc.register_custom_call_target, - ), - "shape_from_pyval": ( - "shape_from_pyval is deprecated; use StableHLO instead.", - _xc.shape_from_pyval, + ( + "register_custom_call_target has been removed in JAX v0.6.0; use" + " the JAX FFI instead (https://docs.jax.dev/en/latest/ffi.html)" + ), + None, ), "PrimitiveType": ( - "PrimitiveType is deprecated; use StableHLO instead.", - _xc.PrimitiveType, + "PrimitiveType has been removed in JAX v0.6.0; use StableHLO instead.", + None, ), "Shape": ( - "Shape is deprecated; use StableHLO instead.", - _xc.Shape, + "Shape has been removed in JAX v0.6.0; use StableHLO instead.", + None, ), "XlaBuilder": ( - "XlaBuilder is deprecated; use StableHLO instead.", - _xc.XlaBuilder, + "XlaBuilder has been removed in JAX v0.6.0; use StableHLO instead.", + None, ), "XlaComputation": ( - "XlaComputation is deprecated; use StableHLO instead.", - _xc.XlaComputation, + "XlaComputation has been removed in JAX v0.6.0; use StableHLO instead.", + None, ), - # Added Nov 20 2024 + # Added Nov 20 2024, finalized 2025-04-09 "ArrayImpl": ( - "jax.lib.xla_client.ArrayImpl is deprecated; use jax.Array instead.", - _xc.ArrayImpl, + ( + "jax.lib.xla_client.ArrayImpl has been removed in JAX v0.6.0; use" + " jax.Array instead." + ), + None, + ), + # Finalized for JAX v0.7.0 + "heap_profile": ( + ( + "jax.lib.xla_client.heap_profile was deprecated in JAX v0.6.0 and" + " removed in JAX v0.7.0" + ), + None, + ), + # Added April 4 2025. + "get_topology_for_devices": ( + ( + "jax.lib.xla_client.get_topology_for_devices was deprecated in JAX" + " v0.6.0 and will be removed in JAX v0.7.0" + ), + _xc.get_topology_for_devices, + ), + "mlir_api_version": ( + ( + "jax.lib.xla_client.mlir_api_version was deprecated in JAX v0.6.0" + " and will be removed in JAX v0.7.0" + ), + 58, + ), + "Client": ( + ( + "jax.lib.xla_client.Client was deprecated in JAX v0.6.0 and will be" + " removed in JAX v0.7.0" + ), + _xc.Client, + ), + "CompileOptions": ( + ( + "jax.lib.xla_client.CompileOptions was deprecated in JAX v0.6.0 and" + " will be removed in JAX v0.7.0" + ), + _xc.CompileOptions, + ), + "DeviceAssignment": ( + ( + "jax.lib.xla_client.DeviceAssignment was deprecated in JAX v0.6.0" + " and will be removed in JAX v0.7.0" + ), + _xc.DeviceAssignment, + ), + "Frame": ( + ( + "jax.lib.xla_client.Frame was deprecated in JAX v0.6.0 and will be" + " removed in JAX v0.7.0" + ), + _xc.Frame, + ), + "HloSharding": ( + ( + "jax.lib.xla_client.HloSharding was deprecated in JAX v0.6.0 and" + " will be removed in JAX v0.7.0" + ), + _xc.HloSharding, + ), + "OpSharding": ( + ( + "jax.lib.xla_client.OpSharding was deprecated in JAX v0.6.0 and" + " will be removed in JAX v0.7.0" + ), + _xc.OpSharding, + ), + "Traceback": ( + ( + "jax.lib.xla_client.Traceback was deprecated in JAX v0.6.0 and will" + " be removed in JAX v0.7.0" + ), + _xc.Traceback, ), } import typing as _typing if _typing.TYPE_CHECKING: - dtype_to_etype = _xc.dtype_to_etype - ops = _xc.ops - register_custom_call_target = _xc.register_custom_call_target - shape_from_pyval = _xc.shape_from_pyval - ArrayImpl = _xc.ArrayImpl - Device = _xc.Device - FftType = _FftType - PaddingType = _xc.PaddingType - PrimitiveType = _xc.PrimitiveType - Shape = _xc.Shape - XlaBuilder = _xc.XlaBuilder - XlaComputation = _xc.XlaComputation - XlaRuntimeError = _xc.XlaRuntimeError + get_topology_for_devices = _xc.get_topology_for_devices + heap_profile = _heap_profile + mlir_api_version = 58 + Client = _xc.Client + CompileOptions = _xc.CompileOptions + DeviceAssignment = _xc.DeviceAssignment + Frame = _xc.Frame + HloSharding = _xc.HloSharding + OpSharding = _xc.OpSharding + Traceback = _xc.Traceback else: from jax._src.deprecations import deprecation_getattr as _deprecation_getattr __getattr__ = _deprecation_getattr(__name__, _deprecations) del _deprecation_getattr del _typing -del _FftType +del _heap_profile del _xc diff --git a/jax/lib/xla_extension.py b/jax/lib/xla_extension.py index 52fe94e231d1..7e183eab5c2c 100644 --- a/jax/lib/xla_extension.py +++ b/jax/lib/xla_extension.py @@ -12,48 +12,146 @@ # See the License for the specific language governing permissions and # limitations under the License. -from jax._src.lib import xla_extension as _xe - -get_distributed_runtime_client = _xe.get_distributed_runtime_client -get_distributed_runtime_service = _xe.get_distributed_runtime_service -hlo_module_cost_analysis = _xe.hlo_module_cost_analysis -hlo_module_to_dot_graph = _xe.hlo_module_to_dot_graph -ifrt_proxy = _xe.ifrt_proxy -jax_jit = _xe.jax_jit -mlir = _xe.mlir -pmap_lib = _xe.pmap_lib -profiler = _xe.profiler -pytree = _xe.pytree -Device = _xe.Device -DistributedRuntimeClient = _xe.DistributedRuntimeClient -HloModule = _xe.HloModule -HloPrintOptions = _xe.HloPrintOptions -OpSharding = _xe.OpSharding -PjitFunctionCache = _xe.PjitFunctionCache -PjitFunction = _xe.PjitFunction -PmapFunction = _xe.PmapFunction +import jax._src.lib +from jax._src.lib import _jax _deprecations = { - # Added Nov 20 2024 + # Finalized for JAX v0.6.0 "ArrayImpl": ( - "jax.lib.xla_extension.ArrayImpl is deprecated; use jax.Array instead.", - _xe.ArrayImpl, + ( + "jax.lib.xla_extension.ArrayImpl has been removed; use jax.Array" + " instead." + ), + None, ), "XlaRuntimeError": ( - "jax.lib.xla_extension.XlaRuntimeError is deprecated; use jax.errors.JaxRuntimeError instead.", - _xe.XlaRuntimeError, + ( + "jax.lib.xla_extension.XlaRuntimeError has been removed; use" + " jax.errors.JaxRuntimeError instead." + ), + None, + ), + # Finalized for JAX v0.7.0 + "Device": ( + ( + "jax.lib.xla_extension.Device was deprecated in JAX v0.6.0" + " and removed in JAX v0.7.0; use jax.Device instead." + ), + None, + ), + "DistributedRuntimeClient": ( + ( + "jax.lib.xla_extension.DistributedRuntimeClient deprecated in JAX" + " v0.6.0 and removed in JAX v0.7.0; use jax.distributed instead." + ), + None, + ), + "HloModule": ( + ( + "jax.lib.xla_extension.HloModule deprecated in JAX v0.6.0" + " and removed in JAX v0.7.0." + ), + None, + ), + "OpSharding": ( + ( + "jax.lib.xla_extension.OpSharding deprecated in JAX v0.6.0" + " and removed in JAX v0.7.0." + ), + None, + ), + "PjitFunctionCache": ( + ( + "jax.lib.xla_extension.PjitFunctionCache was deprecated in JAX v0.6.0" + " and removed in JAX v0.7.0." + ), + None, + ), + "get_distributed_runtime_client": ( + ( + "jax.lib.xla_extension.get_distributed_runtime_client was deprecated" + " in JAX v0.6.0 and removed in JAX v0.7.0; use jax.distributed instead." + ), + None, + ), + "get_distributed_runtime_service": ( + ( + "jax.lib.xla_extension.get_distributed_runtime_service was deprecated" + " in JAX v0.6.0 and removed in JAX v0.7.0; use jax.distributed instead." + ), + None, + ), + "jax_jit": ( + "jax.lib.xla_extension.jax_jit deprecated in JAX v0.6.0 and removed in JAX v0.7.0.", + None, + ), + "pmap_lib": ( + "jax.lib.xla_extension.pmap_lib deprecated in JAX v0.6.0 and removed in JAX v0.7.0.", + None + ), + "pytree": ( + "jax.lib.xla_extension.pytree deprecated in JAX v0.6.0 and removed in JAX v0.7.0.", + None, + ), + # Deprecated March 26 2025. + "ifrt_proxy": ( + "jax.lib.xla_extension.ifrt_proxy is deprecated.", + _jax.ifrt_proxy, + ), + "mlir": ("jax.lib.xla_extension.mlir is deprecated.", _jax.mlir), + "profiler": ( + "jax.lib.xla_extension.profiler is deprecated.", + jax._src.lib._profiler, + ), + "hlo_module_cost_analysis": ( + "jax.lib.xla_extension.hlo_module_cost_analysis is deprecated.", + _jax.hlo_module_cost_analysis, + ), + "hlo_module_to_dot_graph": ( + "jax.lib.xla_extension.hlo_module_to_dot_graph is deprecated.", + _jax.hlo_module_to_dot_graph, + ), + "HloPrintOptions": ( + "jax.lib.xla_extension.HloPrintOptions is deprecated.", + _jax.HloPrintOptions, + ), + "PjitFunction": ( + "jax.lib.xla_extension.PjitFunction is deprecated.", + _jax.PjitFunction, + ), + "PmapFunction": ( + "jax.lib.xla_extension.PmapFunction is deprecated.", + _jax.PmapFunction, ), } import typing as _typing if _typing.TYPE_CHECKING: - ArrayImpl = _xe.ArrayImpl - XlaRuntimeError = _xe.XlaRuntimeError + Device = _jax.Device + DistributedRuntimeClient = _jax.DistributedRuntimeClient + HloModule = _jax.HloModule + HloPrintOptions = _jax.HloPrintOptions + OpSharding = _jax.OpSharding + PjitFunction = _jax.PjitFunction + PjitFunctionCache = _jax.PjitFunctionCache + PmapFunction = _jax.PmapFunction + + get_distributed_runtime_client = _jax.get_distributed_runtime_client + get_distributed_runtime_service = _jax.get_distributed_runtime_service + hlo_module_cost_analysis = _jax.hlo_module_cost_analysis + hlo_module_to_dot_graph = _jax.hlo_module_to_dot_graph + ifrt_proxy = _jax.ifrt_proxy + jax_jit = _jax.jax_jit + mlir = _jax.mlir + pmap_lib = _jax.pmap_lib + profiler = jax._src.lib._profiler + pytree = _jax.pytree + else: from jax._src.deprecations import deprecation_getattr as _deprecation_getattr __getattr__ = _deprecation_getattr(__name__, _deprecations) del _deprecation_getattr del _typing -del _xe +del _jax diff --git a/jax/monitoring.py b/jax/monitoring.py index 4c9996da582c..f4ab8124f219 100644 --- a/jax/monitoring.py +++ b/jax/monitoring.py @@ -26,7 +26,9 @@ record_event_duration_secs as record_event_duration_secs, record_event_time_span as record_event_time_span, record_event as record_event, + record_scalar as record_scalar, register_event_duration_secs_listener as register_event_duration_secs_listener, register_event_listener as register_event_listener, register_event_time_span_listener as register_event_time_span_listener, + register_scalar_listener as register_scalar_listener, ) diff --git a/jax/nn/__init__.py b/jax/nn/__init__.py index 3f08e1c0fd12..651d9cf4e47f 100644 --- a/jax/nn/__init__.py +++ b/jax/nn/__init__.py @@ -35,8 +35,10 @@ standardize as standardize, one_hot as one_hot, relu as relu, + identity as identity, relu6 as relu6, dot_product_attention as dot_product_attention, + get_scaled_dot_general_config as get_scaled_dot_general_config, scaled_dot_general as scaled_dot_general, scaled_matmul as scaled_matmul, selu as selu, diff --git a/jax/numpy/__init__.py b/jax/numpy/__init__.py index cb291bdca79a..24a0ca907567 100644 --- a/jax/numpy/__init__.py +++ b/jax/numpy/__init__.py @@ -24,6 +24,11 @@ isdtype as isdtype, ) +from jax._src.numpy.array import ( + array as array, + asarray as asarray, +) + from jax._src.numpy.lax_numpy import ( ComplexWarning as ComplexWarning, allclose as allclose, @@ -36,12 +41,10 @@ argmin as argmin, argwhere as argwhere, around as around, - array as array, array_equal as array_equal, array_equiv as array_equiv, array_split as array_split, astype as astype, - asarray as asarray, atleast_1d as atleast_1d, atleast_2d as atleast_2d, atleast_3d as atleast_3d, @@ -93,7 +96,6 @@ fromstring as fromstring, from_dlpack as from_dlpack, gcd as gcd, - geomspace as geomspace, get_printoptions as get_printoptions, gradient as gradient, histogram as histogram, @@ -118,9 +120,7 @@ ix_ as ix_, kron as kron, lcm as lcm, - linspace as linspace, load as load, - logspace as logspace, mask_indices as mask_indices, matrix_transpose as matrix_transpose, meshgrid as meshgrid, @@ -180,6 +180,9 @@ empty_like as empty_like, full as full, full_like as full_like, + geomspace as geomspace, + linspace as linspace, + logspace as logspace, ones as ones, ones_like as ones_like, zeros as zeros, @@ -211,13 +214,18 @@ double as double, float16 as float16, float32 as float32, + float4_e2m1fn as float4_e2m1fn, float64 as float64, + float8_e3m4 as float8_e3m4, + float8_e4m3 as float8_e4m3, float8_e4m3b11fnuz as float8_e4m3b11fnuz, float8_e4m3fn as float8_e4m3fn, float8_e4m3fnuz as float8_e4m3fnuz, float8_e5m2 as float8_e5m2, float8_e5m2fnuz as float8_e5m2fnuz, + float8_e8m0fnu as float8_e8m0fnu, float_ as float_, + int2 as int2, int4 as int4, int8 as int8, int16 as int16, @@ -226,6 +234,7 @@ int_ as int_, single as single, uint as uint, + uint2 as uint2, uint4 as uint4, uint8 as uint8, uint16 as uint16, @@ -295,26 +304,6 @@ unsignedinteger as unsignedinteger, ) -# TODO(slebedev): Remove the try-except once we upgrade to ml_dtypes 0.4.1. -try: - from jax._src.numpy.scalar_types import ( - int2 as int2, - uint2 as uint2, - ) -except ImportError: - pass - -# TODO: Remove the try-except once we upgrade to ml_dtypes 0.5.0 -try: - from jax._src.numpy.scalar_types import ( - float8_e3m4 as float8_e3m4, - float8_e4m3 as float8_e4m3, - float8_e8m0fnu as float8_e8m0fnu, - float4_e2m1fn as float4_e2m1fn, - ) -except ImportError: - pass - from jax._src.numpy.array_api_metadata import ( __array_api_version__ as __array_api_version__, __array_namespace_info__ as __array_namespace_info__, @@ -506,19 +495,3 @@ from jax._src.numpy.array_methods import register_jax_array_methods register_jax_array_methods() del register_jax_array_methods - - -_deprecations = { - # Finalized 2024-12-13; remove after 2024-3-13 - "round_": ( - "jnp.round_ was deprecated in JAX 0.4.38; use jnp.round instead.", - None - ), -} - -import typing -if not typing.TYPE_CHECKING: - from jax._src.deprecations import deprecation_getattr as _deprecation_getattr - __getattr__ = _deprecation_getattr(__name__, _deprecations) - del _deprecation_getattr -del typing diff --git a/jax/numpy/__init__.pyi b/jax/numpy/__init__.pyi index b73a3b95b9a5..0ff96a4394ce 100644 --- a/jax/numpy/__init__.pyi +++ b/jax/numpy/__init__.pyi @@ -15,7 +15,7 @@ from jax._src.numpy.index_tricks import _Mgrid, _Ogrid, CClass as _CClass, RClas from jax._src.numpy.array_api_metadata import ArrayNamespaceInfo from jax._src.typing import ( Array, ArrayLike, DType, DTypeLike, DeprecatedArg, - DimSize, DuckTypedArray, Shape, StaticScalar, + DimSize, DuckTypedArray, Shape, StaticScalar, SupportsNdim, SupportsShape, SupportsSize, ) from jax._src.sharding_impls import NamedSharding, PartitionSpec as P from jax.numpy import fft as fft, linalg as linalg @@ -253,7 +253,8 @@ def broadcast_shapes(*shapes: Sequence[int]) -> tuple[int, ...]: ... def broadcast_shapes(*shapes: Sequence[int | _core.Tracer] ) -> tuple[int | _core.Tracer, ...]: ... -def broadcast_to(array: ArrayLike, shape: DimSize | Shape) -> Array: ... +def broadcast_to(array: ArrayLike, shape: DimSize | Shape, *, + out_sharding: NamedSharding | P | None = None) -> Array: ... c_: _CClass can_cast = _np.can_cast def cbrt(x: ArrayLike, /) -> Array: ... @@ -267,6 +268,7 @@ def clip( /, min: ArrayLike | None = ..., max: ArrayLike | None = ..., + *, a: ArrayLike | DeprecatedArg | None = ..., a_min: ArrayLike | DeprecatedArg | None = ..., a_max: ArrayLike | DeprecatedArg | None = ... @@ -278,7 +280,7 @@ complex128: Any complex64: Any complex_: Any complexfloating = _np.complexfloating -def compress(condition: ArrayLike, a: ArrayLike, axis: int | None = ..., +def compress(condition: ArrayLike, a: ArrayLike, axis: int | None = ..., *, size: int | None = ..., fill_value: ArrayLike = ..., out: None = ...) -> Array: ... def concat(arrays: Sequence[ArrayLike], /, *, axis: int | None = 0) -> Array: ... def concatenate( @@ -314,9 +316,9 @@ def cross( axis: int | None = ..., ) -> Array: ... csingle: Any -def cumprod(a: ArrayLike, axis: int | None = ..., dtype: DTypeLike = ..., +def cumprod(a: ArrayLike, axis: int | None = ..., dtype: DTypeLike | None = ..., out: None = ...) -> Array: ... -def cumsum(a: ArrayLike, axis: int | None = ..., dtype: DTypeLike = ..., +def cumsum(a: ArrayLike, axis: int | None = ..., dtype: DTypeLike | None = ..., out: None = ...) -> Array: ... def cumulative_prod(x: ArrayLike, /, *, axis: int | None = ..., dtype: DTypeLike | None = ..., @@ -350,7 +352,8 @@ def divide(x: ArrayLike, y: ArrayLike, /) -> Array: ... def divmod(x: ArrayLike, y: ArrayLike, /) -> tuple[Array, Array]: ... def dot( a: ArrayLike, b: ArrayLike, *, precision: PrecisionLike = ..., - preferred_element_type: DTypeLike | None = ...) -> Array: ... + preferred_element_type: DTypeLike | None = ..., + out_sharding: NamedSharding | P | None = ...) -> Array: ... double: Any def dsplit( ary: ArrayLike, indices_or_sections: int | ArrayLike @@ -370,7 +373,6 @@ def einsum( optimize: str | builtins.bool | list[tuple[int, ...]] = ..., precision: PrecisionLike = ..., preferred_element_type: DTypeLike | None = ..., - _use_xeinsum: builtins.bool = False, _dot_general: Callable[..., Array] = ..., out_sharding: NamedSharding | P | None = ..., ) -> Array: ... @@ -384,7 +386,6 @@ def einsum( optimize: str | builtins.bool | list[tuple[int, ...]] = ..., precision: PrecisionLike = ..., preferred_element_type: DTypeLike | None = ..., - _use_xeinsum: builtins.bool = False, _dot_general: Callable[..., Array] = ..., out_sharding: NamedSharding | P | None = ..., ) -> Array: ... @@ -396,7 +397,6 @@ def einsum( optimize: str | builtins.bool | list[tuple[int, ...]] = ..., precision: PrecisionLike = ..., preferred_element_type: DTypeLike | None = ..., - _use_xeinsum: builtins.bool = ..., _dot_general: Callable[..., Array] = ..., out_sharding: NamedSharding | P | None = ..., ) -> Array: ... @@ -421,7 +421,7 @@ def einsum_path( optimize: str | builtins.bool | list[tuple[int, ...]] = ..., ) -> tuple[list[tuple[int, ...]], Any]: ... -def empty(shape: Any, dtype: DTypeLike | None = ..., +def empty(shape: Any, dtype: DTypeLike | None = ..., *, device: _Device | _Sharding | None = ...) -> Array: ... def empty_like(prototype: ArrayLike | DuckTypedArray, dtype: DTypeLike | None = ..., @@ -456,12 +456,16 @@ def fliplr(m: ArrayLike) -> Array: ... def flipud(m: ArrayLike) -> Array: ... float16: Any float32: Any +float4_e2m1fn: Any float64: Any +float8_e3m4: Any +float8_e4m3: Any float8_e4m3b11fnuz: Any float8_e4m3fn: Any float8_e4m3fnuz: Any float8_e5m2: Any float8_e5m2fnuz: Any +float8_e8m0fnu: Any float_: Any def float_power(x: ArrayLike, y: ArrayLike, /) -> Array: ... floating = _np.floating @@ -562,6 +566,7 @@ def inner( def insert(arr: ArrayLike, obj: ArrayLike | slice, values: ArrayLike, axis: int | None = ...) -> Array: ... int16: Any +int2: Any int32: Any int4: Any int64: Any @@ -578,17 +583,17 @@ def intersect1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: builtins.bool = . def invert(x: ArrayLike, /) -> Array: ... def isclose(a: ArrayLike, b: ArrayLike, rtol: ArrayLike = ..., atol: ArrayLike = ..., equal_nan: builtins.bool = ...) -> Array: ... -def iscomplex(m: ArrayLike) -> Array: ... +def iscomplex(x: ArrayLike) -> Array: ... def iscomplexobj(x: Any) -> builtins.bool: ... def isdtype(dtype: DTypeLike, kind: DType | str | tuple[DType | str, ...]) -> builtins.bool: ... def isfinite(x: ArrayLike, /) -> Array: ... -def isin(element: ArrayLike, test_elements: ArrayLike, - assume_unique: builtins.bool = ..., invert: builtins.bool = ...) -> Array: ... +def isin(element: ArrayLike, test_elements: ArrayLike, assume_unique: builtins.bool = ..., + invert: builtins.bool = ..., *, method: str = ...) -> Array: ... def isinf(x: ArrayLike, /) -> Array: ... def isnan(x: ArrayLike, /) -> Array: ... def isneginf(x: ArrayLike, /) -> Array: ... def isposinf(x: ArrayLike, /) -> Array: ... -def isreal(m: ArrayLike) -> Array: ... +def isreal(x: ArrayLike) -> Array: ... def isrealobj(x: Any) -> builtins.bool: ... def isscalar(element: Any) -> builtins.bool: ... def issubdtype(arg1: DTypeLike, arg2: DTypeLike) -> builtins.bool: ... @@ -643,7 +648,7 @@ def logspace(start: ArrayLike, stop: ArrayLike, num: int = ..., endpoint: builtins.bool = ..., base: ArrayLike = ..., dtype: DTypeLike | None = ..., axis: int = ...) -> Array: ... def mask_indices( - n: int, mask_func: Callable, k: int = ... + n: int, mask_func: Callable, k: int = ..., *, size: int | None = ... ) -> tuple[Array, ...]: ... def matmul( a: ArrayLike, b: ArrayLike, *, precision: PrecisionLike = ..., @@ -653,8 +658,8 @@ def matvec(x1: ArrayLike, x2: ArrayLike, /) -> Array: ... def max(a: ArrayLike, axis: _Axis = ..., out: None = ..., keepdims: builtins.bool = ..., initial: ArrayLike | None = ..., where: ArrayLike | None = ...) -> Array: ... -def maximum(x: ArrayLike, y: ArrayLike, /) -> Array: ... -def mean(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ..., +maximum: BinaryUfunc +def mean(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike | None = ..., out: None = ..., keepdims: builtins.bool = ..., *, where: ArrayLike | None = ...) -> Array: ... def median(a: ArrayLike, axis: int | tuple[int, ...] | None = ..., @@ -666,7 +671,7 @@ mgrid: _Mgrid def min(a: ArrayLike, axis: _Axis = ..., out: None = ..., keepdims: builtins.bool = ..., initial: ArrayLike | None = ..., where: ArrayLike | None = ...) -> Array: ... -def minimum(x: ArrayLike, y: ArrayLike, /) -> Array: ... +minimum: BinaryUfunc def mod(x: ArrayLike, y: ArrayLike, /) -> Array: ... def modf(x: ArrayLike, /, out=None) -> tuple[Array, Array]: ... def moveaxis(a: ArrayLike, source: int | Sequence[int], @@ -688,14 +693,14 @@ def nanargmin( out: None = ..., keepdims: builtins.bool | None = ..., ) -> Array: ... -def nancumprod(a: ArrayLike, axis: int | None = ..., dtype: DTypeLike = ..., +def nancumprod(a: ArrayLike, axis: int | None = ..., dtype: DTypeLike | None = ..., out: None = ...) -> Array: ... -def nancumsum(a: ArrayLike, axis: int | None = ..., dtype: DTypeLike = ..., +def nancumsum(a: ArrayLike, axis: int | None = ..., dtype: DTypeLike | None = ..., out: None = ...) -> Array: ... def nanmax(a: ArrayLike, axis: _Axis = ..., out: None = ..., keepdims: builtins.bool = ..., initial: ArrayLike | None = ..., where: ArrayLike | None = ...) -> Array: ... -def nanmean(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ..., +def nanmean(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike | None = ..., out: None = ..., keepdims: builtins.bool = ..., where: ArrayLike | None = ...) -> Array: ... @@ -709,26 +714,26 @@ def nanpercentile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = ..., out: None = ..., overwrite_input: builtins.bool = ..., method: str = ..., keepdims: builtins.bool = ..., *, interpolation: DeprecatedArg | str = ...) -> Array: ... -def nanprod(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ..., +def nanprod(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike | None = ..., out: None = ..., keepdims: builtins.bool = ..., initial: ArrayLike | None = ..., where: ArrayLike | None = ...) -> Array: ... def nanquantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = ..., out: None = ..., overwrite_input: builtins.bool = ..., method: str = ..., keepdims: builtins.bool = ..., *, interpolation: DeprecatedArg | str = ...) -> Array: ... -def nanstd(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ..., out: None = ..., - ddof: int = ..., keepdims: builtins.bool = ..., +def nanstd(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike | None = ..., + out: None = ..., ddof: int = ..., keepdims: builtins.bool = ..., where: ArrayLike | None = ...) -> Array: ... -def nansum(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ..., +def nansum(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike | None = ..., out: None = ..., keepdims: builtins.bool = ..., initial: ArrayLike | None = ..., where: ArrayLike | None = ...) -> Array: ... -def nanvar(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ..., +def nanvar(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike | None = ..., out: None = ..., ddof: int = 0, keepdims: builtins.bool = False, where: ArrayLike | None = ...) -> Array: ... ndarray = Array -def ndim(a: ArrayLike) -> int: ... +def ndim(a: ArrayLike | SupportsNdim) -> int: ... def negative(x: ArrayLike, /) -> Array: ... newaxis = None def nextafter(x: ArrayLike, y: ArrayLike, /) -> Array: ... @@ -739,7 +744,7 @@ def not_equal(x: ArrayLike, y: ArrayLike, /) -> Array: ... number = _np.number object_ = _np.object_ ogrid: _Ogrid -def ones(shape: Any, dtype: DTypeLike | None = ..., +def ones(shape: Any, dtype: DTypeLike | None = ..., *, device: _Device | _Sharding | None = ...) -> Array: ... def ones_like(a: ArrayLike | DuckTypedArray, dtype: DTypeLike | None = ..., @@ -781,7 +786,7 @@ def positive(x: ArrayLike, /) -> Array: ... def pow(x: ArrayLike, y: ArrayLike, /) -> Array: ... def power(x: ArrayLike, y: ArrayLike, /) -> Array: ... printoptions = _np.printoptions -def prod(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ..., +def prod(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike | None = ..., out: None = ..., keepdims: builtins.bool = ..., initial: ArrayLike | None = ..., where: ArrayLike | None = ..., promote_integers: builtins.bool = ...) -> Array: ... @@ -798,18 +803,19 @@ def quantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = .. r_: _RClass def rad2deg(x: ArrayLike, /) -> Array: ... def radians(x: ArrayLike, /) -> Array: ... -def ravel(a: ArrayLike, order: str = ...) -> Array: ... +def ravel(a: ArrayLike, order: str = ..., *, + out_sharding: NamedSharding | P | None = ...) -> Array: ... def ravel_multi_index(multi_index: Sequence[ArrayLike], dims: Sequence[int], mode: str = ..., order: str = ...) -> Array: ... def real(x: ArrayLike, /) -> Array: ... def reciprocal(x: ArrayLike, /) -> Array: ... -register_jax_array_methods: Any def remainder(x: ArrayLike, y: ArrayLike, /) -> Array: ... def repeat(a: ArrayLike, repeats: ArrayLike, axis: int | None = ..., *, - total_repeat_length: int | None = ...) -> Array: ... + total_repeat_length: int | None = ..., + out_sharding: NamedSharding | P | None = None) -> Array: ... def reshape( - a: ArrayLike, shape: DimSize | Shape = ..., - newshape: DimSize | Shape | None = ..., order: str = ... + a: ArrayLike, shape: DimSize | Shape, order: str = ..., *, copy: bool | None = ..., + out_sharding: NamedSharding | P | None = ..., ) -> Array: ... def resize(a: ArrayLike, new_shape: Shape) -> Array: ... @@ -841,8 +847,9 @@ def setdiff1d( size: int | None = ..., fill_value: ArrayLike | None = ..., ) -> Array: ... -def setxor1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: builtins.bool = ...) -> Array: ... -def shape(a: ArrayLike) -> tuple[int, ...]: ... +def setxor1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: builtins.bool = ..., *, + size: int | None = ..., fill_value: ArrayLike | None = ...) -> Array: ... +def shape(a: ArrayLike | SupportsShape) -> tuple[int, ...]: ... def sign(x: ArrayLike, /) -> Array: ... def signbit(x: ArrayLike, /) -> Array: ... signedinteger = _np.signedinteger @@ -850,7 +857,7 @@ def sin(x: ArrayLike, /) -> Array: ... def sinc(x: ArrayLike, /) -> Array: ... single: Any def sinh(x: ArrayLike, /) -> Array: ... -def size(a: ArrayLike, axis: int | None = None) -> int: ... +def size(a: ArrayLike | SupportsSize, axis: int | None = None) -> int: ... def sort( a: ArrayLike, axis: int | None = ..., @@ -879,14 +886,14 @@ def stack( out: None = ..., dtype: DTypeLike | None = ..., ) -> Array: ... -def std(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ..., +def std(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike | None = ..., out: None = ..., ddof: int = ..., keepdims: builtins.bool = ..., *, where: ArrayLike | None = ..., correction: int | float | None = ...) -> Array: ... subtract: BinaryUfunc def sum( a: ArrayLike, axis: _Axis = ..., - dtype: DTypeLike = ..., + dtype: DTypeLike | None = ..., out: None = ..., keepdims: builtins.bool = ..., initial: ArrayLike | None = ..., @@ -924,24 +931,25 @@ def transpose(a: ArrayLike, axes: Sequence[int] | None = ...) -> Array: ... def trapezoid(y: ArrayLike, x: ArrayLike | None = None, dx: ArrayLike = ..., axis: int = ...) -> Array: ... def tri( - N: int, M: int | None = ..., k: int = ..., dtype: DTypeLike = ... + N: int, M: int | None = ..., k: int = ..., dtype: DTypeLike | None = ... ) -> Array: ... def tril(m: ArrayLike, k: int = ...) -> Array: ... def tril_indices( n: int, k: int = ..., m: int | None = ... ) -> tuple[Array, Array]: ... -def tril_indices_from(arr: ArrayLike, k: int = ...) -> tuple[Array, Array]: ... +def tril_indices_from(arr: ArrayLike | SupportsShape, k: int = ...) -> tuple[Array, Array]: ... def fill_diagonal(a: ArrayLike, val: ArrayLike, wrap: builtins.bool = ..., *, inplace: builtins.bool = ...) -> Array: ... def trim_zeros(filt: ArrayLike, trim: str = ...) -> Array: ... def triu(m: ArrayLike, k: int = ...) -> Array: ... def triu_indices( n: int, k: int = ..., m: int | None = ... ) -> tuple[Array, Array]: ... -def triu_indices_from(arr: ArrayLike, k: int = ...) -> tuple[Array, Array]: ... +def triu_indices_from(arr: ArrayLike | SupportsShape, k: int = ...) -> tuple[Array, Array]: ... def true_divide(x: ArrayLike, y: ArrayLike, /) -> Array: ... def trunc(x: ArrayLike, /) -> Array: ... uint: Any uint16: Any +uint2: Any uint32: Any uint4: Any uint64: Any @@ -967,7 +975,7 @@ class _UniqueInverseResult(NamedTuple): def unique(ar: ArrayLike, return_index: builtins.bool = ..., return_inverse: builtins.bool = ..., return_counts: builtins.bool = ..., axis: int | None = ..., *, equal_nan: builtins.bool = ..., size: int | None = ..., - fill_value: ArrayLike | None = ... + fill_value: ArrayLike | None = ..., sorted: bool = ..., ): ... def unique_all(x: ArrayLike, /, *, size: int | None = ..., fill_value: ArrayLike | None = ...) -> _UniqueAllResult: ... @@ -991,7 +999,7 @@ def unwrap(p: ArrayLike, discont: ArrayLike | None = ..., def vander( x: ArrayLike, N: int | None = ..., increasing: builtins.bool = ... ) -> Array: ... -def var(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ..., +def var(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike | None = ..., out: None = ..., ddof: int = ..., keepdims: builtins.bool = ..., *, where: ArrayLike | None = ..., correction: int | float | None = ...) -> Array: ... def vdot( @@ -1026,7 +1034,7 @@ def where(condition: ArrayLike, x: ArrayLike | None = ..., fill_value: None | ArrayLike | tuple[ArrayLike, ...] = ... ) -> Array | tuple[Array, ...]: ... -def zeros(shape: Any, dtype: DTypeLike | None = ..., +def zeros(shape: Any, dtype: DTypeLike | None = ..., *, device: _Device | _Sharding | None = ...) -> Array: ... def zeros_like(a: ArrayLike | DuckTypedArray, dtype: DTypeLike | None = ..., diff --git a/jax/profiler.py b/jax/profiler.py index 77157dc02a13..d776791e9200 100644 --- a/jax/profiler.py +++ b/jax/profiler.py @@ -14,16 +14,24 @@ # Note: import as is required for names to be exported. # See PEP 484 & https://github.com/jax-ml/jax/issues/7570 +from typing import Any from jax._src.profiler import ( - StepTraceAnnotation as StepTraceAnnotation, - TraceAnnotation as TraceAnnotation, - device_memory_profile as device_memory_profile, - save_device_memory_profile as save_device_memory_profile, - start_server as start_server, - stop_server as stop_server, - start_trace as start_trace, - stop_trace as stop_trace, - trace as trace, - annotate_function as annotate_function, + ProfileOptions as ProfileOptions, + StepTraceAnnotation as StepTraceAnnotation, + TraceAnnotation as TraceAnnotation, + annotate_function as annotate_function, + device_memory_profile as device_memory_profile, + save_device_memory_profile as save_device_memory_profile, + start_server as start_server, + start_trace as start_trace, + stop_server as stop_server, + stop_trace as stop_trace, + trace as trace, ) + +# this is a temporary shim to please pytype in the meantime before the migration +# is complete for cl/760646494 +ProfileData: Any = None +ProfileEvent: Any = None +ProfilePlane: Any = None diff --git a/jax/random.py b/jax/random.py index 9db584895cf1..89d68a24ccaf 100644 --- a/jax/random.py +++ b/jax/random.py @@ -92,7 +92,7 @@ To learn more about this upgrade, and the design of key types, see `JEP 9263 - `_. + `_. Advanced -------- @@ -178,7 +178,7 @@ ``XLA_FLAGS=--xla_tpu_spmd_rng_bit_generator_unsafe=1``. For more about ``jax_threefry_partitionable``, see -https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#generating-random-numbers +https://docs.jax.dev/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#generating-random-numbers **Summary:** diff --git a/jax/scipy/linalg.py b/jax/scipy/linalg.py index 64bc0544000b..c8a2d5f81957 100644 --- a/jax/scipy/linalg.py +++ b/jax/scipy/linalg.py @@ -31,6 +31,7 @@ lu as lu, lu_factor as lu_factor, lu_solve as lu_solve, + pascal as pascal, polar as polar, qr as qr, rsf2csf as rsf2csf, diff --git a/jax/scipy/special.py b/jax/scipy/special.py index 2ffc65a1abe1..e1330d4b6cf3 100644 --- a/jax/scipy/special.py +++ b/jax/scipy/special.py @@ -37,6 +37,7 @@ gammaln as gammaln, gammasgn as gammasgn, hyp1f1 as hyp1f1, + hyp2f1 as hyp2f1, i0 as i0, i0e as i0e, i1 as i1, diff --git a/jax/sharding.py b/jax/sharding.py index 55ff0f6aea0b..ef963d6a0138 100644 --- a/jax/sharding.py +++ b/jax/sharding.py @@ -20,8 +20,6 @@ NamedSharding as NamedSharding, SingleDeviceSharding as SingleDeviceSharding, PmapSharding as PmapSharding, - GSPMDSharding as GSPMDSharding, - PositionalSharding as PositionalSharding, use_mesh as use_mesh, set_mesh as set_mesh, ) @@ -36,16 +34,24 @@ ) _deprecations = { - # Finalized 2024-10-01; remove after 2025-01-01. - "XLACompatibleSharding": ( + # Added April 11, 2025. + "PositionalSharding": ( ( - "jax.sharding.XLACompatibleSharding was removed in JAX v0.4.34. " - "Use jax.sharding.Sharding instead." + "jax.sharding.PositionalSharding was deprecated in JAX v0.6.0 and" + " removed in JAX v0.7.0" ), None, - ) + ), + "GSPMDSharding": ( + ( + "jax.sharding.GSPMDSharding was deprecated in JAX v0.6.0 and" + " removed in JAX v0.7.0" + ), + None, + ), } + from jax._src.deprecations import deprecation_getattr as _deprecation_getattr __getattr__ = _deprecation_getattr(__name__, _deprecations) del _deprecation_getattr diff --git a/jax/stages.py b/jax/stages.py index 3e7e461c385b..aa4c96168b3f 100644 --- a/jax/stages.py +++ b/jax/stages.py @@ -18,7 +18,7 @@ lowering and compilation *ahead of time*. This module defines types that represent the stages of this process. -For more, see the `AOT walkthrough `_. +For more, see the `AOT walkthrough `_. """ # Note: import as is required for names to be exported. diff --git a/jax/tools/jax_to_ir.py b/jax/tools/jax_to_ir.py index 904ce509a87e..47b85382f8bf 100644 --- a/jax/tools/jax_to_ir.py +++ b/jax/tools/jax_to_ir.py @@ -240,16 +240,12 @@ def parse_shape_str(s): _DT = { 'pred': jnp.bool_, - 'u4': jnp.uint4, 'u8': jnp.uint8, 'u16': jnp.uint16, 'u32': jnp.uint32, 'u64': jnp.uint64, - 's4': jnp.int4, 's8': jnp.int8, 's16': jnp.int16, 's32': jnp.int32, 's64': jnp.int64, + 'u2': jnp.uint2, 'u4': jnp.uint4, 'u8': jnp.uint8, 'u16': jnp.uint16, 'u32': jnp.uint32, 'u64': jnp.uint64, + 's2': jnp.int2, 's4': jnp.int4, 's8': jnp.int8, 's16': jnp.int16, 's32': jnp.int32, 's64': jnp.int64, 'bf16': jnp.bfloat16, 'f16': jnp.float16, 'f32': jnp.float32, 'f64': jnp.float64, 'c64': jnp.complex64, 'c128': jnp.complex128 } -if hasattr(jnp, 'int2'): - _DT['s2'] = jnp.int2 -if hasattr(jnp, 'uint2'): - _DT['u2'] = jnp.uint2 _SHAPE_RE = re.compile(f"^({'|'.join(_DT)})\\[\\s*(\\d*[\\s*,\\d+]*)\\s*\\]$") diff --git a/jax/tools/pgo_nsys_converter.py b/jax/tools/pgo_nsys_converter.py index 10209c9a85ba..3e961733f435 100644 --- a/jax/tools/pgo_nsys_converter.py +++ b/jax/tools/pgo_nsys_converter.py @@ -64,7 +64,7 @@ m = thunk_re.search(name) if m is not None: if args.post_process: - cost_dictionary.setdefault(m.group(1), []).append((time_ns/1000.0)) + cost_dictionary.setdefault(m.group(1), []).append(time_ns/1000.0) else: protofile.write(f'costs {{ name: "{m.group(1)}" cost_us: {time_ns / 1000.0} }}\n') if args.post_process: diff --git a/jax/tree.py b/jax/tree.py index 270c34fe9647..03ca503f3a41 100644 --- a/jax/tree.py +++ b/jax/tree.py @@ -19,6 +19,7 @@ from jax._src.tree import ( all as all, + broadcast as broadcast, flatten_with_path as flatten_with_path, flatten as flatten, leaves_with_path as leaves_with_path, diff --git a/jax/tree_util.py b/jax/tree_util.py index 956d79b9b4ef..ad864def3b44 100644 --- a/jax/tree_util.py +++ b/jax/tree_util.py @@ -48,16 +48,16 @@ PyTreeDef as PyTreeDef, SequenceKey as SequenceKey, all_leaves as all_leaves, - build_tree as build_tree, default_registry as default_registry, keystr as keystr, + register_dataclass as register_dataclass, register_pytree_node_class as register_pytree_node_class, register_pytree_node as register_pytree_node, register_pytree_with_keys_class as register_pytree_with_keys_class, - register_dataclass as register_dataclass, register_pytree_with_keys as register_pytree_with_keys, register_static as register_static, tree_all as tree_all, + tree_broadcast as tree_broadcast, tree_flatten_with_path as tree_flatten_with_path, tree_flatten as tree_flatten, tree_leaves_with_path as tree_leaves_with_path, @@ -72,3 +72,23 @@ treedef_is_leaf as treedef_is_leaf, treedef_tuple as treedef_tuple, ) + +_deprecations = { + # Added March 21, 2025: + "build_tree": ( + ( + "jax.tree_util.build_tree was deprecated in JAX v0.6.0 and removed in" + " JAX v0.7.0. Use jax.tree.unflatten instead." + ), + None + ), +} + +import typing as _typing +if _typing.TYPE_CHECKING: + from jax._src.tree_util import build_tree as build_tree +else: + from jax._src.deprecations import deprecation_getattr + __getattr__ = deprecation_getattr(__name__, _deprecations) + del deprecation_getattr +del _typing diff --git a/jax/typing.py b/jax/typing.py index 89efa1f2ca66..0530c69e60ca 100644 --- a/jax/typing.py +++ b/jax/typing.py @@ -15,7 +15,7 @@ """ The JAX typing module is where JAX-specific static type annotations live. This submodule is a work in progress; to see the proposal behind the types exported -here, see https://jax.readthedocs.io/en/latest/jep/12049-type-annotations.html. +here, see https://docs.jax.dev/en/latest/jep/12049-type-annotations.html. The currently-available types are: @@ -67,7 +67,7 @@ def my_function(x: ArrayLike) -> Array: batch-wise transforms like :func:`~jax.vmap` or :func:`jax.pmap`. For more information on this, see `Non-array inputs NumPy vs JAX`_ -.. _Non-array inputs NumPy vs JAX: https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#non-array-inputs-numpy-vs-jax +.. _Non-array inputs NumPy vs JAX: https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html#non-array-inputs-numpy-vs-jax """ from jax._src.typing import ( ArrayLike as ArrayLike, diff --git a/jax/util.py b/jax/util.py index 8071f77dffe2..1931a0293c09 100644 --- a/jax/util.py +++ b/jax/util.py @@ -15,19 +15,111 @@ # Note: import as is required for names to be exported. # See PEP 484 & https://github.com/jax-ml/jax/issues/7570 -from jax._src.util import ( - HashableFunction as HashableFunction, - as_hashable_function as as_hashable_function, - cache as cache, - safe_map as safe_map, - safe_zip as safe_zip, - split_dict as split_dict, - split_list as split_list, - split_list_checked as split_list_checked, - split_merge as split_merge, - subvals as subvals, - toposort as toposort, - unzip2 as unzip2, - wrap_name as wrap_name, - wraps as wraps, -) +import jax._src.deprecations +import jax._src.util + + +_deprecations = { + # Finalized in JAX v0.7.0; remove entries in JAX v0.8.0 + "to_dlpack": ( + ( + "jax.dlpack.to_dlpack was deprecated in JAX v0.6.0 and" + " removed in JAX v0.7.0. Please use the newer DLPack API based on" + " __dlpack__ and __dlpack_device__ instead. Typically, you can pass" + " a JAX array directly to the `from_dlpack` function of another" + " framework without using `to_dlpack`." + ), + None, + ), + "HashableFunction": ( + ( + "HashableFunction was deprecated in JAX v0.6.0 and removed" + " in JAX v0.7.0." + ), + None, + ), + "as_hashable_function": ( + ( + "as_hashable_function was deprecated in JAX v0.6.0 and" + " removed in JAX v0.7.0." + ), + None, + ), + "cache": ( + "cache was deprecated in JAX v0.6.0 and removed in JAX v0.7.0.", + None, + ), + "safe_map": ( + "safe_map was deprecated in JAX v0.6.0 and removed in JAX v0.7.0.", + None, + ), + "safe_zip": ( + ( + "safe_zip was deprecated in JAX v0.6.0 and removed in JAX v0.7.0." + ), + None, + ), + "split_dict": ( + "split_dict was deprecated in JAX v0.6.0 and removed in JAX v0.7.0.", + None, + ), + "split_list": ( + "split_list was deprecated in JAX v0.6.0 and removed in JAX v0.7.0.", + None, + ), + "split_list_checked": ( + ( + "split_list_checked was deprecated in JAX v0.6.0 and" + " removed in JAX v0.7.0." + ), + None, + ), + "split_merge": ( + "split_merge was deprecated in JAX v0.6.0 and removed in JAX v0.7.0.", + None, + ), + "subvals": ( + "subvals was deprecated in JAX v0.6.0 and removed in JAX v0.7.0.", + None, + ), + "toposort": ( + "toposort was deprecated in JAX v0.6.0 and removed in JAX v0.7.0.", + None, + ), + "unzip2": ( + "unzip2 was deprecated in JAX v0.6.0 and removed in JAX v0.7.0.", + None, + ), + "wrap_name": ( + "wrap_name was deprecated in JAX v0.6.0 and removed in JAX v0.7.0.", + None, + ), + "wraps": ( + "wraps was deprecated in JAX v0.6.0 and removed in JAX v0.7.0.", + None, + ), +} + + +import typing as _typing + +if _typing.TYPE_CHECKING: + HashableFunction = jax._src.util.HashableFunction + as_hashable_function = jax._src.util.as_hashable_function + cache = jax._src.util.cache + safe_map = jax._src.util.safe_map + safe_zip = jax._src.util.safe_zip + split_dict = jax._src.util.split_dict + split_list = jax._src.util.split_list + split_list_checked = jax._src.util.split_list_checked + split_merge = jax._src.util.split_merge + subvals = jax._src.util.subvals + toposort = jax._src.util.toposort + unzip2 = jax._src.util.unzip2 + wrap_name = jax._src.util.wrap_name + wraps = jax._src.util.wraps +else: + __getattr__ = jax._src.deprecations.deprecation_getattr( + __name__, _deprecations + ) +del _typing diff --git a/jax/version.py b/jax/version.py index be20aca06358..e2d70eccc54e 100644 --- a/jax/version.py +++ b/jax/version.py @@ -21,7 +21,7 @@ import pathlib import subprocess -_version = "0.5.3" +_version = "0.6.3" # The following line is overwritten by build scripts in distributions & # releases. Do not modify this manually, or jax/jaxlib build will fail. _release_version: str | None = None @@ -93,6 +93,12 @@ def _get_version_for_build() -> str: return _version_from_git_tree(_version) or _version_from_todays_date(_version) +def _is_prerelease() -> bool: + """Determine if this is a pre-release ("rc" wheels) build.""" + rc_version = os.getenv("WHEEL_VERSION_SUFFIX", "") + return True if rc_version.startswith("rc") else False + + def _write_version(fname: str) -> None: """Used by setup.py to write the specified version info into the source tree.""" release_version = _get_version_for_build() @@ -146,7 +152,7 @@ def make_release_tree(self, base_dir, files): __version__ = _get_version_string() -_minimum_jaxlib_version = "0.5.1" +_minimum_jaxlib_version = "0.6.2" def _version_as_tuple(version_str): return tuple(int(i) for i in version_str.split(".") if i.isdigit()) diff --git a/jax_plugins/cuda/BUILD.bazel b/jax_plugins/cuda/BUILD.bazel index 1f4e5a08dcb9..7070bf6bc495 100644 --- a/jax_plugins/cuda/BUILD.bazel +++ b/jax_plugins/cuda/BUILD.bazel @@ -12,15 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -licenses(["notice"]) - load( - "//jaxlib:jax.bzl", - "if_windows", - "py_library_providing_imports_info", - "pytype_library", + "//jaxlib:jax.bzl", + "py_library_providing_imports_info", + "pytype_library", ) +licenses(["notice"]) + package( default_applicable_licenses = [], default_visibility = ["//:__subpackages__"], @@ -34,46 +33,27 @@ exports_files([ "setup.py", ]) +cc_binary( + name = "pjrt_c_api_gpu_plugin.so", + linkopts = [ + "-Wl,--version-script,$(location :gpu_version_script.lds)", + "-Wl,--no-undefined", + ], + linkshared = True, + deps = [ + ":gpu_version_script.lds", + "//jaxlib/mosaic/gpu:custom_call", + "@xla//xla/pjrt/c:pjrt_c_api_gpu", + "@xla//xla/service:gpu_plugin", + "@xla//xla/stream_executor:cuda_platform", + ], +) + py_library_providing_imports_info( name = "cuda_plugin", srcs = [ "__init__.py", ], - data = if_windows( - ["@xla//xla/pjrt/c/pjrt_c_api_gpu_plugin.pyd"], - ["@xla//xla/pjrt/c:pjrt_c_api_gpu_plugin.so"], - ), + data = [":pjrt_c_api_gpu_plugin.so"], lib_rule = pytype_library, ) - -config_setting( - name = "disable_jaxlib_for_cpu_build", - flag_values = { - "//jax:build_jaxlib": "false", - "@local_config_cuda//:enable_cuda": "False", - }, -) - -config_setting( - name = "disable_jaxlib_for_cuda12_build", - flag_values = { - "//jax:build_jaxlib": "false", - "@local_config_cuda//:enable_cuda": "True", - }, -) - -config_setting( - name = "enable_py_import_for_cpu_build", - flag_values = { - "//jax:build_jaxlib": "wheel", - "@local_config_cuda//:enable_cuda": "False", - }, -) - -config_setting( - name = "enable_py_import_for_cuda12_build", - flag_values = { - "//jax:build_jaxlib": "wheel", - "@local_config_cuda//:enable_cuda": "True", - }, -) diff --git a/jax_plugins/cuda/__init__.py b/jax_plugins/cuda/__init__.py index f6540e986024..de296a7a9e81 100644 --- a/jax_plugins/cuda/__init__.py +++ b/jax_plugins/cuda/__init__.py @@ -12,27 +12,41 @@ # See the License for the specific language governing permissions and # limitations under the License. +import ctypes import functools import importlib import logging import os import pathlib +import traceback +from typing import Any from jax._src.lib import triton from jax._src.lib import xla_client import jax._src.xla_bridge as xb -# cuda_plugin_extension locates inside jaxlib. `jaxlib` is for testing without -# preinstalled jax cuda plugin packages. -for pkg_name in ['jax_cuda12_plugin', 'jaxlib.cuda']: - try: - cuda_plugin_extension = importlib.import_module( - f'{pkg_name}.cuda_plugin_extension' - ) - except ImportError: - cuda_plugin_extension = None - else: - break +cuda_plugin_extension = None +cuda_versions = None + +def _import_extensions(): + global cuda_plugin_extension + global cuda_versions + + # cuda_plugin_extension locates inside jaxlib. `jaxlib` is for testing without + # preinstalled jax cuda plugin packages. + for pkg_name in ['jax_cuda12_plugin', 'jaxlib.cuda']: + try: + cuda_plugin_extension = importlib.import_module( + f'{pkg_name}.cuda_plugin_extension' + ) + cuda_versions = importlib.import_module( + f'{pkg_name}._versions' + ) + except ImportError: + cuda_plugin_extension = None + cuda_versions = None + else: + break logger = logging.getLogger(__name__) @@ -51,7 +65,7 @@ def _get_library_path(): runfiles_dir = os.getenv('RUNFILES_DIR', None) if runfiles_dir: local_path = os.path.join( - runfiles_dir, 'xla/xla/pjrt/c/pjrt_c_api_gpu_plugin.so' + runfiles_dir, '__main__/jax_plugins/cuda/pjrt_c_api_gpu_plugin.so' ) if os.path.exists(local_path): @@ -76,11 +90,239 @@ def _get_library_path(): return None +def _load(module, libraries): + try: + m = importlib.import_module(f"nvidia.{module}") + except ImportError: + m = None + + for lib in libraries: + excs = [] + if m is not None: + path = pathlib.Path(m.__path__[0]) / "lib" / lib + try: + ctypes.cdll.LoadLibrary(path) + continue + except OSError as e: + excs.append(e) + + # TODO(phawkins): check the non-Python path here and error if not found. + # # Try again, without the Python module path. + # try: + # ctypes.cdll.LoadLibrary(lib) + # continue + # except OSError as e: + # excs.append(e) + # + # raise ExceptionGroup(f"Unable to load CUDA library {lib}", excs) # noqa: F821 + + +def _load_nvidia_libraries(): + """Attempts to load NVIDIA's libraries. + + We prefer the Python packages, if present. If not, we fall back to loading + them from LD_LIBRARY_PATH. By loading the libraries here, later lookups will + find these copies.""" + _load("cuda_runtime", ["libcudart.so.12"]) + # cuda_nvrtc isn't directly a dependency of JAX, but CUDNN appears to need it + # and at least in CUDA 12.9 has RUNPATHs misconfigured to refer to + # nvidia/nvrtc instead of nvidia/cuda_nvrtc. + _load("cuda_nvrtc", ["libnvrtc.so.12"]) + _load("cublas", ["libcublas.so.12", "libcublasLt.so.12"]) + _load("nccl", ["libnccl.so.2"]) + _load("cuda_cupti", ["libcupti.so.12"]) + _load("cusparse", ["libcusparse.so.12"]) + _load("cusolver", ["libcusolver.so.11"]) + _load("cufft", ["libcufft.so.11"]) + _load("nvshmem", ["libnvshmem_host.so.3"]) + _load("cudnn", ["libcudnn.so.9"]) + + +def _check_cuda_versions(raise_on_first_error: bool = False, + debug: bool = False): + assert cuda_versions is not None + results: list[dict[str, Any]] = [] + + def _make_msg(name: str, + runtime_version: int, + build_version: int, + min_supported: int, + debug_msg: bool = False): + if debug_msg: + return (f"Package: {name}\n" + f"Version JAX was built against: {build_version}\n" + f"Minimum supported: {min_supported}\n" + f"Installed version: {runtime_version}") + if min_supported: + req_str = (f"The local installation version must be no lower than " + f"{min_supported}.") + else: + req_str = ("The local installation must be the same version as " + "the version against which JAX was built.") + msg = (f"Outdated {name} installation found.\n" + f"Version JAX was built against: {build_version}\n" + f"Minimum supported: {min_supported}\n" + f"Installed version: {runtime_version}\n" + f"{req_str}") + return msg + + + def _version_check(name: str, + get_version, + get_build_version, + scale_for_comparison: int = 1, + min_supported_version: int = 0) -> int | None: + """Checks the runtime CUDA component version against the JAX one. + + Args: + name: Of the CUDA component. + get_version: A function to get the local runtime version of the component. + get_build_version: A function to get the build version of the component. + scale_for_comparison: For rounding down a version to ignore patch/minor. + min_supported_version: An absolute minimum version required. Must be + passed without rounding down. + + Returns: the runtime version, or None if the component is not found. + + Raises: + RuntimeError: If the component is not found, or is of unsupported version, + and if raising the error is not deferred till later. + """ + + build_version = get_build_version() + try: + version = get_version() + except Exception as e: + err_msg = f"Unable to load {name}. Is it installed?" + if raise_on_first_error: + raise RuntimeError(err_msg) from e + err_msg += f"\n{traceback.format_exc()}" + results.append({"name": name, "installed": False, "msg": err_msg}) + return + + if not min_supported_version: + min_supported_version = build_version // scale_for_comparison + passed = min_supported_version <= version + + if not passed or debug: + msg = _make_msg(name=name, + runtime_version=version, + build_version=build_version, + min_supported=min_supported_version, + debug_msg=passed) + if not passed and raise_on_first_error: + raise RuntimeError(msg) + else: + record = {"name": name, + "installed": True, + "msg": msg, + "passed": passed, + "build_version": build_version, + "version": version, + "minimum_supported": min_supported_version} + results.append(record) + return version + + _version_check("CUDA", cuda_versions.cuda_runtime_get_version, + cuda_versions.cuda_runtime_build_version, + scale_for_comparison=10, + min_supported_version=12010) + cudnn_version = _version_check( + "cuDNN", + cuda_versions.cudnn_get_version, + cuda_versions.cudnn_build_version, + # NVIDIA promise both backwards and forwards compatibility for cuDNN patch + # versions: + # https://docs.nvidia.com/deeplearning/cudnn/backend/latest/developer/forward-compatibility.html#cudnn-api-compatibility + scale_for_comparison=100, + ) + _version_check("cuFFT", cuda_versions.cufft_get_version, + cuda_versions.cufft_build_version, + # Ignore patch versions. + scale_for_comparison=100) + # TODO(phawkins): for some reason this check fails with a cusolver internal + # error when fetching the version. This may be a path error from our stubs. + # Figure out what's happening here and re-enable. + # _version_check("cuSOLVER", cuda_versions.cusolver_get_version, + # cuda_versions.cusolver_build_version, + # # Ignore patch versions. + # scale_for_comparison=100, + # min_supported_version=11400) + _version_check("cuPTI", cuda_versions.cupti_get_version, + cuda_versions.cupti_build_version, + min_supported_version=18) + cublas_version = _version_check("cuBLAS", cuda_versions.cublas_get_version, + cuda_versions.cublas_build_version, + # Ignore patch versions. + scale_for_comparison=100, + min_supported_version=120100) + _version_check("cuSPARSE", cuda_versions.cusparse_get_version, + cuda_versions.cusparse_build_version, + # Ignore patch versions. + scale_for_comparison=100, + min_supported_version=12100) + + # https://docs.nvidia.com/deeplearning/cudnn/backend/latest/release-notes.html#cudnn-9-10-1 + if (cudnn_version is not None and cudnn_version == 91000 + and cuda_versions.cudnn_build_version() != 91000): + msg = ("cuDNN 9.10.0 had a binary backward-compatibility issue due to reordered enum " + f"values affecting block-scale datatypes. Found runtime version {cudnn_version} " + f"and build version {cuda_versions.cudnn_build_version()}. Please upgrade to " + "9.10.1 or above.") + if raise_on_first_error: + raise RuntimeError(msg) + else: + results.append({"installed": True, "msg": msg, "passed": False}) + # xb.local_device_count() cannot safely be called at this point + if xb.CUDA_VISIBLE_DEVICES.value == "all": + local_device_count = cuda_versions.cuda_device_count() + else: + local_device_count = len(xb.CUDA_VISIBLE_DEVICES.value.split(",")) + # https://docs.nvidia.com/deeplearning/cudnn/backend/latest/release-notes.html#cudnn-9-10-0 + if (cudnn_version is not None and cudnn_version < 91001 + and cublas_version is not None and cublas_version >= 120900 + and local_device_count > 1): + msg = (f"cuDNN < 9.10.0 ({cudnn_version} found) had an issue that caused some multi-GPU " + "matmuls, in which the same finalized execution plan is used across different " + f"GPUs, to be functionally incorrect when run with cublasLt >= 12.9 ({cublas_version} " + "found). Please upgrade to 9.10.1 or above.") + if raise_on_first_error: + raise RuntimeError(msg) + else: + results.append({"installed": True, "msg": msg, "passed": False}) + + errors = [] + debug_results = [] + for result in results: + message: str = result['msg'] + if not result['installed'] or not result['passed']: + errors.append(message) + else: + debug_results.append(message) + + join_str = f'\n{"-" * 50}\n' + if debug_results: + print(f'CUDA components status (debug):\n' + f'{join_str.join(debug_results)}') + if errors: + raise RuntimeError(f'Unable to use CUDA because of the ' + f'following issues with CUDA components:\n' + f'{join_str.join(errors)}') + + def initialize(): + _load_nvidia_libraries() + _import_extensions() path = _get_library_path() if path is None: return + if not os.getenv("JAX_SKIP_CUDA_CONSTRAINTS_CHECK"): + _check_cuda_versions(raise_on_first_error=True) + else: + print('Skipped CUDA versions constraints check due to the ' + 'JAX_SKIP_CUDA_CONSTRAINTS_CHECK env var being set.') + options = xla_client.generate_pjrt_gpu_plugin_options() c_api = xb.register_plugin( 'cuda', priority=500, library_path=str(path), options=options @@ -92,8 +334,10 @@ def initialize(): cuda_plugin_extension.register_custom_call_target, c_api ), ) - for _name, _value in cuda_plugin_extension.registrations().items(): - xla_client.register_custom_call_target(_name, _value, platform="CUDA") + for _name, _value in cuda_plugin_extension.ffi_registrations().items(): + xla_client.register_custom_call_target( + _name, _value, platform='CUDA', api_version=1 + ) xla_client.register_custom_type_id_handler( "CUDA", functools.partial( diff --git a/jaxlib/tools/gpu_version_script.lds b/jax_plugins/cuda/gpu_version_script.lds similarity index 100% rename from jaxlib/tools/gpu_version_script.lds rename to jax_plugins/cuda/gpu_version_script.lds diff --git a/jax_plugins/cuda/plugin_setup.py b/jax_plugins/cuda/plugin_setup.py index ce31684de46f..baa20f2419fc 100644 --- a/jax_plugins/cuda/plugin_setup.py +++ b/jax_plugins/cuda/plugin_setup.py @@ -49,15 +49,15 @@ def has_ext_modules(self): author="JAX team", author_email="jax-dev@google.com", packages=[package_name], - python_requires=">=3.10", + python_requires=">=3.11", install_requires=[f"jax-cuda{cuda_version}-pjrt=={__version__}"], extras_require={ - 'with_cuda': [ + 'with-cuda': [ "nvidia-cublas-cu12>=12.1.3.1", "nvidia-cuda-cupti-cu12>=12.1.105", "nvidia-cuda-nvcc-cu12>=12.6.85", "nvidia-cuda-runtime-cu12>=12.1.105", - "nvidia-cudnn-cu12>=9.1,<10.0", + "nvidia-cudnn-cu12>=9.8,<10.0", "nvidia-cufft-cu12>=11.0.2.54", "nvidia-cusolver-cu12>=11.4.5.107", "nvidia-cusparse-cu12>=12.1.0.106", @@ -70,15 +70,21 @@ def has_ext_modules(self): # Until NVIDIA add version constraints, add a version constraint # here. "nvidia-nvjitlink-cu12>=12.1.105", + # nvrtc is a transitive and undeclared dep of cudnn. + "nvidia-cuda-nvrtc-cu12>=12.1.55", + # NVSHMEM is used by Mosaic GPU collectives and can be used by XLA to + # speed up collectives too. + "nvidia-nvshmem-cu12>=3.2.5", ], }, url="https://github.com/jax-ml/jax", license="Apache-2.0", classifiers=[ - "Development Status :: 3 - Alpha", - "Programming Language :: Python :: 3.10", + "Development Status :: 5 - Production/Stable", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Programming Language :: Python :: Free Threading :: 3 - Stable", ], package_data={ package_name: [ diff --git a/jax_plugins/cuda/setup.py b/jax_plugins/cuda/setup.py index 1ce555978dac..b2c89285e7fd 100644 --- a/jax_plugins/cuda/setup.py +++ b/jax_plugins/cuda/setup.py @@ -51,8 +51,9 @@ def load_version_module(pkg_path): url="https://github.com/jax-ml/jax", license="Apache-2.0", classifiers=[ - "Development Status :: 3 - Alpha", + "Development Status :: 5 - Production/Stable", "Programming Language :: Python :: 3", + "Programming Language :: Python :: Free Threading :: 3 - Stable", ], package_data={ package_name: ["xla_cuda_plugin.so"], diff --git a/jax_plugins/rocm/BUILD.bazel b/jax_plugins/rocm/BUILD.bazel index 6e265bcd18cf..7ee0726e7960 100644 --- a/jax_plugins/rocm/BUILD.bazel +++ b/jax_plugins/rocm/BUILD.bazel @@ -16,7 +16,6 @@ licenses(["notice"]) load( "//jaxlib:jax.bzl", - "if_windows", "py_library_providing_imports_info", "pytype_library", ) @@ -34,14 +33,26 @@ exports_files([ "setup.py", ]) +cc_binary( + name = "pjrt_c_api_gpu_plugin.so", + linkopts = [ + "-Wl,--version-script,$(location :gpu_version_script.lds)", + "-Wl,--no-undefined", + ], + linkshared = True, + deps = [ + ":gpu_version_script.lds", + "@xla//xla/pjrt/c:pjrt_c_api_gpu", + "@xla//xla/service:gpu_plugin", + "@xla//xla/stream_executor:rocm_platform", + ], +) + py_library_providing_imports_info( name = "rocm_plugin", srcs = [ "__init__.py", ], - data = if_windows( - ["@xla//xla/pjrt/c/pjrt_c_api_gpu_plugin.pyd"], - ["@xla//xla/pjrt/c:pjrt_c_api_gpu_plugin.so"], - ), + data = [":pjrt_c_api_gpu_plugin.so"], lib_rule = pytype_library, ) diff --git a/jax_plugins/rocm/__init__.py b/jax_plugins/rocm/__init__.py index c48a681bf337..cf2a625fa783 100644 --- a/jax_plugins/rocm/__init__.py +++ b/jax_plugins/rocm/__init__.py @@ -51,7 +51,7 @@ def _get_library_path(): runfiles_dir = os.getenv('RUNFILES_DIR', None) if runfiles_dir: local_path = pathlib.Path( - os.path.join(runfiles_dir, 'xla/xla/pjrt/c/pjrt_c_api_gpu_plugin.so') + os.path.join(runfiles_dir, '__main__/jax_plugins/rocm/pjrt_c_api_gpu_plugin.so') ) if local_path.exists(): @@ -92,8 +92,10 @@ def initialize(): rocm_plugin_extension.register_custom_call_target, c_api ), ) - for _name, _value in rocm_plugin_extension.registrations().items(): - xla_client.register_custom_call_target(_name, _value, platform="ROCM") + for _name, _value in rocm_plugin_extension.ffi_registrations().items(): + xla_client.register_custom_call_target( + _name, _value, platform='ROCM', api_version=1 + ) xla_client.register_custom_type_id_handler( "ROCM", functools.partial( diff --git a/jax_plugins/rocm/gpu_version_script.lds b/jax_plugins/rocm/gpu_version_script.lds new file mode 100644 index 000000000000..cbac4549bde3 --- /dev/null +++ b/jax_plugins/rocm/gpu_version_script.lds @@ -0,0 +1,9 @@ +VERS_1.0 { + global: + extern "C" { + GetPjrtApi; + }; + + local: + *; +}; diff --git a/jax_plugins/rocm/plugin_setup.py b/jax_plugins/rocm/plugin_setup.py index d504d0a11666..aba9730b8baf 100644 --- a/jax_plugins/rocm/plugin_setup.py +++ b/jax_plugins/rocm/plugin_setup.py @@ -54,16 +54,15 @@ def has_ext_modules(self): author="Ruturaj4", author_email="Ruturaj.Vaidya@amd.com", packages=[package_name], - python_requires=">=3.9", + python_requires=">=3.11", install_requires=[f"jax-rocm{rocm_version}-pjrt=={__version__}"], url="https://github.com/jax-ml/jax", license="Apache-2.0", classifiers=[ "Development Status :: 3 - Alpha", - "Programming Language :: Python :: 3.9", - "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", ], package_data={ package_name: [ diff --git a/jaxlib/BUILD b/jaxlib/BUILD index a35eabc9a505..dd4b06b34bcd 100644 --- a/jaxlib/BUILD +++ b/jaxlib/BUILD @@ -16,9 +16,21 @@ load( "//jaxlib:jax.bzl", + "cc_proto_library", + "if_oss", + "jax_visibility", "nanobind_extension", - "py_library_providing_imports_info", + "proto_library", + "py_deps", + "py_strict_test", "pytype_library", + "pytype_strict_library", +) +load( + "//jaxlib:pywrap.bzl", + "nanobind_pywrap_extension", + "pywrap_binaries", + "pywrap_library", ) load("//jaxlib:symlink_files.bzl", "symlink_files") @@ -29,41 +41,32 @@ package( default_visibility = ["//jax:internal"], ) -# This makes xla_extension module accessible from jax._src.lib. -genrule( - name = "xla_extension_py", - outs = ["xla_extension.py"], - cmd = "echo 'from xla.xla.python.xla_extension import *\n' > $@", +package_group( + name = "xla_python", + includes = [ + "//jax:internal", + ], ) -py_library_providing_imports_info( +pytype_strict_library( name = "jaxlib", - srcs = [ - "gpu_common_utils.py", - "gpu_linalg.py", - "gpu_prng.py", - "gpu_rnn.py", - "gpu_solver.py", - "gpu_sparse.py", - "gpu_triton.py", - "hlo_helpers.py", - "init.py", - "lapack.py", - "plugin_support.py", - ":version", - ":xla_client", - ":xla_extension_py", - ], data = [":ffi_headers"], - lib_rule = pytype_library, deps = [ + ":_jax", + ":_pretty_printer", ":cpu_feature_guard", + ":jax", + ":jaxlib_files", ":utils", + ":weakref_lru_cache", + ":xla_client", "//jaxlib/cpu:_lapack", + "//jaxlib/cpu:_sparse", "//jaxlib/mlir", "//jaxlib/mlir:arithmetic_dialect", "//jaxlib/mlir:builtin_dialect", "//jaxlib/mlir:chlo_dialect", + "//jaxlib/mlir:control_flow_dialect", "//jaxlib/mlir:func_dialect", "//jaxlib/mlir:gpu_dialect", "//jaxlib/mlir:ir", @@ -79,22 +82,45 @@ py_library_providing_imports_info( "//jaxlib/mlir:sparse_tensor_dialect", "//jaxlib/mlir:stablehlo_dialect", "//jaxlib/mlir:vector_dialect", + "//jaxlib/mlir/_mlir_libs:register_jax_dialects", "//jaxlib/mosaic", + "//jaxlib/mosaic/python:gpu_dialect", + "//jaxlib/mosaic/python:tpu_dialect", "//jaxlib/triton", - "@xla//xla/python:xla_extension", + "@xla//xla/python:_profiler", ], ) -symlink_files( - name = "version", - srcs = ["//jax:version.py"], - dst = ".", - flatten = True, +pytype_library( + name = "jaxlib_files", + srcs = [ + "cpu_sparse.py", + "gpu_common_utils.py", + "gpu_linalg.py", + "gpu_prng.py", + "gpu_rnn.py", + "gpu_solver.py", + "gpu_sparse.py", + "gpu_triton.py", + "hlo_helpers.py", + "init.py", + "lapack.py", + "plugin_support.py", + "xla_client.py", + ":version", + ], + deps = [ + ":_jax", + "//jaxlib/cpu:_lapack", + "//jaxlib/cpu:_sparse", + "//jaxlib/mlir:ir", + "//jaxlib/mlir:stablehlo_dialect", + ], ) symlink_files( - name = "xla_client", - srcs = ["@xla//xla/python:xla_client"], + name = "version", + srcs = ["//jax:version.py"], dst = ".", flatten = True, ) @@ -111,6 +137,47 @@ exports_files([ "setup.py", ]) +pywrap_library( + name = "jax", + common_lib_def_files_or_filters = { + "jaxlib/jax_common": "jax_common.json", + }, + common_lib_version_scripts = { + "jaxlib/jax_common": select({ + "@bazel_tools//src/conditions:windows": None, + "@bazel_tools//src/conditions:darwin": "libjax_common_darwin.lds", + "//conditions:default": "libjax_common.lds", + }), + }, + deps = [ + ":_jax", + ":_pretty_printer", + ":utils", + ":weakref_lru_cache", + "//jaxlib/mlir/_mlir_libs:_chlo", + "//jaxlib/mlir/_mlir_libs:_mlir", + "//jaxlib/mlir/_mlir_libs:_mlirDialectsGPU", + "//jaxlib/mlir/_mlir_libs:_mlirDialectsLLVM", + "//jaxlib/mlir/_mlir_libs:_mlirDialectsNVGPU", + "//jaxlib/mlir/_mlir_libs:_mlirDialectsSparseTensor", + "//jaxlib/mlir/_mlir_libs:_mlirGPUPasses", + "//jaxlib/mlir/_mlir_libs:_mlirHlo", + "//jaxlib/mlir/_mlir_libs:_mlirSparseTensorPasses", + "//jaxlib/mlir/_mlir_libs:_mosaic_gpu_ext", + "//jaxlib/mlir/_mlir_libs:_sdy", + "//jaxlib/mlir/_mlir_libs:_stablehlo", + "//jaxlib/mlir/_mlir_libs:_tpu_ext", + "//jaxlib/mlir/_mlir_libs:_triton_ext", + "//jaxlib/mlir/_mlir_libs:register_jax_dialects", + "@xla//xla/python:_profiler", + ], +) + +pywrap_binaries( + name = "jaxlib_binaries", + dep = ":jax", +) + cc_library( name = "absl_status_casters", hdrs = ["absl_status_casters.h"], @@ -167,58 +234,1103 @@ cc_library( features = ["-use_header_modules"], deps = [ "@com_google_absl//absl/base", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", ], ) +# This isn't a CPU kernel. This exists to catch cases where jaxlib is built for the wrong +# target architecture. +nanobind_extension( + name = "cpu_feature_guard", + srcs = ["cpu_feature_guard.c"], + module_name = "cpu_feature_guard", + deps = [ + "@xla//third_party/python_runtime:headers", + ], +) + +nanobind_pywrap_extension( + name = "_pretty_printer", + srcs = ["_pretty_printer.cc"], + deps = [ + ":nb_class_ptr", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", + "@nanobind", + ], +) + +nanobind_pywrap_extension( + name = "weakref_lru_cache", + srcs = ["weakref_lru_cache.cc"], + pytype_srcs = ["weakref_lru_cache.pyi"], + deps = [ + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/cleanup", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", + "@nanobind", + "@xla//third_party/python_runtime:headers", + "@xla//xla/pjrt:lru_cache", + "@xla//xla/tsl/platform:logging", + ], +) + +py_strict_test( + name = "weakref_lru_cache_test", + srcs = ["weakref_lru_cache_test.py"], + deps = [ + ":weakref_lru_cache", + ] + py_deps([ + "absl/flags", + "absl/logging", + "absl/testing", + ]), +) + +nanobind_pywrap_extension( + name = "utils", + srcs = ["utils.cc"], + deps = [ + "@com_google_absl//absl/cleanup", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/synchronization", + "@nanobind", + "@xla//third_party/python_runtime:headers", + ], +) + +nanobind_pywrap_extension( + name = "_jax", + srcs = ["xla.cc"], + pytype_deps = py_deps(["numpy"]), + pytype_srcs = glob(["_jax/*.pyi"]), + visibility = ["//visibility:public"], + deps = [ + ":config", + ":custom_call_sharding", + ":dlpack", + ":ffi", + ":guard_lib", + ":ifrt_proxy", + ":jax_jit", + ":mlir", + ":nb_class_ptr", + ":pjit", + ":pmap_lib", + ":py_client", + ":python_ref_manager", + ":pytree", + ":sdy", + ":traceback", + ":util", + ":xla_compiler", + "@com_google_absl//absl/base", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/log:initialize", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", + "@com_google_absl//absl/types:span", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@nanobind", + "@tsl//tsl/platform", + "@xla//third_party/python_runtime:headers", # buildcleaner: keep + "@xla//xla:literal", + "@xla//xla:shape_util", + "@xla//xla:types", + "@xla//xla:util", + "@xla//xla/backends/cpu/collectives:cpu_collectives", + "@xla//xla/ffi:ffi_api", + "@xla//xla/hlo/builder/lib:approx_topk_shape", + "@xla//xla/pjrt:exceptions", + "@xla//xla/pjrt:mlir_to_hlo", + "@xla//xla/pjrt:pjrt_api", + "@xla//xla/pjrt:pjrt_c_api_client", + "@xla//xla/pjrt:pjrt_client", + "@xla//xla/pjrt:pjrt_common", + "@xla//xla/pjrt:pjrt_compiler", + "@xla//xla/pjrt:pjrt_executable", + "@xla//xla/pjrt:pjrt_layout", + "@xla//xla/pjrt:status_casters", + "@xla//xla/pjrt/c:pjrt_c_api_hdrs", + "@xla//xla/pjrt/c:pjrt_c_api_raw_buffer_external", + "@xla//xla/pjrt/distributed", + "@xla//xla/pjrt/distributed:client", + "@xla//xla/pjrt/distributed:key_value_store_interface", + "@xla//xla/pjrt/distributed:protocol_proto_cc", + "@xla//xla/pjrt/distributed:service", + "@xla//xla/pjrt/plugin/xla_cpu:cpu_client_options", + "@xla//xla/pjrt/plugin/xla_cpu:xla_cpu_pjrt_client", + "@xla//xla/python:logging", + "@xla//xla/python:nb_absl_flat_hash_map", + "@xla//xla/python:nb_absl_span", + "@xla//xla/python:pprof_profile_builder", + "@xla//xla/python:refine_polymorphic_shapes", + "@xla//xla/python:types", + "@xla//xla/python:version", + "@xla//xla/python/ifrt", + "@xla//xla/python/ifrt:plugin_program", + "@xla//xla/python/ifrt:plugin_program_serdes", + "@xla//xla/python/pjrt_ifrt", + "@xla//xla/python/pjrt_ifrt:pjrt_attribute_map_util", + "@xla//xla/python/pjrt_ifrt:xla_ifrt", + "@xla//xla/tsl/concurrency:ref_count", + "@xla//xla/tsl/distributed_runtime/preemption:preemption_sync_manager", + "@xla//xla/tsl/platform:logging", + "@xla//xla/tsl/platform:status", + "@xla//xla/tsl/platform:statusor", + "@xla//xla/tsl/platform/cloud:gcs_file_system", + "@xla//xla/tsl/python/lib/core:numpy", + ] + select({ + # gloo tcp transport only builds on linux + "@xla//xla/tsl:macos": [ + "@gloo//:transport_uv", + "@xla//xla/backends/cpu/collectives:gloo_collectives", + "@xla//xla/backends/cpu/collectives:gloo_kv_store", + ], + "@xla//xla/tsl:windows": [], + "//conditions:default": [ + ":py_socket_transfer", + "@gloo//:transport_tcp", + "@xla//xla/backends/cpu/collectives:gloo_collectives", + "@xla//xla/backends/cpu/collectives:gloo_kv_store", + ], + }) + select({ + # mpitrampoline does not build on windows + "@xla//xla/tsl:windows": [], + # we support MPI collectives only in OSS builds + "//conditions:default": if_oss(["@xla//xla/backends/cpu/collectives:mpi_collectives"]), + }), +) + cc_library( - name = "pass_boilerplate", - hdrs = ["pass_boilerplate.h"], - # compatible with libtpu + name = "callback", + srcs = [ + "callback.cc", + ], + hdrs = [ + "callback.h", + ], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], deps = [ + ":python_ref_manager", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", + "@nanobind", + "@xla//third_party/python_runtime:headers", # buildcleaner: keep + "@xla//xla:comparison_util", + "@xla//xla:xla_data_proto_cc", + "@xla//xla/pjrt:host_callback", + "@xla//xla/pjrt:transpose", + "@xla//xla/python:nb_numpy", + "@xla//xla/tsl/platform:statusor", + "@xla//xla/tsl/python/lib/core:numpy", + ], +) + +cc_library( + name = "config", + srcs = ["config.cc"], + hdrs = ["config.h"], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + ":python_ref_manager", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:span", + "@nanobind", + "@xla//third_party/python_runtime:headers", # buildcleaner: keep + "@xla//xla/tsl/platform:logging", + ], +) + +cc_library( + name = "custom_call_sharding", + srcs = ["custom_call_sharding.cc"], + hdrs = ["custom_call_sharding.h"], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@nanobind", + "@xla//third_party/python_runtime:headers", + "@xla//xla:shape_util", + "@xla//xla:util", + "@xla//xla/hlo/ir:hlo", + "@xla//xla/hlo/utils:hlo_sharding_util", + "@xla//xla/pjrt:status_casters", + "@xla//xla/pjrt/c:pjrt_c_api_custom_partitioner_extension_hdrs", + "@xla//xla/pjrt/c:pjrt_c_api_hdrs", + "@xla//xla/pjrt/c:pjrt_c_api_helpers", + "@xla//xla/python:custom_call_batch_partitioner", + "@xla//xla/python:custom_partition_callback", + "@xla//xla/python:debug_callback_partitioner", + "@xla//xla/python:inspect_sharding", + "@xla//xla/tsl/platform:logging", + "@xla//xla/tsl/platform:statusor", + ], +) + +cc_library( + name = "dlpack_support", + srcs = ["dlpack_support.cc"], + hdrs = ["dlpack_support.h"], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + "@com_google_absl//absl/status:statusor", + "@dlpack", + "@xla//xla:util", + "@xla//xla:xla_data_proto_cc", + ], +) + +cc_library( + name = "dlpack", + srcs = ["dlpack.cc"], + hdrs = ["dlpack.h"], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + ":dlpack_support", + ":nb_class_ptr", + ":py_client", + ":python_ref_manager", + ":traceback", + ":util", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@dlpack", + "@llvm-project//llvm:Support", + "@nanobind", + "@xla//third_party/python_runtime:headers", # buildcleaner: keep + "@xla//xla:shape_util", + "@xla//xla:status_macros", + "@xla//xla:util", + "@xla//xla:xla_data_proto_cc", + "@xla//xla/pjrt:exceptions", + "@xla//xla/pjrt:pjrt_client", + "@xla//xla/pjrt:pjrt_common", + "@xla//xla/pjrt:pjrt_compiler", + "@xla//xla/python:types", + "@xla//xla/python/ifrt", + "@xla//xla/python/pjrt_ifrt", + "@xla//xla/tsl/platform:errors", + "@xla//xla/tsl/platform:logging", + "@xla//xla/tsl/platform:statusor", + ], +) + +cc_library( + name = "ffi", + srcs = ["ffi.cc"], + hdrs = ["ffi.h"], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + ":dlpack_support", + "@com_google_absl//absl/base", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", + "@dlpack", + "@nanobind", + "@xla//third_party/python_runtime:headers", + "@xla//xla:xla_data_proto_cc", + "@xla//xla/ffi:ffi_api", + "@xla//xla/ffi/api:c_api", + "@xla//xla/ffi/api:ffi", + "@xla//xla/pjrt:host_callback", + "@xla//xla/pjrt:status_casters", + "@xla//xla/python:nb_numpy", + "@xla//xla/python:types", + "@xla//xla/tsl/platform:statusor", + ], +) + +cc_library( + name = "guard_lib", + srcs = ["guard_lib.cc"], + hdrs = ["guard_lib.h"], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/functional:function_ref", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@nanobind", + "@xla//xla:util", + ], +) + +cc_library( + name = "ifrt_proxy", + srcs = ["ifrt_proxy.cc"], + hdrs = ["ifrt_proxy.h"], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + ":nb_class_ptr", + ":py_client", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/log:log_entry", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/time", + "@nanobind", + "@xla//xla/pjrt:status_casters", + "@xla//xla/python/ifrt", + "@xla//xla/python/ifrt:attribute_map", + "@xla//xla/python/ifrt_proxy/client:grpc_client", + "@xla//xla/python/ifrt_proxy/client:registry", + "@xla//xla/tsl/platform:env", + "@xla//xla/tsl/platform:statusor", + ], +) + +cc_library( + name = "jax_jit", + srcs = ["jax_jit.cc"], + hdrs = ["jax_jit.h"], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + ":py_client", + ":python_ref_manager", + ":pytree", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", + "@nanobind", + "@tsl//tsl/profiler/lib:traceme", + "@xla//third_party/python_runtime:headers", # build_cleaner: keep + "@xla//xla/pjrt:pjrt_client", + "@xla//xla/pjrt:pjrt_layout", + "@xla//xla/pjrt:status_casters", + "@xla//xla/python:nb_absl_inlined_vector", + "@xla//xla/python:nb_absl_span", + "@xla//xla/python:types", + "@xla//xla/tsl/platform:logging", + ], +) + +cc_library( + name = "mlir", + srcs = ["mlir.cc"], + hdrs = ["mlir.h"], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:BytecodeWriter", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Parser", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:ReconcileUnrealizedCasts", "@llvm-project//mlir:Support", + "@nanobind", + "@stablehlo//:stablehlo_serialization", + "@xla//xla/hlo/builder:xla_computation", + "@xla//xla/hlo/translate:stablehlo", + "@xla//xla/mlir_hlo:mhlo_passes", + "@xla//xla/pjrt:mlir_to_hlo", + "@xla//xla/pjrt:status_casters", + "@xla//xla/python:refine_polymorphic_shapes", + "@xla//xla/service:hlo_proto_cc", + "@xla//xla/tsl/platform:errors", + "@xla//xla/tsl/platform:logging", + "@xla//xla/tsl/platform:statusor", ], ) cc_library( - name = "handle_pool", - hdrs = ["handle_pool.h"], + name = "nb_class_ptr", + hdrs = ["nb_class_ptr.h"], + copts = ["-fexceptions"], + features = ["-use_header_modules"], + visibility = jax_visibility("jaxlib/nb_class_ptr"), + deps = ["@nanobind"], +) + +cc_library( + name = "pjit", + srcs = ["pjit.cc"], + hdrs = ["pjit.h"], + compatible_with = [], copts = [ "-fexceptions", "-fno-strict-aliasing", ], features = ["-use_header_modules"], deps = [ + ":config", + ":guard_lib", + ":jax_jit", + ":nb_class_ptr", + ":py_client", + ":python_ref_manager", + ":pytree", + ":traceback", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/cleanup", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:span", + "@nanobind", + "@tsl//tsl/profiler/lib:traceme", + "@xla//third_party/python_runtime:headers", # buildcleaner: keep + "@xla//xla:shape_util", + "@xla//xla:util", + "@xla//xla/pjrt:exceptions", + "@xla//xla/pjrt:lru_cache", + "@xla//xla/python:nb_helpers", + "@xla//xla/python:nb_numpy", + "@xla//xla/python/ifrt", + "@xla//xla/tsl/concurrency:ref_count", + "@xla//xla/tsl/platform:env", + "@xla//xla/tsl/platform:errors", + "@xla//xla/tsl/platform:logging", + "@xla//xla/tsl/platform:statusor", ], ) -# This isn't a CPU kernel. This exists to catch cases where jaxlib is built for the wrong -# target architecture. -nanobind_extension( - name = "cpu_feature_guard", - srcs = ["cpu_feature_guard.c"], - module_name = "cpu_feature_guard", +cc_library( + name = "pmap_lib", + srcs = ["pmap_lib.cc"], + hdrs = ["pmap_lib.h"], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], deps = [ - "@xla//third_party/python_runtime:headers", + ":config", + ":jax_jit", + ":nb_class_ptr", + ":py_client", + ":python_ref_manager", + ":pytree", + ":traceback", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:span", + "@nanobind", + "@tsl//tsl/profiler/lib:traceme", + "@xla//third_party/python_runtime:headers", # buildcleaner: keep + "@xla//xla:status_macros", + "@xla//xla:util", + "@xla//xla:xla_data_proto_cc", + "@xla//xla/pjrt:exceptions", + "@xla//xla/pjrt:status_casters", + "@xla//xla/python:nb_helpers", + "@xla//xla/python:nb_numpy", + "@xla//xla/python:safe_static_init", + "@xla//xla/python:types", + "@xla//xla/python/ifrt", + "@xla//xla/tsl/concurrency:ref_count", + "@xla//xla/tsl/platform:env", + "@xla//xla/tsl/platform:logging", + "@xla//xla/tsl/platform:statusor", + "@xla//xla/tsl/python/lib/core:numpy", ], ) -nanobind_extension( - name = "utils", - srcs = ["utils.cc"], - module_name = "utils", +cc_library( + name = "cached_py_object", + hdrs = ["cached_py_object.h"], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], deps = [ - "@com_google_absl//absl/cleanup", + "@com_google_absl//absl/functional:function_ref", + "@nanobind", + ], +) + +cc_library( + name = "py_client", + srcs = [ + "partition_spec.cc", + "py_array.cc", + "py_client.cc", + "py_compile_only_client.cc", + "py_device.cc", + "py_device_list.cc", + "py_executable.cc", + "py_memory_space.cc", + "py_program.cc", + "py_values.cc", + "sharding.cc", + "to_ifrt_sharding.cc", + ], + hdrs = [ + "partition_spec.h", + "py_array.h", + "py_client.h", + "py_compile_only_client.h", + "py_device.h", + "py_device_list.h", + "py_executable.h", + "py_memory_space.h", + "py_program.h", + "py_values.h", + "sharded_device_array.h", + "sharding.h", + "to_ifrt_sharding.h", + ], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + visibility = jax_visibility("jaxlib/py_client"), + deps = [ + ":cached_py_object", + ":guard_lib", + ":nb_class_ptr", + ":py_client_cpu", + ":py_host_callback", + ":python_ref_manager", + ":traceback", + ":util", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:cord", + "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:span", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", "@nanobind", - "@xla//third_party/python_runtime:headers", + "@tsl//tsl/platform:fingerprint", + "@tsl//tsl/platform:ml_dtypes", + "@tsl//tsl/profiler/lib:traceme", + "@xla//third_party/python_runtime:headers", # buildcleaner: keep + "@xla//xla:literal", + "@xla//xla:shape_util", + "@xla//xla:status_macros", + "@xla//xla:types", + "@xla//xla:util", + "@xla//xla:xla_data_proto_cc", + "@xla//xla/hlo/ir:hlo", + "@xla//xla/pjrt:exceptions", + "@xla//xla/pjrt:lru_cache", + "@xla//xla/pjrt:mlir_to_hlo", + "@xla//xla/pjrt:pjrt_client", + "@xla//xla/pjrt:pjrt_compiler", + "@xla//xla/pjrt:pjrt_executable", + "@xla//xla/pjrt:pjrt_future", + "@xla//xla/pjrt:pjrt_layout", + "@xla//xla/pjrt:status_casters", + "@xla//xla/python:nb_absl_span", + "@xla//xla/python:nb_helpers", + "@xla//xla/python:nb_numpy", + "@xla//xla/python:pprof_profile_builder", + "@xla//xla/python:safe_static_init", + "@xla//xla/python:types", + "@xla//xla/python:version", + "@xla//xla/python/compile_only_ifrt:client", + "@xla//xla/python/ifrt", + "@xla//xla/python/ifrt:attribute_map", + "@xla//xla/python/ifrt:custom_call_program", + "@xla//xla/python/ifrt:plugin_program", + "@xla//xla/python/ifrt:plugin_program_serdes", + "@xla//xla/python/ifrt:user_context", + "@xla//xla/python/ifrt/hlo:hlo_program", + "@xla//xla/python/pjrt_ifrt", + "@xla//xla/python/pjrt_ifrt:pjrt_dtype", + "@xla//xla/python/pjrt_ifrt:xla_ifrt", + "@xla//xla/service:platform_util", + "@xla//xla/tsl/concurrency:ref_count", + "@xla//xla/tsl/framework:allocator", + "@xla//xla/tsl/platform:env", + "@xla//xla/tsl/platform:errors", + "@xla//xla/tsl/platform:logging", + "@xla//xla/tsl/platform:status", + "@xla//xla/tsl/platform:statusor", + "@xla//xla/tsl/python/lib/core:numpy", + ], +) + +cc_library( + name = "py_client_cpu", + srcs = ["py_client_cpu.cc"], + hdrs = ["py_client_cpu.h"], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + ":ffi", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@dlpack", + "@nanobind", + "@xla//third_party/python_runtime:headers", # buildcleaner: keep + "@xla//xla:shape_util", + "@xla//xla:util", + "@xla//xla:xla_data_proto_cc", + "@xla//xla/ffi:ffi_api", + "@xla//xla/ffi/api:ffi", + "@xla//xla/pjrt:host_callback", + "@xla//xla/pjrt:transpose", + "@xla//xla/python:nb_numpy", + "@xla//xla/python:types", + ], + alwayslink = 1, +) + +cc_library( + name = "py_host_callback", + srcs = ["py_host_callback.cc"], + hdrs = ["py_host_callback.h"], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + ":callback", + ":py_host_callback_cc_proto", + ":python_ref_manager", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@llvm-project//llvm:Support", + "@nanobind", + "@xla//xla:shape_util", + "@xla//xla:status_macros", + "@xla//xla:util", + "@xla//xla:xla_data_proto_cc", + "@xla//xla/pjrt:host_callback", + "@xla//xla/python:types", + "@xla//xla/python/ifrt", + "@xla//xla/python/pjrt_ifrt", + "@xla//xla/python/pjrt_ifrt:xla_host_callback_proto_cc", + "@xla//xla/tsl/concurrency:ref_count", + "@xla//xla/tsl/platform:statusor", + ], +) + +proto_library( + name = "py_host_callback_proto", + srcs = ["py_host_callback.proto"], +) + +cc_proto_library( + name = "py_host_callback_cc_proto", + visibility = jax_visibility("jaxlib/py_host_callback_cc_proto"), + deps = [":py_host_callback_proto"], +) + +cc_library( + name = "py_socket_transfer", + srcs = ["py_socket_transfer.cc"], + hdrs = ["py_socket_transfer.h"], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + ":nb_class_ptr", + ":py_client", + ":traceback", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", + "@llvm-project//llvm:Support", + "@nanobind", + "@tsl//tsl/platform:casts", + "@xla//xla:util", + "@xla//xla/pjrt:pjrt_client", + "@xla//xla/pjrt:status_casters", + "@xla//xla/python:nb_numpy", + "@xla//xla/python:types", + "@xla//xla/python:version", + "@xla//xla/python/ifrt", + "@xla//xla/python/pjrt_ifrt", + "@xla//xla/python/pjrt_ifrt:pjrt_dtype", + "@xla//xla/python/transfer:event_loop", + "@xla//xla/python/transfer:socket-server", + "@xla//xla/python/transfer:socket_bulk_transport", + "@xla//xla/python/transfer:streaming", + "@xla//xla/python/transfer:streaming_ifrt", + "@xla//xla/python/transfer:transfer_socket_proto_cc", + "@xla//xla/tsl/concurrency:ref_count", + "@xla//xla/tsl/platform:statusor", + ], +) + +cc_library( + name = "python_ref_manager", + srcs = ["python_ref_manager.cc"], + hdrs = ["python_ref_manager.h"], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + visibility = jax_visibility("jaxlib/python_ref_manager"), + deps = [ + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:span", + "@nanobind", + "@tsl//tsl/profiler/lib:traceme", + "@xla//third_party/python_runtime:headers", # buildcleaner: keep + ], +) + +proto_library( + name = "pytree_proto", + srcs = ["pytree.proto"], +) + +cc_proto_library( + name = "pytree_cc_proto", + deps = [":pytree_proto"], +) + +cc_library( + name = "pytree", + srcs = ["pytree.cc"], + hdrs = ["pytree.h"], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + visibility = jax_visibility("jaxlib/pytree"), + deps = [ + ":nb_class_ptr", + ":pytree_cc_proto", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", + "@nanobind", + "@xla//third_party/python_runtime:headers", # buildcleaner: keep + "@xla//xla/pjrt:exceptions", + "@xla//xla/tsl/platform:logging", + ], +) + +cc_library( + name = "sdy", + srcs = ["sdy.cc"], + hdrs = ["sdy.h"], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:BytecodeWriter", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@nanobind", + "@shardy//shardy/dialect/sdy/ir:dialect", + "@xla//xla/hlo/translate/hlo_to_mhlo:hlo_to_mlir_hlo", + "@xla//xla/mlir_hlo:all_passes", + "@xla//xla/pjrt:mlir_to_hlo", + "@xla//xla/pjrt:status_casters", + "@xla//xla/service/spmd/shardy:constants", + "@xla//xla/service/spmd/shardy:utils", + "@xla//xla/service/spmd/shardy/sdy_round_trip:import_shardy_attrs", + "@xla//xla/service/spmd/shardy/sdy_round_trip:pipelines", + "@xla//xla/tsl/framework/mlir:status_scoped_diagnostic_handler", + ], +) + +cc_library( + name = "traceback", + srcs = ["traceback.cc"], + hdrs = ["traceback.h"], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + visibility = jax_visibility("jaxlib/traceback"), + deps = [ + ":nb_class_ptr", + "@com_google_absl//absl/base", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@nanobind", + "@tsl//tsl/platform", + "@xla//third_party/python_runtime:headers", # buildcleaner: keep + "@xla//xla/pjrt:exceptions", + "@xla//xla/python:nb_helpers", + "@xla//xla/tsl/platform:logging", + ], +) + +cc_library( + name = "util", + srcs = ["util.cc"], + hdrs = ["util.h"], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", ], + features = ["-use_header_modules"], + deps = [ + "@com_google_absl//absl/status", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", + "@com_google_absl//absl/types:span", + "@nanobind", + "@xla//xla:util", + "@xla//xla/pjrt:pjrt_future", + "@xla//xla/python:version", + "@xla//xla/python/ifrt", + "@xla//xla/tsl/concurrency:async_value", + "@xla//xla/tsl/concurrency:ref_count", + ], +) + +cc_library( + name = "xla_compiler", + srcs = ["xla_compiler.cc"], + hdrs = ["xla_compiler.h"], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + ":dlpack", + ":py_client", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:span", + "@llvm-project//mlir:Support", + "@nanobind", + "@xla//xla:array", + "@xla//xla:debug_options_flags", + "@xla//xla:literal", + "@xla//xla:shape_util", + "@xla//xla:util", + "@xla//xla:xla_data_proto_cc", + "@xla//xla:xla_proto_cc", + "@xla//xla/client:executable_build_options", + "@xla//xla/ffi", + "@xla//xla/ffi:ffi_api", + "@xla//xla/ffi/api:c_api", + "@xla//xla/hlo/builder:xla_computation", + "@xla//xla/hlo/ir:hlo", + "@xla//xla/hlo/parser:hlo_parser", + "@xla//xla/pjrt:exceptions", + "@xla//xla/pjrt:pjrt_executable", + "@xla//xla/pjrt:status_casters", + "@xla//xla/pjrt/proto:compile_options_proto_cc", + "@xla//xla/python:nb_absl_span", + "@xla//xla/python:nb_helpers", + "@xla//xla/python:nb_numpy", + "@xla//xla/python:types", + "@xla//xla/service:computation_placer", + "@xla//xla/service:custom_call_target_registry", + "@xla//xla/service:hlo_graph_dumper", + "@xla//xla/service:hlo_module_config", + "@xla//xla/service:hlo_proto_cc", + "@xla//xla/service:name_uniquer", + "@xla//xla/service/spmd/shardy/stablehlo_round_trip:stablehlo_import", + "@xla//xla/tsl/lib/strings:proto_serialization", + "@xla//xla/tsl/platform:env", + "@xla//xla/tsl/platform:errors", + "@xla//xla/tsl/platform:logging", + "@xla//xla/tsl/platform:statusor", + ], +) + +pytype_strict_library( + name = "xla_client", + srcs = ["xla_client.py"], + pytype_srcs = ["xla_client.pyi"], + visibility = [":xla_python"], + deps = py_deps([ + "numpy", + "ml_dtypes", + ]) + [":_jax"], +) + +py_strict_test( + name = "pytree_test", + srcs = ["pytree_test.py"], + deps = [ + ":xla_client", + ] + py_deps([ + "absl/flags", + "absl/logging", + "absl/testing", + ]), +) + +py_strict_test( + name = "config_test", + srcs = ["config_test.py"], + deps = [ + ":xla_client", + ] + py_deps([ + "absl/flags", + "absl/logging", + "absl/testing", + ]), +) + +py_strict_test( + name = "jax_jit_test", + srcs = ["jax_jit_test.py"], + deps = [ + ":xla_client", + ] + py_deps([ + "absl/flags", + "absl/logging", + "absl/testing", + "numpy", + ]), ) diff --git a/jaxlib/_jax/__init__.pyi b/jaxlib/_jax/__init__.pyi new file mode 100644 index 000000000000..afa9e633a391 --- /dev/null +++ b/jaxlib/_jax/__init__.pyi @@ -0,0 +1,1026 @@ +# Copyright 2021 The JAX Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import annotations + +import builtins +from collections.abc import Callable, Iterator, Mapping, Sequence, Set +import enum +import inspect +import types +from typing import Any, ClassVar, TypeVar, overload + +import numpy as np + +from . import config as config +from . import ffi as ffi +from . import guard_lib as guard_lib +from . import ifrt_programs as ifrt_programs +from . import ifrt_proxy as ifrt_proxy +from . import jax_jit as jax_jit +from . import mlir as mlir +from . import pmap_lib as pmap_lib +from . import profiler as profiler +from . import pytree as pytree +from . import transfer_guard_lib as transfer_guard_lib + +custom_call_targets = Any +hlo_sharding_util = Any + +_LiteralSlice = Any +_Status = Any +_Dtype = Any + +ifrt_version_number: int + +_T = TypeVar("_T") + +class XlaRuntimeError(RuntimeError): + pass + +class PrimitiveType(enum.IntEnum): + PRIMITIVE_TYPE_INVALID = ... + PRED = ... + S2 = ... + S4 = ... + S8 = ... + S16 = ... + S32 = ... + S64 = ... + U2 = ... + U4 = ... + U8 = ... + U16 = ... + U32 = ... + U64 = ... + F4E2M1FN = ... + F8E3M4 = ... + F8E4M3 = ... + F8E4M3FN = ... + F8E4M3B11FNUZ = ... + F8E4M3FNUZ = ... + F8E5M2 = ... + F8E5M2FNUZ = ... + F8E8M0FNU = ... + BF16 = ... + F16 = ... + F32 = ... + F64 = ... + C64 = ... + C128 = ... + TUPLE = ... + OPAQUE_TYPE = ... + TOKEN = ... + +# === BEGIN xla_compiler.cc + +class ArrayCopySemantics(enum.IntEnum): + ALWAYS_COPY = ... + REUSE_INPUT = ... + DONATE_INPUT = ... + +class Layout: + @overload + def __init__(self, minor_to_major: tuple[int, ...]): ... + @overload + def __init__( + self, + minor_to_major: tuple[int, ...], + tiling: tuple[tuple[int, ...], ...], + element_size_in_bits: int, + ): ... + def minor_to_major(self) -> tuple[int, ...]: ... + def tiling(self) -> Sequence[tuple[int, ...]]: ... + def element_size_in_bits(self) -> int: ... + def to_string(self) -> str: ... + def __eq__(self, other: Any) -> bool: ... + def __ne__(self, other: Any) -> bool: ... + def __hash__(self) -> int: ... + +class Shape: + def __init__(self, s: str): ... + @staticmethod + def tuple_shape(shapes: Sequence[Shape]) -> Shape: ... + @staticmethod + def array_shape( + type: np.dtype | PrimitiveType, + dims_seq: Any = ..., + layout_seq: Any = ..., + dynamic_dimensions: list[bool] | None = ..., + ) -> Shape: ... + @staticmethod + def token_shape() -> Shape: ... + @staticmethod + def scalar_shape(type: np.dtype | PrimitiveType) -> Shape: ... + def dimensions(self) -> tuple[int, ...]: ... + def layout(self) -> Layout: ... + def xla_element_type(self) -> PrimitiveType: ... + def element_type(self) -> np.dtype: ... + def numpy_dtype(self) -> np.dtype: ... + def is_tuple(self) -> bool: ... + def is_array(self) -> bool: ... + def is_token(self) -> bool: ... + def is_static(self) -> bool: ... + def is_dynamic(self) -> bool: ... + def is_dynamic_dimension(self, dimension: int) -> bool: ... + def set_dynamic_dimension(self, dimension: int, is_dynamic: bool) -> None: ... + def rank(self) -> int: ... + def to_serialized_proto(self) -> bytes: ... + def tuple_shapes(self) -> list[Shape]: ... + def leaf_count(self) -> int: ... + def with_major_to_minor_layout_if_absent(self) -> Shape: ... + def __eq__(self, other: Any) -> bool: ... + def __ne__(self, other: Any) -> bool: ... + def __hash__(self) -> int: ... + def __repr__(self) -> str: ... + +class ProgramShape: + def __init__(self, params: Sequence[Shape], result: Shape) -> None: ... + def parameter_shapes(self) -> list[Shape]: ... + def result_shape(self) -> Shape: ... + def __repr__(self) -> str: ... + +class Literal: + def __init__(self, shape: Shape) -> None: ... + def __repr__(self) -> str: ... + def __array__( + self, dtype: np.dtype | None = None, copy: bool | None = None + ) -> np.ndarray: ... + def shape(self) -> Shape: ... + +class XlaComputation: + def __init__(self, serialized_hlo_module_proto: bytes) -> None: ... + def get_hlo_module(self) -> HloModule: ... + def program_shape(self) -> ProgramShape: ... + def as_serialized_hlo_module_proto(self) -> bytes: ... + def as_hlo_text(self, print_large_constants: bool = False) -> str: ... + def as_hlo_dot_graph(self) -> str: ... + def hash(self) -> int: ... + def as_hlo_module(self) -> HloModule: ... + +class HloPrintOptions: + def __init__(self) -> None: ... + @staticmethod + def short_parsable() -> HloPrintOptions: ... + @staticmethod + def canonical() -> HloPrintOptions: ... + @staticmethod + def fingerprint() -> HloPrintOptions: ... + print_large_constants: bool + print_metadata: bool + print_backend_config: bool + print_result_shape: bool + print_operand_shape: bool + print_operand_names: bool + print_ids: bool + print_extra_attributes: bool + print_program_shape: bool + print_percent: bool + print_control_dependencies: bool + compact_operands: bool + include_layout_in_shapes: bool + canonicalize_instruction_names: bool + canonicalize_computations: bool + indent_amount: int + is_in_nested_computation: bool + +class HloComputation: + def render_html(self) -> None: ... + +class HloModule: + spmd_output_sharding: OpSharding | None + spmd_parameters_shardings: list[OpSharding] | None + @property + def name(self) -> str: ... + def to_string(self, options: HloPrintOptions = ...) -> str: ... + def as_serialized_hlo_module_proto(self) -> bytes: ... + @staticmethod + def from_serialized_hlo_module_proto( + serialized_hlo_module_proto: bytes, + ) -> HloModule: ... + def computations(self) -> list[HloComputation]: ... + +class HloModuleGroup: + def __init__(self, name: str, modules: list[HloModule]) -> None: ... + @property + def name(self) -> str: ... + def to_string(self) -> str: ... + def to_modules(self) -> list[HloModule]: ... + +def hlo_module_to_dot_graph(hlo_module: HloModule) -> str: ... +def hlo_module_from_text(hlo_module_text: str) -> HloModule: ... +def hlo_module_cost_analysis( + client: Client, module: HloModule +) -> dict[str, float]: ... + +class DeviceAssignment: + @staticmethod + def create(array: np.ndarray) -> DeviceAssignment: ... + def replica_count(self) -> int: ... + def computation_count(self) -> int: ... + def __repr__(self) -> str: ... + def serialize(self) -> bytes: ... + +class CompileOptions: + @staticmethod + def ParseFromString(s: bytes) -> CompileOptions: ... + def __init__(self) -> None: ... + def SerializeAsString(self) -> bytes: ... + argument_layouts: list[Shape] | None + parameter_is_tupled_arguments: bool + executable_build_options: ExecutableBuildOptions + tuple_arguments: bool + num_replicas: int + num_partitions: int + profile_version: int + device_assignment: DeviceAssignment | None + compile_portable_executable: bool + env_option_overrides: list[tuple[str, str]] + +def register_custom_call_target( + fn_name: str, + capsule: Any, + platform: str, + api_version: int = ..., +) -> _Status: ... +def register_custom_call_partitioner( + name: str, + prop_user_sharding: Callable, + partition: Callable, + infer_sharding_from_operands: Callable, + can_side_effecting_have_replicated_sharding: bool = ..., + c_api: Any | None = ..., +) -> None: ... +def encode_inspect_sharding_callback(handler: Any) -> bytes: ... +def register_custom_call_as_batch_partitionable( + target_name: str, + c_api: Any | None = ..., +) -> None: ... +def register_custom_type_id(type_name: str, type_id: Any) -> None: ... + +class AutotuneCacheMode(enum.IntEnum): + UNSPECIFIED = ... + UPDATE = ... + READ = ... + +class DebugOptions: + def __repr__(self) -> str: ... + xla_cpu_enable_fast_math: bool + xla_cpu_fast_math_honor_infs: bool + xla_cpu_fast_math_honor_nans: bool + xla_cpu_fast_math_honor_division: bool + xla_cpu_fast_math_honor_functions: bool + xla_gpu_enable_fast_min_max: bool + xla_backend_optimization_level: int + xla_cpu_enable_xprof_traceme: bool + xla_llvm_disable_expensive_passes: bool + xla_test_all_input_layouts: bool + xla_disable_hlo_passes: str + xla_enable_hlo_passes_only: str + xla_force_host_platform_device_count: int + xla_dump_to: str + xla_dump_hlo_module_re: str + xla_dump_hlo_pass_re: str + xla_dump_hlo_as_text: bool + xla_dump_hlo_as_proto: bool + xla_dump_hlo_as_dot: bool + xla_dump_hlo_as_url: bool + xla_dump_hlo_as_html: bool + xla_dump_fusion_visualization: bool + xla_dump_hlo_snapshots: bool + xla_dump_max_hlo_modules: bool + xla_dump_module_metadata: bool + xla_dump_compress_protos: bool + xla_dump_hlo_as_long_text: bool + xla_dump_disable_metadata: bool + xla_dump_hlo_pipeline_re: str + xla_gpu_cuda_data_dir: str + xla_detailed_logging: bool + xla_enable_dumping: bool + xla_gpu_dump_autotune_results_to: str + xla_gpu_load_autotune_results_from: str + xla_gpu_dump_autotune_logs_to: str + xla_gpu_kernel_cache_file: str + xla_gpu_enable_llvm_module_compilation_parallelism: bool + xla_gpu_per_fusion_autotune_cache_dir: str + xla_gpu_experimental_autotune_cache_mode: AutotuneCacheMode + +class CompiledMemoryStats: + generated_code_size_in_bytes: int + argument_size_in_bytes: int + output_size_in_bytes: int + alias_size_in_bytes: int + temp_size_in_bytes: int + host_generated_code_size_in_bytes: int + host_argument_size_in_bytes: int + host_output_size_in_bytes: int + host_alias_size_in_bytes: int + host_temp_size_in_bytes: int + serialized_buffer_assignment_proto: bytes + def __str__(self) -> str: ... + +class ExecutableBuildOptions: + def __init__(self) -> None: ... + def __repr__(self) -> str: ... + result_layout: Shape | None + fdo_profile: bytes | None + num_replicas: int + num_partitions: int + debug_options: DebugOptions + device_assignment: DeviceAssignment | None + use_spmd_partitioning: bool + use_auto_spmd_partitioning: bool + auto_spmd_partitioning_mesh_shape: list[int] + auto_spmd_partitioning_mesh_ids: list[int] + use_shardy_partitioner: bool + def compilation_environments_from_serialized_proto( + self, serialized_proto: bytes + ) -> None: ... + +class OpSharding_Type(enum.IntEnum): + REPLICATED = ... + MAXIMAL = ... + TUPLE = ... + OTHER = ... + MANUAL = ... + UNKNOWN = ... + +class OpSharding_ShardGroupType(enum.IntEnum): + AS = ... + LIKE = ... + +class OpSharding: + Type: type[OpSharding_Type] + type: OpSharding_Type + replicate_on_last_tile_dim: bool + last_tile_dims: Sequence[OpSharding_Type] + tile_assignment_dimensions: Sequence[int] + tile_assignment_devices: Sequence[int] + iota_reshape_dims: Sequence[int] + iota_transpose_perm: Sequence[int] + tuple_shardings: Sequence[OpSharding] + is_shard_group: bool + shard_group_id: int + ShardGroupType: builtins.type[OpSharding_ShardGroupType] + shard_group_type: OpSharding_ShardGroupType + def ParseFromString(self, s: bytes) -> None: ... + def SerializeToString(self) -> bytes: ... + def clone(self) -> OpSharding: ... + +class HloSharding: + @staticmethod + def from_proto(proto: OpSharding) -> HloSharding: ... + @staticmethod + def from_string(sharding: str) -> HloSharding: ... + @staticmethod + def tuple_sharding( + shape: Shape, shardings: Sequence[HloSharding] + ) -> HloSharding: ... + @staticmethod + def iota_tile( + dims: Sequence[int], + reshape_dims: Sequence[int], + transpose_perm: Sequence[int], + subgroup_types: Sequence[OpSharding_Type], + ) -> HloSharding: ... + @staticmethod + def replicate() -> HloSharding: ... + @staticmethod + def manual() -> HloSharding: ... + @staticmethod + def unknown() -> HloSharding: ... + @staticmethod + def subgroup_with_device_ordering( + tile_assignment: np.ndarray, subgroup_types: Sequence[OpSharding_Type] + ) -> HloSharding: ... + def __eq__(self, other: Any) -> bool: ... + def __hash__(self) -> int: ... + def __repr__(self) -> str: ... + def tile(self, shape: Shape) -> Shape: ... + def is_replicated(self) -> bool: ... + def is_manual(self) -> bool: ... + def is_unknown(self) -> bool: ... + def is_tiled(self) -> bool: ... + def is_maximal(self) -> bool: ... + def tuple_elements(self) -> list[HloSharding]: ... + def num_devices(self) -> int: ... + def num_dimensions(self) -> int: ... + def is_tile_assignment_iota(self) -> bool: ... + def tile_assignment_dimensions(self) -> Sequence[int]: ... + def tile_assignment_devices(self) -> Sequence[int]: ... + def subgroup_types(self) -> Sequence[OpSharding_Type]: ... + def replicate_on_last_tile_dim(self) -> bool: ... + def to_proto(self) -> OpSharding: ... + def get_axis_sizes(self) -> list[int]: ... + +# === END xla_compiler.cc + +class Device: + id: int + host_id: int + process_index: int + platform: str + device_kind: str + client: Client + local_hardware_id: int | None + def __repr__(self) -> str: ... + def __str__(self) -> str: ... + def transfer_to_infeed(self, literal: _LiteralSlice): ... + def transfer_from_outfeed(self, shape: Shape): ... + def memory(self, kind: str) -> Memory: ... + def default_memory(self) -> Memory: ... + def addressable_memories(self) -> list[Memory]: ... + def live_buffers(self) -> list[Any]: ... + def memory_stats(self) -> dict[str, int] | None: ... + def get_stream_for_external_ready_events(self) -> int: ... + def __getattr__(self, name: str) -> Any: ... + +class Memory: + process_index: int + platform: str + kind: str + def __repr__(self) -> str: ... + def __str__(self) -> str: ... + def addressable_by_devices(self) -> list[Device]: ... + +class PjRtLayout: + def __str__(self) -> str: ... + def __eq__(self, other: Any) -> bool: ... + def __hash__(self) -> int: ... + def __getstate__(self) -> Any: ... + def __setstate__(self, _: Any): ... + def _xla_layout(self) -> Layout: ... + +class GpuAllocatorConfig: + class Kind(enum.IntEnum): + DEFAULT = ... + PLATFORM = ... + BFC = ... + CUDA_ASYNC = ... + + def __init__( + self, + kind: Kind = ..., + memory_fraction: float = ..., + preallocate: bool = ..., + collective_memory_size: int = ..., + ) -> None: ... + +class HostBufferSemantics(enum.IntEnum): + IMMUTABLE_ONLY_DURING_CALL = ... + IMMUTABLE_UNTIL_TRANSFER_COMPLETES = ... + ZERO_COPY = ... + +class Client: + platform: str + _raw_platform: str + platform_version: str + runtime_type: str + def device_count(self) -> int: ... + def local_device_count(self) -> int: ... + def devices(self) -> list[Device]: ... + def local_devices(self) -> list[Device]: ... + def _get_all_devices(self) -> list[Device]: ... + def device_from_local_hardware_id(self, int) -> Device: ... + def live_buffers(self) -> list[Any]: ... + def live_arrays(self) -> list[ArrayImpl]: ... + def live_executables(self) -> list[LoadedExecutable]: ... + def host_id(self) -> int: ... + def process_index(self) -> int: ... + def buffer_from_pyval( + self, + argument: Any, + device: Device | None = ..., + force_copy: bool = ..., + host_buffer_semantics: HostBufferSemantics = ..., + ) -> ArrayImpl: ... + def compile( + self, + computation: str | bytes, + executable_devices: DeviceList | Sequence[Device], + compile_options: CompileOptions = ..., + ) -> Executable: ... + def compile_and_load( + self, + computation: str | bytes, + executable_devices: DeviceList | Sequence[Device], + compile_options: CompileOptions = ..., + host_callbacks: Sequence[Any] = ..., + ) -> LoadedExecutable: ... + def compile_ifrt_program( + self, + program: ifrt_programs.Program, + program_options: ifrt_programs.CompileOptions, + ) -> LoadedExecutable: ... + def compile_and_load_ifrt_program( + self, + program: ifrt_programs.Program, + program_options: ifrt_programs.CompileOptions, + ) -> LoadedExecutable: ... + def serialize_executable(self, executable: LoadedExecutable) -> bytes: ... + def deserialize_executable( + self, + serialized: bytes, + executable_devices: DeviceList | Sequence[Device], + options: CompileOptions | None, + host_callbacks: Sequence[Any] = ..., + ) -> LoadedExecutable: ... + def heap_profile(self) -> bytes: ... + def make_python_callback_from_host_send_and_recv( + self, + callable: Callable, + operand_shapes: Sequence[Shape], + result_shapes: Sequence[Shape], + send_channel_ids: Sequence[int], + recv_channel_ids: Sequence[int], + serializer: Callable | None = ..., + ) -> Any: ... + def get_default_layout( + self, dtype: np.dtype, shard_shape: Sequence[int], device: Device + ) -> PjRtLayout: ... + def __getattr__(self, name: str) -> Any: ... + +class CompileOnlyPyClient(Client): + def compile( + self, + computation: str | bytes, + executable_devices: DeviceList | Sequence[Device], + compile_options: CompileOptions = ..., + ) -> Executable: ... + +class CpuCollectives: ... + +def make_gloo_tcp_collectives( + distributed_client: DistributedRuntimeClient | None = ..., + hostname: str | None = ..., + interface: str | None = ..., +) -> CpuCollectives: ... + +class MpiCollectives(CpuCollectives): + def Init(self): ... + def Finalize(self): ... + +def make_mpi_collectives() -> MpiCollectives: ... +def get_tfrt_cpu_client( + asynchronous: bool = ..., + distributed_client: DistributedRuntimeClient | None = ..., + node_id: int = ..., + num_nodes: int = ..., + collectives: CpuCollectives | None = ..., + num_devices: int | None = ..., + get_local_topology_timeout_minutes: int | None = ..., + get_global_topology_timeout_minutes: int | None = ..., +) -> Client: ... +def get_mock_gpu_client( + asynchronous: bool = ..., + allocator_config: GpuAllocatorConfig = ..., + distributed_client: DistributedRuntimeClient | None = ..., + node_id: int = ..., + allowed_devices: Any | None = ..., + platform_name: str | None = ..., +) -> Client: ... +def get_c_api_client( + platform_name: str, + options: Mapping[str, str | int | list[int] | float | bool], + distributed_client: DistributedRuntimeClient | None = ..., +) -> Client: ... +def get_default_c_api_topology( + platform_name: str, + topology_name: str, + options: dict[str, str | int | list[int] | float], +) -> DeviceTopology: ... +def get_c_api_topology( + c_api: Any, + topology_name: str, + options: dict[str, str | int | list[int] | float], +) -> DeviceTopology: ... +def get_topology_for_devices(devices: list[Device]) -> DeviceTopology: ... +def load_pjrt_plugin( + platform_name: str, library_path: str | None, c_api: Any | None +) -> _Status: ... +def pjrt_plugin_loaded(plugin_name: str) -> bool: ... +def pjrt_plugin_initialized(plugin_name: str) -> bool: ... +def initialize_pjrt_plugin(platform_name: str) -> _Status: ... + +Array = Any +ArrayImpl = Any + +# TODO(phawkins): this type is problematic because it is not a subtype of +# jax.Array, and pytype notices. +# class ArrayImpl: +# def __init__(self, +# aval: Any, +# sharding: Any, +# arrays: Sequence[ArrayImpl], +# committed: bool, +# _skip_checks: bool = ...): ... +# def block_until_ready(self) -> ArrayImpl: ... +# def is_deleted(self) -> bool: ... +# def is_ready(self) -> bool: ... +# def delete(self): ... +# def unsafe_buffer_pointer(self) -> Any: ... +# def clone(self) -> ArrayImpl: ... +# def _copy_single_device_array_to_host_async(self): ... +# def _single_device_array_to_np_array_did_copy(self) -> tuple[np.ndarray, bool]: ... +# def on_device_size_in_bytes(self) -> int: ... +# def _fully_replicated_shard(self) -> ArrayImpl: ... +# __cuda_array_interface__: Dict[str, Any] +# dtype: np.dtype +# shape: Tuple[int, ...] +# _arrays: Any +# _npy_value: Any +# traceback: Traceback +# _HAS_DYNAMIC_ATTRIBUTES: bool = ... + +def batched_copy_array_to_devices_with_sharding( + arrays: Sequence[ArrayImpl], + devices: Sequence[list[Device]], + sharding: Sequence[Any], + array_copy_semantics: Sequence[ArrayCopySemantics], +) -> Sequence[ArrayImpl]: ... +def batched_block_until_ready(x: Sequence[ArrayImpl]) -> None: ... +def batched_device_put( + aval: Any, + sharding: Any, + shards: Sequence[Any], + devices: list[Device], + committed: bool = ..., + force_copy: bool = ..., + host_buffer_semantics: Any = ..., +) -> ArrayImpl: ... +def reorder_shards( + x: ArrayImpl, + dst_sharding: Any, + array_copy_semantics: ArrayCopySemantics, +) -> ArrayImpl: ... +def check_and_canonicalize_memory_kind( + memory_kind: str | None, device_list: DeviceList +) -> str | None: ... +def array_result_handler( + aval: Any, sharding: Any, committed: bool, _skip_checks: bool = ... +) -> Callable: ... + +class Token: + def block_until_ready(self): ... + +class ShardedToken: + def block_until_ready(self): ... + def get_token(self, device_id: int): ... + +class ExecuteResults: + def __len__(self) -> int: ... + def disassemble_into_single_device_arrays(self) -> list[list[ArrayImpl]]: ... + def disassemble_prefix_into_single_device_arrays( + self, n: int + ) -> list[list[ArrayImpl]]: ... + def consume_with_handlers(self, handlers: list[Callable]) -> list[Any]: ... + def consume_token(self) -> ShardedToken: ... + +def get_execution_stream_id() -> int: ... + +def set_execution_stream_id(new_id: int): ... + +class LoadedExecutable: + client: Client + def local_devices(self) -> list[Device]: ... + def size_of_generated_code_in_bytes(self) -> int: ... + def execute(self, arguments: Sequence[ArrayImpl]) -> list[ArrayImpl]: ... + def execute_with_token( + self, arguments: Sequence[ArrayImpl] + ) -> tuple[list[ArrayImpl], Token]: ... + def execute_sharded( + self, arguments: Sequence[list[ArrayImpl]], with_tokens: bool = ... + ) -> ExecuteResults: ... + def hlo_modules(self) -> list[HloModule]: ... + def get_output_memory_kinds(self) -> list[list[str]]: ... + def get_compiled_memory_stats(self) -> CompiledMemoryStats: ... + def get_output_shardings(self) -> list[OpSharding] | None: ... + def get_parameter_shardings(self) -> list[OpSharding] | None: ... + def get_parameter_layouts(self) -> list[Layout]: ... + def get_output_layouts(self) -> list[Layout]: ... + def keep_alive(self) -> None: ... + def cost_analysis(self) -> dict[str, Any]: ... + traceback: Traceback + fingerprint: bytes | None + +class Executable: + def hlo_modules(self) -> list[HloModule]: ... + def get_output_memory_kinds(self) -> list[list[str]]: ... + def get_output_shardings(self) -> list[OpSharding] | None: ... + def get_parameter_shardings(self) -> list[OpSharding] | None: ... + def get_parameter_layouts(self) -> list[Layout]: ... + def get_output_layouts(self) -> list[Layout]: ... + def get_compiled_memory_stats(self) -> CompiledMemoryStats: ... + def serialize(self) -> str: ... + def cost_analysis(self) -> dict[str, Any]: ... + +class DeviceTopology: + platform: str + platform_version: str + def _make_compile_only_devices(self) -> list[Device]: ... + def serialize(self) -> bytes: ... + def __getattr__(self, name: str) -> Any: ... + +def buffer_to_dlpack_managed_tensor( + buffer: ArrayImpl, stream: int | None = None +) -> Any: ... +@overload +def dlpack_managed_tensor_to_buffer( + tensor: Any, device: Device, stream: int | None +) -> ArrayImpl: ... +@overload +def dlpack_managed_tensor_to_buffer( # Legacy overload + tensor: Any, + cpu_backend: Client | None = ..., + gpu_backend: Client | None = ..., +) -> ArrayImpl: ... +def cuda_array_interface_to_buffer( + cai: dict[ + str, + ( + str + | int + | None + | tuple[int, ...] + | tuple[int, bool] + | list[tuple[str, str]] + | list[tuple[str, str, tuple[int, ...]]] + ), + ], + gpu_backend: Client | None = ..., + device_id: int | None = None, +) -> ArrayImpl: ... + +# === BEGIN py_traceback.cc + +class Frame: + file_name: str + function_name: str + function_line_start: int + line_num: int + def __init__( + self, + file_name: str, + function_name: str, + function_line_start: int, + line_num: int, + ): ... + def __repr__(self) -> str: ... + +class Traceback: + enabled: ClassVar[bool] + @staticmethod + def get_traceback() -> Traceback: ... + @staticmethod + def traceback_from_frames(frames: Sequence[Frame]) -> Any: ... + frames: Sequence[Frame] + def __str__(self) -> str: ... + def as_python_traceback(self) -> Any: ... + def raw_frames(self) -> tuple[list[types.CodeType], list[int]]: ... + @staticmethod + def code_addr2line(code: types.CodeType, lasti: int) -> int: ... + @staticmethod + def code_addr2location( + code: types.CodeType, lasti: int + ) -> tuple[int, int, int, int]: ... + +def tracebacks_enabled() -> bool: ... +def set_tracebacks_enabled(enabled: bool) -> None: ... + +# === END py_traceback.cc + +class DistributedRuntimeService: + def shutdown(self) -> None: ... + +class DistributedRuntimeClient: + def connect(self) -> _Status: ... + def shutdown(self) -> _Status: ... + def blocking_key_value_get(self, key: str, timeout_in_ms: int) -> _Status: ... + def blocking_key_value_get_bytes( + self, key: str, timeout_in_ms: int + ) -> _Status: ... + def key_value_try_get(self, key: str) -> _Status: ... + def key_value_try_get_bytes(self, key: str) -> _Status: ... + def key_value_dir_get(self, key: str) -> _Status: ... + def key_value_dir_get_bytes(self, key: str) -> _Status: ... + def key_value_set( + self, key: str, value: str, allow_overwrite: bool = False + ) -> _Status: ... + def key_value_set_bytes( + self, key: str, value: bytes, allow_overwrite: bool = False + ) -> _Status: ... + def key_value_delete(self, key: str) -> _Status: ... + def wait_at_barrier( + self, + barrier_id: str, + timeout_in_ms: int, + process_ids: list[int] | None = None, + ) -> _Status: ... + def get_live_nodes(self, process_ids: list[int]) -> _Status: ... + +def get_distributed_runtime_service( + address: str, + num_nodes: int, + heartbeat_interval: int | None = ..., + max_missing_heartbeats: int | None = ..., + cluster_register_timeout: int | None = ..., + shutdown_timeout: int | None = ..., +) -> DistributedRuntimeService: ... +def get_distributed_runtime_client( + address: str, + node_id: int, + rpc_timeout: int | None = ..., + init_timeout: int | None = ..., + shutdown_timeout: int | None = ..., + heartbeat_interval: int | None = ..., + max_missing_heartbeats: int | None = ..., + missed_heartbeat_callback: Any | None = ..., + shutdown_on_destruction: bool | None = ..., + use_compression: bool | None = ..., +) -> DistributedRuntimeClient: ... + +class PreemptionSyncManager: + def initialize(self, client: DistributedRuntimeClient) -> _Status: ... + def reached_sync_point(self, step_counter: int) -> bool: ... + def shutdown(self) -> None: ... + +def create_preemption_sync_manager() -> PreemptionSyncManager: ... +def collect_garbage() -> None: ... +def is_optimized_build() -> bool: ... +def json_to_pprof_profile(json: str) -> bytes: ... +def pprof_profile_to_json(proto: bytes) -> str: ... + +class PmapFunction: + def __call__(self, *args, **kwargs) -> Any: ... + def __getstate__(self) -> Any: ... + def __setstate__(self, Any): ... + __signature__: inspect.Signature + def _cache_size(self) -> int: ... + def _cache_clear(self) -> None: ... + +class DeviceList: + def __init__(self, device_assignment: tuple[Device, ...]): ... + def __hash__(self) -> int: ... + def __eq__(self, other: Any) -> bool: ... + def __ne__(self, other: Any) -> bool: ... + def __len__(self) -> int: ... + def __getitem__(self, index: Any) -> Any: ... + def __iter__(self) -> Iterator[Device]: ... + def __str__(self) -> str: ... + def __repr__(self) -> str: ... + def __getstate__(self) -> Any: ... + def __setstate__(self, state: Any): ... + @property + def is_fully_addressable(self) -> bool: ... + @property + def addressable_device_list(self) -> DeviceList: ... + @property + def process_indices(self) -> set[int]: ... + @property + def default_memory_kind(self) -> str | None: ... + @property + def memory_kinds(self) -> tuple[str, ...]: ... + @property + def device_kind(self) -> str: ... + +class Sharding: ... + +class NamedSharding(Sharding): + def __init__( + self, + mesh: Any, + spec: Any, + *, + memory_kind: str | None = None, + _logical_device_ids: tuple[int, ...] | None = None, + ): ... + mesh: Any + spec: Any + _memory_kind: str | None + _internal_device_list: DeviceList + _logical_device_ids: tuple[int, ...] | None + +class SingleDeviceSharding(Sharding): + def __init__(self, device: Device, *, memory_kind: str | None = None): ... + _device: Device + _memory_kind: str | None + _internal_device_list: DeviceList + +class PmapSharding(Sharding): + def __init__( + self, devices: Sequence[Any], sharding_spec: pmap_lib.ShardingSpec + ): ... + devices: list[Any] + sharding_spec: pmap_lib.ShardingSpec + _internal_device_list: DeviceList + +class GSPMDSharding(Sharding): + def __init__( + self, + devices: Sequence[Device], + op_sharding: OpSharding | HloSharding, + *, + memory_kind: str | None = None, + _device_list: DeviceList | None = None, + ): ... + _devices: tuple[Device, ...] + _hlo_sharding: HloSharding + _memory_kind: str | None + _internal_device_list: DeviceList + +class PjitFunction: + def __call__(self, *args, **kwargs) -> Any: ... + +class PjitFunctionCache: + def __init__(self, capacity: int = ...): ... + def __getstate__(self) -> Any: ... + def __setstate__(self, Any): ... + def size(self) -> int: ... + def capacity(self) -> int: ... + def clear(self): ... + @staticmethod + def clear_all(): ... + +def pjit( + function_name: str, + fun: Callable | None, + cache_miss: Callable, + static_argnums: Sequence[int], + static_argnames: Sequence[str], + global_cache_key: Any, + pytree_registry: pytree.PyTreeRegistry, + shard_arg_fallback: Callable, + cache: PjitFunctionCache | None = ..., +) -> PjitFunction: ... + +class WeakrefLRUCacheInfo: + @property + def hits(self) -> int: ... + @property + def misses(self) -> int: ... + @property + def maxsize(self) -> int: ... + @property + def currsize(self) -> int: ... + +class WeakrefLRUCache: + def __call__(self, weakref_key: Any, *args, **kwargs) -> Any: ... + def cache_keys(self) -> list[Any]: ... + def cache_info(self) -> WeakrefLRUCacheInfo: ... + def cache_clear(self): ... + +def is_asan() -> bool: ... +def is_msan() -> bool: ... +def is_tsan() -> bool: ... +def is_sanitized() -> bool: ... + +class TransferConnection: + def address(self) -> str: ... + def _pull_flat(self, uuid, backend, avals_flat) -> list[Any]: ... + +class TransferServer: + def _await_pull_flat(self, uuid, args: list[ArrayImpl]): ... + def connect(self, address: str) -> TransferConnection: ... + +def start_transfer_server( + client: Client, + address: str = "", + transport_addresses: list[str] = [], + max_num_parallel_copies: int = 0, + transfer_size: int = 0, +) -> TransferServer: ... +def approx_top_k_reduction_output_size( + input_size: int, + rank: int, + top_k: int, + recall_target: float, + aggregate_to_topk: bool | None = ..., + input_size_override: int | None = ..., +) -> tuple[int, int]: ... +def get_internal_device_put_info() -> dict[str, int]: ... + +class UnconstrainedSingleton: + def __repr__(self) -> str: ... + def __reduce__(self) -> Any: ... + +UNCONSTRAINED_PARTITION: UnconstrainedSingleton + +class PartitionSpec: + def __init__(self, *partitions, unreduced: Set[Any] | None = None): ... + def __hash__(self): ... + def __eq__(self, other): ... + _HAS_DYNAMIC_ATTRIBUTES: bool = ... + +def canonicalize_partition(partition: Any) -> Any: ... diff --git a/jaxlib/_jax/config.pyi b/jaxlib/_jax/config.pyi new file mode 100644 index 000000000000..535554559180 --- /dev/null +++ b/jaxlib/_jax/config.pyi @@ -0,0 +1,32 @@ +# Copyright 2024 The JAX Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from typing import Any, Generic, TypeVar + +unset: object + +_T = TypeVar('_T') + +class Config(Generic[_T]): + def __init__(self, value: _T, include_in_jit_key: bool = False): ... + + @property + def value(self) -> _T: ... + + def get_local(self) -> Any: ... + def get_global(self) -> _T: ... + def set_local(self, value: Any) -> None: ... + def swap_local(self, value: Any) -> Any: ... + def set_global(self, value: _T) -> None: ... diff --git a/jaxlib/_jax/ffi.pyi b/jaxlib/_jax/ffi.pyi new file mode 100644 index 000000000000..b92575e77c96 --- /dev/null +++ b/jaxlib/_jax/ffi.pyi @@ -0,0 +1,47 @@ +# Copyright 2025 The JAX Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import enum +from typing import Any + +class Buffer: + @property + def dtype(self) -> Any: ... + @property + def ndim(self) -> int: ... + @property + def shape(self) -> tuple[int, ...]: ... + @property + def writeable(self) -> bool: ... + def __array__(self, dtype: Any = None, copy: bool | None = None) -> Any: ... + def __cuda_array_interface__(self) -> Any: ... + def __dlpack__( + self, + stream: Any = None, + max_version: Any = None, + dl_device: Any = None, + copy: Any = None, + ) -> Any: ... + def __dlpack_device__(self) -> tuple[int, int]: ... + +class ExecutionStage(enum.IntEnum): + INSTANTIATE = ... + PREPARE = ... + INITIALIZE = ... + EXECUTE = ... + +class ExecutionContext: + def stage(self) -> ExecutionStage: ... + def stream(self) -> int: ... diff --git a/jaxlib/_jax/guard_lib.pyi b/jaxlib/_jax/guard_lib.pyi new file mode 100644 index 000000000000..7f8896a4f75a --- /dev/null +++ b/jaxlib/_jax/guard_lib.pyi @@ -0,0 +1,46 @@ +# Copyright 2024 The JAX Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from typing import Any + +class TransferGuardLevel: + ALLOW: Any + LOG: Any + DISALLOW: Any + LOG_EXPLICIT: Any + DISALLOW_EXPLICIT: Any + +class GarbageCollectionGuardLevel: + ALLOW: Any + LOG: Any + FATAL: Any + +class GuardState: + host_to_device: TransferGuardLevel | None + device_to_device: TransferGuardLevel | None + device_to_host: TransferGuardLevel | None + + explicit_device_put: bool + explicit_device_get: bool + + garbage_collect_array: GarbageCollectionGuardLevel | None + +def global_state() -> GuardState: ... +def thread_local_state() -> GuardState: ... + +class _TestingScopedLogSink: + def __enter__(self) -> _TestingScopedLogSink: ... + def __exit__(self, *args, **kwargs) -> None: ... + def logs(self) -> list[str]: ... diff --git a/jaxlib/_jax/ifrt_programs.pyi b/jaxlib/_jax/ifrt_programs.pyi new file mode 100644 index 000000000000..5e426b070c21 --- /dev/null +++ b/jaxlib/_jax/ifrt_programs.pyi @@ -0,0 +1,45 @@ +# Copyright 2024 The JAX Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from typing import Any +from collections.abc import Sequence + +from jaxlib import _jax + +class Program: ... + +class CompileOptions: ... + +def make_hlo_program(mlir_module: str | bytes) -> Program: ... + +def make_colocated_python_program( + name : str, + picked_function: bytes, + devices: Sequence[_jax.Device] | _jax.DeviceList, + input_avals: Sequence[Any], + output_avals: Sequence[Any], +) -> Program: ... + +def make_plugin_program(data: str | bytes) -> Program: ... + +def make_colocated_python_compile_options() -> CompileOptions: ... + +def make_xla_compile_options( + compile_options: _jax.CompileOptions, + executable_devices: _jax.DeviceList, + host_callbacks: Sequence[Any] +) -> CompileOptions: ... + +def make_plugin_compile_options() -> CompileOptions: ... diff --git a/jaxlib/_jax/ifrt_proxy.pyi b/jaxlib/_jax/ifrt_proxy.pyi new file mode 100644 index 000000000000..73688b2d9696 --- /dev/null +++ b/jaxlib/_jax/ifrt_proxy.pyi @@ -0,0 +1,34 @@ +# Copyright 2024 The JAX Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from typing import Any +from collections.abc import Callable + +from jaxlib import _jax + +_Status = Any +Client = _jax.Client + + +class ClientConnectionOptions: + on_disconnect: Callable[[_Status], None] | None = None + on_connection_update: Callable[[str], None] | None = None + connection_timeout_in_seconds: int | None = None + + +def get_client( + proxy_server_address: str, + options: ClientConnectionOptions +) -> Client: ... diff --git a/jaxlib/_jax/jax_jit.pyi b/jaxlib/_jax/jax_jit.pyi new file mode 100644 index 000000000000..be7687f4eaa1 --- /dev/null +++ b/jaxlib/_jax/jax_jit.pyi @@ -0,0 +1,77 @@ +# Copyright 2021 The JAX Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from typing import Any +from collections.abc import Callable, Sequence + +import numpy as np +from jaxlib import _jax + +from . import pytree + +Client = _jax.Client +Device = _jax.Device + + +class JitState: + disable_jit: bool | None + enable_x64: bool | None + default_device: Any | None + extra_jit_context: Any | None + post_hook: Callable[..., Any] | None + +def global_state() -> JitState: ... +def thread_local_state() -> JitState: ... + +def get_enable_x64() -> bool: ... +def set_thread_local_state_initialization_callback( + function: Callable[[], None]): ... + +def swap_thread_local_state_disable_jit( + value: bool | None) -> bool | None: ... + +class ArgSignature: + dtype: np.dtype + shape: tuple[int, ...] + weak_type: bool + +def _ArgSignatureOfValue( + __arg: Any, + __jax_enable_x64: bool) -> ArgSignature: ... + +def _is_float0(__arg: Any) -> bool: ... + + +class ArgumentSignature: + static_args: Sequence[Any] + static_arg_names: Sequence[str] + dynamic_arg_names: Sequence[str] + dynamic_arg_treedefs: Sequence[pytree.PyTreeDef] + + def __eq__(self, value, /): ... + def __ne__(self, value, /): ... + def __hash__(self, /): ... + def __str__(self): ... + def __repr__(self): ... + + +def parse_arguments( + positional_args: Sequence[Any], + keyword_args: Sequence[Any], + kwnames: tuple[str, ...], + static_argnums: Sequence[int], + static_argnames: Sequence[str], + pytree_registry: pytree.PyTreeRegistry, +) -> tuple[ArgumentSignature, Sequence[Any]]: ... diff --git a/jaxlib/_jax/mlir.pyi b/jaxlib/_jax/mlir.pyi new file mode 100644 index 000000000000..9be8ef71b50d --- /dev/null +++ b/jaxlib/_jax/mlir.pyi @@ -0,0 +1,34 @@ +# Copyright 2021 The JAX Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from . import XlaComputation + +def hlo_to_stablehlo(computation: bytes) -> bytes: ... +def xla_computation_to_mlir_module(computation: XlaComputation) -> str: ... +def mlir_module_to_xla_computation( + mlir_module: bytes | str, + use_tuple_args: bool = ..., + return_tuple: bool = ..., +) -> XlaComputation: ... +def mhlo_to_stablehlo(mlir_module: bytes | str) -> bytes: ... +def stablehlo_to_mhlo(mlir_module: bytes | str) -> bytes: ... +def serialize_portable_artifact(mlir_module: str, target: str) -> bytes: ... +def deserialize_portable_artifact(mlir_module: bytes) -> str: ... +def refine_polymorphic_shapes( + mlir_module: bytes | str, + enable_shape_assertions: bool = ..., + validate_static_shapes: bool = ..., + enable_shardy: bool = ..., +) -> bytes: ... diff --git a/jaxlib/_jax/pmap_lib.pyi b/jaxlib/_jax/pmap_lib.pyi new file mode 100644 index 000000000000..3e26e7e1da84 --- /dev/null +++ b/jaxlib/_jax/pmap_lib.pyi @@ -0,0 +1,84 @@ +# Copyright 2021 The JAX Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import inspect +from typing import Any +from collections.abc import Callable, Sequence, Iterable + +from . import pytree + +_AvalDimSharding = Any +_MeshDimAssignment = Any + +class NoSharding: + def __init__(self) -> None: ... + def __repr__(self) -> str: ... + def __eq__(self, __other: Any) -> bool: ... + +class Chunked: + @property + def chunks(self) -> Sequence[int]: ... + def __init__(self, __chunks: Sequence[int]) -> None: ... + def __repr__(self) -> str: ... + def __eq__(self, __other: Any) -> bool: ... + +class Unstacked: + @property + def size(self) -> int: ... + def __init__(self, __sz: int) -> None: ... + def __repr__(self) -> str: ... + def __eq__(self, __other: Any) -> bool: ... + +class ShardedAxis: + @property + def axis(self) -> int: ... + def __init__(self, __axis: int) -> None: ... + def __repr__(self) -> str: ... + def __eq__(self, __other: Any) -> bool: ... + +class Replicated: + @property + def replicas(self) -> int: ... + def __init__(self, __replicas: int) -> None: ... + def __repr__(self) -> str: ... + def __eq__(self, __other: Any) -> bool: ... + +class ShardingSpec: + def __init__(self, + sharding: Iterable[_AvalDimSharding], + mesh_mapping: Iterable[_MeshDimAssignment]) -> None: ... + @property + def sharding(self) -> tuple[_AvalDimSharding, ...]: ... + @property + def mesh_mapping(self) -> tuple[_MeshDimAssignment]: ... + def __eq__(self, __other: Any) -> bool: ... + def __hash__(self) -> int: ... + + _HAS_DYNAMIC_ATTRIBUTES = True + +class PmapFunction: + def __call__(self, *args, **kwargs) -> Any: ... + def __getstate__(self) -> Any: ... + def __setstate__(self, Any): ... + __signature__: inspect.Signature + def _cache_size(self) -> int: ... + def _cache_clear(self) -> None: ... + def _debug_cache_keys(self) -> str: ... + +def pmap(fun: Callable[..., Any], + cache_miss: Callable[..., Any], + static_argnums: Sequence[int], + shard_arg_fallback: Callable[..., Any], + pytree_registry: pytree.PyTreeRegistry) -> PmapFunction: ... diff --git a/jaxlib/_jax/profiler.pyi b/jaxlib/_jax/profiler.pyi new file mode 100644 index 000000000000..a2fcc67fbcb7 --- /dev/null +++ b/jaxlib/_jax/profiler.pyi @@ -0,0 +1,59 @@ +# Copyright 2021 The JAX Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from types import TracebackType +from typing import Any + +_Status = Any + +class ProfilerServer: ... +def start_server(port: int) -> ProfilerServer: ... + +def register_plugin_profiler(c_api: Any) -> None: ... + +def get_profiled_instructions_proto(tensorboard_dir: str) -> bytes: ... +def get_instructins_profile(tensorboard_dir: str) -> list[tuple[str, float]]: ... +def get_fdo_profile( + xspace: bytes, as_textproto: bool = ... +) -> bytes | str: ... + +class ProfilerSession: + def __init__(self, options: ProfileOptions | None = ...) -> None: ... + def stop(self) -> bytes: ... + def export(self, xspace: bytes, tensorboard_dir: str) -> _Status:... + +class ProfileOptions: + include_dataset_ops: bool + host_tracer_level: int + python_tracer_level: int + enable_hlo_proto: bool + start_timestamp_ns: int + duration_ms: int + repository_path: str + raise_error_on_start_failure: bool + +def aggregate_profiled_instructions(profiles: list[bytes], percentile: int) -> str: ... + +class TraceMe: + def __init__(self, name: str, **kwargs: Any) -> None: ... + def __enter__(self) -> TraceMe: ... + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + exc_tb: TracebackType | None) -> bool | None:... + def set_metadata(self, **kwargs): ... + @staticmethod + def is_enabled() -> bool: ... diff --git a/jaxlib/_jax/pytree.pyi b/jaxlib/_jax/pytree.pyi new file mode 100644 index 000000000000..2a33203abb1d --- /dev/null +++ b/jaxlib/_jax/pytree.pyi @@ -0,0 +1,143 @@ +# Copyright 2021 The JAX Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from builtins import tuple as Tuple +from typing import Any, TypeVar +from collections.abc import Callable, Hashable, Iterable, Sequence + +_T = TypeVar("_T") + +version: int + +class PyTreeRegistry: + def __init__( + self, + *, + enable_none: bool = ..., + enable_tuple: bool = ..., + enable_namedtuple: bool = ..., + enable_list: bool = ..., + enable_dict: bool = ... + ): ... + def flatten( + self, + tree: Any, + leaf_predicate: Callable[[Any], bool] | None = ..., + ) -> Tuple[list[Any], PyTreeDef]: ... + def flatten_one_level( + self, tree: Any + ) -> Tuple[Iterable[Any], Any] | None: ... + def flatten_one_level_with_keys( + self, tree: Any + ) -> Tuple[Iterable[_KeyLeafPair], Any] | None: ... + def flatten_with_path( + self, + tree: Any, + leaf_predicate: Callable[[Any, Any], bool] | None = ..., + ) -> Tuple[list[Tuple[_KeyPath, Any]], PyTreeDef]: ... + def register_node( + self, + __type: type[_T], + to_iterable: Callable[[_T], Tuple[_Children, _AuxData]], + from_iterable: Callable[[_AuxData, _Children], _T], + to_iterable_with_keys: ( + Callable[[_T], Tuple[_KeyLeafPairs, _AuxData]] | None + ) = ..., + ) -> Any: ... + def register_dataclass_node( + self, __type: type[_T], meta_fields: list[str], data_fields: list[str] + ) -> Any: ... + +def default_registry() -> PyTreeRegistry: ... +def tuple(registry: PyTreeRegistry, arg0: Sequence[PyTreeDef]) -> PyTreeDef: ... +def all_leaves(registry: PyTreeRegistry, arg0: Iterable[Any]) -> bool: ... + +class SequenceKey(Hashable): + idx: int + __match_args__: Tuple = ... + def __init__(self, idx: int): ... + def __str__(self) -> str: ... + def __repr__(self) -> str: ... + def __hash__(self) -> int: ... + def __getstate__(self) -> Any: ... + def __setstate__(self, state: Any): ... + def __eq__(self, __other: Any) -> bool: ... + +class DictKey(Hashable): + key: Hashable + __match_args__: Tuple = ... + def __init__(self, key: Hashable): ... + def __str__(self) -> str: ... + def __repr__(self) -> str: ... + def __hash__(self) -> int: ... + def __getstate__(self) -> Any: ... + def __setstate__(self, state: Any): ... + def __eq__(self, __other: Any) -> bool: ... + +class GetAttrKey(Hashable): + name: str + __match_args__: Tuple = ... + def __init__(self, name: str): ... + def __str__(self) -> str: ... + def __repr__(self) -> str: ... + def __hash__(self) -> int: ... + def __getstate__(self) -> Any: ... + def __setstate__(self, state: Any): ... + def __eq__(self, __other: Any) -> bool: ... + +class FlattenedIndexKey(Hashable): + key: int + __match_args__: Tuple = ... + def __init__(self, key: int): ... + def __str__(self) -> str: ... + def __repr__(self) -> str: ... + def __hash__(self) -> int: ... + def __getstate__(self) -> Any: ... + def __setstate__(self, state: Any): ... + def __eq__(self, __other: Any) -> bool: ... + +class PyTreeDef: + def unflatten(self, __leaves: Iterable[Any]) -> Any: ... + def flatten_up_to(self, __xs: Any) -> list[Any]: ... + def compose(self, __inner: PyTreeDef) -> PyTreeDef: ... + def walk( + self, + __f_node: Callable[[Any, Any], Any], + __f_leaf: Callable[[_T], Any] | None, + leaves: Iterable[Any], + ) -> Any: ... + def from_iterable_tree(self, __xs: Any): ... + def node_data(self) -> Tuple[type, Any] | None: ... + def children(self) -> list[PyTreeDef]: ... + + num_leaves: int + num_nodes: int + def __repr__(self) -> str: ... + def __eq__(self, __other: Any) -> bool: ... + def __ne__(self, __other: Any) -> bool: ... + def __hash__(self) -> int: ... + def __getstate__(self) -> Any: ... + def __setstate__(self, state: Any): ... + def serialize_using_proto(self) -> bytes: ... + @staticmethod + def deserialize_using_proto( + registry: PyTreeRegistry, data: bytes + ) -> PyTreeDef: ... + +_Children = TypeVar("_Children", bound=Iterable[Any]) +_KeyLeafPair = TypeVar("_KeyLeafPair", bound=Tuple[Any, Any]) +_KeyLeafPairs = TypeVar("_KeyLeafPairs", bound=Iterable[Tuple[Any, Any]]) +_KeyPath = TypeVar("_KeyPath", bound=Tuple[Any, ...]) +_AuxData = TypeVar("_AuxData", bound=Hashable) diff --git a/jaxlib/_jax/sdy.pyi b/jaxlib/_jax/sdy.pyi new file mode 100644 index 000000000000..520f93f11bc6 --- /dev/null +++ b/jaxlib/_jax/sdy.pyi @@ -0,0 +1,32 @@ +# Copyright 2021 The JAX Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from jaxlib.mlir import ir + +def sdy_round_trip_export_pipeline( + module: ir.module +) -> str: ... + +def sdy_round_trip_import_shardings( + module: ir.module +) -> str: ... + +def get_mesh( + module: ir.module +) -> tuple[tuple[str, int], ...]: ... + +def lowered_with_shardy( + module: ir.module +) -> bool: ... diff --git a/jaxlib/_jax/transfer_guard_lib.pyi b/jaxlib/_jax/transfer_guard_lib.pyi new file mode 100644 index 000000000000..d293f7c59798 --- /dev/null +++ b/jaxlib/_jax/transfer_guard_lib.pyi @@ -0,0 +1,39 @@ +# Copyright 2022 The JAX Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from typing import Any + +class TransferGuardLevel: + ALLOW: Any + LOG: Any + DISALLOW: Any + LOG_EXPLICIT: Any + DISALLOW_EXPLICIT: Any + +class TransferGuardState: + host_to_device: TransferGuardLevel | None + device_to_device: TransferGuardLevel | None + device_to_host: TransferGuardLevel | None + + explicit_device_put: bool + explicit_device_get: bool + +def global_state() -> TransferGuardState: ... +def thread_local_state() -> TransferGuardState: ... + +class _TestingScopedLogSink: + def __enter__(self) -> _TestingScopedLogSink: ... + def __exit__(self, *args, **kwargs) -> None: ... + def logs(self) -> list[str]: ... diff --git a/jaxlib/_pretty_printer.cc b/jaxlib/_pretty_printer.cc new file mode 100644 index 000000000000..1bf6f8d2f541 --- /dev/null +++ b/jaxlib/_pretty_printer.cc @@ -0,0 +1,755 @@ +/* Copyright 2025 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include +#include + +#include "absl/container/inlined_vector.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/types/span.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/optional.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "jaxlib/nb_class_ptr.h" + +namespace nb = nanobind; + +namespace jax { + +enum class Color { + kBlack = 30, + kRed = 31, + kGreen = 32, + kYellow = 33, + kBlue = 34, + kMagenta = 35, + kCyan = 36, + kWhite = 37, + kReset = 39, +}; + +std::string ColorToString(Color color) { + switch (color) { + case Color::kBlack: + return "black"; + case Color::kRed: + return "red"; + case Color::kGreen: + return "green"; + case Color::kYellow: + return "yellow"; + case Color::kBlue: + return "blue"; + case Color::kMagenta: + return "magenta"; + case Color::kCyan: + return "cyan"; + case Color::kWhite: + return "white"; + case Color::kReset: + return "reset"; + } +} + +enum class Intensity { + kNormal = 22, + kDim = 2, + kBright = 1, +}; + +std::string IntensityToString(Intensity intensity) { + switch (intensity) { + case Intensity::kNormal: + return "normal"; + case Intensity::kDim: + return "dim"; + case Intensity::kBright: + return "bright"; + } +} + +struct FormatState; +struct FormatAgendum; + +class Doc { + public: + Doc(int num_annotations) : num_annotations_(num_annotations) {} + virtual ~Doc() = default; + virtual std::string Repr() const = 0; + + int num_annotations() const { return num_annotations_; } + + virtual void Fits(std::stack& agenda, int& width) const = 0; + + // Returns true if the doc may be sparse, i.e. there are no breaks between + // annotations. Returns false if the doc is known not to be sparse. + virtual bool Sparse(std::stack& agenda, int& num_annotations, + bool& seen_break) const = 0; + + virtual void Format(const FormatAgendum& agendum, + FormatState& state) const = 0; + + private: + int num_annotations_; +}; + +class NilDoc final : public Doc { + public: + NilDoc() : Doc(/*num_annotations=*/0) {} + std::string Repr() const override; + + void Fits(std::stack& agenda, int& width) const override; + bool Sparse(std::stack& agenda, int& num_annotations, + bool& seen_break) const override; + virtual void Format(const FormatAgendum& agendum, + FormatState& state) const override; +}; + +class TextDoc final : public Doc { + public: + TextDoc(std::string text, std::optional annotation) + : Doc(annotation.has_value() ? 1 : 0), + text_(std::move(text)), + annotation_(std::move(annotation)) {} + std::string Repr() const override; + void Fits(std::stack& agenda, int& width) const override; + bool Sparse(std::stack& agenda, int& num_annotations, + bool& seen_break) const override; + virtual void Format(const FormatAgendum& agendum, + FormatState& state) const override; + + private: + std::string text_; + std::optional annotation_; +}; + +class ConcatDoc final : public Doc { + public: + explicit ConcatDoc(std::vector> children) + : Doc(TotalNumAnnotations(children)), children_(std::move(children)) {} + std::string Repr() const override; + + void Fits(std::stack& agenda, int& width) const override; + bool Sparse(std::stack& agenda, int& num_annotations, + bool& seen_break) const override; + virtual void Format(const FormatAgendum& agendum, + FormatState& state) const override; + + private: + static int TotalNumAnnotations( + absl::Span> children) { + int total = 0; + for (const auto& child : children) { + total += child->num_annotations(); + } + return total; + } + std::vector> children_; +}; + +class BreakDoc final : public Doc { + public: + explicit BreakDoc(std::string text) + : Doc(/*num_annotations=*/0), text_(std::move(text)) {} + std::string Repr() const override; + void Fits(std::stack& agenda, int& width) const override; + bool Sparse(std::stack& agenda, int& num_annotations, + bool& seen_break) const override; + virtual void Format(const FormatAgendum& agendum, + FormatState& state) const override; + + private: + std::string text_; +}; + +class GroupDoc final : public Doc { + public: + explicit GroupDoc(xla::nb_class_ptr child) + : Doc(/*num_annotations=*/child->num_annotations()), + child_(std::move(child)) {} + std::string Repr() const override; + void Fits(std::stack& agenda, int& width) const override; + bool Sparse(std::stack& agenda, int& num_annotations, + bool& seen_break) const override; + virtual void Format(const FormatAgendum& agendum, + FormatState& state) const override; + + private: + xla::nb_class_ptr child_; +}; + +class NestDoc final : public Doc { + public: + explicit NestDoc(int n, xla::nb_class_ptr child) + : Doc(child->num_annotations()), n_(n), child_(std::move(child)) {} + std::string Repr() const override; + void Fits(std::stack& agenda, int& width) const override; + bool Sparse(std::stack& agenda, int& num_annotations, + bool& seen_break) const override; + virtual void Format(const FormatAgendum& agendum, + FormatState& state) const override; + + private: + int n_; + xla::nb_class_ptr child_; +}; + +class SourceMapDoc final : public Doc { + public: + explicit SourceMapDoc(xla::nb_class_ptr child, nb::object source) + : Doc(child->num_annotations()), + child_(std::move(child)), + source_(std::move(source)) {} + std::string Repr() const override; + void Fits(std::stack& agenda, int& width) const override; + bool Sparse(std::stack& agenda, int& num_annotations, + bool& seen_break) const override; + virtual void Format(const FormatAgendum& agendum, + FormatState& state) const override; + + private: + xla::nb_class_ptr child_; + nb::object source_; +}; + +class ColorDoc final : public Doc { + public: + explicit ColorDoc(xla::nb_class_ptr child, + std::optional foreground, + std::optional background, + std::optional intensity) + : Doc(child->num_annotations()), + child_(std::move(child)), + foreground_(foreground), + background_(background), + intensity_(intensity) {} + + std::string Repr() const override; + void Fits(std::stack& agenda, int& width) const override; + bool Sparse(std::stack& agenda, int& num_annotations, + bool& seen_break) const override; + virtual void Format(const FormatAgendum& agendum, + FormatState& state) const override; + + private: + xla::nb_class_ptr child_; + std::optional foreground_; + std::optional background_; + std::optional intensity_; +}; + +std::string NilDoc::Repr() const { return "nil"; } + +std::string TextDoc::Repr() const { + if (annotation_.has_value()) { + return absl::StrFormat("text(\"%s\", annotation=\"%s\")", text_, + *annotation_); + } else { + return absl::StrFormat("text(\"%s\")", text_); + } +} + +std::string ConcatDoc::Repr() const { + return absl::StrFormat( + "concat(%s)", + absl::StrJoin(children_, ", ", [](std::string* out, const auto& child) { + absl::StrAppend(out, child->Repr()); + })); +} + +std::string BreakDoc::Repr() const { + return absl::StrFormat("break(\"%s\")", text_); +} + +std::string GroupDoc::Repr() const { + return absl::StrFormat("group(%s)", child_->Repr()); +} + +std::string NestDoc::Repr() const { + return absl::StrFormat("nest(%d, %s)", n_, child_->Repr()); +} + +std::string SourceMapDoc::Repr() const { + return absl::StrFormat("source(%s, %s)", child_->Repr(), + nb::cast(nb::repr(source_))); +} + +std::string ColorDoc::Repr() const { + std::string foreground_str = + foreground_.has_value() ? ColorToString(*foreground_) : "None"; + std::string background_str = + background_.has_value() ? ColorToString(*background_) : "None"; + std::string intensity_str = + intensity_.has_value() ? IntensityToString(*intensity_) : "None"; + return absl::StrFormat("color(%s, %s, %s, %s)", child_->Repr(), + foreground_str, background_str, intensity_str); +} + +// Fits method implementations + +void NilDoc::Fits(std::stack& agenda, int& width) const {} + +void TextDoc::Fits(std::stack& agenda, int& width) const { + width -= text_.size(); +} + +void ConcatDoc::Fits(std::stack& agenda, int& width) const { + for (auto it = children_.rbegin(); it != children_.rend(); ++it) { + agenda.push(it->get()); + } +} + +void BreakDoc::Fits(std::stack& agenda, int& width) const { + width -= static_cast(text_.size()); +} + +void GroupDoc::Fits(std::stack& agenda, int& width) const { + agenda.push(child_.get()); +} + +void NestDoc::Fits(std::stack& agenda, int& width) const { + agenda.push(child_.get()); +} + +void SourceMapDoc::Fits(std::stack& agenda, int& width) const { + agenda.push(child_.get()); +} + +void ColorDoc::Fits(std::stack& agenda, int& width) const { + agenda.push(child_.get()); +} + +bool Fits(const Doc* doc, int width) { + std::stack agenda; + agenda.push(doc); + while (width >= 0 && !agenda.empty()) { + const Doc* doc = agenda.top(); + agenda.pop(); + doc->Fits(agenda, width); + } + return width >= 0; +} + +// Sparse method implementations + +bool NilDoc::Sparse(std::stack& agenda, int& num_annotations, + bool& seen_break) const { + return true; +} + +bool TextDoc::Sparse(std::stack& agenda, int& num_annotations, + bool& seen_break) const { + if (annotation_.has_value()) { + if (num_annotations >= 1 && seen_break) { + return false; + } + num_annotations -= 1; + } + return true; +} + +bool ConcatDoc::Sparse(std::stack& agenda, int& num_annotations, + bool& seen_break) const { + for (auto it = children_.rbegin(); it != children_.rend(); ++it) { + agenda.push(it->get()); + } + return true; +} + +bool BreakDoc::Sparse(std::stack& agenda, int& num_annotations, + bool& seen_break) const { + seen_break = true; + return true; +} + +bool GroupDoc::Sparse(std::stack& agenda, int& num_annotations, + bool& seen_break) const { + agenda.push(child_.get()); + return true; +} + +bool NestDoc::Sparse(std::stack& agenda, int& num_annotations, + bool& seen_break) const { + agenda.push(child_.get()); + return true; +} + +bool SourceMapDoc::Sparse(std::stack& agenda, int& num_annotations, + bool& seen_break) const { + agenda.push(child_.get()); + return true; +} + +bool ColorDoc::Sparse(std::stack& agenda, int& num_annotations, + bool& seen_break) const { + agenda.push(child_.get()); + return true; +} + +// Returns true if the doc is sparse, i.e. there are no breaks between +// annotations. +bool Sparse(const Doc* doc) { + if (doc->num_annotations() == 0) { + return true; + } + std::stack agenda; + agenda.push(doc); + int num_annotations = 0; + bool seen_break = false; + while (!agenda.empty()) { + const Doc* doc = agenda.top(); + agenda.pop(); + if (!doc->Sparse(agenda, num_annotations, seen_break)) { + return false; + } + } + return true; +} + +struct ColorState { + Color foreground; + Color background; + Intensity intensity; + + bool operator==(const ColorState& other) const { + return foreground == other.foreground && background == other.background && + intensity == other.intensity; + } + bool operator!=(const ColorState& other) const { return !operator==(other); } +}; + +constexpr ColorState kDefaultColors = + ColorState{Color::kReset, Color::kReset, Intensity::kNormal}; +constexpr ColorState kAnnotationColors = + ColorState{Color::kReset, Color::kReset, Intensity::kDim}; + +enum class BreakMode { kFlat, kBreak }; + +struct FormatAgendum { + int indent; + BreakMode mode; + const Doc* doc; + ColorState color; + nb::object source; +}; + +struct Line { + std::string text; + int width; + std::vector annotations; +}; + +// Format method implementations + +struct FormatState { + int width; + std::stack agenda; + std::string line_text; + int k; + std::vector line_annotations; + std::optional color; + std::optional source_map; + nb::list line_source_map; + int source_start; + nb::object source; + std::vector lines; +}; + +std::string UpdateColor(std::optional& state, + const ColorState& update) { + if (!state.has_value() || *state == update) { + return ""; + } + std::string result = "\033["; + absl::InlinedVector codes; + if (state->foreground != update.foreground) { + codes.push_back(absl::StrCat(static_cast(update.foreground))); + } + if (state->background != update.background) { + codes.push_back(absl::StrCat(static_cast(update.background) + 10)); + } + if (state->intensity != update.intensity) { + codes.push_back(absl::StrCat(static_cast(update.intensity))); + } + absl::StrAppend(&result, absl::StrJoin(codes, ";"), "m"); + state = update; + return result; +} + +void NilDoc::Format(const FormatAgendum& agendum, FormatState& state) const {} + +void TextDoc::Format(const FormatAgendum& agendum, FormatState& state) const { + absl::StrAppend(&state.line_text, UpdateColor(state.color, agendum.color), + text_); + if (annotation_.has_value()) { + state.line_annotations.push_back(*annotation_); + } + state.k += text_.size(); +} + +void ConcatDoc::Format(const FormatAgendum& agendum, FormatState& state) const { + for (auto it = children_.rbegin(); it != children_.rend(); ++it) { + state.agenda.push(FormatAgendum{agendum.indent, agendum.mode, it->get(), + agendum.color, state.source}); + } +} + +void BreakDoc::Format(const FormatAgendum& agendum, FormatState& state) const { + if (agendum.mode == BreakMode::kBreak) { + if (!state.line_annotations.empty()) { + absl::StrAppend(&state.line_text, + UpdateColor(state.color, kAnnotationColors)); + } + if (state.source_map.has_value()) { + int pos = state.line_text.size(); + if (state.source_start != pos && state.source.ptr() != nullptr) { + state.line_source_map.append( + nb::make_tuple(state.source_start, pos, state.source)); + } + state.source_map->append(state.line_source_map); + state.line_source_map = nb::list(); + state.source_start = agendum.indent; + } + state.lines.push_back(Line{std::move(state.line_text), state.k, + std::move(state.line_annotations)}); + state.line_text = std::string(agendum.indent, ' '); + state.line_annotations.clear(); + state.k = agendum.indent; + } else { + absl::StrAppend(&state.line_text, UpdateColor(state.color, agendum.color), + text_); + state.k += text_.size(); + } +} + +void GroupDoc::Format(const FormatAgendum& agendum, FormatState& state) const { + // In Lindig's paper, _fits is passed the remainder of the document. + // I'm pretty sure that's a bug and we care only if the current group fits! + bool fits = ::jax::Fits(agendum.doc, state.width - state.k) && + ::jax::Sparse(agendum.doc); + state.agenda.push(FormatAgendum{agendum.indent, + fits ? BreakMode::kFlat : BreakMode::kBreak, + child_.get(), agendum.color, state.source}); +} + +void NestDoc::Format(const FormatAgendum& agendum, FormatState& state) const { + state.agenda.push(FormatAgendum{agendum.indent + n_, agendum.mode, + child_.get(), agendum.color, state.source}); +} + +void SourceMapDoc::Format(const FormatAgendum& agendum, + FormatState& state) const { + state.agenda.push(FormatAgendum{agendum.indent, agendum.mode, child_.get(), + agendum.color, source_}); +} + +void ColorDoc::Format(const FormatAgendum& agendum, FormatState& state) const { + ColorState color = agendum.color; + if (foreground_.has_value()) { + color.foreground = *foreground_; + } + if (background_.has_value()) { + color.background = *background_; + } + if (intensity_.has_value()) { + color.intensity = *intensity_; + } + state.agenda.push(FormatAgendum{agendum.indent, agendum.mode, child_.get(), + color, state.source}); +} + +std::string Format(const Doc* doc, int width, bool use_color, + std::string annotation_prefix, + std::optional source_map) { + FormatState state; + if (use_color) { + state.color = kDefaultColors; + } + state.width = width; + state.source_start = 0; + state.source_map = source_map; + state.agenda.push( + FormatAgendum{0, BreakMode::kBreak, doc, kDefaultColors, nb::object()}); + state.k = 0; + while (!state.agenda.empty()) { + FormatAgendum agendum = state.agenda.top(); + state.agenda.pop(); + if (source_map.has_value() && agendum.source.ptr() != state.source.ptr()) { + int pos = state.line_text.size(); + if (state.source_start != pos && state.source.ptr() != nullptr) { + state.line_source_map.append( + nb::make_tuple(state.source_start, pos, state.source)); + } + state.source = agendum.source; + state.source_start = pos; + } + agendum.doc->Format(agendum, state); + } + if (!state.line_annotations.empty()) { + absl::StrAppend(&state.line_text, + UpdateColor(state.color, kAnnotationColors)); + } + if (state.source_map.has_value()) { + int pos = state.line_text.size(); + if (state.source_start != pos && state.source.ptr() != nullptr) { + state.line_source_map.append( + nb::make_tuple(state.source_start, pos, state.source)); + } + state.source_map->append(state.line_source_map); + } + state.lines.push_back(Line{std::move(state.line_text), state.k, + std::move(state.line_annotations)}); + + int max_width = 0; + for (const auto& line : state.lines) { + max_width = std::max(max_width, line.width); + } + std::string out = + absl::StrJoin(state.lines, "\n", [&](std::string* out, const Line& line) { + if (line.annotations.empty()) { + absl::StrAppend(out, line.text); + } else { + absl::StrAppend(out, line.text, + std::string(max_width - line.width, ' '), + annotation_prefix, line.annotations[0]); + for (int i = 1; i < line.annotations.size(); ++i) { + absl::StrAppend(out, std::string(max_width, ' '), annotation_prefix, + line.annotations[i]); + } + } + }); + absl::StrAppend(&out, UpdateColor(state.color, kDefaultColors)); + return out; +} + +NB_MODULE(_pretty_printer, m) { + nb::enum_(m, "Color") + .value("BLACK", Color::kBlack) + .value("RED", Color::kRed) + .value("GREEN", Color::kGreen) + .value("YELLOW", Color::kYellow) + .value("BLUE", Color::kBlue) + .value("MAGENTA", Color::kMagenta) + .value("CYAN", Color::kCyan) + .value("WHITE", Color::kWhite) + .value("RESET", Color::kReset); + + nb::enum_(m, "Intensity") + .value("DIM", Intensity::kDim) + .value("NORMAL", Intensity::kNormal) + .value("BRIGHT", Intensity::kBright); + + nb::class_(m, "Doc") + .def("__repr__", &Doc::Repr) + .def("__add__", + [](xla::nb_class_ptr self, xla::nb_class_ptr other) { + return xla::make_nb_class( + std::vector>{std::move(self), + std::move(other)}); + }) + .def("_format", &Format, nb::arg("width"), nb::arg("use_color"), + nb::arg("annotation_prefix"), nb::arg("source_map").none()); + + nb::class_(m, "NilDoc"); + nb::class_(m, "TextDoc"); + nb::class_(m, "ConcatDoc"); + nb::class_(m, "BreakDoc"); + nb::class_(m, "GroupDoc"); + nb::class_(m, "NestDoc"); + nb::class_(m, "ColorDoc"); + nb::class_(m, "SourceMapDoc"); + + m.def( + "nil", []() { return xla::make_nb_class(); }, + "An empty document."); + m.def( + "text", + [](std::string text, std::optional annotation) { + return xla::make_nb_class(std::move(text), + std::move(annotation)); + }, + nb::arg("text"), nb::arg("annotation").none() = std::nullopt, + "Literal text."); + m.def( + "concat", + [](std::vector> children) { + return xla::make_nb_class(std::move(children)); + }, + nb::arg("children"), "Concatenation of documents."); + m.def( + "brk", + [](std::string text) { return xla::make_nb_class(text); }, + nb::arg("text") = std::string(" "), + R"(A break. + +Prints either as a newline or as `text`, depending on the enclosing group. +)"); + m.def( + "group", + [](xla::nb_class_ptr child) { + return xla::make_nb_class(std::move(child)); + }, + R"(Layout alternative groups. + +Prints the group with its breaks as their text (typically spaces) if the +entire group would fit on the line when printed that way. Otherwise, breaks +inside the group as printed as newlines. +)"); + m.def( + "nest", + [](int n, xla::nb_class_ptr child) { + return xla::make_nb_class(n, std::move(child)); + }, + "Increases the indentation level by `n`."); + m.def( + "color", + [](xla::nb_class_ptr child, std::optional foreground, + std::optional background, std::optional intensity) { + return xla::make_nb_class(std::move(child), foreground, + background, intensity); + }, + nb::arg("child"), nb::arg("foreground").none() = std::nullopt, + nb::arg("background").none() = std::nullopt, + nb::arg("intensity").none() = std::nullopt, + R"(ANSI colors. + +Overrides the foreground/background/intensity of the text for the child doc. +Requires use_colors=True to be set when printing; otherwise does nothing. +)"); + m.def( + "source_map", + [](xla::nb_class_ptr child, nb::object source) { + return xla::make_nb_class(std::move(child), + std::move(source)); + }, + nb::arg("doc"), nb::arg("source"), + R"(Source mapping. + +A source map associates a region of the pretty-printer's text output with a +source location that produced it. For the purposes of the pretty printer a +``source`` may be any object: we require only that we can compare sources for +equality. A text region to source object mapping can be populated as a side +output of the ``format`` method. +)"); +} + +} // namespace jax diff --git a/jaxlib/cached_py_object.h b/jaxlib/cached_py_object.h new file mode 100644 index 000000000000..b934fa203a44 --- /dev/null +++ b/jaxlib/cached_py_object.h @@ -0,0 +1,61 @@ +/* Copyright 2025 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAX_JAXLIB_CACHED_PY_OBJECT_H_ +#define JAX_JAXLIB_CACHED_PY_OBJECT_H_ + +#include + +#include "absl/functional/function_ref.h" +#include "nanobind/nanobind.h" + +namespace jax { + +// A lock-free thread-safe cache for a single Python object. +// Example use case: caching a hash value in an object. +class CachedPyObject { + public: + CachedPyObject() = default; + ~CachedPyObject() { + PyObject* value = value_.load(); + Py_XDECREF(value); + } + + // Returns the cached value of the object. If the object is not present, + // factory() will be called to create it and the cache will be populated. + // Note: factory() may be called multiple times if used concurrently. The + // returned value will be one of the returned values of factory(). + // Thread-safe. + nanobind::object Get(absl::FunctionRef factory) { + PyObject* v = value_.load(); + if (v) { + return nanobind::borrow(v); + } + nanobind::object new_value = factory(); + if (value_.compare_exchange_strong(v, new_value.inc_ref().ptr())) { + return new_value; + } else { + new_value.dec_ref(); + return nanobind::borrow(v); + } + } + + private: + std::atomic value_ = nullptr; +}; + +} // namespace jax + +#endif // JAX_JAXLIB_CACHED_PY_OBJECT_H_ diff --git a/jaxlib/callback.cc b/jaxlib/callback.cc new file mode 100644 index 000000000000..1262a534961c --- /dev/null +++ b/jaxlib/callback.cc @@ -0,0 +1,173 @@ +/* Copyright 2022 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/callback.h" + +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "jaxlib/python_ref_manager.h" +#include "xla/pjrt/host_callback.h" +#include "xla/pjrt/transpose.h" +#include "xla/primitive_util.h" +#include "xla/python/nb_numpy.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/tsl/python/lib/core/numpy.h" +#include "xla/xla_data.pb.h" + +namespace nb = nanobind; + +namespace xla { + +CpuCallback::~CpuCallback() { + // The destructor may be called without GIL held. In that case, we defer it + // to GlobalPyRefManager. + std::vector objects; + objects.push_back(std::move(callable_)); + for (auto& arg : args_) { + objects.push_back(std::move(arg.dtype)); + } + + GlobalPyRefManager()->AddGarbage(absl::MakeSpan(objects)); +} + +absl::Status CpuCallback::PrepareAndCall(void* result, void** arg_ptrs) { + absl::Span inputs(arg_ptrs, args_.size()); + absl::Span outputs(reinterpret_cast(result), + results_.size()); + + nb::gil_scoped_acquire gil; + nb::tuple args = nb::steal(PyTuple_New(inputs.size())); + for (size_t i = 0; i < inputs.size(); ++i) { + if (args_[i].type == xla::TOKEN) { + PyTuple_SET_ITEM(args.ptr(), i, nb::none().release().ptr()); + } else { + nb_numpy_ndarray array = + nb_numpy_ndarray(args_[i].dtype, args_[i].dims, args_[i].strides, + const_cast(inputs[i])); + array.attr("flags").attr("writeable") = nb::bool_(false); + PyTuple_SET_ITEM(args.ptr(), i, array.release().ptr()); + } + } + + EnterHostCallback(); + absl::StatusOr maybe_result_tuple = Call(std::move(args)); + LeaveHostCallback(); + TF_ASSIGN_OR_RETURN(auto result_tuple, maybe_result_tuple); + + for (size_t i = 0; i < results_.size(); ++i) { + if (results_[i].type == xla::TOKEN) { + continue; + } + nb::object output = + nb::borrow(PyTuple_GetItem(result_tuple.ptr(), i)); + nb_numpy_ndarray array = nb_numpy_ndarray::ensure(std::move(output)); + absl::Span dims( + reinterpret_cast(array.shape()), array.ndim()); + absl::Span strides( + reinterpret_cast(array.strides()), array.ndim()); + if (strides == results_[i].expected_strides) { + std::memcpy(outputs[i], array.data(), results_[i].size_in_bytes); + } else { + xla::TransposePlan::Options options; + options.elem_size_in_bytes = + xla::primitive_util::ByteWidth(results_[i].type); + options.dims = dims; + options.permutation = results_[i].reversed_layout; + options.input_layout = xla::TransposePlan::Striding{strides}; + absl::StatusOr> plan = + transpose_cache_.GetOrCreate(options); + if (!plan.ok()) { + return std::move(plan).status(); + } + plan.value()->Execute(array.data(), outputs[i]); + } + } + + return absl::OkStatus(); +} + +absl::StatusOr CpuCallback::Call(nb::tuple args) { + auto py_error_to_status = [](nb::python_error& e) { + std::string error_message = e.what(); + return absl::InternalError( + absl::StrFormat("CpuCallback error: %s", error_message)); + }; + nb::object result_object; + try { + result_object = callable_(*nb::borrow(args)); + } catch (nb::python_error& e) { + return py_error_to_status(e); + } + if (!PyTuple_Check(result_object.ptr())) { + return absl::InternalError( + absl::StrFormat("CPU callback expected a tuple result, got %s", + nb::cast(nb::repr(result_object)))); + } + if (PyTuple_Size(result_object.ptr()) != results_.size()) { + return absl::InternalError( + absl::StrFormat("CPU callback expected a tuple with %d results, got %d", + results_.size(), PyTuple_Size(result_object.ptr()))); + } + nb::tuple result_tuple = nb::cast(result_object); + for (size_t i = 0; i < results_.size(); ++i) { + nb::object output = + nb::borrow(PyTuple_GetItem(result_tuple.ptr(), i)); + if (results_[i].type == xla::TOKEN) { + if (!output.is_none()) { + return absl::InternalError(absl::StrFormat( + "Token output from Python callback should be None, got %s", + nb::cast(nb::repr(output)))); + } + continue; + } + nb_numpy_ndarray array; + try { + array = nb_numpy_ndarray::from_any(output, NPY_ARRAY_ENSUREARRAY); + } catch (nb::python_error& e) { + return py_error_to_status(e); + } + static_assert(sizeof(ssize_t) == sizeof(int64_t), + "Expected ssize_t to be of equal size to int64_t"); + absl::Span dims( + reinterpret_cast(array.shape()), array.ndim()); + if (dims != results_[i].expected_dims) { + return absl::InternalError(absl::StrFormat( + "Mismatched result shape for %d-th return value from CPU callback; " + "expected array with dimensions %s, got %s", + i, absl::StrJoin(results_[i].expected_dims, ","), + absl::StrJoin(dims, ","))); + } + } + return result_tuple; +} + +} // namespace xla diff --git a/jaxlib/callback.h b/jaxlib/callback.h new file mode 100644 index 000000000000..59844ebf73b9 --- /dev/null +++ b/jaxlib/callback.h @@ -0,0 +1,87 @@ +/* Copyright 2022 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_CALLBACK_H_ +#define JAXLIB_CALLBACK_H_ + +#include +#include +#include +#include + +#include "absl/container/inlined_vector.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "nanobind/nanobind.h" +#include "xla/pjrt/transpose.h" +#include "xla/python/nb_numpy.h" +#include "xla/xla_data.pb.h" + +namespace xla { + +class CpuCallback { + public: + struct Arg { + xla::PrimitiveType type; // XLA type + nb_dtype dtype; // NumPy type, for array types. + absl::InlinedVector dims; // Dimensions, for array types. + std::vector strides; // Byte strides, for array types. + size_t size_in_bytes; // Size of the array in bytes. + }; + struct Result { + xla::PrimitiveType type; // XLA type + // Expected output shape, for array types + absl::InlinedVector expected_dims; + // Expected output byte strides, for array types. If the strides do not + // match the output will be transposed into the expected layout. + std::vector expected_strides; + // The desired order of output dimensions in major-to-minor order. + absl::InlinedVector reversed_layout; + // Size of the array in bytes. + size_t size_in_bytes; + }; + + explicit CpuCallback(nanobind::callable callable, std::vector args, + std::vector results) + : callable_(std::move(callable)), + args_(std::move(args)), + results_(std::move(results)), + transpose_cache_(/*capacity=*/16) {} + + ~CpuCallback(); + + const std::vector& args() const { return args_; } + size_t num_args() const { return args_.size(); } + + const std::vector& results() const { return results_; } + size_t num_results() const { return results_.size(); } + void* callback() const { return callable_.ptr(); } + + xla::TransposePlanCache& transpose_cache() { return transpose_cache_; } + + absl::Status PrepareAndCall(void* result, void** arg_ptrs); + + absl::StatusOr Call(nanobind::tuple args); + + private: + nanobind::callable callable_; + std::vector args_; + std::vector results_; + xla::TransposePlanCache transpose_cache_; +}; + +} // namespace xla + +#endif // JAXLIB_CALLBACK_H_ diff --git a/jaxlib/config.cc b/jaxlib/config.cc new file mode 100644 index 000000000000..625a5aa5a319 --- /dev/null +++ b/jaxlib/config.cc @@ -0,0 +1,348 @@ +/* Copyright 2024 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/config.h" + +#include + +#include +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_set.h" +#include "absl/synchronization/mutex.h" +#include "absl/types/span.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/optional.h" // IWYU pragma: keep +#include "jaxlib/python_ref_manager.h" +#include "xla/tsl/platform/logging.h" + +namespace jax { + +namespace nb = nanobind; + +// Singleton object used to represent "value not set" in thread-local configs. +nb::object UnsetObject() { + return nb::steal(PyObject_CallObject( + reinterpret_cast(&PyBaseObject_Type), nullptr)); +} + +// Each configuration object has: +// * a global value, and +// * a thread-local value. +// When querying the state of a config, the thread-local value is used if it is +// set. Otherwise, the global value is used. + +// This class represents all of the thread-local configuration state for a +// thread. +class ThreadLocalConfigState { + public: + ThreadLocalConfigState(); + ~ThreadLocalConfigState(); + + static ThreadLocalConfigState& Instance() { + thread_local auto state = std::make_unique(); + return *state; + } + + nb::object Get(int key) { + DCHECK_GE(key, 0); + return key >= entries_.size() ? nb::object() : entries_[key]; + } + + void Set(int key, nb::object value); + + private: + friend class GlobalConfigState; + + // These values are accessed in one of two ways: + // * The owning thread reads or writes them, while holding the GIL, or, under + // free-threading, while the owning thread is in ATTACHED gc state. + // * Other threads may read or clear values while performing a garbage + // collection. + // No locking is needed because a GC thread cannot run concurrently with other + // Python threads; even under free-threading Python uses a stop-the-world GC. + std::vector entries_; +}; + +// This class represents all of the global configuration state. +// TODO(phawkins): to support free-threading, we will need to add locking to +// this class. +class GlobalConfigState { + public: + static GlobalConfigState& Instance() { + static auto state = new GlobalConfigState(); + return *state; + } + + nb::object Get(int key) const; + void Set(int key, nb::object value); + + // Adds or removes a thread-local state from the set of thread-local states. + void AddThreadLocalState(ThreadLocalConfigState* state) { + absl::MutexLock lock(&mu_); + thread_local_states_.insert(state); + } + void RemoveThreadLocalState(ThreadLocalConfigState* state) { + absl::MutexLock lock(&mu_); + thread_local_states_.erase(state); + } + + // Python GC helpers. These are called from the tp_traverse and tp_clear + // methods of the Config class. + int tp_traverse(int key, PyObject* self, visitproc visit, void* arg); + int tp_clear(int key, PyObject* self); + + // Returns the singleton object representing "value not set". + const nb::object& unset() const { return unset_; } + + // Returns the set of keys that should be included in the jit key. + absl::Span include_in_jit_key() const { + return include_in_jit_key_; + } + + private: + friend class Config; + + // The set of thread-local states. This is used during garbage collection to + // visit thread-local values. + absl::Mutex mu_; + absl::flat_hash_set thread_local_states_ + ABSL_GUARDED_BY(mu_); + std::vector entries_; + std::vector include_in_jit_key_; + nb::object unset_ = UnsetObject(); +}; + +ThreadLocalConfigState::ThreadLocalConfigState() { + GlobalConfigState::Instance().AddThreadLocalState(this); +} + +ThreadLocalConfigState::~ThreadLocalConfigState() { + // It's important that we remove the thread-local state before we access + // entries_. This ensures that accesses to entries_ are ordered with respect + // any garbage collection. + GlobalConfigState::Instance().RemoveThreadLocalState(this); + // We do not hold the GIL, so we must use deferred destruction. + xla::GlobalPyRefManager()->AddGarbage(absl::MakeSpan(entries_)); +} + +void ThreadLocalConfigState::Set(int key, nb::object value) { + DCHECK_GE(key, 0); + if (key >= entries_.size()) { + entries_.resize(key + 1); + } + std::swap(entries_[key], value); +} + +nb::object GlobalConfigState::Get(int key) const { + DCHECK_GE(key, 0); + DCHECK_LT(key, entries_.size()); + return entries_[key]; +} + +void GlobalConfigState::Set(int key, nb::object value) { + DCHECK_GE(key, 0); + DCHECK_LT(key, entries_.size()); + std::swap(entries_[key], value); +} + +int GlobalConfigState::tp_traverse(int key, PyObject* self, visitproc visit, + void* arg) { + DCHECK_GE(key, 0); + if (key < entries_.size()) { + PyObject* value = entries_[key].ptr(); + Py_VISIT(value); + } + absl::MutexLock lock(&mu_); + for (const auto* state : thread_local_states_) { + if (key < state->entries_.size()) { + PyObject* value = state->entries_[key].ptr(); + Py_VISIT(value); + } + } + return 0; +} + +int GlobalConfigState::tp_clear(int key, PyObject* self) { + if (key < entries_.size()) { + nb::object tmp; + std::swap(entries_[key], tmp); + } + // We destroy the python objects outside of the lock out of an abundance of + // caution. + std::vector to_destroy; + absl::MutexLock lock(&mu_); + to_destroy.reserve(thread_local_states_.size()); + for (auto* state : thread_local_states_) { + if (key < state->entries_.size()) { + nb::object tmp; + std::swap(state->entries_[key], tmp); + to_destroy.push_back(std::move(tmp)); + } + } + return 0; +} + +// A Config object represents a configurable object with both global and +// thread-local state. This class is wrapped using nanobind and exposed to +// Python. +class Config { + public: + Config(nb::object value, bool include_in_jit_key); + + // Returns the thread-local value if it is set, otherwise the global value. + nb::object Get(); + + // Returns the global value. + nb::object GetGlobal(); + + // Sets the global value. + void SetGlobal(nb::object value); + + // Returns the thread-local value. + nb::object GetLocal(); + + // Sets the thread-local value. May be `unset`. + void SetLocal(nb::object value); + + // Swaps the thread-local value with `value`. Returns the previous value. + // Either may be `unset`. + nb::object SwapLocal(nb::object value); + + // This class doesn't actually hold any data, but it's the only type + // known to Python. We pretend that this object owns both the global and any + // thread-local values corresponding to this key. + static int tp_traverse(PyObject* self, visitproc visit, void* arg); + static int tp_clear(PyObject* self); + static PyType_Slot slots_[]; + + private: + int key_; +}; + +Config::Config(nb::object value, bool include_in_jit_key) { + auto& instance = GlobalConfigState::Instance(); + key_ = instance.entries_.size(); + instance.entries_.push_back(std::move(value)); + if (include_in_jit_key) { + instance.include_in_jit_key_.push_back(key_); + } +} + +nb::object Config::GetLocal() { + nb::object result = ThreadLocalConfigState::Instance().Get(key_); + if (!result.is_valid()) { + return GlobalConfigState::Instance().unset(); + } + return result; +} + +nb::object Config::GetGlobal() { + return GlobalConfigState::Instance().Get(key_); +} + +nb::object Config::Get() { + nb::object local = ThreadLocalConfigState::Instance().Get(key_); + if (local.is_valid()) { + return local; + } + return GetGlobal(); +} + +void Config::SetLocal(nb::object value) { + const auto& instance = GlobalConfigState::Instance(); + if (value.ptr() == instance.unset().ptr()) { + value = nb::object(); + } + ThreadLocalConfigState::Instance().Set(key_, std::move(value)); +} + +nb::object Config::SwapLocal(nb::object value) { + const auto& global_instance = GlobalConfigState::Instance(); + auto& instance = ThreadLocalConfigState::Instance(); + auto result = instance.Get(key_); + if (value.ptr() == global_instance.unset().ptr()) { + value = nb::object(); + } + instance.Set(key_, std::move(value)); + if (!result.is_valid()) { + return global_instance.unset(); + } + return result; +} + +void Config::SetGlobal(nb::object value) { + GlobalConfigState::Instance().Set(key_, value); +} + +/* static */ int Config::tp_traverse(PyObject* self, visitproc visit, + void* arg) { + Py_VISIT(Py_TYPE(self)); + if (!nb::inst_ready(self)) { + return 0; + } + Config* c = nb::inst_ptr(self); + // For the purposes of GC, we pretend that this object owns both the global + // and any thread-local values corresponding to this key. + return GlobalConfigState::Instance().tp_traverse(c->key_, self, visit, arg); +} + +/* static */ int Config::tp_clear(PyObject* self) { + Config* c = nb::inst_ptr(self); + return GlobalConfigState::Instance().tp_clear(c->key_, self); +} + +PyType_Slot Config::slots_[] = { + {Py_tp_traverse, reinterpret_cast(Config::tp_traverse)}, + {Py_tp_clear, reinterpret_cast(Config::tp_clear)}, + {0, nullptr}, +}; + +void BuildConfigSubmodule(nanobind::module_& m) { + nb::module_ config_module = m.def_submodule("config", "Config library"); + + config_module.attr("unset") = GlobalConfigState::Instance().unset(); + + nb::class_ config(config_module, "Config", + nb::type_slots(Config::slots_), nb::is_generic()); + config.def(nb::init(), nb::arg("value").none(), + nb::arg("include_in_jit_key") = false); + config.def_prop_ro("value", &Config::Get); + config.def("get_local", &Config::GetLocal); + config.def("get_global", &Config::GetGlobal); + config.def("set_local", &Config::SetLocal, nb::arg("value").none()); + config.def("swap_local", &Config::SwapLocal, nb::arg("value").none()); + config.def("set_global", &Config::SetGlobal, nb::arg("value").none()); +} + +std::vector JitConfigs() { + auto& instance = GlobalConfigState::Instance(); + auto& thread_local_instance = ThreadLocalConfigState::Instance(); + std::vector result; + result.reserve(instance.include_in_jit_key().size()); + for (int i : instance.include_in_jit_key()) { + nb::object local = thread_local_instance.Get(i); + if (local.is_valid()) { + result.push_back(std::move(local)); + } else { + result.push_back(instance.Get(i)); + } + } + return result; +} + +} // namespace jax diff --git a/jaxlib/config.h b/jaxlib/config.h new file mode 100644 index 000000000000..e42673cb66fa --- /dev/null +++ b/jaxlib/config.h @@ -0,0 +1,34 @@ +/* Copyright 2024 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_CONFIG_H_ +#define JAXLIB_CONFIG_H_ + +#include + +// placeholder for index annotation headers +#include "nanobind/nanobind.h" + +namespace jax { + +// Returns the set of configuration values that should be included in the JIT +// cache key. +std::vector JitConfigs(); + +void BuildConfigSubmodule(nanobind::module_& m); + +} // namespace jax + +#endif // JAXLIB_CONFIG_H_ diff --git a/jaxlib/config_test.py b/jaxlib/config_test.py new file mode 100644 index 000000000000..734e9ed78896 --- /dev/null +++ b/jaxlib/config_test.py @@ -0,0 +1,71 @@ +# Copyright 2024 The JAX Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import threading + +from absl.testing import absltest + +from jax.jaxlib import xla_client + +config = xla_client._xla.config + + +class ConfigTest(absltest.TestCase): + + def testBasic(self): + c = config.Config(1) + self.assertEqual(c.value, 1) + self.assertEqual(c.get_global(), 1) + self.assertEqual(c.get_local(), config.unset) + + c.set_global(2) + self.assertEqual(c.value, 2) + self.assertEqual(c.get_global(), 2) + self.assertEqual(c.get_local(), config.unset) + + c.set_local(3) + self.assertEqual(c.value, 3) + self.assertEqual(c.get_global(), 2) + self.assertEqual(c.get_local(), 3) + + c.set_global(4) + self.assertEqual(c.value, 3) + self.assertEqual(c.get_global(), 4) + self.assertEqual(c.get_local(), 3) + + c.set_local(config.unset) + self.assertEqual(c.value, 4) + self.assertEqual(c.get_global(), 4) + self.assertEqual(c.get_local(), config.unset) + + def testThreading(self): + c = config.Config(1) + + def Body(): + for i in range(100): + c.set_local(i) + self.assertEqual(c.get_local(), i) + self.assertEqual(c.get_global(), 1) + self.assertEqual(c.value, i) + + threads = [threading.Thread(target=Body) for _ in range(4)] + for t in threads: + t.start() + for t in threads: + t.join() + + +if __name__ == "__main__": + absltest.main() diff --git a/jaxlib/cpu/BUILD b/jaxlib/cpu/BUILD index 76934df6c37b..349b64b4ce3b 100644 --- a/jaxlib/cpu/BUILD +++ b/jaxlib/cpu/BUILD @@ -42,7 +42,6 @@ cc_library( "@com_google_absl//absl/types:span", "@xla//xla/ffi/api:c_api", "@xla//xla/ffi/api:ffi", - "@xla//xla/service:custom_call_status", ], ) @@ -85,9 +84,41 @@ cc_library( deps = [ ":lapack_kernels", ":lapack_kernels_using_lapack", + ":sparse_kernels", "@xla//xla/ffi/api:c_api", "@xla//xla/ffi/api:ffi", - "@xla//xla/service:custom_call_target_registry", ], alwayslink = 1, ) + +cc_library( + name = "sparse_kernels", + srcs = ["sparse_kernels.cc"], + hdrs = ["sparse_kernels.h"], + deps = [ + "@eigen_archive//:eigen3", + "@xla//xla/ffi/api:ffi", + ], +) + +nanobind_extension( + name = "_sparse", + srcs = ["sparse.cc"], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + enable_stub_generation = False, + features = ["-use_header_modules"], + module_name = "_sparse", + pytype_srcs = [ + "_sparse/__init__.pyi", + ], + deps = [ + ":sparse_kernels", + "//jaxlib:kernel_nanobind_helpers", + "@com_google_absl//absl/base", + "@nanobind", + "@xla//xla/ffi/api:ffi", + ], +) diff --git a/jaxlib/cpu/_lapack/__init__.pyi b/jaxlib/cpu/_lapack/__init__.pyi index 4275d8e48813..f8b9a023b480 100644 --- a/jaxlib/cpu/_lapack/__init__.pyi +++ b/jaxlib/cpu/_lapack/__init__.pyi @@ -17,39 +17,3 @@ from . import eig as eig def initialize() -> None: ... def registrations() -> dict: ... - - -# Old-style LAPACK Workspace Size Queries -def cgesdd_rwork_size(m: int, n: int, compute_uv: int) -> int: ... -def cgesdd_work_size(m: int, n: int, job_opt_compute_uv: bool, job_opt_full_matrices: bool) -> int: ... -def dgesdd_work_size(m: int, n: int, job_opt_compute_uv: bool, job_opt_full_matrices: bool) -> int: ... -def gesdd_iwork_size(m: int, n: int) -> int: ... -def heevd_rwork_size(n: int) -> int: ... -def heevd_work_size(n: int) -> int: ... -def lapack_cgehrd_workspace(lda: int, n: int, ilo: int, ihi: int) -> int: ... -def lapack_cgeqrf_workspace(m: int, n: int) -> int: ... -def lapack_chetrd_workspace(lda: int, n: int) -> int: ... -def lapack_cungqr_workspace(m: int, n: int, k: int) -> int: ... -def lapack_dgehrd_workspace(lda: int, n: int, ilo: int, ihi: int) -> int: ... -def lapack_dgeqrf_workspace(m: int, n: int) -> int: ... -def lapack_dorgqr_workspace(m: int, n: int, k: int) -> int: ... -def lapack_dsytrd_workspace(lda: int, n: int) -> int: ... -def lapack_sgehrd_workspace(lda: int, n: int, ilo: int, ihi: int) -> int: ... -def lapack_sgeqrf_workspace(m: int, n: int) -> int: ... -def lapack_sorgqr_workspace(m: int, n: int, k: int) -> int: ... -def lapack_ssytrd_workspace(lda: int, n: int) -> int: ... -def lapack_zgehrd_workspace(lda: int, n: int, ilo: int, ihi: int) -> int: ... -def lapack_zgeqrf_workspace(m: int, n: int) -> int: ... -def lapack_zhetrd_workspace(lda: int, n: int) -> int: ... -def lapack_zungqr_workspace(m: int, n: int, k: int) -> int: ... -def sgesdd_work_size(m: int, n: int, job_opt_compute_uv: bool, job_opt_full_matrices: bool) -> int: ... -def syevd_iwork_size(n: int) -> int: ... -def syevd_work_size(n: int) -> int: ... -def zgesdd_work_size(m: int, n: int, job_opt_compute_uv: bool, job_opt_full_matrices: bool) -> int: ... - - -# FFI Kernel LAPACK Workspace Size Queries -def lapack_cungqr_workspace_ffi(m: int, n: int, k: int) -> int: ... -def lapack_dorgqr_workspace_ffi(m: int, n: int, k: int) -> int: ... -def lapack_sorgqr_workspace_ffi(m: int, n: int, k: int) -> int: ... -def lapack_zungqr_workspace_ffi(m: int, n: int, k: int) -> int: ... diff --git a/jaxlib/cpu/_sparse/__init__.pyi b/jaxlib/cpu/_sparse/__init__.pyi new file mode 100644 index 000000000000..a82f83b267b7 --- /dev/null +++ b/jaxlib/cpu/_sparse/__init__.pyi @@ -0,0 +1,15 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +def registrations() -> dict: ... diff --git a/jaxlib/cpu/cpu_kernels.cc b/jaxlib/cpu/cpu_kernels.cc index 6ed42496f2f2..3d75a02f6ae3 100644 --- a/jaxlib/cpu/cpu_kernels.cc +++ b/jaxlib/cpu/cpu_kernels.cc @@ -19,9 +19,9 @@ limitations under the License. #include #include "jaxlib/cpu/lapack_kernels.h" +#include "jaxlib/cpu/sparse_kernels.h" #include "xla/ffi/api/c_api.h" #include "xla/ffi/api/ffi.h" -#include "xla/service/custom_call_target_registry.h" #define JAX_CPU_REGISTER_HANDLER(name) \ XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), #name, "Host", name); @@ -29,94 +29,6 @@ limitations under the License. namespace jax { namespace { -// Old-style kernels -// TODO(b/344892332): To be removed after the 6M compatibility period is over. - -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("blas_strsm", Trsm::Kernel, - "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("blas_dtrsm", Trsm::Kernel, - "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("blas_ctrsm", - Trsm>::Kernel, - "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("blas_ztrsm", - Trsm>::Kernel, - "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lapack_sgetrf", Getrf::Kernel, - "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lapack_dgetrf", Getrf::Kernel, - "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lapack_cgetrf", - Getrf>::Kernel, - "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lapack_zgetrf", - Getrf>::Kernel, - "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lapack_sgeqrf", Geqrf::Kernel, - "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lapack_dgeqrf", Geqrf::Kernel, - "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lapack_cgeqrf", - Geqrf>::Kernel, - "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lapack_zgeqrf", - Geqrf>::Kernel, - "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lapack_sorgqr", Orgqr::Kernel, - "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lapack_dorgqr", Orgqr::Kernel, - "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lapack_cungqr", - Orgqr>::Kernel, - "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lapack_zungqr", - Orgqr>::Kernel, - "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lapack_spotrf", Potrf::Kernel, - "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lapack_dpotrf", Potrf::Kernel, - "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lapack_cpotrf", - Potrf>::Kernel, - "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lapack_zpotrf", - Potrf>::Kernel, - "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lapack_sgesdd", - RealGesdd::Kernel, "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lapack_dgesdd", - RealGesdd::Kernel, "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM( - "lapack_cgesdd", ComplexGesdd>::Kernel, "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM( - "lapack_zgesdd", ComplexGesdd>::Kernel, "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lapack_ssyevd", - RealSyevd::Kernel, "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lapack_dsyevd", - RealSyevd::Kernel, "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM( - "lapack_cheevd", ComplexHeevd>::Kernel, "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM( - "lapack_zheevd", ComplexHeevd>::Kernel, "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lapack_sgeev", - RealGeev::Kernel, "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lapack_dgeev", - RealGeev::Kernel, "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM( - "lapack_cgeev", ComplexGeev>::Kernel, "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM( - "lapack_zgeev", ComplexGeev>::Kernel, "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lapack_sgees", - RealGees::Kernel, "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lapack_dgees", - RealGees::Kernel, "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM( - "lapack_cgees", ComplexGees>::Kernel, "Host"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM( - "lapack_zgees", ComplexGees>::Kernel, "Host"); - -// FFI Kernels - JAX_CPU_REGISTER_HANDLER(lapack_strsm_ffi); JAX_CPU_REGISTER_HANDLER(lapack_dtrsm_ffi); JAX_CPU_REGISTER_HANDLER(lapack_ctrsm_ffi); @@ -174,6 +86,8 @@ JAX_CPU_REGISTER_HANDLER(lapack_dgtsv_ffi); JAX_CPU_REGISTER_HANDLER(lapack_cgtsv_ffi); JAX_CPU_REGISTER_HANDLER(lapack_zgtsv_ffi); +JAX_CPU_REGISTER_HANDLER(cpu_csr_sparse_dense_ffi); + #undef JAX_CPU_REGISTER_HANDLER } // namespace diff --git a/jaxlib/cpu/lapack.cc b/jaxlib/cpu/lapack.cc index c104019777e5..b9c92210f311 100644 --- a/jaxlib/cpu/lapack.cc +++ b/jaxlib/cpu/lapack.cc @@ -13,10 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include - -#include "nanobind/nanobind.h" #include "absl/base/call_once.h" +#include "nanobind/nanobind.h" #include "jaxlib/cpu/lapack_kernels.h" #include "jaxlib/kernel_nanobind_helpers.h" @@ -45,10 +43,6 @@ void GetLapackKernelsFromScipy() { return nb::cast(blas_capi[name]).data(); }; - AssignKernelFn>(blas_ptr("strsm")); - AssignKernelFn>(blas_ptr("dtrsm")); - AssignKernelFn>>(blas_ptr("ctrsm")); - AssignKernelFn>>(blas_ptr("ztrsm")); AssignKernelFn>(blas_ptr("strsm")); AssignKernelFn>(blas_ptr("dtrsm")); AssignKernelFn>(blas_ptr("ctrsm")); @@ -58,19 +52,11 @@ void GetLapackKernelsFromScipy() { auto lapack_ptr = [&](const char* name) { return nb::cast(lapack_capi[name]).data(); }; - AssignKernelFn>(lapack_ptr("sgetrf")); - AssignKernelFn>(lapack_ptr("dgetrf")); - AssignKernelFn>>(lapack_ptr("cgetrf")); - AssignKernelFn>>(lapack_ptr("zgetrf")); AssignKernelFn>(lapack_ptr("sgetrf")); AssignKernelFn>(lapack_ptr("dgetrf")); AssignKernelFn>(lapack_ptr("cgetrf")); AssignKernelFn>(lapack_ptr("zgetrf")); - AssignKernelFn>(lapack_ptr("sgeqrf")); - AssignKernelFn>(lapack_ptr("dgeqrf")); - AssignKernelFn>>(lapack_ptr("cgeqrf")); - AssignKernelFn>>(lapack_ptr("zgeqrf")); AssignKernelFn>(lapack_ptr("sgeqrf")); AssignKernelFn>(lapack_ptr("dgeqrf")); AssignKernelFn>(lapack_ptr("cgeqrf")); @@ -85,28 +71,16 @@ void GetLapackKernelsFromScipy() { AssignKernelFn>( lapack_ptr("zgeqp3")); - AssignKernelFn>(lapack_ptr("sorgqr")); - AssignKernelFn>(lapack_ptr("dorgqr")); - AssignKernelFn>>(lapack_ptr("cungqr")); - AssignKernelFn>>(lapack_ptr("zungqr")); AssignKernelFn>(lapack_ptr("sorgqr")); AssignKernelFn>(lapack_ptr("dorgqr")); AssignKernelFn>(lapack_ptr("cungqr")); AssignKernelFn>(lapack_ptr("zungqr")); - AssignKernelFn>(lapack_ptr("spotrf")); - AssignKernelFn>(lapack_ptr("dpotrf")); - AssignKernelFn>>(lapack_ptr("cpotrf")); - AssignKernelFn>>(lapack_ptr("zpotrf")); AssignKernelFn>(lapack_ptr("spotrf")); AssignKernelFn>(lapack_ptr("dpotrf")); AssignKernelFn>(lapack_ptr("cpotrf")); AssignKernelFn>(lapack_ptr("zpotrf")); - AssignKernelFn>(lapack_ptr("sgesdd")); - AssignKernelFn>(lapack_ptr("dgesdd")); - AssignKernelFn>>(lapack_ptr("cgesdd")); - AssignKernelFn>>(lapack_ptr("zgesdd")); AssignKernelFn>(lapack_ptr("sgesdd")); AssignKernelFn>(lapack_ptr("dgesdd")); AssignKernelFn>(lapack_ptr("cgesdd")); @@ -116,10 +90,6 @@ void GetLapackKernelsFromScipy() { AssignKernelFn>(lapack_ptr("cgesvd")); AssignKernelFn>(lapack_ptr("zgesvd")); - AssignKernelFn>(lapack_ptr("ssyevd")); - AssignKernelFn>(lapack_ptr("dsyevd")); - AssignKernelFn>>(lapack_ptr("cheevd")); - AssignKernelFn>>(lapack_ptr("zheevd")); AssignKernelFn>( lapack_ptr("ssyevd")); AssignKernelFn>( @@ -129,10 +99,6 @@ void GetLapackKernelsFromScipy() { AssignKernelFn>( lapack_ptr("zheevd")); - AssignKernelFn>(lapack_ptr("sgeev")); - AssignKernelFn>(lapack_ptr("dgeev")); - AssignKernelFn>>(lapack_ptr("cgeev")); - AssignKernelFn>>(lapack_ptr("zgeev")); AssignKernelFn>(lapack_ptr("sgeev")); AssignKernelFn>(lapack_ptr("dgeev")); AssignKernelFn>( @@ -140,10 +106,6 @@ void GetLapackKernelsFromScipy() { AssignKernelFn>( lapack_ptr("zgeev")); - AssignKernelFn>(lapack_ptr("sgees")); - AssignKernelFn>(lapack_ptr("dgees")); - AssignKernelFn>>(lapack_ptr("cgees")); - AssignKernelFn>>(lapack_ptr("zgees")); AssignKernelFn>(lapack_ptr("sgees")); AssignKernelFn>(lapack_ptr("dgees")); AssignKernelFn>( @@ -151,10 +113,6 @@ void GetLapackKernelsFromScipy() { AssignKernelFn>( lapack_ptr("zgees")); - AssignKernelFn>(lapack_ptr("sgehrd")); - AssignKernelFn>(lapack_ptr("dgehrd")); - AssignKernelFn>>(lapack_ptr("cgehrd")); - AssignKernelFn>>(lapack_ptr("zgehrd")); AssignKernelFn>( lapack_ptr("sgehrd")); AssignKernelFn>( @@ -164,10 +122,6 @@ void GetLapackKernelsFromScipy() { AssignKernelFn>( lapack_ptr("zgehrd")); - AssignKernelFn>(lapack_ptr("ssytrd")); - AssignKernelFn>(lapack_ptr("dsytrd")); - AssignKernelFn>>(lapack_ptr("chetrd")); - AssignKernelFn>>(lapack_ptr("zhetrd")); AssignKernelFn>(lapack_ptr("ssytrd")); AssignKernelFn>(lapack_ptr("dsytrd")); AssignKernelFn>(lapack_ptr("chetrd")); @@ -182,74 +136,6 @@ void GetLapackKernelsFromScipy() { nb::dict Registrations() { nb::dict dict; - dict["blas_strsm"] = EncapsulateFunction(Trsm::Kernel); - dict["blas_dtrsm"] = EncapsulateFunction(Trsm::Kernel); - dict["blas_ctrsm"] = EncapsulateFunction(Trsm>::Kernel); - dict["blas_ztrsm"] = EncapsulateFunction(Trsm>::Kernel); - dict["lapack_sgetrf"] = EncapsulateFunction(Getrf::Kernel); - dict["lapack_dgetrf"] = EncapsulateFunction(Getrf::Kernel); - dict["lapack_cgetrf"] = - EncapsulateFunction(Getrf>::Kernel); - dict["lapack_zgetrf"] = - EncapsulateFunction(Getrf>::Kernel); - dict["lapack_sgeqrf"] = EncapsulateFunction(Geqrf::Kernel); - dict["lapack_dgeqrf"] = EncapsulateFunction(Geqrf::Kernel); - dict["lapack_cgeqrf"] = - EncapsulateFunction(Geqrf>::Kernel); - dict["lapack_zgeqrf"] = - EncapsulateFunction(Geqrf>::Kernel); - dict["lapack_sorgqr"] = EncapsulateFunction(Orgqr::Kernel); - dict["lapack_dorgqr"] = EncapsulateFunction(Orgqr::Kernel); - dict["lapack_cungqr"] = - EncapsulateFunction(Orgqr>::Kernel); - dict["lapack_zungqr"] = - EncapsulateFunction(Orgqr>::Kernel); - dict["lapack_spotrf"] = EncapsulateFunction(Potrf::Kernel); - dict["lapack_dpotrf"] = EncapsulateFunction(Potrf::Kernel); - dict["lapack_cpotrf"] = - EncapsulateFunction(Potrf>::Kernel); - dict["lapack_zpotrf"] = - EncapsulateFunction(Potrf>::Kernel); - dict["lapack_sgesdd"] = EncapsulateFunction(RealGesdd::Kernel); - dict["lapack_dgesdd"] = EncapsulateFunction(RealGesdd::Kernel); - dict["lapack_cgesdd"] = - EncapsulateFunction(ComplexGesdd>::Kernel); - dict["lapack_zgesdd"] = - EncapsulateFunction(ComplexGesdd>::Kernel); - dict["lapack_ssyevd"] = EncapsulateFunction(RealSyevd::Kernel); - dict["lapack_dsyevd"] = EncapsulateFunction(RealSyevd::Kernel); - dict["lapack_cheevd"] = - EncapsulateFunction(ComplexHeevd>::Kernel); - dict["lapack_zheevd"] = - EncapsulateFunction(ComplexHeevd>::Kernel); - dict["lapack_sgeev"] = EncapsulateFunction(RealGeev::Kernel); - dict["lapack_dgeev"] = EncapsulateFunction(RealGeev::Kernel); - dict["lapack_cgeev"] = - EncapsulateFunction(ComplexGeev>::Kernel); - dict["lapack_zgeev"] = - EncapsulateFunction(ComplexGeev>::Kernel); - - dict["lapack_sgees"] = EncapsulateFunction(RealGees::Kernel); - dict["lapack_dgees"] = EncapsulateFunction(RealGees::Kernel); - dict["lapack_cgees"] = - EncapsulateFunction(ComplexGees>::Kernel); - dict["lapack_zgees"] = - EncapsulateFunction(ComplexGees>::Kernel); - - dict["lapack_sgehrd"] = EncapsulateFunction(Gehrd::Kernel); - dict["lapack_dgehrd"] = EncapsulateFunction(Gehrd::Kernel); - dict["lapack_cgehrd"] = - EncapsulateFunction(Gehrd>::Kernel); - dict["lapack_zgehrd"] = - EncapsulateFunction(Gehrd>::Kernel); - - dict["lapack_ssytrd"] = EncapsulateFunction(Sytrd::Kernel); - dict["lapack_dsytrd"] = EncapsulateFunction(Sytrd::Kernel); - dict["lapack_chetrd"] = - EncapsulateFunction(Sytrd>::Kernel); - dict["lapack_zhetrd"] = - EncapsulateFunction(Sytrd>::Kernel); - dict["lapack_strsm_ffi"] = EncapsulateFunction(lapack_strsm_ffi); dict["lapack_dtrsm_ffi"] = EncapsulateFunction(lapack_dtrsm_ffi); dict["lapack_ctrsm_ffi"] = EncapsulateFunction(lapack_ctrsm_ffi); @@ -335,73 +221,6 @@ NB_MODULE(_lapack, m) { nb::enum_(schur, "Sort") .value("kNoSortEigenvalues", schur::Sort::kNoSortEigenvalues) .value("kSortEigenvalues", schur::Sort::kSortEigenvalues); - - // Old-style LAPACK Workspace Size Queries - m.def("lapack_sgeqrf_workspace", &Geqrf::Workspace, nb::arg("m"), - nb::arg("n")); - m.def("lapack_dgeqrf_workspace", &Geqrf::Workspace, nb::arg("m"), - nb::arg("n")); - m.def("lapack_cgeqrf_workspace", &Geqrf>::Workspace, - nb::arg("m"), nb::arg("n")); - m.def("lapack_zgeqrf_workspace", &Geqrf>::Workspace, - nb::arg("m"), nb::arg("n")); - m.def("lapack_sorgqr_workspace", &Orgqr::Workspace, nb::arg("m"), - nb::arg("n"), nb::arg("k")); - m.def("lapack_dorgqr_workspace", &Orgqr::Workspace, nb::arg("m"), - nb::arg("n"), nb::arg("k")); - m.def("lapack_cungqr_workspace", &Orgqr>::Workspace, - nb::arg("m"), nb::arg("n"), nb::arg("k")); - m.def("lapack_zungqr_workspace", &Orgqr>::Workspace, - nb::arg("m"), nb::arg("n"), nb::arg("k")); - m.def("gesdd_iwork_size", &GesddIworkSize, nb::arg("m"), nb::arg("n")); - m.def("sgesdd_work_size", &RealGesdd::Workspace, nb::arg("m"), - nb::arg("n"), nb::arg("job_opt_compute_uv"), - nb::arg("job_opt_full_matrices")); - m.def("dgesdd_work_size", &RealGesdd::Workspace, nb::arg("m"), - nb::arg("n"), nb::arg("job_opt_compute_uv"), - nb::arg("job_opt_full_matrices")); - m.def("cgesdd_rwork_size", &ComplexGesddRworkSize, nb::arg("m"), nb::arg("n"), - nb::arg("compute_uv")); - m.def("cgesdd_work_size", &ComplexGesdd>::Workspace, - nb::arg("m"), nb::arg("n"), nb::arg("job_opt_compute_uv"), - nb::arg("job_opt_full_matrices")); - m.def("zgesdd_work_size", &ComplexGesdd>::Workspace, - nb::arg("m"), nb::arg("n"), nb::arg("job_opt_compute_uv"), - nb::arg("job_opt_full_matrices")); - m.def("syevd_work_size", &SyevdWorkSize, nb::arg("n")); - m.def("syevd_iwork_size", &SyevdIworkSize, nb::arg("n")); - m.def("heevd_work_size", &HeevdWorkSize, nb::arg("n")); - m.def("heevd_rwork_size", &HeevdRworkSize, nb::arg("n")); - - m.def("lapack_sgehrd_workspace", &Gehrd::Workspace, nb::arg("lda"), - nb::arg("n"), nb::arg("ilo"), nb::arg("ihi")); - m.def("lapack_dgehrd_workspace", &Gehrd::Workspace, nb::arg("lda"), - nb::arg("n"), nb::arg("ilo"), nb::arg("ihi")); - m.def("lapack_cgehrd_workspace", &Gehrd>::Workspace, - nb::arg("lda"), nb::arg("n"), nb::arg("ilo"), nb::arg("ihi")); - m.def("lapack_zgehrd_workspace", &Gehrd>::Workspace, - nb::arg("lda"), nb::arg("n"), nb::arg("ilo"), nb::arg("ihi")); - m.def("lapack_ssytrd_workspace", &Sytrd::Workspace, nb::arg("lda"), - nb::arg("n")); - m.def("lapack_dsytrd_workspace", &Sytrd::Workspace, nb::arg("lda"), - nb::arg("n")); - m.def("lapack_chetrd_workspace", &Sytrd>::Workspace, - nb::arg("lda"), nb::arg("n")); - m.def("lapack_zhetrd_workspace", &Sytrd>::Workspace, - nb::arg("lda"), nb::arg("n")); - // FFI Kernel LAPACK Workspace Size Queries - m.def("lapack_sorgqr_workspace_ffi", - &OrthogonalQr::GetWorkspaceSize, nb::arg("m"), - nb::arg("n"), nb::arg("k")); - m.def("lapack_dorgqr_workspace_ffi", - &OrthogonalQr::GetWorkspaceSize, nb::arg("m"), - nb::arg("n"), nb::arg("k")); - m.def("lapack_cungqr_workspace_ffi", - &OrthogonalQr::GetWorkspaceSize, nb::arg("m"), - nb::arg("n"), nb::arg("k")); - m.def("lapack_zungqr_workspace_ffi", - &OrthogonalQr::GetWorkspaceSize, nb::arg("m"), - nb::arg("n"), nb::arg("k")); } } // namespace diff --git a/jaxlib/cpu/lapack_kernels.cc b/jaxlib/cpu/lapack_kernels.cc index ddc93261eeb5..4ec8a73801a6 100644 --- a/jaxlib/cpu/lapack_kernels.cc +++ b/jaxlib/cpu/lapack_kernels.cc @@ -18,14 +18,11 @@ limitations under the License. #include #include #include -#include #include -#include #include -#include #include #include -#include +#include #include "absl/algorithm/container.h" #include "absl/base/dynamic_annotations.h" @@ -34,45 +31,25 @@ limitations under the License. #include "jaxlib/ffi_helpers.h" #include "xla/ffi/api/c_api.h" #include "xla/ffi/api/ffi.h" -#include "xla/service/custom_call_status.h" static_assert(sizeof(jax::lapack_int) == sizeof(int32_t), "Expected LAPACK integers to be 32-bit"); namespace ffi = xla::ffi; -#define REGISTER_CHAR_ENUM_ATTR_DECODING(type) \ - std::optional xla::ffi::AttrDecoding::Decode( \ - XLA_FFI_AttrType attr_type, void* attr, DiagnosticEngine& diagnostic) { \ - if (attr_type != XLA_FFI_AttrType_SCALAR) [[unlikely]] { \ - return diagnostic.Emit("Wrong attribute type: expected ") \ - << XLA_FFI_AttrType_SCALAR << " but got" << attr_type; \ - } \ - auto* scalar = reinterpret_cast(attr); \ - if (scalar->dtype != XLA_FFI_DataType_U8) [[unlikely]] { \ - return diagnostic.Emit("Wrong scalar data type: expected ") \ - << XLA_FFI_DataType_U8 << " but got " << scalar->dtype; \ - } \ - auto underlying = \ - *reinterpret_cast*>(scalar->value); \ - return static_cast(underlying); \ - } - -REGISTER_CHAR_ENUM_ATTR_DECODING(jax::MatrixParams::Side); -REGISTER_CHAR_ENUM_ATTR_DECODING(jax::MatrixParams::Transpose); -REGISTER_CHAR_ENUM_ATTR_DECODING(jax::MatrixParams::Diag); -REGISTER_CHAR_ENUM_ATTR_DECODING(jax::MatrixParams::UpLo); -REGISTER_CHAR_ENUM_ATTR_DECODING(jax::svd::ComputationMode); -REGISTER_CHAR_ENUM_ATTR_DECODING(jax::eig::ComputationMode); -REGISTER_CHAR_ENUM_ATTR_DECODING(jax::schur::ComputationMode); -REGISTER_CHAR_ENUM_ATTR_DECODING(jax::schur::Sort); - -#undef REGISTER_CHAR_ENUM_ATTR_DECODING +XLA_FFI_REGISTER_ENUM_ATTR_DECODING(jax::MatrixParams::Side); +XLA_FFI_REGISTER_ENUM_ATTR_DECODING(jax::MatrixParams::Transpose); +XLA_FFI_REGISTER_ENUM_ATTR_DECODING(jax::MatrixParams::Diag); +XLA_FFI_REGISTER_ENUM_ATTR_DECODING(jax::MatrixParams::UpLo); +XLA_FFI_REGISTER_ENUM_ATTR_DECODING(jax::svd::ComputationMode); +XLA_FFI_REGISTER_ENUM_ATTR_DECODING(jax::eig::ComputationMode); +XLA_FFI_REGISTER_ENUM_ATTR_DECODING(jax::schur::ComputationMode); +XLA_FFI_REGISTER_ENUM_ATTR_DECODING(jax::schur::Sort); namespace jax { template -inline T CastNoOverflow(int64_t value, const std::string& source = __FILE__) { +inline T CastNoOverflow(int64_t value, std::string_view source = __FILE__) { auto result = MaybeCastNoOverflow(value, source); if (!result.ok()) { throw std::overflow_error{std::string(result.status().message())}; @@ -90,67 +67,12 @@ void CopyIfDiffBuffer(ffi::Buffer x, ffi::ResultBuffer x_out) { //== Triangular System Solver ==// -// lapack trsm - -template -typename Trsm::FnType* Trsm::fn = nullptr; - -template -void Trsm::Kernel(void* out, void** data, XlaCustomCallStatus*) { - int32_t left_side = *reinterpret_cast(data[0]); - int32_t lower = *reinterpret_cast(data[1]); - int32_t trans_a = *reinterpret_cast(data[2]); - int32_t diag = *reinterpret_cast(data[3]); - int m = *reinterpret_cast(data[4]); - int n = *reinterpret_cast(data[5]); - int batch = *reinterpret_cast(data[6]); - T* alpha = reinterpret_cast(data[7]); - T* a = reinterpret_cast(data[8]); - T* b = reinterpret_cast(data[9]); - - T* x = reinterpret_cast(out); - if (x != b) { - std::memcpy(x, b, - static_cast(batch) * static_cast(m) * - static_cast(n) * sizeof(T)); - } - - char cside = left_side ? 'L' : 'R'; - char cuplo = lower ? 'L' : 'U'; - char ctransa = 'N'; - if (trans_a == 1) { - ctransa = 'T'; - } else if (trans_a == 2) { - ctransa = 'C'; - } - char cdiag = diag ? 'U' : 'N'; - int lda = left_side ? m : n; - int ldb = m; - - int64_t x_plus = static_cast(m) * static_cast(n); - int64_t a_plus = static_cast(lda) * static_cast(lda); - - for (int i = 0; i < batch; ++i) { - fn(&cside, &cuplo, &ctransa, &cdiag, &m, &n, alpha, a, &lda, x, &ldb); - x += x_plus; - a += a_plus; - } -} - -template struct Trsm; -template struct Trsm; -template struct Trsm>; -template struct Trsm>; - -// FFI Kernel - template ffi::Error TriMatrixEquationSolver::Kernel( ffi::Buffer x, ffi::Buffer y, // TODO(b/397715595): Remove RemainingArgs no earlier than 180 days after // the release of JAX 0.5.2. - ffi::RemainingArgs, - ffi::ResultBuffer y_out, MatrixParams::Side side, + ffi::RemainingArgs, ffi::ResultBuffer y_out, MatrixParams::Side side, MatrixParams::UpLo uplo, MatrixParams::Transpose trans_x, MatrixParams::Diag diag) { CopyIfDiffBuffer(y, y_out); @@ -189,42 +111,6 @@ template struct TriMatrixEquationSolver; //== LU Decomposition ==// -// lapack getrf - -template -typename Getrf::FnType* Getrf::fn = nullptr; - -template -void Getrf::Kernel(void* out_tuple, void** data, XlaCustomCallStatus*) { - int b = *(reinterpret_cast(data[0])); - int m = *(reinterpret_cast(data[1])); - int n = *(reinterpret_cast(data[2])); - const T* a_in = reinterpret_cast(data[3]); - - void** out = reinterpret_cast(out_tuple); - T* a_out = reinterpret_cast(out[0]); - int* ipiv = reinterpret_cast(out[1]); - int* info = reinterpret_cast(out[2]); - if (a_out != a_in) { - std::memcpy(a_out, a_in, - static_cast(b) * static_cast(m) * - static_cast(n) * sizeof(T)); - } - for (int i = 0; i < b; ++i) { - fn(&m, &n, a_out, &m, ipiv, info); - a_out += static_cast(m) * static_cast(n); - ipiv += std::min(m, n); - ++info; - } -} - -template struct Getrf; -template struct Getrf; -template struct Getrf>; -template struct Getrf>; - -// FFI Kernel - template ffi::Error LuDecomposition::Kernel( ffi::Buffer x, ffi::ResultBuffer x_out, @@ -261,55 +147,6 @@ template struct LuDecomposition; //== QR Factorization ==// -// lapack geqrf - -template -typename Geqrf::FnType* Geqrf::fn = nullptr; - -template -void Geqrf::Kernel(void* out_tuple, void** data, XlaCustomCallStatus*) { - int b = *(reinterpret_cast(data[0])); - int m = *(reinterpret_cast(data[1])); - int n = *(reinterpret_cast(data[2])); - int lwork = *(reinterpret_cast(data[3])); - const T* a_in = reinterpret_cast(data[4]); - - void** out = reinterpret_cast(out_tuple); - T* a_out = reinterpret_cast(out[0]); - T* tau = reinterpret_cast(out[1]); - int* info = reinterpret_cast(out[2]); - T* work = reinterpret_cast(out[3]); - - if (a_out != a_in) { - std::memcpy(a_out, a_in, - static_cast(b) * static_cast(m) * - static_cast(n) * sizeof(T)); - } - - for (int i = 0; i < b; ++i) { - fn(&m, &n, a_out, &m, tau, work, &lwork, info); - a_out += static_cast(m) * static_cast(n); - tau += std::min(m, n); - ++info; - } -} - -template -int64_t Geqrf::Workspace(lapack_int m, lapack_int n) { - T work = 0; - lapack_int lwork = -1; - lapack_int info = 0; - fn(&m, &n, nullptr, &m, nullptr, &work, &lwork, &info); - return info == 0 ? static_cast(std::real(work)) : -1; -} - -template struct Geqrf; -template struct Geqrf; -template struct Geqrf>; -template struct Geqrf>; - -// FFI Kernel - template ffi::Error QrFactorization::Kernel(ffi::Buffer x, ffi::ResultBuffer x_out, @@ -430,56 +267,6 @@ template struct PivotingQrFactorization; //== Orthogonal QR ==// //== Computes orthogonal matrix Q from QR Decomposition ==// -// lapack orgqr - -template -typename Orgqr::FnType* Orgqr::fn = nullptr; - -template -void Orgqr::Kernel(void* out_tuple, void** data, XlaCustomCallStatus*) { - int b = *(reinterpret_cast(data[0])); - int m = *(reinterpret_cast(data[1])); - int n = *(reinterpret_cast(data[2])); - int k = *(reinterpret_cast(data[3])); - int lwork = *(reinterpret_cast(data[4])); - const T* a_in = reinterpret_cast(data[5]); - T* tau = reinterpret_cast(data[6]); - - void** out = reinterpret_cast(out_tuple); - T* a_out = reinterpret_cast(out[0]); - int* info = reinterpret_cast(out[1]); - T* work = reinterpret_cast(out[2]); - - if (a_out != a_in) { - std::memcpy(a_out, a_in, - static_cast(b) * static_cast(m) * - static_cast(n) * sizeof(T)); - } - - for (int i = 0; i < b; ++i) { - fn(&m, &n, &k, a_out, &m, tau, work, &lwork, info); - a_out += static_cast(m) * static_cast(n); - tau += k; - ++info; - } -} - -template -int64_t Orgqr::Workspace(int m, int n, int k) { - T work = 0; - int lwork = -1; - int info = 0; - fn(&m, &n, &k, nullptr, &m, nullptr, &work, &lwork, &info); - return info ? -1 : static_cast(std::real(work)); -} - -template struct Orgqr; -template struct Orgqr; -template struct Orgqr>; -template struct Orgqr>; - -// FFI Kernel - template ffi::Error OrthogonalQr::Kernel(ffi::Buffer x, ffi::Buffer tau, @@ -535,42 +322,6 @@ template struct OrthogonalQr; //== Cholesky Factorization ==// -// lapack potrf - -template -typename Potrf::FnType* Potrf::fn = nullptr; - -template -void Potrf::Kernel(void* out_tuple, void** data, XlaCustomCallStatus*) { - int32_t lower = *(reinterpret_cast(data[0])); - int b = *(reinterpret_cast(data[1])); - int n = *(reinterpret_cast(data[2])); - const T* a_in = reinterpret_cast(data[3]); - char uplo = lower ? 'L' : 'U'; - - void** out = reinterpret_cast(out_tuple); - T* a_out = reinterpret_cast(out[0]); - int* info = reinterpret_cast(out[1]); - if (a_out != a_in) { - std::memcpy(a_out, a_in, - static_cast(b) * static_cast(n) * - static_cast(n) * sizeof(T)); - } - - for (int i = 0; i < b; ++i) { - fn(&uplo, &n, a_out, &n, info); - a_out += static_cast(n) * static_cast(n); - ++info; - } -} - -template struct Potrf; -template struct Potrf; -template struct Potrf>; -template struct Potrf>; - -// FFI Kernel - template ffi::Error CholeskyFactorization::Kernel( ffi::Buffer x, MatrixParams::UpLo uplo, @@ -604,162 +355,6 @@ template struct CholeskyFactorization; //== Singular Value Decomposition (SVD) ==// //== using a divide and conquer method ==// -// lapack gesdd - -static char GesddJobz(bool job_opt_compute_uv, bool job_opt_full_matrices) { - if (!job_opt_compute_uv) { - return 'N'; - } else if (!job_opt_full_matrices) { - return 'S'; - } - return 'A'; -} - -lapack_int GesddIworkSize(int64_t m, int64_t n) { - return CastNoOverflow(8 * std::min(m, n), "gesdd iwork"); -} - -template -typename RealGesdd::FnType* RealGesdd::fn = nullptr; - -template -void RealGesdd::Kernel(void* out_tuple, void** data, XlaCustomCallStatus*) { - int32_t job_opt_full_matrices = *(reinterpret_cast(data[0])); - int32_t job_opt_compute_uv = *(reinterpret_cast(data[1])); - int b = *(reinterpret_cast(data[2])); - int m = *(reinterpret_cast(data[3])); - int n = *(reinterpret_cast(data[4])); - int lwork = *(reinterpret_cast(data[5])); - T* a_in = reinterpret_cast(data[6]); - - void** out = reinterpret_cast(out_tuple); - T* a_out = reinterpret_cast(out[0]); - T* s = reinterpret_cast(out[1]); - T* u = reinterpret_cast(out[2]); - T* vt = reinterpret_cast(out[3]); - int* info = reinterpret_cast(out[4]); - int* iwork = reinterpret_cast(out[5]); - T* work = reinterpret_cast(out[6]); - - if (a_out != a_in) { - std::memcpy(a_out, a_in, - static_cast(b) * static_cast(m) * - static_cast(n) * sizeof(T)); - } - - char jobz = GesddJobz(job_opt_compute_uv, job_opt_full_matrices); - - int lda = m; - int ldu = m; - int tdu = job_opt_full_matrices ? m : std::min(m, n); - int ldvt = job_opt_full_matrices ? n : std::min(m, n); - - for (int i = 0; i < b; ++i) { - fn(&jobz, &m, &n, a_out, &lda, s, u, &ldu, vt, &ldvt, work, &lwork, iwork, - info); - a_out += static_cast(m) * n; - s += std::min(m, n); - u += static_cast(m) * tdu; - vt += static_cast(ldvt) * n; - ++info; - } -} - -template -int64_t RealGesdd::Workspace(lapack_int m, lapack_int n, - bool job_opt_compute_uv, - bool job_opt_full_matrices) { - T work = 0; - int lwork = -1; - int info = 0; - int ldvt = job_opt_full_matrices ? n : std::min(m, n); - char jobz = GesddJobz(job_opt_compute_uv, job_opt_full_matrices); - fn(&jobz, &m, &n, nullptr, &m, nullptr, nullptr, &m, nullptr, &ldvt, &work, - &lwork, nullptr, &info); - return info ? -1 : static_cast(work); -} - -lapack_int ComplexGesddRworkSize(int64_t m, int64_t n, int compute_uv) { - int64_t mn = std::min(m, n); - if (compute_uv == 0) { - return CastNoOverflow(7 * mn, "complex gesdd rwork"); - } - int64_t mx = std::max(m, n); - return CastNoOverflow( - std::max(5 * mn * mn + 5 * mn, 2 * mx * mn + 2 * mn * mn + mn), - "complex gesdd rwork"); -} - -template -typename ComplexGesdd::FnType* ComplexGesdd::fn = nullptr; - -template -void ComplexGesdd::Kernel(void* out_tuple, void** data, - XlaCustomCallStatus*) { - int32_t job_opt_full_matrices = *(reinterpret_cast(data[0])); - int32_t job_opt_compute_uv = *(reinterpret_cast(data[1])); - int b = *(reinterpret_cast(data[2])); - int m = *(reinterpret_cast(data[3])); - int n = *(reinterpret_cast(data[4])); - int lwork = *(reinterpret_cast(data[5])); - T* a_in = reinterpret_cast(data[6]); - - void** out = reinterpret_cast(out_tuple); - T* a_out = reinterpret_cast(out[0]); - typename T::value_type* s = reinterpret_cast(out[1]); - T* u = reinterpret_cast(out[2]); - T* vt = reinterpret_cast(out[3]); - int* info = reinterpret_cast(out[4]); - int* iwork = reinterpret_cast(out[5]); - typename T::value_type* rwork = - reinterpret_cast(out[6]); - T* work = reinterpret_cast(out[7]); - - if (a_out != a_in) { - std::memcpy(a_out, a_in, - static_cast(b) * static_cast(m) * - static_cast(n) * sizeof(T)); - } - - char jobz = GesddJobz(job_opt_compute_uv, job_opt_full_matrices); - - int lda = m; - int ldu = m; - int tdu = job_opt_full_matrices ? m : std::min(m, n); - int ldvt = job_opt_full_matrices ? n : std::min(m, n); - - for (int i = 0; i < b; ++i) { - fn(&jobz, &m, &n, a_out, &lda, s, u, &ldu, vt, &ldvt, work, &lwork, rwork, - iwork, info); - a_out += static_cast(m) * n; - s += std::min(m, n); - u += static_cast(m) * tdu; - vt += static_cast(ldvt) * n; - ++info; - } -} - -template -int64_t ComplexGesdd::Workspace(lapack_int m, lapack_int n, - bool job_opt_compute_uv, - bool job_opt_full_matrices) { - T work = 0; - int lwork = -1; - int info = 0; - int ldvt = job_opt_full_matrices ? n : std::min(m, n); - char jobz = GesddJobz(job_opt_compute_uv, job_opt_full_matrices); - fn(&jobz, &m, &n, nullptr, &m, nullptr, nullptr, &m, nullptr, &ldvt, &work, - &lwork, nullptr, nullptr, &info); - return info ? -1 : static_cast(work.real()); -} - -template struct RealGesdd; -template struct RealGesdd; -template struct ComplexGesdd>; -template struct ComplexGesdd>; - -// FFI Kernel - namespace internal { template @@ -949,16 +544,16 @@ static ffi::Error SvdQRKernel( for (int64_t i = 0; i < batch_count; ++i) { if constexpr (ffi::IsComplexType()) { - svd::SVDQRType::fn(&mode_v, &mode_v, &x_rows_v, &x_cols_v, x_out_data, - &x_leading_dim_v, singular_values_data, u_data, - &u_leading_dim_v, vt_data, &vt_leading_dim_v, - work_data.get(), &workspace_dim_v, rwork.get(), - info_data); + svd::SVDQRType::fn(&mode_v, &mode_v, &x_rows_v, &x_cols_v, + x_out_data, &x_leading_dim_v, + singular_values_data, u_data, &u_leading_dim_v, + vt_data, &vt_leading_dim_v, work_data.get(), + &workspace_dim_v, rwork.get(), info_data); } else { - svd::SVDQRType::fn(&mode_v, &mode_v, &x_rows_v, &x_cols_v, x_out_data, - &x_leading_dim_v, singular_values_data, u_data, - &u_leading_dim_v, vt_data, &vt_leading_dim_v, - work_data.get(), &workspace_dim_v, info_data); + svd::SVDQRType::fn( + &mode_v, &mode_v, &x_rows_v, &x_cols_v, x_out_data, &x_leading_dim_v, + singular_values_data, u_data, &u_leading_dim_v, vt_data, + &vt_leading_dim_v, work_data.get(), &workspace_dim_v, info_data); } x_out_data += x_out_step; singular_values_data += singular_values_step; @@ -970,9 +565,8 @@ static ffi::Error SvdQRKernel( } template -static absl::StatusOr SvdQRGetWorkspaceSize(lapack_int x_rows, - lapack_int x_cols, - svd::ComputationMode mode) { +static absl::StatusOr SvdQRGetWorkspaceSize( + lapack_int x_rows, lapack_int x_cols, svd::ComputationMode mode) { ffi::NativeType optimal_size = {}; lapack_int info = 0; lapack_int workspace_query = -1; @@ -994,7 +588,8 @@ static absl::StatusOr SvdQRGetWorkspaceSize(lapack_int x_rows, &u_leading_dim_v, nullptr, &vt_leading_dim_v, &optimal_size, &workspace_query, &info); } - return info == 0 ? MaybeCastNoOverflow(std::real(optimal_size)) : -1; + return info == 0 ? MaybeCastNoOverflow(std::real(optimal_size)) + : -1; } } // namespace internal @@ -1053,7 +648,8 @@ ffi::Error SingularValueDecompositionQRComplex::Kernel( } template -absl::StatusOr SingularValueDecompositionQR::GetWorkspaceSize( +absl::StatusOr +SingularValueDecompositionQR::GetWorkspaceSize( lapack_int x_rows, lapack_int x_cols, svd::ComputationMode mode) { return internal::SvdQRGetWorkspaceSize(x_rows, x_cols, mode); } @@ -1077,7 +673,8 @@ absl::StatusOr svd::GetRealWorkspaceSize( 2 * max_dim * min_dim + 2 * min_dim * min_dim + min_dim)); } -absl::StatusOr svd::GetRealWorkspaceSizeQR(int64_t x_rows, int64_t x_cols) { +absl::StatusOr svd::GetRealWorkspaceSizeQR(int64_t x_rows, + int64_t x_cols) { return CastNoOverflow(5 * std::min(x_rows, x_cols)); } @@ -1098,109 +695,6 @@ template struct SingularValueDecompositionQRComplex; //== Eigenvalues and eigenvectors ==// -// lapack syevd/heevd - -// # Workspace sizes, taken from the LAPACK documentation. -lapack_int SyevdWorkSize(int64_t n) { - return CastNoOverflow(1 + 6 * n + 2 * n * n, "syevd lwork"); -} - -lapack_int SyevdIworkSize(int64_t n) { - return CastNoOverflow(3 + 5 * n, "syevd iwork"); -} - -template -typename RealSyevd::FnType* RealSyevd::fn = nullptr; - -template -void RealSyevd::Kernel(void* out_tuple, void** data, XlaCustomCallStatus*) { - int32_t lower = *(reinterpret_cast(data[0])); - int b = *(reinterpret_cast(data[1])); - int n = *(reinterpret_cast(data[2])); - const T* a_in = reinterpret_cast(data[3]); - void** out = reinterpret_cast(out_tuple); - T* a_out = reinterpret_cast(out[0]); - T* w_out = reinterpret_cast(out[1]); - int* info_out = reinterpret_cast(out[2]); - T* work = reinterpret_cast(out[3]); - int* iwork = reinterpret_cast(out[4]); - if (a_out != a_in) { - std::memcpy(a_out, a_in, - static_cast(b) * static_cast(n) * - static_cast(n) * sizeof(T)); - } - - char jobz = 'V'; - char uplo = lower ? 'L' : 'U'; - - lapack_int lwork = SyevdWorkSize(n); - lapack_int liwork = SyevdIworkSize(n); - for (int i = 0; i < b; ++i) { - fn(&jobz, &uplo, &n, a_out, &n, w_out, work, &lwork, iwork, &liwork, - info_out); - a_out += static_cast(n) * n; - w_out += n; - ++info_out; - } -} - -// Workspace sizes, taken from the LAPACK documentation. -lapack_int HeevdWorkSize(int64_t n) { - return CastNoOverflow(1 + 2 * n + n * n, "heevd work"); -} - -lapack_int HeevdRworkSize(int64_t n) { - return CastNoOverflow(1 + 5 * n + 2 * n * n, "heevd rwork"); -} - -template -typename ComplexHeevd::FnType* ComplexHeevd::fn = nullptr; - -template -void ComplexHeevd::Kernel(void* out_tuple, void** data, - XlaCustomCallStatus*) { - int32_t lower = *(reinterpret_cast(data[0])); - int b = *(reinterpret_cast(data[1])); - int n = *(reinterpret_cast(data[2])); - const T* a_in = reinterpret_cast(data[3]); - - void** out = reinterpret_cast(out_tuple); - T* a_out = reinterpret_cast(out[0]); - typename T::value_type* w_out = - reinterpret_cast(out[1]); - int* info_out = reinterpret_cast(out[2]); - T* work = reinterpret_cast(out[3]); - typename T::value_type* rwork = - reinterpret_cast(out[4]); - int* iwork = reinterpret_cast(out[5]); - if (a_out != a_in) { - std::memcpy(a_out, a_in, - static_cast(b) * static_cast(n) * - static_cast(n) * sizeof(T)); - } - - char jobz = 'V'; - char uplo = lower ? 'L' : 'U'; - - lapack_int lwork = HeevdWorkSize(n); - lapack_int lrwork = HeevdRworkSize(n); - lapack_int liwork = SyevdIworkSize(n); - for (int i = 0; i < b; ++i) { - fn(&jobz, &uplo, &n, a_out, &n, w_out, work, &lwork, rwork, &lrwork, iwork, - &liwork, info_out); - a_out += static_cast(n) * n; - w_out += n; - ++info_out; - } -} - -template struct RealSyevd; -template struct RealSyevd; -template struct ComplexHeevd>; -template struct ComplexHeevd>; - -// FFI Kernel - absl::StatusOr eig::GetWorkspaceSize(int64_t x_cols, ComputationMode mode) { switch (mode) { @@ -1339,155 +833,6 @@ template struct EigenvalueDecompositionSymmetric; template struct EigenvalueDecompositionHermitian; template struct EigenvalueDecompositionHermitian; -// lapack geev - -template -typename RealGeev::FnType* RealGeev::fn = nullptr; - -template -void RealGeev::Kernel(void* out_tuple, void** data, XlaCustomCallStatus*) { - int b = *(reinterpret_cast(data[0])); - int n_int = *(reinterpret_cast(data[1])); - int64_t n = n_int; - char jobvl = *(reinterpret_cast(data[2])); - char jobvr = *(reinterpret_cast(data[3])); - - const T* a_in = reinterpret_cast(data[4]); - - void** out = reinterpret_cast(out_tuple); - T* a_work = reinterpret_cast(out[0]); - T* vl_work = reinterpret_cast(out[1]); - T* vr_work = reinterpret_cast(out[2]); - - T* wr_out = reinterpret_cast(out[3]); - T* wi_out = reinterpret_cast(out[4]); - std::complex* vl_out = reinterpret_cast*>(out[5]); - std::complex* vr_out = reinterpret_cast*>(out[6]); - int* info_out = reinterpret_cast(out[7]); - - // TODO(phawkins): preallocate workspace using XLA. - T work_query; - int lwork = -1; - fn(&jobvl, &jobvr, &n_int, a_work, &n_int, wr_out, wi_out, vl_work, &n_int, - vr_work, &n_int, &work_query, &lwork, info_out); - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(&work_query, sizeof(work_query)); - lwork = static_cast(work_query); - T* work = new T[lwork]; - - auto is_finite = [](T* a_work, int64_t n) { - for (int64_t j = 0; j < n; ++j) { - for (int64_t k = 0; k < n; ++k) { - if (!std::isfinite(a_work[j * n + k])) { - return false; - } - } - } - return true; - }; - for (int i = 0; i < b; ++i) { - size_t a_size = n * n * sizeof(T); - std::memcpy(a_work, a_in, a_size); - if (is_finite(a_work, n)) { - fn(&jobvl, &jobvr, &n_int, a_work, &n_int, wr_out, wi_out, vl_work, - &n_int, vr_work, &n_int, work, &lwork, info_out); - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(a_work, a_size); - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(wr_out, sizeof(T) * n); - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(wi_out, sizeof(T) * n); - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(vl_work, sizeof(T) * n * n); - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(vr_work, sizeof(T) * n * n); - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(info_out, sizeof(int)); - if (info_out[0] == 0) { - UnpackEigenvectors(n, wi_out, vl_work, vl_out); - UnpackEigenvectors(n, wi_out, vr_work, vr_out); - } - } else { - *info_out = -4; - } - a_in += n * n; - wr_out += n; - wi_out += n; - vl_out += n * n; - vr_out += n * n; - ++info_out; - } - delete[] work; -} - -template -typename ComplexGeev::FnType* ComplexGeev::fn = nullptr; - -template -void ComplexGeev::Kernel(void* out_tuple, void** data, - XlaCustomCallStatus*) { - int b = *(reinterpret_cast(data[0])); - int n_int = *(reinterpret_cast(data[1])); - int64_t n = n_int; - char jobvl = *(reinterpret_cast(data[2])); - char jobvr = *(reinterpret_cast(data[3])); - - const T* a_in = reinterpret_cast(data[4]); - - void** out = reinterpret_cast(out_tuple); - T* a_work = reinterpret_cast(out[0]); - typename T::value_type* r_work = - reinterpret_cast(out[1]); - - T* w_out = reinterpret_cast(out[2]); - T* vl_out = reinterpret_cast(out[3]); - T* vr_out = reinterpret_cast(out[4]); - int* info_out = reinterpret_cast(out[5]); - - // TODO(phawkins): preallocate workspace using XLA. - T work_query; - int lwork = -1; - fn(&jobvl, &jobvr, &n_int, a_work, &n_int, w_out, vl_out, &n_int, vr_out, - &n_int, &work_query, &lwork, r_work, info_out); - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(&work_query, sizeof(work_query)); - lwork = static_cast(work_query.real()); - T* work = new T[lwork]; - - auto is_finite = [](T* a_work, int64_t n) { - for (int64_t j = 0; j < n; ++j) { - for (int64_t k = 0; k < n; ++k) { - T v = a_work[j * n + k]; - if (!std::isfinite(v.real()) || !std::isfinite(v.imag())) { - return false; - } - } - } - return true; - }; - - for (int i = 0; i < b; ++i) { - size_t a_size = n * n * sizeof(T); - std::memcpy(a_work, a_in, a_size); - if (is_finite(a_work, n)) { - fn(&jobvl, &jobvr, &n_int, a_work, &n_int, w_out, vl_out, &n_int, vr_out, - &n_int, work, &lwork, r_work, info_out); - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(a_work, a_size); - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(w_out, sizeof(T) * n); - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(vl_out, sizeof(T) * n * n); - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(vr_out, sizeof(T) * n * n); - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(info_out, sizeof(int)); - } else { - *info_out = -4; - } - a_in += n * n; - w_out += n; - vl_out += n * n; - vr_out += n * n; - info_out += 1; - } - delete[] work; -} - -template struct RealGeev; -template struct RealGeev; -template struct ComplexGeev>; -template struct ComplexGeev>; - -// FFI Kernel - template ffi::Error EigenvalueDecomposition::Kernel( ffi::Buffer x, eig::ComputationMode compute_left, @@ -1662,138 +1007,6 @@ template struct EigenvalueDecompositionComplex; //== Schur Decomposition ==// -// lapack gees - -template -typename RealGees::FnType* RealGees::fn = nullptr; - -template -void RealGees::Kernel(void* out_tuple, void** data, XlaCustomCallStatus*) { - int b = *(reinterpret_cast(data[0])); - int n_int = *(reinterpret_cast(data[1])); - int64_t n = n_int; - char jobvs = *(reinterpret_cast(data[2])); - char sort = *(reinterpret_cast(data[3])); - - const T* a_in = reinterpret_cast(data[4]); - - // bool* select (T, T) = reinterpret_cast(data[5]); - bool (*select)(T, T) = nullptr; - - void** out = reinterpret_cast(out_tuple); - T* a_out = reinterpret_cast(out[0]); - - T* wr_out = reinterpret_cast(out[1]); - T* wi_out = reinterpret_cast(out[2]); - T* vs_out = reinterpret_cast(out[3]); - int* sdim_out = reinterpret_cast(out[4]); - int* info_out = reinterpret_cast(out[5]); - - bool* b_work = (sort != 'N') ? (new bool[n]) : nullptr; - - T work_query; - int lwork = -1; - fn(&jobvs, &sort, select, &n_int, a_out, &n_int, sdim_out, wr_out, wi_out, - vs_out, &n_int, &work_query, &lwork, b_work, info_out); - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(&work_query, sizeof(work_query)); - lwork = static_cast(work_query); - T* work = new T[lwork]; - - size_t a_size = static_cast(n) * static_cast(n) * sizeof(T); - if (a_out != a_in) { - std::memcpy(a_out, a_in, static_cast(b) * a_size); - } - - for (int i = 0; i < b; ++i) { - fn(&jobvs, &sort, select, &n_int, a_out, &n_int, sdim_out, wr_out, wi_out, - vs_out, &n_int, work, &lwork, b_work, info_out); - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(a_out, a_size); - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(sdim_out, sizeof(int)); - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(wr_out, sizeof(T) * n); - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(wi_out, sizeof(T) * n); - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(vs_out, sizeof(T) * n * n); - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(info_out, sizeof(int)); - - a_in += n * n; - a_out += n * n; - wr_out += n; - wi_out += n; - vs_out += n * n; - ++sdim_out; - ++info_out; - } - delete[] work; - delete[] b_work; -} - -template -typename ComplexGees::FnType* ComplexGees::fn = nullptr; - -template -void ComplexGees::Kernel(void* out_tuple, void** data, - XlaCustomCallStatus*) { - int b = *(reinterpret_cast(data[0])); - int n_int = *(reinterpret_cast(data[1])); - int64_t n = n_int; - char jobvs = *(reinterpret_cast(data[2])); - char sort = *(reinterpret_cast(data[3])); - - const T* a_in = reinterpret_cast(data[4]); - - // bool* select (T, T) = reinterpret_cast(data[5]); - bool (*select)(T) = nullptr; - - void** out = reinterpret_cast(out_tuple); - T* a_out = reinterpret_cast(out[0]); - typename T::value_type* r_work = - reinterpret_cast(out[1]); - T* w_out = reinterpret_cast(out[2]); - T* vs_out = reinterpret_cast(out[3]); - int* sdim_out = reinterpret_cast(out[4]); - int* info_out = reinterpret_cast(out[5]); - - bool* b_work = (sort != 'N') ? (new bool[n]) : nullptr; - - T work_query; - int lwork = -1; - fn(&jobvs, &sort, select, &n_int, a_out, &n_int, sdim_out, w_out, vs_out, - &n_int, &work_query, &lwork, r_work, b_work, info_out); - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(&work_query, sizeof(work_query)); - lwork = static_cast(work_query.real()); - T* work = new T[lwork]; - - if (a_out != a_in) { - std::memcpy(a_out, a_in, - static_cast(b) * static_cast(n) * - static_cast(n) * sizeof(T)); - } - - for (int i = 0; i < b; ++i) { - fn(&jobvs, &sort, select, &n_int, a_out, &n_int, sdim_out, w_out, vs_out, - &n_int, work, &lwork, r_work, b_work, info_out); - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(w_out, sizeof(T) * n); - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(vs_out, sizeof(T) * n * n); - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(info_out, sizeof(int)); - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(sdim_out, sizeof(int)); - - a_in += n * n; - a_out += n * n; - w_out += n; - vs_out += n * n; - ++info_out; - ++sdim_out; - } - delete[] work; - delete[] b_work; -} - -template struct RealGees; -template struct RealGees; -template struct ComplexGees>; -template struct ComplexGees>; - -// FFI Kernel - template ffi::Error SchurDecomposition::Kernel( ffi::Buffer x, schur::ComputationMode mode, schur::Sort sort, @@ -1968,60 +1181,6 @@ template struct SchurDecompositionComplex; //== Hessenberg Decomposition ==// -// lapack gehrd - -template -typename Gehrd::FnType* Gehrd::fn = nullptr; - -template -void Gehrd::Kernel(void* out_tuple, void** data, XlaCustomCallStatus*) { - int32_t n = *reinterpret_cast(data[0]); - int32_t ilo = *reinterpret_cast(data[1]); - int32_t ihi = *reinterpret_cast(data[2]); - int32_t lda = *reinterpret_cast(data[3]); - int32_t batch = *reinterpret_cast(data[4]); - int32_t lwork = *reinterpret_cast(data[5]); - T* a = reinterpret_cast(data[6]); - - void** out = reinterpret_cast(out_tuple); - T* a_out = reinterpret_cast(out[0]); - T* tau = reinterpret_cast(out[1]); - int* info = reinterpret_cast(out[2]); - T* work = reinterpret_cast(out[3]); - - if (a_out != a) { - std::memcpy(a_out, a, - static_cast(batch) * static_cast(n) * - static_cast(n) * sizeof(T)); - } - - int64_t a_plus = static_cast(lda) * static_cast(n); - - for (int i = 0; i < batch; ++i) { - fn(&n, &ilo, &ihi, a_out, &lda, tau, work, &lwork, info); - a_out += a_plus; - tau += n - 1; - ++info; - } -} - -template -int64_t Gehrd::Workspace(lapack_int lda, lapack_int n, lapack_int ilo, - lapack_int ihi) { - T work = 0; - lapack_int lwork = -1; - lapack_int info = 0; - fn(&n, &ilo, &ihi, nullptr, &lda, nullptr, &work, &lwork, &info); - return info == 0 ? static_cast(std::real(work)) : -1; -} - -template struct Gehrd; -template struct Gehrd; -template struct Gehrd>; -template struct Gehrd>; - -// FFI Kernel - template ffi::Error HessenbergDecomposition::Kernel( ffi::Buffer x, lapack_int low, lapack_int high, @@ -2075,67 +1234,6 @@ template struct HessenbergDecomposition; //== Tridiagonal Reduction ==// -// lapack sytrd/hetrd - -template -typename Sytrd::FnType* Sytrd::fn = nullptr; - -template -void Sytrd::Kernel(void* out_tuple, void** data, XlaCustomCallStatus*) { - int32_t n = *reinterpret_cast(data[0]); - int32_t lower = *reinterpret_cast(data[1]); - int32_t lda = *reinterpret_cast(data[2]); - int32_t batch = *reinterpret_cast(data[3]); - int32_t lwork = *reinterpret_cast(data[4]); - T* a = reinterpret_cast(data[5]); - - void** out = reinterpret_cast(out_tuple); - T* a_out = reinterpret_cast(out[0]); - typedef typename real_type::type Real; - Real* d = reinterpret_cast(out[1]); - Real* e = reinterpret_cast(out[2]); - T* tau = reinterpret_cast(out[3]); - int* info = reinterpret_cast(out[4]); - T* work = reinterpret_cast(out[5]); - - if (a_out != a) { - std::memcpy(a_out, a, - static_cast(batch) * static_cast(n) * - static_cast(n) * sizeof(T)); - } - - char cuplo = lower ? 'L' : 'U'; - - int64_t a_plus = static_cast(lda) * static_cast(n); - - for (int i = 0; i < batch; ++i) { - fn(&cuplo, &n, a_out, &lda, d, e, tau, work, &lwork, info); - a_out += a_plus; - d += n; - e += n - 1; - tau += n - 1; - ++info; - } -} - -template -int64_t Sytrd::Workspace(lapack_int lda, lapack_int n) { - char cuplo = 'L'; - T work = 0; - lapack_int lwork = -1; - lapack_int info = 0; - fn(&cuplo, &n, nullptr, &lda, nullptr, nullptr, nullptr, &work, &lwork, - &info); - return info == 0 ? static_cast(std::real(work)) : -1; -} - -template struct Sytrd; -template struct Sytrd; -template struct Sytrd>; -template struct Sytrd>; - -// FFI Kernel - template ffi::Error TridiagonalReduction::Kernel( ffi::Buffer x, MatrixParams::UpLo uplo, diff --git a/jaxlib/cpu/lapack_kernels.h b/jaxlib/cpu/lapack_kernels.h index e075ff29387f..572f67b7744b 100644 --- a/jaxlib/cpu/lapack_kernels.h +++ b/jaxlib/cpu/lapack_kernels.h @@ -18,13 +18,10 @@ limitations under the License. #include #include -#include #include #include "absl/status/statusor.h" -#include "xla/ffi/api/c_api.h" #include "xla/ffi/api/ffi.h" -#include "xla/service/custom_call_status.h" // Underlying function pointers (i.e., KERNEL_CLASS::Fn) are initialized either // by the nanobind wrapper that links them to an existing SciPy lapack instance, @@ -93,26 +90,6 @@ void AssignKernelFn(typename KernelType::FnType* func) { } // namespace jax -#define DEFINE_CHAR_ENUM_ATTR_DECODING(ATTR) \ - template <> \ - struct xla::ffi::AttrDecoding { \ - using Type = ATTR; \ - static std::optional Decode(XLA_FFI_AttrType type, void* attr, \ - DiagnosticEngine& diagnostic); \ - } - -// XLA needs attributes to have deserialization method specified -DEFINE_CHAR_ENUM_ATTR_DECODING(jax::MatrixParams::Side); -DEFINE_CHAR_ENUM_ATTR_DECODING(jax::MatrixParams::UpLo); -DEFINE_CHAR_ENUM_ATTR_DECODING(jax::MatrixParams::Transpose); -DEFINE_CHAR_ENUM_ATTR_DECODING(jax::MatrixParams::Diag); -DEFINE_CHAR_ENUM_ATTR_DECODING(jax::svd::ComputationMode); -DEFINE_CHAR_ENUM_ATTR_DECODING(jax::eig::ComputationMode); -DEFINE_CHAR_ENUM_ATTR_DECODING(jax::schur::ComputationMode); -DEFINE_CHAR_ENUM_ATTR_DECODING(jax::schur::Sort); - -#undef DEFINE_CHAR_ENUM_ATTR_DECODING - namespace jax { using lapack_int = int; @@ -122,20 +99,6 @@ static_assert( //== Triangular System Solver ==// -// lapack trsm - -template -struct Trsm { - using FnType = void(char* side, char* uplo, char* transa, char* diag, - lapack_int* m, lapack_int* n, T* alpha, T* a, - lapack_int* lda, T* b, lapack_int* ldb); - - static FnType* fn; - static void Kernel(void* out, void** data, XlaCustomCallStatus*); -}; - -// FFI Kernel - template <::xla::ffi::DataType dtype> struct TriMatrixEquationSolver { using ValueType = ::xla::ffi::NativeType; @@ -154,19 +117,6 @@ struct TriMatrixEquationSolver { //== LU Decomposition ==// -// lapack getrf - -template -struct Getrf { - using FnType = void(lapack_int* m, lapack_int* n, T* a, lapack_int* lda, - lapack_int* ipiv, lapack_int* info); - - static FnType* fn; - static void Kernel(void* out, void** data, XlaCustomCallStatus*); -}; - -// FFI Kernel - template <::xla::ffi::DataType dtype> struct LuDecomposition { using ValueType = ::xla::ffi::NativeType; @@ -182,21 +132,6 @@ struct LuDecomposition { //== QR Factorization ==// -// lapack geqrf - -template -struct Geqrf { - using FnType = void(lapack_int* m, lapack_int* n, T* a, lapack_int* lda, - T* tau, T* work, lapack_int* lwork, lapack_int* info); - - static FnType* fn; - static void Kernel(void* out, void** data, XlaCustomCallStatus*); - - static int64_t Workspace(lapack_int m, lapack_int n); -}; - -// FFI Kernel - template <::xla::ffi::DataType dtype> struct QrFactorization { using ValueType = ::xla::ffi::NativeType; @@ -240,23 +175,8 @@ struct PivotingQrFactorization { static int64_t GetWorkspaceSize(lapack_int x_rows, lapack_int x_cols); }; - //== Orthogonal QR ==// -// lapack orgqr - -template -struct Orgqr { - using FnType = void(lapack_int* m, lapack_int* n, lapack_int* k, T* a, - lapack_int* lda, T* tau, T* work, lapack_int* lwork, - lapack_int* info); - static FnType* fn; - static void Kernel(void* out, void** data, XlaCustomCallStatus*); - static int64_t Workspace(lapack_int m, lapack_int n, lapack_int k); -}; - -// FFI Kernel - template <::xla::ffi::DataType dtype> struct OrthogonalQr { using ValueType = ::xla::ffi::NativeType; @@ -276,16 +196,6 @@ struct OrthogonalQr { //== Cholesky Factorization ==// -// lapack potrf - -template -struct Potrf { - using FnType = void(char* uplo, lapack_int* n, T* a, lapack_int* lda, - lapack_int* info); - static FnType* fn; - static void Kernel(void* out, void** data, XlaCustomCallStatus*); -}; - template <::xla::ffi::DataType dtype> struct CholeskyFactorization { using ValueType = ::xla::ffi::NativeType; @@ -302,41 +212,6 @@ struct CholeskyFactorization { //== Singular Value Decomposition (SVD) ==// -// lapack gesdd - -lapack_int GesddIworkSize(int64_t m, int64_t n); - -template -struct RealGesdd { - using FnType = void(char* jobz, lapack_int* m, lapack_int* n, T* a, - lapack_int* lda, T* s, T* u, lapack_int* ldu, T* vt, - lapack_int* ldvt, T* work, lapack_int* lwork, - lapack_int* iwork, lapack_int* info); - static FnType* fn; - static void Kernel(void* out, void** data, XlaCustomCallStatus*); - - static int64_t Workspace(lapack_int m, lapack_int n, bool job_opt_compute_uv, - bool job_opt_full_matrices); -}; - -lapack_int ComplexGesddRworkSize(int64_t m, int64_t n, int compute_uv); - -template -struct ComplexGesdd { - using FnType = void(char* jobz, lapack_int* m, lapack_int* n, T* a, - lapack_int* lda, typename T::value_type* s, T* u, - lapack_int* ldu, T* vt, lapack_int* ldvt, T* work, - lapack_int* lwork, typename T::value_type* rwork, - lapack_int* iwork, lapack_int* info); - static FnType* fn; - static void Kernel(void* out, void** data, XlaCustomCallStatus*); - - static int64_t Workspace(lapack_int m, lapack_int n, bool job_opt_compute_uv, - bool job_opt_full_matrices); -}; - -// FFI Kernel - template <::xla::ffi::DataType dtype> struct SingularValueDecomposition { static_assert(!::xla::ffi::IsComplexType(), @@ -407,8 +282,8 @@ struct SingularValueDecompositionQR { ::xla::ffi::ResultBuffer info, svd::ComputationMode mode); static absl::StatusOr GetWorkspaceSize(lapack_int x_rows, - lapack_int x_cols, - svd::ComputationMode mode); + lapack_int x_cols, + svd::ComputationMode mode); }; template <::xla::ffi::DataType dtype> @@ -432,8 +307,8 @@ struct SingularValueDecompositionQRComplex { ::xla::ffi::ResultBuffer info, svd::ComputationMode mode); static absl::StatusOr GetWorkspaceSize(lapack_int x_rows, - lapack_int x_cols, - svd::ComputationMode mode); + lapack_int x_cols, + svd::ComputationMode mode); }; namespace svd { @@ -451,42 +326,13 @@ using SVDQRType = std::conditional_t<::xla::ffi::IsComplexType(), absl::StatusOr GetIntWorkspaceSize(int64_t x_rows, int64_t x_cols); absl::StatusOr GetRealWorkspaceSize(int64_t x_rows, int64_t x_cols, ComputationMode mode); -absl::StatusOr GetRealWorkspaceSizeQR(int64_t x_rows, int64_t x_cols); +absl::StatusOr GetRealWorkspaceSizeQR(int64_t x_rows, + int64_t x_cols); } // namespace svd //== Eigenvalues and eigenvectors ==// -// lapack syevd/heevd - -lapack_int SyevdWorkSize(int64_t n); -lapack_int SyevdIworkSize(int64_t n); - -template -struct RealSyevd { - using FnType = void(char* jobz, char* uplo, lapack_int* n, T* a, - lapack_int* lda, T* w, T* work, lapack_int* lwork, - lapack_int* iwork, lapack_int* liwork, lapack_int* info); - static FnType* fn; - static void Kernel(void* out, void** data, XlaCustomCallStatus*); -}; - -lapack_int HeevdWorkSize(int64_t n); -lapack_int HeevdRworkSize(int64_t n); - -template -struct ComplexHeevd { - using FnType = void(char* jobz, char* uplo, lapack_int* n, T* a, - lapack_int* lda, typename T::value_type* w, T* work, - lapack_int* lwork, typename T::value_type* rwork, - lapack_int* lrwork, lapack_int* iwork, lapack_int* liwork, - lapack_int* info); - static FnType* fn; - static void Kernel(void* out, void** data, XlaCustomCallStatus*); -}; - -// FFI Kernel - namespace eig { // Eigenvalue Decomposition @@ -544,8 +390,6 @@ struct EigenvalueDecompositionHermitian { ::xla::ffi::ResultBuffer info, eig::ComputationMode mode); }; -// lapack geev - // LAPACK uses a packed representation to represent a mixture of real // eigenvectors and complex conjugate pairs. This helper unpacks the // representation into regular complex matrices. @@ -574,28 +418,6 @@ static void UnpackEigenvectors(Int n, const T* eigenvals_imag, const T* packed, } } -template -struct RealGeev { - using FnType = void(char* jobvl, char* jobvr, lapack_int* n, T* a, - lapack_int* lda, T* wr, T* wi, T* vl, lapack_int* ldvl, - T* vr, lapack_int* ldvr, T* work, lapack_int* lwork, - lapack_int* info); - static FnType* fn; - static void Kernel(void* out, void** data, XlaCustomCallStatus*); -}; - -template -struct ComplexGeev { - using FnType = void(char* jobvl, char* jobvr, lapack_int* n, T* a, - lapack_int* lda, T* w, T* vl, lapack_int* ldvl, T* vr, - lapack_int* ldvr, T* work, lapack_int* lwork, - typename T::value_type* rwork, lapack_int* info); - static FnType* fn; - static void Kernel(void* out, void** data, XlaCustomCallStatus*); -}; - -// FFI Kernel - template <::xla::ffi::DataType dtype> struct EigenvalueDecomposition { static_assert(!::xla::ffi::IsComplexType(), @@ -653,31 +475,6 @@ struct EigenvalueDecompositionComplex { //== Schur Decomposition ==// -// lapack gees - -template -struct RealGees { - using FnType = void(char* jobvs, char* sort, bool (*select)(T, T), - lapack_int* n, T* a, lapack_int* lda, lapack_int* sdim, - T* wr, T* wi, T* vs, lapack_int* ldvs, T* work, - lapack_int* lwork, bool* bwork, lapack_int* info); - static FnType* fn; - static void Kernel(void* out, void** data, XlaCustomCallStatus*); -}; - -template -struct ComplexGees { - using FnType = void(char* jobvs, char* sort, bool (*select)(T), lapack_int* n, - T* a, lapack_int* lda, lapack_int* sdim, T* w, T* vs, - lapack_int* ldvs, T* work, lapack_int* lwork, - typename T::value_type* rwork, bool* bwork, - lapack_int* info); - static FnType* fn; - static void Kernel(void* out, void** data, XlaCustomCallStatus*); -}; - -// FFI Kernel - template <::xla::ffi::DataType dtype> struct SchurDecomposition { static_assert(!::xla::ffi::IsComplexType(), @@ -737,32 +534,6 @@ struct SchurDecompositionComplex { //== Hessenberg Decomposition ==// //== Reduces a non-symmetric square matrix to upper Hessenberg form ==// -// lapack gehrd - -template -struct Gehrd { - using FnType = void(lapack_int* n, lapack_int* ilo, lapack_int* ihi, T* a, - lapack_int* lda, T* tau, T* work, lapack_int* lwork, - lapack_int* info); - - static FnType* fn; - static void Kernel(void* out, void** data, XlaCustomCallStatus*); - - static int64_t Workspace(lapack_int lda, lapack_int n, lapack_int ilo, - lapack_int ihi); -}; - -template -struct real_type { - typedef T type; -}; -template -struct real_type> { - typedef T type; -}; - -// FFI Kernel - template <::xla::ffi::DataType dtype> struct HessenbergDecomposition { using ValueType = ::xla::ffi::NativeType; @@ -785,23 +556,6 @@ struct HessenbergDecomposition { //== Tridiagonal Reduction ==// //== Reduces a Symmetric/Hermitian square matrix to tridiagonal form ==// -// lapack sytrd/hetrd - -template -struct Sytrd { - using FnType = void(char* uplo, lapack_int* n, T* a, lapack_int* lda, - typename real_type::type* d, - typename real_type::type* e, T* tau, T* work, - lapack_int* lwork, lapack_int* info); - - static FnType* fn; - static void Kernel(void* out, void** data, XlaCustomCallStatus*); - - static int64_t Workspace(lapack_int lda, lapack_int n); -}; - -// FFI Kernel - template <::xla::ffi::DataType dtype> struct TridiagonalReduction { using ValueType = ::xla::ffi::NativeType; diff --git a/jaxlib/cpu/lapack_kernels_using_lapack.cc b/jaxlib/cpu/lapack_kernels_using_lapack.cc index 3c8ddf11cf29..c4b154f1ae70 100644 --- a/jaxlib/cpu/lapack_kernels_using_lapack.cc +++ b/jaxlib/cpu/lapack_kernels_using_lapack.cc @@ -13,9 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include -#include - #include "jaxlib/cpu/lapack_kernels.h" // From a Python binary, JAX obtains its LAPACK/BLAS kernels from Scipy, but @@ -100,241 +97,7 @@ jax::TridiagonalSolver::FnType zgtsv_; namespace jax { -#define JAX_KERNEL_FNTYPE_MISMATCH_MSG "FFI Kernel FnType mismatch" - -static_assert( - std::is_same_v::FnType, - jax::Trsm::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v::FnType, - jax::Trsm::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v::FnType, - jax::Trsm>::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v::FnType, - jax::Trsm>::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert(std::is_same_v::FnType, - jax::Getrf::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert(std::is_same_v::FnType, - jax::Getrf::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert(std::is_same_v::FnType, - jax::Getrf>::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert(std::is_same_v::FnType, - jax::Getrf>::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert(std::is_same_v::FnType, - jax::Geqrf::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert(std::is_same_v::FnType, - jax::Geqrf::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert(std::is_same_v::FnType, - jax::Geqrf>::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert(std::is_same_v::FnType, - jax::Geqrf>::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert(std::is_same_v::FnType, - jax::Orgqr::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert(std::is_same_v::FnType, - jax::Orgqr::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert(std::is_same_v::FnType, - jax::Orgqr>::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert(std::is_same_v::FnType, - jax::Orgqr>::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v::FnType, - jax::Potrf::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v::FnType, - jax::Potrf::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v::FnType, - jax::Potrf>::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v::FnType, - jax::Potrf>::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v::FnType, - jax::RealGesdd::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v::FnType, - jax::RealGesdd::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v< - jax::SingularValueDecompositionComplex::FnType, - jax::ComplexGesdd>::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v< - jax::SingularValueDecompositionComplex::FnType, - jax::ComplexGesdd>::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v< - jax::EigenvalueDecompositionSymmetric::FnType, - jax::RealSyevd::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v< - jax::EigenvalueDecompositionSymmetric::FnType, - jax::RealSyevd::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v< - jax::EigenvalueDecompositionHermitian::FnType, - jax::ComplexHeevd>::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v< - jax::EigenvalueDecompositionHermitian::FnType, - jax::ComplexHeevd>::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v::FnType, - jax::RealGeev::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v::FnType, - jax::RealGeev::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v< - jax::EigenvalueDecompositionComplex::FnType, - jax::ComplexGeev>::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v< - jax::EigenvalueDecompositionComplex::FnType, - jax::ComplexGeev>::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v::FnType, - jax::Sytrd::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v::FnType, - jax::Sytrd::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v::FnType, - jax::Sytrd>::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v::FnType, - jax::Sytrd>::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v::FnType, - jax::RealGees::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v::FnType, - jax::RealGees::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v::FnType, - jax::ComplexGees>::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v::FnType, - jax::ComplexGees>::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v::FnType, - jax::Gehrd::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v::FnType, - jax::Gehrd::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v::FnType, - jax::Gehrd>::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); -static_assert( - std::is_same_v::FnType, - jax::Gehrd>::FnType>, - JAX_KERNEL_FNTYPE_MISMATCH_MSG); - -#undef JAX_KERNEL_FNTYPE_MISMATCH_MSG - static auto init = []() -> int { - AssignKernelFn>(strsm_); - AssignKernelFn>(dtrsm_); - AssignKernelFn>>(ctrsm_); - AssignKernelFn>>(ztrsm_); - - AssignKernelFn>(sgetrf_); - AssignKernelFn>(dgetrf_); - AssignKernelFn>>(cgetrf_); - AssignKernelFn>>(zgetrf_); - - AssignKernelFn>(sgeqrf_); - AssignKernelFn>(dgeqrf_); - AssignKernelFn>>(cgeqrf_); - AssignKernelFn>>(zgeqrf_); - - AssignKernelFn>(sorgqr_); - AssignKernelFn>(dorgqr_); - AssignKernelFn>>(cungqr_); - AssignKernelFn>>(zungqr_); - - AssignKernelFn>(spotrf_); - AssignKernelFn>(dpotrf_); - AssignKernelFn>>(cpotrf_); - AssignKernelFn>>(zpotrf_); - - AssignKernelFn>(sgesdd_); - AssignKernelFn>(dgesdd_); - AssignKernelFn>>(cgesdd_); - AssignKernelFn>>(zgesdd_); - - AssignKernelFn>(ssyevd_); - AssignKernelFn>(dsyevd_); - AssignKernelFn>>(cheevd_); - AssignKernelFn>>(zheevd_); - - AssignKernelFn>(sgeev_); - AssignKernelFn>(dgeev_); - AssignKernelFn>>(cgeev_); - AssignKernelFn>>(zgeev_); - - AssignKernelFn>(sgees_); - AssignKernelFn>(dgees_); - AssignKernelFn>>(cgees_); - AssignKernelFn>>(zgees_); - - AssignKernelFn>(sgehrd_); - AssignKernelFn>(dgehrd_); - AssignKernelFn>>(cgehrd_); - AssignKernelFn>>(zgehrd_); - - AssignKernelFn>(ssytrd_); - AssignKernelFn>(dsytrd_); - AssignKernelFn>>(chetrd_); - AssignKernelFn>>(zhetrd_); - - // FFI Kernels - AssignKernelFn>(strsm_); AssignKernelFn>(dtrsm_); AssignKernelFn>(ctrsm_); diff --git a/jaxlib/cpu/sparse.cc b/jaxlib/cpu/sparse.cc new file mode 100644 index 000000000000..15f5c0f1984f --- /dev/null +++ b/jaxlib/cpu/sparse.cc @@ -0,0 +1,37 @@ +/* Copyright 2025 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "nanobind/nanobind.h" +#include "jaxlib/cpu/sparse_kernels.h" +#include "jaxlib/kernel_nanobind_helpers.h" + +namespace jax { +namespace { + +namespace nb = nanobind; + +nb::dict Registrations() { + nb::dict dict; + + dict["cpu_csr_sparse_dense_ffi"] = + EncapsulateFunction(cpu_csr_sparse_dense_ffi); + + return dict; +} + +NB_MODULE(_sparse, m) { m.def("registrations", &Registrations); } + +} // namespace +} // namespace jax diff --git a/jaxlib/cpu/sparse_kernels.cc b/jaxlib/cpu/sparse_kernels.cc new file mode 100644 index 000000000000..8000abca65cc --- /dev/null +++ b/jaxlib/cpu/sparse_kernels.cc @@ -0,0 +1,215 @@ +/* Copyright 2025 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/cpu/sparse_kernels.h" + +#include +#include +#include +#include + +#include "Eigen/Core" +#include "Eigen/SparseCore" +#include "xla/ffi/api/ffi.h" + +namespace ffi = xla::ffi; + +namespace jax { + +template +using SparseMatrixType = + Eigen::SparseMatrix; +template +using DenseMatrixType = + Eigen::Matrix; + +template +using InputMap = Eigen::Map; +template +using OutputMap = Eigen::Map; + +template +static ffi::Future CsrSparseDenseKernelImpl( + const InputMap>& lhs_matrix, + const InputMap>& rhs_matrix, + OutputMap>& out_matrix, + ffi::ThreadPool& thread_pool) { + // Rule of thumb to give each task at least 100k cycles to hide the cost of + // task scheduling. + // TODO(willfroom) Do we want to make this configurable? + constexpr int64_t kTargetCyclesPerTask = 100'000; + // Based on AVX (CPI 0.5 -> 2 IPC) + constexpr int64_t kScalarProductsPerCycle = 2 * 32 / sizeof(ElementType); + constexpr int64_t kTaskSize = kTargetCyclesPerTask * kScalarProductsPerCycle; + + if (lhs_matrix.nonZeros() * rhs_matrix.cols() <= kTaskSize || + thread_pool.num_threads() == 0) { + out_matrix.noalias() = lhs_matrix * rhs_matrix; + + ffi::Promise promise; + promise.SetAvailable(); + return ffi::Future(promise); + } else { + std::vector batch_sizes; + { + int64_t running_batch_nnz = 0; + int64_t running_number_rows = 0; + for (int row = 0; row < lhs_matrix.rows(); ++row) { + int64_t row_nnz = lhs_matrix.outerIndexPtr()[row + 1] - + lhs_matrix.outerIndexPtr()[row]; + // If there is no non-zero elements in a row the task still needs to + // write out a zero row we give each row a non-zero contribution to + // avoid the pathological case of a task having to write many rows where + // there is a large block of zero inputs. + running_batch_nnz += std::max(row_nnz, static_cast(1)); + running_number_rows++; + if (running_batch_nnz * rhs_matrix.cols() > kTaskSize) { + batch_sizes.push_back(running_number_rows); + running_batch_nnz = 0; + running_number_rows = 0; + } else if (row == lhs_matrix.rows() - 1 && running_number_rows > 0) { + batch_sizes.push_back(running_number_rows); + } + } + } + + ffi::CountDownPromise promise(batch_sizes.size()); + ffi::Future future(promise); + int64_t batch_start = 0; + for (int64_t size : batch_sizes) { + thread_pool.Schedule([out_matrix, lhs_matrix, rhs_matrix, batch_start, + size, promise]() mutable { + out_matrix.middleRows(batch_start, size).noalias() = + lhs_matrix.middleRows(batch_start, size) * rhs_matrix; + promise.CountDown(); + }); + batch_start += size; + } + return future; + } +} + +template +static ffi::Future CsrSparseDenseKernelTypedDispatch( + ffi::AnyBuffer lhs_data, ffi::AnyBuffer lhs_outer_indicies, + ffi::AnyBuffer lhs_inner_indicies, ffi::AnyBuffer rhs, + ffi::Result out, ffi::ThreadPool thread_pool) { + ffi::Span rhs_shape = rhs.dimensions(); + ffi::Span out_shape = out->dimensions(); + + InputMap> lhs_matrix( + out_shape[0], rhs_shape[0], lhs_data.element_count(), + lhs_outer_indicies.reinterpret_data(), + lhs_inner_indicies.reinterpret_data(), + lhs_data.reinterpret_data()); + + InputMap> rhs_matrix( + rhs.reinterpret_data(), rhs_shape[0], + rhs_shape.size() > 1 ? rhs_shape[1] : 1); + OutputMap> out_matrix( + out->reinterpret_data(), lhs_matrix.rows(), + rhs_matrix.cols()); + + return CsrSparseDenseKernelImpl( + lhs_matrix, rhs_matrix, out_matrix, thread_pool); +} + +template +static ffi::Future CsrSparseDenseKernelTypedDispatch( + ffi::AnyBuffer lhs_data, ffi::AnyBuffer lhs_outer_indicies, + ffi::AnyBuffer lhs_inner_indicies, ffi::AnyBuffer rhs, + ffi::Result out, ffi::ThreadPool thread_pool) { + if (lhs_outer_indicies.element_type() != lhs_inner_indicies.element_type()) { + ffi::Promise promise; + promise.SetError(ffi::Error(ffi::ErrorCode::kInvalidArgument, + "Sparse index type mismatch")); + return ffi::Future(promise); + } + + switch (lhs_outer_indicies.element_type()) { + case ffi::DataType::S32: + return CsrSparseDenseKernelTypedDispatch( + lhs_data, lhs_outer_indicies, lhs_inner_indicies, rhs, out, + thread_pool); + case ffi::DataType::S64: + return CsrSparseDenseKernelTypedDispatch( + lhs_data, lhs_outer_indicies, lhs_inner_indicies, rhs, out, + thread_pool); + default: + ffi::Promise promise; + promise.SetError(ffi::Error(ffi::ErrorCode::kInvalidArgument, + "Invalid index data type")); + return ffi::Future(promise); + } +} + +static ffi::Future CsrSparseDenseKernelDispatch( + ffi::AnyBuffer lhs_data, ffi::AnyBuffer lhs_outer_indicies, + ffi::AnyBuffer lhs_inner_indicies, ffi::AnyBuffer rhs, + ffi::Result out, ffi::ThreadPool thread_pool) { + if (lhs_data.element_type() != rhs.element_type() || + lhs_data.element_type() != out->element_type()) { + ffi::Promise promise; + promise.SetError( + ffi::Error(ffi::ErrorCode::kInvalidArgument, "Element type mismatch")); + return ffi::Future(promise); + } + + switch (lhs_data.element_type()) { + case ffi::DataType::S32: + return CsrSparseDenseKernelTypedDispatch( + lhs_data, lhs_outer_indicies, lhs_inner_indicies, rhs, out, + thread_pool); + case ffi::DataType::S64: + return CsrSparseDenseKernelTypedDispatch( + lhs_data, lhs_outer_indicies, lhs_inner_indicies, rhs, out, + thread_pool); + case ffi::DataType::F32: + return CsrSparseDenseKernelTypedDispatch( + lhs_data, lhs_outer_indicies, lhs_inner_indicies, rhs, out, + thread_pool); + case ffi::DataType::F64: + return CsrSparseDenseKernelTypedDispatch( + lhs_data, lhs_outer_indicies, lhs_inner_indicies, rhs, out, + thread_pool); + case ffi::DataType::C64: + return CsrSparseDenseKernelTypedDispatch>( + lhs_data, lhs_outer_indicies, lhs_inner_indicies, rhs, out, + thread_pool); + case ffi::DataType::C128: + return CsrSparseDenseKernelTypedDispatch>( + lhs_data, lhs_outer_indicies, lhs_inner_indicies, rhs, out, + thread_pool); + default: + ffi::Promise promise; + promise.SetError( + ffi::Error(ffi::ErrorCode::kInvalidArgument, "Invalid data type")); + return ffi::Future(promise); + } +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(cpu_csr_sparse_dense_ffi, + CsrSparseDenseKernelDispatch, + (ffi::Ffi::Bind() + .Arg(/*lhs_data*/) + .Arg( + /*lhs_outer_indicies*/) + .Arg( + /*lhs_inner_indicies*/) + .Arg(/*rhs*/) + .Ret(/*out*/) + .Ctx(/*thread_pool*/))); + +} // namespace jax diff --git a/jaxlib/cpu/sparse_kernels.h b/jaxlib/cpu/sparse_kernels.h new file mode 100644 index 000000000000..856b1da9d36c --- /dev/null +++ b/jaxlib/cpu/sparse_kernels.h @@ -0,0 +1,27 @@ +/* Copyright 2025 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_PY_JAX_JAXLIB_CPU_SPARSE_KERNELS_H_ +#define THIRD_PARTY_PY_JAX_JAXLIB_CPU_SPARSE_KERNELS_H_ + +#include "xla/ffi/api/ffi.h" + +namespace jax { + +XLA_FFI_DECLARE_HANDLER_SYMBOL(cpu_csr_sparse_dense_ffi); + +} // namespace jax + +#endif // THIRD_PARTY_PY_JAX_JAXLIB_CPU_SPARSE_KERNELS_H_ diff --git a/jaxlib/cpu_sparse.py b/jaxlib/cpu_sparse.py new file mode 100644 index 000000000000..ed43b3ee0f92 --- /dev/null +++ b/jaxlib/cpu_sparse.py @@ -0,0 +1,27 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +from .cpu import _sparse + + +def registrations() -> dict[str, list[tuple[str, Any, int]]]: + api_version = 1 + return { + "cpu": [ + (name, value, api_version) + for name, value in _sparse.registrations().items() + ] + } diff --git a/jaxlib/cuda/BUILD b/jaxlib/cuda/BUILD index a9bd35b7768d..5cc401e14eb0 100644 --- a/jaxlib/cuda/BUILD +++ b/jaxlib/cuda/BUILD @@ -64,7 +64,6 @@ cc_library( "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", - "@local_config_cuda//cuda:cublas_headers", "@local_config_cuda//cuda:cuda_headers", "@xla//xla/tsl/cuda:cupti", "@xla//xla/tsl/cuda:cusolver", @@ -89,7 +88,7 @@ cc_library( deps = [ ":cuda_gpu_kernel_helpers", ":cuda_vendor", - "//jaxlib:handle_pool", + "//jaxlib/gpu:handle_pool", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/synchronization", "@local_config_cuda//cuda:cuda_headers", @@ -98,55 +97,6 @@ cc_library( ], ) -cc_library( - name = "cublas_kernels", - srcs = ["//jaxlib/gpu:blas_kernels.cc"], - hdrs = ["//jaxlib/gpu:blas_kernels.h"], - deps = [ - ":cuda_blas_handle_pool", - ":cuda_gpu_kernel_helpers", - ":cuda_make_batch_pointers", - ":cuda_vendor", - "//jaxlib:kernel_helpers", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/base", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/hash", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@local_config_cuda//cuda:cublas_headers", - "@local_config_cuda//cuda:cuda_headers", - "@xla//xla/service:custom_call_status", - "@xla//xla/tsl/cuda:cublas", - "@xla//xla/tsl/cuda:cudart", - ], -) - -nanobind_extension( - name = "_blas", - srcs = ["//jaxlib/gpu:blas.cc"], - copts = [ - "-fexceptions", - "-fno-strict-aliasing", - ], - features = ["-use_header_modules"], - module_name = "_blas", - deps = [ - ":cublas_kernels", - ":cuda_vendor", - "//jaxlib:kernel_nanobind_helpers", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/strings:str_format", - "@nanobind", - "@xla//xla/tsl/cuda:cublas", - "@xla//xla/tsl/cuda:cudart", - "@xla//xla/tsl/python/lib/core:numpy", - ], -) - cc_library( name = "cudnn_rnn_kernels", srcs = ["//jaxlib/gpu:rnn_kernels.cc"], @@ -155,14 +105,14 @@ cc_library( ":cuda_gpu_kernel_helpers", ":cuda_vendor", ":ffi_wrapper", - "//jaxlib:handle_pool", "//jaxlib:kernel_helpers", + "//jaxlib/gpu:handle_pool", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", "@local_config_cuda//cuda:cuda_headers", "@xla//xla/ffi/api:ffi", - "@xla//xla/service:custom_call_status", "@xla//xla/tsl/cuda:cudart", "@xla//xla/tsl/cuda:cudnn", ], @@ -195,7 +145,7 @@ cc_library( deps = [ ":cuda_gpu_kernel_helpers", ":cuda_vendor", - "//jaxlib:handle_pool", + "//jaxlib/gpu:handle_pool", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/synchronization", "@local_config_cuda//cuda:cuda_headers", @@ -204,24 +154,6 @@ cc_library( ], ) -cc_library( - name = "cusolver_kernels", - srcs = ["//jaxlib/gpu:solver_kernels.cc"], - hdrs = ["//jaxlib/gpu:solver_kernels.h"], - deps = [ - ":cuda_gpu_kernel_helpers", - ":cuda_solver_handle_pool", - ":cuda_vendor", - "//jaxlib:kernel_helpers", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@local_config_cuda//cuda:cuda_headers", - "@xla//xla/service:custom_call_status", - "@xla//xla/tsl/cuda:cudart", - "@xla//xla/tsl/cuda:cusolver", - ], -) - cc_library( name = "cusolver_interface", srcs = ["//jaxlib/gpu:solver_interface.cc"], @@ -272,21 +204,14 @@ nanobind_extension( features = ["-use_header_modules"], module_name = "_solver", deps = [ - ":cuda_gpu_kernel_helpers", - ":cuda_solver_handle_pool", ":cuda_vendor", - ":cusolver_kernels", ":cusolver_kernels_ffi", "//jaxlib:kernel_nanobind_helpers", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:str_format", "@local_config_cuda//cuda:cuda_headers", "@nanobind", "@xla//xla/tsl/cuda:cublas", "@xla//xla/tsl/cuda:cudart", "@xla//xla/tsl/cuda:cusolver", - "@xla//xla/tsl/python/lib/core:numpy", ], ) @@ -308,15 +233,16 @@ cc_library( ":cuda_gpu_kernel_helpers", ":cuda_vendor", ":ffi_wrapper", - "//jaxlib:handle_pool", + "//jaxlib:ffi_helpers", "//jaxlib:kernel_helpers", + "//jaxlib/gpu:handle_pool", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", "@local_config_cuda//cuda:cuda_headers", "@xla//xla/ffi/api:ffi", - "@xla//xla/service:custom_call_status", "@xla//xla/tsl/cuda:cudart", "@xla//xla/tsl/cuda:cusparse", ], @@ -423,7 +349,6 @@ cc_library( "@local_config_cuda//cuda:cuda_headers", "@xla//xla/ffi/api:c_api", "@xla//xla/ffi/api:ffi", - "@xla//xla/service:custom_call_status", ], ) @@ -439,7 +364,6 @@ cuda_library( "//jaxlib:kernel_helpers", "@local_config_cuda//cuda:cuda_headers", "@xla//xla/ffi/api:ffi", - "@xla//xla/service:custom_call_status", ], ) @@ -455,6 +379,7 @@ nanobind_extension( deps = [ ":cuda_gpu_kernel_helpers", ":cuda_prng_kernels", + ":cuda_vendor", "//jaxlib:kernel_nanobind_helpers", "@local_config_cuda//cuda:cuda_headers", "@nanobind", @@ -511,15 +436,14 @@ cc_library( srcs = ["//jaxlib/gpu:gpu_kernels.cc"], visibility = ["//visibility:public"], deps = [ - ":cublas_kernels", ":cuda_linalg_kernels", ":cuda_prng_kernels", ":cuda_vendor", ":cudnn_rnn_kernels", - ":cusolver_kernels", ":cusolver_kernels_ffi", ":cusparse_kernels", ":triton_kernels", + "//jaxlib/mosaic/gpu:mosaic_gpu_support", "@xla//xla/ffi/api:c_api", "@xla//xla/ffi/api:ffi", "@xla//xla/service:custom_call_target_registry", @@ -545,8 +469,8 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/synchronization", - "@tsl//tsl/platform:env", "@xla//xla/service:custom_call_status", "@xla//xla/stream_executor/cuda:cuda_asm_compiler", "@xla//xla/tsl/cuda:cudart", @@ -586,7 +510,9 @@ nanobind_extension( "//jaxlib:absl_status_casters", "//jaxlib:kernel_nanobind_helpers", "//jaxlib/gpu:triton_cc_proto", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", "@nanobind", ], ) @@ -644,7 +570,6 @@ nanobind_extension( py_library( name = "cuda_gpu_support", deps = [ - ":_blas", ":_hybrid", ":_linalg", ":_prng", @@ -657,11 +582,52 @@ py_library( ], ) +cc_library( + name = "py_client_gpu", + srcs = ["//jaxlib/gpu:py_client_gpu.cc"], + hdrs = ["//jaxlib/gpu:py_client_gpu.h"], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + ":cuda_vendor", + "//jaxlib:ffi", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@dlpack", + "@nanobind", + "@xla//third_party/python_runtime:headers", # buildcleaner: keep + "@xla//xla:comparison_util", + "@xla//xla:shape_util", + "@xla//xla:util", + "@xla//xla:xla_data_proto_cc", + "@xla//xla/ffi:ffi_api", + "@xla//xla/ffi/api:ffi", + "@xla//xla/pjrt:host_callback", + "@xla//xla/pjrt:transpose", + "@xla//xla/python:nb_numpy", + "@xla//xla/python:types", + "@xla//xla/service:platform_util", + ], +) + nanobind_extension( name = "cuda_plugin_extension", srcs = ["cuda_plugin_extension.cc"], module_name = "cuda_plugin_extension", deps = [ + ":py_client_gpu", + "//jaxlib:kernel_nanobind_helpers", "//jaxlib/gpu:gpu_plugin_extension", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", diff --git a/jaxlib/cuda/cuda_plugin_extension.cc b/jaxlib/cuda/cuda_plugin_extension.cc index 8d8514bd2740..383bbf7731aa 100644 --- a/jaxlib/cuda/cuda_plugin_extension.cc +++ b/jaxlib/cuda/cuda_plugin_extension.cc @@ -16,17 +16,20 @@ limitations under the License. #include #include -#include "nanobind/nanobind.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "third_party/gpus/cuda/include/cuda.h" +#include "nanobind/nanobind.h" #include "jaxlib/gpu/gpu_plugin_extension.h" +#include "jaxlib/gpu/py_client_gpu.h" +#include "jaxlib/kernel_nanobind_helpers.h" #include "xla/pjrt/status_casters.h" namespace nb = nanobind; namespace xla { namespace { + static std::string ToString(CUresult result) { const char* error_name; if (cuGetErrorName(result, &error_name)) { @@ -38,10 +41,30 @@ static std::string ToString(CUresult result) { } return absl::StrCat(error_name, ": ", error_string); } + +nb::dict FfiRegistrations() { + nb::dict dict; + nb::dict gpu_callback_dict; + gpu_callback_dict["instantiate"] = + jax::EncapsulateFfiHandler(jax::cuda::kGpuTransposePlanCacheInstantiate); + gpu_callback_dict["execute"] = + jax::EncapsulateFfiHandler(jax::cuda::kXlaFfiPythonGpuCallback); + dict["xla_ffi_python_gpu_callback"] = gpu_callback_dict; + dict["xla_ffi_partitioned_python_gpu_callback"] = gpu_callback_dict; + dict["xla_buffer_python_gpu_callback"] = + jax::EncapsulateFfiHandler(jax::cuda::kXlaBufferPythonGpuCallback); + dict["xla_buffer_python_gpu_callback_cmd_buffer"] = + jax::EncapsulateFfiHandler( + jax::cuda::kXlaBufferPythonGpuCallbackCmdBuffer); + return dict; +} + } // namespace NB_MODULE(cuda_plugin_extension, m) { BuildGpuPluginExtension(m); + m.def("ffi_registrations", &FfiRegistrations); + m.def( "get_device_ordinal", [](std::intptr_t data_value) { diff --git a/jaxlib/cuda/versions.cc b/jaxlib/cuda/versions.cc index 8d6577f46709..d9f9f4c86865 100644 --- a/jaxlib/cuda/versions.cc +++ b/jaxlib/cuda/versions.cc @@ -13,9 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "jaxlib/cuda/versions_helpers.h" - #include "nanobind/nanobind.h" +#include "jaxlib/cuda/versions_helpers.h" #include "jaxlib/gpu/vendor.h" namespace jax::cuda { diff --git a/jaxlib/cuda/versions_helpers.cc b/jaxlib/cuda/versions_helpers.cc index d42199d37467..508a92c326cb 100644 --- a/jaxlib/cuda/versions_helpers.cc +++ b/jaxlib/cuda/versions_helpers.cc @@ -16,6 +16,7 @@ limitations under the License. #include "jaxlib/cuda/versions_helpers.h" #include +#include #include #include "absl/base/dynamic_annotations.h" diff --git a/jaxlib/custom_call_sharding.cc b/jaxlib/custom_call_sharding.cc new file mode 100644 index 000000000000..3e16768d0c29 --- /dev/null +++ b/jaxlib/custom_call_sharding.cc @@ -0,0 +1,346 @@ +/* Copyright 2022 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/custom_call_sharding.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/optional.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "nanobind/stl/tuple.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "xla/hlo/ir/hlo_sharding.h" +#include "xla/hlo/utils/hlo_sharding_util.h" +#include "xla/pjrt/c/pjrt_c_api.h" +#include "xla/pjrt/c/pjrt_c_api_custom_partitioner_extension.h" +#include "xla/pjrt/c/pjrt_c_api_helpers.h" +#include "xla/pjrt/status_casters.h" +#include "xla/python/custom_call_batch_partitioner.h" +#include "xla/python/custom_partition_callback.h" +#include "xla/python/inspect_sharding.h" +#include "xla/shape.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/util.h" + +namespace xla { + +namespace nb = ::nanobind; + +class PyCustomCallPartitionerCallbacks { + public: + PyCustomCallPartitionerCallbacks(nb::object prop_user_sharding, + nb::object partition, + nb::object infer_sharding_from_operands) + : prop_user_sharding_(prop_user_sharding), + partition_(partition), + infer_sharding_from_operands_(infer_sharding_from_operands) { + callbacks_.version = 0; + callbacks_.private_data = this; + callbacks_.dtor = +[](JAX_CustomCallPartitioner_Callbacks* self) { + delete GetSelfPtr(self); + }; + callbacks_.partition = +[](JAX_CustomCallPartitioner_Callbacks* self, + JAX_CustomCallPartitioner_Partition_Args* args) { + jax::PopulateResults(GetSelfPtr(self)->CallPartition(args), args); + }; + callbacks_.infer_sharding = + +[](JAX_CustomCallPartitioner_Callbacks* self, + JAX_CustomCallPartitioner_InferShardingFromOperands_Args* args) { + jax::PopulateResults( + GetSelfPtr(self)->CallInferShardingFromOperands(args), args); + }; + callbacks_.propagate_user_sharding = + +[](JAX_CustomCallPartitioner_Callbacks* self, + JAX_CustomCallPartitioner_PropagateUserSharding_Args* args) { + jax::PopulateResults( + GetSelfPtr(self)->CallPropagateUserSharding(args), args); + }; + } + + absl::StatusOr< + std::tuple, xla::HloSharding>> + CallPartition(JAX_CustomCallPartitioner_Partition_Args* args) const { + if (args->header.api_version != 0) { + return absl::InternalError("API version mismatch."); + } + TF_ASSIGN_OR_RETURN(auto args_tuple, jax::ReadArgs(args)); + std::vector shapes = std::move(std::get<0>(args_tuple)); + std::vector> shardings = + std::move(std::get<1>(args_tuple)); + xla::Shape result_shape = std::move(std::get<2>(args_tuple)); + std::optional result_sharding = + std::move(std::get<3>(args_tuple)); + absl::string_view backend_config = std::move(std::get<4>(args_tuple)); + + { + nb::gil_scoped_acquire gil; + try { + auto py_result = + partition_(shapes, shardings, result_shape, result_sharding, + nb::bytes(backend_config.data(), backend_config.size())); + try { + auto [ir, arg_shardings, result_sharding] = nb::cast< + std::tuple, HloSharding>>( + py_result); + if (arg_shardings.size() != args->num_args) { + return xla::Internal( + "Shardings returned from partitioning: lengths must match: %d " + "vs %d", + arg_shardings.size(), args->num_args); + } + return std::make_tuple(std::string(ir.c_str(), ir.size()), + std::move(arg_shardings), + std::move(result_sharding)); + } catch (const nb::cast_error& e) { + return xla::Internal( + "Shardings returned from partitioning: expected " + "Tuple[bytes, List[HloSharding], HloSharding] got: %s", + nb::cast(nb::repr(py_result))); + } + } catch (const nb::python_error& e) { + return xla::Internal("custom_partitioner: %s", e.what()); + } + } + } + + absl::StatusOr> CallInferShardingFromOperands( + JAX_CustomCallPartitioner_InferShardingFromOperands_Args* args) const { + if (args->header.api_version != 0) { + return absl::InternalError("API version mismatch."); + } + TF_ASSIGN_OR_RETURN(auto args_tuple, jax::ReadArgs(args)); + std::vector arg_shapes = std::move(std::get<0>(args_tuple)); + std::vector> arg_shardings = + std::move(std::get<1>(args_tuple)); + xla::Shape result_shape = std::move(std::get<2>(args_tuple)); + absl::string_view backend_config = std::move(std::get<3>(args_tuple)); + + std::optional result; + nb::gil_scoped_acquire gil; + try { + auto py_result = infer_sharding_from_operands_( + arg_shapes, arg_shardings, result_shape, + nb::bytes(backend_config.data(), backend_config.size())); + if (py_result.is_none()) { + return std::nullopt; + } + return nb::cast(py_result); + } catch (const nb::python_error& e) { + return xla::Internal("custom_partitioner: %s", e.what()); + } + } + + absl::StatusOr CallPropagateUserSharding( + JAX_CustomCallPartitioner_PropagateUserSharding_Args* args) const { + if (args->header.api_version != 0) { + return absl::InternalError("API version mismatch."); + } + TF_ASSIGN_OR_RETURN(auto args_tuple, jax::ReadArgs(args)); + xla::HloSharding result_sharding = std::move(std::get<0>(args_tuple)); + xla::Shape result_shape = std::move(std::get<1>(args_tuple)); + absl::string_view backend_config = std::move(std::get<2>(args_tuple)); + + nb::gil_scoped_acquire gil; + try { + // TODO(parkers): expand this API to handle the `user` sharding. + // The user is used when the custom call returns a Tuple and + // the user is a get-tuple-element. In this case we must update only + // part of the sharding spec. + auto result = nb::cast(prop_user_sharding_( + result_sharding, result_shape, + nb::bytes(backend_config.data(), backend_config.size()))); + return result; + } catch (const nb::python_error& e) { + return xla::Internal("custom_partitioner: %s", e.what()); + } + } + + JAX_CustomCallPartitioner_Callbacks* callbacks() { return &callbacks_; } + + private: + static PyCustomCallPartitionerCallbacks* GetSelfPtr( + JAX_CustomCallPartitioner_Callbacks* callbacks) { + return reinterpret_cast( + callbacks->private_data); + } + + JAX_CustomCallPartitioner_Callbacks callbacks_; + nb::object prop_user_sharding_; + nb::object partition_; + nb::object infer_sharding_from_operands_; +}; + +namespace { + +void CallInspectSharding(void* obj, JAX_InspectSharding_Callback_Args* args) { + std::optional arg = jax::InspectShardingReadArgs(args); + if (!arg.has_value()) { + return; + } + try { + nb::gil_scoped_acquire gil; + nb::handle(reinterpret_cast(obj))(*std::move(arg)); + } catch (const nb::python_error& e) { + jax::InspectShardingSetError(args, std::string(e.what())); + } +} + +} // namespace + +void BuildCustomCallShardingPybindAPI(nb::module_& m) { + m.def( + "register_custom_call_partitioner", + [](std::string name, nb::object prop_user_sharding, nb::object partition, + nb::object infer_sharding_from_operands, + bool can_side_effecting_have_replicated_sharding, + std::optional c_api) { + auto* c_fns = + (new PyCustomCallPartitionerCallbacks(prop_user_sharding, partition, + infer_sharding_from_operands)) + ->callbacks(); + c_fns->can_side_effecting_have_replicated_sharding = + can_side_effecting_have_replicated_sharding; + if (!c_api.has_value()) { + RegisterCustomCallPartitioner( + name, jax::CreateCApiCustomCallPartitioner(c_fns)); + return; + } + + if (absl::string_view(c_api->name()) != "pjrt_c_api") { + throw absl::InvalidArgumentError( + "Argument to register_custom_call_partitioner was not a " + "pjrt_c_api capsule."); + } + auto* c_api_value = static_cast(c_api->data()); + PJRT_Custom_Partitioner_Extension* extension = + pjrt::FindExtension( + c_api_value, + PJRT_Extension_Type::PJRT_Extension_Type_Custom_Partitioner); + if (extension == nullptr) { + return; + } + PJRT_Register_Custom_Partitioner_Args args; + args.struct_size = PJRT_Register_Custom_Partitioner_Args_STRUCT_SIZE; + args.name = name.c_str(); + args.name_size = name.size(); + args.callbacks = c_fns; + PJRT_Error* error = + reinterpret_cast( + extension) + ->register_custom_partitioner(&args); + std::unique_ptr error_ptr( + error, pjrt::MakeErrorDeleter(c_api_value)); + ThrowIfError(pjrt::PjrtErrorToStatus(error_ptr.get(), c_api_value)); + }, + R"(Registers a partitioner for a custom-call operation. + +Args: + name: custom_call_target to match. + prop_user_sharding: Custom backwards sharding propagation rule. + Takes result sharding and returns the instruction sharding. + partition: Lowering rule. Takes operand and result shardings and returns + a generated HLO and sharding specs. The spmd lowerer first reshards + to match the returned sharding specs and then inserts the generated hlo. + infer_sharding_from_operands: Custom forwards sharding propagation rule. + Takes operand sharding and returns the instruction sharding. + can_side_effecting_have_replicated_sharding: Side effecting ops are not + allowed to have replicated sharding. Pass true to disable this check. + c_api: Optional `PJRT_Api*` if it is called with a plugin. This is safe to + call on plugins that do not implement the custom partitioner extension +)", + nb::arg("name"), nb::arg("prop_user_sharding"), nb::arg("partition"), + nb::arg("infer_sharding_from_operands"), + nb::arg("can_side_effecting_have_replicated_sharding") = false, + nb::arg("c_api").none() = std::nullopt); + m.def("encode_inspect_sharding_callback", + [](nb::object handler) -> nb::bytes { + JAX_InspectSharding_Callback cb; + cb.call = &CallInspectSharding; + cb.data = handler.ptr(); + char bytes[sizeof(JAX_InspectSharding_Callback)]; + std::memcpy(&bytes, &cb, sizeof(JAX_InspectSharding_Callback)); + return nb::bytes(bytes, sizeof(JAX_InspectSharding_Callback)); + }); + + nb::module_ hlo_sharding_util_m = m.def_submodule( + "hlo_sharding_util", "Utilities for manipulating HloSharding."); + hlo_sharding_util_m.def( + "PartiallyReplicateTiledShardingOnDims", + [](const HloSharding& sharding, std::vector dims) { + return hlo_sharding_util::PartiallyReplicateTiledShardingOnDims( + sharding, dims); + }); + + m.def( + "register_custom_call_as_batch_partitionable", + [](std::string target_name, std::optional c_api) { + if (!c_api.has_value()) { + RegisterCustomCallPartitioner( + target_name, std::make_unique()); + return; + } + if (absl::string_view(c_api->name()) != "pjrt_c_api") { + throw absl::InvalidArgumentError( + "Argument to register_custom_call_partitioner was not a " + "pjrt_c_api capsule."); + } + auto* c_api_value = static_cast(c_api->data()); + PJRT_Custom_Partitioner_Extension* extension = + pjrt::FindExtension( + c_api_value, + PJRT_Extension_Type::PJRT_Extension_Type_Custom_Partitioner); + if (extension == nullptr) { + return; + } + PJRT_Register_Batch_Partitionable_Args args; + args.struct_size = PJRT_Register_Batch_Partitionable_Args_STRUCT_SIZE; + args.name = target_name.c_str(); + args.name_size = target_name.size(); + PJRT_Error* error = extension->register_batch_partitionable(&args); + std::unique_ptr error_ptr( + error, pjrt::MakeErrorDeleter(c_api_value)); + ThrowIfError(pjrt::PjrtErrorToStatus(error_ptr.get(), c_api_value)); + }, + R"(Registers a custom call as batch partitionable. + +If a custom call is "batch partitionable", it means that it can be trivially +partitioned on some number of (leading) dimensions, with the same call being +executed independently on each shard of data. If the data are sharded on +non-batch dimensions, partitioning will re-shard the data to be replicated on +the non-batch dimensions. + +Args: + target_name: the target name of the batch partitionable custom call. + c_api: optional `PJRT_Api*` to support registration via a PJRT plugin. +)", + nb::arg("target_name"), nb::arg("c_api").none() = std::nullopt); +} + +} // namespace xla diff --git a/jaxlib/custom_call_sharding.h b/jaxlib/custom_call_sharding.h new file mode 100644 index 000000000000..454f60c3a03d --- /dev/null +++ b/jaxlib/custom_call_sharding.h @@ -0,0 +1,28 @@ +/* Copyright 2022 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_CUSTOM_CALL_SHARDING_H_ +#define JAXLIB_CUSTOM_CALL_SHARDING_H_ + +// placeholder for index annotation headers +#include "nanobind/nanobind.h" + +namespace xla { + +void BuildCustomCallShardingPybindAPI(nanobind::module_& m); + +} // namespace xla + +#endif // JAXLIB_CUSTOM_CALL_SHARDING_H_ diff --git a/jaxlib/dlpack.cc b/jaxlib/dlpack.cc new file mode 100644 index 000000000000..c58eac81b9a7 --- /dev/null +++ b/jaxlib/dlpack.cc @@ -0,0 +1,503 @@ +/* Copyright 2020 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/dlpack.h" + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "include/dlpack/dlpack.h" +#include "llvm/Support/Casting.h" +#include "nanobind/nanobind.h" +#include "nanobind/ndarray.h" +#include "jaxlib/nb_class_ptr.h" +#include "jaxlib/py_array.h" +#include "jaxlib/py_client.h" +#include "jaxlib/python_ref_manager.h" +#include "jaxlib/traceback.h" +#include "jaxlib/util.h" +#include "jaxlib/dlpack_support.h" +#include "xla/layout.h" +#include "xla/pjrt/exceptions.h" +#include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_common.h" +#include "xla/pjrt/pjrt_compiler.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/pjrt_ifrt/pjrt_array.h" +#include "xla/python/pjrt_ifrt/pjrt_client.h" +#include "xla/python/pjrt_ifrt/pjrt_device.h" +#include "xla/python/types.h" +#include "xla/shape_util.h" +#include "xla/status_macros.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" + +namespace nb = nanobind; + +namespace xla { +namespace { + +const char* const kDlTensorCapsuleName = "dltensor"; + +struct DLPackTensor { + ~DLPackTensor(); + + // `buffer_reference` is populated if we have shared (read-only) access. + nb::object buffer_reference; + + // `external_reference` is always populated. + std::unique_ptr external_reference; + + std::vector shape; + std::vector strides; + DLManagedTensor tensor; +}; + +DLPackTensor::~DLPackTensor() { + if (buffer_reference) { + GlobalPyRefManager()->AddGarbage( + absl::MakeSpan(&buffer_reference, /*size=*/1)); + } +} + +void DLPackTensorDeleter(DLManagedTensor* t) { + if (t) { + delete static_cast(t->manager_ctx); + } +} + +absl::StatusOr> StridesToLayout( + absl::Span dims, absl::Span strides) { + CHECK_EQ(dims.size(), strides.size()); + std::vector minor_to_major(dims.size()); + std::iota(minor_to_major.begin(), minor_to_major.end(), 0); + absl::c_sort(minor_to_major, [&](int a, int b) { + if (strides[a] < strides[b]) { + return true; + } + if (strides[a] > strides[b]) { + return false; + } + // If two dimensions have the same stride, prefer the major-to-minor + // interpretation of the ordering, since that's what JAX wants. + return b < a; + }); + int64_t stride = 1; + for (int64_t d : minor_to_major) { + if (dims[d] > 1 && strides[d] != stride) { + return Unimplemented( + "Only DLPack tensors with trivial (compact) striding are supported; " + "i.e., tensors whose striding represents a transposition of the " + "underlying buffer but not broadcasting. Dimensions were: [%s], " + "strides were [%s].", + absl::StrJoin(dims, ","), absl::StrJoin(strides, ",")); + } + stride *= dims[d]; + } + return minor_to_major; +} + +absl::StatusOr DLDeviceTypeForDevice(const PjRtDevice& device) { + if (device.client()->platform_id() == CpuId()) { + return kDLCPU; + } else if (device.client()->platform_id() == CudaId()) { + return kDLCUDA; + } else if (device.client()->platform_id() == RocmId()) { + return kDLROCM; + } + return InvalidArgument("Device %s cannot be used as a DLPack device.", + device.DebugString()); +} + +absl::StatusOr DLDeviceForDevice(const PjRtDevice& device) { + DLDevice context; + TF_ASSIGN_OR_RETURN(context.device_type, DLDeviceTypeForDevice(device)); + context.device_id = device.local_hardware_id().value(); + return context; +} + +absl::StatusOr DeviceForDLDevice(const PjRtClient* cpu_client, + const PjRtClient* gpu_client, + const DLDevice& context) { + switch (context.device_type) { + case kDLCPU: + if (cpu_client == nullptr) { + return InvalidArgument( + "DLPack tensor is on CPU, but no CPU backend was provided."); + } + TF_RET_CHECK(cpu_client->platform_id() == CpuId()); + return cpu_client->LookupAddressableDevice( + xla::PjRtLocalDeviceId(context.device_id)); + case kDLCUDA: + if (gpu_client == nullptr) { + return InvalidArgument( + "DLPack tensor is on GPU, but no GPU backend was provided."); + } + TF_RET_CHECK(gpu_client->platform_id() == CudaId()); + return gpu_client->LookupAddressableDevice( + xla::PjRtLocalDeviceId(context.device_id)); + case kDLROCM: + if (gpu_client == nullptr) { + return InvalidArgument( + "DLPack tensor is on GPU, but no GPU backend was provided."); + } + TF_RET_CHECK(gpu_client->platform_id() == RocmId()); + return gpu_client->LookupAddressableDevice( + xla::PjRtLocalDeviceId(context.device_id)); + default: + return InvalidArgument("Unknown/unsupported DLPack device type %d", + context.device_type); + } +} + +absl::Status VerifyDType(const DLTensor& dl_tensor) { + if (dl_tensor.dtype.bits % 8 != 0) { + return InvalidArgument( + "Unsupported DLPack tensor dtype: bits should be a multiple of 8, got " + "%d", + dl_tensor.dtype.bits); + } + + if (dl_tensor.dtype.lanes != 1) { + return InvalidArgument( + "Unsupported DLPack tensor dtype: lanes should be equal to 1, got %d", + dl_tensor.dtype.lanes); + } + + return absl::OkStatus(); +} + +absl::StatusOr> GetByteStrides(const DLTensor& dl_tensor) { + TF_RETURN_IF_ERROR(VerifyDType(dl_tensor)); + + // Convert element strides from the number of elements to the number of bytes. + std::vector strides; + strides.reserve(dl_tensor.ndim); + for (int i = 0; i < dl_tensor.ndim; ++i) { + strides.push_back(dl_tensor.strides[i] * dl_tensor.dtype.bits / 8); + } + return strides; +} + +absl::StatusOr> MakePjrtBuffer( + PjRtDevice& device, ::DLManagedTensor* dlmt, const Shape& shape, + PrimitiveType element_type, absl::Span dimensions, + std::optional stream = std::nullopt) { + std::function on_delete_callback; + if (dlmt->deleter) { + on_delete_callback = [dlmt]() { dlmt->deleter(dlmt); }; + } + + // First try to create a view. + void* data = + static_cast(dlmt->dl_tensor.data) + dlmt->dl_tensor.byte_offset; + auto result = device.client()->CreateViewOfDeviceBuffer( + data, shape, *device.default_memory_space(), on_delete_callback, stream); + + // If that fails with invalid argument, it's possibly because of the incorrect + // alignment. If we're on CPU, we can create a copy of buffer. + if (result.status().code() == absl::StatusCode::kInvalidArgument && + dlmt->dl_tensor.device.device_type == kDLCPU) { + LOG(WARNING) << "DLPack buffer is not aligned (data at: " << data + << "). Creating a copy."; + + // Convert tensor strides (expressed in number of elements) to byte strides. + std::optional> byte_strides; + if (dlmt->dl_tensor.strides) { + TF_ASSIGN_OR_RETURN(byte_strides, GetByteStrides(dlmt->dl_tensor)); + } + + TF_ASSIGN_OR_RETURN(auto* memory_space, device.default_memory_space()); + + // Create a copy. + result = device.client()->BufferFromHostBuffer( + data, element_type, dimensions, byte_strides, + PjRtClient::HostBufferSemantics::kMutableZeroCopy, on_delete_callback, + memory_space, /*device_layout=*/nullptr); + } + return result; +} + +} // namespace + +absl::StatusOr BufferToDLPackManagedTensor( + nb::handle py_buffer, std::optional stream) { + ifrt::Array* ifrt_array = nb::cast(py_buffer).ifrt_array(); + if (ifrt_array == nullptr) { + return Unimplemented( + "BufferToDLPackManagedTensor called on deleted array."); + } + auto* arr = llvm::dyn_cast_or_null(ifrt_array); + if (arr == nullptr) { + throw XlaRuntimeError( + "This operation is implemented for a PjRt-compatible backend only."); + } + PjRtBuffer* pjrt_buffer = arr->pjrt_buffers().front().get(); + + if (pjrt_buffer->IsTuple()) { + return Unimplemented( + "BufferToDLPackManagedTensor is not implemented for tuple " + "buffers."); + } + if (pjrt_buffer->has_dynamic_dimensions()) { + return Unimplemented("DynamicShape is not implemented in DLPack."); + } + + auto pack = std::make_unique(); + DLTensor& dt = pack->tensor.dl_tensor; + { + // AcquireExternalReference may block; there are no API guarantees. + GlobalPyRefManager()->CollectGarbage(); + nb::gil_scoped_release gil_release; + TF_ASSIGN_OR_RETURN(pack->external_reference, + pjrt_buffer->AcquireExternalReference()); + if (stream) { + TF_RETURN_IF_ERROR( + pack->external_reference->WaitUntilBufferReadyOnStream(*stream)); + } else { + TF_RETURN_IF_ERROR( + AwaitBuffersReady(absl::MakeConstSpan(&ifrt_array, 1))); + } + } + pack->buffer_reference = nb::borrow(py_buffer); + + dt.data = pack->external_reference->OpaqueDeviceMemoryDataPointer(); + pack->tensor.manager_ctx = pack.get(); + pack->tensor.deleter = DLPackTensorDeleter; + TF_ASSIGN_OR_RETURN(dt.device, DLDeviceForDevice(*pjrt_buffer->device())); + dt.device.device_id = pjrt_buffer->device()->local_hardware_id().value(); + dt.ndim = pjrt_buffer->dimensions().size(); + TF_ASSIGN_OR_RETURN(dt.dtype, + PrimitiveTypeToDLDataType(pjrt_buffer->element_type())); + + pack->shape = std::vector(pjrt_buffer->dimensions().begin(), + pjrt_buffer->dimensions().end()); + + // TODO(b/327524065): use PjRtLayout directly instead of xla::Layout + Layout xla_layout = pjrt_buffer->layout()->xla_layout(); + pack->strides = StridesForShape(pjrt_buffer->element_type(), + pjrt_buffer->dimensions(), xla_layout); + + dt.shape = reinterpret_cast(pack->shape.data()); + dt.strides = reinterpret_cast(pack->strides.data()); + dt.byte_offset = 0; + + // We cannot use nanobind's capsule object constructor because we need to + // detect if the capsule name has been changed in the deleter, but nanobind + // hides the underlying Python object from the deleter. + nb::capsule capsule = nb::steal( + PyCapsule_New(&pack.release()->tensor, kDlTensorCapsuleName, + [](PyObject* obj) noexcept { + DLManagedTensor* dlmt = static_cast( + PyCapsule_GetPointer(obj, kDlTensorCapsuleName)); + if (dlmt) { + DLPackTensorDeleter(dlmt); + } else { + // The tensor has been deleted. Clear any error from + // PyCapsule_GetPointer. + PyErr_Clear(); + } + })); + if (!capsule.ptr()) { + throw nb::python_error(); + } + return capsule; +} + +absl::StatusOr DLPackManagedTensorToBuffer( + const nb::capsule& tensor, std::optional> cpu_client, + std::optional> gpu_client) { + // TODO(hyeontaek): This is a potential target for an IFRT client to multiplex + // multiple PjRt clients. Devices from these PjRt clients could be expressed + // as a unified set of IFRT devices. + auto* cpu_pjrt_client = cpu_client ? (*cpu_client)->pjrt_client() : nullptr; + auto* gpu_pjrt_client = gpu_client ? (*gpu_client)->pjrt_client() : nullptr; + + if (absl::string_view(tensor.name()) != kDlTensorCapsuleName) { + return InvalidArgument( + "DLPack tensor must be a capsule with name \"dltensor\", got \"%s\". " + "Note that a DLPack tensor may be consumed at most once.", + absl::string_view(tensor.name())); + } + DLManagedTensor* dlmt = static_cast(tensor.data()); + if (dlmt->dl_tensor.ndim < 0) { + return InvalidArgument( + "Number of dimensions in DLManagedTensor must be nonnegative, got %d", + dlmt->dl_tensor.ndim); + } + TF_ASSIGN_OR_RETURN(PjRtDevice * device, + DeviceForDLDevice(cpu_client ? cpu_pjrt_client : nullptr, + gpu_client ? gpu_pjrt_client : nullptr, + dlmt->dl_tensor.device)); + absl::Span dimensions( + reinterpret_cast(dlmt->dl_tensor.shape), dlmt->dl_tensor.ndim); + TF_ASSIGN_OR_RETURN(PrimitiveType element_type, + DLDataTypeToPrimitiveType(dlmt->dl_tensor.dtype)); + + std::vector minor_to_major; + if (dlmt->dl_tensor.strides && + absl::c_find(dimensions, 0) == dimensions.end()) { + absl::Span strides( + reinterpret_cast(dlmt->dl_tensor.strides), + dlmt->dl_tensor.ndim); + TF_ASSIGN_OR_RETURN(minor_to_major, StridesToLayout(dimensions, strides)); + } else { + minor_to_major.resize(dlmt->dl_tensor.ndim); + std::iota(minor_to_major.rbegin(), minor_to_major.rend(), 0); + } + Shape shape = ShapeUtil::MakeShapeWithDenseLayout(element_type, dimensions, + minor_to_major); + + // Raise an error if the resulting PjRtBuffer would have a non-default layout. + // TODO(skyewm): we do this because JAX doesn't currently have good support + // for non-default layouts, and will return wrong results if a non-default + // layout is passed to a computation expecting default layouts. Remove this + // special case when non-default layouts are better supported by JAX. + TF_ASSIGN_OR_RETURN(Layout default_layout, device->client()->GetDefaultLayout( + element_type, dimensions)); + if (shape.layout() != default_layout) { + return Unimplemented( + "from_dlpack got array with non-default layout with minor-to-major " + "dimensions (%s), expected (%s)", + absl::StrJoin(shape.layout().minor_to_major(), ","), + absl::StrJoin(default_layout.minor_to_major(), ",")); + } + + std::function on_delete_callback; + if (dlmt->deleter) { + on_delete_callback = [dlmt]() { dlmt->deleter(dlmt); }; + } + + TF_ASSIGN_OR_RETURN( + auto pjrt_buffer, + MakePjrtBuffer(*device, dlmt, shape, element_type, dimensions)); + + // We have taken ownership of the array inside the capsule; make sure the + // capsule it cannot be used again. + PyCapsule_SetName(tensor.ptr(), "used_dltensor"); + PyCapsule_SetDestructor(tensor.ptr(), nullptr); + // TODO(phawkins): simplify the expression below once we know cpu_client is + // always non-null. + auto client = (cpu_client && device->client() == cpu_pjrt_client) + ? std::move(*cpu_client) + : std::move(*gpu_client); + auto* ifrt_client = + llvm::dyn_cast_or_null(client->ifrt_client()); + if (ifrt_client == nullptr) { + throw XlaRuntimeError( + "This operation is implemented for a PjRt-compatible backend only."); + } + TF_ASSIGN_OR_RETURN(auto ifrt_array, + ifrt_client->CreatePjRtArray(std::move(pjrt_buffer))); + return PyArray::MakeFromSingleDeviceArray(std::move(client), Traceback::Get(), + std::move(ifrt_array), false, true); +} + +absl::StatusOr DLPackManagedTensorToBuffer( + const nb::capsule& tensor, ifrt::Device* ifrt_device, + nb_class_ptr client, std::optional stream) { + ifrt::PjRtDevice* device = + llvm::dyn_cast_or_null(ifrt_device); + if (device == nullptr) { + throw XlaRuntimeError( + "DLPack is supported for PjRt-compatible backends only."); + } + if (!device->IsAddressable()) { + throw XlaRuntimeError( + "DLPack is only supported for devices addressable by the current " + "process."); + } + if (absl::string_view(tensor.name()) != kDlTensorCapsuleName) { + return InvalidArgument( + "DLPack tensor must be a capsule with name \"dltensor\", got \"%s\". " + "Note that a DLPack tensor may be consumed at most once.", + absl::string_view(tensor.name())); + } + DLManagedTensor* dlmt = static_cast(tensor.data()); + if (dlmt->dl_tensor.ndim < 0) { + return InvalidArgument( + "Number of dimensions in DLManagedTensor must be nonnegative, got %d", + dlmt->dl_tensor.ndim); + } + absl::Span dimensions( + reinterpret_cast(dlmt->dl_tensor.shape), dlmt->dl_tensor.ndim); + TF_ASSIGN_OR_RETURN(PrimitiveType element_type, + DLDataTypeToPrimitiveType(dlmt->dl_tensor.dtype)); + + std::vector minor_to_major; + if (dlmt->dl_tensor.strides && + absl::c_find(dimensions, 0) == dimensions.end()) { + absl::Span strides( + reinterpret_cast(dlmt->dl_tensor.strides), + dlmt->dl_tensor.ndim); + TF_ASSIGN_OR_RETURN(minor_to_major, StridesToLayout(dimensions, strides)); + } else { + minor_to_major.resize(dlmt->dl_tensor.ndim); + std::iota(minor_to_major.rbegin(), minor_to_major.rend(), 0); + } + Shape shape = ShapeUtil::MakeShapeWithDenseLayout(element_type, dimensions, + minor_to_major); + + TF_ASSIGN_OR_RETURN(auto pjrt_buffer, + MakePjrtBuffer(*device->pjrt_device(), dlmt, shape, + element_type, dimensions, stream)); + + // We have taken ownership of the array inside the capsule; make sure the + // capsule it cannot be used again. + PyCapsule_SetName(tensor.ptr(), "used_dltensor"); + PyCapsule_SetDestructor(tensor.ptr(), nullptr); + + auto* ifrt_client = + llvm::dyn_cast_or_null(client->ifrt_client()); + if (ifrt_client == nullptr) { + throw XlaRuntimeError( + "This operation is implemented for a PjRt-compatible backend only."); + } + TF_ASSIGN_OR_RETURN(auto ifrt_array, + ifrt_client->CreatePjRtArray(std::move(pjrt_buffer))); + return PyArray::MakeFromSingleDeviceArray(std::move(client), Traceback::Get(), + std::move(ifrt_array), false, true); +} + +absl::StatusOr PrimitiveTypeToNbDLDataType( + PrimitiveType type) { + TF_ASSIGN_OR_RETURN(DLDataType dl_type, PrimitiveTypeToDLDataType(type)); + + nanobind::dlpack::dtype nb_type; + nb_type.lanes = dl_type.lanes; + nb_type.bits = dl_type.bits; + nb_type.code = dl_type.code; + + return nb_type; +} + +} // namespace xla diff --git a/jaxlib/dlpack.h b/jaxlib/dlpack.h new file mode 100644 index 000000000000..54feb2b45dba --- /dev/null +++ b/jaxlib/dlpack.h @@ -0,0 +1,58 @@ +/* Copyright 2020 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_DLPACK_H_ +#define JAXLIB_DLPACK_H_ + +#include +#include + +#include "absl/status/statusor.h" +#include "nanobind/nanobind.h" +#include "nanobind/ndarray.h" +#include "jaxlib/nb_class_ptr.h" +#include "jaxlib/py_client.h" +#include "xla/python/ifrt/device.h" +#include "xla/xla_data.pb.h" + +namespace xla { + +// If take_ownership is true, ownership of the buffer is handed to DLPack, and +// the receiver may mutate the buffer as they see fit. Otherwise PjRt retains +// ownership of the buffer and it should be immutable. +// +// stream, if set, is a GPU stream, e.g. cudaStream_t for CUDA GPUs, that should +// be synchronized to the buffer as per +// https://dmlc.github.io/dlpack/latest/python_spec.html#python-specification-for-dlpack. +absl::StatusOr BufferToDLPackManagedTensor( + nanobind::handle buffer, std::optional stream); + +absl::StatusOr DLPackManagedTensorToBuffer( + const nanobind::capsule& tensor, + std::optional> cpu_client, + std::optional> gpu_client); + +absl::StatusOr DLPackManagedTensorToBuffer( + const nanobind::capsule& tensor, ifrt::Device* device, + nb_class_ptr client, std::optional stream); + +// Converts a PrimitiveType to the nanobind specific implementation of +// DLDataType. +absl::StatusOr PrimitiveTypeToNbDLDataType( + PrimitiveType type); + +} // namespace xla + +#endif // JAXLIB_DLPACK_H_ diff --git a/jaxlib/dlpack_support.cc b/jaxlib/dlpack_support.cc new file mode 100644 index 000000000000..9e851842ed14 --- /dev/null +++ b/jaxlib/dlpack_support.cc @@ -0,0 +1,223 @@ +/* Copyright 2025 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/dlpack_support.h" + +#include "absl/status/statusor.h" +#include "include/dlpack/dlpack.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" + +namespace xla { + +absl::StatusOr PrimitiveTypeToDLDataType(PrimitiveType type) { + switch (type) { + case S8: + return DLDataType{kDLInt, 8, 1}; + case S16: + return DLDataType{kDLInt, 16, 1}; + case S32: + return DLDataType{kDLInt, 32, 1}; + case S64: + return DLDataType{kDLInt, 64, 1}; + case U8: + return DLDataType{kDLUInt, 8, 1}; + case U16: + return DLDataType{kDLUInt, 16, 1}; + case U32: + return DLDataType{kDLUInt, 32, 1}; + case U64: + return DLDataType{kDLUInt, 64, 1}; + case F4E2M1FN: + return DLDataType{kDLFloat4_e2m1fn, 4, 1}; + case F8E3M4: + return DLDataType{kDLFloat8_e3m4, 8, 1}; + case F8E4M3: + return DLDataType{kDLFloat8_e4m3, 8, 1}; + case F8E4M3B11FNUZ: + return DLDataType{kDLFloat8_e4m3b11fnuz, 8, 1}; + case F8E4M3FN: + return DLDataType{kDLFloat8_e4m3fn, 8, 1}; + case F8E4M3FNUZ: + return DLDataType{kDLFloat8_e4m3fnuz, 8, 1}; + case F8E5M2: + return DLDataType{kDLFloat8_e5m2, 8, 1}; + case F8E5M2FNUZ: + return DLDataType{kDLFloat8_e5m2fnuz, 8, 1}; + case F8E8M0FNU: + return DLDataType{kDLFloat8_e8m0fnu, 8, 1}; + case BF16: + return DLDataType{kDLBfloat, 16, 1}; + case F16: + return DLDataType{kDLFloat, 16, 1}; + case F32: + return DLDataType{kDLFloat, 32, 1}; + case F64: + return DLDataType{kDLFloat, 64, 1}; + case PRED: + return DLDataType{kDLBool, 8, 1}; + case C64: + return DLDataType{kDLComplex, 64, 1}; + case C128: + return DLDataType{kDLComplex, 128, 1}; + default: + return Unimplemented("XLA type %s has no DLPack equivalent", + PrimitiveType_Name(type)); + } +} + +absl::StatusOr DLDataTypeToPrimitiveType(DLDataType type) { + if (type.lanes != 1) { + return Unimplemented("DLPack types with lanes != 1 not implemented, got %d", + type.lanes); + } + switch (type.code) { + case kDLBool: + switch (type.bits) { + case 8: + return PRED; + default: + return Unimplemented( + "Only 8-bit DLPack booleans are supported, got %d bits", + type.bits); + } + case kDLInt: + switch (type.bits) { + case 8: + return S8; + case 16: + return S16; + case 32: + return S32; + case 64: + return S64; + default: + return Unimplemented( + "Invalid or unsupported DLPack integer width: %d bits", + type.bits); + } + case kDLUInt: + switch (type.bits) { + case 8: + return U8; + case 16: + return U16; + case 32: + return U32; + case 64: + return U64; + default: + return Unimplemented( + "Invalid or unsupported DLPack unsigned integer width: %d bits", + type.bits); + } + case kDLFloat4_e2m1fn: + if (type.bits == 4) { + return F4E2M1FN; + } + return Unimplemented( + "Invalid or unsupported DLPack float4_e2m1fn width: %d bits", + type.bits); + case kDLFloat8_e3m4: + if (type.bits == 8) { + return F8E3M4; + } + return Unimplemented( + "Invalid or unsupported DLPack float8_e3m4 width: %d bits", + type.bits); + case kDLFloat8_e4m3: + if (type.bits == 8) { + return F8E4M3; + } + return Unimplemented( + "Invalid or unsupported DLPack float8_e4m3 width: %d bits", + type.bits); + case kDLFloat8_e4m3b11fnuz: + if (type.bits == 8) { + return F8E4M3B11FNUZ; + } + return Unimplemented( + "Invalid or unsupported DLPack float8_e4m3b11fnuz width: %d bits", + type.bits); + case kDLFloat8_e4m3fn: + if (type.bits == 8) { + return F8E4M3FN; + } + return Unimplemented( + "Invalid or unsupported DLPack float8_e4m3fn width: %d bits", + type.bits); + case kDLFloat8_e4m3fnuz: + if (type.bits == 8) { + return F8E4M3FNUZ; + } + return Unimplemented( + "Invalid or unsupported DLPack float8_e4m3fnuz width: %d bits", + type.bits); + case kDLFloat8_e5m2: + if (type.bits == 8) { + return F8E5M2; + } + return Unimplemented( + "Invalid or unsupported DLPack float8_e5m2 width: %d bits", + type.bits); + case kDLFloat8_e5m2fnuz: + if (type.bits == 8) { + return F8E5M2FNUZ; + } + return Unimplemented( + "Invalid or unsupported DLPack float8_e5m2fnuz width: %d bits", + type.bits); + case kDLFloat8_e8m0fnu: + if (type.bits == 8) { + return F8E8M0FNU; + } + return Unimplemented( + "Invalid or unsupported DLPack float8_e8m0fnu width: %d bits", + type.bits); + case kDLBfloat: + if (type.bits == 16) { + return BF16; + } + return Unimplemented( + "Invalid or unsupported DLPack bfloat width: %d bits", type.bits); + case kDLFloat: + switch (type.bits) { + case 16: + return F16; + case 32: + return F32; + case 64: + return F64; + default: + return Unimplemented( + "Invalid or unsupported DLPack float width: %d bits", type.bits); + } + case kDLComplex: + switch (type.bits) { + case 64: + return C64; + case 128: + return C128; + default: + return Unimplemented( + "Invalid or unsupported DLPack complex width: %d bits", + type.bits); + } + default: + return Unimplemented("Unknown or invalid DLPack type code %d", type.code); + } +} + +} // namespace xla diff --git a/jaxlib/dlpack_support.h b/jaxlib/dlpack_support.h new file mode 100644 index 000000000000..25e862353bab --- /dev/null +++ b/jaxlib/dlpack_support.h @@ -0,0 +1,30 @@ +/* Copyright 2025 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_XLA_DLPACK_SUPPORT_H_ +#define JAXLIB_XLA_DLPACK_SUPPORT_H_ + +#include "absl/status/statusor.h" +#include "include/dlpack/dlpack.h" +#include "xla/xla_data.pb.h" + +namespace xla { + +absl::StatusOr PrimitiveTypeToDLDataType(PrimitiveType type); +absl::StatusOr DLDataTypeToPrimitiveType(DLDataType type); + +} // namespace xla + +#endif // JAXLIB_XLA_DLPACK_SUPPORT_H_ diff --git a/jaxlib/ffi.cc b/jaxlib/ffi.cc new file mode 100644 index 000000000000..790a9876dd10 --- /dev/null +++ b/jaxlib/ffi.cc @@ -0,0 +1,373 @@ +/* Copyright 2025 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/ffi.h" + +#include + +#include +#include +#include +#include +#include + +#include "absl/base/casts.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/types/span.h" +#include "include/dlpack/dlpack.h" +#include "nanobind/nanobind.h" +#include "jaxlib/dlpack_support.h" +#include "xla/ffi/api/c_api.h" +#include "xla/ffi/api/ffi.h" +#include "xla/ffi/ffi_api.h" +#include "xla/pjrt/status_casters.h" +#include "xla/python/nb_numpy.h" +#include "xla/python/types.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/xla_data.pb.h" + +namespace jax { + +namespace ffi = xla::ffi; +namespace nb = nanobind; + +namespace { +const char* const kDlTensorCapsuleName = "dltensor"; +const char* const kDlTensorVersionedCapsuleName = "dltensor_versioned"; + +template +struct DLPackTensor { + std::vector shape; + ManagedTensor tensor; +}; + +template +void DLPackTensorDeleter(ManagedTensor* t) { + if (t) { + delete static_cast*>(t->manager_ctx); + } +} + +xla::PrimitiveType PrimitiveTypeForFfiDataType(ffi::DataType dtype) { + switch (dtype) { + case ffi::DataType::INVALID: + return xla::PrimitiveType::PRIMITIVE_TYPE_INVALID; + case ffi::PRED: + return xla::PrimitiveType::PRED; + case ffi::S1: + return xla::PrimitiveType::S1; + case ffi::S2: + return xla::PrimitiveType::S2; + case ffi::S4: + return xla::PrimitiveType::S4; + case ffi::S8: + return xla::PrimitiveType::S8; + case ffi::S16: + return xla::PrimitiveType::S16; + case ffi::S32: + return xla::PrimitiveType::S32; + case ffi::S64: + return xla::PrimitiveType::S64; + case ffi::U1: + return xla::PrimitiveType::U1; + case ffi::U2: + return xla::PrimitiveType::U2; + case ffi::U4: + return xla::PrimitiveType::U4; + case ffi::U8: + return xla::PrimitiveType::U8; + case ffi::U16: + return xla::PrimitiveType::U16; + case ffi::U32: + return xla::PrimitiveType::U32; + case ffi::U64: + return xla::PrimitiveType::U64; + case ffi::F16: + return xla::PrimitiveType::F16; + case ffi::F32: + return xla::PrimitiveType::F32; + case ffi::F64: + return xla::PrimitiveType::F64; + case ffi::BF16: + return xla::PrimitiveType::BF16; + case ffi::C64: + return xla::PrimitiveType::C64; + case ffi::C128: + return xla::PrimitiveType::C128; + case ffi::TOKEN: + return xla::PrimitiveType::TOKEN; + case ffi::F8E5M2: + return xla::PrimitiveType::F8E5M2; + case ffi::F8E4M3: + return xla::PrimitiveType::F8E4M3; + case ffi::F8E4M3FN: + return xla::PrimitiveType::F8E4M3FN; + case ffi::F8E4M3B11FNUZ: + return xla::PrimitiveType::F8E4M3B11FNUZ; + case ffi::F8E5M2FNUZ: + return xla::PrimitiveType::F8E5M2FNUZ; + case ffi::F8E4M3FNUZ: + return xla::PrimitiveType::F8E4M3FNUZ; + case ffi::F8E3M4: + return xla::PrimitiveType::F8E3M4; + case ffi::F4E2M1FN: + return xla::PrimitiveType::F4E2M1FN; + case ffi::F8E8M0FNU: + return xla::PrimitiveType::F8E8M0FNU; + } +} +} // namespace + +PyFfiContext::PyFfiContext(const XLA_FFI_Api* api, + XLA_FFI_ExecutionContext* ctx, + XLA_FFI_ExecutionStage stage) + : api_(api), ctx_(ctx), stage_(stage) {} + +PyFfiContext::Stage PyFfiContext::stage() const { + return static_cast(stage_); +} + +absl::StatusOr PyFfiContext::stream() const { + XLA_FFI_Stream_Get_Args args; + args.struct_size = XLA_FFI_Stream_Get_Args_STRUCT_SIZE; + args.extension_start = nullptr; + args.ctx = ctx_; + args.stream = nullptr; + if (XLA_FFI_Error* error = api_->XLA_FFI_Stream_Get(&args)) { + return ffi::TakeStatus(error); + } + return absl::bit_cast(args.stream); +} + +PyFfiAnyBuffer::PyFfiAnyBuffer(DLDeviceType device_type, int32_t device_ordinal, + void* data, ffi::Span dimensions, + ffi::DataType element_type, bool writeable) + : device_type_(device_type), + device_ordinal_(device_ordinal), + data_(data), + dimensions_(dimensions.begin(), dimensions.size()), + element_type_(PrimitiveTypeForFfiDataType(element_type)), + writeable_(writeable) {} + +PyFfiAnyBuffer::PyFfiAnyBuffer(DLDeviceType device_type, int32_t device_ordinal, + ffi::AnyBuffer buf) + : PyFfiAnyBuffer(device_type, device_ordinal, buf.untyped_data(), + buf.dimensions(), buf.element_type(), + /*writeable=*/false) {} + +PyFfiAnyBuffer::PyFfiAnyBuffer(DLDeviceType device_type, int32_t device_ordinal, + ffi::Result buf) + : PyFfiAnyBuffer(device_type, device_ordinal, buf->untyped_data(), + buf->dimensions(), buf->element_type(), + /*writeable=*/true) {} + +absl::StatusOr PyFfiAnyBuffer::dtype() const { + return xla::PrimitiveTypeToNbDtype(element_type_); +} + +size_t PyFfiAnyBuffer::ndim() const { return dimensions_.size(); } + +nb::tuple PyFfiAnyBuffer::shape() const { + return xla::SpanToNbTuple(dimensions_); +} + +bool PyFfiAnyBuffer::writeable() const { return writeable_; } + +absl::StatusOr PyFfiAnyBuffer::NumpyArray() const { + if (device_type_ != kDLCPU) { + return absl::UnimplementedError( + "Buffer.__array__ is only supported on CPU."); + } + + TF_ASSIGN_OR_RETURN(auto dtype, this->dtype()); + xla::nb_numpy_ndarray array(dtype, dimensions_, /* strides= */ std::nullopt, + data_, nb::cast(this)); + + // TODO(danfm): We don't seem to be allowed to set this flag like this + // because the array doesn't own its data. + // array.attr("flags").attr("writeable") = nb::bool_(writeable_); + + return array; +} + +absl::StatusOr PyFfiAnyBuffer::CudaArrayInterface() const { + if (device_type_ != kDLCUDA) { + return absl::UnimplementedError( + "Buffer.__cuda_array_interface__ is only supported on CUDA."); + } + + nb::dict result; + result["shape"] = xla::SpanToNbTuple(dimensions_); + TF_ASSIGN_OR_RETURN(result["typestr"], + TypeDescriptorForPrimitiveType(element_type_)); + result["data"] = nb::make_tuple( + nb::int_(absl::bit_cast(data_)), !writeable_); + result["version"] = nb::int_(2); + return result; +} + +absl::StatusOr PyFfiAnyBuffer::DLPack() const { + auto pack = std::make_unique>(); + pack->tensor.manager_ctx = pack.get(); + pack->tensor.deleter = DLPackTensorDeleter; + + DLTensor& dt = pack->tensor.dl_tensor; + dt.data = data_; + dt.device = DLDevice{device_type_, device_ordinal_}; + dt.ndim = dimensions_.size(); + TF_ASSIGN_OR_RETURN(dt.dtype, xla::PrimitiveTypeToDLDataType(element_type_)); + pack->shape = std::vector(dimensions_.begin(), dimensions_.end()); + dt.shape = reinterpret_cast(pack->shape.data()); + dt.strides = nullptr; + dt.byte_offset = 0; + + // We cannot use nanobind's capsule object constructor because we need to + // detect if the capsule name has been changed in the deleter, but nanobind + // hides the underlying Python object from the deleter. + nb::capsule capsule = nb::steal( + PyCapsule_New(&pack.release()->tensor, kDlTensorCapsuleName, + [](PyObject* obj) noexcept { + DLManagedTensor* dlmt = static_cast( + PyCapsule_GetPointer(obj, kDlTensorCapsuleName)); + if (dlmt) { + DLPackTensorDeleter(dlmt); + } else { + // The tensor has been deleted. Clear any error from + // PyCapsule_GetPointer. + PyErr_Clear(); + } + })); + if (!capsule.ptr()) { + throw nb::python_error(); + } + + return capsule; +} + +absl::StatusOr PyFfiAnyBuffer::DLPackVersioned() const { + auto pack = std::make_unique>(); + pack->tensor.version = + DLPackVersion{DLPACK_MAJOR_VERSION, DLPACK_MINOR_VERSION}; + pack->tensor.manager_ctx = pack.get(); + pack->tensor.deleter = DLPackTensorDeleter; + pack->tensor.flags = writeable_ ? 0 : DLPACK_FLAG_BITMASK_READ_ONLY; + + DLTensor& dt = pack->tensor.dl_tensor; + dt.data = data_; + dt.device = DLDevice{device_type_, device_ordinal_}; + dt.ndim = dimensions_.size(); + TF_ASSIGN_OR_RETURN(dt.dtype, xla::PrimitiveTypeToDLDataType(element_type_)); + pack->shape = std::vector(dimensions_.begin(), dimensions_.end()); + dt.shape = reinterpret_cast(pack->shape.data()); + dt.strides = nullptr; + dt.byte_offset = 0; + + // We cannot use nanobind's capsule object constructor because we need to + // detect if the capsule name has been changed in the deleter, but nanobind + // hides the underlying Python object from the deleter. + nb::capsule capsule = nb::steal(PyCapsule_New( + &pack.release()->tensor, kDlTensorVersionedCapsuleName, + [](PyObject* obj) noexcept { + DLManagedTensorVersioned* dlmt = static_cast( + PyCapsule_GetPointer(obj, kDlTensorVersionedCapsuleName)); + if (dlmt) { + DLPackTensorDeleter(dlmt); + } else { + // The tensor has been deleted. Clear any error from + // PyCapsule_GetPointer. + PyErr_Clear(); + } + })); + if (!capsule.ptr()) { + throw nb::python_error(); + } + + return capsule; +} + +nb::tuple PyFfiAnyBuffer::DLPackDevice() const { + return nb::make_tuple(static_cast(device_type_), device_ordinal_); +} + +void BuildFfiSubmodule(nb::module_& m) { + tsl::ImportNumpy(); + + nb::module_ ffi_module = + m.def_submodule("ffi", "Python bindings for the XLA FFI."); + + nb::class_ buffer(ffi_module, "Buffer"); + buffer.def_prop_ro("dtype", xla::ValueOrThrowWrapper(&PyFfiAnyBuffer::dtype)); + buffer.def_prop_ro("ndim", &PyFfiAnyBuffer::ndim); + buffer.def_prop_ro("shape", &PyFfiAnyBuffer::shape); + buffer.def_prop_ro("writeable", &PyFfiAnyBuffer::writeable); + buffer.def( + "__array__", + [](PyFfiAnyBuffer self, nb::object dtype, nb::object copy) { + if (!dtype.is_none()) { + throw nb::value_error( + "dtype parameter is not supported by Buffer.__array__."); + } + if (!copy.is_none() && nb::cast(copy)) { + throw nb::value_error( + "Buffer.__array__ with copy=True is not supported."); + } + return xla::ValueOrThrow(self.NumpyArray()); + }, + nb::arg("dtype") = nb::none(), nb::arg("copy") = nb::none()); + buffer.def_prop_ro( + "__cuda_array_interface__", + xla::ValueOrThrowWrapper(&PyFfiAnyBuffer::CudaArrayInterface)); + buffer.def( + "__dlpack__", + [](PyFfiAnyBuffer self, nb::object stream, nb::object max_version, + nb::object dl_device, nb::object copy) { + if (!copy.is_none() && nb::cast(copy)) { + throw nb::value_error( + "Buffer.__dlpack__ with copy=True is not supported."); + } + + // Fall back on the non-versioned API if unsupported by the requested + // max_version. + nb::tuple max_version_tuple; + int64_t max_version_major; + if (!nb::try_cast(max_version, max_version_tuple) || + max_version_tuple.size() < 2 || + !nb::try_cast(max_version_tuple[0], max_version_major) || + max_version_major < 1) { + return xla::ValueOrThrow(self.DLPack()); + } + + // TODO(danfm): Handle other optional inputs. + return xla::ValueOrThrow(self.DLPackVersioned()); + }, + nb::arg("stream") = nb::none(), nb::arg("max_version") = nb::none(), + nb::arg("dl_device") = nb::none(), nb::arg("copy") = nb::none()); + buffer.def("__dlpack_device__", &PyFfiAnyBuffer::DLPackDevice); + + nb::enum_(ffi_module, "ExecutionStage") + .value("INSTANTIATE", PyFfiContext::Stage::kInstantiate) + .value("PREPARE", PyFfiContext::Stage::kPrepare) + .value("INITIALIZE", PyFfiContext::Stage::kInitialize) + .value("EXECUTE", PyFfiContext::Stage::kExecute) + .export_values(); + + nb::class_ context(ffi_module, "ExecutionContext"); + context.def_prop_ro("stage", &PyFfiContext::stage); + context.def_prop_ro("stream", + xla::ValueOrThrowWrapper(&PyFfiContext::stream)); +} + +} // namespace jax diff --git a/jaxlib/ffi.h b/jaxlib/ffi.h new file mode 100644 index 000000000000..e2045a0f513c --- /dev/null +++ b/jaxlib/ffi.h @@ -0,0 +1,152 @@ +/* Copyright 2025 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_XLA_FFI_H_ +#define JAXLIB_XLA_FFI_H_ + +#include + +#include +#include + +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "absl/types/span.h" +#include "include/dlpack/dlpack.h" +#include "nanobind/nanobind.h" +#include "xla/ffi/api/c_api.h" +#include "xla/ffi/api/ffi.h" +#include "xla/pjrt/host_callback.h" +#include "xla/python/nb_numpy.h" +#include "xla/xla_data.pb.h" + +namespace jax { + +namespace ffi = xla::ffi; +namespace nb = nanobind; + +// Wrapper class for XLA FFI execution context. +// +// This class provides a Python interface to the XLA FFI execution context, +// exposing metadata such as the execution stage, device ordinal, and stream. +class PyFfiContext { + public: + enum class Stage { + kInstantiate, + kPrepare, + kInitialize, + kExecute, + }; + + PyFfiContext(const XLA_FFI_Api* api, XLA_FFI_ExecutionContext* ctx, + XLA_FFI_ExecutionStage stage); + Stage stage() const; + absl::StatusOr stream() const; + + private: + const XLA_FFI_Api* api_; + XLA_FFI_ExecutionContext* ctx_; + XLA_FFI_ExecutionStage stage_; +}; + +// Wrapper class for XLA FFI AnyBuffer. +// +// This class provides a Python interface to the XLA FFI `AnyBuffer` class. +// From Python, this object looks like an array (with `.dtype` and `.shape` +// attributes), but it also provides methods zero-copy conversions to standard +// transport formats: `__array__`, `__cuda_array_interface__`, and `__dlpack__`. +class PyFfiAnyBuffer { + public: + PyFfiAnyBuffer(DLDeviceType device_type, int32_t device_ordinal, void* data, + ffi::Span dimensions, + ffi::DataType element_type, bool writeable); + PyFfiAnyBuffer(DLDeviceType device_type, int32_t device_ordinal, + ffi::AnyBuffer buf); + PyFfiAnyBuffer(DLDeviceType device_type, int32_t device_ordinal, + ffi::Result buf); + + absl::StatusOr dtype() const; + size_t ndim() const; + nb::tuple shape() const; + bool writeable() const; + + absl::StatusOr NumpyArray() const; + absl::StatusOr CudaArrayInterface() const; + absl::StatusOr DLPack() const; + absl::StatusOr DLPackVersioned() const; + nb::tuple DLPackDevice() const; + + private: + DLDeviceType device_type_; + int32_t device_ordinal_; + void* data_; + absl::Span dimensions_; + xla::PrimitiveType element_type_; + bool writeable_; +}; + +template +ffi::Error XlaBufferCallback(int32_t device_ordinal, const XLA_FFI_Api* api, + XLA_FFI_ExecutionContext* ctx, + xla::FfiLoadedHostCallbacks* callbacks, + uint64_t index, ffi::RemainingArgs args, + ffi::RemainingRets rets) { + nb::gil_scoped_acquire gil; + auto callback = nb::borrow( + static_cast(callbacks->callbacks[index])); + auto nb_args = + nb::steal(PyTuple_New(1 + args.size() + rets.size())); + + jax::PyFfiContext py_ctx(api, ctx, XLA_FFI_ExecutionStage_EXECUTE); + PyTuple_SET_ITEM(nb_args.ptr(), 0, nb::cast(py_ctx).release().ptr()); + + size_t offset = 1; + for (size_t i = 0; i < args.size(); ++i, ++offset) { + auto arg = args.get(i); + if (arg.has_error()) { + return arg.error(); + } + jax::PyFfiAnyBuffer py_buffer(DeviceType, device_ordinal, arg.value()); + PyTuple_SET_ITEM(nb_args.ptr(), offset, + nb::cast(py_buffer).release().ptr()); + } + + for (size_t i = 0; i < rets.size(); ++i, ++offset) { + auto ret = rets.get(i); + if (ret.has_error()) { + return ret.error(); + } + jax::PyFfiAnyBuffer py_buffer(DeviceType, device_ordinal, ret.value()); + PyTuple_SET_ITEM(nb_args.ptr(), offset, + nb::cast(py_buffer).release().ptr()); + } + + xla::EnterHostCallback(); + try { + callback(*nb::borrow(nb_args)); + } catch (nb::python_error& e) { + return ffi::Error::Internal( + absl::StrFormat("Error when calling buffer callback: %s", e.what())); + } + xla::LeaveHostCallback(); + + return ffi::Error::Success(); +} + +void BuildFfiSubmodule(nanobind::module_& m); + +} // namespace jax + +#endif // JAXLIB_XLA_FFI_H_ diff --git a/jaxlib/ffi_helpers.h b/jaxlib/ffi_helpers.h index 5c6d80093df5..7c4dfce81311 100644 --- a/jaxlib/ffi_helpers.h +++ b/jaxlib/ffi_helpers.h @@ -1,3 +1,18 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + #ifndef JAXLIB_FFI_HELPERS_H_ #define JAXLIB_FFI_HELPERS_H_ @@ -74,7 +89,7 @@ namespace jax { template inline absl::StatusOr MaybeCastNoOverflow( - std::int64_t value, const std::string& source = __FILE__) { + std::int64_t value, std::string_view source = __FILE__) { if constexpr (sizeof(T) == sizeof(std::int64_t)) { return value; } else { diff --git a/jaxlib/gpu/BUILD b/jaxlib/gpu/BUILD index b5292746dd10..98f0f6cfe624 100644 --- a/jaxlib/gpu/BUILD +++ b/jaxlib/gpu/BUILD @@ -30,11 +30,8 @@ package( ) exports_files(srcs = [ - "blas.cc", "blas_handle_pool.cc", "blas_handle_pool.h", - "blas_kernels.cc", - "blas_kernels.h", "ffi_wrapper.h", "gpu_kernel_helpers.cc", "gpu_kernel_helpers.h", @@ -52,6 +49,8 @@ exports_files(srcs = [ "prng_kernels.cc", "prng_kernels.cu.cc", "prng_kernels.h", + "py_client_gpu.cc", + "py_client_gpu.h", "rnn.cc", "rnn_kernels.cc", "rnn_kernels.h", @@ -60,8 +59,6 @@ exports_files(srcs = [ "solver_handle_pool.h", "solver_interface.cc", "solver_interface.h", - "solver_kernels.cc", - "solver_kernels.h", "solver_kernels_ffi.cc", "solver_kernels_ffi.h", "sparse.cc", @@ -82,6 +79,7 @@ proto_library( cc_proto_library( name = "triton_cc_proto", + compatible_with = None, deps = [":triton_proto"], ) @@ -91,6 +89,21 @@ xla_py_proto_library( deps = [":triton_proto"], ) +cc_library( + name = "handle_pool", + hdrs = ["handle_pool.h"], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/synchronization", + ], +) + cc_library( name = "gpu_plugin_extension", srcs = ["gpu_plugin_extension.cc"], @@ -115,7 +128,7 @@ cc_library( "@xla//xla/pjrt/c:pjrt_c_api_hdrs", "@xla//xla/pjrt/c:pjrt_c_api_helpers", "@xla//xla/pjrt/c:pjrt_c_api_triton_extension_hdrs", - "@xla//xla/python:py_client_gpu", + "@xla//xla/tsl/platform:statusor", "@xla//xla/tsl/python/lib/core:numpy", ], ) diff --git a/jaxlib/gpu/blas.cc b/jaxlib/gpu/blas.cc deleted file mode 100644 index e8761bd32ac9..000000000000 --- a/jaxlib/gpu/blas.cc +++ /dev/null @@ -1,85 +0,0 @@ -/* Copyright 2019 The JAX Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include - -#include "nanobind/nanobind.h" -#include "nanobind/stl/pair.h" -#include "absl/container/flat_hash_map.h" -#include "absl/strings/str_format.h" -#include "jaxlib/gpu/blas_kernels.h" -#include "jaxlib/gpu/vendor.h" -#include "jaxlib/kernel_nanobind_helpers.h" -#include "xla/tsl/python/lib/core/numpy.h" - -namespace jax { -namespace JAX_GPU_NAMESPACE { -namespace { - -namespace nb = nanobind; - -// Converts a NumPy dtype to a Type. -BlasType DtypeToBlasType(const dtype& np_type) { - static auto* types = new absl::flat_hash_map, BlasType>({ - {{'f', 4}, BlasType::F32}, - {{'f', 8}, BlasType::F64}, - {{'c', 8}, BlasType::C64}, - {{'c', 16}, BlasType::C128}, - }); - auto it = types->find({np_type.kind(), np_type.itemsize()}); - if (it == types->end()) { - nb::str repr = nb::repr(np_type); - throw std::invalid_argument( - absl::StrFormat("Unsupported dtype %s", repr.c_str())); - } - return it->second; -} - -// Returns the descriptor for a GetrfBatched operation. -std::pair BuildGetrfBatchedDescriptor(const dtype& dtype, - int b, int n) { - BlasType type = DtypeToBlasType(dtype); - size_t size = b * sizeof(void*); - return {size, PackDescriptor(GetrfBatchedDescriptor{type, b, n})}; -} - -// Returns the descriptor for a GetrfBatched operation. -std::pair BuildGeqrfBatchedDescriptor(const dtype& dtype, - int b, int m, int n) { - BlasType type = DtypeToBlasType(dtype); - size_t size = b * sizeof(void*); - return {size, PackDescriptor(GeqrfBatchedDescriptor{type, b, m, n})}; -} - -nb::dict Registrations() { - nb::dict dict; - dict[JAX_GPU_PREFIX "blas_getrf_batched"] = EncapsulateFunction(GetrfBatched); - dict[JAX_GPU_PREFIX "blas_geqrf_batched"] = EncapsulateFunction(GeqrfBatched); - return dict; -} - -NB_MODULE(_blas, m) { - tsl::ImportNumpy(); - - m.def("registrations", &Registrations); - m.def("build_getrf_batched_descriptor", &BuildGetrfBatchedDescriptor); - m.def("build_geqrf_batched_descriptor", &BuildGeqrfBatchedDescriptor); -} - -} // namespace -} // namespace JAX_GPU_NAMESPACE -} // namespace jax diff --git a/jaxlib/gpu/blas_handle_pool.cc b/jaxlib/gpu/blas_handle_pool.cc index 2ce204453039..ff381b802ab2 100644 --- a/jaxlib/gpu/blas_handle_pool.cc +++ b/jaxlib/gpu/blas_handle_pool.cc @@ -19,7 +19,7 @@ limitations under the License. #include "absl/synchronization/mutex.h" #include "jaxlib/gpu/gpu_kernel_helpers.h" #include "jaxlib/gpu/vendor.h" -#include "jaxlib/handle_pool.h" +#include "jaxlib/gpu/handle_pool.h" namespace jax { diff --git a/jaxlib/gpu/blas_handle_pool.h b/jaxlib/gpu/blas_handle_pool.h index b3cdbaa88867..43724baab45e 100644 --- a/jaxlib/gpu/blas_handle_pool.h +++ b/jaxlib/gpu/blas_handle_pool.h @@ -18,7 +18,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "jaxlib/gpu/vendor.h" -#include "jaxlib/handle_pool.h" +#include "jaxlib/gpu/handle_pool.h" namespace jax { diff --git a/jaxlib/gpu/blas_kernels.cc b/jaxlib/gpu/blas_kernels.cc deleted file mode 100644 index ac30aa9cc520..000000000000 --- a/jaxlib/gpu/blas_kernels.cc +++ /dev/null @@ -1,198 +0,0 @@ -/* Copyright 2019 The JAX Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "jaxlib/gpu/blas_kernels.h" - -#include -#include -#include -#include -#include - -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/strings/str_format.h" -#include "jaxlib/gpu/blas_handle_pool.h" -#include "jaxlib/gpu/gpu_kernel_helpers.h" -#include "jaxlib/gpu/make_batch_pointers.h" -#include "jaxlib/gpu/vendor.h" -#include "jaxlib/kernel_helpers.h" -#include "xla/service/custom_call_status.h" - -namespace jax { - -namespace JAX_GPU_NAMESPACE { - -namespace { - -int SizeOfBlasType(BlasType type) { - switch (type) { - case BlasType::F32: - return sizeof(float); - case BlasType::F64: - return sizeof(double); - case BlasType::C64: - return sizeof(gpublasComplex); - case BlasType::C128: - return sizeof(gpublasDoubleComplex); - } -} - -} // namespace - -// Batched LU decomposition: getrfbatched - -static absl::Status GetrfBatched_(gpuStream_t stream, void** buffers, - const char* opaque, size_t opaque_len) { - auto s = UnpackDescriptor(opaque, opaque_len); - JAX_RETURN_IF_ERROR(s.status()); - const GetrfBatchedDescriptor& d = **s; - auto h = BlasHandlePool::Borrow(stream); - JAX_RETURN_IF_ERROR(h.status()); - auto& handle = *h; - if (buffers[0] != buffers[1]) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuMemcpyAsync( - buffers[1], buffers[0], SizeOfBlasType(d.type) * d.batch * d.n * d.n, - gpuMemcpyDeviceToDevice, stream))); - } - - int* ipiv = static_cast(buffers[2]); - int* info = static_cast(buffers[3]); - MakeBatchPointersAsync(stream, buffers[1], buffers[4], d.batch, - SizeOfBlasType(d.type) * d.n * d.n); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuGetLastError())); - switch (d.type) { - case BlasType::F32: { - float** batch_ptrs = static_cast(buffers[4]); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpublasSgetrfBatched( - handle.get(), d.n, batch_ptrs, d.n, ipiv, info, d.batch))); - break; - } - case BlasType::F64: { - double** batch_ptrs = static_cast(buffers[4]); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpublasDgetrfBatched( - handle.get(), d.n, batch_ptrs, d.n, ipiv, info, d.batch))); - break; - } - case BlasType::C64: { - gpublasComplex** batch_ptrs = static_cast(buffers[4]); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpublasCgetrfBatched( - handle.get(), d.n, batch_ptrs, d.n, ipiv, info, d.batch))); - break; - } - case BlasType::C128: { - gpublasDoubleComplex** batch_ptrs = - static_cast(buffers[4]); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpublasZgetrfBatched( - handle.get(), d.n, batch_ptrs, d.n, ipiv, info, d.batch))); - break; - } - } - return absl::OkStatus(); -} - -void GetrfBatched(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status) { - auto s = GetrfBatched_(stream, buffers, opaque, opaque_len); - if (!s.ok()) { - XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(), - s.message().length()); - } -} - -// Batched QR decomposition: geqrfbatched - -static absl::Status GeqrfBatched_(gpuStream_t stream, void** buffers, - const char* opaque, size_t opaque_len) { - auto s = UnpackDescriptor(opaque, opaque_len); - JAX_RETURN_IF_ERROR(s.status()); - const GeqrfBatchedDescriptor& d = **s; - auto h = BlasHandlePool::Borrow(stream); - JAX_RETURN_IF_ERROR(h.status()); - auto& handle = *h; - if (buffers[0] != buffers[1]) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuMemcpyAsync( - buffers[1], buffers[0], SizeOfBlasType(d.type) * d.batch * d.m * d.n, - gpuMemcpyDeviceToDevice, stream))); - } - - std::vector info(d.batch); - MakeBatchPointersAsync(stream, buffers[1], buffers[3], d.batch, - SizeOfBlasType(d.type) * d.m * d.n); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuGetLastError())); - MakeBatchPointersAsync(stream, buffers[2], buffers[4], d.batch, - SizeOfBlasType(d.type) * std::min(d.m, d.n)); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuGetLastError())); - switch (d.type) { - case BlasType::F32: { - float** a_batch_ptrs = static_cast(buffers[3]); - float** tau_batch_ptrs = static_cast(buffers[4]); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - gpublasSgeqrfBatched(handle.get(), d.m, d.n, a_batch_ptrs, d.m, - tau_batch_ptrs, info.data(), d.batch))); - break; - } - case BlasType::F64: { - double** a_batch_ptrs = static_cast(buffers[3]); - double** tau_batch_ptrs = static_cast(buffers[4]); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - gpublasDgeqrfBatched(handle.get(), d.m, d.n, a_batch_ptrs, d.m, - tau_batch_ptrs, info.data(), d.batch))); - break; - } - case BlasType::C64: { - gpublasComplex** a_batch_ptrs = static_cast(buffers[3]); - gpublasComplex** tau_batch_ptrs = - static_cast(buffers[4]); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - gpublasCgeqrfBatched(handle.get(), d.m, d.n, a_batch_ptrs, d.m, - tau_batch_ptrs, info.data(), d.batch))); - break; - } - case BlasType::C128: { - gpublasDoubleComplex** a_batch_ptrs = - static_cast(buffers[3]); - gpublasDoubleComplex** tau_batch_ptrs = - static_cast(buffers[4]); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - gpublasZgeqrfBatched(handle.get(), d.m, d.n, a_batch_ptrs, d.m, - tau_batch_ptrs, info.data(), d.batch))); - break; - } - } - auto it = - std::find_if(info.begin(), info.end(), [](int i) { return i != 0; }); - - if (it != info.end()) { - return absl::InvalidArgumentError( - absl::StrFormat("QR decomposition failed with status %d for batch " - "element %d", - *it, std::distance(info.begin(), it))); - } - - return absl::OkStatus(); -} - -void GeqrfBatched(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status) { - auto s = GeqrfBatched_(stream, buffers, opaque, opaque_len); - if (!s.ok()) { - XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(), - s.message().length()); - } -} - -} // namespace JAX_GPU_NAMESPACE -} // namespace jax diff --git a/jaxlib/gpu/blas_kernels.h b/jaxlib/gpu/blas_kernels.h deleted file mode 100644 index 724565ea73d1..000000000000 --- a/jaxlib/gpu/blas_kernels.h +++ /dev/null @@ -1,58 +0,0 @@ -/* Copyright 2019 The JAX Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef JAXLIB_GPU_BLAS_KERNELS_H_ -#define JAXLIB_GPU_BLAS_KERNELS_H_ - -#include - -#include "jaxlib/gpu/vendor.h" -#include "xla/service/custom_call_status.h" - -namespace jax { -namespace JAX_GPU_NAMESPACE { - -// Set of types known to Cusolver. -enum class BlasType { - F32, - F64, - C64, - C128, -}; - -// Batched LU decomposition: getrfbatched - -struct GetrfBatchedDescriptor { - BlasType type; - int batch, n; -}; - -void GetrfBatched(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status); - -// Batched QR decomposition: geqrfbatched - -struct GeqrfBatchedDescriptor { - BlasType type; - int batch, m, n; -}; - -void GeqrfBatched(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status); - -} // namespace JAX_GPU_NAMESPACE -} // namespace jax - -#endif // JAXLIB_GPU_BLAS_KERNELS_H_ diff --git a/jaxlib/gpu/gpu_kernel_helpers.cc b/jaxlib/gpu/gpu_kernel_helpers.cc index 5a434f4b6ad5..5b509ad9912d 100644 --- a/jaxlib/gpu/gpu_kernel_helpers.cc +++ b/jaxlib/gpu/gpu_kernel_helpers.cc @@ -15,12 +15,15 @@ limitations under the License. #include "jaxlib/gpu/gpu_kernel_helpers.h" +#include +#include + #include "absl/base/optimization.h" #include "absl/log/check.h" -#include "absl/memory/memory.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" +#include "jaxlib/gpu/vendor.h" namespace jax { namespace JAX_GPU_NAMESPACE { diff --git a/jaxlib/gpu/gpu_kernel_helpers.h b/jaxlib/gpu/gpu_kernel_helpers.h index aecb8a4fdcf1..0326d7f44620 100644 --- a/jaxlib/gpu/gpu_kernel_helpers.h +++ b/jaxlib/gpu/gpu_kernel_helpers.h @@ -16,11 +16,10 @@ limitations under the License. #ifndef JAXLIB_GPU_GPU_KERNEL_HELPERS_H_ #define JAXLIB_GPU_GPU_KERNEL_HELPERS_H_ -#include +#include #include "absl/base/optimization.h" #include "absl/status/status.h" -#include "absl/status/statusor.h" #include "jaxlib/gpu/vendor.h" #define JAX_AS_STATUS(expr) \ diff --git a/jaxlib/gpu/gpu_kernels.cc b/jaxlib/gpu/gpu_kernels.cc index 242078357254..1f6e5f75315d 100644 --- a/jaxlib/gpu/gpu_kernels.cc +++ b/jaxlib/gpu/gpu_kernels.cc @@ -16,11 +16,9 @@ limitations under the License. // This file is not used by JAX itself, but exists to assist with running // JAX-generated HLO code from outside of JAX. -#include "jaxlib/gpu/blas_kernels.h" #include "jaxlib/gpu/linalg_kernels.h" #include "jaxlib/gpu/prng_kernels.h" #include "jaxlib/gpu/rnn_kernels.h" -#include "jaxlib/gpu/solver_kernels.h" #include "jaxlib/gpu/solver_kernels_ffi.h" #include "jaxlib/gpu/sparse_kernels.h" #include "jaxlib/gpu/triton_kernels.h" @@ -33,37 +31,25 @@ namespace jax { namespace JAX_GPU_NAMESPACE { namespace { -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cublas_getrf_batched", GetrfBatched, - "CUDA"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cublas_geqrf_batched", GeqrfBatched, - "CUDA"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cudnn_rnn", RNNForward, "CUDA"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cudnn_rnn_bwd", RNNBackward, "CUDA"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_getrf", Getrf, "CUDA"); +XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cudnn_rnn", "CUDA", RNNForwardFfi); +XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cudnn_rnn_bwd", "CUDA", + RNNBackwardFfi); XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_getrf_ffi", "CUDA", GetrfFfi); XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_syrk_ffi", "CUDA", SyrkFfi); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_geqrf", Geqrf, "CUDA"); XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_geqrf_ffi", "CUDA", GeqrfFfi); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_csrlsvqr", Csrlsvqr, "CUDA"); XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_csrlsvqr_ffi", "CUDA", CsrlsvqrFfi); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_orgqr", Orgqr, "CUDA"); XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_orgqr_ffi", "CUDA", OrgqrFfi); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_syevd", Syevd, "CUDA"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_syevj", Syevj, "CUDA"); XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_syevd_ffi", "CUDA", SyevdFfi); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_sytrd", Sytrd, "CUDA"); XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_sytrd_ffi", "CUDA", SytrdFfi); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_gesvd", Gesvd, "CUDA"); XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_gesvd_ffi", "CUDA", GesvdFfi); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_gesvdj", Gesvdj, "CUDA"); XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_gesvdj_ffi", "CUDA", GesvdjFfi); @@ -74,28 +60,26 @@ XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cu_lu_pivots_to_permutation", XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cu_threefry2x32_ffi", "CUDA", ThreeFry2x32Ffi); -#if JAX_CUSPARSE_11300 -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusparse_csr_todense", CsrToDense, - "CUDA"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusparse_csr_fromdense", CsrFromDense, - "CUDA"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusparse_csr_matvec", CsrMatvec, - "CUDA"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusparse_csr_matmat", CsrMatmat, - "CUDA"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusparse_coo_todense", CooToDense, - "CUDA"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusparse_coo_fromdense", CooFromDense, - "CUDA"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusparse_coo_matvec", CooMatvec, - "CUDA"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusparse_coo_matmat", CooMatmat, - "CUDA"); +#if JAX_GPU_HAVE_SPARSE +XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusparse_csr_todense_ffi", "CUDA", + CsrToDenseFfi); +XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusparse_csr_fromdense_ffi", "CUDA", + CsrFromDenseFfi); +XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusparse_csr_matvec_ffi", "CUDA", + CsrMatvecFfi); +XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusparse_csr_matmat_ffi", "CUDA", + CsrMatmatFfi); +XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusparse_coo_todense_ffi", "CUDA", + CooToDenseFfi); +XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusparse_coo_fromdense_ffi", "CUDA", + CooFromDenseFfi); +XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusparse_coo_matvec_ffi", "CUDA", + CooMatvecFfi); +XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusparse_coo_matmat_ffi", "CUDA", + CooMatmatFfi); #endif -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusparse_gtsv2_f32", gtsv2_f32, - "CUDA"); -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusparse_gtsv2_f64", gtsv2_f64, - "CUDA"); +XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusparse_gtsv2_ffi", "CUDA", + kGtsv2); XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("triton_kernel_call", TritonKernelCall, "CUDA"); diff --git a/jaxlib/gpu/gpu_plugin_extension.cc b/jaxlib/gpu/gpu_plugin_extension.cc index b56cb8337f1b..cca615cfb260 100644 --- a/jaxlib/gpu/gpu_plugin_extension.cc +++ b/jaxlib/gpu/gpu_plugin_extension.cc @@ -20,13 +20,13 @@ limitations under the License. #include #include -#include "nanobind/nanobind.h" -#include "nanobind/stl/string.h" // IWYU pragma: keep -#include "nanobind/stl/string_view.h" // IWYU pragma: keep #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep #include "jaxlib/kernel_nanobind_helpers.h" #include "xla/ffi/api/c_api.h" #include "xla/pjrt/c/pjrt_c_api.h" @@ -35,7 +35,7 @@ limitations under the License. #include "xla/pjrt/c/pjrt_c_api_helpers.h" #include "xla/pjrt/c/pjrt_c_api_triton_extension.h" #include "xla/pjrt/status_casters.h" -#include "xla/python/py_client_gpu.h" +#include "xla/tsl/platform/statusor.h" #include "xla/tsl/python/lib/core/numpy.h" #include "xla/util.h" @@ -202,13 +202,6 @@ absl::Status RegisterCustomTypeId(const PJRT_Api* c_api, return absl::OkStatus(); } -nb::dict Registrations() { - nb::dict dict; - dict["xla_python_gpu_callback"] = - jax::EncapsulateFunction(xla::XlaPythonGpuCallback); - return dict; -} - } // namespace void BuildGpuPluginExtension(nanobind::module_& m) { @@ -264,7 +257,6 @@ void BuildGpuPluginExtension(nanobind::module_& m) { type_name_size, std::move(type_id))); }, nb::arg("c_api"), nb::arg("type_name"), nb::arg("type_id")); - m.def("registrations", &Registrations); } } // namespace xla diff --git a/jaxlib/handle_pool.h b/jaxlib/gpu/handle_pool.h similarity index 96% rename from jaxlib/handle_pool.h rename to jaxlib/gpu/handle_pool.h index 9201d8d579c5..9189bb174b06 100644 --- a/jaxlib/handle_pool.h +++ b/jaxlib/gpu/handle_pool.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef JAXLIB_HANDLE_POOL_H_ -#define JAXLIB_HANDLE_POOL_H_ +#ifndef JAXLIB_GPU_HANDLE_POOL_H_ +#define JAXLIB_GPU_HANDLE_POOL_H_ #include #include @@ -107,4 +107,4 @@ void HandlePool::Return(HandleType handle, } // namespace jax -#endif // JAXLIB_HANDLE_POOL_H_ +#endif // JAXLIB_GPU_HANDLE_POOL_H_ diff --git a/jaxlib/gpu/hybrid.cc b/jaxlib/gpu/hybrid.cc index 94975a5b969f..71c320a60f02 100644 --- a/jaxlib/gpu/hybrid.cc +++ b/jaxlib/gpu/hybrid.cc @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "nanobind/nanobind.h" #include "absl/base/call_once.h" +#include "nanobind/nanobind.h" #include "jaxlib/cpu/lapack_kernels.h" #include "jaxlib/gpu/hybrid_kernels.h" #include "jaxlib/gpu/vendor.h" diff --git a/jaxlib/gpu/linalg_kernels.cc b/jaxlib/gpu/linalg_kernels.cc index 2293bef89b7d..b48e64f2181d 100644 --- a/jaxlib/gpu/linalg_kernels.cc +++ b/jaxlib/gpu/linalg_kernels.cc @@ -90,8 +90,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(CholeskyUpdateFfi, CholeskyUpdateFfiImpl, namespace { ffi::Error LuPivotsToPermutationImpl( - gpuStream_t stream, ffi::Dictionary /* unused */, - ffi::Buffer pivots, + gpuStream_t stream, ffi::Buffer pivots, ffi::Result> permutation) { FFI_ASSIGN_OR_RETURN((auto [batch_size, pivot_size]), SplitBatch1D(pivots.dimensions())); @@ -119,10 +118,6 @@ ffi::Error LuPivotsToPermutationImpl( XLA_FFI_DEFINE_HANDLER_SYMBOL(LuPivotsToPermutation, LuPivotsToPermutationImpl, ffi::Ffi::Bind() .Ctx>() - // TODO(b/358275922): remove Attrs (and the - // unused Dictionary above) 12 weeks after - // release of jaxlib v0.4.32. - .Attrs() .Arg>() .Ret>()); diff --git a/jaxlib/gpu/make_batch_pointers.cu.cc b/jaxlib/gpu/make_batch_pointers.cu.cc index 3a24e355ead0..1d05fa8adcac 100644 --- a/jaxlib/gpu/make_batch_pointers.cu.cc +++ b/jaxlib/gpu/make_batch_pointers.cu.cc @@ -16,6 +16,7 @@ limitations under the License. #include "jaxlib/gpu/make_batch_pointers.h" #include +#include #include #include "jaxlib/gpu/vendor.h" diff --git a/jaxlib/gpu/prng.cc b/jaxlib/gpu/prng.cc index 1ce428d7f9dc..007e51b76de7 100644 --- a/jaxlib/gpu/prng.cc +++ b/jaxlib/gpu/prng.cc @@ -15,6 +15,7 @@ limitations under the License. #include "nanobind/nanobind.h" #include "jaxlib/gpu/prng_kernels.h" +#include "jaxlib/gpu/vendor.h" #include "jaxlib/kernel_nanobind_helpers.h" namespace jax { diff --git a/jaxlib/gpu/prng_kernels.cc b/jaxlib/gpu/prng_kernels.cc index f5d6abef83f8..1dac1e47bd44 100644 --- a/jaxlib/gpu/prng_kernels.cc +++ b/jaxlib/gpu/prng_kernels.cc @@ -17,16 +17,12 @@ limitations under the License. #include #include -#include #include "absl/algorithm/container.h" -#include "absl/status/status.h" #include "jaxlib/gpu/gpu_kernel_helpers.h" #include "jaxlib/gpu/vendor.h" #include "jaxlib/ffi_helpers.h" -#include "jaxlib/kernel_helpers.h" #include "xla/ffi/api/ffi.h" -#include "xla/service/custom_call_status.h" namespace jax { namespace JAX_GPU_NAMESPACE { diff --git a/jaxlib/gpu/prng_kernels.cu.cc b/jaxlib/gpu/prng_kernels.cu.cc index d4aaec62320d..e42165f95d15 100644 --- a/jaxlib/gpu/prng_kernels.cu.cc +++ b/jaxlib/gpu/prng_kernels.cu.cc @@ -15,8 +15,7 @@ limitations under the License. #include "jaxlib/gpu/prng_kernels.h" -#include -#include +#include #include #include "jaxlib/gpu/vendor.h" diff --git a/jaxlib/gpu/prng_kernels.h b/jaxlib/gpu/prng_kernels.h index c98fd485700d..4d64d2b4a4e4 100644 --- a/jaxlib/gpu/prng_kernels.h +++ b/jaxlib/gpu/prng_kernels.h @@ -16,12 +16,10 @@ limitations under the License. #ifndef JAXLIB_GPU_PRNG_KERNELS_H_ #define JAXLIB_GPU_PRNG_KERNELS_H_ -#include #include #include "jaxlib/gpu/vendor.h" #include "xla/ffi/api/ffi.h" -#include "xla/service/custom_call_status.h" namespace jax { namespace JAX_GPU_NAMESPACE { diff --git a/jaxlib/gpu/py_client_gpu.cc b/jaxlib/gpu/py_client_gpu.cc new file mode 100644 index 000000000000..cb618890023b --- /dev/null +++ b/jaxlib/gpu/py_client_gpu.cc @@ -0,0 +1,309 @@ +/* Copyright 2022 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/gpu/py_client_gpu.h" + +#include + +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/container/inlined_vector.h" +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/ascii.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "include/dlpack/dlpack.h" +#include "nanobind/nanobind.h" +#include "jaxlib/ffi.h" +#include "jaxlib/gpu/vendor.h" +#include "xla/ffi/api/ffi.h" +#include "xla/ffi/ffi_api.h" +#include "xla/pjrt/host_callback.h" +#include "xla/pjrt/transpose.h" +#include "xla/primitive_util.h" +#include "xla/python/nb_numpy.h" +#include "xla/python/types.h" +#include "xla/shape_util.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" + +namespace nb = nanobind; + +namespace jax { +namespace JAX_GPU_NAMESPACE { + +struct GpuTransposePlanCache { + static xla::ffi::TypeId id; + explicit GpuTransposePlanCache(int capacity) : cache(capacity) {} + xla::TransposePlanCache cache; +}; +xla::ffi::TypeId GpuTransposePlanCache::id = {}; + +XLA_FFI_REGISTER_TYPE(xla::ffi::GetXlaFfiApi(), "GpuTransposePlanCache", + &GpuTransposePlanCache::id); + +static xla::ffi::ErrorOr> +GpuTransposePlanCacheInstantiate(uint64_t index) { + return std::make_unique(16); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL( + kGpuTransposePlanCacheInstantiate, GpuTransposePlanCacheInstantiate, + xla::ffi::Ffi::BindInstantiate().Attr("index")); +xla::ffi::Error XlaFfiPythonGpuCallback(gpuStream_t stream, + xla::FfiLoadedHostCallbacks* callbacks, + GpuTransposePlanCache* transpose_cache, + uint64_t index, + xla::ffi::RemainingArgs args, + xla::ffi::RemainingRets rets) { + size_t arity = args.size(); + std::vector host_input_buffers(arity); + // Copy input GPU buffers to host + for (size_t i = 0; i < arity; ++i) { + auto arg = args.get(i); + auto ptype = static_cast(arg->element_type()); + // TODO(b/395428868): Remove this check once we support subbyte types. + if (ptype == xla::S1 || ptype == xla::U1) { + return xla::ffi::Error(xla::ffi::ErrorCode::kUnimplemented, + absl::StrFormat("Unsupported primitive type: %s", + PrimitiveType_Name(ptype))); + } + if (ptype == xla::TOKEN) { + host_input_buffers[i] = nullptr; + continue; + } + size_t size_bytes = arg->size_bytes(); + // NOTE(dsuo): FFI arguments and return buffers are sized assuming + // minimum 1-byte element sizes, even if the data itself is packed. We + // assume that 2-bit and 4-bit types are packed. + size_t bits_per_element = xla::primitive_util::BitWidth(ptype); + if (bits_per_element == 2 || bits_per_element == 4) { + size_bytes = arg->element_count() * bits_per_element / 8; + } + host_input_buffers[i] = new char[size_bytes]; + // TODO(b/238441608): Use pinned memory here to speed up the transfer. + auto gpu_res = + gpuMemcpyAsync(host_input_buffers[i], arg.value().untyped_data(), + size_bytes, gpuMemcpyDeviceToHost, stream); + CHECK_EQ(gpu_res, gpuSuccess) << "Failed to gpuMemcpyAsync"; + } + CHECK_EQ(gpuStreamSynchronize(stream), gpuSuccess) + << "Failed to gpuStreamSynchronize"; + nb::gil_scoped_acquire gil; + auto callback = nb::borrow( + static_cast(callbacks->callbacks[index])); + nb::tuple host_input_arrays = nb::steal(PyTuple_New(arity)); + for (size_t i = 0; i < arity; ++i) { + auto arg = args.get(i); + auto ptype = static_cast(arg->element_type()); + if (ptype == xla::TOKEN) { + PyTuple_SET_ITEM(host_input_arrays.ptr(), i, nb::none().inc_ref().ptr()); + continue; + } + auto maybe_dtype = PrimitiveTypeToNbDtype(ptype); + if (!maybe_dtype.ok()) { + return xla::ffi::Error::Internal(maybe_dtype.status().ToString()); + } + auto dtype = maybe_dtype.value(); + auto dims = absl::Span(arg->dimensions().begin(), + arg->dimensions().size()); + // TODO(b/402422886): Remove this once we form Jax arrays directly instead + // of packing/unpacking to/from numpy arrays. + // We pass in data using default numpy layout i.e., std::nullopt. + size_t bits_per_element = xla::primitive_util::BitWidth(ptype); + if (bits_per_element == 2 || bits_per_element == 4) { + // NOTE(dsuo): FFI arguments and return buffers are sized assuming + // minimum 1-byte element sizes, even if the data itself is packed. We + // assume that 2-bit and 4-bit types are packed. + auto size_bytes = arg->element_count() * bits_per_element / 8; + auto buffer = xla::UnpackIntN( + bits_per_element, static_cast(host_input_buffers[i]), + size_bytes); + delete[] static_cast(host_input_buffers[i]); + host_input_buffers[i] = buffer.release(); + } + nb::capsule base(host_input_buffers[i], [](void* ptr) noexcept { + delete[] static_cast(ptr); + }); + auto array = xla::nb_numpy_ndarray(dtype, dims, std::nullopt, + host_input_buffers[i], base); + array.attr("flags").attr("writeable") = nb::bool_(false); + PyTuple_SET_ITEM(host_input_arrays.ptr(), i, array.inc_ref().ptr()); + } + + xla::EnterHostCallback(); + // TODO(dsuo): Change this to use the Python vectorcall protocol, which allows + // you to avoid constructing a tuple for the arguments. + nb::tuple result_tuple; + try { + auto result_object = callback(*nb::borrow(host_input_arrays)); + result_tuple = nb::cast(result_object); + } catch (nb::python_error& e) { + return xla::ffi::Error::Internal( + absl::StrFormat("CpuCallback error calling callback: %s", e.what())); + } + xla::LeaveHostCallback(); + + std::vector temp_buffers; + for (size_t i = 0; i < rets.size(); ++i) { + auto ret = rets.get(i).value(); + auto ptype = static_cast(ret->element_type()); + // TODO(b/395428868): Remove this check once we support subbyte types. + if (ptype == xla::S1 || ptype == xla::U1) { + return xla::ffi::Error(xla::ffi::ErrorCode::kUnimplemented, + absl::StrFormat("Unsupported primitive type: %s", + PrimitiveType_Name(ptype))); + } + if (ptype == xla::TOKEN) continue; + nb::object output = + nb::borrow(PyTuple_GetItem(result_tuple.ptr(), i)); + auto array = xla::nb_numpy_ndarray::ensure(std::move(output)); + absl::Span strides( + reinterpret_cast(array.strides()), array.ndim()); + // We expect the output to be in default numpy layout. + auto dims = absl::Span(ret->dimensions().begin(), + ret->dimensions().size()); + auto maybe_expected_shape = xla::ShapeUtil::MakeValidatedShape(ptype, dims); + if (!maybe_expected_shape.ok()) { + return xla::ffi::Error::Internal( + maybe_expected_shape.status().ToString()); + } + auto expected_shape = maybe_expected_shape.value(); + auto expected_strides = xla::ByteStridesForShape(expected_shape); + + const void* data = array.data(); + size_t size_bytes = array.size() * array.itemsize(); + if (strides != expected_strides) { + xla::TransposePlan::Options options; + options.elem_size_in_bytes = xla::primitive_util::ByteWidth(ptype); + options.dims = absl::Span( + reinterpret_cast(array.shape()), array.ndim()); + absl::InlinedVector reversed_layout; + reversed_layout.resize(expected_shape.dimensions().size()); + absl::c_reverse_copy(expected_shape.layout().minor_to_major(), + reversed_layout.begin()); + options.permutation = reversed_layout; + options.input_layout = xla::TransposePlan::Striding{strides}; + auto maybe_plan = transpose_cache->cache.GetOrCreate(options); + if (!maybe_plan.ok()) { + return xla::ffi::Error::Internal(maybe_plan.status().ToString()); + } + auto plan = maybe_plan.value(); + void* temp = new char[size_bytes]; + temp_buffers.push_back(temp); + plan->Execute(data, temp); + data = temp; + } + + // TODO(b/402422886): Remove this once we form Jax arrays directly instead + // of packing/unpacking to/from numpy arrays. + std::unique_ptr buffer; + size_t bits_per_element = xla::primitive_util::BitWidth(ptype); + if (bits_per_element == 2 || bits_per_element == 4) { + // NOTE(dsuo): FFI arguments and return buffers are sized assuming + // minimum 1-byte element sizes, even if the data itself is packed. We + // assume that 2-bit and 4-bit types are packed. + buffer = xla::PackIntN(bits_per_element, static_cast(data), + size_bytes); + data = buffer.get(); + size_bytes = (size_bytes * bits_per_element) / 8; + } + + auto gpu_res = gpuMemcpyAsync(ret->untyped_data(), data, size_bytes, + gpuMemcpyHostToDevice, stream); + CHECK_EQ(gpu_res, gpuSuccess) << "Failed to gpuMemcpyAsync"; + } + nb::gil_scoped_release release; + CHECK_EQ(gpuStreamSynchronize(stream), gpuSuccess) + << "Failed to gpuStreamSynchronize"; + for (int i = 0; i < temp_buffers.size(); ++i) { + delete[] static_cast(temp_buffers[i]); + } + return xla::ffi::Error::Success(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL( + kXlaFfiPythonGpuCallback, XlaFfiPythonGpuCallback, + xla::ffi::Ffi::Bind() + .Ctx>() + .Ctx>() + .Ctx>() + .Attr("index") + .RemainingArgs() + .RemainingRets()); +XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), + "xla_ffi_python_gpu_callback", + absl::AsciiStrToUpper(JAX_GPU_PLUGIN_NAME), + {kGpuTransposePlanCacheInstantiate, nullptr, nullptr, + kXlaFfiPythonGpuCallback}); +XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), + "xla_ffi_partitioned_python_gpu_callback", + absl::AsciiStrToUpper(JAX_GPU_PLUGIN_NAME), + {kGpuTransposePlanCacheInstantiate, nullptr, nullptr, + kXlaFfiPythonGpuCallback}); + +XLA_FFI_DEFINE_HANDLER_SYMBOL( + kXlaBufferPythonGpuCallback, +#ifdef JAX_GPU_CUDA + (jax::XlaBufferCallback), +#else + (jax::XlaBufferCallback), +#endif + xla::ffi::Ffi::Bind() + .Ctx() + .Ctx() + .Ctx() + .Ctx>() + .Attr("index") + .RemainingArgs() + .RemainingRets()); + +XLA_FFI_DEFINE_HANDLER_SYMBOL( + kXlaBufferPythonGpuCallbackCmdBuffer, +#ifdef JAX_GPU_CUDA + (jax::XlaBufferCallback), +#else + (jax::XlaBufferCallback), +#endif + xla::ffi::Ffi::Bind() + .Ctx() + .Ctx() + .Ctx() + .Ctx>() + .Attr("index") + .RemainingArgs() + .RemainingRets(), + {ffi::Traits::kCmdBufferCompatible}); + +XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), + "xla_buffer_python_gpu_callback", + absl::AsciiStrToUpper(JAX_GPU_PLUGIN_NAME), + kXlaBufferPythonGpuCallback); +XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), + "xla_buffer_python_gpu_callback_cmd_buffer", + absl::AsciiStrToUpper(JAX_GPU_PLUGIN_NAME), + kXlaBufferPythonGpuCallbackCmdBuffer); + +} // namespace JAX_GPU_NAMESPACE +} // namespace jax diff --git a/jaxlib/gpu/py_client_gpu.h b/jaxlib/gpu/py_client_gpu.h new file mode 100644 index 000000000000..0df0891ceae5 --- /dev/null +++ b/jaxlib/gpu/py_client_gpu.h @@ -0,0 +1,31 @@ +/* Copyright 2022 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAX_JAXLIB_GPU_PY_CLIENT_GPU_H_ +#define JAX_JAXLIB_GPU_PY_CLIENT_GPU_H_ + +#include "jaxlib/gpu/vendor.h" +#include "xla/ffi/api/ffi.h" + +namespace jax { +namespace JAX_GPU_NAMESPACE { +XLA_FFI_DECLARE_HANDLER_SYMBOL(kGpuTransposePlanCacheInstantiate); +XLA_FFI_DECLARE_HANDLER_SYMBOL(kXlaFfiPythonGpuCallback); +XLA_FFI_DECLARE_HANDLER_SYMBOL(kXlaBufferPythonGpuCallback); +XLA_FFI_DECLARE_HANDLER_SYMBOL(kXlaBufferPythonGpuCallbackCmdBuffer); +} // namespace JAX_GPU_NAMESPACE +} // namespace jax + +#endif // JAX_JAXLIB_GPU_PY_CLIENT_GPU_H_ diff --git a/jaxlib/gpu/rnn.cc b/jaxlib/gpu/rnn.cc index eaa815d33e68..c235aa9fecfb 100644 --- a/jaxlib/gpu/rnn.cc +++ b/jaxlib/gpu/rnn.cc @@ -16,7 +16,7 @@ limitations under the License. #include #include "nanobind/nanobind.h" -#include "nanobind/stl/pair.h" +#include "nanobind/stl/pair.h" // IWYU pragma: keep #include "jaxlib/absl_status_casters.h" #include "jaxlib/gpu/rnn_kernels.h" #include "jaxlib/gpu/vendor.h" @@ -39,8 +39,6 @@ nb::bytes BuildRnnDescriptor(int input_size, int hidden_size, int num_layers, nb::dict Registrations() { nb::dict dict; - dict[JAX_GPU_PREFIX "dnn_rnn"] = EncapsulateFunction(RNNForward); - dict[JAX_GPU_PREFIX "dnn_rnn_bwd"] = EncapsulateFunction(RNNBackward); dict[JAX_GPU_PREFIX "dnn_rnn_ffi"] = EncapsulateFfiHandler(RNNForwardFfi); dict[JAX_GPU_PREFIX "dnn_rnn_bwd_ffi"] = EncapsulateFfiHandler(RNNBackwardFfi); diff --git a/jaxlib/gpu/rnn_kernels.cc b/jaxlib/gpu/rnn_kernels.cc index e9820bc31f1e..44864d6a2663 100644 --- a/jaxlib/gpu/rnn_kernels.cc +++ b/jaxlib/gpu/rnn_kernels.cc @@ -16,16 +16,20 @@ limitations under the License. #include "jaxlib/gpu/rnn_kernels.h" #include +#include +#include #include #include #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/str_format.h" +#include "absl/synchronization/mutex.h" #include "jaxlib/gpu/ffi_wrapper.h" #include "jaxlib/gpu/gpu_kernel_helpers.h" -#include "jaxlib/handle_pool.h" +#include "jaxlib/gpu/handle_pool.h" +#include "jaxlib/gpu/vendor.h" #include "jaxlib/kernel_helpers.h" -#include "xla/service/custom_call_status.h" namespace jax { @@ -536,24 +540,6 @@ static absl::Status DnnRNNBackward_(gpuStream_t stream, void** buffers, return absl::OkStatus(); } -void RNNForward(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status) { - auto s = DnnRNNForward_(stream, buffers, opaque, opaque_len); - if (!s.ok()) { - XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(), - s.message().length()); - } -} - -void RNNBackward(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status) { - auto s = DnnRNNBackward_(stream, buffers, opaque, opaque_len); - if (!s.ok()) { - XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(), - s.message().length()); - } -} - JAX_GPU_REGISTER_WRAPPED_LEGACY_KERNEL(RNNForwardFfi, DnnRNNForward_); JAX_GPU_REGISTER_WRAPPED_LEGACY_KERNEL(RNNBackwardFfi, DnnRNNBackward_); diff --git a/jaxlib/gpu/rnn_kernels.h b/jaxlib/gpu/rnn_kernels.h index e95b7788382a..c1d6712a9eac 100644 --- a/jaxlib/gpu/rnn_kernels.h +++ b/jaxlib/gpu/rnn_kernels.h @@ -17,11 +17,11 @@ limitations under the License. #define JAXLIB_GPU_RNN_KERNELS_H_ #include +#include #include "absl/status/statusor.h" #include "jaxlib/gpu/vendor.h" #include "xla/ffi/api/ffi.h" -#include "xla/service/custom_call_status.h" namespace jax { namespace JAX_GPU_NAMESPACE { @@ -46,12 +46,6 @@ absl::StatusOr> RnnComputeWorkspaceReserveSpaceSizes( int max_seq_length, float dropout, bool bidirectional, bool cudnn_allow_tf32); -void RNNForward(gpuStream_t stream, void **buffers, const char *opaque, - size_t opaque_len, XlaCustomCallStatus *status); - -void RNNBackward(gpuStream_t stream, void **buffers, const char *opaque, - size_t opaque_len, XlaCustomCallStatus *status); - XLA_FFI_DECLARE_HANDLER_SYMBOL(RNNForwardFfi); XLA_FFI_DECLARE_HANDLER_SYMBOL(RNNBackwardFfi); diff --git a/jaxlib/gpu/solver.cc b/jaxlib/gpu/solver.cc index 357a38eecfd5..08d25948d893 100644 --- a/jaxlib/gpu/solver.cc +++ b/jaxlib/gpu/solver.cc @@ -13,22 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include -#include -#include - #include "nanobind/nanobind.h" -#include "nanobind/stl/pair.h" -#include "absl/container/flat_hash_map.h" -#include "absl/status/statusor.h" -#include "absl/strings/str_format.h" -#include "jaxlib/gpu/gpu_kernel_helpers.h" -#include "jaxlib/gpu/solver_handle_pool.h" -#include "jaxlib/gpu/solver_kernels.h" #include "jaxlib/gpu/solver_kernels_ffi.h" #include "jaxlib/gpu/vendor.h" #include "jaxlib/kernel_nanobind_helpers.h" -#include "xla/tsl/python/lib/core/numpy.h" namespace jax { namespace JAX_GPU_NAMESPACE { @@ -36,445 +24,8 @@ namespace { namespace nb = nanobind; -// Converts a NumPy dtype to a Type. -SolverType DtypeToSolverType(const dtype& np_type) { - static auto* types = - new absl::flat_hash_map, SolverType>({ - {{'f', 4}, SolverType::F32}, - {{'f', 8}, SolverType::F64}, - {{'c', 8}, SolverType::C64}, - {{'c', 16}, SolverType::C128}, - }); - auto it = types->find({np_type.kind(), np_type.itemsize()}); - if (it == types->end()) { - nb::str repr = nb::repr(np_type); - throw std::invalid_argument( - absl::StrFormat("Unsupported dtype %s", repr.c_str())); - } - return it->second; -} - -// getrf: LU decomposition - -// Returns the workspace size and a descriptor for a getrf operation. -std::pair BuildGetrfDescriptor(const dtype& dtype, int b, int m, - int n) { - SolverType type = DtypeToSolverType(dtype); - auto h = SolverHandlePool::Borrow(/*stream=*/nullptr); - JAX_THROW_IF_ERROR(h.status()); - auto& handle = *h; - int lwork; - switch (type) { - case SolverType::F32: - JAX_THROW_IF_ERROR( - JAX_AS_STATUS(gpusolverDnSgetrf_bufferSize(handle.get(), m, n, - /*A=*/nullptr, - /*lda=*/m, &lwork))); - break; - case SolverType::F64: - JAX_THROW_IF_ERROR( - JAX_AS_STATUS(gpusolverDnDgetrf_bufferSize(handle.get(), m, n, - /*A=*/nullptr, - /*lda=*/m, &lwork))); - break; - case SolverType::C64: - JAX_THROW_IF_ERROR( - JAX_AS_STATUS(gpusolverDnCgetrf_bufferSize(handle.get(), m, n, - /*A=*/nullptr, - /*lda=*/m, &lwork))); - break; - case SolverType::C128: - JAX_THROW_IF_ERROR( - JAX_AS_STATUS(gpusolverDnZgetrf_bufferSize(handle.get(), m, n, - /*A=*/nullptr, - /*lda=*/m, &lwork))); - break; - } - return {lwork, PackDescriptor(GetrfDescriptor{type, b, m, n, lwork})}; -} - -// geqrf: QR decomposition - -// Returns the workspace size and a descriptor for a geqrf operation. -std::pair BuildGeqrfDescriptor(const dtype& dtype, int b, int m, - int n) { - SolverType type = DtypeToSolverType(dtype); - auto h = SolverHandlePool::Borrow(/*stream=*/nullptr); - JAX_THROW_IF_ERROR(h.status()); - auto& handle = *h; - int lwork; - switch (type) { - case SolverType::F32: - JAX_THROW_IF_ERROR( - JAX_AS_STATUS(gpusolverDnSgeqrf_bufferSize(handle.get(), m, n, - /*A=*/nullptr, - /*lda=*/m, &lwork))); - break; - case SolverType::F64: - JAX_THROW_IF_ERROR( - JAX_AS_STATUS(gpusolverDnDgeqrf_bufferSize(handle.get(), m, n, - /*A=*/nullptr, - /*lda=*/m, &lwork))); - break; - case SolverType::C64: - JAX_THROW_IF_ERROR( - JAX_AS_STATUS(gpusolverDnCgeqrf_bufferSize(handle.get(), m, n, - /*A=*/nullptr, - /*lda=*/m, &lwork))); - break; - case SolverType::C128: - JAX_THROW_IF_ERROR( - JAX_AS_STATUS(gpusolverDnZgeqrf_bufferSize(handle.get(), m, n, - /*A=*/nullptr, - /*lda=*/m, &lwork))); - break; - } - return {lwork, PackDescriptor(GeqrfDescriptor{type, b, m, n, lwork})}; -} - -#ifdef JAX_GPU_CUDA - -// csrlsvqr: Linear system solve via Sparse QR - -// Returns a descriptor for a csrlsvqr operation. -nb::bytes BuildCsrlsvqrDescriptor(const dtype& dtype, int n, int nnzA, - int reorder, double tol) { - SolverType type = DtypeToSolverType(dtype); - return PackDescriptor(CsrlsvqrDescriptor{type, n, nnzA, reorder, tol}); -} - -#endif // JAX_GPU_CUDA - -// orgqr/ungqr: apply elementary Householder transformations - -// Returns the workspace size and a descriptor for a geqrf operation. -std::pair BuildOrgqrDescriptor(const dtype& dtype, int b, int m, - int n, int k) { - SolverType type = DtypeToSolverType(dtype); - auto h = SolverHandlePool::Borrow(/*stream=*/nullptr); - JAX_THROW_IF_ERROR(h.status()); - auto& handle = *h; - int lwork; - switch (type) { - case SolverType::F32: - JAX_THROW_IF_ERROR( - JAX_AS_STATUS(gpusolverDnSorgqr_bufferSize(handle.get(), m, n, k, - /*A=*/nullptr, - /*lda=*/m, - /*tau=*/nullptr, &lwork))); - break; - case SolverType::F64: - JAX_THROW_IF_ERROR( - JAX_AS_STATUS(gpusolverDnDorgqr_bufferSize(handle.get(), m, n, k, - /*A=*/nullptr, - /*lda=*/m, - /*tau=*/nullptr, &lwork))); - break; - case SolverType::C64: - JAX_THROW_IF_ERROR( - JAX_AS_STATUS(gpusolverDnCungqr_bufferSize(handle.get(), m, n, k, - /*A=*/nullptr, - /*lda=*/m, - /*tau=*/nullptr, &lwork))); - break; - case SolverType::C128: - JAX_THROW_IF_ERROR( - JAX_AS_STATUS(gpusolverDnZungqr_bufferSize(handle.get(), m, n, k, - /*A=*/nullptr, - /*lda=*/m, - /*tau=*/nullptr, &lwork))); - break; - } - return {lwork, PackDescriptor(OrgqrDescriptor{type, b, m, n, k, lwork})}; -} - -// Symmetric (Hermitian) eigendecomposition, QR algorithm: syevd/heevd - -// Returns the workspace size and a descriptor for a syevd operation. -std::pair BuildSyevdDescriptor(const dtype& dtype, bool lower, - int b, int n) { - SolverType type = DtypeToSolverType(dtype); - auto h = SolverHandlePool::Borrow(/*stream=*/nullptr); - JAX_THROW_IF_ERROR(h.status()); - auto& handle = *h; - int lwork; - gpusolverEigMode_t jobz = GPUSOLVER_EIG_MODE_VECTOR; - gpusolverFillMode_t uplo = - lower ? GPUSOLVER_FILL_MODE_LOWER : GPUSOLVER_FILL_MODE_UPPER; - switch (type) { - case SolverType::F32: - JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnSsyevd_bufferSize( - handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, /*W=*/nullptr, - &lwork))); - break; - case SolverType::F64: - JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnDsyevd_bufferSize( - handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, /*W=*/nullptr, - &lwork))); - break; - case SolverType::C64: - JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnCheevd_bufferSize( - handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, /*W=*/nullptr, - &lwork))); - break; - case SolverType::C128: - JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnZheevd_bufferSize( - handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, /*W=*/nullptr, - &lwork))); - break; - } - return {lwork, PackDescriptor(SyevdDescriptor{type, uplo, b, n, lwork})}; -} - -// Symmetric (Hermitian) eigendecomposition, Jacobi algorithm: syevj/heevj -// Supports batches of matrices up to size 32. - -// Returns the workspace size and a descriptor for a syevj_batched operation. -std::pair BuildSyevjDescriptor(const dtype& dtype, bool lower, - int batch, int n) { - SolverType type = DtypeToSolverType(dtype); - auto h = SolverHandlePool::Borrow(/*stream=*/nullptr); - JAX_THROW_IF_ERROR(h.status()); - auto& handle = *h; - int lwork; - gpuSyevjInfo_t params; - JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnCreateSyevjInfo(¶ms))); - std::unique_ptr params_cleanup( - params, [](gpuSyevjInfo_t p) { gpusolverDnDestroySyevjInfo(p); }); - gpusolverEigMode_t jobz = GPUSOLVER_EIG_MODE_VECTOR; - gpusolverFillMode_t uplo = - lower ? GPUSOLVER_FILL_MODE_LOWER : GPUSOLVER_FILL_MODE_UPPER; - if (batch == 1) { - switch (type) { - case SolverType::F32: - JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnSsyevj_bufferSize( - handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, - /*W=*/nullptr, &lwork, params))); - break; - case SolverType::F64: - JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnDsyevj_bufferSize( - handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, - /*W=*/nullptr, &lwork, params))); - break; - case SolverType::C64: - JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnCheevj_bufferSize( - handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, - /*W=*/nullptr, &lwork, params))); - break; - case SolverType::C128: - JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnZheevj_bufferSize( - handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, - /*W=*/nullptr, &lwork, params))); - break; - } - } else { - switch (type) { - case SolverType::F32: - JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnSsyevjBatched_bufferSize( - handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, - /*W=*/nullptr, &lwork, params, batch))); - break; - case SolverType::F64: - JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnDsyevjBatched_bufferSize( - handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, - /*W=*/nullptr, &lwork, params, batch))); - break; - case SolverType::C64: - JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnCheevjBatched_bufferSize( - handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, - /*W=*/nullptr, &lwork, params, batch))); - break; - case SolverType::C128: - JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnZheevjBatched_bufferSize( - handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, - /*W=*/nullptr, &lwork, params, batch))); - break; - } - } - return {lwork, PackDescriptor(SyevjDescriptor{type, uplo, batch, n, lwork})}; -} - -// Singular value decomposition using QR algorithm: gesvd - -// Returns the workspace size and a descriptor for a gesvd operation. -std::pair BuildGesvdDescriptor(const dtype& dtype, int b, int m, - int n, bool compute_uv, - bool full_matrices) { - SolverType type = DtypeToSolverType(dtype); - auto h = SolverHandlePool::Borrow(/*stream=*/nullptr); - JAX_THROW_IF_ERROR(h.status()); - auto& handle = *h; - int lwork; - signed char jobu, jobvt; - if (compute_uv) { - if (full_matrices) { - jobu = jobvt = 'A'; - } else { - jobu = jobvt = 'S'; - } - } else { - jobu = jobvt = 'N'; - } - switch (type) { - case SolverType::F32: - JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnSgesvd_bufferSize( - handle.get(), jobu, jobvt, m, n, &lwork))); - break; - case SolverType::F64: - JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnDgesvd_bufferSize( - handle.get(), jobu, jobvt, m, n, &lwork))); - break; - case SolverType::C64: - JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnCgesvd_bufferSize( - handle.get(), jobu, jobvt, m, n, &lwork))); - break; - case SolverType::C128: - JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnZgesvd_bufferSize( - handle.get(), jobu, jobvt, m, n, &lwork))); - break; - } - return {lwork, - PackDescriptor(GesvdDescriptor{type, b, m, n, lwork, jobu, jobvt})}; -} - -#ifdef JAX_GPU_CUDA - -// Singular value decomposition using Jacobi algorithm: gesvdj - -// Returns the workspace size and a descriptor for a gesvdj operation. -std::pair BuildGesvdjDescriptor(const dtype& dtype, int batch, - int m, int n, bool compute_uv, - int econ) { - SolverType type = DtypeToSolverType(dtype); - auto h = SolverHandlePool::Borrow(/*stream=*/nullptr); - JAX_THROW_IF_ERROR(h.status()); - auto& handle = *h; - int lwork; - gpusolverEigMode_t jobz = - compute_uv ? GPUSOLVER_EIG_MODE_VECTOR : CUSOLVER_EIG_MODE_NOVECTOR; - gesvdjInfo_t params; - JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnCreateGesvdjInfo(¶ms))); - std::unique_ptr params_cleanup( - params, [](gesvdjInfo* p) { cusolverDnDestroyGesvdjInfo(p); }); - if (batch <= 1 || m > 32 || n > 32 || econ) { - switch (type) { - case SolverType::F32: - JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnSgesvdj_bufferSize( - handle.get(), jobz, econ, m, n, - /*A=*/nullptr, /*lda=*/m, /*S=*/nullptr, - /*U=*/nullptr, /*ldu=*/m, /*V=*/nullptr, - /*ldv=*/n, &lwork, params))); - break; - case SolverType::F64: - JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnDgesvdj_bufferSize( - handle.get(), jobz, econ, m, n, - /*A=*/nullptr, /*lda=*/m, /*S=*/nullptr, - /*U=*/nullptr, /*ldu=*/m, /*V=*/nullptr, - /*ldv=*/n, &lwork, params))); - break; - case SolverType::C64: - JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnCgesvdj_bufferSize( - handle.get(), jobz, econ, m, n, - /*A=*/nullptr, /*lda=*/m, /*S=*/nullptr, - /*U=*/nullptr, /*ldu=*/m, /*V=*/nullptr, - /*ldv=*/n, &lwork, params))); - break; - case SolverType::C128: - JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnZgesvdj_bufferSize( - handle.get(), jobz, econ, m, n, - /*A=*/nullptr, /*lda=*/m, /*S=*/nullptr, - /*U=*/nullptr, /*ldu=*/m, /*V=*/nullptr, - /*ldv=*/n, &lwork, params))); - break; - } - } else { - switch (type) { - case SolverType::F32: - JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnSgesvdjBatched_bufferSize( - handle.get(), jobz, m, n, - /*A=*/nullptr, /*lda=*/m, /*S=*/nullptr, - /*U=*/nullptr, /*ldu=*/m, /*V=*/nullptr, - /*ldv=*/n, &lwork, params, batch))); - break; - case SolverType::F64: - JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnDgesvdjBatched_bufferSize( - handle.get(), jobz, m, n, - /*A=*/nullptr, /*lda=*/m, /*S=*/nullptr, - /*U=*/nullptr, /*ldu=*/m, /*V=*/nullptr, - /*ldv=*/n, &lwork, params, batch))); - break; - case SolverType::C64: - JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnCgesvdjBatched_bufferSize( - handle.get(), jobz, m, n, - /*A=*/nullptr, /*lda=*/m, /*S=*/nullptr, - /*U=*/nullptr, /*ldu=*/m, /*V=*/nullptr, - /*ldv=*/n, &lwork, params, batch))); - break; - case SolverType::C128: - JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnZgesvdjBatched_bufferSize( - handle.get(), jobz, m, n, - /*A=*/nullptr, /*lda=*/m, /*S=*/nullptr, - /*U=*/nullptr, /*ldu=*/m, /*V=*/nullptr, - /*ldv=*/n, &lwork, params, batch))); - break; - } - } - return {lwork, PackDescriptor( - GesvdjDescriptor{type, batch, m, n, lwork, jobz, econ})}; -} - -#endif // JAX_GPU_CUDA - -// Returns the workspace size and a descriptor for a geqrf operation. -std::pair BuildSytrdDescriptor(const dtype& dtype, bool lower, - int b, int n) { - SolverType type = DtypeToSolverType(dtype); - auto h = SolverHandlePool::Borrow(/*stream=*/nullptr); - JAX_THROW_IF_ERROR(h.status()); - auto& handle = *h; - int lwork; - gpusolverFillMode_t uplo = - lower ? GPUSOLVER_FILL_MODE_LOWER : GPUSOLVER_FILL_MODE_UPPER; - switch (type) { - case SolverType::F32: - JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnSsytrd_bufferSize( - handle.get(), uplo, n, /*A=*/nullptr, /*lda=*/n, /*D=*/nullptr, - /*E=*/nullptr, /*tau=*/nullptr, &lwork))); - break; - case SolverType::F64: - JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnDsytrd_bufferSize( - handle.get(), uplo, n, /*A=*/nullptr, /*lda=*/n, /*D=*/nullptr, - /*E=*/nullptr, /*tau=*/nullptr, &lwork))); - break; - case SolverType::C64: - JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnChetrd_bufferSize( - handle.get(), uplo, n, /*A=*/nullptr, /*lda=*/n, /*D=*/nullptr, - /*E=*/nullptr, /*tau=*/nullptr, &lwork))); - break; - case SolverType::C128: - JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnZhetrd_bufferSize( - handle.get(), uplo, n, /*A=*/nullptr, /*lda=*/n, /*D=*/nullptr, - /*E=*/nullptr, /*tau=*/nullptr, &lwork))); - break; - } - return {lwork, PackDescriptor(SytrdDescriptor{type, uplo, b, n, n, lwork})}; -} - nb::dict Registrations() { nb::dict dict; - dict[JAX_GPU_PREFIX "solver_getrf"] = EncapsulateFunction(Getrf); - dict[JAX_GPU_PREFIX "solver_geqrf"] = EncapsulateFunction(Geqrf); - dict[JAX_GPU_PREFIX "solver_orgqr"] = EncapsulateFunction(Orgqr); - dict[JAX_GPU_PREFIX "solver_syevd"] = EncapsulateFunction(Syevd); - dict[JAX_GPU_PREFIX "solver_syevj"] = EncapsulateFunction(Syevj); - dict[JAX_GPU_PREFIX "solver_gesvd"] = EncapsulateFunction(Gesvd); - dict[JAX_GPU_PREFIX "solver_sytrd"] = EncapsulateFunction(Sytrd); - -#ifdef JAX_GPU_CUDA - dict["cusolver_csrlsvqr"] = EncapsulateFunction(Csrlsvqr); - dict["cusolver_gesvdj"] = EncapsulateFunction(Gesvdj); - -#endif // JAX_GPU_CUDA dict[JAX_GPU_PREFIX "solver_getrf_ffi"] = EncapsulateFfiHandler(GetrfFfi); dict[JAX_GPU_PREFIX "solver_geqrf_ffi"] = EncapsulateFfiHandler(GeqrfFfi); @@ -494,19 +45,7 @@ nb::dict Registrations() { } NB_MODULE(_solver, m) { - tsl::ImportNumpy(); m.def("registrations", &Registrations); - m.def("build_getrf_descriptor", &BuildGetrfDescriptor); - m.def("build_geqrf_descriptor", &BuildGeqrfDescriptor); - m.def("build_orgqr_descriptor", &BuildOrgqrDescriptor); - m.def("build_syevd_descriptor", &BuildSyevdDescriptor); - m.def("build_syevj_descriptor", &BuildSyevjDescriptor); - m.def("build_gesvd_descriptor", &BuildGesvdDescriptor); - m.def("build_sytrd_descriptor", &BuildSytrdDescriptor); -#ifdef JAX_GPU_CUDA - m.def("build_csrlsvqr_descriptor", &BuildCsrlsvqrDescriptor); - m.def("build_gesvdj_descriptor", &BuildGesvdjDescriptor); -#endif // JAX_GPU_CUDA } } // namespace diff --git a/jaxlib/gpu/solver_handle_pool.cc b/jaxlib/gpu/solver_handle_pool.cc index c55ea923b21b..416ccf9d1bbc 100644 --- a/jaxlib/gpu/solver_handle_pool.cc +++ b/jaxlib/gpu/solver_handle_pool.cc @@ -19,7 +19,7 @@ limitations under the License. #include "absl/synchronization/mutex.h" #include "jaxlib/gpu/gpu_kernel_helpers.h" #include "jaxlib/gpu/vendor.h" -#include "jaxlib/handle_pool.h" +#include "jaxlib/gpu/handle_pool.h" #ifdef JAX_GPU_CUDA #include "third_party/gpus/cuda/include/cusolverSp.h" diff --git a/jaxlib/gpu/solver_handle_pool.h b/jaxlib/gpu/solver_handle_pool.h index c46c062b3054..4e369ea85520 100644 --- a/jaxlib/gpu/solver_handle_pool.h +++ b/jaxlib/gpu/solver_handle_pool.h @@ -18,7 +18,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "jaxlib/gpu/vendor.h" -#include "jaxlib/handle_pool.h" +#include "jaxlib/gpu/handle_pool.h" #ifdef JAX_GPU_CUDA #include "third_party/gpus/cuda/include/cusolverSp.h" diff --git a/jaxlib/gpu/solver_kernels.cc b/jaxlib/gpu/solver_kernels.cc deleted file mode 100644 index 8c22dfcdbca7..000000000000 --- a/jaxlib/gpu/solver_kernels.cc +++ /dev/null @@ -1,978 +0,0 @@ -/* Copyright 2019 The JAX Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "jaxlib/gpu/solver_kernels.h" - -#include -#include -#include -#include -#include - -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "jaxlib/gpu/gpu_kernel_helpers.h" -#include "jaxlib/gpu/solver_handle_pool.h" -#include "jaxlib/gpu/vendor.h" -#include "jaxlib/kernel_helpers.h" -#include "xla/service/custom_call_status.h" - -#ifdef JAX_GPU_CUDA -#include "third_party/gpus/cuda/include/cusolverSp.h" -#endif // JAX_GPU_CUDA - -namespace jax { - -namespace JAX_GPU_NAMESPACE { - -static int SizeOfSolverType(SolverType type) { - switch (type) { - case SolverType::F32: - return sizeof(float); - case SolverType::F64: - return sizeof(double); - case SolverType::C64: - return sizeof(gpuComplex); - case SolverType::C128: - return sizeof(gpuDoubleComplex); - } -} - -// getrf: LU decomposition - -static absl::Status Getrf_(gpuStream_t stream, void** buffers, - const char* opaque, size_t opaque_len) { - auto s = UnpackDescriptor(opaque, opaque_len); - JAX_RETURN_IF_ERROR(s.status()); - const GetrfDescriptor& d = **s; - auto h = SolverHandlePool::Borrow(stream); - JAX_RETURN_IF_ERROR(h.status()); - auto& handle = *h; - if (buffers[1] != buffers[0]) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuMemcpyAsync( - buffers[1], buffers[0], - SizeOfSolverType(d.type) * static_cast(d.batch) * - static_cast(d.m) * static_cast(d.n), - gpuMemcpyDeviceToDevice, stream))); - } - - int* ipiv = static_cast(buffers[2]); - int* info = static_cast(buffers[3]); - void* workspace = buffers[4]; - switch (d.type) { - case SolverType::F32: { - float* a = static_cast(buffers[1]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnSgetrf( - handle.get(), d.m, d.n, a, d.m, static_cast(workspace), - d.lwork, ipiv, info))); - a += d.m * d.n; - ipiv += std::min(d.m, d.n); - ++info; - } - break; - } - case SolverType::F64: { - double* a = static_cast(buffers[1]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnDgetrf( - handle.get(), d.m, d.n, a, d.m, static_cast(workspace), - d.lwork, ipiv, info))); - a += d.m * d.n; - ipiv += std::min(d.m, d.n); - ++info; - } - break; - } - case SolverType::C64: { - gpuComplex* a = static_cast(buffers[1]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnCgetrf( - handle.get(), d.m, d.n, a, d.m, static_cast(workspace), - d.lwork, ipiv, info))); - a += d.m * d.n; - ipiv += std::min(d.m, d.n); - ++info; - } - break; - } - case SolverType::C128: { - gpuDoubleComplex* a = static_cast(buffers[1]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnZgetrf( - handle.get(), d.m, d.n, a, d.m, - static_cast(workspace), d.lwork, ipiv, info))); - a += d.m * d.n; - ipiv += std::min(d.m, d.n); - ++info; - } - break; - } - } - return absl::OkStatus(); -} - -void Getrf(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status) { - auto s = Getrf_(stream, buffers, opaque, opaque_len); - if (!s.ok()) { - XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(), - s.message().length()); - } -} - -// geqrf: QR decomposition - -static absl::Status Geqrf_(gpuStream_t stream, void** buffers, - const char* opaque, size_t opaque_len) { - auto s = UnpackDescriptor(opaque, opaque_len); - JAX_RETURN_IF_ERROR(s.status()); - const GeqrfDescriptor& d = **s; - auto h = SolverHandlePool::Borrow(stream); - JAX_RETURN_IF_ERROR(h.status()); - auto& handle = *h; - if (buffers[1] != buffers[0]) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuMemcpyAsync( - buffers[1], buffers[0], - SizeOfSolverType(d.type) * static_cast(d.batch) * - static_cast(d.m) * static_cast(d.n), - gpuMemcpyDeviceToDevice, stream))); - } - - int* info = static_cast(buffers[3]); - void* workspace = buffers[4]; - switch (d.type) { - case SolverType::F32: { - float* a = static_cast(buffers[1]); - float* tau = static_cast(buffers[2]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - gpusolverDnSgeqrf(handle.get(), d.m, d.n, a, d.m, tau, - static_cast(workspace), d.lwork, info))); - a += d.m * d.n; - tau += std::min(d.m, d.n); - ++info; - } - break; - } - case SolverType::F64: { - double* a = static_cast(buffers[1]); - double* tau = static_cast(buffers[2]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - gpusolverDnDgeqrf(handle.get(), d.m, d.n, a, d.m, tau, - static_cast(workspace), d.lwork, info))); - a += d.m * d.n; - tau += std::min(d.m, d.n); - ++info; - } - break; - } - case SolverType::C64: { - gpuComplex* a = static_cast(buffers[1]); - gpuComplex* tau = static_cast(buffers[2]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnCgeqrf( - handle.get(), d.m, d.n, a, d.m, tau, - static_cast(workspace), d.lwork, info))); - a += d.m * d.n; - tau += std::min(d.m, d.n); - ++info; - } - break; - } - case SolverType::C128: { - gpuDoubleComplex* a = static_cast(buffers[1]); - gpuDoubleComplex* tau = static_cast(buffers[2]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnZgeqrf( - handle.get(), d.m, d.n, a, d.m, tau, - static_cast(workspace), d.lwork, info))); - a += d.m * d.n; - tau += std::min(d.m, d.n); - ++info; - } - break; - } - } - return absl::OkStatus(); -} - -void Geqrf(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status) { - auto s = Geqrf_(stream, buffers, opaque, opaque_len); - if (!s.ok()) { - XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(), - s.message().length()); - } -} - -#ifdef JAX_GPU_CUDA - -// csrlsvqr: Linear system solve via Sparse QR - -static absl::Status Csrlsvqr_(gpuStream_t stream, void** buffers, - const char* opaque, size_t opaque_len, - int& singularity) { - auto s = UnpackDescriptor(opaque, opaque_len); - JAX_RETURN_IF_ERROR(s.status()); - const CsrlsvqrDescriptor& d = **s; - - // This is the handle to the CUDA session. Gets a cusolverSp handle. - auto h = SpSolverHandlePool::Borrow(stream); - JAX_RETURN_IF_ERROR(h.status()); - auto& handle = *h; - - cusparseMatDescr_t matdesc = nullptr; - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseCreateMatDescr(&matdesc))); - JAX_RETURN_IF_ERROR( - JAX_AS_STATUS(cusparseSetMatType(matdesc, CUSPARSE_MATRIX_TYPE_GENERAL))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - cusparseSetMatIndexBase(matdesc, CUSPARSE_INDEX_BASE_ZERO))); - - switch (d.type) { - case SolverType::F32: { - float* csrValA = static_cast(buffers[0]); - int* csrRowPtrA = static_cast(buffers[1]); - int* csrColIndA = static_cast(buffers[2]); - float* b = static_cast(buffers[3]); - float* x = static_cast(buffers[4]); - - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverSpScsrlsvqr( - handle.get(), d.n, d.nnz, matdesc, csrValA, csrRowPtrA, csrColIndA, b, - (float)d.tol, d.reorder, x, &singularity))); - - break; - } - case SolverType::F64: { - double* csrValA = static_cast(buffers[0]); - int* csrRowPtrA = static_cast(buffers[1]); - int* csrColIndA = static_cast(buffers[2]); - double* b = static_cast(buffers[3]); - double* x = static_cast(buffers[4]); - - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverSpDcsrlsvqr( - handle.get(), d.n, d.nnz, matdesc, csrValA, csrRowPtrA, csrColIndA, b, - d.tol, d.reorder, x, &singularity))); - - break; - } - case SolverType::C64: { - gpuComplex* csrValA = static_cast(buffers[0]); - int* csrRowPtrA = static_cast(buffers[1]); - int* csrColIndA = static_cast(buffers[2]); - gpuComplex* b = static_cast(buffers[3]); - gpuComplex* x = static_cast(buffers[4]); - - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverSpCcsrlsvqr( - handle.get(), d.n, d.nnz, matdesc, csrValA, csrRowPtrA, csrColIndA, b, - (float)d.tol, d.reorder, x, &singularity))); - - break; - } - case SolverType::C128: { - gpuDoubleComplex* csrValA = static_cast(buffers[0]); - int* csrRowPtrA = static_cast(buffers[1]); - int* csrColIndA = static_cast(buffers[2]); - gpuDoubleComplex* b = static_cast(buffers[3]); - gpuDoubleComplex* x = static_cast(buffers[4]); - - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverSpZcsrlsvqr( - handle.get(), d.n, d.nnz, matdesc, csrValA, csrRowPtrA, csrColIndA, b, - (float)d.tol, d.reorder, x, &singularity))); - - break; - } - } - - cusparseDestroyMatDescr(matdesc); - return absl::OkStatus(); -} - -void Csrlsvqr(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status) { - // Is >= 0 if A is singular. - int singularity = -1; - - auto s = Csrlsvqr_(stream, buffers, opaque, opaque_len, singularity); - if (!s.ok()) { - XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(), - s.message().length()); - } - - if (singularity >= 0) { - auto s = std::string("Singular matrix in linear solve."); - XlaCustomCallStatusSetFailure(status, s.c_str(), s.length()); - } -} - -#endif // JAX_GPU_CUDA - -// orgqr/ungqr: apply elementary Householder transformations - -static absl::Status Orgqr_(gpuStream_t stream, void** buffers, - const char* opaque, size_t opaque_len) { - auto s = UnpackDescriptor(opaque, opaque_len); - JAX_RETURN_IF_ERROR(s.status()); - const OrgqrDescriptor& d = **s; - auto h = SolverHandlePool::Borrow(stream); - JAX_RETURN_IF_ERROR(h.status()); - auto& handle = *h; - if (buffers[2] != buffers[0]) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuMemcpyAsync( - buffers[2], buffers[0], - SizeOfSolverType(d.type) * static_cast(d.batch) * - static_cast(d.m) * static_cast(d.n), - gpuMemcpyDeviceToDevice, stream))); - } - - int* info = static_cast(buffers[3]); - void* workspace = buffers[4]; - switch (d.type) { - case SolverType::F32: { - float* a = static_cast(buffers[2]); - float* tau = static_cast(buffers[1]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - gpusolverDnSorgqr(handle.get(), d.m, d.n, d.k, a, d.m, tau, - static_cast(workspace), d.lwork, info))); - a += d.m * d.n; - tau += d.k; - ++info; - } - break; - } - case SolverType::F64: { - double* a = static_cast(buffers[2]); - double* tau = static_cast(buffers[1]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - gpusolverDnDorgqr(handle.get(), d.m, d.n, d.k, a, d.m, tau, - static_cast(workspace), d.lwork, info))); - a += d.m * d.n; - tau += d.k; - ++info; - } - break; - } - case SolverType::C64: { - gpuComplex* a = static_cast(buffers[2]); - gpuComplex* tau = static_cast(buffers[1]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnCungqr( - handle.get(), d.m, d.n, d.k, a, d.m, tau, - static_cast(workspace), d.lwork, info))); - a += d.m * d.n; - tau += d.k; - ++info; - } - break; - } - case SolverType::C128: { - gpuDoubleComplex* a = static_cast(buffers[2]); - gpuDoubleComplex* tau = static_cast(buffers[1]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnZungqr( - handle.get(), d.m, d.n, d.k, a, d.m, tau, - static_cast(workspace), d.lwork, info))); - a += d.m * d.n; - tau += d.k; - ++info; - } - break; - } - } - return absl::OkStatus(); -} - -void Orgqr(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status) { - auto s = Orgqr_(stream, buffers, opaque, opaque_len); - if (!s.ok()) { - XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(), - s.message().length()); - } -} - -// Symmetric (Hermitian) eigendecomposition, QR algorithm: syevd/heevd - -static absl::Status Syevd_(gpuStream_t stream, void** buffers, - const char* opaque, size_t opaque_len) { - auto s = UnpackDescriptor(opaque, opaque_len); - JAX_RETURN_IF_ERROR(s.status()); - const SyevdDescriptor& d = **s; - auto h = SolverHandlePool::Borrow(stream); - JAX_RETURN_IF_ERROR(h.status()); - auto& handle = *h; - - std::int64_t batch = d.batch; - int output_idx = 1; // with static shapes buffers[1] is the first output - if (d.batch == -1) { - // the batch is passed as a second operand - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuMemcpyAsync( - (void*)&batch, reinterpret_cast(buffers[1]), - sizeof(batch), gpuMemcpyDeviceToHost, stream))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuStreamSynchronize(stream))); - output_idx = 2; - } - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuMemcpyAsync( - buffers[output_idx], buffers[0], - SizeOfSolverType(d.type) * batch * static_cast(d.n) * - static_cast(d.n), - gpuMemcpyDeviceToDevice, stream))); - gpusolverEigMode_t jobz = GPUSOLVER_EIG_MODE_VECTOR; - int* info = static_cast(buffers[output_idx + 2]); - void* work = buffers[output_idx + 3]; - switch (d.type) { - case SolverType::F32: { - float* a = static_cast(buffers[output_idx]); - float* w = static_cast(buffers[output_idx + 1]); - for (int i = 0; i < batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - gpusolverDnSsyevd(handle.get(), jobz, d.uplo, d.n, a, d.n, w, - static_cast(work), d.lwork, info))); - a += d.n * d.n; - w += d.n; - ++info; - } - break; - } - case SolverType::F64: { - double* a = static_cast(buffers[output_idx]); - double* w = static_cast(buffers[output_idx + 1]); - for (int i = 0; i < batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - gpusolverDnDsyevd(handle.get(), jobz, d.uplo, d.n, a, d.n, w, - static_cast(work), d.lwork, info))); - a += d.n * d.n; - w += d.n; - ++info; - } - break; - } - case SolverType::C64: { - gpuComplex* a = static_cast(buffers[output_idx]); - float* w = static_cast(buffers[output_idx + 1]); - for (int i = 0; i < batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - gpusolverDnCheevd(handle.get(), jobz, d.uplo, d.n, a, d.n, w, - static_cast(work), d.lwork, info))); - a += d.n * d.n; - w += d.n; - ++info; - } - break; - } - case SolverType::C128: { - gpuDoubleComplex* a = static_cast(buffers[output_idx]); - double* w = static_cast(buffers[output_idx + 1]); - for (int i = 0; i < batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnZheevd( - handle.get(), jobz, d.uplo, d.n, a, d.n, w, - static_cast(work), d.lwork, info))); - a += d.n * d.n; - w += d.n; - ++info; - } - break; - } - } - return absl::OkStatus(); -} - -void Syevd(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status) { - auto s = Syevd_(stream, buffers, opaque, opaque_len); - if (!s.ok()) { - XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(), - s.message().length()); - } -} - -// Symmetric (Hermitian) eigendecomposition, Jacobi algorithm: syevj/heevj -// Supports batches of matrices up to size 32. - -absl::Status Syevj_(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len) { - auto s = UnpackDescriptor(opaque, opaque_len); - JAX_RETURN_IF_ERROR(s.status()); - const SyevjDescriptor& d = **s; - auto h = SolverHandlePool::Borrow(stream); - JAX_RETURN_IF_ERROR(h.status()); - auto& handle = *h; - if (buffers[1] != buffers[0]) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuMemcpyAsync( - buffers[1], buffers[0], - SizeOfSolverType(d.type) * static_cast(d.batch) * - static_cast(d.n) * static_cast(d.n), - gpuMemcpyDeviceToDevice, stream))); - } - gpuSyevjInfo_t params; - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnCreateSyevjInfo(¶ms))); - std::unique_ptr params_cleanup( - params, [](gpuSyevjInfo_t p) { gpusolverDnDestroySyevjInfo(p); }); - - gpusolverEigMode_t jobz = GPUSOLVER_EIG_MODE_VECTOR; - int* info = static_cast(buffers[3]); - void* work = buffers[4]; - if (d.batch == 1) { - switch (d.type) { - case SolverType::F32: { - float* a = static_cast(buffers[1]); - float* w = static_cast(buffers[2]); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnSsyevj( - handle.get(), jobz, d.uplo, d.n, a, d.n, w, - static_cast(work), d.lwork, info, params))); - break; - } - case SolverType::F64: { - double* a = static_cast(buffers[1]); - double* w = static_cast(buffers[2]); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnDsyevj( - handle.get(), jobz, d.uplo, d.n, a, d.n, w, - static_cast(work), d.lwork, info, params))); - break; - } - case SolverType::C64: { - gpuComplex* a = static_cast(buffers[1]); - float* w = static_cast(buffers[2]); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnCheevj( - handle.get(), jobz, d.uplo, d.n, a, d.n, w, - static_cast(work), d.lwork, info, params))); - break; - } - case SolverType::C128: { - gpuDoubleComplex* a = static_cast(buffers[1]); - double* w = static_cast(buffers[2]); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnZheevj( - handle.get(), jobz, d.uplo, d.n, a, d.n, w, - static_cast(work), d.lwork, info, params))); - break; - } - } - } else { - switch (d.type) { - case SolverType::F32: { - float* a = static_cast(buffers[1]); - float* w = static_cast(buffers[2]); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnSsyevjBatched( - handle.get(), jobz, d.uplo, d.n, a, d.n, w, - static_cast(work), d.lwork, info, params, d.batch))); - break; - } - case SolverType::F64: { - double* a = static_cast(buffers[1]); - double* w = static_cast(buffers[2]); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnDsyevjBatched( - handle.get(), jobz, d.uplo, d.n, a, d.n, w, - static_cast(work), d.lwork, info, params, d.batch))); - break; - } - case SolverType::C64: { - gpuComplex* a = static_cast(buffers[1]); - float* w = static_cast(buffers[2]); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnCheevjBatched( - handle.get(), jobz, d.uplo, d.n, a, d.n, w, - static_cast(work), d.lwork, info, params, d.batch))); - break; - } - case SolverType::C128: { - gpuDoubleComplex* a = static_cast(buffers[1]); - double* w = static_cast(buffers[2]); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - gpusolverDnZheevjBatched(handle.get(), jobz, d.uplo, d.n, a, d.n, w, - static_cast(work), - d.lwork, info, params, d.batch))); - break; - } - } - } - return absl::OkStatus(); -} - -void Syevj(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status) { - auto s = Syevj_(stream, buffers, opaque, opaque_len); - if (!s.ok()) { - XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(), - s.message().length()); - } -} - -// Singular value decomposition using QR algorithm: gesvd - -static absl::Status Gesvd_(gpuStream_t stream, void** buffers, - const char* opaque, size_t opaque_len) { - auto s = UnpackDescriptor(opaque, opaque_len); - JAX_RETURN_IF_ERROR(s.status()); - const GesvdDescriptor& d = **s; - auto h = SolverHandlePool::Borrow(stream); - JAX_RETURN_IF_ERROR(h.status()); - auto& handle = *h; - if (buffers[1] != buffers[0]) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuMemcpyAsync( - buffers[1], buffers[0], - SizeOfSolverType(d.type) * static_cast(d.batch) * - static_cast(d.m) * static_cast(d.n), - gpuMemcpyDeviceToDevice, stream))); - } - int* info = static_cast(buffers[5]); - void* work = buffers[6]; - int64_t k = d.jobu == 'A' ? d.m : d.n; - switch (d.type) { - case SolverType::F32: { - float* a = static_cast(buffers[1]); - float* s = static_cast(buffers[2]); - float* u = static_cast(buffers[3]); - float* vt = static_cast(buffers[4]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnSgesvd( - handle.get(), d.jobu, d.jobvt, d.m, d.n, a, d.m, s, u, d.m, vt, d.n, - static_cast(work), d.lwork, - /*rwork=*/nullptr, info))); - a += d.m * d.n; - s += std::min(d.m, d.n); - u += d.m * k; - vt += d.n * d.n; - ++info; - } - break; - } - case SolverType::F64: { - double* a = static_cast(buffers[1]); - double* s = static_cast(buffers[2]); - double* u = static_cast(buffers[3]); - double* vt = static_cast(buffers[4]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnDgesvd( - handle.get(), d.jobu, d.jobvt, d.m, d.n, a, d.m, s, u, d.m, vt, d.n, - static_cast(work), d.lwork, - /*rwork=*/nullptr, info))); - a += d.m * d.n; - s += std::min(d.m, d.n); - u += d.m * k; - vt += d.n * d.n; - ++info; - } - break; - } - case SolverType::C64: { - gpuComplex* a = static_cast(buffers[1]); - float* s = static_cast(buffers[2]); - gpuComplex* u = static_cast(buffers[3]); - gpuComplex* vt = static_cast(buffers[4]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnCgesvd( - handle.get(), d.jobu, d.jobvt, d.m, d.n, a, d.m, s, u, d.m, vt, d.n, - static_cast(work), d.lwork, /*rwork=*/nullptr, info))); - a += d.m * d.n; - s += std::min(d.m, d.n); - u += d.m * k; - vt += d.n * d.n; - ++info; - } - break; - } - case SolverType::C128: { - gpuDoubleComplex* a = static_cast(buffers[1]); - double* s = static_cast(buffers[2]); - gpuDoubleComplex* u = static_cast(buffers[3]); - gpuDoubleComplex* vt = static_cast(buffers[4]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnZgesvd( - handle.get(), d.jobu, d.jobvt, d.m, d.n, a, d.m, s, u, d.m, vt, d.n, - static_cast(work), d.lwork, - /*rwork=*/nullptr, info))); - a += d.m * d.n; - s += std::min(d.m, d.n); - u += d.m * k; - vt += d.n * d.n; - ++info; - } - break; - } - } - return absl::OkStatus(); -} - -void Gesvd(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status) { - auto s = Gesvd_(stream, buffers, opaque, opaque_len); - if (!s.ok()) { - XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(), - s.message().length()); - } -} - -#ifdef JAX_GPU_CUDA - -// Singular value decomposition using Jacobi algorithm: gesvdj - -static absl::Status Gesvdj_(gpuStream_t stream, void** buffers, - const char* opaque, size_t opaque_len) { - auto s = UnpackDescriptor(opaque, opaque_len); - JAX_RETURN_IF_ERROR(s.status()); - const GesvdjDescriptor& d = **s; - auto h = SolverHandlePool::Borrow(stream); - JAX_RETURN_IF_ERROR(h.status()); - auto& handle = *h; - if (buffers[1] != buffers[0]) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuMemcpyAsync( - buffers[1], buffers[0], - SizeOfSolverType(d.type) * static_cast(d.batch) * - static_cast(d.m) * static_cast(d.n), - gpuMemcpyDeviceToDevice, stream))); - } - int* info = static_cast(buffers[5]); - void* work = buffers[6]; - gesvdjInfo_t params; - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnCreateGesvdjInfo(¶ms))); - std::unique_ptr params_cleanup( - params, [](gesvdjInfo* p) { cusolverDnDestroyGesvdjInfo(p); }); - if (d.batch <= 1 || d.m > 32 || d.n > 32 || d.econ) { - int k = std::min(d.m, d.n); - switch (d.type) { - case SolverType::F32: { - float* a = static_cast(buffers[1]); - float* s = static_cast(buffers[2]); - float* u = static_cast(buffers[3]); - float* v = static_cast(buffers[4]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnSgesvdj( - handle.get(), d.jobz, d.econ, d.m, d.n, a, d.m, s, u, d.m, v, d.n, - static_cast(work), d.lwork, info, params))); - a += d.m * d.n; - s += k; - u += d.m * (d.econ ? k : d.m); - v += (d.econ ? k : d.n) * d.n; - ++info; - } - break; - } - case SolverType::F64: { - double* a = static_cast(buffers[1]); - double* s = static_cast(buffers[2]); - double* u = static_cast(buffers[3]); - double* v = static_cast(buffers[4]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnDgesvdj( - handle.get(), d.jobz, d.econ, d.m, d.n, a, d.m, s, u, d.m, v, d.n, - static_cast(work), d.lwork, info, params))); - a += d.m * d.n; - s += k; - u += d.m * (d.econ ? k : d.m); - v += (d.econ ? k : d.n) * d.n; - ++info; - } - break; - } - case SolverType::C64: { - gpuComplex* a = static_cast(buffers[1]); - float* s = static_cast(buffers[2]); - gpuComplex* u = static_cast(buffers[3]); - gpuComplex* v = static_cast(buffers[4]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnCgesvdj( - handle.get(), d.jobz, d.econ, d.m, d.n, a, d.m, s, u, d.m, v, d.n, - static_cast(work), d.lwork, info, params))); - a += d.m * d.n; - s += k; - u += d.m * (d.econ ? k : d.m); - v += (d.econ ? k : d.n) * d.n; - ++info; - } - break; - } - case SolverType::C128: { - gpuDoubleComplex* a = static_cast(buffers[1]); - double* s = static_cast(buffers[2]); - gpuDoubleComplex* u = static_cast(buffers[3]); - gpuDoubleComplex* v = static_cast(buffers[4]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnZgesvdj( - handle.get(), d.jobz, d.econ, d.m, d.n, a, d.m, s, u, d.m, v, d.n, - static_cast(work), d.lwork, info, params))); - a += d.m * d.n; - s += k; - u += d.m * (d.econ ? k : d.m); - v += (d.econ ? k : d.n) * d.n; - ++info; - } - break; - } - } - } else { - switch (d.type) { - case SolverType::F32: { - float* a = static_cast(buffers[1]); - float* s = static_cast(buffers[2]); - float* u = static_cast(buffers[3]); - float* v = static_cast(buffers[4]); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnSgesvdjBatched( - handle.get(), d.jobz, d.m, d.n, a, d.m, s, u, d.m, v, d.n, - static_cast(work), d.lwork, info, params, d.batch))); - break; - } - case SolverType::F64: { - double* a = static_cast(buffers[1]); - double* s = static_cast(buffers[2]); - double* u = static_cast(buffers[3]); - double* v = static_cast(buffers[4]); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnDgesvdjBatched( - handle.get(), d.jobz, d.m, d.n, a, d.m, s, u, d.m, v, d.n, - static_cast(work), d.lwork, info, params, d.batch))); - break; - } - case SolverType::C64: { - gpuComplex* a = static_cast(buffers[1]); - float* s = static_cast(buffers[2]); - gpuComplex* u = static_cast(buffers[3]); - gpuComplex* v = static_cast(buffers[4]); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnCgesvdjBatched( - handle.get(), d.jobz, d.m, d.n, a, d.m, s, u, d.m, v, d.n, - static_cast(work), d.lwork, info, params, d.batch))); - break; - } - case SolverType::C128: { - gpuDoubleComplex* a = static_cast(buffers[1]); - double* s = static_cast(buffers[2]); - gpuDoubleComplex* u = static_cast(buffers[3]); - gpuDoubleComplex* v = static_cast(buffers[4]); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnZgesvdjBatched( - handle.get(), d.jobz, d.m, d.n, a, d.m, s, u, d.m, v, d.n, - static_cast(work), d.lwork, info, params, - d.batch))); - break; - } - } - } - return absl::OkStatus(); -} - -void Gesvdj(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status) { - auto s = Gesvdj_(stream, buffers, opaque, opaque_len); - if (!s.ok()) { - XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(), - s.message().length()); - } -} - -#endif // JAX_GPU_CUDA - -// sytrd/hetrd: symmetric (Hermitian) tridiagonal reduction - -static absl::Status Sytrd_(gpuStream_t stream, void** buffers, - const char* opaque, size_t opaque_len) { - auto s = UnpackDescriptor(opaque, opaque_len); - JAX_RETURN_IF_ERROR(s.status()); - const SytrdDescriptor& d = **s; - auto h = SolverHandlePool::Borrow(stream); - JAX_RETURN_IF_ERROR(h.status()); - auto& handle = *h; - if (buffers[1] != buffers[0]) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuMemcpyAsync( - buffers[1], buffers[0], - SizeOfSolverType(d.type) * static_cast(d.batch) * - static_cast(d.n) * static_cast(d.lda), - gpuMemcpyDeviceToDevice, stream))); - } - - int* info = static_cast(buffers[5]); - void* workspace = buffers[6]; - switch (d.type) { - case SolverType::F32: { - float* a = static_cast(buffers[1]); - float* d_out = static_cast(buffers[2]); - float* e_out = static_cast(buffers[3]); - float* tau = static_cast(buffers[4]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnSsytrd( - handle.get(), d.uplo, d.n, a, d.lda, d_out, e_out, tau, - static_cast(workspace), d.lwork, info))); - a += d.lda * d.n; - d_out += d.n; - e_out += d.n - 1; - tau += d.n - 1; - ++info; - } - break; - } - case SolverType::F64: { - double* a = static_cast(buffers[1]); - double* d_out = static_cast(buffers[2]); - double* e_out = static_cast(buffers[3]); - double* tau = static_cast(buffers[4]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnDsytrd( - handle.get(), d.uplo, d.n, a, d.lda, d_out, e_out, tau, - static_cast(workspace), d.lwork, info))); - a += d.lda * d.n; - d_out += d.n; - e_out += d.n - 1; - tau += d.n - 1; - ++info; - } - break; - } - case SolverType::C64: { - gpuComplex* a = static_cast(buffers[1]); - float* d_out = static_cast(buffers[2]); - float* e_out = static_cast(buffers[3]); - gpuComplex* tau = static_cast(buffers[4]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnChetrd( - handle.get(), d.uplo, d.n, a, d.lda, d_out, e_out, tau, - static_cast(workspace), d.lwork, info))); - a += d.lda * d.n; - d_out += d.n; - e_out += d.n - 1; - tau += d.n - 1; - ++info; - } - break; - } - case SolverType::C128: { - gpuDoubleComplex* a = static_cast(buffers[1]); - double* d_out = static_cast(buffers[2]); - double* e_out = static_cast(buffers[3]); - gpuDoubleComplex* tau = static_cast(buffers[4]); - for (int i = 0; i < d.batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnZhetrd( - handle.get(), d.uplo, d.n, a, d.lda, d_out, e_out, tau, - static_cast(workspace), d.lwork, info))); - a += d.lda * d.n; - d_out += d.n; - e_out += d.n - 1; - tau += d.n - 1; - ++info; - } - break; - } - } - return absl::OkStatus(); -} - -void Sytrd(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status) { - auto s = Sytrd_(stream, buffers, opaque, opaque_len); - if (!s.ok()) { - XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(), - s.message().length()); - } -} - -} // namespace JAX_GPU_NAMESPACE -} // namespace jax diff --git a/jaxlib/gpu/solver_kernels.h b/jaxlib/gpu/solver_kernels.h deleted file mode 100644 index 51082f2fe812..000000000000 --- a/jaxlib/gpu/solver_kernels.h +++ /dev/null @@ -1,148 +0,0 @@ -/* Copyright 2019 The JAX Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef JAXLIB_CUSOLVER_KERNELS_H_ -#define JAXLIB_CUSOLVER_KERNELS_H_ - -#include - -#include "jaxlib/gpu/vendor.h" -#include "xla/service/custom_call_status.h" - -namespace jax { - -namespace JAX_GPU_NAMESPACE { - -// Set of types known to Cusolver. -enum class SolverType { - F32, - F64, - C64, - C128, -}; - -// getrf: LU decomposition - -struct GetrfDescriptor { - SolverType type; - int batch, m, n, lwork; -}; - -void Getrf(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status); - -// geqrf: QR decomposition - -struct GeqrfDescriptor { - SolverType type; - int batch, m, n, lwork; -}; - -void Geqrf(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status); - -#ifdef JAX_GPU_CUDA - -// csrlsvpr: Linear system solve via Sparse QR - -struct CsrlsvqrDescriptor { - SolverType type; - int n, nnz, reorder; - double tol; -}; - -void Csrlsvqr(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status); - -#endif // JAX_GPU_CUDA - -// orgqr/ungqr: apply elementary Householder transformations - -struct OrgqrDescriptor { - SolverType type; - int batch, m, n, k, lwork; -}; - -void Orgqr(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status); - -// Symmetric (Hermitian) eigendecomposition, QR algorithm: syevd/heevd - -struct SyevdDescriptor { - SolverType type; - gpusolverFillMode_t uplo; - int batch, n; // batch may be -1 in which case it is passed as operand. - int lwork; -}; - -void Syevd(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status); - -// Symmetric (Hermitian) eigendecomposition, Jacobi algorithm: syevj/heevj -// Supports batches of matrices up to size 32. - -struct SyevjDescriptor { - SolverType type; - gpusolverFillMode_t uplo; - int batch, n; - int lwork; -}; - -void Syevj(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status); - -// Singular value decomposition using QR algorithm: gesvd - -struct GesvdDescriptor { - SolverType type; - int batch, m, n; - int lwork; - signed char jobu, jobvt; -}; - -void Gesvd(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status); - -#ifdef JAX_GPU_CUDA - -// Singular value decomposition using Jacobi algorithm: gesvdj - -struct GesvdjDescriptor { - SolverType type; - int batch, m, n; - int lwork; - gpusolverEigMode_t jobz; - int econ; -}; - -void Gesvdj(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status); -#endif // JAX_GPU_CUDA - -// sytrd/hetrd: Reduction of a symmetric (Hermitian) matrix to tridiagonal form. -struct SytrdDescriptor { - SolverType type; - gpusolverFillMode_t uplo; - int batch, n, lda, lwork; -}; - -void Sytrd(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status); - - -} // namespace JAX_GPU_NAMESPACE -} // namespace jax - -#endif // JAXLIB_CUSOLVER_KERNELS_H_ diff --git a/jaxlib/gpu/sparse.cc b/jaxlib/gpu/sparse.cc index 429c8018dc7a..21f567e79f92 100644 --- a/jaxlib/gpu/sparse.cc +++ b/jaxlib/gpu/sparse.cc @@ -13,19 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include -#include +#include #include #include -#include -#include "nanobind/nanobind.h" -#include "nanobind/stl/pair.h" -#include "absl/base/casts.h" -#include "absl/base/thread_annotations.h" #include "absl/container/flat_hash_map.h" #include "absl/strings/str_format.h" -#include "absl/synchronization/mutex.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/pair.h" // IWYU pragma: keep #include "jaxlib/absl_status_casters.h" #include "jaxlib/gpu/gpu_kernel_helpers.h" #include "jaxlib/gpu/sparse_kernels.h" @@ -147,45 +142,6 @@ std::pair BuildCsrToDenseDescriptor(const dtype& data_dtype, return {buffer_size, PackDescriptor(d)}; } -absl::Status CsrToDense_(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len) { - auto s = UnpackDescriptor(opaque, opaque_len); - JAX_RETURN_IF_ERROR(s.status()); - const SparseMatDescriptor& d = **s; - auto h = SparseHandlePool::Borrow(stream); - JAX_RETURN_IF_ERROR(h.status()); - auto& handle = *h; - - gpusparseSpMatDescr_t mat_a = 0; - gpusparseDnMatDescr_t mat_b = 0; - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - gpusparseCreateCsr(&mat_a, d.rows, d.cols, d.nnz, - /*csrRowOffsets=*/buffers[2], - /*csrColInd=*/buffers[1], - /*csrValues=*/buffers[0], d.index_type, d.index_type, - GPUSPARSE_INDEX_BASE_ZERO, d.value_type))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseCreateDnMat( - &mat_b, d.rows, d.cols, - /*ld=*/d.cols, buffers[3], d.value_type, GPUSPARSE_ORDER_ROW))); - - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - gpusparseSparseToDense(handle.get(), mat_a, mat_b, - GPUSPARSE_SPARSETODENSE_ALG_DEFAULT, buffers[4]))); - - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroySpMat(mat_a))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnMat(mat_b))); - return absl::OkStatus(); -} - -void CsrToDense(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status) { - auto s = CsrToDense_(stream, buffers, opaque, opaque_len); - if (!s.ok()) { - XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(), - s.message().length()); - } -} - // CsrFromDense: Convert dense matrix to CSR matrix // Returns the descriptor for a CsrFromDense operation. @@ -222,46 +178,6 @@ std::pair BuildCsrFromDenseDescriptor( return {buffer_size, PackDescriptor(d)}; } -absl::Status CsrFromDense_(gpuStream_t stream, void** buffers, - const char* opaque, size_t opaque_len) { - auto s = UnpackDescriptor(opaque, opaque_len); - JAX_RETURN_IF_ERROR(s.status()); - const SparseMatDescriptor& d = **s; - auto h = SparseHandlePool::Borrow(stream); - JAX_RETURN_IF_ERROR(h.status()); - auto& handle = *h; - - gpusparseDnMatDescr_t mat_a = 0; - gpusparseSpMatDescr_t mat_b = 0; - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseCreateDnMat( - &mat_a, d.rows, d.cols, - /*ld=*/d.cols, buffers[0], d.value_type, GPUSPARSE_ORDER_ROW))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - gpusparseCreateCsr(&mat_b, d.rows, d.cols, d.nnz, - /*csrRowOffsets=*/buffers[3], - /*csrColInd=*/buffers[2], - /*csrValues=*/buffers[1], d.index_type, d.index_type, - GPUSPARSE_INDEX_BASE_ZERO, d.value_type))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDenseToSparse_analysis( - handle.get(), mat_a, mat_b, GPUSPARSE_DENSETOSPARSE_ALG_DEFAULT, - buffers[4]))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDenseToSparse_convert( - handle.get(), mat_a, mat_b, GPUSPARSE_DENSETOSPARSE_ALG_DEFAULT, - buffers[4]))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnMat(mat_a))); - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroySpMat(mat_b))); - return absl::OkStatus(); -} - -void CsrFromDense(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status) { - auto s = CsrFromDense_(stream, buffers, opaque, opaque_len); - if (!s.ok()) { - XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(), - s.message().length()); - } -} - // CsrMatvec: Product of CSR matrix and dense vector. // Returns the descriptor for a CsrMatvec operation. @@ -554,44 +470,9 @@ std::pair BuildCooMatmatDescriptor( #endif // if JAX_GPU_HAVE_SPARSE -nb::bytes BuildGtsv2Descriptor(int b, int m, int n, int ldb) { - return PackDescriptor(Gtsv2Descriptor{b, m, n, ldb}); -} - -template -size_t Gtsv2BufferSize(F f, int m, int n, int ldb) { - auto h = SparseHandlePool::Borrow(/*stream=*/nullptr); - JAX_THROW_IF_ERROR(h.status()); - auto& handle = *h; - size_t size; - JAX_THROW_IF_ERROR( - JAX_AS_STATUS(f(handle.get(), m, n, /*dl=*/nullptr, /*d=*/nullptr, - /*du=*/nullptr, /*B=*/nullptr, ldb, &size))); - return size; -} - -size_t Gtsv2BufferSizeF32(int m, int n, int ldb) { - return Gtsv2BufferSize(gpusparseSgtsv2_bufferSizeExt, m, n, ldb); -} - -size_t Gtsv2BufferSizeF64(int m, int n, int ldb) { - return Gtsv2BufferSize(gpusparseDgtsv2_bufferSizeExt, m, n, ldb); -} - nb::dict Registrations() { nb::dict dict; #if JAX_GPU_HAVE_SPARSE - dict[JAX_GPU_PREFIX "sparse_csr_todense"] = EncapsulateFunction(CsrToDense); - dict[JAX_GPU_PREFIX "sparse_csr_fromdense"] = - EncapsulateFunction(CsrFromDense); - dict[JAX_GPU_PREFIX "sparse_csr_matvec"] = EncapsulateFunction(CsrMatvec); - dict[JAX_GPU_PREFIX "sparse_csr_matmat"] = EncapsulateFunction(CsrMatmat); - dict[JAX_GPU_PREFIX "sparse_coo_todense"] = EncapsulateFunction(CooToDense); - dict[JAX_GPU_PREFIX "sparse_coo_fromdense"] = - EncapsulateFunction(CooFromDense); - dict[JAX_GPU_PREFIX "sparse_coo_matvec"] = EncapsulateFunction(CooMatvec); - dict[JAX_GPU_PREFIX "sparse_coo_matmat"] = EncapsulateFunction(CooMatmat); - dict[JAX_GPU_PREFIX "sparse_csr_todense_ffi"] = EncapsulateFfiHandler(CsrToDenseFfi); dict[JAX_GPU_PREFIX "sparse_csr_fromdense_ffi"] = @@ -609,12 +490,8 @@ nb::dict Registrations() { dict[JAX_GPU_PREFIX "sparse_coo_matmat_ffi"] = EncapsulateFfiHandler(CooMatmatFfi); #endif - dict[JAX_GPU_PREFIX "sparse_gtsv2_f32"] = EncapsulateFunction(gtsv2_f32); - dict[JAX_GPU_PREFIX "sparse_gtsv2_f64"] = EncapsulateFunction(gtsv2_f64); - dict[JAX_GPU_PREFIX "sparse_gtsv2_f32_ffi"] = - EncapsulateFfiHandler(gtsv2_f32_ffi); - dict[JAX_GPU_PREFIX "sparse_gtsv2_f64_ffi"] = - EncapsulateFfiHandler(gtsv2_f64_ffi); + dict[JAX_GPU_PREFIX "sparse_gtsv2_ffi"] = EncapsulateFfiHandler(kGtsv2); + // TODO(tomhennigan): Add support for gtsv2 complex 32/64. return dict; } @@ -633,9 +510,6 @@ NB_MODULE(_sparse, m) { m.def("build_coo_matvec_descriptor", &BuildCooMatvecDescriptor); m.def("build_coo_matmat_descriptor", &BuildCooMatmatDescriptor); #endif - m.def("gtsv2_f32_buffer_size", &Gtsv2BufferSizeF32); - m.def("gtsv2_f64_buffer_size", &Gtsv2BufferSizeF64); - m.def("build_gtsv2_descriptor", &BuildGtsv2Descriptor); } } // namespace diff --git a/jaxlib/gpu/sparse_kernels.cc b/jaxlib/gpu/sparse_kernels.cc index 5b620a05236d..139fbc73f8ce 100644 --- a/jaxlib/gpu/sparse_kernels.cc +++ b/jaxlib/gpu/sparse_kernels.cc @@ -15,22 +15,28 @@ limitations under the License. #include "jaxlib/gpu/sparse_kernels.h" -#include +#include #include -#include -#include -#include +#include +#include #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "absl/synchronization/mutex.h" +#include "jaxlib/ffi_helpers.h" #include "jaxlib/gpu/ffi_wrapper.h" #include "jaxlib/gpu/gpu_kernel_helpers.h" +#include "jaxlib/gpu/handle_pool.h" #include "jaxlib/gpu/vendor.h" -#include "jaxlib/handle_pool.h" #include "jaxlib/kernel_helpers.h" -#include "xla/service/custom_call_status.h" +#include "xla/ffi/api/ffi.h" + +#define JAX_FFI_RETURN_IF_GPU_ERROR(...) \ + FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(__VA_ARGS__)) + +namespace ffi = ::xla::ffi; namespace jax { @@ -182,15 +188,6 @@ static absl::Status CsrToDense_(gpuStream_t stream, void** buffers, JAX_GPU_REGISTER_WRAPPED_LEGACY_KERNEL(CsrToDenseFfi, CsrToDense_); -void CsrToDense(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status) { - auto s = CsrToDense_(stream, buffers, opaque, opaque_len); - if (!s.ok()) { - XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(), - s.message().length()); - } -} - // CsrFromDense: Convert dense matrix to CSR matrix static absl::Status CsrFromDense_(gpuStream_t stream, void** buffers, @@ -226,15 +223,6 @@ static absl::Status CsrFromDense_(gpuStream_t stream, void** buffers, JAX_GPU_REGISTER_WRAPPED_LEGACY_KERNEL(CsrFromDenseFfi, CsrFromDense_); -void CsrFromDense(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status) { - auto s = CsrFromDense_(stream, buffers, opaque, opaque_len); - if (!s.ok()) { - XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(), - s.message().length()); - } -} - // CsrMatvec: Product of CSR matrix and dense vector. static absl::Status CsrMatvec_(gpuStream_t stream, void** buffers, @@ -285,15 +273,6 @@ static absl::Status CsrMatvec_(gpuStream_t stream, void** buffers, JAX_GPU_REGISTER_WRAPPED_LEGACY_KERNEL(CsrMatvecFfi, CsrMatvec_); -void CsrMatvec(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status) { - auto s = CsrMatvec_(stream, buffers, opaque, opaque_len); - if (!s.ok()) { - XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(), - s.message().length()); - } -} - // CsrMatmat: Product of CSR matrix and dense matrix. static absl::Status CsrMatmat_(gpuStream_t stream, void** buffers, @@ -345,15 +324,6 @@ static absl::Status CsrMatmat_(gpuStream_t stream, void** buffers, JAX_GPU_REGISTER_WRAPPED_LEGACY_KERNEL(CsrMatmatFfi, CsrMatmat_); -void CsrMatmat(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status) { - auto s = CsrMatmat_(stream, buffers, opaque, opaque_len); - if (!s.ok()) { - XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(), - s.message().length()); - } -} - // CooToDense: Convert COO matrix to dense matrix static absl::Status CooToDense_(gpuStream_t stream, void** buffers, @@ -388,15 +358,6 @@ static absl::Status CooToDense_(gpuStream_t stream, void** buffers, JAX_GPU_REGISTER_WRAPPED_LEGACY_KERNEL(CooToDenseFfi, CooToDense_); -void CooToDense(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status) { - auto s = CooToDense_(stream, buffers, opaque, opaque_len); - if (!s.ok()) { - XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(), - s.message().length()); - } -} - // CooFromDense: Convert dense matrix to COO matrix static absl::Status CooFromDense_(gpuStream_t stream, void** buffers, @@ -432,15 +393,6 @@ static absl::Status CooFromDense_(gpuStream_t stream, void** buffers, JAX_GPU_REGISTER_WRAPPED_LEGACY_KERNEL(CooFromDenseFfi, CooFromDense_); -void CooFromDense(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status) { - auto s = CooFromDense_(stream, buffers, opaque, opaque_len); - if (!s.ok()) { - XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(), - s.message().length()); - } -} - // CooMatvec: Product of COO matrix and dense vector. static absl::Status CooMatvec_(gpuStream_t stream, void** buffers, @@ -490,15 +442,6 @@ static absl::Status CooMatvec_(gpuStream_t stream, void** buffers, JAX_GPU_REGISTER_WRAPPED_LEGACY_KERNEL(CooMatvecFfi, CooMatvec_); -void CooMatvec(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status) { - auto s = CooMatvec_(stream, buffers, opaque, opaque_len); - if (!s.ok()) { - XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(), - s.message().length()); - } -} - // CooMatmat: Product of COO matrix and dense matrix. static absl::Status CooMatmat_(gpuStream_t stream, void** buffers, @@ -557,91 +500,164 @@ static absl::Status CooMatmat_(gpuStream_t stream, void** buffers, } JAX_GPU_REGISTER_WRAPPED_LEGACY_KERNEL(CooMatmatFfi, CooMatmat_); +#endif // if JAX_GPU_HAVE_SPARSE -void CooMatmat(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status) { - auto s = CooMatmat_(stream, buffers, opaque, opaque_len); - if (!s.ok()) { - XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(), - s.message().length()); +template +ffi::Error Gtsv2Impl(BufferSizeF getBufferSize, KernelF kernel, int64_t batch, + int64_t rows, int64_t cols, gpuStream_t stream, + ffi::ScratchAllocator& scratch, ffi::AnyBuffer dl, + ffi::AnyBuffer d, ffi::AnyBuffer du, ffi::AnyBuffer b, + ffi::Result out) { + FFI_ASSIGN_OR_RETURN(auto m, MaybeCastNoOverflow(rows)); + FFI_ASSIGN_OR_RETURN(auto n, MaybeCastNoOverflow(cols)); + + FFI_ASSIGN_OR_RETURN(auto handle, SparseHandlePool::Borrow(stream)); + size_t buffer_size_in_bytes; + JAX_FFI_RETURN_IF_GPU_ERROR(getBufferSize(handle.get(), m, n, nullptr, + nullptr, nullptr, nullptr, m, + &buffer_size_in_bytes)); + auto maybe_workspace = scratch.Allocate(buffer_size_in_bytes); + if (!maybe_workspace.has_value()) { + return ffi::Error::Internal("Unable to allocate workspace for gtsv2"); + } + void* workspace = maybe_workspace.value(); + + auto dl_data = static_cast(dl.untyped_data()); + auto d_data = static_cast(d.untyped_data()); + auto du_data = static_cast(du.untyped_data()); + auto b_data = static_cast(b.untyped_data()); + auto out_data = static_cast(out->untyped_data()); + if (b_data != out_data) { + JAX_FFI_RETURN_IF_GPU_ERROR(gpuMemcpyAsync( + out_data, b_data, b.size_bytes(), gpuMemcpyDeviceToDevice, stream)); } -} -#endif // if JAX_GPU_HAVE_SPARSE -template -static absl::Status gtsv2(F computeGtsv2, gpuStream_t stream, void** buffers, - const char* opaque, std::size_t opaque_len) { - auto h = SparseHandlePool::Borrow(stream); - JAX_RETURN_IF_ERROR(h.status()); - auto& handle = *h; + for (int64_t i = 0; i < batch; ++i) { + JAX_FFI_RETURN_IF_GPU_ERROR(kernel(handle.get(), m, n, dl_data, d_data, + du_data, out_data, m, workspace)); + dl_data += m; + d_data += m; + du_data += m; + out_data += m * n; + } + return ffi::Error::Success(); +} - auto s = UnpackDescriptor(opaque, opaque_len); - JAX_RETURN_IF_ERROR(s.status()); - const Gtsv2Descriptor& descriptor = **s; - int batch = descriptor.batch; - int m = descriptor.m; - int n = descriptor.n; - int ldb = descriptor.ldb; - - T* dl = static_cast(buffers[0]); - T* d = static_cast(buffers[1]); - T* du = static_cast(buffers[2]); - T* B = static_cast(buffers[3]); - T* X = static_cast(buffers[4]); - void* buffer = static_cast(buffers[5]); - - // The solution X is written in place to B. We need to therefore copy the - // contents of B into the output buffer X and pass that into the kernel as B. - // Once copy insertion is supported for custom call aliasing, we could alias B - // with X and avoid the copy, the code below is written defensively assuming B - // and X might alias, but today we know they will not. - // TODO(b/182906199): Update the comment here once copy insertion is WAI. - if (X != B) { - size_t B_bytes = ldb * n * sizeof(T) * batch; - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - gpuMemcpyAsync(X, B, B_bytes, gpuMemcpyDeviceToDevice, stream))); +template +ffi::Error Gtsv2BatchedImpl(BufferSizeF getBufferSize, KernelF kernel, + int64_t batch, int64_t rows, gpuStream_t stream, + ffi::ScratchAllocator& scratch, ffi::AnyBuffer dl, + ffi::AnyBuffer d, ffi::AnyBuffer du, + ffi::AnyBuffer b, ffi::Result out) { + FFI_ASSIGN_OR_RETURN(auto batch_count, MaybeCastNoOverflow(batch)); + FFI_ASSIGN_OR_RETURN(auto m, MaybeCastNoOverflow(rows)); + + FFI_ASSIGN_OR_RETURN(auto handle, SparseHandlePool::Borrow(stream)); + size_t buffer_size_in_bytes; + JAX_FFI_RETURN_IF_GPU_ERROR(getBufferSize(handle.get(), m, nullptr, nullptr, + nullptr, nullptr, batch_count, m, + &buffer_size_in_bytes)); + auto maybe_workspace = scratch.Allocate(buffer_size_in_bytes); + if (!maybe_workspace.has_value()) { + return ffi::Error::Internal("Unable to allocate workspace for gtsv2"); } - for (int i = 0; i < batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - computeGtsv2(handle.get(), m, n, dl, d, du, X, ldb, buffer))); - dl += m; - d += m; - du += m; - X += m * n; + void* workspace = maybe_workspace.value(); + + auto dl_data = static_cast(dl.untyped_data()); + auto d_data = static_cast(d.untyped_data()); + auto du_data = static_cast(du.untyped_data()); + auto b_data = static_cast(b.untyped_data()); + auto out_data = static_cast(out->untyped_data()); + if (b_data != out_data) { + JAX_FFI_RETURN_IF_GPU_ERROR(gpuMemcpyAsync( + out_data, b_data, b.size_bytes(), gpuMemcpyDeviceToDevice, stream)); } - return absl::OkStatus(); + + JAX_FFI_RETURN_IF_GPU_ERROR(kernel(handle.get(), m, dl_data, d_data, du_data, + out_data, batch_count, m, workspace)); + return ffi::Error::Success(); } -JAX_GPU_REGISTER_WRAPPED_LEGACY_KERNEL( - gtsv2_f32_ffi, [](gpuStream_t stream, void** buffers, const char* opaque, - std::size_t opaque_len) { - return gtsv2(gpusparseSgtsv2, stream, buffers, opaque, opaque_len); - }); - -JAX_GPU_REGISTER_WRAPPED_LEGACY_KERNEL( - gtsv2_f64_ffi, [](gpuStream_t stream, void** buffers, const char* opaque, - std::size_t opaque_len) { - return gtsv2(gpusparseDgtsv2, stream, buffers, opaque, - opaque_len); - }); - -void gtsv2_f32(gpuStream_t stream, void** buffers, const char* opaque, - std::size_t opaque_len, XlaCustomCallStatus* status) { - auto s = gtsv2(gpusparseSgtsv2, stream, buffers, opaque, opaque_len); - if (!s.ok()) { - XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(), - s.message().length()); +ffi::Error Gtsv2(gpuStream_t stream, ffi::ScratchAllocator scratch, + ffi::AnyBuffer dl, ffi::AnyBuffer d, ffi::AnyBuffer du, + ffi::AnyBuffer b, ffi::Result out) { + auto dataType = dl.element_type(); + if (dataType != d.element_type() || dataType != du.element_type() || + dataType != b.element_type() || dataType != out->element_type()) { + return ffi::Error::InvalidArgument( + "The inputs and outputs to gtsv2 must have the same element type"); } -} + FFI_ASSIGN_OR_RETURN((auto [batch, rows, cols]), + SplitBatch2D(b.dimensions())); + FFI_RETURN_IF_ERROR( + CheckShape(out->dimensions(), {batch, rows, cols}, "out", "gtsv2")); + FFI_RETURN_IF_ERROR( + CheckShape(dl.dimensions(), {batch, rows}, "dl", "gtsv2")); + FFI_RETURN_IF_ERROR(CheckShape(d.dimensions(), {batch, rows}, "d", "gtsv2")); + FFI_RETURN_IF_ERROR( + CheckShape(du.dimensions(), {batch, rows}, "du", "gtsv2")); + if (batch > 1 && cols == 1) { + switch (dataType) { + case ffi::F32: + return Gtsv2BatchedImpl( + gpusparseSgtsv2StridedBatch_bufferSizeExt, + gpusparseSgtsv2StridedBatch, batch, rows, stream, scratch, dl, d, + du, b, out); + case ffi::F64: + return Gtsv2BatchedImpl( + gpusparseDgtsv2StridedBatch_bufferSizeExt, + gpusparseDgtsv2StridedBatch, batch, rows, stream, scratch, dl, d, + du, b, out); + case ffi::C64: + return Gtsv2BatchedImpl( + gpusparseCgtsv2StridedBatch_bufferSizeExt, + gpusparseCgtsv2StridedBatch, batch, rows, stream, scratch, dl, d, + du, b, out); + case ffi::C128: + return Gtsv2BatchedImpl( + gpusparseZgtsv2StridedBatch_bufferSizeExt, + gpusparseZgtsv2StridedBatch, batch, rows, stream, scratch, dl, d, + du, b, out); + default: + break; + } -void gtsv2_f64(gpuStream_t stream, void** buffers, const char* opaque, - std::size_t opaque_len, XlaCustomCallStatus* status) { - auto s = gtsv2(gpusparseDgtsv2, stream, buffers, opaque, opaque_len); - if (!s.ok()) { - XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(), - s.message().length()); + } else { + switch (dataType) { + case ffi::F32: + return Gtsv2Impl(gpusparseSgtsv2_bufferSizeExt, gpusparseSgtsv2, + batch, rows, cols, stream, scratch, dl, d, du, + b, out); + case ffi::F64: + return Gtsv2Impl(gpusparseDgtsv2_bufferSizeExt, gpusparseDgtsv2, + batch, rows, cols, stream, scratch, dl, d, du, + b, out); + case ffi::C64: + return Gtsv2Impl(gpusparseCgtsv2_bufferSizeExt, + gpusparseCgtsv2, batch, rows, cols, stream, + scratch, dl, d, du, b, out); + case ffi::C128: + return Gtsv2Impl(gpusparseZgtsv2_bufferSizeExt, + gpusparseZgtsv2, batch, rows, cols, + stream, scratch, dl, d, du, b, out); + default: + break; + } } + return ffi::Error::InvalidArgument(absl::StrFormat( + "Unsupported dtype %s in gtsv2", absl::FormatStreamed(dataType))); } +XLA_FFI_DEFINE_HANDLER_SYMBOL(kGtsv2, Gtsv2, + ffi::Ffi::Bind() + .Ctx>() + .Ctx() + .Arg() // dl + .Arg() // d + .Arg() // du + .Arg() // b + .Ret() // out +); + } // namespace JAX_GPU_NAMESPACE } // namespace jax diff --git a/jaxlib/gpu/sparse_kernels.h b/jaxlib/gpu/sparse_kernels.h index 323431812758..75f83752be15 100644 --- a/jaxlib/gpu/sparse_kernels.h +++ b/jaxlib/gpu/sparse_kernels.h @@ -16,17 +16,12 @@ limitations under the License. #ifndef JAXLIB_GPU_SPARSE_KERNELS_H_ #define JAXLIB_GPU_SPARSE_KERNELS_H_ -#include #include -#include -#include -#include #include "absl/status/statusor.h" +#include "jaxlib/gpu/handle_pool.h" #include "jaxlib/gpu/vendor.h" -#include "jaxlib/handle_pool.h" #include "xla/ffi/api/ffi.h" -#include "xla/service/custom_call_status.h" namespace jax { @@ -75,17 +70,6 @@ struct DenseVecDescriptor { }; #if JAX_GPU_HAVE_SPARSE -// CsrToDense: Convert CSR matrix to dense matrix - -void CsrToDense(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status); - -// CsrFromDense: Convert dense matrix to CSR matrix - -void CsrFromDense(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status); - -// CsrMatvec: Product of CSR matrix and dense vector. struct CsrMatvecDescriptor { SparseMatDescriptor A; @@ -93,63 +77,24 @@ struct CsrMatvecDescriptor { gpusparseOperation_t op; }; -void CsrMatvec(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status); - -// CsrMatmat: Product of CSR matrix and dense matrix. - struct CsrMatmatDescriptor { SparseMatDescriptor A; DenseMatDescriptor B, C; gpusparseOperation_t op_A; }; -void CsrMatmat(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status); - -// CooToDense: Convert COO matrix to dense matrix - -void CooToDense(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status); - -// CooFromDense: Convert dense matrix to COO matrix - -void CooFromDense(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status); - -// CooMatvec: Product of COO matrix and dense vector. - struct CooMatvecDescriptor { SparseMatDescriptor A; DenseVecDescriptor x, y; gpusparseOperation_t op; }; -void CooMatvec(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status); - -// CooMatmat: Product of COO matrix and dense matrix. - struct CooMatmatDescriptor { SparseMatDescriptor A; DenseMatDescriptor B, C; gpusparseOperation_t op_A; }; -void CooMatmat(gpuStream_t stream, void** buffers, const char* opaque, - size_t opaque_len, XlaCustomCallStatus* status); -#endif // JAX_GPU_HAVE_SPARSE - -struct Gtsv2Descriptor { - int batch, m, n, ldb; -}; - -void gtsv2_f32(gpuStream_t stream, void** buffers, const char* opaque, - std::size_t opaque_len, XlaCustomCallStatus* status); - -void gtsv2_f64(gpuStream_t stream, void** buffers, const char* opaque, - std::size_t opaque_len, XlaCustomCallStatus* status); - XLA_FFI_DECLARE_HANDLER_SYMBOL(CsrToDenseFfi); XLA_FFI_DECLARE_HANDLER_SYMBOL(CsrFromDenseFfi); XLA_FFI_DECLARE_HANDLER_SYMBOL(CsrMatvecFfi); @@ -158,8 +103,10 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(CooToDenseFfi); XLA_FFI_DECLARE_HANDLER_SYMBOL(CooFromDenseFfi); XLA_FFI_DECLARE_HANDLER_SYMBOL(CooMatvecFfi); XLA_FFI_DECLARE_HANDLER_SYMBOL(CooMatmatFfi); -XLA_FFI_DECLARE_HANDLER_SYMBOL(gtsv2_f32_ffi); -XLA_FFI_DECLARE_HANDLER_SYMBOL(gtsv2_f64_ffi); + +#endif // JAX_GPU_HAVE_SPARSE + +XLA_FFI_DECLARE_HANDLER_SYMBOL(kGtsv2); } // namespace JAX_GPU_NAMESPACE } // namespace jax diff --git a/jaxlib/gpu/triton.cc b/jaxlib/gpu/triton.cc index 500034af3ebb..b3f313e4f7ea 100644 --- a/jaxlib/gpu/triton.cc +++ b/jaxlib/gpu/triton.cc @@ -1,17 +1,35 @@ +/* Copyright 2022 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include #include -#include #include #include #include +#include #include -#include "nanobind/nanobind.h" -#include "nanobind/stl/pair.h" -#include "nanobind/stl/string.h" -#include "nanobind/stl/string_view.h" -#include "nanobind/stl/tuple.h" -#include "nanobind/stl/vector.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/pair.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "nanobind/stl/tuple.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep #include "jaxlib/absl_status_casters.h" #include "jaxlib/gpu/gpu_kernel_helpers.h" #include "jaxlib/gpu/triton.pb.h" diff --git a/jaxlib/gpu/triton_kernels.cc b/jaxlib/gpu/triton_kernels.cc index 22397ff908bc..9e0dc6c855ac 100644 --- a/jaxlib/gpu/triton_kernels.cc +++ b/jaxlib/gpu/triton_kernels.cc @@ -1,3 +1,18 @@ +/* Copyright 2023 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + #include "jaxlib/gpu/triton_kernels.h" #include @@ -25,6 +40,7 @@ #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" #include "jaxlib/gpu/gpu_kernel_helpers.h" #include "jaxlib/gpu/triton.pb.h" @@ -37,7 +53,8 @@ #endif // JAX_GPU_CUDA #ifdef JAX_GPU_HIP -#include "tsl/platform/env.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/errors.h" #endif // JAX_GPU_HIP #define GPU_RETURN_IF_ERROR(expr) JAX_RETURN_IF_ERROR(JAX_AS_STATUS(expr)) diff --git a/jaxlib/gpu/triton_kernels.h b/jaxlib/gpu/triton_kernels.h index c3457093c4f8..3ab3e9143fb8 100644 --- a/jaxlib/gpu/triton_kernels.h +++ b/jaxlib/gpu/triton_kernels.h @@ -1,8 +1,23 @@ +/* Copyright 2023 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + #ifndef JAXLIB_GPU_TRITON_H_ #define JAXLIB_GPU_TRITON_H_ +#include #include -#include #include #include #include @@ -10,7 +25,6 @@ #include "absl/status/status.h" #include "absl/status/statusor.h" -#include "jaxlib/gpu/gpu_kernel_helpers.h" #include "jaxlib/gpu/triton.pb.h" #include "jaxlib/gpu/vendor.h" #include "xla/service/custom_call_status.h" diff --git a/jaxlib/gpu/triton_utils.cc b/jaxlib/gpu/triton_utils.cc index b3a0779118de..fd63435da177 100644 --- a/jaxlib/gpu/triton_utils.cc +++ b/jaxlib/gpu/triton_utils.cc @@ -1,3 +1,18 @@ +/* Copyright 2023 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + #include "jaxlib/gpu/triton_utils.h" #include @@ -9,6 +24,7 @@ #include "absl/strings/string_view.h" #include "jaxlib/gpu/gpu_kernel_helpers.h" #include "jaxlib/gpu/triton.pb.h" +#include "jaxlib/gpu/vendor.h" namespace jax::JAX_GPU_NAMESPACE { diff --git a/jaxlib/gpu/triton_utils.h b/jaxlib/gpu/triton_utils.h index 0c286391e296..a79c098373d1 100644 --- a/jaxlib/gpu/triton_utils.h +++ b/jaxlib/gpu/triton_utils.h @@ -1,9 +1,23 @@ +/* Copyright 2023 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + #ifndef JAXLIB_GPU_TRITON_UTILS_H_ #define JAXLIB_GPU_TRITON_UTILS_H_ #include -#include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "jaxlib/gpu/vendor.h" diff --git a/jaxlib/gpu/vendor.h b/jaxlib/gpu/vendor.h index 7334d4690b59..b96552f81bd1 100644 --- a/jaxlib/gpu/vendor.h +++ b/jaxlib/gpu/vendor.h @@ -20,6 +20,7 @@ limitations under the License. #ifndef JAXLIB_GPU_VENDOR_H_ #define JAXLIB_GPU_VENDOR_H_ +#include #if defined(JAX_GPU_CUDA) // IWYU pragma: begin_exports @@ -29,7 +30,7 @@ limitations under the License. #include "third_party/gpus/cuda/include/cublas_v2.h" #include "third_party/gpus/cuda/include/cuda.h" #include "third_party/gpus/cuda/include/cuda_fp8.h" -#include "third_party/gpus/cuda/include/cuda_runtime_api.h" +#include "cuda_runtime_api.h" #include "third_party/gpus/cuda/include/cufft.h" #include "third_party/gpus/cuda/include/cusolverDn.h" #include "third_party/gpus/cuda/include/cusolver_common.h" @@ -48,6 +49,7 @@ limitations under the License. #define JAX_GPU_NAMESPACE cuda #define JAX_GPU_PREFIX "cu" +#define JAX_GPU_PLUGIN_NAME "cuda" typedef cuComplex gpuComplex; typedef cuDoubleComplex gpuDoubleComplex; @@ -150,7 +152,8 @@ typedef cusparseDnVecDescr_t gpusparseDnVecDescr_t; #define GPUDNN_STATUS_SUCCESS CUDNN_STATUS_SUCCESS #define GPUDNN_WGRAD_MODE_ADD CUDNN_WGRAD_MODE_ADD #define GPUDNN_RNN_ALGO_STANDARD CUDNN_RNN_ALGO_STANDARD -#define GPUDNN_RNN_DATA_LAYOUT_BATCH_MAJOR_UNPACKED CUDNN_RNN_DATA_LAYOUT_BATCH_MAJOR_UNPACKED +#define GPUDNN_RNN_DATA_LAYOUT_BATCH_MAJOR_UNPACKED \ + CUDNN_RNN_DATA_LAYOUT_BATCH_MAJOR_UNPACKED #define GPUDNN_RNN_PADDED_IO_ENABLED CUDNN_RNN_PADDED_IO_ENABLED #define GPUDNN_DEFAULT_MATH CUDNN_DEFAULT_MATH #define GPUDNN_FMA_MATH CUDNN_FMA_MATH @@ -287,10 +290,28 @@ typedef cusparseDnVecDescr_t gpusparseDnVecDescr_t; #define gpusparseSpMM_bufferSize cusparseSpMM_bufferSize #define gpusparseSpMV cusparseSpMV #define gpusparseSpMV_bufferSize cusparseSpMV_bufferSize + #define gpusparseSgtsv2 cusparseSgtsv2 #define gpusparseDgtsv2 cusparseDgtsv2 +#define gpusparseCgtsv2 cusparseCgtsv2 +#define gpusparseZgtsv2 cusparseZgtsv2 #define gpusparseSgtsv2_bufferSizeExt cusparseSgtsv2_bufferSizeExt #define gpusparseDgtsv2_bufferSizeExt cusparseDgtsv2_bufferSizeExt +#define gpusparseCgtsv2_bufferSizeExt cusparseCgtsv2_bufferSizeExt +#define gpusparseZgtsv2_bufferSizeExt cusparseZgtsv2_bufferSizeExt + +#define gpusparseSgtsv2StridedBatch_bufferSizeExt \ + cusparseSgtsv2StridedBatch_bufferSizeExt +#define gpusparseDgtsv2StridedBatch_bufferSizeExt \ + cusparseDgtsv2StridedBatch_bufferSizeExt +#define gpusparseCgtsv2StridedBatch_bufferSizeExt \ + cusparseCgtsv2StridedBatch_bufferSizeExt +#define gpusparseZgtsv2StridedBatch_bufferSizeExt \ + cusparseZgtsv2StridedBatch_bufferSizeExt +#define gpusparseSgtsv2StridedBatch cusparseSgtsv2StridedBatch +#define gpusparseDgtsv2StridedBatch cusparseDgtsv2StridedBatch +#define gpusparseCgtsv2StridedBatch cusparseCgtsv2StridedBatch +#define gpusparseZgtsv2StridedBatch cusparseZgtsv2StridedBatch #define GPUSPARSE_INDEX_16U CUSPARSE_INDEX_16U #define GPUSPARSE_INDEX_32I CUSPARSE_INDEX_32I @@ -413,6 +434,7 @@ constexpr uint32_t kNumThreadsPerWarp = 32; #define JAX_GPU_NAMESPACE hip #define JAX_GPU_PREFIX "hip" +#define JAX_GPU_PLUGIN_NAME "rocm" #define JAX_GPU_HAVE_SPARSE 1 #define JAX_GPU_HAVE_64_BIT 0 @@ -633,10 +655,28 @@ typedef hipsparseDnVecDescr_t gpusparseDnVecDescr_t; #define gpusparseSpMM_bufferSize hipsparseSpMM_bufferSize #define gpusparseSpMV hipsparseSpMV #define gpusparseSpMV_bufferSize hipsparseSpMV_bufferSize + #define gpusparseSgtsv2 hipsparseSgtsv2 #define gpusparseDgtsv2 hipsparseDgtsv2 +#define gpusparseCgtsv2 hipsparseCgtsv2 +#define gpusparseZgtsv2 hipsparseZgtsv2 #define gpusparseSgtsv2_bufferSizeExt hipsparseSgtsv2_bufferSizeExt #define gpusparseDgtsv2_bufferSizeExt hipsparseDgtsv2_bufferSizeExt +#define gpusparseCgtsv2_bufferSizeExt hipsparseCgtsv2_bufferSizeExt +#define gpusparseZgtsv2_bufferSizeExt hipsparseZgtsv2_bufferSizeExt + +#define gpusparseSgtsv2StridedBatch_bufferSizeExt \ + hipsparseSgtsv2StridedBatch_bufferSizeExt +#define gpusparseDgtsv2StridedBatch_bufferSizeExt \ + hipsparseDgtsv2StridedBatch_bufferSizeExt +#define gpusparseCgtsv2StridedBatch_bufferSizeExt \ + hipsparseCgtsv2StridedBatch_bufferSizeExt +#define gpusparseZgtsv2StridedBatch_bufferSizeExt \ + hipsparseZgtsv2StridedBatch_bufferSizeExt +#define gpusparseSgtsv2StridedBatch hipsparseSgtsv2StridedBatch +#define gpusparseDgtsv2StridedBatch hipsparseDgtsv2StridedBatch +#define gpusparseCgtsv2StridedBatch hipsparseCgtsv2StridedBatch +#define gpusparseZgtsv2StridedBatch hipsparseZgtsv2StridedBatch #define GPUSPARSE_INDEX_16U HIPSPARSE_INDEX_16U #define GPUSPARSE_INDEX_32I HIPSPARSE_INDEX_32I diff --git a/jaxlib/gpu_linalg.py b/jaxlib/gpu_linalg.py index c747c0abbe8b..967dacdbacff 100644 --- a/jaxlib/gpu_linalg.py +++ b/jaxlib/gpu_linalg.py @@ -19,12 +19,17 @@ _cuda_linalg = import_from_plugin("cuda", "_linalg") _hip_linalg = import_from_plugin("rocm", "_linalg") + def registrations() -> dict[str, list[tuple[str, Any, int]]]: - registrations = {"CUDA": [], "ROCM": []} + registrations: dict[str, list[tuple[str, Any, int]]] = { + "CUDA": [], + "ROCM": [], + } for platform, module in [("CUDA", _cuda_linalg), ("ROCM", _hip_linalg)]: if module: registrations[platform].extend( - (*i, 1) for i in module.registrations().items()) + (*i, 1) for i in module.registrations().items() + ) return registrations # pytype: disable=bad-return-type diff --git a/jaxlib/gpu_prng.py b/jaxlib/gpu_prng.py index 6f74d5813ce4..17da46de699f 100644 --- a/jaxlib/gpu_prng.py +++ b/jaxlib/gpu_prng.py @@ -12,79 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations -from functools import partial -import itertools +from typing import Any -import jaxlib.mlir.ir as ir - -from jaxlib import xla_client - -from .hlo_helpers import custom_call from .plugin_support import import_from_plugin _cuda_prng = import_from_plugin("cuda", "_prng") _hip_prng = import_from_plugin("rocm", "_prng") -if _cuda_prng: - for _name, _value in _cuda_prng.registrations().items(): - # TODO(danfm): remove after JAX 0.5.1 release - api_version = 1 if "_ffi" in _name else 0 - xla_client.register_custom_call_target(_name, _value, platform="CUDA", - api_version=api_version) - -if _hip_prng: - for _name, _value in _hip_prng.registrations().items(): - # TODO(danfm): remove after JAX 0.5.1 release - api_version = 1 if "_ffi" in _name else 0 - xla_client.register_custom_call_target(_name, _value, platform="ROCM", - api_version=api_version) - - -def _threefry2x32_lowering(prng, platform: str, keys, data, - length: int | ir.Value | None = None, - output_shape: ir.Value | None = None, - forward_compatibility_mode: bool = False): - """ThreeFry2x32 kernel for GPU. - - In presence of dynamic shapes, `length` is an `ir.Value` and `output_shape` - is a 1D tensor describing the shape of the two outputs. - """ - del forward_compatibility_mode - assert len(keys) == 2, keys - assert len(data) == 2, data - assert (ir.RankedTensorType(keys[0].type).element_type == - ir.IntegerType.get_unsigned(32)), keys[0].type - - typ = keys[0].type - dims = ir.RankedTensorType(typ).shape - - for x in itertools.chain(keys, data): - assert x.type == typ, (x.type, typ) - ndims = len(dims) - layout = tuple(range(ndims - 1, -1, -1)) - operand_layouts = [layout] * 4 - operands = [keys[0], keys[1], data[0], data[1]] - - opaque = {} # Use if not forward_compatibility_mode to trigger the FFI (v4). - if isinstance(length, int): - result_shapes = None - else: - assert output_shape is not None - # We also need to pass separately the shapes of the outputs. - result_shapes = [output_shape, output_shape] - - custom_call_target = f"{platform}_threefry2x32_ffi" - return custom_call( - custom_call_target, - api_version=4, - result_types=[typ, typ], - operands=operands, - backend_config=opaque, - operand_layouts=operand_layouts, - result_layouts=[layout] * 2, - result_shapes=result_shapes).results - -cuda_threefry2x32 = partial(_threefry2x32_lowering, _cuda_prng, "cu") -rocm_threefry2x32 = partial(_threefry2x32_lowering, _hip_prng, "hip") +def registrations() -> dict[str, list[tuple[str, Any, int]]]: + registrations: dict[str, list[tuple[str, Any, int]]] = { + "CUDA": [], + "ROCM": [], + } + for platform, module in [("CUDA", _cuda_prng), ("ROCM", _hip_prng)]: + if module: + registrations[platform].extend( + (name, value, int(name.endswith("_ffi"))) + for name, value in module.registrations().items() + ) + return registrations diff --git a/jaxlib/gpu_solver.py b/jaxlib/gpu_solver.py index efb58f9a4164..cdcd2b6199f9 100644 --- a/jaxlib/gpu_solver.py +++ b/jaxlib/gpu_solver.py @@ -16,21 +16,18 @@ from .plugin_support import import_from_plugin -_cublas = import_from_plugin("cuda", "_blas") _cusolver = import_from_plugin("cuda", "_solver") _cuhybrid = import_from_plugin("cuda", "_hybrid") -_hipblas = import_from_plugin("rocm", "_blas") _hipsolver = import_from_plugin("rocm", "_solver") _hiphybrid = import_from_plugin("rocm", "_hybrid") def registrations() -> dict[str, list[tuple[str, Any, int]]]: - registrations = {"CUDA": [], "ROCM": []} - for platform, module in [("CUDA", _cublas), ("ROCM", _hipblas)]: - if module: - registrations[platform].extend( - (*i, 0) for i in module.registrations().items()) + registrations: dict[str, list[tuple[str, Any, int]]] = { + "CUDA": [], + "ROCM": [], + } for platform, module in [("CUDA", _cusolver), ("ROCM", _hipsolver)]: if module: registrations[platform].extend( @@ -40,17 +37,17 @@ def registrations() -> dict[str, list[tuple[str, Any, int]]]: for platform, module in [("CUDA", _cuhybrid), ("ROCM", _hiphybrid)]: if module: registrations[platform].extend( - (*i, 1) for i in module.registrations().items()) + (*i, 1) for i in module.registrations().items() + ) return registrations # pytype: disable=bad-return-type def batch_partitionable_targets() -> list[str]: - targets = [] + targets: list[str] = [] for module in [_cusolver, _hipsolver]: if module: targets.extend( - name for name in module.registrations() - if name.endswith("_ffi") + name for name in module.registrations() if name.endswith("_ffi") ) for module in [_cuhybrid, _hiphybrid]: if module: diff --git a/jaxlib/gpu_sparse.py b/jaxlib/gpu_sparse.py index d8645041c946..bf1dc6f64ec1 100644 --- a/jaxlib/gpu_sparse.py +++ b/jaxlib/gpu_sparse.py @@ -11,373 +11,36 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" -cusparse wrappers for performing sparse matrix computations in JAX -""" -import math -from functools import partial from typing import Any -import jaxlib.mlir.ir as ir - -import numpy as np - -from .hlo_helpers import custom_call, mk_result_types_and_shapes - from .plugin_support import import_from_plugin _cusparse = import_from_plugin("cuda", "_sparse") _hipsparse = import_from_plugin("rocm", "_sparse") +cuda_is_supported = bool(_cusparse and _cusparse.sparse_supported) +rocm_is_supported = bool(_hipsparse and _hipsparse.sparse_supported) + + def registrations() -> dict[str, list[tuple[str, Any, int]]]: - registrations = {"CUDA": [], "ROCM": []} + registrations: dict[str, list[tuple[str, Any, int]]] = { + "CUDA": [], + "ROCM": [], + } for platform, module in [("CUDA", _cusparse), ("ROCM", _hipsparse)]: if module: registrations[platform].extend( (name, value, int(name.endswith("_ffi"))) - for name, value in module.registrations().items()) + for name, value in module.registrations().items() + ) return registrations # pytype: disable=bad-return-type - -cuda_is_supported = bool(_cusparse and _cusparse.sparse_supported) -rocm_is_supported = bool(_hipsparse and _hipsparse.sparse_supported) - - -def _validate_csr_hlo(data, indices, indptr, shape): - data_type = ir.RankedTensorType(data.type) - indices_type = ir.RankedTensorType(indices.type) - indptr_type = ir.RankedTensorType(indptr.type) - - nnz, = data_type.shape - assert indices_type.shape == [nnz] - assert indptr_type.element_type == indices_type.element_type - assert indptr_type.shape == [shape[0] + 1] - return data_type.element_type, indices_type.element_type, nnz - -def _validate_coo_hlo(data, row, col): - data_type = ir.RankedTensorType(data.type) - row_type = ir.RankedTensorType(row.type) - col_type = ir.RankedTensorType(col.type) - - nnz, = data_type.shape - assert row_type.shape == [nnz] - assert col_type.element_type == row_type.element_type - assert col_type.shape == [nnz] - return data_type.element_type, row_type.element_type, nnz - - -def _csr_todense_hlo(platform, gpu_sparse, data, indices, indptr, *, shape, - data_dtype, index_dtype): - """CSR to dense matrix.""" - data_type, index_type, nnz = _validate_csr_hlo(data, indices, indptr, shape) - rows, cols = shape - - buffer_size, opaque = gpu_sparse.build_csr_todense_descriptor( - data_dtype, index_dtype, rows, cols, nnz) - - out = custom_call( - f"{platform}sparse_csr_todense_ffi", - result_types=[ - ir.RankedTensorType.get(shape, data_type), - ir.RankedTensorType.get([buffer_size], - ir.IntegerType.get_signless(8)), - ], - operands=[data, indices, indptr], - backend_config={"opaque": ir.StringAttr.get(opaque)}, - api_version=4, - operand_layouts=[[0]] * 3, - result_layouts=[[1, 0], [0]]).results - return out[0] - -cuda_csr_todense = partial(_csr_todense_hlo, "cu", _cusparse) -rocm_csr_todense = partial(_csr_todense_hlo, "hip", _hipsparse) - - -def _csr_fromdense_hlo(platform, gpu_sparse, mat, *, nnz, index_dtype, - data_dtype, index_type): - """CSR from dense matrix.""" - mat_type = ir.RankedTensorType(mat.type) - rows, cols = mat_type.shape - - buffer_size, opaque = gpu_sparse.build_csr_fromdense_descriptor( - data_dtype, index_dtype, rows, cols, nnz) - - out = custom_call( - f"{platform}sparse_csr_fromdense_ffi", - result_types=[ - ir.RankedTensorType.get([nnz], mat_type.element_type), - ir.RankedTensorType.get([nnz], index_type), - ir.RankedTensorType.get([rows + 1], index_type), - ir.RankedTensorType.get([buffer_size], - ir.IntegerType.get_signless(8)), - ], - operands=[mat], - backend_config={"opaque": ir.StringAttr.get(opaque)}, - api_version=4, - operand_layouts=[[1, 0]], - result_layouts=[[0]] * 4).results - return out[:3] - -cuda_csr_fromdense = partial(_csr_fromdense_hlo, "cu", _cusparse) -rocm_csr_fromdense = partial(_csr_fromdense_hlo, "hip", _hipsparse) - - -def _csr_matvec_hlo(platform, gpu_sparse, data, indices, indptr, x, *, shape, - transpose=False, compute_dtype=None, compute_type=None, - data_dtype, index_dtype, x_dtype): - """CSR matrix/vector multiply.""" - data_type, index_type, nnz = _validate_csr_hlo(data, indices, indptr, shape) - rows, cols = shape - - if compute_dtype is None: - compute_dtype = data_dtype - compute_type = data_type - - buffer_size, opaque = gpu_sparse.build_csr_matvec_descriptor( - data_dtype, x_dtype, compute_dtype, index_dtype, - rows, cols, nnz, transpose) - out_size = cols if transpose else rows - - out = custom_call( - f"{platform}sparse_csr_matvec_ffi", - result_types=[ - ir.RankedTensorType.get([out_size], compute_type), - ir.RankedTensorType.get([buffer_size], - ir.IntegerType.get_signless(8)), - ], - operands=[data, indices, indptr, x], - backend_config={"opaque": ir.StringAttr.get(opaque)}, - api_version=4, - operand_layouts=[[0]] * 4, - result_layouts=[[0]] * 2).results - return out[0] - -cuda_csr_matvec = partial(_csr_matvec_hlo, "cu", _cusparse) -rocm_csr_matvec = partial(_csr_matvec_hlo, "hip", _hipsparse) - - -def _csr_matmat_hlo(platform, gpu_sparse, data, indices, indptr, B, *, shape, - transpose=False, compute_dtype=None, compute_type=None, - index_dtype, data_dtype, B_dtype): - """CSR from dense matrix.""" - data_type, index_type, nnz = _validate_csr_hlo(data, indices, indptr, shape) - rows, cols = shape - B_shape = ir.RankedTensorType(B.type).shape - _, Ccols = B_shape - - if compute_dtype is None: - compute_dtype = data_dtype - compute_type = data_type - - buffer_size, opaque = gpu_sparse.build_csr_matmat_descriptor( - data_dtype, B_dtype, compute_dtype, index_dtype, - rows, cols, Ccols, nnz, transpose) - out_size = cols if transpose else rows - - out = custom_call( - f"{platform}sparse_csr_matmat_ffi", - result_types=[ - ir.RankedTensorType.get([out_size, Ccols], compute_type), - ir.RankedTensorType.get([buffer_size], - ir.IntegerType.get_signless(8)), - ], - operands=[data, indices, indptr, B], - backend_config={"opaque": ir.StringAttr.get(opaque)}, - api_version=4, - operand_layouts=[[0], [0], [0], [1, 0]], - result_layouts=[[1, 0], [0]]).results - return out[0] - -cuda_csr_matmat = partial(_csr_matmat_hlo, "cu", _cusparse) -rocm_csr_matmat = partial(_csr_matmat_hlo, "hip", _hipsparse) - - -def _coo_todense_hlo(platform, gpu_sparse, data, row, col, *, shape, - data_dtype, index_dtype): - """COO to dense matrix.""" - data_type, _, nnz = _validate_coo_hlo(data, row, col) - rows, cols = shape - - buffer_size, opaque = gpu_sparse.build_coo_todense_descriptor( - data_dtype, index_dtype, rows, cols, nnz) - - out = custom_call( - f"{platform}sparse_coo_todense_ffi", - result_types=[ - ir.RankedTensorType.get(shape, data_type), - ir.RankedTensorType.get([buffer_size], - ir.IntegerType.get_signless(8)), - ], - operands=[data, row, col], - backend_config={"opaque": ir.StringAttr.get(opaque)}, - api_version=4, - operand_layouts=[[0]] * 3, - result_layouts=[[1, 0], [0]]).results - return out[0] - -cuda_coo_todense = partial(_coo_todense_hlo, "cu", _cusparse) -rocm_coo_todense = partial(_coo_todense_hlo, "hip", _hipsparse) - - -def _coo_fromdense_hlo(platform, gpu_sparse, mat, *, nnz, data_dtype, - index_dtype, index_type): - """COO from dense matrix.""" - mat_type = ir.RankedTensorType(mat.type) - rows, cols = mat_type.shape - - buffer_size, opaque = gpu_sparse.build_coo_fromdense_descriptor( - data_dtype, index_dtype, rows, cols, nnz) - - out = custom_call( - f"{platform}sparse_coo_fromdense_ffi", - result_types=[ - ir.RankedTensorType.get([nnz], mat_type.element_type), - ir.RankedTensorType.get([nnz], index_type), - ir.RankedTensorType.get([nnz], index_type), - ir.RankedTensorType.get([buffer_size], - ir.IntegerType.get_signless(8)), - ], - operands=[mat], - backend_config={"opaque": ir.StringAttr.get(opaque)}, - api_version=4, - operand_layouts=[[1, 0]], - result_layouts=[[0]] * 4).results - return out[:3] - -cuda_coo_fromdense = partial(_coo_fromdense_hlo, "cu", _cusparse) -rocm_coo_fromdense = partial(_coo_fromdense_hlo, "hip", _hipsparse) - - -def _coo_matvec_hlo(platform, gpu_sparse, data, row, col, x, *, shape, - transpose=False, compute_dtype=None, compute_type=None, - index_dtype, data_dtype, x_dtype): - """COO matrix/vector multiply.""" - data_type, _, nnz = _validate_coo_hlo(data, row, col) - rows, cols = shape - - if compute_dtype is None: - compute_dtype = data_dtype - compute_type = data_type - - buffer_size, opaque = gpu_sparse.build_coo_matvec_descriptor( - data_dtype, x_dtype, compute_dtype, index_dtype, - rows, cols, nnz, transpose) - out_size = cols if transpose else rows - - out = custom_call( - f"{platform}sparse_coo_matvec_ffi", - result_types=[ - ir.RankedTensorType.get([out_size], compute_type), - ir.RankedTensorType.get([buffer_size], - ir.IntegerType.get_signless(8)), - ], - operands=[data, row, col, x], - backend_config={"opaque": ir.StringAttr.get(opaque)}, - api_version=4, - operand_layouts=[[0]] * 4, - result_layouts=[[0]] * 2).results - return out[0] - -cuda_coo_matvec = partial(_coo_matvec_hlo, "cu", _cusparse) -rocm_coo_matvec = partial(_coo_matvec_hlo, "hip", _hipsparse) - - -def _coo_matmat_hlo(platform, gpu_sparse, data, row, col, B, *, shape, - transpose=False, compute_dtype=None, compute_type=None, - x_dtype, data_dtype, index_dtype): - """COO from dense matrix.""" - data_type, _, nnz = _validate_coo_hlo(data, row, col) - is_batched_matmat = False - batch_count = 1 - if len(shape) == 2: - rows, cols = shape - elif len(shape) == 3: - is_batched_matmat = True - batch_count, rows, cols = shape - # Redefine nnz as nnz per batch. - nnz = nnz // batch_count - - B_shape = ir.RankedTensorType(B.type).shape - _, Ccols = B_shape - - if compute_dtype is None: - compute_dtype = data_dtype - compute_type = data_type - - # TODO(tianjianlu): use batch stride to trigger different mode of batch - # computation. Currently batch_stride = 0 is not allowed because of the issue - # in cusparse https://github.com/NVIDIA/CUDALibrarySamples/issues/81#issuecomment-1205562643 - # Set batch stride to be the matrix size for now. - lhs_batch_stride = nnz - B_rows = rows if transpose else cols - rhs_batch_stride = B_rows * Ccols - - buffer_size, opaque = gpu_sparse.build_coo_matmat_descriptor( - data_dtype, x_dtype, compute_dtype, index_dtype, - rows, cols, Ccols, nnz, transpose, batch_count, lhs_batch_stride, - rhs_batch_stride) - out_size = cols if transpose else rows - - if is_batched_matmat: - out_shape = [batch_count, out_size, Ccols] - out_layout = [2, 1, 0] - else: - out_shape = [out_size, Ccols] - out_layout = [1, 0] - - out = custom_call( - f"{platform}sparse_coo_matmat_ffi", - result_types=[ - ir.RankedTensorType.get(out_shape, compute_type), - ir.RankedTensorType.get([buffer_size], - ir.IntegerType.get_signless(8)), - ], - operands=[data, row, col, B], - backend_config={"opaque": ir.StringAttr.get(opaque)}, - api_version=4, - operand_layouts=[[0], [0], [0], [1, 0]], - result_layouts=[out_layout, [0]]).results - return out[0] - -cuda_coo_matmat = partial(_coo_matmat_hlo, "cu", _cusparse) -rocm_coo_matmat = partial(_coo_matmat_hlo, "hip", _hipsparse) - - -def _gtsv2_hlo( - platform, gpu_sparse, dl, d, du, B, *, m, n, ldb, t, b_shape_vals=None): - """Calls `cusparsegtsv2(dl, d, du, B, m, n, ldb)`.""" - assert len(b_shape_vals) >= 2 - batch_dim_vals = b_shape_vals[:-2] - batch_size = math.prod(batch_dim_vals) - num_bd = len(b_shape_vals) - 2 - f32 = (t == np.float32) - if f32: - buffer_size = gpu_sparse.gtsv2_f32_buffer_size(m, n, ldb) - else: - buffer_size = gpu_sparse.gtsv2_f64_buffer_size(m, n, ldb) - - b_layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1)) - d_layout = (num_bd,) + tuple(range(num_bd - 1, -1, -1)) - b_type = ir.RankedTensorType(B.type) - - shape_type_pairs = [ - (batch_dim_vals + (ldb, n), b_type.element_type), - ((buffer_size,), ir.IntegerType.get_signless(8)) - ] - result_types, result_shapes = mk_result_types_and_shapes(shape_type_pairs) - opaque = gpu_sparse.build_gtsv2_descriptor(batch_size, m, n, ldb) - out = custom_call( - f"{platform}sparse_gtsv2_" + ("f32" if f32 else "f64") + "_ffi", - result_types=result_types, - operands=[dl, d, du, B], - backend_config={"opaque": ir.StringAttr.get(opaque)}, - api_version=4, - operand_layouts=[d_layout] * 3 + [b_layout], - result_layouts=[b_layout, [0]], - operand_output_aliases={3: 0}, - result_shapes=result_shapes).results - return out[0] - -cuda_gtsv2 = partial(_gtsv2_hlo, "cu", _cusparse) -rocm_gtsv2 = partial(_gtsv2_hlo, "hip", _hipsparse) +def batch_partitionable_targets() -> list[str]: + targets: list[str] = [] + for module in [_cusparse, _hipsparse]: + if module: + targets.extend( + name for name in module.registrations() if name.endswith("gtsv2_ffi") + ) + return targets diff --git a/jaxlib/guard_lib.cc b/jaxlib/guard_lib.cc new file mode 100644 index 000000000000..6ad1f8e5366c --- /dev/null +++ b/jaxlib/guard_lib.cc @@ -0,0 +1,197 @@ +/* Copyright 2024 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This files implements the configuration management for different types of +// guards. +// C++ backends are responsible for enforcing transfer guard levels. + +#include "jaxlib/guard_lib.h" + +#include +#include + +#include "absl/base/attributes.h" +#include "absl/functional/function_ref.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/optional.h" // IWYU pragma: keep +#include "xla/util.h" + +namespace jax { + +namespace nb = ::nanobind; + +namespace { + +// Protected by the GIL. +GuardState& global_state = *new GuardState(); + +ABSL_CONST_INIT thread_local GuardState thread_local_state; + +// The default transfer guard level. +constexpr TransferGuardLevel kDefaultGuardLevel = TransferGuardLevel::kAllow; + +// The default garbage collection guard level. +constexpr GarbageCollectionGuardLevel kDefaultGarbageCollectionGuardLevel = + GarbageCollectionGuardLevel::kAllow; + +// Returns the transfer guard action for a transfer. +TransferGuardAction GetTransferGuardAction(TransferGuardLevel guard_level, + bool explicit_transfer) { + switch (guard_level) { + case TransferGuardLevel::kAllow: + return TransferGuardAction::kAllow; + case TransferGuardLevel::kLog: + if (explicit_transfer) { + return TransferGuardAction::kAllow; + } else { + return TransferGuardAction::kLog; + } + case TransferGuardLevel::kDisallow: + if (explicit_transfer) { + return TransferGuardAction::kAllow; + } else { + return TransferGuardAction::kDisallow; + } + case TransferGuardLevel::kLogExplicit: + return TransferGuardAction::kLog; + case TransferGuardLevel::kDisallowExplicit: + return TransferGuardAction::kDisallow; + default: + // Unreachable; gracefully handle the unexpected guard level and prevent a + // compiler warning. + return TransferGuardAction::kDisallow; + } +} + +// Returns the transfer guard action for a host-to-device transfer. +// REQUIRES: Python GIL. +TransferGuardAction GetTransferGuardActionForHostToDevice() { + return GetTransferGuardAction( + thread_local_state.host_to_device.value_or( + global_state.host_to_device.value_or(kDefaultGuardLevel)), + thread_local_state.explicit_device_put); +} + +// Returns the transfer guard action for a device-to-device transfer. +// REQUIRES: Python GIL. +TransferGuardAction GetTransferGuardActionForDeviceToDevice() { + return GetTransferGuardAction( + thread_local_state.device_to_device.value_or( + global_state.device_to_device.value_or(kDefaultGuardLevel)), + thread_local_state.explicit_device_put); +} + +// Returns the transfer guard action for a device-to-host transfer. +// REQUIRES: Python GIL. +TransferGuardAction GetTransferGuardActionForDeviceToHost() { + return GetTransferGuardAction( + thread_local_state.device_to_host.value_or( + global_state.device_to_host.value_or(kDefaultGuardLevel)), + thread_local_state.explicit_device_get); +} + +} // namespace + +absl::Status ApplyTransferGuardToHostToDevice( + absl::FunctionRef formatter) { + switch (GetTransferGuardActionForHostToDevice()) { + case TransferGuardAction::kAllow: + break; + case TransferGuardAction::kLog: + LOG(WARNING) << "host-to-device transfer: " << formatter(); + break; + case TransferGuardAction::kDisallow: + return xla::InvalidArgument("Disallowed host-to-device transfer: %s", + formatter()); + } + return absl::OkStatus(); +} + +absl::Status ApplyTransferGuardToDeviceToDevice( + absl::FunctionRef formatter) { + switch (GetTransferGuardActionForDeviceToDevice()) { + case TransferGuardAction::kAllow: + break; + case TransferGuardAction::kLog: + LOG(WARNING) << "device-to-device transfer: " << formatter(); + break; + case TransferGuardAction::kDisallow: + return xla::InvalidArgument("Disallowed device-to-device transfer: %s", + formatter()); + } + return absl::OkStatus(); +} + +absl::Status ApplyTransferGuardToDeviceToHost( + absl::FunctionRef formatter) { + switch (GetTransferGuardActionForDeviceToHost()) { + case TransferGuardAction::kAllow: + break; + case TransferGuardAction::kLog: + LOG(WARNING) << "device-to-host transfer: " << formatter(); + break; + case TransferGuardAction::kDisallow: + return xla::InvalidArgument("Disallowed device-to-host transfer: %s", + formatter()); + } + return absl::OkStatus(); +} + +GarbageCollectionGuardLevel GetGarbageCollectArrayGuard() { + return thread_local_state.garbage_collect_array.value_or( + global_state.garbage_collect_array.value_or( + kDefaultGarbageCollectionGuardLevel)); +} + +void BuildGuardSubmodule(nb::module_& m) { + nb::module_ glib = + m.def_submodule("guard_lib", "Jax support library for guards"); + + nb::enum_ tglevel(glib, "TransferGuardLevel"); + tglevel.value("ALLOW", TransferGuardLevel::kAllow); + tglevel.value("LOG", TransferGuardLevel::kLog); + tglevel.value("DISALLOW", TransferGuardLevel::kDisallow); + tglevel.value("LOG_EXPLICIT", TransferGuardLevel::kLogExplicit); + tglevel.value("DISALLOW_EXPLICIT", TransferGuardLevel::kDisallowExplicit); + + nb::enum_ gcglevel( + glib, "GarbageCollectionGuardLevel"); + gcglevel.value("ALLOW", GarbageCollectionGuardLevel::kAllow); + gcglevel.value("LOG", GarbageCollectionGuardLevel::kLog); + gcglevel.value("FATAL", GarbageCollectionGuardLevel::kFatal); + + nb::class_ tgstate(glib, "GuardState"); + tgstate.def_rw("host_to_device", &GuardState::host_to_device, + nb::arg().none()); + tgstate.def_rw("device_to_device", &GuardState::device_to_device, + nb::arg().none()); + tgstate.def_rw("device_to_host", &GuardState::device_to_host, + nb::arg().none()); + tgstate.def_rw("explicit_device_put", &GuardState::explicit_device_put); + tgstate.def_rw("explicit_device_get", &GuardState::explicit_device_get); + tgstate.def_rw("garbage_collect_array", &GuardState::garbage_collect_array, + nb::arg().none()); + + glib.def( + "global_state", [&]() { return &global_state; }, + nb::rv_policy::reference); + glib.def( + "thread_local_state", [&]() { return &thread_local_state; }, + nb::rv_policy::reference); +} + +} // namespace jax diff --git a/jaxlib/guard_lib.h b/jaxlib/guard_lib.h new file mode 100644 index 000000000000..93e632fb7c9a --- /dev/null +++ b/jaxlib/guard_lib.h @@ -0,0 +1,115 @@ +/* Copyright 2024 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_GUARD_LIB_H_ +#define JAXLIB_GUARD_LIB_H_ + +#include +#include + +// placeholder for index annotation headers +#include "absl/functional/function_ref.h" +#include "absl/status/status.h" +#include "nanobind/nanobind.h" + +namespace jax { + +// Transfer guard level chosen by the user code. +enum class TransferGuardLevel { + // Explicit transfers: allow + // Implicit transfers: allow + kAllow, + // Explicit transfers: allow + // Implicit transfers: log + kLog, + // Explicit transfers: allow + // Implicit transfers: disallow + kDisallow, + // Explicit transfers: log + // Implicit transfers: log + kLogExplicit, + // Explicit transfers: disallow + // Implicit transfers: disallow + kDisallowExplicit, +}; + +// Garbage collection guard level chose by the user code. +enum class GarbageCollectionGuardLevel { + // Silently allow the object to be garbage collected. + kAllow, + // Log and allow the object to be garbage collected. + kLog, + // Fatal crash on object garbage collection. + kFatal, +}; + +// Flags for guard levels are controlled by: +// - a global flag value, +// e.g., associated to --jax_transfer_guard_device_to_host +// which defaults to TransferGuardLevel::kAllow. +// - possibly a thread-local value, which initially is std::nullopt and +// overrides the global value if set. The thread-local state is used to +// implement context managers that locally override the global state. +// +// Explicit device_put/device_get contexts are tracked by context managers. +struct GuardState { + std::optional host_to_device; + std::optional device_to_device; + std::optional device_to_host; + bool explicit_device_put = false; + bool explicit_device_get = false; + + std::optional garbage_collect_array; +}; + +// Resulting action for a transfer given the transfer guard level and the +// transfer type. +enum class TransferGuardAction { + // Silently allow the transfer. + kAllow, + // Log and allow the transfer. + kLog, + // Disallow the transfer. + kDisallow, +}; + +// Guards a host-to-device transfer. formatter is called to describe the +// transfer in a log message or error status. +// REQUIRES: Python GIL. +absl::Status ApplyTransferGuardToHostToDevice( + absl::FunctionRef formatter); + +// Guards a device-to-device transfer. formatter is called to describe the +// transfer in a log message or error status. +// REQUIRES: Python GIL. +absl::Status ApplyTransferGuardToDeviceToDevice( + absl::FunctionRef formatter); + +// Guards a device-to-host transfer. formatter is called to describe the +// transfer in a log message or error status. +// REQUIRES: Python GIL. +absl::Status ApplyTransferGuardToDeviceToHost( + absl::FunctionRef formatter); + +// Returns the garbage collection guard level for "jax.Array" objects. +// REQUIRES: Python GIL. +GarbageCollectionGuardLevel GetGarbageCollectArrayGuard(); + +// The function to call in `xla.cc` to add the bindings for this module. +void BuildGuardSubmodule(nanobind::module_& m); + +} // namespace jax + +#endif // JAXLIB_GUARD_LIB_H_ diff --git a/jaxlib/hlo_helpers.py b/jaxlib/hlo_helpers.py index 0d57a04f1aa7..11ff844ae53f 100644 --- a/jaxlib/hlo_helpers.py +++ b/jaxlib/hlo_helpers.py @@ -19,11 +19,22 @@ from collections.abc import Callable, Sequence from functools import partial from typing import Union +import warnings import jaxlib.mlir.ir as ir import jaxlib.mlir.dialects.stablehlo as hlo import numpy as np +# TODO(danfm): This module isn't covered by JAX's compatibility policy, so no +# formal deprecation period is required, but there are enough users that we +# should keep this warning for at least one full release cycle. +# Deprecation added 2025-03-19 after the release of v0.5.3. Remove this whole +# module after the release of v0.5.4 or later. +warnings.warn( + "The jaxlib.hlo_helpers submodule is deprecated. Instead, use jax.ffi if " + "possible or, for lower-level operations, jax.interpreters.mlir.", + DeprecationWarning, +) _dtype_to_ir_type_factory : dict[np.dtype, Callable[[], ir.Type]] = { np.dtype(np.bool_): partial(ir.IntegerType.get_signless, 1), diff --git a/jaxlib/ifrt_proxy.cc b/jaxlib/ifrt_proxy.cc new file mode 100644 index 000000000000..e91c4d9a3859 --- /dev/null +++ b/jaxlib/ifrt_proxy.cc @@ -0,0 +1,162 @@ +// Copyright 2023 The JAX Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "jaxlib/ifrt_proxy.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/log/log_entry.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/time/time.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/function.h" // IWYU pragma: keep +#include "nanobind/stl/optional.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/unordered_map.h" // IWYU pragma: keep +#include "nanobind/stl/variant.h" // IWYU pragma: keep +#include "jaxlib/nb_class_ptr.h" +#include "jaxlib/py_client.h" +#include "xla/pjrt/status_casters.h" +#include "xla/python/ifrt/attribute_map.h" +#include "xla/python/ifrt/client.h" +#include "xla/python/ifrt_proxy/client/registry.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/statusor.h" + +namespace nb = ::nanobind; + +namespace xla { +namespace ifrt { +namespace proxy { +namespace { + +struct PyClientConnectionOptions { + std::optional> on_disconnect; + std::optional> on_connection_update; + std::optional connection_timeout_in_seconds; + std::optional< + std::unordered_map>> + initialization_data; +}; + +absl::StatusOr> GetClient( + std::string proxy_server_address, + const PyClientConnectionOptions& py_options) { + DCHECK(PyGILState_Check()); + std::unique_ptr client; + + ClientConnectionOptions options; + if (py_options.on_disconnect) { + // While it is possible to pass around `py_options.on_disconnect` without + // wrapping it via a shared_ptr, copying the `py_options.on_disconnect` + // object can internally attempt to acquire the GIL [1], and can thus block + // or even deadlock. A unique_ptr or `absl::AnyInvocable` is not sufficient + // because downstream code can make copies. Reference: + // https://pybind11.readthedocs.io/en/stable/advanced/misc.html#common-sources-of-global-interpreter-lock-errors + auto py_on_disconnect = std::make_shared>( + std::move(*py_options.on_disconnect)); + + options.on_disconnect = + [on_disconnect = std::move(py_on_disconnect)](absl::Status s) mutable { + LOG(WARNING) << "Connection to server failed, calling supplied " + << "`on_disconnect` function: " << s; + tsl::Env::Default()->SchedClosure([s, on_disconnect]() mutable { + nb::gil_scoped_acquire gil_acquire; + (*on_disconnect)(s.ToString()); + on_disconnect = nullptr; + }); + }; + } + + if (py_options.on_connection_update) { + auto fn = std::make_shared>( + std::move(*py_options.on_connection_update)); + options.on_connection_update = [fn](absl::string_view log_line) -> void { + tsl::Env::Default()->SchedClosure([fn, str = std::string(log_line)] { + nb::gil_scoped_acquire gil_acquire; + (*fn)(std::string(str)); + }); + }; + } + + if (py_options.connection_timeout_in_seconds.has_value()) { + options.connection_timeout = + absl::Seconds(*py_options.connection_timeout_in_seconds); + } + + if (py_options.initialization_data.has_value()) { + AttributeMap::Map attribute_map; + for (const auto& [key, py_value] : *py_options.initialization_data) { + if (std::holds_alternative(py_value)) { + nb::bytes value = std::get(py_value); + attribute_map.insert({key, AttributeMap::StringValue(std::string( + value.c_str(), value.size()))}); + } else if (std::holds_alternative(py_value)) { + attribute_map.insert( + {key, AttributeMap::BoolValue(std::get(py_value))}); + } else { + CHECK(std::holds_alternative(py_value)); + attribute_map.insert( + {key, AttributeMap::Int64Value(std::get(py_value))}); + } + } + options.initialization_data = AttributeMap(std::move(attribute_map)); + } + + { + nb::gil_scoped_release gil_release; + TF_ASSIGN_OR_RETURN(client, CreateClient(proxy_server_address, options)); + } + + // Constructing `xla::PyClient` requires GIL as it may dec-ref Python objects. + return xla::PyClient::Make(std::move(client)); +} + +} // namespace + +void BuildIfrtProxySubmodule(nb::module_& m) { + nb::module_ sub_module = m.def_submodule("ifrt_proxy", "IFRT proxy"); + + nb::class_(sub_module, "ClientConnectionOptions") + .def(nb::init<>()) + .def_rw("on_disconnect", &PyClientConnectionOptions::on_disconnect, + nb::arg().none()) + .def_rw("on_connection_update", + &PyClientConnectionOptions::on_connection_update, + nb::arg().none()) + .def_rw("connection_timeout_in_seconds", + &PyClientConnectionOptions::connection_timeout_in_seconds, + nb::arg().none()) + .def_rw("initialization_data", + &PyClientConnectionOptions::initialization_data, + nb::arg().none()); + + sub_module.def("get_client", xla::ValueOrThrowWrapper(GetClient), + nb::arg("proxy_server_address"), nb::arg("options")); +} + +} // namespace proxy +} // namespace ifrt +} // namespace xla diff --git a/jaxlib/ifrt_proxy.h b/jaxlib/ifrt_proxy.h new file mode 100644 index 000000000000..2bfc19062012 --- /dev/null +++ b/jaxlib/ifrt_proxy.h @@ -0,0 +1,31 @@ +/* Copyright 2024 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_IFRT_PROXY_CLIENT_PY_MODULE_H_ +#define JAXLIB_IFRT_PROXY_CLIENT_PY_MODULE_H_ + +#include "nanobind/nanobind.h" + +namespace xla { +namespace ifrt { +namespace proxy { + +void BuildIfrtProxySubmodule(nanobind::module_& m); + +} // namespace proxy +} // namespace ifrt +} // namespace xla + +#endif // JAXLIB_IFRT_PROXY_CLIENT_PY_MODULE_H_ diff --git a/jaxlib/jax.bzl b/jaxlib/jax.bzl index 89f1545995d5..c6dd9b1bdb3f 100644 --- a/jaxlib/jax.bzl +++ b/jaxlib/jax.bzl @@ -22,8 +22,9 @@ load("@local_config_cuda//cuda:build_defs.bzl", _cuda_library = "cuda_library", load("@local_config_rocm//rocm:build_defs.bzl", _if_rocm_is_configured = "if_rocm_is_configured", _rocm_library = "rocm_library") load("@python_version_repo//:py_version.bzl", "HERMETIC_PYTHON_VERSION") load("@rules_cc//cc:defs.bzl", _cc_proto_library = "cc_proto_library") -load("@rules_python//python:defs.bzl", "py_test") -load("@xla//xla/tsl:tsl.bzl", _if_windows = "if_windows", _pybind_extension = "tsl_pybind_extension_opensource") +load("@rules_python//python:defs.bzl", "py_library", "py_test") +load("@xla//third_party/py:python_wheel.bzl", "collect_data_files", "transitive_py_deps") +load("@xla//xla/tsl:tsl.bzl", "transitive_hdrs", _if_windows = "if_windows", _pybind_extension = "tsl_pybind_extension_opensource") load("@xla//xla/tsl/platform:build_config_root.bzl", _tf_cuda_tests_tags = "tf_cuda_tests_tags", _tf_exec_properties = "tf_exec_properties") # Explicitly re-exports names to avoid "unused variable" warnings from .bzl @@ -31,7 +32,7 @@ load("@xla//xla/tsl/platform:build_config_root.bzl", _tf_cuda_tests_tags = "tf_c cc_proto_library = _cc_proto_library cuda_library = _cuda_library rocm_library = _rocm_library -pytype_test = native.py_test +proto_library = native.proto_library nanobind_extension = _pybind_extension if_cuda_is_configured = _if_cuda_is_configured if_rocm_is_configured = _if_rocm_is_configured @@ -49,6 +50,7 @@ pallas_tpu_internal_users = [] pallas_fuser_users = [] mosaic_extension_deps = [] serialize_executable_internal_users = [] +buffer_callback_internal_users = [] jax_internal_export_back_compat_test_util_visibility = [] jax_internal_test_harnesses_visibility = [] @@ -64,32 +66,37 @@ PLATFORM_TAGS_DICT = { ("Windows", "AMD64"): ("win", "amd64"), } -# TODO(vam): remove this once zstandard builds against Python 3.13 +# TODO(vam): remove this once zstandard builds against Python >3.13 def get_zstandard(): - if HERMETIC_PYTHON_VERSION == "3.13" or HERMETIC_PYTHON_VERSION == "3.13-ft": + if HERMETIC_PYTHON_VERSION in ("3.13", "3.13-ft", "3.14", "3.14-ft"): return [] - return ["@pypi_zstandard//:pkg"] + return ["@pypi//zstandard"] + +def get_optional_dep(package, excluded_py_versions = ["3.14", "3.14-ft"]): + if HERMETIC_PYTHON_VERSION in excluded_py_versions: + return [] + return [package] _py_deps = { - "absl/logging": ["@pypi_absl_py//:pkg"], - "absl/testing": ["@pypi_absl_py//:pkg"], - "absl/flags": ["@pypi_absl_py//:pkg"], - "cloudpickle": ["@pypi_cloudpickle//:pkg"], - "colorama": ["@pypi_colorama//:pkg"], - "epath": ["@pypi_etils//:pkg"], # etils.epath - "filelock": ["@pypi_filelock//:pkg"], - "flatbuffers": ["@pypi_flatbuffers//:pkg"], - "hypothesis": ["@pypi_hypothesis//:pkg"], + "absl/logging": ["@pypi//absl_py"], + "absl/testing": ["@pypi//absl_py"], + "absl/flags": ["@pypi//absl_py"], + "cloudpickle": get_optional_dep("@pypi//cloudpickle"), + "epath": get_optional_dep("@pypi//etils"), # etils.epath + "filelock": get_optional_dep("@pypi//filelock"), + "flatbuffers": ["@pypi//flatbuffers"], + "hypothesis": ["@pypi//hypothesis"], "magma": [], - "matplotlib": ["@pypi_matplotlib//:pkg"], + "matplotlib": get_optional_dep("@pypi//matplotlib"), "mpmath": [], - "opt_einsum": ["@pypi_opt_einsum//:pkg"], - "pil": ["@pypi_pillow//:pkg"], - "portpicker": ["@pypi_portpicker//:pkg"], - "ml_dtypes": ["@pypi_ml_dtypes//:pkg"], - "numpy": ["@pypi_numpy//:pkg"], - "scipy": ["@pypi_scipy//:pkg"], + "opt_einsum": ["@pypi//opt_einsum"], + "pil": get_optional_dep("@pypi//pillow"), + "portpicker": get_optional_dep("@pypi//portpicker"), + "ml_dtypes": ["@pypi//ml_dtypes"], + "numpy": ["@pypi//numpy"], + "scipy": ["@pypi//scipy"], "tensorflow_core": [], + "tensorstore": get_optional_dep("@pypi//tensorstore"), "torch": [], "zstandard": get_zstandard(), } @@ -125,14 +132,17 @@ jax2tf_deps = [] def pytype_library(name, pytype_srcs = None, **kwargs): _ = pytype_srcs # @unused - native.py_library(name = name, **kwargs) + py_library(name = name, **kwargs) def pytype_strict_library(name, pytype_srcs = [], **kwargs): data = pytype_srcs + (kwargs["data"] if "data" in kwargs else []) new_kwargs = {k: v for k, v in kwargs.items() if k != "data"} - native.py_library(name = name, data = data, **new_kwargs) + py_library(name = name, data = data, **new_kwargs) + +py_strict_library = py_library +py_strict_test = py_test -def py_library_providing_imports_info(*, name, lib_rule = native.py_library, pytype_srcs = [], **kwargs): +def py_library_providing_imports_info(*, name, lib_rule = py_library, pytype_srcs = [], **kwargs): data = pytype_srcs + (kwargs["data"] if "data" in kwargs else []) new_kwargs = {k: v for k, v in kwargs.items() if k != "data"} lib_rule(name = name, data = data, **new_kwargs) @@ -140,119 +150,85 @@ def py_library_providing_imports_info(*, name, lib_rule = native.py_library, pyt def py_extension(name, srcs, copts, deps, linkopts = []): nanobind_extension(name, srcs = srcs, copts = copts, linkopts = linkopts, deps = deps, module_name = name) -def windows_cc_shared_mlir_library(name, out, deps = [], srcs = [], exported_symbol_prefixes = []): - """Workaround DLL building issue. - - 1. cc_binary with linkshared enabled cannot produce DLL with symbol - correctly exported. - 2. Even if the DLL is correctly built, the resulting target cannot be - correctly consumed by other targets. - - Args: - name: the name of the output target - out: the name of the output DLL filename - deps: deps - srcs: srcs - """ - - # create a dummy library to get the *.def file - dummy_library_name = name + ".dummy.dll" - native.cc_binary( - name = dummy_library_name, - linkshared = 1, - linkstatic = 1, - deps = deps, - target_compatible_with = ["@platforms//os:windows"], - ) - - # .def file with all symbols, not usable - full_def_name = name + ".full.def" - native.filegroup( - name = full_def_name, - srcs = [dummy_library_name], - output_group = "def_file", - target_compatible_with = ["@platforms//os:windows"], - ) - - # say filtered_symbol_prefixes == ["mlir", "chlo"], then construct the regex - # pattern as "^\\s*(mlir|clho)" to use grep - pattern = "^\\s*(" + "|".join(exported_symbol_prefixes) + ")" - - # filtered def_file, only the needed symbols are included - filtered_def_name = name + ".filtered.def" - filtered_def_file = out + ".def" - native.genrule( - name = filtered_def_name, - srcs = [full_def_name], - outs = [filtered_def_file], - cmd = """echo 'LIBRARY {}\nEXPORTS ' > $@ && grep -E '{}' $(location :{}) >> $@""".format(out, pattern, full_def_name), - target_compatible_with = ["@platforms//os:windows"], - ) - - # create the desired library - native.cc_binary( - name = out, # this name must be correct, it will be the filename - linkshared = 1, - deps = deps, - win_def_file = filtered_def_file, - target_compatible_with = ["@platforms//os:windows"], - ) - - # however, the created cc_library (a shared library) cannot be correctly - # consumed by other cc_*... - interface_library_file = out + ".if.lib" - native.filegroup( - name = interface_library_file, - srcs = [out], - output_group = "interface_library", - target_compatible_with = ["@platforms//os:windows"], - ) - - # but this one can be correctly consumed, this is our final product - native.cc_import( - name = name, - interface_library = interface_library_file, - shared_library = out, - target_compatible_with = ["@platforms//os:windows"], - ) - ALL_BACKENDS = ["cpu", "gpu", "tpu"] def if_building_jaxlib( if_building, if_not_building = [ - "@pypi_jaxlib//:pkg", - "@pypi_jax_cuda12_plugin//:pkg", - "@pypi_jax_cuda12_pjrt//:pkg", - ], - if_not_building_for_cpu = ["@pypi_jaxlib//:pkg"], - if_py_import = [ - "//jaxlib/tools:jaxlib_py_import", - "//jaxlib/tools:jax_cuda_plugin_py_import", - "//jaxlib/tools:jax_cuda_pjrt_py_import", - ], - if_py_import_for_cpu = [ - "//jaxlib/tools:jaxlib_py_import", + "@pypi//jaxlib", ]): - """Adds jaxlib and jaxlib cuda plugin wheels as dependencies instead of depending on sources. + """Adds jaxlib wheels as dependencies instead of depending on sources. This allows us to test prebuilt versions of jaxlib wheels against the rest of the JAX codebase. Args: if_building: the source code targets to depend on in case we don't depend on the jaxlib wheels - if_not_building: the jaxlib wheels to depend on including gpu-specific plugins in case of - gpu-enabled builds - if_not_building_for_cpu: the jaxlib wheels to depend on in case of cpu-only builds - if_py_import: the py_import targets to depend on in case of gpu-enabled builds - if_py_import_for_cpu: the py_import targets to depend on in case of cpu-only builds + if_not_building: the wheels to depend on if we are not depending directly on //jaxlib. """ + return select({ + "//jax:config_build_jaxlib_true": if_building, + "//jax:config_build_jaxlib_false": if_not_building, + "//jax:config_build_jaxlib_wheel": [], + }) +def _cpu_test_deps(): + """Returns the test dependencies needed for a CPU-only JAX test.""" return select({ - "//jax:enable_jaxlib_build": if_building, - "//jax_plugins/cuda:disable_jaxlib_for_cpu_build": if_not_building_for_cpu, - "//jax_plugins/cuda:disable_jaxlib_for_cuda12_build": if_not_building, - "//jax_plugins/cuda:enable_py_import_for_cpu_build": if_py_import_for_cpu, - "//jax_plugins/cuda:enable_py_import_for_cuda12_build": if_py_import, + "//jax:config_build_jaxlib_true": [], + "//jax:config_build_jaxlib_false": ["@pypi//jaxlib"], + "//jax:config_build_jaxlib_wheel": ["//jaxlib/tools:jaxlib_py_import"], + }) + +def _gpu_test_deps(): + """Returns the additional dependencies needed for a GPU test.""" + return select({ + "//jax:config_build_jaxlib_true": [ + "//jaxlib/cuda:gpu_only_test_deps", + "//jaxlib/rocm:gpu_only_test_deps", + "//jax_plugins:gpu_plugin_only_test_deps", + # TODO(ybaturina): Remove this once we can add NVSHMEM libraries in the dependencies. + "@pypi//nvidia_nvshmem_cu12", + ], + "//jax:config_build_jaxlib_false": [ + "//jaxlib/tools:pypi_jax_cuda_plugin_with_cuda_deps", + "//jaxlib/tools:pypi_jax_cuda_pjrt_with_cuda_deps", + ], + "//jax:config_build_jaxlib_wheel": [ + "//jaxlib/tools:jax_cuda_plugin_py_import", + "//jaxlib/tools:jax_cuda_pjrt_py_import", + ], + }) + +def _get_jax_test_deps(deps): + """Returns the jax build deps, pypi jax wheel dep, or jax py_import dep for the given backend. + + Args: + deps: the full list of test dependencies + + Returns: + A list of jax test deps. + + If --//jax:build_jax=true, returns jax build deps. + If --//jax:build_jax=false, returns jax pypi wheel dep and transitive pypi test deps. + If --//jax:build_jax=wheel, returns jax py_import dep and transitive pypi test deps. + """ + non_pypi_deps = [d for d in deps if not d.startswith("@pypi//")] + + # A lot of tests don't have explicit dependencies on scipy, ml_dtypes, etc. But the tests + # transitively depends on them via //jax. So we need to make sure that these dependencies are + # included in the test when JAX is built from source. + pypi_deps = depset([d for d in deps if d.startswith("@pypi//")]) + pypi_deps = depset(py_deps([ + "ml_dtypes", + "scipy", + "opt_einsum", + "flatbuffers", + ]), transitive = [pypi_deps]).to_list() + + return pypi_deps + select({ + "//jax:config_build_jax_false": ["//:jax_wheel_with_internal_test_util"], + "//jax:config_build_jax_wheel": ["//:jax_py_import"], + "//jax:config_build_jax_true": non_pypi_deps, }) # buildifier: disable=function-docstring @@ -286,6 +262,9 @@ def jax_multiplatform_test( else: fail("Must set a main file to test multiple source files.") + env = dict(env) + env.setdefault("PYTHONWARNINGS", "error") + for backend in ALL_BACKENDS: if shard_count == None or type(shard_count) == type(0): test_shards = shard_count @@ -298,21 +277,21 @@ def jax_multiplatform_test( test_tags = list(tags) + ["jax_test_%s" % backend] + backend_tags.get(backend, []) if enable_backends != None and backend not in enable_backends and not any([config.startswith(backend) for config in enable_configs]): test_tags.append("manual") + test_deps = _cpu_test_deps() + _get_jax_test_deps([ + "//jax", + "//jax:test_util", + ] + deps) if backend == "gpu": + test_deps += _gpu_test_deps() test_tags += tf_cuda_tests_tags() + elif backend == "tpu": + test_deps += ["@pypi//libtpu"] native.py_test( name = name + "_" + backend, srcs = srcs, args = test_args, env = env, - deps = [ - "//jax", - "//jax:test_util", - ] + deps + if_building_jaxlib([ - "//jaxlib/cuda:gpu_only_test_deps", - "//jaxlib/rocm:gpu_only_test_deps", - "//jax_plugins:gpu_plugin_only_test_deps", - ]), + deps = test_deps, data = data, shard_count = test_shards, tags = test_tags, @@ -362,7 +341,7 @@ def _get_full_wheel_name( free_threaded_suffix = "t" if py_freethreaded.lower() == "yes" else "", ) -def _get_source_distribution_name(package_name, wheel_version): +def _get_source_package_name(package_name, wheel_version): return "{package_name}-{wheel_version}.tar.gz".format( package_name = package_name, wheel_version = wheel_version, @@ -370,7 +349,9 @@ def _get_source_distribution_name(package_name, wheel_version): def _jax_wheel_impl(ctx): include_cuda_libs = ctx.attr.include_cuda_libs[BuildSettingInfo].value + include_nvshmem_libs = ctx.attr.include_nvshmem_libs[BuildSettingInfo].value override_include_cuda_libs = ctx.attr.override_include_cuda_libs[BuildSettingInfo].value + override_include_nvshmem_libs = ctx.attr.override_include_nvshmem_libs[BuildSettingInfo].value output_path = ctx.attr.output_path[BuildSettingInfo].value git_hash = ctx.attr.git_hash[BuildSettingInfo].value py_freethreaded = ctx.attr.py_freethreaded[BuildSettingInfo].value @@ -381,6 +362,11 @@ def _jax_wheel_impl(ctx): " Please provide `--config=cuda_libraries_from_stubs` for bazel build command." + " If you absolutely need to build links directly against the CUDA libraries, provide" + " `--@local_config_cuda//cuda:override_include_cuda_libs=true`.") + if include_nvshmem_libs and not override_include_nvshmem_libs: + fail("JAX wheel shouldn't be built directly against the NVSHMEM libraries." + + " Please provide `--config=cuda_libraries_from_stubs` for bazel build command." + + " If you absolutely need to build links directly against the NVSHMEM libraries," + + " `provide --@local_config_nvshmem//:override_include_nvshmem_libs=true`.") env = {} args = ctx.actions.args() @@ -394,37 +380,47 @@ def _jax_wheel_impl(ctx): no_abi = ctx.attr.no_abi platform_independent = ctx.attr.platform_independent build_wheel_only = ctx.attr.build_wheel_only + build_source_package_only = ctx.attr.build_source_package_only editable = ctx.attr.editable platform_name = ctx.attr.platform_name + + output_dir_path = "" + outputs = [] if editable: output_dir = ctx.actions.declare_directory(output_path + "/" + ctx.attr.wheel_name) - wheel_dir = output_dir.path + output_dir_path = output_dir.path outputs = [output_dir] args.add("--editable") else: - wheel_name = _get_full_wheel_name( - package_name = ctx.attr.wheel_name, - no_abi = no_abi, - platform_independent = platform_independent, - platform_name = platform_name, - cpu_name = cpu, - wheel_version = full_wheel_version, - py_freethreaded = py_freethreaded, - ) - wheel_file = ctx.actions.declare_file(output_path + - "/" + wheel_name) - wheel_dir = wheel_file.path[:wheel_file.path.rfind("/")] - outputs = [wheel_file] - if not build_wheel_only: - source_distribution_name = _get_source_distribution_name( + if build_wheel_only: + wheel_name = _get_full_wheel_name( package_name = ctx.attr.wheel_name, + no_abi = no_abi, + platform_independent = platform_independent, + platform_name = platform_name, + cpu_name = cpu, wheel_version = full_wheel_version, + py_freethreaded = py_freethreaded, ) - source_distribution_file = ctx.actions.declare_file(output_path + - "/" + source_distribution_name) - outputs.append(source_distribution_file) - - args.add("--output_path", wheel_dir) # required argument + wheel_file = ctx.actions.declare_file(output_path + + "/" + wheel_name) + output_dir_path = wheel_file.path[:wheel_file.path.rfind("/")] + outputs = [wheel_file] + if ctx.attr.wheel_name == "jax": + args.add("--build-wheel-only", "True") + if build_source_package_only: + source_package_name = _get_source_package_name( + package_name = ctx.attr.wheel_name, + wheel_version = full_wheel_version, + ) + source_package_file = ctx.actions.declare_file(output_path + + "/" + source_package_name) + output_dir_path = source_package_file.path[:source_package_file.path.rfind("/")] + outputs = [source_package_file] + if ctx.attr.wheel_name == "jax": + args.add("--build-source-package-only", "True") + + args.add("--output_path", output_dir_path) # required argument if not platform_independent: args.add("--cpu", cpu) args.add("--jaxlib_git_hash", git_hash) # required argument @@ -464,28 +460,30 @@ def _jax_wheel_impl(ctx): _jax_wheel = rule( attrs = { "wheel_binary": attr.label( - default = Label("//jaxlib/tools:build_wheel"), + default = Label("//jaxlib/tools:build_wheel_tool"), executable = True, - # b/365588895 Investigate cfg = "exec" for multi platform builds - cfg = "target", + cfg = "exec", ), "wheel_name": attr.string(mandatory = True), "no_abi": attr.bool(default = False), "platform_independent": attr.bool(default = False), - "build_wheel_only": attr.bool(default = True), + "build_wheel_only": attr.bool(mandatory = True, default = True), + "build_source_package_only": attr.bool(mandatory = True, default = False), "editable": attr.bool(default = False), - "cpu": attr.string(mandatory = True), - "platform_name": attr.string(mandatory = True), + "cpu": attr.string(), + "platform_name": attr.string(), "git_hash": attr.label(default = Label("//jaxlib/tools:jaxlib_git_hash")), "source_files": attr.label_list(allow_files = True), "output_path": attr.label(default = Label("//jaxlib/tools:output_path")), "enable_cuda": attr.bool(default = False), # A cuda/rocm version is required for gpu wheels; for cpu wheels, it can be an empty string. - "platform_version": attr.string(mandatory = True, default = ""), + "platform_version": attr.string(), "skip_gpu_kernels": attr.bool(default = False), "enable_rocm": attr.bool(default = False), "include_cuda_libs": attr.label(default = Label("@local_config_cuda//cuda:include_cuda_libs")), "override_include_cuda_libs": attr.label(default = Label("@local_config_cuda//cuda:override_include_cuda_libs")), + "include_nvshmem_libs": attr.label(default = Label("@local_config_nvshmem//:include_nvshmem_libs")), + "override_include_nvshmem_libs": attr.label(default = Label("@local_config_nvshmem//:override_include_nvshmem_libs")), "py_freethreaded": attr.label(default = Label("@rules_python//python/config_settings:py_freethreaded")), }, implementation = _jax_wheel_impl, @@ -498,7 +496,6 @@ def jax_wheel( wheel_name, no_abi = False, platform_independent = False, - build_wheel_only = True, editable = False, enable_cuda = False, enable_rocm = False, @@ -509,11 +506,10 @@ def jax_wheel( Common artifact attributes are grouped within a single macro. Args: - name: the name of the wheel + name: the target name wheel_binary: the binary to use to build the wheel wheel_name: the name of the wheel no_abi: whether to build a wheel without ABI - build_wheel_only: whether to build a wheel without source distribution editable: whether to build an editable wheel platform_independent: whether to build a wheel without platform tag enable_cuda: whether to build a cuda wheel @@ -522,7 +518,7 @@ def jax_wheel( source_files: the source files to include in the wheel Returns: - A directory containing the wheel + A wheel file or a wheel directory. """ _jax_wheel( name = name, @@ -530,7 +526,8 @@ def jax_wheel( wheel_name = wheel_name, no_abi = no_abi, platform_independent = platform_independent, - build_wheel_only = build_wheel_only, + build_wheel_only = True, + build_source_package_only = False, editable = editable, enable_cuda = enable_cuda, enable_rocm = enable_rocm, @@ -554,6 +551,34 @@ def jax_wheel( source_files = source_files, ) +def jax_source_package( + name, + source_package_binary, + source_package_name, + source_files = []): + """Create jax source package. + + Common artifact attributes are grouped within a single macro. + + Args: + name: the target name + source_package_binary: the binary to use to build the package + source_package_name: the name of the source package + source_files: the source files to include in the package + + Returns: + A jax source package file. + """ + _jax_wheel( + name = name, + wheel_binary = source_package_binary, + wheel_name = source_package_name, + build_source_package_only = True, + build_wheel_only = False, + platform_independent = True, + source_files = source_files, + ) + jax_test_file_visibility = [] jax_export_file_visibility = [] @@ -566,6 +591,63 @@ def jax_py_test( env = {}, **kwargs): env = dict(env) - if "PYTHONWARNINGS" not in env: - env["PYTHONWARNINGS"] = "error" + env.setdefault("PYTHONWARNINGS", "error") + deps = kwargs.get("deps", []) + test_deps = _cpu_test_deps() + _get_jax_test_deps(deps) + kwargs["deps"] = test_deps py_test(name = name, env = env, **kwargs) + +def pytype_test(name, **kwargs): + deps = kwargs.get("deps", []) + test_deps = _cpu_test_deps() + _get_jax_test_deps(deps) + kwargs["deps"] = test_deps + native.py_test(name = name, **kwargs) + +def if_oss(oss_value, google_value = []): + """Returns one of the arguments based on the non-configurable build env. + + Specifically, it does not return a `select`, and can be used to e.g. + compute elements of list attributes. + """ + _ = (google_value, oss_value) # buildifier: disable=unused-variable + return oss_value + +def wheel_sources( + name, + py_srcs = [], + data_srcs = [], + symlink_data_srcs = [], + hdr_srcs = [], + static_srcs = []): + """Create a filegroup containing the list of source files for a wheel. + + The sources are collected from the static files and from the transitive dependencies of the + given srcs. + + Args: + name: the target name + py_srcs: targets which transitive python dependencies should be included in the wheel + data_srcs: targets which platform-dependent data dependencies should be included in the wheel + symlink_data_srcs: targets which symlinked data dependencies should be included in the wheel + hdr_srcs: targets which transitive header dependencies should be included in the wheel + static_srcs: the platform-independent file dependencies of the wheel + """ + transitive_py_deps(name = "{}_py".format(name), deps = py_srcs) + collect_data_files( + name = "{}_data".format(name), + deps = data_srcs, + symlink_deps = symlink_data_srcs, + ) + transitive_hdrs(name = "{}_hdrs".format(name), deps = hdr_srcs) + native.filegroup(name = name, srcs = [ + ":{}_py".format(name), + ":{}_data".format(name), + ":{}_hdrs".format(name), + ] + static_srcs) + +def if_pypi_cuda_wheel_deps(if_true, if_false = []): + """ select() on whether we're adding pypi CUDA wheel deps. """ + return select({ + "//jaxlib/tools:pypi_cuda_wheel_deps": if_true, + "//conditions:default": if_false, + }) diff --git a/jaxlib/jax_common.json b/jaxlib/jax_common.json new file mode 100644 index 000000000000..61a2c9313897 --- /dev/null +++ b/jaxlib/jax_common.json @@ -0,0 +1,8 @@ +{ + "global": [ + "Wrapped_PyInit_*" + ], + "local": [ + "*" + ] +} diff --git a/jaxlib/jax_jit.cc b/jaxlib/jax_jit.cc new file mode 100644 index 000000000000..e314213e055d --- /dev/null +++ b/jaxlib/jax_jit.cc @@ -0,0 +1,561 @@ +/* Copyright 2020 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This files implements the `jax.jit` dispatch and just-in-time feature. +// +// In a nutshell, `Jit(f)` returns a callable that will dispatch (i.e. forward +// based on passed arguments dtypes/shapes/identity) the execution to a +// just-in-time compiled XLA Executable. All of that is done in C++ for +// performance reasons. +// +// This file contains the utilities to: +// (a) inspect arguments and describe their structure, dtype/shapes, etc. +// (b) keep a mapping from function signatures to compiled XLA Executables. + +#include "jaxlib/jax_jit.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/base/attributes.h" +#include "absl/container/inlined_vector.h" +#include "absl/hash/hash.h" +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/optional.h" // IWYU pragma: keep +#include "nanobind/stl/pair.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "jaxlib/py_values.h" +#include "jaxlib/pytree.h" +#include "jaxlib/sharding.h" +#include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_layout.h" +#include "xla/pjrt/status_casters.h" +#include "xla/python/nb_absl_inlined_vector.h" // IWYU pragma: keep +#include "xla/python/nb_absl_span.h" // IWYU pragma: keep +#include "xla/python/types.h" +#include "xla/tsl/platform/logging.h" +#include "tsl/profiler/lib/traceme.h" + +namespace jax { + +namespace nb = nanobind; + +// TODO(phawkins): Add support for Tracers. +// TODO(jblespiau): Use absl absl::Status. + +namespace { + +// `thread_local_state.extra_jit_context` is set from Python. It's done when +// loading the Python jax modules on the main-thread. For other threads, we +// need to initialize the field the first time we access `thread_local_state`. +nb::object& initialize_local_state = *new nb::object(); + +} // namespace + +JitState& GlobalJitState() { + // Protected by the GIL. + static JitState& global_state = *new JitState(); + return global_state; +} + +JitState& ThreadLocalJitState() { + // TODO(phawkins): Google style guide forbids thread-local values with + // non-trivial destructors. + ABSL_CONST_INIT thread_local JitState thread_local_state; // NOLINT + DCHECK(PyGILState_Check()); + if (thread_local_state.extra_jit_context == std::nullopt) { + CHECK(initialize_local_state.ptr() != nullptr); + // Avoids reentrant calls to the initialization function. + thread_local_state.extra_jit_context = nb::none(); + initialize_local_state(); + } + return thread_local_state; +} + +bool GetDisableJit() { + auto& global_state = GlobalJitState(); + auto& thread_local_state = ThreadLocalJitState(); + CHECK(global_state.disable_jit.has_value()); + return thread_local_state.disable_jit.value_or(*global_state.disable_jit); +} + +bool GetEnableX64() { + auto& global_state = GlobalJitState(); + auto& thread_local_state = ThreadLocalJitState(); + CHECK(global_state.enable_x64.has_value()); + return thread_local_state.enable_x64.value_or(*global_state.enable_x64); +} + +std::optional GetDefaultDevice() { + auto& global_state = GlobalJitState(); + auto& thread_local_state = ThreadLocalJitState(); + return thread_local_state.default_device.has_value() + ? thread_local_state.default_device + : global_state.default_device; +} + +std::optional GetPostHook() { + auto& global_state = GlobalJitState(); + auto& thread_local_state = ThreadLocalJitState(); + return thread_local_state.post_hook.has_value() ? thread_local_state.post_hook + : global_state.post_hook; +} + +static std::string OptionalDebugString( + const std::optional optional) { + if (optional.has_value()) { + return nb::cast(nb::str(optional.value())); + } else { + return "None"; + } +} + +std::string ArgumentSignature::DebugString() const { + auto py_object_formatter = [](std::string* out, const nb::object& o) { + out->append(nb::cast(nb::str(o))); + }; + auto treedef_formatter = [](std::string* out, const xla::PyTreeDef& d) { + out->append(d.ToString()); + }; + return absl::StrFormat( + "static args (positional + keyword): [%s], " + "static arg keyword names: [%s], " + "dynamic arg signatures (positional + keyword): [%s], " + "dynamic arg shardings: [%s]", + absl::StrJoin(static_args, ",", py_object_formatter), + absl::StrJoin(static_arg_names, ",", py_object_formatter), + absl::StrJoin(dynamic_arg_names, ",", py_object_formatter), + absl::StrJoin(dynamic_arg_treedefs, "| ", treedef_formatter)); +} + +bool ArgumentSignature::operator==(const ArgumentSignature& other) const { + if (dynamic_arg_treedefs != other.dynamic_arg_treedefs) { + return false; + } + auto object_ptr_equality = [](nb::handle a, nb::handle b) { + return a.ptr() == b.ptr(); + }; + if (!absl::c_equal(dynamic_arg_names, other.dynamic_arg_names, + object_ptr_equality)) { + return false; + } + if (!absl::c_equal(static_arg_names, other.static_arg_names, + object_ptr_equality)) { + return false; + } + return absl::c_equal( + static_args, other.static_args, + [](const nb::object& a, const nb::object& b) { + try { + return a.type().ptr() == b.type().ptr() && a.equal(b); + } catch (const nb::python_error& e) { + throw std::invalid_argument(absl::StrCat( + "static arguments should be comparable using __eq__." + "The following error was raised when comparing two objects of " + "types ", + nb::cast(nb::str(a.type())), " and ", + nb::cast(nb::str(b.type())), + ". The error was:\n", e.what())); + } + }); +} + +std::string CallSignature::DebugString() const { + auto py_object_formatter = [](std::string* out, const nb::object& o) { + out->append(nb::cast(nb::str(o))); + }; + auto signature_formatter = [](std::string* out, + const xla::PyArgSignature& s) { + out->append(s.DebugString()); + }; + auto layout_formatter = [](std::string* out, + const std::shared_ptr& l) { + if (l != nullptr) { + out->append(l->ToString()); + } else { + out->append("None"); + } + }; + auto bool_formatter = [](std::string* out, bool o) { + out->append(o ? "true" : "false"); + }; + return absl::StrFormat( + "arg signature: %s\n" + "dynamic arg signatures (positional + keyword): %s\n" + "dynamic arg shardings: %s\n" + "dynamic arg layouts: %s\n" + "committed args: %s\n" + "device: %s\n" + "default_device: %s\n" + "jax_enable_x64: %d\n" + "global_extra_jit_context: %s\n" + "thread_local_extra_jit_context: %s\n" + "configs: %s\n", + arg_signature.DebugString(), + absl::StrJoin(dynamic_arg_signatures, ", ", signature_formatter), + absl::StrJoin(dynamic_arg_shardings, ", ", py_object_formatter), + absl::StrJoin(dynamic_arg_layouts, ", ", layout_formatter), + absl::StrJoin(committed_args, ",", bool_formatter), + device != nullptr ? device->DebugString() : "nullptr", + OptionalDebugString(default_device), jax_enable_x64, + OptionalDebugString(global_extra_jit_context), + OptionalDebugString(thread_local_extra_jit_context), + absl::StrJoin(configs, ", ", py_object_formatter)); +} + + +size_t HashShardingForJit(nb::handle sharding) { + auto type = sharding.type(); + + if (type.is(NamedSharding::type())) { + const auto* named_sharding = nb::inst_ptr(sharding); + return absl::Hash()(named_sharding->mesh().ptr()); + } + + if (type.is(GSPMDSharding::type())) { + auto* gspmd_sharding = nb::inst_ptr(sharding); + return gspmd_sharding->Hash(); + } + + if (type.is(SingleDeviceSharding::type())) { + auto* single_device_sharding = nb::inst_ptr(sharding); + return absl::Hash()(single_device_sharding->device().ptr()); + } + + return nb::hash(sharding); +} + +bool EqualShardingsForJit(nb::handle a, nb::handle b) { + if (a.ptr() == b.ptr()){ + return true; + } + + auto a_type = a.type(); + auto b_type = b.type(); + + if (!a_type.is(b_type)) { + return false; + } + + if (a_type.is(NamedSharding::type())) { + auto* a_named_sharding = nb::inst_ptr(a); + auto* b_named_sharding = nb::inst_ptr(b); + return a_named_sharding->mesh().ptr() == b_named_sharding->mesh().ptr() && + *a_named_sharding->spec() == *b_named_sharding->spec() && + a_named_sharding->memory_kind().equal( + b_named_sharding->memory_kind()) && + a_named_sharding->logical_device_ids().equal( + b_named_sharding->logical_device_ids()); + } + + if (a_type.is(GSPMDSharding::type())) { + auto* a_gspmd_sharding = nb::inst_ptr(a); + auto* b_gspmd_sharding = nb::inst_ptr(b); + return *a_gspmd_sharding == *b_gspmd_sharding; + } + + if (a_type.is(SingleDeviceSharding::type())) { + auto* a_single_device_sharding = + nb::inst_ptr(a); + auto* b_single_device_sharding = + nb::inst_ptr(b); + return a_single_device_sharding->device().ptr() == + b_single_device_sharding->device().ptr() && + a_single_device_sharding->memory_kind().equal( + b_single_device_sharding->memory_kind()); + } + + return a.equal(b); +} + +bool CallSignature::operator==(const CallSignature& other) const { + if (arg_signature != other.arg_signature) { + return false; + } + if (dynamic_arg_signatures != other.dynamic_arg_signatures) { + return false; + } + if (device != other.device) { + return false; + } + if (jax_enable_x64 != other.jax_enable_x64) { + return false; + } + if (committed_args != other.committed_args) { + return false; + } + return + // `==` on py:objects is the Python `is`. We need equal. + absl::c_equal(dynamic_arg_shardings, other.dynamic_arg_shardings, + EqualShardingsForJit) && + absl::c_equal(dynamic_arg_layouts, other.dynamic_arg_layouts, + [](const std::shared_ptr& a, + const std::shared_ptr& b) { + return (a && b) ? *a == *b : a == b; + }) && + (global_extra_jit_context.has_value() == + other.global_extra_jit_context.has_value()) && + (!global_extra_jit_context.has_value() || + global_extra_jit_context->equal(*other.global_extra_jit_context)) && + (default_device.has_value() == other.default_device.has_value()) && + (!default_device.has_value() || + default_device->equal(*other.default_device)) && + (thread_local_extra_jit_context.has_value() == + other.thread_local_extra_jit_context.has_value()) && + (!thread_local_extra_jit_context.has_value() || + thread_local_extra_jit_context->equal( + *other.thread_local_extra_jit_context)) && + configs.size() == other.configs.size() && + absl::c_equal( + configs, other.configs, + [](const nb::object& a, const nb::object& b) { return a.equal(b); }); +} + +// Filter out static arguments, flatten and concatenate other arguments (i.e. +// dynamic positional and keyword arguments), filling `arguments` in place. +absl::Status ParseArguments( + absl::Span positional_args, + absl::Span keyword_args, nb::handle kwnames, + absl::Span static_argnums, + absl::Span static_argnames, + xla::PyTreeRegistry* pytree_registry, ArgumentSignature& signature, + absl::InlinedVector& flat_dynamic_args) { + tsl::profiler::TraceMe traceme("ParseArguments"); + + DCHECK(absl::c_all_of(static_argnames, [](const nb::str& name) { + return PyUnicode_CHECK_INTERNED(name.ptr()); + })); + + flat_dynamic_args.reserve(positional_args.size() + keyword_args.size()); + if (static_argnums.empty()) { + signature.dynamic_arg_treedefs.reserve(positional_args.size()); + + // Positional arguments. + for (int i = 0; i < positional_args.size(); ++i) { + signature.dynamic_arg_treedefs.emplace_back(pytree_registry); + xla::PyTreeDef& pytree_def = signature.dynamic_arg_treedefs.back(); + pytree_def.Flatten(nb::handle(positional_args[i]), flat_dynamic_args); + } + } else { + signature.dynamic_arg_treedefs.reserve(positional_args.size()); + + // Positional arguments. + int num_positional_args = positional_args.size(); + for (int i = 0; i < positional_args.size(); ++i) { + if (std::find_if(static_argnums.begin(), static_argnums.end(), + [i, num_positional_args](int t) { + return t >= 0 ? i == t : i == t + num_positional_args; + }) == static_argnums.end()) { + signature.dynamic_arg_treedefs.emplace_back(pytree_registry); + xla::PyTreeDef& pytree_def = signature.dynamic_arg_treedefs.back(); + pytree_def.Flatten(positional_args[i], flat_dynamic_args); + } else { + signature.static_args.emplace_back( + nb::borrow(positional_args[i])); + } + } + } + + // Keyword arguments. + if (!keyword_args.empty()) { + std::vector> kwargs(keyword_args.size()); + // We first intern the keys, then sort them (by name, as in the Python path) + // (see also xla::PyTreeDef::Flatten) and then create the signatures. + // TODO(jblespiau): We should be able to sort the keys by interned-key + // pointers, but this requires the Python compilation to do the same. + for (int i = 0; i < keyword_args.size(); ++i) { + // Intern the key if not already interned. + PyObject* key = PyTuple_GET_ITEM(kwnames.ptr(), i); + Py_INCREF(key); + if (!PyUnicode_CHECK_INTERNED(key)) { + PyUnicode_InternInPlace(&key); + } + kwargs[i].first = key; + kwargs[i].second = keyword_args[i]; + } + + std::sort(kwargs.begin(), kwargs.end(), + [](const std::pair& a, + const std::pair& b) { + return a.first < b.first; + }); + auto kwarg_is_static = [&](nb::handle name) { + for (const auto& kw : static_argnames) { + if (kw.ptr() == name.ptr()) return true; + } + return false; + }; + + signature.dynamic_arg_names.reserve(keyword_args.size()); + for (int i = 0; i < keyword_args.size(); ++i) { + if (kwarg_is_static(kwargs[i].first)) { + signature.static_arg_names.push_back( + nb::steal(kwargs[i].first)); + signature.static_args.push_back( + nb::borrow(kwargs[i].second)); + } else { + signature.dynamic_arg_names.push_back( + nb::steal(kwargs[i].first)); + signature.dynamic_arg_treedefs.emplace_back(pytree_registry); + xla::PyTreeDef& pytree_def = signature.dynamic_arg_treedefs.back(); + pytree_def.Flatten(nb::handle(kwargs[i].second.ptr()), + flat_dynamic_args); + } + } + } + return absl::OkStatus(); +} + +void BuildJaxjitSubmodule(nb::module_& m) { + nb::module_ jitlib = m.def_submodule("jax_jit", "Jax C++ jit library"); + + nb::class_ jit_state_(jitlib, "JitState"); + jit_state_.def_rw("disable_jit", &JitState::disable_jit, nb::arg().none()); + jit_state_.def_rw("enable_x64", &JitState::enable_x64, nb::arg().none()); + jit_state_.def_rw("default_device", &JitState::default_device, + nb::arg().none()); + jit_state_.def_rw("extra_jit_context", &JitState::extra_jit_context, + nb::arg().none()); + jit_state_.def_rw("post_hook", &JitState::post_hook, nb::arg().none()); + + jitlib.def( + "global_state", [&]() { return &GlobalJitState(); }, + nb::rv_policy::reference); + jitlib.def( + "thread_local_state", [&]() { return &ThreadLocalJitState(); }, + nb::rv_policy::reference); + + jitlib.def( + "swap_thread_local_state_disable_jit", + [&](std::optional value) -> std::optional { + auto tls = &ThreadLocalJitState(); + auto result = tls->disable_jit; + tls->disable_jit = value; + return result; + }, + nb::arg("value").none(), nb::rv_policy::reference); + + jitlib.def("get_enable_x64", &GetEnableX64); + jitlib.def("set_thread_local_state_initialization_callback", + [](nb::object f) { initialize_local_state = f; }); + + nb::class_ arg_signature(jitlib, "PyArgSignature"); + arg_signature + .def_prop_ro( + "dtype", + [](const xla::PyArgSignature& sig) { + return xla::ValueOrThrow(xla::PrimitiveTypeToNbDtype(sig.dtype)); + }) + .def_prop_ro("shape", + [](const xla::PyArgSignature& sig) { + return xla::SpanToNbTuple(absl::MakeConstSpan(sig.shape)); + }) + .def_ro("weak_type", &xla::PyArgSignature::weak_type); + jitlib.def("_ArgSignatureOfValue", + xla::ValueOrThrowWrapper(xla::PyArgSignatureOfValue)); + + jitlib.def("_is_float0", &xla::IsFloat0); + + nb::class_ argument_signature(jitlib, "ArgumentSignature"); + argument_signature.def_ro("static_args", &ArgumentSignature::static_args) + .def_ro("static_arg_names", &ArgumentSignature::static_arg_names) + .def_ro("dynamic_arg_names", &ArgumentSignature::dynamic_arg_names) + .def_ro("dynamic_arg_treedefs", &ArgumentSignature::dynamic_arg_treedefs) + .def("__repr__", &ArgumentSignature::DebugString) + .def("__str__", &ArgumentSignature::DebugString) + .def("__hash__", + [](const ArgumentSignature& s) { return absl::HashOf(s); }) + .def("__eq__", [](const ArgumentSignature& a, + const ArgumentSignature& b) { return a == b; }) + .def("__ne__", [](const ArgumentSignature& a, + const ArgumentSignature& b) { return a != b; }); + + jitlib.def( + "parse_arguments", + [](nb::sequence positional_args, nb::sequence keyword_args, + nb::tuple kwnames, absl::Span static_argnums, + absl::Span static_argnames, + xla::PyTreeRegistry* pytree_registry) { + ArgumentSignature signature; + absl::InlinedVector flat_dynamic_args; + nb::object positional_args_seq = nb::steal(PySequence_Fast( + positional_args.ptr(), "positional_args must be a list or tuple")); + if (!positional_args_seq.ptr()) { + throw nb::python_error(); + } + nb::object keyword_args_seq = nb::steal(PySequence_Fast( + keyword_args.ptr(), "keyword_args must be a list or tuple")); + if (!keyword_args_seq.ptr()) { + throw nb::python_error(); + } + absl::Span positional_args_span = + absl::MakeSpan(PySequence_Fast_ITEMS(positional_args_seq.ptr()), + PySequence_Fast_GET_SIZE(positional_args_seq.ptr())); + absl::Span keyword_args_span = + absl::MakeSpan(PySequence_Fast_ITEMS(keyword_args_seq.ptr()), + PySequence_Fast_GET_SIZE(keyword_args_seq.ptr())); + + // Intern the static argument names. + std::vector static_argnames_interned; + static_argnames_interned.reserve(static_argnames.size()); + for (const nb::str& name : static_argnames) { + PyObject* s = name.inc_ref().ptr(); + PyUnicode_InternInPlace(&s); + static_argnames_interned.push_back(nb::steal(s)); + } + + xla::ThrowIfError( + ParseArguments(positional_args_span, keyword_args_span, kwnames, + static_argnums, static_argnames_interned, + pytree_registry, signature, flat_dynamic_args)); + return std::make_pair(std::move(signature), + std::move(flat_dynamic_args)); + }, + nb::arg("positional_args"), nb::arg("keyword_args"), nb::arg("kwnames"), + nb::arg("static_argnums"), nb::arg("static_argnames"), + nb::arg("pytree_registry"), + R"doc(Parses the arguments to a function as jax.jit would. + +Returns a ArgumentSignature and the flattened dynamic arguments. + +Args: + positional_args: The positional arguments. + keyword_args: The keyword arguments. + kwnames: The keyword names. + static_argnums: The static argument numbers. + static_argnames: The static argument names. + pytree_registry: The pytree registry. +)doc"); +} + +} // namespace jax diff --git a/jaxlib/jax_jit.h b/jaxlib/jax_jit.h new file mode 100644 index 000000000000..0061514e3cfb --- /dev/null +++ b/jaxlib/jax_jit.h @@ -0,0 +1,274 @@ +/* Copyright 2020 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_JAX_JIT_H_ +#define JAXLIB_JAX_JIT_H_ + +#include + +#include +#include +#include +#include +#include +#include +#include + +// placeholder for index annotation headers +#include "absl/container/inlined_vector.h" +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "jaxlib/py_values.h" +#include "jaxlib/python_ref_manager.h" +#include "jaxlib/pytree.h" +#include "jaxlib/sharding.h" +#include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_layout.h" +#include "xla/tsl/platform/logging.h" + +namespace jax { + +// Flags, such as JIT disable and the x64 mode, are controlled by: +// - a global flag value, e.g., associated to --jax_enable_x64 +// - possibly a thread-local value, which initially is std::nullopt and +// overrides the global value if set. The thread-local state is +// used to implement context managers that locally override the global state. +struct JitState { + ~JitState() { + if (extra_jit_context) { + // We likely do not hold the GIL if this JitState is thread-local, so we + // hand the Python object to the global reference manager to destroy. + nanobind::object o = std::move(*extra_jit_context); + xla::GlobalPyRefManager()->AddGarbage(absl::MakeSpan(&o, 1)); + extra_jit_context = std::nullopt; + } + } + + std::optional disable_jit; + std::optional enable_x64; + + // Used to manually set the default device jax should use. May be unset even + // in global state, indicating there is no manual override. + // TODO(skyewm): make this a C++ type when all JAX backends support a single + // C++ device interface + std::optional default_device; + + // Extra context that should be included in the JIT cache key. Must be + // hashable and have an equality defined. + std::optional extra_jit_context; + + // A callback that, if present, is called when a JITted function is executed + // from cache. May be unset even in global state. + std::optional post_hook; +}; + +JitState& GlobalJitState(); + +// Requires the GIL. +JitState& ThreadLocalJitState(); + +// Getters for JitState fields that first look in thread-local state, then +// fallback to global state. +bool GetDisableJit(); +bool GetEnableX64(); + +// TODO(skyewm): return a C++ type when all JAX backends support a single C++ +// device interface +std::optional GetDefaultDevice(); +std::optional GetPostHook(); + +// An ArgumentSignature describes the static arguments to a function call, and +// how the dynamic arguments are related to the arguments. Together with the +// values of the dynamic arguments, this fully describes the arguments. +struct ArgumentSignature { + // A PyTreeDef for each dynamic argument, positional arguments first + // followed by keyword arguments. Keyword arguments are in the order given + // by dynamic_arg_names. + absl::InlinedVector dynamic_arg_treedefs; + + // Dynamic keyword argument names. Interned, and sorted by the keyword + // name. Interned values are safe to compare by pointer. + std::vector dynamic_arg_names; + + // Static arguments. Contains the positional arguments sorted in argument + // order, followed by static keyword arguments in the order given by + // `static_arg_names`. + std::vector static_args; + + // Static keyword argument names. Interned, and sorted by keyword name. + std::vector static_arg_names; + + bool operator==(const ArgumentSignature& other) const; + bool operator!=(const ArgumentSignature& other) const { + return !(*this == other); + } + + std::string DebugString() const; +}; + +template +H AbslHashValue(H h, const ArgumentSignature& s) { + h = H::combine(std::move(h), s.dynamic_arg_treedefs, + s.dynamic_arg_names.size(), s.static_args.size(), + s.static_arg_names.size()); + + for (const auto& name : s.dynamic_arg_names) { + h = H::combine(std::move(h), name.ptr()); + } + for (size_t i = 0; i < s.static_args.size(); ++i) { + const auto& static_arg = s.static_args[i]; + Py_hash_t hash; + try { + hash = nanobind::hash(static_arg); + } catch (const nanobind::python_error& e) { + if (!e.matches(PyExc_TypeError)) throw; + throw std::invalid_argument(absl::StrCat( + "Non-hashable static arguments are not supported. An error occurred " + "while trying to hash an object of type ", + nanobind::cast(nanobind::str(static_arg.type())), + ", ", nanobind::cast(nanobind::str(static_arg)), + ". The error was:\n", e.what(), "\n")); + } + h = H::combine(std::move(h), hash); + } + for (const auto& name : s.static_arg_names) { + h = H::combine(std::move(h), name.ptr()); + } + return h; +} + +// Filter out static arguments, flatten and concatenate other arguments (i.e. +// dynamic positional and keyword arguments), filling `arguments` in place. +// Args: +// positional_args: positional arguments +// keyword_args: the values of the keyword arguments +// kwnames: either None or a tuple containing the keyword argument names +// static_argnums: the indices of the static arguments in the positional +// arguments +// static_argnames: the names of the static arguments, which must be interned. +// pytree_registry: the registry to use to convert the arguments to pytrees +// signature: output; describes the static arguments and the identities of the +// dynamic arguments. +// flat_dynamic_args: output; the concatenation of the dynamic positional +// arguments and sorted keyword arguments. +absl::Status ParseArguments( + absl::Span positional_args, + absl::Span keyword_args, nanobind::handle kwnames, + absl::Span static_argnums, + absl::Span static_argnames, + xla::PyTreeRegistry* pytree_registry, ArgumentSignature& signature, + absl::InlinedVector& flat_dynamic_args); + +// The signature of Python jitted function call, partitioned into: +// - dynamic positional arguments (i.e. positional args which are not static) +// - static positional arguments (i.e. the args associated to static_argnums) +// - keyword arguments +// The CallSignature should unambiguously identify a function call, thus, +// equality is based on: +// (a) Same PyTree for all dynamic positional arguments and keyword arguments +// (a) equality of the arguments and keyword arguments ArgSignature +// (a) equality (delegated to Python) of the static arguments. +struct CallSignature { + // Not part of the signature, but we need it for error messages. + absl::string_view function_name; + + ArgumentSignature arg_signature; + + // Shape and dtype for both the dynamic positional arguments and the keyword + // arguments (sorted by keyword name). + absl::InlinedVector dynamic_arg_signatures; + + // The sharding of the jax.Array arguments. + std::vector dynamic_arg_shardings; + + // The layout of the jax.Array arguments. + std::vector> dynamic_arg_layouts; + + absl::InlinedVector committed_args; + + // For JIT, we need this in the key because computation follows the data, so + // we may have multiple executables depending on the devices the data is on. + // This is not the case for PMAP, and is set to `nullptr`. + xla::PjRtDevice* device = nullptr; + bool jax_enable_x64; + + // For JIT on PJIT, we need to fallback to python whenever default_device + // changes. + std::optional default_device; + + // Opaque additional context that should be included as part of the cache key. + std::optional global_extra_jit_context; + std::optional thread_local_extra_jit_context; + + std::vector configs; + + bool operator==(const CallSignature& other) const; + bool operator!=(const CallSignature& other) const { + return !(*this == other); + } + + std::string DebugString() const; +}; + +// A hash and equality for shardings that may sometimes return different hashes +// for equal values, and may sometimes return "not equal" for equal values. +// These are not correct implementations of `__hash__` and `__eq__` in python, +// but they are fine for jit/pjit dispatch since they only causes spurious cache +// misses. +size_t HashShardingForJit(nanobind::handle sharding); +bool EqualShardingsForJit(nanobind::handle a, nanobind::handle b); + +template +H AbslHashValue(H h, const CallSignature& s) { + h = H::combine(std::move(h), s.arg_signature, s.dynamic_arg_signatures); + + DCHECK(s.dynamic_arg_shardings.empty() || + s.dynamic_arg_shardings.size() == s.dynamic_arg_signatures.size()); + + DCHECK(s.dynamic_arg_layouts.empty() || + s.dynamic_arg_layouts.size() == s.dynamic_arg_signatures.size()); + + // TODO(chky): For now, we are only hashing the pointer of shardings to avoid + // slow python hashing function. Consider implementing hashing function and + // equality checks in C++ in jax::Sharding and use those here. + for (const auto& sharding : s.dynamic_arg_shardings) { + h = H::combine(std::move(h), HashShardingForJit(sharding)); + } + + for (const auto& layout : s.dynamic_arg_layouts) { + if (layout != nullptr) { + h = H::combine(std::move(h), *layout); + } + } + + h = H::combine(std::move(h), s.committed_args, s.device, s.jax_enable_x64); + + // We do not hash the extra_jit_context fields since calling Python hash + // functions is expensive (~300ns) and we don't expect a large number of + // different contexts. + return h; +} + +// The function to call in `xla.cc` to add the bindings for this module. +void BuildJaxjitSubmodule(nanobind::module_& m); + +} // namespace jax + +#endif // JAXLIB_JAX_JIT_H_ diff --git a/jaxlib/jax_jit_test.py b/jaxlib/jax_jit_test.py new file mode 100644 index 000000000000..c242823566dc --- /dev/null +++ b/jaxlib/jax_jit_test.py @@ -0,0 +1,47 @@ +# Copyright 2024 The JAX Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for jax_jit helper functions.""" + +from absl.testing import absltest + +from jax.jaxlib import xla_client + +jax_jit = xla_client._xla.jax_jit +pytree = xla_client._xla.pytree + +pytree_registry = pytree.default_registry() + + +class JaxJitTest(absltest.TestCase): + + def testParseArguments(self): + sig, args = jax_jit.parse_arguments( + positional_args=[1, 2, 3], + keyword_args=[4, 5], + kwnames=("a", "b"), + static_argnums=[0, 2], + static_argnames=["a"], + pytree_registry=pytree_registry, + ) + self.assertEqual(args, [2, 5]) + self.assertEqual(sig.static_args, [1, 3, 4]) + self.assertEqual(sig.static_arg_names, ["a"]) + _, leaf = pytree_registry.flatten(0) + self.assertEqual(sig.dynamic_arg_names, ["b"]) + self.assertEqual(sig.dynamic_arg_treedefs, [leaf, leaf]) + + +if __name__ == "__main__": + absltest.main() diff --git a/jaxlib/kernel_helpers.h b/jaxlib/kernel_helpers.h index dac0355fbde6..5a053f833ce4 100644 --- a/jaxlib/kernel_helpers.h +++ b/jaxlib/kernel_helpers.h @@ -17,10 +17,10 @@ limitations under the License. #define JAXLIB_KERNEL_HELPERS_H_ #include -#include #include #include "absl/base/casts.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" namespace jax { diff --git a/jaxlib/kernel_nanobind_helpers.h b/jaxlib/kernel_nanobind_helpers.h index fde37e695349..127d89f702c8 100644 --- a/jaxlib/kernel_nanobind_helpers.h +++ b/jaxlib/kernel_nanobind_helpers.h @@ -19,8 +19,8 @@ limitations under the License. #include #include -#include "nanobind/nanobind.h" #include "absl/base/casts.h" +#include "nanobind/nanobind.h" #include "jaxlib/kernel_helpers.h" #include "xla/ffi/api/c_api.h" #include "xla/tsl/python/lib/core/numpy.h" // NOLINT diff --git a/jaxlib/libjax_common.lds b/jaxlib/libjax_common.lds new file mode 100644 index 000000000000..6130415a8d26 --- /dev/null +++ b/jaxlib/libjax_common.lds @@ -0,0 +1,7 @@ +{ + global: + Wrapped_PyInit_*; + + local: + *; +}; diff --git a/jaxlib/libjax_common_darwin.lds b/jaxlib/libjax_common_darwin.lds new file mode 100644 index 000000000000..aed9a1d7512a --- /dev/null +++ b/jaxlib/libjax_common_darwin.lds @@ -0,0 +1 @@ +*Wrapped_PyInit_* diff --git a/jaxlib/mlir.cc b/jaxlib/mlir.cc new file mode 100644 index 000000000000..4c8188b04a7f --- /dev/null +++ b/jaxlib/mlir.cc @@ -0,0 +1,234 @@ +/* Copyright 2021 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/mlir.h" + +#include + +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/Bytecode/BytecodeWriter.h" +#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/IR/OwningOpRef.h" +#include "mlir/Parser/Parser.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/LogicalResult.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "stablehlo/dialect/Serialization.h" +#include "xla/hlo/builder/xla_computation.h" +#include "xla/hlo/translate/stablehlo.h" +#include "xla/mlir_hlo/mhlo/transforms/passes.h" +#include "xla/pjrt/mlir_to_hlo.h" +#include "xla/pjrt/status_casters.h" +#include "xla/python/refine_polymorphic_shapes.h" +#include "xla/service/hlo.pb.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/statusor.h" + +namespace nb = nanobind; + +namespace xla { +namespace { + +std::string PrintModule(mlir::ModuleOp module) { + std::string s; + llvm::raw_string_ostream os(s); + mlir::OpPrintingFlags flags; + flags.enableDebugInfo(); + module->print(os, flags); + return s; +} + +absl::StatusOr SerializeUsingBytecode(mlir::ModuleOp module) { + std::string bytecode; + llvm::raw_string_ostream os(bytecode); + mlir::BytecodeWriterConfig config; + if (mlir::failed(mlir::writeBytecodeToFile(module, os, config))) { + return absl::InvalidArgumentError("mlir::writeBytecodeToFile failed"); + } + return bytecode; +} + +void EnablePrintBeforeAndAfter(mlir::PassManager& pm) { + auto print_before = [](mlir::Pass*, mlir::Operation*) { return true; }; + auto print_after = [](mlir::Pass*, mlir::Operation*) { return true; }; + pm.enableIRPrinting(print_before, print_after); +} + +absl::StatusOr HloToStableHlo(const nb::bytes& hlo_module_proto) { + mlir::MLIRContext context; + if (VLOG_IS_ON(3)) context.disableMultithreading(); + HloModuleProto proto; + proto.ParseFromArray(hlo_module_proto.c_str(), hlo_module_proto.size()); + TF_ASSIGN_OR_RETURN(mlir::OwningOpRef module, + ConvertHloToStablehlo(context, &proto)); + TF_ASSIGN_OR_RETURN(std::string bytecode, SerializeUsingBytecode(*module)); + return nb::bytes(bytecode.data(), bytecode.size()); +} + +// Converts an XlaComputation to a StableHLO mlir::Module string. +// Exists for backwards compatibility. +// TODO(phawkins): port remaining users of XlaComputations to use mlir::Modules +// instead and delete this function. +absl::StatusOr PyXlaComputationToMlirModule( + const XlaComputation& computation) { + mlir::MLIRContext context; + if (VLOG_IS_ON(3)) context.disableMultithreading(); + TF_ASSIGN_OR_RETURN(mlir::OwningOpRef module, + ConvertHloToStablehlo(context, &computation.proto())); + return PrintModule(*module); +} + +absl::StatusOr PyMlirModuleToXlaComputation( + absl::string_view mlir_module, bool use_tuple_args, bool return_tuple) { + mlir::MLIRContext context; + TF_ASSIGN_OR_RETURN(mlir::OwningOpRef module, + ParseMlirModuleString(mlir_module, context)); + XlaComputation computation; + TF_RETURN_IF_ERROR(MlirToXlaComputation(*module, computation, use_tuple_args, + return_tuple, + /*use_shardy=*/false)); + return computation; +} + +absl::StatusOr PyMhloToStablehlo(absl::string_view mlir_module) { + mlir::MLIRContext context; + if (VLOG_IS_ON(3)) context.disableMultithreading(); + // JAX can be customized in a way that involves operations from custom + // dialects showing up in JAX IR. + // `ParseMlirModuleString` won't know about these dialects, but that's fine + // since we just want to convert MHLO ops to StableHLO ops here and leave + // everything else unchanged. + // In order to achieve that, we're allowing unregistered dialects here. + context.allowUnregisteredDialects(true); + TF_ASSIGN_OR_RETURN(mlir::OwningOpRef module, + ParseMlirModuleString(mlir_module, context)); + mlir::PassManager pm(&context); + if (VLOG_IS_ON(3)) EnablePrintBeforeAndAfter(pm); + pm.addPass(mlir::mhlo::createHloLegalizeToStablehloPass()); + if (!mlir::succeeded(pm.run(*module))) { + return tsl::errors::InvalidArgument("MHLO => StableHLO failed"); + } + // Use bytecode, passing unregistered dialects with properties causes issues + // when using textual assembly. + TF_ASSIGN_OR_RETURN(std::string bytecode, SerializeUsingBytecode(*module)); + return nb::bytes(bytecode.data(), bytecode.size()); +} + +absl::StatusOr PySerializePortableArtifact( + absl::string_view mlir_module, absl::string_view target) { + mlir::MLIRContext context; + if (VLOG_IS_ON(3)) context.disableMultithreading(); + TF_ASSIGN_OR_RETURN(mlir::OwningOpRef module, + ParseMlirModuleString(mlir_module, context)); + + // Serialize portable artifact + TF_ASSIGN_OR_RETURN( + std::string bytecode, + SerializeUsingVersionedStablehlo(*module, target, /*inplace=*/true)); + return nb::bytes(bytecode.data(), bytecode.size()); +} + +absl::StatusOr PyDeserializePortableArtifact( + const nb::bytes& bytecode_str) { + mlir::MLIRContext context; + mlir::OwningOpRef module = + mlir::stablehlo::deserializePortableArtifact( + absl::string_view(bytecode_str.c_str(), bytecode_str.size()), + &context); + if (!module) + return tsl::errors::InvalidArgument("Failed to deserialize StableHLO"); + return PrintModule(*module); +} + +} // namespace + +void BuildMlirSubmodule(nb::module_& m) { + nb::module_ mlir_module = m.def_submodule("mlir", "MLIR/XLA integration"); + + mlir_module.def("hlo_to_stablehlo", xla::ValueOrThrowWrapper(HloToStableHlo), + nb::arg("computation")); + + mlir_module.def("xla_computation_to_mlir_module", + xla::ValueOrThrowWrapper(PyXlaComputationToMlirModule), + nb::arg("computation")); + mlir_module.def( + "mlir_module_to_xla_computation", + [](const nb::bytes& bytecode, bool use_tuple_args, bool return_tuple) { + return xla::ValueOrThrow(PyMlirModuleToXlaComputation( + absl::string_view(bytecode.c_str(), bytecode.size()), + use_tuple_args, return_tuple)); + }, + nb::arg("mlir_module"), nb::arg("use_tuple_args") = false, + nb::arg("return_tuple") = false); + mlir_module.def("mlir_module_to_xla_computation", + xla::ValueOrThrowWrapper(PyMlirModuleToXlaComputation), + nb::arg("mlir_module"), nb::arg("use_tuple_args") = false, + nb::arg("return_tuple") = false); + mlir_module.def( + "mhlo_to_stablehlo", + [](const nb::bytes& bytecode) { + return xla::ValueOrThrow(PyMhloToStablehlo( + absl::string_view(bytecode.c_str(), bytecode.size()))); + }, + nb::arg("mlir_module")); + mlir_module.def("mhlo_to_stablehlo", + xla::ValueOrThrowWrapper(PyMhloToStablehlo), + nb::arg("mlir_module")); + mlir_module.def( + "serialize_portable_artifact", + [](const nb::bytes& bytecode, absl::string_view target) { + return xla::ValueOrThrow(PySerializePortableArtifact( + absl::string_view(bytecode.c_str(), bytecode.size()), target)); + }, + nb::arg("mlir_module"), nb::arg("target")); + mlir_module.def("serialize_portable_artifact", + xla::ValueOrThrowWrapper(PySerializePortableArtifact), + nb::arg("mlir_module"), nb::arg("target")); + mlir_module.def("deserialize_portable_artifact", + xla::ValueOrThrowWrapper(PyDeserializePortableArtifact), + nb::arg("mlir_module")); + mlir_module.def( + "refine_polymorphic_shapes", + [](nb::bytes bytecode, bool enable_shape_assertions, + bool validate_static_shapes, bool enable_shardy) -> nb::bytes { + std::string buffer; + llvm::raw_string_ostream os(buffer); + xla::ThrowIfError(RefinePolymorphicShapes( + absl::string_view(bytecode.c_str(), bytecode.size()), os, + enable_shape_assertions, validate_static_shapes, enable_shardy)); + return nb::bytes(buffer.data(), buffer.size()); + }, + nb::arg("mlir_module"), nb::arg("enable_shape_assertions") = true, + nb::arg("validate_static_shapes") = true, + nb::arg("enable_shardy") = false, + R"(Refines the dynamic shapes for a module. + The "main" function must have static shapes and all the + intermediate dynamic shapes depend only on the input static + shapes. Optionally, also validates that the resulting module has + only static shapes. + )"); +} + +} // namespace xla diff --git a/jaxlib/mlir.h b/jaxlib/mlir.h new file mode 100644 index 000000000000..bcbacb57a485 --- /dev/null +++ b/jaxlib/mlir.h @@ -0,0 +1,28 @@ +/* Copyright 2021 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_MLIR_H_ +#define JAXLIB_MLIR_H_ + +// placeholder for index annotation headers +#include "nanobind/nanobind.h" + +namespace xla { + +void BuildMlirSubmodule(nanobind::module_& m); + +} // namespace xla + +#endif // JAXLIB_MLIR_H_ diff --git a/jaxlib/mlir/BUILD.bazel b/jaxlib/mlir/BUILD.bazel index de7b017355fc..3cc3003c8daa 100644 --- a/jaxlib/mlir/BUILD.bazel +++ b/jaxlib/mlir/BUILD.bazel @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +load("@rules_python//python:py_library.bzl", "py_library") load("//jaxlib:symlink_files.bzl", "symlink_files", "symlink_inputs") package( @@ -65,7 +66,7 @@ symlink_inputs( name = "func_dialect", rule = py_library, symlinked_inputs = {"srcs": { - "dialects": ["@llvm-project//mlir/python:FuncPyFiles"], + "dialects": ["@llvm-project//mlir/python:FuncPyFiles"], }}, deps = [ ":core", @@ -78,7 +79,7 @@ symlink_inputs( name = "vector_dialect", rule = py_library, symlinked_inputs = {"srcs": { - "dialects": ["@llvm-project//mlir/python:VectorOpsPyFiles"], + "dialects": ["@llvm-project//mlir/python:VectorOpsPyFiles"], }}, deps = [ ":core", @@ -91,7 +92,7 @@ symlink_inputs( name = "math_dialect", rule = py_library, symlinked_inputs = {"srcs": { - "dialects": ["@llvm-project//mlir/python:MathOpsPyFiles"], + "dialects": ["@llvm-project//mlir/python:MathOpsPyFiles"], }}, deps = [ ":core", @@ -104,7 +105,7 @@ symlink_inputs( name = "arithmetic_dialect", rule = py_library, symlinked_inputs = {"srcs": { - "dialects": ["@llvm-project//mlir/python:ArithOpsPyFiles"], + "dialects": ["@llvm-project//mlir/python:ArithOpsPyFiles"], }}, deps = [ ":core", @@ -117,7 +118,20 @@ symlink_inputs( name = "memref_dialect", rule = py_library, symlinked_inputs = {"srcs": { - "dialects": ["@llvm-project//mlir/python:MemRefOpsPyFiles"], + "dialects": ["@llvm-project//mlir/python:MemRefOpsPyFiles"], + }}, + deps = [ + ":core", + ":ir", + ":mlir", + ], +) + +symlink_inputs( + name = "control_flow_dialect", + rule = py_library, + symlinked_inputs = {"srcs": { + "dialects": ["@llvm-project//mlir/python:ControlFlowOpsPyFiles"], }}, deps = [ ":core", @@ -130,7 +144,7 @@ symlink_inputs( name = "scf_dialect", rule = py_library, symlinked_inputs = {"srcs": { - "dialects": ["@llvm-project//mlir/python:SCFPyFiles"], + "dialects": ["@llvm-project//mlir/python:SCFPyFiles"], }}, deps = [ ":core", @@ -143,7 +157,7 @@ symlink_inputs( name = "builtin_dialect", rule = py_library, symlinked_inputs = {"srcs": { - "dialects": ["@llvm-project//mlir/python:BuiltinOpsPyFiles"], + "dialects": ["@llvm-project//mlir/python:BuiltinOpsPyFiles"], }}, deps = [ ":core", @@ -157,7 +171,7 @@ symlink_inputs( name = "chlo_dialect", rule = py_library, symlinked_inputs = {"srcs": { - "dialects": ["@stablehlo//:chlo_ops_py_files"], + "dialects": ["@stablehlo//:chlo_ops_py_files"], }}, deps = [ ":core", @@ -171,7 +185,7 @@ symlink_inputs( name = "sparse_tensor_dialect", rule = py_library, symlinked_inputs = {"srcs": { - "dialects": ["@llvm-project//mlir/python:SparseTensorOpsPyFiles"], + "dialects": ["@llvm-project//mlir/python:SparseTensorOpsPyFiles"], }}, deps = [ ":core", @@ -186,7 +200,7 @@ symlink_inputs( name = "mhlo_dialect", rule = py_library, symlinked_inputs = {"srcs": { - "dialects": ["@xla//xla/mlir_hlo:MhloOpsPyFiles"], + "dialects": ["@xla//xla/mlir_hlo:MhloOpsPyFiles"], }}, deps = [ ":core", @@ -228,7 +242,7 @@ symlink_inputs( name = "stablehlo_dialect", rule = py_library, symlinked_inputs = {"srcs": { - "dialects": ["@stablehlo//:stablehlo_ops_py_files"], + "dialects": ["@stablehlo//:stablehlo_ops_py_files"], }}, deps = [ ":core", diff --git a/jaxlib/mlir/_mlir_libs/BUILD.bazel b/jaxlib/mlir/_mlir_libs/BUILD.bazel index fb94837cff37..2f0736c43f11 100644 --- a/jaxlib/mlir/_mlir_libs/BUILD.bazel +++ b/jaxlib/mlir/_mlir_libs/BUILD.bazel @@ -12,13 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +load("@rules_python//python:py_library.bzl", "py_library") load( "//jaxlib:jax.bzl", "if_windows", - "nanobind_extension", - "py_extension", - "windows_cc_shared_mlir_library", ) +load("//jaxlib:pywrap.bzl", "nanobind_pywrap_extension") load("//jaxlib:symlink_files.bzl", "symlink_inputs") package( @@ -33,134 +32,107 @@ COPTS = [ "-frtti", ] -LINKOPTS = select({ - "@xla//xla/tsl:macos": [ - "-Wl,-rpath,@loader_path/", - "-Wl,-rename_section,__TEXT,text_env,__TEXT,__text", - ], - "@xla//xla/tsl:windows": [], - "//conditions:default": [ - "-Wl,-rpath,$$ORIGIN/", - ], -}) - -py_extension( +nanobind_pywrap_extension( name = "_mlir", srcs = [ "@llvm-project//mlir:lib/Bindings/Python/MainModule.cpp", ], copts = COPTS, - linkopts = LINKOPTS, deps = [ - ":jaxlib_mlir_capi_shared_library", - "@llvm-project//mlir:MLIRBindingsPythonCoreNoCAPI", + "@llvm-project//mlir:MLIRBindingsPythonCore", "@llvm-project//mlir:MLIRBindingsPythonNanobindHeaders", "@nanobind", ], ) -py_extension( +nanobind_pywrap_extension( name = "_mlirDialectsGPU", srcs = [ "@llvm-project//mlir:lib/Bindings/Python/DialectGPU.cpp", ], copts = COPTS, - linkopts = LINKOPTS, deps = [ - ":jaxlib_mlir_capi_shared_library", - "@llvm-project//mlir:CAPIGPUHeaders", - "@llvm-project//mlir:CAPIIRHeaders", + "@llvm-project//mlir:CAPIGPU", + "@llvm-project//mlir:CAPIIR", "@llvm-project//mlir:MLIRBindingsPythonNanobindHeaders", "@nanobind", ], ) -py_extension( +nanobind_pywrap_extension( name = "_mlirGPUPasses", srcs = [ "@llvm-project//mlir:lib/Bindings/Python/GPUPasses.cpp", ], copts = COPTS, - linkopts = LINKOPTS, deps = [ - ":jaxlib_mlir_capi_shared_library", - "@llvm-project//mlir:CAPIGPUHeaders", + "@llvm-project//mlir:CAPIGPU", "@llvm-project//mlir:MLIRBindingsPythonNanobindHeaders", "@nanobind", ], ) -py_extension( +nanobind_pywrap_extension( name = "_mlirDialectsNVGPU", srcs = [ "@llvm-project//mlir:lib/Bindings/Python/DialectNVGPU.cpp", ], copts = COPTS, - linkopts = LINKOPTS, deps = [ - ":jaxlib_mlir_capi_shared_library", - "@llvm-project//mlir:CAPIIRHeaders", - "@llvm-project//mlir:CAPINVGPUHeaders", + "@llvm-project//mlir:CAPIIR", + "@llvm-project//mlir:CAPINVGPU", "@llvm-project//mlir:MLIRBindingsPythonNanobindHeaders", "@nanobind", ], ) -py_extension( +nanobind_pywrap_extension( name = "_mlirDialectsLLVM", srcs = [ "@llvm-project//mlir:lib/Bindings/Python/DialectLLVM.cpp", ], copts = COPTS, - linkopts = LINKOPTS, deps = [ - ":jaxlib_mlir_capi_shared_library", - "@llvm-project//mlir:CAPIIRHeaders", - "@llvm-project//mlir:CAPILLVMHeaders", + "@llvm-project//mlir:CAPIIR", + "@llvm-project//mlir:CAPILLVM", "@llvm-project//mlir:MLIRBindingsPythonNanobindHeaders", "@nanobind", ], ) -py_extension( +nanobind_pywrap_extension( name = "_mlirDialectsSparseTensor", srcs = [ "@llvm-project//mlir:lib/Bindings/Python/DialectSparseTensor.cpp", ], copts = COPTS, - linkopts = LINKOPTS, deps = [ - ":jaxlib_mlir_capi_shared_library", - "@llvm-project//mlir:CAPISparseTensorHeaders", + "@llvm-project//mlir:CAPISparseTensor", "@llvm-project//mlir:MLIRBindingsPythonNanobindHeaders", "@nanobind", ], ) -py_extension( +nanobind_pywrap_extension( name = "_mlirSparseTensorPasses", srcs = [ "@llvm-project//mlir:lib/Bindings/Python/SparseTensorPasses.cpp", ], copts = COPTS, - linkopts = LINKOPTS, deps = [ - ":jaxlib_mlir_capi_shared_library", - "@llvm-project//mlir:CAPISparseTensorHeaders", + "@llvm-project//mlir:CAPISparseTensor", "@llvm-project//mlir:MLIRBindingsPythonNanobindHeaders", "@nanobind", ], ) -py_extension( +nanobind_pywrap_extension( name = "_mosaic_gpu_ext", srcs = ["mosaic_gpu_ext.cc"], copts = COPTS, - linkopts = LINKOPTS, deps = [ - ":jaxlib_mlir_capi_shared_library", - "//jaxlib/mosaic/dialect/gpu:gpu_dialect_capi_headers", - "@llvm-project//mlir:CAPIIRHeaders", + "//jaxlib/mosaic/dialect/gpu:gpu_dialect_capi", + "@llvm-project//mlir:CAPIIR", "@llvm-project//mlir:MLIRBindingsPythonNanobindHeadersAndDeps", "@nanobind", ], @@ -171,17 +143,15 @@ py_extension( # :jaxlib_mlir_capi_shared_library). This ensures that the RPATH works correctly # across platforms. It's not clear if Windows supports RPATH-like functionality # across different directories at all. -py_extension( +nanobind_pywrap_extension( name = "_tpu_ext", srcs = ["tpu_ext.cc"], copts = COPTS, - linkopts = LINKOPTS, deps = [ - ":jaxlib_mlir_capi_shared_library", - "//jaxlib/mosaic:tpu_dialect_capi_headers", + "//jaxlib/mosaic:tpu_dialect_capi", "@com_google_absl//absl/log:check", "@llvm-project//llvm:Support", - "@llvm-project//mlir:CAPIIRHeaders", + "@llvm-project//mlir:CAPIIR", "@llvm-project//mlir:MLIRBindingsPythonNanobindHeadersAndDeps", "@nanobind", "@xla//xla/python:nb_numpy", @@ -190,7 +160,7 @@ py_extension( ) # This target contains the extension and it's Python dependencies, which are not -# supported by the `py_extension`/`nanobind_extension` macros. +# supported by the `nanobind_pywrap_extension`/`nanobind_extension` macros. py_library( name = "_tpu_ext_lib", deps = [ @@ -200,19 +170,21 @@ py_library( ], ) -nanobind_extension( +nanobind_pywrap_extension( name = "_triton_ext", srcs = ["triton_ext.cc"], copts = COPTS, - linkopts = LINKOPTS, pytype_srcs = ["_triton_ext.pyi"], deps = [ - ":jaxlib_mlir_capi_shared_library", - "//jaxlib/triton:triton_dialect_capi_headers", - "@llvm-project//mlir:CAPIIRHeaders", - "@llvm-project//mlir:MLIRBindingsPythonNanobindHeadersAndDeps", "@nanobind", - ], + ] + if_windows( + [], + [ + "//jaxlib/triton:triton_dialect_capi", + "@llvm-project//mlir:CAPIIR", + "@llvm-project//mlir:MLIRBindingsPythonNanobindHeadersAndDeps", + ], + ), ) symlink_inputs( @@ -229,55 +201,29 @@ symlink_inputs( ], ) -cc_library( - name = "jaxlib_mlir_capi_shims", - srcs = ["jaxlib_mlir_capi_shims.cc"], - hdrs = ["jaxlib_mlir_capi_shims.h"], - deps = [ - "@llvm-project//mlir:BuiltinToLLVMIRTranslation", - "@llvm-project//mlir:CAPIIRHeaders", - "@llvm-project//mlir:GPUPipelines", - "@llvm-project//mlir:GPUToLLVMIRTranslation", - "@llvm-project//mlir:LLVMToLLVMIRTranslation", - "@llvm-project//mlir:MemRefTransforms", - "@llvm-project//mlir:NVVMTarget", - "@llvm-project//mlir:NVVMToLLVMIRTranslation", - ], - alwayslink = 1, -) - -cc_library( - name = "jaxlib_mlir_capi_shims_hdrs", - hdrs = ["jaxlib_mlir_capi_shims.h"], - deps = [ - "@llvm-project//mlir:CAPIIRHeaders", - ], -) - # JAX-specific registrations. -py_extension( +nanobind_pywrap_extension( name = "register_jax_dialects", srcs = ["register_jax_dialects.cc"], copts = COPTS, - linkopts = LINKOPTS, deps = [ - ":jaxlib_mlir_capi_shared_library", - "//jaxlib/mosaic/gpu:mlir_capi_headers", - "@llvm-project//mlir:CAPIArithHeaders", - "@llvm-project//mlir:CAPIGPUHeaders", - "@llvm-project//mlir:CAPIIRHeaders", - "@llvm-project//mlir:CAPILLVMHeaders", - "@llvm-project//mlir:CAPIMathHeaders", - "@llvm-project//mlir:CAPIMemRefHeaders", - "@llvm-project//mlir:CAPINVGPUHeaders", - "@llvm-project//mlir:CAPINVVMHeaders", - "@llvm-project//mlir:CAPISCFHeaders", - "@llvm-project//mlir:CAPITransformsHeaders", - "@llvm-project//mlir:CAPIVectorHeaders", + "//jaxlib/mosaic/gpu:mlir_capi", + "@llvm-project//mlir:CAPIArith", + "@llvm-project//mlir:CAPICF", + "@llvm-project//mlir:CAPIGPU", + "@llvm-project//mlir:CAPIIR", + "@llvm-project//mlir:CAPILLVM", + "@llvm-project//mlir:CAPIMath", + "@llvm-project//mlir:CAPIMemRef", + "@llvm-project//mlir:CAPINVGPU", + "@llvm-project//mlir:CAPINVVM", + "@llvm-project//mlir:CAPISCF", + "@llvm-project//mlir:CAPITransforms", + "@llvm-project//mlir:CAPIVector", "@llvm-project//mlir:MLIRBindingsPythonNanobindHeaders", "@local_config_python//:headers", "@nanobind", - "@shardy//shardy/integrations/c:sdy_capi_headers", + "@shardy//shardy/integrations/c:sdy_capi", ], ) @@ -285,20 +231,18 @@ py_extension( # MHLO Extensions ##---------------------------------------------------------------------------## -py_extension( +nanobind_pywrap_extension( name = "_mlirHlo", srcs = [ "@xla//xla/mlir_hlo:bindings/python/MlirHloModule.cc", ], copts = COPTS, - linkopts = LINKOPTS, deps = [ - ":jaxlib_mlir_capi_shared_library", - "@llvm-project//mlir:CAPIIRHeaders", + "@llvm-project//mlir:CAPIIR", "@llvm-project//mlir:MLIRBindingsPythonNanobindHeaders", "@local_config_python//:headers", "@nanobind", - "@xla//xla/mlir_hlo:CAPIHeaders", + "@xla//xla/mlir_hlo:CAPI", ], ) @@ -306,21 +250,19 @@ py_extension( # Shardy Extensions ##---------------------------------------------------------------------------## -py_extension( +nanobind_pywrap_extension( name = "_sdy", srcs = [ "@shardy//shardy/integrations/python/ir:sdy_module.cc", ], copts = COPTS, - linkopts = LINKOPTS, deps = [ - ":jaxlib_mlir_capi_shared_library", - "@llvm-project//mlir:CAPIIRHeaders", + "@llvm-project//mlir:CAPIIR", "@llvm-project//mlir:IR", "@llvm-project//mlir:MLIRBindingsPythonNanobindHeaders", "@local_config_python//:headers", "@nanobind", - "@shardy//shardy/integrations/c:sdy_capi_headers", + "@shardy//shardy/integrations/c:sdy_capi", ], ) @@ -328,115 +270,33 @@ py_extension( # Stablehlo Extensions ##---------------------------------------------------------------------------## -py_extension( +nanobind_pywrap_extension( name = "_chlo", srcs = [ "@stablehlo//:chlo_py_api_files", ], copts = COPTS, - linkopts = LINKOPTS, deps = [ - ":jaxlib_mlir_capi_shared_library", - "@llvm-project//mlir:CAPIIRHeaders", + "@llvm-project//mlir:CAPIIR", "@llvm-project//mlir:MLIRBindingsPythonNanobindHeaders", "@local_config_python//:headers", "@nanobind", - "@stablehlo//:chlo_capi_headers", + "@stablehlo//:chlo_capi", ], ) -py_extension( +nanobind_pywrap_extension( name = "_stablehlo", srcs = [ "@stablehlo//:stablehlo_py_api_files", ], copts = COPTS, - linkopts = LINKOPTS, deps = [ - ":jaxlib_mlir_capi_shared_library", "@llvm-project//llvm:Support", - "@llvm-project//mlir:CAPIIRHeaders", + "@llvm-project//mlir:CAPIIR", "@llvm-project//mlir:MLIRBindingsPythonNanobindHeaders", "@local_config_python//:headers", "@nanobind", - "@stablehlo//:stablehlo_capi_headers", - ], -) - -# Shared C++ extension library - -cc_library( - name = "jaxlib_mlir_capi_shared_library", - srcs = select({ - "@xla//xla/tsl:windows": [":jaxlib_mlir_capi.dll"], - "@xla//xla/tsl:macos": [":libjaxlib_mlir_capi.dylib"], - "//conditions:default": [":libjaxlib_mlir_capi.so"], - }), - deps = select({ - "@xla//xla/tsl:windows": [":jaxlib_mlir_capi_dll"], - "//conditions:default": [], - }), -) - -cc_library( - name = "jaxlib_mlir_capi_objects", - deps = [ - "//jaxlib/mosaic:tpu_dialect_capi_objects", - "//jaxlib/mosaic/dialect/gpu:gpu_dialect_capi_objects", - "//jaxlib/mosaic/gpu:mlir_capi_objects", - "@llvm-project//mlir:CAPIArithObjects", - "@llvm-project//mlir:CAPIGPUObjects", - "@llvm-project//mlir:CAPIIRObjects", - "@llvm-project//mlir:CAPILLVMObjects", - "@llvm-project//mlir:CAPIMathObjects", - "@llvm-project//mlir:CAPIMemRefObjects", - "@llvm-project//mlir:CAPINVGPUObjects", - "@llvm-project//mlir:CAPINVVMObjects", - "@llvm-project//mlir:CAPISCFObjects", - "@llvm-project//mlir:CAPISparseTensorObjects", - "@llvm-project//mlir:CAPITransformsObjects", - "@llvm-project//mlir:CAPIVectorObjects", - "@llvm-project//mlir:MLIRBindingsPythonCAPIObjects", - "@shardy//shardy/integrations/c:sdy_capi_objects", - "@stablehlo//:chlo_capi_objects", - "@stablehlo//:stablehlo_capi_objects", - "@xla//xla/mlir_hlo:CAPIObjects", - ] + if_windows( - [], - [ - "//jaxlib/triton:triton_dialect_capi_objects", - ], - ), -) - -cc_binary( - name = "libjaxlib_mlir_capi.so", - linkopts = [ - "-Wl,-soname=libjaxlib_mlir_capi.so", - "-Wl,-rpath='$$ORIGIN'", - ], - linkshared = 1, - deps = [":jaxlib_mlir_capi_objects"], -) - -cc_binary( - name = "libjaxlib_mlir_capi.dylib", - linkopts = [ - "-Wl,-rpath,@loader_path/", - "-Wl,-install_name,@loader_path/libjaxlib_mlir_capi.dylib", - ], - linkshared = 1, - deps = [":jaxlib_mlir_capi_objects"], -) - -windows_cc_shared_mlir_library( - name = "jaxlib_mlir_capi_dll", - out = "jaxlib_mlir_capi.dll", - exported_symbol_prefixes = [ - "mlir", - "chlo", - "sdy", - "stablehlo", + "@stablehlo//:stablehlo_capi", ], - deps = [":jaxlib_mlir_capi_objects"], ) diff --git a/jaxlib/mlir/_mlir_libs/_triton_ext.pyi b/jaxlib/mlir/_mlir_libs/_triton_ext.pyi index 1e1a67405113..93a82010043c 100644 --- a/jaxlib/mlir/_mlir_libs/_triton_ext.pyi +++ b/jaxlib/mlir/_mlir_libs/_triton_ext.pyi @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from mlir import ir +from jaxlib.mlir import ir def register_dialect(context: ir.Context, load: bool = ...) -> None: ... diff --git a/jaxlib/mlir/_mlir_libs/mosaic_gpu_ext.cc b/jaxlib/mlir/_mlir_libs/mosaic_gpu_ext.cc index c73084abc99d..2751719fc61d 100644 --- a/jaxlib/mlir/_mlir_libs/mosaic_gpu_ext.cc +++ b/jaxlib/mlir/_mlir_libs/mosaic_gpu_ext.cc @@ -138,25 +138,4 @@ NB_MODULE(_mosaic_gpu_ext, m) { .def_property_readonly("swizzle", [](MlirAttribute self) { return mlirMosaicGpuSwizzleTransformAttrGetSwizzle(self); }); - - mlir::python::nanobind_adaptors::mlir_attribute_subclass( - m, "LayoutAttr", mlirMosaicGpuIsALayoutAttr) - .def_classmethod( - "get", - [](nb::object cls, int32_t num_dimensions, - std::vector& transforms, MlirContext ctx) { - return cls(mlirMosaicGpuLayoutAttrGet( - ctx, num_dimensions, transforms.data(), transforms.size())); - }, - nb::arg("cls"), nb::arg("num_dimensions"), nb::arg("transforms"), - nb::arg("context").none() = nb::none(), - "Creates a LayoutAttr with the given transforms.") - .def_property_readonly("transforms", [](MlirAttribute self) { - std::vector result; - for (int i = 0; i < mlirMosaicGpuLayoutAttrGetTransformsSize(self); - ++i) { - result.push_back(mlirMosaicGpuLayoutAttrGetTransform(self, i)); - } - return result; - }); } diff --git a/jaxlib/mlir/_mlir_libs/register_jax_dialects.cc b/jaxlib/mlir/_mlir_libs/register_jax_dialects.cc index 9da841acc7de..3c2604640a19 100644 --- a/jaxlib/mlir/_mlir_libs/register_jax_dialects.cc +++ b/jaxlib/mlir/_mlir_libs/register_jax_dialects.cc @@ -1,28 +1,44 @@ +/* Copyright 2022 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + // Registers MLIR dialects used by JAX. // This module is called by mlir/__init__.py during initialization. #include -#include "mlir-c/Dialect/Arith.h" -#include "mlir-c/Dialect/Func.h" -#include "mlir-c/Dialect/GPU.h" -#include "mlir-c/Dialect/LLVM.h" -#include "mlir-c/Dialect/Math.h" -#include "mlir-c/Dialect/MemRef.h" -#include "mlir-c/Dialect/NVGPU.h" -#include "mlir-c/Dialect/NVVM.h" -#include "mlir-c/Dialect/SCF.h" -#include "mlir-c/Dialect/Vector.h" +#include "mlir-c/Dialect/Arith.h" // IWYU pragma: keep +#include "mlir-c/Dialect/ControlFlow.h" +#include "mlir-c/Dialect/Func.h" // IWYU pragma: keep +#include "mlir-c/Dialect/GPU.h" // IWYU pragma: keep +#include "mlir-c/Dialect/LLVM.h" // IWYU pragma: keep +#include "mlir-c/Dialect/Math.h" // IWYU pragma: keep +#include "mlir-c/Dialect/MemRef.h" // IWYU pragma: keep +#include "mlir-c/Dialect/NVGPU.h" // IWYU pragma: keep +#include "mlir-c/Dialect/NVVM.h" // IWYU pragma: keep +#include "mlir-c/Dialect/SCF.h" // IWYU pragma: keep +#include "mlir-c/Dialect/Vector.h" // IWYU pragma: keep +#include "mlir-c/IR.h" #include "mlir-c/Transforms.h" -#include "mlir/Bindings/Python/NanobindAdaptors.h" +#include "mlir/Bindings/Python/NanobindAdaptors.h" // IWYU pragma: keep #include "shardy/integrations/c/passes.h" #include "jaxlib/mosaic/gpu/integrations/c/passes.h" - namespace nb = nanobind; -#define REGISTER_DIALECT(name) \ - MlirDialectHandle name##_dialect = mlirGetDialectHandle__##name##__(); \ - mlirDialectHandleInsertDialect(name##_dialect, registry) +#define REGISTER_DIALECT(name) \ + MlirDialectHandle name##_dialect = mlirGetDialectHandle__##name##__(); \ + mlirDialectHandleInsertDialect(name##_dialect, registry) NB_MODULE(register_jax_dialects, m) { m.doc() = "Registers upstream MLIR dialects used by JAX."; @@ -35,6 +51,7 @@ NB_MODULE(register_jax_dialects, m) { REGISTER_DIALECT(scf); REGISTER_DIALECT(vector); // For Mosaic GPU + REGISTER_DIALECT(cf); REGISTER_DIALECT(gpu); REGISTER_DIALECT(nvgpu); REGISTER_DIALECT(nvvm); diff --git a/jaxlib/mlir/_mlir_libs/tpu_ext.cc b/jaxlib/mlir/_mlir_libs/tpu_ext.cc index 2b5ec898ad3e..8f751693e451 100644 --- a/jaxlib/mlir/_mlir_libs/tpu_ext.cc +++ b/jaxlib/mlir/_mlir_libs/tpu_ext.cc @@ -19,7 +19,6 @@ limitations under the License. #include #include #include -#include #include #include #include @@ -27,9 +26,8 @@ limitations under the License. #include #include -#include "llvm/ADT/ArrayRef.h" +#include "absl/types/span.h" #include "llvm/ADT/SmallVector.h" -#include "llvm/ADT/SmallVectorExtras.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringRef.h" #include "mlir-c/AffineMap.h" @@ -41,15 +39,14 @@ limitations under the License. #include "mlir-c/Support.h" #include "mlir/Bindings/Python/NanobindAdaptors.h" // IWYU pragma: keep // clang-format off -#include "mlir-c/Bindings/Python/Interop.h" // clang-format on +#include "absl/log/check.h" #include "nanobind/nanobind.h" #include "nanobind/stl/optional.h" // IWYU pragma: keep -#include "nanobind/stl/pair.h" // IWYU pragma: keep -#include "nanobind/stl/string.h" // IWYU pragma: keep -#include "nanobind/stl/variant.h" // IWYU pragma: keep -#include "nanobind/stl/vector.h" // IWYU pragma: keep -#include "absl/log/check.h" +#include "nanobind/stl/pair.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/variant.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep #include "jaxlib/mosaic/dialect/tpu/integrations/c/tpu_dialect.h" #include "xla/python/nb_numpy.h" #include "xla/tsl/python/lib/core/numpy.h" diff --git a/jaxlib/mlir/_mlir_libs/triton_ext.cc b/jaxlib/mlir/_mlir_libs/triton_ext.cc index 2a13c40d963f..687ceec4cd33 100644 --- a/jaxlib/mlir/_mlir_libs/triton_ext.cc +++ b/jaxlib/mlir/_mlir_libs/triton_ext.cc @@ -13,6 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#ifndef _WIN32 + +#include #include #include "mlir-c/IR.h" @@ -73,3 +76,11 @@ NB_MODULE(_triton_ext, m) { return encoding; }); } + +#else // _WIN32 + +#include "nanobind/nanobind.h" + +NB_MODULE(_triton_ext, m) {} + +#endif // _WIN32 diff --git a/jaxlib/mosaic/BUILD b/jaxlib/mosaic/BUILD index 4cc2530dd7ca..a212f7afb8bd 100644 --- a/jaxlib/mosaic/BUILD +++ b/jaxlib/mosaic/BUILD @@ -60,9 +60,9 @@ cc_library( ]), # compatible with libtpu deps = [ + ":pass_boilerplate", + ":serde", ":tpu_inc_gen", - "//jaxlib:pass_boilerplate", - "//jaxlib/mosaic:serde", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/hash", @@ -95,75 +95,37 @@ cc_library( "@xla//xla:shape_util", "@xla//xla:util", "@xla//xla/tsl/platform:errors", + "@xla//xla/tsl/platform:statusor", ] + mosaic_extension_deps, ) gentbl_cc_library( name = "tpu_inc_gen", # compatible with libtpu - tbl_outs = [ - ( - ["-gen-op-decls"], - "dialect/tpu/tpu_ops.h.inc", - ), - ( - ["-gen-op-defs"], - "dialect/tpu/tpu_ops.cc.inc", - ), - ( - ["-gen-dialect-decls"], - "dialect/tpu/tpu_dialect.h.inc", - ), - ( - ["-gen-dialect-defs"], - "dialect/tpu/tpu_dialect.cc.inc", - ), - ( - ["-gen-enum-decls"], - "dialect/tpu/tpu_enums.h.inc", - ), - ( - ["-gen-enum-defs"], - "dialect/tpu/tpu_enums.cc.inc", - ), - ( - ["-gen-attrdef-decls"], - "dialect/tpu/tpu_attr_defs.h.inc", - ), - ( - ["-gen-attrdef-defs"], - "dialect/tpu/tpu_attr_defs.cc.inc", - ), - ( - ["-gen-typedef-decls"], - "dialect/tpu/tpu_type_defs.h.inc", - ), - ( - ["-gen-typedef-defs"], - "dialect/tpu/tpu_type_defs.cc.inc", - ), - ( - [ - "-gen-pass-decls", - "-name=TPU", - ], - "dialect/tpu/tpu_passes.h.inc", - ), - ( - [ - "-gen-pass-capi-header", - "--prefix=TPU", - ], - "dialect/tpu/integrations/c/tpu_passes.capi.h.inc", - ), - ( - [ - "-gen-pass-capi-impl", - "--prefix=TPU", - ], - "dialect/tpu/integrations/c/tpu_passes.capi.cc.inc", - ), - ], + tbl_outs = { + "dialect/tpu/tpu_ops.h.inc": ["-gen-op-decls"], + "dialect/tpu/tpu_ops.cc.inc": ["-gen-op-defs"], + "dialect/tpu/tpu_dialect.h.inc": ["-gen-dialect-decls"], + "dialect/tpu/tpu_dialect.cc.inc": ["-gen-dialect-defs"], + "dialect/tpu/tpu_enums.h.inc": ["-gen-enum-decls"], + "dialect/tpu/tpu_enums.cc.inc": ["-gen-enum-defs"], + "dialect/tpu/tpu_attr_defs.h.inc": ["-gen-attrdef-decls"], + "dialect/tpu/tpu_attr_defs.cc.inc": ["-gen-attrdef-defs"], + "dialect/tpu/tpu_type_defs.h.inc": ["-gen-typedef-decls"], + "dialect/tpu/tpu_type_defs.cc.inc": ["-gen-typedef-defs"], + "dialect/tpu/tpu_passes.h.inc": [ + "-gen-pass-decls", + "-name=TPU", + ], + "dialect/tpu/integrations/c/tpu_passes.capi.h.inc": [ + "-gen-pass-capi-header", + "--prefix=TPU", + ], + "dialect/tpu/integrations/c/tpu_passes.capi.cc.inc": [ + "-gen-pass-capi-impl", + "--prefix=TPU", + ], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "dialect/tpu/tpu.td", deps = [":tpu_td_files"], @@ -270,6 +232,21 @@ cc_test( ], ) +cc_test( + name = "tpu_ops_verification_test", + srcs = ["dialect/tpu/tpu_ops_verification_test.cc"], + deps = [ + ":tpu_dialect", + "//testing/base/public:gunit_main", + "@com_google_absl//absl/status", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:MemRefDialect", + "@llvm-project//mlir:Support", + "@xla//xla/mlir/utils:error_util", + ], +) + filegroup( name = "extension_srcs", srcs = [ @@ -279,6 +256,17 @@ filegroup( # compatible with libtpu ) +cc_library( + name = "pass_boilerplate", + hdrs = ["pass_boilerplate.h"], + # compatible with libtpu + deps = [ + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + ], +) + cc_library( name = "serde", srcs = ["serde.cc"], diff --git a/jaxlib/mosaic/dialect/gpu/BUILD b/jaxlib/mosaic/dialect/gpu/BUILD index e21c8756a4e2..854955a60493 100644 --- a/jaxlib/mosaic/dialect/gpu/BUILD +++ b/jaxlib/mosaic/dialect/gpu/BUILD @@ -39,66 +39,36 @@ td_library( gentbl_cc_library( name = "mosaic_gpu_inc_gen", - tbl_outs = [ - ( - [ - "-gen-dialect-decls", - "-dialect=mosaic_gpu", - ], - "mosaic_gpu_dialect.h.inc", - ), - ( - [ - "-gen-dialect-defs", - "-dialect=mosaic_gpu", - ], - "mosaic_gpu_dialect.cc.inc", - ), - ( - ["-gen-op-decls"], - "mosaic_gpu_ops.h.inc", - ), - ( - ["-gen-op-defs"], - "mosaic_gpu_ops.cc.inc", - ), - ( - [ - "-gen-typedef-decls", - "--typedefs-dialect=mosaic_gpu", - ], - "mosaic_gpu_types.h.inc", - ), - ( - [ - "-gen-typedef-defs", - "--typedefs-dialect=mosaic_gpu", - ], - "mosaic_gpu_types.cc.inc", - ), - ( - ["-gen-enum-decls"], - "mosaic_gpu_enums.h.inc", - ), - ( - ["-gen-enum-defs"], - "mosaic_gpu_enums.cc.inc", - ), - ( - [ - "-gen-attrdef-decls", - "--attrdefs-dialect=mosaic_gpu", - ], - "mosaic_gpu_attrdefs.h.inc", - ), - ( - [ - "-gen-attrdef-defs", - "--attrdefs-dialect=mosaic_gpu", - ], - "mosaic_gpu_attrdefs.cc.inc", - ), - ], + tbl_outs = { + "mosaic_gpu_dialect.h.inc": [ + "-gen-dialect-decls", + "-dialect=mosaic_gpu", + ], + "mosaic_gpu_dialect.cc.inc": [ + "-gen-dialect-defs", + "-dialect=mosaic_gpu", + ], + "mosaic_gpu_ops.h.inc": ["-gen-op-decls"], + "mosaic_gpu_ops.cc.inc": ["-gen-op-defs"], + "mosaic_gpu_types.h.inc": [ + "-gen-typedef-decls", + "--typedefs-dialect=mosaic_gpu", + ], + "mosaic_gpu_types.cc.inc": [ + "-gen-typedef-defs", + "--typedefs-dialect=mosaic_gpu", + ], + "mosaic_gpu_enums.h.inc": ["-gen-enum-decls"], + "mosaic_gpu_enums.cc.inc": ["-gen-enum-defs"], + "mosaic_gpu_attrdefs.h.inc": [ + "-gen-attrdef-decls", + "--attrdefs-dialect=mosaic_gpu", + ], + "mosaic_gpu_attrdefs.cc.inc": [ + "-gen-attrdef-defs", + "--attrdefs-dialect=mosaic_gpu", + ], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "mosaic_gpu.td", deps = [ @@ -119,6 +89,7 @@ cc_library( "@llvm-project//llvm:Support", "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:GPUDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:InferTypeOpInterface", "@llvm-project//mlir:LLVMCommonConversion", @@ -127,7 +98,7 @@ cc_library( "@llvm-project//mlir:MemRefUtils", "@llvm-project//mlir:SCFUtils", "@llvm-project//mlir:Support", - "@tsl//tsl/platform:statusor", + "@xla//xla/tsl/platform:statusor", ], ) @@ -151,7 +122,7 @@ cc_test( "@llvm-project//mlir:MemRefDialect", "@llvm-project//mlir:SCFUtils", "@llvm-project//mlir:Support", - "@tsl//tsl/platform:errors", + "@xla//xla/tsl/platform:errors", ], ) diff --git a/jaxlib/mosaic/dialect/gpu/integrations/c/attributes.cc b/jaxlib/mosaic/dialect/gpu/integrations/c/attributes.cc index eac1d104f07f..523b14e425c9 100644 --- a/jaxlib/mosaic/dialect/gpu/integrations/c/attributes.cc +++ b/jaxlib/mosaic/dialect/gpu/integrations/c/attributes.cc @@ -1,7 +1,21 @@ +/* Copyright 2025 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + #include "jaxlib/mosaic/dialect/gpu/integrations/c/attributes.h" #include -#include #include "mlir-c/IR.h" #include "mlir/CAPI/IR.h" @@ -82,36 +96,3 @@ int32_t mlirMosaicGpuSwizzleTransformAttrGetSwizzle(MlirAttribute attr) { .getSwizzle() .getValue()); } - -//===----------------------------------------------------------------------===// -// LayoutAttr -//===----------------------------------------------------------------------===// - -bool mlirMosaicGpuIsALayoutAttr(MlirAttribute attr) { - return mlir::isa(unwrap(attr)); -} - -MlirAttribute mlirMosaicGpuLayoutAttrGet(MlirContext ctx, - int32_t num_dimensions, - MlirAttribute* transforms, - int32_t transforms_size) { - std::vector unwrapped_transforms; - unwrapped_transforms.reserve(transforms_size); - for (int i = 0; i < transforms_size; ++i) { - unwrapped_transforms.push_back(unwrap(transforms[i])); - } - return wrap(mosaic_gpu::LayoutAttr::get(unwrap(ctx), num_dimensions, - unwrapped_transforms)); -} - -int32_t mlirMosaicGpuLayoutAttrGetTransformsSize(MlirAttribute attr) { - return mlir::cast(unwrap(attr)) - .getTransforms() - .size(); -} - -MlirAttribute mlirMosaicGpuLayoutAttrGetTransform(MlirAttribute attr, - int32_t index) { - return wrap( - mlir::cast(unwrap(attr)).getTransforms()[index]); -} \ No newline at end of file diff --git a/jaxlib/mosaic/dialect/gpu/integrations/c/attributes.h b/jaxlib/mosaic/dialect/gpu/integrations/c/attributes.h index 3b8425b6b142..3221b9220e5d 100644 --- a/jaxlib/mosaic/dialect/gpu/integrations/c/attributes.h +++ b/jaxlib/mosaic/dialect/gpu/integrations/c/attributes.h @@ -69,22 +69,6 @@ mlirMosaicGpuSwizzleTransformAttrGet(MlirContext ctx, int32_t swizzle); MLIR_CAPI_EXPORTED int32_t mlirMosaicGpuSwizzleTransformAttrGetSwizzle(MlirAttribute attr); -//===----------------------------------------------------------------------===// -// LayoutAttr -//===----------------------------------------------------------------------===// - -MLIR_CAPI_EXPORTED bool mlirMosaicGpuIsALayoutAttr(MlirAttribute attr); - -MLIR_CAPI_EXPORTED MlirAttribute -mlirMosaicGpuLayoutAttrGet(MlirContext ctx, int32_t num_dimensions, - MlirAttribute* transforms, int32_t transforms_size); - -MLIR_CAPI_EXPORTED int32_t -mlirMosaicGpuLayoutAttrGetTransformsSize(MlirAttribute attr); - -MLIR_CAPI_EXPORTED MlirAttribute -mlirMosaicGpuLayoutAttrGetTransform(MlirAttribute attr, int32_t index); - #ifdef __cplusplus } #endif diff --git a/jaxlib/mosaic/dialect/gpu/integrations/c/gpu_dialect.h b/jaxlib/mosaic/dialect/gpu/integrations/c/gpu_dialect.h index bb6cf6e3af4a..5fd0ce7a4f7a 100644 --- a/jaxlib/mosaic/dialect/gpu/integrations/c/gpu_dialect.h +++ b/jaxlib/mosaic/dialect/gpu/integrations/c/gpu_dialect.h @@ -18,7 +18,7 @@ limitations under the License. #include -#include "mlir/CAPI/Registration.h" +#include "mlir-c/IR.h" #ifdef __cplusplus extern "C" { diff --git a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc index a1e7b571d20e..2f3bfb808981 100644 --- a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc +++ b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc @@ -18,7 +18,13 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/TypeSwitch.h" // IWYU pragma: keep #include "llvm/Support/Casting.h" #include "llvm/Support/FormatVariadic.h" @@ -26,13 +32,17 @@ limitations under the License. #include "mlir/Conversion/LLVMCommon/MemRefBuilder.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" #include "mlir/Dialect/SCF/Utils/Utils.h" +#include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Diagnostics.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/DialectImplementation.h" // IWYU pragma: keep #include "mlir/IR/ImplicitLocOpBuilder.h" @@ -43,15 +53,7 @@ limitations under the License. #include "mlir/IR/Value.h" #include "mlir/IR/ValueRange.h" #include "mlir/Support/LLVM.h" -#include "absl/algorithm/container.h" -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/string_view.h" -#include "mlir/include/mlir/IR/BuiltinTypeInterfaces.h" -#include "mlir/include/mlir/IR/BuiltinTypes.h" -#include "mlir/include/mlir/IR/Diagnostics.h" -#include "tsl/platform/statusor.h" +#include "xla/tsl/platform/statusor.h" // Generated definitions. #include "jaxlib/mosaic/dialect/gpu/mosaic_gpu_dialect.cc.inc" @@ -371,12 +373,80 @@ llvm::LogicalResult WGMMAOp::verify() { return llvm::success(); } -mlir::AffineMap LayoutAttr::getAffineMap() const { - // This always returns an identity map. It's technically not correct, but we - // don't actually use it anywhere. It's only called during verification of the - // layout attribute and needs to be semi-valid. - return mlir::AffineMap::getMultiDimIdentityMap(getNumDimensions(), - getContext()); +llvm::LogicalResult CustomPrimitiveOp::verify() { + int num_vector_operands = 0; + int num_smem_ref_operands = 0; + mlir::Attribute smem = mlir::gpu::AddressSpaceAttr::get( + getContext(), mlir::gpu::AddressSpace::Workgroup); + for (auto operand : getOperands()) { + if (mlir::isa(operand.getType())) { + ++num_vector_operands; + } + + if (auto ref_ty = mlir::dyn_cast(operand.getType())) { + if (ref_ty.getMemorySpace() == smem) { + ++num_smem_ref_operands; + } + } + } + + if (num_vector_operands != getInLayouts().size()) { + return emitOpError( + "Custom primitive must have a layout for each vector operand."); + } + + if (num_smem_ref_operands != getInTransforms().size()) { + return emitOpError( + "Custom primitive must have transforms for each memref operand in " + "smem."); + } + + if (getResults().size() != getOutLayouts().size()) { + return emitOpError("Custom primitive must have a layout for each result."); + } + + return llvm::success(); +} + +llvm::LogicalResult BroadcastInDimOp::verify() { + auto error = [this](auto... params) { + return emitOpError(llvm::formatv(params...)); + }; + + auto operand_type = mlir::cast(getOperand().getType()); + auto result_type = mlir::cast(getResult().getType()); + + if (operand_type.getRank() == 0) { + return error("The input vector must have rank > 0."); + } + + if (operand_type.getRank() > result_type.getRank()) { + return error( + "The rank of the input vector must be smaller or equal to the rank " + "of the result vector."); + } + + if (operand_type.getRank() != getBroadcastDimensions().size()) { + return error( + "The size of the `broadcast_dimensions` attribute must be equal to " + "the rank of the input vector."); + } + auto dims = llvm::to_vector(getBroadcastDimensions()); + for (int i = 0; i < dims.size(); ++i) { + if (dims[i] < 0 || dims[i] >= result_type.getRank()) { + return error( + "The values in the `broadcast_dimensions` attribute must be in the " + "range [0, result.shape.rank={0}).", + result_type.getRank()); + } + if (i > 0 && dims[i] <= dims[i - 1]) { + return error( + "The values in the `broadcast_dimensions` attribute must be strictly " + "increasing."); + } + } + + return llvm::success(); } void MosaicGPUDialect::initialize() { diff --git a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.h b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.h index b4f13c50bd8c..474ed93806a1 100644 --- a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.h +++ b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.h @@ -19,17 +19,15 @@ limitations under the License. #include #include -#include "llvm/ADT/StringRef.h" +#include "absl/status/status.h" +#include "absl/strings/string_view.h" #include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Value.h" -#include "mlir/Interfaces/InferTypeOpInterface.h" -#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" // IWYU pragma: keep #include "mlir/Support/LLVM.h" -#include "absl/status/status.h" -#include "absl/strings/string_view.h" // Generated definitions. #include "jaxlib/mosaic/dialect/gpu/mosaic_gpu_dialect.h.inc" // IWYU pragma: keep diff --git a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td index 0882986fcf5e..1cf8ec11ae66 100644 --- a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td +++ b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td @@ -75,10 +75,6 @@ def MosaicGPU_InitializeBarrierOp : Op { let summary = "Executes an arrive.expect_tx operation on the given barrier."; - let description = [{ - A single thread in the warpgroup will execute an `arrive.expect_tx` - operation on the provided barrier with the provided `expect_tx`. - }]; let arguments = (ins MemRefRankOf<[MosaicGPU_Barrier], [0]>:$barrier, @@ -142,16 +138,15 @@ def MosaicGPU_WGSplatFragLayout : AttrDef { - let summary = "1D array that is a row that can be tiled by supported WGMMA shapes."; +def MosaicGPU_Replicated : AttrDef { + let summary = "Indicates a replicated dimension in a tiled layout."; let description = [{ - This layout is used to handle rows that are fragmented across all threads - in a warpgroup that is executing a WGMMA operation. The length of the array - must be divisible by 64. + See mosaic/gpu/fragmented_array.py -> Replicated for more details. }]; - let mnemonic = "WGMMARowFragLayout"; - let assemblyFormat = ""; + let parameters = (ins "int":$times); + let mnemonic = "Replicated"; + let assemblyFormat = "`<` `times` `=` $times `>`"; } def MosaicGPU_TiledLayout : AttrDef { @@ -162,7 +157,7 @@ def MosaicGPU_TiledLayout : AttrDef { let parameters = (ins "::mlir::ArrayAttr":$tiling, - "int":$warp_dim, + "::mlir::Attribute":$warp_dim, "::mlir::ArrayAttr":$lane_dims, "int":$vector_dim ); @@ -225,27 +220,6 @@ def SwizzleTransformAttr : MosaicGPU_Attr<"SwizzleTransform", "swizzle"> { let assemblyFormat = "`<` $swizzle `>`"; } -def LayoutAttr : MosaicGPU_Attr<"Layout", "layout", - [DeclareAttrInterfaceMethods]> { - let parameters = (ins - TypeParameter<"int32_t", "number of dimensions">:$num_dimensions, - ArrayRefParameter<"mlir::Attribute", "transforms">:$transforms - ); - - let summary = "Specifies a layout of a memref in SMEM."; - let description = [{ - This layout attribute is used to specify the layout of a memref in SMEM. - It is composed of a number of transforms, which are applied in the order - they are provided. The transforms can be any combination of: - - TileTransformAttr - - TransposeTransformAttr - - SwizzleTransformAttr - - The num_dimensions parameter must match the rank of the memref shape. - }]; - let assemblyFormat = "`<` $num_dimensions `,` $transforms `>`"; -} - def MosaicGPU_AsyncLoadOp : Op { let summary = "Schedules an async load of a MemRef from GMEM to SMEM"; @@ -265,28 +239,20 @@ def MosaicGPU_AsyncLoadOp : Op { + let summary = "Casts a vector to a new layout."; + let description = [{Casts a vector value to a new strided or tiled layout.}]; + let arguments = (ins + AnyVectorOfAnyRank:$x, + + // Attributes + AnyAttrOf<[ + MosaicGPU_WGStridedFragLayout, + MosaicGPU_TiledLayout + ]>:$new_layout + ); + + let results = (outs AnyVectorOfAnyRank); + + let assemblyFormat = "`x` `(` $x `:` type($x) `)` attr-dict"; + + let extraClassDeclaration = [{ + static llvm::LogicalResult inferReturnTypes( + mlir::MLIRContext *, + std::optional location, + mlir::ValueRange operands, + mlir::DictionaryAttr attributes, + mlir::OpaqueProperties properties, + mlir::RegionRange regions, + llvm::SmallVectorImpl &inferredReturnTypes) { + if (operands.empty()) { + return ::mlir::emitOptionalError( + location, "expected non-empty operands"); + } + inferredReturnTypes.assign({operands[0].getType()}); + return ::mlir::success(); + } + }]; +} + + +def MosaicGPU_BroadcastInDimOp : Op { + let summary = "Broadcasts a vector to a new shape."; + let description = [{ + `broadcast_dimensions` must have the same size as the rank of the input + vector and for each input dimension, specifies which output dimension it + corresponds to. + }]; + + let arguments = (ins + AnyVectorOfAnyRank:$operand, + + // Attributes + DenseI64ArrayAttr:$broadcast_dimensions + ); + + let results = (outs AnyVectorOfAnyRank); + let assemblyFormat = [{ + `(` $operand `:` type($operand) `)` attr-dict `->` type(results) + }]; + let hasVerifier = 1; +} + def MosaicGPU_SliceSMEMOp : Op { let summary = "Constructs an SMEM MemRef with the requested type that begins at the specified SMEM offset address."; @@ -385,7 +405,7 @@ def MosaicGPU_SliceSMEMOp : Op { } def MosaicGPU_WGMMAOp : Op { - let summary = "Multiply two matrices asyncronously using warpgroup level matrix multiply operations."; + let summary = "Multiply two matrices asynchronously using warpgroup level matrix multiply operations."; let description = [{ Schedules WGMMA operations that perform the following matrix multiply and accumulate: @@ -394,19 +414,14 @@ def MosaicGPU_WGMMAOp : Op { This operation supports larger inputs than the PTX-level WGMMA operation and will schedule as many PTX-level WGMMA operations as needed to - accomplish the calculation. The `b` matrix, and optionally `a`, needs to be - provided as a 2-dimensional memref. All memrefs may have transforms that - define swizzling, tiling, and transposition. + accomplish the calculation. The `b` matrix, and optionally `a`, need to be + provided as a 2-dimensional memref. The inputs should have the following shapes: - a: [groups_m * 64, groups_k * s] - b: [groups_k * s, groups_n * s] - accumulator: [groups_m * 64, groups_n * s] - Where: - - `s == swizzle/element_bytediwth` (for `kNoSwizzle`, `swizzle` is 16.) - and the tilings are [64, s] for `a` and [s, s] for `b`. - - `a` and/or `b` may be transposed if the corresponding attribute is set - to `true`. + where `s == swizzle / element_bytewidth`. The output has an identical shape and type as the input accumulator. @@ -419,7 +434,7 @@ def MosaicGPU_WGMMAOp : Op { registers need to be synchronized with a memory fence. Usually `a` is read from shared memory if it is used directly in the WGMMA - operation. If `a` needs to be transfromed before it is used in the WGMMA + operation. If `a` needs to be transformed before it is used in the WGMMA operation, it may be more convenient to read it directly form registers. This avoids the need to store the data and wait for a fence. }]; @@ -429,10 +444,7 @@ def MosaicGPU_WGMMAOp : Op { AnyTypeOf<[ MemRefOf<[MosaicGPU_WGMMASupportedType]>, VectorOfAnyRankOf<[MosaicGPU_WGMMASupportedType]>]>:$a, - MemRefOf<[MosaicGPU_WGMMASupportedType]>:$b, - - DefaultValuedOptionalAttr:$transpose_a, - DefaultValuedOptionalAttr:$transpose_b + MemRefOf<[MosaicGPU_WGMMASupportedType]>:$b ); let results = (outs VectorOfAnyRankOf<[MosaicGPU_WGMMASupportedType]>); @@ -465,4 +477,75 @@ def MosaicGPU_WGMMAOp : Op { let hasVerifier = 1; } +def MosaicGPU_OptimizationBarrierOp : Op { + let summary = "Prevents MLIR from moving operations across the barrier."; + + let arguments = (ins + Variadic:$operands + ); + let results = (outs Variadic); + + let extraClassDeclaration = [{ + static llvm::LogicalResult inferReturnTypes( + mlir::MLIRContext *, + std::optional location, + mlir::ValueRange operands, + mlir::DictionaryAttr attributes, + mlir::OpaqueProperties properties, + mlir::RegionRange regions, + llvm::SmallVectorImpl &inferredReturnTypes) { + if (operands.empty()) { + return ::mlir::emitOptionalError( + location, "expected non-empty operands"); + } + ::mlir::TypeRange operand_types = operands.getTypes(); + inferredReturnTypes.assign(operand_types.begin(), operand_types.end()); + return ::mlir::success(); + } + }]; +} + +def MosaicGPU_CustomPrimitiveOp : Op { + let summary = "Allows defining a custom Mosaic GPU primitive."; + let description = [{ + Allows defining a custom Mosaic GPU primitive. + + Custom primitives should carry input and output layouts for each of their + vector operands and outputs, and input transforms for each of their memref + operands that live in SMEM. + + Custom primitives can only return vectors. + }]; + + let arguments = ( + ins Variadic:$operands, + // Attributes + ArrayAttr:$in_layouts, + ArrayAttr:$in_transforms, + ArrayAttr:$out_layouts + ); + + let results = (outs Variadic>); + let regions = (region AnyRegion:$body); + + let hasVerifier = 1; +} + +def MosaicGPU_WithTransformsOp : Op { + let summary = "A noop that allows manually setting transforms on a memref."; + let description = [{ + This op enforces the provided transforms on the parameter memref. + }]; + + let arguments = ( + ins MemRefOf<[AnyType]>:$ref, + // Attributes + ArrayAttr:$transforms + ); + + let results = (outs MemRefOf<[AnyType]>); +} + #endif // THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_GPU_MOSAIC_GPU_TD_ diff --git a/jaxlib/mosaic/dialect/gpu/mosaic_gpu_test.cc b/jaxlib/mosaic/dialect/gpu/mosaic_gpu_test.cc index 527aa7c7ce25..5458ba7fac88 100644 --- a/jaxlib/mosaic/dialect/gpu/mosaic_gpu_test.cc +++ b/jaxlib/mosaic/dialect/gpu/mosaic_gpu_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include #include +#include #include #include @@ -25,26 +26,25 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" -#include "llvm/include/llvm/ADT/ArrayRef.h" -#include "llvm/include/llvm/ADT/SmallVector.h" -#include "mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h" -#include "mlir/include/mlir/Conversion/LLVMCommon/StructBuilder.h" -#include "mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" -#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "mlir/include/mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/include/mlir/Dialect/SCF/Utils/Utils.h" -#include "mlir/include/mlir/IR/Builders.h" -#include "mlir/include/mlir/IR/BuiltinOps.h" -#include "mlir/include/mlir/IR/Diagnostics.h" -#include "mlir/include/mlir/IR/MLIRContext.h" -#include "mlir/include/mlir/IR/OwningOpRef.h" -#include "mlir/include/mlir/IR/Types.h" -#include "mlir/include/mlir/IR/Value.h" -#include "mlir/include/mlir/IR/Verifier.h" -#include "mlir/include/mlir/Interfaces/DataLayoutInterfaces.h" -#include "mlir/include/mlir/Support/LLVM.h" -#include "tsl/platform/errors.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/Conversion/LLVMCommon/MemRefBuilder.h" +#include "mlir/Conversion/LLVMCommon/StructBuilder.h" +#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/Utils/Utils.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OwningOpRef.h" +#include "mlir/IR/Types.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/Verifier.h" +#include "mlir/Interfaces/DataLayoutInterfaces.h" +#include "mlir/Support/LLVM.h" +#include "xla/tsl/platform/errors.h" namespace mosaic_gpu { namespace { diff --git a/jaxlib/mosaic/dialect/tpu/array_util.cc b/jaxlib/mosaic/dialect/tpu/array_util.cc index 4c1e79667c0f..f7d559fb08bc 100644 --- a/jaxlib/mosaic/dialect/tpu/array_util.cc +++ b/jaxlib/mosaic/dialect/tpu/array_util.cc @@ -19,8 +19,8 @@ limitations under the License. #include "absl/log/check.h" #include "absl/types/span.h" -#include "llvm/include/llvm/ADT/STLExtras.h" -#include "mlir/include/mlir/Support/LLVM.h" +#include "llvm/ADT/STLExtras.h" +#include "mlir/Support/LLVM.h" namespace mlir::tpu::internal { diff --git a/jaxlib/mosaic/dialect/tpu/array_util.h b/jaxlib/mosaic/dialect/tpu/array_util.h index 1b755dbf8495..ab8e98d17836 100644 --- a/jaxlib/mosaic/dialect/tpu/array_util.h +++ b/jaxlib/mosaic/dialect/tpu/array_util.h @@ -20,7 +20,7 @@ limitations under the License. #include "absl/log/check.h" #include "absl/types/span.h" -#include "mlir/include/mlir/Support/LLVM.h" +#include "mlir/Support/LLVM.h" #include "jaxlib/mosaic/dialect/tpu/util.h" #include "xla/array.h" diff --git a/jaxlib/mosaic/dialect/tpu/array_util_test.cc b/jaxlib/mosaic/dialect/tpu/array_util_test.cc index 18c2f94fa8b6..bcbf417a967b 100644 --- a/jaxlib/mosaic/dialect/tpu/array_util_test.cc +++ b/jaxlib/mosaic/dialect/tpu/array_util_test.cc @@ -20,7 +20,7 @@ limitations under the License. #include #include -#include "mlir/include/mlir/Support/LLVM.h" +#include "mlir/Support/LLVM.h" #include "xla/array.h" namespace mlir::tpu { diff --git a/jaxlib/mosaic/dialect/tpu/integrations/c/tpu_dialect.cc b/jaxlib/mosaic/dialect/tpu/integrations/c/tpu_dialect.cc index 772e87beff71..dee4f5de43d8 100644 --- a/jaxlib/mosaic/dialect/tpu/integrations/c/tpu_dialect.cc +++ b/jaxlib/mosaic/dialect/tpu/integrations/c/tpu_dialect.cc @@ -21,8 +21,13 @@ limitations under the License. #include #include #include +#include +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/LogicalResult.h" #include "llvm/Support/MemAlloc.h" #include "llvm/Support/raw_ostream.h" #include "mlir-c/IR.h" @@ -31,16 +36,14 @@ limitations under the License. #include "mlir/CAPI/Registration.h" #include "mlir/CAPI/Utils.h" #include "mlir/CAPI/Wrap.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/MLIRContext.h" #include "mlir/IR/Value.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" -#include "absl/log/check.h" -#include "absl/log/log.h" #include "jaxlib/mosaic/dialect/tpu/layout.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" #include "jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.h" @@ -410,7 +413,7 @@ MLIR_CAPI_EXPORTED void mlirTpuRegisterMosaicSerdePass() { mlir::tpu::registerMosaicSerdePass(); } -#include "mlir/CAPI/Pass.h" // IWYU pragma: keep +#include "mlir/CAPI/Pass.h" // IWYU pragma: keep #include "mlir/CAPI/Support.h" // IWYU pragma: keep extern "C" { diff --git a/jaxlib/mosaic/dialect/tpu/layout.cc b/jaxlib/mosaic/dialect/tpu/layout.cc index 172f2e91b41f..f041570b8371 100644 --- a/jaxlib/mosaic/dialect/tpu/layout.cc +++ b/jaxlib/mosaic/dialect/tpu/layout.cc @@ -19,7 +19,6 @@ limitations under the License. #include #include #include -#include #include #include #include @@ -27,6 +26,7 @@ limitations under the License. #include #include +#include "absl/log/check.h" #include "llvm/ADT/Hashing.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Support/MathExtras.h" @@ -41,7 +41,6 @@ limitations under the License. #include "mlir/IR/ValueRange.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" -#include "absl/log/check.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" #include "jaxlib/mosaic/dialect/tpu/util.h" @@ -250,25 +249,22 @@ class TiledRectangularVregBounds : public VRegDataBounds { FailureOr> getVectorMask( OpBuilder& builder, const Location loc, const int generation, const std::array target_shape) const override { + const int8_t bitwidth = layout_.bitwidth(); + const int packing = layout_.packing(); + const int max_subelems = generation < 4 ? 1 : generation < 5 ? 2 : 4; const IntegerType i1 = builder.getI1Type(); - FAILUREOR_ASSIGN_OR_RETURN( - const VectorType mask_vreg_ty, [&]() -> FailureOr { - // I'm pretty sure this works for all bitwidths, but it's untested. - if (maskVariesAlong(Direction::kSubelements, target_shape)) { - if (layout_.packing() != 2) { - // TODO(b/300082350): Generalize this - return emitError(loc, "Not implemented: packing != 2"); - } - // For older TPUs, we virtualize masking - if (generation < 4) { - return VectorType::get(target_shape, i1); - } else { - return VectorType::get( - {target_shape[0], target_shape[1], layout_.packing()}, i1); - } - } + const VectorType mask_vreg_ty = [&]() { + if (maskVariesAlong(Direction::kSubelements, target_shape)) { + // When CreateSubelementMask isn't supported, we virtualize masking. + if (packing > max_subelems) { return VectorType::get(target_shape, i1); - }()); + } else { + return VectorType::get( + {target_shape[0], target_shape[1], packing}, i1); + } + } + return VectorType::get(target_shape, i1); + }(); if (isComplete(target_shape)) { return cast>( builder @@ -280,7 +276,6 @@ class TiledRectangularVregBounds : public VRegDataBounds { } Value mask = nullptr; CHECK_GE(num_tiles_, 0); - const int packing = layout_.packing(); const int64_t start_sub = start_offsets_[0] / packing; const int64_t end_sub = llvm::divideCeil(end_offsets_[0], packing); CHECK_LE(0, start_sub); @@ -309,20 +304,20 @@ class TiledRectangularVregBounds : public VRegDataBounds { if (maskVariesAlong(Direction::kSubelements, target_shape)) { int64_t start_row = start_offsets_[0] + row_offset; int64_t end_row = end_offsets_[0] + row_offset; - if (generation >= 4) { + if (packing <= max_subelems) { // Only use non-trivial start/end if they don't fall on sublane // boundary. Otherwise CreateMaskOp already does the right thing. This // lets us use cheaper instruction sequences on TPUv4. - if (start_offsets_[0] % layout_.packing() == 0) { + if (start_offsets_[0] % packing == 0) { start_row = 0; } - if (end_offsets_[0] % layout_.packing() == 0) { - end_row = target_shape[0] * layout_.packing(); + if (end_offsets_[0] % packing == 0) { + end_row = target_shape[0] * packing; } auto submask = builder.create( loc, mask_vreg_ty, start_row, end_row); tile_mask = builder.create(loc, tile_mask, submask); - } else { // generation < 4 + } else { // packing > max_subelems const auto getMaskCst = [&](const uint64_t v) { const auto int_mask_ty = VectorType::get(target_shape, builder.getI32Type()); @@ -334,25 +329,33 @@ class TiledRectangularVregBounds : public VRegDataBounds { }; tile_mask = builder.create( loc, tile_mask, getMaskCst(0xFFFFFFFF), getMaskCst(0)); - if (start_row % 2 != 0) { + if (const int64_t row_in_sublane = start_row % packing; + row_in_sublane != 0) { auto row_mask = builder.create( loc, mask_vreg_ty, - ValueRange{boundIdxConst(start_row / 2), boundIdxConst(0)}, - ValueRange{boundIdxConst(start_row / 2 + 1), + ValueRange{boundIdxConst(start_row / packing), + boundIdxConst(0)}, + ValueRange{boundIdxConst(start_row / packing + 1), boundIdxConst(target_shape[1])}); auto row_bitmask = builder.create( - loc, row_mask, getMaskCst(0xFFFF0000), getMaskCst(0xFFFFFFFF)); + loc, row_mask, + getMaskCst(0xFFFFFFFF << row_in_sublane * bitwidth), + getMaskCst(0xFFFFFFFF)); tile_mask = builder.create(loc, tile_mask, row_bitmask); } - if (end_row % 2 != 0) { + if (const int64_t row_in_sublane = end_row % packing; + row_in_sublane != 0) { auto row_mask = builder.create( loc, mask_vreg_ty, - ValueRange{boundIdxConst(end_row / 2), boundIdxConst(0)}, - ValueRange{boundIdxConst(end_row / 2 + 1), + ValueRange{boundIdxConst(end_row / packing), boundIdxConst(0)}, + ValueRange{boundIdxConst(end_row / packing + 1), boundIdxConst(target_shape[1])}); auto row_bitmask = builder.create( - loc, row_mask, getMaskCst(0xFFFF), getMaskCst(0xFFFFFFFF)); + loc, row_mask, + getMaskCst(0xFFFFFFFFu >> + (packing - row_in_sublane) * bitwidth), + getMaskCst(0xFFFFFFFF)); tile_mask = builder.create(loc, tile_mask, row_bitmask); } diff --git a/jaxlib/mosaic/dialect/tpu/layout.h b/jaxlib/mosaic/dialect/tpu/layout.h index 2c45be62fa7d..dceee9cf41a8 100644 --- a/jaxlib/mosaic/dialect/tpu/layout.h +++ b/jaxlib/mosaic/dialect/tpu/layout.h @@ -18,15 +18,16 @@ limitations under the License. #include #include +#include #include -#include #include #include #include +#include "absl/log/check.h" +#include "llvm/ADT/Hashing.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/bit.h" -#include "llvm/Support/ErrorHandling.h" #include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/IR/Builders.h" @@ -38,7 +39,6 @@ limitations under the License. #include "mlir/IR/Value.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" -#include "absl/log/check.h" namespace mlir::tpu { @@ -168,7 +168,7 @@ class RectangularVregBounds : public VRegDataBounds { // --- // // The tiling attribute makes it possible to subdivide a single vector register -// into multiple subtiles that traverse the last dimension of a value. For +// into multiple sub-tiles that traverse the last dimension of a value. For // example, consider vregs of shape (4, 5) on (2, 10) array: // // a b c d e f g h i j @@ -233,6 +233,11 @@ class VectorLayout { implicit_dim_(implicit_dim) { // TODO(b/275751535): Allow more bitwidths. CHECK(llvm::has_single_bit(bitwidth_) && bitwidth_ <= 32); + CHECK_GT(tiling_[0], 0); + CHECK_GT(tiling_[1], 0); + CHECK_GE(offsets_[0].value_or(0), 0); + CHECK_GE(offsets_[1].value_or(0), 0); + CHECK_LT(offsets_[0].value_or(0), tiling_[0]); } static int num_implicit_dims(const ImplicitDim implicit_dim) { diff --git a/jaxlib/mosaic/dialect/tpu/tpu.td b/jaxlib/mosaic/dialect/tpu/tpu.td index 4b5ed34934d7..241f0a745928 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu.td +++ b/jaxlib/mosaic/dialect/tpu/tpu.td @@ -82,7 +82,8 @@ def TPU_SomeSemaphoreType : AnyTypeOf<[TPU_SemaphoreType, TPU_DMASemaphoreType]> def TPU_DimensionSemantics : I32EnumAttr<"DimensionSemantics", "Dimension semantics", [ I32EnumAttrCase<"parallel", 0>, - I32EnumAttrCase<"arbitrary", 1> + I32EnumAttrCase<"arbitrary", 1>, + I32EnumAttrCase<"core_parallel", 2> ]> { let genSpecializedAttr = 0; let cppNamespace = "::mlir::tpu"; @@ -156,9 +157,8 @@ def TPU_TiledLayoutAttr def TPU_MemorySpace : I32EnumAttr<"MemorySpace", "Memory space", [ I32EnumAttrCase<"kAny", 4294967295, "any">, - // TODO(apaszke): Rename to kXYZ in C++ - I32EnumAttrCase<"vmem", 0, "vmem">, - I32EnumAttrCase<"smem", 1, "smem">, + I32EnumAttrCase<"kVmem", 0, "vmem">, + I32EnumAttrCase<"kSmem", 1, "smem">, I32EnumAttrCase<"kHbm", 2, "hbm">, I32EnumAttrCase<"kCmem", 3, "cmem">, I32EnumAttrCase<"kSemaphoreMem", 4, "semaphore_mem"> @@ -176,6 +176,9 @@ class TPU_Op traits = []> : Op { } +def DefaultMemWrite : MemoryEffects<[MemWrite]>; +def DefaultMemRead : MemoryEffects<[MemRead]>; + def TPU_ReductionKind : I32EnumAttr<"ReductionKind", "Reduction kind", [ I32EnumAttrCase<"SUM", 0, "sum">, I32EnumAttrCase<"MAX", 1, "max">, @@ -198,7 +201,7 @@ def TPU_AllReduceOp : TPU_Op<"all_reduce", [Pure, SameOperandsAndResultType]> { }]; } -def TPU_StoreOp : TPU_Op<"store", [AttrSizedOperandSegments]> { +def TPU_StoreOp : TPU_Op<"store", [DefaultMemWrite, AttrSizedOperandSegments]> { let arguments = (ins TPU_Vreg:$valueToStore, AnyType:$base, @@ -213,7 +216,7 @@ def TPU_StoreOp : TPU_Op<"store", [AttrSizedOperandSegments]> { }]; } -def TPU_LoadOp : TPU_Op<"load"> { +def TPU_LoadOp : TPU_Op<"load", [DefaultMemRead]> { let arguments = (ins AnyType:$base, Variadic:$indices, @@ -227,7 +230,7 @@ def TPU_LoadOp : TPU_Op<"load"> { } // TODO(jevinjiang): migrate tpu.strided_store to general vector store op. -def TPU_VectorStoreOp :TPU_Op<"vector_store", [AttrSizedOperandSegments]> { +def TPU_VectorStoreOp :TPU_Op<"vector_store", [DefaultMemWrite, AttrSizedOperandSegments]> { let arguments = (ins AnyVectorOfNonZeroRank:$valueToStore, AnyMemRef:$base, @@ -242,7 +245,34 @@ def TPU_VectorStoreOp :TPU_Op<"vector_store", [AttrSizedOperandSegments]> { let hasVerifier = 1; } -def TPU_StridedLoadOp : TPU_Op<"strided_load"> { +// tpu.vector_load loads a vector from memory into a register. +// +// base : Memref to load from. +// indices: Scalar indices into base. indices must be of the same rank as the +// base memref shape. +// strides: The stride to use for calculating the address of subsequent +// elements. If left unspecified, the stride is implicitly 1 along +// each dimension. Otherwise the stride must match the rank of the +// memref shape. +// mask : Elementwise vector mask. Must be broadcastable to the shape of the +// result vector. Depending on the core type, this may be a dynamic +// (lane) mask consumed from a register or a static (sublane) mask +// that must be the result of arith.constant. +def TPU_VectorLoadOp :TPU_Op<"vector_load", [DefaultMemRead, AttrSizedOperandSegments]> { + let arguments = (ins + AnyMemRef:$base, + Variadic:$indices, + DenseI32ArrayAttr:$strides, + Optional:$mask // Elementwise mask. + ); + let results = (outs AnyVectorOfNonZeroRank:$result); + let assemblyFormat = [{ + $base `[` $indices `]` (`masked` $mask^)? attr-dict `:` type($base) `,` type($result) `,` type($mask) + }]; + let hasVerifier = 1; +} + +def TPU_StridedLoadOp : TPU_Op<"strided_load", [DefaultMemRead]> { let arguments = (ins AnyMemRef:$base, Variadic:$indices, @@ -255,7 +285,7 @@ def TPU_StridedLoadOp : TPU_Op<"strided_load"> { let hasVerifier = 1; } -def TPU_StridedStoreOp : TPU_Op<"strided_store"> { +def TPU_StridedStoreOp : TPU_Op<"strided_store", [DefaultMemWrite]> { let arguments = (ins AnyVectorOfNonZeroRank:$valueToStore, AnyMemRef:$base, @@ -269,7 +299,7 @@ def TPU_StridedStoreOp : TPU_Op<"strided_store"> { let hasVerifier = 1; } -def TPU_ShuffledLoadOp : TPU_Op<"shuffled_load"> { +def TPU_ShuffledLoadOp : TPU_Op<"shuffled_load", [DefaultMemRead]> { let arguments = (ins AnyMemRef:$base, Variadic:$indices, @@ -284,7 +314,7 @@ def TPU_ShuffledLoadOp : TPU_Op<"shuffled_load"> { let hasCanonicalizeMethod = 1; } -def TPU_ShuffledStoreOp : TPU_Op<"shuffled_store"> { +def TPU_ShuffledStoreOp : TPU_Op<"shuffled_store", [DefaultMemWrite]> { let arguments = (ins TPU_Vreg:$valueToStore, AnyMemRef:$base, @@ -302,6 +332,11 @@ def TPU_ShuffledStoreOp : TPU_Op<"shuffled_store"> { // TODO(jevinjiang): deprecate to use dynamic_rotate. def TPU_RotateOp : TPU_Op<"rotate", [Pure, SameOperandsAndResultType]> { + let description = [{ + Rotates the given vector by the given amount in the given dimension, i.e., + for a 2D vector of shape (m, n), rotating dim 0 by `amount` will shift a row + at index `i` to index `(i + amount) % m` + }]; let arguments = (ins AnyVectorOfNonZeroRank:$value, SI32Attr:$amount, @@ -405,6 +440,7 @@ def TPU_RelayoutOp : TPU_Op<"relayout", [SameOperandsAndResultType]> { let arguments = (ins AnyType:$input); let results = (outs AnyType:$output); let assemblyFormat = [{ $input attr-dict `:` type($input) `->` type($output) }]; + let hasVerifier = 1; } def TPU_PackMaskOp : TPU_Op<"pack_vmsk", [Pure, SameTypeOperands]> { @@ -429,15 +465,30 @@ def TPU_GatherOp : TPU_Op<"gather", [Pure]> { }]; } -def TPU_DynamicGatherOp : TPU_Op<"dynamic_gather", [Pure]> { +def TPU_DynamicGatherOp : TPU_Op<"dynamic_gather", [Pure, SameOperandsAndResultShape, AllTypesMatch<["source", "output"]>]> { + let description = [{ + Gathers elements from `source` using `indices`. + + The specified `dimensions` of `source` are collapsed together and indexed by + `indices`. + + Given a shape `N0 x N1 x ...`, the `output[i0, i1, ...]` is given by + `collapsed_source[j0, j1, ..., indices[i0, i1, ...] mod M]` where + - `collapsed_source` is the result of collapsing `dimensions` of `source` + into a new trailing dimension of size `M`. + - `jk` is the subsequence of `in` for `n` not in `dimensions`. + + When a single dimension is specified, this is similar to + `np.take_along_axis`, except that OOB indices wrap. + }]; let arguments = (ins AnyVectorOfNonZeroRank:$source, - AnyVectorOfNonZeroRank:$indices, - I32Attr:$dimension + VectorOfNonZeroRankOf<[AnyInteger]>:$indices, + DenseI32ArrayAttr:$dimensions ); let results = (outs AnyVectorOfNonZeroRank:$output); let assemblyFormat = [{ - $source `[` $indices `]` `in` $dimension attr-dict + $source `[` $indices `]` `in` $dimensions attr-dict `:` type($source) `,` type($indices) `->` type($output) }]; let hasVerifier = 1; @@ -464,6 +515,14 @@ def TPU_FPToSIOp : TPU_Op<"fptosi", [Pure, ElementwiseMappable]> { let hasCanonicalizeMethod = 1; } +// Internal operation. All arith.sitofp operations that change the bitwidth +// must be canonicalized to this operation. +def TPU_SIToFPOp : TPU_Op<"sitofp", [Pure, ElementwiseMappable]> { + let arguments = (ins AnyType:$in, TPU_RoundingModeEnum:$rounding_mode); + let results = (outs AnyType:$output); + let assemblyFormat = [{ $in attr-dict `:` type($in) `->` type($output) }]; +} + def TPU_DotDimensionNumbersAttr : TPU_Attr<"DotDimensionNumbers", "dot_dimension_numbers"> { let parameters = (ins ArrayRefParameter<"int64_t", "">:$lhs_contracting_dims, @@ -752,7 +811,9 @@ def TPU_EnqueueDMAOp : TPU_Op<"enqueue_dma", [AttrSizedOperandSegments]> { AnyMemRef:$target, MemRefOf<[TPU_DMASemaphoreType]>:$target_semaphore, Optional:$device_id, // For remote DMAs - Optional:$core_id // For megacore + Optional:$core_id, // For megacore + // Smaller number means higher priority. 0 is the highest and the default. + DefaultValuedAttr:$priority ); let hasVerifier = 1; } @@ -840,6 +901,61 @@ def TPU_PRNGRandomBitsOp : TPU_Op<"prng_random_bits"> { let results = (outs AnyVectorOfNonZeroRank:$output); } +def TPU_SublaneShuffleOp : TPU_Op<"sublane_shuffle", [SameOperandsAndResultType]> { + // This op takes 2 physical vregs and a pattern, applies the pattern, + // and returns the result as 1 vreg. + // + // The pattern is a list of integers, where the integer value is the + // index of the sublane in the *combined input* [lhs, rhs], and the + // position of the integer in the list is the index of the sublane + // in the *output* vreg. + // + // The pattern size must match the operand/result sublane count. + // + // Example: + // %0 = tpu.single_output_sublane_shuffle %a, %b, + // [0, 1, 2, 3, 4, 5, 6, 7] // Result is %a + // %1 = tpu.single_output_sublane_shuffle %a, %b, + // [8, 9, 10, 11, 12, 13, 14, 15] // Result is %b + // %2 = tpu.single_output_sublane_shuffle %a, %b, + // [7, 6, 5, 4, 11, 10, 9, 8] // Result uses high half of a + // // and low half of b, reversed. + let arguments = (ins + TPU_Vreg:$lhs, + TPU_Vreg:$rhs, + DenseI32ArrayAttr:$pattern + ); + let results = (outs TPU_Vreg:$result); + let assemblyFormat = [{ + $lhs `,` $rhs `,` $pattern attr-dict `:` type($lhs) `,` type($rhs) `->` type($result) + }]; + + let hasVerifier = 1; +} + +def TPU_TransposeOp : TPU_Op<"transpose", [Pure]> { + let summary = "tpu transpose operation"; + let arguments = (ins AnyVectorOfAnyRank:$vector, + DenseI64ArrayAttr:$permutation); + let results = (outs AnyVectorOfAnyRank:$result); + + let builders = [ + OpBuilder<(ins "Value":$vector, "ArrayRef":$permutation)> + ]; + let assemblyFormat = [{ + $vector `,` $permutation attr-dict `:` type($vector) `->` type($result) + }]; + let extraClassDeclaration = [{ + VectorType getSourceVectorType() { + return ::llvm::cast(getVector().getType()); + } + VectorType getResultVectorType() { + return ::llvm::cast(getResult().getType()); + } + }]; + let hasVerifier = 1; +} + def TPU_LogOp : TPU_Op<"log"> { let arguments = (ins Variadic:$inputs, @@ -912,6 +1028,8 @@ def CanonicalizeMosaicPass : Pass<"tpu-canonicalize-mosaic", "::mlir::func::Func let options = [ Option<"hardware_generation", "hardware-generation", "int", /*default=*/"-1", "">, Option<"compatibility_mode", "compatibility-mode", "bool", /*default=*/"1", "">, + Option<"lane_count", "lane-count", "int", /*default=*/"128", "">, + Option<"sublane_count", "sublane-count", "int", /*default=*/"8", "">, ]; } diff --git a/jaxlib/mosaic/dialect/tpu/tpu_dialect.cc b/jaxlib/mosaic/dialect/tpu/tpu_dialect.cc index 59ca5d7a3437..e0e061fbd6dd 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_dialect.cc +++ b/jaxlib/mosaic/dialect/tpu/tpu_dialect.cc @@ -15,27 +15,23 @@ limitations under the License. #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" -#include -#include #include -#include #include -#include -#include +#include "absl/hash/hash.h" +#include "absl/log/log.h" +#include "llvm/ADT/Hashing.h" #include "llvm/ADT/TypeSwitch.h" // IWYU pragma: keep. +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/DialectImplementation.h" // IWYU pragma: keep. #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" -#include "absl/hash/hash.h" -#include "absl/log/log.h" -#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" +#include "jaxlib/mosaic/dialect/tpu/layout.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.cc.inc" #include "jaxlib/mosaic/dialect/tpu/tpu_enums.cc.inc" #include "xla/layout.h" diff --git a/jaxlib/mosaic/dialect/tpu/tpu_dialect.h b/jaxlib/mosaic/dialect/tpu/tpu_dialect.h index 0800a9e75087..798386b92744 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_dialect.h +++ b/jaxlib/mosaic/dialect/tpu/tpu_dialect.h @@ -23,16 +23,14 @@ limitations under the License. #include #include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Value.h" #include "mlir/Pass/Pass.h" -#include "mlir/include/mlir/IR/BuiltinOps.h" -#include "mlir/include/mlir/IR/BuiltinTypes.h" -#include "mlir/include/mlir/IR/Value.h" -#include "mlir/include/mlir/Support/LogicalResult.h" -#include "jaxlib/mosaic/dialect/tpu/layout.h" +#include "mlir/Support/LogicalResult.h" +#include "jaxlib/mosaic/dialect/tpu/layout.h" // IWYU pragma: keep #include "jaxlib/mosaic/dialect/tpu/tpu_enums.h.inc" -#include "jaxlib/mosaic/dialect/tpu/transforms/serde.h" -#include "xla/layout.h" +#include "xla/layout.h" // IWYU pragma: keep namespace mlir::tpu { class TPUDialect; @@ -64,11 +62,11 @@ struct ApplyVectorLayoutContext { // mxu_shape = {contracting_size, non_contracting_size} std::array mxu_shape = {128, 128}; int64_t max_sublanes_in_scratch = 0; - int64_t vmem_banks = -1; // -1 means "unspecified". + int64_t vmem_banks = -1; // -1 means "unspecified". int32_t max_shuffle_sublane_offset = -1; // -1 means "unspecified". }; -std::pair mightCommunicateBetweenChips(Operation* op); +std::pair mightCommunicateBetweenChips(Operation *op); std::unique_ptr> createInferMemRefLayoutPass( int hardware_generation = -1, @@ -76,7 +74,8 @@ std::unique_ptr> createInferMemRefLayoutPass( const TpuTilingFlags &tpu_tiling_flags = {}); std::unique_ptr> createCanonicalizeMosaicPass( - int hardware_generation = -1, bool compatibility_mode = true); + int hardware_generation = -1, bool compatibility_mode = true, + std::array target_shape = {8, 128}); std::unique_ptr> createInferVectorLayoutPass( int hardware_generation = -1, diff --git a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc index c73accb09b26..9449b2737918 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc +++ b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc @@ -19,26 +19,30 @@ limitations under the License. #include #include +#include "absl/log/check.h" +#include "absl/strings/str_format.h" #include "llvm/ADT/STLExtras.h" -#include "llvm/Support/FormatVariadic.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Casting.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/OperationSupport.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/Value.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" -#include "absl/log/check.h" -#include "absl/strings/str_format.h" -#include "mlir/include/mlir/Dialect/Math/IR/Math.h" -#include "mlir/include/mlir/IR/Builders.h" -#include "mlir/include/mlir/IR/BuiltinTypeInterfaces.h" -#include "mlir/include/mlir/IR/BuiltinTypes.h" -#include "mlir/include/mlir/IR/Diagnostics.h" -#include "mlir/include/mlir/IR/IRMapping.h" -#include "mlir/include/mlir/IR/OperationSupport.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" #include "jaxlib/mosaic/dialect/tpu/util.h" +#include "xla/layout.h" namespace mlir { namespace tpu { @@ -48,15 +52,15 @@ LogicalResult UnrollVectorsOp::canonicalize(UnrollVectorsOp op, RollVectorsOp roll_op = dyn_cast_or_null(op.getOperand().getDefiningOp()); if (!roll_op) { - return failure(); + return failure(); } if (roll_op.getNumOperands() != op.getNumResults()) { - return failure(); + return failure(); } for (auto [v1, v2] : llvm::zip(roll_op.getOperandTypes(), op.getResultTypes())) { if (v1 != v2) { - return failure(); + return failure(); } } rewriter.replaceOp(op, roll_op.getOperands()); @@ -94,6 +98,7 @@ LogicalResult BitcastOp::verify() { LogicalResult MemRefSliceOp::verify() { auto source_type = getMemRefType(getMemRef()); auto target_type = getType(); + auto source_layout = source_type.getLayout(); auto target_layout = target_type.getLayout(); auto target_memory_space = target_type.getMemorySpace(); auto indices = getBaseIdx(); @@ -127,12 +132,38 @@ LogicalResult MemRefSliceOp::verify() { return emitOpError( "Memory spaces must match if the target memory space is provided."); } - bool is_target_layout_identity_map = - isa(target_layout) && target_layout.isIdentity(); - if (!is_target_layout_identity_map && - target_type.getLayout() != source_type.getLayout()) { - return emitOpError( - "Layouts must match if the target layout is not an identity map."); + if (isa(target_layout)) { + SmallVector source_strides; + int64_t source_offset; + if (failed( + source_type.getStridesAndOffset(source_strides, source_offset))) { + return failure(); + } + int64_t target_offset = source_offset; + if (target_offset != ShapedType::kDynamic) { + for (auto [base_idx, source_stride] : + llvm::zip(getBaseIdx(), source_strides)) { + if (auto idx = getConstantIntValue(base_idx)) { + target_offset += *idx * source_stride; + } else { + target_offset = ShapedType::kDynamic; + break; + } + } + } + auto expected_layout = + StridedLayoutAttr::get(getContext(), target_offset, source_strides); + if (target_layout != expected_layout) { + return emitOpError("Layout mismatch: got ") + << target_layout << ", expected " << expected_layout << "."; + } + } else { + bool is_target_layout_identity_map = + isa(target_layout) && target_layout.isIdentity(); + if (!is_target_layout_identity_map && target_layout != source_layout) { + return emitOpError( + "Layouts must match if the target layout is not an identity map."); + } } if (getDynamicSizes().size() != target_type.getNumDynamicDims()) { return emitOpError( @@ -165,50 +196,91 @@ LogicalResult MemRefSliceOp::canonicalize(MemRefSliceOp op, LogicalResult MemRefSqueezeOp::verify() { auto source_type = getMemRefType(getInput()); auto target_type = getType(); - // Source and target attributes may be different before propagation is done by - // the canonicalizer, so we allow this when attributes are "unset" in the - // target type. + if (target_type.getMemorySpace() != nullptr && target_type.getMemorySpace() != source_type.getMemorySpace()) { - emitOpError("Memory spaces do not match."); - return failure(); + return emitOpError("Memory spaces do not match."); } + if (target_type.getElementType() != source_type.getElementType()) { - this->emitOpError("Element types don't match."); - return failure(); - } - if (!HasMemorySpace(source_type, tpu::MemorySpace::kSemaphoreMem) && - source_type.getRank() > 1 && target_type.getRank() == 1) { - return emitError("Not implemented: squeeze memref to 1d."); + return emitOpError("Element types don't match."); } + auto source_shape = source_type.getShape(); auto target_shape = target_type.getShape(); - int source_index = source_shape.size() - 1; - int target_index = target_shape.size() - 1; - auto error_msg = llvm::formatv( - "Target shape is not valid. " - "Source type: {0}. Target type: {1}.", - source_type, target_type); - while (source_index >= 0 || target_index >= 0) { - int target_dim = target_index < 0 ? -1 : target_shape[target_index]; - if (source_index < 0) { - // We have run out of source shape but target shape still remains. - emitOpError(error_msg); - return failure(); - } - int source_dim = source_shape[source_index]; - if (source_dim == target_dim) { - source_index--; - target_index--; - } else { - // Only the source dim can be 1 here. - if (source_dim != 1) { - this->emitOpError(error_msg); - return failure(); - } - source_index--; + auto squeezed_or = + computeSqueezedDimsChecked(*this, source_shape, target_shape); + if (failed(squeezed_or)) { + return failure(); + } + + auto target_layout = target_type.getLayout(); + if (isa(target_layout)) { + SmallVector source_strides; + int64_t source_offset; + if (failed( + source_type.getStridesAndOffset(source_strides, source_offset))) { + return failure(); + } + SmallVector target_strides; + for (auto [i, stride] : llvm::enumerate(source_strides)) { + if (!llvm::is_contained(*squeezed_or, i)) { + target_strides.push_back(stride); + } + } + auto expected_layout = + StridedLayoutAttr::get(getContext(), source_offset, target_strides); + if (target_layout != expected_layout) { + return emitOpError("Layout mismatch: got ") + << target_layout << ", expected " << expected_layout << "."; + } + } + + auto erase_layout_op = getInput().getDefiningOp(); + if (!erase_layout_op) { + return success(); + } + + auto layout_ref = erase_layout_op.getOperand(); + MemRefType layout_ty = getMemRefType(layout_ref); + auto layout_attr = dyn_cast(layout_ty.getLayout()); + if (!layout_attr) { + return emitOpError( + "Input from EraseLayoutOp is expected to have a TiledLayoutAttr."); + } + auto &squeezed = squeezed_or.value(); + if (squeezed.empty() && source_shape != target_shape) { + return failure(); + } + + auto tiles = layout_attr.getTiles(); + if (tiles.size() == 1) { + auto tile = layout_attr.getTiles().front(); + auto tile_dims = tile.dimensions(); + int first_tiled = source_shape.size() - tile_dims.size(); + for (int dim : squeezed) { + if (dim >= first_tiled) { + int tile_idx = dim - first_tiled; + if (tile_idx < 0 || tile_idx >= static_cast(tile_dims.size())) { + return emitOpError() << "Internal error: tile index out of bounds."; + } + if (tile_dims[tile_idx] != 1) { + return emitOpError() + << "All tiled squeezed dimensions must be of size 1."; + } + } + } + } else { + auto first_tile = tiles.front(); + for (int dim : squeezed) { + int first_tiled = source_shape.size() - first_tile.dimensions().size(); + if (dim >= first_tiled) { + return emitOpError() << "When multiple tiles are present, no tiled " + "dimensions can be squeezed."; + } } } + return success(); } @@ -220,42 +292,108 @@ LogicalResult MemRefSqueezeOp::canonicalize(MemRefSqueezeOp op, if (!erase_layout) { return failure(); } - // Push layout erasure through squeezing. It is important we see the layout - // for lowering and don't make it hard for other ops to query it. + auto layout_ref = erase_layout.getOperand(); - MemRefType layout_ty = layout_ref.getType(); + MemRefType layout_ty = getMemRefType(layout_ref); + auto layout_attr = dyn_cast(layout_ty.getLayout()); + if (!layout_attr) { + return failure(); + } + auto source_shape = source_type.getShape(); auto target_shape = target_type.getShape(); - int source_index = source_shape.size() - 1; - int target_index = target_shape.size() - 1; - auto old_layout = dyn_cast(layout_ty.getLayout()); - auto target_strides = old_layout.getTileStrides(); - SmallVector tile_strides(target_strides.begin(), - target_strides.end()); - // We want to remove all strides that correspond to squeezed dimensions and - // update the corresponding output layout. - while (source_index >= 0 || target_index >= 0) { - int target_dim = target_index < 0 ? -1 : target_shape[target_index]; - int source_dim = source_shape[source_index]; - if (source_dim == target_dim) { - source_index--; - target_index--; - } else { - // Source index must be 1 here (otherwise verification will have failed). - // We are safe to mutate the strides vector here because we are looping - // backwards. - tile_strides.erase(tile_strides.begin() + source_index); - source_index--; + auto squeezed_or = computeSqueezedDimsChecked(op, source_shape, target_shape); + if (failed(squeezed_or)) { + return failure(); + } + auto &squeezed = squeezed_or.value(); + if (squeezed.empty() && source_shape != target_shape) { + return failure(); + } + + SmallVector tile_strides = + llvm::to_vector(layout_attr.getTileStrides()); + for (int i = squeezed.size() - 1; i >= 0; --i) { + tile_strides.erase(tile_strides.begin() + squeezed[i]); + } + + tpu::TiledLayoutAttr new_layout; + bool target_is_1d = target_shape.size() == 1; + auto tiles = layout_attr.getTiles(); + if (target_is_1d && tiles.size() == 1) { + auto tile_dims = llvm::to_vector(tiles.front().dimensions()); + int first_tiled = source_shape.size() - tile_dims.size(); + for (int i = squeezed.size() - 1; i >= 0; --i) { + int dim = squeezed[i]; + if (dim >= first_tiled) { + int tile_idx = dim - first_tiled; + if (tile_idx < 0 || tile_idx >= static_cast(tile_dims.size())) { + return op.emitError() << "Internal error: tile index out of bounds."; + } + tile_dims.erase(tile_dims.begin() + tile_idx); + } + } + new_layout = tpu::TiledLayoutAttr::get( + op.getContext(), {xla::Tile(tile_dims)}, tile_strides); + } else { + new_layout = tpu::TiledLayoutAttr::get( + op.getContext(), layout_attr.getTiles(), tile_strides); + } + + auto new_ty = MemRefType::get(target_shape, layout_ty.getElementType(), + new_layout, layout_ty.getMemorySpace()); + + auto new_squeeze = + rewriter.create(op.getLoc(), new_ty, layout_ref); + rewriter.replaceOpWithNewOp(op, target_type, new_squeeze); + return success(); +} + +LogicalResult RelayoutOp::verify() { + auto in_layout_array_attr = + getOperation()->getAttrOfType("in_layout"); + if (!in_layout_array_attr || in_layout_array_attr.empty()) { + return emitOpError("missing or empty 'in_layout' attribute"); + } + if (in_layout_array_attr.size() != 1) { + return emitOpError( + "'in_layout' attribute must be an array containing a single " + "VectorLayoutAttr"); + } + auto src_vla = dyn_cast(in_layout_array_attr[0]); + if (!src_vla) { + return emitOpError("'in_layout' attribute is not a VectorLayoutAttr"); + } + + auto out_layout_array_attr = + getOperation()->getAttrOfType("out_layout"); + if (!out_layout_array_attr || out_layout_array_attr.empty()) { + return emitOpError("missing or empty 'out_layout' attribute"); + } + if (out_layout_array_attr.size() != 1) { + return emitOpError( + "'out_layout' attribute must be an array containing a single " + "VectorLayoutAttr"); + } + auto dst_vla = dyn_cast(out_layout_array_attr[0]); + if (!dst_vla) { + return emitOpError("'out_layout' attribute is not a VectorLayoutAttr"); + } + + VectorType input_type = cast(getInput().getType()); + VectorType output_type = cast(getOutput().getType()); + + if (input_type.getShape() != output_type.getShape()) { + return emitOpError("input and output shapes must match"); + } + if (input_type.getElementType() != output_type.getElementType()) { + // Allow i1 to i1 even if bitwidth in layout changes. + if (!(input_type.getElementType().isInteger(1) && + output_type.getElementType().isInteger(1))) { + return emitOpError( + "input and output element types must match for non-mask relayouts"); } } - auto new_layout = tpu::TiledLayoutAttr::get( - source_type.getContext(), old_layout.getTiles(), tile_strides); - auto new_result_type = MemRefType::get(op.getResult().getType().getShape(), - layout_ty.getElementType(), new_layout, - layout_ty.getMemorySpace()); - auto squeeze = rewriter.create(op.getLoc(), new_result_type, - layout_ref); - rewriter.replaceOpWithNewOp(op, op.getType(), squeeze); return success(); } @@ -322,6 +460,41 @@ LogicalResult MemRefReshapeOp::verify() { return success(); } +LogicalResult TransposeOp::verify() { + auto source_type = getSourceVectorType(); + auto permutation = getPermutation(); + auto output_type = getResultVectorType(); + auto input_shape = source_type.getShape(); + auto output_shape = output_type.getShape(); + if (source_type.getElementType() != output_type.getElementType()) { + return emitOpError("Expected input and output element types to match"); + } + if (permutation.size() != source_type.getRank()) { + return emitOpError("Expected permutation rank to match input rank"); + } + if (permutation.size() != output_type.getRank()) { + return emitOpError("Expected permutation rank to match output rank"); + } + std::vector seen_dims(source_type.getRank(), false); + for (int64_t dim : permutation) { + if (dim < 0 || dim >= source_type.getRank()) { + return emitOpError("Permutation element out of bounds: ") << dim; + } + if (seen_dims[dim]) { + return emitOpError("Permutation element repeated: ") << dim; + } + seen_dims[dim] = true; + } + for (int i = 0; i < source_type.getRank(); ++i) { + if (input_shape[permutation[i]] != output_shape[i]) { + return emitOpError( + "Expected input shape permuted by the given permutation to match the " + "output shape"); + } + } + return success(); +} + LogicalResult MemRefReshapeOp::canonicalize(MemRefReshapeOp op, PatternRewriter &rewriter) { auto src_ty = op.getInput().getType(); @@ -332,8 +505,7 @@ LogicalResult MemRefReshapeOp::canonicalize(MemRefReshapeOp op, } auto layout_ref = erase_layout_op.getOperand(); auto layout_ty = layout_ref.getType(); - auto layout = - dyn_cast(layout_ty.getLayout()); + auto layout = dyn_cast(layout_ty.getLayout()); CHECK(!layout.getTiles().empty()); auto tile = layout.getTiles().front().dimensions(); auto new_tile_strides = ComputeTileStrides(dst_ty, tile); @@ -427,8 +599,8 @@ LogicalResult MemRefBitcastOp::canonicalize(MemRefBitcastOp op, if (tile[0] * src_bitwidth % tgt_bitwidth != 0) { return failure(); } - SmallVector new_tiles = - {xla::Tile({tile[0] * src_bitwidth / tgt_bitwidth, 128})}; + SmallVector new_tiles = { + xla::Tile({tile[0] * src_bitwidth / tgt_bitwidth, 128})}; if (tgt_bitwidth < 32) { new_tiles.push_back(xla::Tile({32 / tgt_bitwidth, 1})); } @@ -507,6 +679,36 @@ LogicalResult VectorStoreOp::verify() { return success(); } +LogicalResult VectorLoadOp::verify() { + const MemRefType ref_ty = getBase().getType(); + if (!getStrides().empty()) { + if (llvm::size(getStrides()) != ref_ty.getRank()) { + return emitOpError("Expected ") << ref_ty.getRank() << " strides."; + } + return emitError("Not implemented: general vector load with strides."); + } + const VectorType value_ty = getResult().getType(); + + if (value_ty.getElementType() != ref_ty.getElementType()) { + return emitOpError("Expected base and result element type to match."); + } + if (llvm::size(getIndices()) != ref_ty.getRank()) { + return emitOpError("Expected ") << ref_ty.getRank() << " indices."; + } + if (getMask()) { + if (value_ty.getElementTypeBitWidth() != 32) { + return emitError( + "Not implemented: masked load with non-32-bit element type"); + } + if (vector::isBroadcastableTo(getMask().getType(), value_ty) != + vector::BroadcastableToResult::Success) { + return emitOpError( + "Expected mask shape to be broadcastable to result shape."); + } + } + return success(); +} + LogicalResult ReinterpretCastOp::verify() { auto source_type = getMemRefType(getInput()); auto target_type = getType(); @@ -954,13 +1156,24 @@ LogicalResult EnqueueDMAOp::verify() { "device_id or core_id is specified"); } } + bool is_remote = getDeviceId() || getCoreId(); if (getSourceSemaphore()) { - if (!getDeviceId() && !getCoreId()) { + if (!is_remote) { return emitOpError( "DMA destination device_id or core_id must be specified when source " "semaphore is specified"); } } + int priority = getPriority(); + if (priority < 0 || priority > 1) { + return emitOpError( + "Not implemented: only support priority 0 or 1, but got ") + << priority; + } + if (priority != 0 && is_remote) { + return emitOpError( + "Not implemented: non-zero priority is not supported for remote DMA"); + } return success(); } @@ -1084,7 +1297,7 @@ LogicalResult ConcatenateOp::verify() { if (getOperands().size() < 2) { return emitOpError("Expected at least 2 operands for concatenate op."); } - auto first_type = getOperand(0).getType().cast(); + auto first_type = cast(getOperand(0).getType()); auto first_shape = first_type.getShape(); auto first_dtype = first_type.getElementType(); for (auto operand : getOperands()) { @@ -1117,11 +1330,21 @@ LogicalResult LogOp::verify() { return failure(); } CoreType logging_core_type = logging_core_type_maybe->value_or(CoreType::kTc); - if ((logging_core_type == CoreType::kScScalarSubcore || - logging_core_type == CoreType::kScVectorSubcore) && - getFormattedAttr() != nullptr && getFormattedAttr().getValue()) { + bool is_sc_core = logging_core_type == CoreType::kScScalarSubcore || + logging_core_type == CoreType::kScVectorSubcore; + if (is_sc_core && getFormattedAttr() != nullptr && + getFormattedAttr().getValue()) { return emitOpError("Formatted logging is not supported on SC"); } + if (is_sc_core && getInputs().size() > 1) { + return emitOpError("SC logging only supports 0 or 1 inputs"); + } + if (is_sc_core && getInputs().size() == 1) { + Type input_type = getInputs().front().getType(); + if (!llvm::isa(input_type)) { + return emitOpError("SC logging only supports memrefs or scalars"); + } + } switch (logging_core_type) { case CoreType::kTc: case CoreType::kScScalarSubcore: @@ -1231,8 +1454,16 @@ LogicalResult DynamicGatherOp::verify() { if (getIndices().getType().getShape() != getIndices().getType().getShape()) { return emitOpError("Expected indices and result shapes must match"); } - if (!getIndices().getType().getElementType().isInteger(32)) { - return emitOpError("Not implemented: Only i32 indices supported"); + const int64_t rank = getSource().getType().getRank(); + SmallVector seen(rank, false); + for (int32_t d : getDimensions()) { + if (d < 0 || d >= rank) { + return emitOpError("Dimensions must be in [0, rank), but got ") << d; + } + if (seen[d]) { + return emitOpError("Dimensions must be unique"); + } + seen[d] = true; } return success(); } @@ -1258,6 +1489,50 @@ LogicalResult AssumeMultipleOp::verify() { return success(); } +LogicalResult SublaneShuffleOp::verify() { + auto lhs = getLhs(); + auto rhs = getRhs(); + auto result = getResult(); + auto lhs_ty = dyn_cast(lhs.getType()); + auto rhs_ty = dyn_cast(rhs.getType()); + auto result_ty = dyn_cast(result.getType()); + + if (!lhs_ty || !rhs_ty || !result_ty) { + return emitOpError("Expected operands and result to be vector types"); + } + + if (lhs_ty.getShape() != rhs_ty.getShape() || + lhs_ty.getShape() != result_ty.getShape()) { + return emitOpError("Expected lhs, rhs, and result shapes to match"); + } + if (lhs_ty.getElementType() != rhs_ty.getElementType() || + lhs_ty.getElementType() != result_ty.getElementType()) { + return emitOpError("Expected lhs, rhs, and result element types to match"); + } + + auto pattern = getPattern(); + auto shape = result_ty.getShape(); + if (shape.size() < 2 || shape.size() > 3) { + return emitOpError("Vreg rank should be 2 or 3"); + } + auto sublane_count = shape[0]; + + if (pattern.size() != sublane_count) { + return emitOpError("Expected pattern size (") + << pattern.size() << ") to match result/operand sublanes (" + << sublane_count << ")"; + } + + int64_t total_input_sublanes = sublane_count * 2; + for (int32_t idx : pattern) { + if (idx < 0 || idx >= total_input_sublanes) { + return emitOpError("Pattern index ") << idx << " out of bounds [0, " + << (total_input_sublanes - 1) << "]"; + } + } + return success(); +} + } // namespace tpu } // namespace mlir diff --git a/jaxlib/mosaic/dialect/tpu/tpu_ops_verification_test.cc b/jaxlib/mosaic/dialect/tpu/tpu_ops_verification_test.cc new file mode 100644 index 000000000000..e92403c21ad0 --- /dev/null +++ b/jaxlib/mosaic/dialect/tpu/tpu_ops_verification_test.cc @@ -0,0 +1,246 @@ +/* Copyright 2025 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include +#include +#include "absl/status/status.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributeInterfaces.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/DialectRegistry.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/Support/LLVM.h" +#include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" +#include "xla/mlir/utils/error_util.h" + +namespace mlir::tpu { +namespace { + +using ::testing::_; +using ::testing::HasSubstr; +using ::testing::status::StatusIs; + +class TpuOpsVerificationTest : public ::testing::Test { + protected: + TpuOpsVerificationTest() + : context_([]() { + DialectRegistry registry; + registry + .insert(); + return registry; + }()), + builder_(UnknownLoc::get(&context_), &context_) { + context_.loadAllAvailableDialects(); + context_.printOpOnDiagnostic(true); + } + ~TpuOpsVerificationTest() { + for (int i = ops_.size() - 1; i >= 0; --i) { + ops_[i]->erase(); + } + } + + template + OpTy Create(Args&&... args) { + OpTy op = builder_.create(std::forward(args)...); + ops_.push_back(op.getOperation()); + return op; + } + + template + absl::Status VerifyOp(OpTy op) { + BaseScopedDiagnosticHandler diag(&context_); + if (op.verify().succeeded()) { + return absl::OkStatus(); + } + return diag.ConsumeStatus(); + } + + ImplicitLocOpBuilder& builder() { return builder_; } + + private: + MLIRContext context_; + ImplicitLocOpBuilder builder_; + std::vector ops_; +}; + +TEST_F(TpuOpsVerificationTest, VectorLoadVerificationWorks) { + auto c0 = Create(0); + auto memref = + Create(MemRefType::get({8}, builder().getI32Type())); + auto vl = Create( + /*result=*/VectorType::get({8}, builder().getI32Type()), + /*base=*/memref.getMemref(), + /*indices=*/ValueRange{c0}, + /*strides=*/builder().getDenseI32ArrayAttr({}), + /*mask=*/nullptr); + + ASSERT_OK(VerifyOp(vl)); +} + +TEST_F(TpuOpsVerificationTest, + VectorLoadRankOfStridesDoesNotMatchBaseMemrefRank) { + auto c0 = Create(0); + auto memref = + Create(MemRefType::get({8}, builder().getI32Type())); + auto vl = Create( + /*result=*/VectorType::get({8}, builder().getI32Type()), + /*base=*/memref.getMemref(), + /*indices=*/ValueRange{c0}, + /*strides=*/builder().getDenseI32ArrayAttr({1, 1, 1, 1}), + /*mask=*/nullptr); + ASSERT_THAT(VerifyOp(vl), StatusIs(_, HasSubstr("Expected 1 strides."))); +} + +TEST_F(TpuOpsVerificationTest, VectorLoadStridesFeatureNotImplemented) { + auto c0 = Create(0); + auto memref = + Create(MemRefType::get({8}, builder().getI32Type())); + auto vl = Create( + /*result=*/VectorType::get({8}, builder().getI32Type()), + /*base=*/memref.getMemref(), + /*indices=*/ValueRange{c0}, + /*strides=*/builder().getDenseI32ArrayAttr({1}), + /*mask=*/nullptr); + ASSERT_THAT( + VerifyOp(vl), + StatusIs( + _, HasSubstr("Not implemented: general vector load with strides."))); +} + +TEST_F(TpuOpsVerificationTest, VectorLoadBaseAndResultTypesDoNotMatch) { + auto c0 = Create(0); + auto memref = + Create(MemRefType::get({8}, builder().getI32Type())); + auto vl = Create( + /*result=*/VectorType::get({8}, builder().getF32Type()), + /*base=*/memref.getMemref(), + /*indices=*/ValueRange{c0}, + /*strides=*/builder().getDenseI32ArrayAttr({}), + /*mask=*/nullptr); + + ASSERT_THAT( + VerifyOp(vl), + StatusIs(_, + HasSubstr("Expected base and result element type to match."))); +} + +TEST_F(TpuOpsVerificationTest, + VectorLoadRankOfIndicesDoesNotMatchBaseMemrefRank) { + auto c0 = Create(0); + auto memref = + Create(MemRefType::get({8}, builder().getI32Type())); + auto vl = Create( + /*result=*/VectorType::get({8}, builder().getI32Type()), + /*base=*/memref.getMemref(), + /*indices=*/ValueRange{c0, c0, c0}, + /*strides=*/builder().getDenseI32ArrayAttr({}), + /*mask=*/nullptr); + + ASSERT_THAT(VerifyOp(vl), StatusIs(_, HasSubstr("Expected 1 indices."))); +} + +TEST_F(TpuOpsVerificationTest, VectorLoadValidMaskSucceeds) { + auto c0 = Create(0); + auto memref = Create( + MemRefType::get({8, 128}, builder().getI32Type())); + auto mask = Create( + /*result=*/VectorType::get({8, 1}, builder().getI32Type()), + /*value=*/dyn_cast( + builder().getDenseI32ArrayAttr({1, 1, 1, 1, 1, 1, 1, 1}))); + auto vl = Create( + /*result=*/VectorType::get({8, 128}, builder().getI32Type()), + /*base=*/memref.getMemref(), + /*indices=*/ValueRange{c0, c0}, + /*strides=*/builder().getDenseI32ArrayAttr({}), + /*mask=*/mask.getResult()); + + ASSERT_OK(VerifyOp(vl)); +} + +TEST_F(TpuOpsVerificationTest, VectorLoadMaskInvalidResultBitWidth) { + auto c0 = Create(0); + auto memref = Create( + MemRefType::get({8, 128}, builder().getI64Type())); + auto mask = Create( + /*result=*/VectorType::get({8, 1}, builder().getI32Type()), + /*value=*/dyn_cast( + builder().getDenseI32ArrayAttr({1, 1, 1, 1, 1, 1, 1, 1}))); + auto vl = Create( + /*result=*/VectorType::get({8, 128}, builder().getI64Type()), + /*base=*/memref.getMemref(), + /*indices=*/ValueRange{c0, c0}, + /*strides=*/builder().getDenseI32ArrayAttr({}), + /*mask=*/mask.getResult()); + + ASSERT_THAT( + VerifyOp(vl), + StatusIs( + _, HasSubstr( + "Not implemented: masked load with non-32-bit element type"))); +} + +TEST_F(TpuOpsVerificationTest, + VectorLoadMaskNotBroadcastableToResultShapeInvalidMinor) { + auto c0 = Create(0); + auto memref = Create( + MemRefType::get({8, 128}, builder().getI32Type())); + auto mask = Create( + /*result=*/VectorType::get({8, 2}, builder().getI32Type()), + /*value=*/dyn_cast(builder().getDenseI32ArrayAttr({1}))); + auto vl = Create( + /*result=*/VectorType::get({8, 128}, builder().getI32Type()), + /*base=*/memref.getMemref(), + /*indices=*/ValueRange{c0, c0}, + /*strides=*/builder().getDenseI32ArrayAttr({}), + /*mask=*/mask.getResult()); + + ASSERT_THAT( + VerifyOp(vl), + StatusIs( + _, HasSubstr( + "Expected mask shape to be broadcastable to result shape."))); +} + +TEST_F(TpuOpsVerificationTest, + VectorLoadMaskNotBroadcastableToResultShapeInvalidMajor) { + auto c0 = Create(0); + auto memref = Create( + MemRefType::get({8, 128}, builder().getI32Type())); + auto mask = Create( + /*result=*/VectorType::get({5, 1}, builder().getI32Type()), + /*value=*/dyn_cast(builder().getDenseI32ArrayAttr({1}))); + auto vl = Create( + /*result=*/VectorType::get({8, 128}, builder().getI32Type()), + /*base=*/memref.getMemref(), + /*indices=*/ValueRange{c0, c0}, + /*strides=*/builder().getDenseI32ArrayAttr({}), + /*mask=*/mask.getResult()); + + ASSERT_THAT( + VerifyOp(vl), + StatusIs( + _, HasSubstr( + "Expected mask shape to be broadcastable to result shape."))); +} + +} // namespace +} // namespace mlir::tpu diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index 1997ffe34535..c3b2e8cc7f38 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -1,3 +1,18 @@ +/* Copyright 2021 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + #include "jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.h" #include @@ -13,15 +28,23 @@ #include #include +#include "absl/algorithm/container.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/types/span.h" +#include "llvm/ADT/APInt.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVectorExtras.h" #include "llvm/ADT/StringMap.h" #include "llvm/ADT/iterator_range.h" -#include "llvm/Support/Compiler.h" +#include "llvm/Support/LogicalResult.h" #include "llvm/Support/MathExtras.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Traits.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" @@ -33,9 +56,11 @@ #include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Diagnostics.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/Operation.h" +#include "mlir/IR/OperationSupport.h" #include "mlir/IR/Region.h" #include "mlir/IR/TypeRange.h" #include "mlir/IR/Types.h" @@ -45,21 +70,6 @@ #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" -#include "absl/algorithm/container.h" -#include "absl/log/check.h" -#include "absl/log/log.h" -#include "absl/status/status.h" -#include "absl/types/span.h" -#include "llvm/include/llvm/ADT/APInt.h" -#include "llvm/include/llvm/Support/LogicalResult.h" -#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/include/mlir/Dialect/Math/IR/Math.h" -#include "mlir/include/mlir/Dialect/Vector/IR/VectorOps.h" -#include "mlir/include/mlir/IR/Attributes.h" -#include "mlir/include/mlir/IR/Builders.h" -#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" -#include "mlir/include/mlir/IR/OperationSupport.h" #include "jaxlib/mosaic/dialect/tpu/array_util.h" #include "jaxlib/mosaic/dialect/tpu/layout.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" @@ -199,23 +209,18 @@ bool incrementIndex(const MutableArrayRef idx, return false; } -FailureOr getIntConst(Value v, bool silent = false) { - if (auto constant_op = v.getDefiningOp()) { - if (auto integer_attr = dyn_cast(constant_op.getValue())) { - return integer_attr.getValue().getSExtValue(); - } - } - if (silent) { - return failure(); +FailureOr expectIntConst(Value v) { + if (auto cst = getIntConst(v)) { + return cst.value(); } return emitError(v.getLoc(), "Expected an integer constant"); } -FailureOr> getIntConstsFromOperandRange( - ValueRange vals, bool silent = false) { +FailureOr> expectIntConstsFromOperandRange( + ValueRange vals) { SmallVector res(vals.size()); for (int i = 0; i < vals.size(); ++i) { - FAILUREOR_ASSIGN_OR_RETURN(res[i], getIntConst(vals[i], silent)); + FAILUREOR_ASSIGN_OR_RETURN(res[i], expectIntConst(vals[i])); } return res; } @@ -255,7 +260,7 @@ FailureOr>> sliceRef( Value c0 = nullptr; SmallVector indices_within_slice(indices.size() - tiling.size(), 0); for (auto tiled_idx : indices.take_back(tiling.size())) { - if (auto cst = getIntConst(tiled_idx, /*silent=*/true); succeeded(cst)) { + if (auto cst = getIntConst(tiled_idx)) { indices_within_slice.push_back(*cst); if (!c0) { c0 = builder.create(i32, @@ -436,6 +441,121 @@ FailureOr maskOOB(RewriteContext &ctx, ImplicitLocOpBuilder &builder, .getResult(); } +// Transpose the 2nd minor dimension of the implicit shape. +// +// Shape of (..., N, 1) becomes (..., 1, N) +FailureOr> transposeSingletonMinorDimension( + RewriteContext &ctx, OpBuilder &builder, const Location loc, + xla::Array vregs, const ArrayRef ishape, + VectorLayout layout, const int64_t new_minor_offset) { + if (layout.bitwidth() != 32 || !layout.hasNativeTiling(ctx.target_shape)) { + // Note: For non-native tilings it is probably better to retile first, to + // to make the most out of each lane rotate (they are expensive). + return emitError(loc, "Not implemented: Unsupported bitwidth or tiling"); + } + auto create_index_const = [&](const int64_t idx) { + return builder.create(loc, idx); + }; + auto create_i32_vreg_const = [&](const int64_t val) { + return I32Const(val, ctx.target_shape, builder, loc); + }; + if (layout.offsets()[1].has_value()) { + // Replicate minor dimension + // TODO(tlongeri): Move into its own function (it will be needed for + // relayout) and make this a precondition of this function, so that we have + // "building block" functions with minimal overlap + vregs.Each([&](const absl::Span idxs, Value *vreg) { + *vreg = builder.create( + loc, vreg->getType(), *vreg, + create_i32_vreg_const(*layout.offsets()[1]), 1); + }); + layout = + VectorLayout(layout.bitwidth(), {layout.offsets()[0], std::nullopt}, + layout.tiling(), VectorLayout::ImplicitDim::kNone); + } + if (!layout.offsets()[0].has_value()) { + return vregs; + } + const int64_t old_2nd_minor_offset = *layout.offsets()[0]; + SmallVector new_ishape(ishape); + CHECK_EQ(new_ishape.back(), 1); + std::iter_swap(new_ishape.end() - 2, new_ishape.end() - 1); + // new_layout is only to get the new vreg array shape, the implicit dim is + // irrelevant (since we already have the implicit shape): + const VectorLayout new_layout( + layout.bitwidth(), {std::nullopt, new_minor_offset}, layout.tiling(), + VectorLayout::ImplicitDim::kNone); + xla::Array new_vregs(new_layout.tileArrayShape( + /*src_is_implicit=*/true, /*res_is_implicit=*/true, new_ishape, + ctx.target_shape)); + VectorType iota_vreg_ty = + getNativeVregType(builder.getI32Type(), ctx.target_shape); + // Preallocate an indices vector to avoid repeated allocations: + SmallVector old_idxs; + new_vregs.Each([&](const absl::Span new_idxs, + Value *new_vreg) { + const int64_t uncorrected_shape_start = + ctx.target_shape[1] * new_idxs.back() - new_minor_offset; + // The start and end of the data contained by new_vreg in the implicit shape + const int64_t shape_start = std::max(uncorrected_shape_start, 0); + const int64_t shape_end = std::min( + uncorrected_shape_start + ctx.target_shape[1], new_ishape.back()); + old_idxs.assign(new_idxs.begin(), new_idxs.end()); + CHECK_EQ(*(old_idxs.end() - 2), 0); + old_idxs.back() = 0; + *new_vreg = nullptr; + VectorType vmask_ty = + getNativeVregOrVmaskType(builder.getI1Type(), 32, ctx.target_shape); + int64_t shape_offset = shape_start; + // The data in the new vreg is composed of data from multiple of the old + // vregs, so iterate over them until the new vreg is full + while (shape_offset < shape_end) { + // Find the vreg that contains the data at shape_offset + *(old_idxs.end() - 2) = + (shape_offset + old_2nd_minor_offset) / ctx.target_shape[0]; + const int64_t old_sublane_offset = + (shape_offset + old_2nd_minor_offset) % ctx.target_shape[0]; + const int64_t new_lane_offset = + (shape_offset + new_minor_offset) % ctx.target_shape[1]; + // We will blend in all the relevant data contained by the old vreg + const int64_t data_size = + std::min(ctx.target_shape[0] - old_sublane_offset, + ctx.target_shape[1] - new_lane_offset); + // [ a a a a a a a a ] [ . . a b c . . . ] + // [ b b b b b b b b ] => [ . . a b c . . . ] + // [ c c c c c c c c ] [ . . a b c . . . ] + // [ . . . . . . . . ] [ . . a b c . . . ] + // Every lane has all the data, so at each sublane we can just pick out + // the element that we want using a sublane shuffle. + Value vreg = vregs(old_idxs); + Value iota_vreg = builder.create( + loc, iota_vreg_ty, + /*dimension =*/builder.getI32IntegerAttr(1)); + iota_vreg = builder.create( + loc, iota_vreg, + create_i32_vreg_const(old_sublane_offset - new_lane_offset)); + vreg = builder.create(loc, vreg.getType(), vreg, + iota_vreg, 0); + // Now, blend the transposed data into new_vreg + if (*new_vreg == nullptr) { + *new_vreg = vreg; + } else { + Value mask = builder.create( + loc, vmask_ty, + ArrayRef{create_index_const(0), + create_index_const(new_lane_offset)}, + ArrayRef{create_index_const(ctx.target_shape[0]), + create_index_const(new_lane_offset + data_size)}); + *new_vreg = builder.create(loc, mask, vreg, *new_vreg); + } + shape_offset += data_size; + ++*(old_idxs.end() - 2); + } + CHECK(*new_vreg != nullptr); + }); + return new_vregs; +} + // Insert a minor dimension to the implicit shape. The original minor dimension // becomes the new second minor dimension, laid out across sublanes. // @@ -1022,6 +1142,40 @@ LogicalResult tpu_fptosi_rule(RewriteContext &ctx, Operation &op, return op.emitOpError("Unsupported FPToSI conversion"); } +LogicalResult tpu_sitofp_rule(RewriteContext &ctx, Operation &op, + const ArrayRef layouts_in, + const ArrayRef layouts_out) { + TPU_ASSERT_EQ_OP(layouts_in.size(), 1); + TPU_ASSERT_OP(layouts_in.front().has_value()); + TPU_ASSERT_EQ_OP(layouts_out.size(), 1); + TPU_ASSERT_OP(layouts_out.front().has_value()); + auto &layout_in = *layouts_in.front(); + auto &layout_out = *layouts_out.front(); + if (layout_in.bitwidth() == layout_out.bitwidth()) { + return elementwise_op_rule(ctx, op, layouts_in, layouts_out); + } else if (layout_in.bitwidth() < layout_out.bitwidth()) { + auto sitofp_op = cast(op); + switch (sitofp_op.getRoundingMode()) { + case tpu::RoundingMode::kToNearestEven: { + ImplicitLocOpBuilder builder(op.getLoc(), &op); + FAILUREOR_ASSIGN_OR_RETURN( + xla::Array vregs, + ext_op_rule_impl(ctx, builder, sitofp_op, layout_in, layout_out)); + sitofp_op.replaceAllUsesWith( + assemble(builder, cast(sitofp_op.getType()), layout_out, + std::move(vregs), ctx.target_shape) + .getResult()); + sitofp_op.erase(); + return success(); + } + case tpu::RoundingMode::kTowardsZero: + return op.emitOpError( + "Not implemented: SIToFP with rounding mode kTowardsZero"); + } + } + return op.emitOpError("Unsupported SIToFP conversion"); +} + LogicalResult func_return_rule(RewriteContext &ctx, Operation &op, const ArrayRef layouts_in, const ArrayRef layouts_out) { @@ -1538,7 +1692,7 @@ LogicalResult tpu_load_rule(RewriteContext &ctx, Operation &op, } FAILUREOR_ASSIGN_OR_RETURN( const SmallVector indices, - getIntConstsFromOperandRange(load_op.getIndices())); + expectIntConstsFromOperandRange(load_op.getIndices())); TPU_ASSERT_EQ_OP(indices.size(), 2); if (indices[1] % ctx.target_shape[1] != 0) { return op.emitOpError("Not implemented: Lane index is not a multiple of ") @@ -1596,8 +1750,8 @@ LogicalResult strided_op_rule_impl(RewriteContext &ctx, Operation &op, if (strides[rank - 1] != 1) { return op.emitOpError("Not Implemented: Stride on last dim is not 1"); } - auto last_idx = getIntConst(indices[rank - 1], /*silent=*/true); - if (failed(last_idx)) { + auto last_idx = getIntConst(indices[rank - 1]); + if (!last_idx.has_value()) { return op.emitOpError("Not Implemented: Dynamic index on last dim"); } else if (last_idx.value() != 0) { return op.emitOpError("Not Implemented: Index on last dim is not 0"); @@ -1965,7 +2119,7 @@ LogicalResult tpu_store_rule(RewriteContext &ctx, Operation &op, tpu::StoreOp store_op = cast(op); FAILUREOR_ASSIGN_OR_RETURN( const SmallVector indices, - getIntConstsFromOperandRange(store_op.getIndices())); + expectIntConstsFromOperandRange(store_op.getIndices())); TPU_ASSERT_EQ_OP(indices.size(), 2); if (indices[1] % ctx.target_shape[1] != 0) { return op.emitOpError("Not implemented: Lane index is not a multiple of ") @@ -2092,72 +2246,32 @@ LogicalResult tpu_assume_layout_rule(RewriteContext &ctx, Operation &op, return success(); } -LogicalResult tpu_relayout_rule(RewriteContext &ctx, Operation &op, - const ArrayRef layouts_in, - const ArrayRef layouts_out) { - TPU_ASSERT_EQ_OP(op.getNumOperands(), 1); - TPU_ASSERT_EQ_OP(op.getNumResults(), 1); - TPU_ASSERT_EQ_OP(layouts_in.size(), 1); - TPU_ASSERT_EQ_OP(layouts_out.size(), 1); - TPU_ASSERT_OP(layouts_in[0].has_value()); - TPU_ASSERT_OP(layouts_out[0].has_value()); - const auto& in_layout = *layouts_in[0]; - const auto& out_layout = *layouts_out[0]; - auto realyout_op = cast(op); - auto in_bitwidth = in_layout.bitwidth(); - auto out_bitwidth = out_layout.bitwidth(); - auto vty = cast(realyout_op.getType()); - ImplicitLocOpBuilder builder(op.getLoc(), &op); - if (in_layout == out_layout) { - realyout_op.replaceAllUsesWith(realyout_op.getInput()); - realyout_op.erase(); - return success(); - } - FAILUREOR_ASSIGN_OR_RETURN( - xla::Array vals, - disassemble(builder, in_layout, - cast>(realyout_op.getInput()), - ctx.target_shape, - /*use_implicit_shape=*/true)); - // Packing vector masks from 32-bit to 16-bit. - if (vty.getElementType() == builder.getI1Type() && in_bitwidth == 32 && - out_bitwidth == 16 && - in_layout.tiling()[0] == in_layout.packing() * ctx.target_shape[0] && - in_layout.tiling()[1] == ctx.target_shape[1] && - in_layout.tiling() == out_layout.tiling() && - in_layout.offsets() == out_layout.offsets() && - in_layout.implicit_dim() == out_layout.implicit_dim()) { - std::vector vmsks_shape(vals.dimensions().begin(), - vals.dimensions().end()); - *(vmsks_shape.end() - 1) = llvm::divideCeil(vmsks_shape.back(), 2); - xla::Array out_vmsks(vmsks_shape, nullptr); - SmallVector val_idx; - Value default_val = - getFullLikeVector(builder, cast>(*vals.begin()), - IntegerAttr::get(builder.getI1Type(), 0)); - out_vmsks.Each([&](absl::Span idx, Value *v) { - val_idx.assign(idx.begin(), idx.end()); - // TODO(jevinjiang): can be simplified when offset is replicated. - *(val_idx.end() - 1) *= 2; - Value low_part = *(val_idx.end() - 1) < *(vals.dimensions().end() - 1) - ? vals(val_idx) - : default_val; - *(val_idx.end() - 1) += 1; - Value high_part = *(val_idx.end() - 1) < *(vals.dimensions().end() - 1) - ? vals(val_idx) - : default_val; - const VectorType mask_ty = getNativeVregOrVmaskType( - builder.getI1Type(), in_bitwidth / 2, ctx.target_shape); - *v = builder.create(mask_ty, low_part, high_part); - }); - const RollVectorsOp rolled_op = - assemble(builder, vty, out_layout, out_vmsks, ctx.target_shape, - /*use_implicit_shape=*/true); - op.replaceAllUsesWith(rolled_op); - op.erase(); - return success(); - } - return op.emitOpError("Not implemented: unsupported layout change"); +Value createSubelementMask(OpBuilder &builder, const Location loc, + const int bitwidth, const int64_t from, + const int64_t to, + const std::array target_shape) { + auto create_index_const = [&](const int64_t idx) { + return builder.create( + loc, builder.getIntegerAttr(builder.getIndexType(), idx)); + }; + const int packing = 32 / bitwidth; + const VectorType vmask_ty = + getNativeVregOrVmaskType(builder.getI1Type(), bitwidth, target_shape); + // Prefer CreateMaskOp if possible - more efficient and supports unpacked + // TODO: b/412754162 - We can probably always use the CreateSubelementMaskOp + // if (1) optimize it on TPUv4 and (2) Add support for unpacked types in some + // of the invariants in lower_to_llo. + if (from % packing == 0 && to % packing == 0) { + const int64_t from_sublane = from / packing; + const int64_t to_sublane = to / packing; + return builder.create( + loc, vmask_ty, + ArrayRef{create_index_const(from_sublane), + create_index_const(0)}, + ArrayRef{create_index_const(to_sublane), + create_index_const(target_shape[1])}); + } + return builder.create(loc, vmask_ty, from, to); } // TODO(b/347016737): Deprecate tpu.rotate and only use tpu.dynamic_rotate. So @@ -2172,16 +2286,40 @@ LogicalResult rotate_rule_impl(RewriteContext &ctx, OpTy op, Value amount, if (layout_in != layout) { return op.emitOpError("Not implemented: unsupported layout for input"); } - if (layout_out != layout) { + LayoutOffsets expected_offsets_out = layout_in.offsets(); + auto shift = getIntConst(amount); + int rotated_tiled_dim = op.getDimension() - (op.getType().getRank() - 2); + bool has_padding_along_rotation = + (rotated_tiled_dim == 0 || rotated_tiled_dim == 1) && + op.getType().getShape()[op.getDimension()] % + layout.tiling()[rotated_tiled_dim] != + 0; + if (shift.has_value() && has_padding_along_rotation) { + // We checked above that there are no implicit dims. + const int64_t dim_size = op.getType().getShape()[op.getDimension()]; + // TODO(b/337384645): Currently we assume {0, 0} offsets in the input + // layout. Relax this assumption. + expected_offsets_out[rotated_tiled_dim] = + (dim_size - (shift.value() % dim_size)) % + layout.tiling()[rotated_tiled_dim]; + } + if (layout_out.bitwidth() != layout.bitwidth() || + layout_out.offsets() != expected_offsets_out || + layout_out.tiling() != layout.tiling() || + layout_out.implicit_dim() != layout.implicit_dim()) { return op.emitOpError("Not implemented: unsupported layout for output"); } auto vty = op.getResult().getType(); if (vty.getRank() < 2) { return op.emitOpError("Not implemented: unsupported 1D shape"); } - if (*(vty.getShape().end() - 2) % *(layout.tiling().end() - 2) != 0 || - *(vty.getShape().end() - 1) % *(layout.tiling().end() - 1) != 0) { - return op.emitOpError("Not implemented: unsupported unaliged shape"); + // TODO(b/411170715): Allow sublane rotation once the bug is fixed. + // TODO(b/337384645): Support non-zero stride. + if (has_padding_along_rotation && + (!shift.has_value() || + (rotated_tiled_dim == 0 || + (rotated_tiled_dim == 1 && op.getStride().value_or(0) != 0)))) { + return op.emitOpError("Not implemented: unsupported unaligned shape"); } ImplicitLocOpBuilder builder(op.getLoc(), op.getOperation()); @@ -2205,19 +2343,19 @@ LogicalResult rotate_rule_impl(RewriteContext &ctx, OpTy op, Value amount, builder.getIntegerAttr(builder.getIndexType(), d)); }; auto modI = [&](const Value &v, unsigned d) -> Value { - if (auto cst = getIntConst(v, /*silent=*/true); succeeded(cst)) { + if (auto cst = getIntConst(v)) { return mlirI32Const(cst.value() % d); } return builder.create(v, mlirI32Const(d)); }; auto divI = [&](const Value &v, unsigned d) -> Value { - if (auto cst = getIntConst(v, /*silent=*/true); succeeded(cst)) { + if (auto cst = getIntConst(v)) { return mlirI32Const(cst.value() / d); } return builder.create(v, mlirI32Const(d)); }; auto addI = [&](const Value &v, unsigned d) -> Value { - if (auto cst = getIntConst(v, /*silent=*/true); succeeded(cst)) { + if (auto cst = getIntConst(v)) { return mlirI32Const(cst.value() + d); } return builder.create(v, mlirI32Const(d)); @@ -2244,8 +2382,7 @@ LogicalResult rotate_rule_impl(RewriteContext &ctx, OpTy op, Value amount, auto getVmaskByPaddingEnd = [&](Value padding, int dim, int stride = 0) { CHECK(dim == 0 || dim == 1); Value padding_vreg; - if (auto padding_cst = getIntConst(padding, /*silent=*/true); - succeeded(padding_cst)) { + if (auto padding_cst = getIntConst(padding)) { CHECK_GE(padding_cst.value(), 0); CHECK_LE(padding_cst.value(), ctx.target_shape[dim]); padding_vreg = builder.create(DenseElementsAttr::get( @@ -2274,8 +2411,7 @@ LogicalResult rotate_rule_impl(RewriteContext &ctx, OpTy op, Value amount, // and blend the data from contiguous vregs to emulate circular rotation. auto rotateOnTilingDim = [&](const xla::Array &vregs, const Value &shift, int axis, int stride = 0) { - if (auto shift_cst = getIntConst(shift, /*silent=*/true); - succeeded(shift_cst)) { + if (auto shift_cst = getIntConst(shift)) { if (shift_cst.value() == 0 && stride == 0) { return vregs; } @@ -2308,6 +2444,88 @@ LogicalResult rotate_rule_impl(RewriteContext &ctx, OpTy op, Value amount, return concatenate(chunks, axis); }; + // Applies lazy rotation (see go/pltpu-roll for details). + auto lazyRotate = [&](const xla::Array &vregs, int64_t shift, + int axis) { + const int tiling_dim = axis - (vregs.num_dimensions() - 2); + const int64_t tile_size = ctx.target_shape[tiling_dim]; + const int64_t input_size = vty.getShape()[axis]; + const int64_t normalized_shift = shift % input_size; + const int64_t start_idx = input_size - normalized_shift; + const int64_t start_vreg_idx = start_idx / tile_size; + const int64_t valid_amount = input_size % tile_size; + + // We start with the following: + // + // vregs: + // +------+ +------+ +------+ + // |░░░ 0 | | 1 | | 2 XXX| + // +------+ +------+ +------+ + // + // where XXX is the padding and ░░░ is the prefix of the same size as the + // padding. + + // After concatenation: + // + // concat: + // +------+ +------+ +------+ +------+ +------+ +------+ + // |░░░ 0 | | 1 | | 2 XXX| |░░░ 0 | | 1 | | 2 XXX| + // +------+ +------+ +------+ +------+ +------+ +------+ + auto concat = concatenate({vregs, vregs}, axis); + auto chunks = split(concat, axis); + int64_t original_num_chunks = chunks.size() / 2; + + Value rotate_amount = mlirI32Const(valid_amount); + SmallVector low = {mlirIndexConst(0), mlirIndexConst(0)}; + low[tiling_dim] = mlirIndexConst(valid_amount); + auto mask = builder.create( + VectorType::get(ctx.target_shape, builder.getI1Type()), low, + /*high=*/ + ArrayRef{mlirIndexConst(ctx.target_shape[0]), + mlirIndexConst(ctx.target_shape[1])}); + // overwrite padding in the last vreg with valid data from the first vreg, + // yielding: + // + // +------+ +------+ +------+ +------+ +------+ +------+ + // |░░░ 0 | | 1 | | 2 XXX| |░░░ 0 | | 1 | | 2 ░░░| + // +------+ +------+ +------+ +------+ +------+ +------+ + chunks.back().Each([&](absl::Span idxs, Value *v) { + *v = builder.create( + mask, + builder.create( + res_vreg_ty, chunks.front()(idxs), rotate_amount, tiling_dim, + nullptr, nullptr), + *v); + }); + // rotate the vregs starting from the middle vreg and then blend the vregs + // to overwrite the padding, yielding: + // + // +------+ +------+ +---+ +------+ +------+ +------+ + // |░░░ 0 | | 1 | | 2 | |░░░ 0 | | 1 | | 2 ░░░| + // +------+ +------+ +---+ +------+ +------+ +------+ + for (int64_t i = original_num_chunks; i < chunks.size(); ++i) { + chunks[i].Each([&](absl::Span idxs, Value *v) { + *v = builder.create( + res_vreg_ty, *v, rotate_amount, tiling_dim, nullptr, nullptr); + }); + } + for (int64_t i = original_num_chunks - 1; i < chunks.size() - 1; ++i) { + chunks[i].Each([&](absl::Span idxs, Value *v) { + *v = builder.create(mask, chunks[i + 1](idxs), *v); + }); + } + SmallVector result_dimensions = + layout_out.tileArrayImplicitShape(vty.getShape(), ctx.target_shape); + // assemble the result + xla::Array result(result_dimensions); + SmallVector starts(result.num_dimensions(), 0); + for (int64_t i = 0; i < result_dimensions[axis]; ++i) { + starts[axis] = i; + result.UpdateSlice(chunks[i + start_vreg_idx], starts); + } + return result; + }; + std::function(const xla::Array &, Value, int, int)> rotate; rotate = [&](const xla::Array &vregs, Value shift, int axis, @@ -2318,9 +2536,11 @@ LogicalResult rotate_rule_impl(RewriteContext &ctx, OpTy op, Value amount, CHECK((tiling_dim != 1 && stride == 0) || (tiling_dim == 1 && stride >= 0)); SmallVector, 4> chunks; // Handle rotation with static shift. - if (auto shift_cst = getIntConst(shift, /*silent=*/true); - succeeded(shift_cst)) { + if (auto shift_cst = getIntConst(shift)) { int64_t static_shift = shift_cst.value(); + if (has_padding_along_rotation) { + return lazyRotate(vregs, static_shift, axis); + } if (tiling_dim >= 0) { shift = mlirI32Const(static_shift % ctx.target_shape[tiling_dim]); static_shift /= ctx.target_shape[tiling_dim]; @@ -2410,7 +2630,9 @@ LogicalResult rotate_rule_impl(RewriteContext &ctx, OpTy op, Value amount, return result; }; - xla::Array out_tiles(in_tiles.dimensions()); + SmallVector out_dimensions = + layout_out.tileArrayImplicitShape(vty.getShape(), ctx.target_shape); + xla::Array out_tiles(out_dimensions); const auto dim = op.getDimension(); amount = modI(amount, vty.getDimSize(dim)); @@ -2437,8 +2659,7 @@ LogicalResult rotate_rule_impl(RewriteContext &ctx, OpTy op, Value amount, vty.getDimSize(dim)); // After applying stride, we expect all shifts in a vreg are less or // equal to the vreg's lane count for now. - if (auto base_amount_cst = getIntConst(base_amount, /*silent=*/true); - succeeded(base_amount_cst)) { + if (auto base_amount_cst = getIntConst(base_amount)) { int64_t static_base_amount = base_amount_cst.value(); auto max_shift_in_vreg = static_base_amount % ctx.target_shape[1] + (ctx.target_shape[0] - 1) * stride; @@ -2689,23 +2910,9 @@ LogicalResult tpu_concatenate_rule(RewriteContext &ctx, Operation &op, const VectorType vmask_ty = getNativeVregOrVmaskType( builder.getI1Type(), bitwidth, ctx.target_shape); if (tiling_dim.value() == 0) { // sublane - if (operand_offset % packing != 0) { - // Packed case, degenerate where we have a half or quarter - // sublane. - // TODO(mvoz): We can probably always use the - // CreateSubelementMaskOp if (1) optimize it on TPUv4 and (2) Add - // support for unpacked types in some of the invariants in - // lower_to_llo. - mask = builder.create( - op.getLoc(), vmask_ty, 0, operand_offset); - } else { - auto sublane_offset = operand_offset / packing; - mask = builder.create( - op.getLoc(), vmask_ty, - ArrayRef{boundIdxConst(0), boundIdxConst(0)}, - ArrayRef{boundIdxConst(sublane_offset), - boundIdxConst(layout->tiling()[1])}); - } + mask = createSubelementMask(builder, op.getLoc(), bitwidth, + /*from=*/0, /*to=*/operand_offset, + ctx.target_shape); } else { // lane mask = builder.create( op.getLoc(), vmask_ty, @@ -2939,11 +3146,18 @@ LogicalResult tpu_dynamic_gather_rule(RewriteContext &ctx, Operation &op, OpBuilder builder(&op); auto dy_gather_op = cast(op); - // TODO(jevinjiang): we need to think harder for general vector shape. - if (dy_gather_op.getType().getShape() != - ArrayRef(ctx.target_shape)) { + // TODO: b/423658138 - we need to think harder for general vector shape. + const bool is_8bit_vreg = + dy_gather_op.getType().getElementTypeBitWidth() == 8 && + dy_gather_op.getType().getShape() == + ArrayRef{4 * ctx.target_shape[0], ctx.target_shape[1]}; + const bool is_32bit_vreg = + dy_gather_op.getType().getElementTypeBitWidth() == 32 && + dy_gather_op.getType().getShape() == ArrayRef(ctx.target_shape); + if (!is_32bit_vreg && !is_8bit_vreg) { return op.emitOpError( - "Not implemented: DynamicGatherOp only supports 32-bit VREG shape"); + "Not implemented: DynamicGatherOp only supports 8- or 32-bit VREG " + "shape"); } if (src_layout != out_layout || idx_layout != out_layout) { @@ -2952,7 +3166,7 @@ LogicalResult tpu_dynamic_gather_rule(RewriteContext &ctx, Operation &op, "result"); } - if (!out_layout.hasNaturalTopology(ctx.target_shape)) { + if (!out_layout.hasNativeTiling(ctx.target_shape)) { return op.emitOpError( "Not implemented: unsupported layout for DynamicGatherOp"); } @@ -2970,11 +3184,75 @@ LogicalResult tpu_dynamic_gather_rule(RewriteContext &ctx, Operation &op, TPU_ASSERT_EQ_OP(src_vregs.dimensions(), idx_vregs.dimensions()); TPU_ASSERT_EQ_OP(src_vregs.num_elements(), 1); + Location loc = dy_gather_op.getLoc(); + SmallVector dimensions(dy_gather_op.getDimensions()); + if (dy_gather_op.getType().getElementTypeBitWidth() == 8) { + if (dy_gather_op.getDimensions() != ArrayRef{0}) { + return dy_gather_op.emitOpError( + "Not implemented: 8-bit dynamic gather only supported along " + "dimension 0"); + } + // Vreg shape is 8x128x4, and lowering only supports dimensions == {2, 0}, + // i.e. byte index is in the upper bits and sublane index in the lower bits. + // However, the input indices effectively have sublane index in the upper + // bits and byte index in the lower bits. + VectorType i32_vreg_ty = + getNativeVregType(builder.getI32Type(), ctx.target_shape); + VectorType i8_vreg_ty = + getNativeVregType(builder.getI8Type(), ctx.target_shape); + auto i8_const_vreg = [&](const int8_t value) { + return getFullVector(builder, loc, i8_vreg_ty, + builder.getI8IntegerAttr(value)); + }; + idx_vregs.Each([&](absl::Span idxs, Value *v) { + const int sublane_bits = llvm::Log2_64(ctx.target_shape[0]); + const int byte_bits = 2; + // This check ensures that e.g. when right shifting below, the bits from + // the higher bytes don't influence the indices of the lower bytes. Lets + // us mask just once. + const bool mask_once = + sublane_bits + byte_bits + std::max(byte_bits, sublane_bits) <= 8; + if (mask_once) { + // Zero out the high bits that specify neither byte nor index (they + // might not be zero since op semantics allow wrapping). + Value mask = i8_const_vreg((1 << (byte_bits + sublane_bits)) - 1); + *v = builder.create(loc, mask, *v); + } + Value shifted_byte = *v; + if (!mask_once) { + Value mask = i8_const_vreg((1 << byte_bits) - 1); + shifted_byte = builder.create(loc, mask, shifted_byte); + } + shifted_byte = + builder.create(loc, i32_vreg_ty, shifted_byte); + shifted_byte = builder.create( + loc, shifted_byte, + getFullVector(builder, loc, i32_vreg_ty, + builder.getI32IntegerAttr(sublane_bits))); + Value shifted_sublane = *v; + if (!mask_once) { + Value mask = + i8_const_vreg((1 << (byte_bits + sublane_bits)) - (1 << byte_bits)); + shifted_sublane = + builder.create(loc, mask, shifted_sublane); + } + shifted_sublane = + builder.create(loc, i32_vreg_ty, shifted_sublane); + shifted_sublane = builder.create( + loc, shifted_sublane, + getFullVector(builder, loc, i32_vreg_ty, + builder.getI32IntegerAttr(byte_bits))); + *v = builder.create(loc, shifted_byte, shifted_sublane); + *v = builder.create(loc, i8_vreg_ty, *v); + }); + dimensions = SmallVector{2, 0}; + } + xla::Array out_vregs(src_vregs.dimensions()); out_vregs.Each([&](absl::Span idxs, Value *v) { - *v = builder.create( - op.getLoc(), src_vregs(idxs).getType(), src_vregs(idxs), - idx_vregs(idxs), dy_gather_op.getDimension()); + *v = builder.create(loc, src_vregs(idxs).getType(), + src_vregs(idxs), idx_vregs(idxs), + dimensions); }); dy_gather_op.replaceAllUsesWith( @@ -3085,7 +3363,7 @@ LogicalResult vector_load_rule(RewriteContext &ctx, Operation &op, // a bunch of loads! } else { return op.emitOpError( - "Not implemented: dismatch in memref tiling and vector tiling in " + "Not implemented: mismatch in memref tiling and vector tiling in " "load"); } } @@ -3095,7 +3373,7 @@ LogicalResult vector_load_rule(RewriteContext &ctx, Operation &op, bool must_support_unaligned_dynamic_index = false; if (load_op.getIndices().size() > 1) { auto second_minor_idx = load_op.getIndices().take_back(2)[0]; - if (failed(getIntConst(second_minor_idx, /*silent=*/true)) && + if (!getIntConst(second_minor_idx).has_value() && !isGuaranteedDivisible(second_minor_idx, memref_tiling[0])) { must_support_unaligned_dynamic_index = true; } @@ -3128,7 +3406,7 @@ LogicalResult vector_load_rule(RewriteContext &ctx, Operation &op, } auto add_idx = [&](const Value &v, int64_t d) -> Value { - if (auto cst = getIntConst(v, /*silent=*/true); succeeded(cst)) { + if (auto cst = getIntConst(v)) { return IdxConst(cst.value() + d, builder, op.getLoc()); } return builder.create(v, IdxConst(d, builder, op.getLoc())); @@ -3421,20 +3699,25 @@ LogicalResult vector_broadcast_rule(RewriteContext &ctx, Operation &op, if (tiling[1] != ctx.target_shape[1]) { return op.emitOpError("Not implemented: unsupported tiling"); } - int64_t num_tiles = layout_in.tilesPerVreg(ctx.target_shape); + const int64_t num_tiles = layout_in.tilesPerVreg(ctx.target_shape); + const int64_t sublanes_per_tile = + layout_in.sublanesPerTile(ctx.target_shape); if (needs_physical_broadcast == std::array{true, false}) { // Sublane broadcast const int packing = layout_in.packing(); - if (num_tiles != 1) { - return op.emitOpError( - "Not implemented: Only native tiling supported"); - } TPU_ASSERT_EQ_OP(*(src_tiles.dimensions().end() - 2), 1); TPU_ASSERT_OP(offsets_in[0].has_value()); const int64_t sublane_offset = *offsets_in[0] / packing; const int64_t subelement_offset = *offsets_in[0] % packing; - const DenseI32ArrayAttr indices = builder.getDenseI32ArrayAttr( - SmallVector(ctx.target_shape[0], sublane_offset)); + SmallVector pattern; + pattern.reserve(ctx.target_shape[0]); + for (int32_t t = 0; t < num_tiles; ++t) { + for (int32_t i = 0; i < sublanes_per_tile; ++i) { + pattern.push_back(sublanes_per_tile * t + sublane_offset); + } + } + const DenseI32ArrayAttr sublane_pattern = + builder.getDenseI32ArrayAttr(pattern); const absl::Status status = src_tiles.EachStatus([&](const absl::Span src_idx, Value *const src_vreg) { @@ -3443,16 +3726,15 @@ LogicalResult vector_broadcast_rule(RewriteContext &ctx, Operation &op, if (packing != 1) { if (auto new_dst_vreg = broadcastSubelements( builder, cast>(dst_vreg), - subelement_offset, ctx.target_shape, - ctx.hardware_generation); + subelement_offset, ctx.target_shape); succeeded(new_dst_vreg)) { dst_vreg = *new_dst_vreg; } else { return absl::InternalError(""); } } - dst_vreg = builder.create(dst_vreg.getType(), - dst_vreg, indices, 0); + dst_vreg = builder.create( + dst_vreg.getType(), dst_vreg, sublane_pattern, 0); SmallVector dst_starts(dst_tiles_implicit_shape.size()); SmallVector dst_limits(dst_tiles_implicit_shape.size()); for (int64_t i = 0; i < dst_tiles.num_dimensions(); ++i) { @@ -3474,14 +3756,13 @@ LogicalResult vector_broadcast_rule(RewriteContext &ctx, Operation &op, std::array{false, true}) { // Lane broadcast TPU_ASSERT_EQ_OP(*(src_tiles.dimensions().end() - 1), 1); TPU_ASSERT_OP(offsets_in[1].has_value()); - const int64_t sublanes_per_tile = - layout_in.sublanesPerTile(ctx.target_shape); + VectorType i32_vreg_ty = + getNativeVregType(builder.getI32Type(), ctx.target_shape); const int64_t offset = *offsets_in[1]; const int64_t lane_offset = offset % ctx.target_shape[1]; const int64_t tile_offset = offset / ctx.target_shape[1]; Value lane_offset_cst = getFullVector( - builder, getNativeVregType(builder.getI32Type(), ctx.target_shape), - builder.getI32IntegerAttr(lane_offset)); + builder, i32_vreg_ty, builder.getI32IntegerAttr(lane_offset)); DenseI32ArrayAttr sublane_pattern; if (num_tiles != 1) { SmallVector pattern; @@ -3494,7 +3775,7 @@ LogicalResult vector_broadcast_rule(RewriteContext &ctx, Operation &op, sublane_pattern = builder.getDenseI32ArrayAttr(pattern); } src_tiles.Each([&](const absl::Span src_idx, - Value *const src_tile) { + Value *const src_vreg) { SmallVector dst_starts(dst_tiles_implicit_shape.size()); SmallVector dst_limits(dst_tiles_implicit_shape.size()); for (int64_t i = 0; i < dst_tiles.num_dimensions(); ++i) { @@ -3506,10 +3787,13 @@ LogicalResult vector_broadcast_rule(RewriteContext &ctx, Operation &op, dst_limits[i] = dst_starts[i] + 1; } } - Value res_vreg = builder.create( - broadcast_op.getLoc(), src_tile->getType(), *src_tile, - lane_offset_cst, + Value src_vreg_i32 = + builder.create(i32_vreg_ty, *src_vreg); + Value res_vreg_i32 = builder.create( + broadcast_op.getLoc(), i32_vreg_ty, src_vreg_i32, lane_offset_cst, /*dimension=*/1); + Value res_vreg = builder.create( + src_vreg->getType(), res_vreg_i32); if (num_tiles != 1) { res_vreg = builder.create( broadcast_op.getLoc(), res_vreg.getType(), res_vreg, @@ -3728,10 +4012,6 @@ LogicalResult vector_extract_rule(RewriteContext &ctx, Operation &op, TPU_ASSERT_EQ_OP(layouts_out.size(), 1); TPU_ASSERT_OP(layouts_in.front().has_value()); const VectorLayout &layout_in = *layouts_in.front(); - if (layout_in.bitwidth() != 32) { - return op.emitOpError( - "Not implemented: Only 32-bit vector.extract supported"); - } const VectorType res_vty = dyn_cast(extract_op.getResult().getType()); if (res_vty != nullptr) { @@ -3760,6 +4040,10 @@ LogicalResult vector_extract_rule(RewriteContext &ctx, Operation &op, op.erase(); return success(); } else { + if (layout_in.bitwidth() != 32) { + return op.emitOpError( + "Not implemented: Only 32-bit vector.extract supported"); + } // TODO(b/367459476): Support non-zero offsets. if (layout_in.offsets() != LayoutOffsets{0, 0}) { return op.emitOpError("Not implemented: Unsupported layout"); @@ -4240,6 +4524,43 @@ LogicalResult vector_multi_reduction_rule(RewriteContext &ctx, Operation &op, return success(); } +// Copy one sublane from a vreg to another vreg. +// +// Arguments: +// src_vreg: The source vreg to copy a sublane from. +// src_sl_idx: The sublane index in src_vreg to copy from. +// dst_vreg: The base vreg to copy the sublane into. May be null. +// dst_sl_idx: The sublane index in the result. +// +// Returns: +// A new dst_vreg with the copied sublane. +Value copyOneSublane(OpBuilder &builder, Value src_vreg, int src_sl_idx, + Value dst_vreg, int dst_sl_idx, + const std::array target_shape) { + src_vreg = builder.create( + src_vreg.getLoc(), src_vreg, + /*amount=*/(dst_sl_idx - src_sl_idx + target_shape[0]) % target_shape[0], + /*dimension=*/0, /*stride=*/nullptr, /*stride_dimension=*/nullptr); + if (dst_vreg) { + auto boundIdxConst = + std::bind(IdxConst, std::placeholders::_1, builder, src_vreg.getLoc()); + const int bitwidth = + cast(src_vreg.getType()).getElementTypeBitWidth(); + CHECK_EQ(bitwidth, + cast(dst_vreg.getType()).getElementTypeBitWidth()); + const VectorType vmask_ty = + getNativeVregOrVmaskType(builder.getI1Type(), bitwidth, target_shape); + auto sublanes_mask = builder.create( + src_vreg.getLoc(), vmask_ty, + ValueRange{boundIdxConst(dst_sl_idx), boundIdxConst(0)}, + ValueRange{boundIdxConst(dst_sl_idx + 1), + boundIdxConst(target_shape[1])}); + src_vreg = builder.create(src_vreg.getLoc(), sublanes_mask, + src_vreg, dst_vreg); + } + return src_vreg; +} + LogicalResult vector_shape_cast_rule(RewriteContext &ctx, Operation &op, const ArrayRef layouts_in, const ArrayRef layouts_out) { @@ -4336,6 +4657,132 @@ LogicalResult vector_shape_cast_rule(RewriteContext &ctx, Operation &op, dst_vregs_local.Reshape( layout_out.tileArrayImplicitShape(dst_shape, ctx.target_shape)); return dst_vregs_local; + } else if ( + // Lower shape_casts for 32-bit types where the minor dimension both + // before and after the shape cast is a multiple of 128. We allow + // folding or unfolding multiple number of minor dimensions and folding + // or unfolding some number of leading dimensions. For example (given + // k % 128 == 0 in the following): + // (q, m, n, k) -> (q, m, n * k) + // (p, q, m, n, k) -> (p, q * m * n * k) + // (q, m, n, k) -> (q, m, 1, n * k) (in 2 steps, first to fold n, k then + // to add the unit dimension) + // (q, m, n, k) -> (q * m, n * k) + // (q * m, n, k) -> (q, m, n * k) + // (q * m, n * k) -> (q, m, n, k) + // (q, m, n * k) -> (q * m, n, k) + dst_shape.size() > 1 && src_shape.size() > 1 && + (mlir::tpu::canFoldMinorDimsToSize(src_shape, dst_shape.back()) || + mlir::tpu::canFoldMinorDimsToSize(dst_shape, src_shape.back())) && + dst_shape.back() % ctx.target_shape[1] == 0 && + src_shape.back() % ctx.target_shape[1] == 0 && + layout_in.offsets() == LayoutOffsets{0, 0} && + layout_in.hasNativeTiling(ctx.target_shape) && + layout_in.bitwidth() == 32 && + layout_in.implicit_dim() == VectorLayout::ImplicitDim::kNone && + layout_out == layout_in) { + auto target_sublanes = ctx.target_shape[0]; + auto target_lanes = ctx.target_shape[1]; + xla::Array dst_vregs( + layout_out.tileArrayShape(false, false, dst_shape, ctx.target_shape)); + + auto to_linear_index = [&](absl::Span indices, + absl::Span bounds) { + CHECK_EQ(indices.size(), bounds.size()); + int linear_index = 0; + int multiplier = 1; + for (int i = indices.size() - 1; i >= 0; --i) { + linear_index += multiplier * indices[i]; + multiplier *= bounds[i]; + } + return linear_index; + }; + auto from_linear_index = [&](int linear_index, + absl::Span bounds) { + SmallVector indices(bounds.size(), 0); + int64_t divisor = std::accumulate(bounds.begin(), bounds.end(), 1, + std::multiplies()); + CHECK_GT(divisor, 0); + int64_t remainder = linear_index % divisor; + for (int i = 0; i < bounds.size(); ++i) { + int64_t radix = bounds[i]; + CHECK_GT(radix, 0); + divisor /= radix; + CHECK_GT(divisor, 0); + indices[i] = remainder / divisor; + remainder = remainder % divisor; + } + return indices; + }; + // Gather sublanes from src_vregs via rotating and selecting each relevant + // sublane from the source, into the destination vreg. + // Args: + // * src_sublane_indices: the mixed-radix indices of the sublanes to + // gather in the order they should be gathered. + // * src_vregs: the vregs to gather from. + // Returns: + // * a vreg with the gathered sublanes. + auto gather_sublanes = [target_sublanes]( + RewriteContext &ctx, Operation &op, + SmallVector> + src_sublane_indices, + const xla::Array &src_vregs) { + ImplicitLocOpBuilder builder(op.getLoc(), &op); + Value dst_vreg = getZerosVector( + builder, cast(src_vregs.begin()->getType())); + for (int sublane_number = 0; + sublane_number < src_sublane_indices.size(); ++sublane_number) { + SmallVector src_vreg_index = + src_sublane_indices[sublane_number]; + src_vreg_index[src_vreg_index.size() - 2] /= target_sublanes; + Value src_vreg = src_vregs(src_vreg_index); + int sublane_within_src_vreg = + src_sublane_indices[sublane_number] + [src_sublane_indices[sublane_number].size() - + 2] % + target_sublanes; + dst_vreg = copyOneSublane(builder, src_vreg, sublane_within_src_vreg, + dst_vreg, sublane_number, ctx.target_shape); + } + return dst_vreg; + }; + SmallVector dst_shape_in_sublanes(dst_shape); + dst_shape_in_sublanes[dst_shape.size() - 1] = + dst_shape[dst_shape.size() - 1] / target_lanes; + SmallVector src_shape_in_sublanes(src_shape); + src_shape_in_sublanes[src_shape.size() - 1] = + src_shape[src_shape.size() - 1] / target_lanes; + // The algorithm operates on 1 destination vreg at a time: + // 1. For each destination vreg, compute the linear index of each sublane + // within it + // 2. Map the destination sublane linear index to a source sublane linear + // index + // 3. convert that to a mixed-radix index into the source shape + // 4. Gather from those source sublane indices. + SmallVector indices; + dst_vregs.Each([&](absl::Span dst_vreg_indices, + Value *dst_vreg) { + indices.assign(dst_vreg_indices.begin(), dst_vreg_indices.end()); + indices[indices.size() - 2] *= target_sublanes; + int sublane_offset = to_linear_index(indices, dst_shape_in_sublanes); + + // Only move non-padding sublanes to the destination vreg. + int num_non_padding_sublanes = std::min( + dst_shape_in_sublanes[dst_shape_in_sublanes.size() - 2] - + dst_vreg_indices[dst_vreg_indices.size() - 2] * target_sublanes, + target_sublanes); + CHECK_EQ(dst_shape.back() % target_lanes, 0); + int stride_in_sublanes = dst_shape.back() / target_lanes; + SmallVector> gathered_sublanes( + num_non_padding_sublanes); + for (int i = 0; i < gathered_sublanes.size(); ++i) { + gathered_sublanes[i] = + from_linear_index(sublane_offset, src_shape_in_sublanes); + sublane_offset += stride_in_sublanes; + } + *dst_vreg = gather_sublanes(ctx, op, gathered_sublanes, src_vregs); + }); + return dst_vregs; } else { return shape_cast_op.emitOpError( "Not implemented: Unsupported vector.shape_cast: ") @@ -4396,7 +4843,7 @@ LogicalResult vector_store_impl(RewriteContext &ctx, Op store_op, // us a bunch of stores! } else { return op.emitOpError( - "Not implemented: dismatch in memref tiling and vector tiling in " + "Not implemented: mismatch in memref tiling and vector tiling in " "store"); } } @@ -4405,7 +4852,7 @@ LogicalResult vector_store_impl(RewriteContext &ctx, Op store_op, bool must_support_unaligned_dynamic_index = false; if (store_op.getIndices().size() > 1) { auto second_minor_idx = store_op.getIndices().take_back(2)[0]; - if (failed(getIntConst(second_minor_idx, /*silent=*/true)) && + if (!getIntConst(second_minor_idx).has_value() && !isGuaranteedDivisible(second_minor_idx, memref_tiling[0])) { must_support_unaligned_dynamic_index = true; } @@ -4436,7 +4883,7 @@ LogicalResult vector_store_impl(RewriteContext &ctx, Op store_op, } auto add_idx = [&](const Value &v, int64_t d) -> Value { - if (auto cst = getIntConst(v, /*silent=*/true); succeeded(cst)) { + if (auto cst = getIntConst(v)) { return IdxConst(cst.value() + d, builder, op.getLoc()); } return builder.create(v, IdxConst(d, builder, op.getLoc())); @@ -4626,7 +5073,7 @@ LogicalResult vector_transpose_rule(RewriteContext &ctx, Operation &op, return op.emitOpError("Not implemented: Unsupported 2D layouts"); } ImplicitLocOpBuilder builder(op.getLoc(), &op); - auto transpose_op = cast(op); + auto transpose_op = cast(op); VectorType src_ty = transpose_op.getSourceVectorType(); VectorType dst_ty = transpose_op.getResultVectorType(); const int64_t rank = src_ty.getRank(); @@ -4636,11 +5083,315 @@ LogicalResult vector_transpose_rule(RewriteContext &ctx, Operation &op, ctx.target_shape)); ArrayRef permutation = transpose_op.getPermutation(); const auto tile_perm = permutation.take_back(2); + + // Major minor pemute if (tile_perm != ArrayRef{rank - 2, rank - 1} && tile_perm != ArrayRef{rank - 1, rank - 2}) { - return transpose_op->emitOpError( - "Not implemented: Unsupported permutation"); + // This is a 3 stage algorithm that uses combinations and shuffles + // to do a transposition of an 8x8 block of sublanes. + // In the following algorithm description, A, B, ..., H represent 8 + // distinct input vregs that form an 8x8 block of data + // to be transposed. In our notation, B2 identifies the third + // sublane (2) of the second vreg (B)". + // + // + // If we think of each starting input vreg as a row in an 8x8 block of + // elements: + // A: A0 A1 A2 A3 A4 A5 A6 A7 + // B: B0 B1 B2 B3 B4 B5 B6 B7 + // ... + // H: H0 H1 H2 H3 H4 H5 H6 H7 + // + // The goal is to transpose this block, so the output vregs are: + // out0: A0 B0 C0 D0 E0 F0 G0 H0 + // out1: A1 B1 C1 D1 E1 F1 G1 H1 + // ... + // out7: A7 B7 C7 D7 E7 F7 G7 H7 + // + // Stage 1: Operates on pairs of input vregs (e.g., A and B). + // + // Input to Stage 1 (example pair A, B): + // A: A0 A1 A2 A3 A4 A5 A6 A7 + // B: B0 B1 B2 B3 B4 B5 B6 B7 + // + // Step 1.1: Combine low/high halves. + // combine_low(A, B) -> CL_AB: [A0 A1 A2 A3 | B0 B1 B2 B3] (8 elements) + // combine_high(A, B) -> CH_AB: [A4 A5 A6 A7 | B4 B5 B6 B7] (8 elements) + // (Notation: '|' separates the 4 elements from A and 4 from B) + // + // Step 1.2: Shuffle. + // The shuffle pattern for the low part (applied to CL_AB using + // `shuffle(CL_AB, CH_AB, pattern)`) is {0, 4, 1, 5, 2, 6, 3, 7}. + // The shuffle pattern for the high part (applied to CH_AB using + // `shuffle(CL_AB, CH_AB, pattern)`) is {8, 12, 9, 13, 10, 14, 11, 15}. + // (Indices 0-7 in shuffle refer to CL_AB, 8-15 to CH_AB). + // This results in: + // s1_AB_0: A0 B0 A1 B1 A2 B2 A3 B3 (from shuffling CL_AB elements) + // s1_AB_1: A4 B4 A5 B5 A6 B6 A7 B7 (from shuffling CH_AB elements) + // + // Output of Stage 1 / Input to Stage 2 (example for A,B,C,D processing): + // s1_vregs[0] (from A,B): A0 B0 A1 B1 A2 B2 A3 B3 + // s1_vregs[1] (from A,B): A4 B4 A5 B5 A6 B6 A7 B7 + // s1_vregs[2] (from C,D): C0 D0 C1 D1 C2 D2 C3 D3 + // s1_vregs[3] (from C,D): C4 D4 C5 D5 C6 D6 C7 D7 + // ... (and so on for E,F,G,H into s1_vregs[4-7]) + + // Stage 2: Operates on groups of 4 vregs from Stage 1 output. + // (e.g., s1_vregs[0], s1_vregs[1], s1_vregs[2], s1_vregs[3]) + // + // Input to Stage 2 (example processing s1_vregs[0] and s1_vregs[2]): + // X = s1_vregs[0] = [A0 B0 A1 B1 | A2 B2 A3 B3] + // Y = s1_vregs[2] = [C0 D0 C1 D1 | C2 D2 C3 D3] + // + // Step 2.1: Combine low/high halves. + // combine_low(X, Y) -> CL_XY: [A0 B0 A1 B1 | C0 D0 C1 D1] + // combine_high(X, Y) -> CH_XY: [A2 B2 A3 B3 | C2 D2 C3 D3] + // + // (Similarly for s1_vregs[1] and s1_vregs[3], let them be X' and Y') + // combine_low(X', Y') -> CL_X'Y': [A4 B4 A5 B5 | C4 D4 C5 D5] + // combine_high(X', Y') -> CH_X'Y': [A6 B6 A7 B7 | C6 D6 C7 D7] + // + // Step 2.2: Shuffle. + // The shuffle pattern for the low part (e.g., applied to CL_XY) is {0, 1, + // 4, 5, 2, 3, 6, 7}. The shuffle pattern for the high part (e.g., applied + // to CH_XY, effectively) is {8, 9, 12, 13, 10, 11, 14, 15}. + // + // This results in (for the first group of 4 input vregs A,B,C,D): + // s2_vregs[0]: A0 B0 C0 D0 A1 B1 C1 D1 (from shuffling CL_XY elements) + // s2_vregs[1]: A2 B2 C2 D2 A3 B3 C3 D3 (from shuffling CH_XY elements) + // s2_vregs[2]: A4 B4 C4 D4 A5 B5 C5 D5 (from shuffling CL_X'Y' elements) + // s2_vregs[3]: A6 B6 C6 D6 A7 B7 C7 D7 (from shuffling CH_X'Y' elements) + // + // Output of Stage 2 / Input to Stage 3: + // s2_vregs[0]: A0 B0 C0 D0 A1 B1 C1 D1 + // s2_vregs[1]: A2 B2 C2 D2 A3 B3 C3 D3 + // s2_vregs[2]: A4 B4 C4 D4 A5 B5 C5 D5 + // s2_vregs[3]: A6 B6 C6 D6 A7 B7 C7 D7 + // s2_vregs[4]: E0 F0 G0 H0 E1 F1 G1 H1 (from E,F,G,H processing) + // s2_vregs[5]: E2 F2 G2 H2 E3 F3 G3 H3 + // s2_vregs[6]: E4 F4 G4 H4 E5 F5 G5 H5 + // s2_vregs[7]: E6 F6 G6 H6 E7 F7 G7 H7 + + // Stage 3: Combine results from Stage 2. No shuffle needed after combine. + // Input to Stage 3 (example for the first two rows of the final transpose): + // L = s2_vregs[0] = [A0 B0 C0 D0 | A1 B1 C1 D1] + // R = s2_vregs[4] = [E0 F0 G0 H0 | E1 F1 G1 H1] + // + // Step 3.1: Combine low/high halves. + // combine_low(L, R) -> [A0 B0 C0 D0 | E0 F0 G0 H0] -> + // Final out0: A0 B0 C0 D0 E0 F0 G0 H0 + // combine_high(L, R) -> [A1 B1 C1 D1 | E1 F1 G1 H1] -> + // Final out1: A1 B1 C1 D1 E1 F1 G1 H1 + // ... and so on for other pairs from Stage 2 output + // (e.g. L=s2_vregs[1], R=s2_vregs[5]). + // + // This results in the correctly transposed 8x8 block. + + constexpr int64_t kMajorDimOriginalIdx = 0; + constexpr int64_t kSecondMinorDimOriginalIdx = 1; + constexpr int64_t kMinorMostDimOriginalIdx = 2; + + auto vec_shape = src_ty.getShape(); + auto major_dim_size = vec_shape[kMajorDimOriginalIdx]; + auto second_minor_dim_size = vec_shape[kSecondMinorDimOriginalIdx]; + + if (layout_in.offsets() != LayoutOffsets{0, 0}) { + return transpose_op.emitOpError("Not implemented: Layout with offset."); + } + if (layout_in.implicit_dim() != VectorLayout::ImplicitDim::kNone) { + return transpose_op.emitOpError( + "Not implemented: Layout with implicit dimension."); + } + + auto sublane_count = ctx.target_shape[0]; + if (second_minor_dim_size % sublane_count != 0 || + major_dim_size % sublane_count != 0) { + return transpose_op.emitOpError( + "Not implemented: Swapping major and second minor dimensions must " + "result in dimension sizes that are multiples of sublane_count."); + } + + if (!layout_in.hasNativeTiling(ctx.target_shape)) { + return transpose_op.emitOpError( + "Not implemented: Expected native input tiling."); + } + if (layout_in != layout_out) { + return transpose_op.emitOpError( + "Not implemented: Expected same input and output layouts."); + } + xla::Array dst_vregs( + layout_out.tileArrayShape(dst_ty.getShape(), ctx.target_shape)); + + if (layout_in.bitwidth() != 32) { + return transpose_op.emitOpError( + "Not implemented: Major-second-minor transpose only supported for " + "32-bit vectors. Also, input must be a vector type."); + } + if (ctx.target_shape[0] != 8) { + return transpose_op.emitOpError( + "Not implemented: Major-second-minor transpose expects 8 sublanes."); + } + + auto vreg_dimensions = src_vregs.dimensions(); + // Note(mvoz): Slice is a weird word here, This is used for constructing + // the output vregs - the reason we divide here is because we multiply it + // back later on to get the correct index into src_vregs, but the reason + // we cannot just resolve that in our outer loop is because of the nature + // of a transpose - this dim value goes unmultiplied into the output vregs. + // effectively, our indexing: + // {major_dim_slice_idx * sublane_count, second_minor_dim_slice_idx, + // minor_most_dim_slice_idx} becomes {second_minor_dim_slice_idx * + // sublane_count, major_dim_slice_idx, minor_most_dim_slice_idx} + auto num_slices_in_major_dim = + vreg_dimensions[kMajorDimOriginalIdx] / sublane_count; + auto num_slices_in_second_minor_dim = + vreg_dimensions[kSecondMinorDimOriginalIdx]; + auto num_slices_in_minor_most_dim = + vreg_dimensions[kMinorMostDimOriginalIdx]; + + auto shuffle = [&](Value lhs_vreg, Value rhs_vreg, ArrayRef pattern) { + auto lhs_vreg_type = lhs_vreg.getType(); + auto pattern_attr = builder.getDenseI32ArrayAttr(pattern); + return builder + .create(transpose_op.getLoc(), lhs_vreg_type, + lhs_vreg, rhs_vreg, pattern_attr) + .getResult(); + }; + + static constexpr std::array combine_low_pattern = {0, 1, 2, 3, + 8, 9, 10, 11}; + static constexpr std::array combine_high_pattern = {4, 5, 6, 7, + 12, 13, 14, 15}; + + auto combine_low = [&](Value lhs_vreg, Value rhs_vreg) { + return shuffle(lhs_vreg, rhs_vreg, combine_low_pattern); + }; + auto combine_high = [&](Value lhs_vreg, Value rhs_vreg) { + return shuffle(lhs_vreg, rhs_vreg, combine_high_pattern); + }; + + // Shuffle patterns for Stage 1 + // Input to shuffle: (combine_low_val, combine_high_val) + // combine_low_val has A0-A3, B0-B3. Indices 0-7 for shuffle. + // combine_high_val has A4-A7, B4-B7. Indices 8-15 for shuffle. + static constexpr std::array permute_pattern_stage1_low_arr = { + 0, 4, 1, 5, + 2, 6, 3, 7}; // Selects from combine_low_val to make A0B0A1B1A2B2A3B3 + static constexpr std::array permute_pattern_stage1_high_arr = { + 8, 12, 9, 13, 10, + 14, 11, 15}; // Selects from combine_high_val to make A4B4A5B5A6B6A7B7 + + // Shuffle patterns for Stage 2 + // Input to shuffle: (CL_XY, CH_XY) from Step 2.1 in comments. + // CL_XY has A0B0A1B1C0D0C1D1. Indices 0-7 for shuffle. + // CH_XY has A2B2A3B3C2D2C3D3. Indices 8-15 for shuffle. + static constexpr std::array permute_pattern_stage2_low_arr = { + 0, 1, 4, 5, 2, 3, 6, 7}; // Selects from CL_XY to make A0B0C0D0A1B1C1D1 + static constexpr std::array permute_pattern_stage2_high_arr = { + 8, 9, 12, 13, + 10, 11, 14, 15}; // Selects from CH_XY to make A2B2C2D2A3B3C3D3 + + for (int major_dim_slice_idx = 0; + major_dim_slice_idx < num_slices_in_major_dim; ++major_dim_slice_idx) { + for (int second_minor_dim_slice_idx = 0; + second_minor_dim_slice_idx < num_slices_in_second_minor_dim; + ++second_minor_dim_slice_idx) { + for (int minor_most_dim_slice_idx = 0; + minor_most_dim_slice_idx < num_slices_in_minor_most_dim; + ++minor_most_dim_slice_idx) { + // STAGE 1! + std::array + stage1_output_vregs; // Stores s1_vregs from comments + constexpr int num_pairs_stage1 = + 4; // Processes 4 pairs of vregs (A,B), (C,D), (E,F), (G,H) + + for (int i = 0; i < num_pairs_stage1; ++i) { + Value first_vreg = src_vregs( + {(2 * i) + (sublane_count * major_dim_slice_idx), + second_minor_dim_slice_idx, minor_most_dim_slice_idx}); + Value second_vreg = src_vregs( + {(2 * i) + (sublane_count * major_dim_slice_idx) + 1, + second_minor_dim_slice_idx, minor_most_dim_slice_idx}); + + auto combined_low_val = combine_low(first_vreg, second_vreg); + auto combined_high_val = combine_high(first_vreg, second_vreg); + + stage1_output_vregs[2 * i] = + shuffle(combined_low_val, combined_high_val, + permute_pattern_stage1_low_arr); + stage1_output_vregs[2 * i + 1] = + shuffle(combined_low_val, combined_high_val, + permute_pattern_stage1_high_arr); + } + + // STAGE 2! + std::array + stage2_output_vregs; // Stores s2_vregs from comments + constexpr int num_pairs_stage2 = + 4; // Processes 4 pairs of vregs from stage1_output_vregs + + for (int i = 0; i < num_pairs_stage2; ++i) { + // Determine the indices for the input pair from + // stage1_output_vregs. The 4 pairs processed in this stage are: + // i=0: (s1_vregs[0], s1_vregs[2]) + // i=1: (s1_vregs[1], s1_vregs[3]) + // i=2: (s1_vregs[4], s1_vregs[6]) + // i=3: (s1_vregs[5], s1_vregs[7]) + int s1_lhs_idx = (i / 2) * 4 + (i % 2); + int s1_rhs_idx = s1_lhs_idx + 2; + + Value s1_lhs_vreg = stage1_output_vregs[s1_lhs_idx]; + Value s1_rhs_vreg = stage1_output_vregs[s1_rhs_idx]; + + auto combined_low_val = combine_low(s1_lhs_vreg, s1_rhs_vreg); + auto combined_high_val = combine_high(s1_lhs_vreg, s1_rhs_vreg); + + // Determine the output indices for stage2_output_vregs. + // Each pair from Stage 1 produces a pair of vregs for Stage 2. + // Results are stored pair-wise: + // i=0 -> s2_vregs[0], s2_vregs[1] + // i=1 -> s2_vregs[2], s2_vregs[3] + // i=2 -> s2_vregs[4], s2_vregs[5] + // i=3 -> s2_vregs[6], s2_vregs[7] + int s2_out_idx_base = 2 * i; + + stage2_output_vregs[s2_out_idx_base] = + shuffle(combined_low_val, combined_high_val, + permute_pattern_stage2_low_arr); + stage2_output_vregs[s2_out_idx_base + 1] = + shuffle(combined_low_val, combined_high_val, + permute_pattern_stage2_high_arr); + } + + // STAGE 3! Combine results from stage 2. + std::array output_idx_parts{ + second_minor_dim_slice_idx * sublane_count, major_dim_slice_idx, + minor_most_dim_slice_idx}; + + constexpr int num_final_combines = + 4; // Corresponds to s2_vregs[0]..s2_vregs[3] pairing with + // s2_vregs[4]..s2_vregs[7] + for (int i = 0; i < num_final_combines; ++i) { + Value lhs = stage2_output_vregs[i]; // e.g., s2_ABCD_0 + Value rhs = stage2_output_vregs[i + 4]; // e.g., s2_EFGH_0 + auto final_combined_low = combine_low(lhs, rhs); + auto final_combined_high = combine_high(lhs, rhs); + + dst_vregs(output_idx_parts) = final_combined_low; + output_idx_parts[0] += 1; + dst_vregs(output_idx_parts) = final_combined_high; + output_idx_parts[0] += 1; + } + } + } + } + auto assembled = + assemble(builder, dst_ty, layout_out, dst_vregs, ctx.target_shape); + transpose_op.getOperation()->replaceAllUsesWith(assembled); + transpose_op.erase(); + return success(); } + { SmallVector p(permutation); p[rank - 2] = rank - 2; @@ -4709,7 +5460,7 @@ LogicalResult vector_transpose_rule(RewriteContext &ctx, Operation &op, const Value src_tile = assemble(builder, tile_ty_in, layout_in, src_tile_vregs, ctx.target_shape); auto new_transpose_op = - builder.create(tile_ty_out, src_tile, minor_perm); + builder.create(tile_ty_out, src_tile, minor_perm); new_transpose_op->setAttr("out_layout", builder.getAttr(layout_out)); auto unroll_vectors_op = builder.create( @@ -4801,60 +5552,6 @@ LogicalResult tpu_prng_random_bits_rule(RewriteContext &ctx, Operation &op, return success(); } -const llvm::StringMap &rules() { - static const llvm::StringMap *rules = [] { - static auto rules = new llvm::StringMap{ - {arith::ConstantOp::getOperationName(), arith_constant_rule}, - {arith::ExtFOp::getOperationName(), arith_extf_rule}, - {arith::ExtSIOp::getOperationName(), arith_extsi_rule}, - {arith::ExtUIOp::getOperationName(), arith_extui_rule}, - {arith::TruncFOp::getOperationName(), arith_truncf_rule}, - {arith::TruncIOp::getOperationName(), arith_trunci_rule}, - {func::ReturnOp::getOperationName(), func_return_rule}, - {scf::ForOp::getOperationName(), scf_for_rule}, - {scf::WhileOp::getOperationName(), scf_while_rule}, - {scf::ConditionOp::getOperationName(), scf_condition_rule}, - {scf::IfOp::getOperationName(), scf_if_rule}, - {scf::YieldOp::getOperationName(), yield_rule}, - {tpu::YieldOp::getOperationName(), yield_rule}, - {tpu::RotateOp::getOperationName(), tpu_rotate_rule}, - {tpu::DynamicRotateOp::getOperationName(), tpu_dynamic_rotate_rule}, - {tpu::ConcatenateOp::getOperationName(), tpu_concatenate_rule}, - {tpu::IotaOp::getOperationName(), tpu_iota_rule}, - {tpu::GatherOp::getOperationName(), tpu_gather_rule}, - {tpu::DynamicGatherOp::getOperationName(), tpu_dynamic_gather_rule}, - {tpu::LoadOp::getOperationName(), tpu_load_rule}, - {tpu::StoreOp::getOperationName(), tpu_store_rule}, - {tpu::StridedLoadOp::getOperationName(), tpu_strided_load_rule}, - {tpu::StridedStoreOp::getOperationName(), tpu_strided_store_rule}, - {tpu::VectorStoreOp::getOperationName(), tpu_vector_store_rule}, - {tpu::MatmulOp::getOperationName(), tpu_matmul_rule}, - {tpu::RegionOp::getOperationName(), tpu_region_rule}, - {tpu::BitcastOp::getOperationName(), tpu_bitcast_rule}, - {tpu::TraceOp::getOperationName(), tpu_trace_rule}, - {tpu::AssumeLayoutOp::getOperationName(), tpu_assume_layout_rule}, - {tpu::PRNGRandomBitsOp::getOperationName(), tpu_prng_random_bits_rule}, - {tpu::RelayoutOp::getOperationName(), tpu_relayout_rule}, - {tpu::FPToSIOp::getOperationName(), tpu_fptosi_rule}, - {vector::BroadcastOp::getOperationName(), vector_broadcast_rule}, - {vector::ExtractOp::getOperationName(), vector_extract_rule}, - {vector::LoadOp::getOperationName(), vector_load_rule}, - {vector::MultiDimReductionOp::getOperationName(), - vector_multi_reduction_rule}, - {vector::ExtractStridedSliceOp::getOperationName(), - vector_extract_strided_slice_rule}, - {vector::ShapeCastOp::getOperationName(), vector_shape_cast_rule}, - {vector::StoreOp::getOperationName(), vector_store_rule}, - {vector::TransposeOp::getOperationName(), vector_transpose_rule}}; - - for (const auto &[name, rule] : mlir::tpu::extensions::rules()) { - rules->insert({name, rule}); - } - return rules; - }(); - return *rules; -} - // Determines whether we should handle bank conflict for the given stride and // max_sublane_offset. // @@ -5255,221 +5952,6 @@ xla::Array retileToReducedSublanes( return dst_vreg_array; } - -// Copy one sublane from a vreg to another vreg. -// -// Arguments: -// src_vreg: The source vreg to copy a sublane from. -// src_sl_idx: The sublane index in src_vreg to copy from. -// dst_vreg: The base vreg to copy the sublane into. May be null. -// dst_sl_idx: The sublane index in the result. -// -// Returns: -// A new dst_vreg with the copied sublane. -Value copy_one_sublane(OpBuilder &builder, Value src_vreg, int src_sl_idx, - Value dst_vreg, int dst_sl_idx, - const std::array target_shape) { - src_vreg = builder.create( - src_vreg.getLoc(), src_vreg, - /*amount=*/(dst_sl_idx - src_sl_idx + target_shape[0]) % target_shape[0], - /*dimension=*/0, /*stride=*/nullptr, /*stride_dimension=*/nullptr); - if (dst_vreg) { - auto boundIdxConst = - std::bind(IdxConst, std::placeholders::_1, builder, src_vreg.getLoc()); - const int bitwidth = - cast(src_vreg.getType()).getElementTypeBitWidth(); - CHECK_EQ(bitwidth, - cast(dst_vreg.getType()).getElementTypeBitWidth()); - const VectorType vmask_ty = - getNativeVregOrVmaskType(builder.getI1Type(), bitwidth, target_shape); - auto sublanes_mask = builder.create( - src_vreg.getLoc(), vmask_ty, - ValueRange{boundIdxConst(dst_sl_idx), boundIdxConst(0)}, - ValueRange{boundIdxConst(dst_sl_idx + 1), - boundIdxConst(target_shape[1])}); - src_vreg = builder.create(src_vreg.getLoc(), sublanes_mask, - src_vreg, dst_vreg); - } - return src_vreg; -} - -// This function is based on tpu_rotate_rule. It applies a shift of amount to -// a given dim. A major difference is that it "overflows", i.e. if the shift -// amount is such that it pushes us into a new vreg, we create a new vreg and -// fill it in with the remaining rows. -// -// The shift is the difference between layout_in and layout_out, on the -// given dim. -FailureOr> tpu_rotate_with_overflow( - OpBuilder &builder, const std::array target_shape, - const Location loc, const VectorType vty, xla::Array in_tiles, - int64_t dim, const VectorLayout &layout_in, - const LayoutOffsets offsets_out) { - if (!layout_in.hasNativeTiling(target_shape)) { - return emitError(loc, "Not implemented: non-native tiling for layout"); - } - if (layout_in.bitwidth() != 32) { - return emitError(loc, - "Not implemented: multi-row shift with " - "bitwidth != 32"); - } - // TODO(apaszke,mvoz): Just use offsets_out instead of this. - VectorLayout layout_out(layout_in.bitwidth(), offsets_out, layout_in.tiling(), - layout_in.implicit_dim()); - - int64_t tiling_dim = dim - (in_tiles.num_dimensions() - 2); - if (tiling_dim != 0) { - return emitError(loc, - "Rotate with overflow untested for " - "dim != 0"); - } - auto amount = - *layout_out.offsets()[tiling_dim] - *layout_in.offsets()[tiling_dim]; - - SmallVector dst_tiles_shape = - layout_out.tileArrayImplicitShape(vty.getShape(), target_shape); - - const VectorType res_vreg_ty = - getNativeVregType(vty.getElementType(), target_shape); - - xla::Array out_tiles(dst_tiles_shape); - - // We update the result vregs in the following way: - // - If the offset is positive, write the first tile as is, if the offset - // is negative, blend it with the next tile. - // - Blend the rest of the tiles with the prior (positive offset) or next - // (negative offset) tile. - // - (In positive cases, we can get an extra vreg (overflow)) we write the - // remaining tiles. - // This only happens if the original input vreg size is smaller than the - // result vreg size (an offset) can "push" us into a new vreg. - // - // Ex: (30, 128), starting offset 0, shift by 6, native tiling (8, 128) - // The input is (4, 1), where the first 3 vregs are full (0-24) - // and the last vreg is filled in rows 0-6. When we offset it by 6, we - // need a 4th vreg, as now vreg 0 is filled in 6-8 (2 total), vreg 1, 2, 3 - // are filled in fully (8-16, 16-24, 24-32) (2 + 24 total), and vreg 4 is - // filled in 0-4. (2 + 24 + 4 = 30). - - // Negative offset amount means we: - // - // Ex 1: (30, 128), input offset 6, shift by -2, native tiling (8, 128) - // (The result of the last example, for simplicity). In this case, we have - // (5, 1) vregs as decribed above. Because the shift does not cause us to - // shift back from the 5th vreg, we still need it. In such a case, the result - // vreg is still (5, 1). - // - // - Write the first vreg as is. - // - The next vregs are blended with the prior one (except the last), - // where we blend by the shift amount. Ex: Vreg 1 goes from 6-8 to 4-8, - // pulling 2 rows from the next vreg. - // - The last tile is masked to only write the remaining rows. - // Ex: Vreg 4 goes from 0-4 to 0-2. - // - // Ex 2: (30, 128), starting offset 6, shift by -6, native tiling (8, 128) - // In this case, we have (5, 1) vregs as described above. Because the shift - // causes us to shift back from the 5th vreg, we don't need it anymore. - // In such a case, the result vreg is (4, 1). - // - // - All vregs are blended with the next one (except the last), - // where we blend by the shift amount. Ex: Vreg 1 goes from 6-8 to 0-8, - // pulling 6 rows from the next vreg. - // - The last tile is discarded - it was fully subsumed by the prior blends. - // - // Ex 3: (30, 128), starting offset 0, shift by -6, native tiling (8, 128) - // In this case, we have (4, 1) vregs as described above. - // In such a case, the result vreg is (4, 1), where the first vreg is filled - // in rows 2-8 (6), and vregs 1 and 2 are filled in fully (8-16, 16-24), and - // vreg 3 is filled in rows 0-6. - // - // NOTE - in such cases, where the abs(shift) in a negative shift > starting - // offset, we can actually implement this as a positive shift of the delta - // from the native tile size. - // in the example above, the delta is 8 - 6 + 0 = 2. The resulting vregs are - // the same as if we had shifted by 2, starting at offset 0. - // - // Another example to demonstrate the point: - // Ex 4: (30, 128), starting offset 2, shift by -4, native tiling (8, 128) - // In this case, we start with (4, 1) vregs as described above. - // (2-8)(8-16)(16-24)(0-4). Shifting by -4 is the same as 8 - 4 + 2 = 6. - // So we can just shift by 6, starting at offset 0. - // Vreg 0 is filled in 6-8 (2 total), vreg 1, 2 and 3 are filled in fully - // (8-16, 16-24, 24-32) (2 + 24 total = 26) vreg 4 is filled with the - // remainder, 0-4 (30 total). - // - // This means that no matter what the shift is, we should always - // rotate and compute the shift amount in such a way that the first input - // vreg is the first output vreg. - - // Compute the mask for the blend. - // Positive blends blend "forward" and negative blends blend "backward". - auto mask_val = amount; - auto vreg_rot_amount = amount; - if (amount < 0) { - mask_val = layout_in.tiling()[tiling_dim] - std::abs(amount); - vreg_rot_amount += target_shape[tiling_dim]; - } - auto boundIdxConst = std::bind(IdxConst, std::placeholders::_1, builder, loc); - auto mask = builder.create( - loc, VectorType::get(target_shape, builder.getI1Type()), - ValueRange{boundIdxConst(0), boundIdxConst(0)}, - ValueRange{boundIdxConst(mask_val), boundIdxConst(target_shape[1])}); - - // Actually do the rotation. - in_tiles.Each([&](absl::Span idxs, Value *v) { - if (dim >= in_tiles.num_dimensions() - 2) { - *v = builder.create(loc, res_vreg_ty, in_tiles(idxs), - vreg_rot_amount, tiling_dim, nullptr, - nullptr); - } - }); - - // Walk the result tiles. - // TODO(mvoz): There is a micro-optimization here where we can avoid - // allocating blend indices per vreg. - out_tiles.Each([&](absl::Span idxs, Value *v) { - if (idxs[dim] == 0) { - // A negative shift amount means we need to blend the first tile with the - // next one, but only if we're not at the end of the input. - if (amount < 0 && (idxs[dim] + 1 < in_tiles.dim(dim))) { - SmallVector next_idx = {idxs.begin(), idxs.end()}; - next_idx[dim] = idxs[dim] + 1; - *v = builder.create(loc, mask, in_tiles(idxs), - in_tiles(next_idx)); - } else { - // Positive shift, or negative shift at the end of the input. - *v = in_tiles(idxs); - } - } else if (idxs[dim] < in_tiles.dim(dim)) { - // write the rest as blended up to the end of the input - if (amount < 0) { - if (idxs[dim] + 1 < in_tiles.dim(dim)) { - SmallVector next_idx = {idxs.begin(), idxs.end()}; - next_idx[dim] = idxs[dim] + 1; - *v = builder.create(loc, mask, in_tiles(idxs), - in_tiles(next_idx)); - } else { - // Nothing to blend with, just write the last tile. - *v = in_tiles(idxs); - } - } else { - SmallVector prior_idx = {idxs.begin(), idxs.end()}; - prior_idx[dim] = idxs[dim] - 1; - *v = builder.create(loc, mask, in_tiles(prior_idx), - in_tiles(idxs)); - } - } else { - // write trailing if it's there (positive shift, increasing vreg count) - // Use the last prior - SmallVector prior_idx = {idxs.begin(), idxs.end()}; - prior_idx[dim] = idxs[dim] - 1; - *v = in_tiles(prior_idx); - } - }); - - return out_tiles; -} - void rotateVregs(OpBuilder &builder, xla::Array &vregs, const int64_t amount, const int dimension) { if (amount != 0) { @@ -5496,6 +5978,214 @@ void rotateLanes(OpBuilder &builder, xla::Array &vregs, rotateVregs(builder, vregs, amount, 1); } +// Rotate a vreg by a certain amount of rows, and get the low or high bits of +// each sublane after rotation. +// +// For these purposes, the vreg is considered to have shape (row_packing * +// target_shape[0], target_shape[1]) +// +// Note: When rotating by a whole number of sublanes, there are no low bits, so +// null is returned when is_high is false. +// +// Args: +// vreg: The vreg to rotate +// rotate_amount: The amount to rotate the vreg by. +// rows_per_sublane: The number of rows in a sublane. +// is_high: If true, get the high bits of each sublane, otherwise get low bits. +// +// Returns: +// The rotated vreg. +Value rotateVregRows(OpBuilder &builder, Location loc, Value vreg, + const int64_t rotate_amount, + const int64_t rows_per_sublane, const bool is_high, + const std::array target_shape) { + CHECK_LE(0, rotate_amount); + CHECK_LT(0, rows_per_sublane); + const int64_t bits_per_row = 32 / rows_per_sublane; + const int64_t sublane_rotate_amount = + (rotate_amount / rows_per_sublane + (is_high ? 0 : 1)) % target_shape[0]; + const int64_t within_sublane_rotate_amount = rotate_amount % rows_per_sublane; + if (within_sublane_rotate_amount == 0 && !is_high) { + return nullptr; + } + if (within_sublane_rotate_amount != 0) { + const VectorType vreg_ty = cast(vreg.getType()); + const VectorType i32_vreg_ty = + getNativeVregType(builder.getI32Type(), target_shape); + vreg = builder.create(loc, i32_vreg_ty, vreg); + if (is_high) { + auto shift_amt = builder.create( + loc, + DenseElementsAttr::get( + i32_vreg_ty, static_cast(bits_per_row * + within_sublane_rotate_amount))); + vreg = builder.create(loc, vreg, shift_amt); + } else { + auto shift_amt = builder.create( + loc, + DenseElementsAttr::get( + i32_vreg_ty, static_cast( + bits_per_row * (rows_per_sublane - + within_sublane_rotate_amount)))); + vreg = builder.create(loc, vreg, shift_amt); + } + vreg = builder.create(loc, vreg_ty, vreg); + } + return builder.create(vreg.getLoc(), vreg, + /*amount=*/sublane_rotate_amount, + /*dimension=*/0, /*stride=*/nullptr, + /*stride_dimension=*/nullptr); +} + +FailureOr> doRowShiftRelayout( + OpBuilder &builder, const Location loc, const ArrayRef shape, + xla::Array src_vregs, const VectorLayout &src_layout, + const int64_t dst_row_offset, const std::array target_shape) { + constexpr int32_t kNativeBitwidth = 32; + const std::array tiling = src_layout.tiling(); + const std::array tiled_ishape = + src_layout.getImplicitTiledDims(shape, 1); + const int64_t sublanes_per_tile = src_layout.sublanesPerTile(target_shape); + const int64_t tiles_per_vreg = src_layout.tilesPerVreg(target_shape); + const LayoutOffsets &src_offsets = src_layout.offsets(); + CHECK(src_offsets[0].has_value()); + CHECK_GE(*src_offsets[0], 0); + CHECK_LT(*src_offsets[0], tiling[0]); + CHECK_GE(dst_row_offset, 0); + CHECK_LT(dst_row_offset, tiling[0]); + CHECK_EQ(tiling[0] % sublanes_per_tile, 0); + const int64_t rows_per_sublane = tiling[0] / sublanes_per_tile; + const int64_t bits_per_row = kNativeBitwidth / rows_per_sublane; + const int64_t row_shift_amount = dst_row_offset - *src_offsets[0]; + // How many rows to shift (positive): + const int64_t shift_in_tile = (row_shift_amount + tiling[0]) % tiling[0]; + // How many rows to shift within a single sublane: + const int64_t shift_in_sublane = shift_in_tile % rows_per_sublane; + CHECK(src_vregs.begin() != src_vregs.end()); + const VectorType vreg_ty = cast(src_vregs.begin()->getType()); + const VectorType int_vreg_ty = + getNativeVregType(builder.getIntegerType(bits_per_row), target_shape); + + // The mask selects the first row_shift_amount full/half/quarter/etc-sublanes + // of each tile that contains data. + Value mask = nullptr; + for (int64_t i = 0; i < tiles_per_vreg; ++i) { + const int64_t start = i * sublanes_per_tile * rows_per_sublane; + // TODO: b/412753800 - Skip tiles that never contain data + Value tile_mask = + createSubelementMask(builder, loc, bits_per_row, /*from=*/start, + /*to=*/start + shift_in_tile, target_shape); + mask = mask == nullptr ? tile_mask + : builder.create(loc, mask, tile_mask); + } + + xla::Array res_vregs( + VectorLayout(src_layout.bitwidth(), {dst_row_offset, src_offsets[1]}, + src_layout.tiling(), src_layout.implicit_dim()) + .tileArrayImplicitShape(shape, target_shape)); + // rotate_rows_and_blend returns the combined high and low bits of a vreg + // after rotation by shift_in_tile. data_start and data_end (exclusive) are + // the rows of interest in the resulting vreg. + auto rotate_rows_and_blend = [&](Value vreg, const int64_t data_start, + const int64_t data_end) -> Value { + CHECK(vreg != nullptr); + // The split between low and high bits is at shift_in_sublane rows. + Value low_bits, high_bits; + // start_sublane is the first sublane in a tile that contains data + const int64_t start_sublane = data_start / rows_per_sublane; + // end_sublane the last sublane in a tile that contains data, inclusive + const int64_t end_sublane = (data_end - 1) / rows_per_sublane; + + // If data is in the high bits only, skip low bits + // This happens iff data is in a single sublane and begins after the split + if (start_sublane != end_sublane || + data_start % rows_per_sublane < shift_in_sublane) { + // Note that if shift_in_sublane is 0, rotateVregRows will return null + // since there are no low bits. + low_bits = + rotateVregRows(builder, loc, vreg, shift_in_tile, rows_per_sublane, + /*is_high=*/false, target_shape); + } + // If data is in the low bits only, skip high bits + // This happens iff data is in a single sublane and ends before the split + if (start_sublane != end_sublane || + (data_end - 1) % rows_per_sublane >= shift_in_sublane) { + high_bits = + rotateVregRows(builder, loc, vreg, shift_in_tile, rows_per_sublane, + /*is_high=*/true, target_shape); + } + if (low_bits != nullptr && high_bits != nullptr) { + return builder.create(loc, low_bits, high_bits); + } else if (low_bits != nullptr) { + return low_bits; + } else { + CHECK(high_bits != nullptr); + return high_bits; + } + }; + const int64_t res_low_idx_delta = *src_offsets[0] < dst_row_offset ? -1 : 0; + const int64_t res_high_idx_delta = *src_offsets[0] < dst_row_offset ? 0 : 1; + res_vregs.Each([&](absl::Span idxs, Value *v) { + // Each vreg of the result is (usually) a combination of two vregs from the + // source. If we are shifting *down* by 5 rows, the first 5 rows of result + // vreg i (along 2nd minor) will come from source vreg i-1, while the + // following rows will come from source vreg i. + + // The split of data between low and high is at shift_in_tile rows. + Value low, high; + // The start row of data in the vreg + const int64_t res_data_start = *(idxs.end() - 2) == 0 ? dst_row_offset : 0; + // The end row of data in the vreg, exclusive + const int64_t res_data_end = + *(idxs.end() - 2) == *(res_vregs.dimensions().end() - 2) - 1 + // -+ 1 before/after modulo so result is (1, tiling[0]) inclusive + ? (dst_row_offset + tiled_ishape[0] - 1) % tiling[0] + 1 + : tiling[0]; + // If data begins after the split, skip the low rows + if (res_data_start < shift_in_tile) { + SmallVector low_idxs(toArrayRef(idxs)); + *(low_idxs.end() - 2) += res_low_idx_delta; + low = builder.create(loc, int_vreg_ty, + src_vregs(low_idxs)); + low = rotate_rows_and_blend( + low, res_data_start, + /*data_end=*/std::min(res_data_end, shift_in_tile)); + // By doing the tile rotate after, rotate_rows_and_blend can be CSE'd + // since the low part of this vreg is the high part of the previous vreg. + // If there is no next previous or there is no benefit in CSE (e.g. we + // only use high bits and next vreg only uses low bits), the rotates + // should get merged anyway. + // TODO(tlongeri): Think more about the order in which rotates happen. + // Doing OR before rotate may be better. + low = builder.create( + loc, low, (tiles_per_vreg - 1) * sublanes_per_tile, 0, nullptr, + nullptr); + } + // If data ends before the split, skip high rows. + if (res_data_end > shift_in_tile) { + SmallVector high_idxs(toArrayRef(idxs)); + *(high_idxs.end() - 2) += res_high_idx_delta; + high = builder.create(loc, int_vreg_ty, + src_vregs(high_idxs)); + high = rotate_rows_and_blend( + high, + /*data_start=*/std::max(res_data_start, shift_in_tile), res_data_end); + } + + if (low != nullptr && high != nullptr) { + *v = builder.create(loc, mask, low, high); + } else if (low != nullptr) { + *v = low; + } else { + CHECK(high != nullptr); + *v = high; + } + *v = builder.create(loc, vreg_ty, *v); + }); + + return res_vregs; +} + // Relayout src_vregs from layout src to layout dst, where dst is the same as // src except that the column offset is dst_col_offset. FailureOr> doColumnShiftRelayout( @@ -5751,10 +6441,6 @@ FailureOr>> changeOffsets( const VectorType vty, const VectorLayout src, xla::Array vregs, const LayoutOffsets dst_offsets) { const auto &target_shape = ctx.target_shape; - const VectorLayout dst(src.bitwidth(), dst_offsets, src.tiling(), - src.implicit_dim()); - const int packing = src.packing(); - const int8_t bitwidth = src.bitwidth(); int row_diff; if (!src.offsets()[0].has_value()) { @@ -5774,77 +6460,28 @@ FailureOr>> changeOffsets( col_diff = *dst_offsets[1] - *src.offsets()[1]; } + VectorLayout src_after_row_shift(src.bitwidth(), + {dst_offsets[0], src.offsets()[1]}, + src.tiling(), src.implicit_dim()); if (row_diff != 0) { - if (col_diff != 0) { - return emitError(loc, "Not implemented: Row and column offset changes"); - } - const SmallVector implicit_shape = - src.implicitShape(vty.getShape()); - if (implicit_shape[implicit_shape.size() - 2] != 1) { - // Multi row shift - // TODO(mvoz): This should take the vregs array, not the value. - FAILUREOR_ASSIGN_OR_RETURN( - vregs, tpu_rotate_with_overflow( - builder, target_shape, loc, vty, std::move(vregs), - /*dim*/ implicit_shape.size() - 2, src, dst_offsets)); - } else { - // Single row case - // TODO(mvoz): The single row case has a broader set of supported - // operations: non-native tiling, packed types, implicit dim. We should - // support these cases in tpu_rotate_with_overflow and remove this - // branch. - const int64_t src_sublane = *src.offsets()[0] / packing; - const int64_t dst_sublane = *dst_offsets[0] / packing; - if (int64_t sublane_diff = dst_sublane - src_sublane) { - if (sublane_diff < 0) { - sublane_diff += target_shape[0]; - } - rotateSublanes(builder, vregs, sublane_diff); - } - const int src_subelem = *src.offsets()[0] % packing; - const int dst_subelem = *dst.offsets()[0] % packing; - if (src_subelem != dst_subelem) { - const int subelem_diff = dst_subelem - src_subelem; - const int shift_bits = bitwidth * std::abs(subelem_diff); - VectorType bits_vreg_ty = - VectorType::get(target_shape, builder.getI32Type()); - auto shift_vreg = builder.create( - loc, bits_vreg_ty, - DenseElementsAttr::get(bits_vreg_ty, shift_bits)); - vregs.Each([&](absl::Span /*idx*/, Value *tile) { - auto bit_tile = - builder.create(loc, bits_vreg_ty, *tile); - Operation *shift_tile; - if (subelem_diff > 0) { - shift_tile = - builder.create(loc, bit_tile, shift_vreg); - } else { // subelem_diff < 0 - CHECK_LT(subelem_diff, 0); - shift_tile = - builder.create(loc, bit_tile, shift_vreg); - } - *tile = builder - .create(loc, tile->getType(), - shift_tile->getResult(0)) - .getResult(); - }); - } - } + FAILUREOR_ASSIGN_OR_RETURN( + vregs, doRowShiftRelayout(builder, loc, vty.getShape(), vregs, src, + *dst_offsets[0], ctx.target_shape)); + // Make sure the shape is as expected. + SmallVector current_tiles_shape = + src_after_row_shift.tileArrayImplicitShape(vty.getShape(), + target_shape); + CHECK_EQ(*(current_tiles_shape.end() - 2), *(vregs.dimensions().end() - 2)); } - // Rows are now correctly aligned. Time to offset columns. - // TODO(apaszke, mvoz): Changing an offset might add or remove one vreg. - // Note - this is handled for row shifts via tpu_rotate_with_overflow - SmallVector dst_tiles_shape = - dst.tileArrayImplicitShape(vty.getShape(), target_shape); - CHECK_EQ(*(dst_tiles_shape.end() - 2), *(vregs.dimensions().end() - 2)); - - // TODO(tlongeri): Clean up col_diff and pass the dst offset directly. if (col_diff != 0) { FAILUREOR_ASSIGN_OR_RETURN( vregs, doColumnShiftRelayout(builder, vty.getShape(), std::move(vregs), - src, *dst.offsets()[1], target_shape)); + src_after_row_shift, *dst_offsets[1], + target_shape)); } + VectorLayout dst(src.bitwidth(), dst_offsets, src.tiling(), + src.implicit_dim()); return std::make_pair(dst, std::move(vregs)); } @@ -6622,7 +7259,7 @@ FailureOr>> changeTiling( FailureOr>> changeImplicitDim( RewriteContext &ctx, OpBuilder &builder, const Location loc, VectorType vty, - const VectorLayout src, xla::Array vregs, + VectorLayout src, xla::Array vregs, const VectorLayout::ImplicitDim dst_implicit_dim, const LayoutOffsets dst_offset_hints) { const auto &target_shape = ctx.target_shape; @@ -6637,6 +7274,59 @@ FailureOr>> changeImplicitDim( src_candidate.tileArrayImplicitShape(vty.getShape(), target_shape)); return std::make_pair(src_candidate, vregs); } + const int64_t sublanes_per_tile = src.sublanesPerTile(target_shape); + CHECK_GT(sublanes_per_tile, 0); + if (src.tiling()[0] % sublanes_per_tile != 0) { + // Tilings such as 32-bit (4, 256) are not used and not supported. + return emitError( + loc, "Not implemented: Rows within tile span multiple sublanes"); + } + const int64_t rows_per_sublane = src.tiling()[0] / sublanes_per_tile; + // Add second minor implicit dim + if (src.implicit_dim() == VectorLayout::ImplicitDim::kNone && + dst_implicit_dim == VectorLayout::ImplicitDim::kSecondMinor) { + // TODO(tlongeri): Detect replicated source 2nd minor as a no-op above + const int64_t src_offset = src.offsets()[0].value_or(0); + // TODO(tlongeri): Do broadcast (different path) for replicated output + const int64_t dst_offset = dst_offset_hints[0].value_or(0); + VectorLayout dst(src.bitwidth(), {dst_offset, src.offsets()[1]}, + src.tiling(), dst_implicit_dim); + xla::Array new_vregs( + dst.tileArrayImplicitShape(vty.getShape(), target_shape)); + DCHECK_EQ(*(new_vregs.dimensions().end() - 2), 1); + // Define src_idx outside loop to avoid reallocation + SmallVector src_idx; + new_vregs.Each([&](const absl::Span idx, Value *new_vreg) { + // Shift the desired row from the source vreg to the desired offset for + // the destination vreg. This is done with rotates and, for packed types + // with multiple rows per sublane, bitshifts. + // Note that the offset of the source row varies but the destination + // offset is always the same. + const int64_t dst_offset_in_sublane = dst_offset % rows_per_sublane; + // src_row_with_offset is the row of the padded implicit shape that we + // will place in the destination vreg. The first dst vreg along the + // non-implicit 2nd minor has the source row at offset src_offset, the + // second has the source row at offset src_offset+1, etc. + const int64_t src_row_with_offset = *(idx.end() - 3) + src_offset; + src_idx.assign(idx.begin(), idx.end() - 3); + src_idx.push_back(src_row_with_offset / src.tiling()[0]); + src_idx.push_back(idx.back()); + Value vreg = vregs(src_idx); + const int64_t src_offset_in_vreg = src_row_with_offset % src.tiling()[0]; + const int64_t src_offset_in_sublane = + src_row_with_offset % rows_per_sublane; + int64_t row_rotate_amt = dst_offset - src_offset_in_vreg; + if (row_rotate_amt < 0) { + row_rotate_amt += rows_per_sublane * target_shape[0]; + } + *new_vreg = rotateVregRows( + builder, loc, vreg, row_rotate_amt, rows_per_sublane, + /*is_high=*/src_offset_in_sublane <= dst_offset_in_sublane, + ctx.target_shape); + }); + return std::make_pair(dst, new_vregs); + } + // Remove second minor implicit dim, for values that have (m, 128) tiling (for // m that is a power of 2). if (src.implicit_dim() == VectorLayout::ImplicitDim::kSecondMinor && @@ -6663,7 +7353,6 @@ FailureOr>> changeImplicitDim( // For example, extended offsets allow us to skip copies of low sublanes // in tiles with idx.back() == 0. const int tiles_per_vreg = src.tilesPerVreg(target_shape); - const int sublanes_per_tile = src.sublanesPerTile(target_shape); src_idx[dst_2nd_minor_idx] = src.tiling()[0] * idx[dst_2nd_minor_idx] + dst_sl_start - dst_sublane_offset; for (int dst_sl_idx = dst_sl_start; @@ -6674,9 +7363,9 @@ FailureOr>> changeImplicitDim( for (int tile_idx = 0; tile_idx < tiles_per_vreg; ++tile_idx) { int tile_off = tile_idx * sublanes_per_tile; *tile = - copy_one_sublane(builder, vregs(src_idx), - tile_off + src.offsets()[0].value_or(dst_sl_idx), - *tile, tile_off + dst_sl_idx, target_shape); + copyOneSublane(builder, vregs(src_idx), + tile_off + src.offsets()[0].value_or(dst_sl_idx), + *tile, tile_off + dst_sl_idx, target_shape); } } }); @@ -6702,6 +7391,31 @@ FailureOr>> changeImplicitDim( dst.offsets())); return std::make_pair(dst, std::move(dst_vregs)); } + if (src.implicit_dim() == VectorLayout::ImplicitDim::kMinor && + dst_implicit_dim == VectorLayout::ImplicitDim::kSecondMinor && + src.bitwidth() == 32 && src.hasNativeTiling(ctx.target_shape)) { + const int64_t dst_minor_offset = dst_offset_hints[1].value_or(0); + FAILUREOR_ASSIGN_OR_RETURN( + xla::Array dst_vregs, + transposeSingletonMinorDimension(ctx, builder, loc, vregs, + src.implicitShape(vty.getShape()), src, + dst_minor_offset)); + VectorLayout dst(src.bitwidth(), {std::nullopt, dst_minor_offset}, + src.tiling(), VectorLayout::ImplicitDim::kSecondMinor); + return std::make_pair(dst, std::move(dst_vregs)); + } + if (src.implicit_dim() == VectorLayout::ImplicitDim::kMinor && + dst_implicit_dim == VectorLayout::ImplicitDim::kNone && + src.bitwidth() == 32 && src.hasNativeTiling(ctx.target_shape)) { + FAILUREOR_ASSIGN_OR_RETURN( + std::tie(src, vregs), + changeImplicitDim(ctx, builder, loc, vty, src, std::move(vregs), + VectorLayout::ImplicitDim::kSecondMinor, + dst_offset_hints)); + return changeImplicitDim(ctx, builder, loc, vty, src, std::move(vregs), + VectorLayout::ImplicitDim::kNone, + dst_offset_hints); + } return emitError(loc, "Not implemented: Unsupported implicit dim change: from ") << src << " to " << dst_implicit_dim; @@ -6714,12 +7428,20 @@ FailureOr> relayout(RewriteContext &ctx, VectorLayout src, VectorLayout dst) { const auto target_shape = ctx.target_shape; + VectorType vty = v.getType(); const int8_t bitwidth = src.bitwidth(); - if (bitwidth != dst.bitwidth()) { + const bool is_mask = vty.getElementTypeBitWidth() == 1; + const bool is_mask_pack = + is_mask && bitwidth == 32 && dst.bitwidth() == 16 && + src.tiling()[0] == src.packing() * target_shape[0] && + src.tiling()[1] == target_shape[1] && src.tiling() == dst.tiling() && + src.offsets() == dst.offsets() && + src.implicit_dim() == dst.implicit_dim(); + + if (bitwidth != dst.bitwidth() && !is_mask_pack) { return emitError(v.getLoc(), "Can't change bitwidth during a relayout"); } - VectorType vty = v.getType(); - const bool is_mask = vty.getElementTypeBitWidth() == 1; + { // Replication imposes a replication constraint on the *logical* value of // the vector: When moving along a replicated axis, all elements must be @@ -6753,6 +7475,39 @@ FailureOr> relayout(RewriteContext &ctx, FAILUREOR_ASSIGN_OR_RETURN( xla::Array src_tiles, disassemble(builder, src, v, target_shape, /*use_implicit_shape=*/true)); + + if (is_mask_pack) { + std::vector vmsks_shape(src_tiles.dimensions().begin(), + src_tiles.dimensions().end()); + *(vmsks_shape.end() - 1) = llvm::divideCeil(vmsks_shape.back(), 2); + xla::Array out_vmsks(vmsks_shape, nullptr); + SmallVector val_idx; + Value default_val = getFullVector( + builder, v.getLoc(), + cast>(*src_tiles.begin()).getType(), + IntegerAttr::get(builder.getI1Type(), 0)); + out_vmsks.Each([&](absl::Span idx, Value *v_slot_in_array) { + val_idx.assign(idx.begin(), idx.end()); + *(val_idx.end() - 1) *= 2; + Value low_part = + *(val_idx.end() - 1) < *(src_tiles.dimensions().end() - 1) + ? src_tiles(val_idx) + : default_val; + *(val_idx.end() - 1) += 1; + Value high_part = + *(val_idx.end() - 1) < *(src_tiles.dimensions().end() - 1) + ? src_tiles(val_idx) + : default_val; + const VectorType mask_ty = getNativeVregOrVmaskType( + builder.getI1Type(), bitwidth / 2, target_shape); + *v_slot_in_array = + builder.create(v.getLoc(), mask_ty, low_part, high_part); + }); + return assemble(builder, vty, dst, out_vmsks, target_shape, + /*use_implicit_shape=*/true) + .getResult(); + } + if (is_mask) { auto new_tile_ty = getNativeVregOrVmaskType( builder.getIntegerType(bitwidth), bitwidth, target_shape); @@ -6764,6 +7519,8 @@ FailureOr> relayout(RewriteContext &ctx, } auto assemble_with_mask_check = [&](xla::Array &tiles, bool use_implicit_shape = false) { + + if (is_mask) { auto zeros_tile = builder.create( tiles.begin()->getLoc(), @@ -6882,14 +7639,97 @@ FailureOr> relayout(RewriteContext &ctx, changeOffsets(ctx, builder, v.getLoc(), vty, src, std::move(src_tiles), dst.offsets())); - CHECK_EQ(src, dst); // At this point we've should be done. - return assemble_with_mask_check(src_tiles, - /*use_implicit_shape=*/true); + CHECK_EQ(src, dst); + return assemble_with_mask_check(src_tiles, /*use_implicit_shape=*/true); +} + +LogicalResult tpu_relayout_rule(RewriteContext &ctx, Operation &op, + const ArrayRef layouts_in, + const ArrayRef layouts_out) { + auto tpu_relayout_op = cast(op); + auto input_val = dyn_cast>(tpu_relayout_op.getInput()); + + auto in_layout_array_attr = + tpu_relayout_op->getAttrOfType("in_layout"); + auto src_vla = dyn_cast(in_layout_array_attr[0]); + VectorLayout src_layout = src_vla.getLayout().value(); + + auto out_layout_array_attr = + tpu_relayout_op->getAttrOfType("out_layout"); + auto dst_vla = dyn_cast(out_layout_array_attr[0]); + VectorLayout dst_layout = dst_vla.getLayout().value(); + + if (src_layout == dst_layout) { + return op.emitError( + "Source and destination layouts are the same - did you forget to run " + "relayout-insertion-pass?"); + } + + OpBuilder builder(&op); + FAILUREOR_ASSIGN_OR_RETURN( + TypedValue new_v, + relayout(ctx, builder, input_val, src_layout, dst_layout)); + + tpu_relayout_op.replaceAllUsesWith(new_v); + tpu_relayout_op.erase(); + return success(); +} + +const llvm::StringMap &rules() { + static const llvm::StringMap *rules = [] { + static auto rules = new llvm::StringMap{ + {arith::ConstantOp::getOperationName(), arith_constant_rule}, + {arith::ExtFOp::getOperationName(), arith_extf_rule}, + {arith::ExtSIOp::getOperationName(), arith_extsi_rule}, + {arith::ExtUIOp::getOperationName(), arith_extui_rule}, + {arith::TruncFOp::getOperationName(), arith_truncf_rule}, + {arith::TruncIOp::getOperationName(), arith_trunci_rule}, + {func::ReturnOp::getOperationName(), func_return_rule}, + {scf::ForOp::getOperationName(), scf_for_rule}, + {scf::WhileOp::getOperationName(), scf_while_rule}, + {scf::ConditionOp::getOperationName(), scf_condition_rule}, + {scf::IfOp::getOperationName(), scf_if_rule}, + {scf::YieldOp::getOperationName(), yield_rule}, + {tpu::YieldOp::getOperationName(), yield_rule}, + {tpu::RotateOp::getOperationName(), tpu_rotate_rule}, + {tpu::DynamicRotateOp::getOperationName(), tpu_dynamic_rotate_rule}, + {tpu::ConcatenateOp::getOperationName(), tpu_concatenate_rule}, + {tpu::IotaOp::getOperationName(), tpu_iota_rule}, + {tpu::GatherOp::getOperationName(), tpu_gather_rule}, + {tpu::DynamicGatherOp::getOperationName(), tpu_dynamic_gather_rule}, + {tpu::LoadOp::getOperationName(), tpu_load_rule}, + {tpu::StoreOp::getOperationName(), tpu_store_rule}, + {tpu::StridedLoadOp::getOperationName(), tpu_strided_load_rule}, + {tpu::StridedStoreOp::getOperationName(), tpu_strided_store_rule}, + {tpu::VectorStoreOp::getOperationName(), tpu_vector_store_rule}, + {tpu::MatmulOp::getOperationName(), tpu_matmul_rule}, + {tpu::RegionOp::getOperationName(), tpu_region_rule}, + {tpu::BitcastOp::getOperationName(), tpu_bitcast_rule}, + {tpu::TraceOp::getOperationName(), tpu_trace_rule}, + {tpu::AssumeLayoutOp::getOperationName(), tpu_assume_layout_rule}, + {tpu::PRNGRandomBitsOp::getOperationName(), tpu_prng_random_bits_rule}, + {tpu::RelayoutOp::getOperationName(), tpu_relayout_rule}, + {tpu::FPToSIOp::getOperationName(), tpu_fptosi_rule}, + {tpu::SIToFPOp::getOperationName(), tpu_sitofp_rule}, + {vector::BroadcastOp::getOperationName(), vector_broadcast_rule}, + {vector::ExtractOp::getOperationName(), vector_extract_rule}, + {vector::LoadOp::getOperationName(), vector_load_rule}, + {vector::MultiDimReductionOp::getOperationName(), + vector_multi_reduction_rule}, + {vector::ExtractStridedSliceOp::getOperationName(), + vector_extract_strided_slice_rule}, + {vector::ShapeCastOp::getOperationName(), vector_shape_cast_rule}, + {vector::StoreOp::getOperationName(), vector_store_rule}, + {tpu::TransposeOp::getOperationName(), vector_transpose_rule}}; + + for (const auto &[name, rule] : mlir::tpu::extensions::rules()) { + rules->insert({name, rule}); + } + return rules; + }(); + return *rules; } -// TODO(apaszke): Implement a debug mode that inserts additional assertions. -// For example, we should verify that ops that were supposed to generate -// replicated outputs satisfy that requirement. LogicalResult applyLayoutOp(RewriteContext &ctx, Operation &op) { // When an operation does not have any operands, the layout_in tuple is empty. // If one of the operands is not of vector type, the corresponding entry in @@ -6925,14 +7765,11 @@ LogicalResult applyLayoutOp(RewriteContext &ctx, Operation &op) { getOutLayouts(*def_op, ctx.target_shape)); const Layout lo = def_layouts[res_idx]; TPU_ASSERT_OP(lo.has_value()); - if (*lo == *li) { - continue; + if (*lo != *li) { + return op.emitError( + "Invariant violation: Input layout does not match output layout - " + "did you forget to run relayout-insertion?"); } - OpBuilder builder(&op); - FAILUREOR_ASSIGN_OR_RETURN( - Value new_v, relayout(ctx, builder, vector_operand, /*src=*/*lo, - /*dst=*/*li)); - op.setOperand(idx, new_v); } } @@ -6940,7 +7777,8 @@ LogicalResult applyLayoutOp(RewriteContext &ctx, Operation &op) { // support for offsets outside of the first tile. When support is more broad, // any op without support should check it within their own rule. if (!isa(op)) { + vector::ExtractStridedSliceOp, vector::ShapeCastOp, tpu::RelayoutOp>( + op)) { for (const Layout &layout : layouts_in) { if (layout && layout->offsets()[1].has_value() && layout->offsets()[1].value() >= layout->tiling()[1]) { diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.h b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.h index ed72a21028eb..bbf23a9f3844 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.h +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.h @@ -1,3 +1,18 @@ +/* Copyright 2023 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + #ifndef THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_TRANSFORMS_APPLY_VECTOR_LAYOUT_H_ #define THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_TRANSFORMS_APPLY_VECTOR_LAYOUT_H_ diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout_extensions.h b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout_extensions.h index 33c9e7421004..72bd8ca370c8 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout_extensions.h +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout_extensions.h @@ -1,11 +1,26 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + #ifndef THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_TRANFORMS_APPLY_VECTOR_LAYOUT_EXTENSIONS_H_ #define THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_TRANFORMS_APPLY_VECTOR_LAYOUT_EXTENSIONS_H_ #include -#include "llvm/include/llvm/ADT/StringMap.h" -#include "mlir/include/mlir/IR/Operation.h" -#include "mlir/include/mlir/Support/LLVM.h" +#include "llvm/ADT/StringMap.h" +#include "mlir/IR/Operation.h" +#include "mlir/Support/LLVM.h" #include "jaxlib/mosaic/dialect/tpu/layout.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" diff --git a/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc b/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc index 5efbdb9cb437..b74a5ae15137 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc @@ -1,44 +1,57 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + #include +#include #include #include +#include #include #include #include #include #include -#include "llvm/ADT/STLExtras.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -// It requires these headers, but does not include them. -// NOLINTNEXTLINE(misc-include-cleaner) -#include "mlir/Dialect/MemRef/IR/MemRef.h" -// NOLINTNEXTLINE(misc-include-cleaner) +#include "absl/log/check.h" #include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringMap.h" -#include "llvm/ADT/StringSet.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Math/IR/Math.h" -#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" // IWYU pragma: keep +#include "mlir/Dialect/SCF/IR/SCF.h" // IWYU pragma: keep +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Block.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributeInterfaces.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/Region.h" +#include "mlir/IR/Value.h" #include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" -#include "absl/log/check.h" -#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/include/mlir/Dialect/Math/IR/Math.h" -#include "mlir/include/mlir/Dialect/Vector/IR/VectorOps.h" -#include "mlir/include/mlir/Dialect/Vector/Transforms/VectorTransforms.h" -#include "mlir/include/mlir/IR/AffineExpr.h" -#include "mlir/include/mlir/IR/Attributes.h" -#include "mlir/include/mlir/IR/Block.h" -#include "mlir/include/mlir/IR/Builders.h" -#include "mlir/include/mlir/IR/BuiltinAttributes.h" -#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" -#include "mlir/include/mlir/IR/OpDefinition.h" -#include "mlir/include/mlir/IR/Operation.h" -#include "mlir/include/mlir/IR/PatternMatch.h" -#include "mlir/include/mlir/IR/Region.h" -#include "mlir/include/mlir/IR/Value.h" -#include "mlir/include/mlir/Support/LLVM.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" +#include "jaxlib/mosaic/dialect/tpu/util.h" #include "jaxlib/mosaic/dialect/tpu/vreg_util.h" namespace mlir::tpu { @@ -54,8 +67,17 @@ struct CanonicalizeContext { bool compatibility_mode; int hardware_generation; + + std::array target_shape; }; +Value create_transpose_op(const CanonicalizeContext &ctx, + ImplicitLocOpBuilder &builder, VectorType input_ty, + Value input, ArrayRef permutation); + +bool need_elementwise_canonicalization(const CanonicalizeContext &ctx, + Operation &op); + LogicalResult tpu_matmul_rule(const CanonicalizeContext &ctx, tpu::MatmulOp op) { ImplicitLocOpBuilder builder(op.getLoc(), op.getOperation()); @@ -190,7 +212,38 @@ LogicalResult tpu_matmul_rule(const CanonicalizeContext &ctx, } } - auto dot_dim_matmul = [&](auto lhs, auto rhs, auto acc) { + // Attempt to canonicalize matmul(x, transpose(y)) to a matmul with the + // dimension numbers changed which will later be lowered into a more efficient + // operation that fuses the transpose into the matmul. + auto transpose_op = + dyn_cast_if_present(rhs.getDefiningOp()); + auto dimension_numbers = op.getDimensionNumbers(); + if (transpose_op && transpose_op->hasOneUse() && + dimension_numbers->getRhsContractingDims().size() == 1 && + dimension_numbers->getRhsNonContractingDims().size() == 1) { + auto rhs_non_contracting_dim = + dimension_numbers->getRhsNonContractingDims()[0]; + auto rhs_contracting_dim = dimension_numbers->getRhsContractingDims()[0]; + auto permutation = transpose_op.getPermutation(); + if (permutation[rhs_contracting_dim] == rhs_non_contracting_dim && + permutation[rhs_non_contracting_dim] == rhs_contracting_dim && + std::all_of(dimension_numbers->getRhsBatchDims().begin(), + dimension_numbers->getRhsBatchDims().end(), + [&](long batch_dim) { + return permutation[batch_dim] == batch_dim; + })) { + if (auto transpose_op_vector_operand = + dyn_cast>(transpose_op.getOperand())) { + // The transpose is DCE'ed away at a later point. + rhs = transpose_op_vector_operand; + transpose_rhs = !transpose_rhs; + } else { + return op->emitOpError("Unexpected operand type for transpose op."); + } + } + } + + auto dot_dim_matmul = [&](Value lhs, auto rhs, auto acc) { auto precision_attr = op.getPrecisionAttr(); // If we are transposing the lhs, we need to transpose the lhs before @@ -209,13 +262,12 @@ LogicalResult tpu_matmul_rule(const CanonicalizeContext &ctx, std::vector shape(lhs_ty.getShape()); std::swap(shape[rank - 2], shape[rank - 1]); - auto lhs_ty_transposed = VectorType::get(shape, lhs_ty.getElementType()); + VectorType lhs_ty_transposed = + VectorType::get(shape, lhs_ty.getElementType()); const SmallVector perm_vec = SmallVector(perm.begin(), perm.end()); - lhs = builder.create( - lhs_ty_transposed, lhs, - DenseI64ArrayAttr::get(builder.getContext(), perm_vec)); + lhs = create_transpose_op(ctx, builder, lhs_ty_transposed, lhs, perm_vec); } auto ddn = defaultDimensionNumbers(builder, /*transpose_lhs=*/false, transpose_rhs); @@ -246,7 +298,7 @@ LogicalResult tpu_matmul_rule(const CanonicalizeContext &ctx, auto matmul_res = dot_dim_matmul(sliced_lhs.getResult(), sliced_rhs.getResult(), sliced_acc.getResult()); - auto res_ty = matmul_res.getType().cast(); + auto res_ty = cast(matmul_res.getType()); auto res_shape = res_ty.getShape(); // reshape to 1x[prior_shape] auto reshape_shape = llvm::to_vector(res_shape); @@ -310,12 +362,7 @@ LogicalResult canonicalize_elementwise(const CanonicalizeContext &ctx, auto element_type = ty.getElementType(); // There's an annoying hodgepodge of elementwise ops that need to be // rewritten to f32 on later hardware. - // TODO(mvoz): Look into (1) what it would take to support these ops - // natively on later hardware, and (2) how to better organize this list. - bool needs_cast = ctx.hardware_generation <= 5 || isa(op) || - isa(op) || isa(op) || - isa(op); - if (needs_cast && element_type.isBF16()) { + if (element_type.isBF16()) { if (ctx.compatibility_mode) { auto target_f32 = builder.create(op.getLoc(), target_f32_ty, operand) @@ -340,21 +387,23 @@ LogicalResult canonicalize_elementwise(const CanonicalizeContext &ctx, } } if (should_rewrite_op) { - auto result_ty = dyn_cast(op.getResult(0).getType()); - if (!result_ty) { + if (!res_ty) { op.emitOpError("Not implemented: Unexpected result type"); return failure(); } - auto result_element_type = result_ty.getElementType(); - if (!result_element_type.isF32() && !result_element_type.isBF16()) { - op.emitOpError("Not implemented: Unexpected result element type"); - return failure(); - } - // Do the new op in f32, then truncate to the original element type. + // Do the new op in f32, then truncate to the original element type if + // needed. For example, result of arith::CmpF is i1 and doesn't need to be + // truncated. + bool should_truncate = !isa(op); + auto new_res_ty = + VectorType::get(shape, should_truncate ? builder.getF32Type() + : res_ty.getElementType()); auto new_op = builder.create(op.getLoc(), op.getName().getIdentifier(), - new_operands, target_f32_ty); - new_op = builder.create(op.getLoc(), res_ty, - new_op->getResult(0)); + new_operands, new_res_ty, op.getAttrs()); + if (should_truncate) { + new_op = builder.create(op.getLoc(), res_ty, + new_op->getResult(0)); + } op.replaceAllUsesWith(new_op); op.erase(); } @@ -520,6 +569,44 @@ LogicalResult canonicalize_extract(const CanonicalizeContext &ctx, return success(); } +LogicalResult canonicalize_broadcast(const CanonicalizeContext &ctx, + Operation &raw_op) { + auto op = dyn_cast(raw_op); + auto src_ty = op.getSource().getType(); + auto src_vty = dyn_cast(src_ty); + if ((src_vty && src_vty.getElementType().isSignlessInteger(1)) || + op.getSource().getType().isSignlessInteger(1)) { + // Canonicalize i1 broadcast. + // i1 represents vmsk in Mosaic and TPU doesn't support vmsk replication + // directly. + // Instead, convert i1 to i32 vector, broadcast i32, and then convert it + // back to i1. + ImplicitLocOpBuilder builder(op->getLoc(), op.getOperation()); + Value i32_src; + if (src_vty) { + i32_src = builder.create( + VectorType::get(src_vty.getShape(), builder.getI32Type()), + op.getSource()); + } else { + i32_src = + builder.create(builder.getI32Type(), op.getSource()); + } + auto i32_res_vty = + VectorType::get(op.getType().getShape(), builder.getI32Type()); + auto bcast = builder.create(i32_res_vty, i32_src); + auto ones = builder.create( + i32_res_vty, + SplatElementsAttr::get(i32_res_vty, + builder.getOneAttr(builder.getI32Type()))); + auto cmp = + builder.create(arith::CmpIPredicate::eq, bcast, ones); + op.replaceAllUsesWith(cmp.getResult()); + op.erase(); + return success(); + } + return success(); +} + LogicalResult canonicalize_select(const CanonicalizeContext &ctx, Operation &raw_op) { auto op = dyn_cast(raw_op); @@ -536,6 +623,9 @@ LogicalResult canonicalize_select(const CanonicalizeContext &ctx, op.getLoc(), cond, op.getTrueValue(), op.getFalseValue()); op.replaceAllUsesWith(new_op.getResult()); op.erase(); + if (need_elementwise_canonicalization(ctx, *new_op.getOperation())) { + return canonicalize_elementwise(ctx, *new_op.getOperation()); + } return success(); } @@ -550,14 +640,10 @@ LogicalResult canonicalize_fptosi(const CanonicalizeContext &ctx, return op.emitOpError("Vector/scalar mismatch between input and output"); } bool is_vector = static_cast(src_vty); - unsigned src_bitwidth, dst_bitwidth; - if (is_vector) { - src_bitwidth = src_vty.getElementTypeBitWidth(); - dst_bitwidth = dst_vty.getElementTypeBitWidth(); - } else { - src_bitwidth = op.getIn().getType().getIntOrFloatBitWidth(); - dst_bitwidth = op.getType().getIntOrFloatBitWidth(); - } + FAILUREOR_ASSIGN_OR_RETURN(const unsigned src_bitwidth, + getElementTypeBitwidth(op.getIn().getType())); + FAILUREOR_ASSIGN_OR_RETURN(const unsigned dst_bitwidth, + getElementTypeBitwidth(op.getType())); if (dst_bitwidth > 32) { return op.emitOpError("Target bitwidth too large"); } @@ -570,17 +656,16 @@ LogicalResult canonicalize_fptosi(const CanonicalizeContext &ctx, op.getType(), op.getIn(), tpu::RoundingMode::kTowardsZero); op.replaceAllUsesWith(new_op.getResult()); op.erase(); - // We briefly trigger canonicalization here to potentially fuse the rounding - // ops into the newly created tpu.fptosi. - { - PatternRewriter rewriter(new_op.getContext()); - rewriter.setInsertionPoint(new_op); - // We don't care if the canonicalization pattern matched or not. - (void)tpu::FPToSIOp::canonicalize(new_op, rewriter); - new_op = nullptr; // Canonicalization may have erased the op! - } return success(); } + + if ((src_bitwidth < 32 || dst_bitwidth < 32) && !ctx.compatibility_mode) { + return op.emitOpError( + "On this target float-to-integer conversions can only happen on " + "32-bit values. Enable compatibility mode or upcast to float32, cast " + "to int32 and truncate to desired bitwidth."); + } + Value x = op.getIn(); // Upcast the input to f32. if (src_bitwidth < 32) { @@ -592,11 +677,6 @@ LogicalResult canonicalize_fptosi(const CanonicalizeContext &ctx, } } if (dst_bitwidth < 32) { - if (!ctx.compatibility_mode) { - return op.emitOpError( - "On this target only float-to-integer conversions can only happen on " - "32-bit values. Enable compatibility mode or upcast to float32."); - } // Need to clip values to match XLA auto clip = [&](Value x, Value low, Value high) { x = builder.create(x, low); @@ -624,12 +704,6 @@ LogicalResult canonicalize_fptosi(const CanonicalizeContext &ctx, x = builder.create(builder.getI32Type(), x); } if (dst_bitwidth < 32) { - if (!ctx.compatibility_mode) { - return op.emitOpError( - "On this target only float-to-integer conversions can only happen on " - "32-bit values. Enable compatibility mode or cast to int32 and " - "truncate later."); - } x = builder.create(op.getType(), x); } op.replaceAllUsesWith(x); @@ -637,6 +711,66 @@ LogicalResult canonicalize_fptosi(const CanonicalizeContext &ctx, return success(); } +LogicalResult canonicalize_sitofp(const CanonicalizeContext &ctx, + Operation &raw_op) { + auto op = cast(raw_op); + ImplicitLocOpBuilder builder(op->getLoc(), op.getOperation()); + auto src_vty = dyn_cast(op.getIn().getType()); + auto dst_vty = dyn_cast(op.getType()); + if (static_cast(src_vty) != static_cast(dst_vty)) { + return op.emitOpError("Vector/scalar mismatch between input and output"); + } + bool is_vector = static_cast(src_vty); + FAILUREOR_ASSIGN_OR_RETURN(const unsigned src_bitwidth, + getElementTypeBitwidth(op.getIn().getType())); + FAILUREOR_ASSIGN_OR_RETURN(const unsigned dst_bitwidth, + getElementTypeBitwidth(op.getType())); + + // We have low-level optimized code for s8->bf16 and s4->bf16 casts on v6. + if (ctx.hardware_generation >= 6 && is_vector && + (src_vty.getElementType().isSignlessInteger(8) || + src_vty.getElementType().isSignlessInteger(4)) && + dst_vty.getElementType().isBF16()) { + auto new_op = builder.create( + op.getType(), op.getIn(), tpu::RoundingMode::kToNearestEven); + op.replaceAllUsesWith(new_op.getResult()); + op.erase(); + return success(); + } + + if ((src_bitwidth < 32 || dst_bitwidth < 32) && !ctx.compatibility_mode) { + return op.emitOpError( + "On this target integer-to-float conversions can only happen on " + "32-bit values. Enable compatibility mode or upcast to int32, cast to " + "float32 and truncate to desired bitwidth."); + } + + // Canonicalize (intX -> floatY) to (intX -> int32 -> float32 -> floatY). + Value x = op.getIn(); + if (src_bitwidth < 32) { + if (is_vector) { + x = builder.create( + VectorType::get(src_vty.getShape(), builder.getI32Type()), x); + } else { + x = builder.create(builder.getI32Type(), x); + } + } + if (is_vector) { + x = builder.create( + VectorType::get(src_vty.getShape(), builder.getF32Type()), x, + tpu::RoundingMode::kToNearestEven); + } else { + x = builder.create(builder.getF32Type(), x, + tpu::RoundingMode::kToNearestEven); + } + if (dst_bitwidth < 32) { + x = builder.create(op.getType(), x); + } + op.replaceAllUsesWith(x); + op.erase(); + return success(); +} + LogicalResult canonicalize_repeat(const CanonicalizeContext &ctx, Operation &raw_op) { auto op = dyn_cast(raw_op); @@ -661,6 +795,271 @@ LogicalResult canonicalize_repeat(const CanonicalizeContext &ctx, return success(); } +LogicalResult canonicalize_vector_transpose(const CanonicalizeContext &ctx, + Operation &raw_op) { + auto op = cast(raw_op); + ImplicitLocOpBuilder builder(op->getLoc(), op.getOperation()); + auto new_op = builder.create(op.getType(), op.getVector(), + op.getPermutation()); + op.replaceAllUsesWith(new_op.getResult()); + op.erase(); + return success(); +} + +LogicalResult canonicalize_reshape(const CanonicalizeContext &ctx, + Operation &raw_op) { + auto op = cast(raw_op); + // We can canonicalize some reshape(load(x)) -> strided load + ALU ops. + auto src = op.getSource(); + auto src_ty = src.getType(); + auto tgt_ty = op.getType(); + if (auto load_op = src.getDefiningOp()) { + // Pattern match (..., M, N, 128) -> (..., M, N * 128). + // This reshape can be folded into the load for any dtype and tiling + // as long as the minormost dim is 128 and N is aligned to packing. The + // pseudo code is: + // ``` + // src_ref: (M, N, 128) with src_ty + // + // def load_to_reshape(src_ref): + // b_ref = src_ref.bitcast(i32) # i32[M, N / packing, 128] + // r_ref = b_ref.reshape(M * N / packing, 128) + // chunks = [] + // for i in range(N / packing): + // v = r_ref[i::N / packing, :] # i32[M, 128] + // for j in range(packing): + // chunk = v >> (j * bitwidth) + // chunks.append(chunk) + // res = concat(chunks, axis=-1) # i32[M, N * 128] + // # int_src_ty refers to int type with the same bitwidth as src_ty. + // res = res.astype(int_src_ty) # Trigger i32 -> int_src_ty packing. + // return bitcast(res, src_ty) # src_ty[M, N * 128] + // ``` + // TODO(jevinjiang): we can extend this to support folding more dims to last + // dim not just last 2 dims. + auto bitwidth = src_ty.getElementTypeBitWidth(); + auto packing = 32 / bitwidth; + if (packing <= 0) { + return op.emitOpError("Unsupported bitwidth = ") << bitwidth; + } + // Memref bitcast is not supported if HW generation is below 4. We don't + // return failure because we will rely on vector reshape. + if ((ctx.hardware_generation < 4 && packing > 1) || + (ctx.hardware_generation == 4 && packing > 2)) { + return success(); + } + auto ref = load_op.getBase(); + auto indices = load_op.getIndices(); + auto ref_shape = ref.getType().getShape(); + auto src_shape = src_ty.getShape(); + auto tgt_shape = tgt_ty.getShape(); + int ref_rank = ref_shape.size(); + int src_rank = src_shape.size(); + int tgt_rank = tgt_shape.size(); + if (ref_rank != src_rank) { + return op.emitOpError("Loaded vector rank and memref rank mismatch"); + } + // Check the memref's eligibility. + if (!isContiguousMemref(ref) || ref_rank <= 2 || + // TODO(jevinjiang): add support for partial load on last 2 dims where + // last 2 indices are not necessarily 0 or load shape is not full. + getIntConst(indices[ref_rank - 1]) != 0 || + getIntConst(indices[ref_rank - 2]) != 0 || + ref_shape[ref_rank - 1] != src_shape[src_rank - 1] || + ref_shape[ref_rank - 2] != src_shape[src_rank - 2]) { + return success(); + } + // Check the reshape's eligibility. + if (src_rank != tgt_rank + 1 || src_shape[src_rank - 2] % packing != 0 || + src_shape[src_rank - 1] != ctx.target_shape[1] || + src_shape[src_rank - 2] * src_shape[src_rank - 1] != + tgt_shape[tgt_rank - 1]) { + return success(); + } + // At this point, the pattern is matched. + ImplicitLocOpBuilder builder(op->getLoc(), op.getOperation()); + auto loc = op.getLoc(); + // First, we bitcast and reshape src ref from (..., M, N, 128) to + // i32(..., M * N / packing, 128). + SmallVector bitcast_shape(ref_shape); + // TODO(jevinjiang): once we have memref pad op, we can use ceiling + // division to ref_shape[ref_rank - 2] and packing to get sublane_cnt. + CHECK_EQ(ref_shape[ref_rank - 2] % packing, 0); + auto i32_2nd_minor_size = ref_shape[ref_rank - 2] / packing; + bitcast_shape[ref_rank - 2] = i32_2nd_minor_size; + auto i32_ref = builder.create( + MemRefType::get(bitcast_shape, builder.getI32Type()), ref); + + SmallVector reshape_shape(ref_shape.begin(), + ref_shape.begin() + tgt_rank); + reshape_shape[tgt_rank - 1] = ctx.target_shape[1]; + reshape_shape[tgt_rank - 2] = ref_shape[ref_rank - 3] * i32_2nd_minor_size; + auto reshape_ref = builder.create( + MemRefType::get(reshape_shape, builder.getI32Type()), i32_ref); + + // We also need to transform the indices while transforming the memref. + SmallVector new_indices(indices.begin(), indices.begin() + tgt_rank); + new_indices[tgt_rank - 1] = IdxConst(0, builder, loc); + new_indices[tgt_rank - 2] = builder.create( + builder.getIndexType(), indices[ref_rank - 3], + IdxConst(i32_2nd_minor_size, builder, loc)); + // Then, we strided load the bitcasted ref by stride (N / packing). + int stride = i32_2nd_minor_size; + // Expect to hold src_shape[src_rank - 2] number of chunks which have the + // shape (..., src_shape[src_rank - 3], 128) and wait to be concatenated + // along the last dim. + SmallVector chunks(src_shape[src_rank - 2]); + SmallVector chunk_shape(tgt_shape); + chunk_shape[tgt_rank - 1] = ctx.target_shape[1]; + SmallVector strides(tgt_rank, 1); + strides[tgt_rank - 2] = stride; + auto tgt_2nd_minor_idx = new_indices[tgt_rank - 2]; + for (int i = 0; i < stride; ++i) { + new_indices[tgt_rank - 2] = builder.create( + builder.getIndexType(), tgt_2nd_minor_idx, IdxConst(i, builder, loc)); + auto chunk = builder.create( + VectorType::get(chunk_shape, builder.getI32Type()), reshape_ref, + new_indices, strides); + for (int j = 0; j < packing; ++j) { + int idx = i * packing + j; + chunks[idx] = builder.create( + chunk.getType(), chunk, + I32Const(j * bitwidth, chunk_shape, builder, loc)); + } + } + // Concatenate the chunks along the last dim to get i32(..., M, N * 128). + CHECK_GT(chunks.size(), 0); + Value i32_tgt = chunks[0]; + if (chunks.size() > 1) { + i32_tgt = builder.create( + VectorType::get(tgt_shape, builder.getI32Type()), chunks, + /*dimension=*/tgt_rank - 1); + } + Value tgt = i32_tgt; + // Convert to target dtype. + if (packing > 1) { + tgt = builder.create( + VectorType::get(tgt_shape, builder.getIntegerType(bitwidth)), + i32_tgt); + } + tgt = builder.create(tgt_ty, tgt); + op.replaceAllUsesWith(tgt); + op.erase(); + } + return success(); +} + +namespace { +// TODO(mvoz): We can refactor a lot of other canonicalization rules to use +// these functions. +// TODO(mvoz): I think we can eventually do direct conversion to bf16 +// without going through f32? +Value upcastInt8ToBf16(ImplicitLocOpBuilder &builder, Value input) { + auto vty = cast(input.getType()); + auto shape = vty.getShape(); + auto int_ty = cast(vty.getElementType()); + + auto i32_vty = VectorType::get(shape, builder.getI32Type()); + auto val_i32 = int_ty.isUnsigned() + ? builder.create(i32_vty, input) + : builder.create(i32_vty, input); + + auto f32_vty = VectorType::get(shape, builder.getF32Type()); + auto val_f32 = builder.create( + f32_vty, val_i32->getResult(0), tpu::RoundingMode::kToNearestEven); + + auto bf16_vty = VectorType::get(shape, builder.getBF16Type()); + return builder.create(bf16_vty, val_f32); +} + +Value downcastBf16ToInt8(ImplicitLocOpBuilder &builder, Value input_bf16, + Type target_vty) { + auto shape = cast(input_bf16.getType()).getShape(); + + auto f32_vty = VectorType::get(shape, builder.getF32Type()); + auto val_f32 = builder.create(f32_vty, input_bf16); + + auto i32_vty = VectorType::get(shape, builder.getI32Type()); + auto val_i32 = builder.create(i32_vty, val_f32); + + return builder.create(target_vty, val_i32); +} + +Value upcastFp8ToBf16(ImplicitLocOpBuilder &builder, Value input) { + auto shape = cast(input.getType()).getShape(); + auto f32_vty = VectorType::get(shape, builder.getF32Type()); + auto val_f32 = builder.create(f32_vty, input); + auto bf16_vty = VectorType::get(shape, builder.getBF16Type()); + return builder.create(bf16_vty, val_f32); +} + +Value downcastBf16ToFp8(ImplicitLocOpBuilder &builder, Value input_bf16, + Type target_vty) { + auto shape = cast(input_bf16.getType()).getShape(); + auto f32_vty = VectorType::get(shape, builder.getF32Type()); + auto val_f32 = builder.create(f32_vty, input_bf16); + return builder.create(target_vty, val_f32); +} +} // namespace + +// Note(mvoz): Returns optional to signal no replacement, simplifying downstream +// .replace() and .erase() calls. +std::optional canonicalize_transpose_impl(const CanonicalizeContext &ctx, + ImplicitLocOpBuilder &builder, + tpu::TransposeOp op) { + auto input_ty = dyn_cast(op.getOperand().getType()); + auto element_type = input_ty.getElementType(); + // TODO(mvoz): Even gen 7 support is spotty on all test targets. + if (element_type.getIntOrFloatBitWidth() == 8 && ctx.compatibility_mode && + ctx.hardware_generation > 3) { + Value val_bf16; + if (isa(element_type)) { + val_bf16 = upcastInt8ToBf16(builder, op.getOperand()); + } else { + val_bf16 = upcastFp8ToBf16(builder, op.getOperand()); + } + + auto original_output_ty = cast(op.getType()); + auto post_transpose_bf16_vty = + VectorType::get(original_output_ty.getShape(), builder.getBF16Type()); + + auto new_t = builder.create( + post_transpose_bf16_vty, val_bf16, op.getPermutation()); + + Value final_val; + if (isa(element_type)) { + final_val = downcastBf16ToInt8(builder, new_t.getResult(), op.getType()); + } else { + final_val = downcastBf16ToFp8(builder, new_t.getResult(), op.getType()); + } + return final_val; + } + return std::nullopt; +} + +Value create_transpose_op(const CanonicalizeContext &ctx, + ImplicitLocOpBuilder &builder, VectorType input_ty, + Value input, ArrayRef permutation) { + auto t = builder.create(input_ty, input, permutation); + auto new_op_opt = canonicalize_transpose_impl(ctx, builder, t); + if (new_op_opt.has_value()) { + return new_op_opt.value(); + } + return t; +} + +LogicalResult canonicalize_transpose(const CanonicalizeContext &ctx, + Operation &raw_op) { + auto op = cast(raw_op); + auto builder = ImplicitLocOpBuilder(op->getLoc(), op.getOperation()); + auto new_op_opt = canonicalize_transpose_impl(ctx, builder, op); + if (new_op_opt.has_value()) { + op.replaceAllUsesWith(new_op_opt.value()); + op.erase(); + } + return success(); +} + using canonicalize_rule_type = std::function; @@ -671,34 +1070,62 @@ const llvm::StringMap &rules() { {vector::ExtractOp::getOperationName(), canonicalize_extract}, {vector::MultiDimReductionOp::getOperationName(), canonicalize_multi_dim_reduction}, + {vector::TransposeOp::getOperationName(), canonicalize_vector_transpose}, + {vector::ShapeCastOp::getOperationName(), canonicalize_reshape}, + {vector::BroadcastOp::getOperationName(), canonicalize_broadcast}, {arith::SelectOp::getOperationName(), canonicalize_select}, {arith::FPToSIOp::getOperationName(), canonicalize_fptosi}, + {arith::SIToFPOp::getOperationName(), canonicalize_sitofp}, + {tpu::TransposeOp::getOperationName(), canonicalize_transpose}, {tpu::RepeatOp::getOperationName(), canonicalize_repeat}}; return *rules; } -bool need_elementwise_canonicalization(CanonicalizeContext ctx, Operation &op) { - if (isa(op)) { - auto vec_ty = dyn_cast(op.getOperand(0).getType()); - if (vec_ty && vec_ty.getElementType().isBF16() && - ctx.hardware_generation >= 4) { - return false; - } - return true; +const llvm::StringMap &bf16_ops_min_supported_versions() { + static const auto m = new llvm::StringMap{ + {arith::DivFOp::getOperationName(), 4}, + {arith::SelectOp::getOperationName(), 5}, + {arith::CmpFOp::getOperationName(), 5}, + {arith::MulFOp::getOperationName(), 6}, + {arith::AddFOp::getOperationName(), 6}, + {arith::SubFOp::getOperationName(), 6}, + {arith::MaximumFOp::getOperationName(), 6}, + {arith::MinimumFOp::getOperationName(), 6}, + {math::PowFOp::getOperationName(), 6}, + {math::TanhOp::getOperationName(), 6}, + {math::ExpOp::getOperationName(), 6}, + {math::Exp2Op::getOperationName(), 6}, + {math::LogOp::getOperationName(), 6}, + }; + return *m; +} + +bool need_elementwise_canonicalization(const CanonicalizeContext &ctx, + Operation &op) { + // Only rewrite when the hardware generation is below the minimum supported + // version. + auto it = bf16_ops_min_supported_versions().find(op.getName().getStringRef()); + if (it == bf16_ops_min_supported_versions().end() || + ctx.hardware_generation >= it->second) { + return false; } - return isa(op); + return llvm::any_of(op.getOperands(), [](Value operand) { + auto vty = dyn_cast(operand.getType()); + return vty && vty.getElementType().isBF16(); + }); } class MosaicCanonicalizer { public: - MosaicCanonicalizer(int hardware_generation, bool compatibility_mode) + MosaicCanonicalizer(int hardware_generation, bool compatibility_mode, + std::array target_shape) : hardware_generation_(hardware_generation), - compatibility_mode_(compatibility_mode) {} + compatibility_mode_(compatibility_mode), + target_shape_(target_shape) {} int hardware_generation_; bool compatibility_mode_; + std::array target_shape_; LogicalResult canonicalize(func::FuncOp op) { if (!op.getBody().hasOneBlock()) { @@ -719,10 +1146,11 @@ class MosaicCanonicalizer { } LogicalResult canonicalizeOp(Operation &any_op) { - CanonicalizeContext ctx({compatibility_mode_, hardware_generation_}); + CanonicalizeContext ctx( + {compatibility_mode_, hardware_generation_, target_shape_}); // We must iterate over the op first, because canonicalization can cause - // us to .erase() an op, and accessing getRegions on it after is not sound. - // Invariant - top level ops with regions may never be invalidated. + // us to .erase() an op, and accessing getRegions on it after is not + // sound. Invariant - top level ops with regions may never be invalidated. for (Region ®ion : any_op.getRegions()) { for (Block &block : region) { if (canonicalizeBlock(block).failed()) { @@ -744,14 +1172,18 @@ class MosaicCanonicalizer { struct CanonicalizeMosaicPass : public impl::CanonicalizeMosaicPassBase { - CanonicalizeMosaicPass(int hardware_generation_p, bool compatibility_mode_p) + CanonicalizeMosaicPass(int hardware_generation_p, bool compatibility_mode_p, + std::array target_shape) : compatibility_mode_(compatibility_mode_p) { this->hardware_generation = hardware_generation_p; + this->sublane_count = target_shape[0]; + this->lane_count = target_shape[1]; } void runOnOperation() override { func::FuncOp func = getOperation(); - MosaicCanonicalizer vlc(hardware_generation, compatibility_mode_); + MosaicCanonicalizer vlc(hardware_generation, compatibility_mode_, + {sublane_count, lane_count}); if (vlc.canonicalize(func).failed()) { signalPassFailure(); } @@ -763,9 +1195,10 @@ struct CanonicalizeMosaicPass } // namespace std::unique_ptr> createCanonicalizeMosaicPass( - int hardware_generation, bool compatibility_mode) { - return std::make_unique(hardware_generation, - compatibility_mode); + int hardware_generation, bool compatibility_mode, + std::array target_shape) { + return std::make_unique( + hardware_generation, compatibility_mode, target_shape); } } // namespace mlir::tpu diff --git a/jaxlib/mosaic/dialect/tpu/transforms/communication.cc b/jaxlib/mosaic/dialect/tpu/transforms/communication.cc index 89e3a8bb9f70..7798b0027369 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/communication.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/communication.cc @@ -17,13 +17,16 @@ limitations under the License. #include #include +#include "absl/log/check.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Diagnostics.h" +#include "mlir/IR/ValueRange.h" #include "mlir/IR/Visitors.h" -#include "mlir/Support/LogicalResult.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" #include "xla/layout.h" @@ -107,9 +110,13 @@ struct LogicalToPhysicalDeviceIdPass auto device_assignment_type = MemRefType::get( {total_devices}, IntegerType::get(func.getContext(), 32), TiledLayoutAttr::get(func.getContext(), {xla::Tile({128})}, {1}), - MemorySpaceAttr::get(func.getContext(), MemorySpace::smem)); - func.insertArgument(func.getNumArguments(), device_assignment_type, - nullptr, UnknownLoc::get(func.getContext())); + MemorySpaceAttr::get(func.getContext(), MemorySpace::kSmem)); + + if (failed(func.insertArgument(func.getNumArguments(), + device_assignment_type, nullptr, + UnknownLoc::get(func.getContext())))) { + return signalPassFailure(); + } auto device_assignment_arg = func.getArgument(func.getNumArguments() - 1); func.walk([device_assignment_arg](Operation *some_op) { if (auto op = dyn_cast(some_op)) { diff --git a/jaxlib/mosaic/dialect/tpu/transforms/extensions/apply_vector_layout_extensions.cc b/jaxlib/mosaic/dialect/tpu/transforms/extensions/apply_vector_layout_extensions.cc index e7528533938f..d2c149a47150 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/extensions/apply_vector_layout_extensions.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/extensions/apply_vector_layout_extensions.cc @@ -1,7 +1,22 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + #include "jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout_extensions.h" -#include "llvm/include/llvm/ADT/StringMap.h" -#include "mlir/include/mlir/IR/Operation.h" +#include "llvm/ADT/StringMap.h" +#include "mlir/IR/Operation.h" namespace mlir::tpu::extensions { diff --git a/jaxlib/mosaic/dialect/tpu/transforms/extensions/infer_vector_layout_extensions.cc b/jaxlib/mosaic/dialect/tpu/transforms/extensions/infer_vector_layout_extensions.cc index c9c4a97e6222..e34ef7fcb261 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/extensions/infer_vector_layout_extensions.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/extensions/infer_vector_layout_extensions.cc @@ -1,11 +1,26 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + #include "jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout_extensions.h" #include #include -#include "mlir/include/mlir/IR/Operation.h" -#include "mlir/include/mlir/Support/LLVM.h" -#include "mlir/include/mlir/Support/LogicalResult.h" +#include "mlir/IR/Operation.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" namespace mlir::tpu::extensions { diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.cc index 0926f8a3c7b5..f96989c0fd95 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.cc @@ -1,3 +1,18 @@ +/* Copyright 2023 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + #include "jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.h" #include @@ -6,6 +21,7 @@ #include #include +#include "absl/log/check.h" #include "llvm/ADT/bit.h" #include "llvm/Support/MathExtras.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -14,7 +30,6 @@ #include "mlir/IR/Block.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/Location.h" @@ -23,7 +38,6 @@ #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" -#include "absl/log/check.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" #include "jaxlib/mosaic/dialect/tpu/util.h" #include "xla/layout.h" @@ -44,10 +58,12 @@ namespace mlir::tpu { // enabled by XLA for memrefs. // bitwidth: The bitwidth of the element type of the operand. // is_kernel_argument: Whether the operand is a kernel argument. +// is_1d: Whether the operand is 1D. int getTilingFactor(const int src_sublane, const int hardware_generation, const int64_t target_sublane_count, const TpuTilingFlags &tpu_tiling_flags, - const int8_t bitwidth, const bool is_kernel_argument) { + const int8_t bitwidth, const bool is_kernel_argument, + const bool is_1d) { CHECK(llvm::isPowerOf2_32(bitwidth)); CHECK_LE(2, bitwidth); CHECK_LE(bitwidth, 32); @@ -62,18 +78,25 @@ int getTilingFactor(const int src_sublane, const int hardware_generation, const int max_normal_tiling = tiling_sublane; int large_tiling = [&] { + if (is_1d) { + // 1D tiling is always compact. + return tiling_sublane; + } + if (bitwidth == 2) { + return target_sublane_count * 16; + } if (bitwidth == 4 && tpu_tiling_flags.use_x4_large_second_minor) { - return tiling_sublane * 8; + return target_sublane_count * 8; } if (bitwidth == 8 && tpu_tiling_flags.use_x8_large_second_minor) { - return tiling_sublane * 4; + return target_sublane_count * 4; } // 16-bit values are generally always possible to relayout on the fly in v6, // so we allow large 2nd minor tiling whenever possible. We can't do this // for kernel arguments, because the layout of those is controlled by XLA. if (bitwidth == 16 && (tpu_tiling_flags.use_x16_large_second_minor || (!is_kernel_argument && hardware_generation >= 6))) { - return tiling_sublane * 2; + return target_sublane_count * 2; } return tiling_sublane; }(); @@ -134,9 +157,9 @@ FailureOr inferLayout(MemRefType memref_ty, auto src_sublane = llvm::divideCeil(memref_ty.getShape().back(), lane_count); const int64_t leading_tile = - getTilingFactor(src_sublane, hardware_generation, - sublane_count, tpu_tiling_flags, bitwidth, - is_kernel_argument) * + getTilingFactor(src_sublane, hardware_generation, sublane_count, + tpu_tiling_flags, bitwidth, is_kernel_argument, + /*is_1d=*/true) * lane_count; SmallVector tiles{xla::Tile({leading_tile})}; if (bitwidth != 32) { @@ -156,8 +179,8 @@ FailureOr inferLayout(MemRefType memref_ty, const int64_t src_sublane = shape[shape.size() - 2]; if (leading_tile_rows == 0) { leading_tile_rows = getTilingFactor( - src_sublane, hardware_generation, sublane_count, - tpu_tiling_flags, bitwidth, is_kernel_argument); + src_sublane, hardware_generation, sublane_count, tpu_tiling_flags, + bitwidth, is_kernel_argument, /*is_1d=*/false); } SmallVector tiles{xla::Tile({leading_tile_rows, lane_count})}; if (bitwidth != 32) { @@ -222,7 +245,7 @@ FailureOr inferMemref(MemRefType memref, semaphore_mem); } const Attribute vmem = - tpu::MemorySpaceAttr::get(memref.getContext(), MemorySpace::vmem); + tpu::MemorySpaceAttr::get(memref.getContext(), MemorySpace::kVmem); const Attribute memory_space = memref.getMemorySpace() == nullptr ? vmem : memref.getMemorySpace(); FAILUREOR_ASSIGN_OR_RETURN( diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.h b/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.h index f2ab7c624eb1..a6dd8ad1dbd3 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.h +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.h @@ -1,3 +1,18 @@ +/* Copyright 2023 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + #ifndef THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_TRANSFORMS_INFER_MEMREF_LAYOUT_H_ #define THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_TRANSFORMS_INFER_MEMREF_LAYOUT_H_ diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc index 0081feba985b..388adc421a09 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc @@ -19,35 +19,29 @@ limitations under the License. #include #include #include -#include -#include #include +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/types/span.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVectorExtras.h" -#include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/IR/OpDefinition.h" #include "mlir/IR/Value.h" +#include "mlir/IR/Visitors.h" +#include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" -#include "absl/log/check.h" -#include "absl/log/log.h" -#include "absl/types/span.h" -#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/include/mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/include/mlir/Dialect/Vector/IR/VectorOps.h" -#include "mlir/include/mlir/IR/Attributes.h" -#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" -#include "mlir/include/mlir/IR/OpDefinition.h" -#include "mlir/include/mlir/IR/Visitors.h" -#include "mlir/include/mlir/Pass/Pass.h" #include "jaxlib/mosaic/dialect/tpu/layout.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" #include "jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout_extensions.h" @@ -66,7 +60,6 @@ using ImplicitDim = VectorLayout::ImplicitDim; static constexpr int kLayoutLog = 10; - bool is_fully_replicated(const Layout &layout) { static LayoutOffsets replicated_offsets = {std::nullopt, std::nullopt}; return layout.has_value() && layout->offsets() == replicated_offsets; @@ -142,10 +135,10 @@ class VectorLayoutInferer { bool has_vector_io = false; for (auto op : any_op.getOperands()) { - has_vector_io |= op.getType().isa(); + has_vector_io |= isa(op.getType()); } for (auto r : any_op.getResults()) { - has_vector_io |= r.getType().isa(); + has_vector_io |= isa(r.getType()); } if (!has_vector_io && any_op.getRegions().empty()) { SmallVector in_layout(any_op.getNumOperands(), kNoLayout); @@ -161,6 +154,14 @@ class VectorLayoutInferer { if (inferExt(&any_op).failed()) { return failure(); } + } else if (auto op = dyn_cast(any_op); + op && + cast(op.getIn().getType()) + .getElementTypeBitWidth() < + cast(op.getType()).getElementTypeBitWidth()) { + if (inferExt(&any_op).failed()) { + return failure(); + } } else if (isa(any_op)) { if (inferTrunc(&any_op).failed()) { return failure(); @@ -326,7 +327,7 @@ class VectorLayoutInferer { if (inferStore(op).failed()) { return failure(); } - } else if (auto op = dyn_cast(any_op)) { + } else if (auto op = dyn_cast(any_op)) { if (infer(op).failed()) { return failure(); } @@ -468,7 +469,7 @@ class VectorLayoutInferer { TPU_CHECK_OP(else_yield->getOperandTypes() == op->getResultTypes(), "scf if results and else branch yield operands do not match"); auto else_yield_in_layouts = getLayoutFromOperands(else_yield); - // Find a compatible layout from then and else branches for each reuslt. For + // Find a compatible layout from then and else branches for each result. For // example, if we yield offset (*, *) in then branch and offset (*, 0) in // else branch, the result offset should be (*, 0). SmallVector out_layouts; @@ -648,7 +649,7 @@ class VectorLayoutInferer { auto yield_in_layouts = getLayoutFromOperands(yield_op); // Find a compatible layout from condition body and loop body for each - // reuslt. For example, if we yield offset (*, *) in condition body and + // result. For example, if we yield offset (*, *) in condition body and // offset (*, 0) in loop body, the result offset should be (*, 0). SmallVector out_layouts; out_layouts.reserve(op->getNumResults()); @@ -763,9 +764,28 @@ class VectorLayoutInferer { if (op.getType().getRank() < 2) { NYI("Unsupported 1D shape"); } + // TODO(b/337384645): Currently we assume {0, 0} offsets in the input + // layout. Relax this assumption. auto layout = VectorLayout(bitwidth, {0, 0}, nativeTiling(bitwidth), ImplicitDim::kNone); - setLayout(op, {layout, kNoLayout}, layout); + // Calculate the offsets for the output layout. + LayoutOffsets offsets_out = layout.offsets(); + // We assume there are no implicit dims. + int tiling_dim = op.getDimension() - (op.getType().getRank() - 2); + if (auto amount = op.getAmount().getDefiningOp(); + amount && (tiling_dim == 0 || tiling_dim == 1)) { + if (auto integer_attr = dyn_cast(amount.getValue())) { + const int64_t tile_size = layout.tiling()[tiling_dim]; + const int64_t dim_size = op.getType().getShape()[op.getDimension()]; + const int64_t shift = integer_attr.getValue().getSExtValue(); + if (dim_size % tile_size != 0) { + offsets_out[tiling_dim] = (dim_size - (shift % dim_size)) % tile_size; + } + } + } + auto out_layout = VectorLayout(bitwidth, offsets_out, + nativeTiling(bitwidth), ImplicitDim::kNone); + setLayout(op, {layout, kNoLayout}, out_layout); return success(); } @@ -947,22 +967,23 @@ class VectorLayoutInferer { } LogicalResult infer(tpu::DynamicGatherOp op) { - if (op.getType().getShape() != ArrayRef(target_shape_) && - op.getType().getElementTypeBitWidth() != 32) { - return op.emitOpError( - "Not implemented: DynamicGatherOp only supports 32-bit VREG shape"); - } - if (op.getDimension() != 0 && op.getDimension() != 1) { - return op.emitOpError( - "Not implemented: Only dimension 0 and 1 are supported"); - } // TODO(jevinjiang): we could preserve some offsets such as replicated // offset but since we are forcing all operands and result to be the same // layout, we can set all offsets to zero for now. Also maybe we should // consider adding this to elementwise rule. - auto layout = VectorLayout(kNativeBitwidth, {0, 0}, default_tiling_, - ImplicitDim::kNone); - setLayout(op, {layout, layout}, layout); + if (op.getType().getShape() == ArrayRef(target_shape_) && + op.getType().getElementTypeBitWidth() == 32) { + VectorLayout layout(kNativeBitwidth, {0, 0}, default_tiling_, + ImplicitDim::kNone); + setLayout(op, {layout, layout}, layout); + } else if (op.getIndices().getType().getShape() == + ArrayRef{4 * target_shape_[0], target_shape_[1]} && + op.getType().getElementTypeBitWidth() == 8) { + VectorLayout layout(8, {0, 0}, nativeTiling(8), ImplicitDim::kNone); + setLayout(op, {layout, layout}, layout); + } else { + return op.emitOpError("Not implemented"); + } return success(); } @@ -1095,13 +1116,11 @@ class VectorLayoutInferer { } auto src_tiled_ishape = layout.getImplicitTiledDims(src_ty.getShape(), 1); auto dst_tiled_ishape = layout.getImplicitTiledDims(res_ty.getShape(), 1); - // Since we can only do sublane broadcasts in the (8, 128) tiling, we - // should always use that when sublane broadcasting is required. if (src_tiled_ishape[0] != dst_tiled_ishape[0] && layout.offsets()[0] != std::nullopt) { + // TODO(tlongeri): Remove this. We support non-native tiling now, but + // things may still break downstream due to missing relayouts. LayoutOffsets offsets = layout.offsets(); - // At the moment relayout can only produce replicated sublanes when - // converting to (8, 128) if the input was in (1, 128) tiling if (layout.tiling()[0] == 1 && layout.bitwidth() == kNativeBitwidth) { offsets[0] = std::nullopt; } @@ -1301,7 +1320,7 @@ class VectorLayoutInferer { (*(offsets.end() - 1) + *input_layout->offsets()[1]) % vreg_slice[1]; } for (auto stride : strides_attr) { - TPU_CHECK_OP(stride.cast().getInt() == 1, + TPU_CHECK_OP(cast(stride).getInt() == 1, "Only trivial strides supported."); } @@ -1509,7 +1528,30 @@ class VectorLayoutInferer { native_tiling, ImplicitDim::kNone)); return success(); } - op.emitOpError("unsupported shape cast"); + + // Shape cast (..., m, n, k * target_shape_[1]) -> (..., m, n * k * + // target_shape_[1]) for 32-bit types. We allow multiple major or minor + // dimensions to be folded or unfolded. + if (kNativeBitwidth == bitwidth && res_shape.size() >= 2 && + src_shape.size() >= 2 && src_shape.back() % native_tiling[1] == 0 && + res_shape.back() % native_tiling[1] == 0 && + (mlir::tpu::canFoldMinorDimsToSize(src_shape, res_shape.back()) || + mlir::tpu::canFoldMinorDimsToSize(res_shape, src_shape.back()))) { + // TODO(jsreeram): Add support for picking space-efficient tilings for + // small 2nd minor dim shapes. + // Example 1: (4, 2, 1024) -> (4, 2048) If we infer src and tgt layout to + // be (1, 128), it is no-op because essentially we just shufflle the VREGs + // in VREG array. + // Example 2: (4, 256) -> (1, 1024) is actually sublane + // shuffle inside each vreg from [0, 1, 2, 3, 4,..7] to [0, 4, 1, 5, ...] + setLayout(op, + VectorLayout(layout.bitwidth(), {0, 0}, native_tiling, + ImplicitDim::kNone), + VectorLayout(layout.bitwidth(), {0, 0}, native_tiling, + ImplicitDim::kNone)); + return success(); + } + op.emitOpError("infer-vector-layout: unsupported shape cast"); return failure(); } @@ -1630,7 +1672,7 @@ class VectorLayoutInferer { return success(); } - LogicalResult infer(vector::TransposeOp op) { + LogicalResult infer(tpu::TransposeOp op) { auto permutation = op.getPermutation(); TPU_CHECK_OP(permutation.size() > 1, "Vector and scalar transpose should be a no-op and removed"); @@ -1641,17 +1683,27 @@ class VectorLayoutInferer { auto src_ty = op.getSourceVectorType(); TPU_CHECK_OP(permutation.size() == src_ty.getRank(), "Transpose permutation has incorrect rank"); - for (auto dim : permutation.drop_back(2)) { - TPU_CHECK_OP(dim < src_ty.getRank() - 2, - "Unsupported transpose permutation - minor dims into major"); - } - for (auto dim : permutation.take_back(2)) { - TPU_CHECK_OP(dim >= src_ty.getRank() - 2, - "Unsupported transpose permutation - major dims into minor"); + bool untiled_tiled_swap = false; + // TODO(mvoz): Expand to more general cases. b/419268277 + if (permutation.size() == 3 && permutation[0] == 1 && permutation[1] == 0) { + untiled_tiled_swap = true; + } else { + for (auto dim : permutation.drop_back(2)) { + TPU_CHECK_OP(dim < src_ty.getRank() - 2, + "Unsupported transpose permutation - minor dims into " + "major > 3 dimensions"); + } + for (auto dim : permutation.take_back(2)) { + TPU_CHECK_OP(dim >= src_ty.getRank() - 2, + "Unsupported transpose permutation - major dims into " + "minor > 3 dimensions"); + } } Layout required_layout = some_layout; - // Require native tiling if we're going to use the XLU. - if (permutation[permutation.size() - 1] == permutation.size() - 2) { + // Require native tiling if we're going to use the XLU, or doing a + // major/minor permute. + if (untiled_tiled_swap || + permutation[permutation.size() - 1] == permutation.size() - 2) { auto native_tiling = nativeTiling(layout.bitwidth()); required_layout = VectorLayout(layout.bitwidth(), LayoutOffsets{0, 0}, native_tiling, ImplicitDim::kNone); @@ -1901,12 +1953,11 @@ class VectorLayoutInferer { } bool allUsersRequireNativeTiling(Value x) { - for (OpOperand &operand : x.getUses()) { - if (isa(operand.getOwner())) { + for (Operation *user : getNontrivialTransitiveUsers(x)) { + if (isa(user)) { continue; } - if (auto reduce = - dyn_cast(operand.getOwner())) { + if (auto reduce = dyn_cast(user)) { bool reduces_tiled_dims = false; for (int64_t dim : reduce.getReductionDims()) { if (dim >= reduce.getSourceVectorType().getRank() - 2) { @@ -1918,7 +1969,7 @@ class VectorLayoutInferer { continue; } } - if (auto transpose = dyn_cast(operand.getOwner())) { + if (auto transpose = dyn_cast(user)) { auto perm = transpose.getPermutation(); auto rank = perm.size(); // Only permutations that actually swap the last two dims need it. @@ -1928,7 +1979,7 @@ class VectorLayoutInferer { } // Fall through. } - if (auto store = dyn_cast(operand.getOwner())) { + if (auto store = dyn_cast(user)) { auto maybe_tiling = verifyMemoryTiling( store, getMemRefLayout(store.getBase()).getTiles(), store.getMemRefType().getRank(), diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout_extensions.h b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout_extensions.h index d240f27fd42d..a81e982f8e1a 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout_extensions.h +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout_extensions.h @@ -1,11 +1,26 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + #ifndef THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_TRANSFORMS_INFER_VECTOR_LAYOUT_EXTENSIONS_H_ #define THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_TRANSFORMS_INFER_VECTOR_LAYOUT_EXTENSIONS_H_ #include #include -#include "mlir/include/mlir/IR/Operation.h" -#include "mlir/include/mlir/Support/LLVM.h" +#include "mlir/IR/Operation.h" +#include "mlir/Support/LLVM.h" namespace mlir::tpu::extensions { diff --git a/jaxlib/mosaic/dialect/tpu/transforms/linalg_vectorization.cc b/jaxlib/mosaic/dialect/tpu/transforms/linalg_vectorization.cc index 949a26a4f593..0d310ff45b30 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/linalg_vectorization.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/linalg_vectorization.cc @@ -19,32 +19,32 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/container/flat_hash_set.h" #include "absl/strings/string_view.h" -#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h" -#include "mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h" -#include "mlir/include/mlir/Dialect/Math/IR/Math.h" -#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/include/mlir/Dialect/Utils/StaticValueUtils.h" -#include "mlir/include/mlir/Dialect/Vector/IR/VectorOps.h" -#include "mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h" -#include "mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" -#include "mlir/include/mlir/IR/AffineMap.h" -#include "mlir/include/mlir/IR/Attributes.h" -#include "mlir/include/mlir/IR/BuiltinAttributes.h" -#include "mlir/include/mlir/IR/BuiltinTypeInterfaces.h" -#include "mlir/include/mlir/IR/BuiltinTypes.h" -#include "mlir/include/mlir/IR/DialectRegistry.h" -#include "mlir/include/mlir/IR/Matchers.h" -#include "mlir/include/mlir/IR/Operation.h" -#include "mlir/include/mlir/IR/OperationSupport.h" -#include "mlir/include/mlir/IR/PatternMatch.h" -#include "mlir/include/mlir/IR/Types.h" -#include "mlir/include/mlir/IR/Value.h" -#include "mlir/include/mlir/Pass/Pass.h" -#include "mlir/include/mlir/Support/LLVM.h" -#include "mlir/include/mlir/Support/LogicalResult.h" -#include "mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/Transforms/Hoisting.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" +#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/DialectRegistry.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Types.h" +#include "mlir/IR/Value.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" namespace mlir::tpu { diff --git a/jaxlib/mosaic/dialect/tpu/transforms/memory_space_specialization.cc b/jaxlib/mosaic/dialect/tpu/transforms/memory_space_specialization.cc index b73ea0f1250f..1cfb797c5478 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/memory_space_specialization.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/memory_space_specialization.cc @@ -13,12 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "absl/log/check.h" -#include "mlir/include/mlir/IR/Attributes.h" -#include "mlir/include/mlir/IR/BuiltinTypes.h" -#include "mlir/include/mlir/IR/Value.h" -#include "mlir/include/mlir/Support/LLVM.h" -#include "mlir/include/mlir/Support/LogicalResult.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Value.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" namespace mlir { diff --git a/jaxlib/mosaic/dialect/tpu/transforms/relayout_insertion.cc b/jaxlib/mosaic/dialect/tpu/transforms/relayout_insertion.cc index b88504e35068..178b97876b49 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/relayout_insertion.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/relayout_insertion.cc @@ -1,22 +1,36 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + #include #include #include #include +#include "absl/log/check.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/MathExtras.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Diagnostics.h" #include "mlir/IR/Types.h" #include "mlir/IR/Value.h" #include "mlir/IR/Visitors.h" #include "mlir/Pass/Pass.h" -#include "absl/log/check.h" -#include "llvm/include/llvm/ADT/STLExtras.h" -#include "llvm/include/llvm/Support/MathExtras.h" -#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/include/mlir/IR/Builders.h" -#include "mlir/include/mlir/IR/Diagnostics.h" -#include "mlir/include/mlir/Support/LLVM.h" +#include "mlir/Support/LLVM.h" #include "jaxlib/mosaic/dialect/tpu/layout.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" #include "jaxlib/mosaic/dialect/tpu/util.h" @@ -105,7 +119,26 @@ FailureOr> relayout( dst_bitwidth_layout); return cast>(cmp_op.getResult()); } - return v; + // Fall through to generic relayout. + auto relayout_op = + builder.create(v.getLoc(), v.getType(), v); + setLayout(relayout_op, src, dst); + + return cast>(relayout_op.getResult()); +} + +LogicalResult insertRelayout(Operation &op, int hardware_generation, + std::array target_shape); + +LogicalResult insertRelayoutBlock(Block &block, int hardware_generation, + const std::array target_shape) { + // We'll be modifying the block, so use early increment. + for (Operation &op : make_early_inc_range(block)) { + if (failed(insertRelayout(op, hardware_generation, target_shape))) { + return failure(); + } + } + return success(); } // TODO(jevinjiang): make relayout to an op so we don't need decide when to @@ -153,6 +186,15 @@ LogicalResult insertRelayout(Operation &op, int hardware_generation, /*dst=*/*li, hardware_generation, target_shape)); op.setOperand(idx, new_v); } + + for (auto ®ion : op.getRegions()) { + for (auto &block : region.getBlocks()) { + if (failed( + insertRelayoutBlock(block, hardware_generation, target_shape))) { + return failure(); + } + } + } return success(); } diff --git a/jaxlib/mosaic/dialect/tpu/transforms/serde.cc b/jaxlib/mosaic/dialect/tpu/transforms/serde.cc index 0981c263d252..1f1b97e205d8 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/serde.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/serde.cc @@ -18,19 +18,17 @@ limitations under the License. #include #include +#include "llvm/ADT/StringMap.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/OpDefinition.h" #include "mlir/IR/OperationSupport.h" #include "mlir/IR/Value.h" #include "mlir/IR/Visitors.h" #include "mlir/Support/LLVM.h" -#include "llvm/include/llvm/ADT/StringMap.h" -#include "mlir/include/mlir/IR/Attributes.h" -#include "mlir/include/mlir/IR/BuiltinAttributes.h" -#include "mlir/include/mlir/IR/OpDefinition.h" -#include "mlir/include/mlir/IR/OperationSupport.h" -#include "mlir/include/mlir/Support/LogicalResult.h" +#include "mlir/Support/LogicalResult.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" #include "jaxlib/mosaic/serde.h" @@ -42,10 +40,45 @@ constexpr StringRef kMangledDialect = "stable_mosaic."; constexpr StringRef kVersionAttrName = "stable_mosaic.version"; // When this is bumped, we should file a TODO to update the forward-compatible // version in tpu_custom_call.py in a month! -constexpr int kVersion = 3; +constexpr int kVersion = 5; using SerdeRuleType = jaxlib::mosaic::SerdeRuleType; +LogicalResult dynamic_gather_upgrade(Operation* op, int version) { + if (version < 5) { + auto dimension_attr = op->getAttrOfType("dimension"); + if (!dimension_attr || dimension_attr.getValue().getBitWidth() != 32) { + return op->emitError("Missing or invalid dimension attribute"); + } + const int32_t dimension = dimension_attr.getInt(); + op->removeAttr("dimension"); + op->setAttr("dimensions", + DenseI32ArrayAttr::get(op->getContext(), {dimension})); + } + return success(); +} + +LogicalResult dynamic_gather_downgrade(Operation* op, int version) { + if (version < 5) { + auto dimensions_attr = op->getAttrOfType("dimensions"); + if (!dimensions_attr) { + return op->emitError("Missing or invalid dimensions attribute"); + } + const ArrayRef dimensions = dimensions_attr.asArrayRef(); + if (dimensions.size() != 1) { + return op->emitError( + "Can only downgrade below version 5 when a single dimension is " + "specified."); + } + const int32_t dimension = dimensions.front(); + op->removeAttr("dimensions"); + op->setAttr("dimension", + mlir::IntegerAttr::get( + mlir::IntegerType::get(op->getContext(), 32), dimension)); + } + return success(); +} + LogicalResult enqueue_dma_upgrade(Operation* op, int version) { // Added AttrSizedOperandSegments and core_id in version 2. if (version < 2) { @@ -64,6 +97,11 @@ LogicalResult enqueue_dma_upgrade(Operation* op, int version) { << op->getNumOperands(); } } + if (version < 4) { + op->setAttr("priority", + mlir::IntegerAttr::get( + mlir::IntegerType::get(op->getContext(), 32), 0)); + } return success(); } @@ -71,6 +109,9 @@ LogicalResult enqueue_dma_downgrade(Operation* op, int version) { if (version < 2) { return op->emitError("Downgrade to version ") << version << " unsupported"; } + if (version < 4) { + op->removeAttr("priority"); + } return success(); } @@ -148,15 +189,18 @@ LogicalResult vector_multi_dim_reduce_downgrade(Operation* op, int version) { const llvm::StringMap& upgrade_rules() { static auto rules = new llvm::StringMap{ {EnqueueDMAOp::getOperationName(), enqueue_dma_upgrade}, + {DynamicGatherOp::getOperationName(), dynamic_gather_upgrade}, {SemaphoreSignalOp::getOperationName(), semaphore_signal_upgrade}, {vector::MultiDimReductionOp::getOperationName(), - vector_multi_dim_reduce_upgrade}}; + vector_multi_dim_reduce_upgrade}, + }; return *rules; } const llvm::StringMap& downgrade_rules() { static auto rules = new llvm::StringMap{ {EnqueueDMAOp::getOperationName(), enqueue_dma_downgrade}, + {DynamicGatherOp::getOperationName(), dynamic_gather_downgrade}, {SemaphoreSignalOp::getOperationName(), semaphore_signal_downgrade}, {vector::MultiDimReductionOp::getOperationName(), vector_multi_dim_reduce_downgrade}}; diff --git a/jaxlib/mosaic/dialect/tpu/transforms/serde.h b/jaxlib/mosaic/dialect/tpu/transforms/serde.h index 8685918d3b39..e5617ef151f7 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/serde.h +++ b/jaxlib/mosaic/dialect/tpu/transforms/serde.h @@ -1,15 +1,31 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + #ifndef THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_TRANSFORMS_SERDE_H_ #define THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_TRANSFORMS_SERDE_H_ #include #include -#include "llvm/include/llvm/ADT/StringRef.h" -#include "llvm/include/llvm/Support/CommandLine.h" -#include "mlir/include/mlir/Interfaces/DataLayoutInterfaces.h" -#include "mlir/include/mlir/Pass/Pass.h" -#include "mlir/include/mlir/Pass/PassRegistry.h" -#include "jaxlib/pass_boilerplate.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/CommandLine.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Interfaces/DataLayoutInterfaces.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassRegistry.h" +#include "jaxlib/mosaic/pass_boilerplate.h" namespace mlir::tpu { diff --git a/jaxlib/mosaic/dialect/tpu/util.cc b/jaxlib/mosaic/dialect/tpu/util.cc index 651cef85f740..ace5a67a4a42 100644 --- a/jaxlib/mosaic/dialect/tpu/util.cc +++ b/jaxlib/mosaic/dialect/tpu/util.cc @@ -15,6 +15,7 @@ limitations under the License. #include "jaxlib/mosaic/dialect/tpu/util.h" +#include #include #include #include @@ -22,16 +23,21 @@ limitations under the License. #include #include -#include "llvm/Support/MathExtras.h" #include "absl/log/check.h" #include "absl/types/span.h" -#include "llvm/include/llvm/Support/raw_ostream.h" -#include "mlir/include/mlir/IR/Attributes.h" -#include "mlir/include/mlir/IR/BuiltinAttributes.h" -#include "mlir/include/mlir/IR/BuiltinTypes.h" -#include "mlir/include/mlir/IR/Value.h" -#include "mlir/include/mlir/IR/ValueRange.h" -#include "mlir/include/mlir/Support/LLVM.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/MathExtras.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/Support/LLVM.h" #include "jaxlib/mosaic/dialect/tpu/layout.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" @@ -45,23 +51,64 @@ std::ostream &operator<<(std::ostream &os, Print p) { return os; } -SmallVector ComputeTileStrides(MemRefType memref_ty, +SmallVector ComputeTileStrides(absl::Span shape, absl::Span tiling) { - SmallVector tile_strides(memref_ty.getRank()); + SmallVector tile_strides(shape.size()); int64_t stride = 1; - for (int64_t i = 0; i < memref_ty.getRank(); ++i) { - int64_t idx = memref_ty.getRank() - 1 - i; + for (int64_t i = 0; i < shape.size(); ++i) { + int64_t idx = shape.size() - 1 - i; int64_t tiling_idx = tiling.size() - 1 - i; tile_strides[idx] = stride; if (tiling_idx >= 0) { - stride *= llvm::divideCeil(memref_ty.getShape()[idx], tiling[tiling_idx]); + stride *= llvm::divideCeil(shape[idx], tiling[tiling_idx]); } else { - stride *= memref_ty.getShape()[idx]; + stride *= shape[idx]; } } return tile_strides; } +FailureOr> computeSqueezedDimsChecked( + Operation *op, ArrayRef source_shape, + ArrayRef target_shape) { + SmallVector squeezed; + int source_index = source_shape.size() - 1; + int target_index = target_shape.size() - 1; + + while (source_index >= 0 || target_index >= 0) { + int64_t target_dim = (target_index >= 0) ? target_shape[target_index] : -1; + if (source_index < 0) { + op->emitError() << llvm::formatv( + "Target shape is not valid. Source: {0}, Target: {1}.", + shapeToString(source_shape), shapeToString(target_shape)); + return failure(); + } + int64_t source_dim = source_shape[source_index]; + if (source_dim == target_dim) { + source_index--; + target_index--; + } else { + if (source_dim != 1) { + op->emitError() << llvm::formatv( + "Target shape is not valid. Source: {0}, Target: {1}.", + shapeToString(source_shape), shapeToString(target_shape)); + return failure(); + } + squeezed.push_back(source_index); + source_index--; + } + } + + if (source_index != -1 || target_index != -1) { + op->emitError() << "Shape mismatch after traversal. Source shape: " + << shapeToString(source_shape) + << ", target shape: " << shapeToString(target_shape); + return failure(); + } + std::reverse(squeezed.begin(), squeezed.end()); + return squeezed; +} + std::optional> isTransposedMatmul( DotDimensionNumbersAttr dim_numbers) { auto lhs_contracting_dims = dim_numbers.getLhsContractingDims(); @@ -158,6 +205,17 @@ bool canReinterpretToUntiledMemref(TypedValue tiled_memref, *(tiled_layout.getTileStrides().end() - 2) == 1; } +bool isContiguousMemref(TypedValue memref) { + auto memref_ty = getMemRefType(memref); + if (auto tiled_layout = + dyn_cast(memref_ty.getLayout())) { + auto contiguous_tile_strides = ComputeTileStrides( + memref_ty, tiled_layout.getTiles().front().dimensions()); + return contiguous_tile_strides == tiled_layout.getTileStrides(); + } + return true; +} + bool HasMemorySpace(MemRefType ty, tpu::MemorySpace space) { auto memory_space = dyn_cast_or_null(ty.getMemorySpace()); @@ -209,7 +267,10 @@ FailureOr> getOutLayouts( FAILUREOR_ASSIGN_OR_RETURN(const SmallVector out_layouts, getLayoutArrayFromAttr(op.getAttr("out_layout"))); if (out_layouts.size() != op.getNumResults()) { - return op.emitOpError("out_layout size does not match number of results"); + return op.emitOpError("out_layout size does not match number of results") + << " results: " << op.getNumResults() + << " vs layout size: " << out_layouts.size() << " for " + << op.getName(); } for (const auto [l, res] : llvm::zip_equal(out_layouts, op.getResults())) { if (!layoutIsValidForValue(l, res, target_shape)) { @@ -274,4 +335,49 @@ void setLayout(Operation *op, ArrayRef in, ArrayRef out) { setInLayout(op, in); setOutLayout(op, out); } + +std::optional getIntConst(Value v) { + if (auto const_op = v.getDefiningOp()) { + if (auto cst_attr = dyn_cast(const_op.getValue())) { + return cst_attr.getValue().getSExtValue(); + } + } + return std::nullopt; +} + +bool canFoldMinorDimsToSize(ArrayRef shape, int64_t target_size) { + CHECK_GE(shape.size(), 2); + int64_t product = shape.back(); + for (int i = shape.size() - 2; i >= 1; --i) { + product *= shape[i]; + if (product >= target_size) { + break; + } + } + return product == target_size; +} + +SmallVector getNontrivialTransitiveUsers(Value v) { + auto isUnaryElementwise = [](Operation *op) { + if (!op->hasTrait()) { + return false; + } + return op->getNumOperands() == 1 && op->getNumResults() == 1; + }; + SmallVector users; + SmallVector candidates; + candidates.push_back(v); + while (!candidates.empty()) { + Value candidate = candidates.back(); + candidates.pop_back(); + for (const auto &user : candidate.getUsers()) { + if (isa(user) || isUnaryElementwise(user)) + candidates.push_back(user->getResult(0)); + else + users.push_back(user); + } + } + return users; +} + } // namespace mlir::tpu diff --git a/jaxlib/mosaic/dialect/tpu/util.h b/jaxlib/mosaic/dialect/tpu/util.h index 2e19cb820b5b..3d7f6315b695 100644 --- a/jaxlib/mosaic/dialect/tpu/util.h +++ b/jaxlib/mosaic/dialect/tpu/util.h @@ -1,3 +1,18 @@ +/* Copyright 2023 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + #ifndef THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_UTIL_H_ #define THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_UTIL_H_ @@ -10,22 +25,21 @@ #include #include +#include "absl/status/status.h" +#include "absl/types/span.h" +#include "llvm/Support/Compiler.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/Location.h" +#include "mlir/IR/Value.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" -#include "absl/status/status.h" -#include "absl/types/span.h" -#include "mlir/include/mlir/IR/Attributes.h" -#include "mlir/include/mlir/IR/BuiltinTypes.h" -#include "mlir/include/mlir/IR/Diagnostics.h" -#include "mlir/include/mlir/IR/Value.h" #include "jaxlib/mosaic/dialect/tpu/layout.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" -#include "tsl/platform/statusor.h" +#include "xla/tsl/platform/statusor.h" // TODO: Instead of CHECK_EQs, can we do something like TF_RET_CHECK but with // MLIR diagnostics? @@ -166,16 +180,21 @@ FailureOr getTypeBitwidth(Type ty) { << ty; } +// Returns the bitwidth of the element type. The function works for both +// scalar and vector types. +template +inline FailureOr getElementTypeBitwidth(Type ty) { + if (auto vty = dyn_cast(ty)) { + return getTypeBitwidth(vty.getElementType()); + } + return getTypeBitwidth(ty); +} + template ArrayRef> toArrayRef(absl::Span span) { return ArrayRef>(span.data(), span.size()); } -inline arith::ConstantOp IdxConst(int64_t idx, OpBuilder &builder, - Location loc) { - return builder.create(loc, builder.getIndexType(), - builder.getIndexAttr(idx)); -} // Debug only util. template @@ -192,8 +211,22 @@ std::string shapeToString(const T &shape) { return os.str(); } -SmallVector ComputeTileStrides(MemRefType memref_ty, +SmallVector ComputeTileStrides(absl::Span shape, absl::Span tiling); + +inline SmallVector ComputeTileStrides( + MemRefType memref_ty, absl::Span tiling) { + absl::Span shape(memref_ty.getShape().data(), + memref_ty.getShape().size()); + return ComputeTileStrides(shape, tiling); +} + +// Computes the dimensions that were squeezed from the source shape to match the +// target shape. Returns the dimensions in increasing order. +FailureOr> computeSqueezedDimsChecked( + Operation *op, ArrayRef source_shape, + ArrayRef target_shape); + // Assuming MKN matmul - This function must only be called after // canonicalization passes. // @@ -211,6 +244,8 @@ bool canReinterpretToUntiledMemref(TypedValue tiled_memref, const std::array &target_shape, bool allow_minormost_padding = false); +bool isContiguousMemref(TypedValue memref); + // Determines whether the given MemRefType has the given memory space. bool HasMemorySpace(MemRefType ty, tpu::MemorySpace space); @@ -233,6 +268,41 @@ void setLayout(Operation *op, Layout in, Layout out); void setLayout(Operation *op, ArrayRef in, Layout out); void setLayout(Operation *op, Layout in, ArrayRef out); void setLayout(Operation *op, ArrayRef in, ArrayRef out); + +// Helper functions to create constants. +inline arith::ConstantOp IdxConst(int64_t idx, OpBuilder &builder, + Location loc) { + return builder.create(loc, builder.getIndexType(), + builder.getIndexAttr(idx)); +} + +inline arith::ConstantOp I32Const(int32_t value, OpBuilder &builder, + Location loc) { + return builder.create(loc, builder.getI32Type(), + builder.getI32IntegerAttr(value)); +} + +inline arith::ConstantOp I32Const(int32_t value, ArrayRef shape, + OpBuilder &builder, Location loc) { + return builder.create( + loc, DenseElementsAttr::get( + VectorType::get(shape, builder.getI32Type()), + builder.getIntegerAttr(builder.getI32Type(), value))); +} + +std::optional getIntConst(Value v); + +// Returns true if the product of up to `shape.size() - 1` minor-most dimensions +// in `shape` equals `target_size`. The major-most dimension is not considered. +// Precondition: `shape` has at least 2 dimensions. +bool canFoldMinorDimsToSize(ArrayRef shape, int64_t target_size); + +// Recursively finds all non-trivial users of a given value, including those +// accessed via `tpu.bitcast` or unary elementwise operations. However, +// `tpu.bitcast` and unary element-wise operations are excluded from the +// results. +SmallVector getNontrivialTransitiveUsers(Value v); + } // namespace mlir::tpu #endif // THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_UTIL_H_ diff --git a/jaxlib/mosaic/dialect/tpu/vreg_util.cc b/jaxlib/mosaic/dialect/tpu/vreg_util.cc index 1f59ee13a311..90efacf0c676 100644 --- a/jaxlib/mosaic/dialect/tpu/vreg_util.cc +++ b/jaxlib/mosaic/dialect/tpu/vreg_util.cc @@ -19,16 +19,16 @@ limitations under the License. #include #include "absl/log/check.h" -#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/include/mlir/IR/Attributes.h" -#include "mlir/include/mlir/IR/BuiltinAttributes.h" -#include "mlir/include/mlir/IR/BuiltinTypes.h" -#include "mlir/include/mlir/IR/Diagnostics.h" -#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" -#include "mlir/include/mlir/IR/Types.h" -#include "mlir/include/mlir/IR/Value.h" -#include "mlir/include/mlir/IR/ValueRange.h" -#include "mlir/include/mlir/Support/LLVM.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/IR/Types.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/Support/LLVM.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" #include "jaxlib/mosaic/dialect/tpu/util.h" #include "xla/array.h" @@ -79,6 +79,19 @@ TypedValue getFullLikeVector(ImplicitLocOpBuilder &builder, return getFullVector(builder, vec.getType(), value); } +TypedValue getFullVector(OpBuilder &builder, Location loc, + VectorType vty, Attribute value) { + return cast>( + builder.create(loc, DenseElementsAttr::get(vty, value)) + .getResult()); +} + +TypedValue getFullLikeVector(OpBuilder &builder, Location loc, + TypedValue vec, + Attribute value) { + return getFullVector(builder, loc, vec.getType(), value); +} + TypedValue getZerosVector(ImplicitLocOpBuilder &builder, VectorType vty) { return getFullVector(builder, vty, builder.getZeroAttr(vty.getElementType())); @@ -211,8 +224,7 @@ LogicalResult maskNativeTilingVregs(ImplicitLocOpBuilder &builder, FailureOr> broadcastSubelements( ImplicitLocOpBuilder &builder, TypedValue vec, - int subelement_idx, std::array target_shape, - int hardware_generation) { + int subelement_idx, std::array target_shape) { int bitwidth = vec.getType().getElementTypeBitWidth(); int packing = 32 / bitwidth; if (subelement_idx < 0 || subelement_idx >= packing) { @@ -234,17 +246,9 @@ FailureOr> broadcastSubelements( src_vreg_int, getFullVector(builder, vreg_native_int_ty, builder.getI32IntegerAttr(subelement_idx * bitwidth))); - Value vreg_result_int; - if (hardware_generation >= 5) { - SmallVector packed_vregs(packing, vreg_subelement_low); - vreg_result_int = builder.create( - vreg_packed_int_ty, packed_vregs, tpu::PackFormat::kInterleaved); - } else { - // This can be virtualized as a tree of shifts and ORs. - return builder.emitError() - << "broadcastSubelements not implemented for hardware generation " - << hardware_generation; - } + SmallVector packed_vregs(packing, vreg_subelement_low); + Value vreg_result_int = builder.create( + vreg_packed_int_ty, packed_vregs, tpu::PackFormat::kInterleaved); return cast>( builder.create(vec.getType(), vreg_result_int) .getResult()); diff --git a/jaxlib/mosaic/dialect/tpu/vreg_util.h b/jaxlib/mosaic/dialect/tpu/vreg_util.h index 86955e128f59..90e802fcb8fc 100644 --- a/jaxlib/mosaic/dialect/tpu/vreg_util.h +++ b/jaxlib/mosaic/dialect/tpu/vreg_util.h @@ -19,12 +19,12 @@ limitations under the License. #include #include -#include "mlir/include/mlir/IR/Attributes.h" -#include "mlir/include/mlir/IR/Builders.h" -#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" -#include "mlir/include/mlir/IR/Types.h" -#include "mlir/include/mlir/IR/Value.h" -#include "mlir/include/mlir/Support/LLVM.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/IR/Types.h" +#include "mlir/IR/Value.h" +#include "mlir/Support/LLVM.h" #include "xla/array.h" namespace mlir::tpu { @@ -50,6 +50,15 @@ TypedValue getFullLikeVector(ImplicitLocOpBuilder &builder, TypedValue vec, Attribute value); +// Same as above, but takes a `loc` as input, in case of an OpBuilder. +TypedValue getFullVector(OpBuilder &builder, Location loc, + VectorType vty, Attribute value); + +// Same as above, but takes a `vec` as input. +TypedValue getFullLikeVector(OpBuilder &builder, Location loc, + TypedValue vec, + Attribute value); + // Creates a vmask with false flags to bottom (dim = 0) // or right (dim = 1) where the flag count corresponds to the (dim_size - // padding). @@ -81,8 +90,7 @@ LogicalResult maskNativeTilingVregs(ImplicitLocOpBuilder &builder, // subelement_idx must be between 0 and packing. FailureOr> broadcastSubelements( ImplicitLocOpBuilder &builder, TypedValue vec, - int subelement_idx, std::array target_shape, - int hardware_generation); + int subelement_idx, std::array target_shape); } // namespace mlir::tpu diff --git a/jaxlib/mosaic/dialect/tpu/vreg_util_test.cc b/jaxlib/mosaic/dialect/tpu/vreg_util_test.cc index ea3063361e1a..8a6d437ab73c 100644 --- a/jaxlib/mosaic/dialect/tpu/vreg_util_test.cc +++ b/jaxlib/mosaic/dialect/tpu/vreg_util_test.cc @@ -21,20 +21,20 @@ limitations under the License. #include #include -#include "llvm/include/llvm/ADT/TypeSwitch.h" -#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/include/mlir/Dialect/Vector/IR/VectorOps.h" -#include "mlir/include/mlir/IR/Attributes.h" -#include "mlir/include/mlir/IR/Builders.h" -#include "mlir/include/mlir/IR/BuiltinAttributes.h" -#include "mlir/include/mlir/IR/BuiltinOps.h" -#include "mlir/include/mlir/IR/BuiltinTypes.h" -#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" -#include "mlir/include/mlir/IR/MLIRContext.h" -#include "mlir/include/mlir/IR/OwningOpRef.h" -#include "mlir/include/mlir/IR/Value.h" -#include "mlir/include/mlir/Support/DebugStringHelper.h" -#include "mlir/include/mlir/Support/LLVM.h" +#include "llvm/ADT/TypeSwitch.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OwningOpRef.h" +#include "mlir/IR/Value.h" +#include "mlir/Support/DebugStringHelper.h" +#include "mlir/Support/LLVM.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" namespace mlir::tpu { diff --git a/jaxlib/mosaic/gpu/BUILD b/jaxlib/mosaic/gpu/BUILD index 9249ae256901..e50ecfaa63ec 100644 --- a/jaxlib/mosaic/gpu/BUILD +++ b/jaxlib/mosaic/gpu/BUILD @@ -26,6 +26,14 @@ py_library( deps = [":_mosaic_gpu_ext"], ) +cc_library( + name = "mosaic_gpu_support", + deps = [ + ":custom_call", + ":runtime", + ], +) + cc_library( name = "target", srcs = ["target.cc"], @@ -52,7 +60,7 @@ cc_library( "serde.h", ], deps = [ - "//jaxlib:pass_boilerplate", + "//jaxlib/mosaic:pass_boilerplate", "//jaxlib/mosaic:serde", "@llvm-project//llvm:Support", "@llvm-project//mlir:DataLayoutInterfaces", @@ -111,22 +119,46 @@ cc_library( cc_library( name = "runtime", srcs = ["runtime.cc"], + # Linker may prune these symbols if they are not explicitly exported. + linkopts = [ + "-Wl,--export-dynamic-symbol='mosaic_gpu_*'", + "-Wl,--export-dynamic-symbol='nvshmem_my_pe'", + "-Wl,--export-dynamic-symbol='nvshmem_ptr'", + "-Wl,--export-dynamic-symbol='nvshmemx_barrier_all_on_stream'", + "-Wl,--export-dynamic-symbol='nvshmemx_cumodule_init'", + "-Wl,--export-dynamic-symbol='nvshmemx_init_status'", + ], deps = [ + ":nvshmem", "@local_config_cuda//cuda:cuda_headers", ], + alwayslink = True, +) + +cc_library( + name = "nvshmem", + hdrs = ["nvshmem.h"], + deps = [ + "@local_config_cuda//cuda:cuda_headers", + "@xla//xla/tsl/cuda:cudart", + ], ) cc_library( name = "custom_call", srcs = ["custom_call.cc"], deps = [ + ":library_paths", + ":nvshmem", ":passes", ":target", "//jaxlib/cuda:cuda_vendor", "//jaxlib/mosaic/dialect/gpu:mosaic_gpu", + "@com_google_absl//absl/base", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/cleanup", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -139,6 +171,7 @@ cc_library( "@llvm-project//mlir:ArithTransforms", "@llvm-project//mlir:BuiltinToLLVMIRTranslation", "@llvm-project//mlir:ComplexToLLVM", + "@llvm-project//mlir:ControlFlowDialect", "@llvm-project//mlir:ControlFlowToLLVM", "@llvm-project//mlir:ConversionPasses", "@llvm-project//mlir:ExecutionEngine", @@ -151,6 +184,7 @@ cc_library( "@llvm-project//mlir:IR", "@llvm-project//mlir:IndexToLLVM", "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:LLVMIRTransforms", "@llvm-project//mlir:LLVMToLLVMIRTranslation", "@llvm-project//mlir:MathDialect", "@llvm-project//mlir:MathToLLVM", @@ -170,6 +204,9 @@ cc_library( "@llvm-project//mlir:UBToLLVM", "@llvm-project//mlir:VectorDialect", "@llvm-project//mlir:VectorToLLVM", + "@tsl//tsl/profiler/lib:traceme", + "@xla//xla/ffi", + "@xla//xla/ffi:ffi_api", "@xla//xla/service:custom_call_status", "@xla//xla/service:custom_call_target_registry", ], @@ -205,7 +242,13 @@ cc_binary( "notap", ], deps = [ + ":nvshmem", "@local_config_cuda//cuda:cuda_headers", "@xla//xla/tsl/cuda:cudart", ], ) + +cc_library( + name = "library_paths", + hdrs = ["library_paths.h"], +) diff --git a/jaxlib/mosaic/gpu/custom_call.cc b/jaxlib/mosaic/gpu/custom_call.cc index 402e099c8d6b..39f9635b043b 100644 --- a/jaxlib/mosaic/gpu/custom_call.cc +++ b/jaxlib/mosaic/gpu/custom_call.cc @@ -18,10 +18,13 @@ limitations under the License. #include #include +#include #include +#include #include #include #include +#include #include #include #include @@ -31,73 +34,257 @@ limitations under the License. #include #include +#include "jaxlib/mosaic/gpu/library_paths.h" +#include "absl/base/call_once.h" #include "absl/base/optimization.h" #include "absl/cleanup/cleanup.h" #include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/log/check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/numbers.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" -#include "llvm/include/llvm/ADT/SmallVector.h" -#include "llvm/include/llvm/Support/CodeGen.h" -#include "llvm/include/llvm/Support/TargetSelect.h" -#include "mlir/include/mlir/Conversion/ArithToLLVM/ArithToLLVM.h" -#include "mlir/include/mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h" -#include "mlir/include/mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" -#include "mlir/include/mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" -#include "mlir/include/mlir/Conversion/IndexToLLVM/IndexToLLVM.h" -#include "mlir/include/mlir/Conversion/MathToLLVM/MathToLLVM.h" -#include "mlir/include/mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" -#include "mlir/include/mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h" -#include "mlir/include/mlir/Conversion/Passes.h" -#include "mlir/include/mlir/Conversion/UBToLLVM/UBToLLVM.h" -#include "mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" -#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/include/mlir/Dialect/Arith/Transforms/Passes.h" -#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h" -#include "mlir/include/mlir/Dialect/GPU/Transforms/Passes.h" -#include "mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h" -#include "mlir/include/mlir/Dialect/Math/IR/Math.h" -#include "mlir/include/mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h" -#include "mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h" -#include "mlir/include/mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/include/mlir/Dialect/Vector/IR/VectorOps.h" -#include "mlir/include/mlir/ExecutionEngine/ExecutionEngine.h" -#include "mlir/include/mlir/ExecutionEngine/OptUtils.h" -#include "mlir/include/mlir/IR/AsmState.h" -#include "mlir/include/mlir/IR/DialectRegistry.h" -#include "mlir/include/mlir/IR/MLIRContext.h" -#include "mlir/include/mlir/Parser/Parser.h" -#include "mlir/include/mlir/Pass/PassManager.h" -#include "mlir/include/mlir/Pass/PassRegistry.h" -#include "mlir/include/mlir/Support/LLVM.h" -#include "mlir/include/mlir/Support/LogicalResult.h" -#include "mlir/include/mlir/Target/LLVM/NVVM/Target.h" -#include "mlir/include/mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" -#include "mlir/include/mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h" -#include "mlir/include/mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" -#include "mlir/include/mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h" -#include "mlir/include/mlir/Transforms/Passes.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/CodeGen.h" +#include "llvm/Support/TargetSelect.h" +#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" +#include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h" +#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" +#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" +#include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h" +#include "mlir/Conversion/MathToLLVM/MathToLLVM.h" +#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" +#include "mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h" +#include "mlir/Conversion/Passes.h" +#include "mlir/Conversion/UBToLLVM/UBToLLVM.h" +#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Arith/Transforms/Passes.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/GPU/Transforms/Passes.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" +#include "mlir/Dialect/LLVMIR/Transforms/Passes.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/MemRef/Transforms/Passes.h" +#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/ExecutionEngine/ExecutionEngine.h" +#include "mlir/ExecutionEngine/OptUtils.h" +#include "mlir/IR/AsmState.h" +#include "mlir/IR/DialectRegistry.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/Parser/Parser.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Pass/PassRegistry.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Target/LLVM/NVVM/Target.h" +#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h" +#include "mlir/Transforms/Passes.h" #include "jaxlib/gpu/vendor.h" #include "jaxlib/mosaic/dialect/gpu/mosaic_gpu.h" #include "jaxlib/mosaic/gpu/launch_lowering.h" +#include "jaxlib/mosaic/gpu/nvshmem.h" #include "jaxlib/mosaic/gpu/passes.h" #include "jaxlib/mosaic/gpu/serde.h" #include "jaxlib/mosaic/gpu/target.h" +#include "xla/ffi/ffi.h" +#include "xla/ffi/ffi_api.h" #include "xla/service/custom_call_status.h" #include "xla/service/custom_call_target_registry.h" +#include "tsl/profiler/lib/traceme.h" namespace { +namespace ffi = xla::ffi; + using MosaicInitFunc = void(void****); using MosaicHostFunc = void(void**); -absl::StatusOr> GetSmAndPtxIsaVersion() { +class TemporaryDirectory { + private: + TemporaryDirectory(std::string path) : path(std::move(path)) {} + // TODO(apaszke): Unlink in destructor. + + public: + static absl::StatusOr Create() { + std::string pattern = "/tmp/mosaic-gpu-XXXXXX"; + if (mkdtemp(pattern.data()) == NULL) { + return absl::InternalError("Failed to create temporary directory"); + } + return TemporaryDirectory(std::move(pattern)); + } + + std::string_view GetPath() { return path; } + + private: + std::string path; +}; + +absl::StatusOr RunCUDATool(const char* tool, + const std::vector& args, + bool stderr_to_stdout = true) { + CHECK(!args.empty() && args.back() == nullptr); + const char* cuda_path_ptr = mosaic::gpu::GetCUDARoot(); + if (!cuda_path_ptr) + return absl::InternalError("Failed to get the CUDA toolkit path"); + std::string tool_path(cuda_path_ptr); + tool_path += "/bin/"; + tool_path += tool; + int stdout_pipe[2] = {-1, -1}; + pid_t child_pid; + posix_spawn_file_actions_t file_actions; + if (posix_spawn_file_actions_init(&file_actions)) { + return absl::InternalError("Failed to initialize spawn file actions"); + } + absl::Cleanup file_actions_destroyer = [&file_actions] { + posix_spawn_file_actions_destroy(&file_actions); + }; + if (pipe(stdout_pipe) == -1) { + return absl::InternalError("Failed to set up pipe"); + } + absl::Cleanup pipe_closer = [&stdout_pipe] { + if (stdout_pipe[0] != -1) close(stdout_pipe[0]); + if (stdout_pipe[1] != -1) close(stdout_pipe[1]); + }; + // close read end in child + if (posix_spawn_file_actions_addclose(&file_actions, stdout_pipe[0])) { + return absl::InternalError("Failed to close read end of the pipe in child"); + } + if (posix_spawn_file_actions_adddup2(&file_actions, stdout_pipe[1], + STDOUT_FILENO)) { + return absl::InternalError("Failed to redirect stdout to pipe"); + } + if (stderr_to_stdout && posix_spawn_file_actions_adddup2( + &file_actions, STDOUT_FILENO, STDERR_FILENO)) { + return absl::InternalError("Failed to redirect stderr to stdout"); + } + // execv is guaranteed by POSIX to not modify the args (other than + // replacing the whole process image), so the const_cast is valid. + if (int status = + posix_spawn(&child_pid, tool_path.c_str(), &file_actions, nullptr, + const_cast(args.data()), environ)) { + return absl::InternalError( + absl::StrCat("Process spawn failed: ", strerror(status))); + } + // Proactively close write end in parent. If we don't do this, read + // will block since the pipe will have an open write end in the + // parent process. + if (close(stdout_pipe[1]) == -1) { + return absl::InternalError( + absl::StrCat("Failed to close write end of pipe in parent process: ", + strerror(errno))); + } + // Mark the write end as successfully closed, so it doesn't get + // closed a second time by the deferred pipe_closer. + stdout_pipe[1] = -1; + std::string stdout; + char buf[1024]; + while (int bytes_read = read(stdout_pipe[0], buf, sizeof buf)) { + if (bytes_read == -1) { + return absl::InternalError( + absl::StrCat("Failed to read from pipe: ", strerror(errno))); + } + stdout.append(buf, bytes_read); + } + int status; + if (waitpid(child_pid, &status, 0) == -1) { + return absl::InternalError("Failed to wait for CUDA tool invocation"); + } + if (status != 0) { + std::string error_message = "CUDA tool failed"; + if (!stdout.empty()) { + error_message += ": "; + error_message += stdout; + } + return absl::InternalError(error_message); + } + return stdout; +} + +void EnsureLLVMNVPTXTargetIsRegistered() { + static absl::once_flag register_nvptx_target_flag; + absl::call_once(register_nvptx_target_flag, []() { + LLVMInitializeNVPTXTarget(); + LLVMInitializeNVPTXTargetInfo(); + LLVMInitializeNVPTXTargetMC(); + LLVMInitializeNVPTXAsmPrinter(); + }); +} + +absl::StatusOr GetLatestPtxasPtxIsaVersion() { + std::vector ptxas_args = {"ptxas", "--input-as-string", + ".version 99.99", nullptr}; + auto status = RunCUDATool("ptxas", ptxas_args).status(); + if (status.ok()) { + return absl::InternalError("ptxas succeeded where it was expected to fail"); + } + // Output message is of the form: + // ptxas application ptx input, line 1; fatal : + // Unsupported .version 99.99; current version is '8.8' + std::vector chunks = absl::StrSplit(status.message(), '\''); + if (chunks.size() != 3) { + return absl::InternalError(absl::StrCat( + "Failed to locate PTX ISA version in ptxas error message: ", + status.message())); + } + std::vector major_minor = absl::StrSplit(chunks[1], '.'); + if (major_minor.size() != 2) { + return absl::InternalError( + absl::StrFormat("Expected PTX ISA version to be formatted as " + "MAJOR.MINOR, instead got: %s", + chunks[1])); + } + int major; + if (!absl::SimpleAtoi(major_minor[0], &major)) { + return absl::InternalError( + absl::StrFormat("Failed to parse PTX ISA major version, expected a " + "parsable integer, instead got: %s", + major_minor[0])); + } + int minor; + if (!absl::SimpleAtoi(major_minor[1], &minor)) { + return absl::InternalError( + absl::StrFormat("Failed to parse PTX ISA minor version, expected a " + "parsable integer, instead got: %s", + major_minor[1])); + } + if (minor >= 10) { + return absl::InternalError( + absl::StrFormat("PTX ISA minor version %d is not less than or equal to " + "9, which is assumed for version comparison", + minor)); + } + return major * 10 + minor; +} + +absl::StatusOr GetPtxIsaVersion() { + TF_ASSIGN_OR_RETURN(int ptxas_latest_version, GetLatestPtxasPtxIsaVersion()); + // We'd like to target the latest PTX ISA version supported by + // ptxas. However, it doesn't make sense to ask LLVM to target a PTX + // ISA that it isn't aware of yet. Find the latest version supported + // by LLVM and return the minimum of the two versions, one from + // ptxas and the other from LLVM. + TF_ASSIGN_OR_RETURN(int llvm_latest_version, + mosaic::gpu::GetLatestLlvmPtxIsaVersion()); + int final_version = std::min(ptxas_latest_version, llvm_latest_version); + return absl::StrFormat("ptx%d", final_version); +} + +absl::StatusOr GetSmVersion() { // Assumes driver has been initialized and a context exists. XLA already has // some utilities to query this, but we try to stay runtime-agnostic, so we // build our own here. @@ -115,13 +302,17 @@ absl::StatusOr> GetSmAndPtxIsaVersion() { device) != CUDA_SUCCESS) { return absl::InternalError("Failed to get minor compute capability"); } - return mosaic::gpu::GetSmAndPtxIsaVersion(major, minor); + EnsureLLVMNVPTXTargetIsRegistered(); + return mosaic::gpu::GetSmVersion(major, minor); } mlir::FailureOr GetPassPipeline( mlir::MLIRContext* ctx, mlir::gpu::CompilationTarget target, - const std::string& sm, const std::string& ptx_isa) { - static bool register_once = []() { + const std::string& sm, const std::string& ptx_isa, const std::string& nvshmem_path) { + static absl::once_flag register_passes_flag; + absl::call_once(register_passes_flag, []() { + EnsureLLVMNVPTXTargetIsRegistered(); + llvm::InitializeNativeTarget(); llvm::InitializeNativeTarget(); llvm::InitializeNativeTargetAsmPrinter(); @@ -148,9 +339,13 @@ mlir::FailureOr GetPassPipeline( mosaic::gpu::registerConvertGpuToLLVMPass(); mosaic::gpu::registerByvalInsertionPass(); mlir::arith::registerArithExpandOpsPass(); + mlir::LLVM::registerDIScopeForLLVMFuncOpPass(); return true; - }(); - (void)register_once; + }); + const char *cuda_root = mosaic::gpu::GetCUDARoot(); + if (!cuda_root) { + return mlir::failure(); + } return mlir::parsePassPipeline(absl::StrCat( R"( builtin.module( @@ -164,22 +359,27 @@ mlir::FailureOr GetPassPipeline( convert-nvvm-to-llvm, expand-strided-metadata, nvvm-attach-target{O=3 chip=)", - sm, R"( fast=false features=+)", ptx_isa, + sm, " fast=false features=+", ptx_isa, R"( ftz=false module= triple=nvptx64-nvidia-cuda}, lower-affine, convert-arith-to-llvm{index-bitwidth=0}, convert-index-to-llvm{index-bitwidth=64}, canonicalize{max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}, cse, - gpu.module(strip-debuginfo), + )", + R"( gpu.module(convert-gpu-to-nvvm{has-redux=false index-bitwidth=64 use-bare-ptr-memref-call-conv=false}), gpu.module(canonicalize{max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}), gpu.module(cse), gpu.module(mosaic-byval-insertion), gpu.module(reconcile-unrealized-casts), mosaic-convert-gpu-to-llvm, + ensure-debug-info-scope-on-llvm-func{emission-kind=DebugDirectivesOnly}, gpu-module-to-binary{format=)", - mlir::gpu::stringifyCompilationTarget(target).str(), R"(}, + mlir::gpu::stringifyCompilationTarget(target).str(), + (!nvshmem_path.empty() ? " l=" + nvshmem_path : ""), + " opts=-lineinfo toolkit=", cuda_root, + R"(}, convert-math-to-llvm{approximate-log1p=true}, canonicalize{max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}, cse, @@ -205,12 +405,12 @@ mlir::LogicalResult RunPasses(mlir::OpPassManager&& passes, void InitContext(mlir::MLIRContext* context) { mlir::DialectRegistry registry; - registry.insert(); + registry.insert(); mlir::registerConvertNVVMToLLVMInterface(registry); mlir::registerConvertComplexToLLVMInterface(registry); mlir::registerConvertMemRefToLLVMInterface(registry); @@ -232,63 +432,8 @@ void InitContext(mlir::MLIRContext* context) { context->loadAllAvailableDialects(); } -absl::Status RunCUDATool(const char* tool, - const std::vector& args, - bool stderr_to_stdout = false) { - CHECK(!args.empty() && args.back() == nullptr); - const char * cuda_path_ptr = getenv("CUDA_ROOT"); - if (!cuda_path_ptr) return absl::InternalError("Failed to get CUDA_ROOT"); - std::string tool_path(cuda_path_ptr); - tool_path += "/bin/"; - tool_path += tool; - pid_t child_pid; - posix_spawn_file_actions_t file_actions; - if (posix_spawn_file_actions_init(&file_actions)) { - return absl::InternalError("Failed to initialize spawn file actions"); - } - if (posix_spawn_file_actions_adddup2(&file_actions, STDOUT_FILENO, - STDERR_FILENO)) { - return absl::InternalError("Failed to set up spawn file actions"); - } - // execv is guaranteed by POSIX to not modify the args (other than - // replacing the whole process image), so the const_cast is valid. - if (posix_spawn(&child_pid, tool_path.c_str(), &file_actions, nullptr, - const_cast(args.data()), environ)) { - return absl::InternalError("Process spawn failed"); - } - int status; - if (waitpid(child_pid, &status, 0) == -1) { - return absl::InternalError("Failed to wait for CUDA tool invocation"); - } - if (status != 0) return absl::InternalError("CUDA tool failed"); - if (posix_spawn_file_actions_destroy(&file_actions) != 0) { - return absl::InternalError("Failed to clean up after posix_spawn"); - } - return absl::OkStatus(); -} - -class TemporaryDirectory { - private: - TemporaryDirectory(std::string path) : path(std::move(path)) {} - // TODO(apaszke): Unlink in destructor. - - public: - static absl::StatusOr Create() { - std::string pattern = "/tmp/mosaic-gpu-XXXXXX"; - if (mkdtemp(pattern.data()) == NULL) { - return absl::InternalError("Failed to create temporary directory"); - } - return TemporaryDirectory(std::move(pattern)); - } - - std::string_view GetPath() { return path; } - - private: - std::string path; -}; - void DumpCompilationOutput(mlir::ModuleOp module, const std::string& sm, - const std::string& ptx_isa) { + const std::string& ptx_isa, const std::string& nvshmem_path) { bool dump_ptx = getenv("MOSAIC_GPU_DUMP_PTX") != nullptr; bool dump_ptxas = getenv("MOSAIC_GPU_DUMP_PTXAS") != nullptr; bool dump_sass = getenv("MOSAIC_GPU_DUMP_SASS") != nullptr; @@ -299,7 +444,8 @@ void DumpCompilationOutput(mlir::ModuleOp module, const std::string& sm, module = module.clone(); // Prevent accidental modification. absl::Cleanup module_destroyer = [module] { module->erase(); }; auto passes = GetPassPipeline( - module.getContext(), mlir::gpu::CompilationTarget::Assembly, sm, ptx_isa); + module.getContext(), mlir::gpu::CompilationTarget::Assembly, + sm, ptx_isa, nvshmem_path); if (mlir::failed(passes) || mlir::failed(RunPasses(std::move(*passes), module))) { return; @@ -341,33 +487,69 @@ void DumpCompilationOutput(mlir::ModuleOp module, const std::string& sm, ptxas_args.push_back("-v"); } ptxas_args.push_back(nullptr); - if (auto status = RunCUDATool("ptxas", ptxas_args); !status.ok()) { - std::cerr << "ptxas invocation failed: " << status.message() << std::endl; + if (auto result = RunCUDATool("ptxas", ptxas_args); !result.ok()) { + std::cerr << "ptxas invocation failed: " << result.status() << std::endl; continue; + } else if (dump_ptxas) { + std::cout << *result << std::endl; } if (!dump_sass) { continue; } // We're done. // Call nvdisasm to pretty-print SASS. - if (auto status = RunCUDATool( - "nvdisasm", {"nvdisasm", "-ndf", "-c", elf_path.c_str(), nullptr}); - !status.ok()) { - std::cerr << "nvdisasm invocation failed: " << status.message() + auto result = RunCUDATool( + "nvdisasm", {"nvdisasm", "-ndf", "-c", elf_path.c_str(), nullptr}); + if (!result.ok()) { + std::cerr << "nvdisasm invocation failed: " << result.status() << std::endl; continue; } + // Dump SASS. + std::cout << *result << std::endl; + } +} + +bool is_nvshmem_used(mlir::ModuleOp module) { + constexpr std::string_view prefix1 = "nvshmem_"; + constexpr std::string_view prefix2 = "nvshmemx_"; + for (mlir::LLVM::LLVMFuncOp llvm_func : module.getOps()) { + const auto& func_name = llvm_func.getName(); + if (!func_name.starts_with(prefix1) && !func_name.starts_with(prefix2)) { + continue; + } + auto uses = mlir::SymbolTable::getSymbolUses(llvm_func, module.getOperation()); + if (uses && !uses->empty()) { + return true; + } } + return false; +} + +absl::StatusOr get_nvshmem_llvm_lib_path() { + const char* nvshmem_path_ptr = getenv("MOSAIC_GPU_NVSHMEM_BC_PATH"); + if (!nvshmem_path_ptr) + return absl::InternalError("Failed to get MOSAIC_GPU_NVSHMEM_BC_PATH"); + return nvshmem_path_ptr; } -absl::StatusOr> Compile( +absl::StatusOr, bool>> Compile( mlir::ModuleOp module) { - auto sm_and_ptx_isa = GetSmAndPtxIsaVersion(); - if (!sm_and_ptx_isa.ok()) { - return sm_and_ptx_isa.status(); + tsl::profiler::TraceMe trace("Compile"); + TF_ASSIGN_OR_RETURN(std::string sm, GetSmVersion()); + TF_ASSIGN_OR_RETURN(std::string ptx_isa, GetPtxIsaVersion()); + bool is_comm_used = is_nvshmem_used(module); + std::string nvshmem_path = ""; + if (is_comm_used) { + TF_ASSIGN_OR_RETURN(nvshmem_path, get_nvshmem_llvm_lib_path()); + if (!mosaic::gpu::NvshmemApi::Default(/*assert_ok=*/false).is_loaded()) { + return absl::InternalError( + "Failed to load the NVSHMEM library. Make sure it is installed (e.g. " + "`pip install nvidia-nvshmem-cu12`)."); + } } - const std::string sm = sm_and_ptx_isa.value().first; - const std::string ptx_isa = sm_and_ptx_isa.value().second; - DumpCompilationOutput(module, sm, ptx_isa); + DumpCompilationOutput(module, sm, ptx_isa, nvshmem_path); auto passes = GetPassPipeline( - module.getContext(), mlir::gpu::CompilationTarget::Binary, sm, ptx_isa); + module.getContext(), + mlir::gpu::CompilationTarget::Binary, + sm, ptx_isa, nvshmem_path); if (mlir::failed(passes)) { return absl::InternalError("Failed to construct pass pipeline"); } @@ -375,9 +557,12 @@ absl::StatusOr> Compile( return absl::InternalError("Pass pipeline failed"); } - llvm::SmallVector runtime_lib; - if (const char* lib_path = getenv("MOSAIC_GPU_RUNTIME_LIB_PATH")) { - runtime_lib.emplace_back(lib_path); + llvm::SmallVector runtime_libs; + if (const char* runtime_lib_path = getenv("MOSAIC_GPU_RUNTIME_LIB_PATH")) { + runtime_libs.emplace_back(runtime_lib_path); + } + if (const char* nvshmem_path = getenv("MOSAIC_GPU_NVSHMEM_SO_PATH")) { + runtime_libs.emplace_back(nvshmem_path); } // Create a transformer to run all LLVM optimization passes at the // specified optimization level. @@ -386,28 +571,30 @@ absl::StatusOr> Compile( mlir::ExecutionEngineOptions options; options.transformer = transformer; options.jitCodeGenOptLevel = llvm::CodeGenOptLevel::Aggressive; - options.sharedLibPaths = runtime_lib; + options.sharedLibPaths = runtime_libs; auto maybe_execution_engine = mlir::ExecutionEngine::create(module, options); if (!maybe_execution_engine) { return absl::InternalError("Failed to compile kernel"); } - return std::move(*maybe_execution_engine); + return std::make_pair(std::move(*maybe_execution_engine), is_comm_used); } class CompiledKernel { public: CompiledKernel(std::unique_ptr engine, void* ctx, - MosaicHostFunc* host_launch) - : engine_(std::move(engine)), ctx_(ctx), host_launch_(host_launch) {} + MosaicHostFunc* host_launch, bool is_comm_used) + : engine_(std::move(engine)), ctx_(ctx), host_launch_(host_launch), + is_comm_used_(is_comm_used) {} - std::tuple GetHostLaunch() { - return std::make_tuple(ctx_, host_launch_); + std::tuple GetHostLaunch() { + return std::make_tuple(ctx_, host_launch_, is_comm_used_); } private: std::unique_ptr engine_; void* ctx_; // TODO(apaszke): Destroy this properly MosaicHostFunc* host_launch_; + bool is_comm_used_; }; using KernelHash = std::array; @@ -476,7 +663,8 @@ absl::StatusOr CompileAndInit(const char* module) { if (!maybe_engine.ok()) { return maybe_engine.status(); } - mlir::ExecutionEngine* execution_engine = maybe_engine->get(); + mlir::ExecutionEngine* execution_engine = maybe_engine.value().first.get(); + bool is_comm_used = maybe_engine.value().second; auto host_and_init_func_names = GetHostAndInitFuncNames(*module_op); if (!host_and_init_func_names.ok()) { @@ -495,14 +683,15 @@ absl::StatusOr CompileAndInit(const char* module) { void** kernel_ptr_ptr = &kernel_ptr; void*** init_args[2] = {&module_ptr_ptr, &kernel_ptr_ptr}; reinterpret_cast(*init)(init_args); - return CompiledKernel(std::move(*maybe_engine), kernel_ptr, - reinterpret_cast(*host)); + return CompiledKernel(std::move(maybe_engine.value().first), kernel_ptr, + reinterpret_cast(*host), + is_comm_used); } // Each compiled kernel has a unique init func, and each kernel is used from // a single HLO module. So it should be safe to not include the CUDA context // in the key. -absl::StatusOr> CachedCompileAndInit( +absl::StatusOr CachedCompileAndInit( CacheKey key, const char* module) { auto cache_and_mutex = GetKernelCache(); auto* cache = cache_and_mutex.first; @@ -513,23 +702,25 @@ absl::StatusOr> CachedCompileAndInit( absl::ReaderMutexLock lock(mutex); auto it = cache->find(key); if (ABSL_PREDICT_TRUE(it != cache->end())) - return it->second.GetHostLaunch(); + return &it->second; } absl::MutexLock lock(mutex); // We released the reader lock, another thread might have initialized it. if (cache->find(key) == cache->end()) { + tsl::profiler::TraceMe trace("Compilation cache miss"); auto compiled = CompileAndInit(module); if (!compiled.ok()) { return compiled.status(); } cache->insert_or_assign(key, std::move(*compiled)); } - return cache->at(key).GetHostLaunch(); + return &cache->at(key); } void MosaicGPUCustomCall(void* stream, void** buffers, char* opaque, size_t opaque_len, XlaCustomCallStatus* status) { + // Forward-compatible version using the legacy FFI API if (reinterpret_cast(opaque) % alignof(KernelHash)) { fprintf(stderr, "Misaligned opaque pointer\n"); abort(); @@ -541,20 +732,92 @@ void MosaicGPUCustomCall(void* stream, void** buffers, char* opaque, abort(); } CacheKey key(hash, reinterpret_cast(ctx)); - auto ctx_and_kernel = CachedCompileAndInit(key, opaque + sizeof(KernelHash)); - if (!ctx_and_kernel.ok()) { + auto compiled_kernel = CachedCompileAndInit(key, opaque + sizeof(KernelHash)); + if (!compiled_kernel.ok()) { XlaCustomCallStatusSetFailure(status, - ctx_and_kernel.status().message().data(), - ctx_and_kernel.status().message().size()); + compiled_kernel.status().message().data(), + compiled_kernel.status().message().size()); return; } - void* args[4] = {&std::get<0>(*ctx_and_kernel), &stream, &buffers}; - std::get<1>(*ctx_and_kernel)(args); + auto ctx_kernel_comm = (*compiled_kernel)->GetHostLaunch(); + bool is_comm_used = std::get<2>(ctx_kernel_comm); + void* args[4] = {&std::get<0>(ctx_kernel_comm), &stream, &buffers}; + if (is_comm_used) { + mosaic::gpu::NvshmemApi::Default().barrier_all_on_stream( + reinterpret_cast(stream)); + } + std::get<1>(ctx_kernel_comm)(args); } XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("mosaic_gpu", &MosaicGPUCustomCall, "CUDA"); +absl::Status MosaicGpuExecute(gpuStream_t stream, ffi::RemainingArgs inputs, + ffi::RemainingRets results, + absl::string_view kernel_hash, + absl::string_view module, + bool use_custom_barrier, + xla::RunId run_id) { + // Updated version using the new FFI API supporting custom barrier + // for distributed kernels + if (use_custom_barrier) { + fprintf(stderr, "Custom barrier is not supported on GPUs.\n"); + abort(); + } + if (reinterpret_cast(kernel_hash.data()) % + alignof(KernelHash) || + kernel_hash.size() != sizeof(KernelHash)) { + fprintf(stderr, "Misaligned opaque pointer\n"); + abort(); + } + auto hash = *reinterpret_cast(kernel_hash.data()); + CUcontext ctx; + if (cuCtxGetCurrent(&ctx) != CUDA_SUCCESS) { + fprintf(stderr, "Failed to get current CUDA context\n"); + abort(); + } + CacheKey key(hash, reinterpret_cast(ctx)); + TF_ASSIGN_OR_RETURN(auto compiled_kernel, CachedCompileAndInit(key, module.data())); + auto ctx_kernel_comm = compiled_kernel->GetHostLaunch(); + bool is_comm_used = std::get<2>(ctx_kernel_comm); + + std::vector buffers; + buffers.reserve(inputs.size() + results.size()); + for (int i = 0; i < inputs.size(); ++i) { + buffers.push_back(inputs.get(i)->untyped_data()); + } + for (int i = 0; i < results.size(); ++i) { + buffers.push_back((*results.get(i))->untyped_data()); + } + void **buffers_ptr = buffers.data(); + void *args[4] = {&std::get<0>(ctx_kernel_comm), &stream, &buffers_ptr}; + + if (is_comm_used) { + mosaic::gpu::NvshmemApi::Default().barrier_all_on_stream( + reinterpret_cast(stream)); + } + std::get<1>(ctx_kernel_comm)(args); + return absl::OkStatus(); +} + +XLA_FFI_DEFINE_HANDLER(kMosaicGpuExecute, MosaicGpuExecute, + ffi::Ffi::Bind() + .Ctx>() + .RemainingArgs() + .RemainingRets() + .Attr("kernel_hash") + .Attr("module") + .Attr("use_custom_barrier") + .Ctx()); + +XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "mosaic_gpu_v2", "CUDA", + { + /*instantiate=*/nullptr, + /*prepare=*/nullptr, + /*initialize=*/nullptr, + /*execute=*/kMosaicGpuExecute, + }); + } // namespace extern "C" { @@ -565,7 +828,7 @@ void** MosaicGpuCompile(const char* module) { if (!compiled.ok()) { return nullptr; } - auto [ctx, launch] = compiled->GetHostLaunch(); + auto [ctx, launch, is_comm_used] = compiled->GetHostLaunch(); auto tuple_ptr = std::unique_ptr(new void*[3]); if (!tuple_ptr) { return nullptr; diff --git a/jaxlib/mosaic/gpu/launch_lowering.cc b/jaxlib/mosaic/gpu/launch_lowering.cc index 0331d800ec50..44362e825345 100644 --- a/jaxlib/mosaic/gpu/launch_lowering.cc +++ b/jaxlib/mosaic/gpu/launch_lowering.cc @@ -31,29 +31,29 @@ limitations under the License. #include #include -#include "llvm/include/llvm/ADT/STLExtras.h" -#include "llvm/include/llvm/ADT/StringRef.h" -#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h" -#include "mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h" -#include "mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h" -#include "mlir/include/mlir/IR/Builders.h" -#include "mlir/include/mlir/IR/BuiltinAttributes.h" -#include "mlir/include/mlir/IR/BuiltinOps.h" -#include "mlir/include/mlir/IR/DialectRegistry.h" -#include "mlir/include/mlir/IR/Location.h" -#include "mlir/include/mlir/IR/SymbolTable.h" -#include "mlir/include/mlir/IR/TypeRange.h" -#include "mlir/include/mlir/IR/Value.h" -#include "mlir/include/mlir/IR/ValueRange.h" -#include "mlir/include/mlir/IR/Visitors.h" -#include "mlir/include/mlir/Interfaces/DataLayoutInterfaces.h" -#include "mlir/include/mlir/Pass/Pass.h" -#include "mlir/include/mlir/Pass/PassRegistry.h" -#include "mlir/include/mlir/Support/LLVM.h" -#include "mlir/include/mlir/Support/LogicalResult.h" -#include "mlir/include/mlir/Support/TypeID.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringRef.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMAttrs.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/DialectRegistry.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/IR/TypeRange.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/IR/Visitors.h" +#include "mlir/Interfaces/DataLayoutInterfaces.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassRegistry.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Support/TypeID.h" namespace mosaic { namespace gpu { @@ -238,7 +238,7 @@ mlir::LogicalResult launchPreloadedKernel(mlir::func::FuncOp func, cluster = as_32bit(launch.getClusterSizeOperandValues()); } else { cluster.x = cluster.y = cluster.z = builder.create( - launch.getLoc(), builder.getI32Type(), builder.getI32IntegerAttr(0)); + launch.getLoc(), builder.getI32Type(), builder.getI32IntegerAttr(0)); } mlir::Value stream = launch.getAsyncObject(); builder.create( @@ -299,6 +299,7 @@ class GpuLaunchLoweringPass : public ::mlir::OperationPass { init_func->setAttr(mlir::LLVM::LLVMDialect::getEmitCWrapperAttrName(), mlir::UnitAttr::get(func->getContext())); bool had_launch = false; + mlir::Operation *gpu_binary = nullptr; auto result = getOperation()->walk([&](mlir::gpu::LaunchFuncOp launch) -> mlir::WalkResult { if (had_launch) { @@ -314,6 +315,7 @@ class GpuLaunchLoweringPass : public ::mlir::OperationPass { << launch.getKernelModuleName(); return mlir::WalkResult::interrupt(); } + gpu_binary = binary.getOperation(); if (binary.getObjects().size() != 1) { binary.emitOpError("Expected exactly one object in the binary."); return mlir::WalkResult::interrupt(); @@ -335,15 +337,16 @@ class GpuLaunchLoweringPass : public ::mlir::OperationPass { launch.getDynamicSharedMemorySize(), cluster_shape); // Add a new function argument for the kernel handle. - func.insertArgument(0, ptr_ty, - mlir::DictionaryAttr::get(func.getContext()), - mlir::UnknownLoc::get(func.getContext())); + if (failed(func.insertArgument( + 0, ptr_ty, mlir::DictionaryAttr::get(func.getContext()), + mlir::UnknownLoc::get(func.getContext())))) { + return mlir::WalkResult::interrupt(); + } mlir::Value kernel_handle = func.getArgument(0); if (launchPreloadedKernel(func, launch, kernel_handle).failed()) { return mlir::WalkResult::interrupt(); } launch.erase(); - // TODO(apaszke): Generate a destructor function. // builder.CreateCall(getModuleUnloadFn(), {moduleObject}); @@ -352,6 +355,13 @@ class GpuLaunchLoweringPass : public ::mlir::OperationPass { if (!had_launch) { init_func.erase(); } + if (gpu_binary) { + // This deletion is load-bearing: the conversion of `gpu.binary` to + // LLVM is side-effecting, as it creates module constructors and + // destructors which create an assumption that symbols from the MLIR + // runtime are available. + gpu_binary->erase(); + } if (result == mlir::WalkResult::interrupt()) { signalPassFailure(); } diff --git a/jaxlib/mosaic/gpu/library_paths.h b/jaxlib/mosaic/gpu/library_paths.h new file mode 100644 index 000000000000..83d523ac3ccc --- /dev/null +++ b/jaxlib/mosaic/gpu/library_paths.h @@ -0,0 +1,31 @@ +/* Copyright 2025 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_MOSAIC_GPU_LIBRARY_PATHS_H_ +#define JAXLIB_MOSAIC_GPU_LIBRARY_PATHS_H_ + +#include + +namespace mosaic { +namespace gpu { + +inline const char *GetCUDARoot() { + return getenv("CUDA_ROOT"); +} + +} // namespace gpu +} // namespace mosaic + +#endif // JAXLIB_MOSAIC_GPU_LIBRARY_PATHS_H_ diff --git a/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc b/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc index 4f804c9e2116..decdbaef28e1 100644 --- a/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc +++ b/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc @@ -22,11 +22,11 @@ limitations under the License. #include #include -#include "nanobind/nanobind.h" -#include "nanobind/stl/tuple.h" -#include "nanobind/stl/vector.h" #include "absl/cleanup/cleanup.h" #include "absl/strings/str_cat.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/tuple.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep #include "jaxlib/gpu/vendor.h" #include "jaxlib/kernel_nanobind_helpers.h" #include "xla/ffi/api/c_api.h" @@ -98,19 +98,21 @@ static const auto* kEventElapsed = .Ret>() // elapsed_ms .To([](gpuStream_t stream, auto start, auto end, auto out) { gpuStreamSynchronize(stream); - auto start_event = std::make_unique(); - auto end_event = std::make_unique(); - absl::MakeCleanup([&]() { - gpuEventDestroy(*start_event); - gpuEventDestroy(*end_event); - }); - gpuMemcpy(start_event.get(), start.untyped_data(), sizeof(gpuEvent_t), + gpuEvent_t start_event = nullptr; + gpuEvent_t end_event = nullptr; + + absl::Cleanup cleanup = [&]() { + gpuEventDestroy(start_event); + gpuEventDestroy(end_event); + }; + + gpuMemcpy(&start_event, start.untyped_data(), sizeof(gpuEvent_t), gpuMemcpyDeviceToHost); - gpuMemcpy(end_event.get(), end.untyped_data(), sizeof(gpuEvent_t), + gpuMemcpy(&end_event, end.untyped_data(), sizeof(gpuEvent_t), gpuMemcpyDeviceToHost); + float elapsed; - if (auto res = - gpuEventElapsedTime(&elapsed, *start_event, *end_event); + if (auto res = gpuEventElapsedTime(&elapsed, start_event, end_event); res) { return ffi::Error::Internal(absl::StrCat( "Failed to get elapsed time between events: ", ToString(res))); @@ -193,6 +195,12 @@ void callback_complete(CUcontext context, uint32_t streamId, THROW_IF_CUPTI_ERROR(status); } } + + size_t num_dropped; + THROW_IF_CUPTI_ERROR( + cuptiActivityGetNumDroppedRecords(context, streamId, &num_dropped), + "failed to get number of dropped activity records"); + THROW_IF(num_dropped > 0, "activity records were dropped"); } NB_MODULE(_mosaic_gpu_ext, m) { @@ -237,15 +245,23 @@ NB_MODULE(_mosaic_gpu_ext, m) { cuptiActivityEnable(CUPTI_ACTIVITY_KIND_CONCURRENT_KERNEL), "failed to enable tracking of kernel activity by CUPTI"); }); - m.def("_cupti_get_timings", []() { - THROW_IF_CUPTI_ERROR( - cuptiActivityFlushAll(CUPTI_ACTIVITY_FLAG_FLUSH_FORCED), - "failed to flush CUPTI activity buffers"); - THROW_IF_CUPTI_ERROR(cuptiFinalize(), "failed to detach CUPTI"); - THROW_IF_CUPTI_ERROR(cuptiUnsubscribe(profiler_state.subscriber), - "failed to unsubscribe from CUPTI"); - return profiler_state.timings; - }); + m.def( + "_cupti_get_timings", + [](bool finalize) { + THROW_IF_CUPTI_ERROR( + cuptiActivityDisable(CUPTI_ACTIVITY_KIND_CONCURRENT_KERNEL), + "failed to disable tracking of kernel activity by CUPTI"); + THROW_IF_CUPTI_ERROR( + cuptiActivityFlushAll(CUPTI_ACTIVITY_FLAG_FLUSH_FORCED), + "failed to flush CUPTI activity buffers"); + if (finalize) { + THROW_IF_CUPTI_ERROR(cuptiFinalize(), "failed to detach CUPTI"); + } + THROW_IF_CUPTI_ERROR(cuptiUnsubscribe(profiler_state.subscriber), + "failed to unsubscribe from CUPTI"); + return profiler_state.timings; + }, + nb::arg("finalize") = true); } } // namespace diff --git a/jaxlib/mosaic/gpu/nvshmem.h b/jaxlib/mosaic/gpu/nvshmem.h new file mode 100644 index 000000000000..dbd11aa1d373 --- /dev/null +++ b/jaxlib/mosaic/gpu/nvshmem.h @@ -0,0 +1,94 @@ +/* Copyright 2025 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_MOSAIC_GPU_COMM_H_ +#define JAXLIB_MOSAIC_GPU_COMM_H_ + +#include + +#include +#include +#include + +#include "third_party/gpus/cuda/include/cuda.h" +#include "cuda_runtime_api.h" + +#define NVSHMEM_SUCCESS 0 + +namespace mosaic { +namespace gpu { + +#define NVSHMEM_SET_FN(FnName) \ + FnName = reinterpret_cast(dlsym(library, #FnName)); \ + if (!FnName) { \ + fprintf(stderr, #FnName " not available in this library."); \ + } + +class NvshmemApi { + public: + // Returns a default NvshmemApi for a current process. + // NvshmemApi follows the Singleton design pattern + static NvshmemApi& Default(bool assert_ok = true) { + static NvshmemApi instance; + if (assert_ok && !instance.is_loaded()) { + fprintf(stderr, "Failed to load the NVSHMEM library.\n"); + abort(); + } + return instance; + } + + int cumodule_init(CUmodule module) { + std::lock_guard lock(mutex_); + return nvshmemx_cumodule_init(module); + } + + void barrier_all_on_stream(cudaStream_t stream) { + nvshmemx_barrier_all_on_stream(stream); + } + + bool is_loaded() { + return nvshmemx_init_status != nullptr && nvshmemx_init_status() == 2; + } + + NvshmemApi(NvshmemApi const&) = delete; + void operator=(NvshmemApi const&) = delete; + + private: + NvshmemApi() { + const char* env_value = getenv("MOSAIC_GPU_NVSHMEM_SO_PATH"); + const char* libnvshmem_path = + env_value && *env_value != 0 ? env_value : nullptr; + void* library = dlopen(libnvshmem_path, RTLD_LAZY); + if (library == nullptr) { + fprintf(stderr, "Failed to open library (from %s): %s", + libnvshmem_path ? libnvshmem_path : "", dlerror()); + } + + NVSHMEM_SET_FN(nvshmemx_barrier_all_on_stream) + NVSHMEM_SET_FN(nvshmemx_cumodule_init) + NVSHMEM_SET_FN(nvshmemx_init_status) + } + + int (*nvshmemx_barrier_all_on_stream)(cudaStream_t); + int (*nvshmemx_cumodule_init)(CUmodule); + int (*nvshmemx_init_status)(); + + std::mutex mutex_; +}; + +} // namespace gpu +} // namespace mosaic + +#endif // JAXLIB_MOSAIC_GPU_COMM_H_ diff --git a/jaxlib/mosaic/gpu/passes.cc b/jaxlib/mosaic/gpu/passes.cc index b8c3fbb74c81..b5325e97d4ad 100644 --- a/jaxlib/mosaic/gpu/passes.cc +++ b/jaxlib/mosaic/gpu/passes.cc @@ -14,24 +14,28 @@ limitations under the License. ==============================================================================*/ #include "jaxlib/mosaic/gpu/passes.h" + #include #include #include #include -#include "llvm/include/llvm/ADT/StringRef.h" -#include "mlir/include/mlir/Conversion/GPUCommon/GPUCommonPass.h" -#include "mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h" -#include "mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h" -#include "mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "mlir/include/mlir/Dialect/Vector/IR/VectorOps.h" -#include "mlir/include/mlir/IR/BuiltinAttributes.h" -#include "mlir/include/mlir/IR/BuiltinOps.h" -#include "mlir/include/mlir/IR/SymbolTable.h" -#include "mlir/include/mlir/Pass/PassRegistry.h" -#include "mlir/include/mlir/Support/LLVM.h" -#include "mlir/include/mlir/Transforms/DialectConversion.h" -#include "jaxlib/pass_boilerplate.h" +#include "llvm/ADT/StringRef.h" +#include "mlir/Conversion/GPUCommon/GPUCommonPass.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/Visitors.h" +#include "mlir/Pass/PassRegistry.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Transforms/DialectConversion.h" +#include "jaxlib/mosaic/pass_boilerplate.h" namespace mosaic { namespace gpu { @@ -50,14 +54,14 @@ struct ConvertExtractStridedSlicePattern final return rewriter.notifyMatchFailure(op, "only 1-D vectors are supported"); } int64_t size = - (*op.getSizes().getAsRange().begin()).getSInt(); + (*op.getSizes().getAsRange().begin()).getInt(); if (size < 0) { return rewriter.notifyMatchFailure(op, "size is negative"); } int64_t start = - (*op.getOffsets().getAsRange().begin()).getSInt(); + (*op.getOffsets().getAsRange().begin()).getInt(); int64_t stride = - (*op.getStrides().getAsRange().begin()).getSInt(); + (*op.getStrides().getAsRange().begin()).getInt(); if (stride != 1) { return rewriter.notifyMatchFailure(op, "only stride 1 is supported"); } diff --git a/jaxlib/mosaic/gpu/runtime.cc b/jaxlib/mosaic/gpu/runtime.cc index ad3cd0e19644..da7b0159d7b2 100644 --- a/jaxlib/mosaic/gpu/runtime.cc +++ b/jaxlib/mosaic/gpu/runtime.cc @@ -18,11 +18,12 @@ limitations under the License. #include #include "third_party/gpus/cuda/include/cuda.h" +#include "jaxlib/mosaic/gpu/nvshmem.h" extern "C" { void mosaic_gpu_init_tma_desc(CUtensorMap *tma_desc, void *base_addr, - int64_t elem_bitwidth, int64_t rank, + int64_t elem_type, int64_t rank, int64_t *sizes, int64_t *strides, int64_t swizzle_bytes, int64_t *window_shape) { if (((uintptr_t)tma_desc) % 64 != 0) { @@ -32,6 +33,39 @@ void mosaic_gpu_init_tma_desc(CUtensorMap *tma_desc, void *base_addr, abort(); } + CUtensorMapDataType data_type; + int64_t elem_bitwidth; + // types are defined in: LaunchContext._get_tma_desc() + if (elem_type == 0){ + // this is for int4s + data_type = CU_TENSOR_MAP_DATA_TYPE_UINT8; + elem_bitwidth = 4; + } else if (elem_type == 1){ + data_type = CU_TENSOR_MAP_DATA_TYPE_UINT8; + elem_bitwidth = 8; + } else if (elem_type == 2){ + data_type = CU_TENSOR_MAP_DATA_TYPE_UINT16; + elem_bitwidth = 16; + } else if (elem_type == 3){ + data_type = CU_TENSOR_MAP_DATA_TYPE_UINT32; + elem_bitwidth = 32; + } else if (elem_type == 4){ + data_type = CU_TENSOR_MAP_DATA_TYPE_UINT64; + elem_bitwidth = 64; + } else if (elem_type == 5){ + data_type = CU_TENSOR_MAP_DATA_TYPE_FLOAT16; + elem_bitwidth = 16; + } else if (elem_type == 6){ + data_type = CU_TENSOR_MAP_DATA_TYPE_FLOAT32; + elem_bitwidth = 32; + } else if (elem_type == 7){ + data_type = CU_TENSOR_MAP_DATA_TYPE_BFLOAT16; + elem_bitwidth = 16; + } else{ + fprintf(stderr, "Unsupported element type: %ld \n", elem_type); + abort(); + } + // Pack 4 bit types in 8 bit pairs. int64_t elem_bytewidth; if (elem_bitwidth < 8) { @@ -54,19 +88,6 @@ void mosaic_gpu_init_tma_desc(CUtensorMap *tma_desc, void *base_addr, elem_bytewidth = elem_bitwidth / 8; } - CUtensorMapDataType data_type; - if (elem_bytewidth == 1) { - data_type = CU_TENSOR_MAP_DATA_TYPE_UINT8; - } else if (elem_bytewidth == 2) { - data_type = CU_TENSOR_MAP_DATA_TYPE_UINT16; - } else if (elem_bytewidth == 4) { - data_type = CU_TENSOR_MAP_DATA_TYPE_UINT32; - } else if (elem_bytewidth == 8) { - data_type = CU_TENSOR_MAP_DATA_TYPE_UINT64; - } else { - fprintf(stderr, "Unsupported element size: %ld\n", elem_bytewidth); - abort(); - } if (rank < 1 || rank > 5) { fprintf(stderr, "Rank must be in [1, 5], but got %ld\n", rank); abort(); @@ -94,7 +115,7 @@ void mosaic_gpu_init_tma_desc(CUtensorMap *tma_desc, void *base_addr, if (tma_stride_i % 16 != 0 || tma_stride_i >= static_cast(1) << 40) { fprintf(stderr, - "Byte strides must be divisble by 16 and less than 2**40, but " + "Byte strides must be divisible by 16 and less than 2**40, but " "got %ld (item stride = %ld, item size = %ld) at index %ld\n", tma_stride_i, strides[rank - 1], elem_bytewidth, rank - i - 2); abort(); @@ -154,6 +175,20 @@ void* mosaic_gpu_module_load(void *data) { fprintf(stderr, "cuModuleLoadData failed: %s\n", ptr); abort(); } + + { // Set the NVSHMEM state if it's used by the module. + CUdeviceptr ptr = 0; + size_t size = 0; + if (cuModuleGetGlobal(&ptr, &size, module, + "nvshmemi_device_lib_version_d") == CUDA_SUCCESS) { + if (mosaic::gpu::NvshmemApi::Default().cumodule_init(module) != + NVSHMEM_SUCCESS) { + fprintf(stderr, "nvshmemx_cumodule_init failed.\n"); + abort(); + } + } + } + return module; } diff --git a/jaxlib/mosaic/gpu/serde.cc b/jaxlib/mosaic/gpu/serde.cc index f4cf846acc11..5fca1d445774 100644 --- a/jaxlib/mosaic/gpu/serde.cc +++ b/jaxlib/mosaic/gpu/serde.cc @@ -15,10 +15,10 @@ limitations under the License. #include "jaxlib/mosaic/gpu/serde.h" -#include "llvm/include/llvm/ADT/StringMap.h" -#include "llvm/include/llvm/ADT/StringRef.h" -#include "mlir/include/mlir/IR/BuiltinOps.h" -#include "mlir/include/mlir/Support/LLVM.h" +#include "llvm/ADT/StringMap.h" +#include "llvm/ADT/StringRef.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Support/LLVM.h" #include "jaxlib/mosaic/serde.h" namespace mosaic::gpu { diff --git a/jaxlib/mosaic/gpu/serde.h b/jaxlib/mosaic/gpu/serde.h index 6187d72b4cd5..29dda33d0c5a 100644 --- a/jaxlib/mosaic/gpu/serde.h +++ b/jaxlib/mosaic/gpu/serde.h @@ -19,13 +19,13 @@ limitations under the License. #include #include -#include "llvm/include/llvm/ADT/StringRef.h" -#include "llvm/include/llvm/Support/CommandLine.h" -#include "mlir/include/mlir/IR/BuiltinOps.h" -#include "mlir/include/mlir/Interfaces/DataLayoutInterfaces.h" -#include "mlir/include/mlir/Pass/Pass.h" -#include "mlir/include/mlir/Pass/PassRegistry.h" -#include "jaxlib/pass_boilerplate.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/CommandLine.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Interfaces/DataLayoutInterfaces.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassRegistry.h" +#include "jaxlib/mosaic/pass_boilerplate.h" namespace mosaic::gpu { diff --git a/jaxlib/mosaic/gpu/target.cc b/jaxlib/mosaic/gpu/target.cc index a1a66a709cbe..d26b1f1ccbf7 100644 --- a/jaxlib/mosaic/gpu/target.cc +++ b/jaxlib/mosaic/gpu/target.cc @@ -16,20 +16,21 @@ limitations under the License. #include #include -#include #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/match.h" +#include "absl/strings/numbers.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" -#include "llvm/include/llvm/MC/MCSubtargetInfo.h" -#include "llvm/include/llvm/MC/TargetRegistry.h" +#include "absl/strings/string_view.h" +#include "absl/strings/strip.h" +#include "llvm/MC/MCSubtargetInfo.h" +#include "llvm/MC/TargetRegistry.h" namespace mosaic::gpu { -absl::StatusOr> GetSmAndPtxIsaVersion( - int major, int minor) { +absl::StatusOr GetSmVersion(int major, int minor) { // "base" compute capability as reported by the driver. // For example for a Hopper H200 GPU this would return sm_90, and never // sm_90a. @@ -64,25 +65,41 @@ absl::StatusOr> GetSmAndPtxIsaVersion( } } } + return sm_arch_specific ? sm_arch_specific : sm_base; +} - const std::string sm = sm_arch_specific ? sm_arch_specific : sm_base; - +absl::StatusOr GetLatestLlvmPtxIsaVersion() { + const std::string triple = "nvptx64-nvidia-cuda"; + std::string error; + const llvm::Target* target = + llvm::TargetRegistry::lookupTarget(triple, error); + if (target == nullptr) { + return absl::InternalError(absl::StrFormat( + "Failed to lookup LLVM target based on triple %s: %s", triple, error)); + } + // generic subtarget std::unique_ptr subtarget_info{ - target->createMCSubtargetInfo(triple, sm, "")}; + target->createMCSubtargetInfo(triple, "", "")}; if (subtarget_info == nullptr) { - return absl::InternalError( - absl::StrFormat("Failed to get LLVM subtarget info for sm %s", sm)); + return absl::InternalError(absl::StrFormat( + "Failed to get generic LLVM subtarget info for triple %s", triple)); } - + int llvm_latest_version = 0; for (const llvm::SubtargetFeatureKV& feature : - subtarget_info->getEnabledProcessorFeatures()) { - if (absl::StartsWith(feature.Key, "ptx")) { - std::string ptx_isa = feature.Key; - return std::make_pair(sm, ptx_isa); + subtarget_info->getAllProcessorFeatures()) { + absl::string_view version_string = feature.Key; + if (absl::ConsumePrefix(&version_string, "ptx")) { + int version; + if (!absl::SimpleAtoi(version_string, &version)) { + return absl::InternalError( + absl::StrFormat("Failed to convert PTX ISA version to integer: %s", + version_string)); + } + llvm_latest_version = + version > llvm_latest_version ? version : llvm_latest_version; } } - return absl::InternalError(absl::StrFormat( - "Failed to find a PTX ISA LLVM subtarget feature for %s", sm)); + return llvm_latest_version; } } // namespace mosaic::gpu diff --git a/jaxlib/mosaic/gpu/target.h b/jaxlib/mosaic/gpu/target.h index 070ecedebd01..5a2a240d8db1 100644 --- a/jaxlib/mosaic/gpu/target.h +++ b/jaxlib/mosaic/gpu/target.h @@ -22,8 +22,8 @@ limitations under the License. namespace mosaic::gpu { -absl::StatusOr> GetSmAndPtxIsaVersion( - int major, int minor); +absl::StatusOr GetSmVersion(int major, int minor); +absl::StatusOr GetLatestLlvmPtxIsaVersion(); } // namespace mosaic::gpu diff --git a/jaxlib/pass_boilerplate.h b/jaxlib/mosaic/pass_boilerplate.h similarity index 87% rename from jaxlib/pass_boilerplate.h rename to jaxlib/mosaic/pass_boilerplate.h index b9754a8738ee..96d9e85a1d2d 100644 --- a/jaxlib/pass_boilerplate.h +++ b/jaxlib/mosaic/pass_boilerplate.h @@ -13,15 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef JAXLIB_PASS_BOILERPLATE_H_ -#define JAXLIB_PASS_BOILERPLATE_H_ +#ifndef JAXLIB_MOSAIC_PASS_BOILERPLATE_H_ +#define JAXLIB_MOSAIC_PASS_BOILERPLATE_H_ #include -#include "mlir/include/mlir/IR/DialectRegistry.h" -#include "mlir/include/mlir/Pass/Pass.h" -#include "mlir/include/mlir/Support/LLVM.h" -#include "mlir/include/mlir/Support/TypeID.h" +#include "mlir/IR/DialectRegistry.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/TypeID.h" namespace jaxlib { namespace mlir { @@ -64,4 +64,4 @@ class Pass : public ::mlir::OperationPass { } // namespace mlir } // namespace jaxlib -#endif // JAXLIB_PASS_BOILERPLATE_H_ +#endif // JAXLIB_MOSAIC_PASS_BOILERPLATE_H_ diff --git a/jaxlib/mosaic/python/tpu.py b/jaxlib/mosaic/python/tpu.py index a1c7f79ba769..8083b9759f1b 100644 --- a/jaxlib/mosaic/python/tpu.py +++ b/jaxlib/mosaic/python/tpu.py @@ -19,6 +19,7 @@ # pylint: disable=g-bad-import-order +from . import _tpu_gen from ._tpu_gen import * # pylint: disable=wildcard-import from ._tpu_gen import _Dialect from jaxlib.mlir._mlir_libs._tpu_ext import * # pylint: disable=wildcard-import @@ -32,7 +33,7 @@ @_cext.register_operation(_Dialect, replace=True) -class TraceOp(TraceOp): # noqa: F405 +class TraceOp(_tpu_gen.TraceOp): # noqa: F405 """An extension to the automatically generated TraceOp bindings.""" def __init__(self, results, message, level, *, loc=None, ip=None): @@ -45,7 +46,7 @@ def body(self): @_cext.register_operation(_Dialect, replace=True) -class RegionOp(RegionOp): # noqa: F405 +class RegionOp(_tpu_gen.RegionOp): # noqa: F405 """An extension to the automatically generated RegionOp bindings.""" def __init__(self, results, *, loc=None, ip=None): diff --git a/jaxlib/mosaic/serde.cc b/jaxlib/mosaic/serde.cc index 88bca44bf181..307164d91dd9 100644 --- a/jaxlib/mosaic/serde.cc +++ b/jaxlib/mosaic/serde.cc @@ -18,15 +18,15 @@ limitations under the License. #include #include -#include "llvm/include/llvm/ADT/StringMap.h" -#include "llvm/include/llvm/ADT/StringRef.h" -#include "mlir/include/mlir/IR/BuiltinAttributes.h" -#include "mlir/include/mlir/IR/BuiltinOps.h" -#include "mlir/include/mlir/IR/Operation.h" -#include "mlir/include/mlir/IR/OperationSupport.h" -#include "mlir/include/mlir/IR/Visitors.h" -#include "mlir/include/mlir/Interfaces/DataLayoutInterfaces.h" -#include "mlir/include/mlir/Support/LLVM.h" +#include "llvm/ADT/StringMap.h" +#include "llvm/ADT/StringRef.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/IR/Visitors.h" +#include "mlir/Interfaces/DataLayoutInterfaces.h" +#include "mlir/Support/LLVM.h" namespace jaxlib::mosaic { diff --git a/jaxlib/mosaic/serde.h b/jaxlib/mosaic/serde.h index 762d9e5dad73..fdcaf58d4a8e 100644 --- a/jaxlib/mosaic/serde.h +++ b/jaxlib/mosaic/serde.h @@ -18,11 +18,11 @@ limitations under the License. #include -#include "llvm/include/llvm/ADT/StringMap.h" -#include "llvm/include/llvm/ADT/StringRef.h" -#include "mlir/include/mlir/IR/BuiltinAttributes.h" -#include "mlir/include/mlir/IR/BuiltinOps.h" -#include "mlir/include/mlir/Support/LLVM.h" +#include "llvm/ADT/StringMap.h" +#include "llvm/ADT/StringRef.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Support/LLVM.h" namespace jaxlib::mosaic { diff --git a/jaxlib/nb_class_ptr.h b/jaxlib/nb_class_ptr.h new file mode 100644 index 000000000000..f1214c19369e --- /dev/null +++ b/jaxlib/nb_class_ptr.h @@ -0,0 +1,59 @@ +/* Copyright 2024 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_NB_CLASS_PTR_H_ +#define JAXLIB_NB_CLASS_PTR_H_ + +#include "nanobind/nanobind.h" + +namespace xla { + +// A reference-counting smart pointer to a nanobind-wrapped class on the Python +// heap. Type T must be a class known to nanobind via a nanobind::class_ +// declaration. nb_class_ptr is useful for managing C++ classes that may be +// allocated inline in Python objects on the Python heap. +template +class nb_class_ptr : public nanobind::object { + public: + inline nb_class_ptr() : nanobind::object() {} + inline nb_class_ptr(nanobind::handle h, ::nanobind::detail::borrow_t) + : nanobind::object(h, ::nanobind::detail::borrow_t{}) {} + inline nb_class_ptr(nanobind::handle h, ::nanobind::detail::steal_t) + : nanobind::object(h, ::nanobind::detail::steal_t{}) {} + inline static bool check_(nanobind::handle h) { + nanobind::handle type = nanobind::type(); + return nanobind::isinstance(h, type); + }; + + T* operator->() const { return nanobind::inst_ptr(ptr()); } + T& operator*() const { return *nanobind::inst_ptr(ptr()); } + T* get() const { return ptr() ? nanobind::inst_ptr(ptr()) : nullptr; } +}; + +// This function is analogous to std::make_unique(...), but instead it +// allocates the object on the Python heap +template +nb_class_ptr make_nb_class(Args&&... args) { + nanobind::handle type = nanobind::type(); + nanobind::object instance = nanobind::inst_alloc(type); + T* ptr = nanobind::inst_ptr(instance); + new (ptr) T(std::forward(args)...); + nanobind::inst_mark_ready(instance); + return nb_class_ptr(instance.release(), ::nanobind::detail::steal_t{}); +} + +} // namespace xla + +#endif // JAXLIB_NB_CLASS_PTR_H_ diff --git a/jaxlib/partition_spec.cc b/jaxlib/partition_spec.cc new file mode 100644 index 000000000000..2535c38b977b --- /dev/null +++ b/jaxlib/partition_spec.cc @@ -0,0 +1,245 @@ +/* Copyright 2025 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/partition_spec.h" + +#include +#include +#include + +#include "absl/base/casts.h" +#include "absl/hash/hash.h" +#include "absl/strings/str_format.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/optional.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep + +namespace nb = nanobind; + +namespace jax { + +/*static*/ PyObject* nb_frozenset::nb_frozenset_from_obj(PyObject* o) { + PyObject* result = PyFrozenSet_New(o); + if (!result) { + throw nb::python_error(); + } + return result; +} + +template +bool nb_frozenset::contains(T&& key) const { + object o = nanobind::cast((nb::detail::forward_t)key); + int rv = PySet_Contains(m_ptr, o.ptr()); + if (rv == -1) { + throw nb::python_error(); + } + return rv == 1; +} + +namespace { + +bool IsTrue(nb::handle x) { + int ret = PyObject_IsTrue(x.ptr()); + if (ret == -1) { + throw nb::python_error(); + } + return static_cast(ret); +} + +nb::object CanonicalizePartition(nb::object unconstrained_singleton, + nb::object partition) { + if (!IsTrue(partition)) { + return nb::none(); + } + if (partition.is(unconstrained_singleton)) { + return unconstrained_singleton; + } + bool is_tuple = nb::isinstance(partition); + if (is_tuple || nb::isinstance(partition)) { + if (nb::len(partition) == 1) { + return partition[0]; + } + if (!is_tuple) { + return nb::tuple(partition); + } + return partition; + } + return partition; +} + +void CheckPartitionSpec(nb::tuple partitions, nb_frozenset unreduced, + nb_frozenset reduced) { + if (unreduced.contains(nb::none())) { + throw nb::value_error( + "unreduced cannot contain None. All elements in unreduced should " + "refer to the mesh axes."); + } + if (reduced.contains(nb::none())) { + throw nb::value_error( + "reduced cannot contain None. All elements in reduced should " + "refer to the mesh axes."); + } + auto check_overlap = [&](nb::handle partition) { + if (unreduced.contains(partition)) { + throw nb::value_error( + absl::StrFormat( + "partitions cannot overlap with unreduced axes passed to " + "PartitionSpec. Got partitions: %s and unreduced axes: %s", + nb::cast(nb::str(partitions)), + nb::cast(nb::str(unreduced))) + .c_str()); + } + if (reduced.contains(partition)) { + throw nb::value_error( + absl::StrFormat( + "partitions cannot overlap with reduced axes passed to " + "PartitionSpec. Got partitions: %s and reduced axes: %s", + nb::cast(nb::str(partitions)), + nb::cast(nb::str(reduced))) + .c_str()); + } + }; + for (nb::handle partition : partitions) { + if (nb::isinstance(partition)) { + for (nb::handle p : partition) { + check_overlap(p); + } + } else { + check_overlap(partition); + } + } + // TODO(yashkatariya, phawkins): Update this to `!(unreduced & + // reduced).empty()` after nanobind's version > 2.7.0 + if (nb::len((unreduced & reduced)) != 0) { + throw nb::value_error( + absl::StrFormat("`unreduced` and `reduced` argument to PartitionSpec " + "cannot overlap. " + "Got unreduced: %s and reduced: %s", + nb::cast(nb::str(unreduced)), + nb::cast(nb::str(reduced))) + .c_str()); + } +} + +} // namespace + +PartitionSpec::PartitionSpec(nb::tuple partitions, nb_frozenset unreduced, + nb_frozenset reduced) + : partitions_(std::move(partitions)), + unreduced_(std::move(unreduced)), + reduced_(std::move(reduced)) {} + +Py_hash_t PartitionSpec::Hash() const { + size_t h = absl::HashOf(nb::hash(partitions_), nb::hash(unreduced_), + nb::hash(reduced_)); + Py_hash_t s = absl::bit_cast(h); // Python hashes are signed. + return s == -1 ? -2 : s; // -1 must not be used as a Python hash value. +} + +bool PartitionSpec::operator==(const PartitionSpec& other) const { + return partitions().equal(other.partitions()) && + unreduced().equal(other.unreduced()) && + reduced().equal(other.reduced()); +} + +bool PartitionSpec::Eq(const nb::object& other) const { + if (!other.ptr() || other.is_none()) { + return false; + } + PartitionSpec* other_spec; + if (nb::try_cast(other, other_spec)) { + return *this == *other_spec; + } + nb::tuple other_tuple; + if (nb::try_cast(other, other_tuple)) { + if (unreduced().size() > 0 || reduced().size() > 0 || + partitions().size() != other_tuple.size()) { + return false; + } + for (size_t i = 0; i < partitions().size(); ++i) { + if (!partitions()[i].equal(CanonicalizePartition( + *unconstrained_singleton_, other_tuple[i]))) { + return false; + } + } + return true; + } + return false; +} + +nb::object* PartitionSpec::unconstrained_singleton_ = nullptr; + +void PartitionSpec::Register(nb::module_& m) { + nb::class_(m, "UnconstrainedSingleton") + .def("__repr__", [](nb::handle self) { return nb::str("UNCONSTRAINED"); }) + .def("__reduce__", + [](nb::handle self) { return nb::str("UNCONSTRAINED_PARTITION"); }); + + unconstrained_singleton_ = new nb::object(nb::cast(UnconstrainedSingleton())); + m.attr("UNCONSTRAINED_PARTITION") = *unconstrained_singleton_; + + m.def("canonicalize_partition", [](nb::object partition) { + return CanonicalizePartition(*unconstrained_singleton_, partition); + }); + + nb::class_(m, "PartitionSpec") + .def( + "__init__", + [](PartitionSpec* self, nb::args partition_args, + nb::object unreduced_arg, nb::object reduced_arg) { + nb::tuple partitions = + nb::steal(PyTuple_New(partition_args.size())); + for (size_t i = 0; i < partition_args.size(); ++i) { + PyTuple_SET_ITEM(partitions.ptr(), i, + CanonicalizePartition( + *PartitionSpec::unconstrained_singleton_, + partition_args[i]) + .release() + .ptr()); + } + nb_frozenset unreduced; + nb_frozenset reduced; + if (!PyAnySet_Check(unreduced_arg.ptr())) { + throw nb::type_error( + absl::StrFormat( + "unreduced argument of PartitionSpec should " + "of type `frozenset` or `set`. Got type %s", + nb::cast(nb::repr(unreduced_arg.type()))) + .c_str()); + } + if (!PyAnySet_Check(reduced_arg.ptr())) { + throw nb::type_error( + absl::StrFormat( + "reduced argument of PartitionSpec should " + "of type `frozenset` or `set`. Got type %s", + nb::cast(nb::repr(reduced_arg.type()))) + .c_str()); + } + unreduced = nb_frozenset(unreduced_arg); + reduced = nb_frozenset(reduced_arg); + CheckPartitionSpec(partitions, unreduced, reduced); + new (self) PartitionSpec(std::move(partitions), + std::move(unreduced), std::move(reduced)); + }, + nb::arg("partitions"), nb::arg("unreduced") = nb_frozenset(), + nb::arg("reduced") = nb_frozenset()) + .def_prop_ro("_partitions", &PartitionSpec::partitions) + .def_prop_ro("unreduced", &PartitionSpec::unreduced) + .def_prop_ro("reduced", &PartitionSpec::reduced) + .def("__eq__", &PartitionSpec::Eq, nb::arg().none()) + .def("__hash__", &PartitionSpec::Hash); +} + +} // namespace jax diff --git a/jaxlib/partition_spec.h b/jaxlib/partition_spec.h new file mode 100644 index 000000000000..fc207cfe7a28 --- /dev/null +++ b/jaxlib/partition_spec.h @@ -0,0 +1,67 @@ +/* Copyright 2025 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAX_JAXLIB_PARTITION_SPEC_H_ +#define JAX_JAXLIB_PARTITION_SPEC_H_ + +#include + +#include "nanobind/nanobind.h" + +namespace jax { + +struct UnconstrainedSingleton {}; + +class nb_frozenset : public nanobind::object { + NB_OBJECT(nb_frozenset, object, "frozenset", PyFrozenSet_Check) + nb_frozenset() + : object(PyFrozenSet_New(nullptr), nanobind::detail::steal_t()) {} + explicit nb_frozenset(handle h) + : object(nb_frozenset_from_obj(h.ptr()), nanobind::detail::steal_t{}) {} + size_t size() const { return (size_t)NB_SET_GET_SIZE(m_ptr); } + template + bool contains(T&& key) const; + + private: + static PyObject* nb_frozenset_from_obj(PyObject* o); +}; + +class PartitionSpec { + public: + PartitionSpec(nanobind::tuple partitions, nb_frozenset unreduced, + nb_frozenset reduced); + + nanobind::tuple partitions() const { return partitions_; } + nb_frozenset unreduced() const { return unreduced_; } + nb_frozenset reduced() const { return reduced_; } + + bool operator==(const PartitionSpec& other) const; + + bool Eq(const nanobind::object& other) const; // Python __eq__ + Py_hash_t Hash() const; // Python __hash__ + + static void Register(nanobind::module_& m); + + private: + nanobind::tuple partitions_; + nb_frozenset unreduced_; + nb_frozenset reduced_; + + static nanobind::object* unconstrained_singleton_; +}; + +} // namespace jax + +#endif // JAX_JAXLIB_PARTITION_SPEC_H_ diff --git a/jaxlib/pjit.cc b/jaxlib/pjit.cc new file mode 100644 index 000000000000..13b314d7c0c3 --- /dev/null +++ b/jaxlib/pjit.cc @@ -0,0 +1,1397 @@ +/* Copyright 2022 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/pjit.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include // NOLINT +#include +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/cleanup/cleanup.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/inlined_vector.h" +#include "absl/hash/hash.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" +#include "absl/synchronization/notification.h" +#include "absl/types/span.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/optional.h" // IWYU pragma: keep +#include "nanobind/stl/shared_ptr.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "jaxlib/config.h" +#include "jaxlib/guard_lib.h" +#include "jaxlib/jax_jit.h" +#include "jaxlib/nb_class_ptr.h" +#include "jaxlib/py_array.h" +#include "jaxlib/py_client.h" +#include "jaxlib/py_executable.h" +#include "jaxlib/py_values.h" +#include "jaxlib/python_ref_manager.h" +#include "jaxlib/pytree.h" +#include "jaxlib/sharding.h" +#include "jaxlib/traceback.h" +#include "xla/layout.h" +#include "xla/pjrt/exceptions.h" +#include "xla/pjrt/lru_cache.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/device_list.h" +#include "xla/python/ifrt/executable.h" +#include "xla/python/ifrt/memory.h" +#include "xla/python/ifrt/sharding.h" +#include "xla/python/nb_helpers.h" +#include "xla/python/nb_numpy.h" +#include "xla/tsl/concurrency/ref_count.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/util.h" +#include "tsl/profiler/lib/traceme.h" + +namespace jax { +namespace { + +namespace nb = nanobind; + +struct PjitCacheEntry { + explicit PjitCacheEntry(xla::PyTreeRegistry* registry) + : out_pytree_def(registry) {} + std::shared_ptr executable; + std::vector in_shardings; + std::vector out_avals; + std::vector out_dtypes; + std::vector> out_shapes; + std::vector out_weak_types; + std::vector out_shardings; + std::vector out_committed; + xla::PyTreeDef out_pytree_def; + // Bitvector of kept arguments from Jaxpr DCE pass. Used to drop some `args` + // in PjitFunction::Call before calling into compiled computation. + std::vector kept_var_bitvec; + std::vector in_device_local_layouts; + + // Ensures a single thread performs the compilation for a given executable. + // + // The first thread (holding the GIL) will create the CacheEntry associated to + // a signature and if the object has been inserted already, other threads + // will wait for the notification. + absl::Notification compilation_complete; + + std::thread::id thread_id = std::this_thread::get_id(); + + bool fall_back_to_python = false; +}; + +// A PjitFunctionCache represents a cache of compiled functions that can be +// shared between one or more PjitFunction objects. It serves two goals: +// - reduce the number of lru caches (hash map) across multiple JITs. +// - make the cache global to increase cache hits (e.g. calling jit(f)(3) twice) +// keeping entries alive as long as the underlying function f is alive. +// Assume the cache is protected by the GIL. +class PjitFunctionCache { + public: + static constexpr int kDefaultCapacity = 4096; + explicit PjitFunctionCache(int capacity); + + // Cache entries are shared_ptr<>s because it's possible the cache entry + // might be evicted before we finish tracing/compiling. + typedef xla::LRUCache> Cache; + + // We include as part of the cache key `global_cache_key` (and any other + // fields that aren't subsumed by the CallSignature we compute for each call). + static std::shared_ptr Lookup( + xla::nb_class_ptr self, nb::handle function, + nb::object global_cache_key); + std::shared_ptr DefaultCache(); + + // These methods require the GIL or the object's lock in no-GIL mode. + int Size() const { return lru_list_.Size(); } + int Capacity() const { return lru_list_.Capacity(); } + void Clear() { + lru_list_.Clear(); + functions_.clear(); + } + + private: + struct Key { + nb::handle function; // Does not hold a reference. + + // Other fields that are part of the arguments to `jit`, but are not + // otherwise part of CallSignature. + nb::object global_cache_key; + + size_t cached_hash; + + bool operator==(const Key& other) const { + bool global_cache_eq; + try { + global_cache_eq = global_cache_key.equal(other.global_cache_key); + } catch (const nanobind::python_error& e) { + throw std::invalid_argument( + absl::StrCat("Equality of global cache key lead to an exception. " + "The error was:\n", + e.what(), "\n")); + } + return function.ptr() == other.function.ptr() && global_cache_eq; + } + + struct Hash { + size_t operator()(const Key& key) const { return key.cached_hash; } + }; + }; + + template + friend H AbslHashValue(H h, const Key& key) { + h = H::combine(std::move(h), key.function.ptr()); + Py_hash_t hash; + try { + hash = nb::hash(key.global_cache_key); + } catch (const nanobind::python_error& e) { + if (!e.matches(PyExc_TypeError)) throw; + throw std::invalid_argument(absl::StrCat( + "Hashing global cache key lead to an exception. The error was:\n", + e.what(), "\n")); + } + h = H::combine(std::move(h), hash); + return h; + } + + struct Value { + explicit Value(std::shared_ptr cache) : cache(std::move(cache)) {} + std::shared_ptr cache; + + // A weak reference to the key function. We use the weak reference to + // register a callback that is triggered when the key function is destroyed. + // We use a weak pointer because we want to allow caching across multiple + // calls to `pjit(f)` if `f` remains alive, but we do not want the cache + // to keep `f` alive if all other references are dropped. + std::optional weakref; + }; + + // lru_list_ and functions_ are protected by the GIL in GIL mode, and by the + // self object lock in freethreading mode. + Cache::LRUList lru_list_; + // We use std::unordered_map because ABSL containers are not exception safe: + std::unordered_map, Key::Hash> functions_; + // mu_ prevents concurrent insertions into functions_ if the gil or critical + // section lock is released during insertion. + absl::Mutex mu_; +}; + +PjitFunctionCache::PjitFunctionCache(int capacity) : lru_list_(capacity) {} + +std::shared_ptr PjitFunctionCache::DefaultCache() { + return std::make_shared(&lru_list_); +} + +/*static*/ std::shared_ptr PjitFunctionCache::Lookup( + xla::nb_class_ptr self, nb::handle function, + nb::object global_cache_key) ABSL_NO_THREAD_SAFETY_ANALYSIS { + // In no-GIL mode, a critical section on self plays the same role that + // the GIL plays in GIL mode. + nb::ft_object_guard lock(self); + { + // Because the gil (or the critical section lock) can be released during + // cache insertion, this forces the lock order to be mu_ then gil so we + // must release the gil first. + nb::gil_scoped_release release; + // Acquire a mutex to avoid problems where the gil is released during + // cache insertion and then a second thread invalidates the cache order. + self->mu_.Lock(); + } + absl::Cleanup unlock = [&self]() ABSL_UNLOCK_FUNCTION(self->mu_) { + self->mu_.Unlock(); + }; + Key key; + key.function = function; + key.global_cache_key = global_cache_key; + key.cached_hash = absl::HashOf(key); + auto insert = self->functions_.emplace(key, nullptr); + if (!insert.second) { + return insert.first->second->cache; + } + std::shared_ptr cache = std::make_shared(&self->lru_list_); + auto callback = + nb::cpp_function([self, key{std::move(key)}](nb::handle weakref) { + nb::ft_object_guard lock(self); + auto it = self->functions_.find(key); + if (it == self->functions_.end()) { + return; + } + // Remove the value from the map before destroying it. Destroying + // the value may release `lock` since it may call arbitrary Python + // code. + std::unique_ptr value = std::move(it->second); + self->functions_.erase(it); + value.reset(); + }); + PyObject* weakref = PyWeakref_NewRef(function.ptr(), callback.ptr()); + if (weakref) { + std::unique_ptr& entry = insert.first->second; + entry = std::make_unique(cache); + entry->weakref = nb::steal(weakref); + } else { + PyErr_Clear(); + // `function` is not weak-referenceable. Don't bother adding it to the + // shared cache in that case; the `jit` object will hold the only shared + // reference to the cache entry. + self->functions_.erase(insert.first); + } + return cache; +} + +class PjitFunction { + public: + PjitFunction(std::string function_name, std::optional fun, + nb::callable cache_miss, std::vector static_argnums, + std::vector static_argnames, + nb::object global_cache_key, + xla::nb_class_ptr pytree_registry, + nb::callable shard_arg_fallback, + xla::nb_class_ptr cache); + ~PjitFunction(); + + PjitFunction(const PjitFunction&) = delete; + PjitFunction& operator=(const PjitFunction&) = delete; + PjitFunction(PjitFunction&&) = default; + PjitFunction& operator=(PjitFunction&&) = default; + + // nb::object typed subclass for PjitFunction objects. + class pyobject : public nb::object { + public: + NB_OBJECT(pyobject, nb::object, "PjitFunction", + PjitFunction::IsPjitFunction); + pyobject() = default; + PjitFunction* func() const { + return PjitFunction::AsPjitFunctionUnchecked(*this); + } + }; + // Alias as ::object; outside the scope above we won't confuse nanobind's + // macros. + using object = pyobject; + + // Returns true if `h` is a PjitFunction. + static bool IsPjitFunction(nb::handle handle); + // Converts `handle` to a PjitFunction*. Does not do any checking. + static PjitFunction* AsPjitFunctionUnchecked(nb::handle handle); + + absl::StatusOr Call(nb::handle callable, PyObject* const* args, + size_t nargs, PyObject* kwnames); + + void InitExecutables(); + + void ClearPythonReferences(); + + const std::string& function_name() const { return function_name_; } + const std::optional& fun() const { return fun_; } + const nb::callable& cache_miss() const { return cache_miss_; } + const xla::nb_class_ptr& pytree_registry() const { + return pytree_registry_; + } + const nb::callable& shard_arg_fallback() const { return shard_arg_fallback_; } + + const std::vector& static_argnums() const { return static_argnums_; } + const std::vector& static_argnames() const { + return static_argnames_; + } + const nb::object& global_cache_key() const { return global_cache_key_; } + const xla::nb_class_ptr& cache() const { return cache_; } + + int cache_capacity() const { + nb::ft_object_guard lock(cache_); + return executables_->Size(); + } + + void ClearCache() { + nb::ft_object_guard lock(cache_); + executables_->Clear(); + } + + std::shared_ptr executables() { + nb::ft_object_guard lock(cache_); + return executables_; + } + + nb::object PythonSignature() { + if (!fun_.has_value()) { + throw nb::value_error( + absl::StrFormat( + "Calling __signature__ on PjitFunction(%s) not supported.", + function_name_) + .c_str()); + } + static const auto* inspect = + new nb::module_(nb::module_::import_("inspect")); + return inspect->attr("signature")(*fun_); + } + + private: + absl::Status ComputeCallSignature( + absl::Span flat_dynamic_args, + CallSignature& call_signature); + + void PopulateCacheEntry(PjitCacheEntry& cache_entry, + const nb::tuple& out_and_fastpath_data); + + std::string function_name_; + std::optional fun_; + nb::callable cache_miss_; + std::vector static_argnums_; + std::vector static_argnames_; + nb::object global_cache_key_; + + xla::nb_class_ptr pytree_registry_; + nb::callable shard_arg_fallback_; + xla::nb_class_ptr cache_; + + // In no-GIL mode executables_ is protected by the object lock on cache_, + // because it shared an LRU list with cache_. + std::shared_ptr executables_; +}; + +PjitFunction::PjitFunction( + std::string function_name, std::optional fun, + nb::callable cache_miss, std::vector static_argnums, + std::vector static_argnames, nb::object global_cache_key, + xla::nb_class_ptr pytree_registry, + nb::callable shard_arg_fallback, xla::nb_class_ptr cache) + : function_name_(std::move(function_name)), + fun_(std::move(fun)), + cache_miss_(std::move(cache_miss)), + static_argnums_(std::move(static_argnums)), + global_cache_key_(std::move(global_cache_key)), + pytree_registry_(std::move(pytree_registry)), + shard_arg_fallback_(std::move(shard_arg_fallback)), + cache_(std::move(cache)) { + std::sort(static_argnums_.begin(), static_argnums_.end()); + static_argnames_.reserve(static_argnames.size()); + for (nb::str& name : static_argnames) { + PyObject* s = name.inc_ref().ptr(); + PyUnicode_InternInPlace(&s); + static_argnames_.push_back(nb::steal(s)); + } +} + +void PjitFunction::InitExecutables() { + // Construction of the object hasn't completed yet, so we don't need to hold + // the cache lock to mutate executables_. + if (!fun_.has_value()) { + executables_ = cache_->DefaultCache(); + } else { + executables_ = cache_->Lookup(cache_, fun_.value(), global_cache_key_); + } +} + +PjitFunction::~PjitFunction() { + nb::ft_object_guard lock(cache_); + executables_ = nullptr; +} + +void CallShardArgFallback(nb::handle arg, nb::handle sharding, + nb::handle layout, const nb::callable& fallback, + std::vector& num_args_arrays, + std::vector& keep_alive_objects) { + tsl::profiler::TraceMe traceme("cpp_pjit_shard_arg_fallback"); + auto py_array_or_bufs = fallback(arg, sharding, layout); + auto py_array = nb::cast(py_array_or_bufs); + num_args_arrays.push_back(tsl::FormRef(py_array.ifrt_array())); + keep_alive_objects.push_back(std::move(py_array_or_bufs)); +} + +// Prepares the input PjRtBuffers from the python arguments. This is equivalent +// to shard_args() in pxla.py but for only a few supported cases. +absl::StatusOr> PrepareIfrtInputs( + const xla::PyLoadedExecutable& executable, + absl::Span flat_dynamic_args, + absl::Span flat_dynamic_arg_signatures, + bool enable_x64, const std::vector& kept_args, + const std::vector& in_shardings, + const std::vector& in_device_local_layouts, + const nb::callable& shard_arg_fallback, + std::vector& keep_alive_objects) { + const auto& addressable_devices = + executable.ifrt_loaded_executable()->addressable_devices(); + const auto& num_global_devices = + executable.ifrt_loaded_executable()->num_devices(); + int num_args = flat_dynamic_args.size(); + + std::vector num_args_arrays; + num_args_arrays.reserve(num_args); + + struct CopyGroup { + std::vector indices; + std::vector arrays; + }; + absl::flat_hash_map, + CopyGroup> + copy_groups; + + xla::DevicePutOptions options; + options.squash_64bit_types = !enable_x64; + options.allow_zero_copy = true; + xla::ifrt::Device* data_device = nullptr; + if (executable.ifrt_loaded_executable()->num_devices() == 1) { + data_device = executable.ifrt_loaded_executable()->addressable_devices()[0]; + } + int dce_i = 0; + for (int i = 0; i < num_args; ++i) { + if (!kept_args[i]) { + continue; + } + int dce_index = dce_i; + ++dce_i; + + const nb::object& arg = flat_dynamic_args[i]; + const nb::object& in_device_local_layout = + in_device_local_layouts[dce_index]; + + auto transfer_guard_formatter = [] { return std::string(""); }; + + if (arg.type().ptr() != xla::PyArray::type().ptr()) { + if (data_device != nullptr && in_device_local_layout.is_none()) { + TF_RETURN_IF_ERROR( + jax::ApplyTransferGuardToHostToDevice(transfer_guard_formatter)); + TF_ASSIGN_OR_RETURN( + auto device_put_result, + DevicePutWithDevice(arg, + executable.ifrt_loaded_executable()->client(), + data_device, xla::ifrt::MemoryKind(), options)); + num_args_arrays.push_back(std::move(device_put_result.ifrt_array)); + continue; + } else { + CallShardArgFallback(arg, in_shardings[dce_index], + in_device_local_layout, shard_arg_fallback, + num_args_arrays, keep_alive_objects); + continue; + } + } + + xla::PyArray py_array = nb::borrow(arg); + const auto& sharding = py_array.sharding(); + int sharding_num_devices = jax::Sharding::SafeNumDevices(sharding); + + // Currently only committed PyArray inputs or uncommitted PyArray on a + // single device inputs are allowed. This is checked previously in the entry + // point of PjitFunction::Call(). + DCHECK(py_array.committed() || + (!py_array.committed() && sharding_num_devices == 1)); + + if (!in_device_local_layout.is_none()) { + TF_ASSIGN_OR_RETURN(auto arr_layout, py_array.ifrt_array()->layout()); + xla::Layout in_xc_layout = nb::cast( + in_device_local_layout.attr("_to_xla_layout")(py_array.dtype())); + if (in_xc_layout != arr_layout->xla_layout()) { + CallShardArgFallback(arg, in_shardings[dce_index], + in_device_local_layout, shard_arg_fallback, + num_args_arrays, keep_alive_objects); + continue; + } + } + + if (sharding.type().ptr() == jax::PmapSharding::type().ptr()) { + CallShardArgFallback(arg, in_shardings[dce_index], in_device_local_layout, + shard_arg_fallback, num_args_arrays, + keep_alive_objects); + continue; + } + + if (sharding_num_devices != num_global_devices) { + CallShardArgFallback(arg, in_shardings[dce_index], in_device_local_layout, + shard_arg_fallback, num_args_arrays, + keep_alive_objects); + continue; + } + + xla::ifrt::Array* ifrt_array = py_array.ifrt_array(); + // PyArray inputs should have already been checked in + // `xla::PyArgSignatureOfValue()` called by + // `PjitFunction::ComputeCallSignature()`. + DCHECK(ifrt_array != nullptr) << "PyArray has been unexpectedly deleted."; + + const auto& ifrt_sharding = ifrt_array->sharding(); + if (sharding_num_devices == 1 && + ifrt_sharding.devices()->devices().front() != addressable_devices[0]) { + auto& copy_group = + copy_groups[std::make_pair(ifrt_sharding.devices()->devices().front(), + ifrt_sharding.memory_kind())]; + copy_group.indices.push_back(num_args_arrays.size()); + copy_group.arrays.push_back(tsl::FormRef(ifrt_array)); + num_args_arrays.push_back({}); + } else { + num_args_arrays.push_back(tsl::FormRef(ifrt_array)); + } + + keep_alive_objects.push_back(arg); + } + + if (!copy_groups.empty()) { + xla::ifrt::Client* const ifrt_client = + executable.ifrt_loaded_executable()->client(); + xla::ifrt::DeviceListRef ifrt_devices = + ifrt_client->MakeDeviceList({addressable_devices[0]}); + for (auto& [key, group] : copy_groups) { + TF_ASSIGN_OR_RETURN( + auto copied_ifrt_arrays, + ifrt_client->CopyArrays(absl::MakeSpan(group.arrays), ifrt_devices, + /*memory_kind=*/std::nullopt, + xla::ifrt::ArrayCopySemantics::kReuseInput)); + for (int i = 0; i < copied_ifrt_arrays.size(); ++i) { + num_args_arrays[group.indices[i]] = std::move(copied_ifrt_arrays[i]); + } + } + } + + return num_args_arrays; +} + +absl::StatusOr PjitFunction::Call(nb::handle callable, + PyObject* const* args, + size_t nargs, PyObject* kwnames) { + tsl::profiler::TraceMe traceme( + [&] { return absl::StrCat("PjitFunction(", function_name_, ")"); }); + + // Make sure we trigger a garbage collection on JIT function calls. Otherwise + // code like + // f = jit(...) + // while True: + // f(x) + // may never free temporary buffers for copies of arguments. + xla::GlobalPyRefManager()->MaybeCollectGarbage(); + + if (GetDisableJit()) { + if (!fun_.has_value()) { + throw nb::value_error( + absl::StrFormat("Disable jit is not supported in the AOT path since " + "the function is not available for (%s)", + function_name_) + .c_str()); + } + return nb::steal( + PyObject_Vectorcall(fun_.value().ptr(), args, nargs, kwnames)); + } + + // Calls the cache_miss_ function. This just calls the Python function; it may + // return nullptr value if a Python exception is thrown. + auto cache_miss = [&]() -> nb::tuple { + return nb::steal( + PyObject_Vectorcall(cache_miss_.ptr(), args, nargs, kwnames)); + }; + + // Call the cache_miss() function, extracting the output data and ignoring + // the fastpath data. If the cache miss returns a Python error, returns + // nullptr and leaves the Python error set. + auto fallback_to_cache_miss = [&]() { + nb::tuple cache_miss_output = cache_miss(); + if (!cache_miss_output.ptr()) { + return nb::object(); + } + return nb::object(cache_miss_output[0]); + }; + + size_t num_positional_args = PyVectorcall_NARGS(nargs); + size_t num_keyword_args = kwnames ? PyTuple_GET_SIZE(kwnames) : 0; + absl::Span positional_args(args, num_positional_args); + absl::Span keyword_args(args + num_positional_args, + num_keyword_args); + + CallSignature call_signature; + std::vector keep_alive_objects; + absl::InlinedVector flat_dynamic_args; + auto status = ParseArguments( + positional_args, keyword_args, kwnames, static_argnums_, static_argnames_, + pytree_registry_.get(), call_signature.arg_signature, flat_dynamic_args); + if (!status.ok()) { + VLOG(2) << "ParseArguments failed: " << status; + return fallback_to_cache_miss(); + } + + // Perform a few checks for the arguments. Currently we are only allowing + // committed PyArray inputs. For other cases, e.g. Tracers or ShapedArray, it + // will fallback to python. For jit, numpy arrays and scalars are also + // allowed, which we will check later. + for (const auto& arg : flat_dynamic_args) { + if (arg.type().ptr() != xla::PyArray::type().ptr()) { + continue; + } + + xla::PyArray py_array = nb::borrow(arg); + + // Only allow committed PyArray in cpp pjit for now as the logic on handling + // sharding for uncommitted PyArray is complicated and still under + // development. + // + // TODO(chky): Consider support uncommitted PyArray in cpp when the python + // side stabilizes. + if (!py_array.committed() && + jax::Sharding::SafeNumDevices(py_array.sharding()) > 1) { + VLOG(2) << "PyArray argument is not committed and number of global " + "devices is more than 1; fallback to python."; + return fallback_to_cache_miss(); + } + } + + status = ComputeCallSignature(flat_dynamic_args, call_signature); + if (!status.ok()) { + VLOG(2) << "ComputeCallSignature failed: " << status; + return fallback_to_cache_miss(); + } + + VLOG(2) << "CallSignature:\n" << call_signature.DebugString(); + bool inserted = false; + std::shared_ptr cache_entry; + { + nb::ft_object_guard lock(cache_); + cache_entry = executables_->GetOrCreateIfAbsent( + call_signature, [this, &inserted](const CallSignature& unused) { + inserted = true; + return std::make_shared(pytree_registry_.get()); + }); + } + + if (!cache_entry->compilation_complete.HasBeenNotified()) { + // In case of several threads attempting to compile the executable, only + // the one that inserted the item will perform the compilation. + if (inserted) { + nb::object out_and_fastpath_data; + nb::tuple out_tuple; + VLOG(2) << "Cache miss for " << call_signature.DebugString(); + bool remove_cache = false; + try { + // Calls Python and may release the GIL. May also throw if + // compilation/tracing fails. + out_and_fastpath_data = cache_miss(); + if (!out_and_fastpath_data.ptr()) { + throw nb::python_error(); + } + out_tuple = nb::cast(out_and_fastpath_data); + + PopulateCacheEntry(*cache_entry, out_tuple); + + if (out_tuple.size() > 2 && out_tuple[2].is_valid()) { + remove_cache = nb::cast(out_tuple[2]); + } + } catch (const std::exception& e) { + VLOG(2) << "cache miss fail: " << e.what(); + cache_entry->fall_back_to_python = true; + cache_entry->compilation_complete.Notify(); + throw; + } + cache_entry->compilation_complete.Notify(); + + if (remove_cache) { + nb::ft_object_guard lock(cache_); + executables_->Remove(call_signature); + } + + // We have already computed the result in the miss path so we can return + // it. We are even *required* to do so if there are donated arguments, + // because any donated buffers will now be invalid. + return nb::object(out_tuple[0]); + } else { + if (cache_entry->thread_id == std::this_thread::get_id()) { + auto error_string = absl::StrCat("Recursively calling jit: ", + call_signature.DebugString()); + PyErr_SetString(PyExc_RecursionError, error_string.c_str()); + throw nb::python_error(); + } + // Release the GIL while we wait, making sure the compile thread can + // lock it. + nb::gil_scoped_release release; + cache_entry->compilation_complete.WaitForNotification(); + } + } + + if (cache_entry->fall_back_to_python) { + VLOG(2) << "cpp pjit fallback to python."; + return fallback_to_cache_miss(); + } + + // A vector of [num_inputs]. + auto num_args_arrays = PrepareIfrtInputs( + *cache_entry->executable, flat_dynamic_args, + call_signature.dynamic_arg_signatures, call_signature.jax_enable_x64, + cache_entry->kept_var_bitvec, cache_entry->in_shardings, + cache_entry->in_device_local_layouts, shard_arg_fallback_, + keep_alive_objects); + + if (!num_args_arrays.ok()) { + VLOG(2) << "Failed to prepare IFRT inputs: " << num_args_arrays.status(); + return fallback_to_cache_miss(); + } + + xla::ifrt::ExecuteOptions execute_options = + cache_entry->executable->options(); + execute_options.launch_id = cache_entry->executable->GetNextLaunchId(); + execute_options.execution_stream_id = xla::GetExecutionStreamId(); + if (execute_options.execution_stream_id == 0) { + execute_options.execution_stream_id = + tsl::Env::Default()->GetCurrentThreadId(); + } + + // A vector of [num_outputs]. + std::vector output_arrays; + { + nb::gil_scoped_release gil_release; + TF_ASSIGN_OR_RETURN(auto result, + cache_entry->executable->ifrt_executable()->Execute( + absl::MakeSpan(*num_args_arrays), execute_options, + /*devices=*/std::nullopt)); + output_arrays = std::move(result.outputs); + } + + auto traceback = xla::Traceback::Get(); + + // Convert the ifrt::Array objects to PyArray. + int num_outputs = output_arrays.size(); + absl::InlinedVector outputs; + outputs.reserve(num_outputs); + for (int i = 0; i < num_outputs; ++i) { + // Creating the PyArray result. In addition to the IFRT arrays, the metadata + // like `aval` and `sharding` are retrieved from the cache for this + // function, which are produced by the python path in `cache_miss`. + xla::PyArray py_array( + cache_entry->out_avals[i], cache_entry->out_weak_types[i], + cache_entry->out_dtypes[i], cache_entry->out_shapes[i], + cache_entry->out_shardings[i], cache_entry->executable->client(), + traceback, std::move(output_arrays[i]), + /*committed=*/cache_entry->out_committed.at(i), /*skip_checks=*/true); + + outputs.push_back(std::move(py_array)); + } + + nb::object out = nb::steal( + cache_entry->out_pytree_def.Unflatten(outputs).release().ptr()); + + // If there is a post-hook function, call it with the inputs and the outputs. + std::optional post_hook = GetPostHook(); + if (post_hook) { + nb::tuple args_tuple = + nb::steal(PyTuple_New(num_positional_args)); + for (size_t i = 0; i < num_positional_args; ++i) { + Py_INCREF(args[i]); + PyTuple_SET_ITEM(args_tuple.ptr(), i, args[i]); + } + nb::dict kwargs; + if (kwnames) { + for (size_t i = 0; i < num_keyword_args; ++i) { + kwargs[nb::handle(PyTuple_GET_ITEM(kwnames, i))] = + nb::borrow(args[num_positional_args + i]); + } + } + (*post_hook)(nb::handle(callable.ptr()), args_tuple, kwargs, + nb::handle(out.ptr())); + } + + return out; +} + +absl::Status PjitFunction::ComputeCallSignature( + absl::Span flat_dynamic_args, CallSignature& signature) { + signature.function_name = function_name_; + + // Get dynamic argument signatures. + JitState& global_state = jax::GlobalJitState(); + JitState& tls = jax::ThreadLocalJitState(); + bool jax_enable_x64 = GetEnableX64(); + + signature.default_device = GetDefaultDevice(); + signature.jax_enable_x64 = jax_enable_x64; + + auto& dynamic_arg_signatures = signature.dynamic_arg_signatures; + dynamic_arg_signatures.reserve(flat_dynamic_args.size()); + auto& dynamic_arg_shardings = signature.dynamic_arg_shardings; + dynamic_arg_shardings.reserve(flat_dynamic_args.size()); + auto& dynamic_arg_layouts = signature.dynamic_arg_layouts; + dynamic_arg_layouts.reserve(flat_dynamic_args.size()); + + for (nb::handle arg : flat_dynamic_args) { + TF_ASSIGN_OR_RETURN(auto arg_signature, + xla::PyArgSignatureOfValue(arg, jax_enable_x64)); + signature.dynamic_arg_signatures.push_back(std::move(arg_signature)); + + // It should be already checked previously in the entry point of + // PjitFunction::Call(). + if (arg.type().ptr() == xla::PyArray::type().ptr()) { + auto py_array = nb::borrow(arg); + signature.dynamic_arg_shardings.push_back(py_array.sharding()); + auto layout = py_array.layout(); + if (absl::IsUnimplemented(layout.status())) { + signature.dynamic_arg_layouts.push_back(nullptr); + } else { + signature.dynamic_arg_layouts.push_back(*std::move(layout)); + } + signature.committed_args.push_back(py_array.committed()); + } else { + signature.dynamic_arg_shardings.push_back(nb::none()); + signature.dynamic_arg_layouts.push_back(nullptr); + signature.committed_args.push_back(false); + } + } + + signature.thread_local_extra_jit_context = tls.extra_jit_context; + signature.global_extra_jit_context = global_state.extra_jit_context; + signature.configs = JitConfigs(); + + return absl::OkStatus(); +} + +void PjitFunction::PopulateCacheEntry(PjitCacheEntry& cache_entry, + const nb::tuple& out_and_fastpath_data) { + DCHECK_GE(out_and_fastpath_data.size(), 2); + + if (out_and_fastpath_data[1].is_none()) { + VLOG(2) << "fastpath_data is none"; + cache_entry.fall_back_to_python = true; + return; + } + + nb::tuple fastpath_data = nb::cast(out_and_fastpath_data[1]); + + cache_entry.executable = nb::cast>( + fastpath_data.attr("xla_executable")); + + nb::sequence in_shardings = fastpath_data.attr("in_shardings"); + cache_entry.in_shardings.reserve(nb::len(in_shardings)); + for (nb::handle sharding : in_shardings) { + cache_entry.in_shardings.push_back(nb::borrow(sharding)); + } + + nb::sequence out_shardings = fastpath_data.attr("out_shardings"); + cache_entry.out_shardings.reserve(nb::len(out_shardings)); + for (nb::handle sharding : out_shardings) { + cache_entry.out_shardings.push_back(nb::borrow(sharding)); + } + + nb::sequence out_committed = fastpath_data.attr("out_committed"); + cache_entry.out_committed.reserve(nb::len(out_committed)); + for (nb::handle c : out_committed) { + cache_entry.out_committed.push_back(nb::cast(c)); + } + + nb::sequence out_avals = fastpath_data.attr("out_avals"); + cache_entry.out_avals.reserve(nb::len(out_avals)); + cache_entry.out_dtypes.reserve(nb::len(out_avals)); + cache_entry.out_shapes.reserve(nb::len(out_avals)); + cache_entry.out_weak_types.reserve(nb::len(out_avals)); + for (nb::handle aval : out_avals) { + cache_entry.out_avals.push_back(nb::borrow(aval)); + cache_entry.out_dtypes.push_back(aval.attr("dtype")); + cache_entry.out_shapes.push_back( + nb::cast>(aval.attr("shape"))); + cache_entry.out_weak_types.push_back( + nb::cast(aval.attr("weak_type"))); + } + + cache_entry.out_pytree_def = nb::cast( + nb::handle(fastpath_data.attr("out_pytree_def").ptr())); + + nb::sequence kept_var_bitvec = fastpath_data.attr("kept_var_bitvec"); + cache_entry.kept_var_bitvec.reserve(nb::len(kept_var_bitvec)); + for (nb::handle k : kept_var_bitvec) { + cache_entry.kept_var_bitvec.push_back(nb::cast(k)); + } + + nb::sequence in_device_local_layouts = + fastpath_data.attr("in_device_local_layouts"); + cache_entry.in_device_local_layouts.reserve(nb::len(in_device_local_layouts)); + for (nb::handle dll : in_device_local_layouts) { + cache_entry.in_device_local_layouts.push_back(nb::borrow(dll)); + } +} + +// Helper function used by the tp_clear GC method. +void PjitFunction::ClearPythonReferences() { + // TODO(mattjj): phawkins@ observed that the xla::PyTreeRegistry + // pytree_registry_ attribute of PjitFunction could in principle also have + // python references to clear + nb::callable cache_miss; + std::optional fun; + nb::callable shard_arg_fallback; + // Swap values for nulls before they are destroyed. See the Python + // Py_CLEAR() documentation for a discussion of this topic. + std::swap(cache_miss_, cache_miss); + std::swap(fun_, fun); + std::swap(shard_arg_fallback_, shard_arg_fallback); +} + +struct PjitFunctionObject { + PyObject_HEAD; +#if PY_VERSION_HEX < 0x030C0000 + PyObject* dict; // Dictionary for __dict__ + PyObject* weakrefs; // Weak references; for use by the Python interpreter. +#endif // PY_VERSION_HEX < 0x030C0000 + vectorcallfunc vectorcall; + PjitFunction fun; + + // Doubly-linked list of PjitFunctionObjects, protected by + // PjitFunctionStore::mu_ or the GIL in GIL mode. + PjitFunctionObject* next; + PjitFunctionObject* prev; +}; + +// Contains a list of all PjitFunctionObjects. +// Thread-safe. +class PjitFunctionStore { + public: + void Insert(PjitFunctionObject* o) { + nb::ft_lock_guard lock(mu_); + o->next = compiled_functions_; + o->prev = nullptr; + if (o->next) { + o->next->prev = o; + } + compiled_functions_ = o; + } + + void Remove(PjitFunctionObject* o) { + nb::ft_lock_guard lock(mu_); + if (o->next) { + o->next->prev = o->prev; + } + if (o->prev) { + o->prev->next = o->next; + } else { + compiled_functions_ = o->next; + } + } + + void ClearCaches() { + std::vector< + std::pair>> + caches; + { + nb::ft_lock_guard lock(mu_); + for (PjitFunctionObject* fn = compiled_functions_; fn != nullptr; + fn = fn->next) { + caches.emplace_back(fn->fun.cache(), fn->fun.executables()); + } + } + for (auto& [cache, executables] : caches) { + nb::ft_object_guard lock(cache); + executables->Clear(); + } + }; + + private: + // Protected by the GIL in GIL mode, and by mu_ in freethreading mode. + nb::ft_mutex mu_; + PjitFunctionObject* compiled_functions_; +}; + +PjitFunctionStore pjit_function_store; + +PyObject* PjitFunction_Type = nullptr; + +bool PjitFunction::IsPjitFunction(nb::handle handle) { + return handle.type().ptr() == PjitFunction_Type; +} + +PjitFunction* PjitFunction::AsPjitFunctionUnchecked(nb::handle handle) { + return &(reinterpret_cast(handle.ptr())->fun); +} + +PjitFunction* AsPjitFunction(nb::handle handle) { + if (!PjitFunction::IsPjitFunction(handle)) { + throw xla::XlaRuntimeError(xla::InvalidArgument("Expected a PjitFunction")); + } + return PjitFunction::AsPjitFunctionUnchecked(handle); +} + +extern "C" { + +PyObject* PjitFunction_tp_vectorcall(PyObject* callable, PyObject* const* args, + size_t nargs, PyObject* kwnames) { + PjitFunctionObject* o = reinterpret_cast(callable); + tsl::profiler::TraceMe traceme([&] { + return absl::StrCat("PjitFunction(", o->fun.function_name(), ")"); + }); + try { + absl::StatusOr out = + o->fun.Call(callable, args, nargs, kwnames); + if (!out.ok()) { + PyErr_SetString(PyExc_ValueError, out.status().ToString().c_str()); + return nullptr; + } + return out.value().release().ptr(); + } catch (nb::python_error& e) { + e.restore(); + return nullptr; + } catch (nb::cast_error& e) { + PyErr_SetString(PyExc_ValueError, e.what()); + return nullptr; + } catch (std::invalid_argument& e) { + PyErr_SetString(PyExc_ValueError, e.what()); + return nullptr; + } catch (std::runtime_error& e) { + PyErr_SetString(PyExc_ValueError, e.what()); + return nullptr; + } +} + +PyObject* PjitFunction_tp_new(PyTypeObject* subtype, PyObject* args, + PyObject* kwds) { + PjitFunctionObject* self = + reinterpret_cast(subtype->tp_alloc(subtype, 0)); + if (!self) return nullptr; +#if PY_VERSION_HEX < 0x030C0000 + self->dict = nullptr; + self->weakrefs = nullptr; +#endif // PY_VERSION_HEX < 0x030C0000 + self->vectorcall = PjitFunction_tp_vectorcall; + return reinterpret_cast(self); +} + +void PjitFunction_tp_dealloc(PyObject* self) { + PyObject_GC_UnTrack(self); + PyTypeObject* tp = Py_TYPE(self); + PjitFunctionObject* o = reinterpret_cast(self); + pjit_function_store.Remove(o); + PyObject_ClearWeakRefs(self); +#if PY_VERSION_HEX < 0x030C0000 + Py_CLEAR(o->dict); +#elif PY_VERSION_HEX < 0x030D0000 + _PyObject_ClearManagedDict(self); +#else + PyObject_ClearManagedDict(self); +#endif // PY_VERSION_HEX < 0x030C0000 + o->fun.~PjitFunction(); + tp->tp_free(self); + Py_DECREF(tp); +} + +int PjitFunction_tp_traverse(PyObject* self, visitproc visit, void* arg) { + // TODO(mattjj): phawkins@ observed that the xla::PyTreeRegistry + // pytree_registry_ attribute of PjitFunction could in principle also have + // python references to visit + PjitFunctionObject* o = reinterpret_cast(self); + // https://docs.python.org/3/c-api/typeobj.html#c.PyTypeObject.tp_traverse + Py_VISIT(Py_TYPE(self)); +#if PY_VERSION_HEX < 0x030C0000 + Py_VISIT(o->dict); +#elif PY_VERSION_HEX < 0x030D0000 + _PyObject_VisitManagedDict(self, visit, arg); +#else + PyObject_VisitManagedDict(self, visit, arg); +#endif // PY_VERSION_HEX < 0x030C0000 + Py_VISIT(o->fun.cache_miss().ptr()); + Py_VISIT(o->fun.shard_arg_fallback().ptr()); + if (o->fun.fun()) { + Py_VISIT(o->fun.fun()->ptr()); + } + return 0; +} + +int PjitFunction_tp_clear(PyObject* self) { + PjitFunctionObject* o = reinterpret_cast(self); +#if PY_VERSION_HEX < 0x030C0000 + Py_CLEAR(o->dict); +#elif PY_VERSION_HEX < 0x030D0000 + _PyObject_ClearManagedDict(self); +#else + PyObject_ClearManagedDict(self); +#endif // PY_VERSION_HEX < 0x030C0000 + o->fun.ClearPythonReferences(); + return 0; +} + +// Implements the Python descriptor protocol so JIT-compiled functions can be +// used as bound methods. See: +// https://docs.python.org/3/howto/descriptor.html#functions-and-methods +PyObject* PjitFunction_tp_descr_get(PyObject* self, PyObject* obj, + PyObject* type) { + if (obj == nullptr || obj == Py_None) { + Py_INCREF(self); + return self; + } + return PyMethod_New(self, obj); +} + +static PyGetSetDef PjitFunction_tp_getset[] = { + // Having a __dict__ seems necessary to allow !functool.wraps to override + // __doc__. + {const_cast("__dict__"), PyObject_GenericGetDict, + PyObject_GenericSetDict, nullptr, nullptr}, + {nullptr, nullptr, nullptr, nullptr, nullptr}}; + +PyObject* PjitFunction_tp_repr(PyObject* self) { + try { + const std::string& repr = absl::StrFormat( + "", nb::cast(nb::repr( + nb::getattr(self, "__wrapped__")))); + return PyUnicode_FromString(repr.c_str()); + } catch (...) { + // Ignore all errors when accessing a repr. + return PyUnicode_FromString(""); + } +} + +} // extern "C" + +void InitializePjitFunction( + PjitFunctionObject* fn_obj, std::string function_name, + std::optional fun, nb::callable cache_miss, + std::vector static_argnums, std::vector static_argnames, + nb::object global_cache_key, + xla::nb_class_ptr pytree_registry, + nb::callable shard_arg_fallback, + xla::nb_class_ptr cache) { + fn_obj->next = fn_obj->prev = nullptr; + if (nb::isinstance(global_cache_key)) { + global_cache_key = nb::tuple(global_cache_key); + } + new (&fn_obj->fun) PjitFunction( + std::move(function_name), std::move(fun), std::move(cache_miss), + std::move(static_argnums), std::move(static_argnames), + std::move(global_cache_key), std::move(pytree_registry), + std::move(shard_arg_fallback), std::move(cache)); + // Handled separately because it is not exception safe to call this + // in the constructor because it leaves the object improperly constructed. + fn_obj->fun.InitExecutables(); + + // Only add the executable to the store after executables_ has been + // initialized. We want only fully constructed executables in the store. + pjit_function_store.Insert(fn_obj); +} + +nb::object MakePjitFunction( + std::string function_name, std::optional fun, + nb::callable cache_miss, std::vector static_argnums, + std::vector static_argnames, nb::object global_cache_key, + xla::nb_class_ptr pytree_registry, + nb::callable shard_arg_fallback, + std::optional> cache) { + nb::object obj = nb::steal(PjitFunction_tp_new( + reinterpret_cast(PjitFunction_Type), nullptr, nullptr)); + PjitFunctionObject* fn_obj = reinterpret_cast(obj.ptr()); + if (!cache) { + cache = xla::make_nb_class( + PjitFunctionCache::kDefaultCapacity); + } + InitializePjitFunction( + fn_obj, std::move(function_name), std::move(fun), std::move(cache_miss), + std::move(static_argnums), std::move(static_argnames), + std::move(global_cache_key), std::move(pytree_registry), + std::move(shard_arg_fallback), std::move(*cache)); + return obj; +} + +// Version numbers for the pickled representations of +// PjitFunction. Increment these if changing them. +const int kPjitFunctionPickleVersion = 1; + +PyMemberDef PjitFunction_members[] = { + {"__vectorcalloffset__", T_PYSSIZET, + static_cast(offsetof(PjitFunctionObject, vectorcall)), + READONLY, nullptr}, +#if PY_VERSION_HEX < 0x030C0000 + {"__dictoffset__", T_PYSSIZET, + static_cast(offsetof(PjitFunctionObject, dict)), READONLY, + nullptr}, + {"__weaklistoffset__", T_PYSSIZET, + static_cast(offsetof(PjitFunctionObject, weakrefs)), READONLY, + nullptr}, +#endif // PY_VERSION_HEX < 0x030C0000 + {nullptr, 0, 0, 0, nullptr}, +}; + +PyType_Slot PjitFunction_slots[] = { + {Py_tp_new, reinterpret_cast(PjitFunction_tp_new)}, + {Py_tp_dealloc, reinterpret_cast(PjitFunction_tp_dealloc)}, + {Py_tp_traverse, reinterpret_cast(PjitFunction_tp_traverse)}, + {Py_tp_clear, reinterpret_cast(PjitFunction_tp_clear)}, + {Py_tp_getset, reinterpret_cast(PjitFunction_tp_getset)}, + {Py_tp_descr_get, reinterpret_cast(PjitFunction_tp_descr_get)}, + {Py_tp_call, reinterpret_cast(PyVectorcall_Call)}, + {Py_tp_repr, reinterpret_cast(PjitFunction_tp_repr)}, + {Py_tp_members, reinterpret_cast(PjitFunction_members)}, + {0, nullptr}, +}; + +} // namespace + +void BuildPjitSubmodule(nb::module_& m) { + nb::class_ cache(m, "PjitFunctionCache"); + cache.def(nb::init(), + nb::arg("capacity") = PjitFunctionCache::kDefaultCapacity); + cache.def("size", &PjitFunctionCache::Size, nb::lock_self()); + cache.def("capacity", &PjitFunctionCache::Capacity, nb::lock_self()); + cache.def("clear", &PjitFunctionCache::Clear, nb::lock_self()); + cache.def_static("clear_all", []() { pjit_function_store.ClearCaches(); }); + cache.def( + "__getstate__", + // Pickles as an empty cache; the client can repopulate as needed. + [](const PjitFunctionCache& cache) { + nb::dict pickle; + pickle["version"] = kPjitFunctionPickleVersion; + pickle["capacity"] = cache.Capacity(); + return pickle; + }, + nb::lock_self()); + cache.def("__setstate__", + [](PjitFunctionCache* cache, const nb::dict& pickle) { + int version = nb::cast(pickle["version"]); + if (version != kPjitFunctionPickleVersion) { + throw std::invalid_argument(absl::StrFormat( + "Invalid PjitFunction pickle version, got %d, expected %d", + version, kPjitFunctionPickleVersion)); + } + int capacity = nb::cast(pickle["capacity"]); + new (cache) PjitFunctionCache(capacity); + }); + + // We need to use heap-allocated type objects because we want to add + // additional methods dynamically. + std::string name = + absl::StrCat(nb::cast(m.attr("__name__")), ".PjitFunction"); + PyType_Spec PjitFunction_spec = { + /*.name=*/name.c_str(), + /*.basicsize=*/static_cast(sizeof(PjitFunctionObject)), + /*.itemsize=*/0, +#if PY_VERSION_HEX < 0x030C0000 + /*.flags=*/Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC | + Py_TPFLAGS_HAVE_VECTORCALL, +#else // PY_VERSION_HEX < 0x030C0000 + /*.flags=*/Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC | + Py_TPFLAGS_HAVE_VECTORCALL | Py_TPFLAGS_MANAGED_DICT | + Py_TPFLAGS_MANAGED_WEAKREF, +#endif // PY_VERSION_HEX < 0x030C0000 + /*.slots=*/PjitFunction_slots, + }; + PjitFunction_Type = PyType_FromSpec(&PjitFunction_spec); + if (!PjitFunction_Type) { + throw nb::python_error(); + } + nb::object cfun = nb::borrow(PjitFunction_Type); + + // Add PjitFunction to the _jax module so it can be pickled. + m.attr("PjitFunction") = cfun; + cfun.attr("__getstate__") = nb::cpp_function( + [](const PjitFunction::object& self) { + PjitFunction* fn = self.func(); + nb::dict pickle; + pickle["version"] = kPjitFunctionPickleVersion; + pickle["function_name"] = fn->function_name(); + if (fn->fun().has_value()) { + pickle["fun"] = *fn->fun(); + } + pickle["cache_miss"] = fn->cache_miss(); + pickle["static_argnums"] = fn->static_argnums(); + pickle["static_argnames"] = nb::cast(fn->static_argnames()); + pickle["global_cache_key"] = fn->global_cache_key(); + pickle["pytree_registry"] = nb::cast(fn->pytree_registry()); + pickle["shard_arg_fallback"] = fn->shard_arg_fallback(); + pickle["cache"] = fn->cache(); + return pickle; + }, + nb::is_method()); + cfun.attr("__setstate__") = nb::cpp_function( + [](nb::object& self, const nb::dict& pickle) { + int version = nb::cast(pickle["version"]); + if (version != kPjitFunctionPickleVersion) { + throw std::invalid_argument(absl::StrFormat( + "Invalid PjitFunction pickle version, got %d, expected %d. " + "Pickling/Unpickling jitted functions using different JAX " + "versions is not supported.", + version, kPjitFunctionPickleVersion)); + } + std::string function_name = + nb::cast(pickle["function_name"]); + std::optional fun; + if (pickle.contains("fun")) { + fun = nb::cast(pickle["fun"]); + } + nb::callable cache_miss = nb::cast(pickle["cache_miss"]); + std::vector static_argnums = + nb::cast>(pickle["static_argnums"]); + std::vector static_argnames = + nb::cast>(pickle["static_argnames"]); + nb::object global_cache_key = pickle["global_cache_key"]; + xla::nb_class_ptr pytree_registry = + nb::cast>( + nb::handle(pickle["pytree_registry"].ptr())); + nb::callable shard_arg_fallback = + nb::cast(pickle["shard_arg_fallback"]); + xla::nb_class_ptr cache = + nb::cast>(pickle["cache"]); + InitializePjitFunction( + reinterpret_cast(self.ptr()), + std::move(function_name), std::move(fun), std::move(cache_miss), + std::move(static_argnums), std::move(static_argnames), + std::move(global_cache_key), std::move(pytree_registry), + std::move(shard_arg_fallback), std::move(cache)); + }, + nb::is_method()); + cfun.attr("__signature__") = + xla::nb_property_readonly([](nb::handle self) -> nb::object { + return AsPjitFunction(self)->PythonSignature(); + }); + cfun.attr("_cache_miss") = + xla::nb_property_readonly([](nb::handle self) -> nb::object { + return AsPjitFunction(self)->cache_miss(); + }); + // All private members are only for testing/debugging purposes + cfun.attr("_cache_size") = nb::cpp_function( + [](nb::handle self) -> int { + return AsPjitFunction(self)->cache_capacity(); + }, + nb::is_method()); + cfun.attr("_clear_cache") = nb::cpp_function( + [](nb::handle self) { AsPjitFunction(self)->ClearCache(); }, + nb::is_method()); + + m.def( + "pjit", + [](std::string function_name, std::optional fun, + nb::callable cache_miss, std::vector static_argnums, + std::vector static_argnames, nb::object global_cache_key, + nb::object pytree_registry, nb::callable shard_arg_fallback, + std::optional> cache) { + xla::nb_class_ptr registry = + nb::cast>( + nb::handle(pytree_registry.ptr())); + return MakePjitFunction( + std::move(function_name), std::move(fun), std::move(cache_miss), + std::move(static_argnums), std::move(static_argnames), + std::move(global_cache_key), std::move(registry), + std::move(shard_arg_fallback), std::move(cache)); + }, + nb::arg("function_name"), nb::arg("fun").none(), nb::arg("cache_miss"), + nb::arg("static_argnums"), nb::arg("static_argnames"), + nb::arg("global_cache_key"), nb::arg("pytree_registry"), + nb::arg("shard_arg_fallback"), nb::arg("cache").none() = nb::none()); +} + +} // namespace jax diff --git a/jaxlib/pjit.h b/jaxlib/pjit.h new file mode 100644 index 000000000000..d86fa6bddc3c --- /dev/null +++ b/jaxlib/pjit.h @@ -0,0 +1,27 @@ +/* Copyright 2022 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_PJIT_H_ +#define JAXLIB_PJIT_H_ + +// placeholder for index annotation headers +#include "nanobind/nanobind.h" + +namespace jax { + +void BuildPjitSubmodule(nanobind::module_& m); +} + +#endif // JAXLIB_PJIT_H_ diff --git a/jaxlib/pmap_lib.cc b/jaxlib/pmap_lib.cc new file mode 100644 index 000000000000..f49954e7df90 --- /dev/null +++ b/jaxlib/pmap_lib.cc @@ -0,0 +1,1137 @@ +/* Copyright 2021 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/pmap_lib.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/inlined_vector.h" +#include "absl/hash/hash.h" +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/synchronization/notification.h" +#include "absl/types/span.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/shared_ptr.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/variant.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "jaxlib/config.h" +#include "jaxlib/jax_jit.h" +#include "jaxlib/nb_class_ptr.h" +#include "jaxlib/py_array.h" +#include "jaxlib/py_client.h" +#include "jaxlib/py_device.h" +#include "jaxlib/py_executable.h" +#include "jaxlib/py_values.h" +#include "jaxlib/python_ref_manager.h" +#include "jaxlib/pytree.h" +#include "jaxlib/sharded_device_array.h" +#include "jaxlib/sharding.h" +#include "jaxlib/traceback.h" +#include "xla/pjrt/exceptions.h" +#include "xla/pjrt/status_casters.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/device_list.h" +#include "xla/python/ifrt/executable.h" +#include "xla/python/ifrt/memory.h" +#include "xla/python/ifrt/sharding.h" +#include "xla/python/nb_helpers.h" +#include "xla/python/nb_numpy.h" +#include "xla/python/safe_static_init.h" +#include "xla/python/types.h" +#include "xla/status_macros.h" +#include "xla/tsl/concurrency/ref_count.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/tsl/python/lib/core/numpy.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" +#include "tsl/profiler/lib/traceme.h" + +namespace jax { + +namespace nb = nanobind; + +namespace { + +// Specifies how to shard the inputs. Even though everything could be computed +// from `sharding_specs` and the argument shape, we cache derived computations +// for performance. +struct InputSpec { + InputSpec(nb::object indices, nb::object array_sharding) + : indices(std::move(indices)), + array_sharding(std::move(array_sharding)) {} + nb::object indices; + nb::object array_sharding; +}; + +// An object containing the arguments to create Array from the +// output buffers. +struct ResultSpec { + public: + explicit ResultSpec(nb::object aval) + : out_aval(std::move(aval)), + weak_type(nb::cast(out_aval.attr("weak_type"))) {} + nb::object out_aval; + bool weak_type; +}; + +// The result of `ShardArg`. +struct ShardArgResult { + // Points to the on-device array. + // ifrt_array->sharding().num_shards() == `num_devices`. + xla::ifrt::ArrayRef ifrt_array; + // The Python argument will be always be copied to `owning_sda`. + nb::object owning_sda; +}; + +// Shards a single argument over devices. +// +// We currently only support fully in C++, C++ Array. For all +// other usages, we call a Python function returning C++ Array +// that will be casted back to the C++ objects. +// +// This function is not usable for JAX extensions that do not comply with the +// PjRt interfaces. +// +// Arguments: +// `arg`: The object to shard across `devices`. If a `Array`, +// a fast-path will be executed if it's already correctly sharded. +// +// Returns a failure absl::Status when an unrecoverable error occurred, so we +// don't need to fallback to Python. +// +// Both `devices` and `sharding_spec` has the same length. +absl::StatusOr ShardArg( + nb::handle arg, absl::Span devices, + const InputSpec& input_spec, nb::handle py_devices, + const nb::callable& python_fallback) { + if (arg.type().ptr() == xla::PyArray::type().ptr()) { + auto py_array = nb::borrow(arg); + if (py_array.sharding().type().ptr() == + input_spec.array_sharding.type().ptr()) { + auto* pmap_sharding = nb::cast(py_array.sharding()); + auto* cached_pmap_sharding = + nb::cast(input_spec.array_sharding); + + if (pmap_sharding->sharding_spec() == + cached_pmap_sharding->sharding_spec()) { + ShardArgResult result; + result.owning_sda = nb::borrow(arg); + result.ifrt_array = tsl::FormRef(py_array.ifrt_array()); + if (result.ifrt_array == nullptr) { + return xla::InvalidArgument("Array has been deleted."); + } + if (result.ifrt_array->sharding().devices()->devices() != devices) { + absl::InlinedVector ifrt_devices; + ifrt_devices.reserve(devices.size()); + ifrt_devices.insert(ifrt_devices.end(), devices.begin(), + devices.end()); + // pmap does not support memory_kind for now. + auto* ifrt_client = result.ifrt_array->client(); + TF_ASSIGN_OR_RETURN(auto copied_ifrt_arrays, + ifrt_client->CopyArrays( + absl::MakeSpan(&result.ifrt_array, 1), + ifrt_client->MakeDeviceList(ifrt_devices), + xla::ifrt::MemoryKind(), + xla::ifrt::ArrayCopySemantics::kReuseInput)); + result.ifrt_array = std::move(copied_ifrt_arrays.front()); + } + return result; + } + } + } + + auto ndarray = xla::nb_numpy_ndarray::ensure(arg); + if (ndarray && PyArray_CheckExact(arg.ptr()) && + xla::DtypeToPrimitiveType(ndarray.dtype()).status().ok()) { + tsl::profiler::TraceMe traceme("ndarray pmap ShardArg"); + nb::list indices = nb::list(input_spec.indices); + nb::list py_devices_list = nb::cast(py_devices); + auto n_devices = py_devices_list.size(); + if (indices.size() != n_devices) { + return xla::InvalidArgument("indices vs devices mismatch: %d vs %d", + indices.size(), n_devices); + } + + ShardArgResult result; + const bool jax_enable_x64 = GetEnableX64(); + + std::vector owning_args; + std::vector args; + owning_args.reserve(n_devices); + args.reserve(n_devices); + xla::DevicePutOptions options; + options.squash_64bit_types = !jax_enable_x64; + options.allow_zero_copy = true; + xla::ifrt::Client* ifrt_client = nullptr; + for (size_t i = 0; i < n_devices; ++i) { + auto to_device = nb::cast(py_devices_list[i]); + if (to_device->client().get() == nullptr) { + return xla::InvalidArgument("Cannot copy to unattached devices."); + } + if (i == 0) { + ifrt_client = to_device->client()->ifrt_client(); + } + owning_args.push_back(arg[indices[i]]); + args.push_back(owning_args.back()); + } + CHECK(ifrt_client != nullptr); + TF_ASSIGN_OR_RETURN( + xla::DevicePutResult device_put_result, + xla::DevicePutWithSharding( + args, ifrt_client, ndarray.dtype(), + nb::cast>(ndarray.attr("shape")), + input_spec.array_sharding, options)); + result.ifrt_array = std::move(device_put_result.ifrt_array); + return result; + } + tsl::profiler::TraceMe traceme("pmap_lib_shard_arg_python_fallback"); + auto py_array_or_bufs = python_fallback(arg, input_spec.array_sharding); + + auto py_array = nb::cast(py_array_or_bufs); + ShardArgResult result; + result.owning_sda = nb::borrow(py_array_or_bufs); + result.ifrt_array = tsl::FormRef(py_array.ifrt_array()); + return result; +} + +struct PmapCacheEntry { + explicit PmapCacheEntry(xla::PyTreeRegistry* registry) + : out_pytree_def(registry) {} + std::shared_ptr executable; + // The value `backend.local_devices()`. + nb::object py_devices; // To pass back to Python. + std::vector devices; + std::vector input_specs; + xla::PyTreeDef out_pytree_def; + // Objects necessary to build the out Array objects. + std::vector out_result_specs; + + std::vector out_array_shardings; + std::vector out_dtypes; + std::vector> out_shapes; + std::vector out_committed; + + // Ensures a single thread performs the compilation for a given executable. + // + // The first thread (holding the GIL) will create the CacheEntry associated to + // a signature and if the object has been inserted already, other threads + // will wait for the notification. + absl::Notification compilation_complete; + + bool fall_back_to_python = false; +}; + +} // namespace + +// A `PmapFunction` is associated to a `jax.pmap(f)` and takes care of the +// bookkeeping of the different signatures used and the dispatch of calls to +// the correct underlying `PyLoadedExecutable`. This class is thread-safe. +class PmapFunction { + public: + PmapFunction(nb::callable fun, nb::callable cache_miss, + std::vector static_argnums, + nb::callable python_shard_arg_fallback, + xla::nb_class_ptr pytree_registry) + : fun_(std::move(fun)), + cache_miss_(std::move(cache_miss)), + static_argnums_(std::move(static_argnums)), + pytree_registry_(std::move(pytree_registry)), + python_shard_arg_fallback_(std::move(python_shard_arg_fallback)) { + std::sort(static_argnums_.begin(), static_argnums_.end()); + + function_name_ = + nb::cast(nb::str(nb::getattr(fun_, "__name__", fun_))); + } + PmapFunction(const PmapFunction&) = delete; + PmapFunction& operator=(const PmapFunction& other) = delete; + PmapFunction(PmapFunction&&) = default; + PmapFunction& operator=(PmapFunction&&) = default; + + // This function will: + // (a) flatten the inputs using pytree + // (b) get buffer objects from the arguments + // (c) call the executable + // (d) construct `Array` objects from the outputs + // (e) reconstruct the `PyTree`. + absl::StatusOr Call(nb::handle callable, PyObject* const* args, + size_t nargs, PyObject* kwnames); + + nb::object PythonSignature() { + const nb::module_& inspect = xla::SafeStaticInit([]() { + return std::make_unique(nb::module_::import_("inspect")); + }); + return inspect.attr("signature")(fun_); + } + + int cache_size() { + nb::ft_lock_guard lock(mu_); + return executables_.size(); + } + void cache_clear() { + nb::ft_lock_guard lock(mu_); + return executables_.clear(); + } + const nb::callable& fun() const { return fun_; } + const nb::callable& cache_miss() const { return cache_miss_; } + const std::string& function_name() const { return function_name_; } + const xla::nb_class_ptr& pytree_registry() const { + return pytree_registry_; + } + const nb::callable& python_shard_arg_fallback() const { + return python_shard_arg_fallback_; + } + const std::vector& static_argnums() const { return static_argnums_; } + + // nb::object typed subclass for PmapFunction objects. + class pyobject : public nb::object { + public: + NB_OBJECT(pyobject, nb::object, "PmapFunction", + PmapFunction::IsPmapFunction); + pyobject() = default; + PmapFunction* func() const { + return PmapFunction::AsPmapFunctionUnchecked(*this); + } + }; + // Alias as ::object; outside the scope above we won't confuse nanobind's + // macros. + using object = pyobject; + + // Returns true if `h` is a PmapFunction. + static bool IsPmapFunction(nb::handle handle); + // Converts `handle` to a PmapFunction*. Does not do any checking. + static PmapFunction* AsPmapFunctionUnchecked(nb::handle handle); + + // Helper function used by the tp_clear GC method. + void ClearPythonReferences() { + nb::callable fun, cache_miss, python_shard_arg_fallback; + // Swap values for nulls before they are destroyed. See the Python + // Py_CLEAR() documentation for a discussion of this topic. + std::swap(fun_, fun); + std::swap(cache_miss_, cache_miss); + std::swap(python_shard_arg_fallback_, python_shard_arg_fallback); + } + + // Updates the signature of arguments for a pmapped function. + // + // It deals with the arguments signatures and also of the global and + // thread-local jit context. + absl::Status ComputeCallSignature( + absl::Span flat_dynamic_args, + CallSignature& signature) { + signature.function_name = function_name_; + + // Get dynamic argument signatures. + JitState& global_state = jax::GlobalJitState(); + JitState& tls = jax::ThreadLocalJitState(); + const bool jax_enable_x64 = GetEnableX64(); + signature.jax_enable_x64 = jax_enable_x64; + for (nb::handle arg : flat_dynamic_args) { + auto signature_or_error = xla::PyArgSignatureOfValue(arg, jax_enable_x64); + if (!signature_or_error.ok()) { + VLOG(2) << "PyArgSignatureOfValue failed: " + << signature_or_error.status(); + return signature_or_error.status(); + } + signature.dynamic_arg_signatures.push_back( + std::move(signature_or_error).value()); + } + signature.thread_local_extra_jit_context = tls.extra_jit_context; + signature.global_extra_jit_context = global_state.extra_jit_context; + signature.configs = JitConfigs(); + return absl::Status(); + } + + // Returns, for debugging purposes (e.g. finding why some call misses the + // cache and recompiles), the list of the string representations of the keys. + // + // The format can change at any time. + std::string DebugCacheKeys() { + nb::ft_lock_guard lock(mu_); + std::vector key_strings = { + absl::StrCat("The cache contains ", executables_.size(), " elements:")}; + // We will be able to use auto& [key, _] when TF uses C++ 17. + for (auto& pair : executables_) { + key_strings.push_back(pair.first.DebugString()); + } + return absl::StrJoin(key_strings, "\n\n"); + } + + private: + // Mutates `cache_entry` in place. + void PopulateCacheEntry(PmapCacheEntry& cache_entry, + const nb::tuple& out_and_fastpath_data); + + bool always_fallback_to_python_ = false; + + nb::callable fun_; // The Python function to pmap. + std::string function_name_; + // See JAX _cpp_pmap in api.py for documentation. + nb::callable cache_miss_; + + // We need to know the static arguments to remove them from the arguments + // passed to the underlying PyLoadedExecutable. In sorted order. + std::vector static_argnums_; + xla::nb_class_ptr pytree_registry_; + // We need a `shared_ptr` here to ensure value pointer stability, and to + // ensure that the cache entry remains alive in the presence of concurrent + // removals. + absl::flat_hash_map> + executables_; + + // The fallback function to use with `ShardArgs`. + // TODO(jblespiau): Add support for more types from C++. + nb::callable python_shard_arg_fallback_; + + // Protect methods in FT: + nb::ft_mutex mu_; +}; + +void PmapFunction::PopulateCacheEntry(PmapCacheEntry& cache_entry, + const nb::tuple& out_and_fastpath_data) { + CHECK_EQ(out_and_fastpath_data.size(), 2); + if (out_and_fastpath_data[1].is_none()) { + cache_entry.fall_back_to_python = true; + return; + } + + nb::tuple pmap_data = nb::cast(out_and_fastpath_data[1]); + if (nb::cast(pmap_data.attr("version")) != 1) { + throw xla::XlaRuntimeError(absl::StrCat( + "The versions of jaxlib and Jax are incompatible (pmap cpp version 1 " + "expected, but got ", + nb::cast(pmap_data.attr("version")), + "Upgrade jaxlib and jax. Provided data was:", + nb::cast(nb::str(nb::repr(pmap_data))))); + } + // See api.nb::_PmapFastpathData in the JAX code base for the expected + // namedtuple. + std::shared_ptr executable; + try { + executable = nb::cast>( + pmap_data.attr("xla_executable")); + } catch (const nb::cast_error& e) { + // Backends that don't implement the C++ PjRt APIs + cache_entry.fall_back_to_python = true; + always_fallback_to_python_ = true; + return; + } + cache_entry.executable = std::move(executable); + const std::vector>& devices = + cache_entry.executable->AddressableDevices(); + cache_entry.devices.reserve(devices.size()); + for (auto& device : devices) { + cache_entry.devices.push_back(device->device()); + } + + // Inputs shard args details. + nb::list input_indices = pmap_data.attr("input_indices"); + + cache_entry.py_devices = pmap_data.attr("input_devices"); + auto input_devices = nb::cast>>( + pmap_data.attr("input_devices")); + + nb::list input_array_shardings = pmap_data.attr("input_array_shardings"); + + cache_entry.input_specs.reserve(input_array_shardings.size()); + + for (int i = 0; i < input_array_shardings.size(); ++i) { + cache_entry.input_specs.emplace_back(input_indices[i], + input_array_shardings[i]); + } + + // Outputs specs. + auto out_tree = nb::cast(pmap_data.attr("out_pytree_def")); + cache_entry.out_pytree_def = std::move(out_tree); + nb::list out_avals = pmap_data.attr("out_avals"); + + cache_entry.out_result_specs.reserve(out_avals.size()); + cache_entry.out_dtypes.reserve(out_avals.size()); + cache_entry.out_shapes.reserve(out_avals.size()); + + for (int i = 0; i < out_avals.size(); ++i) { + cache_entry.out_dtypes.push_back(out_avals[i].attr("dtype")); + cache_entry.out_shapes.push_back( + nb::cast>(out_avals[i].attr("shape"))); + cache_entry.out_result_specs.emplace_back(out_avals[i]); + } + + nb::list out_array_shardings = pmap_data.attr("out_array_shardings"); + + DCHECK(out_array_shardings.size() == 0 || + out_avals.size() == out_array_shardings.size()); + + cache_entry.out_array_shardings.reserve(out_array_shardings.size()); + for (nb::handle out_array_sharding : out_array_shardings) { + cache_entry.out_array_shardings.push_back( + nb::borrow(out_array_sharding)); + } + + nb::list out_committed = pmap_data.attr("out_committed"); + + DCHECK(out_committed.size() == 0 || out_avals.size() == out_committed.size()); + + cache_entry.out_committed.reserve(out_committed.size()); + for (nb::handle c : out_committed) { + cache_entry.out_committed.push_back(nb::cast(c)); + } +} + +absl::StatusOr PmapFunction::Call(nb::handle callable, + PyObject* const* args, + size_t nargs, PyObject* kwnames) { + xla::GlobalPyRefManager()->MaybeCollectGarbage(); + + // Calls the cache_miss_ function. This just calls the Python function; it may + // return nullptr value if a Python exception is thrown. + auto cache_miss = [&]() -> nb::tuple { + return nb::steal( + PyObject_Vectorcall(cache_miss_.ptr(), args, nargs, kwnames)); + }; + + // Call the cache_miss() function, extracting the output data and ignoring + // the fastpath data. If the cache miss returns a Python error, returns + // nullptr and leaves the Python error set. + auto fallback_to_cache_miss = [&]() { + nb::tuple cache_miss_output = cache_miss(); + if (!cache_miss_output.ptr()) { + return nb::object(); + } + return nb::object(cache_miss_output[0]); + }; + + if (always_fallback_to_python_) { + return fallback_to_cache_miss(); + } + + size_t num_positional_args = PyVectorcall_NARGS(nargs); + size_t num_keyword_args = kwnames ? PyTuple_GET_SIZE(kwnames) : 0; + absl::Span positional_args(args, num_positional_args); + absl::Span keyword_args(args + num_positional_args, + num_keyword_args); + CallSignature call_signature; + absl::InlinedVector flat_dynamic_args; + std::vector keep_alive_objects; + absl::Status status = + ParseArguments(positional_args, keyword_args, kwnames, static_argnums_, + /*static_argnames=*/{}, pytree_registry_.get(), + call_signature.arg_signature, flat_dynamic_args); + if (!status.ok()) { + VLOG(2) << "ParseArguments failed: " << status; + return fallback_to_cache_miss(); + } + + status = ComputeCallSignature(flat_dynamic_args, call_signature); + if (!status.ok()) { + return fallback_to_cache_miss(); + } + + // Retrieve/Maybe add the executable to the cache. + bool inserted = false; + std::shared_ptr cache_entry_ptr; + { + nb::ft_lock_guard lock(mu_); + std::shared_ptr& entry_ref = executables_[call_signature]; + if (!entry_ref) { + inserted = true; + entry_ref = std::make_shared(pytree_registry_.get()); + } + cache_entry_ptr = entry_ref; + } + PmapCacheEntry& cache_entry = *cache_entry_ptr; + + if (!cache_entry.compilation_complete.HasBeenNotified()) { + // In case of several threads attempting to compile the executable, only + // the one that inserted the item will perform the compilation. + if (inserted) { + nb::object out_and_fastpath_data; + nb::tuple out_tuple; + VLOG(2) << "Cache miss for " << call_signature.DebugString(); + try { + // Calls Python and may release the GIL. May also throw if + // compilation/tracing fails. + out_and_fastpath_data = cache_miss(); + if (!out_and_fastpath_data.ptr()) { + throw nb::python_error(); + } + out_tuple = nb::cast(out_and_fastpath_data); + + PopulateCacheEntry(cache_entry, out_tuple); + } catch (const std::exception& e) { + cache_entry.fall_back_to_python = true; + cache_entry.compilation_complete.Notify(); + throw; + } + cache_entry.compilation_complete.Notify(); + + // We have already computed the result in the miss path so we can return + // it. We are even *required* to do so if there are donated arguments, + // because any donated buffers will now be invalid. + return nb::object(out_tuple[0]); + } else { + // Release the GIL while we wait, making sure the compile thread can + // lock it. + nb::gil_scoped_release release; + cache_entry.compilation_complete.WaitForNotification(); + } + } + if (cache_entry.fall_back_to_python) { + return fallback_to_cache_miss(); + } + + // 1. Parse arguments. + std::vector& input_devices = cache_entry.devices; + std::vector& input_specs = cache_entry.input_specs; + const int num_args = flat_dynamic_args.size(); + + // We need [num_args] for the `Execute` call below. + std::vector num_args_arrays(num_args); + for (int i = 0; i < num_args; ++i) { + TF_ASSIGN_OR_RETURN( + ShardArgResult sharded_arg, + ShardArg(flat_dynamic_args[i], input_devices, input_specs[i], + cache_entry.py_devices, python_shard_arg_fallback_)); + + num_args_arrays[i] = std::move(sharded_arg.ifrt_array); + if (sharded_arg.owning_sda) { + keep_alive_objects.push_back(std::move(sharded_arg.owning_sda)); + } + } + + xla::ifrt::ExecuteOptions execute_options = cache_entry.executable->options(); + execute_options.launch_id = cache_entry.executable->GetNextLaunchId(); + execute_options.execution_stream_id = xla::GetExecutionStreamId(); + if (execute_options.execution_stream_id == 0) { + execute_options.execution_stream_id = + tsl::Env::Default()->GetCurrentThreadId(); + } + + // A vector of [num_outputs]. + std::vector output_arrays; + { + nb::gil_scoped_release gil_release; + auto ifrt_executable = cache_entry.executable->ifrt_executable(); + TF_ASSIGN_OR_RETURN( + auto result, ifrt_executable->Execute(absl::MakeSpan(num_args_arrays), + execute_options, + /*devices=*/std::nullopt)); + output_arrays = std::move(result.outputs); + } + + // TODO(jblespiau): We don't need to create the PyBuffer objects. + // Having a C++ `Array`, keeping internally the PjRtBuffer + // objects is sufficient, and we can lazily create the `PyBuffer` only if + // we access them from Python. + auto traceback = xla::Traceback::Get(); + // TODO(jblespiau): Change the `client` function to return a reference. + xla::nb_class_ptr client = cache_entry.executable->client(); + + // Convert the PjRtBuffer objects to PyBuffer, and invert the order from + // [num_devices, num_args] to [num_args, num_devices]. + const int num_outputs = output_arrays.size(); + std::vector flat_sharded_device_arrays; + flat_sharded_device_arrays.reserve(num_outputs); + + const auto& output_specs = cache_entry.out_result_specs; + + TF_RET_CHECK(cache_entry.out_array_shardings.size() == num_outputs); + for (int i = 0; i < num_outputs; ++i) { + const ResultSpec& result_spec = output_specs[i]; + xla::PyArray py_array( + result_spec.out_aval, result_spec.weak_type, cache_entry.out_dtypes[i], + cache_entry.out_shapes[i], cache_entry.out_array_shardings[i], client, + traceback, std::move(output_arrays[i]), cache_entry.out_committed[i], + /*skip_checks=*/true); + + flat_sharded_device_arrays.push_back(std::move(py_array)); + } + + nb::object out = + cache_entry.out_pytree_def.Unflatten(flat_sharded_device_arrays); + + // If there is a post-hook function, call it with the inputs and the outputs. + std::optional post_hook = GetPostHook(); + if (post_hook) { + nb::tuple args_tuple = + nb::steal(PyTuple_New(num_positional_args)); + for (size_t i = 0; i < num_positional_args; ++i) { + Py_INCREF(args[i]); + PyTuple_SET_ITEM(args_tuple.ptr(), i, args[i]); + } + nb::dict kwargs; + if (kwnames) { + for (size_t i = 0; i < num_keyword_args; ++i) { + kwargs[nb::handle(PyTuple_GET_ITEM(kwnames, i))] = + nb::borrow(args[num_positional_args + i]); + } + } + + (*post_hook)(callable, args_tuple, kwargs, out); + } + + return out; +} + +struct JaxPmapFunctionObject { + PyObject_HEAD; +#if PY_VERSION_HEX < 0x030C0000 + PyObject* dict; // Dictionary for __dict__ + PyObject* weakrefs; // Weak references; for use by the Python interpreter. +#endif // PY_VERSION_HEX < 0x030C0000 + vectorcallfunc vectorcall; + PmapFunction fun; +}; + +PyObject* JaxPmapFunction_Type = nullptr; + +bool PmapFunction::IsPmapFunction(nb::handle handle) { + return handle.type().ptr() == JaxPmapFunction_Type; +} + +PmapFunction* PmapFunction::AsPmapFunctionUnchecked(nb::handle handle) { + return &(reinterpret_cast(handle.ptr())->fun); +} + +absl::StatusOr AsPmapFunction(nb::handle handle) { + if (!PmapFunction::IsPmapFunction(handle)) { + return xla::InvalidArgument("Expected a PmapFunction"); + } + return PmapFunction::AsPmapFunctionUnchecked(handle); +} + +namespace { + +extern "C" { + +PyObject* JaxPmapFunction_tp_vectorcall(PyObject* callable, + PyObject* const* args, size_t nargs, + PyObject* kwnames) { + JaxPmapFunctionObject* o = reinterpret_cast(callable); + tsl::profiler::TraceMe traceme([&] { + return absl::StrCat("JaxPmapFunction(", o->fun.function_name(), ")"); + }); + try { + absl::StatusOr out = + o->fun.Call(callable, args, nargs, kwnames); + if (!out.ok()) { + PyErr_SetString(PyExc_ValueError, out.status().ToString().c_str()); + return nullptr; + } + return out.value().release().ptr(); + } catch (nb::python_error& e) { + e.restore(); + return nullptr; + } catch (nb::cast_error& e) { + PyErr_SetString(PyExc_ValueError, e.what()); + return nullptr; + } catch (std::invalid_argument& e) { + PyErr_SetString(PyExc_ValueError, e.what()); + return nullptr; + } +} + +PyObject* JaxPmapFunction_tp_new(PyTypeObject* subtype, PyObject* args, + PyObject* kwds) { + JaxPmapFunctionObject* self = + reinterpret_cast(subtype->tp_alloc(subtype, 0)); + if (!self) return nullptr; +#if PY_VERSION_HEX < 0x030C0000 + self->dict = nullptr; + self->weakrefs = nullptr; +#endif // PY_VERSION_HEX < 0x030C0000 + self->vectorcall = JaxPmapFunction_tp_vectorcall; + return reinterpret_cast(self); +} + +void JaxPmapFunction_tp_dealloc(PyObject* self) { + PyObject_GC_UnTrack(self); + PyTypeObject* tp = Py_TYPE(self); + JaxPmapFunctionObject* o = reinterpret_cast(self); + PyObject_ClearWeakRefs(self); +#if PY_VERSION_HEX < 0x030C0000 + Py_CLEAR(o->dict); +#elif PY_VERSION_HEX < 0x030D0000 + _PyObject_ClearManagedDict(self); +#else + PyObject_ClearManagedDict(self); +#endif // PY_VERSION_HEX < 0x030C0000 + o->fun.~PmapFunction(); + tp->tp_free(self); + Py_DECREF(tp); +} + +int JaxPmapFunction_tp_traverse(PyObject* self, visitproc visit, void* arg) { + JaxPmapFunctionObject* o = reinterpret_cast(self); + // https://docs.python.org/3/c-api/typeobj.html#c.PyTypeObject.tp_traverse + Py_VISIT(Py_TYPE(self)); +#if PY_VERSION_HEX < 0x030C0000 + Py_VISIT(o->dict); +#elif PY_VERSION_HEX < 0x030D0000 + _PyObject_VisitManagedDict(self, visit, arg); +#else + PyObject_VisitManagedDict(self, visit, arg); +#endif // PY_VERSION_HEX < 0x030C0000 + Py_VISIT(o->fun.fun().ptr()); + Py_VISIT(o->fun.cache_miss().ptr()); + return 0; +} + +int JaxPmapFunction_tp_clear(PyObject* self) { + JaxPmapFunctionObject* o = reinterpret_cast(self); +#if PY_VERSION_HEX < 0x030C0000 + Py_CLEAR(o->dict); +#elif PY_VERSION_HEX < 0x030D0000 + _PyObject_ClearManagedDict(self); +#else + PyObject_ClearManagedDict(self); +#endif // PY_VERSION_HEX < 0x030C0000 + o->fun.ClearPythonReferences(); + return 0; +} + +// Implements the Python descriptor protocol so PMAP-compiled functions can be +// used as bound methods. See: +// https://docs.python.org/3/howto/descriptor.html#functions-and-methods +PyObject* JaxPmapFunction_tp_descr_get(PyObject* self, PyObject* obj, + PyObject* type) { + if (obj == nullptr || obj == Py_None) { + Py_INCREF(self); + return self; + } + return PyMethod_New(self, obj); +} + +static PyGetSetDef JaxPmapFunction_tp_getset[] = { + // Having a __dict__ seems necessary to allow !functool.wraps to override + // __doc__. + {const_cast("__dict__"), PyObject_GenericGetDict, + PyObject_GenericSetDict, nullptr, nullptr}, + {nullptr, nullptr, nullptr, nullptr, nullptr}}; + +PyMemberDef JaxPmapFunction_members[] = { + {"__vectorcalloffset__", T_PYSSIZET, + static_cast(offsetof(JaxPmapFunctionObject, vectorcall)), + READONLY, nullptr}, +#if PY_VERSION_HEX < 0x030C0000 + {"__dictoffset__", T_PYSSIZET, + static_cast(offsetof(JaxPmapFunctionObject, dict)), READONLY, + nullptr}, + {"__weaklistoffset__", T_PYSSIZET, + static_cast(offsetof(JaxPmapFunctionObject, weakrefs)), + READONLY, nullptr}, +#endif // PY_VERSION_HEX < 0x030C0000 + {nullptr, 0, 0, 0, nullptr}, +}; + +PyType_Slot JaxPmapFunction_slots[] = { + {Py_tp_new, reinterpret_cast(JaxPmapFunction_tp_new)}, + {Py_tp_dealloc, reinterpret_cast(JaxPmapFunction_tp_dealloc)}, + {Py_tp_traverse, reinterpret_cast(JaxPmapFunction_tp_traverse)}, + {Py_tp_clear, reinterpret_cast(JaxPmapFunction_tp_clear)}, + {Py_tp_getset, reinterpret_cast(JaxPmapFunction_tp_getset)}, + {Py_tp_descr_get, reinterpret_cast(JaxPmapFunction_tp_descr_get)}, + {Py_tp_call, reinterpret_cast(PyVectorcall_Call)}, + {Py_tp_members, reinterpret_cast(JaxPmapFunction_members)}, + {0, nullptr}, +}; + +} // extern "C" + +nb::object MakePmapFunction( + nb::callable fun, nb::callable cache_miss, std::vector static_argnums, + nb::callable python_shard_arg_fallback, + xla::nb_class_ptr pytree_registry) { + nb::object obj = nb::steal(JaxPmapFunction_tp_new( + reinterpret_cast(JaxPmapFunction_Type), nullptr, nullptr)); + JaxPmapFunctionObject* buf = + reinterpret_cast(obj.ptr()); + new (&buf->fun) PmapFunction( + std::move(fun), std::move(cache_miss), std::move(static_argnums), + std::move(python_shard_arg_fallback), std::move(pytree_registry)); + return obj; +} + +// Version numbers for the pickled representations. +// Increment these if changing them. +const int kPmapFunctionPickleVersion = 1; + +} // namespace + +void BuildPmapSubmodule(nb::module_& m) { + nb::module_ pmap_lib = m.def_submodule("pmap_lib", "Jax C++ pmap library"); + + nb::class_ no_sharding(pmap_lib, "NoSharding"); + no_sharding.def(nb::init<>()) + .def("__getstate__", + [](const NoSharding& self) { return nb::make_tuple(); }) + .def("__setstate__", + [](NoSharding& self, nb::tuple t) { new (&self) NoSharding(); }) + .def("__repr__", [](const NoSharding& self) { return "NoSharding()"; }) + .def("__eq__", + [](const NoSharding& self, nb::object obj) { + return nb::isinstance(obj); + }) + .def("__hash__", [](const NoSharding& self) { + const size_t hash = absl::HashOf(self); + return nb::int_(hash); + }); + + nb::class_ chunked(pmap_lib, "Chunked"); + chunked.def(nb::init>()) + .def("__getstate__", + [](const Chunked& self) { return nb::make_tuple(self.chunks); }) + .def("__setstate__", + [](Chunked& self, nb::tuple t) { + new (&self) Chunked{nb::cast>(t[0])}; + }) + .def_ro("chunks", &Chunked::chunks) + .def("__repr__", + [](const Chunked& self) { + return absl::StrCat("Chunked(", absl::StrJoin(self.chunks, ","), + ")"); + }) + .def("__eq__", [](const Chunked& self, nb::object other) { + if (!nb::isinstance(other)) { + return false; + } + return self == nb::cast(other); + }); + + nb::class_ unstacked(pmap_lib, "Unstacked"); + unstacked.def(nb::init()) + .def("__getstate__", + [](const Unstacked& self) { return nb::make_tuple(self.size); }) + .def("__setstate__", + [](Unstacked& self, nb::tuple t) { + new (&self) Unstacked{nb::cast(t[0])}; + }) + .def_ro("size", &Unstacked::size) + .def("__repr__", + [](const Unstacked& x) { + return absl::StrCat("Unstacked(", x.size, ")"); + }) + .def("__eq__", [](const Unstacked& self, nb::object other) { + if (!nb::isinstance(other)) { + return false; + } + return self == nb::cast(other); + }); + + nb::class_ sharded_axis(pmap_lib, "ShardedAxis"); + sharded_axis.def(nb::init()) + .def("__getstate__", + [](const ShardedAxis& self) { return nb::make_tuple(self.axis); }) + .def("__setstate__", + [](ShardedAxis& self, nb::tuple t) { + new (&self) ShardedAxis{nb::cast(t[0])}; + }) + .def_ro("axis", &ShardedAxis::axis) + .def("__repr__", + [](const ShardedAxis& x) { + return absl::StrCat("ShardedAxis(axis=", x.axis, ")"); + }) + .def("__eq__", [](const ShardedAxis& self, const ShardedAxis& other) { + return self == other; + }); + + nb::class_ replicated(pmap_lib, "Replicated"); + replicated.def(nb::init()) + .def("__getstate__", + [](const Replicated& self) { return nb::make_tuple(self.replicas); }) + .def("__setstate__", + [](Replicated& self, nb::tuple t) { + new (&self) Replicated{nb::cast(t[0])}; + }) + .def_ro("replicas", &Replicated::replicas) + .def("__repr__", + [](const Replicated& x) { + return absl::StrCat("Replicated(replicas=", x.replicas, ")"); + }) + .def("__eq__", [](const Replicated& self, const Replicated& other) { + return self == other; + }); + + nb::class_ sharding_spec(pmap_lib, "ShardingSpec"); + sharding_spec + .def(nb::init(), nb::arg("sharding"), + nb::arg("mesh_mapping")) + .def("__getstate__", + [](const ShardingSpec& self) { + auto sharding = + xla::SpanToNbTuple(absl::MakeConstSpan(self.GetSharding())); + auto mesh_mapping = + xla::SpanToNbTuple(absl::MakeConstSpan(self.GetMeshMapping())); + return nb::make_tuple(sharding, mesh_mapping); + }) + .def("__setstate__", + [](ShardingSpec& self, nb::tuple t) { + new (&self) + ShardingSpec{nb::cast>(t[0]), + nb::cast>(t[1])}; + }) + .def_prop_ro( + "sharding", + [](const ShardingSpec& self) { + return xla::SpanToNbTuple(absl::MakeConstSpan(self.GetSharding())); + }) + .def_prop_ro("mesh_mapping", + [](const ShardingSpec& self) { + return xla::SpanToNbTuple( + absl::MakeConstSpan(self.GetMeshMapping())); + }) + .def("__eq__", [](const ShardingSpec& self, + const ShardingSpec& other) { return self == other; }) + .def("__hash__", [](const ShardingSpec& self) { + const size_t hash = absl::HashOf(self); + return nb::int_(hash); + }); + + // We need to use heap-allocated type objects because we want to add + // additional methods dynamically. + + std::string name = + absl::StrCat(nb::cast(m.attr("__name__")), ".PmapFunction"); + PyType_Spec pmap_function_spec = { + /*.name=*/name.c_str(), + /*.basicsize=*/static_cast(sizeof(JaxPmapFunctionObject)), + /*.itemsize=*/0, +#if PY_VERSION_HEX < 0x030C0000 + /*.flags=*/Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC | + Py_TPFLAGS_HAVE_VECTORCALL, +#else // PY_VERSION_HEX >= 0x030C0000 + /*.flags=*/Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC | + Py_TPFLAGS_HAVE_VECTORCALL | Py_TPFLAGS_MANAGED_DICT | + Py_TPFLAGS_MANAGED_WEAKREF, +#endif // PY_VERSION_HEX >= 0x030C0000 + /*.slots=*/JaxPmapFunction_slots, + }; + + JaxPmapFunction_Type = PyType_FromSpec(&pmap_function_spec); + if (!JaxPmapFunction_Type) { + throw nb::python_error(); + } + nb::object cfun = nb::borrow(JaxPmapFunction_Type); + + // Add PmapFunction to the _jax module so it can be pickled. + m.attr("PmapFunction") = cfun; + + cfun.attr("__signature__") = + xla::nb_property_readonly([](nb::handle self) -> nb::object { + PmapFunction* fun = xla::ValueOrThrow(AsPmapFunction(self)); + return fun->PythonSignature(); + }); + // Required by `post_hook`. + cfun.attr("_cache_miss") = + xla::nb_property_readonly([](nb::handle self) -> nb::object { + PmapFunction* fun = xla::ValueOrThrow(AsPmapFunction(self)); + return fun->cache_miss(); + }); + cfun.attr("__getstate__") = nb::cpp_function( + [](const PmapFunction::object& self) { + PmapFunction* fn = self.func(); + nb::dict pickle; + pickle["version"] = kPmapFunctionPickleVersion; + pickle["fun"] = fn->fun(); + pickle["cache_miss"] = fn->cache_miss(); + pickle["static_argnums"] = fn->static_argnums(); + pickle["python_shard_arg_fallback"] = fn->python_shard_arg_fallback(); + pickle["pytree_registry"] = nb::cast(fn->pytree_registry()); + return pickle; + }, + nb::is_method()); + cfun.attr("__setstate__") = nb::cpp_function( + [](PmapFunction::object& self, const nb::dict& pickle) { + int version = nb::cast(pickle["version"]); + if (version != kPmapFunctionPickleVersion) { + throw std::invalid_argument(absl::StrFormat( + "Invalid PmapFunction pickle version, got %d, expected %d. " + "Pickling/Unpickling jitted functions using different JAX " + "versions is not supported.", + version, kPmapFunctionPickleVersion)); + } + nb::callable fun = nb::cast(pickle["fun"]); + nb::callable cache_miss = nb::cast(pickle["cache_miss"]); + std::vector static_argnums = + nb::cast>(pickle["static_argnums"]); + nb::callable python_shard_arg_fallback = + nb::cast(pickle["python_shard_arg_fallback"]); + xla::nb_class_ptr pytree_registry = + nb::cast>( + pickle["pytree_registry"]); + new (&(reinterpret_cast(self.ptr())->fun)) + PmapFunction(std::move(fun), std::move(cache_miss), + std::move(static_argnums), + std::move(python_shard_arg_fallback), + std::move(pytree_registry)); + }, + nb::is_method()); + + // This is only for testing/debugging purposes. + cfun.attr("_cache_size") = + xla::nb_property_readonly([](nb::handle self) -> nb::object { + PmapFunction* fun = xla::ValueOrThrow(AsPmapFunction(self)); + return nb::cast(fun->cache_size()); + }); + + cfun.attr("_cache_clear") = nb::cpp_function( + [](nb::handle self) { + PmapFunction* fun = xla::ValueOrThrow(AsPmapFunction(self)); + fun->cache_clear(); + }, + nb::is_method()); + + cfun.attr("_debug_cache_keys") = nb::cpp_function( + [](nb::handle self) -> std::string { + PmapFunction* fun = xla::ValueOrThrow(AsPmapFunction(self)); + return fun->DebugCacheKeys(); + }, + nb::is_method()); + + pmap_lib.def( + "pmap", + [](nb::callable fun, nb::callable cache_miss, + std::vector static_argnums, nb::callable shard_arg_fallback, + nb::object pytree_registry) -> nb::object { + xla::nb_class_ptr registry = + nb::cast>(pytree_registry); + return MakePmapFunction( + std::move(fun), std::move(cache_miss), std::move(static_argnums), + std::move(shard_arg_fallback), std::move(registry)); + }, + nb::arg("fun"), nb::arg("cache_miss"), nb::arg("static_argnums"), + nb::arg("shard_arg_fallback"), nb::arg("pytree_registry")); +} + +} // namespace jax diff --git a/jaxlib/pmap_lib.h b/jaxlib/pmap_lib.h new file mode 100644 index 000000000000..b7cc2cc13f36 --- /dev/null +++ b/jaxlib/pmap_lib.h @@ -0,0 +1,34 @@ +/* Copyright 2021 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_PMAP_LIB_H_ +#define JAXLIB_PMAP_LIB_H_ + + +// placeholder for index annotation headers +#include "nanobind/nanobind.h" + +// TODO(jblespiau): The current implementation moves the Python logic to C++, +// as a preliminary step to executing the `pmap` execution path from C++. +// It implements the current Python behavior (thus, it may not be optimal, and +// we will be able to modify it later). + +namespace jax { + +void BuildPmapSubmodule(nanobind::module_& m); + +} // namespace jax + +#endif // JAXLIB_PMAP_LIB_H_ diff --git a/jaxlib/py_array.cc b/jaxlib/py_array.cc new file mode 100644 index 000000000000..d7ff7ee6e3f7 --- /dev/null +++ b/jaxlib/py_array.cc @@ -0,0 +1,2141 @@ +/* Copyright 2022 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/py_array.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include // NOLINT +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/base/casts.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/container/inlined_vector.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "llvm/Support/Casting.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/optional.h" // IWYU pragma: keep +#include "nanobind/stl/pair.h" // IWYU pragma: keep +#include "nanobind/stl/shared_ptr.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "nanobind/stl/unique_ptr.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "jaxlib/guard_lib.h" +#include "jaxlib/nb_class_ptr.h" +#include "jaxlib/py_client.h" +#include "jaxlib/py_device.h" +#include "jaxlib/py_device_list.h" +#include "jaxlib/py_values.h" +#include "jaxlib/python_ref_manager.h" +#include "jaxlib/sharding.h" +#include "jaxlib/to_ifrt_sharding.h" +#include "jaxlib/traceback.h" +#include "jaxlib/util.h" +#include "xla/layout.h" +#include "xla/layout_util.h" +#include "xla/pjrt/exceptions.h" +#include "xla/pjrt/lru_cache.h" +#include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_compiler.h" +#include "xla/pjrt/pjrt_future.h" +#include "xla/pjrt/status_casters.h" +#include "xla/primitive_util.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/array_spec.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/device_list.h" +#include "xla/python/ifrt/dtype.h" +#include "xla/python/ifrt/future.h" +#include "xla/python/ifrt/memory.h" +#include "xla/python/ifrt/remap_plan.h" +#include "xla/python/ifrt/shape.h" +#include "xla/python/ifrt/sharding.h" +#include "xla/python/nb_absl_span.h" // IWYU pragma: keep +#include "xla/python/nb_helpers.h" +#include "xla/python/nb_numpy.h" +#include "xla/python/pjrt_ifrt/pjrt_array.h" +#include "xla/python/pjrt_ifrt/pjrt_client.h" +#include "xla/python/pjrt_ifrt/pjrt_device.h" +#include "xla/python/pjrt_ifrt/pjrt_dtype.h" +#include "xla/python/safe_static_init.h" +#include "xla/python/types.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/status_macros.h" +#include "xla/tsl/concurrency/ref_count.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/tsl/python/lib/core/numpy.h" // IWYU pragma: keep +#include "xla/util.h" +#include "xla/xla_data.pb.h" +#include "tsl/profiler/lib/traceme.h" + +namespace xla { +namespace { + +namespace nb = nanobind; + +PjRtBuffer* GetPjrtBuffer(ifrt::Array* ifrt_array) { + auto* arr = llvm::dyn_cast_or_null(ifrt_array); + if (arr == nullptr) { + throw XlaRuntimeError( + "This operation is implemented for a PjRt-compatible backend only."); + } + return arr->pjrt_buffers().front().get(); +} + +absl::StatusOr XlaDynamicShape(ifrt::Array* ifrt_array, + std::optional& scratch) { + auto* pjrt_buffer = GetPjrtBuffer(ifrt_array); + + if (!scratch) { + absl::Span dims; + std::optional> logical_dims_storage; + if (pjrt_buffer->has_dynamic_dimensions()) { + { + nb::gil_scoped_release gil_release; + TF_ASSIGN_OR_RETURN(std::vector logical_dims, + pjrt_buffer->logical_dimensions()); + logical_dims_storage.emplace(std::move(logical_dims)); + } + dims = *logical_dims_storage; + } else { + dims = pjrt_buffer->dimensions(); + } + Shape shape = ShapeUtil::MakeShape(pjrt_buffer->element_type(), dims); + // TODO(b/327524065): fix this + *shape.mutable_layout() = pjrt_buffer->layout()->xla_layout(); + scratch = std::move(shape); + } + return &scratch.value(); +} + +ifrt::ArrayRef CreateIfRtArrayFromSingleDeviceShardedPyArrays( + nb_dtype dtype, absl::Span shape, + absl::Span py_arrays, const nb::object& sharding) { + const ifrt::MemoryKind dst_memory_kind = xla::GetMemoryKind(sharding); + + std::vector ifrt_arrays; + ifrt_arrays.reserve(py_arrays.size()); + absl::InlinedVector devices; + devices.reserve(py_arrays.size()); + absl::flat_hash_set device_set; + device_set.reserve(py_arrays.size()); + std::vector shapes; + shapes.reserve(py_arrays.size()); + + auto sharding_device_list = xla::GetIfrtDeviceList(sharding); + if (!sharding_device_list.ok()) { + // TODO(hyeontaek): Return a absl::Status. + throw nb::value_error(sharding_device_list.status().ToString().c_str()); + } + ifrt::Device* device = sharding_device_list.value()->devices().front(); + + // TODO(hyeontaek): Canonicalize every `ifrt::MemoryKind` at creation time to + // skip canonicalization here once JAX begins to do it for JAX shardings. + const ifrt::MemoryKind canonical_dst_memory_kind = + ifrt::CanonicalizeMemoryKind(dst_memory_kind, device); + for (const auto& py_array : py_arrays) { + if (py_array.num_shards() != 1) { + throw nb::value_error( + absl::StrFormat( + "When making an array from single-device arrays the input arrays " + "must have one shard each. An argument array had %d shard(s).", + py_array.num_shards()) + .c_str()); + } + ifrt_arrays.push_back(tsl::FormRef(py_array.ifrt_array())); + ifrt::Device* const device = + ifrt_arrays.back()->sharding().devices()->devices().front(); + devices.push_back(device); + device_set.insert(device); + shapes.push_back(ifrt_arrays.back()->shape()); + if (canonical_dst_memory_kind != + ifrt::CanonicalizeMemoryKind( + ifrt_arrays.back()->sharding().memory_kind(), device)) { + throw nb::value_error( + absl::StrFormat( + "Memory kind mismatch with PjRtBuffers. Got sharding with " + "memory kind '%v' and a buffer with memory_kind '%v'", + dst_memory_kind, ifrt_arrays.back()->sharding().memory_kind()) + .c_str()); + } + } + ifrt::DeviceListRef device_list = device->client()->MakeDeviceList(devices); + if (device_set.size() != device_list->size()) { + throw nb::value_error( + absl::StrFormat( + "When making an array from single-device arrays, the input arrays " + "must be from distinct devices, but got %v", + *device_list) + .c_str()); + } + + auto ifrt_dtype = DtypeToIfRtDType(dtype); + if (!ifrt_dtype.ok()) { + // TODO(hyeontaek): Return a absl::Status. + throw nb::value_error(ifrt_dtype.status().ToString().c_str()); + } + + absl::StatusOr ifrt_sharding = + sharding.type().is(jax::PmapSharding::type()) + ? xla::GetIfrtConcreteSharding(sharding, ifrt::Shape(shape), + std::move(shapes)) + : xla::GetIfrtHloSharding(sharding, ifrt::Shape(shape)); + if (!ifrt_sharding.ok()) { + // TODO(hyeontaek): Return a absl::Status. + throw nb::value_error(ifrt_sharding.status().ToString().c_str()); + } + // TODO(emilyaf): Always use `ifrt_dtype` once tokens are handled correctly. + ifrt::DType array_dtype = + ifrt_arrays.empty() ? ifrt_dtype.value() : ifrt_arrays[0]->dtype(); + absl::StatusOr ifrt_array = + device->client()->AssembleArrayFromSingleDeviceArrays( + array_dtype, ifrt::Shape(shape), *std::move(ifrt_sharding), + absl::MakeSpan(ifrt_arrays), ifrt::ArrayCopySemantics::kReuseInput, + ifrt::SingleDeviceShardSemantics::kAddressableShards); + if (!ifrt_array.ok()) { + // TODO(hyeontaek): Return a absl::Status. + throw nb::value_error(ifrt_array.status().ToString().c_str()); + } + return *std::move(ifrt_array); +} + +struct PyBaseArrayObject { + PyObject_HEAD; +#if PY_VERSION_HEX < 0x030C0000 + PyObject* weakrefs; +#endif // PY_VERSION_HEX < 0x030C0000 +}; + +extern "C" void PyBaseArray_tp_dealloc(PyBaseArrayObject* self) { + PyObject_GC_UnTrack(self); + PyObject_ClearWeakRefs((PyObject*)self); + PyTypeObject* tp = Py_TYPE(self); + tp->tp_free((PyObject*)self); + Py_DECREF(tp); +} + +extern "C" int PyBaseArray_tp_traverse(PyObject* self, visitproc visit, + void* arg) { + Py_VISIT(Py_TYPE(self)); + return 0; +} + +struct PyArrayObject { + PyObject_HEAD; +#if PY_VERSION_HEX < 0x030C0000 + PyObject* weakrefs; + PyObject* dict; +#endif // PY_VERSION_HEX < 0x030C0000 + bool initialized; + alignas(PyArray::Storage) char array_storage[sizeof(PyArray::Storage)]; +}; +static_assert(std::is_standard_layout::value); + +PyArray::Storage* GetPyArrayStorageFromObject(PyArrayObject* py_array_object) { + return std::launder( + reinterpret_cast(py_array_object->array_storage)); +} + +extern "C" PyObject* PyArray_tp_new(PyTypeObject* type, PyObject*, PyObject*) { + PyObject* self = type->tp_alloc(type, 0); + auto* obj = reinterpret_cast(self); + obj->initialized = false; + return self; +} + +extern "C" void PyArray_tp_dealloc(PyObject* self) { + PyObject_GC_UnTrack(self); + PyTypeObject* tp = Py_TYPE(self); + auto* obj = reinterpret_cast(self); + + if (obj->initialized) { + GetPyArrayStorageFromObject(obj)->~PyArray_Storage(); + } + + PyObject_ClearWeakRefs(self); +#if PY_VERSION_HEX < 0x030C0000 + PyObject*& dict = *_PyObject_GetDictPtr(self); + Py_CLEAR(dict); +#elif PY_VERSION_HEX < 0x030D0000 + _PyObject_ClearManagedDict(self); +#else + PyObject_ClearManagedDict(self); +#endif // PY_VERSION_HEX < 0x030C0000 + + tp->tp_free(self); + Py_DECREF(tp); +} + +// dynamic_attr: Allow the garbage collector to traverse the internal instance +// `__dict__`. +extern "C" int PyArray_tp_traverse(PyObject* self, visitproc visit, void* arg) { +#if PY_VERSION_HEX < 0x030C0000 + PyObject*& dict = *_PyObject_GetDictPtr(self); + Py_VISIT(dict); +#elif PY_VERSION_HEX < 0x030D0000 + _PyObject_VisitManagedDict(self, visit, arg); +#else + PyObject_VisitManagedDict(self, visit, arg); +#endif // PY_VERSION_HEX < 0x030C0000 + // https://docs.python.org/3/c-api/typeobj.html#c.PyTypeObject.tp_traverse + Py_VISIT(Py_TYPE(self)); + return 0; +} + +// dynamic_attr: Allow the GC to clear the dictionary. +extern "C" int PyArray_tp_clear(PyObject* self) { + switch (auto guard_level = jax::GetGarbageCollectArrayGuard(); guard_level) { + case jax::GarbageCollectionGuardLevel::kAllow: + break; + case jax::GarbageCollectionGuardLevel::kLog: + case jax::GarbageCollectionGuardLevel::kFatal: { + auto* obj = reinterpret_cast(self); + std::string traceback_str; + if (obj->initialized) { + auto traceback = GetPyArrayStorageFromObject(obj)->traceback; + if (traceback.has_value()) { + traceback_str = traceback.value().ToString(); + } + } + auto error_msg = absl::StrCat( + "`jax.Array` was deleted by the Python garbage collector " + "instead of reference counting. Break the reference cycle " + "that delays the deletion of this `jax.Array` to avoid hogging " + "memory. Traceback: \n", + traceback_str.empty() ? "not available" : traceback_str); + if (guard_level == jax::GarbageCollectionGuardLevel::kFatal) { + Py_FatalError(error_msg.c_str()); + } else { + PyErr_SetString(PyExc_RuntimeError, error_msg.c_str()); + PyErr_Print(); + PyErr_Clear(); + } + break; + } + } +#if PY_VERSION_HEX < 0x030C0000 + PyObject*& dict = *_PyObject_GetDictPtr(self); + Py_CLEAR(dict); +#elif PY_VERSION_HEX < 0x030D0000 + _PyObject_ClearManagedDict(self); +#else + PyObject_ClearManagedDict(self); +#endif // PY_VERSION_HEX < 0x030C0000 + return 0; +} + +template +PyArray::Storage* Construct(PyArrayObject* self, Args&&... args) { + PyArray::Storage* out = + new (self->array_storage) PyArray::Storage(std::forward(args)...); + self->initialized = true; + return out; +} + +struct ShapedArrayCacheKey { + std::vector dims; + ifrt::DType dtype{ifrt::DType::kInvalid}; + bool weak_type; + + template + friend H AbslHashValue(H h, const ShapedArrayCacheKey& value) { + return H::combine(std::move(h), value.dims, value.dtype, value.weak_type); + } + bool operator==(const ShapedArrayCacheKey& other) const { + return dims == other.dims && dtype == other.dtype && + weak_type == other.weak_type; + } +}; + +// Constructing ShapedArrays has gotten slow. Cache it. +nb::object MakeShapedArrayCached(const ShapedArrayCacheKey& key) { + using CacheT = + LRUCache>>; + static nb::ft_mutex mu; + static auto* lru_list = new CacheT::LRUList(4096); + static auto* cache = new CacheT(lru_list); + + const nb::object& shaped_array = SafeStaticInit([]() { + nb::object jax_core; + try { + jax_core = nb::module_::import_("jax.core"); + } catch (nb::python_error& e) { + return std::make_unique(); + } + return std::make_unique(jax_core.attr("ShapedArray")); + }); + if (!shaped_array.ptr()) { + return nb::none(); + } + + nb::ft_lock_guard lock(mu); + auto value = + cache->GetOrCreateIfAbsent(key, [](const ShapedArrayCacheKey& key) { + return std::make_shared>(); + }); + + if (!value->has_value()) { + nb_dtype dtype = + IfrtDtypeToDtypeWithTokenCanonicalization(key.dtype).value(); + nb::object aval = shaped_array( + SpanToNbTuple(absl::Span( + key.dtype.kind() == ifrt::DType::kToken ? std::vector{0} + : key.dims)), + dtype, key.weak_type); + *value = aval; + return aval; + } + return **value; +} + +// Grouping key used by BatchedCopyToDeviceWithSharding. +// Defined outside of the function as required by templatized function +// `AbslHashValue`. +struct BatchedCopyToDeviceWithShardingKey { + ifrt::DeviceListRef src_devices; + ifrt::MemoryKind src_memory_kind; + ifrt::DeviceListRef dst_devices; + ifrt::MemoryKind dst_memory_kind; + ifrt::ArrayCopySemantics array_copy_semantics; + + bool operator==(const BatchedCopyToDeviceWithShardingKey& other) const { + return *src_devices == *other.src_devices && + src_memory_kind == other.src_memory_kind && + *dst_devices == *other.dst_devices && + dst_memory_kind == other.dst_memory_kind && + array_copy_semantics == other.array_copy_semantics; + } + + template + friend H AbslHashValue(H h, const BatchedCopyToDeviceWithShardingKey& key) { + return H::combine(std::move(h), key.src_devices, key.src_memory_kind, + key.dst_devices, key.dst_memory_kind, + key.array_copy_semantics); + } +}; + +} // namespace + +PyArray_Storage::PyArray_Storage( + nb::object aval, bool weak_type, xla::nb_dtype dtype, + std::vector shape, nb::object sharding, bool committed, + nb_class_ptr py_client, std::optional traceback, + ifrt::ArrayRef ifrt_array, xla::PjRtFuture<> result_status) + : aval(std::move(aval)), + weak_type(weak_type), + dtype(std::move(dtype)), + shape(std::move(shape)), + sharding(std::move(sharding)), + committed(committed), + py_client(std::move(py_client)), + traceback(std::move(traceback)), + ifrt_array(std::move(ifrt_array)), + result_status(std::move(result_status)) { + static_assert(PyClient::kNumArraysShards < + std::numeric_limits::max()); + thread_id_bucket = std::hash()(std::this_thread::get_id()) % + PyClient::kNumArraysShards; + + PyClient::ArraysShard& shard = this->py_client->arrays_[thread_id_bucket]; + nanobind::ft_lock_guard lock(shard.mutex); + next = shard.arrays; + shard.arrays = this; + if (next) { + next->prev = this; + } + prev = nullptr; +} + +void PyInit_helper(PyArray self, nb::object aval, nb::object sharding, + absl::Span py_arrays, bool committed) { + auto dtype = nb::cast(aval.attr("dtype")); + auto shape = nb::cast>(aval.attr("shape")); + auto py_device_list = nb::cast( + sharding.attr("_internal_device_list")); + nb_class_ptr py_client = py_device_list->py_client(); + auto ifrt_array = CreateIfRtArrayFromSingleDeviceShardedPyArrays( + dtype, shape, py_arrays, sharding); + Construct(reinterpret_cast(self.ptr()), aval, + nb::cast(aval.attr("weak_type")), std::move(dtype), + std::move(shape), std::move(sharding), committed, py_client, + Traceback::Get(), std::move(ifrt_array), xla::PjRtFuture<>()); +} + +void PyArray::PyInit(PyArray self, nb::object aval, nb::object sharding, + absl::Span py_arrays, bool committed, + bool skip_checks) { + if (skip_checks) { + PyInit_helper(self, aval, sharding, py_arrays, committed); + } else { + nb::object rearranged_arrays = + self.CheckAndRearrange(py_arrays, sharding, aval); + auto rearranged_py_arrays = + nb::cast>(rearranged_arrays); + PyInit_helper(self, aval, sharding, rearranged_py_arrays, committed); + } +} + +PyArray PyArray::MakeFromSingleDeviceArray(nb_class_ptr py_client, + std::optional traceback, + ifrt::ArrayRef ifrt_array, + bool weak_type, bool committed, + xla::PjRtFuture<> result_status) { + if (!llvm::isa(ifrt_array->sharding())) { + throw XlaRuntimeError( + InvalidArgument("Constructing single device jax.Array from non-single " + "device ifrt array.")); + } + auto shape_span = ifrt_array->shape().dims(); + ShapedArrayCacheKey key; + key.dtype = ifrt_array->dtype(); + key.dims = key.dtype.kind() == ifrt::DType::kToken + ? std::vector{0} + : std::vector(shape_span.begin(), shape_span.end()); + key.weak_type = weak_type; + auto aval = MakeShapedArrayCached(key); + auto dtype = IfrtDtypeToDtypeWithTokenCanonicalization(key.dtype).value(); + const ifrt::MemoryKind memory_kind = ifrt_array->sharding().memory_kind(); + nb::object py_memory_kind = + (memory_kind.memory_kind().has_value()) + ? nb::object(nb::str(memory_kind.memory_kind()->data(), + memory_kind.memory_kind()->size())) + : nb::none(); + nb::object sharding = make_nb_class( + py_client, ifrt_array->sharding().devices(), std::move(py_memory_kind)); + return PyArray(std::move(aval), weak_type, dtype, std::move(key.dims), + std::move(sharding), std::move(py_client), + std::move(traceback), std::move(ifrt_array), committed, + /*skip_checks=*/true, std::move(result_status)); +} + +PyArray PyArray::MakeFromIfrtArrayAndSharding( + nb_class_ptr py_client, std::optional traceback, + ifrt::ArrayRef ifrt_array, nb::object sharding, bool weak_type, + bool committed, bool skip_checks) { + auto shape_span = ifrt_array->shape().dims(); + ShapedArrayCacheKey key; + key.dtype = ifrt_array->dtype(); + key.dims = key.dtype.kind() == ifrt::DType::kToken + ? std::vector{0} + : std::vector(shape_span.begin(), shape_span.end()); + key.weak_type = weak_type; + auto aval = MakeShapedArrayCached(key); + auto dtype = IfrtDtypeToDtypeWithTokenCanonicalization(key.dtype).value(); + return PyArray(std::move(aval), weak_type, dtype, std::move(key.dims), + std::move(sharding), std::move(py_client), + std::move(traceback), std::move(ifrt_array), committed, + skip_checks); +} + +PyArrayResultHandler::PyArrayResultHandler(nb::object aval, nb::object sharding, + bool committed, bool skip_checks) + : aval_(std::move(aval)), + sharding_(std::move(sharding)), + committed_(committed), + skip_checks_(skip_checks) { + weak_type_ = nb::cast(aval_.attr("weak_type")); + dtype_ = nb::cast(aval_.attr("dtype")); + shape_ = nb::cast>(aval_.attr("shape")); +} + +PyArray PyArrayResultHandler::Call(absl::Span py_arrays) const { + auto py_device_list = jax::GetPyDeviceList(sharding_); + if (!py_device_list.ok()) { + throw nb::value_error( + absl::StrCat("Failed to get py device list from sharding: ", + py_device_list.status().ToString()) + .c_str()); + } + return Call(py_device_list.value()->py_client(), + CreateIfRtArrayFromSingleDeviceShardedPyArrays( + dtype_, shape_, py_arrays, sharding_), + xla::PjRtFuture<>()); +} + +PyArray PyArrayResultHandler::Call(nb_class_ptr py_client, + ifrt::ArrayRef ifrt_array, + xla::PjRtFuture<> result_status) const { + return PyArray(aval_, weak_type_, dtype_, shape_, sharding_, + std::move(py_client), Traceback::Get(), std::move(ifrt_array), + committed_, skip_checks_, std::move(result_status)); +} + +PyArray PyArrayResultHandler::Call(PyArray py_array) const { + return Call(py_array.py_client(), tsl::FormRef(py_array.ifrt_array()), + xla::PjRtFuture<>()); +} + +PyArray::PyArray(nb::object aval, bool weak_type, nb_dtype dtype, + std::vector shape, nb::object sharding, + nb_class_ptr py_client, + std::optional traceback, ifrt::ArrayRef ifrt_array, + bool committed, bool skip_checks, + xla::PjRtFuture<> result_status) { + auto* self = + PyArray_tp_new(reinterpret_cast(type_), nullptr, nullptr); + m_ptr = self; + Construct(reinterpret_cast(self), std::move(aval), weak_type, + std::move(dtype), std::move(shape), std::move(sharding), committed, + std::move(py_client), std::move(traceback), std::move(ifrt_array), + std::move(result_status)); + + if (!skip_checks) { + this->attr("_arrays") = this->attr("_check_and_rearrange")( + this->attr("_arrays"), this->attr("_sharding"), this->attr("aval")); + } +} + +PyArray::Storage& PyArray::GetStorage() { + return *GetPyArrayStorageFromObject(reinterpret_cast(ptr())); +} + +const PyArray::Storage& PyArray::GetStorage() const { + return *GetPyArrayStorageFromObject(reinterpret_cast(ptr())); +} + +nb::object PyArray::CheckAndRearrange(const absl::Span py_arrays, + const nb::object sharding, + const nb::object aval) { + return this->attr("_check_and_rearrange")(py_arrays, sharding, aval); +} + +void PyArray::SetIfrtArray(ifrt::ArrayRef ifrt_array) { + GetStorage().ifrt_array = std::move(ifrt_array); +} + +const std::vector& PyArray::py_arrays_cached() { + auto& py_arrays = this->py_arrays(); + + if (py_arrays.empty()) { + auto ifrt_arrays = ifrt_array()->DisassembleIntoSingleDeviceArrays( + ifrt::ArrayCopySemantics::kReuseInput, + ifrt::SingleDeviceShardSemantics::kAddressableShards); + if (!ifrt_arrays.ok()) { + throw nb::value_error( + absl::StrCat("Failed to disassemble into single-device arrays: ", + ifrt_arrays.status().ToString()) + .c_str()); + } + py_arrays.reserve(ifrt_arrays->size()); + for (auto& ifrt_array : *ifrt_arrays) { + py_arrays.push_back(PyArray::MakeFromSingleDeviceArray( + py_client(), traceback(), std::move(ifrt_array), weak_type(), + committed(), result_status())); + } + } + + return py_arrays; +} + +nb::object PyArray::arrays() { + // For performance, we only keep pjrt buffers by default. But on python side + // "_arrays" returns PyArrays instead, and subsequent calls to "_arrays" + // should return the same PyArrays (to avoid duplicate device to host + // transfers). So we create PyArrays the first time it is called and reuse + // them later. + if (ifrt_array() == nullptr || ifrt_array()->IsDeleted()) return nb::none(); + + if (llvm::isa(&ifrt_array()->sharding())) { + std::vector py_arrays; + py_arrays.push_back(*this); + return nb::cast(py_arrays); + } + + return nb::cast(py_arrays_cached()); +} + +absl::Status PyArray::set_arrays(nb::object obj) { + if (obj.is_none()) { + SetIfrtArray(ifrt::ArrayRef()); + py_arrays().clear(); + return absl::OkStatus(); + } + + if (!nb::isinstance(obj)) { + return InvalidArgument("Unsupported arg when setting Array._arrays: %s", + nb::cast(nb::str(obj.type()))); + } + + nb::list list(obj); + + if (list.size() == 0) return absl::OkStatus(); + + SetIfrtArray(ifrt::ArrayRef()); + py_arrays().clear(); + std::vector ifrt_arrays; + ifrt_arrays.reserve(list.size()); + absl::InlinedVector devices; + devices.reserve(list.size()); + std::vector shapes; + shapes.reserve(list.size()); + for (nb::handle obj : list) { + if (obj.type().is(PyArray::type())) { + auto py_array = nb::borrow(obj); + if (py_array.py_client().get() != py_client().get()) { + return InvalidArgument("Client mismatch when assigning to _arrays."); + } + if (py_array.num_shards() != 1) { + return InvalidArgument("Wrong number of shards: %d", + py_array.num_shards()); + } + ifrt_arrays.push_back(tsl::FormRef(py_array.ifrt_array())); + devices.push_back( + ifrt_arrays.back()->sharding().devices()->devices().front()); + shapes.push_back(ifrt_arrays.back()->shape()); + } else { + return InvalidArgument("Unsupported arg when setting Array._arrays: %s", + nb::cast(nb::str(obj.type()))); + } + } + const ifrt::MemoryKind first_memory_kind = + ifrt_arrays.front()->sharding().memory_kind(); + // TODO(hyeontaek): Canonicalize every `ifrt::MemoryKind` at creation time to + // skip canonicalization here once JAX begins to do it for JAX shardings. + const ifrt::MemoryKind canonical_first_memory_kind = + ifrt::CanonicalizeMemoryKind( + first_memory_kind, + ifrt_arrays.front()->sharding().devices()->devices().front()); + for (const auto& ifrt_array : ifrt_arrays) { + if (canonical_first_memory_kind != + ifrt::CanonicalizeMemoryKind( + ifrt_array->sharding().memory_kind(), + ifrt_array->sharding().devices()->devices().front())) { + throw nb::value_error( + absl::StrFormat( + "Memory kind mismatch between single-device arrays. Got one " + "array with memory kind '%v' and another with memory_kind '%v'", + first_memory_kind, ifrt_array->sharding().memory_kind()) + .c_str()); + } + } + + TF_ASSIGN_OR_RETURN( + auto ifrt_sharding, + sharding().type().is(jax::PmapSharding::type()) + ? xla::GetIfrtConcreteSharding(sharding(), ifrt::Shape(shape()), + std::move(shapes)) + : xla::GetIfrtHloSharding(sharding(), ifrt::Shape(shape()))); + TF_ASSIGN_OR_RETURN( + auto array, + py_client()->ifrt_client()->AssembleArrayFromSingleDeviceArrays( + ifrt::Shape(shape()), std::move(ifrt_sharding), + absl::MakeSpan(ifrt_arrays), ifrt::ArrayCopySemantics::kReuseInput, + ifrt::SingleDeviceShardSemantics::kAddressableShards)); + SetIfrtArray(std::move(array)); + return absl::OkStatus(); +} + +absl::StatusOr PyArray::FullyReplicatedShard() { + auto& cached = GetStorage().fully_replicated_array; + if (!cached.is_none()) { + return nb::cast(cached); + } + + if (ifrt_array() == nullptr) { + return InvalidArgument( + "FullyReplicatedShard() called on deleted or donated buffer"); + } + + TF_ASSIGN_OR_RETURN(auto fully_replicated_ifrt_shard, + ifrt_array()->FullyReplicatedShard( + ifrt::ArrayCopySemantics::kReuseInput)); + auto array = MakeFromSingleDeviceArray( + py_client(), traceback(), std::move(fully_replicated_ifrt_shard), + weak_type(), committed(), result_status()); + cached = array; + return nb::cast(cached); +} + +absl::Status PyArray::BlockUntilReady() const { + nb::gil_scoped_release gil_release; + if (ifrt_array() == nullptr) { + return InvalidArgument( + "BlockHostUntilReady() called on deleted or donated buffer"); + } + ifrt::Array* ifrt_array = this->ifrt_array(); + return AwaitBuffersReady(absl::MakeConstSpan(&ifrt_array, 1)); +} + +absl::StatusOr PyArray::GetOnDeviceSizeInBytes() { + if (ifrt_array() == nullptr) { + return InvalidArgument( + "GetOnDeviceSizeInBytes() called on deleted or donated buffer"); + } + + TF_ASSIGN_OR_RETURN(size_t shard_size, + GetPjrtBuffer(ifrt_array())->GetOnDeviceSizeInBytes()); + return shard_size * nb::len(nb::object(sharding().attr("device_set"))); +} + +absl::Status PyArray::BlockUntilResultStatusIsReady() { + auto& result_status = GetStorage().result_status; + // If the result_status future is not valid, this result did not come directly + // from a computation that returns tokens, so we don't wait for the status. + if (!result_status.IsValid()) { + return absl::OkStatus(); + } + if (!result_status.IsReady()) { + // Only release the gil if we need to Await(). + nb::gil_scoped_release release_gil; + BlockUntilReadyWithCancel(result_status); + return result_status.Await(); + } + return result_status.Await(); +} + +absl::StatusOr> +PyArray::SingleDeviceArrayToNumpyArrayDidCopy() { + TF_ASSIGN_OR_RETURN(auto arr, FullyReplicatedShard()); + auto result = arr.GetStorage().host_value.AsNumPyArray( + arr.GetStorage().dynamic_shape, arr.ifrt_array()); + TF_RETURN_IF_ERROR(arr.BlockUntilResultStatusIsReady()); + return result; +} + +absl::StatusOr PyArray::SingleDeviceArrayToNumpyArray() { + TF_ASSIGN_OR_RETURN(auto result, SingleDeviceArrayToNumpyArrayDidCopy()); + return result.first; +} + +absl::Status PyArray::CopySingleDeviceArrayToHostAsync() { + TF_ASSIGN_OR_RETURN(auto arr, FullyReplicatedShard()); + return arr.GetStorage().host_value.CopyToHostAsync( + arr.GetStorage().dynamic_shape, arr.ifrt_array()); +} + +absl::StatusOr PyArray::AssertUnsharded(absl::string_view api) { + if (ifrt_array() == nullptr) { + return InvalidArgument("%s( called on deleted or donated buffer", api); + } + + if (llvm::isa(&ifrt_array()->sharding())) { + return *this; + } + + auto& py_arrays = py_arrays_cached(); + if (py_arrays.size() != 1) { + return InvalidArgument("%s() is supported only for unsharded arrays.", api); + } + return py_arrays[0]; +} + +absl::StatusOr PyArray::UnsafeBufferPointer() { + TF_ASSIGN_OR_RETURN(auto arr, AssertUnsharded("UnsafeBufferPointer")); + + return py_client()->pjrt_client()->UnsafeBufferPointer( + GetPjrtBuffer(arr.ifrt_array())); +} + +nb::dict PyArray::CudaArrayInterface() { + auto arr_or_error = AssertUnsharded("UnsafeBufferPointer"); + if (!arr_or_error.ok()) { + throw nb::attribute_error( + "__cuda_array_interface__ is only supported for unsharded arrays."); + } + auto arr = *arr_or_error; + + ifrt::Array* ifrt_array = arr.ifrt_array(); + std::optional& scratch = arr.GetStorage().dynamic_shape; + auto* pjrt_buffer = GetPjrtBuffer(ifrt_array); + if (pjrt_buffer->client()->platform_id() != CudaId()) { + throw nb::attribute_error( + "__cuda_array_interface__ is only defined for NVidia GPU buffers."); + } + if (pjrt_buffer->IsTuple()) { + throw nb::attribute_error( + "__cuda_array_interface__ is only defined for array buffers."); + } + + switch (pjrt_buffer->element_type()) { + case PrimitiveType::PRED: + case PrimitiveType::S8: + case PrimitiveType::S16: + case PrimitiveType::S32: + case PrimitiveType::S64: + case PrimitiveType::U8: + case PrimitiveType::U16: + case PrimitiveType::U32: + case PrimitiveType::U64: + case PrimitiveType::F16: + case PrimitiveType::F32: + case PrimitiveType::F64: + case PrimitiveType::C64: + case PrimitiveType::C128: + break; + + default: + throw nb::attribute_error( + absl::StrFormat( + "__cuda_array_interface__ is not supported for %s buffers.", + PrimitiveType_Name(pjrt_buffer->element_type())) + .c_str()); + } + + nb::str typestr = + ValueOrThrow(TypeDescriptorForPrimitiveType(pjrt_buffer->element_type())); + + // TODO(b/327524065): use PjRtLayout directly instead of xla::Layout + Layout xla_layout = pjrt_buffer->layout()->xla_layout(); + if (!LayoutUtil::IsMonotonicWithDim0Major(xla_layout)) { + throw nb::attribute_error( + "__cuda_array_interface__ is only currently supported for " + "buffers in row-major order."); + } + + nb::dict result; + const auto* dynamic_shape = + ValueOrThrow(XlaDynamicShape(ifrt_array, scratch)); + result["shape"] = SpanToNbTuple(dynamic_shape->dimensions()); + result["typestr"] = std::move(typestr); + std::unique_ptr external_reference_hold = + ValueOrThrow(pjrt_buffer->AcquireExternalReference()); + const void* root_ptr = + external_reference_hold->OpaqueDeviceMemoryDataPointer(); + nb::tuple data = + nb::make_tuple(nb::int_(absl::bit_cast(root_ptr)), + nb::bool_(true) /* read-only */ + ); + result["data"] = std::move(data); + result["version"] = nb::int_(2); + return result; +} + +absl::StatusOr CudaArrayInterfaceToBuffer( + const nb::dict& cai, nb_class_ptr client, + std::optional device_id) { + if (!cai.contains("data")) { + return absl::InvalidArgumentError( + "CUDA Array Interface does not define `data`"); + } + if (!cai.contains("shape")) { + return absl::InvalidArgumentError( + "CUDA Array Interface does not define `shape`"); + } + if (!cai.contains("typestr")) { + return absl::InvalidArgumentError( + "CUDA Array Interface does not define `typestr`"); + } + if (!cai.contains("version")) { + return absl::InvalidArgumentError( + "CUDA Array Interface does not define `version`"); + } + auto version = nb::cast(cai["version"]); + if (version < 2 || version > 3) { + LOG(WARNING) << "CUDA Array Interface version " << version + << " support is undefined"; + } + auto data = nb::cast(cai["data"]); + auto data_value = nb::cast(data[0]); + void* data_ptr = reinterpret_cast(data_value); + auto dimensions = nb::cast>(cai["shape"]); + if (data_value == 0 && absl::c_find(dimensions, 0) == dimensions.end()) { + return absl::InvalidArgumentError( + "CUDA Array Interface `data`(=NULL) and `shape`(no zero-valued " + "dimensions) are inconsistent"); + } + auto ndim = dimensions.size(); + TF_ASSIGN_OR_RETURN( + PrimitiveType element_type, + DtypeToPrimitiveType(nb_dtype::from_args(cai["typestr"]))); + + if (!device_id.has_value()) { + throw XlaRuntimeError( + "This operation requires CUDA support from jaxlib or jax cuda plugin."); + } + TF_ASSIGN_OR_RETURN(auto device, + client->DeviceFromLocalHardwareId(*device_id)); + bool is_default_stream = + data_value == 0 || version == 2 || + (version == 3 && (!cai.contains("stream") || cai["stream"].is_none())); + TF_ASSIGN_OR_RETURN( + std::intptr_t stream, + ([is_default_stream, cai, device]() -> absl::StatusOr { + if (is_default_stream) { + return device->GetStreamForExternalReadyEvents(); + } else { + auto stream_ = nb::cast(cai["stream"]); + if (stream_ == 0) { + return absl::InvalidArgumentError( + "CUDA Array Interface does not allow zero stream value"); + } + return stream_; + } + }())); + + std::vector minor_to_major(ndim); + if (cai.contains("strides") && !cai["strides"].is_none() && data_value != 0) { + std::iota(minor_to_major.begin(), minor_to_major.end(), 0); + auto strides = nb::cast>(cai["strides"]); + if (strides.size() != ndim) { + return absl::InvalidArgumentError( + "CUDA Array Interface `shape` and `strides` dimensionalities are " + "inconsistent"); + } + absl::c_sort(minor_to_major, [&](int a, int b) { + // If two dimensions have the same stride, prefer the major-to-minor + // interpretation of the ordering, since that's what JAX wants. + return (strides[a] == strides[b] ? b < a : strides[a] < strides[b]); + }); + int64_t stride = ShapeUtil::ByteSizeOfPrimitiveType(element_type); + for (int64_t d : minor_to_major) { + if (dimensions[d] > 1 && strides[d] != stride) { + return absl::UnimplementedError(absl::StrCat( + "Only arrays with trivial (compact) striding are supported; " + "i.e., arrays whose striding represents a transposition of the " + "underlying buffer but not broadcasting. Dimensions were: [%s], " + "strides were [%s].", + absl::StrJoin(dimensions, ","), absl::StrJoin(strides, ","))); + } + stride *= dimensions[d]; + } + } else { + std::iota(minor_to_major.rbegin(), minor_to_major.rend(), 0); + } + Shape shape = ShapeUtil::MakeShapeWithDenseLayout(element_type, dimensions, + minor_to_major); + std::function on_delete_callback = []() {}; + auto* pjrt_device = + llvm::dyn_cast_or_null(device->device()); + if (pjrt_device == nullptr) { + return InvalidArgument( + "This operation is implemented for a PjRt-compatible backend only."); + } + TF_RET_CHECK(pjrt_device->IsAddressable()); + TF_ASSIGN_OR_RETURN( + auto pjrt_buffer, + device->client()->pjrt_client()->CreateViewOfDeviceBuffer( + static_cast(data_ptr), shape, + *pjrt_device->pjrt_device()->default_memory_space(), + on_delete_callback, + stream <= 2 ? std::nullopt : std::make_optional(stream))); + auto* ifrt_client = + llvm::dyn_cast_or_null(client->ifrt_client()); + if (ifrt_client == nullptr) { + throw XlaRuntimeError( + "This operation is implemented for a PjRt-compatible backend only."); + } + TF_ASSIGN_OR_RETURN(auto ifrt_array, + ifrt_client->CreatePjRtArray(std::move(pjrt_buffer))); + return PyArray::MakeFromSingleDeviceArray(std::move(client), Traceback::Get(), + std::move(ifrt_array), false, true); +} + +absl::Status PyArray::Delete() { + for (auto& arr : py_arrays()) { + TF_RETURN_IF_ERROR(arr.Delete()); + } + py_arrays().clear(); + if (ifrt_array() != nullptr) { + // We do not wait for the deletion to complete here. + // + // (1) Skipping blocking does not affect the correctness of deletion as long + // as the runtime preserves dispatch ordering of deletion w.r.t. other + // operations. + // + // (2) Synchronously waiting for the deletion to complete is very expensive + // when the deletion can return a status only after the underlying physical + // buffer has been deleted or a request must be processed via RPC, + // especially as this deletion is done per array. + ifrt_array()->Delete(); + SetIfrtArray(ifrt::ArrayRef()); + } + return absl::OkStatus(); +} + +bool PyArray::IsDeleted() const { + if (ifrt_array() == nullptr) { + return true; + } + + return ifrt_array()->IsDeleted(); +} + +PyArray PyArray::Clone() const { + auto array = tsl::FormRef(ifrt_array()); + auto* ifrt_client = py_client()->ifrt_client(); + ifrt::ArrayRef out = + ifrt_client + ->CopyArrays(absl::MakeSpan(&array, 1), /*devices=*/std::nullopt, + /*memory_kind=*/std::nullopt, + ifrt::ArrayCopySemantics::kReuseInput) + .value() + .front(); + return PyArray(aval(), weak_type(), dtype(), + std::vector(shape().begin(), shape().end()), + sharding(), py_client(), traceback(), std::move(out), + committed(), /*skip_checks=*/true, result_status()); +} + +nb::handle PyArray::Storage::AsHandle() { + return reinterpret_cast(reinterpret_cast(this) - + offsetof(PyArrayObject, array_storage)); +} + +PyArray::Storage::~PyArray_Storage() { + CHECK(PyGILState_Check()); + if (py_client) { + PyClient::ArraysShard& shard = py_client->arrays_[thread_id_bucket]; + nanobind::ft_lock_guard lock(shard.mutex); + if (shard.arrays == this) { + shard.arrays = next; + } + if (prev) { + prev->next = next; + } + if (next) { + next->prev = prev; + } + } + // Release GIL and then explicitly destroy `ifrt_array` to prevent deadlock on + // CPU backend caused by interactions between argument donations and host + // callbacks. + nb::gil_scoped_release gil_release; + ifrt_array.reset(); +} + +absl::StatusOr> PyArray::BatchedCopyToDeviceWithSharding( + absl::Span py_arrays, + absl::Span dst_device_lists, + absl::Span dst_shardings, + absl::Span array_copy_semantics) { + if (py_arrays.empty()) { + return std::vector(); + } + + TF_RET_CHECK(py_arrays.size() == dst_device_lists.size()); + TF_RET_CHECK(py_arrays.size() == dst_shardings.size()); + + ifrt::Client* const client = py_arrays.front().ifrt_array()->client(); + std::vector results(py_arrays.size()); + + // Arrays to be copied, grouped by source/destination devices and memory + // kinds. The grouping is enforced by `ifrt::Client::CopyArrays()`. + struct Batch { + std::vector indexes; + std::vector ifrt_arrays; + }; + absl::flat_hash_map batches; + + auto traceback = Traceback::Get(); + { + tsl::profiler::TraceMe results_traceme( + "BatchedCopyToDeviceWithSharding create batch"); + for (int i = 0; i < py_arrays.size(); ++i) { + const auto& py_array = py_arrays[i]; + const auto& dst_sharding = dst_shardings[i]; + const auto& array_cs = array_copy_semantics[i]; + + auto* ifrt_array_ptr = py_array.ifrt_array(); + const ifrt::DeviceListRef& src_devices = + ifrt_array_ptr->sharding().devices(); + const ifrt::DeviceListRef& dst_devices = dst_device_lists[i]; + + ifrt::MemoryKind src_memory_kind = + ifrt::CanonicalizeMemoryKind(ifrt_array_ptr->sharding().memory_kind(), + src_devices->devices().front()); + ifrt::MemoryKind dst_memory_kind = ifrt::CanonicalizeMemoryKind( + xla::GetMemoryKind(dst_sharding), dst_devices->devices().front()); + + if (*src_devices == *dst_devices && src_memory_kind == dst_memory_kind && + array_cs == ifrt::ArrayCopySemantics::kReuseInput) { + if (py_array.sharding().equal(dst_sharding)) { + results[i] = py_arrays[i]; + } else { + absl::Span shape_span = py_array.shape(); + // We can reuse the input array despite the sharding being different. + // This is because this code expects no resharding is necessary, which + // has been verified by the code invoking this method. + results[i] = PyArray( + py_array.aval(), py_array.weak_type(), py_array.dtype(), + std::vector(shape_span.begin(), shape_span.end()), + dst_sharding, py_array.py_client(), traceback, + tsl::FormRef(ifrt_array_ptr), py_array.committed(), + /*skip_checks=*/true, py_array.result_status()); + } + continue; + } + + auto transfer_guard_formatter = [&py_array, &dst_sharding] { + return absl::StrCat( + "aval=", nb::cast(nb::repr(py_array.aval())), + ", sharding=", + nb::cast(nb::repr(py_array.sharding())), + ", dst_sharding=", + nb::cast(nb::repr(dst_sharding))); + }; + TF_RETURN_IF_ERROR( + jax::ApplyTransferGuardToDeviceToDevice(transfer_guard_formatter)); + + Batch& batch = batches[BatchedCopyToDeviceWithShardingKey{ + src_devices, src_memory_kind, dst_devices, dst_memory_kind, + array_cs}]; + batch.indexes.push_back(i); + batch.ifrt_arrays.push_back(tsl::FormRef(ifrt_array_ptr)); + } + } + + std::vector> ifrt_arrays; + { + GlobalPyRefManager()->CollectGarbage(); + nb::gil_scoped_release gil_release; + + tsl::profiler::TraceMe copy_traceme( + "BatchedCopyToDeviceWithSharding: dispatch"); + for (auto& [key, batch] : batches) { + TF_ASSIGN_OR_RETURN( + auto copied, + client->CopyArrays( + absl::MakeSpan(batch.ifrt_arrays), + // All arrays in `batch` have the same `key.dst_devices` and + // `key.dst_memory_kind` due to the grouping above. + key.dst_devices, key.dst_memory_kind, key.array_copy_semantics)); + for (int i = 0; i < batch.indexes.size(); ++i) { + ifrt_arrays.push_back( + std::make_pair(batch.indexes[i], std::move(copied[i]))); + } + } + } + + tsl::profiler::TraceMe results_traceme( + "BatchedCopyToDeviceWithSharding create results"); + for (auto& [i, ifrt_array] : ifrt_arrays) { + const auto& py_array = py_arrays[i]; + absl::Span shape_span = py_array.shape(); + results[i] = + PyArray(py_array.aval(), py_array.weak_type(), py_array.dtype(), + std::vector(shape_span.begin(), shape_span.end()), + dst_shardings[i], py_array.py_client(), traceback, + std::move(ifrt_array), py_array.committed(), + /*skip_checks=*/true, py_array.result_status()); + } + return results; +} + +absl::StatusOr PyArray::BatchedDevicePut( + nb::object aval, nb::object sharding, std::vector xs, + absl::Span dst_devices, bool committed, + bool force_copy, PjRtClient::HostBufferSemantics host_buffer_semantics, + bool jax_enable_x64) { + if (dst_devices.size() != xs.size()) { + throw nb::value_error( + absl::StrCat("Argument sizes (xs and devices) must match %zu vs %zu", + dst_devices.size(), xs.size()) + .c_str()); + } + for (const PyDevice* device : dst_devices) { + if (device->client().get() == nullptr) { + return InvalidArgument("Cannot copy to unattached devices."); + } + } + auto transfer_guard_formatter = [&aval, &sharding] { + return absl::StrCat( + "aval=", nb::cast(nb::repr(aval)), + ", dst_sharding=", nb::cast(nb::repr(sharding))); + }; + + GlobalPyRefManager()->CollectGarbage(); + + auto n_devices = dst_devices.size(); + + DevicePutOptions options; + options.squash_64bit_types = !jax_enable_x64; + options.allow_zero_copy = + (!force_copy && (host_buffer_semantics == + ifrt::Client::HostBufferSemantics::kImmutableZeroCopy)); + + std::vector ifrt_arrays; + + absl::InlinedVector devices; + devices.reserve(n_devices); + std::vector shapes; + shapes.reserve(n_devices); + + std::vector args; + args.reserve(xs.size()); + for (const nb::object& x : xs) { + if (PyArray::IsPyArray(x)) { + TF_RETURN_IF_ERROR( + jax::ApplyTransferGuardToDeviceToDevice(transfer_guard_formatter)); + } else { + TF_RETURN_IF_ERROR( + jax::ApplyTransferGuardToHostToDevice(transfer_guard_formatter)); + } + args.push_back(x); + } + auto weak_type = nb::cast(aval.attr("weak_type")); + auto dtype = aval.attr("dtype"); + auto shape = nb::cast>(aval.attr("shape")); + TF_ASSIGN_OR_RETURN(nb_class_ptr py_device_list, + jax::GetPyDeviceList(sharding)); + + TF_ASSIGN_OR_RETURN( + DevicePutResult device_put_result, + DevicePutWithSharding(args, py_device_list->py_client()->ifrt_client(), + dtype, shape, sharding, options)); + + return PyArray(aval, weak_type, dtype, std::move(shape), std::move(sharding), + py_device_list->py_client(), Traceback::Get(), + std::move(device_put_result.ifrt_array), committed, + /*skip_checks=*/true); +} + +absl::StatusOr PyArray::ReorderShards( + PyArray x, nanobind::object dst_sharding, + ifrt::ArrayCopySemantics array_copy_semantics) { + xla::ifrt::Array* ifrt_array_ptr = x.ifrt_array(); + if (ifrt_array_ptr == nullptr) { + return absl::InvalidArgumentError( + "Reorder() called on deleted or donated buffer"); + } + + ifrt::Client* const client = ifrt_array_ptr->client(); + + const auto& device_list = ifrt_array_ptr->sharding().devices(); + TF_ASSIGN_OR_RETURN(auto dst_device_list, GetIfrtDeviceList(dst_sharding)); + if (device_list->AddressableDeviceList()->size() != + dst_device_list->AddressableDeviceList()->size()) { + return absl::InvalidArgumentError(absl::StrCat( + "Array is expected to have ", + dst_device_list->AddressableDeviceList()->size(), + " addressable shards, but has ", + device_list->AddressableDeviceList()->size(), " addressable shards")); + } + + TF_ASSIGN_OR_RETURN( + xla::ifrt::ShardingRef dst_ifrt_sharding, + GetIfrtConcreteEvenSharding(dst_sharding, ifrt_array_ptr->dtype(), + ifrt_array_ptr->shape())); + + xla::ifrt::ArrayRef new_ifrt_array; + { + nb::gil_scoped_release gil_release; + + const absl::Span addressable_devices = + device_list->AddressableDeviceList()->devices(); + const absl::Span dst_addressable_devices = + dst_device_list->AddressableDeviceList()->devices(); + + absl::flat_hash_map device_id_to_array_shard_index; + device_id_to_array_shard_index.reserve(dst_addressable_devices.size()); + for (int i = 0; i < dst_addressable_devices.size(); ++i) { + const int device_id = dst_addressable_devices[i]->Id().value(); + const bool inserted = + device_id_to_array_shard_index.insert({device_id, i}).second; + if (!inserted) { + return absl::InvalidArgumentError( + absl::StrCat("Sharding contains duplicate device id=", device_id)); + } + } + + std::vector from_shard_indices; + from_shard_indices.reserve(addressable_devices.size()); + std::vector to_shard_indices; + to_shard_indices.reserve(dst_addressable_devices.size()); + for (int i = 0; i < dst_addressable_devices.size(); ++i) { + from_shard_indices.push_back(i); + const int shard_device_id = addressable_devices[i]->Id().value(); + const auto it = device_id_to_array_shard_index.find(shard_device_id); + if (it == device_id_to_array_shard_index.end()) { + return absl::InvalidArgumentError(absl::StrCat( + "Array shard ", i, " is on device id=", shard_device_id, + ", but sharding does not have a shard on that device.")); + } + to_shard_indices.push_back(it->second); + } + + auto mappings = + std::make_shared>(); + { + auto& mapping = mappings->emplace_back(); + mapping.in_array = 0; + mapping.out_array = 0; + mapping.from.reserve(dst_addressable_devices.size()); + mapping.to.reserve(dst_addressable_devices.size()); + for (int64_t i = 0; i < dst_addressable_devices.size(); ++i) { + mapping.from.push_back(xla::ifrt::RemapPlan::Interval{ + from_shard_indices[i], from_shard_indices[i] + 1, 1}); + mapping.to.push_back(xla::ifrt::RemapPlan::Interval{ + to_shard_indices[i], to_shard_indices[i] + 1, 1}); + } + } + + xla::ifrt::RemapPlan plan = { + /*input_specs=*/{xla::ifrt::ArraySpec{ + /*dtype=*/ifrt_array_ptr->dtype(), + /*shape=*/ifrt_array_ptr->shape(), + /*sharding=*/ifrt_array_ptr->shared_ptr_sharding()}}, + /*output_specs=*/ + {xla::ifrt::ArraySpec{/*dtype=*/ifrt_array_ptr->dtype(), + /*shape=*/ifrt_array_ptr->shape(), + /*sharding=*/std::move(dst_ifrt_sharding)}}, + /*mappings=*/std::move(mappings), + }; + DCHECK_OK(plan.Validate()); + std::vector input; + input.push_back(tsl::FormRef(ifrt_array_ptr)); + TF_ASSIGN_OR_RETURN( + auto remapped, + client->RemapArrays(plan, absl::MakeSpan(input), array_copy_semantics)); + + TF_RET_CHECK(remapped.size() == 1); + new_ifrt_array = std::move(remapped.front()); + } + + return xla::PyArray(nb::borrow(x.aval().ptr()), x.weak_type(), + nb::borrow(x.dtype().ptr()), + std::vector(x.shape().begin(), x.shape().end()), + std::move(dst_sharding), x.py_client(), x.traceback(), + std::move(new_ifrt_array), + /*committed=*/true, + /*skip_checks=*/true); +} + +absl::Status PyArray::BatchedBlockUntilReady(std::vector objs) { + // Create ready futures for all arrays before blocking on their readiness. + // This helps reduce the latency in some backend implementations where + // querying readiness of an array is not free. + + std::vector ifrt_arrays; + ifrt_arrays.reserve(objs.size()); + for (nb::handle obj : objs) { + if (obj.type().is(PyArray::type())) { + auto py_array = nb::borrow(obj); + ifrt::Array* const ifrt_array = py_array.ifrt_array(); + if (ifrt_array == nullptr) { + return absl::InvalidArgumentError( + "BlockHostUntilReady() called on deleted or donated buffer"); + } + ifrt_arrays.push_back(ifrt_array); + } else { + return absl::InvalidArgumentError( + "PyArray::BatchedBlockUntilReady can take PyArray only"); + } + } + + GlobalPyRefManager()->CollectGarbage(); + nb::gil_scoped_release gil_release; + return AwaitBuffersReady(absl::MakeConstSpan(ifrt_arrays)); +} + +absl::Status PyArray::ReplaceWithAlias(PyArray o) { + auto& storage = GetStorage(); + auto& o_storage = o.GetStorage(); + if (storage.py_client.get() != o_storage.py_client.get()) { + return absl::InvalidArgumentError( + "Unable to replace a PyArray with a PyArray from a different client."); + } + storage.aval = o_storage.aval; + storage.weak_type = o_storage.weak_type; + storage.dtype = o_storage.dtype; + storage.shape = o_storage.shape; + storage.sharding = o_storage.sharding; + storage.npy_value = o_storage.npy_value; + storage.committed = o_storage.committed; + storage.traceback = o_storage.traceback; + storage.ifrt_array = o_storage.ifrt_array; + storage.fully_replicated_array = o_storage.fully_replicated_array; + storage.py_arrays = o_storage.py_arrays; + storage.host_value.Clear(); + storage.dynamic_shape = o_storage.dynamic_shape; + storage.result_status = o_storage.result_status; + + return absl::OkStatus(); +} + +std::vector PyClient::LiveArrays() const { + std::vector result; + for (auto& shard : arrays_) { + nb::ft_lock_guard lock(shard.mutex); + for (PyArray::Storage* array = shard.arrays; array; array = array->next) { + bool all_deleted = + (array->ifrt_array == nullptr || array->ifrt_array->IsDeleted()); + if (!all_deleted) { + result.push_back(nb::borrow(array->AsHandle())); + } + } + } + return result; +} + +// PEP 3118 buffer protocol implementation. + +namespace { + +// Extra data to be kept alive by the consumer of the buffer protocol. +struct ExtraBufferInfo { + explicit ExtraBufferInfo( + std::shared_ptr buffer, + std::unique_ptr external_reference_hold) + : buffer(std::move(buffer)), + external_reference_hold(std::move(external_reference_hold)) {} + + std::vector strides; + // We keep an external reference hold to the PjRtBuffer. This prevents a + // use-after-free in the event that Delete() is called on a buffer with an + // live buffer protocol view. It does however mean that Delete() sometimes + // won't actually delete immediately. + std::shared_ptr buffer; + std::unique_ptr external_reference_hold; +}; + +// The default layout of a non-tuple array should have major-to-minor layout +// and no tiles. +bool HasDefaultLayout(const Layout& layout) { + return LayoutUtil::IsMonotonicWithDim0Major(layout) && layout.tiles().empty(); +} + +int PyArray_bf_getbuffer(PyObject* exporter, Py_buffer* view, int flags) { + absl::Status status = [&]() -> absl::Status { + PyArray py_array = nb::borrow(exporter); + if (py_array.ifrt_array() == nullptr) { + // TODO(phawkins): why is this happening? + return InvalidArgument("Array is null"); + } + if (!llvm::isa(py_array.ifrt_array())) { + return InvalidArgument("Only local arrays are supported, got %s", + py_array.ifrt_array()->DebugString()); + } + auto* array = + static_cast(py_array.ifrt_array()); + absl::Span> buffers = + array->pjrt_buffers(); + + if (buffers.empty()) { + return InvalidArgument("Array has no buffers."); + } + PjRtBuffer& buffer = *buffers.front(); + if (!buffer.IsOnCpu()) { + return InvalidArgument( + "Python buffer protocol is only defined for CPU buffers."); + } + + if (buffers.size() != 1) { + return InvalidArgument( + "Python buffer protocol is only defined for buffers with a single " + "shard."); + } + if (!py_array.sharding().type().is(jax::SingleDeviceSharding::type())) { + return InvalidArgument( + "Python buffer protocol is only defined for single-device sharded " + "buffers."); + } + + const char* format = + PEP3118FormatDescriptorForPrimitiveType(buffer.element_type()); + // It isn't an option for us to export unknown types as, say, bytes. When + // converting an object to an ndarray, NumPy tries the buffer protocol + // first. We very much want NumPy to fail and fall back to using + // __array__, which allows us to handle custom dtypes correctly. + if (!format) { + return InvalidArgument( + "Buffers of type %s are not supported by the Python buffer protocol.", + PrimitiveType_Name(buffer.element_type())); + } + + std::unique_ptr external_reference_hold; + { + // We call BlockHostUntilReady() below, which may block. + nb::gil_scoped_release gil_release; + + if (buffer.IsTuple()) { + return InvalidArgument( + "Python buffer protocol is only defined for array buffers."); + } + if ((flags & PyBUF_WRITEABLE) == PyBUF_WRITEABLE) { + return InvalidArgument("XLA buffers are read-only."); + } + TF_ASSIGN_OR_RETURN(external_reference_hold, + buffer.AcquireExternalReference()); + if (buffer.IsDeleted()) { + return InvalidArgument("Deleted buffer used in buffer protocol."); + } + + // TODO(b/327524065): use PjRtLayout directly instead of xla::Layout + Layout xla_layout = buffer.layout()->xla_layout(); + + if (((flags & PyBUF_C_CONTIGUOUS) == PyBUF_C_CONTIGUOUS || + (flags & PyBUF_STRIDES) == PyBUF_ND) && + !LayoutUtil::IsMonotonicWithDim0Major(xla_layout)) { + return InvalidArgument("Buffer is not in C-contiguous layout."); + } else if ((flags & PyBUF_F_CONTIGUOUS) == PyBUF_F_CONTIGUOUS && + !LayoutUtil::IsMonotonicWithDim0Minor(xla_layout)) { + return InvalidArgument("Buffer is not in F-contiguous layout."); + } else if ((flags & PyBUF_ANY_CONTIGUOUS) == PyBUF_ANY_CONTIGUOUS && + !LayoutUtil::IsMonotonicWithDim0Major(xla_layout) && + !LayoutUtil::IsMonotonicWithDim0Minor(xla_layout)) { + return InvalidArgument("Buffer is not in contiguous layout."); + } else if (!HasDefaultLayout(xla_layout)) { + // Fail and fall back to using __array__ if the CPU buffer has a device + // specific layout. For instance, this happens for host buffers in + // pinned memories of the TPU device. + return InvalidArgument( + "Buffer is potentially a device buffer with non default layout."); + } + TF_RETURN_IF_ERROR(buffer.GetReadyFuture().Await()); + } + + // We must hold the GIL (or at least prevent Python GC) while writing to the + // view object, see https://github.com/python/cpython/issues/130409. + std::memset(view, 0, sizeof(Py_buffer)); + const void* root_ptr = + external_reference_hold->OpaqueDeviceMemoryDataPointer(); + view->buf = const_cast(root_ptr); + auto extra = std::make_unique( + buffers.front(), std::move(external_reference_hold)); + view->itemsize = ShapeUtil::ByteSizeOfPrimitiveType(buffer.element_type()); + TF_ASSIGN_OR_RETURN(view->len, buffer.GetOnDeviceSizeInBytes()); + view->readonly = 1; + if ((flags & PyBUF_FORMAT) == PyBUF_FORMAT) { + view->format = const_cast(format); + } + if ((flags & PyBUF_ND) == PyBUF_ND) { + view->ndim = buffer.dimensions().size(); + static_assert(sizeof(int64_t) == sizeof(Py_ssize_t), + "Py_ssize_t must be 64 bits"); + if (view->ndim != 0) { + view->shape = reinterpret_cast( + const_cast(buffer.dimensions().data())); + if ((flags & PyBUF_STRIDES) == PyBUF_STRIDES) { + extra->strides = + ByteStridesForShape(buffer.element_type(), buffer.dimensions(), + buffer.layout()->xla_layout()); + view->strides = reinterpret_cast( + const_cast(extra->strides.data())); + } + } + } + view->internal = extra.release(); + return absl::OkStatus(); + }(); + if (!status.ok()) { + // numpy.asarray(...) eats the PyExc_BufferError. Adding a log here helps + // debugging when the error really occurs. + VLOG(1) << "Buffer Protocol Error: " << status; + PyErr_SetString(PyExc_BufferError, status.ToString().c_str()); + return -1; + } + view->obj = exporter; + Py_INCREF(view->obj); + return 0; +} + +void PyArray_bf_releasebuffer(PyObject*, Py_buffer* buffer) { + auto extra = static_cast(buffer->internal); + delete extra; +} + +// Returns if shape has a major-to-minor layout. +bool HasMajorToMinorLayout(const xla::Shape& shape) { + if (shape.has_layout()) { + for (int i = 0; i < shape.layout().minor_to_major().size(); ++i) { + if (shape.layout().minor_to_major(i) != + shape.layout().minor_to_major().size() - 1 - i) { + return false; + } + } + } + return true; +} + +// Returns byte_strides if shape has a non-major-to-minor layout. +std::optional> ByteStridesOrDefaultForShapeInt64( + const Shape& shape) { + if (!shape.has_layout() || HasMajorToMinorLayout(shape)) { + return std::nullopt; + } + return ByteStridesForShape(shape); +} + +bool IsZeroCopyableCpuBuffer(const PjRtBuffer* buf) { + // For CPU buffers with device-specific layouts, we must delinearize + // to unpack the array. This could happen for the host buffer + // pre-mapped to the TPU device, a.k.a., pinned host buffers for the + // device. + bool has_default_layout = + buf->layout() == nullptr || HasDefaultLayout(buf->layout()->xla_layout()); + // On CPU for values >= 8 bits, we can return the value in a zero-copy way. + // For sub-byte values, we must copy in order to unpack the array. + return buf->IsOnCpu() && + !primitive_util::IsSubByteNonPredType(buf->element_type()) && + has_default_layout; +} +} // namespace + +PyHostValue::PyHostValue() = default; +PyHostValue::~PyHostValue() = default; + +absl::StatusOr> PyHostValue::AsNumPyArray( + std::optional& dynamic_shape_holder, ifrt::Array* ifrt_array) { + if (ifrt_array->IsDeleted()) { + return InvalidArgument("DeviceArray has been deleted."); + } + // The only `jax.Array` with token-shape buffer is the one wrapped by + // `jax.core.Token`. Since it is an internal implementation detail, we + // don't support converting it to a numpy array. + if (ifrt_array->dtype().kind() == ifrt::DType::kToken) { + return InvalidArgument( + "Cannot convert a token-shape buffer to a numpy array."); + } + auto* arr = llvm::dyn_cast_or_null(ifrt_array); + if (arr != nullptr) { + auto* pjrt_buffer = arr->pjrt_buffers().front().get(); + TF_RET_CHECK(!pjrt_buffer->IsTuple()); + // On CPU for values >= 8 bits, we can return the value in a zero-copy way. + // For sub-byte values, we must copy in order to unpack the array. + if (IsZeroCopyableCpuBuffer(pjrt_buffer)) { + TF_ASSIGN_OR_RETURN(const auto* shape, + XlaDynamicShape(ifrt_array, dynamic_shape_holder)); + TF_ASSIGN_OR_RETURN(nb_dtype dtype, + PrimitiveTypeToNbDtype(shape->element_type())); + // Objects that must be kept alive while the array is alive. + struct Hold { + ifrt::ArrayRef buffer; + std::unique_ptr external_reference_hold; + }; + auto hold = std::make_unique(); + hold->buffer = tsl::FormRef(ifrt_array); + auto* hold_ptr = hold.release(); + nb::capsule hold_capsule( + hold_ptr, [](void* h) noexcept { delete static_cast(h); }); + { + // Release the GIL as `AcquireExternalReference` may block. + nb::gil_scoped_release gil; + TF_ASSIGN_OR_RETURN(hold_ptr->external_reference_hold, + pjrt_buffer->AcquireExternalReference()); + auto fut = ifrt_array->GetReadyFuture(); + BlockUntilReadyWithCancel(fut); + TF_RETURN_IF_ERROR(fut.Await()); + } + void* data = + hold_ptr->external_reference_hold->OpaqueDeviceMemoryDataPointer(); + nb_numpy_ndarray array(dtype, shape->dimensions(), + ByteStridesForShape(*shape), data, hold_capsule); + array.attr("flags").attr("writeable") = nb::bool_(false); + return std::make_pair(array, false); + } + } + + TF_RETURN_IF_ERROR(CopyToHostAsync(dynamic_shape_holder, ifrt_array)); + if (!ready_.IsReady()) { + nb::gil_scoped_release gil; + BlockUntilReadyWithCancel(ready_); + TF_RETURN_IF_ERROR(ready_.Await()); + } else { + TF_RETURN_IF_ERROR(ready_.Await()); + } + if (string_array_contents_ != nullptr) { + TF_RETURN_IF_ERROR(ConvertStringArrayContentsToNumpyArray(ifrt_array)); + } + return std::make_pair(value_, true); +} + +absl::Status PyHostValue::ConvertStringArrayContentsToNumpyArray( + ifrt::Array* ifrt_array) { +#ifdef NPY_2_0_API_VERSION + if (PyArray_RUNTIME_VERSION < NPY_2_0_API_VERSION) { + return absl::FailedPreconditionError( + absl::StrCat("String arrays are not supported in NumPy version: ", + PyArray_RUNTIME_VERSION)); + } + auto numpy_dtype = nb::steal( + reinterpret_cast(PyArray_DescrFromType(NPY_VSTRING))); + value_ = nb_numpy_ndarray(numpy_dtype, ifrt_array->shape().dims(), + /*strides=*/std::nullopt); + + auto dst_py_array_obj = reinterpret_cast<::PyArrayObject*>(value_.ptr()); + auto iter = + nb::steal(PyArray_IterNew(reinterpret_cast(dst_py_array_obj))); + for (auto& cord : *string_array_contents_) { + absl::string_view input_str_view = cord.Flatten(); + auto py_unicode = nb::steal(PyUnicode_FromStringAndSize( + input_str_view.data(), input_str_view.size())); + if (py_unicode.ptr() == nullptr) { + return absl::InternalError("PyUnicode_FromStringAndSize failed"); + } + if (PyArray_SETITEM(dst_py_array_obj, + static_cast(PyArray_ITER_DATA(iter.ptr())), + py_unicode.ptr()) != 0) { + return absl::InternalError("PyArray_SETITEM failed"); + } + PyArray_ITER_NEXT(iter.ptr()); + } + + value_.attr("flags").attr("writeable") = nb::bool_(false); + + string_array_contents_.reset(); + + return absl::OkStatus(); +#else + return absl::FailedPreconditionError( + "String arrays are not supported in this NumPy version."); +#endif +} + +absl::Status PyHostValue::CopyStringArrayToHostAsync( + std::optional& dynamic_shape_holder, ifrt::Array* ifrt_array) { + auto transfer_guard_formatter = [ifrt_array] { + return absl::StrCat( + "shape=(", absl::StrJoin(ifrt_array->shape().dims(), ","), + "), dtype=", ifrt_array->dtype().DebugString(), ", device=", + ifrt_array->sharding().devices()->devices().front()->DebugString()); + }; + TF_RETURN_IF_ERROR( + jax::ApplyTransferGuardToDeviceToHost(transfer_guard_formatter)); + + TF_ASSIGN_OR_RETURN(nb_dtype dtype, IfrtDtypeToNbDtype(ifrt_array->dtype())); + auto shape = ifrt_array->shape(); + + // Allocate a vector of cords to hold the contents of the array until + // they are until they are ultimately converted to a numpy array as part + // of the `AsNumPyArray` call. + string_array_contents_ = + std::make_shared>(shape.num_elements()); + ready_ = ifrt_array->CopyToHostBuffer(string_array_contents_->data(), + /*byte_strides=*/std::nullopt, + ifrt::ArrayCopySemantics::kAlwaysCopy); + + ready_.OnReady( + [string_array_contents = string_array_contents_](absl::Status) { + }); // Keeps the cords alive until the copy is done. + + return absl::OkStatus(); +} + +absl::Status PyHostValue::CopyToHostAsync( + std::optional& dynamic_shape_holder, ifrt::Array* ifrt_array) { + if (ready_.IsValid()) { + // The array value has been populated, so CopyToHostAsync has been called. + return absl::OkStatus(); + } + + // Copying in Arrays of type kString requires some special handling + if (ifrt_array->dtype().kind() == ifrt::DType::kString) { + return CopyStringArrayToHostAsync(dynamic_shape_holder, ifrt_array); + } + + auto* arr = llvm::dyn_cast_or_null(ifrt_array); + if (arr != nullptr && !arr->pjrt_buffers().front()->IsTuple() && + IsZeroCopyableCpuBuffer(arr->pjrt_buffers().front().get())) { + return absl::OkStatus(); + } + auto transfer_guard_formatter = [ifrt_array] { + return absl::StrCat( + "shape=(", absl::StrJoin(ifrt_array->shape().dims(), ","), + "), dtype=", ifrt_array->dtype().DebugString(), ", device=", + ifrt_array->sharding().devices()->devices().front()->DebugString()); + }; + TF_RETURN_IF_ERROR( + jax::ApplyTransferGuardToDeviceToHost(transfer_guard_formatter)); + + // TODO(b/182461453): This is a blocking call. If we further implemented + // populating dynamic shape metadata while fetching the literal, we wouldn't + // need this static approach. + const xla::Shape* dynamic_shape; + std::optional shape_holder; + if (llvm::isa(ifrt_array)) { + TF_ASSIGN_OR_RETURN(dynamic_shape, + XlaDynamicShape(ifrt_array, dynamic_shape_holder)); + } else { + // Skip querying the dynamic shape for a non-PjRt Array. + TF_ASSIGN_OR_RETURN(xla::PrimitiveType type, + ifrt::ToPrimitiveType(ifrt_array->dtype())); + shape_holder = ShapeUtil::MakeShapeWithDescendingLayout( + type, ifrt_array->shape().dims()); + dynamic_shape = &*shape_holder; + } + + xla::Shape host_shape = ShapeUtil::DeviceShapeToHostShape(*dynamic_shape); + + auto strides = ByteStridesOrDefaultForShapeInt64(host_shape); + TF_ASSIGN_OR_RETURN(nb_dtype dtype, + PrimitiveTypeToNbDtype(host_shape.element_type())); + value_ = nb_numpy_ndarray(dtype, host_shape.dimensions(), strides); + // TODO(hyeontaek): Several PjRt runtimes assume that the host buffer uses + // the same transposition as the device buffer. This is different from + // PjRtBuffer::ToLiteral()'s semantics that the runtime respects the layout + // of the host buffer literal. On the other hand, the runtime often knows + // better about an efficient layout for the host buffer. It will be useful + // to revisit the semantics of PjRtBuffer::ToLiteral() to see if it is + // desirable for the runtime to choose the layout. + ready_ = ifrt_array->CopyToHostBuffer(value_.mutable_data(), strides, + ifrt::ArrayCopySemantics::kReuseInput); + // Make sure the destination of the copy remains alive until the copy is done. + value_.inc_ref(); + ready_.OnReady([array{value_.ptr()}](absl::Status status) { + GlobalPyRefManager()->AddGarbage(nb::steal(array)); + }); + value_.attr("flags").attr("writeable") = nb::bool_(false); + return absl::OkStatus(); +} + +void PyHostValue::Clear() { + ready_ = {}; + value_ = {}; + string_array_contents_ = {}; +} + +namespace { +PyMemberDef PyBaseArray_members[] = { +#if PY_VERSION_HEX < 0x030C0000 + {"__weaklistoffset__", T_PYSSIZET, + static_cast(offsetof(PyBaseArrayObject, weakrefs)), READONLY, + nullptr}, +#endif // PY_VERSION_HEX < 0x030C0000 + {nullptr, 0, 0, 0, nullptr}, +}; + +PyType_Slot PyBaseArray_slots[] = { + {Py_tp_dealloc, reinterpret_cast(PyBaseArray_tp_dealloc)}, + {Py_tp_members, reinterpret_cast(PyBaseArray_members)}, + {Py_tp_traverse, reinterpret_cast(PyBaseArray_tp_traverse)}, + {Py_tp_hash, reinterpret_cast(PyObject_HashNotImplemented)}, + {0, nullptr}, +}; + +PyGetSetDef PyArray_tp_getset[] = { + {"__dict__", PyObject_GenericGetDict, PyObject_GenericSetDict, nullptr, + nullptr}, + {nullptr, nullptr, nullptr, nullptr, nullptr}, +}; + +PyMemberDef PyArray_members[] = { +#if PY_VERSION_HEX < 0x030C0000 + {"__weaklistoffset__", T_PYSSIZET, + static_cast(offsetof(PyArrayObject, weakrefs)), READONLY, + nullptr}, + {"__dictoffset__", T_PYSSIZET, + static_cast(offsetof(PyArrayObject, dict)), READONLY, nullptr}, +#endif // PY_VERSION_HEX < 0x030C0000 + {nullptr, 0, 0, 0, nullptr}, +}; // namespace xla + +PyType_Slot PyArray_slots[] = { + {Py_tp_new, reinterpret_cast(PyArray_tp_new)}, + {Py_tp_dealloc, reinterpret_cast(PyArray_tp_dealloc)}, + {Py_tp_members, reinterpret_cast(PyArray_members)}, + {Py_tp_traverse, reinterpret_cast(PyArray_tp_traverse)}, + {Py_tp_clear, reinterpret_cast(PyArray_tp_clear)}, + {Py_tp_getset, reinterpret_cast(PyArray_tp_getset)}, + {Py_bf_getbuffer, reinterpret_cast(PyArray_bf_getbuffer)}, + {Py_bf_releasebuffer, reinterpret_cast(PyArray_bf_releasebuffer)}, + {0, nullptr}, +}; + +} // namespace + +absl::Status PyArray::RegisterTypes(nb::module_& m) { + // We are not using nanobind to avoid having a non-standard metaclass, which + // would make Array incompatible with abc.ABCMeta. + std::string base_name = + absl::StrCat(nb::cast(m.attr("__name__")), ".Array"); + PyType_Spec PyBaseArray_spec = { + /*.name=*/base_name.c_str(), + /*.basicsize=*/static_cast(sizeof(PyBaseArrayObject)), + /*.itemsize=*/0, +#if PY_VERSION_HEX < 0x030C0000 + /*.flags=*/Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HAVE_GC, +#else // PY_VERSION_HEX >= 0x030C0000 + /*.flags=*/Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HAVE_GC | + Py_TPFLAGS_MANAGED_WEAKREF, +#endif // PY_VERSION_HEX >= 0x030C0000 + /*.slots=*/PyBaseArray_slots}; + auto* base_type = PyType_FromSpec(&PyBaseArray_spec); + if (!base_type) { + throw nb::python_error(); + } + m.attr("Array") = nb::borrow(base_type); + + std::string name = + absl::StrCat(nb::cast(m.attr("__name__")), ".ArrayImpl"); + + PyType_Spec PyArray_spec = { + /*.name=*/name.c_str(), + /*.basicsize=*/static_cast(sizeof(PyArrayObject)), + /*.itemsize=*/0, +#if PY_VERSION_HEX < 0x030C0000 + /*.flags=*/Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC, +#else // PY_VERSION_HEX >= 0x030C0000 + /*.flags=*/Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC | + Py_TPFLAGS_MANAGED_DICT | Py_TPFLAGS_MANAGED_WEAKREF, +#endif // PY_VERSION_HEX >= 0x030C0000 + /*.slots=*/PyArray_slots, + }; + + type_ = PyType_FromSpecWithBases(&PyArray_spec, base_type); + if (!type_) { + throw nb::python_error(); + } + auto type = nb::borrow(type_); + m.attr("ArrayImpl") = type; + + type.attr("__init__") = nb::cpp_function( + [](PyArray self, nb::object aval, nb::object sharding, nb::list arrays, + bool committed, bool skip_checks) { + if (!(arrays.size() == 0 || arrays[0].type().is(PyArray::type()))) { + throw nb::type_error( + absl::StrCat( + "Unsupported type for elements in `arrays`: ", + nb::cast(nb::str(arrays[0].type()))) + .c_str()); + } + auto py_arrays = nb::cast>(arrays); + PyArray::PyInit(self, std::move(aval), std::move(sharding), py_arrays, + committed, skip_checks); + }, + nb::is_method(), nb::arg("aval"), nb::arg("sharding"), nb::arg("arrays"), + nb::arg("committed"), nb::arg("_skip_checks") = false); + type.attr("delete") = nb::cpp_function( + [](PyArray& self) { xla::ThrowIfError(self.Delete()); }, nb::is_method()); + type.attr("_sharding") = nb_property_readonly(&PyArray::sharding); + type.attr("aval") = nb_property(&PyArray::aval, &PyArray::set_aval); + type.attr("_arrays") = + nb_property(&PyArray::arrays, [](PyArray& self, nb::object obj) { + xla::ThrowIfError(self.set_arrays(obj)); + }); + type.attr("_fully_replicated_shard") = nb::cpp_function( + [](PyArray self) { + return xla::ValueOrThrow(self.FullyReplicatedShard()); + }, + nb::is_method()); + type.attr("_npy_value") = + nb_property(&PyArray::npy_value, &PyArray::set_npy_value); + type.attr("_committed") = nb_property_readonly(&PyArray::committed); + type.attr("unsafe_buffer_pointer") = nb::cpp_function( + [](PyArray self) { + return xla::ValueOrThrow(self.UnsafeBufferPointer()); + }, + nb::is_method()); + type.attr("__cuda_array_interface__") = nb_property_readonly( + [](PyArray self) { return self.CudaArrayInterface(); }); + type.attr("_pjrt_layout") = + nb_property_readonly(xla::ValueOrThrowWrapper(&PyArray::layout)); + type.attr("on_device_size_in_bytes") = nb::cpp_function( + xla::ValueOrThrowWrapper(&PyArray::GetOnDeviceSizeInBytes), + nb::is_method()); + type.attr("_single_device_array_to_np_array_did_copy") = nb::cpp_function( + xla::ValueOrThrowWrapper(&PyArray::SingleDeviceArrayToNumpyArrayDidCopy), + nb::is_method()); + type.attr("_copy_single_device_array_to_host_async") = nb::cpp_function( + [](PyArray& self) { + xla::ThrowIfError(self.CopySingleDeviceArrayToHostAsync()); + }, + nb::is_method()); + type.attr("_replace_with") = nb::cpp_function( + [](PyArray& self, PyArray& o) { + xla::ThrowIfError(self.ReplaceWithAlias(o)); + }, + nb::is_method()); + type.attr("block_until_ready") = nb::cpp_function( + [](PyArray self) -> nb::object { + xla::ThrowIfError(self.BlockUntilReady()); + return self; + }, + nb::is_method()); + type.attr("platform") = nb::cpp_function( + [](PyArray self) { + if (self.ifrt_array()->client()->platform_name() == "cuda" || + self.ifrt_array()->client()->platform_name() == "rocm") { + return absl::string_view("gpu"); + } else { + return self.ifrt_array()->client()->platform_name(); + } + }, + nb::is_method()); + type.attr("is_ready") = nb::cpp_function( + [](PyArray self) { return xla::ValueOrThrow(self.IsReady()); }, + nb::is_method()); + type.attr("is_deleted") = + nb::cpp_function(&PyArray::IsDeleted, nb::is_method()); + type.attr("traceback") = nb_property_readonly(&PyArray::traceback); + type.attr("clone") = nb::cpp_function(&PyArray::Clone, nb::is_method()); + type.attr("__module__") = m.attr("__name__"); + + m.attr("batched_copy_array_to_devices_with_sharding") = nb::cpp_function( + [](absl::Span arrays, + absl::Span> dst_device_lists, + absl::Span shardings, + absl::Span array_copy_semantics) { + if (arrays.empty()) { + return std::vector(); + } + tsl::profiler::TraceMe traceme( + "batched_copy_array_to_devices_with_sharding"); + std::vector device_lists; + { + tsl::profiler::TraceMe device_list_traceme( + "batched_copy_array_to_devices_with_sharding: assemble device " + "lists"); + auto* client = arrays[0].ifrt_array()->client(); + device_lists.reserve(dst_device_lists.size()); + for (const auto& dst_devices : dst_device_lists) { + absl::InlinedVector devices; + devices.reserve(dst_devices.size()); + for (auto& d : dst_devices) { + devices.push_back(d->device()); + } + device_lists.push_back(client->MakeDeviceList(devices)); + } + } + return xla::ValueOrThrow(PyArray::BatchedCopyToDeviceWithSharding( + arrays, device_lists, shardings, array_copy_semantics)); + }); + m.attr("array_result_handler") = nb::cpp_function( + [](nb::object aval, nb::object sharding, bool committed, + bool skip_checks) -> nb_class_ptr { + return make_nb_class( + std::move(aval), std::move(sharding), committed, skip_checks); + }, + nb::arg("aval"), nb::arg("sharding"), nb::arg("committed"), + nb::arg("_skip_checks") = false); + + nb::class_(m, "ResultHandler") + .def("__call__", [](const PyArrayResultHandler& self, + PyArray arg) { return self.Call(arg); }) + .def("__call__", + [](const PyArrayResultHandler& self, + std::vector py_arrays) { return self.Call(py_arrays); }); + + return absl::OkStatus(); +} + +} // namespace xla diff --git a/jaxlib/py_array.h b/jaxlib/py_array.h new file mode 100644 index 000000000000..5b496be091f0 --- /dev/null +++ b/jaxlib/py_array.h @@ -0,0 +1,360 @@ +/* Copyright 2022 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_PY_ARRAY_H_ +#define JAXLIB_PY_ARRAY_H_ + +#include + +#include +#include +#include +#include +#include +#include + +// placeholder for index annotation headers +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/cord.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "llvm/Support/Casting.h" +#include "nanobind/nanobind.h" +#include "jaxlib/nb_class_ptr.h" +#include "jaxlib/py_client.h" +#include "jaxlib/traceback.h" +#include "xla/pjrt/exceptions.h" +#include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_future.h" +#include "xla/pjrt/pjrt_layout.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/device_list.h" +#include "xla/python/ifrt/future.h" +#include "xla/python/nb_numpy.h" +#include "xla/python/pjrt_ifrt/pjrt_array.h" +#include "xla/shape.h" +#include "xla/util.h" + +namespace xla { + +// Private to PyArray, but you cannot forward declare member classes. +// Not thread safe; assumes the GIL is held. +class PyHostValue { + public: + PyHostValue(); + ~PyHostValue(); + + PyHostValue(const PyHostValue&) = delete; + PyHostValue(PyHostValue&&) = delete; + PyHostValue& operator=(const PyHostValue&) = delete; + PyHostValue& operator=(PyHostValue&&) = delete; + + absl::Status CopyToHostAsync(std::optional& dynamic_shape_holder, + ifrt::Array* ifrt_array); + + absl::StatusOr> AsNumPyArray( + std::optional& dynamic_shape_holder, ifrt::Array* ifrt_array); + + void Clear(); + + private: + absl::Status CopyStringArrayToHostAsync( + std::optional& dynamic_shape_holder, ifrt::Array* ifrt_array); + + absl::Status ConvertStringArrayContentsToNumpyArray(ifrt::Array* ifrt_array); + + ifrt::Future<> ready_; + nb_numpy_ndarray value_; + + // Optional field, only used for arrays of type kString. This vector of cords + // serves as input buffer for the CopyToHostBuffer call. It holds these + // contents until it is lazily converted it to a numpy array when the user + // calls `AsNumPyArray`. + std::shared_ptr> string_array_contents_; +}; + +// Private to PyArray, but you cannot forward declare member classes. +struct PyArray_Storage { + PyArray_Storage(nanobind::object aval, bool weak_type, nb_dtype dtype, + std::vector shape, nanobind::object sharding, + bool committed, nb_class_ptr py_client, + std::optional traceback, ifrt::ArrayRef ifrt_array, + xla::PjRtFuture<> result_status); + + ~PyArray_Storage(); + nanobind::handle AsHandle(); + + nanobind::object aval; + bool weak_type = false; + nb_dtype dtype; + std::vector shape; + + nanobind::object sharding; + nanobind::object npy_value = nanobind::none(); + bool committed = false; + + nb_class_ptr py_client; + std::optional traceback; + ifrt::ArrayRef ifrt_array; + nanobind::object fully_replicated_array = nanobind::none(); + + // optional field, used only in python + std::vector py_arrays; + PyHostValue host_value; // Protected by the GIL. + std::optional dynamic_shape = std::nullopt; + // Only set if this Array was generated by a computation that has effects. + // This is the result status of the XLA computation that generated this + // array. + xla::PjRtFuture<> result_status; + + // Doubly-linked list of all PyArrays known to the client. Protected by the + // GIL. Since multiple PyArrays may share the same PjRtBuffer, there may be + // duplicate PjRtBuffers in this list. + PyArray_Storage* next; + PyArray_Storage* prev; + + uint8_t thread_id_bucket; +}; + +// The C++ implementation of jax.Array. A few key methods and data members are +// implemented in C++ for performance, while most of the functionalities are +// still implemented in python. +class PyArray : public nanobind::object { + public: + NB_OBJECT(PyArray, nanobind::object, "Array", PyArray::IsPyArray); + PyArray() = default; + + // "__init__" methods. Only used in python + static void PyInit(PyArray self, nanobind::object aval, + nanobind::object sharding, + absl::Span py_arrays, bool committed, + bool skip_checks); + + // Only used in C++. `skip_checks` should only be set for Arrays created by + // jax that cannot possibly have consistency issues (e.g. `sharding` devices + // different than `ifrt_array` devices). Arrays created by users should be + // checked. + PyArray(nanobind::object aval, bool weak_type, nb_dtype dtype, + std::vector shape, nanobind::object sharding, + nb_class_ptr py_client, std::optional traceback, + ifrt::ArrayRef ifrt_array, bool committed, bool skip_checks, + xla::PjRtFuture<> result_status = xla::PjRtFuture<>()); + + static PyArray MakeFromSingleDeviceArray( + nb_class_ptr py_client, std::optional traceback, + ifrt::ArrayRef ifrt_array, bool weak_type, bool committed, + xla::PjRtFuture<> result_status = xla::PjRtFuture<>()); + + static PyArray MakeFromIfrtArrayAndSharding( + nb_class_ptr py_client, std::optional traceback, + ifrt::ArrayRef ifrt_array, nanobind::object sharding, bool weak_type, + bool committed, bool skip_checks); + + static absl::Status RegisterTypes(nanobind::module_& m); + + static PyArray borrow(PyObject* ptr) { + return nanobind::borrow(ptr); + } + + using Storage = PyArray_Storage; + + const nanobind::object& aval() const { return GetStorage().aval; } + void set_aval(nanobind::object aval) { GetStorage().aval = std::move(aval); } + + bool weak_type() const { return GetStorage().weak_type; } + + const nb_dtype& dtype() const { return GetStorage().dtype; } + absl::Span shape() const { return GetStorage().shape; } + + const nanobind::object& sharding() const { return GetStorage().sharding; } + + absl::StatusOr> layout() { + return ifrt_array()->layout(); + } + + bool committed() const { return GetStorage().committed; } + + const nanobind::object& npy_value() const { return GetStorage().npy_value; } + void set_npy_value(nanobind::object v) { + GetStorage().npy_value = std::move(v); + } + + const nb_class_ptr& py_client() const { + return GetStorage().py_client; + } + + const std::optional& traceback() const { + return GetStorage().traceback; + } + + // Returns xla::InvalidArgument if the buffer has been deleted. + // See `PjRtFuture` for the semantics of `IsReady` and `IsKnownReady`. + absl::StatusOr IsReady() { + ifrt::Array* ifrt_array_ptr = ifrt_array(); + if (ifrt_array_ptr->IsDeleted()) { + return InvalidArgument("Array has been deleted."); + } + return ifrt_array_ptr->GetReadyFuture().IsReady(); + } + + const xla::PjRtFuture<>& result_status() const { + return GetStorage().result_status; + } + + ifrt::Array* ifrt_array() const { return GetStorage().ifrt_array.get(); } + + // Short-term escape hatch to get PjRtBuffers from PyArray. + // TODO(hyeontaek): Migrate all users of this method to be agnostic of PjRt. + absl::Span> pjrt_buffers() const { + ifrt::Array* ifrt_array_ptr = ifrt_array(); + if (ifrt_array_ptr == nullptr) { + return {}; + } + auto* arr = + llvm::dyn_cast_or_null(ifrt_array_ptr); + if (arr == nullptr) { + throw XlaRuntimeError( + "This operation is implemented for a PjRt-compatible backend only."); + } + return arr->pjrt_buffers(); + } + + int num_addressable_shards() const { + ifrt::Array* ifrt_array_ptr = ifrt_array(); + if (ifrt_array_ptr == nullptr) { + return 0; + } + auto* arr = + llvm::dyn_cast_or_null(ifrt_array_ptr); + if (arr == nullptr) { + // TODO(hyeontaek): Add num_addressable_shards to ifrt. + return num_shards(); + } + return arr->pjrt_buffers().size(); + } + + std::vector& py_arrays() { return GetStorage().py_arrays; } + const std::vector& py_arrays() const { + return GetStorage().py_arrays; + } + const std::vector& py_arrays_cached(); + + nanobind::object arrays(); + absl::Status set_arrays(nanobind::object obj); + absl::StatusOr FullyReplicatedShard(); + + int num_shards() const { + ifrt::Array* ifrt_array_ptr = ifrt_array(); + if (ifrt_array_ptr == nullptr) { + return 0; + } + return ifrt_array_ptr->sharding().devices()->size(); + } + + static nanobind::handle type() { + DCHECK(type_); + return nanobind::handle(type_); + } + + static bool IsPyArray(nanobind::handle arg) { + return arg.type().is(PyArray::type()); + } + + absl::Status BlockUntilReady() const; + + absl::Status BlockUntilResultStatusIsReady(); + + absl::StatusOr GetOnDeviceSizeInBytes(); + absl::StatusOr> + SingleDeviceArrayToNumpyArrayDidCopy(); + absl::StatusOr SingleDeviceArrayToNumpyArray(); + absl::Status CopySingleDeviceArrayToHostAsync(); + nanobind::dict CudaArrayInterface(); + absl::StatusOr UnsafeBufferPointer(); + + absl::Status Delete(); + + bool IsDeleted() const; + + PyArray Clone() const; + + static absl::StatusOr> BatchedCopyToDeviceWithSharding( + absl::Span py_arrays, + absl::Span dst_device_lists, + absl::Span dst_shardings, + absl::Span array_copy_semantics); + + static absl::StatusOr BatchedDevicePut( + nanobind::object aval, nanobind::object sharding, + std::vector xs, + absl::Span dst_devices, bool committed, + bool force_copy, PjRtClient::HostBufferSemantics host_buffer_semantics, + bool jax_enable_x64); + + static absl::StatusOr ReorderShards( + PyArray x, nanobind::object dst_sharding, + ifrt::ArrayCopySemantics array_copy_semantics); + + static absl::Status BatchedBlockUntilReady( + std::vector objs); + + absl::Status ReplaceWithAlias(PyArray o); + + private: + absl::StatusOr AssertUnsharded(absl::string_view api); + + nanobind::object CheckAndRearrange(absl::Span py_arrays, + nanobind::object sharding, + nanobind::object aval); + + void SetIfrtArray(ifrt::ArrayRef ifrt_array); + + Storage& GetStorage(); + const Storage& GetStorage() const; + + inline static PyObject* type_ = nullptr; +}; + +class PyArrayResultHandler { + public: + PyArrayResultHandler(nanobind::object aval, nanobind::object sharding, + bool committed, bool skip_checks); + + PyArray Call(absl::Span py_arrays) const; + PyArray Call(PyArray py_array) const; + + PyArray Call(nb_class_ptr py_client, ifrt::ArrayRef ifrt_array, + xla::PjRtFuture<> result_status = xla::PjRtFuture<>()) const; + + private: + nanobind::object aval_; + nanobind::object sharding_; + bool weak_type_; + bool committed_; + bool skip_checks_; + + nb_dtype dtype_; + std::vector shape_; +}; + +absl::StatusOr CudaArrayInterfaceToBuffer( + const nanobind::dict& cai, nb_class_ptr cuda_client, + std::optional device_id); + +} // namespace xla + +#endif // JAXLIB_PY_ARRAY_H_ diff --git a/jaxlib/py_client.cc b/jaxlib/py_client.cc new file mode 100644 index 000000000000..fbbb803607ee --- /dev/null +++ b/jaxlib/py_client.cc @@ -0,0 +1,950 @@ +/* Copyright 2020 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/py_client.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "llvm/Support/Casting.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OwningOpRef.h" +#include "mlir/Pass/PassManager.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/optional.h" // IWYU pragma: keep +#include "nanobind/stl/pair.h" // IWYU pragma: keep +#include "nanobind/stl/shared_ptr.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "nanobind/stl/unique_ptr.h" // IWYU pragma: keep +#include "nanobind/stl/variant.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "jaxlib/guard_lib.h" +#include "jaxlib/nb_class_ptr.h" +#include "jaxlib/py_array.h" +#include "jaxlib/py_device.h" +#include "jaxlib/py_device_list.h" +#include "jaxlib/py_executable.h" +#include "jaxlib/py_host_callback.h" +#include "jaxlib/py_memory_space.h" +#include "jaxlib/py_values.h" +#include "jaxlib/python_ref_manager.h" +#include "jaxlib/sharding.h" +#include "jaxlib/traceback.h" +#include "xla/literal.h" +#include "xla/pjrt/exceptions.h" +#include "xla/pjrt/mlir_to_hlo.h" +#include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_compiler.h" +#include "xla/pjrt/pjrt_executable.h" +#include "xla/pjrt/pjrt_layout.h" +#include "xla/pjrt/status_casters.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/client.h" +#include "xla/python/ifrt/compiler.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/device_list.h" +#include "xla/python/ifrt/dtype.h" +#include "xla/python/ifrt/executable.h" +#include "xla/python/ifrt/hlo/hlo_program.h" +#include "xla/python/ifrt/host_callback.h" +#include "xla/python/ifrt/memory.h" +#include "xla/python/ifrt/program.h" +#include "xla/python/nb_absl_span.h" // IWYU pragma: keep +#include "xla/python/nb_numpy.h" +#include "xla/python/pjrt_ifrt/pjrt_array.h" +#include "xla/python/pjrt_ifrt/pjrt_client.h" +#include "xla/python/pjrt_ifrt/pjrt_executable.h" +#include "xla/python/pjrt_ifrt/xla_compiler.h" +#include "xla/python/pprof_profile_builder.h" +#include "xla/python/types.h" +#include "xla/python/version.h" +#include "xla/service/platform_util.h" // IWYU pragma: keep +#include "xla/shape.h" +#include "xla/status_macros.h" +#include "xla/tsl/concurrency/ref_count.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/status.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/util.h" + +namespace xla { + +namespace nb = nanobind; + +/*static*/ nb_class_ptr PyClient::Make( + std::shared_ptr ifrt_client) { + auto client = make_nb_class(std::move(ifrt_client)); + Initialize(client); + return client; +} + +PyClient::PyClient(std::shared_ptr ifrt_client) + : ifrt_client_(std::move(ifrt_client)), + client_attributes_(ifrt_client_->Attributes()) { + CHECK(ifrt_client_); +} + +/* static */ void PyClient::Initialize(nb_class_ptr client) { + for (ifrt::Device* device : client->ifrt_client()->devices()) { + client->devices_[device] = make_nb_class(client, device); + + for (ifrt::Memory* memory : device->Memories()) { + auto& py_memory = client->memory_spaces_[memory]; + if (py_memory.get() == nullptr) { + py_memory = make_nb_class(client, memory); + } + } + } +} + +PyClient::~PyClient() { + nb::gil_scoped_release gil; + ifrt_client_ = nullptr; +} + +nb_class_ptr PyClient::GetPyDevice(ifrt::Device* device) { + auto& py_device = devices_[device]; + if (py_device.get() == nullptr) { + py_device = make_nb_class( + nb::borrow>(nb::find(this)), device); + } + return py_device; +} + +nb_class_ptr PyClient::GetPyMemorySpace( + ifrt::Memory* memory_space) { + auto& py_memory = memory_spaces_[memory_space]; + if (py_memory.get() == nullptr) { + py_memory = make_nb_class( + nb::borrow>(nb::find(this)), memory_space); + } + return py_memory; +} + +std::vector> PyClient::Devices() { + std::vector> devices; + auto span = ifrt_client_->devices(); + devices.reserve(span.size()); + for (ifrt::Device* device : span) { + devices.push_back(GetPyDevice(device)); + } + return devices; +} + +std::vector> PyClient::LocalDevices() { + std::vector> devices; + devices.reserve(ifrt_client_->addressable_devices().size()); + for (ifrt::Device* device : ifrt_client_->addressable_devices()) { + devices.push_back(GetPyDevice(device)); + } + return devices; +} + +std::vector> PyClient::GetAllDevices() { + std::vector> devices; + devices.reserve(ifrt_client_->GetAllDevices().size()); + for (ifrt::Device* device : ifrt_client_->GetAllDevices()) { + devices.push_back(GetPyDevice(device)); + } + return devices; +} + +absl::StatusOr> PyClient::DeviceFromLocalHardwareId( + int local_hardware_id) { + TF_ASSIGN_OR_RETURN(ifrt::Device * device, + ifrt_client_->LookupAddressableDevice(local_hardware_id)); + return GetPyDevice(device); +} + +nb::list PyClient::LiveExecutables() { + CHECK(PyGILState_Check()); + nb::ft_lock_guard lock(executables_mutex_); + nb::list executables; + for (PyLoadedExecutable* exec = executables_; exec; exec = exec->next_) { + executables.append(nb::find(exec)); + } + return executables; +} + +absl::Status PyClient::Defragment() { + CHECK(PyGILState_Check()); + if (!llvm::isa(ifrt_client_.get())) { + return absl::UnimplementedError( + "Defragmentation is not supported on this runtime."); + } + ifrt::PlatformId platform_id = ifrt_client_->platform_id(); + bool is_gpu_client = platform_id == CudaId() || platform_id == RocmId() || + platform_id == SyclId(); + + if (!is_gpu_client) { + return absl::UnimplementedError( + "Defragmentation is not supported on this runtime."); + } + + // TODO(b/399879011): This is a GPU-specific implementation of `Defragment`. + // Ideally, this would be replaced with some kind of auto-defrag-on-OOM, or at + // least would not live in this file. + + struct TmpBuffer { + // Non-empty for buffers found in a PyArray_Storage. Multiple Arrays + // can reference the same PjRtBuffer. + std::vector*> pjrt_buffer_ptrs; + // TODO(skyewm): maybe use py_buffer's HostValue + std::shared_ptr host_copy; + }; + + // Synchronously copy all buffers to host + absl::flat_hash_map pjrt_buf_to_tmp_buffer; + + std::vector arrays = LiveArrays(); + for (const PyArray& array : arrays) { + // TODO(hyeontaek): Support non-PjRt Arrays. + // TODO(hyeontaek): Re-construct ifrt::Array with new PjRtBuffer so that + // std::shared_ptr does not need to be updated in-place. + if (array.ifrt_array() == nullptr) { + continue; + } + auto* arr = + llvm::dyn_cast_or_null(array.ifrt_array()); + if (arr == nullptr) { + throw XlaRuntimeError( + "This operation is implemented for a PjRt-compatible backend " + "only."); + } + TF_ASSIGN_OR_RETURN(absl::Span> pjrt_buffers, + arr->mutable_pjrt_buffers()); + for (int i = 0; i < pjrt_buffers.size(); ++i) { + std::shared_ptr& pjrt_buf_ptr = pjrt_buffers[i]; + if (pjrt_buf_ptr->IsDeleted()) { + continue; + } + auto [iter, inserted] = + pjrt_buf_to_tmp_buffer.insert({pjrt_buf_ptr.get(), TmpBuffer()}); + if (inserted) { + TF_ASSIGN_OR_RETURN(iter->second.host_copy, + pjrt_buf_ptr->ToLiteralSync()); + } + iter->second.pjrt_buffer_ptrs.push_back(&pjrt_buf_ptr); + } + } + + // All buffers successfully copied to host, delete on-device copies. + // + // Use blocking delete operation to ensure all memory is actually cleared + // before we start rewriting buffers. + // + // Die instead of returning a bad status because program presumably can't + // continue if we fail to reconstitute device buffers. + for (const auto& it : pjrt_buf_to_tmp_buffer) { + PjRtBuffer* pjrt_buf = it.first; + TF_CHECK_OK(pjrt_buf + ->ReleaseDeviceMemoryOwnership( + /*wait_for_operations_to_complete=*/true) + .status()); + } + + // Copy host copies back to device and update PyArrays in-place. + for (auto& it : pjrt_buf_to_tmp_buffer) { + PjRtBuffer* pjrt_buf = it.first; + TmpBuffer& tmp_buffer = it.second; + std::unique_ptr new_copy = + pjrt_client() + ->BufferFromHostLiteral(*tmp_buffer.host_copy, + pjrt_buf->memory_space()) + .value(); + TF_CHECK_OK(new_copy->GetReadyFuture().Await()); + + std::shared_ptr new_pjrt_buf_ptr(new_copy.release()); + for (std::shared_ptr* pjrt_buffer_ptr : + tmp_buffer.pjrt_buffer_ptrs) { + *pjrt_buffer_ptr = new_pjrt_buf_ptr; + } + } + + // TODO(skyewm): delete executables? + return absl::OkStatus(); +} + +/* static */ absl::StatusOr PyClient::BufferFromPyval( + nb_class_ptr client, nb::handle argument, ifrt::Device* device, + bool force_copy, ifrt::Client::HostBufferSemantics host_buffer_semantics) { + if (device == nullptr) { + TF_RET_CHECK(!client->ifrt_client_->addressable_devices().empty()); + device = client->ifrt_client_->addressable_devices().front(); + } + CHECK(device != nullptr); + + auto transfer_guard_formatter = [&argument, dst_device = device] { + auto type = nb::cast(nb::str(argument.type())); + // Catch exceptions because shape and dtype properties convertible to str + // are not guaranteed to present in an arbitrary argument. + std::string shape; + std::string dtype; + try { + shape = + nb::cast(nb::str(nb::object(argument.attr("shape")))); + } catch (const std::exception& e) { + shape = ""; + } + try { + dtype = + nb::cast(nb::str(nb::object(argument.attr("dtype")))); + } catch (const std::exception& e) { + dtype = ""; + } + return absl::StrCat("type=", type, ", shape=", shape, ", dtype=", dtype, + ", dst_device=", dst_device->DebugString()); + }; + TF_RETURN_IF_ERROR( + jax::ApplyTransferGuardToHostToDevice(transfer_guard_formatter)); + + TF_ASSIGN_OR_RETURN(ifrt::Device * found_device, + client->ifrt_client_->LookupDevice(device->Id())); + if (found_device != device) { + return InvalidArgument("Cannot copy value to device '%s' with '%s' backend", + device->DebugString(), + client->ifrt_client_->platform_name()); + } + GlobalPyRefManager()->CollectGarbage(); + + DevicePutOptions options; + options.squash_64bit_types = false; + options.allow_zero_copy = + (!force_copy && (host_buffer_semantics == + ifrt::Client::HostBufferSemantics::kImmutableZeroCopy)); + TF_ASSIGN_OR_RETURN(DevicePutResult device_put_result, + DevicePutWithDevice(argument, client->ifrt_client_.get(), + device, ifrt::MemoryKind(), options)); + auto sharding = make_nb_class( + client, client->ifrt_client()->MakeDeviceList({device}), + /*memory_kind=*/nb::none()); + + auto traceback = Traceback::Get(); + return PyArray::MakeFromIfrtArrayAndSharding( + std::move(client), std::move(traceback), + std::move(device_put_result.ifrt_array), std::move(sharding), + /*weak_type=*/false, /*committed=*/false, + /*skip_checks=*/true); +} + +namespace { + +// Makes IFRT `CompileOptions` from XLA `CompileOptions` and optional host +// callbacks. +std::unique_ptr MakeIfrtCompileOptions( + CompileOptions options, ifrt::DeviceListRef executable_devices, + std::vector host_callbacks) { + std::vector> + ifrt_loaded_host_callbacks; + ifrt_loaded_host_callbacks.reserve(host_callbacks.size()); + // Extract `ifrt::LoadedHostCallback`s from host callback capsules that were + // created by `PyClient::MakePythonCallbackUsingHostSendAndRecv()`. + for (auto& host_callback : host_callbacks) { + ifrt_loaded_host_callbacks.push_back(tsl::FormRef( + static_cast(host_callback.data()))); + } + return std::make_unique( + std::move(options), std::move(executable_devices), + std::move(ifrt_loaded_host_callbacks)); +} + +// Makes IFRT `DeserializeExecutableOptions` from XLA `CompileOptions` and +// optional host callbacks. +std::unique_ptr +MakeIfrtDeserializeExecutableOptions(std::optional options, + ifrt::DeviceListRef executable_devices, + std::vector host_callbacks) { + std::vector> + ifrt_loaded_host_callbacks; + ifrt_loaded_host_callbacks.reserve(host_callbacks.size()); + // Extract `ifrt::LoadedHostCallback`s from host callback capsules that were + // created by `PyClient::MakePythonCallbackUsingHostSendAndRecv()`. + for (auto& host_callback : host_callbacks) { + ifrt_loaded_host_callbacks.push_back(tsl::FormRef( + static_cast(host_callback.data()))); + } + return std::make_unique( + std::move(options), std::move(executable_devices), + std::move(ifrt_loaded_host_callbacks)); +} + +} // namespace + +/* static */ absl::StatusOr> +PyClient::CompileAndLoadIfrtProgram( + nb_class_ptr client, std::unique_ptr ifrt_program, + std::unique_ptr ifrt_options) { + auto* pjrt_compatible_client = + llvm::dyn_cast_or_null( + client->ifrt_client_.get()); + auto* ifrt_xla_options = + llvm::dyn_cast_or_null(ifrt_options.get()); + // For XLA programs, pass allocated device memory size to compile options for + // pjrt compatible backends. + if (pjrt_compatible_client != nullptr && ifrt_xla_options != nullptr) { + xla::CompileOptions& options = ifrt_xla_options->compile_options; + auto addressable_devices = + pjrt_compatible_client->pjrt_client()->addressable_devices(); + if (!addressable_devices.empty()) { + int device_ordinal = options.executable_build_options.device_ordinal(); + if (device_ordinal < 0) { + device_ordinal = 0; + } + CHECK_LT(device_ordinal, addressable_devices.size()); + auto stats = addressable_devices[device_ordinal]->GetAllocatorStats(); + if (stats.ok() && stats->bytes_limit) { + options.executable_build_options.set_device_memory_size( + *stats->bytes_limit); + } + } + + if (pjrt_compatible_client->pjrt_client()->key_value_store().has_value()) { + options.executable_build_options.set_key_value_store( + *pjrt_compatible_client->pjrt_client()->key_value_store()); + } + } + + ifrt::LoadedExecutableRef ifrt_loaded_executable; + std::optional fingerprint; + { + nb::gil_scoped_release gil_release; + TF_ASSIGN_OR_RETURN( + ifrt_loaded_executable, + client->ifrt_client_->GetDefaultCompiler()->CompileAndLoad( + std::move(ifrt_program), std::move(ifrt_options))); + TF_RETURN_IF_ERROR(ifrt_loaded_executable->GetReadyFuture().Await()); + TF_ASSIGN_OR_RETURN(fingerprint, ifrt_loaded_executable->Fingerprint()); + } + auto traceback = Traceback::Get(); + return make_nb_class( + std::move(client), std::move(ifrt_loaded_executable), + std::move(traceback), std::move(fingerprint)); +} + +/* static */ absl::StatusOr> PyClient::Compile( + nb_class_ptr client, std::string mlir_module, + ifrt::DeviceListRef executable_devices, CompileOptions options) { + ifrt::ExecutableRef executable_ref; + { + mlir::MLIRContext context; + nb::gil_scoped_release gil_release; + TF_ASSIGN_OR_RETURN(mlir::OwningOpRef module, + ParseMlirModuleString(mlir_module, context)); + TF_ASSIGN_OR_RETURN( + auto topology, + client->ifrt_client()->GetTopologyForDevices(executable_devices)); + auto xla_options = std::make_unique( + options, std::move(executable_devices)); + TF_ASSIGN_OR_RETURN(auto pjrt_executable, + PjRtCompile(std::move(options), module.get(), + *topology->description())); + TF_ASSIGN_OR_RETURN(executable_ref, ifrt::PjRtExecutable::Create( + std::move(pjrt_executable))); + } + return make_nb_class(executable_ref); +} + +/* static */ absl::StatusOr> +PyClient::CompileAndLoad(nb_class_ptr client, std::string mlir_module, + ifrt::DeviceListRef executable_devices, + CompileOptions options, + std::vector host_callbacks) { + mlir::MLIRContext context; + TF_ASSIGN_OR_RETURN(mlir::OwningOpRef module, + ParseMlirModuleString(mlir_module, context)); + return CompileAndLoadIfrtProgram( + client, std::make_unique(module.get()), + MakeIfrtCompileOptions(std::move(options), std::move(executable_devices), + std::move(host_callbacks))); +} + +/* static */ absl::StatusOr> +PyClient::CompileAndLoad(nb_class_ptr client, std::string mlir_module, + ifrt::DeviceListRef executable_devices, + CompileOptions options, + std::vector host_callbacks) { + mlir::MLIRContext context; + TF_ASSIGN_OR_RETURN(mlir::OwningOpRef module, + ParseMlirModuleString(mlir_module, context)); + + std::vector> + ifrt_loaded_host_callbacks; + ifrt_loaded_host_callbacks.reserve(host_callbacks.size()); + // Extract `ifrt::LoadedHostCallback`s from host callback capsules that were + // created by `PyClient::MakePythonCallbackUsingHostSendAndRecv()`. + for (auto& host_callback : host_callbacks) { + auto callback = tsl::MakeRef( + client->ifrt_client(), std::move(host_callback)); + ifrt_loaded_host_callbacks.push_back(callback); + } + auto compile_options = std::make_unique( + std::move(options), std::move(executable_devices), + std::move(ifrt_loaded_host_callbacks)); + return CompileAndLoadIfrtProgram( + client, std::make_unique(module.get()), + std::move(compile_options)); +} + +absl::StatusOr PyClient::SerializeExecutable( + const PyLoadedExecutable& executable) const { + TF_ASSIGN_OR_RETURN(auto serialized, + executable.ifrt_loaded_executable()->Serialize()); + return nb::bytes(serialized.data(), serialized.size()); +} + +/* static */ absl::StatusOr> +PyClient::DeserializeExecutable(nb_class_ptr client, + nb::bytes serialized, + ifrt::DeviceListRef executable_devices, + std::optional options, + std::vector host_callbacks) { + ifrt::LoadedExecutableRef ifrt_loaded_executable; + std::optional fingerprint; + auto ifrt_deserialize_options = MakeIfrtDeserializeExecutableOptions( + std::move(options), std::move(executable_devices), + std::move(host_callbacks)); + { + nb::gil_scoped_release gil_release; + TF_ASSIGN_OR_RETURN( + ifrt_loaded_executable, + client->ifrt_client_->GetDefaultCompiler()->DeserializeLoadedExecutable( + absl::string_view(serialized.c_str(), serialized.size()), + std::move(ifrt_deserialize_options))); + } + TF_ASSIGN_OR_RETURN(fingerprint, ifrt_loaded_executable->Fingerprint()); + auto traceback = Traceback::Get(); + return make_nb_class( + std::move(client), std::move(ifrt_loaded_executable), + std::move(traceback), std::move(fingerprint)); +} + +namespace { + +struct HeapProfileKey { + std::optional traceback; + int64_t size; + xla::PjRtDevice* device; + bool operator==(const HeapProfileKey& other) const; +}; + +bool HeapProfileKey::operator==(const HeapProfileKey& other) const { + if (size != other.size || device != other.device) { + return false; + } + if ((traceback.has_value()) != (other.traceback.has_value())) { + return false; + } + if (traceback.has_value() && traceback->not_equal(*other.traceback)) { + return false; + } + return true; +} + +template +H AbslHashValue(H h, const HeapProfileKey& key) { + if (key.traceback) { + h = H::combine(std::move(h), nb::hash(*key.traceback)); + } + h = H::combine(std::move(h), key.size, key.device); + return h; +} + +} // namespace + +absl::StatusOr PyClient::HeapProfile() { + CHECK(PyGILState_Check()); + absl::flat_hash_set buffer_set; + absl::flat_hash_map entries; + + auto add_buffer_to_profile = [&](PjRtBuffer* buffer, + std::optional traceback) { + // We only wish to count each PjRtBuffer once, even though they may be + // shared by multiple PyArrays. + if (!buffer->IsDeleted() && buffer_set.insert(buffer).second) { + TF_ASSIGN_OR_RETURN(size_t size, buffer->GetOnDeviceSizeInBytes()); + HeapProfileKey key{traceback, static_cast(size), + buffer->device()}; + ++entries[key]; + } + return absl::OkStatus(); + }; + + std::vector arrays = LiveArrays(); + for (const PyArray& array : arrays) { + if (array.ifrt_array() == nullptr) { + continue; + } + auto* arr = + llvm::dyn_cast_or_null(array.ifrt_array()); + // TODO(hyeontaek): Support non-PjRt Arrays. + if (arr == nullptr) { + throw XlaRuntimeError( + "This operation is implemented for a PjRt-compatible backend " + "only."); + } + for (const auto& buffer : arr->pjrt_buffers()) { + TF_RETURN_IF_ERROR( + add_buffer_to_profile(buffer.get(), array.traceback())); + } + } + + for (PyLoadedExecutable* executable = executables_; executable; + executable = executable->next_) { + HeapProfileKey key{executable->traceback(), + executable->SizeOfGeneratedCodeInBytes(), nullptr}; + ++entries[key]; + } + + PprofProfileBuilder builder; + auto* allocations = builder.profile().add_sample_type(); + allocations->set_type(builder.StringId("allocations")); + allocations->set_unit(builder.StringId("count")); + auto* space = builder.profile().add_sample_type(); + space->set_type(builder.StringId("space")); + space->set_unit(builder.StringId("bytes")); + + const int kind_string_id = builder.StringId("kind"); + const int buffer_string_id = builder.StringId("buffer"); + const int executable_string_id = builder.StringId("executable"); + const int device_string_id = builder.StringId("device"); + for (const auto& entry : entries) { + auto* sample = builder.profile().add_sample(); + if (entry.first.traceback) { + for (const auto& frame : entry.first.traceback->RawFrames()) { + sample->add_location_id(builder.LocationId(frame.first, frame.second)); + } + } + sample->add_value(entry.second); + sample->add_value(entry.first.size * entry.second); + + auto* kind_label = sample->add_label(); + kind_label->set_key(kind_string_id); + if (entry.first.device) { + kind_label->set_str(buffer_string_id); + auto* device_label = sample->add_label(); + device_label->set_key(device_string_id); + std::string device_label_str(entry.first.device->DebugString()); + device_label->set_str(builder.StringId(device_label_str)); + } else { + kind_label->set_str(executable_string_id); + } + } + std::string serialized = builder.profile().SerializeAsString(); + return nb::bytes(serialized.data(), serialized.size()); +} + +absl::StatusOr PyClient::MakePythonCallbackUsingHostSendAndRecv( + nb::callable callable, absl::Span operand_shapes, + absl::Span result_shapes, + absl::Span send_channel_ids, + absl::Span recv_channel_ids, nb::callable serializer) { + TF_ASSIGN_OR_RETURN( + auto loaded_host_callback, + PyHostSendAndRecvLoadedHostCallback::Create( + ifrt_client(), std::move(callable), operand_shapes, result_shapes, + send_channel_ids, recv_channel_ids, std::move(serializer))); + nb::capsule callback_capsule( + loaded_host_callback.release(), [](void* ptr) noexcept { + static_cast(ptr)->DropRef(); + }); + return callback_capsule; +} + +/* static */ int PyClient::tp_traverse(PyObject* self, visitproc visit, + void* arg) { + Py_VISIT(Py_TYPE(self)); + if (!nb::inst_ready(self)) { + return 0; + } + PyClient* c = nb::inst_ptr(self); + for (const auto& [ifrt_device, py_device] : c->devices_) { + Py_VISIT(py_device.ptr()); + } + for (const auto& [ifrt_memory, py_memory] : c->memory_spaces_) { + Py_VISIT(py_memory.ptr()); + } + return 0; +} + +/* static */ int PyClient::tp_clear(PyObject* self) { + PyClient* c = nb::inst_ptr(self); + absl::flat_hash_map> devices; + std::swap(devices, c->devices_); + absl::flat_hash_map> memory_spaces; + std::swap(memory_spaces, c->memory_spaces_); + return 0; +} + +PyType_Slot PyClient::slots_[] = { + {Py_tp_traverse, (void*)PyClient::tp_traverse}, + {Py_tp_clear, (void*)PyClient::tp_clear}, + {0, nullptr}, +}; + +/* static */ void PyClient::RegisterPythonTypes(nb::module_& m) { + nb::enum_(m, "HostBufferSemantics") + .value("IMMUTABLE_ONLY_DURING_CALL", + PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall) + .value("IMMUTABLE_UNTIL_TRANSFER_COMPLETES", + PjRtClient::HostBufferSemantics::kImmutableUntilTransferCompletes) + .value("ZERO_COPY", PjRtClient::HostBufferSemantics::kImmutableZeroCopy); + + nb::class_ py_local_client(m, "Client", nb::is_weak_referenceable(), + nb::type_slots(PyClient::slots_)); + py_local_client.def_prop_ro("platform", &PyClient::platform_name) + .def_prop_ro("_raw_platform", &PyClient::raw_platform_name) + .def_prop_ro("platform_version", &PyClient::platform_version) + .def_prop_ro("runtime_type", &PyClient::runtime_type) + .def("device_count", &PyClient::device_count) + .def("local_device_count", &PyClient::addressable_device_count) + .def("devices", &PyClient::Devices) + .def("local_devices", &PyClient::LocalDevices) + // TODO(hyeontaek): Remove this method once we have a unified API for + // enumerating devices with different criteria. + .def("_get_all_devices", &PyClient::GetAllDevices) + .def("device_from_local_hardware_id", + xla::ValueOrThrowWrapper(&PyClient::DeviceFromLocalHardwareId)) + .def("live_executables", &PyClient::LiveExecutables) + .def("live_arrays", &PyClient::LiveArrays) + .def("live_buffers", &PyClient::LiveArrays) + .def("process_index", &PyClient::process_index) + .def("host_id", &PyClient::process_index) + .def("task_id", &PyClient::process_index) + .def( + "buffer_from_pyval", + [](nb_class_ptr client, nb::handle argument, + PyDevice* device, bool force_copy, + PjRtClient::HostBufferSemantics host_buffer_semantics) { + return ValueOrThrow( + PyClient::BufferFromPyval(std::move(client), argument, + device ? device->device() : nullptr, + force_copy, host_buffer_semantics)); + }, + nb::arg("argument"), nb::arg("device").none() = nullptr, + nb::arg("force_copy") = false, + nb::arg("host_buffer_semantics") = + PjRtClient::HostBufferSemantics::kImmutableZeroCopy) + .def( + "compile", + [](nb_class_ptr client, nb::bytes mlir_module, + jax::PyDeviceList& py_executable_devices, CompileOptions options) { + ifrt::DeviceListRef executable_devices = + ValueOrThrow(py_executable_devices.ifrt_device_list()); + return ValueOrThrow(PyClient::Compile( + std::move(client), + std::string(mlir_module.c_str(), mlir_module.size()), + std::move(executable_devices), std::move(options))); + }, + nb::arg("computation"), nb::arg("executable_devices"), + nb::arg("compile_options") = CompileOptions()) + .def( + "compile", + [](nb_class_ptr client, std::string mlir_module, + jax::PyDeviceList& py_executable_devices, CompileOptions options) { + ifrt::DeviceListRef executable_devices = + ValueOrThrow(py_executable_devices.ifrt_device_list()); + return ValueOrThrow(PyClient::Compile( + std::move(client), std::move(mlir_module), + std::move(executable_devices), std::move(options))); + }, + nb::arg("computation"), nb::arg("executable_devices"), + nb::arg("compile_options") = CompileOptions()) + .def( + "compile_and_load", + [](nb_class_ptr client, nb::bytes mlir_module, + jax::PyDeviceList& py_executable_devices, CompileOptions options, + std::vector host_callbacks) { + ifrt::DeviceListRef executable_devices = + ValueOrThrow(py_executable_devices.ifrt_device_list()); + return ValueOrThrow(PyClient::CompileAndLoad( + std::move(client), + std::string(mlir_module.c_str(), mlir_module.size()), + std::move(executable_devices), std::move(options), + std::move(host_callbacks))); + }, + nb::arg("computation"), nb::arg("executable_devices"), + nb::arg("compile_options") = CompileOptions(), + nb::arg("host_callbacks") = std::vector()) + .def( + "compile_and_load", + [](nb_class_ptr client, nb::bytes mlir_module, + jax::PyDeviceList& py_executable_devices, CompileOptions options, + std::vector host_callbacks) { + ifrt::DeviceListRef executable_devices = + ValueOrThrow(py_executable_devices.ifrt_device_list()); + return ValueOrThrow(PyClient::CompileAndLoad( + std::move(client), + std::string(mlir_module.c_str(), mlir_module.size()), + std::move(executable_devices), std::move(options), + std::move(host_callbacks))); + }, + nb::arg("computation"), nb::arg("executable_devices"), + nb::arg("compile_options") = CompileOptions(), + nb::arg("host_callbacks") = std::vector()) + .def( + "compile_and_load", + [](nb_class_ptr client, std::string mlir_module, + jax::PyDeviceList& py_executable_devices, CompileOptions options, + std::vector host_callbacks) { + ifrt::DeviceListRef executable_devices = + ValueOrThrow(py_executable_devices.ifrt_device_list()); + return ValueOrThrow(PyClient::CompileAndLoad( + std::move(client), std::move(mlir_module), + std::move(executable_devices), std::move(options), + std::move(host_callbacks))); + }, + nb::arg("computation"), nb::arg("executable_devices"), + nb::arg("compile_options") = CompileOptions(), + nb::arg("host_callbacks") = std::vector()) + .def( + "compile_and_load", + [](nb_class_ptr client, std::string mlir_module, + jax::PyDeviceList& py_executable_devices, CompileOptions options, + std::vector host_callbacks) { + ifrt::DeviceListRef executable_devices = + ValueOrThrow(py_executable_devices.ifrt_device_list()); + return ValueOrThrow(PyClient::CompileAndLoad( + std::move(client), std::move(mlir_module), + std::move(executable_devices), std::move(options), + std::move(host_callbacks))); + }, + nb::arg("computation"), nb::arg("executable_devices"), + nb::arg("compile_options") = CompileOptions(), + nb::arg("host_callbacks") = std::vector()) + // The following two overloads are for users of deprecated APIs who call + // `backend.compile` but do not have visibility to `DeviceList`. + .def( + "compile_and_load", + [](nb_class_ptr client, nb::bytes mlir_module, + nb::sequence& py_executable_devices, CompileOptions options) { + ifrt::DeviceListRef executable_devices = + ValueOrThrow(jax::PyDeviceList(nb::tuple(py_executable_devices)) + .ifrt_device_list()); + return ValueOrThrow(PyClient::CompileAndLoad( + std::move(client), + std::string(mlir_module.c_str(), mlir_module.size()), + std::move(executable_devices), std::move(options), + std::vector())); + }, + nb::arg("computation"), nb::arg("executable_devices"), + nb::arg("compile_options") = CompileOptions()) + .def( + "compile_and_load", + [](nb_class_ptr client, std::string mlir_module, + nb::sequence& py_executable_devices, CompileOptions options) { + ifrt::DeviceListRef executable_devices = + ValueOrThrow(jax::PyDeviceList(nb::tuple(py_executable_devices)) + .ifrt_device_list()); + return ValueOrThrow(PyClient::CompileAndLoad( + std::move(client), std::move(mlir_module), + std::move(executable_devices), std::move(options), + std::vector())); + }, + nb::arg("computation"), nb::arg("executable_devices"), + nb::arg("compile_options") = CompileOptions()) + .def("compile_ifrt_program", + xla::ValueOrThrowWrapper(PyClient::CompileAndLoadIfrtProgram)) + .def("compile_and_load_ifrt_program", + xla::ValueOrThrowWrapper(PyClient::CompileAndLoadIfrtProgram)) + .def("serialize_executable", + xla::ValueOrThrowWrapper(&PyClient::SerializeExecutable)) + .def( + "deserialize_executable", + [](nb_class_ptr client, nb::bytes serialized, + jax::PyDeviceList& py_executable_devices, + std::optional options, + std::vector host_callbacks) { + ifrt::DeviceListRef executable_devices = + ValueOrThrow(py_executable_devices.ifrt_device_list()); + return ValueOrThrow(PyClient::DeserializeExecutable( + std::move(client), std::move(serialized), + std::move(executable_devices), std::move(options), + std::move(host_callbacks))); + }, + nb::arg("serialized"), nb::arg("executable_devices"), + nb::arg("compile_options").none() = nb::none(), + nb::arg("host_callbacks") = std::vector()) + // The following overload is for users of deprecated APIs who call + // `deserialize_executable` but do not have visibility to `DeviceList`. + .def( + "deserialize_executable", + [](nb_class_ptr client, nb::bytes serialized, + nb::sequence& py_executable_devices, + std::optional options) { + ifrt::DeviceListRef executable_devices = + ValueOrThrow(jax::PyDeviceList(nb::tuple(py_executable_devices)) + .ifrt_device_list()); + return ValueOrThrow(PyClient::DeserializeExecutable( + std::move(client), std::move(serialized), + std::move(executable_devices), std::move(options), + std::vector())); + }, + nb::arg("serialized"), nb::arg("executable_devices"), + nb::arg("compile_options").none() = nb::none()) + .def("heap_profile", xla::ValueOrThrowWrapper(&PyClient::HeapProfile)) + // TODO(zhangqiaorjc): Experimental. + .def("defragment", + [](PyClient& self) { xla::ThrowIfError(self.Defragment()); }) + .def("make_python_callback_from_host_send_and_recv", + xla::ValueOrThrowWrapper( + &PyClient::MakePythonCallbackUsingHostSendAndRecv), + nb::arg("callable"), nb::arg("operand_shapes"), + nb::arg("result_shapes"), nb::arg("send_channel_ids"), + nb::arg("recv_channel_ids"), + nb::arg("serializer").none() = nb::none()) + .def( + "get_default_layout", + [](PyClient& self, nb_dtype dtype, nb::sequence shard_shape, + nb_class_ptr device) + -> std::shared_ptr { + ifrt::DType ifrt_type = xla::ValueOrThrow(DtypeToIfRtDType(dtype)); + std::vector dims = SequenceToVector(shard_shape); + return xla::ValueOrThrow(self.ifrt_client()->GetDefaultLayout( + ifrt_type, dims, device->device(), xla::ifrt::MemoryKind())); + }, + nb::arg("dtype"), nb::arg("shard_shape"), nb::arg("device")) + .def("__getattr__", + [](PyClient& client, absl::string_view name) -> nb::object { + const auto& attrs = client.Attributes().map(); + auto it = attrs.find(name); + if (it != attrs.end()) { + return std::visit([](auto&& v) { return nb::cast(v.value); }, + it->second); + } + throw nb::attribute_error( + absl::StrCat("Unknown attribute ", name).c_str()); + }); +} + +} // namespace xla diff --git a/jaxlib/py_client.h b/jaxlib/py_client.h new file mode 100644 index 000000000000..da89b4718f76 --- /dev/null +++ b/jaxlib/py_client.h @@ -0,0 +1,267 @@ +/* Copyright 2020 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_PY_CLIENT_H_ +#define JAXLIB_PY_CLIENT_H_ + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "llvm/Support/Casting.h" +#include "nanobind/nanobind.h" +#include "jaxlib/nb_class_ptr.h" +#include "xla/pjrt/exceptions.h" +#include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_executable.h" +#include "xla/python/ifrt/attribute_map.h" +#include "xla/python/ifrt/client.h" +#include "xla/python/ifrt/compiler.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/device_list.h" +#include "xla/python/ifrt/program.h" +#include "xla/python/pjrt_ifrt/pjrt_client.h" +#include "xla/shape.h" + +namespace xla { + +class PyClient; +class PyLoadedExecutable; +class PyExecutable; +class PyArray; +class PyDevice; +class PyMemorySpace; +struct PyArray_Storage; + +// Python wrapper around PjRtClient. +// We use a wrapper class to add Python-specific functionality. +class PyClient { + public: + static nb_class_ptr Make(std::shared_ptr ifrt_client); + + // Do not call the constructor directly. Use `PyClient::Make` instead. + explicit PyClient(std::shared_ptr ifrt_client); + virtual ~PyClient(); + + ifrt::Client* ifrt_client() const { return ifrt_client_.get(); } + const std::shared_ptr& shared_ptr_ifrt_client() const { + return ifrt_client_; + } + + // Short-term escape hatch to get PjRtClient from PyClient. + // TODO(hyeontaek): Migrate all users of this method to be agnostic of PjRt. + xla::PjRtClient* pjrt_client() const { + auto* pjrt_client = + llvm::dyn_cast_or_null(ifrt_client_.get()); + if (pjrt_client == nullptr) { + throw XlaRuntimeError( + "This operation is implemented for a PjRt-compatible backend only."); + } + return pjrt_client->pjrt_client(); + } + std::shared_ptr shared_ptr_pjrt_client() { + auto* pjrt_client = + llvm::dyn_cast_or_null(ifrt_client_.get()); + if (pjrt_client == nullptr) { + throw XlaRuntimeError( + "This operation is implemented for a PjRt-compatible backend only."); + } + return pjrt_client->shared_ptr_pjrt_client(); + } + + // Legacy aliases. + std::shared_ptr shared_pjrt_client() { + return shared_ptr_pjrt_client(); + } + + absl::string_view platform_name() const { + // TODO(phawkins): this is a temporary backwards compatibility shim. We + // changed the name PJRT reports for GPU platforms to "cuda" or "rocm", but + // we haven't yet updated JAX clients that expect "gpu". Migrate users and + // remove this code. + if (ifrt_client_->platform_name() == "cuda" || + ifrt_client_->platform_name() == "rocm") { + return "gpu"; + } else { + return ifrt_client_->platform_name(); + } + } + absl::string_view raw_platform_name() const { + // TODO(parkers): Once platform_name() is the same, remove this. + return ifrt_client_->platform_name(); + } + absl::string_view platform_version() const { + return ifrt_client_->platform_version(); + } + absl::string_view runtime_type() const { + return ifrt_client_->runtime_type(); + } + + // Returns implementation-specific attributes about this client, e.g. the PJRT + // C API version if applicable. + const xla::ifrt::AttributeMap& Attributes() const { + return client_attributes_; + } + + int addressable_device_count() const { + return ifrt_client_->addressable_device_count(); + } + int device_count() const { return ifrt_client_->device_count(); } + int process_index() const { return ifrt_client_->process_index(); } + + std::vector> Devices(); + std::vector> LocalDevices(); + // Returns all devices in the client. Private API; only use this method for + // implementing backend._get_all_devices(). + // TODO(hyeontaek): Remove this method once we have a unified API for + // enumerating devices with different criteria. + std::vector> GetAllDevices(); + absl::StatusOr> DeviceFromLocalHardwareId( + int local_hardware_id); + + // Returns the PyDevice associated with the given ifrt::Device. + nb_class_ptr GetPyDevice(ifrt::Device* device); + + // Returns the PyMemorySpace associated with the given ifrt::Memory. + nb_class_ptr GetPyMemorySpace(ifrt::Memory* memory_space); + + // Returns a vector of live PyArray objects. PyArray objects may share + // PjRtBuffers, so there may be duplicates of the same underlying device + // buffer. + std::vector LiveBuffersOnDevice(ifrt::Device* device); + + nanobind::list LiveExecutables(); + + // TODO(zhangqiaorjc): Remove when we have transparent defragmentation. + absl::Status Defragment(); + + static absl::StatusOr BufferFromPyval( + nb_class_ptr client, nanobind::handle argument, + ifrt::Device* device, bool force_copy, + ifrt::Client::HostBufferSemantics host_buffer_semantics); + + static absl::StatusOr> + CompileAndLoadIfrtProgram(nb_class_ptr client, + std::unique_ptr ifrt_program, + std::unique_ptr ifrt_options); + + static absl::StatusOr> Compile( + nb_class_ptr client, std::string mlir_module, + ifrt::DeviceListRef executable_devices, CompileOptions options); + + static absl::StatusOr> CompileAndLoad( + nb_class_ptr client, std::string mlir_module, + ifrt::DeviceListRef executable_devices, CompileOptions options, + std::vector host_callbacks); + + static absl::StatusOr> CompileAndLoad( + nb_class_ptr client, std::string mlir_module, + ifrt::DeviceListRef executable_devices, CompileOptions options, + std::vector host_callbacks); + + absl::StatusOr SerializeExecutable( + const PyLoadedExecutable& executable) const; + static absl::StatusOr> DeserializeExecutable( + nb_class_ptr client, nanobind::bytes serialized, + ifrt::DeviceListRef executable_devices, + std::optional options, + std::vector host_callbacks); + + absl::StatusOr HeapProfile(); + + // `MakePythonCallbackUsingHostSendAndRecv` takes in an input Python callable + // that takes in arguments of shapes `operand_shapes` and returns results of + // shapes `result_shapes`. The arguments correspond to Send ops in the HLO + // program through `send_channel_ids` and the results correspond to Recv ops + // through `recv_channel_ids`. It returns the host callback as an opaque + // object whose reference will keep the Python callback alive. The host + // callback can be passed to `PyClient::CompileAndLoad` or + // `PyClient::DeserializeExecutable`. The corresponding Send/Recv ops in the + // XLA computation can trigger the execution of this host callback. + // `serializer` is a function that takes `callable` as an argument and returns + // a serialized callable as a string. + // + // The callable receives as arguments NumPy arrays for arguments with array + // types, and None for Token argument. The callable must return a tuple of + // either arrays or None values. + absl::StatusOr MakePythonCallbackUsingHostSendAndRecv( + nanobind::callable callable, absl::Span operand_shapes, + absl::Span result_shapes, + absl::Span send_channel_ids, + absl::Span recv_channel_ids, + nanobind::callable serializer); + + std::vector LiveArrays() const; + + static void RegisterPythonTypes(nanobind::module_& m); + + protected: + static void Initialize(nb_class_ptr client); + + private: + friend class PyLoadedExecutable; + friend class PyArray; + friend struct PyArray_Storage; + + static int tp_traverse(PyObject* self, visitproc visit, void* arg); + static int tp_clear(PyObject* self); + static PyType_Slot slots_[]; + + std::shared_ptr ifrt_client_; + xla::ifrt::AttributeMap client_attributes_; + // Pointers to intrusive doubly-linked lists of arrays and executables, used + // to iterate over all known objects when heap profiling. The list structure + // is protected by the GIL. + + nanobind::ft_mutex executables_mutex_; + // List guarded by executables_mutex_. + PyLoadedExecutable* executables_ = nullptr; + +#ifdef NB_FREE_THREADING + static constexpr size_t kNumArraysShards = 16; +#else + static constexpr size_t kNumArraysShards = 1; +#endif + struct ArraysShard { + mutable nanobind::ft_mutex mutex; + PyArray_Storage* arrays; + }; + std::array arrays_; + + absl::flat_hash_map> devices_; + absl::flat_hash_map> + memory_spaces_; +}; + +// Returns the execution stream id set for the current thread. +inline int64_t& GetExecutionStreamId() { + thread_local int64_t execution_stream_id = 0; + return execution_stream_id; +} + +} // namespace xla + +#endif // JAXLIB_PY_CLIENT_H_ diff --git a/jaxlib/py_client_cpu.cc b/jaxlib/py_client_cpu.cc new file mode 100644 index 000000000000..1943244b51be --- /dev/null +++ b/jaxlib/py_client_cpu.cc @@ -0,0 +1,243 @@ +/* Copyright 2025 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/py_client_cpu.h" + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/container/inlined_vector.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "include/dlpack/dlpack.h" +#include "nanobind/nanobind.h" +#include "jaxlib/ffi.h" +#include "xla/ffi/api/ffi.h" +#include "xla/ffi/ffi_api.h" +#include "xla/pjrt/host_callback.h" +#include "xla/pjrt/transpose.h" +#include "xla/primitive_util.h" +#include "xla/python/nb_numpy.h" +#include "xla/python/types.h" +#include "xla/shape_util.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" + +namespace nb = nanobind; + +namespace xla { + +struct CpuTransposePlanCache { + static ffi::TypeId id; + explicit CpuTransposePlanCache(int capacity) : cache(capacity) {} + xla::TransposePlanCache cache; +}; + +ffi::TypeId CpuTransposePlanCache::id = {}; + +XLA_FFI_REGISTER_TYPE(ffi::GetXlaFfiApi(), "CpuTransposePlanCache", + &CpuTransposePlanCache::id); + +static ffi::ErrorOr> +CpuTransposePlanCacheInstantiate(uint64_t index) { + return std::make_unique(/*capacity=*/16); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL( + kCpuTransposePlanCacheInstantiate, CpuTransposePlanCacheInstantiate, + ffi::Ffi::BindInstantiate().Attr("index")); + +ffi::Error XlaFfiPythonCpuCallback(FfiLoadedHostCallbacks* callbacks, + CpuTransposePlanCache* transpose_cache, + uint64_t index, ffi::RemainingArgs args, + ffi::RemainingRets rets) { + nb::gil_scoped_acquire gil; + auto callback = nb::borrow( + static_cast(callbacks->callbacks[index])); + auto nb_args = nb::steal(PyTuple_New(args.size())); + for (size_t i = 0; i < args.size(); ++i) { + auto arg = args.get(i); + auto ptype = static_cast(arg->element_type()); + // TODO(b/395428868): Remove this check once we support subbyte types. + if (ptype == S1 || ptype == U1) { + return ffi::Error(ffi::ErrorCode::kUnimplemented, + absl::StrFormat("Unsupported primitive type: %s", + PrimitiveType_Name(ptype))); + } + if (ptype == TOKEN) { + PyTuple_SET_ITEM(nb_args.ptr(), i, nb::none().release().ptr()); + continue; + } + auto maybe_dtype = PrimitiveTypeToNbDtype(ptype); + if (!maybe_dtype.ok()) { + return ffi::Error::Internal(maybe_dtype.status().ToString()); + } + auto dtype = maybe_dtype.value(); + auto dims = absl::Span(arg->dimensions().begin(), + arg->dimensions().size()); + // TODO(b/402422886): Remove this once we form Jax arrays directly instead + std::unique_ptr buffer; + const void* data = arg->untyped_data(); + size_t bits_per_element = xla::primitive_util::BitWidth(ptype); + if (bits_per_element == 2 || bits_per_element == 4) { + // NOTE(dsuo): FFI arguments and return buffers are sized assuming + // minimum 1-byte element sizes, even if the data itself is packed. We + // assume that 2-bit and 4-bit types are packed. + size_t size_bytes = arg->element_count() * bits_per_element / 8; + buffer = xla::UnpackIntN(bits_per_element, static_cast(data), + size_bytes); + data = buffer.get(); + } + // We pass in data using default numpy layout i.e., std::nullopt. + auto array = nb_numpy_ndarray(dtype, dims, std::nullopt, data); + array.attr("flags").attr("writeable") = nb::bool_(false); + PyTuple_SET_ITEM(nb_args.ptr(), i, array.release().ptr()); + } + + EnterHostCallback(); + // TODO(dsuo): Change this to use the Python vectorcall protocol, which allows + // you to avoid constructing a tuple for the arguments. + nb::tuple result_tuple; + try { + auto result_object = callback(*nb::borrow(nb_args)); + result_tuple = nb::cast(result_object); + } catch (nb::python_error& e) { + return ffi::Error::Internal( + absl::StrFormat("CpuCallback error calling callback: %s", e.what())); + } + LeaveHostCallback(); + + for (size_t i = 0; i < rets.size(); ++i) { + auto ret = rets.get(i).value(); + auto ptype = static_cast(ret->element_type()); + // TODO(b/402422886): Remove this once we form Jax arrays directly instead + if (ptype == S1 || ptype == U1) { + return ffi::Error(ffi::ErrorCode::kUnimplemented, + absl::StrFormat("Unsupported primitive type: %s", + PrimitiveType_Name(ptype))); + } + if (ptype == TOKEN) continue; + nb::object output = + nb::borrow(PyTuple_GetItem(result_tuple.ptr(), i)); + nb_numpy_ndarray array = nb_numpy_ndarray::ensure(std::move(output)); + absl::Span strides( + reinterpret_cast(array.strides()), array.ndim()); + // We expect the output to be in default numpy layout. + auto dims = absl::Span(ret->dimensions().begin(), + ret->dimensions().size()); + auto maybe_expected_shape = ShapeUtil::MakeValidatedShape(ptype, dims); + if (!maybe_expected_shape.ok()) { + return ffi::Error::Internal(maybe_expected_shape.status().ToString()); + } + auto expected_shape = maybe_expected_shape.value(); + auto expected_strides = ByteStridesForShape(expected_shape); + + const void* data = array.data(); + std::unique_ptr buffer; + size_t bits_per_element = xla::primitive_util::BitWidth(ptype); + size_t size_bytes = array.size() * array.itemsize(); + if (strides != expected_strides) { + xla::TransposePlan::Options options; + options.elem_size_in_bytes = xla::primitive_util::ByteWidth(ptype); + options.dims = absl::Span( + reinterpret_cast(array.shape()), array.ndim()); + absl::InlinedVector reversed_layout; + reversed_layout.resize(expected_shape.dimensions().size()); + absl::c_reverse_copy(expected_shape.layout().minor_to_major(), + reversed_layout.begin()); + options.permutation = reversed_layout; + options.input_layout = xla::TransposePlan::Striding{strides}; + auto maybe_plan = transpose_cache->cache.GetOrCreate(options); + if (!maybe_plan.ok()) { + return ffi::Error::Internal(maybe_plan.status().ToString()); + } + auto plan = maybe_plan.value(); + if (bits_per_element == 2 || bits_per_element == 4) { + // NOTE(dsuo): If the data needs to be unpacked, don't use return buffer + // supplied by FFI directly. + buffer = std::make_unique(size_bytes); + plan->Execute(data, buffer.get()); + data = buffer.get(); + } else { + plan->Execute(data, ret->untyped_data()); + data = ret->untyped_data(); + } + } + + // TODO(b/402422886): Remove this once we form Jax arrays directly instead + // of packing/unpacking to/from numpy arrays. + if (bits_per_element == 2 || bits_per_element == 4) { + // NOTE(dsuo): FFI arguments and return buffers are sized assuming + // minimum 1-byte element sizes, even if the data itself is packed. We + // assume that 2-bit and 4-bit types are packed. + buffer = xla::PackIntN(bits_per_element, static_cast(data), + size_bytes); + data = buffer.get(); + size_bytes = (size_bytes * bits_per_element) / 8; + } + + // Copy data to output buffer if haven't already or modified the data to + // write back. + if (data != ret->untyped_data()) { + std::memcpy(ret->untyped_data(), data, size_bytes); + } + } + + return ffi::Error::Success(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(kXlaFfiPythonCpuCallback, XlaFfiPythonCpuCallback, + ffi::Ffi::Bind() + .Ctx>() + .Ctx>() + .Attr("index") + .RemainingArgs() + .RemainingRets()); + +XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "xla_ffi_python_cpu_callback", + "HOST", + {kCpuTransposePlanCacheInstantiate, nullptr, nullptr, + kXlaFfiPythonCpuCallback}); +XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), + "xla_ffi_partitioned_python_cpu_callback", "HOST", + {kCpuTransposePlanCacheInstantiate, nullptr, nullptr, + kXlaFfiPythonCpuCallback}); + +XLA_FFI_DEFINE_HANDLER_SYMBOL(kXlaBufferPythonCpuCallback, + (jax::XlaBufferCallback), + ffi::Ffi::Bind() + .Ctx() + .Ctx() + .Ctx() + .Ctx>() + .Attr("index") + .RemainingArgs() + .RemainingRets()); + +XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "xla_buffer_python_cpu_callback", + "HOST", kXlaBufferPythonCpuCallback); + +} // namespace xla diff --git a/jaxlib/py_client_cpu.h b/jaxlib/py_client_cpu.h new file mode 100644 index 000000000000..275a57fa06b5 --- /dev/null +++ b/jaxlib/py_client_cpu.h @@ -0,0 +1,28 @@ +/* Copyright 2025 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_PY_CLIENT_CPU_H_ +#define JAXLIB_PY_CLIENT_CPU_H_ + +#include "xla/ffi/api/ffi.h" + +namespace xla { + +XLA_FFI_DECLARE_HANDLER_SYMBOL(kCpuTransposePlanCacheInstantiate); +XLA_FFI_DECLARE_HANDLER_SYMBOL(kXlaFfiPythonCpuCallback); + +} // namespace xla + +#endif // JAXLIB_PY_CLIENT_CPU_H_ diff --git a/jaxlib/py_compile_only_client.cc b/jaxlib/py_compile_only_client.cc new file mode 100644 index 000000000000..f23f09c265a1 --- /dev/null +++ b/jaxlib/py_compile_only_client.cc @@ -0,0 +1,133 @@ +/* Copyright 2023 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/py_compile_only_client.h" + +#include +#include +#include + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "llvm/Support/Casting.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OwningOpRef.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/shared_ptr.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "jaxlib/nb_class_ptr.h" +#include "jaxlib/py_client.h" +#include "jaxlib/py_device_list.h" +#include "jaxlib/py_executable.h" +#include "xla/pjrt/mlir_to_hlo.h" +#include "xla/pjrt/pjrt_compiler.h" +#include "xla/pjrt/pjrt_executable.h" +#include "xla/pjrt/status_casters.h" +#include "xla/python/compile_only_ifrt/client.h" +#include "xla/python/ifrt/device_list.h" +#include "xla/python/ifrt/executable.h" +#include "xla/python/pjrt_ifrt/pjrt_executable.h" +#include "xla/python/pjrt_ifrt/pjrt_topology.h" +#include "xla/python/pjrt_ifrt/xla_compiler.h" +#include "xla/python/version.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/tsl/python/lib/core/numpy.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" + +namespace nb = nanobind; + +namespace xla { + +namespace { + +class CompileOnlyPyClient : public PyClient { + public: + using PyClient::PyClient; + + static nb_class_ptr Make( + std::shared_ptr topology) { + auto client = + nb::borrow>(make_nb_class( + std::make_unique(std::move(topology)))); + CompileOnlyPyClient::Initialize(client); + return client; + } + + absl::StatusOr> CompileUnloaded( + absl::string_view mlir_module, ifrt::DeviceListRef executable_devices, + CompileOptions options) { + ifrt::ExecutableRef ifrt_executable; + { + nb::gil_scoped_release gil_release; + mlir::MLIRContext context; + TF_ASSIGN_OR_RETURN(mlir::OwningOpRef module, + ParseMlirModuleString(mlir_module, context)); + auto* ifrt_client = + llvm::dyn_cast_or_null(this->ifrt_client()); + CHECK(ifrt_client) << "CompileOnlyPyClient requires ifrt_client be a " + "CompileOnlyIfRtClient"; + + auto xla_options = std::make_unique( + options, std::move(executable_devices)); + TF_ASSIGN_OR_RETURN(auto executable, + PjRtCompile(std::move(options), module.get(), + *ifrt_client->topology().description())); + TF_ASSIGN_OR_RETURN(ifrt_executable, + ifrt::PjRtExecutable::Create(std::move(executable))); + } + return make_nb_class(ifrt_executable); + } + + private: + static void Initialize(nb_class_ptr client) { + PyClient::Initialize(client); + } +}; + +} // namespace + +nb_class_ptr MakeCompileOnlyClient( + std::shared_ptr topology) { + return CompileOnlyPyClient::Make(std::move(topology)); +} + +void RegisterCompileOnlyClient(nb::module_& m) { + nb::class_(m, "CompileOnlyPyClient") + .def( + "compile", + [](CompileOnlyPyClient& self, nb::bytes mlir_module, + jax::PyDeviceList& py_executable_devices, CompileOptions options, + std::vector host_callbacks) { + ifrt::DeviceListRef executable_devices = + ValueOrThrow(py_executable_devices.ifrt_device_list()); + return ValueOrThrow(self.CompileUnloaded( + absl::string_view(mlir_module.c_str(), mlir_module.size()), + std::move(executable_devices), std::move(options))); + }, + nb::arg("computation"), nb::arg("executable_devices"), + nb::arg("compile_options") = CompileOptions(), + nb::arg("host_callbacks") = std::vector()) + .def("compile", + ValueOrThrowWrapper(&CompileOnlyPyClient::CompileUnloaded), + nb::arg("computation"), nb::arg("executable_devices"), + nb::arg("compile_options") = CompileOptions()); +} + +} // namespace xla diff --git a/jaxlib/py_compile_only_client.h b/jaxlib/py_compile_only_client.h new file mode 100644 index 000000000000..4b274871ee96 --- /dev/null +++ b/jaxlib/py_compile_only_client.h @@ -0,0 +1,45 @@ +/* Copyright 2023 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_PY_COMPILE_ONLY_CLIENT_H_ +#define JAXLIB_PY_COMPILE_ONLY_CLIENT_H_ + +#include + +// placeholder for index annotation headers +#include "nanobind/nanobind.h" +#include "jaxlib/nb_class_ptr.h" +#include "jaxlib/py_client.h" +#include "xla/python/pjrt_ifrt/pjrt_topology.h" + +namespace xla { + +// This is a workaround for AOT compilation until topologies and device +// descriptions are better integrated into jax's Python code. It returns a +// PyClient that will return errors for all non-AOT methods. It also exposes a +// different compile method that returns an unloaded executable (vs. PyClient +// usually returns a loaded executable). RegisterCompileOnlyClient() overloads +// the Python "compile" method to return the unloaded executable, and we rely on +// Python duck typing to treat the unloaded executable like a loaded executable +// (except it will raise errors if you try to run it, which is what we want for +// AOT environments). +nb_class_ptr MakeCompileOnlyClient( + std::shared_ptr); + +void RegisterCompileOnlyClient(nanobind::module_& m); + +} // namespace xla + +#endif // JAXLIB_PY_COMPILE_ONLY_CLIENT_H_ diff --git a/jaxlib/py_device.cc b/jaxlib/py_device.cc new file mode 100644 index 000000000000..f830b4f49448 --- /dev/null +++ b/jaxlib/py_device.cc @@ -0,0 +1,350 @@ +/* Copyright 2024 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/py_device.h" + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "llvm/Support/Casting.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/optional.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "nanobind/stl/variant.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "jaxlib/nb_class_ptr.h" +#include "jaxlib/py_client.h" +#include "jaxlib/py_memory_space.h" +#include "jaxlib/python_ref_manager.h" +#include "xla/layout_util.h" +#include "xla/literal.h" +#include "xla/pjrt/status_casters.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/nb_helpers.h" +#include "xla/python/pjrt_ifrt/pjrt_client.h" +#include "xla/python/pjrt_ifrt/pjrt_device.h" +#include "xla/python/types.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/tsl/framework/allocator.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/util.h" + +namespace nb = ::nanobind; + +namespace xla { + +PyDevice::PyDevice(nb_class_ptr client, ifrt::Device* device) + : client_(std::move(client)), device_(device) {} + +int PyDevice::id() const { return device_->Id().value(); } + +int PyDevice::process_index() const { return device_->ProcessIndex(); } + +absl::string_view PyDevice::platform() const { + // TODO(phawkins): this is a temporary backwards + // compatibility shim. We changed the name PJRT + // reports for GPU platforms to "cuda" or "rocm", + // but we haven't yet updated JAX clients that + // expect "gpu". Migrate users and remove this + // code. + if (client_->platform_name() == "cuda" || + client_->platform_name() == "rocm") { + return absl::string_view("gpu"); + } else { + return client_->platform_name(); + } +} + +absl::string_view PyDevice::device_kind() const { return device_->Kind(); } + +std::optional PyDevice::local_hardware_id() const { + // TODO(phawkins): consider supporting this for non-PJRT devices. + ifrt::PjRtDevice* device = llvm::dyn_cast(device_); + if (device == nullptr || !device->IsAddressable()) { + return std::nullopt; + } + int local_hardware_id = device->pjrt_device()->local_hardware_id().value(); + if (local_hardware_id == -1) { + return std::nullopt; + } + return local_hardware_id; +} + +absl::string_view PyDevice::Str() const { return device_->DebugString(); } + +absl::string_view PyDevice::Repr() const { return device_->ToString(); } + +absl::Status PyDevice::TransferToInfeed(LiteralSlice literal) { + GlobalPyRefManager()->CollectGarbage(); + nb::gil_scoped_release gil_release; + auto client = llvm::dyn_cast(client_->ifrt_client()); + auto device = llvm::dyn_cast(device_); + if (client == nullptr || device == nullptr) { + return xla::InvalidArgument( + "TransferToInfeed is only supported for PjRt devices."); + } + return client->TransferToInfeed(device, literal); +} + +absl::StatusOr PyDevice::TransferFromOutfeed(Shape shape) { + GlobalPyRefManager()->CollectGarbage(); + std::shared_ptr literal; + { + nb::gil_scoped_release gil_release; + auto client = llvm::dyn_cast(client_->ifrt_client()); + auto device = llvm::dyn_cast(device_); + if (client == nullptr || device == nullptr) { + return xla::InvalidArgument( + "TransferFromOutfeed is only supported for PjRt devices."); + } + ShapeUtil::ForEachMutableSubshape( + &shape, [](Shape* subshape, const ShapeIndex&) { + if (!subshape->has_layout()) { + LayoutUtil::SetToDefaultLayout(subshape); + } + }); + literal = std::make_shared(shape); + TF_RETURN_IF_ERROR(client->TransferFromOutfeed(device, literal.get())); + } + return LiteralToPython(std::move(literal)); +} + +absl::StatusOr> PyDevice::Memory( + absl::string_view kind) const { + ifrt::Memory* result_memory_space = nullptr; + for (auto* memory_space : device_->Memories()) { + if (memory_space->Kind().memory_kind() == kind) { + if (result_memory_space != nullptr) { + std::string memories = absl::StrJoin( + device_->Memories(), ", ", + [](std::string* out, const auto& memory_space) { + absl::StrAppend(out, *memory_space->Kind().memory_kind()); + }); + auto device_kind = device_->Kind(); + return xla::InvalidArgument( + "Found more than one addressable memory for " + "kind %s which is not allowed. There can only " + "be one memory for each " + "kind. Device %s can address the following " + "memory kinds: %s", + kind, device_kind, memories); + } + result_memory_space = memory_space; + } + } + if (result_memory_space == nullptr) { + std::string memories = absl::StrJoin( + device_->Memories(), ", ", + [](std::string* out, const auto& memory_space) { + absl::StrAppend(out, *memory_space->Kind().memory_kind()); + }); + auto device_kind = device_->Kind(); + return xla::InvalidArgument( + "Could not find memory addressable by device %s. Device %s " + "can address the following memory kinds: %s. " + "Got memory kind: %s", + device_kind, device_kind, memories, kind); + } + return client_->GetPyMemorySpace(result_memory_space); +} + +absl::StatusOr> PyDevice::DefaultMemory() const { + TF_ASSIGN_OR_RETURN(auto* memory_space, device_->DefaultMemory()); + return client_->GetPyMemorySpace(memory_space); +} + +nb::list PyDevice::AddressableMemories() const { + nb::list memory_spaces; + for (auto* memory_space : device_->Memories()) { + memory_spaces.append(client_->GetPyMemorySpace(memory_space)); + } + return memory_spaces; +} + +absl::StatusOr> PyDevice::MemoryStats() const { + GlobalPyRefManager()->CollectGarbage(); + ifrt::PjRtDevice* device = llvm::dyn_cast(device_); + if (device == nullptr || !device->IsAddressable()) { + return xla::InvalidArgument( + "MemoryStats is only supported for addressable PjRt devices."); + } + absl::StatusOr maybe_stats = + device->pjrt_device()->GetAllocatorStats(); + if (absl::IsUnimplemented(maybe_stats.status())) { + return std::nullopt; + } + // Raise error if any status other than Unimplemented is returned. + ThrowIfError(maybe_stats.status()); + + nb::dict result; + result["num_allocs"] = maybe_stats->num_allocs; + result["bytes_in_use"] = maybe_stats->bytes_in_use; + result["peak_bytes_in_use"] = maybe_stats->peak_bytes_in_use; + result["largest_alloc_size"] = maybe_stats->largest_alloc_size; + if (maybe_stats->bytes_limit) { + result["bytes_limit"] = *maybe_stats->bytes_limit; + } + result["bytes_reserved"] = maybe_stats->bytes_reserved; + result["peak_bytes_reserved"] = maybe_stats->peak_bytes_reserved; + if (maybe_stats->bytes_reservable_limit) { + result["bytes_reservable_limit"] = *maybe_stats->bytes_reservable_limit; + } + result["largest_free_block_bytes"] = maybe_stats->largest_free_block_bytes; + if (maybe_stats->pool_bytes) { + result["pool_bytes"] = *maybe_stats->pool_bytes; + } + if (maybe_stats->peak_pool_bytes) { + result["peak_pool_bytes"] = *maybe_stats->peak_pool_bytes; + } + return result; +} + +absl::StatusOr PyDevice::GetStreamForExternalReadyEvents() + const { + ifrt::PjRtDevice* device = llvm::dyn_cast(device_); + if (device == nullptr || !device->IsAddressable()) { + return xla::InvalidArgument( + "GetStreamForExternalReadyEvents is only supported for addressable " + "PjRt devices."); + } + return device->pjrt_device()->GetStreamForExternalReadyEvents(); +} + +/* static */ int PyDevice::tp_traverse(PyObject* self, visitproc visit, + void* arg) { + PyDevice* d = nb::inst_ptr(self); + Py_VISIT(d->client().ptr()); + return 0; +} + +/* static */ int PyDevice::tp_clear(PyObject* self) { + PyDevice* d = nb::inst_ptr(self); + nb_class_ptr client; + std::swap(client, d->client_); + return 0; +} + +PyType_Slot PyDevice::slots_[] = { + {Py_tp_traverse, (void*)PyDevice::tp_traverse}, + {Py_tp_clear, (void*)PyDevice::tp_clear}, + {0, nullptr}, +}; + +/* static */ void PyDevice::RegisterPythonType(nb::module_& m) { + nb::class_ device( + m, "Device", nb::type_slots(PyDevice::slots_), + "A descriptor of an available device.\n\nSubclasses are used to " + "represent specific types of devices, e.g. CPUs, GPUs. Subclasses may " + "have additional properties specific to that device type."); + device + .def_prop_ro( + "id", &PyDevice::id, + "Integer ID of this device.\n\nUnique across all available devices " + "of this type, including remote devices on multi-host platforms.") + .def_prop_ro("process_index", &PyDevice::process_index, + "Integer index of this device's process.\n\n" + "This is always 0 except on multi-process platforms.") + .def_prop_ro("host_id", &PyDevice::process_index, + "Deprecated; please use process_index") + .def_prop_ro("task_id", &PyDevice::process_index, + "Deprecated; please use process_index") + .def_prop_ro("platform", &PyDevice::platform) + .def_prop_ro("device_kind", &PyDevice::device_kind) + .def_prop_ro("client", &PyDevice::client) + .def_prop_ro( + "local_hardware_id", &PyDevice::local_hardware_id, + "Opaque hardware ID, e.g., the CUDA device number. In general, not " + "guaranteed to be dense, and not guaranteed to be defined on all " + "platforms.") + .def("__str__", &PyDevice::Str) + .def("__repr__", &PyDevice::Repr) + .def("transfer_to_infeed", + ThrowIfErrorWrapper(&PyDevice::TransferToInfeed)) + .def("transfer_from_outfeed", + ValueOrThrowWrapper(&PyDevice::TransferFromOutfeed)) + .def("memory", ValueOrThrowWrapper(&PyDevice::Memory), nb::arg("kind")) + .def("default_memory", ValueOrThrowWrapper(&PyDevice::DefaultMemory), + "Returns the default memory of a device.") + .def("addressable_memories", &PyDevice::AddressableMemories, + "Returns all the memories that a device can address.") + + .def("live_buffers", + [](nb::handle device) { + PythonDeprecationWarning( + /*stacklevel=*/1, + "Per device live_buffers() is deprecated. Please " + "use the jax.live_arrays() for jax.Arrays instead."); + return nb::list(); + }) + .def( + "memory_stats", ValueOrThrowWrapper(&PyDevice::MemoryStats), + "Returns memory statistics for this device keyed by name. May not " + "be implemented on all platforms, and different platforms may return " + "different stats, or -1 for unavailable stats. 'bytes_in_use' is " + "usually available. Intended for diagnostic use.") + .def( + "get_stream_for_external_ready_events", + xla::ValueOrThrowWrapper(&PyDevice::GetStreamForExternalReadyEvents)); + static PyMethodDef get_attr_method = { + "__getattr__", + +[](PyObject* self, PyObject* args) -> PyObject* { + PyObject* key; + if (!PyArg_ParseTuple(args, "O", &key)) { + PyErr_SetString(PyExc_TypeError, "__getattr__ must take 1 argument."); + return nullptr; + } + try { + auto device = nb::cast(nb::handle(self)); + auto name = nb::cast(nb::handle(key)); + const auto& attrs = device->device_->Attributes().map(); + auto it = attrs.find(name); + if (it != attrs.end()) { + auto result = std::visit([](auto&& v) { return nb::cast(v.value); }, + it->second); + return result.release().ptr(); + } + PyErr_SetNone(PyExc_AttributeError); + return nullptr; + } catch (std::exception& e) { + PyErr_Format(PyExc_SystemError, "Unhandled nanobind exception: %s", + e.what()); + return nullptr; + } catch (...) { + PyErr_SetString(PyExc_SystemError, "Unhandled nanobind exception."); + return nullptr; + } + }, + METH_VARARGS, + nullptr, + }; + device.attr("__getattr__") = nb::steal(PyDescr_NewMethod( + reinterpret_cast(device.ptr()), &get_attr_method)); +} + +} // namespace xla diff --git a/jaxlib/py_device.h b/jaxlib/py_device.h new file mode 100644 index 000000000000..8366f8deae3e --- /dev/null +++ b/jaxlib/py_device.h @@ -0,0 +1,83 @@ +/* Copyright 2024 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_PY_DEVICE_H_ +#define JAXLIB_PY_DEVICE_H_ + +#include + +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "nanobind/nanobind.h" +#include "jaxlib/nb_class_ptr.h" +#include "jaxlib/py_client.h" +#include "xla/literal.h" +#include "xla/python/ifrt/device.h" +#include "xla/shape.h" + +namespace xla { + +class PyDevice { + public: + PyDevice(nb_class_ptr client, ifrt::Device* device); + + // Devices are compared using Python object identity, so we don't allow them + // to be copied or moved. + PyDevice(const PyDevice&) = delete; + PyDevice(PyDevice&&) = delete; + PyDevice& operator=(const PyDevice&) = delete; + PyDevice& operator=(PyDevice&&) = delete; + + const nb_class_ptr& client() const { return client_; } + ifrt::Device* device() const { return device_; } + + int id() const; + int process_index() const; + absl::string_view platform() const; + absl::string_view device_kind() const; + std::optional local_hardware_id() const; + + absl::string_view Str() const; + absl::string_view Repr() const; + + absl::Status TransferToInfeed(LiteralSlice literal); + absl::StatusOr TransferFromOutfeed(Shape shape); + + absl::StatusOr> Memory( + absl::string_view kind) const; + absl::StatusOr> DefaultMemory() const; + nanobind::list AddressableMemories() const; + absl::StatusOr> MemoryStats() const; + + absl::StatusOr GetStreamForExternalReadyEvents() const; + + static void RegisterPythonType(nanobind::module_& m); + + private: + static int tp_traverse(PyObject* self, visitproc visit, void* arg); + static int tp_clear(PyObject* self); + static PyType_Slot slots_[]; + + nb_class_ptr client_; + ifrt::Device* device_; +}; + +} // namespace xla + +#endif // JAXLIB_PY_DEVICE_H_ diff --git a/jaxlib/py_device_list.cc b/jaxlib/py_device_list.cc new file mode 100644 index 000000000000..71f1125c749b --- /dev/null +++ b/jaxlib/py_device_list.cc @@ -0,0 +1,497 @@ +/* Copyright 2023 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/py_device_list.h" + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/inlined_vector.h" +#include "absl/hash/hash.h" +#include "absl/log/check.h" +#include "absl/status/statusor.h" +#include "absl/types/span.h" +#include "nanobind/make_iterator.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/set.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "jaxlib/nb_class_ptr.h" +#include "jaxlib/py_client.h" +#include "jaxlib/py_device.h" +#include "jaxlib/python_ref_manager.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/device_list.h" +#include "xla/python/types.h" +#include "xla/util.h" + +namespace jax { + +namespace nb = ::nanobind; + +PyDeviceList::PyDeviceList(xla::nb_class_ptr py_client, + xla::ifrt::DeviceListRef device_list) + : py_client_(std::move(py_client)), device_list_(std::move(device_list)) {} + +PyDeviceList::PyDeviceList(nb::tuple py_device_assignment) + : device_list_(py_device_assignment) { + // Attempt to convert to Python devices into `ifrt::DeviceList`. + if (py_device_assignment.size() == 0) { + return; + } + absl::InlinedVector devices; + devices.reserve(py_device_assignment.size()); + for (nb::handle obj : py_device_assignment) { + if (!nb::isinstance(obj.ptr())) { + // Non-`xla::PyDevice` is used on an alternative JAX backend with device + // duck typing. Use Python device objects already set in `device_list_`. + return; + } + auto py_device = nb::cast(obj); + if (py_client_.get() == nullptr) { + py_client_ = py_device->client(); + } else if (py_device->client().get() != py_client_.get()) { + // If the list contains multiple clients, fall back to device duck typing. + return; + } + devices.push_back(py_device->device()); + } + device_list_ = py_client_->ifrt_client()->MakeDeviceList(devices); +} + +PyDeviceList::~PyDeviceList() { + if (device_list_.index() == 1) { + xla::GlobalPyRefManager()->AddGarbage( + std::move(std::get<1>(std::move(device_list_)))); + } +} + +absl::StatusOr PyDeviceList::ifrt_device_list() + const { + switch (device_list_.index()) { + case 0: + return std::get<0>(device_list_); + case 1: + return xla::InvalidArgument("DeviceList contains non-IFRT devices"); + default: + return xla::InvalidArgument("Unrecognized DeviceList type"); + } +} + +int64_t PyDeviceList::Hash() { + if (!hash_.has_value()) { + switch (device_list_.index()) { + case 0: + hash_ = absl::HashOf(std::get<0>(device_list_)); + break; + case 1: + hash_ = nb::hash(std::get<1>(device_list_)); + break; + default: + throw nb::value_error("Unrecognized DeviceList type"); + } + } + return *hash_; +} + +/*static*/ bool PyDeviceList::Equal(xla::nb_class_ptr self, + nb::handle other) { + if (!nb::isinstance(other)) { + return false; + } + auto o = nb::cast(other); + // Fast-path using a pointer equality check. + if (self.get() == o) { + return true; + } + int64_t h1, h2; + { + nb::ft_object_guard lock(self); + h1 = self->Hash(); + } + { + nb::ft_object_guard lock(other); + h2 = o->Hash(); + } + if (h1 != h2) { + return false; + } + if (self->device_list_.index() == 0 && o->device_list_.index() == 0) { + nb::gil_scoped_release gil_release; + return *std::get<0>(self->device_list_) == *std::get<0>(o->device_list_); + } else { + return self->AsTuple().equal(o->AsTuple()); + } +} + +/*static*/ bool PyDeviceList::NotEqual(xla::nb_class_ptr self, + nb::handle other) { + return !Equal(std::move(self), other); +} + +int PyDeviceList::Len() const { + switch (device_list_.index()) { + case 0: + return std::get<0>(device_list_)->size(); + case 1: + return nb::len(std::get<1>(device_list_)); + default: + throw nb::value_error("Unrecognized DeviceList type"); + } +} + +nb::object PyDeviceList::GetItem(int index) { + switch (device_list_.index()) { + case 0: { + const xla::ifrt::DeviceListRef& device_list = std::get<0>(device_list_); + if (index < -device_list->size() || index >= device_list->size()) { + throw nb::index_error(); + } else if (index < 0) { + index += device_list->size(); + } + return py_client_->GetPyDevice(device_list->devices()[index]); + } + case 1: + return std::get<1>(device_list_).attr("__getitem__")(index); + default: + throw nb::value_error("Unrecognized DeviceList type"); + } +} + +nb::object PyDeviceList::GetSlice(nb::slice slice) { + switch (device_list_.index()) { + case 0: { + const xla::ifrt::DeviceListRef& device_list = std::get<0>(device_list_); + const absl::Span devices = + device_list->devices(); + Py_ssize_t start, stop, step, slicelength; + if (PySlice_GetIndicesEx(slice.ptr(), devices.size(), &start, &stop, + &step, &slicelength) != 0) { + throw nb::python_error(); + } + nb::tuple out = nb::steal(PyTuple_New(slicelength)); + for (size_t i = 0; i < slicelength; ++i) { + nb::object d = py_client_->GetPyDevice(devices[start]); + PyTuple_SET_ITEM(out.ptr(), i, d.release().ptr()); + start += step; + } + return std::move(out); + } + case 1: + return std::get<1>(device_list_).attr("__getitem__")(slice); + default: + throw nb::value_error("Unrecognized DeviceList type"); + } +} + +nb::tuple PyDeviceList::AsTuple() const { + switch (device_list_.index()) { + case 0: { + const xla::ifrt::DeviceListRef& device_list = std::get<0>(device_list_); + nb::tuple out = nb::steal(PyTuple_New(device_list->size())); + int i = 0; + for (xla::ifrt::Device* device : device_list->devices()) { + nb::object d = py_client_->GetPyDevice(device); + PyTuple_SET_ITEM(out.ptr(), i, d.release().ptr()); + ++i; + } + return out; + } + case 1: + return std::get<1>(device_list_); + default: + throw nb::value_error("Unrecognized DeviceList type"); + } +} + +nb::iterator PyDeviceList::Iter() { + switch (device_list_.index()) { + case 0: { + // Iterator whose deference converts `xla::ifrt::Device*` into JAX + // `PjRtDevice`. + struct Iterator { + void operator++() { ++it; } + bool operator==(const Iterator& other) const { return it == other.it; } + xla::nb_class_ptr operator*() const { + return py_client->GetPyDevice(*it); + } + xla::nb_class_ptr py_client; + absl::Span::const_iterator it; + }; + return nb::make_iterator( + nb::type(), "ifrt_device_iterator", + Iterator{py_client_, std::get<0>(device_list_)->devices().cbegin()}, + Iterator{py_client_, std::get<0>(device_list_)->devices().cend()}); + } + case 1: + return nb::make_iterator( + nb::type(), "python_device_iterator", + std::get<1>(device_list_).begin(), std::get<1>(device_list_).end()); + default: + throw nb::value_error("Unrecognized DeviceList type"); + } +} + +std::string PyDeviceList::Str() { + return nb::cast(nb::str(AsTuple())); +} + +nb::tuple PyDeviceList::Dump() const { return AsTuple(); } + +bool PyDeviceList::IsFullyAddressable() { + if (!is_fully_addressable_.has_value()) { + ProcessIndices(); + CHECK(process_indices_.has_value()); + if (process_indices_->size() > 1) { + is_fully_addressable_ = false; + } else { + CHECK_EQ(process_indices_->size(), 1); + int process_index; + switch (device_list_.index()) { + case 0: { + process_index = py_client_ ? py_client_->process_index() : 0; + break; + } + case 1: { + process_index = + nb::cast(std::get<1>(device_list_)[0].attr("client").attr( + "process_index")()); + break; + } + default: + throw nb::value_error("Unrecognized DeviceList type"); + } + is_fully_addressable_ = *process_indices_->begin() == process_index; + } + } + return *is_fully_addressable_; +} + +/*static*/ xla::nb_class_ptr PyDeviceList::AddressableDeviceList( + xla::nb_class_ptr self) { + nb::ft_object_guard lock(self); + if (self->IsFullyAddressable()) { + // Do not cache this result in `addressable_device_list_`. Otherwise, it + // will create a cycle that prevents deletion of this object. + return self; + } + if (!self->addressable_device_list_.has_value()) { + switch (self->device_list_.index()) { + case 0: { + absl::InlinedVector addressable_devices; + const int process_index = + self->py_client_ ? self->py_client_->process_index() : 0; + for (xla::ifrt::Device* device : + std::get<0>(self->device_list_)->devices()) { + if (device->ProcessIndex() == process_index) { + addressable_devices.push_back(device); + } + } + self->addressable_device_list_ = xla::make_nb_class( + self->py_client_, self->py_client_->ifrt_client()->MakeDeviceList( + addressable_devices)); + break; + } + case 1: { + auto device_list = std::get<1>(self->device_list_); + std::vector addressable_devices; + for (size_t i = 0; i < device_list.size(); ++i) { + nb::object device = device_list[i]; + if (nb::cast(device.attr("process_index")) == + nb::cast(device.attr("client").attr("process_index")())) { + addressable_devices.push_back(std::move(device)); + } + } + self->addressable_device_list_ = xla::make_nb_class( + xla::MutableSpanToNbTuple(absl::MakeSpan(addressable_devices))); + break; + } + default: + throw nb::value_error("Unrecognized DeviceList type"); + } + } + return *self->addressable_device_list_; +} + +const std::set& PyDeviceList::ProcessIndices() { + if (!process_indices_.has_value()) { + process_indices_ = std::set{}; + switch (device_list_.index()) { + case 0: { + for (const xla::ifrt::Device* device : + std::get<0>(device_list_)->devices()) { + process_indices_->insert(device->ProcessIndex()); + } + break; + } + case 1: { + for (nb::handle device : std::get<1>(device_list_)) { + process_indices_->insert(nb::cast(device.attr("process_index"))); + } + break; + } + default: + throw nb::value_error("Unrecognized DeviceList type"); + } + } + return *process_indices_; +} + +const std::string& PyDeviceList::DeviceKind() { + if (!device_kind_.has_value()) { + auto device_list = ifrt_device_list(); + if (!device_list.ok()) { + throw nb::value_error(device_list.status().ToString().c_str()); + } + if (Len() == 0) { + throw nb::value_error("DeviceList is empty"); + } + device_kind_ = (*device_list)->devices()[0]->Kind(); + } + return *device_kind_; +} + +void PyDeviceList::PopulateMemoryKindInfo() { + if (device_list_.index() == 1) { + // Handle Python duck-type devices in a separate function for readability. + PopulateMemoryKindInfoForDuckTypedDevices(); + return; + } + if (device_list_.index() != 0) { + throw nb::value_error("Unrecognized DeviceList type"); + } + MemoryKindInfo info; + if (std::get<0>(device_list_)->size() == 0) { + info.default_memory_kind = nb::none(); + memory_kind_info_ = std::move(info); + return; + } + xla::ifrt::Device* device = std::get<0>(device_list_)->devices()[0]; + + auto default_memory = device->DefaultMemory(); + if (!default_memory.ok()) { + // Cache the error. + memory_kind_info_ = default_memory.status(); + return; + } + info.default_memory_kind = nb::cast(*(*default_memory)->Kind().memory_kind()); + nb::tuple memory_kinds = + nb::steal(PyTuple_New(device->Memories().size())); + for (size_t i = 0; i < device->Memories().size(); ++i) { + auto* memory = device->Memories()[i]; + nb::str s = nb::str(memory->Kind().memory_kind()->data(), + memory->Kind().memory_kind()->size()); + PyTuple_SET_ITEM(memory_kinds.ptr(), i, s.release().ptr()); + } + info.memory_kinds = std::move(memory_kinds); + memory_kind_info_ = std::move(info); +} + +void PyDeviceList::PopulateMemoryKindInfoForDuckTypedDevices() { + MemoryKindInfo info; + try { + if (std::get<1>(device_list_).size() == 0) { + info.default_memory_kind = nb::none(); + // info.memory_kinds is default-initialized to an empty tuple. + memory_kind_info_ = std::move(info); + return; + } + nb::handle device = std::get<1>(device_list_)[0]; + auto default_memory = device.attr("default_memory")(); + info.default_memory_kind = default_memory.attr("kind"); + info.memory_kinds = nb::tuple( + nb::object(device.attr("addressable_memories")())); + memory_kind_info_ = std::move(info); + } catch (nb::python_error& e) { + // Cache the error. + memory_kind_info_ = xla::InvalidArgument("%s", e.what()); + } +} + +/*static*/ absl::StatusOr PyDeviceList::MemoryKinds( + xla::nb_class_ptr self) { + nb::ft_object_guard lock(self); + if (!self->memory_kind_info_.has_value()) { + self->PopulateMemoryKindInfo(); + } + if (!self->memory_kind_info_->ok()) { + return self->memory_kind_info_->status(); + } + return (*self->memory_kind_info_)->memory_kinds; +} + +/*static*/ absl::StatusOr PyDeviceList::DefaultMemoryKind( + xla::nb_class_ptr self) { + nb::ft_object_guard lock(self); + if (!self->memory_kind_info_.has_value()) { + self->PopulateMemoryKindInfo(); + } + if (!self->memory_kind_info_->ok()) { + return self->memory_kind_info_->status(); + } + return (*self->memory_kind_info_)->default_memory_kind; +} + +/*static*/ void PyDeviceList::Register(nb::module_& m) { + nb::class_(m, "DeviceList") + .def(nb::init()) + .def("__hash__", &PyDeviceList::Hash, nb::lock_self()) + .def("__eq__", &PyDeviceList::Equal) + .def("__ne__", &PyDeviceList::NotEqual) + .def("__len__", &PyDeviceList::Len) + .def("__getitem__", &PyDeviceList::GetItem) + .def("__getitem__", &PyDeviceList::GetSlice) + .def("__iter__", &PyDeviceList::Iter, nb::keep_alive<0, 1>()) + .def("__str__", &PyDeviceList::Str) + .def("__repr__", &PyDeviceList::Str) + .def("__getstate__", [](const PyDeviceList& l) { return l.Dump(); }) + .def("__setstate__", + [](PyDeviceList& self, nb::tuple t) { + new (&self) PyDeviceList(std::move(t)); + }) + .def_prop_ro("is_fully_addressable", &PyDeviceList::IsFullyAddressable, + nb::lock_self()) + .def_prop_ro("addressable_device_list", + &PyDeviceList::AddressableDeviceList) + .def_prop_ro("process_indices", &PyDeviceList::ProcessIndices, + nb::lock_self()) + // `xla::ValueOrThrowWrapper` does not work with + // `def_prop_ro()`. Manually convert an error into an exception. + .def_prop_ro("default_memory_kind", + [](xla::nb_class_ptr l) { + auto kind = DefaultMemoryKind(l); + if (!kind.ok()) { + throw nb::value_error(kind.status().ToString().c_str()); + } + return *kind; + }) + .def_prop_ro("memory_kinds", [](xla::nb_class_ptr l) { + auto kinds = MemoryKinds(l); + if (!kinds.ok()) { + throw nb::value_error(kinds.status().ToString().c_str()); + } + return *kinds; + }) + .def_prop_ro("device_kind", &PyDeviceList::DeviceKind, nb::lock_self()); +} + +} // namespace jax diff --git a/jaxlib/py_device_list.h b/jaxlib/py_device_list.h new file mode 100644 index 000000000000..8cc44206e734 --- /dev/null +++ b/jaxlib/py_device_list.h @@ -0,0 +1,147 @@ +/* Copyright 2023 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_PY_DEVICE_LIST_H_ +#define JAXLIB_PY_DEVICE_LIST_H_ + +#include +#include +#include +#include +#include + +#include "absl/status/statusor.h" +#include "nanobind/nanobind.h" +#include "jaxlib/nb_class_ptr.h" +#include "jaxlib/py_client.h" +#include "xla/python/ifrt/device_list.h" + +namespace jax { + +// Device list with various caching and direct access to IFRT DeviceList. +class PyDeviceList { + public: + PyDeviceList(xla::nb_class_ptr py_client, + xla::ifrt::DeviceListRef device_list); + explicit PyDeviceList(nanobind::tuple py_device_assignment); + ~PyDeviceList(); + + PyDeviceList(const PyDeviceList&) = delete; + PyDeviceList(PyDeviceList&&) = delete; + PyDeviceList& operator=(const PyDeviceList&) = delete; + PyDeviceList& operator=(PyDeviceList&&) = delete; + + static nanobind::handle type() { + static auto type = nanobind::type(); + return type; + } + + // These two methods are safe to call from C++ without GIL. + xla::nb_class_ptr py_client() const { return py_client_; } + absl::StatusOr ifrt_device_list() const; + + int Len() const; // Requires the GIL in GIL mode. + nanobind::object GetItem(int index); // Requires the GIL in GIL mode. + + // Requires the GIL in GIL mode. Acquires the self lock in non-GIL mode. + static xla::nb_class_ptr AddressableDeviceList( + xla::nb_class_ptr self); + + // Requires the GIL in GIL mode. Acquires the self lock in non-GIL mode. + static absl::StatusOr DefaultMemoryKind( + xla::nb_class_ptr self); + + // Requires the GIL in GIL mode. Acquires the self lock in non-GIL mode. + static absl::StatusOr MemoryKinds( + xla::nb_class_ptr self); + + // go/pywald-pybind-annotation BEGIN + // refs { + // module_path: "third_party/py/jax/jaxlib/xla.cc" + // module_arg {} + // } + // go/pywald-pybind-annotation END + static void Register(nanobind::module_& m); + + private: + nanobind::tuple AsTuple() const; + + // Methods below require GIL. + nanobind::object GetSlice(nanobind::slice slice); + nanobind::iterator Iter(); + + std::string Str(); + + nanobind::tuple Dump() const; + + int64_t Hash(); // Mutates hash_, needs self lock. + + static bool Equal(xla::nb_class_ptr self, + nanobind::handle other); + static bool NotEqual(xla::nb_class_ptr self, + nanobind::handle other); + + // Finds the memory kind info from an addressable device. Requires the GIL + // or self lock. + void PopulateMemoryKindInfo(); + // Same as `PopulateMemoryKindInfo()`, but uses `py_device_assignment_` + // instead of `ifrt_device_list_` to support duck-typed device objects. + // Requires the GIL or self lock. + void PopulateMemoryKindInfoForDuckTypedDevices(); + + // Requires the self lock or GIL is held. + bool IsFullyAddressable(); + + // Requires the self lock or GIL. + const std::set& ProcessIndices(); + + // Requires the self lock or GIL. + const std::string& DeviceKind(); + + // Valid only if `device_list_` contains `xla::ifrt::DeviceList` and + // non-empty. + xla::nb_class_ptr py_client_; + + // Either C++ `ifrt::DeviceList` or Python duck-type devices. + // TODO(hyeontaek): Remove support for Python duck-type devices once all + // JAX backends and tests are migrated to use an `xla::ifrt::Device` type + // for JAX devices. + // Immutable after constructor; no locking needed. + std::variant device_list_; + + // Populated on demand. Guarded by the object's self lock. + std::optional hash_; + // TODO(hyeontaek): Make the following property cached within + // `xla::ifrt::DeviceList`. + // Populated on demand. Guarded by the object's self lock. + std::optional is_fully_addressable_; + // Populated on demand. Guarded by the object's self lock. + std::optional> addressable_device_list_; + // Populated on demand. Guarded by the object's self lock. + std::optional> process_indices_; + // Populated on demand. Guarded by the object's self lock. + std::optional device_kind_; + + struct MemoryKindInfo { + nanobind::object default_memory_kind; + nanobind::tuple memory_kinds; + }; + // Populated on demand. Guarded by the object's self lock. + std::optional> memory_kind_info_; +}; + +} // namespace jax + +#endif // JAXLIB_PY_DEVICE_LIST_H_ diff --git a/jaxlib/py_executable.cc b/jaxlib/py_executable.cc new file mode 100644 index 000000000000..f3acfa8f62e3 --- /dev/null +++ b/jaxlib/py_executable.cc @@ -0,0 +1,429 @@ +/* Copyright 2020 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/py_executable.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/base/casts.h" +#include "absl/container/inlined_vector.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "nanobind/nanobind.h" +#include "jaxlib/nb_class_ptr.h" +#include "jaxlib/py_array.h" +#include "jaxlib/py_client.h" +#include "jaxlib/py_device.h" +#include "jaxlib/traceback.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/pjrt/pjrt_future.h" +#include "xla/pjrt/pjrt_layout.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/device_list.h" +#include "xla/python/ifrt/executable.h" +#include "xla/python/ifrt/future.h" +#include "xla/python/ifrt/memory.h" +#include "xla/python/ifrt/sharding.h" +#include "xla/tsl/concurrency/ref_count.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/status.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/fingerprint.h" +#include "tsl/profiler/lib/traceme.h" + +namespace xla { + +namespace nb = nanobind; + +absl::Status PyToken::Await() { + CHECK(future_.IsValid()); + nb::gil_scoped_release gil_release; + return future_.Await(); +} + +absl::Status PyShardedToken::Await() { + nb::gil_scoped_release gil_release; + absl::Status status = absl::OkStatus(); + for (auto& future : futures_) { + auto s = future.Await(); + if (!s.ok()) status = std::move(s); + } + return status; +} + +PyLoadedExecutable::PyLoadedExecutable( + nb_class_ptr client, + ifrt::LoadedExecutableRef ifrt_loaded_executable, + std::optional traceback, std::optional fingerprint) + : client_(std::move(client)), + ifrt_loaded_executable_(std::move(ifrt_loaded_executable)), + traceback_(std::move(traceback)), + fingerprint_(std::move(fingerprint)), + next_launch_id_( + fingerprint_.has_value() ? tsl::Fingerprint32(*fingerprint_) : 1) { + CHECK(PyGILState_Check()); + if (fingerprint_) { + VLOG(1) << "Fingerprint for executable " << ifrt_loaded_executable_->name() + << ": " << *fingerprint_; + } + nb::ft_lock_guard lock(client_->executables_mutex_); + next_ = client_->executables_; + client_->executables_ = this; + prev_ = nullptr; + if (next_) { + next_->prev_ = this; + } +} + +PyLoadedExecutable::~PyLoadedExecutable() { + CHECK(PyGILState_Check()); + nb::ft_lock_guard lock(client_->executables_mutex_); + if (client_->executables_ == this) { + client_->executables_ = next_; + } + if (prev_) { + prev_->next_ = next_; + } + if (next_) { + next_->prev_ = prev_; + } +} + +std::vector> PyLoadedExecutable::AddressableDevices() + const { + std::vector> devices; + devices.reserve(ifrt_loaded_executable_->addressable_devices().size()); + for (ifrt::Device* device : ifrt_loaded_executable_->addressable_devices()) { + devices.push_back(client_->GetPyDevice(device)); + } + return devices; +} + +namespace { + +static int GetNumDevices(const ExecuteShardedArg& arg) { + if (std::holds_alternative(arg)) { + return std::get(arg).num_addressable_shards(); + } else { + return std::get>(arg).size(); + } +} +static ifrt::ArrayRef GetIfRtArray(const ExecuteShardedArg& arg) { + if (std::holds_alternative(arg)) { + return tsl::FormRef(std::get(arg).ifrt_array()); + } + auto& arg_vector = std::get>(arg); + + // TODO(hyeontaek): This on-demand Array creation is not efficient and has + // insufficient information about the shape (a dummy shape is used). This + // should be removed if possible and only be used in the context where the + // shape information is unused. + std::vector ifrt_arrays; + ifrt_arrays.reserve(arg_vector.size()); + absl::InlinedVector devices; + devices.reserve(arg_vector.size()); + for (auto& arr : arg_vector) { + CHECK_EQ(arr.ifrt_array()->sharding().devices()->size(), 1) + << arr.ifrt_array()->sharding().DebugString(); + ifrt_arrays.push_back(tsl::FormRef(arr.ifrt_array())); + devices.push_back( + arr.ifrt_array()->sharding().devices()->devices().front()); + } + CHECK(!ifrt_arrays.empty()); + // Use a dummy shape. + // TODO(hyeontaek): Find a way to compute a correct shape. + // TODO(yashkatariya): Plumb sharding or memory_kind here. + ifrt::Client* client = ifrt_arrays.front()->client(); + auto ifrt_array = client->AssembleArrayFromSingleDeviceArrays( + ifrt_arrays.front()->shape(), + ifrt::OpaqueSharding::Create(client->MakeDeviceList(devices), + ifrt::MemoryKind()), + absl::MakeSpan(ifrt_arrays), ifrt::ArrayCopySemantics::kReuseInput, + ifrt::SingleDeviceShardSemantics::kAddressableShards); + TF_CHECK_OK(ifrt_array.status()); + return *ifrt_array; +} + +void PopulateExecuteShardedResults(const nb_class_ptr& client, + std::vector ifrt_arrays, + const PjRtFuture<>& result_status, + int num_computations, + std::vector>& outputs) { + auto traceback = Traceback::Get(); + DCHECK_GT(num_computations, 0); + int num_output_buffers = ifrt_arrays.size(); + outputs.resize(num_output_buffers); + for (int buffer_id = 0; buffer_id < num_output_buffers; ++buffer_id) { + outputs[buffer_id].reserve(num_computations); + auto exploded_arrays = + ifrt_arrays[buffer_id]->DisassembleIntoSingleDeviceArrays( + ifrt::ArrayCopySemantics::kReuseInput, + ifrt::SingleDeviceShardSemantics::kAddressableShards); + TF_CHECK_OK(exploded_arrays.status()); + for (auto& exploded_array : *exploded_arrays) { + outputs[buffer_id].push_back(PyArray::MakeFromSingleDeviceArray( + client, traceback, std::move(exploded_array), false, true, + result_status)); + } + } +} + +absl::StatusOr ExecuteShardedOnLocalDevicesInternal( + const ifrt::ExecuteOptions& options, const nb_class_ptr& client, + ifrt::LoadedExecutable* ifrt_loaded_executable, + absl::Span args, + std::optional>>& returned_futures) { + std::vector output_arrays; + std::unique_ptr> returned_future; + int num_computations = ifrt_loaded_executable->addressable_devices().size(); + PjRtFuture<> result_status; + { + nb::gil_scoped_release gil_release; + for (const auto& arg : args) { + if (GetNumDevices(arg) != num_computations) { + return InvalidArgument( + "Expected args to execute_sharded_on_local_devices to have %d " + "shards, got: [%s]", + num_computations, + absl::StrJoin(args, ", ", + [](std::string* out, const ExecuteShardedArg& arg) { + out->append(std::to_string(GetNumDevices(arg))); + })); + } + } + std::vector arg_arrays(args.size()); + absl::c_transform(args, arg_arrays.begin(), + [&](const ExecuteShardedArg& arg) mutable { + return GetIfRtArray(arg); + }); + TF_ASSIGN_OR_RETURN(auto result, ifrt_loaded_executable->Execute( + absl::MakeSpan(arg_arrays), options, + /*devices=*/std::nullopt)); + output_arrays = std::move(result.outputs); + // options.fill_status is only supposed to be true when the computation has + // tokens. + if (options.fill_status) { + result_status = result.status; + if (returned_futures.has_value()) { + returned_futures->resize(num_computations, std::move(result.status)); + } + } + } + + // TODO(b/240696624): Although the PjRt interface require `returned_futures` + // to be resized correctly if it is not nullopt, some implementation does not + // implement this. So we have to check whether returned_futures is empty. + // Remove this check once the implementation is fixed. + auto py_sharded_token = returned_futures.has_value() + ? PyShardedToken(std::move(*returned_futures)) + : PyShardedToken(); + + return PyExecuteResults(client, std::move(output_arrays), num_computations, + std::move(py_sharded_token), result_status); +} + +} // namespace + +PyExecuteResults::PyExecuteResults(const nb_class_ptr& client, + std::vector ifrt_arrays, + int num_computations, PyShardedToken token, + PjRtFuture<> result_status) + : client_(client), + ifrt_arrays_(std::move(ifrt_arrays)), + num_computations_(num_computations), + token_(std::move(token)), + result_status_(std::move(result_status)) {} + +void PyExecuteResults::CheckNotDisassembled() const { + if (is_exploded_) { + throw nb::value_error("ExecuteResults already exploded."); + } +} + +std::vector PyExecuteResults::Consume() { + CheckNotDisassembled(); + is_exploded_ = true; + return std::move(ifrt_arrays_); +} + +PyShardedToken PyExecuteResults::ConsumeToken() { + if (token_consumed_) { + throw nb::value_error("ExecuteResults token already consumed."); + } + token_consumed_ = true; + return std::move(token_); +} + +std::vector> +PyExecuteResults::DisassembleIntoSingleDeviceArrays() { + std::vector> outputs; + PopulateExecuteShardedResults( + client_, Consume(), + result_status_.IsValid() ? result_status_ : PjRtFuture<>(), + num_computations_, outputs); + return outputs; +} + +std::vector> +PyExecuteResults::DisassemblePrefixIntoSingleDeviceArrays(size_t n) { + CheckNotDisassembled(); + if (n > ifrt_arrays_.size()) { + throw nb::value_error( + absl::StrCat("In DisassemblePrefixIntoSingleDeviceArrays: ", n, " > ", + ifrt_arrays_.size()) + .c_str()); + } + std::vector ifrt_arrays; + ifrt_arrays.reserve(ifrt_arrays_.size() - n); + for (size_t i = n; i < ifrt_arrays_.size(); ++i) { + ifrt_arrays.push_back(std::move(ifrt_arrays_[i])); + } + ifrt_arrays_.erase(ifrt_arrays_.begin() + n, ifrt_arrays_.end()); + std::swap(ifrt_arrays_, ifrt_arrays); + std::vector> outputs; + PopulateExecuteShardedResults( + client_, std::move(ifrt_arrays), + result_status_.IsValid() ? result_status_ : PjRtFuture<>(), + num_computations_, outputs); + return outputs; +} + +std::vector PyExecuteResults::ConsumeWithHandlers( + std::vector> + out_handlers) { + std::vector outputs; + auto ifrt_arrays = Consume(); + auto traceback = Traceback::Get(); + int num_output_buffers = ifrt_arrays.size(); + outputs.reserve(num_output_buffers); + if (out_handlers.size() != num_output_buffers) { + throw nb::value_error( + absl::StrCat("Mismatch between out_handlers and num_results: ", + out_handlers.size(), " vs ", num_output_buffers) + .c_str()); + } + for (int buffer_id = 0; buffer_id < num_output_buffers; ++buffer_id) { + auto& handler = out_handlers[buffer_id]; + if (std::holds_alternative(handler)) { + outputs.push_back(std::get(handler)->Call( + client_, std::move(ifrt_arrays[buffer_id]), + result_status_.IsValid() ? result_status_ : PjRtFuture<>())); + } else { + tsl::profiler::TraceMe traceme("ConsumeWithHandlers fallback."); + auto disassembled_arrays = + ifrt_arrays[buffer_id]->DisassembleIntoSingleDeviceArrays( + ifrt::ArrayCopySemantics::kReuseInput, + ifrt::SingleDeviceShardSemantics::kAddressableShards); + TF_CHECK_OK(disassembled_arrays.status()); + nb::list bufs = + nb::steal(PyList_New(disassembled_arrays->size())); + int i = 0; + for (auto& disassembled_array : *disassembled_arrays) { + nb::object array = PyArray::MakeFromSingleDeviceArray( + client_, traceback, std::move(disassembled_array), false, true, + result_status_.IsValid() ? result_status_ : PjRtFuture<>()); + PyList_SET_ITEM(bufs.ptr(), i, array.release().ptr()); + ++i; + } + outputs.push_back(std::get(handler)(std::move(bufs))); + } + } + return outputs; +} + +absl::StatusOr PyLoadedExecutable::ExecuteSharded( + std::vector args, bool with_tokens) { + xla::ifrt::ExecuteOptions options = options_; + options.launch_id = GetNextLaunchId(); + options.fill_status = with_tokens; + options.execution_stream_id = GetExecutionStreamId(); + if (options.execution_stream_id == 0) { + options.execution_stream_id = tsl::Env::Default()->GetCurrentThreadId(); + } + std::optional>> returned_futures; + if (with_tokens) { + returned_futures.emplace(); + } + absl::Span span_args = args; + return ExecuteShardedOnLocalDevicesInternal(options, client_, + ifrt_loaded_executable_.get(), + span_args, returned_futures); +} + +absl::StatusOr>> +PyLoadedExecutable::HloModules() const { + nb::gil_scoped_release gil_release; + return ifrt_loaded_executable_->GetHloModules(); +} + +absl::StatusOr>> +PyLoadedExecutable::GetOutputMemoryKinds() const { + nb::gil_scoped_release gil_release; + return ifrt_loaded_executable_->GetOutputMemoryKinds(); +} + +absl::StatusOr>> +PyLoadedExecutable::GetParameterLayouts() const { + nb::gil_scoped_release gil_release; + return ifrt_loaded_executable_->GetParameterLayouts(); +} + +absl::StatusOr>> +PyLoadedExecutable::GetOutputLayouts() const { + nb::gil_scoped_release gil_release; + return ifrt_loaded_executable_->GetOutputLayouts(); +} + +std::optional> +PyLoadedExecutable::GetParameterShardings() const { + nb::gil_scoped_release gil_release; + return ifrt_loaded_executable_->GetParameterShardings(); +} + +std::optional> PyLoadedExecutable::GetOutputShardings() + const { + nb::gil_scoped_release gil_release; + return ifrt_loaded_executable_->GetOutputShardings(); +} + +int32_t PyLoadedExecutable::GetNextLaunchId() { + return absl::bit_cast( + next_launch_id_.fetch_add(1, std::memory_order_relaxed)); +} + +void PyLoadedExecutable::KeepAlive(nb::object obj) { + keepalives_.push_back(std::move(obj)); +} + +} // namespace xla diff --git a/jaxlib/py_executable.h b/jaxlib/py_executable.h new file mode 100644 index 000000000000..ee68e8388627 --- /dev/null +++ b/jaxlib/py_executable.h @@ -0,0 +1,294 @@ +/* Copyright 2020 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_PY_EXECUTABLE_H_ +#define JAXLIB_PY_EXECUTABLE_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "llvm/Support/Casting.h" +#include "nanobind/nanobind.h" +#include "jaxlib/nb_class_ptr.h" +#include "jaxlib/py_array.h" +#include "jaxlib/py_client.h" +#include "jaxlib/traceback.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/pjrt/exceptions.h" +#include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_executable.h" +#include "xla/pjrt/pjrt_future.h" +#include "xla/pjrt/pjrt_layout.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/attribute_map.h" +#include "xla/python/ifrt/executable.h" +#include "xla/python/pjrt_ifrt/pjrt_executable.h" +#include "xla/tsl/concurrency/ref_count.h" +#include "xla/tsl/platform/status.h" +#include "xla/xla_data.pb.h" + +namespace xla { + +class PyToken { + public: + PyToken() = default; + explicit PyToken(PjRtFuture<> future) : future_(std::move(future)) {} + + static PyToken ReadyPyToken() { + return PyToken(PjRtFuture<>(absl::OkStatus())); + } + + absl::Status Await(); + + private: + PjRtFuture<> future_; +}; + +// PyShardedToken contains a PyToken for each device's execution. +class PyShardedToken { + public: + // Default construction creates a always-ready token. + PyShardedToken() = default; + explicit PyShardedToken(std::vector> futures) + : futures_(std::move(futures)) {} + + PyToken GetPyToken(int device_id) const { + if (futures_.empty()) return PyToken::ReadyPyToken(); + return PyToken(futures_.at(device_id)); + } + + absl::Status Await(); + + private: + std::vector> futures_; +}; + +class PyExecuteResults { + public: + PyExecuteResults(const nb_class_ptr& client, + std::vector ifrt_arrays, + int num_computations, PyShardedToken token, + PjRtFuture<> result_status = PjRtFuture<>()); + + std::vector> DisassembleIntoSingleDeviceArrays(); + + std::vector> DisassemblePrefixIntoSingleDeviceArrays( + size_t n); + + std::vector ConsumeWithHandlers( + std::vector> + out_handlers); + + std::vector Consume(); + + PyShardedToken ConsumeToken(); + + size_t Size() const { + CheckNotDisassembled(); + return ifrt_arrays_.size(); + } + + void CheckNotDisassembled() const; + + private: + bool is_exploded_ = false; + bool token_consumed_ = false; + nb_class_ptr client_; + std::vector ifrt_arrays_; + int num_computations_; + PyShardedToken token_; + // Only set if the computation has tokens. + PjRtFuture<> result_status_; +}; + +using ExecuteShardedArg = std::variant>; + +// Thin Python wrapper around ifrt::ExecutableRef. We use a wrapper class: +// a) Standardize around ifrt::ExecutableRef, which is +// std::shared_ptr. +// b) Concrete subclasses of ifrt::Executable have protected constructors. +class PyExecutable { + public: + PyExecutable(ifrt::ExecutableRef ifrt_executable) + : ifrt_executable_(std::move(ifrt_executable)) {}; + ~PyExecutable() = default; + + // NOTE(dsuo): For now, we only expose the ifrt::Executable members required + // by the Python bindings. + absl::StatusOr>> GetHloModules() + const { + return ifrt_executable_->GetHloModules(); + } + absl::StatusOr>> + GetOutputMemoryKinds() const { + return ifrt_executable_->GetOutputMemoryKinds(); + } + std::optional> GetOutputShardings() const { + return ifrt_executable_->GetOutputShardings(); + } + absl::StatusOr>> + GetParameterLayouts() const { + return ifrt_executable_->GetParameterLayouts(); + } + absl::StatusOr>> + GetOutputLayouts() const { + return ifrt_executable_->GetOutputLayouts(); + } + std::optional> GetParameterShardings() const { + return ifrt_executable_->GetParameterShardings(); + } + absl::StatusOr GetCompiledMemoryStats() const { + return ifrt_executable_->GetCompiledMemoryStats(); + } + absl::StatusOr Serialize() const { + return ifrt_executable_->Serialize(); + } + absl::StatusOr GetCostAnalysis() const { + return ifrt_executable_->GetCostAnalysis(); + } + + private: + ifrt::ExecutableRef ifrt_executable_; +}; + +// Python wrapper around ifrt::LoadedExecutableRef. We use a wrapper class: +// a) to keep the PyClient alive via a std::shared_ptr<> +// b) to add Python-specific functionality. +class PyLoadedExecutable { + public: + PyLoadedExecutable(nb_class_ptr client, + ifrt::LoadedExecutableRef ifrt_loaded_executable, + std::optional traceback, + std::optional fingerprint); + ~PyLoadedExecutable(); + + nb_class_ptr client() const { return client_; } + ifrt::LoadedExecutable* ifrt_loaded_executable() const { + return ifrt_loaded_executable_.get(); + } + + ifrt::LoadedExecutableRef shared_ifrt_loaded_executable() { + return ifrt_loaded_executable_; + } + + std::vector> AddressableDevices() const; + + int64_t SizeOfGeneratedCodeInBytes() const { + return ifrt_loaded_executable_->SizeOfGeneratedCodeInBytes(); + } + + absl::StatusOr GetCompiledMemoryStats() const { + nanobind::gil_scoped_release scope; + return ifrt_loaded_executable_->GetCompiledMemoryStats(); + } + + absl::StatusOr GetCostAnalysis() const { + return ifrt_loaded_executable_->GetCostAnalysis(); + } + + // Takes args indexed by argid then deviceid, transposes them, and passes to + // ifrt::LoadedExecutable::Execute. The result is similarly transposed back + // into the argid,deviceid format. + // args is [num_args x num_devices]. + absl::StatusOr ExecuteSharded( + std::vector args, bool with_tokens); + + absl::StatusOr>> HloModules() const; + + absl::StatusOr>> + GetOutputMemoryKinds() const; + + absl::StatusOr>> + GetParameterLayouts() const; + + absl::StatusOr>> + GetOutputLayouts() const; + + std::optional> GetParameterShardings() const; + + std::optional> GetOutputShardings() const; + + const std::optional& traceback() { return traceback_; } + + ifrt::LoadedExecutable* ifrt_executable() const { + return ifrt_loaded_executable_.get(); + } + + // Short-term escape hatch to get PjRtLoadedExecutable from PyExecutable. + // TODO(hyeontaek): Migrate all users of this method to be agnostic of PjRt. + std::shared_ptr shared_ptr_pjrt_executable() { + auto* exec = llvm::dyn_cast_or_null( + ifrt_loaded_executable_.get()); + if (exec == nullptr) { + throw XlaRuntimeError( + "This operation is implemented for a PjRt-compatible backend only."); + } + return exec->shared_ptr_pjrt_loaded_executable(); + } + + // Returns a template of execute options to pass to + // `ifrt_executable()->Execute()`. Note that the caller may need to override + // some options such as `launch_id` that change at each execution. + const ifrt::ExecuteOptions& options() const { return options_; } + + // Returns a unique launch ID to use for the next execution. + int32_t GetNextLaunchId(); + + const std::optional& fingerprint() const { return fingerprint_; } + + // Keep `obj` alive as long as PyLoadedExecutable. + void KeepAlive(nanobind::object obj); + + private: + friend class PyClient; + + nb_class_ptr client_; + ifrt::LoadedExecutableRef ifrt_loaded_executable_; + std::optional traceback_; + + // Identical executables (i.e. representing the same program) will have the + // same fingerprint. nullopt on platforms or executables where fingerprints + // aren't implemented. + std::optional fingerprint_; + + // Launch ID to use for the next execution. + std::atomic next_launch_id_; + + // The options to pass to `executable_.Execute`. + ifrt::ExecuteOptions options_; + + // Python objects to keep alive as requested by user. + std::vector keepalives_; + + // Doubly-linked list of all executables known to the client. Protected by the + // GIL. + PyLoadedExecutable* next_; + PyLoadedExecutable* prev_; +}; + +} // namespace xla + +#endif // JAXLIB_PY_EXECUTABLE_H_ diff --git a/jaxlib/py_host_callback.cc b/jaxlib/py_host_callback.cc new file mode 100644 index 000000000000..49525db53ca5 --- /dev/null +++ b/jaxlib/py_host_callback.cc @@ -0,0 +1,259 @@ +/* Copyright 2023 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/py_host_callback.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/types/span.h" +#include "llvm/Support/ExtensibleRTTI.h" +#include "nanobind/nanobind.h" +#include "jaxlib/callback.h" +#include "jaxlib/py_host_callback.pb.h" +#include "jaxlib/python_ref_manager.h" +#include "xla/layout_util.h" +#include "xla/pjrt/host_callback.h" +#include "xla/python/ifrt/client.h" +#include "xla/python/ifrt/host_callback.h" +#include "xla/python/pjrt_ifrt/pjrt_host_callback.h" +#include "xla/python/pjrt_ifrt/xla_host_callback.pb.h" +#include "xla/python/types.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/status_macros.h" +#include "xla/tsl/concurrency/ref_count.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" + +namespace nb = nanobind; + +namespace xla { + +char PyFfiLoadedHostCallback::ID = 0; +char PyHostSendAndRecvLoadedHostCallback::ID = 0; + +namespace { + +absl::StatusOr> CreateCallbackArgs( + absl::Span operand_shapes) { + std::vector callback_args(operand_shapes.size()); + for (int i = 0; i < operand_shapes.size(); ++i) { + Shape shape = operand_shapes[i]; + + if (shape.IsArray()) { + Shape layout = + (shape.has_layout() ? shape + : LayoutUtil::GetWithDefaultLayout(shape)); + callback_args[i].dims.resize(shape.dimensions_size()); + absl::c_copy(shape.dimensions(), callback_args[i].dims.begin()); + callback_args[i].strides = ByteStridesForShape(layout); + callback_args[i].type = shape.element_type(); + callback_args[i].size_in_bytes = ShapeUtil::ByteSizeOf(layout); + TF_ASSIGN_OR_RETURN(callback_args[i].dtype, + PrimitiveTypeToNbDtype(shape.element_type())); + } else if (shape.IsToken()) { + callback_args[i].type = TOKEN; + } else { + return InvalidArgument( + "Only array and token arguments to Python callbacks are supported, " + "got %s", + shape.ToString()); + } + } + return callback_args; +} + +absl::StatusOr> CreateCallbackResults( + absl::Span result_shapes) { + std::vector callback_results(result_shapes.size()); + for (int i = 0; i < result_shapes.size(); ++i) { + if (result_shapes[i].IsArray()) { + const Shape& shape = + result_shapes[i].has_layout() + ? result_shapes[i] + : LayoutUtil::GetWithDefaultLayout(result_shapes[i]); + callback_results[i].expected_dims.resize(shape.dimensions_size()); + absl::c_copy(shape.dimensions(), + callback_results[i].expected_dims.begin()); + callback_results[i].expected_strides = ByteStridesForShape(shape); + callback_results[i].type = shape.element_type(); + callback_results[i].size_in_bytes = ShapeUtil::ByteSizeOf(shape); + callback_results[i].reversed_layout.resize(shape.dimensions_size()); + absl::c_reverse_copy(shape.layout().minor_to_major(), + callback_results[i].reversed_layout.begin()); + } else if (result_shapes[i].IsToken()) { + callback_results[i].type = TOKEN; + } else { + return InvalidArgument( + "Only array and token return values from Python callbacks are " + "supported, got %s", + result_shapes[i].ToString()); + } + } + return callback_results; +} + +} // namespace + +PyFfiLoadedHostCallback::~PyFfiLoadedHostCallback() { + // The destructor may be called without GIL held. In that case, we defer it + // to GlobalPyRefManager. + std::vector objects; + objects.push_back(std::move(callable_)); + GlobalPyRefManager()->AddGarbage(absl::MakeSpan(objects)); +} + +absl::StatusOr> +PyHostSendAndRecvLoadedHostCallback::Create( + ifrt::Client* ifrt_client, nb::callable callable, + absl::Span operand_shapes, + absl::Span result_shapes, + absl::Span send_channel_ids, + absl::Span recv_channel_ids, nb::callable serializer) { + TF_ASSIGN_OR_RETURN(auto callback_args, CreateCallbackArgs(operand_shapes)); + TF_ASSIGN_OR_RETURN(auto callback_results, + CreateCallbackResults(result_shapes)); + + // `callable` will be destroyed safely with `PythonRefManager` when + // `CpuCallback` is destroyed. + auto cpu_callback = + std::make_shared(callable, callback_args, callback_results); + + auto host_callback = std::make_unique(); + + auto assign_arg_info = [](absl::Span shapes, + absl::Span channel_ids, + std::vector& arg_infos) { + DCHECK_EQ(shapes.size(), channel_ids.size()); + arg_infos.reserve(shapes.size()); + for (int i = 0; i < shapes.size(); ++i) { + HostCallbackArgInfo host_callback_arg_info; + host_callback_arg_info.channel_id = channel_ids[i]; + const auto& shape = shapes[i]; + Shape layout = + (shape.has_layout() ? shape + : LayoutUtil::GetWithDefaultLayout(shape)); + host_callback_arg_info.shape = layout; + arg_infos.push_back(std::move(host_callback_arg_info)); + } + }; + + assign_arg_info(operand_shapes, send_channel_ids, host_callback->operands); + assign_arg_info(result_shapes, recv_channel_ids, host_callback->results); + + host_callback->callback = [cpu_callback = std::move(cpu_callback)]( + void** outputs, void** inputs) { + return cpu_callback->PrepareAndCall(outputs, inputs); + }; + return tsl::RCReference( + tsl::MakeRef( + ifrt_client, std::move(host_callback), callable, operand_shapes, + result_shapes, send_channel_ids, recv_channel_ids, + std::move(serializer))); +} + +PyHostSendAndRecvLoadedHostCallback::PyHostSendAndRecvLoadedHostCallback( + ifrt::Client* ifrt_client, + std::unique_ptr xla_host_callback, nb::callable callable, + absl::Span operand_shapes, + absl::Span result_shapes, + absl::Span send_channel_ids, + absl::Span recv_channel_ids, nb::callable serializer) + : llvm::RTTIExtends( + ifrt_client, std::move(xla_host_callback)), + callable_(std::move(callable)), + operand_shapes_(operand_shapes.begin(), operand_shapes.end()), + result_shapes_(result_shapes.begin(), result_shapes.end()), + send_channel_ids_(send_channel_ids.begin(), send_channel_ids.end()), + recv_channel_ids_(recv_channel_ids.begin(), recv_channel_ids.end()), + serializer_(serializer) {} + +PyHostSendAndRecvLoadedHostCallback::~PyHostSendAndRecvLoadedHostCallback() { + GlobalPyRefManager()->AddGarbage( + absl::MakeSpan(static_cast(&callable_), 1)); + GlobalPyRefManager()->AddGarbage( + absl::MakeSpan(static_cast(&serializer_), 1)); +} + +absl::StatusOr PyHostSendAndRecvLoadedHostCallback::Serialize() + const { + if (serializer_.is_none()) { + return InvalidArgument( + "Host callback cannot be serialized because serializer was not " + "provided by JAX"); + } + ifrt::XlaHostCallbackProto xla_host_callback_proto; + + TF_RET_CHECK(operand_shapes_.size() == send_channel_ids_.size()); + for (int i = 0; i < operand_shapes_.size(); ++i) { + ifrt::XlaHostCallbackProto::ArgInfo* const operand = + xla_host_callback_proto.add_operands(); + operand->set_channel_id(send_channel_ids_[i]); + *operand->mutable_shape() = operand_shapes_[i].ToProto(); + } + + TF_RET_CHECK(result_shapes_.size() == recv_channel_ids_.size()); + for (int i = 0; i < result_shapes_.size(); ++i) { + ifrt::XlaHostCallbackProto::ArgInfo* const result = + xla_host_callback_proto.add_results(); + result->set_channel_id(recv_channel_ids_[i]); + *result->mutable_shape() = result_shapes_[i].ToProto(); + } + + std::string callable; + { + nb::gil_scoped_acquire gil_acquire; + try { + nb::bytes bytes = nb::cast(serializer_(callable_)); + callable = std::string(bytes.c_str(), bytes.size()); + } catch (const nb::python_error& e) { + return absl::InternalError(absl::StrCat( + "Unable to pickle the host_callback callable: ", e.what())); + } catch (const std::exception& e) { + std::exception_ptr p = std::current_exception(); + return absl::InternalError(absl::StrCat( + "Exception while pickling the host_callback callable: ", e.what())); + } catch (...) { + // Ensure to avoid leaking any exception because this method could have + // been called outside of a Python context where C++ exceptions are not + // necessarily enabled. + return absl::InternalError( + "Unknown exception while pickling the host_callback callable."); + } + } + PyHostCallbackProto py_host_callback_proto; + py_host_callback_proto.set_callable(std::move(callable)); + if (!xla_host_callback_proto.mutable_serialized_callback()->PackFrom( + py_host_callback_proto)) { + return absl::InternalError("Could not serialize a Python host callback"); + } + xla_host_callback_proto.set_use_major_to_minor_data_layout_for_callbacks( + true); + return xla_host_callback_proto.SerializeAsString(); +} + +} // namespace xla diff --git a/jaxlib/py_host_callback.h b/jaxlib/py_host_callback.h new file mode 100644 index 000000000000..b98338988bfd --- /dev/null +++ b/jaxlib/py_host_callback.h @@ -0,0 +1,119 @@ +/* Copyright 2023 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_PY_HOST_CALLBACK_H_ +#define JAXLIB_PY_HOST_CALLBACK_H_ + +#include +#include +#include +#include +#include + +#include "absl/status/statusor.h" +#include "absl/types/span.h" +#include "llvm/Support/ExtensibleRTTI.h" +#include "nanobind/nanobind.h" +#include "xla/pjrt/host_callback.h" +#include "xla/python/ifrt/client.h" +#include "xla/python/ifrt/host_callback.h" +#include "xla/python/pjrt_ifrt/pjrt_host_callback.h" +#include "xla/shape.h" +#include "xla/tsl/concurrency/ref_count.h" +#include "xla/util.h" + +namespace xla { + +using PyLoadedHostCallback = ::xla::ifrt::LoadedHostCallback; + +class PyFfiLoadedHostCallback final + : public llvm::RTTIExtends { + public: + PyFfiLoadedHostCallback(ifrt::Client* ifrt_client, + nanobind::callable callable) + : llvm::RTTIExtends(ifrt_client, + callable.ptr()), + callable_(std::move(callable)) {} + ~PyFfiLoadedHostCallback() override; + + ifrt::Client* client() const override { return ifrt_client_; } + absl::StatusOr Serialize() const override { + return Unimplemented( + "PyFfiLoadedHostCallback::Serialize() is not supported"); + }; + + static char ID; // NOLINT + + private: + ifrt::Client* ifrt_client_; + nanobind::callable callable_; +}; + +// `PyHostSendAndRecvLoadedHostCallback` implements a Python host callback that +// uses XLA host send and recv. This object should be passed to the compiler +// when creating `xla::ifrt::LoadedExecutable`. +// +// Serialization is supported if the Python host callback using the +// `cloudpickle` third-party library. +// +// TODO(hyeontaek): Update the comment ("compiler" to "client") after splitting +// compilation and loading. +class PyHostSendAndRecvLoadedHostCallback final + : public llvm::RTTIExtends { + public: + static absl::StatusOr> + Create(ifrt::Client* ifrt_client, nanobind::callable callable, + absl::Span operand_shapes, + absl::Span result_shapes, + absl::Span send_channel_ids, + absl::Span recv_channel_ids, + nanobind::callable serializer); + + // PjRtLoadedHostCallback implementation. + + ~PyHostSendAndRecvLoadedHostCallback() override; + + absl::StatusOr Serialize() const override; + + static char ID; // NOLINT + + private: + PyHostSendAndRecvLoadedHostCallback( + ifrt::Client* ifrt_client, + std::unique_ptr xla_host_callback, + nanobind::callable callable, absl::Span operand_shapes, + absl::Span result_shapes, + absl::Span send_channel_ids, + absl::Span recv_channel_ids, + nanobind::callable serializer); + + template + friend tsl::RCReference tsl::MakeRef(Args&&... args); + + // Retained arguments for host callback serialization. + nanobind::callable callable_; + std::vector operand_shapes_; + std::vector result_shapes_; + std::vector send_channel_ids_; + std::vector recv_channel_ids_; + nanobind::callable serializer_; +}; + +} // namespace xla + +#endif // JAXLIB_PY_HOST_CALLBACK_H_ diff --git a/jaxlib/py_host_callback.proto b/jaxlib/py_host_callback.proto new file mode 100644 index 000000000000..997fc7fe450c --- /dev/null +++ b/jaxlib/py_host_callback.proto @@ -0,0 +1,25 @@ +/* Copyright 2023 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto3"; + +package xla; + +// Represents a JAX host callback that is serialized using the 'cloudpickle' +// Python library. Typically used for +// `xla.ifrt.XlaHostCallbackProto.serialized_callback`. +message PyHostCallbackProto { + bytes callable = 1; +} diff --git a/jaxlib/py_memory_space.cc b/jaxlib/py_memory_space.cc new file mode 100644 index 000000000000..2c123942a92d --- /dev/null +++ b/jaxlib/py_memory_space.cc @@ -0,0 +1,102 @@ +/* Copyright 2024 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/py_memory_space.h" + +#include + +#include + +#include "absl/strings/string_view.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "jaxlib/nb_class_ptr.h" +#include "jaxlib/py_client.h" +#include "xla/python/ifrt/device.h" + +namespace nb = ::nanobind; + +namespace xla { + +PyMemorySpace::PyMemorySpace(nb_class_ptr client, + ifrt::Memory* memory) + : client_(std::move(client)), memory_(memory) {} + +int PyMemorySpace::process_index() const { return client_->process_index(); } + +absl::string_view PyMemorySpace::platform() const { + // TODO(phawkins): this is a temporary backwards + // compatibility shim. We changed the name PJRT + // reports for GPU platforms to "cuda" or "rocm", + // but we haven't yet updated JAX clients that + // expect "gpu". Migrate users and remove this + // code. + if (client_->platform_name() == "cuda" || + client_->platform_name() == "rocm") { + return absl::string_view("gpu"); + } else { + return client_->platform_name(); + } +} + +absl::string_view PyMemorySpace::kind() const { + return *memory_->Kind().memory_kind(); +} + +absl::string_view PyMemorySpace::Str() const { return memory_->DebugString(); } + +absl::string_view PyMemorySpace::Repr() const { return memory_->ToString(); } + +nb::list PyMemorySpace::AddressableByDevices() const { + nb::list devices; + for (ifrt::Device* device : memory_->Devices()) { + devices.append(client_->GetPyDevice(device)); + } + return devices; +} + +/* static */ int PyMemorySpace::tp_traverse(PyObject* self, visitproc visit, + void* arg) { + PyMemorySpace* d = nb::inst_ptr(self); + Py_VISIT(d->client().ptr()); + return 0; +} + +/* static */ int PyMemorySpace::tp_clear(PyObject* self) { + PyMemorySpace* d = nb::inst_ptr(self); + nb_class_ptr client; + std::swap(client, d->client_); + return 0; +} + +PyType_Slot PyMemorySpace::slots_[] = { + {Py_tp_traverse, (void*)PyMemorySpace::tp_traverse}, + {Py_tp_clear, (void*)PyMemorySpace::tp_clear}, + {0, nullptr}, +}; + +/* static */ void PyMemorySpace::RegisterPythonType(nb::module_& m) { + nb::class_ device(m, "Memory", + nb::type_slots(PyMemorySpace::slots_)); + device.def_prop_ro("process_index", &PyMemorySpace::process_index) + .def_prop_ro("platform", &PyMemorySpace::platform) + .def_prop_ro("kind", &PyMemorySpace::kind) + .def("__str__", &PyMemorySpace::Str) + .def("__repr__", &PyMemorySpace::Repr) + .def("addressable_by_devices", &PyMemorySpace::AddressableByDevices, + "Returns devices that can address this memory."); +} + +} // namespace xla diff --git a/jaxlib/py_memory_space.h b/jaxlib/py_memory_space.h new file mode 100644 index 000000000000..2196a6cd9f30 --- /dev/null +++ b/jaxlib/py_memory_space.h @@ -0,0 +1,65 @@ +/* Copyright 2024 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_PY_MEMORY_SPACE_H_ +#define JAXLIB_PY_MEMORY_SPACE_H_ + +#include + +#include "absl/strings/string_view.h" +#include "nanobind/nanobind.h" +#include "jaxlib/nb_class_ptr.h" +#include "jaxlib/py_client.h" +#include "xla/python/ifrt/memory.h" + +namespace xla { + +class PyMemorySpace { + public: + PyMemorySpace(nb_class_ptr client, ifrt::Memory* memory_space); + + // Memory spaces are compared using Python object identity, so we don't allow + // them to be copied or moved. + PyMemorySpace(const PyMemorySpace&) = delete; + PyMemorySpace(PyMemorySpace&&) = delete; + PyMemorySpace& operator=(const PyMemorySpace&) = delete; + PyMemorySpace& operator=(PyMemorySpace&&) = delete; + + const nb_class_ptr& client() const { return client_; } + ifrt::Memory* memory_space() const { return memory_; } + + int process_index() const; + absl::string_view platform() const; + absl::string_view kind() const; + + absl::string_view Str() const; + absl::string_view Repr() const; + + nanobind::list AddressableByDevices() const; + + static void RegisterPythonType(nanobind::module_& m); + + private: + static int tp_traverse(PyObject* self, visitproc visit, void* arg); + static int tp_clear(PyObject* self); + static PyType_Slot slots_[]; + + nb_class_ptr client_; + ifrt::Memory* memory_; +}; + +} // namespace xla + +#endif // JAXLIB_PY_MEMORY_SPACE_H_ diff --git a/jaxlib/py_program.cc b/jaxlib/py_program.cc new file mode 100644 index 000000000000..ee2d3eef9973 --- /dev/null +++ b/jaxlib/py_program.cc @@ -0,0 +1,296 @@ +/* Copyright 2024 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/py_program.h" + +#include +#include +#include +#include +#include + +#include "absl/container/inlined_vector.h" +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/cord.h" +#include "absl/strings/string_view.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OwningOpRef.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/unique_ptr.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "jaxlib/nb_class_ptr.h" +#include "jaxlib/py_device.h" +#include "jaxlib/py_device_list.h" +#include "jaxlib/python_ref_manager.h" +#include "jaxlib/sharding.h" +#include "xla/hlo/ir/hlo_sharding.h" +#include "xla/pjrt/mlir_to_hlo.h" +#include "xla/pjrt/pjrt_executable.h" +#include "xla/pjrt/status_casters.h" +#include "xla/python/ifrt/array_spec.h" +#include "xla/python/ifrt/compiler.h" +#include "xla/python/ifrt/custom_call_program.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/device_list.h" +#include "xla/python/ifrt/hlo/hlo_program.h" +#include "xla/python/ifrt/host_callback.h" +#include "xla/python/ifrt/memory.h" +#include "xla/python/ifrt/plugin_program.h" +#include "xla/python/ifrt/program.h" +#include "xla/python/ifrt/shape.h" +#include "xla/python/ifrt/sharding.h" +#include "xla/python/nb_numpy.h" +#include "xla/python/pjrt_ifrt/xla_compiler.h" +#include "xla/python/pjrt_ifrt/xla_sharding.h" +#include "xla/python/types.h" +#include "xla/python/version.h" +#include "xla/tsl/concurrency/ref_count.h" +#include "xla/tsl/platform/statusor.h" + +namespace xla { + +namespace nb = ::nanobind; + +namespace { + +// Gets `ifrt::DeviceList` from a sequence of JAX devices. +absl::StatusOr GetDeviceList(nb::sequence devices) { + ifrt::DeviceListRef ifrt_device_list; + if (devices.type().is(jax::PyDeviceList::type())) { + return nb::cast(devices)->ifrt_device_list(); + } else { + auto py_devices = nb::cast>>(devices); + if (py_devices.empty()) { + return absl::InvalidArgumentError( + "Colocated Python program requires at least one device"); + } + absl::InlinedVector ifrt_devices; + ifrt_devices.reserve(py_devices.size()); + for (const nb_class_ptr& py_device : py_devices) { + ifrt_devices.push_back(py_device->device()); + } + return py_devices.front()->client()->ifrt_client()->MakeDeviceList( + ifrt_devices); + } +} + +// Gets `xla::HloSharding` from a JAX Sharding. +xla::HloSharding GetXlaHloSharding(nb::handle sharding, + int64_t num_dimensions) { + if (sharding.type().is(jax::GSPMDSharding::type())) { + return nb::cast(sharding)->hlo_sharding(); + } else { + return nb::cast( + sharding.attr("_to_xla_hlo_sharding")(num_dimensions)); + } +} + +// Gets `ifrt::DeviceList` from a JAX Sharding. +absl::StatusOr GetIfrtDeviceList(nb::handle sharding) { + if (sharding.type().is(jax::NamedSharding::type())) { + TF_ASSIGN_OR_RETURN( + auto ns_device_list, + nb::cast(sharding)->internal_device_list()); + return ns_device_list->ifrt_device_list(); + } else if (sharding.type().is(jax::SingleDeviceSharding::type())) { + return nb::cast(sharding) + ->internal_device_list() + ->ifrt_device_list(); + } else if (sharding.type().is(jax::PmapSharding::type())) { + return nb::cast(sharding) + ->internal_device_list() + ->ifrt_device_list(); + } else if (sharding.type().is(jax::GSPMDSharding::type())) { + return nb::cast(sharding) + ->internal_device_list() + ->ifrt_device_list(); + } else { + return nb::cast( + sharding.attr("_internal_device_list")) + ->ifrt_device_list(); + } +} + +// Gets `ifrt::MemoryKind` from a JAX Sharding. +ifrt::MemoryKind GetIfrtMemoryKind(nb::handle sharding) { + auto memory_kind = sharding.attr("memory_kind"); + if (memory_kind.is_none()) { + return ifrt::MemoryKind(); + } else { + return ifrt::MemoryKind(nb::cast(memory_kind)); + } +} + +// Makes `ifrt::Sharding` from a JAX Sharding. It requires the number of shape +// dimensions, which may become necessary when building an HLO sharding. +absl::StatusOr GetIfrtSharding(nb::handle sharding, + int64_t num_dimensions) { + auto ifrt_memory_kind = GetIfrtMemoryKind(sharding); + ifrt::ShardingRef ifrt_sharding; + if (sharding.type().is(jax::SingleDeviceSharding::type())) { + TF_ASSIGN_OR_RETURN(auto ifrt_device_list, + nb::cast(sharding) + ->internal_device_list() + ->ifrt_device_list()); + return ifrt::SingleDeviceSharding::Create( + ifrt_device_list->devices().front(), ifrt_memory_kind); + } else { + TF_ASSIGN_OR_RETURN(auto ifrt_device_list, GetIfrtDeviceList(sharding)); + auto xla_hlo_sharding = GetXlaHloSharding(sharding, num_dimensions); + return ifrt::HloSharding::Create(std::move(ifrt_device_list), + ifrt_memory_kind, + std::move(xla_hlo_sharding)); + } +} + +// Gets `ifrt::ArraySpec`s from a sequence of JAX avals (e.g., +// `jax.ShapeDtypeStruct`). +absl::StatusOr> GetIfrtArraySpecs( + nb::sequence avals) { + std::vector ifrt_array_specs; + ifrt_array_specs.reserve(nb::len(avals)); + for (nb::handle aval : avals) { + ifrt::Shape ifrt_shape(nb::cast>(aval.attr("shape"))); + TF_ASSIGN_OR_RETURN( + auto ifrt_dtype, + DtypeToIfRtDType(nb::cast(aval.attr("dtype")))); + TF_ASSIGN_OR_RETURN( + auto ifrt_sharding, + GetIfrtSharding(aval.attr("sharding"), ifrt_shape.dims().size())); + ifrt_array_specs.push_back(ifrt::ArraySpec{ + ifrt_dtype, std::move(ifrt_shape), std::move(ifrt_sharding)}); + } + return ifrt_array_specs; +} + +absl::StatusOr> MakePluginProgramFromString( + std::string data) { + auto plugin_program = std::make_unique(); + plugin_program->data = std::move(data); + return plugin_program; +} + +absl::StatusOr> MakePluginProgramFromBytes( + nb::bytes data) { + auto plugin_program = std::make_unique(); + plugin_program->data = std::string(data.c_str(), data.size()); + return plugin_program; +} + +absl::StatusOr> +MakeColocatedPythonCompileOptions() { + return std::make_unique(); +} + +absl::StatusOr> +MakePluginCompileOptions() { + return std::make_unique(); +} + +absl::StatusOr> MakeHloProgram( + absl::string_view mlir_module) { + auto context = std::make_unique(); + TF_ASSIGN_OR_RETURN(mlir::OwningOpRef module, + ParseMlirModuleString(mlir_module, *context)); + return std::make_unique(std::move(context), + std::move(module)); +} + +absl::StatusOr> MakeHloProgramFromString( + std::string mlir_module) { + return MakeHloProgram(mlir_module); +} + +absl::StatusOr> MakeHloProgramFromBytes( + nb::bytes mlir_module) { + return MakeHloProgram( + absl::string_view(mlir_module.c_str(), mlir_module.size())); +} + +absl::StatusOr> MakeXlaCompileOptions( + CompileOptions options, jax::PyDeviceList& py_executable_devices, + std::vector host_callbacks) { + std::vector> + ifrt_loaded_host_callbacks; + ifrt_loaded_host_callbacks.reserve(host_callbacks.size()); + // Extract `ifrt::LoadedHostCallback`s from host callback capsules that were + // created by `PyClient::MakePythonCallbackUsingHostSendAndRecv()` or + // `PyClient::GetEmitPythonCallbackDescriptor()`. + for (auto& host_callback : host_callbacks) { + ifrt_loaded_host_callbacks.push_back(tsl::FormRef( + static_cast(host_callback.data()))); + } + TF_ASSIGN_OR_RETURN(ifrt::DeviceListRef executable_devices, + py_executable_devices.ifrt_device_list()); + return std::make_unique( + std::move(options), std::move(executable_devices), + std::move(ifrt_loaded_host_callbacks)); +} + +constexpr absl::string_view kColocatedPythonProgramType = + "jax_colocated_python_v0.0.1"; + +absl::StatusOr> MakeColocatedPythonProgram( + std::string name, nb::bytes picked_function, nb::sequence devices, + nb::sequence input_avals, nb::sequence output_avals) { + auto ifrt_serialized_program_text = absl::MakeCordFromExternal( + absl::string_view(reinterpret_cast(picked_function.data()), + picked_function.size()), + /*releaser=*/[picked_function](absl::string_view) mutable { + GlobalPyRefManager()->AddGarbage(std::move(picked_function)); + }); + TF_ASSIGN_OR_RETURN(auto ifrt_device_list, GetDeviceList(devices)); + TF_ASSIGN_OR_RETURN(auto ifrt_input_specs, GetIfrtArraySpecs(input_avals)); + TF_ASSIGN_OR_RETURN(auto ifrt_output_specs, GetIfrtArraySpecs(output_avals)); + return std::make_unique( + std::string(kColocatedPythonProgramType), std::move(name), + std::move(ifrt_serialized_program_text), std::move(ifrt_device_list), + std::move(ifrt_input_specs), std::move(ifrt_output_specs)); +} + +} // namespace + +void BuildIfrtProgramsSubmodule(nanobind::module_& m) { + auto sub_module = m.def_submodule("ifrt_programs"); + nb::class_ ifrt_program_base_class(sub_module, "Program"); + nb::class_ ifrt_compile_options_base_class( + sub_module, "CompileOptions"); + sub_module + .def("make_hlo_program", ValueOrThrowWrapper(MakeHloProgramFromString), + nb::arg("mlir_module")) + .def("make_hlo_program", ValueOrThrowWrapper(MakeHloProgramFromBytes), + nb::arg("mlir_module")) + .def("make_colocated_python_program", + ValueOrThrowWrapper(MakeColocatedPythonProgram), nb::arg("name"), + nb::arg("pickled_function"), nb::arg("devices"), + nb::arg("input_avals"), nb::arg("output_avals")) + .def("make_plugin_program", + ValueOrThrowWrapper(MakePluginProgramFromString), nb::arg("data")) + .def("make_plugin_program", + ValueOrThrowWrapper(MakePluginProgramFromBytes), nb::arg("data")) + .def("make_xla_compile_options", + ValueOrThrowWrapper(MakeXlaCompileOptions), nb::arg("options"), + nb::arg("executable_devices"), nb::arg("host_callbacks")) + .def("make_colocated_python_compile_options", + ValueOrThrowWrapper(MakeColocatedPythonCompileOptions)) + .def("make_plugin_compile_options", + ValueOrThrowWrapper(MakePluginCompileOptions)); +} + +} // namespace xla diff --git a/jaxlib/py_program.h b/jaxlib/py_program.h new file mode 100644 index 000000000000..7772d740c41e --- /dev/null +++ b/jaxlib/py_program.h @@ -0,0 +1,27 @@ +/* Copyright 2024 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_PY_PROGRAM_H_ +#define JAXLIB_PY_PROGRAM_H_ + +#include "nanobind/nanobind.h" + +namespace xla { + +void BuildIfrtProgramsSubmodule(nanobind::module_& m); + +} // namespace xla + +#endif // JAXLIB_PY_PROGRAM_H_ diff --git a/jaxlib/py_socket_transfer.cc b/jaxlib/py_socket_transfer.cc new file mode 100644 index 000000000000..69321aa788d5 --- /dev/null +++ b/jaxlib/py_socket_transfer.cc @@ -0,0 +1,409 @@ +/* Copyright 2025 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "jaxlib/py_socket_transfer.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "llvm/Support/Casting.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/array.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "jaxlib/nb_class_ptr.h" +#include "jaxlib/py_array.h" +#include "jaxlib/py_client.h" +#include "jaxlib/to_ifrt_sharding.h" +#include "jaxlib/traceback.h" +#include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/status_casters.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/array_spec.h" +#include "xla/python/ifrt/dtype.h" +#include "xla/python/ifrt/memory.h" +#include "xla/python/ifrt/shape.h" +#include "xla/python/ifrt/sharding.h" +#include "xla/python/nb_numpy.h" +#include "xla/python/pjrt_ifrt/pjrt_array.h" +#include "xla/python/pjrt_ifrt/pjrt_device.h" +#include "xla/python/pjrt_ifrt/pjrt_dtype.h" +#include "xla/python/pjrt_ifrt/pjrt_memory.h" +#include "xla/python/transfer/event_loop.h" +#include "xla/python/transfer/socket-server.h" +#include "xla/python/transfer/socket_bulk_transport.h" +#include "xla/python/transfer/streaming.h" +#include "xla/python/transfer/streaming_ifrt.h" +#include "xla/python/transfer/transfer_socket.pb.h" +#include "xla/python/types.h" +#include "xla/python/version.h" +#include "xla/tsl/concurrency/ref_count.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/util.h" +#include "tsl/platform/casts.h" + +namespace aux { + +namespace nb = nanobind; + +absl::StatusOr MemorySpaceFromSharding( + const xla::ifrt::Sharding& sharding) { + if (sharding.devices()->devices().size() != 1) { + return xla::InvalidArgument( + "Can only convert SingleDeviceSharding to MemorySpace not %s", + sharding.DebugString()); + } + auto* device = sharding.devices()->devices()[0]; + if (sharding.memory_kind().memory_kind().has_value()) { + // Find `PjRtMemorySpace` that is associated with the sharding's device + // and matches the sharding's memory_kind. + xla::ifrt::Memory* memory = nullptr; + for (xla::ifrt::Memory* ms : device->Memories()) { + if (ms->Kind() == sharding.memory_kind()) { + memory = ms; + break; + } + } + if (memory == nullptr) { + return xla::InvalidArgument( + "Invalid memory kind: %s; available memory kinds: %s", + *sharding.memory_kind().memory_kind(), + absl::StrJoin(sharding.devices()->devices().front()->Memories(), ", ", + [](std::string* out, xla::ifrt::Memory* ms) { + absl::StrAppend(out, *ms->Kind().memory_kind()); + })); + } + return tensorflow::down_cast(memory)->pjrt_memory(); + } else { + if (!device->IsAddressable()) { + return xla::InvalidArgument( + "Cannot copy array to non-addressable device %s", + device->DebugString()); + } + return tensorflow::down_cast(device) + ->pjrt_device() + ->default_memory_space(); + } +} + +absl::StatusOr> CreatePullEntry( + const std::vector& arrs, + std::shared_ptr state, size_t xfer_size, + bool use_raw_buffers) { + if (use_raw_buffers) { + std::vector refs; + for (auto& arr : arrs) { + auto* pjrt_arr = llvm::dyn_cast_or_null(arr.get()); + if (pjrt_arr == nullptr) { + return absl::InvalidArgumentError( + "Cannot remote transfer non-pjrt arrays."); + } + for (auto& pjrt_buf : pjrt_arr->pjrt_buffers()) { + TF_ASSIGN_OR_RETURN(size_t buf_size, + pjrt_buf->GetOnDeviceSizeInBytes()); + TF_ASSIGN_OR_RETURN( + auto raw_buffer, + xla::PjRtRawBuffer::CreateRawAliasOfBuffer(pjrt_buf.get())); + refs.push_back( + {pjrt_buf->GetReadyFuture(), std::move(raw_buffer), buf_size}); + } + } + return tsl::MakeRef(std::move(refs), state, xfer_size); + } + + std::vector refs; + for (auto& arr : arrs) { + auto* pjrt_arr = llvm::dyn_cast_or_null(arr.get()); + if (pjrt_arr == nullptr) { + return absl::InvalidArgumentError( + "Cannot remote transfer non-pjrt arrays."); + } + for (auto& pjrt_buf : pjrt_arr->pjrt_buffers()) { + TF_ASSIGN_OR_RETURN(size_t buf_size, pjrt_buf->GetOnDeviceSizeInBytes()); + refs.push_back({pjrt_buf, buf_size}); + } + } + return tsl::MakeRef(std::move(refs), state, xfer_size); +} + +class PyTransferServerConnection { + public: + explicit PyTransferServerConnection( + tsl::RCReference conn) + : conn_(std::move(conn)) {} + + void Pull(uint64_t uuid, std::vector buffer_ids, + std::vector> pull_dests) { + for (size_t i = 0; i < buffer_ids.size(); ++i) { + conn_->Pull(uuid, buffer_ids[i], std::move(pull_dests[i])); + } + } + + SocketServer::Connection& conn() { return *conn_; } + + private: + tsl::RCReference conn_; +}; + +class PyTransferServer { + public: + PyTransferServer() = default; + absl::Status Start(xla::ifrt::Client* client, size_t max_num_parallel_copies, + size_t xfer_size, const SocketAddress& addr, + const std::vector& transport_addresses, + bool supports_pinned_allocator, bool use_raw_buffers) { + use_raw_buffers_ = use_raw_buffers; + std::shared_ptr factory; + if (transport_addresses.empty()) { + factory = BulkTransportFactory::CreateLocal(); + } else { + auto tmp = xla::ValueOrThrow( + AllocateAlignedMemory(xfer_size * max_num_parallel_copies)); + SlabAllocator uallocator(xla::ValueOrThrow(MapPjrtMemory( + client, tmp->data(), tmp->size(), tmp)), + xfer_size); + std::optional pinned_allocator; + if (supports_pinned_allocator) { + auto tmp = xla::ValueOrThrow( + AllocateNetworkPinnedMemory(xfer_size * max_num_parallel_copies)); + pinned_allocator.emplace(xla::ValueOrThrow(MapPjrtMemory( + client, tmp->data(), tmp->size(), tmp)), + xfer_size); + } + factory = xla::ValueOrThrow(CreateSocketBulkTransportFactory( + transport_addresses, pinned_allocator, uallocator)); + } + + server_ = std::make_shared(); + + TF_ASSIGN_OR_RETURN(auto mem, + AllocateAndMapPjrtMemory( + client, max_num_parallel_copies * xfer_size * 2)); + premapped_copier_ = std::make_shared( + mem, max_num_parallel_copies, xfer_size); + xfer_size_ = xfer_size; + return server_->Start(addr, factory); + } + std::string address() { return server_->addr().ToString(); } + + PyTransferServerConnection Connect(const std::string& saddr) { + return PyTransferServerConnection( + server_->Connect(xla::ValueOrThrow(SocketAddress::Parse(saddr)))); + } + + void AwaitPull(uint64_t uuid, const std::vector& arrs) { + server_->AwaitPull( + uuid, xla::ValueOrThrow(CreatePullEntry(arrs, premapped_copier_, + xfer_size_, use_raw_buffers_))); + } + + size_t xfer_size() { return xfer_size_; } + + std::shared_ptr premapped_copier() { + return premapped_copier_; + } + + private: + std::shared_ptr server_; + std::shared_ptr premapped_copier_; + size_t xfer_size_; + bool use_raw_buffers_ = false; +}; + +absl::StatusOr ArraySpecFromShapeDtypeStruct( + nb::handle aval) { + TF_ASSIGN_OR_RETURN(xla::ifrt::DType dtype, + xla::DtypeToIfRtDType( + nb::borrow(aval.attr("dtype").ptr()))); + auto shape_dims = nb::cast>(aval.attr("shape")); + auto shape = xla::ifrt::Shape( + xla::ifrt::Shape::Dimensions(shape_dims.begin(), shape_dims.end())); + TF_ASSIGN_OR_RETURN(auto sharding, + xla::GetIfrtHloSharding(aval.attr("sharding"), shape)); + return xla::ifrt::ArraySpec{dtype, std::move(shape), std::move(sharding)}; +} + +struct BufferSource { + xla::ifrt::ArrayRef arr; + xla::PjRtBuffer* buffer; +}; + +struct CopyDests { + std::vector shape_specs; + xla::PjRtMemorySpace* memory_space; +}; + +void RegisterTransferServerTypes(nanobind::module_& m) { + nb::class_(m, "TransferConnection") +#if JAX_IFRT_VERSION_NUMBER > 9 + .def( + "_testonly_inject_failure", + [](PyTransferServerConnection& self) { self.conn().InjectFailure(); }) +#endif + .def("_pull_flat", [](PyTransferServerConnection& self, uint64_t uuid, + xla::nb_class_ptr py_client, + std::vector py_avals) { + auto* ifrt_client = llvm::dyn_cast_or_null( + py_client->ifrt_client()); + if (ifrt_client == nullptr) { + xla::ThrowIfError(absl::InvalidArgumentError( + "_pull_flat only supported on pjrt-ifrt clients.")); + } + + std::vector avals; + std::vector shardings; + shardings.reserve(py_avals.size()); + avals.reserve(py_avals.size()); + for (const auto& py_aval : py_avals) { + avals.push_back( + xla::ValueOrThrow(ArraySpecFromShapeDtypeStruct(py_aval))); + shardings.push_back(py_aval.attr("sharding")); + } + + std::vector dests; + std::vector> fetch_idxs; + absl::flat_hash_map mapping; + std::vector>> buffer_list; + + for (auto& aval : avals) { + std::vector> buf_list; + auto prim_type = + xla::ValueOrThrow(xla::ifrt::ToPrimitiveType(aval.dtype)); + auto shards = xla::ValueOrThrow(aval.sharding->Disassemble( + aval.shape, + xla::ifrt::SingleDeviceShardSemantics::kAddressableShards)); + buf_list.reserve(shards.size()); + for (auto& shard : shards) { + auto* mem_space = + xla::ValueOrThrow(MemorySpaceFromSharding(*shard.second)); + int dest_idx = + mapping.emplace(mem_space, static_cast(dests.size())) + .first->second; + if (dest_idx == dests.size()) { + dests.emplace_back(); + dests.back().memory_space = mem_space; + } + fetch_idxs.push_back( + {dest_idx, + static_cast(dests[dest_idx].shape_specs.size())}); + buf_list.push_back(fetch_idxs.back()); + dests[dest_idx].shape_specs.push_back( + {prim_type, xla::DimensionVector(shard.first.dims().begin(), + shard.first.dims().end())}); + } + buffer_list.push_back(std::move(buf_list)); + } + + std::vector< + std::shared_ptr> + atms; + atms.reserve(dests.size()); + + for (auto& dest : dests) { + atms.push_back(xla::ValueOrThrow( + py_client->pjrt_client()->CreateBuffersForAsyncHostToDevice( + dest.shape_specs, std::nullopt, dest.memory_space))); + } + + std::vector> pull_dests; + std::vector buffer_ids; + pull_dests.reserve(fetch_idxs.size()); + buffer_ids.reserve(fetch_idxs.size()); + for (auto& fetch_idx : fetch_idxs) { + auto& atm = atms[fetch_idx.first]; + pull_dests.push_back(MakeDmaDestination( + atm, fetch_idx.second, atm->buffer_size(fetch_idx.second))); + buffer_ids.push_back(static_cast(buffer_ids.size())); + } + + self.Pull(uuid, buffer_ids, std::move(pull_dests)); + + std::vector out; + auto traceback = xla::Traceback::Get(); + for (size_t i = 0; i < buffer_list.size(); ++i) { + xla::ifrt::PjRtArray::PjRtBuffers buffers; + buffers.reserve(buffer_list[i].size()); + for (auto& v : buffer_list[i]) { + buffers.push_back(atms[v.first]->RetrieveBuffer(v.second)); + } + auto arr = xla::ValueOrThrow(xla::ifrt::PjRtArray::Create( + ifrt_client, avals[i].dtype, avals[i].shape, avals[i].sharding, + std::move(buffers), avals[i].layout)); + out.push_back(xla::PyArray::MakeFromIfrtArrayAndSharding( + py_client, traceback, std::move(arr), shardings[i], false, true, + /*skip_checks=*/false)); + } + + return out; + }); + + nb::class_(m, "TransferServer") + .def("address", [](PyTransferServer& self) { return self.address(); }) + .def("_await_pull_flat", + [](PyTransferServer& self, uint64_t uuid, + std::vector inputs) { + std::vector arrs; + arrs.reserve(inputs.size()); + for (const xla::PyArray& input : inputs) { + arrs.push_back(tsl::FormRef(input.ifrt_array())); + } + self.AwaitPull(uuid, arrs); + }) + .def("connect", [](PyTransferServer& self, const std::string& address) { + return self.Connect(address); + }); + + m.def( + "start_transfer_server", + [](xla::nb_class_ptr py_client, std::string address, + std::vector transport_addresses_str, + size_t max_num_parallel_copies, size_t transfer_size, + bool supports_pinned_allocator, + bool use_raw_buffers) -> PyTransferServer { + PyTransferServer result; + std::vector transport_addresses; + transport_addresses.reserve(transport_addresses_str.size()); + for (const std::string& addr : transport_addresses_str) { + transport_addresses.push_back( + xla::ValueOrThrow(SocketAddress::Parse(addr))); + } + xla::ThrowIfError(result.Start( + py_client->ifrt_client(), max_num_parallel_copies, transfer_size, + xla::ValueOrThrow(SocketAddress::Parse(address)), + transport_addresses, supports_pinned_allocator, use_raw_buffers)); + return result; + }, + nb::arg("client"), nb::arg("address") = SocketAddress().ToString(), + nb::arg("transport_addresses") = std::vector(), + nb::arg("max_num_parallel_copies") = 8, + nb::arg("transfer_size") = 256 * 1024 * 1024, + // Dual pinning not confirmed to be supported. + nb::arg("supports_pinned_allocator") = false, + // Technically unsafe (because a future donation won't wait for the + // transfer to complete). + nb::arg("use_raw_buffers") = false); +} + +} // namespace aux diff --git a/jaxlib/py_socket_transfer.h b/jaxlib/py_socket_transfer.h new file mode 100644 index 000000000000..1b0236b56889 --- /dev/null +++ b/jaxlib/py_socket_transfer.h @@ -0,0 +1,26 @@ +/* Copyright 2025 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef JAXLIB_TRANSFER_PY_SOCKET_TRANSFER_H_ +#define JAXLIB_TRANSFER_PY_SOCKET_TRANSFER_H_ + +#include "nanobind/nanobind.h" + +namespace aux { + +void RegisterTransferServerTypes(nanobind::module_& m); + +} // namespace aux + +#endif // JAXLIB_TRANSFER_PY_SOCKET_TRANSFER_H_ diff --git a/jaxlib/py_values.cc b/jaxlib/py_values.cc new file mode 100644 index 000000000000..987a51eb67cf --- /dev/null +++ b/jaxlib/py_values.cc @@ -0,0 +1,1100 @@ +/* Copyright 2020 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/py_values.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/inlined_vector.h" +#include "absl/functional/any_invocable.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/cord.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/complex.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "jaxlib/py_array.h" +#include "jaxlib/python_ref_manager.h" +#include "jaxlib/sharding.h" +#include "jaxlib/to_ifrt_sharding.h" +#include "xla/primitive_util.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/array_spec.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/device_list.h" +#include "xla/python/ifrt/dtype.h" +#include "xla/python/ifrt/memory.h" +#include "xla/python/ifrt/shape.h" +#include "xla/python/ifrt/sharding.h" +#include "xla/python/ifrt/user_context.h" +#include "xla/python/nb_numpy.h" +#include "xla/python/pjrt_ifrt/pjrt_dtype.h" +#include "xla/python/safe_static_init.h" +#include "xla/python/types.h" +#include "xla/shape.h" +#include "xla/tsl/concurrency/ref_count.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/tsl/python/lib/core/numpy.h" +#include "xla/types.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/ml_dtypes.h" +#include "tsl/profiler/lib/traceme.h" + +namespace nb = nanobind; + +namespace xla { + +namespace { + +// Gets the thread-local instance. +static DevicePutInfo& GetDevicePutInfo() { + thread_local DevicePutInfo device_put_info; + return device_put_info; +} + +// Prepared data for creating a single shard of an array. Holds a single-device +// IFRT array or a host buffer. +struct Shard { + explicit Shard(ifrt::ArrayRef ifrt_array, bool weak_type) + : ifrt_array_or_host_buffer(std::move(ifrt_array)), + weak_type(weak_type), + // host_buffer_semantics is not meaningful when + // `ifrt_array_or_host_buffer` is an IFRT Array. + host_buffer_semantics( + ifrt::Client::HostBufferSemantics::kImmutableOnlyDuringCall) {} + + Shard(ifrt::Client::HostBuffer ifrt_host_buffer, bool weak_type, + ifrt::Client::HostBufferSemantics host_buffer_semantics) + : ifrt_array_or_host_buffer(std::move(ifrt_host_buffer)), + weak_type(weak_type), + host_buffer_semantics(host_buffer_semantics) {} + + Shard(const Shard&) = delete; + Shard& operator=(const Shard&) = delete; + Shard(Shard&&) noexcept = default; + Shard& operator=(Shard&&) noexcept = default; + + bool is_ifrt_array() const { + return std::holds_alternative(ifrt_array_or_host_buffer); + } + ifrt::DType ifrt_dtype() const; + const ifrt::Shape& ifrt_shape() const; + + // Points to the on-device array or on-host buffer. + std::variant + ifrt_array_or_host_buffer; + bool weak_type; + ifrt::Client::HostBufferSemantics host_buffer_semantics; +}; + +// A function that creates a `Shard` from a Python object when called. +using ShardFn = absl::AnyInvocable() &&>; + +absl::StatusOr> StringDTypeArrayToCords( + PyArrayObject* py_array_obj) { + if (PyArray_SIZE(py_array_obj) == 0) { + return absl::InvalidArgumentError("empty numpy array"); + } + + std::vector cords; + cords.reserve(PyArray_SIZE(py_array_obj)); + + auto iter = + nb::steal(PyArray_IterNew(reinterpret_cast(py_array_obj))); + while (PyArray_ITER_NOTDONE(iter.ptr())) { + auto* iter_data = PyArray_ITER_DATA(iter.ptr()); + auto* item = PyArray_GETITEM(py_array_obj, static_cast(iter_data)); + if (!item) { + return absl::InternalError( + "Failed to get elements out of the ndarray iter."); + } + Py_ssize_t len; + auto str = PyUnicode_AsUTF8AndSize(item, &len); + cords.push_back(absl::Cord(absl::string_view(str, len))); + PyArray_ITER_NEXT(iter.ptr()); + } + return cords; +} + +// Handler that creates a `Shard` from a Python object. +using DevicePutHandler = std::function( + nb::handle obj, ifrt::Client* client, ifrt::Device* to_device, + ifrt::MemoryKind to_memory_kind, const DevicePutOptions& options)>; + +// Shared logic that makes an IFRT array (either single-device or multi-device) +// from a fully-replicated `shard` that is created from a host buffer (not from +// an existing IFRT array). `shard` will be consumed. +// +// `user_context` will be used for a new IFRT array created. +// +// Expected to be called without holding GIL. +absl::StatusOr> +MakeIfrtArrayFromFullyReplicatedShard( + ifrt::Client* ifrt_client, ifrt::ShardingRef ifrt_sharding, Shard& shard, + tsl::RCReference user_context) { + auto host_buffer_shard = std::get( + std::move(shard.ifrt_array_or_host_buffer)); + return ifrt_client->MakeArrayFromHostBuffer( + host_buffer_shard.data, host_buffer_shard.dtype, + std::move(host_buffer_shard.shape), + std::move(host_buffer_shard.byte_strides), std::move(ifrt_sharding), + shard.host_buffer_semantics, std::move(host_buffer_shard.on_done), + std::move(user_context)); +} + +// Shared logic that makes a single-device IFRT array from a `shard`. `shard` +// will be consumed. +// +// `user_context` will be used for a new IFRT array created from the host +// buffer, and be not applied when reusing an existing IFRT array. +// +// Expected to be called without holding GIL. +absl::StatusOr MakeSingleDeviceIfrtArrayFromShard( + xla::ifrt::Client* ifrt_client, xla::ifrt::Device* ifrt_device, + xla::ifrt::MemoryKind ifrt_memory_kind, Shard& shard, + tsl::RCReference user_context) { + if (auto* ifrt_array = + std::get_if(&shard.ifrt_array_or_host_buffer)) { + return std::move(*ifrt_array); + } + ifrt::ShardingRef ifrt_sharding = + ifrt::SingleDeviceSharding::Create(ifrt_device, ifrt_memory_kind); + return MakeIfrtArrayFromFullyReplicatedShard( + ifrt_client, std::move(ifrt_sharding), shard, std::move(user_context)); +} + +// Makes an IFRT Array from `shards` using a batched array creation API (fast +// path). `shards` will be consumed. +// +// Expected to be called without holding GIL. +absl::StatusOr MakeIfrtArrayFromShardsInBatch( + ifrt::Client* ifrt_client, ifrt::DType ifrt_dtype, ifrt::Shape ifrt_shape, + ifrt::ShardingRef ifrt_sharding, absl::Span shards, + tsl::RCReference user_context) { + absl::InlinedVector< + std::pair, ifrt::Client::HostBuffer>, 1> + host_buffers; + host_buffers.reserve(shards.size()); + ifrt::Client::HostBufferSemantics safe_host_semantics = + ifrt::Client::HostBufferSemantics::kImmutableZeroCopy; + // TODO(hyeontaek): Deduplicate shards here or early on to create a unique + // HostBuffer for each set of replicated shards. + for (int64_t i = 0; i < shards.size(); ++i) { + host_buffers.push_back({{i}, + std::get(std::move( + shards[i].ifrt_array_or_host_buffer))}); + // The minimum host buffer semantics is a safe semantics that can be used + // for all shards when they are created in a single batch. + safe_host_semantics = + std::min(safe_host_semantics, shards[i].host_buffer_semantics); + } + + std::vector specs; + specs.push_back(ifrt::Client::MakeArraysFromHostBufferShardsSpec{ + std::move(host_buffers), + ifrt::ArraySpec{/*dtype=*/ifrt_dtype, + /*shape=*/std::move(ifrt_shape), + /*sharding=*/std::move(ifrt_sharding), + /*layout=*/nullptr}}); + TF_ASSIGN_OR_RETURN( + auto arrays, + ifrt_client->MakeArraysFromHostBufferShards( + absl::MakeSpan(specs), safe_host_semantics, std::move(user_context))); + return std::move(arrays.front()); +} + +// Makes an IFRT Array from `shards` using an array assembly API (slow path). +// `shards` will be consumed. +// +// Expected to be called without holding GIL. +absl::StatusOr MakeIfrtArrayFromShardsWithAssembly( + ifrt::Client* ifrt_client, ifrt::DType ifrt_dtype, ifrt::Shape ifrt_shape, + ifrt::ShardingRef ifrt_sharding, + ifrt::DeviceList* ifrt_addressable_device_list, + ifrt::MemoryKind ifrt_memory_kind, absl::Span shards, + tsl::RCReference user_context) { + absl::Span ifrt_addressable_devices = + ifrt_addressable_device_list->devices(); + std::vector ifrt_array_shards; + ifrt_array_shards.reserve(shards.size()); + for (int64_t i = 0; i < shards.size(); ++i) { + TF_ASSIGN_OR_RETURN(ifrt::ArrayRef ifrt_array_shard, + MakeSingleDeviceIfrtArrayFromShard( + ifrt_client, ifrt_addressable_devices[i], + ifrt_memory_kind, shards[i], user_context)); + ifrt_array_shards.push_back(std::move(ifrt_array_shard)); + } + return ifrt_client->AssembleArrayFromSingleDeviceArrays( + ifrt_dtype, std::move(ifrt_shape), std::move(ifrt_sharding), + absl::MakeSpan(ifrt_array_shards), ifrt::ArrayCopySemantics::kReuseInput, + ifrt::SingleDeviceShardSemantics::kAddressableShards); +} + +template +absl::StatusOr HandlePythonScalar(nb::handle obj, ifrt::Client* client, + ifrt::Device* to_device, + ifrt::MemoryKind to_memory_kind, + const DevicePutOptions& options) { + T value; + try { + value = nb::cast(obj); + } catch (const std::exception& e) { + return InvalidArgument( + "Unable to convert Python scalar to %s. This most likely means the " + "value (%s) overflows the range of the type.", + PrimitiveType_Name(primitive_util::NativeToPrimitiveType()), + nb::cast(nb::repr(obj))); + } + + std::variant data; + Shape shape; + PrimitiveType type; + if (std::is_same() || !options.squash_64bit_types) { + data.template emplace<0>(value); + type = primitive_util::NativeToPrimitiveType(); + } else { + // TODO(phawkins): we should check for overflow here, e.g., because of bugs + // like https://github.com/google/jax/issues/2006 + data.template emplace<1>(static_cast(value)); + type = primitive_util::NativeToPrimitiveType(); + } + TF_ASSIGN_OR_RETURN(ifrt::DType ifrt_dtype, ifrt::ToDType(type)); + + return [data, ifrt_dtype]() -> absl::StatusOr { + const void* ptr = std::visit( + [](const auto& v) { return static_cast(&v); }, data); + ifrt::Client::HostBuffer ifrt_host_buffer{ + ptr, ifrt_dtype, ifrt::Shape({}), + /*byte_strides=*/std::nullopt, + /*on_done_with_host_buffer=*/nullptr}; + return Shard(std::move(ifrt_host_buffer), /*weak_type=*/true, + ifrt::Client::HostBufferSemantics::kImmutableOnlyDuringCall); + }; +} + +absl::StatusOr HandlePythonInt(nb::handle obj, ifrt::Client* client, + ifrt::Device* to_device, + ifrt::MemoryKind to_memory_kind, + const DevicePutOptions& options) { + PrimitiveType type; + std::variant data; + + if (options.squash_64bit_types) { + try { + data.emplace<1>(nb::cast(obj)); + } catch (const std::exception& e) { + return InvalidArgument( + "Unable to convert Python scalar to %s. This most likely means the " + "value (%s) overflows the range of the type.", + PrimitiveType_Name(primitive_util::NativeToPrimitiveType()), + nb::cast(nb::repr(obj))); + } + type = S32; + } else { + try { + data.emplace<0>(nb::cast(obj)); + } catch (const std::exception& e) { + return InvalidArgument( + "Unable to convert Python scalar to %s. This most likely means the " + "value (%s) overflows the range of the type.", + PrimitiveType_Name(primitive_util::NativeToPrimitiveType()), + nb::cast(nb::repr(obj))); + } + type = S64; + } + TF_ASSIGN_OR_RETURN(ifrt::DType ifrt_dtype, ifrt::ToDType(type)); + return [data, ifrt_dtype]() -> absl::StatusOr { + const void* ptr = std::visit( + [](const auto& v) { return static_cast(&v); }, data); + ifrt::Client::HostBuffer ifrt_host_buffer{ + ptr, ifrt_dtype, ifrt::Shape({}), + /*byte_strides=*/std::nullopt, + /*on_done_with_host_buffer=*/nullptr}; + return Shard(std::move(ifrt_host_buffer), /*weak_type=*/true, + ifrt::Client::HostBufferSemantics::kImmutableOnlyDuringCall); + }; +} + +template +absl::StatusOr HandleNumpyScalar(nb::handle h, ifrt::Client* client, + ifrt::Device* to_device, + ifrt::MemoryKind to_memory_kind, + const DevicePutOptions& options) { + std::variant data; + PrimitiveType type; + // For extension types, ScalarAsCtype returns a pointer to the data. + if (std::is_same()) { + PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); + type = S2; + } else if (std::is_same()) { + PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); + type = S4; + } else if (std::is_same()) { + PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); + type = U2; + } else if (std::is_same()) { + PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); + type = U4; + } else if (std::is_same()) { + PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); + type = BF16; + } else if (std::is_same()) { + PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); + type = F4E2M1FN; + } else if (std::is_same()) { + PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); + type = F8E3M4; + } else if (std::is_same()) { + PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); + type = F8E4M3; + } else if (std::is_same()) { + PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); + type = F8E4M3FN; + } else if (std::is_same()) { + PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); + type = F8E4M3B11FNUZ; + } else if (std::is_same()) { + PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); + type = F8E5M2; + } else if (std::is_same()) { + PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); + type = F8E4M3FNUZ; + } else if (std::is_same()) { + PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); + type = F8E5M2FNUZ; + } else if (std::is_same()) { + PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); + type = F8E8M0FNU; + } else if (std::is_same() || !options.squash_64bit_types) { + PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<0>()); + type = primitive_util::NativeToPrimitiveType(); + } else { + T value; + PyArray_ScalarAsCtype(h.ptr(), &value); + data.template emplace<1>(static_cast(value)); + type = primitive_util::NativeToPrimitiveType(); + } + std::shared_ptr py_buffer_ref; + if (data.index() == 2) { + py_buffer_ref = + GlobalPyRefManager()->ManageReference(nb::cast(h)); + } + TF_ASSIGN_OR_RETURN(ifrt::DType ifrt_dtype, ifrt::ToDType(type)); + return [data, py_buffer_ref = std::move(py_buffer_ref), + ifrt_dtype]() mutable -> absl::StatusOr { + const void* ptr = std::visit( + [](const auto& v) -> const void* { + if constexpr (std::is_same_v, void*>) { + return v; + } else { + return static_cast(&v); + } + }, + data); + ifrt::Client::HostBuffer ifrt_host_buffer{ + ptr, ifrt_dtype, ifrt::Shape({}), + /*byte_strides=*/std::nullopt, + /*on_done_with_host_buffer=*/ + [py_buffer_ref = + std::move(py_buffer_ref)]() { /* keeps py_buffer_ref alive */ }}; + return Shard(std::move(ifrt_host_buffer), /*weak_type=*/false, + ifrt::Client::HostBufferSemantics::kImmutableOnlyDuringCall); + }; +} + +absl::StatusOr HandleStringNumpyArray( + nb::handle h, ifrt::Client* client, ifrt::Device* to_device, + ifrt::MemoryKind to_memory_kind, const DevicePutOptions& options) { + xla::nb_numpy_ndarray array = nb::cast(h); + auto py_array_obj = reinterpret_cast(array.ptr()); + TF_ASSIGN_OR_RETURN(auto cords, StringDTypeArrayToCords(py_array_obj)); + + // Assemble all the parameters of MakeArrayFromHostBuffer + const void* data = cords.data(); + + // Make an explicit copy of the shape elements so we won't run into complex + // endianness and precision issues that might arise if we reinterpret-casted + // from npy_intp, that can be just 32 bits-wide in some environments + // such as macos_arm64 to const int64_t* that must be 64 bits-wide. + ifrt::Shape::Dimensions dims; + dims.reserve(array.ndim()); + for (int i = 0; i < array.ndim(); ++i) { + dims.push_back(array.shape(i)); + } + ifrt::Shape shape(std::move(dims)); + + auto on_done_with_host_buffer = [cords = std::move(cords)] {}; + + return [data, shape = std::move(shape), + on_done_with_host_buffer = std::move( + on_done_with_host_buffer)]() mutable -> absl::StatusOr { + ifrt::Client::HostBuffer ifrt_host_buffer{ + data, ifrt::DType(ifrt::DType::kString), std::move(shape), + /*byte_strides=*/std::nullopt, std::move(on_done_with_host_buffer)}; + return Shard( + std::move(ifrt_host_buffer), /*weak_type=*/false, + ifrt::Client::HostBufferSemantics::kImmutableUntilTransferCompletes); + }; +} + +absl::StatusOr HandleNumpyArray(nb::handle h, ifrt::Client* client, + ifrt::Device* to_device, + ifrt::MemoryKind to_memory_kind, + const DevicePutOptions& options) { + xla::nb_numpy_ndarray array = nb::cast(h); + + // String numpy arrays require substantially different processing. + if (array.dtype().char_() == (int)'T' || array.dtype().kind() == 'T') { + return HandleStringNumpyArray(h, client, to_device, to_memory_kind, + options); + } + + TF_ASSIGN_OR_RETURN(PrimitiveType type, DtypeToPrimitiveType(array.dtype())); + + PrimitiveType squashed_type; + if (options.squash_64bit_types) { + squashed_type = Squash64BitTypes(type); + if (squashed_type != type) { + TF_ASSIGN_OR_RETURN(xla::nb_dtype squashed_dtype, + PrimitiveTypeToNbDtype(squashed_type)); + array = nb::steal(PyArray_CastToType( + reinterpret_cast(array.ptr()), + reinterpret_cast(squashed_dtype.release().ptr()), + /*fortran=*/0)); + } + } else { + squashed_type = type; + } + + absl::InlinedVector dims(array.ndim()); + ifrt::Client::HostBuffer::ByteStrides byte_strides(array.ndim()); + for (int i = 0; i < array.ndim(); ++i) { + dims[i] = array.shape(i); + byte_strides[i] = array.strides(i); + } + const void* data = array.data(); + std::shared_ptr py_buffer_ref = + GlobalPyRefManager()->ManageReference(std::move(array)); + TF_ASSIGN_OR_RETURN(ifrt::DType ifrt_dtype, ifrt::ToDType(squashed_type)); + return [data, ifrt_dtype, dims = std::move(dims), + byte_strides = std::move(byte_strides), + py_buffer_ref = std::move(py_buffer_ref), + allow_zero_copy = + options.allow_zero_copy]() mutable -> absl::StatusOr { + ifrt::Client::HostBufferSemantics host_buffer_semantics = + ifrt::Client::HostBufferSemantics::kImmutableOnlyDuringCall; + std::function on_done_with_host_buffer; + if (allow_zero_copy) { + on_done_with_host_buffer = + [py_buffer_ref{ + std::move(py_buffer_ref)}]() { /* keeps py_buffer_ref alive */ }; + host_buffer_semantics = + ifrt::Client::HostBufferSemantics::kImmutableZeroCopy; + } + + ifrt::Client::HostBuffer ifrt_host_buffer{ + data, ifrt_dtype, ifrt::Shape(dims), std::move(byte_strides), + std::move(on_done_with_host_buffer)}; + return Shard(std::move(ifrt_host_buffer), /*weak_type=*/false, + host_buffer_semantics); + }; +} + +absl::StatusOr HandlePyArray(nb::handle obj, ifrt::Client* client, + ifrt::Device* to_device, + ifrt::MemoryKind to_memory_kind, + const DevicePutOptions& options) { + auto py_array = nb::borrow(obj); + + // We only allow single device case for PyArray in device put. + if (py_array.num_shards() != 1) { + return InvalidArgument( + "device_put expects an array with exactly one shard, got an array with " + "with %d shards.", + py_array.num_shards()); + } + + ifrt::Array* ifrt_array = py_array.ifrt_array(); + if (ifrt_array == nullptr) { + return InvalidArgument("Array has been deleted."); + } + + // Fallback to python for non-matching clients or pmap sharding. + if (py_array.sharding().type().ptr() == jax::PmapSharding::type().ptr() || + ifrt_array->sharding().devices()->devices().front()->client() != + to_device->client()) { + return HandleNumpyArray(obj.attr("_value"), client, to_device, + to_memory_kind, options); + } + + if (ifrt_array->sharding().devices()->devices().front() == to_device && + options.allow_zero_copy && + (!to_memory_kind.memory_kind().has_value() || + !ifrt_array->sharding().memory_kind().memory_kind().has_value() || + ifrt_array->sharding().memory_kind() == to_memory_kind)) { + Shard result(tsl::FormRef(ifrt_array), py_array.weak_type()); + return [result = std::move(result)]() mutable { return std::move(result); }; + } else { + return [ifrt_array = tsl::FormRef(ifrt_array), to_device, to_memory_kind, + weak_type = py_array.weak_type(), + allow_zero_copy = + options.allow_zero_copy]() mutable -> absl::StatusOr { + auto* ifrt_client = ifrt_array->client(); + TF_ASSIGN_OR_RETURN( + auto copied_ifrt_arrays, + ifrt_client->CopyArrays( + absl::MakeSpan(&ifrt_array, 1), + ifrt_client->MakeDeviceList({to_device}), to_memory_kind, + allow_zero_copy ? ifrt::ArrayCopySemantics::kReuseInput + : ifrt::ArrayCopySemantics::kAlwaysCopy)); + return Shard(std::move(copied_ifrt_arrays.front()), weak_type); + }; + } +} + +ifrt::DType Shard::ifrt_dtype() const { + if (is_ifrt_array()) { + return std::get(ifrt_array_or_host_buffer)->dtype(); + } else { + return std::get(ifrt_array_or_host_buffer).dtype; + } +} + +const ifrt::Shape& Shard::ifrt_shape() const { + if (is_ifrt_array()) { + return std::get(ifrt_array_or_host_buffer)->shape(); + } else { + return std::get(ifrt_array_or_host_buffer).shape; + } +} + +// Creates a `ShardFn` that copies `arg` to `to_device` and `to_memory_kind`. +// +// Requires GIL. The returned `ShardFn` should be called without GIL held. +absl::StatusOr MakeShardFn(nb::handle arg, ifrt::Client* client, + ifrt::Device* to_device, + ifrt::MemoryKind to_memory_kind, + const DevicePutOptions& options) { + using PyObjectDeviceHandlerMap = + absl::flat_hash_map; + + auto init_fn = []() { + std::unique_ptr p = + std::make_unique(); + + const NumpyScalarTypes& dtypes = GetNumpyScalarTypes(); + // Python scalar types. + static_assert(sizeof(bool) == 1, "Conversion code assumes bool is 1 byte"); + (*p)[reinterpret_cast(&PyBool_Type)] = + HandlePythonScalar; + (*p)[reinterpret_cast(&PyLong_Type)] = HandlePythonInt; + (*p)[reinterpret_cast(&PyFloat_Type)] = + HandlePythonScalar; + (*p)[reinterpret_cast(&PyComplex_Type)] = + HandlePythonScalar; + + (*p)[reinterpret_cast(&PyArray_Type)] = HandleNumpyArray; + + // Numpy scalar types. For some of them, we share the handler with + // Python types (np_int64, np_float64, np_complex128). + (*p)[dtypes.np_bool.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_int4.ptr()] = HandleNumpyScalar; + if (dtypes.np_int2.has_value()) { + (*p)[dtypes.np_int2->ptr()] = HandleNumpyScalar; + } + (*p)[dtypes.np_int8.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_int16.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_int32.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_int64.ptr()] = HandleNumpyScalar; + if (dtypes.np_uint2.has_value()) { + (*p)[dtypes.np_uint2->ptr()] = HandleNumpyScalar; + } + (*p)[dtypes.np_uint4.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_uint8.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_uint16.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_uint32.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_uint64.ptr()] = HandleNumpyScalar; + if (dtypes.np_float4_e2m1fn.has_value()) { + (*p)[dtypes.np_float4_e2m1fn->ptr()] = + HandleNumpyScalar; + } + if (dtypes.np_float8_e3m4.has_value()) { + (*p)[dtypes.np_float8_e3m4->ptr()] = HandleNumpyScalar; + } + if (dtypes.np_float8_e4m3.has_value()) { + (*p)[dtypes.np_float8_e4m3->ptr()] = HandleNumpyScalar; + } + (*p)[dtypes.np_float8_e4m3fn.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_float8_e4m3b11fnuz.ptr()] = + HandleNumpyScalar; + (*p)[dtypes.np_float8_e5m2.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_float8_e4m3fnuz.ptr()] = + HandleNumpyScalar; + (*p)[dtypes.np_float8_e5m2fnuz.ptr()] = + HandleNumpyScalar; + if (dtypes.np_float8_e8m0fnu.has_value()) { + (*p)[dtypes.np_float8_e8m0fnu->ptr()] = + HandleNumpyScalar; + } + (*p)[dtypes.np_bfloat16.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_float16.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_float32.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_float64.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_complex64.ptr()] = HandleNumpyScalar; + (*p)[dtypes.np_complex128.ptr()] = HandleNumpyScalar; + static_assert(sizeof(long long) == sizeof(int64_t), // NOLINT + "long long must be the same size as int64_t"); + (*p)[dtypes.np_longlong.ptr()] = HandleNumpyScalar; + static_assert(sizeof(int) == sizeof(int32_t), + "int must be the same size as int32_t"); + (*p)[dtypes.np_intc.ptr()] = HandleNumpyScalar; + return p; + }; + const PyObjectDeviceHandlerMap& handlers = + xla::SafeStaticInit(init_fn); + + if (arg.type().ptr() == PyArray::type().ptr()) { + auto array = nb::borrow(arg); + return HandlePyArray(arg, client, to_device, to_memory_kind, options); + } + + auto res = handlers.find(arg.type().ptr()); + if (res == handlers.end()) { + for (auto base_class : arg.type().attr("__mro__")) { + res = handlers.find(base_class.ptr()); + if (res != handlers.end()) { + return res->second(arg, client, to_device, to_memory_kind, options); + } + } + return InvalidArgument( + "%s", absl::StrCat( + "Not supported: The C++ jax jit execution path, only accepts " + "DeviceArray, Numpy arrays scalars of supported types " + "(see implementation), or Python scalars. Got type ", + nb::cast(nb::str(arg.type())))); + } + return res->second(arg, client, to_device, to_memory_kind, options); +} + +} // namespace + +bool IsFloat0(xla::nb_numpy_ndarray arg) { + const nb::object& float0_dtype = SafeStaticInit([] { + nb::module_ dtypes_module = nb::module_::import_("jax.dtypes"); + nb::object float0_dtype = dtypes_module.attr("float0"); + return std::make_unique(float0_dtype); + }); + return float0_dtype.is(arg.attr("dtype")); +} + +std::string PyArgSignature::DebugString() const { + std::string result = ""; + if (weak_type) { + absl::StrAppend(&result, "weak_"); + } + absl::StrAppend(&result, xla::PrimitiveType_Name(dtype)); + absl::StrAppend(&result, "[", absl::StrJoin(shape, ","), "]"); + return result; +} + +using ToPyArgSignatureHandler = + std::function(nb::handle, bool)>; + +absl::StatusOr PyArgSignatureOfValue(nb::handle arg, + bool jax_enable_x64) { + const absl::flat_hash_map& handlers = + SafeStaticInit< + absl::flat_hash_map>([] { + auto p = std::make_unique< + absl::flat_hash_map>(); + + const NumpyScalarTypes& dtypes = GetNumpyScalarTypes(); + + // The 4 Python native types. + ToPyArgSignatureHandler bool_handler = + [](nb::handle, bool) -> absl::StatusOr { + return PyArgSignature(PrimitiveType::PRED, {}, true); + }; + ToPyArgSignatureHandler int_handler = + [](nb::handle h, + bool jax_enable_x64) -> absl::StatusOr { + // TODO(phawkins): we should consider checking for integer overflow. + if (jax_enable_x64) { + return PyArgSignature(PrimitiveType::S64, {}, true); + } else { + return PyArgSignature(PrimitiveType::S32, {}, true); + } + }; + ToPyArgSignatureHandler float_handler = + [&dtypes](nb::handle h, + bool jax_enable_x64) -> absl::StatusOr { + // Only Python native types has a True weak_type. + bool weak_type = !nb::isinstance(h, dtypes.np_float64); + if (jax_enable_x64) { + return PyArgSignature(PrimitiveType::F64, {}, weak_type); + } else { + return PyArgSignature(PrimitiveType::F32, {}, weak_type); + } + }; + ToPyArgSignatureHandler complex_handler = + [&dtypes](nb::handle h, + bool jax_enable_x64) -> absl::StatusOr { + // Note that this branch is also taken for np.complex128: + // isinstance(np.complex128(3), complex) returns True + // isinstance(np.complex64(3), complex) returns False + bool weak_type = !nb::isinstance(h, dtypes.np_complex128); + if (jax_enable_x64) { + return PyArgSignature(PrimitiveType::C128, {}, weak_type); + } else { + return PyArgSignature(PrimitiveType::C64, {}, weak_type); + } + }; + + (*p)[reinterpret_cast(&PyBool_Type)] = bool_handler; + (*p)[reinterpret_cast(&PyLong_Type)] = int_handler; + (*p)[reinterpret_cast(&PyFloat_Type)] = float_handler; + (*p)[reinterpret_cast(&PyComplex_Type)] = complex_handler; + + ToPyArgSignatureHandler numpy_handler = + [](nb::handle h, + bool jax_enable_x64) -> absl::StatusOr { + xla::nb_numpy_ndarray numpy_array = + nb::cast(h); + TF_ASSIGN_OR_RETURN(PrimitiveType dtype, + DtypeToPrimitiveType(numpy_array.dtype())); + if (!jax_enable_x64) { + dtype = Squash64BitTypes(dtype); + } + // We use reinterpret_cast<> to defend against environments where + // ssize_t may not be precisely the same type as int64_t, even if it + // is the same size (long vs long long). + static_assert(sizeof(int64_t) == sizeof(ssize_t), + "Code assumes ssize_t is the same as int64_t"); + return PyArgSignature( + dtype, + absl::MakeConstSpan( + reinterpret_cast(numpy_array.shape()), + numpy_array.ndim()), + /*weak_type=*/false); + }; + (*p)[reinterpret_cast(&PyArray_Type)] = numpy_handler; + + ToPyArgSignatureHandler np_uint64_handler = + [](nb::handle h, + bool jax_enable_x64) -> absl::StatusOr { + if (jax_enable_x64) { + return PyArgSignature(PrimitiveType::U64, {}, /*weak_type=*/false); + } else { + return PyArgSignature(PrimitiveType::U32, {}, /*weak_type=*/false); + } + }; + ToPyArgSignatureHandler np_int_handler = + [](nb::handle h, + bool jax_enable_x64) -> absl::StatusOr { + if (jax_enable_x64) { + return PyArgSignature(PrimitiveType::S64, {}, /*weak_type=*/false); + } else { + return PyArgSignature(PrimitiveType::S32, {}, /*weak_type=*/false); + } + }; + ToPyArgSignatureHandler numpy_array_handler = + [](nb::handle h, + bool jax_enable_x64) -> absl::StatusOr { + // This block deals with all numpy scalar types, except for int64_dt, + // float64_dt and complex128_dt which are taken care of in previous if + // blocks. + TF_ASSIGN_OR_RETURN(auto dtype, + DtypeToPrimitiveType(h.attr("dtype"))); + return PyArgSignature(dtype, {}, /*weak_type=*/false); + }; + + // This block deals with all numpy scalar types, except for int64_dt, + // float64_dt and complex128_dt which are taken care of in previous if + // blocks. + (*p)[dtypes.np_bool.ptr()] = numpy_array_handler; + (*p)[dtypes.np_int4.ptr()] = numpy_array_handler; + (*p)[dtypes.np_int8.ptr()] = numpy_array_handler; + (*p)[dtypes.np_int16.ptr()] = numpy_array_handler; + (*p)[dtypes.np_int32.ptr()] = numpy_array_handler; + (*p)[dtypes.np_int64.ptr()] = np_int_handler; + (*p)[dtypes.np_uint4.ptr()] = numpy_array_handler; + (*p)[dtypes.np_uint8.ptr()] = numpy_array_handler; + (*p)[dtypes.np_uint16.ptr()] = numpy_array_handler; + (*p)[dtypes.np_uint32.ptr()] = numpy_array_handler; + (*p)[dtypes.np_uint64.ptr()] = np_uint64_handler; + // TODO(upwind): Explore if we can remove std::optional for these types + // in xla/python/types.h and xla/python/types.cc + if (dtypes.np_float4_e2m1fn.has_value()) { + (*p)[dtypes.np_float4_e2m1fn->ptr()] = numpy_array_handler; + } + if (dtypes.np_float8_e3m4.has_value()) { + (*p)[dtypes.np_float8_e3m4->ptr()] = numpy_array_handler; + } + if (dtypes.np_float8_e4m3.has_value()) { + (*p)[dtypes.np_float8_e4m3->ptr()] = numpy_array_handler; + } + (*p)[dtypes.np_float8_e4m3fn.ptr()] = numpy_array_handler; + (*p)[dtypes.np_float8_e4m3b11fnuz.ptr()] = numpy_array_handler; + (*p)[dtypes.np_float8_e4m3fnuz.ptr()] = numpy_array_handler; + (*p)[dtypes.np_float8_e5m2.ptr()] = numpy_array_handler; + (*p)[dtypes.np_float8_e5m2fnuz.ptr()] = numpy_array_handler; + if (dtypes.np_float8_e8m0fnu.has_value()) { + (*p)[dtypes.np_float8_e8m0fnu->ptr()] = numpy_array_handler; + } + (*p)[dtypes.np_float16.ptr()] = numpy_array_handler; + (*p)[dtypes.np_bfloat16.ptr()] = numpy_array_handler; + (*p)[dtypes.np_float32.ptr()] = numpy_array_handler; + (*p)[dtypes.np_float64.ptr()] = float_handler; + (*p)[dtypes.np_complex64.ptr()] = numpy_array_handler; + (*p)[dtypes.np_complex128.ptr()] = complex_handler; + (*p)[dtypes.np_longlong.ptr()] = np_int_handler; + (*p)[dtypes.np_intc.ptr()] = numpy_array_handler; + + return p; + }); + + if (arg.type().ptr() == PyArray::type().ptr()) { + auto array = nb::borrow(arg); + ifrt::Array* ifrt_array = array.ifrt_array(); + if (ifrt_array == nullptr) { + return xla::InvalidArgument("Array has been deleted."); + } + TF_ASSIGN_OR_RETURN(auto primitive_type, + ifrt::ToPrimitiveType(ifrt_array->dtype())); + return PyArgSignature(primitive_type, array.shape(), array.weak_type()); + } + + auto res = handlers.find(arg.type().ptr()); + if (res == handlers.end()) { + // We attempt to look at the MRO classes + for (auto base_class : arg.type().attr("__mro__")) { + res = handlers.find(base_class.ptr()); + if (res != handlers.end()) { + return res->second(arg, jax_enable_x64); + } + } + return InvalidArgument( + "%s", + absl::StrCat("Not supported: The C++ ToPyArgSignature only accepts " + "Buffer/DeviceArray, Numpy " + "arrays scalars of supported types " + "(see implementation), or Python scalars. Got type ", + nb::cast(nb::str(arg.type())))); + } + return res->second(arg, jax_enable_x64); +} + +absl::StatusOr DevicePutWithDevice( + nanobind::handle addressable_shard, ifrt::Client* ifrt_client, + ifrt::Device* ifrt_device, ifrt::MemoryKind ifrt_memory_kind, + const DevicePutOptions& options) { + tsl::profiler::TraceMe traceme("DevicePut"); + ++GetDevicePutInfo().device_put_with_device; + + if (!ifrt_device->IsAddressable()) { + return InvalidArgument("Cannot copy array to non-addressable device: %s", + ifrt_device->DebugString()); + } + + TF_ASSIGN_OR_RETURN(ShardFn shard_fn, + MakeShardFn(addressable_shard, ifrt_client, ifrt_device, + ifrt_memory_kind, options)); + + tsl::RCReference ifrt_user_context = + ifrt_client->CreateUserContext(); + + nb::gil_scoped_release gil_release; + + TF_ASSIGN_OR_RETURN(Shard shard, std::move(shard_fn)()); + TF_ASSIGN_OR_RETURN(ifrt::ArrayRef ifrt_array, + MakeSingleDeviceIfrtArrayFromShard( + ifrt_client, ifrt_device, ifrt_memory_kind, shard, + std::move(ifrt_user_context))); + return DevicePutResult(std::move(ifrt_array), shard.weak_type); +} + +absl::StatusOr DevicePutWithSharding( + absl::Span addressable_shards, + ifrt::Client* ifrt_client, const nb_dtype& dtype, + absl::Span shape, nanobind::handle sharding, + const DevicePutOptions& options) { + tsl::profiler::TraceMe traceme("DevicePutWithSharding"); + ++GetDevicePutInfo().device_put_with_sharding; + + TF_ASSIGN_OR_RETURN(ifrt::DeviceListRef ifrt_device_list, + GetIfrtDeviceList(sharding)); + ifrt::DeviceList* ifrt_addressable_device_list = + ifrt_device_list->AddressableDeviceList(); + absl::Span ifrt_addressable_devices = + ifrt_addressable_device_list->devices(); + // Pmap sharding requires special handling because it needs a shard shape + // upfront. + const bool is_pmap_sharding = sharding.type().is(jax::PmapSharding::type()); + + if (addressable_shards.size() != ifrt_addressable_devices.size()) { + // Try to generate a friendly error message if the user attempted to copy to + // a non-addressable device. + if (addressable_shards.size() > ifrt_addressable_devices.size()) { + for (ifrt::Device* device : ifrt_device_list->devices()) { + if (!device->IsAddressable()) { + return InvalidArgument( + "Cannot copy array to non-addressable device: %s", + device->DebugString()); + } + } + } + // Otherwise, generate a generic error message. + return InvalidArgument( + "Number of addressable shard data does not match the number " + "of addressable devices in the sharding: %d vs. %d", + addressable_shards.size(), ifrt_addressable_devices.size()); + } + if (is_pmap_sharding && addressable_shards.empty()) { + return InvalidArgument( + "Pmap sharding requires at least one addressable shard."); + } + + TF_ASSIGN_OR_RETURN(ifrt::DType ifrt_dtype, DtypeToIfRtDType(dtype)); + ifrt::Shape ifrt_shape(shape); + ifrt::MemoryKind ifrt_memory_kind = GetMemoryKind(sharding); + + std::vector shard_fns; + shard_fns.reserve(addressable_shards.size()); + for (int i = 0; i < addressable_shards.size(); ++i) { + TF_ASSIGN_OR_RETURN( + ShardFn shard, + MakeShardFn(addressable_shards[i], ifrt_client, + ifrt_addressable_devices[i], ifrt_memory_kind, options)); + shard_fns.push_back(std::move(shard)); + } + + ifrt::ShardingRef ifrt_sharding; + bool is_fully_replicated; + if (is_pmap_sharding) { + CHECK(!shard_fns.empty()); + // IFRT Sharding will be determined once we discover the shard shape. + is_fully_replicated = false; + } else { + TF_ASSIGN_OR_RETURN(ifrt_sharding, + GetIfrtHloSharding(sharding, ifrt_shape)); + // Fully-replicated shardings enable additional optimizations of using a + // single host buffer. + // TODO(hyeontaek): Enable a similar optimization for partially replicated + // cases to reduce the number of host buffers to obtain. + is_fully_replicated = ifrt_sharding->IsFullyReplicated(); + } + tsl::RCReference ifrt_user_context = + ifrt_client->CreateUserContext(); + + nb::gil_scoped_release gil_release; + + // Whether to build an IFRT array from host buffers as a single batch. We do + // not batch any shard is already an IFRT array. + bool should_batch = true; + + std::vector shards; + shards.reserve(shard_fns.size()); + for (int64_t i = 0; i < shard_fns.size(); ++i) { + TF_ASSIGN_OR_RETURN(Shard shard, std::move(shard_fns[i])()); + if (shard.is_ifrt_array()) { + // If any shard is an IFRT array, we should assemble shards. + should_batch = false; + } + shards.push_back(std::move(shard)); + if (should_batch && is_fully_replicated) { + // We need only one host buffer for a fully-replicated array. + break; + } + } + // While we have finished calling `shard_fns`, we cannot destroy them until we + // make a call to IFRT array creation. Destroying `shard_fns` would release + // host buffers prematurely and can cause the array creation API to see + // garbage data. + + // TODO(emilyaf): Remove the following and just use ifrt_dtype when tokens are + // supported. + if (!shards.empty()) { + ifrt_dtype = shards.front().ifrt_dtype(); + } + if (is_pmap_sharding) { + ifrt_sharding = ifrt::ConcreteEvenSharding::Create( + ifrt::DeviceListRef(tsl::FormRef(ifrt_addressable_device_list)), + ifrt_memory_kind, ifrt_shape, + /*shard_shape=*/shards.front().ifrt_shape(), + /*is_fully_replicated=*/false); + } + + ifrt::ArrayRef ifrt_array; + if (should_batch) { + if (is_fully_replicated && shards.size() == 1) { + ++GetDevicePutInfo().device_put_fully_replicated; + TF_ASSIGN_OR_RETURN( + ifrt_array, MakeIfrtArrayFromFullyReplicatedShard( + ifrt_client, std::move(ifrt_sharding), shards.front(), + std::move(ifrt_user_context))); + } else { + ++GetDevicePutInfo().device_put_batched; + TF_ASSIGN_OR_RETURN(ifrt_array, + MakeIfrtArrayFromShardsInBatch( + ifrt_client, ifrt_dtype, std::move(ifrt_shape), + std::move(ifrt_sharding), absl::MakeSpan(shards), + std::move(ifrt_user_context))); + } + } else { + ++GetDevicePutInfo().device_put_assembled; + TF_ASSIGN_OR_RETURN( + ifrt_array, MakeIfrtArrayFromShardsWithAssembly( + ifrt_client, ifrt_dtype, std::move(ifrt_shape), + std::move(ifrt_sharding), ifrt_addressable_device_list, + ifrt_memory_kind, absl::MakeSpan(shards), + std::move(ifrt_user_context))); + } + const bool weak_type = shards.empty() ? false : shards.front().weak_type; + return DevicePutResult(std::move(ifrt_array), weak_type); +} + +std::unordered_map DevicePutInfo::GetInfo() { + const DevicePutInfo& info = GetDevicePutInfo(); + return std::unordered_map({ + {"device_put_with_device", info.device_put_with_device}, + {"device_put_with_sharding", info.device_put_with_sharding}, + {"device_put_fully_replicated", info.device_put_fully_replicated}, + {"device_put_batched", info.device_put_batched}, + {"device_put_assembled", info.device_put_assembled}, + }); +} + +} // namespace xla diff --git a/jaxlib/py_values.h b/jaxlib/py_values.h new file mode 100644 index 000000000000..d74cf9668a99 --- /dev/null +++ b/jaxlib/py_values.h @@ -0,0 +1,161 @@ +/* Copyright 2020 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Helpers for converting Python values into buffers. + +#ifndef JAXLIB_PY_VALUES_H_ +#define JAXLIB_PY_VALUES_H_ + +#include +#include +#include +#include +#include + +#include "absl/container/inlined_vector.h" +#include "absl/status/statusor.h" +#include "absl/types/span.h" +#include "nanobind/nanobind.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/client.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/memory.h" +#include "xla/python/nb_numpy.h" +#include "xla/xla_data.pb.h" + +namespace xla { + +struct DevicePutResult { + DevicePutResult(ifrt::ArrayRef ifrt_array, bool weak_type) + : ifrt_array(std::move(ifrt_array)), weak_type(weak_type) {} + + // Disallow copy. `DevicePutResult` is expected to be consumed by one user. + DevicePutResult(const DevicePutResult&) = delete; + DevicePutResult& operator=(const DevicePutResult&) = delete; + DevicePutResult(DevicePutResult&&) noexcept = default; + DevicePutResult& operator=(DevicePutResult&&) noexcept = default; + + // Points to the on-device array. + ifrt::ArrayRef ifrt_array; + bool weak_type; +}; + +// Options for `DevicePut`. +struct DevicePutOptions { + bool squash_64bit_types = false; + bool allow_zero_copy = true; +}; + +// Copies a buffer-like object to be on device. This version is designed for +// creating a single-device array. +// +// If `addressable_shard` is not convertible to a `PjRtBuffer` from C++, an +// error will be returned; float0s are not supported yet. +// +// If the value is known to be a PyBuffer object, py_buffer can be passed as an +// optimization to avoid a Python->C++ cast. +// +// Requires GIL. This function performs Python work inline, and runs expensive +// C++ work with GIL temporarily released. +// +// May throw exceptions from nanobind in addition to failing via an error +// absl::Status. (We could catch these if needed, but there seems little point.) +absl::StatusOr DevicePutWithDevice( + nanobind::handle addressable_shard, ifrt::Client* ifrt_client, + ifrt::Device* ifrt_device, ifrt::MemoryKind ifrt_memory_kind, + const DevicePutOptions& options); + +// Copies a buffer-like object to be on device. This version is optimized for +// creating a multi-device array. +// +// `addressable_shards` is a list of buffer-like objects to be copied to +// addressable devices specified in `sharding`. +// +// `shape` and `sharding` determine the shape and sharding of the returned IFRT +// Array. +// +// The size of `addressable_shards` must match the number of addressable devices +// in `sharding`. For a Pmap sharding, there must be at least one addressable +// device. +// +// Requires GIL. This function performs Python work inline, and runs expensive +// C++ work with GIL temporarily released. +// +// See the above `DevicePutWithDevice` for other details. +absl::StatusOr DevicePutWithSharding( + absl::Span addressable_shards, + ifrt::Client* ifrt_client, const nb_dtype& dtype, + absl::Span shape, nanobind::handle sharding, + const DevicePutOptions& options); + +// Returns `true` if `arg` is a JAX float0 array. +bool IsFloat0(xla::nb_numpy_ndarray arg); + +// Describes the abstract shape and dtype of an argument. +struct PyArgSignature { + PyArgSignature(PrimitiveType dtype, absl::Span shape, + bool weak_type) + : dtype(dtype), shape(shape.begin(), shape.end()), weak_type(weak_type) {} + // This is the XLA dtype of the object. + const PrimitiveType dtype; + const absl::InlinedVector shape; + // JAX arguments can be of weak type, if and only if they are Python scalars + // or `DeviceArray` values such that `aval.weak_type` is true. + const bool weak_type; + bool operator==(const PyArgSignature& other) const { + return std::tie(dtype, weak_type, shape) == + std::tie(other.dtype, other.weak_type, other.shape); + } + bool operator!=(const PyArgSignature& other) const { + return !(*this == other); + } + std::string DebugString() const; +}; + +// Returns the PyArgSignature associated with an argument. Returns an error if +// the argument is not supported. +absl::StatusOr PyArgSignatureOfValue(nanobind::handle arg, + bool jax_enable_x64); + +template +H AbslHashValue(H h, const xla::PyArgSignature& s) { + h = H::combine(std::move(h), s.dtype); + h = H::combine_contiguous(std::move(h), s.shape.data(), s.shape.size()); + return h; +} + +// Tracks the number of DevicePut calls and subcases. For testing. +struct DevicePutInfo { + // DevicePutWithDevice call count. + int device_put_with_device = 0; + + // DevicePutWithSharding call count. + int device_put_with_sharding = 0; + + // DevicePutWithSharding with a fully replicated sharding. + int device_put_fully_replicated = 0; + // DevicePutWithSharding that made a batched array creation call. + int device_put_batched = 0; + // DevicePutWithSharding that made per-shard creation calls followed by an + // assembly call. + int device_put_assembled = 0; + + // Returns a map of the counters for the current thread. + static std::unordered_map GetInfo(); +}; + +} // namespace xla + +#endif // JAXLIB_PY_VALUES_H_ diff --git a/jaxlib/pyinit_stub.c b/jaxlib/pyinit_stub.c new file mode 100644 index 000000000000..7fc873d9ae0e --- /dev/null +++ b/jaxlib/pyinit_stub.c @@ -0,0 +1,28 @@ +/* Copyright 2025 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Stub that reexports Wrapped_PyInit_module as PyInit_module. + +extern void* Wrapped_PyInit_@MODULE_NAME@(); + +#if defined(WIN32) || defined(_WIN32) +#define EXPORT_SYMBOL __declspec(dllexport) +#else +#define EXPORT_SYMBOL __attribute__ ((visibility("default"))) +#endif + +EXPORT_SYMBOL void* PyInit_@MODULE_NAME@() { + return Wrapped_PyInit_@MODULE_NAME@(); +} diff --git a/jaxlib/python_ref_manager.cc b/jaxlib/python_ref_manager.cc new file mode 100644 index 000000000000..6cc2714b75ad --- /dev/null +++ b/jaxlib/python_ref_manager.cc @@ -0,0 +1,108 @@ +/* Copyright 2019 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/python_ref_manager.h" + +#include + +#include +#include +#include +#include + +#include "absl/container/inlined_vector.h" +#include "absl/synchronization/mutex.h" +#include "absl/types/span.h" +#include "nanobind/nanobind.h" +#include "tsl/profiler/lib/traceme.h" + +namespace xla { + +namespace nb = nanobind; + +PythonRefManager::ManagedPyObjects::ManagedPyObjects( + PythonRefManager* manager, absl::Span objects) + : manager_(manager) { + objects_.reserve(objects.size()); + for (nb::object& object : objects) { + objects_.push_back(std::move(object)); + } +} + +PythonRefManager::ManagedPyObjects::~ManagedPyObjects() { + if (manager_ && !objects_.empty()) { + manager_->AddGarbage(absl::MakeSpan(objects_)); + } +} + +std::shared_ptr +PythonRefManager::ManageReference(nb::object object) { + return std::make_shared(this, + absl::Span(&object, 1)); +} + +std::shared_ptr +PythonRefManager::ManageReferences(absl::Span objects) { + return std::make_shared(this, objects); +} + +void PythonRefManager::AddGarbage(nb::object garbage) { + absl::MutexLock lock(&mu_); + // We want to collect arbitrary python garbage (e.g., buffers) aggressively. + garbage_count_.fetch_add(100, std::memory_order_relaxed); + python_garbage_.push_back(std::move(garbage)); +} + +void PythonRefManager::AddGarbage(absl::Span garbage) { + absl::MutexLock lock(&mu_); + // We want to collect arbitrary python garbage (e.g., buffers) aggressively. + garbage_count_.fetch_add(100, std::memory_order_relaxed); + for (nb::object& o : garbage) { + python_garbage_.push_back(std::move(o)); + } +} + +void PythonRefManager::AddGarbage( + absl::Span const> garbage) { + absl::MutexLock lock(&mu_); + // We don't care about collecting stack frame objects often. We grab a lot of + // tracebacks and the code objects are most likely live for the entire + // process. + garbage_count_.fetch_add(1, std::memory_order_relaxed); + for (const auto& o : garbage) { + python_garbage_.push_back(nb::steal(reinterpret_cast(o.first))); + } +} + +void PythonRefManager::CollectGarbage() { + // TODO(phawkins): we should CHECK(PyGILState_Check()); + tsl::profiler::TraceMe traceme("PythonRefManager::CollectGarbage"); + std::deque garbage; + { + absl::MutexLock lock(&mu_); + garbage_count_ = 0; + garbage.swap(python_garbage_); + } + // We defer deleting garbage until the lock is released. It's possible that + // deleting garbage will lead to more Python garbage being added; if we held + // the lock we would deadlock because absl::Mutex is not reentrant. +} + +PythonRefManager* GlobalPyRefManager() { + static PythonRefManager* static_ref_manager = new PythonRefManager(); + return static_ref_manager; +} + +} // namespace xla diff --git a/jaxlib/python_ref_manager.h b/jaxlib/python_ref_manager.h new file mode 100644 index 000000000000..37eae1cae84d --- /dev/null +++ b/jaxlib/python_ref_manager.h @@ -0,0 +1,108 @@ +/* Copyright 2019 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_PYTHON_REF_MANAGER_H_ +#define JAXLIB_PYTHON_REF_MANAGER_H_ + +#include + +#include +#include +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/container/inlined_vector.h" +#include "absl/synchronization/mutex.h" +#include "absl/types/span.h" +#include "nanobind/nanobind.h" + +namespace xla { + +// Class that manages destruction of Python objects. +// +// We must not destroy Python objects without holding the GIL. However, we +// frequently want to hold references to Python objects for the duration of +// an asynchronous transfer on a Stream, and release our reference when the +// transfer completes. +// +// This class holds references to Python objects outside a GIL scope, that can +// be collected later when the GIL is held by calling CollectGarbage(). +class PythonRefManager { + public: + PythonRefManager() = default; + + // Holds references to a set of nanobind::objects, adding the references to + // the PythonRefManager on destruction. + class ManagedPyObjects { + public: + ManagedPyObjects() = default; + ManagedPyObjects(PythonRefManager* manager, + absl::Span objects); + + ~ManagedPyObjects(); + + ManagedPyObjects(const ManagedPyObjects& other) = delete; + ManagedPyObjects(ManagedPyObjects&& other) = default; + ManagedPyObjects& operator=(const ManagedPyObjects& other) = delete; + ManagedPyObjects& operator=(ManagedPyObjects&& other) noexcept = default; + + private: + PythonRefManager* manager_ = nullptr; + absl::InlinedVector objects_; + }; + + // Creates a managed std::shared_ptr to an object. When the shared_ptr is + // destroyed, the reference to 'object' will be added to python_garbage_, + // and collected next time CollectGarbage() is called. + std::shared_ptr ManageReference(nanobind::object object); + std::shared_ptr ManageReferences( + absl::Span objects); + + // Adds garbage objects to the manager. + void AddGarbage(nanobind::object garbage); + void AddGarbage(absl::Span garbage); + void AddGarbage(absl::Span const> garbage); + + // Releases the contents of python_garbage_. Requires that the GIL is held. + // The client calls this method during API entry points where the GIL is held + // to free any garbage that has accumulated. + void CollectGarbage(); + + // Cheaper version of CollectGarbage() with relaxed consistency and frequency. + // The purpose of this function is to amortize lock acquisition costs over + // a larger number of API calls. + void MaybeCollectGarbage() { + if (garbage_count_.load(std::memory_order_relaxed) >= 100) { + CollectGarbage(); + } + } + + private: + absl::Mutex mu_; + std::deque python_garbage_ ABSL_GUARDED_BY(mu_); + + // Writes to garbage_count_ are protected by mu_, reads are not protected. + std::atomic garbage_count_{0}; +}; + +// A global PythonRefManager. Unless `CollectGarbage()` is called before +// shutdown, this container will hold on to Python objects and thus cause a +// leak. This behavior is similar to `tensorflow::ClearDecRefCache()`. +PythonRefManager* GlobalPyRefManager(); + +} // namespace xla + +#endif // JAXLIB_PYTHON_REF_MANAGER_H_ diff --git a/jaxlib/pytree.cc b/jaxlib/pytree.cc new file mode 100644 index 000000000000..272ac0c82859 --- /dev/null +++ b/jaxlib/pytree.cc @@ -0,0 +1,1788 @@ +/* Copyright 2019 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Caution: this code uses exceptions. The exception use is local to the +// binding code and the idiomatic way to emit Python exceptions. + +#include "jaxlib/pytree.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/inlined_vector.h" +#include "absl/hash/hash.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/optional.h" // IWYU pragma: keep +#include "nanobind/stl/pair.h" // IWYU pragma: keep +#include "nanobind/stl/shared_ptr.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "nanobind/stl/tuple.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "jaxlib/nb_class_ptr.h" +#include "jaxlib/pytree.pb.h" +#include "xla/pjrt/exceptions.h" +#include "xla/tsl/platform/logging.h" + +namespace xla { + +namespace nb = nanobind; + +constexpr int kSequenceKeyHashSalt = 1; +constexpr int kFlattenedIndexKeyHashSalt = 42; + +PyTreeRegistry::PyTreeRegistry(bool enable_none, bool enable_tuple, + bool enable_namedtuple, bool enable_list, + bool enable_dict) { + auto add_builtin_type = [&](PyTypeObject* type_obj, PyTreeKind kind) { + nb::object type = + nb::borrow(reinterpret_cast(type_obj)); + auto registration = std::make_unique(); + registration->kind = kind; + registration->type = type; + CHECK(registrations_.emplace(type, std::move(registration)).second); + }; + if (enable_none) { + add_builtin_type(Py_TYPE(Py_None), PyTreeKind::kNone); + } + if (enable_tuple) { + add_builtin_type(&PyTuple_Type, PyTreeKind::kTuple); + } + enable_namedtuple_ = enable_namedtuple; + if (enable_list) { + add_builtin_type(&PyList_Type, PyTreeKind::kList); + } + if (enable_dict) { + add_builtin_type(&PyDict_Type, PyTreeKind::kDict); + } +} + +void PyTreeRegistry::Register( + nb::object type, nb::callable to_iterable, nb::callable from_iterable, + std::optional to_iterable_with_keys) { + auto registration = std::make_unique(); + registration->kind = PyTreeKind::kCustom; + registration->type = type; + registration->to_iterable = std::move(to_iterable); + registration->from_iterable = std::move(from_iterable); + registration->to_iterable_with_keys = std::move(to_iterable_with_keys); + nb::ft_lock_guard lock(mu_); + auto it = registrations_.emplace(type, std::move(registration)); + if (!it.second) { + throw std::invalid_argument( + absl::StrFormat("Duplicate custom PyTreeDef type registration for %s.", + nb::cast(nb::repr(type)))); + } +} + +void PyTreeRegistry::RegisterDataclass(nb::object type, + std::vector data_fields, + std::vector meta_fields) { + auto registration = std::make_unique(); + registration->kind = PyTreeKind::kDataclass; + registration->type = type; + registration->data_fields = std::move(data_fields); + registration->meta_fields = std::move(meta_fields); + nb::ft_lock_guard lock(mu_); + auto it = registrations_.emplace(type, std::move(registration)); + if (!it.second) { + throw std::invalid_argument(absl::StrFormat( + "Duplicate custom dataclass PyTreeDef type registration for %s.", + nb::cast(nb::repr(std::move(type))))); + } +} + +std::pair +PyTreeRegistry::Registration::ToIterable(nanobind::handle o) const { + nb::object out = to_iterable(o); + nb::tuple leaves_and_aux_data; + if (!nb::try_cast(out, leaves_and_aux_data) || + leaves_and_aux_data.size() != 2) { + throw std::invalid_argument(absl::StrCat( + "The to_iterable function for a custom PyTree node should return " + "a (children, aux_data) tuple, got ", + nb::cast(nb::repr(out)))); + } + nb::iterable leaves; + if (!nb::try_cast(leaves_and_aux_data[0], leaves)) { + throw std::invalid_argument(absl::StrCat( + "The to_iterable function for a custom PyTree node should return " + "a (children, aux_data) tuple where 'children' is iterable, " + "got ", + nb::cast(nb::repr(out)))); + } + return std::make_pair(std::move(leaves), nb::object(leaves_and_aux_data[1])); +} + +std::pair>, nb::object> +PyTreeRegistry::Registration::ToIterableWithKeys(nb::handle o) const { + // Backwards compatibility case: return dummy FlattenedIndexKey for each leaf. + std::vector> result; + if (!to_iterable_with_keys.has_value()) { + auto [leaves, aux_data] = ToIterable(o); + for (nb::handle leaf : leaves) { + result.push_back(std::make_pair( + make_nb_class(result.size()), nb::borrow(leaf))); + } + return std::make_pair(std::move(result), std::move(aux_data)); + } + nb::object out = to_iterable_with_keys.value()(o); + nb::tuple leaves_and_aux_data; + if (!nb::try_cast(out, leaves_and_aux_data) || + leaves_and_aux_data.size() != 2) { + throw std::invalid_argument(absl::StrCat( + "The to_iterable_with_keys function for a custom PyTree " + "node should return a (key_leaf_pairs, aux_data) tuple, got ", + nb::cast(nb::repr(out)))); + } + nb::iterable key_leaf_pairs; + if (!nb::try_cast(leaves_and_aux_data[0], key_leaf_pairs)) { + throw std::invalid_argument(absl::StrCat( + "The to_iterable_with_keys function for a custom PyTree node should " + "return a (key_leaf_pairs, aux_data) tuple where 'key_leaf_pairs' is " + "iterable, got ", + nb::cast(nb::repr(leaves_and_aux_data)))); + } + for (nb::handle key_leaf_pair : key_leaf_pairs) { + nb::tuple key_leaf_pair_tuple; + if (!nb::try_cast(key_leaf_pair, key_leaf_pair_tuple) || + key_leaf_pair_tuple.size() != 2) { + throw std::invalid_argument(absl::StrCat( + "The to_iterable_with_keys function for a custom PyTree node should " + "return a (key_leaf_pairs, aux_data) tuple where 'child", + nb::cast(nb::repr(key_leaf_pair)))); + } + result.push_back(std::make_pair(nb::borrow(key_leaf_pair_tuple[0]), + nb::borrow(key_leaf_pair_tuple[1]))); + } + return std::make_pair(std::move(result), nb::object(leaves_and_aux_data[1])); +} + +int PyTreeRegistry::Registration::tp_traverse(visitproc visit, void* arg) { + Py_VISIT(type.ptr()); + Py_VISIT(to_iterable.ptr()); + Py_VISIT(from_iterable.ptr()); + for (const auto& field : data_fields) { + Py_VISIT(field.ptr()); + } + for (const auto& field : meta_fields) { + Py_VISIT(field.ptr()); + } + return 0; +} + +// Computes the node kind of a given Python object. +PyTreeKind PyTreeRegistry::KindOfObject( + nb::handle obj, PyTreeRegistry::Registration const** custom) const { + const PyTreeRegistry::Registration* registration = Lookup(obj.type()); + if (registration) { + if (registration->kind == PyTreeKind::kCustom || + registration->kind == PyTreeKind::kDataclass) { + *custom = registration; + } else { + *custom = nullptr; + } + return registration->kind; + } else if (nb::isinstance(obj) && nb::hasattr(obj, "_fields")) { + // We can only identify namedtuples heuristically, here by the presence of + // a _fields attribute. + return PyTreeKind::kNamedTuple; + } else { + return PyTreeKind::kLeaf; + } +} + +/*static*/ const PyTreeRegistry::Registration* PyTreeRegistry::Lookup( + nb::handle type) const { + nb::ft_lock_guard lock(mu_); + auto it = registrations_.find(type); + return it == registrations_.end() ? nullptr : it->second.get(); +} + +/*static*/ std::vector GetSortedPyDictKeys(PyObject* py_dict) { + std::vector keys; + keys.reserve(PyDict_Size(py_dict)); + PyObject* key; + Py_ssize_t pos = 0; + while (PyDict_Next(py_dict, &pos, &key, /*value=*/nullptr)) { + keys.push_back(nb::borrow(key)); + } + + try { + std::stable_sort( + keys.begin(), keys.end(), [](const nb::object& a, const nb::object& b) { + int cmp = PyObject_RichCompareBool(a.ptr(), b.ptr(), Py_LT); + if (cmp == -1) { + throw nb::python_error(); + } + return cmp; + }); + } catch (nb::python_error& e) { + nb::raise_from(e, PyExc_ValueError, + "Comparator raised exception while sorting pytree " + "dictionary keys."); + } + return keys; +} + +/*static*/ bool IsSortedPyDictKeysEqual(absl::Span lhs, + absl::Span rhs) { + if (lhs.size() != rhs.size()) { + return false; + } + for (int i = 0; i < lhs.size(); ++i) { + if (lhs[i].not_equal(rhs[i])) { + return false; + } + } + return true; +} + +bool PyTreeDef::operator==(const PyTreeDef& other) const { + if (traversal_.size() != other.traversal_.size()) { + return false; + } + for (size_t i = 0; i < traversal_.size(); ++i) { + const Node& a = traversal_[i]; + const Node& b = other.traversal_[i]; + if (a.kind != b.kind || a.arity != b.arity || + (a.node_data.ptr() == nullptr) != (b.node_data.ptr() == nullptr) || + (a.sorted_dict_keys.size() != b.sorted_dict_keys.size()) || + a.custom != b.custom) { + return false; + } + try { + if (a.node_data && a.node_data.not_equal(b.node_data)) { + return false; + } + } catch (nb::python_error& e) { + nb::raise_from(e, PyExc_ValueError, + "Exception raised while checking equality of metadata " + "fields of pytree. Make sure that metadata fields are " + "hashable and have simple equality semantics. (Note: " + "arrays cannot be passed as metadata fields!)"); + } + if (!IsSortedPyDictKeysEqual(a.sorted_dict_keys, b.sorted_dict_keys)) { + return false; + } + // We don't need to test equality of num_leaves and num_nodes since they + // are derivable from the other node data. + } + return true; +} + +nb::object PyTreeRegistry::FlattenOneLevel(nb::handle x) const { + return FlattenOneLevelImpl(x, /*with_keys=*/false); +} + +nb::object PyTreeRegistry::FlattenOneLevelWithKeys(nb::handle x) const { + return FlattenOneLevelImpl(x, /*with_keys=*/true); +} + +nb::object PyTreeRegistry::FlattenOneLevelImpl(nb::handle x, + bool with_keys) const { + PyTreeRegistry::Registration const* custom; + PyTreeKind kind = KindOfObject(x, &custom); + switch (kind) { + case PyTreeKind::kNone: + return nb::make_tuple(nb::make_tuple(), nb::none()); + case PyTreeKind::kTuple: { + if (with_keys) { + auto size = PyTuple_GET_SIZE(x.ptr()); + nb::object key_leaves = nb::steal(PyTuple_New(size)); + for (int i = 0; i < size; ++i) { + nb::object key = make_nb_class(i); + nb::object value = + nb::borrow(PyTuple_GET_ITEM(x.ptr(), i)); + PyTuple_SET_ITEM(key_leaves.ptr(), i, + nb::make_tuple(key, value).release().ptr()); + } + return nb::make_tuple(std::move(key_leaves), nb::none()); + } + return nb::make_tuple(nb::borrow(x), nb::none()); + } + case PyTreeKind::kList: { + if (with_keys) { + auto size = PyList_GET_SIZE(x.ptr()); + nb::object key_leaves = nb::steal(PyTuple_New(size)); + for (int i = 0; i < size; ++i) { + nb::object key = make_nb_class(i); + nb::object value = + nb::borrow(PyList_GET_ITEM(x.ptr(), i)); + PyTuple_SET_ITEM(key_leaves.ptr(), i, + nb::make_tuple(key, value).release().ptr()); + } + return nb::make_tuple(std::move(key_leaves), nb::none()); + } + return nb::make_tuple(nb::borrow(x), nb::none()); + } + case PyTreeKind::kDict: { + nb::dict dict = nb::borrow(x); + std::vector sorted_keys = GetSortedPyDictKeys(dict.ptr()); + nb::tuple keys = nb::steal(PyTuple_New(sorted_keys.size())); + nb::tuple values = nb::steal(PyTuple_New(sorted_keys.size())); + for (size_t i = 0; i < sorted_keys.size(); ++i) { + nb::object& key = sorted_keys[i]; + nb::object value = nb::object(dict[key]); + if (with_keys) { + value = nb::make_tuple(make_nb_class(key), value); + } + PyTuple_SET_ITEM(values.ptr(), i, value.release().ptr()); + PyTuple_SET_ITEM(keys.ptr(), i, sorted_keys[i].release().ptr()); + } + return nb::make_tuple(std::move(values), std::move(keys)); + } + case PyTreeKind::kNamedTuple: { + nb::tuple in = nb::borrow(x); + nb::list out; + if (with_keys) { + // Get key names from NamedTuple fields. + nb::tuple fields; + if (!nb::try_cast(nb::getattr(in, "_fields"), fields) || + in.size() != fields.size()) { + throw std::invalid_argument( + "A namedtuple's _fields attribute should have the same size as " + "the tuple."); + } + auto field_iter = fields.begin(); + for (nb::handle entry : in) { + out.append(nb::make_tuple( + make_nb_class(nb::str(*field_iter)), entry)); + } + return nb::make_tuple(std::move(out), x.type()); + } + for (size_t i = 0; i < in.size(); ++i) { + out.append(in[i]); + } + return nb::make_tuple(std::move(out), x.type()); + } + case PyTreeKind::kCustom: { + if (with_keys) { + auto [leaves, aux_data] = custom->ToIterableWithKeys(x); + return nb::make_tuple(std::move(leaves), std::move(aux_data)); + } + auto [leaves, aux_data] = custom->ToIterable(x); + return nb::make_tuple(std::move(leaves), std::move(aux_data)); + } + case PyTreeKind::kDataclass: { + auto data_size = custom->data_fields.size(); + nb::list leaves = nb::steal(PyList_New(data_size)); + for (int leaf = 0; leaf < data_size; ++leaf) { + nb::object value = nb::getattr(x, custom->data_fields[leaf]); + if (with_keys) { + value = nb::make_tuple( + make_nb_class(custom->data_fields[leaf]), value); + } + PyList_SET_ITEM(leaves.ptr(), leaf, value.release().ptr()); + } + auto meta_size = custom->meta_fields.size(); + nb::object aux_data = nb::steal(PyTuple_New(meta_size)); + for (int meta_leaf = 0; meta_leaf < meta_size; ++meta_leaf) { + PyTuple_SET_ITEM( + aux_data.ptr(), meta_leaf, + nb::getattr(x, custom->meta_fields[meta_leaf]).release().ptr()); + } + return nb::make_tuple(std::move(leaves), std::move(aux_data)); + } + default: + DCHECK(kind == PyTreeKind::kLeaf); + return nb::none(); + } +} + +/* static */ PyType_Slot PyTreeRegistry::slots_[] = { + {Py_tp_traverse, (void*)PyTreeRegistry::tp_traverse}, + {Py_tp_clear, (void*)PyTreeRegistry::tp_clear}, + {0, nullptr}, +}; + +/* static */ int PyTreeRegistry::tp_traverse(PyObject* self, visitproc visit, + void* arg) { + Py_VISIT(Py_TYPE(self)); + if (!nb::inst_ready(self)) { + return 0; + } + PyTreeRegistry* registry = nb::inst_ptr(self); + nb::ft_lock_guard lock(registry->mu_); + for (const auto& [key, value] : registry->registrations_) { + Py_VISIT(key.ptr()); + int rval = value->tp_traverse(visit, arg); + if (rval != 0) { + return rval; + } + } + return 0; +} + +/* static */ int PyTreeRegistry::tp_clear(PyObject* self) { + PyTreeRegistry* registry = nb::inst_ptr(self); + nb::ft_lock_guard lock(registry->mu_); + registry->registrations_.clear(); + return 0; +} + +/* static */ PyType_Slot DictKey::slots_[] = { + {Py_tp_traverse, (void*)DictKey::tp_traverse}, + {Py_tp_clear, (void*)DictKey::tp_clear}, + {0, nullptr}, +}; + +/* static */ int DictKey::tp_traverse(PyObject* self, visitproc visit, + void* arg) { + DictKey* key = nb::inst_ptr(self); + Py_VISIT(key->key_.ptr()); + return 0; +} + +/* static */ int DictKey::tp_clear(PyObject* self) { + DictKey* dictkey = nb::inst_ptr(self); + nb::object tmp; + std::swap(tmp, dictkey->key_); + return 0; +} + +std::string SequenceKey::ToString() const { + return absl::StrFormat("[%d]", idx_); +} + +std::string SequenceKey::ToReprString() const { + return absl::StrFormat("SequenceKey(idx=%d)", idx_); +} + +std::string DictKey::ToString() const { + return absl::StrFormat("[%s]", nb::cast(nb::repr(key_))); +} + +std::string DictKey::ToReprString() const { + return absl::StrFormat("DictKey(key=%s)", + nb::cast(nb::repr(key_))); +} + +std::string GetAttrKey::ToString() const { + return absl::StrFormat(".%s", nb::cast(name_)); +} + +std::string GetAttrKey::ToReprString() const { + return absl::StrFormat("GetAttrKey(name='%s')", + nb::cast(name_)); +} + +std::string FlattenedIndexKey::ToString() const { + return absl::StrFormat("[]", key_); +} + +std::string FlattenedIndexKey::ToReprString() const { + return absl::StrFormat("FlattenedIndexKey(key=%d)", key_); +} + +bool SequenceKey::Equals(const nb::object& other) { + SequenceKey other_key(0); + if (!nb::try_cast(other, other_key)) return false; + return idx_ == other_key.idx(); +} + +bool DictKey::Equals(const nb::object& other) { + DictKey other_key(nb::none()); + if (!nb::try_cast(other, other_key)) return false; + return key_.equal(other_key.key()); +} + +bool GetAttrKey::Equals(const nb::object& other) { + GetAttrKey other_key(nb::str("")); + if (!nb::try_cast(other, other_key)) return false; + return name_.equal(other_key.name()); +} + +bool FlattenedIndexKey::Equals(const nb::object& other) { + FlattenedIndexKey other_key(0); + if (!nb::try_cast(other, other_key)) return false; + return key_ == other_key.key(); +} + +nanobind::tuple SequenceKey::MatchArgs(nanobind::handle unused) { + return nanobind::make_tuple("idx"); +}; + +nanobind::tuple DictKey::MatchArgs(nanobind::handle unused) { + return nanobind::make_tuple("key"); +}; + +nanobind::tuple GetAttrKey::MatchArgs(nanobind::handle unused) { + return nanobind::make_tuple("name"); +}; + +nanobind::tuple FlattenedIndexKey::MatchArgs(nanobind::handle unused) { + return nanobind::make_tuple("key"); +}; + +/* static */ nb::object MakeKeyPathTuple(std::vector& keypath) { + const std::vector& frozen_keypath = keypath; + nb::object kp_tuple = nb::steal(PyTuple_New(frozen_keypath.size())); + for (int i = 0; i < frozen_keypath.size(); ++i) { + PyTuple_SET_ITEM(kp_tuple.ptr(), i, + nb::object(frozen_keypath[i]).release().ptr()); + } + return kp_tuple; +} + +template +void PyTreeDef::FlattenImpl( + nb::handle handle, T& leaves, + std::optional>& keypath, + const std::optional& leaf_predicate) { + Node node; + const int start_num_nodes = traversal_.size(); + const int start_num_leaves = leaves.size(); + bool is_known_leaf = false; + if (leaf_predicate) { + nb::object o; + if (keypath.has_value()) { + auto kp_tuple = MakeKeyPathTuple(keypath.value()); + o = (*leaf_predicate)(kp_tuple, handle); + } else { + o = (*leaf_predicate)(handle); + } + // Historically we accepted "truthy" values from leaf predicates. Accept + // None here to keep existing clients happy. + if (o.is_none()) { + is_known_leaf = false; + } else if (!nb::try_cast(o, is_known_leaf)) { + throw std::invalid_argument(absl::StrCat( + "is_leaf predicate returned a non-boolean value ", + nb::cast(nb::repr(o)), "; expected a boolean")); + } + } + if (is_known_leaf) { + nb::object value = nb::borrow(handle); + if (keypath.has_value()) { + auto kp_tuple = MakeKeyPathTuple(keypath.value()); + value = nb::make_tuple(std::move(kp_tuple), std::move(value)); + } + if constexpr (std::is_same_v) { + leaves.append(std::move(value)); + } else { + leaves.push_back(std::move(value)); + } + } else { + node.kind = registry_->KindOfObject(handle, &node.custom); + auto recurse = [this, &leaf_predicate, &leaves]( + nb::handle child, + std::optional>& keypath) { + if (Py_EnterRecursiveCall( + " in flatten; PyTree may have cyclical node references.")) { + return; + } + FlattenImpl(child, leaves, keypath, leaf_predicate); + Py_LeaveRecursiveCall(); + }; + switch (node.kind) { + case PyTreeKind::kNone: + // Nothing to do. + break; + case PyTreeKind::kTuple: { + node.arity = PyTuple_GET_SIZE(handle.ptr()); + for (int i = 0; i < node.arity; ++i) { + if (keypath.has_value()) { + keypath->push_back(make_nb_class(i)); + } + recurse(PyTuple_GET_ITEM(handle.ptr(), i), keypath); + if (keypath.has_value()) { + keypath->pop_back(); + } + } + break; + } + case PyTreeKind::kList: { + node.arity = PyList_GET_SIZE(handle.ptr()); + for (int i = 0; i < node.arity; ++i) { + if (keypath.has_value()) { + keypath->push_back(make_nb_class(i)); + } + recurse(PyList_GET_ITEM(handle.ptr(), i), keypath); + if (keypath.has_value()) { + keypath->pop_back(); + } + } + break; + } + case PyTreeKind::kDict: { + nb::dict dict = nb::borrow(handle); + + std::vector keys = GetSortedPyDictKeys(dict.ptr()); + for (nb::object& key : keys) { + if (keypath.has_value()) { + keypath->push_back(make_nb_class(key)); + } + recurse(dict[key], keypath); + if (keypath.has_value()) { + keypath->pop_back(); + } + } + node.arity = dict.size(); + node.sorted_dict_keys = std::move(keys); + break; + } + case PyTreeKind::kCustom: { + if (keypath.has_value()) { + auto [leaves, aux_data] = node.custom->ToIterableWithKeys(handle); + node.node_data = std::move(aux_data); + node.arity = 0; + for (auto& [key, leaf] : leaves) { + keypath->push_back(key); + ++node.arity; + recurse(leaf, keypath); + keypath->pop_back(); + } + } else { + auto [leaves, aux_data] = node.custom->ToIterable(handle); + node.node_data = std::move(aux_data); + node.arity = 0; + for (nb::handle entry : leaves) { + ++node.arity; + recurse(entry, keypath); + } + } + break; + } + case PyTreeKind::kDataclass: { + auto meta_size = node.custom->meta_fields.size(); + nb::object aux_data = nb::steal(PyTuple_New(meta_size)); + for (int meta_leaf = 0; meta_leaf < meta_size; ++meta_leaf) { + PyTuple_SET_ITEM( + aux_data.ptr(), meta_leaf, + nb::getattr(handle, node.custom->meta_fields[meta_leaf]) + .release() + .ptr()); + } + node.node_data = std::move(aux_data); + auto data_size = node.custom->data_fields.size(); + node.arity = data_size; + for (int leaf = 0; leaf < data_size; ++leaf) { + if (keypath.has_value()) { + keypath->push_back( + make_nb_class(node.custom->data_fields[leaf])); + } + recurse(nb::getattr(handle, node.custom->data_fields[leaf]), keypath); + if (keypath.has_value()) { + keypath->pop_back(); + } + } + break; + } + case PyTreeKind::kNamedTuple: { + nb::tuple tuple = nb::borrow(handle); + node.arity = tuple.size(); + node.node_data = nb::borrow(tuple.type()); + if (keypath.has_value()) { + // Get key names from NamedTuple fields. + nb::tuple fields; + if (!nb::try_cast(nb::getattr(tuple, "_fields"), fields) || + tuple.size() != fields.size()) { + throw std::invalid_argument( + "A namedtuple's _fields attribute should have the same size as " + "the tuple."); + } + auto field_iter = fields.begin(); + for (nb::handle entry : tuple) { + keypath->push_back(make_nb_class(nb::str(*field_iter))); + field_iter++; + recurse(entry, keypath); + keypath->pop_back(); + } + } else { + for (nb::handle entry : tuple) { + recurse(entry, keypath); + } + } + break; + } + default: + DCHECK(node.kind == PyTreeKind::kLeaf); + auto value = nb::borrow(handle); + if (keypath.has_value()) { + auto kp_tuple = MakeKeyPathTuple(keypath.value()); + value = nb::make_tuple(std::move(kp_tuple), std::move(value)); + } + if constexpr (std::is_same_v) { + leaves.append(std::move(value)); + } else { + leaves.push_back(std::move(value)); + } + } + } + node.num_nodes = traversal_.size() - start_num_nodes + 1; + node.num_leaves = leaves.size() - start_num_leaves; + traversal_.push_back(std::move(node)); +} + +void PyTreeDef::Flatten(nb::handle handle, + absl::InlinedVector& leaves, + std::optional leaf_predicate) { + std::optional> keypath = std::nullopt; + FlattenImpl(handle, leaves, keypath, leaf_predicate); +} + +void PyTreeDef::Flatten(nb::handle handle, std::vector& leaves, + std::optional leaf_predicate) { + std::optional> keypath = std::nullopt; + FlattenImpl(handle, leaves, keypath, leaf_predicate); +} + +void PyTreeDef::Flatten(nb::handle handle, nb::list& leaves, + std::optional leaf_predicate) { + std::optional> keypath = std::nullopt; + FlattenImpl(handle, leaves, keypath, leaf_predicate); +} + +/*static*/ std::pair, nb_class_ptr> +PyTreeDef::Flatten(nb::handle x, nb_class_ptr registry, + std::optional leaf_predicate) { + auto def = make_nb_class(registry); + std::vector leaves; + def->Flatten(x, leaves, leaf_predicate); + return std::make_pair(std::move(leaves), std::move(def)); +} + +void PyTreeDef::FlattenWithPath( + nb::handle handle, nanobind::list& leaves, + std::optional leaf_predicate) { + std::optional> keypath = std::vector(); + FlattenImpl(handle, leaves, keypath, leaf_predicate); +} + +/*static*/ bool PyTreeDef::AllLeaves(PyTreeRegistry* registry, + const nb::iterable& x) { + const PyTreeRegistry::Registration* custom; + for (const nb::handle& h : x) { + if (registry->KindOfObject(h, &custom) != PyTreeKind::kLeaf) return false; + } + return true; +} + +template +nb::object PyTreeDef::UnflattenImpl(T leaves) const { + absl::InlinedVector agenda; + auto it = leaves.begin(); + int leaf_count = 0; + for (const Node& node : traversal_) { + if (agenda.size() < node.arity) { + throw std::logic_error("Too few elements for TreeDef node."); + } + switch (node.kind) { + case PyTreeKind::kLeaf: + if (it == leaves.end()) { + throw std::invalid_argument(absl::StrFormat( + "Too few leaves for PyTreeDef; expected %d, got %d", num_leaves(), + leaf_count)); + } + agenda.push_back(nb::borrow(*it)); + ++it; + ++leaf_count; + break; + + case PyTreeKind::kNone: + case PyTreeKind::kTuple: + case PyTreeKind::kNamedTuple: + case PyTreeKind::kList: + case PyTreeKind::kDict: + case PyTreeKind::kCustom: + case PyTreeKind::kDataclass: { + const int size = agenda.size(); + absl::Span span; + if (node.arity > 0) { + span = absl::Span(&agenda[size - node.arity], node.arity); + } + nb::object o = MakeNode(node, span); + agenda.resize(size - node.arity); + agenda.push_back(o); + break; + } + } + } + if (it != leaves.end()) { + throw std::invalid_argument(absl::StrFormat( + "Too many leaves for PyTreeDef; expected %d.", num_leaves())); + } + if (agenda.size() != 1) { + throw std::logic_error("PyTreeDef traversal did not yield a singleton."); + } + return std::move(agenda.back()); +} + +nb::object PyTreeDef::Unflatten(nb::iterable leaves) const { + return UnflattenImpl(leaves); +} + +nb::object PyTreeDef::Unflatten(absl::Span leaves) const { + return UnflattenImpl(leaves); +} + +/*static*/ nb::object PyTreeDef::MakeNode(const PyTreeDef::Node& node, + absl::Span children) { + if (children.size() != node.arity) { + throw std::logic_error("Node arity mismatch."); + } + switch (node.kind) { + case PyTreeKind::kLeaf: + throw std::logic_error("MakeNode not implemented for leaves."); + + case PyTreeKind::kNone: + return nb::none(); + + case PyTreeKind::kTuple: + case PyTreeKind::kNamedTuple: { + nb::object tuple = nb::steal(PyTuple_New(node.arity)); + for (int i = 0; i < node.arity; ++i) { + PyTuple_SET_ITEM(tuple.ptr(), i, children[i].release().ptr()); + } + if (node.kind == PyTreeKind::kNamedTuple) { + return node.node_data(*tuple); + } else { + return tuple; + } + } + + case PyTreeKind::kList: { + nb::object list = nb::steal(PyList_New(node.arity)); + for (int i = 0; i < node.arity; ++i) { + PyList_SET_ITEM(list.ptr(), i, children[i].release().ptr()); + } + return list; + } + + case PyTreeKind::kDict: { + nb::dict dict; + for (int i = 0; i < node.arity; ++i) { + dict[node.sorted_dict_keys[i]] = std::move(children[i]); + } + return std::move(dict); + break; + } + case PyTreeKind::kCustom: { + nb::object tuple = nb::steal(PyTuple_New(node.arity)); + for (int i = 0; i < node.arity; ++i) { + PyTuple_SET_ITEM(tuple.ptr(), i, children[i].release().ptr()); + } + return node.custom->from_iterable(node.node_data, tuple); + } + + case PyTreeKind::kDataclass: { + nb::kwargs kwargs; + auto meta_size = node.custom->meta_fields.size(); + for (int i = 0; i < meta_size; ++i) { + kwargs[node.custom->meta_fields[i]] = + nb::borrow(nb::tuple(node.node_data)[i]); + } + auto data_size = node.custom->data_fields.size(); + for (int i = 0; i < data_size; ++i) { + kwargs[node.custom->data_fields[i]] = std::move(children[i]); + } + return node.custom->type(**kwargs); + } + } + throw std::logic_error("Unreachable code."); +} + +nb::list PyTreeDef::FlattenUpTo(nb::handle xs) const { + nb::list leaves = nb::steal(PyList_New(num_leaves())); + std::vector agenda; + agenda.push_back(nb::borrow(xs)); + auto it = traversal_.rbegin(); + int leaf = num_leaves() - 1; + while (!agenda.empty()) { + if (it == traversal_.rend()) { + throw std::invalid_argument(absl::StrFormat( + "Tree structures did not match: %s vs %s", + nb::cast(nb::repr(xs)), ToString())); + } + const Node& node = *it; + nb::object object = agenda.back(); + agenda.pop_back(); + ++it; + + switch (node.kind) { + case PyTreeKind::kLeaf: + if (leaf < 0) { + throw std::logic_error("Leaf count mismatch."); + } + PyList_SET_ITEM(leaves.ptr(), leaf, object.release().ptr()); + --leaf; + break; + + case PyTreeKind::kNone: + if (!object.is_none()) { + throw std::invalid_argument(absl::StrFormat( + "Expected None, got %s.\n\n" + "In previous releases of JAX, flatten-up-to used to " + "consider None to be a tree-prefix of non-None values. To obtain " + "the previous behavior, you can usually write:\n" + " jax.tree.map(lambda x, y: None if x is None else f(x, y), a, " + "b, is_leaf=lambda x: x is None)", + nb::cast(nb::repr(object)))); + } + break; + + case PyTreeKind::kTuple: { + if (!PyTuple_CheckExact(object.ptr())) { + throw std::invalid_argument( + absl::StrFormat("Expected tuple, got %s.", + nb::cast(nb::repr(object)))); + } + nb::tuple tuple = nb::borrow(object); + if (tuple.size() != node.arity) { + throw std::invalid_argument(absl::StrFormat( + "Tuple arity mismatch: %d != %d; tuple: %s.", tuple.size(), + node.arity, nb::cast(nb::repr(object)))); + } + for (nb::handle entry : tuple) { + agenda.push_back(nb::borrow(entry)); + } + break; + } + + case PyTreeKind::kList: { + if (!PyList_CheckExact(object.ptr())) { + throw std::invalid_argument( + absl::StrFormat("Expected list, got %s.", + nb::cast(nb::repr(object)))); + } + nb::list list = nb::borrow(object); + if (list.size() != node.arity) { + throw std::invalid_argument(absl::StrFormat( + "List arity mismatch: %d != %d; list: %s.", list.size(), + node.arity, nb::cast(nb::repr(object)))); + } + for (nb::handle entry : list) { + agenda.push_back(nb::borrow(entry)); + } + break; + } + + case PyTreeKind::kDict: { + if (!PyDict_CheckExact(object.ptr())) { + throw std::invalid_argument( + absl::StrFormat("Expected dict, got %s.", + nb::cast(nb::repr(object)))); + } + nb::dict dict = nb::borrow(object); + std::vector keys = GetSortedPyDictKeys(dict.ptr()); + if (!IsSortedPyDictKeysEqual(keys, node.sorted_dict_keys)) { + // Convert to a nb::list for nb::repr to avoid having to stringify a + // vector. This is error path so it is fine to pay conversion cost. + throw std::invalid_argument( + absl::StrFormat("Dict key mismatch; expected keys: %s; dict: %s.", + nb::cast( + nb::repr(nb::cast(node.sorted_dict_keys))), + nb::cast(nb::repr(object)))); + } + for (nb::handle key : keys) { + agenda.push_back(dict[key]); + } + break; + } + + case PyTreeKind::kNamedTuple: { + if (!nb::isinstance(object) || + !nb::hasattr(object, "_fields")) { + throw std::invalid_argument( + absl::StrFormat("Expected named tuple, got %s.", + nb::cast(nb::repr(object)))); + } + nb::tuple tuple = nb::borrow(object); + if (tuple.size() != node.arity) { + throw std::invalid_argument(absl::StrFormat( + "Named tuple arity mismatch: %d != %d; tuple: %s.", tuple.size(), + node.arity, nb::cast(nb::repr(object)))); + } + if (tuple.type().not_equal(node.node_data)) { + throw std::invalid_argument(absl::StrFormat( + "Named tuple type mismatch: expected type: %s, tuple: %s.", + nb::cast(nb::repr(node.node_data)), + nb::cast(nb::repr(object)))); + } + for (nb::handle entry : tuple) { + agenda.push_back(nb::borrow(entry)); + } + break; + } + + case PyTreeKind::kCustom: { + auto* registration = registry_->Lookup(object.type()); + if (registration != node.custom) { + throw std::invalid_argument(absl::StrFormat( + "Custom node type mismatch: expected type: %s, value: %s.", + nb::cast(nb::repr(node.custom->type)), + nb::cast(nb::repr(object)))); + } + auto [leaves, aux_data] = node.custom->ToIterable(object); + if (node.node_data.not_equal(aux_data)) { + throw std::invalid_argument(absl::StrFormat( + "Mismatch custom node data: %s != %s; value: %s.", + nb::cast(nb::repr(node.node_data)), + nb::cast(nb::repr(aux_data)), + nb::cast(nb::repr(object)))); + } + int arity = 0; + for (nb::handle entry : leaves) { + ++arity; + agenda.push_back(nb::borrow(entry)); + } + if (arity != node.arity) { + throw std::invalid_argument(absl::StrFormat( + "Custom type arity mismatch: %d != %d; value: %s.", arity, + node.arity, nb::cast(nb::repr(object)))); + } + break; + } + + case PyTreeKind::kDataclass: { + auto* registration = registry_->Lookup(object.type()); + if (registration != node.custom) { + throw std::invalid_argument(absl::StrFormat( + "Custom dataclass node type mismatch: expected type: %s, value: " + "%s.", + nb::cast(nb::repr(node.custom->type)), + nb::cast(nb::repr(std::move(object))))); + } + auto meta_size = node.custom->meta_fields.size(); + nb::object aux_data = nb::steal(PyTuple_New(meta_size)); + for (int meta_leaf = 0; meta_leaf < meta_size; ++meta_leaf) { + PyTuple_SET_ITEM( + aux_data.ptr(), meta_leaf, + nb::getattr(object, node.custom->meta_fields[meta_leaf]) + .release() + .ptr()); + } + if (node.node_data.not_equal(aux_data)) { + throw std::invalid_argument(absl::StrFormat( + "Mismatch custom dataclass node data: %s != %s; value: %s.", + nb::cast(nb::repr(node.node_data)), + nb::cast(nb::repr(aux_data)), + nb::cast(nb::repr(object)))); + } + auto data_size = node.custom->data_fields.size(); + if (data_size != node.arity) { + throw std::invalid_argument(absl::StrFormat( + "Custom type arity mismatch: %d != %d; value: %s.", data_size, + node.arity, nb::cast(nb::repr(object)))); + } + for (int leaf = 0; leaf < data_size; ++leaf) { + agenda.push_back(nb::borrow( + nb::getattr(object, node.custom->data_fields[leaf]))); + } + break; + } + } + } + if (it != traversal_.rend() || leaf != -1) { + throw std::invalid_argument( + absl::StrFormat("Tree structures did not match: %s vs %s", + nb::cast(nb::repr(xs)), ToString())); + } + return leaves; +} + +nb::object PyTreeDef::Walk(const nb::callable& f_node, nb::handle f_leaf, + nb::iterable leaves) const { + std::vector agenda; + auto it = leaves.begin(); + for (const Node& node : traversal_) { + switch (node.kind) { + case PyTreeKind::kLeaf: { + if (it == leaves.end()) { + throw std::invalid_argument("Too few leaves for PyTreeDef"); + } + + nb::object leaf = nb::borrow(*it); + agenda.push_back(f_leaf.is_none() ? std::move(leaf) + : f_leaf(std::move(leaf))); + ++it; + break; + } + + case PyTreeKind::kNone: + case PyTreeKind::kTuple: + case PyTreeKind::kNamedTuple: + case PyTreeKind::kList: + case PyTreeKind::kDict: + case PyTreeKind::kCustom: + case PyTreeKind::kDataclass: { + if (agenda.size() < node.arity) { + throw std::logic_error("Too few elements for custom type."); + } + nb::object tuple = nb::steal(PyTuple_New(node.arity)); + for (int i = node.arity - 1; i >= 0; --i) { + PyTuple_SET_ITEM(tuple.ptr(), i, agenda.back().release().ptr()); + agenda.pop_back(); + } + nb::object node_data = node.node_data; + if (node.kind == PyTreeKind::kDict) { + // Convert to a nb::list for f_node invocation. + node_data = nb::cast(node.sorted_dict_keys); + } + agenda.push_back(f_node(tuple, node_data ? node_data : nb::none())); + } + } + } + if (it != leaves.end()) { + throw std::invalid_argument("Too many leaves for PyTreeDef"); + } + if (agenda.size() != 1) { + throw std::logic_error("PyTreeDef traversal did not yield a singleton."); + } + return std::move(agenda.back()); +} + +nb::object PyTreeDef::FromIterableTreeHelper( + nb::handle xs, + absl::InlinedVector::const_reverse_iterator* it) const { + if (*it == traversal_.rend()) { + throw std::invalid_argument("Tree structures did not match."); + } + const Node& node = **it; + ++*it; + if (node.kind == PyTreeKind::kLeaf) { + return nb::borrow(xs); + } + nb::iterable iterable = nb::borrow(xs); + std::vector ys; + ys.reserve(node.arity); + for (nb::handle x : iterable) { + ys.push_back(nb::borrow(x)); + } + if (ys.size() != node.arity) { + throw std::invalid_argument("Arity mismatch between trees"); + } + for (int j = node.arity - 1; j >= 0; --j) { + ys[j] = FromIterableTreeHelper(ys[j], it); + } + + return MakeNode(node, absl::MakeSpan(ys)); +} + +nb::object PyTreeDef::FromIterableTree(nb::handle xs) const { + auto it = traversal_.rbegin(); + nb::object out = FromIterableTreeHelper(xs, &it); + if (it != traversal_.rend()) { + throw std::invalid_argument("Tree structures did not match."); + } + return out; +} + +nb_class_ptr PyTreeDef::Compose(const PyTreeDef& inner) const { + if (inner.registry_ != registry_) { + throw std::invalid_argument( + "PyTree registries of PyTreeDefs passed to Compose() must match."); + } + auto out = make_nb_class(registry_ref_); + out->traversal_.reserve(static_cast(num_leaves()) * + inner.num_nodes() + + num_nodes() - num_leaves()); + for (const Node& n : traversal_) { + if (n.kind == PyTreeKind::kLeaf) { + absl::c_copy(inner.traversal_, std::back_inserter(out->traversal_)); + } else { + out->traversal_.push_back(n); + } + } + out->SetNumLeavesAndNumNodes(); + return out; +} + +/*static*/ nb_class_ptr PyTreeDef::Tuple( + nb_class_ptr registry, nb::list defs) { + auto out = make_nb_class(std::move(registry)); + int num_leaves = 0; + for (nb::handle def_handle : defs) { + const PyTreeDef* def = nb::cast(def_handle); + if (def->registry() != out->registry()) { + throw std::invalid_argument( + "PyTree registries of PyTreeDefs passed to Tuple() must match."); + } + absl::c_copy(def->traversal_, std::back_inserter(out->traversal_)); + num_leaves += def->num_leaves(); + } + Node node; + node.kind = PyTreeKind::kTuple; + node.arity = defs.size(); + node.num_leaves = num_leaves; + node.num_nodes = out->traversal_.size() + 1; + out->traversal_.push_back(node); + return out; +} + +std::vector> PyTreeDef::Children() const { + std::vector> children; + if (traversal_.empty()) { + return children; + } + Node const& root = traversal_.back(); + children.resize(root.arity); + int pos = traversal_.size() - 1; + for (int i = root.arity - 1; i >= 0; --i) { + children[i] = make_nb_class(registry_ref_); + const Node& node = traversal_.at(pos - 1); + if (pos < node.num_nodes) { + throw std::logic_error("children() walked off start of array"); + } + std::copy(traversal_.begin() + pos - node.num_nodes, + traversal_.begin() + pos, + std::back_inserter(children[i]->traversal_)); + pos -= node.num_nodes; + } + if (pos != 0) { + throw std::logic_error("pos != 0 at end of PyTreeDef::Children"); + } + return children; +} + +std::string PyTreeDef::ToString() const { + std::vector agenda; + for (const Node& node : traversal_) { + if (agenda.size() < node.arity) { + throw std::logic_error("Too few elements for container."); + } + + std::string children = + absl::StrJoin(agenda.end() - node.arity, agenda.end(), ", "); + std::string representation; + switch (node.kind) { + case PyTreeKind::kLeaf: + agenda.push_back("*"); + continue; + case PyTreeKind::kNone: + representation = "None"; + break; + case PyTreeKind::kTuple: + // Tuples with only one element must have a trailing comma. + if (node.arity == 1) children += ","; + representation = absl::StrCat("(", children, ")"); + break; + case PyTreeKind::kList: + representation = absl::StrCat("[", children, "]"); + break; + case PyTreeKind::kDict: { + if (node.sorted_dict_keys.size() != node.arity) { + throw std::logic_error("Number of keys and entries does not match."); + } + representation = "{"; + std::string separator; + auto child_iter = agenda.end() - node.arity; + for (const nb::handle& key : node.sorted_dict_keys) { + absl::StrAppendFormat(&representation, "%s%s: %s", separator, + nb::cast(nb::repr(key)), + *child_iter); + child_iter++; + separator = ", "; + } + representation += "}"; + break; + } + + case PyTreeKind::kNamedTuple: + case PyTreeKind::kCustom: + case PyTreeKind::kDataclass: { + std::string kind; + std::string data; + if (node.kind == PyTreeKind::kNamedTuple) { + kind = "namedtuple"; + if (node.node_data) { + // Node data for named tuples is the type. + data = absl::StrFormat( + "[%s]", nb::cast( + nb::str(nb::getattr(node.node_data, "__name__")))); + } + } else { + kind = nb::cast( + nb::str(nb::getattr(node.custom->type, "__name__"))); + if (node.node_data) { + data = absl::StrFormat( + "[%s]", nb::cast(nb::str(node.node_data))); + } + } + + representation = + absl::StrFormat("CustomNode(%s%s, [%s])", kind, data, children); + break; + } + } + agenda.erase(agenda.end() - node.arity, agenda.end()); + agenda.push_back(std::move(representation)); + } + if (agenda.size() != 1) { + throw std::logic_error("PyTreeDef traversal did not yield a singleton."); + } + return absl::StrCat("PyTreeDef(", agenda.back(), ")"); +} + +nb::object PyTreeDef::ToPickle() const { + nb::list traversal; + for (const auto& node : traversal_) { + nb::object node_data = node.node_data; + if (node.kind == PyTreeKind::kDict) { + // Convert to a nb::list for pickling to avoid having to pickle a vector. + // Pickle should be a rare operation so this conversion cost is hopefully + // on non-critical path. + node_data = nb::cast(node.sorted_dict_keys); + } + traversal.append( + nb::make_tuple(static_cast(node.kind), node.arity, + node_data ? node_data : nb::none(), + node.custom != nullptr ? node.custom->type : nb::none(), + node.num_leaves, node.num_nodes)); + } + return nb::make_tuple(nb::cast(registry_ref_), traversal); +} + +void PyTreeDef::FromPickle(nb::object pickle) { + for (const auto& item : nb::cast(pickle)) { + auto t = nb::cast(item); + if (t.size() != 6) { + throw xla::XlaRuntimeError("Malformed pickled PyTreeDef"); + } + Node& node = traversal_.emplace_back(); + node.kind = static_cast(nb::cast(t[0])); + node.arity = nb::cast(t[1]); + switch (node.kind) { + case PyTreeKind::kNamedTuple: + node.node_data = t[2]; + break; + case PyTreeKind::kDict: + node.sorted_dict_keys = nb::cast>(t[2]); + break; + case PyTreeKind::kCustom: + case PyTreeKind::kDataclass: + node.node_data = t[2]; + break; + default: + if (!t[2].is_none()) { + throw xla::XlaRuntimeError("Malformed pickled PyTreeDef"); + } + break; + } + if (node.kind == PyTreeKind::kCustom || + node.kind == PyTreeKind::kDataclass) { + node.custom = t[3].is_none() ? nullptr : registry()->Lookup(t[3]); + if (node.custom == nullptr) { + throw xla::XlaRuntimeError( + absl::StrCat("Unknown custom type in pickled PyTreeDef: ", + nb::cast(nb::repr(t[3])))); + } + } else { + if (!t[3].is_none()) { + throw xla::XlaRuntimeError("Malformed pickled PyTreeDef"); + } + } + node.num_leaves = nb::cast(t[4]); + node.num_nodes = nb::cast(t[5]); + } +} + +void PyTreeDef::SetNumLeavesAndNumNodes() { + // num_leaves and num_nodes are fully determined by arity. + std::vector> starts; + int num_leaves = 0; + for (int i = 0; i < traversal_.size(); ++i) { + std::pair start = {num_leaves, i}; + if (traversal_[i].kind == PyTreeKind::kLeaf) { + num_leaves += 1; + } + if (traversal_[i].arity == 0) { + starts.push_back(start); + } else { + starts.resize(starts.size() - (traversal_[i].arity - 1)); + } + traversal_[i].num_leaves = num_leaves - starts.back().first; + traversal_[i].num_nodes = i + 1 - starts.back().second; + } +} + +void PyTreeDef::SerializeTo(jax::PyTreeDefProto& result) const { + absl::flat_hash_map interned_strings; + auto intern_str = [&](const std::string& key) { + auto [it, added] = + interned_strings.emplace(key, result.interned_strings_size()); + if (added) { + result.add_interned_strings(key); + } + return it->second; + }; + for (const auto& node : traversal_) { + auto* node_data = result.add_nodes(); + node_data->set_arity(node.arity); + switch (node.kind) { + case PyTreeKind::kLeaf: + node_data->set_type(jax::PyTreeNodeType::PY_TREE_KIND_LEAF); + break; + case PyTreeKind::kList: + node_data->set_type(jax::PyTreeNodeType::PY_TREE_KIND_LIST); + break; + case PyTreeKind::kNone: + node_data->set_type(jax::PyTreeNodeType::PY_TREE_KIND_NONE); + break; + case PyTreeKind::kTuple: + node_data->set_type(jax::PyTreeNodeType::PY_TREE_KIND_TUPLE); + break; + case PyTreeKind::kDict: + node_data->set_type(jax::PyTreeNodeType::PY_TREE_KIND_DICT); + for (auto& key : node.sorted_dict_keys) { + if (!nb::isinstance(key)) { + throw std::invalid_argument( + "Only string keys are supported in proto pytree " + "serialization."); + } + node_data->mutable_dict_keys()->add_str_id( + intern_str(nb::cast(key))); + } + break; + default: + throw std::invalid_argument( + "User-defined nodes are not supported when serializing pytrees as " + "protocol buffers. You should either convert the user-defined " + "nodes to another type or use pickle instead."); + break; + } + } +} + +nb_class_ptr PyTreeDef::DeserializeFrom( + nb_class_ptr registry, const jax::PyTreeDefProto& input) { + std::vector interned_strings; + interned_strings.reserve(input.interned_strings().size()); + for (auto& s : input.interned_strings()) { + interned_strings.push_back(nb::cast(s)); + } + nb_class_ptr result = + make_nb_class(std::move(registry)); + for (auto& node_proto : input.nodes()) { + result->traversal_.emplace_back(); + auto& node = result->traversal_.back(); + node.arity = node_proto.arity(); + node.custom = nullptr; + switch (node_proto.type()) { + case jax::PyTreeNodeType::PY_TREE_KIND_LEAF: + node.kind = PyTreeKind::kLeaf; + break; + case jax::PyTreeNodeType::PY_TREE_KIND_LIST: + node.kind = PyTreeKind::kList; + break; + case jax::PyTreeNodeType::PY_TREE_KIND_NONE: + node.kind = PyTreeKind::kNone; + break; + case jax::PyTreeNodeType::PY_TREE_KIND_TUPLE: + node.kind = PyTreeKind::kTuple; + break; + case jax::PyTreeNodeType::PY_TREE_KIND_DICT: + node.kind = PyTreeKind::kDict; + for (uint32_t str_id : node_proto.dict_keys().str_id()) { + if (str_id >= interned_strings.size()) { + throw std::invalid_argument( + "Malformed pytree proto (dict_key out of range)."); + } + node.sorted_dict_keys.push_back(interned_strings.at(str_id)); + } + break; + default: + throw std::invalid_argument( + "Malformed pytree proto (invalid node type)"); + break; + } + } + result->SetNumLeavesAndNumNodes(); + return result; +} + +std::optional> PyTreeDef::GetNodeData() + const { + if (traversal_.empty()) { + throw std::logic_error("empty PyTreeDef traversal."); + } + auto builtin_type = [](PyTypeObject* type_obj) { + return nb::borrow(reinterpret_cast(type_obj)); + }; + const auto& node = traversal_.back(); + switch (node.kind) { + case PyTreeKind::kLeaf: + return std::nullopt; + case PyTreeKind::kNone: + return std::make_pair(builtin_type(Py_TYPE(Py_None)), nb::none()); + case PyTreeKind::kTuple: + return std::make_pair(builtin_type(&PyTuple_Type), nb::none()); + case PyTreeKind::kList: + return std::make_pair(builtin_type(&PyList_Type), nb::none()); + case PyTreeKind::kDict: + return std::make_pair(builtin_type(&PyDict_Type), + nb::cast(node.sorted_dict_keys)); + case PyTreeKind::kNamedTuple: + return std::make_pair(node.node_data, nb::none()); + case PyTreeKind::kCustom: + case PyTreeKind::kDataclass: + return std::make_pair(node.custom->type, node.node_data); + } +} + +int PyTreeDef::Node::tp_traverse(visitproc visit, void* arg) const { + Py_VISIT(node_data.ptr()); + for (const auto& key : sorted_dict_keys) { + Py_VISIT(key.ptr()); + } + return 0; +} + +/* static */ int PyTreeDef::tp_traverse(PyObject* self, visitproc visit, + void* arg) { + Py_VISIT(Py_TYPE(self)); + if (!nb::inst_ready(self)) { + return 0; + } + PyTreeDef* treedef = nb::inst_ptr(self); + Py_VISIT(treedef->registry_ref_.ptr()); + for (const auto& node : treedef->traversal_) { + node.tp_traverse(visit, arg); + } + return 0; +} + +/* static */ int PyTreeDef::tp_clear(PyObject* self) { + PyTreeDef* treedef = nb::inst_ptr(self); + treedef->registry_ref_.reset(); + treedef->traversal_.clear(); + return 0; +} + +/* static */ PyType_Slot PyTreeDef::slots_[] = { + {Py_tp_traverse, (void*)PyTreeDef::tp_traverse}, + {Py_tp_clear, (void*)PyTreeDef::tp_clear}, + {0, nullptr}, +}; + +void BuildPytreeSubmodule(nb::module_& m) { + nb::module_ pytree = m.def_submodule("pytree", "Python tree library"); + pytree.attr("version") = nb::int_(3); + + nb::class_ treedef(pytree, "PyTreeDef", + nb::type_slots(PyTreeDef::slots_)); + + nb::class_ registry(m, "PyTreeRegistry", nb::dynamic_attr(), + nb::type_slots(PyTreeRegistry::slots_)); + + registry.def(nb::init(), + nb::arg("enable_none") = true, nb::arg("enable_tuple") = true, + nb::arg("enable_namedtuple") = true, + nb::arg("enable_list") = true, nb::arg("enable_dict") = true); + registry.def( + "flatten", + [](nb_class_ptr registry, nb::object x, + std::optional leaf_predicate) { + nb::list leaves; + nb_class_ptr def = + make_nb_class(std::move(registry)); + def->Flatten(x, leaves, leaf_predicate); + return nb::make_tuple(std::move(leaves), std::move(def)); + }, + nb::arg("tree").none(), nb::arg("leaf_predicate").none() = std::nullopt); + registry.def("flatten_one_level", &PyTreeRegistry::FlattenOneLevel, + nb::arg("tree").none()); + registry.def("flatten_one_level_with_keys", + &PyTreeRegistry::FlattenOneLevelWithKeys, + nb::arg("tree").none()); + registry.def( + "flatten_with_path", + [](nb_class_ptr registry, nb::object x, + std::optional leaf_predicate) { + nb::list leaves; + nb_class_ptr def = + make_nb_class(std::move(registry)); + def->FlattenWithPath(x, leaves, leaf_predicate); + return nb::make_tuple(std::move(leaves), std::move(def)); + }, + nb::arg("tree").none(), nb::arg("leaf_predicate").none() = std::nullopt); + registry.def("register_node", &PyTreeRegistry::Register, + nb::arg("type").none(), nb::arg("to_iterable").none(), + nb::arg("from_iterable").none(), + nb::arg("to_iterable_with_keys").none() = std::nullopt); + registry.def("register_dataclass_node", &PyTreeRegistry::RegisterDataclass); + registry.def("__reduce__", + [](nb::object self) { return self.attr("__name__"); }); + + pytree.attr("_default_registry") = make_nb_class( + /*enable_none=*/true, /*enable_tuple=*/true, /*enable_namedtuple=*/true, + /*enable_list=*/true, /*enable_dict*/ true); + pytree.def("default_registry", + [registry = nb::cast>( + pytree.attr("_default_registry"))]() { return registry; }); + + pytree.attr("PyTreeRegistry") = m.attr("PyTreeRegistry"); + pytree.def("tuple", &PyTreeDef::Tuple); + pytree.def("all_leaves", &PyTreeDef::AllLeaves); + + treedef.def("unflatten", + static_cast( + &PyTreeDef::Unflatten)); + treedef.def("flatten_up_to", &PyTreeDef::FlattenUpTo, nb::arg("tree").none()); + treedef.def("compose", &PyTreeDef::Compose); + treedef.def( + "walk", &PyTreeDef::Walk, + "Walk pytree, calling f_node(node, node_data) at nodes, and f_leaf " + "at leaves", + nb::arg("f_node"), nb::arg("f_leaf"), nb::arg("leaves")); + treedef.def("from_iterable_tree", &PyTreeDef::FromIterableTree); + treedef.def("children", &PyTreeDef::Children); + treedef.def_prop_ro("num_leaves", &PyTreeDef::num_leaves); + treedef.def_prop_ro("num_nodes", &PyTreeDef::num_nodes); + treedef.def("__repr__", &PyTreeDef::ToString); + treedef.def("__eq__", + [](const PyTreeDef& a, const PyTreeDef& b) { return a == b; }); + treedef.def("__ne__", + [](const PyTreeDef& a, const PyTreeDef& b) { return a != b; }); + treedef.def("__hash__", [](const PyTreeDef& t) { return absl::HashOf(t); }); + treedef.def("serialize_using_proto", [](const PyTreeDef& a) { + jax::PyTreeDefProto result; + a.SerializeTo(result); + std::string serialized = result.SerializeAsString(); + return nb::bytes(serialized.data(), serialized.size()); + }); + treedef.def_static( + "deserialize_using_proto", + [](nb_class_ptr registry, nb::bytes data) { + jax::PyTreeDefProto input; + absl::string_view serialized(data.c_str(), data.size()); + if (serialized.size() > std::numeric_limits::max()) { + throw xla::XlaRuntimeError( + "Pytree serialization too large to deserialize."); + } + if (!input.ParseFromArray(serialized.data(), serialized.size())) { + throw xla::XlaRuntimeError("Could not deserialize PyTreeDefProto."); + } + return PyTreeDef::DeserializeFrom(std::move(registry), input); + }, + nb::arg("registry"), nb::arg("data")); + treedef.def("node_data", &PyTreeDef::GetNodeData, + "Returns None if a leaf-pytree, else (type, node_data)"); + treedef.def("__getstate__", &PyTreeDef::ToPickle); + treedef.def("__setstate__", [](PyTreeDef& t, nb::object o) { + nb::tuple pickle = nb::cast(o); + if (pickle.size() != 2) { + throw xla::XlaRuntimeError( + "Malformed pickled PyTreeDef, expected 2-tuple"); + } + auto registry = nb::cast>(pickle[0]); + new (&t) PyTreeDef(registry); + t.FromPickle(pickle[1]); + }); + + nb::class_ sequence_key(pytree, "SequenceKey"); + sequence_key.def(nb::init(), nb::arg("idx")); + sequence_key.def("__str__", &SequenceKey::ToString); + sequence_key.def("__repr__", &SequenceKey::ToReprString); + sequence_key.def("__eq__", &SequenceKey::Equals); + sequence_key.def("__hash__", [](const SequenceKey& key) { + return key.idx() + kSequenceKeyHashSalt; + }); + sequence_key.def_prop_ro("idx", &SequenceKey::idx); + sequence_key.def_prop_ro_static("__match_args__", &SequenceKey::MatchArgs); + sequence_key.def("__getstate__", + [](SequenceKey& key) { return nb::make_tuple(key.idx()); }); + sequence_key.def("__setstate__", + [](SequenceKey& key, const nb::tuple& state) { + if (state.size() != 1) { + throw xla::XlaRuntimeError( + "Malformed pickled SequenceKey, expected 1-tuple"); + } + new (&key) SequenceKey(nb::cast(state[0])); + }); + + nb::class_ dict_key(pytree, "DictKey", + nb::type_slots(DictKey::slots_)); + dict_key.def(nb::init(), nb::arg("key")); + dict_key.def("__str__", &DictKey::ToString); + dict_key.def("__repr__", &DictKey::ToReprString); + dict_key.def("__eq__", &DictKey::Equals); + dict_key.def("__hash__", + [](const DictKey& key) { return nanobind::hash(key.key()); }); + dict_key.def_prop_ro("key", &DictKey::key); + dict_key.def_prop_ro_static("__match_args__", &DictKey::MatchArgs); + dict_key.def("__getstate__", + [](DictKey& key) { return nb::make_tuple(key.key()); }); + dict_key.def("__setstate__", [](DictKey& key, const nb::tuple& state) { + if (state.size() != 1) { + throw xla::XlaRuntimeError("Malformed pickled DictKey, expected 1-tuple"); + } + new (&key) DictKey(nb::cast(state[0])); + }); + + nb::class_ get_attr_key(pytree, "GetAttrKey"); + get_attr_key.def(nb::init(), nb::arg("name")); + get_attr_key.def("__str__", &GetAttrKey::ToString); + get_attr_key.def("__repr__", &GetAttrKey::ToReprString); + get_attr_key.def("__eq__", &GetAttrKey::Equals); + get_attr_key.def("__hash__", + [](const GetAttrKey& key) { return nb::hash(key.name()); }); + get_attr_key.def_prop_ro("name", &GetAttrKey::name); + get_attr_key.def_prop_ro_static("__match_args__", &GetAttrKey::MatchArgs); + get_attr_key.def("__getstate__", + [](GetAttrKey& key) { return nb::make_tuple(key.name()); }); + get_attr_key.def("__setstate__", [](GetAttrKey& key, const nb::tuple& state) { + if (state.size() != 1) { + throw xla::XlaRuntimeError( + "Malformed pickled GetAttrKey, expected 1-tuple"); + } + new (&key) GetAttrKey(nb::str(state[0])); + }); + + nb::class_ flattened_index_key(pytree, + "FlattenedIndexKey"); + flattened_index_key.def(nb::init(), nb::arg("key")); + flattened_index_key.def("__str__", &FlattenedIndexKey::ToString); + flattened_index_key.def("__repr__", &FlattenedIndexKey::ToReprString); + flattened_index_key.def("__eq__", &FlattenedIndexKey::Equals); + flattened_index_key.def("__hash__", [](const FlattenedIndexKey& key) { + return key.key() + kFlattenedIndexKeyHashSalt; + }); + flattened_index_key.def_prop_ro("key", &FlattenedIndexKey::key); + flattened_index_key.def_prop_ro_static("__match_args__", + &FlattenedIndexKey::MatchArgs); + flattened_index_key.def("__getstate__", [](FlattenedIndexKey& key) { + return nb::make_tuple(key.key()); + }); + flattened_index_key.def( + "__setstate__", [](FlattenedIndexKey& key, const nb::tuple& state) { + if (state.size() != 1) { + throw xla::XlaRuntimeError( + "Malformed pickled FlattenedIndexKey, expected 1-tuple"); + } + new (&key) FlattenedIndexKey(nb::cast(state[0])); + }); +} + +} // namespace xla diff --git a/jaxlib/pytree.h b/jaxlib/pytree.h new file mode 100644 index 000000000000..f36d6999c887 --- /dev/null +++ b/jaxlib/pytree.h @@ -0,0 +1,404 @@ +/* Copyright 2019 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_PYTREE_H_ +#define JAXLIB_PYTREE_H_ + +// See https://docs.jax.dev/en/latest/pytrees.html for the documentation +// about pytree. + +#include + +#include +#include +#include +#include +#include +#include + +// placeholder for index annotation headers +#include "absl/container/flat_hash_map.h" +#include "absl/container/inlined_vector.h" +#include "absl/hash/hash.h" +#include "absl/types/span.h" +#include "nanobind/nanobind.h" +#include "jaxlib/nb_class_ptr.h" +#include "jaxlib/pytree.pb.h" + +namespace xla { + +enum class PyTreeKind { + kLeaf, // An opaque leaf node + kNone, // None. + kTuple, // A tuple + kNamedTuple, // A collections.namedtuple + kList, // A list + kDict, // A dict + kCustom, // A custom type. + kDataclass, // A dataclass. +}; + +// Registry of custom node types. +class PyTreeRegistry { + public: + PyTreeRegistry(bool enable_none, bool enable_tuple, bool enable_namedtuple, + bool enable_list, bool enable_dict); + + PyTreeRegistry(const PyTreeRegistry&) = delete; + PyTreeRegistry(PyTreeRegistry&&) = delete; + PyTreeRegistry& operator=(const PyTreeRegistry&) = delete; + PyTreeRegistry& operator=(PyTreeRegistry&&) = delete; + + struct Registration { + PyTreeKind kind; + + // The following values are populated for custom types. + // The Python type object, used to identify the type. + nanobind::object type; + // A function with signature: object -> (iterable, aux_data) + nanobind::callable to_iterable; + // A function with signature: (aux_data, iterable) -> object + nanobind::callable from_iterable; + // A function with signature: (aux_data, iterable(keypath, leaf)) -> object + std::optional to_iterable_with_keys; + + // Helper that calls to_iterable and validates that it returns a pair + // of an iterable and an aux_data object + std::pair ToIterable( + nanobind::handle o) const; + // Helper that calls to_iterable_with_keys and validates that it returns a + // pair of an iterable of key-leaf pairs and an aux_data object. If + // to_iterable_with_keys is not available, return a dummy key for each leaf, + // similar to the current jax.tree_util.FlattenedIndexKey. + std::pair>, + nanobind::object> + ToIterableWithKeys(nanobind::handle o) const; + + // For dataclasses. + std::vector data_fields; + std::vector meta_fields; + + int tp_traverse(visitproc visit, void* arg); + }; + + // Registers a new custom type. Objects of `type` will be treated as container + // node types in PyTrees. + void Register( + nanobind::object type, nanobind::callable to_iterable, + nanobind::callable from_iterable, + std::optional to_iterable_with_keys = std::nullopt); + // Same, but for dataclasses. + void RegisterDataclass(nanobind::object type, + std::vector data_fields, + std::vector meta_fields); + + // Finds the custom type registration for `type`. Returns nullptr if none + // exists. + const Registration* Lookup(nanobind::handle type) const; + + PyTreeKind KindOfObject(nanobind::handle obj, + PyTreeRegistry::Registration const** custom) const; + + // Flattens a pytree one level, returning either a tuple of the leaves and + // the node data, or None, if the entry is a leaf. + nanobind::object FlattenOneLevel(nanobind::handle x) const; + // Similar to above but returns a key-leaf pair for each leaf. + nanobind::object FlattenOneLevelWithKeys(nanobind::handle x) const; + // Underlying implementation of FlattenOneLevel and FlattenOneLevelWithKeys. + nanobind::object FlattenOneLevelImpl(nanobind::handle x, + bool with_keys) const; + + static PyType_Slot slots_[]; + + private: + struct TypeHash { + using is_transparent = void; + size_t operator()(const nanobind::object& t) const { + return absl::HashOf(t.ptr()); + } + size_t operator()(const nanobind::handle& t) const { + return absl::HashOf(t.ptr()); + } + }; + struct TypeEq { + using is_transparent = void; + bool operator()(const nanobind::object& a, + const nanobind::object& b) const { + return a.ptr() == b.ptr(); + } + bool operator()(const nanobind::object& a, + const nanobind::handle& b) const { + return a.ptr() == b.ptr(); + } + }; + mutable nanobind::ft_mutex mu_; + absl::flat_hash_map, TypeHash, + TypeEq> + registrations_; // Guarded by mu_ + bool enable_namedtuple_; + + static int tp_traverse(PyObject* self, visitproc visit, void* arg); + static int tp_clear(PyObject* self); +}; + +class SequenceKey { + public: + explicit SequenceKey(int idx) : idx_(idx) {}; + std::string ToReprString() const; + std::string ToString() const; + bool Equals(const nanobind::object& other); + int idx() const { return idx_; } + static nanobind::tuple MatchArgs(nanobind::handle unused); + + private: + int idx_; +}; + +class DictKey { + public: + explicit DictKey(nanobind::object key) : key_(key) {}; + std::string ToReprString() const; + std::string ToString() const; + bool Equals(const nanobind::object& other); + nanobind::object key() const { return key_; } + static nanobind::tuple MatchArgs(nanobind::handle unused); + static PyType_Slot slots_[]; + + private: + nanobind::object key_; + static int tp_traverse(PyObject* self, visitproc visit, void* arg); + static int tp_clear(PyObject* self); +}; + +class GetAttrKey { + public: + explicit GetAttrKey(nanobind::str name) : name_(name) {}; + std::string ToReprString() const; + std::string ToString() const; + bool Equals(const nanobind::object& other); + nanobind::str name() const { return name_; } + static nanobind::tuple MatchArgs(nanobind::handle unused); + + private: + nanobind::str name_; +}; + +class FlattenedIndexKey { + public: + explicit FlattenedIndexKey(int key) : key_(key) {}; + std::string ToReprString() const; + std::string ToString() const; + bool Equals(const nanobind::object& other); + int key() const { return key_; } + static nanobind::tuple MatchArgs(nanobind::handle unused); + + private: + int key_; +}; + +// A PyTreeDef describes the tree structure of a PyTree. A PyTree is a tree of +// Python values, where the interior nodes are tuples, lists, dictionaries, or +// user-defined containers, and the leaves are other objects. +class PyTreeDef { + public: + // Unowned registry: the registry must remain live at least as long as the + // PyTreeDef. It is the caller's responsibility to enforce this. + explicit PyTreeDef(PyTreeRegistry* registry) : registry_(registry) {} + + explicit PyTreeDef(nb_class_ptr registry) + : registry_(registry.get()), registry_ref_(std::move(registry)) {} + + // Flattens a Pytree into a list of leaves and a PyTreeDef. + // Returns references to the flattened objects, which might be temporary + // objects in the case of custom pytype handlers. + static std::pair, nb_class_ptr> + Flatten(nanobind::handle x, nb_class_ptr registry, + std::optional leaf_predicate = std::nullopt); + + // Flattens a Pytree into a list of `leaves` and a PyTreeDef (this). + // `leaves` owns references to the flattened objects, which might be + // temporary objects in the case of custom pytype handlers. + void Flatten(nanobind::handle handle, std::vector& leaves, + std::optional leaf_predicate = std::nullopt); + void Flatten(nanobind::handle handle, + absl::InlinedVector& leaves, + std::optional leaf_predicate = std::nullopt); + void Flatten(nanobind::handle handle, nanobind::list& leaves, + std::optional leaf_predicate = std::nullopt); + + void FlattenWithPath( + nanobind::handle handle, nanobind::list& leaves, + std::optional leaf_predicate = std::nullopt); + + // Tests whether the given list is a flat list of leaves. + static bool AllLeaves(PyTreeRegistry* registry, const nanobind::iterable& x); + + // Flattens a Pytree up to this PyTreeDef. 'this' must be a tree prefix of + // the tree-structure of 'x'. For example, if we flatten a value + // [(1, (2, 3)), {"foo": 4}] with a treedef [(*, *), *], the result is the + // list of leaves [1, (2, 3), {"foo": 4}]. + nanobind::list FlattenUpTo(nanobind::handle x) const; + + // Returns an unflattened PyTree given an iterable of leaves and a PyTreeDef. + nanobind::object Unflatten(nanobind::iterable leaves) const; + nanobind::object Unflatten(absl::Span leaves) const; + + // Composes two PyTreeDefs, replacing the leaves of this tree with copies of + // `inner`. The returned PyTreeDef holds a reference to its registry. + nb_class_ptr Compose(const PyTreeDef& inner) const; + + // Makes a Tuple PyTreeDef out of a vector of PyTreeDefs. + static nb_class_ptr Tuple(nb_class_ptr registry, + nanobind::list defs); + + // The returned PyTreeDefs hold a reference to the registry. + std::vector> Children() const; + + // Maps a function over a PyTree structure, applying f_leaf to each leaf, and + // f_node(node, node_data) to each container node. + nanobind::object Walk(const nanobind::callable& f_node, + nanobind::handle f_leaf, + nanobind::iterable leaves) const; + + // Given a tree of iterables with the same node/leaf structure as this PyTree, + // build the corresponding PyTree. + // TODO(phawkins): use flattening everywhere instead and delete this method. + nanobind::object FromIterableTree(nanobind::handle xs) const; + + int num_leaves() const { + if (traversal_.empty()) { + return 0; + } + return traversal_.back().num_leaves; + } + + int num_nodes() const { return traversal_.size(); } + + PyTreeRegistry* registry() const { return registry_; } + + size_t Hash() const; + + bool operator==(const PyTreeDef& other) const; + bool operator!=(const PyTreeDef& other) const { return !(*this == other); } + + std::string ToString() const; + + // Transforms the PyTreeDef into a pickleable object. Used to implement + // `PyTreeDef.__getstate__`. + nanobind::object ToPickle() const; + + // Transforms the object returned by `ToPickleable()` back to PyTreeDef. Used + // to implement `PyTreeDef.__setstate__`. + void FromPickle(nanobind::object pickleable); + + void SerializeTo(jax::PyTreeDefProto& result) const; + + static nb_class_ptr DeserializeFrom( + nb_class_ptr registry, const jax::PyTreeDefProto& input); + + std::optional> GetNodeData() + const; + + static PyType_Slot slots_[]; + + private: + void SetNumLeavesAndNumNodes(); + + struct Node { + PyTreeKind kind = PyTreeKind::kLeaf; + + // Arity for non-kLeaf types. + int arity = 0; + + // Kind-specific auxiliary data. For a kNamedTuple, contains the tuple type + // object. For a kDict, use `sorted_dict_keys` field below. For a kCustom + // type, contains the auxiliary data returned by the `to_iterable` function. + nanobind::object node_data; + + // Kind-specific auxiliary data specialized for kDict. Use a c++ vector + // to hold the sorted dict keys instead of a py::list to avoid creating + // a new python list object when flattening kDict. For deeply nested dict, + // using c++ vector instead of py::list avoids creating too many python + // objects that make python gc sweep slow. + std::vector sorted_dict_keys; + + // Custom type registration. Must be null for non-custom types. + const PyTreeRegistry::Registration* custom = nullptr; + + // Number of leaf nodes in the subtree rooted at this node. + int num_leaves = 0; + + // Number of leaf and interior nodes in the subtree rooted at this node. + int num_nodes = 0; + + int tp_traverse(visitproc visit, void* arg) const; + }; + template + friend H AbslHashValue(H h, const Node& n); + + template + friend H AbslHashValue(H h, const PyTreeDef& t); + + // Helper that manufactures an instance of a node given its children. + static nanobind::object MakeNode(const Node& node, + absl::Span children); + + // Recursive helper used to implement FromIterableTree() + nanobind::object FromIterableTreeHelper( + nanobind::handle xs, + absl::InlinedVector::const_reverse_iterator* it) + const; + + template + void FlattenImpl( + nanobind::handle handle, T& leaves, + std::optional>& keypath, + const std::optional& leaf_predicate); + + template + nanobind::object UnflattenImpl(T leaves) const; + + static int tp_traverse(PyObject* self, visitproc visit, void* arg); + static int tp_clear(PyObject* self); + + // Pytree registry. Not owned. + PyTreeRegistry* registry_; + // If this class holds a reference to `registry`, it is held by + // `registry_ref_`. + nb_class_ptr registry_ref_; + + // Nodes, in a post-order traversal. We use an ordered traversal to minimize + // allocations, and post-order corresponds to the order we need to rebuild the + // tree structure. + absl::InlinedVector traversal_; +}; + +template +H AbslHashValue(H h, const PyTreeDef::Node& n) { + h = H::combine(std::move(h), n.kind, n.arity, n.custom); + return h; +} + +template +H AbslHashValue(H h, const PyTreeDef& t) { + h = H::combine(std::move(h), t.traversal_); + return h; +} + +void BuildPytreeSubmodule(nanobind::module_& m); + +} // namespace xla + +#endif // JAXLIB_PYTREE_H_ diff --git a/jaxlib/pytree.proto b/jaxlib/pytree.proto new file mode 100644 index 000000000000..73c087ef55ab --- /dev/null +++ b/jaxlib/pytree.proto @@ -0,0 +1,32 @@ +syntax = "proto3"; + +package jax; + +enum PyTreeNodeType { + PY_TREE_KIND_INVALID = 0; + PY_TREE_KIND_LEAF = 1; + PY_TREE_KIND_LIST = 2; + PY_TREE_KIND_NONE = 3; + PY_TREE_KIND_TUPLE = 4; + PY_TREE_KIND_DICT = 5; +} + +message DictKeysProto { + repeated uint32 str_id = 1; +} + +message PyTreeNodeDefProto { + // Recovers the tree structure. + uint32 arity = 1; + // Node type. + PyTreeNodeType type = 2; + // Only set when type == DICT. + DictKeysProto dict_keys = 3; +} + +// A Pytree. +message PyTreeDefProto { + repeated PyTreeNodeDefProto nodes = 1; + // Extra strings. + repeated string interned_strings = 2; +} diff --git a/jaxlib/pytree_test.py b/jaxlib/pytree_test.py new file mode 100644 index 000000000000..0e5ccf69bdbe --- /dev/null +++ b/jaxlib/pytree_test.py @@ -0,0 +1,116 @@ +# Copyright 2023 The JAX Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import collections +import dataclasses +import gc + +from absl.testing import absltest + +from jax.jaxlib import xla_client + +pytree = xla_client._xla.pytree + + +ExampleType = collections.namedtuple("ExampleType", "field0 field1") + +registry = pytree.PyTreeRegistry() + + +class ExampleType2: + + def __init__(self, field0, field1): + self.field0 = field0 + self.field1 = field1 + + def to_iterable(self): + return [self.field0, self.field1], (None,) + +def from_iterable(state, values): + del state + return ExampleType2(field0=values[0], field1=values[1]) + + +registry.register_node(ExampleType2, ExampleType2.to_iterable, from_iterable) + + +@dataclasses.dataclass +class Custom: + a: int + b: str + + +registry.register_dataclass_node(Custom, ["a"], ["b"]) + + +class PyTreeTest(absltest.TestCase): + + def roundtrip_proto(self, example): + original = registry.flatten(example)[1] + self.assertEqual( + pytree.PyTreeDef.deserialize_using_proto( + registry, original.serialize_using_proto() + ), + original, + ) + + def testSerializeDeserializeNoPickle(self): + o = object() + self.roundtrip_proto(({"a": o, "b": o}, [o, (o, o), None])) + + def testSerializeWithFallback(self): + o = object() + with self.assertRaises(ValueError): + self.roundtrip_proto({"a": ExampleType(field0=o, field1=o)}) + + def testRegisteredType(self): + o = object() + with self.assertRaises(ValueError): + self.roundtrip_proto({"a": ExampleType2(field0=o, field1=o)}) + + def testCompose(self): + x = registry.flatten(0)[1] + y = registry.flatten((0, 0))[1] + self.assertEqual((x.compose(y)).num_leaves, 2) + + def testTpTraverse(self): + self.assertContainsSubset( + [ + pytree.PyTreeRegistry, + ExampleType2, + ExampleType2.to_iterable, + from_iterable, + ], + gc.get_referents(registry), + ) + k1 = "k1" + k2 = "k2" + + t = ExampleType("a", "b") + _, treedef = registry.flatten([1, {k1: 2, k2: t}, 5, t]) + + self.assertContainsSubset( + [ + pytree.PyTreeDef, + registry, + k1, + k2, + ExampleType, + ], + gc.get_referents(treedef), + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/jaxlib/pywrap.bzl b/jaxlib/pywrap.bzl new file mode 100644 index 000000000000..e63bb0de9fd4 --- /dev/null +++ b/jaxlib/pywrap.bzl @@ -0,0 +1,83 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Wrappers around pywrap rules for JAX.""" + +load("@bazel_skylib//rules:expand_template.bzl", "expand_template") +load( + "@xla//third_party/py/rules_pywrap:pywrap.impl.bzl", + "pybind_extension", + _pywrap_binaries = "pywrap_binaries", + _pywrap_library = "pywrap_library", +) + +pywrap_library = _pywrap_library +pywrap_binaries = _pywrap_binaries + +def nanobind_pywrap_extension( + name, + srcs = [], + deps = [], + pytype_srcs = [], + pytype_deps = [], + copts = [], + linkopts = [], + visibility = None): + # buildifier: disable=function-docstring-args + "Python extension rule using nanobind and the pywrap rules." + module_name = name + lib_name = name + "_pywrap_library" + src_cc_name = name + "_pywrap_stub.c" + + # We put the entire contents of the extension in a single cc_library, which will become part of + # the common pywrap library. All the contents of all extensions will end up in the common + # library. + native.cc_library( + name = lib_name, + srcs = srcs, + copts = copts, + deps = deps, + local_defines = [ + "PyInit_{}=Wrapped_PyInit_{}".format(module_name, module_name), + ], + visibility = ["//visibility:private"], + ) + + # We build a small stub library as the extension that forwards to the PyInit_... symbol from the + # common pywrap library. + expand_template( + name = name + "_pywrap_stub", + testonly = True, + out = src_cc_name, + substitutions = { + "@MODULE_NAME@": module_name, + }, + template = "//jaxlib:pyinit_stub.c", + visibility = ["//visibility:private"], + ) + + # Despite its name "pybind_extension" has nothing to do with pybind. It is the Python extension + # rule from the pywrap rules. + pybind_extension( + name = name, + srcs = [src_cc_name], + deps = [":" + lib_name], + data = pytype_srcs, + linkopts = linkopts, + visibility = visibility, + default_deps = [], + common_lib_packages = [ + "jaxlib", + ], + ) diff --git a/jaxlib/rocm/BUILD b/jaxlib/rocm/BUILD index 9a25a795fd14..f265e6714c8e 100644 --- a/jaxlib/rocm/BUILD +++ b/jaxlib/rocm/BUILD @@ -79,7 +79,7 @@ cc_library( deps = [ ":hip_gpu_kernel_helpers", ":hip_vendor", - "//jaxlib:handle_pool", + "//jaxlib/gpu:handle_pool", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/synchronization", "@local_config_rocm//rocm:hipblas", @@ -87,54 +87,6 @@ cc_library( ], ) -cc_library( - name = "hipblas_kernels", - srcs = ["//jaxlib/gpu:blas_kernels.cc"], - hdrs = ["//jaxlib/gpu:blas_kernels.h"], - deps = [ - ":hip_blas_handle_pool", - ":hip_gpu_kernel_helpers", - ":hip_make_batch_pointers", - ":hip_vendor", - "//jaxlib:kernel_helpers", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/base", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/hash", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/synchronization", - "@local_config_rocm//rocm:hipblas", - "@local_config_rocm//rocm:rocm_headers", - "@xla//xla/service:custom_call_status", - ], -) - -nanobind_extension( - name = "_blas", - srcs = ["//jaxlib/gpu:blas.cc"], - copts = [ - "-fexceptions", - "-fno-strict-aliasing", - ], - features = ["-use_header_modules"], - module_name = "_blas", - deps = [ - ":hip_vendor", - ":hipblas_kernels", - "//jaxlib:kernel_nanobind_helpers", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/strings:str_format", - "@local_config_rocm//rocm:hipblas", - "@local_config_rocm//rocm:rocm_headers", - "@nanobind", - "@xla//xla/tsl/python/lib/core:numpy", - ], -) - cc_library( name = "miopen_rnn_kernels", srcs = ["//jaxlib/gpu:rnn_kernels.cc"], @@ -143,15 +95,15 @@ cc_library( ":ffi_wrapper", ":hip_gpu_kernel_helpers", ":hip_vendor", - "//jaxlib:handle_pool", "//jaxlib:kernel_helpers", + "//jaxlib/gpu:handle_pool", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", "@local_config_rocm//rocm:miopen", "@local_config_rocm//rocm:rocm_headers", "@xla//xla/ffi/api:ffi", - "@xla//xla/service:custom_call_status", ], ) @@ -182,7 +134,7 @@ cc_library( deps = [ ":hip_gpu_kernel_helpers", ":hip_vendor", - "//jaxlib:handle_pool", + "//jaxlib/gpu:handle_pool", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/synchronization", "@local_config_rocm//rocm:hipsolver", @@ -190,24 +142,6 @@ cc_library( ], ) -cc_library( - name = "hipsolver_kernels", - srcs = ["//jaxlib/gpu:solver_kernels.cc"], - hdrs = ["//jaxlib/gpu:solver_kernels.h"], - deps = [ - ":hip_gpu_kernel_helpers", - ":hip_solver_handle_pool", - ":hip_vendor", - "//jaxlib:kernel_helpers", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/synchronization", - "@local_config_rocm//rocm:hipsolver", - "@local_config_rocm//rocm:rocm_headers", - "@xla//xla/service:custom_call_status", - ], -) - cc_library( name = "hipsolver_interface", srcs = ["//jaxlib/gpu:solver_interface.cc"], @@ -242,7 +176,6 @@ cc_library( "@local_config_rocm//rocm:hipsolver", "@local_config_rocm//rocm:rocm_headers", "@xla//xla/ffi/api:ffi", - "@xla//xla/service:custom_call_status", ], ) @@ -256,20 +189,13 @@ nanobind_extension( features = ["-use_header_modules"], module_name = "_solver", deps = [ - ":hip_gpu_kernel_helpers", - ":hip_solver_handle_pool", ":hip_vendor", - ":hipsolver_kernels", ":hipsolver_kernels_ffi", "//jaxlib:kernel_nanobind_helpers", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:str_format", "@local_config_rocm//rocm:hipblas", "@local_config_rocm//rocm:hipsolver", "@local_config_rocm//rocm:rocm_headers", "@nanobind", - "@xla//xla/tsl/python/lib/core:numpy", ], ) @@ -291,16 +217,17 @@ cc_library( ":ffi_wrapper", ":hip_gpu_kernel_helpers", ":hip_vendor", - "//jaxlib:handle_pool", + "//jaxlib:ffi_helpers", "//jaxlib:kernel_helpers", + "//jaxlib/gpu:handle_pool", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", "@local_config_rocm//rocm:hipsparse", "@local_config_rocm//rocm:rocm_headers", "@xla//xla/ffi/api:ffi", - "@xla//xla/service:custom_call_status", ], ) @@ -398,7 +325,6 @@ cc_library( "@local_config_rocm//rocm:rocm_headers", "@xla//xla/ffi/api:c_api", "@xla//xla/ffi/api:ffi", - "@xla//xla/service:custom_call_status", ], ) @@ -412,7 +338,6 @@ rocm_library( "//jaxlib:kernel_helpers", "@local_config_rocm//rocm:rocm_headers", "@xla//xla/ffi/api:ffi", - "@xla//xla/service:custom_call_status", ], ) @@ -496,9 +421,11 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/synchronization", - "@tsl//tsl/platform:env", "@xla//xla/service:custom_call_status", + "@xla//xla/tsl/platform:env", + "@xla//xla/tsl/platform:errors", "@xla//xla/tsl/util:env_var", ], ) @@ -536,7 +463,9 @@ nanobind_extension( "//jaxlib:absl_status_casters", "//jaxlib:kernel_nanobind_helpers", "//jaxlib/gpu:triton_cc_proto", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", "@nanobind", ], ) @@ -544,7 +473,6 @@ nanobind_extension( py_library( name = "rocm_gpu_support", deps = [ - ":_blas", ":_hybrid", ":_linalg", ":_prng", @@ -555,11 +483,52 @@ py_library( ], ) +cc_library( + name = "py_client_gpu", + srcs = ["//jaxlib/gpu:py_client_gpu.cc"], + hdrs = ["//jaxlib/gpu:py_client_gpu.h"], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + ":hip_vendor", + "//jaxlib:ffi", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@dlpack", + "@nanobind", + "@xla//third_party/python_runtime:headers", # buildcleaner: keep + "@xla//xla:comparison_util", + "@xla//xla:shape_util", + "@xla//xla:util", + "@xla//xla:xla_data_proto_cc", + "@xla//xla/ffi:ffi_api", + "@xla//xla/ffi/api:ffi", + "@xla//xla/pjrt:host_callback", + "@xla//xla/pjrt:transpose", + "@xla//xla/python:nb_numpy", + "@xla//xla/python:types", + "@xla//xla/service:platform_util", + ], +) + nanobind_extension( name = "rocm_plugin_extension", srcs = ["rocm_plugin_extension.cc"], module_name = "rocm_plugin_extension", deps = [ + ":py_client_gpu", + "//jaxlib:kernel_nanobind_helpers", "//jaxlib/gpu:gpu_plugin_extension", "@com_google_absl//absl/log", "@com_google_absl//absl/strings", diff --git a/jaxlib/rocm/rocm_plugin_extension.cc b/jaxlib/rocm/rocm_plugin_extension.cc index 1dd1f1943fc8..37ae638a47fc 100644 --- a/jaxlib/rocm/rocm_plugin_extension.cc +++ b/jaxlib/rocm/rocm_plugin_extension.cc @@ -16,16 +16,19 @@ limitations under the License. #include #include -#include "nanobind/nanobind.h" #include "absl/log/log.h" #include "absl/strings/str_cat.h" #include "rocm/include/hip/hip_runtime.h" +#include "nanobind/nanobind.h" #include "jaxlib/gpu/gpu_plugin_extension.h" +#include "jaxlib/gpu/py_client_gpu.h" +#include "jaxlib/kernel_nanobind_helpers.h" namespace nb = nanobind; namespace xla { namespace { + std::string ToString(hipError_t result) { #define OSTREAM_ROCM_ERROR(__name) \ case hipError##__name: \ @@ -62,10 +65,30 @@ std::string ToString(hipError_t result) { return absl::StrCat("hipError_t(", static_cast(result), ")"); } } + +nb::dict FfiRegistrations() { + nb::dict dict; + nb::dict gpu_callback_dict; + gpu_callback_dict["instantiate"] = + jax::EncapsulateFfiHandler(jax::hip::kGpuTransposePlanCacheInstantiate); + gpu_callback_dict["execute"] = + jax::EncapsulateFfiHandler(jax::hip::kXlaFfiPythonGpuCallback); + dict["xla_ffi_python_gpu_callback"] = gpu_callback_dict; + dict["xla_ffi_partitioned_python_gpu_callback"] = gpu_callback_dict; + dict["xla_buffer_python_gpu_callback"] = + jax::EncapsulateFfiHandler(jax::hip::kXlaBufferPythonGpuCallback); + dict["xla_buffer_python_gpu_callback_cmd_buffer"] = + jax::EncapsulateFfiHandler( + jax::hip::kXlaBufferPythonGpuCallbackCmdBuffer); + return dict; +} + } // namespace NB_MODULE(rocm_plugin_extension, m) { BuildGpuPluginExtension(m); + m.def("ffi_registrations", &FfiRegistrations); + m.def( "get_device_ordinal", [](std::intptr_t data_value) { diff --git a/jaxlib/sdy.cc b/jaxlib/sdy.cc new file mode 100644 index 000000000000..c31d11bac0d0 --- /dev/null +++ b/jaxlib/sdy.cc @@ -0,0 +1,140 @@ +/* Copyright 2024 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/sdy.h" + +#include +#include + +#include "mhlo/transforms/passes.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/Bytecode/BytecodeWriter.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OwningOpRef.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/LLVM.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "nanobind/stl/tuple.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "shardy/dialect/sdy/ir/dialect.h" +#include "shardy/dialect/sdy/ir/utils.h" +#include "xla/hlo/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h" +#include "xla/mlir_hlo/mhlo/transforms/passes.h" +#include "xla/pjrt/mlir_to_hlo.h" +#include "xla/pjrt/status_casters.h" +#include "xla/service/spmd/shardy/constants.h" +#include "xla/service/spmd/shardy/sdy_round_trip/import_shardy_attrs.h" +#include "xla/service/spmd/shardy/sdy_round_trip/pipelines.h" +#include "xla/service/spmd/shardy/utils.h" +#include "xla/tsl/framework/mlir/status_scoped_diagnostic_handler.h" + +namespace nb = nanobind; + +namespace xla { + +namespace { + +absl::StatusOr SerializeUsingBytecode(mlir::ModuleOp module) { + std::string bytecode; + llvm::raw_string_ostream os(bytecode); + mlir::BytecodeWriterConfig config; + if (mlir::failed(mlir::writeBytecodeToFile(module, os, config))) { + return absl::InvalidArgumentError("mlir::writeBytecodeToFile failed"); + } + return bytecode; +} + +} // namespace + +void BuildSdySubmodule(nb::module_& m) { + nb::module_ mlir_module = m.def_submodule("sdy", "Shardy/XLA integration"); + + mlir_module + // TODO(b/707574930): define a C API for the XLA pipelines. + .def( + "sdy_round_trip_export_pipeline", + [](const nb::bytes& bytecode) -> nb::bytes { + mlir::MLIRContext context; + mlir::OwningOpRef module = + xla::ValueOrThrow(ParseMlirModuleString( + absl::string_view(bytecode.c_str(), bytecode.size()), + context)); + mlir::PassManager pm(&context); + sdy::addSdyRoundTripExportPipeline(pm); + tsl::StatusScopedDiagnosticHandler diagnosticHandler(&context); + ThrowIfError(diagnosticHandler.consumeStatus(pm.run(module.get()))); + std::string module_str = + xla::ValueOrThrow(SerializeUsingBytecode(module.get())); + return nb::bytes(module_str.data(), module_str.size()); + }, + nb::arg("module")) + .def( + "sdy_round_trip_import_shardings", + [](const nb::bytes& bytecode) -> nb::bytes { + mlir::MLIRContext context; + mlir::OwningOpRef module = + xla::ValueOrThrow(ParseMlirModuleString( + absl::string_view(bytecode.c_str(), bytecode.size()), + context)); + mlir::PassManager pm(&context); + pm.addPass(xla::sdy::createSdyRoundTripImportShardyAttrsPass()); + tsl::StatusScopedDiagnosticHandler diagnosticHandler(&context); + ThrowIfError(diagnosticHandler.consumeStatus(pm.run(module.get()))); + std::string module_str = + xla::ValueOrThrow(SerializeUsingBytecode(module.get())); + return nb::bytes(module_str.data(), module_str.size()); + }, + nb::arg("module")) + .def("lowered_with_shardy", + [](const nb::bytes& bytecode) -> bool { + mlir::MLIRContext context; + mlir::OwningOpRef module = + xla::ValueOrThrow(ParseMlirModuleString( + absl::string_view(bytecode.c_str(), bytecode.size()), + context)); + return mlir::sdy::getMeshAttr(module.get(), "mesh") || + sdy::tryGetFrontendAttr( + module.get(), sdy::kMeshesRoundTripAttr) + .has_value(); + }) + // TODO(bartchr): delete this and all uses of it once I have JAX export + // support multiple meshes. + .def("get_mesh", [](const nb::bytes& bytecode) -> nb::list { + mlir::MLIRContext context; + mlir::OwningOpRef module = + xla::ValueOrThrow(ParseMlirModuleString( + absl::string_view(bytecode.c_str(), bytecode.size()), context)); + auto mesh_attr = mlir::sdy::getMeshAttr(module.get(), "mesh"); + if (!mesh_attr) { + return {}; + } + nb::list mesh_shape; + for (auto axis : mesh_attr.getAxes()) { + mesh_shape.append( + nb::make_tuple(axis.getName().str(), axis.getSize())); + } + return mesh_shape; + }); +} + +} // namespace xla diff --git a/jaxlib/sdy.h b/jaxlib/sdy.h new file mode 100644 index 000000000000..60ce012738fb --- /dev/null +++ b/jaxlib/sdy.h @@ -0,0 +1,28 @@ +/* Copyright 2024 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_SDY_H_ +#define JAXLIB_SDY_H_ + +// placeholder for index annotation headers +#include "nanobind/nanobind.h" + +namespace xla { + +void BuildSdySubmodule(nanobind::module_& m); + +} // namespace xla + +#endif // JAXLIB_SDY_H_ diff --git a/jaxlib/setup.py b/jaxlib/setup.py index b3a37a25f1b2..6a0c6520af2b 100644 --- a/jaxlib/setup.py +++ b/jaxlib/setup.py @@ -58,24 +58,27 @@ def has_ext_modules(self): long_description_content_type='text/markdown', author='JAX team', author_email='jax-dev@google.com', - packages=['jaxlib', 'jaxlib.xla_extension'], - python_requires='>=3.10', + packages=['jaxlib'], + python_requires='>=3.11', install_requires=[ - 'scipy>=1.11.1', - 'numpy>=1.25', - 'ml_dtypes>=0.2.0', + 'scipy>=1.12', + 'numpy>=1.26', + 'ml_dtypes>=0.5.0', ], url='https://github.com/jax-ml/jax', license='Apache-2.0', classifiers=[ - "Programming Language :: Python :: 3.10", + "Development Status :: 5 - Production/Stable", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3.13", + "Programming Language :: Python :: Free Threading :: 3 - Stable", ], package_data={ 'jaxlib': [ '*.so', + '*.dylib', + '*.dll', '*.pyd*', 'py.typed', 'cpu/*', @@ -105,7 +108,6 @@ def has_ext_modules(self): 'triton/*.so', 'include/xla/ffi/api/*.h', ], - 'jaxlib.xla_extension': ['*.pyi'], }, zip_safe=False, distclass=BinaryDistribution, diff --git a/jaxlib/sharded_device_array.h b/jaxlib/sharded_device_array.h new file mode 100644 index 000000000000..97fb8702cae5 --- /dev/null +++ b/jaxlib/sharded_device_array.h @@ -0,0 +1,216 @@ +/* Copyright 2021 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_SHARDED_DEVICE_ARRAY_H_ +#define JAXLIB_SHARDED_DEVICE_ARRAY_H_ + +#include +#include +#include + +#include "nanobind/nanobind.h" +#include "nanobind/stl/variant.h" // IWYU pragma: keep +#include "xla/python/types.h" + +// TODO(jblespiau): The current implementation moves the Python logic to C++, +// as a preliminary step to executing the `pmap` execution path from C++. +// It implements the current Python behavior (thus, it may not be optimal, and +// we will be able to modify it later). + +namespace jax { + +// High level introduction. +// +// pmap and other parallel computation functions distribute some computation on +// several devices. On December 2020, the devices mesh (i.e. N-dimensional array +// of devices on which we map the computation) is defined by the user. +// +// We describe how to shard the inputs, and how to map it to the mesh of devices +// using `ShardingSpec`. It's mainly based on 2 components: +// - `sharding`, which specifies how to shard the inputs. +// - `mesh_mapping`, which specifies how to map shards to devices. +// +// The 3 following structs define how to shard one dimension of an ndarry. +// +// `NoSharding` (`None` in Python) means no sharding. +struct NoSharding { + bool operator==(const NoSharding& other) const { return true; } + bool operator!=(const NoSharding& other) const { return false; } +}; + +template +H AbslHashValue(H h, const NoSharding& key) { + return h; +} + +// `Chunked` means that the dimension is split into np.prod(chunks) chunks +// and the split dimension itself is preserved inside the map. +// Those chunks are distributed over `len(chunks)` ShardedAxes axes +// (major-to-minor). +// For example, for a tensor `t` of shape [N] sharded using [Chunked([p])] (with +// p dividing N, let S = N // p) the tensor will be split into p chunks of +// shape [S], such sharded_t[k] = t[k * S: (k+1)*S] (left included, right +// excluded) for k in {0, ... p-1}. +struct Chunked { + public: + explicit Chunked(std::vector chunks_) : chunks(std::move(chunks_)) {} + // The number of chunks per axis. + std::vector chunks; + + bool operator==(const Chunked& other) const { return chunks == other.chunks; } + bool operator!=(const Chunked& other) const { return chunks != other.chunks; } +}; + +template +H AbslHashValue(H h, const Chunked& key) { + h = H::combine(std::move(h), key.chunks); + return h; +} + +// `Unstacked` means that the dimension is split into chunks of size 1, and +// doesn't appear inside the map. `size` is always the dimension size. +// For example, a Tensor t of shape [N] will be sharded into N tensors of shape +// [], when using `Unstacked(N)`. +struct Unstacked { + public: + explicit Unstacked(int sz) : size(sz) {} + int size; + + bool operator==(const Unstacked& other) const { return size == other.size; } + bool operator!=(const Unstacked& other) const { return size != other.size; } +}; + +template +H AbslHashValue(H h, const Unstacked& key) { + h = H::combine(std::move(h), key.size); + return h; +} + +using AvalDimSharding = std::variant; + +// Assigns sharded axes to mesh dimensions. +// +// The devices will be for each dimension which has a sharded `AvalDimSharding` +// When no axis is assigned, the data is replicated. +// As indices are 0-indexed, `ShardedAxis(1)` refers to the second actually +// sharded axis (i.e. counting as if the None dimensions of sharding were +// filtered out). +// For example, given the sharding `[Unstacked(n), None, Chunked(m)]`, an entry +// of `ShardedAxis(1)` refers to the `Chunked(m)` axis, not the `None`. + +struct ShardedAxis { + int axis; + bool operator==(const ShardedAxis& other) const { return axis == other.axis; } + bool operator!=(const ShardedAxis& other) const { return axis != other.axis; } +}; + +template +H AbslHashValue(H h, const ShardedAxis& key) { + h = H::combine(std::move(h), key.axis); + return h; +} + +struct Replicated { + int replicas; + bool operator==(const Replicated& other) const { + return replicas == other.replicas; + } + bool operator!=(const Replicated& other) const { + return replicas != other.replicas; + } +}; + +template +H AbslHashValue(H h, const Replicated& key) { + h = H::combine(std::move(h), key.replicas); + return h; +} + +using MeshDimAssignment = std::variant; + +// Describes how each axis is sharded (if it is), and how it's mapped to the +// devices mesh. See Jax pxla.py for the documentation. +// +// ShardingSpec is shared across pmap, pjit and xpmap. For pmap, an input +// `sharding` is composed of `NoSharding` and at most one `Unstacked`. +// If `axis_size=None`, at least one the inputs has a dimension associated to +// `Unstacked`. +// +// Examples: +// +// 1. For pmap, with a tensor of shape [8, 2, 2], to unstack along the first +// dimension into [8] devices: +// +// sharding = [Unstacked(8), NoSharding, NoSharding] +// mesh_mapping = [ShardedAxis(0)] +// +// 2. With an input array of shape [6], that we want to chunk into [2, 3] +// Assuming a device mesh [3, 4, 2] of devices, we will have: +// +// sharding = [Chunked([2, 3])] +// mesh_mapping = [ShardedAxis(1), Replicated, ShardedAxis(0)] +// +// In particular, in the above example, the ShardedAxis refers to indices +// of the sharded shape [2, 3]. (only the `Chunked` sharding can produce more +// than one dimension). +class ShardingSpec { + public: + ShardingSpec(std::vector sharding, + std::vector mesh_mapping) + : sharding_(std::move(sharding)), + mesh_mapping_(std::move(mesh_mapping)) {} + ShardingSpec(nanobind::iterable py_sharding, + nanobind::iterable py_mesh_mapping) + : sharding_(xla::IterableToVector(py_sharding)), + mesh_mapping_( + xla::IterableToVector(py_mesh_mapping)) {} + + const std::vector& GetSharding() const { return sharding_; } + const std::vector& GetMeshMapping() const { + return mesh_mapping_; + } + + bool operator==(const ShardingSpec& other) const { + return sharding_ == other.sharding_ && mesh_mapping_ == other.mesh_mapping_; + } + + bool operator!=(const ShardingSpec& other) const { return !(*this == other); } + + template + friend H AbslHashValue(H h, const ShardingSpec& key); + + private: + // `sharding` specifies how the array is supposed to get partitioned into + // chunks. Its length matches the rank of the array. See the docstring + // of `AvalDimSharding` for the supported partitioning schemes. + std::vector sharding_; + // `mesh_mapping` describes an assignments of the array chunks created by + // `sharding` to a logical device mesh. The length of the tuple is equal to + // the rank of the mesh. Each mesh dimension can either get partitions of + // data varying along one of the sharded dimensions, or the data can be + // replicated. + std::vector mesh_mapping_; +}; + +template +H AbslHashValue(H h, const ShardingSpec& key) { + h = H::combine(std::move(h), key.sharding_); + h = H::combine(std::move(h), key.mesh_mapping_); + return h; +} + +} // namespace jax + +#endif // JAXLIB_SHARDED_DEVICE_ARRAY_H_ diff --git a/jaxlib/sharding.cc b/jaxlib/sharding.cc new file mode 100644 index 000000000000..77e97c3654bc --- /dev/null +++ b/jaxlib/sharding.cc @@ -0,0 +1,365 @@ +/* Copyright 2022 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/sharding.h" + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/base/casts.h" +#include "absl/hash/hash.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "jaxlib/nb_class_ptr.h" +#include "jaxlib/partition_spec.h" +#include "jaxlib/py_client.h" +#include "jaxlib/py_device_list.h" +#include "jaxlib/sharded_device_array.h" +#include "xla/hlo/ir/hlo_sharding.h" +#include "xla/pjrt/status_casters.h" +#include "xla/python/ifrt/device_list.h" +#include "xla/python/nb_numpy.h" +#include "xla/python/safe_static_init.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/xla_data.pb.h" + +namespace jax { + +namespace nb = nanobind; + +// Gets `jax::PyDeviceList` from a JAX Sharding. +absl::StatusOr> GetPyDeviceList( + nb::handle sharding) { + if (sharding.type().is(jax::NamedSharding::type())) { + TF_ASSIGN_OR_RETURN( + auto ns_device_list, + nb::cast(sharding)->internal_device_list()); + return ns_device_list; + } else if (sharding.type().is(jax::SingleDeviceSharding::type())) { + return nb::cast(sharding) + ->internal_device_list(); + } else if (sharding.type().is(jax::PmapSharding::type())) { + return nb::cast(sharding)->internal_device_list(); + } else if (sharding.type().is(jax::GSPMDSharding::type())) { + return nb::cast(sharding) + ->internal_device_list(); + } else { + return nb::cast>( + sharding.attr("_internal_device_list")); + } +} + +nb::object CheckAndCanonicalizeMemoryKind( + nb::object memory_kind, + const xla::nb_class_ptr& device_list) { + if (!memory_kind.is_none()) { + // If memory kind is not None, check if it's supported by the devices + // mentioned in the Sharding. + auto supported_memory_kinds = PyDeviceList::MemoryKinds(device_list); + if (!supported_memory_kinds.ok()) { + supported_memory_kinds = nb::tuple(); + } + for (nb::handle supported_memory_kind : *supported_memory_kinds) { + if (supported_memory_kind.equal(memory_kind)) { + return memory_kind; + } + } + auto addressable_device_list = + PyDeviceList::AddressableDeviceList(device_list); + if (addressable_device_list->Len() == 0) { + // If the device list is not addressable, we can't check if the memory + // kind is supported, so we assume it is. + return memory_kind; + } + nb::object device_kind = + addressable_device_list->GetItem(0).attr("device_kind"); + absl::string_view device_kind_str = + nb::cast(device_kind); + auto py_str_formatter = [](std::string* out, nb::handle h) { + *out += nb::cast(nb::str(h)); + }; + throw nb::value_error( + absl::StrCat( + "Could not find memory addressable by device ", device_kind_str, + ". Device ", device_kind_str, + " can address the following memory kinds: ", + absl::StrJoin(*supported_memory_kinds, ", ", py_str_formatter), + ". Got memory kind: ", nb::cast(memory_kind)) + .c_str()); + } + // If memory kind is None, canonicalize to default memory. + absl::StatusOr default_memory_kind = + PyDeviceList::DefaultMemoryKind(device_list); + if (!default_memory_kind.ok()) { + return nb::none(); + } + return *std::move(default_memory_kind); +} + +int Sharding::SafeNumDevices(nb::handle sharding) { + const jax::Sharding* cpp_sharding; + if (nb::try_cast(sharding, cpp_sharding)) { + if (cpp_sharding->num_devices_.has_value()) { + return (*cpp_sharding->num_devices_); + } + } + nb::set device_set = sharding.attr("device_set"); + return device_set.size(); +} + +// This list is to check for valid memory kinds when an AbstractMesh is passed +// to NamedSharding. +static const std::array valid_memory_kinds = { + "device", + "pinned_host", + "unpinned_host", +}; + +NamedSharding::NamedSharding(nb::object mesh, + xla::nb_class_ptr spec, + nb::object memory_kind, + nb::object logical_device_ids) + : Sharding(/*num_devices=*/[&mesh]() { + return nb::cast(mesh.attr("size")); + }()), + mesh_(std::move(mesh)), + spec_(std::move(spec)), + memory_kind_(std::move(memory_kind)), + logical_device_ids_(std::move(logical_device_ids)) { + nb::object idl = nb::object(mesh_.attr("_internal_device_list")); + if (idl.is_none()) { + internal_device_list_ = std::nullopt; + } else { + internal_device_list_ = nb::cast>(idl); + } + if (internal_device_list_) { + memory_kind_ = + CheckAndCanonicalizeMemoryKind(memory_kind_, *internal_device_list_); + } else { + if (!memory_kind_.is_none() && + (std::find(valid_memory_kinds.begin(), valid_memory_kinds.end(), + nb::cast(memory_kind_)) == + valid_memory_kinds.end())) { + throw nb::value_error( + absl::StrCat("Got invalid memory kind: ", + nb::cast(memory_kind_), + ". Valid memory kinds are: ", + absl::StrJoin(valid_memory_kinds, ", ")) + .c_str()); + } + } + + // TODO(phawkins): this leaks a reference to the check_pspec function. + // A better way to fix this would be to move PartitionSpec and this check into + // C++. + auto init_fn = []() { + nb::module_ si = nb::module_::import_("jax._src.named_sharding"); + return std::make_unique(si.attr("check_pspec")); + }; + nb::object& check_pspec = xla::SafeStaticInit(init_fn); + check_pspec(mesh_, spec_); +} + +/*static*/ PyObject* NamedSharding::type_ = nullptr; + +/*static*/ void NamedSharding::InitializeType() { + // Intentionally leaks a reference. + type_ = nanobind::type().inc_ref().ptr(); +} + +bool NamedSharding::operator==(const NamedSharding& other) const { + // Caution: you may need to update EqualShardingsForJit in jax_jit.cc as well. + return mesh().equal(other.mesh()) && *spec() == *other.spec() && + memory_kind().equal(other.memory_kind()) && + logical_device_ids().equal(other.logical_device_ids()); +} + +bool NamedSharding::Eq(const nanobind::object& other) const { + if (!other.ptr() || other.is_none()) { + return false; + } + const NamedSharding* other_sharding; + if (!nb::try_cast(other, other_sharding)) { + return false; + } + return this == other_sharding || *this == *other_sharding; +} + +nb::object NamedSharding::Hash() const { + // Caution: you may need to update HashShardingForJit in jax_jit.cc as well. + return hash_.Get([&]() { + size_t h = + absl::HashOf(nb::hash(mesh_), spec_->Hash(), nb::hash(memory_kind_), + nb::hash(logical_device_ids_)); + Py_hash_t s = absl::bit_cast(h); // Python hashes are signed. + return nb::cast( + s == -1 ? -2 : s); // -1 must not be used as a Python hash value. + }); +} + +SingleDeviceSharding::SingleDeviceSharding(nb::object device, + nb::object memory_kind) + : Sharding(/*num_devices=*/1), + device_(device), + memory_kind_(std::move(memory_kind)), + internal_device_list_( + xla::make_nb_class(nb::make_tuple(std::move(device)))) { + memory_kind_ = + CheckAndCanonicalizeMemoryKind(memory_kind_, internal_device_list_); +} + +/*static*/ PyObject* SingleDeviceSharding::type_ = nullptr; + +/*static*/ void SingleDeviceSharding::InitializeType() { + // Intentionally leaks a reference. + type_ = nanobind::type().inc_ref().ptr(); +} + +SingleDeviceSharding::SingleDeviceSharding( + xla::nb_class_ptr client, + xla::ifrt::DeviceListRef device_list, nb::object memory_kind) + : Sharding(/*num_devices=*/1), + device_(client->GetPyDevice(device_list->devices().front())), + memory_kind_(std::move(memory_kind)), + internal_device_list_(xla::make_nb_class( + std::move(client), std::move(device_list))) { + memory_kind_ = + CheckAndCanonicalizeMemoryKind(memory_kind_, internal_device_list_); +} + +PmapSharding::PmapSharding(xla::nb_numpy_ndarray devices, + ShardingSpec sharding_spec) + : Sharding(/*num_devices=*/devices.size()), + devices_(std::move(devices)), + sharding_spec_(std::move(sharding_spec)) { + nb::object flat_devices = devices_.attr("flat"); + internal_device_list_ = + xla::make_nb_class(nb::tuple(flat_devices)); +} + +/*static*/ PyObject* PmapSharding::type_ = nullptr; + +// /*static*/ nanobind::handle PmapSharding::type() { return type_; } + +/*static*/ void PmapSharding::InitializeType() { + // Intentionally leaks a reference. + type_ = nanobind::type().inc_ref().ptr(); +} + +GSPMDSharding::GSPMDSharding(nb::sequence devices, xla::HloSharding op_sharding, + nb::object memory_kind, nb::object device_list) + : Sharding(/*num_devices=*/nb::len(devices.ptr())), + devices_(nb::tuple(devices)), + hlo_sharding_(std::move(op_sharding)), + memory_kind_(std::move(memory_kind)) { + if (device_list.is_none()) { + internal_device_list_ = xla::make_nb_class(devices_); + } else { + internal_device_list_ = + nb::cast>(std::move(device_list)); + } + // This checks in python if the memory kind is correct for the given + // devices. Currently in python this check is optimized but we want to + // move that check to C++ after which we can remove this call. + CHECK(devices_.size() != 0) + << "Devices given to GSPMDSharding must not be empty"; + memory_kind_ = + CheckAndCanonicalizeMemoryKind(memory_kind_, internal_device_list_); +} + +/*static*/ PyObject* GSPMDSharding::type_ = nullptr; + +/*static*/ void GSPMDSharding::InitializeType() { + // Intentionally leaks a reference. + type_ = nanobind::type().inc_ref().ptr(); +} + +void RegisterSharding(nb::module_& m) { + nb::class_(m, "Sharding").def(nb::init<>()); + + nb::class_(m, "NamedSharding", nb::dynamic_attr()) + .def(nb::init, nb::object, + nb::object>(), + nb::arg("mesh"), nb::arg("spec"), + nb::arg("memory_kind").none() = nb::none(), + nb::arg("_logical_device_ids").none() = nb::none()) + .def_prop_ro("mesh", &NamedSharding::mesh) + .def_prop_ro("spec", &NamedSharding::spec) + .def_prop_ro("_memory_kind", &NamedSharding::memory_kind) + .def_prop_ro("_logical_device_ids", &NamedSharding::logical_device_ids) + .def_prop_ro("_internal_device_list", + [](const NamedSharding& s) { + return xla::ValueOrThrow(s.internal_device_list()); + }) + .def("__eq__", &NamedSharding::Eq, nb::arg().none()) + .def("__hash__", &NamedSharding::Hash); + NamedSharding::InitializeType(); + + nb::class_(m, "SingleDeviceSharding", + nb::dynamic_attr()) + .def(nb::init(), nb::arg("device"), + nb::arg("memory_kind").none() = nb::none()) + .def_prop_ro("_device", &SingleDeviceSharding::device) + .def_prop_ro("_memory_kind", &SingleDeviceSharding::memory_kind) + .def_prop_ro("_internal_device_list", + &SingleDeviceSharding::internal_device_list); + SingleDeviceSharding::InitializeType(); + + nb::class_(m, "PmapSharding", nb::dynamic_attr()) + .def( + "__init__", + [](PmapSharding* self, nb::object devices, + ShardingSpec sharding_spec) { + new (self) PmapSharding(xla::nb_numpy_ndarray::ensure(devices), + std::move(sharding_spec)); + }, + nb::arg("devices"), nb::arg("sharding_spec")) + .def_prop_ro("devices", &PmapSharding::devices) + .def_prop_ro("sharding_spec", &PmapSharding::sharding_spec) + .def_prop_ro("_internal_device_list", + &PmapSharding::internal_device_list); + PmapSharding::InitializeType(); + + nb::class_(m, "GSPMDSharding", nb::dynamic_attr()) + .def(nb::init(), + nb::arg("devices"), nb::arg("op_sharding"), + nb::arg("memory_kind").none() = nb::none(), + nb::arg("_device_list").none() = nb::none()) + .def(nb::init(), + nb::arg("devices"), nb::arg("op_sharding"), + nb::arg("memory_kind").none() = nb::none(), + nb::arg("_device_list").none() = nb::none()) + .def_prop_ro("_devices", &GSPMDSharding::devices) + .def_prop_ro("_hlo_sharding", &GSPMDSharding::hlo_sharding) + .def_prop_ro("_memory_kind", &GSPMDSharding::memory_kind) + .def_prop_ro("_internal_device_list", + &GSPMDSharding::internal_device_list); + GSPMDSharding::InitializeType(); +} + +} // namespace jax diff --git a/jaxlib/sharding.h b/jaxlib/sharding.h new file mode 100644 index 000000000000..083fb2b5d3ce --- /dev/null +++ b/jaxlib/sharding.h @@ -0,0 +1,242 @@ +/* Copyright 2022 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_SHARDING_H_ +#define JAXLIB_SHARDING_H_ + +#include + +#include +#include +#include + +// placeholder for index annotation headers +#include "absl/hash/hash.h" +#include "absl/status/statusor.h" +#include "nanobind/nanobind.h" +#include "jaxlib/cached_py_object.h" +#include "jaxlib/nb_class_ptr.h" +#include "jaxlib/partition_spec.h" +#include "jaxlib/py_client.h" +#include "jaxlib/py_device_list.h" +#include "jaxlib/sharded_device_array.h" +#include "xla/hlo/ir/hlo_sharding.h" +#include "xla/pjrt/status_casters.h" +#include "xla/python/ifrt/device_list.h" +#include "xla/python/nb_numpy.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" + +namespace jax { + +class Sharding { + public: + Sharding() = default; + + // This constructor is used in the fast path to retrieve the number of devices + // without falling back to python. This is only used in the cpp path. + explicit Sharding(int num_devices) : num_devices_(num_devices) {} + + virtual ~Sharding() = default; + + static int SafeNumDevices(nanobind::handle sharding); + + private: + std::optional num_devices_; +}; + +// Gets `jax::PyDeviceList` from a JAX Sharding. +absl::StatusOr> GetPyDeviceList( + nanobind::handle sharding); + +// Checks if the memory kind is valid, and canonicalizes the +// memory kind to default memory on backends that support memories. +nanobind::object CheckAndCanonicalizeMemoryKind( + nanobind::object memory_kind, + const xla::nb_class_ptr& device_list); + +class NamedSharding : public Sharding { + public: + NamedSharding(nanobind::object mesh, xla::nb_class_ptr spec, + nanobind::object memory_kind, + nanobind::object logical_device_ids); + + const nanobind::object& mesh() const { return mesh_; } + const xla::nb_class_ptr& spec() const { return spec_; } + const nanobind::object& memory_kind() const { return memory_kind_; } + const nanobind::object& logical_device_ids() const { + return logical_device_ids_; + } + + static nanobind::handle type() { return type_; } + static void InitializeType(); + + absl::StatusOr> internal_device_list() const { + if (internal_device_list_) { + return *internal_device_list_; + } + return xla::InvalidArgument( + "internal_device_list is not implemented for " + "`jax.sharding.AbstractMesh`"); + } + + bool operator==(const NamedSharding& other) const; + + bool Eq(const nanobind::object& other) const; // Python __eq__ + nanobind::object Hash() const; // Python __hash__ + + private: + nanobind::object mesh_; + xla::nb_class_ptr spec_; + nanobind::object memory_kind_; + nanobind::object logical_device_ids_; + std::optional> internal_device_list_; + mutable CachedPyObject hash_; + static PyObject* type_; +}; + +class SingleDeviceSharding : public Sharding { + public: + explicit SingleDeviceSharding( + nanobind::object device, nanobind::object memory_kind = nanobind::none()); + + // Used only in C++ to accelerate `PyArray::MakeFromSingleDeviceArray()`. + SingleDeviceSharding(xla::nb_class_ptr client, + xla::ifrt::DeviceListRef device_list, + nanobind::object memory_kind); + + const nanobind::object& device() const { return device_; } + const nanobind::object& memory_kind() const { return memory_kind_; } + + static nanobind::handle type() { return type_; } + static void InitializeType(); + + xla::nb_class_ptr internal_device_list() const { + return internal_device_list_; + } + + private: + nanobind::object device_; + nanobind::object memory_kind_; + xla::nb_class_ptr internal_device_list_; + + static PyObject* type_; +}; + +// The C++ implementation of jax.PmapSharding in python. It contains a few key +// data members and methods that are performance-critical. +class PmapSharding : public Sharding { + public: + PmapSharding(xla::nb_numpy_ndarray devices, ShardingSpec sharding_spec); + + ~PmapSharding() override = default; + + xla::nb_numpy_ndarray devices() const { return devices_; } + + const ShardingSpec& sharding_spec() const { return sharding_spec_; } + + static nanobind::handle type() { return type_; } + static void InitializeType(); + + xla::nb_class_ptr internal_device_list() const { + return internal_device_list_; + } + + private: + xla::nb_numpy_ndarray devices_; + ShardingSpec sharding_spec_; + xla::nb_class_ptr internal_device_list_; + static PyObject* type_; +}; + +class GSPMDSharding : public Sharding { + public: + GSPMDSharding(nanobind::sequence devices, xla::OpSharding op_sharding, + nanobind::object memory_kind, nanobind::object device_list) + : GSPMDSharding( + std::move(devices), + xla::ValueOrThrow(xla::HloSharding::FromProto(op_sharding)), + std::move(memory_kind), std::move(device_list)) {} + + GSPMDSharding(nanobind::sequence devices, xla::HloSharding op_sharding, + nanobind::object memory_kind, nanobind::object device_list); + + const nanobind::tuple& devices() const { return devices_; } + const nanobind::object& memory_kind() const { return memory_kind_; } + + size_t Hash() { + if (!hash_.has_value()) { + hash_ = CalculateHash(); + } + return *hash_; + } + + static nanobind::handle type() { return type_; } + static void InitializeType(); + + const xla::HloSharding& hlo_sharding() const { return hlo_sharding_; } + + bool operator==(const GSPMDSharding& other) const { + return AreOpShardingsEqual(*this, other) && + this->devices().equal(other.devices()) && + this->memory_kind().equal(other.memory_kind()); + } + + xla::nb_class_ptr internal_device_list() const { + return internal_device_list_; + } + + private: + size_t CalculateHash() const { + // We only hash `hlo_sharding_` here for performance. + return absl::Hash()(hlo_sharding_); + } + + static bool AreOpShardingsEqual(const GSPMDSharding& a, + const GSPMDSharding& b) { + // If the OpSharding object is the same, return true + if (&a.hlo_sharding() == &b.hlo_sharding()) { + return true; + } + // If both OpShardings are replicated, return true + if (a.IsOpShardingReplicated() && b.IsOpShardingReplicated()) { + return true; + } + return a.hlo_sharding() == b.hlo_sharding(); + } + + bool IsOpShardingReplicated() const { + // For JAX, shardings with 1 device are considered as replicated in its + // semantics so that downstream things continue to work. + if (hlo_sharding_.tile_assignment().num_elements() == 1) { + return true; + } + return hlo_sharding().IsReplicated(); + } + + nanobind::tuple devices_; + xla::HloSharding hlo_sharding_; + nanobind::object memory_kind_; + std::optional hash_; + xla::nb_class_ptr internal_device_list_; + + static PyObject* type_; +}; + +void RegisterSharding(nanobind::module_& m); + +} // namespace jax + +#endif // JAXLIB_SHARDING_H_ diff --git a/jaxlib/to_ifrt_sharding.cc b/jaxlib/to_ifrt_sharding.cc new file mode 100644 index 000000000000..220c54e7a1e5 --- /dev/null +++ b/jaxlib/to_ifrt_sharding.cc @@ -0,0 +1,140 @@ +/* Copyright 2025 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/to_ifrt_sharding.h" + +#include +#include +#include +#include + +#include "absl/status/statusor.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "jaxlib/nb_class_ptr.h" +#include "jaxlib/py_device_list.h" +#include "jaxlib/sharding.h" +#include "xla/hlo/ir/hlo_sharding.h" +#include "xla/python/ifrt/device_list.h" +#include "xla/python/ifrt/dtype.h" +#include "xla/python/ifrt/memory.h" +#include "xla/python/ifrt/shape.h" +#include "xla/python/ifrt/sharding.h" +#include "xla/python/pjrt_ifrt/pjrt_dtype.h" +#include "xla/python/pjrt_ifrt/xla_sharding.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/xla_data.pb.h" + +namespace xla { + +namespace nb = ::nanobind; + +// Gets `xla::HloSharding` from a JAX Sharding. +xla::HloSharding GetXlaHloSharding(nb::handle sharding, + int64_t num_dimensions) { + if (sharding.type().is(nb::handle(jax::GSPMDSharding::type().ptr()))) { + return nb::cast(nb::handle(sharding.ptr())) + ->hlo_sharding(); + } else { + return nb::cast( + sharding.attr("_to_xla_hlo_sharding")(num_dimensions)); + } +} + +// Gets `xla::ifrt::DeviceList` from a JAX Sharding. +absl::StatusOr GetIfrtDeviceList( + nb::handle sharding_py) { + TF_ASSIGN_OR_RETURN(auto py_device_list, jax::GetPyDeviceList(sharding_py)); + return py_device_list->ifrt_device_list(); +} + +// Gets `xla::ifrt::MemoryKind` from a JAX Sharding. +xla::ifrt::MemoryKind GetMemoryKind(nb::handle sharding) { + nb::object py_memory_kind = nb::none(); + + // sharding.attr("memory_kind") can crash if sharding was originally created + // from C++ and casted into a Python Sharding object. Thus, we cast sharding + // to a C++ type and use C++ `memory_kind()` method, which bypasses any Python + // attribute access. + nb::handle type = sharding.type(); + if (type.is(jax::NamedSharding::type())) { + py_memory_kind = + nb::cast(sharding)->memory_kind(); + } else if (type.is(jax::SingleDeviceSharding::type())) { + py_memory_kind = + nb::cast(sharding)->memory_kind(); + } else if (type.is(jax::GSPMDSharding::type())) { + py_memory_kind = + nb::cast(sharding)->memory_kind(); + } else { + py_memory_kind = sharding.attr("memory_kind"); + } + + if (py_memory_kind.is_none()) { + return xla::ifrt::MemoryKind(); + } + return xla::ifrt::MemoryKind(nb::cast(py_memory_kind)); +} + +// Converts a JAX Sharding into `xla::ifrt::HloSharding`. +absl::StatusOr GetIfrtHloSharding( + nb::handle sharding, const xla::ifrt::Shape& shape) { + TF_ASSIGN_OR_RETURN(xla::ifrt::DeviceListRef device_list, + GetIfrtDeviceList(sharding)); + xla::ifrt::MemoryKind memory_kind = GetMemoryKind(sharding.ptr()); + xla::HloSharding hlo_sharding = + GetXlaHloSharding(sharding, shape.dims().size()); + return xla::ifrt::HloSharding::Create( + std::move(device_list), std::move(memory_kind), std::move(hlo_sharding)); +} + +// Converts a JAX Sharding into `xla::ifrt::ConcreteEvenSharding`. +absl::StatusOr GetIfrtConcreteEvenSharding( + nb::handle sharding, xla::ifrt::DType dtype, + const xla::ifrt::Shape& shape) { + TF_ASSIGN_OR_RETURN(xla::ifrt::DeviceListRef device_list, + GetIfrtDeviceList(sharding)); + xla::ifrt::MemoryKind memory_kind = GetMemoryKind(sharding.ptr()); + TF_ASSIGN_OR_RETURN(xla::PrimitiveType xla_primitive_type, + xla::ifrt::ToPrimitiveType(dtype)); + // The XLA shape's layout is irrelevant because we only need to know the + // tile shape, which is independent from the layout. + xla::Shape xla_shape = xla::ShapeUtil::MakeShapeWithDescendingLayout( + xla_primitive_type, shape.dims()); + xla::HloSharding hlo_sharding = + GetXlaHloSharding(sharding, shape.dims().size()); + xla::Shape tile_shape = hlo_sharding.TileShape(xla_shape); + xla::ifrt::Shape shard_shape(xla::ifrt::Shape::Dimensions( + tile_shape.dimensions().begin(), tile_shape.dimensions().end())); + return xla::ifrt::ConcreteEvenSharding::Create( + std::move(device_list), std::move(memory_kind), shape, + /*shard_shape=*/std::move(shard_shape)); +} + +// Converts a JAX Sharding into `xla::ifrt::ConcreteSharding`. +absl::StatusOr GetIfrtConcreteSharding( + nb::handle sharding, const xla::ifrt::Shape& shape, + std::vector shard_shapes) { + TF_ASSIGN_OR_RETURN(xla::ifrt::DeviceListRef device_list, + GetIfrtDeviceList(sharding)); + xla::ifrt::MemoryKind memory_kind = GetMemoryKind(sharding.ptr()); + return xla::ifrt::ConcreteSharding::Create( + std::move(device_list), std::move(memory_kind), shape, + /*shard_shapes=*/std::move(shard_shapes)); +} + +} // namespace xla diff --git a/jaxlib/to_ifrt_sharding.h b/jaxlib/to_ifrt_sharding.h new file mode 100644 index 000000000000..911a7caea368 --- /dev/null +++ b/jaxlib/to_ifrt_sharding.h @@ -0,0 +1,61 @@ +/* Copyright 2025 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_TO_IFRT_SHARDING_H_ +#define JAXLIB_TO_IFRT_SHARDING_H_ + +#include +#include +#include + +#include "absl/status/statusor.h" +#include "nanobind/nanobind.h" +#include "xla/hlo/ir/hlo_sharding.h" +#include "xla/python/ifrt/device_list.h" +#include "xla/python/ifrt/dtype.h" +#include "xla/python/ifrt/memory.h" +#include "xla/python/ifrt/shape.h" +#include "xla/python/ifrt/sharding.h" + +namespace xla { + +// Gets `xla::HloSharding` from a JAX Sharding. +xla::HloSharding GetXlaHloSharding(nanobind::handle sharding, + int64_t num_dimensions); + +// Gets `xla::ifrt::DeviceList` from a JAX Sharding. +absl::StatusOr GetIfrtDeviceList( + nanobind::handle sharding_py); + +// Gets `xla::ifrt::MemoryKind` from a JAX Sharding. +xla::ifrt::MemoryKind GetMemoryKind(nanobind::handle sharding); + +// Converts a JAX Sharding into `xla::ifrt::HloSharding`. +absl::StatusOr GetIfrtHloSharding( + nanobind::handle sharding, const xla::ifrt::Shape& shape); + +// Converts a JAX Sharding into `xla::ifrt::ConcreteEvenSharding`. +absl::StatusOr GetIfrtConcreteEvenSharding( + nanobind::handle sharding, xla::ifrt::DType dtype, + const xla::ifrt::Shape& shape); + +// Converts a JAX Sharding into `xla::ifrt::ConcreteSharding`. +absl::StatusOr GetIfrtConcreteSharding( + nanobind::handle sharding, const xla::ifrt::Shape& shape, + std::vector shard_shapes); + +} // namespace xla + +#endif // JAXLIB_TO_IFRT_SHARDING_H_ diff --git a/jaxlib/tools/BUILD.bazel b/jaxlib/tools/BUILD.bazel index afa5866e286d..515b7be04f64 100644 --- a/jaxlib/tools/BUILD.bazel +++ b/jaxlib/tools/BUILD.bazel @@ -15,7 +15,7 @@ # JAX is Autograd and XLA load("@bazel_skylib//lib:selects.bzl", "selects") -load("@bazel_skylib//rules:common_settings.bzl", "string_flag") +load("@bazel_skylib//rules:common_settings.bzl", "bool_flag", "string_flag") load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm") load( @@ -29,16 +29,20 @@ load( load( "//jaxlib:jax.bzl", "PLATFORM_TAGS_DICT", - "if_windows", + "if_pypi_cuda_wheel_deps", "jax_py_test", "jax_wheel", "pytype_strict_library", + "pytype_test", + "wheel_sources", ) licenses(["notice"]) # Apache 2 package(default_visibility = ["//visibility:public"]) +exports_files(["wheel_size_test.py"]) + genrule( name = "platform_tags_py", srcs = [], @@ -61,21 +65,20 @@ py_binary( "LICENSE.txt", "//jaxlib", "//jaxlib:README.md", + "//jaxlib:_jax", + "//jaxlib:jaxlib_binaries", "//jaxlib:setup.py", + "//jaxlib:xla_client.py", "@xla//xla/ffi/api:api.h", "@xla//xla/ffi/api:c_api.h", "@xla//xla/ffi/api:ffi.h", - "@xla//xla/python:xla_client.py", - "@xla//xla/python:xla_extension", - ] + if_windows([ - "//jaxlib/mlir/_mlir_libs:jaxlib_mlir_capi.dll", - ]), + ], deps = [ ":build_utils", "@bazel_tools//tools/python/runfiles", - "@pypi_build//:pkg", - "@pypi_setuptools//:pkg", - "@pypi_wheel//:pkg", + "@pypi//build", + "@pypi//setuptools", + "@pypi//wheel", ], ) @@ -88,35 +91,15 @@ jax_py_test( ], ) -cc_binary( - name = "pjrt_c_api_gpu_plugin.so", - linkopts = [ - "-Wl,--version-script,$(location :gpu_version_script.lds)", - "-Wl,--no-undefined", - ], - linkshared = True, - deps = [ - ":gpu_version_script.lds", - "@xla//xla/pjrt/c:pjrt_c_api_gpu", - "@xla//xla/pjrt/c:pjrt_c_api_gpu_version_script.lds", - "@xla//xla/service:gpu_plugin", - ] + if_cuda([ - "//jaxlib/mosaic/gpu:custom_call", - "@xla//xla/stream_executor:cuda_platform", - ]) + if_rocm([ - "@xla//xla/stream_executor:rocm_platform", - ]), -) - py_binary( name = "build_gpu_plugin_wheel", srcs = ["build_gpu_plugin_wheel.py"], data = [ "LICENSE.txt", - ":pjrt_c_api_gpu_plugin.so", ] + if_cuda([ "//jaxlib:version", "//jaxlib/cuda:cuda_gpu_support", + "//jax_plugins/cuda:pjrt_c_api_gpu_plugin.so", "//jax_plugins/cuda:pyproject.toml", "//jax_plugins/cuda:setup.py", "//jax_plugins/cuda:__init__.py", @@ -124,6 +107,7 @@ py_binary( ]) + if_rocm([ "//jaxlib:version", "//jaxlib/rocm:rocm_gpu_support", + "//jax_plugins/rocm:pjrt_c_api_gpu_plugin.so", "//jax_plugins/rocm:pyproject.toml", "//jax_plugins/rocm:setup.py", "//jax_plugins/rocm:__init__.py", @@ -131,9 +115,9 @@ py_binary( deps = [ ":build_utils", "@bazel_tools//tools/python/runfiles", - "@pypi_build//:pkg", - "@pypi_setuptools//:pkg", - "@pypi_wheel//:pkg", + "@pypi//build", + "@pypi//setuptools", + "@pypi//wheel", ], ) @@ -160,12 +144,16 @@ py_binary( deps = [ ":build_utils", "@bazel_tools//tools/python/runfiles", - "@pypi_build//:pkg", - "@pypi_setuptools//:pkg", - "@pypi_wheel//:pkg", + "@pypi//build", + "@pypi//setuptools", + "@pypi//wheel", ], ) +# Targets and configurations for the new wheel build rules. + +# Platform configurations. + selects.config_setting_group( name = "macos", match_any = [ @@ -222,6 +210,8 @@ selects.config_setting_group( ], ) +# Flags for the new wheel build rules. + string_flag( name = "jaxlib_git_hash", build_setting_default = "", @@ -232,60 +222,129 @@ string_flag( build_setting_default = "dist", ) -NVIDIA_WHEELS_DEPS = [ - "@pypi_nvidia_cublas_cu12//:whl", - "@pypi_nvidia_cuda_cupti_cu12//:whl", - "@pypi_nvidia_cuda_runtime_cu12//:whl", - "@pypi_nvidia_cudnn_cu12//:whl", - "@pypi_nvidia_cufft_cu12//:whl", - "@pypi_nvidia_cusolver_cu12//:whl", - "@pypi_nvidia_cusparse_cu12//:whl", - "@pypi_nvidia_nccl_cu12//:whl", - "@pypi_nvidia_nvjitlink_cu12//:whl", -] +# Wheel targets. + +# Jaxlib wheel targets. +py_binary( + name = "build_wheel_tool", + srcs = ["build_wheel.py"], + main = "build_wheel.py", + deps = [ + ":build_utils", + "@bazel_tools//tools/python/runfiles", + "@pypi//build", + "@pypi//setuptools", + "@pypi//wheel", + ], +) + +wheel_sources( + name = "jaxlib_sources", + data_srcs = [ + "//jaxlib", + "//jaxlib:jaxlib_binaries", + "//jaxlib:_jax", + ], + hdr_srcs = [ + "@xla//xla/ffi/api:ffi", + ], + py_srcs = [ + "//jaxlib", + ], + static_srcs = [ + "//jaxlib:README.md", + "LICENSE.txt", + "//jaxlib:setup.py", + "//jaxlib:xla_client.py", + ], + symlink_data_srcs = [ + "//jaxlib", + ], +) jax_wheel( name = "jaxlib_wheel", no_abi = False, - wheel_binary = ":build_wheel", + source_files = [":jaxlib_sources"], + wheel_binary = ":build_wheel_tool", wheel_name = "jaxlib", ) -py_import( - name = "jaxlib_py_import", - wheel = ":jaxlib_wheel", -) - jax_wheel( name = "jaxlib_wheel_editable", editable = True, - wheel_binary = ":build_wheel", + source_files = [":jaxlib_sources"], + wheel_binary = ":build_wheel_tool", wheel_name = "jaxlib", ) +# JAX plugin wheel targets. +pytype_strict_library( + name = "version", + srcs = ["//jaxlib:version"], +) + +py_binary( + name = "build_gpu_kernels_wheel_tool", + srcs = ["build_gpu_kernels_wheel.py"], + main = "build_gpu_kernels_wheel.py", + deps = [ + ":build_utils", + "@bazel_tools//tools/python/runfiles", + "@pypi//build", + "@pypi//setuptools", + "@pypi//wheel", + ], +) + +wheel_sources( + name = "jax_plugin_sources", + data_srcs = [ + ] + if_cuda([ + "//jaxlib/cuda:cuda_gpu_support", + "@local_config_cuda//cuda:cuda-nvvm", + "//jaxlib/cuda:cuda_plugin_extension", + "//jaxlib/mosaic/gpu:mosaic_gpu", + ]) + if_rocm([ + "//jaxlib/rocm:rocm_gpu_support", + "//jaxlib/rocm:rocm_plugin_extension", + ]), + py_srcs = [":version"] + if_cuda([ + "//jaxlib/cuda:cuda_gpu_support", + "//jaxlib/mosaic/gpu:mosaic_gpu", + ]) + if_rocm([ + "//jaxlib/rocm:rocm_gpu_support", + ]), + static_srcs = [ + "LICENSE.txt", + ] + if_cuda([ + "//jax_plugins/cuda:plugin_pyproject.toml", + "//jax_plugins/cuda:plugin_setup.py", + ]) + if_rocm([ + "//jax_plugins/rocm:plugin_pyproject.toml", + "//jax_plugins/rocm:plugin_setup.py", + ]), +) + jax_wheel( name = "jax_cuda_plugin_wheel", enable_cuda = True, no_abi = False, # TODO(b/371217563) May use hermetic cuda version here. platform_version = "12", - wheel_binary = ":build_gpu_kernels_wheel", + source_files = [":jax_plugin_sources"], + wheel_binary = ":build_gpu_kernels_wheel_tool", wheel_name = "jax_cuda12_plugin", ) -py_import( - name = "jax_cuda_plugin_py_import", - wheel = ":jax_cuda_plugin_wheel", - wheel_deps = if_cuda(NVIDIA_WHEELS_DEPS), -) - jax_wheel( name = "jax_cuda_plugin_wheel_editable", editable = True, enable_cuda = True, # TODO(b/371217563) May use hermetic cuda version here. platform_version = "12", - wheel_binary = ":build_gpu_kernels_wheel", + source_files = [":jax_plugin_sources"], + wheel_binary = ":build_gpu_kernels_wheel_tool", wheel_name = "jax_cuda12_plugin", ) @@ -294,7 +353,8 @@ jax_wheel( enable_rocm = True, no_abi = False, platform_version = "60", - wheel_binary = ":build_gpu_kernels_wheel", + source_files = [":jax_plugin_sources"], + wheel_binary = ":build_gpu_kernels_wheel_tool", wheel_name = "jax_rocm60_plugin", ) @@ -303,33 +363,75 @@ jax_wheel( editable = True, enable_rocm = True, platform_version = "60", - wheel_binary = ":build_gpu_kernels_wheel", + source_files = [":jax_plugin_sources"], + wheel_binary = ":build_gpu_kernels_wheel_tool", wheel_name = "jax_rocm60_plugin", ) +# JAX PJRT wheel targets. + +py_binary( + name = "build_gpu_plugin_wheel_tool", + srcs = ["build_gpu_plugin_wheel.py"], + main = "build_gpu_plugin_wheel.py", + deps = [ + ":build_utils", + "@bazel_tools//tools/python/runfiles", + "@pypi//build", + "@pypi//setuptools", + "@pypi//wheel", + ], +) + +wheel_sources( + name = "jax_pjrt_sources", + data_srcs = if_cuda([ + "//jax_plugins/cuda:cuda_plugin", + "//jaxlib/cuda:cuda_gpu_support", + "@local_config_cuda//cuda:cuda-nvvm", + ]) + if_rocm([ + "//jax_plugins/rocm:rocm_plugin", + "//jaxlib/rocm:rocm_gpu_support", + ]), + py_srcs = [ + ":version", + ] + if_cuda([ + "//jaxlib/cuda:cuda_gpu_support", + ]) + if_rocm([ + "//jaxlib/rocm:rocm_gpu_support", + ]), + static_srcs = [ + "LICENSE.txt", + ] + if_cuda([ + "//jax_plugins/cuda:pyproject.toml", + "//jax_plugins/cuda:setup.py", + "//jax_plugins/cuda:__init__.py", + ]) + if_rocm([ + "//jax_plugins/rocm:pyproject.toml", + "//jax_plugins/rocm:setup.py", + "//jax_plugins/rocm:__init__.py", + ]), +) + jax_wheel( name = "jax_cuda_pjrt_wheel", enable_cuda = True, no_abi = True, # TODO(b/371217563) May use hermetic cuda version here. platform_version = "12", - wheel_binary = ":build_gpu_plugin_wheel", + source_files = [":jax_pjrt_sources"], + wheel_binary = ":build_gpu_plugin_wheel_tool", wheel_name = "jax_cuda12_pjrt", ) -py_import( - name = "jax_cuda_pjrt_py_import", - wheel = ":jax_cuda_pjrt_wheel", - wheel_deps = if_cuda(NVIDIA_WHEELS_DEPS), -) - jax_wheel( name = "jax_cuda_pjrt_wheel_editable", editable = True, enable_cuda = True, # TODO(b/371217563) May use hermetic cuda version here. platform_version = "12", - wheel_binary = ":build_gpu_plugin_wheel", + source_files = [":jax_pjrt_sources"], + wheel_binary = ":build_gpu_plugin_wheel_tool", wheel_name = "jax_cuda12_pjrt", ) @@ -338,7 +440,8 @@ jax_wheel( enable_rocm = True, no_abi = True, platform_version = "60", - wheel_binary = ":build_gpu_plugin_wheel", + source_files = [":jax_pjrt_sources"], + wheel_binary = ":build_gpu_plugin_wheel_tool", wheel_name = "jax_rocm60_pjrt", ) @@ -347,10 +450,76 @@ jax_wheel( editable = True, enable_rocm = True, platform_version = "60", - wheel_binary = ":build_gpu_plugin_wheel", + source_files = [":jax_pjrt_sources"], + wheel_binary = ":build_gpu_plugin_wheel_tool", wheel_name = "jax_rocm60_pjrt", ) +# Py_import targets. +filegroup( + name = "nvidia_wheel_deps", + srcs = [ + "@pypi_nvidia_cublas_cu12//:whl", + "@pypi_nvidia_cuda_cupti_cu12//:whl", + "@pypi_nvidia_cuda_nvcc_cu12//:whl", + "@pypi_nvidia_cuda_nvrtc_cu12//:whl", + "@pypi_nvidia_cuda_runtime_cu12//:whl", + "@pypi_nvidia_cudnn_cu12//:whl", + "@pypi_nvidia_cufft_cu12//:whl", + "@pypi_nvidia_cusolver_cu12//:whl", + "@pypi_nvidia_cusparse_cu12//:whl", + "@pypi_nvidia_nccl_cu12//:whl", + "@pypi_nvidia_nvjitlink_cu12//:whl", + "@pypi_nvidia_nvshmem_cu12//:whl", + ], +) + +# The flag configures whether to add the pypi NVIDIA CUDA deps to py_import. +bool_flag( + name = "add_pypi_cuda_wheel_deps", + build_setting_default = True, +) + +config_setting( + name = "pypi_cuda_wheel_deps", + flag_values = { + ":add_pypi_cuda_wheel_deps": "True", + "@local_config_cuda//:enable_cuda": "True", + }, +) + +py_import( + name = "jaxlib_py_import", + wheel = ":jaxlib_wheel", +) + +py_import( + name = "jax_cuda_plugin_py_import", + wheel = ":jax_cuda_plugin_wheel", + wheel_deps = if_pypi_cuda_wheel_deps([":nvidia_wheel_deps"]), +) + +py_import( + name = "jax_cuda_pjrt_py_import", + wheel = ":jax_cuda_pjrt_wheel", + wheel_deps = if_pypi_cuda_wheel_deps([":nvidia_wheel_deps"]), +) + +# The targets below are used for GPU tests with `--//jax:build_jaxlib=false`. +py_import( + name = "pypi_jax_cuda_plugin_with_cuda_deps", + wheel = "@pypi_jax_cuda12_plugin//:whl", + wheel_deps = if_pypi_cuda_wheel_deps([":nvidia_wheel_deps"]), +) + +py_import( + name = "pypi_jax_cuda_pjrt_with_cuda_deps", + wheel = "@pypi_jax_cuda12_pjrt//:whl", + wheel_deps = if_pypi_cuda_wheel_deps([":nvidia_wheel_deps"]), +) + +# Wheel tests. + AARCH64_MANYLINUX_TAG = "_".join(PLATFORM_TAGS_DICT[("Linux", "aarch64")]) PPC64LE_MANYLINUX_TAG = "_".join(PLATFORM_TAGS_DICT[("Linux", "ppc64le")]) @@ -389,3 +558,48 @@ verify_manylinux_compliance_test( wheel = ":jax_cuda_pjrt_wheel", x86_64_compliance_tag = X86_64_MANYLINUX_TAG, ) + +pytype_test( + name = "jaxlib_wheel_size_test", + srcs = [":wheel_size_test.py"], + args = [ + "--wheel-path=$(location :jaxlib_wheel)", + "--max-size-mib=110", + ], + data = [":jaxlib_wheel"], + main = "wheel_size_test.py", + tags = [ + "manual", + "notap", + ], +) + +pytype_test( + name = "jax_cuda_plugin_wheel_size_test", + srcs = [":wheel_size_test.py"], + args = [ + "--wheel-path=$(location :jax_cuda_plugin_wheel)", + "--max-size-mib=20", + ], + data = [":jax_cuda_plugin_wheel"], + main = "wheel_size_test.py", + tags = [ + "manual", + "notap", + ], +) + +pytype_test( + name = "jax_cuda_pjrt_wheel_size_test", + srcs = [":wheel_size_test.py"], + args = [ + "--wheel-path=$(location :jax_cuda_pjrt_wheel)", + "--max-size-mib=120", + ], + data = [":jax_cuda_pjrt_wheel"], + main = "wheel_size_test.py", + tags = [ + "manual", + "notap", + ], +) diff --git a/jaxlib/tools/build_gpu_kernels_wheel.py b/jaxlib/tools/build_gpu_kernels_wheel.py index 2f81eacbdde4..835a8b72de9f 100644 --- a/jaxlib/tools/build_gpu_kernels_wheel.py +++ b/jaxlib/tools/build_gpu_kernels_wheel.py @@ -26,7 +26,7 @@ from bazel_tools.tools.python.runfiles import runfiles from jaxlib.tools import build_utils -parser = argparse.ArgumentParser() +parser = argparse.ArgumentParser(fromfile_prefix_chars="@") parser.add_argument( "--output_path", default=None, @@ -61,6 +61,9 @@ "--enable-rocm", default=False, help="Should we build with ROCM enabled?") +parser.add_argument( + "--srcs", help="source files for the wheel", action="append" +) args = parser.parse_args() r = runfiles.Create() @@ -79,80 +82,106 @@ def write_setup_cfg(sources_path, cpu): def prepare_wheel_cuda( - sources_path: pathlib.Path, *, cpu, cuda_version + wheel_sources_path: pathlib.Path, *, cpu, cuda_version, wheel_sources ): - """Assembles a source tree for the cuda kernel wheel in `sources_path`.""" - copy_runfiles = functools.partial(build_utils.copy_file, runfiles=r) + """Assembles a source tree for the cuda kernel wheel in `wheel_sources_path`.""" + source_file_prefix = build_utils.get_source_file_prefix(wheel_sources) + wheel_sources_map = build_utils.create_wheel_sources_map( + wheel_sources, + root_packages=[ + "jax_plugins", + f"jax_cuda{cuda_version}_plugin", + "jaxlib", + ], + ) + copy_files = functools.partial( + build_utils.copy_file, + runfiles=r, + wheel_sources_map=wheel_sources_map, + ) - copy_runfiles( - "__main__/jax_plugins/cuda/plugin_pyproject.toml", - dst_dir=sources_path, + copy_files( + f"{source_file_prefix}jax_plugins/cuda/plugin_pyproject.toml", + dst_dir=wheel_sources_path, dst_filename="pyproject.toml", ) - copy_runfiles( - "__main__/jax_plugins/cuda/plugin_setup.py", - dst_dir=sources_path, + copy_files( + f"{source_file_prefix}jax_plugins/cuda/plugin_setup.py", + dst_dir=wheel_sources_path, dst_filename="setup.py", ) - build_utils.update_setup_with_cuda_version(sources_path, cuda_version) - write_setup_cfg(sources_path, cpu) + build_utils.update_setup_with_cuda_version(wheel_sources_path, cuda_version) + write_setup_cfg(wheel_sources_path, cpu) - plugin_dir = sources_path / f"jax_cuda{cuda_version}_plugin" - copy_runfiles( + plugin_dir = wheel_sources_path / f"jax_cuda{cuda_version}_plugin" + copy_files( dst_dir=plugin_dir, src_files=[ - f"__main__/jaxlib/cuda/_solver.{pyext}", - f"__main__/jaxlib/cuda/_blas.{pyext}", - f"__main__/jaxlib/cuda/_linalg.{pyext}", - f"__main__/jaxlib/cuda/_prng.{pyext}", - f"__main__/jaxlib/cuda/_rnn.{pyext}", - f"__main__/jaxlib/cuda/_sparse.{pyext}", - f"__main__/jaxlib/cuda/_triton.{pyext}", - f"__main__/jaxlib/cuda/_hybrid.{pyext}", - f"__main__/jaxlib/cuda/_versions.{pyext}", - f"__main__/jaxlib/cuda/cuda_plugin_extension.{pyext}", - f"__main__/jaxlib/mosaic/gpu/_mosaic_gpu_ext.{pyext}", - "__main__/jaxlib/mosaic/gpu/libmosaic_gpu_runtime.so", - "__main__/jaxlib/version.py", + f"{source_file_prefix}jaxlib/cuda/_solver.{pyext}", + f"{source_file_prefix}jaxlib/cuda/_linalg.{pyext}", + f"{source_file_prefix}jaxlib/cuda/_prng.{pyext}", + f"{source_file_prefix}jaxlib/cuda/_rnn.{pyext}", + f"{source_file_prefix}jaxlib/cuda/_sparse.{pyext}", + f"{source_file_prefix}jaxlib/cuda/_triton.{pyext}", + f"{source_file_prefix}jaxlib/cuda/_hybrid.{pyext}", + f"{source_file_prefix}jaxlib/cuda/_versions.{pyext}", + f"{source_file_prefix}jaxlib/cuda/cuda_plugin_extension.{pyext}", + f"{source_file_prefix}jaxlib/mosaic/gpu/_mosaic_gpu_ext.{pyext}", + f"{source_file_prefix}jaxlib/mosaic/gpu/libmosaic_gpu_runtime.so", + f"{source_file_prefix}jaxlib/version.py", ], ) + def prepare_wheel_rocm( - sources_path: pathlib.Path, *, cpu, rocm_version + wheel_sources_path: pathlib.Path, *, cpu, rocm_version, wheel_sources ): - """Assembles a source tree for the rocm kernel wheel in `sources_path`.""" - copy_runfiles = functools.partial(build_utils.copy_file, runfiles=r) + """Assembles a source tree for the rocm kernel wheel in `wheel_sources_path`.""" + source_file_prefix = build_utils.get_source_file_prefix(wheel_sources) + wheel_sources_map = build_utils.create_wheel_sources_map( + wheel_sources, + root_packages=[ + "jax_plugins", + f"jax_rocm{rocm_version}_plugin", + "jaxlib", + ], + ) + copy_files = functools.partial( + build_utils.copy_file, + runfiles=r, + wheel_sources_map=wheel_sources_map, + ) - copy_runfiles( - "__main__/jax_plugins/rocm/plugin_pyproject.toml", - dst_dir=sources_path, + copy_files( + f"{source_file_prefix}jax_plugins/rocm/plugin_pyproject.toml", + dst_dir=wheel_sources_path, dst_filename="pyproject.toml", ) - copy_runfiles( - "__main__/jax_plugins/rocm/plugin_setup.py", - dst_dir=sources_path, + copy_files( + f"{source_file_prefix}jax_plugins/rocm/plugin_setup.py", + dst_dir=wheel_sources_path, dst_filename="setup.py", ) - build_utils.update_setup_with_rocm_version(sources_path, rocm_version) - write_setup_cfg(sources_path, cpu) + build_utils.update_setup_with_rocm_version(wheel_sources_path, rocm_version) + write_setup_cfg(wheel_sources_path, cpu) - plugin_dir = sources_path / f"jax_rocm{rocm_version}_plugin" - copy_runfiles( + plugin_dir = wheel_sources_path / f"jax_rocm{rocm_version}_plugin" + copy_files( dst_dir=plugin_dir, src_files=[ - f"__main__/jaxlib/rocm/_blas.{pyext}", - f"__main__/jaxlib/rocm/_linalg.{pyext}", - f"__main__/jaxlib/rocm/_prng.{pyext}", - f"__main__/jaxlib/rocm/_solver.{pyext}", - f"__main__/jaxlib/rocm/_sparse.{pyext}", - f"__main__/jaxlib/rocm/_hybrid.{pyext}", - f"__main__/jaxlib/rocm/_rnn.{pyext}", - f"__main__/jaxlib/rocm/_triton.{pyext}", - f"__main__/jaxlib/rocm/rocm_plugin_extension.{pyext}", - "__main__/jaxlib/version.py", + f"{source_file_prefix}jaxlib/rocm/_linalg.{pyext}", + f"{source_file_prefix}jaxlib/rocm/_prng.{pyext}", + f"{source_file_prefix}jaxlib/rocm/_solver.{pyext}", + f"{source_file_prefix}jaxlib/rocm/_sparse.{pyext}", + f"{source_file_prefix}jaxlib/rocm/_hybrid.{pyext}", + f"{source_file_prefix}jaxlib/rocm/_rnn.{pyext}", + f"{source_file_prefix}jaxlib/rocm/_triton.{pyext}", + f"{source_file_prefix}jaxlib/rocm/rocm_plugin_extension.{pyext}", + f"{source_file_prefix}jaxlib/version.py", ], ) + # Build wheel for cuda kernels if args.enable_rocm: tmpdir = tempfile.TemporaryDirectory(prefix="jax_rocm_plugin") @@ -163,12 +192,18 @@ def prepare_wheel_rocm( os.makedirs(args.output_path, exist_ok=True) if args.enable_cuda: prepare_wheel_cuda( - pathlib.Path(sources_path), cpu=args.cpu, cuda_version=args.platform_version + pathlib.Path(sources_path), + cpu=args.cpu, + cuda_version=args.platform_version, + wheel_sources=args.srcs, ) package_name = f"jax cuda{args.platform_version} plugin" elif args.enable_rocm: prepare_wheel_rocm( - pathlib.Path(sources_path), cpu=args.cpu, rocm_version=args.platform_version + pathlib.Path(sources_path), + cpu=args.cpu, + rocm_version=args.platform_version, + wheel_sources=args.srcs, ) package_name = f"jax rocm{args.platform_version} plugin" if args.editable: diff --git a/jaxlib/tools/build_gpu_plugin_wheel.py b/jaxlib/tools/build_gpu_plugin_wheel.py index 667807b51197..68e08d89338e 100644 --- a/jaxlib/tools/build_gpu_plugin_wheel.py +++ b/jaxlib/tools/build_gpu_plugin_wheel.py @@ -26,7 +26,7 @@ from bazel_tools.tools.python.runfiles import runfiles from jaxlib.tools import build_utils -parser = argparse.ArgumentParser() +parser = argparse.ArgumentParser(fromfile_prefix_chars="@") parser.add_argument( "--sources_path", default=None, @@ -67,6 +67,9 @@ "--enable-rocm", default=False, help="Should we build with ROCM enabled?") +parser.add_argument( + "--srcs", help="source files for the wheel", action="append" +) args = parser.parse_args() r = runfiles.Create() @@ -81,62 +84,81 @@ def write_setup_cfg(sources_path, cpu): [bdist_wheel] plat_name={tag} -python-tag=py3 +python_tag=py3 """ ) +def prepare_cuda_plugin_wheel( + wheel_sources_path: pathlib.Path, *, cpu, cuda_version, wheel_sources +): + """Assembles a source tree for the wheel in `wheel_sources_path`""" + source_file_prefix = build_utils.get_source_file_prefix(wheel_sources) + wheel_sources_map = build_utils.create_wheel_sources_map( + wheel_sources, root_packages=["jax_plugins", "jaxlib"] + ) + copy_files = functools.partial( + build_utils.copy_file, + runfiles=r, + wheel_sources_map=wheel_sources_map, + ) -def prepare_cuda_plugin_wheel(sources_path: pathlib.Path, *, cpu, cuda_version): - """Assembles a source tree for the wheel in `sources_path`.""" - copy_runfiles = functools.partial(build_utils.copy_file, runfiles=r) - - plugin_dir = sources_path / "jax_plugins" / f"xla_cuda{cuda_version}" - copy_runfiles( - dst_dir=sources_path, + plugin_dir = wheel_sources_path / "jax_plugins" / f"xla_cuda{cuda_version}" + copy_files( + dst_dir=wheel_sources_path, src_files=[ - "__main__/jax_plugins/cuda/pyproject.toml", - "__main__/jax_plugins/cuda/setup.py", + f"{source_file_prefix}jax_plugins/cuda/pyproject.toml", + f"{source_file_prefix}jax_plugins/cuda/setup.py", ], ) - build_utils.update_setup_with_cuda_version(sources_path, cuda_version) - write_setup_cfg(sources_path, cpu) - copy_runfiles( + build_utils.update_setup_with_cuda_version(wheel_sources_path, cuda_version) + write_setup_cfg(wheel_sources_path, cpu) + copy_files( dst_dir=plugin_dir, src_files=[ - "__main__/jax_plugins/cuda/__init__.py", - "__main__/jaxlib/version.py", + f"{source_file_prefix}jax_plugins/cuda/__init__.py", + f"{source_file_prefix}jaxlib/version.py", ], ) - copy_runfiles( - "__main__/jaxlib/tools/pjrt_c_api_gpu_plugin.so", + copy_files( + f"{source_file_prefix}jax_plugins/cuda/pjrt_c_api_gpu_plugin.so", dst_dir=plugin_dir, dst_filename="xla_cuda_plugin.so", ) -def prepare_rocm_plugin_wheel(sources_path: pathlib.Path, *, cpu, rocm_version): - """Assembles a source tree for the ROCm wheel in `sources_path`.""" - copy_runfiles = functools.partial(build_utils.copy_file, runfiles=r) +def prepare_rocm_plugin_wheel( + wheel_sources_path: pathlib.Path, *, cpu, rocm_version, wheel_sources +): + """Assembles a source tree for the ROCm wheel in `wheel_sources_path`.""" + source_file_prefix = build_utils.get_source_file_prefix(wheel_sources) + wheel_sources_map = build_utils.create_wheel_sources_map( + wheel_sources, root_packages=["jax_plugins", "jaxlib"] + ) + copy_files = functools.partial( + build_utils.copy_file, + runfiles=r, + wheel_sources_map=wheel_sources_map, + ) - plugin_dir = sources_path / "jax_plugins" / f"xla_rocm{rocm_version}" - copy_runfiles( - dst_dir=sources_path, - src_files=[ - "__main__/jax_plugins/rocm/pyproject.toml", - "__main__/jax_plugins/rocm/setup.py", + plugin_dir = wheel_sources_path / "jax_plugins" / f"xla_rocm{rocm_version}" + copy_files( + dst_dir=wheel_sources_path, + src_files=[ + f"{source_file_prefix}jax_plugins/rocm/pyproject.toml", + f"{source_file_prefix}jax_plugins/rocm/setup.py", ], ) - build_utils.update_setup_with_rocm_version(sources_path, rocm_version) - write_setup_cfg(sources_path, cpu) - copy_runfiles( + build_utils.update_setup_with_rocm_version(wheel_sources_path, rocm_version) + write_setup_cfg(wheel_sources_path, cpu) + copy_files( dst_dir=plugin_dir, src_files=[ - "__main__/jax_plugins/rocm/__init__.py", - "__main__/jaxlib/version.py", + f"{source_file_prefix}jax_plugins/rocm/__init__.py", + f"{source_file_prefix}jaxlib/version.py", ], ) - copy_runfiles( - "__main__/jaxlib/tools/pjrt_c_api_gpu_plugin.so", + copy_files( + f"{source_file_prefix}jax_plugins/rocm/pjrt_c_api_gpu_plugin.so", dst_dir=plugin_dir, dst_filename="xla_rocm_plugin.so", ) @@ -153,12 +175,18 @@ def prepare_rocm_plugin_wheel(sources_path: pathlib.Path, *, cpu, rocm_version): if args.enable_cuda: prepare_cuda_plugin_wheel( - pathlib.Path(sources_path), cpu=args.cpu, cuda_version=args.platform_version + pathlib.Path(sources_path), + cpu=args.cpu, + cuda_version=args.platform_version, + wheel_sources=args.srcs, ) package_name = "jax cuda plugin" elif args.enable_rocm: prepare_rocm_plugin_wheel( - pathlib.Path(sources_path), cpu=args.cpu, rocm_version=args.platform_version + pathlib.Path(sources_path), + cpu=args.cpu, + rocm_version=args.platform_version, + wheel_sources=args.srcs, ) package_name = "jax rocm plugin" else: diff --git a/jaxlib/tools/build_utils.py b/jaxlib/tools/build_utils.py index 4c50cff16743..bf64a36ef0b7 100644 --- a/jaxlib/tools/build_utils.py +++ b/jaxlib/tools/build_utils.py @@ -27,29 +27,65 @@ from jaxlib.tools import platform_tags +MAIN_RUNFILES_DIR = "__main__/" + + def is_windows() -> bool: return sys.platform.startswith("win32") +def create_wheel_sources_map(wheel_sources, root_packages): + """Returns a map of paths relative to the root package to the full paths.""" + wheel_sources_map = {} + if not wheel_sources: + return wheel_sources_map + for source in wheel_sources: + for package in root_packages: + if source.startswith("{}/".format(package)): + wheel_sources_map[source] = source + continue + root_package_ind = source.find("/{}/".format(package)) + if root_package_ind >= 0: + wheel_sources_map[source[root_package_ind + 1:]] = source + return wheel_sources_map + + +# TODO(ybaturina): remove the method when we switch to the new wheel build rules +# and the runfiles are not needed. +def get_source_file_prefix(wheel_sources): + return "" if wheel_sources else MAIN_RUNFILES_DIR + + def copy_file( src_files: str | Sequence[str], dst_dir: pathlib.Path, - dst_filename = None, - runfiles = None, + dst_filename=None, + runfiles=None, + wheel_sources_map=None, ) -> None: dst_dir.mkdir(parents=True, exist_ok=True) if isinstance(src_files, str): src_files = [src_files] for src_file in src_files: - src_file_rloc = runfiles.Rlocation(src_file) - if src_file_rloc is None: + if wheel_sources_map: + src_file_loc = wheel_sources_map.get(src_file, None) + # TODO(ybaturina): remove the runfiles part when we switch to the new wheel + # build rules and the runfiles are not needed. + elif runfiles: + src_file_loc = runfiles.Rlocation(src_file) + else: + raise RuntimeError( + "Either runfiles or wheel_sources_map should be provided!" + ) + if src_file_loc is None: raise ValueError(f"Unable to find wheel source file {src_file}") - src_filename = os.path.basename(src_file_rloc) + + src_filename = os.path.basename(src_file_loc) dst_file = os.path.join(dst_dir, dst_filename or src_filename) if is_windows(): - shutil.copyfile(src_file_rloc, dst_file) + shutil.copyfile(src_file_loc, dst_file) else: - shutil.copy(src_file_rloc, dst_file) + shutil.copy(src_file_loc, dst_file) def platform_tag(cpu: str) -> str: @@ -65,6 +101,7 @@ def build_wheel( package_name: str, git_hash: str = "", build_wheel_only: bool = True, + build_source_package_only: bool = False, ) -> None: """Builds a wheel in `output_path` using the source tree in `sources_path`.""" env = dict(os.environ) @@ -78,7 +115,8 @@ def build_wheel( env["USERPROFILE"] = env.get("SYSTEMDRIVE", "C:") subprocess.run( [sys.executable, "-m", "build", "-n"] - + (["-w"] if build_wheel_only else []), + + (["-w"] if build_wheel_only else []) + + (["-s"] if build_source_package_only else []), check=True, cwd=sources_path, env=env, @@ -97,10 +135,10 @@ def build_wheel( sys.stderr.write(" bazel run //build:requirements.update" + f" --repo_env=HERMETIC_PYTHON_VERSION={py_version}\n\n") shutil.copy(wheel, output_path) - if not build_wheel_only: + if build_source_package_only: for dist in glob.glob(os.path.join(sources_path, "dist", "*.tar.gz")): output_file = os.path.join(output_path, os.path.basename(dist)) - sys.stderr.write(f"Output source distribution: {output_file}\n\n") + sys.stderr.write(f"Output source package: {output_file}\n\n") shutil.copy(dist, output_path) diff --git a/jaxlib/tools/build_wheel.py b/jaxlib/tools/build_wheel.py index 8632468acb97..cf1be5e5a8ed 100644 --- a/jaxlib/tools/build_wheel.py +++ b/jaxlib/tools/build_wheel.py @@ -29,7 +29,7 @@ from bazel_tools.tools.python.runfiles import runfiles from jaxlib.tools import build_utils -parser = argparse.ArgumentParser() +parser = argparse.ArgumentParser(fromfile_prefix_chars="@") parser.add_argument( "--sources_path", default=None, @@ -56,27 +56,38 @@ action="store_true", help="Create an 'editable' jaxlib build instead of a wheel.", ) +parser.add_argument( + "--srcs", help="source files for the wheel", action="append" +) args = parser.parse_args() r = runfiles.Create() - def _is_mac(): return platform.system() == "Darwin" +soext = "dll" if build_utils.is_windows() else ("dylib" if _is_mac() else "so") pyext = "pyd" if build_utils.is_windows() else "so" -def exists(src_file): - path = r.Rlocation(src_file) - if path is None: - return False - return os.path.exists(path) +def _get_file_path(src_file, runfiles=None, wheel_sources_map=None): + if wheel_sources_map: + return wheel_sources_map.get( + src_file.replace(build_utils.MAIN_RUNFILES_DIR, ""), None + ) + # TODO(ybaturina): remove the runfiles part when we switch to the new wheel + # build rules and the runfiles are not needed. + elif runfiles: + return runfiles.Rlocation(src_file) + else: + raise RuntimeError("Either runfiles or wheel_sources should be provided!") -def patch_copy_mlir_import(src_file, dst_dir): - src_file = r.Rlocation(src_file) +def patch_copy_mlir_import( + src_file, dst_dir, runfiles=None, wheel_sources_map=None +): + src_file = _get_file_path(src_file, runfiles, wheel_sources_map) src_filename = os.path.basename(src_file) with open(src_file) as f: src = f.read() @@ -91,40 +102,10 @@ def patch_copy_mlir_import(src_file, dst_dir): f.write(replaced) -_XLA_EXTENSION_STUBS = [ - "__init__.pyi", - "guard_lib.pyi", - "ifrt_programs.pyi", - "ifrt_proxy.pyi", - "jax_jit.pyi", - "ops.pyi", - "pmap_lib.pyi", - "profiler.pyi", - "pytree.pyi", - "transfer_guard_lib.pyi", -] -_OPTIONAL_XLA_EXTENSION_STUBS = [] - - -def patch_copy_xla_extension_stubs(dst_dir): - xla_extension_dir = os.path.join(dst_dir, "xla_extension") - os.makedirs(xla_extension_dir) - for stub_name in _XLA_EXTENSION_STUBS: - stub_path = r.Rlocation("xla/xla/python/xla_extension/" + stub_name) - stub_path = str(stub_path) # Make pytype accept os.path.exists(stub_path). - if stub_name in _OPTIONAL_XLA_EXTENSION_STUBS and not os.path.exists(stub_path): - continue - with open(stub_path) as f: - src = f.read() - src = src.replace( - "from xla.python import xla_extension", "from .. import xla_extension" - ) - with open(os.path.join(xla_extension_dir, stub_name), "w") as f: - f.write(src) - - -def verify_mac_libraries_dont_reference_chkstack(): - """Verifies that xla_extension.so doesn't depend on ____chkstk_darwin. +def verify_mac_libraries_dont_reference_chkstack( + runfiles=None, wheel_sources_map=None +): + """Verifies that _jax.so doesn't depend on ____chkstk_darwin. We don't entirely know why this happens, but in some build environments we seem to target the wrong Mac OS version. @@ -134,8 +115,11 @@ def verify_mac_libraries_dont_reference_chkstack(): """ if not _is_mac(): return + file_path = _get_file_path( + f"__main__/jaxlib/_jax.{pyext}", runfiles, wheel_sources_map + ) nm = subprocess.run( - ["nm", "-g", r.Rlocation("xla/xla/python/xla_extension.so")], + ["nm", "-g", file_path], capture_output=True, text=True, check=False, @@ -162,214 +146,251 @@ def write_setup_cfg(sources_path, cpu): ) -def prepare_wheel(sources_path: pathlib.Path, *, cpu): - """Assembles a source tree for the wheel in `sources_path`.""" - copy_runfiles = functools.partial(build_utils.copy_file, runfiles=r) +def prepare_wheel(wheel_sources_path: pathlib.Path, *, cpu, wheel_sources): + """Assembles a source tree for the wheel in `wheel_sources_path`.""" + source_file_prefix = build_utils.get_source_file_prefix(wheel_sources) + # The wheel sources provided by the transitive rules might have different path + # prefixes, so we need to create a map of paths relative to the root package + # to the full paths. + # E.g. if we have the wheel sources paths like + # bazel-out/k8-opt/bin/jaxlib/mlir/_mlir_libs/register_jax_dialects.py and + # external/xla/xla/ffi/api/c_api.h, the resulting map will be + # {'jaxlib/mlir/_mlir_libs/register_jax_dialects.py': + # 'bazel-out/k8-opt/bin/jaxlib/mlir/_mlir_libs/register_jax_dialects.py', + # 'xla/ffi/api/c_api.h': 'external/xla/xla/ffi/api/c_api.h'} + wheel_sources_map = build_utils.create_wheel_sources_map( + wheel_sources, root_packages=["jaxlib", "xla"] + ) + copy_files = functools.partial( + build_utils.copy_file, + runfiles=r, + wheel_sources_map=wheel_sources_map, + ) - verify_mac_libraries_dont_reference_chkstack() - copy_runfiles( - dst_dir=sources_path, + verify_mac_libraries_dont_reference_chkstack( + runfiles=r, wheel_sources_map=wheel_sources_map + ) + copy_files( + dst_dir=wheel_sources_path, src_files=[ - "__main__/jaxlib/tools/LICENSE.txt", - "__main__/jaxlib/README.md", - "__main__/jaxlib/setup.py", + f"{source_file_prefix}jaxlib/tools/LICENSE.txt", + f"{source_file_prefix}jaxlib/README.md", + f"{source_file_prefix}jaxlib/setup.py", ], ) - write_setup_cfg(sources_path, cpu) + write_setup_cfg(wheel_sources_path, cpu) - jaxlib_dir = sources_path / "jaxlib" - copy_runfiles( - "__main__/jaxlib/init.py", dst_dir=jaxlib_dir, dst_filename="__init__.py" + jaxlib_dir = wheel_sources_path / "jaxlib" + copy_files( + f"{source_file_prefix}jaxlib/init.py", + dst_dir=jaxlib_dir, + dst_filename="__init__.py", ) - copy_runfiles( + copy_files( dst_dir=jaxlib_dir, src_files=[ - f"__main__/jaxlib/cpu_feature_guard.{pyext}", - f"__main__/jaxlib/utils.{pyext}", - "__main__/jaxlib/lapack.py", - "__main__/jaxlib/hlo_helpers.py", - "__main__/jaxlib/gpu_prng.py", - "__main__/jaxlib/gpu_linalg.py", - "__main__/jaxlib/gpu_rnn.py", - "__main__/jaxlib/gpu_triton.py", - "__main__/jaxlib/gpu_common_utils.py", - "__main__/jaxlib/gpu_solver.py", - "__main__/jaxlib/gpu_sparse.py", - "__main__/jaxlib/plugin_support.py", - "__main__/jaxlib/version.py", - "__main__/jaxlib/xla_client.py", - f"xla/xla/python/xla_extension.{pyext}", + f"{source_file_prefix}jaxlib/cpu_feature_guard.{pyext}", + f"{source_file_prefix}jaxlib/cpu_sparse.py", + f"{source_file_prefix}jaxlib/utils.{pyext}", + f"{source_file_prefix}jaxlib/jax_common.dll" + if build_utils.is_windows() + else f"{source_file_prefix}jaxlib/libjax_common.{soext}", + f"{source_file_prefix}jaxlib/lapack.py", + f"{source_file_prefix}jaxlib/hlo_helpers.py", + f"{source_file_prefix}jaxlib/gpu_prng.py", + f"{source_file_prefix}jaxlib/gpu_linalg.py", + f"{source_file_prefix}jaxlib/gpu_rnn.py", + f"{source_file_prefix}jaxlib/gpu_triton.py", + f"{source_file_prefix}jaxlib/gpu_common_utils.py", + f"{source_file_prefix}jaxlib/gpu_solver.py", + f"{source_file_prefix}jaxlib/gpu_sparse.py", + f"{source_file_prefix}jaxlib/plugin_support.py", + f"{source_file_prefix}jaxlib/_pretty_printer.{pyext}", + f"{source_file_prefix}jaxlib/version.py", + f"{source_file_prefix}jaxlib/xla_client.py", + f"{source_file_prefix}jaxlib/weakref_lru_cache.{pyext}", + f"{source_file_prefix}jaxlib/weakref_lru_cache.pyi", + f"{source_file_prefix}jaxlib/_jax.{pyext}", + f"{source_file_prefix}jaxlib/_profiler.{pyext}", ], ) # This file is required by PEP-561. It marks jaxlib as package containing # type stubs. with open(jaxlib_dir / "py.typed", "w"): pass - patch_copy_xla_extension_stubs(jaxlib_dir) - copy_runfiles( + copy_files( dst_dir=jaxlib_dir / "cpu", src_files=[ - f"__main__/jaxlib/cpu/_lapack.{pyext}", + f"{source_file_prefix}jaxlib/cpu/_lapack.{pyext}", + f"{source_file_prefix}jaxlib/cpu/_sparse.{pyext}", ], ) mosaic_python_dir = jaxlib_dir / "mosaic" / "python" - copy_runfiles( + copy_files( dst_dir=mosaic_python_dir, src_files=[ - "__main__/jaxlib/mosaic/python/layout_defs.py", - "__main__/jaxlib/mosaic/python/mosaic_gpu.py", - "__main__/jaxlib/mosaic/python/tpu.py", + f"{source_file_prefix}jaxlib/mosaic/python/layout_defs.py", + f"{source_file_prefix}jaxlib/mosaic/python/mosaic_gpu.py", + f"{source_file_prefix}jaxlib/mosaic/python/tpu.py", ], ) # TODO (sharadmv,skyewm): can we avoid patching this file? patch_copy_mlir_import( - "__main__/jaxlib/mosaic/python/_tpu_gen.py", dst_dir=mosaic_python_dir + f"{source_file_prefix}jaxlib/mosaic/python/_tpu_gen.py", + dst_dir=mosaic_python_dir, + runfiles=r, + wheel_sources_map=wheel_sources_map, ) mosaic_gpu_dir = jaxlib_dir / "mosaic" / "dialect" / "gpu" os.makedirs(mosaic_gpu_dir) patch_copy_mlir_import( - "__main__/jaxlib/mosaic/dialect/gpu/_mosaic_gpu_gen_ops.py", + f"{source_file_prefix}jaxlib/mosaic/dialect/gpu/_mosaic_gpu_gen_ops.py", dst_dir=mosaic_gpu_dir, + runfiles=r, + wheel_sources_map=wheel_sources_map, ) patch_copy_mlir_import( - "__main__/jaxlib/mosaic/dialect/gpu/_mosaic_gpu_gen_enums.py", + f"{source_file_prefix}jaxlib/mosaic/dialect/gpu/_mosaic_gpu_gen_enums.py", dst_dir=mosaic_gpu_dir, + runfiles=r, + wheel_sources_map=wheel_sources_map, ) - copy_runfiles( + copy_files( dst_dir=jaxlib_dir / "mlir", src_files=[ - "__main__/jaxlib/mlir/ir.py", - "__main__/jaxlib/mlir/ir.pyi", - "__main__/jaxlib/mlir/passmanager.py", - "__main__/jaxlib/mlir/passmanager.pyi", + f"{source_file_prefix}jaxlib/mlir/ir.py", + f"{source_file_prefix}jaxlib/mlir/ir.pyi", + f"{source_file_prefix}jaxlib/mlir/passmanager.py", + f"{source_file_prefix}jaxlib/mlir/passmanager.pyi", ], ) - copy_runfiles( + copy_files( dst_dir=jaxlib_dir / "mlir" / "dialects", src_files=[ - "__main__/jaxlib/mlir/dialects/_arith_enum_gen.py", - "__main__/jaxlib/mlir/dialects/_arith_ops_gen.py", - "__main__/jaxlib/mlir/dialects/_builtin_ops_gen.py", - "__main__/jaxlib/mlir/dialects/_chlo_ops_gen.py", - "__main__/jaxlib/mlir/dialects/_func_ops_gen.py", - "__main__/jaxlib/mlir/dialects/_math_ops_gen.py", - "__main__/jaxlib/mlir/dialects/_memref_ops_gen.py", - "__main__/jaxlib/mlir/dialects/_mhlo_ops_gen.py", - "__main__/jaxlib/mlir/dialects/_ods_common.py", - "__main__/jaxlib/mlir/dialects/_scf_ops_gen.py", - "__main__/jaxlib/mlir/dialects/_sdy_ops_gen.py", - "__main__/jaxlib/mlir/dialects/_sparse_tensor_enum_gen.py", - "__main__/jaxlib/mlir/dialects/_sparse_tensor_ops_gen.py", - "__main__/jaxlib/mlir/dialects/_stablehlo_ops_gen.py", - "__main__/jaxlib/mlir/dialects/_vector_enum_gen.py", - "__main__/jaxlib/mlir/dialects/_vector_ops_gen.py", - "__main__/jaxlib/mlir/dialects/_gpu_enum_gen.py", - "__main__/jaxlib/mlir/dialects/_gpu_ops_gen.py", - "__main__/jaxlib/mlir/dialects/_nvgpu_enum_gen.py", - "__main__/jaxlib/mlir/dialects/_nvgpu_ops_gen.py", - "__main__/jaxlib/mlir/dialects/_nvvm_enum_gen.py", - "__main__/jaxlib/mlir/dialects/_nvvm_ops_gen.py", - "__main__/jaxlib/mlir/dialects/_llvm_enum_gen.py", - "__main__/jaxlib/mlir/dialects/_llvm_ops_gen.py", - "__main__/jaxlib/mlir/dialects/arith.py", - "__main__/jaxlib/mlir/dialects/builtin.py", - "__main__/jaxlib/mlir/dialects/chlo.py", - "__main__/jaxlib/mlir/dialects/func.py", - "__main__/jaxlib/mlir/dialects/math.py", - "__main__/jaxlib/mlir/dialects/memref.py", - "__main__/jaxlib/mlir/dialects/mhlo.py", - "__main__/jaxlib/mlir/dialects/scf.py", - "__main__/jaxlib/mlir/dialects/sdy.py", - "__main__/jaxlib/mlir/dialects/sparse_tensor.py", - "__main__/jaxlib/mlir/dialects/stablehlo.py", - "__main__/jaxlib/mlir/dialects/vector.py", - "__main__/jaxlib/mlir/dialects/nvgpu.py", - "__main__/jaxlib/mlir/dialects/nvvm.py", - "__main__/jaxlib/mlir/dialects/llvm.py", + f"{source_file_prefix}jaxlib/mlir/dialects/_arith_enum_gen.py", + f"{source_file_prefix}jaxlib/mlir/dialects/_arith_ops_gen.py", + f"{source_file_prefix}jaxlib/mlir/dialects/_builtin_ops_gen.py", + f"{source_file_prefix}jaxlib/mlir/dialects/_cf_ops_gen.py", + f"{source_file_prefix}jaxlib/mlir/dialects/_chlo_ops_gen.py", + f"{source_file_prefix}jaxlib/mlir/dialects/_func_ops_gen.py", + f"{source_file_prefix}jaxlib/mlir/dialects/_math_ops_gen.py", + f"{source_file_prefix}jaxlib/mlir/dialects/_memref_ops_gen.py", + f"{source_file_prefix}jaxlib/mlir/dialects/_mhlo_ops_gen.py", + f"{source_file_prefix}jaxlib/mlir/dialects/_ods_common.py", + f"{source_file_prefix}jaxlib/mlir/dialects/_scf_ops_gen.py", + f"{source_file_prefix}jaxlib/mlir/dialects/_sdy_enums_gen.py", + f"{source_file_prefix}jaxlib/mlir/dialects/_sdy_ops_gen.py", + f"{source_file_prefix}jaxlib/mlir/dialects/_sparse_tensor_enum_gen.py", + f"{source_file_prefix}jaxlib/mlir/dialects/_sparse_tensor_ops_gen.py", + f"{source_file_prefix}jaxlib/mlir/dialects/_stablehlo_ops_gen.py", + f"{source_file_prefix}jaxlib/mlir/dialects/_vector_enum_gen.py", + f"{source_file_prefix}jaxlib/mlir/dialects/_vector_ops_gen.py", + f"{source_file_prefix}jaxlib/mlir/dialects/_gpu_enum_gen.py", + f"{source_file_prefix}jaxlib/mlir/dialects/_gpu_ops_gen.py", + f"{source_file_prefix}jaxlib/mlir/dialects/_nvgpu_enum_gen.py", + f"{source_file_prefix}jaxlib/mlir/dialects/_nvgpu_ops_gen.py", + f"{source_file_prefix}jaxlib/mlir/dialects/_nvvm_enum_gen.py", + f"{source_file_prefix}jaxlib/mlir/dialects/_nvvm_ops_gen.py", + f"{source_file_prefix}jaxlib/mlir/dialects/_llvm_enum_gen.py", + f"{source_file_prefix}jaxlib/mlir/dialects/_llvm_ops_gen.py", + f"{source_file_prefix}jaxlib/mlir/dialects/arith.py", + f"{source_file_prefix}jaxlib/mlir/dialects/builtin.py", + f"{source_file_prefix}jaxlib/mlir/dialects/cf.py", + f"{source_file_prefix}jaxlib/mlir/dialects/chlo.py", + f"{source_file_prefix}jaxlib/mlir/dialects/func.py", + f"{source_file_prefix}jaxlib/mlir/dialects/math.py", + f"{source_file_prefix}jaxlib/mlir/dialects/memref.py", + f"{source_file_prefix}jaxlib/mlir/dialects/mhlo.py", + f"{source_file_prefix}jaxlib/mlir/dialects/scf.py", + f"{source_file_prefix}jaxlib/mlir/dialects/sdy.py", + f"{source_file_prefix}jaxlib/mlir/dialects/sparse_tensor.py", + f"{source_file_prefix}jaxlib/mlir/dialects/stablehlo.py", + f"{source_file_prefix}jaxlib/mlir/dialects/vector.py", + f"{source_file_prefix}jaxlib/mlir/dialects/nvgpu.py", + f"{source_file_prefix}jaxlib/mlir/dialects/nvvm.py", + f"{source_file_prefix}jaxlib/mlir/dialects/llvm.py", ], ) - copy_runfiles( + copy_files( dst_dir=jaxlib_dir / "mlir" / "extras", src_files=[ - "__main__/jaxlib/mlir/extras/meta.py", + f"{source_file_prefix}jaxlib/mlir/extras/meta.py", ], ) - copy_runfiles( + copy_files( dst_dir=jaxlib_dir / "mlir" / "dialects" / "gpu", src_files=[ - "__main__/jaxlib/mlir/dialects/gpu/__init__.py", + f"{source_file_prefix}jaxlib/mlir/dialects/gpu/__init__.py", ], ) - copy_runfiles( + copy_files( dst_dir=jaxlib_dir / "mlir" / "dialects" / "gpu" / "passes", src_files=[ - "__main__/jaxlib/mlir/dialects/gpu/passes/__init__.py", + f"{source_file_prefix}jaxlib/mlir/dialects/gpu/passes/__init__.py", ], ) - - if build_utils.is_windows(): - capi_so = "__main__/jaxlib/mlir/_mlir_libs/jaxlib_mlir_capi.dll" - else: - so_ext = "dylib" if _is_mac() else "so" - capi_so = f"__main__/jaxlib/mlir/_mlir_libs/libjaxlib_mlir_capi.{so_ext}" - mlir_libs_dir = jaxlib_dir / "mlir" / "_mlir_libs" - copy_runfiles( + copy_files( dst_dir=mlir_libs_dir, src_files=[ - capi_so, - "__main__/jaxlib/mlir/_mlir_libs/__init__.py", - f"__main__/jaxlib/mlir/_mlir_libs/_mlir.{pyext}", - f"__main__/jaxlib/mlir/_mlir_libs/_chlo.{pyext}", - f"__main__/jaxlib/mlir/_mlir_libs/_mlirHlo.{pyext}", - f"__main__/jaxlib/mlir/_mlir_libs/_mlirDialectsSparseTensor.{pyext}", - f"__main__/jaxlib/mlir/_mlir_libs/_mlirSparseTensorPasses.{pyext}", - f"__main__/jaxlib/mlir/_mlir_libs/_mosaic_gpu_ext.{pyext}", - f"__main__/jaxlib/mlir/_mlir_libs/_tpu_ext.{pyext}", - f"__main__/jaxlib/mlir/_mlir_libs/_sdy.{pyext}", - f"__main__/jaxlib/mlir/_mlir_libs/_stablehlo.{pyext}", - f"__main__/jaxlib/mlir/_mlir_libs/register_jax_dialects.{pyext}", - f"__main__/jaxlib/mlir/_mlir_libs/_mlirDialectsGPU.{pyext}", - f"__main__/jaxlib/mlir/_mlir_libs/_mlirDialectsLLVM.{pyext}", - f"__main__/jaxlib/mlir/_mlir_libs/_mlirDialectsNVGPU.{pyext}", - f"__main__/jaxlib/mlir/_mlir_libs/_mlirGPUPasses.{pyext}", + f"{source_file_prefix}jaxlib/mlir/_mlir_libs/__init__.py", + f"{source_file_prefix}jaxlib/_mlir.{pyext}", + f"{source_file_prefix}jaxlib/_chlo.{pyext}", + f"{source_file_prefix}jaxlib/_mlirHlo.{pyext}", + f"{source_file_prefix}jaxlib/_mlirDialectsSparseTensor.{pyext}", + f"{source_file_prefix}jaxlib/_mlirSparseTensorPasses.{pyext}", + f"{source_file_prefix}jaxlib/_mosaic_gpu_ext.{pyext}", + f"{source_file_prefix}jaxlib/_tpu_ext.{pyext}", + f"{source_file_prefix}jaxlib/_sdy.{pyext}", + f"{source_file_prefix}jaxlib/_stablehlo.{pyext}", + f"{source_file_prefix}jaxlib/register_jax_dialects.{pyext}", + f"{source_file_prefix}jaxlib/_mlirDialectsGPU.{pyext}", + f"{source_file_prefix}jaxlib/_mlirDialectsLLVM.{pyext}", + f"{source_file_prefix}jaxlib/_mlirDialectsNVGPU.{pyext}", + f"{source_file_prefix}jaxlib/_mlirGPUPasses.{pyext}", ] + ( [] if build_utils.is_windows() else [ - f"__main__/jaxlib/mlir/_mlir_libs/_triton_ext.{pyext}", - "__main__/jaxlib/mlir/_mlir_libs/_triton_ext.pyi", + f"{source_file_prefix}jaxlib/_triton_ext.{pyext}", + f"{source_file_prefix}jaxlib/mlir/_mlir_libs/_triton_ext.pyi", ] ), ) triton_dir = jaxlib_dir / "triton" - copy_runfiles( + copy_files( dst_dir=triton_dir, src_files=[ - "__main__/jaxlib/triton/__init__.py", - "__main__/jaxlib/triton/dialect.py", + f"{source_file_prefix}jaxlib/triton/__init__.py", + f"{source_file_prefix}jaxlib/triton/dialect.py", ], ) patch_copy_mlir_import( - "__main__/jaxlib/triton/_triton_enum_gen.py", dst_dir=triton_dir + f"{source_file_prefix}jaxlib/triton/_triton_enum_gen.py", + dst_dir=triton_dir, + runfiles=r, + wheel_sources_map=wheel_sources_map, ) patch_copy_mlir_import( - "__main__/jaxlib/triton/_triton_ops_gen.py", dst_dir=triton_dir + f"{source_file_prefix}jaxlib/triton/_triton_ops_gen.py", + dst_dir=triton_dir, + runfiles=r, + wheel_sources_map=wheel_sources_map, ) - copy_runfiles( - dst_dir=jaxlib_dir / "include" / "xla" / "ffi" / "api", - src_files=[ - "xla/xla/ffi/api/c_api.h", - "xla/xla/ffi/api/api.h", - "xla/xla/ffi/api/ffi.h", - ], + copy_files( + dst_dir=jaxlib_dir / "include" / "xla" / "ffi" / "api", + src_files=[ + "xla/xla/ffi/api/c_api.h", + "xla/xla/ffi/api/api.h", + "xla/xla/ffi/api/ffi.h", + ], ) tmpdir = None @@ -383,6 +404,7 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu): prepare_wheel( pathlib.Path(sources_path), cpu=args.cpu, + wheel_sources=args.srcs, ) package_name = "jaxlib" if args.editable: diff --git a/jaxlib/tools/wheel_size_test.py b/jaxlib/tools/wheel_size_test.py new file mode 100644 index 000000000000..7e9c08ff9797 --- /dev/null +++ b/jaxlib/tools/wheel_size_test.py @@ -0,0 +1,56 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging +import os + + +def parse_args(): + """Arguments parser.""" + parser = argparse.ArgumentParser( + description="Helper for the wheel size verification", + fromfile_prefix_chars="@", + ) + parser.add_argument( + "--wheel-path", required=True, help="Path of the wheel, mandatory" + ) + parser.add_argument( + "--max-size-mib", + required=True, + help="Maximum size of the wheel in MiB", + ) + return parser.parse_args() + + +def verify_wheel_size(args): + wheel_size_mib = os.path.getsize(args.wheel_path) >> 20 + wheel_name = os.path.basename(args.wheel_path) + if wheel_size_mib > int(args.max_size_mib): + raise RuntimeError( + "The {name} size is {size} MiB, which is larger than the maximum size" + " {max_size} MiB".format( + name=wheel_name, + size=wheel_size_mib, + max_size=args.max_size_mb, + ) + ) + else: + logging.info( + "The %s size is %s MiB, which is less than the maximum size" + " %s MB", wheel_name, wheel_size_mib, args.max_size_mib) + + +if __name__ == "__main__": + verify_wheel_size(parse_args()) diff --git a/jaxlib/traceback.cc b/jaxlib/traceback.cc new file mode 100644 index 000000000000..8a309ebb6f8f --- /dev/null +++ b/jaxlib/traceback.cc @@ -0,0 +1,425 @@ +/* Copyright 2020 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/traceback.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/base/casts.h" +#include "absl/hash/hash.h" +#include "absl/log/check.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/optional.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "xla/pjrt/exceptions.h" +#include "xla/python/nb_helpers.h" +#include "tsl/platform/platform.h" + +#ifdef PLATFORM_GOOGLE +#define Py_BUILD_CORE +#include "internal/pycore_frame.h" +#undef Py_BUILD_CORE +#endif // PLATFORM_GOOGLE + +namespace nb = nanobind; + +namespace xla { + +namespace { + +std::atomic traceback_enabled_ = true; + +static constexpr int kMaxFrames = 512; + +PyTypeObject* traceback_type_ = nullptr; + +// Entry in a traceback. Must be POD. +struct TracebackEntry { + TracebackEntry() = default; + TracebackEntry(PyCodeObject* code, int lasti) : code(code), lasti(lasti) {} + PyCodeObject* code; + int lasti; + + bool operator==(const TracebackEntry& other) const { + return code == other.code && lasti == other.lasti; + } + bool operator!=(const TracebackEntry& other) const { + return !operator==(other); + } +}; +static_assert(std::is_trivial_v == true); + +template +H AbslHashValue(H h, const TracebackEntry& entry) { + h = H::combine(std::move(h), entry.code, entry.lasti); + return h; +} + +struct TracebackObject { + PyObject_VAR_HEAD; + TracebackEntry frames[]; +}; + +template +H AbslHashValue(H h, const TracebackObject& tb) { + h = H::combine_contiguous(std::move(h), &tb.frames[0], Py_SIZE(&tb)); + return h; +} + +static_assert(sizeof(TracebackObject) % alignof(PyObject) == 0); +static_assert(sizeof(TracebackEntry) % alignof(void*) == 0); + +bool traceback_check(nb::handle o) { + return Py_TYPE(o.ptr()) == traceback_type_; +} + +Py_hash_t traceback_tp_hash(PyObject* o) { + TracebackObject* tb = reinterpret_cast(o); + size_t h = absl::HashOf(*tb); + Py_hash_t s = absl::bit_cast(h); // Python hashes are signed. + return s == -1 ? -2 : s; // -1 must not be used as a Python hash value. +} + +PyObject* traceback_tp_richcompare(PyObject* self, PyObject* other, int op) { + if (op != Py_EQ && op != Py_NE) { + return Py_NewRef(Py_NotImplemented); + } + + if (!traceback_check(other)) { + return Py_NewRef(Py_False); + } + TracebackObject* tb_self = reinterpret_cast(self); + TracebackObject* tb_other = reinterpret_cast(other); + if (Py_SIZE(tb_self) != Py_SIZE(tb_other)) { + return Py_NewRef(op == Py_EQ ? Py_False : Py_True); + } + for (Py_ssize_t i = 0; i < Py_SIZE(tb_self); ++i) { + if ((tb_self->frames[i] != tb_other->frames[i])) { + return Py_NewRef(op == Py_EQ ? Py_False : Py_True); + } + } + return Py_NewRef(op == Py_EQ ? Py_True : Py_False); +} + +static void traceback_tp_dealloc(PyObject* self) { + TracebackObject* tb = reinterpret_cast(self); + for (Py_ssize_t i = 0; i < Py_SIZE(tb); ++i) { + Py_XDECREF(tb->frames[i].code); + } + PyTypeObject* tp = Py_TYPE(self); + tp->tp_free((PyObject*)self); + Py_DECREF(tp); +} + +Traceback::Frame DecodeFrame(const TracebackEntry& frame) { + return Traceback::Frame{ + .file_name = nb::borrow(frame.code->co_filename), + .function_name = nb::borrow(frame.code->co_name), + .function_start_line = frame.code->co_firstlineno, + .line_num = PyCode_Addr2Line(frame.code, frame.lasti), + }; +} + +std::string traceback_to_string(const TracebackObject* tb) { + std::vector frame_strs; + frame_strs.reserve(Py_SIZE(tb)); + for (Py_ssize_t i = 0; i < Py_SIZE(tb); ++i) { + frame_strs.push_back(DecodeFrame(tb->frames[i]).ToString()); + } + return absl::StrJoin(frame_strs, "\n"); +} + +PyObject* traceback_tp_str(PyObject* self) { + TracebackObject* tb = reinterpret_cast(self); + return nb::cast(traceback_to_string(tb)).release().ptr(); +} + +// It turns out to be slightly faster to define a tp_hash slot rather than +// defining __hash__ and __eq__ on the class. +PyType_Slot traceback_slots_[] = { + {Py_tp_hash, reinterpret_cast(traceback_tp_hash)}, + {Py_tp_richcompare, reinterpret_cast(traceback_tp_richcompare)}, + {Py_tp_dealloc, reinterpret_cast(traceback_tp_dealloc)}, + {Py_tp_str, reinterpret_cast(traceback_tp_str)}, + {0, nullptr}, +}; + +nb::object AsPythonTraceback(const Traceback& tb) { + nb::object traceback = nb::none(); + nb::dict globals; + nb::handle traceback_type(reinterpret_cast(&PyTraceBack_Type)); + TracebackObject* tb_obj = reinterpret_cast(tb.ptr()); + for (Py_ssize_t i = 0; i < Py_SIZE(tb_obj); ++i) { + const TracebackEntry& frame = tb_obj->frames[i]; + int lineno = PyCode_Addr2Line(frame.code, frame.lasti); + // Under Python 3.11 we observed crashes when using a fake PyFrameObject + // with a real PyCodeObject (https://github.com/google/jax/issues/16027). + // because the frame does not have fields necessary to compute the locals, + // notably the closure object, leading to crashes in CPython in + // _PyFrame_FastToLocalsWithError + // https://github.com/python/cpython/blob/deaf509e8fc6e0363bd6f26d52ad42f976ec42f2/Objects/frameobject.c#LL1116C2-L1116C2 + // We therefore always build a fake code object to go along with our fake + // frame. + PyCodeObject* py_code = + PyCode_NewEmpty(PyUnicode_AsUTF8(frame.code->co_filename), + PyUnicode_AsUTF8(frame.code->co_name), lineno); + PyFrameObject* py_frame = PyFrame_New(PyThreadState_Get(), py_code, + globals.ptr(), /*locals=*/nullptr); + Py_DECREF(py_code); + + traceback = traceback_type( + /*tb_next=*/std::move(traceback), + /*tb_frame=*/ + nb::steal(reinterpret_cast(py_frame)), + /*tb_lasti=*/0, + /*tb_lineno=*/lineno); + } + return traceback; +} + +} // namespace + +std::vector Traceback::Frames() const { + // We require the GIL because we manipulate Python strings. + CHECK(PyGILState_Check()); + std::vector frames; + TracebackObject* tb = reinterpret_cast(ptr()); + frames.reserve(Py_SIZE(tb)); + for (Py_ssize_t i = 0; i < Py_SIZE(tb); ++i) { + const TracebackEntry& frame = tb->frames[i]; + frames.push_back(Frame{nb::borrow(frame.code->co_filename), + nb::borrow(frame.code->co_name), + frame.code->co_firstlineno, + PyCode_Addr2Line(frame.code, frame.lasti)}); + } + return frames; +} + +std::string Traceback::Frame::ToString() const { + return absl::StrFormat("%s:%d (%s)", nb::cast(file_name), + line_num, nb::cast(function_name)); +} + +std::string Traceback::ToString() const { + return traceback_to_string(reinterpret_cast(ptr())); +} + +std::vector> Traceback::RawFrames() const { + const TracebackObject* tb = reinterpret_cast(ptr()); + std::vector> frames; + frames.reserve(Py_SIZE(tb)); + for (Py_ssize_t i = 0; i < Py_SIZE(tb); ++i) { + frames.push_back(std::make_pair(tb->frames[i].code, tb->frames[i].lasti)); + } + return frames; +} + +/*static*/ bool Traceback::Check(PyObject* o) { return traceback_check(o); } + +/*static*/ std::optional Traceback::Get() { + // We use a thread_local here mostly to avoid requiring a large amount of + // space. + thread_local std::array frames; + int count = 0; + + DCHECK(PyGILState_Check()); + + if (!traceback_enabled_.load()) { + return std::nullopt; + } + + PyThreadState* thread_state = PyThreadState_GET(); + +#ifdef PLATFORM_GOOGLE +// This code is equivalent to the version using public APIs, but it saves us +// an allocation of one object per stack frame. However, this is definitely +// violating the API contract of CPython, so we only use this where we can be +// confident we know exactly which CPython we are using (internal to Google). +// Feel free to turn this on if you like, but it might break at any time! +#if PY_VERSION_HEX < 0x030d0000 + for (_PyInterpreterFrame* f = thread_state->cframe->current_frame; + f != nullptr && count < kMaxFrames; f = f->previous) { + if (_PyFrame_IsIncomplete(f)) continue; + Py_INCREF(f->f_code); + frames[count] = {f->f_code, static_cast(_PyInterpreterFrame_LASTI(f) * + sizeof(_Py_CODEUNIT))}; + ++count; + } +#else // PY_VERSION_HEX < 0x030d0000 + for (_PyInterpreterFrame* f = thread_state->current_frame; + f != nullptr && count < kMaxFrames; f = f->previous) { + if (_PyFrame_IsIncomplete(f)) continue; + Py_INCREF(f->f_executable); + frames[count] = { + reinterpret_cast(f->f_executable), + static_cast(_PyInterpreterFrame_LASTI(f) * sizeof(_Py_CODEUNIT))}; + ++count; + } +#endif // PY_VERSION_HEX < 0x030d0000 + +#else // PLATFORM_GOOGLE + PyFrameObject* next; + for (PyFrameObject* py_frame = PyThreadState_GetFrame(thread_state); + py_frame != nullptr && count < kMaxFrames; py_frame = next) { + frames[count] = {PyFrame_GetCode(py_frame), PyFrame_GetLasti(py_frame)}; + ++count; + next = PyFrame_GetBack(py_frame); + Py_XDECREF(py_frame); + } +#endif // PLATFORM_GOOGLE + + Traceback traceback = + nb::steal(PyObject_NewVar(PyObject, traceback_type_, count)); + TracebackObject* tb = reinterpret_cast(traceback.ptr()); + std::memcpy(tb->frames, frames.data(), sizeof(TracebackEntry) * count); + return traceback; +} + +void BuildTracebackSubmodule(nb::module_& m) { + nb::class_(m, "Frame") + .def(nb::init()) + .def_ro("file_name", &Traceback::Frame::file_name) + .def_ro("function_name", &Traceback::Frame::function_name) + .def_ro("function_start_line", &Traceback::Frame::function_start_line) + .def_ro("line_num", &Traceback::Frame::line_num) + .def("__repr__", [](const Traceback::Frame& frame) { + return absl::StrFormat( + "%s;%s:%d", nb::cast(frame.function_name), + nb::cast(frame.file_name), frame.line_num); + }); + + std::string name = + absl::StrCat(nb::cast(m.attr("__name__")), ".Traceback"); + + PyType_Spec traceback_spec = { + /*.name=*/name.c_str(), + /*.basicsize=*/static_cast(sizeof(TracebackObject)), + /*.itemsize=*/static_cast(sizeof(TracebackEntry)), + /*.flags=*/Py_TPFLAGS_DEFAULT, + /*.slots=*/traceback_slots_, + }; + + traceback_type_ = + reinterpret_cast(PyType_FromSpec(&traceback_spec)); + if (!traceback_type_) { + throw nb::python_error(); + } + + auto type = nb::borrow(traceback_type_); + m.attr("Traceback") = type; + + m.def("tracebacks_enabled", []() { return traceback_enabled_.load(); }); + m.def("set_tracebacks_enabled", + [](bool value) { traceback_enabled_.store(value); }); + + type.attr("get_traceback") = nb::cpp_function(Traceback::Get, + R"doc( + Returns a :class:`Traceback` for the current thread. + + If ``Traceback.enabled`` is ``True``, returns a :class:`Traceback` + object that describes the Python stack of the calling thread. Stack + trace collection has a small overhead, so it is disabled by default. If + traceback collection is disabled, returns ``None``. )doc"); + type.attr("frames") = nb_property_readonly(&Traceback::Frames); + type.attr("raw_frames") = nb::cpp_function( + [](const Traceback& tb) -> nb::tuple { + // We return a tuple of lists, rather than a list of tuples, because it + // is cheaper to allocate only three Python objects for everything + // rather than one per frame. + std::vector> frames = tb.RawFrames(); + nb::list out_code = nb::steal(PyList_New(frames.size())); + nb::list out_lasti = nb::steal(PyList_New(frames.size())); + for (size_t i = 0; i < frames.size(); ++i) { + const auto& frame = frames[i]; + PyObject* code = reinterpret_cast(frame.first); + Py_INCREF(code); + PyList_SET_ITEM(out_code.ptr(), i, code); + PyList_SET_ITEM(out_lasti.ptr(), i, + nb::int_(frame.second).release().ptr()); + } + return nb::make_tuple(out_code, out_lasti); + }, + nb::is_method()); + type.attr("as_python_traceback") = + nb::cpp_function(AsPythonTraceback, nb::is_method()); + + type.attr("traceback_from_frames") = nb::cpp_function( + [](std::vector frames) { + nb::object traceback = nb::none(); + nb::dict globals; + nb::handle traceback_type( + reinterpret_cast(&PyTraceBack_Type)); + for (const Traceback::Frame& frame : frames) { + PyCodeObject* py_code = + PyCode_NewEmpty(frame.file_name.c_str(), + frame.function_name.c_str(), frame.line_num); + PyFrameObject* py_frame = PyFrame_New(PyThreadState_Get(), py_code, + globals.ptr(), /*locals=*/ + nullptr); + Py_DECREF(py_code); + traceback = traceback_type( + /*tb_next=*/std::move(traceback), + /*tb_frame=*/ + nb::steal(reinterpret_cast(py_frame)), + /*tb_lasti=*/0, + /*tb_lineno=*/ + frame.line_num); + } + return traceback; + }, + "Creates a traceback from a list of frames."); + + type.attr("code_addr2line") = nb::cpp_function( + [](nb::handle code, int lasti) { + if (!PyCode_Check(code.ptr())) { + throw xla::XlaRuntimeError("code argument must be a code object"); + } + return PyCode_Addr2Line(reinterpret_cast(code.ptr()), + lasti); + }, + "Python wrapper around the Python C API function PyCode_Addr2Line"); + + type.attr("code_addr2location") = nb::cpp_function( + [](nb::handle code, int lasti) { + if (!PyCode_Check(code.ptr())) { + throw xla::XlaRuntimeError("code argument must be a code object"); + } + int start_line, start_column, end_line, end_column; + if (!PyCode_Addr2Location(reinterpret_cast(code.ptr()), + lasti, &start_line, &start_column, &end_line, + &end_column)) { + throw nb::python_error(); + } + return nb::make_tuple(start_line, start_column, end_line, end_column); + }, + "Python wrapper around the Python C API function PyCode_Addr2Location"); +} +} // namespace xla diff --git a/jaxlib/traceback.h b/jaxlib/traceback.h new file mode 100644 index 000000000000..9ae7e9e0836f --- /dev/null +++ b/jaxlib/traceback.h @@ -0,0 +1,63 @@ +/* Copyright 2020 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_TRACEBACK_H_ +#define JAXLIB_TRACEBACK_H_ + +#include + +#include +#include +#include +#include + +// placeholder for index annotation headers +#include "nanobind/nanobind.h" + +namespace xla { + +class Traceback : public nanobind::object { + public: + NB_OBJECT(Traceback, nanobind::object, "Traceback", Traceback::Check); + + // Returns a traceback if it is enabled, otherwise returns nullopt. + static std::optional Get(); + + // Returns a string representation of the traceback. + std::string ToString() const; + + // Returns a list of (code, lasti) pairs for each frame in the traceback. + std::vector> RawFrames() const; + + struct Frame { + nanobind::str file_name; + nanobind::str function_name; + int function_start_line; + int line_num; + + std::string ToString() const; + }; + // Returns a list of Frames for the traceback. + std::vector Frames() const; + + private: + static bool Check(PyObject* o); +}; + +void BuildTracebackSubmodule(nanobind::module_& m); + +} // namespace xla + +#endif // JAXLIB_TRACEBACK_H_ diff --git a/jaxlib/triton/BUILD b/jaxlib/triton/BUILD index 99cddd9e6381..478ce31140a6 100644 --- a/jaxlib/triton/BUILD +++ b/jaxlib/triton/BUILD @@ -35,7 +35,9 @@ pytype_strict_library( "//jaxlib/mlir:ir", ] + if_windows( [], - ["//jaxlib/mlir/_mlir_libs:_triton_ext"], + [ + "//jaxlib/mlir/_mlir_libs:_triton_ext", + ], ), ) diff --git a/jaxlib/triton/triton_dialect_capi.cc b/jaxlib/triton/triton_dialect_capi.cc index 6a46d2914f57..8781fd16d76a 100644 --- a/jaxlib/triton/triton_dialect_capi.cc +++ b/jaxlib/triton/triton_dialect_capi.cc @@ -15,12 +15,12 @@ limitations under the License. #include "jaxlib/triton/triton_dialect_capi.h" -#include "llvm/include/llvm/Support/Casting.h" -#include "mlir/include/mlir-c/IR.h" -#include "mlir/include/mlir/CAPI/IR.h" -#include "mlir/include/mlir/CAPI/Registration.h" -#include "mlir/include/mlir/IR/Attributes.h" -#include "mlir/include/mlir/IR/Dialect.h" +#include "llvm/Support/Casting.h" +#include "mlir-c/IR.h" +#include "mlir/CAPI/IR.h" +#include "mlir/CAPI/Registration.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Types.h" diff --git a/jaxlib/triton/triton_dialect_capi.h b/jaxlib/triton/triton_dialect_capi.h index 8c27b5b82500..7d2a2f10404a 100644 --- a/jaxlib/triton/triton_dialect_capi.h +++ b/jaxlib/triton/triton_dialect_capi.h @@ -16,8 +16,8 @@ limitations under the License. #ifndef JAXLIB_TRITON_TRITON_DIALECT_CAPI_H_ #define JAXLIB_TRITON_TRITON_DIALECT_CAPI_H_ -#include "mlir/include/mlir-c/IR.h" -#include "mlir/include/mlir-c/Support.h" +#include "mlir-c/IR.h" +#include "mlir-c/Support.h" #ifdef __cplusplus extern "C" { diff --git a/jaxlib/util.cc b/jaxlib/util.cc new file mode 100644 index 000000000000..a8d45749f4d1 --- /dev/null +++ b/jaxlib/util.cc @@ -0,0 +1,83 @@ +/* Copyright 2022 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/util.h" + +#include +#include + +#include "absl/status/status.h" +#include "absl/synchronization/notification.h" +#include "absl/time/time.h" +#include "absl/types/span.h" +#include "nanobind/nanobind.h" +#include "xla/pjrt/pjrt_future.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/client.h" +#include "xla/python/ifrt/future.h" +#include "xla/python/ifrt/value.h" +#include "xla/python/version.h" +#include "xla/tsl/concurrency/async_value.h" +#include "xla/tsl/concurrency/ref_count.h" +#include "xla/util.h" + +namespace xla { + +void BlockUntilReadyWithCancel(xla::PjRtFuture<>& future) { + future.BlockUntilReady([](tsl::AsyncValue* value) { + auto state = std::make_shared(); + value->AndThen([state]() { state->Notify(); }); + while (true) { + if (state->WaitForNotificationWithTimeout(absl::Milliseconds(200))) { + break; + } + nanobind::gil_scoped_acquire gil_acquire; + if (PyErr_CheckSignals() != 0) { + throw nanobind::python_error(); + } + } + }); +} + +absl::Status AwaitBuffersReady(absl::Span ifrt_arrays) { + if (ifrt_arrays.empty()) { + return absl::OkStatus(); + } + + ifrt::Future<> future; + if (ifrt_arrays.size() == 1) { + future = ifrt_arrays[0]->GetReadyFuture(); + } else { + std::vector values; + values.reserve(ifrt_arrays.size()); + for (ifrt::Array* const ifrt_array : ifrt_arrays) { + values.push_back(tsl::FormRef(ifrt_array)); + } + ifrt::Client* const client = ifrt_arrays.front()->client(); + future = client->GetReadyFuture(values); + } + BlockUntilReadyWithCancel(future); + absl::Status s = future.Await(); + if (!s.ok()) { + // Fix up error string because some clients rely on it. + if (s.message() == "GetReadyFuture() called on deleted or donated buffer") { + s = InvalidArgument( + "BlockHostUntilReady() called on deleted or donated buffer"); + } + } + return s; +} + +} // namespace xla diff --git a/jaxlib/util.h b/jaxlib/util.h new file mode 100644 index 000000000000..14848bb0ccf8 --- /dev/null +++ b/jaxlib/util.h @@ -0,0 +1,34 @@ +/* Copyright 2022 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_UTIL_H_ +#define JAXLIB_UTIL_H_ + +#include "absl/status/status.h" +#include "absl/types/span.h" +#include "xla/python/ifrt/array.h" + +namespace xla { + +// Waits until future is ready but will cancel if ctrl-c is pressed. +void BlockUntilReadyWithCancel(xla::PjRtFuture<>& future); + +// Requests if given buffers are ready, awaits for results and returns OK if +// all of the buffers are ready or the last non-ok status. +absl::Status AwaitBuffersReady(absl::Span ifrt_arrays); + +} // namespace xla + +#endif // JAXLIB_UTIL_H_ diff --git a/jaxlib/utils.cc b/jaxlib/utils.cc index bf50b3a5254d..e5bb45e999da 100644 --- a/jaxlib/utils.cc +++ b/jaxlib/utils.cc @@ -19,12 +19,12 @@ limitations under the License. #include #include -#include "nanobind/nanobind.h" #include "absl/cleanup/cleanup.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" #include "absl/synchronization/mutex.h" +#include "nanobind/nanobind.h" namespace nb = nanobind; diff --git a/jaxlib/weakref_lru_cache.cc b/jaxlib/weakref_lru_cache.cc new file mode 100644 index 000000000000..0e3b9b831b82 --- /dev/null +++ b/jaxlib/weakref_lru_cache.cc @@ -0,0 +1,416 @@ +/* Copyright 2022 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include +#include +#include +#include +#include +#include +#include // NOLINT +#include +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/cleanup/cleanup.h" +#include "absl/hash/hash.h" +#include "absl/strings/str_cat.h" +#include "absl/synchronization/mutex.h" +#include "absl/synchronization/notification.h" +#include "nanobind/nanobind.h" +#include "nanobind/stl/shared_ptr.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "xla/pjrt/lru_cache.h" +#include "xla/tsl/platform/logging.h" + +namespace nb = nanobind; + +namespace jax { +namespace { + +// Minimal wrapper to expose a nb::dict_iterator's value as something +// hashable with Abseil. +class HashablePyDictEntry { + public: + explicit HashablePyDictEntry(std::pair entry) + : entry_(entry) {} + + template + friend H AbslHashValue(H h, const HashablePyDictEntry& v) { + return H::combine(std::move(h), nb::hash(v.entry_.first), + nb::hash(v.entry_.second)); + } + + std::pair entry_; +}; + +// Similarly, a minimalist adaptor around the nb::detail::dict_iterator +// itself. Note that the iterator "is" also a Value. Does not meet the full +// standard iterator requirements, only enough to support H::combine_unordered. +class HashablePyDictIter { + public: + using iterator_category = std::input_iterator_tag; + + explicit HashablePyDictIter(nb::detail::dict_iterator& iter) : iter_(iter) {} + + // Minimal set of iterator operations. + HashablePyDictEntry operator*() const { return HashablePyDictEntry(*iter_); } + bool operator!=(const HashablePyDictIter& rhs) const { + return iter_ != rhs.iter_; + } + void operator++() { ++iter_; } + + private: + nb::detail::dict_iterator& iter_; +}; + +struct HashableKey { + nb::object context; + nb::args args; + nb::kwargs kwargs; + + template + friend H AbslHashValue(H h, const HashableKey& key) { + // Note: Despite the fact this is an ABSL hash function, it's safe to call + // functions that may throw exceptions such as nb::hash(), because it is + // used by an LRUCache, which uses a std::unordered_map, which is + // exception-safe. + h = H::combine(std::move(h), nb::hash(key.context), nb::hash(key.args)); + nb::detail::dict_iterator begin = key.kwargs.begin(); + nb::detail::dict_iterator end = key.kwargs.end(); + h = H::combine_unordered(std::move(h), HashablePyDictIter(begin), + HashablePyDictIter(end)); + h = H::combine(std::move(h), key.kwargs.size()); + return h; + } +}; + +} // namespace + +class WeakrefLRUCache : public std::enable_shared_from_this { + public: + WeakrefLRUCache(nb::callable cache_context_fn, nb::callable fn, + int64_t maxsize) + : cache_context_fn_(cache_context_fn), fn_(fn), lru_list_(maxsize) {} + + nb::object Call(nb::object weakref_key, nb::args args, nb::kwargs kwargs); + + std::vector GetKeys(); + + struct CacheInfo { + int64_t hits; + int64_t misses; + int64_t maxsize; + int64_t currsize; + }; + CacheInfo GetCacheInfo() const; + + void Clear(); + + static PyType_Slot slots_[]; + + private: + class Key { + public: + Key(nb::object context, nb::args args, nb::kwargs kwargs) + : context_(std::move(context)), + args_(std::move(args)), + kwargs_(std::move(kwargs)), + cached_hash_(absl::HashOf(HashableKey{context_, args_, kwargs_})) {} + + bool operator==(const Key& other) const { + return context_.equal(other.context_) && args_.equal(other.args_) && + kwargs_.equal(other.kwargs_); + } + + template + friend H AbslHashValue(H h, const Key& key) { + return H::combine(std::move(h), key.cached_hash_); + } + + nb::object context() const { return context_; } + nb::args args() const { return args_; } + nb::kwargs kwargs() const { return kwargs_; } + + int tp_traverse(visitproc visit, void* arg) const { + Py_VISIT(context_.ptr()); + Py_VISIT(args_.ptr()); + Py_VISIT(kwargs_.ptr()); + return 0; + } + + private: + nb::object context_; + nb::args args_; + nb::kwargs kwargs_; + size_t cached_hash_; + }; + + struct CacheEntry { + bool has_result = false; + nb::object result; + absl::Notification completed; + std::thread::id thread_id = std::this_thread::get_id(); + + int tp_traverse(visitproc visit, void* arg) const { + Py_VISIT(result.ptr()); + return 0; + } + }; + + struct WeakrefCacheKey { + nb::weakref ref; + size_t cached_hash; + }; + + using Cache = xla::LRUCache>; + + struct WeakrefCacheValue { + std::shared_ptr cache; + }; + + struct WeakrefKeyHash { + size_t operator()(const WeakrefCacheKey& v) const { return v.cached_hash; } + }; + + struct WeakrefKeyEq { + bool operator()(const WeakrefCacheKey& lhs, + const WeakrefCacheKey& rhs) const { + return lhs.ref.equal(rhs.ref); + } + }; + + std::shared_ptr GetCache(WeakrefCacheKey key) { + WeakrefCacheValue& value = entries_[key]; + if (!value.cache) { + value.cache = std::make_shared(&lru_list_); + } + return value.cache; + } + + nb::callable cache_context_fn_; + nb::callable fn_; + Cache::LRUList lru_list_; + std::unordered_map + entries_; + int64_t misses_ = 0; + int64_t total_queries_ = 0; + absl::Mutex mu_; + + static int tp_traverse(PyObject* self, visitproc visit, void* arg); + static int tp_clear(PyObject* self); +}; + +nb::object WeakrefLRUCache::Call(nb::object weakref_key, nb::args args, + nb::kwargs kwargs) + ABSL_NO_THREAD_SAFETY_ANALYSIS { + nb::object context = cache_context_fn_(); + + // We precompute all of the hash values needed by the various maps rather + // than computing them during the std::unordered_map insertions. At the very + // least, MSVC's std::unordered_map has undefined behavior if the hash + // function throws an exception + // (https://learn.microsoft.com/en-us/cpp/standard-library/unordered-map-class?view=msvc-170#emplace). + Key key(context, args, kwargs); + size_t wrcache_hash = static_cast(nb::hash(weakref_key)); + + // No hash computations after this point. + + auto weakref_gc_callback = nb::cpp_function( + [this_weak = weak_from_this(), wrcache_hash](nb::handle weakref) { + auto cache = this_weak.lock(); + if (cache == nullptr) { + return; + } + // Set up PyCriticalSection for cache python associated object; + auto py_cache = nb::find(cache); + // This should never happen as python cache should always be found + CHECK(py_cache.ptr() != nullptr); + nb::ft_object_guard lock(py_cache); + + // The object the reference referred to is now in the process of being + // destroyed, so we cannot refer to its contents. Python weakref + // objects compare based on identity if the object they refer to is + // gone, so the hash lookup will work fine. + auto it = cache->entries_.find( + WeakrefCacheKey{nb::borrow(weakref), wrcache_hash}); + if (it == cache->entries_.end()) { + return; + } + // Create temp-var to avoid re-entrant erase. + auto tmp = std::move(it->second); + cache->entries_.erase(it); + }); + nb::weakref weakref = nb::weakref(weakref_key, weakref_gc_callback); + WeakrefCacheKey wrcache_key{weakref, wrcache_hash}; + std::shared_ptr cache_ptr = GetCache(wrcache_key); + Cache& cache = *cache_ptr; + ++total_queries_; + + bool inserted = false; + std::shared_ptr entry; + { + // Because the gil can be released during cache insertion, this forces + // the lock order to be mu_ then gil so we must release the gil first. + nb::gil_scoped_release release; + // Acquire a mutex to avoid problems where the gil is released during + // cache insertion and then a second thread invalidates the cache order. + mu_.Lock(); + } + { + // GetOrCreateIfAbsent calls into Python hash and equality functions, + // which may throw exceptions. The use of absl::Cleanup ensures mu_ is + // released if that happens. + absl::Cleanup unlock = [this]() ABSL_UNLOCK_FUNCTION(mu_) { mu_.Unlock(); }; + entry = cache.GetOrCreateIfAbsent(key, [&inserted](const Key& key) { + inserted = true; + return std::make_shared(); + }); + } + if (!entry->completed.HasBeenNotified()) { + if (inserted) { + ++misses_; + absl::Cleanup notify = [&] { entry->completed.Notify(); }; + entry->result = fn_(weakref_key, *args, **kwargs); + entry->has_result = true; + } else { + if (entry->thread_id == std::this_thread::get_id()) { + auto error_string = + absl::StrCat("Recursively calling ", + nb::cast(nb::repr(weakref_key)), + nb::cast(nb::repr(args))); + PyErr_SetString(PyExc_RecursionError, error_string.c_str()); + throw nb::python_error(); + } + nb::gil_scoped_release release; + entry->completed.WaitForNotification(); + } + } + + if (entry->has_result) { + return entry->result; + } else { + ++misses_; + return fn_(weakref_key, *args, **kwargs); + } +} + +std::vector WeakrefLRUCache::GetKeys() { + std::vector results; + mu_.Lock(); + for (const auto& wr_entry : entries_) { + for (const auto& rest : *wr_entry.second.cache) { + nb::tuple result = + nb::make_tuple(*wr_entry.first.ref, rest.first.context(), + rest.first.args(), rest.first.kwargs()); + results.push_back(std::move(result)); + } + } + mu_.Unlock(); + return results; +} + +WeakrefLRUCache::CacheInfo WeakrefLRUCache::GetCacheInfo() const { + CacheInfo result; + result.hits = total_queries_ - misses_; + result.misses = misses_; + result.maxsize = lru_list_.Capacity(); + result.currsize = lru_list_.Size(); + return result; +} + +void WeakrefLRUCache::Clear() { + total_queries_ = misses_ = 0; + std::vector> deferred_deletes; + deferred_deletes.reserve(entries_.size()); + for (auto& entry : entries_) { + deferred_deletes.emplace_back(entry.first, std::move(entry.second)); + } + entries_.clear(); + deferred_deletes.clear(); +} + +/*static*/ int WeakrefLRUCache::tp_traverse(PyObject* self, visitproc visit, + void* arg) { + Py_VISIT(Py_TYPE(self)); + if (!nb::inst_ready(self)) { + return 0; + } + WeakrefLRUCache* cache = nb::inst_ptr(self); + Py_VISIT(cache->cache_context_fn_.ptr()); + Py_VISIT(cache->fn_.ptr()); + for (const auto& [wr_key, wr_value] : cache->entries_) { + Py_VISIT(wr_key.ref.ptr()); + for (const auto& [key, cache_value] : *wr_value.cache) { + int rval = key.tp_traverse(visit, arg); + if (rval != 0) { + return rval; + } + if (cache_value.value.has_value()) { + cache_value.value->get()->tp_traverse(visit, arg); + } + } + } + return 0; +} + +/*static*/ int WeakrefLRUCache::tp_clear(PyObject* self) { + WeakrefLRUCache* cache = nb::inst_ptr(self); + cache->Clear(); + cache->cache_context_fn_.reset(); + cache->fn_.reset(); + return 0; +} + +/* static */ PyType_Slot WeakrefLRUCache::slots_[] = { + {Py_tp_traverse, (void*)WeakrefLRUCache::tp_traverse}, + {Py_tp_clear, (void*)WeakrefLRUCache::tp_clear}, + {0, nullptr}, +}; + +NB_MODULE(weakref_lru_cache, m) { + auto weakref_lru_cache = + nb::class_(m, "WeakrefLRUCache", + nb::is_weak_referenceable(), + nb::type_slots(WeakrefLRUCache::slots_)) + .def("__call__", &WeakrefLRUCache::Call, nb::lock_self()) + .def("cache_keys", &WeakrefLRUCache::GetKeys, nb::lock_self()) + .def("cache_info", &WeakrefLRUCache::GetCacheInfo, nb::lock_self()) + .def("cache_clear", &WeakrefLRUCache::Clear, nb::lock_self()); + nb::class_(weakref_lru_cache, + "WeakrefLRUCacheInfo") + .def_ro("hits", &WeakrefLRUCache::CacheInfo::hits) + .def_ro("misses", &WeakrefLRUCache::CacheInfo::misses) + .def_ro("maxsize", &WeakrefLRUCache::CacheInfo::maxsize) + .def_ro("currsize", &WeakrefLRUCache::CacheInfo::currsize) + .def("__repr__", [](WeakrefLRUCache::CacheInfo& info) { + return absl::StrCat( + "WeakrefLRUCache(hits=", info.hits, ", misses=", info.misses, + ", maxsize=", info.maxsize, ", currsize=", info.currsize, ")"); + }); + m.def( + "weakref_lru_cache", + [](nb::callable cache_context_fn, nb::callable fn, int64_t maxsize) { + return std::make_shared(cache_context_fn, fn, maxsize); + }, + nb::arg("cache_context_fn"), nb::arg("fn"), nb::arg("maxsize") = 2048); +} + +} // namespace jax diff --git a/jaxlib/weakref_lru_cache.pyi b/jaxlib/weakref_lru_cache.pyi new file mode 100644 index 000000000000..ed965d7be811 --- /dev/null +++ b/jaxlib/weakref_lru_cache.pyi @@ -0,0 +1,38 @@ +# Copyright 2025 The JAX Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from collections.abc import Callable +from typing import Any + +class WeakrefLRUCache: + def __call__(self, arg0: Any, /, *args, **kwargs) -> Any: ... + def cache_keys(self) -> list[Any]: ... + def cache_info(self) -> WeakrefLRUCache.WeakrefLRUCacheInfo: ... + def cache_clear(self) -> None: ... + + class WeakrefLRUCacheInfo: + @property + def hits(self) -> int: ... + @property + def misses(self) -> int: ... + @property + def maxsize(self) -> int: ... + @property + def currsize(self) -> int: ... + def __repr__(self) -> str: ... + +def weakref_lru_cache( + cache_context_fn: Callable, fn: Callable, maxsize: int = 2048 +) -> WeakrefLRUCache: ... diff --git a/jaxlib/weakref_lru_cache_test.py b/jaxlib/weakref_lru_cache_test.py new file mode 100644 index 000000000000..a1016f397389 --- /dev/null +++ b/jaxlib/weakref_lru_cache_test.py @@ -0,0 +1,264 @@ +# Copyright 2023 The JAX Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import gc +import threading +import time +import weakref + +from absl.testing import absltest +from jax.jaxlib import weakref_lru_cache + + +class WeakrefLRUCacheTest(absltest.TestCase): + + def testMultiThreaded(self): + insert_evs = [threading.Event() for _ in range(2)] + insert_evs_i = 0 + + class WRKey: + pass + + class ClashingKey: + + def __eq__(self, other): + return False + + def __hash__(self): + return 333 # induce maximal caching problems. + + class GilReleasingCacheKey: + + def __eq__(self, other): + nonlocal insert_evs_i + if isinstance(other, GilReleasingCacheKey) and insert_evs_i < len( + insert_evs + ): + insert_evs[insert_evs_i].set() + insert_evs_i += 1 + time.sleep(0.01) + return False + + def __hash__(self): + return 333 # induce maximal caching problems. + + def CacheFn(obj, gil_releasing_cache_key): + del obj + del gil_releasing_cache_key + return None + + cache = weakref_lru_cache.weakref_lru_cache(lambda: None, CacheFn, 2048) + + wrkey = WRKey() + + def Body(): + for insert_ev in insert_evs: + insert_ev.wait() + for _ in range(20): + cache(wrkey, ClashingKey()) + + t = threading.Thread(target=Body) + t.start() + for _ in range(3): + cache(wrkey, GilReleasingCacheKey()) + t.join() + + def testAnotherMultiThreaded(self): + num_workers = 5 + barrier = threading.Barrier(num_workers) + cache = weakref_lru_cache.weakref_lru_cache( + lambda: None, lambda x, y: y, 2048 + ) + + class WRKey: + pass + + def WorkerAddToCache(): + barrier.wait() + wrkey = WRKey() + for i in range(10): + cache(wrkey, i) + + def WorkerCleanCache(): + barrier.wait() + for _ in range(10): + cache.cache_clear() + + workers = [ + threading.Thread(target=WorkerAddToCache) + for _ in range(num_workers - 1) + ] + [threading.Thread(target=WorkerCleanCache)] + + for t in workers: + t.start() + + for t in workers: + t.join() + + def testKwargsDictOrder(self): + miss_id = 0 + + class WRKey: + pass + + def CacheFn(obj, kwkey1, kwkey2): + del obj, kwkey1, kwkey2 + nonlocal miss_id + miss_id += 1 + return miss_id + + cache = weakref_lru_cache.weakref_lru_cache(lambda: None, CacheFn, 4) + + wrkey = WRKey() + + self.assertEqual(cache(wrkey, kwkey1="a", kwkey2="b"), 1) + self.assertEqual(cache(wrkey, kwkey1="b", kwkey2="a"), 2) + self.assertEqual(cache(wrkey, kwkey2="b", kwkey1="a"), 1) + + def testGetKeys(self): + def CacheFn(obj, arg): + del obj + return arg + "extra" + + cache = weakref_lru_cache.weakref_lru_cache(lambda: None, CacheFn, 4) + + class WRKey: + pass + + wrkey = WRKey() + + self.assertEmpty(cache.cache_keys()) + cache(wrkey, "arg1") + cache(wrkey, "arg2") + self.assertLen(cache.cache_keys(), 2) + + def testNonWeakreferenceableKey(self): + class NonWRKey: + __slots__ = () + + non_wr_key = NonWRKey() + with self.assertRaises(TypeError): + weakref.ref(non_wr_key) + + cache = weakref_lru_cache.weakref_lru_cache(lambda: None, lambda x: 2048) + for _ in range(100): + with self.assertRaises(TypeError): + cache(non_wr_key) + + def testCrashingKey(self): + class WRKey: + pass + + class CrashingKey: + # A key that raises exceptions if eq or hash is called. + + def __eq__(self, other): + raise ValueError("eq") + + def __hash__(self): + raise ValueError("hash") + + cache = weakref_lru_cache.weakref_lru_cache( + lambda: None, lambda x, y: y, 2048 + ) + wrkey = WRKey() + with self.assertRaises(ValueError): + for _ in range(100): + cache(wrkey, CrashingKey()) + + def testPrintingStats(self): + class WRKey: + pass + + cache = weakref_lru_cache.weakref_lru_cache( + lambda: None, lambda x, y: y, 2048 + ) + wrkey = WRKey() + for i in range(10): + cache(wrkey, i) + for i in range(5): + cache(wrkey, i) + + self.assertEqual( + repr(cache.cache_info()), + "WeakrefLRUCache(hits=5, misses=10, maxsize=2048, currsize=10)", + ) + + def testGCKeys(self): + class WRKey: + + def __init__(self, x): + self.x = x + + def __eq__(self, other): + return self.x == other.x + + def __hash__(self): + return hash(self.x) + + cache = weakref_lru_cache.weakref_lru_cache( + lambda: None, lambda x, y: y, 2048 + ) + keys = [WRKey(i) for i in range(10)] + for i in range(10): + cache(keys[i], i) + + # Delete some keys, to exercise the weakref callback behavior. + del keys[::2] + + for key in keys: + cache(key, 7) + + def testTpTraverse(self): + class WRKey: + pass + + def CacheContextFn(): + return None + + def CallFn(x, y, *args, **kwargs): + del x, args, kwargs + return y + + cache = weakref_lru_cache.weakref_lru_cache(CacheContextFn, CallFn, 2048) + + keys = [WRKey() for _ in range(10)] + values = [str(i) for i in range(10)] + args = [str(i) for i in range(10)] + kwargs = {"a": "b"} + + for key, value in zip(keys, values): + cache(key, value, *args, **kwargs) + + expected_refs = ( + [ + CacheContextFn, + CallFn, + weakref_lru_cache.WeakrefLRUCache, + kwargs, + ] + + [weakref.getweakrefs(key)[0] for key in keys] + + values + + args + ) + + # Can't use assertContainsSubset because it doesn't support kwargs since + # dicts aren't hashable. + for ref in expected_refs: + self.assertIn(ref, gc.get_referents(cache)) + + +if __name__ == "__main__": + absltest.main() diff --git a/jaxlib/xla.cc b/jaxlib/xla.cc new file mode 100644 index 000000000000..6c8e3fdb4ad1 --- /dev/null +++ b/jaxlib/xla.cc @@ -0,0 +1,993 @@ +/* Copyright 2019 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/base/casts.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/inlined_vector.h" +#include "absl/hash/hash.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/time/time.h" +#include "absl/types/span.h" +#include "llvm/Support/Casting.h" +#include "nanobind/nanobind.h" +#include "nanobind/nb_defs.h" +#include "nanobind/stl/function.h" // IWYU pragma: keep +#include "nanobind/stl/optional.h" // IWYU pragma: keep +#include "nanobind/stl/pair.h" // IWYU pragma: keep +#include "nanobind/stl/set.h" // IWYU pragma: keep +#include "nanobind/stl/shared_ptr.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "nanobind/stl/unique_ptr.h" // IWYU pragma: keep +#include "nanobind/stl/unordered_map.h" // IWYU pragma: keep +#include "nanobind/stl/variant.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "jaxlib/ffi.h" +#include "jaxlib/ifrt_proxy.h" +#include "jaxlib/py_client.h" +#include "jaxlib/py_program.h" +#include "jaxlib/sdy.h" +#include "xla/backends/cpu/collectives/cpu_collectives.h" +#include "xla/pjrt/c/pjrt_c_api.h" +#include "xla/pjrt/distributed/client.h" +#include "xla/pjrt/distributed/distributed.h" +#include "xla/pjrt/distributed/protocol.pb.h" +#include "xla/pjrt/distributed/service.h" +#include "xla/pjrt/pjrt_compiler.h" +#include "xla/pjrt/plugin/xla_cpu/cpu_client_options.h" +#include "xla/pjrt/plugin/xla_cpu/xla_cpu_pjrt_client.h" +#include "xla/pjrt/status_casters.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/device_list.h" +#include "xla/python/ifrt/executable.h" +#include "xla/python/ifrt/topology.h" +#include "xla/python/pjrt_ifrt/pjrt_attribute_map_util.h" +#include "xla/python/version.h" +#include "xla/tsl/python/lib/core/numpy.h" // NOLINT + +#if defined(__linux__) +#include "gloo/transport/tcp/attr.h" +#include "gloo/transport/tcp/device.h" +#include "jaxlib/py_socket_transfer.h" +#include "xla/backends/cpu/collectives/gloo_collectives.h" +#include "xla/backends/cpu/collectives/gloo_kv_store.h" +#elif defined(__APPLE__) +#include "gloo/transport/uv/device.h" +#include "xla/backends/cpu/collectives/gloo_collectives.h" // NOLINT +#include "xla/backends/cpu/collectives/gloo_kv_store.h" // NOLINT +#endif // defined(__linux__) + +#if !defined(_WIN32) && !defined(PLATFORM_GOOGLE) +#include "xla/backends/cpu/collectives/mpi_collectives.h" +#endif // !_WIN32 && !PLATFORM_GOOGLE + +#include "jaxlib/config.h" +#include "jaxlib/custom_call_sharding.h" +#include "jaxlib/dlpack.h" +#include "jaxlib/guard_lib.h" +#include "jaxlib/jax_jit.h" +#include "jaxlib/mlir.h" +#include "jaxlib/nb_class_ptr.h" +#include "jaxlib/partition_spec.h" +#include "jaxlib/pjit.h" +#include "jaxlib/pmap_lib.h" +#include "jaxlib/py_array.h" +#include "jaxlib/py_compile_only_client.h" +#include "jaxlib/py_device.h" +#include "jaxlib/py_device_list.h" +#include "jaxlib/py_executable.h" +#include "jaxlib/py_memory_space.h" +#include "jaxlib/python_ref_manager.h" +#include "jaxlib/pytree.h" +#include "jaxlib/sharding.h" +#include "jaxlib/traceback.h" +#include "jaxlib/xla_compiler.h" +#include "xla/hlo/builder/lib/approx_topk_shape.h" +#include "xla/pjrt/distributed/key_value_store_interface.h" +#include "xla/pjrt/exceptions.h" +#include "xla/pjrt/pjrt_api.h" +#include "xla/pjrt/pjrt_c_api_client.h" +#include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_common.h" +#include "xla/pjrt/pjrt_executable.h" +#include "xla/pjrt/pjrt_layout.h" +#include "xla/python/logging.h" // IWYU pragma: keep +#include "xla/python/nb_absl_flat_hash_map.h" // IWYU pragma: keep +#include "xla/python/nb_absl_span.h" // IWYU pragma: keep +#include "xla/python/pjrt_ifrt/pjrt_client.h" +#include "xla/python/pjrt_ifrt/pjrt_topology.h" +#include "xla/python/pprof_profile_builder.h" +#include "xla/tsl/distributed_runtime/preemption/preemption_sync_manager.h" +#include "xla/tsl/platform/status.h" +#include "tsl/platform/platform.h" + +// TODO(phawkins): remove host_id properties after JAX is update to avoid them. + +namespace xla { +namespace { + +namespace nb = nanobind; + +bool IsOptimizedBuild() { +#if NDEBUG + return true; +#else + return false; +#endif // NDEBUG +} + +// Is*san reports whether the build is under that particular sanitizer. +bool IsAsan() { +#if defined(ADDRESS_SANITIZER) + return true; +#else // defined(ADDRESS_SANITIZER) + return false; +#endif +} + +bool IsMsan() { +#if defined(MEMORY_SANITIZER) + return true; +#else // defined(MEMORY_SANITIZER) + return false; +#endif +} + +bool IsTsan() { +#if defined(THREAD_SANITIZER) + return true; +#else // defined(THREAD_SANITIZER) + return false; +#endif +} + +// IsSanitized reports whether the build is under any sanitizer. +bool IsSanitized() { return IsAsan() || IsMsan() || IsTsan(); } + +} // namespace + +NB_MODULE(_jax, m) { + // Initialize ABSL logging because code within XLA uses it. +#ifndef PLATFORM_GOOGLE + InitializeAbslLogging(); +#endif // PLATFORM_GOOGLE + + // We seem to get a fair number of leak warnings from nanobind. It's unclear + // whether these are false positives or not. + nb::set_leak_warnings(false); + + tsl::ImportNumpy(); + + // Exceptions + nb::exception xla_runtime_error(m, "XlaRuntimeError", + PyExc_RuntimeError); + xla_runtime_error.attr("__doc__") = nb::str( + "Runtime errors thrown by the JAX runtime. While the JAX runtime may " + "raise other exceptions as well, most exceptions thrown by the runtime " + "are instances of this class."); + + // Types + nb::enum_(m, "PrimitiveType", nb::is_arithmetic()) + .value("PRIMITIVE_TYPE_INVALID", PRIMITIVE_TYPE_INVALID) + .value("PRED", PRED) + .value("S4", S4) + .value("S8", S8) + .value("S16", S16) + .value("S32", S32) + .value("S64", S64) + .value("U4", U4) + .value("U8", U8) + .value("U16", U16) + .value("U32", U32) + .value("U64", U64) + .value("F16", F16) + .value("F4E2M1FN", F4E2M1FN) + .value("F8E3M4", F8E3M4) + .value("F8E4M3", F8E4M3) + .value("F8E4M3FN", F8E4M3FN) + .value("F8E4M3B11FNUZ", F8E4M3B11FNUZ) + .value("F8E4M3FNUZ", F8E4M3FNUZ) + .value("F8E5M2", F8E5M2) + .value("F8E5M2FNUZ", F8E5M2FNUZ) + .value("F8E8M0FNU", F8E8M0FNU) + .value("BF16", BF16) + .value("F32", F32) + .value("F64", F64) + .value("C64", C64) + .value("C128", C128) + .value("TUPLE", TUPLE) + .value("OPAQUE_TYPE", OPAQUE_TYPE) + .value("TOKEN", TOKEN); + + // Must be before PyClient.compile. + BuildXlaCompilerSubmodule(m); + + PyDevice::RegisterPythonType(m); + PyMemorySpace::RegisterPythonType(m); + PyClient::RegisterPythonTypes(m); + + nb::enum_(m, "ArrayCopySemantics", + nb::is_arithmetic()) + .value("ALWAYS_COPY", ifrt::ArrayCopySemantics::kAlwaysCopy) + .value("REUSE_INPUT", ifrt::ArrayCopySemantics::kReuseInput) + .value("DONATE_INPUT", ifrt::ArrayCopySemantics::kDonateInput); + + nb::class_(m, "PjRtLayout") + .def("__str__", &PjRtLayout::ToString) + .def("__eq__", [](const PjRtLayout& layout, + const PjRtLayout& other) { return layout == other; }) + .def("__hash__", + [](const PjRtLayout& layout) { return absl::HashOf(layout); }) + .def("_xla_layout", &PjRtLayout::xla_layout) + .def("__getstate__", + [](const PjRtLayout& layout) -> nb::tuple { + absl::StatusOr serialized = layout.Serialize(); + ThrowIfError(serialized.status()); + return nb::make_tuple( + nb::bytes(serialized->data(), serialized->size())); + }) + .def("__setstate__", [](PjRtLayout* self, nb::tuple t) { + nb::bytes serialized = nb::cast(t[0]); + absl::StatusOr> layout = + PjRtLayout::Deserialize( + absl::string_view(serialized.c_str(), serialized.size())); + ThrowIfError(layout.status()); + new (self) PjRtLayout((*layout)->xla_layout()); + }); + + nb::class_ cpu_collectives(m, "CpuCollectives"); + + m.def( + "make_gloo_tcp_collectives", + [](std::shared_ptr distributed_client, + + std::optional hostname, + std::optional interface) + -> std::shared_ptr { +#if defined(__linux__) + std::shared_ptr kv_store = nullptr; + if (distributed_client != nullptr) { + kv_store = GetDistributedKeyValueStore(distributed_client, + /*key_prefix=*/"cpu:"); + } + auto gloo_kv_store = std::make_unique(kv_store); + auto tcp_attrs = gloo::transport::tcp::attr(); + if (hostname) { + tcp_attrs.hostname = *hostname; + } + if (interface) { + tcp_attrs.iface = *interface; + } + auto tcp_device = gloo::transport::tcp::CreateDevice(tcp_attrs); + return std::make_shared(std::move(gloo_kv_store), + std::move(tcp_device)); +#elif defined(__APPLE__) + std::shared_ptr kv_store = nullptr; + if (distributed_client != nullptr) { + kv_store = GetDistributedKeyValueStore(distributed_client, + /*key_prefix=*/"cpu:"); + } + auto gloo_kv_store = std::make_unique(kv_store); + auto uv_attrs = gloo::transport::uv::attr(); + if (hostname) { + uv_attrs.hostname = *hostname; + } + if (interface) { + uv_attrs.iface = *interface; + } + auto uv_device = gloo::transport::uv::CreateDevice(uv_attrs); + return std::make_shared(std::move(gloo_kv_store), + std::move(uv_device)); +#else // defined(__linux__) + throw xla::XlaRuntimeError( + "make_gloo_tcp_collectives only implemented for linux and macos"); +#endif // defined(__linux__) + }, + nb::arg("distributed_client"), nb::arg("hostname").none() = std::nullopt, + nb::arg("interface").none() = std::nullopt); + +#if !defined(_WIN32) && !defined(PLATFORM_GOOGLE) + nb::class_ mpi_collectives(m, "MpiCollectives", + cpu_collectives); + mpi_collectives.def("Init", &cpu::MpiCollectives::Init); + mpi_collectives.def("Finalize", &cpu::MpiCollectives::Finalize); + m.def("make_mpi_collectives", []() -> std::shared_ptr { + return std::make_shared(); + }); +#else // !_WIN32 && !PLATFORM_GOOGLE + m.def("make_mpi_collectives", + []() -> std::shared_ptr { + throw xla::XlaRuntimeError( + "make_mpi_collectives is not implemented for Windows"); + }); +#endif // !_WIN32 && !PLATFORM_GOOGLE + + m.def( + "get_tfrt_cpu_client", + [](bool asynchronous, + std::shared_ptr distributed_client, + int node_id, int num_nodes, + std::shared_ptr collectives, + std::optional num_devices, + std::optional get_local_topology_timeout_minutes, + std::optional get_global_topology_timeout_minutes) + -> nb_class_ptr { + std::unique_ptr ifrt_client; + { + nb::gil_scoped_release gil_release; + xla::CpuClientOptions options; + + options.asynchronous = asynchronous; + options.collectives = std::move(collectives); + options.process_id = node_id; + options.cpu_device_count = num_devices; + std::unique_ptr client = + xla::ValueOrThrow(xla::GetXlaPjrtCpuClient(std::move(options))); + ifrt::PjRtClient::CreateOptions ifrt_options; + ifrt_options.pjrt_client = + std::shared_ptr(std::move(client)); + if (distributed_client != nullptr) { + ifrt_options.kv_store = + GetDistributedKeyValueStore(distributed_client, + /*key_prefix=*/"cpu:"); + ifrt_options.process_id = node_id; + ifrt_options.num_processes = num_nodes; + } + if (get_local_topology_timeout_minutes.has_value()) { + ifrt_options.get_local_topology_timeout = + absl::Minutes(*get_local_topology_timeout_minutes); + } + if (get_global_topology_timeout_minutes.has_value()) { + ifrt_options.get_global_topology_timeout = + absl::Minutes(*get_global_topology_timeout_minutes); + } + ifrt_client = + ValueOrThrow(ifrt::PjRtClient::Create(std::move(ifrt_options))); + } + return PyClient::Make(std::move(ifrt_client)); + }, + nb::arg("asynchronous") = true, nb::arg("distributed_client") = nullptr, + nb::arg("node_id") = 0, nb::arg("num_nodes") = 1, + nb::arg("collectives").none() = + std::shared_ptr(), + nb::arg("num_devices").none() = std::nullopt, + nb::arg("get_local_topology_timeout_minutes").none() = std::nullopt, + nb::arg("get_global_topology_timeout_minutes").none() = std::nullopt); + m.def("pjrt_plugin_loaded", [](std::string platform_name) -> bool { + absl::StatusOr pjrt_api = pjrt::PjrtApi(platform_name); + return pjrt_api.ok(); + }); + m.def( + "load_pjrt_plugin", + [](std::string platform_name, std::optional library_path, + std::optional c_api) -> nb::capsule { + if (library_path.has_value()) { + const PJRT_Api* api = xla::ValueOrThrow( + pjrt::LoadPjrtPlugin(platform_name, *library_path)); + return nb::capsule(absl::bit_cast(api), "pjrt_c_api"); + } + if (absl::string_view(c_api->name()) != "pjrt_c_api") { + throw nb::value_error( + "c_api argument to load_pjrt_plugin is not a pjrt_c_api " + "capsule."); + } + xla::ThrowIfError(pjrt::SetPjrtApi( + platform_name, static_cast(c_api->data()))); + return *c_api; + }, + nb::arg("platform_name"), nb::arg("library_path").none() = std::nullopt, + nb::arg("c_api").none() = std::nullopt); + m.def("pjrt_plugin_initialized", [](std::string platform_name) -> bool { + return xla::ValueOrThrow(pjrt::IsPjrtPluginInitialized(platform_name)); + }); + m.def("initialize_pjrt_plugin", [](std::string platform_name) { + return xla::ThrowIfError(pjrt::InitializePjrtPlugin(platform_name)); + }); + + m.def( + "get_c_api_client", + [](std::string platform_name, + const absl::flat_hash_map& options, + std::shared_ptr distributed_client) + -> nb_class_ptr { + std::unique_ptr ifrt_client; + { + nb::gil_scoped_release gil_release; + std::shared_ptr kv_store = nullptr; + if (distributed_client != nullptr) { + kv_store = GetDistributedKeyValueStore( + distributed_client, + /*key_prefix=*/absl::StrCat(platform_name, ":")); + } + std::unique_ptr c_api_client = xla::ValueOrThrow( + GetCApiClient(platform_name, options, kv_store)); + ifrt_client = ifrt::PjRtClient::Create(std::move(c_api_client)); + } + return PyClient::Make(std::move(ifrt_client)); + }, + nb::arg("platform_name"), + nb::arg("options") = absl::flat_hash_map(), + nb::arg("distributed_client").none() = nullptr); + // TODO(b/322357665): Delete this method after TPU plugin changes to use the + // standard registration. + m.def("get_default_c_api_topology", + [](std::string platform_name, std::string topology_name, + const absl::flat_hash_map& options) + -> std::shared_ptr { + return std::make_shared(xla::ValueOrThrow( + GetCApiTopology(platform_name, topology_name, options))); + }); + m.def("get_c_api_topology", + [](nb::capsule c_api, std::string topology_name, + const absl::flat_hash_map& options) + -> std::shared_ptr { + if (absl::string_view(c_api.name()) != "pjrt_c_api") { + throw nb::value_error( + "Argument to get_c_api_topology was not a pjrt_c_api capsule."); + } + return std::make_shared(xla::ValueOrThrow( + GetCApiTopology(static_cast(c_api.data()), + topology_name, options))); + }); + m.def("get_topology_for_devices", + [](const std::vector>& py_devices) { + if (py_devices.empty()) { + throw nb::value_error( + "get_topology_for_devices requires >= 1 devices."); + } + auto client = py_devices[0]->client(); + absl::InlinedVector ifrt_devices; + ifrt_devices.reserve(py_devices.size()); + for (const auto& py_device : py_devices) { + if (py_device->client().get() != client.get()) { + throw nb::value_error( + "devices passed to get_topology_for_devices come from " + "different clients."); + } + ifrt_devices.push_back(py_device->device()); + } + ifrt::DeviceListRef device_list = + client->ifrt_client()->MakeDeviceList(ifrt_devices); + return xla::ValueOrThrow( + client->ifrt_client()->GetTopologyForDevices(device_list)); + }); + + TF_CHECK_OK(PyArray::RegisterTypes(m)); + jax::PyDeviceList::Register(m); + jax::RegisterSharding(m); + + nb::class_(m, "CompiledMemoryStats") + .def_rw("generated_code_size_in_bytes", + &CompiledMemoryStats::generated_code_size_in_bytes) + .def_rw("argument_size_in_bytes", + &CompiledMemoryStats::argument_size_in_bytes) + .def_rw("output_size_in_bytes", + &CompiledMemoryStats::output_size_in_bytes) + .def_rw("alias_size_in_bytes", &CompiledMemoryStats::alias_size_in_bytes) + .def_rw("temp_size_in_bytes", &CompiledMemoryStats::temp_size_in_bytes) + .def_rw("host_generated_code_size_in_bytes", + &CompiledMemoryStats::host_generated_code_size_in_bytes) + .def_rw("host_argument_size_in_bytes", + &CompiledMemoryStats::host_argument_size_in_bytes) + .def_rw("host_output_size_in_bytes", + &CompiledMemoryStats::host_output_size_in_bytes) + .def_rw("host_alias_size_in_bytes", + &CompiledMemoryStats::host_alias_size_in_bytes) + .def_rw("host_temp_size_in_bytes", + &CompiledMemoryStats::host_temp_size_in_bytes) + .def_prop_ro("serialized_buffer_assignment_proto", + [](const CompiledMemoryStats& cms) -> nb::bytes { + const std::string& s = cms.serialized_buffer_assignment; + return nb::bytes(s.data(), s.size()); + }) + .def("__str__", &CompiledMemoryStats::DebugString); + + nb::class_(m, "ExecuteResults") + .def("__len__", [](PyExecuteResults& results) { return results.Size(); }) + .def("disassemble_into_single_device_arrays", + &PyExecuteResults::DisassembleIntoSingleDeviceArrays) + .def("disassemble_prefix_into_single_device_arrays", + &PyExecuteResults::DisassemblePrefixIntoSingleDeviceArrays) + .def("consume_with_handlers", &PyExecuteResults::ConsumeWithHandlers) + .def("consume_token", &PyExecuteResults::ConsumeToken); + + m.def("get_execution_stream_id", []() { return GetExecutionStreamId(); }); + m.def("set_execution_stream_id", + [](int64_t id) { GetExecutionStreamId() = id; }); + + nb::class_(m, "LoadedExecutable") + .def_prop_ro("client", &PyLoadedExecutable::client) + .def("local_devices", &PyLoadedExecutable::AddressableDevices) + .def("size_of_generated_code_in_bytes", + &PyLoadedExecutable::SizeOfGeneratedCodeInBytes) + .def( + "get_compiled_memory_stats", + xla::ValueOrThrowWrapper(&PyLoadedExecutable::GetCompiledMemoryStats)) + .def("execute_sharded", + xla::ValueOrThrowWrapper(&PyLoadedExecutable::ExecuteSharded), + nb::arg("arguments"), nb::arg("with_tokens") = false) + .def("hlo_modules", ValueOrThrowWrapper(&PyLoadedExecutable::HloModules)) + .def("get_output_memory_kinds", + xla::ValueOrThrowWrapper(&PyLoadedExecutable::GetOutputMemoryKinds)) + .def("get_output_shardings", &PyLoadedExecutable::GetOutputShardings) + .def("get_parameter_layouts", + xla::ValueOrThrowWrapper(&PyLoadedExecutable::GetParameterLayouts)) + .def("get_output_layouts", + xla::ValueOrThrowWrapper(&PyLoadedExecutable::GetOutputLayouts)) + .def("get_parameter_shardings", + &PyLoadedExecutable::GetParameterShardings) + .def("keep_alive", &PyLoadedExecutable::KeepAlive) + .def("cost_analysis", + [](const PyLoadedExecutable& self) { + auto map = ValueOrThrow(self.GetCostAnalysis()); + return ifrt::ToPjRtAttributeMap(std::move(map)); + }) + .def_prop_ro("traceback", &PyLoadedExecutable::traceback) + .def_prop_ro("fingerprint", [](PyLoadedExecutable* exec) -> nb::object { + if (exec->fingerprint().has_value()) { + return nb::bytes(exec->fingerprint()->data(), + exec->fingerprint()->size()); + } else { + return nb::none(); + } + }); + nb::class_ token(m, "Token"); + token.def("block_until_ready", + [](PyToken& self) { xla::ThrowIfError(self.Await()); }); + + nb::class_ sharded_token(m, "ShardedToken"); + sharded_token.def("block_until_ready", [](PyShardedToken& self) { + xla::ThrowIfError(self.Await()); + }); + sharded_token.def("get_token", &PyShardedToken::GetPyToken); + + m.def("buffer_to_dlpack_managed_tensor", + xla::ValueOrThrowWrapper(BufferToDLPackManagedTensor), + nb::arg("buffer"), nb::arg("stream").none() = nb::none()); + m.def( + "dlpack_managed_tensor_to_buffer", + [](const nb::capsule& tensor, nb_class_ptr device, + std::optional stream) { + return xla::ValueOrThrow(DLPackManagedTensorToBuffer( + tensor, device->device(), device->client(), stream)); + }, + nb::arg("dlpack"), nb::arg("device"), nb::arg("stream").none()); + // Legacy overload + m.def( + "dlpack_managed_tensor_to_buffer", + [](const nb::capsule& tensor, + std::optional> cpu_client, + std::optional> gpu_client) { + return xla::ValueOrThrow(DLPackManagedTensorToBuffer( + tensor, std::move(cpu_client), std::move(gpu_client))); + }, + nb::arg("dlpack"), nb::arg("cpu_backend").none() = nb::none(), + nb::arg("gpu_backend").none() = nb::none()); + m.def("cuda_array_interface_to_buffer", + xla::ValueOrThrowWrapper(CudaArrayInterfaceToBuffer), nb::arg("cai"), + nb::arg("gpu_backend").none() = nb::none(), + nb::arg("device_id").none() = nb::none()); + + jax::BuildConfigSubmodule(m); + BuildIfrtProgramsSubmodule(m); + BuildPytreeSubmodule(m); + jax::BuildGuardSubmodule(m); + jax::BuildJaxjitSubmodule(m); + jax::BuildPmapSubmodule(m); + jax::BuildPjitSubmodule(m); + BuildTracebackSubmodule(m); + BuildMlirSubmodule(m); + BuildSdySubmodule(m); + BuildCustomCallShardingPybindAPI(m); + jax::BuildFfiSubmodule(m); +#if defined(__linux__) + aux::RegisterTransferServerTypes(m); +#endif // defined(__linux__) + + // The following uses python bindings for PyClient defined above using + // pybind11, and hence needs pybind11::module_ (not just nanobind::module_). + xla::ifrt::proxy::BuildIfrtProxySubmodule(m); + + nb::class_ preemption_sync_manager( + m, "PreemptionSyncManager"); + preemption_sync_manager + .def( + "initialize", + [](tsl::PreemptionSyncManager& manager, + DistributedRuntimeClient* client) { + tsl::CoordinationServiceAgent* agent = + xla::ValueOrThrow(client->GetCoordinationServiceAgent()); + xla::ThrowIfError(manager.Initialize(agent)); + }, + nb::arg("distributed_client")) + .def("reached_sync_point", + [](tsl::PreemptionSyncManager& manager, int step_counter) { + return manager.ReachedSyncPoint(step_counter); + }) + .def("shutdown", [](tsl::PreemptionSyncManager& manager) { + nb::gil_scoped_release gil_release; + manager.Shutdown(); + }); + m.def("create_preemption_sync_manager", + []() { return tsl::CreatePreemptionSyncManager(); }); + + nb::class_ distributed_runtime_service( + m, "DistributedRuntimeService"); + distributed_runtime_service.def("shutdown", + &DistributedRuntimeService::Shutdown, + nb::call_guard()); + nb::class_ distributed_runtime_client( + m, "DistributedRuntimeClient"); + distributed_runtime_client + .def("connect", + [](DistributedRuntimeClient& self) { + nb::gil_scoped_release gil_release; + xla::ThrowIfError(self.Connect()); + }) + .def("shutdown", + [](DistributedRuntimeClient& self) { + nb::gil_scoped_release gil_release; + xla::ThrowIfError(self.Shutdown()); + }) + // This method assumes that the value is a Python string. Use + // `blocking_key_value_get_bytes()` if key_value_set() was called with a + // Python bytes object as its value. + .def( + "blocking_key_value_get", + [](DistributedRuntimeClient& client, std::string key, + int64_t timeout_in_ms) { + nb::gil_scoped_release gil_release; + return xla::ValueOrThrow(client.BlockingKeyValueGet( + key, absl::Milliseconds(timeout_in_ms))); + }, + nb::arg("key"), nb::arg("timeout_in_ms")) + // Same as `blocking_key_value_get()`, but retrieves the raw Python byte + // values explicitly. + .def( + "blocking_key_value_get_bytes", + [](DistributedRuntimeClient& client, std::string key, + int64_t timeout_in_ms) -> nb::bytes { + std::string result; + { + nb::gil_scoped_release gil_release; + result = xla::ValueOrThrow(client.BlockingKeyValueGet( + key, absl::Milliseconds(timeout_in_ms))); + } + return nb::bytes(result.data(), result.size()); + }, + nb::arg("key"), nb::arg("timeout_in_ms")) + .def( + "key_value_try_get", + [](DistributedRuntimeClient& client, std::string key) { + nb::gil_scoped_release gil_release; + return xla::ValueOrThrow(client.KeyValueTryGet(key)); + }, + nb::arg("key")) + .def( + "key_value_try_get_bytes", + [](DistributedRuntimeClient& client, std::string key) -> nb::bytes { + std::string result; + { + nb::gil_scoped_release gil_release; + result = xla::ValueOrThrow(client.KeyValueTryGet(key)); + } + return nb::bytes(result.data(), result.size()); + }, + nb::arg("key")) + .def( + "wait_at_barrier", + [](DistributedRuntimeClient& client, std::string barrier_id, + int64_t timeout_in_ms, + std::optional> process_ids) { + nb::gil_scoped_release gil_release; + xla::ThrowIfError(client.WaitAtBarrier( + barrier_id, absl::Milliseconds(timeout_in_ms), process_ids)); + }, + nb::arg("barrier_id"), nb::arg("timeout_in_ms"), + nb::arg("process_ids") = std::nullopt) + .def( + "get_live_nodes", + [](DistributedRuntimeClient& client, + std::vector process_ids) { + nb::gil_scoped_release gil_release; + return xla::ValueOrThrow(client.GetLiveNodes(process_ids)); + }, + nb::arg("process_ids")) + // The key must be a string, but the value can either be a Python string + // or bytes object. + // With Python string values, use `key_value_set()` and + // `blocking_key_value_get()`. + // With Python byte object values, use `key_value_set()` and + // `blocking_key_value_get_bytes()`. + .def( + "key_value_set", + [](DistributedRuntimeClient& client, absl::string_view key, + absl::string_view value, bool allow_overwrite) { + nb::gil_scoped_release gil_release; + xla::ThrowIfError(client.KeyValueSet(key, value, allow_overwrite)); + }, + nb::arg("key"), nb::arg("value"), nb::arg("allow_overwrite") = false) + // The key must be a string, but the value must a + // Python bytes object. + // Use `key_value_set_bytes()` and `blocking_key_value_get_bytes()`. + .def( + "key_value_set_bytes", + [](DistributedRuntimeClient& client, absl::string_view key, + nb::bytes value, bool allow_overwrite) { + nb::gil_scoped_release gil_release; + xla::ThrowIfError(client.KeyValueSet( + key, absl::string_view(value.c_str(), value.size()), + allow_overwrite)); + }, + nb::arg("key"), nb::arg("value"), nb::arg("allow_overwrite") = false) + // Assumes that all values in the directory are Python strings. + .def( + "key_value_dir_get", + [](DistributedRuntimeClient& client, absl::string_view key) { + nb::gil_scoped_release gil_release; + return xla::ValueOrThrow(client.KeyValueDirGet(key)); + }, + nb::arg("key")) + // Assumes that all values in the directory are Python byte objects. + // Same as `key_value_dir_get()`, but retrieves Python byte values + // explicitly. + .def( + "key_value_dir_get_bytes", + [](DistributedRuntimeClient& client, absl::string_view key) + -> std::vector> { + std::vector> result; + { + nb::gil_scoped_release gil_release; + result = xla::ValueOrThrow(client.KeyValueDirGet(key)); + } + // Convert std::string values to nb::bytes. + std::vector> kvs; + kvs.reserve(result.size()); + for (auto& kv : result) { + kvs.push_back( + std::pair(std::move(kv.first), + nb::bytes(kv.second.data(), kv.second.size()))); + } + return kvs; + }, + nb::arg("key")) + .def( + "key_value_delete", + [](DistributedRuntimeClient& client, absl::string_view key) { + nb::gil_scoped_release gil_release; + return xla::ThrowIfError(client.KeyValueDelete(key)); + }, + nb::arg("key")); + + m.def( + "get_distributed_runtime_service", + [](std::string address, int num_nodes, + std::optional heartbeat_interval, + std::optional max_missing_heartbeats, + std::optional cluster_register_timeout, + std::optional shutdown_timeout) + -> std::unique_ptr { + CoordinationServiceImpl::Options options; + options.num_nodes = num_nodes; + options.heartbeat_timeout = + max_missing_heartbeats.value_or(10) * + absl::Seconds(heartbeat_interval.value_or(10)); + if (heartbeat_interval.has_value()) { + options.heartbeat_interval = absl::Seconds(*heartbeat_interval); + } + if (max_missing_heartbeats.has_value()) { + options.max_missing_heartbeats = *max_missing_heartbeats; + } + if (cluster_register_timeout.has_value()) { + options.cluster_register_timeout = + absl::Seconds(*cluster_register_timeout); + } + if (shutdown_timeout.has_value()) { + options.shutdown_timeout = absl::Seconds(*shutdown_timeout); + } + std::unique_ptr service = + xla::ValueOrThrow(GetDistributedRuntimeService(address, options)); + return service; + }, + nb::arg("address"), nb::arg("num_nodes"), + nb::arg("heartbeat_interval").none() = std::nullopt, + nb::arg("max_missing_heartbeats").none() = std::nullopt, + nb::arg("cluster_register_timeout").none() = std::nullopt, + nb::arg("shutdown_timeout").none() = std::nullopt); + + m.def( + "get_distributed_runtime_client", + [](std::string address, int node_id, std::optional rpc_timeout, + std::optional init_timeout, std::optional shutdown_timeout, + std::optional heartbeat_interval, + std::optional max_missing_heartbeats, + std::optional> + missed_heartbeat_callback, + std::optional shutdown_on_destruction, + std::optional use_compression) + -> std::shared_ptr { + bool compression = use_compression.value_or(false); + DistributedRuntimeClient::Options options; + options.node_id = node_id; + if (rpc_timeout.has_value()) { + options.rpc_timeout = absl::Seconds(*rpc_timeout); + } + if (init_timeout.has_value()) { + options.init_timeout = absl::Seconds(*init_timeout); + } + if (shutdown_timeout.has_value()) { + options.shutdown_timeout = absl::Seconds(*shutdown_timeout); + } + options.heartbeat_timeout = + max_missing_heartbeats.value_or(10) * + absl::Seconds(heartbeat_interval.value_or(10)); + if (heartbeat_interval.has_value()) { + options.heartbeat_interval = absl::Seconds(*heartbeat_interval); + } + if (max_missing_heartbeats.has_value()) { + options.max_missing_heartbeats = *max_missing_heartbeats; + } + if (missed_heartbeat_callback.has_value()) { + options.missed_heartbeat_callback = + std::move(*missed_heartbeat_callback); + } + if (shutdown_on_destruction.has_value()) { + options.shutdown_on_destruction = *shutdown_on_destruction; + } + return GetDistributedRuntimeClient(address, options, compression); + }, + nb::arg("address"), nb::arg("node_id"), + nb::arg("rpc_timeout").none() = std::nullopt, + nb::arg("init_timeout").none() = std::nullopt, + nb::arg("shutdown_timeout").none() = std::nullopt, + nb::arg("heartbeat_interval").none() = std::nullopt, + nb::arg("max_missing_heartbeats").none() = std::nullopt, + nb::arg("missed_heartbeat_callback").none() = std::nullopt, + nb::arg("shutdown_on_destruction").none() = std::nullopt, + nb::arg("use_compression").none() = std::nullopt); + + m.def("collect_garbage", []() { GlobalPyRefManager()->CollectGarbage(); }); + + m.def("is_optimized_build", &IsOptimizedBuild); + + m.def("json_to_pprof_profile", xla::ValueOrThrowWrapper(JsonToPprofProfile), + "Encodes the JSON representation of a pprof Profile into its binary " + "protocol buffer encoding."); + m.def("pprof_profile_to_json", xla::ValueOrThrowWrapper(PprofProfileToJson), + "Decodes an uncompressed pprof Profile protocol buffer into a JSON " + "representation"); + + RegisterCompileOnlyClient(m); + nb::class_(m, "DeviceTopology") + .def("_make_compile_only_devices", + [](std::shared_ptr topology) { + if (!llvm::isa(*topology)) { + throw xla::XlaRuntimeError("Only PjRtTopologies are supported."); + } + return MakeCompileOnlyClient( + std::dynamic_pointer_cast(topology)) + ->Devices(); + }) + .def_prop_ro( + "platform", + [](ifrt::Topology& topology) { return topology.platform_name(); }) + .def_prop_ro( + "platform_version", + [](ifrt::Topology& topology) { return topology.platform_version(); }) + .def("serialize", + [](ifrt::Topology& topology) -> nb::bytes { + std::string serialized = ValueOrThrow(topology.Serialize()); + return nb::bytes(serialized.data(), serialized.size()); + }) + .def("__getattr__", + [](ifrt::Topology& topology, absl::string_view name) -> nb::object { + const auto& attrs = topology.Attributes().map(); + auto it = attrs.find(name); + if (it != attrs.end()) { + return std::visit([](auto&& v) { return nb::cast(v.value); }, + it->second); + } + throw nb::attribute_error( + absl::StrCat("Unknown attribute ", name).c_str()); + }); + + nb::class_(m, "Executable") + .def("hlo_modules", ValueOrThrowWrapper(&PyExecutable::GetHloModules)) + .def("get_output_memory_kinds", + xla::ValueOrThrowWrapper(&PyExecutable::GetOutputMemoryKinds)) + .def("get_output_shardings", &PyExecutable::GetOutputShardings) + .def("get_parameter_layouts", + ValueOrThrowWrapper(&PyExecutable::GetParameterLayouts)) + .def("get_output_layouts", + xla::ValueOrThrowWrapper(&PyExecutable::GetOutputLayouts)) + .def("get_parameter_shardings", &PyExecutable::GetParameterShardings) + .def("get_compiled_memory_stats", + xla::ValueOrThrowWrapper(&PyExecutable::GetCompiledMemoryStats)) + .def("serialize", + [](const PyExecutable& exec) -> nb::bytes { + std::string serialized = ValueOrThrow(exec.Serialize()); + return nb::bytes(serialized.data(), serialized.size()); + }) + .def("cost_analysis", [](const PyExecutable& exec) { + auto attrs = ValueOrThrow(exec.GetCostAnalysis()); + return ifrt::ToPjRtAttributeMap(std::move(attrs)); + }); + + m.def("is_asan", IsAsan); + m.def("is_msan", IsMsan); + m.def("is_tsan", IsTsan); + m.def("is_sanitized", IsSanitized); + + m.def( + "batched_device_put", + [](nb::object aval, nb::object sharding, std::vector xs, + std::vector dst_devices, bool committed, + bool force_copy, + PjRtClient::HostBufferSemantics host_buffer_semantics) -> nb::object { + return ValueOrThrow(PyArray::BatchedDevicePut( + aval, sharding, std::move(xs), std::move(dst_devices), committed, + force_copy, host_buffer_semantics, jax::GetEnableX64())); + }, + nb::arg("aval"), nb::arg("sharding"), nb::arg("xs"), nb::arg("devices"), + nb::arg("committed") = true, nb::arg("force_copy") = false, + nb::arg("host_buffer_semantics") = + PjRtClient::HostBufferSemantics::kImmutableZeroCopy); + m.def( + "reorder_shards", + [](PyArray x, nb::object dst_sharding, + ifrt::ArrayCopySemantics array_copy_semantics) { + return ValueOrThrow(PyArray::ReorderShards( + std::move(x), std::move(dst_sharding), array_copy_semantics)); + }, + nb::arg("x"), nb::arg("dst_sharding"), nb::arg("array_copy_semantics")); + + m.def("batched_block_until_ready", [](std::vector xs) { + ThrowIfError(PyArray::BatchedBlockUntilReady(std::move(xs))); + }); + + m.def("check_and_canonicalize_memory_kind", + &jax::CheckAndCanonicalizeMemoryKind, nb::arg("memory_kind").none(), + nb::arg("device_list")); + + m.attr("ifrt_version_number") = JAX_IFRT_VERSION_NUMBER; + + m.def("approx_top_k_reduction_output_size", + xla::ValueOrThrowWrapper(ApproxTopKReductionOutputSize), + nb::arg("input_size"), nb::arg("rank"), nb::arg("top_k"), + nb::arg("recall_target"), nb::arg("aggregate_to_topk") = true, + nb::arg("input_size_override") = -1); + + m.def("get_internal_device_put_info", + []() { return DevicePutInfo::GetInfo(); }); + + jax::PartitionSpec::Register(m); +} // NOLINT(readability/fn_size) + +} // namespace xla diff --git a/jaxlib/xla_client.py b/jaxlib/xla_client.py new file mode 100644 index 000000000000..a5c85276ce2c --- /dev/null +++ b/jaxlib/xla_client.py @@ -0,0 +1,549 @@ +# Copyright 2017 The JAX Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""An XLA client in Python.""" + +from __future__ import annotations + +import atexit +from collections.abc import Mapping +import contextlib +import enum +import logging +import os +import threading +from typing import Any, Protocol, Union + +from jaxlib import _jax as _xla + +# Note this module does *not* depend on any Python protocol buffers. The XLA +# Python bindings are currently packaged both as part of jaxlib and as part +# of TensorFlow. If we use protocol buffers here, then importing both jaxlib +# and TensorFlow may fail with duplicate protocol buffer message definitions. + +# Most functions are snake_case for consistency with other modules, some +# method names are CamelCase for consistency with XLA. +# pylint: disable=invalid-name + +# Pylint has false positives for type annotations. +# pylint: disable=invalid-sequence-index + +ifrt_programs = _xla.ifrt_programs + +# Just an internal arbitrary increasing number to help with backward-compatible +# changes. In JAX, reference this via jax._src.lib.jaxlib_extension_version. +_version = 355 + +# An internal increasing version number for protecting jaxlib code against +# ifrt changes. +# lives in xla/python/version.h. +# In JAX, reference this via jax._src.lib.ifrt_version. +_ifrt_version = _xla.ifrt_version_number + +xla_platform_names = { + 'cpu': 'Host', + 'gpu': 'CUDA', +} + +logger = logging.getLogger(__name__) + +_NameValueMapping = Mapping[str, Union[str, int, list[int], float, bool]] + + +def make_cpu_client( + asynchronous=True, + distributed_client=None, + node_id=0, + num_nodes=1, + collectives=None, + num_devices=None, + get_local_topology_timeout_minutes=None, + get_global_topology_timeout_minutes=None, +) -> Client: + register_custom_call_handler('cpu', _xla.register_custom_call_target) + register_custom_type_id_handler('cpu', _xla.register_custom_type_id) + return _xla.get_tfrt_cpu_client( + asynchronous=asynchronous, + distributed_client=distributed_client, + node_id=node_id, + num_nodes=num_nodes, + collectives=collectives, + num_devices=num_devices, + get_local_topology_timeout_minutes=get_local_topology_timeout_minutes, + get_global_topology_timeout_minutes=get_global_topology_timeout_minutes, + ) + + +DeviceTopology = _xla.DeviceTopology +get_topology_for_devices = _xla.get_topology_for_devices + + +def make_tfrt_tpu_c_api_device_topology( + topology_name: str = '', **kwargs +) -> DeviceTopology: + """Creates a PJRT C API TopologyDescription.""" + return _xla.get_default_c_api_topology('tpu', topology_name, dict(**kwargs)) + + +def make_c_api_device_topology( + c_api: Any, topology_name: str = '', **kwargs +) -> DeviceTopology: + """Creates a PJRT C API TopologyDescription.""" + return _xla.get_c_api_topology(c_api, topology_name, dict(**kwargs)) + + +def pjrt_plugin_loaded(plugin_name: str) -> bool: + return _xla.pjrt_plugin_loaded(plugin_name) + + +def load_pjrt_plugin_dynamically(plugin_name: str, library_path: str) -> Any: + return _xla.load_pjrt_plugin(plugin_name, library_path, c_api=None) + + +def load_pjrt_plugin_with_c_api(plugin_name: str, c_api: Any) -> None: + return _xla.load_pjrt_plugin(plugin_name, None, c_api) + + +def pjrt_plugin_initialized(plugin_name: str) -> bool: + return _xla.pjrt_plugin_initialized(plugin_name) + + +def initialize_pjrt_plugin(plugin_name: str) -> None: + """Initializes a PJRT plugin. + + The plugin needs to be loaded first (through load_pjrt_plugin_dynamically or + static linking) before this method is called. + Args: + plugin_name: the name of the PJRT plugin. + """ + _xla.initialize_pjrt_plugin(plugin_name) + + +def make_c_api_client( + plugin_name: str, + options: _NameValueMapping | None = None, + distributed_client: _xla.DistributedRuntimeClient | None = None, +): + """Creates a PJRT C API client for a PJRT plugin. + + It is required that load_pjrt_plugin_dynamically is called once with the same + plugin_name before this method is called. + + Args: + plugin_name: the name of the PJRT plugin. + options: extra platform-specific options. + distributed_client: distributed client. + + Returns: + A PJRT C API client for plugin_name. + """ + if options is None: + options = {} + return _xla.get_c_api_client(plugin_name, options, distributed_client) + + +def generate_pjrt_gpu_plugin_options() -> _NameValueMapping: + """Generates the PjRt GPU plugin options. + + Returns: + A dictionary of plugin options. + """ + + options = {} + options['platform_name'] = 'cuda' + allocator = os.getenv('XLA_PYTHON_CLIENT_ALLOCATOR', 'default').lower() + memory_fraction = os.getenv('XLA_CLIENT_MEM_FRACTION', '') + deprecated_memory_fraction = os.getenv('XLA_PYTHON_CLIENT_MEM_FRACTION', '') + if deprecated_memory_fraction: + if memory_fraction: + raise ValueError( + 'XLA_CLIENT_MEM_FRACTION is specified together ' + 'with XLA_PYTHON_CLIENT_MEM_FRACTION. ' + 'Remove the latter one, it is deprecated.' + ) + else: + memory_fraction = deprecated_memory_fraction + preallocate = os.getenv('XLA_PYTHON_CLIENT_PREALLOCATE', '') + collective_memory_size = os.getenv( + 'XLA_PYTHON_CLIENT_COLLECTIVE_MEM_SIZE_MB', '' + ) + if allocator not in ('default', 'platform', 'bfc', 'cuda_async'): + raise ValueError( + 'XLA_PYTHON_CLIENT_ALLOCATOR env var must be "default", "platform", ' + '"bfc", or "cuda_async", got "%s"' % allocator + ) + options['allocator'] = allocator + if memory_fraction: + options['memory_fraction'] = float(memory_fraction) + if preallocate: + options['preallocate'] = preallocate not in ('false', 'False', '0') + if collective_memory_size: + options['collective_memory_size'] = int(collective_memory_size) * (1 << 20) + return options + + +PrimitiveType = _xla.PrimitiveType + +Shape = _xla.Shape +Shape.__doc__ = """ +A Shape is an object defined in C++ that duck types like the following class: + +class Shape: + '''Represents an XLA shape. + + A shape is either an array shape, having rank-many integer + dimensions and an element type (represented by a Numpy dtype), or it + is a tuple shape, having a shape for every tuple component: + + type shape = + TupleShape of shape list + | ArrayShape of { dimensions: int list; element_type: dtype } + ''' + + @staticmethod + def tuple_shape(tuple_shapes) -> Shape: + "Construct a tuple shape." + + @staticmethod + def array_shape(element_type, dimensions, minor_to_major=None) -> Shape: + + @staticmethod + def from_pyval(pyval) -> Shape: + "Returns a Shape that describes a tuple-tree of Numpy arrays." + + def __init__(self, str) -> Shape: + "Parses a shape string." + def __eq__(self, other: Shape) -> bool: + def __ne__(self, other: Shape) -> bool: + def __hash__(self): + def __repr__(self): + def is_tuple(self) -> bool: + def is_array(self) -> bool: + def tuple_shapes(self) -> [Shape]: + def numpy_dtype(self) -> np.dtype: + "Like element_type(), but returns dtype('O') for a tuple shape." + def xla_element_type(self) -> PrimitiveType: + def element_type(self) -> np.dtype: + def dimensions(self) -> (int, int, ...): + def rank(self) -> int: + def with_major_to_minor_layout_if_absent(self) -> Shape: + "Returns a copy with missing layouts set to major-to-minor." + + def to_serialized_proto(self) -> bytes: + "Returns 'shape' as a serialized proto." +""" + +ProgramShape = _xla.ProgramShape +ProgramShape.__doc__ = """ +A ProgramShape is a C++ object that duck types like the following class. + +class ProgramShape: + def __init__(self, parameter_shapes, result_shape): + def parameter_shapes(self) -> [Shape]: + def result_shape(self) -> Shape: + def __repr__(self): +""" + +DeviceAssignment = _xla.DeviceAssignment +DeviceAssignment.__doc__ = """ +A DeviceAssignment is a C++ object with the following signature. + +def create(assignment): + '''Builds a device assignment. + + Args: + assignment: a 2D numpy array of device ordinal integers, indexed by + [replica][computation_in_replica]. + Returns: + A device assignment. + ''' + +def replica_count(): + '''Returns the number of replicas.''' +def computation_count(): + '''Returns the number of computations per replica.''' +""" + +Device = _xla.Device +CompileOptions = _xla.CompileOptions + +HostBufferSemantics = _xla.HostBufferSemantics + +# An Executable is a C++ class that duck types with the following API: +# class Executable: +# def local_devices(self) -> [Device]: +# def execute(self, arguments : [Buffer]) -> Buffer: +# """Execute on one replica with Buffer arguments and return value.""" +# +# def size_of_generated_code_in_bytes(self) -> int: +# """Return generated binary size, or -1 if not known.""" +# +# def execute_sharded_on_local_devices(self, arguments: [[Buffer]]) +# -> [Buffer]: +# """Execute on many replicas with Buffer arguments and return value. +# +# Args: +# arguments: A sequence of sequences of Buffers. The i'th element of each +# sequence comprises the arguments for execution on the i'th local +# device. +# +# Returns: +# A list of the computation's outputs as a list of Buffers for each +# device. +# """ +# +# There are different implementations of Executable for different backends. + + +XlaComputation = _xla.XlaComputation +Client = _xla.Client +Memory = _xla.Memory +Array = _xla.Array +ArrayImpl = _xla.ArrayImpl +LoadedExecutable = _xla.LoadedExecutable +Executable = _xla.Executable +DeviceList = _xla.DeviceList +OpSharding = _xla.OpSharding +HloSharding = _xla.HloSharding +Sharding = _xla.Sharding +NamedSharding = _xla.NamedSharding +SingleDeviceSharding = _xla.SingleDeviceSharding +PmapSharding = _xla.PmapSharding +GSPMDSharding = _xla.GSPMDSharding +PjRtLayout = _xla.PjRtLayout +AutotuneCacheMode = _xla.AutotuneCacheMode + + +def LoadedExecutable_execute(self, arguments, device=None): + del device + results = self.execute_sharded(arguments) + return [x[0] for x in results.disassemble_into_single_device_arrays()] + + +def LoadedExecutable_execute_with_token(self, arguments, device=None): + del device + results = self.execute_sharded(arguments, with_tokens=True) + return ( + [x[0] for x in results.disassemble_into_single_device_arrays()], + results.consume_token().get_token(0), + ) + + +LoadedExecutable.execute = LoadedExecutable_execute +LoadedExecutable.execute_with_token = LoadedExecutable_execute_with_token + + +class CustomCallTargetTraits(enum.IntFlag): + DEFAULT = 0 + # Calls to custom call are safe to trace into the command buffer. It means + # that calls to custom call always launch exactly the same device operations + # (can depend on attribute values) that can be captured and then replayed. + # + # Supported only for custom calls implemented with XLA FFI. + COMMAND_BUFFER_COMPATIBLE = 1 + + +class CustomCallHandler(Protocol): + + def __call__( + self, + name: str, + fn: Any, + platform: str, + /, + api_version: int = ..., + traits: CustomCallTargetTraits = ..., + ) -> None: + ... + + +_custom_callback_handler: dict[str, CustomCallHandler] = {} +# Key is xla_platform_name, value is (function_name, function, api_version) +_custom_callback: dict[ + str, list[tuple[str, Any, int, CustomCallTargetTraits]] +] = {} +_custom_callback_lock = threading.Lock() + + +def register_custom_call_target( + name: str, + fn: Any, + platform: str = 'cpu', + api_version: int = 0, + traits: CustomCallTargetTraits = CustomCallTargetTraits.DEFAULT, +) -> None: + """Registers a custom call target. + + Args: + name: bytes containing the name of the function. + fn: a PyCapsule object containing the function pointer. + platform: the target platform. + api_version: the XLA FFI version to use. Supported versions are: 0 for the + untyped FFI and 1 for the typed FFI. + traits: custom call traits corresponding to XLA FFI handler traits. + """ + # To support AMD GPUs, we need to have xla_platform_names["gpu"] == "ROCM" + # Since that is hardcoded to CUDA, we are using the following as workaround. + xla_platform_name = xla_platform_names.get(platform, platform) + with _custom_callback_lock: + if xla_platform_name in _custom_callback_handler: + _custom_callback_handler[xla_platform_name]( + name, fn, xla_platform_name, api_version, traits + ) + else: + _custom_callback.setdefault(xla_platform_name, []).append( + (name, fn, api_version, traits) + ) + + +def register_custom_call_handler( + platform: str, handler: CustomCallHandler +) -> None: + """Registers a custom handler and use it to register existing custom calls. + + If a custom call handler for the platform already exist, calling this method + is a no-op and it will not register a new handler. + + Args: + platform: the target platform. + handler: the function to register a custom call. + """ + xla_platform_name = xla_platform_names.get(platform, platform) + with _custom_callback_lock: + if xla_platform_name in _custom_callback_handler: + logger.debug( + 'Custom call handler for %s is already register. Will not register a' + ' new one', + xla_platform_name, + ) + return + _custom_callback_handler[xla_platform_name] = handler + if xla_platform_name in _custom_callback: + for name, fn, api_version, traits in _custom_callback[xla_platform_name]: + handler(name, fn, xla_platform_name, api_version, traits) + del _custom_callback[xla_platform_name] + + +class CustomTypeIdHandler(Protocol): + + def __call__(self, name: str, capsule: Any) -> None: + ... + + +_custom_type_id_handler: dict[str, CustomTypeIdHandler] = {} +_custom_type_id: dict[str, Any] = {} +_custom_type_id_lock = threading.Lock() + + +def register_custom_type_id( + type_name: str, + type_id: Any, + platform: str = 'cpu', +) -> None: + """Register a custom type id for use with the FFI. + + Args: + type_name: a unique name for the type. + type_id: a PyCapsule object containing a pointer to the ``ffi::TypeId``. + platform: the target platform. + """ + xla_platform_name = xla_platform_names.get(platform, platform) + with _custom_type_id_lock: + if xla_platform_name in _custom_type_id_handler: + _custom_type_id_handler[xla_platform_name](type_name, type_id) + else: + _custom_type_id.setdefault(xla_platform_name, []).append( + (type_name, type_id) + ) + + +def register_custom_type_id_handler( + platform: str, handler: CustomTypeIdHandler +) -> None: + """Register a custom type id handler and use it to register existing type ids. + + If a custom type id handler for the platform already exist, calling this + method is a no-op and it will not register a new handler. + + Args: + platform: the target platform. + handler: the function to register a custom type id. + """ + xla_platform_name = xla_platform_names.get(platform, platform) + with _custom_callback_lock: + if xla_platform_name in _custom_type_id_handler: + logger.debug( + 'Custom type id handler for %s is already register. Will not ' + 'register a new one', + xla_platform_name, + ) + return + _custom_type_id_handler[xla_platform_name] = handler + if xla_platform_name in _custom_type_id: + for name, capsule in _custom_type_id[xla_platform_name]: + handler(name, capsule) + del _custom_type_id[xla_platform_name] + + +register_custom_call_partitioner = _xla.register_custom_call_partitioner +encode_inspect_sharding_callback = _xla.encode_inspect_sharding_callback +hlo_sharding_util = _xla.hlo_sharding_util +register_custom_call_as_batch_partitionable = ( + _xla.register_custom_call_as_batch_partitionable +) + + +Traceback = _xla.Traceback +Frame = _xla.Frame + + +@contextlib.contextmanager +def tracebacks(enabled=True): + """Context manager that enables or disables traceback collection.""" + saved = _xla.tracebacks_enabled() + _xla.set_tracebacks_enabled(enabled) + try: + yield + finally: + _xla.set_tracebacks_enabled(saved) + + +@contextlib.contextmanager +def execution_stream_id(new_id: int): + """Context manager that overwrites and restores the current thread's execution_stream_id.""" + saved = _xla.get_execution_stream_id() + _xla.set_execution_stream_id(new_id) + try: + yield + finally: + _xla.set_execution_stream_id(saved) + + +XlaRuntimeError = _xla.XlaRuntimeError + +# Perform one last garbage collection of deferred Python references. This is +# mostly to keep ASAN happy. +atexit.register(_xla.collect_garbage) + +array_result_handler = _xla.array_result_handler +batched_copy_array_to_devices_with_sharding = ( + _xla.batched_copy_array_to_devices_with_sharding +) +batched_device_put = _xla.batched_device_put +reorder_shards = _xla.reorder_shards +batched_block_until_ready = _xla.batched_block_until_ready +check_and_canonicalize_memory_kind = _xla.check_and_canonicalize_memory_kind +Layout = _xla.Layout +custom_call_targets = _xla.custom_call_targets +ArrayCopySemantics = _xla.ArrayCopySemantics diff --git a/jaxlib/xla_client.pyi b/jaxlib/xla_client.pyi new file mode 100644 index 000000000000..ce9a2b815809 --- /dev/null +++ b/jaxlib/xla_client.pyi @@ -0,0 +1,153 @@ +# Copyright 2021 The JAX Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import annotations + +from collections.abc import Callable, Mapping, Sequence +import enum +from typing import Any, Union + +from jaxlib import _jax as _xla +from jaxlib._jax import ArrayCopySemantics as ArrayCopySemantics +from jaxlib._jax import ArrayImpl as ArrayImpl +from jaxlib._jax import AutotuneCacheMode as AutotuneCacheMode +from jaxlib._jax import Client as Client +from jaxlib._jax import CompileOptions as CompileOptions +from jaxlib._jax import Device as Device +from jaxlib._jax import DeviceAssignment as DeviceAssignment +from jaxlib._jax import DeviceList as DeviceList +from jaxlib._jax import DeviceTopology as DeviceTopology +from jaxlib._jax import DistributedRuntimeClient as DistributedRuntimeClient +from jaxlib._jax import Frame as Frame +from jaxlib._jax import GSPMDSharding as GSPMDSharding +from jaxlib._jax import HloSharding as HloSharding +from jaxlib._jax import HostBufferSemantics as HostBufferSemantics +from jaxlib._jax import ifrt_programs as ifrt_programs +from jaxlib._jax import Layout as Layout +from jaxlib._jax import LoadedExecutable as LoadedExecutable +from jaxlib._jax import Executable as Executable +from jaxlib._jax import Memory as Memory +from jaxlib._jax import NamedSharding as NamedSharding +from jaxlib._jax import OpSharding as OpSharding +from jaxlib._jax import PjRtLayout as PjRtLayout +from jaxlib._jax import PmapSharding as PmapSharding +from jaxlib._jax import PrimitiveType as PrimitiveType +from jaxlib._jax import Shape as Shape +from jaxlib._jax import Sharding as Sharding +from jaxlib._jax import SingleDeviceSharding as SingleDeviceSharding +from jaxlib._jax import Traceback as Traceback +from jaxlib._jax import XlaComputation as XlaComputation + +_version: int +_ifrt_version: int + +_NameValueMapping = Mapping[str, Union[str, int, list[int], float, bool]] + +XlaRuntimeError = _xla.XlaRuntimeError + +def make_cpu_client( + asynchronous: bool = ..., + distributed_client: DistributedRuntimeClient | None = ..., + node_id: int = ..., + num_nodes: int = ..., + collectives: _xla.CpuCollectives | None = ..., + num_devices: int | None = ..., + get_local_topology_timeout_minutes: int | None = ..., + get_global_topology_timeout_minutes: int | None = ..., +) -> Client: ... +def make_gpu_client( + distributed_client: DistributedRuntimeClient | None = ..., + node_id: int = ..., + num_nodes: int = ..., + platform_name: str | None = ..., + allowed_devices: set[int] | None = ..., + mock: bool | None = ..., + mock_gpu_topology: str | None = ..., +) -> Client: ... +def make_tfrt_tpu_c_api_device_topology( + topology_name: str | None = None, **kwargs +) -> DeviceTopology: ... +def make_c_api_device_topology( + c_api: Any, topology_name: str = '', **kwargs +) -> DeviceTopology: ... +def get_topology_for_devices(devices: list[Device]) -> DeviceTopology: ... +def make_c_api_client( + plugin_name: str, + options: _NameValueMapping | None = None, + distributed_client: DistributedRuntimeClient | None = None, +) -> Client: ... +def pjrt_plugin_loaded(plugin_name: str) -> bool: ... +def load_pjrt_plugin_dynamically( + plugin_name: str, library_path: str +) -> Any: ... +def load_pjrt_plugin_with_c_api(plugin_name: str, c_api: Any) -> None: ... +def pjrt_plugin_initialized(plugin_name: str) -> bool: ... +def initialize_pjrt_plugin(plugin_name: str) -> None: ... +def generate_pjrt_gpu_plugin_options() -> _NameValueMapping: ... +def batched_copy_array_to_devices_with_sharding( + arrays: Sequence[ArrayImpl], + devices: Sequence[list[Device]], + sharding: Sequence[Any], + array_copy_semantics: Sequence[ArrayCopySemantics], +) -> Sequence[ArrayImpl]: ... +def batched_device_put( + aval: Any, + sharding: Any, + shards: Sequence[Any], + devices: list[Device], + committed: bool = ..., + force_copy: bool = ..., + host_buffer_semantics: Any = ..., +) -> ArrayImpl: ... +def reorder_shards( + x: ArrayImpl, + dst_sharding: Any, + array_copy_semantics: ArrayCopySemantics, +) -> ArrayImpl: ... +def batched_block_until_ready(x: Sequence[ArrayImpl]) -> None: ... +def check_and_canonicalize_memory_kind( + memory_kind: str | None, device_list: DeviceList +) -> str | None: ... +def array_result_handler( + aval: Any, sharding: Any, committed: bool, _skip_checks: bool = ... +) -> Callable: ... + +class CustomCallTargetTraits(enum.IntFlag): + DEFAULT = 0 + COMMAND_BUFFER_COMPATIBLE = 1 + +def register_custom_call_target( + name: str, + fn: Any, + platform: str = ..., + api_version: int = ..., + traits: CustomCallTargetTraits = ..., +) -> None: ... +def register_custom_call_handler( + xla_platform_name: str, handler: Any +) -> None: ... +def custom_call_targets(platform: str) -> dict[str, Any]: ... +def register_custom_type_id( + type_name: str, + type_id: Any, + platform: str = ..., +) -> None: ... +def register_custom_type_id_handler(platform: str, handler: Any) -> None: ... +def encode_inspect_sharding_callback(handler: Any) -> bytes: ... + +register_custom_call_partitioner = _xla.register_custom_call_partitioner +register_custom_call_as_batch_partitionable = ( + _xla.register_custom_call_as_batch_partitionable +) diff --git a/jaxlib/xla_compiler.cc b/jaxlib/xla_compiler.cc new file mode 100644 index 000000000000..57de57b26aee --- /dev/null +++ b/jaxlib/xla_compiler.cc @@ -0,0 +1,1464 @@ +/* Copyright 2020 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/xla_compiler.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/inlined_vector.h" +#include "absl/hash/hash.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "mlir/Support/LLVM.h" +#include "nanobind/nanobind.h" +#include "nanobind/ndarray.h" +#include "nanobind/stl/optional.h" // IWYU pragma: keep +#include "nanobind/stl/pair.h" // IWYU pragma: keep +#include "nanobind/stl/shared_ptr.h" // IWYU pragma: keep +#include "nanobind/stl/string.h" // IWYU pragma: keep +#include "nanobind/stl/string_view.h" // IWYU pragma: keep +#include "nanobind/stl/variant.h" // IWYU pragma: keep +#include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "jaxlib/dlpack.h" +#include "jaxlib/py_client.h" +#include "xla/array.h" +#include "xla/client/executable_build_options.h" +#include "xla/debug_options_flags.h" +#include "xla/ffi/api/c_api.h" +#include "xla/ffi/ffi.h" +#include "xla/ffi/ffi_api.h" +#include "xla/hlo/builder/xla_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_print_options.h" +#include "xla/hlo/ir/hlo_sharding.h" +#include "xla/hlo/parser/hlo_parser.h" +#include "xla/layout.h" +#include "xla/layout_util.h" +#include "xla/literal.h" +#include "xla/pjrt/exceptions.h" +#include "xla/pjrt/pjrt_executable.h" +#include "xla/pjrt/proto/compile_options.pb.h" +#include "xla/pjrt/status_casters.h" +#include "xla/python/nb_absl_span.h" // IWYU pragma: keep +#include "xla/python/nb_numpy.h" +#include "xla/python/types.h" +#include "xla/service/computation_placer.h" +#include "xla/service/custom_call_target_registry.h" +#include "xla/service/hlo.pb.h" +#include "xla/service/hlo_graph_dumper.h" +#include "xla/service/hlo_module_config.h" +#include "xla/service/spmd/shardy/stablehlo_round_trip/stablehlo_import.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/tsl/lib/strings/proto_serialization.h" +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/util.h" +#include "xla/xla.pb.h" +#include "xla/xla_data.pb.h" + +namespace xla { +namespace { + +namespace nb = nanobind; + +// Converts a computation to a serialized HloModuleProto. +absl::StatusOr GetComputationSerializedProto( + const XlaComputation& computation) { + std::string result; + if (!tsl::SerializeToStringDeterministic(computation.proto(), &result)) { + return Unknown("Failed to serialize the HloModuleProto."); + } + return nb::bytes(result.data(), result.size()); +} + +// Converts a hlo module to a serialized HloModuleProto. +absl::StatusOr GetHloModuleSerializedProto(const HloModule& module) { + std::string result; + if (!tsl::SerializeToStringDeterministic(module.ToProto(), &result)) { + return Unknown("Failed to serialize the HloModuleProto."); + } + return nb::bytes(result.data(), result.size()); +} + +// Converts a serialized HloModuleProto into a HloModule. +absl::StatusOr> HloModuleFromSerializedProto( + const nb::bytes& bytes) { + HloModuleProto proto; + proto.ParseFromArray(bytes.c_str(), bytes.size()); + TF_ASSIGN_OR_RETURN(const HloModuleConfig module_config, + HloModule::CreateModuleConfigFromProto( + proto, GetDebugOptionsFromFlags())); + TF_ASSIGN_OR_RETURN(std::unique_ptr module, + HloModule::CreateFromProto(proto, module_config)); + return std::shared_ptr(std::move(module)); +} + +absl::StatusOr> GetHloModule( + const XlaComputation& computation) { + TF_ASSIGN_OR_RETURN(const HloModuleConfig module_config, + HloModule::CreateModuleConfigFromProto( + computation.proto(), GetDebugOptionsFromFlags())); + TF_ASSIGN_OR_RETURN( + std::unique_ptr module, + HloModule::CreateFromProto(computation.proto(), module_config)); + return std::shared_ptr(std::move(module)); +} + +// Converts a computation to textual HLO form. +absl::StatusOr GetComputationHloText( + const XlaComputation& computation, bool print_large_constants = false) { + TF_ASSIGN_OR_RETURN(std::shared_ptr hlo_module, + GetHloModule(computation)); + HloPrintOptions options; + options = HloPrintOptions::ShortParsable(); + options.set_print_large_constants(print_large_constants); + return hlo_module->ToString(options); +} + +// Converts a computation to HLO dot graph form. +absl::StatusOr GetComputationHloDotGraph( + const XlaComputation& computation) { + TF_ASSIGN_OR_RETURN(std::shared_ptr hlo_module, + GetHloModule(computation)); + return RenderGraph(*hlo_module->entry_computation(), /*label=*/"", + hlo_module->config().debug_options(), + RenderedGraphFormat::kDot); +} + +// Hashes the HLO module. +absl::StatusOr HashComputation(const XlaComputation& computation) { + TF_ASSIGN_OR_RETURN(std::shared_ptr hlo_module, + GetHloModule(computation)); + return absl::HashOf(*hlo_module); +} +// Safe version of ShapeUtil::MakeShapeWithDenseLayout that fails gracefully on +// invalid input. +absl::StatusOr MakeShapeWithDenseLayout( + PrimitiveType element_type, absl::Span dims, + std::optional> minor_to_major, + std::optional> dynamic_dimensions) { + Shape shape; + if (dynamic_dimensions) { + TF_ASSIGN_OR_RETURN( + shape, ShapeUtil::MakeValidatedShape(element_type, dims, + dynamic_dimensions.value())); + } else { + TF_ASSIGN_OR_RETURN(shape, + ShapeUtil::MakeValidatedShape(element_type, dims)); + } + if (minor_to_major) { + *shape.mutable_layout() = LayoutUtil::MakeLayout(*minor_to_major); + TF_RETURN_IF_ERROR( + LayoutUtil::ValidateLayoutForShape(shape.layout(), shape)); + } + + return shape; +} + +// Pybind function for HloSharding.iota_tile, which is a non-crashing factory +// that produces a HloSharding instance backed by tile assignment of a +// transposed and reshaped iota array of device ids. More specifically the tile +// assignment array is as if it is produced by the following numpy code: +// numpy.arange(math.prod(dims)).reshape(reshape_dims) +// .transpose(transpose_perm).reshape(math.prod(dims)) +// where: +// `dims`: is the dimensions of the tile assignment array, which corresponds to +// OpSharding.tile_assignment_dimensions. +// `reshape_dims`: is the dimensions the 1D iota array is reshaped to. +// `transpose_perm`: is the dimension permutation to transpose `reshape_dims`. +// `subgroup_types`: indicates the subgroups of the last `subgroup_types.size()` +// dimensions in `dims`. +// +// In practice, `reshape_dims` often maps to the axes of user defined device +// mesh, and `transpose_perm` often maps to the user specification of how a +// tensor is partitioned based on the axes defined in the mesh, e.g. for a mesh +// of size 4x2x2 as AxBxC: +// PartitionSpec('A', 'B', 'C') corresponds to reshape_dims=[4,2,2], +// transpose_perm=[0,1,2] (no transpose) +// PartitionSpec('B', 'A', 'C') corresponds to reshape_dims=[4,2,2], +// transpose_perm=[1,0,2] (swap A and B) +absl::StatusOr IotaTileHelper( + absl::Span dims, absl::Span reshape_dims, + absl::Span transpose_perm, + absl::Span subgroup_types) { + if (dims.empty()) { + return InvalidArgument("`dims` should not be empty."); + } + if (reshape_dims.size() != transpose_perm.size()) { + return InvalidArgument( + "`reshape_dims` and `transpose_perm` should have the same size, saw " + "[%s] v.s. [%s]", + absl::StrJoin(reshape_dims, ","), absl::StrJoin(transpose_perm, ",")); + } + if (!reshape_dims.empty() && Product(dims) != Product(reshape_dims)) { + return InvalidArgument( + "Cannot reshape from `dims` [%s] to `reshape_dims` [%s].", + absl::StrJoin(dims, ","), absl::StrJoin(reshape_dims, ",")); + } + if (subgroup_types.size() > dims.size()) { + return InvalidArgument( + "`subgroup_types`(%lld) should not have more dimensions than " + "`dims`(%lld).", + subgroup_types.size(), dims.size()); + } + if (reshape_dims.empty()) { + return subgroup_types.empty() + ? HloSharding::IotaTile(dims) + : HloSharding::Subgroup(TileAssignment(dims), subgroup_types); + } + return subgroup_types.empty() + ? HloSharding::IotaTile(dims, reshape_dims, transpose_perm) + : HloSharding::Subgroup( + TileAssignment(dims, reshape_dims, transpose_perm), + subgroup_types); +} + +// Registers a 'fn' as a custom call target. +// +// `fn` must be a custom call implementation function pointer (XLA_FFI_Handler* +// when implemented as FFI handler) encapsulated in a PyCapsule object or a +// a dictionary of function pointers (also encapsulated in a PyCapsule). +// +// See XLA_FFI_ExecutionStage documentation for more details about the +// custom execution stages. +absl::Status PyRegisterCustomCallTarget(const std::string& fn_name, + nb::object fn, + const std::string& platform, + int api_version, + XLA_FFI_Handler_Traits traits) { + // Register legacy custom call target (untyped void* API). + if (api_version == 0) { + if (traits != 0) { + return absl::InvalidArgumentError( + "Custom call target registration with traits is not supported for " + "api_version=0"); + } + + nb::capsule capsule; + if (!nb::try_cast(fn, capsule)) { + return absl::InvalidArgumentError( + "Custom call target registration with api_version=0 requires a " + "PyCapsule fn object"); + } + + CustomCallTargetRegistry::Global()->Register( + fn_name, static_cast(capsule.data()), platform); + return absl::OkStatus(); + } + + // Register XLA FFI handler (typed API with explicit function signatures). + if (api_version == 1) { + nb::capsule capsule; + if (nb::try_cast(fn, capsule)) { + return ffi::TakeStatus(ffi::Ffi::RegisterStaticHandler( + xla::ffi::GetXlaFfiApi(), fn_name, platform, + reinterpret_cast( + static_cast(capsule.data())))); + } + + nb::dict bundle; + if (nb::try_cast(fn, bundle)) { + auto handler = [&](const char* name) -> absl::StatusOr { + if (!bundle.contains(name)) return nullptr; + + nb::capsule capsule; + if (!nb::try_cast(bundle[name], capsule)) { + return absl::InvalidArgumentError( + "Custom call target registration with api_version=1 requires a " + "PyCapsule fn object for all dict keys"); + } + + return reinterpret_cast(capsule.data()); + }; + + XLA_FFI_Handler_Bundle bundle; + TF_ASSIGN_OR_RETURN(bundle.instantiate, handler("instantiate")); + TF_ASSIGN_OR_RETURN(bundle.prepare, handler("prepare")); + TF_ASSIGN_OR_RETURN(bundle.initialize, handler("initialize")); + TF_ASSIGN_OR_RETURN(bundle.execute, handler("execute")); + + return ffi::TakeStatus(ffi::Ffi::RegisterStaticHandler( + xla::ffi::GetXlaFfiApi(), fn_name, platform, bundle, traits)); + } + + return absl::InvalidArgumentError( + "Unsupported custom call target type for api_version=1"); + } + + return absl::UnimplementedError(absl::StrFormat( + "API version %d is not supported by RegisterCustomCallTarget. " + "Supported versions are 0 and 1.", + api_version)); +} + +absl::Status PyRegisterCustomTypeId(absl::string_view type_name, + nb::object type_id) { + nb::capsule capsule; + if (!nb::try_cast(type_id, capsule)) { + return absl::InvalidArgumentError( + "The type_id argument to register_custom_call_type_id must be a " + "PyCapsule object holding a pointer to a XLA_FFI_TypeId."); + } + XLA_FFI_TypeId* type_id_ptr = + reinterpret_cast(static_cast(capsule.data())); + return ffi::TakeStatus(ffi::Ffi::RegisterTypeId(xla::ffi::GetXlaFfiApi(), + type_name, type_id_ptr)); +} + +template +void DefRepeatedProperty(nb::class_& cls, const char* name, + Container* (T::*getter)()) { + cls.def_prop_rw( + name, + [getter](T& obj) { + Container* elems = (obj.*getter)(); + std::vector result; + result.reserve(elems->size()); + std::copy(elems->begin(), elems->end(), std::back_inserter(result)); + return result; + }, + [getter](T& obj, std::vector new_elems) { + Container* elems = (obj.*getter)(); + elems->Clear(); + elems->Reserve(new_elems.size()); + for (typename Container::value_type& e : new_elems) { + elems->Add(std::move(e)); + } + }); +} + +template +void DefRepeatedEnumProperty(nb::class_& cls, const char* name, + Container* (T::*getter)()) { + cls.def_prop_rw( + name, + [getter](T& obj) { + Container* elems = (obj.*getter)(); + std::vector result; + result.reserve(elems->size()); + std::copy(elems->begin(), elems->end(), std::back_inserter(result)); + return result; + }, + [getter](T& obj, nb::sequence new_elems) { + Container* elems = (obj.*getter)(); + elems->Clear(); + for (nb::handle e : new_elems) { + elems->Add(nb::cast(e.attr("value"))); + } + }); +} + +template +Array NDArrayToArray(nb::ndarray ndarray) { + std::vector shapes; + shapes.reserve(ndarray.ndim()); + for (int i = 0; i < ndarray.ndim(); ++i) { + shapes.push_back(ndarray.shape(i)); + } + xla::Array array(shapes); + array.Each([&](absl::Span indices, int64_t* val) { + int64_t offset = indices.back(); + int64_t multiplier = 1; + for (int i = ndarray.ndim() - 1; i > 0; --i) { + multiplier *= ndarray.shape(i); + offset += indices[i - 1] * multiplier; + } + *val = *(ndarray.data() + offset); + }); + return array; +} + +absl::StatusOr SubgroupWithTileAssignmentHelper( + nb::ndarray tile_assignment, + absl::Span subgroup_types) { + return HloSharding::Subgroup(NDArrayToArray(tile_assignment), subgroup_types); +} + +nb::ndarray<> LiteralToNdarray(Literal& obj) { + const Shape& shape = obj.shape(); + + if (!shape.has_layout()) { + throw XlaRuntimeError( + "Creating an array is only supported for Literals with a layout."); + } + + const Layout& layout = shape.layout(); + + if (!layout.tiles().empty()) { + throw XlaRuntimeError( + "Creating an array from a tiled Literal is not supported."); + } + + if (!LayoutUtil::IsDenseArray(shape)) { + throw XlaRuntimeError( + "Creating an array is only supported for dense Literals."); + } + + xla::PrimitiveType primitive_type = shape.element_type(); + nb::dlpack::dtype dtype = + ValueOrThrow(PrimitiveTypeToNbDLDataType(primitive_type)); + + absl::Span dimensions = shape.dimensions(); + std::vector unsigned_dimensions(dimensions.begin(), dimensions.end()); + auto strides = StridesForShape(primitive_type, dimensions, layout); + + return nb::ndarray<>(obj.untyped_data(), unsigned_dimensions.size(), + unsigned_dimensions.data(), {}, strides.data(), dtype, + nb::device::cpu::value, 0); +} + +} // namespace + +void BuildXlaCompilerSubmodule(nb::module_& m) { + // Shapes + nb::class_ layout_class(m, "Layout"); + layout_class.def(nb::init>()) + .def("__init__", + [](Layout* self, nb::sequence minor_to_major, nb::sequence tiling, + int64_t element_size_in_bits) { + std::vector xla_tiles; + xla_tiles.reserve(nb::len(tiling.ptr())); + for (auto tile : tiling) { + xla_tiles.push_back(Tile( + SequenceToVector(nb::cast(tile)))); + } + std::vector xla_minor_to_major = + SequenceToVector(minor_to_major); + new (self) + Layout(xla_minor_to_major, xla_tiles, element_size_in_bits); + }) + .def("minor_to_major", + [](Layout layout) { return SpanToNbTuple(layout.minor_to_major()); }) + .def("element_size_in_bits", &Layout::element_size_in_bits) + .def("tiling", + [](Layout layout) { + std::vector result; + result.reserve(layout.tiles().size()); + for (auto& t : layout.tiles()) { + result.push_back(SpanToNbTuple(t.dimensions())); + } + return result; + }) + .def("__eq__", [](const Layout& layout, + const Layout& other) { return layout == other; }) + .def("__ne__", [](const Layout& layout, + const Layout& other) { return layout != other; }) + .def("__str__", &Layout::ToString) + .def("__hash__", + [](const Layout& layout) { return absl::HashOf(layout); }) + .def("to_string", &Layout::ToString) + .def("__getstate__", + [](const Layout& self) -> nb::tuple { + auto proto = self.ToProto(); + std::string result; + if (!tsl::SerializeToStringDeterministic(proto, &result)) { + // throw converted by PyBind to a Python RuntimeError. + throw XlaRuntimeError( + absl::StrCat("Layout.py_pickle: ", + "SerializeToStringDeterministic failed")); + } + return nb::make_tuple(nb::bytes(result.data(), result.size())); + }) + .def("__setstate__", [](Layout* self, nb::tuple t) { + LayoutProto result; + nb::bytes serialized = nb::cast(t[0]); + result.ParseFromArray(serialized.c_str(), serialized.size()); + new (self) Layout(ValueOrThrow(Layout::FromProto(result))); + }); + + nb::class_ shape_class(m, "Shape"); + shape_class + .def("__init__", + [](Shape* self, const std::string& s) { + new (self) Shape(ValueOrThrow(ParseShape(s))); + }) + .def_static( + "tuple_shape", + [](std::vector shapes) -> Shape { + return ShapeUtil::MakeTupleShape(shapes); + }, + "Constructs a tuple shape.") + .def_static("array_shape", + xla::ValueOrThrowWrapper( + [](PrimitiveType type, nb::sequence dims_seq, + std::optional layout_seq, + std::optional> dynamic_dimensions) + -> absl::StatusOr { + std::vector dims = + SequenceToVector(dims_seq); + if (layout_seq) { + std::vector layout = + SequenceToVector(*layout_seq); + return MakeShapeWithDenseLayout(type, dims, layout, + dynamic_dimensions); + } else { + return MakeShapeWithDenseLayout( + type, dims, std::nullopt, dynamic_dimensions); + } + }), + "Constructs an array shape.", nb::arg("type"), + nb::arg("dims"), nb::arg("layout").none() = std::nullopt, + nb::arg("dynamic_dimensions").none() = std::nullopt) + .def_static( + "array_shape", + xla::ValueOrThrowWrapper( + [](nb_dtype dtype, nb::sequence dims_seq, + std::optional layout_seq, + std::optional> dynamic_dimensions) + -> absl::StatusOr { + PrimitiveType type = ValueOrThrow(DtypeToPrimitiveType(dtype)); + std::vector dims = SequenceToVector(dims_seq); + if (layout_seq) { + std::vector layout = + SequenceToVector(*layout_seq); + return MakeShapeWithDenseLayout(type, dims, layout, + dynamic_dimensions); + } else { + return MakeShapeWithDenseLayout(type, dims, std::nullopt, + dynamic_dimensions); + } + }), + "Constructs an array shape.", nb::arg("type"), nb::arg("dims"), + nb::arg("layout").none() = std::nullopt, + nb::arg("dynamic_dimensions").none() = std::nullopt) + .def_static("token_shape", []() { return ShapeUtil::MakeTokenShape(); }) + .def_static( + "scalar_shape", + [](PrimitiveType type) -> Shape { + return ShapeUtil::MakeScalarShape(type); + }, + "Constructs a scalar shape.", nb::arg("type")) + .def_static( + "scalar_shape", + [](nb_dtype dtype) -> Shape { + PrimitiveType type = xla::ValueOrThrow(DtypeToPrimitiveType(dtype)); + return ShapeUtil::MakeScalarShape(type); + }, + "Constructs a scalar shape.", nb::arg("type")) + .def("dimensions", + [](const Shape& shape) -> nb::tuple { + return SpanToNbTuple(shape.dimensions()); + }) + .def("layout", + [](const Shape& shape) -> Layout { return shape.layout(); }) + .def("xla_element_type", &Shape::element_type) + .def("element_type", + [](const Shape& shape) { + return xla::ValueOrThrow( + PrimitiveTypeToNbDtype(shape.element_type())); + }) + .def("numpy_dtype", + [](const Shape& shape) { + if (shape.IsTuple()) { + return nb_dtype("O"); + } + return xla::ValueOrThrow( + PrimitiveTypeToNbDtype(shape.element_type())); + }) + .def("is_tuple", &Shape::IsTuple) + .def("is_array", &Shape::IsArray) + .def("is_token", &Shape::IsToken) + .def("is_static", &Shape::is_static) + .def("is_dynamic", &Shape::is_dynamic) + .def("is_dynamic_dimension", &Shape::is_dynamic_dimension, + nb::arg("dimension")) + .def("set_dynamic_dimension", &Shape::set_dynamic_dimension, + nb::arg("dimension"), nb::arg("is_dynamic")) + .def("rank", &Shape::dimensions_size) + .def("to_serialized_proto", + [](const Shape& shape) { + ShapeProto proto = shape.ToProto(); + std::string s = proto.SerializeAsString(); + return nb::bytes(s.data(), s.size()); + }) + .def("tuple_shapes", + [](const Shape& shape) { + return std::vector(shape.tuple_shapes()); + }) + .def("leaf_count", + [](const Shape& shape) { return ShapeUtil::GetLeafCount(shape); }) + .def( + "with_major_to_minor_layout_if_absent", + [](const Shape& shape) { + Shape out = shape; + ShapeUtil::ForEachMutableSubshape( + &out, [](Shape* subshape, const ShapeIndex&) { + if (!subshape->has_layout()) { + LayoutUtil::SetToDefaultLayout(subshape); + } + }); + return out; + }, + "Returns a copy of a shape with missing layouts set to " + "major-to-minor.") + .def("__eq__", [](const Shape& shape, + const Shape& other) { return shape == other; }) + .def("__ne__", [](const Shape& shape, + const Shape& other) { return shape != other; }) + .def("__hash__", [](const Shape& shape) { return absl::HashOf(shape); }) + .def("__repr__", [](const Shape& shape) { + return shape.ToString(/*print_layout=*/true); + }); + + nb::class_(m, "ProgramShape") + .def( + "__init__", + [](ProgramShape* self, absl::Span params, Shape result) { + new (self) ProgramShape(); + for (const Shape& param : params) { + self->AddParameter(param, ""); + } + *self->mutable_result() = result; + }) + .def("parameter_shapes", + static_cast& (ProgramShape::*)() const>( + &ProgramShape::parameters)) + .def("result_shape", &ProgramShape::result) + .def("__repr__", &ProgramShape::ToString); + + // Literals + nb::class_(m, "Literal") + .def(nb::init()) + .def("__repr__", &Literal::ToString) + .def( + "__array__", + [](std::shared_ptr obj, std::optional dtype, + std::optional copy) { + // Provides the interface required by numpy to create a np.ndarray. + // Currently don't support the __dl_pack__ interface but can be + // added with very little effort it if needed. + + nb::ndarray np_array(LiteralToNdarray(*obj)); + + if (dtype.has_value()) { + throw XlaRuntimeError( + "Passing of dtype to __array__ not currently supported."); + } + + if (copy.has_value() && *copy) { + // when a copy is requested we _must_ return a copy: + // https://numpy.org/doc/2.1/reference/generated/numpy.ndarray.__array__.html + return np_array.cast(nb::rv_policy::copy); + } + + return np_array.cast(nb::rv_policy::reference_internal, + nb::cast(obj)); + }, + nb::arg("dtype").none() = nb::none(), + nb::arg("copy").none() = nb::none()) + .def("shape", &Literal::shape); + + nb::class_(m, "XlaComputation") + .def("__init__", + [](XlaComputation* self, + const nb::bytes& serialized_hlo_module_proto) { + HloModuleProto proto; + proto.ParseFromArray(serialized_hlo_module_proto.c_str(), + serialized_hlo_module_proto.size()); + new (self) XlaComputation(proto); + }) + .def("get_hlo_module", xla::ValueOrThrowWrapper(GetHloModule)) + .def("program_shape", + xla::ValueOrThrowWrapper(&XlaComputation::GetProgramShape)) + .def("name", &XlaComputation::name) + .def("as_serialized_hlo_module_proto", + xla::ValueOrThrowWrapper(GetComputationSerializedProto)) + .def("as_hlo_text", xla::ValueOrThrowWrapper(GetComputationHloText), + nb::arg("print_large_constants") = false) + .def("as_hlo_dot_graph", + xla::ValueOrThrowWrapper(GetComputationHloDotGraph)) + .def("hash", xla::ValueOrThrowWrapper(HashComputation)) + .def("as_hlo_module", xla::ValueOrThrowWrapper(GetHloModule)); + + nb::class_ hlo_print_options_class(m, "HloPrintOptions"); + hlo_print_options_class.def(nb::init<>()) + .def_static("short_parsable", &HloPrintOptions::ShortParsable) + .def_static("canonical", &HloPrintOptions::Canonical) + .def_static("fingerprint", &HloPrintOptions::Fingerprint) + .def_prop_rw("print_large_constants", + &HloPrintOptions::print_large_constants, + &HloPrintOptions::set_print_large_constants) + .def_prop_rw("print_metadata", &HloPrintOptions::print_metadata, + &HloPrintOptions::set_print_metadata) + .def_prop_rw("print_backend_config", + &HloPrintOptions::print_backend_config, + &HloPrintOptions::set_print_backend_config) + .def_prop_rw("print_result_shape", &HloPrintOptions::print_result_shape, + &HloPrintOptions::set_print_result_shape) + .def_prop_rw("print_operand_shape", &HloPrintOptions::print_operand_shape, + &HloPrintOptions::set_print_operand_shape) + .def_prop_rw("print_operand_names", &HloPrintOptions::print_operand_names, + &HloPrintOptions::set_print_operand_names) + .def_prop_rw("print_ids", &HloPrintOptions::print_ids, + &HloPrintOptions::set_print_ids) + .def_prop_rw("print_extra_attributes", + &HloPrintOptions::print_extra_attributes, + &HloPrintOptions::set_print_extra_attributes) + .def_prop_rw("print_program_shape", &HloPrintOptions::print_program_shape, + &HloPrintOptions::set_print_program_shape) + .def_prop_rw("print_percent", &HloPrintOptions::print_percent, + &HloPrintOptions::set_print_percent) + .def_prop_rw("print_control_dependencies", + &HloPrintOptions::print_control_dependencies, + &HloPrintOptions::set_print_control_dependencies) + .def_prop_rw("compact_operands", &HloPrintOptions::compact_operands, + &HloPrintOptions::set_compact_operands) + .def_prop_rw("include_layout_in_shapes", + &HloPrintOptions::include_layout_in_shapes, + &HloPrintOptions::set_include_layout_in_shapes) + .def_prop_rw("canonicalize_instruction_names", + &HloPrintOptions::canonicalize_instruction_names, + &HloPrintOptions::set_canonicalize_instruction_names) + .def_prop_rw("canonicalize_computations", + &HloPrintOptions::canonicalize_computations, + &HloPrintOptions::set_canonicalize_computations) + .def_prop_rw("indent_amount", &HloPrintOptions::indent_amount, + &HloPrintOptions::set_indent_amount) + .def_prop_rw("is_in_nested_computation", + &HloPrintOptions::is_in_nested_computation, + &HloPrintOptions::set_is_in_nested_computation); + + // HloModule.computations() returns raw pointers. + // pybind seems to prefer smart pointers. + // We give pybind a smart pointer to a wrapper around a raw pointer to satisfy + // pybind and avoid double frees. + class ComputationWrapper { + public: + ComputationWrapper(const HloComputation* comp, + const std::shared_ptr module) + : comp_(comp), module_(module) {} + absl::string_view name() const { return comp_->name(); } + void render_html(const std::string& filename) { + std::string html = xla::ValueOrThrow(RenderGraph( + *comp_, /*label=*/"", comp_->parent()->config().debug_options(), + RenderedGraphFormat::kHtml, HloRenderOptions())); + xla::ThrowIfError(tsl::WriteStringToFile( + tsl::Env::Default(), absl::StrCat(filename, ".html"), html)); + } + + private: + const HloComputation* comp_; + // The module owns the computations: if its destructor is called, the + // computations are freed. To prevent that from happening in cases where the + // module Python object goes out of scope and gets garbage collected before + // the computations, we keep a shared_ptr to the module that originated the + // computation. + const std::shared_ptr module_; + }; + + nb::class_ hlo_computation_class(m, "HloComputation"); + + hlo_computation_class.def_prop_ro("name", &ComputationWrapper::name) + .def("render_html", &ComputationWrapper::render_html); + + nb::class_ hlo_module_class(m, "HloModule"); + hlo_module_class.def_prop_ro("name", &HloModule::name) + .def( + "to_string", + static_cast( + &HloModule::ToString), + nb::arg("options") = HloPrintOptions()) + .def("as_serialized_hlo_module_proto", + xla::ValueOrThrowWrapper(GetHloModuleSerializedProto)) + .def("from_serialized_hlo_module_proto", + xla::ValueOrThrowWrapper(HloModuleFromSerializedProto)) + .def("computations", + [](const std::shared_ptr m) + -> std::vector> { + std::vector> computations; + for (HloComputation* comp : m->computations()) + computations.push_back( + std::make_shared(comp, m)); + return computations; + }) + .def_prop_ro("spmd_output_sharding", + [](const HloModule& m) -> std::optional { + if (!m.has_spmd_output_sharding()) return std::nullopt; + return m.spmd_output_sharding().ToProto(); + }) + .def_prop_ro("spmd_parameters_shardings", + [](const HloModule& m) + -> std::optional> { + if (!m.has_spmd_parameters_shardings()) + return std::nullopt; + std::vector param_shardings; + for (const auto& parameter_sharding : + m.spmd_parameters_shardings()) { + param_shardings.push_back(parameter_sharding.ToProto()); + } + return param_shardings; + }); + + m.def("hlo_module_to_dot_graph", + [](const HloModule& hlo_module) -> std::string { + return xla::ValueOrThrow(RenderGraph( + *hlo_module.entry_computation(), /*label=*/"", + hlo_module.config().debug_options(), RenderedGraphFormat::kDot)); + }); + m.def( + "hlo_module_cost_analysis", + xla::ValueOrThrowWrapper([](PyClient* client, const HloModule& module) + -> absl::StatusOr { + TF_ASSIGN_OR_RETURN(auto analysis, + client->pjrt_client()->GetHloCostAnalysis()); + TF_RETURN_IF_ERROR(module.entry_computation()->Accept(analysis.get())); + + // Convert from HloCostAnalysis::Properties to a standard map. + nb::dict ret; + analysis->properties().ForEach([&](absl::string_view key, float val) { + ret[nb::str(key.data(), key.size())] = nb::cast(val); + }); + return ret; + })); + m.def("hlo_module_from_text", + xla::ValueOrThrowWrapper( + [](const std::string& hlo_module_text) + -> absl::StatusOr> { + auto hlo_module = + xla::ParseAndReturnUnverifiedModule(hlo_module_text); + TF_RETURN_IF_ERROR(hlo_module.status()); + std::shared_ptr result(std::move(*hlo_module)); + return result; + })); + + // Device assignments + nb::class_(m, "DeviceAssignment") + .def_static( + "create", + xla::ValueOrThrowWrapper([](nb::ndarray> array) + -> absl::StatusOr { + if (array.ndim() != 2) { + return InvalidArgument( + "Argument to DeviceAssignment constructor must be a " + "2D array, received an %dD array.", + array.ndim()); + } + DeviceAssignment result(array.shape(0), array.shape(1)); + for (int i = 0; i < array.shape(0); ++i) { + for (int j = 0; j < array.shape(1); ++j) { + result(i, j) = array(i, j); + } + } + return result; + })) + .def("replica_count", &DeviceAssignment::replica_count) + .def("computation_count", &DeviceAssignment::computation_count) + .def("__repr__", &DeviceAssignment::ToString) + .def("serialize", + xla::ValueOrThrowWrapper( + [](const DeviceAssignment& da) -> absl::StatusOr { + DeviceAssignmentProto proto; + da.Serialize(&proto); + std::string result; + if (!tsl::SerializeToStringDeterministic(proto, &result)) { + return Unknown( + "Failed to serialize the DeviceAssignmentProto."); + } + return nb::bytes(result.data(), result.size()); + })); + + nb::class_ compile_options(m, "CompileOptions"); + compile_options + .def("__init__", + [](CompileOptions* self) { + new (self) CompileOptions(); + DebugOptions* debug_options = + self->executable_build_options.mutable_debug_options(); + // Sets fast-math-disabling default options expected by JAX. + debug_options->set_xla_cpu_enable_fast_min_max(false); + debug_options->set_xla_gpu_enable_fast_min_max(false); + }) + .def("__getstate__", + [](const CompileOptions& self) -> nb::tuple { + auto proto = ValueOrThrow(self.ToProto()); + std::string result; + if (!tsl::SerializeToStringDeterministic(proto, &result)) { + // throw converted by PyBind to a Python RuntimeError. + throw XlaRuntimeError( + absl::StrCat("CompileOptions.py_pickle: ", + "SerializeToStringDeterministic failed")); + } + return nb::make_tuple(nb::bytes(result.data(), result.size())); + }) + .def("__setstate__", + [](CompileOptions* self, nb::tuple t) { + CompileOptionsProto result; + nb::bytes serialized = nb::cast(t[0]); + result.ParseFromArray(serialized.c_str(), serialized.size()); + new (self) CompileOptions( + ValueOrThrow(CompileOptions::FromProto(result))); + }) + .def("SerializeAsString", + [](const CompileOptions& self) -> nb::bytes { + auto proto = ValueOrThrow(self.ToProto()); + std::string result; + if (!tsl::SerializeToStringDeterministic(proto, &result)) { + // throw converted by PyBind to a Python RuntimeError. + throw XlaRuntimeError( + absl::StrCat("CompileOptions.SerializeAsString: ", + "SerializeToStringDeterministic failed")); + } + return nb::bytes(result.data(), result.size()); + }) + .def_static("ParseFromString", + [](nb::bytes s) { + CompileOptionsProto result; + result.ParseFromArray(s.c_str(), s.size()); + return ValueOrThrow(CompileOptions::FromProto(result)); + }) + .def_rw("argument_layouts", &CompileOptions::argument_layouts) + .def_rw("parameter_is_tupled_arguments", + &CompileOptions::parameter_is_tupled_arguments) + .def_rw("compile_portable_executable", + &CompileOptions::compile_portable_executable) + .def_ro("executable_build_options", + &CompileOptions::executable_build_options) + .def_rw("env_option_overrides", &CompileOptions::env_option_overrides) + // TODO(phawkins): the following fields exist for backward compatibility. + // Remove them after JAX has been updated not to use them. + .def_rw("tuple_arguments", &CompileOptions::parameter_is_tupled_arguments) + .def_prop_rw( + "num_replicas", + [](const CompileOptions& options) { + return options.executable_build_options.num_replicas(); + }, + [](CompileOptions& options, int num_replicas) { + options.executable_build_options.set_num_replicas(num_replicas); + }) + .def_prop_rw( + "num_partitions", + [](const CompileOptions& options) { + return options.executable_build_options.num_partitions(); + }, + [](CompileOptions& options, int num_partitions) { + options.executable_build_options.set_num_partitions(num_partitions); + }) + .def_prop_rw( + "profile_version", + [](const CompileOptions& options) { return options.profile_version; }, + [](CompileOptions& options, int64_t profile_version) { + options.profile_version = profile_version; + }) + .def_prop_rw( + "device_assignment", + [](const CompileOptions& options) -> std::optional { + return options.executable_build_options.has_device_assignment() + ? std::optional( + options.executable_build_options + .device_assignment()) + : std::nullopt; + }, + [](CompileOptions& options, + const DeviceAssignment& device_assignment) { + options.executable_build_options.set_device_assignment( + device_assignment); + }); + + // Custom-call targets. + m.def( + "register_custom_call_target", + [](nb::object fn_name_py, nb::object fn, const std::string& platform, + int api_version, XLA_FFI_Handler_Traits traits) { + std::string fn_name; + if (!nb::try_cast(fn_name_py, fn_name)) { + nb::bytes bytes = nb::cast(fn_name_py); + fn_name = std::string(bytes.c_str(), bytes.size()); + } + xla::ThrowIfError(PyRegisterCustomCallTarget( + fn_name, std::move(fn), platform, api_version, traits)); + }, + nb::arg("fn_name"), nb::arg("fn"), nb::arg("platform"), + nb::arg("api_version") = 0, nb::arg("traits") = 0); + + m.def( + "custom_call_targets", + [](const std::string& platform) -> nb::dict { + nb::dict targets; + for (const auto& [name, target] : + CustomCallTargetRegistry::Global()->registered_symbols(platform)) { + targets[nb::str(name.data(), name.size())] = nb::capsule(target); + } + + auto ffi_handlers = ffi::StaticRegisteredHandlers(platform); + if (!ffi_handlers.ok()) return targets; + + for (const auto& [name, registration] : *ffi_handlers) { + nb::dict bundle; + auto export_handler = [&](absl::string_view name, + XLA_FFI_Handler* h) { + if (h != nullptr) { + bundle[nb::str(name.data(), name.size())] = + nb::capsule(reinterpret_cast(h)); + } + }; + export_handler("prepare", registration.bundle.prepare); + export_handler("initialize", registration.bundle.initialize); + export_handler("execute", registration.bundle.execute); + targets[nb::str(name.data(), name.size())] = std::move(bundle); + } + return targets; + }, + nb::arg("platform")); + + nb::enum_(m, "AutotuneCacheMode") + .value("UNSPECIFIED", DebugOptions::AUTOTUNE_CACHE_MODE_UNSPECIFIED) + .value("UPDATE", DebugOptions::AUTOTUNE_CACHE_MODE_UPDATE) + .value("READ", DebugOptions::AUTOTUNE_CACHE_MODE_READ); + + m.def( + "register_custom_type_id", + [](absl::string_view type_name, nb::object type_id) { + xla::ThrowIfError(PyRegisterCustomTypeId(type_name, type_id)); + }, + nb::arg("type_name"), nb::arg("type_id")); + + nb::class_(m, "DebugOptions") + .def("__repr__", &DebugOptions::DebugString) + .def_prop_rw("xla_backend_optimization_level", + &DebugOptions::xla_backend_optimization_level, + &DebugOptions::set_xla_backend_optimization_level) + .def_prop_rw("xla_cpu_enable_fast_math", + &DebugOptions::xla_cpu_enable_fast_math, + &DebugOptions::set_xla_cpu_enable_fast_math) + .def_prop_rw("xla_cpu_enable_xprof_traceme", + &DebugOptions::xla_cpu_enable_xprof_traceme, + &DebugOptions::set_xla_cpu_enable_xprof_traceme) + .def_prop_rw("xla_cpu_fast_math_honor_infs", + &DebugOptions::xla_cpu_fast_math_honor_infs, + &DebugOptions::set_xla_cpu_fast_math_honor_infs) + .def_prop_rw("xla_cpu_fast_math_honor_nans", + &DebugOptions::xla_cpu_fast_math_honor_nans, + &DebugOptions::set_xla_cpu_fast_math_honor_nans) + .def_prop_rw("xla_cpu_fast_math_honor_division", + &DebugOptions::xla_cpu_fast_math_honor_division, + &DebugOptions::set_xla_cpu_fast_math_honor_division) + .def_prop_rw("xla_cpu_fast_math_honor_functions", + &DebugOptions::xla_cpu_fast_math_honor_functions, + &DebugOptions::set_xla_cpu_fast_math_honor_functions) + .def_prop_rw("xla_detailed_logging", &DebugOptions::xla_detailed_logging, + &DebugOptions::set_xla_detailed_logging) + .def_prop_rw("xla_enable_dumping", &DebugOptions::xla_enable_dumping, + &DebugOptions::set_xla_enable_dumping) + .def_prop_rw("xla_gpu_enable_fast_min_max", + &DebugOptions::xla_gpu_enable_fast_min_max, + &DebugOptions::set_xla_gpu_enable_fast_min_max) + .def_prop_rw("xla_gpu_dump_autotune_results_to", + &DebugOptions::xla_gpu_dump_autotune_results_to, + [](DebugOptions* self, std::string value) { + self->set_xla_gpu_dump_autotune_results_to(value); + }) + .def_prop_rw("xla_gpu_load_autotune_results_from", + &DebugOptions::xla_gpu_load_autotune_results_from, + [](DebugOptions* self, std::string value) { + self->set_xla_gpu_load_autotune_results_from(value); + }) + .def_prop_rw("xla_gpu_cuda_data_dir", + &DebugOptions::xla_gpu_cuda_data_dir, + [](DebugOptions* self, std::string value) { + self->set_xla_gpu_cuda_data_dir(value); + }) + .def_prop_rw("xla_llvm_disable_expensive_passes", + &DebugOptions::xla_llvm_disable_expensive_passes, + &DebugOptions::set_xla_llvm_disable_expensive_passes) + .def_prop_rw( + "xla_disable_hlo_passes", + [](DebugOptions* self) { + return absl::StrJoin(self->xla_disable_hlo_passes(), ","); + }, + [](DebugOptions* self, std::string value) { + self->clear_xla_disable_hlo_passes(); + for (const auto& passname : + std::vector(absl::StrSplit(value, ','))) { + self->add_xla_disable_hlo_passes(passname); + } + }) + .def_prop_rw( + "xla_enable_hlo_passes_only", + [](DebugOptions* self) { + return absl::StrJoin(self->xla_enable_hlo_passes_only(), ","); + }, + [](DebugOptions* self, std::string value) { + self->clear_xla_enable_hlo_passes_only(); + for (const auto& passname : + std::vector(absl::StrSplit(value, ','))) { + self->add_xla_enable_hlo_passes_only(passname); + } + }) + .def_prop_rw("xla_test_all_input_layouts", + &DebugOptions::xla_test_all_input_layouts, + &DebugOptions::set_xla_test_all_input_layouts) + .def_prop_rw("xla_force_host_platform_device_count", + &DebugOptions::xla_force_host_platform_device_count, + &DebugOptions::set_xla_force_host_platform_device_count) + .def_prop_rw("xla_dump_to", &DebugOptions::xla_dump_to, + [](DebugOptions* self, std::string value) { + self->set_xla_dump_to(value); + }) + .def_prop_rw("xla_dump_hlo_module_re", + &DebugOptions::xla_dump_hlo_module_re, + [](DebugOptions* self, std::string value) { + self->set_xla_dump_hlo_module_re(value); + }) + .def_prop_rw("xla_dump_hlo_pass_re", &DebugOptions::xla_dump_hlo_pass_re, + [](DebugOptions* self, std::string value) { + self->set_xla_dump_hlo_pass_re(value); + }) + .def_prop_rw("xla_dump_hlo_as_text", &DebugOptions::xla_dump_hlo_as_text, + &DebugOptions::set_xla_dump_hlo_as_text) + .def_prop_rw("xla_dump_hlo_as_proto", + &DebugOptions::xla_dump_hlo_as_proto, + &DebugOptions::set_xla_dump_hlo_as_proto) + .def_prop_rw("xla_dump_hlo_as_dot", &DebugOptions::xla_dump_hlo_as_dot, + &DebugOptions::set_xla_dump_hlo_as_dot) + .def_prop_rw("xla_dump_hlo_as_url", &DebugOptions::xla_dump_hlo_as_url, + &DebugOptions::set_xla_dump_hlo_as_url) + .def_prop_rw("xla_dump_hlo_as_html", &DebugOptions::xla_dump_hlo_as_html, + &DebugOptions::set_xla_dump_hlo_as_html) + .def_prop_rw("xla_dump_fusion_visualization", + &DebugOptions::xla_dump_fusion_visualization, + &DebugOptions::set_xla_dump_fusion_visualization) + .def_prop_rw("xla_dump_hlo_snapshots", + &DebugOptions::xla_dump_hlo_snapshots, + &DebugOptions::set_xla_dump_hlo_snapshots) + .def_prop_rw("xla_dump_max_hlo_modules", + &DebugOptions::xla_dump_max_hlo_modules, + &DebugOptions::set_xla_dump_max_hlo_modules) + .def_prop_rw("xla_dump_module_metadata", + &DebugOptions::xla_dump_module_metadata, + &DebugOptions::set_xla_dump_module_metadata) + .def_prop_rw("xla_dump_compress_protos", + &DebugOptions::xla_dump_compress_protos, + &DebugOptions::set_xla_dump_compress_protos) + .def_prop_rw("xla_dump_hlo_as_long_text", + &DebugOptions::xla_dump_hlo_as_long_text, + &DebugOptions::set_xla_dump_hlo_as_long_text) + .def_prop_rw("xla_dump_disable_metadata", + &DebugOptions::xla_dump_disable_metadata, + &DebugOptions::set_xla_dump_disable_metadata) + .def_prop_rw("xla_dump_hlo_pipeline_re", + &DebugOptions::xla_dump_hlo_pipeline_re, + [](DebugOptions* self, std::string value) { + self->set_xla_dump_hlo_pipeline_re(value); + }) + .def_prop_rw("xla_gpu_dump_autotune_logs_to", + &DebugOptions::xla_gpu_dump_autotune_logs_to, + [](DebugOptions* self, std::string value) { + self->set_xla_gpu_dump_autotune_logs_to(value); + }) + .def_prop_rw("xla_gpu_kernel_cache_file", + &DebugOptions::xla_gpu_kernel_cache_file, + [](DebugOptions* self, std::string value) { + self->set_xla_gpu_kernel_cache_file(value); + }) + .def_prop_rw( + "xla_gpu_enable_llvm_module_compilation_parallelism", + &DebugOptions::xla_gpu_enable_llvm_module_compilation_parallelism, + &DebugOptions::set_xla_gpu_enable_llvm_module_compilation_parallelism) + .def_prop_rw("xla_gpu_per_fusion_autotune_cache_dir", + &DebugOptions::xla_gpu_per_fusion_autotune_cache_dir, + [](DebugOptions* self, std::string value) { + self->set_xla_gpu_per_fusion_autotune_cache_dir(value); + }) + .def_prop_rw("xla_gpu_experimental_autotune_cache_mode", + &DebugOptions::xla_gpu_experimental_autotune_cache_mode, + &DebugOptions::set_xla_gpu_experimental_autotune_cache_mode); + + nb::class_(m, "ExecutableBuildOptions") + .def(nb::init<>()) + .def("__repr__", &ExecutableBuildOptions::ToString) + .def_prop_rw( + "fdo_profile", + [](const ExecutableBuildOptions& options) { + return nb::bytes(options.fdo_profile().data(), + options.fdo_profile().size()); + }, + [](ExecutableBuildOptions& options, nb::bytes fdo_profile) { + options.set_fdo_profile( + std::string(fdo_profile.c_str(), fdo_profile.size())); + }) + .def_prop_rw( + "result_layout", + [](const ExecutableBuildOptions& options) -> std::optional { + return options.result_layout() + ? std::optional(*options.result_layout()) + : std::nullopt; + }, + &ExecutableBuildOptions::set_result_layout) + .def_prop_rw("num_replicas", &ExecutableBuildOptions::num_replicas, + &ExecutableBuildOptions::set_num_replicas) + .def_prop_rw("num_partitions", &ExecutableBuildOptions::num_partitions, + &ExecutableBuildOptions::set_num_partitions) + .def_prop_ro("debug_options", + &ExecutableBuildOptions::mutable_debug_options, + nb::rv_policy::reference, nb::keep_alive<1, 0>()) + .def_prop_rw( + "device_assignment", + [](const ExecutableBuildOptions& options) + -> std::optional { + return options.has_device_assignment() + ? std::optional( + options.device_assignment()) + : std::nullopt; + }, + &ExecutableBuildOptions::set_device_assignment) + .def("compilation_environments_from_serialized_proto", + [](ExecutableBuildOptions& options, + const nb::bytes& serialized_proto) { + xla::CompilationEnvironmentsProto env_proto; + env_proto.ParseFromArray(serialized_proto.c_str(), + serialized_proto.size()); + auto comp_envs = xla::ValueOrThrow( + xla::CompilationEnvironments::CreateFromProto(env_proto)); + *options.mutable_comp_envs() = std::move(*comp_envs); + }) + .def_prop_rw("exec_time_optimization_effort", + &ExecutableBuildOptions::exec_time_optimization_effort, + &ExecutableBuildOptions::set_exec_time_optimization_effort) + .def_prop_rw("memory_fitting_effort", + &ExecutableBuildOptions::memory_fitting_effort, + &ExecutableBuildOptions::set_memory_fitting_effort) + .def_prop_rw( + "optimization_level", &ExecutableBuildOptions::optimization_level, + [](ExecutableBuildOptions& options, int value) { + options.set_optimization_level( + static_cast(value)); + }) + .def_prop_rw( + "memory_fitting_level", &ExecutableBuildOptions::memory_fitting_level, + [](ExecutableBuildOptions& options, int value) { + options.set_memory_fitting_level( + static_cast(value)); + }) + .def_prop_rw("use_spmd_partitioning", + &ExecutableBuildOptions::use_spmd_partitioning, + &ExecutableBuildOptions::set_use_spmd_partitioning) + .def_prop_rw("use_auto_spmd_partitioning", + &ExecutableBuildOptions::use_auto_spmd_partitioning, + &ExecutableBuildOptions::set_use_auto_spmd_partitioning) + .def_prop_rw( + "auto_spmd_partitioning_mesh_shape", + &ExecutableBuildOptions::auto_spmd_partitioning_mesh_shape, + &ExecutableBuildOptions::set_auto_spmd_partitioning_mesh_shape) + .def_prop_rw("auto_spmd_partitioning_mesh_ids", + &ExecutableBuildOptions::auto_spmd_partitioning_mesh_ids, + &ExecutableBuildOptions::set_auto_spmd_partitioning_mesh_ids) + .def_prop_rw( + "allow_spmd_sharding_propagation_to_parameters", + [](const ExecutableBuildOptions& options) -> std::vector { + return std::vector( + options.allow_spmd_sharding_propagation_to_parameters().begin(), + options.allow_spmd_sharding_propagation_to_parameters().end()); + }, + [](ExecutableBuildOptions& options, std::vector values) { + absl::InlinedVector v(values.begin(), values.end()); + options.set_allow_spmd_sharding_propagation_to_parameters(v); + }) + .def_prop_rw( + "allow_spmd_sharding_propagation_to_output", + [](const ExecutableBuildOptions& options) -> std::vector { + return std::vector( + options.allow_spmd_sharding_propagation_to_output().begin(), + options.allow_spmd_sharding_propagation_to_output().end()); + }, + [](ExecutableBuildOptions& options, std::vector values) { + absl::InlinedVector v(values.begin(), values.end()); + options.set_allow_spmd_sharding_propagation_to_output(v); + }) + .def_prop_rw("use_shardy_partitioner", + &ExecutableBuildOptions::use_shardy_partitioner, + &ExecutableBuildOptions::set_use_shardy_partitioner); + + nb::enum_ op_sharding_type(m, "OpSharding_Type", + nb::is_arithmetic()); + op_sharding_type.value("REPLICATED", OpSharding::REPLICATED) + .value("MAXIMAL", OpSharding::MAXIMAL) + .value("MANUAL", OpSharding::MANUAL) + .value("TUPLE", OpSharding::TUPLE) + .value("OTHER", OpSharding::OTHER) + .value("UNKNOWN", OpSharding::UNKNOWN); + + nb::enum_ op_sharding_shard_group_type( + m, "OpSharding_ShardGroupType"); + op_sharding_shard_group_type.value("AS", OpSharding::AS) + .value("LIKE", OpSharding::LIKE); + + nb::class_ op_sharding(m, "OpSharding"); + op_sharding + .def_prop_ro_static( + "Type", + [op_sharding_type](const nb::object&) { return op_sharding_type; }) + .def_prop_ro_static("ShardGroupType", + [op_sharding_shard_group_type](const nb::object&) { + return op_sharding_shard_group_type; + }) + .def(nb::init<>()) + .def("__getstate__", + [](const OpSharding& self) { + std::string serialized = self.SerializeAsString(); + return nb::make_tuple( + nb::bytes(serialized.data(), serialized.size())); + }) + .def("__setstate__", + [](OpSharding* self, nb::tuple t) { + new (self) OpSharding(); + nb::bytes serialized = nb::cast(t[0]); + self->ParseFromArray(serialized.c_str(), serialized.size()); + }) + .def_prop_rw("type", &xla::OpSharding::type, &xla::OpSharding::set_type) + .def_prop_rw("replicate_on_last_tile_dim", + &xla::OpSharding::replicate_on_last_tile_dim, + &xla::OpSharding::set_replicate_on_last_tile_dim) + .def_prop_rw("is_shard_group", &xla::OpSharding::is_shard_group, + &xla::OpSharding::set_is_shard_group) + .def_prop_rw("shard_group_id", &xla::OpSharding::shard_group_id, + &xla::OpSharding::set_shard_group_id) + .def_prop_rw("shard_group_type", &xla::OpSharding::shard_group_type, + &xla::OpSharding::set_shard_group_type) + .def("__repr__", + [](const xla::OpSharding& self) { return self.DebugString(); }) + .def("ParseFromString", + [](OpSharding& sharding, const nb::bytes& s) { + sharding.ParseFromArray(s.c_str(), s.size()); + }) + .def("SerializeToString", + [](const OpSharding& sharding) { + std::string serialized = sharding.SerializeAsString(); + return nb::bytes(serialized.data(), serialized.size()); + }) + .def("clone", + [](const OpSharding& sharding) { return OpSharding(sharding); }); + DefRepeatedProperty(op_sharding, "tile_assignment_dimensions", + &xla::OpSharding::mutable_tile_assignment_dimensions); + DefRepeatedProperty(op_sharding, "tile_assignment_devices", + &xla::OpSharding::mutable_tile_assignment_devices); + DefRepeatedProperty(op_sharding, "iota_reshape_dims", + &xla::OpSharding::mutable_iota_reshape_dims); + DefRepeatedProperty(op_sharding, "iota_transpose_perm", + &xla::OpSharding::mutable_iota_transpose_perm); + DefRepeatedProperty(op_sharding, "tuple_shardings", + &xla::OpSharding::mutable_tuple_shardings); + DefRepeatedEnumProperty(op_sharding, "last_tile_dims", + &xla::OpSharding::mutable_last_tile_dims); + + nb::class_ hlo_sharding(m, "HloSharding"); + hlo_sharding + .def_static("from_proto", + xla::ValueOrThrowWrapper(xla::HloSharding::FromProto)) + .def_static("from_string", xla::ValueOrThrowWrapper(xla::ParseSharding)) + .def_static( + "tuple_sharding", + [](xla::Shape shape, + std::vector shardings) -> xla::HloSharding { + return HloSharding::Tuple(shape, shardings); + }, + "Constructs a tuple sharding.") + .def_static( + "iota_tile", xla::ValueOrThrowWrapper(IotaTileHelper), + nb::arg("dims"), + nb::arg("reshape_dims") = absl::Span(), + nb::arg("transpose_perm") = absl::Span(), + nb::arg("subgroup_types") = absl::Span()) + .def_static("manual", [] { return HloSharding::Manual(); }) + .def_static("replicate", [] { return HloSharding::Replicate(); }) + .def_static("unknown", [] { return HloSharding::Unknown(); }) + .def_static( + "subgroup_with_device_ordering", + xla::ValueOrThrowWrapper(SubgroupWithTileAssignmentHelper), + nb::arg("tile_assignment"), + nb::arg("subgroup_types") = absl::Span()) + .def("__eq__", [](const xla::HloSharding& a, + const xla::HloSharding& b) { return a == b; }) + .def("__hash__", + [](const xla::HloSharding& self) { return absl::HashOf(self); }) + .def("is_replicated", &xla::HloSharding::IsReplicated) + .def("is_manual", &xla::HloSharding::IsManual) + .def("is_unknown", &xla::HloSharding::IsUnknown) + .def("is_tiled", &xla::HloSharding::IsTiled) + .def("is_maximal", &xla::HloSharding::IsTileMaximal) + .def("tile", [](const xla::HloSharding& self, + xla::Shape shape) { return self.TileShape(shape); }) + // tile_assignment.array() is computed using an internal cache, + // which is why nb::lock_self() is required. It may be preferable to move + // this locking into the TileAssignment class if we find it to race with + // non-Python users of that class. + .def( + "tuple_elements", + [](const xla::HloSharding& self) { return self.tuple_elements(); }, + nb::lock_self()) + .def( + "num_devices", + [](const xla::HloSharding& self) { + return self.tile_assignment().num_elements(); + }, + nb::lock_self()) + .def( + "num_dimensions", + [](const xla::HloSharding& self) { + return self.tile_assignment().num_dimensions(); + }, + nb::lock_self()) + .def("is_tile_assignment_iota", + [](const xla::HloSharding& self) { + return self.tile_assignment().iota().has_value(); + }) + .def( + "tile_assignment_dimensions", + [](const xla::HloSharding& self) { + absl::Span span = + self.tile_assignment().dimensions(); + CHECK(span.data()); + return span; + }, + nb::lock_self()) + .def( + "tile_assignment_devices", + [](const xla::HloSharding& self) { + auto span = + absl::MakeConstSpan(self.tile_assignment().array().data(), + self.tile_assignment().num_elements()); + CHECK(span.data()); + return span; + }, + nb::lock_self()) + .def("replicate_on_last_tile_dim", + &xla::HloSharding::ReplicateOnLastTileDim) + .def("subgroup_types", &xla::HloSharding::subgroup_types) + .def("__repr__", + [](const xla::HloSharding& self) { return self.ToString(); }) + .def("to_proto", &xla::HloSharding::ToProto) + .def("get_axis_sizes", [](const xla::HloSharding& self) { + // If returning the SmallVector, we encounter the error "unable to + // convert function return value to a Python type!". + mlir::SmallVector mesh_shape = + xla::sdy::getAxisSizes(self.tile_assignment()); + return std::vector(mesh_shape.begin(), mesh_shape.end()); + }); +} // NOLINT(readability/fn_size) +} // namespace xla diff --git a/jaxlib/xla_compiler.h b/jaxlib/xla_compiler.h new file mode 100644 index 000000000000..261f630d1cd3 --- /dev/null +++ b/jaxlib/xla_compiler.h @@ -0,0 +1,28 @@ +/* Copyright 2020 The JAX Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_XLA_COMPILER_H_ +#define JAXLIB_XLA_COMPILER_H_ + +// placeholder for index annotation headers +#include "nanobind/nanobind.h" + +namespace xla { + +void BuildXlaCompilerSubmodule(nanobind::module_& m); + +} // namespace xla + +#endif // JAXLIB_XLA_COMPILER_H_ diff --git a/pyproject.toml b/pyproject.toml index a1b9e7dd446a..ff34488124e7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,21 +23,32 @@ module = [ "jax.experimental.jax2tf.tests.back_compat_testdata", "jax.experimental.jax2tf.tests.flax_models", "jax_cuda12_plugin.*", - "jaxlib.*", + "jaxlib.cpu_feature_guard", + "jaxlib.cuda.*", "jaxlib.mlir.*", + "jaxlib.mosaic.dialect.gpu.*", + "jaxlib.mosaic.python._tpu_gen", + "jaxlib.triton.*", + "jaxlib.utils", + "jaxlib.version", + "jaxlib._jax.utils", + "jaxlib._pretty_printer", "jraph.*", "libtpu.*", "matplotlib.*", + "mlir.*", + "ml_dtypes.*", "nvidia.*", "numpy.*", "opt_einsum.*", "optax.*", + "portpicker.*", "pygments.*", "pytest.*", "rich.*", "scipy.*", "setuptools.*", - "tensorboard_plugin_profile.convert.*", + "xprof.convert.*", "tensorflow.*", "tensorflow.io.*", "tensorflowjs.*", @@ -78,7 +89,7 @@ doctest_optionflags = [ "NUMBER", "NORMALIZE_WHITESPACE" ] -addopts = "--doctest-glob='*.rst' --ignore='examples/ffi'" +addopts = "--doctest-glob='*.rst' --ignore='examples/ffi' --import-mode=importlib" [tool.ruff] preview = true @@ -103,6 +114,8 @@ ignore = [ "C901", # Local variable is assigned to but never used "F841", + # Class could be dataclass or namedtuple + "B903", # Raise with from clause inside except block "B904", # Zip without explicit strict parameter diff --git a/setup.py b/setup.py index 80f45285ba61..a8bcdee95091 100644 --- a/setup.py +++ b/setup.py @@ -19,11 +19,11 @@ project_name = 'jax' -_current_jaxlib_version = '0.5.1' +_current_jaxlib_version = '0.6.2' # The following should be updated after each new jaxlib release. -_latest_jaxlib_version_on_pypi = '0.5.1' +_latest_jaxlib_version_on_pypi = '0.6.2' -_libtpu_version = '0.0.10.*' +_libtpu_version = '0.0.17.*' def load_version_module(pkg_path): spec = importlib.util.spec_from_file_location( @@ -38,6 +38,13 @@ def load_version_module(pkg_path): _cmdclass = _version_module._get_cmdclass(project_name) _minimum_jaxlib_version = _version_module._minimum_jaxlib_version +# If this is a pre-release ("rc" wheels), append "rc0" to +# _minimum_jaxlib_version and _current_jaxlib_version so that we are able to +# install the rc wheels. +if _version_module._is_prerelease(): + _minimum_jaxlib_version += "rc0" + _current_jaxlib_version += "rc0" + with open('README.md', encoding='utf-8') as f: _long_description = f.read() @@ -50,16 +57,15 @@ def load_version_module(pkg_path): long_description_content_type='text/markdown', author='JAX team', author_email='jax-dev@google.com', - packages=find_packages(exclude=["*examples*", "*internal_test_util*"]), + packages=find_packages(exclude=["examples"]), package_data={'jax': ['py.typed', "*.pyi", "**/*.pyi"]}, - python_requires='>=3.10', + python_requires='>=3.11', install_requires=[ f'jaxlib >={_minimum_jaxlib_version}, <={_jax_version}', - 'ml_dtypes>=0.4.0', - 'numpy>=1.25', - "numpy>=1.26.0; python_version>='3.12'", + 'ml_dtypes>=0.5.0', + 'numpy>=1.26', 'opt_einsum', - 'scipy>=1.11.1', + 'scipy>=1.12', ], extras_require={ # Minimum jaxlib version; used in testing. @@ -81,32 +87,25 @@ def load_version_module(pkg_path): ], 'cuda': [ - f"jaxlib=={_current_jaxlib_version}", - f"jax-cuda12-plugin[with_cuda]>={_current_jaxlib_version},<={_jax_version}", + f"jaxlib>={_current_jaxlib_version},<={_jax_version}", + f"jax-cuda12-plugin[with-cuda]>={_current_jaxlib_version},<={_jax_version}", ], 'cuda12': [ - f"jaxlib=={_current_jaxlib_version}", - f"jax-cuda12-plugin[with_cuda]>={_current_jaxlib_version},<={_jax_version}", - ], - - # Deprecated alias for cuda12, kept to avoid breaking users who wrote - # cuda12_pip in their CI. - 'cuda12_pip': [ - f"jaxlib=={_current_jaxlib_version}", - f"jax-cuda12-plugin[with_cuda]>={_current_jaxlib_version},<={_jax_version}", + f"jaxlib>={_current_jaxlib_version},<={_jax_version}", + f"jax-cuda12-plugin[with-cuda]>={_current_jaxlib_version},<={_jax_version}", ], # Target that does not depend on the CUDA pip wheels, for those who want # to use a preinstalled CUDA. - 'cuda12_local': [ - f"jaxlib=={_current_jaxlib_version}", - f"jax-cuda12-plugin=={_current_jaxlib_version}", + 'cuda12-local': [ + f"jaxlib>={_current_jaxlib_version},<={_jax_version}", + f"jax-cuda12-plugin>={_current_jaxlib_version},<={_jax_version}", ], # ROCm support for ROCm 6.0 and above. 'rocm': [ - f"jaxlib=={_current_jaxlib_version}", + f"jaxlib>={_current_jaxlib_version},<={_jax_version}", f"jax-rocm60-plugin>={_current_jaxlib_version},<={_jax_version}", ], @@ -114,14 +113,20 @@ def load_version_module(pkg_path): 'k8s': [ 'kubernetes', ], + + # For including XProf server + 'xprof': [ + 'xprof', + ], }, url='https://github.com/jax-ml/jax', license='Apache-2.0', classifiers=[ - "Programming Language :: Python :: 3.10", + "Development Status :: 5 - Production/Stable", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3.13", + "Programming Language :: Python :: Free Threading :: 3 - Stable", ], zip_safe=False, ) diff --git a/tests/BUILD b/tests/BUILD index 0ffa68ed8eb3..15cf4330d28b 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -34,41 +34,69 @@ jax_generate_backend_suites() jax_multiplatform_test( name = "api_test", srcs = ["api_test.py"], - enable_configs = ["tpu_v3_2x2"], + enable_configs = ["tpu_v3_x4"], shard_count = 10, deps = [ "//jax:experimental", - ], + ] + py_deps([ + "absl/testing", + "numpy", + ]), +) + +jax_multiplatform_test( + name = "custom_api_test", + srcs = ["custom_api_test.py"], + shard_count = 10, + deps = [ + "//jax:custom_derivatives", + "//jax:experimental", + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( name = "debug_info_test", srcs = ["debug_info_test.py"], - enable_configs = ["tpu_v3_2x2"], + enable_configs = ["tpu_v3_x4"], deps = [ + "//jax:custom_transpose", "//jax:experimental", "//jax:pallas", "//jax:pallas_gpu", "//jax:pallas_gpu_ops", "//jax:pallas_tpu", "//jax:pallas_tpu_ops", - ] + py_deps("numpy"), + ] + py_deps([ + "numpy", + "absl/testing", + ]), ) jax_multiplatform_test( name = "device_test", srcs = ["device_test.py"], + deps = py_deps("absl/testing"), ) -jax_multiplatform_test( +jax_py_test( name = "dynamic_api_test", srcs = ["dynamic_api_test.py"], - shard_count = 2, + deps = [ + "//jax", + "//jax:test_util", + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( name = "api_util_test", srcs = ["api_util_test.py"], + deps = py_deps("absl/testing"), ) jax_py_test( @@ -80,21 +108,40 @@ jax_py_test( ] + py_deps("absl/testing"), ) +jax_py_test( + name = "array_extensibility_test", + srcs = ["array_extensibility_test.py"], + deps = [ + "//jax", + "//jax:test_util", + ] + py_deps([ + "absl/testing", + "numpy", + ]), +) + jax_multiplatform_test( name = "array_interoperability_test", srcs = ["array_interoperability_test.py"], + disable_configs = [ + "gpu_h100_tfrt", # TODO(b/411472145): Re-enable once fixed. + "gpu_h100x2_tfrt", + ], enable_backends = [ "cpu", "gpu", ], enable_configs = [ - "gpu_p100x2", + "gpu_h100x2", ], - env = { - "PYTHONWARNINGS": "default", # TODO(b/394123878): protobuf, via TensorFlow, issues a Python warning under Python 3.12+ sometimes. - }, - tags = ["multiaccelerator"], - deps = py_deps("tensorflow_core"), + tags = [ + "multiaccelerator", + ], + deps = py_deps([ + "absl/testing", + "numpy", + "tensorflow_core", + ]), ) jax_multiplatform_test( @@ -103,6 +150,25 @@ jax_multiplatform_test( shard_count = { "gpu": 5, }, + deps = py_deps([ + "absl/testing", + "numpy", + ]), +) + +jax_multiplatform_test( + name = "buffer_callback_test", + srcs = ["buffer_callback_test.py"], + enable_backends = [ + "cpu", + "gpu", + ], + deps = [ + "//jax:experimental_buffer_callback", + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_py_test( @@ -121,16 +187,41 @@ jax_multiplatform_test( "cpu": 5, "gpu": 10, }, + deps = py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( name = "debug_nans_test", srcs = ["debug_nans_test.py"], + deps = py_deps([ + "absl/testing", + "numpy", + ]), +) + +jax_py_test( + name = "distributed_initialize_test", + srcs = ["distributed_initialize_test.py"], + deps = [ + "//jax", + "//jax:test_util", + ] + py_deps([ + "portpicker", + "absl/testing", + ]), ) jax_multiplatform_test( name = "distributed_test", srcs = ["distributed_test.py"], + enable_backends = ["gpu"], + deps = py_deps([ + "portpicker", + "absl/testing", + ]), ) jax_py_test( @@ -143,12 +234,19 @@ jax_py_test( deps = [ "//jax", "//jax:test_util", - ] + py_deps("portpicker"), + ] + py_deps([ + "portpicker", + "absl/testing", + ]), ) jax_multiplatform_test( name = "dtypes_test", srcs = ["dtypes_test.py"], + deps = py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -158,22 +256,29 @@ jax_multiplatform_test( enable_configs = [ "cpu", ], + deps = py_deps("absl/testing"), ) jax_multiplatform_test( name = "extend_test", srcs = ["extend_test.py"], - deps = ["//jax:extend"], + deps = ["//jax:extend"] + py_deps("absl/testing"), ) jax_multiplatform_test( name = "ffi_test", srcs = ["ffi_test.py"], enable_configs = [ - "gpu_p100x2", + "gpu_h100x2", ], # TODO(dfm): Remove after removal of jex.ffi imports. - deps = ["//jax:extend"], + deps = [ + "//jax:extend", + "//jax:ffi", + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -186,15 +291,23 @@ jax_multiplatform_test( ], # Times out on TPU with asan/tsan. }, shard_count = { - "tpu": 20, + "tpu": 10, "cpu": 20, "gpu": 10, }, + deps = py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( name = "generated_fun_test", srcs = ["generated_fun_test.py"], + deps = py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -205,6 +318,7 @@ jax_multiplatform_test( "XLA_PYTHON_CLIENT_PREALLOCATE": "0", }, main = "gpu_memory_flags_test.py", + deps = py_deps("absl/testing"), ) jax_multiplatform_test( @@ -214,30 +328,42 @@ jax_multiplatform_test( env = { "XLA_PYTHON_CLIENT_PREALLOCATE": "1", }, + deps = py_deps("absl/testing"), ) jax_multiplatform_test( name = "lobpcg_test", srcs = ["lobpcg_test.py"], - env = {"LOBPCG_EMIT_DEBUG_PLOTS": "1"}, + # Set LOBPCG_EMIT_DEBUG_PLOTS=1 to debug + # checkLobpcgMonotonicity and checkApproxEigs tests + # using matplotlib plots + # env = {"LOBPCG_EMIT_DEBUG_PLOTS": "1"}, shard_count = { - "cpu": 48, - "gpu": 48, - "tpu": 48, + "cpu": 8, }, deps = [ "//jax:experimental_sparse", - ] + py_deps("matplotlib"), + ] + py_deps([ + "matplotlib", + "absl/testing", + "numpy", + "scipy", + ]), ) jax_multiplatform_test( name = "svd_test", srcs = ["svd_test.py"], shard_count = { - "cpu": 10, + "cpu": 20, "gpu": 10, - "tpu": 40, + "tpu": 15, }, + deps = py_deps([ + "absl/testing", + "numpy", + "scipy", + ]), ) jax_py_test( @@ -246,7 +372,7 @@ jax_py_test( deps = [ "//jax", "//jax:test_util", - ], + ] + py_deps("absl/testing"), ) jax_multiplatform_test( @@ -254,20 +380,20 @@ jax_multiplatform_test( srcs = ["memories_test.py"], enable_configs = [ "cpu", - "gpu_p100x2", - "tpu_v3_2x2", - "tpu_v4_2x2", - "tpu_v5p_2x2", - "tpu_v5e_4x2", + "gpu_h100x2", + "tpu_v3_x4", + "tpu_v4_x4", + "tpu_v5p_x4", + "tpu_v5e_x8", "gpu_p100x2_shardy", - "tpu_v5e_4x2_shardy", + "tpu_v5e_x8_shardy", ], - shard_count = { - "tpu": 5, - }, deps = [ "//jax:experimental", - ], + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -279,9 +405,9 @@ jax_multiplatform_test( }, enable_configs = [ "gpu_p100x2_shardy", - "tpu_v3_2x2_shardy", - "tpu_v3_2x2", - "gpu_p100x2", + "tpu_v3_x4_shardy", + "tpu_v3_x4", + "gpu_h100x2", ], shard_count = { "cpu": 5, @@ -291,7 +417,10 @@ jax_multiplatform_test( tags = ["multiaccelerator"], deps = [ "//jax:experimental", - ], + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -301,26 +430,33 @@ jax_multiplatform_test( "tpu": ["requires-mem:16g"], # Under tsan on 2x2 this test exceeds the default 12G memory limit. }, enable_configs = [ - "tpu_v3_2x2_shardy", + "tpu_v3_x4_shardy", + "tpu_v3_x4", ], tags = ["multiaccelerator"], deps = [ "//jax:experimental", - ], + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( name = "shard_alike_test", srcs = ["shard_alike_test.py"], enable_configs = [ - "tpu_v3_2x2", - "tpu_v5e_4x2", - "tpu_v4_2x2", - "tpu_v3_2x2_shardy", + "tpu_v3_x4", + "tpu_v5e_x8", + "tpu_v4_x4", + "tpu_v3_x4_shardy", ], deps = [ "//jax:experimental", - ], + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -329,6 +465,9 @@ jax_multiplatform_test( backend_tags = { "gpu": ["noasan"], # Memory leaks in NCCL, see https://github.com/NVIDIA/nccl/pull/1143 }, + disable_configs = [ + "gpu_h100x2_tfrt", # TODO(b/419192167): Doesn't work + ], enable_backends = ["gpu"], tags = [ "config-cuda-only", @@ -336,7 +475,10 @@ jax_multiplatform_test( ], deps = [ "//jax:experimental", - ], + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -351,7 +493,10 @@ jax_multiplatform_test( ], deps = [ "//jax:experimental", - ], + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -360,13 +505,14 @@ jax_multiplatform_test( enable_backends = ["gpu"], enable_configs = [ "gpu_h100", + "gpu_h100_shardy", ], tags = [ "config-cuda-only", ], deps = [ "//jax:experimental", - ], + ] + py_deps("absl/testing"), ) jax_multiplatform_test( @@ -376,13 +522,16 @@ jax_multiplatform_test( "tpu": ["requires-mem:16g"], # Under tsan on 2x2 this test exceeds the default 12G memory limit. }, enable_configs = [ - "tpu_v3_2x2", + "tpu_v3_x4", ], tags = ["multiaccelerator"], deps = [ "//jax:experimental", "//jax:internal_test_util", - ], + ] + py_deps([ + "numpy", + "absl/testing", + ]), ) jax_multiplatform_test( @@ -391,7 +540,10 @@ jax_multiplatform_test( tags = ["multiaccelerator"], deps = [ "//jax:experimental", - ] + py_deps("numpy"), + ] + py_deps([ + "numpy", + "absl/testing", + ]), ) jax_multiplatform_test( @@ -399,24 +551,35 @@ jax_multiplatform_test( srcs = ["image_test.py"], shard_count = { "cpu": 10, - "gpu": 20, - "tpu": 10, + "gpu": 10, + "tpu": 8, }, tags = ["noasan"], # Linking TF causes a linker OOM. - deps = py_deps("pil") + py_deps("tensorflow_core"), + deps = py_deps([ + "pil", + "tensorflow_core", + "numpy", + "absl/testing", + ]), ) jax_multiplatform_test( name = "infeed_test", srcs = ["infeed_test.py"], - deps = [ - ], + deps = py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( name = "jax_jit_test", srcs = ["jax_jit_test.py"], main = "jax_jit_test.py", + deps = py_deps([ + "absl/testing", + "numpy", + ]), ) jax_py_test( @@ -426,7 +589,10 @@ jax_py_test( "//jax:test_util", "//jax/experimental/jax2tf", "//jax/tools:jax_to_ir", - ] + py_deps("tensorflow_core"), + ] + py_deps([ + "tensorflow_core", + "absl/testing", + ]), ) jax_py_test( @@ -436,7 +602,7 @@ jax_py_test( "//jax", "//jax:jaxpr_util", "//jax:test_util", - ], + ] + py_deps("absl/testing"), ) jax_multiplatform_test( @@ -444,12 +610,15 @@ jax_multiplatform_test( srcs = ["jet_test.py"], shard_count = { "cpu": 10, - "gpu": 10, + "gpu": 4, }, deps = [ "//jax:jet", "//jax:stax", - ], + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -457,25 +626,38 @@ jax_multiplatform_test( srcs = ["lax_control_flow_test.py"], shard_count = { "cpu": 30, - "gpu": 40, - "tpu": 30, + "gpu": 30, + "tpu": 20, }, + deps = py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( name = "custom_root_test", srcs = ["custom_root_test.py"], + deps = py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( name = "custom_linear_solve_test", srcs = ["custom_linear_solve_test.py"], + deps = py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( name = "lax_numpy_test", srcs = ["lax_numpy_test.py"], backend_tags = { + "tpu": ["notsan"], # Test times out. "cpu": ["notsan"], # Test times out. }, shard_count = { @@ -487,6 +669,10 @@ jax_multiplatform_test( "noasan", # Test times out on all backends "test_cpu_thunks", ], + deps = py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -497,6 +683,10 @@ jax_multiplatform_test( "gpu": 30, "tpu": 40, }, + deps = py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -507,6 +697,10 @@ jax_multiplatform_test( "gpu": 20, "tpu": 20, }, + deps = py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -517,16 +711,19 @@ jax_multiplatform_test( "gpu": 10, "tpu": 10, }, + deps = py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( name = "lax_numpy_einsum_test", srcs = ["lax_numpy_einsum_test.py"], - shard_count = { - "cpu": 10, - "gpu": 10, - "tpu": 10, - }, + deps = py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -534,25 +731,37 @@ jax_multiplatform_test( srcs = ["lax_numpy_ufuncs_test.py"], shard_count = { "cpu": 10, - "gpu": 10, - "tpu": 10, + "gpu": 5, + "tpu": 5, }, + deps = py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( name = "lax_numpy_vectorize_test", srcs = ["lax_numpy_vectorize_test.py"], + deps = py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( name = "lax_scipy_test", srcs = ["lax_scipy_test.py"], shard_count = { - "cpu": 20, + "cpu": 30, "gpu": 20, - "tpu": 20, + "tpu": 8, }, - deps = py_deps("numpy") + py_deps("scipy") + py_deps("absl/testing"), + deps = py_deps([ + "numpy", + "scipy", + "absl/testing", + ]), ) jax_multiplatform_test( @@ -563,37 +772,52 @@ jax_multiplatform_test( }, shard_count = { "cpu": 10, - "gpu": 10, - "tpu": 10, + "gpu": 5, + "tpu": 5, }, + deps = py_deps([ + "numpy", + "scipy", + "absl/testing", + ]), ) jax_multiplatform_test( name = "lax_scipy_special_functions_test", srcs = ["lax_scipy_special_functions_test.py"], backend_tags = { - "gpu": ["noasan"], # Times out. - "cpu": ["noasan"], # Times out. + "cpu": [ + "nomsan", # Times out. + "notsan", # Times out. + ], }, shard_count = { "cpu": 20, - "gpu": 20, + "gpu": 30, "tpu": 20, }, - deps = py_deps("numpy") + py_deps("scipy") + py_deps("absl/testing"), + tags = ["noasan"], # Times out under asan. + deps = py_deps([ + "numpy", + "scipy", + "absl/testing", + ]), ) jax_multiplatform_test( name = "lax_scipy_spectral_dac_test", srcs = ["lax_scipy_spectral_dac_test.py"], shard_count = { - "cpu": 40, - "gpu": 40, - "tpu": 40, + "cpu": 20, + "gpu": 8, + "tpu": 8, }, deps = [ "//jax:internal_test_util", - ] + py_deps("numpy") + py_deps("scipy") + py_deps("absl/testing"), + ] + py_deps([ + "numpy", + "absl/testing", + ]), ) jax_multiplatform_test( @@ -611,7 +835,11 @@ jax_multiplatform_test( deps = [ "//jax:internal_test_util", "//jax:lax_reference", - ] + py_deps("numpy") + py_deps("mpmath"), + ] + py_deps([ + "numpy", + "absl/testing", + "mpmath", + ]), ) jax_multiplatform_test( @@ -622,7 +850,10 @@ jax_multiplatform_test( deps = [ "//jax:internal_test_util", "//jax:lax_reference", - ] + py_deps("numpy"), + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -630,9 +861,13 @@ jax_multiplatform_test( srcs = ["lax_autodiff_test.py"], shard_count = { "cpu": 40, - "gpu": 40, + "gpu": 30, "tpu": 20, }, + deps = py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -643,7 +878,10 @@ jax_multiplatform_test( "gpu": 40, "tpu": 40, }, - deps = ["//jax:internal_test_util"] + py_deps("numpy") + py_deps("absl/testing"), + deps = ["//jax:internal_test_util"] + py_deps([ + "numpy", + "absl/testing", + ]), ) jax_multiplatform_test( @@ -654,7 +892,10 @@ jax_multiplatform_test( "gpu": 40, "tpu": 40, }, - deps = ["//jax:internal_test_util"] + py_deps("numpy") + py_deps("absl/testing"), + deps = ["//jax:internal_test_util"] + py_deps([ + "numpy", + "absl/testing", + ]), ) jax_py_test( @@ -665,7 +906,7 @@ jax_py_test( deps = [ "//jax:internal_test_util", "//jax:test_util", - ], + ] + py_deps("absl/testing"), ) jax_py_test( @@ -676,7 +917,7 @@ jax_py_test( deps = [ "//jax:internal_test_util", "//jax:test_util", - ], + ] + py_deps("absl/testing"), ) jax_multiplatform_test( @@ -696,6 +937,11 @@ jax_multiplatform_test( "gpu": 40, "tpu": 40, }, + deps = py_deps([ + "absl/testing", + "numpy", + "scipy", + ]), ) jax_multiplatform_test( @@ -705,31 +951,47 @@ jax_multiplatform_test( "cpu", ], enable_configs = [ - "gpu_p100x2", + "gpu_h100x2", "gpu_p100x2_shardy", "gpu_p100x2_pjrt_c_api", ], + shard_count = { + "cpu": 10, + "gpu": 10, + }, tags = [ "multiaccelerator", ], + deps = py_deps([ + "absl/testing", + ]), ) jax_multiplatform_test( name = "magma_linalg_test", srcs = ["magma_linalg_test.py"], enable_backends = ["gpu"], - deps = py_deps("magma"), + deps = py_deps([ + "magma", + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( name = "cholesky_update_test", srcs = ["cholesky_update_test.py"], + deps = py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( name = "metadata_test", srcs = ["metadata_test.py"], enable_backends = ["cpu"], + deps = py_deps("absl/testing"), ) jax_py_test( @@ -738,22 +1000,27 @@ jax_py_test( deps = [ "//jax", "//jax:test_util", - ], + ] + py_deps("absl/testing"), ) jax_multiplatform_test( name = "multibackend_test", srcs = ["multibackend_test.py"], enable_configs = [ - "tpu_v3_2x2", - "gpu_p100x2", + "tpu_v3_x4", + "gpu_h100x2", ], + deps = py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( name = "multi_device_test", srcs = ["multi_device_test.py"], enable_backends = ["cpu"], + deps = py_deps("absl/testing"), ) jax_multiplatform_test( @@ -772,12 +1039,20 @@ jax_multiplatform_test( "tpu": 10, "gpu": 10, }, + deps = py_deps([ + "absl/testing", + "numpy", + "scipy", + ]), ) jax_multiplatform_test( name = "optimizers_test", srcs = ["optimizers_test.py"], - deps = ["//jax:optimizers"], + deps = ["//jax:optimizers"] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -785,7 +1060,11 @@ jax_multiplatform_test( srcs = ["pickle_test.py"], deps = [ "//jax:experimental", - ] + py_deps("cloudpickle") + py_deps("numpy"), + ] + py_deps([ + "cloudpickle", + "numpy", + "absl/testing", + ]), ) jax_multiplatform_test( @@ -799,17 +1078,20 @@ jax_multiplatform_test( }, enable_configs = [ "gpu_v100", - "tpu_v3_2x2", + "tpu_v3_x4", ], shard_count = { "cpu": 30, - "gpu": 30, + "gpu": 10, "tpu": 30, }, tags = ["multiaccelerator"], deps = [ "//jax:internal_test_util", - ], + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -818,7 +1100,7 @@ jax_multiplatform_test( # No implementation of nonsymmetric Eigendecomposition. enable_backends = ["cpu"], shard_count = { - "cpu": 10, + "cpu": 5, }, # This test ends up calling Fortran code that initializes some memory and # passes it to C code. MSan is not able to detect that the memory was @@ -827,12 +1109,21 @@ jax_multiplatform_test( # in this case there's not a good place to do it, see b/197635968#comment19 # for details. tags = ["nomsan"], + deps = py_deps([ + "absl/testing", + "numpy", + "scipy", + ]), ) jax_multiplatform_test( name = "heap_profiler_test", srcs = ["heap_profiler_test.py"], enable_backends = ["cpu"], + deps = py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -848,15 +1139,20 @@ jax_multiplatform_test( enable_backends = [ "cpu", "gpu", + "tpu", ], deps = [ "//jax:profiler", - ], + ] + py_deps("absl/testing"), ) jax_multiplatform_test( name = "pytorch_interoperability_test", srcs = ["pytorch_interoperability_test.py"], + disable_configs = [ + "gpu_h100_tfrt", # TODO(b/411472145): Re-enable once fixed. + "gpu_h100x2_tfrt", + ], enable_backends = [ "cpu", "gpu", @@ -866,7 +1162,10 @@ jax_multiplatform_test( "nomsan", # TODO(b/355237462): msan false-positives in torch? "not_build:arm", ], - deps = py_deps("torch"), + deps = py_deps([ + "torch", + "absl/testing", + ]), ) jax_multiplatform_test( @@ -879,29 +1178,20 @@ jax_multiplatform_test( "notsan", # Times out ], }, - shard_count = 10, + shard_count = 8, + deps = py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( name = "random_test", srcs = ["random_test.py"], - backend_tags = { - "cpu": [ - "notsan", # Times out - "nomsan", # Times out - ], - "tpu": [ - "optonly", - "nomsan", # Times out - "notsan", # Times out - ], - }, - shard_count = { - "cpu": 30, - "gpu": 30, - "tpu": 40, - }, - tags = ["noasan"], # Times out + deps = py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -923,10 +1213,15 @@ jax_multiplatform_test( }, shard_count = { "cpu": 40, - "gpu": 40, + "gpu": 50, "tpu": 40, }, tags = ["noasan"], # Times out + deps = py_deps([ + "absl/testing", + "numpy", + "scipy", + ]), ) # TODO(b/199564969): remove once we always enable_custom_prng @@ -934,25 +1229,8 @@ jax_multiplatform_test( name = "random_test_with_custom_prng", srcs = ["random_test.py"], args = ["--jax_enable_custom_prng=true"], - backend_tags = { - "cpu": [ - "noasan", # Times out under asan/msan/tsan. - "nomsan", - "notsan", - ], - "tpu": [ - "noasan", # Times out under asan/msan/tsan. - "nomsan", - "notsan", - "optonly", - ], - }, main = "random_test.py", - shard_count = { - "cpu": 40, - "gpu": 40, - "tpu": 40, - }, + deps = py_deps("absl/testing"), ) jax_multiplatform_test( @@ -966,21 +1244,41 @@ jax_multiplatform_test( ], # Times out on TPU with asan/tsan/msan. }, shard_count = 12, + deps = py_deps([ + "absl/testing", + "numpy", + "scipy", + ]), ) jax_multiplatform_test( name = "scipy_interpolate_test", srcs = ["scipy_interpolate_test.py"], + deps = py_deps([ + "absl/testing", + "numpy", + "scipy", + ]), ) jax_multiplatform_test( name = "scipy_ndimage_test", srcs = ["scipy_ndimage_test.py"], + deps = py_deps([ + "absl/testing", + "numpy", + "scipy", + ]), ) jax_multiplatform_test( name = "scipy_optimize_test", srcs = ["scipy_optimize_test.py"], + deps = py_deps([ + "absl/testing", + "numpy", + "scipy", + ]), ) jax_multiplatform_test( @@ -1006,12 +1304,25 @@ jax_multiplatform_test( "gpu": 40, "tpu": 50, }, + deps = py_deps([ + "absl/testing", + "numpy", + "scipy", + ]), ) jax_multiplatform_test( name = "scipy_spatial_test", srcs = ["scipy_spatial_test.py"], - deps = py_deps("scipy"), + shard_count = { + "cpu": 4, + "gpu": 4, + }, + deps = py_deps([ + "absl/testing", + "numpy", + "scipy", + ]), ) jax_multiplatform_test( @@ -1021,14 +1332,19 @@ jax_multiplatform_test( "tpu": ["nomsan"], # Times out }, shard_count = { - "cpu": 40, - "gpu": 30, - "tpu": 40, + "cpu": 50, + "gpu": 50, + "tpu": 50, }, tags = [ "noasan", "notsan", ], # Times out + deps = py_deps([ + "absl/testing", + "numpy", + "scipy", + ]), ) jax_multiplatform_test( @@ -1050,8 +1366,8 @@ jax_multiplatform_test( }, shard_count = { "cpu": 50, - "gpu": 50, - "tpu": 50, + "gpu": 30, + "tpu": 20, }, tags = [ "noasan", @@ -1061,7 +1377,11 @@ jax_multiplatform_test( deps = [ "//jax:experimental_sparse", "//jax:sparse_test_util", - ] + py_deps("scipy"), + ] + py_deps([ + "scipy", + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -1098,21 +1418,10 @@ jax_multiplatform_test( deps = [ "//jax:experimental_sparse", "//jax:sparse_test_util", - ] + py_deps("scipy"), -) - -jax_multiplatform_test( - name = "sparse_nm_test", - srcs = ["sparse_nm_test.py"], - enable_backends = [], - enable_configs = [ - "gpu_a100", - "gpu_h100", - ], - deps = [ - "//jax:experimental_sparse", - "//jax:pallas_gpu", - ], + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -1122,10 +1431,11 @@ jax_multiplatform_test( backend_tags = { "cpu": [ "noasan", # Times out under asan - "notsan", # Times out under asan + "notsan", # Times out under tsan ], "tpu": [ - "noasan", # Times out under asan. + "noasan", # Times out under asan + "notsan", # Times out under tsan ], }, shard_count = { @@ -1136,48 +1446,70 @@ jax_multiplatform_test( deps = [ "//jax:experimental_sparse", "//jax:sparse_test_util", - ], + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( name = "stack_test", srcs = ["stack_test.py"], + deps = py_deps("absl/testing"), ) jax_multiplatform_test( name = "checkify_test", srcs = ["checkify_test.py"], - enable_configs = ["tpu_v3_2x2"], + enable_configs = ["tpu_v3_x4"], shard_count = { "gpu": 2, "tpu": 4, }, + deps = py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( name = "error_check_test", srcs = ["error_check_test.py"], + deps = py_deps("absl/testing"), +) + +jax_multiplatform_test( + name = "jax_numpy_error_test", + srcs = ["jax_numpy_error_test.py"], + deps = py_deps("absl/testing"), ) jax_multiplatform_test( name = "stax_test", srcs = ["stax_test.py"], - shard_count = { - "cpu": 5, - "gpu": 5, - }, - deps = ["//jax:stax"], + deps = ["//jax:stax"] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( name = "linear_search_test", srcs = ["third_party/scipy/line_search_test.py"], main = "third_party/scipy/line_search_test.py", + deps = py_deps([ + "absl/testing", + "scipy", + ]), ) jax_multiplatform_test( name = "blocked_sampler_test", srcs = ["blocked_sampler_test.py"], + deps = py_deps([ + "absl/testing", + "numpy", + ]), ) jax_py_test( @@ -1186,7 +1518,11 @@ jax_py_test( deps = [ "//jax", "//jax:test_util", - ], + ] + py_deps([ + "absl/testing", + "numpy", + "cloudpickle", + ]), ) pytype_test( @@ -1195,7 +1531,11 @@ pytype_test( deps = [ "//jax", "//jax:test_util", - ], + "//jax:typing", + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_py_test( @@ -1204,7 +1544,7 @@ jax_py_test( deps = [ "//jax", "//jax:test_util", - ], + ] + py_deps("absl/testing"), ) jax_py_test( @@ -1213,7 +1553,7 @@ jax_py_test( deps = [ "//jax", "//jax:test_util", - ], + ] + py_deps("absl/testing"), ) jax_py_test( @@ -1232,7 +1572,9 @@ jax_py_test( "//jax", "//jax:compiler", "//jax:test_util", - ] + py_deps("absl/logging"), + ] + py_deps([ + "absl/logging", + ]), ) jax_py_test( @@ -1242,7 +1584,10 @@ jax_py_test( "//jax", "//jax:lru_cache", "//jax:test_util", - ] + py_deps("filelock"), + ] + py_deps([ + "filelock", + "absl/logging", + ]), ) jax_multiplatform_test( @@ -1251,7 +1596,10 @@ jax_multiplatform_test( deps = [ "//jax:compilation_cache_internal", "//jax:compiler", - ], + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -1260,7 +1608,8 @@ jax_multiplatform_test( deps = [ "//jax:cache_key", "//jax:compiler", - ], + "//jax:custom_partitioning", + ] + py_deps("absl/testing"), ) jax_multiplatform_test( @@ -1269,18 +1618,27 @@ jax_multiplatform_test( shard_count = { "cpu": 10, }, - deps = ["//jax:ode"], + deps = ["//jax:ode"] + py_deps([ + "absl/testing", + "numpy", + "scipy", + ]), ) jax_multiplatform_test( name = "key_reuse_test", srcs = ["key_reuse_test.py"], + deps = py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( name = "roofline_test", srcs = ["roofline_test.py"], enable_backends = ["cpu"], + deps = py_deps("absl/testing"), ) jax_multiplatform_test( @@ -1288,13 +1646,24 @@ jax_multiplatform_test( srcs = ["x64_context_test.py"], deps = [ "//jax:experimental", - ], + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( name = "ann_test", srcs = ["ann_test.py"], - shard_count = 10, + shard_count = { + "cpu": 5, + "gpu": 5, + "tpu": 10, + }, + deps = py_deps([ + "numpy", + "absl/testing", + ]), ) jax_py_test( @@ -1304,22 +1673,35 @@ jax_py_test( "//jax", "//jax:mesh_utils", "//jax:test_util", - ], + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( name = "transfer_guard_test", srcs = ["transfer_guard_test.py"], + deps = py_deps([ + "absl/testing", + "numpy", + "cloudpickle", + ]), ) jax_multiplatform_test( name = "garbage_collection_guard_test", srcs = ["garbage_collection_guard_test.py"], + deps = py_deps("absl/testing"), ) -jax_multiplatform_test( +jax_py_test( name = "name_stack_test", srcs = ["name_stack_test.py"], + deps = [ + "//jax", + "//jax:test_util", + ] + py_deps("absl/testing"), ) jax_multiplatform_test( @@ -1331,11 +1713,15 @@ jax_multiplatform_test( enable_configs = [ "cpu", "gpu_h100", - "tpu_v2_1x1", - "tpu_v3_2x2", - "tpu_v4_2x2", + "tpu_v2", + "tpu_v3_x4", + "tpu_v4_x4", ], tags = ["multiaccelerator"], + deps = py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -1344,12 +1730,16 @@ jax_multiplatform_test( enable_configs = [ "cpu", "gpu_h100", - "tpu_v2_1x1", - "tpu_v3_2x2", - "tpu_v4_2x2", - "gpu_a100_shardy", - "tpu_v3_2x2_shardy", + "tpu_v2", + "tpu_v3_x4", + "tpu_v4_x4", + "gpu_h100_shardy", + "tpu_v3_x4_shardy", ], + deps = py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -1359,16 +1749,19 @@ jax_multiplatform_test( "gpu": ["noasan"], # Memory leaks in NCCL, see https://github.com/NVIDIA/nccl/pull/1143 }, enable_configs = [ - "tpu_v2_1x1", - "tpu_v3_2x2", - "tpu_v4_2x2", - "tpu_v3_2x2_shardy", + "tpu_v2", + "tpu_v3_x4", + "tpu_v4_x4", + "tpu_v3_x4_shardy", "gpu_p100x2_shardy", ], tags = ["multiaccelerator"], deps = [ "//jax:experimental", - ], + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -1380,10 +1773,14 @@ jax_multiplatform_test( enable_configs = [ "cpu", "gpu_h100", - "tpu_v2_1x1", - "tpu_v3_2x2", - "tpu_v4_2x2", + "tpu_v2", + "tpu_v3_x4", + "tpu_v4_x4", ], + deps = py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -1405,22 +1802,32 @@ jax_multiplatform_test( "gpu": 2, "tpu": 2, }, - deps = py_deps("hypothesis"), + deps = py_deps([ + "absl/testing", + "numpy", + "hypothesis", + ]), ) jax_multiplatform_test( name = "mutable_array_test", srcs = ["mutable_array_test.py"], + deps = py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( name = "for_loop_test", srcs = ["for_loop_test.py"], shard_count = { - "cpu": 20, - "gpu": 10, "tpu": 20, }, + deps = py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -1436,28 +1843,27 @@ jax_multiplatform_test( enable_configs = [ "gpu_p100x2_shardy", ], - shard_count = { - "gpu": 10, - "tpu": 10, - }, tags = [ "multiaccelerator", ], deps = [ "//jax:experimental", - ], + ] + py_deps("absl/testing"), ) jax_multiplatform_test( name = "shard_map_test", srcs = ["shard_map_test.py"], + disable_configs = [ + "gpu_h100x2_tfrt", # TODO(b/419192167): Doesn't work + ], enable_configs = [ "gpu_p100x2_shardy", - "tpu_v3_2x2_shardy", + "tpu_v3_x4_shardy", ], shard_count = { "cpu": 50, - "gpu": 10, + "gpu": 20, "tpu": 50, }, tags = [ @@ -1469,12 +1875,16 @@ jax_multiplatform_test( deps = [ "//jax:experimental", "//jax:tree_util", - ], + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( name = "clear_backends_test", srcs = ["clear_backends_test.py"], + deps = py_deps("absl/testing"), ) jax_multiplatform_test( @@ -1482,7 +1892,21 @@ jax_multiplatform_test( srcs = ["attrs_test.py"], deps = [ "//jax:experimental", - ], + ] + py_deps([ + "numpy", + "absl/testing", + ]), +) + +jax_multiplatform_test( + name = "hijax_test", + srcs = ["hijax_test.py"], + deps = [ + "//jax:experimental", + ] + py_deps([ + "numpy", + "absl/testing", + ]), ) jax_multiplatform_test( @@ -1491,7 +1915,10 @@ jax_multiplatform_test( deps = [ "//jax:experimental_colocated_python", "//jax/extend:ifrt_programs", - ], + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -1504,7 +1931,10 @@ jax_multiplatform_test( shard_count = 15, deps = [ "//jax:rnn", - ], + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_py_test( @@ -1514,7 +1944,7 @@ jax_py_test( "//jax", "//jax:mosaic", "//jax:test_util", - ], + ] + py_deps("absl/testing"), ) jax_py_test( @@ -1523,7 +1953,7 @@ jax_py_test( deps = [ "//jax", "//jax:test_util", - ], + ] + py_deps("absl/testing"), ) jax_py_test( @@ -1532,27 +1962,29 @@ jax_py_test( deps = [ "//jax", "//jax:test_util", - ], + ] + py_deps("absl/testing"), ) jax_multiplatform_test( name = "logging_test", srcs = ["logging_test.py"], + deps = py_deps("absl/testing"), ) jax_multiplatform_test( name = "export_test", srcs = ["export_test.py"], - disable_configs = [ - "cpu_shardy", # TODO(b/355263220): enable once export is supported. - ], enable_configs = [ "cpu_shardy", "gpu_p100x2_shardy", - "tpu_v3_2x2_shardy", - "tpu_v3_2x2", + "tpu_v3_x4_shardy", + "tpu_v3_x4", ], tags = [], + deps = py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -1566,9 +1998,9 @@ jax_multiplatform_test( "cpu_x32", ], shard_count = { - "cpu": 4, - "gpu": 6, - "tpu": 4, + "cpu": 30, + "gpu": 20, + "tpu": 25, }, tags = [ "noasan", # Times out @@ -1577,7 +2009,10 @@ jax_multiplatform_test( ], deps = [ "//jax:internal_test_harnesses", - ], + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -1590,7 +2025,7 @@ jax_multiplatform_test( ], shard_count = { "cpu": 40, - "gpu": 20, + "gpu": 30, "tpu": 20, }, tags = [ @@ -1599,33 +2034,60 @@ jax_multiplatform_test( ], deps = [ "//jax:internal_test_harnesses", - ], + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( name = "export_back_compat_test", srcs = ["export_back_compat_test.py"], + enable_configs = [ + "tpu_v3_x4_shardy", + ], tags = [], deps = [ "//jax:internal_export_back_compat_test_data", "//jax:internal_export_back_compat_test_util", - ], + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( name = "fused_attention_stablehlo_test", srcs = ["fused_attention_stablehlo_test.py"], enable_backends = ["gpu"], - shard_count = { - "gpu": 4, - }, tags = ["multiaccelerator"], + deps = py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( name = "xla_metadata_test", srcs = ["xla_metadata_test.py"], - deps = ["//jax:experimental"], + deps = ["//jax:experimental"] + py_deps("absl/testing"), +) + +jax_multiplatform_test( + name = "unary_ops_accuracy_test", + srcs = ["unary_ops_accuracy_test.py"], + disable_configs = [ + "tpu_pjrt_c_api", + ], + enable_backends = [ + "tpu", + ], + deps = [ + "//jax:experimental", + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_py_test( @@ -1634,7 +2096,7 @@ jax_py_test( deps = [ "//jax", "//jax:test_util", - ], + ] + py_deps("absl/testing"), ) jax_py_test( @@ -1644,7 +2106,7 @@ jax_py_test( "//jax", "//jax:source_mapper", "//jax:test_util", - ], + ] + py_deps("absl/testing"), ) jax_py_test( @@ -1652,13 +2114,21 @@ jax_py_test( srcs = ["sourcemap_test.py"], deps = [ "//jax", + "//jax:sourcemap", "//jax:test_util", - ], + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( name = "string_array_test", srcs = ["string_array_test.py"], + deps = py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -1670,6 +2140,7 @@ jax_multiplatform_test( "gpu_h100", ], tags = ["multiaccelerator"], + deps = py_deps("absl/testing"), ) jax_multiplatform_test( @@ -1679,6 +2150,10 @@ jax_multiplatform_test( shard_count = { "gpu": 4, }, + deps = py_deps([ + "absl/testing", + "numpy", + ]), ) jax_py_test( @@ -1686,14 +2161,16 @@ jax_py_test( srcs = ["custom_partitioning_sharding_rule_test.py"], deps = [ "//jax", + "//jax:custom_partitioning_sharding_rule", "//jax:experimental", "//jax:test_util", - ], + ] + py_deps("absl/testing"), ) exports_files( [ "api_test.py", + "custom_api_test.py", "array_test.py", "cache_key_test.py", "colocated_python_test.py", diff --git a/tests/ann_test.py b/tests/ann_test.py index 1d704c725c61..18bb51bec93b 100644 --- a/tests/ann_test.py +++ b/tests/ann_test.py @@ -179,7 +179,7 @@ def approx_max_k(qy, db): def test_vmap_after(self): - batch = 4 + batch = 8 qy_size = 128 db_size = 1024 feature_dim = 32 diff --git a/tests/aot_test.py b/tests/aot_test.py index daaeb8417d33..623c6aaed0cc 100644 --- a/tests/aot_test.py +++ b/tests/aot_test.py @@ -126,6 +126,20 @@ def my_function(x): hlo = lowered.as_text("hlo") self.assertNotRegex(hlo, r"sine.*metadata=.*source_file=.*") + @jtu.run_on_devices('gpu', 'tpu') + def test_mismatched_backends_raises(self): + @jax.jit + def f(x): + return x * 2 + + x = jnp.arange(1) + f_lowered = f.lower(x) + serialized, in_tree, out_tree = serialize(f_lowered.compile()) + with self.assertRaisesRegex( + ValueError, + 'Execution devices belong to a client other than `backend`'): + deserialize_and_load(serialized, in_tree, out_tree, backend='cpu', + execution_devices=jax.devices()[:1]) if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/api_test.py b/tests/api_test.py index aece7b19fdfb..74022d2207b5 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -16,7 +16,6 @@ import collections import collections.abc -from collections.abc import Callable import concurrent.futures from contextlib import contextmanager import copy @@ -43,7 +42,6 @@ from absl import logging from absl.testing import absltest, parameterized import jax -from jax import custom_derivatives as custom_derivatives_public from jax import device_put, float0, grad, hessian, jacfwd, jacrev, jit from jax import lax from jax import tree_util @@ -51,24 +49,20 @@ from jax._src import array from jax._src import config from jax._src import core -from jax._src import custom_derivatives from jax._src import linear_util as lu from jax._src import test_util as jtu from jax._src import xla_bridge from jax._src import debugging from jax._src import pjit as pjit_lib +from jax._src import sharding_impls from jax._src.ad_checkpoint import saved_residuals from jax._src.interpreters import ad as ad_internal from jax._src.interpreters import mlir from jax._src.interpreters import partial_eval as pe from jax._src.compilation_cache import is_persistent_cache_enabled -from jax._src.lib import xla_extension +from jax._src.lib import _jax import jax._src.util as jax_util from jax.ad_checkpoint import checkpoint_name, checkpoint as new_checkpoint -import jax.custom_batching -import jax.custom_derivatives -import jax.custom_transpose -import jax.experimental.custom_dce from jax.errors import (UnexpectedTracerError, TracerIntegerConversionError, ConcretizationTypeError, TracerBoolConversionError) from jax.experimental import pjit @@ -510,6 +504,27 @@ def test_device_put_aliasing(self): may_alias=False, donate=False) self.assertNotEqual(id(arr), id(out)) + def test_device_put_aliasing_with_diff_compatible_sharding(self): + if jax.device_count() < 2: + raise unittest.SkipTest("Test requires >= 2 devices") + + mesh = jax.sharding.Mesh( + np.array(jax.devices()[:2]).reshape((2, 1)), ("x", "y") + ) + x = jax.device_put( + np.arange(16).reshape((4, 4)), + jax.NamedSharding(mesh, P("x", None)), + ) + expanded_mesh = jax.sharding.Mesh( + np.array(jax.devices()[:2]).reshape((1, 2, 1)), ("replicas", "x", "y") + ) + dst_sharding = jax.NamedSharding(expanded_mesh, P("x", None)) + # No transfer should happen because the array is aliased to compatible + # sharding that only has a mesh with an additional dimension of size 1. + with jax.transfer_guard_device_to_device("disallow_explicit"): + res = jax.device_put(x, dst_sharding, may_alias=True) + self.assertEqual(dst_sharding, res.sharding) + @parameterized.named_parameters( ("argnums", "donate_argnums", 0), ("argnames", "donate_argnames", 'x'), @@ -1361,7 +1376,7 @@ def f(x): "exec_time_optimization_effort": 0.0, })(1.0) # doesn't crash. - with self.assertRaisesRegex(xla_extension.XlaRuntimeError, "No such"): + with self.assertRaisesRegex(_jax.XlaRuntimeError, "No such"): f_jit = jit( f, compiler_options={ @@ -1402,12 +1417,12 @@ def f(x): lowered = f_jit.lower(1.) self.assertRaisesRegex( - xla_extension.XlaRuntimeError, "No such compile option: 'invalid_key'", + _jax.XlaRuntimeError, "No such compile option: 'invalid_key'", lambda: lowered.compile( compiler_options={"invalid_key": "invalid_value"})) self.assertRaisesRegex( - xla_extension.XlaRuntimeError, "is not a valid bool value.", + _jax.XlaRuntimeError, "is not a valid bool value.", lambda: lowered.compile( compiler_options={"xla_embed_ir_in_executable": "invalid_value"})) @@ -1422,7 +1437,7 @@ def f(x): # We should still error on invalid options after some valid compiles with self.assertRaisesRegex( - xla_extension.XlaRuntimeError, "No such compile option: 'invalid_key'"): + _jax.XlaRuntimeError, "No such compile option: 'invalid_key'"): jit(f, compiler_options={"invalid_key": "invalid_value"})(1.) def test_lower_compile_with_compiler_options_multiple(self): @@ -1447,7 +1462,7 @@ def f(x): # We should still error on invalid options after some valid compiles self.assertRaisesRegex( - xla_extension.XlaRuntimeError, "No such compile option: 'invalid_key'", + _jax.XlaRuntimeError, "No such compile option: 'invalid_key'", lambda: lowered.compile( compiler_options={"invalid_key": "invalid_value"})) @@ -1482,7 +1497,7 @@ def f(k): def test_caches_depend_on_axis_env(self): # https://github.com/jax-ml/jax/issues/9187 - f = lambda: lax.psum(1, "i") + f = lambda: lax.axis_size("i") g = jax.jit(f) expected = jax.vmap(f, axis_name="i", axis_size=2, out_axes=None)() ans = jax.vmap(g, axis_name="i", axis_size=2, out_axes=None)() @@ -1626,6 +1641,27 @@ def f(x): assert g(2.0) == 4.0 assert len(side) == 1 + @jtu.thread_unsafe_test() # Concurrent ache eviction means we may retrace. + def test_fwd_and_bwd(self): + def f(x, W): + return x @ W + + x = W = cot_out = jnp.ones((4,4)) + expected_y, f_vjp = api.vjp(f, x, W) + expected_cot_x, expected_cot_W = f_vjp(cot_out) + + fwd, bwd = api.fwd_and_bwd(f, argnums=(0,1)) + y, residuals = fwd(x, W) + cot_x, cot_W = bwd(residuals, cot_out) + + self.assertArraysAllClose(y, expected_y) + self.assertArraysAllClose(cot_x, expected_cot_x) + self.assertArraysAllClose(cot_W, expected_cot_W) + + with jax.no_tracing(): + y, residuals = fwd(x, W) + cot_x, cot_W = bwd(residuals, cot_out) # no recompilation + @parameterized.named_parameters( {"testcase_name": f"_{transform.__name__}", "transform": transform} for transform in [grad, jacfwd, jacrev]) @@ -1937,6 +1973,75 @@ def test_device_put_sharding_mismatched_tree_different_leaf_count(self): ): jax.device_put((x, y, z), device=(s1, s2)) + def test_internal_device_put_with_device(self): + # Hitting the cache for a single-device jitted execution while using a numpy + # array calls internal `DevicePutWithDevice`. + f = jax.jit(lambda x: x + 1) + f(np.arange(8)) + + with jtu.count_internal_device_puts() as counts: + f(np.arange(8)) + self.assertEqual(counts(), {"device_put_with_device": 1}) + + def test_internal_device_put_fully_replicated(self): + if jax.device_count() < 2: + raise unittest.SkipTest("Test requires >= 2 devices") + + # Creating an array from a numpy array with a fully-replicated sharding + # calls internal `DevicePutWithSharding`, taking the fully-replicated sub + # case. + mesh = jax.sharding.Mesh(np.array(jax.devices()[:2]), "x") + sharding = jax.NamedSharding(mesh, P()) + + with jtu.count_internal_device_puts() as counts: + jax.device_put(np.arange(8), sharding) + self.assertEqual( + counts(), + {"device_put_with_sharding": 1, "device_put_fully_replicated": 1}, + ) + + def test_internal_device_put_batched(self): + if jax.device_count() < 2: + raise unittest.SkipTest("Test requires >= 2 devices") + + # Creating an array from a numpy array with a non-fully-replicated sharding + # calls internal `DevicePutWithSharding`, performing batched creation of a + # multi-shard array. + mesh = jax.sharding.Mesh(np.array(jax.devices()[:2]), "x") + sharding = jax.NamedSharding(mesh, P("x")) + + with jtu.count_internal_device_puts() as counts: + jax.device_put(np.arange(8), sharding) + self.assertEqual( + counts(), {"device_put_with_sharding": 1, "device_put_batched": 1} + ) + + def test_internal_device_put_assembled(self): + if jax.device_count() < 2: + raise unittest.SkipTest("Test requires >= 2 devices") + + # Creating an array from per-device JAX arrays calls internal + # `DevicePutWithSharding`, performing per-shard array adoption followed by + # assembly. + mesh = jax.sharding.Mesh(np.array(jax.devices()[:2]), "x") + sharding = jax.NamedSharding(mesh, P("x")) + + arr = np.arange(8) + per_device_arrs = { + # Use uncommitted arrays that are not aligned with the destination + # sharding so that we trigger `BatchedDevicePut`. + sharding_impls.hashed_index(index): jnp.array(arr[index]) + for _, index in sharding.devices_indices_map(arr.shape).items() + } + data_callback = lambda index: per_device_arrs[ + sharding_impls.hashed_index(index) + ] + with jtu.count_internal_device_puts() as counts: + jax.make_array_from_callback(arr.shape, sharding, data_callback) + self.assertEqual( + counts(), {"device_put_with_sharding": 1, "device_put_assembled": 1} + ) + def test_device_put_custom_type_not_accepting_none_leaves(self): class CustomNode(list): @@ -3120,7 +3225,6 @@ def test_error_for_invalid_dtype(self): def test_vmap_preserves_docstr(self): def superfun(a): """Does things with stuff.""" - pass self.assertRegex(api.vmap(superfun).__doc__, "\n".join([ "Vectorized version of superfun.*", @@ -3917,6 +4021,9 @@ def test_default_device(self): def test_dunder_jax_array(self): # https://github.com/jax-ml/jax/pull/4725 + @partial(jax.tree_util.register_dataclass, + data_fields=['jax_val'], + meta_fields=[]) class AlexArray: def __init__(self, jax_val): self.jax_val = jax_val @@ -3926,10 +4033,16 @@ def __jax_array__(self): shape = property(lambda self: self.jax_val.shape) x = AlexArray(jnp.array([1., 2., 3.])) + + y = jax.jit(lambda x: x)(x) + self.assertIsInstance(x, AlexArray) + self.assertArraysEqual(jnp.asarray(x), jnp.asarray(y)) + y = jnp.sin(x) self.assertAllClose(y, jnp.sin(jnp.array([1., 2., 3.]))) y = api.grad(api.jit(lambda x: jnp.sin(x).sum()))(x) - self.assertAllClose(y, jnp.cos(jnp.array([1., 2., 3.]))) + self.assertIsInstance(y, AlexArray) + self.assertAllClose(jnp.asarray(y), jnp.cos(jnp.array([1., 2., 3.]))) x = AlexArray(jnp.array([[1., 2., 3.]])) y = api.pmap(jnp.sin)(x) @@ -3947,6 +4060,19 @@ def __jax_array__(self): a2 = jnp.array(((x, x), [x, x])) self.assertAllClose(np.array(((1, 1), (1, 1))), a2) + def test_dunder_jax_array_warnings(self): + class AlexArray: + def __init__(self, jax_val): + self.jax_val = jax_val + def __jax_array__(self): + return self.jax_val + + f = jax.jit(lambda x: x) + a = AlexArray(jnp.arange(4)) + msg = r"Triggering of __jax_array__\(\) during abstractification is deprecated." + with self.assertDeprecationWarnsOrRaises('jax-abstract-dunder-array', msg): + f(a) + @jtu.thread_unsafe_test() # count_jit_tracing_cache_miss() isn't thread-safe def test_eval_shape_weak_type(self): # https://github.com/jax-ml/jax/issues/23302 @@ -4313,6 +4439,21 @@ def g(x, y): for i in range(3): # Loop verifies we exercise both Python and C++ dispatch self.assertEqual(2 * i, g(2, i), msg=i) + def test_make_jaxpr_static_argnums_order(self): + # https://github.com/jax-ml/jax/issues/28065 + def f(a, b, c): + x = a + c + y = b * c + z = x - y + return z + + for static_argnums in [(1, 0), (0, 1)]: + val = jax.jit(f, static_argnums=static_argnums)(1, 2, 3) + self.assertEqual(val, -2) + jaxpr = jax.make_jaxpr(f, static_argnums=static_argnums)(1, 2, 3) + self.assertEqual(jaxpr.eqns[0].invars[0].val, 1) + self.assertEqual(jaxpr.eqns[1].invars[0].val, 2) + def test_fastpath_cache_confusion(self): # https://github.com/jax-ml/jax/issues/12542 @jax.jit @@ -4366,13 +4507,6 @@ def foo(x): with self.assertRaisesRegex(TypeError, "applied to foo"): f_vjp(1.0, 1.0) - def test_shapedtypestruct_sharding_error(self): - with self.assertRaisesRegex( - ValueError, - "sharding should be an instance of `jax.sharding.Sharding`."): - jax.ShapeDtypeStruct((8, 2), np.float32, - sharding=jax.sharding.PartitionSpec('x')) - def test_make_jaxpr_weakref(self): class Foo(NamedTuple): x: int @@ -4424,6 +4558,7 @@ def test_grad_conj_symbolic_zeros(self): out = jax.grad(f)(3.0) # doesn't crash self.assertAllClose(out, 1., check_dtypes=False) + @jtu.thread_unsafe_test() def test_cache_clear_pmap(self): @jax.pmap def f(i): @@ -4466,64 +4601,214 @@ def add(x): self.assertEqual(tracing_add_count, 2) @jtu.thread_unsafe_test() # logging is not thread-safe - def test_cache_miss_explanations(self): - @jax.jit - def f(x, y): - return jnp.sin(x) * y['hi'] + def test_cache_miss_explanations_skip_internals(self): + if is_persistent_cache_enabled(): + self.skipTest('With persistent cache, we see the cache misses') + + with config.explain_cache_misses(True): + with self.assertNoLogs(level='WARNING'): + for i in range(2): + jnp.sin(jnp.arange(i + 1, dtype=np.float32)) + @jtu.thread_unsafe_test() # logging is not thread-safe + def test_cache_miss_explanations_first_miss(self): + @jax.jit + def f(x): return x x = jnp.float32(1.) - y = {'hi': jnp.arange(3., dtype='float32')} expected_log_len = 1 if not is_persistent_cache_enabled() else 3 - # print on first miss, not on hit + with config.explain_cache_misses(True): + with self.assertLogs(level="WARNING") as cm: + f(x) + f(x) + self.assertLen(cm.output, expected_log_len) + msg = cm.output[0] + self.assertIn("TRACING CACHE MISS", msg) + self.assertIn("never seen function", msg) + self.assertNotIn("explanation unavailable!", msg) + + @jtu.thread_unsafe_test() # logging is not thread-safe + def test_cache_miss_explanations_other_in_tree(self): + @jax.jit + def f(*args, **kwargs): return args[0] + + f(0., 1., y=(2., 2.1)) + + with config.explain_cache_misses(True): + with self.assertLogs(level="WARNING") as cm: + # Same number of leaves but different trees + f(0., (1., 1.1), y=2.) + self.assertLen(cm.output, 1) + msg = cm.output[0] + self.assertIn("different input pytree", msg) + self.assertNotIn("explanation unavailable!", msg) + + @jtu.thread_unsafe_test() # logging is not thread-safe + def test_cache_miss_explanations_other_arg_passed_as_kwarg(self): + @jax.jit + def f(x, y): return jnp.sin(x) + y + + f(0., 1.) + + # kwarg change + with config.explain_cache_misses(True): + with self.assertLogs(level="WARNING") as cm: + f(0., y=1.) + + self.assertLen(cm.output, 1) + msg = cm.output[0] + self.assertIn("different number of args and kwargs, but same total number", msg) + self.assertIn("now 1 args and kwargs with keys ['y']", msg) + self.assertIn("before 1 args and kwargs with keys []", msg) + self.assertNotIn("explanation unavailable!", msg) + + @jtu.thread_unsafe_test() # logging is not thread-safe + def test_cache_miss_explanations_other_static_argnums(self): + @partial(jax.jit, static_argnums=(0, 2)) + def f(x, y, z): + return y + + f(1., 2., "foo") + + with config.explain_cache_misses(True): + with self.assertLogs(level="WARNING") as cm: + f(1., 2., "bar") + self.assertLen(cm.output, 1) + msg = cm.output[0] + self.assertIn("different value of static args", msg) + self.assertIn("now 1.0, 'bar' and before 1.0, 'foo'", msg) + self.assertNotIn("explanation unavailable!", msg) + + @jtu.thread_unsafe_test() # logging is not thread-safe + def test_cache_miss_explanations_other_static_argnames(self): + @partial(jax.jit, static_argnames="foo") + def f(*, foo): + return 1 + + f(foo="foo") + + with config.explain_cache_misses(True): + with self.assertLogs(level="WARNING") as cm: + f(foo="bar") + self.assertLen(cm.output, 1) + msg = cm.output[0] + self.assertIn("different value of static kwargs", msg) + self.assertIn("now {foo: 'bar'} and before {foo: 'foo'}", msg) + self.assertNotIn('explanation unavailable!', msg) + + @jtu.thread_unsafe_test() # logging is not thread-safe + def test_cache_miss_explanations_other_dtype(self): + @jax.jit + def f(x, y): return x + f(np.float32(0), np.float32(1)) + with config.explain_cache_misses(True): with self.assertLogs(level='WARNING') as cm: - f(x, y) - f(x, y) + f(np.float32(0), np.int32(1)) + self.assertLen(cm.output, 1) + msg = cm.output[0] + self.assertIn("different input types", msg) + self.assertIn("at y, now i32[] and before f32[]", msg) + self.assertNotIn("explanation unavailable!", msg) + + @jtu.thread_unsafe_test() # logging is not thread-safe + def test_cache_miss_explanations_other_weak_type(self): + @jax.jit + def f(x, y): return jnp.sin(x) + y + + y = jnp.arange(4, dtype="float32") + f(jnp.float32(0.), y) + # weak type change (assuming no x64) + if config.enable_x64.value: + self.skipTest("Work only for 32 bit mode") + with config.explain_cache_misses(True): + with self.assertLogs(level="WARNING") as cm: + f(0., y) + expected_log_len = 1 if not is_persistent_cache_enabled() else 3 self.assertLen(cm.output, expected_log_len) msg = cm.output[0] - self.assertIn('TRACING CACHE MISS', msg) - self.assertIn('never seen function', msg) + self.assertIn("different input types", msg) + self.assertIn("at x, now f32[]{weak_type=True} and before f32[]{weak_type=False}", msg) + self.assertIn("https://docs.jax.dev/en/latest/type_promotion.html#weak-types", msg) + self.assertNotIn("explanation unavailable!", msg) + + @jtu.thread_unsafe_test() # logging is not thread-safe + def test_cache_miss_explanations_other_shape(self): + @jax.jit + def f(x, y): return jnp.sin(x) + y + f(np.float32(0), np.arange(1, dtype=np.float32)) - # shape change - y_ = {'hi': jnp.arange(4, dtype='float32')} with config.explain_cache_misses(True): with self.assertLogs(level='WARNING') as cm: - f(x, y_) + f(np.float32(0), np.arange(2, dtype=np.float32)) + expected_log_len = 1 if not is_persistent_cache_enabled() else 3 self.assertLen(cm.output, expected_log_len) msg = cm.output[0] - self.assertIn('never seen input type signature', msg) - self.assertIn('closest seen input type signature has 1 mismatches', msg) - self.assertIn('seen f32[3], but now given f32[4]', msg) + self.assertIn("different input types", msg) + self.assertIn("at y, now f32[2] and before f32[1]", msg) + self.assertNotIn("explanation unavailable!", msg) - # weak type change (assuming no x64) - if not config.enable_x64.value: - with config.explain_cache_misses(True): - with self.assertLogs(level='WARNING') as cm: - f(1., y) - self.assertLen(cm.output, expected_log_len) - msg = cm.output[0] - self.assertIn('weak_type=True', msg) - self.assertIn('https://jax.readthedocs.io/en/latest/type_promotion.html#weak-types', msg) + @jtu.thread_unsafe_test() # logging is not thread-safe + def test_cache_miss_explanations_other_shape_explain_closest(self): + @jax.jit + def f(x): return x + f(np.ones((1, 2), dtype=np.float32)) + f(np.ones((10, 20, 30), dtype=np.float32)) + f(np.ones((1, 2, 3), dtype=np.float32)) - # kwarg change with config.explain_cache_misses(True): with self.assertLogs(level='WARNING') as cm: - f(1, y=y) + f(np.ones((10, 2, 30), dtype=np.float32)) + expected_log_len = 1 if not is_persistent_cache_enabled() else 3 self.assertLen(cm.output, expected_log_len) msg = cm.output[0] - self.assertIn('never seen passing 1 positional args and 1 keyword args', msg) + self.assertIn("key with different input types", msg) + self.assertIn("at x, now f32[10,2,30] and before f32[10,20,30]", msg) + self.assertNotIn("explanation unavailable!", msg) + + @jtu.thread_unsafe_test() # logging is not thread-safe + def test_cache_miss_explanations_other_tracing_config(self): + @jax.jit + def f(x, y): return jnp.sin(x) + y + f(0., 1.) # tracing config change with config.explain_cache_misses(True): - with self.assertLogs(level='WARNING') as cm: - with jax.numpy_rank_promotion('warn'): - f(x, y) - # depending on the backend, we may or may not get persistent cache warnings + with self.assertLogs(level="WARNING") as cm: + with jax.numpy_rank_promotion("warn"): + with jax.default_matmul_precision("high"): + f(0., 1.) + + expected_log_len = 1 if not is_persistent_cache_enabled() else 3 self.assertTrue(1 <= len(cm.output) <= expected_log_len) msg = cm.output[0] - self.assertIn("tracing context doesn't match", msg) + self.assertIn("key with different tracing context", msg) + self.assertIn("now warn and before", msg) + self.assertIn("now high and before", msg) + self.assertNotIn("explanation unavailable!", msg) + + @jtu.thread_unsafe_test() # logging is not thread-safe + def test_cache_miss_explanations_multiple_changes(self): + @jax.jit + def f(x): return jnp.sin(x) + + call_1 = f(np.arange(4, dtype=np.float32)) + with jax.numpy_rank_promotion("warn"): + call_2 = f(np.arange(8, dtype=np.float32)) + + with config.explain_cache_misses(True): + with self.assertLogs(level='WARNING') as cm: + # Matches call_2 in shape but not context, and call_1 in context but + # not in shape. + f(np.arange(8, dtype=np.float32)) + + self.assertLen(cm.output, 1) + msg = cm.output[0] + self.assertIn("key with different input types", msg) + self.assertIn("at x, now f32[8] and before f32[4]", msg) + self.assertIn("key with different tracing context", msg) + self.assertNotIn("explanation unavailable!", msg) @jtu.thread_unsafe_test() # logging is not thread-safe def test_cache_miss_explanations_new_function_in_loop(self): @@ -4547,28 +4832,6 @@ def f(x, y): _, msg = cm.output self.assertIn('another function defined on the same line', msg) - @jtu.thread_unsafe_test() # logging is not thread-safe - def test_cache_miss_explanations_unpacks_transforms(self): - # Tests that the explain_tracing_cache_miss() function does not throw an - # error when unpacking `transforms` with a length greater than 3. - @jax.jit - def f(key): - return jax.random.truncated_normal(key, 1, 1, dtype=jax.numpy.float32) - - with config.explain_cache_misses(True): - with self.assertLogs(level="WARNING") as cm: - f(jax.random.key(seed=123)) - - if is_persistent_cache_enabled(): - # 5 warnings from tracing cache, 5-10 from persistent cache depending on - # the backend - self.assertTrue(10 <= len(cm.output) <= 15) - self.assertTrue(any("TRACING CACHE MISS" in msg for msg in cm.output)) - else: - self.assertLen(cm.output, 5) - for msg in cm.output: - self.assertIn("TRACING CACHE MISS", msg) - def test_cache_miss_explanations_no_source_info(self): # ``operator.add`` is a built-in function and does not have source info. with config.explain_cache_misses(True): @@ -4687,6 +4950,8 @@ def f(inputs): @jtu.run_on_devices("cpu") def test_inner_jit_forwarding_happens(self): + if not config.dynamic_shapes.value: + self.skipTest("Only works for dynamic shapes") jaxpr = jax.make_jaxpr(lambda: jax.jit(lambda x: x)(3))() self.assertLen(jaxpr.jaxpr.outvars, 1) self.assertIsInstance(jaxpr.jaxpr.outvars[0], core.Literal) @@ -4695,6 +4960,8 @@ def test_inner_jit_forwarding_happens(self): @parameterized.parameters(range(8)) @jtu.run_on_devices("cpu") def test_inner_jit_forwarding_correctness(self, num_input_fwd): + if not config.dynamic_shapes.value: + self.skipTest("Only works for dynamic shapes") num_args = 8 rng = np.random.RandomState(0) @@ -4776,7 +5043,7 @@ def sin_of_sin(x): def test_deferred_primal_with_direct_linearize(self): def my_sin_lin(nzs, x): nz, = nzs - return (my_sin_p.bind(x), nz, x, lambda x, t: lax.mul(t, lax.cos(x))) + return (my_sin_p.bind(x, accuracy=None), nz, x, lambda x, t: lax.mul(t, lax.cos(x))) my_sin_p = core.Primitive("my_sin_p") my_sin_p.def_impl(lax.sin) @@ -4786,6 +5053,34 @@ def my_sin_lin(nzs, x): with config.use_direct_linearize(True): jax.grad(my_sin_p.bind)(1.0) # doesn't crash + def test_ensure_compile_time_eval_no_leaks(self): + # https://github.com/jax-ml/jax/issues/25847 + with jax.ensure_compile_time_eval(): + jnp.linalg.solve(jnp.eye(3), jnp.ones(3)) # doesn't crash + + def test_returned_non_jaxtype(self): + + class TestEnum(enum.Enum): + A = enum.auto() + + @jax.tree_util.register_dataclass + @dataclasses.dataclass + class TestClass3: + test_enum_field: TestEnum = dataclasses.field(metadata=dict(static=True)) + test_data_field: int + + def test_jax_function(test_class: TestClass3) -> TestEnum: + return test_class.test_enum_field + + jitted_test_function = jax.jit(test_jax_function) + with self.assertRaisesRegex(TypeError, "returned a value of type"): + jitted_test_function( + TestClass3( + test_data_field=1, + test_enum_field=TestEnum.A, + ) + ) + class RematTest(jtu.JaxTestCase): @@ -4823,8 +5118,8 @@ def f(x): sin_impl = lax.sin_p.impl cos_impl = lax.cos_p.impl try: - lax.sin_p.def_impl(lambda x: sin_calls.append(1) or sin_impl(x)) - lax.cos_p.def_impl(lambda x: cos_calls.append(1) or cos_impl(x)) + lax.sin_p.def_impl(lambda x, **kwargs: sin_calls.append(1) or sin_impl(x, **kwargs)) + lax.cos_p.def_impl(lambda x, **kwargs: cos_calls.append(1) or cos_impl(x, **kwargs)) f_lin(3.) finally: lax.sin_p.def_impl(sin_impl) @@ -5019,7 +5314,7 @@ def g(x): # Make sure that introducing constants in vmap works. constant_introducing_p = core.Primitive('introduce_constant') - constant_introducing_p.def_abstract_eval(core.raise_to_shaped) + constant_introducing_p.def_abstract_eval(lambda x: x) def _constant_introducing_batcher(xs, ds): (x,), (d,) = xs, ds return (x + np.arange(x.size, dtype=x.dtype).reshape(x.shape)), d @@ -5117,7 +5412,7 @@ def f(x, y): called = [] sin_impl = lax.sin_p.impl try: - lax.sin_p.def_impl(lambda x: called.append(1) or sin_impl(x)) + lax.sin_p.def_impl(lambda x, **kwargs: called.append(1) or sin_impl(x, **kwargs)) api.grad(g)(3.) finally: lax.sin_p.def_impl(sin_impl) @@ -5809,17 +6104,17 @@ def f(x, y): res = saved_residuals(f, (2., 3.), y=4.) self.assertLen(res, 6) - self.assertEqual(res[0][0].shape, ()) - self.assertEqual(res[0][1], "from the argument x[0]") + self.assertEqual(res[0][0].shape, (1,)) + self.assertEqual(res[0][1], "from a constant") self.assertEqual(res[1][0].shape, ()) - self.assertEqual(res[1][1], "from the argument x[1]") + self.assertEqual(res[1][1], "from the argument x[0]") self.assertEqual(res[2][0].shape, ()) - self.assertEqual(res[2][1], "from the argument y") + self.assertEqual(res[2][1], "from the argument x[1]") self.assertEqual(res[3][0].shape, ()) - self.assertStartsWith(res[3][1], "output of jitted function 'f'") + self.assertEqual(res[3][1], "from the argument y") self.assertEqual(res[4][0].shape, ()) - self.assertEqual(res[5][0].shape, (1,)) - self.assertStartsWith(res[5][1], "output of jitted function 'f'") + self.assertStartsWith(res[4][1], "output of jitted function 'f'") + self.assertEqual(res[5][0].shape, ()) @parameterized.named_parameters( {"testcase_name": f"{suffix}", "remat": remat} @@ -5901,6 +6196,7 @@ def test_remat_of_scan(self, remat): jtu.check_grads(remat(f), (3.,), order=2, modes=['rev']) jaxpr = api.make_jaxpr(api.linearize(remat(f), 4.)[1])(1.) + print("debug jaxpr: ", str(jaxpr)) self.assertIn(' sin ', str(jaxpr)) self.assertIn(' cos ', str(jaxpr)) @@ -6482,7 +6778,8 @@ def test_const(self): def fun(x): return (x, 1., np.zeros(1, dtype=jnp.float32)) - expected = "{ lambda a:f32[1]; b:f32[]. let in (b, 1.0, a) }" + dtype = "f64" if config.enable_x64.value else "f32" + expected = f"{{ lambda a:f32[1]; b:f32[]. let in (b, 1.0:{dtype}[], a) }}" jaxpr = api.make_jaxpr(fun)(jnp.float32(0.)) self.assertMultiLineStrippedEqual(expected, str(jaxpr)) @@ -6494,9 +6791,9 @@ def f(x): x + 2., lambda xf: xf - x) expected = """{ lambda ; a:f32[]. let - b:bool[] = ge a 0.0 - c:f32[] = add a 1.0 - d:f32[] = add a 2.0 + b:bool[] = ge a 0.0:f32[] + c:f32[] = add a 1.0:f32[] + d:f32[] = add a 2.0:f32[] e:i32[] = convert_element_type[new_dtype=int32 weak_type=False] b f:f32[] = cond[ branches=( @@ -6678,13 +6975,13 @@ def body(c, _): self.assert_dce_result( jaxpr, used_outputs=used_outputs, expected_used_inputs=expected_used_inputs, - expected_num_eqns=1) # 1 b/c scan doesn't have fwding rule + expected_num_eqns=0) used_outputs[7] = expected_used_inputs[7] = True used_outputs[6] = expected_used_inputs[6] = True self.assert_dce_result( jaxpr, used_outputs=used_outputs, expected_used_inputs=expected_used_inputs, - expected_num_eqns=1) + expected_num_eqns=0) # If we use the value at index 3 only, some of the hidden sequence must be # kept but the rest pruned. @@ -6864,4506 +7161,60 @@ def f(x1, x2): self.assert_dce_result(jaxpr, [True, False], [True, True], 5) -class CustomJVPTest(jtu.JaxTestCase): +class BufferDonationTest(jtu.BufferDonationTestCase): - def test_basic(self): - @jax.custom_jvp - def f(x): - return jnp.sin(x) - def f_jvp(primals, tangents): - x, = primals - g, = tangents - return f(x), 2 * jnp.cos(x) * g - f.defjvp(f_jvp) + @jtu.device_supports_buffer_donation() + def test_pmap_donate_argnums_invalidates_input(self): + move = api.pmap(lambda x: x + x - x, donate_argnums=0) + n = jax.local_device_count() + x = api.pmap(lambda x: x)(jnp.ones([n])) + y = move(x) + self.assertDeleted(x) + np.testing.assert_allclose(y, [1.] * n) - x = 3. - self.assertAllClose(f(x), jnp.sin(x)) - self.assertAllClose(api.jvp(f, (x,), (1.,)), - (jnp.sin(x), 2 * jnp.cos(x))) - self.assertAllClose(api.grad(f)(x), 2 * jnp.cos(x)) + @jtu.device_supports_buffer_donation() + def test_pmap_nested_donate_ignored(self): + pmap_fun = jit(lambda x: api.pmap(lambda y: y ** 2, donate_argnums=0)(x)) + a = api.pmap(lambda x: x)(jnp.array([1])) - def test_invariance(self): - @jax.custom_jvp - def f(x): - return jnp.cos(2 * x) / 2. - def f_jvp(primals, tangents): - x, = primals - g, = tangents - return (f(x), 3 * g) - f.defjvp(f_jvp) - def f2(x): - y, _ = api.jvp(f, (x,), (x,)) - return y - def f3(x): - y, _ = api.jvp(f2, (x,), (x,)) - return y - x = 1. - self.assertAllClose(api.jvp(f, (x,), (x,)), - api.jvp(f2, (x,), (x,)), - check_dtypes=False) - self.assertAllClose(api.jvp(f, (x,), (x,)), - api.jvp(f3, (x,), (x,)), - check_dtypes=False) - - def test_python_control_flow(self): - @jax.custom_jvp - def f(x): - if x > 0: - return jnp.sin(x) - else: - return jnp.cos(x) - def f_jvp(primals, tangents): - x, = primals - g, = tangents - if x > 0: - return f(x), 2 * g - else: - return f(x), 3 * g - f.defjvp(f_jvp) - x = 2. - self.assertAllClose(f(x), jnp.sin(x)) - self.assertAllClose(f(-x), jnp.cos(-x)) - self.assertAllClose(api.jvp(f, (x,), (1.,)), - (jnp.sin(x), 2.), - check_dtypes=False) - self.assertAllClose(api.jvp(f, (-x,), (1.,)), - (jnp.cos(-x), 3.), - check_dtypes=False) - self.assertAllClose(api.grad(f)(x), 2., check_dtypes=False) - self.assertAllClose(api.grad(f)(-x), 3., check_dtypes=False) - - def test_vmap(self): - @jax.custom_jvp - def f(x): - assert jnp.ndim(x) == 0 - return jnp.sin(x) - def f_jvp(primals, tangents): - x, = primals - g, = tangents - assert jnp.ndim(x) == jnp.ndim(g) == 0 - return f(x), 2 * jnp.cos(x) * g - f.defjvp(f_jvp) + # NOTE(mattjj): stopped raising error here and instead just ignored + # with self.assertRaisesRegex(ValueError, "nested.*not supported"): + # pmap_fun(a) - x = jnp.arange(3.) - xx = jnp.arange(6.).reshape(2, 3) + pmap_fun(a) # doesn't crash - # vmap of f - self.assertAllClose(api.vmap(f)(x), jnp.sin(x)) - self.assertAllClose(api.vmap(api.vmap(f))(xx), jnp.sin(xx)) - # vmap of jvp of f - self.assertAllClose(api.vmap(lambda x: api.jvp(f, (x,), (x,)))(x), - (jnp.sin(x), 2 * jnp.cos(x) * x)) - self.assertAllClose(api.vmap(api.vmap(lambda x: api.jvp(f, (x,), (x,))))(xx), - (jnp.sin(xx), 2 * jnp.cos(xx) * xx)) +class NamedCallTest(jtu.JaxTestCase): - # jvp of vmap of f - self.assertAllClose(api.jvp(api.vmap(f), (x,), (x,)), - (jnp.sin(x), 2 * jnp.cos(x) * x)) - self.assertAllClose(api.jvp(api.vmap(api.vmap(f)), (xx,), (xx,)), - (jnp.sin(xx), 2 * jnp.cos(xx) * xx)) + def test_non_jaxtype_arg(self): + # For the test to fail without the invalid JaxType filter we need to pass + # in a valid JaxType that forces the invalid Jaxtype to be raised to an + # abstract value. + def f(not_a_jaxtype, a_jaxtype): + # then Jax needs to try and evaluate the abstractified non-JaxType + if not_a_jaxtype: + return a_jaxtype + return 0 - # vmap of jvp of vmap of f - self.assertAllClose(api.vmap(lambda x: api.jvp(api.vmap(f), (x,), (x,)))(xx), - (jnp.sin(xx), 2 * jnp.cos(xx) * xx)) + f = api.named_call(f, name="test") + out = jax.jit(f, static_argnums=(0,))("not a Jaxtype", 1) + self.assertEqual(out, 1) - def test_jit(self): - @jax.custom_jvp - def f(x): - return jnp.sin(x) - def f_jvp(primals, tangents): - x, = primals - g, = tangents - return f(x), 2 * jnp.cos(x) * g - f.defjvp(f_jvp) + @parameterized.parameters(jax.jit, jax.grad, jax.vmap, jax.remat) + def test_jax_transforms(self, transform): + f = jnp.sum + x = jnp.array([1.]) - x = 3. + unnamed_out = transform(f)(x) + named_out = transform(api.named_call(f, name="test"))(x) - # jit - self.assertAllClose(api.jit(f)(x), jnp.sin(x)) - self.assertAllClose(api.jit(api.jit(f))(x), jnp.sin(x)) + self.assertEqual(unnamed_out, named_out) - # jit of jvp - self.assertAllClose(api.jit(lambda x: api.jvp(f, (x,), (x,)))(x), - (jnp.sin(x), 2 * jnp.cos(x) * x), - check_dtypes=False) - - # jvp of jit - self.assertAllClose(api.jvp(api.jit(f), (x,), (x,)), - (jnp.sin(x), 2 * jnp.cos(x) * x), - check_dtypes=False) - - def test_pytrees(self): - @jax.custom_jvp - def f(x): - return {'b': jnp.sin(x['a'])} - def f_jvp(primals, tangents): - x, = primals - g, = tangents - return f(x), {'b': 2 * jnp.cos(x['a']) * g['a']} - f.defjvp(f_jvp) - x = {'a': 3.} - self.assertAllClose(f(x)['b'], jnp.sin(x['a'])) - self.assertAllClose(api.jvp(f, (x,), (x,)), - ({'b': jnp.sin(x['a'])}, - {'b': 2 * jnp.cos(x['a']) * x['a']}), - check_dtypes=False) - - def test_kwargs(self): - # from https://github.com/jax-ml/jax/issues/1938 - @jax.custom_jvp - def my_fun(x, y, c=1.): - return c * (x + y) - def my_jvp(primals, tangents): - x, y, c = primals - t_x, t_y, t_c = tangents - return my_fun(x, y, c), t_c - my_fun.defjvp(my_jvp) - f = lambda x, y: jnp.square(my_fun(x, y, c=2.)).sum() - f(10., 5.) # doesn't crash - api.jvp(f, (10., 5.), (1., 1.)) # doesn't crash - - def test_initial_style(self): - @jax.custom_jvp - def f(x): - return 3 * x - def f_jvp(primals, tangents): - x, = primals - g, = tangents - return f(x), 2 * g - f.defjvp(f_jvp) - - def foo(x): - out, _ = lax.scan(lambda c, _: (f(c), None), x, None, length=1) - return out - - ans = api.grad(foo)(3.) - expected = 2. - self.assertAllClose(ans, expected, check_dtypes=False) - - ans = api.grad(api.jit(foo))(3.) - expected = 2. - self.assertAllClose(ans, expected, check_dtypes=False) - - ans = api.jit(api.grad(foo))(3.) - expected = 2. - self.assertAllClose(ans, expected, check_dtypes=False) - - ans = api.grad(api.grad(foo))(3.) - expected = 0. - self.assertAllClose(ans, expected, check_dtypes=False) - - ans = api.grad(api.grad(api.jit(foo)))(3.) - expected = 0. - self.assertAllClose(ans, expected, check_dtypes=False) - - ans = api.grad(api.jit(api.grad(foo)))(3.) - expected = 0. - self.assertAllClose(ans, expected, check_dtypes=False) - - ans = api.jit(api.grad(api.grad(foo)))(3.) - expected = 0. - self.assertAllClose(ans, expected, check_dtypes=False) - - def test_initial_style_vmap(self): - @jax.custom_jvp - def f(x): - assert jnp.ndim(x) == 0 - return 3 * x - def f_jvp(primals, tangents): - x, = primals - g, = tangents - return f(x), 2 * g - f.defjvp(f_jvp) - - def foo(x): - out, _ = lax.scan(lambda c, _: (f(c), None), x, None, length=1) - return out - - ans = api.vmap(foo)(jnp.ones(3)) - expected = 3. * jnp.ones(3) - self.assertAllClose(ans, expected, check_dtypes=False) - - ans = api.vmap(api.jit(foo))(jnp.ones(3)) - expected = 3. * jnp.ones(3) - self.assertAllClose(ans, expected, check_dtypes=False) - - ans = api.jit(api.vmap(foo))(jnp.ones(3)) - expected = 3. * jnp.ones(3) - self.assertAllClose(ans, expected, check_dtypes=False) - - ans = api.grad(lambda x: api.vmap(foo)(x).sum())(jnp.ones(3)) - expected = 2. * jnp.ones(3) - self.assertAllClose(ans, expected, check_dtypes=False) - - ans = api.grad(lambda x: api.vmap(api.jit(foo))(x).sum())(jnp.ones(3)) - expected = 2. * jnp.ones(3) - self.assertAllClose(ans, expected, check_dtypes=False) - - ans = api.grad(lambda x: api.jit(api.vmap(foo))(x).sum())(jnp.ones(3)) - expected = 2. * jnp.ones(3) - self.assertAllClose(ans, expected, check_dtypes=False) - - ans = api.grad(api.jit(lambda x: api.vmap(foo)(x).sum()))(jnp.ones(3)) - expected = 2. * jnp.ones(3) - self.assertAllClose(ans, expected, check_dtypes=False) - - ans = api.jit(api.grad(lambda x: api.vmap(foo)(x).sum()))(jnp.ones(3)) - expected = 2. * jnp.ones(3) - self.assertAllClose(ans, expected, check_dtypes=False) - - def test_initial_style_vmap_with_collective(self): - - @jax.custom_jvp - def f(x): - return lax.psum(x, 'foo') - - @f.defjvp - def f_jvp(xs, ts): - x, = xs - t, = ts - return lax.psum(x, 'foo'), t - - def g(x): - jaxpr = api.make_jaxpr(f)(x) - return core.eval_jaxpr(jaxpr.jaxpr, [], x)[0] - - v = api.vmap(lambda _, x: g(x), axis_name='foo', in_axes=(0, None), - out_axes=None)(jnp.arange(4.), 2.) - self.assertAllClose(v, 8.) - - def test_closed_over_tracers_error_message(self): - def f(x): - @jax.custom_jvp - def g(y): - return x + y - def g_jvp(primals, tangents): - return g(x), 2 * primals[0] - g.defjvp(g_jvp) - return g(1.) - - self.assertRaises(UnexpectedTracerError, lambda: api.jvp(f, (3.,), (1.,))) - self.assertRaises(UnexpectedTracerError, lambda: api.grad(f)(3.)) - - def test_nondiff_arg(self): - @partial(jax.custom_jvp, nondiff_argnums=(0,)) - def app(f, x): - return f(x) - def app_jvp(f, primals, tangents): - (x,), (t,) = primals, tangents - return app(f, x), 3 * t - app.defjvp(app_jvp) - - ans = app(lambda x: 2 * x, 1) - expected = 2 - self.assertAllClose(ans, expected, check_dtypes=False) - - ans = api.jvp(lambda x: app(lambda y: 2 * y, x), (1.,), (1.,)) - expected = (2., 3.) - self.assertAllClose(ans, expected, check_dtypes=False) - - def test_nondiff_arg_jit_tracer(self): - # This test would pass with "final-style" JIT tracing, but that was - # misleading: it doesn't work with "initial-style" staging, i.e. control - # flow primitives like jax.lax.scan or even pjit. The behavior isn't very - # useful either: instead of using nondiff_argnums here, a user can just pass - # such inputs as ordinary arguments, and ignore the corresponding tangents. - # Then nondiff_argnums can be reserved for (1) non jaxtype data (like a - # string- or callable-valued argument which parameterizes the function or - # rule) or (2) static data (e.g. integers which parameterize shapes). - raise unittest.SkipTest("behavior no longer supported") - - @partial(jax.custom_jvp, nondiff_argnums=(0,)) - def f(x, y): - return x * y - def f_jvp(x, primals, tangents): - (y,), (t_y,) = primals, tangents - return f(x, y), 5 * t_y - f.defjvp(f_jvp) - - @jit - def g(x, y): - return f(x, y) - - ans = api.jvp(lambda y: g(2., y), (3.,), (1.,)) - expected = (6., 5.) - self.assertAllClose(ans, expected, check_dtypes=False) - - def test_nondiff_arg_vmap_tracer(self): - @partial(jax.custom_jvp, nondiff_argnums=(0,)) - def f(x, y): - return x * y - def f_jvp(x, primals, tangents): - (y,), (t_y,) = primals, tangents - return f(x, y), 5 * t_y - f.defjvp(f_jvp) - - g = jax.vmap(f) - - ans = api.jvp(lambda y: g(jnp.array([2.]), y), - (jnp.array([3.]),), (jnp.array([1.]),)) - expected = (jnp.array([6.]), jnp.array([5.])) - self.assertAllClose(ans, expected, check_dtypes=False) - - def test_nondiff_arg_hiding_jvp_tracer(self): - def f(x): - @partial(jax.custom_jvp, nondiff_argnums=(0,)) - def g(h, x): - return h(x) - @g.defjvp - def g_jvp(h, primals, tangents): - x, = primals - t, = tangents - return g(h, x), 2. * t - h = lambda y: x + y # capture x - return g(h, x) - - with self.assertRaises(UnexpectedTracerError): - api.jvp(f, (2.,), (1.,)) - - def test_vmap_axes(self): - raise unittest.SkipTest("TODO") # TODO(mattjj): write test - - def test_pmap(self): - raise unittest.SkipTest("TODO") # TODO(mattjj): write test - - def test_missing_jvp_rule_error_message(self): - @jax.custom_jvp - def foo(x): - return x ** 2 - - self.assertRaisesRegex( - AttributeError, - r"No JVP defined for custom_jvp function foo using defjvp.", - lambda: foo(2)) - self.assertRaisesRegex( - AttributeError, - r"No JVP defined for custom_jvp function foo using defjvp.", - lambda: api.jvp(foo, (2.,), (1.,))) - self.assertRaisesRegex( - AttributeError, - r"No JVP defined for custom_jvp function foo using defjvp.", - lambda: api.grad(foo)(2.)) - - def test_jvp_rule_inconsistent_pytree_structures_error_message(self): - @jax.custom_jvp - def f(x): - return (x**2,) - - @f.defjvp - def foo_jvp(primals, tangents): - x, = primals - t, = tangents - return f(x), [2 * x * t, x] - - f(2.) # doesn't crash - self.assertRaisesRegex( - TypeError, - re.escape( - "Custom JVP rule foo_jvp for function f " - "must produce primal and tangent outputs " - "with equal container (pytree) structures, but got " - "{} and {} respectively.".format( - jax.tree.structure((1,)), - jax.tree.structure([1, 2])) - ), - lambda: api.jvp(f, (2.,), (1.,))) - - def test_primal_tangent_aval_disagreement_error_message(self): - @jax.custom_jvp - def f(x): - return x ** 2 - - @f.defjvp - def foo_jvp(primals, tangents): - x, = primals - t, = tangents - return f(x), jnp.reshape(t, (1,)) - - f(2.) # doesn't crash - self.assertRaisesRegex( - TypeError, - re.escape( - "Custom JVP rule must produce primal and tangent outputs " - "with corresponding shapes and dtypes. " - "Expected float32[] (tangent type of float32[]) but got float32[1]."), - lambda: api.jvp(f, (jnp.float32(2.),), (jnp.float32(1.),))) - - - def test_jvp_rule_doesnt_return_pair_error_message(self): - # https://github.com/jax-ml/jax/issues/2516 - - @jax.custom_jvp - def f(x): - return x ** 2 - - @f.defjvp - def foo_jvp(primals, tangents): - x, = primals - t, = tangents - return t - - f(2.) # doesn't crash - self.assertRaisesRegex( - TypeError, - re.escape( - "Custom JVP rule foo_jvp for function f " - "must produce a pair (list or tuple of length two) " - "representing primal and tangent outputs, but got 1.0"), - lambda: api.jvp(f, (2.,), (1.,))) - - def test_jvp_rule_primal_out_type_doesnt_match_primal_error_message(self): - # https://github.com/lucidrains/flash-attention-jax/issues/7 - - def scan_apply(f, x): - y, _ = jax.lax.scan(lambda x, _: (f(x), None), x, None, length=1) - return y - - @jax.custom_jvp - def f(x): - return x - - @f.defjvp - def f_jvp(primals, tangents): - (x,), (xdot,) = primals, tangents - return (x, x), (xdot, xdot) - - x = jnp.float32(1.) - self.assertRaisesRegex( - TypeError, - re.escape( - "Custom JVP rule f_jvp for function f must produce a pair " - "(list or tuple of length two) where the first element represents " - "the primal output (equal in value to the output of the " - "custom_jvp-decorated function f, and in particular of the " - "same container/pytree structure), but instead the JVP rule " - "output's first element had container/pytree structure:\n" - " (float32[], float32[])\n" - "while the custom_jvp-decorated function f had output " - "container/pytree structure:\n" - " float32[]." - ), - lambda: jax.jvp(lambda x: scan_apply(f, x), (x,), (x,))) - - @f.defjvp - def f_jvp2(primals, tangents): - (x,), (xdot,) = primals, tangents - return jnp.zeros((3, *x.shape), x.dtype), xdot - - self.assertRaisesRegex( - TypeError, - re.escape( - "Custom JVP rule f_jvp2 for function f must produce a pair " - "(list or tuple of length two) where the first element represents " - "the primal output (equal in value to the output of the " - "custom_jvp-decorated function f, and in particular " - "with leaves of the same shape/dtype), but instead the JVP rule " - "output's first element had shapes/dtypes of:\n" - " float32[3]\n" - "while the custom_jvp-decorated function f had output shapes/dtypes" - " of:\n" - " float32[]" - ), - lambda: jax.jvp(lambda x: scan_apply(f, x), (x,), (x,))) - - def test_multiple_rule_invocations(self): - @jax.custom_jvp - def expit(x): - return 1 / (1 + lax.exp(-x)) - - @expit.defjvp - def _expit_jvp(primals, tangents): - (x,), (t,) = primals, tangents - ans = expit(x) - t_out = t * ans * (1 - ans) - return ans, t_out - - def scanned_fun(c, _): - return [expit(c[0])] + [c[i-1] + c[i] for i in range(1, len(c))], None - - def foo(x): - zero = jnp.zeros_like(x) - c, _ = lax.scan(scanned_fun, [x, zero, zero, zero, zero], None, length=10) - return c[-1] - - # just make sure these don't crash - foo(3.) - grad(foo)(3.) - grad(lambda x: jax.vmap(foo)(x).sum())(jnp.arange(3.)) - - def test_hard_stuff(self): - arr = jnp.ones((5, 2, 2)) - api.jit(jax.vmap(jnp.linalg.det))(arr) # doesn't crash - - def test_hard_stuff2(self): - @jax.custom_jvp - def f(x): - return np.zeros(x.shape, x.dtype) - - @f.defjvp - def f_jvp(primals, tangents): - x, = primals - t, = tangents - return f(x), t - - # don't crash - jax.jit(jax.vmap(f))(jnp.arange(3.)) - jax.jit(jax.vmap(jax.grad(f)))(jnp.arange(3.)) - jax.jit(jax.grad(lambda x: jax.vmap(f)(x).sum()))(jnp.arange(3.)) - jax.grad(lambda x: jax.vmap(f)(x).sum())(jnp.arange(3.)) - jax.jvp(jax.vmap(f), (jnp.arange(3.),), (jnp.ones(3),)) - - def test_hard_stuff3(self): - @jax.custom_jvp - def relu(x): - return jnp.maximum(x, 0) - - @relu.defjvp - def _relu_jvp(primals, tangents): - x, = primals - t, = tangents - return relu(x), lax.select(x > 0, t, lax.full_like(t, 0)) - - def scanned_fun(c, _): - return [relu(c[0])] + [c[i-1] + c[i] for i in range(1, len(c))], None - - def f(x): - zero = jnp.zeros_like(x) - c, _ = lax.scan(scanned_fun, [x, zero, zero, zero, zero], None, length=10) - return c[-1] - - # don't crash - jax.jit(jax.vmap(f))(jnp.arange(3.)) - jax.jit(jax.vmap(jax.grad(f)))(jnp.arange(3.)) - jax.jit(jax.grad(lambda x: jax.vmap(f)(x).sum()))(jnp.arange(3.)) - jax.grad(lambda x: jax.vmap(f)(x).sum())(jnp.arange(3.)) - jax.jvp(jax.jit(jax.vmap(f)), (jnp.arange(3.),), (jnp.ones(3),)) - - def test_eval_shape(self): - @jax.custom_jvp - def expit(x): - return 1 / (1 + lax.exp(-x)) - - @expit.defjvp - def _expit_jvp(primals, tangents): - (x,), (t,) = primals, tangents - ans = expit(x) - t_out = t * ans * (1 - ans) - return ans, t_out - - # don't crash - api.eval_shape(expit, jnp.ones((2, 3))) - api.eval_shape(api.grad(lambda x: expit(x).sum()), jnp.ones((2, 3))) - - def test_jaxpr_zeros(self): - # from https://github.com/jax-ml/jax/issues/2657 - @jax.custom_jvp - def f(A, b): - return A @ b - - def f_jvp(primals, tangents): - A, b = primals - dA, db = tangents - z = f(A, b) - dz = A @ db + dA @ b - return z, dz - - f.defjvp(f_jvp) - - def experiment(theta): - def step(q, _): - z = f(jnp.eye(3), jnp.ones(3) * theta) - q += z[0] - return q, q - - q = 0. - q, _ = lax.scan(step, q, None, 4) - return q - - grad(experiment)(1.) # doesn't crash - - def test_linear_in_scan(self): - @jax.custom_jvp - def f(x): - return -x - - @f.defjvp - def f_jvp(primals, tangents): - x, = primals - x_dot, = tangents - return f(x), f(x_dot) - - def foo(x): - out, _ = lax.scan(lambda c, _: (f(c), None), x, None, length=1) - return out - - ans = api.grad(foo)(3.) - expected = -1. - self.assertAllClose(ans, expected, check_dtypes=False) - - def test_custom_jvps_first_rule_is_none(self): - # https://github.com/jax-ml/jax/issues/3389 - @jax.custom_jvp - def f(x, y): - return x ** 2 * y - - f.defjvps(None, lambda x_dot, primal_out, x, y: 2 * x * y * x_dot) - ans = grad(f, 1)(2., 3.) # doesn't crash - expected = 12. - self.assertAllClose(ans, expected, check_dtypes=False) - - def test_concurrent_initial_style(self): - # https://github.com/jax-ml/jax/issues/3843 - def unroll(param, sequence): - def scan_f(prev_state, inputs): - return prev_state, jax.nn.sigmoid(param * inputs) - return jnp.sum(jax.lax.scan(scan_f, None, sequence)[1]) - - def run(): - return jax.grad(unroll)(jnp.array(1.0), jnp.array([1.0])) - - expected = run() - - # we just don't want this to crash - n_workers = 2 - with concurrent.futures.ThreadPoolExecutor(max_workers=n_workers) as e: - futures = [] - for _ in range(n_workers): - futures.append(e.submit(run)) - results = [f.result() for f in futures] - for ans in results: - self.assertAllClose(ans, expected) - - def test_nondiff_argnums_vmap_tracer(self): - # https://github.com/jax-ml/jax/issues/3964 - @partial(jax.custom_jvp, nondiff_argnums=(0, 2)) - def sample(shape, param, seed): - return jax.random.uniform(key=seed, shape=shape, minval=param) - - @sample.defjvp - def sample_jvp(shape, seed, primals, tangents): - param, = primals - dparam, = tangents - dparam = jnp.broadcast_to(dparam, shape) - samples = sample(shape, param, seed) - return samples, samples * dparam # dummy jvp for proof of concept - - # check these don't crash - jax.vmap(lambda seed: sample((2,3), 1., seed))( - jax.random.split(jax.random.key(1), 10)) - jax.jvp(lambda x: sample((2, 3), x, jax.random.key(1)), - (1.,), (1.,)) - - def test_fun_with_nested_calls_2(self): - def call(f, *args): - f = jax.custom_jvp(f) - f.defjvp(lambda primals, tangents: (f(*primals), sum(tangents))) - return f(*args) - - def fun_with_nested_calls_2(x): - def bar(y): - def baz(w): - q = call(lambda x: y, x) - q = q + call(lambda: y) - q = q + call(lambda y: w + y, y) - q = call(lambda w: call(jnp.sin, x) * y, 1.0) + q - return q - return api.jit(baz)(x) - return call(bar, x) - - # test these don't crash - self.assertAllClose(api.jit(fun_with_nested_calls_2)(3.), - fun_with_nested_calls_2(3.)) - api.vmap(fun_with_nested_calls_2)(jnp.arange(3.)) - - def test_closure_with_vmap(self): - # https://github.com/jax-ml/jax/issues/3822 - alpha = np.float32(2.) - - def sample(seed): - @jax.custom_jvp - def f(alpha): - return jax.random.gamma(seed, alpha, shape=[]) - - @f.defjvp - def f_jvp(primal, tangent): - alpha = primal - dalpha = tangent - sample = f(alpha) - partial_alpha = lax.random_gamma_grad(alpha, sample) - return sample, partial_alpha * dalpha - return f(alpha) - - api.vmap(sample)(jax.random.split(jax.random.key(1), 3)) # don't crash - - def test_closure_with_vmap2(self): - # https://github.com/jax-ml/jax/issues/8783 - def h(z): - def f(x): - @jax.custom_jvp - def g(y): - return x * y - - # NOTE: rule closes over vmap tracer - @g.defjvp - def g_jvp(primals, tangents): - (y,), (ydot,) = primals, tangents - return x * y, x * ydot - - return g(z) # NOTE: no vmapped arg - - return jax.vmap(f)(jnp.arange(3., dtype='float32')) - - primals, tangents = jax.jvp(h, (jnp.float32(1.),), (jnp.float32(2.),)) - self.assertAllClose(primals , jnp.arange(3., dtype='float32')) - self.assertAllClose(tangents, 2 * jnp.arange(3., dtype='float32')) - - def test_float0(self): - scalar_float0 = jnp.zeros((), dtype=float0) - @jax.custom_jvp - def f(x, y): - return x, y - def f_jvp(primals, _): - x, y = primals - return (x, y), (2., custom_derivatives_public.zero_from_primal(y)) - f.defjvp(f_jvp) - - primals = (2., 3) - tangents = (np.ones(()), scalar_float0) - expected_tangents = (2., scalar_float0) - self.assertAllClose(api.jvp(f, primals, tangents), - (primals, expected_tangents)) - - def test_float0_initial_style(self): - scalar_float0 = jnp.zeros((), dtype=float0) - @jax.custom_jvp - def f(x, y): - return x, y - def f_jvp(primals, _): - x, y = primals - return (x, y), (2., custom_derivatives_public.zero_from_primal(y)) - f.defjvp(f_jvp) - - def foo(x, y): - out, _ = lax.scan(lambda c, _: (f(*c), None), (x, y), None, length=1) - return out - - primals = (2., 3) - tangents = (np.ones(()), scalar_float0) - expected_tangents = (2., scalar_float0) - - self.assertAllClose(api.jvp(foo, primals, tangents), - (primals, expected_tangents)) - - def test_remat(self): - @jax.custom_jvp - def f(x): - return jnp.sin(x) - def f_jvp(primals, tangents): - x, = primals - g, = tangents - return f(x), 2 * jnp.cos(x) * g - f.defjvp(f_jvp) - - @jax.remat - def g(x): - return f(f(x)) - - ans = g(2.) - expected = np.sin(np.sin(2.)) - self.assertAllClose(ans, expected, check_dtypes=False) - - ans = api.grad(g)(2.) - expected = 4. * api.grad(lambda x: jnp.sin(jnp.sin(x)))(2.) - self.assertAllClose(ans, expected, check_dtypes=False) - - def test_remat_higher_order(self): - @jax.custom_jvp - def f(x): - return jnp.sin(x) - def f_jvp(primals, tangents): - x, = primals - g, = tangents - return f(x), 2 * jnp.cos(x) * g - f.defjvp(f_jvp) - - def g(x): - return f(f(x)) - - ans = api.grad(api.grad(new_checkpoint(g)))(2.) - expected = api.grad(api.grad(g))(2.) - self.assertAllClose(ans, expected, check_dtypes=False) - - ans = api.grad(new_checkpoint(api.grad(g)))(2.) - expected = api.grad(api.grad(g))(2.) - self.assertAllClose(ans, expected, check_dtypes=False) - - ans = api.grad(api.grad(api.grad(new_checkpoint(g))))(2.) - expected = api.grad(api.grad(api.grad(g)))(2.) - self.assertAllClose(ans, expected, check_dtypes=False) - - def test_initial_style_vmap_2(self): - # This is like test_initial_style_vmap except the primal function closes - # over an array constant. - y = jnp.arange(1., 4.) - - @jax.custom_jvp - def f(x): - assert jnp.ndim(x) == 0 - return 3 * x * jnp.sum(y) - def f_jvp(primals, tangents): - x, = primals - g, = tangents - return f(x), 2 * g - f.defjvp(f_jvp) - - def foo(x): - out, _ = lax.scan(lambda c, _: (f(c), None), x, None, length=1) - return out - - ans = api.grad(lambda x: api.vmap(foo)(x).sum())(jnp.ones(3)) - expected = 2. * jnp.ones(3) - self.assertAllClose(ans, expected, check_dtypes=False) - - ans = api.grad(lambda x: api.vmap(api.jit(foo))(x).sum())(jnp.ones(3)) - expected = 2. * jnp.ones(3) - self.assertAllClose(ans, expected, check_dtypes=False) - - ans = api.grad(lambda x: api.jit(api.vmap(foo))(x).sum())(jnp.ones(3)) - expected = 2. * jnp.ones(3) - self.assertAllClose(ans, expected, check_dtypes=False) - - ans = api.grad(api.jit(lambda x: api.vmap(foo)(x).sum()))(jnp.ones(3)) - expected = 2. * jnp.ones(3) - self.assertAllClose(ans, expected, check_dtypes=False) - - ans = api.jit(api.grad(lambda x: api.vmap(foo)(x).sum()))(jnp.ones(3)) - expected = 2. * jnp.ones(3) - self.assertAllClose(ans, expected, check_dtypes=False) - - def test_custom_jvp_vmap_broadcasting_interaction(self): - # https://github.com/jax-ml/jax/issues/6452 - def f2(y, z): - v1 = z - v2 = jnp.sum(y) + z - return jnp.logaddexp(v1, v2) - - def f1(y, z): - v = api.vmap(lambda _y: f2(_y, z))(y) - return jnp.sum(v) - - y = jnp.ones((3, 2)) - f = lambda z: f1(y, z) - z = 0.1 - val, g = api.value_and_grad(f)(z) - self.assertEqual(val.shape, ()) - self.assertEqual(g.shape, ()) - - def test_custom_jvp_vmap_broadcasting_interaction_2(self): - # https://github.com/jax-ml/jax/issues/5849 - @jax.custom_jvp - def transform(box, R): - if jnp.isscalar(box) or box.size == 1: - return R * box - elif box.ndim == 2: - return jnp.einsum('ij,j->i', box, R) - raise ValueError() - - @transform.defjvp - def transform_jvp(primals, tangents): - box, R = primals - dbox, dR = tangents - return (transform(box, R), dR + transform(dbox, R)) - - def periodic_general(box): - def displacement_fn(Ra, Rb, **kwargs): - _box = kwargs.get('box', box) - return transform(_box, Ra - Rb) - - return displacement_fn - - N = 250 - - scalar_box = 1.0 - displacement = periodic_general(scalar_box) - - key = jax.random.key(0) - R = jax.random.uniform(key, (N, 2)) - - def energy_fn(box): - d = partial(displacement, box=box) - d = api.vmap(api.vmap(d, (None, 0)), (0, None)) - return jnp.sum(d(R, R) ** 2) - - self.assertEqual(grad(energy_fn)(scalar_box).shape, ()) - - def test_custom_jvp_implicit_broadcasting(self): - # https://github.com/jax-ml/jax/issues/6357 - if config.enable_x64.value: - raise unittest.SkipTest("test only applies when x64 is disabled") - - @jax.custom_jvp - def projection_unit_simplex(x: jax.Array) -> jax.Array: - """Projection onto the unit simplex.""" - s = 1.0 - n_features = x.shape[0] - u = jnp.sort(x)[::-1] - cssv = jnp.cumsum(u) - s - ind = jnp.arange(n_features, dtype=x.dtype) + 1 - cond = u - cssv / ind > 0 - idx = jnp.count_nonzero(cond) - threshold = cssv[idx - 1] / idx.astype(x.dtype) - return jax.nn.relu(x - threshold) - - - @projection_unit_simplex.defjvp - def projection_unit_simplex_jvp(primals, tangents): - x, = primals - x_dot, = tangents - primal_out = projection_unit_simplex(x) - supp = (primal_out > 0).astype(x_dot.dtype) - card = jnp.count_nonzero(supp).astype(x_dot.dtype) - tangent_out = supp * x_dot - (jnp.dot(supp, x_dot) / card) * supp - return primal_out, tangent_out - - rng = self.rng() - x = rng.rand(5).astype(np.float32) - - J_rev = jax.jacrev(projection_unit_simplex)(x) - J_fwd = jax.jacfwd(projection_unit_simplex)(x) - - p = projection_unit_simplex(x) - support = (p > 0).astype(jnp.float32) - cardinality = jnp.count_nonzero(support).astype(support.dtype) - J_true = jnp.diag(support) - jnp.outer(support, support) / cardinality - self.assertAllClose(J_true, J_fwd) - self.assertAllClose(J_true, J_rev) - - proj = jax.vmap(projection_unit_simplex) - - def fun(X): - return jnp.sum(proj(X) ** 2) - - rng = self.rng() - X = rng.rand(4, 5).astype(np.float32) - U = rng.rand(4, 5) - U /= np.sqrt(np.sum(U ** 2)) - U = U.astype(np.float32) - - eps = 1e-3 - dir_deriv_num = (fun(X + eps * U) - fun(X - eps * U)) / (2 * eps) - dir_deriv = jnp.vdot(jax.grad(fun)(X), U) - self.assertAllClose(dir_deriv, dir_deriv_num, atol=1e-3) - - def test_vmap_inside_defjvp(self): - # https://github.com/jax-ml/jax/issues/3201 - seed = 47 - key = jax.random.key(seed) - mat = jax.random.normal(key, (2, 3)) - - @jax.custom_jvp - def f(mat, aux): - num_rows, num_cols = mat.shape - return jnp.ones((num_rows, 1)) / num_cols - - @f.defjvp - def f_jvp(primals, tangents): - mat, aux = primals - vec, _ = tangents - output = f(*primals) - num_rows, num_cols = mat.shape - size = num_rows * num_cols - # ----- - bd_mat = mat.reshape(1, 1, num_rows, num_cols) - bd_mat = jnp.tile(bd_mat, reps=(num_rows, num_cols)) - bd_mat = bd_mat.reshape(size, num_rows, num_cols) - # ----- - rowsum = jnp.sum(mat, axis=1, keepdims=True) - colsum = jnp.sum(mat, axis=0, keepdims=True) - bd_rowsum = jnp.tile(rowsum, reps=(1, num_rows)) - bd_colsum = jnp.tile(colsum, reps=(num_cols, 1)) - # ----- - bd_vec = vec.reshape(size, 1) - # ----- - def operate(mx, val): - buf = 0 - for i in range(2): - buf = buf + jnp.matmul(mx, bd_colsum) / jnp.power(aux, i) - buf = jnp.matmul(bd_rowsum, buf) - return buf * val[None, :] - # ----- - # Vertorizing will raise shape error - bd_buf = jax.vmap(operate, in_axes=(0, 0), out_axes=0)(bd_mat, bd_vec) - # ----- - bd_buf = bd_buf / aux - jvp = jnp.sum(bd_buf, axis=0) - jvp = jnp.mean(jvp, axis=1, keepdims=True) - # ----- - # JVP ends successfully, but still raise an error - return (output, jvp) - - jax.grad(lambda mat, aux: jnp.sum(f(mat, aux)))(mat, 0.5) # doesn't crash - - def test_custom_jvp_unbroadcasting(self): - # https://github.com/jax-ml/jax/issues/3056 - a = jnp.array([1., 1.]) - - @jax.custom_jvp - def f(x): - return a * x - - @f.defjvp - def f_jvp(primals, tangents): - x, = primals - dx, = tangents - return a * x, a * dx - - shape = grad(lambda x: jnp.sum(f(x)))(jnp.array(1.)).shape - self.assertEqual(shape, ()) - - def test_maybe_perturbed_internal_helper_function(self): - # This is a unit test for an internal API. We include it so as not to - # regress https://github.com/jax-ml/jax/issues/9567. For an explanation of - # this helper function, see https://github.com/jax-ml/jax/issues/6415. - def f(x): - def g(y, _): - z = y * x - self.assertTrue(custom_derivatives._maybe_perturbed(z)) - return y, None - g(1, None) - return lax.scan(g, 1, xs=None, length=1)[0] - - jax.jvp(f, (1.0,), (1.0,)) # assertions inside f - - def test_maybe_perturbed_int_regression(self): - # see https://github.com/jax-ml/jax/discussions/9951 - - @jax.jit - def f(): - x = jnp.array(1) - _, aux_args = custom_derivatives.closure_convert(lambda: x) - self.assertEmpty(aux_args) - f() - - def test_sinc_constant_function_batching(self): - # https://github.com/jax-ml/jax/pull/10756 - batch_data = jnp.arange(15.).reshape(5, 3) - - @jax.vmap - def f(x): - return jax.lax.map(jnp.sinc, x) - g = lambda param: f(param * batch_data).sum() - - @jax.vmap - def f_ref(x): - return jnp.stack([jnp.sinc(x_) for x_ in x]) - g_ref = lambda param: f_ref(param * batch_data).sum() - - grad = jax.grad(g )(0.1) # doesn't crash - grad_ref = jax.grad(g_ref)(0.1) - self.assertAllClose(grad, grad_ref, check_dtypes=False) - - @parameterized.named_parameters( - ('jit_vmap', True, True), - ('jit', True, False), - ('vmap', False, True), - ('', False, False), - ) - def test_symbolic_zero_custom_jvp(self, maybe_jit, maybe_vmap): - def f(static_scalar, static_array, dyn_scalar, dyn_array): - out1 = static_scalar + dyn_scalar - out2 = static_array + dyn_array - return out1, out2 - - def _pack(x): - return lax.broadcast(x, (1,)) - - def _unpack(x): - (x,) = x - return x - - def _vmap(fun): - def _fun(*args): - args = jax.tree.map(_pack, args) - out = jax.vmap(fun)(*args) - out = jax.tree.map(_unpack, out) - return out - return _fun - - f = jax.custom_jvp(f) - - @partial(f.defjvp, symbolic_zeros=True) - def f_jvp(primals, tangents): - static_scalar, *_ = primals - t_static, t_static_arr, t_dyn_scalar, t_dyn_array = tangents - self.assertIs(type(t_static) , custom_derivatives_public.SymbolicZero) - self.assertIs(type(t_static_arr), custom_derivatives_public.SymbolicZero) - self.assertEqual(t_static.shape, ()) - self.assertEqual(t_static_arr.shape, (2,)) - return f(*primals), (static_scalar + 90, t_dyn_array + 91) - - def g(dyn_scalar, dyn_array): - if maybe_vmap: - f_ = _vmap(f) - else: - f_ = f - return f_(1., jnp.array([2., 3.]), dyn_scalar, dyn_array) - - def run(primal_ins, tangent_ins): - return jax.jvp(g, primal_ins, tangent_ins) - - if maybe_jit: - run = jax.jit(run) - - primal_ins = (4., jnp.array([5., 6.])) - tangent_ins = (7., jnp.array([8., 9.])) - primal_outs, tangent_outs = run(primal_ins, tangent_ins) - primal_out1, primal_out2 = primal_outs - tangent_out1, tangent_out2 = tangent_outs - scalar_type = jax.Array if maybe_jit or maybe_vmap else float - self.assertIsInstance(primal_out1, scalar_type) - self.assertAllClose(primal_out1, 5.) - self.assertIsInstance(tangent_out1, scalar_type) - self.assertAllClose(tangent_out1, 91.) - self.assertIsInstance(primal_out2, jax.Array) - self.assertArraysAllClose(primal_out2, jnp.array([7., 9.])) - self.assertIsInstance(tangent_out2, jax.Array) - self.assertArraysAllClose(tangent_out2, jnp.array([99., 100.])) - - def test_symbolic_zero_custom_jvp_vmap_output(self): - @jax.custom_jvp - def f(x, y): - return x * y - - @partial(f.defjvp, symbolic_zeros=True) - def f_jvp(primals, tangents): - x, y = primals - x_dot, y_dot = tangents - self.assertIs(type(y_dot), custom_derivatives_public.SymbolicZero) - return f(x, y), y_dot - - jax.grad(lambda x, y: jax.vmap(f)(x, y).sum())(jnp.ones(3), jnp.ones(3)) - - def test_symbolic_zeros_memoization_caching(self): - # Tests multiple zero patterns for partial_eval._memoize, and also tests - # that we're okay with stores being occupied with equal values. - - @jax.custom_jvp - def f(x, y): - return x * y - - @partial(f.defjvp, symbolic_zeros=True) - def f_jvp(primals, tangents): - x, y = primals - x_dot, y_dot = tangents - return f(x, y), y_dot - - f_ = core.jaxpr_as_fun(jax.make_jaxpr(f)(2., 3.)) - _ = jax.linearize(f_, 2., 3.) - _ = jax.linearize(lambda x: f_(x, 3.), 2.) # don't crash! - - def test_symbolic_zeros_under_jit(self): - # https://github.com/jax-ml/jax/issues/14833 - Zero = jax.custom_derivatives.SymbolicZero - - @jax.custom_jvp - def f(x, y): - return x * y - - @partial(f.defjvp, symbolic_zeros=True) - def fjvp(primals, tangents): - x, y = primals - tx, ty = tangents - assert type(tx) is not Zero or type(ty) is not Zero - return f(x, y), ( - ty if type(tx) is Zero else - tx if type(ty) is Zero else - tx + ty) - - jax.jacfwd(jax.jit(f))(0.1, 0.2) # don't crash - - def test_custom_jvp_functools_partial(self): - def fun(x, y, a): - return x + y * a - - fun_wrapped = functools.partial(fun, a = 0.1) - - def jvp_fn(primals, tangents): - return jax.jvp(fun_wrapped, primals, tangents) - - fn = jax.custom_jvp(fun_wrapped) - fn.defjvp(jvp_fn) - - self.assertEqual((1.0, 0.1), jax.grad(lambda args: fn(*args))((1.0, 2.0))) - - def test_run_rules_more_than_once(self): - # https://github.com/jax-ml/jax/issues/16614 - - @jax.custom_jvp - def f(x, y): - return x - - @partial(f.defjvp, symbolic_zeros=True) - def f_jvp(primals, tangents): - x, _ = primals - x_dot, _ = tangents - return x, x_dot - - def body(x_y, _): - x, y = x_y - return (f(x, y), x), None - - @jax.grad - def g(x): - (out, _), _ = lax.scan(body, (x, 1.), xs=None, length=2) - return out - - g(1.) # doesn't crash - - def test_dce(self): - @jax.custom_jvp - def f(x, y): - return jnp.sin(x), x + jnp.cos(y) - - @f.defjvp - def f_jvp(primals, tangents): - x, y = primals - dx, dy = tangents - return f(x, y), (2.0 * jnp.cos(x) * dx, 1.5 * dx - 0.5 * jnp.sin(y) * dy) - - def check_jaxpr(jaxpr, used_outs, includes, excludes): - dce_jaxpr, _ = pe.dce_jaxpr(jaxpr, used_outs) - if not dce_jaxpr.eqns: - assert not includes - return - call_jaxpr = dce_jaxpr.eqns[0].params["call_jaxpr"] - for prim in includes: - assert any(eqn.primitive == prim for eqn in call_jaxpr.eqns) - for prim in excludes: - assert all(eqn.primitive != prim for eqn in call_jaxpr.eqns) - - x, y = 0.1, -1.3 - jaxpr = jax.make_jaxpr(f)(x, y).jaxpr - check_jaxpr(jaxpr, [True, True], [lax.sin_p, lax.cos_p], []) - check_jaxpr(jaxpr, [True, False], [lax.sin_p], [lax.cos_p]) - check_jaxpr(jaxpr, [False, True], [lax.cos_p], [lax.sin_p]) - check_jaxpr(jaxpr, [False, False], [], [lax.sin_p, lax.cos_p]) - - def dce_jaxpr_as_fun(jaxpr, used_outs): - jaxpr_, _ = pe.dce_jaxpr(jaxpr, used_outs) - fun = core.jaxpr_as_fun(pe.close_jaxpr(jaxpr_)) - return lambda *args: fun(*args)[0] - - f0 = dce_jaxpr_as_fun(jaxpr, [True, False]) - f1 = dce_jaxpr_as_fun(jaxpr, [False, True]) - self.assertAllClose( - api.jvp(f0, (x, y), (1.0, 0.0)), (f0(x, y), 2.0 * jnp.cos(x))) - self.assertAllClose( - api.jvp(f0, (x, y), (0.0, 1.0)), (f0(x, y), 0.0)) - self.assertAllClose( - api.jvp(f1, (x, y), (1.0, 0.0)), (f1(x, y), 1.5)) - self.assertAllClose( - api.jvp(f1, (x, y), (0.0, 1.0)), (f1(x, y), -0.5 * jnp.sin(y))) - - def test_resolve_kwargs_error_message(self): - @jax.custom_jvp - def f(x, y, *, z=None): - return jnp.sin(x), x + jnp.cos(y) - - @f.defjvp - def f_jvp(primals, tangents): - self.fail("should not be executed") - - with self.assertRaisesRegex( - TypeError, - r"The input arguments to the custom_jvp-decorated function f(.*)\n" - r"missing a required argument: 'y'" - ): - f(0.5) - - with self.assertRaisesRegex( - TypeError, - r"The input arguments to the custom_jvp-decorated function f(.*)\n" - "The following keyword arguments could not be resolved to positions: z" - ): - f(0.5, 0.1, z=1.0) - - -class CustomVJPTest(jtu.JaxTestCase): - - def test_basic(self): - @jax.custom_vjp - def f(x): - return jnp.sin(x) - def f_fwd(x): - return f(x), jnp.cos(x) - def f_rev(cos_x, g): - return (2 * cos_x * g,) - f.defvjp(f_fwd, f_rev) - - x = 3. - self.assertAllClose(f(x), jnp.sin(x)) - self.assertAllClose(api.grad(f)(x), 2 * jnp.cos(x)) - self.assertAllClose(api.value_and_grad(f)(x), - (jnp.sin(x), 2 * jnp.cos(x))) - - def test_invariance(self): - @jax.custom_vjp - def f(x): - return jnp.cos(2 * x) / 2. - def f_fwd(x): - return (f(x), x) - def f_rev(x, g): - return (g * 3,) - f.defvjp(f_fwd, f_rev) - def f2(x): - y, _ = api.value_and_grad(f)(x) - return y - def f3(x): - y, _ = api.value_and_grad(f2)(x) - return y - x = 1. - self.assertAllClose(f(x), f2(x), check_dtypes=False) - self.assertAllClose(f(x), f3(x), check_dtypes=False) - self.assertAllClose(api.grad(f)(x), api.grad(f2)(x), - check_dtypes=False) - self.assertAllClose(api.grad(f)(x), api.grad(f3)(x), - check_dtypes=False) - - def test_python_control_flow(self): - @jax.custom_vjp - def f(x): - if x > 0: - return jnp.sin(x) - else: - return jnp.cos(x) - def f_fwd(x): - if x > 0: - return f(x), x - else: - return f(x), x - def f_rev(x, g): - if x > 0: - return (2 * g,) - else: - return (3 * g,) - f.defvjp(f_fwd, f_rev) - x = 2. - self.assertAllClose(f(x), jnp.sin(x)) - self.assertAllClose(f(-x), jnp.cos(-x)) - self.assertAllClose(api.value_and_grad(f)(x), (jnp.sin(x), 2.), - check_dtypes=False) - self.assertAllClose(api.value_and_grad(f)(-x), (jnp.cos(-x), 3.), - check_dtypes=False) - - def test_vmap(self): - @jax.custom_vjp - def f(x): - assert jnp.ndim(x) == 0 - return jnp.sin(x) - def f_fwd(x): - assert jnp.ndim(x) == 0 - return f(x), jnp.cos(x) - def f_rev(cos_x, g): - return (2 * cos_x * g,) - f.defvjp(f_fwd, f_rev) - - x = jnp.arange(3.) - xx = jnp.arange(6.).reshape(2, 3) - - # vmap of f - self.assertAllClose(api.vmap(f)(x), jnp.sin(x)) - self.assertAllClose(api.vmap(api.vmap(f))(xx), jnp.sin(xx)) - - # vmap of grad of f - self.assertAllClose(api.vmap(api.grad(f))(x), 2 * jnp.cos(x)) - self.assertAllClose(api.vmap(api.value_and_grad(f))(x), - (jnp.sin(x), 2 * jnp.cos(x))) - self.assertAllClose(api.vmap(api.vmap(api.grad(f)))(xx), 2 * jnp.cos(xx)) - self.assertAllClose(api.vmap(api.vmap(api.value_and_grad(f)))(xx), - (jnp.sin(xx), 2 * jnp.cos(xx))) - - # grad of vmap of f - self.assertAllClose(api.grad(lambda x: api.vmap(f)(x).sum())(x), - 2 * jnp.cos(x)) - self.assertAllClose(api.grad(lambda x: api.vmap(api.vmap(f))(x).sum())(xx), - 2 * jnp.cos(xx)) - - # vmap of grad of vmap of f - self.assertAllClose(api.vmap(api.grad(lambda x: api.vmap(f)(x).sum()))(xx), - 2 * jnp.cos(xx)) - - def test_jit(self): - @jax.custom_vjp - def f(x): - return jnp.sin(x) - def f_fwd(x): - return f(x), jnp.cos(x) - def f_rev(cos_x, g): - return (2 * cos_x * g,) - f.defvjp(f_fwd, f_rev) - - x = 3. - - # jit - self.assertAllClose(api.jit(f)(x), jnp.sin(x)) - self.assertAllClose(api.jit(api.jit(f))(x), jnp.sin(x)) - - # jit of grad - self.assertAllClose(api.jit(api.grad(f))(x), 2 * jnp.cos(x), - check_dtypes=False) - - # grad of jit - self.assertAllClose(api.grad(api.jit(f))(x), 2 * jnp.cos(x), - check_dtypes=False) - - def test_pytrees(self): - @jax.custom_vjp - def f(x): - return {'b': jnp.sin(x['a'])} - def f_fwd(x): - return f(x), {'r': jnp.cos(x['a'])} - def f_bwd(res, g): - cos_x = res['r'] - return ({'a': 2 * cos_x * g['b']},) - f.defvjp(f_fwd, f_bwd) - x = {'a': 3.} - self.assertAllClose(f(x)['b'], jnp.sin(x['a'])) - self.assertAllClose(api.grad(lambda x: f(x)['b'])(x), - {'a': 2 * jnp.cos(x['a'])}) - - def test_jvp_error(self): - @jax.custom_vjp - def f(x): - return jnp.sin(x) - def f_fwd(x): - return f(x), jnp.cos(x) - def f_rev(cos_x, g): - return (2 * cos_x * g,) - f.defvjp(f_fwd, f_rev) - - self.assertRaisesRegex( - TypeError, - r"can't apply forward-mode autodiff \(jvp\) to a custom_vjp function.", - lambda: api.jvp(f, (3.,), (1.,))) - self.assertRaisesRegex( - TypeError, - r"can't apply forward-mode autodiff \(jvp\) to a custom_vjp function.", - lambda: api.jvp(api.vmap(f), (jnp.arange(3.),), (jnp.ones(3),))) - self.assertRaisesRegex( - TypeError, - r"can't apply forward-mode autodiff \(jvp\) to a custom_vjp function.", - lambda: api.jvp(jit(f), (3.,), (1.,))) - - def test_kwargs(self): - # from https://github.com/jax-ml/jax/issues/1938 - @jax.custom_vjp - def my_fun(x, y, c=1.): - return c * (x + y) - my_fun.defvjp(lambda x, y, c=1.: (my_fun(c, y, c), None), - lambda _, g: (g, g, g)) - f = lambda x, y: jnp.square(my_fun(x, y, c=2.)).sum() - f(10., 5.) # doesn't crash - api.grad(f)(10., 5.) # doesn't crash - - def test_initial_style(self): - @jax.custom_vjp - def f(x): - return jnp.sin(x) - def f_fwd(x): - return f(x), jnp.cos(x) - def f_rev(cos_x, g): - return (2 * cos_x * g,) - f.defvjp(f_fwd, f_rev) - - def foo(x): - out, _ = lax.scan(lambda c, _: (f(c), None), x, None, length=1) - return out - - ans = api.grad(foo)(3.) - expected = 2. * jnp.cos(3.) - self.assertAllClose(ans, expected, check_dtypes=False) - - ans = api.grad(api.grad(foo))(3.) - expected = -2. * jnp.sin(3.) - self.assertAllClose(ans, expected) - - def test_initial_style_vmap(self): - @jax.custom_vjp - def f(x): - assert jnp.ndim(x) == 0 - return 3 * x - def f_fwd(x): - return f(x), jnp.cos(x) - def f_rev(cos_x, g): - return (2 * cos_x * g,) - f.defvjp(f_fwd, f_rev) - - def foo(x): - out, _ = lax.scan(lambda c, _: (f(c), None), x, None, length=1) - return out - - ans = api.vmap(foo)(jnp.arange(3.)) - expected = 3. * jnp.arange(3.) - self.assertAllClose(ans, expected, check_dtypes=False) - - ans = api.grad(lambda x: api.vmap(foo)(x).sum())(jnp.arange(3.)) - expected = 2. * jnp.cos(jnp.arange(3.)) - self.assertAllClose(ans, expected, check_dtypes=False) - - def test_nondiff_arg(self): - @partial(jax.custom_vjp, nondiff_argnums=(0,)) - def app(f, x): - return f(x) - def app_fwd(f, x): - return app(f, x), jnp.cos(x) - def app_rev(f, cos_x, g): - return (cos_x * g,) - app.defvjp(app_fwd, app_rev) - - ans = app(lambda x: 2 * x, 1) - expected = 2 - self.assertAllClose(ans, expected, check_dtypes=False) - - ans = api.value_and_grad(lambda x: app(lambda y: 2 * y, x))(1.) - expected = (2., jnp.cos(1.)) - self.assertAllClose(ans, expected, check_dtypes=False) - - def test_closed_over_jit_tracer(self): - # See the comment in CustomJVPTest.test_nondiff_arg_jit_tracer. - raise unittest.SkipTest("behavior no longer supported") - - # This test is similar to test_nondiff_arg_tracer except it uses lexical - # closure rather than the nondiff_argnums mechanism. We decided to disallow - # tracers in nondiff_argnums to greatly simplify bookkeeping while still - # supporting the cases for which it is necessary. - def outer(x): - @jax.custom_vjp - def f(y): - return x * y - def f_fwd(y): - return f(y), jnp.cos(y) - def f_rev(cos_y, g): - return (cos_y * g,) - f.defvjp(f_fwd, f_rev) - return f - - @jit - def g(x, y): - return outer(x)(y) - - ans = g(2, 3.) - expected = 6. - self.assertAllClose(ans, expected, check_dtypes=False) - - ans = api.grad(g, 1)(2., 3.) - expected = jnp.cos(3.) - self.assertAllClose(ans, expected, check_dtypes=False) - - def test_closed_over_vmap_tracer(self): - def outer(x): - @jax.custom_vjp - def f(y): - return x * y - def f_fwd(y): - return f(y), jnp.cos(y) - def f_rev(cos_y, g): - return (cos_y * g,) - f.defvjp(f_fwd, f_rev) - return f - - @api.vmap - def g(x): - return outer(x)(3.) - - ans = g(np.arange(3.)) - expected = np.arange(3.) * 3 - self.assertAllClose(ans, expected, check_dtypes=False) - - def test_closed_over_tracer3(self): - def outer(x): - @jax.custom_vjp - def f(y): - return x * y - def f_fwd(y): - return f(y), (x, jnp.cos(y)) - def f_rev(res, g): - x, cos_y = res - return (cos_y * g * x,) - f.defvjp(f_fwd, f_rev) - return api.grad(f) - - @api.vmap - def g(x): - return outer(x)(3.) - - ans = g(np.arange(3.)) - expected = np.cos(3.) * np.arange(3.) - self.assertAllClose(ans, expected, check_dtypes=False) - - def test_nondiff_arg_tracer_error(self): - # This is similar to the old (now skipped) test_nondiff_arg_tracer, except - # we're testing for the error message that usage pattern now raises. - - @partial(jax.custom_vjp, nondiff_argnums=(0,)) - def f(x, y): - return x * y - def f_fwd(x, y): - return f(x, y), jnp.cos(y) - def f_rev(x, cos_y, g): - return (cos_y * g,) - f.defvjp(f_fwd, f_rev) - - @jit - def g(x, y): - return f(x, y) - - with self.assertRaisesRegex(UnexpectedTracerError, "custom_vjp"): - _ = g(2, 3.) - with self.assertRaisesRegex(UnexpectedTracerError, "custom_vjp"): - _ = api.grad(g, 1)(2., 3.) - - def test_vmap_axes(self): - raise unittest.SkipTest("TODO") # TODO(mattjj): write test - - def test_pmap(self): - raise unittest.SkipTest("TODO") # TODO(mattjj): write test - - def test_missing_vjp_rule_error(self): - @jax.custom_vjp - def foo(x): - return x ** 2 - - self.assertRaisesRegex( - AttributeError, - r"No VJP defined for custom_vjp function foo using defvjp.", - lambda: foo(2)) - self.assertRaisesRegex( - AttributeError, - r"No VJP defined for custom_vjp function foo using defvjp.", - lambda: api.grad(foo)(2.)) - - def test_vjp_rule_inconsistent_pytree_structures_error(self): - @jax.custom_vjp - def f(x): - return x - - def foo_fwd(x): - return x, None - - def foo_bwd(_, g): - return (g, g) - - f.defvjp(foo_fwd, foo_bwd) - - f(2) # doesn't crash - self.assertRaisesRegex( - TypeError, - re.escape( - "Custom VJP bwd rule must produce an output with the same container " - "(pytree) structure as the args tuple of the primal function, " - "and in particular must produce a tuple of length equal to the " - "number of arguments to the primal function, but got bwd output " - "structure {} for primal input structure {}.".format( - jax.tree.structure((1, 1)), - jax.tree.structure((1,))) - ), - lambda: api.grad(f)(2.)) - - def test_vjp_bwd_returns_non_tuple_error(self): - @jax.custom_vjp - def f(x): - return x - - def foo_fwd(x): - return x, None - - def foo_bwd(_, g): - return 2. * g # Should be a tuple - - f.defvjp(foo_fwd, foo_bwd) - with self.assertRaisesRegex(TypeError, "Custom VJP bwd rule .* must produce a tuple"): - api.grad(f)(3.) - - def test_fwd_rule_primal_out_type_doesnt_match_primal_error_message(self): - # https://github.com/lucidrains/flash-attention-jax/issues/7 - - def scan_apply(f, x): - y, _ = jax.lax.scan(lambda x, _: (f(x), None), x, None, length=1) - return y - - @jax.custom_vjp - def f(x): - return x - - def f_fwd(x): - return (x, x), None - - def f_bwd(_, y_bar): - return (y_bar,) - - f.defvjp(f_fwd, f_bwd) - - self.assertRaisesRegex( - TypeError, - re.escape( - "Custom VJP fwd rule f_fwd for function f must produce a pair " - "(list or tuple of length two) where the first element represents " - "the primal output (equal to the output of the " - "custom_vjp-decorated function f) and the second element " - "represents residuals (i.e. values stored from the forward " - "pass for use on the backward pass), but instead the fwd rule " - "output's first element had container/pytree structure:\n" - " (float32[], float32[])\n" - "while the custom_vjp-decorated function f had output " - "container/pytree structure:\n" - " float32[]." - ), - lambda: jax.grad(lambda x: scan_apply(f, x))(jnp.float32(1.))) - - def f_fwd2(x): - return jnp.zeros((3, *x.shape), x.dtype), None - - def f_bwd2(_, y_bar): - return (y_bar,) - - f.defvjp(f_fwd2, f_bwd2) - - self.assertRaisesRegex( - TypeError, - re.escape( - "Custom VJP fwd rule f_fwd2 for function f must produce a pair " - "(list or tuple of length two) where the first element represents " - "the primal output (equal to the output of the " - "custom_vjp-decorated function f) and the second element " - "represents residuals (i.e. values stored from the forward " - "pass for use on the backward pass), but instead the fwd rule " - "output's first element had shapes/dtypes of:\n" - " float32[3]\n" - "while the custom_vjp-decorated function f had output " - "shapes/dtypes of:\n" - " float32[]" - ), - lambda: jax.grad(lambda x: scan_apply(f, x))(jnp.float32(1.))) - - def test_issue2511(self): - arr = jnp.ones((5, 2, 2)) - foo = lambda x: api.vmap(jnp.linalg.det, (0,))(x) - api.jit(foo)(arr) # doesn't crash - - def test_lowering_out_of_traces(self): - # https://github.com/jax-ml/jax/issues/2578 - - class F(collections.namedtuple("F", ["a"])): - def __call__(self, x): - return jax.nn.relu(self.a) * x - - @jax.jit - def g(f, x): - return f(x) - - jax.grad(g, argnums=(1,))(F(2.0), 0.) # doesn't crash - - def test_clip_gradient(self): - # https://github.com/jax-ml/jax/issues/2784 - @jax.custom_vjp - def _clip_gradient(lo, hi, x): - return x # identity function when not differentiating - - def clip_gradient_fwd(lo, hi, x): - return x, (lo, hi,) - - def clip_gradient_bwd(res, g): - lo, hi = res - return (None, None, jnp.clip(g, lo, hi),) - - _clip_gradient.defvjp(clip_gradient_fwd, clip_gradient_bwd) - - def clip_gradient(x): - lo = -0.1 - hi = x + 0.1 - return _clip_gradient(lo, hi, x) - - g = jax.grad(clip_gradient)(0.1) # doesn't crash - self.assertAllClose(g, jnp.array(0.2)) - - def test_nestable_vjp(self): - # Verify that https://github.com/jax-ml/jax/issues/3667 is resolved. - def f(x): - return x ** 2 - - @jax.custom_vjp - def g(x): - return f(x) - - def g_fwd(x): - y, f_vjp = api.vjp(f, x) - return y, f_vjp - - def g_bwd(f_vjp, y_bar): - return f_vjp(y_bar) - - g.defvjp(g_fwd, g_bwd) - - # Check that VJP can be nested in simple situations. For this to pass, - # vjp has to return a PyTree. - _, g_vjp = api.vjp(g, 1.0) - y, = g_vjp(1.0) - self.assertAllClose(y, jnp.array(2.0)) - - # Check that VJP can be nested in complex situations. For this to pass, - # vjp can't treat the closed-over tracer x as a static argument. - @jit - def z(x): - _, g_vjp = api.vjp(g, x) - return g_vjp - y, = z(1.0)(3.0) - self.assertAllClose(y, jnp.array(6.0)) - - def test_initial_style_vmap_2(self): - # https://github.com/jax-ml/jax/issues/4173 - x = jnp.ones((10, 3)) - - # Create the custom function - @jax.custom_vjp - def custom_fun(x): - return x.sum() - - def forward(x): - return x.sum(), (jnp.ones_like(x),) - - def backward(res, g): - return g * res[0], - - custom_fun.defvjp(forward, backward) - - def train_fun(x): - - def summed_fun(x): - return api.vmap(custom_fun)(x).sum() - - return api.grad(summed_fun)(x) - - def scan_body(carry, inputs): - x = carry - return carry, train_fun(x) - - scan_range = jnp.arange(4) - lax.scan(scan_body, x, scan_range) # don't crash - - def test_initial_style_vmap_3(self): - # This is like test_initial_style_vmap except the primal function closes - # over an array constant. - y = jnp.arange(1., 4.) - - @jax.custom_vjp - def f(x): - assert jnp.ndim(x) == 0 - return 3 * x * jnp.sum(y) - def f_fwd(x): - return f(x), jnp.cos(x) - def f_rev(cos_x, g): - return (2 * cos_x * g,) - f.defvjp(f_fwd, f_rev) - - def foo(x): - out, _ = lax.scan(lambda c, _: (f(c), None), x, None, length=1) - return out - - ans = api.vmap(foo)(jnp.arange(3.)) - expected = 3. * jnp.arange(3.) * 6 - self.assertAllClose(ans, expected, check_dtypes=False) - - ans = api.grad(lambda x: api.vmap(foo)(x).sum())(jnp.arange(3.)) - expected = 2. * jnp.cos(jnp.arange(3.)) - self.assertAllClose(ans, expected, check_dtypes=False) - - def test_initial_style_vmap_with_collective(self): - - @jax.custom_vjp - def f(x): - return lax.psum(x, 'foo') - - def f_fwd(x): - return lax.psum(x, 'foo'), None - - def f_bwd(res, dx): - return dx - f.defvjp(f_fwd, f_bwd) - - def g(x): - jaxpr = api.make_jaxpr(f)(x) - return core.eval_jaxpr(jaxpr.jaxpr, [], x)[0] - - out = api.vmap(lambda _, x: g(x), axis_name='foo', in_axes=(0, None), - out_axes=None)(jnp.arange(4.), 2.) - self.assertAllClose(out, 8.) - - def test_bwd_closes_over_tracer(self): - def f(y): - @jax.custom_vjp - def f(x): - return 2. * jnp.sin(x) - - def fwd(x): - return f(x), () - - def bwd(_, g): - return (2. * jnp.cos(y) * g,) # capture! - - f.defvjp(fwd, bwd) - - return jax.grad(f)(1.) - - ans = jax.jit(f)(2.) - self.assertAllClose(ans, 2. * jnp.cos(2.)) - - ans = jax.vmap(f)(jnp.arange(3.)) - self.assertAllClose(ans, 2. * jnp.cos(jnp.arange(3.))) - - ans = jax.jit(jax.vmap(f))(jnp.arange(3.)) - self.assertAllClose(ans, 2. * jnp.cos(jnp.arange(3.))) - - ans = jax.vmap(jax.jit(f))(jnp.arange(3.)) - self.assertAllClose(ans, 2. * jnp.cos(jnp.arange(3.))) - - ans = jax.grad(f)(4.) - self.assertAllClose(ans, -2. * jnp.sin(4.)) - - def test_fwd_closes_over_tracer(self): - def f(y): - @jax.custom_vjp - def f(x): - return 2. * jnp.sin(x) - - def fwd(x): - return f(x), y - - def bwd(y, g): - return (2. * jnp.cos(y) * g,) # capture! - - f.defvjp(fwd, bwd) - - return jax.grad(f)(1.) - - ans = jax.jit(f)(2.) - self.assertAllClose(ans, 2. * jnp.cos(2.)) - - ans = jax.vmap(f)(jnp.arange(3.)) - self.assertAllClose(ans, 2. * jnp.cos(jnp.arange(3.))) - - ans = jax.jit(jax.vmap(f))(jnp.arange(3.)) - self.assertAllClose(ans, 2. * jnp.cos(jnp.arange(3.))) - - ans = jax.vmap(jax.jit(f))(jnp.arange(3.)) - self.assertAllClose(ans, 2. * jnp.cos(jnp.arange(3.))) - - ans = jax.grad(f)(4.) - self.assertAllClose(ans, -2. * jnp.sin(4.)) - - def test_float0(self): - @jax.custom_vjp - def f(x, _): - return x - def f_fwd(x, _): - # we need a defined (non-float0) tangent to trigger the rule - return x, (2., 1) - def f_rev(*_): - return (2., 1) - f.defvjp(f_fwd, f_rev) - - x = 2. - y = 3 - self.assertEqual(api.grad(f, allow_int=True, argnums=(0, 1))(x, y), - (2., np.zeros(shape=(), dtype=float0))) - - def test_float0_initial_style(self): - @jax.custom_vjp - def f(x): - return x - def f_fwd(x): - return x, (2., x) - def f_rev(*_): - return ((2., jnp.zeros(shape=(), dtype=float0)),) - f.defvjp(f_fwd, f_rev) - - def foo(x, y): - out, _ = lax.scan(lambda c, _: (f(c), None), (x, y), None, length=1) - return out[0] - - x = 2. - y = 3 - self.assertEqual(api.grad(foo, allow_int=True, argnums=(0, 1))(x, y), - (2., np.zeros(shape=(), dtype=float0))) - - def test_remat(self): - @jax.custom_vjp - def f(x): - return jnp.sin(x) - def f_fwd(x): - return f(x), jnp.cos(x) - def f_rev(cos_x, g): - return (2 * cos_x * g,) - f.defvjp(f_fwd, f_rev) - - @jax.remat - def g(x): - return f(f(x)) - - ans = g(2.) - expected = np.sin(np.sin(2.)) - self.assertAllClose(ans, expected, check_dtypes=False) - - ans = api.grad(g)(2.) - expected = 4. * api.grad(lambda x: jnp.sin(jnp.sin(x)))(2.) - self.assertAllClose(ans, expected, check_dtypes=False) - - def test_remat_higher_order(self): - @jax.custom_vjp - def f(x): - return jnp.sin(x) - def f_fwd(x): - return f(x), jnp.cos(x) - def f_rev(cos_x, g): - return (2 * cos_x * g,) - f.defvjp(f_fwd, f_rev) - - def g(x): - return f(f(x)) - - ans = api.grad(api.grad(jax.remat(g)))(2.) - expected = api.grad(api.grad(g))(2.) - self.assertAllClose(ans, expected, check_dtypes=False) - - ans = api.grad(jax.remat(api.grad(g)))(2.) - expected = api.grad(api.grad(g))(2.) - self.assertAllClose(ans, expected, check_dtypes=False) - - ans = api.grad(api.grad(api.grad(jax.remat(g))))(2.) - expected = api.grad(api.grad(api.grad(g)))(2.) - self.assertAllClose(ans, expected, check_dtypes=False) - - def test_bwd_nones(self): - @jax.custom_vjp - def f(x, y): - return x * jnp.sin(y) - def f_fwd(x, y): - return f(x, y), jnp.cos(y) - def f_rev(cos, g): - return (None, 2 * cos * g) - f.defvjp(f_fwd, f_rev) - - ans = api.grad(lambda x: f(x, x))(3.) - expected = 2 * jnp.cos(3.) - self.assertAllClose(ans, expected, check_dtypes=False) - - def test_bwd_nones_vmap(self): - @jax.custom_vjp - def f(x, y): - return x * jnp.sin(y) - def f_fwd(x, y): - return f(x, y), jnp.cos(y) - def f_rev(cos, g): - return (None, 2 * cos * g) - f.defvjp(f_fwd, f_rev) - - ans = api.grad(lambda x: api.vmap(f)(x, x).sum())(jnp.arange(3.)) - expected = 2 * jnp.cos(jnp.arange(3.)) - self.assertAllClose(ans, expected, check_dtypes=False) - - def test_bwd_nones_pytree(self): - @jax.custom_vjp - def f(xs, y): - x1, x2 = xs - return x1 * x2 * jnp.sin(y) - def f_fwd(xs, y): - return f(xs, y), jnp.cos(y) - def f_rev(cos, g): - return (None, 2 * cos * g) - f.defvjp(f_fwd, f_rev) - - ans = api.grad(lambda x: f((x, x), x))(3.) - expected = 2 * jnp.cos(3.) - self.assertAllClose(ans, expected, check_dtypes=False) - - def test_custom_vjp_closure_4521(self): - # https://github.com/jax-ml/jax/issues/4521 - @jax.custom_vjp - def g(x, y): - return None - def g_fwd(x, y): - return None, y - def g_bwd(residuals, z_bar): - assert False - - g.defvjp(g_fwd, g_bwd) - - def f(xs, y): - v_g = api.vmap(g, in_axes=(0, None), out_axes=None) - v_g(xs, y) - - def scan_body(xs, _): - y = jnp.zeros(1) - _, vjp_f = api.vjp(f, xs, y) - vjp_f(None) - return xs, None - - lax.scan(scan_body, jnp.ones(5), None, 100) # doesn't crash - - def test_float0_bwd_none(self): - @jax.custom_vjp - def f(i, x): - return jnp.sin(x) - def f_fwd(i, x): - return f(i, x), jnp.cos(x) - def f_rev(cos_x, g): - return (None, 2 * cos_x * g) - f.defvjp(f_fwd, f_rev) - - ans = api.grad(f, 1)(jnp.array([1, 2]), 3.) # doesn't crash - expected = 2 * jnp.cos(3.) - self.assertAllClose(ans, expected, check_dtypes=False) - - def test_custom_gradient(self): - @jax.custom_gradient - def f(x): - return x ** 2, lambda g: (g * x,) - - self.assertAllClose(f(3.), 9., check_dtypes=False) - self.assertAllClose(api.grad(f)(3.), 3., check_dtypes=False) - self.assertAllClose(api.grad(api.grad(f))(3.), 1., check_dtypes=False) - - def test_custom_gradient_2(self): - @jax.custom_gradient - def f(x, y): - return x * y, lambda g: (y, x) - - self.assertAllClose(f(3., 4.), 12., check_dtypes=False) - self.assertAllClose(api.grad(f, argnums=(0, 1))(3., 4.), (4., 3.), - check_dtypes=False) - - def test_custom_gradient_3(self): - @jax.custom_gradient - def f(x): - vjp = lambda g: (jnp.cos(x) * jnp.arange(3., 6.),) - return jnp.sum(jnp.sin(x)), vjp - - self.assertAllClose(f(jnp.arange(3)), jnp.sum(jnp.sin(jnp.arange(3.))), - check_dtypes=False) - self.assertAllClose( - api.grad(f)(jnp.arange(3.)), - api.grad(lambda x: jnp.sum(jnp.sin(x)))(jnp.arange(3.)) * jnp.arange(3., 6.), - check_dtypes=False) - - def test_custom_gradient_can_return_singleton_value_in_vjp(self): - @jax.custom_gradient - def f(x): - return x ** 2, lambda g: g * x - - self.assertAllClose(f(3.), 9., check_dtypes=False) - self.assertAllClose(api.grad(f)(3.), 3., check_dtypes=False) - self.assertAllClose(api.grad(api.grad(f))(3.), 1., check_dtypes=False) - - def test_closure_convert(self): - def cos_after(fn, x): - converted_fn, aux_args = jax.closure_convert(fn, x) - self.assertLessEqual(len(aux_args), 1) - return _cos_after(converted_fn, x, *aux_args) - - @partial(jax.custom_vjp, nondiff_argnums=(0,)) - def _cos_after(fn, x, *args): - return jnp.cos(fn(x, *args)) - - def fwd(fn, x, *args): - y = _cos_after(fn, x, *args) - return y, (x, args) - - def rev(fn, res, g): - x, args = res - x_bar = 17. * x - args_bars = [42. * a for a in args] - return (x_bar, *args_bars) - - _cos_after.defvjp(fwd, rev) - - def dist(c, x): - return jnp.sum((x - c) ** 2.) - - def solve(c, x): - def closure(x): - return dist(c, x) - return cos_after(closure, x) - - c, x = 2. * jnp.ones(2), jnp.ones(2) - expected = jnp.cos(dist(c, x)) - self.assertAllClose(solve(c, x), expected, check_dtypes=False) - g_c, g_x = api.grad(solve, argnums=(0, 1))(c, x) - self.assertAllClose(g_c, 42. * c, check_dtypes=False) - self.assertAllClose(g_x, 17. * x, check_dtypes=False) - - def test_closure_convert_mixed_consts(self): - # Like test_closure_convert, but close over values that - # participate in AD as well as values that do not. - # See https://github.com/jax-ml/jax/issues/6415 - - def cos_after(fn, x): - converted_fn, aux_args = jax.closure_convert(fn, x) - self.assertLessEqual(len(aux_args), 1) - return _cos_after(converted_fn, x, *aux_args) - - @partial(jax.custom_vjp, nondiff_argnums=(0,)) - def _cos_after(fn, x, *args): - return jnp.cos(fn(x, *args)) - - def fwd(fn, x, *args): - y = _cos_after(fn, x, *args) - return y, (x, args) - - def rev(fn, res, g): - x, args = res - x_bar = 17. * x - args_bars = [42. * a for a in args] - return (x_bar, *args_bars) - - _cos_after.defvjp(fwd, rev) - - def dist(c, s, x): - return jnp.sum(s * (x - c) ** 2.) - - def solve(c, s, x): - def closure(x): - return dist(c, s, x) - return cos_after(closure, x) - - c, s, x = 2. * jnp.ones(2), 3. * jnp.ones(2), jnp.ones(2) - expected = jnp.cos(dist(c, s, x)) - self.assertAllClose(solve(c, s, x), expected, check_dtypes=False) - g_c, g_x = api.grad(solve, argnums=(0, 2))(c, s, x) - self.assertAllClose(g_c, 42. * c, check_dtypes=False) - self.assertAllClose(g_x, 17. * x, check_dtypes=False) - - def test_closure_convert_pytree_mismatch(self): - # See https://github.com/jax-ml/jax/issues/23588 - def f(x, z): - return z * x - - x, z = 2.0, 3.0 - _, vjp = api.vjp(f, x, z) - vjp_pure, vjp_aux_args = jax.closure_convert(vjp, x) - vjp_pure(x, *vjp_aux_args) - with self.assertRaisesRegex( - TypeError, "The inputs to the closure produced by closure_convert"): - vjp_pure(x, vjp_aux_args) - - def test_float0_cotangents_automatically_handled(self): - @jax.custom_vjp - def f(x, y): - return x - - def f_fwd(x, y): - return x, None - - def f_bwd(_, zbar): - return (0., 1) - - f.defvjp(f_fwd, f_bwd) - - jax.jit(lambda x: jax.vjp(f, 0., x)[1](1.))(1) # doesn't crash - - def test_custom_vjp_scan_batching_edge_case(self): - # https://github.com/jax-ml/jax/issues/5832 - @jax.custom_vjp - def mul(x, coeff): return x * coeff - def mul_fwd(x, coeff): return mul(x, coeff), (x, coeff) - def mul_bwd(res, g): - x, coeff = res - g_x = g * coeff - g_coeff = (x * g).sum() - return g_x, g_coeff - mul.defvjp(mul_fwd, mul_bwd) - - def scan_over_mul(x, coeff): - def f_(x, t): - return mul(x, coeff), None - y, _ = jax.lax.scan(f_, x, jnp.arange(3)) - return y - - key = jax.random.key(0) - key1, key2 = jax.random.split(key, 2) - x_batch = jax.random.normal(key1, (3, 2)) - covector_batch = jax.random.normal(key2, (3, 2)) - coeff = jnp.array(1., dtype=x_batch.dtype) - - batched_scan_over_mul = jax.vmap(scan_over_mul, in_axes=(0, None), out_axes=0) - res, vjp_fun = jax.vjp(batched_scan_over_mul, x_batch, coeff) - vjp_fun(covector_batch) # doesn't crash - - jtu.check_grads(batched_scan_over_mul, (x_batch, coeff), order=2, - modes=['rev']) - - def test_closure_with_vmap2(self): - # https://github.com/jax-ml/jax/issues/8783 - def h(z): - def f(x): - @jax.custom_vjp - def g(y): - return x * y - - def g_fwd(y): - return x * y, (x, x * y, y) - def g_rev(res, w_bar): - x, *_ = res - return (x * w_bar,) - g.defvjp(g_fwd, g_rev) - - return g(z) - - return jax.vmap(f)(jnp.arange(3., dtype='float32')).sum() - - jtu.check_grads(h, (jnp.float32(3.14),), order=1, modes=['rev']) - - def test_pytrees_not_required_to_contain_nones(self): - class A(list): - pass - - def unflatten(_, children): - assert children[0] is not None - return A(children) - - tree_util.register_pytree_node(A, lambda x: (x, None), unflatten) - - @jax.custom_vjp - def f(x): - return x[0] - def f_fwd(x): - return x[0], None - def f_bwd(_, g): - return A([g]), - f.defvjp(f_fwd, f_bwd) - - jax.grad(f)(A([1.])) # doesn't crash - - def test_vmap_vjp_called_twice(self): - # https://github.com/jax-ml/jax/pull/14728 - @jax.custom_vjp - def f(x): - return x - f.defvjp(lambda x: (x, None), lambda _, y_bar: (y_bar,)) - - _, f_vjp = jax.vjp(jax.vmap(f), jnp.array([3.])) - f_vjp(jnp.array([3.])) - f_vjp(jnp.array([3.])) # doesn't crash - - def test_symbolic_zero_custom_vjp_basic(self): - ZERO = custom_derivatives_public.SymbolicZero - - @jax.custom_vjp - def f(x, y, z): - return x, x - - def fwd(x, y, z): - self.assertIsInstance(x, jax.custom_derivatives.CustomVJPPrimal) - self.assertIsInstance(y, jax.custom_derivatives.CustomVJPPrimal) - self.assertIsInstance(z, jax.custom_derivatives.CustomVJPPrimal) - self.assertTrue(x.perturbed) - self.assertFalse(y.perturbed) - self.assertFalse(z.perturbed) - return (x.value, x.value), None - - def fwd_all(x, y, z): - self.assertIsInstance(x, jax.custom_derivatives.CustomVJPPrimal) - self.assertIsInstance(y, jax.custom_derivatives.CustomVJPPrimal) - self.assertIsInstance(z, jax.custom_derivatives.CustomVJPPrimal) - self.assertTrue(x.perturbed) - self.assertTrue(y.perturbed) - self.assertTrue(z.perturbed) - return (x.value, x.value), None - - def bwd_all(_, g): - x1, x2 = g - self.assertFalse(type(x1) is ZERO) - self.assertFalse(type(x2) is ZERO) - return x1, x1, x2 - - def bwd_fst(_, g): - x1, x2 = g - self.assertFalse(type(x1) is ZERO) - self.assertIs(type(x2), ZERO) - return x1, x1, x2 - - def bwd_snd(_, g): - x1, x2 = g - self.assertIs(type(x1), ZERO) - self.assertFalse(type(x2) is ZERO) - return x1, x1, x2 - - x, y, z = 4., 5., 6. - i = np.array(7, np.int32) - zero = np.array(0.) - - f.defvjp(fwd, bwd_all, symbolic_zeros=True) - h = jax.jit(f) - jax.jacrev(h)(x, y, z) - jax.jacrev(lambda x: h(x, y, z))(x) - jax.jacrev(h, argnums=(0, 1, 2), allow_int=True)(x, i, i) - - f.defvjp(fwd_all, bwd_fst, symbolic_zeros=True) - fst_f = lambda *xs: f(*xs)[0] - _, vjp = jax.vjp(fst_f, x, y, z) - _, _, gz = vjp(x) - self.assertArraysAllClose(gz, zero) - - f.defvjp(fwd_all, bwd_snd, symbolic_zeros=True) - snd_f = lambda *xs: f(*xs)[1] - _, vjp = jax.vjp(snd_f, x, y, z) - gx, gy, _ = vjp(x) - self.assertArraysAllClose(gx, zero) - self.assertArraysAllClose(gy, zero) - - f.defvjp(fwd, bwd_snd, symbolic_zeros=True) - _, vjp = jax.vjp(lambda x: snd_f(x, y, z), x) - gx, = vjp(x) - self.assertArraysAllClose(gx, zero) - - def test_symbolic_zero_custom_vjp_bwd_shape_error(self): - @jax.custom_vjp - def f(x, y, z): - return x, y, z - - def fwd(x, y, z): - return f(x.value, y.value, z.value), None - - def bwd(_, gs): - x_bar, y_bar, z_bar = gs - return y_bar, x_bar, z_bar # swapped! - - f.defvjp(fwd, bwd, symbolic_zeros=True) - - with self.assertRaisesRegex( - ValueError, - r'Consider just returning a None here'): - jax.grad(lambda x, y, z: f(x, y, z)[2].sum())( - jnp.ones(1), jnp.ones(2), jnp.ones(3)) - - @parameterized.named_parameters( - ('jit_vmap', True, True), - ('jit', True, False), - ('vmap', False, True), - ('', False, False), - ) - def test_symbolic_zero_custom_vjp(self, maybe_jit, maybe_vmap): - # below: - # * static_scalar will be static in and out - # * static_array will be static in, but dynamic out - # * dyn_scalar and dyn_array will be dynamic in and out - - ZERO = custom_derivatives_public.SymbolicZero - - def f(static_scalar, static_array, dyn_scalar, dyn_array): - out1 = static_scalar + dyn_scalar - out2 = static_array + dyn_array - return static_scalar, static_array, out1, out2 - - def _pack(x): - return lax.broadcast(x, (1,)) - - def _unpack(x): - (x,) = x - return x - - def _vmap(fun): - def _fun(*args): - args = jax.tree.map(_pack, args) - out = jax.vmap(fun)(*args) - out = jax.tree.map(_unpack, out) - return out - return _fun - - f = jax.custom_vjp(f) - - def fwd(*args): - xs, pert = [x.value for x in args], [x.perturbed for x in args] - self.assertFalse(pert[0]) - self.assertFalse(pert[1]) - self.assertTrue(pert[2]) - self.assertTrue(pert[3]) - return f(*xs), xs - - def bwd(res, g): - static_scalar, *_ = res - t_static, t_static_arr, t_dyn_scalar, t_dyn_array = g - self.assertIs(type(t_static), ZERO) - self.assertFalse(type(t_static_arr) is ZERO) - self.assertFalse(type(t_dyn_scalar) is ZERO) - self.assertFalse(type(t_dyn_array) is ZERO) - self.assertEqual(t_static.shape, ()) - self.assertEqual(t_static_arr.shape, (2,)) - return (static_scalar + 90, - t_static_arr + 91, - t_dyn_scalar + 92, - t_dyn_array + 93) - - f.defvjp(fwd, bwd, symbolic_zeros=True) - - def g(dyn_scalar, dyn_array): - if maybe_vmap: - f_ = _vmap(f) - else: - f_ = f - outs = f_(1., jnp.array([2., 3.]), dyn_scalar, dyn_array) - return outs[1:] - - def run(primal_ins, cotangent_outs): - primal_outs, vjp = jax.vjp(g, *primal_ins) - cotangent_ins = vjp(cotangent_outs) - return primal_outs, cotangent_ins - - if maybe_jit: - run = jax.jit(run) - - scalar_type = jax.Array if maybe_jit or maybe_vmap else float - primal_ins = (4., jnp.array([5., 6.])) - cotangent_outs = (jnp.array([10., 11.]), 7., jnp.array([8., 9.])) - primal_outs, cotangent_ins = run(primal_ins, cotangent_outs) - - primal_out1, primal_out2, primal_out3 = primal_outs - self.assertIsInstance(primal_out1, jax.Array) - self.assertAllClose(primal_out1, jnp.array([2., 3.])) - self.assertIsInstance(primal_out2, scalar_type) - self.assertAllClose(primal_out2, 5.) - self.assertIsInstance(primal_out3, jax.Array) - self.assertAllClose(primal_out3, jnp.array([7., 9.])) - - ct_in1, ct_in2 = cotangent_ins - self.assertIsInstance(ct_in1, scalar_type) - self.assertAllClose(ct_in1, 99.) - self.assertIsInstance(ct_in2, jax.Array) - self.assertArraysAllClose(ct_in2, jnp.array([101., 102.])) - - def test_symbolic_zero_custom_vjp_vmap_output(self): - @jax.custom_vjp - def f(x, y): - return x, y - - def fwd(x, y): - self.assertTrue(x.perturbed) - self.assertFalse(y.perturbed) - return f(x.value, y.value), None - - def bwd(_, g): - _, ct_y = g - self.assertIs(type(ct_y), custom_derivatives_public.SymbolicZero) - return g - - f.defvjp(fwd, bwd, symbolic_zeros=True) - jax.grad(lambda x, y: jax.vmap(f)(x, y)[0].sum())(jnp.ones(3), jnp.ones(3)) - - def test_symbolic_zero_custom_vjp_custom_pytree(self): - tree_values = custom_derivatives_public.custom_vjp_primal_tree_values - - @tree_util.register_pytree_node_class - class Box: - def __init__(self_, strict, val): - if strict: - # make sure we aren't getting special arguments that should only - # come up when symbolic_zeros is True - self.assertFalse(hasattr(val, 'perturbed')) - self_.strict = strict - self_.x = val - - def tree_flatten(self_): - return [self_.x], self_.strict - - @classmethod - def tree_unflatten(cls, strict, xs): - x, = xs - return cls(strict, x) - - x, y = Box(False, jnp.array(72.)), jnp.array(73.) - - @jax.custom_vjp - def f(box, y): - return box.x * y - - def fwd0(box, y): - self.assertTrue(box.x.perturbed) - self.assertFalse(y.perturbed) - box, y = map(tree_values, [box, y]) - return f(box, y), (box, y) - - def bwd0(res, g): - box, y = res - return y * g, box.x * g - - def fwd1(box, y): - self.assertFalse(box.x.perturbed) - self.assertTrue(y.perturbed) - box, y = map(tree_values, [box, y]) - return f(box, y), (box, y) - - def bwd1(res, g): - box, y = res - return y * g, box.x * g - - f.defvjp(fwd0, bwd0, symbolic_zeros=True) - jax.grad(f, argnums=0)(x, y) - f.defvjp(fwd1, bwd1, symbolic_zeros=True) - jax.grad(f, argnums=1)(x, y) - - def fwd_strict(box, y): - return f(box, y), (box, y) - - def bwd_strict(res, g): - box, y = res - return y * g, box.x * g - - f.defvjp(fwd_strict, bwd_strict) - jax.grad(f)(x, y) - - def test_symbolic_zeros_memoization_caching(self): - # Tests multiple zero patterns for partial_eval._memoize, and also tests - # that we're okay with stores being occupied with equal values. - @jax.custom_vjp - def f(x, y): - return x * y - - def f_fwd(x, y): - return x.value, None - - def f_bwd(_, z_bar): - return z_bar, None - - f.defvjp(f_fwd, f_bwd, symbolic_zeros=True) - - f_ = core.jaxpr_as_fun(jax.make_jaxpr(f)(2., 3.)) - _ = jax.linearize(f_, 2., 3.) - _ = jax.linearize(lambda x: f_(x, 3.), 2.) # don't crash! - - def test_run_rules_more_than_once(self): - # https://github.com/jax-ml/jax/issues/16614 - - @jax.custom_vjp - def f(x, y): - return x + y - - def f_fwd(x, y): - if y.perturbed: - res = None - else: - res = [] - return x.value + y.value, res - - def f_bwd(res, ct): - return ct, ct - - f.defvjp(f_fwd, f_bwd, symbolic_zeros=True) - - def body(x_y, _): - x, y = x_y - return (f(x, y), x), None - - @jax.grad - def g(x): - (out, _), _ = lax.scan(body, (x, 1.), xs=None, length=2) - return out - - g(1.) # doesn't crash - - def test_nones_representing_zeros_in_subtrees_returned_by_bwd(self): - # https://github.com/jax-ml/jax/issues/8356 - @jax.custom_vjp - def f(x): - return x[0] - - def f_fwd(x): - return f(x), None - - def f_bwd(_, z_bar): - return (z_bar, (None, None)), - - f.defvjp(f_fwd, f_bwd) - - jax.grad(f)((1.0, (2.0, 3.0))) # don't crash - - def test_pytree_nones_returned_by_bwd(self): - @jax.custom_vjp - def f(x): - return x[0] - - def f_fwd(x): - return f(x), None - - def f_bwd(_, z_bar): - return (z_bar, (None, None)), - - f.defvjp(f_fwd, f_bwd) - - jax.grad(f)((1.0, (2.0, None))) # don't crash - - def test_bwd_rule_shape_mismatch(self): - @jax.custom_vjp - def foo(x, y): - return x - - def foo_fwd(x, y): - return x, None - - def foo_bwd(_, g): - return jnp.zeros(3), jnp.zeros(3) - - foo.defvjp(foo_fwd, foo_bwd) - - with self.assertRaisesRegex( - ValueError, - r'output\[1\] the bwd rule produced an output of shape/dtype float..\[3\]'): - jax.grad(lambda x, y: foo(x, y * y).sum(), 1)(jnp.ones(3), jnp.ones(4)) - - def test_bwd_rule_shape_mismatch_disable(self): - # TODO(mattjj): remove this test when the config option is removed - @jax.custom_vjp - def foo(x, y): - return x - - def foo_fwd(x, y): - return x, None - - def foo_bwd(_, g): - return jnp.zeros(3), jnp.zeros(3) - - foo.defvjp(foo_fwd, foo_bwd) - - with config.custom_vjp_disable_shape_check(True): - jax.grad(lambda x, y: foo(x, y).sum(), 1)(jnp.ones(3), jnp.ones(4)) - - def test_bwd_rule_can_produce_list_or_tuple(self): - @jax.custom_vjp - def f(x, y): - return x * y - - def f_fwd(x, y): - return f(x, y), (x, y) - - def f_bwd(xy, g): - x, y = xy - return [g * y, x * g] # list, not tuple - - f.defvjp(f_fwd, f_bwd) - - jax.grad(f)(1., 2.) # don't crash - - def test_optimize_remat(self): - def fun(x): - # This array is included to make sure that we handle consts appropriately - return np.array([1.0])*x - - def fwd(x): - return np.array([2.0])*x*x/np.array([1.0]), (x,) - - x = jnp.linspace(0, 5.0, 10) - fwd = custom_derivatives.optimize_remat_of_custom_vjp_fwd( - fun, api_util.debug_info("custom_vjp fun", fun, (x,), {}), - fwd, api_util.debug_info("custom_vjp fwd", fwd, (x,), {})) - - self.assertAllClose(jax.jit(fwd)(x)[0], 2*x*x) # Shouldn't hit custom DCE - self.assertAllClose(jax.jit(lambda x: fwd(x)[0])(x), x) # Should be DCEed - - def test_optimize_remat_vmap(self): - def fun(x): - return (np.array([1.0])*x)[0] - def fwd(x): - return (np.array([2.0])*x*x/np.array([1.0]))[0], (x,) - x = jnp.linspace(0, 5.0, 10) - fwd = custom_derivatives.optimize_remat_of_custom_vjp_fwd( - fun, api_util.debug_info("custom_vjp fun", fun, (x,), {}), - fwd, api_util.debug_info("custom_vjp fwd", fwd, (x,), {})) - self.assertAllClose(jax.jit(jax.vmap(fwd))(x)[0], 2*x*x) - self.assertAllClose(jax.jit(lambda x: jax.vmap(fwd)(x)[0])(x), x) - - def test_optimize_remat_cond(self): - def fun(x): - return x - def fwd(x): - return x*x, (x,) - - x = jnp.linspace(0, 5.0, 10) - fwd = custom_derivatives.optimize_remat_of_custom_vjp_fwd( - fun, api_util.debug_info("custom_vjp fun", fun, (x,), {}), - fwd, api_util.debug_info("custom_vjp fwd", fwd, (x,), {})) - - def g(x): - return jax.lax.cond(True, fwd, lambda x: (2.0 * x, (x,)), x) - - self.assertAllClose(jax.jit(g)(x)[0], x*x) - self.assertAllClose(jax.jit(lambda x: g(x)[0])(x), x) - - def test_optimize_remat_jvp(self): - def fun(x): - return x**2 - def fwd_(x): - return x*x, (x,) - - fwd = custom_derivatives.optimize_remat_of_custom_vjp_fwd( - fun, api_util.debug_info("custom_vjp fun", fun, (3.2,), {}), - fwd_, api_util.debug_info("custom_vjp fwd", fwd_, (3.2,), {})) - calc = jax.jvp(fwd, (3.2,), (1.0,)) - expected = jax.jvp(fwd_, (3.2,), (1.0,)) - self.assertAllClose(calc, expected) - - @jax.jit - def g(x, t): - (y, r), (y_dot, r_dot) = jax.jvp(fwd, (x,), (t,)) - return y, y_dot - calc = g(3.2, 1.0) - expected = jax.jvp(fun, (3.2,), (1.0,)) - self.assertAllClose(calc, expected) - - def test_optimize_remat_gh21303(self): - @jax.custom_vjp - def f(x): - return jnp.tan(x) - - def f_fwd(x): - return jnp.sin(x), (x,) - - def f_bwd(res, g): - x, = res - cos_x = jnp.cos(x) - return (cos_x * g,) - - f.defvjp(f_fwd, f_bwd, optimize_remat=True) - - def temp(x): - out = jax.remat(f)(x) - out = out ** 2 - return out - - v, g = jax.value_and_grad(temp)(3.2) - self.assertAllClose(v, jnp.tan(3.2)**2) - - def test_optimize_remat_multiple_args(self): - def f_(x, y): - return jnp.sin(x) * y - - @jax.custom_vjp - def f(x, y): - return f_(x, y) - - def f_fwd(x, y): - return f(x, y), (jnp.cos(x), jnp.sin(x), y) - - def f_bwd(res, g): - cos_x, sin_x, y = res - return (cos_x * g * y, sin_x * g) - - f.defvjp(f_fwd, f_bwd, optimize_remat=True) - x, y = 3.2, 1.0 - self.assertAllClose(jax.grad(f)(x, y), jax.grad(f_)(x, y)) - - def test_optimize_remat_kwargs(self): - @jax.custom_vjp - def f(x, y): - return jnp.sin(x) * y - - def f_fwd(x, y, *, keyword=False): - del keyword - return f(x, y), (jnp.cos(x), jnp.sin(x), y) - - def f_bwd(res, g): - cos_x, sin_x, y = res - return (cos_x * g * y, sin_x * g) - - f.defvjp(f_fwd, f_bwd, optimize_remat=True) - x, y = 3.2, 1.0 - jax.grad(f)(x, y) # Doesn't error - - def test_optimize_remat_custom_vmap(self): - # See https://github.com/jax-ml/jax/pull/23000 - @jax.custom_vjp - def f(x, y): - return jnp.sin(x) * y - - @jax.custom_batching.custom_vmap - def f_fwd(x, y): - return f(x, y), (jnp.cos(x), jnp.sin(x), y) - - @f_fwd.def_vmap - def f_fwd_vmap(_, in_batched, x, y): - # Insert a new const here to test the optimize_remat batching rule. - out = np.array([2.0])*f(x, y) - out_batched = (True, (True, True, True)) - return (out, (jnp.cos(x), jnp.sin(x), y)), out_batched - - def f_bwd(res, g): - cos_x, sin_x, y = res - return (cos_x * g * y, sin_x * g) - - f.defvjp(f_fwd, f_bwd, optimize_remat=True) - x, y = jnp.linspace(0.0, 1.0, 5), jnp.linspace(2.0, 5.0, 5) - jax.jit(jax.vmap(jax.grad(f)))(x, y) # Doesn't error - - def test_dce(self): - @jax.custom_vjp - def f(x, y): - return jnp.sin(x), x + jnp.cos(y) - - def f_fwd(x, y): - return f(x, y), (jnp.cos(x), jnp.sin(y)) - - def f_bwd(res, cts): - cos_x, sin_y = res - ct_a, ct_b = cts - return 2.0 * cos_x * ct_a + 1.5 * ct_b, -0.5 * sin_y * ct_b - - f.defvjp(f_fwd, f_bwd) - - def check_jaxpr(jaxpr, used_outs, includes, excludes): - dce_jaxpr, _ = pe.dce_jaxpr(jaxpr, used_outs) - if not dce_jaxpr.eqns: - assert not includes - return - call_jaxpr = dce_jaxpr.eqns[0].params["fun_jaxpr"] - for prim in includes: - assert any(eqn.primitive == prim for eqn in call_jaxpr.eqns) - for prim in excludes: - assert all(eqn.primitive != prim for eqn in call_jaxpr.eqns) - - x, y = 0.1, -1.3 - jaxpr = jax.make_jaxpr(f)(x, y).jaxpr - check_jaxpr(jaxpr, [True, True], [lax.sin_p, lax.cos_p], []) - check_jaxpr(jaxpr, [True, False], [lax.sin_p], [lax.cos_p]) - check_jaxpr(jaxpr, [False, True], [lax.cos_p], [lax.sin_p]) - check_jaxpr(jaxpr, [False, False], [], [lax.sin_p, lax.cos_p]) - - def dce_jaxpr_as_fun(jaxpr, used_outs): - jaxpr_, _ = pe.dce_jaxpr(jaxpr, used_outs) - fun = core.jaxpr_as_fun(pe.close_jaxpr(jaxpr_)) - return lambda *args: fun(*args)[0] - - f0 = dce_jaxpr_as_fun(jaxpr, [True, False]) - f1 = dce_jaxpr_as_fun(jaxpr, [False, True]) - self.assertAllClose( - api.grad(f0, argnums=(0, 1))(x, y), (2.0 * jnp.cos(x), 0.0)) - self.assertAllClose( - api.grad(f1, argnums=(0, 1))(x, y), (1.5, -0.5 * jnp.sin(y))) - - def test_resolve_kwargs_error_message(self): - @jax.custom_vjp - def f(x, y, *, z=None): - return jnp.sin(x), x + jnp.cos(y) - - def f_fwd(x, y): - self.fail("should not be executed") - - def f_bwd(res, cts): - self.fail("should not be executed") - - f.defvjp(f_fwd, f_bwd) - - with self.assertRaisesRegex( - TypeError, - r"The input arguments to the custom_vjp-decorated function f(.*)\n" - r"missing a required argument: 'y'" - ): - f(0.5) - - with self.assertRaisesRegex( - TypeError, - r"The input arguments to the custom_vjp-decorated function f(.*)\n" - "The following keyword arguments could not be resolved to positions: z" - ): - f(0.5, 0.1, z=1.0) - - -def transpose_unary(f, x_example): - def transposed(y): - x, = api.linear_transpose(f, x_example)(y) - return x - return transposed - - -# This class wraps jax.custom_transpose.custom_transpose in order to pass in a -# particular tree of output type on each call. Otherwise it forwards -# all attribute access. -class _custom_transpose: - def __init__(self, out_types, fun): - self.out_types = out_types - self.fun = jax.custom_transpose.custom_transpose(fun) - - def __getattr__(self, name): - return getattr(self.fun, name) - - def __call__(self, *args): - return self.fun(self.out_types, *args) - - -# This function is meant to be used as a decorator that delegates to -# custom_transpose but makes it easy to specify output argument types -# by example. If used directly a decorator (i.e. not invoked with -# example arguments), assumes a scalar-valued function. -# -# TODO(frostig): remove this (and its uses) once custom_transpose offers -# an option of inferring output types. -def custom_transpose(example_out): - if isinstance(example_out, Callable): - out_type = core.get_aval(0.).to_tangent_aval() - return _custom_transpose(out_type, example_out) - return partial( - _custom_transpose, - jax.tree.map( - lambda x: core.get_aval(x).to_tangent_aval(), example_out)) - - -class CustomTransposeTest(jtu.JaxTestCase): - - def test_linear_call(self): - def f(x, y): - def fn(r, x): return x / r - def tp(r, t): return t / r - return x + jax.custom_derivatives.linear_call(fn, tp, y, x) - - def f_ref(x, y): - return x + x / y - - x = jnp.ones(2) * 6. - y = jnp.ones(2) * 3. - self.assertAllClose(f(x, y), f_ref(x, y)) - - f1 = lambda x: f(x, y) - f1_ref = lambda x: f_ref(x, y) - self.assertAllClose(transpose_unary(f1, x)(x), - transpose_unary(f1_ref, x)(x)) - - def test_linear_call_incorrect_transpose(self): - def f(x, y): - def fn(r, x): return x / r - def tp(r, t): return t / (2. * r) # nb: not the true transpose - return x + jax.custom_derivatives.linear_call(fn, tp, y, x) - - def f_ref(x, y): - return x + x / y - - x = jnp.ones(2) * 6. - y = jnp.ones(2) * 3. - self.assertAllClose(f(x, y), f_ref(x, y)) - - f1 = lambda x: f(x, y) - f1_ref = lambda x: f_ref(x, 2. * y) # nb: double the reference divisor - self.assertAllClose(transpose_unary(f1, x)(x), - transpose_unary(f1_ref, x)(x)) - - def test_linear_call_transpose_transpose_transpose(self): - def fn(r, x): return x / r - def tp(r, t): return t / (2. * r) # nb: untrue transpose - def f_(x, y): - return x + jax.custom_derivatives.linear_call(fn, tp, y, x) - - x = jnp.ones(2) * 6. - y = jnp.ones(2) * 3. - f = lambda x: f_(x, y) - ft = transpose_unary(f, x) - ftt = transpose_unary(ft, x) - fttt = transpose_unary(ftt, x) - self.assertAllClose(ft(x), x + tp(y, x)) - self.assertAllClose(f(x), ftt(x)) - self.assertAllClose(ft(x), fttt(x)) - - def test_linear_call_scalar_to_vector(self): - def f(c, x): - def fn(_, x): - return [x, x] - - def tp(_, t): - t1, t2 = t - return t1 + t2 - - return jax.custom_derivatives.linear_call(fn, tp, (), c * x) - - def f_ref(c, x): - return [c * x, c * x] - - c, x = 2., 3. - t = [4., 5.] - self.assertAllClose(f(c, x), f_ref(c, x)) - self.assertAllClose(transpose_unary(partial(f, c), x)(t), - transpose_unary(partial(f_ref, c), x)(t)) - - def test_linear_call_nested(self): - # identity function with an untrue transpose of 0 - def id_(x): - def f(_, x): return x - def t(_, t): return 0. - return jax.custom_derivatives.linear_call(f, t, (), x) - - # identity function with an untrue transpose of 7, and where both - # forward and transpose have custom transpositions that should - # never end up invoked. - def f(x): - def f_(_, x): return id_(x) - def t_(_, t): return id_(7.) - return jax.custom_derivatives.linear_call(f_, t_, (), x) - - x = 5. - id_t = transpose_unary(id_, x) - id_tt = transpose_unary(id_t, x) - ft = transpose_unary(f, x) - ftt = transpose_unary(ft, x) - fttt = transpose_unary(ftt, x) - - self.assertAllClose(id_(x), x) - self.assertAllClose(id_t(x), 0.) - self.assertAllClose(id_tt(x), x) - - self.assertAllClose(f(x), x) - self.assertAllClose(ft(x), 7.) - self.assertAllClose(ftt(x), x) - self.assertAllClose(fttt(x), 7.) - - def test_linear_call_jit(self): - def f(x, y): - def fn(r, x): return x / r - def tp(r, t): return t / r - return x + jax.custom_derivatives.linear_call(fn, tp, y, x) - - x = jnp.ones(2) * 6. - y = jnp.ones(2) * 3. - self.assertAllClose(f(x, y), jax.jit(f)(x, y)) - - f1 = lambda x: f(x, y) - self.assertAllClose(transpose_unary(f1, x)(x), - jax.jit(transpose_unary(f1, x))(x)) - - def test_basic(self): - def f(x, y): - @custom_transpose(jnp.ones(2)) - def fn(r, x): return x / r - @fn.def_transpose - def tp(r, t): return t / r - - return x + fn(y, x) - - def f_ref(x, y): - return x + x / y - - x = jnp.ones(2) * 6. - y = jnp.ones(2) * 3. - self.assertAllClose(f(x, y), f_ref(x, y)) - - f1 = lambda x: f(x, y) - f1_ref = lambda x: f_ref(x, y) - self.assertAllClose(transpose_unary(f1, x)(x), - transpose_unary(f1_ref, x)(x)) - - def test_incorrect_transpose(self): - def f(x, y): - @custom_transpose(jnp.ones(2)) - def fn(r, x): return x / r - @fn.def_transpose - def tp(r, t): return t / (2. * r) # nb: not the true transpose - - return x + fn(y, x) - - def f_ref(x, y): - return x + x / y - - x = jnp.ones(2) * 6. - y = jnp.ones(2) * 3. - self.assertAllClose(f(x, y), f_ref(x, y)) - - f1 = lambda x: f(x, y) - f1_ref = lambda x: f_ref(x, 2. * y) # nb: double the reference divisor - self.assertAllClose(transpose_unary(f1, x)(x), - transpose_unary(f1_ref, x)(x)) - - def test_transpose_transpose_transpose(self): - @custom_transpose(jnp.ones(2)) - def fn(r, x): return x / r - @custom_transpose(jnp.ones(2)) - def tp(r, t): return t / (2. * r) # nb: untrue transpose - - fn.def_transpose(tp) - tp.def_transpose(fn) - - def f_(x, y): - return x + fn(y, x) - - x = jnp.ones(2) * 6. - y = jnp.ones(2) * 3. - f = lambda x: f_(x, y) - ft = transpose_unary(f, x) - ftt = transpose_unary(ft, x) - fttt = transpose_unary(ftt, x) - self.assertAllClose(ft(x), x + tp(y, x)) - self.assertAllClose(f(x), ftt(x)) - self.assertAllClose(ft(x), fttt(x)) - - def test_scalar_to_vector(self): - def f(c, x): - @custom_transpose([0., 0.]) - def fn(_, x): - return [x, x] - - @fn.def_transpose - def tp(_, t): - t1, t2 = t - return t1 + t2 - - return fn((), c * x) - - def f_ref(c, x): - return [c * x, c * x] - - c, x = 2., 3. - t = [4., 5.] - self.assertAllClose(f(c, x), f_ref(c, x)) - self.assertAllClose(transpose_unary(partial(f, c), x)(t), - transpose_unary(partial(f_ref, c), x)(t)) - - def test_nested(self): - # identity function with an untrue transpose of 0 - def id_(x): - f = custom_transpose(lambda _, x: x) - t = custom_transpose(lambda _, t: 0.) - f.def_transpose(t) - t.def_transpose(f) - return f((), x) - - # identity function with an untrue transpose of 7, and where both - # forward and transpose have custom transpositions that should - # never end up invoked. - def f(x): - f_ = custom_transpose(lambda _, x: id_(x)) - t_ = custom_transpose(lambda _, t: id_(7.)) - f_.def_transpose(t_) - t_.def_transpose(f_) - return f_((), x) - - x = 5. - id_t = transpose_unary(id_, x) - id_tt = transpose_unary(id_t, x) - ft = transpose_unary(f, x) - ftt = transpose_unary(ft, x) - fttt = transpose_unary(ftt, x) - - self.assertAllClose(id_(x), x) - self.assertAllClose(id_t(x), 0.) - self.assertAllClose(id_tt(x), x) - - self.assertAllClose(f(x), x) - self.assertAllClose(ft(x), 7.) - self.assertAllClose(ftt(x), x) - self.assertAllClose(fttt(x), 7.) - - def test_one_degree(self): - T = lambda f: transpose_unary(f, 0.) - - @custom_transpose - def f(_, z): return 2. * z - @f.def_transpose - def ft(_, z): return 3. * z - - f = partial(f, ()) - self.assertAllClose(2., f(1.)) - self.assertAllClose(3., T(f)(1.)) - self.assertAllClose(3., T(T(f))(1.)) - self.assertAllClose(3., T(T(T(f)))(1.)) - self.assertAllClose(3., T(T(T(T(f))))(1.)) # ... - - def test_two_degrees(self): - T = lambda f: transpose_unary(f, 0.) - - @custom_transpose - def f(_, z): return 2. * z - - @f.def_transpose - @custom_transpose - def ft(_, z): return 3. * z - - @ft.def_transpose - def ftt(_, z): return 7. * z - - f = partial(f, ()) - self.assertAllClose(2., f(1.)) - self.assertAllClose(3., T(f)(1.)) - self.assertAllClose(7., T(T(f))(1.)) - self.assertAllClose(7., T(T(T(f)))(1.)) - self.assertAllClose(7., T(T(T(T(f))))(1.)) # ... - - def test_symmetric(self): - T = lambda f: transpose_unary(f, 0.) - - @custom_transpose - def f(_, z): return 2. * z - @custom_transpose - def g(_, z): return 3. * z - - f.def_transpose(g) - g.def_transpose(f) - - f = partial(f, ()) - self.assertAllClose(2., f(1.)) - self.assertAllClose(3., T(f)(1.)) - self.assertAllClose(2., T(T(f))(1.)) - self.assertAllClose(3., T(T(T(f)))(1.)) - self.assertAllClose(2., T(T(T(T(f))))(1.)) # ... - - def test_recursive(self): - T = lambda f: transpose_unary(f, 0.) - - @custom_transpose - def f(c, z): return c * z - - @f.def_transpose - def ft(c, z): return f(c + 1., z) - - g = partial(f, 1.) - self.assertAllClose(1., g(1.)) - self.assertAllClose(2., T(g)(1.)) - self.assertAllClose(3., T(T(g))(1.)) - self.assertAllClose(4., T(T(T(g)))(1.)) - self.assertAllClose(5., T(T(T(T(g))))(1.)) # ... - - def test_jvp_lin(self): - def f(x, y): - @custom_transpose(jnp.ones(2)) - def fn(r, x): return x / r - @fn.def_transpose - def tp(r, t): return t / r - return x + fn(y, x) - - def f_ref(x, y): return x + x / y - - x, y, tx = 6., 3., 1. - g = lambda x: f(x, y) - g_ref = lambda x: f_ref(x, y) - self.assertAllClose(api.jvp(g, [x], [tx]), api.jvp(g_ref, [x], [tx])) - - def test_jvp_res(self): - raise unittest.SkipTest('unimplemented') # TODO(frostig) - - def f(x, y): - @custom_transpose(jnp.ones(2)) - def fn(r, x): return x / r - @fn.def_transpose - def tp(r, t): return t / r - return x + fn(y, x) - - def f_ref(x, y): return x + x / y - - x, y, ty = 6., 3., 1. - g = lambda y: f(x, y) - g_ref = lambda y: f_ref(x, y) - self.assertAllClose(api.jvp(g, [y], [ty]), api.jvp(g_ref, [y], [ty])) - - def test_jvp_both(self): - raise unittest.SkipTest('unimplemented') # TODO(frostig) - - def f(x, y): - @custom_transpose(jnp.ones(2)) - def fn(r, x): return x / r - @fn.def_transpose - def tp(r, t): return t / r - return x + fn(y, x) - - def f_ref(x, y): return x + x / y - - x, y, tx, ty = 6., 3., 1., 1. - self.assertAllClose(api.jvp(f, [x, y], [tx, ty]), - api.jvp(f_ref, [x, y], [tx, ty])) - - def test_make_jaxpr(self): - def f(x, y): - @custom_transpose(jnp.ones(2)) - def fn(r, x): return x / r - @fn.def_transpose - def tp(r, t): return 2 * t / r - - return x + fn(y, x) - - x = jnp.ones(2) * 6. - y = jnp.ones(2) * 3. - f_ = lambda x: f(x, y) - f_t = transpose_unary(f_, x) - - jaxpr = api.make_jaxpr(f_)(x) - self.assertIn('custom_transpose_call', str(jaxpr)) - - jaxpr_t = api.make_jaxpr(f_t)(x) - self.assertNotIn('custom_transpose_call', str(jaxpr_t)) - - def test_jit(self): - def f(x, y): - @custom_transpose(jnp.ones(2)) - def fn(r, x): return x / r - @fn.def_transpose - def tp(r, t): return 2 * t / r - - return x + fn(y, x) - - x = jnp.ones(2) * 6. - y = jnp.ones(2) * 3. - self.assertAllClose(f(x, y), jax.jit(f)(x, y)) - - f_ = lambda x: f(x, y) - f_t = transpose_unary(f_, x) - g_ = jax.jit(f_) - g_t = transpose_unary(g_, x) - self.assertAllClose(f_(x), jax.jit(f_)(x)) - self.assertAllClose(f_t(x), jax.jit(f_t)(x)) - self.assertAllClose(f_(x), g_(x)) - self.assertAllClose(f_t(x), g_t(x)) - - def test_jit_recursive(self): - def f(x, y): - @custom_transpose(jnp.ones(2)) - def fn(r, x): return x / r - @fn.def_transpose - def tp(r, t): return 2 * fn(r, t) - - return x + fn(y, x) - - x = jnp.ones(2) * 6. - y = jnp.ones(2) * 3. - self.assertAllClose(f(x, y), jax.jit(f)(x, y)) - - f_ = lambda x: f(x, y) - f_t = transpose_unary(f_, x) - g_ = jax.jit(f_) - g_t = transpose_unary(g_, x) - self.assertAllClose(f_(x), jax.jit(f_)(x)) - self.assertAllClose(f_t(x), jax.jit(f_t)(x)) - self.assertAllClose(f_(x), g_(x)) - self.assertAllClose(f_t(x), g_t(x)) - - def test_cond(self): - def f(x, y): - @custom_transpose(jnp.ones(2)) - def fn(r, x): return x / r - @fn.def_transpose - def tp(r, t): return 2 * t / r - - return x + fn(y, x) - - def cond_wrap(f): - return lambda i, x: lax.cond(i > 0, f, lambda x: x, x) - - i = 7. - x = jnp.ones(2) * 6. - y = jnp.ones(2) * 3. - - f_ = lambda x: f(x, y) - f_t = transpose_unary(f_, x) - g_ = partial(cond_wrap(f_), i) - g_t = transpose_unary(g_, x) - - self.assertAllClose(f_(x), g_(x)) - self.assertAllClose(f_t(x), g_t(x)) - - def test_cond_recursive(self): - def f(x, y): - @custom_transpose(jnp.ones(2)) - def fn(r, x): return x / r - @fn.def_transpose - def tp(r, t): return 2 * fn(r, t) - - return x + fn(y, x) - - def cond_wrap(f): - return lambda i, x: lax.cond(i > 0, f, lambda x: x, x) - - i = 7. - x = jnp.ones(2) * 6. - y = jnp.ones(2) * 3. - - f_ = lambda x: f(x, y) - f_t = transpose_unary(f_, x) - g_ = partial(cond_wrap(f_), i) - g_t = transpose_unary(g_, x) - - self.assertAllClose(f_(x), g_(x)) - self.assertAllClose(f_t(x), g_t(x)) - - -class CustomDceTest(jtu.JaxTestCase): - - def test_basic(self): - @jax.experimental.custom_dce.custom_dce - def f(x): - return jnp.sin(x), jnp.cos(x) - - @f.def_dce - def rule(used_outs, x): - return ( - jnp.exp(x) if used_outs[0] else None, - jnp.sqrt(x) if used_outs[1] else None, - ) - - x = jnp.array(1.1234) - self.assertAllClose(jax.jit(lambda x: f(x)[0])(x), jnp.exp(x)) - self.assertAllClose(jax.jit(lambda x: f(x)[1])(x), jnp.sqrt(x)) - - def test_recursive(self): - @jax.experimental.custom_dce.custom_dce - def f(x): - return jnp.exp(x), 10 * jnp.sqrt(x) - - @f.def_dce - def f_dce(used_outs, x): - return [2 * v if used else None for used, v in zip(used_outs, f(x))] - - x = 1.1234 - expected = f(x) - self.assertAllClose(jax.jit(lambda x: f(x)[0])(x), 2 * expected[0]) - self.assertAllClose(jax.jit(lambda x: f(x)[1])(x), 2 * expected[1]) - - def test_multiple_rounds(self): - @jax.experimental.custom_dce.custom_dce - def f(x, y, z): - return jnp.sin(x), jnp.sin(y), jnp.sin(z) - - @f.def_dce - def rule(used_outs, x, y, z): - patterns.append(used_outs) - outs = [ - jnp.cos(v) if used else None for used, v in zip(used_outs, (x, y, z)) - ] - return outs - - patterns = [] - x, y, z = jnp.array(1.), jnp.array(2.), jnp.array(3.) - jaxpr = jax.make_jaxpr(f)(x, y, z).jaxpr - new_jaxpr, used_ins = pe.dce_jaxpr(jaxpr, [True, False, True]) - assert used_ins == [True, False, True] - new_jaxpr, used_ins = pe.dce_jaxpr(new_jaxpr, [True, False]) - assert used_ins == [True, False] - assert patterns == [(True, False, True), (True, False, False)], patterns - - def test_batching(self): - @jax.experimental.custom_dce.custom_dce - def f(x, y): - return jnp.sin(x), jnp.sin(y) - - @f.def_dce - def rule(used_outs, x, y): - return ( - jnp.cos(x) if used_outs[0] else None, - jnp.cos(y) if used_outs[1] else None, - ) - - x = jnp.linspace(-0.1, 0.2, 5) - y = jnp.linspace(3.0, 4.0, 5) - self.assertAllClose(jax.vmap(f)(x, y), f(x, y)) - self.assertAllClose( - jax.jit(lambda *args: jax.vmap(f)(*args)[0])(x, y), jnp.cos(x) - ) - self.assertAllClose( - jax.vmap(jax.jit(lambda *args: f(*args)[0]))(x, y), jnp.cos(x) - ) - self.assertAllClose( - jax.jit(lambda *args: jax.vmap(f)(*args)[1])(x, y), jnp.cos(y) - ) - self.assertAllClose( - jax.vmap(jax.jit(lambda *args: f(*args)[1]))(x, y), jnp.cos(y) - ) - - def test_composes_with_custom_vjp(self): - # custom_dce must be the "outer" decorator (for now!) because custom_vjp - # doesn't pass through DCE. - @jax.experimental.custom_dce.custom_dce - @jax.custom_vjp - def f(x, y): - return jnp.sin(x) * y, x * jnp.sin(y) - - @f.def_dce - def f_dce_rule(used_outs, x, y): - return ( - jnp.cos(x) * y if used_outs[0] else None, - x * jnp.cos(y) if used_outs[1] else None, - ) - - def f_fwd(x, y): - return f(x, y), (x, jnp.cos(x), jnp.sin(x), y, jnp.cos(y), jnp.sin(y)) - - def f_bwd(res, g): - ga, gb = g - x, cos_x, sin_x, y, cos_y, sin_y = res - return (cos_x * ga * y + sin_y * gb, sin_x * ga + x * cos_y * gb) - - f.defvjp(f_fwd, f_bwd) - - x, y = jnp.array(1.), jnp.array(2.) - self.assertAllClose(jax.jit(lambda *args: f(*args)[0])(x, y), - jnp.cos(x) * y) - jax.grad(lambda *args: f(*args)[0])(x, y) # Doesn't crash. - - def test_can_optimize_remat(self): - @jax.custom_vjp - def f(x): - return jnp.tan(x) - - @jax.experimental.custom_dce.custom_dce - def f_fwd(x): - return jnp.sin(x), (x,) - - @f_fwd.def_dce - def f_dce_rule(used_outs, x): - used_prim, used_res = used_outs - used_res, = used_res - if not used_res: - return f(x), None - prim, res = f_fwd(x) - return prim if used_prim else None, res - - def f_bwd(res, g): - x, = res - cos_x = jnp.cos(x) - return (cos_x * g,) - - f.defvjp(f_fwd, f_bwd) - - def temp(x): - out = jax.remat(f)(x) - out = out ** 2 - return out - - v, g = jax.value_and_grad(temp)(3.2) - self.assertAllClose(v, jnp.tan(3.2)**2) - - def test_static_argnums(self): - @partial(jax.experimental.custom_dce.custom_dce, static_argnums=(0,)) - def g(f, x): - return f(x), 10 * f(x) - - @g.def_dce - def g_dce(f, used_outs, x): # note: static_argnums are always passes first - self.assertTrue(callable(f)) - return [2 * v if used else None for used, v in zip(used_outs, g(f, x))] - - x = 1.1234 - f = lambda x: jnp.exp(x) - expected = g(f, x) - self.assertAllClose(jax.jit(lambda x: g(f, x)[0])(x), 2 * expected[0]) - self.assertAllClose(jax.jit(lambda x: g(f, x)[1])(x), 2 * expected[1]) - - def test_shape_mismatch_error(self): - @jax.experimental.custom_dce.custom_dce - def f(x): - return jnp.stack((x, x)), jnp.cos(x) - - @f.def_dce - def rule(used_outs, x): - return ( - jnp.exp(x) if used_outs[0] else None, - x.astype(jnp.int32) if used_outs[1] else None, - ) - - x = jnp.array(1.1234) - with self.assertRaisesRegex( - ValueError, - r'Custom DCE rule .* same shapes/dtypes .* output\[0\]', - ): - jax.jit(lambda x: f(x)[0])(x) - with self.assertRaisesRegex( - ValueError, - r'Custom DCE rule .* same shapes/dtypes .* output\[1\]', - ): - jax.jit(lambda x: f(x)[1])(x) - - def test_missing_output_error(self): - @jax.experimental.custom_dce.custom_dce - def f(x): - return jnp.sin(x), jnp.cos(x) - - @f.def_dce - def rule(used_outs, x): - return None, None - - x = jnp.array(1.1234) - with self.assertRaisesRegex( - ValueError, - r'Custom DCE rule .* produce values for all .* output\[0\]', - ): - jax.jit(lambda x: f(x)[0])(x) - - def test_consts(self): - @jax.experimental.custom_dce.custom_dce - def f(x): - return np.eye(1) * jnp.sin(x), jnp.cos(x) - - @f.def_dce - def rule(used_outs, x): - return ( - np.full((1, 1), 2.0) * jnp.exp(x) if used_outs[0] else None, - jnp.sqrt(x) if used_outs[1] else None, - ) - - x = jnp.array(1.1234) - expected = rule([True, True], x) - self.assertAllClose(jax.jit(lambda x: f(x)[0])(x), expected[0]) - self.assertAllClose(jax.jit(lambda x: f(x)[1])(x), expected[1]) - - def test_resolve_kwargs_error_message(self): - @jax.experimental.custom_dce.custom_dce - def f(x, y, *, z=None): - return jnp.sin(x) * y, x * jnp.sin(y) - - @f.def_dce - def f_dce_rule(used_outs, x, y): - self.fail("should not be executed") - - with self.assertRaisesRegex( - TypeError, - r"The input arguments to the custom_dce-decorated function f(.*)\n" - r"missing a required argument: 'y'" - ): - f(0.5) - - with self.assertRaisesRegex( - TypeError, - r"The input arguments to the custom_dce-decorated function f(.*)\n" - "The following keyword arguments could not be resolved to positions: z" - ): - f(0.5, 0.1, z=1.0) - - -class CustomVmapTest(jtu.JaxTestCase): - - def test_basic(self): - @jax.custom_batching.custom_vmap - def f(x): return jnp.sin(x) - - @f.def_vmap - def rule(axis_size, in_batched, xs): - xs_batched, = in_batched - self.assertEqual(xs_batched, True) - self.assertEqual(axis_size, xs.shape[0]) - return jnp.cos(xs), xs_batched - - x, xs = jnp.array(1.), jnp.arange(3) - y = f(x) - self.assertAllClose(y, jnp.sin(x)) - ys = api.vmap(f)(xs) - self.assertAllClose(ys, jnp.cos(xs)) - - @jax.numpy_dtype_promotion('standard') - def test_closure(self): - z = jnp.array([2., 1., 3.]) - - @jax.custom_batching.custom_vmap - def f(x): return z + jnp.sin(x) - - @f.def_vmap - def rule(axis_size, in_batched, *args): - self.assertEqual(len(in_batched), 1) - self.assertEqual(len(args), 1) - xs, = args - xs_batched, = in_batched - self.assertEqual(xs_batched, True) - self.assertEqual(axis_size, xs.shape[0]) - return z + jnp.cos(xs), xs_batched - - x, xs = jnp.array(1.), jnp.arange(3) - y = f(x) - self.assertAllClose(y, z + jnp.sin(x)) - ys = api.vmap(f)(xs) - self.assertAllClose(ys, z + jnp.cos(xs)) - - def test_rule_multi_output(self): - @jax.custom_batching.custom_vmap - def f(x): return jnp.sin(x), jnp.cos(x) - - @f.def_vmap - def rule(axis_size, in_batched, xs): - return (jnp.cos(xs), jnp.sin(xs)), tuple(in_batched * 2) - - x, xs = jnp.array(1.), jnp.arange(3) - y1, y2 = f(x) - self.assertAllClose(y1, jnp.sin(x)) - self.assertAllClose(y2, jnp.cos(x)) - ys1, ys2 = api.vmap(f)(xs) - self.assertAllClose(ys1, jnp.cos(xs)) - self.assertAllClose(ys2, jnp.sin(xs)) - - def test_nary(self): - @jax.custom_batching.custom_vmap - def f(x, y): return jnp.sin(x) + y ** 2. - - @f.def_vmap - def rule(axis_size, in_batched, xs, ys): - self.assertEqual(in_batched, [True, True]) - self.assertEqual(axis_size, 3) - self.assertEqual(axis_size, xs.shape[0]) - self.assertEqual(axis_size, ys.shape[0]) - return jnp.cos(xs) + ys ** 2., True - - xs, ys = jnp.arange(3.0), jnp.arange(3.0) - zs = api.vmap(f)(xs, ys) - self.assertAllClose(zs, jnp.cos(xs) + ys ** 2.) - - def test_nary_mixed_batching(self): - @jax.custom_batching.custom_vmap - def vector_dot(u, v): - self.assertEqual(u.ndim, 1) - self.assertEqual(v.ndim, 1) - return u @ v - - size = 4 - vlen = 3 - in_batched_log = [] - - @vector_dot.def_vmap - def vector_dot_vmap_rule(axis_size, in_batched, u, v): - in_batched_log.append(in_batched) - self.assertEqual(axis_size, size) - u_batched, v_batched = in_batched - if u_batched: - self.assertEqual(u.ndim, 2) - self.assertEqual(u.shape[0], size) - else: - self.assertEqual(u.ndim, 1) - self.assertEqual(u.shape[0], vlen) - if v_batched: - self.assertEqual(v.ndim, 2) - self.assertEqual(v.shape[0], size) - else: - self.assertEqual(v.ndim, 1) - self.assertEqual(v.shape[0], vlen) - if u_batched and v_batched: - out = jnp.sum(u * v, axis=1) - else: - out = u @ v if u_batched else v @ u - return out, u_batched or v_batched - - f = vector_dot - v = lambda *shape: jnp.ones(shape) - - y = api.vmap(f, in_axes=(0, None))(v(4, 3), v(3)) - self.assertAllClose(y, v(4, 3) @ v(3)) - y = api.vmap(f, in_axes=(1, None))(v(3, 4), v(3)) - self.assertAllClose(y, v(3, 4).T @ v(3)) - y = api.vmap(f, in_axes=(None, 0))(v(3), v(4, 3)) - self.assertAllClose(y, v(3) @ v(4, 3).T) - y = api.vmap(f, in_axes=(0, 0))(v(4, 3), v(4, 3)) - self.assertAllClose(y, jnp.sum(v(4, 3) * v(4, 3), axis=1)) - self.assertEqual(in_batched_log[0], [True, False]) - self.assertEqual(in_batched_log[1], [True, False]) - self.assertEqual(in_batched_log[2], [False, True]) - self.assertEqual(in_batched_log[3], [True, True]) - - def test_rule_input_signature(self): - @jax.custom_batching.custom_vmap - def f(x): return jnp.sin(x) - - rule_args = [] - - @f.def_vmap - def rule(axis_size, in_batched, xs): - rule_args.append((axis_size, in_batched)) - return jnp.cos(xs), in_batched[0] - - xs = jnp.arange(3) - _ = api.vmap(f)(xs) - (axis_size, in_batched), = rule_args - self.assertIs(type(axis_size), int) - self.assertIs(type(in_batched), list) - self.assertEqual(len(in_batched), 1) - - def test_rule_output_vs_batching_output_mismatch(self): - @jax.custom_batching.custom_vmap - def f(x): return jnp.sin(x) - - @f.def_vmap - def test_rule_abc(axis_size, in_batched, xs): - return [jnp.sin(xs), jnp.cos(xs)], in_batched - - xs = jnp.arange(3) - self.assertRaisesRegex( - ValueError, - 'structure of output value and output batching specification ' - r'returned by custom vmap rule \(test_rule_abc\) do not match.*', - lambda: api.vmap(f)(xs)) - - def test_rule_vs_call_output_mismatch(self): - @jax.custom_batching.custom_vmap - def f(x): return jnp.sin(x) - - @f.def_vmap - def test_rule_abc2(axis_size, in_batched, xs): - return [jnp.sin(xs)], in_batched - - xs = jnp.arange(3) - self.assertRaisesRegex( - ValueError, - r'structure of output returned by custom vmap rule \(test_rule_abc2\) ' - r'does not match that of original custom-vmapped function.*', - lambda: api.vmap(f)(xs)) - - def test_jvp_basic(self): - @jax.custom_batching.custom_vmap - def f(x): return jnp.sin(x) - - @f.def_vmap - def rule(axis_size, in_batched, xs): - self.assertEqual(axis_size, 3) - self.assertEqual(in_batched, [True]) - return jnp.cos(xs), in_batched[0] - - f_jvp = lambda x, tx: api.jvp(f, [x], [tx]) - - x, tx = jnp.array(1.), jnp.array(2.) - xs, txs = jnp.arange(3.), jnp.arange(3.) * 2. - - y, ty = f_jvp(x, tx) - self.assertAllClose(y, jnp.sin(x)) - self.assertAllClose(ty, jnp.cos(x) * tx) - - ys, tys = api.vmap(f_jvp)(xs, txs) - self.assertAllClose(ys, jnp.cos(xs)) - self.assertAllClose(tys, -jnp.sin(xs) * txs) - - ys, tys = api.jvp(api.vmap(f), [xs], [txs]) - self.assertAllClose(ys, jnp.cos(xs)) - self.assertAllClose(tys, -jnp.sin(xs) * txs) - - @jax.numpy_dtype_promotion('standard') - def test_jvp_closure(self): - z = jnp.array([2., 1., 3.]) - def bcast(x): return z + x - z - - @jax.custom_batching.custom_vmap - def f(x): return z + jnp.sin(x) - - @f.def_vmap - def rule(axis_size, in_batched, xs): - self.assertEqual(axis_size, 3) - self.assertEqual(in_batched, [True]) - return z + jnp.cos(xs), in_batched[0] - - f_jvp = lambda x, tx: api.jvp(f, [x], [tx]) - - x, tx = jnp.array(1.), jnp.array(2.) - xs, txs = jnp.arange(3.), jnp.arange(3.) * 2. - - y, ty = f_jvp(x, tx) - self.assertAllClose(y, z + jnp.sin(x)) - self.assertAllClose(ty, bcast(jnp.cos(x)) * tx) - - ys, tys = api.vmap(f_jvp)(xs, txs) - self.assertAllClose(ys, z + jnp.cos(xs)) - self.assertAllClose(tys, bcast(-jnp.sin(xs)) * txs) - - ys, tys = api.jvp(api.vmap(f), [xs], [txs]) - self.assertAllClose(ys, z + jnp.cos(xs)) - self.assertAllClose(tys, bcast(-jnp.sin(xs)) * txs) - - def test_jvp_nary(self): - @jax.custom_batching.custom_vmap - def f(x, y): return jnp.sin(x) + y - - @f.def_vmap - def rule(axis_size, in_batched, xs, ys): - self.assertEqual(axis_size, 3) - self.assertEqual(in_batched, [True, True]) - return jnp.cos(xs) + ys, True - - f_jvp = lambda x, y, tx, ty: api.jvp(f, [x, y], [tx, ty]) - - x, y, tx, ty = jnp.arange(4.) - xs, ys, txs, tys = 4. + jnp.arange(3. * 4).reshape((4, 3)) - - zs, tzs = api.vmap(f_jvp)(xs, ys, txs, tys) - self.assertAllClose(zs, jnp.cos(xs) + ys) - self.assertAllClose(tzs, -jnp.sin(xs) * txs + tys) - - zs, tzs = api.jvp(api.vmap(f), [xs, ys], [txs, tys]) - self.assertAllClose(zs, jnp.cos(xs) + ys) - self.assertAllClose(tzs, -jnp.sin(xs) * txs + tys) - - def test_jvp_extra_batched_tangents(self): - @jax.custom_batching.custom_vmap - def f(x): return jnp.sin(x) - - @f.def_vmap - def rule(axis_size, in_batched, xs): - self.assertEqual(axis_size, 3) - self.assertEqual(in_batched, [False]) - return jnp.cos(xs), in_batched[0] - - f_jvp = lambda x, tx: api.jvp(f, [x], [tx]) - - txs = 2. + jnp.arange(3.) - x = jnp.array(1, dtype=txs.dtype) - y, tys = api.vmap(f_jvp, in_axes=(None, 0), out_axes=(None, 0))(x, txs) - self.assertAllClose(y, jnp.cos(x)) - self.assertAllClose(tys, -jnp.sin(x) * txs) - - def test_jacfwd(self): - # jacfwd is another way to exercise extra-batched tangents - - @jax.custom_batching.custom_vmap - def f(x): return jnp.sin(x) - - @f.def_vmap - def rule(axis_size, in_batched, xs): - self.assertEqual(axis_size, 3) - self.assertEqual(in_batched, [False]) - return jnp.cos(xs), in_batched[0] - - x = jnp.arange(3.) + .72 - j = api.jacfwd(f)(x) - self.assertAllClose(j, -jnp.diag(jnp.sin(x))) - - def test_jvp_extra_batched_primals(self): - @jax.custom_batching.custom_vmap - def f(x): return jnp.sin(x) - - @f.def_vmap - def rule(axis_size, in_batched, xs): - self.assertEqual(axis_size, 3) - self.assertEqual(in_batched, [False]) - return jnp.cos(xs), in_batched[0] - - f_jvp = lambda x, tx: api.jvp(f, [x], [tx]) - - xs = jnp.arange(3.) - tx = jnp.array(4, dtype=xs.dtype) - ys, tys = api.vmap(f_jvp, in_axes=(0, None))(xs, tx) - self.assertAllClose(ys, jnp.cos(xs)) - self.assertAllClose(tys, -jnp.sin(xs) * tx) - - def test_jvp_extra_batched_primals_with_linear_vmap_rule(self): - # When a function is linear, its Jacobian is constant. JAX's JVP - # of linear functions takes advantage of this: when mapping over a - # batch of primals relative to a fixed (i.e. symbolically - # replicated) tangent, output tangents remain replicated as well - # (i.e. JAX will not broadcast them). This is true in general, and - # this test checks that vmapped JVPs continue to behave this way - # when custom_vmap is involved and the custom vmap rule is linear. - - @jax.custom_batching.custom_vmap - def f_linear(x): return 7. * x - - @f_linear.def_vmap - def linear_rule(axis_size, in_batched, xs): - return 11. * xs, in_batched[0] - - @jax.custom_batching.custom_vmap - def f_nonlinear(x): return jnp.sin(x) - - @f_nonlinear.def_vmap - def nonlinear_rule(axis_size, in_batched, xs): - return jnp.cos(xs), in_batched[0] - - f_lin_jvp = lambda x, tx: api.jvp(f_linear, [x], [tx]) - f_non_jvp = lambda x, tx: api.jvp(f_nonlinear, [x], [tx]) - xs = jnp.arange(3.) - tx = jnp.array(4., dtype=xs.dtype) - - # doesn't err - _ = api.vmap(f_lin_jvp, in_axes=(0, None), out_axes=(0, None))(xs, tx) - - # does err - self.assertRaisesRegex( - ValueError, "at vmap out_axes", - lambda: api.vmap( - f_non_jvp, in_axes=(0, None), out_axes=(0, None))(xs, tx)) - - def test_jvp_dataflow_violation(self): - # The jvp-of-custom-vmap machinery should not assume the standard - # dataflow constraint on the JVP of the custom vmap rule (primal - # outputs independent of tangent inputs). Both jvp and vmap are - # "forward" transformations under which, at present, we don't - # enforce the JVP dependence diagram. Because output primals can - # depend on input tangents, extra-batched input tangents can - # create batched output primals, as this test checks. - - @jax.custom_jvp - def cos_with_invalid_dataflow_jvp(x): return jnp.cos(x) - - @cos_with_invalid_dataflow_jvp.defjvp - def invalid_dataflow_jvp(x, tx): - [x], [tx] = x, tx - return jnp.cos(x * tx), tx - - @jax.custom_batching.custom_vmap - def f(x): return jnp.sin(x) - - @f.def_vmap - def rule(axis_size, in_batched, xs): - return cos_with_invalid_dataflow_jvp(xs), in_batched[0] - - f_jvp = lambda x, tx: api.jvp(f, [x], [tx]) - txs = 2. + jnp.arange(3.) - x = jnp.array(1, dtype=txs.dtype) - - # doesn't err - ys, tys = api.vmap(f_jvp, in_axes=(None, 0))(x, txs) - self.assertAllClose(ys, jnp.cos(x * txs)) - self.assertAllClose(tys, txs) - - # does err - self.assertRaisesRegex( - ValueError, "at vmap out_axes", - lambda: api.vmap( - f_jvp, in_axes=(None, 0), out_axes=(None, 0))(x, txs)) - - def test_tree(self): - tree_sin = partial(jax.tree.map, jnp.sin) - tree_cos = partial(jax.tree.map, jnp.cos) - - x, xs = jnp.array(1.), jnp.arange(3) - x = (x, [x + 1, x + 2], [x + 3], x + 4) - xs = (xs, [xs + 1, xs + 2], [xs + 3], xs + 4) - in_batched_ref = jax.tree.map(lambda _: True, x) - - @jax.custom_batching.custom_vmap - def f(xs): return tree_sin(xs) - - @f.def_vmap - def rule(axis_size, in_batched, xs): - self.assertEqual(in_batched, [in_batched_ref]) - sz, = {z.shape[0] for z in jax.tree.leaves(xs)} - self.assertEqual(axis_size, sz) - return tree_cos(xs), in_batched[0] - - y = f(x) - self.assertAllClose(y, tree_sin(x)) - ys = api.vmap(f)(xs) - self.assertAllClose(ys, tree_cos(xs)) - - def test_tree_with_nones(self): - tree_sin = partial(jax.tree.map, jnp.sin) - tree_cos = partial(jax.tree.map, jnp.cos) - - x, xs = jnp.array(1.), jnp.arange(3) - x = (x, [x + 1, None], [x + 3], None) - xs = (xs, [xs + 1, None], [xs + 3], None) - in_batched_ref = jax.tree.map(lambda _: True, x) - - @jax.custom_batching.custom_vmap - def f(xs): return tree_sin(xs) - - @f.def_vmap - def rule(axis_size, in_batched, xs): - self.assertEqual(in_batched, [in_batched_ref]) - sz, = {z.shape[0] for z in jax.tree.leaves(xs)} - self.assertEqual(axis_size, sz) - return tree_cos(xs), in_batched[0] - - y = f(x) - self.assertAllClose(y, tree_sin(x)) - ys = api.vmap(f)(xs) - self.assertAllClose(ys, tree_cos(xs)) - - def test_jit(self): - @jax.custom_batching.custom_vmap - def f(x): return jnp.sin(x) - - @f.def_vmap - def rule(axis_size, in_batched, xs): - self.assertEqual(in_batched, [True]) - self.assertEqual(axis_size, xs.shape[0]) - return jnp.cos(xs), in_batched[0] - - x, xs = jnp.array(1.), jnp.arange(3) - self.assertAllClose(f(x), jit(f)(x)) - self.assertAllClose(jit(api.vmap(f))(xs), api.vmap(f)(xs)) - self.assertAllClose(api.vmap(jit(f))(xs), api.vmap(f)(xs)) - - def test_sequential_vmap_basic(self): - @jax.custom_batching.sequential_vmap - def f(x): - return x + 1. - - def vmap_ref(xs): - return lax.map(f, xs) - - xs = jnp.arange(3.) - jaxpr = api.make_jaxpr(api.vmap(f))(xs) - jaxpr_ref = api.make_jaxpr(vmap_ref)(xs) - - self.assertEqual(str(jaxpr), str(jaxpr_ref)) - - def test_sequential_vmap_nary_same_batching(self): - @jax.custom_batching.sequential_vmap - def f(x, y): - return x + y - - def vmap_ref(xs, ys): - return lax.map(lambda args: f(*args), (xs, ys)) - - xs, ys = jnp.arange(3.), 4. + jnp.arange(3.) - jaxpr = api.make_jaxpr(api.vmap(f))(xs, ys) - jaxpr_ref = api.make_jaxpr(vmap_ref)(xs, ys) - - self.assertEqual(str(jaxpr), str(jaxpr_ref)) - - def test_sequential_vmap_nary_mixed_batching(self): - @jax.custom_batching.sequential_vmap - def f(x, y): - return x + y - - def vmap_ref(xs, y): - return lax.map(lambda x: f(x, y), xs) - - xs, y = jnp.arange(3.), 4. - jaxpr = api.make_jaxpr(api.vmap(f, in_axes=(0, None)))(xs, y) - jaxpr_ref = api.make_jaxpr(vmap_ref)(xs, y) - - self.assertEqual(str(jaxpr), str(jaxpr_ref)) - - @parameterized.named_parameters( - ("1", 1), - ("8", 4), - ("12", 8), - ("16", 16), - ) - def test_batch_map_basic(self, batch_size: int): - def f(x): - self.assertEqual(x.shape, ()) - return x**2 - - x = np.arange(16) - y = jax.lax.map(f, x, batch_size=batch_size) - - np.testing.assert_array_equal(y, x**2) - - @parameterized.named_parameters( - ("1", 1), - ("8", 4), - ("12", 8), - ("16", 16), - ) - def test_batch_map_pytrees(self, batch_size: int): - f = lambda x: {'b': x['a'] ** 2} - inputs = {'a': np.arange(16)} - expected = np.arange(16) ** 2 - - outputs = jax.lax.map(f, inputs, batch_size=batch_size) - self.assertAllClose(outputs['b'], expected) - - outputs = jax.lax.map( - f, inputs, batch_size=batch_size - ) - self.assertAllClose(outputs['b'], expected) - - def test_batch_divides_axis(self): - def f(t): - x, a = t - self.assertEqual(x.shape, (4,)) - return (x + a)**2 - - x = jax.random.randint(jax.random.key(0), (16, 4), -10, 10) - a = jax.random.randint(jax.random.key(1), (16, 4), -10, 10) - - @jax.jit - def g(x, a): - return jax.lax.map(f, (x, a), batch_size=8) - - y = g(x, a) - - self.assertAllClose(y, (x + a)**2) - - def test_undefined_rule(self): - @jax.custom_batching.custom_vmap - def f(x): return jnp.sin(x) - - with self.assertRaisesRegex( - AttributeError, "No batching rule defined for custom_vmap function f"): - f(0.5) - - def test_kwargs(self): - @jax.custom_batching.custom_vmap - def f(x): return jnp.sin(x) - - @f.def_vmap - def rule(axis_size, in_batched, xs): - xs_batched, = in_batched - self.assertEqual(xs_batched, True) - self.assertEqual(axis_size, xs.shape[0]) - return jnp.cos(xs), xs_batched - - x, xs = jnp.array(1.), jnp.arange(3) - y = f(x=x) - self.assertAllClose(y, jnp.sin(x)) - ys = api.vmap(f)(x=xs) - self.assertAllClose(ys, jnp.cos(xs)) - - def test_partial_eval_raises(self): - @jax.custom_batching.custom_vmap - def f(x): - return jnp.sin(x) - - @f.def_vmap - def rule(axis_size, in_batched, xs): - del axis_size # unused - return jnp.cos(xs), in_batched[0] - - with self.assertRaisesRegex( - ValueError, - "Linearization failed to produce known values for all output primals", - ): - jax.grad(f)(0.5) - - def test_compose_custom_vjp(self): - @jax.custom_vjp - @jax.custom_batching.custom_vmap - def f(x, y): - return jnp.sin(x) * y - - @f.def_vmap - def f_vmap_rule(axis_size, in_batched, xs, ys): - return jnp.cos(xs) * ys, True - - def f_fwd(x, y): - return f(x, y), (jnp.cos(x), jnp.sin(x), y) - - def f_bwd(res, g): - cos_x, sin_x, y = res - return (cos_x * g * y, sin_x * g) - - f.defvjp(f_fwd, f_bwd) - - xs = jnp.linspace(0, 1, 5) - ys = jnp.linspace(-0.1, 0.1, 5) - self.assertAllClose(jax.vmap(f)(xs, ys), jnp.cos(xs) * ys) - jax.grad(f)(xs[0], ys[0]) # Doesn't crash. - - def test_compose_custom_vjp_bwd_rule(self): - # This tests the case where both the forward and backward rules are wrapped - # in custom_vmap. - @jax.custom_batching.sequential_vmap - def fun_fwd(x, y): - return jnp.sin(x) * y, (x, y) - - @jax.custom_batching.sequential_vmap - def fun_bwd(res, ct): - x, y = res - return x * ct, y * ct - - fun = jax.custom_vjp(lambda *args: fun_fwd(*args)[0]) - fun.defvjp(fun_fwd, fun_bwd) - - xs = jnp.linspace(0, 1, 5) - y = jnp.array(0.5, dtype=xs.dtype) - f = jax.vmap(jax.jit(fun), in_axes=(0, None)) - out, f_vjp = jax.vjp(f, xs, y) - f_vjp(out) # Doesn't crash. - - def test_resolve_kwargs_error_message(self): - @jax.custom_batching.custom_vmap - def f(x, y, *, z=None): - return jnp.sin(x) * y - - @f.def_vmap - def f_vmap_rule(axis_size, in_batched, xs, ys): - self.fail("should not be executed") - - with self.assertRaisesRegex( - TypeError, - r"The input arguments to the custom_vmap-decorated function f(.*)\n" - r"missing a required argument: 'y'" - ): - f(0.5) - - with self.assertRaisesRegex( - TypeError, - r"The input arguments to the custom_vmap-decorated function f(.*)\n" - "The following keyword arguments could not be resolved to positions: z" - ): - f(0.5, 0.1, z=1.0) - - -class CustomApiTest(jtu.JaxTestCase): - """Test interactions among the custom_{vmap,jvp,vjp,transpose,*} APIs""" - - def test_method_forwarding(self): - @jax.custom_batching.custom_vmap - @jax.custom_jvp - @jax.custom_transpose.custom_transpose - def f(x): return 2. * x - - # none of these err: - @f.def_vmap - def f_batch(sz, b, xs): return 2. * xs - @f.defjvp - def f_jvp(x, tx): return 2. * x, 2. * tx - @f.def_transpose - def f_transpose(x): return 2. * x - - def test_def_method_forwarding_all_permutations(self): - for wraps in it.permutations([ - jax.custom_jvp, jax.custom_transpose.custom_transpose, jax.custom_batching.custom_vmap]): - f = lambda x: x + 1. - for wrap in wraps: - f = wrap(f) - for methods in it.permutations(['defjvp', 'def_vmap', 'def_transpose']): - for method in methods: - self.assertIsInstance(getattr(f, method), Callable) - - for decorators in it.permutations([ - jax.custom_vjp, jax.custom_transpose.custom_transpose, jax.custom_batching.custom_vmap]): - f = lambda x: x + 1. - for decorator in decorators: - f = decorator(f) - for methods in it.permutations(['defvjp', 'def_vmap', 'def_transpose']): - for method in methods: - self.assertIsInstance(getattr(f, method), Callable) - - -class BufferDonationTest(jtu.BufferDonationTestCase): - - @jtu.device_supports_buffer_donation() - def test_pmap_donate_argnums_invalidates_input(self): - move = api.pmap(lambda x: x + x - x, donate_argnums=0) - n = jax.local_device_count() - x = api.pmap(lambda x: x)(jnp.ones([n])) - y = move(x) - self.assertDeleted(x) - np.testing.assert_allclose(y, [1.] * n) - - @jtu.device_supports_buffer_donation() - def test_pmap_nested_donate_ignored(self): - pmap_fun = jit(lambda x: api.pmap(lambda y: y ** 2, donate_argnums=0)(x)) - a = api.pmap(lambda x: x)(jnp.array([1])) - - # NOTE(mattjj): stopped raising error here and instead just ignored - # with self.assertRaisesRegex(ValueError, "nested.*not supported"): - # pmap_fun(a) - - pmap_fun(a) # doesn't crash - - -class NamedCallTest(jtu.JaxTestCase): - - def test_non_jaxtype_arg(self): - # For the test to fail without the invalid JaxType filter we need to pass - # in a valid JaxType that forces the invalid Jaxtype to be raised to an - # abstract value. - def f(not_a_jaxtype, a_jaxtype): - # then Jax needs to try and evaluate the abstractified non-JaxType - if not_a_jaxtype: - return a_jaxtype - return 0 - - f = api.named_call(f, name="test") - out = jax.jit(f, static_argnums=(0,))("not a Jaxtype", 1) - self.assertEqual(out, 1) - - @parameterized.parameters(jax.jit, jax.grad, jax.vmap, jax.remat) - def test_jax_transforms(self, transform): - f = jnp.sum - x = jnp.array([1.]) - - unnamed_out = transform(f)(x) - named_out = transform(api.named_call(f, name="test"))(x) - - self.assertEqual(unnamed_out, named_out) - - def test_static_argnums(self): - f = api.named_call(lambda x, y: y if x else None, name="test") - f = jax.jit(f, static_argnums=(0,)) - out = f(True, 5) - self.assertEqual(out, 5) + def test_static_argnums(self): + f = api.named_call(lambda x, y: y if x else None, name="test") + f = jax.jit(f, static_argnums=(0,)) + out = f(True, 5) + self.assertEqual(out, 5) def test_partial_eval(self): f = api.named_call(lambda x, y: y if x else None, name="test") @@ -11504,5 +7355,87 @@ def wsc_as_noop(ctx, operand, *args, **kwargs): self.assertNotIn("stablehlo.custom_call @Sharding", lowered_ir) +class InputSavedVJPTest(jtu.JaxTestCase): + + def test_basic(self): + def f(x, y): + return x * y + + primals = 2., 3. + y, f_vjp = api.si_vjp(f, [True, True], *primals) + arg_cts = f_vjp(1., *primals) + self.assertAllClose(y, 6.) + self.assertAllClose(arg_cts, (3., 2.)) + + def test_basic_pass_through_jit(self): + def f(x, y): + return x * y + + @jax.jit + def g(): + primals = 2., 3. + y, f_vjp = api.si_vjp(f, [True, True], *primals) + return y, f_vjp + + @jax.jit + def h(f_vjp): + return f_vjp(1., 2., 3.) + + y, f_vjp = g() + arg_cts = h(f_vjp) + self.assertAllClose(y, 6.) + self.assertAllClose(arg_cts, (3., 2.)) + + def test_basic_unused(self): + f = jnp.sin + primals = 3., + y, f_vjp = api.si_vjp(f, [True], *primals) + x_ct, = f_vjp(1., *primals) + self.assertAllClose(y, jnp.sin(3.)) + self.assertAllClose(x_ct, jnp.cos(3.)) + + with self.assertRaisesRegex(Exception, "not used by the backward pass: x"): + _ = api.si_vjp(f, [True], *primals, allow_unused=False) + + def test_basic_opaque(self): + f = jnp.sin + primals = 3., + with self.assertRaisesRegex(Exception, "the backward pass requires opaque"): + _ = api.si_vjp(f, [True], *primals, allow_opaque=False) + + def test_basic_pytree_error(self): + def f(x): + return [x['hi'] * x['bye']] + + y, f_vjp = api.si_vjp(f, [True], {'hi': 2., 'bye': 3.}) + arg_ct, = f_vjp([1.], {'hi': 2., 'bye': 3.}) + self.assertAllClose(y, [6.]) + self.assertAllClose(arg_ct, {'hi': 3., 'bye': 2.}) + + with self.assertRaisesRegex(ValueError, "but the structures differ"): + f_vjp(1., {'hi': 2.}) + + def test_fsdp(self): + # see https://github.com/jax-ml/jax/pull/27017 for why this is called "fsdp" + def f2(x, w): + x = 1. * x + x = x @ w + x = 2. * x + return x + + x = jnp.ones((3, 4)) + w = jnp.ones((4, 4)) + y, f2_sivjp = api.si_vjp(f2, [False, True], x, w) + y_grad = jnp.ones_like(y) + x_grad, w_grad = f2_sivjp(y_grad, w) + self.assertAllClose(x_grad, 2. * y_grad @ w.T) + self.assertAllClose(w_grad, 2. * x.T @ y_grad) + + def test_doesnt_leak_symbolic_zeros(self): + _, vjp = api.si_vjp(lambda x: 1., [False], 3.14) + ans, = vjp(1.0) + self.assertIsInstance(ans, jax.Array) + + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/array_api_skips.txt b/tests/array_api_skips.txt index 2f8d4d1c666f..7781b93e7820 100644 --- a/tests/array_api_skips.txt +++ b/tests/array_api_skips.txt @@ -2,6 +2,7 @@ # finfo return type misalignment (https://github.com/data-apis/array-api/issues/405) array_api_tests/test_data_type_functions.py::test_finfo[float32] +array_api_tests/test_data_type_functions.py::test_finfo[complex64] # Test suite attempts in-place mutation: array_api_tests/test_array_object.py::test_setitem @@ -10,6 +11,26 @@ array_api_tests/test_creation_functions.py::test_asarray_arrays # Returns wrong zero sign array_api_tests/test_special_cases.py::test_unary[sign((x_i is -0 or x_i == +0)) -> 0] +array_api_tests/test_special_cases.py::test_iop[__imod__(x1_i is +0 and x2_i < 0) -> -0] +array_api_tests/test_special_cases.py::test_iop[__imod__(x1_i is -0 and x2_i > 0) -> +0] +array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -0 and x2_i > 0) -> -0] +array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -0 and x2_i < 0) -> +0] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -0 and x2_i > 0) -> -0] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -0 and x2_i < 0) -> +0] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] +array_api_tests/test_special_cases.py::test_binary[remainder(x1_i is -0 and x2_i > 0) -> +0] +array_api_tests/test_special_cases.py::test_binary[__mod__(x1_i is -0 and x2_i > 0) -> +0] +array_api_tests/test_special_cases.py::test_binary[floor_divide(isfinite(x1_i) and x1_i < 0 and x2_i is -infinity) -> +0] +array_api_tests/test_special_cases.py::test_binary[__mod__(x1_i is +0 and x2_i < 0) -> -0] +array_api_tests/test_special_cases.py::test_binary[remainder(x1_i is +0 and x2_i < 0) -> -0] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is -infinity) -> +0] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -0 and x2_i > 0) -> -0] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -0 and x2_i < 0) -> +0] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is -infinity) -> +0] + +# Array API expects default value for axis argument. +array_api_tests/test_indexing_functions.py::test_take_along_axis # Returns int32 when int64 is expected array_api_tests/test_searching_functions.py::test_searchsorted @@ -19,3 +40,40 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_clip # JAX raises a ValueError rather than the expected IndexError for out-of-bound axis array_api_tests/test_manipulation_functions.py::test_expand_dims + +# Doesn't promote to uint64 +array_api_tests/test_statistical_functions.py::test_cumulative_prod + +# TODO(jakevdp): fix the following failures: + +# Returns NaN rather than inf +array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] +array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i < 0 and x2_i is +0) -> -infinity] +array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i > 0 and x2_i is +0) -> +infinity] +array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i < 0 and x2_i is -0) -> +infinity] +array_api_tests/test_special_cases.py::test_binary[floor_divide(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] +array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i > 0 and x2_i is -0) -> -infinity] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] +array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i < 0 and x2_i is +0) -> -infinity] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] +array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i < 0 and x2_i is -0) -> +infinity] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity] +array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i > 0 and x2_i is +0) -> +infinity] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i > 0 and x2_i is -0) -> -infinity] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i > 0 and x2_i is +0) -> +infinity] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i > 0 and x2_i is -0) -> -infinity] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i < 0 and x2_i is -0) -> +infinity] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i < 0 and x2_i is +0) -> -infinity] + +# Returns -1.0 rather than 0.0 +array_api_tests/test_special_cases.py::test_binary[floor_divide(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] \ No newline at end of file diff --git a/tests/array_api_test.py b/tests/array_api_test.py index 250eeb810872..8e4ba275fdd3 100644 --- a/tests/array_api_test.py +++ b/tests/array_api_test.py @@ -26,6 +26,7 @@ import jax.numpy as jnp from jax._src import config, test_util as jtu from jax._src.dtypes import _default_types, canonicalize_dtype +from jax._src import xla_bridge as xb ARRAY_API_NAMESPACE = jnp @@ -275,14 +276,18 @@ def build_dtype_dict(self, dtypes): def test_capabilities_info(self): capabilities = self.info.capabilities() - assert capabilities["boolean indexing"] + assert not capabilities["boolean indexing"] assert not capabilities["data-dependent shapes"] + assert capabilities["max dimensions"] == 64 def test_default_device_info(self): assert self.info.default_device() is None def test_devices_info(self): - assert self.info.devices() == jax.devices() + devices = set(self.info.devices()) + assert None in devices + for backend in xb.backends(): + assert devices.issuperset(jax.devices(backend)) def test_default_dtypes_info(self): _default_dtypes = { diff --git a/tests/array_extensibility_test.py b/tests/array_extensibility_test.py new file mode 100644 index 000000000000..6461cb54d73f --- /dev/null +++ b/tests/array_extensibility_test.py @@ -0,0 +1,586 @@ +# Copyright 2018 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +from typing import Any, NamedTuple +from collections.abc import Callable + +from absl.testing import absltest +from absl.testing import parameterized +import numpy as np + +import jax +import jax.numpy as jnp +from jax.typing import ArrayLike +from jax._src import config +from jax._src import test_util as jtu + + +config.parse_flags_with_absl() + + +@functools.partial(jax.tree_util.register_dataclass, + data_fields=['x'], + meta_fields=[]) +class JaxArrayWrapper: + """Class that provides a __jax_array__ method.""" + x: ArrayLike + + def __init__(self, x: ArrayLike): + self.x = x + + def __jax_array__(self) -> jax.Array: + return jnp.asarray(self.x) + + +class DuckTypedArrayWithErroringJaxArray: + """Duck-typed array that provides a __jax_array__ method which fails.""" + shape = (2, 3) + dtype = np.dtype('float32') + + def __jax_array__(self): + raise ValueError("jax array was called.") + + +class NumPyAPI(NamedTuple): + fun: Callable[..., Any] + args: list[jax.ShapeDtypeStruct] + kwargs: dict[str, Any] + skip_on_devices: list[str] | None + + def name(self): + return self.fun.__name__ + + def make_args(self, rng): + rng = jtu.rand_default(rng) + return jax.tree.map(lambda arg: rng(arg.shape, arg.dtype), self.args) + + def with_skip_on_devices(self, disabled_devices: list[str]) -> 'NumPyAPI': + return self._replace(skip_on_devices=disabled_devices) + + @classmethod + def sig(cls, fun: Callable[..., Any], *args: Any, **kwargs: Any) -> 'NumPyAPI': + return cls(fun, args, kwargs, None) + + +class ShapeDtype: + """Shortcut for specifying ShapeDtypeStruct.""" + def __init__(self, dtype): + self.dtype = jax.dtypes.canonicalize_dtype(dtype) + def __getitem__(self, shape) -> jax.ShapeDtypeStruct: + if isinstance(shape, int): + shape = (shape,) + return jax.ShapeDtypeStruct(shape, self.dtype) + +Bool = ShapeDtype(bool) +Int = ShapeDtype(int) +UInt = ShapeDtype('uint32') +Uint8 = ShapeDtype('uint8') +Float = ShapeDtype(float) +Complex = ShapeDtype(complex) + + +# NumPy namespace objects skipped in the enumeration below, mainly because +# they are not functions or do not take arrays as positional arguments. +SKIPPED_APIS = [ + 'apply_along_axis', + 'apply_over_axes', + 'arange', + 'array_str', + 'array_repr', + 'astype', + 'bartlett', + 'bfloat16', + 'blackman', + 'block', + 'bool', + 'bool_', + 'broadcast_shapes', + 'c_', + 'can_cast', + 'cdouble', + 'character', + 'complex128', + 'complex64', + 'complex_', + 'complexfloating', + 'csingle', + 'diag_indices', + 'double', + 'dtype', + 'e', + 'einsum', + 'einsum_path', + 'euler_gamma', + 'empty', + 'eye', + 'finfo', + 'flexible', + 'float_', + 'float16', + 'float32', + 'float4_e2m1fn', + 'float64', + 'float8_e3m4', + 'float8_e4m3', + 'float8_e4m3b11fnuz', + 'float8_e4m3fn', + 'float8_e4m3fnuz', + 'float8_e5m2', + 'float8_e5m2fnuz', + 'float8_e8m0fnu', + 'floating', + 'from_dlpack', + 'frombuffer', + 'fromfile', + 'fromfunction', + 'fromiter', + 'frompyfunc', + 'fromstring', + 'full', + 'generic', + 'geomspace', + 'get_printoptions', + 'gradient', + 'hamming', + 'hanning', + 'identity', + 'iinfo', + 'index_exp', + 'indices', + 'inexact', + 'inf', + 'int16', + 'int2', + 'int32', + 'int4', + 'int64', + 'int8', + 'int_', + 'integer', + 'isdtype', + 'issubdtype' + 'iterable' + 'kaiser' + 'kron' + 'ix_', + 'linalg', + 'linspace', + 'load', + 'logspace', + 'mask_indices', + 'mgrid', + 'nan', + 'ndarray', + 'newaxis', + 'number', + 'object_', + 'ogrid', + 'ones', + 'pi', + 'printoptions', + 'promote_types' + 'r_', + 'result_type', + 's_', + 'save', + 'savez', + 'set_printoptions', + 'signedinteger', + 'single', + 'tri', + 'tril_indices', + 'triu_indices', + 'ufunc', + 'uint', + 'uint16', + 'uint2', + 'uint32', + 'uint4', + 'uint64', + 'uint8', + 'unsignedinteger', + 'vectorize', + 'zeros', +] + +# TODO(jakevdp): commented APIs are ones which do not yet support +# __jax_array__ on inputs. We should fix these! +NUMPY_APIS = [ + NumPyAPI.sig(jnp.abs, Float[5]), + NumPyAPI.sig(jnp.absolute, Float[5]), + NumPyAPI.sig(jnp.acos, Float[5]), + NumPyAPI.sig(jnp.acosh, Float[5]), + NumPyAPI.sig(jnp.add, Float[5], Float[5]), + NumPyAPI.sig(jnp.all, Bool[5]), + NumPyAPI.sig(jnp.allclose, Float[5], Float[5]), + NumPyAPI.sig(jnp.amax, Float[5]), + NumPyAPI.sig(jnp.amin, Float[5]), + NumPyAPI.sig(jnp.angle, Float[5]), + NumPyAPI.sig(jnp.any, Float[5]), + NumPyAPI.sig(jnp.append, Float[10], Float[()]), + NumPyAPI.sig(jnp.arccos, Float[5]), + NumPyAPI.sig(jnp.arccosh, Float[5]), + NumPyAPI.sig(jnp.arcsin, Float[5]), + NumPyAPI.sig(jnp.arcsinh, Float[5]), + NumPyAPI.sig(jnp.arctan, Float[5]), + NumPyAPI.sig(jnp.arctan2, Float[5], Float[5]), + NumPyAPI.sig(jnp.arctanh, Float[5]), + NumPyAPI.sig(jnp.argmax, Float[10]), + NumPyAPI.sig(jnp.argmin, Float[10]), + NumPyAPI.sig(jnp.argpartition, Float[10], kth=5), + NumPyAPI.sig(jnp.argsort, Float[10]), + NumPyAPI.sig(jnp.argwhere, Float[10]), + NumPyAPI.sig(jnp.around, Float[5]), + NumPyAPI.sig(jnp.array, Float[5]), + NumPyAPI.sig(jnp.array_equal, Float[5], Float[5]), + NumPyAPI.sig(jnp.array_equiv, Float[5], Float[5]), + NumPyAPI.sig(jnp.array_split, Float[9], indices_or_sections=3), + NumPyAPI.sig(jnp.asarray, Float[5]), + NumPyAPI.sig(jnp.asin, Float[5]), + NumPyAPI.sig(jnp.asinh, Float[5]), + NumPyAPI.sig(jnp.atan, Float[5]), + NumPyAPI.sig(jnp.atan2, Float[5], Float[5]), + NumPyAPI.sig(jnp.atanh, Float[5]), + NumPyAPI.sig(jnp.atleast_1d, Float[5]), + NumPyAPI.sig(jnp.atleast_2d, Float[5]), + NumPyAPI.sig(jnp.atleast_3d, Float[5]), + NumPyAPI.sig(jnp.average, Float[10]), + NumPyAPI.sig(jnp.bincount, Int[10]), + NumPyAPI.sig(jnp.bitwise_and, Int[5], Int[5]), + NumPyAPI.sig(jnp.bitwise_count, Int[5]), + NumPyAPI.sig(jnp.bitwise_invert, Int[5]), + NumPyAPI.sig(jnp.bitwise_left_shift, Int[5], Int[5]), + NumPyAPI.sig(jnp.bitwise_not, Int[5]), + NumPyAPI.sig(jnp.bitwise_or, Int[5], Int[5]), + NumPyAPI.sig(jnp.bitwise_right_shift, Int[5], Int[5]), + NumPyAPI.sig(jnp.bitwise_xor, Int[5], Int[5]), + NumPyAPI.sig(jnp.broadcast_arrays, Float[5]), + NumPyAPI.sig(jnp.broadcast_to, Float[()], shape=(10,)), + NumPyAPI.sig(jnp.cbrt, Float[5]), + NumPyAPI.sig(jnp.ceil, Float[5]), + NumPyAPI.sig(jnp.choose, Int[3], [Float[3], Float[3], Float[3]], mode='clip'), + NumPyAPI.sig(jnp.clip, Float[5]), + NumPyAPI.sig(jnp.column_stack, [Float[5], Float[5], Float[5]]), + NumPyAPI.sig(jnp.compress, Float[10], Bool[10]), + NumPyAPI.sig(jnp.concat, [Float[5], Float[5]]), + NumPyAPI.sig(jnp.concatenate, [Float[5], Float[5]]), + NumPyAPI.sig(jnp.conj, Float[5]), + NumPyAPI.sig(jnp.conjugate, Float[5]), + NumPyAPI.sig(jnp.convolve, Float[7], Float[3]), + NumPyAPI.sig(jnp.copy, Float[5]), + NumPyAPI.sig(jnp.copysign, Float[5], Float[5]), + NumPyAPI.sig(jnp.corrcoef, Float[7], Float[7]), + NumPyAPI.sig(jnp.correlate, Float[7], Float[3]), + NumPyAPI.sig(jnp.cos, Float[5]), + NumPyAPI.sig(jnp.cosh, Float[5]), + NumPyAPI.sig(jnp.count_nonzero, Float[10]), + NumPyAPI.sig(jnp.cov, Float[10]), + NumPyAPI.sig(jnp.cross, Float[3], Float[3]), + NumPyAPI.sig(jnp.cumprod, Float[5]), + NumPyAPI.sig(jnp.cumsum, Float[5]), + NumPyAPI.sig(jnp.cumulative_prod, Float[5]), + NumPyAPI.sig(jnp.cumulative_sum, Float[5]), + NumPyAPI.sig(jnp.deg2rad, Float[5]), + NumPyAPI.sig(jnp.degrees, Float[5]), + NumPyAPI.sig(jnp.delete, Float[5], Int[()]), + NumPyAPI.sig(jnp.diag, Float[5]), + NumPyAPI.sig(jnp.diag_indices_from, Float[5, 5]), + NumPyAPI.sig(jnp.diagflat, Float[5]), + NumPyAPI.sig(jnp.diagonal, Float[5, 5]), + NumPyAPI.sig(jnp.diff, Float[5]), + NumPyAPI.sig(jnp.digitize, Float[5], Float[5]), + NumPyAPI.sig(jnp.divide, Float[5], Float[5]), + NumPyAPI.sig(jnp.divmod, Float[5], Float[5]), + NumPyAPI.sig(jnp.dot, Float[5], Float[5]), + NumPyAPI.sig(jnp.dsplit, Float[3, 5, 6], indices_or_sections=2), + NumPyAPI.sig(jnp.dstack, [Float[3, 5, 1], Float[3, 5, 3]]), + NumPyAPI.sig(jnp.ediff1d, Float[5]), + NumPyAPI.sig(jnp.empty_like, Float[5]), + NumPyAPI.sig(jnp.equal, Float[5], Float[5]), + NumPyAPI.sig(jnp.exp, Float[5]), + NumPyAPI.sig(jnp.exp2, Float[5]), + NumPyAPI.sig(jnp.expand_dims, Float[5], axis=0), + NumPyAPI.sig(jnp.expm1, Float[5]), + NumPyAPI.sig(jnp.extract, Bool[5], Float[5]), + NumPyAPI.sig(jnp.fabs, Float[5]), + NumPyAPI.sig(jnp.fft.fft, Float[5]), + NumPyAPI.sig(jnp.fft.fft2, Float[5, 5]), + NumPyAPI.sig(jnp.fft.ifft, Float[5]), + NumPyAPI.sig(jnp.fft.ifft2, Float[5, 5]), + NumPyAPI.sig(jnp.fill_diagonal, Float[5, 5], Float[()], inplace=False), + NumPyAPI.sig(jnp.fix, Float[5]), + NumPyAPI.sig(jnp.flatnonzero, Float[5]), + NumPyAPI.sig(jnp.flip, Float[5]), + NumPyAPI.sig(jnp.fliplr, Float[5, 5]), + NumPyAPI.sig(jnp.flipud, Float[5, 5]), + NumPyAPI.sig(jnp.float_power, Float[5], Float[5]), + NumPyAPI.sig(jnp.floor, Float[5]), + NumPyAPI.sig(jnp.floor_divide, Float[5], Float[5]), + NumPyAPI.sig(jnp.fmax, Float[5], Float[5]), + NumPyAPI.sig(jnp.fmin, Float[5], Float[5]), + NumPyAPI.sig(jnp.fmod, Float[5], Float[5]), + NumPyAPI.sig(jnp.frexp, Float[5]), + NumPyAPI.sig(jnp.full_like, Float[5], Float[()]), + NumPyAPI.sig(jnp.gcd, Int[5], Int[5]), + NumPyAPI.sig(jnp.greater, Float[5], Float[5]), + NumPyAPI.sig(jnp.greater_equal, Float[5], Float[5]), + NumPyAPI.sig(jnp.heaviside, Float[5], Float[5]), + NumPyAPI.sig(jnp.histogram, Float[5]), + NumPyAPI.sig(jnp.histogram2d, Float[5], Float[5]), + NumPyAPI.sig(jnp.histogram_bin_edges, Float[5]), + NumPyAPI.sig(jnp.histogramdd, Float[5, 3]), + NumPyAPI.sig(jnp.hsplit, Float[3, 6], indices_or_sections=2), + NumPyAPI.sig(jnp.hstack, (Float[5], Float[5])), + NumPyAPI.sig(jnp.hypot, Float[5], Float[5]), + NumPyAPI.sig(jnp.i0, Float[5]), + NumPyAPI.sig(jnp.imag, Complex[5]), + NumPyAPI.sig(jnp.inner, Float[5], Float[5]), + NumPyAPI.sig(jnp.insert, Float[5], Int[()], Float[2]), + NumPyAPI.sig(jnp.interp, Float[10], Float[5], Float[5]), + NumPyAPI.sig(jnp.intersect1d, Int[5], Int[5]), + NumPyAPI.sig(jnp.invert, Int[5]), + NumPyAPI.sig(jnp.isclose, Float[5], Float[5]), + NumPyAPI.sig(jnp.iscomplex, Float[5]), + NumPyAPI.sig(jnp.iscomplexobj, Complex[5]), + NumPyAPI.sig(jnp.isfinite, Float[5]), + NumPyAPI.sig(jnp.isin, Int[5], Int[10]), + NumPyAPI.sig(jnp.isinf, Float[5]), + NumPyAPI.sig(jnp.isnan, Float[5]), + NumPyAPI.sig(jnp.isneginf, Float[5]), + NumPyAPI.sig(jnp.isposinf, Float[5]), + NumPyAPI.sig(jnp.isreal, Float[5]), + NumPyAPI.sig(jnp.isrealobj, Float[5]), + NumPyAPI.sig(jnp.isscalar, Float[()]), + NumPyAPI.sig(jnp.lcm, Int[5], Int[5]), + NumPyAPI.sig(jnp.ldexp, Float[5], Int[5]), + NumPyAPI.sig(jnp.left_shift, Int[5], Int[5]), + NumPyAPI.sig(jnp.less, Float[5], Float[5]), + NumPyAPI.sig(jnp.less_equal, Float[5], Float[5]), + NumPyAPI.sig(jnp.lexsort, [Float[5], Float[5]]), + NumPyAPI.sig(jnp.log, Float[5]), + NumPyAPI.sig(jnp.log10, Float[5]), + NumPyAPI.sig(jnp.log1p, Float[5]), + NumPyAPI.sig(jnp.log2, Float[5]), + NumPyAPI.sig(jnp.logaddexp, Float[5], Float[5]), + NumPyAPI.sig(jnp.logaddexp2, Float[5], Float[5]), + NumPyAPI.sig(jnp.logical_and, Int[5], Int[5]), + NumPyAPI.sig(jnp.logical_not, Int[5]), + NumPyAPI.sig(jnp.logical_or, Int[5], Int[5]), + NumPyAPI.sig(jnp.logical_xor, Int[5], Int[5]), + NumPyAPI.sig(jnp.matmul, Float[5, 5], Float[5]), + NumPyAPI.sig(jnp.matrix_transpose, Float[5, 6]), + NumPyAPI.sig(jnp.matvec, Float[5, 5], Float[5]), + NumPyAPI.sig(jnp.max, Float[5]), + NumPyAPI.sig(jnp.maximum, Float[5], Float[5]), + NumPyAPI.sig(jnp.mean, Float[5]), + NumPyAPI.sig(jnp.median, Float[5]), + NumPyAPI.sig(jnp.meshgrid, Float[5], Float[5]), + NumPyAPI.sig(jnp.min, Float[5]), + NumPyAPI.sig(jnp.minimum, Float[5], Float[5]), + NumPyAPI.sig(jnp.mod, Float[5], Float[5]), + NumPyAPI.sig(jnp.modf, Float[5]), + NumPyAPI.sig(jnp.moveaxis, Float[5, 3], source=0, destination=1), + NumPyAPI.sig(jnp.multiply, Float[5], Float[5]), + NumPyAPI.sig(jnp.nan_to_num, Float[5]), + NumPyAPI.sig(jnp.nanargmax, Float[5]), + NumPyAPI.sig(jnp.nanargmin, Float[5]), + NumPyAPI.sig(jnp.nancumprod, Float[5]), + NumPyAPI.sig(jnp.nancumsum, Float[5]), + NumPyAPI.sig(jnp.nanmax, Float[5]), + NumPyAPI.sig(jnp.nanmean, Float[5]), + NumPyAPI.sig(jnp.nanmedian, Float[5]), + NumPyAPI.sig(jnp.nanmin, Float[5]), + NumPyAPI.sig(jnp.nanpercentile, Float[5], q=75), + NumPyAPI.sig(jnp.nanprod, Float[5]), + NumPyAPI.sig(jnp.nanquantile, Float[5], q=0.75), + NumPyAPI.sig(jnp.nanstd, Float[5]), + NumPyAPI.sig(jnp.nansum, Float[5]), + NumPyAPI.sig(jnp.nanvar, Float[5]), + NumPyAPI.sig(jnp.ndim, Float[5]), + NumPyAPI.sig(jnp.negative, Float[5]), + NumPyAPI.sig(jnp.nextafter, Float[5], Float[5]), + NumPyAPI.sig(jnp.nonzero, Float[5]), + NumPyAPI.sig(jnp.not_equal, Float[5], Float[5]), + NumPyAPI.sig(jnp.ones_like, Float[5]), + NumPyAPI.sig(jnp.outer, Float[5], Float[5]), + NumPyAPI.sig(jnp.packbits, Int[5]), + NumPyAPI.sig(jnp.pad, Float[5], pad_width=2), + NumPyAPI.sig(jnp.partition, Float[5], kth=3), + NumPyAPI.sig(jnp.percentile, Float[5], q=75), + NumPyAPI.sig(jnp.permute_dims, Float[3, 5], axes=(1, 0)), + NumPyAPI.sig(jnp.piecewise, Float[5], [Bool[5], Bool[5]], funclist=[jnp.sin, jnp.cos]), + NumPyAPI.sig(jnp.place, Float[5], Bool[5], Float[3], inplace=False), + NumPyAPI.sig(jnp.poly, Float[5]), + NumPyAPI.sig(jnp.polyadd, Float[5], Float[5]), + NumPyAPI.sig(jnp.polyder, Float[5]), + NumPyAPI.sig(jnp.polydiv, Float[5], Float[5]), + NumPyAPI.sig(jnp.polyfit, Float[5], Float[5], deg=2), + NumPyAPI.sig(jnp.polyint, Float[5]), + NumPyAPI.sig(jnp.polymul, Float[5], Float[5]), + NumPyAPI.sig(jnp.polysub, Float[5], Float[5]), + NumPyAPI.sig(jnp.polyval, Float[5], Float[10]), + NumPyAPI.sig(jnp.positive, Float[5]), + NumPyAPI.sig(jnp.pow, Float[5], Float[5]), + NumPyAPI.sig(jnp.power, Float[5], Float[5]), + NumPyAPI.sig(jnp.prod, Float[5]), + NumPyAPI.sig(jnp.ptp, Float[5]), + NumPyAPI.sig(jnp.put, Float[5], Int[()], Float[()], inplace=False), + NumPyAPI.sig(jnp.put_along_axis, Float[5], Int[1], Float[1], axis=0, inplace=False), + NumPyAPI.sig(jnp.quantile, Float[5], q=0.75), + NumPyAPI.sig(jnp.rad2deg, Float[5]), + NumPyAPI.sig(jnp.radians, Float[5]), + NumPyAPI.sig(jnp.ravel, Float[5]), + NumPyAPI.sig(jnp.ravel_multi_index, [Uint8[5], Uint8[5]], dims=(8, 9)), + NumPyAPI.sig(jnp.real, Complex[5]), + NumPyAPI.sig(jnp.reciprocal, Float[5]), + NumPyAPI.sig(jnp.remainder, Float[5], Float[5]), + NumPyAPI.sig(jnp.repeat, Float[5], repeats=np.array([2, 3, 1, 5, 4])), + NumPyAPI.sig(jnp.reshape, Float[6], shape=(2, 3)), + NumPyAPI.sig(jnp.resize, Float[6], new_shape=(2, 3)), + NumPyAPI.sig(jnp.right_shift, Int[5], Int[5]), + NumPyAPI.sig(jnp.rint, Float[5]), + NumPyAPI.sig(jnp.roll, Float[5], Int[1]), + NumPyAPI.sig(jnp.rollaxis, Float[5, 4], axis=1), + NumPyAPI.sig(jnp.roots, Float[5]).with_skip_on_devices(['tpu']), + NumPyAPI.sig(jnp.rot90, Float[5, 3]), + NumPyAPI.sig(jnp.round, Float[5]), + NumPyAPI.sig(jnp.searchsorted, Float[5], Float[5]), + NumPyAPI.sig(jnp.select, [Bool[5], Bool[5]], [Float[5], Float[5]], Float[()]), + NumPyAPI.sig(jnp.setdiff1d, Int[5], Int[5]), + NumPyAPI.sig(jnp.setxor1d, Int[5], Int[5]), + NumPyAPI.sig(jnp.shape, Float[5, 3]), + NumPyAPI.sig(jnp.sign, Float[5]), + NumPyAPI.sig(jnp.signbit, Float[5]), + NumPyAPI.sig(jnp.sin, Float[5]), + NumPyAPI.sig(jnp.sinc, Float[5]), + NumPyAPI.sig(jnp.sinh, Float[5]), + NumPyAPI.sig(jnp.size, Float[5]), + NumPyAPI.sig(jnp.sort, Float[5]), + NumPyAPI.sig(jnp.sort_complex, Complex[5]), + NumPyAPI.sig(jnp.spacing, Float[5]), + NumPyAPI.sig(jnp.split, Float[6], indices_or_sections=2), + NumPyAPI.sig(jnp.sqrt, Float[5]), + NumPyAPI.sig(jnp.square, Float[5]), + NumPyAPI.sig(jnp.squeeze, Float[5]), + NumPyAPI.sig(jnp.stack, [Float[2, 3], Float[2, 3]], axis=1), + NumPyAPI.sig(jnp.std, Float[5]), + NumPyAPI.sig(jnp.subtract, Float[5], Float[5]), + NumPyAPI.sig(jnp.sum, Float[5]), + NumPyAPI.sig(jnp.swapaxes, Float[3, 5], axis1=1, axis2=0), + NumPyAPI.sig(jnp.take, Float[5], Int[2]), + NumPyAPI.sig(jnp.take_along_axis, Float[5], Int[2], axis=0), + NumPyAPI.sig(jnp.tan, Float[5]), + NumPyAPI.sig(jnp.tanh, Float[5]), + NumPyAPI.sig(jnp.tensordot, Float[2, 3, 4], Float[3, 4, 5]), + NumPyAPI.sig(jnp.tile, Float[5], reps=(2,)), + NumPyAPI.sig(jnp.trace, Float[5, 5]), + NumPyAPI.sig(jnp.transpose, Float[5, 6]), + NumPyAPI.sig(jnp.trapezoid, Float[5]), + NumPyAPI.sig(jnp.tril, Float[5, 6]), + NumPyAPI.sig(jnp.tril_indices_from, Float[5, 6]), + NumPyAPI.sig(jnp.trim_zeros, Float[5]), + NumPyAPI.sig(jnp.triu, Float[5, 6]), + NumPyAPI.sig(jnp.triu_indices_from, Float[5, 6]), + NumPyAPI.sig(jnp.true_divide, Float[5], Float[5]), + NumPyAPI.sig(jnp.trunc, Float[5]), + NumPyAPI.sig(jnp.union1d, Int[5], Int[5]), + NumPyAPI.sig(jnp.unique, Int[10]), + NumPyAPI.sig(jnp.unique_all, Int[10]), + NumPyAPI.sig(jnp.unique_counts, Int[10]), + NumPyAPI.sig(jnp.unique_inverse, Int[10]), + NumPyAPI.sig(jnp.unique_values, Int[10]), + NumPyAPI.sig(jnp.unpackbits, Uint8[8]), + NumPyAPI.sig(jnp.unravel_index, Int[5], shape=(2, 3)), + NumPyAPI.sig(jnp.unstack, Float[5]), + NumPyAPI.sig(jnp.unwrap, Float[5]), + NumPyAPI.sig(jnp.vander, Float[5]), + NumPyAPI.sig(jnp.var, Float[5]), + NumPyAPI.sig(jnp.vdot, Float[5], Float[5]), + NumPyAPI.sig(jnp.vecdot, Float[5], Float[5]), + NumPyAPI.sig(jnp.vecmat, Float[5], Float[5, 3]), + NumPyAPI.sig(jnp.vsplit, Float[6], indices_or_sections=2), + NumPyAPI.sig(jnp.vstack, [Float[5], Float[2, 5]]), + NumPyAPI.sig(jnp.where, Bool[5], Float[5], Float[5]), + NumPyAPI.sig(jnp.zeros_like, Float[5]), +] + + +class JaxArrayTests(jtu.JaxTestCase): + @parameterized.named_parameters( + {'testcase_name': api.name(), 'api': api} for api in NUMPY_APIS) + def test_numpy_api_supports_jax_array(self, api): + if api.skip_on_devices and jtu.test_device_matches(api.skip_on_devices): + self.skipTest(f'{api.name()} not supported on {api.skip_on_devices}') + fun = api.fun + args = api.make_args(self.rng()) + wrapped_args = jax.tree.map(JaxArrayWrapper, args) + kwargs = api.kwargs + + expected = fun(*args, **kwargs) + wrapped = fun(*wrapped_args, **kwargs) + + self.assertAllClose(wrapped, expected, atol=0, rtol=0) + + @parameterized.named_parameters( + {'testcase_name': func.__name__, 'func': func} + for func in [jnp.zeros_like, jnp.ones_like, jnp.empty_like, jnp.full_like] + ) + def test_array_creation_from_duck_typed_array(self, func): + # Ensure that jnp.*_like prefers shape/dtype over __jax_array__ when + # both methods are available. + if func is jnp.full_like: + func = functools.partial(func, fill_value=2.0) + obj = DuckTypedArrayWithErroringJaxArray() + + # The test relies on this failing + with self.assertRaises(ValueError): + jnp.asarray(obj) + + result = func(obj) + self.assertIsInstance(result, jax.Array) + self.assertEqual(result.shape, obj.shape) + self.assertEqual(result.dtype, obj.dtype) + + @parameterized.named_parameters( + {"testcase_name": "subscript-form", "args": ("jk,k->j", Float[5, 3], Float[3])}, + {"testcase_name": "index-form", "args": (Float[5, 3], (0, 1), Float[3], (1,), (0,))}, + ) + def test_einsum(self, args): + rng = jtu.rand_default(self.rng()) + def make_arg(arg): + if isinstance(arg, jax.ShapeDtypeStruct): + return rng(arg.shape, arg.dtype) + return arg + args = jax.tree.map(make_arg, args) + + def wrap_array(arg): + if isinstance(arg, (jax.Array, np.ndarray)): + return JaxArrayWrapper(arg) + return arg + wrapped_args = jax.tree.map(wrap_array, args) + + expected = jnp.einsum(*args) + actual = jnp.einsum(*wrapped_args) + + self.assertAllClose(actual, expected, atol=0, rtol=0) + + +@jtu.with_config(jax_disable_jit=True) +class JaxArrayTestsNoJit(JaxArrayTests): + pass + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/array_interoperability_test.py b/tests/array_interoperability_test.py index 80a4d8ef5a25..a61c19ab4e0b 100644 --- a/tests/array_interoperability_test.py +++ b/tests/array_interoperability_test.py @@ -95,6 +95,10 @@ def setUp(self): message="Calling from_dlpack with a DLPack tensor", category=DeprecationWarning, ) + @jtu.ignore_warning( + message="jax.dlpack.to_dlpack was deprecated.*", + category=DeprecationWarning, + ) def testJaxRoundTrip(self, shape, dtype, copy, use_stream): rng = jtu.rand_default(self.rng()) np = rng(shape, dtype) @@ -107,35 +111,13 @@ def _check_copy(x: jax.Array, y: jax.Array, expect_copy): x = jax.device_put(np, jax.devices("cpu")[0]) device = jax.devices("gpu")[0] y = jax.device_put(x, device) - dl_device = y.__dlpack_device__() - if use_stream: - stream = tuple(y.devices())[0].get_stream_for_external_ready_events() - dlpack = jax.dlpack.to_dlpack(y, copy=copy, stream=stream) - else: - dlpack = jax.dlpack.to_dlpack(y, copy=copy) - z = jax.dlpack.from_dlpack(dlpack) + # TODO(parkers): Remove after setting 'stream' properly below. + jax.block_until_ready(y) + z = jax.dlpack.from_dlpack(y) self.assertEqual(z.devices(), {device}) self.assertAllClose(np.astype(x.dtype), z) - self.assertRaisesRegex(RuntimeError, - "DLPack tensor may be consumed at most once", - lambda: jax.dlpack.from_dlpack(dlpack)) - - if shape in nonempty_array_shapes: - _check_copy(y, z, bool(copy)) - # Check if the destination device can be specified - make_dlpack = lambda: x.__dlpack__(dl_device=dl_device, copy=copy) - if copy == False: - self.assertRaisesRegex(ValueError, "copy=False", make_dlpack) - return - - z = jax.dlpack.from_dlpack(make_dlpack()) - self.assertEqual(z.devices(), {device}) - self.assertAllClose(x, z) - - if shape in nonempty_array_shapes: - _check_copy(x, z, True) @jtu.sample_product( shape=all_shapes, @@ -149,6 +131,8 @@ def testJaxArrayRoundTrip(self, shape, dtype, gpu): raise unittest.SkipTest("Skipping GPU test case on CPU") device = jax.devices("gpu" if gpu else "cpu")[0] x = jax.device_put(np, device) + # TODO(parkers): Remove after setting 'stream' properly. + jax.block_until_ready(x) y = jax.dlpack.from_dlpack(x) self.assertEqual(y.devices(), {device}) self.assertAllClose(np.astype(x.dtype), y) @@ -188,6 +172,10 @@ def testTensorFlowToJax(self, shape, dtype): dtype=dlpack_dtypes, ) @unittest.skipIf(not tf, "Test requires TensorFlow") + @jtu.ignore_warning( + message="jax.dlpack.to_dlpack was deprecated.*", + category=DeprecationWarning, + ) def testJaxToTensorFlow(self, shape, dtype): if (not config.enable_x64.value and dtype in [jnp.int64, jnp.uint64, jnp.float64]): @@ -198,11 +186,12 @@ def testJaxToTensorFlow(self, shape, dtype): rng = jtu.rand_default(self.rng()) np = rng(shape, dtype) x = jnp.array(np) + # TODO(parkers): Remove after setting 'stream' properly. + jax.block_until_ready(x) # TODO(b/171320191): this line works around a missing context initialization # bug in TensorFlow. _ = tf.add(1, 1) - dlpack = jax.dlpack.to_dlpack(x) - y = tf.experimental.dlpack.from_dlpack(dlpack) + y = tf.experimental.dlpack.from_dlpack(x.__dlpack__()) self.assertAllClose(np, y.numpy()) @unittest.skipIf(not tf, "Test requires TensorFlow") @@ -319,6 +308,8 @@ def testJaxToCuPy(self, shape, dtype): rng = jtu.rand_default(self.rng()) x = rng(shape, dtype) y = jnp.array(x) + # TODO(parkers): Remove after setting 'stream' properly. + jax.block_until_ready(y) z = cupy.asarray(y) self.assertEqual(y.__cuda_array_interface__["data"][0], z.__cuda_array_interface__["data"][0]) @@ -354,6 +345,8 @@ def testCaiToJax(self, shape, dtype): device = jax.devices('cuda')[-1] with jax.default_device(device): y = jnp.array(x, dtype=dtype) + # TODO(parkers): Remove after setting 'stream' properly below. + jax.block_until_ready(y) self.assertEqual(y.dtype, dtype) # Using a jax array CAI provider support to construct an object diff --git a/tests/array_test.py b/tests/array_test.py index cc8990828ded..b98f25abca7a 100644 --- a/tests/array_test.py +++ b/tests/array_test.py @@ -31,12 +31,11 @@ from jax._src.lib import xla_client as xc from jax._src.lib.mlir import dialects, ir from jax._src.util import safe_zip -from jax._src.mesh import AxisType +from jax._src.mesh import AxisType, AbstractMesh from jax._src.sharding import common_devices_indices_map from jax._src.sharding_impls import ( - _op_sharding_to_pos_sharding, pmap_sharding_devices_indices_map, - NamedSharding, GSPMDSharding, PositionalSharding, SdyDimSharding, - SdyArraySharding) + pmap_sharding_devices_indices_map, NamedSharding, GSPMDSharding, SdyDim, + SdyArray) from jax.experimental.pjit import pjit from jax.experimental import multihost_utils from jax.sharding import PartitionSpec as P @@ -368,8 +367,6 @@ def test_different_devices_in_arrays_than_sharding(self): array.ArrayImpl(core.ShapedArray(shape, np.float32), s, bufs, committed=True) def test_duplicated_devices_in_arrays(self): - if xc._version <= 274: - self.skipTest('Test requires jaxlib version 275') shape = (8, 2) mesh = jtu.create_mesh((1, 2), ('x', 'y')) # Sharding device ids = {0, 1} @@ -657,12 +654,15 @@ def f(x): output_shardings._to_xla_hlo_sharding(x_dummy.ndim), s._to_xla_hlo_sharding(x_dummy.ndim))) - # TODO(skyewm): remove this test when we can remove the workaround manual - # defragment API - @jtu.skip_on_devices('cpu') # defragment not implemented for TFRT CPU + # TODO(b/399879011): GPU is the only platform that has an implementation for + # this, which exists in py_client.cc. Ideally, this would be replaced with + # some kind of auto-defrag-on-OOM. + @jtu.run_on_devices('gpu') def test_defragment(self): + # Since the GPU implementation is in py_client.cc, it cannot be exposed via + # the PjRt C API. if xb.using_pjrt_c_api(): - self.skipTest("Manual defragment not exposed via PJRT C API") + self.skipTest('Manual defragment not exposed via PJRT C API') # Create a few arrays global_mesh = jtu.create_mesh((jax.local_device_count(),), ('x',)) @@ -675,7 +675,7 @@ def test_defragment(self): # Delete one of them arr2.delete() - # Defragment + # Defragment. xb.get_backend().defragment() # Sanity check remaining arrays @@ -710,6 +710,13 @@ def test_process_allgather_single_host(self): self.assertEqual(out.shape, (1, x.shape[0])) self.assertArraysEqual(out, np.expand_dims(x, axis=0)) + def test_broadcast_one_to_all_single_host(self): + x = jnp.arange(8, dtype=jnp.uint8) + out = multihost_utils.broadcast_one_to_all(x) + self.assertEqual(out.shape, x.shape) + self.assertEqual(out.dtype, x.dtype) + self.assertArraysEqual(out, x) + @jtu.sample_product( dtype=jtu.dtypes.all, shape=[(), (10), (2, 3)], @@ -897,7 +904,7 @@ def test_op_sharding_indices(self, pspec): shape = (8, 4) mesh = jtu.create_mesh((4, 2), ('x', 'y')) mps = jax.sharding.NamedSharding(mesh, pspec) - ops = jax.sharding.GSPMDSharding( + ops = GSPMDSharding( list(mesh.devices.flat), mps._to_xla_hlo_sharding(len(shape))) self.assertDictEqual( ops.devices_indices_map(shape), mps.devices_indices_map(shape)) @@ -973,7 +980,7 @@ def test_gspmd_sharding_repr(self): op.tile_assignment_dimensions = [4, 1, 2] op.tile_assignment_devices = [0, 1, 2, 3, 4, 5, 6, 7] op.replicate_on_last_tile_dim = True - s = jax.sharding.GSPMDSharding(jax.devices(), op) + s = GSPMDSharding(jax.devices(), op) # memory kind also appears in the repr but only for TPU. self.assertIn( 'GSPMDSharding({devices=[4,1,2]0,1,2,3,4,5,6,7 ' @@ -981,93 +988,10 @@ def test_gspmd_sharding_repr(self): op2 = xc.OpSharding() op2.type = xc.OpSharding.Type.REPLICATED - s2 = jax.sharding.GSPMDSharding(jax.devices(), op2) + s2 = GSPMDSharding(jax.devices(), op2) # memory kind also appears in the repr but only for TPU. self.assertIn('GSPMDSharding({replicated}', repr(s2)) - def test_positional_sharding_fully_replicated(self): - sharding = PositionalSharding(jax.devices()) - jax.device_put(jnp.array(1), sharding.replicate()) # doesn't crash - - @parameterized.named_parameters( - ("mesh_x_y", P("x", "y"), (4, 2), (), False), - ("mesh_x", P("x"), (4, 2), (1,), False), - ("mesh_y", P("y"), (4, 2), (0,), True), - ("mesh_none_y", P(None, "y"), (4, 2), (0,), False), - ("mesh_none_x", P(None, "x"), (4, 2), (1,), True), - ("mesh_xy", P(("x", "y")), (8, 1), (), False), - ("mesh_fully_replicated", P(), (4, 2), None, False), - ) - def test_positional_sharding_op_sharding_lowering( - self, pspec, shape, axes, transpose): - value_shape = (8, 4) - - mesh = jtu.create_mesh((4, 2), ('x', 'y')) - mps = jax.sharding.NamedSharding(mesh, pspec) - devices = jax.local_devices()[:8] # Taking up to 8 devices - - devices_sharding = jax.sharding.PositionalSharding(devices) - devices_sharding = devices_sharding.reshape(shape).replicate(axes) - if transpose: - devices_sharding = devices_sharding.T - - op1 = mps._to_xla_hlo_sharding(len(value_shape)) - op2 = devices_sharding._to_xla_hlo_sharding(len(value_shape)) - - self.assertEqual(mps.shard_shape(value_shape), - devices_sharding.shard_shape(value_shape)) - self.assertTrue(op_shardings.are_op_shardings_equal(op1, op2)) - - def test_positional_sharding_aval_compatible(self): - if jax.device_count() < 2: - self.skipTest('Requires >=2 devices') - sharding = PositionalSharding(jax.devices()).reshape(1, jax.device_count()) - x = jax.random.uniform(jax.random.key(42), (256, 20, 1000)) - with self.assertRaisesRegex( - ValueError, - 'Sharding PositionalSharding.*is only valid for values of rank 2, but' - ' was applied to a value of rank 3'): - jax.lax.with_sharding_constraint(x, sharding) - - @parameterized.named_parameters( - ("2d_mesh_x_y", (4, 2), P("x", "y")), - ("2d_mesh_x", (4, 2), P("x")), - ("2d_mesh_y", (4, 2), P("y")), - ("2d_mesh_none_y", (4, 2), P(None, "y")), - ("2d_mesh_none_x", (4, 2), P(None, "x")), - ("2d_mesh_xy", (4, 2), P(("x", "y"))), - ("2d_mesh_none_xy", (4, 2), P(None, ("x", "y"))), - ("2d_mesh_x_none", (2, 1), P(('x',), None)), - ("2d_mesh_fully_replicated", (4, 2), P()), - ("3d_mesh_none_none_z", (2, 2, 2), P(None, None, 'z')), - ("3d_mesh_none_y_none", (2, 2, 2), P(None, 'y', None)), - ("3d_mesh_x_y_none", (2, 2, 2), P('x', 'y', None)), - ("3d_mesh_none_yz", (2, 2, 2), P(None, ('y', 'z'))), - ("3d_mesh_x_none_yz", (2, 2, 2), P('x', None, ('y', 'z'))), - ("3d_mesh_none_x_yz", (2, 2, 2), P(None, 'x', ('y', 'z'))), - ("3d_mesh_xy_z", (2, 2, 2), P(('x', 'y'), 'z')), - ("3d_mesh_xy_none_z", (2, 2, 2), P(('x', 'y'), None, 'z')), - ("3d_mesh_x_y_z", (2, 2, 2), P('x', 'y', 'z')), - ("3d_mesh_xz_y", (2, 2, 2), P(('x', 'z'), 'y')), - ("3d_mesh_xz_none_y", (2, 2, 2), P(('x', 'z'), None, 'y')), - ("3d_mesh_y_none_xz", (2, 2, 2), P('y', None, ('x', 'z'))), - ("3d_mesh_none_y_xz", (2, 2, 2), P(None, 'y', ('x', 'z'))), - ("3d_mesh2_none_none_z", (1, 2, 4), P(None, None, 'z')), - ("3d_mesh2_x_none_none", (1, 2, 4), P('x', None, None)), - ("3d_mesh_x_none_none", (2, 1, 1), P('x', None, None)), - ) - def test_positional_sharding_from_op_sharding(self, mesh_shape, pspec): - ndim = len(mesh_shape) - mesh = jtu.create_mesh( - mesh_shape, ('x', 'y') if ndim == 2 else ('x', 'y', 'z')) - mps = jax.sharding.NamedSharding(mesh, pspec) - original_op_sharding = mps._to_xla_hlo_sharding(ndim) - ps = _op_sharding_to_pos_sharding(original_op_sharding, - mps._device_assignment) - out_op_sharding = ps._to_xla_hlo_sharding(ndim) - self.assertTrue(op_shardings.are_op_shardings_equal( - original_op_sharding, out_op_sharding)) - @parameterized.named_parameters( ("2d_mesh_x", (1, 1), P("x", "y")), ("2d_mesh_x_y", (4, 2), P("x", "y")), @@ -1097,26 +1021,6 @@ def test_is_fully_replicated_named_sharding(self, mesh_shape, pspec): ops_ifr = op_shardings.is_op_sharding_replicated(mps_op_sharding) self.assertEqual(mps.is_fully_replicated, ops_ifr) - ps = _op_sharding_to_pos_sharding(mps_op_sharding, mps._device_assignment) - self.assertEqual(ps.is_fully_replicated, - op_shardings.is_op_sharding_replicated( - ps._to_xla_hlo_sharding(len(shape)))) - - def test_devices_sharding_respects_init_mesh_shape(self): - value_shape = (8, 4) - - mesh = jtu.create_mesh((4, 2), ('x', 'y')) - mps = jax.sharding.NamedSharding(mesh, P('x', 'y')) - - devices_sharding = jax.sharding.PositionalSharding(mesh.devices) - - op1 = mps._to_xla_hlo_sharding(len(value_shape)) - op2 = devices_sharding._to_xla_hlo_sharding(len(value_shape)) - - self.assertEqual(mps.shard_shape(value_shape), - devices_sharding.shard_shape(value_shape)) - self.assertTrue(op_shardings.are_op_shardings_equal(op1, op2)) - def test_pmap_sharding_repr(self): if jax.device_count() < 2: self.skipTest('Test needs >= 2 devices.') @@ -1124,13 +1028,6 @@ def test_pmap_sharding_repr(self): str(out.sharding) # doesn't crash repr(out.sharding) # doesn't crash - def test_positional_sharding_repr(self): - if jax.device_count() < 2: - self.skipTest('Test needs >= 2 devices.') - s = jax.sharding.PositionalSharding(jax.devices()).reshape(jax.device_count(), 1) - repr(s) # doesn't crash - str(s) # doesn't crash - def test_pspec_tuple(self): pspec = P('x', 'y', 'z') self.assertEqual(pspec, ('x', 'y', 'z')) @@ -1198,9 +1095,9 @@ def test_are_shardings_equivalent(self): op1 = xc.OpSharding() op1.type = xc.OpSharding.Type.REPLICATED - s6 = jax.sharding.GSPMDSharding([jax.devices()[0]], op1) + s6 = GSPMDSharding([jax.devices()[0]], op1) - s7 = jax.sharding.GSPMDSharding(jax.devices(), op1) + s7 = GSPMDSharding(jax.devices(), op1) # The OpSharding is replicated but the Sharding itself are on different # devices. @@ -1210,7 +1107,7 @@ def test_are_shardings_equivalent(self): op2.type = xc.OpSharding.Type.OTHER op2.tile_assignment_devices = [0, 1] op2.tile_assignment_dimensions = [2, 1] - s8 = jax.sharding.GSPMDSharding(list(mesh2.devices.flat), op2) + s8 = GSPMDSharding(list(mesh2.devices.flat), op2) self.assertTrue(s1.is_equivalent_to(s6, 2)) self.assertTrue(s5.is_equivalent_to(s8, 2)) @@ -1223,7 +1120,7 @@ def test_are_shardings_equivalent(self): op3.tile_assignment_devices = [0, 1] op3.tile_assignment_dimensions = [1, 1, 2] op3.replicate_on_last_tile_dim = True - s10 = jax.sharding.GSPMDSharding(list(mesh2.devices.flat), op3) + s10 = GSPMDSharding(list(mesh2.devices.flat), op3) self.assertTrue(s9.is_equivalent_to(s10, 2)) @@ -1301,6 +1198,18 @@ def f(x): with self.assertRaisesRegex(TypeError, msg): jax.jit(f)(x) + def test_make_array_from_single_device_arrays_tuple(self): + mesh = jtu.create_mesh((2, 2), ('x', 'y')) + shape = (8, 8) + s = jax.sharding.NamedSharding(mesh, P('x', 'y')) + inp_data = np.arange(math.prod(shape)).reshape(shape) + + arrays = tuple( + jax.device_put(inp_data[index], d) + for d, index in s.addressable_devices_indices_map(shape).items()) + + jax.make_array_from_single_device_arrays(shape, s, arrays) # doesn't crash + def test_make_array_from_single_device_arrays_bad_inputs(self): x = jnp.arange(10) mesh = jtu.create_mesh((2,), ('x',)) @@ -1363,6 +1272,16 @@ def test_mesh_axis_types_mismatch(self): jax.sharding.AbstractMesh((2, 1), ('x', 'y'), axis_types=jax.sharding.AxisType.Auto) + with self.assertRaisesRegex(TypeError, "axis_types.*must be of type"): + AbstractMesh((2,), ('x',), axis_types=("explicit",)) + + with self.assertRaisesRegex(TypeError, "axis_types.*must be of type"): + AbstractMesh((2,), ('x',), axis_types="explicit") + + with self.assertRaisesRegex(TypeError, "axis_types.*must be of type"): + AbstractMesh((2, 2), ('x', 'y'), + axis_types=("explicit", AxisType.Explicit)) + def test_make_mesh_axis_types(self): Auto, Explicit, Manual = AxisType.Auto, AxisType.Explicit, AxisType.Manual @@ -1378,6 +1297,9 @@ def test_make_mesh_axis_types(self): self.assertDictEqual( mesh._axis_types_dict, {AxisType.Auto: ('y',), AxisType.Explicit: ('x',), AxisType.Manual: ('z',)}) + self.assertEqual(mesh.explicit_axes, ('x',)) + self.assertEqual(mesh.auto_axes, ('y',)) + self.assertEqual(mesh.manual_axes, ('z',)) mesh = jax.make_mesh((1, 1, 1), ('x', 'y', 'z'), axis_types=(Explicit, Explicit, Manual)) @@ -1402,6 +1324,129 @@ def test_make_mesh_axis_types(self): self.assertNotEqual(mesh1, mesh2) self.assertNotEqual(hash(mesh1), hash(mesh2)) + def test_memory_kind_with_abstract_mesh(self): + abstract_mesh = AbstractMesh((2,), ('x',)) + ns = NamedSharding(abstract_mesh, P(), memory_kind='pinned_host') + self.assertEqual(ns.memory_kind, 'pinned_host') + + ns = NamedSharding(abstract_mesh, P()) + self.assertIsNone(ns.memory_kind) + + with self.assertRaisesRegex( + ValueError, 'Got invalid memory kind'): + NamedSharding(abstract_mesh, P(), memory_kind='weird_device') + + def test_pspec_unreduced(self): + pspec = P('a', 'b', None, unreduced={'c'}, reduced={'d'}) + self.assertEqual( + repr(pspec), + "PartitionSpec('a', 'b', None, unreduced={'c'}, reduced={'d'})") + + pspec1 = P('a', 'b', None, unreduced={'c'}) + self.assertEqual(repr(pspec1), + "PartitionSpec('a', 'b', None, unreduced={'c'})") + + pspec2 = P('a', 'b', None, unreduced={'c'}) + self.assertEqual(pspec1, pspec2) + + pspec3 = P('a', 'b', None, unreduced={'d'}) + self.assertNotEqual(pspec1, pspec3) + + out = P('x', unreduced={'z'}) + P('a', unreduced={'b'}) + self.assertEqual(out, P('x', 'a', unreduced={'z', 'b'})) + + pspec4 = P('x', unreduced={'y'}) + self.assertEqual(repr(pspec4), + "PartitionSpec('x', unreduced={'y'})") + + pspec5 = P(None, None, unreduced={'x'}) + self.assertEqual(repr(pspec5), + "PartitionSpec(None, None, unreduced={'x'})") + + pspec6 = P(None, unreduced={'x'}) + self.assertEqual(repr(pspec6), "PartitionSpec(None, unreduced={'x'})") + + pspec7 = P(unreduced={'x'}) + self.assertEqual(repr(pspec7), "PartitionSpec(unreduced={'x'})") + + with self.assertRaisesRegex( + TypeError, 'unreduced in `__add__` of PartitionSpec'): + P('x', unreduced={'z'}) + (None,) * 2 + + with self.assertRaisesRegex( + TypeError, "unreduced in `__radd__` of PartitionSpec"): + (None,) * 2 + P('x', unreduced={'y'}) + + with self.assertRaisesRegex( + ValueError, "partitions cannot overlap with unreduced"): + P('x', 'y', unreduced={'x'}) + + with self.assertRaisesRegex( + ValueError, "partitions cannot overlap with unreduced"): + P('x', None, 'y', unreduced={'z', 'y'}) + + def test_named_sharding_unreduced_error(self): + mesh = jtu.create_mesh((1, 1, 1), ('x', 'y', 'z')) + + with self.assertRaisesRegex( + ValueError, "Unreduced axes.*not found in mesh.*"): + NamedSharding(mesh, P('x', unreduced={'a'})) + + with self.assertRaisesRegex( + ValueError, "Unreduced axes can only refer to mesh axes.*Explicit"): + NamedSharding(mesh, P('x', unreduced={'y', 'z'})) + + with self.assertRaisesRegex( + ValueError, "unreduced cannot contain None.*"): + NamedSharding(mesh, P('x', unreduced={'y', None})) + + def test_hlo_sharding_get_axis_sizes(self): + op = xc.OpSharding() + op.type = xc.OpSharding.Type.OTHER + op.tile_assignment_dimensions = [6, 35] + op.iota_reshape_dims = [7, 10, 3] + op.iota_transpose_perm = [2, 1, 0] + s = GSPMDSharding(jax.devices(), op) + self.assertIn('{devices=[6,35]<=[7,10,3]T(2,1,0)}', repr(s)) + self.assertEqual(s._to_xla_hlo_sharding(2).get_axis_sizes(), [7, 2, 5, 3]) + + @parameterized.named_parameters( + ('2d_mesh_x_y', (4, 2), P('x', 'y')), + ('2d_mesh_x', (4, 2), P('x')), + ('2d_mesh_y', (4, 2), P('y')), + ('2d_mesh_none_y', (4, 2), P(None, 'y')), + ('2d_mesh_none_x', (4, 2), P(None, 'x')), + ('2d_mesh_xy', (4, 2), P(('x', 'y'))), + ('2d_mesh_none_xy', (4, 2), P(None, ('x', 'y'))), + ('2d_mesh_fully_replicated', (4, 2), P()), + ('2d_mesh_x_none', (2, 1), P(('x',), None)), + ('3d_mesh_none_none_z', (2, 2, 2), P(None, None, 'z')), + ('3d_mesh_none_y_none', (2, 2, 2), P(None, 'y', None)), + ('3d_mesh_x_y_none', (2, 2, 2), P('x', 'y', None)), + ('3d_mesh_none_yz', (2, 2, 2), P(None, ('y', 'z'))), + ('3d_mesh_x_none_yz', (2, 2, 2), P('x', None, ('y', 'z'))), + ('3d_mesh_none_x_yz', (2, 2, 2), P(None, 'x', ('y', 'z'))), + ('3d_mesh_xy_z', (2, 2, 2), P(('x', 'y'), 'z')), + ('3d_mesh_xy_none_z', (2, 2, 2), P(('x', 'y'), None, 'z')), + ('3d_mesh_x_y_z', (2, 2, 2), P('x', 'y', 'z')), + ('3d_mesh_xz_y', (2, 2, 2), P(('x', 'z'), 'y')), + ('3d_mesh_xz_none_y', (2, 2, 2), P(('x', 'z'), None, 'y')), + ('3d_mesh_y_none_xz', (2, 2, 2), P('y', None, ('x', 'z'))), + ('3d_mesh_none_y_xz', (2, 2, 2), P(None, 'y', ('x', 'z'))), + ('3d_mesh2_none_none_z', (1, 2, 4), P(None, None, 'z')), + ('3d_mesh2_x_none_none', (1, 2, 4), P('x', None, None)), + ('3d_mesh_x_none_none', (2, 1, 1), P('x', None, None)), + ) + def test_gspmd_sharding_shardy_lowering(self, mesh_shape, pspec): + ndim = len(mesh_shape) + mesh = jtu.create_mesh( + mesh_shape, ('x', 'y') if ndim == 2 else ('x', 'y', 'z') + ) + ns = jax.sharding.NamedSharding(mesh, pspec) + gs = GSPMDSharding(ns._device_assignment, ns._to_xla_hlo_sharding(ndim)) + out_sdy_sharding = gs._to_sdy_sharding(ndim) + self.assertTrue(out_sdy_sharding, ns._to_sdy_sharding(ndim)) + @jtu.with_config(jax_use_shardy_partitioner=True) class ShardyShardingTest(jtu.JaxTestCase): @@ -1412,12 +1457,12 @@ def test_long_axis_names(self): sdy_sharding = s._to_sdy_sharding(3) self.assertEqual( sdy_sharding, - SdyArraySharding( - mesh.shape_tuple, - [SdyDimSharding( - ('sequence', 'data'), True), - SdyDimSharding(('model',), True), - SdyDimSharding([], True)])) + SdyArray( + mesh_shape=mesh.shape_tuple, + dim_shardings=[SdyDim( + ('sequence', 'data'), False), + SdyDim(('model',), False), + SdyDim([], False)])) with ir.Context() as ctx: dialects.sdy.register_dialect(ctx) self.assertEqual( @@ -1432,11 +1477,11 @@ def test_unconstrained(self): sdy_sharding = s._to_sdy_sharding(3) self.assertEqual( sdy_sharding, - SdyArraySharding( - mesh.shape_tuple, - [SdyDimSharding([], True), - SdyDimSharding([], False), - SdyDimSharding(('x',), True)])) + SdyArray( + mesh_shape=mesh.shape_tuple, + dim_shardings=[SdyDim([], False), + SdyDim([], True), + SdyDim(('x',), False)])) with ir.Context() as ctx: dialects.sdy.register_dialect(ctx) self.assertEqual( diff --git a/tests/attrs_test.py b/tests/attrs_test.py index 2334a7b98f91..90083626fb8e 100644 --- a/tests/attrs_test.py +++ b/tests/attrs_test.py @@ -15,6 +15,7 @@ from __future__ import annotations from dataclasses import dataclass +import itertools as it from absl.testing import absltest from absl.testing import parameterized @@ -27,8 +28,9 @@ from jax._src import test_util as jtu from jax._src.util import safe_zip, safe_map -from jax.experimental import attrs -from jax.experimental.attrs import jax_setattr, jax_getattr +from jax._src import attrs +from jax.experimental.attrs import ( + jax_setattr, jax_getattr, jax_appendattr, Box, List) config.parse_flags_with_absl() @@ -66,6 +68,19 @@ def double_it() -> None: double_it() self.assertEqual(thing.x, 16.0) + def test_setattr_doesnt_leak(self): + thing = Thing(1.0) + + @jax.jit + def f(x): + jax_setattr(thing, 'x', x) + raise Exception + + try: f(1.) + except: pass + self.assertNotIsInstance(thing.x, jax.core.Tracer) + + @parameterized.parameters([True, False]) def test_jit_basic_tree(self, jit: bool): thing = Thing((1.0, 2.0)) @@ -260,6 +275,26 @@ def body(_, __): double_it_10() self.assertAllClose(thing.x, 1024., check_dtypes=False) + @parameterized.parameters([True, False]) + def test_scan_basic_pytree(self, jit): + class Thing: ... + thing = Thing() + thing.x = (1.0, 1.0) + + def double_it_10(): + def body(_, __): + cur_x, _ = jax_getattr(thing ,"x") + jax_setattr(thing, "x", (cur_x * 2.0, 3.0)) + return None, None + _, _ = jax.lax.scan(body, None, None, length=10) + + if jit: + double_it_10 = jax.jit(double_it_10) + + double_it_10() + self.assertAllClose(thing.x[0], 1024., check_dtypes=False) + self.assertAllClose(thing.x[1], 3., check_dtypes=False) + def test_scan_basic_consts_and_args(self): thing = Thing(1.0) @@ -360,6 +395,227 @@ def body(i, _): return i + 1, None _, _ = jax.lax.scan(body, 0, None, length=3) # don't crash + @parameterized.parameters([True, False]) + def test_setattr_doesnt_exist(self, jit): + class Thing: + ... + thing = Thing() + + def f(x): + assert (not jit) or tracing_is_ok + jax_setattr(thing, 'x', x) + + if jit: + f = jax.jit(f) + + tracing_is_ok = True + self.assertFalse(hasattr(thing, 'x')) + f(1.0) + self.assertEqual(thing.x, 1.0) + f(2.0) + self.assertEqual(thing.x, 2.0) + + tracing_is_ok = False + f(3.0) + self.assertEqual(thing.x, 3.0) + + del thing.x + f(4.0) + self.assertEqual(thing.x, 4.0) + + tracing_is_ok = True + f(5) + self.assertEqual(thing.x, 5) + + def test_setattr_doesnt_exist_doesnt_leave_sentinel_around(self): + class Thing: + ... + thing = Thing() + + def f(x): + jax_setattr(thing, 'x', x) + + jax.make_jaxpr(f)(3.) + self.assertFalse(hasattr(thing, 'x')) + tracing_ok = True + f(0.0) + self.assertAllClose(thing.x, 0.) + tracing_ok = False + f(1.0) + self.assertAllClose(thing.x, 1.) + + @parameterized.parameters(it.product([False, True], repeat=2)) + def test_appendattr_basic(self, jit, initialized): + class Thing: + ... + thing = Thing() + + if initialized: + thing.x = jnp.arange(0.) + + def f(x): + assert (not jit) or tracing_ok + jax_appendattr(thing, 'x', x) + jax_appendattr(thing, 'x', x + 1) + + if jit: + f = jax.jit(f) + + tracing_ok = True + f(0.0) + self.assertAllClose(thing.x, jnp.array([0., 1.])) + tracing_ok = False + f(2.0) + self.assertAllClose(thing.x, jnp.array([0., 1., 2., 3.])) + f(4.0) + self.assertAllClose(thing.x, jnp.array([0., 1., 2., 3., 4., 5.])) + + @parameterized.parameters(it.product([False, True], repeat=2)) + def test_appendattr_constant(self, jit, initialized): + class Thing: ... + thing = Thing() + + if initialized: + thing.x = jnp.arange(0.) + + def f(): + assert (not jit) or tracing_ok + jax_appendattr(thing, 'x', 0.0) + jax_appendattr(thing, 'x', 1.0) + + if jit: + f = jax.jit(f) + + tracing_ok = True + f() + self.assertAllClose(thing.x, jnp.array([0., 1.])) + tracing_ok = False + f() + self.assertAllClose(thing.x, jnp.array([0., 1., 0., 1.])) + + @parameterized.parameters([True, False]) + def test_appendattr_getattr_errors(self, initialized): + class Thing: ... + thing = Thing() + + if initialized: + thing.x = jnp.arange(0.) + + @jax.jit + def f(x): + jax_appendattr(thing, 'x', x) + jax_getattr(thing, 'x') + + with self.assertRaisesRegex(TypeError, "can't read/write"): + f(1.0) + + @jax.jit + def g(x): + jax_setattr(thing, 'x', x) + jax_appendattr(thing, 'x', x) + + with self.assertRaisesRegex(TypeError, "can't append"): + g(1.0) + + if initialized: + self.assertNotIsInstance(thing.x, jax.core.Tracer) + else: + self.assertFalse(hasattr(thing, 'x')) + + @parameterized.parameters(it.product([False, True], repeat=2)) + def test_appendattr_dtype_disagreement(self, jit, initialized): + class Thing: ... + thing = Thing() + + if initialized: + thing.x = jnp.array([], 'float32') + + def f(x): + jax_appendattr(thing, 'x', x) + jax_appendattr(thing, 'x', x.astype('complex64')) + + if jit: + f = jax.jit(f) + + msg = "can only append to attr x with values of trailing shape " + msg += "float32" if initialized else "int32" + with self.assertRaisesRegex(TypeError, msg): + f(jnp.array(1, 'int32')) + + @parameterized.parameters(it.product([False, True], repeat=2)) + def test_appendattr_shape_disagreement(self, jit, initialized): + class Thing: ... + thing = Thing() + + if initialized: + thing.x = jnp.array([]) + + def f(x): + jax_appendattr(thing, 'x', x) + jax_appendattr(thing, 'x', jnp.stack([x, x])) + + if jit: + f = jax.jit(f) + + msg = "can only append to attr x with values of trailing shape" + with self.assertRaisesRegex(TypeError, msg): + f(1) + + @parameterized.parameters(it.product([False, True], repeat=2)) + def test_appendattr_scan(self, jit, initialized): + class Thing: ... + thing = Thing() + + if initialized: + thing.x = jnp.array([]) + + def f(): + def body(c, x): + jax_appendattr(thing, 'x', 2 * x) + jax_appendattr(thing, 'x', 2 * x + 1) + return c, () + _, () = jax.lax.scan(body, 0, jnp.arange(3.)) + + if jit: + f = jax.jit(f) + + f() + + self.assertAllClose(thing.x, jnp.array([0., 1., 2., 3., 4., 5.])) + + @parameterized.parameters(it.product([False, True], repeat=2)) + def test_appendattr_scan_vjp(self, jit, initialized): + class Thing: ... + thing = Thing() + + if initialized: + thing.y_bar = jnp.array([]) + + def f(x): + def body(c, _): + return 0.5 * g(2 * c), () + y, _ = jax.lax.scan(body, x, (), length=5) + return y + + if jit: + f = jax.jit(f) + + @jax.custom_vjp + def g(x): + return x + + def g_fwd(x): + return g(x), None + + def g_bwd(_, y_bar): + jax_appendattr(thing, 'y_bar', y_bar) + return y_bar, + + g.defvjp(g_fwd, g_bwd) + jax.grad(f)(3.) + + self.assertAllClose(thing.y_bar, jnp.array([0.5] * 5)) + class AttrsJVPTest(jtu.JaxTestCase): @@ -500,6 +756,7 @@ def g_ref(x, x_dot, y, y_dot): self.assertAllClose(w_ddot, w_ddot_, check_dtypes=False) self.assertAllClose(z_ddot, z_ddot_, check_dtypes=False) + class AttrsLinTest(jtu.JaxTestCase): @parameterized.parameters([True, False]) @@ -666,5 +923,407 @@ def f_ref(x, y, z, w): check_dtypes=False) +class BoxTest(jtu.JaxTestCase): + + def test_jit_arg(self): + @jax.jit + def f(box, x): + assert tracing_ok + box.set(box.get() + x) + + tracing_ok = True + box1 = Box(1.0) + f(box1, 1.) + self.assertAllClose(box1.get(), 2.0) + + tracing_ok = False + box2 = Box(2.0) + f(box2, 2.) + self.assertAllClose(box2.get(), 4.0) + + def test_jit_arg_in_pytree(self): + @jax.jit + def f(dct, x): + assert tracing_ok + box = dct['box'] + box.set(box.get() + x) + + tracing_ok = True + box1 = Box(1.0) + f({'box': box1, 'a': 1.0}, 1.) + self.assertAllClose(box1.get(), 2.0) + + tracing_ok = False + box2 = Box(2.0) + f({'box': box2, 'a': 2.0}, 2.) + self.assertAllClose(box2.get(), 4.0) + + tracing_ok = True + box3 = Box(3) # int, dtype changed + f({'box': box3, 'a': 2.0}, 2.) + self.assertAllClose(box3.get(), 5.0) + + def test_jit_closure(self): + @jax.jit + def f(x): + box.set(box.get() + x) + + box = Box(1.0) + f(2.0) + self.assertAllClose(box.get(), 3.0) + + @jax.jit + def g(x): + f(x) + + g(3.0) + self.assertAllClose(box.get(), 6.0) + + def test_jit_closure_nested(self): + @jax.jit + def h(x): + box = Box(x) + + @jax.jit + def k(x): + box.set(box.get() + x) + + k(1.0) + k(1.0) + return box.get() + + ans = h(2.0) + self.assertAllClose(ans, 4.0) + + @parameterized.parameters([False, True]) + def test_jvp_closure_stop_gradient(self, jit): + box = Box(1.0) + + def f(x): + y = 2 * x + box.set(box.get() + jax.lax.stop_gradient(y)) + return y + + if jit: + f = jax.jit(f) + + y, y_dot = jax.jvp(f, (1.0,), (1.0,)) + self.assertAllClose(y, 2.0) + self.assertAllClose(y_dot, 2.0) + self.assertAllClose(box.get(), 3.0) + + @parameterized.parameters([False, True]) + def test_jvp_arg(self, jit): + def f(box, x): + box.set(box.get() + x) + return x + + if jit: + f = jax.jit(f) + + box = Box(5.0) + box_dot = Box(1.0) + y, y_dot = jax.jvp(f, (box, 2.), (box_dot, 1.)) + self.assertAllClose(y, 2.0) + self.assertAllClose(y_dot, 1.0) + self.assertAllClose(box.get(), 7.0) + self.assertAllClose(box_dot.get(), 2.0) + + @parameterized.parameters([False, True]) + def test_custom_vjp_plumbing(self, jit): + box = Box(0.0) + + @jax.custom_vjp + def foo(x): + return x + def foo_fwd(x): + return foo(x), None + def foo_bwd(_, g): + box.set(g) + return g, + foo.defvjp(foo_fwd, foo_bwd) + + def f(x): + x = 2 * x + x = foo(x) + x = 2 * x + return x + + if jit: + f = jax.jit(f) + + jax.grad(f)(1.0) + self.assertAllClose(box.get(), 2.0) + + @parameterized.parameters([False, True]) + def test_grad_closure_stop_gradient(self, jit): + box = Box(0.0) + + def f(x): + y = x * 2 + box.set(box.get() + jax.lax.stop_gradient(y)) + return y + + if jit: + f = jax.jit(f) + + g = jax.grad(f)(1.0) + self.assertAllClose(g, 2.0) + self.assertAllClose(box.get(), 2.0) + + @parameterized.parameters([False, True]) + def test_scan_basic(self, jit): + box = Box(1.0) + + def double_it_10(): + def body(_, __): + box.set(box.get() * 2) + return None, None + _, _ = jax.lax.scan(body, None, None, length=10) + + if jit: + double_it_10 = jax.jit(double_it_10) + + double_it_10() + self.assertAllClose(box.get(), 1024., check_dtypes=False) + + def test_error_passing_multiple_times_to_jit(self): + + @jax.jit + def f(box1, box2): + ... + + b = Box(1.0) + with self.assertRaisesRegex(ValueError, "a Box instance can't be passed"): + f(b, b) + + # TODO(mattjj): re-enable this test + # def test_error_returning_from_jit(self): + # @jax.jit + # def f(): + # return {'a': Box(1.0)} + + # with self.assertRaisesRegex(ValueError, "a Box instance can't be returned"): + # f() + + +class ListTest(jtu.JaxTestCase): + + def test_eager(self): + lst = List() + lst.append(1.0) + lst.append(2.0) + lst.append(3.0) + self.assertAllClose(lst.get(), [1., 2., 3.]) + + def test_jit_arg(self): + @jax.jit + def f(lst, x): + assert tracing_ok + lst.append(1.0) + lst.append(2.0) + lst.append({'c': x + 3.0}) + + tracing_ok = True + lst1 = List() + f(lst1, 0) + self.assertAllClose(lst1.get(), [1., 2., {'c': 3.}]) + + tracing_ok = False + lst2 = List() + lst2.append(0.) + f(lst2, 1) + self.assertAllClose(lst2.get(), [0., 1., 2., {'c': 4.}]) + + def test_jit_closure(self): + lst = List() + + @jax.jit + def f(x): + assert tracing_ok + lst.append(1.0) + lst.append({'a': 2.0}) + lst.append(x + 3.0) + + tracing_ok = True + f(1) + self.assertAllClose(lst._val, [1., {'a': 2.}, 4.]) + + tracing_ok = False + f(2) + self.assertAllClose(lst.get(), [1., {'a': 2.}, 4., 1., {'a': 2.0}, 5.0]) + + def test_jit_closure_nested(self): + lst = List() + + @jax.jit + def h(x): + lst.append(x) + + @jax.jit + def k(x): + lst.append(x) + + k(1.0) + k(2.0) + + h(0.0) + self.assertAllClose(lst.get(), [0., 1., 2.]) + + @parameterized.parameters([False, True]) + def test_scan_basic(self, jit): + lst = List() + + def f(): + def body(_, x): + lst.append(2 * x) + lst.append(2 * x + 1) + return (), () + (), () = jax.lax.scan(body, (), jnp.arange(3.)) + + if jit: + f = jax.jit(f) + + f() + + self.assertAllClose(lst.get(), [0., 1., 2., 3., 4., 5.]) + + @parameterized.parameters([False, True]) + def test_scan_basic_hetero(self, jit): + lst = List() + + def f(): + def body(_, x): + lst.append(2 * x) + lst.append({'a': (2 * x + 1, 2 * x + 2)}) + return (), () + (), () = jax.lax.scan(body, (), jnp.arange(3.)) + + if jit: + f = jax.jit(f) + + f() + + expected = [ + 0., + {'a': (1., 2.)}, + 2., + {'a': (3., 4.)}, + 4., + {'a': (5., 6.)}, + ] + self.assertAllClose(lst.get(), expected) + + @parameterized.parameters([False, True]) + def test_get_basic(self, jit): + + def f(): + lst = List() + lst.append(1.) + lst.append(2.) + return lst.get() + + if jit: + f = jax.jit(f) + + lst = f() + self.assertAllClose(lst, [1., 2.]) + + def test_freeze_nonlocal_list(self): + lst = List() + + @jax.jit + def f(): + lst.get() + + with self.assertRaisesRegex(Exception, "can't read the value"): + f() + + def test_freeze_nonlocal_list_nested(self): + @jax.jit + def f(): + lst = List() + + @jax.jit + def g(): + lst.get() + + g() + + with self.assertRaisesRegex(Exception, "can't read the value"): + f() + + @parameterized.parameters([False, True]) + def test_append_after_get(self, jit): + def f(): + lst = List() + lst.append(1.) + lst.append(2.) + val = lst.get() + lst.append(3.) + return lst.get() + + if jit: + f = jax.jit(f) + + lst = f() + self.assertAllClose(lst, [1., 2., 3.]) + + def test_get_on_nonlocal_list_closure(self): + lst = List() + + @jax.jit + def f(): + lst.append(1.) + lst.append(2.) + with self.assertRaisesRegex(Exception, "can't read"): + val = lst.get() + + def test_get_on_nonlocal_list_arg(self): + lst = List() + + @jax.jit + def f(lst): + lst.append(1.) + lst.append(2.) + with self.assertRaisesRegex(Exception, "can't read"): + val = lst.get() + + @parameterized.parameters([False, True]) + def test_custom_vjp_plumbing(self, jit): + lst = List() + + @jax.custom_vjp + def foo(x): + return x + def foo_fwd(x): + return foo(x), None + def foo_bwd(_, g): + lst.append(g) + return g, + foo.defvjp(foo_fwd, foo_bwd) + + def f(x): + x = 2 * x + x = foo(x) + x = 2 * x + return x + + if jit: + f = jax.jit(f) + + jax.grad(f)(1.0) + self.assertAllClose(lst.get(), [2.0]) + + def test_error_passing_multiple_times_to_jit(self): + @jax.jit + def f(lst1, lst2): + ... + + b = List([]) + with self.assertRaisesRegex(ValueError, "a List instance can't be passed"): + f(b, b) + + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/batching_test.py b/tests/batching_test.py index f2a4e8c34fe3..393317bcbe77 100644 --- a/tests/batching_test.py +++ b/tests/batching_test.py @@ -1328,33 +1328,70 @@ def list_insert(lst: list[a], idx: int, val: a) -> list[a]: @jtu.thread_unsafe_test_class() # temporary registration isn't thread-safe class VmappableTest(jtu.JaxTestCase): - def test_basic(self): + @parameterized.parameters([False, True]) + def test_basic(self, jit): with temporarily_register_named_array_vmappable(): def f(x): return named_mul(x, x) + if jit: + f = jax.jit(f) x = NamedArray(['i', 'j'], jnp.arange(12.).reshape(3, 4)) g = jax.vmap(f, - in_axes=NamedMapSpec('i', 0), - out_axes=NamedMapSpec('i', 1), - axis_size=3) + in_axes=NamedMapSpec('i', 0), + out_axes=NamedMapSpec('i', 1), + axis_size=3) ans = g(x) expected = NamedArray(['j', 'i'], jnp.arange(12.).reshape(3, 4).T ** 2) self.assertEqual(ans.names, expected.names) self.assertAllClose(ans.data, expected.data) - def test_basic_jit(self): - with temporarily_register_named_array_vmappable(): - def f(x): - return named_mul(x, x) - - x = NamedArray(['i', 'j'], jnp.arange(12.).reshape(3, 4)) - ans = jax.jit(f)(x) - expected = NamedArray(['i', 'j'], jnp.arange(12.).reshape(3, 4) ** 2) - - self.assertEqual(ans.names, expected.names) - self.assertAllClose(ans.data, expected.data) + def test_to_elt_that_binds_primitives(self): + class A: + data: Array + def __init__(self, data): + self.data = data + def to_elt(cont, _, val, spec): + return cont(val.data + 1, spec) + def from_elt(cont, size, elt, spec): + assert False + + @jax.jit + def f(): + a = A(jnp.arange(3.)) + return jax.vmap(lambda x: x - 1, axis_size=3)(a) + + try: + batching.register_vmappable(A, int, int, to_elt, from_elt, None) + ans = f() + finally: + batching.unregister_vmappable(A) + + self.assertAllClose(ans, jnp.arange(3.)) + + def test_from_elt_that_binds_primitives(self): + class A: + data: Array + def __init__(self, data): + self.data = data + def to_elt(cont, _, val, spec): + return A(cont(val.data, spec)) + def from_elt(cont, size, elt, spec): + return A(cont(size, elt.data + 1, spec)) + + @jax.jit + def f(): + a = A(jnp.arange(3.)) + return jax.vmap(lambda x: x, axis_size=3)(a).data + + try: + batching.register_vmappable(A, int, int, to_elt, from_elt, None) + ans = f() + finally: + batching.unregister_vmappable(A) + + self.assertAllClose(ans, jnp.arange(3.) + 1) def test_types_with_same_spec(self): # We register NamedArray. diff --git a/tests/buffer_callback_test.py b/tests/buffer_callback_test.py new file mode 100644 index 000000000000..8bef4135f5d5 --- /dev/null +++ b/tests/buffer_callback_test.py @@ -0,0 +1,188 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.testing import absltest +from absl.testing import parameterized +import numpy as np + +import jax +import jax.numpy as jnp +from jax._src import test_util as jtu +from jax.experimental import buffer_callback + +jax.config.parse_flags_with_absl() + + +class BufferCallbackTest(jtu.JaxTestCase): + + def setUp(self): + super().setUp() + if jtu.test_device_matches(["tpu"]): + self.skipTest("Not supported on TPU.") + + @parameterized.parameters(jtu.dtypes.all) + @jtu.run_on_devices("cpu") + def test_numpy(self, dtype): + def callback(ctx, out, arg): + with self.assertRaisesRegex( + jax.errors.JaxRuntimeError, "XLA FFI GPU context is not available" + ): + ctx.stream + + self.assertEqual(ctx.stage, buffer_callback.ExecutionStage.EXECUTE) + self.assertEqual(arg.shape, shape) + self.assertEqual(arg.dtype, dtype) + self.assertEqual(out.shape, shape) + self.assertEqual(out.dtype, dtype) + + self.assertFalse(arg.writeable) + self.assertTrue(out.writeable) + + x = np.asarray(arg) + self.assertArraysEqual(x, data) + + y = np.asarray(out) + self.assertEqual(x.dtype, y.dtype) + self.assertEqual(x.shape, y.shape) + y[...] = x + + rng = jtu.rand_default(self.rng()) + shape = (3, 4) + data = rng(shape, dtype) + fun = buffer_callback.buffer_callback( + callback, jax.ShapeDtypeStruct(data.shape, data.dtype) + ) + self.assertArraysEqual(fun(data), data) + + @parameterized.parameters(jtu.dtypes.all) + @jtu.run_on_devices("cpu") + def test_dlpack(self, dtype): + if dtype == jnp.bfloat16: + self.skipTest("Numpy's DLPack implementation does not support bfloat16") + + def callback(ctx, out, arg): + del ctx # unused + + x = np.from_dlpack(arg) + self.assertArraysEqual(x, data) + + y = np.from_dlpack(out) + self.assertEqual(x.dtype, y.dtype) + self.assertEqual(x.shape, y.shape) + + rng = jtu.rand_default(self.rng()) + shape = (3, 4) + data = rng(shape, dtype) + fun = buffer_callback.buffer_callback( + callback, jax.ShapeDtypeStruct(data.shape, data.dtype) + ) + + # We can't actually test the output because numpy doesn't support writable + # DLPack tensors. + jax.block_until_ready(fun(data)) + + @parameterized.product( + dtype=jtu.dtypes.all, command_buffer_compatible=[True, False] + ) + @jtu.run_on_devices("cuda") + def test_cuda_array_interface(self, dtype, command_buffer_compatible): + if command_buffer_compatible: + self.skipTest("Requires jaxlib extension version of at least 337.") + + def callback(ctx, out, arg): + ctx.stream # doesn't crash + + self.assertEqual(ctx.stage, buffer_callback.ExecutionStage.EXECUTE) + self.assertEqual(arg.shape, shape) + self.assertEqual(arg.dtype, dtype) + self.assertEqual(out.shape, shape) + self.assertEqual(out.dtype, dtype) + + obj = arg.__cuda_array_interface__ + self.assertEqual(obj["shape"], data.shape) + self.assertEqual(obj["typestr"], data.dtype.str) + + obj = out.__cuda_array_interface__ + self.assertEqual(obj["shape"], data.shape) + self.assertEqual(obj["typestr"], data.dtype.str) + + rng = jtu.rand_default(self.rng()) + shape = (3, 4) + data = rng(shape, dtype) + fun = buffer_callback.buffer_callback( + callback, jax.ShapeDtypeStruct(data.shape, data.dtype), + command_buffer_compatible=command_buffer_compatible, + ) + + # TODO: There's an XLA:GPU/CUDA bug that causes a segfault when + # instantiating an empty CUDA graph. Once that bug is fixed or worked + # around, add a test that checks that the Python callback is only executed + # once. + jax.block_until_ready(fun(data)) + + @parameterized.parameters([ + "sequential", "sequential_unrolled", "expand_dims", "broadcast_all" + ]) + @jtu.run_on_devices("cpu") + def test_batching(self, vmap_method): + def callback(ctx, out, *args): + del ctx # unused + x = np.asarray(args[0]) + y = np.asarray(args[1]) + z = np.asarray(out) + z[...] = x + z[...] += y + + rng = jtu.rand_default(self.rng()) + shape = (3, 4) + x = rng(shape, jnp.float32) + y = rng(shape, jnp.float32) + fun = buffer_callback.buffer_callback( + callback, + jax.ShapeDtypeStruct(x.shape[1:], x.dtype), + vmap_method=vmap_method, + ) + self.assertArraysEqual(jax.vmap(fun)(x, y), x + y) + + @jtu.run_on_devices("cpu") + def test_input_output_aliases(self): + def callback(ctx, out, arg): + del ctx # unused + x = np.asarray(arg) + y = np.asarray(out) + self.assertEqual(x.ctypes.data, y.ctypes.data) + + rng = jtu.rand_default(self.rng()) + shape = (3, 4) + data = rng(shape, jnp.float32) + fun = buffer_callback.buffer_callback( + callback, jax.ShapeDtypeStruct(data.shape, data.dtype), + input_output_aliases={0: 0}, + ) + jax.block_until_ready(fun(data)) + + def test_side_effect(self): + def callback(*_): + nonlocal called + called = True + + called = False + fun = buffer_callback.buffer_callback( + callback, jax.ShapeDtypeStruct((), jnp.float32), has_side_effect=True) + jax.block_until_ready(fun()) + self.assertTrue(called) + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/cache_key_test.py b/tests/cache_key_test.py index 2faa4dbaf9d4..35ac03011a97 100644 --- a/tests/cache_key_test.py +++ b/tests/cache_key_test.py @@ -83,9 +83,9 @@ def test_hash_accelerator_devices(self): self.assertEqual(dev_hash1, dev_hash2) acc_hash1 = self.get_hashed_value( - cache_key._hash_accelerator_config, devices, xla_bridge.get_backend()) + cache_key._hash_accelerator_config, devices) acc_hash2 = self.get_hashed_value( - cache_key._hash_accelerator_config, devices, xla_bridge.get_backend()) + cache_key._hash_accelerator_config, devices) self.assertEqual(acc_hash1, acc_hash2) def test_hash_platform(self): @@ -163,6 +163,8 @@ def test_different_computations(self): cache_key.get(computation2, devices, compile_options, backend), ) + # TODO(phawkins): this test flakes if test concurrency is enabled. + @jtu.thread_unsafe_test() def test_custom_partitioning_ptr_removal(self): def _partition(mesh, arg_shapes, result_shape): arg_shardings = jax.tree.map(lambda x: x.sharding, arg_shapes) @@ -178,7 +180,8 @@ def _cp_add(x, y): _cp_add.def_partition( infer_sharding_from_operands=_infer_sharding_from_operands, - partition=_partition) + partition=_partition, + sharding_rule='..., ... -> ...') devices = np.asarray(jax.devices()) with Mesh(devices, ('x',)) as m: diff --git a/tests/checkify_test.py b/tests/checkify_test.py index 6a1660b28578..050ac5314da3 100644 --- a/tests/checkify_test.py +++ b/tests/checkify_test.py @@ -23,14 +23,14 @@ from jax import lax from jax.experimental import checkify from jax.experimental import pjit -from jax.experimental import shard_map -from jax.sharding import NamedSharding +from jax._src import shard_map +from jax.sharding import NamedSharding, PartitionSpec as P from jax._src import array from jax._src import config from jax._src import core from jax._src import test_util as jtu from jax._src.checkify import JaxRuntimeError, FailedCheckError, ErrorEffect, OOBError -from jax._src.lib import xla_extension +from jax._src.lib import _jax import jax.numpy as jnp config.parse_flags_with_absl() @@ -475,12 +475,25 @@ def f(init_val): self.assertIsNotNone(err.get()) self.assertStartsWith(err.get(), "division by zero") + def test_checify_donation_no_forwarding(self): + mesh = jtu.create_mesh((2,), ('x',)) + + @checkify.checkify + @partial(jax.jit, donate_argnums=(0,)) + def f(x: jax.Array) -> jax.Array: + checkify.check(jnp.all(x > 0), "a") + return x + + x = jax.device_put(jnp.zeros(64, dtype="int32"), NamedSharding(mesh, P())) + err, y = f(x) + err, z = f(y) # doesn't crash + @jtu.skip_on_devices("tpu") def test_while_loop_body_and_cond_error(self): def while_cond(val): i, cond_val, _ = val - _ = jnp.sin(cond_val) - return i < 2 + j = jnp.sin(cond_val) + return i + (0. * j) < 2 # don't let the sin value be dead code def while_body(val): i, cond_val, body_val = val @@ -541,7 +554,7 @@ def g(x, y): self.assertStartsWith(b_err.get(), "division by zero") @parameterized.parameters(True, False) - def test_shard_map(self, check_rep): + def test_shard_map(self, check_vma): def f(x): # unary func return jax.lax.axis_index("dev") * x / x @@ -558,12 +571,12 @@ def g(x, y): x = array.make_array_from_callback(inp.shape, ps, lambda idx: inp[idx]) f = shard_map.shard_map( - f, mesh, in_specs=pspec, out_specs=pspec, check_rep=check_rep + f, mesh=mesh, in_specs=pspec, out_specs=pspec, check_vma=check_vma ) f = jax.jit(f, in_shardings=ps, out_shardings=ps) f = checkify.checkify(f, errors=checkify.float_checks) g = shard_map.shard_map( - g, mesh, in_specs=(pspec, pspec), out_specs=pspec, check_rep=check_rep + g, mesh=mesh, in_specs=(pspec, pspec), out_specs=pspec, check_vma=check_vma ) g = jax.jit(g, in_shardings=(ps, ps), out_shardings=ps) g = checkify.checkify(g, errors=checkify.float_checks) @@ -1215,7 +1228,7 @@ def while_body(s): with self.assertRaisesRegex(ValueError, "checkify-of-vmap-of-while"): checked_f(jnp.asarray([1., 2., 3.]), jnp.asarray([5., 2., 4.])) - # TODO(lenamartens): reenable assertions below. + # TODO(lenamartens): re-enable assertions below. # self.assertIsNotNone(err.get()) # self.assertStartsWith(err.get(), "division by zero") @@ -1244,7 +1257,7 @@ def fun(x): with self.assertRaisesRegex(ValueError, "checkify-of-vmap-of-while"): checked_f(jnp.arange(5)) - # TODO(lenamartens): reenable assertions below. + # TODO(lenamartens): re-enable assertions below. # self.assertIsNone(err.get()) def test_assert_cond_no_data_dependence(self): @@ -1374,9 +1387,9 @@ def f(x): checkify.check(x > 0, "x needs to be positive") return x - with self.assertRaisesRegex(xla_extension.XlaRuntimeError, + with self.assertRaisesRegex(_jax.XlaRuntimeError, "x needs to be positive"): - f(-1.) + f(-1.).block_until_ready() if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/colocated_python_test.py b/tests/colocated_python_test.py index 52d494904fe6..452824cb6d52 100644 --- a/tests/colocated_python_test.py +++ b/tests/colocated_python_test.py @@ -16,9 +16,6 @@ import struct import tempfile import threading -import time -from typing import Sequence -import unittest from absl.testing import absltest from absl.testing import parameterized @@ -36,27 +33,9 @@ try: import cloudpickle # noqa + HAS_CLOUDPICKLE = True except (ModuleNotFoundError, ImportError): - raise unittest.SkipTest("tests depend on cloudpickle library") - - -def _colocated_cpu_devices( - devices: Sequence[jax.Device], -) -> Sequence[jax.Device]: - """Returns CPU devices colocated with the given devices.""" - try: - return colocated_python.colocated_cpu_devices(devices) - except (ValueError, AttributeError): - # PjRt-IFRT prepares CPU devices by its own. - # TODO(hyeontaek): Remove this fallback path once PjRt-IFRT prepares CPU - # devices by its own. - cpu_backend_devices = jax.local_devices(backend="cpu") - device_index_map = {device.id: i for i, device in enumerate(jax.devices())} - - available_devices = devices[: min(len(cpu_backend_devices), len(devices))] - return [ - cpu_backend_devices[device_index_map[d.id]] for d in available_devices - ] + HAS_CLOUDPICKLE = False _count_colocated_python_specialization_cache_miss = jtu.count_events( @@ -68,17 +47,35 @@ class ColocatedPythonTest(jtu.JaxTestCase): def setUp(self): super().setUp() + if not HAS_CLOUDPICKLE: + self.skipTest( + "ColocatedPythonTest depends on cloudpickle library" + ) if np.lib.NumpyVersion(np.__version__) < "2.0.0": self.skipTest( - "Serialization in Colocated Python needs StringDType, and thus" - " requires NumPy 2.0.0 or later" + "Serialization in Colocated Python needs StringDType, and thus" + " requires NumPy 2.0.0 or later" ) - def testMakeColocatedPythonProgram(self): + def test_colocated_cpu_devices(self): + mesh = jax.sharding.Mesh( + np.array(jax.local_devices()[:1]).reshape((1, 1)), ("x", "y") + ) + cpu_mesh1 = colocated_python.colocated_cpu_devices(mesh) + + cpu_devices = colocated_python.colocated_cpu_devices( + jax.local_devices()[:1] + ) + cpu_mesh2 = jax.sharding.Mesh( + np.array(cpu_devices).reshape((1, 1)), ("x", "y") + ) + self.assertEqual(cpu_mesh1, cpu_mesh2) + + def test_make_colocated_python_program(self): def add_one(x): return x + 1 - cpu_devices = _colocated_cpu_devices(jax.local_devices()) + cpu_devices = colocated_python.colocated_cpu_devices(jax.local_devices()) sharding = jax.sharding.SingleDeviceSharding(cpu_devices[0]) sds = jax.ShapeDtypeStruct((), jnp.int32, sharding=sharding) @@ -88,12 +85,12 @@ def add_one(x): ) del program - def testSimpleFunction(self): + def test_simple_function(self): @colocated_python.colocated_python def add_one(x): return x + 1 - cpu_devices = _colocated_cpu_devices(jax.local_devices()) + cpu_devices = colocated_python.colocated_cpu_devices(jax.local_devices()) x = np.array(1) x = jax.device_put(x, cpu_devices[0]) @@ -108,12 +105,12 @@ def add_one(x): self.assertEqual(out, np.array(2)) self.assertEqual(count(), 1) - def testSimpleFunctionWithTree(self): + def test_simple_function_with_tree(self): @colocated_python.colocated_python def add_one(x): return jax.tree.map(lambda x: x + 1, x) - cpu_devices = _colocated_cpu_devices(jax.local_devices()) + cpu_devices = colocated_python.colocated_cpu_devices(jax.local_devices()) x = [np.array(1), (np.array(2), {"v": np.array(3)})] x = jax.device_put(x, jax.sharding.SingleDeviceSharding(cpu_devices[0])) @@ -128,7 +125,7 @@ def add_one(x): self.assertEqual(out, [np.array(2), (np.array(3), {"v": np.array(4)})]) self.assertEqual(count(), 1) - def testEmptyInputFailsWithoutSpecialization(self): + def test_empty_input_fails_without_specialization(self): @colocated_python.colocated_python def make_zero(): return jnp.array(0) @@ -140,12 +137,12 @@ def make_zero(): ): _ = make_zero() - def testEmptyInputWithDevicesSpecialization(self): + def test_empty_input_with_devices_specialization(self): @colocated_python.colocated_python def make_zero(): return jnp.array(0) - cpu_devices = _colocated_cpu_devices(jax.local_devices()) + cpu_devices = colocated_python.colocated_cpu_devices(jax.local_devices()) with _count_colocated_python_specialization_cache_miss() as count: make_zero = make_zero.specialize(devices=cpu_devices[:1]) @@ -159,12 +156,12 @@ def make_zero(): self.assertEqual(out, np.array(0)) self.assertEqual(count(), 1) - def testInputPolymorphismWithoutOutSpecsFn(self): + def test_input_polymorphism_without_out_specs_fn(self): @colocated_python.colocated_python def add_one(x): return jax.tree.map(lambda x: x + 1, x) - cpu_devices = _colocated_cpu_devices(jax.local_devices()) + cpu_devices = colocated_python.colocated_cpu_devices(jax.local_devices()) x = np.array(1) x = jax.device_put(x, cpu_devices[0]) @@ -193,12 +190,12 @@ def add_one(x): self.assertEqual(out, [np.array(2), (np.array(3), {"v": np.array(4)})]) self.assertEqual(count(), 2) - def testInputPolymorphismAllowedWithOutSpecsFn(self): + def test_input_polymorphism_allowed_with_out_specs_fn(self): @colocated_python.colocated_python def add_one(x): return jax.tree.map(lambda x: x + 1, x) - cpu_devices = _colocated_cpu_devices(jax.local_devices()) + cpu_devices = colocated_python.colocated_cpu_devices(jax.local_devices()) x = np.array(1) x = jax.device_put(x, cpu_devices[0]) @@ -232,82 +229,108 @@ def add_one(x): ("on_main_thread", True), ("on_non_main_thread", False), ) - def testSequentialExecution(self, on_main_thread: bool): - cpu_devices = _colocated_cpu_devices(jax.local_devices()) + def test_sequential_execution(self, on_main_thread: bool): + cpu_devices = colocated_python.colocated_cpu_devices(jax.local_devices()) x = np.array(1) x = jax.device_put(x, cpu_devices[0]) - # Make sure that this input array is ready for use by the colocated Python - # function and does not disrupt elapsed time measurement. - jax.block_until_ready(x) @colocated_python.colocated_python - def sleep(x: jax.Array) -> jax.Array: - time.sleep(5) + def func0(x: jax.Array) -> jax.Array: + colocated_python._testing_global_state = 100 return x - # Specify out_specs_fn so that all executions are asynchronously dispatched. - sleep = sleep.specialize(out_specs_fn=lambda x: x) + @colocated_python.colocated_python + def func1(x: jax.Array) -> jax.Array: + assert "_testing_global_state" in colocated_python.__dict__ + assert colocated_python._testing_global_state == 100 + colocated_python._testing_global_state += 1 + return x - def sleep_twice_and_wait(x: jax.Array) -> None: - _ = sleep(x) - jax.block_until_ready(sleep(x)) + @colocated_python.colocated_python + def func2(x: jax.Array) -> jax.Array: + assert "_testing_global_state" in colocated_python.__dict__ + assert colocated_python._testing_global_state == 101 + return x - start_time = time.time() + @colocated_python.colocated_python + def cleanup(): + if "_testing_global_state" in colocated_python.__dict__: + del colocated_python._testing_global_state - # Two executions of `sleep` within `sleep_twice_and_wait` should run - # sequentially. - if on_main_thread: - sleep_twice_and_wait(x) - else: - t = threading.Thread(target=sleep_twice_and_wait, args=(x,)) - t.start() - t.join() + # Specify out_specs_fn so that their executions are asynchronously + # dispatched. + func0 = func0.specialize(out_specs_fn=lambda x: x) + func1 = func1.specialize(out_specs_fn=lambda x: x) + func2 = func2.specialize(out_specs_fn=lambda x: x) - elapsed_time = time.time() - start_time + # cleanup needs specialization because they do not have input arguments. + cleanup = cleanup.specialize(devices=cpu_devices[:1]) - # If sequential execution did not happen, elapsed time typically will be - # around 5 seconds. - self.assertGreaterEqual(elapsed_time, 10) + def calls(x: jax.Array) -> None: + # No explicit blocking before making the next call. + func0(x) + func1(x) + jax.block_until_ready(func2(x)) - def testConcurrentExecution(self): - cpu_devices = _colocated_cpu_devices(jax.local_devices()) + try: + # Executions in `calls` should run sequentially. + if on_main_thread: + calls(x) + else: + t = threading.Thread(target=calls, args=(x,)) + t.start() + t.join() + # Executions should succeed without an error. + finally: + cleanup() + + def test_concurrent_execution(self): + cpu_devices = colocated_python.colocated_cpu_devices(jax.local_devices()) x = np.array(1) x = jax.device_put(x, cpu_devices[0]) - # Make sure that this input array is ready for use by the colocated Python - # function and does not disrupt elapsed time measurement. - jax.block_until_ready(x) @colocated_python.colocated_python - def sleep(x: jax.Array) -> jax.Array: - time.sleep(5) + def init(x: jax.Array) -> jax.Array: + colocated_python._testing_global_state = threading.Barrier(3) return x - # Specify out_specs_fn so that all executions are asynchronously dispatched. - sleep = sleep.specialize(out_specs_fn=lambda x: x) - - def sleep_and_wait(x: jax.Array) -> None: - jax.block_until_ready(sleep(x)) + @colocated_python.colocated_python + def func(x: jax.Array) -> jax.Array: + assert "_testing_global_state" in colocated_python.__dict__ + colocated_python._testing_global_state.wait(timeout=5) + return x - start_time = time.time() + @colocated_python.colocated_python + def cleanup(): + if "_testing_global_state" in colocated_python.__dict__: + del colocated_python._testing_global_state - # All three executions of `sleep_and_wait` should run concurrently. - t1 = threading.Thread(target=sleep_and_wait, args=(x,)) - t2 = threading.Thread(target=sleep_and_wait, args=(x,)) - t1.start() - t2.start() - sleep_and_wait(x) - t1.join() - t2.join() + # Specify out_specs_fn so that their executions are asynchronously + # dispatched. + func = func.specialize(out_specs_fn=lambda x: x) - elapsed_time = time.time() - start_time + # cleanup needs specialization because they do not have input arguments. + cleanup = cleanup.specialize(devices=cpu_devices[:1]) - self.assertGreaterEqual(elapsed_time, 5) - # If concurrent execution did not happen, elapsed time typically will be - # around 15 seconds. - self.assertLess(elapsed_time, 10) + try: + jax.block_until_ready(init(x)) + + # All func calls should run concurrently and enter/exit the barrier. + t1 = threading.Thread(target=func, args=(x,)) + t2 = threading.Thread(target=func, args=(x,)) + t3 = threading.Thread(target=func, args=(x,)) + t1.start() + t2.start() + t3.start() + t1.join() + t2.join() + t3.join() + # Executions should succeed without a deadlock. + finally: + cleanup() - def testInputsWithDifferentDeviceOrders(self): - cpu_devices = _colocated_cpu_devices(jax.local_devices())[:2] + def test_inputs_with_different_device_orders(self): + cpu_devices = colocated_python.colocated_cpu_devices(jax.local_devices())[:2] if len(cpu_devices) < 2: self.skipTest("Not enough CPU devices") @@ -348,7 +371,7 @@ def add(x: jax.Array, y: jax.Array) -> jax.Array: out = jax.device_get(out) np.testing.assert_equal(out, np.array([2 + 4, 0 + 8])) - def testModuleVariableAccess(self): + def test_module_variable_access(self): try: # The following pattern of storing and accessing non-serialized state in # the Python module is discouraged for storing user-defined state. @@ -372,7 +395,7 @@ def get_global_state(x: jax.Array) -> jax.Array: del x return colocated_python._testing_global_state - cpu_devices = _colocated_cpu_devices(jax.local_devices()) + cpu_devices = colocated_python.colocated_cpu_devices(jax.local_devices()) x = np.array(1) x = jax.device_put(x, cpu_devices[0]) y = np.array(2) @@ -389,8 +412,8 @@ def get_global_state(x: jax.Array) -> jax.Array: if "_testing_global_state" in colocated_python.__dict__: del colocated_python._testing_global_state - def testStringProcessing(self): - cpu_devices = _colocated_cpu_devices(jax.local_devices()) + def test_string_processing(self): + cpu_devices = colocated_python.colocated_cpu_devices(jax.local_devices()) if len(cpu_devices) < 2: self.skipTest(f"Need at least two CPU devices, got: {len(cpu_devices)}") @@ -430,8 +453,8 @@ def f(x): ), ) - def testBinaryDataProcessing(self): - cpu_devices = _colocated_cpu_devices(jax.local_devices()) + def test_binary_data_processing(self): + cpu_devices = colocated_python.colocated_cpu_devices(jax.local_devices()) if len(cpu_devices) < 1: self.skipTest("Need at least one CPU devices") @@ -472,8 +495,8 @@ def f(x): self.assertEqual(out_ints[0], 1002) self.assertEqual(out_ints[1], 1003) - def testDetectInvalidMeshDevice(self): - cpu_devices = _colocated_cpu_devices(jax.local_devices()) + def test_detect_invalid_mesh_device(self): + cpu_devices = colocated_python.colocated_cpu_devices(jax.local_devices()) if jax.local_devices()[0].id == cpu_devices[0].id: self.skipTest( "This test only works in a setup where accelerator and CPU devices" @@ -493,8 +516,8 @@ def make_zero() -> jax.Array: make_zero = make_zero.specialize(devices=cpu_devices) jax.block_until_ready(make_zero()) - def testObjectLifecycle(self): - cpu_devices = _colocated_cpu_devices(jax.local_devices()) + def test_object_lifecycle(self): + cpu_devices = colocated_python.colocated_cpu_devices(jax.local_devices()) sharding = jax.sharding.SingleDeviceSharding(cpu_devices[0]) @colocated_python.colocated_python_class @@ -565,8 +588,8 @@ def cleanup(): finally: cleanup() - def testStatefulObject(self): - cpu_devices = _colocated_cpu_devices(jax.local_devices()) + def test_stateful_object(self): + cpu_devices = colocated_python.colocated_cpu_devices(jax.local_devices()) @colocated_python.colocated_python_class class Value: @@ -597,8 +620,8 @@ def fetch(self, x: jax.Array) -> jax.Array: out = jax.device_get(value.fetch(x)) self.assertEqual(out, np.array(7)) - def testObjectWithCapturedSharding(self): - cpu_devices = _colocated_cpu_devices(jax.local_devices()) + def test_object_with_captured_sharding(self): + cpu_devices = colocated_python.colocated_cpu_devices(jax.local_devices()) if len(cpu_devices) < 2: self.skipTest(f"Need at least two CPU devices, got: {len(cpu_devices)}") diff --git a/tests/compilation_cache_test.py b/tests/compilation_cache_test.py index 3fcc0ab476bf..3f1bb7fab4b1 100644 --- a/tests/compilation_cache_test.py +++ b/tests/compilation_cache_test.py @@ -134,7 +134,7 @@ def test_get_no_executable(self): backend = xla_bridge.get_backend() key = cc.get_cache_key(computation, devices, compile_options, backend) executable, compile_time = cc.get_executable_and_time( - key, compile_options, backend) + key, compile_options, backend, xc.DeviceList(tuple(devices.flat))) self.assertIsNone(executable) self.assertIsNone(compile_time) @@ -145,15 +145,20 @@ def test_diff_executables(self): num_replicas=1, num_partitions=1 ) backend = xla_bridge.get_backend() - executable1 = backend.compile(computation1, compile_options) - executable2 = backend.compile(computation2, compile_options) + executable_devices = xc.DeviceList(tuple(backend.local_devices())) + executable1 = backend.compile_and_load( + computation1, executable_devices, compile_options) + executable2 = backend.compile_and_load( + computation2, executable_devices, compile_options) cc.put_executable_and_time( "key1", "computation1", executable1, backend, FAKE_COMPILE_TIME) cc.put_executable_and_time( "key2", "computation2", executable2, backend, FAKE_COMPILE_TIME) self.assertNotEqual( - cc.get_executable_and_time("key1", compile_options, backend)[0], - cc.get_executable_and_time("key2", compile_options, backend)[0] + cc.get_executable_and_time( + "key1", compile_options, backend, executable_devices)[0], + cc.get_executable_and_time( + "key2", compile_options, backend, executable_devices)[0] ) def test_put_executable(self): @@ -167,12 +172,14 @@ def test_put_executable(self): num_replicas=1, num_partitions=1 ) backend = xla_bridge.get_backend() - executable = backend.compile(str(computation), compile_options) + executable_devices = xc.DeviceList(tuple(devices.flat)) + executable = backend.compile_and_load( + str(computation), executable_devices, compile_options) key = cc.get_cache_key(computation, devices, compile_options, backend) cc.put_executable_and_time( key, "alambda", executable, backend, FAKE_COMPILE_TIME) executable_retrieved, compile_time_retrieved = cc.get_executable_and_time( - key, compile_options, backend) + key, compile_options, backend, executable_devices) inputs_to_executable = ( jnp.array(1, dtype=np.int32), jnp.array(2, dtype=np.int32), @@ -237,7 +244,7 @@ def test_enable_compilation_cache(self): g = jit(lambda x: x * 3) g(2) cache = cc._get_cache(backend) - self.assertIsNotNone(cache) # Cache should be initalized + self.assertIsNotNone(cache) # Cache should be initialized def test_xla_autofdo_profile_version(self): original_profile_version = config.jax_xla_profile_version.value @@ -344,7 +351,8 @@ def test_cache_saving_metric(self): config.persistent_cache_min_entry_size_bytes(0), ): durations = Counter() # Map metric name to time duration. - def append_metric_duration(metric, duration): + def append_metric_duration(metric, duration, **kwargs): + del kwargs durations[metric] += duration with jtu.register_event_duration_listener(append_metric_duration): @@ -562,8 +570,9 @@ def test_backend_serialization_deserialization(self): .runtime_executable() ) serialized_executable = backend.serialize_executable(executable) - deserialized_executable = backend.deserialize_executable( - serialized_executable, None) + deserialized_executable = backend.deserialize_executable( # type: ignore + serialized_executable, + xc.DeviceList(tuple(jax.local_devices(backend=backend))), None) self.assertEqual( executable.fingerprint, deserialized_executable.fingerprint) diff --git a/tests/core_test.py b/tests/core_test.py index c46d493bda54..334df2222b0c 100644 --- a/tests/core_test.py +++ b/tests/core_test.py @@ -203,6 +203,13 @@ def test_is_valid_jaxtype(self, dtype): else: self.assertFalse(core.valid_jaxtype(arr)) + def test_str_aval(self): + aval = ShapedArray((8, 2), np.int32) + self.assertEqual(str(aval), "int32[8,2]") + + aval = ShapedArray((8, 2), np.int32, weak_type=True) + self.assertEqual(str(aval), "~int32[8,2]") + @parameterized.named_parameters( (str(i), *spec) for i, spec in enumerate(test_specs)) def test_jit(self, f, args): @@ -357,15 +364,15 @@ def g_vmap(x): def test_dropvar_avals(self): def f(x): def body(c, _): - return c, None + x1, x2 = c + return (2 * x1, 2 * x2), None (x1, x2), _ = jax.lax.scan(body, (x, x), None, length=1) return [x2] aval = core.ShapedArray((), jnp.dtype('int32')) pval = pe.PartialVal.unknown(aval) jaxpr, _, _ = pe.trace_to_jaxpr_nounits( - lu.wrap_init(f, - debug_info=debug_info("test", f, (0,), {})), + lu.wrap_init(f, debug_info=debug_info("test", f, (0,), {})), [pval], False) dropvar, b = jaxpr.eqns[0].outvars self.assertEqual(dropvar.aval, aval) @@ -397,6 +404,12 @@ def setUp(self): lax_control_flow._initial_style_jaxpr.cache_clear() lax_control_flow.common._pad_jaxpr_constvars.cache_clear() + def tearDown(self): + super().tearDown() + lax_control_flow._initial_style_open_jaxpr.cache_clear() + lax_control_flow._initial_style_jaxpr.cache_clear() + lax_control_flow.common._pad_jaxpr_constvars.cache_clear() + def test_check_jaxpr_correct(self): jaxpr = make_jaxpr(lambda x: jnp.sin(x) + jnp.cos(x))(1.).jaxpr core.check_jaxpr(jaxpr) @@ -405,6 +418,7 @@ def test_check_jaxpr_cond_correct(self): jaxpr = make_jaxpr(lambda x: lax.switch(0, [jnp.sin, jnp.cos], x))(1.).jaxpr core.check_jaxpr(jaxpr) + @jtu.thread_unsafe_test() # in-place mutation of possibly-cached jaxpr def test_check_jaxpr_jit_invalid(self): jaxpr = make_jaxpr(jax.jit(lambda x, y: x + 1))(1., 2.).jaxpr pjit_eqn, = jaxpr.eqns @@ -414,6 +428,7 @@ def test_check_jaxpr_jit_invalid(self): '0 operands cannot call jaxpr with 2 inputs', lambda: core.check_jaxpr(jaxpr)) + @jtu.thread_unsafe_test() # in-place mutation of possibly-cached jaxpr def test_check_jaxpr_cond_invalid(self): jaxpr = make_jaxpr(lambda x: lax.switch(0, [jnp.sin, jnp.cos], x))(1.).jaxpr cond = next(eqn for eqn in jaxpr.eqns if eqn.primitive.name == 'cond') @@ -433,6 +448,7 @@ def f(c, x): jaxpr = make_jaxpr(partial(lax.scan, f))(c, xs).jaxpr core.check_jaxpr(jaxpr) + @jtu.thread_unsafe_test() # in-place mutation of possibly-cached jaxpr def test_check_jaxpr_invalid_long(self): # jaxprs can be large, and this tests that when large ones are printed for # context in jaxpr typechecking errors, they're not printed entirely @@ -464,6 +480,7 @@ def g(x): self.assertIn('while checking jaxpr:', msg) self.assertLess(msg.count('\n'), 200) + @jtu.thread_unsafe_test() # in-place mutation of possibly-cached jaxpr def test_check_jaxpr_eqn_mismatch(self): def f(x): return jnp.sin(x) + jnp.cos(x) @@ -487,7 +504,7 @@ def new_jaxpr(): self.assertRaisesRegex( core.JaxprTypeError, r"Value for variable 'b' inconsistently typed as f32\[\] " - r"for let-binder of type i32\[\]\n\nin equation:\n\nb:i32\[\] = sin a", + r"for let-binder of type i32\[\]\n\nin equation:\n\nb:i32\[\] = sin\ a", lambda: core.check_jaxpr(jaxpr)) jaxpr = new_jaxpr() @@ -496,7 +513,7 @@ def new_jaxpr(): self.assertRaisesRegex( core.JaxprTypeError, r"Value for variable 'b' inconsistently typed as f32\[\] " - r"for let-binder of type f32\[2,3\]\n\nin equation:\n\nb:f32\[2,3\] = sin a", + r"for let-binder of type f32\[2,3\]\n\nin equation:\n\nb:f32\[2,3\] = sin\ a", lambda: core.check_jaxpr(jaxpr)) def test_jaxpr_dropvar_from_jit_call(self): @@ -534,15 +551,6 @@ def f(x): assert isinstance(jaxpr.eqns[-1].outvars[0], core.DropVar) core.check_jaxpr(jaxpr) - def test_jaxpr_undefined_eqn_invar(self): - jaxpr = make_jaxpr(lambda x: jnp.sin(x) + jnp.cos(x))(1.).jaxpr - cos = next(eqn for eqn in jaxpr.eqns if eqn.primitive.name == 'cos') - cos.invars[0] = core.gensym(suffix='_test')(cos.invars[0].aval) - self.assertRaisesRegex( - core.JaxprTypeError, - r"Variable '.+_test' not defined\n\nin equation:", - lambda: core.check_jaxpr(jaxpr)) - @jtu.with_config(jax_dynamic_shapes=True) class DynamicShapesTest(jtu.JaxTestCase): diff --git a/tests/custom_api_test.py b/tests/custom_api_test.py new file mode 100644 index 000000000000..a53995751f36 --- /dev/null +++ b/tests/custom_api_test.py @@ -0,0 +1,4736 @@ +# Copyright 2018 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import collections +from collections.abc import Callable +import concurrent.futures +import functools +from functools import partial +import itertools as it +import re +import unittest +import textwrap + +from absl.testing import absltest, parameterized +import numpy as np + +import jax +import jax.numpy as jnp +from jax import float0, grad, jit +from jax import lax +from jax import tree_util +from jax.ad_checkpoint import checkpoint as new_checkpoint +import jax.custom_batching +import jax.custom_derivatives +import jax.custom_transpose +import jax.experimental.custom_dce +from jax.errors import UnexpectedTracerError + +from jax._src import api +from jax._src import api_util +from jax._src import config +from jax._src import core +from jax._src import custom_derivatives +from jax._src import test_util as jtu +from jax._src.interpreters import partial_eval as pe + +config.parse_flags_with_absl() + + +class CustomJVPTest(jtu.JaxTestCase): + + def test_basic(self): + @jax.custom_jvp + def f(x): + return jnp.sin(x) + def f_jvp(primals, tangents): + x, = primals + g, = tangents + return f(x), 2 * jnp.cos(x) * g + f.defjvp(f_jvp) + + x = 3. + self.assertAllClose(f(x), jnp.sin(x)) + self.assertAllClose(api.jvp(f, (x,), (1.,)), + (jnp.sin(x), 2 * jnp.cos(x))) + self.assertAllClose(api.grad(f)(x), 2 * jnp.cos(x)) + + def test_invariance(self): + @jax.custom_jvp + def f(x): + return jnp.cos(2 * x) / 2. + def f_jvp(primals, tangents): + x, = primals + g, = tangents + return (f(x), 3 * g) + f.defjvp(f_jvp) + def f2(x): + y, _ = api.jvp(f, (x,), (x,)) + return y + def f3(x): + y, _ = api.jvp(f2, (x,), (x,)) + return y + x = 1. + self.assertAllClose(api.jvp(f, (x,), (x,)), + api.jvp(f2, (x,), (x,)), + check_dtypes=False) + self.assertAllClose(api.jvp(f, (x,), (x,)), + api.jvp(f3, (x,), (x,)), + check_dtypes=False) + + def test_python_control_flow(self): + @jax.custom_jvp + def f(x): + if x > 0: + return jnp.sin(x) + else: + return jnp.cos(x) + def f_jvp(primals, tangents): + x, = primals + g, = tangents + if x > 0: + return f(x), 2 * g + else: + return f(x), 3 * g + f.defjvp(f_jvp) + x = 2. + self.assertAllClose(f(x), jnp.sin(x)) + self.assertAllClose(f(-x), jnp.cos(-x)) + self.assertAllClose(api.jvp(f, (x,), (1.,)), + (jnp.sin(x), 2.), + check_dtypes=False) + self.assertAllClose(api.jvp(f, (-x,), (1.,)), + (jnp.cos(-x), 3.), + check_dtypes=False) + self.assertAllClose(api.grad(f)(x), 2., check_dtypes=False) + self.assertAllClose(api.grad(f)(-x), 3., check_dtypes=False) + + def test_vmap(self): + @jax.custom_jvp + def f(x): + assert jnp.ndim(x) == 0 + return jnp.sin(x) + def f_jvp(primals, tangents): + x, = primals + g, = tangents + assert jnp.ndim(x) == jnp.ndim(g) == 0 + return f(x), 2 * jnp.cos(x) * g + f.defjvp(f_jvp) + + x = jnp.arange(3.) + xx = jnp.arange(6.).reshape(2, 3) + + # vmap of f + self.assertAllClose(api.vmap(f)(x), jnp.sin(x)) + self.assertAllClose(api.vmap(api.vmap(f))(xx), jnp.sin(xx)) + + # vmap of jvp of f + self.assertAllClose(api.vmap(lambda x: api.jvp(f, (x,), (x,)))(x), + (jnp.sin(x), 2 * jnp.cos(x) * x)) + self.assertAllClose(api.vmap(api.vmap(lambda x: api.jvp(f, (x,), (x,))))(xx), + (jnp.sin(xx), 2 * jnp.cos(xx) * xx)) + + # jvp of vmap of f + self.assertAllClose(api.jvp(api.vmap(f), (x,), (x,)), + (jnp.sin(x), 2 * jnp.cos(x) * x)) + self.assertAllClose(api.jvp(api.vmap(api.vmap(f)), (xx,), (xx,)), + (jnp.sin(xx), 2 * jnp.cos(xx) * xx)) + + # vmap of jvp of vmap of f + self.assertAllClose(api.vmap(lambda x: api.jvp(api.vmap(f), (x,), (x,)))(xx), + (jnp.sin(xx), 2 * jnp.cos(xx) * xx)) + + def test_jit(self): + @jax.custom_jvp + def f(x): + return jnp.sin(x) + def f_jvp(primals, tangents): + x, = primals + g, = tangents + return f(x), 2 * jnp.cos(x) * g + f.defjvp(f_jvp) + + x = 3. + + # jit + self.assertAllClose(api.jit(f)(x), jnp.sin(x)) + self.assertAllClose(api.jit(api.jit(f))(x), jnp.sin(x)) + + # jit of jvp + self.assertAllClose(api.jit(lambda x: api.jvp(f, (x,), (x,)))(x), + (jnp.sin(x), 2 * jnp.cos(x) * x), + check_dtypes=False) + + # jvp of jit + self.assertAllClose(api.jvp(api.jit(f), (x,), (x,)), + (jnp.sin(x), 2 * jnp.cos(x) * x), + check_dtypes=False) + + def test_pytrees(self): + @jax.custom_jvp + def f(x): + return {'b': jnp.sin(x['a'])} + def f_jvp(primals, tangents): + x, = primals + g, = tangents + return f(x), {'b': 2 * jnp.cos(x['a']) * g['a']} + f.defjvp(f_jvp) + x = {'a': 3.} + self.assertAllClose(f(x)['b'], jnp.sin(x['a'])) + self.assertAllClose(api.jvp(f, (x,), (x,)), + ({'b': jnp.sin(x['a'])}, + {'b': 2 * jnp.cos(x['a']) * x['a']}), + check_dtypes=False) + + def test_kwargs(self): + # from https://github.com/jax-ml/jax/issues/1938 + @jax.custom_jvp + def my_fun(x, y, c=1.): + return c * (x + y) + def my_jvp(primals, tangents): + x, y, c = primals + t_x, t_y, t_c = tangents + return my_fun(x, y, c), t_c + my_fun.defjvp(my_jvp) + f = lambda x, y: jnp.square(my_fun(x, y, c=2.)).sum() + f(10., 5.) # doesn't crash + api.jvp(f, (10., 5.), (1., 1.)) # doesn't crash + + def test_initial_style(self): + @jax.custom_jvp + def f(x): + return 3 * x + def f_jvp(primals, tangents): + x, = primals + g, = tangents + return f(x), 2 * g + f.defjvp(f_jvp) + + def foo(x): + out, _ = lax.scan(lambda c, _: (f(c), None), x, None, length=1) + return out + + ans = api.grad(foo)(3.) + expected = 2. + self.assertAllClose(ans, expected, check_dtypes=False) + + ans = api.grad(api.jit(foo))(3.) + expected = 2. + self.assertAllClose(ans, expected, check_dtypes=False) + + ans = api.jit(api.grad(foo))(3.) + expected = 2. + self.assertAllClose(ans, expected, check_dtypes=False) + + ans = api.grad(api.grad(foo))(3.) + expected = 0. + self.assertAllClose(ans, expected, check_dtypes=False) + + ans = api.grad(api.grad(api.jit(foo)))(3.) + expected = 0. + self.assertAllClose(ans, expected, check_dtypes=False) + + ans = api.grad(api.jit(api.grad(foo)))(3.) + expected = 0. + self.assertAllClose(ans, expected, check_dtypes=False) + + ans = api.jit(api.grad(api.grad(foo)))(3.) + expected = 0. + self.assertAllClose(ans, expected, check_dtypes=False) + + def test_initial_style_vmap(self): + @jax.custom_jvp + def f(x): + assert jnp.ndim(x) == 0 + return 3 * x + def f_jvp(primals, tangents): + x, = primals + g, = tangents + return f(x), 2 * g + f.defjvp(f_jvp) + + def foo(x): + out, _ = lax.scan(lambda c, _: (f(c), None), x, None, length=1) + return out + + ans = api.vmap(foo)(jnp.ones(3)) + expected = 3. * jnp.ones(3) + self.assertAllClose(ans, expected, check_dtypes=False) + + ans = api.vmap(api.jit(foo))(jnp.ones(3)) + expected = 3. * jnp.ones(3) + self.assertAllClose(ans, expected, check_dtypes=False) + + ans = api.jit(api.vmap(foo))(jnp.ones(3)) + expected = 3. * jnp.ones(3) + self.assertAllClose(ans, expected, check_dtypes=False) + + ans = api.grad(lambda x: api.vmap(foo)(x).sum())(jnp.ones(3)) + expected = 2. * jnp.ones(3) + self.assertAllClose(ans, expected, check_dtypes=False) + + ans = api.grad(lambda x: api.vmap(api.jit(foo))(x).sum())(jnp.ones(3)) + expected = 2. * jnp.ones(3) + self.assertAllClose(ans, expected, check_dtypes=False) + + ans = api.grad(lambda x: api.jit(api.vmap(foo))(x).sum())(jnp.ones(3)) + expected = 2. * jnp.ones(3) + self.assertAllClose(ans, expected, check_dtypes=False) + + ans = api.grad(api.jit(lambda x: api.vmap(foo)(x).sum()))(jnp.ones(3)) + expected = 2. * jnp.ones(3) + self.assertAllClose(ans, expected, check_dtypes=False) + + ans = api.jit(api.grad(lambda x: api.vmap(foo)(x).sum()))(jnp.ones(3)) + expected = 2. * jnp.ones(3) + self.assertAllClose(ans, expected, check_dtypes=False) + + def test_initial_style_vmap_with_collective(self): + + @jax.custom_jvp + def f(x): + return lax.psum(x, 'foo') + + @f.defjvp + def f_jvp(xs, ts): + x, = xs + t, = ts + return lax.psum(x, 'foo'), t + + def g(x): + jaxpr = api.make_jaxpr(f)(x) + return core.eval_jaxpr(jaxpr.jaxpr, [], x)[0] + + v = api.vmap(lambda _, x: g(x), axis_name='foo', in_axes=(0, None), + out_axes=None)(jnp.arange(4.), 2.) + self.assertAllClose(v, 8.) + + def test_closed_over_tracers_error_message(self): + def f(x): + @jax.custom_jvp + def g(y): + return x + y + def g_jvp(primals, tangents): + return g(x), 2 * primals[0] + g.defjvp(g_jvp) + return g(1.) + + self.assertRaises(UnexpectedTracerError, lambda: api.jvp(f, (3.,), (1.,))) + self.assertRaises(UnexpectedTracerError, lambda: api.grad(f)(3.)) + + def test_nondiff_argnums(self): + @partial(jax.custom_jvp, nondiff_argnums=(0,)) + def app(f, x): + return f(x) + def app_jvp(f, primals, tangents): + (x,), (t,) = primals, tangents + return app(f, x), 3 * t + app.defjvp(app_jvp) + + ans = app(lambda x: 2 * x, 1) + expected = 2 + self.assertAllClose(ans, expected, check_dtypes=False) + + ans = api.jvp(lambda x: app(lambda y: 2 * y, x), (1.,), (1.,)) + expected = (2., 3.) + self.assertAllClose(ans, expected, check_dtypes=False) + + def test_nondiff_argnames(self): + @partial(jax.custom_jvp, nondiff_argnames=('f',)) + def app(f, x): + return f(x) + + def app_jvp(f, primals, tangents): + (x,), (t,) = primals, tangents + return app(f, x), 3 * t + + app.defjvp(app_jvp) + + ans = app(lambda x: 2 * x, 1) + expected = 2 + self.assertAllClose(ans, expected, check_dtypes=False) + + def test_nondiff_arg_jit_tracer(self): + # This test would pass with "final-style" JIT tracing, but that was + # misleading: it doesn't work with "initial-style" staging, i.e. control + # flow primitives like jax.lax.scan or even pjit. The behavior isn't very + # useful either: instead of using nondiff_argnums here, a user can just pass + # such inputs as ordinary arguments, and ignore the corresponding tangents. + # Then nondiff_argnums can be reserved for (1) non jaxtype data (like a + # string- or callable-valued argument which parameterizes the function or + # rule) or (2) static data (e.g. integers which parameterize shapes). + raise unittest.SkipTest("behavior no longer supported") + + @partial(jax.custom_jvp, nondiff_argnums=(0,)) + def f(x, y): + return x * y + def f_jvp(x, primals, tangents): + (y,), (t_y,) = primals, tangents + return f(x, y), 5 * t_y + f.defjvp(f_jvp) + + @jit + def g(x, y): + return f(x, y) + + ans = api.jvp(lambda y: g(2., y), (3.,), (1.,)) + expected = (6., 5.) + self.assertAllClose(ans, expected, check_dtypes=False) + + def test_nondiff_arg_vmap_tracer(self): + @partial(jax.custom_jvp, nondiff_argnums=(0,)) + def f(x, y): + return x * y + def f_jvp(x, primals, tangents): + (y,), (t_y,) = primals, tangents + return f(x, y), 5 * t_y + f.defjvp(f_jvp) + + g = jax.vmap(f) + + ans = api.jvp(lambda y: g(jnp.array([2.]), y), + (jnp.array([3.]),), (jnp.array([1.]),)) + expected = (jnp.array([6.]), jnp.array([5.])) + self.assertAllClose(ans, expected, check_dtypes=False) + + def test_nondiff_arg_hiding_jvp_tracer(self): + def f(x): + @partial(jax.custom_jvp, nondiff_argnums=(0,)) + def g(h, x): + return h(x) + @g.defjvp + def g_jvp(h, primals, tangents): + x, = primals + t, = tangents + return g(h, x), 2. * t + h = lambda y: x + y # capture x + return g(h, x) + + with self.assertRaises(UnexpectedTracerError): + api.jvp(f, (2.,), (1.,)) + + def test_vmap_axes(self): + raise unittest.SkipTest("TODO") # TODO(mattjj): write test + + def test_pmap(self): + raise unittest.SkipTest("TODO") # TODO(mattjj): write test + + def test_missing_jvp_rule_error_message(self): + @jax.custom_jvp + def foo(x): + return x ** 2 + + self.assertRaisesRegex( + AttributeError, + r"No JVP defined for custom_jvp function foo using defjvp.", + lambda: foo(2)) + self.assertRaisesRegex( + AttributeError, + r"No JVP defined for custom_jvp function foo using defjvp.", + lambda: api.jvp(foo, (2.,), (1.,))) + self.assertRaisesRegex( + AttributeError, + r"No JVP defined for custom_jvp function foo using defjvp.", + lambda: api.grad(foo)(2.)) + + def test_jvp_rule_inconsistent_pytree_structures_error_message(self): + @jax.custom_jvp + def f(x): + return (x**2,) + + @f.defjvp + def foo_jvp(primals, tangents): + x, = primals + t, = tangents + return f(x), [2 * x * t, x] + + f(2.) # doesn't crash + self.assertRaisesRegex( + TypeError, + re.escape( + "Custom JVP rule foo_jvp for function f " + "must produce primal and tangent outputs " + "with equal container (pytree) structures, but got " + "{} and {} respectively.".format( + jax.tree.structure((1,)), + jax.tree.structure([1, 2])) + ), + lambda: api.jvp(f, (2.,), (1.,))) + + def test_primal_tangent_aval_disagreement_error_message(self): + @jax.custom_jvp + def f(x): + return x ** 2 + + @f.defjvp + def foo_jvp(primals, tangents): + x, = primals + t, = tangents + return f(x), jnp.reshape(t, (1,)) + + f(2.) # doesn't crash + self.assertRaisesRegex( + TypeError, + re.escape( + "Custom JVP rule must produce primal and tangent outputs " + "with corresponding shapes and dtypes. " + "Expected float32[] (tangent type of float32[]) but got float32[1]."), + lambda: api.jvp(f, (jnp.float32(2.),), (jnp.float32(1.),))) + + + def test_jvp_rule_doesnt_return_pair_error_message(self): + # https://github.com/jax-ml/jax/issues/2516 + + @jax.custom_jvp + def f(x): + return x ** 2 + + @f.defjvp + def foo_jvp(primals, tangents): + x, = primals + t, = tangents + return t + + f(2.) # doesn't crash + self.assertRaisesRegex( + TypeError, + re.escape( + "Custom JVP rule foo_jvp for function f " + "must produce a pair (list or tuple of length two) " + "representing primal and tangent outputs, but got 1.0"), + lambda: api.jvp(f, (2.,), (1.,))) + + def test_jvp_rule_primal_out_type_doesnt_match_primal_error_message(self): + # https://github.com/lucidrains/flash-attention-jax/issues/7 + + def scan_apply(f, x): + y, _ = jax.lax.scan(lambda x, _: (f(x), None), x, None, length=1) + return y + + @jax.custom_jvp + def f(x): + return x + + @f.defjvp + def f_jvp(primals, tangents): + (x,), (xdot,) = primals, tangents + return (x, x), (xdot, xdot) + + x = jnp.float32(1.) + self.assertRaisesRegex( + TypeError, + re.escape( + "Custom JVP rule f_jvp for function f must produce a pair " + "(list or tuple of length two) where the first element represents " + "the primal output (equal in value to the output of the " + "custom_jvp-decorated function f, and in particular of the " + "same container/pytree structure), but instead the JVP rule " + "output's first element had container/pytree structure:\n" + " (float32[], float32[])\n" + "while the custom_jvp-decorated function f had output " + "container/pytree structure:\n" + " float32[]." + ), + lambda: jax.jvp(lambda x: scan_apply(f, x), (x,), (x,))) + + @f.defjvp + def f_jvp2(primals, tangents): + (x,), (xdot,) = primals, tangents + return jnp.zeros((3, *x.shape), x.dtype), xdot + + self.assertRaisesRegex( + TypeError, + re.escape( + "Custom JVP rule f_jvp2 for function f must produce a pair " + "(list or tuple of length two) where the first element represents " + "the primal output (equal in value to the output of the " + "custom_jvp-decorated function f, and in particular " + "with leaves of the same shape/dtype), but instead the JVP rule " + "output's first element had shapes/dtypes of:\n" + " float32[3]\n" + "while the custom_jvp-decorated function f had output shapes/dtypes" + " of:\n" + " float32[]" + ), + lambda: jax.jvp(lambda x: scan_apply(f, x), (x,), (x,))) + + def test_multiple_rule_invocations(self): + @jax.custom_jvp + def expit(x): + return 1 / (1 + lax.exp(-x)) + + @expit.defjvp + def _expit_jvp(primals, tangents): + (x,), (t,) = primals, tangents + ans = expit(x) + t_out = t * ans * (1 - ans) + return ans, t_out + + def scanned_fun(c, _): + return [expit(c[0])] + [c[i-1] + c[i] for i in range(1, len(c))], None + + def foo(x): + zero = jnp.zeros_like(x) + c, _ = lax.scan(scanned_fun, [x, zero, zero, zero, zero], None, length=10) + return c[-1] + + # just make sure these don't crash + foo(3.) + grad(foo)(3.) + grad(lambda x: jax.vmap(foo)(x).sum())(jnp.arange(3.)) + + def test_hard_stuff(self): + arr = jnp.ones((5, 2, 2)) + api.jit(jax.vmap(jnp.linalg.det))(arr) # doesn't crash + + def test_hard_stuff2(self): + @jax.custom_jvp + def f(x): + return np.zeros(x.shape, x.dtype) + + @f.defjvp + def f_jvp(primals, tangents): + x, = primals + t, = tangents + return f(x), t + + # don't crash + jax.jit(jax.vmap(f))(jnp.arange(3.)) + jax.jit(jax.vmap(jax.grad(f)))(jnp.arange(3.)) + jax.jit(jax.grad(lambda x: jax.vmap(f)(x).sum()))(jnp.arange(3.)) + jax.grad(lambda x: jax.vmap(f)(x).sum())(jnp.arange(3.)) + jax.jvp(jax.vmap(f), (jnp.arange(3.),), (jnp.ones(3),)) + + def test_hard_stuff3(self): + @jax.custom_jvp + def relu(x): + return jnp.maximum(x, 0) + + @relu.defjvp + def _relu_jvp(primals, tangents): + x, = primals + t, = tangents + return relu(x), lax.select(x > 0, t, lax.full_like(t, 0)) + + def scanned_fun(c, _): + return [relu(c[0])] + [c[i-1] + c[i] for i in range(1, len(c))], None + + def f(x): + zero = jnp.zeros_like(x) + c, _ = lax.scan(scanned_fun, [x, zero, zero, zero, zero], None, length=10) + return c[-1] + + # don't crash + jax.jit(jax.vmap(f))(jnp.arange(3.)) + jax.jit(jax.vmap(jax.grad(f)))(jnp.arange(3.)) + jax.jit(jax.grad(lambda x: jax.vmap(f)(x).sum()))(jnp.arange(3.)) + jax.grad(lambda x: jax.vmap(f)(x).sum())(jnp.arange(3.)) + jax.jvp(jax.jit(jax.vmap(f)), (jnp.arange(3.),), (jnp.ones(3),)) + + def test_eval_shape(self): + @jax.custom_jvp + def expit(x): + return 1 / (1 + lax.exp(-x)) + + @expit.defjvp + def _expit_jvp(primals, tangents): + (x,), (t,) = primals, tangents + ans = expit(x) + t_out = t * ans * (1 - ans) + return ans, t_out + + # don't crash + api.eval_shape(expit, jnp.ones((2, 3))) + api.eval_shape(api.grad(lambda x: expit(x).sum()), jnp.ones((2, 3))) + + def test_jaxpr_zeros(self): + # from https://github.com/jax-ml/jax/issues/2657 + @jax.custom_jvp + def f(A, b): + return A @ b + + def f_jvp(primals, tangents): + A, b = primals + dA, db = tangents + z = f(A, b) + dz = A @ db + dA @ b + return z, dz + + f.defjvp(f_jvp) + + def experiment(theta): + def step(q, _): + z = f(jnp.eye(3), jnp.ones(3) * theta) + q += z[0] + return q, q + + q = 0. + q, _ = lax.scan(step, q, None, 4) + return q + + grad(experiment)(1.) # doesn't crash + + def test_linear_in_scan(self): + @jax.custom_jvp + def f(x): + return -x + + @f.defjvp + def f_jvp(primals, tangents): + x, = primals + x_dot, = tangents + return f(x), f(x_dot) + + def foo(x): + out, _ = lax.scan(lambda c, _: (f(c), None), x, None, length=1) + return out + + ans = api.grad(foo)(3.) + expected = -1. + self.assertAllClose(ans, expected, check_dtypes=False) + + def test_custom_jvps_first_rule_is_none(self): + # https://github.com/jax-ml/jax/issues/3389 + @jax.custom_jvp + def f(x, y): + return x ** 2 * y + + f.defjvps(None, lambda x_dot, primal_out, x, y: 2 * x * y * x_dot) + ans = grad(f, 1)(2., 3.) # doesn't crash + expected = 12. + self.assertAllClose(ans, expected, check_dtypes=False) + + def test_concurrent_initial_style(self): + # https://github.com/jax-ml/jax/issues/3843 + def unroll(param, sequence): + def scan_f(prev_state, inputs): + return prev_state, jax.nn.sigmoid(param * inputs) + return jnp.sum(jax.lax.scan(scan_f, None, sequence)[1]) + + def run(): + return jax.grad(unroll)(jnp.array(1.0), jnp.array([1.0])) + + expected = run() + + # we just don't want this to crash + n_workers = 2 + with concurrent.futures.ThreadPoolExecutor(max_workers=n_workers) as e: + futures = [] + for _ in range(n_workers): + futures.append(e.submit(run)) + results = [f.result() for f in futures] + for ans in results: + self.assertAllClose(ans, expected) + + def test_nondiff_argnums_vmap_tracer(self): + # https://github.com/jax-ml/jax/issues/3964 + @partial(jax.custom_jvp, nondiff_argnums=(0, 2)) + def sample(shape, param, seed): + return jax.random.uniform(key=seed, shape=shape, minval=param) + + @sample.defjvp + def sample_jvp(shape, seed, primals, tangents): + param, = primals + dparam, = tangents + dparam = jnp.broadcast_to(dparam, shape) + samples = sample(shape, param, seed) + return samples, samples * dparam # dummy jvp for proof of concept + + # check these don't crash + jax.vmap(lambda seed: sample((2,3), 1., seed))( + jax.random.split(jax.random.key(1), 10)) + jax.jvp(lambda x: sample((2, 3), x, jax.random.key(1)), + (1.,), (1.,)) + + def test_fun_with_nested_calls_2(self): + def call(f, *args): + f = jax.custom_jvp(f) + f.defjvp(lambda primals, tangents: (f(*primals), sum(tangents))) + return f(*args) + + def fun_with_nested_calls_2(x): + def bar(y): + def baz(w): + q = call(lambda x: y, x) + q = q + call(lambda: y) + q = q + call(lambda y: w + y, y) + q = call(lambda w: call(jnp.sin, x) * y, 1.0) + q + return q + return api.jit(baz)(x) + return call(bar, x) + + # test these don't crash + self.assertAllClose(api.jit(fun_with_nested_calls_2)(3.), + fun_with_nested_calls_2(3.)) + api.vmap(fun_with_nested_calls_2)(jnp.arange(3.)) + + def test_closure_with_vmap(self): + # https://github.com/jax-ml/jax/issues/3822 + alpha = np.float32(2.) + + def sample(seed): + @jax.custom_jvp + def f(alpha): + return jax.random.gamma(seed, alpha, shape=[]) + + @f.defjvp + def f_jvp(primal, tangent): + alpha = primal + dalpha = tangent + sample = f(alpha) + partial_alpha = lax.random_gamma_grad(alpha, sample) + return sample, partial_alpha * dalpha + return f(alpha) + + api.vmap(sample)(jax.random.split(jax.random.key(1), 3)) # don't crash + + def test_closure_with_vmap2(self): + # https://github.com/jax-ml/jax/issues/8783 + def h(z): + def f(x): + @jax.custom_jvp + def g(y): + return x * y + + # NOTE: rule closes over vmap tracer + @g.defjvp + def g_jvp(primals, tangents): + (y,), (ydot,) = primals, tangents + return x * y, x * ydot + + return g(z) # NOTE: no vmapped arg + + return jax.vmap(f)(jnp.arange(3., dtype='float32')) + + primals, tangents = jax.jvp(h, (jnp.float32(1.),), (jnp.float32(2.),)) + self.assertAllClose(primals , jnp.arange(3., dtype='float32')) + self.assertAllClose(tangents, 2 * jnp.arange(3., dtype='float32')) + + def test_float0(self): + scalar_float0 = jnp.zeros((), dtype=float0) + @jax.custom_jvp + def f(x, y): + return x, y + def f_jvp(primals, _): + x, y = primals + return (x, y), (2., jax.custom_derivatives.zero_from_primal(y)) + f.defjvp(f_jvp) + + primals = (2., 3) + tangents = (np.ones(()), scalar_float0) + expected_tangents = (2., scalar_float0) + self.assertAllClose(api.jvp(f, primals, tangents), + (primals, expected_tangents)) + + def test_float0_initial_style(self): + scalar_float0 = jnp.zeros((), dtype=float0) + @jax.custom_jvp + def f(x, y): + return x, y + def f_jvp(primals, _): + x, y = primals + return (x, y), (2., jax.custom_derivatives.zero_from_primal(y)) + f.defjvp(f_jvp) + + def foo(x, y): + out, _ = lax.scan(lambda c, _: (f(*c), None), (x, y), None, length=1) + return out + + primals = (2., 3) + tangents = (np.ones(()), scalar_float0) + expected_tangents = (2., scalar_float0) + + self.assertAllClose(api.jvp(foo, primals, tangents), + (primals, expected_tangents)) + + def test_remat(self): + @jax.custom_jvp + def f(x): + return jnp.sin(x) + def f_jvp(primals, tangents): + x, = primals + g, = tangents + return f(x), 2 * jnp.cos(x) * g + f.defjvp(f_jvp) + + @jax.remat + def g(x): + return f(f(x)) + + ans = g(2.) + expected = np.sin(np.sin(2.)) + self.assertAllClose(ans, expected, check_dtypes=False) + + ans = api.grad(g)(2.) + expected = 4. * api.grad(lambda x: jnp.sin(jnp.sin(x)))(2.) + self.assertAllClose(ans, expected, check_dtypes=False) + + def test_remat_higher_order(self): + @jax.custom_jvp + def f(x): + return jnp.sin(x) + def f_jvp(primals, tangents): + x, = primals + g, = tangents + return f(x), 2 * jnp.cos(x) * g + f.defjvp(f_jvp) + + def g(x): + return f(f(x)) + + ans = api.grad(api.grad(new_checkpoint(g)))(2.) + expected = api.grad(api.grad(g))(2.) + self.assertAllClose(ans, expected, check_dtypes=False) + + ans = api.grad(new_checkpoint(api.grad(g)))(2.) + expected = api.grad(api.grad(g))(2.) + self.assertAllClose(ans, expected, check_dtypes=False) + + ans = api.grad(api.grad(api.grad(new_checkpoint(g))))(2.) + expected = api.grad(api.grad(api.grad(g)))(2.) + self.assertAllClose(ans, expected, check_dtypes=False) + + def test_initial_style_vmap_2(self): + # This is like test_initial_style_vmap except the primal function closes + # over an array constant. + y = jnp.arange(1., 4.) + + @jax.custom_jvp + def f(x): + assert jnp.ndim(x) == 0 + return 3 * x * jnp.sum(y) + def f_jvp(primals, tangents): + x, = primals + g, = tangents + return f(x), 2 * g + f.defjvp(f_jvp) + + def foo(x): + out, _ = lax.scan(lambda c, _: (f(c), None), x, None, length=1) + return out + + ans = api.grad(lambda x: api.vmap(foo)(x).sum())(jnp.ones(3)) + expected = 2. * jnp.ones(3) + self.assertAllClose(ans, expected, check_dtypes=False) + + ans = api.grad(lambda x: api.vmap(api.jit(foo))(x).sum())(jnp.ones(3)) + expected = 2. * jnp.ones(3) + self.assertAllClose(ans, expected, check_dtypes=False) + + ans = api.grad(lambda x: api.jit(api.vmap(foo))(x).sum())(jnp.ones(3)) + expected = 2. * jnp.ones(3) + self.assertAllClose(ans, expected, check_dtypes=False) + + ans = api.grad(api.jit(lambda x: api.vmap(foo)(x).sum()))(jnp.ones(3)) + expected = 2. * jnp.ones(3) + self.assertAllClose(ans, expected, check_dtypes=False) + + ans = api.jit(api.grad(lambda x: api.vmap(foo)(x).sum()))(jnp.ones(3)) + expected = 2. * jnp.ones(3) + self.assertAllClose(ans, expected, check_dtypes=False) + + def test_custom_jvp_vmap_broadcasting_interaction(self): + # https://github.com/jax-ml/jax/issues/6452 + def f2(y, z): + v1 = z + v2 = jnp.sum(y) + z + return jnp.logaddexp(v1, v2) + + def f1(y, z): + v = api.vmap(lambda _y: f2(_y, z))(y) + return jnp.sum(v) + + y = jnp.ones((3, 2)) + f = lambda z: f1(y, z) + z = 0.1 + val, g = api.value_and_grad(f)(z) + self.assertEqual(val.shape, ()) + self.assertEqual(g.shape, ()) + + def test_custom_jvp_vmap_broadcasting_interaction_2(self): + # https://github.com/jax-ml/jax/issues/5849 + @jax.custom_jvp + def transform(box, R): + if jnp.isscalar(box) or box.size == 1: + return R * box + elif box.ndim == 2: + return jnp.einsum('ij,j->i', box, R) + raise ValueError() + + @transform.defjvp + def transform_jvp(primals, tangents): + box, R = primals + dbox, dR = tangents + return (transform(box, R), dR + transform(dbox, R)) + + def periodic_general(box): + def displacement_fn(Ra, Rb, **kwargs): + _box = kwargs.get('box', box) + return transform(_box, Ra - Rb) + + return displacement_fn + + N = 250 + + scalar_box = 1.0 + displacement = periodic_general(scalar_box) + + key = jax.random.key(0) + R = jax.random.uniform(key, (N, 2)) + + def energy_fn(box): + d = partial(displacement, box=box) + d = api.vmap(api.vmap(d, (None, 0)), (0, None)) + return jnp.sum(d(R, R) ** 2) + + self.assertEqual(grad(energy_fn)(scalar_box).shape, ()) + + def test_custom_jvp_implicit_broadcasting(self): + # https://github.com/jax-ml/jax/issues/6357 + if config.enable_x64.value: + raise unittest.SkipTest("test only applies when x64 is disabled") + + @jax.custom_jvp + def projection_unit_simplex(x: jax.Array) -> jax.Array: + """Projection onto the unit simplex.""" + s = 1.0 + n_features = x.shape[0] + u = jnp.sort(x)[::-1] + cssv = jnp.cumsum(u) - s + ind = jnp.arange(n_features, dtype=x.dtype) + 1 + cond = u - cssv / ind > 0 + idx = jnp.count_nonzero(cond) + threshold = cssv[idx - 1] / idx.astype(x.dtype) + return jax.nn.relu(x - threshold) + + + @projection_unit_simplex.defjvp + def projection_unit_simplex_jvp(primals, tangents): + x, = primals + x_dot, = tangents + primal_out = projection_unit_simplex(x) + supp = (primal_out > 0).astype(x_dot.dtype) + card = jnp.count_nonzero(supp).astype(x_dot.dtype) + tangent_out = supp * x_dot - (jnp.dot(supp, x_dot) / card) * supp + return primal_out, tangent_out + + rng = self.rng() + x = rng.rand(5).astype(np.float32) + + J_rev = jax.jacrev(projection_unit_simplex)(x) + J_fwd = jax.jacfwd(projection_unit_simplex)(x) + + p = projection_unit_simplex(x) + support = (p > 0).astype(jnp.float32) + cardinality = jnp.count_nonzero(support).astype(support.dtype) + J_true = jnp.diag(support) - jnp.outer(support, support) / cardinality + self.assertAllClose(J_true, J_fwd) + self.assertAllClose(J_true, J_rev) + + proj = jax.vmap(projection_unit_simplex) + + def fun(X): + return jnp.sum(proj(X) ** 2) + + rng = self.rng() + X = rng.rand(4, 5).astype(np.float32) + U = rng.rand(4, 5) + U /= np.sqrt(np.sum(U ** 2)) + U = U.astype(np.float32) + + eps = 1e-3 + dir_deriv_num = (fun(X + eps * U) - fun(X - eps * U)) / (2 * eps) + dir_deriv = jnp.vdot(jax.grad(fun)(X), U) + self.assertAllClose(dir_deriv, dir_deriv_num, atol=1e-3) + + def test_vmap_inside_defjvp(self): + # https://github.com/jax-ml/jax/issues/3201 + seed = 47 + key = jax.random.key(seed) + mat = jax.random.normal(key, (2, 3)) + + @jax.custom_jvp + def f(mat, aux): + num_rows, num_cols = mat.shape + return jnp.ones((num_rows, 1)) / num_cols + + @f.defjvp + def f_jvp(primals, tangents): + mat, aux = primals + vec, _ = tangents + output = f(*primals) + num_rows, num_cols = mat.shape + size = num_rows * num_cols + # ----- + bd_mat = mat.reshape(1, 1, num_rows, num_cols) + bd_mat = jnp.tile(bd_mat, reps=(num_rows, num_cols)) + bd_mat = bd_mat.reshape(size, num_rows, num_cols) + # ----- + rowsum = jnp.sum(mat, axis=1, keepdims=True) + colsum = jnp.sum(mat, axis=0, keepdims=True) + bd_rowsum = jnp.tile(rowsum, reps=(1, num_rows)) + bd_colsum = jnp.tile(colsum, reps=(num_cols, 1)) + # ----- + bd_vec = vec.reshape(size, 1) + # ----- + def operate(mx, val): + buf = 0 + for i in range(2): + buf = buf + jnp.matmul(mx, bd_colsum) / jnp.power(aux, i) + buf = jnp.matmul(bd_rowsum, buf) + return buf * val[None, :] + # ----- + # Vertorizing will raise shape error + bd_buf = jax.vmap(operate, in_axes=(0, 0), out_axes=0)(bd_mat, bd_vec) + # ----- + bd_buf = bd_buf / aux + jvp = jnp.sum(bd_buf, axis=0) + jvp = jnp.mean(jvp, axis=1, keepdims=True) + # ----- + # JVP ends successfully, but still raise an error + return (output, jvp) + + jax.grad(lambda mat, aux: jnp.sum(f(mat, aux)))(mat, 0.5) # doesn't crash + + def test_custom_jvp_unbroadcasting(self): + # https://github.com/jax-ml/jax/issues/3056 + a = jnp.array([1., 1.]) + + @jax.custom_jvp + def f(x): + return a * x + + @f.defjvp + def f_jvp(primals, tangents): + x, = primals + dx, = tangents + return a * x, a * dx + + shape = grad(lambda x: jnp.sum(f(x)))(jnp.array(1.)).shape + self.assertEqual(shape, ()) + + def test_maybe_perturbed_internal_helper_function(self): + # This is a unit test for an internal API. We include it so as not to + # regress https://github.com/jax-ml/jax/issues/9567. For an explanation of + # this helper function, see https://github.com/jax-ml/jax/issues/6415. + def f(x): + def g(y, _): + z = y * x + self.assertTrue(custom_derivatives._maybe_perturbed(z)) + return y, None + g(1, None) + return lax.scan(g, 1, xs=None, length=1)[0] + + jax.jvp(f, (1.0,), (1.0,)) # assertions inside f + + def test_maybe_perturbed_int_regression(self): + # see https://github.com/jax-ml/jax/discussions/9951 + + @jax.jit + def f(): + x = jnp.array(1) + _, aux_args = custom_derivatives.closure_convert(lambda: x) + self.assertEmpty(aux_args) + f() + + def test_sinc_constant_function_batching(self): + # https://github.com/jax-ml/jax/pull/10756 + batch_data = jnp.arange(15.).reshape(5, 3) + + @jax.vmap + def f(x): + return jax.lax.map(jnp.sinc, x) + g = lambda param: f(param * batch_data).sum() + + @jax.vmap + def f_ref(x): + return jnp.stack([jnp.sinc(x_) for x_ in x]) + g_ref = lambda param: f_ref(param * batch_data).sum() + + grad = jax.grad(g )(0.1) # doesn't crash + grad_ref = jax.grad(g_ref)(0.1) + self.assertAllClose(grad, grad_ref, check_dtypes=False) + + @parameterized.named_parameters( + ('jit_vmap', True, True), + ('jit', True, False), + ('vmap', False, True), + ('', False, False), + ) + def test_symbolic_zero_custom_jvp(self, maybe_jit, maybe_vmap): + def f(static_scalar, static_array, dyn_scalar, dyn_array): + out1 = static_scalar + dyn_scalar + out2 = static_array + dyn_array + return out1, out2 + + def _pack(x): + return lax.broadcast(x, (1,)) + + def _unpack(x): + (x,) = x + return x + + def _vmap(fun): + def _fun(*args): + args = jax.tree.map(_pack, args) + out = jax.vmap(fun)(*args) + out = jax.tree.map(_unpack, out) + return out + return _fun + + f = jax.custom_jvp(f) + + @partial(f.defjvp, symbolic_zeros=True) + def f_jvp(primals, tangents): + static_scalar, *_ = primals + t_static, t_static_arr, t_dyn_scalar, t_dyn_array = tangents + self.assertIs(type(t_static) , jax.custom_derivatives.SymbolicZero) + self.assertIs(type(t_static_arr), jax.custom_derivatives.SymbolicZero) + self.assertEqual(t_static.shape, ()) + self.assertEqual(t_static_arr.shape, (2,)) + return f(*primals), (static_scalar + 90, t_dyn_array + 91) + + def g(dyn_scalar, dyn_array): + if maybe_vmap: + f_ = _vmap(f) + else: + f_ = f + return f_(1., jnp.array([2., 3.]), dyn_scalar, dyn_array) + + def run(primal_ins, tangent_ins): + return jax.jvp(g, primal_ins, tangent_ins) + + if maybe_jit: + run = jax.jit(run) + + primal_ins = (4., jnp.array([5., 6.])) + tangent_ins = (7., jnp.array([8., 9.])) + primal_outs, tangent_outs = run(primal_ins, tangent_ins) + primal_out1, primal_out2 = primal_outs + tangent_out1, tangent_out2 = tangent_outs + scalar_type = jax.Array if maybe_jit or maybe_vmap else float + self.assertIsInstance(primal_out1, scalar_type) + self.assertAllClose(primal_out1, 5.) + self.assertIsInstance(tangent_out1, scalar_type) + self.assertAllClose(tangent_out1, 91.) + self.assertIsInstance(primal_out2, jax.Array) + self.assertArraysAllClose(primal_out2, jnp.array([7., 9.])) + self.assertIsInstance(tangent_out2, jax.Array) + self.assertArraysAllClose(tangent_out2, jnp.array([99., 100.])) + + def test_symbolic_zero_custom_jvp_vmap_output(self): + @jax.custom_jvp + def f(x, y): + return x * y + + @partial(f.defjvp, symbolic_zeros=True) + def f_jvp(primals, tangents): + x, y = primals + x_dot, y_dot = tangents + self.assertIs(type(y_dot), jax.custom_derivatives.SymbolicZero) + return f(x, y), y_dot + + jax.grad(lambda x, y: jax.vmap(f)(x, y).sum())(jnp.ones(3), jnp.ones(3)) + + def test_symbolic_zeros_memoization_caching(self): + # Tests multiple zero patterns for partial_eval._memoize, and also tests + # that we're okay with stores being occupied with equal values. + + @jax.custom_jvp + def f(x, y): + return x * y + + @partial(f.defjvp, symbolic_zeros=True) + def f_jvp(primals, tangents): + x, y = primals + x_dot, y_dot = tangents + return f(x, y), y_dot + + f_ = core.jaxpr_as_fun(jax.make_jaxpr(f)(2., 3.)) + _ = jax.linearize(f_, 2., 3.) + _ = jax.linearize(lambda x: f_(x, 3.), 2.) # don't crash! + + def test_symbolic_zeros_under_jit(self): + # https://github.com/jax-ml/jax/issues/14833 + Zero = jax.custom_derivatives.SymbolicZero + + @jax.custom_jvp + def f(x, y): + return x * y + + @partial(f.defjvp, symbolic_zeros=True) + def fjvp(primals, tangents): + x, y = primals + tx, ty = tangents + assert type(tx) is not Zero or type(ty) is not Zero + return f(x, y), ( + ty if type(tx) is Zero else + tx if type(ty) is Zero else + tx + ty) + + jax.jacfwd(jax.jit(f))(0.1, 0.2) # don't crash + + def test_custom_jvp_functools_partial(self): + def fun(x, y, a): + return x + y * a + + fun_wrapped = functools.partial(fun, a = 0.1) + + def jvp_fn(primals, tangents): + return jax.jvp(fun_wrapped, primals, tangents) + + fn = jax.custom_jvp(fun_wrapped) + fn.defjvp(jvp_fn) + + self.assertEqual((1.0, 0.1), jax.grad(lambda args: fn(*args))((1.0, 2.0))) + + def test_run_rules_more_than_once(self): + # https://github.com/jax-ml/jax/issues/16614 + + @jax.custom_jvp + def f(x, y): + return x + + @partial(f.defjvp, symbolic_zeros=True) + def f_jvp(primals, tangents): + x, _ = primals + x_dot, _ = tangents + return x, x_dot + + def body(x_y, _): + x, y = x_y + return (f(x, y), x), None + + @jax.grad + def g(x): + (out, _), _ = lax.scan(body, (x, 1.), xs=None, length=2) + return out + + g(1.) # doesn't crash + + def test_dce(self): + @jax.custom_jvp + def f(x, y): + return jnp.sin(x), x + jnp.cos(y) + + @f.defjvp + def f_jvp(primals, tangents): + x, y = primals + dx, dy = tangents + return f(x, y), (2.0 * jnp.cos(x) * dx, 1.5 * dx - 0.5 * jnp.sin(y) * dy) + + def check_jaxpr(jaxpr, used_outs, includes, excludes): + dce_jaxpr, _ = pe.dce_jaxpr(jaxpr, used_outs) + if not dce_jaxpr.eqns: + assert not includes + return + call_jaxpr = dce_jaxpr.eqns[0].params["call_jaxpr"] + for prim in includes: + assert any(eqn.primitive == prim for eqn in call_jaxpr.eqns) + for prim in excludes: + assert all(eqn.primitive != prim for eqn in call_jaxpr.eqns) + + x, y = 0.1, -1.3 + jaxpr = jax.make_jaxpr(f)(x, y).jaxpr + check_jaxpr(jaxpr, [True, True], [lax.sin_p, lax.cos_p], []) + check_jaxpr(jaxpr, [True, False], [lax.sin_p], [lax.cos_p]) + check_jaxpr(jaxpr, [False, True], [lax.cos_p], [lax.sin_p]) + check_jaxpr(jaxpr, [False, False], [], [lax.sin_p, lax.cos_p]) + + def dce_jaxpr_as_fun(jaxpr, used_outs): + jaxpr_, _ = pe.dce_jaxpr(jaxpr, used_outs) + fun = core.jaxpr_as_fun(pe.close_jaxpr(jaxpr_)) + return lambda *args: fun(*args)[0] + + f0 = dce_jaxpr_as_fun(jaxpr, [True, False]) + f1 = dce_jaxpr_as_fun(jaxpr, [False, True]) + self.assertAllClose( + api.jvp(f0, (x, y), (1.0, 0.0)), (f0(x, y), 2.0 * jnp.cos(x))) + self.assertAllClose( + api.jvp(f0, (x, y), (0.0, 1.0)), (f0(x, y), 0.0)) + self.assertAllClose( + api.jvp(f1, (x, y), (1.0, 0.0)), (f1(x, y), 1.5)) + self.assertAllClose( + api.jvp(f1, (x, y), (0.0, 1.0)), (f1(x, y), -0.5 * jnp.sin(y))) + + def test_resolve_kwargs_error_message(self): + @jax.custom_jvp + def f(x, y, *, z=None): + return jnp.sin(x), x + jnp.cos(y) + + @f.defjvp + def f_jvp(primals, tangents): + self.fail("should not be executed") + + with self.assertRaisesRegex( + TypeError, + r"The input arguments to the custom_jvp-decorated function f(.*)\n" + r"missing a required argument: 'y'" + ): + f(0.5) + + with self.assertRaisesRegex( + TypeError, + r"The input arguments to the custom_jvp-decorated function f(.*)\n" + "The following keyword arguments could not be resolved to positions: z" + ): + f(0.5, 0.1, z=1.0) + + def test_symbolic_zero_custom_jvp_vmap_doesnt_instantiate(self): + @jax.custom_jvp + def f(x, y): + return y + + def f_jvp(primals, tangents): + (x, y), (x_dot, y_dot) = primals, tangents + assert type(y_dot) is jax.custom_derivatives.SymbolicZero + return y, y_dot + + f.defjvp(f_jvp, symbolic_zeros=True) + + def g(x): + return f(x, f(x, 1.)) + + jax.jvp(jax.vmap(g), (jnp.ones(3),), (jnp.ones(3),)) # don't crash + + def test_symbolic_zero_under_vmap_of_jit(self): + # https://github.com/jax-ml/jax/issues/28144 + @jax.custom_jvp + def f(x): + return x + 1 + + @f.defjvp + def f_jvp(x, t): + (x,) = x + (t,) = t + z = jax.custom_derivatives.zero_from_primal(x, symbolic_zeros=True) + return f(x), z + + x = jnp.arange(3.0) + jax.jvp(jax.vmap(jax.jit(f)), (x,), (x,)) # doesn't crash + + def test_pretty_print(self): + @jax.custom_jvp + def f(x): + return x + 1 + + @f.defjvp + def f_jvp(primals, tangents): + return f(*primals), tangents[0] + + x = jnp.array([4.2], dtype=jnp.float32) + jaxpr = jax.make_jaxpr(f)(x) + actual = jaxpr.pretty_print(use_color=False) + expected = textwrap.dedent( + """ + { lambda ; a:f32[1]. let + b:f32[1] = custom_jvp_call[ + name=f + call_jaxpr={ lambda ; c:f32[1]. let d:f32[1] = add c 1.0:f32[] in (d,) } + jvp=f_jvp + symbolic_zeros=False + ] a + in (b,) } + """).strip() + self.assertEqual(actual, expected) + + + +class CustomVJPTest(jtu.JaxTestCase): + + def test_basic(self): + @jax.custom_vjp + def f(x): + return jnp.sin(x) + def f_fwd(x): + return f(x), jnp.cos(x) + def f_rev(cos_x, g): + return (2 * cos_x * g,) + f.defvjp(f_fwd, f_rev) + + x = 3. + self.assertAllClose(f(x), jnp.sin(x)) + self.assertAllClose(api.grad(f)(x), 2 * jnp.cos(x)) + self.assertAllClose(api.value_and_grad(f)(x), + (jnp.sin(x), 2 * jnp.cos(x))) + + def test_invariance(self): + @jax.custom_vjp + def f(x): + return jnp.cos(2 * x) / 2. + def f_fwd(x): + return (f(x), x) + def f_rev(x, g): + return (g * 3,) + f.defvjp(f_fwd, f_rev) + def f2(x): + y, _ = api.value_and_grad(f)(x) + return y + def f3(x): + y, _ = api.value_and_grad(f2)(x) + return y + x = 1. + self.assertAllClose(f(x), f2(x), check_dtypes=False) + self.assertAllClose(f(x), f3(x), check_dtypes=False) + self.assertAllClose(api.grad(f)(x), api.grad(f2)(x), + check_dtypes=False) + self.assertAllClose(api.grad(f)(x), api.grad(f3)(x), + check_dtypes=False) + + def test_python_control_flow(self): + @jax.custom_vjp + def f(x): + if x > 0: + return jnp.sin(x) + else: + return jnp.cos(x) + def f_fwd(x): + if x > 0: + return f(x), x + else: + return f(x), x + def f_rev(x, g): + if x > 0: + return (2 * g,) + else: + return (3 * g,) + f.defvjp(f_fwd, f_rev) + x = 2. + self.assertAllClose(f(x), jnp.sin(x)) + self.assertAllClose(f(-x), jnp.cos(-x)) + self.assertAllClose(api.value_and_grad(f)(x), (jnp.sin(x), 2.), + check_dtypes=False) + self.assertAllClose(api.value_and_grad(f)(-x), (jnp.cos(-x), 3.), + check_dtypes=False) + + def test_vmap(self): + @jax.custom_vjp + def f(x): + assert jnp.ndim(x) == 0 + return jnp.sin(x) + def f_fwd(x): + assert jnp.ndim(x) == 0 + return f(x), jnp.cos(x) + def f_rev(cos_x, g): + return (2 * cos_x * g,) + f.defvjp(f_fwd, f_rev) + + x = jnp.arange(3.) + xx = jnp.arange(6.).reshape(2, 3) + + # vmap of f + self.assertAllClose(api.vmap(f)(x), jnp.sin(x)) + self.assertAllClose(api.vmap(api.vmap(f))(xx), jnp.sin(xx)) + + # vmap of grad of f + self.assertAllClose(api.vmap(api.grad(f))(x), 2 * jnp.cos(x)) + self.assertAllClose(api.vmap(api.value_and_grad(f))(x), + (jnp.sin(x), 2 * jnp.cos(x))) + self.assertAllClose(api.vmap(api.vmap(api.grad(f)))(xx), 2 * jnp.cos(xx)) + self.assertAllClose(api.vmap(api.vmap(api.value_and_grad(f)))(xx), + (jnp.sin(xx), 2 * jnp.cos(xx))) + + # grad of vmap of f + self.assertAllClose(api.grad(lambda x: api.vmap(f)(x).sum())(x), + 2 * jnp.cos(x)) + self.assertAllClose(api.grad(lambda x: api.vmap(api.vmap(f))(x).sum())(xx), + 2 * jnp.cos(xx)) + + # vmap of grad of vmap of f + self.assertAllClose(api.vmap(api.grad(lambda x: api.vmap(f)(x).sum()))(xx), + 2 * jnp.cos(xx)) + + def test_jit(self): + @jax.custom_vjp + def f(x): + return jnp.sin(x) + def f_fwd(x): + return f(x), jnp.cos(x) + def f_rev(cos_x, g): + return (2 * cos_x * g,) + f.defvjp(f_fwd, f_rev) + + x = 3. + + # jit + self.assertAllClose(api.jit(f)(x), jnp.sin(x)) + self.assertAllClose(api.jit(api.jit(f))(x), jnp.sin(x)) + + # jit of grad + self.assertAllClose(api.jit(api.grad(f))(x), 2 * jnp.cos(x), + check_dtypes=False) + + # grad of jit + self.assertAllClose(api.grad(api.jit(f))(x), 2 * jnp.cos(x), + check_dtypes=False) + + def test_pytrees(self): + @jax.custom_vjp + def f(x): + return {'b': jnp.sin(x['a'])} + def f_fwd(x): + return f(x), {'r': jnp.cos(x['a'])} + def f_bwd(res, g): + cos_x = res['r'] + return ({'a': 2 * cos_x * g['b']},) + f.defvjp(f_fwd, f_bwd) + x = {'a': 3.} + self.assertAllClose(f(x)['b'], jnp.sin(x['a'])) + self.assertAllClose(api.grad(lambda x: f(x)['b'])(x), + {'a': 2 * jnp.cos(x['a'])}) + + def test_jvp_error(self): + @jax.custom_vjp + def f(x): + return jnp.sin(x) + def f_fwd(x): + return f(x), jnp.cos(x) + def f_rev(cos_x, g): + return (2 * cos_x * g,) + f.defvjp(f_fwd, f_rev) + + self.assertRaisesRegex( + TypeError, + r"can't apply forward-mode autodiff \(jvp\) to a custom_vjp function.", + lambda: api.jvp(f, (3.,), (1.,))) + self.assertRaisesRegex( + TypeError, + r"can't apply forward-mode autodiff \(jvp\) to a custom_vjp function.", + lambda: api.jvp(api.vmap(f), (jnp.arange(3.),), (jnp.ones(3),))) + self.assertRaisesRegex( + TypeError, + r"can't apply forward-mode autodiff \(jvp\) to a custom_vjp function.", + lambda: api.jvp(jit(f), (3.,), (1.,))) + + def test_kwargs(self): + # from https://github.com/jax-ml/jax/issues/1938 + @jax.custom_vjp + def my_fun(x, y, c=1.): + return c * (x + y) + my_fun.defvjp(lambda x, y, c=1.: (my_fun(c, y, c), None), + lambda _, g: (g, g, g)) + f = lambda x, y: jnp.square(my_fun(x, y, c=2.)).sum() + f(10., 5.) # doesn't crash + api.grad(f)(10., 5.) # doesn't crash + + def test_initial_style(self): + @jax.custom_vjp + def f(x): + return jnp.sin(x) + def f_fwd(x): + return f(x), jnp.cos(x) + def f_rev(cos_x, g): + return (2 * cos_x * g,) + f.defvjp(f_fwd, f_rev) + + def foo(x): + out, _ = lax.scan(lambda c, _: (f(c), None), x, None, length=1) + return out + + ans = api.grad(foo)(3.) + expected = 2. * jnp.cos(3.) + self.assertAllClose(ans, expected, check_dtypes=False) + + ans = api.grad(api.grad(foo))(3.) + expected = -2. * jnp.sin(3.) + self.assertAllClose(ans, expected) + + def test_initial_style_vmap(self): + @jax.custom_vjp + def f(x): + assert jnp.ndim(x) == 0 + return 3 * x + def f_fwd(x): + return f(x), jnp.cos(x) + def f_rev(cos_x, g): + return (2 * cos_x * g,) + f.defvjp(f_fwd, f_rev) + + def foo(x): + out, _ = lax.scan(lambda c, _: (f(c), None), x, None, length=1) + return out + + ans = api.vmap(foo)(jnp.arange(3.)) + expected = 3. * jnp.arange(3.) + self.assertAllClose(ans, expected, check_dtypes=False) + + ans = api.grad(lambda x: api.vmap(foo)(x).sum())(jnp.arange(3.)) + expected = 2. * jnp.cos(jnp.arange(3.)) + self.assertAllClose(ans, expected, check_dtypes=False) + + def test_nondiff_argnums(self): + @partial(jax.custom_vjp, nondiff_argnums=(0,)) + def app(f, x): + return f(x) + def app_fwd(f, x): + return app(f, x), jnp.cos(x) + def app_rev(f, cos_x, g): + return (cos_x * g,) + app.defvjp(app_fwd, app_rev) + + ans = app(lambda x: 2 * x, 1) + expected = 2 + self.assertAllClose(ans, expected, check_dtypes=False) + + ans = api.value_and_grad(lambda x: app(lambda y: 2 * y, x))(1.) + expected = (2., jnp.cos(1.)) + self.assertAllClose(ans, expected, check_dtypes=False) + + def test_nondiff_argnames(self): + @partial(jax.custom_vjp, nondiff_argnames=('f',)) + def app(f, x): + return f(x) + def app_fwd(f, x): + return app(f, x), jnp.cos(x) + def app_rev(f, cos_x, g): + return (cos_x * g,) + app.defvjp(app_fwd, app_rev) + + ans = app(lambda x: 2 * x, 1) + expected = 2 + self.assertAllClose(ans, expected, check_dtypes=False) + + ans = api.value_and_grad(lambda x: app(lambda y: 2 * y, x))(1.) + expected = (2., jnp.cos(1.)) + self.assertAllClose(ans, expected, check_dtypes=False) + + def test_nondiff_argnums_argnames(self): + @partial(jax.custom_vjp, nondiff_argnums=(0,), nondiff_argnames=('g',)) + def app(f, g, x): + return f(x) + g(x) + def app_fwd(f, g, x): + return app(f, g, x), jnp.cos(x) + def app_rev(f, g, cos_x, v): + return (cos_x * v,) + app.defvjp(app_fwd, app_rev) + + f = lambda x: 2 * x + g = lambda x: 2 * x + ans = app(f, g, 1) + expected = 4 + self.assertAllClose(ans, expected, check_dtypes=False) + + ans = api.value_and_grad(lambda x: app(f, g, x))(1.) + expected = (4., jnp.cos(1.)) + self.assertAllClose(ans, expected, check_dtypes=False) + + def test_closed_over_jit_tracer(self): + # See the comment in CustomJVPTest.test_nondiff_arg_jit_tracer. + raise unittest.SkipTest("behavior no longer supported") + + # This test is similar to test_nondiff_arg_tracer except it uses lexical + # closure rather than the nondiff_argnums mechanism. We decided to disallow + # tracers in nondiff_argnums to greatly simplify bookkeeping while still + # supporting the cases for which it is necessary. + def outer(x): + @jax.custom_vjp + def f(y): + return x * y + def f_fwd(y): + return f(y), jnp.cos(y) + def f_rev(cos_y, g): + return (cos_y * g,) + f.defvjp(f_fwd, f_rev) + return f + + @jit + def g(x, y): + return outer(x)(y) + + ans = g(2, 3.) + expected = 6. + self.assertAllClose(ans, expected, check_dtypes=False) + + ans = api.grad(g, 1)(2., 3.) + expected = jnp.cos(3.) + self.assertAllClose(ans, expected, check_dtypes=False) + + def test_closed_over_vmap_tracer(self): + def outer(x): + @jax.custom_vjp + def f(y): + return x * y + def f_fwd(y): + return f(y), jnp.cos(y) + def f_rev(cos_y, g): + return (cos_y * g,) + f.defvjp(f_fwd, f_rev) + return f + + @api.vmap + def g(x): + return outer(x)(3.) + + ans = g(np.arange(3.)) + expected = np.arange(3.) * 3 + self.assertAllClose(ans, expected, check_dtypes=False) + + def test_closed_over_tracer3(self): + def outer(x): + @jax.custom_vjp + def f(y): + return x * y + def f_fwd(y): + return f(y), (x, jnp.cos(y)) + def f_rev(res, g): + x, cos_y = res + return (cos_y * g * x,) + f.defvjp(f_fwd, f_rev) + return api.grad(f) + + @api.vmap + def g(x): + return outer(x)(3.) + + ans = g(np.arange(3.)) + expected = np.cos(3.) * np.arange(3.) + self.assertAllClose(ans, expected, check_dtypes=False) + + def test_nondiff_arg_tracer_error(self): + # This is similar to the old (now skipped) test_nondiff_arg_tracer, except + # we're testing for the error message that usage pattern now raises. + + @partial(jax.custom_vjp, nondiff_argnums=(0,)) + def f(x, y): + return x * y + def f_fwd(x, y): + return f(x, y), jnp.cos(y) + def f_rev(x, cos_y, g): + return (cos_y * g,) + f.defvjp(f_fwd, f_rev) + + @jit + def g(x, y): + return f(x, y) + + with self.assertRaisesRegex(UnexpectedTracerError, "custom_vjp"): + _ = g(2, 3.) + with self.assertRaisesRegex(UnexpectedTracerError, "custom_vjp"): + _ = api.grad(g, 1)(2., 3.) + + def test_vmap_axes(self): + raise unittest.SkipTest("TODO") # TODO(mattjj): write test + + def test_pmap(self): + raise unittest.SkipTest("TODO") # TODO(mattjj): write test + + def test_missing_vjp_rule_error(self): + @jax.custom_vjp + def foo(x): + return x ** 2 + + self.assertRaisesRegex( + AttributeError, + r"No VJP defined for custom_vjp function foo using defvjp.", + lambda: foo(2)) + self.assertRaisesRegex( + AttributeError, + r"No VJP defined for custom_vjp function foo using defvjp.", + lambda: api.grad(foo)(2.)) + + def test_vjp_rule_inconsistent_pytree_structures_error(self): + @jax.custom_vjp + def f(x): + return x + + def foo_fwd(x): + return x, None + + def foo_bwd(_, g): + return (g, g) + + f.defvjp(foo_fwd, foo_bwd) + + f(2) # doesn't crash + self.assertRaisesRegex( + TypeError, + re.escape( + "Custom VJP bwd rule must produce an output with the same container " + "(pytree) structure as the args tuple of the primal function, " + "and in particular must produce a tuple of length equal to the " + "number of arguments to the primal function, but got bwd output " + "structure {} for primal input structure {}.".format( + jax.tree.structure((1, 1)), + jax.tree.structure((1,))) + ), + lambda: api.grad(f)(2.)) + + def test_vjp_bwd_returns_non_tuple_error(self): + @jax.custom_vjp + def f(x): + return x + + def foo_fwd(x): + return x, None + + def foo_bwd(_, g): + return 2. * g # Should be a tuple + + f.defvjp(foo_fwd, foo_bwd) + with self.assertRaisesRegex(TypeError, "Custom VJP bwd rule .* must produce a tuple"): + api.grad(f)(3.) + + def test_fwd_rule_primal_out_type_doesnt_match_primal_error_message(self): + # https://github.com/lucidrains/flash-attention-jax/issues/7 + + def scan_apply(f, x): + y, _ = jax.lax.scan(lambda x, _: (f(x), None), x, None, length=1) + return y + + @jax.custom_vjp + def f(x): + return x + + def f_fwd(x): + return (x, x), None + + def f_bwd(_, y_bar): + return (y_bar,) + + f.defvjp(f_fwd, f_bwd) + + self.assertRaisesRegex( + TypeError, + re.escape( + "Custom VJP fwd rule f_fwd for function f must produce a pair " + "(list or tuple of length two) where the first element represents " + "the primal output (equal to the output of the " + "custom_vjp-decorated function f) and the second element " + "represents residuals (i.e. values stored from the forward " + "pass for use on the backward pass), but instead the fwd rule " + "output's first element had container/pytree structure:\n" + " (float32[], float32[])\n" + "while the custom_vjp-decorated function f had output " + "container/pytree structure:\n" + " float32[]." + ), + lambda: jax.grad(lambda x: scan_apply(f, x))(jnp.float32(1.))) + + def f_fwd2(x): + return jnp.zeros((3, *x.shape), x.dtype), None + + def f_bwd2(_, y_bar): + return (y_bar,) + + f.defvjp(f_fwd2, f_bwd2) + + self.assertRaisesRegex( + TypeError, + re.escape( + "Custom VJP fwd rule f_fwd2 for function f must produce a pair " + "(list or tuple of length two) where the first element represents " + "the primal output (equal to the output of the " + "custom_vjp-decorated function f) and the second element " + "represents residuals (i.e. values stored from the forward " + "pass for use on the backward pass), but instead the fwd rule " + "output's first element had shapes/dtypes of:\n" + " float32[3]\n" + "while the custom_vjp-decorated function f had output " + "shapes/dtypes of:\n" + " float32[]" + ), + lambda: jax.grad(lambda x: scan_apply(f, x))(jnp.float32(1.))) + + def test_issue2511(self): + arr = jnp.ones((5, 2, 2)) + foo = lambda x: api.vmap(jnp.linalg.det, (0,))(x) + api.jit(foo)(arr) # doesn't crash + + def test_lowering_out_of_traces(self): + # https://github.com/jax-ml/jax/issues/2578 + + class F(collections.namedtuple("F", ["a"])): + def __call__(self, x): + return jax.nn.relu(self.a) * x + + @jax.jit + def g(f, x): + return f(x) + + jax.grad(g, argnums=(1,))(F(2.0), 0.) # doesn't crash + + def test_clip_gradient(self): + # https://github.com/jax-ml/jax/issues/2784 + @jax.custom_vjp + def _clip_gradient(lo, hi, x): + return x # identity function when not differentiating + + def clip_gradient_fwd(lo, hi, x): + return x, (lo, hi,) + + def clip_gradient_bwd(res, g): + lo, hi = res + return (None, None, jnp.clip(g, lo, hi),) + + _clip_gradient.defvjp(clip_gradient_fwd, clip_gradient_bwd) + + def clip_gradient(x): + lo = -0.1 + hi = x + 0.1 + return _clip_gradient(lo, hi, x) + + g = jax.grad(clip_gradient)(0.1) # doesn't crash + self.assertAllClose(g, jnp.array(0.2)) + + def test_nestable_vjp(self): + # Verify that https://github.com/jax-ml/jax/issues/3667 is resolved. + def f(x): + return x ** 2 + + @jax.custom_vjp + def g(x): + return f(x) + + def g_fwd(x): + y, f_vjp = api.vjp(f, x) + return y, f_vjp + + def g_bwd(f_vjp, y_bar): + return f_vjp(y_bar) + + g.defvjp(g_fwd, g_bwd) + + # Check that VJP can be nested in simple situations. For this to pass, + # vjp has to return a PyTree. + _, g_vjp = api.vjp(g, 1.0) + y, = g_vjp(1.0) + self.assertAllClose(y, jnp.array(2.0)) + + # Check that VJP can be nested in complex situations. For this to pass, + # vjp can't treat the closed-over tracer x as a static argument. + @jit + def z(x): + _, g_vjp = api.vjp(g, x) + return g_vjp + y, = z(1.0)(3.0) + self.assertAllClose(y, jnp.array(6.0)) + + def test_initial_style_vmap_2(self): + # https://github.com/jax-ml/jax/issues/4173 + x = jnp.ones((10, 3)) + + # Create the custom function + @jax.custom_vjp + def custom_fun(x): + return x.sum() + + def forward(x): + return x.sum(), (jnp.ones_like(x),) + + def backward(res, g): + return g * res[0], + + custom_fun.defvjp(forward, backward) + + def train_fun(x): + + def summed_fun(x): + return api.vmap(custom_fun)(x).sum() + + return api.grad(summed_fun)(x) + + def scan_body(carry, inputs): + x = carry + return carry, train_fun(x) + + scan_range = jnp.arange(4) + lax.scan(scan_body, x, scan_range) # don't crash + + def test_initial_style_vmap_3(self): + # This is like test_initial_style_vmap except the primal function closes + # over an array constant. + y = jnp.arange(1., 4.) + + @jax.custom_vjp + def f(x): + assert jnp.ndim(x) == 0 + return 3 * x * jnp.sum(y) + def f_fwd(x): + return f(x), jnp.cos(x) + def f_rev(cos_x, g): + return (2 * cos_x * g,) + f.defvjp(f_fwd, f_rev) + + def foo(x): + out, _ = lax.scan(lambda c, _: (f(c), None), x, None, length=1) + return out + + ans = api.vmap(foo)(jnp.arange(3.)) + expected = 3. * jnp.arange(3.) * 6 + self.assertAllClose(ans, expected, check_dtypes=False) + + ans = api.grad(lambda x: api.vmap(foo)(x).sum())(jnp.arange(3.)) + expected = 2. * jnp.cos(jnp.arange(3.)) + self.assertAllClose(ans, expected, check_dtypes=False) + + def test_initial_style_vmap_with_collective(self): + + @jax.custom_vjp + def f(x): + return lax.psum(x, 'foo') + + def f_fwd(x): + return lax.psum(x, 'foo'), None + + def f_bwd(res, dx): + return dx + f.defvjp(f_fwd, f_bwd) + + def g(x): + jaxpr = api.make_jaxpr(f)(x) + return core.eval_jaxpr(jaxpr.jaxpr, [], x)[0] + + out = api.vmap(lambda _, x: g(x), axis_name='foo', in_axes=(0, None), + out_axes=None)(jnp.arange(4.), 2.) + self.assertAllClose(out, 8.) + + def test_bwd_closes_over_tracer(self): + def f(y): + @jax.custom_vjp + def f(x): + return 2. * jnp.sin(x) + + def fwd(x): + return f(x), () + + def bwd(_, g): + return (2. * jnp.cos(y) * g,) # capture! + + f.defvjp(fwd, bwd) + + return jax.grad(f)(1.) + + ans = jax.jit(f)(2.) + self.assertAllClose(ans, 2. * jnp.cos(2.)) + + ans = jax.vmap(f)(jnp.arange(3.)) + self.assertAllClose(ans, 2. * jnp.cos(jnp.arange(3.))) + + ans = jax.jit(jax.vmap(f))(jnp.arange(3.)) + self.assertAllClose(ans, 2. * jnp.cos(jnp.arange(3.))) + + ans = jax.vmap(jax.jit(f))(jnp.arange(3.)) + self.assertAllClose(ans, 2. * jnp.cos(jnp.arange(3.))) + + ans = jax.grad(f)(4.) + self.assertAllClose(ans, -2. * jnp.sin(4.)) + + def test_fwd_closes_over_tracer(self): + def f(y): + @jax.custom_vjp + def f(x): + return 2. * jnp.sin(x) + + def fwd(x): + return f(x), y + + def bwd(y, g): + return (2. * jnp.cos(y) * g,) # capture! + + f.defvjp(fwd, bwd) + + return jax.grad(f)(1.) + + ans = jax.jit(f)(2.) + self.assertAllClose(ans, 2. * jnp.cos(2.)) + + ans = jax.vmap(f)(jnp.arange(3.)) + self.assertAllClose(ans, 2. * jnp.cos(jnp.arange(3.))) + + ans = jax.jit(jax.vmap(f))(jnp.arange(3.)) + self.assertAllClose(ans, 2. * jnp.cos(jnp.arange(3.))) + + ans = jax.vmap(jax.jit(f))(jnp.arange(3.)) + self.assertAllClose(ans, 2. * jnp.cos(jnp.arange(3.))) + + ans = jax.grad(f)(4.) + self.assertAllClose(ans, -2. * jnp.sin(4.)) + + def test_float0(self): + @jax.custom_vjp + def f(x, _): + return x + def f_fwd(x, _): + # we need a defined (non-float0) tangent to trigger the rule + return x, (2., 1) + def f_rev(*_): + return (2., 1) + f.defvjp(f_fwd, f_rev) + + x = 2. + y = 3 + self.assertEqual(api.grad(f, allow_int=True, argnums=(0, 1))(x, y), + (2., np.zeros(shape=(), dtype=float0))) + + def test_float0_initial_style(self): + @jax.custom_vjp + def f(x): + return x + def f_fwd(x): + return x, (2., x) + def f_rev(*_): + return ((2., jnp.zeros(shape=(), dtype=float0)),) + f.defvjp(f_fwd, f_rev) + + def foo(x, y): + out, _ = lax.scan(lambda c, _: (f(c), None), (x, y), None, length=1) + return out[0] + + x = 2. + y = 3 + self.assertEqual(api.grad(foo, allow_int=True, argnums=(0, 1))(x, y), + (2., np.zeros(shape=(), dtype=float0))) + + def test_remat(self): + @jax.custom_vjp + def f(x): + return jnp.sin(x) + def f_fwd(x): + return f(x), jnp.cos(x) + def f_rev(cos_x, g): + return (2 * cos_x * g,) + f.defvjp(f_fwd, f_rev) + + @jax.remat + def g(x): + return f(f(x)) + + ans = g(2.) + expected = np.sin(np.sin(2.)) + self.assertAllClose(ans, expected, check_dtypes=False) + + ans = api.grad(g)(2.) + expected = 4. * api.grad(lambda x: jnp.sin(jnp.sin(x)))(2.) + self.assertAllClose(ans, expected, check_dtypes=False) + + def test_remat_higher_order(self): + @jax.custom_vjp + def f(x): + return jnp.sin(x) + def f_fwd(x): + return f(x), jnp.cos(x) + def f_rev(cos_x, g): + return (2 * cos_x * g,) + f.defvjp(f_fwd, f_rev) + + def g(x): + return f(f(x)) + + ans = api.grad(api.grad(jax.remat(g)))(2.) + expected = api.grad(api.grad(g))(2.) + self.assertAllClose(ans, expected, check_dtypes=False) + + ans = api.grad(jax.remat(api.grad(g)))(2.) + expected = api.grad(api.grad(g))(2.) + self.assertAllClose(ans, expected, check_dtypes=False) + + ans = api.grad(api.grad(api.grad(jax.remat(g))))(2.) + expected = api.grad(api.grad(api.grad(g)))(2.) + self.assertAllClose(ans, expected, check_dtypes=False) + + def test_bwd_nones(self): + @jax.custom_vjp + def f(x, y): + return x * jnp.sin(y) + def f_fwd(x, y): + return f(x, y), jnp.cos(y) + def f_rev(cos, g): + return (None, 2 * cos * g) + f.defvjp(f_fwd, f_rev) + + ans = api.grad(lambda x: f(x, x))(3.) + expected = 2 * jnp.cos(3.) + self.assertAllClose(ans, expected, check_dtypes=False) + + def test_bwd_nones_vmap(self): + @jax.custom_vjp + def f(x, y): + return x * jnp.sin(y) + def f_fwd(x, y): + return f(x, y), jnp.cos(y) + def f_rev(cos, g): + return (None, 2 * cos * g) + f.defvjp(f_fwd, f_rev) + + ans = api.grad(lambda x: api.vmap(f)(x, x).sum())(jnp.arange(3.)) + expected = 2 * jnp.cos(jnp.arange(3.)) + self.assertAllClose(ans, expected, check_dtypes=False) + + def test_bwd_nones_pytree(self): + @jax.custom_vjp + def f(xs, y): + x1, x2 = xs + return x1 * x2 * jnp.sin(y) + def f_fwd(xs, y): + return f(xs, y), jnp.cos(y) + def f_rev(cos, g): + return (None, 2 * cos * g) + f.defvjp(f_fwd, f_rev) + + ans = api.grad(lambda x: f((x, x), x))(3.) + expected = 2 * jnp.cos(3.) + self.assertAllClose(ans, expected, check_dtypes=False) + + def test_custom_vjp_closure_4521(self): + # https://github.com/jax-ml/jax/issues/4521 + @jax.custom_vjp + def g(x, y): + return None + def g_fwd(x, y): + return None, y + def g_bwd(residuals, z_bar): + assert False + + g.defvjp(g_fwd, g_bwd) + + def f(xs, y): + v_g = api.vmap(g, in_axes=(0, None), out_axes=None) + v_g(xs, y) + + def scan_body(xs, _): + y = jnp.zeros(1) + _, vjp_f = api.vjp(f, xs, y) + vjp_f(None) + return xs, None + + lax.scan(scan_body, jnp.ones(5), None, 100) # doesn't crash + + def test_float0_bwd_none(self): + @jax.custom_vjp + def f(i, x): + return jnp.sin(x) + def f_fwd(i, x): + return f(i, x), jnp.cos(x) + def f_rev(cos_x, g): + return (None, 2 * cos_x * g) + f.defvjp(f_fwd, f_rev) + + ans = api.grad(f, 1)(jnp.array([1, 2]), 3.) # doesn't crash + expected = 2 * jnp.cos(3.) + self.assertAllClose(ans, expected, check_dtypes=False) + + def test_custom_gradient(self): + @jax.custom_gradient + def f(x): + return x ** 2, lambda g: (g * x,) + + self.assertAllClose(f(3.), 9., check_dtypes=False) + self.assertAllClose(api.grad(f)(3.), 3., check_dtypes=False) + self.assertAllClose(api.grad(api.grad(f))(3.), 1., check_dtypes=False) + + def test_custom_gradient_2(self): + @jax.custom_gradient + def f(x, y): + return x * y, lambda g: (y, x) + + self.assertAllClose(f(3., 4.), 12., check_dtypes=False) + self.assertAllClose(api.grad(f, argnums=(0, 1))(3., 4.), (4., 3.), + check_dtypes=False) + + def test_custom_gradient_3(self): + @jax.custom_gradient + def f(x): + vjp = lambda g: (jnp.cos(x) * jnp.arange(3., 6.),) + return jnp.sum(jnp.sin(x)), vjp + + self.assertAllClose(f(jnp.arange(3)), jnp.sum(jnp.sin(jnp.arange(3.))), + check_dtypes=False) + self.assertAllClose( + api.grad(f)(jnp.arange(3.)), + api.grad(lambda x: jnp.sum(jnp.sin(x)))(jnp.arange(3.)) * jnp.arange(3., 6.), + check_dtypes=False) + + def test_custom_gradient_can_return_singleton_value_in_vjp(self): + @jax.custom_gradient + def f(x): + return x ** 2, lambda g: g * x + + self.assertAllClose(f(3.), 9., check_dtypes=False) + self.assertAllClose(api.grad(f)(3.), 3., check_dtypes=False) + self.assertAllClose(api.grad(api.grad(f))(3.), 1., check_dtypes=False) + + def test_closure_convert(self): + def cos_after(fn, x): + converted_fn, aux_args = jax.closure_convert(fn, x) + self.assertLessEqual(len(aux_args), 1) + return _cos_after(converted_fn, x, *aux_args) + + @partial(jax.custom_vjp, nondiff_argnums=(0,)) + def _cos_after(fn, x, *args): + return jnp.cos(fn(x, *args)) + + def fwd(fn, x, *args): + y = _cos_after(fn, x, *args) + return y, (x, args) + + def rev(fn, res, g): + x, args = res + x_bar = 17. * x + args_bars = [42. * a for a in args] + return (x_bar, *args_bars) + + _cos_after.defvjp(fwd, rev) + + def dist(c, x): + return jnp.sum((x - c) ** 2.) + + def solve(c, x): + def closure(x): + return dist(c, x) + return cos_after(closure, x) + + c, x = 2. * jnp.ones(2), jnp.ones(2) + expected = jnp.cos(dist(c, x)) + self.assertAllClose(solve(c, x), expected, check_dtypes=False) + g_c, g_x = api.grad(solve, argnums=(0, 1))(c, x) + self.assertAllClose(g_c, 42. * c, check_dtypes=False) + self.assertAllClose(g_x, 17. * x, check_dtypes=False) + + def test_closure_convert_mixed_consts(self): + # Like test_closure_convert, but close over values that + # participate in AD as well as values that do not. + # See https://github.com/jax-ml/jax/issues/6415 + + def cos_after(fn, x): + converted_fn, aux_args = jax.closure_convert(fn, x) + self.assertLessEqual(len(aux_args), 1) + return _cos_after(converted_fn, x, *aux_args) + + @partial(jax.custom_vjp, nondiff_argnums=(0,)) + def _cos_after(fn, x, *args): + return jnp.cos(fn(x, *args)) + + def fwd(fn, x, *args): + y = _cos_after(fn, x, *args) + return y, (x, args) + + def rev(fn, res, g): + x, args = res + x_bar = 17. * x + args_bars = [42. * a for a in args] + return (x_bar, *args_bars) + + _cos_after.defvjp(fwd, rev) + + def dist(c, s, x): + return jnp.sum(s * (x - c) ** 2.) + + def solve(c, s, x): + def closure(x): + return dist(c, s, x) + return cos_after(closure, x) + + c, s, x = 2. * jnp.ones(2), 3. * jnp.ones(2), jnp.ones(2) + expected = jnp.cos(dist(c, s, x)) + self.assertAllClose(solve(c, s, x), expected, check_dtypes=False) + g_c, g_x = api.grad(solve, argnums=(0, 2))(c, s, x) + self.assertAllClose(g_c, 42. * c, check_dtypes=False) + self.assertAllClose(g_x, 17. * x, check_dtypes=False) + + def test_closure_convert_pytree_mismatch(self): + # See https://github.com/jax-ml/jax/issues/23588 + def f(x, z): + return z * x + + x, z = 2.0, 3.0 + _, vjp = api.vjp(f, x, z) + vjp_pure, vjp_aux_args = jax.closure_convert(vjp, x) + vjp_pure(x, *vjp_aux_args) + with self.assertRaisesRegex( + TypeError, "The inputs to the closure produced by closure_convert"): + vjp_pure(x, vjp_aux_args) + + def test_float0_cotangents_automatically_handled(self): + @jax.custom_vjp + def f(x, y): + return x + + def f_fwd(x, y): + return x, None + + def f_bwd(_, zbar): + return (0., 1) + + f.defvjp(f_fwd, f_bwd) + + jax.jit(lambda x: jax.vjp(f, 0., x)[1](1.))(1) # doesn't crash + + def test_custom_vjp_scan_batching_edge_case(self): + # https://github.com/jax-ml/jax/issues/5832 + @jax.custom_vjp + def mul(x, coeff): return x * coeff + def mul_fwd(x, coeff): return mul(x, coeff), (x, coeff) + def mul_bwd(res, g): + x, coeff = res + g_x = g * coeff + g_coeff = (x * g).sum() + return g_x, g_coeff + mul.defvjp(mul_fwd, mul_bwd) + + def scan_over_mul(x, coeff): + def f_(x, t): + return mul(x, coeff), None + y, _ = jax.lax.scan(f_, x, jnp.arange(3)) + return y + + key = jax.random.key(0) + key1, key2 = jax.random.split(key, 2) + x_batch = jax.random.normal(key1, (3, 2)) + covector_batch = jax.random.normal(key2, (3, 2)) + coeff = jnp.array(1., dtype=x_batch.dtype) + + batched_scan_over_mul = jax.vmap(scan_over_mul, in_axes=(0, None), out_axes=0) + res, vjp_fun = jax.vjp(batched_scan_over_mul, x_batch, coeff) + vjp_fun(covector_batch) # doesn't crash + + jtu.check_grads(batched_scan_over_mul, (x_batch, coeff), order=2, + modes=['rev']) + + def test_closure_with_vmap2(self): + # https://github.com/jax-ml/jax/issues/8783 + def h(z): + def f(x): + @jax.custom_vjp + def g(y): + return x * y + + def g_fwd(y): + return x * y, (x, x * y, y) + def g_rev(res, w_bar): + x, *_ = res + return (x * w_bar,) + g.defvjp(g_fwd, g_rev) + + return g(z) + + return jax.vmap(f)(jnp.arange(3., dtype='float32')).sum() + + jtu.check_grads(h, (jnp.float32(3.14),), order=1, modes=['rev']) + + def test_pytrees_not_required_to_contain_nones(self): + class A(list): + pass + + def unflatten(_, children): + assert children[0] is not None + return A(children) + + tree_util.register_pytree_node(A, lambda x: (x, None), unflatten) + + @jax.custom_vjp + def f(x): + return x[0] + def f_fwd(x): + return x[0], None + def f_bwd(_, g): + return A([g]), + f.defvjp(f_fwd, f_bwd) + + jax.grad(f)(A([1.])) # doesn't crash + + def test_vmap_vjp_called_twice(self): + # https://github.com/jax-ml/jax/pull/14728 + @jax.custom_vjp + def f(x): + return x + f.defvjp(lambda x: (x, None), lambda _, y_bar: (y_bar,)) + + _, f_vjp = jax.vjp(jax.vmap(f), jnp.array([3.])) + f_vjp(jnp.array([3.])) + f_vjp(jnp.array([3.])) # doesn't crash + + def test_symbolic_zero_custom_vjp_basic(self): + ZERO = jax.custom_derivatives.SymbolicZero + + @jax.custom_vjp + def f(x, y, z): + return x, x + + def fwd(x, y, z): + self.assertIsInstance(x, jax.custom_derivatives.CustomVJPPrimal) + self.assertIsInstance(y, jax.custom_derivatives.CustomVJPPrimal) + self.assertIsInstance(z, jax.custom_derivatives.CustomVJPPrimal) + self.assertTrue(x.perturbed) + self.assertFalse(y.perturbed) + self.assertFalse(z.perturbed) + return (x.value, x.value), None + + def fwd_all(x, y, z): + self.assertIsInstance(x, jax.custom_derivatives.CustomVJPPrimal) + self.assertIsInstance(y, jax.custom_derivatives.CustomVJPPrimal) + self.assertIsInstance(z, jax.custom_derivatives.CustomVJPPrimal) + self.assertTrue(x.perturbed) + self.assertTrue(y.perturbed) + self.assertTrue(z.perturbed) + return (x.value, x.value), None + + def bwd_all(_, g): + x1, x2 = g + self.assertFalse(type(x1) is ZERO) + self.assertFalse(type(x2) is ZERO) + return x1, x1, x2 + + def bwd_fst(_, g): + x1, x2 = g + self.assertFalse(type(x1) is ZERO) + self.assertIs(type(x2), ZERO) + return x1, x1, x2 + + def bwd_snd(_, g): + x1, x2 = g + self.assertIs(type(x1), ZERO) + self.assertFalse(type(x2) is ZERO) + return x1, x1, x2 + + x, y, z = 4., 5., 6. + i = np.array(7, np.int32) + zero = np.array(0.) + + f.defvjp(fwd, bwd_all, symbolic_zeros=True) + h = jax.jit(f) + jax.jacrev(h)(x, y, z) + jax.jacrev(lambda x: h(x, y, z))(x) + jax.jacrev(h, argnums=(0, 1, 2), allow_int=True)(x, i, i) + + f.defvjp(fwd_all, bwd_fst, symbolic_zeros=True) + fst_f = lambda *xs: f(*xs)[0] + _, vjp = jax.vjp(fst_f, x, y, z) + _, _, gz = vjp(x) + self.assertArraysAllClose(gz, zero) + + f.defvjp(fwd_all, bwd_snd, symbolic_zeros=True) + snd_f = lambda *xs: f(*xs)[1] + _, vjp = jax.vjp(snd_f, x, y, z) + gx, gy, _ = vjp(x) + self.assertArraysAllClose(gx, zero) + self.assertArraysAllClose(gy, zero) + + f.defvjp(fwd, bwd_snd, symbolic_zeros=True) + _, vjp = jax.vjp(lambda x: snd_f(x, y, z), x) + gx, = vjp(x) + self.assertArraysAllClose(gx, zero) + + def test_symbolic_zero_custom_vjp_bwd_shape_error(self): + @jax.custom_vjp + def f(x, y, z): + return x, y, z + + def fwd(x, y, z): + return f(x.value, y.value, z.value), None + + def bwd(_, gs): + x_bar, y_bar, z_bar = gs + return y_bar, x_bar, z_bar # swapped! + + f.defvjp(fwd, bwd, symbolic_zeros=True) + + with self.assertRaisesRegex( + ValueError, + r'Consider just returning a None here'): + jax.grad(lambda x, y, z: f(x, y, z)[2].sum())( + jnp.ones(1), jnp.ones(2), jnp.ones(3)) + + @parameterized.named_parameters( + ('jit_vmap', True, True), + ('jit', True, False), + ('vmap', False, True), + ('', False, False), + ) + def test_symbolic_zero_custom_vjp(self, maybe_jit, maybe_vmap): + # below: + # * static_scalar will be static in and out + # * static_array will be static in, but dynamic out + # * dyn_scalar and dyn_array will be dynamic in and out + + ZERO = jax.custom_derivatives.SymbolicZero + + def f(static_scalar, static_array, dyn_scalar, dyn_array): + out1 = static_scalar + dyn_scalar + out2 = static_array + dyn_array + return static_scalar, static_array, out1, out2 + + def _pack(x): + return lax.broadcast(x, (1,)) + + def _unpack(x): + (x,) = x + return x + + def _vmap(fun): + def _fun(*args): + args = jax.tree.map(_pack, args) + out = jax.vmap(fun)(*args) + out = jax.tree.map(_unpack, out) + return out + return _fun + + f = jax.custom_vjp(f) + + def fwd(*args): + xs, pert = [x.value for x in args], [x.perturbed for x in args] + self.assertFalse(pert[0]) + self.assertFalse(pert[1]) + self.assertTrue(pert[2]) + self.assertTrue(pert[3]) + return f(*xs), xs + + def bwd(res, g): + static_scalar, *_ = res + t_static, t_static_arr, t_dyn_scalar, t_dyn_array = g + self.assertIs(type(t_static), ZERO) + self.assertFalse(type(t_static_arr) is ZERO) + self.assertFalse(type(t_dyn_scalar) is ZERO) + self.assertFalse(type(t_dyn_array) is ZERO) + self.assertEqual(t_static.shape, ()) + self.assertEqual(t_static_arr.shape, (2,)) + return (static_scalar + 90, + t_static_arr + 91, + t_dyn_scalar + 92, + t_dyn_array + 93) + + f.defvjp(fwd, bwd, symbolic_zeros=True) + + def g(dyn_scalar, dyn_array): + if maybe_vmap: + f_ = _vmap(f) + else: + f_ = f + outs = f_(1., jnp.array([2., 3.]), dyn_scalar, dyn_array) + return outs[1:] + + def run(primal_ins, cotangent_outs): + primal_outs, vjp = jax.vjp(g, *primal_ins) + cotangent_ins = vjp(cotangent_outs) + return primal_outs, cotangent_ins + + if maybe_jit: + run = jax.jit(run) + + scalar_type = jax.Array if maybe_jit or maybe_vmap else float + primal_ins = (4., jnp.array([5., 6.])) + cotangent_outs = (jnp.array([10., 11.]), 7., jnp.array([8., 9.])) + primal_outs, cotangent_ins = run(primal_ins, cotangent_outs) + + primal_out1, primal_out2, primal_out3 = primal_outs + self.assertIsInstance(primal_out1, jax.Array) + self.assertAllClose(primal_out1, jnp.array([2., 3.])) + self.assertIsInstance(primal_out2, scalar_type) + self.assertAllClose(primal_out2, 5.) + self.assertIsInstance(primal_out3, jax.Array) + self.assertAllClose(primal_out3, jnp.array([7., 9.])) + + ct_in1, ct_in2 = cotangent_ins + self.assertIsInstance(ct_in1, scalar_type) + self.assertAllClose(ct_in1, 99.) + self.assertIsInstance(ct_in2, jax.Array) + self.assertArraysAllClose(ct_in2, jnp.array([101., 102.])) + + def test_symbolic_zero_custom_vjp_vmap_output(self): + @jax.custom_vjp + def f(x, y): + return x, y + + def fwd(x, y): + self.assertTrue(x.perturbed) + self.assertFalse(y.perturbed) + return f(x.value, y.value), None + + def bwd(_, g): + _, ct_y = g + self.assertIs(type(ct_y), jax.custom_derivatives.SymbolicZero) + return g + + f.defvjp(fwd, bwd, symbolic_zeros=True) + jax.grad(lambda x, y: jax.vmap(f)(x, y)[0].sum())(jnp.ones(3), jnp.ones(3)) + + def test_symbolic_zero_custom_vjp_custom_pytree(self): + tree_values = jax.custom_derivatives.custom_vjp_primal_tree_values + + @tree_util.register_pytree_node_class + class Box: + def __init__(self_, strict, val): + if strict: + # make sure we aren't getting special arguments that should only + # come up when symbolic_zeros is True + self.assertFalse(hasattr(val, 'perturbed')) + self_.strict = strict + self_.x = val + + def tree_flatten(self_): + return [self_.x], self_.strict + + @classmethod + def tree_unflatten(cls, strict, xs): + x, = xs + return cls(strict, x) + + x, y = Box(False, jnp.array(72.)), jnp.array(73.) + + @jax.custom_vjp + def f(box, y): + return box.x * y + + def fwd0(box, y): + self.assertTrue(box.x.perturbed) + self.assertFalse(y.perturbed) + box, y = map(tree_values, [box, y]) + return f(box, y), (box, y) + + def bwd0(res, g): + box, y = res + return y * g, box.x * g + + def fwd1(box, y): + self.assertFalse(box.x.perturbed) + self.assertTrue(y.perturbed) + box, y = map(tree_values, [box, y]) + return f(box, y), (box, y) + + def bwd1(res, g): + box, y = res + return y * g, box.x * g + + f.defvjp(fwd0, bwd0, symbolic_zeros=True) + jax.grad(f, argnums=0)(x, y) + f.defvjp(fwd1, bwd1, symbolic_zeros=True) + jax.grad(f, argnums=1)(x, y) + + def fwd_strict(box, y): + return f(box, y), (box, y) + + def bwd_strict(res, g): + box, y = res + return y * g, box.x * g + + f.defvjp(fwd_strict, bwd_strict) + jax.grad(f)(x, y) + + def test_symbolic_zeros_memoization_caching(self): + # Tests multiple zero patterns for partial_eval._memoize, and also tests + # that we're okay with stores being occupied with equal values. + @jax.custom_vjp + def f(x, y): + return x * y + + def f_fwd(x, y): + return x.value, None + + def f_bwd(_, z_bar): + return z_bar, None + + f.defvjp(f_fwd, f_bwd, symbolic_zeros=True) + + f_ = core.jaxpr_as_fun(jax.make_jaxpr(f)(2., 3.)) + _ = jax.linearize(f_, 2., 3.) + _ = jax.linearize(lambda x: f_(x, 3.), 2.) # don't crash! + + def test_run_rules_more_than_once(self): + # https://github.com/jax-ml/jax/issues/16614 + + @jax.custom_vjp + def f(x, y): + return x + y + + def f_fwd(x, y): + if y.perturbed: + res = None + else: + res = [] + return x.value + y.value, res + + def f_bwd(res, ct): + return ct, ct + + f.defvjp(f_fwd, f_bwd, symbolic_zeros=True) + + def body(x_y, _): + x, y = x_y + return (f(x, y), x), None + + @jax.grad + def g(x): + (out, _), _ = lax.scan(body, (x, 1.), xs=None, length=2) + return out + + g(1.) # doesn't crash + + def test_nones_representing_zeros_in_subtrees_returned_by_bwd(self): + # https://github.com/jax-ml/jax/issues/8356 + @jax.custom_vjp + def f(x): + return x[0] + + def f_fwd(x): + return f(x), None + + def f_bwd(_, z_bar): + return (z_bar, (None, None)), + + f.defvjp(f_fwd, f_bwd) + + jax.grad(f)((1.0, (2.0, 3.0))) # don't crash + + def test_pytree_nones_returned_by_bwd(self): + @jax.custom_vjp + def f(x): + return x[0] + + def f_fwd(x): + return f(x), None + + def f_bwd(_, z_bar): + return (z_bar, (None, None)), + + f.defvjp(f_fwd, f_bwd) + + jax.grad(f)((1.0, (2.0, None))) # don't crash + + def test_bwd_rule_shape_mismatch(self): + @jax.custom_vjp + def foo(x, y): + return x + + def foo_fwd(x, y): + return x, None + + def foo_bwd(_, g): + return jnp.zeros(3), jnp.zeros(3) + + foo.defvjp(foo_fwd, foo_bwd) + + with self.assertRaisesRegex( + ValueError, + r'output\[1\] the bwd rule produced an output of shape/dtype float..\[3\]'): + jax.grad(lambda x, y: foo(x, y * y).sum(), 1)(jnp.ones(3), jnp.ones(4)) + + def test_bwd_rule_shape_mismatch_disable(self): + # TODO(mattjj): remove this test when the config option is removed + @jax.custom_vjp + def foo(x, y): + return x + + def foo_fwd(x, y): + return x, None + + def foo_bwd(_, g): + return jnp.zeros(3), jnp.zeros(3) + + foo.defvjp(foo_fwd, foo_bwd) + + with config.custom_vjp_disable_shape_check(True): + jax.grad(lambda x, y: foo(x, y).sum(), 1)(jnp.ones(3), jnp.ones(4)) + + def test_bwd_rule_can_produce_list_or_tuple(self): + @jax.custom_vjp + def f(x, y): + return x * y + + def f_fwd(x, y): + return f(x, y), (x, y) + + def f_bwd(xy, g): + x, y = xy + return [g * y, x * g] # list, not tuple + + f.defvjp(f_fwd, f_bwd) + + jax.grad(f)(1., 2.) # don't crash + + def test_optimize_remat(self): + def fun(x): + # This array is included to make sure that we handle consts appropriately + return np.array([1.0])*x + + def fwd(x): + return np.array([2.0])*x*x/np.array([1.0]), (2 * x,) + + x = jnp.linspace(0, 5.0, 10) + fwd = custom_derivatives.optimize_remat_of_custom_vjp_fwd( + fun, api_util.debug_info("custom_vjp fun", fun, (x,), {}), + fwd, api_util.debug_info("custom_vjp fwd", fwd, (x,), {})) + + self.assertAllClose(jax.jit(fwd)(x)[0], 2*x*x) # Shouldn't hit custom DCE + self.assertAllClose(jax.jit(lambda x: fwd(x)[0])(x), x) # Should be DCEed + + def test_optimize_remat_vmap(self): + def fun(x): + return (np.array([1.0])*x)[0] + def fwd(x): + return (np.array([2.0])*x*x/np.array([1.0]))[0], (2 * x,) + x = jnp.linspace(0, 5.0, 10) + fwd = custom_derivatives.optimize_remat_of_custom_vjp_fwd( + fun, api_util.debug_info("custom_vjp fun", fun, (x,), {}), + fwd, api_util.debug_info("custom_vjp fwd", fwd, (x,), {})) + self.assertAllClose(jax.jit(jax.vmap(fwd))(x)[0], 2*x*x) + self.assertAllClose(jax.jit(lambda x: jax.vmap(fwd)(x)[0])(x), x) + + def test_optimize_remat_cond(self): + def fun(x): + return x + def fwd(x): + return x*x, (2 * x,) + + x = jnp.linspace(0, 5.0, 10) + fwd = custom_derivatives.optimize_remat_of_custom_vjp_fwd( + fun, api_util.debug_info("custom_vjp fun", fun, (x,), {}), + fwd, api_util.debug_info("custom_vjp fwd", fwd, (x,), {})) + + def g(x): + return jax.lax.cond(True, fwd, lambda x: (2.0 * x, (x,)), x) + + self.assertAllClose(jax.jit(g)(x)[0], x*x) + self.assertAllClose(jax.jit(lambda x: g(x)[0])(x), x) + + def test_optimize_remat_jvp(self): + def fun(x): + return x**2 + def fwd_(x): + return x*x, (2 * x,) + + fwd = custom_derivatives.optimize_remat_of_custom_vjp_fwd( + fun, api_util.debug_info("custom_vjp fun", fun, (3.2,), {}), + fwd_, api_util.debug_info("custom_vjp fwd", fwd_, (3.2,), {})) + calc = jax.jvp(fwd, (3.2,), (1.0,)) + expected = jax.jvp(fwd_, (3.2,), (1.0,)) + self.assertAllClose(calc, expected) + + @jax.jit + def g(x, t): + (y, r), (y_dot, r_dot) = jax.jvp(fwd, (x,), (t,)) + return y, y_dot + calc = g(3.2, 1.0) + expected = jax.jvp(fun, (3.2,), (1.0,)) + self.assertAllClose(calc, expected) + + def test_optimize_remat_gh21303(self): + @jax.custom_vjp + def f(x): + return jnp.tan(x) + + def f_fwd(x): + return jnp.sin(x), (x,) + + def f_bwd(res, g): + x, = res + cos_x = jnp.cos(x) + return (cos_x * g,) + + f.defvjp(f_fwd, f_bwd, optimize_remat=True) + + def temp(x): + out = jax.remat(f)(x) + out = out ** 2 + return out + + v, g = jax.value_and_grad(temp)(3.2) + self.assertAllClose(v, jnp.tan(3.2)**2) + + def test_optimize_remat_multiple_args(self): + def f_(x, y): + return jnp.sin(x) * y + + @jax.custom_vjp + def f(x, y): + return f_(x, y) + + def f_fwd(x, y): + return f(x, y), (jnp.cos(x), jnp.sin(x), y) + + def f_bwd(res, g): + cos_x, sin_x, y = res + return (cos_x * g * y, sin_x * g) + + f.defvjp(f_fwd, f_bwd, optimize_remat=True) + x, y = 3.2, 1.0 + self.assertAllClose(jax.grad(f)(x, y), jax.grad(f_)(x, y)) + + def test_optimize_remat_kwargs(self): + @jax.custom_vjp + def f(x, y): + return jnp.sin(x) * y + + def f_fwd(x, y, *, keyword=False): + del keyword + return f(x, y), (jnp.cos(x), jnp.sin(x), y) + + def f_bwd(res, g): + cos_x, sin_x, y = res + return (cos_x * g * y, sin_x * g) + + f.defvjp(f_fwd, f_bwd, optimize_remat=True) + x, y = 3.2, 1.0 + jax.grad(f)(x, y) # Doesn't error + + def test_optimize_remat_custom_vmap(self): + # See https://github.com/jax-ml/jax/pull/23000 + @jax.custom_vjp + def f(x, y): + return jnp.sin(x) * y + + @jax.custom_batching.custom_vmap + def f_fwd(x, y): + return f(x, y), (jnp.cos(x), jnp.sin(x), y) + + @f_fwd.def_vmap + def f_fwd_vmap(_, in_batched, x, y): + # Insert a new const here to test the optimize_remat batching rule. + out = np.array([2.0])*f(x, y) + out_batched = (True, (True, True, True)) + return (out, (jnp.cos(x), jnp.sin(x), y)), out_batched + + def f_bwd(res, g): + cos_x, sin_x, y = res + return (cos_x * g * y, sin_x * g) + + f.defvjp(f_fwd, f_bwd, optimize_remat=True) + x, y = jnp.linspace(0.0, 1.0, 5), jnp.linspace(2.0, 5.0, 5) + jax.jit(jax.vmap(jax.grad(f)))(x, y) # Doesn't error + + def test_dce(self): + @jax.custom_vjp + def f(x, y): + return jnp.sin(x), x + jnp.cos(y) + + def f_fwd(x, y): + return f(x, y), (jnp.cos(x), jnp.sin(y)) + + def f_bwd(res, cts): + cos_x, sin_y = res + ct_a, ct_b = cts + return 2.0 * cos_x * ct_a + 1.5 * ct_b, -0.5 * sin_y * ct_b + + f.defvjp(f_fwd, f_bwd) + + def check_jaxpr(jaxpr, used_outs, includes, excludes): + dce_jaxpr, _ = pe.dce_jaxpr(jaxpr, used_outs) + if not dce_jaxpr.eqns: + assert not includes + return + call_jaxpr = dce_jaxpr.eqns[0].params["call_jaxpr"] + for prim in includes: + assert any(eqn.primitive == prim for eqn in call_jaxpr.eqns) + for prim in excludes: + assert all(eqn.primitive != prim for eqn in call_jaxpr.eqns) + + x, y = 0.1, -1.3 + jaxpr = jax.make_jaxpr(f)(x, y).jaxpr + check_jaxpr(jaxpr, [True, True], [lax.sin_p, lax.cos_p], []) + check_jaxpr(jaxpr, [True, False], [lax.sin_p], [lax.cos_p]) + check_jaxpr(jaxpr, [False, True], [lax.cos_p], [lax.sin_p]) + check_jaxpr(jaxpr, [False, False], [], [lax.sin_p, lax.cos_p]) + + def dce_jaxpr_as_fun(jaxpr, used_outs): + jaxpr_, _ = pe.dce_jaxpr(jaxpr, used_outs) + fun = core.jaxpr_as_fun(pe.close_jaxpr(jaxpr_)) + return lambda *args: fun(*args)[0] + + f0 = dce_jaxpr_as_fun(jaxpr, [True, False]) + f1 = dce_jaxpr_as_fun(jaxpr, [False, True]) + self.assertAllClose( + api.grad(f0, argnums=(0, 1))(x, y), (2.0 * jnp.cos(x), 0.0)) + self.assertAllClose( + api.grad(f1, argnums=(0, 1))(x, y), (1.5, -0.5 * jnp.sin(y))) + + def test_resolve_kwargs_error_message(self): + @jax.custom_vjp + def f(x, y, *, z=None): + return jnp.sin(x), x + jnp.cos(y) + + def f_fwd(x, y): + self.fail("should not be executed") + + def f_bwd(res, cts): + self.fail("should not be executed") + + f.defvjp(f_fwd, f_bwd) + + with self.assertRaisesRegex( + TypeError, + r"The input arguments to the custom_vjp-decorated function f(.*)\n" + r"missing a required argument: 'y'" + ): + f(0.5) + + with self.assertRaisesRegex( + TypeError, + r"The input arguments to the custom_vjp-decorated function f(.*)\n" + "The following keyword arguments could not be resolved to positions: z" + ): + f(0.5, 0.1, z=1.0) + + def test_pretty_print(self): + @jax.custom_vjp + def f(x): + return x + 1 + + def f_fwd(x): + return f(x), () + + def f_bwd(_, g): + return g + f.defvjp(f_fwd, f_bwd) + + x = jnp.array([4.2], dtype=jnp.float32) + jaxpr = jax.make_jaxpr(f)(x) + actual = jaxpr.pretty_print(use_color=False) + expected = textwrap.dedent( + """ + { lambda ; a:f32[1]. let + b:f32[1] = custom_vjp_call[ + name=f + bwd=f_bwd + call_jaxpr={ lambda ; c:f32[1]. let d:f32[1] = add c 1.0:f32[] in (d,) } + fwd=f_fwd + symbolic_zeros=False + ] a + in (b,) } + """).strip() + self.assertEqual(actual, expected) + + def test_custom_lin_pretty_print(self): + @jax.custom_vjp + def f(x): + return x + 1 + + def f_fwd(x): + return f(x), () + + def f_bwd(_, g): + return g + f.defvjp(f_fwd, f_bwd) + + x = jnp.array([4.2], dtype=jnp.float32) + jaxpr = jax.make_jaxpr(lambda x: jax.jvp(f, (x,), (x,)))(x) + jaxpr, _ = pe.dce_jaxpr(jaxpr.jaxpr, [False, True]) + actual = jaxpr.pretty_print(use_color=False) + expected = textwrap.dedent( + """ + { lambda ; a:f32[1]. let + b:f32[1] = custom_lin[ + bwd=f_bwd + in_zeros=[False] + num_res=0 + symbolic_zeros=False + ] a + in (b,) } + """).strip() + self.assertEqual(actual, expected) + + +def transpose_unary(f, x_example): + def transposed(y): + x, = api.linear_transpose(f, x_example)(y) + return x + return transposed + + +# This class wraps jax.custom_transpose.custom_transpose in order to pass in a +# particular tree of output type on each call. Otherwise it forwards +# all attribute access. +class _custom_transpose: + def __init__(self, out_types, fun): + self.out_types = out_types + self.fun = jax.custom_transpose.custom_transpose(fun) + + def __getattr__(self, name): + return getattr(self.fun, name) + + def __call__(self, *args): + return self.fun(self.out_types, *args) + + +# This function is meant to be used as a decorator that delegates to +# custom_transpose but makes it easy to specify output argument types +# by example. If used directly a decorator (i.e. not invoked with +# example arguments), assumes a scalar-valued function. +# +# TODO(frostig): remove this (and its uses) once custom_transpose offers +# an option of inferring output types. +def custom_transpose(example_out): + if isinstance(example_out, Callable): + out_type = core.get_aval(0.).to_tangent_aval() + return _custom_transpose(out_type, example_out) + return partial( + _custom_transpose, + jax.tree.map( + lambda x: core.get_aval(x).to_tangent_aval(), example_out)) + + +class CustomTransposeTest(jtu.JaxTestCase): + + def test_linear_call(self): + def f(x, y): + def fn(r, x): return x / r + def tp(r, t): return t / r + return x + jax.custom_derivatives.linear_call(fn, tp, y, x) + + def f_ref(x, y): + return x + x / y + + x = jnp.ones(2) * 6. + y = jnp.ones(2) * 3. + self.assertAllClose(f(x, y), f_ref(x, y)) + + f1 = lambda x: f(x, y) + f1_ref = lambda x: f_ref(x, y) + self.assertAllClose(transpose_unary(f1, x)(x), + transpose_unary(f1_ref, x)(x)) + + def test_linear_call_incorrect_transpose(self): + def f(x, y): + def fn(r, x): return x / r + def tp(r, t): return t / (2. * r) # nb: not the true transpose + return x + jax.custom_derivatives.linear_call(fn, tp, y, x) + + def f_ref(x, y): + return x + x / y + + x = jnp.ones(2) * 6. + y = jnp.ones(2) * 3. + self.assertAllClose(f(x, y), f_ref(x, y)) + + f1 = lambda x: f(x, y) + f1_ref = lambda x: f_ref(x, 2. * y) # nb: double the reference divisor + self.assertAllClose(transpose_unary(f1, x)(x), + transpose_unary(f1_ref, x)(x)) + + def test_linear_call_transpose_transpose_transpose(self): + def fn(r, x): return x / r + def tp(r, t): return t / (2. * r) # nb: untrue transpose + def f_(x, y): + return x + jax.custom_derivatives.linear_call(fn, tp, y, x) + + x = jnp.ones(2) * 6. + y = jnp.ones(2) * 3. + f = lambda x: f_(x, y) + ft = transpose_unary(f, x) + ftt = transpose_unary(ft, x) + fttt = transpose_unary(ftt, x) + self.assertAllClose(ft(x), x + tp(y, x)) + self.assertAllClose(f(x), ftt(x)) + self.assertAllClose(ft(x), fttt(x)) + + def test_linear_call_scalar_to_vector(self): + def f(c, x): + def fn(_, x): + return [x, x] + + def tp(_, t): + t1, t2 = t + return t1 + t2 + + return jax.custom_derivatives.linear_call(fn, tp, (), c * x) + + def f_ref(c, x): + return [c * x, c * x] + + c, x = 2., 3. + t = [4., 5.] + self.assertAllClose(f(c, x), f_ref(c, x)) + self.assertAllClose(transpose_unary(partial(f, c), x)(t), + transpose_unary(partial(f_ref, c), x)(t)) + + def test_linear_call_nested(self): + # identity function with an untrue transpose of 0 + def id_(x): + def f(_, x): return x + def t(_, t): return 0. + return jax.custom_derivatives.linear_call(f, t, (), x) + + # identity function with an untrue transpose of 7, and where both + # forward and transpose have custom transpositions that should + # never end up invoked. + def f(x): + def f_(_, x): return id_(x) + def t_(_, t): return id_(7.) + return jax.custom_derivatives.linear_call(f_, t_, (), x) + + x = 5. + id_t = transpose_unary(id_, x) + id_tt = transpose_unary(id_t, x) + ft = transpose_unary(f, x) + ftt = transpose_unary(ft, x) + fttt = transpose_unary(ftt, x) + + self.assertAllClose(id_(x), x) + self.assertAllClose(id_t(x), 0.) + self.assertAllClose(id_tt(x), x) + + self.assertAllClose(f(x), x) + self.assertAllClose(ft(x), 7.) + self.assertAllClose(ftt(x), x) + self.assertAllClose(fttt(x), 7.) + + def test_linear_call_jit(self): + def f(x, y): + def fn(r, x): return x / r + def tp(r, t): return t / r + return x + jax.custom_derivatives.linear_call(fn, tp, y, x) + + x = jnp.ones(2) * 6. + y = jnp.ones(2) * 3. + self.assertAllClose(f(x, y), jax.jit(f)(x, y)) + + f1 = lambda x: f(x, y) + self.assertAllClose(transpose_unary(f1, x)(x), + jax.jit(transpose_unary(f1, x))(x)) + + def test_linear_call_type_mismatch(self): + def f(x, y): + def fn(r, x): return x / r + def tp(r, t): return None + return x + jax.custom_derivatives.linear_call(fn, tp, y, x) + + x = jnp.ones(2) * 6. + y = jnp.ones(2) * 3. + f1 = lambda x: f(x, y) + with self.assertRaisesRegex(TypeError, "transpose output pytree"): + transpose_unary(f1, x)(x) + + def test_linear_call_recursion(self): + def f(x): + def fn(_, x): return x + def tp(_, t): return f(t) + return jax.custom_derivatives.linear_call(fn, tp, None, x) + jax.jit(f)(0.1) + + def test_linear_call_grad(self): + def f(x, y): + def fn(r, x): return x / r + def tp(r, t): return t / r + return x + jax.custom_derivatives.linear_call(fn, tp, y, x) + + def f_ref(x, y): + return x + x / y + + x = jnp.array(6.) + y = jnp.array(3.) + self.assertAllClose(jax.grad(f)(x, y), jax.grad(f_ref)(x, y)) + + def test_basic(self): + def f(x, y): + @custom_transpose(jnp.ones(2)) + def fn(r, x): return x / r + @fn.def_transpose + def tp(r, t): return t / r + + return x + fn(y, x) + + def f_ref(x, y): + return x + x / y + + x = jnp.ones(2) * 6. + y = jnp.ones(2) * 3. + self.assertAllClose(f(x, y), f_ref(x, y)) + + f1 = lambda x: f(x, y) + f1_ref = lambda x: f_ref(x, y) + self.assertAllClose(transpose_unary(f1, x)(x), + transpose_unary(f1_ref, x)(x)) + + def test_incorrect_transpose(self): + def f(x, y): + @custom_transpose(jnp.ones(2)) + def fn(r, x): return x / r + @fn.def_transpose + def tp(r, t): return t / (2. * r) # nb: not the true transpose + + return x + fn(y, x) + + def f_ref(x, y): + return x + x / y + + x = jnp.ones(2) * 6. + y = jnp.ones(2) * 3. + self.assertAllClose(f(x, y), f_ref(x, y)) + + f1 = lambda x: f(x, y) + f1_ref = lambda x: f_ref(x, 2. * y) # nb: double the reference divisor + self.assertAllClose(transpose_unary(f1, x)(x), + transpose_unary(f1_ref, x)(x)) + + def test_transpose_transpose_transpose(self): + @custom_transpose(jnp.ones(2)) + def fn(r, x): return x / r + @custom_transpose(jnp.ones(2)) + def tp(r, t): return t / (2. * r) # nb: untrue transpose + + fn.def_transpose(tp) + tp.def_transpose(fn) + + def f_(x, y): + return x + fn(y, x) + + x = jnp.ones(2) * 6. + y = jnp.ones(2) * 3. + f = lambda x: f_(x, y) + ft = transpose_unary(f, x) + ftt = transpose_unary(ft, x) + fttt = transpose_unary(ftt, x) + self.assertAllClose(ft(x), x + tp(y, x)) + self.assertAllClose(f(x), ftt(x)) + self.assertAllClose(ft(x), fttt(x)) + + def test_scalar_to_vector(self): + def f(c, x): + @custom_transpose([0., 0.]) + def fn(_, x): + return [x, x] + + @fn.def_transpose + def tp(_, t): + t1, t2 = t + return t1 + t2 + + return fn((), c * x) + + def f_ref(c, x): + return [c * x, c * x] + + c, x = 2., 3. + t = [4., 5.] + self.assertAllClose(f(c, x), f_ref(c, x)) + self.assertAllClose(transpose_unary(partial(f, c), x)(t), + transpose_unary(partial(f_ref, c), x)(t)) + + def test_nested(self): + # identity function with an untrue transpose of 0 + def id_(x): + f = custom_transpose(lambda _, x: x) + t = custom_transpose(lambda _, t: 0.) + f.def_transpose(t) + t.def_transpose(f) + return f((), x) + + # identity function with an untrue transpose of 7, and where both + # forward and transpose have custom transpositions that should + # never end up invoked. + def f(x): + f_ = custom_transpose(lambda _, x: id_(x)) + t_ = custom_transpose(lambda _, t: id_(7.)) + f_.def_transpose(t_) + t_.def_transpose(f_) + return f_((), x) + + x = 5. + id_t = transpose_unary(id_, x) + id_tt = transpose_unary(id_t, x) + ft = transpose_unary(f, x) + ftt = transpose_unary(ft, x) + fttt = transpose_unary(ftt, x) + + self.assertAllClose(id_(x), x) + self.assertAllClose(id_t(x), 0.) + self.assertAllClose(id_tt(x), x) + + self.assertAllClose(f(x), x) + self.assertAllClose(ft(x), 7.) + self.assertAllClose(ftt(x), x) + self.assertAllClose(fttt(x), 7.) + + def test_one_degree(self): + T = lambda f: transpose_unary(f, 0.) + + @custom_transpose + def f(_, z): return 2. * z + @f.def_transpose + def ft(_, z): return 3. * z + + f = partial(f, ()) + self.assertAllClose(2., f(1.)) + self.assertAllClose(3., T(f)(1.)) + self.assertAllClose(3., T(T(f))(1.)) + self.assertAllClose(3., T(T(T(f)))(1.)) + self.assertAllClose(3., T(T(T(T(f))))(1.)) # ... + + def test_two_degrees(self): + T = lambda f: transpose_unary(f, 0.) + + @custom_transpose + def f(_, z): return 2. * z + + @f.def_transpose + @custom_transpose + def ft(_, z): return 3. * z + + @ft.def_transpose + def ftt(_, z): return 7. * z + + f = partial(f, ()) + self.assertAllClose(2., f(1.)) + self.assertAllClose(3., T(f)(1.)) + self.assertAllClose(7., T(T(f))(1.)) + self.assertAllClose(7., T(T(T(f)))(1.)) + self.assertAllClose(7., T(T(T(T(f))))(1.)) # ... + + def test_symmetric(self): + T = lambda f: transpose_unary(f, 0.) + + @custom_transpose + def f(_, z): return 2. * z + @custom_transpose + def g(_, z): return 3. * z + + f.def_transpose(g) + g.def_transpose(f) + + f = partial(f, ()) + self.assertAllClose(2., f(1.)) + self.assertAllClose(3., T(f)(1.)) + self.assertAllClose(2., T(T(f))(1.)) + self.assertAllClose(3., T(T(T(f)))(1.)) + self.assertAllClose(2., T(T(T(T(f))))(1.)) # ... + + def test_recursive(self): + T = lambda f: transpose_unary(f, 0.) + + @custom_transpose + def f(c, z): return c * z + + @f.def_transpose + def ft(c, z): return f(c + 1., z) + + g = partial(f, 1.) + self.assertAllClose(1., g(1.)) + self.assertAllClose(2., T(g)(1.)) + self.assertAllClose(3., T(T(g))(1.)) + self.assertAllClose(4., T(T(T(g)))(1.)) + self.assertAllClose(5., T(T(T(T(g))))(1.)) # ... + + def test_jvp_lin(self): + def f(x, y): + @custom_transpose(jnp.ones(2)) + def fn(r, x): return x / r + @fn.def_transpose + def tp(r, t): return t / r + return x + fn(y, x) + + def f_ref(x, y): return x + x / y + + x, y, tx = 6., 3., 1. + g = lambda x: f(x, y) + g_ref = lambda x: f_ref(x, y) + self.assertAllClose(api.jvp(g, [x], [tx]), api.jvp(g_ref, [x], [tx])) + + def test_jvp_res(self): + raise unittest.SkipTest('unimplemented') # TODO(frostig) + + def f(x, y): + @custom_transpose(jnp.ones(2)) + def fn(r, x): return x / r + @fn.def_transpose + def tp(r, t): return t / r + return x + fn(y, x) + + def f_ref(x, y): return x + x / y + + x, y, ty = 6., 3., 1. + g = lambda y: f(x, y) + g_ref = lambda y: f_ref(x, y) + self.assertAllClose(api.jvp(g, [y], [ty]), api.jvp(g_ref, [y], [ty])) + + def test_jvp_both(self): + raise unittest.SkipTest('unimplemented') # TODO(frostig) + + def f(x, y): + @custom_transpose(jnp.ones(2)) + def fn(r, x): return x / r + @fn.def_transpose + def tp(r, t): return t / r + return x + fn(y, x) + + def f_ref(x, y): return x + x / y + + x, y, tx, ty = 6., 3., 1., 1. + self.assertAllClose(api.jvp(f, [x, y], [tx, ty]), + api.jvp(f_ref, [x, y], [tx, ty])) + + def test_make_jaxpr(self): + def f(x, y): + @custom_transpose(jnp.ones(2)) + def fn(r, x): return x / r + @fn.def_transpose + def tp(r, t): return 2 * t / r + + return x + fn(y, x) + + x = jnp.ones(2) * 6. + y = jnp.ones(2) * 3. + f_ = lambda x: f(x, y) + f_t = transpose_unary(f_, x) + + jaxpr = api.make_jaxpr(f_)(x) + self.assertIn('custom_transpose_call', str(jaxpr)) + + jaxpr_t = api.make_jaxpr(f_t)(x) + self.assertNotIn('custom_transpose_call', str(jaxpr_t)) + + def test_jit(self): + def f(x, y): + @custom_transpose(jnp.ones(2)) + def fn(r, x): return x / r + @fn.def_transpose + def tp(r, t): return 2 * t / r + + return x + fn(y, x) + + x = jnp.ones(2) * 6. + y = jnp.ones(2) * 3. + self.assertAllClose(f(x, y), jax.jit(f)(x, y)) + + f_ = lambda x: f(x, y) + f_t = transpose_unary(f_, x) + g_ = jax.jit(f_) + g_t = transpose_unary(g_, x) + self.assertAllClose(f_(x), jax.jit(f_)(x)) + self.assertAllClose(f_t(x), jax.jit(f_t)(x)) + self.assertAllClose(f_(x), g_(x)) + self.assertAllClose(f_t(x), g_t(x)) + + def test_jit_recursive(self): + def f(x, y): + @custom_transpose(jnp.ones(2)) + def fn(r, x): return x / r + @fn.def_transpose + def tp(r, t): return 2 * fn(r, t) + + return x + fn(y, x) + + x = jnp.ones(2) * 6. + y = jnp.ones(2) * 3. + self.assertAllClose(f(x, y), jax.jit(f)(x, y)) + + f_ = lambda x: f(x, y) + f_t = transpose_unary(f_, x) + g_ = jax.jit(f_) + g_t = transpose_unary(g_, x) + self.assertAllClose(f_(x), jax.jit(f_)(x)) + self.assertAllClose(f_t(x), jax.jit(f_t)(x)) + self.assertAllClose(f_(x), g_(x)) + self.assertAllClose(f_t(x), g_t(x)) + + def test_cond(self): + def f(x, y): + @custom_transpose(jnp.ones(2)) + def fn(r, x): return x / r + @fn.def_transpose + def tp(r, t): return 2 * t / r + + return x + fn(y, x) + + def cond_wrap(f): + return lambda i, x: lax.cond(i > 0, f, lambda x: x, x) + + i = 7. + x = jnp.ones(2) * 6. + y = jnp.ones(2) * 3. + + f_ = lambda x: f(x, y) + f_t = transpose_unary(f_, x) + g_ = partial(cond_wrap(f_), i) + g_t = transpose_unary(g_, x) + + self.assertAllClose(f_(x), g_(x)) + self.assertAllClose(f_t(x), g_t(x)) + + def test_cond_recursive(self): + def f(x, y): + @custom_transpose(jnp.ones(2)) + def fn(r, x): return x / r + @fn.def_transpose + def tp(r, t): return 2 * fn(r, t) + + return x + fn(y, x) + + def cond_wrap(f): + return lambda i, x: lax.cond(i > 0, f, lambda x: x, x) + + i = 7. + x = jnp.ones(2) * 6. + y = jnp.ones(2) * 3. + + f_ = lambda x: f(x, y) + f_t = transpose_unary(f_, x) + g_ = partial(cond_wrap(f_), i) + g_t = transpose_unary(g_, x) + + self.assertAllClose(f_(x), g_(x)) + self.assertAllClose(f_t(x), g_t(x)) + + def test_compose_custom_jvp(self): + @jax.custom_jvp + def f(x): + return jnp.sin(x) + + @f.defjvp + def f_jvp(primals, tangents): + x, = primals + dx, = tangents + return f(x), g(x, dx) + + @custom_transpose + def g(x, dx): + return jnp.cos(x) * dx + + @g.def_transpose + def gt(x, t): + return jnp.cos(x) * t + + with config.use_direct_linearize(True): + self.assertAllClose(jax.grad(f)(0.5), jnp.cos(0.5)) + + def test_input_none(self): + # ref: https://github.com/jax-ml/jax/issues/29009 + @jax.custom_jvp + def f(x, y): return y + @f.defjvp + def f_jvp(p, t): return f(*p), g(p, t) + + @custom_transpose(jnp.float32(0)) + def g(r, x): return x[1] + @g.def_transpose + def gt(r, t): return None, jnp.zeros_like(r[1]) + + jax.grad(f, argnums=(1,))(None, jnp.float32(2)) # doesn't crash + + +class CustomDceTest(jtu.JaxTestCase): + + def test_basic(self): + @jax.experimental.custom_dce.custom_dce + def f(x): + return jnp.sin(x), jnp.cos(x) + + @f.def_dce + def rule(used_outs, x): + return ( + jnp.exp(x) if used_outs[0] else None, + jnp.sqrt(x) if used_outs[1] else None, + ) + + x = jnp.array(1.1234) + self.assertAllClose(jax.jit(lambda x: f(x)[0])(x), jnp.exp(x)) + self.assertAllClose(jax.jit(lambda x: f(x)[1])(x), jnp.sqrt(x)) + + def test_recursive(self): + @jax.experimental.custom_dce.custom_dce + def f(x): + return jnp.exp(x), 10 * jnp.sqrt(x) + + @f.def_dce + def f_dce(used_outs, x): + return [2 * v if used else None for used, v in zip(used_outs, f(x))] + + x = 1.1234 + expected = f(x) + self.assertAllClose(jax.jit(lambda x: f(x)[0])(x), 2 * expected[0]) + self.assertAllClose(jax.jit(lambda x: f(x)[1])(x), 2 * expected[1]) + + def test_multiple_rounds(self): + @jax.experimental.custom_dce.custom_dce + def f(x, y, z): + return jnp.sin(x), jnp.sin(y), jnp.sin(z) + + @f.def_dce + def rule(used_outs, x, y, z): + patterns.append(used_outs) + outs = [ + jnp.cos(v) if used else None for used, v in zip(used_outs, (x, y, z)) + ] + return outs + + patterns = [] + x, y, z = jnp.array(1.), jnp.array(2.), jnp.array(3.) + jaxpr = jax.make_jaxpr(f)(x, y, z).jaxpr + new_jaxpr, used_ins = pe.dce_jaxpr(jaxpr, [True, False, True]) + assert used_ins == [True, False, True] + new_jaxpr, used_ins = pe.dce_jaxpr(new_jaxpr, [True, False]) + assert used_ins == [True, False] + assert patterns == [(True, False, True), (True, False, False)], patterns + + def test_batching(self): + @jax.experimental.custom_dce.custom_dce + def f(x, y): + return jnp.sin(x), jnp.sin(y) + + @f.def_dce + def rule(used_outs, x, y): + return ( + jnp.cos(x) if used_outs[0] else None, + jnp.cos(y) if used_outs[1] else None, + ) + + x = jnp.linspace(-0.1, 0.2, 5) + y = jnp.linspace(3.0, 4.0, 5) + self.assertAllClose(jax.vmap(f)(x, y), f(x, y)) + self.assertAllClose( + jax.jit(lambda *args: jax.vmap(f)(*args)[0])(x, y), jnp.cos(x) + ) + self.assertAllClose( + jax.vmap(jax.jit(lambda *args: f(*args)[0]))(x, y), jnp.cos(x) + ) + self.assertAllClose( + jax.jit(lambda *args: jax.vmap(f)(*args)[1])(x, y), jnp.cos(y) + ) + self.assertAllClose( + jax.vmap(jax.jit(lambda *args: f(*args)[1]))(x, y), jnp.cos(y) + ) + + def test_composes_with_custom_vjp(self): + # custom_dce must be the "outer" decorator (for now!) because custom_vjp + # doesn't pass through DCE. + @jax.experimental.custom_dce.custom_dce + @jax.custom_vjp + def f(x, y): + return jnp.sin(x) * y, x * jnp.sin(y) + + @f.def_dce + def f_dce_rule(used_outs, x, y): + return ( + jnp.cos(x) * y if used_outs[0] else None, + x * jnp.cos(y) if used_outs[1] else None, + ) + + def f_fwd(x, y): + return f(x, y), (x, jnp.cos(x), jnp.sin(x), y, jnp.cos(y), jnp.sin(y)) + + def f_bwd(res, g): + ga, gb = g + x, cos_x, sin_x, y, cos_y, sin_y = res + return (cos_x * ga * y + sin_y * gb, sin_x * ga + x * cos_y * gb) + + f.defvjp(f_fwd, f_bwd) + + x, y = jnp.array(1.), jnp.array(2.) + self.assertAllClose(jax.jit(lambda *args: f(*args)[0])(x, y), + jnp.cos(x) * y) + jax.grad(lambda *args: f(*args)[0])(x, y) # Doesn't crash. + + def test_can_optimize_remat(self): + @jax.custom_vjp + def f(x): + return jnp.tan(x) + + @jax.experimental.custom_dce.custom_dce + def f_fwd(x): + return jnp.sin(x), (x,) + + @f_fwd.def_dce + def f_dce_rule(used_outs, x): + used_prim, used_res = used_outs + used_res, = used_res + if not used_res: + return f(x), None + prim, res = f_fwd(x) + return prim if used_prim else None, res + + def f_bwd(res, g): + x, = res + cos_x = jnp.cos(x) + return (cos_x * g,) + + f.defvjp(f_fwd, f_bwd) + + def temp(x): + out = jax.remat(f)(x) + out = out ** 2 + return out + + v, g = jax.value_and_grad(temp)(3.2) + self.assertAllClose(v, jnp.tan(3.2)**2) + + def test_static_argnums(self): + @partial(jax.experimental.custom_dce.custom_dce, static_argnums=(0,)) + def g(f, x): + return f(x), 10 * f(x) + + @g.def_dce + def g_dce(f, used_outs, x): # note: static_argnums are always passes first + self.assertTrue(callable(f)) + return [2 * v if used else None for used, v in zip(used_outs, g(f, x))] + + x = 1.1234 + f = lambda x: jnp.exp(x) + expected = g(f, x) + self.assertAllClose(jax.jit(lambda x: g(f, x)[0])(x), 2 * expected[0]) + self.assertAllClose(jax.jit(lambda x: g(f, x)[1])(x), 2 * expected[1]) + + def test_shape_mismatch_error(self): + @jax.experimental.custom_dce.custom_dce + def f(x): + return jnp.stack((x, x)), jnp.cos(x) + + @f.def_dce + def rule(used_outs, x): + return ( + jnp.exp(x) if used_outs[0] else None, + x.astype(jnp.int32) if used_outs[1] else None, + ) + + x = jnp.array(1.1234) + with self.assertRaisesRegex( + ValueError, + r'Custom DCE rule .* same shapes/dtypes .* output\[0\]', + ): + jax.jit(lambda x: f(x)[0])(x) + with self.assertRaisesRegex( + ValueError, + r'Custom DCE rule .* same shapes/dtypes .* output\[1\]', + ): + jax.jit(lambda x: f(x)[1])(x) + + def test_missing_output_error(self): + @jax.experimental.custom_dce.custom_dce + def f(x): + return jnp.sin(x), jnp.cos(x) + + @f.def_dce + def rule(used_outs, x): + return None, None + + x = jnp.array(1.1234) + with self.assertRaisesRegex( + ValueError, + r'Custom DCE rule .* produce values for all .* output\[0\]', + ): + jax.jit(lambda x: f(x)[0])(x) + + def test_consts(self): + @jax.experimental.custom_dce.custom_dce + def f(x): + return np.eye(1) * jnp.sin(x), jnp.cos(x) + + @f.def_dce + def rule(used_outs, x): + return ( + np.full((1, 1), 2.0) * jnp.exp(x) if used_outs[0] else None, + jnp.sqrt(x) if used_outs[1] else None, + ) + + x = jnp.array(1.1234) + expected = rule([True, True], x) + self.assertAllClose(jax.jit(lambda x: f(x)[0])(x), expected[0]) + self.assertAllClose(jax.jit(lambda x: f(x)[1])(x), expected[1]) + + def test_resolve_kwargs_error_message(self): + @jax.experimental.custom_dce.custom_dce + def f(x, y, *, z=None): + return jnp.sin(x) * y, x * jnp.sin(y) + + @f.def_dce + def f_dce_rule(used_outs, x, y): + self.fail("should not be executed") + + with self.assertRaisesRegex( + TypeError, + r"The input arguments to the custom_dce-decorated function f(.*)\n" + r"missing a required argument: 'y'" + ): + f(0.5) + + with self.assertRaisesRegex( + TypeError, + r"The input arguments to the custom_dce-decorated function f(.*)\n" + "The following keyword arguments could not be resolved to positions: z" + ): + f(0.5, 0.1, z=1.0) + + +class CustomVmapTest(jtu.JaxTestCase): + + def test_basic(self): + @jax.custom_batching.custom_vmap + def f(x): return jnp.sin(x) + + @f.def_vmap + def rule(axis_size, in_batched, xs): + xs_batched, = in_batched + self.assertEqual(xs_batched, True) + self.assertEqual(axis_size, xs.shape[0]) + return jnp.cos(xs), xs_batched + + x, xs = jnp.array(1.), jnp.arange(3) + y = f(x) + self.assertAllClose(y, jnp.sin(x)) + ys = api.vmap(f)(xs) + self.assertAllClose(ys, jnp.cos(xs)) + + @jax.numpy_dtype_promotion('standard') + def test_closure(self): + z = jnp.array([2., 1., 3.]) + + @jax.custom_batching.custom_vmap + def f(x): return z + jnp.sin(x) + + @f.def_vmap + def rule(axis_size, in_batched, *args): + self.assertEqual(len(in_batched), 1) + self.assertEqual(len(args), 1) + xs, = args + xs_batched, = in_batched + self.assertEqual(xs_batched, True) + self.assertEqual(axis_size, xs.shape[0]) + return z + jnp.cos(xs), xs_batched + + x, xs = jnp.array(1.), jnp.arange(3) + y = f(x) + self.assertAllClose(y, z + jnp.sin(x)) + ys = api.vmap(f)(xs) + self.assertAllClose(ys, z + jnp.cos(xs)) + + def test_rule_multi_output(self): + @jax.custom_batching.custom_vmap + def f(x): return jnp.sin(x), jnp.cos(x) + + @f.def_vmap + def rule(axis_size, in_batched, xs): + return (jnp.cos(xs), jnp.sin(xs)), tuple(in_batched * 2) + + x, xs = jnp.array(1.), jnp.arange(3) + y1, y2 = f(x) + self.assertAllClose(y1, jnp.sin(x)) + self.assertAllClose(y2, jnp.cos(x)) + ys1, ys2 = api.vmap(f)(xs) + self.assertAllClose(ys1, jnp.cos(xs)) + self.assertAllClose(ys2, jnp.sin(xs)) + + def test_nary(self): + @jax.custom_batching.custom_vmap + def f(x, y): return jnp.sin(x) + y ** 2. + + @f.def_vmap + def rule(axis_size, in_batched, xs, ys): + self.assertEqual(in_batched, [True, True]) + self.assertEqual(axis_size, 3) + self.assertEqual(axis_size, xs.shape[0]) + self.assertEqual(axis_size, ys.shape[0]) + return jnp.cos(xs) + ys ** 2., True + + xs, ys = jnp.arange(3.0), jnp.arange(3.0) + zs = api.vmap(f)(xs, ys) + self.assertAllClose(zs, jnp.cos(xs) + ys ** 2.) + + def test_nary_mixed_batching(self): + @jax.custom_batching.custom_vmap + def vector_dot(u, v): + self.assertEqual(u.ndim, 1) + self.assertEqual(v.ndim, 1) + return u @ v + + size = 4 + vlen = 3 + in_batched_log = [] + + @vector_dot.def_vmap + def vector_dot_vmap_rule(axis_size, in_batched, u, v): + in_batched_log.append(in_batched) + self.assertEqual(axis_size, size) + u_batched, v_batched = in_batched + if u_batched: + self.assertEqual(u.ndim, 2) + self.assertEqual(u.shape[0], size) + else: + self.assertEqual(u.ndim, 1) + self.assertEqual(u.shape[0], vlen) + if v_batched: + self.assertEqual(v.ndim, 2) + self.assertEqual(v.shape[0], size) + else: + self.assertEqual(v.ndim, 1) + self.assertEqual(v.shape[0], vlen) + if u_batched and v_batched: + out = jnp.sum(u * v, axis=1) + else: + out = u @ v if u_batched else v @ u + return out, u_batched or v_batched + + f = vector_dot + v = lambda *shape: jnp.ones(shape) + + y = api.vmap(f, in_axes=(0, None))(v(4, 3), v(3)) + self.assertAllClose(y, v(4, 3) @ v(3)) + y = api.vmap(f, in_axes=(1, None))(v(3, 4), v(3)) + self.assertAllClose(y, v(3, 4).T @ v(3)) + y = api.vmap(f, in_axes=(None, 0))(v(3), v(4, 3)) + self.assertAllClose(y, v(3) @ v(4, 3).T) + y = api.vmap(f, in_axes=(0, 0))(v(4, 3), v(4, 3)) + self.assertAllClose(y, jnp.sum(v(4, 3) * v(4, 3), axis=1)) + self.assertEqual(in_batched_log[0], [True, False]) + self.assertEqual(in_batched_log[1], [True, False]) + self.assertEqual(in_batched_log[2], [False, True]) + self.assertEqual(in_batched_log[3], [True, True]) + + def test_rule_input_signature(self): + @jax.custom_batching.custom_vmap + def f(x): return jnp.sin(x) + + rule_args = [] + + @f.def_vmap + def rule(axis_size, in_batched, xs): + rule_args.append((axis_size, in_batched)) + return jnp.cos(xs), in_batched[0] + + xs = jnp.arange(3) + _ = api.vmap(f)(xs) + (axis_size, in_batched), = rule_args + self.assertIs(type(axis_size), int) + self.assertIs(type(in_batched), list) + self.assertEqual(len(in_batched), 1) + + def test_rule_output_vs_batching_output_mismatch(self): + @jax.custom_batching.custom_vmap + def f(x): return jnp.sin(x) + + @f.def_vmap + def test_rule_abc(axis_size, in_batched, xs): + return [jnp.sin(xs), jnp.cos(xs)], in_batched + + xs = jnp.arange(3) + self.assertRaisesRegex( + ValueError, + 'structure of output value and output batching specification ' + r'returned by custom vmap rule \(test_rule_abc\) do not match.*', + lambda: api.vmap(f)(xs)) + + def test_rule_vs_call_output_mismatch(self): + @jax.custom_batching.custom_vmap + def f(x): return jnp.sin(x) + + @f.def_vmap + def test_rule_abc2(axis_size, in_batched, xs): + return [jnp.sin(xs)], in_batched + + xs = jnp.arange(3) + self.assertRaisesRegex( + ValueError, + r'structure of output returned by custom vmap rule \(test_rule_abc2\) ' + r'does not match that of original custom-vmapped function.*', + lambda: api.vmap(f)(xs)) + + def test_jvp_basic(self): + @jax.custom_batching.custom_vmap + def f(x): return jnp.sin(x) + + @f.def_vmap + def rule(axis_size, in_batched, xs): + self.assertEqual(axis_size, 3) + self.assertEqual(in_batched, [True]) + return jnp.cos(xs), in_batched[0] + + f_jvp = lambda x, tx: api.jvp(f, [x], [tx]) + + x, tx = jnp.array(1.), jnp.array(2.) + xs, txs = jnp.arange(3.), jnp.arange(3.) * 2. + + y, ty = f_jvp(x, tx) + self.assertAllClose(y, jnp.sin(x)) + self.assertAllClose(ty, jnp.cos(x) * tx) + + ys, tys = api.vmap(f_jvp)(xs, txs) + self.assertAllClose(ys, jnp.cos(xs)) + self.assertAllClose(tys, -jnp.sin(xs) * txs) + + ys, tys = api.jvp(api.vmap(f), [xs], [txs]) + self.assertAllClose(ys, jnp.cos(xs)) + self.assertAllClose(tys, -jnp.sin(xs) * txs) + + @jax.numpy_dtype_promotion('standard') + def test_jvp_closure(self): + z = jnp.array([2., 1., 3.]) + def bcast(x): return z + x - z + + @jax.custom_batching.custom_vmap + def f(x): return z + jnp.sin(x) + + @f.def_vmap + def rule(axis_size, in_batched, xs): + self.assertEqual(axis_size, 3) + self.assertEqual(in_batched, [True]) + return z + jnp.cos(xs), in_batched[0] + + f_jvp = lambda x, tx: api.jvp(f, [x], [tx]) + + x, tx = jnp.array(1.), jnp.array(2.) + xs, txs = jnp.arange(3.), jnp.arange(3.) * 2. + + y, ty = f_jvp(x, tx) + self.assertAllClose(y, z + jnp.sin(x)) + self.assertAllClose(ty, bcast(jnp.cos(x)) * tx) + + ys, tys = api.vmap(f_jvp)(xs, txs) + self.assertAllClose(ys, z + jnp.cos(xs)) + self.assertAllClose(tys, bcast(-jnp.sin(xs)) * txs) + + ys, tys = api.jvp(api.vmap(f), [xs], [txs]) + self.assertAllClose(ys, z + jnp.cos(xs)) + self.assertAllClose(tys, bcast(-jnp.sin(xs)) * txs) + + def test_jvp_nary(self): + @jax.custom_batching.custom_vmap + def f(x, y): return jnp.sin(x) + y + + @f.def_vmap + def rule(axis_size, in_batched, xs, ys): + self.assertEqual(axis_size, 3) + self.assertEqual(in_batched, [True, True]) + return jnp.cos(xs) + ys, True + + f_jvp = lambda x, y, tx, ty: api.jvp(f, [x, y], [tx, ty]) + + x, y, tx, ty = jnp.arange(4.) + xs, ys, txs, tys = 4. + jnp.arange(3. * 4).reshape((4, 3)) + + zs, tzs = api.vmap(f_jvp)(xs, ys, txs, tys) + self.assertAllClose(zs, jnp.cos(xs) + ys) + self.assertAllClose(tzs, -jnp.sin(xs) * txs + tys) + + zs, tzs = api.jvp(api.vmap(f), [xs, ys], [txs, tys]) + self.assertAllClose(zs, jnp.cos(xs) + ys) + self.assertAllClose(tzs, -jnp.sin(xs) * txs + tys) + + def test_jvp_extra_batched_tangents(self): + @jax.custom_batching.custom_vmap + def f(x): return jnp.sin(x) + + @f.def_vmap + def rule(axis_size, in_batched, xs): + self.assertEqual(axis_size, 3) + self.assertEqual(in_batched, [False]) + return jnp.cos(xs), in_batched[0] + + f_jvp = lambda x, tx: api.jvp(f, [x], [tx]) + + txs = 2. + jnp.arange(3.) + x = jnp.array(1, dtype=txs.dtype) + y, tys = api.vmap(f_jvp, in_axes=(None, 0), out_axes=(None, 0))(x, txs) + self.assertAllClose(y, jnp.cos(x)) + self.assertAllClose(tys, -jnp.sin(x) * txs) + + def test_jacfwd(self): + # jacfwd is another way to exercise extra-batched tangents + + @jax.custom_batching.custom_vmap + def f(x): return jnp.sin(x) + + @f.def_vmap + def rule(axis_size, in_batched, xs): + self.assertEqual(axis_size, 3) + self.assertEqual(in_batched, [False]) + return jnp.cos(xs), in_batched[0] + + x = jnp.arange(3.) + .72 + j = api.jacfwd(f)(x) + self.assertAllClose(j, -jnp.diag(jnp.sin(x))) + + def test_jvp_extra_batched_primals(self): + @jax.custom_batching.custom_vmap + def f(x): return jnp.sin(x) + + @f.def_vmap + def rule(axis_size, in_batched, xs): + self.assertEqual(axis_size, 3) + self.assertEqual(in_batched, [False]) + return jnp.cos(xs), in_batched[0] + + f_jvp = lambda x, tx: api.jvp(f, [x], [tx]) + + xs = jnp.arange(3.) + tx = jnp.array(4, dtype=xs.dtype) + ys, tys = api.vmap(f_jvp, in_axes=(0, None))(xs, tx) + self.assertAllClose(ys, jnp.cos(xs)) + self.assertAllClose(tys, -jnp.sin(xs) * tx) + + def test_jvp_extra_batched_primals_with_linear_vmap_rule(self): + # When a function is linear, its Jacobian is constant. JAX's JVP + # of linear functions takes advantage of this: when mapping over a + # batch of primals relative to a fixed (i.e. symbolically + # replicated) tangent, output tangents remain replicated as well + # (i.e. JAX will not broadcast them). This is true in general, and + # this test checks that vmapped JVPs continue to behave this way + # when custom_vmap is involved and the custom vmap rule is linear. + + @jax.custom_batching.custom_vmap + def f_linear(x): return 7. * x + + @f_linear.def_vmap + def linear_rule(axis_size, in_batched, xs): + return 11. * xs, in_batched[0] + + @jax.custom_batching.custom_vmap + def f_nonlinear(x): return jnp.sin(x) + + @f_nonlinear.def_vmap + def nonlinear_rule(axis_size, in_batched, xs): + return jnp.cos(xs), in_batched[0] + + f_lin_jvp = lambda x, tx: api.jvp(f_linear, [x], [tx]) + f_non_jvp = lambda x, tx: api.jvp(f_nonlinear, [x], [tx]) + xs = jnp.arange(3.) + tx = jnp.array(4., dtype=xs.dtype) + + # doesn't err + _ = api.vmap(f_lin_jvp, in_axes=(0, None), out_axes=(0, None))(xs, tx) + + # does err + self.assertRaisesRegex( + ValueError, "at vmap out_axes", + lambda: api.vmap( + f_non_jvp, in_axes=(0, None), out_axes=(0, None))(xs, tx)) + + def test_jvp_dataflow_violation(self): + # The jvp-of-custom-vmap machinery should not assume the standard + # dataflow constraint on the JVP of the custom vmap rule (primal + # outputs independent of tangent inputs). Both jvp and vmap are + # "forward" transformations under which, at present, we don't + # enforce the JVP dependence diagram. Because output primals can + # depend on input tangents, extra-batched input tangents can + # create batched output primals, as this test checks. + + @jax.custom_jvp + def cos_with_invalid_dataflow_jvp(x): return jnp.cos(x) + + @cos_with_invalid_dataflow_jvp.defjvp + def invalid_dataflow_jvp(x, tx): + [x], [tx] = x, tx + return jnp.cos(x * tx), tx + + @jax.custom_batching.custom_vmap + def f(x): return jnp.sin(x) + + @f.def_vmap + def rule(axis_size, in_batched, xs): + return cos_with_invalid_dataflow_jvp(xs), in_batched[0] + + f_jvp = lambda x, tx: api.jvp(f, [x], [tx]) + txs = 2. + jnp.arange(3.) + x = jnp.array(1, dtype=txs.dtype) + + # doesn't err + ys, tys = api.vmap(f_jvp, in_axes=(None, 0))(x, txs) + self.assertAllClose(ys, jnp.cos(x * txs)) + self.assertAllClose(tys, txs) + + # does err + self.assertRaisesRegex( + ValueError, "at vmap out_axes", + lambda: api.vmap( + f_jvp, in_axes=(None, 0), out_axes=(None, 0))(x, txs)) + + def test_tree(self): + tree_sin = partial(jax.tree.map, jnp.sin) + tree_cos = partial(jax.tree.map, jnp.cos) + + x, xs = jnp.array(1.), jnp.arange(3) + x = (x, [x + 1, x + 2], [x + 3], x + 4) + xs = (xs, [xs + 1, xs + 2], [xs + 3], xs + 4) + in_batched_ref = jax.tree.map(lambda _: True, x) + + @jax.custom_batching.custom_vmap + def f(xs): return tree_sin(xs) + + @f.def_vmap + def rule(axis_size, in_batched, xs): + self.assertEqual(in_batched, [in_batched_ref]) + sz, = {z.shape[0] for z in jax.tree.leaves(xs)} + self.assertEqual(axis_size, sz) + return tree_cos(xs), in_batched[0] + + y = f(x) + self.assertAllClose(y, tree_sin(x)) + ys = api.vmap(f)(xs) + self.assertAllClose(ys, tree_cos(xs)) + + def test_tree_with_nones(self): + tree_sin = partial(jax.tree.map, jnp.sin) + tree_cos = partial(jax.tree.map, jnp.cos) + + x, xs = jnp.array(1.), jnp.arange(3) + x = (x, [x + 1, None], [x + 3], None) + xs = (xs, [xs + 1, None], [xs + 3], None) + in_batched_ref = jax.tree.map(lambda _: True, x) + + @jax.custom_batching.custom_vmap + def f(xs): return tree_sin(xs) + + @f.def_vmap + def rule(axis_size, in_batched, xs): + self.assertEqual(in_batched, [in_batched_ref]) + sz, = {z.shape[0] for z in jax.tree.leaves(xs)} + self.assertEqual(axis_size, sz) + return tree_cos(xs), in_batched[0] + + y = f(x) + self.assertAllClose(y, tree_sin(x)) + ys = api.vmap(f)(xs) + self.assertAllClose(ys, tree_cos(xs)) + + def test_jit(self): + @jax.custom_batching.custom_vmap + def f(x): return jnp.sin(x) + + @f.def_vmap + def rule(axis_size, in_batched, xs): + self.assertEqual(in_batched, [True]) + self.assertEqual(axis_size, xs.shape[0]) + return jnp.cos(xs), in_batched[0] + + x, xs = jnp.array(1.), jnp.arange(3) + self.assertAllClose(f(x), jit(f)(x)) + self.assertAllClose(jit(api.vmap(f))(xs), api.vmap(f)(xs)) + self.assertAllClose(api.vmap(jit(f))(xs), api.vmap(f)(xs)) + + def test_sequential_vmap_basic(self): + @jax.custom_batching.sequential_vmap + def f(x): + return x + 1. + + def vmap_ref(xs): + return lax.map(f, xs) + + xs = jnp.arange(3.) + jaxpr = api.make_jaxpr(api.vmap(f))(xs) + jaxpr_ref = api.make_jaxpr(vmap_ref)(xs) + + self.assertEqual(str(jaxpr), str(jaxpr_ref)) + + def test_sequential_vmap_nary_same_batching(self): + @jax.custom_batching.sequential_vmap + def f(x, y): + return x + y + + def vmap_ref(xs, ys): + return lax.map(lambda args: f(*args), (xs, ys)) + + xs, ys = jnp.arange(3.), 4. + jnp.arange(3.) + jaxpr = api.make_jaxpr(api.vmap(f))(xs, ys) + jaxpr_ref = api.make_jaxpr(vmap_ref)(xs, ys) + + self.assertEqual(str(jaxpr), str(jaxpr_ref)) + + def test_sequential_vmap_nary_mixed_batching(self): + @jax.custom_batching.sequential_vmap + def f(x, y): + return x + y + + def vmap_ref(xs, y): + return lax.map(lambda x: f(x, y), xs) + + xs, y = jnp.arange(3.), 4. + jaxpr = api.make_jaxpr(api.vmap(f, in_axes=(0, None)))(xs, y) + jaxpr_ref = api.make_jaxpr(vmap_ref)(xs, y) + + self.assertEqual(str(jaxpr), str(jaxpr_ref)) + + @parameterized.named_parameters( + ("1", 1), + ("8", 4), + ("12", 8), + ("16", 16), + ) + def test_batch_map_basic(self, batch_size: int): + def f(x): + self.assertEqual(x.shape, ()) + return x**2 + + x = np.arange(16) + y = jax.lax.map(f, x, batch_size=batch_size) + + np.testing.assert_array_equal(y, x**2) + + @parameterized.named_parameters( + ("1", 1), + ("8", 4), + ("12", 8), + ("16", 16), + ) + def test_batch_map_pytrees(self, batch_size: int): + f = lambda x: {'b': x['a'] ** 2} + inputs = {'a': np.arange(16)} + expected = np.arange(16) ** 2 + + outputs = jax.lax.map(f, inputs, batch_size=batch_size) + self.assertAllClose(outputs['b'], expected) + + outputs = jax.lax.map( + f, inputs, batch_size=batch_size + ) + self.assertAllClose(outputs['b'], expected) + + def test_batch_divides_axis(self): + def f(t): + x, a = t + self.assertEqual(x.shape, (4,)) + return (x + a)**2 + + x = jax.random.randint(jax.random.key(0), (16, 4), -10, 10) + a = jax.random.randint(jax.random.key(1), (16, 4), -10, 10) + + @jax.jit + def g(x, a): + return jax.lax.map(f, (x, a), batch_size=8) + + y = g(x, a) + + self.assertAllClose(y, (x + a)**2) + + def test_undefined_rule(self): + @jax.custom_batching.custom_vmap + def f(x): return jnp.sin(x) + + with self.assertRaisesRegex( + AttributeError, "No batching rule defined for custom_vmap function f"): + f(0.5) + + def test_kwargs(self): + @jax.custom_batching.custom_vmap + def f(x): return jnp.sin(x) + + @f.def_vmap + def rule(axis_size, in_batched, xs): + xs_batched, = in_batched + self.assertEqual(xs_batched, True) + self.assertEqual(axis_size, xs.shape[0]) + return jnp.cos(xs), xs_batched + + x, xs = jnp.array(1.), jnp.arange(3) + y = f(x=x) + self.assertAllClose(y, jnp.sin(x)) + ys = api.vmap(f)(x=xs) + self.assertAllClose(ys, jnp.cos(xs)) + + def test_partial_eval_raises(self): + @jax.custom_batching.custom_vmap + def f(x): + return jnp.sin(x) + + @f.def_vmap + def rule(axis_size, in_batched, xs): + del axis_size # unused + return jnp.cos(xs), in_batched[0] + + with self.assertRaisesRegex( + ValueError, + "Linearization failed to produce known values for all output primals", + ): + jax.grad(f)(0.5) + + def test_compose_custom_vjp(self): + @jax.custom_vjp + @jax.custom_batching.custom_vmap + def f(x, y): + return jnp.sin(x) * y + + @f.def_vmap + def f_vmap_rule(axis_size, in_batched, xs, ys): + return jnp.cos(xs) * ys, True + + def f_fwd(x, y): + return f(x, y), (jnp.cos(x), jnp.sin(x), y) + + def f_bwd(res, g): + cos_x, sin_x, y = res + return (cos_x * g * y, sin_x * g) + + f.defvjp(f_fwd, f_bwd) + + xs = jnp.linspace(0, 1, 5) + ys = jnp.linspace(-0.1, 0.1, 5) + self.assertAllClose(jax.vmap(f)(xs, ys), jnp.cos(xs) * ys) + jax.grad(f)(xs[0], ys[0]) # Doesn't crash. + + def test_compose_custom_vjp_bwd_rule(self): + # This tests the case where both the forward and backward rules are wrapped + # in custom_vmap. + @jax.custom_batching.sequential_vmap + def fun_fwd(x, y): + return jnp.sin(x) * y, (x, y) + + @jax.custom_batching.sequential_vmap + def fun_bwd(res, ct): + x, y = res + return x * ct, y * ct + + fun = jax.custom_vjp(lambda *args: fun_fwd(*args)[0]) + fun.defvjp(fun_fwd, fun_bwd) + + xs = jnp.linspace(0, 1, 5) + y = jnp.array(0.5, dtype=xs.dtype) + f = jax.vmap(jax.jit(fun), in_axes=(0, None)) + out, f_vjp = jax.vjp(f, xs, y) + f_vjp(out) # Doesn't crash. + + def test_resolve_kwargs_error_message(self): + @jax.custom_batching.custom_vmap + def f(x, y, *, z=None): + return jnp.sin(x) * y + + @f.def_vmap + def f_vmap_rule(axis_size, in_batched, xs, ys): + self.fail("should not be executed") + + with self.assertRaisesRegex( + TypeError, + r"The input arguments to the custom_vmap-decorated function f(.*)\n" + r"missing a required argument: 'y'" + ): + f(0.5) + + with self.assertRaisesRegex( + TypeError, + r"The input arguments to the custom_vmap-decorated function f(.*)\n" + "The following keyword arguments could not be resolved to positions: z" + ): + f(0.5, 0.1, z=1.0) + + +class CustomApiTest(jtu.JaxTestCase): + """Test interactions among the custom_{vmap,jvp,vjp,transpose,*} APIs""" + + def test_method_forwarding(self): + @jax.custom_batching.custom_vmap + @jax.custom_jvp + @jax.custom_transpose.custom_transpose + def f(x): return 2. * x + + # none of these err: + @f.def_vmap + def f_batch(sz, b, xs): return 2. * xs + @f.defjvp + def f_jvp(x, tx): return 2. * x, 2. * tx + @f.def_transpose + def f_transpose(x): return 2. * x + + def test_def_method_forwarding_all_permutations(self): + for wraps in it.permutations([ + jax.custom_jvp, jax.custom_transpose.custom_transpose, jax.custom_batching.custom_vmap]): + f = lambda x: x + 1. + for wrap in wraps: + f = wrap(f) + for methods in it.permutations(['defjvp', 'def_vmap', 'def_transpose']): + for method in methods: + self.assertIsInstance(getattr(f, method), Callable) + + for decorators in it.permutations([ + jax.custom_vjp, jax.custom_transpose.custom_transpose, jax.custom_batching.custom_vmap]): + f = lambda x: x + 1. + for decorator in decorators: + f = decorator(f) + for methods in it.permutations(['defvjp', 'def_vmap', 'def_transpose']): + for method in methods: + self.assertIsInstance(getattr(f, method), Callable) + + +if __name__ == '__main__': + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/custom_partitioning_sharding_rule_test.py b/tests/custom_partitioning_sharding_rule_test.py index f22721910408..d7e93ddec5b2 100644 --- a/tests/custom_partitioning_sharding_rule_test.py +++ b/tests/custom_partitioning_sharding_rule_test.py @@ -383,7 +383,7 @@ def test_conversion_compound_then_individual(self): [result.result.type,]) self.assertEqual( str(mlir_rule), - "#sdy.op_sharding_rule<([ij])->([i, j]) {i=2, j=4}>") + "#sdy.op_sharding_rule<([ij])->([i, j]) {i=2, j=4}, custom>") def test_conversion_elementwise_rule_scalar_instance(self): opnd0 = self.create_tensor_value(()) @@ -399,7 +399,7 @@ def test_conversion_elementwise_rule_scalar_instance(self): [result.result.type,]) self.assertEqual( str(mlir_rule), - "#sdy.op_sharding_rule<([], [])->([])>") + "#sdy.op_sharding_rule<([], [])->([]), custom>") def test_conversion_elementwise_rule_2D_instance(self): opnd0 = self.create_tensor_value((16, 32)) @@ -415,7 +415,7 @@ def test_conversion_elementwise_rule_2D_instance(self): [result.result.type,]) self.assertEqual( str(mlir_rule), - "#sdy.op_sharding_rule<([i, j], [i, j])->([i, j]) {i=16, j=32}>") + "#sdy.op_sharding_rule<([i, j], [i, j])->([i, j]) {i=16, j=32}, custom>") def test_conversion_vector_scalar_add_2D_instance(self): opnd0 = self.create_tensor_value((16, 32)) @@ -431,7 +431,7 @@ def test_conversion_vector_scalar_add_2D_instance(self): [result.result.type,]) self.assertEqual( str(mlir_rule), - "#sdy.op_sharding_rule<([i, j], [])->([i, j]) {i=16, j=32}>") + "#sdy.op_sharding_rule<([i, j], [])->([i, j]) {i=16, j=32}, custom>") def test_conversion_reshape_rule(self): opnd0 = self.create_tensor_value((2, 4)) @@ -446,7 +446,7 @@ def test_conversion_reshape_rule(self): [result.result.type,]) self.assertEqual( str(mlir_rule), - "#sdy.op_sharding_rule<([i, j])->([ij]) {i=2, j=4}>") + "#sdy.op_sharding_rule<([i, j])->([ij]) {i=2, j=4}, custom>") def test_conversion_contracting_dim_matmul(self): opnd0 = self.create_tensor_value((16, 32)) @@ -462,7 +462,7 @@ def test_conversion_contracting_dim_matmul(self): [result.result.type,]) self.assertEqual( str(mlir_rule), - "#sdy.op_sharding_rule<([i, j], [j, k])->([i, k]) {i=16, j=32, k=8}>") + "#sdy.op_sharding_rule<([i, j], [j, k])->([i, k]) {i=16, j=32, k=8}, custom>") def test_conversion_multiple_batching_groups(self): @@ -479,7 +479,7 @@ def test_conversion_multiple_batching_groups(self): [result.result.type,]) self.assertEqual( str(mlir_rule), - "#sdy.op_sharding_rule<([i, j, k, l], [m, n, o, l, k])->([i, j, l, k]) {i=4, j=5, k=16, l=32, m=6, n=7, o=8}>") + "#sdy.op_sharding_rule<([i, j, k, l], [m, n, o, l, k])->([i, j, l, k]) {i=4, j=5, k=16, l=32, m=6, n=7, o=8}, custom>") if __name__ == "__main__": diff --git a/tests/debug_info_test.py b/tests/debug_info_test.py index a39b53c3ad16..41207f903b2a 100644 --- a/tests/debug_info_test.py +++ b/tests/debug_info_test.py @@ -30,7 +30,7 @@ from jax.experimental import checkify import jax.experimental.custom_dce from jax.experimental import pallas as pl -from jax.experimental.shard_map import shard_map +from jax._src.shard_map import shard_map import jax.numpy as jnp import jax.scipy as jsp @@ -46,6 +46,7 @@ from jax._src.compilation_cache import is_persistent_cache_enabled from jax._src.lax.control_flow import for_loop from jax._src.interpreters import mlir +from jax._src import util as util import numpy as np @@ -130,7 +131,7 @@ def _check_tracers_and_jaxprs(self, traceable: Any, mode. The debug infos in the nested Jaxprs are first converted to strings using `_debug_info_to_string` and then compared against `expected_jaxpr_debug_infos`. During this conversion, - we strip occurences of this test file name and a line number + we strip occurrences of this test file name and a line number (e.g., .*/debug_info_test.py:56) An element of `expected_jaxpr_debug_infos` can be a string, in which case it is compared by equality, or a `re.Pattern` (the result of `re.compile`) @@ -241,7 +242,7 @@ def my_f(x, y, z, w): dbg = api_util.debug_info("jit", my_f, (1, 2), dict(z=3, w=4)) self.assertRegex(dbg.func_src_info, r"^my_f at .*debug_info_test.py:\d+") self.assertEqual(dbg.func_name, "my_f") - self.assertEqual(dbg.arg_names, ("x", "y", "z", "w")) + self.assertEqual(dbg.arg_names, ("x", "y", "w", "z")) self.assertIsNone(dbg.result_paths) def test_debug_info_arg_passed_as_kwarg(self): @@ -261,23 +262,29 @@ def my_f(x_tree, *, y_tree): "y_tree['w']", "y_tree['z']")) def test_debug_info_with_statics(self): - def my_f(x, y, *, z, w): + def my_f(x, z, *, w, y): pass - dbg = api_util.debug_info("jit", my_f, (1, 2), dict(z=3, w=4), + dbg = api_util.debug_info("jit", my_f, (1,), dict(y=2, z=3, w=4), static_argnums=(1,), static_argnames=("w",)) - self.assertEqual(dbg.arg_names, ("x", "z")) + self.assertEqual(dbg.arg_names, ("x", "y", "z")) def test_debug_info_with_pytrees_and_statics(self): - def my_f(x, y, *, z, w): + def my_f(x, y, *, z, w, t): pass dbg = api_util.debug_info("jit", my_f, ((1, 2), (2, 3)), - dict(z=(3, 4), w=(5, 6)), + dict(z=(3, 4), w=(5, 6), t=7), + static_argnums=(1,), + static_argnames=("w",)) + self.assertEqual(dbg.arg_names, ("x[0]", "x[1]", "t", "z[0]", "z[1]")) + + dbg = api_util.debug_info("jit", my_f, ((1, 2),), + dict(z=(3, 4), w=(5, 6), t=7, y=3), static_argnums=(1,), static_argnames=("w",)) - self.assertEqual(dbg.arg_names, ("x[0]", "x[1]", "z[0]", "z[1]")) + self.assertEqual(dbg.arg_names, ("x[0]", "x[1]", "t", "y", "z[0]", "z[1]")) def test_debug_info_too_many_args(self): def my_f(x): @@ -287,15 +294,20 @@ def my_f(x): self.assertEqual(dbg.arg_names, ('args[0]', 'args[1]', 'args[2]', "kwargs['z']")) def test_debug_info_no_source_info_built_in(self): - # built-in function "int" does not have an inspect.Signature + # built-in function "max" does not have an inspect.Signature dbg = api_util.debug_info("jit", max, (1,), {}) self.assertEqual(dbg.func_src_info, "max") + self.assertEqual(dbg.func_name, "max") + self.assertEqual(dbg.func_filename, None) + self.assertEqual(dbg.func_lineno, None) self.assertEqual(dbg.arg_names, ("args[0]",)) def test_debug_info_lambda(self): # built-in function "int" does not have an inspect.Signature dbg = api_util.debug_info("jit", lambda my_arg: False, (1,), {}) self.assertRegex(dbg.func_src_info, r"^ at .*debug_info_test.py:\d+") + self.assertEndsWith(dbg.func_filename, "debug_info_test.py") + self.assertIsNotNone(dbg.func_lineno) self.assertEqual(dbg.arg_names, ("my_arg",)) def test_debug_info_save_wrapped_fun_source_info(self): @@ -380,66 +392,6 @@ def f(x): with self.assertRaisesRegex(TypeError, err_str): jax.jit(f)(jnp.int32) - @jtu.thread_unsafe_test() # logging is not thread-safe - def test_arg_names_cache_miss_explanations(self): - @jax.jit - def f(x, y): - return jnp.sin(x) * y['hi'] - - x = jnp.float32(1.) - y = {'hi': jnp.arange(3., dtype='float32')} - - expected_log_len = 1 if not is_persistent_cache_enabled() else 3 - - # print on first miss, not on hit - with config.explain_cache_misses(True): - with self.assertLogs(level='WARNING') as cm: - f(x, y) - f(x, y) - self.assertLen(cm.output, expected_log_len) - msg = cm.output[0] - self.assertIn('TRACING CACHE MISS', msg) - self.assertIn('never seen function', msg) - - # shape change - y_ = {'hi': jnp.arange(4, dtype='float32')} - with config.explain_cache_misses(True): - with self.assertLogs(level='WARNING') as cm: - f(x, y_) - self.assertLen(cm.output, expected_log_len) - msg = cm.output[0] - self.assertIn('never seen input type signature', msg) - self.assertIn('closest seen input type signature has 1 mismatches', msg) - self.assertIn('seen f32[3], but now given f32[4]', msg) - - # weak type change (assuming no x64) - if not config.enable_x64.value: - with config.explain_cache_misses(True): - with self.assertLogs(level='WARNING') as cm: - f(1., y) - self.assertLen(cm.output, expected_log_len) - msg = cm.output[0] - self.assertIn('weak_type=True', msg) - self.assertIn('https://jax.readthedocs.io/en/latest/type_promotion.html#weak-types', msg) - - # kwarg change - with config.explain_cache_misses(True): - with self.assertLogs(level='WARNING') as cm: - f(1, y=y) - self.assertLen(cm.output, expected_log_len) - msg = cm.output[0] - self.assertIn('never seen passing 1 positional args and 1 keyword args', msg) - - # tracing config change - with config.explain_cache_misses(True): - with self.assertLogs(level='WARNING') as cm: - with jax.numpy_rank_promotion('warn'): - f(x, y) - # depending on the backend, we may or may not get persistent cache warnings - self.assertTrue(1 <= len(cm.output) <= expected_log_len) - msg = cm.output[0] - self.assertIn("tracing context doesn't match", msg) - @jtu.thread_unsafe_test() # logging is not thread-safe def test_arg_names_cache_miss_explanations_new_function_in_loop(self): @jax.jit @@ -671,7 +623,7 @@ def my_g(b, d=1): tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ # TODO(necula): result_paths? - "traced_for=jit, fun=my_f, arg_names=a, result_paths=", + "traced_for=jit, fun=my_f, arg_names=a, result_paths=result", "traced_for=jit, fun=my_g, arg_names=b, result_paths=result", ], expected_tracer_debug_infos=[ @@ -761,6 +713,122 @@ def f(x, y, *args, **kwargs): re.compile(r".*func.func public @main\(.*\{jax.result_info = \"result\"\}"), ]) + def test_jit_arg_names_with_out_of_order_kwargs(self): + tracer_spy = TracerSpy() + + # The shapes are different, to differentiate them easily + a1 = (np.float32(0),) # a hashable tuple, can be static + b2 = np.arange(2, dtype=np.float32) # b2 + z3 = np.arange(3, dtype=np.float32) + y4 = (np.float32(0.), np.float32(1.), np.float32(2.), np.float32(3.)) + x5 = np.arange(5, dtype=np.float32) + u6 = np.arange(6, dtype=np.float32) + t7 = np.arange(7, dtype=np.float32) + + def my_f(a1, b2, z3, y4, x5, *, u6, t7): + assert np.shape(a1[0]) == () + assert np.shape(b2) == (2,) + assert np.shape(z3) == (3,) + assert np.shape(y4) == (4,) + assert np.shape(x5) == (5,) + assert np.shape(u6) == (6,) + assert np.shape(t7) == (7,) + tracer_spy.append(b2) + tracer_spy.append(x5) + return a1[0] + b2[0] + z3[0] + y4[0] + x5[0] + u6[0] + t7[0] + + self._check_tracers_and_jaxprs( + jax.jit(my_f, static_argnums=(0,), static_argnames=("y4",)), + # Some positional args passed as keyword + a1, b2, x5=x5, y4=y4, z3=z3, t7=t7, u6=u6, + expected_jaxpr_debug_infos=[ + "traced_for=jit, fun=my_f, arg_names=b2,t7,u6,x5,z3, result_paths=result", + ], + tracer_spy=tracer_spy, + expected_tracer_debug_infos=[ + "traced_for=jit, fun=my_f, arg_names=b2,t7,u6,x5,z3, from b2", + "traced_for=jit, fun=my_f, arg_names=b2,t7,u6,x5,z3, from x5", + ], + expected_lowering_lines=[ + re.compile(r".*func.func public @main\(%arg0: tensor<2xf..> loc\(\"b2\"\)"), + re.compile(r".*func.func public @main\(.*%arg1: tensor<7xf..> loc\(\"t7\"\)"), + re.compile(r".*func.func public @main\(.*%arg2: tensor<6xf..> loc\(\"u6\"\)"), + re.compile(r".*func.func public @main\(.*%arg3: tensor<5xf..> loc\(\"x5\"\)"), + ] + ) + + tracer_spy.tracers = [] + util.clear_all_caches() + self._check_tracers_and_jaxprs( + jax.jit(my_f, static_argnames=("y4",)), + # Positional argument y4 is static and passed by kwarg + a1, b2, z3, x5=x5, y4=y4, t7=t7, u6=u6, + expected_jaxpr_debug_infos=[ + "traced_for=jit, fun=my_f, arg_names=a1[0],b2,z3,t7,u6,x5, result_paths=result", + ], + tracer_spy=tracer_spy, + expected_tracer_debug_infos=[ + "traced_for=jit, fun=my_f, arg_names=a1[0],b2,z3,t7,u6,x5, from b2", + "traced_for=jit, fun=my_f, arg_names=a1[0],b2,z3,t7,u6,x5, from x5", + ], + expected_lowering_lines=[ + re.compile(r".*func.func public @main\(%arg0: tensor loc\(\"a1\[0\]\"\)"), + re.compile(r".*func.func public @main\(.*%arg1: tensor<2xf..> loc\(\"b2\"\)"), + re.compile(r".*func.func public @main\(.*%arg2: tensor<3xf..> loc\(\"z3\"\)"), + re.compile(r".*func.func public @main\(.*%arg3: tensor<7xf..> loc\(\"t7\"\)"), + re.compile(r".*func.func public @main\(.*%arg4: tensor<6xf..> loc\(\"u6\"\)"), + re.compile(r".*func.func public @main\(.*%arg5: tensor<5xf..> loc\(\"x5\"\)"), + ] + ) + + tracer_spy.tracers = [] + util.clear_all_caches() + self._check_tracers_and_jaxprs( + jax.jit(my_f, static_argnames=("y4",)), + # Positional argument y4 is static (declared as static_argnames) + a1, b2, z3, y4, x5=x5, t7=t7, u6=u6, + expected_jaxpr_debug_infos=[ + "traced_for=jit, fun=my_f, arg_names=a1[0],b2,z3,t7,u6,x5, result_paths=result", + ], + tracer_spy=tracer_spy, + expected_tracer_debug_infos=[ + "traced_for=jit, fun=my_f, arg_names=a1[0],b2,z3,t7,u6,x5, from b2", + "traced_for=jit, fun=my_f, arg_names=a1[0],b2,z3,t7,u6,x5, from x5", + ], + expected_lowering_lines=[ + re.compile(r".*func.func public @main\(%arg0: tensor loc\(\"a1\[0\]\"\)"), + re.compile(r".*func.func public @main\(.*%arg1: tensor<2xf..> loc\(\"b2\"\)"), + re.compile(r".*func.func public @main\(.*%arg2: tensor<3xf..> loc\(\"z3\"\)"), + re.compile(r".*func.func public @main\(.*%arg3: tensor<7xf..> loc\(\"t7\"\)"), + re.compile(r".*func.func public @main\(.*%arg4: tensor<6xf..> loc\(\"u6\"\)"), + re.compile(r".*func.func public @main\(.*%arg5: tensor<5xf..> loc\(\"x5\"\)"), + ] + ) + + tracer_spy.tracers = [] + util.clear_all_caches() + self._check_tracers_and_jaxprs( + jax.jit(my_f, static_argnums=(3,)), + # Positional argument y4 is static (declared as static_argnums) + a1, b2, z3, y4, x5=x5, t7=t7, u6=u6, + expected_jaxpr_debug_infos=[ + "traced_for=jit, fun=my_f, arg_names=a1[0],b2,z3,t7,u6,x5, result_paths=result", + ], + tracer_spy=tracer_spy, + expected_tracer_debug_infos=[ + "traced_for=jit, fun=my_f, arg_names=a1[0],b2,z3,t7,u6,x5, from b2", + "traced_for=jit, fun=my_f, arg_names=a1[0],b2,z3,t7,u6,x5, from x5", + ], + expected_lowering_lines=[ + re.compile(r".*func.func public @main\(%arg0: tensor loc\(\"a1\[0\]\"\)"), + re.compile(r".*func.func public @main\(.*%arg1: tensor<2xf..> loc\(\"b2\"\)"), + re.compile(r".*func.func public @main\(.*%arg2: tensor<3xf..> loc\(\"z3\"\)"), + re.compile(r".*func.func public @main\(.*%arg3: tensor<7xf..> loc\(\"t7\"\)"), + re.compile(r".*func.func public @main\(.*%arg4: tensor<6xf..> loc\(\"u6\"\)"), + re.compile(r".*func.func public @main\(.*%arg5: tensor<5xf..> loc\(\"x5\"\)"), + ] + ) + def test_jit_result_info(self): def f(x, y, z): return {'a': x, 'b': [y]} @@ -794,7 +862,7 @@ def my_g(u, v): tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ "traced_for=jit, fun=my_f, arg_names=x,y, result_paths=result", - "traced_for=jit, fun=my_g, arg_names=u,v, result_paths=result['c']" + "traced_for=jit, fun=my_g, arg_names=u,v, result_paths=result['c'],result['d']", ], expected_tracer_debug_infos=[ "traced_for=jit, fun=my_f, arg_names=x,y, from x", @@ -886,20 +954,27 @@ def my_g(u, v): return dict(c=u * v, d=v) return jax.jit(my_g)(y, x)["c"] + if config.use_direct_linearize.value: + expected_jaxpr_debug_infos = [ + "traced_for=jit, fun=, arg_names=x,y,res_ct, result_paths=result[0],result[1]", + # TODO(necula): result_paths + "traced_for=jit, fun=my_g, arg_names=u,v, result_paths=", + # TODO(necula): arg_names + "traced_for=jit, fun=my_g, arg_names=u,v,,, result_paths=result['c']", + ] + else: + expected_jaxpr_debug_infos = [ + "traced_for=jit, fun=, arg_names=x,y,res_ct, result_paths=result[0],result[1]", + "traced_for=jit, fun=my_g, arg_names=u,v, result_paths=result['c']", + # TODO(necula): arg_names + "traced_for=jit, fun=my_g, arg_names=,,u,v, result_paths=result['c'],result['d']", + ] self._check_tracers_and_jaxprs( jax.jit(lambda x, y, res_ct: jax.vjp(my_f, x, y)[1](res_ct)), 2., 3., 0.3, tracer_spy=tracer_spy, - expected_jaxpr_debug_infos=[ - "traced_for=jit, fun=, arg_names=x,y,res_ct, result_paths=result[0],result[1]", - # TODO(necula): result_paths - "traced_for=jit, fun=my_g, arg_names=u,v, result_paths=", - # TODO(necula): arg_names - "traced_for=jit, fun=my_g, arg_names=u,v,,, result_paths=," - if config.use_direct_linearize.value else - "traced_for=jit, fun=my_g, arg_names=,,u,v, result_paths=result['c'],result['d']", - ], + expected_jaxpr_debug_infos=expected_jaxpr_debug_infos, expected_tracer_debug_infos=[ # TODO(necula): missing debug info "None", @@ -1145,7 +1220,6 @@ def fn_tp(r, t): tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ "traced_for=jit, fun=, arg_names=x, result_paths=result[0]['c']", - "traced_for=linear_call fun, fun=fn, arg_names=r,x['c'], result_paths=result['b']", "traced_for=linear_call fun_transpose, fun=fn_tp, arg_names=r,t['c'], result_paths=result['c']", ], expected_tracer_debug_infos=[ @@ -1312,24 +1386,37 @@ def the_grad(c, as_): _, pullback = jax.vjp(my_f, c, as_) return pullback((c, np.arange(3, dtype=c.dtype))) + if config.use_direct_linearize.value: + expected_jaxpr_debug_infos = [ + "traced_for=jit, fun=the_grad, arg_names=c,as_, result_paths=result[0],result[1]", + "traced_for=jit, fun=my_f, arg_names=x,as_, result_paths=,,", + "traced_for=for_loop, fun=f, arg_names=,,, result_paths=,", + "traced_for=for_loop, fun=f, arg_names=i,refs[0],refs[1],refs[2], result_paths=", + "traced_for=jit, fun=my_f, arg_names=as_,,, result_paths=result[0],result[1]", + "traced_for=checkpoint / remat, fun=to_remat, arg_names=,,, result_paths=,", + "traced_for=for_loop, fun=f, arg_names=,,,,,, result_paths=,", + "traced_for=for_loop, fun=f, arg_names=i,refs[0],refs[1],refs[2], result_paths=", + "traced_for=for_loop, fun=f, arg_names=,,,,,,,,,,,,,,, result_paths=,", + "traced_for=for_loop, fun=f, arg_names=,,,,,,,,,,, result_paths=", + ] + else: + expected_jaxpr_debug_infos = [ + "traced_for=jit, fun=the_grad, arg_names=c,as_, result_paths=result[0],result[1]", + "traced_for=jit, fun=my_f, arg_names=x,as_, result_paths=result[0],result[1]", + "traced_for=for_loop, fun=f, arg_names=,,, result_paths=,", + "traced_for=for_loop, fun=f, arg_names=i,refs[0],refs[1],refs[2], result_paths=", + "traced_for=jit, fun=my_f, arg_names=,,x,as_, result_paths=result[0],result[1]", + "traced_for=checkpoint / remat, fun=to_remat, arg_names=,,, result_paths=,", + "traced_for=for_loop, fun=f, arg_names=,,,,,, result_paths=,", + "traced_for=for_loop, fun=f, arg_names=i,refs[0],refs[1],refs[2], result_paths=", + "traced_for=for_loop, fun=f, arg_names=,,,,,,,,,,,,,,, result_paths=,", + "traced_for=for_loop, fun=f, arg_names=,,,,,,,,,,, result_paths=", + ] self._check_tracers_and_jaxprs( jax.jit(the_grad), c, as_, tracer_spy=tracer_spy, - expected_jaxpr_debug_infos=[ - "traced_for=jit, fun=the_grad, arg_names=c,as_, result_paths=result[0],result[1]", - # TODO(necula): arg names, bad result paths - "traced_for=jit, fun=my_f, arg_names=x,as_, result_paths=,,", - "traced_for=for_loop, fun=f, arg_names=i,refs[0],refs[1],refs[2], result_paths=", - "traced_for=for_loop, fun=f, arg_names=,,, result_paths=,", - "traced_for=for_loop, fun=f, arg_names=,,,,,, result_paths=,", - "traced_for=for_loop, fun=f, arg_names=,,,,,,,,,,, result_paths=", - "traced_for=for_loop, fun=f, arg_names=,,,,,,,,,,,,,,, result_paths=,", - "traced_for=checkpoint / remat, fun=to_remat, arg_names=,,, result_paths=,", - "traced_for=jit, fun=my_f, arg_names=as_,,, result_paths=" - if config.use_direct_linearize.value else - "traced_for=jit, fun=my_f, arg_names=,,x,as_, result_paths=", - ], + expected_jaxpr_debug_infos=expected_jaxpr_debug_infos, expected_tracer_debug_infos=[ "traced_for=jit, fun=the_grad, arg_names=c,as_, from c", "traced_for=scan, fun=f, arg_names=c,a, from c", @@ -1467,7 +1554,7 @@ def my_g(u, v): tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ "traced_for=jit, fun=my_f, arg_names=x,y, result_paths=result", - "traced_for=jit, fun=my_g, arg_names=u,v, result_paths=result['c']", + "traced_for=jit, fun=my_g, arg_names=u,v, result_paths=result['c'],result['d']", ], expected_tracer_debug_infos=[ # TODO(necula): missing debug info @@ -1495,34 +1582,50 @@ def my_f(x): def test_pmap_with_arg_and_result_names(self): tracer_spy = TracerSpy() - x = np.ones((jax.device_count(),), dtype=np.float32) - def my_f(x, y, *args, a, **kwargs): - # y and kwargs[c] is dead + + # Use different shapes arguments to distinguish them in the HLO + def my_f(x0, y1, *args, b4, **kwargs): + assert np.shape(x0) == () + assert np.shape(y1) == (1,) + assert np.shape(args[0]) == (2,) + assert np.shape(args[1]) == (3,) + assert np.shape(b4) == (4,) + assert np.shape(kwargs["a5"]) == (5,) + assert np.shape(kwargs["c6"]) == (6,) + # kwargs[b5] is dead tracer_spy.append(args[1]) - s = x + a + args[1] + kwargs["d"] - return dict(u=s, v=x) + tracer_spy.append(b4) + tracer_spy.append(kwargs["c6"]) + s0 = x0 + y1[0] + b4[0] + args[1][0] + kwargs["c6"][0] + return dict(v1=jnp.broadcast_to(s0, (1,)), u0=s0) self._check_tracers_and_jaxprs( jax.pmap(my_f, static_broadcasted_argnums=(0,)), - 1., x, x, x, # x, y, args[0], args[1] - d=x, a=x, b=x, # kwargs + 1., # x0 + np.ones((jax.device_count(), 1), dtype=np.float32), # y1 + np.ones((jax.device_count(), 2), dtype=np.float32), # args[0] + np.ones((jax.device_count(), 3), dtype=np.float32), # args[1] + b4=np.ones((jax.device_count(), 4), dtype=np.float32), + a5=np.ones((jax.device_count(), 5), dtype=np.float32), + c6=np.ones((jax.device_count(), 6), dtype=np.float32), expected_jaxpr_debug_infos=[ - "traced_for=pmap, fun=my_f, arg_names=y,args[0],args[1],a,kwargs['b'],kwargs['d'], result_paths=result['u'],result['v']", + "traced_for=pmap, fun=my_f, arg_names=y1,args[0],args[1],kwargs['a5'],b4,kwargs['c6'], result_paths=result['u0'],result['v1']", ], tracer_spy=tracer_spy, expected_tracer_debug_infos=[ - "traced_for=pmap, fun=my_f, arg_names=y,args[0],args[1],a,kwargs['b'],kwargs['d'], from args[1]", + "traced_for=pmap, fun=my_f, arg_names=y1,args[0],args[1],kwargs['a5'],b4,kwargs['c6'], from args[1]", + "traced_for=pmap, fun=my_f, arg_names=y1,args[0],args[1],kwargs['a5'],b4,kwargs['c6'], from b4", + "traced_for=pmap, fun=my_f, arg_names=y1,args[0],args[1],kwargs['a5'],b4,kwargs['c6'], from kwargs['c6']", ], expected_lowering_lines=[ - # TODO(necula): we did not DCE y? - re.compile(r".*func.func public @main\(.*%arg0: tensor<1xf..> loc\(\"y\"\)"), - re.compile(r".*func.func public @main\(.*%arg1: tensor<1xf..> loc\(\"args\[0\]\"\)"), - re.compile(r".*func.func public @main\(.*%arg2: tensor<1xf..> loc\(\"args\[1\]\"\)"), - re.compile(r".*func.func public @main\(.*%arg3: tensor<1xf..> loc\(\"a\"\)"), - re.compile(r".*func.func public @main\(.*%arg4: tensor<1xf..> loc\(\"kwargs\['b'\]\"\)"), - re.compile(r".*func.func public @main\(.*%arg5: tensor<1xf..> loc\(\"kwargs\['d'\]\"\)"), - re.compile(r".*func.func public @main\(.* -> .*\{jax.result_info = \"result\['u'\]\"\}"), - re.compile(r".*func.func public @main\(.* -> .*\{jax.result_info = \"result\['v'\]\"\}"), + re.compile(r".*func.func public @main\(.*%arg0: tensor<1x1xf..> loc\(\"y1\"\)"), + re.compile(r".*func.func public @main\(.*%arg1: tensor<1x2xf..> loc\(\"args\[0\]\"\)"), + re.compile(r".*func.func public @main\(.*%arg2: tensor<1x3xf..> loc\(\"args\[1\]\"\)"), + re.compile(r".*func.func public @main\(.*%arg3: tensor<1x5xf..> loc\(\"kwargs\['a5'\]\"\)"), + re.compile(r".*func.func public @main\(.*%arg4: tensor<1x4xf..> loc\(\"b4\"\)"), + re.compile(r".*func.func public @main\(.*%arg5: tensor<1x6xf..> loc\(\"kwargs\['c6'\]\"\)"), + re.compile(r".*func.func public @main\(.* -> .*\{jax.result_info = \"result\['u0'\]\"\}"), + re.compile(r".*func.func public @main\(.* -> .*\{jax.result_info = \"result\['v1'\]\"\}"), ] ) @@ -1606,17 +1709,23 @@ def my_f(x): x = jax.random.uniform(jax.random.key(0), shape=(8, 4)) + if config.use_direct_linearize.value: + expected_jaxpr_debug_infos = [ + "traced_for=jit, fun=my_f, arg_names=x, result_paths=result", + "traced_for=jit, fun=my_f, arg_names=x, result_paths=,", + "traced_for=jit, fun=my_f, arg_names=x,, result_paths=result" + ] + else: + expected_jaxpr_debug_infos = [ + "traced_for=jit, fun=my_f, arg_names=x, result_paths=result", + "traced_for=jit, fun=my_f, arg_names=x, result_paths=,", + "traced_for=jit, fun=my_f, arg_names=,x, result_paths=result" + ] + self._check_tracers_and_jaxprs( jax.jit(jax.hessian(jax.jit(my_f))), x, - expected_jaxpr_debug_infos=[ - "traced_for=jit, fun=my_f, arg_names=x, result_paths=result", - # TODO(necula): arg_names and result_paths? - "traced_for=jit, fun=my_f, arg_names=x, result_paths=,,,", - "traced_for=jit, fun=my_f, arg_names=x,, result_paths=," - if config.use_direct_linearize.value else - "traced_for=jit, fun=my_f, arg_names=,x, result_paths=,", - ], + expected_jaxpr_debug_infos=expected_jaxpr_debug_infos, tracer_spy=tracer_spy, expected_tracer_debug_infos=[ "traced_for=jit, fun=my_f, arg_names=x, from x", @@ -1697,7 +1806,7 @@ def my_f(x): "traced_for=shard_map, fun=my_f, arg_names=,, result_paths=", ], expected_tracer_debug_infos=[ - "None" # TODO(necula): missing + "traced_for=shard_map, fun=my_f, arg_names=x, from x" ]) def test_remat_saved_residuals(self): diff --git a/tests/debug_nans_test.py b/tests/debug_nans_test.py index c80d23c416df..29fde318756c 100644 --- a/tests/debug_nans_test.py +++ b/tests/debug_nans_test.py @@ -22,7 +22,7 @@ from jax._src import test_util as jtu from jax import numpy as jnp from jax.experimental import pjit -from jax.experimental.shard_map import shard_map +from jax._src.shard_map import shard_map from jax.sharding import PartitionSpec as P jax.config.parse_flags_with_absl() diff --git a/tests/debugger_test.py b/tests/debugger_test.py index 419e7b18dfed..0d66cd47d8cc 100644 --- a/tests/debugger_test.py +++ b/tests/debugger_test.py @@ -43,6 +43,10 @@ def _format_multiline(text): foo = 2 +# This test is thread-unsafe because jax.effects_barrier() is global. This means +# that we can create a deadlock if running tests in multiple threads because we +# can introduce false dependencies via the effects barrier. +@jtu.thread_unsafe_test_class() class CliDebuggerTest(jtu.JaxTestCase): def setUp(self): diff --git a/tests/debugging_primitives_test.py b/tests/debugging_primitives_test.py index a8d59bc39e36..9c23f136b825 100644 --- a/tests/debugging_primitives_test.py +++ b/tests/debugging_primitives_test.py @@ -16,7 +16,7 @@ import textwrap import unittest -from absl.testing import absltest +from absl.testing import absltest, parameterized import jax from jax import lax from jax.experimental import pjit @@ -25,6 +25,7 @@ from jax._src import debugging from jax._src import dispatch from jax._src import test_util as jtu +from jax.sharding import PartitionSpec as P import jax.numpy as jnp import numpy as np @@ -274,6 +275,28 @@ def f(x): jax.effects_barrier() self.assertEqual(output(), "[1.23 2.35 0. ]\n") + @parameterized.parameters([False, True]) + def test_debug_print_in_unrolled_loop(self, use_jit): + def body(i, _): + jax.debug.print("{}", i) + if use_jit: + body = jax.jit(body) + @jax.jit + def f(): + return jax.lax.fori_loop(0, 4, body, None, unroll=2) + with jtu.capture_stdout() as output: + f() + jax.effects_barrier() + actual = tuple(sorted(map(int, output().splitlines()))) + self.assertEqual(actual, tuple(range(4))) + + def test_debug_print_extended_dtype(self): + def f(k): + jax.debug.print("{}", k) + with jtu.capture_stdout(): + f(jax.random.key(0)) # doesn't crash + jax.effects_barrier() + @jtu.thread_unsafe_test_class() # printing isn't thread-safe class DebugPrintTransformationTest(jtu.JaxTestCase): @@ -419,8 +442,6 @@ def f(x): with jtu.capture_stdout() as output: jax.linear_transpose(f, 1.)(1.) jax.effects_barrier() - # `debug_print` should be dropped by `partial_eval` because of no - # output data-dependence. self.assertEqual(output(), "") @jtu.sample_product(ordered=[False, True]) @@ -1120,6 +1141,28 @@ def test_visualize_pmap_sharding(self): """) self.assertEqual(output(), expected) + def test_visualize_sharding_shard_map(self): + mesh = jtu.create_mesh((2,), 'x') + + def f(): + a = jnp.zeros(1000) + debugging.visualize_array_sharding(a) + return a + + with jtu.capture_stdout() as output: + f() # doesn't crash + + with jtu.capture_stdout() as output: + jax.jit(f, out_shardings=jax.NamedSharding(mesh, P('x')))() # doesn't crash + + with jtu.capture_stdout() as output: + jax.shard_map(f, mesh=mesh, in_specs=P(None), out_specs=P("x"))() # doesn't crash + + with jtu.capture_stdout() as output: + jax.shard_map(f, mesh=mesh, in_specs=P(None), out_specs=P("x"), + check_vma=False)() # doesn't crash + + class InspectShardingTest(jtu.JaxTestCase): def test_inspect_sharding_is_called_in_pjit(self): @@ -1203,6 +1246,164 @@ def f_(x): f(arr) +def _get_output_set(output, num_lines): + """Return a set of strings where each string is num_lines.""" + output = output().strip().split("\n") + return { + "\n".join(output[i : i + num_lines]) + for i in range(0, len(output), num_lines) + } + + +@jtu.thread_unsafe_test_class() # printing isn't thread-safe +class PartitionedDebugCallbackTest(jtu.JaxTestCase): + + def setUp(self): + super().setUp() + if (jtu.device_under_test() not in ("cpu", "gpu")): + raise unittest.SkipTest( + f"Test requires CPU or GPU devices. Got {jtu.device_under_test()}" + ) + if len(jax.devices()) < 2: + raise unittest.SkipTest("Test requires >= 2 devices.") + + def tearDown(self): + super().tearDown() + dispatch.runtime_tokens.clear() + + def test_partitioned_debug_callback(self): + def f_(x): + debug_print("hello: {x}", x=x, partitioned=True) + + f = pjit.pjit(f_) + mesh = jtu.create_mesh((1, 1, 2,), ("x", "y", "z")) + s = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec("x", "y", "z")) + arr = jax.device_put(np.arange(24).reshape(2, 3, 4), s) + + with jtu.capture_stdout() as output: + with mesh: + f(arr) + jax.effects_barrier() + + expected = { + _format_multiline(""" + hello: [[[ 0 1] + [ 4 5] + [ 8 9]] + + [[12 13] + [16 17] + [20 21]]]"""), + _format_multiline(""" + hello: [[[ 2 3] + [ 6 7] + [10 11]] + + [[14 15] + [18 19] + [22 23]]]"""), + } + self.assertEqual(_get_output_set(output, 7), expected) + + def test_debug_print_batching(self): + @jax.vmap + def f_(x): + debug_print("hello: {}", x, partitioned=True) + + f = pjit.pjit(f_) + mesh = jtu.create_mesh((1, 1, 2), ("x", "y", "z")) + s = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec("x", "y", "z")) + arr = np.arange(24).reshape(2, 3, 4) + arr = jax.device_put(arr, s) + + with jtu.capture_stdout() as output: + with mesh: + f(arr) + jax.effects_barrier() + + expected = { + _format_multiline(""" + hello: [[0 1] + [4 5] + [8 9]]"""), + _format_multiline(""" + hello: [[ 2 3] + [ 6 7] + [10 11]]"""), + _format_multiline(""" + hello: [[14 15] + [18 19] + [22 23]]"""), + _format_multiline(""" + hello: [[12 13] + [16 17] + [20 21]]"""), + } + + self.assertEqual(_get_output_set(output, 3), expected) + + def test_debug_print_batching_with_diff_axes(self): + @functools.partial(jax.vmap, in_axes=(0, 1)) + def f_(x, y): + debug_print("hello: {} {}", x, y, partitioned=True) + + f = pjit.pjit(f_) + mesh = jtu.create_mesh((2,), ("x")) + s = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec("x")) + x = np.arange(4).reshape(2, 2) + x = jax.device_put(x, s) + y = np.arange(4).reshape(2, 2) + 6 + y = jax.device_put(y, s) + + with jtu.capture_stdout() as output: + with mesh: + f(x, y) + jax.effects_barrier() + + expected = { + "hello: [2 3] [9]", + "hello: [0 1] [6]", + "hello: [0 1] [8]", + "hello: [2 3] [7]", + } + + self.assertEqual(_get_output_set(output, 1), expected) + + def test_debug_print_with_nested_vmap(self): + @jax.vmap + @jax.vmap + def f_(x): + debug_print("hello: {}", x, partitioned=True) + + f = pjit.pjit(f_) + mesh = jtu.create_mesh((1, 1, 2), ("x", "y", "z")) + s = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec("x", "y", "z")) + arr = np.arange(24).reshape(2, 3, 4) + arr = jax.device_put(arr, s) + + with jtu.capture_stdout() as output: + with mesh: + f(arr) + jax.effects_barrier() + + expected = { + "hello: [14 15]", + "hello: [12 13]", + "hello: [18 19]", + "hello: [16 17]", + "hello: [22 23]", + "hello: [20 21]", + "hello: [2 3]", + "hello: [0 1]", + "hello: [6 7]", + "hello: [10 11]", + "hello: [4 5]", + "hello: [8 9]", + } + + self.assertEqual(_get_output_set(output, 1), expected) + + if not rich: del VisualizeShardingTest diff --git a/tests/distributed_initialize_test.py b/tests/distributed_initialize_test.py new file mode 100644 index 000000000000..33242a41a68e --- /dev/null +++ b/tests/distributed_initialize_test.py @@ -0,0 +1,44 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from absl.testing import absltest +import jax +from jax._src import test_util as jtu + +try: + import portpicker +except ImportError: + portpicker = None + +jax.config.parse_flags_with_absl() + + +@unittest.skipIf(not portpicker, "Test requires portpicker") +class DistributedInitializeTest(jtu.JaxTestCase): + + @jtu.skip_under_pytest( + """Side effects from jax.distributed.initialize conflict with other tests + in the same process. pytest runs multiple tests in the same process.""" + ) + def test_is_distributed_initialized(self): + port = portpicker.pick_unused_port() # type: ignore + self.assertFalse(jax.distributed.is_initialized()) + jax.distributed.initialize(f"localhost:{port}", 1, 0) + self.assertTrue(jax.distributed.is_initialized()) + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/distributed_test.py b/tests/distributed_test.py index 3961932dfad0..ae72143fbe7d 100644 --- a/tests/distributed_test.py +++ b/tests/distributed_test.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import subprocess -import sys import threading import unittest @@ -43,7 +41,10 @@ def testInitializeAndShutdown(self): # concurrency to simulate multiple tasks. port = portpicker.pick_unused_port() jax.distributed.initialize( - coordinator_address=f"localhost:{port}", num_processes=1, process_id=0 + coordinator_address=f"localhost:{port}", + num_processes=1, + process_id=0, + cluster_detection_method="deactivate", ) jax.distributed.shutdown() @@ -57,7 +58,10 @@ def task(i): # We can't call the public APIs directly because they use global state. state = distributed.State() state.initialize( - coordinator_address=f"localhost:{port}", num_processes=n, process_id=i + coordinator_address=f"localhost:{port}", + num_processes=n, + process_id=i, + cluster_detection_method="deactivate", ) state.shutdown() @@ -67,22 +71,6 @@ def task(i): for thread in threads: thread.join() - def test_is_distributed_initialized(self): - # Run in subprocess to isolate side effects from jax.distributed.initialize which conflict with other - # tests. Unfortunately this can't be avoided by calling jax.distributed.shutdown, as the XLA backend - # will be warmed up, which yields a RuntimeError on subsequent calls to initialize. - port = portpicker.pick_unused_port() # type: ignore - cmd = f"""import jax; - assert not jax.distributed.is_initialized(); - jax.distributed.initialize('localhost:{port}', 1, 0); - assert jax.distributed.is_initialized(); - """.replace("\n", ' ') - - result = subprocess.run([sys.executable, "-c", cmd], capture_output=True) - self.assertEqual( - result.returncode, 0, msg=f"Test failed with:\n{result.stdout}\n{result.stderr}" - ) - if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/dtypes_test.py b/tests/dtypes_test.py index 87380443f4cb..d8fb30397b27 100644 --- a/tests/dtypes_test.py +++ b/tests/dtypes_test.py @@ -46,30 +46,19 @@ np.dtype('uint64')] unsigned_dtypes = list(np_unsigned_dtypes) -intn_dtypes = [np.dtype('int4'), np.dtype('uint4')] -signed_dtypes += [np.dtype('int4')] -unsigned_dtypes += [np.dtype('uint4')] -if dtypes.int2 is not None: - assert dtypes.uint2 is not None - intn_dtypes[:0] = [np.dtype('int2'), np.dtype('uint2')] - signed_dtypes[:0] = [np.dtype('int2')] - unsigned_dtypes[:0] = [np.dtype('uint2')] - -np_float_dtypes = [np.dtype('float16'), np.dtype('float32'), - np.dtype('float64')] +intn_dtypes = [np.dtype('int2'), np.dtype('uint2'), np.dtype('int4'), np.dtype('uint4')] +signed_dtypes += [np.dtype('int2'), np.dtype('int4')] +unsigned_dtypes += [np.dtype('uint2'), np.dtype('uint4')] + +np_float_dtypes = [np.dtype('float16'), np.dtype('float32'), np.dtype('float64')] float_dtypes = [np.dtype(dtypes.bfloat16)] + np_float_dtypes custom_float_dtypes = [np.dtype(dtypes.bfloat16)] fp8_dtypes = [np.dtype(dtypes.float8_e4m3b11fnuz), np.dtype(dtypes.float8_e4m3fn), np.dtype(dtypes.float8_e4m3fnuz), np.dtype(dtypes.float8_e5m2), - np.dtype(dtypes.float8_e5m2fnuz)] -if dtypes.float8_e3m4 is not None: - fp8_dtypes += [np.dtype(dtypes.float8_e3m4)] -if dtypes.float8_e4m3 is not None: - fp8_dtypes += [np.dtype(dtypes.float8_e4m3)] -if dtypes.float8_e8m0fnu is not None: - fp8_dtypes += [np.dtype(dtypes.float8_e8m0fnu)] + np.dtype(dtypes.float8_e5m2fnuz), np.dtype(dtypes.float8_e3m4), + np.dtype(dtypes.float8_e4m3), np.dtype(dtypes.float8_e8m0fnu)] float_dtypes += fp8_dtypes custom_float_dtypes += fp8_dtypes diff --git a/tests/error_check_test.py b/tests/error_check_test.py index b96c6281411f..e20017a39a9b 100644 --- a/tests/error_check_test.py +++ b/tests/error_check_test.py @@ -13,12 +13,16 @@ # limitations under the License. +import traceback + from absl.testing import absltest from absl.testing import parameterized import jax from jax._src import config from jax._src import error_check +from jax._src import mesh as mesh_lib from jax._src import test_util as jtu +import jax.export import jax.numpy as jnp from jax.sharding import NamedSharding, PartitionSpec as P @@ -30,7 +34,9 @@ jtu.request_cpu_devices(4) -@jtu.with_config(jax_check_tracer_leaks=True) +# TODO: AOT tests fails with the tracer leak checker. +# Re-enable once https://github.com/jax-ml/jax/issues/27315 is fixed. +# @jtu.with_config(jax_check_tracer_leaks=True) class ErrorCheckTests(jtu.JaxTestCase): @parameterized.product(jit=[True, False]) @@ -107,6 +113,32 @@ def g(x): with self.assertRaisesRegex(JaxValueError, "x must be greater than 0 in g"): error_check.raise_if_error() + @parameterized.product(jit=[True, False]) + def test_error_includes_traceback(self, jit): + def function_that_triggers_error_for_traceback_test(x): + error_check.set_error_if( # This line must be included in the traceback. + x <= 0, "x must be greater than 0" + ) + return x + 1 + + if jit: + function_that_triggers_error_for_traceback_test = jax.jit( + function_that_triggers_error_for_traceback_test + ) + + x = jnp.zeros((4,), dtype=jnp.int32) + function_that_triggers_error_for_traceback_test(x) + + tb_string = "" + try: + error_check.raise_if_error() + except JaxValueError as e: + tb_string = traceback.format_tb(e.__traceback__) + tb_string = "".join(tb_string) + + self.assertIn("function_that_triggers_error_for_traceback_test", tb_string) + self.assertIn("This line must be included in the traceback", tb_string) + @parameterized.product(jit=[True, False]) def test_error_check_works_with_cond(self, jit): def f(x): @@ -193,7 +225,7 @@ def f(x): jax.jit(error_check.raise_if_error)() @parameterized.product(jit=[True, False]) - @jtu.with_user_mesh((2, 2), ("x", "y")) + @jtu.with_explicit_mesh((2, 2), ("x", "y")) def test_error_check_explicit_mode(self, mesh, jit): def f(x): error_check.set_error_if(x <= 0, "x must be greater than 0") @@ -202,13 +234,144 @@ def f(x): if jit: f = jax.jit(f) - sharding = NamedSharding(mesh, P("x", "y")) - x = jnp.full((4, 4), -1, dtype=jnp.int32, device=sharding) with error_check.error_checking_context(): + x = jnp.full((4, 4), -1, dtype=jnp.int32) + f(x) + with self.assertRaisesRegex(JaxValueError, "x must be greater than 0"): + error_check.raise_if_error() + + sharding = NamedSharding(mesh, P("x", "y")) + with error_check.error_checking_context(): + y = jnp.full((4, 4), -1, dtype=jnp.int32, device=sharding) + f(y) + with self.assertRaisesRegex(JaxValueError, "x must be greater than 0"): + error_check.raise_if_error() + + # The unsharded version of `f` should still be able to check errors after + # exiting the error checking context. + f(x) + with self.assertRaisesRegex(JaxValueError, "x must be greater than 0"): + error_check.raise_if_error() + + @parameterized.product(jit=[True, False]) + @jtu.with_explicit_mesh( + (2, 2), + ("x", "y"), + axis_types=(mesh_lib.AxisType.Auto, mesh_lib.AxisType.Auto), + ) + @jtu.ignore_warning( + message=( + "When at least one mesh axis of `pred` is in auto mode, calling" + " `set_error_if` will cause implicit communication between devices." + " To avoid this, consider converting the mesh axis in auto mode to" + " explicit mode." + ), + category=RuntimeWarning, + ) + def test_error_check_auto_mode(self, jit, mesh): + def f(x): + error_check.set_error_if(x <= 0, "x must be greater than 0") + return x + 1 + + if jit: + f = jax.jit(f) + + with error_check.error_checking_context(): + sharding = NamedSharding(mesh, P("x", "y")) + x = jnp.full((4, 4), -1, dtype=jnp.int32, device=sharding) f(x) with self.assertRaisesRegex(JaxValueError, "x must be greater than 0"): error_check.raise_if_error() + def test_error_check_aot(self): + def run_export(): + def f(x): + error_check.set_error_if(x <= 0, "x must be greater than 0") + return x + 1 + + f = jax.jit(error_check.wrap_for_export(jax.jit(f))) + x = jax.ShapeDtypeStruct((), jnp.float32) + serialized = jax.export.export(f)(x).serialize() + return serialized + + def run_import(serialized): + f = jax.export.deserialize(serialized).call + f = jax.jit(error_check.unwrap_from_import(jax.jit(f))) + x = jnp.float32(-3.) + _ = f(x) + with self.assertRaisesRegex(JaxValueError, "x must be greater than 0"): + error_check.raise_if_error() + + serialized = run_export() + run_import(serialized) + + def test_error_check_aot_includes_traceback(self): + def run_export(): + def function_that_triggers_error_for_traceback_test(x): + error_check.set_error_if( # This line must be included in the traceback + x <= 0, "x must be greater than 0" + ) + return x + 1 + + f = jax.jit( + error_check.wrap_for_export( + jax.jit(function_that_triggers_error_for_traceback_test) + ) + ) + x = jax.ShapeDtypeStruct((), jnp.float32) + serialized = jax.export.export(f)(x).serialize() + return serialized + + def run_import(serialized): + f = jax.export.deserialize(serialized).call + f = jax.jit(error_check.unwrap_from_import(jax.jit(f))) + x = jnp.float32(-3.0) + _ = f(x) + + msg = "" + try: + error_check.raise_if_error() + except JaxValueError as e: + msg = str(e) + + self.assertIn("function_that_triggers_error_for_traceback_test", msg) + self.assertIn("This line must be included in the traceback", msg) + + serialized = run_export() + run_import(serialized) + + def test_error_check_aot_should_not_override_existing_error(self): + def f1(x): + error_check.set_error_if(x <= 0, "x must be greater than 0 in f1") + return x + 1 + + def run_export(): + def f2(x): + error_check.set_error_if(x <= 0, "x must be greater than 0 in f2") + return x + 1 + + f2 = jax.jit(error_check.wrap_for_export(jax.jit(f2))) + x = jax.ShapeDtypeStruct((), jnp.float32) + serialized = jax.export.export(f2)(x).serialize() + return serialized + + def run_import(serialized): + f2 = jax.export.deserialize(serialized).call + f2 = jax.jit(error_check.unwrap_from_import(jax.jit(f2))) + return f2 + + x = jnp.float32(-3.) + _ = f1(x) # check fails. so it should set error + + serialized = run_export() + f2 = run_import(serialized) + _ = f2(x) # check fails, but should not override the error + + with self.assertRaisesRegex( + JaxValueError, "x must be greater than 0 in f1" + ): + error_check.raise_if_error() + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/errors_test.py b/tests/errors_test.py index 25f29cfee224..356ca0713adf 100644 --- a/tests/errors_test.py +++ b/tests/errors_test.py @@ -13,7 +13,6 @@ # limitations under the License. import re -import sys import traceback from absl.testing import absltest @@ -46,10 +45,7 @@ def check_filtered_stack_trace(test, etype, f, frame_patterns=(), e = get_exception(etype, f) c = e.__cause__ if filter_mode == "quiet_remove_frames": - if sys.version_info >= (3, 11): - assert any("For simplicity" in x for x in e.__notes__) - else: - test.assertIsInstance(c, jax.errors.SimplifiedTraceback) + assert any("For simplicity" in x for x in e.__notes__) elif filter_mode == "remove_frames": test.assertIsInstance(c, traceback_util.UnfilteredStackTrace) else: @@ -393,12 +389,8 @@ def outer(x): ('', 'f = lambda: outer'), ('outer', 'raise TypeError')], filter_mode=filter_mode) e = get_exception(TypeError, f) # Uses the default JAX_TRACEBACK_FILTERING=auto - if sys.version_info >= (3, 11): - assert any("For simplicity" in x for x in e.__notes__) - self.assertIsInstance(e.__cause__, ValueError) - else: - self.assertIsInstance(e.__cause__, jax.errors.SimplifiedTraceback) - self.assertIsInstance(e.__cause__.__cause__, ValueError) + assert any("For simplicity" in x for x in e.__notes__) + self.assertIsInstance(e.__cause__, ValueError) def test_null_traceback(self, filter_mode): class TestA: pass @@ -424,14 +416,9 @@ def test_grad_norm(self): e = exc self.assertIsNot(e, None) self.assertIn("invalid value", str(e)) - if sys.version_info >= (3, 11): - self.assertIsInstance( - e.__cause__, - source_info_util.JaxStackTraceBeforeTransformation) - else: - self.assertIsInstance( - e.__cause__.__cause__, - source_info_util.JaxStackTraceBeforeTransformation) + self.assertIsInstance( + e.__cause__, + source_info_util.JaxStackTraceBeforeTransformation) class CustomErrorsTest(jtu.JaxTestCase): @@ -455,7 +442,7 @@ class FakeTracer(core.Tracer): ErrorClass = getattr(jax.errors, errorclass) err = ErrorClass(FakeTracer(None)) - self.assertIn(f'https://jax.readthedocs.io/en/latest/errors.html#jax.errors.{errorclass}', str(err)) + self.assertIn(f'https://docs.jax.dev/en/latest/errors.html#jax.errors.{errorclass}', str(err)) if __name__ == '__main__': diff --git a/tests/experimental_rnn_test.py b/tests/experimental_rnn_test.py index 7fa3b93f3c42..58f5291e9375 100644 --- a/tests/experimental_rnn_test.py +++ b/tests/experimental_rnn_test.py @@ -213,18 +213,11 @@ def f(k1, k2, k3, k4): k = jax.random.split(jax.random.PRNGKey(1), 4) stablehlo = jax.jit(f).lower(*k).as_text("stablehlo") - if jtu.jaxlib_version() <= (0, 5, 2): - self.assertIn('"\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\01\\00\\00\\00@\\03\\80\\00@\\01\\00\\00"', - stablehlo) - else: - self.assertIn('"\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\01\\00\\00\\00@\\03\\80\\00\\00\\00\\00\\00@\\01\\00\\00\\00\\00\\00\\00"', + self.assertIn('"\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\01\\00\\00\\00@\\03\\80\\00\\00\\00\\00\\00@\\01\\00\\00\\00\\00\\00\\00"', stablehlo) @jtu.run_on_devices("cuda") def test_no_workspace_overflow(self): - if jtu.jaxlib_version() <= (0, 5, 2): - self.skipTest("Older versions fail because of integer overflow.") - # Problem sizes known to cause overflows on older versions. batch_size, max_seq_length, input_size = 256, 500, 512 num_layers, hidden_size = 1, 256 diff --git a/tests/export_back_compat_test.py b/tests/export_back_compat_test.py index 9b457b8f27a5..e742654e6740 100644 --- a/tests/export_back_compat_test.py +++ b/tests/export_back_compat_test.py @@ -31,6 +31,7 @@ from jax._src.internal_test_util import export_back_compat_test_util as bctu +from jax._src.internal_test_util.export_back_compat_test_data import annotate_data_placement from jax._src.internal_test_util.export_back_compat_test_data import cpu_cholesky_lapack_potrf from jax._src.internal_test_util.export_back_compat_test_data import cpu_eig_lapack_geev from jax._src.internal_test_util.export_back_compat_test_data import cuda_eigh_cusolver_syev @@ -38,7 +39,6 @@ from jax._src.internal_test_util.export_back_compat_test_data import cpu_eigh_lapack_syev from jax._src.internal_test_util.export_back_compat_test_data import cpu_lu_lapack_getrf from jax._src.internal_test_util.export_back_compat_test_data import cuda_qr_cusolver_geqrf -from jax._src.internal_test_util.export_back_compat_test_data import rocm_qr_hipsolver_geqrf from jax._src.internal_test_util.export_back_compat_test_data import cpu_qr_lapack_geqrf from jax._src.internal_test_util.export_back_compat_test_data import cpu_schur_lapack_gees from jax._src.internal_test_util.export_back_compat_test_data import cpu_svd_lapack_gesdd @@ -51,6 +51,7 @@ from jax._src.internal_test_util.export_back_compat_test_data import cuda_lu_cusolver_getrf from jax._src.internal_test_util.export_back_compat_test_data import cuda_svd_cusolver_gesvd from jax._src.internal_test_util.export_back_compat_test_data import cuda_tridiagonal_cusolver_sytrd +from jax._src.internal_test_util.export_back_compat_test_data import cuda_tridiagonal_solve from jax._src.internal_test_util.export_back_compat_test_data import tpu_Eigh from jax._src.internal_test_util.export_back_compat_test_data import tpu_Lu from jax._src.internal_test_util.export_back_compat_test_data import tpu_ApproxTopK @@ -63,7 +64,7 @@ from jax._src.internal_test_util.export_back_compat_test_data import stablehlo_dynamic_approx_top_k from jax.experimental import pjit -from jax.experimental.shard_map import shard_map +from jax._src.shard_map import shard_map import jax.numpy as jnp from jax.sharding import Mesh @@ -120,7 +121,7 @@ def test_custom_call_coverage(self): targets_to_cover = set(_export._CUSTOM_CALL_TARGETS_GUARANTEED_STABLE) cpu_ffi_testdatas = [ cpu_cholesky_lapack_potrf.data_2024_05_31, - cpu_qr_lapack_geqrf.data_2024_08_22, + cpu_qr_lapack_geqrf.data_2025_04_02, cpu_eig_lapack_geev.data_2024_08_19, cpu_eigh_lapack_syev.data_2024_08_19, cpu_lu_lapack_getrf.data_2024_05_31, @@ -134,26 +135,16 @@ def test_custom_call_coverage(self): # stable covering_testdatas = [ *cpu_ffi_testdatas, - cpu_cholesky_lapack_potrf.data_2023_06_19, - cpu_eig_lapack_geev.data_2023_06_19, - cpu_eigh_lapack_syev.data_2023_03_17, - cpu_qr_lapack_geqrf.data_2023_03_17, cuda_threefry2x32.data_2024_07_30, - cpu_lu_lapack_getrf.data_2023_06_14, - cuda_lu_pivots_to_permutation.data_2024_08_08, + cuda_lu_pivots_to_permutation.data_2025_04_01, cuda_lu_cusolver_getrf.data_2024_08_19, cuda_qr_cusolver_geqrf.data_2024_09_26, cuda_eigh_cusolver_syev.data_2024_09_30, cuda_svd_cusolver_gesvd.data_2024_10_08, cpu_tridiagonal_solve_lapack_gtsv.data_2025_01_09, cuda_tridiagonal_cusolver_sytrd.data_2025_01_09, - rocm_qr_hipsolver_geqrf.data_2024_08_05, + cuda_tridiagonal_solve.data_2025_06_16, rocm_eigh_hipsolver_syev.data_2024_08_05, - cpu_schur_lapack_gees.data_2023_07_16, - cpu_svd_lapack_gesdd.data_2023_06_19, - cpu_triangular_solve_blas_trsm.data_2023_07_16, - cpu_hessenberg_lapack_gehrd.data_2024_08_30, - cpu_tridiagonal_lapack_sytrd_hetrd.data_2024_09_03, tpu_Eigh.data, tpu_Lu.data_2023_03_21, tpu_Qr.data_2023_03_17, tpu_Sharding.data_2023_03_16, tpu_ApproxTopK.data_2023_04_17, tpu_ApproxTopK.data_2023_05_16, @@ -163,6 +154,8 @@ def test_custom_call_coverage(self): stablehlo_dynamic_top_k.data_2023_07_16, stablehlo_dynamic_top_k.data_2023_08_11, # with shape_assertion stablehlo_dynamic_approx_top_k.data_2024_05_30, + annotate_data_placement.data_2025_04_07_tpu, + annotate_data_placement.data_2025_04_07_cuda, ] # Some of the above are nested structures. covering_testdatas = itertools.chain( @@ -175,6 +168,7 @@ def test_custom_call_coverage(self): covered_targets = covered_targets.union({ "tf.call_tf_function", # tested in jax2tf/tests/back_compat_tf_test.py "tpu_custom_call", # tested separately + "mosaic_gpu", # tested in pallas/export_back_compat_pallas_test.py "__gpu$xla.gpu.triton", # tested in pallas/export_back_compat_pallas_test.py # The following require ROCm to test "hip_lu_pivots_to_permutation", "hipsolver_getrf_ffi", @@ -212,10 +206,6 @@ def test_cpu_cholesky_lapack_potrf(self, dtype_name="f32"): data = self.load_testdata(info) self.run_one_test(func, data, rtol=rtol, atol=atol) - data = self.load_testdata(cpu_cholesky_lapack_potrf.data_2023_06_19[dtype_name]) - self.run_one_test(func, data, rtol=rtol, atol=atol, - expect_current_custom_calls=info["custom_call_targets"]) - @parameterized.named_parameters( dict(testcase_name=f"_dtype={dtype_name}", dtype_name=dtype_name) for dtype_name in ("f32", "f64", "c64", "c128")) @@ -276,10 +266,6 @@ def check_eigenvalue_is_in_array(eigenvalue, eigenvalues_array): data = self.load_testdata(info) self.run_one_test(func, data, rtol=rtol, atol=atol, check_results=check_eig_results) - data = self.load_testdata(cpu_eig_lapack_geev.data_2023_06_19[dtype_name]) - self.run_one_test(func, data, rtol=rtol, atol=atol, - check_results=check_eig_results, - expect_current_custom_calls=info["custom_call_targets"]) @staticmethod def eigh_input(shape, dtype): @@ -333,44 +319,6 @@ def test_cpu_eigh_lapack_syevd(self, dtype_name="f32"): self.run_one_test(func, data, rtol=rtol, atol=atol, check_results=partial(self.check_eigh_results, operand)) - # Legacy custom call test - data = self.load_testdata(cpu_eigh_lapack_syev.data_2023_03_17[dtype_name]) - self.run_one_test(func, data, rtol=rtol, atol=atol, - check_results=partial(self.check_eigh_results, operand), - expect_current_custom_calls=info["custom_call_targets"]) - - @parameterized.named_parameters( - dict(testcase_name=f"_dtype={dtype_name}_{variant}", - dtype_name=dtype_name, variant=variant) - for dtype_name in ("f32", "f64") - # We use different custom calls for sizes <= 32 - for variant in ["syevj", "syevd"]) - def test_gpu_eigh_solver_syev_legacy(self, dtype_name="f32", variant="syevj"): - if not config.enable_x64.value and dtype_name == "f64": - self.skipTest("Test disabled for x32 mode") - if jtu.test_device_matches(["rocm"]): - data = self.load_testdata(rocm_eigh_hipsolver_syev.data_2024_08_05[f"{dtype_name}_{variant}"]) - prefix = "hip" - elif jtu.test_device_matches(["cuda"]): - if _is_required_cusolver_version_satisfied(11600): - # The underlying problem is that this test assumes the workspace size can be - # queried from an older version of cuSOLVER and then be used in a newer one. - self.skipTest("Newer cuSOLVER expects a larger workspace than was serialized") - data = self.load_testdata(cuda_eigh_cusolver_syev.data_2023_03_17[f"{dtype_name}_{variant}"]) - prefix = "cu" - else: - self.skipTest("Unsupported platform") - # For lax.linalg.eigh - dtype = dict(f32=np.float32, f64=np.float64)[dtype_name] - size = dict(syevj=8, syevd=36)[variant] - rtol = dict(f32=1e-3, f64=1e-5)[dtype_name] - atol = dict(f32=1e-2, f64=1e-10)[dtype_name] - operand = CompatTest.eigh_input((size, size), dtype) - func = lambda: CompatTest.eigh_harness((size, size), dtype) - self.run_one_test(func, data, rtol=rtol, atol=atol, - check_results=partial(self.check_eigh_results, operand), - expect_current_custom_calls=[f"{prefix}solver_syevd_ffi"]) - @parameterized.named_parameters( dict(testcase_name=f"_dtype={dtype_name}", dtype_name=dtype_name) for dtype_name in ("f32", "f64", "c64", "c128")) @@ -411,14 +359,14 @@ def lu_pivots_to_permutation_harness(shape): def test_cuda_lu_pivots_to_permutation(self): shape = (2, 3, 4) func = lambda: CompatTest.lu_pivots_to_permutation_harness(shape) - data = self.load_testdata(cuda_lu_pivots_to_permutation.data_2024_08_08) + data = self.load_testdata(cuda_lu_pivots_to_permutation.data_2025_04_01) self.run_one_test(func, data) @parameterized.named_parameters( dict(testcase_name=f"_dtype={dtype_name}", dtype_name=dtype_name) for dtype_name in ("f32", "f64", "c64", "c128")) - def test_cuda_lu_lapack_getrf(self, dtype_name:str): + def test_cuda_lu_cusolver_getrf(self, dtype_name:str): if not config.enable_x64.value and dtype_name in ["f64", "c128"]: self.skipTest("Test disabled for x32 mode") dtype = dict(f32=np.float32, f64=np.float64, @@ -445,38 +393,10 @@ def test_cpu_qr_lapack_geqrf(self, dtype_name="f32"): c64=np.complex64, c128=np.complex128)[dtype_name] func = lambda: CompatTest.qr_harness((3, 3), dtype) - info = cpu_qr_lapack_geqrf.data_2024_08_22[dtype_name] + info = cpu_qr_lapack_geqrf.data_2025_04_02[dtype_name] data = self.load_testdata(info) self.run_one_test(func, data, rtol=rtol) - # TODO(b/369826500): Remove legacy custom call test after mid March 2025. - data = self.load_testdata(cpu_qr_lapack_geqrf.data_2023_03_17[dtype_name]) - self.run_one_test(func, data, rtol=rtol, - expect_current_custom_calls=info["custom_call_targets"]) - - # TODO(b/369826500): Remove legacy custom call test after mid March 2025. - @parameterized.named_parameters( - dict(testcase_name=f"_dtype={dtype_name}_{batched}", - dtype_name=dtype_name, batched=batched) - for dtype_name in ("f32",) - # For batched qr we use cublas_geqrf_batched/hipblas_geqrf_batched. - for batched in ("batched", "unbatched")) - def test_gpu_qr_solver_geqrf_legacy(self, dtype_name, batched): - if jtu.test_device_matches(["rocm"]): - data = self.load_testdata(rocm_qr_hipsolver_geqrf.data_2024_08_05[batched]) - prefix = "hip" - elif jtu.test_device_matches(["cuda"]): - data = self.load_testdata(cuda_qr_cusolver_geqrf.data_2023_03_18[batched]) - prefix = "cu" - else: - self.skipTest("Unsupported platform") - dtype = dict(f32=np.float32)[dtype_name] - rtol = dict(f32=1e-3)[dtype_name] - shape = dict(batched=(2, 3, 3), unbatched=(3, 3))[batched] - func = lambda: CompatTest.qr_harness(shape, dtype) - self.run_one_test(func, data, rtol=rtol, expect_current_custom_calls=[ - f"{prefix}solver_geqrf_ffi", f"{prefix}solver_orgqr_ffi"]) - @parameterized.named_parameters( dict(testcase_name=f"_dtype={dtype_name}", dtype_name=dtype_name) for dtype_name in ("f32", "f64", "c64", "c128")) @@ -551,14 +471,6 @@ def test_cpu_lu_lapack_getrf(self, dtype_name:str): check_results=partial(self.check_lu_results, operand, dtype=dtype)) - # TODO(b/357034884): Remove legacy custom call test after mid March 2025. - legacy_data = self.load_testdata( - cpu_lu_lapack_getrf.data_2023_06_14[dtype_name]) - self.run_one_test(func, legacy_data, rtol=rtol, atol=atol, - check_results=partial(self.check_lu_results, operand, - dtype=dtype), - expect_current_custom_calls=info["custom_call_targets"]) - def check_svd_results(self, input, res_run, res_exp, rtol=None, atol=None): # Following linalg_test.testSVD @@ -652,10 +564,6 @@ def check_schur_results(res_run, res_expected, *, rtol, atol): data = self.load_testdata(info) self.run_one_test(func, data, rtol=rtol, atol=atol, check_results=check_schur_results) - data = self.load_testdata(cpu_schur_lapack_gees.data_2023_07_16[dtype_name]) - self.run_one_test(func, data, rtol=rtol, atol=atol, - check_results=check_schur_results, - expect_current_custom_calls=info["custom_call_targets"]) @parameterized.named_parameters( dict(testcase_name=f"_dtype={dtype_name}", dtype_name=dtype_name) @@ -677,12 +585,6 @@ def func(operand): check_results=partial(self.check_svd_results, *data.inputs)) - data = self.load_testdata(cpu_svd_lapack_gesdd.data_2023_06_19[dtype_name]) - self.run_one_test(func, data, rtol=rtol, atol=atol, - check_results=partial(self.check_svd_results, - *data.inputs), - expect_current_custom_calls=info["custom_call_targets"]) - @parameterized.named_parameters( dict(testcase_name=f"_dtype={dtype_name}_algorithm={algorithm_name}", dtype_name=dtype_name, algorithm_name=algorithm_name) @@ -745,11 +647,6 @@ def check_triangular_solve_results(res_run, res_expected, *, rtol, atol): self.run_one_test(func, data, rtol=rtol, atol=atol, check_results=check_triangular_solve_results) - data = self.load_testdata(cpu_triangular_solve_blas_trsm.data_2023_07_16[dtype_name]) - self.run_one_test(func, data, rtol=rtol, atol=atol, - check_results=check_triangular_solve_results, - expect_current_custom_calls=info["custom_call_targets"]) - @parameterized.named_parameters( dict(testcase_name=f"_dtype={dtype_name}", dtype_name=dtype_name) for dtype_name in ("f32", "f64", "c64", "c128")) @@ -773,12 +670,6 @@ def func(): data = self.load_testdata(info) self.run_one_test(func, data, rtol=rtol, atol=atol) - data = self.load_testdata( - cpu_hessenberg_lapack_gehrd.data_2024_08_30[dtype_name] - ) - self.run_one_test(func, data, rtol=rtol, atol=atol, - expect_current_custom_calls=info["custom_call_targets"]) - @parameterized.named_parameters( dict(testcase_name=f"_dtype={dtype_name}", dtype_name=dtype_name) for dtype_name in ("f32", "f64", "c64", "c128")) @@ -802,12 +693,6 @@ def func(): data = self.load_testdata(info) self.run_one_test(func, data, rtol=rtol, atol=atol) - data = self.load_testdata( - cpu_tridiagonal_lapack_sytrd_hetrd.data_2024_09_03[dtype_name] - ) - self.run_one_test(func, data, rtol=rtol, atol=atol, - expect_current_custom_calls=info["custom_call_targets"]) - @parameterized.named_parameters( dict(testcase_name=f"_dtype={dtype_name}", dtype_name=dtype_name) for dtype_name in ("f32", "f64", "c64", "c128")) @@ -827,7 +712,7 @@ def test_cpu_tridiagonal_solve_lapack_gtsv(self, dtype_name): dict(testcase_name=f"_dtype={dtype_name}", dtype_name=dtype_name) for dtype_name in ("f32", "f64", "c64", "c128")) @jax.default_matmul_precision("float32") - def test_gpu_tridiagonal_solver_sytrd(self, dtype_name): + def test_gpu_tridiagonal_sytrd(self, dtype_name): if not config.enable_x64.value and dtype_name in ["f64", "c128"]: self.skipTest("Test disabled for x32 mode") @@ -842,7 +727,27 @@ def func(x): ) self.run_one_test(func, data, rtol=rtol, atol=atol) - def test_approx_top_k(self): + @parameterized.named_parameters( + dict(testcase_name=f"_dtype={dtype_name}", dtype_name=dtype_name) + for dtype_name in ("f32", "f64")) + @jax.default_matmul_precision("float32") + def test_gpu_tridiagonal_solve(self, dtype_name): + if not config.enable_x64.value and dtype_name == "f64": + self.skipTest("Test disabled for x32 mode") + + dtype = dict(f32=np.float32, f64=np.float64)[dtype_name] + def func(dl, d, du, b): + return lax.linalg.tridiagonal_solve(dl, d, du, b) + + rtol = dict(f32=1e-3, f64=1e-5)[dtype_name] + atol = dict(f32=1e-4, f64=1e-12)[dtype_name] + + data = self.load_testdata( + cuda_tridiagonal_solve.data_2025_06_16[dtype_name] + ) + self.run_one_test(func, data, atol=atol, rtol=rtol) + + def test_tpu_approx_top_k(self): def func(): x = np.array([3.0, 1.0, 4.0, 2.0, 5.0, 6.0, 7.0]) y = lax.approx_max_k(x, 3) @@ -859,7 +764,7 @@ def func(x): data = self.load_testdata(cuda_threefry2x32.data_2024_07_30) self.run_one_test(func, data) - def test_sharding(self): + def test_tpu_sharding(self): # Tests "Sharding", "SPMDShardToFullShape", "SPMDFullToShardShape" on TPU if not jtu.test_device_matches(["tpu"]) or len(jax.devices()) < 2: self.skipTest("Test runs only on TPU with at least 2 devices") @@ -873,14 +778,50 @@ def test_sharding(self): @partial(shard_map, mesh=mesh, in_specs=(P('a', None),), out_specs=P('a', None)) def func(x): # b: f32[2, 4] - axis_size = lax.psum(1, 'a') + axis_size = lax.axis_size('a') perm = [(j, (j + 1) % axis_size) for j in range(axis_size)] return lax.ppermute(x, 'a', perm=perm) - data = self.load_testdata(tpu_Sharding.data_2023_03_16) + data = tpu_Sharding.data_2023_03_16 + if jax.config.jax_use_shardy_partitioner: + data = data["shardy"] + else: + data = data["gspmd"] + data = self.load_testdata(data) with mesh: self.run_one_test(func, data) + @parameterized.named_parameters( + dict(testcase_name=f"_platform={platform}", platform=platform) + for platform in ("tpu", "gpu")) + def test_annotate_device_placement(self, platform): + if not jtu.test_device_matches([platform]): + self.skipTest(f"Test enabled only for {platform}") + + mesh = Mesh(jax.local_devices()[0:1], axis_names=("a")) + + dev_sharding = NS(mesh, P("a")) + host_sharding = NS(mesh, P("a"), memory_kind="pinned_host") + + @partial(jax.jit, + in_shardings=(dev_sharding, host_sharding), + out_shardings=host_sharding) + def func(x, y): + return x + y + + if platform == "tpu": + data = annotate_data_placement.data_2025_04_07_tpu + else: + data = annotate_data_placement.data_2025_04_07_cuda + + if jax.config.jax_use_shardy_partitioner: + data = data["shardy"] + else: + data = data["gspmd"] + data = self.load_testdata(data) + + self.run_one_test(func, data) + def test_tpu_stablehlo_dynamic_reduce_window_unary(self): # stablehlo.dynamic_reduce_window is used temporarily on TPU for a # reduce window with dynamic shapes. @@ -1044,15 +985,20 @@ def func(x): # x: f32[4, 4] @partial(shard_map, mesh=old_mesh, in_specs=(P('a', None),), out_specs=P('a', None)) def shard_map_func(x): # b: f32[2, 4] - axis_size = lax.psum(1, 'a') + axis_size = lax.axis_size('a') perm = [(j, (j + 1) % axis_size) for j in range(axis_size)] return lax.ppermute(x, 'a', perm=perm) x = jax.lax.with_sharding_constraint(x, NS(old_mesh, P('a', None))) return shard_map_func(x) - data = self.load_testdata(shardy_sharding_ops_with_different_meshes.data_2025_02_12) - with Mesh(devices, axis_names=('x')): - self.run_one_test(func, data) + data = [ + shardy_sharding_ops_with_different_meshes.data_2025_02_12, + shardy_sharding_ops_with_different_meshes.data_2025_04_14, + ] + + for d in data: + with Mesh(devices, axis_names=('x')): + self.run_one_test(func, self.load_testdata(d)) if __name__ == "__main__": diff --git a/tests/export_harnesses_multi_platform_test.py b/tests/export_harnesses_multi_platform_test.py index ef9d1e04c796..b91a1fc550bd 100644 --- a/tests/export_harnesses_multi_platform_test.py +++ b/tests/export_harnesses_multi_platform_test.py @@ -23,7 +23,6 @@ from collections.abc import Callable import math -import re from absl import logging from absl.testing import absltest @@ -39,13 +38,6 @@ from jax import random -def make_disjunction_regexp(*parts: str) -> re.Pattern[str]: - if not parts: - return re.compile("matches_no_test") - else: - return re.compile("(" + "|".join(parts) + ")") - - class PrimitiveTest(jtu.JaxTestCase): def setUp(self): @@ -84,10 +76,6 @@ def test_prim(self, harness: test_harnesses.Harness): self.skipTest("Eigenvalues are sorted and it is not correct to compare " "decompositions for equality.") - if (jtu.device_under_test() == "gpu" - and "tridiagonal_solve_" in harness.fullname): - self.skipTest("tridiagonal_solve_ is not yet guaranteed stable.") - if harness.params.get("enable_xla", False): self.skipTest("enable_xla=False is not relevant") @@ -98,11 +86,6 @@ def test_prim(self, harness: test_harnesses.Harness): for l in harness.jax_unimplemented: if l.filter(dtype=harness.dtype): unimplemented_platforms = unimplemented_platforms.union(l.devices) - # Some primitive lowering rules need the GPU backend to be able to create - # CUDA lowering. - if ("tridiagonal_solve_" in harness.fullname - and all(d.platform != "gpu" for d in self.devices)): - unimplemented_platforms.add("gpu") if unimplemented_platforms: logging.info("Harness is not implemented on %s", unimplemented_platforms) diff --git a/tests/export_test.py b/tests/export_test.py index 2b083f3121f4..ecdb470819c8 100644 --- a/tests/export_test.py +++ b/tests/export_test.py @@ -19,7 +19,6 @@ import dataclasses import functools import logging -import json import math import re import unittest @@ -30,13 +29,14 @@ from jax import numpy as jnp from jax import export from jax.experimental import pjit -from jax.experimental.shard_map import shard_map +from jax._src.shard_map import shard_map from jax.sharding import NamedSharding from jax.sharding import Mesh from jax.sharding import PartitionSpec as P from jax import tree_util from jax._src import config +from jax._src import compute_on from jax._src import core from jax._src import dtypes from jax._src import effects @@ -204,7 +204,14 @@ def test_basic(self): f = jnp.sin x = np.arange(4, dtype=np.float32) exp_f = get_exported(f)(x) + self.assertAllClose(f(x), exp_f.call(x)) + def test_basic_single_device_sharding(self): + device = jax.local_devices()[0] + s = jax.sharding.SingleDeviceSharding(device) + x = np.arange(16, dtype=np.float32).reshape(4, -1) + f = jax.jit(lambda x: x * 2., in_shardings=s, out_shardings=s) + exp_f = get_exported(f)(x) self.assertAllClose(f(x), exp_f.call(x)) def test_jit_static_arg(self): @@ -281,6 +288,18 @@ def test_unused_args(self): self.assertAllClose(f(x, y), exp_f.call(x, y)) + def test_override_lowering_rules(self): + @jax.jit + def f(x): + return jnp.sin(x) + + def my_lowering_rule(ctx, arg, **_): + return mlir.hlo.CosineOp(arg).results + + exp = get_exported(f, _override_lowering_rules=( + (lax.sin_p, my_lowering_rule),))(42.) + self.assertIn("stablehlo.cosine", exp.mlir_module()) + def test_pytree(self): a = np.arange(4, dtype=np.float32) b = np.arange(6, dtype=np.float32) @@ -397,7 +416,7 @@ def f(x1, x2): exp = export.export(jax.jit(f))(x1, x2) res = exp.call(x1, x2) self.assertEqual(tree_util.tree_structure(res), - tree_util.tree_structure(((x1, x2, x1, x2)))) + tree_util.tree_structure((x1, x2, x1, x2))) self.assertEqual(type(res[0]), type(x1)) self.assertEqual(type(res[1]), type(x2)) self.assertEqual(type(res[2]), type(x1)) @@ -410,6 +429,18 @@ def f(x1, x2): self.assertEqual(tree_util.tree_structure(res2), tree_util.tree_structure(res)) + @jtu.parameterized_filterable( + kwargs=[dict(impl=p) + for p in ("rbg", "unsafe_rbg", "threefry2x32")]) + def test_prng_keys(self, *, impl): + + key = jax.random.key(42, impl=impl) + @jax.jit + def f(key): + return key + exp_f = get_exported(jax.jit(f))(key) + self.assertEqual(f(key), exp_f.call(key)) + def test_error_wrong_intree(self): def f(a_b_pair, *, c): return jnp.sin(a_b_pair[0]) + jnp.cos(a_b_pair[1]) + c @@ -941,7 +972,7 @@ def outer(x): # x: outer_poly_spec "Using the following polymorphic shapes specifications: args[0].shape = (2*b + a, a, c + b + a). " "Obtained dimension variables: 'a' = 2 from specification 'a' for dimension args[0].shape[1] (= 2), " "'b' = 0 from specification '2*b + a' for dimension args[0].shape[0] (= 2), . " - "Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#shape-assertion-errors for more details." + "Please see https://docs.jax.dev/en/latest/export/shape_poly.html#shape-assertion-errors for more details." )), dict(shape=(3, 2, 6), # a = 2, b = 0.5, c = 4 - b is not integer poly_spec="(a + 2*b, a, a + b + c)", @@ -950,7 +981,7 @@ def outer(x): # x: outer_poly_spec "Division had remainder 1 when computing the value of 'b'. " "Using the following polymorphic shapes specifications: args[0].shape = (2*b + a, a, c + b + a). " "Obtained dimension variables: 'a' = 2 from specification 'a' for dimension args[0].shape[1] (= 2), . " - "Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#shape-assertion-errors for more details." + "Please see https://docs.jax.dev/en/latest/export/shape_poly.html#shape-assertion-errors for more details." )), dict(shape=(8, 2, 6), # a = 2, b = 3 - inconsistency poly_spec="(a + 2*b, a, a + b)", @@ -960,7 +991,7 @@ def outer(x): # x: outer_poly_spec "Using the following polymorphic shapes specifications: args[0].shape = (2*b + a, a, b + a). " "Obtained dimension variables: 'a' = 2 from specification 'a' for dimension args[0].shape[1] (= 2), " "'b' = 4 from specification 'b + a' for dimension args[0].shape[2] (= 6), . " - "Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#shape-assertion-errors for more details." + "Please see https://docs.jax.dev/en/latest/export/shape_poly.html#shape-assertion-errors for more details." )), dict(shape=(7, 2, 36), # a = 2, b = 3, c = 6 - cannot solve c poly_spec="(2 * a + b, a, c * c)", @@ -969,7 +1000,7 @@ def outer(x): # x: outer_poly_spec "We can only solve linear uni-variate constraints. " "Using the following polymorphic shapes specifications: args[0].shape = (b + 2*a, a, c^2). " "Unprocessed specifications: 'c^2' for dimension size args[0].shape[2]. " - "Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#dimension-variables-must-be-solvable-from-the-input-shapes for more details." + "Please see https://docs.jax.dev/en/latest/export/shape_poly.html#dimension-variables-must-be-solvable-from-the-input-shapes for more details." )), ]) def test_shape_constraints_errors(self, *, @@ -1312,7 +1343,7 @@ def test_shard_map_collective_permute(self, poly=None): shard_map, mesh=mesh, in_specs=(P("x", None),), out_specs=P("x", None)) def f_jax(b): # b: f32[2, 4] - axis_size = lax.psum(1, "x") + axis_size = lax.axis_size("x") perm = [(j, (j + 1) % axis_size) for j in range(axis_size)] return lax.ppermute(b, "x", perm=perm) @@ -1517,7 +1548,7 @@ def test_multi_platform(self): self.assertIn("jax.uses_shape_polymorphism = true", module_str) - # Call with argument placed on different plaforms + # Call with argument placed on different platforms for platform in self.platforms: x_device = jax.device_put(x, jax.devices(platform)[0]) res_exp = exp.call(x_device) @@ -1542,7 +1573,7 @@ def test_multi_platform_nested(self): count_sine = len(re.findall("stablehlo.sine", exp2_module_str)) self.assertEqual(1, count_sine) - # Call with argument placed on different plaforms + # Call with argument placed on different platforms for platform in self.platforms: if platform == "tpu": continue x_device = jax.device_put(x, jax.devices(platform)[0]) @@ -1685,7 +1716,7 @@ def f_jax(b): # b: f32[16 // DEVICES, 4] res_native = f_jax(a) exp = get_exported(f_jax, platforms=("cpu", "tpu", "cuda", "rocm"))(a) - # Call with argument placed on different plaforms + # Call with argument placed on different platforms for platform in self.platforms: run_devices = jax.devices(platform)[0:len(export_devices)] if len(run_devices) != len(export_devices): @@ -1695,6 +1726,22 @@ def f_jax(b): # b: f32[16 // DEVICES, 4] res_exp = exp.call(a_device) self.assertArraysAllClose(res_native, res_exp) + def test_compute_on_host(self): + operand = np.float32(0.) + + @jax.jit + @compute_on.compute_on("device_host") + def f_host(x): + # Adds 1 on CPU, which should be the result on all platforms because + # this code should always run on the host. + return jax.lax.platform_dependent(x, + cpu=lambda x: x + np.float32(1.), + default=lambda x: x + np.float32(2.)) + + self.assertAllClose(np.float32(1.), f_host(operand)) + exp = get_exported(f_host, platforms=("cpu", "tpu", "cuda", "rocm"))(operand) + self.assertAllClose(np.float32(1.), exp.call(operand)) + @jtu.parameterized_filterable( kwargs=[ dict(v=v) @@ -1903,8 +1950,8 @@ def f_jax(x): @jtu.parameterized_filterable( kwargs=[ - {"m": 5, "k": 4, "n": 3, "group_sizes": [5]}, - {"m": 10, "k": 9, "n": 8, "group_sizes": [3, 7]}, + {"m": 64, "k": 4, "n": 3, "group_sizes": [5]}, + {"m": 64, "k": 9, "n": 8, "group_sizes": [3, 7]}, ]) def test_ragged_dot(self, m, k, n, group_sizes): def f_jax(x, y, gs): @@ -1961,6 +2008,32 @@ def f(x, y): r = jax.jit(exp.call, out_shardings=NamedSharding(old_mesh_0, P("old_b")))(a, b) self.assertAllClose(a + b, r) + def test_lower_wth_different_meshes_axis_names(self): + mesh1 = jtu.create_mesh((4, 2), ("a", "b")) + mesh2 = jtu.create_mesh((4, 2), ("x", "y")) + @jax.jit + def f(tree): + return tree['foo'] + tree['bar'] + + args = { + 'foo': jax.ShapeDtypeStruct( + (32, 32), dtype=np.float32, + sharding=NamedSharding(mesh1, P(None, "a"))), + 'bar': jax.ShapeDtypeStruct( + (32, 32), dtype=np.float32, + sharding=NamedSharding(mesh2, P("y"))), + } + + if config.use_shardy_partitioner.value: + with self.assertRaisesRegex( + ValueError, + r'Mesh for all inputs/outputs should be equal.*' + r"args\[0\]\['bar'\].*"): + get_exported(f)(args) + else: + get_exported(f)(args) + + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/ffi_test.py b/tests/ffi_test.py index 46aaefa8f521..77b0f823e125 100644 --- a/tests/ffi_test.py +++ b/tests/ffi_test.py @@ -28,6 +28,7 @@ from jax._src import config from jax._src import core +from jax._src import deprecations from jax._src import dispatch from jax._src import test_util as jtu from jax._src.interpreters import mlir @@ -35,7 +36,7 @@ from jax._src.lib import lapack from jax._src.lib.mlir.dialects import hlo from jax._src.lax import linalg as lax_linalg_internal -from jax.experimental.shard_map import shard_map +from jax._src.shard_map import shard_map jax.config.parse_flags_with_absl() jtu.request_cpu_devices(8) @@ -200,21 +201,6 @@ def test_ffi_call_batching(self, shape, vmap_method): else: self.assertArraysEqual(a, b) - @jtu.run_on_devices("gpu", "cpu") - def test_vectorized_deprecation(self): - x = self.rng().randn(3, 5, 4).astype(np.float32) - with self.assertWarns(DeprecationWarning): - ffi_call_geqrf(x, vectorized=True) - with self.assertWarns(DeprecationWarning): - jax.vmap(ffi_call_geqrf)(x) - - def test_backward_compat_syntax(self): - def fun(x): - return jax.ffi.ffi_call("test_ffi", x, x, param=0.5) - msg = "Calling ffi_call directly with input arguments is deprecated" - with self.assertDeprecationWarnsOrRaises("jax-ffi-call-args", msg): - jax.jit(fun).lower(jnp.ones(5)) - def test_input_output_aliases(self): def fun(x): return jax.ffi.ffi_call("test", x, input_output_aliases={0: 0})(x) @@ -299,8 +285,15 @@ def f(x): @jtu.run_on_devices("gpu", "cpu") @jtu.ignore_warning(category=DeprecationWarning) def test_extend_import_shim(self): + if deprecations.is_accelerated_attribute(jex.ffi, "ffi_call"): + self.skipTest("FFI call deprecation is accelerated.") ffi_call_geqrf(jnp.ones((4, 5), dtype=np.float32), _use_extend=True) + def test_extended_dtype_lowering(self): + def f(x): + return jax.ffi.ffi_call("edtype", (), has_side_effect=True)(x) + jax.jit(f).lower(jax.random.key(0)) # doesn't crash + def ffi_call_geqrf(x, _use_extend=False, **kwargs): if jtu.test_device_matches(["cpu"]): @@ -349,7 +342,7 @@ def test_shard_map(self): x = self.rng().randn(8, 4, 5).astype(np.float32) @partial(shard_map, mesh=mesh, in_specs=P("i"), out_specs=P("i"), - check_rep=False) + check_vma=False) def f(x): return batch_partitionable_ffi_call(x) diff --git a/tests/filecheck/custom_call.filecheck.py b/tests/filecheck/custom_call.filecheck.py index c6af4235ebb4..27cc904e59d8 100644 --- a/tests/filecheck/custom_call.filecheck.py +++ b/tests/filecheck/custom_call.filecheck.py @@ -19,7 +19,7 @@ from absl import app import jax -from jax.interpreters import mlir +from jax._src.interpreters import mlir from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import func as func_dialect import numpy as np diff --git a/tests/fused_attention_stablehlo_test.py b/tests/fused_attention_stablehlo_test.py index af0b18b02f37..394e5b4b0e8f 100644 --- a/tests/fused_attention_stablehlo_test.py +++ b/tests/fused_attention_stablehlo_test.py @@ -24,6 +24,7 @@ from jax._src import test_util as jtu from jax._src.cudnn.fused_attention_stablehlo import ( dot_product_attention, + paged_attention, check_is_flash_attention, check_cudnn_version, MaskType, @@ -618,6 +619,8 @@ def test_sdpa_packed_layout(self): return if cudnn_version < 90600: self.skipTest("Requires >= cuDNN 9.6.0") + if not jtu.is_cuda_compute_capability_at_least("9.0"): + self.skipTest("Requires at least Hopper arch") k1, k2, k3, k4 = jax.random.split(jax.random.key(0), 4) query = jax.random.normal( k1, (4, 512, 4, 64), dtype=jnp.bfloat16) @@ -737,6 +740,171 @@ def generate_segment_mask(segment_ids, dtype): self.assertArraysAllClose(key_grad_ref, key_grad, rtol=1e-2, atol=1e-2) self.assertArraysAllClose(value_grad_ref, value_grad, rtol=1e-2, atol=1e-2) + @jtu.run_on_devices("cuda") + def test_sdpa_residual(self): + k1, k2, k3, k4, k5 = jax.random.split(jax.random.key(0), 5) + query = jax.random.normal( + k1, (4, 1024, 4, 64), dtype=jnp.bfloat16) + key = jax.random.normal( + k2, (4, 1024, 4, 64), dtype=jnp.bfloat16) + value = jax.random.normal( + k3, (4, 1024, 4, 64), dtype=jnp.bfloat16) + grad = jax.random.normal( + k4, (4, 1024, 4, 64), dtype=jnp.bfloat16) + grad_stat = jax.random.normal( + k5, (4, 4, 1024), dtype=jnp.float32) + + devices = np.array(jax.local_devices()[:2]) + with Mesh(devices, ("dp")) as mesh: + qkv_spec = PartitionSpec("dp", None, None, None) + stat_spec = PartitionSpec("dp", None, None) + qkv_sharding = NamedSharding(mesh, qkv_spec) + stat_sharding = NamedSharding(mesh, stat_spec) + + query = jax.device_put(query, qkv_sharding) + key = jax.device_put(key, qkv_sharding) + value = jax.device_put(value, qkv_sharding) + grad = jax.device_put(grad, qkv_sharding) + grad_stat = jax.device_put(grad_stat, stat_sharding) + + jitted_sdpa_inference = jax.jit( + partial( + dot_product_attention, scale=1.0, mask_type=MaskType.NO_MASK, + dropout_rate=0, return_residual=True), + in_shardings=(qkv_sharding, qkv_sharding, qkv_sharding), + out_shardings=(qkv_sharding, stat_sharding) + ) + + outs = jitted_sdpa_inference(query, key, value) + assert len(outs) == 2 + + def train(query, key, value, grads): + outs, grad_fn = jax.vjp(partial( + dot_product_attention, scale=1.0, mask_type=MaskType.NO_MASK, + dropout_rate=0, return_residual=True), query, key, value) + return outs, grad_fn(grads) + jitted_sdpa_train = jax.jit(train, + in_shardings=(qkv_sharding, qkv_sharding, qkv_sharding, (qkv_sharding, stat_sharding)), + out_shardings=((qkv_sharding, stat_sharding), (qkv_sharding, qkv_sharding, qkv_sharding))) + outs = jitted_sdpa_train(query, key, value, (grad, grad_stat)) + assert len(outs) == 2 + + @jtu.sample_product( + batch_size=[4], + q_seq_len=[1, 1024], + kv_seq_len=[1024], + num_heads=[8], + head_dim=[64, 128], + block_size=[64, 128], + dtype=[jnp.float16, jnp.bfloat16] + ) + @jtu.run_on_devices("cuda") + def test_sdpa_paged_attention(self, batch_size, q_seq_len, kv_seq_len, + num_heads, head_dim, block_size, dtype): + try: + cudnn_version = check_cudnn_version() + except RuntimeError as e: + self.skipTest(str(e)) + return + if cudnn_version < 90500: + self.skipTest("Requires >= cuDNN 9.5.0") + + keys = jax.random.split(jax.random.key(0), 5) + blocks_per_batch = kv_seq_len // block_size + num_blocks = batch_size * blocks_per_batch + + # different q_seq_len for prefill and decode + q = jax.random.normal( + keys[0], (batch_size, q_seq_len, num_heads, head_dim), dtype=dtype) + k_container = jax.random.normal( + keys[1], (num_blocks, block_size, num_heads, head_dim), dtype=dtype) + v_container = jax.random.normal( + keys[2], (num_blocks, block_size, num_heads, head_dim), dtype=dtype) + page_table_k = jax.random.randint( + keys[3], (batch_size, 1, blocks_per_batch, 1), 0, num_blocks-1, dtype=jnp.int32) + page_table_v = jax.random.randint( + keys[4], (batch_size, 1, blocks_per_batch, 1), 0, num_blocks-1, dtype=jnp.int32) + # full page table + q_seqlen = jnp.full((batch_size,), q_seq_len, jnp.int32) + kv_seqlen = jnp.full((batch_size,), kv_seq_len, jnp.int32) + + def unpaged(paged, page_table): + output = jnp.zeros((batch_size, kv_seq_len, num_heads, head_dim), dtype=dtype) + for b in range(batch_size): + for block in range(blocks_per_batch): + block_idx = page_table[b, 0, block, 0] + output = output.at[ + b, block * block_size : (block + 1) * block_size, :, : + ].set(paged[block_idx, :, :, :]) + return output + + k = unpaged(k_container, page_table_k) + v = unpaged(v_container, page_table_v) + + sdpa_infer = jax.jit(partial( + paged_attention, scale=1.0, mask_type=MaskType.NO_MASK) + ) + sdpa_infer_ref = jax.jit(partial( + sdpa_ref, scale=1.0, mask_type=MaskType.NO_MASK, dropout_rate=0) + ) + + out = sdpa_infer(q, k_container, v_container, q_seqlen=q_seqlen, + kv_seqlen=kv_seqlen, page_table_k=page_table_k, page_table_v=page_table_v) + out_ref = sdpa_infer_ref(q, k, v) + self.assertArraysAllClose(out_ref, out_ref, rtol=1e-2, atol=1e-2) + + @jtu.run_on_devices("cuda") + def test_sdpa_mla(self): + if jax.device_count() < 4: + self.skipTest("Requires more than 4 devices.") + try: + cudnn_version = check_cudnn_version() + except RuntimeError as e: + self.skipTest(str(e)) + return + if cudnn_version < 91000: + self.skipTest("Requires >= cuDNN 9.10.0") + if not jtu.is_cuda_compute_capability_at_least("9.0"): + self.skipTest("Requires at least Hopper arch") + k1, k2, k3 = jax.random.split(jax.random.key(0), 3) + query = jax.random.normal( + k1, (4, 1024, 4, 128), dtype=jnp.bfloat16) + key = jax.random.normal( + k2, (4, 1024, 4, 128), dtype=jnp.bfloat16) + value = jax.random.normal( + k3, (4, 1024, 4, 64), dtype=jnp.bfloat16) + + devices = np.array(jax.local_devices()[:4]) + devices = devices.reshape((2, 2)) + with Mesh(devices, ("dp", "tp")) as mesh: + qkv_spec = PartitionSpec("dp", None, "tp", None) + qkv_sharding = NamedSharding(mesh, qkv_spec) + in_shardings = ( + qkv_sharding, qkv_sharding, qkv_sharding) + out_shardings = qkv_sharding + query = jax.device_put(query, qkv_sharding) + key = jax.device_put(key, qkv_sharding) + value = jax.device_put(value, qkv_sharding) + + jitted_sdpa_inference = jax.jit( + partial( + dot_product_attention, scale=1.0, mask_type=MaskType.NO_MASK, + dropout_rate=0), + in_shardings=in_shardings, + out_shardings=out_shardings + ) + + jitted_sdpa_inference_ref = jax.jit( + partial( + sdpa_ref, scale=1.0, mask_type=MaskType.NO_MASK, dropout_rate=0), + in_shardings=in_shardings, + out_shardings=out_shardings + ) + + out = jitted_sdpa_inference(query, key, value) + out_ref = jitted_sdpa_inference_ref(query, key, value) + self.assertArraysAllClose(out_ref, out, rtol=2e-2, atol=2e-2) + @jtu.run_on_devices("cuda") def test_layouts(self): if jax.device_count() < 4: @@ -783,15 +951,16 @@ def test_sdpa_utils(self): expected_pass = k query = jnp.empty((4, sql_q, 4, head_dim)) key = jnp.empty((4, sql_v, 4, head_dim)) + value = jnp.empty((4, sql_v, 4, head_dim)) if expected_pass: check_is_flash_attention( - query, key, AttentionLayout.BNTH.value, cudnn_version, has_bias, - is_training) + query, key, value, AttentionLayout.BNTH.value, cudnn_version, + has_bias, is_training) else: with self.assertRaises(NotImplementedError): check_is_flash_attention( - query, key, AttentionLayout.BNTH.value, cudnn_version, has_bias, - is_training) + query, key, value, AttentionLayout.BNTH.value, cudnn_version, + has_bias, is_training) @jtu.with_config(jax_numpy_dtype_promotion="standard") diff --git a/tests/generated_fun_test.py b/tests/generated_fun_test.py index cdfeeba6275b..67c19179bb8b 100644 --- a/tests/generated_fun_test.py +++ b/tests/generated_fun_test.py @@ -218,7 +218,7 @@ def check_all_close(xs, ys, tol=1e-3): def check_close(x, y, tol=1e-3): assert jnp.shape(x) == jnp.shape(y) - # TODO(dougalm): re-enable once we've tackled the less pendantic bugs + # TODO(dougalm): re-enable once we've tackled the less pedantic bugs # assert x.dtype == y.dtype assert jnp.allclose(x, y, rtol=tol, atol=tol), \ f"Value mismatch:\n{x}\n vs\n{y}\n" diff --git a/tests/gpu_memory_flags_test.py b/tests/gpu_memory_flags_test.py index 308fff257348..bada2bebc74e 100644 --- a/tests/gpu_memory_flags_test.py +++ b/tests/gpu_memory_flags_test.py @@ -29,7 +29,7 @@ class GpuMemoryAllocationTest(absltest.TestCase): @jtu.skip_under_pytest("Test must run in an isolated process") @unittest.skipIf( "XLA_PYTHON_CLIENT_ALLOCATOR" in os.environ, - "Test does not work if the python client allocator has been overriden", + "Test does not work if the python client allocator has been overridden", ) def test_gpu_memory_allocation(self): falsey_values = ("0", "False", "false") @@ -40,7 +40,7 @@ def test_gpu_memory_allocation(self): device = jax.devices()[0] mem_stats = device.memory_stats() self.assertEqual(mem_stats["pool_bytes"], 0) - x = jax.lax.add(1, 2) + x = jax.lax.add(1, 2).block_until_ready() mem_stats = device.memory_stats() if preallocate: diff --git a/tests/hijax_test.py b/tests/hijax_test.py new file mode 100644 index 000000000000..ccdbe7371b69 --- /dev/null +++ b/tests/hijax_test.py @@ -0,0 +1,781 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from dataclasses import dataclass +from functools import partial +import itertools as it +from typing import Any +import unittest + +from absl.testing import absltest, parameterized + +import jax +import jax.numpy as jnp + +from jax._src import config +from jax._src import core +from jax._src.interpreters import ad +from jax._src.interpreters import partial_eval as pe +from jax._src import ad_util +from jax._src import test_util as jtu +from jax._src.util import safe_zip, safe_map + +config.parse_flags_with_absl() + +map, unsafe_map = safe_map, map +zip, unsafe_zip = safe_zip, zip + +PyTreeDef = Any + + +# TODO(mattjj,dougalm): move HiPrimitive, Box, etc out of tests and into library +class HiPrimitive(core.Primitive): + def __init__(self, name): + self.name = name + ad.primitive_jvps[self] = self.jvp + ad.primitive_transposes[self] = self.transpose + pe.custom_staging_rules[self] = self.staging + + def staging(self, trace, source_info, *args, **kwargs): + trace.frame.is_high = True + return trace.default_process_primitive(self, args, kwargs, + source_info=source_info) + + def is_high(self, **params): + return True + + def abstract_eval(self, *arg_avals, **params): + assert False, "must override" + + def to_lojax(self, *lotypes_wrapped_in_hitypes, **params): + assert False, "must override" + + def jvp(self, primals, tangents, **params): + assert False, "must override" + + def transpose(self, *args, **params): + assert False # TODO + +class HijaxTest(jtu.JaxTestCase): + + def test_custom_types_and_primitive(self): + if config.enable_x64.value: raise unittest.SkipTest("no x64") + + @dataclass(frozen=True) + class MyArray: + arr: jax.Array # always f32 + + @dataclass(frozen=True) + class MyTy(core.AbstractValue): + def to_tangent_aval(self): + return MyTy() + def str_short(self, short_dtypes=False): + return 'MyTy' + def lo_ty(self): + return [core.ShapedArray((), jnp.dtype('float32'))] + def lower_val(self, hi_val: MyArray) -> list[jax.Array]: + return [hi_val.arr] + def raise_val(self, val) -> MyArray: + return MyArray(val) + + def __eq__(self, other): return isinstance(other, MyTy) + + def vspace_zero(self): + return MyArray(jnp.zeros((), 'float32')) + def vspace_add(self, x, y): + return add(x, y) + core.pytype_aval_mappings[MyArray] = lambda _: MyTy() + + class ToMy(HiPrimitive): + def is_high(self): return True + + def abstract_eval(_, lo_aval): + return MyTy(), set() + + def to_lojax(_, lo): + return MyArray(lo) + + def jvp(_, primals, tangents): + x, x_dot = *primals, *tangents + return to(x), to(x_dot) + + def transpose(self, out_bar, _): + return from_(out_bar), + + class FromMy(HiPrimitive): + def is_high(self): return True + + def abstract_eval(_, hi_aval): + return hi_aval.lo_ty()[0], set() + + def to_lojax(_, hi): + return hi.arr + + def jvp(_, primals, tangents): + x, x_dot = *primals, *tangents + return from_(x), from_(x_dot) + + def transpose(self, out_bar, _): + return to(out_bar), + + def to(x): return to_p.bind(x) + to_p = ToMy('to_my') + + def from_(x): return from_p.bind(x) + from_p = FromMy('from_my') + + def mul(x, y): return mul_p.bind(x, y) + def add(x, y): return add_p.bind(x, y) + + class MyMul(HiPrimitive): + def is_high(self): return True + + def abstract_eval(_, hi_x, hi_y): + if hi_x != hi_y: raise Exception + return hi_x, set() + + def to_lojax(_, hi_x, hi_y): + return MyArray(hi_x.arr * hi_y.arr) + + def jvp(_, primals, tangents): + (x, y), (x_dot, y_dot) = primals, tangents + return mul(x, y), add(mul(x, y_dot), mul(x_dot, y)) + + def transpose(self, out_bar, x, y): + assert ad.is_undefined_primal(x) ^ ad.is_undefined_primal(y) + if ad.is_undefined_primal(x): + return mul(out_bar, y), None + else: + return None, mul(x, out_bar) + + class MyAdd(HiPrimitive): + def is_high(self): return True + + def abstract_eval(_, hi_x, hi_y): + if hi_x != hi_y: raise Exception + return hi_x, set() + + def to_lojax(_, hi_x, hi_y): + return MyArray(hi_x.arr + hi_y.arr) + + def jvp(_, primals, tangents): + assert False # TODO + + def transpose(self, out_bar, x, y): + return out_bar, out_bar + + mul_p = MyMul('my_mul') + add_p = MyAdd('my_add') + + + @jax.jit + def f(x): + return to(from_(x)) + + # test basic to/from jit + a = MyArray(jnp.ones(())) + b = f(a) # don't crash + self.assertIsInstance(b, MyArray) + self.assertAllClose(b.arr, jnp.ones(())) + + # test basic to/from autodiff + b, b_dot = jax.jvp(f, (a,), (a,)) + self.assertIsInstance(b, MyArray) + self.assertIsInstance(b_dot, MyArray) + + # test mul jit and backward pass + + @jax.jit + def f(x): + return mul(x, x) + + b, f_vjp = jax.vjp(f, a) + self.assertIn('MyTy', str(f_vjp)) + a_grad, = f_vjp(b) + self.assertIsInstance(a_grad, MyArray) + self.assertAllClose(a_grad.arr, 2.0, check_dtypes=False) + + +def new_box(): + (), treedef = jax.tree.flatten(None) + return new_box_p.bind(treedef=treedef) + +def box_get(box): + tys = core.cur_qdd(box) + leaf_vals = box_get_p.bind(box, avals=tuple(tys.leaf_avals)) + return jax.tree.unflatten(tys.treedef, leaf_vals) + +def box_set(box, val): + leaves, treedef = jax.tree.flatten(val) + box_set_p.bind(box, *leaves, treedef=treedef) + +@dataclass(frozen=True) +class BoxTypeState(core.QuasiDynamicData): + leaf_avals: tuple[core.AbstractValue, ...] + treedef: PyTreeDef + + def to_tangent_qdd(self): + return BoxTypeState(tuple(a.to_tangent_aval() for a in self.leaf_avals), + self.treedef) + + def normalize(self): + return BoxTypeState(tuple(a.normalize() for a in self.leaf_avals), + self.treedef) + +class BoxTy(core.AbstractValue): + has_qdd = True + + # forwarded to value + get = core.aval_method(box_get) + set = core.aval_method(box_set) + type_state = core.aval_method(core.cur_qdd) + + # aval interface: hashability and str_short + def __hash__(self): return hash(BoxTy) + def __eq__(self, other): return isinstance(other, BoxTy) + + def str_short(self, short_dtypes=False): + return 'BoxTy' + + # mutable interface + def lo_ty_qdd(self, box_state): + return [lo_ty for t in box_state.leaf_avals for lo_ty in t.lo_ty()] + + def new_from_loval(self, box_state: BoxTypeState, *lo_vals): + lo_vals_ = iter(lo_vals) + hi_vals = [hi_ty.raise_val(*it.islice(lo_vals_, len(hi_ty.lo_ty()))) + for hi_ty in box_state.leaf_avals] + assert next(lo_vals_, None) is None + return Box(jax.tree.unflatten(box_state.treedef, hi_vals)) # will be mutated + + def read_loval(self, box_state: BoxTypeState, box): + leaf_vals, treedef = jax.tree.flatten(box_get(box)) + assert treedef == box_state.treedef + return [lo_val for hi_ty, hi_val in zip(box_state.leaf_avals, leaf_vals) + for lo_val in hi_ty.lower_val(hi_val)] + + def update_from_loval(self, box_state: BoxTypeState, box, *lo_vals): + lo_vals_ = iter(lo_vals) + hi_vals = [hi_ty.raise_val(*it.islice(lo_vals_, len(hi_ty.lo_ty()))) + for hi_ty in box_state.leaf_avals] + assert next(lo_vals_, None) is None + box_set(box, jax.tree.unflatten(box_state.treedef, hi_vals)) + + def to_tangent_aval(self): + return BoxTy() + +class Box: # noqa: F811 + def __init__(self, val): + self._val = val + + def get(self): + return box_get(self) + + def set(self, val): + box_set(self, val) + + def cur_qdd(self): + return self.type_state() + + @property + def ty(self): + return BoxTy() + + def type_state(self): + leaves, treedef = jax.tree.flatten(self._val) + leaf_avals = tuple(map(core.typeof, leaves)) + return BoxTypeState(leaf_avals, treedef) +core.pytype_aval_mappings[Box] = lambda b: b.ty + + +class NewBox(HiPrimitive): + def is_high(self, *, treedef) -> bool: return True + + def abstract_eval(self, *, treedef): + leaves, treedef = jax.tree.flatten(None) + qdd = BoxTypeState(leaves, treedef) + return core.AvalQDD(BoxTy(), qdd), set() + + def to_lojax(_, *, treedef): + return Box(None) + + def jvp(_, primals, tangents, *, treedef): + assert False # TODO + + def transpose(_, *args, treedef): + assert False # TODO +new_box_p = NewBox('new_box') + + +class BoxSet(HiPrimitive): + multiple_results = True + + def is_high(self, *, treedef) -> bool: return True + + def abstract_eval(self, box_ty, *leaf_avals, treedef): + box_ty.mutable_qdd.update(BoxTypeState(leaf_avals, treedef)) + return [], set() # TODO better typechecking... + + def to_lojax(_, box, *leaves, treedef): + box._val = jax.tree.unflatten(treedef, leaves) + return [] + + def jvp(_, primals, tangents, *, treedef): + box, *vals = primals + box_dot, *val_dots = tangents + if type(box_dot) is ad_util.Zero: + raise Exception("you're an idiot") + box_set_p.bind(box, *vals, treedef=treedef) + box_set_p.bind(box_dot, *val_dots, treedef=treedef) + return [], [] + + def transpose(_, *args, treedef): + assert False # TODO +box_set_p = BoxSet('box_set') + + +class BoxGet(HiPrimitive): + multiple_results = True + + def abstract_eval(self, box_ty, *, avals): + return avals, set() + + def to_lojax(_, box, *, avals): + return jax.tree.leaves(box._val) + + def jvp(_, primals, tangents, *, avals): + (box,), (box_dot,) = primals, tangents + return ( + box_get_p.bind(box, avals=avals), + box_get_p.bind(box_dot, avals=tuple(a.to_tangent_aval() for a in avals)) + ) + + def transpose(_, *args): + assert False # TODO +box_get_p = BoxGet('box_get') + + + +class BoxTest(jtu.JaxTestCase): + + @parameterized.parameters([False, True]) + def test_qdd(self, jit): + + val1 = 1.0 + val2 = jnp.arange(3) + + box1 = Box(val1) + + def f(box2): + assert core.cur_qdd(box2).leaf_avals == (core.typeof(val1),) + box2.set(val2) + assert core.cur_qdd(box2).leaf_avals == (core.typeof(val2),) + + box3 = new_box() + box3.set(val2) + assert core.cur_qdd(box3).leaf_avals == (core.typeof(val2),) + box3.set(val1) + assert core.cur_qdd(box3).leaf_avals == (core.typeof(val1),) + + assert core.cur_qdd(box1).leaf_avals == (core.typeof(val1),) + box1.set(val2) + assert core.cur_qdd(box1).leaf_avals == (core.typeof(val2),) + + return + + if jit: + f = jax.jit(f) + + f(Box(val1)) + + def test_jit_arg(self): + @jax.jit + def f(box, x): + assert tracing_ok + box.set(box.get() + x) + + tracing_ok = True + box1 = Box(1.0) + f(box1, 1.) + self.assertAllClose(box1.get(), 2.0) + + tracing_ok = False + box2 = Box(2.0) + f(box2, 2.) + self.assertAllClose(box2.get(), 4.0) + + def test_jit_arg2(self): + # set without get + + @jax.jit + def f(box, x): + box_set(box, x) + + box = Box(0.0) + f(box, 1.) + self.assertAllClose(box_get(box), 1.0, check_dtypes=False) + + def test_jit_arg_in_pytree(self): + @jax.jit + def f(dct, x): + assert tracing_ok + box = dct['box'] + box.set(box.get() + x) + + tracing_ok = True + box1 = Box(1.0) + f({'box': box1, 'a': 1.0}, 1.) + self.assertAllClose(box1.get(), 2.0) + + tracing_ok = False + box2 = Box(2.0) + f({'box': box2, 'a': 2.0}, 2.) + self.assertAllClose(box2.get(), 4.0) + + tracing_ok = True + box3 = Box(3) # int, dtype changed + f({'box': box3, 'a': 2.0}, 2.) + self.assertAllClose(box3.get(), 5.0) + + def test_jit_closure(self): + box = Box(1.0) + + @jax.jit + def f(x): + assert tracing_ok + box.set(box.get() + x) + + tracing_ok = True + f(2.0) + self.assertAllClose(box.get(), 3.0) + tracing_ok = False + f(5.0) + self.assertAllClose(box.get(), 8.0) + + def test_jit_closure_nested(self): + box = Box(5.0) + + @jax.jit + def f(x): + box.set(box.get() + x) + + @jax.jit + def g(x): + f(x) + + g(3.0) + self.assertAllClose(box.get(), 8.0) + + def test_jit_closure_nested2(self): + @jax.jit + def h(x): + box = new_box() + box.set(x) + + @jax.jit + def k(x): + box.set(box.get() + x) + + k(1.0) + k(1.0) + return box.get() + + ans = h(2.0) + self.assertAllClose(ans, 4.0) + + def test_jit_closure_nested3(self): + box = new_box() + + @jax.jit + def h(x): + box.set(x) + + @jax.jit + def k(x): + box.set(box.get() + x) + + k(1.0) + k(1.0) + return box.get() + + ans = h(2.0) + self.assertAllClose(ans, 4.0) + + @parameterized.parameters([False, True]) + def test_jvp_closure_stop_gradient(self, jit): + box = Box(1.0) + + def f(x): + y = 2 * x + box.set(box.get() + jax.lax.stop_gradient(y)) + return y + + if jit: + f = jax.jit(f) + + y, y_dot = jax.jvp(f, (1.0,), (1.0,)) + self.assertAllClose(y, 2.0) + self.assertAllClose(y_dot, 2.0) + self.assertAllClose(box.get(), 3.0) + + @parameterized.parameters([False, True]) + def test_jvp_arg(self, jit): + def f(box, x): + box.set(box.get() + x) + return x + + if jit: + f = jax.jit(f) + + box = Box(5.0) + box_dot = Box(1.0) + y, y_dot = jax.jvp(f, (box, 2.), (box_dot, 1.)) + self.assertAllClose(y, 2.0) + self.assertAllClose(y_dot, 1.0) + self.assertAllClose(box.get(), 7.0) + self.assertAllClose(box_dot.get(), 2.0) + + @parameterized.parameters([False, True]) + def test_custom_vjp_plumbing(self, jit): + box = Box(0.0) + + @jax.custom_vjp + def foo(x): + return x + def foo_fwd(x): + return foo(x), None + def foo_bwd(_, g): + box.set(g) + return g, + foo.defvjp(foo_fwd, foo_bwd) + + def f(x): + x = 2 * x + x = foo(x) + x = 2 * x + return x + + if jit: + f = jax.jit(f) + + jax.grad(f)(1.0) + + self.assertAllClose(box.get(), 2.0) + + # TODO(mattjj,dougalm): make this work... + # @parameterized.parameters([False, True]) + # def test_custom_vjp_plumbing_abstracted(self, jit): + # box = Box(0.0) + + # @jax.custom_vjp + # def foo(box, x): + # return x + # def foo_fwd(box, x): + # return x, box + # def foo_bwd(box, g): + # box.set(g) + # return None, g + # foo.defvjp(foo_fwd, foo_bwd) + + # def f(box, x): + # x = 2 * x + # x = foo(box, x) + # x = 2 * x + # return x + + # if jit: + # f = jax.jit(f) + + # jax.grad(partial(f, box))(1.0) + # self.assertAllClose(box.get(), 2.0) + + @parameterized.parameters([False, True]) + def test_grad_closure_stop_gradient(self, jit): + box = Box(0.0) + + def f(x): + y = x * 2 + box.set(box.get() + jax.lax.stop_gradient(y)) + return y + + if jit: + f = jax.jit(f) + + g = jax.grad(f)(1.0) + self.assertAllClose(g, 2.0) + self.assertAllClose(box.get(), 2.0) + + @parameterized.parameters([False, True]) + def test_scan_basic(self, jit): + box = Box(1.0) + + def double_it_10(): + def body(_, __): + box.set(box.get() * 2) + return None, None + _, _ = jax.lax.scan(body, None, None, length=10) + + if jit: + double_it_10 = jax.jit(double_it_10) + + double_it_10() + + self.assertAllClose(box.get(), 1024., check_dtypes=False) + + # TODO error-checking tests from attrs_test.py + + ### + + def test_box_autodiff(self): + if config.enable_x64.value: raise unittest.SkipTest("no x64") + + class StashTangents(HiPrimitive): + def is_high(self): + return True + + def abstract_eval(_, box_aval, x_aval): + del box_aval + return x_aval, set() + + def to_lojax(_, box, x): + return x + + def jvp(_, primals, tangents): + box, x = primals + _, x_dot = tangents + box_set(box, x_dot) + return x, x_dot + + def transpose(self, *args): + assert False # TODO + stash_tangents_p = StashTangents('stash_tangents') + + def stash_tangents(box, x): + return stash_tangents_p.bind(box, x) + + @jax.jit + def f(box, x): + x = stash_tangents(box, x) + return x + + box = Box(0.0) + jax.jvp(partial(f, box), (3.,), (5.,)) + self.assertAllClose(box_get(box), 5.0, check_dtypes=False) + + def test_type_changing_box(self): + box = Box(jnp.arange(1)) + box_set(box, jnp.arange(2)) + self.assertLen(box._val, 2) + + @jax.jit + def f(box, x): + box_set(box, x) + + f(box, jnp.arange(3)) + self.assertLen(box._val, 3) + f(box, jnp.arange(4)) + self.assertLen(box._val, 4) + + def test_pytree_box(self): + box = Box(None) + + @jax.jit + def f(box, x): + assert tracing_ok + val = box_get(box) + if val is None: + box_set(box, x) + else: + box_set(box, [x, x]) + + tracing_ok = True + f(box, 1.0) + self.assertAllClose(box_get(box), 1.0, check_dtypes=False) + f(box, 2.0) + self.assertAllClose(box_get(box), [2.0, 2.0], check_dtypes=False) + f(box, 3.0) + self.assertAllClose(box_get(box), [3.0, 3.0], check_dtypes=False) + tracing_ok = False + f(box, 4.0) + self.assertAllClose(box_get(box), [4.0, 4.0], check_dtypes=False) + + def test_pytree_of_hijaxtypes_box(self): + + @dataclass(frozen=True) + class MyArray: + arr: jax.Array # always f32 + + @dataclass(frozen=True) + class MyTy(core.AbstractValue): + has_qdd = False + + def to_tangent_aval(self): + return MyTy() + def str_short(self, short_dtypes=False): + return 'MyTy' + def lo_ty(self): + return [core.ShapedArray((), jnp.dtype('float32'))] + def lower_val(self, hi_val: MyArray) -> list[jax.Array]: + return [hi_val.arr] + def raise_val(self, val) -> MyArray: + return MyArray(val) + + def __eq__(self, other): return isinstance(other, MyTy) + + core.pytype_aval_mappings[MyArray] = lambda _: MyTy() + + box = Box([MyArray(jnp.float32(1)), + MyArray(jnp.float32(2))]) + + @jax.jit + def f(box): + a, b = box_get(box) + box_set(box, [b, a]) + + f(box) + val = box_get(box) + self.assertIsInstance(val, list) + self.assertLen(val, 2) + b_, a_ = val + self.assertIsInstance(a_, MyArray) + self.assertIsInstance(b_, MyArray) + self.assertAllClose(a_.arr, 1, check_dtypes=False) + self.assertAllClose(b_.arr, 2, check_dtypes=False) + + +class ListTy(core.AbstractValue): + has_qdd = True + + # forwarded to value + get = core.aval_method(box_get) + set = core.aval_method(box_set) + + # aval interface: hashability and str_short + def __hash__(self): return hash(BoxTy) + def __eq__(self, other): return isinstance(other, BoxTy) + + def str_short(self, short_dtypes=False): + return 'ListTy' + + # TODO + +class ListTest(jtu.JaxTestCase): + ... + + + +if __name__ == '__main__': + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/infeed_test.py b/tests/infeed_test.py index 060502ae68cd..08052b041bae 100644 --- a/tests/infeed_test.py +++ b/tests/infeed_test.py @@ -36,7 +36,9 @@ def setUp(self): raise SkipTest("infeed not implemented in PJRT C API") super().setUp() - @jax.numpy_rank_promotion("allow") # Test explicitly exercises implicit rank promotion. + @jax.numpy_rank_promotion( + "allow" + ) # Test explicitly exercises implicit rank promotion. def testInfeed(self): raise SkipTest("skipping temporarily for stackless") @@ -44,13 +46,17 @@ def testInfeed(self): def f(x): token = lax.create_token(x) (y,), token = lax.infeed( - token, shape=(core.ShapedArray((3, 4), jnp.float32),)) + token, shape=(core.ShapedArray((3, 4), jnp.float32),) + ) (z,), _ = lax.infeed( - token, shape=(core.ShapedArray((3, 1, 1), jnp.float32),)) + token, shape=(core.ShapedArray((3, 1, 1), jnp.float32),) + ) return x + y + z x = np.float32(1.5) - y = np.reshape(np.arange(12, dtype=np.float32), (3, 4)) # self.rng().randn(3, 4).astype(np.float32) + y = np.reshape( + np.arange(12, dtype=np.float32), (3, 4) + ) # self.rng().randn(3, 4).astype(np.float32) z = self.rng().randn(3, 1, 1).astype(np.float32) device = jax.local_devices()[0] device.transfer_to_infeed((y,)) @@ -63,8 +69,11 @@ def testInfeedPytree(self): x = np.float32(1.5) y = np.reshape(np.arange(12, dtype=np.int16), (3, 4)) to_infeed = dict(a=x, b=y) - to_infeed_shape = dict(a=core.ShapedArray((), dtype=np.float32), - b=core.ShapedArray((3, 4), dtype=np.int16)) + to_infeed_shape = dict( + a=core.ShapedArray((), dtype=np.float32), + b=core.ShapedArray((3, 4), dtype=np.int16), + ) + @jax.jit def f(x): token = lax.create_token(x) @@ -77,14 +86,18 @@ def f(x): device.transfer_to_infeed(tuple(flat_to_infeed)) self.assertAllClose(f(x), to_infeed) - @jax.numpy_rank_promotion("allow") # Test explicitly exercises implicit rank promotion. + @jax.numpy_rank_promotion( + "allow" + ) # Test explicitly exercises implicit rank promotion. + @jtu.ignore_warning( + category=DeprecationWarning, message=".*(infeed|outfeed) was deprecated.*" + ) def testInfeedThenOutfeed(self): @jax.jit def f(x): token = lax.create_token(x) - y, token = lax.infeed( - token, shape=core.ShapedArray((3, 4), jnp.float32)) + y, token = lax.infeed(token, shape=core.ShapedArray((3, 4), jnp.float32)) token = lax.outfeed(token, y + np.float32(1)) return x - 1 @@ -94,16 +107,21 @@ def f(x): execution.start() device = jax.local_devices()[0] device.transfer_to_infeed((y,)) - out, = device.transfer_from_outfeed( - xla_client.shape_from_pyval((y,)).with_major_to_minor_layout_if_absent()) + out = device.transfer_from_outfeed( + xla_client.Shape.array_shape( + xla_client.PrimitiveType.F32, (3, 4) + ).with_major_to_minor_layout_if_absent() + ) execution.join() self.assertAllClose(out, y + np.float32(1)) + @jtu.ignore_warning( + category=DeprecationWarning, message=".*(infeed|outfeed) was deprecated.*" + ) def testInfeedThenOutfeedInALoop(self): def doubler(_, token): - y, token = lax.infeed( - token, shape=core.ShapedArray((3, 4), jnp.float32)) + y, token = lax.infeed(token, shape=core.ShapedArray((3, 4), jnp.float32)) return lax.outfeed(token, y * np.float32(2)) @jax.jit @@ -119,11 +137,14 @@ def f(n): for _ in range(n): x = self.rng().randn(3, 4).astype(np.float32) device.transfer_to_infeed((x,)) - y, = device.transfer_from_outfeed(xla_client.shape_from_pyval((x,)) - .with_major_to_minor_layout_if_absent()) + y = device.transfer_from_outfeed( + xla_client.Shape.array_shape( + xla_client.PrimitiveType.F32, (3, 4) + ).with_major_to_minor_layout_if_absent() + ) self.assertAllClose(y, x * np.float32(2)) execution.join() -if __name__ == '__main__': +if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/jax_jit_test.py b/tests/jax_jit_test.py index 5946d557d4ba..cbf7c710f0e8 100644 --- a/tests/jax_jit_test.py +++ b/tests/jax_jit_test.py @@ -227,6 +227,25 @@ def fn(x): self.assertArraysEqual(v1, v1_expected) self.assertArraysEqual(v2, v2_expected) + def test_check_for_large_number_of_constants(self): + y = jnp.ones((128, 128)) + x = jnp.zeros((128,)) + + def jit_maker(): # need to ensure we lower at each test + def func(x): + return x @ y + return jax.jit(func) + + with self.assertWarnsRegex(UserWarning, "A large amount of constants were captured during lowering"): + with config.captured_constants_warn_bytes(y.nbytes): + jit_maker()(x) + + with self.assertNoWarnings(): + with config.captured_constants_warn_bytes(y.nbytes + 1): + jit_maker()(x) + + with config.captured_constants_warn_bytes(-1): + jit_maker()(x) if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/jax_numpy_error_test.py b/tests/jax_numpy_error_test.py new file mode 100644 index 000000000000..dba277289f42 --- /dev/null +++ b/tests/jax_numpy_error_test.py @@ -0,0 +1,278 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import operator + +from absl.testing import absltest +from absl.testing import parameterized +import jax +from jax._src import config +from jax._src import error_check +from jax._src import test_util as jtu +from jax._src.numpy import error as jnp_error +import jax.numpy as jnp + +config.parse_flags_with_absl() + + +JaxValueError = error_check.JaxValueError + + +class JaxNumpyErrorTests(jtu.JaxTestCase): + def setUp(self): + # TODO(b/408148001): Fix thread safety issue. + if jtu.TEST_NUM_THREADS.value > 1: + self.skipTest("Test does not work with multiple threads") + super().setUp() + + @parameterized.product(jit=[True, False]) + def test_set_error_if_nan(self, jit): + def f(x): + jnp_error._set_error_if_nan(x) + return x + + if jit: + f = jax.jit(f) + + x = jnp.full((4,), jnp.nan, dtype=jnp.float32) + + with jnp_error.error_checking_behavior(nan="ignore"): + _ = f(x) + error_check.raise_if_error() # should not raise error + + with jnp_error.error_checking_behavior(nan="raise"): + _ = f(x) + with self.assertRaisesRegex(JaxValueError, "NaN"): + error_check.raise_if_error() + + @parameterized.product(jit=[True, False]) + def test_set_error_if_divide_by_zero(self, jit): + def f(x, y): + jnp_error._set_error_if_divide_by_zero(y) + return x / y + + if jit: + f = jax.jit(f) + + x = jnp.arange(4, dtype=jnp.float32) + 1 + y = jnp.arange(4, dtype=jnp.float32) + + with jnp_error.error_checking_behavior(divide="ignore"): + _ = f(x, y) + error_check.raise_if_error() # should not raise error + + with jnp_error.error_checking_behavior(divide="raise"): + _ = f(x, y) + with self.assertRaisesRegex(JaxValueError, "Division by zero"): + error_check.raise_if_error() + + @parameterized.product(jit=[True, False]) + def test_error_category_oob_check(self, jit): + def f(x, start_indices, slice_sizes): + jnp_error._set_error_if_with_category( + jnp.logical_or( + start_indices < 0, + start_indices + jnp.array(slice_sizes, dtype=jnp.int32) + >= jnp.array(x.shape, dtype=jnp.int32), + ), + "Out of bounds in dynamic_slice", + category="oob", + ) + y = jax.lax.dynamic_slice( + x, start_indices, slice_sizes, allow_negative_indices=False + ) + return y + + if jit: + f = jax.jit(f, static_argnums=(2,)) + + x = jnp.arange(12).reshape(3, 4) + start_indices = jnp.array([0, -1], dtype=jnp.int32) + slice_sizes = (3, 4) + + with jnp_error.error_checking_behavior(oob="ignore"): + _ = f(x, start_indices, slice_sizes) + error_check.raise_if_error() # should not raise error + + with jnp_error.error_checking_behavior(oob="raise"): + _ = f(x, start_indices, slice_sizes) + with self.assertRaisesRegex( + JaxValueError, "Out of bounds in dynamic_slice", + ): + error_check.raise_if_error() + + def test_error_category_invalid_category(self): + with self.assertRaisesRegex(ValueError, "Invalid category"): + jnp_error._set_error_if_with_category( + jnp.isnan(jnp.float32(1.0)), "x is NaN", category="invalid" + ) + + @staticmethod + def nan_cases(cases): + for jit in (True, False): + for func, args_error, args_no_err in cases: + if not isinstance(args_error, tuple): + args_error = (args_error,) + if not isinstance(args_no_err, tuple): + args_no_err = (args_no_err,) + + jit_str = "jit" if jit else "nojit" + func_str = f"{func.__module__}.{func.__name__}" + name = f"_{jit_str}_{func_str}" + + yield name, jit, func, args_error, args_no_err + + @parameterized.named_parameters( + nan_cases(( + # List of all NaN-producing jax.numpy functions. + # The first group of numbers is the input that will produce a NaN, and + # the second group is the input that will not produce a NaN. + # go/keep-sorted start + (jnp.acos, 2.0, 0.5), + (jnp.acosh, 0.5, 2.0), + (jnp.add, (jnp.inf, -jnp.inf), (0.0, 0.0)), + (jnp.arccos, 2.0, 0.5), + (jnp.arccosh, 0.5, 2.0), + (jnp.arcsin, -2.0, 0.5), + (jnp.arctanh, -2.0, 0.5), + (jnp.asin, -2.0, 0.5), + (jnp.atanh, -2.0, 0.5), + (jnp.cos, jnp.inf, 1.0), + (jnp.divide, (0.0, 0.0), (1.0, 1.0)), + (jnp.divmod, (1.0, 0.0), (1.0, 1.0)), + (jnp.float_power, (-1.0, 0.5), (1.0, 1.0)), + (jnp.fmod, (1.0, 0.0), (1.0, 1.0)), + (jnp.log, -1.0, 1.0), + (jnp.log10, -1.0, 1.0), + (jnp.log1p, -1.5, 1.0), + (jnp.log2, -1.0, 1.0), + (jnp.mod, (1.0, 0.0), (1.0, 1.0)), + (jnp.pow, (-1.0, 0.5), (1.0, 1.0)), + (jnp.power, (-1.0, 0.5), (1.0, 1.0)), + (jnp.remainder, (1.0, 0.0), (1.0, 1.0)), + (jnp.sin, jnp.inf, 1.0), + # TODO(https://github.com/jax-ml/jax/issues/27470): Not yet supported. + # (jnp.sinc, jnp.inf, 1.0), + (jnp.sqrt, -4.0, 4.0), + (jnp.subtract, (jnp.inf, jnp.inf), (0.0, 0.0)), + (jnp.tan, jnp.inf, 1.0), + (jnp.true_divide, (0.0, 0.0), (1.0, 1.0)), + (operator.add, (jnp.inf, -jnp.inf), (0.0, 0.0)), + (operator.mod, (1.0, 0.0), (1.0, 1.0)), + (operator.pow, (-1.0, 0.5), (1.0, 1.0)), + (operator.sub, (jnp.inf, jnp.inf), (0.0, 0.0)), + (operator.truediv, (0.0, 0.0), (1.0, 1.0)), + # go/keep-sorted end + )) + ) + def test_can_raise_nan_error(self, jit, f, args_err, args_no_err): + args_err = [jnp.float32(x) for x in args_err] + args_no_err = [jnp.float32(x) for x in args_no_err] + + if jit: + f = jax.jit(f) + + with jnp_error.error_checking_behavior(nan="raise"): + f(*args_no_err) + error_check.raise_if_error() # should not raise error + + f(*args_err) + with self.assertRaisesRegex(JaxValueError, "NaN"): + error_check.raise_if_error() + + INT_TYPES = (jnp.int32, jnp.uint32, jnp.int64, jnp.uint64, jnp.int16, + jnp.uint16, jnp.int8, jnp.uint8) + FLOAT_TYPES = (jnp.float32, jnp.float64, jnp.float16, jnp.bfloat16) + + @staticmethod + def divide_cases(cases): + for jit in (True, False): + for func, dtypes in cases: + for dtype in dtypes: + jit_str = "jit" if jit else "nojit" + func_str = f"{func.__module__}.{func.__name__}" + dtype_str = dtype.__name__ + name = f"_{jit_str}_{func_str}_{dtype_str}" + yield name, jit, func, dtype + + @parameterized.named_parameters( + divide_cases(( + # go/keep-sorted start + (jnp.divmod, FLOAT_TYPES + INT_TYPES), + (jnp.floor_divide, INT_TYPES), + (jnp.mod, FLOAT_TYPES + INT_TYPES), + (jnp.remainder, FLOAT_TYPES + INT_TYPES), + (jnp.true_divide, FLOAT_TYPES), + (operator.mod, FLOAT_TYPES + INT_TYPES), + (operator.truediv, FLOAT_TYPES), + # go/keep-sorted end + )) + ) + def test_can_raise_divide_by_zero_error(self, jit, div_func, dtype): + if not jax.config.x64_enabled and jnp.dtype(dtype).itemsize == 8: + self.skipTest("64-bit types require x64_enabled") + + args_err = (dtype(1), dtype(0)) + args_no_err = (dtype(1), dtype(1)) + + if jit: + div_func = jax.jit(div_func) + + with jnp_error.error_checking_behavior(divide="raise"): + div_func(*args_no_err) + error_check.raise_if_error() # should not raise error + + div_func(*args_err) + with self.assertRaisesRegex(JaxValueError, "Division by zero"): + error_check.raise_if_error() + + @parameterized.product(jit=[True, False]) + def test_can_raise_oob_error_take(self, jit): + def f(x, a): + return x[a] + + if jit: + f = jax.jit(f) + + x = jnp.arange(10) + a = jnp.int32(10) + + with jnp_error.error_checking_behavior(oob="ignore"): + f(x, a) + error_check.raise_if_error() # should not raise error + + with jnp_error.error_checking_behavior(oob="raise"): + f(x, a) + with self.assertRaisesRegex(JaxValueError, "Out of bounds"): + error_check.raise_if_error() + + def test_can_raise_oob_error_dynamic_slice(self): + def f(x, a): + return x[:, a:a+4] # dynamic indices are non-jittable + + x = jnp.arange(10).reshape(2, 5) + a = jnp.array(3, dtype=jnp.int32) + + with jnp_error.error_checking_behavior(oob="ignore"): + f(x, a) + error_check.raise_if_error() # should not raise error + + with jnp_error.error_checking_behavior(oob="raise"): + f(x, a) + with self.assertRaisesRegex(JaxValueError, "Out of bounds"): + error_check.raise_if_error() + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/jax_to_ir_test.py b/tests/jax_to_ir_test.py index f600a08f5dc4..4eb8190b712f 100644 --- a/tests/jax_to_ir_test.py +++ b/tests/jax_to_ir_test.py @@ -114,15 +114,13 @@ def test_parse_shape_str(self): self.assertParsedShape('f32[]', [], jnp.float32) self.assertParsedShape('f32[1,2,3]', [1, 2, 3], jnp.float32) self.assertParsedShape('pred[1]', [1], jnp.bool_) - if hasattr(jnp, 'int2'): - self.assertParsedShape('s2[1]', [1], jnp.int2) + self.assertParsedShape('s2[1]', [1], jnp.int2) self.assertParsedShape('s4[1]', [1], jnp.int4) self.assertParsedShape('s8[1]', [1], jnp.int8) self.assertParsedShape('s16[1]', [1], jnp.int16) self.assertParsedShape('s32[1]', [1], jnp.int32) self.assertParsedShape('s64[1]', [1], jnp.int64) - if hasattr(jnp, 'uint2'): - self.assertParsedShape('u2[1]', [1], jnp.uint2) + self.assertParsedShape('u2[1]', [1], jnp.uint2) self.assertParsedShape('u4[1]', [1], jnp.uint4) self.assertParsedShape('u8[1]', [1], jnp.uint8) self.assertParsedShape('u16[1]', [1], jnp.uint16) diff --git a/tests/jaxpr_effects_test.py b/tests/jaxpr_effects_test.py index c331bfaf438a..420aa642a1d6 100644 --- a/tests/jaxpr_effects_test.py +++ b/tests/jaxpr_effects_test.py @@ -527,7 +527,7 @@ def log_value(x): @jax.jit def f(x): - return callback_p.bind(x, callback=log_value, effect=log_effect, out_avals=[]) + return callback_p.bind(x, callback=log_value, effect=log_effect, out_avals=()) f(2.) jax.effects_barrier() @@ -552,11 +552,11 @@ def f(x): # Expensive computation x = x.dot(x) x = jnp.log(x.sum()) - return callback_p.bind(x, callback=log_value, effect=log_effect, out_avals=[]) + return callback_p.bind(x, callback=log_value, effect=log_effect, out_avals=()) @jax.jit def g(x): - return callback_p.bind(x, callback=log_value, effect=log_effect, out_avals=[]) + return callback_p.bind(x, callback=log_value, effect=log_effect, out_avals=()) x = jax.device_put(jnp.ones((500, 500)), jax.devices()[0]) y = jax.device_put(3., jax.devices()[1]) @@ -579,7 +579,7 @@ def f(x): # Runs in a thread. res = jax.jit( lambda x: callback_p.bind( - x, callback=_noop, effect=log_effect, out_avals=[]) + x, callback=_noop, effect=log_effect, out_avals=()) )(x) tokens.append(dispatch.runtime_tokens.current_tokens[log_effect]) return res @@ -635,7 +635,7 @@ def log_value(x): @jax.pmap def f(x): callback_p.bind( - x, callback=log_value, effect=unordered_log_effect, out_avals=[]) + x, callback=log_value, effect=unordered_log_effect, out_avals=()) return x + 1 f(jnp.arange(2)).block_until_ready() jax.effects_barrier() @@ -947,7 +947,7 @@ def make_fun(index): def f(x): def body(y): input_effect(x, y, index=index) - return y + return 2 * y lax.while_loop(lambda _: True, body, y) return f jaxpr = jax.make_jaxpr(make_fun(0))(0) @@ -959,7 +959,7 @@ def body(y): def f(x): def body(y): input_effect(x, y, index=1) - return y + return 2 * y lax.while_loop(lambda _: (x > 0).all(), body, y) jaxpr = jax.make_jaxpr(f)(0) self.assertIn(InputEffect(0), jaxpr.effects) diff --git a/tests/lax_autodiff_test.py b/tests/lax_autodiff_test.py index a69f44f37754..a6398e402df9 100644 --- a/tests/lax_autodiff_test.py +++ b/tests/lax_autodiff_test.py @@ -28,7 +28,6 @@ from jax import dtypes from jax import lax from jax._src import test_util as jtu -from jax._src.util import NumpyComplexWarning from jax.test_util import check_grads jax.config.parse_flags_with_absl() @@ -205,14 +204,16 @@ class LaxAutodiffTest(jtu.JaxTestCase): )) def testOpGrad(self, op, rng_factory, shapes, dtype, order, tol): rng = rng_factory(self.rng()) - if jtu.test_device_matches(["cpu"]): + if jtu.test_device_matches(["cpu", "tpu"]): if op is lax.cosh and dtype == np.complex64: - tol = 3e-1 # 2nd-order gradients are noisy on CPU + tol = 3e-1 # 2nd-order gradients are noisy on CPU and TPU if jtu.test_device_matches(["tpu"]): if op is lax.pow: raise SkipTest("pow grad imprecise on tpu") if op is lax.cos: order = 1 # 2nd-order gradient is imprecise on TPU. + if op is lax.sin: + order = 1 # 2nd-order gradient is imprecise on TPUv5p. if op is lax.log: order = 1 # 2nd-order gradient is imprecise on TPU. @@ -242,7 +243,7 @@ def testConvertElementTypeGrad(self, from_dtype, to_dtype): jtu.tolerance(from_dtype, jtu.default_gradient_tolerance)) args = (rng((2, 3), from_dtype),) convert_element_type = lambda x: lax.convert_element_type(x, to_dtype) - convert_element_type = jtu.ignore_warning(category=NumpyComplexWarning)( + convert_element_type = jtu.ignore_warning(category=np.exceptions.ComplexWarning)( convert_element_type) check_grads(convert_element_type, args, 2, ["fwd", "rev"], tol, tol, eps=1.) diff --git a/tests/lax_control_flow_test.py b/tests/lax_control_flow_test.py index 3871a87a7a3e..42bc953a8236 100644 --- a/tests/lax_control_flow_test.py +++ b/tests/lax_control_flow_test.py @@ -33,12 +33,14 @@ from jax import random from jax._src import test_util as jtu from jax import tree_util -from jax._src.util import unzip2 +from jax._src.util import unzip2, split_list from jax.ad_checkpoint import checkpoint as new_checkpoint, checkpoint_policies import jax.numpy as jnp # scan tests use numpy import jax.scipy as jsp +from jax._src import dispatch from jax._src.lax import control_flow as lax_control_flow from jax._src.lax.control_flow import for_loop +from jax._src.interpreters import batching from jax._src.interpreters import mlir jax.config.parse_flags_with_absl() @@ -137,6 +139,36 @@ def scan_reference(f, init, xs): lambda ctx, x: mlir.hlo.CustomCallOp( [x.type], [x], call_target_name=mlir.ir.StringAttr.get("__testing_non_existent_custom_call")).results) +batching.primitive_batchers[prim_non_existent_custom_call] = ( + lambda batched_args, batch_dims: (prim_non_existent_custom_call.bind(batched_args[0]), + batch_dims[0])) + +# A JAX primitive that triggers error when lowering on unintended platforms +prim_with_lowering_error = core.Primitive("__testing_prim_with_lowering_error") +prim_with_lowering_error.def_abstract_eval(lambda x_aval, **_: x_aval) +def prim_with_lowering_error_lowering(platform: str, + ctx: mlir.LoweringRuleContext, x, *, + only_on: str): + if platform != only_on: + raise ValueError(f"prim_with_lowering_error with only_on={only_on} lowered for {platform}") + return mlir.hlo.SineOp(x).results +def prim_with_lowering_error_batch_rule(batched_args, batch_dims, **params): + xs, = batched_args + xs_bdim, = batch_dims + return prim_with_lowering_error.bind(xs, **params), xs_bdim + +batching.primitive_batchers[prim_with_lowering_error] = prim_with_lowering_error_batch_rule + +mlir.register_lowering( + prim_with_lowering_error, + partial(prim_with_lowering_error_lowering, "cpu"), + platform="cpu") +mlir.register_lowering( + prim_with_lowering_error, + partial(prim_with_lowering_error_lowering, "tpu"), + platform="tpu") +prim_with_lowering_error.def_impl(partial(dispatch.apply_primitive, + prim_with_lowering_error)) class LaxControlFlowTest(jtu.JaxTestCase): @@ -588,7 +620,6 @@ def test_fori_loop_returns_init_with_nonpositive_length( init = jnp.float32(10) self.assertEqual(fori_loop_with_static_upper_and_lower(init), init) - def testForiLoopBatched(self): def body_fun(i, loop_carry): x, y = loop_carry @@ -994,16 +1025,24 @@ def testCondTypeErrors(self): with self.assertRaisesRegex(TypeError, re.escape("Pred must be a scalar, got (1.0, 1.0) of type ")): lax.cond((1., 1.), lambda top: 2., lambda fop: 3., 1.) - with self.assertRaisesRegex(TypeError, - re.compile("true_fun output must have same type structure " - "as false_fun output, but there are differences:.*" - r"at output\['a'\], true_fun output has pytree leaf", re.DOTALL)): + + with self.assertRaisesRegex( + TypeError, + re.compile( + r"cond branch outputs must have the same pytree structure, but they" + r" differ:.*true_fun output at path \['a'\] is a pytree leaf but" + r" false_fun output at path \['a'\] is a ", + re.DOTALL)): lax.cond(True, lambda top: dict(a=2.), lambda fop: dict(a=(3., 3.)), 1.) + with self.assertRaisesRegex( TypeError, - "true_fun output and false_fun output must have identical types, got\n" - r"DIFFERENT ShapedArray\(float32\[1\]\) vs. " - r"ShapedArray\(float32\[\].*\)."): + re.compile( + r"cond branches must have equal output types but they differ.*The" + r" output of true_fun has type float32\[1\] but the corresponding" + r" output of false_fun has type float32\[\], so the shapes do not" + r" match", + re.DOTALL)): lax.cond(True, lambda top: jnp.array([1.], jnp.float32), lambda fop: jnp.float32(1.), @@ -1023,16 +1062,26 @@ def testSwitchErrors(self): with self.assertRaisesRegex(ValueError, re.escape("Empty branch sequence")): lax.switch(0, [], 1.) - with self.assertRaisesRegex(TypeError, - re.compile("branch 0 output must have same type structure " - "as branch 1 output, but there are differences:.*" - r"at output\['a'\], branch 0 output has pytree leaf", re.DOTALL)): + + with self.assertRaisesRegex( + TypeError, + re.compile( + "switch branch outputs must have the same pytree structure, but" + r" they differ.*branch 0 output at path \['a'\] is a pytree leaf" + r" but branch1 output at path \['a'\] is a , so" + r" their" + " Python types differ.", + re.DOTALL)): lax.switch(1, [lambda _: dict(a=2.), lambda _: dict(a=(3., 3.))], 1.) + with self.assertRaisesRegex( TypeError, - "branch 0 output and branch 1 output must have identical types, got\n" - r"{'a': 'DIFFERENT ShapedArray\(float32\[1\]\) " - r"vs. ShapedArray\(float32\[\].*\)'}."): + re.compile( + "switch branches must have equal output types but they differ.*The" + r" output of branch 0 at path \['a'\] has type float32\[1\] but the" + r" corresponding output of branch1 has type float32\[\], so the" + " shapes do not match", + re.DOTALL)): lax.switch(1, [lambda _: dict(a=jnp.array([1.], jnp.float32)), lambda _: dict(a=jnp.float32(1.))], 1.) @@ -1309,6 +1358,34 @@ def f(x): self.assertAllClose(ans, expected, check_dtypes=False) jtu.check_grads(f, (x,), order=2, modes=["fwd", "rev"]) + @parameterized.parameters(itertools.product(range(4), repeat=3)) + @jtu.run_on_devices("cpu") + def testSwitchGradWithForwarding(self, seed, num_input_fwd, num_output_fwd): + num_args = 3 + num_branches = 4 + rng = np.random.RandomState(seed) + in_perm = rng.permutation(num_args) + out_perm = rng.permutation(num_args) + + def branch(s, inputs): + inputs = [inputs[i] for i in in_perm] + outputs = inputs[:num_input_fwd] + [ + s * jnp.exp(inputs[i]) if i < num_output_fwd else jnp.sin(inputs[i]) + for i in range(num_args - num_input_fwd)] + return [outputs[i] for i in out_perm] + + branches = [partial(branch, i) for i in range(num_branches)] + + @jax.jit + def f_(idx, inputs): + idx = lax.convert_element_type(idx // 1, np.int32) + return lax.switch(idx, branches, inputs) + + for idx in range(num_branches): + f = partial(f_, idx) + jtu.check_grads(f, (jnp.arange(float(num_args)),), + order=1, modes=['fwd', 'rev'], atol=1e-2, rtol=1e-2) + def testSwitchGradWithWeakTypeMismatch(self): # issue #4696, PR #4896 dtype = dtypes.canonicalize_dtype(np.float64) dtype = jnp.float32 if dtype == jnp.float32 else jnp.float64 @@ -1333,7 +1410,7 @@ def f(x): @parameterized.named_parameters( {"testcase_name": f"_{name}", "cond": cond} for cond, name in COND_IMPLS) - def testCondGrad2(self, cond): + def testCondGrad2(self, cond=cond_with_new_checkpoint): def f_ref(x): z = jnp.array([1., 2.], x.dtype) * x if x[0] < 2 else jnp.sin(x) return z.sum() @@ -1725,15 +1802,20 @@ def f(c, a): c = rng.randn(4) if scan is scan_with_new_checkpoint2: + atol = {} rtol = {np.float64: 1e-12, np.float32: 1e-4} elif scan is scan_with_for: + atol = {} rtol = {np.float64: 1e-12, np.float32: 1e-4} else: + atol = {np.float64: 1e-14} rtol = {np.float64: 1e-14, np.float32: 1e-4} ans = jax.linearize(lambda c, as_: scan(f, c, as_), c, as_)[1](c, as_) expected = jax.linearize(lambda c, as_: scan_reference(f, c, as_), c, as_)[1](c, as_) - self.assertAllClose(ans, expected, check_dtypes=False, rtol=rtol) + self.assertAllClose( + ans, expected, check_dtypes=False, atol=atol, rtol=rtol + ) @parameterized.named_parameters( {"testcase_name": f"_{jit_scan=}_{jit_f=}_impl={scan_name}", @@ -1896,7 +1978,7 @@ def plus_one(p, iter_idx): def testScanBodyOutputError(self): with self.assertRaisesRegex( TypeError, - re.escape("scan body output must be a pair, got ShapedArray(float32[]).")): + re.escape("scan body output must be a pair, got float32[].")): lax.scan(lambda c, x: np.float32(0.), 0, jnp.arange(5.)) def testScanMetadataError(self): @@ -1955,7 +2037,7 @@ def testScanBodyCarryTypeMismatchErrors(self): with self.assertRaisesRegex( TypeError, re.escape("function carry input and carry output must have equal " - "types (e.g. shapes and dtypes of arrays), but they differ:\n\n" + "types, but they differ:\n\n" "The input carry x has type int32[] but the corresponding " "output carry component has type float32[], so the dtypes do " "not match" @@ -1966,7 +2048,7 @@ def testScanBodyCarryTypeMismatchErrors(self): with self.assertRaisesRegex( TypeError, re.escape("function carry input and carry output must have equal " - "types (e.g. shapes and dtypes of arrays), but they differ:\n\n" + "types, but they differ:\n\n" "The input carry component x[1] has type int32[] but the " "corresponding output carry component has type float32[], " "so the dtypes do not match" @@ -1977,13 +2059,13 @@ def testScanBodyCarryTypeMismatchErrors(self): with self.assertRaisesRegex( TypeError, re.escape("function carry input and carry output must have equal " - "types (e.g. shapes and dtypes of arrays), but they differ:\n\n" + "types, but they differ:\n\n" " * the input carry component x[0] has type int32[] but the " "corresponding output carry component has type float32[], " "so the dtypes do not match;\n" " * the input carry component x[1] has type int32[] but the " "corresponding output carry component has type float32[1,1], " - "so the dtypes do not match and also the shapes do not match." + "so the dtypes do not match, and the shapes do not match." )): jax.lax.scan(lambda x, _: ((x[0].astype('float32'), x[1].astype('float32').reshape(1, 1), @@ -2192,7 +2274,7 @@ def body(x): def test_caches_depend_on_axis_env(self): # https://github.com/jax-ml/jax/issues/9187 - scanned_f = lambda _, __: (lax.psum(1, 'i'), None) + scanned_f = lambda _, __: (lax.axis_size('i'), None) f = lambda: lax.scan(scanned_f, 0, None, length=1)[0] ans = jax.vmap(f, axis_name='i', axis_size=2, out_axes=None)() self.assertEqual(ans, 2) @@ -2317,7 +2399,7 @@ def testWhileGradError(self, loop: str = "fori_inside_scan"): elif loop == "fori_inside_cond": func = lambda x: lax.cond( True, - x, lambda x: lax.fori_loop(x, x + 2., lambda i, c: c, x), + x, lambda x: lax.fori_loop(x, x + 2., lambda i, c: c * 2., x), 1., lambda x: x) elif loop == "fori_inside_scan": func = lambda x: lax.scan( @@ -2467,7 +2549,7 @@ def f(c, a): self.assertLess(len(scan_unrolled_hlo), len(scan_fully_unrolled_hlo)) # and the lowering should contain a while loop, unless the scan is fully - # unrolled + # unrolled self.assertIn("while(", scan_hlo) self.assertIn("while(", scan_unrolled_hlo) self.assertNotIn("while(", scan_fully_unrolled_hlo) @@ -2758,7 +2840,6 @@ def cond_fun(val): self.assertAllClose(deriv(my_pow)(3.0, 1), 1.0, check_dtypes=False) - def test_while_loop_fixed_point_with_batched_pred_and_consts(self): def f(i, x): def cond(carry): @@ -2861,18 +2942,13 @@ def f(x): x = np.arange(3, dtype=np.float32) lowered = jax.jit(f).lower(x) stablehlo = lowered.as_text() - self.assertIn("stablehlo.case", stablehlo) - self.assertIn("stablehlo.sine", stablehlo) - self.assertIn("stablehlo.cosine", stablehlo) - - # The HLO has been canonicalized and contains only the branch we need - hlo = lowered.as_text("hlo") + # The StableHLO contains only the branch we need if jtu.device_under_test() == "cpu": - self.assertIn(" sine", hlo) - self.assertNotIn(" cosine", hlo) + self.assertIn("stablehlo.sine", stablehlo) + self.assertNotIn("stablehlo.cosine", stablehlo) else: - self.assertNotIn(" sine", hlo) - self.assertIn(" cosine", hlo) + self.assertNotIn("stablehlo.sine", stablehlo) + self.assertIn("stablehlo.cosine", stablehlo) def test_platform_dependent_with_non_existent_custom_call(self): if not jtu.test_device_matches(["cpu"]): @@ -2895,8 +2971,7 @@ def f(x): x = np.arange(3, dtype=np.float32) hlo = str(jax.jit(f).lower(x).compiler_ir()) - occurrences = re.findall(prim_non_existent_custom_call.name, hlo) - self.assertLen(occurrences, 3) + self.assertNotIn(prim_non_existent_custom_call.name, hlo) res_eager = f(x) self.assertAllClose(res_eager, 3. * np.sin(x)) @@ -2912,6 +2987,26 @@ def f(x): res_grad = jax.grad(f)(1.) self.assertAllClose(res_grad, 3. * np.cos(1.)) + def test_platform_dependent_with_primitive_with_lowering_error(self): + if not jtu.test_device_matches(["cpu", "tpu"]): + self.skipTest("Only for CPU and TPU") + + def f(x): + return lax.platform_dependent( + x, + # Check that we only lower on the intended platform + cpu=lambda x: prim_with_lowering_error.bind(x, only_on="cpu"), + tpu=lambda x: prim_with_lowering_error.bind(x, only_on="tpu")) + + self.assertAllClose(np.sin(1.), f(1.)) # Eager + self.assertAllClose(np.sin(1.), jax.jit(f)(1.)) + self.assertAllClose(np.sin(1.), lax.cond(True, f, lambda x: x, 1.)) + self.assertAllClose(1., lax.cond(False, f, lambda x: x, 1.)) + self.assertAllClose((0., np.sin(np.arange(8.))), + lax.scan(lambda carry, x: (carry, f(x)), + 0., np.arange(8.))) + self.assertAllClose(np.sin(np.arange(8.)), jax.vmap(f)(np.arange(8.))) + def test_platform_dependent_multiple_identical_branches(self): x = np.arange(3, dtype=np.float32) def f(x): @@ -2921,13 +3016,14 @@ def f(x): tpu=jnp.sin, default=lambda x: x) res = f(x) + on_cpu_tpu = jtu.device_under_test() in ["cpu", "tpu"] self.assertAllClose( res, - np.sin(x) if jtu.device_under_test() in ["cpu", "tpu"] else x) - # We only lower the common branches once + np.sin(x) if on_cpu_tpu else x) + stablehlo = jax.jit(f).lower(x).as_text() sines = re.findall(r"stablehlo.sine", stablehlo) - self.assertEqual(1, len(sines)) + self.assertEqual(1 if on_cpu_tpu else 0, len(sines)) def test_platform_dependent_no_default(self): ctx = contextlib.ExitStack() @@ -2981,6 +3077,26 @@ def f(x): self.assertEqual(expect_a_dot, " dot(" in hlo) self.assertEqual(not expect_a_dot, " while(" in hlo) + def test_issue_29329(self): + + def outer_fn(x): + def inner_fn(x): + return jax.jit( + lambda x: lax.platform_dependent(x, + default=jnp.sin, + other=jnp.cos))(x) + + _, lin_fn = jax.linearize(inner_fn, x) + + def with_transpose(x): + grad = jax.linear_transpose(lin_fn, x)(x) + del grad + return x + + return jax.lax.cond(x[0][0] > 0., with_transpose, lambda x: x, x) + + jax.vmap(outer_fn)(jnp.ones((5, 10, 10))) + def test_scan_lowering_doesnt_introduce_singleton(self): b = 4 i = 2 @@ -3048,7 +3164,7 @@ def test_cond_memory_leak(self): def leak(): data = jax.device_put(np.zeros((1024), dtype=np.float32) + 1) def g(): - return jax.lax.cond( + return jax.lax.cond( True, lambda: data[0], # noqa: F821 lambda: data[1], # noqa: F821 @@ -3066,6 +3182,219 @@ def g(): leak() self.assertEqual(base, nbufs()) + def test_grad_remat_while_fixpoint(self): + @jax.remat + def f(x, y): + def cond(_): + return False + def body(c): + x, y = c + return (y, x) + x, y = jax.lax.while_loop(cond, body, (x, y)) + return x + y + jax.linearize(f, 1., 2.) # don't crash + + def test_while_readonly_carry_optimization(self): + # https://github.com/google/flax/issues/4700 + def foo(w, x, c_max): + def while_cond(val): + c, x, w = val + return c < c_max + + def while_body(val): + c, x, w = val + return c + 1, x @ w, w + + _, x, w = jax.lax.while_loop(while_cond, while_body, (0, x, w)) + return w, x + + w = jnp.ones((2, 2)) + xs = jnp.ones((4, 2)) + c_maxs = jnp.arange(4) + w_, _ = jax.vmap(foo, in_axes=(None, 0, 0), out_axes=(None, 0) + )(w, xs, c_maxs) # doesn't crash + self.assertAllClose(w, w_, check_dtypes=False) + + @parameterized.parameters(itertools.product(range(3), repeat=5)) + @jtu.run_on_devices("cpu") + def test_while_constification_correctness( + self, + seed, + num_body_consts, + num_inplace_fwds_cond_uses, + num_inplace_fwds_cond_doesnt_use, + num_noninplace_fwds): + + num_fwds = (num_inplace_fwds_cond_uses + num_inplace_fwds_cond_doesnt_use + + num_noninplace_fwds) + num_carry = num_fwds + 4 + + rng = np.random.RandomState(seed) + perm = rng.permutation(num_carry) + iperm = np.argsort(perm) + + body_consts = [rng.randn(3) for _ in range(num_body_consts)] + init_vals = list(rng.uniform(size=num_carry)) + + def cond_fun(c): + i, c = c + c = [c[i] for i in iperm] + c, _ = split_list(c, [num_inplace_fwds_cond_uses]) + return (i < 2) + (0. * jnp.array(sum(c))).astype(bool) + + def body_fun(c): + i, c = c + c = [c[i] for i in iperm] + inplace_fwds, noninplace_fwds, dont_fwd = split_list( + c, [num_inplace_fwds_cond_uses + num_inplace_fwds_cond_doesnt_use, + num_noninplace_fwds]) + dont_fwd = [jnp.sin(x) * sum(jnp.sum(c) for c in body_consts) + for x in dont_fwd] + new_c_perm = [*inplace_fwds, *dont_fwd, *noninplace_fwds] + new_c = [new_c_perm[i] for i in perm] + return (i + 1, new_c) + + i, outs = jax.lax.while_loop(cond_fun, body_fun, (0, init_vals)) + self.assertEqual(i, 2) + _, outs_ref = body_fun(body_fun((0, init_vals))) + self.assertAllClose(outs, outs_ref, check_dtypes=False) + + def test_while_constification_correctness_manually(self): + # regression test for a particular index-offset logic bug + + def cond_fun(c): + # cond doesn't use first or third element of the carry + _, i, _ = c + return i == 0 + + def body_fun(c): + # two body consts + for _ in range(2): jnp.sin(np.zeros(3)) + # first element of the carry is forwarded to third element of the carry + return 0., 1., c[0] + + outs = jax.lax.while_loop(cond_fun, body_fun, (5., 0., 3.14)) + self.assertAllClose(outs, (0., 1., 5.)) + + def test_scan_readonly_carry_optimization(self): + # https://github.com/google/flax/issues/4709 + def f(x, y): + def g(_, y): + y, _ = jax.lax.scan(lambda y, _: (y, None), y, None, length=1) + return y + return jax.lax.cond(x < 0, g, g, x, y) + xs = jnp.arange(3.) + y = 3. + jax.vmap(f, (0, None), None)(xs, y) # don't crash + + @parameterized.parameters(itertools.product(range(3), repeat=4)) + @jtu.run_on_devices("cpu") + def test_scan_constification_correctness( + self, + seed, + num_body_consts, + num_inplace_fwds, + num_noninplace_fwds): + + num_fwds = num_inplace_fwds + num_noninplace_fwds + num_carry = num_fwds + 4 + num_xs = 2 + num_ys = 3 + + rng = np.random.RandomState(seed) + perm = rng.permutation(num_carry) + iperm = np.argsort(perm) + + body_consts = [rng.randn(3) for _ in range(num_body_consts)] + init_vals = list(rng.uniform(size=num_carry)) + + def body_fun(c, _): + c = [c[i] for i in iperm] + inplace_fwds, noninplace_fwds, dont_fwd = split_list( + c, [num_inplace_fwds, num_noninplace_fwds]) + dont_fwd = [jnp.sin(x) * sum(jnp.sum(c) for c in body_consts) + for x in dont_fwd] + new_c_perm = [*inplace_fwds, *dont_fwd, *noninplace_fwds] + new_c = [new_c_perm[i] for i in perm] + return new_c, [0 for _ in range(num_ys)] + + xs = [jnp.arange(2.) for _ in range(num_xs)] + outs = jax.lax.scan(body_fun, init_vals, xs)[0] + outs_ref = body_fun(body_fun(init_vals, [x[0] for x in xs])[0], [x[1] for x in xs])[0] + self.assertAllClose(outs, outs_ref, check_dtypes=False) + + @parameterized.parameters(itertools.product(range(3), repeat=4)) + @jtu.run_on_devices("cpu") + def test_scan_forwarding_correctness( + self, + seed, + num_body_consts, + num_const_fwds, + num_input_fwds): + + num_carry = num_const_fwds + 4 + num_xs = num_input_fwds + 2 + num_ys = num_xs + 1 + + rng = np.random.RandomState(seed) + carry_perm = rng.permutation(num_carry) + carry_iperm = np.argsort(carry_perm) + + xs_perm = rng.permutation(num_xs) + ys_perm = rng.permutation(num_ys) + f = np.arange(num_xs) + f = [f[i] if idx < num_input_fwds else None for idx, i in enumerate(xs_perm)] + f += [None] + in_fwd = [f[i] for i in ys_perm] + + body_consts = [rng.randn(3) for _ in range(num_body_consts)] + init_vals = list(rng.uniform(size=num_carry)) + + def body_fun(c, x): + c = [c[i] for i in carry_iperm] + carry_fwds, carry_dont_fwd = split_list(c, [num_const_fwds]) + carry_dont_fwd = [jnp.sin(x) * sum(jnp.sum(c) for c in body_consts) + for x in carry_dont_fwd] + new_c_perm = [*carry_fwds, *carry_dont_fwd] + new_c = [new_c_perm[i] for i in carry_perm] + + x = [x[i] for i in xs_perm] + x_fwd, x_dont_fwd = split_list(x, [num_input_fwds]) + x_dont_fwd = [jnp.cos(x) * sum(jnp.sum(c) for c in body_consts) + for x in x_dont_fwd] + y = [*x_fwd, *x_dont_fwd, 0] + y = [y[i] for i in ys_perm] + + return new_c, y + + xs = list(rng.uniform(size=(num_xs, 2))) + final, outs = jax.lax.scan(body_fun, init_vals, xs) + for f, y in zip(in_fwd, outs): + if f is not None: + self.assertAllClose(y, xs[f]) + + final_ref = body_fun(body_fun(init_vals, [x[0] for x in xs])[0], [x[1] for x in xs])[0] + self.assertAllClose(final, final_ref, check_dtypes=False) + + def test_scan_diff_of_print(self): + # ref: https://github.com/jax-ml/jax/issues/28738 + def f(c, _): + jax.debug.print("c = {c}", c=c, ordered=True) + return c + 1, None + def g(x): + return jax.lax.scan(f, x, length=2)[0] + jaxpr = jax.make_jaxpr(jax.value_and_grad(g))(1.0) + eqn_jaxpr = jaxpr.eqns[0].params["jaxpr"] + self.assertIn("debug_callback", [e.primitive.name for e in eqn_jaxpr.eqns]) + + def test_scan_input_to_output_forwarding(self): + def f(c, x): + return c + 1, x + def g(x): + return jax.lax.scan(f, 0, x) + jaxpr = jax.make_jaxpr(g)(jnp.arange(3.)) + self.assertLen(jaxpr.eqns[0].params["jaxpr"].jaxpr.outvars, 1) + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/lax_metal_test.py b/tests/lax_metal_test.py index 5f1781c3be06..ecbf908d2f09 100644 --- a/tests/lax_metal_test.py +++ b/tests/lax_metal_test.py @@ -48,7 +48,7 @@ from jax._src import test_util as jtu from jax._src.lax import lax as lax_internal -from jax._src.util import safe_zip, NumpyComplexWarning +from jax._src.util import safe_zip try: from jax_plugins import metal_plugin @@ -2099,11 +2099,11 @@ def testCumSumProd(self, axis, shape, dtype, out_dtype, op): np_op = getattr(np, op) rng = jtu.rand_default(self.rng()) np_fun = lambda arg: np_op(arg, axis=axis, dtype=out_dtype) - np_fun = jtu.ignore_warning(category=NumpyComplexWarning)(np_fun) + np_fun = jtu.ignore_warning(category=np.exceptions.ComplexWarning)(np_fun) np_fun = jtu.ignore_warning(category=RuntimeWarning, message="overflow encountered.*")(np_fun) jnp_fun = lambda arg: jnp_op(arg, axis=axis, dtype=out_dtype) - jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun) + jnp_fun = jtu.ignore_warning(category=np.exceptions.ComplexWarning)(jnp_fun) args_maker = lambda: [rng(shape, dtype)] @@ -2127,11 +2127,11 @@ def testNanCumSumProd(self, axis, shape, dtype, out_dtype, op): np_op = getattr(np, op) rng = jtu.rand_some_nan(self.rng()) np_fun = partial(np_op, axis=axis, dtype=out_dtype) - np_fun = jtu.ignore_warning(category=NumpyComplexWarning)(np_fun) + np_fun = jtu.ignore_warning(category=np.exceptions.ComplexWarning)(np_fun) np_fun = jtu.ignore_warning(category=RuntimeWarning, message="overflow encountered.*")(np_fun) jnp_fun = partial(jnp_op, axis=axis, dtype=out_dtype) - jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun) + jnp_fun = jtu.ignore_warning(category=np.exceptions.ComplexWarning)(jnp_fun) args_maker = lambda: [rng(shape, dtype)] @@ -3867,7 +3867,7 @@ def testItem(self, shape, dtype, num_args, use_tuple): self._CheckAgainstNumpy(np_op, jnp_op, args_maker) @jtu.sample_product( - # Final dimension must be a multiple of 16 to ensure compatibilty of all dtype pairs. + # Final dimension must be a multiple of 16 to ensure compatibility of all dtype pairs. shape=[(0,), (32,), (2, 16)], a_dtype=all_dtypes, dtype=(*all_dtypes, None) if config.enable_x64.value else all_dtypes, diff --git a/tests/lax_numpy_indexing_test.py b/tests/lax_numpy_indexing_test.py index 63a725ad3643..6364137cc1c6 100644 --- a/tests/lax_numpy_indexing_test.py +++ b/tests/lax_numpy_indexing_test.py @@ -35,7 +35,6 @@ from jax._src import test_util as jtu from jax._src import util from jax._src.lax import lax as lax_internal -from jax._src.util import NumpyComplexWarning config.parse_flags_with_absl() @@ -926,12 +925,20 @@ def testSimpleIndexingUsesSlice(self): self.assertEqual(jaxpr.jaxpr.eqns[-2].primitive, lax.slice_p) self.assertEqual(jaxpr.jaxpr.eqns[-1].primitive, lax.squeeze_p) - # Indexing with `Ellipsis` is not lowered to `gather`. + # Indexing with `Ellipsis` is not lowered to `gather` ... jaxpr = jax.make_jaxpr(lambda x: x[..., 0])(jnp.ones((3, 4, 5))) self.assertLen((jaxpr.jaxpr.eqns), 2) self.assertEqual(jaxpr.jaxpr.eqns[-2].primitive, lax.slice_p) self.assertEqual(jaxpr.jaxpr.eqns[-1].primitive, lax.squeeze_p) + # ... even when the ellipsis expands to no dimensions. + jaxpr = jax.make_jaxpr(lambda x: x[..., 0:1])(jnp.ones((3,))) + self.assertLen((jaxpr.jaxpr.eqns), 1) + self.assertEqual(jaxpr.jaxpr.eqns[-1].primitive, lax.slice_p) + jaxpr = jax.make_jaxpr(lambda x: x[0:1, ...])(jnp.ones((3,))) + self.assertLen((jaxpr.jaxpr.eqns), 1) + self.assertEqual(jaxpr.jaxpr.eqns[-1].primitive, lax.slice_p) + # Simple reverses lower to lax.rev_p jaxpr = jax.make_jaxpr(lambda x: x[:, ::-1])(jnp.ones((3, 4))) self.assertEqual(len(jaxpr.jaxpr.eqns), 1) @@ -1132,6 +1139,47 @@ def testStrIndexingError(self): with self.assertRaisesRegex(TypeError, msg): jnp.zeros((2, 3))[:, 'abc'] + @jtu.sample_product( + mode=["promise_in_bounds", "fill", "clip", "drop"], + wrap_negative_indices=[True, False], + shape=[(5,), (10,)], + idx_shape=[(5,)], + ) + def testWrapNegativeIndices1D(self, mode, wrap_negative_indices, shape, idx_shape): + """Test the behavior of the wrap_negative_indices parameter in array.at[...].get()""" + fill_value = 99 + + data_rng = jtu.rand_default(self.rng()) + idx_rng = jtu.rand_uniform(self.rng(), low=-12, high=12) + + args_maker = lambda: [data_rng(shape, 'float32'), idx_rng(idx_shape, 'int32')] + + def jnp_fun(data, idx): + return jnp.array(data).at[idx].get( + mode=mode, + fill_value=fill_value, + wrap_negative_indices=wrap_negative_indices) + + def np_fun(data, idx): + if wrap_negative_indices: + idx = np.where(idx < 0, idx + len(data), idx) + out_of_bound = (idx < 0) | (idx >= len(data)) + safe_idx = np.where(out_of_bound, 0, idx) + result = data[safe_idx] + if mode in ["fill", "drop"]: + result = np.where(out_of_bound, fill_value, result) + elif mode in ["promise_in_bounds", "clip"]: + result = np.where(idx < 0, data[0], + np.where(idx >= len(data), data[-1], + result)) + else: + raise ValueError(f"Unrecognized mode {mode!r}") + return result + + tol = 1E-4 if jtu.test_device_matches(["tpu"]) else None + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, tol=tol) + self._CompileAndCheck(jnp_fun, args_maker, tol=tol) + def testIndexOutOfBounds(self): # https://github.com/jax-ml/jax/issues/2245 x = jnp.arange(5, dtype=jnp.int32) + 1 self.assertAllClose(x, x[:10]) @@ -1178,7 +1226,7 @@ def _check(x_type, y_type): out = x.at[0].set(y) self.assertEqual(x.dtype, out.dtype) - @jtu.ignore_warning(category=NumpyComplexWarning, + @jtu.ignore_warning(category=np.exceptions.ComplexWarning, message="Casting complex values to real") def _check_warns(x_type, y_type, msg): with self.assertWarnsRegex(FutureWarning, msg): @@ -1284,22 +1332,30 @@ class UpdateOps(enum.Enum): def np_fn(op, indexer, x, y): x = x.copy() - x[indexer] = { - UpdateOps.UPDATE: lambda: y, - UpdateOps.ADD: lambda: x[indexer] + y, - UpdateOps.SUB: lambda: x[indexer] - y, - UpdateOps.MUL: lambda: x[indexer] * y, - UpdateOps.DIV: jtu.ignore_warning(category=RuntimeWarning)( - lambda: x[indexer] / y.astype(x.dtype)), - UpdateOps.POW: jtu.ignore_warning(category=RuntimeWarning)( - lambda: x[indexer] ** y.astype(x.dtype)), - UpdateOps.MIN: lambda: np.minimum(x[indexer], y), - UpdateOps.MAX: lambda: np.maximum(x[indexer], y), - }[op]() + if op == UpdateOps.UPDATE: + x[indexer] = y + elif op == UpdateOps.ADD: + np.add.at(x, indexer, y) + elif op == UpdateOps.SUB: + np.subtract.at(x, indexer, y) + elif op == UpdateOps.MUL: + np.multiply.at(x, indexer, y) + elif op == UpdateOps.DIV: + with jtu.ignore_warning(category=RuntimeWarning): + np.divide.at(x, indexer, y) + elif op == UpdateOps.POW: + with jtu.ignore_warning(category=RuntimeWarning): + np.power.at(x, indexer, y) + elif op == UpdateOps.MIN: + np.minimum.at(x, indexer, y.astype(x.dtype)) + elif op == UpdateOps.MAX: + np.maximum.at(x, indexer, y.astype(x.dtype)) + else: + raise ValueError(f"{op=}") return x def jax_fn(op, indexer, x, y, indices_are_sorted=False, - unique_indices=False, mode=None): + unique_indices=False, mode=None, wrap_negative_indices=True): x = jnp.array(x) return { UpdateOps.UPDATE: x.at[indexer].set, @@ -1311,7 +1367,8 @@ def jax_fn(op, indexer, x, y, indices_are_sorted=False, UpdateOps.MIN: x.at[indexer].min, UpdateOps.MAX: x.at[indexer].max, }[op](y, indices_are_sorted=indices_are_sorted, - unique_indices=unique_indices, mode=mode) + unique_indices=unique_indices, mode=mode, + wrap_negative_indices=wrap_negative_indices) def dtypes(op): if op == UpdateOps.UPDATE: @@ -1424,6 +1481,52 @@ def testMixedAdvancedIndexing(self, name, shape, dtype, update_shape, self._CheckAgainstNumpy(np_fn, jax_fn, args_maker, tol=_update_tol(op)) self._CompileAndCheck(jax_fn, args_maker) + @jtu.sample_product( + op=UpdateOps, + mode=["fill", "clip"], + wrap_negative_indices=[True, False], + shape=[(5,), (10,)], + update_shape=[(5,)], + ) + def testWrapNegativeIndices1D(self, op, mode, wrap_negative_indices, shape, update_shape): + rng = jtu.rand_default(self.rng()) + idx_rng = jtu.rand_unique_int(self.rng(), high=shape[0]) + + def args_maker(): + data = rng(shape, 'float32').round(1) + update = rng(update_shape, 'float32').round(1) + # we need indices to be unique, so we generate unique values in [0, N) + # and then subtract N from half of them. To test out-of-bound behavior + # we push the bottom and top index out-of-bounds + idx = idx_rng(update_shape, 'int32') + idx = np.where(rng(update_shape, bool), idx, idx - shape[0]) + idx[idx == shape[0] - 1] = shape[0] + 2 # out-of-bound positive + idx[idx == -shape[0]] = -(shape[0] + 2) # out-of-bound negative + return data, idx, update + + def jnp_fun(data, idx, values): + return UpdateOps.jax_fn(op, idx, data, values, + mode=mode, + wrap_negative_indices=wrap_negative_indices) + + def np_fun(data, idx, values): + if wrap_negative_indices: + idx = np.where(idx < 0, idx + len(data), idx) + if mode in ["fill", "drop", "promise_in_bounds"]: + ok = (idx >= 0) & (idx < len(data)) + idx = idx[ok] + values = values[ok] + elif mode == "clip": + idx = np.where(idx < 0, 0, idx) + idx = np.where(idx >= len(data), len(data) - 1, idx) + else: + raise ValueError(f"Unrecognized mode {mode!r}") + return UpdateOps.np_fn(op, idx, data, values) + + tol = 1E-4 if jtu.test_device_matches(["tpu"]) else None + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, tol=tol) + self._CompileAndCheck(jnp_fun, args_maker, tol=tol) + @jtu.sample_product( [dict(name=name, mode=mode, shape=shape, indexer=indexer, update_shape=update_shape) diff --git a/tests/lax_numpy_reducers_test.py b/tests/lax_numpy_reducers_test.py index 0c3f1d1471fb..93aff25c6f8e 100644 --- a/tests/lax_numpy_reducers_test.py +++ b/tests/lax_numpy_reducers_test.py @@ -29,7 +29,6 @@ from jax._src import config from jax._src import dtypes from jax._src import test_util as jtu -from jax._src.util import NumpyComplexWarning config.parse_flags_with_absl() @@ -209,7 +208,7 @@ def testReducer(self, name, rng_factory, shape, dtype, out_dtype, np_op = getattr(np, name) jnp_op = getattr(jnp, name) rng = rng_factory(self.rng()) - @jtu.ignore_warning(category=NumpyComplexWarning) + @jtu.ignore_warning(category=np.exceptions.ComplexWarning) @jtu.ignore_warning(category=RuntimeWarning, message="Mean of empty slice.*") @jtu.ignore_warning(category=RuntimeWarning, @@ -225,7 +224,7 @@ def np_fun(x): return np_op(x_cast, axis, dtype=t, keepdims=keepdims) jnp_fun = lambda x: jnp_op(x, axis, dtype=out_dtype, keepdims=keepdims) - jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun) + jnp_fun = jtu.ignore_warning(category=np.exceptions.ComplexWarning)(jnp_fun) args_maker = lambda: [rng(shape, dtype)] tol_spec = {np.float16: 1e-2, np.int16: 2e-7, np.int32: 1E-3, np.uint32: 3e-7, np.float32: 1e-3, np.complex64: 1e-3, @@ -313,7 +312,7 @@ def testReducerInitial(self, name, rng_factory, shape, dtype, axis, is_bf16_nan_test = dtype == jnp.bfloat16 and rng_factory.__name__ == 'rand_some_nan' @jtu.ignore_warning(category=RuntimeWarning, message="Degrees of freedom <= 0 for slice.*") - @jtu.ignore_warning(category=NumpyComplexWarning) + @jtu.ignore_warning(category=np.exceptions.ComplexWarning) def np_fun(x): x = np.asarray(x) if inexact: @@ -324,7 +323,7 @@ def np_fun(x): return res.astype(_reducer_output_dtype(name, x.dtype)) jnp_fun = lambda x: jnp_op(x, axis, keepdims=keepdims, initial=initial) - jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun) + jnp_fun = jtu.ignore_warning(category=np.exceptions.ComplexWarning)(jnp_fun) args_maker = lambda: [rng(shape, dtype)] tol = {jnp.bfloat16: 3E-2} self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, rtol=tol, atol=tol) @@ -353,7 +352,7 @@ def testReducerPromoteInt(self, name, rng_factory, shape, dtype, axis, rng_factory.__name__ == 'rand_some_nan') @jtu.ignore_warning(category=RuntimeWarning, message="Degrees of freedom <= 0 for slice.*") - @jtu.ignore_warning(category=NumpyComplexWarning) + @jtu.ignore_warning(category=np.exceptions.ComplexWarning) def np_fun(x): x = np.asarray(x) if inexact: @@ -364,7 +363,7 @@ def np_fun(x): return res.astype(_reducer_output_dtype(name, x.dtype, promote_integers)) jnp_fun = lambda x: jnp_op(x, axis, keepdims=keepdims, initial=initial, promote_integers=promote_integers) - jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun) + jnp_fun = jtu.ignore_warning(category=np.exceptions.ComplexWarning)(jnp_fun) args_maker = lambda: [rng(shape, dtype)] tol = {jnp.bfloat16: 3E-2, jnp.float16: 5e-3} self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, rtol=tol) @@ -390,7 +389,7 @@ def testReducerNoInitialZeroDims(self, name, rng_factory, shape, dtype, axis, is_bf16_nan_test = dtype == jnp.bfloat16 and rng_factory.__name__ == 'rand_some_nan' @jtu.ignore_warning(category=RuntimeWarning, message="Degrees of freedom <= 0 for slice.*") - @jtu.ignore_warning(category=NumpyComplexWarning) + @jtu.ignore_warning(category=np.exceptions.ComplexWarning) def np_fun(x): x = np.asarray(x) if inexact: @@ -401,7 +400,7 @@ def np_fun(x): return res.astype(_reducer_output_dtype(name, x.dtype)) jnp_fun = lambda x: jnp_op(x, axis, keepdims=keepdims) - jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun) + jnp_fun = jtu.ignore_warning(category=np.exceptions.ComplexWarning)(jnp_fun) args_maker = lambda: [rng(shape, dtype)] tol = {jnp.bfloat16: 3E-2} self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, rtol=tol) @@ -436,7 +435,7 @@ def testReducerWhere(self, name, rng_factory, shape, dtype, axis, where = jtu.rand_bool(self.rng())(whereshape, np.bool_) @jtu.ignore_warning(category=RuntimeWarning, message="Degrees of freedom <= 0 for slice.*") - @jtu.ignore_warning(category=NumpyComplexWarning) + @jtu.ignore_warning(category=np.exceptions.ComplexWarning) def np_fun(x): x = np.asarray(x) if inexact: @@ -447,7 +446,7 @@ def np_fun(x): return res.astype(_reducer_output_dtype(name, x.dtype)) jnp_fun = lambda x: jnp_op(x, axis, keepdims=keepdims, initial=initial, where=where) - jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun) + jnp_fun = jtu.ignore_warning(category=np.exceptions.ComplexWarning)(jnp_fun) args_maker = lambda: [rng(shape, dtype)] self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, atol=tol, rtol=tol) self._CompileAndCheck(jnp_fun, args_maker) @@ -499,7 +498,7 @@ def testReducerWhereNoInitial(self, name, rng_factory, shape, dtype, axis, message="Mean of empty slice.*") @jtu.ignore_warning(category=RuntimeWarning, message="invalid value encountered.*") - @jtu.ignore_warning(category=NumpyComplexWarning) + @jtu.ignore_warning(category=np.exceptions.ComplexWarning) def np_fun(x): x = np.asarray(x) if inexact: @@ -510,7 +509,7 @@ def np_fun(x): return res jnp_fun = lambda x: jnp_op(x, axis, keepdims=keepdims, where=where) - jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun) + jnp_fun = jtu.ignore_warning(category=np.exceptions.ComplexWarning)(jnp_fun) args_maker = lambda: [rng(shape, dtype)] self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, atol=tol, rtol=tol) self._CompileAndCheck(jnp_fun, args_maker) @@ -574,7 +573,7 @@ def testStdOrVar(self, test_fns, shape, dtype, out_dtype, axis, ddof_correction, args_maker = self._GetArgsMaker(rng, [shape], [dtype]) @jtu.ignore_warning(category=RuntimeWarning, message="Degrees of freedom <= 0 for slice.") - @jtu.ignore_warning(category=NumpyComplexWarning) + @jtu.ignore_warning(category=np.exceptions.ComplexWarning) def np_fun(x): # setup ddof and correction kwargs excluding case when correction is not specified ddof_correction_kwargs = {"ddof": ddof} @@ -625,7 +624,7 @@ def testNanVar(self, shape, dtype, out_dtype, axis, ddof, keepdims): args_maker = self._GetArgsMaker(rng, [shape], [dtype]) @jtu.ignore_warning(category=RuntimeWarning, message="Degrees of freedom <= 0 for slice.") - @jtu.ignore_warning(category=NumpyComplexWarning) + @jtu.ignore_warning(category=np.exceptions.ComplexWarning) def np_fun(x): # Numpy fails with bfloat16 inputs out = np.nanvar(x.astype(np.float32 if dtype == dtypes.bfloat16 else dtype), @@ -834,7 +833,7 @@ def test_f16_mean(self, dtype): ], include_initial=[False, True], ) - @jtu.ignore_warning(category=NumpyComplexWarning) + @jtu.ignore_warning(category=np.exceptions.ComplexWarning) @jax.numpy_dtype_promotion('standard') # This test explicitly exercises mixed type promotion def testCumulativeSum(self, shape, axis, dtype, out_dtype, include_initial): rng = jtu.rand_some_zero(self.rng()) @@ -902,11 +901,11 @@ def testCumulativeSumBool(self): ], include_initial=[False, True], ) - @jtu.ignore_warning(category=NumpyComplexWarning) + @jtu.ignore_warning(category=np.exceptions.ComplexWarning) @jax.numpy_dtype_promotion('standard') # This test explicitly exercises mixed type promotion def testCumulativeProd(self, shape, axis, dtype, out_dtype, include_initial): - if jtu.is_device_tpu(6): - raise unittest.SkipTest("TODO(b/364258243): Test fails on TPU v6e") + if jtu.is_device_tpu_at_least(6): + raise unittest.SkipTest("TODO(b/364258243): Test fails on TPU v6+") rng = jtu.rand_some_zero(self.rng()) # We currently "cheat" to ensure we have JAX arrays, not NumPy arrays as diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 98f10d9c02b3..60ad6a83701b 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -50,7 +50,7 @@ from jax._src import dtypes from jax._src import test_util as jtu from jax._src.lax import lax as lax_internal -from jax._src.util import safe_zip, NumpyComplexWarning, tuple_replace +from jax._src.util import safe_zip, tuple_update config.parse_flags_with_absl() @@ -968,6 +968,7 @@ def np_fun(lhs, rhs): @jtu.sample_product( dtype=[dt for dt in float_dtypes if dt not in [jnp.float16, jnp.bfloat16]], shape=[shape for shape in one_dim_array_shapes if shape != (1,)], + num_rhs=[1, 5], deg=[1, 2, 3], rcond=[None, -1, 10e-3, 10e-5, 10e-10], full=[False, True], @@ -975,12 +976,13 @@ def np_fun(lhs, rhs): cov=[False, True, "unscaled"], ) @jax.default_matmul_precision("float32") - def testPolyfit(self, shape, dtype, deg, rcond, full, w, cov): + def testPolyfit(self, shape, num_rhs, dtype, deg, rcond, full, w, cov): rng = jtu.rand_default(self.rng()) tol_spec = {np.float32: 1e-3, np.float64: 1e-13, np.complex64: 1e-5} tol = jtu.tolerance(dtype, tol_spec) _w = lambda a: abs(a) if w else None - args_maker = lambda: [rng(shape, dtype), rng(shape, dtype), rng(shape, dtype)] + rhs_shape = shape + (num_rhs,) if num_rhs > 1 else shape + args_maker = lambda: [rng(shape, dtype), rng(rhs_shape, dtype), rng(shape, dtype)] jnp_fun = lambda x, y, a: jnp.polyfit(x, y, deg=deg, rcond=rcond, full=full, w=_w(a), cov=cov) np_fun = jtu.ignore_warning( message="Polyfit may be poorly conditioned*")(lambda x, y, a: np.polyfit(x, y, deg=deg, rcond=rcond, full=full, w=_w(a), cov=cov)) @@ -1063,6 +1065,14 @@ def testClipDeprecatedArgs(self): "Passing arguments 'a', 'a_min' or 'a_max' to jax.numpy.clip is deprecated"): jnp.clip(jnp.arange(4), a_min=2, a_max=3) + def testClipUpperPrecedence(self): + a_min = 3 * np.ones(1) + a_max = 2 * np.ones(1) + x = 4 * np.ones(1) + y = jnp.clip(x, min=a_min, max=a_max) + assert y == a_max, f"Expected {y} to equal {a_max} when a_min > a_max." + assert y == jnp.asarray(np.clip(x, a_min=a_min, a_max=a_max)) + def testHypotComplexInputError(self): rng = jtu.rand_default(self.rng()) x = rng((5,), dtype=jnp.complex64) @@ -1928,9 +1938,6 @@ def testDeleteMaskArray(self, shape, dtype, axis): rng = jtu.rand_default(self.rng()) mask_size = np.zeros(shape).size if axis is None else np.zeros(shape).shape[axis] mask = jtu.rand_int(self.rng(), low=0, high=2)(mask_size, bool) - if numpy_version == (1, 23, 0) and mask.shape == (1,): - # https://github.com/numpy/numpy/issues/21840 - self.skipTest("test fails for numpy v1.23.0") args_maker = lambda: [rng(shape, dtype)] np_fun = lambda arg: np.delete(arg, mask, axis=axis) jnp_fun = lambda arg: jnp.delete(arg, mask, axis=axis) @@ -2347,11 +2354,11 @@ def testCumSumProd(self, axis, shape, dtype, out_dtype, op): np_op = getattr(np, op) rng = jtu.rand_default(self.rng()) np_fun = lambda arg: np_op(arg, axis=axis, dtype=out_dtype) - np_fun = jtu.ignore_warning(category=NumpyComplexWarning)(np_fun) + np_fun = jtu.ignore_warning(category=np.exceptions.ComplexWarning)(np_fun) np_fun = jtu.ignore_warning(category=RuntimeWarning, message="overflow encountered.*")(np_fun) jnp_fun = lambda arg: jnp_op(arg, axis=axis, dtype=out_dtype) - jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun) + jnp_fun = jtu.ignore_warning(category=np.exceptions.ComplexWarning)(jnp_fun) args_maker = lambda: [rng(shape, dtype)] @@ -2375,11 +2382,11 @@ def testNanCumSumProd(self, axis, shape, dtype, out_dtype, op): np_op = getattr(np, op) rng = jtu.rand_some_nan(self.rng()) np_fun = partial(np_op, axis=axis, dtype=out_dtype) - np_fun = jtu.ignore_warning(category=NumpyComplexWarning)(np_fun) + np_fun = jtu.ignore_warning(category=np.exceptions.ComplexWarning)(np_fun) np_fun = jtu.ignore_warning(category=RuntimeWarning, message="overflow encountered.*")(np_fun) jnp_fun = partial(jnp_op, axis=axis, dtype=out_dtype) - jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun) + jnp_fun = jtu.ignore_warning(category=np.exceptions.ComplexWarning)(jnp_fun) args_maker = lambda: [rng(shape, dtype)] @@ -2754,6 +2761,28 @@ def np_fun(x1, x2): self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) self._CompileAndCheck(jnp_fun, args_maker) + @parameterized.parameters(*float_dtypes) + def testLdexpOverflow(self, dtype): + # Regression test for https://github.com/jax-ml/jax/issues/28040 + args_maker = lambda: [np.array(0.5, dtype), 1 << (dtypes.finfo(dtype).nexp - 1)] + def np_ldexp(x1, x2): + return np.ldexp(x1, x2).astype(x1.dtype) + self._CheckAgainstNumpy(np_ldexp, jnp.ldexp, args_maker) + self._CompileAndCheck(jnp.ldexp, args_maker) + + @parameterized.parameters(*float_dtypes) + def testLdexpExtremeValues(self, dtype): + # Regression test for https://github.com/jax-ml/jax/issues/28040 + def args_maker(): + info = dtypes.finfo(dtype) + span = int(np.log2(float(info.max)) - np.log2(float(info.tiny))) - 1 + return [np.array([info.tiny, info.max], dtype=dtype), + np.array([span, -span])] + def np_ldexp(x1, x2): + return np.ldexp(x1, x2).astype(x1.dtype) + self._CheckAgainstNumpy(np_ldexp, jnp.ldexp, args_maker) + self._CompileAndCheck(jnp.ldexp, args_maker) + @jtu.sample_product( rng_factory=[ jtu.rand_some_inf_and_nan, @@ -3496,11 +3525,6 @@ def testReshape(self, arg_shape, out_shape, dtype, order): self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) self._CompileAndCheck(jnp_fun, args_maker) - def testReshapeDeprecatedArgs(self): - msg = "The newshape argument to jnp.reshape was removed in JAX v0.4.36." - with self.assertRaisesRegex(TypeError, msg): - jnp.reshape(jnp.arange(4), newshape=(2, 2)) - @jtu.sample_product( [dict(arg_shape=arg_shape, out_shape=out_shape) for arg_shape, out_shape in [ @@ -3666,6 +3690,53 @@ def testAsarrayCopy(self, copy): self.assertArraysEqual(x_jax, func(x_np), check_dtypes=False) self.assertArraysEqual(x_jax, func(x_buf), check_dtypes=False) + @jtu.sample_product(numpy_array=[True, False]) + def testAsarrayWithCopyFalse(self, numpy_array): + x_jax = jnp.arange(4) + if numpy_array: + x = np.arange(4) + else: + x = make_python_array('l', [0, 1, 2, 3]) + device_error_msg = ('jnp.asarray: cannot convert object of type .* to JAX' + ' Array on platform={} with copy=False. Consider using' + ' copy=None or copy=True instead.') + + if jax.default_backend() != 'cpu': + # test accelerator devices - no support for copy=False + expected_platform = jax.local_devices()[0].platform + with self.assertRaisesRegex( + ValueError, device_error_msg.format(expected_platform)): + jnp.asarray(x, copy=False, device=jax.local_devices()[0]) + sharding = SingleDeviceSharding(jax.local_devices()[0]) + with self.assertRaisesRegex( + ValueError, device_error_msg.format(expected_platform)): + jnp.asarray(x, copy=False, device=sharding) + + # test None defaults to default backend - no support for copy=False + with self.assertRaisesRegex( + ValueError, device_error_msg.format(expected_platform)): + jnp.asarray(x, copy=False, device=None) + else: + self.assertArraysEqual(jnp.asarray(x, copy=False, device=None), x_jax, + check_dtypes=False) + + # test explicit CPU device or default CPU device context managers overwrite the default backend + x = make_python_array('l', [0, 1, 2, 3]) + for device in [jax.local_devices(backend='cpu')[0], + SingleDeviceSharding(jax.local_devices(backend='cpu')[0])]: + self.assertArraysEqual(jnp.asarray(x, copy=False, device=device), + x_jax, check_dtypes=False) + with jax.default_device('cpu'): + self.assertArraysEqual(jnp.asarray(x, copy=False), x_jax, + check_dtypes=False) + self.assertArraysEqual(jnp.asarray(x, copy=False, device=None), x_jax, + check_dtypes=False) + with jax.default_device(jax.local_devices(backend='cpu')[0]): + self.assertArraysEqual(jnp.asarray(x, copy=False), x_jax, + check_dtypes=False) + self.assertArraysEqual(jnp.asarray(x, copy=False, device=None), x_jax, + check_dtypes=False) + @jtu.ignore_warning(category=UserWarning, message="Explicitly requested dtype.*") def testArrayDtypeInference(self): def _check(obj, out_dtype, weak_type): @@ -3800,9 +3871,10 @@ def testArrayFromList(self): with self.assertRaisesRegex(OverflowError, "Python int too large.*"): jnp.array([0, val]) - def testArrayNoneWarning(self): - # TODO(jakevdp): make this an error after the deprecation period. - with self.assertWarnsRegex(FutureWarning, r"None encountered in jnp.array\(\)"): + def testArrayNone(self): + with self.assertRaisesRegex( + ValueError, 'None is not a valid value for jnp.array' + ): jnp.array([0.0, None]) def testIssue121(self): @@ -3891,6 +3963,24 @@ def testIsClose(self): key = jax.random.key(0) self.assertTrue(jnp.isclose(key, key)) + @jtu.sample_product( + atol=[0.0, 1E-4, np.inf], + rtol=[0.0, 1E-4, np.inf], + equal_nan=[True, False] + ) + def testIsCloseCornerCases(self, atol, rtol, equal_nan): + if jtu.numpy_version() < (2, 0, 0) and (np.isinf(atol) or np.isinf(rtol)): + self.skipTest("fails on older NumPy") + if jtu.numpy_version() >= (2, 3, 0) and (np.isinf(atol) or np.isinf(rtol)): + self.skipTest("NumPy 2.3.0 now throws warnings for inf atol/rtol") + vals = np.array([-np.nan, -np.inf, -1.00001, -1.0, -0.00001, -0.0, + 0.0, 0.00001, 1.0, 1.00001, np.inf, np.nan]) + x, y = np.meshgrid(vals, vals) + self.assertArraysEqual( + np.isclose(x, y, atol=atol, rtol=rtol, equal_nan=equal_nan), + jnp.isclose(x, y, atol=atol, rtol=rtol, equal_nan=equal_nan) + ) + @jtu.sample_product( x=[1, [1], [1, 1 + 1E-4], [1, np.nan]], y=[1, [1], [1, 1 + 1E-4], [1, np.nan]], @@ -4594,7 +4684,7 @@ def testRollaxis(self, shape, dtype, start, axis): self._CompileAndCheck(jnp_op, args_maker) @jtu.sample_product( - dtype=[np.uint8, np.bool_], + dtype=int_dtypes + unsigned_dtypes + bool_dtypes, bitorder=['big', 'little'], shape=[(1, 2, 3, 4)], axis=[None, 0, 1, -2, -1], @@ -6042,7 +6132,10 @@ def np_fun(a, i, v): dict(a_shape=a_shape, i_shape=i_shape, v_shape=v_shape, axis=axis) for a_shape in nonempty_array_shapes for axis in list(range(-len(a_shape), len(a_shape))) - for i_shape in [tuple_replace(a_shape, axis, J) for J in range(a_shape[axis] + 1)] + for i_shape in [ + tuple_update(a_shape, axis if axis >= 0 else axis + len(a_shape), J) + for J in range(a_shape[axis] + 1) + ] for v_shape in [(), (1,), i_shape] ] + [ dict(a_shape=a_shape, i_shape=i_shape, v_shape=v_shape, axis=None) @@ -6123,7 +6216,7 @@ def test_isdtype(self, dtype, kind): ], dtype=float_dtypes + int_dtypes, ) - @jtu.skip_on_devices("tpu") # TODO(jakevdp): fix and reenable this test. + @jtu.skip_on_devices("tpu") # TODO(jakevdp): fix and re-enable this test. @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. def test_trapezoid(self, yshape, xshape, dtype, dx, axis): rng = jtu.rand_default(self.rng()) @@ -6327,9 +6420,18 @@ def testGradLogaddexp2Complex(self, shapes, dtype): ) def testGradLdexp(self, n, dtype): rng = jtu.rand_default(self.rng()) - x = rng((), dtype) + x = rng((10,), dtype) check_grads(lambda x: jnp.ldexp(x, n), (x,), 1) + @jtu.sample_product( + n=range(-4, 5), + dtype=[jnp.float32, jnp.float64], + ) + def testGradFrexp(self, n, dtype): + rng = jtu.rand_default(self.rng()) + x = rng((10,), dtype) * 2 ** n + check_grads(lambda x: jnp.frexp(x)[0], (x,), 1) + class NumpySignaturesTest(jtu.JaxTestCase): diff --git a/tests/lax_numpy_ufuncs_test.py b/tests/lax_numpy_ufuncs_test.py index fd5050a5829b..f2155afb841d 100644 --- a/tests/lax_numpy_ufuncs_test.py +++ b/tests/lax_numpy_ufuncs_test.py @@ -56,7 +56,7 @@ def _jnp_ufunc_props(name): jnp_func = getattr(jnp, name) assert isinstance(jnp_func, jnp.ufunc) np_func = getattr(np, name) - dtypes = [np.dtype(c) for c in "Ffi?" if f"{c}{c}->{c}" in np_func.types or f"{c}->{c}" in np_func.types] + dtypes = [np.dtype(c) for c in "FfIi?" if f"{c}{c}->{c}" in np_func.types or f"{c}->{c}" in np_func.types] return [dict(name=name, dtype=dtype) for dtype in dtypes] @@ -242,6 +242,7 @@ def test_frompyfunc_reduce(self, func, nin, nout, identity, shape, axis, dtype): self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker) self._CompileAndCheck(jnp_fun, args_maker) + @jtu.sample_product( BINARY_UFUNCS_WITH_DTYPES, [{'shape': shape, 'axis': axis} @@ -324,6 +325,64 @@ def test_binary_ufunc_reduce_where(self, name, shape, axis, dtype): self._CheckAgainstNumpy(jnp_fun_reduce, np_fun_reduce, args_maker, tol=tol) self._CompileAndCheck(jnp_fun_reduce, args_maker) + @jtu.sample_product( + BINARY_UFUNCS_WITH_DTYPES, + [{'shape': shape, 'axis': axis} + for shape in nonscalar_shapes + for axis in [None, *range(-len(shape), len(shape))]], + ) + def test_binary_ufunc_reduce_initial(self, name, shape, axis, dtype): + jnp_fun = getattr(jnp, name) + np_fun = getattr(np, name) + + if jnp_fun.identity is None and axis is None and len(shape) > 1: + self.skipTest("Multiple-axis reduction over non-reorderable ufunc.") + + jnp_fun_reduce = lambda a, initial: jnp_fun.reduce(a, axis=axis, initial=initial) + np_fun_reduce = lambda a, initial: np_fun.reduce(a, axis=axis, initial=initial) + + rng = jtu.rand_default(self.rng()) + rng_initial = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape, dtype), rng_initial((), dtype)] + + tol = {np.float32: 1E-4} if jtu.test_device_matches(['tpu']) else None + + self._CheckAgainstNumpy(jnp_fun_reduce, np_fun_reduce, args_maker, tol=tol) + self._CompileAndCheck(jnp_fun_reduce, args_maker) + + @jtu.sample_product( + BINARY_UFUNCS_WITH_DTYPES, + [{'shape': shape, 'axis': axis} + for shape in nonscalar_shapes + for axis in [None, *range(-len(shape), len(shape))]], + ) + def test_binary_ufunc_reduce_where_initial(self, name, shape, axis, dtype): + jnp_fun = getattr(jnp, name) + np_fun = getattr(np, name) + + # Skip if the ufunc doesn't have an identity and we're doing a multi-axis reduction + if jnp_fun.identity is None and axis is None and len(shape) > 1: + self.skipTest("Multiple-axis reduction over non-reorderable ufunc.") + + jnp_fun_reduce = lambda a, where, initial: jnp_fun.reduce( + a, axis=axis, where=where, initial=initial) + np_fun_reduce = lambda a, where, initial: np_fun.reduce( + a, axis=axis, where=where, initial=initial) + + rng = jtu.rand_default(self.rng()) + rng_where = jtu.rand_bool(self.rng()) + rng_initial = jtu.rand_default(self.rng()) + args_maker = lambda: [ + rng(shape, dtype), + rng_where(shape, bool), + rng_initial((), dtype) + ] + + tol = {np.float32: 1E-4} if jtu.test_device_matches(['tpu']) else None + + self._CheckAgainstNumpy(jnp_fun_reduce, np_fun_reduce, args_maker, tol=tol) + self._CompileAndCheck(jnp_fun_reduce, args_maker) + @jtu.sample_product( SCALAR_FUNCS, [{'shape': shape, 'axis': axis} @@ -484,7 +543,7 @@ def test_binary_ufunc_reduceat(self, name, shape, axis, idx_shape, dtype): if (jnp_fun.nin, jnp_fun.nout) != (2, 1): self.skipTest(f"accumulate requires (nin, nout)=(2, 1); got {(jnp_fun.nin, jnp_fun.nout)=}") if name in ['add', 'multiply'] and dtype == bool: - # TODO(jakevdp): figure out how to fix thest cases. + # TODO(jakevdp): figure out how to fix test cases. self.skipTest(f"known failure for {name}.reduceat with {dtype=}") rng = jtu.rand_default(self.rng()) diff --git a/tests/lax_numpy_vectorize_test.py b/tests/lax_numpy_vectorize_test.py index 985dba484845..8fbd393dc3f8 100644 --- a/tests/lax_numpy_vectorize_test.py +++ b/tests/lax_numpy_vectorize_test.py @@ -287,6 +287,15 @@ def test_rank_promotion_error(self): with self.assertNoWarnings(): f2(rank2, rank1) + def test_non_scalar_outputs_and_default_signature(self): + def f(x): + self.assertEqual(np.shape(x), ()) + return x + jnp.linspace(-1, 1, out_dim) + + out_dim = 5 + self.assertEqual(jnp.vectorize(f)(0.5).shape, (out_dim,)) + self.assertEqual(jnp.vectorize(f)(jnp.ones(3)).shape, (3, out_dim)) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/lax_scipy_special_functions_test.py b/tests/lax_scipy_special_functions_test.py index f4e4e4f48213..9fc7619f7145 100644 --- a/tests/lax_scipy_special_functions_test.py +++ b/tests/lax_scipy_special_functions_test.py @@ -157,6 +157,10 @@ def op_record(name, nargs, dtypes, rng_factory, test_grad, nondiff_argnums=(), t "hyp1f1", 3, float_dtypes, functools.partial(jtu.rand_uniform, low=0.5, high=30), True ), + op_record( + "hyp2f1", 4, float_dtypes, + functools.partial(jtu.rand_uniform, low=0.1, high=0.9), False + ), op_record("log_softmax", 1, float_dtypes, jtu.rand_default, True), op_record("softmax", 1, float_dtypes, jtu.rand_default, True), ] @@ -256,6 +260,17 @@ def testNdtriExtremeValues(self): self._CheckAgainstNumpy(osp_special.ndtri, lsp_special.ndtri, args_maker, rtol=rtol) self._CompileAndCheck(lsp_special.ndtri, args_maker, rtol=rtol) + @parameterized.parameters([True, False]) + def testNdtriDebugInfs(self, with_jit): + # ref: https://github.com/jax-ml/jax/issues/29328 + f = jax.jit(lsp_special.ndtri) if with_jit else lsp_special.ndtri + with jax.debug_infs(True): + f(0.5) # Doesn't crash + with self.assertRaisesRegex(FloatingPointError, "invalid value \\(inf\\)"): + f(1.0) + with self.assertRaisesRegex(FloatingPointError, "invalid value \\(inf\\)"): + f(0.0) + def testRelEntrExtremeValues(self): # Testing at the extreme values (bounds (0. and 1.) and outside the bounds). dtype = jnp.zeros(0).dtype # default float dtype. @@ -288,35 +303,84 @@ def testExpiDisableJit(self): self.assertAllClose(result_jit, result_nojit) def testGammaIncBoundaryValues(self): - dtype = jax.numpy.zeros(0).dtype # default float dtype. + dtype = jax.dtypes.canonicalize_dtype(float) nan = float('nan') inf = float('inf') if jtu.parse_version(scipy.__version__) >= (1, 16): - samples_slice = slice(None) + a_samples = [0, 0, 0, 1, nan, 1, nan, 0, 1, 1, nan] + x_samples = [0, 1, 2, 0, 1, nan, nan, inf, inf, -1, inf] else: # disable samples that contradict with scipy/scipy#22441 - samples_slice = slice(None, -1) - args_maker = lambda: [np.array([0, 0, 0, 1, nan, 1, nan, 0, 1, nan][samples_slice]).astype(dtype), - np.array([0, 1, 2, 0, 1, nan, nan, inf, inf, inf][samples_slice]).astype(dtype)] + a_samples = [0, 0, 0, 1, nan, 1, nan, 0, 1, 1] + x_samples = [0, 1, 2, 0, 1, nan, nan, inf, inf, -1] + + args_maker = lambda: (np.array(a_samples, dtype=dtype), np.array(x_samples, dtype=dtype)) + rtol = 1E-3 if jtu.test_device_matches(["tpu"]) else 1e-5 self._CheckAgainstNumpy(lsp_special.gammainc, osp_special.gammainc, args_maker, rtol=rtol) self._CompileAndCheck(lsp_special.gammainc, args_maker, rtol=rtol) def testGammaIncCBoundaryValues(self): - dtype = jax.numpy.zeros(0).dtype # default float dtype. + dtype = jax.dtypes.canonicalize_dtype(float) nan = float('nan') inf = float('inf') if jtu.parse_version(scipy.__version__) >= (1, 16): - samples_slice = slice(None) + a_samples = [0, 0, 0, 1, nan, 1, nan, 0, 1, 1, nan] + x_samples = [0, 1, 2, 0, 1, nan, nan, inf, inf, -1, inf] else: # disable samples that contradict with scipy/scipy#22441 - samples_slice = slice(None, -1) - args_maker = lambda: [np.array([0, 0, 0, 1, nan, 1, nan, 0, 1, 1, nan][samples_slice]).astype(dtype), - np.array([0, 1, 2, 0, 1, nan, nan, inf, inf, -1, inf][samples_slice]).astype(dtype)] + a_samples = [0, 0, 0, 1, nan, 1, nan, 0, 1, 1] + x_samples = [0, 1, 2, 0, 1, nan, nan, inf, inf, -1] + + args_maker = lambda: (np.array(a_samples, dtype=dtype), np.array(x_samples, dtype=dtype)) + rtol = 1E-3 if jtu.test_device_matches(["tpu"]) else 1e-5 self._CheckAgainstNumpy(lsp_special.gammaincc, osp_special.gammaincc, args_maker, rtol=rtol) self._CompileAndCheck(lsp_special.gammaincc, args_maker, rtol=rtol) + def testBetaIncBoundaryValues(self): + dtype = jax.dtypes.canonicalize_dtype(float) + fi = jax.numpy.finfo(dtype) + nan = float('nan') + inf = float('inf') + tiny = fi.tiny + eps = fi.eps + if jtu.parse_version(scipy.__version__) >= (1, 16): + # TODO(pearu): enable tiny samples when a fix to scipy/scipy#22682 + # will be available + a_samples = [nan, -0.5, inf, 0, eps, 1, tiny][:-1] + b_samples = [nan, -0.5, inf, 0, eps, 1, tiny][:-1] + elif jtu.parse_version(scipy.__version__) >= (1, 12): + # disabled samples that contradict with scipy/scipy#22425 + a_samples = [nan, -0.5, 0.5] + b_samples = [nan, -0.5, 0.5] + else: + a_samples = [-0.5, 0.5] + b_samples = [-0.5, 0.5] + x_samples = [nan, -0.5, 0, 0.5, 1, 1.5] + + a_samples = np.array(a_samples, dtype=dtype) + b_samples = np.array(b_samples, dtype=dtype) + x_samples = np.array(x_samples, dtype=dtype) + + args_maker = lambda: np.meshgrid(a_samples, b_samples, x_samples) + + rtol = 1E-3 if jtu.test_device_matches(["tpu"]) else 5e-5 + self._CheckAgainstNumpy(osp_special.betainc, lsp_special.betainc, args_maker, rtol=rtol) + self._CompileAndCheck(lsp_special.betainc, args_maker, rtol=rtol) + + def testHyp2f1SpecialCases(self): + dtype = jax.dtypes.canonicalize_dtype(float) + + a_samples = np.array([0, 1, 1, 1, 1, 5, 5, 0.245, 0.45, 0.45, 2, 0.4, 0.32, 4, 4], dtype=dtype) + b_samples = np.array([1, 0, 1, 1, 1, 1, 1, 3, 0.7, 0.7, 1, 0.7, 0.76, 2, 3], dtype=dtype) + c_samples = np.array([1, 1, 0, 1, -1, 3, 3, 3, 0.45, 0.45, 5, 0.3, 0.11, 7, 7], dtype=dtype) + x_samples = np.array([1, 1, 1, 0, 1, 0.5, 1, 0.35, 0.35, 1.5, 1, 0.4, 0.95, 0.95, 0.95], dtype=dtype) + + args_maker = lambda: (a_samples, b_samples, c_samples, x_samples) + rtol = 1E-3 if jtu.test_device_matches(["tpu"]) else 5e-5 + self._CheckAgainstNumpy(osp_special.hyp2f1, lsp_special.hyp2f1, args_maker, rtol=rtol) + self._CompileAndCheck(lsp_special.hyp2f1, args_maker, rtol=rtol) if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/lax_scipy_spectral_dac_test.py b/tests/lax_scipy_spectral_dac_test.py index a09dcac5371c..4359318a7997 100644 --- a/tests/lax_scipy_spectral_dac_test.py +++ b/tests/lax_scipy_spectral_dac_test.py @@ -18,7 +18,7 @@ from jax import lax from jax import numpy as jnp from jax._src import test_util as jtu -from jax._src.lax import eigh as lax_eigh +from jax._src.tpu.linalg import eigh as lax_eigh from absl.testing import absltest diff --git a/tests/lax_scipy_test.py b/tests/lax_scipy_test.py index 388d053d9608..610bf5fabefd 100644 --- a/tests/lax_scipy_test.py +++ b/tests/lax_scipy_test.py @@ -113,7 +113,12 @@ def testLogSumExp(self, shapes, dtype, axis, keepdims, return_sign, use_b): if jnp.issubdtype(dtype, jnp.complexfloating) and scipy_version < (1, 13, 0): self.skipTest("logsumexp of complex input uses scipy 1.13.0 semantics.") - if not jtu.test_device_matches(["cpu"]): + if use_b and scipy_version >= (1, 15) and scipy_version < (1, 15, 3): + self.skipTest( + "TODO(https://github.com/scipy/scipy/issues/22903): logsumexp with a" + " b scale array is buggy in scipy 1.15" + ) + if not jtu.test_device_matches(["cpu", "gpu"]): rng = jtu.rand_some_inf_and_nan(self.rng()) else: rng = jtu.rand_default(self.rng()) @@ -339,8 +344,8 @@ def scipy_fun(z): ) @jtu.ignore_warning(category=DeprecationWarning, message=".*scipy.special.lpmn.*") def testLpmn(self, l_max, shape, dtype): - if jtu.is_device_tpu(6, "e"): - self.skipTest("TODO(b/364258243): fails on TPU v6e") + if jtu.is_device_tpu_at_least(6): + self.skipTest("TODO(b/364258243): fails on TPU v6+") rng = jtu.rand_uniform(self.rng(), low=-0.2, high=0.9) args_maker = lambda: [rng(shape, dtype)] @@ -461,8 +466,8 @@ def testSphHarmOrderOneDegreeOne(self): @jax.numpy_dtype_promotion('standard') # This test explicitly exercises dtype promotion def testSphHarmForJitAndAgainstNumpy(self, l_max, num_z, dtype): """Tests against JIT compatibility and Numpy.""" - if jtu.is_device_tpu(6, "e"): - self.skipTest("TODO(b/364258243): fails on TPU v6e") + if jtu.is_device_tpu_at_least(6): + self.skipTest("TODO(b/364258243): fails on TPU v6+") n_max = l_max shape = (num_z,) rng = jtu.rand_int(self.rng(), -l_max, l_max + 1) @@ -508,8 +513,8 @@ def testSphHarmCornerCaseWithWrongNmax(self): ) @jax.numpy_dtype_promotion('standard') # This test explicitly exercises dtype promotion def testSphHarmY(self, l_max, num_z, dtype): - if jtu.is_device_tpu(6, "e"): - self.skipTest("TODO(b/364258243): fails on TPU v6e") + if jtu.is_device_tpu_at_least(6): + self.skipTest("TODO(b/364258243): fails on TPU v6+") n_max = l_max shape = (num_z,) rng = jtu.rand_int(self.rng(), -l_max, l_max + 1) @@ -641,7 +646,7 @@ def test_spence(self, shape, dtype): ], dtype=float_dtypes + int_dtypes, ) - @jtu.skip_on_devices("tpu") # TODO(jakevdp): fix and reenable this test. + @jtu.skip_on_devices("tpu") # TODO(jakevdp): fix and re-enable this test. @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. def testIntegrateTrapezoid(self, yshape, xshape, dtype, dx, axis): rng = jtu.rand_default(self.rng()) diff --git a/tests/lax_test.py b/tests/lax_test.py index 8764caeb2e49..2be76913a59f 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -29,6 +29,7 @@ import jax from jax._src import core +from jax import export from jax import jvp, grad from jax import lax import jax.numpy as jnp @@ -47,9 +48,8 @@ from jax._src.interpreters import pxla from jax._src.internal_test_util import lax_test_util from jax._src.lax import lax as lax_internal -from jax._src.util import NumpyComplexWarning, safe_zip +from jax._src.util import safe_zip from jax._src.tree_util import tree_map -from jax._src.lib import xla_extension_version config.parse_flags_with_absl() @@ -1128,11 +1128,6 @@ def testDotAlgorithm(self, algorithm, dtype): raise SkipTest( f"The dot algorithm '{algorithm}' is not supported on CPU.") if jtu.test_device_matches(["gpu"]): - if (algorithm == lax.DotAlgorithmPreset.BF16_BF16_F32_X9 and - xla_extension_version < 320): - raise SkipTest( - f"The dot algorithm ${algorithm} requires XLA extension version " - ">= 320.") # GPU algorithm support is a little spotty. It is checked in # xla/service/algorithm_util.cc and the logic is copied here. if algorithm in { @@ -2636,6 +2631,11 @@ def reference_top_k(x): self._CheckAgainstNumpy(op, reference_top_k, args_maker) self._CompileAndCheck(op, args_maker) + def testTopKOverflow(self): + x = jax.ShapeDtypeStruct((2 ** 31 + 1,), np.dtype('bfloat16')) + with self.assertRaisesRegex(ValueError, "top_k returns int32 indices, which will overflow"): + jax.eval_shape(lambda x: jax.lax.top_k(x, 100), x) + @jtu.sample_product( [dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape) for lhs_shape, rhs_shape in [((3, 2), (2, 4)), @@ -3627,6 +3627,37 @@ def f(x): g = jax.grad(f)(5.) # doesn't crash self.assertAllClose(g, 3., check_dtypes=False) + def test_shape_as_value_handles_static_shapes(self): + result = lax.shape_as_value(()) + self.assertArraysEqual(result, lax.full((0,), np.array(0, np.int64))) + + result = lax.shape_as_value((2,)) + self.assertArraysEqual(result, np.asarray((2,), np.int64)) + + result = lax.shape_as_value((2, 3)) + self.assertArraysEqual(result, np.asarray((2, 3), np.int64)) + + def test_shape_as_value_handles_polymorphic_shapes(self): + @jax.jit + def f(x): + return lax.shape_as_value(x.shape) + + exported = export.export(f)( + jax.ShapeDtypeStruct(export.symbolic_shape("a"), jnp.float32) + ) + result = exported.call(np.ones((1), dtype=np.float32)) + self.assertArraysEqual(result, np.asarray((1,), np.int64)) + result = exported.call(np.ones((2), dtype=np.float32)) + self.assertArraysEqual(result, np.asarray((2,), np.int64)) + + exported = export.export(f)( + jax.ShapeDtypeStruct(export.symbolic_shape("a, b"), jnp.float32) + ) + result = exported.call(np.ones((1, 2), dtype=np.float32)) + self.assertArraysEqual(result, np.asarray((1, 2), np.int64)) + result = exported.call(np.ones((3, 4), dtype=np.float32)) + self.assertArraysEqual(result, np.asarray((3, 4), np.int64)) + class LazyConstantTest(jtu.JaxTestCase): def _Check(self, make_const, expected): @@ -3718,7 +3749,7 @@ def testConvertElementReturnType(self, input_type, dtype, value, jit): @jtu.sample_product( dtype_in=lax_test_util.all_dtypes, dtype_out=lax_test_util.all_dtypes) - @jtu.ignore_warning(category=NumpyComplexWarning) + @jtu.ignore_warning(category=np.exceptions.ComplexWarning) def testConvertElementTypeAvoidsCopies(self, dtype_in, dtype_out): x = jax.device_put(np.zeros(5, dtype_in)) self.assertEqual(x.dtype, dtype_in) @@ -4369,7 +4400,7 @@ def _testOnComplexPlaneWorker(self, name, dtype, kind): # # In addition, the 1/3 middle parts of regions q1, q2, q3, q4, # neg, pos are tested separately as these don't contain extremely - # small or extremelly large values and functions on these regions + # small or extremely large values and functions on these regions # ought not to possess any incorrectness issues. s0, s1 = size_re, size_im @@ -4747,7 +4778,7 @@ def my_square(x): ValueError, "JVP rule for composite not implemented. You can use `jax.custom_jvp` " "to add support. See " - "https://jax.readthedocs.io/en/latest/_autosummary/jax.custom_jvp.html" + "https://docs.jax.dev/en/latest/_autosummary/jax.custom_jvp.html" ): jvp(my_square, (1.0,), (2.0,)) @@ -4760,7 +4791,7 @@ def my_square(x): ValueError, "JVP rule for composite not implemented. You can use `jax.custom_jvp` " "to add support. See " - "https://jax.readthedocs.io/en/latest/_autosummary/jax.custom_jvp.html" + "https://docs.jax.dev/en/latest/_autosummary/jax.custom_jvp.html" ): grad(my_square)(1.0) @@ -4802,10 +4833,10 @@ class RaggedTest(jtu.JaxTestCase): @jtu.sample_product( [ - {'m': 5, 'k': 4, 'n': 3, 'num_groups': 1}, - {'m': 10, 'k': 9, 'n': 8, 'num_groups': 2}, + {'m': 64, 'k': 4, 'n': 3, 'num_groups': 1}, + {'m': 64, 'k': 9, 'n': 8, 'num_groups': 2}, ], - dtype=jtu.dtypes.numeric, + dtype=jtu.dtypes.all_floating, ) def test_ragged_dot(self, m, k, n, num_groups, dtype): """Tests ragged_dot. @@ -4816,6 +4847,8 @@ def test_ragged_dot(self, m, k, n, num_groups, dtype): Raises: SkipTest: in the case dtype is not supported. """ + if (dtype == np.float16): + raise SkipTest(f"unsupported dtype for ragged_dot: {dtype}") lhs_shape = (m, k) rhs_shape = (num_groups, k, n) @@ -4837,6 +4870,25 @@ def group_sizes(m, num_groups): self._CheckAgainstNumpy( lax_reference.ragged_dot, lax.ragged_dot, args_maker) + @parameterized.parameters( + { "m": 5, "k": 4, "n": 3, "num_groups": 1}, + { "m": 10, "k": 9, "n": 8, "num_groups": 2}, + ) + def test_ragged_dot_unsupported( + self, m, k, n, num_groups): + lhs_shape = (m, k) + rhs_shape = (num_groups, k, n) + group_sizes_shape = (num_groups,) + + args_maker = lambda: [ + jnp.ones(lhs_shape, dtype=jnp.float32), + jnp.ones(rhs_shape, dtype=jnp.float32), + jnp.ones(group_sizes_shape, dtype=jnp.int32), + ] + if jtu.test_device_matches(["tpu"]): + with self.assertRaises(jax.errors.JaxRuntimeError): + self._CompileAndCheck(lax.ragged_dot, args_maker) + @parameterized.parameters( { "lhs_shape": lhs_shape, @@ -5055,10 +5107,69 @@ def test_ragged_dot_general_shape_inference_success( lhs = jnp.ones(lhs_shape, dtype=jnp.float32) rhs = jnp.ones(rhs_shape, dtype=jnp.float32) group_sizes = jnp.ones(group_sizes_shape, dtype=jnp.int32) - self.assertEqual( - lax.ragged_dot_general(lhs, rhs, group_sizes, ragged_dnums).shape, - out_shape, + if jtu.test_device_matches(["tpu"]): + actual_shape = lax_internal._ragged_dot_general_shape_rule( + lhs, rhs, group_sizes, ragged_dot_dimension_numbers=ragged_dnums, + precision=jax.lax.Precision.DEFAULT, + preferred_element_type=jnp.float32, + ) + else: + actual_shape = lax.ragged_dot_general( + lhs, rhs, group_sizes, ragged_dnums + ).shape + self.assertEqual(actual_shape, out_shape) + + @parameterized.product( + batch_size=[3, 5], + m=[128, 1024], + k=[128, 1024], + n=[128, 1024], + num_groups=[2, 4], + ) + def test_ragged_dot_general_vmap( + self, batch_size: int, m: int, k: int, n: int, num_groups: int + ): + if (jtu.test_device_matches(["tpu"])): + raise SkipTest("batched ragged_dot not yet supported on TPU") + + lhs_shape = (batch_size, m, k) + rhs_shape = (batch_size, num_groups, k, n) + dtype = jnp.float32 + + def make_group_sizes(m, num_groups): + ends_no_final = jnp.sort(self.rng().choice(m, size=num_groups - 1)) + ends = jnp.concatenate( + [ends_no_final, jnp.array([m], dtype=ends_no_final.dtype)]) + starts = jnp.concatenate( + [jnp.zeros(1, dtype=ends_no_final.dtype), ends_no_final]) + return ends - starts + + rng = jtu.rand_small(self.rng()) + args_maker = lambda: [ + rng(lhs_shape, dtype), + rng(rhs_shape, dtype), + jnp.array([make_group_sizes(m, num_groups) for _ in range(batch_size)]), + ] + lhs, rhs, group_sizes = args_maker() + + out_dtype = jnp.float32 + precision = jax.lax.Precision.HIGHEST + ragged_dot = partial( + jax.lax.ragged_dot, + preferred_element_type=out_dtype, + precision=precision, ) + tol = 1e-5 + + batch_res = jax.vmap(ragged_dot)(lhs, rhs, group_sizes) + for i in range(batch_size): + # The ragged_dot does not zero out the output in the case sum(group_sizes) + # < m, hence we need to compare only the valid part of the output. + upper_bound = group_sizes[i].sum(axis=0) + ref_res = ragged_dot(lhs[i], rhs[i], group_sizes[i])[0:upper_bound, :] + self.assertArraysAllClose( + batch_res[i, 0:upper_bound, :], ref_res, rtol=tol, atol=tol + ) if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/layout_test.py b/tests/layout_test.py index b9062b8d21dc..ce0ca17b05de 100644 --- a/tests/layout_test.py +++ b/tests/layout_test.py @@ -21,9 +21,10 @@ import jax.numpy as jnp from jax.sharding import NamedSharding, PartitionSpec as P, SingleDeviceSharding from jax._src import config -from jax._src.layout import Layout, DeviceLocalLayout as DLL from jax._src import test_util as jtu from jax._src.util import safe_zip +from jax.experimental.layout import (with_layout_constraint, Format, + DeviceLocalLayout as DLL) from jax.experimental.compute_on import compute_on config.parse_flags_with_absl() @@ -50,22 +51,22 @@ def init(x, y): sds1 = jax.ShapeDtypeStruct(np_inp1.shape, np_inp1.dtype, sharding=s1) sds2 = jax.ShapeDtypeStruct(np_inp2.shape, np_inp2.dtype, sharding=s2) - lowered_apply = jax.jit(apply, in_shardings=Layout(DLL.AUTO), - out_shardings=Layout(DLL.AUTO)).lower(sds1, sds2) + lowered_apply = jax.jit(apply, in_shardings=Format(DLL.AUTO), + out_shardings=Format(DLL.AUTO)).lower(sds1, sds2) compiled_apply = lowered_apply.compile() - arg_layouts, kw_layouts = compiled_apply.input_layouts + arg_formats, kw_layouts = compiled_apply.input_formats self.assertEmpty(kw_layouts) - for i, o in zip(arg_layouts, compiled_apply.output_layouts): + for i, o in zip(arg_formats, compiled_apply.output_formats): self.assertEqual(i.device_local_layout.major_to_minor, o.device_local_layout.major_to_minor[::-1]) init_compiled = jax.jit( - init, out_shardings=arg_layouts).lower(sds1, sds2).compile() + init, out_shardings=arg_formats).lower(sds1, sds2).compile() - for i, o in zip(init_compiled.input_layouts[0], - init_compiled.output_layouts): + for i, o in zip(init_compiled.input_formats[0], + init_compiled.output_formats): self.assertEqual(i, o) arr1 = jax.device_put(np_inp1, s1) @@ -76,21 +77,21 @@ def init(x, y): init_compiled(arr1, arr2) self.assertEqual(init_count(), 1) - self.assertEqual(init_out[0].layout, init_compiled.output_layouts[0]) - self.assertEqual(init_out[1].layout, init_compiled.output_layouts[1]) + self.assertEqual(init_out[0].format, init_compiled.output_formats[0]) + self.assertEqual(init_out[1].format, init_compiled.output_formats[1]) with jtu.count_aot_jit_cpp_cache_miss() as apply_count: apply_out = compiled_apply(*init_out) compiled_apply(*init_out) self.assertEqual(apply_count(), 1) - self.assertEqual(apply_out[0].layout, compiled_apply.output_layouts[0]) - self.assertEqual(apply_out[1].layout, compiled_apply.output_layouts[1]) + self.assertEqual(apply_out[0].format, compiled_apply.output_formats[0]) + self.assertEqual(apply_out[1].format, compiled_apply.output_formats[1]) - self.assertTupleEqual(apply_out[0].layout.device_local_layout.major_to_minor, - init_out[0].layout.device_local_layout.major_to_minor[::-1]) - self.assertTupleEqual(apply_out[1].layout.device_local_layout.major_to_minor, - init_out[1].layout.device_local_layout.major_to_minor[::-1]) + self.assertTupleEqual(apply_out[0].format.device_local_layout.major_to_minor, + init_out[0].format.device_local_layout.major_to_minor[::-1]) + self.assertTupleEqual(apply_out[1].format.device_local_layout.major_to_minor, + init_out[1].format.device_local_layout.major_to_minor[::-1]) self.assertArraysEqual(init_out[0], np_inp1 * 2) self.assertArraysEqual(init_out[1], np_inp2 * 2) @@ -113,21 +114,21 @@ def f(x): out = compiled(arr) self.assertTupleEqual( - compiled.input_layouts[0][0].device_local_layout.major_to_minor[::-1], + compiled.input_formats[0][0].device_local_layout.major_to_minor[::-1], (2, 1, 0)) self.assertTupleEqual( - compiled.output_layouts.device_local_layout.major_to_minor[::-1], + compiled.output_formats.device_local_layout.major_to_minor[::-1], (2, 1, 0)) self.assertArraysEqual(out, np_inp.T) self.assertEqual(out.sharding, NamedSharding(mesh, P(None, 'y', 'x'))) - compiled_auto = jax.jit(f, in_shardings=Layout(DLL.AUTO), - out_shardings=Layout(DLL.AUTO)).lower(sds).compile() + compiled_auto = jax.jit(f, in_shardings=Format(DLL.AUTO), + out_shardings=Format(DLL.AUTO)).lower(sds).compile() self.assertTupleEqual( - compiled_auto.input_layouts[0][0].device_local_layout.major_to_minor[::-1], + compiled_auto.input_formats[0][0].device_local_layout.major_to_minor[::-1], (2, 1, 0)) self.assertTupleEqual( - compiled_auto.output_layouts.device_local_layout.major_to_minor[::-1], + compiled_auto.output_formats.device_local_layout.major_to_minor[::-1], (0, 1, 2)) with self.assertRaisesRegex( @@ -145,18 +146,18 @@ def test_in_layouts_out_layouts(self): def f(x): return x.T - compiled = jax.jit(f, in_shardings=Layout(), - out_shardings=Layout(DLL.AUTO)).lower(arr).compile() + compiled = jax.jit(f, in_shardings=Format(), + out_shardings=Format(DLL.AUTO)).lower(arr).compile() self.assertTupleEqual( - compiled.input_layouts[0][0].device_local_layout.major_to_minor[::-1], + compiled.input_formats[0][0].device_local_layout.major_to_minor[::-1], (1, 0)) self.assertTupleEqual( - compiled.output_layouts.device_local_layout.major_to_minor[::-1], + compiled.output_formats.device_local_layout.major_to_minor[::-1], (0, 1)) out = compiled(arr) self.assertArraysEqual(out, np_inp.T) - self.assertEqual(out.layout, compiled.output_layouts) + self.assertEqual(out.format, compiled.output_formats) self.assertEqual(out.sharding, NamedSharding(mesh, P('y', 'x'))) def test_sharding_and_layouts(self): @@ -165,15 +166,15 @@ def test_sharding_and_layouts(self): np_inp = np.arange(math.prod(shape)).reshape(shape) s = NamedSharding(mesh, P('x', 'y')) - compiled = jax.jit(lambda x: x.T, in_shardings=Layout(DLL.AUTO, s), - out_shardings=Layout(DLL.AUTO, s)).lower(np_inp).compile() + compiled = jax.jit(lambda x: x.T, in_shardings=Format(DLL.AUTO, s), + out_shardings=Format(DLL.AUTO, s)).lower(np_inp).compile() out = compiled(np_inp) self.assertTupleEqual( - compiled.input_layouts[0][0].device_local_layout.major_to_minor[::-1], + compiled.input_formats[0][0].device_local_layout.major_to_minor[::-1], (1, 0)) if not jtu.test_device_matches(['cpu']): self.assertTupleEqual( - compiled.output_layouts.device_local_layout.major_to_minor[::-1], + compiled.output_formats.device_local_layout.major_to_minor[::-1], (0, 1)) self.assertArraysEqual(out, np_inp.T) self.assertEqual(out.sharding, s) @@ -184,21 +185,21 @@ def f(x, y, z, a, b, c): shape = (8, 2) inps = [np.arange(math.prod(shape)).reshape(shape)] * 6 - compiled = jax.jit(f, in_shardings=Layout(DLL.AUTO), - out_shardings=Layout(DLL.AUTO)).lower(*inps).compile() - arg_layouts, _ = compiled.input_layouts + compiled = jax.jit(f, in_shardings=Format(DLL.AUTO), + out_shardings=Format(DLL.AUTO)).lower(*inps).compile() + arg_formats, _ = compiled.input_formats out1, out2 = compiled(*inps) - compiled2 = jax.jit(f, in_shardings=arg_layouts).lower(*inps).compile() + compiled2 = jax.jit(f, in_shardings=arg_formats).lower(*inps).compile() out3, out4 = compiled2(*inps) - for l1, l2 in safe_zip(arg_layouts, compiled2.input_layouts[0]): + for l1, l2 in safe_zip(arg_formats, compiled2.input_formats[0]): self.assertEqual(l1, l2) self.assertArraysEqual(out1, out3) self.assertArraysEqual(out2, out4) - arrs = [jax.device_put(i, l) for i, l in zip(inps, arg_layouts)] + arrs = [jax.device_put(i, l) for i, l in zip(inps, arg_formats)] out5, out6 = jax.jit(f)(*arrs) self.assertArraysEqual(out1, out5) self.assertArraysEqual(out2, out6) @@ -215,11 +216,11 @@ def test_no_error_dced_args(self): def f(x, y): return x * 2 - jf = jax.jit(f, in_shardings=Layout(DLL.AUTO, s), - out_shardings=Layout(DLL.AUTO, s)) + jf = jax.jit(f, in_shardings=Format(DLL.AUTO, s), + out_shardings=Format(DLL.AUTO, s)) compiled = jf.lower(np_inp, np_inp).compile() - arg_layouts, _ = compiled.input_layouts - arrs = [jax.device_put(i, l) for i, l in zip(arrs, arg_layouts)] + arg_formats, _ = compiled.input_formats + arrs = [jax.device_put(i, l) for i, l in zip(arrs, arg_formats)] compiled(*arrs) def test_aot_layout_mismatch(self): @@ -243,10 +244,10 @@ def f(x): with self.assertRaisesRegex( ValueError, 'Layout passed to jit does not match the layout on the respective arg'): - jax.jit(f, in_shardings=Layout(DLL.AUTO)).lower(arr) + jax.jit(f, in_shardings=Format(DLL.AUTO)).lower(arr) - compiled = jax.jit(f, in_shardings=Layout(DLL.AUTO), - out_shardings=Layout(DLL.AUTO)).lower(sds).compile() + compiled = jax.jit(f, in_shardings=Format(DLL.AUTO), + out_shardings=Format(DLL.AUTO)).lower(sds).compile() with self.assertRaisesRegex( ValueError, @@ -272,30 +273,30 @@ def test_device_put_concrete_layout(self): arr = jax.device_put(np_inp, s) compiled = jax.jit( - lambda x: x * 2, out_shardings=Layout(DLL.AUTO)).lower(arr).compile() - col = compiled.output_layouts + lambda x: x * 2, out_shardings=Format(DLL.AUTO)).lower(arr).compile() + col = compiled.output_formats out = jax.device_put(np_inp, col) - self.assertEqual(out.layout, col) + self.assertEqual(out.format, col) self.assertArraysEqual(out, np_inp) for s in out.addressable_shards: - self.assertEqual(out.layout.device_local_layout, - s.data.layout.device_local_layout) + self.assertEqual(out.format.device_local_layout, + s.data.format.device_local_layout) def test_device_put_non_concrete_layout_error(self): np_inp = np.arange(16).reshape(8, 2) - l1 = Layout(DLL.AUTO, SingleDeviceSharding(jax.devices()[0])) + l1 = Format(DLL.AUTO, SingleDeviceSharding(jax.devices()[0])) with self.assertRaisesRegex( ValueError, 'sharding and device_local_layout.*should be concrete'): jax.device_put(np_inp, l1) - l2 = Layout(DLL.AUTO) + l2 = Format(DLL.AUTO) with self.assertRaisesRegex( ValueError, 'sharding and device_local_layout.*should be concrete'): jax.device_put(np_inp, l2) - l3 = Layout(None, SingleDeviceSharding(jax.devices()[0])) + l3 = Format(None, SingleDeviceSharding(jax.devices()[0])) out = jax.device_put(np_inp, l3) self.assertArraysEqual(out, np_inp) self.assertTrue(out._committed) @@ -305,7 +306,7 @@ def invalid_layout_spec(self): compiled = jax.jit(lambda x: x).lower(x).compile() with self.assertRaisesRegex( ValueError, 'Sharding has to be concrete when layout.*'): - Layout(compiled.output_layouts[0], None) + Format(compiled.output_formats[0], None) def test_layout_on_sds(self): mesh = jtu.create_mesh((2, 1), ('x', 'y')) @@ -313,18 +314,18 @@ def test_layout_on_sds(self): np_inp = np.arange(16).reshape(8, 2) arr = jax.device_put(np_inp, s) - out_layout = jax.jit(jnp.sin, out_shardings=Layout(DLL.AUTO)).lower( - arr).compile().output_layouts + out_format = jax.jit(jnp.sin, out_shardings=Format(DLL.AUTO)).lower( + arr).compile().output_formats - sds = jax.ShapeDtypeStruct(arr.shape, arr.dtype, sharding=out_layout) - arg_layout, _ = jax.jit(lambda x: x * 2).lower(sds).compile().input_layouts - self.assertEqual(arg_layout[0], out_layout) + sds = jax.ShapeDtypeStruct(arr.shape, arr.dtype, sharding=out_format) + arg_format, _ = jax.jit(lambda x: x * 2).lower(sds).compile().input_formats + self.assertEqual(arg_format[0], out_format) with self.assertRaisesRegex( TypeError, 'DeviceLocalLayout.AUTO` cannot be used in place of a device-local' ' layout in a `ShapeDtypeStruct`'): - jax.ShapeDtypeStruct(arr.shape, arr.dtype, sharding=Layout(DLL.AUTO)) + jax.ShapeDtypeStruct(arr.shape, arr.dtype, sharding=Format(DLL.AUTO)) def test_make_array_from_callback(self): mesh = jtu.create_mesh((2, 1), ('x', 'y')) @@ -332,24 +333,24 @@ def test_make_array_from_callback(self): np_inp = np.arange(16).reshape(8, 2) sds = jax.ShapeDtypeStruct(np_inp.shape, np_inp.dtype, sharding=s) - layout = jax.jit(lambda x: x * 2).lower(sds).compile().output_layouts + format = jax.jit(lambda x: x * 2).lower(sds).compile().output_formats - out = jax.make_array_from_callback(np_inp.shape, layout, + out = jax.make_array_from_callback(np_inp.shape, format, lambda idx: np_inp[idx]) self.assertArraysEqual(out, np_inp) - self.assertEqual(out.layout, layout) + self.assertEqual(out.format, format) with self.assertRaisesRegex( TypeError, '`DeviceLocalLayout.AUTO` cannot be used in place of a device-local' ' layout'): - jax.make_array_from_callback(np_inp.shape, Layout(DLL.AUTO, s), + jax.make_array_from_callback(np_inp.shape, Format(DLL.AUTO, s), lambda idx: np_inp[idx]) with self.assertRaisesRegex( TypeError, 'sharding should be an instance of `jax.sharding`'): jax.make_array_from_callback( - np_inp.shape, Layout(None, None), lambda idx: np_inp[idx]) + np_inp.shape, Format(None, None), lambda idx: np_inp[idx]) def test_wsc_concrete_layout(self): mesh = jtu.create_mesh((2, 2), ('x', 'y')) @@ -366,12 +367,12 @@ def f(x): y = x.T # Constrain `y` to the original layout of `arr` because without it, # the layout of `y` would be the transpose of `arr`. - return jax.lax.with_sharding_constraint(y, Layout(custom_dll, s)) + return jax.lax.with_sharding_constraint(y, Format(custom_dll, s)) out = f(arr) - self.assertEqual(out.layout.device_local_layout.major_to_minor, + self.assertEqual(out.format.device_local_layout.major_to_minor, custom_dll.major_to_minor) - self.assertEqual(out.layout, arr.layout) + self.assertEqual(out.format, arr.format) self.assertArraysEqual(out, np_inp.T) def test_wsc_bfloat16_concrete_layout(self): @@ -389,12 +390,12 @@ def f(x): y = x.T # Constrain `y` to the original layout of `arr` because without it, # the layout of `y` would be the transpose of `arr`. - return jax.lax.with_sharding_constraint(y, Layout(custom_dll, s)) + return jax.lax.with_sharding_constraint(y, Format(custom_dll, s)) out = f(arr) - self.assertEqual(out.layout.device_local_layout.major_to_minor, + self.assertEqual(out.format.device_local_layout.major_to_minor, custom_dll.major_to_minor) - self.assertEqual(out.layout, arr.layout) + self.assertEqual(out.format, arr.format) self.assertArraysEqual(out, inp.T) def test_device_put_user_concrete_layout(self): @@ -403,8 +404,8 @@ def test_device_put_user_concrete_layout(self): dll = DLL(major_to_minor=(1, 0)) s = SingleDeviceSharding(jax.devices()[0]) - out = jax.device_put(np_inp, Layout(dll, s)) - self.assertEqual(out.layout.device_local_layout.major_to_minor, + out = jax.device_put(np_inp, Format(dll, s)) + self.assertEqual(out.format.device_local_layout.major_to_minor, dll.major_to_minor) self.assertArraysEqual(out, np_inp) @@ -416,18 +417,18 @@ def test_device_put_user_concrete_layout_multi_device(self): jnp_inp = jnp.arange(math.prod(shape)).reshape(shape) arr = jax.device_put(np_inp, s) - custom_layout = Layout(DLL(major_to_minor=(0, 1)), s) - out1 = jax.device_put(arr, custom_layout) + custom_format = Format(DLL(major_to_minor=(0, 1)), s) + out1 = jax.device_put(arr, custom_format) with jax.sharding.use_mesh(mesh): - out2 = jax.device_put(arr, custom_layout) - out3 = jax.device_put(jnp_inp, custom_layout) - out4 = jax.device_put(np_inp, custom_layout) + out2 = jax.device_put(arr, custom_format) + out3 = jax.device_put(jnp_inp, custom_format) + out4 = jax.device_put(np_inp, custom_format) for o in [out1, out2, out3, out4]: self.assertArraysEqual(o, np_inp) - self.assertEqual(o.layout.device_local_layout.major_to_minor, - custom_layout.device_local_layout.major_to_minor) + self.assertEqual(o.format.device_local_layout.major_to_minor, + custom_format.device_local_layout.major_to_minor) def test_concrete_layout_jit(self): mesh = jtu.create_mesh((2, 2), ('x', 'y')) @@ -440,16 +441,16 @@ def f(x): return x.T custom_dll = DLL(major_to_minor=(0, 1)) - f = jax.jit(f, out_shardings=Layout(custom_dll, s)) + f = jax.jit(f, out_shardings=Format(custom_dll, s)) out = f(arr) self.assertArraysEqual(out, np_inp.T) - self.assertEqual(out.layout.device_local_layout.major_to_minor, + self.assertEqual(out.format.device_local_layout.major_to_minor, custom_dll.major_to_minor) def test_compatible_aval_error(self): custom_dll = DLL(major_to_minor=(0, 1, 2)) - l = Layout(custom_dll, SingleDeviceSharding(jax.devices()[0])) + l = Format(custom_dll, SingleDeviceSharding(jax.devices()[0])) inp = np.arange(8) @partial(jax.jit, in_shardings=l) @@ -463,7 +464,7 @@ def f(x): def test_incompatible_aval_error_device_put(self): custom_dll = DLL(major_to_minor=(0, 1, 2)) - l = Layout(custom_dll, SingleDeviceSharding(jax.devices()[0])) + l = Format(custom_dll, SingleDeviceSharding(jax.devices()[0])) inp = np.arange(8) with self.assertRaisesRegex( @@ -481,19 +482,19 @@ def test_concrete_layout_in_shardings(self): custom_dll = DLL(major_to_minor=(0, 1)) @partial(jax.jit, - in_shardings=Layout(custom_dll, s), - out_shardings=Layout(DLL.AUTO)) + in_shardings=Format(custom_dll, s), + out_shardings=Format(DLL.AUTO)) def f(x): return x.T out = f(arr) self.assertArraysEqual(out, np_inp.T) - self.assertEqual(out.layout.device_local_layout.major_to_minor, + self.assertEqual(out.format.device_local_layout.major_to_minor, custom_dll.major_to_minor[::-1]) custom_dll2 = DLL(major_to_minor=(1, 0)) - @partial(jax.jit, in_shardings=Layout(custom_dll2, s)) + @partial(jax.jit, in_shardings=Format(custom_dll2, s)) def g(x): return x.T @@ -507,7 +508,7 @@ def test_in_layouts_jit_jnp_input(self): sharding = jax.sharding.SingleDeviceSharding(jax.devices()[0]) f = jax.jit(lambda x: x + 1, - in_shardings=Layout(major_last_layout, sharding)) + in_shardings=Format(major_last_layout, sharding)) arr = jnp.arange(8 * 128).reshape(8, 128) out = f(arr) @@ -532,9 +533,9 @@ def test_layout_donation(self): np_inp = np.arange(math.prod(shape)).reshape(shape) custom_dll = DLL(major_to_minor=(0, 1)) - arr = jax.device_put(np_inp, Layout(custom_dll, s)) + arr = jax.device_put(np_inp, Format(custom_dll, s)) - @partial(jax.jit, in_shardings=Layout(custom_dll, s), donate_argnums=0) + @partial(jax.jit, in_shardings=Format(custom_dll, s), donate_argnums=0) def f(x): return x @@ -549,7 +550,7 @@ def test_layout_donation_auto(self): arr = jax.device_put(np_inp, s) - @partial(jax.jit, out_shardings=Layout(DLL.AUTO), donate_argnums=0) + @partial(jax.jit, out_shardings=Format(DLL.AUTO), donate_argnums=0) def f(x): return x * x @@ -563,7 +564,7 @@ def test_layout_donation_matching_in_and_out(self): np_inp = np.arange(math.prod(shape)).reshape(shape) custom_dll = DLL(major_to_minor=(0, 1)) - l = Layout(custom_dll, s) + l = Format(custom_dll, s) arr = jax.device_put(np_inp, l) @partial(jax.jit, in_shardings=l, out_shardings=l, donate_argnums=0) @@ -581,7 +582,7 @@ def test_layout_donation_mismatching_in_and_out_fails(self): np_inp = np.arange(math.prod(shape), dtype=jnp.bfloat16).reshape(shape) custom_dll1 = DLL(major_to_minor=(1, 0), _tiling=((8,128), (2,1))) - l1 = Layout(custom_dll1, s) + l1 = Format(custom_dll1, s) arr = jax.device_put(np_inp, s) @partial(jax.jit, out_shardings=l1, donate_argnums=0) @@ -593,7 +594,7 @@ def f(x): self.assertFalse(arr.is_deleted()) def test_donation_error_on_auto(self): - @partial(jax.jit, donate_argnums=0, in_shardings=Layout(DLL.AUTO)) + @partial(jax.jit, donate_argnums=0, in_shardings=Format(DLL.AUTO)) def f(x): return x * 2 @@ -601,7 +602,7 @@ def f(x): ValueError, ".*Did you mean to set the.*output layout.*AUTO.*"): f(jnp.arange(8)) - @partial(jax.jit, donate_argnums=0, out_shardings=Layout(DLL.AUTO)) + @partial(jax.jit, donate_argnums=0, out_shardings=Format(DLL.AUTO)) def g(x): return x * 2 @@ -618,16 +619,16 @@ def test_sparsecore_compute(self): dll = DLL(major_to_minor=(0, 1), _tiling=((8,),)) s = SingleDeviceSharding(jax.devices()[0]) - sparse_layout = Layout(dll, s) - sparecore_arr = jax.device_put(inp, sparse_layout) - dense_layout = Layout(DLL(major_to_minor=(0, 1)), s) + sparse_format = Format(dll, s) + sparecore_arr = jax.device_put(inp, sparse_format) + dense_format = Format(DLL(major_to_minor=(0, 1)), s) @compute_on('tpu_sparsecore') @jax.jit def sparsecore_compute(x): return x * x - @partial(jax.jit, out_shardings=(dense_layout, sparse_layout)) + @partial(jax.jit, out_shardings=(dense_format, sparse_format)) def f(x, y): return x * 2, sparsecore_compute(y) @@ -644,8 +645,8 @@ def test_sparsecore_compute_twice(self): dll = DLL(major_to_minor=(0, 1), _tiling=((8,),)) s = SingleDeviceSharding(jax.devices()[0]) - sparse_layout = Layout(dll, s) - sparecore_arr = jax.device_put(inp, sparse_layout) + sparse_format = Format(dll, s) + sparecore_arr = jax.device_put(inp, sparse_format) @compute_on('tpu_sparsecore') @jax.jit @@ -657,7 +658,7 @@ def sparsecore_multiply(x, y): def sparsecore_add(x, y): return x + y - @partial(jax.jit, donate_argnums=0, out_shardings=sparse_layout) + @partial(jax.jit, donate_argnums=0, out_shardings=sparse_format) def f(x): return sparsecore_multiply(sparsecore_add(x, x) + 1, x) @@ -674,12 +675,12 @@ def test_sparsecore_and_host_compute(self): s = SingleDeviceSharding(jax.devices()[0]) sparse_dll = DLL(major_to_minor=(0, 1), _tiling=((8,),)) - sparse_layout = Layout(sparse_dll, s) - sparecore_arr = jax.device_put(inp, sparse_layout) + sparse_format = Format(sparse_dll, s) + sparecore_arr = jax.device_put(inp, sparse_format) host_dll = DLL(major_to_minor=(0, 1), _tiling=((1,),)) - host_layout = Layout(host_dll, s) - host_arr = jax.device_put(inp, host_layout) + host_format = Format(host_dll, s) + host_arr = jax.device_put(inp, host_format) @compute_on('tpu_sparsecore') @jax.jit @@ -693,8 +694,8 @@ def host_compute(x): @partial( jax.jit, - in_shardings=(sparse_layout, host_layout), - out_shardings=(sparse_layout, host_layout), + in_shardings=(sparse_format, host_format), + out_shardings=(sparse_format, host_format), ) def f(x, y): return sparsecore_compute(x), host_compute(y) @@ -708,9 +709,9 @@ def test_cpp_layout_cache_miss(self): np_inp = np.arange(math.prod(shape)).reshape(shape) arr = jax.device_put(np_inp, s) - arr_m2m = arr.layout.device_local_layout.major_to_minor - custom_layout = Layout(DLL(major_to_minor=arr_m2m[::-1]), s) - arr2 = jax.device_put(np_inp, custom_layout) + arr_m2m = arr.format.device_local_layout.major_to_minor + custom_format = Format(DLL(major_to_minor=arr_m2m[::-1]), s) + arr2 = jax.device_put(np_inp, custom_format) @jax.jit def f(x): @@ -730,9 +731,9 @@ def test_layout_donation_with_default_layout(self): shape = (16, 16) np_inp = np.arange(math.prod(shape)).reshape(shape) arr = jax.device_put(np_inp, s) - out_layout = Layout(arr.layout.device_local_layout, s) + out_format = Format(arr.format.device_local_layout, s) - @partial(jax.jit, out_shardings=out_layout, donate_argnums=0) + @partial(jax.jit, out_shardings=out_format, donate_argnums=0) def f(x): return x * 2 @@ -742,7 +743,37 @@ def f(x): out = f(arr) self.assertArraysEqual(out, np_inp * 2) - self.assertEqual(out.layout, out_layout) + self.assertEqual(out.format, out_format) + + def test_with_layout_constraint(self): + if not jtu.test_device_matches(['tpu']): + self.skipTest('Only works for TPU') + mesh = jtu.create_mesh((2, 2), ('x', 'y')) + shape = (16, 128) + s = NamedSharding(mesh, P('x')) + np_inp = np.arange(math.prod(shape)).reshape(shape) + arr = jax.device_put(np_inp, s) + + # Create a custom layout instead of using `arr.layout` to test the API. + custom_dll = DLL(major_to_minor=arr.format.dll.major_to_minor[::-1]) + + def f(x): + y = x.T + # Constrain `y` to the original layout of `arr` because without it, + # the layout of `y` would be the transpose of `arr`. + y = with_layout_constraint(y, custom_dll) + return y * 2 + + f(arr) # doesn't crash + + f = jax.jit(f) + out = f(arr) + self.assertEqual(out.format.device_local_layout.major_to_minor, + custom_dll.major_to_minor) + self.assertArraysEqual(out, np_inp.T * 2) + + lowered_text = f.lower(arr).as_text() + self.assertIn('LayoutConstraint', lowered_text) if __name__ == '__main__': diff --git a/tests/linalg_sharding_test.py b/tests/linalg_sharding_test.py index d8e1e6a16871..5d7b3b8a637b 100644 --- a/tests/linalg_sharding_test.py +++ b/tests/linalg_sharding_test.py @@ -14,7 +14,7 @@ import functools -from absl.testing import absltest +from absl.testing import absltest, parameterized import numpy as np import jax @@ -31,30 +31,22 @@ complex_types = jtu.dtypes.complex +# These functions are only supported on CPU. CPU_ONLY_FUN_AND_SHAPES = [ - # These functions are supported on GPU, but partitioning support will - # require updates to GSPMD, since they are lowered directly to HLO ops - # instead of custom calls on GPU. - (lax.linalg.cholesky, ((6, 6),)), - (lax.linalg.triangular_solve, ((6, 6), (4, 6))), - - # The GPU kernel for this function still uses an opaque descriptor to - # encode the input shapes so it is not partitionable. - # TODO(danfm): Update the kernel and enable this test on GPU. - (lax.linalg.tridiagonal_solve, ((6,), (6,), (6,), (6, 4))), - - # These functions are only supported on CPU. (lax.linalg.hessenberg, ((6, 6),)), (lax.linalg.schur, ((6, 6),)), ] CPU_AND_GPU_FUN_AND_SHAPES = [ + (lax.linalg.cholesky, ((6, 6),)), (lax.linalg.eig, ((6, 6),)), (lax.linalg.eigh, ((6, 6),)), (lax.linalg.lu, ((10, 6),)), (lax.linalg.qr, ((6, 6),)), (lax.linalg.svd, ((10, 6),)), + (lax.linalg.triangular_solve, ((6, 6), (4, 6))), (lax.linalg.tridiagonal, ((6, 6),)), + (lax.linalg.tridiagonal_solve, ((6,), (6,), (6,), (6, 4))), ] ALL_FUN_AND_SHAPES = CPU_ONLY_FUN_AND_SHAPES + CPU_AND_GPU_FUN_AND_SHAPES @@ -68,9 +60,19 @@ def setUp(self): self.skipTest("Requires multiple devices") def get_fun_and_shapes(self, fun_and_shapes, grad=False): - if (jtu.test_device_matches(["gpu"]) - and fun_and_shapes not in CPU_AND_GPU_FUN_AND_SHAPES): - self.skipTest(f"{fun_and_shapes[0].__name__} not supported on GPU") + if jtu.test_device_matches(["gpu"]): + if fun_and_shapes not in CPU_AND_GPU_FUN_AND_SHAPES: + self.skipTest( + f"Partitioning {fun_and_shapes[0].__name__} not supported on GPU.") + if (fun_and_shapes[0] in (lax.linalg.cholesky, lax.linalg.triangular_solve) + and not config.use_shardy_partitioner.value): + self.skipTest( + f"Partitioning {fun_and_shapes[0].__name__} only supported on GPU " + "when shardy is enabled.") + if fun_and_shapes[0] == lax.linalg.tridiagonal_solve: + self.skipTest( + f"Partitioning {fun_and_shapes[0].__name__} on GPU, requires a " + "more recent jaxlib version.") if not grad: return fun_and_shapes @@ -79,10 +81,10 @@ def get_fun_and_shapes(self, fun_and_shapes, grad=False): self.skipTest(f"{fun.__name__} does not support differentation") if jtu.test_device_matches(["gpu"]) and fun in ( lax.linalg.eig, lax.linalg.lu, lax.linalg.qr - ): + ) and not config.use_shardy_partitioner.value: self.skipTest( f"JVP of {fun.__name__} uses triangular solve on GPU, which doesn't " - "support batch partitioning yet") + "support batch partitioning unless shardy is enabled.") if fun == lax.linalg.eig: fun = functools.partial( @@ -107,9 +109,8 @@ def arg_maker(shape): return x return tuple(arg_maker(shape) for shape in shapes) - @jtu.sample_product( - fun_and_shapes=ALL_FUN_AND_SHAPES, - dtype=float_types + complex_types, + @parameterized.product( + fun_and_shapes=ALL_FUN_AND_SHAPES, dtype=float_types + complex_types ) @jtu.run_on_devices("gpu", "cpu") def test_batch_axis_sharding(self, fun_and_shapes, dtype): @@ -124,20 +125,17 @@ def test_batch_axis_sharding(self, fun_and_shapes, dtype): expected = fun(*args) actual = fun_jit(*args_sharded) self.assertAllClose(actual, expected) - # TODO(danfm): Re-enable this check after diganosing non-determinism. - # self.assertNotIn("all-", fun_jit.lower(*args_sharded).compile().as_text()) + self.assertNotIn("all-", fun_jit.lower(*args_sharded).compile().as_text()) vmap_fun = jax.vmap(fun) vmap_fun_jit = jax.jit(vmap_fun) actual = vmap_fun_jit(*args_sharded) self.assertAllClose(actual, expected) - # TODO(danfm): Re-enable this check after diganosing non-determinism. - # self.assertNotIn( - # "all-", vmap_fun_jit.lower(*args_sharded).compile().as_text()) + self.assertNotIn( + "all-", vmap_fun_jit.lower(*args_sharded).compile().as_text()) - @jtu.sample_product( - fun_and_shapes=ALL_FUN_AND_SHAPES, - dtype=float_types + complex_types, + @parameterized.product( + fun_and_shapes=ALL_FUN_AND_SHAPES, dtype=float_types + complex_types ) @jtu.run_on_devices("gpu", "cpu") def test_non_batch_axis_sharding(self, fun_and_shapes, dtype): @@ -155,9 +153,8 @@ def test_non_batch_axis_sharding(self, fun_and_shapes, dtype): self.assertIn( "all-gather", fun_jit.lower(*args_sharded).compile().as_text()) - @jtu.sample_product( - fun_and_shapes=ALL_FUN_AND_SHAPES, - dtype=float_types + complex_types, + @parameterized.product( + fun_and_shapes=ALL_FUN_AND_SHAPES, dtype=float_types + complex_types ) @jtu.run_on_devices("gpu", "cpu") def test_batch_axis_sharding_jvp(self, fun_and_shapes, dtype): @@ -181,14 +178,14 @@ def jvp_fun(primals, tangents): (primals_sharded, tangents), ]: _, actual = jvp_fun_jit(*args) - self.assertAllClose(actual, expected) - # TODO(danfm): Re-enable this check after diganosing non-determinism. - # hlo = jvp_fun_jit.lower(primals_sharded, tangents_sharded).compile() - # self.assertNotIn("all-", hlo.as_text()) - - @jtu.sample_product( - fun_and_shapes=ALL_FUN_AND_SHAPES, - dtype=float_types + complex_types, + self.assertAllClose(actual, expected, rtol={ + np.float32: 1e-4, np.float64: 1e-11, np.complex64: 1e-4, + np.complex128: 1e-11}) + hlo = jvp_fun_jit.lower(primals_sharded, tangents_sharded).compile() + self.assertNotIn("all-", hlo.as_text()) + + @parameterized.product( + fun_and_shapes=ALL_FUN_AND_SHAPES, dtype=float_types + complex_types ) @jtu.run_on_devices("gpu", "cpu") def test_batch_axis_sharding_vjp(self, fun_and_shapes, dtype): @@ -204,10 +201,11 @@ def test_batch_axis_sharding_vjp(self, fun_and_shapes, dtype): vjp_fun_jit = jax.jit(vjp_fun) expected = vjp_fun(tangents) actual = vjp_fun_jit(tangents_sharded) - self.assertAllClose(actual, expected) - # TODO(danfm): Re-enable this check after diganosing non-determinism. - # hlo = vjp_fun_jit.lower(tangents_sharded).compile() - # self.assertNotIn("all-", hlo.as_text()) + self.assertAllClose(actual, expected, rtol={ + np.float32: 1e-4, np.float64: 1e-11, np.complex64: 1e-4, + np.complex128: 1e-11}) + hlo = vjp_fun_jit.lower(tangents_sharded).compile() + self.assertNotIn("all-", hlo.as_text()) if __name__ == "__main__": diff --git a/tests/linalg_test.py b/tests/linalg_test.py index feab105ccbe2..c75927b26fd8 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -14,7 +14,7 @@ from functools import partial import itertools -from typing import Iterator +from collections.abc import Iterator from unittest import skipIf import numpy as np @@ -69,7 +69,9 @@ def _axis_for_ndim(ndim: int) -> Iterator[None | int | tuple[int, ...]]: def osp_linalg_toeplitz(c: np.ndarray, r: np.ndarray | None = None) -> np.ndarray: """scipy.linalg.toeplitz with v1.17+ batching semantics.""" - if scipy_version >= (1, 17, 0): + # TODO(dfm,jakevdp): Remove dev check after upstream PR is merged: + # https://github.com/scipy/scipy/issues/21466. + if scipy_version >= (1, 17, 0) and "dev0" not in scipy.version.version: return scipy.linalg.toeplitz(c, r) elif r is None: c = np.atleast_1d(c) @@ -96,7 +98,7 @@ def args_maker(): a = rng(factor_shape, dtype) return [np.matmul(a, jnp.conj(T(a)))] - jnp_fun = partial(jnp.linalg.cholesky, upper=upper) + jnp_fun = partial(jnp.linalg.cholesky, upper=upper, symmetrize_input=True) def np_fun(x, upper=upper): # Upper argument added in NumPy 2.0.0 @@ -867,9 +869,6 @@ def testSVD(self, b, m, n, dtype, full_matrices, compute_uv, hermitian, algorith self.skipTest("Hermitian SVD doesn't support the algorithm parameter.") if not jtu.test_device_matches(["cpu", "gpu"]): self.skipTest("SVD algorithm selection only supported on CPU and GPU.") - # TODO(danfm): Remove this check after 0.5.2 is released. - if jtu.test_device_matches(["cpu"]) and jtu.jaxlib_version() <= (0, 5, 1): - self.skipTest("SVD algorithm selection on CPU requires a newer jaxlib version.") if jtu.test_device_matches(["cpu"]) and algorithm == lax.linalg.SvdAlgorithm.JACOBI: self.skipTest("Jacobi SVD not supported on GPU.") @@ -1064,7 +1063,7 @@ def testQrInvalidDtypeCPU(self, shape=(5, 6), dtype=np.float16): else: err, msg = Exception, "Unsupported dtype" with self.assertRaisesRegex(err, msg): - jnp.linalg.qr(arr) + jax.block_until_ready(jnp.linalg.qr(arr)) @jtu.sample_product( shape=[(10, 4, 5), (5, 3, 3), (7, 6, 4)], @@ -1233,6 +1232,13 @@ def testMatrixPower(self, shape, dtype, n): self._CompileAndCheck(partial(jnp.linalg.matrix_power, n=n), args_maker, rtol=1e-3) + def testMatrixPowerBool(self): + # Regression test for https://github.com/jax-ml/jax/issues/28603 + mat = np.array([[True,True], [False,True]]) + np_result = np.linalg.matrix_power(mat, 2) + jnp_result = jnp.linalg.matrix_power(mat, 2) + self.assertArraysEqual(np_result, jnp_result) + @jtu.sample_product( shape=[(3, ), (1, 2), (8, 5), (4, 4), (5, 5), (50, 50), (3, 4, 5), (2, 3, 4, 5)], @@ -2213,7 +2219,10 @@ def build_tri(dl, d, du): build_tri = jax.vmap(build_tri) a = build_tri(dl, d, du) - self.assertAllClose(a @ x, b, atol=5e-5, rtol=1e-4) + with jax.default_matmul_precision("float32"): + self.assertAllClose(a @ x, b, atol={ + np.float32: 1e-3, np.float64: 1e-10, np.complex64: 1e-3, + np.complex128: 1e-10}) def test_tridiagonal_solve_endpoints(self): # tridagonal_solve shouldn't depend on the endpoints being explicitly zero. @@ -2332,6 +2341,22 @@ def testSymmetricProduct(self, shape, dtype, symmetrize_output): self.assertAllClose( new_product_with_batching, old_product, atol=atol) + @jtu.sample_product( + n=[0, 1, 5, 10, 20], + kind=["symmetric", "lower", "upper"], + ) + @jax.default_matmul_precision("float32") + def testPascal(self, n, kind): + args_maker = lambda: [] + osp_fun = partial(osp.linalg.pascal, n=n, kind=kind, exact=False) + jsp_fun = partial(jsp.linalg.pascal, n=n, kind=kind) + self._CheckAgainstNumpy(osp_fun, + jsp_fun, args_maker, + atol=1e-3, + rtol=1e-2 if jtu.test_device_matches(['tpu']) else 1e-3, + check_dtypes=False) + self._CompileAndCheck(jsp_fun, args_maker) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/lobpcg_test.py b/tests/lobpcg_test.py index fc2b0df849d1..76d6006432f4 100644 --- a/tests/lobpcg_test.py +++ b/tests/lobpcg_test.py @@ -272,7 +272,7 @@ def checkLobpcgMonotonicity(self, matrix_name, n, k, m, tol, dtype): self._possibly_plot(A, eigs, X, m, matrix_name) def _possibly_plot(self, A, eigs, X, m, matrix_name): - if not os.getenv('LOBPCG_EMIT_DEBUG_PLOTS'): + if os.getenv('LOBPCG_EMIT_DEBUG_PLOTS', '0') != '1': return if isinstance(A, (np.ndarray, jax.Array)): diff --git a/tests/memories_test.py b/tests/memories_test.py index 0ca973c4d221..1b56236c91c9 100644 --- a/tests/memories_test.py +++ b/tests/memories_test.py @@ -25,17 +25,18 @@ from jax import lax from jax._src import test_util as jtu from jax._src import xla_bridge as xb -from jax._src.layout import DeviceLocalLayout as DLL, Layout +from jax._src.layout import DeviceLocalLayout as DLL, Format from jax._src import config from jax.ad_checkpoint import checkpoint_name, checkpoint as new_checkpoint import jax.numpy as jnp from jax.ad_checkpoint import Offloadable, remat, Recompute from jax._src.sharding import common_devices_indices_map -from jax._src.sharding_impls import (NamedSharding, PositionalSharding, - SingleDeviceSharding, GSPMDSharding, - TransferToMemoryKind, PartitionSpec as P) +from jax._src.sharding_impls import ( + NamedSharding, SingleDeviceSharding, GSPMDSharding, + TransferToMemoryKind, PartitionSpec as P) +from jax._src.xla_metadata import set_xla_metadata from jax.experimental.compute_on import compute_on -from jax.experimental.shard_map import shard_map +from jax._src.shard_map import shard_map import numpy as np config.parse_flags_with_absl() @@ -66,7 +67,6 @@ def setUp(self): @parameterized.named_parameters( ("named_sharding", "named_sharding"), - ("positional_sharding", "positional_sharding"), ("single_device_sharding", "single_device_sharding"), ("gspmd_sharding", "gspmd_sharding"), ) @@ -75,9 +75,6 @@ def test_canonicalize_memory_kind(self, name): mesh = jtu.create_mesh((1,), "x") ns = NamedSharding(mesh, P("x")) self.assertEqual(ns.memory_kind, self._default_memory_kind) - elif name == "positional_sharding": - ps = PositionalSharding(jax.devices()) - self.assertEqual(ps.memory_kind, self._default_memory_kind) elif name == "single_device_sharding": ss = SingleDeviceSharding(jax.devices()[0]) self.assertEqual(ss.memory_kind, self._default_memory_kind) @@ -88,7 +85,6 @@ def test_canonicalize_memory_kind(self, name): @parameterized.named_parameters( ("named_sharding", "named_sharding"), - ("positional_sharding", "positional_sharding"), ("single_device_sharding", "single_device_sharding"), ("gspmd_sharding", "gspmd_sharding"), ) @@ -99,11 +95,6 @@ def test_wrong_memory_kind(self, name): ): mesh = jtu.create_mesh((1,), ("x",)) NamedSharding(mesh, P("x"), memory_kind="hbm") - elif name == "positional_sharding": - with self.assertRaisesRegex( - ValueError, "Could not find memory addressable by device.*" - ): - PositionalSharding(jax.devices(), memory_kind="gpu_hbm") elif name == "single_device_sharding": with self.assertRaisesRegex( ValueError, @@ -120,7 +111,6 @@ def test_wrong_memory_kind(self, name): @parameterized.named_parameters( ("named_sharding", "named_sharding"), - ("positional_sharding", "positional_sharding"), ("single_device_sharding", "single_device_sharding"), ("gspmd_sharding", "gspmd_sharding"), ) @@ -131,8 +121,6 @@ def test_correct_tpu_memory_kind(self, name): if name == "named_sharding": mesh = jtu.create_mesh((1,), ("x",)) NamedSharding(mesh, P("x"), memory_kind=self._default_memory_kind) - elif name == "positional_sharding": - PositionalSharding(jax.devices(), memory_kind=self._default_memory_kind) elif name == "single_device_sharding": SingleDeviceSharding(jax.devices()[0], memory_kind="unpinned_host") else: @@ -141,7 +129,6 @@ def test_correct_tpu_memory_kind(self, name): @parameterized.named_parameters( ("named_sharding", "named_sharding"), - ("positional_sharding", "positional_sharding"), ("single_device_sharding", "single_device_sharding"), ("gspmd_sharding", "gspmd_sharding"), ) @@ -151,10 +138,6 @@ def test_sharding_eq(self, name): s1 = NamedSharding(mesh, P("x")) s2 = NamedSharding(mesh, P("x"), memory_kind=self._default_memory_kind) self.assertEqual(s1, s2) - elif name == "positional_sharding": - s1 = PositionalSharding(jax.devices()) - s2 = PositionalSharding(jax.devices(), memory_kind=self._default_memory_kind) - self.assertEqual(s1, s2) elif name == "single_device_sharding": s1 = SingleDeviceSharding(jax.devices()[0]) s2 = SingleDeviceSharding(jax.devices()[0], memory_kind=self._default_memory_kind) @@ -655,7 +638,7 @@ def f(): @jtu.run_on_devices('tpu') def test_ragged_copy_on_host(self): mesh = jtu.create_mesh((2,), ('x')) - sharding = jax.sharding.NamedSharding(mesh, P(('x'))) + sharding = jax.sharding.NamedSharding(mesh, P('x')) cpu_sharding = sharding.with_memory_kind('pinned_host') num_pages = 512 * 1024 @@ -665,7 +648,7 @@ def test_ragged_copy_on_host(self): def write(x): return x.at[16 * 1024:].set(0) - x = shard_map(write, mesh, P(('x'),), P(('x')))(x) + x = shard_map(write, mesh=mesh, in_specs=P(('x'),), out_specs=P('x'))(x) chunk_size = 8 def inner(state): @@ -686,8 +669,8 @@ def foo(x): _, _, cpu_x = jax.lax.while_loop(cond, inner, (0, x, output)) return cpu_x - fn = jax.jit(shard_map(foo, mesh, P(('x'),), P(('x')), - check_rep=False), + fn = jax.jit(shard_map(foo, mesh=mesh, in_specs=P(('x'),), + out_specs=P('x'), check_vma=False), out_shardings=cpu_sharding) y = fn(x) jax.block_until_ready(y) @@ -756,9 +739,6 @@ def init(): def test_compute_no_inputs_host_replicated(self): if xb.backend_xla_version() is not None and xb.backend_xla_version() < 3: self.skipTest("This test requires an xla_version >= 3.") - if config.use_shardy_partitioner.value: - self.skipTest("XLA failure due to b/370786664 and b/366411266. " - "Enable when fixed.") mesh = jtu.create_mesh((4,), ('data')) tpu_sharding = NamedSharding(mesh, P('data')) @@ -794,6 +774,36 @@ def f(x): lowered_text = f.lower(jnp.arange(8)).as_text() self.assertIn('_xla_compute_type', lowered_text) + @functools.partial(jax.jit, out_shardings=out_s) + def h(x): + y = g(x) + return y * 3 + + out2 = h(inp) + self.assertArraysEqual(out2, inp * 6) + self.assertEqual(out2.sharding.memory_kind, "pinned_host") + + def test_compute_on_2d(self): + out_s = SingleDeviceSharding(jax.devices()[0], memory_kind="pinned_host") + + @compute_on("device_host") + @jax.jit + def g(x): + return x * 2 + + @jax.jit + def f(x): + y = g(x) + return y * 3 + + inp = jnp.arange(9943.0) + inp = jnp.reshape(inp, (61, 163)) + out = f(inp) + self.assertArraysEqual(out, inp * 6) + + lowered_text = f.lower(inp).as_text() + self.assertIn("_xla_compute_type", lowered_text) + @functools.partial(jax.jit, out_shardings=out_s) def h(x): y = g(x) @@ -1474,8 +1484,8 @@ def test_mem_kind_donation_pinned_host(self): s = NamedSharding(mesh, P(), memory_kind='pinned_host') s_dev = s.with_memory_kind('device') - @compute_on('device_host') @functools.partial(jax.jit, out_shardings=(s, s_dev), donate_argnums=(0, 1)) + @compute_on('device_host') def f(inp1, inp2): return inp1 * 2, inp2 * 2 @@ -1564,8 +1574,8 @@ def test_fn(x_in, y_in): y = jnp.reshape(y, (16, 64)) custom_dll = DLL(major_to_minor=(0, 1), _tiling=((8, 128),)) custom_dll_linear = DLL(major_to_minor=(0, 1), _tiling=((1,),)) - x = jax.device_put(x, Layout(custom_dll, sharding)) - y = jax.device_put(y, Layout(custom_dll_linear, p_sharding)) + x = jax.device_put(x, Format(custom_dll, sharding)) + y = jax.device_put(y, Format(custom_dll_linear, p_sharding)) x1 = jnp.arange(0, 1024, dtype=jnp.float32) x1 = jnp.reshape(x1, (16, 64)) @@ -1575,8 +1585,8 @@ def test_fn(x_in, y_in): jit_fn = jax.jit( test_fn, out_shardings=( - Layout(custom_dll, sharding), - Layout(custom_dll_linear, p_sharding), + Format(custom_dll, sharding), + Format(custom_dll_linear, p_sharding), ), ) x_out, y_out = jit_fn(x, y) @@ -1603,8 +1613,8 @@ def test_fn(x_in, y_in): y = jnp.reshape(y, (32, 64)) custom_dll = DLL(major_to_minor=(0, 1), _tiling=((8, 128),)) custom_dll_linear = DLL(major_to_minor=(0, 1), _tiling=((1,),)) - x = jax.device_put(x, Layout(custom_dll, sharding)) - y = jax.device_put(y, Layout(custom_dll_linear, p_sharding)) + x = jax.device_put(x, Format(custom_dll, sharding)) + y = jax.device_put(y, Format(custom_dll_linear, p_sharding)) x1 = jnp.arange(0, 2048, dtype=jnp.float32) x1 = jnp.reshape(x1, (32, 64)) @@ -1614,8 +1624,8 @@ def test_fn(x_in, y_in): jit_fn = jax.jit( test_fn, out_shardings=( - Layout(custom_dll, sharding), - Layout(custom_dll_linear, p_sharding), + Format(custom_dll, sharding), + Format(custom_dll_linear, p_sharding), ), ) x_out, y_out = jit_fn(x, y) @@ -1638,6 +1648,20 @@ def f(x): # 2 for `f` and `2` for `mul` (compute type changes for `mul`) self.assertEqual(count(), 4) + def test_compute_on_aot(self): + operand = np.float32(0.) + + @jax.jit + @compute_on("device_host") + def f_host(x): + # Adds 1 on CPU and adds 2 on other platforms + return jax.lax.platform_dependent(x, + cpu=lambda x: x + 1., + default=lambda x: x + 2.) + + self.assertAllClose(1., f_host(operand)) + self.assertAllClose(1., f_host.lower(operand).compile()(operand)) + def test_offload_take_host(self): # TODO(apaszke): Remove after 12 weeks have passed. if not jtu.if_cloud_tpu_at_least(2024, 12, 19): @@ -1661,9 +1685,78 @@ def peer_forward(x, experts, indices, scores): class StreamAnnotationTest(jtu.JaxTestCase): + def test_stream_annotation_single_instruction(self): + # E2E test for fix https://github.com/openxla/xla/pull/24269 + if not jtu.test_device_matches(["gpu"]): + self.skipTest("Stream annotation is only supported on GPU.") + + mesh = jtu.create_mesh((2,), ('x',)) + s = NamedSharding(mesh, P('x')) + np_inp = np.ones((8,)) + arr1 = jax.device_put(np_inp, s) + arr2 = jax.device_put(np_inp, s) + + @compute_on('gpu_stream:1') + @jax.jit + def g(x, y): + return x + y + + @jax.jit + def f(x, y): + return g(x, y) + + compiled_f = jax.jit(f).lower(arr1, arr2).compile() + compiled_text = compiled_f.as_text() + self.assertIn('call-start', compiled_text) + self.assertIn('_xla_stream_annotation="1"', compiled_text) + self.assertIn('wrapped_add', compiled_text) + self.assertArraysEqual(compiled_f(arr1, arr2), arr1 * 2) + + def test_streamed_gemm_overlap(self): + if not jtu.test_device_matches(["gpu"]): + self.skipTest("Stream annotation is only supported on GPU.") + + mesh = jtu.create_mesh((2,), ('x',)) + s = NamedSharding(mesh, P('x')) + + @compute_on('gpu_stream:1') + @jax.jit + def g(x, y): + return x @ y + + @compute_on('gpu_stream:2') + @jax.jit + def h(x, y): + return x @ y + + @jax.jit + @functools.partial( + jax.shard_map, mesh=mesh, in_specs=(P('x'), P('x')), + out_specs=P('x')) + def f(x, y): + with set_xla_metadata(_scheduling_group_id="1"): + a = g(x, y) + b = h(y, x) + return a + b + + np_input = np.ones((1024, 512)) + + arr1 = jax.device_put(np_input, s) + arr2 = jax.device_put(np_input, s) + + compiled_f = jax.jit(f).lower(arr1, arr2).compile() + compiled_text = compiled_f.as_text() + self.assertIn('call-start', compiled_text) + self.assertIn('_xla_stream_annotation="1"', compiled_text) + self.assertIn('call-start.1', compiled_text) + self.assertIn('_xla_stream_annotation="2"', compiled_text) + self.assertIn('_scheduling_group_id="1"', compiled_text) + self.assertArraysEqual(compiled_f(arr1, arr2), arr1 * 1024) + def test_stream_annotation_inside_shmap(self): if not jtu.test_device_matches(["gpu"]): self.skipTest("Stream annotation is only supported on GPU.") + mesh = jtu.create_mesh((2,), ('x',)) s = NamedSharding(mesh, P('x')) np_inp = np.ones((8,)) @@ -1673,22 +1766,27 @@ def test_stream_annotation_inside_shmap(self): @compute_on('gpu_stream:1') @jax.jit def g(x, y): - return x * y + return x * y + x @compute_on('gpu_stream:2') @jax.jit def h(x, y): - return x * y + return x * y + x def f(x, y): z = g(x, y) w = h(3 * x, 2 * y) return z + w - out = jax.jit(shard_map(f, mesh=mesh, in_specs=(P('x'), P('x')), - out_specs=P('x')))(arr1, arr2) - self.assertArraysEqual(out, arr1 * 7) - + compiled_f = jax.jit( + shard_map(f, mesh=mesh, in_specs=(P('x'), P('x')), + out_specs=P('x'))).lower(arr1, arr2).compile() + compiled_text = compiled_f.as_text() + self.assertIn('call-start', compiled_text) + self.assertIn('_xla_stream_annotation="1"', compiled_text) + self.assertIn('call-start.1', compiled_f.as_text()) + self.assertIn('_xla_stream_annotation="2"', compiled_text) + self.assertArraysEqual(compiled_f(arr1, arr2), arr1 * 11) class ActivationOffloadingTest(jtu.JaxTestCase): diff --git a/tests/mesh_utils_test.py b/tests/mesh_utils_test.py index 136b507942e7..28efb266b281 100644 --- a/tests/mesh_utils_test.py +++ b/tests/mesh_utils_test.py @@ -1,4 +1,4 @@ -# Copyright 2021 The TensorFlow Authors. All Rights Reserved. +# Copyright 2021 The JAX Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/mock_gpu_topology_test.py b/tests/mock_gpu_topology_test.py index 59c511ae61cf..8e409d6ed331 100644 --- a/tests/mock_gpu_topology_test.py +++ b/tests/mock_gpu_topology_test.py @@ -14,6 +14,7 @@ from absl.testing import absltest import jax +from jax._src import config from jax._src import test_util as jtu import jax.numpy as jnp from jax.sharding import NamedSharding @@ -49,13 +50,18 @@ def testMockWithSharding(self): f_lowered = f.lower(jnp.arange(16)) hlo = f_lowered.compiler_ir() + hlo_str = str(hlo) mocked_count = NUM_SLICES * NUM_HOSTS_PER_SLICE - self.assertIn(f'num_partitions = {mocked_count}', str(hlo)) - self.assertIn( - f'sharding = "{{devices=[{mocked_count}]<=[{mocked_count}]}}"', - str(hlo) - ) + self.assertIn(f'num_partitions = {mocked_count}', hlo_str) + + if config.use_shardy_partitioner.value: + expected_sharding = 'sharding = #sdy.sharding<@mesh, [{"x"}]>' + else: + expected_sharding = ( + f'sharding = "{{devices=[{mocked_count}]<=[{mocked_count}]}}"' + ) + self.assertIn(expected_sharding, hlo_str) if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/monitoring_test.py b/tests/monitoring_test.py index 52b53895c2cc..a50ddf6f4cc6 100644 --- a/tests/monitoring_test.py +++ b/tests/monitoring_test.py @@ -29,7 +29,7 @@ def tearDown(self): def test_record_event(self): events = [] - counters = {} # Map event names to frequency. + counters = {} # Map event names to frequency. def increment_event_counter(event): if event not in counters: counters[event] = 0 @@ -48,8 +48,9 @@ def increment_event_counter(event): "test_common_event": 2}) def test_record_event_durations(self): - durations = {} # Map event names to frequency. - def increment_event_duration(event, duration): + durations = {} # Map event names to frequency. + def increment_event_duration(event, duration, **kwargs): + del kwargs if event not in durations: durations[event] = 0. durations[event] += duration @@ -62,9 +63,33 @@ def increment_event_duration(event, duration): self.assertDictEqual(durations, {"test_short_event": 3, "test_long_event": 10}) + def test_record_scalar(self): + observed_keys = [] + observed_values = [] + + monitoring.register_scalar_listener( + lambda key, _, **kwargs: observed_keys.append(key), + ) + monitoring.register_scalar_listener( + lambda _, value, **kwargs: observed_values.append(value), + ) + + monitoring.record_scalar("test_unique_event", 1) + monitoring.record_scalar("test_common_event", 2.5) + monitoring.record_scalar("test_common_event", 5e5) + + self.assertListEqual( + observed_keys, + ["test_unique_event", "test_common_event", "test_common_event"], + ) + self.assertListEqual( + observed_values, + [1, 2.5, 5e5], + ) + def test_unregister_exist_callback_success(self): original_duration_listeners = jax_src_monitoring.get_event_duration_listeners() - callback = lambda event, durations: None + callback = lambda event, durations, **kwargs: None self.assertNotIn(callback, original_duration_listeners) monitoring.register_event_duration_secs_listener(callback) self.assertIn(callback, jax_src_monitoring.get_event_duration_listeners()) @@ -78,7 +103,7 @@ def test_unregister_exist_callback_success(self): jax_src_monitoring.get_event_duration_listeners()) def test_unregister_not_exist_callback_fail(self): - callback = lambda event, durations: None + callback = lambda event, durations, **kwargs: None self.assertNotIn(callback, jax_src_monitoring.get_event_duration_listeners()) @@ -88,7 +113,7 @@ def test_unregister_not_exist_callback_fail(self): def test_unregister_callback_index_in_range_success(self): original_duration_listeners = jax_src_monitoring.get_event_duration_listeners() - callback = lambda event, durations: None + callback = lambda event, durations, **kwargs: None self.assertNotIn(callback, original_duration_listeners) monitoring.register_event_duration_secs_listener(callback) self.assertIn(callback, jax_src_monitoring.get_event_duration_listeners()) @@ -114,7 +139,7 @@ def test_unregister_callback_index_out_of_range_fail(self): def test_get_event_duration_listeners_returns_a_copy(self): original_duration_listeners = jax_src_monitoring.get_event_duration_listeners() - callback = lambda event, durations: None + callback = lambda event, durations, **kwargs: None original_duration_listeners.append(callback) diff --git a/tests/mosaic/BUILD b/tests/mosaic/BUILD index 71b2b7d80570..ffaa0c3c843f 100644 --- a/tests/mosaic/BUILD +++ b/tests/mosaic/BUILD @@ -33,19 +33,55 @@ jax_multiplatform_test( name = "gpu_test", srcs = ["gpu_test.py"], enable_backends = [], - enable_configs = [ - "gpu_h100", - "gpu_h100x2", - ], + enable_configs = ["gpu_h100"], env = {"XLA_FLAGS": "--xla_gpu_autotune_level=0"}, shard_count = 16, tags = [ - "multiaccelerator", "noasan", # Times out. ], deps = [ "//jax:mosaic_gpu", - ] + py_deps("absl/testing") + py_deps("numpy"), + ] + py_deps([ + "absl/testing", + "numpy", + "hypothesis", + ]), +) + +jax_multiplatform_test( + name = "gpu_test_multidevice", + srcs = ["gpu_test_multidevice.py"], + enable_backends = [], + enable_configs = ["gpu_h100x2"], + env = {"XLA_FLAGS": "--xla_gpu_autotune_level=0"}, + tags = ["multiaccelerator"], + deps = [ + "//jax:mosaic_gpu", + ] + py_deps([ + "absl/testing", + "numpy", + ]), +) + +jax_multiplatform_test( + name = "gpu_test_distributed", + srcs = ["gpu_test_distributed.py"], + args = [ + "--num_processes=2", + "--gpus_per_process=1", + ], + enable_backends = [], + enable_configs = ["gpu_h100x2"], + env = {"XLA_FLAGS": "--xla_gpu_autotune_level=0 --xla_gpu_experimental_enable_nvshmem=true"}, + tags = ["multiaccelerator"], + deps = [ + "//jax:experimental", + "//jax:mosaic_gpu", + "//jax:test_multiprocess", + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_py_test( @@ -75,7 +111,10 @@ jax_py_test( "//jax", "//jax:mosaic_gpu", "//jax:test_util", - ] + py_deps("absl/testing"), + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -87,7 +126,12 @@ jax_multiplatform_test( deps = [ "//jax:mosaic_gpu", "//jax/experimental/mosaic/gpu/examples:matmul", - ] + py_deps("absl/testing") + py_deps("numpy") + py_deps("hypothesis"), + "//jax/experimental/mosaic/gpu/examples:matmul_blackwell", + ] + py_deps([ + "absl/testing", + "numpy", + "hypothesis", + ]), ) jax_multiplatform_test( @@ -96,10 +140,16 @@ jax_multiplatform_test( enable_backends = [], enable_configs = ["gpu_h100"], main = "//jax/experimental/mosaic/gpu/examples:flash_attention.py", - tags = ["notap"], + tags = [ + "manual", + "notap", + ], deps = [ "//jax:mosaic_gpu", - ] + py_deps("numpy"), + ] + py_deps([ + "numpy", + "absl/testing", + ]), ) jax_multiplatform_test( diff --git a/tests/mosaic/gpu_dialect_test.py b/tests/mosaic/gpu_dialect_test.py index ba9d23fa5b4f..6a49181e10c8 100644 --- a/tests/mosaic/gpu_dialect_test.py +++ b/tests/mosaic/gpu_dialect_test.py @@ -14,7 +14,7 @@ # ============================================================================== """(Deviceless) tests for the Mosaic GPU MLIR dialect.""" -from typing import Callable +from collections.abc import Callable from absl.testing import parameterized import jax @@ -34,6 +34,7 @@ from jax.experimental.mosaic import gpu as mgpu from jax.experimental.mosaic.gpu import layouts from jax.experimental.mosaic.gpu import utils as mgpu_utils +from jax.experimental.mosaic.gpu import dialect_lowering as lowering _cext = mgpu.dialect._cext if mgpu.dialect is not None else None @@ -592,6 +593,128 @@ def test_wgmma_b_n_dim_not_equal_to_acc_n_dim(self): ): self.module.operation.verify() + def test_tiled_layout_attr_parsing(self): + with ir.InsertionPoint(self.module.body): + for layout in ( + mgpu.WGMMA_LAYOUT, + mgpu.WGMMA_ROW_LAYOUT, + mgpu.WGMMA_COL_LAYOUT, + mgpu.WGMMA_TRANSPOSED_LAYOUT, + ): + attr = layouts.to_tiled_layout_attr(layout) + parsed_layout = layouts.from_tiled_layout_attr(attr) + self.assertEqual(layout, parsed_layout) + + def test_broadcast_in_dim_ok(self): + with ir.InsertionPoint(self.module.body): + func.FuncOp.from_py_func( + ir.VectorType.get([64], ir.F32Type.get()), + name="broadcast_in_dim", + )( + lambda operand: mgpu.dialect.broadcast_in_dim( + ir.VectorType.get([64, 64], ir.F32Type.get()), + operand, + broadcast_dimensions=[0], + ) + ) + + self.assertTrue(self.module.operation.verify()) + + def test_broadcast_in_dim_no_0d(self): + with ir.InsertionPoint(self.module.body): + func.FuncOp.from_py_func( + ir.VectorType.get([], ir.F32Type.get()), + name="broadcast_in_dim", + )( + lambda operand: mgpu.dialect.broadcast_in_dim( + ir.VectorType.get([64], ir.F32Type.get()), + operand, + broadcast_dimensions=[], + ) + ) + + with self.assertRaisesRegex( + ir.MLIRError, + r"The input vector must have rank > 0", + ): + self.module.operation.verify() + + def test_broadcast_in_dim_no_input_larger_than_output(self): + with ir.InsertionPoint(self.module.body): + func.FuncOp.from_py_func( + ir.VectorType.get([64, 64], ir.F32Type.get()), + name="broadcast_in_dim", + )( + lambda operand: mgpu.dialect.broadcast_in_dim( + ir.VectorType.get([64], ir.F32Type.get()), + operand, + broadcast_dimensions=[], + ) + ) + + with self.assertRaisesRegex( + ir.MLIRError, + r"rank of the input vector must be smaller", + ): + self.module.operation.verify() + + def test_broadcast_in_dim_too_many_dims(self): + with ir.InsertionPoint(self.module.body): + func.FuncOp.from_py_func( + ir.VectorType.get([64], ir.F32Type.get()), + name="broadcast_in_dim", + )( + lambda operand: mgpu.dialect.broadcast_in_dim( + ir.VectorType.get([64, 64], ir.F32Type.get()), + operand, + broadcast_dimensions=[0, 1], + ) + ) + + with self.assertRaisesRegex( + ir.MLIRError, + r"size of the `broadcast_dimensions` attribute must be", + ): + self.module.operation.verify() + + def test_broadcast_in_dim_dim_oob(self): + with ir.InsertionPoint(self.module.body): + func.FuncOp.from_py_func( + ir.VectorType.get([64], ir.F32Type.get()), + name="broadcast_in_dim", + )( + lambda operand: mgpu.dialect.broadcast_in_dim( + ir.VectorType.get([64, 64], ir.F32Type.get()), + operand, + broadcast_dimensions=[2], + ) + ) + + with self.assertRaisesRegex( + ir.MLIRError, + r"must be in the range \[0, result.shape.rank", + ): + self.module.operation.verify() + + def test_broadcast_in_dim_dim_transpose(self): + with ir.InsertionPoint(self.module.body): + func.FuncOp.from_py_func( + ir.VectorType.get([64, 64, 64, 64], ir.F32Type.get()), + name="broadcast_in_dim", + )( + lambda operand: mgpu.dialect.broadcast_in_dim( + ir.VectorType.get([64, 64, 64, 64], ir.F32Type.get()), + operand, + broadcast_dimensions=[0, 1, 3, 2], + ) + ) + + with self.assertRaisesRegex( + ir.MLIRError, + r"`broadcast_dimensions` attribute must be strictly increasing", + ): + self.module.operation.verify() + class DialectLoweringTest(MosaicGpuTest): @@ -653,11 +776,14 @@ def test_initialize_barrier_op_lowering_rule(self): # One nvvm.mbarrier_init_shared is issued per barrier. self.assertLen(all_mbarrier_init_shared_ops, num_shape_elements) - # Each barrier has its count equal to the arrival count. + # Each barrier has its count equal to the arrival count times the + # warpgroup size. for op in all_mbarrier_init_shared_ops: count = op.count.owner.opview self.assertIsInstance(count, arith.ConstantOp) - self.assertEqual(count.literal_value, arrival_count) + self.assertEqual( + count.literal_value, arrival_count * mgpu_utils.WARPGROUP_SIZE + ) def test_lowering_vector_op_without_layout_fails(self): shape = (3, 4) @@ -810,8 +936,11 @@ def test_lowering_slice_smem_op(self): def body(): nonlocal offset i32 = ir.IntegerType.get_signless(32) + smem = ir.Attribute.parse("#gpu.address_space") + memref_ty = ir.MemRefType.get((4, 32), i32, memory_space=smem) offset = arith.constant(i32, shift) - mgpu.dialect.slice_smem(i32, offset) + op = mgpu.dialect.SliceSMEMOp(memref_ty, offset) + op.attributes["out_transforms"] = ir.ArrayAttr.get([ir.ArrayAttr.get([])]) with ir.InsertionPoint(self.module.body): func.FuncOp.from_py_func()(body) @@ -862,6 +991,69 @@ def test_lower_conversion_op_lowers_to_same_op(self, op, in_dtype, out_dtype): self.assertLen(conversion_ops, 1) self.assertEqual(conversion_ops[0].result.type, scalar_out_ty) + @parameterized.parameters( + (True, False, False), + (False, True, False), + (False, False, True), + ) + def test_custom_primitive_op_must_have_number_of_annotations_matching_operands_and_results( + self, omit_in_layouts, omit_in_transforms, omit_out_layouts + ): + vec_ty = ir.VectorType.get((4, 32), ir.BF16Type.get()) + out_layouts = [ + layouts.to_layout_attr( + mgpu.WGStridedFragLayout.from_shaped_type(vec_ty) + ) + ] + in_layouts = out_layouts * 2 + in_transforms = [ + ir.ArrayAttr.get([mgpu.dialect.SwizzleTransformAttr.get(128)]) + ] + + in_layouts = [] if omit_in_layouts else in_layouts + in_transforms = [] if omit_in_transforms else in_transforms + out_layouts = [] if omit_out_layouts else out_layouts + + def body(vec1, vec2, ref): + mgpu.dialect.custom_primitive( + [vec_ty], [vec1, vec2, ref], in_layouts, in_transforms, out_layouts + ) + + with ir.InsertionPoint(self.module.body): + smem = ir.Attribute.parse("#gpu.address_space") + ref_ty = ir.MemRefType.get((4, 32), ir.BF16Type.get(), memory_space=smem) + func.FuncOp.from_py_func(vec_ty, vec_ty, ref_ty)(body) + + if omit_in_layouts: + error = "layout for each vector operand" + elif omit_in_transforms: + error = "transforms for each memref operand in smem" + else: + assert omit_out_layouts + error = "layout for each result" + + with self.assertRaisesRegex(ir.MLIRError, error): + self.module.operation.verify() + + def test_memref_transforms_with_transpose(self): + with ir.InsertionPoint(self.module.body): + ty_in = ir.MemRefType.get( + (64, 128), + ir.BF16Type.get(), + memory_space=ir.Attribute.parse("#gpu.address_space"), + ) + ref = memref.alloc(ty_in, [], []) + + ref = mgpu_utils.memref_transpose(ref, (1, 0)) + # This tiling is applied to the transposed memref. + transforms = [mgpu.TileTransform(tiling=(16, 32))] + + ref_transformed = lowering.reinterpret_smem_ref(ref, transforms) + ty_transformed = ir.MemRefType(ref_transformed.type) + self.assertEqual(ty_transformed.shape, [8, 2, 16, 32]) + strides, _ = ty_transformed.get_strides_and_offset() + self.assertEqual(strides, [512, 4096, 1, 16]) + if __name__ == "__main__": parameterized.absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/mosaic/gpu_layout_inference_test.py b/tests/mosaic/gpu_layout_inference_test.py index 36c8ff9cf47e..53c0aae14dc2 100644 --- a/tests/mosaic/gpu_layout_inference_test.py +++ b/tests/mosaic/gpu_layout_inference_test.py @@ -40,6 +40,13 @@ def _make_ir_context(): return context +def layout_cast(x: ir.Value, layout: mgpu.FragmentedLayout | ir.Attribute) -> ir.Value: + """Convenience wrapper around `mgpu.dialect.layout_cast`.""" + if isinstance(layout, mgpu.FragmentedLayout): + layout = layouts.to_layout_attr(layout) + return mgpu.dialect.layout_cast(x, layout) + + class LayoutInferenceTest(parameterized.TestCase): def setUp(self): @@ -50,6 +57,12 @@ def setUp(self): self.enter_context(ir.Location.unknown()) self.module = ir.Module.create() + def checkInLayouts(self, op, in_layouts): + self.assertSequenceEqual(op.attributes["in_layouts"], in_layouts) + + def checkOutLayouts(self, op, out_layouts): + self.assertSequenceEqual(op.attributes["out_layouts"], out_layouts) + def test_infer_strided_layout_default(self): shape = (16, 8) elt_type = ir.BF16Type.get() @@ -71,8 +84,8 @@ def body(a, b): mgpu.WGStridedFragLayout.from_shaped_type(ty) ) - self.assertSequenceEqual(add.attributes["in_layouts"], [layout, layout]) - self.assertSequenceEqual(add.attributes["out_layouts"], [layout]) + self.checkInLayouts(add, [layout, layout]) + self.checkOutLayouts(add, [layout]) def test_infer_strided_layout_from_shape_cast(self): shape = (16, 8) @@ -97,13 +110,13 @@ def body(x): mgpu.WGStridedFragLayout.from_shaped_type(dst_type) ) - self.assertSequenceEqual(op.attributes["in_layouts"], [in_layout]) - self.assertSequenceEqual(op.attributes["out_layouts"], [out_layout]) + self.checkInLayouts(op, [in_layout]) + self.checkOutLayouts(op, [out_layout]) # Ensure that we can recover the original layout. del op.attributes["in_layouts"] mgpu.infer_layout(self.module) - self.assertSequenceEqual(op.attributes["in_layouts"], [in_layout]) + self.checkInLayouts(op, [in_layout]) def test_infer_splat_layout_for_splat_constants(self): shape = (16, 8) @@ -124,17 +137,20 @@ def test_infer_splat_layout_for_splat_constants(self): layout = layouts.to_layout_attr(mgpu.WGSplatFragLayout(shape=shape)) self.assertEmpty(splat0.attributes["in_layouts"]) - self.assertSequenceEqual(splat0.attributes["out_layouts"], [layout]) + self.checkOutLayouts(splat0, [layout]) self.assertEmpty(splat1.attributes["in_layouts"]) - self.assertSequenceEqual(splat1.attributes["out_layouts"], [layout]) + self.checkOutLayouts(splat1, [layout]) - self.assertSequenceEqual(add.attributes["in_layouts"], [layout, layout]) - self.assertSequenceEqual(add.attributes["out_layouts"], [layout]) + self.checkInLayouts(add, [layout, layout]) + self.checkOutLayouts(add, [layout]) def test_infer_layout_from_consumer_for_non_splat_constant(self): shape = (16, 8) elt_type = ir.BF16Type.get() + layout = layouts.to_layout_attr( + mgpu.WGStridedFragLayout(shape=shape, vec_size=1) + ) with ir.InsertionPoint(self.module.body): ty = ir.VectorType.get(shape, elt_type) @@ -142,49 +158,41 @@ def test_infer_layout_from_consumer_for_non_splat_constant(self): ir.FloatAttr.get(elt_type, i) for i in range(shape[0] * shape[1]) ] c = arith.ConstantOp(ty, ir.DenseElementsAttr.get(attr_list, ty)) - add = arith.AddFOp(c, c) - - layout = layouts.to_layout_attr( - mgpu.WGStridedFragLayout(shape=shape, vec_size=1) - ) - add.attributes["in_layouts"] = ir.ArrayAttr.get([layout, layout]) + layout_cast(c, layout) mgpu.infer_layout(self.module) self.assertEmpty(c.attributes["in_layouts"]) - self.assertSequenceEqual(c.attributes["out_layouts"], [layout]) + self.checkOutLayouts(c, [layout]) @parameterized.parameters(True, False) def test_infer_splat_layout_for_vector_splat(self, rhs_splat): add = splat = None + shape = (16, 8) + layout = layouts.to_layout_attr(mgpu.WGSplatFragLayout(shape=shape)) def body(lhs, rhs): nonlocal add, splat + rhs = layout_cast(rhs, layout) if rhs_splat else rhs splat = vector.SplatOp(rhs.type, lhs) add = arith.AddFOp(splat.result, rhs) with ir.InsertionPoint(self.module.body): - shape = (16, 8) elt_type = ir.BF16Type.get() ty = ir.VectorType.get(shape, elt_type) - func_op = func.FuncOp.from_py_func(elt_type, ty)(body).func_op + func.FuncOp.from_py_func(elt_type, ty)(body) - layout = layouts.to_layout_attr(mgpu.WGSplatFragLayout(shape=shape)) - if rhs_splat: - func_op.attributes["in_layouts"] = ir.ArrayAttr.get([layout]) mgpu.infer_layout(self.module) self.assertEmpty(splat.attributes["in_layouts"]) - self.assertSequenceEqual(splat.attributes["out_layouts"], [layout]) + self.checkOutLayouts(splat, [layout]) - add_layout = layout - if not rhs_splat: - add_layout = layouts.to_layout_attr( - mgpu.WGStridedFragLayout.from_shaped_type(ty) - ) + add_layout = layout if rhs_splat else layouts.to_layout_attr( + mgpu.WGStridedFragLayout.from_shaped_type(ty) + ) - self.assertSequenceEqual(add.attributes["in_layouts"], [add_layout, add_layout]) - self.assertSequenceEqual(add.attributes["out_layouts"], [add_layout]) + self.checkInLayouts(add, [add_layout, add_layout]) + self.checkOutLayouts(add, [add_layout]) @parameterized.parameters( mgpu.WGSplatFragLayout(shape=(32, 4)), @@ -195,22 +203,120 @@ def test_pointwise_op_propagates_argument_layouts(self, layout): def body(lhs, rhs): nonlocal add + lhs = layout_cast(lhs, layout) + rhs = layout_cast(rhs, layout) add = arith.AddFOp(lhs, rhs) with ir.InsertionPoint(self.module.body): ty = ir.VectorType.get(layout.shape, ir.BF16Type.get()) func.FuncOp.from_py_func(ty, ty)(body) - [f] = self.module.body.operations + mgpu.infer_layout(self.module) + layout_attr = layouts.to_layout_attr(layout) - f.attributes["in_layouts"] = ir.ArrayAttr.get([layout_attr, layout_attr]) + self.checkInLayouts(add, [layout_attr, layout_attr]) + self.checkOutLayouts(add, [layout_attr]) + + def test_infer_layout_cast_layout(self): + add = cast = None + + shape = (128, 64) + splat_layout = layouts.to_layout_attr(mgpu.WGSplatFragLayout(shape=shape)) + wgmma_layout = layouts.to_layout_attr(mgpu.WGMMA_LAYOUT) + + def body(x): + nonlocal add, cast + x = mgpu.dialect.layout_cast(x, splat_layout) + add = arith.AddFOp(x, x) + cast = mgpu.dialect.LayoutCastOp(add.result, wgmma_layout) + + with ir.InsertionPoint(self.module.body): + elt_type = ir.BF16Type.get() + ty = ir.VectorType.get(shape, elt_type) + func.FuncOp.from_py_func(ty)(body) mgpu.infer_layout(self.module) + self.checkOutLayouts(add, [splat_layout]) + self.checkInLayouts(cast, [wgmma_layout]) + self.checkOutLayouts(cast, [wgmma_layout]) - self.assertSequenceEqual( - add.attributes["in_layouts"], [layout_attr, layout_attr] - ) - self.assertSequenceEqual(add.attributes["out_layouts"], [layout_attr]) + @parameterized.parameters( + (0, mgpu.WGMMA_ROW_LAYOUT, None, mgpu.WGMMA_ROW_LAYOUT, mgpu.WGMMA_LAYOUT), + (1, mgpu.WGMMA_COL_LAYOUT, None, mgpu.WGMMA_COL_LAYOUT, mgpu.WGMMA_LAYOUT), + (0, None, mgpu.WGMMA_LAYOUT, mgpu.WGMMA_ROW_LAYOUT, mgpu.WGMMA_LAYOUT), + (1, None, mgpu.WGMMA_LAYOUT, mgpu.WGMMA_COL_LAYOUT, mgpu.WGMMA_LAYOUT), + ) + def test_infer_broadcast_in_dim_layout( + self, broadcast_dim, in_cast, out_cast, in_layout, out_layout + ): + bcast = None + in_shape = (64,) + out_shape = (64, 64) + + def body(x): + nonlocal bcast + if in_cast is not None: + x = mgpu.dialect.LayoutCastOp(x, layouts.to_layout_attr(in_cast)) + + out_type = ir.VectorType.get(out_shape, ir.F32Type.get()) + bcast = mgpu.dialect.BroadcastInDimOp(out_type, x, [broadcast_dim]) + + if out_cast is not None: + mgpu.dialect.LayoutCastOp( + bcast.result, layouts.to_layout_attr(out_cast) + ) + + with ir.InsertionPoint(self.module.body): + ty = ir.VectorType.get(in_shape, ir.F32Type.get()) + func.FuncOp.from_py_func(ty)(body) + + mgpu.infer_layout(self.module) + self.checkInLayouts(bcast, [layouts.to_layout_attr(in_layout)]) + self.checkOutLayouts(bcast, [layouts.to_layout_attr(out_layout)]) + + @parameterized.parameters( + (1, mgpu.WGMMA_LAYOUT, None, None, mgpu.WGMMA_LAYOUT, mgpu.WGMMA_ROW_LAYOUT), + (0, mgpu.WGMMA_LAYOUT, None, None, mgpu.WGMMA_LAYOUT, mgpu.WGMMA_COL_LAYOUT), + (1, None, None, mgpu.WGMMA_ROW_LAYOUT, mgpu.WGMMA_LAYOUT, mgpu.WGMMA_ROW_LAYOUT), + (0, None, None, mgpu.WGMMA_COL_LAYOUT, mgpu.WGMMA_LAYOUT, mgpu.WGMMA_COL_LAYOUT), + (1, None, mgpu.WGMMA_ROW_LAYOUT, None, mgpu.WGMMA_LAYOUT, mgpu.WGMMA_ROW_LAYOUT), + (0, None, mgpu.WGMMA_COL_LAYOUT, None, mgpu.WGMMA_LAYOUT, mgpu.WGMMA_COL_LAYOUT), + (1, None, mgpu.WGMMA_ROW_LAYOUT, mgpu.WGMMA_ROW_LAYOUT, mgpu.WGMMA_LAYOUT, mgpu.WGMMA_ROW_LAYOUT), + (0, None, mgpu.WGMMA_COL_LAYOUT, mgpu.WGMMA_COL_LAYOUT, mgpu.WGMMA_LAYOUT, mgpu.WGMMA_COL_LAYOUT), + ) + def test_infer_multi_reduce_layout( + self, reduce_dim, in_cast, acc_cast, out_cast, in_layout, out_layout + ): + red = None + + in_shape = (64, 64) + out_shape = (64,) + + def body(x, acc): + nonlocal red + if in_cast is not None: + x = mgpu.dialect.LayoutCastOp(x, layouts.to_layout_attr(in_cast)) + if acc_cast is not None: + acc = mgpu.dialect.LayoutCastOp(acc, layouts.to_layout_attr(acc_cast)) + + kind = vector.CombiningKind.MAXIMUMF + red = vector.MultiDimReductionOp(kind, x, acc, [reduce_dim]) + + if out_cast is not None: + mgpu.dialect.LayoutCastOp( + red.result, layouts.to_layout_attr(out_cast) + ) + + with ir.InsertionPoint(self.module.body): + in_ty = ir.VectorType.get(in_shape, ir.F32Type.get()) + acc_ty = ir.VectorType.get(out_shape, ir.F32Type.get()) + func.FuncOp.from_py_func(in_ty, acc_ty)(body) + + mgpu.infer_layout(self.module) + in_layout_attr = layouts.to_layout_attr(in_layout) + out_layout_attr = layouts.to_layout_attr(out_layout) + self.checkInLayouts(red, [in_layout_attr, out_layout_attr]) + self.checkOutLayouts(red, [out_layout_attr]) def test_infer_layout_traverses_ops_correctly(self): shape = (16, 8) @@ -247,26 +353,23 @@ def body(a, b): def test_infer_layout_from_yield_op_in_layouts_for_for_op( self, shape, layout ): - add_op = for_op = yield_op = None + for_op = yield_op = None def body(lower_bound, upper_bound, step, a, b): nonlocal for_op for_op = scf.ForOp(lower_bound, upper_bound, step, [a, b]) [loop_a, loop_b] = list(for_op.inner_iter_args) with ir.InsertionPoint(for_op.body): - nonlocal add_op, yield_op - add_op = arith.AddFOp(loop_a, loop_b) - yield_op = scf.YieldOp([add_op.result, add_op.result]) + nonlocal yield_op + add = arith.addf(loop_a, loop_b) + add = layout_cast(add, layout) + yield_op = scf.YieldOp([add, add]) with ir.InsertionPoint(self.module.body): ab_type = ir.VectorType.get(shape, ir.BF16Type.get()) i32 = ir.IntegerType.get_signless(32) func.FuncOp.from_py_func(i32, i32, i32, ab_type, ab_type)(body) - add_op.attributes["out_layouts"] = ir.ArrayAttr.get( - [layouts.to_layout_attr(layout)] - ) - mgpu.infer_layout(self.module) if isinstance(layout, mgpu.WGSplatFragLayout): @@ -279,14 +382,14 @@ def body(lower_bound, upper_bound, step, a, b): mgpu.WGStridedFragLayout.from_shaped_type(ab_type) ) carry_layouts = [strided_layout, strided_layout] - self.assertSequenceEqual(yield_op.attributes["out_layouts"], []) - self.assertSequenceEqual(for_op.attributes["in_layouts"], carry_layouts) - self.assertSequenceEqual(for_op.attributes["out_layouts"], carry_layouts) + self.checkOutLayouts(yield_op, []) + self.checkInLayouts(for_op, carry_layouts) + self.checkOutLayouts(for_op, carry_layouts) else: carry_layouts = [layouts.to_layout_attr(layout)] * 2 - self.assertSequenceEqual(yield_op.attributes["out_layouts"], []) - self.assertSequenceEqual(for_op.attributes["in_layouts"], carry_layouts) - self.assertSequenceEqual(for_op.attributes["out_layouts"], carry_layouts) + self.checkOutLayouts(yield_op, []) + self.checkInLayouts(for_op, carry_layouts) + self.checkOutLayouts(for_op, carry_layouts) def test_infer_layout_from_body_op_to_yield_op_to_for_op(self): for_op = yield_op = None @@ -310,10 +413,49 @@ def body(lower_bound, upper_bound, step, a, b, c): mgpu.infer_layout(self.module) wgmma_layout = layouts.to_layout_attr(mgpu.WGMMA_LAYOUT) - self.assertSequenceEqual(yield_op.attributes["in_layouts"], [wgmma_layout]) - self.assertSequenceEqual(yield_op.attributes["out_layouts"], []) - self.assertSequenceEqual(for_op.attributes["in_layouts"], [wgmma_layout]) - self.assertSequenceEqual(for_op.attributes["out_layouts"], [wgmma_layout]) + self.checkInLayouts(yield_op, [wgmma_layout]) + self.checkOutLayouts(yield_op, []) + self.checkInLayouts(for_op, [wgmma_layout]) + self.checkOutLayouts(for_op, [wgmma_layout]) + + @parameterized.parameters( + ((), None, (), None), + ((64, 32), mgpu.WGMMA_LAYOUT, (), None), + ((), None, (64, 32), mgpu.WGMMA_LAYOUT), + ((64,), mgpu.WGMMA_ROW_LAYOUT, (64, 32), mgpu.WGMMA_LAYOUT), + ) + def test_infer_while_op_layouts( + self, init_shape, init_layout, result_shape, result_layout + ): + f32 = ir.F32Type.get() + in_type = ir.VectorType.get(init_shape, f32) if init_shape else f32 + out_type = ir.VectorType.get(result_shape, f32) if result_shape else f32 + while_op = condition_op = yield_op = None + + def body(condition, init, result): + nonlocal while_op, condition_op, yield_op + init = layout_cast(init, init_layout) if init_layout else init + result = layout_cast(result, result_layout) if result_layout else result + while_op = scf.WhileOp([out_type], [init]) + before_block = while_op.before.blocks.append(init.type) + with ir.InsertionPoint(before_block): + condition_op = scf.ConditionOp(condition, [result]) + + after_block = while_op.after.blocks.append(out_type) + with ir.InsertionPoint(after_block): + yield_op = scf.YieldOp([init]) + + with ir.InsertionPoint(self.module.body): + i1 = ir.IntegerType.get_signless(1) + func.FuncOp.from_py_func(i1, in_type, out_type)(body) + + mgpu.infer_layout(self.module) + + if init_layout is not None or result_layout is not None: + init_layouts = [layouts.to_layout_attr(init_layout)] if init_layout else [] + result_layouts = [layouts.to_layout_attr(result_layout)] if result_layout else [] + self.checkInLayouts(while_op, init_layouts) + self.checkOutLayouts(while_op, result_layouts) def test_infer_layout_has_no_layout_for_non_vector_types(self): shape = (32, 4) @@ -349,98 +491,108 @@ def test_infer_layout_picks_non_splat_layout_over_splat_layout( self, layout ): add = None + shape = (32, 4) + splat_layout = layouts.to_layout_attr(mgpu.WGSplatFragLayout(shape)) + non_splat_layout = layouts.to_layout_attr(layout) def body(lhs, rhs): nonlocal add + lhs = layout_cast(lhs, non_splat_layout) + rhs = layout_cast(rhs, splat_layout) add = arith.AddFOp(lhs, rhs) with ir.InsertionPoint(self.module.body): - shape = (32, 4) elt_type = ir.BF16Type.get() ty = ir.VectorType.get(shape, elt_type) - - f = func.FuncOp.from_py_func(ty, ty)(body).func_op - - splat_layout = layouts.to_layout_attr(mgpu.WGSplatFragLayout(shape)) - non_splat_layout = layouts.to_layout_attr(layout) - - f.attributes["in_layouts"] = ir.ArrayAttr.get( - [non_splat_layout, splat_layout] - ) + func.FuncOp.from_py_func(ty, ty)(body) mgpu.infer_layout(self.module) - self.assertSequenceEqual( - add.attributes["in_layouts"], - [non_splat_layout, non_splat_layout], - ) - self.assertSequenceEqual(add.attributes["out_layouts"], [non_splat_layout]) + self.checkInLayouts(add, [non_splat_layout, non_splat_layout]) + self.checkOutLayouts(add, [non_splat_layout]) def test_infer_layout_preserves_splat_layouts_in_producers(self): add0 = add1 = None + shape = (32, 4) + splat_layout = layouts.to_layout_attr(mgpu.WGSplatFragLayout(shape)) + strided_layout = layouts.to_layout_attr( + mgpu.WGStridedFragLayout(shape, vec_size=1) + ) def body(lhs, rhs): nonlocal add0, add1 + lhs = layout_cast(lhs, splat_layout) + rhs = layout_cast(rhs, splat_layout) add0 = arith.AddFOp(lhs, rhs) - add1 = arith.AddFOp(add0.result, add0) + cast = layout_cast(add0, strided_layout) + add1 = arith.AddFOp(cast, cast) with ir.InsertionPoint(self.module.body): - shape = (32, 4) elt_type = ir.BF16Type.get() ty = ir.VectorType.get(shape, elt_type) - f = func.FuncOp.from_py_func(ty, ty)(body).func_op + func.FuncOp.from_py_func(ty, ty)(body) - splat_layout = layouts.to_layout_attr(mgpu.WGSplatFragLayout(shape)) - strided_layout = layouts.to_layout_attr( - mgpu.WGStridedFragLayout(shape, vec_size=1) - ) - f.attributes["in_layouts"] = ir.ArrayAttr.get([splat_layout, splat_layout]) - add1.attributes["out_layouts"] = ir.ArrayAttr.get([strided_layout]) mgpu.infer_layout(self.module) - self.assertSequenceEqual( - add0.attributes["in_layouts"], [splat_layout, splat_layout] - ) - self.assertSequenceEqual( - add1.attributes["in_layouts"], [strided_layout, strided_layout] - ) + self.checkInLayouts(add0, [splat_layout, splat_layout]) + self.checkOutLayouts(add0, [splat_layout]) + self.checkInLayouts(add1, [strided_layout, strided_layout]) + self.checkOutLayouts(add1, [strided_layout]) - self.assertSequenceEqual(add0.attributes["out_layouts"], [splat_layout]) - self.assertSequenceEqual(add1.attributes["out_layouts"], [strided_layout]) - - def test_infer_layout_propagates_func_layouts_to_ops(self): - add = None + def test_infer_layout_does_not_assign_default_layouts_to_func(self): def body(lhs, rhs): - nonlocal add - add = arith.AddFOp(lhs, rhs) + arith.AddFOp(lhs, rhs) with ir.InsertionPoint(self.module.body): shape = (32, 4) ty = ir.VectorType.get(shape, ir.BF16Type.get()) f = func.FuncOp.from_py_func(ty, ty)(body).func_op - splat_layout = layouts.to_layout_attr(mgpu.WGSplatFragLayout(shape)) - f.attributes["in_layouts"] = ir.ArrayAttr.get([splat_layout, splat_layout]) mgpu.infer_layout(self.module) + self.assertNotIn("in_layouts", f.attributes) + self.assertNotIn("out_layouts", f.attributes) - self.assertSequenceEqual( - add.attributes["in_layouts"], [splat_layout, splat_layout]) - self.assertSequenceEqual(add.attributes["out_layouts"], [splat_layout]) + def test_optimization_barrier_op_propagates_user_layouts(self): + add = optimization_barrier = None + wgmma_layout = layouts.to_layout_attr(mgpu.WGMMA_LAYOUT) - def test_infer_layout_does_not_assign_default_layouts_to_func(self): + def body(lhs, rhs): + nonlocal add, optimization_barrier + optimization_barrier = mgpu.dialect.OptimizationBarrierOp([lhs, rhs]) + lhs, rhs = optimization_barrier.results + add = arith.AddFOp(lhs, rhs) + add = layout_cast(add, wgmma_layout) + + with ir.InsertionPoint(self.module.body): + ty = ir.VectorType.get((32, 4), ir.BF16Type.get()) + func.FuncOp.from_py_func(ty, ty)(body) + + mgpu.infer_layout(self.module) + + self.checkInLayouts(optimization_barrier, [wgmma_layout, wgmma_layout]) + self.checkOutLayouts(optimization_barrier, [wgmma_layout, wgmma_layout]) + + def test_optimization_barrier_op_propagates_producer_layouts(self): + add = optimization_barrier = None + shape = (32, 4) + splat_layout = layouts.to_layout_attr(mgpu.WGSplatFragLayout(shape)) def body(lhs, rhs): - arith.AddFOp(lhs, rhs) + nonlocal add, optimization_barrier + lhs = layout_cast(lhs, splat_layout) + rhs = layout_cast(rhs, splat_layout) + add = arith.AddFOp(lhs, rhs) + optimization_barrier = mgpu.dialect.OptimizationBarrierOp([add]) with ir.InsertionPoint(self.module.body): - shape = (32, 4) ty = ir.VectorType.get(shape, ir.BF16Type.get()) - f = func.FuncOp.from_py_func(ty, ty)(body).func_op + func.FuncOp.from_py_func(ty, ty)(body) mgpu.infer_layout(self.module) - self.assertNotIn("in_layouts", f.attributes) - self.assertNotIn("out_layouts", f.attributes) + + self.checkInLayouts(optimization_barrier, [splat_layout]) + self.checkOutLayouts(optimization_barrier, [splat_layout]) if __name__ == "__main__": diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index e7bd7fad3798..d7db2b85d27a 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -20,18 +20,20 @@ import itertools import math import operator -import os +import sys import re import unittest from absl.testing import absltest, parameterized import jax from jax._src import config +from jax._src import lib as jaxlib from jax._src import test_util as jtu from jax._src.interpreters import mlir from jax._src.lib.mlir import ir from jax._src.lib.mlir import passmanager from jax._src.lib.mlir.dialects import arith +from jax._src.lib.mlir.dialects import cf from jax._src.lib.mlir.dialects import scf from jax._src.lib.mlir.dialects import vector from jax.experimental.mosaic.gpu import dialect as mgpu_dialect # pylint: disable=g-importing-member @@ -40,9 +42,10 @@ import jax.numpy as jnp import numpy as np try: - import jax._src.lib.mosaic_gpu # noqa: F401 + import jax._src.lib.mosaic_gpu as mosaic_gpu_lib # noqa: F401 HAS_MOSAIC_GPU = True except ImportError: + mosaic_gpu_lib = None HAS_MOSAIC_GPU = False class Dimension(enum.IntEnum): # Just to make parameterized tests expand ok @@ -53,6 +56,7 @@ class Dimension(enum.IntEnum): # Just to make parameterized tests expand ok import jax.experimental.mosaic.gpu as mgpu from jax.experimental.mosaic.gpu import core from jax.experimental.mosaic.gpu import launch_context + from jax.experimental.mosaic.gpu import layouts from jax.experimental.mosaic.gpu import utils as utils from jax.experimental.mosaic.gpu import profiler from jax.experimental.mosaic.gpu import inference_utils @@ -60,6 +64,12 @@ class Dimension(enum.IntEnum): # Just to make parameterized tests expand ok from jax._src.lib.mlir.dialects import gpu from jax._src.lib.mlir.dialects import llvm Dimension = gpu.Dimension +try: + import hypothesis as hp + import hypothesis.strategies as hps + jtu.setup_hypothesis() +except ImportError: + hp = hps = None # ruff: noqa: F405 @@ -85,20 +95,6 @@ def mlir_sum(elems): return total -@contextlib.contextmanager -def get_sass(): - prev_dump = os.environ.get("MOSAIC_GPU_DUMP_SASS", None) - os.environ["MOSAIC_GPU_DUMP_SASS"] = "1" - try: - with jtu.capture_stdout() as output: - yield output - finally: - if prev_dump is not None: - os.environ["MOSAIC_GPU_DUMP_SASS"] = prev_dump - else: - del os.environ["MOSAIC_GPU_DUMP_SASS"] - - def copy(src: ir.Value, dst: ir.Value, swizzle: int | None = None): index = ir.IndexType.get() thread_id = gpu.thread_id(gpu.Dimension.x) @@ -233,12 +229,22 @@ def setUp(self): super().setUp() self.prng = np.random.default_rng(1234) self.context = mlir.make_ir_context() - if mgpu_dialect is not None: - mgpu_dialect.register_dialect(self.context) + mgpu_dialect.register_dialect(self.context) self.enter_context(config.traceback_filtering("off")) self.enter_context(self.context) self.enter_context(ir.Location.unknown()) + @contextlib.contextmanager + def capture_stdout(self): + if "pytest" in sys.modules: + self.skipTest("pytest interacts badly with GPU stdout capture") + if mosaic_gpu_lib is None: + raise ValueError("Running tests but missing Mosaic GPU extension") + with jtu.capture_stdout() as stdout: + yield stdout + # We need to cudaDeviceSynchronize to make sure printfs are flushed. + mosaic_gpu_lib._mosaic_gpu_ext._sync_all_devices() + class Sm90ATestCase(TestCase, jtu.CudaArchSpecificTest): @@ -382,17 +388,19 @@ def kernel(ctx, inp, out, _): ("add_1s", (5, 1, 2), (1, 1, 5, 1, 1, 2, 1, 1)), ("fold", (1, 5, 2, 1,), (1, 10, 1)), ("un", (1, 10, 1), (1, 5, 2, 1,)), + ("to_scalar", (1, 1, 1), ()), + ("from_scalar", (), (1, 1, 1)), ) def test_reshape(self, inp_shape, out_shape): def kernel(ctx, inp, out, _): copy(memref_reshape(inp, out_shape), out) - x = np.arange(math.prod(inp_shape), dtype=jnp.float32).reshape(*inp_shape) + x = np.arange(math.prod(inp_shape), dtype=jnp.float32).reshape(inp_shape) out_ty = jax.ShapeDtypeStruct(out_shape, jnp.float32) y = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), x, out_ty, () )(x) - np.testing.assert_array_equal(y, x.reshape(*out_shape)) + np.testing.assert_array_equal(y, x.reshape(out_shape)) @parameterized.named_parameters([ ("packed", (4, 4, 4), (16, 4, 1), 1, 2, False), @@ -452,7 +460,7 @@ def test_scalar_argument(self, dtype): " values read from the 32-bit input buffer to sometimes" " (nondeterministically) contain garbage.") - scalar = 42 + scalar = dtype(42) expected = np.full((128, 128), scalar, dtype=dtype) def kernel(ctx, inp, out, _): @@ -489,24 +497,51 @@ def get_packed_shape(strides, shape): class WGMMALayoutTest(TestCase): - @parameterized.product(dtype=[jnp.float16, jnp.float32], - transposed_smem=[False, True]) - def test_store_untiled(self, dtype, transposed_smem): + @parameterized.product(dtype=[jnp.float16, jnp.float32]) + def test_store_untiled(self, dtype): def kernel(ctx, out, _): del ctx - if transposed_smem: - out = memref_transpose(out, (1, 0)) - iota_tensor(64, 64, dtype).store_untiled( - out, vector_store=not transposed_smem - ) + iota_tensor(64, 64, dtype).store_untiled(out, optimized=False) expected = np.arange(64 * 64, dtype=dtype).reshape(64, 64) - if transposed_smem: - expected = expected.T iota = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), (), expected, () )() np.testing.assert_array_equal(iota, expected) + @parameterized.product( + dtype=[jnp.float8_e5m2fnuz, jnp.float8_e5m2, jnp.float8_e4m3b11fnuz, + jnp.float8_e4m3fn, jnp.float8_e4m3fnuz], + swizzle=(32, 64, 128), + num_col_tiles=(1, 2, 3), + ) + def test_load_and_store_tiled_f8(self, dtype, swizzle, num_col_tiles): + # We use a different test than `test_store_tiled` because converting + # `iota` to `f8` type requires additional specialized logic that is not + # yet available. + col_tiling = swizzle + m = 128 + n = col_tiling * num_col_tiles + tiling = (64, col_tiling) + def kernel(ctx, inp, out, smem): + del ctx + smem_inp, smem_out = smem + copy(inp, smem_inp, swizzle=swizzle) + arr = mgpu.FragmentedArray.load_tiled(smem_inp, swizzle=swizzle) + arr.store_tiled(smem_out, swizzle=swizzle) + copy(smem_out, out, swizzle=swizzle) + expected = ( + jax.random.randint( + jax.random.key(42), (m * n,), -16, 15, dtype=jnp.int8 + ) + .reshape(m // tiling[0], tiling[0], n // tiling[1], tiling[1]) + .astype(dtype) + .transpose(0, 2, 1, 3) + ) + res = mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), expected, expected, (expected,) * 2 + )(expected) + np.testing.assert_array_equal(res, expected) + @parameterized.product( dtype=[jnp.float32, jnp.float16, jnp.int8], swizzle=(32, 64, 128), @@ -534,25 +569,95 @@ def kernel(ctx, out, smem): )() np.testing.assert_array_equal(iota, expected) - @parameterized.parameters(jnp.int8, jnp.int16, jnp.int32) - def test_sub_byte_conversion(self, jax_dtype_to): + @parameterized.product( + jax_dtype_to=( + jnp.int8, jnp.int16, jnp.int32, jnp.bfloat16, jnp.float8_e4m3fn, + ), + # Use different layouts to vary the size of the vector dimension. + layout=( + fa.WGMMA_LAYOUT, + fa.WGMMA_LAYOUT_UPCAST_2X, + fa.WGMMA_LAYOUT_UPCAST_4X, + ), + ) + def test_sub_byte_conversion(self, jax_dtype_to, layout: fa.TiledLayout): + if jax_dtype_to == jnp.int32 and layout.vector_length == 8: + self.skipTest( + "Raises: failed to prove that vector transfers don't cross swizzle" + " tile boundaries.") jax_dtype_from = jnp.int4 + if jnp.issubdtype(jax_dtype_to, jnp.integer): + is_signed = jnp.issubdtype(jax_dtype_to, jnp.signedinteger) + else: + is_signed = None def kernel(ctx, inp, out, smem): del ctx # Unused. smem_inp, smem_out = smem copy(inp, smem_inp, swizzle=16) - t = mgpu.FragmentedArray.load_tiled(smem_inp, is_signed=True, swizzle=16) - t = t.astype(utils.dtype_to_ir_type(jax_dtype_to), is_signed=True) + t = mgpu.FragmentedArray.load_tiled( + smem_inp, is_signed=True, swizzle=16, layout=layout + ) + t = t.astype(utils.dtype_to_ir_type(jax_dtype_to), is_signed=is_signed) t.store_tiled(smem_out, swizzle=32 * jnp.dtype(jax_dtype_to).itemsize) copy(smem_out, out, swizzle=32 * jnp.dtype(jax_dtype_to).itemsize) x = self.prng.integers( low=-8, high=7, size=(1, 1, 64, 64), dtype=np.int32 ).astype(jax_dtype_from) - y = x.astype(jax_dtype_to) + y = jax.lax.convert_element_type(x, jax_dtype_to) f = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, y, (x, y)) np.testing.assert_array_equal(f(x), y) + @parameterized.parameters( + (jnp.float32, jnp.float8_e4m3fn), + (jnp.bfloat16, jnp.float8_e4m3fn) + ) + def test_f8_conversions(self, jax_dtype_from, jax_dtype_to): + mlir_dtype_to = utils.dtype_to_ir_type(jax_dtype_to) + def kernel(ctx, inp, out, smem): + del ctx + smem_from, smem_to = smem + copy(inp, smem_from, swizzle=128) + t = mgpu.FragmentedArray.load_tiled( + smem_from, + swizzle=128, + is_signed=None, + layout=fa.WGMMA_LAYOUT, + ) + t = t.astype(mlir_dtype_to, is_signed=utils.is_signed(jax_dtype_to)) + t.store_tiled(smem_to, swizzle=128) + copy(smem_to, out, swizzle=128) + + # These generative shenanigans are to ensure that we don't generate values + # that are too large for the target type. That is because the saturation + # behavior of the conversion is different between XLA and Mosaic GPU here + # (to use the NVIDIA internal, we allow Mosaic GPU to use the .satfinite + # modifier, which saturates to the largest finite value---while XLA would + # give us NaNs in this case). + max_finite_val = 0b111_1110 + + expected = jax.lax.bitcast_convert_type( + jax.random.randint( + jax.random.key(42), + (1, 1, 64, 128), + -max_finite_val, + max_finite_val + 1, + dtype=jnp.uint8, + ), + jax_dtype_to, + ) + x = expected.astype(jax_dtype_from) + + res = mgpu.as_gpu_kernel( + kernel, + (1, 1, 1), + (128, 1, 1), + x, + expected, + (x, expected), + )(x) + np.testing.assert_array_equal(res, expected) + @parameterized.product( jax_dtype_from_to=( (jnp.int8, jnp.bfloat16), @@ -643,6 +748,19 @@ def kernel(ctx, in_, out, smem): np.testing.assert_array_equal(iota, expected) +class I8Type: + """A type that represents a 8-bit signed integer. + + This is a workaround to bypass the fact that we don't have a proper 8-bit + integer type class available in MLIR, and can't instantiate types without a + MLIR context. + """ + + @staticmethod + def get(): # pylint: disable=no-method-argument + return ir.IntegerType.get_signless(8) + + class WGMMATest(TestCase): def setUp(self): @@ -653,7 +771,13 @@ def setUp(self): @parameterized.product( lhs_transpose=(False, True), rhs_transpose=(False, True), - in_mlir_dtype_cls=(ir.F16Type, ir.BF16Type, ir.F32Type), + in_mlir_dtype_cls=( + ir.F16Type, + ir.BF16Type, + ir.F32Type, + ir.Float8E5M2Type, + ir.Float8E4M3FNType, + ), m=(64, 128, 192), n=(64, 128, 192), k_steps=(1, 2), @@ -662,7 +786,67 @@ def setUp(self): rhs_tiling_kind=("large", "small", "small+no_transpose"), lhs_tiling_kind=("large", "small", "small+no_transpose"), ) - def test_wgmma_basic( + def test_wgmma_basic_float( + self, + lhs_transpose, + rhs_transpose, + in_mlir_dtype_cls, + m, + n, + k_steps, + swizzle, + jax_out_dtype, + rhs_tiling_kind, + lhs_tiling_kind, + ): + self._test_wgmma_basic( + m, + n, + k_steps, + in_mlir_dtype_cls, + lhs_transpose, + rhs_transpose, + swizzle, + jax_out_dtype, + rhs_tiling_kind, + lhs_tiling_kind, + ) + + @parameterized.product( + in_mlir_dtype_cls=(I8Type,), + m=(64, 128, 192), + n=(64, 128, 192), + k_steps=(1, 2), + swizzle=(32, 64, 128), + jax_out_dtype=(jnp.int32,), + rhs_tiling_kind=("large", "small", "small+no_transpose"), + lhs_tiling_kind=("large", "small"), + ) + def test_wgmma_basic_int( + self, + in_mlir_dtype_cls, + m, + n, + k_steps, + swizzle, + jax_out_dtype, + rhs_tiling_kind, + lhs_tiling_kind, + ): + self._test_wgmma_basic( + m, + n, + k_steps, + in_mlir_dtype_cls, + lhs_transpose=False, + rhs_transpose=True, + swizzle=swizzle, + jax_out_dtype=jax_out_dtype, + rhs_tiling_kind=rhs_tiling_kind, + lhs_tiling_kind=lhs_tiling_kind, + ) + + def _test_wgmma_basic( self, m, n, @@ -675,8 +859,12 @@ def test_wgmma_basic( rhs_tiling_kind, lhs_tiling_kind, ): - if jax_out_dtype == jnp.float16 and in_mlir_dtype_cls is not ir.F16Type: - self.skipTest("Only f16 input is supported for f16 output.") + if jax_out_dtype == jnp.int32 and in_mlir_dtype_cls != I8Type: + self.skipTest("s32 accumulator only supported with s8 inputs") + if jax_out_dtype != jnp.int32 and in_mlir_dtype_cls == I8Type: + self.skipTest("s8 inputs only supported with s32 accumulator") + if jax_out_dtype == jnp.float16 and in_mlir_dtype_cls in {ir.F32Type, ir.BF16Type}: + self.skipTest(f"{in_mlir_dtype_cls.get()} does not support f16 output.") if swizzle != 128 and lhs_transpose and lhs_tiling_kind == "large": self.skipTest("Transpose only supported in 128B swizzled WGMMA") if rhs_tiling_kind == "small+no_transpose" and not rhs_transpose: @@ -686,10 +874,10 @@ def test_wgmma_basic( in_mlir_dtype = in_mlir_dtype_cls.get() out_mlir_dtype = utils.dtype_to_ir_type(jax_out_dtype) + if (lhs_transpose or not rhs_transpose) and bytewidth(in_mlir_dtype) != 2: + self.skipTest("Transpose only supported in 16-bit WGMMA") if ir.F32Type.isinstance(in_mlir_dtype): # We actually use tf32 instead in_jax_dtype = jnp.float32 - if lhs_transpose or not rhs_transpose: - self.skipTest("Transpose only supported in 16-bit WGMMA") exponent_bits, mantissa_bits = 8, 10 # Use tf32 elif bytewidth(in_mlir_dtype) == 2: if n % 64 != 0: @@ -702,10 +890,21 @@ def test_wgmma_basic( exponent_bits, mantissa_bits = 8, 7 else: raise NotImplementedError(in_mlir_dtype) + elif in_mlir_dtype_cls == ir.Float8E5M2Type: + in_jax_dtype = jnp.float8_e5m2 + exponent_bits, mantissa_bits = 5, 2 + elif in_mlir_dtype_cls == ir.Float8E4M3FNType: + in_jax_dtype = jnp.float8_e4m3fn + exponent_bits, mantissa_bits = 4, 3 + elif in_mlir_dtype_cls == I8Type: + in_jax_dtype = jnp.int8 + exponent_bits = mantissa_bits = None else: raise NotImplementedError(in_mlir_dtype) nk_tile = swizzle // bytewidth(in_mlir_dtype) k = nk_tile * k_steps + if n % nk_tile: + self.skipTest("tiling does not divide N") assert m % 64 == 0 and n % nk_tile == 0 small_rhs_tile = rhs_tiling_kind != "large" @@ -739,7 +938,8 @@ def kernel(ctx, lhs, rhs, out, scratch): ) for i in range(2): barriers[i].wait() - init_acc = mgpu.WGMMAAccumulator.zero(m=m, n=n, dtype=out_mlir_dtype) + is_signed = True if ir.IntegerType.isinstance(in_mlir_dtype) else None + init_acc = mgpu.WGMMAAccumulator.zero(m=m, n=n, dtype=out_mlir_dtype, is_signed=is_signed) if lhs_transpose: perm = (0, 1, 3, 2) if transpose_lhs_tiles else (1, 0, 3, 2) lhs_smem = memref_transpose(lhs_smem, perm) @@ -749,16 +949,20 @@ def kernel(ctx, lhs, rhs, out, scratch): acc = mgpu.wgmma(init_acc, lhs_smem, rhs_smem, swizzle=swizzle) nvvm.wgmma_commit_group_sync_aligned() nvvm.wgmma_wait_group_sync_aligned(0) - acc.value.store_untiled(out) + acc.value.store_untiled(out, optimized=False) def quantize(x): # Quantize the input to avoid rounding when feeding the WGMMA return jax.lax.reduce_precision(x, exponent_bits, mantissa_bits) x_shape = (k, m) if lhs_transpose else (m, k) - x = quantize(self.prng.uniform(-1, 1, x_shape)).astype(in_jax_dtype) y_shape = (n, k) if rhs_transpose else (k, n) - y = quantize(self.prng.uniform(-1, 1, y_shape)).astype(in_jax_dtype) + if in_mlir_dtype_cls == I8Type: + x = self.prng.integers(-128, 127, x_shape).astype(in_jax_dtype) + y = self.prng.integers(-128, 127, y_shape).astype(in_jax_dtype) + else: + x = quantize(self.prng.uniform(-1, 1, x_shape)).astype(in_jax_dtype) + y = quantize(self.prng.uniform(-1, 1, y_shape)).astype(in_jax_dtype) out_shape = jax.ShapeDtypeStruct((m, n), jax_out_dtype) if transpose_rhs_tiles: rhs_tiling_t = rhs_tiling[::-1] if rhs_transpose else rhs_tiling @@ -781,6 +985,10 @@ def quantize(x): x32, y32 = x.astype(np.float32), y.astype(np.float32) ref = (x32.T if lhs_transpose else x32) @ (y32.T if rhs_transpose else y32) atol = 2e-2 if jax_out_dtype == jnp.float16 else 5e-6 + if ir.IntegerType.isinstance(in_mlir_dtype) and ir.IntegerType.isinstance(out_mlir_dtype): + atol = 0 + elif utils.bitwidth(in_mlir_dtype) == 8: + atol = 3e-2 np.testing.assert_allclose(z, ref, atol=atol) # TODO(apaszke): Add support for f32 @@ -821,7 +1029,7 @@ def kernel(ctx, rhs, out, rhs_smem): acc = mgpu.wgmma(init_acc, lhs_regs, rhs_smem, swizzle=swizzle) nvvm.wgmma_commit_group_sync_aligned() nvvm.wgmma_wait_group_sync_aligned(0) - acc.value.store_untiled(out) + acc.value.store_untiled(out, optimized=False) y_shape = (n, k) if rhs_transpose else (k, n) y = self.prng.uniform(-1, 1, y_shape).astype(dtype) @@ -881,7 +1089,7 @@ def kernel(ctx, rhs, out, smem): acc = mgpu.wgmma(init_acc, lhs_regs, rhs_smem, swizzle=swizzle) nvvm.wgmma_commit_group_sync_aligned() nvvm.wgmma_wait_group_sync_aligned(0) - acc.value.store_untiled(out) + acc.value.store_untiled(out, optimized=False) jax_dtype = jnp.float16 y_shape = (n, k) if rhs_transpose else (k, n) @@ -897,7 +1105,7 @@ def kernel(ctx, rhs, out, smem): ref = jax.lax.dot( x, (y.T if rhs_transpose else y), preferred_element_type=jnp.float32 ) - np.testing.assert_allclose(z, ref, rtol=5e-4, atol=0) + np.testing.assert_allclose(z, ref, rtol=1e-3, atol=0) class TCGen05Test(TestCase): @@ -908,19 +1116,164 @@ def setUp(self): if not any(jtu.is_cuda_compute_capability_equal(sm) for sm in capabilities): self.skipTest("Only works on GPU with capability sm_100a or sm_101a") + @parameterized.parameters([(jnp.float32, 1), (jnp.float16, 1), (jnp.float16, 2)]) + def test_load_store_tmem_swizzle(self, jax_dtype, packing): + swizzle = 128 + in_mlir_dtype = utils.dtype_to_ir_type(jax_dtype) + swizzle_elems = swizzle // bytewidth(in_mlir_dtype) + tiling = (8, swizzle_elems) + + def kernel(ctx, input, output, scratch): + smem, barrier, tmem = scratch + ctx.async_copy( + src_ref=input, + dst_ref=smem, + swizzle=swizzle, + gmem_transform=mgpu.TileTransform(tiling), + barrier=barrier, + ) + barrier.wait() + tmem.store(fa.FragmentedArray.load_tiled(smem, swizzle, layout=tcgen05.LAYOUT)) + tcgen05.commit_tmem() + tmem.load().store_tiled(smem, swizzle) + mgpu.commit_shared() + ctx.async_copy( + src_ref=smem, dst_ref=output, swizzle=swizzle, gmem_transform=mgpu.TileTransform(tiling), + ) + ctx.await_async_copy(0) + + x = self.prng.uniform(-1, 1, (128, 128)).astype(jax_dtype) + scratch_shape = [ + jax.ShapeDtypeStruct(tile_shape(x.shape, tiling), jax_dtype), + mgpu.TMABarrier(), + mgpu.TMEM(x.shape, jax_dtype, packing=packing), + ] + y = mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), x, x, scratch_shape + )(x) + np.testing.assert_array_equal(x, y) + + @parameterized.parameters([(jnp.float32, 1), (jnp.float16, 1), (jnp.float16, 2)]) + def test_load_store_tmem_native(self, jax_dtype, packing): + + def kernel(ctx, input, output, scratch): + smem, barrier, tmem = scratch + ctx.async_copy(src_ref=input, dst_ref=smem, barrier=barrier) + barrier.wait() + tmem.store(fa.FragmentedArray.load_untiled(smem, layout=tcgen05.TMEM_NATIVE_LAYOUT, optimized=False)) + tcgen05.commit_tmem() + tmem.load(tcgen05.TMEM_NATIVE_LAYOUT).store_untiled(smem, optimized=False) + mgpu.commit_shared() + ctx.async_copy(src_ref=smem, dst_ref=output) + ctx.await_async_copy(0) + + x = self.prng.uniform(-1, 1, (128, 128)).astype(jax_dtype) + scratch_shape = [ + jax.ShapeDtypeStruct(x.shape, jax_dtype), + mgpu.TMABarrier(), + mgpu.TMEM(x.shape, jax_dtype, packing=packing), + ] + y = mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), x, x, scratch_shape + )(x) + np.testing.assert_array_equal(x, y) + + @parameterized.parameters([ + (jnp.float32, 1, "130.0000"), + (jnp.float16, 1, "130.0000"), + (jnp.float16, 2, "[132.000000,133.000000]"), + ]) + @jtu.thread_unsafe_test() + def test_tmem_debug_print(self, jax_dtype, packing, expected): + swizzle = 128 + in_mlir_dtype = utils.dtype_to_ir_type(jax_dtype) + swizzle_elems = swizzle // bytewidth(in_mlir_dtype) + tiling = (8, swizzle_elems) + + def kernel(ctx, input, output, scratch): + smem, barrier, tmem = scratch + ctx.async_copy( + src_ref=input, + dst_ref=smem, + swizzle=swizzle, + gmem_transform=mgpu.TileTransform(tiling), + barrier=barrier, + ) + barrier.wait() + tmem.store(fa.FragmentedArray.load_tiled(smem, swizzle, layout=tcgen05.LAYOUT)) + tcgen05.commit_tmem() + tmem.slice(slice(None), slice(0, 8))._debug_print() + + x = jnp.arange(128 * 128, dtype=jax_dtype).reshape(128, 128) + scratch_shape = [ + jax.ShapeDtypeStruct(tile_shape(x.shape, tiling), jax_dtype), + mgpu.TMABarrier(), + mgpu.TMEM(x.shape, jax_dtype, packing=packing), + ] + with self.capture_stdout() as stdout: + mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), x, x, scratch_shape + )(x).block_until_ready() + self.assertIn("[1, 2]: " + expected, stdout()) + @parameterized.product( lhs_transpose=(False, True), rhs_transpose=(False, True), - in_jax_dtype=(jnp.float16, jnp.bfloat16), # TODO(apaszke): f32 - out_jax_dtype=(jnp.float32,), # TODO(apaszke): f16 accumulation + in_jax_dtype=(jnp.float16, jnp.bfloat16, jnp.float8_e5m2, jnp.float8_e4m3fn), # TODO(apaszke): f32 + out_jax_dtype=(jnp.float16, jnp.float32,), m=(128,), # TODO(apaszke): 64, 192, 256 n=(64, 128, 256, 512), # TODO(apaszke): 192, other non-power-of-2 - k_steps=(1, 2), swizzle=(32, 64, 128,), - rhs_transpose_tiles=(False, True), + ) + def test_mma_basic_float(self, **kwargs): + if kwargs["n"] * jnp.dtype(kwargs["in_jax_dtype"]).itemsize < kwargs["swizzle"]: + self.skipTest("swizzle too large for input") + self._basic_mma_test( + **kwargs, + k_steps=2, # Reducing to 1 can be helpful while debugging. + lhs_transpose_tiles=False, + rhs_transpose_tiles=False, + ) + + @parameterized.product( + lhs_transpose=(False, True), + rhs_transpose=(False, True), + in_jax_dtype=(jnp.int8,), + out_jax_dtype=(jnp.int32,), + m=(128,), # TODO(apaszke): 64, 192, 256 + n=(64, 128, 256, 512), # TODO(apaszke): 192, other non-power-of-2 + swizzle=(32, 64, 128,), + ) + def test_mma_basic_int(self, **kwargs): + if kwargs["n"] * jnp.dtype(kwargs["in_jax_dtype"]).itemsize < kwargs["swizzle"]: + self.skipTest("swizzle too large for input") + self._basic_mma_test( + **kwargs, + k_steps=2, # Reducing to 1 can be helpful while debugging. + lhs_transpose_tiles=False, + rhs_transpose_tiles=False, + ) + + @parameterized.product( + lhs_transpose=(False, True), + rhs_transpose=(False, True), + in_jax_dtype=(jnp.float16,), + out_jax_dtype=(jnp.float32,), + m=(128,), + n=(128, 512), + swizzle=(32, 64, 128,), lhs_transpose_tiles=(False, True), + rhs_transpose_tiles=(False, True), ) - def test_mma_basic( + def test_mma_transposed_tiles(self, **kwargs): + if not kwargs["lhs_transpose_tiles"] and not kwargs["rhs_transpose_tiles"]: + self.skipTest("This is already tested in test_mma_basic") + self._basic_mma_test( + **kwargs, + k_steps=2, # Reducing to 1 can be helpful while debugging. + ) + + def _basic_mma_test( self, m, n, @@ -933,8 +1286,10 @@ def test_mma_basic( rhs_transpose_tiles, lhs_transpose_tiles, ): - if out_jax_dtype == jnp.float16 and in_jax_dtype != jnp.float16: - self.skipTest("Only f16 input is supported for f16 output.") + if out_jax_dtype != jnp.float32 and ( + in_jax_dtype == jnp.float32 or in_jax_dtype == jnp.bfloat16 + ): + self.skipTest("Only f32 output is supported for f32 and bf16 input.") in_mlir_dtype = utils.dtype_to_ir_type(in_jax_dtype) swizzle_elems = swizzle // bytewidth(in_mlir_dtype) @@ -979,18 +1334,13 @@ def kernel(ctx, lhs, rhs, out, scratch): ) tcgen05.commit_arrive(barriers[2]) barriers[2].wait(for_tensor_core=True) - acc[:].store_untiled(out) - - in_finfo = jnp.finfo(in_jax_dtype) - exponent_bits, mantissa_bits = in_finfo.nexp, in_finfo.nmant - def quantize(x): - # Quantize the input to avoid rounding when feeding the TensorCore - return jax.lax.reduce_precision(x, exponent_bits, mantissa_bits) + is_signed = True if jnp.issubdtype(in_jax_dtype, jnp.integer) else None + acc.load(is_signed=is_signed).store_untiled(out, optimized=False) x_shape = (k, m) if lhs_transpose else (m, k) - x = quantize(self.prng.uniform(-1, 1, x_shape)).astype(in_jax_dtype) + x = self.prng.uniform(-1, 1, x_shape).astype(in_jax_dtype) y_shape = (n, k) if rhs_transpose else (k, n) - y = quantize(self.prng.uniform(-1, 1, y_shape)).astype(in_jax_dtype) + y = self.prng.uniform(-1, 1, y_shape).astype(in_jax_dtype) out_shape = jax.ShapeDtypeStruct((m, n), out_jax_dtype) if rhs_transpose_tiles: rhs_smem_shape = ( @@ -1015,14 +1365,85 @@ def quantize(x): )(x, y) x32, y32 = x.astype(np.float32), y.astype(np.float32) ref = (x32.T if lhs_transpose else x32) @ (y32.T if rhs_transpose else y32) - atol = 2e-2 if out_jax_dtype == jnp.float16 else 5e-6 - np.testing.assert_allclose(z, ref, atol=atol) + atol = 2e-2 if out_jax_dtype == jnp.float16 else 2e-5 + rtol = 8e-4 if out_jax_dtype == jnp.float16 else 1e-7 + np.testing.assert_allclose(z, ref, atol=atol, rtol=rtol) + + @parameterized.product( + in_jax_dtype=(jnp.float16, jnp.bfloat16), # TODO(apaszke): f32 + out_jax_dtype=(jnp.float16, jnp.float32,), + m=(128,), # TODO(apaszke): 64, 192, 256 + n=(64, 128, 256), # TODO(apaszke): 192, other non-power-of-2 + ) + def test_mma_lhs_tmem(self, m, n, in_jax_dtype, out_jax_dtype): + swizzle = 128 + k_steps = 2 # Reducing to 1 can be helpful while debugging. + if out_jax_dtype == jnp.float16 and in_jax_dtype != jnp.float16: + self.skipTest("Only f16 input is supported for f16 output.") + + in_mlir_dtype = utils.dtype_to_ir_type(in_jax_dtype) + swizzle_elems = swizzle // bytewidth(in_mlir_dtype) + k = swizzle_elems * k_steps + lhs_tiling = rhs_tiling = (8, swizzle_elems) + + def kernel(ctx, lhs, rhs, out, scratch): + lhs_smem, rhs_smem, barriers, acc, lhs_tmem = scratch + ctx.async_copy( + src_ref=lhs, + dst_ref=lhs_smem, + swizzle=swizzle, + gmem_transform=mgpu.TileTransform(lhs_tiling), + barrier=barriers[0], + ) + ctx.async_copy( + src_ref=rhs, + dst_ref=rhs_smem, + swizzle=swizzle, + gmem_transform=mgpu.TileTransform(rhs_tiling), + barrier=barriers[1], + ) + barriers[0].wait() + barriers[1].wait() + lhs_tmem.store( + fa.FragmentedArray.load_tiled( + lhs_smem, swizzle, layout=tcgen05.LAYOUT + ) + ) + tcgen05.commit_tmem() + with mgpu.single_thread(): + tcgen05.mma( + acc, lhs_tmem, rhs_smem, a_swizzle=swizzle, b_swizzle=swizzle, accumulate=False, + ) + tcgen05.commit_arrive(barriers[2]) + barriers[2].wait(for_tensor_core=True) + acc.load().store_untiled(out, optimized=False) + + x_shape = (m, k) + x = self.prng.uniform(-1, 1, x_shape).astype(in_jax_dtype) + y_shape = (k, n) + y = self.prng.uniform(-1, 1, y_shape).astype(in_jax_dtype) + out_shape = jax.ShapeDtypeStruct((m, n), out_jax_dtype) + scratch_shape = [ + jax.ShapeDtypeStruct(tile_shape(x_shape, lhs_tiling), in_jax_dtype), + jax.ShapeDtypeStruct(tile_shape(y_shape, rhs_tiling), in_jax_dtype), + mgpu.TMABarrier(3), + mgpu.TMEM((128, n), out_jax_dtype), + mgpu.TMEM((128, k), in_jax_dtype, packing=2), + ] + z = mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), (x, y), out_shape, scratch_shape + )(x, y) + x32, y32 = x.astype(np.float32), y.astype(np.float32) + ref = x32 @ y32 + atol = 2e-2 if out_jax_dtype == jnp.float16 else 2e-5 + rtol = 8e-4 if out_jax_dtype == jnp.float16 else 1e-7 + np.testing.assert_allclose(z, ref, atol=atol, rtol=rtol) @parameterized.product( lhs_transpose=(False, True), rhs_transpose=(False, True), - in_jax_dtype=(jnp.float16,), # TODO(apaszke): f32 - out_jax_dtype=(jnp.float32,), # TODO(apaszke): f16 accumulation + in_jax_dtype=(jnp.float16,), + out_jax_dtype=(jnp.float32,), m=(256,), # TODO(apaszke): 64, 192, 256 n=(128, 256, 512), # TODO(apaszke): 192, other non-power-of-2 k_steps=(1, 2), @@ -1061,7 +1482,102 @@ def kernel(ctx, lhs, rhs, out, scratch): gmem_transform=mgpu.TileTransform(tiling), barrier=barriers[0], collective=gpu.Dimension.x, - partitioned=1 if lhs_transpose else 0, # Split non-contracting dim. + partitioned=1 if lhs_transpose else 0, # Split non-contracting dim. + ) + ctx.async_copy( + src_ref=rhs, + dst_ref=rhs_smem, + swizzle=swizzle, + gmem_transform=mgpu.TileTransform(tiling), + barrier=barriers[1], + collective=gpu.Dimension.x, + partitioned=0 if rhs_transpose else 1, # Split non-contracting dim. + ) + is_leader_thread = single_thread_predicate() + is_first_block = arith.cmpi(arith.CmpIPredicate.eq, block_id, c(0, index)) + with when(arith.andi(is_first_block, is_leader_thread)): + barriers[0].wait() + barriers[1].wait() + if lhs_transpose: + lhs_smem = memref_transpose(lhs_smem, (1, 0, 3, 2)) + if rhs_transpose: + rhs_smem = memref_transpose(rhs_smem, (1, 0, 3, 2)) + tcgen05.mma( + acc, lhs_smem, rhs_smem, a_swizzle=swizzle, b_swizzle=swizzle, accumulate=False, collective=True + ) + tcgen05.commit_arrive(barriers[2], collective=True, ctx=ctx) + barriers[2].wait(for_tensor_core=True) + m_slice = ds(arith.muli(block_id, c(m_block_tile, index)), m_block_tile) + acc.load().store_untiled(memref_slice(out, m_slice), optimized=False) + + in_finfo = jnp.finfo(in_jax_dtype) + exponent_bits, mantissa_bits = in_finfo.nexp, in_finfo.nmant + def quantize(x): + # Quantize the input to avoid rounding when feeding the TensorCore + return jax.lax.reduce_precision(x, exponent_bits, mantissa_bits) + + x_shape = (k, m) if lhs_transpose else (m, k) + x_block_shape = (k, m_block_tile) if lhs_transpose else (m_block_tile, k) + x = quantize(self.prng.uniform(-1, 1, x_shape)).astype(in_jax_dtype) + y_shape = (n, k) if rhs_transpose else (k, n) + y_block_shape = (n_block_tile, k) if rhs_transpose else (k, n_block_tile) + y = quantize(self.prng.uniform(-1, 1, y_shape)).astype(in_jax_dtype) + out_shape = jax.ShapeDtypeStruct((m, n), out_jax_dtype) + tmem_layout = tcgen05.TMEM_COLLECTIVE_N512_LAYOUT if n == 512 else None + scratch_shape = [ + jax.ShapeDtypeStruct(tile_shape(x_block_shape, tiling), in_jax_dtype), + jax.ShapeDtypeStruct(tile_shape(y_block_shape, tiling), in_jax_dtype), + mgpu.TMABarrier(3), + mgpu.TMEM((128, n), out_jax_dtype, collective=True, layout=tmem_layout), + ] + z = mgpu.as_gpu_kernel( + kernel, (2, 1, 1), (128, 1, 1), (x, y), out_shape, scratch_shape, cluster=(2, 1, 1) + )(x, y) + x32, y32 = x.astype(np.float32), y.astype(np.float32) + ref = (x32.T if lhs_transpose else x32) @ (y32.T if rhs_transpose else y32) + atol = 2e-2 if out_jax_dtype == jnp.float16 else 5e-6 + np.testing.assert_allclose(z, ref, atol=atol) + + @parameterized.product( + in_jax_dtype=(jnp.float16,), + out_jax_dtype=(jnp.float32,), + m=(256,), # TODO(apaszke): 64, 192, 256 + n=(128, 256,), # TODO(apaszke): 192, other non-power-of-2, 512 + k_steps=(2,), # Note: reducing to 1 can be useful for debugging. + swizzle=(32, 64, 128,), + ) + def test_mma_collective_lhs_tmem( + self, + m, + n, + k_steps, + swizzle, + in_jax_dtype, + out_jax_dtype, + ): + if out_jax_dtype == jnp.float16 and in_jax_dtype != jnp.float16: + raise self.skipTest("Only f16 input is supported for f16 output.") + + in_mlir_dtype = utils.dtype_to_ir_type(in_jax_dtype) + m_block_tile = m // 2 + n_block_tile = n // 2 + swizzle_elems = swizzle // bytewidth(in_mlir_dtype) + k = swizzle_elems * k_steps + index = ir.IndexType.get() + + tiling = (8, swizzle_elems) + + def kernel(ctx, lhs, rhs, out, scratch): + lhs_smem, rhs_smem, barriers, cluster_barrier, acc, lhs_tmem = scratch + block_id = gpu.cluster_block_id(gpu.Dimension.x) + ctx.async_copy( + src_ref=lhs, + dst_ref=lhs_smem, + swizzle=swizzle, + gmem_transform=mgpu.TileTransform(tiling), + barrier=barriers[0], + collective=gpu.Dimension.x, + partitioned=0, # Split non-contracting dim. ) ctx.async_copy( src_ref=rhs, @@ -1070,49 +1586,77 @@ def kernel(ctx, lhs, rhs, out, scratch): gmem_transform=mgpu.TileTransform(tiling), barrier=barriers[1], collective=gpu.Dimension.x, - partitioned=0 if rhs_transpose else 1, # Split non-contracting dim. + partitioned=1, # Split non-contracting dim. ) + is_leader_thread = single_thread_predicate() is_first_block = arith.cmpi(arith.CmpIPredicate.eq, block_id, c(0, index)) + with when(arith.andi(is_first_block, is_leader_thread)): barriers[0].wait() + gpu.barrier() + # Because only block 1 waits on the TMA, we need a cluster barrier so + # that the SMEM updates are visible on block 2. + cluster_barrier.arrive() + cluster_barrier.wait() + lhs_tmem.store( + fa.FragmentedArray.load_tiled( + lhs_smem, swizzle, layout=tcgen05.LAYOUT + ) + ) + tcgen05.commit_tmem() + # Make sure TMEM has been loaded on both blocks. + cluster_barrier.arrive() + cluster_barrier.wait() + with when(arith.andi(is_first_block, is_leader_thread)): barriers[1].wait() - if lhs_transpose: - lhs_smem = memref_transpose(lhs_smem, (1, 0, 3, 2)) - if rhs_transpose: - rhs_smem = memref_transpose(rhs_smem, (1, 0, 3, 2)) tcgen05.mma( - acc, lhs_smem, rhs_smem, a_swizzle=swizzle, b_swizzle=swizzle, accumulate=False, collective=True + acc, + lhs_tmem, + rhs_smem, + a_swizzle=swizzle, + b_swizzle=swizzle, + accumulate=False, + collective=True, ) tcgen05.commit_arrive(barriers[2], collective=True, ctx=ctx) barriers[2].wait(for_tensor_core=True) m_slice = ds(arith.muli(block_id, c(m_block_tile, index)), m_block_tile) - acc[:].store_untiled(memref_slice(out, m_slice)) + acc.load().store_untiled(memref_slice(out, m_slice), optimized=False) in_finfo = jnp.finfo(in_jax_dtype) exponent_bits, mantissa_bits = in_finfo.nexp, in_finfo.nmant + def quantize(x): # Quantize the input to avoid rounding when feeding the TensorCore return jax.lax.reduce_precision(x, exponent_bits, mantissa_bits) - x_shape = (k, m) if lhs_transpose else (m, k) - x_block_shape = (k, m_block_tile) if lhs_transpose else (m_block_tile, k) + x_shape = (m, k) + x_block_shape = (m_block_tile, k) x = quantize(self.prng.uniform(-1, 1, x_shape)).astype(in_jax_dtype) - y_shape = (n, k) if rhs_transpose else (k, n) - y_block_shape = (n_block_tile, k) if rhs_transpose else (k, n_block_tile) + y_shape = (k, n) + y_block_shape = (k, n_block_tile) y = quantize(self.prng.uniform(-1, 1, y_shape)).astype(in_jax_dtype) out_shape = jax.ShapeDtypeStruct((m, n), out_jax_dtype) scratch_shape = [ jax.ShapeDtypeStruct(tile_shape(x_block_shape, tiling), in_jax_dtype), jax.ShapeDtypeStruct(tile_shape(y_block_shape, tiling), in_jax_dtype), mgpu.TMABarrier(3), + mgpu.ClusterBarrier(collective_dims=(gpu.Dimension.x,)), mgpu.TMEM((128, n), out_jax_dtype, collective=True), + mgpu.TMEM((128, k), in_jax_dtype, collective=True, packing=2), ] z = mgpu.as_gpu_kernel( - kernel, (2, 1, 1), (128, 1, 1), (x, y), out_shape, scratch_shape, cluster=(2, 1, 1) + kernel, + (2, 1, 1), + (128, 1, 1), + (x, y), + out_shape, + scratch_shape, + cluster=(2, 1, 1), )(x, y) x32, y32 = x.astype(np.float32), y.astype(np.float32) - ref = (x32.T if lhs_transpose else x32) @ (y32.T if rhs_transpose else y32) + ref = x32 @ y32 atol = 2e-2 if out_jax_dtype == jnp.float16 else 5e-6 np.testing.assert_allclose(z, ref, atol=atol) @@ -1140,7 +1684,7 @@ def kernel(ctx, dst, scratch): final_arr = arr + mgpu.FragmentedArray.load_strided( tmp, is_signed=False ) - final_arr.store_untiled(memref_slice(dst, 0)) + final_arr.store_untiled(memref_slice(dst, 0), optimized=False) scf.yield_([]) with ir.InsertionPoint(scf.IfOp(is_second_wg).then_block): barriers[0].wait() @@ -1151,7 +1695,7 @@ def kernel(ctx, dst, scratch): barriers[2].wait() # Synchronize this warpgroup before we overwrite tmp. arr.store_untiled(tmp) barriers[1].arrive() # Signal that tmp is ready. - final_arr.store_untiled(memref_slice(dst, 1)) + final_arr.store_untiled(memref_slice(dst, 1), optimized=False) scf.yield_([]) out_shape = jax.ShapeDtypeStruct((2, 128), jnp.int32) y = mgpu.as_gpu_kernel( @@ -1260,6 +1804,25 @@ def kernel(ctx, src, dst, smem): y = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, smem)(x) np.testing.assert_array_equal(y, x) + def test_tma_with_1d_tiling(self): + swizzle = 128 + dtype = jnp.float16 + shape = (64, 128) + tiling = (1, swizzle // jnp.dtype(dtype).itemsize) + def kernel(ctx, dst, smem): + iota_tensor(*shape, dtype=dtype).store_tiled(smem, swizzle=swizzle) + ctx.async_copy( + src_ref=smem, + dst_ref=dst, + swizzle=swizzle, + gmem_transform=mgpu.TileTransform(tiling), + ) + ctx.await_async_copy(0) + x = np.arange(np.prod(shape), dtype=dtype).reshape(shape) + smem = jax.ShapeDtypeStruct(utils.tile_shape(shape, tiling), dtype) + y = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), (), x, smem)() + np.testing.assert_array_equal(y, x) + @parameterized.named_parameters( ( f"_{''.join(map(str, collective_dims))}={collective_size}{'_' + ''.join(map(str, noncollective_dims)) if noncollective_dims else ''}", @@ -1579,7 +2142,7 @@ def run_kernel(shape): run_kernel([1] * 6) with self.assertRaisesRegex( - ValueError, "last dimension to be divisible by 16" + ValueError, "last dimension to be divisible by 128" ): run_kernel([23]) @@ -1612,7 +2175,7 @@ def kernel(ctx, dst, _): mlir_dtype = utils.dtype_to_ir_type(dtype) iota = iota_tensor(m, n, dtype) rhs = iota if scalar_rhs is None else c(scalar_rhs, mlir_dtype) - op(iota, rhs).store_untiled(dst) + op(iota, rhs).store_untiled(dst, optimized=False) out_shape = jax.ShapeDtypeStruct((m, n), dtype) result = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), (), out_shape, () @@ -1658,7 +2221,7 @@ def test_division(self, op, dtype, m=64, n=32): def kernel(ctx, dst, _): iota = iota_tensor(m, n, dtype) - op(dtype(4.2).item() * iota, iota + 1).store_untiled(dst) + op(dtype(4.2).item() * iota, iota + 1).store_untiled(dst, optimized=False) out_shape = jax.ShapeDtypeStruct((m, n), dtype) result = mgpu.as_gpu_kernel( @@ -1688,22 +2251,46 @@ def kernel(ctx, dst, _): rhs = 0 if rhs_is_literal else iota + 1 res = op(iota, rhs) assert not res.is_signed - res.astype(i8, is_signed=False).store_untiled(dst) + res.astype(i8, is_signed=False).store_untiled(dst, optimized=False) out_shape = jax.ShapeDtypeStruct((m, n), jnp.int8) result = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), (), out_shape, () )() iota = np.arange(m * n, dtype=dtype).reshape(m, n) - rhs = rhs = 0 if rhs_is_literal else iota + 1 + rhs = 0 if rhs_is_literal else iota + 1 np.testing.assert_array_equal(result, op(iota, rhs).astype(jnp.int8)) + def test_foreach_wgmma_row_array(self): + def kernel(ctx, out, smem): + del ctx, smem + x = iota_tensor(128, 128, jnp.float32) + row = x.reduce("add", 1) + # Test returning an array + row = row.foreach( + lambda x, _: arith.addf(x, c(1, row.mlir_dtype)), create_array=True + ) + # Test no array return + @row.foreach + def _(v, idx): + memref.store(v, out, idx) + + result = mgpu.as_gpu_kernel( + kernel, + grid=(1, 1, 1), + block=(128, 1, 1), + in_shape=(), + out_shape=jax.ShapeDtypeStruct(shape=(128,), dtype=jnp.float32), + smem_scratch_shape=(), + )() + iota = np.arange(128 * 128, dtype=jnp.float32).reshape(128, 128) + np.testing.assert_array_equal(result, iota.sum(axis=1) + 1) + def test_foreach(self): dtype = jnp.int32 swizzle = 128 - tile = 64, swizzle // jnp.dtype(dtype).itemsize + tiling = (8, swizzle // jnp.dtype(dtype).itemsize) shape = 128, 192 - tiled_shape = mgpu.tile_shape(shape, tile) mlir_dtype = utils.dtype_to_ir_type(dtype) cst = 9999 def causal(val, idx): @@ -1711,12 +2298,16 @@ def causal(val, idx): mask = arith.cmpi(arith.CmpIPredicate.uge, row, col) return arith.select(mask, val, c(cst, mlir_dtype)) - tiling = mgpu.TileTransform(tile) def kernel(ctx, dst, smem): x = iota_tensor(shape[0], shape[1], dtype) - x.foreach(causal, create_array=True, is_signed=False).store_untiled(smem) + x.foreach(causal, create_array=True, is_signed=False).store_tiled(smem, swizzle=128) mgpu.commit_shared() - ctx.async_copy(src_ref=smem, dst_ref=dst) + ctx.async_copy( + src_ref=smem, + dst_ref=dst, + gmem_transform=mgpu.TileTransform(tiling), + swizzle=128, + ) ctx.await_async_copy(0) iota = np.arange(np.prod(shape), dtype=dtype).reshape(*shape) @@ -1726,7 +2317,7 @@ def kernel(ctx, dst, smem): (128, 1, 1), (), jax.ShapeDtypeStruct(shape=shape, dtype=dtype), - jax.ShapeDtypeStruct(shape=shape, dtype=dtype), + jax.ShapeDtypeStruct(shape=mgpu.tile_shape(shape, tiling), dtype=dtype), )() expected = jnp.tril(iota) + jnp.triu(jnp.ones(shape), k=1) * cst np.testing.assert_array_equal(result, expected) @@ -1738,7 +2329,7 @@ def kernel(ctx, dst, smem): def test_bitwise(self, op, dtype, m=64, n=8): def kernel(ctx, dst, _): iota = iota_tensor(m, n, dtype) - op(iota, iota + 1).store_untiled(dst) + op(iota, iota + 1).store_untiled(dst, optimized=False) out_shape = jax.ShapeDtypeStruct((m, n), dtype) result = mgpu.as_gpu_kernel( @@ -1762,7 +2353,7 @@ def test_unary(self, ops, dtype, m=64, n=32): def kernel(ctx, dst, _): iota = iota_tensor(m, n, dtype) - op(iota).store_untiled(dst) + op(iota).store_untiled(dst, optimized=False) out_shape = jax.ShapeDtypeStruct((m, n), dtype) result = mgpu.as_gpu_kernel( @@ -1775,7 +2366,7 @@ def test_select(self, m=64, n=32): def kernel(ctx, dst, _): iota = iota_tensor(m, n, jnp.int32) - (iota < 16).select(iota * 2, iota * 3).store_untiled(dst) + (iota < 16).select(iota * 2, iota * 3).store_untiled(dst, optimized=False) out_shape = jax.ShapeDtypeStruct((m, n), jnp.int32) result = mgpu.as_gpu_kernel( @@ -1798,7 +2389,7 @@ def test_math(self, ops, approx, m=64, n=32): op, np_op = ops def kernel(ctx, dst, _): iota = iota_tensor(m, n, jnp.float32) - op(iota).store_untiled(dst) + op(iota).store_untiled(dst, optimized=False) out_shape = jax.ShapeDtypeStruct((m, n), jnp.float32) result = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), (), out_shape, () @@ -1818,8 +2409,8 @@ def kernel(ctx, src, dst, scratch): src = mgpu.FragmentedArray.load_strided( src, is_signed=utils.is_signed(dtype) ) - acc = src.reduce_sum(scratch).broadcast((m,)) - acc.store_untiled(dst) + acc = src.reduce("add", (0, 1), scratch).broadcast((m,)) + acc.store_untiled(dst, optimized=False) in_shape = jax.ShapeDtypeStruct((m, n), dtype) out_shape = jax.ShapeDtypeStruct((m,), dtype) @@ -1838,16 +2429,20 @@ def kernel(ctx, src, dst, scratch): dtype=[jnp.float32, jnp.int32], m=[128], n=[32, 64], + reduce_both=[False, True], ) - def test_splat_reduce_sum(self, dtype, m, n): + def test_splat_reduce_sum(self, dtype, m, n, reduce_both): def kernel(ctx, dst, _): src = mgpu.FragmentedArray.splat( utils.c(1, utils.dtype_to_ir_type(dtype)), (m, n), is_signed=utils.is_signed(dtype), ) - acc = src.reduce_sum().broadcast((m,)) - acc.store_untiled(dst) + if reduce_both: + acc = src.reduce("add", (0, 1)).broadcast((m,)) + else: + acc = src.reduce("add", 1) + acc.store_untiled(dst, optimized=False) kernel_fn = mgpu.as_gpu_kernel( kernel, @@ -1857,7 +2452,17 @@ def kernel(ctx, dst, _): out_shape=jax.ShapeDtypeStruct((m,), dtype), smem_scratch_shape=(), ) - np.testing.assert_array_equal(kernel_fn(), jnp.full((m,), m * n * 1.0)) + result = m * n if reduce_both else n + np.testing.assert_array_equal(kernel_fn(), jnp.full((m,), result, dtype)) + + @parameterized.named_parameters( + ("wgmma_row", fa.WGMMA_LAYOUT, fa.WGMMA_ROW_LAYOUT, 1), + ("wgmma_col", fa.WGMMA_LAYOUT, fa.WGMMA_COL_LAYOUT, 0), + ("tcgen05_row", tcgen05.LAYOUT, tcgen05.ROW_LAYOUT, 1), + ("tcgen05_col", tcgen05.LAYOUT, tcgen05.COL_LAYOUT, 0), + ) + def test_layout_reduction_definition(self, layout, expected_reduced_layout, axis): + self.assertEqual(layout.reduce((axis,)), expected_reduced_layout) @parameterized.product( op=(arith.addf, arith.maximumf), @@ -1867,7 +2472,7 @@ def kernel(ctx, dst, _): def test_reduce(self, op, m=64, n=32): def kernel(ctx, dst, _): iota = iota_tensor(m, n, jnp.float32) - iota.reduce(op, axis=1).broadcast_minor(n).store_untiled(dst) + iota.reduce(op, axis=1).broadcast_minor(n).store_untiled(dst, optimized=False) out_shape = jax.ShapeDtypeStruct((m, n), jnp.float32) result = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), (), out_shape, () @@ -1888,7 +2493,7 @@ def kernel(ctx, dst, _): cte = c(1, iota.mlir_dtype) cte_arr = mgpu.FragmentedArray.splat(cte, ()) cte_arr = cte_arr.reshape((1, 1)).broadcast((m, n)) - (iota + cte_arr).store_untiled(dst) + (iota + cte_arr).store_untiled(dst, optimized=False) out_shape = jax.ShapeDtypeStruct((m, n), jnp.float32) result = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), (), out_shape, () @@ -1903,7 +2508,7 @@ def kernel(ctx, dst, _): t = mgpu.FragmentedArray.splat( v, (128,), mgpu.WGMMA_ROW_LAYOUT ) - t.broadcast_minor(32).store_untiled(dst) + t.broadcast_minor(32).store_untiled(dst, optimized=False) out_shape = jax.ShapeDtypeStruct((128, 32), jnp.float32) result = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), (), out_shape, () @@ -1922,7 +2527,7 @@ def kernel(ctx, src, dst, _): assert isinstance(pi_arr_sq.layout, mgpu.WGStridedFragLayout) pi_arr_cube = pi_splat.broadcast(pi_arr.shape) * pi_arr_sq assert isinstance(pi_arr_cube.layout, mgpu.WGStridedFragLayout) - (pi_arr == pi_arr).select(pi_splat, pi_arr_cube).store_untiled(dst) + (pi_arr == pi_arr).select(pi_splat, pi_arr_cube).store_untiled(dst, optimized=False) out_shape = jax.ShapeDtypeStruct((128, 32), jnp.float32) inp = jnp.ones_like(out_shape) * 3.14 @@ -1946,20 +2551,63 @@ def kernel(ctx, *args): )(inp) np.testing.assert_array_equal(inp, result) - @parameterized.product(in_shape=((128,), (64,))) - def test_wgmma_row_load_store_with_layout(self, in_shape): + @parameterized.product( + in_shape=((1024,), (256,), (128,), (64,)), + dtype=(jnp.float16, jnp.float32), + swizzle=(16, 32, 64, 128) + ) + def test_wgmma_row_load_store_with_layout(self, in_shape, dtype, swizzle): + def kernel(ctx, gmem_input, gmem_output, smem): + smem_input, smem_output = smem + copy(gmem_input, smem_input, swizzle=swizzle) + t = mgpu.FragmentedArray.load_untiled( + smem_input, layout=mgpu.WGMMA_ROW_LAYOUT, swizzle=swizzle + ) + t.store_untiled(smem_output) + copy(smem_output, gmem_output) + + inp = out = self.prng.uniform(-1, 1, in_shape).astype(dtype) + result = mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), (inp,), out, [inp, out], + )(inp) + np.testing.assert_array_equal(inp, result) + + @parameterized.product( + in_shape=((128,), (64,)), + dtype=(jnp.float16, jnp.float32), + swizzle=(16, 32, 64, 128), + ) + def test_wgmma_col_load_store_with_layout(self, in_shape, dtype, swizzle): def kernel(ctx, *args): gmem_input, gmem_output, (smem_input, smem_output) = args - copy(gmem_input, smem_input) - t = mgpu.FragmentedArray.load_wgmma_row(smem_input) + copy(gmem_input, smem_input, swizzle=swizzle) + t = mgpu.FragmentedArray.load_untiled( + smem_input, swizzle=swizzle, layout=mgpu.WGMMA_COL_LAYOUT + ) t.store_untiled(smem_output) copy(smem_output, gmem_output) - inp = out = self.prng.uniform(-1, 1, in_shape).astype(jnp.float32) + inp = out = self.prng.uniform(-1, 1, in_shape).astype(dtype) result = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), (inp,), out, [inp, out], )(inp) - np.testing.assert_array_equal(inp, result) + np.testing.assert_array_equal(result, inp) + + @parameterized.parameters((128, 128), (128, 64), (64, 128)) + def test_broadcast_major(self, m, n): + def kernel(ctx, gmem_input, gmem_output, _): + t = mgpu.FragmentedArray.load_untiled( + gmem_input, layout=mgpu.WGMMA_COL_LAYOUT, optimized=False + ) + t.broadcast_major(m).store_untiled(gmem_output, optimized=False) + + inp = self.prng.uniform(-1, 1, (n,)).astype(jnp.float16) + out_shape = jax.ShapeDtypeStruct((m, n), jnp.float16) + result = mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), (inp,), out_shape, inp + )(inp) + out_ref = jax.lax.broadcast_in_dim(inp, (m, n), (1,)) + np.testing.assert_array_equal(result, out_ref) def test_warp_tree_reduce(self): def kernel(ctx, out, *_): @@ -1988,7 +2636,7 @@ def kernel(ctx, inp, out, smem): del ctx, smem arr = mgpu.FragmentedArray.load_strided(inp, is_signed=True) assert ir.VectorType(arr.registers.flat[0].type).shape == [reg_length] - arr.astype(mlir_dtype_to).store_untiled(out) + arr.astype(mlir_dtype_to).store_untiled(out, optimized=False) x = jnp.arange(-128, 128, dtype=jax_dtype_from) x = jnp.tile(x, reg_length // 2) @@ -2059,12 +2707,27 @@ def kernel(ctx, inp, out, smem): f = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, None) np.testing.assert_array_equal(f(x), x * 3) + def test_optimization_barrier_with_single_value(self): + shape = (64, 64) + value = 5.0 + dtype = jnp.float32 + def kernel(ctx, out, smem): + del ctx, smem + mlir_type = utils.dtype_to_ir_type(dtype) + arr = mgpu.FragmentedArray.splat(c(value, mlir_type), shape) + arr = mgpu.optimization_barrier(arr) + arr.store_untiled(out) + + out_shape = jax.ShapeDtypeStruct(shape, dtype) + f = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), (), out_shape, ()) + np.testing.assert_array_equal(f(), jnp.full(shape, value, dtype=dtype)) + def test_convert_bool_to_u8(self): m, n = 128, 128 def kernel(ctx, dst, _): i8 = ir.IntegerType.get_signless(8) iota = iota_tensor(m, n, jnp.uint8) - (iota > 10).astype(i8, is_signed=False).store_untiled(dst) + (iota > 10).astype(i8, is_signed=False).store_untiled(dst, optimized=False) out_shape = jax.ShapeDtypeStruct((m, n), jnp.int8) result = mgpu.as_gpu_kernel( @@ -2119,19 +2782,6 @@ def kernel(ctx, src, dst, _): )) jax.block_until_ready(f(x)) - def test_multigpu(self): - if len(jax.devices()) < 2: - self.skipTest("Need at least 2 devices") - def kernel(ctx, src, dst, _): - mgpu.FragmentedArray.load_strided(src).store_untiled(dst) - x = np.arange(64 * 64, dtype=jnp.float32).reshape(64, 64) - f = jax.jit(mgpu.as_gpu_kernel( - kernel, (1, 1, 1), (128, 1, 1), x, x, () - )) - # Make sure we can invoke the same program on different devices. - for xd in (jax.device_put(x, d) for d in jax.devices()[:2]): - jax.block_until_ready(f(xd)) - class TorchTest(TestCase): @@ -2188,11 +2838,11 @@ def kernel(ctx, dst, _): # Note that WGMMA layouts are always (shape[0] // 64, shape[1] // 8, 2, 1) self.assertEqual( tiled.registers.shape, - (shape[0] // 64, shape[1] // 8, 1, 1, 2, 1, 1, 1, 1, 1), + (shape[0] // 64, shape[1] // 8, 1, 1, 2, 1, 1, 1, 1), ) self.assertEqual(tiled.shape, shape) self.assertEqual(tiled.mlir_dtype, iota.mlir_dtype) - tiled.store_untiled(dst) + tiled.store_untiled(dst, optimized=False) ty = jax.ShapeDtypeStruct(shape, dtype) f = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), (), ty, ()) expected = np.arange(math.prod(shape), dtype=dtype).reshape(shape) @@ -2204,6 +2854,7 @@ def kernel(ctx, dst, _): num_col_tiles=[1, 2, 3], row_tiling=[8, 64], ) + @jtu.thread_unsafe_test() # Modifies ``os.environ``. def test_copy_tiled(self, dtype, swizzle, num_col_tiles, row_tiling): mlir_dtype = utils.dtype_to_ir_type(dtype) bw = bytewidth(mlir_dtype) @@ -2229,7 +2880,7 @@ def kernel(ctx, in_, out, smems): .transpose(0, 2, 1, 3) ) - with get_sass() as sass: + with jtu.set_env(MOSAIC_GPU_DUMP_SASS="1"), self.capture_stdout() as sass: iota = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), expected, expected, [expected, expected, mgpu.TMABarrier()], @@ -2328,6 +2979,7 @@ def kernel(ctx, in_, out, smems): (fa.WGMMA_LAYOUT_UPCAST_2X, fa.WGMMA_LAYOUT, jnp.int4, jnp.int4, 0.5), (fa.WGMMA_LAYOUT_UPCAST_4X, fa.WGMMA_LAYOUT, jnp.int4, jnp.int4, 2), ) + @jtu.thread_unsafe_test() # Modifies ``os.environ``. def test_upcast_to_wgmma( self, start_layout, end_layout, in_dtype, cast_dtype, shfl_per_reg ): @@ -2371,8 +3023,9 @@ def tile(x, tiling): f = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), xt, yt, [xt, yt, mgpu.TMABarrier()], ) - with get_sass() as sass: + with jtu.set_env(MOSAIC_GPU_DUMP_SASS="1"), self.capture_stdout() as sass: yt_kernel = f(xt) + jax.block_until_ready(yt_kernel) np.testing.assert_array_equal(yt_kernel, yt) self.assertEqual(sass().count("SHFL.BFLY"), regs_per_thread * shfl_per_reg) @@ -2430,7 +3083,7 @@ def set_in_transforms( in_transforms = [] smem_refs = filter(inference_utils.is_transformable_smem_memref, op.operands) # pylint: disable=undefined-variable - for _, result_transforms in jax.util.safe_zip(smem_refs, transforms): + for _, result_transforms in jax._src.util.safe_zip(smem_refs, transforms): in_transforms.append( ir.ArrayAttr.get([t.attr() for t in result_transforms]) ) @@ -2476,7 +3129,7 @@ def add(ctx, a, b, result, smem): in_shape=(jax_shape, jax_shape), out_shape=jax_shape, smem_scratch_shape=[], - thread_semantics=mgpu.ThreadSemantics.Warpgroup, + thread_semantics=mgpu.LoweringSemantics.Warpgroup, ) x = self.prng.uniform(-1, 1, shape).astype(dtype) @@ -2568,7 +3221,7 @@ def add( ): del ctx smem_ref, tma_barrier = smem - dialect_barrier = tma_barrier.as_dialect_barrier_memref() + dialect_barrier = tma_barrier.as_barrier_memref() elt_type = ir.MemRefType(in_gmem_ref.type).element_type memref_bytes = utils.bytewidth(elt_type) * math.prod( @@ -2592,7 +3245,7 @@ def add( ) set_in_transforms(load_op, [test_case.transforms]) - parities = memref.load(tma_barrier.phases, []) + parities = memref.load(tma_barrier.barrier_ref.phases, []) parity, _ = tma_barrier.update_parities(parities) mgpu_dialect.wait(dialect_barrier, parity) @@ -2621,7 +3274,7 @@ def add( jax_shape_sliced, core.TMABarrier(1), ], - thread_semantics=mgpu.ThreadSemantics.Warpgroup, + thread_semantics=mgpu.LoweringSemantics.Warpgroup, ) x = self.prng.uniform(-1, 1, test_case.shape).astype(dtype) @@ -2645,7 +3298,7 @@ def add( ): del ctx a_smem_ref, b_smem_ref, result_smem_ref, tma_barrier = smem - dialect_barrier = tma_barrier.as_dialect_barrier_memref() + dialect_barrier = tma_barrier.as_barrier_memref() memref_type = ir.MemRefType(a_gmem_ref.type) shape = memref_type.shape @@ -2677,7 +3330,7 @@ def add( collective=ir.ArrayAttr.get([]), ) - parities = memref.load(tma_barrier.phases, []) + parities = memref.load(tma_barrier.barrier_ref.phases, []) parity, _ = tma_barrier.update_parities(parities) mgpu_dialect.wait(dialect_barrier, parity) @@ -2720,7 +3373,7 @@ def add( spec, core.TMABarrier(1), ], - thread_semantics=mgpu.ThreadSemantics.Warpgroup, + thread_semantics=mgpu.LoweringSemantics.Warpgroup, ) x = self.prng.uniform(-1, 1, spec.shape).astype(dtype) @@ -2728,6 +3381,340 @@ def add( self.assertArraysEqual(jax.jit(kernel)(x, y), x + y + y) + @parameterized.parameters( + ((64,), (64, 128), [0]), + ((64,), (128, 64), [1]), + ) + def test_broadcast_in_dim(self, input_shape, output_shape, bcast_dims): + element_value = 42.0 + def body(ctx, result_gmem_ref, smem): + del ctx + result_smem_ref = smem[0] + + f32 = ir.F32Type.get() + zero_index = arith.constant(ir.IndexType.get(), 0) + + # Create input in registers + x_type = ir.VectorType.get(input_shape, f32) + c = arith.constant(f32, element_value) + x = vector.splat(x_type, c) + + # Computation + out_type = ir.VectorType.get(output_shape, f32) + expanded = mgpu_dialect.broadcast_in_dim(out_type, x, bcast_dims) + cast = mgpu_dialect.layout_cast( + expanded, layouts.to_layout_attr(fa.WGMMA_LAYOUT) + ) + + # Registers -> SMEM + vector.store(cast, result_smem_ref, [zero_index] * len(output_shape)) + + # SMEM -> GMEM + zero_i32 = arith.constant(ir.IntegerType.get_signless(32), 0) + mgpu_dialect.async_store( + source=result_smem_ref, + destination=result_gmem_ref, + indices=[zero_i32] * len(output_shape), + slice_lengths=output_shape, + ) + nvvm.cp_async_bulk_wait_group(0) + utils.warpgroup_barrier() + + dtype = jnp.float32 + kernel = mgpu.as_gpu_kernel( + body, + grid=(1, 1, 1), + block=(128, 1, 1), + in_shape=(), + out_shape=jax.ShapeDtypeStruct(output_shape, dtype), + smem_scratch_shape=[jax.ShapeDtypeStruct(output_shape, dtype)], + thread_semantics=mgpu.LoweringSemantics.Warpgroup, + ) + + x = np.full(input_shape, element_value, dtype=dtype) + self.assertArraysEqual( + jax.jit(kernel)(), jax.lax.broadcast_in_dim(x, output_shape, bcast_dims) + ) + + @parameterized.parameters( + (jnp.float32, 5.0, 2.0, vector.CombiningKind.ADD), + (jnp.float32, 5.0, 2.0, vector.CombiningKind.MAXIMUMF), + (jnp.float32, 5.0, 7.0, vector.CombiningKind.MAXIMUMF), + (jnp.int32, 5, 2, vector.CombiningKind.MAXSI), + (jnp.int32, -5, -2, vector.CombiningKind.MAXSI), + (jnp.int32, -2, -5, vector.CombiningKind.MAXSI), + (jnp.uint32, 5, 2, vector.CombiningKind.MAXUI), + (jnp.uint32, 2, 5, vector.CombiningKind.MAXUI), + # + # TODO(dasenov): Add tests for wgmma_col_layout output once + # fragmented_array.reduce supports that. + ) + def test_vector_multi_dim_reduction( + self, + dtype, + input_value, + init_value, + kind, + ): + input_shape = (128, 64) + output_shape = (128,) + red_dims = [1] + + def body(ctx, result_gmem_ref, smem): + del ctx + result_smem_ref = smem[0] + + el_type = utils.dtype_to_ir_type(dtype) + zero_index = arith.constant(ir.IndexType.get(), 0) + + # Create source in registers + source_type = ir.VectorType.get(input_shape, el_type) + c = arith.constant(el_type, input_value) + source = vector.splat(source_type, c) + + # Create accumulator in registers + acc_type = ir.VectorType.get(output_shape, el_type) + c = arith.constant(el_type, init_value) + acc = vector.splat(acc_type, c) + + # Cast inputs + source = mgpu_dialect.layout_cast( + source, layouts.to_layout_attr(fa.WGMMA_LAYOUT) + ) + acc_layout = ( + fa.WGMMA_ROW_LAYOUT if red_dims[0] == 1 else fa.WGMMA_COL_LAYOUT + ) + acc = mgpu_dialect.layout_cast(acc, layouts.to_layout_attr(acc_layout)) + + # Computation + reduced = vector.multi_reduction(kind, source, acc, red_dims) + + # Registers -> SMEM + vector.store(reduced, result_smem_ref, [zero_index] * len(output_shape)) + + # SMEM -> GMEM + zero_i32 = arith.constant(ir.IntegerType.get_signless(32), 0) + mgpu_dialect.async_store( + source=result_smem_ref, + destination=result_gmem_ref, + indices=[zero_i32] * len(output_shape), + slice_lengths=output_shape, + ) + nvvm.cp_async_bulk_wait_group(0) + utils.warpgroup_barrier() + + kernel = mgpu.as_gpu_kernel( + body, + grid=(1, 1, 1), + block=(128, 1, 1), + in_shape=(), + out_shape=jax.ShapeDtypeStruct(output_shape, dtype), + smem_scratch_shape=[jax.ShapeDtypeStruct(output_shape, dtype)], + thread_semantics=mgpu.LoweringSemantics.Warpgroup, + ) + + source = np.full(input_shape, input_value, dtype=dtype) + acc = np.full(output_shape, init_value, dtype=dtype) + if kind == vector.CombiningKind.ADD: + red = jax.lax.reduce_sum(source, red_dims) + red = red + acc + else: + red = jax.lax.reduce_max(source, red_dims) + red = jax.lax.max(red, acc) + self.assertArraysEqual(jax.jit(kernel)(), red) + + @parameterized.parameters(fa.WGMMA_ROW_LAYOUT, fa.WGMMA_COL_LAYOUT) + def test_wgmma_row_col_store(self, in_layout): + element_value = 42.0 + shape = (64, ) + def body(ctx, result_gmem_ref, smem): + del ctx + result_smem_ref = smem[0] + + f32 = ir.F32Type.get() + zero_index = arith.constant(ir.IndexType.get(), 0) + + # Create input in registers + x_type = ir.VectorType.get(shape, f32) + c = arith.constant(f32, element_value) + x = vector.splat(x_type, c) + cast = mgpu_dialect.layout_cast(x, layouts.to_layout_attr(in_layout)) + + # Registers -> SMEM + vector.store(cast, result_smem_ref, [zero_index]) + + # SMEM -> GMEM + zero_i32 = arith.constant(ir.IntegerType.get_signless(32), 0) + mgpu_dialect.async_store( + source=result_smem_ref, + destination=result_gmem_ref, + indices=[zero_i32], + slice_lengths=shape, + ) + nvvm.cp_async_bulk_wait_group(0) + utils.warpgroup_barrier() + + dtype = jnp.float32 + kernel = mgpu.as_gpu_kernel( + body, + grid=(1, 1, 1), + block=(128, 1, 1), + in_shape=(), + out_shape=jax.ShapeDtypeStruct(shape, dtype), + smem_scratch_shape=[jax.ShapeDtypeStruct(shape, dtype)], + thread_semantics=mgpu.LoweringSemantics.Warpgroup, + ) + + x = np.full(shape, element_value, dtype=dtype) + self.assertArraysEqual(jax.jit(kernel)(), x) + + @parameterized.parameters( + # Positive offsets will be passsed as static offsets. + # Negative offsets will be converted to positive dynamic offsets. + ((2, 3, 128, 64), (32, 64), [-1, 0, -96, 0], None, None, None), + ( + (3, 128, 64), + (32, 64), + [-2, -96, 0], + [32, 64], + mgpu_dialect.SwizzlingMode.k128ByteSwizzle, + None, + ), + ( + (128, 128), + (64,), + [-1, 64], + [64], + mgpu_dialect.SwizzlingMode.k128ByteSwizzle, + "Swizzle transforms .* if the minor dimension is unchanged.", + ), + ) + def test_subview( + self, + full_shape, + sub_shape, + offsets, + tiling, + swizzle, + error_regex, + ): + assert len(sub_shape) <= 2 + sizes = [1] * (len(full_shape) - len(sub_shape)) + list(sub_shape) + + def body( + ctx: launch_context.LaunchContext, + full_gmem_ref: ir.Value, + sub_gmem_ref: ir.Value, + smem: list[ir.Value], + ): + del ctx + full_smem_ref, tma_barrier = smem + dialect_barrier = tma_barrier.as_barrier_memref() + + operand_elt_type = ir.MemRefType(full_gmem_ref.type).element_type + mgpu_dialect.arrive_expect_tx( + barrier=dialect_barrier, + expect_tx=utils.bytewidth(operand_elt_type) * math.prod(full_shape), + ) + + zero_i32 = arith.constant(ir.IntegerType.get_signless(32), 0) + # GMEM -> SMEM + mgpu_dialect.async_load( + source=full_gmem_ref, + destination=full_smem_ref, + barrier=dialect_barrier, + indices=[zero_i32] * len(full_shape), + slice_lengths=full_shape, + collective=ir.ArrayAttr.get([]), + ) + + parities = memref.load(tma_barrier.barrier_ref.phases, []) + parity, _ = tma_barrier.update_parities(parities) + mgpu_dialect.wait(dialect_barrier, parity) + + # SubView + dynamic_offsets = [ + arith.constant(ir.IndexType.get(), -o) for o in offsets if o < 0 + ] + + full_ref_type = ir.MemRefType(full_smem_ref.type) + dynamic = ir.ShapedType.get_dynamic_stride_or_offset() + rhs_subview_ref_type = ir.MemRefType.get( + shape=sub_shape, + element_type=full_ref_type.element_type, + layout=ir.StridedLayoutAttr.get( + dynamic, [full_shape[-1], 1] if len(sub_shape) == 2 else [1] + ), + memory_space=full_ref_type.memory_space, + ) + sub_smem_ref = memref.SubViewOp( + result=rhs_subview_ref_type, + source=full_smem_ref, + offsets=dynamic_offsets, + sizes=None, + strides=None, + static_offsets=[(dynamic if o < 0 else o) for o in offsets], + static_sizes=sizes, + static_strides=[1] * len(sizes), + ).result + + transforms = [] + if tiling is not None: + transforms.append(mgpu_dialect.TileTransformAttr.get(tiling)) + if swizzle is not None: + transforms.append(mgpu_dialect.SwizzleTransformAttr.get(swizzle)) + + if transforms: + # TODO(dasenov): Remove this after the minimal jaxlib version is 0.6.2. + if jaxlib.version < (0, 6, 2): + self.skipTest("Test requires jaxlib version >= 0.6.2") + + sub_smem_ref = mgpu_dialect.with_transforms( + sub_smem_ref, + transforms=ir.ArrayAttr.get(transforms), + ) + + # SMEM -> GMEM + mgpu_dialect.async_store( + source=sub_smem_ref, + destination=sub_gmem_ref, + indices=[zero_i32] * len(sub_shape), + slice_lengths=sub_shape, + ) + nvvm.cp_async_bulk_wait_group(0) + + el_type = jnp.bfloat16 + full_jax_shape = jax.ShapeDtypeStruct(full_shape, el_type) + result_jax_shape = jax.ShapeDtypeStruct(sub_shape, el_type) + + def create_kernel(): + return mgpu.as_gpu_kernel( + body, + grid=(1, 1, 1), + block=(128, 1, 1), + in_shape=(full_jax_shape), + out_shape=result_jax_shape, + smem_scratch_shape=[full_jax_shape, core.TMABarrier(1)], + thread_semantics=mgpu.LoweringSemantics.Warpgroup, + ) + + if error_regex: + with self.assertRaisesRegex(NotImplementedError, error_regex): + # While we expect NotImplementedError here, the test is actually + # checking a restricted behaviour that should be a ValueError. However, + # our code cannot yet figure out the difference and raise the correct + # type. + create_kernel() + else: + prng_key = jax.random.key(1234) + x = jax.random.randint(prng_key, full_shape, 0, 10).astype(el_type) + + slicing = tuple(slice(abs(o), abs(o) + s) for o, s in zip(offsets, sizes)) + self.assertArraysEqual( + jax.jit(create_kernel())(x), + x[slicing].reshape(sub_shape), + ) + class MosaicGpuDialectSm90ATest(Sm90ATestCase, jtu.JaxTestCase): @@ -2750,8 +3737,8 @@ def test_wgmma_kernel_with_tma( if swizzle == mgpu_dialect.SwizzlingMode.kNoSwizzle: self.skipTest("No swizzle is not supported by wgmma") - if transpose_lhs or transpose_rhs: - self.skipTest("Transposes are not supported by transform inference yet.") + if transpose_lhs and load_a_in_registers: + self.skipTest("The A operand can only be transposed if it is in SMEM.") swizzle_elems = swizzle // np.dtype(jnp.bfloat16).itemsize tiling_m, tiling_n, tiling_k = 64, swizzle_elems, swizzle_elems @@ -2772,7 +3759,7 @@ def matmul( ): del ctx lhs_smem_ref, rhs_smem_ref, result_smem_ref, tma_barrier = smem - dialect_barrier = tma_barrier.as_dialect_barrier_memref() + dialect_barrier = tma_barrier.as_barrier_memref() operand_elt_type = ir.MemRefType(lhs_gmem_ref.type).element_type bytes_a = utils.bytewidth(operand_elt_type) * math.prod(lhs_shape) @@ -2802,7 +3789,7 @@ def matmul( collective=ir.ArrayAttr.get([]), ) - parities = memref.load(tma_barrier.phases, []) + parities = memref.load(tma_barrier.barrier_ref.phases, []) parity, _ = tma_barrier.update_parities(parities) mgpu_dialect.wait(dialect_barrier, parity) @@ -2868,7 +3855,7 @@ def matmul( result_jax_shape, core.TMABarrier(1), ], - thread_semantics=mgpu.ThreadSemantics.Warpgroup, + thread_semantics=mgpu.LoweringSemantics.Warpgroup, ) prng_key = jax.random.key(1234) @@ -2916,6 +3903,34 @@ def test_parse_indices_oob(self, indices): with self.assertRaisesRegex(IndexError, "out of bounds"): utils.parse_indices(indices, (2, 3, 4)) + @jtu.thread_unsafe_test() # Modifies ``os.environ``. + def test_assert(self): + if cf is None: + self.skipTest("``cf`` is not available") + + def kernel(ctx: mgpu.LaunchContext, x_ref, out, scratch) -> None: + del ctx, out # Unused. + # TODO(b/408271232): Use a False condition once the bug is fixed. + x = mgpu.FragmentedArray.load_strided(x_ref) + cond = x.reduce("add", 0, *scratch) != 42.0 + cf.assert_(cond.registers.item(), "OOOPS") + + f = mgpu.as_gpu_kernel( + kernel, + grid=(1, 1, 1), + block=(128, 1, 1), + in_shape=(jax.ShapeDtypeStruct((128,), jnp.float32),), + out_shape=jax.ShapeDtypeStruct((128,), jnp.float32), + smem_scratch_shape=(jax.ShapeDtypeStruct((4,), jnp.float32),), + ) + + with jtu.set_env(MOSAIC_GPU_DUMP_SASS="1"), self.capture_stdout() as sass: + jax.block_until_ready(f(jnp.ones((128,), jnp.float32))) + + # SASS doesn't seem to include the assertion message, so we are just + # checking that __assertfail appears in the symbol table for the kernel. + self.assertIn("__assertfail", sass()) + class SerializationTest(absltest.TestCase): @@ -2931,5 +3946,187 @@ def test_pass_is_registered(self): pipeline.run(module.operation) +class ApiTest(TestCase): + + def test_inout(self): + def kernel(ctx, src, inout, dst, smem): + val = memref.load(inout, []) + gpu.barrier() + new_val = arith.constant(ir.IntegerType.get_signless(32), 42) + memref.store(new_val, inout, []) + x = mgpu.FragmentedArray.load_strided(src, is_signed=True) + (x + val).store_untiled(dst) + x = jnp.arange(128, dtype=jnp.int32) + y = jnp.asarray(2.0, dtype=jnp.int32) + kernel = mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), x, x, (), inout_shape=y, + ) + xo, yo = kernel(x, y) + np.testing.assert_array_equal(xo, x + 2.0) + np.testing.assert_array_equal(yo, jnp.asarray(42, dtype=jnp.int32)) + + +if hp is not None: + @hps.composite + def tiled_layouts( + draw, initial_tile, vector_transfer: bool = False + ) -> fa.TiledLayout: + assert all(t.bit_count() == 1 for t in initial_tile) + assert math.prod(initial_tile) >= 128 + tiles = [initial_tile] + dim_offset = len(initial_tile) + warp_dim = fa.Replicated(4) + if draw(hps.booleans()): + warp_dim = draw( + hps.sampled_from( + [i for i, t in enumerate(tiles[-1]) if t % 4 == 0] + ) + ) + warp_tile = list(tiles[-1]) + warp_tile[warp_dim] //= 4 + warp_dim += dim_offset + tiles.append(warp_tile) + dim_offset += len(tiles[-1]) + lane_dims = [fa.Replicated(2) if draw(hps.booleans()) else None for _ in range(5)] + for i, dim in enumerate(lane_dims): + if isinstance(dim, fa.Replicated): + continue + lane_dim = draw(hps.sampled_from( + [i for i, t in enumerate(tiles[-1]) if t % 2 == 0] + )) + lane_tile = list(tiles[-1]) + lane_tile[lane_dim] //= 2 + lane_dims[i] = dim_offset + lane_dim + tiles.append(lane_tile) + dim_offset += len(lane_tile) + # Permute lane dims so that they don't always partition the data in order. + lane_dims = draw(hps.permutations(lane_dims)) + if vector_transfer: + min_vector_dim = len(tiles[-1]) - 1 + else: + min_vector_dim = 0 + vector_dim = draw(hps.integers(min_vector_dim, len(tiles[-1]) - 1)) + vector_size = 2 ** draw( + hps.integers(0, tiles[-1][vector_dim].bit_length() - 1) + ) + vector_tile = list(tiles[-1]) + assert vector_tile[vector_dim] % vector_size == 0 + vector_tile[vector_dim] //= vector_size + tiles.append(vector_tile) + dim_offset += len(vector_tile) + vector_dim += dim_offset + dim_offset += len(vector_tile) # This is the remainder after tiling! + + if not isinstance(warp_dim, fa.Replicated): + warp_dim = warp_dim - dim_offset + lane_dims = tuple( + d if isinstance(d, fa.Replicated) else d - dim_offset + for d in lane_dims + ) + vector_dim = vector_dim - dim_offset + return fa.TiledLayout( + tiling=fa.Tiling(tuple(map(tuple, tiles))), + warp_dim=warp_dim, + lane_dims=lane_dims, + vector_dim=vector_dim, + _check_canonical=False, + ).canonicalize() + + @hps.composite + def shape_and_tiled_layout( + draw, vector_transfer: bool = False + ) -> tuple[tuple[int, ...], fa.TiledLayout]: + rank = draw(hps.integers(2, 3)) + initial_tile = tuple( + draw(hps.sampled_from([1, 2, 4, 8, 16, 32, 64, 128])) + for _ in range(rank) + ) + hp.assume(128 <= math.prod(initial_tile) < 128 * 32) + shape = tuple(t * draw(hps.integers(1, 5)) for t in initial_tile) + hp.assume(math.prod(shape) <= 128 * 128) + layout = draw(tiled_layouts(initial_tile, vector_transfer=vector_transfer)) + return shape, layout + + class HypothesisTest(TestCase): + + def test_reduce(self): + @hps.composite + def strategy(draw): + shape, layout = draw(shape_and_tiled_layout(vector_transfer=True)) + rank = len(shape) + reduced_dims = draw(hps.sets(hps.integers(0, rank - 1), min_size=1)) + dtype = draw(hps.sampled_from([jnp.float32, jnp.bfloat16])) + return shape, layout, tuple(reduced_dims), dtype + + @hp.given(strategy()) + def run(args): + shape, layout, reduced_dims, dtype = args + out_shape = list(shape) + for d in sorted(reduced_dims, reverse=True): + del out_shape[d] + def kernel(ctx, src, dst, scratch): + del ctx + arr = fa.FragmentedArray.load_untiled(src, layout=layout, optimized=False) + arr.reduce("max", reduced_dims, scratch).store_untiled(dst, optimized=False) + x = jax.random.normal(jax.random.key(1234), shape, dtype) + out_type = jax.ShapeDtypeStruct(out_shape, dtype) + scratch_type = jax.ShapeDtypeStruct((2048,), dtype) + hp.assume(layout.vector_length <= 16) # Otherwise we run out of scratch + try: + result = mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), x, out_type, scratch_type + )(x) + except NotImplementedError: + hp.assume(False) + return + np.testing.assert_array_equal(result, x.max(reduced_dims)) + run() + + def test_slice(self): + i32 = ir.IntegerType.get_signless(32) + index = ir.IndexType.get() + + @hps.composite + def strategy(draw): + shape, layout = draw(shape_and_tiled_layout(vector_transfer=True)) + tiling = layout.base_tile_shape + tiled_shape = mgpu.tile_shape(shape, tiling)[:len(shape)] + def draw_slice(size, tile): + start = draw(hps.integers(0, size - 1)) + length = draw(hps.integers(1, size - start)) + return slice(start * tile, (start + length) * tile) + slices = tuple(map(draw_slice, tiled_shape, tiling)) + return shape, layout, slices + + basic_slices = (slice(128, 256), slice(16, 16 + 32)) + @hp.given(strategy()) + @hp.example(((256, 256), fa.WGMMA_LAYOUT, basic_slices)) + @hp.example(((256, 256), tcgen05.LAYOUT, basic_slices)) + @hp.example(((256, 256), tcgen05.TMEM_NATIVE_LAYOUT, basic_slices)) + def run(args): + shape, layout, slices = args + def kernel(ctx, dst, _): + def linear_index(*idxs): + total = arith.constant(index, 0) + stride = 1 + for i, size in zip(idxs[::-1], shape[::-1]): + total = arith.addi(total, arith.muli(i, c(stride, index))) + stride *= size + return arith.index_cast(i32, total) + x = mgpu.FragmentedArray.build( + shape, layout, linear_index, is_signed=True + ) + x[slices].store_untiled(dst, optimized=False) + + slice_shape = tuple(len(range(size)[s]) for s, size in zip(slices, shape)) + out_shape = jax.ShapeDtypeStruct(shape=slice_shape, dtype=jnp.int32) + result = mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), (), out_shape, () + )() + iota = np.arange(np.prod(shape), dtype=jnp.int32).reshape(*shape) + np.testing.assert_array_equal(result, iota[slices]) + run() + + if __name__ == "__main__": absltest.main(argv=["python"], testLoader=jtu.JaxTestLoader()) diff --git a/tests/mosaic/gpu_test_distributed.py b/tests/mosaic/gpu_test_distributed.py new file mode 100644 index 000000000000..c289b27c0be1 --- /dev/null +++ b/tests/mosaic/gpu_test_distributed.py @@ -0,0 +1,164 @@ +# Copyright 2025 The JAX Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import os + +from absl.testing import parameterized +import jax +from jax._src import config +from jax._src import test_util as jtu +from jax._src import test_multiprocess as jt_multiprocess +from jax._src.interpreters import mlir +from jax._src.lib.mlir import ir +from jax._src.lib.mlir.dialects import arith +from jax._src.lib.mlir.dialects import memref +from jax.experimental.mosaic.gpu import dialect as mgpu_dialect # pylint: disable=g-importing-member +from jax.experimental import shard +from jax.experimental import multihost_utils +import jax.numpy as jnp +import numpy as np +try: + import jax._src.lib.mosaic_gpu # noqa: F401 + HAS_MOSAIC_GPU = True +except ImportError: + HAS_MOSAIC_GPU = False +else: + import jax.experimental.mosaic.gpu as mgpu + + +# ruff: noqa: F405 +# pylint: disable=g-complex-comprehension +P = jax.sharding.PartitionSpec + + +class TestCase(parameterized.TestCase): + + def setUp(self): + if not HAS_MOSAIC_GPU: + self.skipTest("jaxlib built without Mosaic GPU") + if (not jtu.test_device_matches(["cuda"]) or + not jtu.is_cuda_compute_capability_at_least("9.0")): + self.skipTest("Only works on GPU with capability >= sm90") + if not mgpu.supports_cross_device_collectives(): + self.skipTest("NVSHMEM library unavailable.") + if os.environ.get("XLA_PYTHON_CLIENT_ALLOCATOR", "") == "platform": + self.skipTest("NVSHMEM doesn't work with the platform allocator.") + if jax.process_count() == 1: + self.skipTest("Test requires multiple processes.") + if jax.device_count() != jax.process_count(): + self.skipTest("Need 1 device per process") + super().setUp() + self.prng = np.random.default_rng(1234) + self.context = mlir.make_ir_context() + if mgpu_dialect is not None: + mgpu_dialect.register_dialect(self.context) + self.enter_context(config.traceback_filtering("off")) + self.enter_context(self.context) + self.enter_context(ir.Location.unknown()) + + +class ProfilerTest(TestCase): + + def test_get_device_id(self): + index = ir.IndexType.get() + def kernel(ctx, dst, _): + device_id = ctx.device_id() + memref.store(device_id, dst, [arith.constant(index, 0)]) + mesh = jax.make_mesh( + (jax.device_count(),), ("x",), axis_types=(jax.sharding.AxisType.Explicit,) + ) + with jax.sharding.use_mesh(mesh): + out_shape = jax.ShapeDtypeStruct((1,), jnp.int32) + y = jax.jit( + jax.shard_map( + mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), (), out_shape, () + ), + out_specs=P("x"), + check_vma=False, + ) + )() + y_np = multihost_utils.process_allgather(y, tiled=True) + np.testing.assert_array_equal(y_np, np.arange(jax.device_count())) + + def test_remote_async_copy(self): + i32 = ir.IntegerType.get_signless(32) + def kernel(ctx, src, dst, scratch): + tmp, barrier = scratch + other_device = arith.subi(arith.constant(i32, 1), ctx.device_id()) + ctx.async_copy(src_ref=src, dst_ref=tmp, barrier=barrier) + barrier.wait() + ctx.async_copy(src_ref=tmp, dst_ref=dst, gmem_peer_id=other_device) + ctx.await_async_copy(0) + mesh = jax.make_mesh( + (2,), ("x",), axis_types=(jax.sharding.AxisType.Explicit,) + ) + with jax.sharding.use_mesh(mesh): + x_np = np.arange(64 * 64, dtype=jnp.float32).reshape(64, 64) + x = shard.reshard(x_np, P("x")) + y = jax.jit( + jax.shard_map( + lambda x: mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), x, x, (x, mgpu.TMABarrier()) + )(x), + out_specs=P("x"), + check_vma=False, + ) + )(x) + y_np = multihost_utils.process_allgather(y, tiled=True) + np.testing.assert_array_equal( + y_np, np.concatenate(np.split(x_np, 2)[::-1], axis=0) + ) + + def test_remote_semaphore(self): + i32 = ir.IntegerType.get_signless(32) + def kernel(ctx, sem, _): + my_device = ctx.device_id() + other_device = arith.subi(arith.constant(i32, 1), my_device) + my_sem = mgpu.SemaphoreRef(mgpu.utils.memref_ptr(sem)) + other_dst = ctx.to_remote(sem, other_device) + other_sem = mgpu.SemaphoreRef(mgpu.utils.memref_ptr(other_dst)) + # We signal and wait a different amount on each device to make sure we're + # really communicating here. + other_sem.signal(arith.addi(arith.constant(i32, 1), other_device)) + @mgpu.fori(arith.addi(arith.constant(i32, 1), my_device), None) + def wait_loop(i, _): + my_sem.wait(1) + + mesh = jax.make_mesh( + (2,), ("x",), axis_types=(jax.sharding.AxisType.Explicit,) + ) + with jax.sharding.use_mesh(mesh): + sem = shard.reshard(jnp.zeros((1,), dtype=jnp.int32), P()) + out_sem = jax.jit( + jax.shard_map( + mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), (), (), (), inout_shape=sem + ), + out_specs=P("x"), + check_vma=False, + ) + )(sem) + out_sems = multihost_utils.process_allgather(out_sem, tiled=True) + np.testing.assert_array_equal(out_sems, np.zeros_like(out_sems)) + + +if __name__ == "__main__": + # This test doesn't work with the platform allocator, so we override it + # if it's ran alone. If it's part of a larger test suite and the platform + # allocator is used, setUp will skip the test. + os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.01' + os.environ['XLA_PYTHON_CLIENT_ALLOCATOR'] = 'default' + jt_multiprocess.main() diff --git a/tests/mosaic/gpu_test_multidevice.py b/tests/mosaic/gpu_test_multidevice.py new file mode 100644 index 000000000000..114409a5efd8 --- /dev/null +++ b/tests/mosaic/gpu_test_multidevice.py @@ -0,0 +1,74 @@ +# Copyright 2025 The JAX Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from absl.testing import absltest, parameterized +import jax +from jax._src import config +from jax._src import test_util as jtu +from jax._src.interpreters import mlir +from jax._src.lib.mlir import ir +from jax.experimental.mosaic.gpu import dialect as mgpu_dialect # pylint: disable=g-importing-member +import jax.numpy as jnp +import numpy as np +try: + import jax._src.lib.mosaic_gpu # noqa: F401 + HAS_MOSAIC_GPU = True +except ImportError: + HAS_MOSAIC_GPU = False +else: + import jax.experimental.mosaic.gpu as mgpu + + +# ruff: noqa: F405 +# pylint: disable=g-complex-comprehension +config.parse_flags_with_absl() + + +class TestCase(parameterized.TestCase): + + def setUp(self): + if not HAS_MOSAIC_GPU: + self.skipTest("jaxlib built without Mosaic GPU") + if (not jtu.test_device_matches(["cuda"]) or + not jtu.is_cuda_compute_capability_at_least("9.0")): + self.skipTest("Only works on GPU with capability >= sm90") + super().setUp() + self.prng = np.random.default_rng(1234) + self.context = mlir.make_ir_context() + if mgpu_dialect is not None: + mgpu_dialect.register_dialect(self.context) + self.enter_context(config.traceback_filtering("off")) + self.enter_context(self.context) + self.enter_context(ir.Location.unknown()) + + +class ProfilerTest(TestCase): + + def test_multigpu(self): + if len(jax.devices()) < 2: + self.skipTest("Need at least 2 devices") + def kernel(ctx, src, dst, _): + mgpu.FragmentedArray.load_strided(src).store_untiled(dst) + x = np.arange(64 * 64, dtype=jnp.float32).reshape(64, 64) + f = jax.jit(mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), x, x, () + )) + # Make sure we can invoke the same program on different devices. + for xd in (jax.device_put(x, d) for d in jax.devices()[:2]): + jax.block_until_ready(f(xd)) + + +if __name__ == "__main__": + absltest.main(argv=["python"], testLoader=jtu.JaxTestLoader()) diff --git a/tests/mosaic/gpu_transform_inference_test.py b/tests/mosaic/gpu_transform_inference_test.py index b7cd146dfdb6..b4a41a141bf6 100644 --- a/tests/mosaic/gpu_transform_inference_test.py +++ b/tests/mosaic/gpu_transform_inference_test.py @@ -19,12 +19,14 @@ from absl.testing import parameterized import jax from jax import numpy as jnp +from jax._src import lib as jaxlib from jax._src import config from jax._src import test_util as jtu from jax._src.interpreters import mlir as mlir_interpreter from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import arith from jax._src.lib.mlir.dialects import func +from jax._src.lib.mlir.dialects import memref from jax._src.lib.mlir.dialects import vector import jax.experimental.mosaic.gpu as mgpu from jax.experimental.mosaic.gpu import fragmented_array as fa @@ -418,6 +420,251 @@ def body(offset): with self.assertRaisesRegex(NotImplementedError, "Conflicting transforms"): mgpu.infer_transforms(self.module) + @parameterized.parameters([False, True]) + def test_infer_transforms_for_subview_op_propagates_undisturbed_tile_and_swizzle_transforms( + self, annotate_input + ): + subview_op = user_op = None + shape = (2, 64, 64) + elt_ty = ir.BF16Type.get() + smem = ir.Attribute.parse("#gpu.address_space") + + in_ref_ty = ir.MemRefType.get(shape, elt_ty, memory_space=smem) + out_ref_ty = ir.MemRefType.get(shape[2:], elt_ty, memory_space=smem) + + def body(in_ref): + nonlocal subview_op, user_op + subview_op = memref.SubViewOp( + out_ref_ty, + in_ref, + [], + [], + [], + static_offsets=[1, 0, 0], + static_sizes=[1, 64, 64], + static_strides=[1, 1, 1], + ) + user_op = memref.CastOp(out_ref_ty, subview_op.result) + + with ir.InsertionPoint(self.module.body): + f = func.FuncOp.from_py_func(in_ref_ty)(body).func_op + + transforms = ir.ArrayAttr.get([ + mgpu.dialect.TileTransformAttr.get((32, 16)), + mgpu.dialect.SwizzleTransformAttr.get(32), + ]) + + if annotate_input: + f.attributes["in_transforms"] = ir.ArrayAttr.get([transforms]) + else: + user_op.attributes["in_transforms"] = ir.ArrayAttr.get([transforms]) + + mgpu.infer_transforms(self.module) + + self.assertSequenceEqual( + inference_utils.in_transforms(subview_op), [transforms] + ) + self.assertSequenceEqual( + inference_utils.out_transforms(subview_op), [transforms] + ) + + def test_infer_transforms_sets_default_emptry_transforms(self): + async_load_op = None + shape = (64, 64) + elt_ty = ir.BF16Type.get() + + def body(gmem_ref, smem_ref, barrier): + nonlocal async_load_op + zero = arith.constant(ir.IntegerType.get_signless(32), 0) + async_load_op = mgpu.dialect.AsyncLoadOp( + source=gmem_ref, + destination=smem_ref, + barrier=barrier, + indices=[zero, zero], + slice_lengths=shape, + collective=ir.ArrayAttr.get([]), + ) + + with ir.InsertionPoint(self.module.body): + smem = ir.Attribute.parse("#gpu.address_space") + gmem_ty = ir.MemRefType.get(shape, elt_ty) + smem_ty = ir.MemRefType.get(shape, elt_ty, memory_space=smem) + barrier_ty = ir.Type.parse("!mosaic_gpu.barrier") + func.FuncOp.from_py_func(gmem_ty, smem_ty, barrier_ty)(body).func_op + + mgpu.infer_transforms(self.module) + [in_transform] = inference_utils.in_transforms(async_load_op) + self.assertSequenceEqual(in_transform, ir.ArrayAttr.get([])) + self.assertEmpty(inference_utils.out_transforms(async_load_op)) + + @parameterized.parameters([False, True]) + def test_infer_transforms_for_subview_op_raises_on_disturbed_transforms( + self, annotate_input + ): + subview_op = user_op = None + shape = (2, 64, 64) + elt_ty = ir.BF16Type.get() + smem = ir.Attribute.parse("#gpu.address_space") + + in_ref_ty = ir.MemRefType.get(shape, elt_ty, memory_space=smem) + out_ref_ty = ir.MemRefType.get((2, 64, 32), elt_ty, memory_space=smem) + + def body(in_ref): + nonlocal subview_op, user_op + subview_op = memref.SubViewOp( + out_ref_ty, + in_ref, + [], + [], + [], + static_offsets = [1, 0, 0], + static_sizes = [2, 64, 32], + static_strides = [1, 1, 1] + ) + user_op = memref.CastOp(out_ref_ty, subview_op.result) + + with ir.InsertionPoint(self.module.body): + f = func.FuncOp.from_py_func(in_ref_ty)(body).func_op + + transforms = ir.ArrayAttr.get([ + mgpu.dialect.TileTransformAttr.get((32, 16)), + mgpu.dialect.SwizzleTransformAttr.get(32), + ]) + + if annotate_input: + f.attributes["in_transforms"] = ir.ArrayAttr.get([transforms]) + else: + user_op.attributes["in_transforms"] = ir.ArrayAttr.get([transforms]) + + if annotate_input: + f.attributes["in_transforms"] = ir.ArrayAttr.get([transforms]) + else: + user_op.attributes["in_transforms"] = ir.ArrayAttr.get([transforms]) + + with self.assertRaises(NotImplementedError): + mgpu.infer_transforms(self.module) + + @parameterized.parameters([False, True]) + def test_infer_transforms_for_sibling_subviews_and_distant_op( + self, even_offsets + ): + # This test uses the following op tree extracted from this ragged dot + # kernel: + # https://github.com/jax-ml/jax/blob/main/jax/experimental/pallas/ops/gpu/ragged_dot_mgpu.py + # + # subview_op0 (slice = 64, 64) + # - subview_op1 (slice = 2, 64) + # - subview_op2 (slice = 4, 64, either at an even or odd offset) + # - subview_op3 (slice = 8, 64) + # - user_op0 (in_transforms = [tile(64, 64), swizzle(32)]) + # + # First the in_transforms of user_op0 have to be propagated up to + # subview_op0. Then they have to be propagated down and resolved. Finally + # all subview ops need to have the same transforms. + + # TODO(dasenov): Remove this after the minimal jaxlib version is 0.6.2. + if jaxlib.version < (0, 6, 2): + self.skipTest("Test requires jaxlib version >= 0.6.2") + + subview_op0, subview_op1, subview_op2, subview_op3 = None, None, None, None + user_op0 = None + + source_shape = (64, 64) + elt_ty = ir.BF16Type.get() + smem = ir.Attribute.parse("#gpu.address_space") + source_ref_ty = ir.MemRefType.get(source_shape, elt_ty, memory_space=smem) + + slice1_shape = (2, 64) + slice2_shape = (4, 64) + slice3_shape = (8, 64) + + slice0_ref_ty = ir.MemRefType.get(source_shape, elt_ty, memory_space=smem) + slice1_ref_ty = ir.MemRefType.get(slice1_shape, elt_ty, memory_space=smem) + slice2_ref_ty = ir.MemRefType.get(slice2_shape, elt_ty, memory_space=smem) + slice3_ref_ty = ir.MemRefType.get(slice3_shape, elt_ty, memory_space=smem) + + def body(source_ref): + nonlocal subview_op0, subview_op1, subview_op2, subview_op3, user_op0 + + subview_op0 = memref.SubViewOp( + slice0_ref_ty, + source_ref, + [], # dynamic offsets + [], # dynamic sizes + [], # dynamic strides + static_offsets=[0, 0], + static_sizes=source_shape, + static_strides=[1, 1], + ) + + transforms_0 = ir.ArrayAttr.get([ + mgpu.dialect.TileTransformAttr.get((64, 64)), + mgpu.dialect.SwizzleTransformAttr.get(32), + ]) + user_op0 = mgpu.dialect.WithTransformsOp(subview_op0.result, transforms_0) + + subview_op1 = memref.SubViewOp( + slice1_ref_ty, + subview_op0, + [], # dynamic offsets + [], # dynamic sizes + [], # dynamic strides + static_offsets=[0, 0], + static_sizes=slice1_shape, + static_strides=[1, 1], + ) + + subview_op2 = memref.SubViewOp( + slice2_ref_ty, + subview_op0, + [], # dynamic offsets + [], # dynamic sizes + [], # dynamic strides + static_offsets=[16 if even_offsets else 15, 0], + static_sizes=slice2_shape, + static_strides=[1, 1], + ) + + # The following ops are just to test the dynamic offsets support. + c = lambda x: arith.constant(ir.IntegerType.get_signless(32), x) + c64 = c(64) + c32 = c(32) + c16 = c(16) + subi = arith.subi(c64, c32) + maxsi = arith.maxsi(c16, subi) + addi = arith.addi(maxsi, subi) + andi = arith.andi(addi, maxsi) + idx = arith.index_cast(ir.IndexType.get(), andi) + subview_op3 = memref.SubViewOp( + slice3_ref_ty, + subview_op0, + [idx], # dynamic offsets + [], # dynamic sizes + [], # dynamic strides + static_offsets=[ir.ShapedType.get_dynamic_size(), 0], + static_sizes=slice3_shape, + static_strides=[1, 1], + ) + + with ir.InsertionPoint(self.module.body): + func.FuncOp.from_py_func(source_ref_ty)(body) + + mgpu.infer_transforms(self.module) + + want = ir.ArrayAttr.get([ + mgpu.dialect.TileTransformAttr.get((2 if even_offsets else 1, 64)), + mgpu.dialect.SwizzleTransformAttr.get(32), + ]) + + self.assertSequenceEqual(inference_utils.in_transforms(subview_op0), [want]) + self.assertSequenceEqual(inference_utils.out_transforms(subview_op0), [want]) + self.assertSequenceEqual(inference_utils.in_transforms(subview_op1), [want]) + self.assertSequenceEqual(inference_utils.out_transforms(subview_op1), [want]) + self.assertSequenceEqual(inference_utils.in_transforms(subview_op2), [want]) + self.assertSequenceEqual(inference_utils.out_transforms(subview_op2), [want]) + self.assertSequenceEqual(inference_utils.in_transforms(subview_op3), [want]) + self.assertSequenceEqual(inference_utils.out_transforms(subview_op3), [want]) + if __name__ == "__main__": parameterized.absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/mosaic/matmul_test.py b/tests/mosaic/matmul_test.py index d598d7d0c0ec..13082885710e 100644 --- a/tests/mosaic/matmul_test.py +++ b/tests/mosaic/matmul_test.py @@ -15,12 +15,19 @@ """Test different parameterizations of a matmul.""" import os -import unittest from absl.testing import absltest, parameterized from jax._src import config from jax._src import test_util as jtu +from jax._src.interpreters import mlir +from jax._src.lib.mlir import ir +from jax.experimental.mosaic.gpu import dialect as mgpu_dialect # pylint: disable=g-importing-member import jax.numpy as jnp +import numpy as np + +import hypothesis as hp +import hypothesis.strategies as hps + try: # We only import this to see if Mosaic is available. import jax.experimental.mosaic.gpu # noqa: F401 @@ -28,11 +35,7 @@ matmul = None else: from jax.experimental.mosaic.gpu.examples import matmul -try: - import hypothesis as hp - import hypothesis.strategies as hps -except (ModuleNotFoundError, ImportError): - raise unittest.SkipTest("these tests require hypothesis") + from jax.experimental.mosaic.gpu.examples import matmul_blackwell config.parse_flags_with_absl() @@ -48,15 +51,20 @@ def wrapper(self, seed): @jtu.with_config(jax_traceback_filtering="off") +@jtu.thread_unsafe_test_class() # hypothesis is not thread safe class MatmulTestCase(jtu.JaxTestCase): def setUp(self): super().setUp() if matmul is None: self.skipTest("Mosaic GPU not available.") - if (not jtu.test_device_matches(["cuda"]) or - not jtu.is_cuda_compute_capability_equal("9.0")): - self.skipTest("Only works on GPU with capability sm90a") + if not jtu.test_device_matches(["cuda"]): + self.skipTest("Test needs a GPU device") + self.context = mlir.make_ir_context() + mgpu_dialect.register_dialect(self.context) + self.enter_context(config.traceback_filtering("off")) + self.enter_context(self.context) + self.enter_context(ir.Location.unknown()) @parameterized.named_parameters( (f"_shard{i}", i) for i in range(5) @@ -64,7 +72,10 @@ def setUp(self): @seed_hypothesis @hp.settings(max_examples=100) # Add verbosity=hp.Verbosity.verbose to debug @hp.given(hps.data()) - def test_matmul(self, data): + def test_matmul_sm90(self, data): + if not jtu.is_cuda_compute_capability_equal("9.0"): + self.skipTest("Only works on GPU with capability sm90a") + in_dtype = data.draw( hps.sampled_from([jnp.float16, jnp.bfloat16, jnp.float32]), label="in_dtype", @@ -123,6 +134,73 @@ def test_matmul(self, data): hp.assume(False) raise e + @parameterized.named_parameters( + # TODO(apaszke): Increase shard count once we have more B200s in CI. + (f"_shard{i}", i) for i in range(1) + ) + @seed_hypothesis + @hp.settings(max_examples=100) # Add verbosity=hp.Verbosity.verbose to debug + @hp.given(hps.data()) + def test_matmul_sm100(self, data): + if not jtu.is_cuda_compute_capability_equal("10.0"): + self.skipTest("Only works on GPU with capability sm100a") + + dtype = data.draw( + hps.sampled_from([jnp.float16, jnp.bfloat16]), + label="dtype", + ) + m, n, k = ( + data.draw(hps.sampled_from([128, 256, 512, 2048, 8192]), label=d) for d in "mnk" + ) + max_concurrent_steps = data.draw( + hps.integers(2, 5), label="max_concurrent_steps" + ) + collective = data.draw(hps.booleans(), label="collective") + num_ctas = 2 if collective else 1 + hp.assume(not (m == 128 and collective)) # Too small for collective MMA. + tile_m = data.draw( + hps.sampled_from([t for t in [128] if t * num_ctas <= m]), label="tile_m" + ) + tmem_cols = 512 + tile_n = data.draw( + hps.sampled_from([ + t + for t in [64, 128, 256] + # We're double buffering TMEM in the kernel, hence the 2x. + if t * num_ctas <= n and 2 * t * num_ctas <= tmem_cols + ]), + label="tile_n", + ) + grid_m = m // (num_ctas * tile_m) + grid_tile_m = data.draw(hps.sampled_from([1, 2, 4, 8, 16]), label="grid_tile_m") + hp.assume(grid_m % grid_tile_m == 0) + + try: + kernel = matmul_blackwell.build_kernel( + m, + k, + n, + dtype=dtype, + tile_m=tile_m, + tile_n=tile_n, + grid_tile_m=grid_tile_m, + max_concurrent_steps=max_concurrent_steps, + collective=collective, + ) + except ValueError as e: + if "Mosaic GPU kernel exceeds available shared memory" in str(e): + hp.assume(False) + raise + + ka, kb = jax.random.split(jax.random.key(0), 2) + a = jax.random.normal(key=ka, shape=(m, k), dtype=dtype) + b = jax.random.normal(key=kb, shape=(n, k), dtype=dtype) + out = kernel(a, b) + out_ref = jnp.dot(a, b.T) + np.testing.assert_allclose( + out, out_ref, atol=1e-3, rtol=1e-3 if k < 512 else 1e-2 + ) + if __name__ == "__main__": - absltest.main(testLoader=jtu.JaxTestLoader()) + absltest.main(argv=["python"], testLoader=jtu.JaxTestLoader()) diff --git a/tests/multi_device_test.py b/tests/multi_device_test.py index 38a37844ebf8..c4bcd98472d5 100644 --- a/tests/multi_device_test.py +++ b/tests/multi_device_test.py @@ -102,7 +102,7 @@ def test_computation_follows_data(self): self.assert_uncommitted_to_device(z3, devices[0]) - # A jitted computation with an device specification behaves as if the + # A jitted computation with a device specification behaves as if the # arguments are first device_put to the specified device. The result # will be committed on the specified. # The `device` parameter is experimental, and subject to change. diff --git a/tests/multiprocess_gpu_test.py b/tests/multiprocess_gpu_test.py index fe9922148ab4..c2ec44916745 100644 --- a/tests/multiprocess_gpu_test.py +++ b/tests/multiprocess_gpu_test.py @@ -82,9 +82,11 @@ def test_gpu_distributed_initialize(self): try: for proc in subprocesses: - out, _ = proc.communicate() + out, err = proc.communicate() self.assertEqual(proc.returncode, 0) - self.assertEqual(out, f'{num_gpus_per_task},{num_gpus}') + self.assertEqual( + out, f"{num_gpus_per_task},{num_gpus}", msg=f"Process failed:\n\n{err}", + ) finally: for proc in subprocesses: proc.kill() @@ -106,29 +108,17 @@ def test_distributed_jax_visible_devices(self): env["JAX_PORT"] = str(port) env["NUM_TASKS"] = str(num_tasks) env["TASK"] = str(task) - visible_devices = ",".join( - str((task * num_gpus_per_task) + i) for i in range(num_gpus_per_task)) - - if jtu.is_device_rocm(): - program = ( - 'import jax, os; ' - f'jax.config.update("jax_rocm_visible_devices", "{visible_devices}"); ' - 'jax.distributed.initialize(' - 'f\'localhost:{os.environ["JAX_PORT"]}\', ' - 'int(os.environ["NUM_TASKS"]), int(os.environ["TASK"])); ' - 's = jax.pmap(lambda x: jax.lax.psum(x, "i"), axis_name="i")(jax.numpy.ones(jax.local_device_count())); ' - 'print(f\'{jax.local_device_count()},{jax.device_count()},{s}\', end=""); ' - ) - else: - program = ( - 'import jax, os; ' - f'jax.config.update("jax_cuda_visible_devices", "{visible_devices}"); ' - 'jax.distributed.initialize(' - 'f\'localhost:{os.environ["JAX_PORT"]}\', ' - 'int(os.environ["NUM_TASKS"]), int(os.environ["TASK"])); ' - 's = jax.pmap(lambda x: jax.lax.psum(x, "i"), axis_name="i")(jax.numpy.ones(jax.local_device_count())); ' - 'print(f\'{jax.local_device_count()},{jax.device_count()},{s}\', end=""); ' - ) + visible_devices = [ + (task * num_gpus_per_task) + i for i in range(num_gpus_per_task) + ] + program = ( + 'import jax, os; ' + 'jax.distributed.initialize(' + 'f\'localhost:{os.environ["JAX_PORT"]}\', ' + f'int(os.environ["NUM_TASKS"]), int(os.environ["TASK"]), {visible_devices}); ' + 's = jax.pmap(lambda x: jax.lax.psum(x, "i"), axis_name="i")(jax.numpy.ones(jax.local_device_count())); ' + 'print(f\'{jax.local_device_count()},{jax.device_count()},{s}\', end=""); ' + ) args = [sys.executable, "-c", program] proc = subprocess.Popen(args, env=env, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True) diff --git a/tests/mutable_array_test.py b/tests/mutable_array_test.py index e962653ed32d..ce35424cb418 100644 --- a/tests/mutable_array_test.py +++ b/tests/mutable_array_test.py @@ -16,6 +16,7 @@ from absl.testing import absltest from absl.testing import parameterized +from functools import partial import numpy as np import jax from jax._src import core @@ -28,6 +29,9 @@ config.parse_flags_with_absl() +jtu.request_cpu_devices(8) + + class MutableArrayTest(jtu.JaxTestCase): @parameterized.parameters([True, False]) @@ -49,6 +53,35 @@ def f(x_mut): jaxpr = jax.make_jaxpr(f)(x_mut) self.assertTrue(any(isinstance(e, RefEffect) for e in jaxpr.effects)) + def test_basic_aot(self): + @jax.jit + def f(x_mut): + x_mut[...] += 1. + x_mut[0] += 1 + x_mut[1] += 5 + + x_mut = core.mutable_array(jnp.zeros(3)) + f.lower(x_mut).compile()(x_mut) + self.assertAllClose(x_mut[...], jnp.array([2., 6., 1.]), + check_dtypes=False) + + def test_basic_sharded_aot(self): + mesh = jtu.create_mesh((2,), ('x',)) + arr = jax.device_put(np.arange(8.), NamedSharding(mesh, P('x'))) + + @jax.jit + def f(x_mut): + x_mut[...] += 1. + x_mut[0] += 1 + x_mut[1] += 5 + + x_mut = core.mutable_array(arr) + f.lower(x_mut).compile()(x_mut) + expected = np.arange(8.) + 1 + expected[0] += 1 + expected[1] += 5 + self.assertAllClose(x_mut[...], expected) + @parameterized.parameters([True, False]) def test_multiple_inputs_and_outputs(self, jit): def f(x_mut, y, z_mut, w): @@ -116,6 +149,18 @@ def f(y_mut, z): check_dtypes=False) self.assertAllClose(w, 10, check_dtypes=False) + @parameterized.parameters([True, False]) + def test_len_mutable_array(self, jit): + x_mut = core.mutable_array(jnp.zeros(3)) + + def f(): + return jnp.int32(len(x_mut)) + + if jit: + f = jax.jit(f) + + self.assertEqual(f(), 3) + @parameterized.parameters([True, False]) def test_internal_mutarray_basic(self, jit): def f(): @@ -180,14 +225,18 @@ def f(): x = f() self.assertArraysEqual(x, jnp.zeros(8)) - def test_grad_mutable_array(self): - @jax.jit + @parameterized.parameters([False, True]) + def test_grad_mutable_array(self, jit): + def f(x): x_ = core.mutable_array(x) x_[()] = x_[()] + x_[()] y = core.freeze(x_) return y + if jit: + f = jax.jit(f) + ans = jax.grad(f)(1.) expected = 2.0 self.assertAllClose(ans, expected, check_dtypes=False) @@ -227,6 +276,127 @@ def f(x_ref): x_ref = core.mutable_array(x) y = f(x_ref) + def test_vmap_basic(self): + @jax.vmap + def f(x): + x_ref = core.mutable_array(x) + x_ref[...] = x_ref[...] * x_ref[...] + return x_ref[...] + xs = jnp.arange(4.) + ys = f(xs) + self.assertAllClose(ys, xs ** 2, check_dtypes=False) + + def test_vmap_extensive_inputs(self): + def f(x_ref, val): + x_ref[...] += val + x_ref[...] += val + + xs_ref = core.mutable_array(jnp.array([0, 0, 0])) + vals = jnp.arange(3) + jax.vmap(f)(xs_ref, vals) + self.assertAllClose(xs_ref[...], 2 * vals, check_dtypes=False) + + def test_vmap_closed_over_read_only(self): + y_ref = core.mutable_array(1) + + def f(x_ref): + x_ref[...] += y_ref[...] + x_ref[...] += y_ref[...] + + xs_ref = core.mutable_array(jnp.array([0, 0, 0])) + jax.vmap(f)(xs_ref) + self.assertAllClose(xs_ref[...], jnp.array([2, 2, 2]), check_dtypes=False) + + def test_implicit_bitcast_regression(self): + # https://github.com/jax-ml/jax/issues/27683 + v = core.mutable_array(jnp.array([0, 0, 0])) + with self.assertRaises(ValueError): + v[...] += 1.0 + + def test_implicit_cast_in_swap(self): + v = core.mutable_array(jnp.array(0, dtype='bfloat16')) + v[...] += 1.0 # don't crash + + def test_rng_key(self): + key = core.mutable_array(jax.random.key(0)) + # test read/write + key[...] = jax.random.fold_in(key[...], 1) # don't crash + + def test_scan_grad_doesnt_hoist_mutable_stuff(self): + x_ref = core.mutable_array(0) + + def f(x): + def body(c, _): + x_ref[...] += 1 + return c, () + x, () = jax.lax.scan(body, x, (), length=3) + return x + + jax.grad(f)(1.0) + self.assertAllClose(x_ref[...], 3, check_dtypes=False) + + def test_scan_grad_doesnt_hoist_mutable_stuff2(self): + x_ref = core.mutable_array(0) + const = jnp.arange(3) + const2 = jnp.zeros(()) + + def f(x): + def body(c, _): + x_ref[...] += const.sum() + return c + const2, () + x, () = jax.lax.scan(body, x, (), length=4) + return x + + jax.grad(f)(1.0) + self.assertAllClose(x_ref[...], 12, check_dtypes=False) + + @parameterized.parameters([False, True]) + def test_custom_vjp_grad_stats_plumbing(self, jit): + + @jax.custom_vjp + def gradient_history_calculator(x, ref): + del ref + return x + + def gradient_history_calculator_fwd(x, ref): + return x, ref + + def gradient_history_calculator_bwd(amax_history, grad_output): + amax_update = jnp.max(jnp.abs(grad_output)) + shifted = jnp.roll(amax_history[:], 1) + shifted = shifted.at[0].set(amax_update) + amax_history[:] = shifted + amax_from_history = jnp.max(amax_history[:]) + grad_output = grad_output / amax_from_history + return grad_output, None + + gradient_history_calculator.defvjp( + gradient_history_calculator_fwd, + gradient_history_calculator_bwd) + + class DotOp: + def __init__(self): + self.amax_history = core.mutable_array(jnp.zeros(5,)) + + def forward(self, x, y): + out = jnp.dot(x, y) + out = gradient_history_calculator(out, self.amax_history) + return out + + dot_op = DotOp() + x_top = jnp.ones((5,)) + y_top = jnp.ones((5,)) + + def loss(x, y): + return dot_op.forward(x, y).sum() + + if jit: + loss = jax.jit(loss) + + for i in range(3): + jax.grad(loss, (0,1))(x_top, y_top) + self.assertAllClose(dot_op.amax_history[:], jnp.zeros((5,)).at[:i+1].set(1.0), check_dtypes=False) + @jtu.with_config(jax_mutable_array_checks=True) class MutableArrayErrorsTest(jtu.JaxTestCase): @@ -319,6 +489,19 @@ def f(ref): ValueError, "custom_vjp primal function"): f(x_ref) + @parameterized.parameters([False, True]) + def test_return_from_custom_vjp_primal_nondiff_argnum(self, jit): + @partial(jax.custom_vjp, nondiff_argnums=(0,)) + def f(_, ref): + return ref + f.defvjp(lambda _, ref: ..., lambda *_: ...) + if jit: + f = jax.jit(f, static_argnums=0) + x_ref = core.mutable_array(0.) + with self.assertRaisesRegex( + ValueError, "custom_vjp primal function"): + f('hi', x_ref) + @parameterized.parameters([False, True]) def test_return_from_custom_vjp_fwd(self, jit): @jax.custom_vjp @@ -328,9 +511,23 @@ def f(x, ref): if jit: f = jax.jit(f) x_ref = core.mutable_array(0.) + + jax.vjp(f, 3., x_ref) # returning input ref, okay + + @jax.custom_vjp + def g(x, ref): + return x + def g_fwd(x, _): + y_ref = core.mutable_array(0) + return x, y_ref + g.defvjp(g_fwd, lambda ref, g: g) + if jit: + g = jax.jit(g) + x_ref = core.mutable_array(0.) + with self.assertRaisesRegex( ValueError, "custom_vjp fwd function"): - jax.vjp(f, 3., x_ref) + jax.vjp(g, 3., x_ref) @parameterized.parameters([False, True]) def test_argument_aliases_custom_vjp_primal(self, jit): @@ -376,6 +573,16 @@ def false_fun(): out_false = f(False) self.assertAllClose(x_ref[...], 2.) + def test_vmap_closed_over_ref_write(self): + x_ref = core.mutable_array(jnp.zeros((), 'int32')) + + def f(val): + x_ref[...] += val + + vals = jnp.arange(3, dtype='int32') + with self.assertRaisesRegex(Exception, "unbatched mutable array"): + jax.vmap(f)(vals) + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/nn_test.py b/tests/nn_test.py index ed016ec349ef..385b216aeb57 100644 --- a/tests/nn_test.py +++ b/tests/nn_test.py @@ -31,7 +31,6 @@ from jax._src.cudnn.scaled_matmul_stablehlo import ( quantize, shape_normalization, - BlockScaleConfig, ) from jax.test_util import check_grads from jax import nn @@ -110,17 +109,7 @@ def create_mxfp8_configs_if_available(): if _dtypes.float8_e8m0fnu is None: raise unittest.SkipTest("float8_e8m0fnu is not available.") - def _create_mxfp8_config(): - return BlockScaleConfig( - mode='mxfp8', - block_size=32, - data_type=jnp.float8_e4m3fn, - scale_type=jnp.float8_e8m0fnu, - global_scale=None, - infer_only=False - ) - - return [_create_mxfp8_config() for _ in range(3)] + return [nn.get_scaled_dot_general_config("mxfp8") for _ in range(3)] @jtu.with_config(jax_legacy_prng_key="allow", @@ -130,10 +119,9 @@ class NNFunctionsTest(jtu.JaxTestCase): contract=[160, 96], lhs_non_contract=[240, 100], dtype=[jnp.float16, jnp.bfloat16, jnp.float32], - impl=['cudnn',], ) - def testScaledMatmul(self, contract, lhs_non_contract, dtype, impl): - if impl == 'cudnn' and not _is_required_cudnn_version_satisfied("10.0", 90700): + def testScaledMatmul(self, contract, lhs_non_contract, dtype): + if not _is_required_cudnn_version_satisfied("10.0", 90700): raise unittest.SkipTest("CUDA or cuDNN versions are not compatible") # Check if float8_e8m0fnu is available configs = create_mxfp8_configs_if_available() @@ -153,11 +141,10 @@ def testScaledMatmul(self, contract, lhs_non_contract, dtype, impl): @parameterized.product( is_training=[True, False], output_type=[jnp.float16, jnp.bfloat16, jnp.float32], - impl=['cudnn',], ) def testScaledDotGeneral( - self, is_training, output_type, impl): - if impl == 'cudnn' and not _is_required_cudnn_version_satisfied("10.0", 90700): + self, is_training, output_type): + if not _is_required_cudnn_version_satisfied("10.0", 90700): raise unittest.SkipTest("CUDA or cuDNN versions are not compatible") configs = create_mxfp8_configs_if_available() @@ -422,6 +409,7 @@ def testSparseplusAndSparseSigmoid(self): jax.grad(nn.sparse_plus)(-2.), nn.sparse_sigmoid(-2.), check_dtypes=False) + @jtu.skip_on_flag("jax_skip_slow_tests", True) def testSquareplusGrad(self): check_grads(nn.squareplus, (1e-8,), order=4, rtol=1e-2 if jtu.test_device_matches(["tpu"]) else None) @@ -442,6 +430,7 @@ def testSquareplusGradNan(self): def testSquareplusZero(self, dtype): self.assertEqual(dtype(1), nn.squareplus(dtype(0), dtype(4))) + @jtu.skip_on_flag("jax_skip_slow_tests", True) def testMishGrad(self): check_grads(nn.mish, (1e-8,), order=4, rtol=1e-2 if jtu.test_device_matches(["tpu"]) else None) @@ -541,7 +530,7 @@ def gelu_reference(x): (jnp.float32, jnp.bfloat16, jnp.float16), (partial(nn.gelu, approximate=False), partial(nn.gelu, approximate=True), - nn.relu, nn.softplus, nn.sparse_plus, nn.sigmoid, nn.squareplus, nn.mish))) + nn.relu, nn.identity, nn.softplus, nn.sparse_plus, nn.sigmoid, nn.squareplus, nn.mish))) def testDtypeMatchesInput(self, dtype, fn): x = jnp.zeros((), dtype=dtype) out = fn(x) @@ -829,6 +818,12 @@ def testVarianceScalingError(self): ): initializer(rng, shape) + def testIdentity(self): + x = jnp.array([1., 2., 3.]) + self.assertAllClose(nn.identity(x), x, check_dtypes=False) + grad = jax.grad(nn.identity)(6.0) + self.assertEqual(grad, 1.) + def testAccidentalUpcasting(self): rng = random.PRNGKey(0) shape = (4, 4) diff --git a/tests/notebooks/colab_cpu.ipynb b/tests/notebooks/colab_cpu.ipynb index f5dcff837838..b49eed8a5e62 100644 --- a/tests/notebooks/colab_cpu.ipynb +++ b/tests/notebooks/colab_cpu.ipynb @@ -106,7 +106,6 @@ } ], "source": [ - "from jaxlib import xla_extension\n", "import jax\n", "key = jax.random.PRNGKey(1701)\n", "arr = jax.random.normal(key, (1000,))\n", diff --git a/tests/pallas/BUILD b/tests/pallas/BUILD index 987a3aa9d50a..678354a43b28 100644 --- a/tests/pallas/BUILD +++ b/tests/pallas/BUILD @@ -30,6 +30,11 @@ package( jax_generate_backend_suites() +test_suite( + name = "mosaic_gpu_tests", + tags = ["mosaic_gpu_test"], +) + jax_multiplatform_test( name = "pallas_test", srcs = [ @@ -54,7 +59,10 @@ jax_multiplatform_test( "//jax:pallas_gpu_ops", "//jax:pallas_tpu", "//jax:pallas_tpu_ops", - ] + py_deps("absl/testing") + py_deps("numpy"), + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -68,7 +76,10 @@ jax_multiplatform_test( "//jax:pallas_gpu_ops", "//jax:pallas_tpu", "//jax:pallas_tpu_ops", - ] + py_deps("absl/testing") + py_deps("numpy"), + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -88,7 +99,10 @@ jax_multiplatform_test( "//jax:pallas", "//jax:pallas_tpu", "//jax:pallas_tpu_ops", - ] + py_deps("absl/testing") + py_deps("numpy"), + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -107,11 +121,11 @@ jax_multiplatform_test( "gpu_a100_x32", "gpu_h100", "gpu_h100_x32", - "tpu_v6e_1x1", + "tpu_v6e", ], shard_count = { "cpu": 16, - "gpu": 16, + "gpu": 32, "tpu": 16, }, tags = [ @@ -124,7 +138,11 @@ jax_multiplatform_test( "//jax:pallas_gpu", # build_cleaner: keep "//jax:pallas_tpu", "//jax:pallas_tpu_ops", - ] + py_deps("absl/testing") + py_deps("hypothesis") + py_deps("numpy"), + ] + py_deps([ + "absl/testing", + "hypothesis", + "numpy", + ]), ) jax_multiplatform_test( @@ -149,9 +167,10 @@ jax_multiplatform_test( ], env = { "JAX_PALLAS_USE_MOSAIC_GPU": "1", - "JAX_PALLAS_VERBOSE_ERRORS": "0", }, + shard_count = 16, tags = [ + "mosaic_gpu_test", "noasan", # Times out. "nomsan", # Times out. "notsan", # Times out. @@ -161,7 +180,11 @@ jax_multiplatform_test( "//jax:pallas_gpu", # build_cleaner: keep "//jax:pallas_mosaic_gpu", # build_cleaner: keep "//jax:pallas_tpu", - ] + py_deps("absl/testing") + py_deps("hypothesis") + py_deps("numpy"), + ] + py_deps([ + "absl/testing", + "hypothesis", + "numpy", + ]), ) jax_multiplatform_test( @@ -181,7 +204,11 @@ jax_multiplatform_test( deps = [ "//jax:pallas", "//jax:pallas_tpu", - ] + py_deps("absl/testing") + py_deps("hypothesis") + py_deps("numpy"), + ] + py_deps([ + "absl/testing", + "hypothesis", + "numpy", + ]), ) jax_multiplatform_test( @@ -201,7 +228,10 @@ jax_multiplatform_test( "//jax:pallas_gpu_ops", "//jax:pallas_tpu", "//jax:pallas_tpu_ops", - ] + py_deps("absl/testing") + py_deps("numpy"), + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -214,14 +244,16 @@ jax_multiplatform_test( "gpu_h100_x32", "gpu_h100", ], - env = { - "JAX_PALLAS_USE_MOSAIC_GPU": "1", - "JAX_PALLAS_VERBOSE_ERRORS": "0", - }, + tags = [ + "mosaic_gpu_test", + ], deps = [ "//jax:pallas", "//jax:pallas_mosaic_gpu", # build_cleaner: keep - ] + py_deps("absl/testing") + py_deps("numpy"), + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -238,8 +270,9 @@ jax_multiplatform_test( "//jax:internal_export_back_compat_test_util", "//jax:pallas", "//jax:pallas_gpu", # build_cleaner: keep + "//jax:pallas_mosaic_gpu", # build_cleaner: keep "//jax:pallas_tpu_ops", # build_cleaner: keep - ], + ] + py_deps("absl/testing"), ) jax_py_test( @@ -252,7 +285,10 @@ jax_py_test( "//jax:pallas_gpu", # build_cleaner: keep "//jax:pallas_tpu", # build_cleaner: keep "//jax:test_util", - ] + jax_gpu_support_deps, + ] + jax_gpu_support_deps + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -272,7 +308,10 @@ jax_multiplatform_test( "//jax:pallas", "//jax:pallas_gpu", # build_cleaner: keep "//jax:pallas_tpu", # build_cleaner: keep - ], + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -293,7 +332,10 @@ jax_multiplatform_test( "//jax:pallas", "//jax:pallas_gpu", # build_cleaner: keep "//jax:pallas_tpu", # build_cleaner: keep - ], + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -306,7 +348,10 @@ jax_multiplatform_test( "//jax:pallas", "//jax:pallas_tpu", "//jax/_src/pallas/mosaic:random", - ] + py_deps("absl/testing") + py_deps("numpy"), + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -316,11 +361,15 @@ jax_multiplatform_test( ], enable_backends = [], enable_configs = [ - "tpu_v5e_4x2", + "tpu_v5e_x8", ], deps = [ "//jax:pallas_tpu_ops", - ] + py_deps("absl/testing") + py_deps("numpy") + py_deps("hypothesis"), + ] + py_deps([ + "absl/testing", + "numpy", + "hypothesis", + ]), ) jax_multiplatform_test( @@ -329,7 +378,7 @@ jax_multiplatform_test( "tpu_gmm_test.py", ], enable_backends = ["tpu"], - shard_count = 50, + shard_count = 5, tags = [ "noasan", # Times out. "nomsan", # Times out. @@ -353,13 +402,44 @@ jax_multiplatform_test( enable_backends = ["tpu"], enable_configs = [ "tpu_v5e", - "tpu_v5p_1x1", + "tpu_v5p", ], deps = [ "//jax:extend", "//jax:pallas_tpu", "//jax:pallas_tpu_ops", + ] + py_deps([ + "absl/testing", + "numpy", + ]), +) + +jax_multiplatform_test( + name = "gpu_pallas_distributed_test", + srcs = ["gpu_pallas_distributed_test.py"], + args = [ + "--num_processes=2", + "--gpus_per_process=1", + ], + enable_backends = [], + enable_configs = ["gpu_h100x2"], + env = { + "JAX_PALLAS_USE_MOSAIC_GPU": "1", + "XLA_FLAGS": "--xla_gpu_experimental_enable_nvshmem=true", + }, + tags = [ + "mosaic_gpu_test", + "multiaccelerator", ], + deps = [ + "//jax:extend", + "//jax:pallas_mosaic_gpu", + "//jax:test_multiprocess", + ] + py_deps([ + "portpicker", + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -376,7 +456,11 @@ jax_multiplatform_test( "//jax:pallas_gpu", # build_cleaner: keep "//jax:pallas_tpu", "//jax:pallas_tpu_ops", - ] + py_deps("absl/testing") + py_deps("hypothesis") + py_deps("numpy"), + ] + py_deps([ + "absl/testing", + "hypothesis", + "numpy", + ]), ) jax_multiplatform_test( @@ -384,16 +468,19 @@ jax_multiplatform_test( srcs = ["tpu_pallas_distributed_test.py"], enable_backends = ["tpu"], enable_configs = [ - "tpu_v5e_4x2", - "tpu_v5p_2x2", - "tpu_v4_2x2", - "tpu_v3_2x2", + "tpu_v5e_x8", + "tpu_v5p_x4", + "tpu_v4_x4", + "tpu_v3_x4", ], deps = [ "//jax:extend", "//jax:pallas_tpu", "//jax:pallas_tpu_ops", - ], + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -401,8 +488,8 @@ jax_multiplatform_test( srcs = ["tpu_pallas_pipeline_test.py"], enable_backends = ["tpu"], enable_configs = [ - "tpu_v5e_4x2", - "tpu_v5p_1x1", + "tpu_v5e_x8", + "tpu_v5p", ], shard_count = 5, tags = [ @@ -414,7 +501,11 @@ jax_multiplatform_test( "//jax:extend", "//jax:pallas_tpu", "//jax:pallas_tpu_ops", - ] + py_deps("hypothesis"), + ] + py_deps([ + "absl/testing", + "numpy", + "hypothesis", + ]), ) jax_multiplatform_test( @@ -422,12 +513,30 @@ jax_multiplatform_test( srcs = ["tpu_pallas_async_test.py"], enable_backends = ["tpu"], enable_configs = [ - "tpu_v5e_4x2", - "tpu_v5p_1x1", + "tpu_v5e_x8", + "tpu_v5p", ], deps = [ "//jax:pallas_tpu", + ] + py_deps([ + "absl/testing", + "numpy", + ]), +) + +jax_multiplatform_test( + name = "tpu_pallas_memory_space_test", + srcs = ["tpu_pallas_memory_space_test.py"], + enable_backends = ["tpu"], + enable_configs = [ + "tpu_v5p", ], + deps = [ + "//jax:pallas_tpu", + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -442,7 +551,10 @@ jax_multiplatform_test( deps = [ "//jax:extend", "//jax:pallas_tpu", - ], + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -452,14 +564,17 @@ jax_multiplatform_test( ], enable_backends = ["tpu"], enable_configs = [ - "tpu_v5p_2x2", + "tpu_v5p_x4", ], deps = [ "//jax:pallas", "//jax:pallas_tpu", "//jax:pallas_tpu_ops", "//jax/_src/pallas/mosaic:random", - ] + py_deps("absl/testing") + py_deps("numpy"), + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -470,9 +585,13 @@ jax_multiplatform_test( disable_configs = ["cpu_shardy"], enable_backends = ["cpu"], deps = [ + "//jax:experimental", "//jax:pallas", "//jax:pallas_tpu", - ] + py_deps("absl/testing") + py_deps("numpy"), + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -485,14 +604,17 @@ jax_multiplatform_test( deps = [ "//jax:pallas", "//jax:pallas_tpu", - ] + py_deps("absl/testing") + py_deps("numpy"), + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( name = "tpu_paged_attention_kernel_test", srcs = ["tpu_paged_attention_kernel_test.py"], disable_configs = [ - "tpu_v5p_1x1", + "tpu_v5p", ], enable_backends = ["tpu"], shard_count = 5, @@ -503,14 +625,17 @@ jax_multiplatform_test( ], deps = [ "//jax:pallas_tpu_ops", - ] + py_deps("absl/testing") + py_deps("numpy"), + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( name = "tpu_ragged_paged_attention_test", srcs = ["tpu_ragged_paged_attention_test.py"], disable_configs = [ - "tpu_v5p_1x1", + "tpu_v5p", ], enable_backends = ["tpu"], shard_count = 24, @@ -521,7 +646,10 @@ jax_multiplatform_test( ], deps = [ "//jax:pallas_tpu_ops", - ] + py_deps("absl/testing") + py_deps("numpy"), + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -538,7 +666,29 @@ jax_multiplatform_test( ], deps = [ "//jax:pallas_tpu_ops", - ] + py_deps("absl/testing") + py_deps("numpy") + py_deps("hypothesis"), + ] + py_deps([ + "absl/testing", + "numpy", + "hypothesis", + ]), +) + +jax_multiplatform_test( + name = "tpu_splash_attention_kernel_sharded_test", + srcs = ["tpu_splash_attention_kernel_sharded_test.py"], + enable_configs = [ + "tpu_v5e_x8", + "tpu_v5p_x4", + ], + shard_count = 5, + deps = [ + "//jax:extend", + "//jax:pallas_tpu", + "//jax:pallas_tpu_ops", + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) # This test doesn't need a TPU; it only tests numpy-using helpers. @@ -551,7 +701,11 @@ jax_py_test( "//jax", "//jax:pallas_tpu_ops", "//jax:test_util", - ] + py_deps("absl/testing") + py_deps("numpy") + py_deps("hypothesis"), + ] + py_deps([ + "absl/testing", + "numpy", + "hypothesis", + ]), ) jax_multiplatform_test( @@ -569,7 +723,10 @@ jax_multiplatform_test( "//jax:pallas", "//jax:pallas_gpu", # build_cleaner: keep "//jax:pallas_gpu_ops", - ] + py_deps("absl/testing") + py_deps("numpy"), + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -592,7 +749,10 @@ jax_multiplatform_test( "//jax:pallas", "//jax:pallas_gpu", "//jax:pallas_gpu_ops", - ] + py_deps("absl/testing") + py_deps("numpy"), + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -610,7 +770,28 @@ jax_multiplatform_test( "//jax:pallas", "//jax:pallas_gpu", "//jax:pallas_gpu_ops", - ] + py_deps("absl/testing") + py_deps("numpy"), + ] + py_deps([ + "absl/testing", + "numpy", + ]), +) + +jax_multiplatform_test( + name = "triton_pallas_test", + srcs = [ + "triton_pallas_test.py", + ], + enable_backends = ["cpu"], + enable_configs = [ + "gpu_h100_x32", + ], + shard_count = 1, + deps = [ + "//jax:pallas", + "//jax:pallas_gpu", + ] + py_deps([ + "absl/testing", + ]), ) jax_multiplatform_test( @@ -626,7 +807,10 @@ jax_multiplatform_test( deps = [ "//jax:pallas", "//jax:pallas_mosaic_gpu", - ] + py_deps("absl/testing") + py_deps("numpy"), + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( @@ -638,10 +822,129 @@ jax_multiplatform_test( "gpu_h100", ], env = {"XLA_FLAGS": "--xla_gpu_autotune_level=0"}, + shard_count = 8, + tags = [ + "mosaic_gpu_test", + ], + deps = [ + "//jax:pallas", + "//jax:pallas_experimental_gpu_ops", + "//jax:pallas_mosaic_gpu", + ] + py_deps([ + "absl/testing", + "numpy", + ]), +) + +jax_multiplatform_test( + name = "mgpu_matmul_test", + srcs = ["mgpu_matmul_test.py"], + enable_backends = [], + enable_configs = ["gpu_b200"], + env = {"XLA_FLAGS": "--xla_gpu_autotune_level=0"}, + shard_count = 8, + tags = [ + "mosaic_gpu_test", + # TODO(b/330364373): Remove when B200 is fully supported. + "notap", + ], + deps = [ + "//jax:pallas", + "//jax:pallas_experimental_gpu_ops", + "//jax:pallas_mosaic_gpu", + ] + py_deps([ + "absl/testing", + "numpy", + ]), +) + +jax_multiplatform_test( + name = "blackwell_matmul_mgpu_run", + srcs = ["//jax/experimental/pallas/ops/gpu:blackwell_matmul_mgpu.py"], + enable_backends = [], + enable_configs = ["gpu_b200"], + env = {"XLA_FLAGS": "--xla_gpu_autotune_level=0"}, + tags = [ + "manual", + "notap", + ], + deps = [ + "//jax:pallas", + "//jax:pallas_mosaic_gpu", + ] + py_deps([ + "absl/testing", + "numpy", + ]), +) + +jax_multiplatform_test( + name = "mgpu_ragged_dot_run", + srcs = ["//jax/experimental/pallas/ops/gpu:ragged_dot_mgpu.py"], + enable_backends = [], + enable_configs = [ + "gpu_h100_x32", + "gpu_h100", + ], + env = {"XLA_FLAGS": "--xla_gpu_autotune_level=0"}, + tags = [ + "manual", + "notap", + ], + deps = [ + "//jax:pallas", + "//jax:pallas_mosaic_gpu", + ] + py_deps("absl/testing") + py_deps("numpy"), +) + +jax_multiplatform_test( + name = "mgpu_ragged_dot_test", + srcs = ["mgpu_ragged_dot_test.py"], + enable_backends = [], + enable_configs = [ + "gpu_h100", + ], + shard_count = 12, + tags = [ + "mosaic_gpu_test", + "noasan", # Times out. + ], + deps = [ + "//jax:pallas", + "//jax:pallas_experimental_gpu_ops", + "//jax:pallas_mosaic_gpu", + ] + py_deps([ + "absl/testing", + "numpy", + ]), +) + +jax_multiplatform_test( + name = "mgpu_collective_matmul_test", + srcs = ["mgpu_collective_matmul_test.py"], + args = [ + "--num_processes=2", + "--gpus_per_process=1", + ], + enable_backends = [], + enable_configs = [ + "gpu_h100x2", + ], + env = { + "XLA_FLAGS": "--xla_gpu_experimental_enable_nvshmem=true", + "JAX_PALLAS_USE_MOSAIC_GPU": "1", + }, + shard_count = 4, + tags = [ + "manual", + "multiaccelerator", + "notap", + ], deps = [ + "//jax:experimental", "//jax:pallas", "//jax:pallas_experimental_gpu_ops", "//jax:pallas_mosaic_gpu", + "//jax:test_multiprocess", ] + py_deps("absl/testing") + py_deps("numpy"), ) @@ -663,14 +966,41 @@ jax_multiplatform_test( deps = [ "//jax:pallas", "//jax/_src/pallas/fuser", - ] + py_deps("absl/testing") + py_deps("numpy"), + ] + py_deps([ + "absl/testing", + "numpy", + ]), +) + +jax_multiplatform_test( + name = "fusion_test", + srcs = [ + "fusion_test.py", + ], + disable_configs = [ + "cpu", + "cpu_shardy", + ], + enable_backends = ["cpu"], + tags = [ + "noasan", + "nomsan", + "notsan", + ], + deps = [ + "//jax:pallas", + "//jax:pallas_fuser", + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) jax_multiplatform_test( - name = "tpu_fusable_matmul_test", - srcs = ["tpu_fusable_matmul_test.py"], + name = "tpu_fusible_matmul_test", + srcs = ["tpu_fusible_matmul_test.py"], disable_configs = [ - "tpu_v3_1x1", + "tpu_v3", "tpu_pjrt_c_api", "gpu_v100", "gpu_v100_x32", @@ -684,10 +1014,10 @@ jax_multiplatform_test( ], enable_backends = ["tpu"], enable_configs = [ - "tpu_v4_1x1", + "tpu_v4", "tpu_v5e", - "tpu_v5p_1x1", - "tpu_v6e_1x1", + "tpu_v5p", + "tpu_v6e", ], shard_count = 4, tags = [ @@ -699,5 +1029,8 @@ jax_multiplatform_test( "//jax:pallas_tpu", "//jax:pallas_tpu_ops", "//jax/_src/pallas/fuser", - ] + py_deps("absl/testing") + py_deps("numpy"), + ] + py_deps([ + "absl/testing", + "numpy", + ]), ) diff --git a/tests/pallas/export_back_compat_pallas_test.py b/tests/pallas/export_back_compat_pallas_test.py index addf14d73792..c37bbbfec2a0 100644 --- a/tests/pallas/export_back_compat_pallas_test.py +++ b/tests/pallas/export_back_compat_pallas_test.py @@ -17,6 +17,7 @@ update these tests. """ +import functools import math import unittest @@ -25,6 +26,7 @@ from jax._src import config from jax._src import test_util as jtu from jax._src.internal_test_util import export_back_compat_test_util as bctu +from jax._src.internal_test_util.export_back_compat_test_data.pallas import mosaic_gpu_add_one from jax._src.internal_test_util.export_back_compat_test_data.pallas import mosaic_matmul from jax._src.internal_test_util.export_back_compat_test_data.pallas import mosaic_semaphore_dma from jax._src.internal_test_util.export_back_compat_test_data.pallas import triton_add_one @@ -43,9 +45,6 @@ class CompatTest(bctu.CompatTestBase): def setUp(self): if jax.config.x64_enabled: self.skipTest("Only works in 32-bit") - if (jtu.test_device_matches(["cuda"]) and - not jtu.is_cuda_compute_capability_at_least("8.0")): - self.skipTest("Only works on GPUs with capability >= sm80") super().setUp() @unittest.skip("This test is checking backwards compatibility " @@ -53,6 +52,9 @@ def setUp(self): "compatibility for its IR, and we have since removed " "the corresponding custom call from the guaranteed stable list.") def test_triton_add_one(self): + if not jtu.is_cuda_compute_capability_at_least("8.0"): + self.skipTest("Only works on GPUs with capability >= sm80") + def func(x): def add_one(x_ref, o_ref): o_ref[0] = x_ref[0] + 1 @@ -65,6 +67,22 @@ def add_one(x_ref, o_ref): self.run_one_test(func, data) + def test_mosaic_gpu_add_one(self): + if not jtu.is_cuda_compute_capability_at_least("9.0"): + self.skipTest("Only works on GPUs with capability >= sm90") + + @functools.partial( + pl.pallas_call, + out_shape=jax.ShapeDtypeStruct((128 * 2,), jnp.float32), + grid=2, + backend="mosaic_gpu", + ) + def add_one(x_ref, o_ref): + o_ref[...] = x_ref[...] + 1 + + data = self.load_testdata(mosaic_gpu_add_one.data_2025_04_22) + self.run_one_test(add_one, data) + @jax.default_matmul_precision("bfloat16") def test_mosaic_matmul(self): # TODO(apaszke): Remove after 12 weeks have passed. diff --git a/tests/pallas/fuser_block_spec_test.py b/tests/pallas/fuser_block_spec_test.py index 1b3a215876ec..db242fb1e400 100644 --- a/tests/pallas/fuser_block_spec_test.py +++ b/tests/pallas/fuser_block_spec_test.py @@ -653,9 +653,12 @@ def f(): kernel_fn((0, 0, 3, 0), scalar_prefetch_values, ()), x ) - def test_broadcast_array(self): + @parameterized.parameters( + (False, False), (False, True), (True, False), (True, True) + ) + def test_broadcast_array(self, bcast0, bcast1): - x = jnp.ones((512, 512)) + x = jnp.ones((1 if bcast0 else 512, 1 if bcast1 else 512)) def f(): return jax.lax.broadcast_in_dim(x, (2, 2, 512, 512), (2, 3)) @@ -664,9 +667,47 @@ def f(): self.assertLen(new_values, 1) self.assertEmpty(scalar_prefetch_values) - block_spec = pl.BlockSpec( - (None, 1, 128, 128), lambda i, j, k, l: (i, j, k, l) + block_shape = (None, 1, 128, 128) + block_spec = pl.BlockSpec(block_shape, lambda i, j, k, l: (i, j, k, l)) + kernel_fn, (value_block_specs,), _ = block_spec_lib.pull_block_spec( + f2, + block_spec, + grid=(2, 2, 4, 4), + scalar_prefetch_handler=block_spec_lib.make_scalar_prefetch_handler(), + )(new_values) + self.assertLen(value_block_specs, 1) + x_index_map = value_block_specs[0].index_map + self.assertEqual( + x_index_map(0, 0, 1, 2), (0 if bcast0 else 1, 0 if bcast1 else 2) ) + self.assertEqual( + x_index_map(1, 2, 3, 3), (0 if bcast0 else 3, 0 if bcast1 else 3) + ) + + block_shape = (1 if bcast0 else 128, 1 if bcast1 else 128) + self.assertEqual(block_shape, value_block_specs[0].block_shape) + x = jnp.full(block_shape, fill_value=1.2345, dtype=jnp.float32) + y = jax.lax.broadcast_in_dim(x, (1, 128, 128), (1, 2)) + np.testing.assert_array_equal(kernel_fn((0, 0, 0, 0), (), (x,)), y) + np.testing.assert_array_equal(kernel_fn((1, 1, 0, 0), (), (x,)), y) + np.testing.assert_array_equal(kernel_fn((0, 0, 0, 1), (), (x,)), y) + np.testing.assert_array_equal(kernel_fn((0, 0, 1, 0), (), (x,)), y) + np.testing.assert_array_equal(kernel_fn((0, 0, 3, 0), (), (x,)), y) + + @parameterized.parameters(0, 1, 2, 3) + def test_broadcast_1d_array(self, bcast_dim): + full_shape = (2, 2, 512, 512) + x = jnp.ones((full_shape[bcast_dim],)) + + def f(): + return jax.lax.broadcast_in_dim(x, full_shape, (bcast_dim,)) + + f2, new_values, scalar_prefetch_values = block_spec_lib.get_fusion_values(f) + self.assertLen(new_values, 1) + self.assertEmpty(scalar_prefetch_values) + + block_shape = (None, 1, 128, 128) + block_spec = pl.BlockSpec(block_shape, lambda i, j, k, l: (i, j, k, l)) kernel_fn, (value_block_specs,), _ = block_spec_lib.pull_block_spec( f2, block_spec, @@ -674,26 +715,276 @@ def f(): scalar_prefetch_handler=block_spec_lib.make_scalar_prefetch_handler(), )(new_values) self.assertLen(value_block_specs, 1) - x_block_spec = value_block_specs[0] - self.assertEqual(x_block_spec.index_map(0, 0, 1, 2), (1, 2)) - self.assertEqual(x_block_spec.index_map(1, 2, 3, 3), (3, 3)) + x_index_map = value_block_specs[0].index_map + self.assertEqual(x_index_map(0, 0, 1, 2), ((0, 0, 1, 2)[bcast_dim],)) + self.assertEqual(x_index_map(1, 2, 3, 3), ((1, 2, 3, 3)[bcast_dim],)) - x = jnp.full((128, 128), fill_value=1.2345, dtype=jnp.float32) - np.testing.assert_array_equal( - kernel_fn((0, 0, 0, 0), scalar_prefetch_values, (x,)), x + if block_shape[bcast_dim] is None: + x = jnp.ones(()) + y = jax.lax.broadcast_in_dim(x, (1, 128, 128), ()) + else: + x = jnp.arange(block_shape[bcast_dim] or 1, dtype=jnp.float32) + y = jax.lax.broadcast_in_dim(x, (1, 128, 128), (bcast_dim - 1,)) + + np.testing.assert_array_equal(kernel_fn((0, 0, 0, 0), (), (x,)), y) + np.testing.assert_array_equal(kernel_fn((1, 1, 0, 0), (), (x,)), y) + np.testing.assert_array_equal(kernel_fn((0, 0, 0, 1), (), (x,)), y) + np.testing.assert_array_equal(kernel_fn((0, 0, 1, 0), (), (x,)), y) + np.testing.assert_array_equal(kernel_fn((0, 0, 3, 0), (), (x,)), y) + + def test_element_indexing(self): + + x = np.zeros((512, 512), dtype=np.float32) + + def f(): + return x + + f2, new_values, scalar_prefetch_values = block_spec_lib.get_fusion_values(f) + self.assertLen(new_values, 1) + self.assertEmpty(scalar_prefetch_values) + + # Block spec with an offset on the first dimension + block_spec = pl.BlockSpec( + (pl.Element(128, (0, 16)), 128), lambda i, j, k: (128 * i + 16, j) ) - np.testing.assert_array_equal( - kernel_fn((1, 1, 0, 0), scalar_prefetch_values, (x,)), x + kernel_fn, (value_block_specs,), _ = block_spec_lib.pull_block_spec( + f2, + block_spec, + grid=(1, 1, 1), + scalar_prefetch_handler=block_spec_lib.make_scalar_prefetch_handler(), + )(new_values) + self.assertLen(value_block_specs, 1) + self.assertEmpty(scalar_prefetch_values) + self.assertEqual( + value_block_specs[0].block_shape, (pl.Element(128, (0, 16)), 128) ) + self.assertEqual(value_block_specs[0].index_map(0, 1, 2), (16, 1)) + self.assertEqual(value_block_specs[0].index_map(1, 1, 2), (128 + 16, 1)) + + x_block = np.ones((128, 128), dtype=np.float32) np.testing.assert_array_equal( - kernel_fn((0, 0, 0, 1), scalar_prefetch_values, (x,)), x + kernel_fn( + (0, 0, 0), + scalar_prefetch_values, + (np.ones((128, 128), dtype=np.float32),), + ), + x_block, ) - np.testing.assert_array_equal( - kernel_fn((0, 0, 1, 0), scalar_prefetch_values, (x,)), x + + def test_basic_reshape_sublanes_to_lanes(self): + + def f(x): + return x.reshape((512, 2048)) + + in_type = jax.ShapeDtypeStruct((512, 16, 128), jnp.float32) + f2, new_values, scalar_prefetch_values = block_spec_lib.get_fusion_values( + f, in_type ) - np.testing.assert_array_equal( - kernel_fn((0, 0, 3, 0), scalar_prefetch_values, (x,)), x + self.assertEmpty(new_values) + self.assertEmpty(scalar_prefetch_values) + + block_spec = pl.BlockSpec((256, 1024), lambda i, j, k: (i, k)) + kernel_fn, (value_block_specs, x_block_spec), _ = ( + block_spec_lib.pull_block_spec( + f2, + block_spec, + grid=(2, 3, 4), + scalar_prefetch_handler=block_spec_lib.make_scalar_prefetch_handler(), + )(new_values, in_type) ) + self.assertEmpty(value_block_specs) + self.assertEqual(x_block_spec.index_map(0, 1, 2), (0, 2, 0)) + self.assertEqual(x_block_spec.index_map(3, 2, 1), (3, 1, 0)) + + x = jnp.arange((256 * 1024), dtype=jnp.float32).reshape((256, 8, 128)) + y = kernel_fn((0, 1, 2), scalar_prefetch_values, (), x) + np.testing.assert_array_equal(y, x.reshape((256, 1024))) + + def test_basic_reshape_lanes_to_sublanes(self): + + def f(x): + return x.reshape((512, 32, 128)) + + in_type = jax.ShapeDtypeStruct((512, 4096), jnp.float32) + f2, new_values, scalar_prefetch_values = block_spec_lib.get_fusion_values( + f, in_type + ) + self.assertEmpty(new_values) + self.assertEmpty(scalar_prefetch_values) + + block_spec = pl.BlockSpec((256, 8, 128), lambda i, j, k: (i, k, 0)) + kernel_fn, (value_block_specs, x_block_spec), _ = ( + block_spec_lib.pull_block_spec( + f2, + block_spec, + grid=(2, 3, 4), + scalar_prefetch_handler=block_spec_lib.make_scalar_prefetch_handler(), + )(new_values, in_type) + ) + self.assertEmpty(value_block_specs) + self.assertEqual(x_block_spec.index_map(0, 1, 2), (0, 2)) + self.assertEqual(x_block_spec.index_map(3, 2, 1), (3, 1)) + + x = jnp.arange((256 * 1024), dtype=jnp.float32).reshape((256, 1024)) + y = kernel_fn((0, 1, 2), scalar_prefetch_values, (), x) + np.testing.assert_array_equal(y, x.reshape((256, 8, 128))) + + block_spec = pl.BlockSpec((256, 4, 256), lambda i, j, k: (i, j, k)) + with self.assertRaises(NotImplementedError): + _ = block_spec_lib.pull_block_spec( + f2, + block_spec, + grid=(2, 3, 4), + scalar_prefetch_handler=block_spec_lib.make_scalar_prefetch_handler(), + )(new_values, in_type) + + def test_basic_swap(self): + value = jnp.arange((512 * 1024), dtype=jnp.int32).reshape((512, 1024)) * 2 + x = jnp.zeros((256, 512), dtype=jnp.int32) + + def outer(refs): + ref, y_ref = refs + + def f(x): + return ref.swap(x) + + in_type = jax.ShapeDtypeStruct((512, 1024), jnp.int32) + f2, new_values, scalar_prefetch_values = block_spec_lib.get_fusion_values( + f, in_type + ) + self.assertLen(new_values, 1) # Captures Ref + self.assertEmpty(scalar_prefetch_values) + + block_spec = pl.BlockSpec((256, 512), lambda i, j, k: (i, k)) + kernel_fn, (value_block_specs, x_block_spec), _ = ( + block_spec_lib.pull_block_spec( + f2, + block_spec, + grid=(2, 3, 4), + scalar_prefetch_handler=block_spec_lib.make_scalar_prefetch_handler(), + )(new_values, in_type) + ) + self.assertLen(value_block_specs, 1) + self.assertEqual(x_block_spec.index_map(0, 1, 2), (0, 2)) + self.assertEqual(x_block_spec.index_map(3, 2, 1), (3, 1)) + + y_ref[...] = kernel_fn((0, 1, 1), scalar_prefetch_values, (ref,), x) + + y = jnp.zeros((256, 512), jnp.int32) + _, y = pl.run_state(outer)((value, y)) + np.testing.assert_array_equal(y, value[:256, 512:1024]) + + def test_basic_get(self): + value = jnp.arange((512 * 1024), dtype=jnp.int32).reshape((512, 1024)) * 2 + + def outer(refs): + ref, y_ref = refs + + def f(): + return ref.get() + + block_spec = pl.BlockSpec((256, 512), lambda i, j, k: (i, k)) + kernel_fn, (), _ = block_spec_lib.pull_block_spec( + f, + block_spec, + grid=(2, 3, 4), + scalar_prefetch_handler=block_spec_lib.make_scalar_prefetch_handler(), + )() + y_ref[...] = kernel_fn((0, 1, 1), ()) + + y = jnp.zeros((256, 512), jnp.int32) + _, y = pl.run_state(outer)((value, y)) + np.testing.assert_array_equal(y, value[:256, 512:1024]) + + def test_get_with_squeezed_block_spec(self): + value = ( + jnp.arange((4 * 512 * 1024), dtype=jnp.int32).reshape((4, 512, 1024)) + * 2 + ) + + def outer(refs): + ref, y_ref = refs + + def f(): + return ref.get() + + block_spec = pl.BlockSpec( + (pl.Squeezed(), 256, 512), lambda i, j, k: (j, i, k) + ) + kernel_fn, (), _ = block_spec_lib.pull_block_spec( + f, + block_spec, + grid=(2, 3, 4), + scalar_prefetch_handler=block_spec_lib.make_scalar_prefetch_handler(), + )() + y_ref[...] = kernel_fn((0, 3, 1), ()) + + y = jnp.zeros((256, 512), jnp.int32) + _, y = pl.run_state(outer)((value, y)) + np.testing.assert_array_equal(y, value[3, :256, 512:1024]) + + def test_get_with_squeezed_indexer(self): + value = ( + jnp.arange((4 * 512 * 1024), dtype=jnp.int32).reshape((4, 512, 1024)) + * 2 + ) + + def outer(refs): + ref, y_ref = refs + + def f(): + return ref[3] + + block_spec = pl.BlockSpec((256, 512), lambda i, j, k: (i, k)) + kernel_fn, (), _ = block_spec_lib.pull_block_spec( + f, + block_spec, + grid=(2, 3, 4), + scalar_prefetch_handler=block_spec_lib.make_scalar_prefetch_handler(), + )() + y_ref[...] = kernel_fn((0, 2, 1), ()) + + y = jnp.zeros((256, 512), jnp.int32) + _, y = pl.run_state(outer)((value, y)) + np.testing.assert_array_equal(y, value[3, :256, 512:1024]) + + def test_random_noise(self): + key = jax.random.key(0, impl='threefry2x32') + + def f(key): + return jax.random.uniform(key, (512, 512), dtype=jnp.float32) + + f2, new_values, scalar_prefetch_values = block_spec_lib.get_fusion_values( + f, key + ) + self.assertEmpty(new_values) + self.assertEmpty(scalar_prefetch_values) + + block_spec = pl.BlockSpec((128, 256), lambda i, j: (i, j)) + kernel_fn, (value_block_specs, key_block_spec), _ = ( + block_spec_lib.pull_block_spec( + f2, + block_spec, + grid=(4, 2), + scalar_prefetch_handler=block_spec_lib.make_scalar_prefetch_handler(), + )(new_values, key) + ) + self.assertEmpty(value_block_specs) + self.assertEqual(key_block_spec.memory_space, pl.MemorySpace.KEY) + self.assertIsNone(key_block_spec.block_shape) + + @jax.jit + def gen(idx): + k = key + for i in idx: + k = jax.random.fold_in(k, i) + return jax.random.uniform(k, (128, 256), dtype=jnp.float32) + + for i in range(4): + for j in range(2): + out = kernel_fn((i, j), scalar_prefetch_values, (), key) + out_ref = gen((i, j)) + np.testing.assert_array_equal(out, out_ref) class PullBlockSpecHOPTest(jtu.JaxTestCase): @@ -769,6 +1060,41 @@ def f(x): kernel_fn((0, 0, 0, 0), scalar_prefetch_values, (), x), relu_x ) + def test_pull_block_spec_handles_closed_over_constants(self): + x = jnp.ones((2, 512, 512)) + i = jnp.array(1) + + def f(): + return x[i] + + f2, new_values, scalar_prefetch_values = block_spec_lib.get_fusion_values(f) + self.assertLen(new_values, 1) + self.assertLen(scalar_prefetch_values, 1) + + block_spec = pl.BlockSpec( + (None, 1, 128, 128), lambda i, j, k, l, _: (i, j, k, l) + ) + kernel_fn, (value_block_specs,), _ = block_spec_lib.pull_block_spec( + f2, + block_spec, + grid=(2, 2, 4, 4), + scalar_prefetch_handler=block_spec_lib.make_scalar_prefetch_handler(), + )(new_values) + self.assertLen(value_block_specs, 1) + scalar_prefetch_values = jax.tree.map( + lambda x: x[None], scalar_prefetch_values + ) + fn = lambda x: kernel_fn((0, 0, 0, 0), scalar_prefetch_values, x) + new_values_type = (jax.ShapeDtypeStruct((1, 128, 128), jnp.float32),) + # Try pulling again + # This should not raise an error. + _ = block_spec_lib.pull_block_spec( + fn, + block_spec, + grid=(1,), + scalar_prefetch_handler=block_spec_lib.make_scalar_prefetch_handler(), + )(new_values_type) + class PushBlockSpecTest(parameterized.TestCase): @@ -800,6 +1126,32 @@ def f(x): out_block_spec = block_spec_lib.push_block_spec(f, block_spec)(x_type) self.assertEqual(out_block_spec.block_shape, block_spec.block_shape) + def test_push_reshape_lanes_to_sublanes(self): + def f(x): + return x.reshape((512, 32, 128)) + + x_type = jax.ShapeDtypeStruct((512, 4096), jnp.float32) + block_spec = pl.BlockSpec( + (256, 1024), lambda i, j, k: (i, k) + ) + out_block_spec = block_spec_lib.push_block_spec(f, block_spec)(x_type) + self.assertEqual(out_block_spec.block_shape, (256, 8, 128)) + self.assertTupleEqual(out_block_spec.index_map(0, 1, 2), (0, 2, 0)) + self.assertEqual(out_block_spec.index_map(3, 2, 1), (3, 1, 0)) + + def f(x): + return x.reshape((512, 16, 256)) + + x_type = jax.ShapeDtypeStruct((512, 4096), jnp.float32) + block_spec = pl.BlockSpec( + (256, 1024), lambda i, j, k: (i, k) + ) + out_block_spec = block_spec_lib.push_block_spec(f, block_spec)(x_type) + self.assertEqual(out_block_spec.block_shape, (256, 4, 256)) + self.assertTupleEqual(out_block_spec.index_map(0, 1, 2), (0, 2, 0)) + self.assertEqual(out_block_spec.index_map(3, 2, 1), (3, 1, 0)) + + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/pallas/fusion_test.py b/tests/pallas/fusion_test.py new file mode 100644 index 000000000000..4bd02345ca62 --- /dev/null +++ b/tests/pallas/fusion_test.py @@ -0,0 +1,234 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.testing import absltest +import jax +from jax._src import test_util as jtu +from jax.experimental.pallas import fuser +import jax.numpy as jnp +import numpy as np + +jax.config.parse_flags_with_absl() + + +class FusionTest(jtu.JaxTestCase): + + def test_basic_fusion(self): + + @jax.jit + @fuser.fuse + @fuser.fusible + def f(x_fn, y_fn): + x = x_fn() + if y_fn is None: + y_fn = lambda x: x + return y_fn(x) + + x = jax.random.normal(jax.random.key(0), (128, 128), dtype=jnp.float32) + np.testing.assert_array_equal(f(x), x) + + def test_separate_output_fusions_trivial(self): + + @fuser.fusible(output_fusion_prefix=(True, True)) + def f(x_fn, y_fn, z_fns): + x = x_fn() + y = y_fn() + if z_fns is None: + z_fns = lambda x: x, lambda x: x + z_fn1, z_fn2 = z_fns + return z_fn1(x), z_fn2(y) + + @jax.jit + @fuser.fuse + def g(x, y): + x, y = f(x, y) + return x, y * 2 + + x = jax.random.normal(jax.random.key(0), (128, 128), dtype=jnp.float32) + y = jax.random.normal(jax.random.key(1), (1, 128), dtype=jnp.float32) + x_out, y_out = g(x, y) + np.testing.assert_array_equal(x_out, x) + np.testing.assert_array_equal(y_out, y * 2) + + def test_separate_output_fusions_should_error_if_not_disjoint(self): + + @fuser.fusible(output_fusion_prefix=(True, True)) + def f(x_fn, y_fn, z_fns): + x = x_fn() + y = y_fn() + if z_fns is None: + z_fns = lambda x: x, lambda x: x + z_fn1, z_fn2 = z_fns + return z_fn1(x), z_fn2(y) + + @jax.jit + @fuser.fuse + def g(x, y): + x_res, y_res = f(x, y) + return x_res + y_res + + x = jax.random.normal(jax.random.key(0), (128, 128), dtype=jnp.float32) + y = jax.random.normal(jax.random.key(1), (128, 128), dtype=jnp.float32) + + with self.assertRaisesRegex( + ValueError, + "Outputs must be disjoint in order to use separate output fusions", + ): + g(x, y) + + def test_separate_output_fusions_allows_permute(self): + + @fuser.fusible(output_fusion_prefix=(True, True)) + def f(x_fn, y_fn, z_fns): + x = x_fn() + y = y_fn() + if z_fns is None: + z_fns = lambda x: x, lambda x: x + z_fn1, z_fn2 = z_fns + return z_fn1(x), z_fn2(y) + + @jax.jit + @fuser.fuse + def g(x, y): + x_res, y_res = f(x, y) + return y_res * 2, x_res + + x = jax.random.normal(jax.random.key(0), (128, 128), dtype=jnp.float32) + y = jax.random.normal(jax.random.key(1), (1, 128), dtype=jnp.float32) + y_out, x_out = g(x, y) + np.testing.assert_array_equal(x_out, x) + np.testing.assert_array_equal(y_out, y * 2) + + def test_separate_output_fusions_with_nesting(self): + + @fuser.fusible(output_fusion_prefix=(True, True)) + def f(x_fn, y_fn, z_fns): + x = x_fn() + y = y_fn() + if z_fns is None: + z_fns = lambda x: x, lambda x: x + z_fn1, z_fn2 = z_fns + return z_fn1(x), z_fn2(y) + + @jax.jit + @fuser.fuse + def g(x, y): + x_res, y_res = f(x, y) + return (x_res * 2, x_res + x_res), y_res + + x = jax.random.normal(jax.random.key(0), (128, 128), dtype=jnp.float32) + y = jax.random.normal(jax.random.key(1), (1, 128), dtype=jnp.float32) + (x1_out, x2_out), y_out = g(x, y) + np.testing.assert_array_equal(x1_out, x * 2) + np.testing.assert_array_equal(x2_out, x + x) + np.testing.assert_array_equal(y_out, y) + + def test_separate_output_fusions_with_nesting_and_permutation(self): + + @fuser.fusible(output_fusion_prefix=(True, True)) + def f(x_fn, y_fn, z_fns): + x = x_fn() + y = y_fn() + if z_fns is None: + z_fns = lambda x: x, lambda x: x + z_fn1, z_fn2 = z_fns + return z_fn1(x), z_fn2(y) + + @jax.jit + @fuser.fuse + def g(x, y): + x_res, y_res = f(x, y) + return y_res, (x_res * 2, x_res + x_res) + + x = jax.random.normal(jax.random.key(0), (128, 128), dtype=jnp.float32) + y = jax.random.normal(jax.random.key(1), (1, 128), dtype=jnp.float32) + y_out, (x1_out, x2_out) = g(x, y) + np.testing.assert_array_equal(x1_out, x * 2) + np.testing.assert_array_equal(x2_out, x + x) + np.testing.assert_array_equal(y_out, y) + + def test_separate_output_fusions_with_deep_output_mask(self): + + @fuser.fusible(output_fusion_prefix=(True, (True, True))) + def f(x_fn, y_fn, z_fn, o_fns): + x = x_fn() + y = y_fn() + z = z_fn() + if o_fns is None: + o_fns = lambda x: x, (lambda x: x, lambda x: x) + o_fn1, (o_fn2, o_fn3) = o_fns + return o_fn1(x), (o_fn2(y), o_fn3(z)) + + @jax.jit + @fuser.fuse + def g(x, y, z): + x_res, (y_res, z_res) = f(x, y, z) + return (x_res * 2, (y_res, z_res + z_res)) + + x = jax.random.normal(jax.random.key(0), (128, 128), dtype=jnp.float32) + y = jax.random.normal(jax.random.key(1), (1, 128), dtype=jnp.float32) + z = jax.random.normal(jax.random.key(1), (128, 1), dtype=jnp.float32) + x_out, (y_out, z_out) = g(x, y, z) + np.testing.assert_array_equal(x_out, x * 2) + np.testing.assert_array_equal(y_out, y) + np.testing.assert_array_equal(z_out, z + z) + + def test_separate_output_fusions_with_reused_value(self): + + @fuser.fusible(output_fusion_prefix=(True, True)) + def f(x_fn, y_fn, z_fns): + x = x_fn() + y = y_fn() + if z_fns is None: + z_fns = lambda x: x, lambda x: x + z_fn1, z_fn2 = z_fns + return z_fn1(x), z_fn2(y) + + @jax.jit + @fuser.fuse + def g(x, y, a): + x_res, y_res = f(x, y) + return y_res + a, (x_res * 2, x_res + x_res + a) + + x = jax.random.normal(jax.random.key(0), (128, 128), dtype=jnp.float32) + y = jax.random.normal(jax.random.key(1), (1, 128), dtype=jnp.float32) + a = jax.random.normal(jax.random.key(1), (1, 128), dtype=jnp.float32) + y_out, (x1_out, x2_out) = g(x, y, a) + np.testing.assert_array_equal(x1_out, x * 2) + np.testing.assert_array_equal(x2_out, x + x + a) + np.testing.assert_array_equal(y_out, y + a) + + def test_empty_fusion(self): + + @fuser.fusible + def f(x_fn, y_fn): + x = x_fn() + if y_fn is None: + y_fn = lambda x: x + return y_fn(x) + + @jax.jit + @fuser.fuse + def g(x, a): + _ = f(x) + return a + + x = jax.random.normal(jax.random.key(0), (128, 128), dtype=jnp.float32) + a = jax.random.normal(jax.random.key(1), (128, 128), dtype=jnp.float32) + y_out = g(x, a) + np.testing.assert_array_equal(y_out, a) + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/pallas/gpu_ops_test.py b/tests/pallas/gpu_ops_test.py index a33760cbfa86..edda5cf686db 100644 --- a/tests/pallas/gpu_ops_test.py +++ b/tests/pallas/gpu_ops_test.py @@ -153,7 +153,7 @@ def setUp(self): batch_size=(1, 2), seq_len=(128, 384), num_heads=(1, 2, 8), - head_dim=(32, 64, 128), + head_dim=(32, 64, 72, 128), block_sizes=( (("block_q", 128), ("block_k", 128)), (("block_q", 64), ("block_k", 64)), @@ -226,14 +226,14 @@ def impl(q, k, v): batch_size=(1, 2), seq_len=(128, 384), num_heads=(1, 2), - head_dim=(32, 64, 128,), + head_dim=(32, 64, 72, 128,), block_sizes=( ( ("block_q", 128), ("block_k", 128), - ("block_q_dkv", 128), - ("block_kv_dkv", 128), - ("block_q_dq", 128), + ("block_q_dkv", 32), + ("block_kv_dkv", 32), + ("block_q_dq", 32), ("block_kv_dq", 128), ), ( @@ -248,8 +248,8 @@ def impl(q, k, v): ("block_q", 64), ("block_k", 128), ("block_q_dkv", 64), - ("block_kv_dkv", 128), - ("block_q_dq", 128), + ("block_kv_dkv", 32), + ("block_q_dq", 32), ("block_kv_dq", 64), ), ), @@ -267,6 +267,17 @@ def test_fused_attention_bwd( causal, use_segment_ids, ): + if jtu.is_cuda_compute_capability_equal("8.0") and all([ + dict(block_sizes)["block_q"] == 128, + batch_size == 2, + num_heads == 2, + head_dim == 128, + causal, + not use_segment_ids + ]): + # TODO(b/416306534) + self.skipTest("Precision issues after CUDA 12.8.1 upgrade") + k1, k2, k3 = random.split(random.key(0), 3) q = random.normal( k1, (batch_size, seq_len, num_heads, head_dim), dtype=jnp.float16 @@ -302,6 +313,30 @@ def f_ref(q, k, v): self.assertAllClose(dk, dk_ref, atol=5e-2) self.assertAllClose(dv, dv_ref, atol=5e-2) + def test_return_residuals_not_differentiable(self): + batch_size, seq_len, num_heads, head_dim = 2, 128, 2, 128 + causal = False + k1, k2, k3 = random.split(random.key(0), 3) + q = random.normal( + k1, (batch_size, seq_len, num_heads, head_dim), dtype=jnp.float16 + ) + k = random.normal( + k2, (batch_size, seq_len, num_heads, head_dim), dtype=jnp.float16 + ) + v = random.normal( + k3, (batch_size, seq_len, num_heads, head_dim), dtype=jnp.float16 + ) + segment_ids = None + + def f(q, k, v): + return attention.mha(q, k, v, causal=causal, segment_ids=segment_ids, + interpret=self.INTERPRET, + return_residuals=True)[0].sum() + + with self.assertRaisesRegex(ValueError, "Kernel differentiation is not" + " supported if return_residuals is True."): + _ = jax.grad(f, argnums=(0, 1, 2))(q, k, v) + class FusedAttentionInterpretTest(FusedAttentionTest): INTERPRET = True diff --git a/tests/pallas/gpu_paged_attention_test.py b/tests/pallas/gpu_paged_attention_test.py index 081051f15dae..1b778c787a6d 100644 --- a/tests/pallas/gpu_paged_attention_test.py +++ b/tests/pallas/gpu_paged_attention_test.py @@ -44,9 +44,11 @@ def _generate_qkv( k_pages = jax.random.normal( k1, (num_kv_heads, total_pages, page_size, head_dim), dtype=dtype ) + k_pages = k_pages / jnp.linalg.norm(k_pages, axis=-1)[..., None] v_pages = jax.random.normal( k2, (num_kv_heads, total_pages, page_size, head_dim), dtype=dtype ) + v_pages = v_pages / jnp.linalg.norm(v_pages, axis=-1)[..., None] block_tables = jnp.arange( batch_size * max_num_blocks_per_seq, dtype=jnp.int32 @@ -54,6 +56,7 @@ def _generate_qkv( block_tables = jax.random.permutation(k3, block_tables, independent=True) block_tables = block_tables.reshape(batch_size, max_num_blocks_per_seq) q = jax.random.normal(k4, (batch_size, num_heads, head_dim), dtype=dtype) + q = q / jnp.linalg.norm(q, axis=-1)[..., None] return q, k_pages, v_pages, block_tables @@ -72,6 +75,17 @@ def fn(_block_tables, _pages): return out +def _quantize(x: jax.Array, dtype=jnp.int8): + if isinstance(dtype, jnp.floating): + max_val = jnp.astype(jnp.finfo(dtype).max, x.dtype) + else: + max_val = 127 + x_scale = jnp.max(jnp.abs(x), axis=-1) / (0.95 * max_val) + x_quant = (x / x_scale[..., None]) + if isinstance(dtype, jnp.floating): + x_quant = jnp.rint(x_quant) + return x_quant.astype(dtype), x_scale.astype(x.dtype) + @jtu.with_config(jax_traceback_filtering="off") class PallasBaseTest(jtu.JaxTestCase): @@ -93,7 +107,6 @@ def setUp(self): super().setUp() - class PagedAttentionKernelTest(PallasBaseTest): def setUp(self): @@ -154,6 +167,83 @@ def test_paged_attention( self.assertArraysAllClose(o, o_ref, rtol=5e-2, atol=5e-2) + @jtu.sample_product( + dtype=(jnp.float16,), + page_size=(8, 16, 32), + num_kv_heads=(1, 2), + q_kv_head_ratio=(2, 16, 20), + head_dim=(32, 64), + block_h=(16, 32), + pages_per_compute_block=(4, 8), + k_splits=(4, 16), + attn_logits_soft_cap=(None,), + quantize_k=(True, False), + quantize_v=(True, False), + quant_dtype=(jnp.float8_e5m2, jnp.float8_e4m3fn, jnp.int8), + ) + def test_quantized_paged_attention( + self, + dtype, + page_size, + num_kv_heads, + q_kv_head_ratio, + head_dim, + block_h, + pages_per_compute_block, + k_splits, + attn_logits_soft_cap, + quantize_k, + quantize_v, + quant_dtype, + ): + if not quantize_k and not quantize_v: + self.skipTest("Skipping since neither (k, v) quantization requested.") + if (quant_dtype == jnp.float8_e4m3fn + and not jtu.is_cuda_compute_capability_at_least("8.9")): + self.skipTest("Skipping since float8_e4m3fn is not supported on < sm89") + max_kv_len = 2048 + seq_lens = np.asarray([3, 256, 513, 1023, 2048], dtype=jnp.int32) + q, k_pages, v_pages, block_tables = _generate_qkv( + seq_lens.shape[0], + page_size, + max_kv_len, + num_kv_heads, + num_kv_heads * q_kv_head_ratio, + head_dim, + jax.random.key(0), + dtype, + ) + k = _reconstruct_kv(block_tables, k_pages) + v = _reconstruct_kv(block_tables, v_pages) + + k_, k_scales = (_quantize(k_pages, quant_dtype) + if quantize_k else (k_pages, None)) + v_, v_scales = (_quantize(k_pages, quant_dtype) + if quantize_v else (v_pages, None)) + + o = paged_attention.paged_attention( + q, + k_, + v_, + block_tables, + seq_lens, + k_scales_pages=k_scales, + v_scales_pages=v_scales, + block_h=block_h, + pages_per_compute_block=pages_per_compute_block, + k_splits=k_splits, + attn_logits_soft_cap=attn_logits_soft_cap, + interpret=self.INTERPRET, + ) + + o_ref = paged_attention.paged_attention_reference(q, k, v, lengths=seq_lens) + + error = (jnp.linalg.norm((o - o_ref).astype(jnp.float32), axis=-1) + / jnp.linalg.norm(o_ref.astype(jnp.float32))) + + admissible_error = 3e-1 + self.assertLessEqual(jnp.mean(error), admissible_error) + class PagedAttentionInterpretTest(PagedAttentionKernelTest): INTERPRET = True diff --git a/tests/pallas/gpu_pallas_distributed_test.py b/tests/pallas/gpu_pallas_distributed_test.py new file mode 100644 index 000000000000..163adc385b23 --- /dev/null +++ b/tests/pallas/gpu_pallas_distributed_test.py @@ -0,0 +1,152 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for distributed pallas GPU operations.""" + +import functools +import os + +import jax +from jax import lax +from jax._src import test_util as jtu +from jax._src import test_multiprocess as jt_multiprocess +from jax.experimental import pallas as pl +from jax.experimental import shard_map +from jax.experimental.pallas import mosaic_gpu as plgpu +import jax.experimental.mosaic.gpu as mgpu +import jax.numpy as jnp +import numpy as np + + +P = jax.sharding.PartitionSpec +partial = functools.partial + + +class PallasCallRemoteDMATest(jt_multiprocess.MultiProcessTest): + + def setUp(self): + if (not jtu.test_device_matches(["cuda"]) or + not jtu.is_cuda_compute_capability_at_least("9.0")): + self.skipTest("Only works on GPU with capability >= sm90") + if not mgpu.supports_cross_device_collectives(): + self.skipTest("NVSHMEM library unavailable.") + if jax.process_count() == 1: + self.skipTest("Test requires multiple processes.") + if os.environ.get("XLA_PYTHON_CLIENT_ALLOCATOR", "") == "platform": + self.skipTest("NVSHMEM doesn't work with the platform allocator.") + super().setUp() + + def test_basic_remote_dma(self): + if jax.process_index() > 2: + return # Only 2 processes needed. + def kernel(x_ref, y_ref, ready_sem, recv_sem): + other_dev_id = 1 - lax.axis_index('x') + y_ref[...] = x_ref[...] + pl.semaphore_signal(ready_sem, device_id=other_dev_id, + device_id_type=pl.DeviceIdType.LOGICAL) + pl.semaphore_wait(ready_sem) + neighbor_ptr = plgpu.remote_ref( + y_ref, other_dev_id, device_id_type=pl.DeviceIdType.LOGICAL + ) + neighbor_ptr[...] = x_ref[...] + pl.semaphore_signal(recv_sem, device_id=other_dev_id, + device_id_type=pl.DeviceIdType.LOGICAL) + pl.semaphore_wait(recv_sem) + + x = jnp.arange(2 * 8 * 128.0, dtype=jnp.float32).reshape((2 * 8, 128)) + def body(x): + return pl.pallas_call( + kernel, + in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)], + out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), + out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), + scratch_shapes=[ + plgpu.SemaphoreType.REGULAR, + plgpu.SemaphoreType.REGULAR, + ], + )(x) + + devices = jax.devices()[:2] + mesh = jax.sharding.Mesh(devices, ['x']) + y = jax.jit( + shard_map.shard_map( + body, mesh, in_specs=P('x'), out_specs=P('x'), check_rep=False, + ) + )(x) + + expected = x[8:] if jax.process_index() == 0 else x[:8] + np.testing.assert_allclose(y.addressable_shards[0].data, expected) + + def test_wait_twice(self): + if jax.process_index() > 2: + return # Only 2 processes needed. + + def kernel(y_ref, sem): + other_dev_id = 1 - lax.axis_index('x') + pl.semaphore_signal(sem, 2, device_id=other_dev_id, + device_id_type=pl.DeviceIdType.LOGICAL) + pl.semaphore_wait(sem) + pl.semaphore_wait(sem) + y_ref[...] = jnp.ones_like(y_ref) + + kernel_call = pl.pallas_call( + kernel, + out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), + out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), + scratch_shapes=[plgpu.SemaphoreType.REGULAR], + ) + + devices = jax.devices()[:2] + mesh = jax.sharding.Mesh(devices, ['x']) + y = jax.jit( + shard_map.shard_map( + kernel_call, mesh, in_specs=(), out_specs=P(None), check_rep=False, + ) + )() + np.testing.assert_allclose(y, jnp.ones_like(y)) + + def test_permuted_mesh(self): + def kernel(y_ref, sem): + other_dev_id = 1 - lax.axis_index('x') + pl.semaphore_signal(sem, 1, device_id=other_dev_id, + device_id_type=pl.DeviceIdType.LOGICAL) + pl.semaphore_wait(sem) + + kernel_call = pl.pallas_call( + kernel, + out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), + out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), + scratch_shapes=[plgpu.SemaphoreType.REGULAR], + ) + mesh = jax.sharding.Mesh(jax.devices()[::-1], ['x']) # Reverse the devices. + f = jax.jit( + shard_map.shard_map( + kernel_call, mesh, in_specs=(), out_specs=P(None), check_rep=False, + ) + ) + msg = ( + 'Mosaic GPU only supports meshes with device ordering that follows' + ' row-major device ids.' + ) + with self.assertRaisesRegex(NotImplementedError, msg): + f() + + +if __name__ == '__main__': + # This test doesn't work with the platform allocator, so we override it + # if it's ran alone. If it's part of a larger test suite and the platform + # allocator is used, setUp will skip the test. + os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.01' + os.environ['XLA_PYTHON_CLIENT_ALLOCATOR'] = 'default' + jt_multiprocess.main() diff --git a/tests/pallas/indexing_test.py b/tests/pallas/indexing_test.py index c3f3fa6e80a8..932076645c1e 100644 --- a/tests/pallas/indexing_test.py +++ b/tests/pallas/indexing_test.py @@ -14,7 +14,6 @@ from __future__ import annotations import sys -import unittest from absl.testing import absltest from absl.testing import parameterized @@ -32,11 +31,7 @@ else: pltpu = None -try: - import hypothesis as hp -except (ModuleNotFoundError, ImportError): - raise unittest.SkipTest("tests depend on hypothesis library") - +import hypothesis as hp import hypothesis.extra.numpy as hnp import hypothesis.strategies as hps @@ -95,7 +90,7 @@ def array_indexer_strategy(draw, shape) -> jax.Array: @hps.composite def indexer_strategy(draw, dim, int_indexer_shape - ) -> int | Slice | jax.Array: + ) -> int | Slice | jax.Array: return draw(hps.one_of( int_indexer_strategy(dim), slice_indexer_strategy(dim), @@ -104,12 +99,12 @@ def indexer_strategy(draw, dim, int_indexer_shape @hps.composite -def nd_indexer_strategy(draw, shape) -> NDIndexer: +def nd_indices_strategy(draw, shape) -> tuple[int | Slice | jax.Array, ...]: num_indices = draw(hps.integers(min_value=0, max_value=len(shape))) int_indexer_shape = draw(hnp.array_shapes()) indices = tuple(draw(indexer_strategy(dim, int_indexer_shape)) for dim in shape[:num_indices]) - return NDIndexer.from_indices_shape(indices, shape) + return indices class PallasBaseTest(jtu.JaxTestCase): @@ -127,6 +122,7 @@ def pallas_call(cls, *args, **kwargs): return pl.pallas_call(*args, interpret=cls.INTERPRET, **kwargs) +@jtu.thread_unsafe_test_class() # hypothesis is not thread safe class IndexerTest(jtu.JaxTestCase): """These are unit tests for the indexer logic, not using pallas_call.""" @@ -217,12 +213,15 @@ def test_indexer_with_all_types(self): indices = (ds(0, 2), np.arange(5)[:, None], np.arange(4)[None]) indexer = NDIndexer.from_indices_shape(indices, shape) - self.assertTupleEqual(indexer.get_indexer_shape(), (5, 4, 2)) + self.assertTupleEqual(indexer.get_indexer_shape(), (2, 5, 4)) @hp.given(hps.data()) + @hp.settings(suppress_health_check=[hp.HealthCheck.too_slow]) # ASAN is slow def test_ndindexer(self, data): shape = data.draw(hnp.array_shapes()) - indexer = data.draw(nd_indexer_strategy(shape)) + indices = data.draw(nd_indices_strategy(shape)) + indexer = NDIndexer.from_indices_shape(indices, shape) + is_int_indexer = [not isinstance(idx, Slice) for idx in indexer.indices] rest_indexers, int_indexers = util.partition_list( is_int_indexer, indexer.indices @@ -234,18 +233,15 @@ def test_ndindexer(self, data): self.assertTupleEqual( indexer.int_indexer_shape, expected_int_indexer_shape ) + for idx in rest_indexers: self.assertIsInstance(idx, (np.ndarray, Slice)) if isinstance(idx, np.ndarray): self.assertTupleEqual(idx.shape, ()) self.assertEqual(idx.dtype, np.dtype("int32")) - rest_shape = tuple( - r.size for r in rest_indexers if not isinstance(r, np.ndarray) - ) - self.assertTupleEqual((*indexer.int_indexer_shape, *rest_shape), - indexer.get_indexer_shape()) +@jtu.thread_unsafe_test_class() # hypothesis is not thread safe class IndexerOpsTest(PallasBaseTest): def test_multi_indexing_interpreter_only(self): @@ -373,11 +369,13 @@ def permute_columns_in_row_kernel(left, right, new_left, new_right): def test_vmap_nd_indexing(self, data): self.skipTest("TODO(necula): enable this test; was in jax_triton.") vmap_shape = data.draw(hnp.array_shapes(min_dims=1, max_dims=3, min_side=2), - label="vmap_shape") + label="vmap_shape") el_shape = data.draw(hnp.array_shapes(min_dims=2), label="el_shape") # TODO(sharadmv,apaszke): enable rank 0 and rank 1 Refs # hp.assume(len(el_shape) >= 2) - nd_indexer = data.draw(nd_indexer_strategy(el_shape), label="nd_indexer") + nd_indexer = NDIndexer.from_indices_shape( + data.draw(nd_indices_strategy(el_shape), label="nd_indexer"), + el_shape) expected_shape = jax.eval_shape(lambda x: x[nd_indexer], jax.ShapeDtypeStruct(el_shape, jnp.float32)) @@ -390,7 +388,7 @@ def kernel(x_ref, y_ref): shape = el_shape for vmap_dim in vmap_shape[::-1]: index = data.draw(hps.integers(min_value=0, - max_value=max(0, len(shape) - 2)), + max_value=max(0, len(shape) - 2)), label="index") # hp.assume(index <= max(0, len(shape) - 2)) # TODO(sharadmv,apaszke): enable vmapping over batch axes in 2 minormost @@ -641,6 +639,34 @@ def kernel(x_ref, indices, y_ref): )(x, indices) self.assertAllClose(res[:, start : start + 1, :], x, atol=0., rtol=0.) + def test_scalar_load_from_vmem(self): + if not jtu.is_device_tpu_at_least(4): + self.skipTest("Requires TPU v4 or later") + def kernel(x_ref, o_ref, sem_ref): + o_ref[...] = jnp.zeros_like(o_ref) + scalar_val = x_ref[1, 2] + # Use scalar_val in both async_copy and store. + o_ref[scalar_val] = jnp.ones_like(o_ref[0]) * scalar_val + desc = pltpu.make_async_copy( + o_ref.at[scalar_val], + o_ref.at[scalar_val + 1], + sem_ref, + ) + desc.start() + desc.wait() + + x = jnp.array([[1, 2, 3], [4, 5, 6]], dtype=jnp.int32) + res = self.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((8, 8, 128), jnp.int32), + grid=(1,), + scratch_shapes=[pltpu.SemaphoreType.DMA] + )(x) + expected = jnp.zeros_like(res) + expected = expected.at[6].set(jnp.ones((8, 128), jnp.int32) * 6) + expected = expected.at[7].set(jnp.ones((8, 128), jnp.int32) * 6) + self.assertArraysEqual(res, expected) + class IndexerOpsInterpretTest(IndexerOpsTest): INTERPRET = True @@ -662,18 +688,18 @@ class IndexerOpsInterpretTest(IndexerOpsTest): ((4, 3), lambda arr, a, b, c, d: arr[a, 2]), # slice + 1-D array ((4, 3), lambda arr, a, b, c, d: arr[a, :]), - # ((4, 3), lambda arr, a, b, c, d: arr[:, a]), + ((4, 3), lambda arr, a, b, c, d: arr[:, a]), ((6, 8, 3), lambda arr, a, b, c, d: arr[c, ::3]), - # ((8, 6, 3), lambda arr, a, b, c, d: arr[::3, c]), - # ((8, 8, 3), lambda arr, a, b, c, d: arr[::4, ::2, a]), - # ((8, 8, 3), lambda arr, a, b, c, d: arr[::4, a, ::2]), + ((8, 6, 3), lambda arr, a, b, c, d: arr[::3, c]), + ((8, 8, 3), lambda arr, a, b, c, d: arr[::4, ::2, a]), + ((8, 8, 3), lambda arr, a, b, c, d: arr[::4, a, ::2]), ((8, 8, 3, 7), lambda arr, a, b, c, d: arr[b, ::4, a, ::2]), ((3, 8, 8, 7), lambda arr, a, b, c, d: arr[b, a, ::4, ::2]), # ((8, 8, 3, 7), lambda arr, a, b, c, d: arr[::4, b, a, ::2]), ((16, 3, 6, 2), lambda arr, a, b, c, d: arr[::4, a, 1::2, b]), ((8, 8, 3, 6), lambda arr, a, b, c, d: arr[b, ::4, a, a]), # slice + array w/ broadcasting - ((8, 8, 3, 6), lambda arr, a, b, c, d: \ + ((8, 8, 3, 6), lambda arr, a, b, c, d: arr[b[:, None], ::4, a[None], a[:, None]]), # integer + slice + 1-D array ((5, 8, 8, 3), lambda arr, a, b, c, d: arr[2, ::4, ::2, a]), diff --git a/tests/pallas/mgpu_attention_test.py b/tests/pallas/mgpu_attention_test.py index cf8ed30925bf..f86793174c16 100644 --- a/tests/pallas/mgpu_attention_test.py +++ b/tests/pallas/mgpu_attention_test.py @@ -16,10 +16,13 @@ import os +import contextlib import numpy as np from absl.testing import absltest, parameterized from jax._src import config from jax._src import test_util as jtu +from jax._src.lib import cuda_versions +from jax._src.pallas import pallas_call import jax.numpy as jnp # pylint: disable=g-import-not-at-top @@ -47,6 +50,9 @@ def setUp(self): if (not jtu.test_device_matches(["cuda"]) or not jtu.is_cuda_compute_capability_equal("9.0")): self.skipTest("Only works on GPU with capability sm90a") + context_stack = contextlib.ExitStack() + context_stack.enter_context(pallas_call._PALLAS_USE_MOSAIC_GPU(True)) + self.addCleanup(context_stack.close) @parameterized.product( batch_size=(1, 4), @@ -58,10 +64,13 @@ def setUp(self): (4, 4), ), # MHA head_dim=(64, 128, 256), + blocks=((64, 64), (64, 128), (128, 64)), attention_impl=( attention_mgpu.attention, attention_mgpu.attention_with_pipeline_emitter, ), + save_residuals=(True,), + causal=(True, False,), ) def test_flash_attention( self, @@ -70,24 +79,105 @@ def test_flash_attention( kv_seq_len, num_q_and_kv_heads, head_dim, + blocks, attention_impl, + save_residuals, + causal, ): + cuda_runtime_version = cuda_versions.cuda_runtime_get_version() + # TODO(pobudzey): Undo when we upgrade to cuda 12.9.1. + if causal and (cuda_runtime_version >= 12080 and cuda_runtime_version < 12091): + self.skipTest("Skipping because of ptxas miscompilation.") + + if causal and attention_impl == attention_mgpu.attention_with_pipeline_emitter: + self.skipTest("Pipeline emitter does not support causal attention.") + + if head_dim >= 256 and max(blocks) >= 128: + self.skipTest("Head dim too large for block sizes.") + num_q_heads, num_kv_heads = num_q_and_kv_heads + block_q, block_kv = blocks k1, k2, k3 = jax.random.split(jax.random.key(42), 3) q = jax.random.normal(k1, (batch_size, q_seq_len, num_q_heads, head_dim), jnp.float16) k = jax.random.normal(k2, (batch_size, kv_seq_len, num_kv_heads, head_dim), jnp.float16) v = jax.random.normal(k3, (batch_size, kv_seq_len, num_kv_heads, head_dim), jnp.float16) - out = attention_impl( + out, *res = attention_impl( q, k, v, attention_mgpu.TuningConfig( - block_q=64, block_kv=64, max_concurrent_steps=2 + block_q=block_q, block_kv=block_kv, max_concurrent_steps=2, causal=causal ), + save_residuals=save_residuals, ) - out_ref = attention_mgpu.attention_reference(q, k, v) + out_ref, *res_ref = attention_mgpu.attention_reference( + q, k, v, causal=causal, save_residuals=save_residuals) np.testing.assert_allclose(out, out_ref, atol=2e-3, rtol=1e-3) + if save_residuals: + (lse,) = res[0] + (lse_ref,) = res_ref[0] + np.testing.assert_allclose(lse, lse_ref, atol=2e-3, rtol=1e-3) + + @parameterized.product( + batch_size=(3,), + seq_lens=((512, 512), (3584, 4096)), + num_q_and_kv_heads=( + (4, 4), # MHA + (4, 1), # MQA + (6, 3), # GQA + ), + bwd_blocks = ( + (64, 64, 64, 64), + (64, 128, 128, 64), + (128, 128, 128, 128), + ), + head_dim=(64, 128, 256), + ) + def test_bwd_flash_attention( + self, + batch_size, + seq_lens, + num_q_and_kv_heads, + bwd_blocks, + head_dim, + ): + num_q_heads, num_kv_heads = num_q_and_kv_heads + kv_seq_len, q_seq_len = seq_lens + block_q_dq, block_kv_dq, block_q_dkv, block_kv_dkv = bwd_blocks + compute_wgs = 2 if head_dim <= 128 else 1 + k1, k2, k3 = jax.random.split(jax.random.key(42), 3) + q = jax.random.normal(k1, (batch_size, q_seq_len, num_q_heads, head_dim), jnp.float16) + k = jax.random.normal(k2, (batch_size, kv_seq_len, num_kv_heads, head_dim), jnp.float16) + v = jax.random.normal(k3, (batch_size, kv_seq_len, num_kv_heads, head_dim), jnp.float16) + + def f(q, k, v): + return attention_mgpu.attention( + q, + k, + v, + attention_mgpu.TuningConfig( + block_q=block_q_dq, block_kv=block_kv_dq, + max_concurrent_steps=2, compute_wgs_bwd=compute_wgs, + block_q_dkv=block_q_dkv, block_kv_dkv=block_kv_dkv, + block_q_dq=block_q_dq, block_kv_dq=block_kv_dq, + ) + ).sum() + + def f_ref(q, k, v): + return attention_mgpu.attention_reference(q, k, v).sum() + + try: + # TODO(pobudzey): Replace with `jtu.check_grads` when it's fixed. + dq, dk, dv = jax.grad(f, argnums=(0, 1, 2))(q, k, v) + dq_ref, dk_ref, dv_ref = jax.grad(f_ref, argnums=(0, 1, 2))(q, k, v) + + self.assertAllClose(dq, dq_ref, atol=7e-2) + self.assertAllClose(dk, dk_ref, atol=7e-2) + self.assertAllClose(dv, dv_ref, atol=5e-2) + except ValueError as e: + if "exceeds available shared memory" in e.args[0]: + self.skipTest("Not enough SMEM for this configuration.") if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/pallas/mgpu_collective_matmul_test.py b/tests/pallas/mgpu_collective_matmul_test.py new file mode 100644 index 000000000000..b1b7e0ffd118 --- /dev/null +++ b/tests/pallas/mgpu_collective_matmul_test.py @@ -0,0 +1,138 @@ +# Copyright 2025 The JAX Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Test different parameterizations of our Mosaic GPU collective matmul.""" + +import contextlib +import functools +import os + +from absl.testing import parameterized # pylint: disable=g-multiple-import +import jax +from jax import lax +from jax import random +from jax._src import test_multiprocess as jt_multiprocess +from jax._src import test_util as jtu +from jax._src.pallas import pallas_call +from jax.experimental.mosaic import gpu as mgpu +from jax.experimental.pallas.ops.gpu import collective_matmul_mgpu +from jax.experimental import shard +import jax.numpy as jnp +import numpy as np + + +P = jax.sharding.PartitionSpec + + +@jtu.with_config(jax_traceback_filtering="off") +class CollectiveMatmulTestCase(jtu.JaxTestCase): + + def setUp(self): + super().setUp() + if collective_matmul_mgpu is None: + self.skipTest("Mosaic GPU not available.") + if (not jtu.test_device_matches(["cuda"]) or + not jtu.is_cuda_compute_capability_equal("9.0")): + self.skipTest("Only works on GPU with capability sm90a") + if not mgpu.supports_cross_device_collectives(): + self.skipTest("NVSHMEM library unavailable.") + if jax.process_count() == 1: + self.skipTest("Test requires multiple processes.") + if os.environ.get("XLA_PYTHON_CLIENT_ALLOCATOR", "") == "platform": + self.skipTest("NVSHMEM doesn't work with the platform allocator.") + context_stack = contextlib.ExitStack() + self.addCleanup(context_stack.close) + context_stack.enter_context(pallas_call._PALLAS_USE_MOSAIC_GPU(True)) + num_devices = jax.device_count() + mesh = jax.make_mesh( + (num_devices,), ("x",), axis_types=(jax.sharding.AxisType.Explicit,) + ) + context_stack.enter_context(jax.sharding.use_mesh(mesh)) + + @parameterized.product( + m_shard=(1024, 8192), + n_shard=(64, 128, 192), + k=(256, 8192), + block_m=(64, 128, 192), + block_n=(64, 128, 192), + block_k=(64, 128), + max_concurrent_steps=(2, 4), + dtype=(jnp.float16, jnp.bfloat16), + ) + def test_all_gather_lhs_matmul( + self, + m_shard, + n_shard, + k, + block_m, + block_n, + block_k, + max_concurrent_steps, + dtype, + ): + num_devices = jax.device_count() + lhs_smem_size = block_m * block_k * max_concurrent_steps * 2 + rhs_smem_size = block_k * block_n * max_concurrent_steps * 2 + # H100 SMEM limit is 228kB. + if lhs_smem_size + rhs_smem_size > 228_000: + self.skipTest("This configuration requires too much SMEM.") + if n_shard != block_n: + self.skipTest("n_shard must be equal to block_n for now.") + if n_shard % block_n: + self.skipTest("n_shard must be divisible by block_n for now.") + if m_shard % block_m: + self.skipTest("m_shard must be divisible by block_m for now.") + + k1, k2 = random.split(random.key(1234), num=2) + lhs = random.normal(k1, (num_devices * m_shard, k), dtype) + rhs = random.normal(k2, (k, num_devices * n_shard), dtype) + lhs = shard.reshard(lhs, P("x", None)) + rhs = shard.reshard(rhs, P(None, "x")) + + def run(body): + out = jax.jit( + jax.shard_map(body, out_specs=P(None, "x"), check_vma=False) + )(lhs, rhs) + # Gather output, for NumPy comparison on the host. + out = jax.shard_map( + lambda x: lax.all_gather(x, "x", axis=1, tiled=True), + out_specs=P(None), check_vma=False, + )(out) + return out + + out = run( + functools.partial( + collective_matmul_mgpu.all_gather_lhs_matmul, + axis_name="x", + block_m=block_m, + block_n=block_n, + block_k=block_k, + max_concurrent_steps=max_concurrent_steps, + dtype=dtype, + ) + ) + ref_out = run(lambda x, y: lax.all_gather(x, "x", axis=0, tiled=True) @ y) + np.testing.assert_allclose(out, ref_out) + + +if __name__ == "__main__": + # This test doesn't work with the platform allocator, so we override it + # if it's ran alone. If it's part of a larger test suite and the platform + # allocator is used, setUp will skip the test. + os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.01" + os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "default" + os.environ["XLA_FLAGS"] = ( + os.environ.get("XLA_FLAGS", "") + " --xla_gpu_autotune_level=0" + ) + jt_multiprocess.main() diff --git a/tests/pallas/mgpu_matmul_test.py b/tests/pallas/mgpu_matmul_test.py new file mode 100644 index 000000000000..5c52b0c77296 --- /dev/null +++ b/tests/pallas/mgpu_matmul_test.py @@ -0,0 +1,88 @@ +# Copyright 2025 The JAX Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Test different parameterizations of matrix multiplication.""" + +import contextlib +import os + +from absl.testing import absltest +from absl.testing import parameterized +from jax._src import config +from jax._src import test_util as jtu +from jax._src.pallas import pallas_call +import jax.numpy as jnp +import numpy as np + + +# pylint: disable=g-import-not-at-top +try: + # We only import this to see if Mosaic is available. + import jax.experimental.mosaic.gpu # noqa: F401 +except ImportError: + blackwell_matmul_mgpu = None +else: + from jax.experimental.pallas.ops.gpu import blackwell_matmul_mgpu + + +config.parse_flags_with_absl() +os.environ["XLA_FLAGS"] = ( + os.environ.get("XLA_FLAGS", "") + " --xla_gpu_autotune_level=0") + + +@jtu.with_config(jax_traceback_filtering="off") +class MatrixMultiplicationSm100ATest(jtu.JaxTestCase): + + def setUp(self): + super().setUp() + if blackwell_matmul_mgpu is None: + self.skipTest("Mosaic GPU not available.") + if (not jtu.test_device_matches(["cuda"]) or + not jtu.is_cuda_compute_capability_equal("10.0")): + self.skipTest("Only works on GPU with capability sm100a") + context_stack = contextlib.ExitStack() + context_stack.enter_context(pallas_call._PALLAS_USE_MOSAIC_GPU(True)) + self.addCleanup(context_stack.close) + + @parameterized.product( + m=(1024, 4096), + k=(1024, 4096), + n=(1024, 4096), + dtype=(jnp.float16,), + ) + def test_matmul( + self, + m, + n, + k, + dtype, + ): + k1, k2, = jax.random.split(jax.random.key(42), 2) + a = jax.random.normal(k1, (m, k), dtype) + b = jax.random.normal(k2, (k, n), dtype) + + out = blackwell_matmul_mgpu.matmul_kernel( + a, + b, + blackwell_matmul_mgpu.TuningConfig( + tile_m=128, tile_n=128, tile_k=128, + max_concurrent_steps=2, + collective=False, + ), + ) + out_ref = a @ b + np.testing.assert_allclose(out, out_ref, atol=2e-3, rtol=1e-3) + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/pallas/mgpu_ragged_dot_test.py b/tests/pallas/mgpu_ragged_dot_test.py new file mode 100644 index 000000000000..e9137df1298a --- /dev/null +++ b/tests/pallas/mgpu_ragged_dot_test.py @@ -0,0 +1,114 @@ +# Copyright 2025 The JAX Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Test different parameterizations of our Mosaic GPU ragged dot kernel.""" + +import contextlib +import os + +from absl.testing import absltest, parameterized # pylint: disable=g-multiple-import +from jax import random +from jax._src import config +from jax._src import test_util as jtu +from jax._src.pallas import pallas_call +import jax.numpy as jnp +import numpy as np + +# pylint: disable=g-import-not-at-top +try: + # We only import this to see if Mosaic is available. + import jax.experimental.mosaic.gpu # noqa: F401 +except ImportError: + ragged_dot = None +else: + from jax.experimental.pallas.ops.gpu import ragged_dot_mgpu + + +config.parse_flags_with_absl() + + +@jtu.with_config(jax_traceback_filtering="off") +class RaggedDotTestCase(jtu.JaxTestCase): + + def setUp(self): + super().setUp() + if ragged_dot_mgpu is None: + self.skipTest("Mosaic GPU not available.") + if (not jtu.test_device_matches(["cuda"]) or + not jtu.is_cuda_compute_capability_equal("9.0")): + self.skipTest("Only works on GPU with capability sm90a") + context_stack = contextlib.ExitStack() + context_stack.enter_context(pallas_call._PALLAS_USE_MOSAIC_GPU(True)) + self.addCleanup(context_stack.close) + + @parameterized.product( + block_m=(64, 128, 192), + block_n=(64, 128, 192), + block_k=(64, 128), + grid_block_n=(2, 4), + max_concurrent_steps=(2, 4), + num_groups=(1, 3, 16), + ) + def test_ragged_dot( + self, + block_m, + block_n, + block_k, + grid_block_n, + max_concurrent_steps, + num_groups, + ): + dtype = jnp.float16 + lhs_smem_size = block_m * block_k * max_concurrent_steps * 2 + rhs_smem_size = block_k * block_n * max_concurrent_steps * 2 + # H100 SMEM limit is 228kB. + if lhs_smem_size + rhs_smem_size > 228_000: + self.skipTest("This configuration requires too much SMEM.") + + m, k, n = 16 * 1024, 2048, 16 * 1024 + kx, ky, kz = random.split(random.key(1234), num=3) + + lhs = jax.random.normal(kx, (m, k), dtype) + rhs = jax.random.normal(ky, (num_groups, k, n), dtype) + group_boundaries = jax.lax.sort( + jax.random.randint(kz, (num_groups - 1,), 0, m, jnp.int32) + ) + group_starts = jax.lax.concatenate( + [jnp.array([0], dtype=jnp.int32), group_boundaries], 0 + ) + group_ends = jax.lax.concatenate( + [group_boundaries, jnp.array([m], dtype=jnp.int32)], 0 + ) + group_sizes = group_ends - group_starts + assert group_sizes.shape == (num_groups,) + + out = ragged_dot_mgpu.ragged_dot( + lhs, + rhs, + group_sizes=group_sizes, + block_m=block_m, + block_n=block_n, + block_k=block_k, + max_concurrent_steps=max_concurrent_steps, + grid_block_n=grid_block_n, + ) + out_ref = jax.lax.ragged_dot(lhs, rhs, group_sizes=group_sizes) + np.testing.assert_allclose(out, out_ref, atol=1e-3, rtol=1e-3) + + +if __name__ == "__main__": + os.environ["XLA_FLAGS"] = ( + os.environ.get("XLA_FLAGS", "") + " --xla_gpu_autotune_level=0" + ) + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index b3c3ddb84e09..d13623e4238f 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -13,23 +13,39 @@ # limitations under the License. import contextlib +import dataclasses import functools +import itertools import math import operator import os import re +import sys import tempfile +import traceback +from typing import ClassVar from absl.testing import absltest from absl.testing import parameterized import jax +from jax import export from jax import lax +from jax._src import checkify from jax._src import test_util as jtu +from jax._src.pallas import core as pallas_core +from jax._src.pallas import pallas_call +from jax._src.pallas import primitives as pallas_primitives +from jax._src.pallas.mosaic_gpu import core as gpu_core +from jax._src.pallas.mosaic_gpu import lowering as mgpu_lowering from jax._src.pallas.mosaic_gpu import pipeline as mgpu_pipeline +from jax._src.pallas.mosaic_gpu import primitives as mgpu_primitives +from jax._src.state import types as state_types from jax.experimental import pallas as pl +import jax.experimental.mosaic.gpu as mgpu from jax.experimental.pallas import mosaic_gpu as plgpu import jax.numpy as jnp import numpy as np + try: from jax._src.lib import mosaic_gpu as mosaic_gpu_lib except ImportError: @@ -54,16 +70,48 @@ def _sum_same_dtype(x): return jnp.sum(x, dtype=x.dtype) -class PallasTest(jtu.JaxTestCase): +class PallasTestMetaclass(parameterized.TestGeneratorMetaclass): + + def __new__(mcs, *args, lowering_semantics=plgpu.LoweringSemantics.Lane): + cls = super().__new__(mcs, *args) + cls.LOWERING_SEMANTICS = lowering_semantics + return cls + + +class PallasTest(jtu.JaxTestCase, metaclass=PallasTestMetaclass): + LOWERING_SEMANTICS: ClassVar[plgpu.LoweringSemantics] def setUp(self): if not jtu.is_cuda_compute_capability_at_least("9.0"): self.skipTest("Only works on a GPU with capability >= sm90") + context_stack = contextlib.ExitStack() + context_stack.enter_context(pallas_call._PALLAS_USE_MOSAIC_GPU(True)) + self.addCleanup(context_stack.close) super().setUp() + def skip_if_wg_semantics(self): + if self.LOWERING_SEMANTICS == plgpu.LoweringSemantics.Warpgroup: + self.skipTest("Not supported under WG semantics") + + def kernel(self, *args, **kwargs): + compiler_params = dataclasses.replace( + kwargs.pop("compiler_params", plgpu.CompilerParams()), + lowering_semantics=self.LOWERING_SEMANTICS, + ) + return plgpu.kernel(*args, compiler_params=compiler_params, **kwargs) + + def pallas_call(self, *args, **kwargs): + compiler_params = dataclasses.replace( + kwargs.pop("compiler_params", plgpu.CompilerParams()), + lowering_semantics=self.LOWERING_SEMANTICS, + ) + return pl.pallas_call(*args, compiler_params=compiler_params, **kwargs) + @contextlib.contextmanager def capture_stdout(self): + if "pytest" in sys.modules: + self.skipTest("pytest interacts badly with GPU stdout capture") if mosaic_gpu_lib is None: raise ValueError("Running tests but missing Mosaic GPU extension") with jtu.capture_stdout() as stdout: @@ -79,6 +127,13 @@ def setUp(self): super().setUp() +class PallasSm100ATest(PallasTest, jtu.CudaArchSpecificTest): + + def setUp(self): + self.skip_unless_sm100a() + super().setUp() + + class PallasCallTest(PallasTest): @parameterized.product( @@ -93,17 +148,14 @@ class PallasCallTest(PallasTest): lax.log, ], approx_math=[True, False], - thread_semantics=[*plgpu.ThreadSemantics], ) - def test_unary_op(self, op, approx_math, thread_semantics): + def test_unary_op(self, op, approx_math): dtype = jnp.int32 if op is lax.bitwise_not else jnp.float32 @functools.partial( - pl.pallas_call, + self.pallas_call, out_shape=jax.ShapeDtypeStruct([256], dtype), - compiler_params=plgpu.GPUCompilerParams( - approx_math=approx_math, thread_semantics=thread_semantics - ), + compiler_params=plgpu.CompilerParams(approx_math=approx_math), ) def kernel(x_ref, o_ref): o_ref[...] = op(x_ref[...]) @@ -124,16 +176,10 @@ def kernel(x_ref, o_ref): jnp.maximum, ], dtype=[jnp.float32, jnp.int32, jnp.uint32], - thread_semantics=[*plgpu.ThreadSemantics], ) - def test_binary_op(self, op, dtype, thread_semantics): - + def test_binary_op(self, op, dtype): @functools.partial( - pl.pallas_call, - out_shape=jax.ShapeDtypeStruct([256], dtype), - compiler_params=plgpu.GPUCompilerParams( - thread_semantics=thread_semantics - ), + self.pallas_call, out_shape=jax.ShapeDtypeStruct([256], dtype) ) def kernel(x_ref, y_ref, o_ref): o_ref[...] = op(x_ref[...], y_ref[...]) @@ -154,16 +200,10 @@ def kernel(x_ref, y_ref, o_ref): ], # TODO(slebedev): Support integral types. dtype=[jnp.float32, jnp.int32, jnp.uint32], - thread_semantics=[*plgpu.ThreadSemantics], ) - def test_comparison_op(self, op, dtype, thread_semantics): - + def test_comparison_op(self, op, dtype): @functools.partial( - pl.pallas_call, - out_shape=jax.ShapeDtypeStruct([256], dtype), - compiler_params=plgpu.GPUCompilerParams( - thread_semantics=thread_semantics - ), + self.pallas_call, out_shape=jax.ShapeDtypeStruct([256], dtype) ) def kernel(o_ref): o_ref[...] = jnp.broadcast_to( @@ -173,8 +213,9 @@ def kernel(o_ref): np.testing.assert_array_equal(kernel(), jnp.full([256], op(42, 24), dtype)) def test_add_first(self): + @functools.partial( - pl.pallas_call, + self.pallas_call, out_shape=jax.ShapeDtypeStruct([256], jnp.float32), ) def kernel(x_ref, y_ref, o_ref): @@ -184,16 +225,10 @@ def kernel(x_ref, y_ref, o_ref): y = jnp.flip(x).reshape(1, 256) np.testing.assert_array_equal(kernel(x, y), x + y[0]) - @parameterized.product( - shape=[(128,), (128, 128)], thread_semantics=[*plgpu.ThreadSemantics] - ) - def test_reduce_sum(self, shape, thread_semantics): + @parameterized.product(shape=[(128,), (128, 128)]) + def test_reduce_sum(self, shape): @functools.partial( - pl.pallas_call, - out_shape=jax.ShapeDtypeStruct(shape, jnp.float32), - compiler_params=plgpu.GPUCompilerParams( - thread_semantics=thread_semantics - ), + self.pallas_call, out_shape=jax.ShapeDtypeStruct(shape, jnp.float32) ) def kernel(x_ref, o_ref): o_ref[...] = jnp.broadcast_to(_sum_same_dtype(x_ref[...]), o_ref.shape) @@ -202,11 +237,12 @@ def kernel(x_ref, o_ref): np.testing.assert_array_equal(kernel(x), jnp.sum(x)) def test_reshape(self): + self.skip_if_wg_semantics() + shape1, shape2 = (128,), (2, 16, 4) @functools.partial( - pl.pallas_call, - out_shape=jax.ShapeDtypeStruct(shape2, jnp.float32), + self.pallas_call, out_shape=jax.ShapeDtypeStruct(shape2, jnp.float32) ) def kernel(x_ref, out_ref): x_ref_reshaped = x_ref.reshape(shape2) @@ -217,14 +253,9 @@ def kernel(x_ref, out_ref): x = jnp.arange(math.prod(shape1)).astype(jnp.float32) np.testing.assert_array_equal(kernel(x), x.reshape(shape2)) - @parameterized.product(thread_semantics=[*plgpu.ThreadSemantics]) - def test_add_xy_indexed(self, thread_semantics): + def test_add_xy_indexed(self): @functools.partial( - pl.pallas_call, - out_shape=jax.ShapeDtypeStruct([128], jnp.float32), - compiler_params=plgpu.GPUCompilerParams( - thread_semantics=thread_semantics - ), + self.pallas_call, out_shape=jax.ShapeDtypeStruct([128], jnp.float32) ) def kernel(x_ref, y_ref, o_ref): idx = _sum_same_dtype(y_ref[...]) @@ -235,8 +266,9 @@ def kernel(x_ref, y_ref, o_ref): np.testing.assert_array_equal(kernel(x, y), x[jnp.sum(y)]) def test_add_one_grid(self): + @functools.partial( - pl.pallas_call, + self.pallas_call, in_specs=[pl.BlockSpec((128,), lambda *i: i)], out_specs=pl.BlockSpec((128,), lambda *i: i), out_shape=jax.ShapeDtypeStruct([128 * 2], jnp.float32), @@ -249,9 +281,8 @@ def kernel(x_ref, o_ref): np.testing.assert_array_equal(kernel(x), x + 1.0) def test_add_one_grid_with_scratch(self): - @functools.partial( - pl.pallas_call, + self.pallas_call, out_shape=jax.ShapeDtypeStruct([128 * 2], jnp.float32), in_specs=[pl.BlockSpec((128,), lambda *i: i)], out_specs=pl.BlockSpec((128,), lambda *i: i), @@ -269,11 +300,11 @@ def kernel(x_ref, o_ref, scratch_ref): def test_add_one_grid_pipelined(self, max_concurrent_steps): @functools.partial( - pl.pallas_call, + self.pallas_call, in_specs=[pl.BlockSpec((128, 16), lambda i, j: (i, j))], out_specs=pl.BlockSpec((128, 16), lambda i, j: (i, j)), out_shape=jax.ShapeDtypeStruct([128 * 2, 64], jnp.float32), - compiler_params=plgpu.GPUCompilerParams( + compiler_params=plgpu.CompilerParams( dimension_semantics=["parallel", "sequential"], max_concurrent_steps=max_concurrent_steps, ), @@ -288,10 +319,10 @@ def kernel(x_ref, o_ref): def test_add_one_grid_pipelined_program_id(self): @functools.partial( - pl.pallas_call, + self.pallas_call, out_specs=pl.BlockSpec((16, 16), lambda i, j: (i, j)), out_shape=jax.ShapeDtypeStruct([16, 64], jnp.int32), - compiler_params=plgpu.GPUCompilerParams( + compiler_params=plgpu.CompilerParams( dimension_semantics=["parallel", "sequential"], max_concurrent_steps=2, ), @@ -306,12 +337,13 @@ def kernel(o_ref): ) def test_add_one_grid_pipelined_sequential_invariant_output(self): + @functools.partial( - pl.pallas_call, + self.pallas_call, in_specs=[pl.BlockSpec((32, 16), lambda i, j: (i, j))], out_specs=pl.BlockSpec((32, 16), lambda i, j: (i, 0)), out_shape=jax.ShapeDtypeStruct([32 * 2, 64], jnp.float32), - compiler_params=plgpu.GPUCompilerParams( + compiler_params=plgpu.CompilerParams( dimension_semantics=["parallel", "sequential"], max_concurrent_steps=2, ), @@ -334,30 +366,106 @@ def kernel(x_ref, o_ref): @parameterized.parameters(jnp.float32, jnp.int32, jnp.uint32) def test_iota(self, dtype): + self.skip_if_wg_semantics() + dimension = 1 + @functools.partial( - pl.pallas_call, - out_shape=jax.ShapeDtypeStruct((128, 128), dtype), + self.pallas_call, out_shape=jax.ShapeDtypeStruct((128, 128), dtype) ) def kernel(o_ref): - o_ref[...] = plgpu.broadcasted_iota(dtype, (128, 128), dimension, layout=plgpu.Layout.WGMMA) + o_ref[...] = plgpu.broadcasted_iota( + dtype, o_ref.shape, dimension, layout=plgpu.Layout.WGMMA + ) - np.testing.assert_array_equal(kernel(), jax.lax.broadcasted_iota(dtype, (128, 128), dimension)) + np.testing.assert_array_equal( + kernel(), jax.lax.broadcasted_iota(dtype, (128, 128), dimension) + ) - @parameterized.product( - indexer=[..., slice(128), slice(None, 128)], - thread_semantics=[*plgpu.ThreadSemantics], - ) - def test_copy_smem_to_gmem(self, indexer, thread_semantics): + def test_inline_mgpu(self): + dtype = jnp.dtype(jnp.bfloat16) + self.skip_if_wg_semantics() + shape = (128, 128) + tile = (64, 128 // dtype.itemsize) + tiled_shape = mgpu.tile_shape(shape, tile) + tiled_shape_t = list(tiled_shape) + tiled_shape_t[0], tiled_shape_t[1] = tiled_shape_t[1], tiled_shape_t[0] + key = jax.random.key(0) + x = (jax.random.uniform(key, (2, *shape)) * 42).astype(dtype) + + transforms = ( + plgpu.TilingTransform(tile), + plgpu.TransposeTransform((0, 2, 1, 3, 4)), + plgpu.SwizzleTransform(128), + ) @functools.partial( - pl.pallas_call, + self.pallas_call, + out_shape=jax.ShapeDtypeStruct(shape, dtype), + in_specs=(pl.BlockSpec(memory_space=plgpu.GMEM),), + scratch_shapes=[ + plgpu.SMEM( + x.shape, + dtype, + transforms=transforms, + ), + plgpu.Barrier(), + ], + out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), + ) + def kernel(x_ref, o_ref, smem_ref, barrier): + plgpu.copy_gmem_to_smem(x_ref, smem_ref, barrier) + plgpu.barrier_wait(barrier) + # Add an indexer at the end. + sliced_smem_ref = smem_ref.at[0] + @plgpu.inline_mgpu( + arg_types=(plgpu.RefType(( + plgpu.TilingTransform(tile), + plgpu.TransposeTransform((1, 0, 2, 3)), + plgpu.SwizzleTransform(128), + )),), + return_type=plgpu.ShapeDtypeStruct( + shape, dtype, layout=plgpu.Layout.WGMMA + ), + ) + def foo(ctx, smem_ref): + del ctx + assert smem_ref.type.shape == tiled_shape_t, (smem_ref.type, tiled_shape_t) + x = mgpu.FragmentedArray.load_tiled(smem_ref, swizzle=128) + y = mgpu.FragmentedArray.splat( + mgpu.c(1, x.mlir_dtype), shape=x.shape, layout=x.layout + ) + return (x + y) + + arr = foo(sliced_smem_ref) + @plgpu.inline_mgpu(arg_types=(plgpu.Layout.WGMMA, plgpu.RefType(transforms), plgpu.RefType())) + def store(ctx, arr, smem_ref, o_ref): + sliced_smem_ref = mgpu.memref_slice(smem_ref, (0,)) + arr.store_tiled(sliced_smem_ref, swizzle=128) + mgpu.commit_shared() + ctx.async_copy( + src_ref=sliced_smem_ref, + dst_ref=o_ref, + swizzle=128, + gmem_transform=( + mgpu.TileTransform(tile), + mgpu.TransposeTransform((1, 0, 2, 3)), + ), + ) + ctx.await_async_copy(0) + + # This time we slice inside the inline_mgpu body. + store(arr, smem_ref, o_ref) + + np.testing.assert_array_equal(kernel(x), x[0] + 1) + + @parameterized.product(indexer=[..., slice(128), slice(None, 128)]) + def test_copy_smem_to_gmem(self, indexer): + @functools.partial( + self.pallas_call, out_shape=jax.ShapeDtypeStruct([256], jnp.float32), out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), scratch_shapes=[plgpu.SMEM((256,), jnp.float32)], - compiler_params=plgpu.GPUCompilerParams( - thread_semantics=thread_semantics - ), ) def kernel(x_ref, o_ref_gmem, scratch_ref): scratch_ref[...] = x_ref[...] + 1 @@ -368,6 +476,29 @@ def kernel(x_ref, o_ref_gmem, scratch_ref): x = jnp.arange(256).astype(jnp.float32) np.testing.assert_array_equal(kernel(x)[indexer], x[indexer] + 1.0) + @parameterized.parameters(jnp.bfloat16, jnp.float16, jnp.float32) + def test_copy_smem_to_gmem_reduction(self, dtype): + @functools.partial( + pl.pallas_call, + grid=(200,), + in_specs=[pl.BlockSpec((128,), lambda *i: i), pl.BlockSpec(memory_space=plgpu.GMEM)], + out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), + out_shape=jax.ShapeDtypeStruct([128], dtype), + scratch_shapes=[plgpu.SMEM((128,), dtype)], + input_output_aliases={1:0} + ) + def kernel(x_ref, o_ref_gmem, o_ref_gmem_alias, scratch_ref): + del o_ref_gmem_alias + scratch_ref[...] = x_ref[...] + plgpu.commit_smem() + plgpu.copy_smem_to_gmem(scratch_ref.at[...], o_ref_gmem.at[...], reduction_op="add") + plgpu.wait_smem_to_gmem(0) + x = jnp.ones(200 * 128).astype(dtype) # 200 blocks + output = jnp.zeros(128).astype(dtype) + output = kernel(x, output) + output_val = x.reshape(-1, 128).sum(axis=0) + np.testing.assert_array_equal(output, output_val) + @parameterized.named_parameters( {"testcase_name": "1d_none", "shape": (256,), "indexers": (slice(0, 128), slice(None, 32))}, @@ -377,8 +508,9 @@ def kernel(x_ref, o_ref_gmem, scratch_ref): "shape": (64, 64), "indexers": (4, slice(0, 64))}, ) def test_copy_smem_to_gmem_with_multiple_gmem_indexers(self, shape, indexers): + @functools.partial( - pl.pallas_call, + self.pallas_call, out_shape=jax.ShapeDtypeStruct(shape, jnp.float32), out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), scratch_shapes=[plgpu.SMEM(shape, jnp.float32)], @@ -402,13 +534,14 @@ def kernel(x_ref, o_ref_gmem, scratch_ref): @parameterized.product(indexer=[..., slice(128), slice(None, 128)]) def test_copy_gmem_to_smem(self, indexer): + @functools.partial( - pl.pallas_call, + self.pallas_call, out_shape=jax.ShapeDtypeStruct([256], jnp.float32), in_specs=(pl.BlockSpec(memory_space=plgpu.GMEM),), scratch_shapes=[ plgpu.SMEM((256,), jnp.float32), - plgpu.Barrier(num_arrivals=1), + plgpu.Barrier(), ], ) def kernel(x_ref_gmem, o_ref, scratch_ref, barrier_ref): @@ -447,13 +580,15 @@ def kernel(x_ref_gmem, o_ref, scratch_ref, barrier_ref): }, ) def test_copy_gmem_to_smem_with_multiple_gmem_indexers(self, shape, indexers): + @functools.partial( - pl.pallas_call, + self.pallas_call, out_shape=jax.ShapeDtypeStruct(shape, jnp.float32), in_specs=(pl.BlockSpec(memory_space=plgpu.GMEM),), - scratch_shapes=[plgpu.SMEM(shape, jnp.float32), - plgpu.Barrier(num_arrivals=1), - ], + scratch_shapes=[ + plgpu.SMEM(shape, jnp.float32), + plgpu.Barrier(), + ], grid=(1,), ) def kernel(x_ref_gmem, o_ref, scratch_ref, barrier_ref): @@ -478,12 +613,12 @@ def kernel(x_ref_gmem, o_ref, scratch_ref, barrier_ref): def test_gmem_to_smem_with_multiple_smem_indexers(self): x = jax.random.uniform(jax.random.key(0), (2, 64, 64), dtype=jnp.float32) @functools.partial( - pl.pallas_call, + self.pallas_call, out_shape=jax.ShapeDtypeStruct([64, 64], jnp.float32), in_specs=(pl.BlockSpec(memory_space=plgpu.GMEM),), scratch_shapes=[ plgpu.SMEM(x.shape, jnp.float32), - plgpu.Barrier(num_arrivals=1), + plgpu.Barrier(), ], ) def extract_x0(x_ref_gmem, o_ref, scratch_ref, barrier_ref): @@ -495,21 +630,31 @@ def extract_x0(x_ref_gmem, o_ref, scratch_ref, barrier_ref): np.testing.assert_array_equal(extract_x0(x), x[0]) def test_gmem_to_smem_with_multiple_smem_indexers_and_transforms(self): + self.skip_if_wg_semantics() + x = jnp.arange(512 * 512, dtype=jnp.int32).reshape(512, 512) @functools.partial( - pl.pallas_call, + self.pallas_call, grid=(4, 4), out_shape=jax.ShapeDtypeStruct((256, 128), jnp.int32), - in_specs=(plgpu.GPUBlockSpec( - block_shape=(128, 128), - index_map=lambda i, j: (i, j), - memory_space=plgpu.SMEM, - transforms=(plgpu.TilingTransform((64, 32)), - plgpu.SwizzleTransform(128))),), - out_specs=(plgpu.GPUBlockSpec( - block_shape=(64, 32), - index_map=lambda i, j: (i, j), - memory_space=plgpu.SMEM,)), + in_specs=( + plgpu.BlockSpec( + block_shape=(128, 128), + index_map=lambda i, j: (i, j), + memory_space=plgpu.SMEM, + transforms=( + plgpu.TilingTransform((8, 32)), + plgpu.SwizzleTransform(128), + ), + ), + ), + out_specs=( + plgpu.BlockSpec( + block_shape=(64, 32), + index_map=lambda i, j: (i, j), + memory_space=plgpu.SMEM, + ) + ), ) def kernel(x_ref, o_ref): x_sliced = x_ref.at[0:64, 32:96].at[:, 0:32] # get x_ref[0:64, 32:64] @@ -521,13 +666,14 @@ def kernel(x_ref, o_ref): @parameterized.product(indexer=[0, 1, 2, 3]) def test_copy_gmem_to_smem_with_indexed_barrier(self, indexer): + @functools.partial( - pl.pallas_call, + self.pallas_call, out_shape=jax.ShapeDtypeStruct([128], jnp.float32), in_specs=(pl.BlockSpec(memory_space=plgpu.GMEM),), scratch_shapes=[ plgpu.SMEM((128,), jnp.float32), - plgpu.Barrier(num_arrivals=1, num_barriers=4), + plgpu.Barrier(num_barriers=4), ], ) def kernel(x_ref_gmem, o_ref, scratch_ref, barrier_ref): @@ -542,6 +688,8 @@ def kernel(x_ref_gmem, o_ref, scratch_ref, barrier_ref): @parameterized.named_parameters(("_g2s", False), ("_s2g", True)) def test_copy_with_transforms(self, to_smem): + self.skip_if_wg_semantics() + def kernel(x_ref, o_ref, barrier_ref): if to_smem: plgpu.copy_gmem_to_smem(x_ref, o_ref, barrier_ref) @@ -552,29 +700,29 @@ def kernel(x_ref, o_ref, barrier_ref): plgpu.wait_smem_to_gmem(0) in_spec = pl.BlockSpec(memory_space=plgpu.GMEM) - out_spec = plgpu.GPUBlockSpec( - (128, 128), - lambda: (0, 0), + out_spec = plgpu.BlockSpec( transforms=( - plgpu.TilingTransform((64, 32)), + plgpu.TilingTransform((8, 32)), plgpu.SwizzleTransform(128), ), memory_space=plgpu.SMEM, ) if not to_smem: in_spec, out_spec = out_spec, in_spec - f = pl.pallas_call( + f = self.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct([128, 128], jnp.float32), in_specs=(in_spec,), out_specs=out_spec, - scratch_shapes=[plgpu.Barrier(num_arrivals=1)], + scratch_shapes=[plgpu.Barrier()], ) x = jnp.arange(128 * 128, dtype=jnp.float32).reshape(128, 128) np.testing.assert_array_equal(f(x), x) def test_scoped_copy_with_transforms(self): - ts = (plgpu.TilingTransform((64, 32)), plgpu.SwizzleTransform(128)) + self.skip_if_wg_semantics() + + ts = (plgpu.TilingTransform((8, 32)), plgpu.SwizzleTransform(128)) def kernel(x_ref, o_ref, barrier_ref): def body(tmp_ref): plgpu.copy_gmem_to_smem(x_ref, tmp_ref, barrier_ref) @@ -583,47 +731,89 @@ def body(tmp_ref): pl.run_scoped(body, plgpu.SMEM((128, 128), jnp.float32, transforms=ts)) in_spec = pl.BlockSpec(memory_space=plgpu.GMEM) - out_spec = plgpu.GPUBlockSpec( - (128, 128), lambda: (0, 0), transforms=ts, memory_space=plgpu.SMEM, + out_spec = plgpu.BlockSpec(transforms=ts, memory_space=plgpu.SMEM) + f = self.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct([128, 128], jnp.float32), + in_specs=(in_spec,), + out_specs=out_spec, + scratch_shapes=[plgpu.Barrier()], ) + x = jnp.arange(128 * 128, dtype=jnp.float32).reshape(128, 128) + np.testing.assert_array_equal(f(x), x * 2) + + def test_scoped_copy_with_user_transforms(self): + def kernel(x_ref, o_ref, barrier_ref): + def body(tmp_ref): + tmp_ref = plgpu.unswizzle_ref(tmp_ref, 128) + tmp_ref = plgpu.untile_ref(tmp_ref, (8, 32)) + plgpu.copy_gmem_to_smem(x_ref, tmp_ref, barrier_ref) + plgpu.barrier_wait(barrier_ref) + o_ref[...] = tmp_ref[...] * 2 + pl.run_scoped(body, plgpu.SMEM((16, 4, 8, 32), jnp.float32)) + + in_spec = pl.BlockSpec(memory_space=plgpu.GMEM) f = pl.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct([128, 128], jnp.float32), in_specs=(in_spec,), - out_specs=out_spec, - scratch_shapes=[plgpu.Barrier(num_arrivals=1)], + scratch_shapes=[plgpu.Barrier()], ) x = jnp.arange(128 * 128, dtype=jnp.float32).reshape(128, 128) np.testing.assert_array_equal(f(x), x * 2) def test_copy_with_transforms_and_indexing(self): + self.skip_if_wg_semantics() + def kernel(x_ref, o_ref, barrier_ref): for i in range(2): plgpu.copy_gmem_to_smem(x_ref, o_ref.at[i], barrier_ref) plgpu.barrier_wait(barrier_ref) in_spec = pl.BlockSpec(memory_space=plgpu.GMEM) - out_spec = plgpu.GPUBlockSpec( - (2, 128, 128), - lambda: (0, 0, 0), + out_spec = plgpu.BlockSpec( transforms=( - plgpu.TilingTransform((64, 32)), + plgpu.TilingTransform((8, 32)), plgpu.TransposeTransform((0, 2, 1, 3, 4)), plgpu.SwizzleTransform(128), ), memory_space=plgpu.SMEM, ) - f = pl.pallas_call( + f = self.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct([2, 128, 128], jnp.float32), in_specs=(in_spec,), out_specs=out_spec, - scratch_shapes=[plgpu.Barrier(num_arrivals=1)], + scratch_shapes=[plgpu.Barrier()], ) x = jnp.arange(128 * 128, dtype=jnp.float32).reshape(128, 128) np.testing.assert_array_equal(f(x), np.stack([x, x], axis=0)) + @parameterized.product( + src_memory_space=[plgpu.SMEM, plgpu.GMEM], + layout=[plgpu.Layout.WG_STRIDED((128,), vec_size=1), None, + ] + ) + def test_load_to_strided_layout_with_indexing(self, src_memory_space, layout): + self.skip_if_wg_semantics() + + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct([2, 128], jnp.float32), + in_specs=[pl.BlockSpec(memory_space=src_memory_space)], + out_specs=plgpu.BlockSpec(memory_space=plgpu.SMEM), + ) + def kernel(x_ref, o_ref): + for i in range(2): + x = plgpu.load(x_ref, (i,), layout=layout) + o_ref[i, ...] = x + + x = jnp.arange(2 * 128, dtype=jnp.float32).reshape(2, 128) + np.testing.assert_array_equal(kernel(x), x) + def test_indexing_before_transpose(self): + self.skip_if_wg_semantics() + def kernel(x_ref, o_ref, barrier_ref): for i in range(2): plgpu.copy_gmem_to_smem( @@ -632,23 +822,22 @@ def kernel(x_ref, o_ref, barrier_ref): plgpu.barrier_wait(barrier_ref) in_spec = pl.BlockSpec(memory_space=plgpu.GMEM) - out_spec = plgpu.GPUBlockSpec( - (2, 64, 2, 128), lambda: (0, 0, 0, 0), memory_space=plgpu.SMEM, - ) - f = pl.pallas_call( + out_spec = plgpu.BlockSpec(memory_space=plgpu.SMEM) + f = self.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct([2, 64, 2, 128], jnp.float32), in_specs=(in_spec,), out_specs=out_spec, - scratch_shapes=[plgpu.Barrier(num_arrivals=1)], + scratch_shapes=[plgpu.Barrier()], ) x = jnp.arange(2 * 64 * 128, dtype=jnp.float32).reshape(2, 64, 128) xt = x.transpose((1, 0, 2)) np.testing.assert_array_equal(f(x), np.stack([xt, xt], axis=0)) def test_copy_gmem_to_smem_in_run_scoped(self): + @functools.partial( - pl.pallas_call, + self.pallas_call, out_shape=jax.ShapeDtypeStruct([256], jnp.float32), in_specs=(pl.BlockSpec(memory_space=plgpu.GMEM),), ) @@ -659,14 +848,15 @@ def inner_body(scratch_ref): plgpu.barrier_wait(barrier_ref) o_ref[...] = scratch_ref[...] + 1 pl.run_scoped(inner_body, plgpu.SMEM((256,), jnp.float32)) - pl.run_scoped(body, plgpu.Barrier(num_arrivals=1)) + pl.run_scoped(body, plgpu.Barrier()) x = jnp.arange(256).astype(jnp.float32) np.testing.assert_array_equal(kernel(x), x + 1.0) def test_add_doubled_sum(self): + @functools.partial( - pl.pallas_call, + self.pallas_call, out_shape=jax.ShapeDtypeStruct([128], jnp.float32), ) def kernel(x_ref, o_ref): @@ -675,26 +865,6 @@ def kernel(x_ref, o_ref): x = jnp.arange(128).astype(jnp.float32) np.testing.assert_array_equal(kernel(x), x + x.sum()*2) - @parameterized.named_parameters( - ("rsqrt", jax.lax.rsqrt, ), - ("log", jax.lax.log, 5e-7), - ("exp", jax.lax.exp, ), - ("exp2", jax.lax.exp2, 5e-7), - ("logistic", jax.lax.logistic, ), - ("tanh", jax.lax.tanh, 5e-7), - ) - def test_approx_math_unary_op(self, unary_op, rtol=1e-7): - @functools.partial( - pl.pallas_call, - out_shape=jax.ShapeDtypeStruct([128], jnp.float32), - compiler_params=plgpu.GPUCompilerParams(approx_math=True), - ) - def kernel(x_ref, o_ref): - o_ref[...] = unary_op(x_ref[...]) - - x = jnp.arange(128).astype(jnp.float32) / 128 - np.testing.assert_allclose(kernel(x), unary_op(x), rtol=rtol, atol=1e-5) - @parameterized.product(input_factor=[0.001, 1, 10, 100, 100]) def test_layer_norm(self, input_factor): eps = 1e-5 @@ -702,7 +872,7 @@ def test_layer_norm(self, input_factor): beta = 1.0 @functools.partial( - pl.pallas_call, + self.pallas_call, out_shape=jax.ShapeDtypeStruct([256], jnp.float32), ) def layer_norm(x_ref, o_ref): @@ -730,8 +900,9 @@ def layer_norm_np(x): np.testing.assert_allclose(layer_norm(x), layer_norm_np(x), rtol=5e-5) def test_print(self): + @functools.partial( - pl.pallas_call, + self.pallas_call, out_shape=jax.ShapeDtypeStruct([256], jnp.float32), ) def kernel(x_ref, o_ref): @@ -744,16 +915,30 @@ def kernel(x_ref, o_ref): self.assertEqual(output(), "It works!\n") def test_print_wgmma_tiled_layout(self): + self.skip_if_wg_semantics() + shape = (128, 64) size = math.prod(shape) + + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct(shape, jnp.float32), + in_specs=[ + plgpu.BlockSpec( + transforms=( + plgpu.TilingTransform((8, 32)), + plgpu.SwizzleTransform(128), + ) + ) + ], + ) def kernel(x_ref, o_ref): + del o_ref # Unused. pl.debug_print("prefix {}", x_ref[...]) - spec = plgpu.GPUBlockSpec(shape, lambda: (0, 0), transforms=(plgpu.TilingTransform((64, 32)), plgpu.SwizzleTransform(128))) - x = jnp.arange(size, dtype=jnp.float32).reshape(shape) - f = pl.pallas_call(kernel, out_shape=x, in_specs=[spec], out_specs=spec) + x = jnp.arange(size, dtype=jnp.float32).reshape(shape) with self.capture_stdout() as get_output: - jax.block_until_ready(f(x)) + jax.block_until_ready(kernel(x)) output = get_output() results = re.findall(r"prefix \[(\d+), (\d+)\]: (\d+).?\d*", output) @@ -763,8 +948,10 @@ def kernel(x_ref, o_ref): self.assertEqual(v, i * shape[1] + j) def test_print_scalar(self): + self.skip_if_wg_semantics() + @functools.partial( - pl.pallas_call, + self.pallas_call, out_shape=jax.ShapeDtypeStruct([256], jnp.int32), ) def kernel(x_ref, o_ref): @@ -778,8 +965,10 @@ def kernel(x_ref, o_ref): self.assertIn(f"x.sum() = {x.sum()}", output()) def test_print_scalar_array(self): + self.skip_if_wg_semantics() + @functools.partial( - pl.pallas_call, + self.pallas_call, out_shape=jax.ShapeDtypeStruct([256], jnp.int32), ) def kernel(x_ref, o_ref): @@ -793,10 +982,12 @@ def kernel(x_ref, o_ref): self.assertIn(f"x.sum() = {x.sum() + 1}", output()) def test_print_array(self): + self.skip_if_wg_semantics() + in_shape = [2, 1, 64, 64] @functools.partial( - pl.pallas_call, + self.pallas_call, out_shape=jax.ShapeDtypeStruct(in_shape, jnp.int32), ) def kernel(x_ref, o_ref): @@ -809,11 +1000,58 @@ def kernel(x_ref, o_ref): self.assertIn("x: [1, 0, 43, 23]: 6871\n", output()) + @parameterized.parameters( + (plgpu.TilingTransform((1, 32)), plgpu.SwizzleTransform(128)), + (plgpu.TilingTransform((8, 32)), plgpu.SwizzleTransform(128)), + (), + ) + def test_get_swap_with_transforms(self, *transforms): + self.skip_if_wg_semantics() + + shape = (128, 128) + + @functools.partial( + self.pallas_call, + in_specs=[plgpu.BlockSpec(memory_space=plgpu.GMEM)], + out_specs=plgpu.BlockSpec(memory_space=plgpu.GMEM), + out_shape=jax.ShapeDtypeStruct(shape, jnp.int32), + scratch_shapes=[ + plgpu.SMEM(shape, jnp.int32, transforms=tuple(transforms)), + plgpu.Barrier(), + ] + ) + def kernel(x_ref, o_ref, scratch_ref, barrier_ref): + plgpu.copy_gmem_to_smem(x_ref, scratch_ref, barrier_ref) + plgpu.barrier_wait(barrier_ref) + scratch_ref[...] = scratch_ref[...] * 2 + plgpu.copy_smem_to_gmem(scratch_ref, o_ref) + plgpu.wait_smem_to_gmem(0) + + x = jnp.arange(math.prod(shape), dtype=jnp.int32).reshape(shape) + np.testing.assert_array_equal(kernel(x), x * 2) + + def test_check(self): + self.skip_if_wg_semantics() + + self.enter_context(pl.enable_debug_checks(True)) + + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct([256], jnp.int32), + ) + def kernel(x_ref, o_ref): + pl.debug_check(_sum_same_dtype(x_ref[...]) > 0, "x.sum() is negative") + o_ref[...] = x_ref[...] + + x = jnp.arange(256, dtype=jnp.int32) + np.testing.assert_array_equal(kernel(x), x) + def test_load_scalar(self): + @functools.partial( - pl.pallas_call, + self.pallas_call, out_shape=jax.ShapeDtypeStruct((128,), jnp.int32), - in_specs=[plgpu.GPUBlockSpec(memory_space=plgpu.GPUMemorySpace.GMEM)], + in_specs=[plgpu.BlockSpec(memory_space=plgpu.GMEM)], ) def kernel(x_ref, o_ref): o_ref[...] = jnp.broadcast_to(x_ref[10], (128,)) @@ -821,9 +1059,11 @@ def kernel(x_ref, o_ref): np.testing.assert_array_equal(kernel(jnp.arange(11, dtype=jnp.int32)), jnp.full((128,), 10, dtype=jnp.int32)) - @parameterized.product(thread_semantics=[*plgpu.ThreadSemantics]) - def test_run_scoped(self, thread_semantics): - + def test_run_scoped(self): + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), + ) def kernel(x_ref, o_ref): def body(tmp_ref): self.assertEqual(tmp_ref.shape, (8, 128)) @@ -834,20 +1074,32 @@ def body(tmp_ref): self.assertEqual(tmp.shape, (8, 128)) o_ref[...] = tmp - inp = np.ones((8, 128), jnp.float32) - f = pl.pallas_call( - kernel, - out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), - compiler_params=plgpu.GPUCompilerParams( - thread_semantics=thread_semantics - ), + x = np.ones((8, 128), jnp.float32) + np.testing.assert_array_equal(kernel(x), x + 1.0) + + def test_run_scoped_in_cond(self): + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct([256], jnp.int32), + in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)], + out_specs=pl.BlockSpec(memory_space=plgpu.SMEM), ) - o = f(inp) - np.testing.assert_array_equal(o, inp + 1.0) + def kernel(x_ref_gmem, o_ref): + def scoped_kernel(barrier_ref): + plgpu.copy_gmem_to_smem(x_ref_gmem, o_ref, barrier_ref) + plgpu.barrier_wait(barrier_ref) + + def branch(): + pl.run_scoped(scoped_kernel, plgpu.Barrier()) + + jax.lax.cond(x_ref_gmem[0] % 2 == 0, branch, branch) + + x = jnp.full((256,), 1234, dtype=jnp.int32) + np.testing.assert_array_equal(kernel(x), x) def test_program_id(self): @functools.partial( - pl.pallas_call, + self.pallas_call, in_specs=(), out_specs=pl.BlockSpec((128,), lambda *i: i), out_shape=jax.ShapeDtypeStruct([128 * 2], jnp.int32), @@ -866,7 +1118,7 @@ def test_program_id_in_squashed_grid(self): # 3 CUDA grid dimensions. grid = (2, 3, 4, 5) @functools.partial( - pl.pallas_call, + self.pallas_call, in_specs=(), out_specs=pl.BlockSpec((1,) * len(grid) + (128,), lambda *i: (*i, 0)), out_shape=jax.ShapeDtypeStruct([*grid, 128], jnp.int32), @@ -887,7 +1139,7 @@ def kernel(o_ref): def test_program_id_in_block_spec(self): @functools.partial( - pl.pallas_call, + self.pallas_call, in_specs=(pl.BlockSpec((2, 128), lambda i: (pl.program_id(0), i)),), out_specs=pl.BlockSpec((2, 128), lambda i: (pl.program_id(0), i)), out_shape=jax.ShapeDtypeStruct([2, 128], jnp.int32), @@ -901,7 +1153,7 @@ def kernel(x_ref, o_ref): def test_num_programs(self): @functools.partial( - pl.pallas_call, + self.pallas_call, in_specs=(), out_specs=pl.BlockSpec((128,), lambda *i: i), out_shape=jax.ShapeDtypeStruct([128 * 2], jnp.int32), @@ -916,17 +1168,18 @@ def kernel(o_ref): ) def test_swizzled_blockspec_shapes(self): + self.skip_if_wg_semantics() - spec = plgpu.GPUBlockSpec( + spec = plgpu.BlockSpec( (128, 64), lambda *i: i, transforms=( - plgpu.TilingTransform((64, 64)), + plgpu.TilingTransform((8, 64)), plgpu.SwizzleTransform(128), ), ) @functools.partial( - pl.pallas_call, + self.pallas_call, in_specs=[spec], out_specs=spec, out_shape=jax.ShapeDtypeStruct((128, 128), jnp.float16), @@ -939,30 +1192,38 @@ def kernel(x_ref, o_ref): x = jnp.arange(128 * 128).astype(jnp.float16).reshape(128, 128) np.testing.assert_array_equal(kernel(x), x) - @parameterized.product( - force_while=[False, True], thread_semantics=[*plgpu.ThreadSemantics] - ) - def test_fori_loop_array(self, force_while, thread_semantics): + @parameterized.product(force_while=[False, True]) + def test_fori_loop_array(self, force_while): @functools.partial( - pl.pallas_call, - out_shape=jax.ShapeDtypeStruct([256], jnp.int32), - compiler_params=plgpu.GPUCompilerParams(thread_semantics=thread_semantics), + self.pallas_call, out_shape=jax.ShapeDtypeStruct([256], jnp.int32) ) def kernel(x_ref, o_ref): # Equivalent to x_ref[...] + 2 + 3. - o_ref[...] = _fori_loop(force_while, 2, 4, lambda i, x: x + i, x_ref[...]) + o_ref[...] = _fori_loop( + force_while, 2, 4, lambda i, x: x + i, x_ref[...] + ) x = jnp.arange(256, dtype=jnp.int32) np.testing.assert_array_equal(kernel(x), x + 2 + 3) - @parameterized.product( - force_while=[False, True], thread_semantics=[*plgpu.ThreadSemantics] - ) - def test_fori_loop_scalar(self, force_while, thread_semantics): + @parameterized.product(unroll=[1, 2]) + def test_fori_loop_array_unrolled(self, unroll): @functools.partial( - pl.pallas_call, - out_shape=jax.ShapeDtypeStruct([256], jnp.int32), - compiler_params=plgpu.GPUCompilerParams(thread_semantics=thread_semantics), + self.pallas_call, out_shape=jax.ShapeDtypeStruct([256], jnp.int32) + ) + def kernel(x_ref, o_ref): + # Equivalent to x_ref[...] + 2 + 3 + 4 + 5. + o_ref[...] = lax.fori_loop( + 2, 6, lambda i, x: x + i, x_ref[...], unroll=unroll + ) + + x = jnp.arange(256, dtype=jnp.int32) + np.testing.assert_array_equal(kernel(x), x + 2 + 3 + 4 + 5) + + @parameterized.product(force_while=[False, True]) + def test_fori_loop_scalar(self, force_while): + @functools.partial( + self.pallas_call, out_shape=jax.ShapeDtypeStruct([256], jnp.int32) ) def kernel(o_ref): # Equivalent to 2 + 3. @@ -974,9 +1235,8 @@ def kernel(o_ref): np.testing.assert_array_equal(kernel(), jnp.full([256], 5, jnp.int32)) def test_fori_loop_dynamic_bounds(self): - @functools.partial( - pl.pallas_call, + self.pallas_call, out_shape=jax.ShapeDtypeStruct([256], jnp.int32), grid=(1,) ) @@ -989,16 +1249,10 @@ def kernel(o_ref): np.testing.assert_array_equal(kernel(), jnp.full([256], 5, dtype=jnp.int32)) - @parameterized.product( - force_while=[False, True], thread_semantics=[*plgpu.ThreadSemantics] - ) - def test_fori_loop_tuple(self, force_while, thread_semantics): + @parameterized.product(force_while=[False, True]) + def test_fori_loop_tuple(self, force_while): @functools.partial( - pl.pallas_call, - out_shape=jax.ShapeDtypeStruct([256], jnp.int32), - compiler_params=plgpu.GPUCompilerParams( - thread_semantics=thread_semantics - ), + self.pallas_call, out_shape=jax.ShapeDtypeStruct([256], jnp.int32) ) def kernel(o_ref): def body(step, xs): @@ -1017,16 +1271,11 @@ def body(step, xs): kernel(), jnp.full([256], 3 * (0 + 1), jnp.int32) ) - @parameterized.product( - force_while=[False, True], thread_semantics=[*plgpu.ThreadSemantics] - ) - def test_fori_loop_indexed_store(self, force_while, thread_semantics): + @parameterized.product(force_while=[False, True]) + def test_fori_loop_indexed_store(self, force_while): @functools.partial( - pl.pallas_call, + self.pallas_call, out_shape=jax.ShapeDtypeStruct([4, 128], jnp.float32), - compiler_params=plgpu.GPUCompilerParams( - thread_semantics=thread_semantics - ), ) def kernel(x_ref, y_ref, o_ref): def body(idx, _): @@ -1039,17 +1288,9 @@ def body(idx, _): y = x + 1 np.testing.assert_array_equal(kernel(x, y), x + y) - @parameterized.product(thread_semantics=[*plgpu.ThreadSemantics]) - def test_while_loop(self, thread_semantics): - if thread_semantics == plgpu.ThreadSemantics.Warpgroup: - self.skipTest("WG lowering does not support reduce_sum_p needed for this test") - + def test_while_loop(self): @functools.partial( - pl.pallas_call, - out_shape=jax.ShapeDtypeStruct([128], jnp.int32), - compiler_params=plgpu.GPUCompilerParams( - thread_semantics=thread_semantics - ), + self.pallas_call, out_shape=jax.ShapeDtypeStruct([128], jnp.int32) ) def kernel(x_ref, o_ref): o_ref[...] = jnp.zeros(o_ref.shape, dtype=jnp.int32) @@ -1071,7 +1312,7 @@ def body(acc): def test_while_loop_layout_mismatch(self): @functools.partial( - pl.pallas_call, out_shape=jax.ShapeDtypeStruct([128], jnp.int32) + self.pallas_call, out_shape=jax.ShapeDtypeStruct([128], jnp.int32) ) def kernel(o_ref): def cond(acc): @@ -1084,18 +1325,28 @@ def body(acc): return plgpu.layout_cast( jnp.zeros(o_ref.shape, o_ref.dtype), plgpu.Layout.WGMMA_ROW ) - - _ = jax.lax.while_loop(cond, body, o_ref[...]) - - with self.assertRaisesRegex(ValueError, "has layout .*, when it should be"): - kernel() - - @parameterized.parameters([*plgpu.ThreadSemantics]) - def test_cond(self, thread_semantics): + # Cast explicitly to cause the mismatch, otherwise layout inference will + # succeed at constructing a working program. + strided_input = plgpu.layout_cast( + o_ref[...], plgpu.Layout.WG_STRIDED(shape=(128,), vec_size=1) + ) + _ = jax.lax.while_loop(cond, body, strided_input) + + if self.LOWERING_SEMANTICS == plgpu.LoweringSemantics.Warpgroup: + with self.assertRaisesRegex( + NotImplementedError, + "Cannot convert from WGStridedFragLayout.* to TiledLayout", + ): + kernel() + else: + with self.assertRaisesRegex( + ValueError, "has layout .*, when it should be" + ): + kernel() + + def test_cond(self): @functools.partial( - pl.pallas_call, - out_shape=jax.ShapeDtypeStruct([256], jnp.int32), - compiler_params=plgpu.GPUCompilerParams(thread_semantics=thread_semantics), + self.pallas_call, out_shape=jax.ShapeDtypeStruct([256], jnp.int32) ) def kernel(x_ref, o_ref): jax.lax.cond( @@ -1111,27 +1362,49 @@ def kernel(x_ref, o_ref): self.assertIn("acc % 2", output()) - @parameterized.parameters([*plgpu.ThreadSemantics]) - def test_cond_returning_array(self, thread_semantics): + def test_cond_returning_array(self): @functools.partial( - pl.pallas_call, - out_shape=jax.ShapeDtypeStruct([256], jnp.int32), - compiler_params=plgpu.GPUCompilerParams( - thread_semantics=thread_semantics - ), + self.pallas_call, out_shape=jax.ShapeDtypeStruct([256], jnp.int32) ) def kernel(x_ref, o_ref): - acc = _sum_same_dtype(x_ref[...]) + acc_sum = _sum_same_dtype(x_ref[...]) acc2, acc = jax.lax.cond( - acc % 2 == 0, - lambda: (acc * 2, acc), - lambda: (acc, acc * 2), + acc_sum % 2 == 0, + lambda: (acc_sum * 2, x_ref[...]), + lambda: (acc_sum, x_ref[...]), ) - o_ref[...] = jnp.broadcast_to(acc + acc2, o_ref.shape) + o_ref[...] = jnp.broadcast_to(_sum_same_dtype(acc) + acc2, o_ref.shape) x = jnp.arange(256, dtype=jnp.int32) np.testing.assert_array_equal(kernel(x), jnp.broadcast_to(jnp.sum(x) * 3, [256])) + def test_tile_slicing(self): + # Not testing with warpgroup semantics, because we want to enforce a layout. + self.skip_if_wg_semantics() + + shape = (256, 128) + block_spec = plgpu.BlockSpec( + transforms=(plgpu.TilingTransform((8, 64)), plgpu.SwizzleTransform(128)) + ) + @functools.partial( + self.pallas_call, + in_specs=[block_spec], + out_specs=block_spec, + out_shape=jax.ShapeDtypeStruct((64, 64), jnp.uint16), + ) + def kernel(x_ref, o_ref): + def sum_tiles(row, acc): + row_slice = pl.ds(row * 64, 64) + for col in range(128 // 64): + acc += x_ref[row_slice, pl.ds(col * 64, 64)] + return acc + acc = plgpu.layout_cast(jnp.zeros((64, 64), jnp.uint16), plgpu.Layout.WGMMA) + o_ref[...] = _fori_loop(False, 0, 256 // 64, sum_tiles, acc) + + x = jnp.arange(math.prod(shape), dtype=jnp.uint16).reshape(shape) + y = x.reshape(256 // 64, 64, 128 // 64, 64).sum(axis=(0, 2), dtype=jnp.uint16) + np.testing.assert_array_equal(kernel(x), y) + def test_input_output_aliases(self): # Note that we're writing to the input pointer, which should alias b_ptr. def kernel(a_ref, b_ref): @@ -1139,16 +1412,18 @@ def kernel(a_ref, b_ref): a_ref[...] = jnp.ones_like(a_ref) a = np.zeros((64, 64), dtype=jnp.float32) - b = pl.pallas_call( + b = self.pallas_call( kernel, - in_specs=[plgpu.GPUBlockSpec(memory_space=plgpu.GPUMemorySpace.GMEM)], - out_specs=plgpu.GPUBlockSpec(memory_space=plgpu.GPUMemorySpace.GMEM), + in_specs=[plgpu.BlockSpec(memory_space=plgpu.GMEM)], + out_specs=plgpu.BlockSpec(memory_space=plgpu.GMEM), input_output_aliases={0: 0}, out_shape=a, )(a) np.testing.assert_array_equal(b, np.ones_like(a)) def test_slicing(self): + self.skip_if_wg_semantics() + left = upper = slice(None, 64) right = lower = slice(64, None) # We rotate the four quadrants of the input clockwise. @@ -1159,22 +1434,17 @@ def rotate(src, dst): dst[lower, left] = src[lower, right] x = jnp.arange(128 * 128).astype(jnp.float16).reshape(128, 128) - spec = plgpu.GPUBlockSpec( - (128, 128), - lambda: (0, 0), - transforms=( - plgpu.TilingTransform((64, 64)), - plgpu.SwizzleTransform(128), - ), + spec = plgpu.BlockSpec( + transforms=(plgpu.TilingTransform((8, 64)), plgpu.SwizzleTransform(128)) ) - f = pl.pallas_call(rotate, out_shape=x, in_specs=[spec], out_specs=spec) + f = self.pallas_call(rotate, out_shape=x, in_specs=[spec], out_specs=spec) expected = np.empty_like(x) rotate(x, expected) np.testing.assert_array_equal(f(x), expected) def test_layout_cast(self, shape=(256, 64)): @functools.partial( - pl.pallas_call, + self.pallas_call, out_shape=jax.ShapeDtypeStruct(shape, jnp.float32), ) def kernel(o_ref): @@ -1183,7 +1453,52 @@ def kernel(o_ref): x = jnp.full(shape, 42.0, jnp.float32) np.testing.assert_array_equal(kernel(), x) + @parameterized.parameters(False, True) + def test_wgmma_transposed_layout(self, store_transposed): + """Tests that the result of wgmma can be store transposed using + the WGMMA_TRNASPOSED layout. + """ + + dtype = jnp.dtype(jnp.float16) + swizzle_elems = 128 // dtype.itemsize + shape = (128, 128) + @functools.partial( + pl.pallas_call, + out_shape=jax.ShapeDtypeStruct(shape, dtype), + out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), + scratch_shapes=[ + plgpu.SMEM( + shape, dtype, + transforms=( + plgpu.TilingTransform((8, swizzle_elems)), + plgpu.SwizzleTransform(128), + ), + ) + ] + ) + def kernel(o_ref, smem): + iota = plgpu.broadcasted_iota( + dtype, o_ref.shape, 0, layout=plgpu.Layout.WGMMA + ) * o_ref.shape[0] + iota += plgpu.broadcasted_iota( + dtype, o_ref.shape, 1, layout=plgpu.Layout.WGMMA + ) + + smem_trns = plgpu.transpose_ref(smem, (1, 0)) + smem_trns[...] = plgpu.layout_cast(iota, plgpu.Layout.WGMMA_TRANSPOSED) + plgpu.commit_smem() + plgpu.copy_smem_to_gmem(smem_trns if store_transposed else smem, o_ref) + + x = jnp.arange(128 * 128, dtype=dtype).reshape((128, 128)).T + if store_transposed: + with self.assertRaises(ValueError): + kernel() + else: + np.testing.assert_array_equal(kernel(), x) + def test_profiler(self): + self.skip_if_wg_semantics() # Transform inference fails. + def kernel(x_ref, o_ref): with jax.named_scope("add"): with jax.named_scope("load"): @@ -1193,17 +1508,17 @@ def kernel(x_ref, o_ref): o_ref[...] = o with tempfile.TemporaryDirectory() as tmpdir: x = jnp.arange(256).astype(jnp.float32) - y = pl.pallas_call( + y = self.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct([256], jnp.float32), - compiler_params=plgpu.GPUCompilerParams( + compiler_params=plgpu.CompilerParams( profile_space=16, profile_dir=tmpdir ), )(x) jax.block_until_ready(y) jax.effects_barrier() [name] = os.listdir(tmpdir) - with open(os.path.join(tmpdir, name), "r") as f: + with open(os.path.join(tmpdir, name)) as f: data = f.read() self.assertEqual(data.count('"name": "add"'), 2) self.assertEqual(data.count('"name": "load"'), 2) @@ -1221,20 +1536,13 @@ def kernel(x_ref, o_ref): (jnp.uint32, jnp.int32), (jnp.int32, jnp.uint32), ], - thread_semantics=[*plgpu.ThreadSemantics], ) - def test_bitcast_convert_type(self, dtypes, thread_semantics): + def test_bitcast_convert_type(self, dtypes): in_dtype, out_dtype = dtypes m, n = 16, 8 out_shape = jax.ShapeDtypeStruct((m, n), out_dtype) - @functools.partial( - pl.pallas_call, - out_shape=out_shape, - compiler_params=plgpu.GPUCompilerParams( - thread_semantics=thread_semantics - ), - ) + @functools.partial(self.pallas_call, out_shape=out_shape) def convert(x_ref, y_ref): y_ref[...] = jax.lax.bitcast_convert_type(x_ref[...], out_shape) @@ -1243,17 +1551,547 @@ def convert(x_ref, y_ref): convert(x), jax.lax.bitcast_convert_type(x, out_dtype) ) + def test_optimization_barrier(self): + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct((128,), jnp.float32), + ) + def kernel(x_ref, o_ref): + o_ref[...] = lax.optimization_barrier(x_ref[...]) + + x = jax.lax.iota(jnp.float32, 128) + np.testing.assert_array_equal(kernel(x), x) + + def test_optimization_barrier_multiple_inputs(self): + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct((128,), jnp.float32), + ) + def kernel(x_ref, y_ref, o_ref): + x, y = lax.optimization_barrier([x_ref[...], y_ref[...]]) + o_ref[...] = x + y + + x = jax.lax.iota(jnp.float32, 128) + y = jax.lax.iota(jnp.float32, 128) * 3 + np.testing.assert_array_equal(kernel(x, y), x + y) + + def test_smem_aliasing_works(self): + self.skip_if_wg_semantics() + + in_shape = (2, 256) + + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct([128], jnp.float32), + in_specs=[pl.BlockSpec(in_shape)], + out_specs=pl.BlockSpec((128,), memory_space=plgpu.GMEM), + scratch_shapes=[ + plgpu.RefUnion( + # Note: this test exposes internals that we don't particularly + # want to phold for the sake of testing the functionality of the + # API. It's expected that this test might end up breaking in the + # future, e.g. if we decide to change our alignment requirements + # on SMEM refs---and that's OK. Users should explicitly NOT rely + # on this exact behaviour. + # + # Use a value larger than the number of bytes used for SMEM + # alignment (1024) in order to make sure that the second ref + # in the second group aliases the single ref in the first group. + plgpu.SMEM(in_shape, jnp.float32), + [ + plgpu.SMEM((256,), jnp.bfloat16), + # Add an arbitrary level of nesting to make sure that we + # support PyTrees. + [ + plgpu.SMEM( + (128,), + jnp.float32, + transforms=(plgpu.TilingTransform((64,)),), + ), + ] + ], + ) + ], + ) + def kernel(x_ref, o_ref128, aliased_ref): + smem_ref256, _, smem_ref128 = aliased_ref + # Ensure that extraction via index works the same as unfolding. + smem_ref128_2 = aliased_ref[2] + self.assertIsInstance(smem_ref128, state_types.TransformedRef) + self.assertIsInstance(smem_ref128_2, state_types.TransformedRef) + self.assertIs(smem_ref128.ref, smem_ref128_2.ref) + self.assertEqual(smem_ref128.transforms, smem_ref128_2.transforms) + extract_alias_transform, tile_transform = smem_ref128.transforms + # Ensure that the transforms provided in the scratch shapes have been + # passed correctly. + self.assertIsInstance(extract_alias_transform, gpu_core.ExtractAliasedRef) + self.assertIsInstance(tile_transform, gpu_core.UntileRef) + smem_ref256[...] = x_ref[...] + 1 + plgpu.commit_smem() + plgpu.copy_smem_to_gmem(smem_ref128, o_ref128) + + x = jnp.arange(512).astype(jnp.float32) + np.testing.assert_array_equal( + kernel(x.reshape(in_shape)).reshape((128,)), x[256 : 256 + 128] + 1 + ) + + def test_smem_aliasing_works_with_subbyte_dtypes(self): + self.skip_if_wg_semantics() + + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct([256], jnp.uint4), + in_specs=[pl.BlockSpec((128,))], + out_specs=pl.BlockSpec((256,), memory_space=plgpu.GMEM), + scratch_shapes=[ + plgpu.RefUnion( + # Note: this test exposes internals that we don't particularly + # want to phold for the sake of testing the functionality of the + # API. It's expected that this test might end up breaking in the + # future, e.g. if we decide to change our alignment requirements + # on SMEM refs---and that's OK. Users should explicitly NOT rely + # on this exact behaviour. + # + # This allocation scheme is a bit complicated, but serves to + # test that + # 1. Refs are aligned correctly (currently to 1024 bytes); + # 2. (u)int4 references are not allocated more than 1 byte per + # 2 elements. + # The first group of refs serves to create two allocations, each + # aligned to 1024 bytes. The second group serves to create two + # allocations where the first one is exactly 1024 bytes, + # assuming 1 byte per 2 uint4 elements. As a result, if our + # implementation is correct, the second allocation of the second + # group should exactly alias the second allocation of the first + # group. + [ + plgpu.SMEM((128,), jnp.int8), + plgpu.SMEM((128,), jnp.int8), + ], + [plgpu.SMEM((2048,), jnp.uint4), plgpu.SMEM((256,), jnp.uint4)], + ) + ], + ) + def kernel(x_ref, o_refi4, aliased_ref): + _, smem_refi8, _, smem_refi4 = aliased_ref + smem_refi8[...] = x_ref[...] + plgpu.commit_smem() + plgpu.copy_smem_to_gmem(smem_refi4, o_refi4) + + def unpack_i4_as_i8(x): + x = x.reshape((128, 1)) + x_high = x >> 4 + x_low = x & 0xF + return jnp.concatenate([x_low, x_high], axis=-1).reshape((256,)) + + x = jnp.arange(128).astype(jnp.int8) + test_as_i8 = jax.lax.convert_element_type(kernel(x), new_dtype=jnp.int8) + np.testing.assert_array_equal(test_as_i8[:256], unpack_i4_as_i8(x)) + + def test_smem_aliasing_works_for_quantization(self): + self.skip_if_wg_semantics() + shape = (64, 256) + large_ty, small_ty = jnp.bfloat16, jnp.uint4 + large_swizzle = plgpu.SwizzleTransform(64 * jnp.finfo(large_ty).bits // 8) + small_swizzle = plgpu.SwizzleTransform(64 * jnp.iinfo(small_ty).bits // 8) + tiling = plgpu.TilingTransform((8, 64)) + + def kernel(x_gmem, o_gmem): + return pl.run_scoped( + functools.partial(scoped_kernel, x_gmem, o_gmem), + plgpu.RefUnion( + plgpu.SMEM(shape, large_ty, transforms=(tiling, large_swizzle)), + plgpu.SMEM(shape, small_ty, transforms=(tiling, small_swizzle)) + ), + plgpu.Barrier(num_barriers=1), + ) + + def scoped_kernel(x_gmem, o_gmem, aliased_ref, barrier): + ref_large_ty, ref_small_ty = aliased_ref + plgpu.copy_gmem_to_smem(x_gmem, ref_small_ty, barrier=barrier) + plgpu.barrier_wait(barrier) + ref_large_ty[...] = ref_small_ty[...].astype(ref_large_ty.dtype) * 3 + plgpu.commit_smem() + plgpu.copy_smem_to_gmem(ref_large_ty, o_gmem) + plgpu.wait_smem_to_gmem(0, wait_read_only=True) + + kernel_fn = self.pallas_call( + kernel, + in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)], + out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), + out_shape=jax.ShapeDtypeStruct(shape, large_ty), + grid=(1, 1), + ) + key = jax.random.key(42) + x = jax.random.randint(key, shape, 0, 4).astype(small_ty) + expected = x * 3 + np.testing.assert_array_equal(kernel_fn(x), expected) + + def test_assigning_to_ref_union_raises(self): + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct([128], jnp.float32), + in_specs=[pl.BlockSpec((128,))], + out_specs=pl.BlockSpec((128,), memory_space=plgpu.GMEM), + scratch_shapes=[plgpu.RefUnion(plgpu.SMEM((128,), jnp.float32))], + ) + def kernel(x_ref, o_ref128, aliased_ref): + aliased_ref[...] = x_ref[...] + 1 + plgpu.commit_smem() + plgpu.copy_smem_to_gmem(aliased_ref, o_ref128) + + with self.assertRaisesRegex(ValueError, "can't be assigned to"): + kernel(jnp.arange(128).astype(jnp.float32)) + + def test_loading_from_ref_union_works(self): + # `load_p` does not have a defined lowering for warpgroup semantics. + self.skip_if_wg_semantics() + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct([128], jnp.float32), + in_specs=[pl.BlockSpec((128,))] * 2, + out_specs=pl.BlockSpec((128,), memory_space=plgpu.GMEM), + scratch_shapes=[plgpu.RefUnion(plgpu.SMEM((128,), jnp.float32)), + plgpu.SMEM((128,), jnp.float32)], + ) + def kernel(x_ref, y_ref, o_ref128, ref_union, o_smem): + [aliased_ref] = ref_union + aliased_ref[...] = x_ref[...] + plgpu.commit_smem() + load_ref = lambda r: plgpu.load(r, (), layout=plgpu.Layout.TCGEN05_ROW) + # This is a regression test for b/423697560, where we used to fail to + # transform the dtype correctly when processing an aliased ref. + o_smem[...] = load_ref(aliased_ref) + load_ref(y_ref) + plgpu.commit_smem() + plgpu.copy_smem_to_gmem(o_smem, o_ref128) + + x, y = (jnp.arange(128).astype(jnp.float32) for _ in range(2)) + np.testing.assert_array_equal(kernel(x, y), x + y) + + @parameterized.parameters(1, 2, 3) + def test_nd_loop(self, sm_steps): + @functools.partial( + self.kernel, + out_shape=jax.ShapeDtypeStruct((sm_steps, 132, 128), jnp.int32), + grid=(132,), + grid_names=("sm",), + ) + def kernel(o_ref): + @plgpu.nd_loop((sm_steps, 4, 33), collective_axes="sm") + def _(idx): + assert len(idx) == 3 + # We need to use `mode="clip"`, because the indices are not static. + flat_idx = jnp.ravel_multi_index(idx, (sm_steps, 4, 33), mode="clip") + sm_step = lax.div( + flat_idx, lax.convert_element_type(lax.axis_size("sm"), jnp.int32) + ) + o_ref[sm_step, lax.axis_index("sm")] = lax.broadcast( + flat_idx, o_ref.shape[-1:] + ) + + result = kernel() + for sm_step in range(sm_steps): + np.testing.assert_array_equal( + result[sm_step], + jnp.tile((132 * sm_step + jnp.arange(132))[:, None], 128), + ) + + def test_lowering_error_context(self): + def body(x_ref, y_ref, barrier): + plgpu.copy_gmem_to_smem(x_ref, y_ref, barrier) + plgpu.barrier_wait(barrier) + + x = jnp.arange(127, dtype=jnp.int4) # Size is not a multiple of bytes + offending_line = "plgpu.copy_gmem_to_smem(x_ref, y_ref, barrier)" + try: + pl.pallas_call( + body, + in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)], + out_specs=pl.BlockSpec(memory_space=plgpu.SMEM), + out_shape=x, + scratch_shapes=[plgpu.Barrier()], + )(x) + except: + # assertRaisesRegex raises does not let us match the traceback. + self.assertIn(offending_line, traceback.format_exc()) + else: + self.fail("Should have raised an exception") + + @parameterized.named_parameters( + ( + f"_{''.join(map(str, collective_dims))}={collective_size}{'_' + ''.join(map(str, noncollective_dims)) if noncollective_dims else ''}", + collective_dims, + noncollective_dims, + collective_size, + ) + for collective_dims in itertools.chain.from_iterable( + itertools.combinations("xyz", n) for n in range(1, 4) + ) + for noncollective_dims in itertools.chain.from_iterable( + itertools.combinations("xyz", n) for n in range(3) + ) + for collective_size in (1, 2, 4) + if all(d not in noncollective_dims for d in collective_dims) + ) + def test_tma_load_multicast(self, collective_dims, noncollective_dims, collective_dim_size): + """ + 1. Broadcast a GMEM slice to SMEM across collective CTAs. + 2. Send a SMEM slice from each collective CTA to reconstruct the GMEM slice. + It's not strictly necessary to use every collective CTA, but we use them + to test that the cluster axes are used correctly. + """ + + dtype = jnp.float16 + cluster = [1, 1, 1] + for d in collective_dims: + cluster["xyz".index(d)] = collective_dim_size + for d in noncollective_dims: + cluster["xyz".index(d)] = 2 + if math.prod(cluster) > 16: + self.skipTest("Cluster is too big.") + + collective_size = math.prod(cluster["xyz".index(d)] for d in collective_dims) + noncollective_size = math.prod(cluster) // collective_size + + swizzle = 128 + swizzle_elems = swizzle // jnp.dtype(dtype).itemsize + transforms = ( + plgpu.TilingTransform((8, swizzle_elems)), + plgpu.SwizzleTransform(swizzle), + ) + shape = (noncollective_size, collective_size * 8, swizzle_elems) + + def body(x_gmem, out_gmem, smem, tma_barrier): + # Compute the index in a subset of the cluster. + def cluster_id(axes): + idx, stride = 0, 1 + for d in sorted(axes): + idx += lax.axis_index(d) * stride + stride *= lax.axis_size(d) + return idx + + noncollective_idx = cluster_id(noncollective_dims) + collective_idx = cluster_id(collective_dims) + + plgpu.copy_gmem_to_smem( + x_gmem.at[noncollective_idx], + smem, + tma_barrier, + collective_axes=collective_dims) + plgpu.barrier_wait(tma_barrier) + + plgpu.commit_smem() + collective_slice = pl.ds(8 * collective_idx, 8) + plgpu.copy_smem_to_gmem( + smem.at[collective_slice], + out_gmem.at[noncollective_idx, collective_slice, :], + ) + plgpu.wait_smem_to_gmem(0) + + x = np.arange(np.prod(shape), dtype=dtype).reshape(shape) + kernel = plgpu.kernel( + body, + grid=cluster, + grid_names=("grid_x", "grid_y", "grid_z"), + cluster=cluster, + cluster_names=("x", "y", "z"), + out_shape=jax.ShapeDtypeStruct(shape, dtype), + scratch_shapes=( + plgpu.SMEM(shape[1:], dtype, transforms=transforms), + plgpu.Barrier(), + ) + ) + np.testing.assert_array_equal(kernel(x), x) + + +class PallasCallWarpPrimitiveSemanticsTest(PallasTest): + def setUp(self): + super().setUp() + if self.LOWERING_SEMANTICS != plgpu.LoweringSemantics.Lane: + self.skipTest("Test only works on Lane semantics") + + def test_axis_index(self): + warp_mesh = plgpu.WarpMesh(axis_name="warp") + @functools.partial(plgpu.kernel, + out_shape=jax.ShapeDtypeStruct((2, 128), jnp.int32)) + def kernel(y_ref): + def scope(ones_smem_ref, threes_smem_ref): + # Prepare data to copy. + ones_smem_ref[:] = jnp.ones((1, 128), jnp.int32) + threes_smem_ref[:] = jnp.ones((1, 128), jnp.int32) * 3 + plgpu.commit_smem() + @pl.core_map(warp_mesh) + def _(): + warp_id = lax.axis_index("warp") + # We cannot load/store inside of core_map, so we issue async + # copies instead to produce a testable result. + @pl.when(warp_id == 1) + def _(): + plgpu.copy_smem_to_gmem(ones_smem_ref, y_ref.at[0:1]) + @pl.when(warp_id == 3) + def _(): + plgpu.copy_smem_to_gmem(threes_smem_ref, y_ref.at[1:2]) + plgpu.wait_smem_to_gmem(0) + pl.run_scoped(scope, + plgpu.SMEM((1, 128), jnp.int32), + plgpu.SMEM((1, 128), jnp.int32) + ) + result = kernel() + expected = jnp.stack((jnp.ones((128,), jnp.int32), + jnp.ones((128,), jnp.int32) * 3), axis=0) + np.testing.assert_array_equal(result, expected) + + def test_errors_when_closing_over_array(self): + # We currently do not allow closing over arrays when mapping over + # a mesh, since we would need to present a view of the array local + # to each warp. + warp_mesh = plgpu.WarpMesh(axis_name="warp") + @functools.partial(plgpu.kernel, + out_shape=jax.ShapeDtypeStruct((32, 32), jnp.float32), + scratch_shapes=[plgpu.SMEM((32, 32), jnp.float32)]) + def kernel(out_ref, smem_ref): + arr = jnp.ones((32, 32), dtype=jnp.float32) + @pl.core_map(warp_mesh) + def _(): + smem_ref[...] = arr + 1 + plgpu.commit_smem() + plgpu.copy_smem_to_gmem(smem_ref, out_ref) + plgpu.wait_smem_to_gmem(0) + with self.assertRaisesRegex( + mgpu_lowering.LoweringError, + "Can only close over scalars and Refs when using core_map with " + "WarpMesh", + ): + kernel() + + def test_single_warp_scan(self): + warp_mesh = plgpu.WarpMesh(axis_name="warp") + @functools.partial(plgpu.kernel, + out_shape=jax.ShapeDtypeStruct((10, 128), jnp.int32)) + def kernel(y_ref): + def scope(smem_ref): + # Prepare data to copy. + for i in range(10): + smem_ref[i, :] = jnp.ones_like(smem_ref.at[i]) * i + plgpu.commit_smem() + @pl.core_map(warp_mesh) + def _(): + warp_id = lax.axis_index("warp") + @pl.when(warp_id == 0) + def _(): + def loop_body(i, _): + _slice = pl.ds(i, 1) + plgpu.copy_smem_to_gmem(smem_ref.at[_slice], y_ref.at[_slice]) + lax.fori_loop(0, 10, loop_body, None) + plgpu.wait_smem_to_gmem(0) + pl.run_scoped(scope, plgpu.SMEM((10, 128), jnp.int32)) + result = kernel() + expected = jnp.stack( + [jnp.ones((128,), jnp.int32) * i for i in range(10)], axis=0) + np.testing.assert_array_equal(result, expected) + + def test_debug_print(self): + warp_mesh = plgpu.WarpMesh(axis_name="warp") + @functools.partial( + plgpu.kernel, + out_shape=jnp.zeros(128, np.int32), + ) + def kernel(ref): + ref[...] = ref[...] # Prevent kernel from being DCE'd + @pl.core_map(warp_mesh) + def _(): + warp_id = lax.axis_index("warp") + pl.debug_print("warp: {}", warp_id) + + with self.capture_stdout() as output: + jax.block_until_ready(kernel()) + self.assertEqual( + set(output().splitlines()), + { + "warp: 0", + "warp: 1", + "warp: 2", + "warp: 3", + }, + ) + + def test_copy_gmem_to_smem_from_different_warps(self): + # In this test, we issue a copy from from warp 0 and await it in warp 1. + warp_mesh = plgpu.WarpMesh(axis_name="warp") + @functools.partial(plgpu.kernel, + out_shape=jax.ShapeDtypeStruct((32, 32), jnp.float32)) + def kernel(x_ref, y_ref): + def scope(smem_ref, tma_barrier): + @pl.core_map(warp_mesh) + def _(): + warp_id = lax.axis_index("warp") + @pl.when(warp_id == 0) + def _(): + plgpu.copy_gmem_to_smem(x_ref.at[32:64], smem_ref, tma_barrier) + + @pl.when(warp_id == 1) + def _(): + plgpu.barrier_wait(tma_barrier) + plgpu.copy_smem_to_gmem(smem_ref, y_ref) + plgpu.wait_smem_to_gmem(0) + pl.run_scoped(scope, + smem_ref=plgpu.SMEM((32, 32), jnp.float32), + tma_barrier=plgpu.Barrier()) + x = jax.random.uniform(jax.random.key(42), (64, 32), jnp.float32) + result = kernel(x) + np.testing.assert_array_equal(result, x[32:64]) + + +class PallasCallWGTest( + PallasCallTest, lowering_semantics=plgpu.LoweringSemantics.Warpgroup +): + ... + + def test_missing_primitive_lowerings_are_tracked(self): + # This test is a way to keep track of which primitives need to be adapted + # to using warpgroup semantics. Once the set is empty, we should be able to + # enable warpgroup semantics by default (assuming we haven't overspecialized + # lowerings). + rules = mgpu_lowering.mosaic_lowering_rules + wg_wg_lowered_primitives = set( + rules[(plgpu.LoweringSemantics.Warpgroup, + gpu_core.PrimitiveSemantics.Warpgroup)]) + lane_wg_lowered_primitives = set(rules[ + (plgpu.LoweringSemantics.Lane, gpu_core.PrimitiveSemantics.Warpgroup)]) + + actual_missing_primitives = (lane_wg_lowered_primitives - + wg_wg_lowered_primitives) + expected_missing_primitives = { + mgpu_primitives.inline_mgpu_p, + mgpu_primitives.broadcasted_iota_p, + mgpu_primitives.load_p, + mgpu_primitives.tcgen05_mma_p, + mgpu_primitives.commit_tmem_p, + lax.slice_p, + pallas_core.core_map_p, + pallas_primitives.semaphore_signal_p, + pallas_primitives.semaphore_wait_p, + pallas_primitives.semaphore_read_p, + checkify.check_p, + } + + self.assertSetEqual(actual_missing_primitives, expected_missing_primitives) + class PallasCallSm90ATest(PallasSm90ATest): @parameterized.parameters(False, True) def test_fori_loop_accumulator(self, force_while): - transforms = (plgpu.TilingTransform((64, 64)), plgpu.SwizzleTransform(128)) + if self.LOWERING_SEMANTICS == plgpu.LoweringSemantics.Lane: + transforms = (plgpu.TilingTransform((8, 64)), plgpu.SwizzleTransform(128)) + else: + transforms = () + @functools.partial( - pl.pallas_call, - in_specs=[plgpu.GPUBlockSpec((64, 64), lambda: (0, 0), transforms=transforms)], + self.pallas_call, + in_specs=[plgpu.BlockSpec((64, 64), transforms=transforms)], out_shape=jax.ShapeDtypeStruct((64, 64), jnp.float16), - out_specs=plgpu.GPUBlockSpec((64, 64), lambda: (0, 0)), + out_specs=plgpu.BlockSpec((64, 64)), ) def kernel(i_ref, o_ref): def scope(acc_ref): @@ -1263,7 +2101,8 @@ def scope(acc_ref): acc_ini = jnp.ones((64, 64), dtype=jnp.float16) np.testing.assert_array_equal(kernel(acc_ini), jnp.full((64, 64), 5, dtype=jnp.float16)) - def test_realistic_matmul(self): + @parameterized.product(lhs_transpose=[False, True], rhs_transpose=[False, True]) + def test_realistic_matmul(self, lhs_transpose, rhs_transpose): dtype = jnp.float16 swizzle = 128 elems_128b = swizzle // jnp.dtype(dtype).itemsize @@ -1273,7 +2112,11 @@ def test_realistic_matmul(self): m, k, n = grid_m * tile_m, grid_k * tile_k, grid_n * tile_n def kernel(a_ref, b_ref, o_ref, acc_ref): # Make sure tiling does not alter the shape of references + if lhs_transpose: + a_ref = plgpu.transpose_ref(a_ref, (1, 0)) assert a_ref.shape == (tile_m, tile_k) + if rhs_transpose: + b_ref = plgpu.transpose_ref(b_ref, (1, 0)) assert b_ref.shape == (tile_k, tile_n) assert o_ref.shape == acc_ref.shape == (tile_m, tile_n) plgpu.wgmma(acc_ref, a_ref, b_ref) @@ -1284,50 +2127,85 @@ def _epilogue(): plgpu.wgmma_wait(1) # We don't await the last WGMMA, hence delay_release=1 key1, key2 = jax.random.split(jax.random.key(42), 2) - a = jax.random.uniform(key1, shape=(m, k), dtype=dtype) - b = jax.random.uniform(key2, shape=(k, n), dtype=dtype) + a_shape = (k, m) if lhs_transpose else (m, k) + a = jax.random.uniform(key1, shape=a_shape, dtype=dtype) + b_shape = (n, k) if rhs_transpose else (k, n) + b = jax.random.uniform(key2, shape=b_shape, dtype=dtype) - res = pl.pallas_call( + if lhs_transpose: + lhs_spec = pl.BlockSpec( + (tile_k, tile_m), + lambda m, n, k: (k, m), + ) + else: + lhs_spec = pl.BlockSpec( + (tile_m, tile_k), + lambda m, n, k: (m, k), + ) + if rhs_transpose: + rhs_spec = pl.BlockSpec( + (tile_n, tile_k), + lambda m, n, k: (n, k), + ) + else: + rhs_spec = pl.BlockSpec( + (tile_k, tile_n), + lambda m, n, k: (k, n), + ) + out_spec = pl.BlockSpec( + (tile_m, tile_n), + lambda m, n, k: (m, n), + ) + + if self.LOWERING_SEMANTICS == plgpu.LoweringSemantics.Lane: + lhs_spec = plgpu.BlockSpec( + lhs_spec.block_shape, + lhs_spec.index_map, + transforms=( + plgpu.TilingTransform((8, elems_128b)), + plgpu.SwizzleTransform(128), + ), + ) + rhs_spec = plgpu.BlockSpec( + rhs_spec.block_shape, + rhs_spec.index_map, + transforms=( + plgpu.TilingTransform((8, elems_128b)), + plgpu.SwizzleTransform(128), + ), + ) + out_spec = plgpu.BlockSpec( + out_spec.block_shape, + out_spec.index_map, + transforms=( + plgpu.TilingTransform((8, elems_128b)), + plgpu.SwizzleTransform(128), + ), + ) + + res = self.pallas_call( kernel, - in_specs=[ - plgpu.GPUBlockSpec( - (tile_m, tile_k), - lambda m, n, k: (m, k), - transforms=( - plgpu.TilingTransform((64, elems_128b)), - plgpu.SwizzleTransform(128), - ), - ), - plgpu.GPUBlockSpec( - (tile_k, tile_n), - lambda m, n, k: (k, n), - transforms=( - plgpu.TilingTransform((elems_128b, elems_128b)), - plgpu.SwizzleTransform(128), - ), - ), - ], - out_specs=plgpu.GPUBlockSpec( - (tile_m, tile_n), - lambda m, n, k: (m, n), - transforms=( - plgpu.TilingTransform((64, elems_128b)), - plgpu.SwizzleTransform(128), - ), - ), + in_specs=[lhs_spec, rhs_spec], + out_specs=out_spec, out_shape=jax.ShapeDtypeStruct((m, n), jnp.float16), scratch_shapes=[plgpu.ACC((tile_m, tile_n), jnp.float32)], grid=(grid_m, grid_n, grid_k), - compiler_params=plgpu.GPUCompilerParams( + compiler_params=plgpu.CompilerParams( dimension_semantics=["parallel", "parallel", "sequential"], max_concurrent_steps=2, delay_release=1, ), )(a, b) - np.testing.assert_allclose(res, a @ b, rtol=1e-3) + np.testing.assert_allclose( + res, + (a.T if lhs_transpose else a) @ (b.T if rhs_transpose else b), + rtol=1e-3, + ) @parameterized.parameters(jnp.float16, jnp.float32) def test_wgmma(self, dtype): + self.skip_if_wg_semantics() + # TensorCores can only fuse transposes of 16-bit values, and RHS # is expected to be column major by default. rhs_transpose = jnp.dtype(dtype).itemsize != 2 @@ -1349,27 +2227,25 @@ def scope(acc_ref): b_shape = b_shape[::-1] b = jax.random.uniform(key2, shape=b_shape, dtype=dtype) - rhs_transforms = (plgpu.TilingTransform((elems_128b, elems_128b)),) - if rhs_transpose: - rhs_transforms += (plgpu.TransposeTransform((1, 0, 2, 3)),) - res = pl.pallas_call( + rhs_transforms = (plgpu.TilingTransform((8, elems_128b)),) + res = self.pallas_call( kernel, in_specs=[ - plgpu.GPUBlockSpec( + plgpu.BlockSpec( (64, 128), lambda i, j: (i, j), transforms=( - plgpu.TilingTransform((64, elems_128b)), + plgpu.TilingTransform((8, elems_128b)), plgpu.SwizzleTransform(128), ), ), - plgpu.GPUBlockSpec( + plgpu.BlockSpec( b_shape, lambda *i: i, transforms=(*rhs_transforms, plgpu.SwizzleTransform(128)), ), ], - out_specs=plgpu.GPUBlockSpec((64, 192), lambda *i: i), + out_specs=plgpu.BlockSpec((64, 192), lambda *i: i), out_shape=jax.ShapeDtypeStruct((64, 192), jnp.float32), grid=(1, 1), )(a, b) @@ -1388,14 +2264,15 @@ def scope(acc_ref): a = jax.random.uniform(key1, shape=(64, 128), dtype=jnp.float16) b = jax.random.uniform(key2, shape=(128, 192), dtype=jnp.float16) - transforms = (plgpu.TilingTransform((64, 64)), plgpu.SwizzleTransform(128)) - res = pl.pallas_call( + transforms = () + if self.LOWERING_SEMANTICS == plgpu.LoweringSemantics.Lane: + transforms = (plgpu.TilingTransform((8, 64)), plgpu.SwizzleTransform(128)) + res = self.pallas_call( kernel, in_specs=[ - plgpu.GPUBlockSpec((64, 128), lambda: (0, 0), transforms=transforms), - plgpu.GPUBlockSpec((128, 192), lambda: (0, 0), transforms=transforms), + plgpu.BlockSpec(transforms=transforms), + plgpu.BlockSpec(transforms=transforms), ], - out_specs=plgpu.GPUBlockSpec((64, 192), lambda: (0, 0)), out_shape=jax.ShapeDtypeStruct((64, 192), jnp.float32), )(a, b) np.testing.assert_allclose(res, a @ b, rtol=1e-3) @@ -1411,20 +2288,24 @@ def scope(acc_ref): b = jax.random.uniform(key2, shape=(128, 192), dtype=jnp.float16) i = jax.random.uniform(key3, shape=(64, 192), dtype=jnp.float16) * 10 - transforms = (plgpu.TilingTransform((64, 64)), plgpu.SwizzleTransform(128)) - res = pl.pallas_call( + if self.LOWERING_SEMANTICS == plgpu.LoweringSemantics.Lane: + transforms = (plgpu.TilingTransform((8, 64)), plgpu.SwizzleTransform(128)) + else: + transforms = () + res = self.pallas_call( kernel, in_specs=[ - plgpu.GPUBlockSpec((64, 128), lambda: (0, 0), transforms=transforms), - plgpu.GPUBlockSpec((128, 192), lambda: (0, 0), transforms=transforms), - plgpu.GPUBlockSpec((64, 192), lambda: (0, 0), transforms=transforms), + plgpu.BlockSpec(transforms=transforms), + plgpu.BlockSpec(transforms=transforms), + plgpu.BlockSpec(transforms=transforms), ], - out_specs=plgpu.GPUBlockSpec((64, 192), lambda: (0, 0)), out_shape=jax.ShapeDtypeStruct((64, 192), jnp.float16), )(a, b, i) np.testing.assert_allclose(res, i + a @ b, rtol=2e-3) def test_wgmma_sliced_ref(self): + self.skip_if_wg_semantics() # Needs WGMMA to support slices. + def kernel(a_ref, b_ref, o_ref): def scope(acc_ref): plgpu.wgmma(acc_ref, a_ref.at[0], b_ref.at[0]) @@ -1436,30 +2317,23 @@ def scope(acc_ref): a = jax.random.uniform(key1, shape=(2, 64, 128), dtype=jnp.float16) b = jax.random.uniform(key2, shape=(2, 128, 192), dtype=jnp.float16) - res = pl.pallas_call( + transforms = () + if self.LOWERING_SEMANTICS == plgpu.LoweringSemantics.Lane: + transforms = (plgpu.TilingTransform((8, 64)), plgpu.SwizzleTransform(128)) + + res = self.pallas_call( kernel, in_specs=[ - plgpu.GPUBlockSpec( - (2, 64, 128), lambda: (0, 0, 0), - transforms=( - plgpu.TilingTransform((64, 64)), - plgpu.SwizzleTransform(128), - ), - ), - plgpu.GPUBlockSpec( - (2, 128, 192), lambda: (0, 0, 0), - transforms=( - plgpu.TilingTransform((64, 64)), - plgpu.SwizzleTransform(128), - ), - ), + plgpu.BlockSpec(transforms=transforms), + plgpu.BlockSpec(transforms=transforms), ], - out_specs=plgpu.GPUBlockSpec((64, 192), lambda: (0, 0)), out_shape=jax.ShapeDtypeStruct((64, 192), jnp.float32), )(a, b) np.testing.assert_allclose(res, a[0] @ b[0], rtol=1e-3) def test_wgmma_sliced_acc(self): + self.skip_if_wg_semantics() # Needs WGMMA to support slices. + swizzle = 128 elems_128b = swizzle // jnp.dtype(jnp.float16).itemsize def kernel(a_ref, b_ref, o_ref): @@ -1472,75 +2346,627 @@ def scope(acc_ref): key1, key2 = jax.random.split(jax.random.key(42), 2) a = jax.random.uniform(key1, shape=(64, 128), dtype=jnp.float16) b = jax.random.uniform(key2, shape=(128, 128), dtype=jnp.float16) - res = pl.pallas_call( + transforms = () + if self.LOWERING_SEMANTICS == plgpu.LoweringSemantics.Lane: + transforms = ( + plgpu.TilingTransform((8, elems_128b)), + plgpu.SwizzleTransform(128), + ) + res = self.pallas_call( kernel, in_specs=[ - plgpu.GPUBlockSpec( - (64, 128), - lambda i, j: (i, j), - transforms=( - plgpu.TilingTransform((64, elems_128b)), - plgpu.SwizzleTransform(128), - ), - ), - plgpu.GPUBlockSpec( - (128, 128), - lambda *i: i, - transforms=( - plgpu.TilingTransform((elems_128b, elems_128b)), - plgpu.SwizzleTransform(128), - ), - ), + plgpu.BlockSpec((64, 128), lambda *ij: ij, transforms=transforms), + plgpu.BlockSpec((128, 128), lambda *ij: ij, transforms=transforms), ], - out_specs=plgpu.GPUBlockSpec((64, 128), lambda *i: i), + out_specs=plgpu.BlockSpec((64, 128), lambda *ij: ij), out_shape=jax.ShapeDtypeStruct((64, 128), jnp.float32), grid=(1, 1), )(a, b) np.testing.assert_allclose(res, a @ b, rtol=1e-3) + @parameterized.product( + src_memory_space=[plgpu.SMEM, plgpu.GMEM], + layout=[plgpu.Layout.WGMMA_ROW, plgpu.Layout.WGMMA_COL], + m=[64, 128, 192], + ) + def test_load_to_wgmma_row_col_layout_with_indexing(self, src_memory_space, layout, m): + self.skip_if_wg_semantics() -class PipelineTest(PallasTest): + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct([2, m], jnp.float32), + in_specs=[pl.BlockSpec(memory_space=src_memory_space)], + out_specs=plgpu.BlockSpec(memory_space=plgpu.SMEM), + ) + def kernel(x_ref, o_ref): + for i in range(2): + x = plgpu.load( + x_ref, (i,), layout=layout, optimized=src_memory_space != plgpu.GMEM + ) + o_ref[i, ...] = x - def test_pipeline_mode(self): - def body(x_ref, y_ref, o_ref): - x = x_ref[:] - y = y_ref[:] - o_ref[:] = x + y + x = jnp.arange(2 * m, dtype=jnp.float32).reshape(2, m) + np.testing.assert_array_equal(kernel(x), x) - data_size = 64 * 256 - block_size = 256 + @parameterized.product( + src_memory_space=[plgpu.SMEM], + layout=[plgpu.Layout.WGMMA_ROW, plgpu.Layout.WGMMA_COL], + ) + def test_load_row_input_to_wgmma_with_transforms(self, src_memory_space, layout): + self.skip_if_wg_semantics() - x = jnp.arange(data_size, dtype=jnp.float32) - y = jnp.arange(data_size, dtype=jnp.float32) - in_specs = [ - pl.BlockSpec((block_size,), lambda *i: i, pipeline_mode=pl.Buffered(2)), - pl.BlockSpec((block_size,), lambda *i: i, pipeline_mode=pl.Buffered(1)) - ] - out_specs = pl.BlockSpec((block_size,), lambda *i: i) + m, k, n = 64, 128, 192 + key1, key2 = jax.random.split(jax.random.key(42), 2) + if layout == plgpu.Layout.WGMMA_ROW: + input_shape = (m,) + broadcast_dim = 0 + expand_dim = 1 + else: + input_shape = (k,) + broadcast_dim = 1 + expand_dim = 0 + a = jax.random.uniform(key1, shape=input_shape, dtype=jnp.float16) + b = jax.random.uniform(key2, shape=(k, n), dtype=jnp.float16) + def kernel(x_ref, y_ref, o_ref): + x = plgpu.load(x_ref, (), layout=layout) + x = lax.broadcast_in_dim(x, (m, k), [broadcast_dim]) - @jax.jit - def vadd(x, y): - return pl.pallas_call( - body, - out_shape=jax.ShapeDtypeStruct(x.shape, jnp.float32), - in_specs=in_specs, - out_specs=out_specs, - grid=data_size // block_size, - )(x, y) + def compute(acc_ref): + plgpu.wgmma(acc_ref, x, y_ref) + return acc_ref[...] - with self.assertRaisesRegex(Exception, "Pipeline mode is not supported"): - vadd(x, y) + out = pl.run_scoped(compute, plgpu.ACC((m, n), jnp.float32)) + o_ref[...] = out + f = self.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct([m, n], jnp.float32), + in_specs=( + pl.BlockSpec(memory_space=src_memory_space), + plgpu.BlockSpec( + transforms=( + plgpu.TilingTransform((8, 64)), + plgpu.SwizzleTransform(128), + ), + ), + ), + out_specs=plgpu.BlockSpec(memory_space=plgpu.SMEM), + ) - def test_manual(self): - max_concurrent_steps = 2 - num_steps = 4 + out_ref = ( + jnp.broadcast_to(jnp.expand_dims(a, axis=expand_dim), (m, k)) @ b + ) + np.testing.assert_allclose(f(a, b), out_ref, rtol=1e-3) + + +class PallasCallSm90AWGTest( + PallasCallSm90ATest, lowering_semantics=plgpu.LoweringSemantics.Warpgroup +): + ... + + +class PallasCallSm100ATest(PallasSm100ATest): + + @parameterized.parameters( + (False,), + (True,), + ) + def test_tmem(self, collective): + self.skip_if_wg_semantics() # TMEM read not wired up in the WG get rule. + swizzle_elems = 128 // jnp.dtype(jnp.float32).itemsize + transforms = ( + plgpu.TilingTransform((8, swizzle_elems)), + plgpu.SwizzleTransform(128), + ) + @functools.partial( + self.kernel, + out_shape=jnp.zeros((128, 128), jnp.float32), + scratch_shapes=[ + plgpu.TMEM((128, 128), jnp.float32, collective=collective), + plgpu.TMEM((128, 128), jnp.float32, collective=collective), + plgpu.SMEM((128, 128), jnp.float32, transforms=transforms), + plgpu.Barrier(), + ], + num_threads=1, + thread_name="x", + cluster=(2,) if collective else (), + cluster_names=("x",) if collective else (), + ) + def kernel(x_ref, y_ref, tmem_ref, tmem_ref2, smem_ref, barrier_ref): + plgpu.copy_gmem_to_smem(x_ref, smem_ref, barrier_ref) + plgpu.barrier_wait(barrier_ref) + # Exercise TMEM by roundtripping SMEM -> TMEM -> TMEM -> SMEM. + x_val = plgpu.load(smem_ref, (), layout=plgpu.Layout.TCGEN05) + tmem_ref[...] = x_val + 1 + plgpu.commit_tmem() + tmem_ref2[...] = tmem_ref[...] + plgpu.commit_tmem() + smem_ref[...] = tmem_ref2[...] + plgpu.commit_smem() + plgpu.copy_smem_to_gmem(smem_ref, y_ref) + plgpu.wait_smem_to_gmem(0) + + x = jax.random.uniform( + jax.random.key(0), shape=(128, 128), dtype=jnp.float32) + x_result = jax.block_until_ready(kernel(x)) + np.testing.assert_array_equal(x_result, x + 1) + + def test_tmem_column_slicing(self): + self.skip_if_wg_semantics() + swizzle_elems = 128 // jnp.dtype(jnp.float32).itemsize + transforms = ( + plgpu.TilingTransform((8, swizzle_elems)), + plgpu.SwizzleTransform(128), + ) + @functools.partial( + self.kernel, + out_shape=jnp.zeros((128, 128), jnp.float32), + scratch_shapes=[ + plgpu.TMEM((128, 256), jnp.float32), + plgpu.SMEM((128, 128), jnp.float32, transforms=transforms), + plgpu.Barrier(), + ], + num_threads=1, + thread_name="x", + ) + def kernel(x_ref, y_ref, tmem_ref, smem_ref, barrier_ref): + plgpu.copy_gmem_to_smem(x_ref, smem_ref, barrier_ref) + plgpu.barrier_wait(barrier_ref) + x_val = plgpu.load(smem_ref, (), layout=plgpu.Layout.TCGEN05) + tmem_slice = tmem_ref.at[:, 8:208].at[:, 0:128] + tmem_slice[...] = x_val + 1 + plgpu.commit_tmem() + smem_ref[...] = tmem_ref[:, 8:136] + plgpu.commit_smem() + plgpu.copy_smem_to_gmem(smem_ref, y_ref) + plgpu.wait_smem_to_gmem(0) + + x = jax.random.uniform( + jax.random.key(0), shape=(128, 128), dtype=jnp.float32) + x_result = jax.block_until_ready(kernel(x)) + np.testing.assert_array_equal(x_result, (x + 1)[:, 0:128]) + + @parameterized.parameters( + (jnp.sum,), + (jnp.max,) + ) + def test_reduce_with_tcgen05_layout(self, op): + self.skip_if_wg_semantics() + axis = -1 + swizzle_elems = 128 // jnp.dtype(jnp.float32).itemsize + transforms = ( + plgpu.TilingTransform((8, swizzle_elems)), + plgpu.SwizzleTransform(128), + ) + @functools.partial( + self.kernel, + out_shape=jnp.zeros((128,), jnp.float32), + scratch_shapes=[ + plgpu.SMEM((128, 128), jnp.float32, transforms=transforms), + plgpu.SMEM((128,), jnp.float32), + plgpu.Barrier(), + ], + num_threads=1, + thread_name="x", + ) + def kernel(x_ref, y_ref, smem_ref, smem_reduced_ref, barrier_ref): + plgpu.copy_gmem_to_smem(x_ref, smem_ref, barrier_ref) + plgpu.barrier_wait(barrier_ref) + x_val = plgpu.load(smem_ref, (), layout=plgpu.Layout.TCGEN05) + smem_reduced_ref[...] = op(x_val, axis=axis) + plgpu.commit_smem() + plgpu.copy_smem_to_gmem(smem_reduced_ref, y_ref) + plgpu.wait_smem_to_gmem(0) + + x = jax.random.uniform( + jax.random.key(0), shape=(128, 128), dtype=jnp.float32) + x_result = jax.block_until_ready(kernel(x)) + np.testing.assert_allclose(x_result, op(x, axis=axis), atol=1e-5) + + @parameterized.parameters((0,), (1,)) + def test_broadcast_in_dim_tcgen05_layout(self, axis): + self.skip_if_wg_semantics() + + @functools.partial( + self.kernel, + out_shape=jnp.zeros((128, 128), jnp.float32), + scratch_shapes=[ + plgpu.SMEM((128,), jnp.float32), + plgpu.SMEM((128, 128), jnp.float32), + plgpu.Barrier(), + ], + num_threads=1, + thread_name="x", + ) + def kernel(x_ref, y_ref, smem_ref, smem_out_ref, barrier_ref): + plgpu.copy_gmem_to_smem(x_ref, smem_ref, barrier_ref) + plgpu.barrier_wait(barrier_ref) + if axis == 0: + reduced = plgpu.load(smem_ref, (), layout=plgpu.Layout.TCGEN05_COL) + else: + reduced = plgpu.load(smem_ref, (), layout=plgpu.Layout.TCGEN05_ROW) + broadcasted = lax.broadcast_in_dim(reduced, (128, 128), [1 - axis]) + smem_out_ref[...] = broadcasted + plgpu.commit_smem() + plgpu.copy_smem_to_gmem(smem_out_ref, y_ref) + plgpu.wait_smem_to_gmem(0) + + x = jax.random.uniform(jax.random.key(0), shape=(128,), dtype=jnp.float32) + x_result = jax.block_until_ready(kernel(x)) + expected = jnp.expand_dims(x, axis=axis) + expected = jnp.broadcast_to(expected, (128, 128)) + np.testing.assert_array_equal(x_result, expected) + + @parameterized.product(shape=[(128, 128)], + swizzle=[128, 64, 32], + dtype=[jnp.float16, jnp.bfloat16], + lhs_tmem=[False, True], + transpose_rhs=[False, True], + transpose_lhs=[False, True]) + def test_simple_matmul(self, shape, swizzle, + dtype=jnp.float16, + lhs_tmem=False, + transpose_lhs=False, + transpose_rhs=False): + self.skip_if_wg_semantics() + if transpose_lhs and lhs_tmem: + self.skipTest("TMEM transpose not supported.") + # Test a matmul with a single block. + swizzle_elems = swizzle // jnp.dtype(dtype).itemsize + transforms = ( + plgpu.TilingTransform((8, swizzle_elems)), + plgpu.SwizzleTransform(swizzle), + ) + + def kernel(a_smem, b_smem, out_ref, acc_tmem, scratch_smem, barrier_ref, + a_tmem_ref): + if transpose_lhs: + a_smem = plgpu.transpose_ref(a_smem, (1, 0)) + if transpose_rhs: + b_smem = plgpu.transpose_ref(b_smem, (1, 0)) + if lhs_tmem: + lhs_ref = a_tmem_ref + lhs_ref[...] = plgpu.load(a_smem, (), layout=plgpu.Layout.TCGEN05) + plgpu.commit_tmem() + else: + lhs_ref = a_smem + plgpu.tcgen05_mma(acc_tmem, + lhs_ref, + b_smem, + barrier_ref, + accumulate=False) + plgpu.barrier_wait(barrier_ref) + scratch_smem[...] = acc_tmem[...].astype(dtype) + plgpu.commit_smem() + plgpu.copy_smem_to_gmem(scratch_smem, out_ref) + plgpu.wait_smem_to_gmem(0) + + scratch_shapes = [ + plgpu.TMEM(shape, jnp.float32, packed=False), + plgpu.SMEM(shape, dtype, transforms=transforms), + plgpu.Barrier(for_tensor_core=True), + ] + if lhs_tmem: + scratch_shapes.append(plgpu.TMEM(shape, dtype, packed=True)) + else: + scratch_shapes.append(None) + + f = self.pallas_call( + kernel, + in_specs=( + plgpu.BlockSpec(transforms=transforms, memory_space=plgpu.SMEM), + plgpu.BlockSpec(transforms=transforms, memory_space=plgpu.SMEM), + ), + out_specs=plgpu.BlockSpec(memory_space=plgpu.GMEM), + out_shape=jax.ShapeDtypeStruct(shape, dtype), + scratch_shapes=scratch_shapes, + ) + x = jax.random.uniform(jax.random.key(0), shape=shape, dtype=dtype) + y = jax.random.uniform(jax.random.key(1), shape=shape, dtype=dtype) + result = f(x, y) + if transpose_lhs: + x = jnp.transpose(x, (1, 0)) + if transpose_rhs: + y = jnp.transpose(y, (1, 0)) + expected = x @ y + np.testing.assert_allclose(result, expected, rtol=1e-3) + + def test_matmul_with_sliced_accumulator(self): + self.skip_if_wg_semantics() + dtype = jnp.bfloat16 + shape = (128, 128) + tmem_shape = (128, 2 * 128) + swizzle = 128 + + # Test a matmul with a single block. + swizzle_elems = swizzle // jnp.dtype(dtype).itemsize + transforms = ( + plgpu.TilingTransform((8, swizzle_elems)), + plgpu.SwizzleTransform(swizzle), + ) + + def kernel(a_smem, b_smem, out_ref, acc_tmem, scratch_smem, barrier_ref): + acc_tmem_slice = acc_tmem.at[slice(None), pl.dslice(0, 128)] + plgpu.tcgen05_mma(acc_tmem_slice, + a_smem, + b_smem, + barrier_ref, + accumulate=False) + plgpu.barrier_wait(barrier_ref) + scratch_smem[...] = acc_tmem_slice[...].astype(dtype) + plgpu.commit_smem() + plgpu.copy_smem_to_gmem(scratch_smem, out_ref) + plgpu.wait_smem_to_gmem(0) + + scratch_shapes = [ + plgpu.TMEM(tmem_shape, jnp.float32, packed=False), + plgpu.SMEM(shape, dtype, transforms=transforms), + plgpu.Barrier(for_tensor_core=True), + ] + + f = self.pallas_call( + kernel, + in_specs=( + plgpu.BlockSpec(transforms=transforms, memory_space=plgpu.SMEM), + plgpu.BlockSpec(transforms=transforms, memory_space=plgpu.SMEM), + ), + out_specs=plgpu.BlockSpec(memory_space=plgpu.GMEM), + out_shape=jax.ShapeDtypeStruct(shape, dtype), + scratch_shapes=scratch_shapes, + ) + x = jax.random.uniform(jax.random.key(0), shape=shape, dtype=dtype) + y = jax.random.uniform(jax.random.key(1), shape=shape, dtype=dtype) + result = f(x, y) + expected = x @ y + np.testing.assert_allclose(result, expected, rtol=1e-3) + + @parameterized.product( + m_n_k=[(256, 256, 256), (256, 128, 128), (256, 256, 64)], + swizzle=[128, 64, 32], + dtype=[jnp.float16, jnp.bfloat16], + lhs_tmem=[False, True], + ) + def test_simple_collective_matmul(self, m_n_k, swizzle, dtype, lhs_tmem): + self.skip_if_wg_semantics() + m, n, k = m_n_k + full_lhs_shape = (m, k) + full_rhs_shape = (k, n) + full_acc_shape = (m, n) + block_acc_shape = (m // 2, n) + block_lhs_shape = (m // 2, k) + block_rhs_shape = (k, n // 2) + # Test a collective (paired CTA) matmul on a single block. + swizzle_elems = swizzle // jnp.dtype(dtype).itemsize + transforms = ( + plgpu.TilingTransform((8, swizzle_elems)), + plgpu.SwizzleTransform(swizzle), + ) + + def kernel(a_gmem, b_gmem, out_gmem, a_smem, b_smem, + scratch_smem, acc_tmem, tma_barrier, mma_barrier, + cluster_barrier, lhs_tmem_ref): + cluster_idx = lax.axis_index("x") + slice_lhs = pl.ds(cluster_idx * block_lhs_shape[0], block_lhs_shape[0]) + slice_rhs = pl.ds(cluster_idx * block_rhs_shape[1], block_rhs_shape[1]) + + plgpu.copy_gmem_to_smem(a_gmem.at[slice_lhs, :], a_smem, tma_barrier) + plgpu.barrier_wait(tma_barrier) + plgpu.copy_gmem_to_smem(b_gmem.at[:, slice_rhs], b_smem, tma_barrier) + plgpu.barrier_wait(tma_barrier) + + if lhs_tmem: + lhs_ref = lhs_tmem_ref + lhs_ref[...] = plgpu.load(a_smem, (), layout=plgpu.Layout.TCGEN05) + plgpu.commit_tmem() + else: + lhs_ref = a_smem + + plgpu.barrier_arrive(cluster_barrier) + plgpu.barrier_wait(cluster_barrier) + + plgpu.tcgen05_mma( + acc_tmem, + lhs_ref, + b_smem, + mma_barrier, + accumulate=False, + collective_axis="x", + ) + plgpu.barrier_wait(mma_barrier) + scratch_smem[...] = acc_tmem[...].astype(dtype) + plgpu.commit_smem() + plgpu.copy_smem_to_gmem(scratch_smem, out_gmem.at[slice_lhs, :]) + plgpu.wait_smem_to_gmem(0) + + scratch_shapes = [ + plgpu.SMEM(block_lhs_shape, dtype, transforms=transforms), + plgpu.SMEM(block_rhs_shape, dtype, transforms=transforms), + plgpu.SMEM(block_acc_shape, dtype, transforms=transforms), + plgpu.TMEM(block_acc_shape, jnp.float32, collective=True), + plgpu.Barrier(), + plgpu.Barrier(for_tensor_core=True), + plgpu.ClusterBarrier(collective_axes=("x",)), + ] + if lhs_tmem: + scratch_shapes.append( + plgpu.TMEM(block_lhs_shape, dtype, collective=True, packed=True) + ) + else: + scratch_shapes.append(None) + + f = self.kernel( + kernel, + out_shape=jax.ShapeDtypeStruct(full_acc_shape, dtype), + grid=(1,), + grid_names=("_",), + cluster=(2,), + cluster_names=("x",), + scratch_shapes=scratch_shapes, + ) + x = jax.random.uniform(jax.random.key(0), shape=full_lhs_shape, dtype=dtype) + y = jax.random.uniform(jax.random.key(1), shape=full_rhs_shape, dtype=dtype) + result = f(x, y) + expected = x @ y + np.testing.assert_allclose(result, expected, rtol=1e-3) + + @parameterized.parameters((0,), (1,)) + def test_mma_barrier_indexing( + self, barrier_index, shape=(128, 128), swizzle=128, dtype=jnp.float16 + ): + self.skip_if_wg_semantics() + swizzle_elems = swizzle // jnp.dtype(dtype).itemsize + transforms = ( + plgpu.TilingTransform((8, swizzle_elems)), + plgpu.SwizzleTransform(swizzle), + ) + + def kernel(a_smem, b_smem, out_ref, acc_tmem, scratch_smem, barrier_ref): + plgpu.tcgen05_mma( + acc_tmem, + a_smem, + b_smem, + barrier_ref.at[barrier_index], + accumulate=False, + ) + plgpu.barrier_wait(barrier_ref.at[barrier_index]) + scratch_smem[...] = acc_tmem[...].astype(dtype) + plgpu.commit_smem() + plgpu.copy_smem_to_gmem(scratch_smem, out_ref) + plgpu.wait_smem_to_gmem(0) + + scratch_shapes = [ + plgpu.TMEM(shape, jnp.float32, packed=False), + plgpu.SMEM(shape, dtype, transforms=transforms), + plgpu.Barrier(num_barriers=2, for_tensor_core=True), + ] + f = self.pallas_call( + kernel, + in_specs=( + plgpu.BlockSpec(transforms=transforms, memory_space=plgpu.SMEM), + plgpu.BlockSpec(transforms=transforms, memory_space=plgpu.SMEM), + ), + out_specs=plgpu.BlockSpec(memory_space=plgpu.GMEM), + out_shape=jax.ShapeDtypeStruct(shape, dtype), + scratch_shapes=scratch_shapes, + ) + x = jax.random.uniform(jax.random.key(0), shape=shape, dtype=dtype) + y = jax.random.uniform(jax.random.key(1), shape=shape, dtype=dtype) + result = f(x, y) + expected = x @ y + np.testing.assert_allclose(result, expected, rtol=1e-3) + + def test_collective_partitioned_copy(self): + self.skip_if_wg_semantics() + block_size = (128, 128) + partitioned_block_size = (block_size[0] // 2, block_size[1]) + a = jax.random.uniform( + jax.random.key(0), shape=block_size, dtype=jnp.float32) + b = jax.random.uniform( + jax.random.key(1), shape=block_size, dtype=jnp.float32) + def kernel(a_gmem, b_gmem, out_gmem, + a_smem, b_smem, out_smem, + a_tma_barrier, b_tma_barrier, cluster_barrier): + cluster_idx = lax.axis_index("x") + out_slice = pl.ds(cluster_idx * partitioned_block_size[0], + partitioned_block_size[0]) + + @pl.core_map(plgpu.WarpMesh(axis_name="warp")) + def _per_warp(): + warp_id = lax.axis_index("warp") + @pl.when(warp_id == 0) + def _(): + plgpu.copy_gmem_to_smem( + a_gmem, + a_smem, + a_tma_barrier, + collective_axes="x", + partitioned_axis=0, + ) + plgpu.copy_gmem_to_smem( + b_gmem, + b_smem, + b_tma_barrier, + collective_axes="x", + partitioned_axis=0, + ) + # TODO(justinfu): Clean up this API where we need to explicitly wait + # only on the first block. + @pl.when(cluster_idx == 0) + def _(): + plgpu.barrier_wait(a_tma_barrier) + plgpu.barrier_wait(b_tma_barrier) + plgpu.barrier_arrive(cluster_barrier) + plgpu.barrier_wait(cluster_barrier) + out_smem[...] = a_smem[...] + b_smem[...] + plgpu.copy_smem_to_gmem(out_smem, out_gmem.at[out_slice]) + plgpu.wait_smem_to_gmem(0) + f = plgpu.kernel( + kernel, + out_shape=jax.ShapeDtypeStruct(block_size, jnp.float32), + grid=(1,), + grid_names=("_"), + cluster_names=("x",), + cluster=(2,), + scratch_shapes=( # type: ignore + plgpu.SMEM(partitioned_block_size, jnp.float32), + plgpu.SMEM(partitioned_block_size, jnp.float32), + plgpu.SMEM(partitioned_block_size, jnp.float32), + plgpu.Barrier(num_arrivals=1), + plgpu.Barrier(num_arrivals=1), + plgpu.ClusterBarrier(collective_axes=("x",)), + ), + ) + result = f(a, b) + np.testing.assert_array_equal(result, a + b) + + +class PallasCallSm100AWGTest( + PallasCallSm100ATest, lowering_semantics=plgpu.LoweringSemantics.Warpgroup +): + ... + + +class PipelineTest(PallasTest): + + def test_pipeline_mode(self): + def body(x_ref, y_ref, o_ref): + x = x_ref[:] + y = y_ref[:] + o_ref[:] = x + y + + data_size = 64 * 256 + block_size = 256 + + x = jnp.arange(data_size, dtype=jnp.float32) + y = jnp.arange(data_size, dtype=jnp.float32) + in_specs = [ + pl.BlockSpec((block_size,), lambda *i: i, pipeline_mode=pl.Buffered(2)), + pl.BlockSpec((block_size,), lambda *i: i, pipeline_mode=pl.Buffered(1)) + ] + out_specs = pl.BlockSpec((block_size,), lambda *i: i) + + @jax.jit + def vadd(x, y): + return self.pallas_call( + body, + out_shape=jax.ShapeDtypeStruct(x.shape, jnp.float32), + in_specs=in_specs, + out_specs=out_specs, + grid=data_size // block_size, + )(x, y) + + with self.assertRaisesRegex(Exception, "Pipeline mode is not supported"): + vadd(x, y) + + def test_manual(self): + max_concurrent_steps = 2 + num_steps = 4 def kernel(x_gmem, o_gmem): return pl.run_scoped( functools.partial(scoped_kernel, x_gmem, o_gmem), plgpu.SMEM((max_concurrent_steps, 32, 16), jnp.float32), plgpu.SMEM((max_concurrent_steps, 32, 16), jnp.float32), - plgpu.Barrier(1, num_barriers=max_concurrent_steps), + plgpu.Barrier(num_barriers=max_concurrent_steps), ) def scoped_kernel(x_gmem, o_gmem, x_smem, o_smem, barrier): @@ -1588,7 +3014,7 @@ def body(step, _): plgpu.wait_smem_to_gmem(0) x = jnp.arange(32 * 4 * 64).reshape(32 * 4, 64).astype(jnp.float32) - kernel_fn = pl.pallas_call( + kernel_fn = self.pallas_call( kernel, in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)], out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), @@ -1597,38 +3023,45 @@ def body(step, _): ) np.testing.assert_array_equal(kernel_fn(x), x + 1.0) - @parameterized.parameters( - ((),), - ((plgpu.TilingTransform((64, 32)), plgpu.SwizzleTransform(128)),), + @parameterized.product( + transforms=( + (), + (plgpu.TilingTransform((8, 32)), plgpu.SwizzleTransform(128)), + ), + repeats=(1, 3), ) - def test_emit(self, transforms): + def test_emit(self, transforms, repeats): + if transforms: + self.skip_if_wg_semantics() + num_steps = 4 def kernel(x_gmem, o_gmem): - plgpu.emit_pipeline( - kernel_body, - in_specs=[ - plgpu.GPUBlockSpec( - (64, 64), lambda i: (0, i), transforms=transforms - ) - ], - out_specs=[ - plgpu.GPUBlockSpec( - (64, 64), lambda i: (0, i), transforms=transforms - ) - ], - grid=(num_steps,), - max_concurrent_steps=2, - )(x_gmem, o_gmem) + for _ in range(repeats): + plgpu.emit_pipeline( + kernel_body, + in_specs=[ + plgpu.BlockSpec( + (64, 64), lambda i: (0, i), transforms=transforms + ) + ], + out_specs=[ + plgpu.BlockSpec( + (64, 64), lambda i: (0, i), transforms=transforms + ) + ], + grid=(num_steps,), + max_concurrent_steps=2, + )(x_gmem, o_gmem) - def kernel_body(x_smem, o_smem): + def kernel_body(_, x_smem, o_smem): # +1 for the indexing done by ``emit_pipeline`. self.assertLen(x_smem.transforms, len(transforms) + 1) o_smem[...] = x_smem[...] + 1.0 x = jnp.arange(64 * num_steps * 64) x = x.reshape(-1, num_steps * 64).astype(jnp.float32) - kernel_fn = pl.pallas_call( + kernel_fn = self.pallas_call( kernel, in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)], out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), @@ -1647,7 +3080,7 @@ def kernel(x_gmem, o_gmem): grid=(), )(x_gmem, o_gmem) - def nested_kernel(x_gmem, o_gmem): + def nested_kernel(_, x_gmem, o_gmem): plgpu.emit_pipeline( nested_kernel_body, in_specs=[pl.BlockSpec((32, 16), lambda i: (0, i))], @@ -1656,12 +3089,12 @@ def nested_kernel(x_gmem, o_gmem): max_concurrent_steps=2, )(x_gmem, o_gmem) - def nested_kernel_body(x_smem, o_smem): + def nested_kernel_body(_, x_smem, o_smem): o_smem[...] = x_smem[...] + 1.0 x = jnp.arange(32 * num_steps * 16) x = x.reshape(-1, num_steps * 16).astype(jnp.float32) - kernel_fn = pl.pallas_call( + kernel_fn = self.pallas_call( kernel, in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)], out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), @@ -1681,12 +3114,12 @@ def kernel(x_gmem, o_gmem): max_concurrent_steps=2, )(x_gmem, o_gmem) - def kernel_body(x_smem, o_smem): + def kernel_body(_, x_smem, o_smem): o_smem[...] = x_smem[...] + 1.0 x = jnp.arange(32 * num_steps * 16) x = x.reshape(-1, num_steps * 16).astype(jnp.float32) - kernel_fn = pl.pallas_call( + kernel_fn = self.pallas_call( kernel, in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)], out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), @@ -1714,12 +3147,12 @@ def kernel(x_gmem, o_gmem): max_concurrent_steps=2, )(x_gmem, o_gmem) - def kernel_body(x_smem, o_smem): + def kernel_body(_, x_smem, o_smem): o_smem[...] = x_smem[...] + 1.0 x = jnp.arange(num_steps1 * 32 * num_steps2 * 16) x = x.reshape(-1, num_steps2 * 16).astype(jnp.float32) - kernel_fn = pl.pallas_call( + kernel_fn = self.pallas_call( kernel, in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)], out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), @@ -1729,25 +3162,32 @@ def kernel_body(x_smem, o_smem): y = x + 1.0 np.testing.assert_array_equal(kernel_fn(x), y) - def test_emit_with_2d_grid(self): + @parameterized.product(static=[False, True], short=[False, True]) + def test_emit_with_2d_grid(self, static, short): num_steps1 = 4 num_steps2 = 5 + if short: + num_steps1 = num_steps2 = 1 def kernel(x_gmem, o_gmem): + grid = (num_steps1, num_steps2) + if static: + grid = jax.tree.map(jnp.asarray, grid) + plgpu.emit_pipeline( kernel_body, in_specs=[pl.BlockSpec((32, 16, 8), lambda i, j: (0, i, j))], out_specs=[pl.BlockSpec((32, 16, 8), lambda i, j: (0, i, j))], - grid=(num_steps1, num_steps2), + grid=grid, max_concurrent_steps=2, )(x_gmem, o_gmem) - def kernel_body(x_smem, o_smem): + def kernel_body(_, x_smem, o_smem): o_smem[...] = x_smem[...] + 1.0 x = jnp.arange(32 * num_steps1 * 16 * num_steps2 * 8) x = x.reshape(-1, num_steps1 * 16, num_steps2 * 8).astype(jnp.float32) - kernel_fn = pl.pallas_call( + kernel_fn = self.pallas_call( kernel, in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)], out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), @@ -1755,10 +3195,43 @@ def kernel_body(x_smem, o_smem): ) np.testing.assert_array_equal(kernel_fn(x), x + 1.0) + def test_emit_with_carry(self): + num_steps = 4 + + def kernel(o_gmem): + plgpu.emit_pipeline( + kernel_body, + out_specs=[pl.BlockSpec((64, 64), lambda i: (0, i))], + grid=(num_steps,), + max_concurrent_steps=2, + init_carry=0, + )(o_gmem) + + def kernel_body(_, o_smem, carry): + o_smem[...] = lax.broadcast(carry, o_smem.shape) + return carry + 1 + + kernel_fn = self.pallas_call( + kernel, + out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), + out_shape=jax.ShapeDtypeStruct((64, num_steps * 64), jnp.int32), + ) + np.testing.assert_array_equal( + kernel_fn(), jnp.tile(jnp.repeat(jnp.arange(num_steps), 64), (64, 1)) + ) + + +class PipelineWGTest( + PipelineTest, lowering_semantics=plgpu.LoweringSemantics.Warpgroup +): + ... + class PipelineSm90ATest(PallasSm90ATest): def test_realistic_matmul(self): + self.skip_if_wg_semantics() # Needs WGMMA to support slices. + dtype = jnp.float16 swizzle = 128 elems_128b = swizzle // jnp.dtype(dtype).itemsize @@ -1768,8 +3241,15 @@ def test_realistic_matmul(self): tile_k = elems_128b m, k, n = grid_m * tile_m, grid_k * tile_k, grid_n * tile_n + transforms = () + if self.LOWERING_SEMANTICS == plgpu.LoweringSemantics.Lane: + transforms = ( + plgpu.TilingTransform((8, elems_128b)), + plgpu.SwizzleTransform(128), + ) + def kernel(a_gmem, b_gmem, o_smem, acc): - def kernel_body(a_smem, b_smem): + def kernel_body(_, a_smem, b_smem): assert a_smem.shape == (tile_m, tile_k) assert b_smem.shape == (tile_k, tile_n) plgpu.wgmma(acc, a_smem, b_smem) @@ -1780,22 +3260,12 @@ def kernel_body(a_smem, b_smem): plgpu.emit_pipeline( kernel_body, in_specs=[ - plgpu.GPUBlockSpec( - (tile_m, tile_k), - lambda k: (pid_m, k), - transforms=( - plgpu.TilingTransform((64, elems_128b)), - plgpu.SwizzleTransform(128), - ), - ), - plgpu.GPUBlockSpec( - (tile_k, tile_n), - lambda k: (k, pid_n), - transforms=( - plgpu.TilingTransform((elems_128b, elems_128b)), - plgpu.SwizzleTransform(128), - ), - ), + plgpu.BlockSpec( + (tile_m, tile_k), lambda k: (pid_m, k), transforms=transforms + ), + plgpu.BlockSpec( + (tile_k, tile_n), lambda k: (k, pid_n), transforms=transforms + ), ], grid=(grid_k,), max_concurrent_steps=2, @@ -1808,19 +3278,14 @@ def kernel_body(a_smem, b_smem): a = jax.random.uniform(key1, shape=(m, k), dtype=dtype) b = jax.random.uniform(key2, shape=(k, n), dtype=dtype) - res = pl.pallas_call( + res = self.pallas_call( kernel, in_specs=[ pl.BlockSpec(memory_space=plgpu.GMEM), - pl.BlockSpec(memory_space=plgpu.GMEM) + pl.BlockSpec(memory_space=plgpu.GMEM), ], - out_specs=plgpu.GPUBlockSpec( - (tile_m, tile_n), - lambda m, n: (m, n), - transforms=( - plgpu.TilingTransform((64, elems_128b)), - plgpu.SwizzleTransform(128), - ), + out_specs=plgpu.BlockSpec( + (tile_m, tile_n), lambda m, n: (m, n), transforms=transforms ), out_shape=jax.ShapeDtypeStruct((m, n), jnp.float16), scratch_shapes=[plgpu.ACC((tile_m, tile_n), jnp.float32)], @@ -1829,17 +3294,23 @@ def kernel_body(a_smem, b_smem): np.testing.assert_array_equal(res, a @ b) +class PipelineSm90AWGTest( + PipelineSm90ATest, lowering_semantics=plgpu.LoweringSemantics.Warpgroup +): + ... + + class WarpSpecializedPipelineTest(PallasTest): - @parameterized.product(m=[512], n=[512], + @parameterized.product(m=[512], n=[512], repeats=[1, 3], manual_consumed_barriers=[False, True]) - def test_pipelined_copy(self, m, n, manual_consumed_barriers): + def test_pipelined_copy(self, m, n, repeats, manual_consumed_barriers): + self.skip_if_wg_semantics() # Times out! + x = jax.random.uniform(jax.random.key(0), (m, n), dtype=jnp.float16) - o = jnp.zeros((m, n), dtype=jnp.float16) blk_m = blk_n = 64 - o_last_block = jnp.zeros((blk_m, blk_n), dtype=jnp.float16) - def copy_kernel(x_smem, o_smem, o_last_block_smem, *consumed_barriers): + def copy_kernel(_, x_smem, o_smem, o_last_block_smem, *consumed_barriers): # TODO(justinfu): Have each wg compute a separate slice # after multiple-indexers are supported. # This is currently a race, but the values written are the same. @@ -1848,109 +3319,115 @@ def copy_kernel(x_smem, o_smem, o_last_block_smem, *consumed_barriers): if manual_consumed_barriers: [x_barrier] = consumed_barriers plgpu.barrier_arrive(x_barrier) - block_spec = plgpu.GPUBlockSpec( - block_shape=(blk_m, blk_n), - index_map=lambda i, j: (i, j), - transforms=[], - ) - pipeline = mgpu_pipeline.emit_pipeline_warp_specialized( - copy_kernel, - grid=(m // blk_m, n // blk_n), - memory_registers=40, - max_concurrent_steps=2, - num_compute_wgs=2, - wg_axis="wg", - manual_consumed_barriers=manual_consumed_barriers, - in_specs=[block_spec], - out_specs=[block_spec, - # Create an index-invariant output. - plgpu.GPUBlockSpec(block_shape=(blk_m, blk_n), - index_map=lambda i, j: (0, 0)) - ], - ) - mesh = plgpu.GPUMesh(grid=(1,), num_threads=3, axis_names=("_", "wg")) - def run(refs): - @pl.core_map( - mesh, compiler_params=plgpu.GPUCompilerParams(approx_math=True) + + spec = pl.BlockSpec( + block_shape=(blk_m, blk_n), index_map=lambda i, j: (i, j) + ) + def body(*gmem_refs): + pipeline = mgpu_pipeline.emit_pipeline_warp_specialized( + copy_kernel, + grid=(m // blk_m, n // blk_n), + memory_registers=40, + max_concurrent_steps=2, + num_compute_wgs=2, + wg_axis="wg", + manual_consumed_barriers=manual_consumed_barriers, + in_specs=[spec], + out_specs=[ + spec, + # Create an index-invariant output. + pl.BlockSpec( + block_shape=(blk_m, blk_n), index_map=lambda i, j: (0, 0) + ), + ], ) - def _kernel_entry(): - pipeline(*refs) - @jax.jit - def run_function(x, o, o_last_block): - _, out, out_last = pl.run_state(run)((x, o, o_last_block)) - return (out, out_last) - out, out_last_block = run_function(x, o, o_last_block) + for _ in range(repeats): + pipeline(*gmem_refs) # Make sure we can run the pipeline multiple times + kernel = self.kernel( + body, + out_shape=( + jax.ShapeDtypeStruct((m, n), jnp.float16), + jax.ShapeDtypeStruct((blk_m, blk_n), jnp.float16), + ), + compiler_params=plgpu.CompilerParams(approx_math=True), + grid=(1,), + grid_names=("_",), + num_threads=3, + thread_name="wg", + ) + out, out_last_block = kernel(x) np.testing.assert_array_equal(out, x) np.testing.assert_array_equal(out_last_block, x[-blk_m:, -blk_n:]) - def test_elementwise_add(self, m=256, n=256, num_compute_wgs=2): + @parameterized.product( + m=[256, 64], n=[256, 64], num_compute_wgs=[1, 2], static=[False, True] + ) + def test_elementwise_add(self, m, n, num_compute_wgs, static): + self.skip_if_wg_semantics() # Crashes! + blk_m = blk_n = 64 - x = jax.random.uniform(jax.random.key(0), (m, n), dtype=jnp.float32) - y = jax.random.uniform(jax.random.key(1), (m, n), dtype=jnp.float32) - o = jnp.zeros((m, n), dtype=jnp.float32) + spec = pl.BlockSpec( + block_shape=(blk_m, blk_n), index_map=lambda i, j: (i, j) + ) - def tiled_add_kernel(x_smem, y_smem, o_smem): + def tiled_add_kernel(_, x_smem, y_smem, o_smem): # TODO(justinfu): Have each wg compute a separate slice # after multiple-indexers are supported. # This is currently a race, but the values written are the same. o_smem[...] = x_smem[...] + y_smem[...] - pipeline = mgpu_pipeline.emit_pipeline_warp_specialized( - tiled_add_kernel, - grid=(m // blk_m, n // blk_n), - max_concurrent_steps=2, - num_compute_wgs=num_compute_wgs, - memory_registers=40, - wg_axis="wg", - in_specs=[ - plgpu.GPUBlockSpec( - block_shape=(blk_m, blk_n), - index_map=lambda i, j: (i, j), - transforms=[]), - plgpu.GPUBlockSpec( - block_shape=(blk_m, blk_n), - index_map=lambda i, j: (i, j), - transforms=[]), - ], - out_specs=[ - plgpu.GPUBlockSpec( - block_shape=(blk_m, blk_n), - index_map=lambda i, j: (i, j), - transforms=[])], - ) - mesh = plgpu.GPUMesh( - grid=(1,), num_threads=num_compute_wgs + 1, axis_names=("_", "wg") + def pipeline(*gmem_refs): + grid = (m // blk_m, n // blk_n) + if not static: + grid = jax.tree.map(jnp.asarray, grid) + return mgpu_pipeline.emit_pipeline_warp_specialized( + tiled_add_kernel, + grid=grid, + max_concurrent_steps=2, + num_compute_wgs=num_compute_wgs, + memory_registers=40, + wg_axis="wg", + in_specs=[spec, spec], + out_specs=[spec], + )(*gmem_refs) + + kernel = self.kernel( + pipeline, + out_shape=jax.ShapeDtypeStruct((m, n), jnp.float32), + compiler_params=plgpu.CompilerParams(approx_math=True), + grid=(1,), + grid_names=("_",), + num_threads=num_compute_wgs + 1, + thread_name="wg", ) - def run(refs): - @pl.core_map( - mesh, compiler_params=plgpu.GPUCompilerParams(approx_math=True) - ) - def _kernel_entry(): - pipeline(*refs) - @jax.jit - def run_function(x, y, o): - _, _, out = pl.run_state(run)((x, y, o)) - return out - out = run_function(x, y, o) - reference = x + y - np.testing.assert_allclose(out, reference, atol=1e-4) + x = jax.random.uniform(jax.random.key(0), (m, n), dtype=jnp.float32) + y = jax.random.uniform(jax.random.key(1), (m, n), dtype=jnp.float32) + np.testing.assert_allclose(kernel(x, y), x + y, atol=1e-4) def test_carry_accumulate(self, m=256, n=256, num_compute_wgs=2): blk_m = blk_n = 64 - x = jax.random.uniform(jax.random.key(0), (m, n), dtype=jnp.float32) - acc_init = jnp.zeros((blk_m, blk_n), dtype=jnp.float32) - def _scoped(acc_smem, x_gmem, acc_gmem): - def _compute_thread(): + @functools.partial( + self.kernel, + out_shape=jax.ShapeDtypeStruct((blk_m, blk_n), jnp.float32), + scratch_shapes=[ + plgpu.SMEM((blk_m, blk_n), jnp.float32), + ], + compiler_params=plgpu.CompilerParams(approx_math=True), + grid=(1,), + grid_names=("_",), + num_threads=num_compute_wgs + 1, + thread_name="wg", + ) + def kernel(x_gmem, acc_gmem, acc_smem): + def _compute_thread(pipeline_fn): # Cast the init value to the same layout as x_smem, so the pipeline loop # carry has a constant signature. o_acc = plgpu.layout_cast( jnp.full((blk_m, blk_n,), 0, dtype=jnp.float32), plgpu.Layout.WG_STRIDED((blk_m, blk_n), vec_size=2)) - carry_init = (o_acc,) # Pass control to the pipeline emitter and return the final carry. - final_carry = (yield carry_init) - o_final, = final_carry + o_final = pipeline_fn(o_acc) # Note that both compute WGs are doing identical work so the potential # race condition on the store here won't affect the result. acc_smem[...] = o_final @@ -1958,10 +3435,9 @@ def _compute_thread(): plgpu.copy_smem_to_gmem(acc_smem, acc_gmem) plgpu.wait_smem_to_gmem(0) - def tiled_acc_kernel(x_smem, carry): - o_carry, = carry - new_carry = x_smem[...] + o_carry - return (new_carry,) + def tiled_acc_kernel(_, x_smem, carry): + new_carry = x_smem[...] + carry + return new_carry pipeline = mgpu_pipeline.emit_pipeline_warp_specialized( tiled_acc_kernel, @@ -1970,79 +3446,68 @@ def tiled_acc_kernel(x_smem, carry): num_compute_wgs=num_compute_wgs, memory_registers=40, wg_axis="wg", - carry_coroutine=_compute_thread, + compute_context=_compute_thread, in_specs=[ - plgpu.GPUBlockSpec( - block_shape=(blk_m, blk_n), - index_map=lambda i, j: (i, j), - transforms=[]), + pl.BlockSpec( + block_shape=(blk_m, blk_n), index_map=lambda i, j: (i, j) + ) ], out_specs=[], ) pipeline(x_gmem) - mesh = plgpu.GPUMesh( - grid=(1,), - num_threads=num_compute_wgs + 1, - axis_names=("_", "wg",), - ) - def run(refs): - x_ref, acc_ref = refs - @pl.core_map(mesh) - def _kernel_entry(): - pl.run_scoped( - functools.partial(_scoped, x_gmem=x_ref, acc_gmem=acc_ref), - plgpu.SMEM((blk_m, blk_n), jnp.float32) - ) - @jax.jit - def run_function(x, acc): - _, out_acc = pl.run_state(run)((x, acc)) - return out_acc - out_acc = run_function(x, acc_init) + x = jax.random.uniform(jax.random.key(0), (m, n), dtype=jnp.float32) ref = jnp.sum(jnp.stack(np.split(x, m // blk_m, axis=0)), axis=0) ref = jnp.sum(jnp.stack(np.split(ref, n // blk_n, axis=1)), axis=0) - np.testing.assert_allclose(out_acc, ref, atol=1e-4) + np.testing.assert_allclose(kernel(x), ref, atol=1e-4) + + +class WarpSpecializedPipelineWGTest( + WarpSpecializedPipelineTest, + lowering_semantics=plgpu.LoweringSemantics.Warpgroup, +): + ... -class CoreMapTest(PallasTest): +class CoreMapTest(PallasTest, jtu.CudaArchSpecificTest): def test_multiple_wg(self): - mesh = plgpu.GPUMesh(num_threads=2, axis_names=("y",)) - @jax.jit - def f(): - @pl.run_state - def inner(y_ref): - @pl.core_map(mesh) - def kernel(): - wg_idx = jax.lax.axis_index("y") - y_ref[wg_idx] = jnp.broadcast_to(wg_idx, (128,)) - y_init = jnp.zeros((2, 128), np.int32) - return inner(y_init) + @functools.partial( + self.kernel, + out_shape=jnp.zeros((2, 128), np.int32), + num_threads=2, + thread_name="wg", + ) + def kernel(o_ref): + wg_idx = jax.lax.axis_index("wg") + o_ref[wg_idx] = jnp.broadcast_to(wg_idx, (128,)) + np.testing.assert_array_equal( - f(), np.repeat(np.arange(2), 128).reshape(2, 128) + kernel(), np.repeat(np.arange(2), 128).reshape(2, 128) ) def test_multiple_wg_with_grid(self): - mesh = plgpu.GPUMesh(grid=(2, 2), num_threads=2, axis_names=("x", "y", "wg")) - @jax.jit - def f(): - @pl.run_state - def inner(y_ref): - @pl.core_map(mesh) - def kernel(): - xy_idx = jax.lax.axis_index(("x", "y")) - yx_idx = jax.lax.axis_index(("y", "x")) - wg_idx = jax.lax.axis_index("wg") - num_wgs = jax.lax.psum(1, "wg") - y_ref[xy_idx, wg_idx] = jnp.broadcast_to( - yx_idx * num_wgs + wg_idx, (128,) - ) - y_init = jnp.zeros((4, 2, 128), np.int32) - return inner(y_init) + @functools.partial( + self.kernel, + out_shape=jnp.zeros((4, 2, 128), np.int32), + grid=(2, 2), + grid_names=("x", "y"), + num_threads=2, + thread_name="wg", + ) + def kernel(o_ref): + xy_idx = jax.lax.axis_index(("x", "y")) + yx_idx = jax.lax.axis_index(("y", "x")) + wg_idx = jax.lax.axis_index("wg") + num_wgs = jax.lax.axis_size("wg") + o_ref[xy_idx, wg_idx] = jnp.broadcast_to( + yx_idx * num_wgs + wg_idx, (128,) + ) + np.testing.assert_array_equal( - f(), np.repeat([0, 1, 4, 5, 2, 3, 6, 7], 128).reshape(4, 2, 128) + kernel(), np.repeat([0, 1, 4, 5, 2, 3, 6, 7], 128).reshape(4, 2, 128) ) def test_multiple_wg_with_squashed_grid(self): @@ -2053,104 +3518,366 @@ def test_multiple_wg_with_squashed_grid(self): y_dim = 5 z_dim = 7 num_threads = 2 - mesh = plgpu.GPUMesh(grid=(b, x_dim, y_dim, z_dim), - num_threads=num_threads, - axis_names=("b", "x", "y", "z", "wg")) - @jax.jit - def f(): - @pl.run_state - def inner(y_ref): - @pl.core_map(mesh) - def _(): - b_idx = jax.lax.axis_index("b") - x_idx = jax.lax.axis_index("x") - y_idx = jax.lax.axis_index("y") - z_idx = jax.lax.axis_index("z") - wg_idx = jax.lax.axis_index("wg") - bxyzw_idx = jax.lax.axis_index(("b", "x", "y", "z", "wg")) - y_ref[b_idx, x_idx, y_idx, z_idx, wg_idx] = jnp.broadcast_to( - bxyzw_idx, (128,) - ) - y_init = jnp.zeros((b, x_dim, y_dim, z_dim, num_threads, 128), np.int32) - return inner(y_init) - result = f()[:, :, :, :, :, 0] + @functools.partial( + self.kernel, + out_shape=jnp.zeros( + (b, x_dim, y_dim, z_dim, num_threads, 128), np.int32 + ), + grid=(b, x_dim, y_dim, z_dim), + grid_names=("b", "x", "y", "z"), + num_threads=num_threads, + thread_name="wg", + ) + def kernel(o_ref): + b_idx = jax.lax.axis_index("b") + x_idx = jax.lax.axis_index("x") + y_idx = jax.lax.axis_index("y") + z_idx = jax.lax.axis_index("z") + wg_idx = jax.lax.axis_index("wg") + bxyzw_idx = jax.lax.axis_index(("b", "x", "y", "z", "wg")) + o_ref[b_idx, x_idx, y_idx, z_idx, wg_idx] = jnp.broadcast_to( + bxyzw_idx, (128,) + ) + + result = kernel()[:, :, :, :, :, 0] ref = np.arange(b * x_dim * y_dim * z_dim * num_threads).reshape( - result.shape) + result.shape + ) np.testing.assert_array_equal(result, ref) - def test_cross_wg_barrier(self): - mesh = plgpu.GPUMesh(num_threads=2, axis_names=("wg",)) + self.skip_if_wg_semantics() # Times out! - @jax.jit - def f(): - @pl.run_state - def inner(y_ref): - @pl.core_map(mesh) - def kernel(): - def scoped(barrier): - plgpu.barrier_arrive(barrier) - plgpu.barrier_wait(barrier) - wg_idx = jax.lax.axis_index("wg") - y_ref[wg_idx] = jnp.broadcast_to(wg_idx, (128,)) - # Each warpgroup is a single logical thread! - pl.run_scoped(scoped, plgpu.Barrier(num_arrivals=2)) - y_init = jnp.zeros((2, 128), np.int32) - return inner(y_init) - np.testing.assert_array_equal(f(), np.repeat([0, 1], 128).reshape(2, 128)) + @functools.partial( + self.kernel, + out_shape=jnp.zeros((2, 128), np.int32), + # Each warpgroup is a single logical thread! + scratch_shapes=[plgpu.Barrier(num_arrivals=2)], + num_threads=2, + thread_name="wg", + ) + def kernel(o_ref, barrier): + plgpu.barrier_arrive(barrier) + plgpu.barrier_wait(barrier) + wg_idx = jax.lax.axis_index("wg") + o_ref[wg_idx] = jnp.broadcast_to(wg_idx, (128,)) + + np.testing.assert_array_equal( + kernel(), np.repeat([0, 1], 128).reshape(2, 128) + ) + + def test_cluster(self): + self.skip_if_wg_semantics() # Needs debug_print in the MGPU dialect. + + @functools.partial( + self.kernel, + out_shape=jnp.zeros(128, np.int32), + grid=(2,), + grid_names=("x",), + cluster=(2,), + cluster_names=("cluster",), + ) + def kernel(ref): + block_idx = jax.lax.axis_index("x") + cluster_idx = jax.lax.axis_index("cluster") + pl.debug_print("block: {} cluster: {}", block_idx, cluster_idx) + + ref[...] = ref[...] + + with self.capture_stdout() as output: + jax.block_until_ready(kernel()) + self.assertEqual( + set(output().splitlines()), + { + "block: 0 cluster: 0", + "block: 1 cluster: 0", + "block: 0 cluster: 1", + "block: 1 cluster: 1", + }, + ) + + def test_realistic_matmul_with_cluster(self): + self.skip_if_wg_semantics() # Needs WGMMA to support slices. + self.skip_unless_sm90a() # Requires WGMMA. + + dtype = jnp.float16 + swizzle = 128 + elems_128b = swizzle // jnp.dtype(dtype).itemsize + grid_m, grid_k, grid_n = 132, 10, 32 + # TODO(slebedev): Remove ``grid_tile_n`` to simplify the test. + grid_tile_n = 4 + assert grid_n % grid_tile_n == 0 + cluster_m = 2 + cluster_n = 2 + cluster_tile_n = min(cluster_n, grid_tile_n) + tile_m = tile_n = 128 + assert tile_m % elems_128b == 0 + tile_k = elems_128b + m, k, n = grid_m * tile_m, grid_k * tile_k, grid_n * tile_n + + transforms = ( + plgpu.TilingTransform((8, elems_128b)), + plgpu.SwizzleTransform(128), + ) + + max_concurrent_steps = 2 + delay_release = 1 + + @functools.partial( + self.kernel, + out_shape=jax.ShapeDtypeStruct((m, n), dtype), + scratch_shapes=[ + plgpu.SMEM( + (max_concurrent_steps, tile_m, tile_k), + dtype, + transforms=transforms, + ), + plgpu.SMEM( + (max_concurrent_steps, tile_k, tile_n), + dtype, + transforms=transforms, + ), + plgpu.SMEM((tile_m, tile_n), dtype, transforms=transforms), + plgpu.ACC((tile_m, tile_n), jnp.float32), + plgpu.Barrier(num_arrivals=2, num_barriers=max_concurrent_steps), + plgpu.ClusterBarrier( + collective_axes=(("x", "z"), "y"), + num_barriers=max_concurrent_steps, + ), + ], + grid=(grid_tile_n, grid_m, grid_n // grid_tile_n), + grid_names=("tile_n", "m", "n"), + cluster=(cluster_tile_n, cluster_m, cluster_n // cluster_tile_n), + cluster_names=("x", "y", "z"), + ) + def kernel( + a_gmem, + b_gmem, + o_gmem, + a_smem, + b_smem, + o_smem, + acc, + barrier, + cluster_barrier, + ): + m_slice = pl.ds(lax.axis_index("m") * tile_m, tile_m) + n_slice = pl.ds( + (lax.axis_index("tile_n") + lax.axis_index("n") * grid_tile_n) + * tile_n, + tile_n, + ) + + def fetch(step, slot): + if not isinstance(slot, int): # Skip in initialization. + plgpu.barrier_arrive(cluster_barrier.at[slot]) + plgpu.barrier_wait(cluster_barrier.at[slot]) + + k_slice = pl.ds(step * tile_k, tile_k) + plgpu.copy_gmem_to_smem( + a_gmem.at[m_slice, k_slice], + a_smem.at[slot], + barrier.at[slot], + collective_axes=("x", "z"), + ) + plgpu.copy_gmem_to_smem( + b_gmem.at[k_slice, n_slice], + b_smem.at[slot], + barrier.at[slot], + collective_axes="y", + ) + + # Initialize the pipeline. + for slot in range(min(max_concurrent_steps, grid_k)): + fetch(slot, slot) + + def body(step, _): + slot = step % max_concurrent_steps + plgpu.barrier_wait(barrier.at[slot]) + + plgpu.wgmma(acc, a_smem.at[slot], b_smem.at[slot]) + plgpu.wgmma_wait(delay_release) + + fetch_step = step + (max_concurrent_steps - delay_release) + fetch_slot = lax.rem(fetch_step, max_concurrent_steps) + jax.lax.cond( + lax.bitwise_and(step >= delay_release, fetch_step < grid_k), + lambda: fetch(fetch_step, fetch_slot), + lambda: None, + ) + return () + + jax.lax.fori_loop(0, grid_k, body, ()) + + # Finalize the pipeline. + o_smem[...] = acc[...].astype(dtype) + plgpu.commit_smem() + plgpu.copy_smem_to_gmem(o_smem, o_gmem.at[m_slice, n_slice]) + plgpu.wait_smem_to_gmem(0) + + key1, key2 = jax.random.split(jax.random.key(42), 2) + a = jax.random.uniform(key1, shape=(m, k), dtype=dtype) + b = jax.random.uniform(key2, shape=(k, n), dtype=dtype) + np.testing.assert_array_equal(kernel(a, b), a @ b) + + +class CoreMapWGTest( + CoreMapTest, lowering_semantics=plgpu.LoweringSemantics.Warpgroup +): + ... + + +class PrettyPrintingTest(PallasTest): + + def test_load(self): + + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct([2, 128], jnp.float32), + in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)], + out_specs=plgpu.BlockSpec(memory_space=plgpu.SMEM), + ) + def kernel(x_ref, o_ref): + for i in range(2): + x = plgpu.load(x_ref, (i,)) + o_ref[i, ...] = x + + _ = str(jax.make_jaxpr(kernel)(jax.ShapeDtypeStruct((2, 128), jnp.float32))) + + def test_copy_primitives(self): + num_steps = 4 + + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct((64, 64), jnp.float32), + in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)], + out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), + ) + def kernel(x_gmem, o_gmem): + # ``plgpu.emit_pipeline`` is implemented in terms of async copy and + # synchronization primitives. + plgpu.emit_pipeline( + kernel_body, + in_specs=[pl.BlockSpec((64, 64), lambda i: (0, i))], + out_specs=[ + pl.BlockSpec( + (64, 64), + lambda i: (0, i), + ) + ], + grid=(num_steps,), + max_concurrent_steps=2, + )(x_gmem, o_gmem) + + def kernel_body(_, x_smem, o_smem): + o_smem[...] = x_smem[...] + 1.0 + + _ = str(jax.make_jaxpr(kernel)(jax.ShapeDtypeStruct((64, 64), jnp.float32))) + + def test_wgmma(self): + transforms = () + if self.LOWERING_SEMANTICS == plgpu.LoweringSemantics.Lane: + transforms = (plgpu.TilingTransform((8, 64)), plgpu.SwizzleTransform(128)) + + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct((64, 192), jnp.float32), + in_specs=[ + plgpu.BlockSpec(transforms=transforms), + plgpu.BlockSpec(transforms=transforms), + ], + ) + def kernel(a_ref, b_ref, o_ref): + def scope(acc_ref): + plgpu.wgmma(acc_ref, a_ref[...], b_ref) + return acc_ref[...] + + o_ref[...] = pl.run_scoped(scope, plgpu.ACC((64, 192), jnp.float32)) + + _ = str( + jax.make_jaxpr(kernel)( + jax.ShapeDtypeStruct((64, 128), jnp.float16), + jax.ShapeDtypeStruct((128, 192), jnp.float16), + ) + ) + + +class ExportTest(PallasTest): + + def test_export_succeeds(self): + out_shape = jax.ShapeDtypeStruct([128], jnp.float32) + + @functools.partial(self.pallas_call, out_shape=out_shape) + def kernel(x_ref, o_ref): + o_ref[...] = x_ref[...] + 1.0 + + _ = export.export(kernel)(out_shape) class ExamplesTest(PallasTest): # Basic def test_stage0(self): - def body(l_ref, r_ref, o_ref): + x = jnp.arange(128 * 128, dtype=jnp.float16).reshape(128, 128) + + @functools.partial(self.kernel, out_shape=x) + def kernel(l_ref, r_ref, o_ref): o_ref[...] = l_ref[...] + r_ref[...] - x = jnp.arange(128 * 128, dtype=jnp.float16).reshape(128, 128) - out = plgpu.kernel(body, out_shape=x)(x, x) - np.testing.assert_allclose(out, x + x) + np.testing.assert_allclose(kernel(x, x), x + x) # Multi-block kernels def test_stage1(self): row_block = 64 - def body(l_ref, r_ref, o_ref): + x = jnp.arange(128 * 128, dtype=jnp.float16).reshape(128, 128) + + @functools.partial( + self.kernel, out_shape=x, grid=(2,), grid_names=("rows",) + ) + def kernel(l_ref, r_ref, o_ref): my_slice = pl.ds(lax.axis_index("rows") * row_block, row_block) o_ref[my_slice] = l_ref[my_slice] + r_ref[my_slice] - x = jnp.arange(128 * 128, dtype=jnp.float16).reshape(128, 128) - out = plgpu.kernel(body, out_shape=x, grid=(2,), axis_names=("rows",))(x, x) - np.testing.assert_allclose(out, x + x) + np.testing.assert_allclose(kernel(x, x), x + x) # Async copies def test_stage3(self): row_block, col_block = 64, 128 - def body(l_ref, r_ref, o_ref): + + @functools.partial( + self.kernel, + out_shape=jax.ShapeDtypeStruct((128, 128), jnp.float16), + scratch_shapes=[ + *([plgpu.SMEM((row_block, col_block), jnp.float16)] * 3), + plgpu.Barrier(num_arrivals=2), + ], + grid=(2,), + grid_names=("rows",), + ) + def kernel(l_ref, r_ref, o_ref, l_smem, r_smem, o_smem, barrier): my_slice = pl.ds(lax.axis_index("rows") * row_block, row_block) - def scoped(l_smem, r_smem, o_smem, barrier): - plgpu.copy_gmem_to_smem(l_ref.at[my_slice], l_smem, barrier) - plgpu.copy_gmem_to_smem(r_ref.at[my_slice], r_smem, barrier) - plgpu.barrier_wait(barrier) - o_smem[...] = l_smem[...] + r_smem[...] - plgpu.commit_smem() - plgpu.copy_smem_to_gmem(o_smem, o_ref.at[my_slice]) - plgpu.wait_smem_to_gmem(0) - pl.run_scoped( - scoped, - *([plgpu.SMEM((row_block, col_block), jnp.float16)] * 3), - plgpu.Barrier(num_arrivals=2), - ) + plgpu.copy_gmem_to_smem(l_ref.at[my_slice], l_smem, barrier) + plgpu.copy_gmem_to_smem(r_ref.at[my_slice], r_smem, barrier) + plgpu.barrier_wait(barrier) + o_smem[...] = l_smem[...] + r_smem[...] + plgpu.commit_smem() + plgpu.copy_smem_to_gmem(o_smem, o_ref.at[my_slice]) + plgpu.wait_smem_to_gmem(0) x = jnp.arange(128 * 128, dtype=jnp.float16).reshape(128, 128) - out = plgpu.kernel(body, out_shape=x, grid=(2,), axis_names=("rows",))(x, x) - np.testing.assert_allclose(out, x + x) + np.testing.assert_allclose(kernel(x, x), x + x) # Pipelining def test_stage4(self): row_block, col_block = 64, 32 - def body(l_ref, r_ref, o_ref): - def compute(l_smem, r_smem, o_smem): + x = jnp.arange(128 * 128, dtype=jnp.float16).reshape(128, 128) + + @functools.partial( + self.kernel, out_shape=x, grid=(2,), grid_names=("rows",) + ) + def kernel(l_ref, r_ref, o_ref): + def compute(_, l_smem, r_smem, o_smem): o_smem[...] = l_smem[...] + r_smem[...] r = lax.axis_index("rows") block = pl.BlockSpec((row_block, col_block), lambda c: (r, c)) @@ -2161,20 +3888,29 @@ def compute(l_smem, r_smem, o_smem): out_specs=[block], )(l_ref, r_ref, o_ref) - x = jnp.arange(128 * 128, dtype=jnp.float16).reshape(128, 128) - out = plgpu.kernel(body, out_shape=x, grid=(2,), axis_names=("rows",))(x, x) - np.testing.assert_allclose(out, x + x) + np.testing.assert_allclose(kernel(x, x), x + x) # Transforms def test_stage5(self): + self.skip_if_wg_semantics() # Needs WGMMA to support slices. + row_block, col_block = 64, 32 - def body(l_ref, r_ref, o_ref): - def compute(l_smem, r_smem, o_smem): + x = jnp.arange(128 * 128, dtype=jnp.float16).reshape(128, 128) + + @functools.partial( + self.kernel, out_shape=x, grid=(2,), grid_names=("rows",) + ) + def kernel(l_ref, r_ref, o_ref): + def compute(_, l_smem, r_smem, o_smem): o_smem[...] = l_smem[...] + r_smem[...] r = lax.axis_index("rows") - block = plgpu.GPUBlockSpec( - (row_block, col_block), lambda c: (r, c), - transforms=(plgpu.TilingTransform((64, 32)), plgpu.SwizzleTransform(64)), + block = plgpu.BlockSpec( + (row_block, col_block), + lambda c: (r, c), + transforms=( + plgpu.TilingTransform((8, 32)), + plgpu.SwizzleTransform(64), + ), ) plgpu.emit_pipeline( compute, @@ -2183,40 +3919,140 @@ def compute(l_smem, r_smem, o_smem): out_specs=[block], )(l_ref, r_ref, o_ref) - x = jnp.arange(128 * 128, dtype=jnp.float16).reshape(128, 128) - out = plgpu.kernel(body, out_shape=x, grid=(2,), axis_names=("rows",))(x, x) - np.testing.assert_allclose(out, x + x) + np.testing.assert_allclose(kernel(x, x), x + x) + + +class SemaphoreTest(PallasTest): + + def test_lowering(self): + # This is a smoke test until we add support for lowering of semaphore ops. + def body(i_ref1, i_ref2, o_ref, sem_ref): + del i_ref2 # Only here to have a different number of inputs and outputs. + assert sem_ref.shape == (4,) + assert jnp.issubdtype(sem_ref.dtype, pl.semaphore) + o_ref[...] = i_ref1[...] + x = jnp.arange(128, dtype=jnp.float32).reshape((128,)) + kernel = self.pallas_call( + body, + out_shape=x, + scratch_shapes=[plgpu.SemaphoreType.REGULAR((4,))], + ) + text = jax.jit(kernel).lower(x, x).as_text() + self.assertIn( + r"output_operand_aliases =" + r" [#stablehlo.output_operand_alias]", + text, + ) + self.assertIn( + r"(tensor<128xf32>, tensor<128xf32>, tensor<4xi32>) ->" + r" (tensor<128xf32>, tensor<4xi32>)", + text, + ) + + def test_basic(self): + def body(o_ref, sem_ref): + assert jnp.issubdtype(sem_ref.dtype, pl.semaphore) + pl.semaphore_signal(sem_ref) + o_ref[...] = jnp.ones_like(o_ref) + pl.semaphore_wait(sem_ref) + kernel = plgpu.kernel( + body, + out_shape=jax.ShapeDtypeStruct((128,), jnp.float32), + scratch_shapes=[plgpu.SemaphoreType.REGULAR], + grid=(2,), + grid_names=("x",), + ) + text = jax.jit(kernel).lower().as_text() + np.testing.assert_array_equal(kernel(), jnp.ones((128,), jnp.float32)) + # The semaphore array is scaled up by the grid size. + self.assertIn( + r"(tensor<128xf32>, tensor<2xi32>) -> (tensor<128xf32>, tensor<2xi32>)", + text, + ) + + def test_with_profiler(self): + # Dealing with profiler and semaphores together is tricky because they both + # add extra outputs to the HLO op. + def body(o_ref, sem_ref): + assert jnp.issubdtype(sem_ref.dtype, pl.semaphore) + with jax.named_scope("output"): + o_ref[...] = jnp.ones_like(o_ref) + with tempfile.TemporaryDirectory() as tmp_dir: + kernel = plgpu.kernel( + body, + out_shape=jax.ShapeDtypeStruct((128,), jnp.float32), + scratch_shapes=[plgpu.SemaphoreType.REGULAR], + grid=(2,), + grid_names=("x",), + compiler_params=plgpu.CompilerParams(profile_space=32, profile_dir=tmp_dir), + ) + text = jax.jit(kernel).lower().as_text() + np.testing.assert_array_equal(kernel(), jnp.ones((128,), jnp.float32)) + self.assertIn( + r"(tensor<128xf32>, tensor<2xi32>) ->" + r" (tensor<128xf32>, tensor<2xi32>, tensor<512xui32>)", + text, + ) + + +class ExamplesWGTest( + ExamplesTest, lowering_semantics=plgpu.LoweringSemantics.Warpgroup +): + ... class ExamplesSm90ATest(PallasSm90ATest): # WGMMA def test_stage6(self): + self.skip_if_wg_semantics() # Needs WGMMA to support slices. + m_block = n_block = 64 k_block = 32 - def body(l_ref, r_ref, o_ref): - def compute(l_smem, r_smem, o_smem): + x = jnp.arange(128 * 128, dtype=jnp.float16).reshape(128, 128) + + @functools.partial( + self.kernel, out_shape=x, grid=(2, 2), grid_names=("m", "n") + ) + def kernel(l_ref, r_ref, o_ref): + def compute(_, l_smem, r_smem, o_smem): def do_wgmma(acc_ref): plgpu.wgmma(acc_ref, l_smem, r_smem) return acc_ref[...] o_smem[...] += pl.run_scoped(do_wgmma, plgpu.ACC((m_block, n_block), jnp.float16)) - m, n = lax.axis_index("m"), lax.axis_index("n") - lo_transforms = (plgpu.TilingTransform((64, 32)), plgpu.SwizzleTransform(64)) - r_transforms = (plgpu.TilingTransform((32, 32)), plgpu.SwizzleTransform(64)) + m = lax.axis_index("m") + n = lax.axis_index("n") + lo_transforms = (plgpu.TilingTransform((8, 32)), plgpu.SwizzleTransform(64)) + r_transforms = (plgpu.TilingTransform((8, 32)), plgpu.SwizzleTransform(64)) plgpu.emit_pipeline( compute, grid=(l_ref.shape[1] // k_block,), - in_specs=[plgpu.GPUBlockSpec((m_block, k_block), lambda k: (m, k), transforms=lo_transforms), - plgpu.GPUBlockSpec((k_block, n_block), lambda k: (k, n), transforms=r_transforms)], - out_specs=[plgpu.GPUBlockSpec((m_block, n_block), lambda k: (m, n), transforms=lo_transforms)], + in_specs=[ + plgpu.BlockSpec( + (m_block, k_block), lambda k: (m, k), transforms=lo_transforms + ), + plgpu.BlockSpec( + (k_block, n_block), lambda k: (k, n), transforms=r_transforms + ), + ], + out_specs=[ + plgpu.BlockSpec( + (m_block, n_block), lambda k: (m, n), transforms=lo_transforms + ) + ], )(l_ref, r_ref, o_ref) - x = jnp.arange(128 * 128, dtype=jnp.float16).reshape(128, 128) - out = plgpu.kernel(body, out_shape=x, grid=(2, 2), axis_names=("m", "n"))(x, x) - np.testing.assert_allclose(out, x @ x) + np.testing.assert_allclose(kernel(x, x), x @ x) # TODO(apaszke): Clusters and multicast +class ExamplesSm90AWGTest( + ExamplesSm90ATest, lowering_semantics=plgpu.LoweringSemantics.Warpgroup +): + ... + + if __name__ == "__main__": absltest.main() diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index 0fc375bf64a1..162152dc2a3f 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -17,7 +17,8 @@ import itertools import math import sys -from typing import Any, Callable +from typing import Any +from collections.abc import Callable import unittest from absl.testing import absltest @@ -30,6 +31,7 @@ from jax._src import linear_util as lu from jax._src import state from jax._src import test_util as jtu +from jax._src.pallas import pallas_call from jax.experimental import pallas as pl from jax.interpreters import partial_eval as pe import jax.numpy as jnp @@ -47,21 +49,18 @@ plgpu_triton = None pltpu = None -try: - import hypothesis as hp -except (ModuleNotFoundError, ImportError): - raise unittest.SkipTest("tests depend on hypothesis library") - +import hypothesis as hp import hypothesis.extra.numpy as hnp import hypothesis.strategies as hps + # There are many inherited redefinitions of _ # ruff: noqa: F811 jax.config.parse_flags_with_absl() jtu.setup_hypothesis(max_examples=50) -use_mosaic_gpu = jax.config.read("jax_pallas_use_mosaic_gpu") +use_mosaic_gpu = pallas_call._PALLAS_USE_MOSAIC_GPU.value intx = dtypes.canonicalize_dtype(jnp.int64) floatx = dtypes.canonicalize_dtype(jnp.float64) @@ -187,7 +186,7 @@ def select_n_strategy( else: pred_dtype = np.int32 pred = draw(arrays(shape=pred_shape, dtype=pred_dtype, - elements=allowed_elements)) + elements=allowed_elements)) cases = ( draw( arrays(shape=case_shape_dtype.shape, dtype=case_shape_dtype.dtype) @@ -203,7 +202,7 @@ def select_n_strategy( # TODO(sharadmv,apaszke): enable zero dim sizes # TODO(sharadmv,apaszke): enable one dim sizes ( - lax.neg_p, + lax.neg_p, {}, make_shape_dtype_strategy( min_rank=2, max_rank=3, @@ -213,7 +212,7 @@ def select_n_strategy( ), ), ( - lax.not_p, + lax.not_p, {}, make_shape_dtype_strategy( min_rank=2, max_rank=3, @@ -225,6 +224,7 @@ def select_n_strategy( *[ ( prim, + params, make_shape_dtype_strategy( min_rank=2, max_rank=3, @@ -233,23 +233,23 @@ def select_n_strategy( valid_dtypes=[jnp.dtype("float32")], ), ) - for prim in [ - lax.exp_p, - lax.tanh_p, - lax.logistic_p, - lax.rsqrt_p, - lax.log_p, - lax.exp2_p, - lax.abs_p, - lax.log1p_p, - lax.sin_p, - lax.sqrt_p, + for prim, params in [ + (lax.abs_p, {}), + (lax.exp_p, {"accuracy": None}), + (lax.tanh_p, {"accuracy": None}), + (lax.logistic_p, {"accuracy": None}), + (lax.rsqrt_p, {"accuracy": None}), + (lax.log_p, {"accuracy": None}), + (lax.exp2_p, {"accuracy": None}), + (lax.log1p_p, {"accuracy": None}), + (lax.sin_p, {"accuracy": None}), + (lax.sqrt_p, {"accuracy": None}), ] ], ] UNARY_FUNCTIONS = [ - (prim.name, prim.bind, strategy) for prim, strategy in UNARY_PRIMITIVES + (prim.name, functools.partial(prim.bind, **params), strategy) for prim, params, strategy in UNARY_PRIMITIVES ] + [ ( name, @@ -293,18 +293,19 @@ def setUp(self): def pallas_call(cls, *args, **kwargs): if jtu.test_device_matches(["cuda"]) and use_mosaic_gpu: assert plgpu_mgpu is not None - compiler_params = plgpu_mgpu.GPUCompilerParams( - thread_semantics=plgpu_mgpu.ThreadSemantics.Warpgroup + compiler_params = plgpu_mgpu.CompilerParams( + lowering_semantics=plgpu_mgpu.LoweringSemantics.Warpgroup ) kwargs["compiler_params"] = compiler_params return pl.pallas_call(*args, interpret=cls.INTERPRET, **kwargs) def skip_if_mosaic_gpu(self): - if jtu.test_device_matches(["cuda"]) and use_mosaic_gpu: + if jtu.test_device_matches(["gpu"]) and use_mosaic_gpu: self.skipTest("TODO: Mosaic GPU does not support this yet") +@jtu.thread_unsafe_test_class() # hypothesis is not thread safe class OpsTest(PallasBaseTest): @parameterized.named_parameters( @@ -329,7 +330,7 @@ def kernel(x_ref, y_ref, o_ref): x = jnp.full((8, 128), 4, dtype=dtype) y = jnp.full((8, 128), 2 if jnp.issubdtype(dtype, jnp.integer) else 2.0, - dtype=dtype) + dtype=dtype) np.testing.assert_allclose(kernel(x, y), fn(x, y)) @parameterized.named_parameters( @@ -560,7 +561,8 @@ def kernel(*refs): ) @hp.given(hps.data()) def test_unary_primitives(self, name, func, shape_dtype_strategy, data): - self.skip_if_mosaic_gpu() + if name in ["abs", "log1p", "pow2", "reciprocal", "relu", "sin", "sqrt"]: + self.skip_if_mosaic_gpu() if self.INTERPRET: self.skipTest("This hypothesis test is slow, even more so in interpret mode.") @@ -577,6 +579,12 @@ def test_unary_primitives(self, name, func, shape_dtype_strategy, data): def kernel(x_ref, y_ref): y_ref[...] = func(x_ref[...]) x_shape_dtype = data.draw(shape_dtype_strategy) + + sut_is_mosaic_gpu = jtu.test_device_matches(["gpu"]) and use_mosaic_gpu + if sut_is_mosaic_gpu: + hp.assume(math.prod(x_shape_dtype.shape) % 128 == 0) + hp.assume(x_shape_dtype.shape[-1] >= 16) + key = random.key(0) x = _random_value(key, x_shape_dtype) out = self.pallas_call(kernel, out_shape=x_shape_dtype)(x) @@ -587,10 +595,16 @@ def kernel(x_ref, y_ref): def test_cast_from_32bit(self, from_dtype, to_dtype, data): sut_is_mosaic_gpu = jtu.test_device_matches(["gpu"]) and use_mosaic_gpu if to_dtype in {"float8_e4m3b11fnuz", "float8_e5m2", "float8_e4m3fn"}: - if not jtu.test_device_matches(["tpu"]) or jtu.get_tpu_version() < 5: + if not jtu.test_device_matches(["tpu"]): self.skipTest("Not supported on this hardware") - if not jtu.if_cloud_tpu_at_least(2025, 3, 8): + if jtu.get_tpu_version() >= 5 and not jtu.if_cloud_tpu_at_least( + 2025, 3, 8 + ): self.skipTest("Test requires libtpu from 2025/3/8 or later") + if jtu.get_tpu_version() < 5 and not jtu.if_cloud_tpu_at_least( + 2025, 5, 15 + ): + self.skipTest("Test requires libtpu from 2025/5/15 or later") if from_dtype in {"int2", "uint2"} or to_dtype in {"int2", "uint2"}: if jtu.test_device_matches(["tpu"]) and not jtu.if_cloud_tpu_at_least( 2025, 4, 1 @@ -599,10 +613,17 @@ def test_cast_from_32bit(self, from_dtype, to_dtype, data): if from_dtype == to_dtype: self.skipTest("Unnecessary test") if jtu.is_device_tpu(version=4): - if to_dtype in {"int8", "uint8", "int4", "uint4", "int2", "uint2"}: + if to_dtype in {"int2", "uint2"}: self.skipTest("Not supported on this TPU generation") if to_dtype in {"int16", "uint16"} and not jtu.if_cloud_tpu_at_least(2025, 1, 18): self.skipTest("Test requires libtpu from 2025/1/18 or later") + if to_dtype in { + "int4", + "uint4", + "int8", + "uint8", + } and not jtu.if_cloud_tpu_at_least(2025, 5, 15): + self.skipTest("Test requires libtpu from 2025/5/15 or later") if jtu.test_device_matches(["tpu"]) and jtu.get_tpu_version() < 4: # Currently only casts between 32-bit types and to bf16 are supported. if to_dtype not in {"int32", "uint32", "float32", "bfloat16"}: @@ -666,18 +687,7 @@ def test_cast_from_sub_32bit(self, from_dtype, to_dtype, randomize): if jtu.is_device_tpu(version=4): allowed_v4_cats = {("int16", "int32"): (2025, 1, 18)} if ( - from_dtype - in { - "int16", - "int8", - "uint16", - "uint8", - "int4", - "uint4", - "int2", - "uint2", - } - or to_dtype in {"int8", "uint8", "int4", "uint4", "int2", "uint2"} + from_dtype in {"int2", "uint2"} or to_dtype in {"int2", "uint2"} ) and (from_dtype, to_dtype) not in allowed_v4_cats: self.skipTest("Not supported on this TPU generation") if minimum_libtpu_date := allowed_v4_cats.get((from_dtype, to_dtype), None): @@ -685,6 +695,12 @@ def test_cast_from_sub_32bit(self, from_dtype, to_dtype, randomize): self.skipTest("Test requires a newer libtpu") if to_dtype in {"int16", "uint16"} and not jtu.if_cloud_tpu_at_least(2025, 1, 18): self.skipTest("Test requires libtpu from 2025/1/18 or later") + if ( + to_dtype in {"int4", "uint4", "int8", "uint8"} + and from_dtype in {"int4", "uint4", "int8", "uint8"} + and not jtu.if_cloud_tpu_at_least(2025, 5, 15) + ): + self.skipTest("Test requires libtpu from 2025/5/15 or later") if jtu.test_device_matches(["tpu"]) and jtu.get_tpu_version() < 4: self.skipTest("Not supported on this TPU generation") if jtu.test_device_matches(["gpu"]) and ( @@ -712,10 +728,16 @@ def test_cast_from_sub_32bit(self, from_dtype, to_dtype, randomize): "float8_e5m2", "float8_e4m3fn", } or to_dtype in {"float8_e4m3b11fnuz", "float8_e5m2", "float8_e4m3fn"}: - if not jtu.test_device_matches(["tpu"]) or jtu.get_tpu_version() < 5: + if not jtu.test_device_matches(["tpu"]): self.skipTest("Not supported on this hardware") - if not jtu.if_cloud_tpu_at_least(2025, 3, 9): + if jtu.get_tpu_version() >= 5 and not jtu.if_cloud_tpu_at_least( + 2025, 3, 9 + ): self.skipTest("Test requires libtpu from 2025/3/9 or later") + if jtu.get_tpu_version() < 5 and not jtu.if_cloud_tpu_at_least( + 2025, 5, 15 + ): + self.skipTest("Test requires libtpu from 2025/5/15 or later") if from_dtype == "int2" and to_dtype == "bool": self.skipTest( "TODO(b/343490729): XLA compare(s2, s2) yields wrong results" @@ -1061,8 +1083,8 @@ def kernel(x_ref, o_ref): ( # fmt: off [jnp.expm1, jnp.log1p, jnp.cbrt, lax.rsqrt, jnp.tan, jnp.asin, - jnp.acos, jnp.atan, jnp.sinh, jnp.cosh, jnp.tanh, jnp.asinh, - jnp.acosh, jnp.atanh], + jnp.acos, jnp.atan, jnp.sinh, jnp.cosh, jnp.tanh, jnp.asinh, + jnp.acosh, jnp.atanh], # fmt: on ["bfloat16", "float32", "float64"], ), @@ -1086,7 +1108,7 @@ def test_elementwise(self, fn, dtype): self.skipTest("int16 and float16 are not supported on TPU") if ( fn in (jnp.ceil, jnp.floor, jnp.negative, jnp.exp, jnp.exp2, jnp.log, - jnp.sqrt, lax.rsqrt) + jnp.sqrt, lax.rsqrt) and dtype == "bfloat16" and not jtu.is_device_tpu_at_least(6) ): @@ -1285,7 +1307,6 @@ def kernel(x_ref, y_ref, o_ref): ) ) def test_comparison(self, fn, dtype): - self.skip_if_mosaic_gpu() if jtu.test_device_matches(["gpu"]) and dtype == jnp.bool_: self.skipTest("Not implemented on GPU.") @@ -1295,16 +1316,16 @@ def test_comparison(self, fn, dtype): @functools.partial( self.pallas_call, - out_shape=jax.ShapeDtypeStruct((8,), jnp.bool_), + out_shape=jax.ShapeDtypeStruct((128,), jnp.int32), ) def kernel(x_ref, y_ref, o_ref): - o_ref[:] = fn(x_ref[...], y_ref[...]) + o_ref[:] = fn(x_ref[...], y_ref[...]).astype(jnp.int32) - x = jnp.array([0, 3, -4, -6, 0, 5, 4, -7]).astype(dtype) - y = jnp.array([3, 1, -4, -5, 0, -2, 2, 4]).astype(dtype) + x = jnp.tile(jnp.array([0, 3, -4, -6, 0, 5, 4, -7]).astype(dtype), 16) + y = jnp.tile(jnp.array([3, 1, -4, -5, 0, -2, 2, 4]).astype(dtype), 16) out = kernel(x, y) expected = fn(x, y) - self.assertArraysEqual(out, expected) + self.assertArraysEqual(out != 0, expected) @parameterized.named_parameters( (f"{fn.__name__}_{dtype.__name__}", fn, dtype) @@ -1314,8 +1335,6 @@ def kernel(x_ref, y_ref, o_ref): ) ) def test_comparison_scalar(self, fn, dtype): - self.skip_if_mosaic_gpu() - if jtu.test_device_matches(["tpu"]) and dtype == jnp.float16: self.skipTest("float16 is not supported on TPU") @@ -1325,6 +1344,9 @@ def test_comparison_scalar(self, fn, dtype): ): self.skipTest("Only works on GPUs with capability >= sm80") + if jtu.test_device_matches(["gpu"]) and dtype == jnp.bool_: + self.skip_if_mosaic_gpu() + @functools.partial( self.pallas_call, in_specs=( @@ -1332,17 +1354,17 @@ def test_comparison_scalar(self, fn, dtype): pl.BlockSpec(memory_space=smem_on_tpu()), ), out_specs=pl.BlockSpec(memory_space=smem_on_tpu()), - out_shape=jax.ShapeDtypeStruct((8,), jnp.bool_), + out_shape=jax.ShapeDtypeStruct((128,), jnp.int32), ) def kernel(x_ref, y_ref, o_ref): - for i in range(8): - o_ref[i] = fn(x_ref[i], y_ref[i]) + for i in range(128): + o_ref[i] = fn(x_ref[i], y_ref[i]).astype(jnp.int32) - x = jnp.array([0, 3, -4, -6, 0, 5, 4, -7]).astype(dtype) - y = jnp.array([3, 1, -4, -5, 0, -2, 2, 4]).astype(dtype) + x = jnp.tile(jnp.array([0, 3, -4, -6, 0, 5, 4, -7]).astype(dtype), 16) + y = jnp.tile(jnp.array([3, 1, -4, -5, 0, -2, 2, 4]).astype(dtype), 16) out = kernel(x, y) expected = fn(x, y) - self.assertArraysEqual(out, expected) + self.assertArraysEqual(out != 0, expected) def test_isnan(self): self.skip_if_mosaic_gpu() @@ -1464,7 +1486,7 @@ def kernel(x_ref, y_ref, o_ref): ( # fmt: off [jnp.bitwise_and, jnp.bitwise_or, jnp.bitwise_xor, - jnp.bitwise_left_shift, jnp.bitwise_right_shift], + jnp.bitwise_left_shift, jnp.bitwise_right_shift], # fmt: on ["int32", "uint32"], ), @@ -1510,10 +1532,10 @@ def test_binary_scalar(self, f, dtype): @functools.partial( self.pallas_call, - in_specs=[pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM), - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM), + in_specs=[pl.BlockSpec(memory_space=pltpu.SMEM), + pl.BlockSpec(memory_space=pltpu.SMEM), ], - out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM), + out_specs=pl.BlockSpec(memory_space=pltpu.SMEM), out_shape=jax.ShapeDtypeStruct((1,), dtype), ) def kernel(x_ref, y_ref, o_ref): @@ -1525,14 +1547,15 @@ def kernel(x_ref, y_ref, o_ref): np.testing.assert_allclose(f(x, y), kernel(x, y)) @parameterized.parameters( + ((32,), jnp.int32, 0), ((8, 4), jnp.int32, 0), ((8, 16), jnp.float32, 1), ((8, 16, 2), jnp.int8, 1), ) - def test_broadcasted_iota(self, shape, dtype, dimension): + def test_iota(self, shape, dtype, dimension): self.skip_if_mosaic_gpu() - if jtu.test_device_matches(["tpu"]): + if jtu.test_device_matches(["tpu"]) and dtype != jnp.int32: self.skipTest("Only 32-bit integer iota supported") f = lambda: jax.lax.broadcasted_iota(dtype, shape, dimension) @@ -1624,7 +1647,7 @@ def kernel(x_ref, o_ref): @unittest.skipIf( sys.platform == "win32", - "plgpu_triton.TritonCompilerParams unavailable on Windows", + "plgpu_triton.CompilerParams unavailable on Windows", ) def test_debug_print(self): self.skip_if_mosaic_gpu() @@ -1639,7 +1662,7 @@ def test_debug_print(self): @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct((2,), jnp.float32), - compiler_params=plgpu_triton.TritonCompilerParams( + compiler_params=plgpu_triton.CompilerParams( num_warps=1, num_stages=1 ), ) @@ -1655,7 +1678,7 @@ def kernel(x_ref, o_ref): @unittest.skipIf( sys.platform == "win32", - "plgpu_triton.TritonCompilerParams unavailable on Windows", + "plgpu_triton.CompilerParams unavailable on Windows", ) def test_debug_print_with_values(self): if jtu.test_device_matches(["tpu"]): @@ -1668,7 +1691,7 @@ def test_debug_print_with_values(self): @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct((2,), jnp.float32), - compiler_params=plgpu_triton.TritonCompilerParams( + compiler_params=plgpu_triton.CompilerParams( num_warps=1, num_stages=1 ), ) @@ -1739,6 +1762,27 @@ def f(x_ref, o_ref): expected = x.reshape(out_shape) np.testing.assert_allclose(f(x), expected) + def test_reshape_to_scalar(self): + self.skip_if_mosaic_gpu() + # Test reshapes from (1, 1) to (). + # Because TPUs distinguish between VREGs/SREGs this tests an implicit + # copy from VREG -> SREG that must be inserted by Pallas. + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct((8, 128), jnp.int32), + ) + def f(x_ref, o_ref): + o_ref[...] = jnp.zeros_like(o_ref) + vector_val = x_ref[1:2, 0:1] + scalar_val = jnp.reshape(vector_val, ()) + o_ref[scalar_val] = jnp.ones_like(o_ref[0]) * scalar_val + + in_shape = (4, 4) + x = jnp.arange(int(np.prod(in_shape)), dtype=jnp.int32).reshape(in_shape) + expected = jnp.zeros((8, 128), jnp.int32) + expected = expected.at[x[1, 0]].set(x[1, 0]) + np.testing.assert_allclose(f(x), expected) + def test_num_programs(self): self.skip_if_mosaic_gpu() @@ -1779,30 +1823,49 @@ def copyitem(x_ref, in_idx_ref, out_idx_ref, o_ref): np.testing.assert_allclose(out[oi], x[ii]) np.testing.assert_allclose(out[oi + 1 :], jnp.zeros_like(out[oi + 1 :])) - @parameterized.parameters( - ((), (2,), ()), - ((1,), (2,), (0,)), - ((1, 1), (2, 2), (0, 1)), - ((), (2, 2), ()), + @parameterized.product( + shape_spec=[ + ((), (2,), ()), + ((1,), (2,), (0,)), + ((1, 128), (8, 128), (0, 1)), # row broadcasting + ((), (2, 2), ()), + ], + dtype=[jnp.int32, jnp.int16, jnp.int8, jnp.bool_], ) - def test_broadcast_in_dim(self, in_shape, out_shape, dims): + def test_broadcast_in_dim(self, shape_spec, dtype): self.skip_if_mosaic_gpu() - # The Pallas TPU lowering currently supports only blocks of rank >= 1 + in_shape, out_shape, dims = shape_spec if jtu.test_device_matches(["tpu"]): - self.skipTest("Not supported on TPU") + if not in_shape: + self.skipTest( + "The Pallas TPU lowering currently supports only blocks of rank" + " >= 1" + ) + if dtype is jnp.bool_ and not jtu.if_cloud_tpu_at_least(2025, 6, 5): + self.skipTest("Requires libtpu built after 2025-06-05") + if ( + len(in_shape) == 1 + and len(out_shape) == 1 + and dtype not in {jnp.int32, jnp.bool_} + ): + self.skipTest("Unsupported tiling") @functools.partial( self.pallas_call, - out_shape=jax.ShapeDtypeStruct(out_shape, jnp.float32), + out_shape=jax.ShapeDtypeStruct(out_shape, dtype), ) def f(x_ref, o_ref): x = x_ref[...] o_ref[...] = jax.lax.broadcast_in_dim(x, out_shape, dims) - x = jnp.arange(int(np.prod(in_shape)), dtype=jnp.float32).reshape(in_shape) + x = ( + jnp.arange(math.prod(in_shape), dtype=jnp.int32) + .reshape(in_shape) + .astype(dtype) + ) expected = jax.lax.broadcast_in_dim(x, out_shape, dims) - np.testing.assert_allclose(f(x), expected) + np.testing.assert_array_equal(f(x), expected) @parameterized.product( lhs_and_rhs_shape=[ @@ -1865,6 +1928,15 @@ def test_dot(self, lhs_and_rhs_shape, dtype, trans_x, trans_y): > (256 * 256) * 2 ): self.skipTest("Shared memory size limit exceeded") + if (jax.local_devices()[0].device_kind == "NVIDIA L4" and + dtype == jnp.float32 and + lhs_and_rhs_shape in [ + ((128, 16), (128, 256)), + ((16, 128), (128, 256)), + ((16, 256), (256, 128)), + ((256, 16), (256, 128)), + ]): + self.skipTest("Shared memory size limit exceeded") if min(*lhs_shape, *rhs_shape) < 16: self.skipTest("All dimensions of lhs and rhs must be >= 16") if any(not is_power_of_two(x) for x in lhs_shape + rhs_shape): @@ -1886,7 +1958,7 @@ def dot(x_ref, y_ref, o_ref): # Pallas always accumulates in FP32, so we are explicit about # preferred_element_type here. expected = jnp.dot(x.T if trans_x else x, y.T if trans_y else y, - preferred_element_type=jnp.float32).astype(dtype) + preferred_element_type=jnp.float32).astype(dtype) np.testing.assert_allclose( out.astype(jnp.float32), expected.astype(jnp.float32), @@ -1936,7 +2008,7 @@ def test_masked_oob_load_store_slice(self): def masked_oob_load_store_slice(x_ref, mask_ref, start_idx_ref, o_ref): x = pl.load(x_ref, (pl.dslice(start_idx_ref[()], n)), mask=mask_ref[:], other=-1.) - pl.store(o_ref, (pl.dslice(None),), x) + o_ref[...] = x x = random.normal(random.key(0), (n,)) slice_start = random.randint(random.key(2), (), 1, n) @@ -2075,7 +2147,7 @@ def test_masked_oob_swap_slice(self): @functools.partial( self.pallas_call, out_shape=(jax.ShapeDtypeStruct((n,), floatx), - jax.ShapeDtypeStruct((m,), floatx)), + jax.ShapeDtypeStruct((m,), floatx)), input_output_aliases={0: 0, 1: 1}, ) def masked_oob_swap_slice(_, _2, mask_ref, start_idx_ref, x_ref, y_ref): @@ -2205,7 +2277,7 @@ def swap(_, lock_ref, out_ref): lock, out = swap(init_value) np.testing.assert_allclose(lock, new_value if cmp == init_value else - init_value) + init_value) np.testing.assert_allclose(out, init_value) @parameterized.parameters(1, 2, 3, 4, 8) @@ -2517,6 +2589,52 @@ def kernel(x_ref, out_ref): )(x) np.testing.assert_array_equal(out, np.diagonal(x)) + @parameterized.product( + # Skip some steps to just run less cases + # TODO(mvoz): Hypothesis? + x_dim_size=tuple(8 * i for i in range(1, 5)), + y_dim_size=tuple(8 * i for i in range(1, 5)), + z_dim_size=tuple(128 * i for i in range(1, 3)), + dtype=(jnp.float32,), + ) + def test_jnp_swapaxes_major_minor( + self, x_dim_size, y_dim_size, z_dim_size, dtype + ): + if jtu.test_device_matches(["gpu"]): + if any( + not is_power_of_two(x) for x in [x_dim_size, y_dim_size, z_dim_size] + ): + self.skipTest( + "the Pallas Triton lowering currently requires that all operations" + " have array arguments and results whose size is a power of 2." + f" Encountered an array of shape ({x_dim_size}, {y_dim_size}," + f" {z_dim_size})" + ) + if x_dim_size * y_dim_size * z_dim_size * 4 > 32768: + self.skipTest( + "Mosaic GPU kernel exceeds available shared memory" + f" smem_bytes={x_dim_size * y_dim_size * z_dim_size * 4} > 32768" + ) + self.skip_if_mosaic_gpu() + if not jtu.if_cloud_tpu_at_least(2025, 5, 22): + self.skipTest("Requires libtpu built after 2025-5-22") + + x = jnp.arange(x_dim_size * y_dim_size * z_dim_size, dtype=dtype).reshape( + (x_dim_size, y_dim_size, z_dim_size) + ) + + def kernel(x_ref, out_ref): + out_ref[...] = jnp.swapaxes(x_ref[...], 0, 1) + + out = self.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct( + (y_dim_size, x_dim_size, z_dim_size), dtype + ), + )(x) + expected = jnp.swapaxes(x, 0, 1) + np.testing.assert_array_equal(out, expected) + class OpsInterpretTest(OpsTest): INTERPRET = True @@ -2571,15 +2689,15 @@ def body(x_ref): @parameterized.parameters(*[ (lambda: (pl.dslice(0, 4), slice(None), slice(None)), - "c:i32[4,3,2], a[:,:,:] <-"), + "c:i32[4,3,2], a[:,:,:] <-"), (lambda: (pl.dslice(0, 3), slice(None), slice(None)), - "c:i32[3,3,2], a[:3,:,:] <-"), + "c:i32[3,3,2], a[:3,:,:] <-"), (lambda: (pl.dslice(1, 3), slice(None), pl.dslice(0, 4)), - "c:i32[3,3,4], a[1:,:,:4] <-"), + "c:i32[3,3,4], a[1:,:,:4] <-"), (lambda: (jnp.arange(5), slice(None), pl.dslice(0, 4)), - "e:i32[5,3,4], a[b,:,:4] <-"), + "e:i32[5,3,4], a[b,:,:4] <-"), (lambda: (jnp.arange(5)[:, None], jnp.arange(3)[None], pl.dslice(4)), - "o:i32[5,3,4], a[m,n,:4] <-"), + "o:i32[5,3,4], a[m,n,:4] <-"), ]) def test_swap_pretty_print(self, expr, expected): def body(x_ref): diff --git a/tests/pallas/pallas_error_handling_test.py b/tests/pallas/pallas_error_handling_test.py index cd5ceecfc9a8..cc0f3f8ba7aa 100644 --- a/tests/pallas/pallas_error_handling_test.py +++ b/tests/pallas/pallas_error_handling_test.py @@ -16,13 +16,16 @@ import traceback from absl.testing import absltest +from absl.testing import parameterized import jax from jax import numpy as jnp from jax._src import config from jax._src import test_util as jtu +from jax._src.lib import xla_client from jax._src.pallas.mosaic import error_handling from jax.experimental import pallas as pl from jax.experimental.pallas import tpu as pltpu +import numpy as np config.parse_flags_with_absl() @@ -50,9 +53,9 @@ def test_non_singular_stride(self): grid_spec = pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, in_specs=[ - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), + pl.BlockSpec(memory_space=pltpu.VMEM), ], - out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), + out_specs=pl.BlockSpec(memory_space=pltpu.VMEM), ) @functools.partial(pl.pallas_call, out_shape=out_shape, grid_spec=grid_spec) @@ -92,20 +95,21 @@ def kernel_in_jitted_fn(x): tb_string = "".join(tb_string) self.assertEndsWith(tb_string, "x = input_ref[:, ::8]\n") - def test_invalid_smem_vmem_verification_error(self): + def test_index_with_f32_verification_error(self): input_arr = jax.random.uniform(jax.random.key(0), (2, 2), dtype=jnp.float32) out_shape = jax.ShapeDtypeStruct((1, 1), jnp.float32) grid_spec = pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, in_specs=[ - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), + pl.BlockSpec(memory_space=pltpu.VMEM), ], - out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM), + out_specs=pl.BlockSpec(memory_space=pltpu.SMEM), ) @functools.partial(pl.pallas_call, out_shape=out_shape, grid_spec=grid_spec) def test_kernel(input_ref, output_ref): - output_ref[0, 0] = input_ref[0, 0] + idx = input_ref[0, 0] + output_ref[idx, 0] = input_ref[0, 0] # Test that a verification error is raised. This assert is a guard against # underlying changes in Pallas lowering. @@ -113,8 +117,8 @@ def test_kernel(input_ref, output_ref): # the test example to force a different error. with self.assertRaisesRegex( error_handling.VerificationError, - "'memref.store' op failed to verify that type of 'value' matches " - "element type of 'memref'", + "must be signless-integer-like or memref of signless-integer, " + "but got 'f32'" ): test_kernel(input_arr) @@ -125,7 +129,37 @@ def test_kernel(input_ref, output_ref): except error_handling.MosaicError as e: tb_string = traceback.format_tb(e.__traceback__) tb_string = "".join(tb_string) - self.assertEndsWith(tb_string, "output_ref[0, 0] = input_ref[0, 0]\n") + self.assertEndsWith(tb_string, "output_ref[idx, 0] = input_ref[0, 0]\n") + + @parameterized.parameters( + ((2048,), (256,)), + ((2048,), (512,)), + ) + def test_small_1d_block_spec_raises(self, total_shape, block_shape): + # https://github.com/jax-ml/jax/issues/25379 + dtype = jnp.float32 + + def kernel(x_ref, y_ref): + y_ref[...] = x_ref[...] * 2 + + x = jnp.arange(np.prod(total_shape), dtype=dtype).reshape(total_shape) + x_spec = pl.BlockSpec(block_shape, lambda *args: args) + fn = pl.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct(total_shape, dtype), + in_specs=[x_spec], + out_specs=x_spec, + grid=tuple(tot // blk for tot, blk in zip(total_shape, block_shape, + strict=True)), + ) + # Having a block size that is too small should raise a suggestion + # to increase the block size. + with self.assertRaisesRegex( + xla_client.XlaRuntimeError, + r"Try changing your kernel block shape to \([0-9,\s]+\) to align with" + " the XLA layout", + ): + fn(x) def test_parse_location_string(self): name, frames = error_handling.parse_location_string(LOCATION_TEST_STRING) diff --git a/tests/pallas/pallas_jumble_test.py b/tests/pallas/pallas_jumble_test.py index 509ef08a987f..0a2994a84a8f 100644 --- a/tests/pallas/pallas_jumble_test.py +++ b/tests/pallas/pallas_jumble_test.py @@ -354,7 +354,7 @@ def invoke_kernel(x): with self.assertRaisesRegex( ValueError, - "Ragged input shape must be evenly divisble by the grid" # noqa: W605 + "Ragged input shape must be evenly divisible by the grid" # noqa: W605 " size at the ragged dimension 2", ): jax.vmap( diff --git a/tests/pallas/pallas_test.py b/tests/pallas/pallas_test.py index 745c30ba98cb..cb61d5648912 100644 --- a/tests/pallas/pallas_test.py +++ b/tests/pallas/pallas_test.py @@ -30,11 +30,9 @@ from jax import random from jax._src import checkify from jax._src import config -from jax._src import core as jax_core from jax._src import dtypes from jax._src import test_util as jtu from jax._src.lax.control_flow.for_loop import for_loop -from jax._src.pallas import pallas_call from jax._src.pallas.pallas_call import _trace_kernel_to_jaxpr from jax.experimental import pallas as pl import jax.numpy as jnp @@ -128,8 +126,8 @@ def matmul_block_spec(x, y, *, bm, bn, bk, interpret, debug=False): def matmul_kernel(x_ref, y_ref, o_ref): acc = jnp.zeros(o_ref.shape, dtype=jnp.float32) def body(i, acc_ref): - x_block = pl.load(x_ref, (slice(None), pl.ds(i * bk, bk))) - y_block = pl.load(y_ref, (pl.ds(i * bk, bk), slice(None))) + x_block = x_ref[:, pl.ds(i * bk, bk)] + y_block = y_ref[pl.ds(i * bk, bk), :] acc_ref[:, :] += pl.dot(x_block, y_block) acc = for_loop(k // bk, body, acc).astype(o_ref.dtype) o_ref[:, :] = acc @@ -624,8 +622,9 @@ def test_unused_ref(self): out_shape=jax.ShapeDtypeStruct((m, n), jnp.float32), ) def dummy(_, o_ref): - pl.store(o_ref, (jnp.arange(m)[:, None], jnp.arange(n)[None, :]), - jnp.ones_like(o_ref)) + o_ref[jnp.arange(m)[:, None], jnp.arange(n)[None, :]] = jnp.ones_like( + o_ref + ) key = random.key(0) x = random.normal(key, (m, n)) @@ -667,8 +666,7 @@ def test_using_pallas_slice(self): out_shape=out_shape, ) def slice_kernel(x_ref, y_ref): - x = pl.load(x_ref, (pl.dslice(0, 4), pl.dslice(0, 4))) - pl.store(y_ref, (pl.dslice(4), pl.dslice(4)), x) + y_ref[:4, :4] = x_ref[:4, :4] x = random.normal(random.key(0), (m, n)) y = slice_kernel(x) y_ref = x[:4] @@ -694,6 +692,22 @@ def f(x): self.assertEqual(f(x), 2.) self.assertEqual(trace_count, 1) + def test_pallas_call_under_disable_jit(self): + @functools.partial( + self.pallas_call, out_shape=jax.ShapeDtypeStruct((8,), jnp.float32), + ) + def add_one(x_ref, o_ref): + o_ref[...] = x_ref[...] + 1. + + x = jnp.arange(8, dtype=jnp.float32) + + result = add_one(x) + np.testing.assert_array_equal(result, x + 1.) + + with jax.disable_jit(): + result = add_one(x) + np.testing.assert_array_equal(result, x + 1.) + @parameterized.parameters( ("float32", None), ("float32", jax.lax.Precision.DEFAULT), @@ -702,6 +716,9 @@ def f(x): ("float32", jax.lax.DotAlgorithmPreset.DEFAULT), ("float32", jax.lax.DotAlgorithmPreset.F16_F16_F32), ("float32", jax.lax.DotAlgorithmPreset.BF16_BF16_F32), + ("float32", jax.lax.DotAlgorithmPreset.BF16_BF16_F32_X3), + ("float32", jax.lax.DotAlgorithmPreset.BF16_BF16_F32_X6), + ("float32", jax.lax.DotAlgorithmPreset.BF16_BF16_F32_X9), ("float32", jax.lax.DotAlgorithmPreset.TF32_TF32_F32), ("float32", jax.lax.DotAlgorithmPreset.TF32_TF32_F32_X3), ("float32", jax.lax.DotAlgorithmPreset.F32_F32_F32), @@ -731,7 +748,21 @@ def dot_kernel(x_ref, y_ref, o_ref): precision=jax.lax.Precision.HIGHEST, preferred_element_type=jnp.float32, ) - self.assertAllClose(dot_kernel(x, y), expected, atol=5e-2, rtol=5e-3) + if dtype == "bfloat16" or precision in ( + jax.lax.Precision.HIGHEST, + jax.lax.DotAlgorithmPreset.F32_F32_F32, + jax.lax.DotAlgorithmPreset.BF16_BF16_F32_X6, + jax.lax.DotAlgorithmPreset.BF16_BF16_F32_X9, + ): + atol = 5e-6 + elif precision in ( + jax.lax.DotAlgorithmPreset.BF16_BF16_F32_X3, + jax.lax.DotAlgorithmPreset.TF32_TF32_F32_X3, + ): + atol = 5e-4 + else: + atol = 5e-2 + self.assertAllClose(dot_kernel(x, y), expected, atol=atol, rtol=atol / 10) @parameterized.parameters(jnp.int8, jnp.uint8) def test_integer_dot(self, dtype): @@ -826,18 +857,37 @@ def dot_kernel(x_ref, y_ref, o_ref): self.assertAllClose(dot_kernel(x, y), expected) + @parameterized.parameters( + ((32,), 2, 0), ((32, 64), 4, 0), ((32, 16), 8, 1), ((32, 16, 2), 16, 1) + ) + def test_split(self, shape, num_parts, axis): + if jtu.test_device_matches(["tpu"]) and shape[axis] == num_parts: + self.skipTest("TPU doesn't support fully split axis.") + + x = jax.random.normal(jax.random.key(0), shape) + expected = jnp.split(x, num_parts, axis) + + @functools.partial(self.pallas_call, out_shape=expected) + def kernel(x_ref, *o_ref): + x_parts = jnp.split(x_ref[()], num_parts, axis) + for o_ref, x_part in zip(o_ref, x_parts): + o_ref[...] = x_part + + self.assertAllClose(kernel(x), expected) + + class PallasCallInterpretTest(PallasCallTest): INTERPRET = True -class PallasCallUnblockedIndexingTest(PallasBaseTest): +class PallasCallElementIndexingTest(PallasBaseTest): - def test_block_spec_unblocked(self): + def test_block_spec_element(self): def show_program_ids( - *, shape, block_shape, grid, indexing_mode: pl.IndexingMode + *, shape, block_shape, grid, ): def kernel(o1_ref): - assert o1_ref.shape == block_shape + assert o1_ref.shape == (8, 128) o1_ref[...] = jnp.full(o1_ref.shape, pl.program_id(0)) return self.pallas_call( @@ -845,16 +895,15 @@ def kernel(o1_ref): jax.ShapeDtypeStruct(shape, dtype=np.int32), grid=grid, out_specs=pl.BlockSpec( - block_shape, lambda i: (8 * i, 0), indexing_mode=indexing_mode + block_shape, lambda i: (8 * i, 0), ), )() # No padding pids = show_program_ids( shape=(16, 128), - block_shape=(8, 128), + block_shape=(pl.Element(8), pl.Element(128)), grid=(2,), - indexing_mode=pl.Unblocked(), ) expected_pids = np.array([[0] * 128] * 8 + [[1] * 128] * 8, dtype=np.int32) self.assertAllClose(pids, expected_pids) @@ -865,9 +914,8 @@ def kernel(o1_ref): # Only high padding pids = show_program_ids( shape=(14, 128), - block_shape=(8, 128), + block_shape=(pl.Element(8, (0, 2)), pl.Element(128, (0, 0))), grid=(2,), - indexing_mode=pl.Unblocked(((0, 2), (0, 0))), ) expected_pids = np.array([[0] * 128] * 8 + [[1] * 128] * 6, dtype=np.int32) self.assertAllClose(pids, expected_pids) @@ -876,15 +924,14 @@ def kernel(o1_ref): self.skipTest("TODO: low padding not supported yet") pids = show_program_ids( shape=(11, 128), - block_shape=(8, 128), + block_shape=(pl.Element(8, (3, 2)), pl.Element(128, (0, 0))), grid=(2,), - indexing_mode=pl.Unblocked(((3, 2), (0, 0))), ) expected_pids = np.array([[0] * 128] * 5 + [[1] * 128] * 6, dtype=np.int32) self.assertAllClose(pids, expected_pids) @parameterized.parameters("int32", "float32") - def test_block_spec_unblocked_padding_is_nan(self, dtype_name): + def test_block_spec_element_padding_is_nan(self, dtype_name): if not self.INTERPRET: self.skipTest("Only applicable for the interpret mode") @@ -899,7 +946,7 @@ def copy_kernel(x_ref, o_ref): grid=(1,), in_specs=[ pl.BlockSpec( - (6,), lambda i: 0, indexing_mode=pl.Unblocked(((1, 2),)) + (pl.Element(6, (1, 2)),), lambda i: 0, ) ], )(np.full((3,), 42, dtype=dtype)) @@ -913,7 +960,7 @@ def copy_kernel(x_ref, o_ref): ), ) - def test_unblocked_indexing(self): + def test_element_indexing(self): shape = (16 * 8, 128) result_ty = jax.ShapeDtypeStruct((15 * 8, 128), jnp.float32) @@ -926,7 +973,7 @@ def kernel(x_ref, o_ref): grid=(15,), in_specs=( pl.BlockSpec( - (2 * 8, 128), lambda i: (i * 8, 0), indexing_mode=pl.unblocked + (pl.Element(2 * 8), pl.Element(128)), lambda i: (i * 8, 0), ), ), out_specs=pl.BlockSpec((8, 128), lambda i: (i, 0)), @@ -955,9 +1002,8 @@ def kernel(x_ref, y_ref): grid=(1,), in_specs=( pl.BlockSpec( - (2 * 8, 128), + (pl.Element(2 * 8, (0, 8)), pl.Element(128)), lambda i: (0, 0), - indexing_mode=pl.Unblocked(((0, 8), (0, 0))), ), ), out_specs=pl.BlockSpec((8, 128), lambda i: (0, 0)), @@ -966,10 +1012,39 @@ def kernel(x_ref, y_ref): np.testing.assert_array_equal(y, x) -class PallasCallUnblockedIndexingInterpretTest(PallasCallUnblockedIndexingTest): +class PallasCallElementIndexingInterpretTest(PallasCallElementIndexingTest): INTERPRET = True +class PallasCallBoundedSliceIndexingTest(PallasBaseTest): + + def setUp(self): + super().setUp() + if not jtu.is_device_tpu(): + self.skipTest("Only applicable for TPU") + + def test_block_spec_bounded_slice_static(self): + shape = (16, 8, 128) + def kernel(x_ref, o_ref): + o_ref[...] = x_ref[...] + + x = jnp.arange(np.prod(shape), dtype=np.int32).reshape(shape) + with self.assertRaisesRegex(NotImplementedError, + "Unsupported block dimension type:"): + _ = self.pallas_call( + kernel, + jax.ShapeDtypeStruct((8, 8, 128), dtype=np.int32), + grid=(1,), + in_specs=( + pl.BlockSpec( + (pl.BoundedSlice(8), 8, 128), lambda i: (pl.ds(4, 8), 0, 0), + ), + ), + out_specs=pl.BlockSpec( + (8, 8, 128), lambda i: (0, 0, 0), + ), + )(x) + class ApiErrorTest(PallasBaseTest): def test_pallas_call_kernel_args_mismatch(self): a = np.arange(256, dtype=np.int32) @@ -1022,10 +1097,10 @@ def test_pallas_call_in_specs_mismatch_inputs(self): pl.BlockSpec((4,), lambda: 0)]) with self.assertRaisesRegex( ValueError, - re.compile("Pytree for `in_specs` and inputs do not match. " + re.compile("Pytree for `in_specs` and `inputs` do not match. " "There are 1 mismatches, including:" ".* at \\[1\\], `in_specs` is a pytree leaf but " - "inputs is a.*", re.DOTALL)): + "`inputs` is a.*", re.DOTALL)): f(a, dict(a=a)) def test_pallas_call_index_map_wrong_number_of_arguments(self): @@ -1067,7 +1142,6 @@ def my_index_map(): "Currently returning 2 values."): f(dict(one=a, two=a)) - def test_pallas_call_index_map_wrong_return_type(self): a = np.arange(256, dtype=np.int32) def my_index_map(i): @@ -1181,6 +1255,28 @@ def test_pallas_call_input_output_aliases_errors(self): out_shape=[jax.ShapeDtypeStruct(x.shape, jnp.float32)], input_output_aliases={1: 0})(x, x) + def test_pallas_error_for_ref_to_jax(self): + m, n, k = 8, 16, 32 + + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct((m, n), jnp.float32), + ) + def dot_general_kernel(x_ref, y_ref, o_ref): + o_ref[...] = jax.lax.dot_general(x_ref, y_ref, (((2), (1)), ((1,), (2,)))) + + key1, key2 = random.split(random.key(0)) + x = random.normal(key1, (m, k), dtype=jnp.float32) + y = random.normal(key2, (k, n), dtype=jnp.float32) + with self.assertRaisesRegex( + ValueError, + r" Attempting to pass a Ref" + r" MemRef{float32\[8,32\]}" + r" to a primitive: dot_general - did you forget to unpack \(\[...\]\)" + r" the ref?", + ): + dot_general_kernel(x, y) + class ApiErrorInterpretTest(ApiErrorTest): INTERPRET = True @@ -1697,7 +1793,7 @@ def test_range_while_loop(self): def kernel(x_ref, r_ref): @pl.when(pl.program_id(0) == 0) def _(): - pl.store(r_ref, (0, 0), 0) + r_ref[0, 0] = 0 def cond(carry): i, j = carry @@ -1709,8 +1805,7 @@ def body(carry): sl = jax.lax.div(i, 128) l = jax.lax.rem(i, 128) v = x_ref[0, sl, l] - s = pl.load(r_ref, (0, 0)) - pl.store(r_ref, (0, 0), s + v) + r_ref[0, 0] += v return io + 1, j i = 128 @@ -1762,7 +1857,7 @@ def test_non_range_while_loop(self): def kernel(x_ref, r_ref): @pl.when(pl.program_id(0) == 0) def _(): - pl.store(r_ref, (0, 0), 0) + r_ref[0, 0] = 0 def cond(state): i, s = state @@ -1772,14 +1867,11 @@ def body(state): i, s = state sl = jax.lax.div(i, jnp.astype(128, i.dtype)) l = jax.lax.rem(i, jnp.astype(128, i.dtype)) - v = pl.load(x_ref, (0, sl, l)) + v = x_ref[0, sl, l] return i + 1, s + v i = jnp.int32(0) - s = pl.load(r_ref, (0, 0)) - - i, s = jax.lax.while_loop(cond, body, (i, s)) - pl.store(r_ref, (0, 0), s) + _, r_ref[0, 0] = jax.lax.while_loop(cond, body, (i, r_ref[0, 0])) x = jnp.arange(4096) x = jnp.reshape(x, [4, 8, 128]) @@ -2175,7 +2267,7 @@ def kernel(x_ref, y_ref): checkify.check(False, "second check failed") input_ = jnp.arange(4, dtype=jnp.int32) out_shape = jax.ShapeDtypeStruct(input_.shape, input_.dtype) - with pltpu.enable_runtime_assert(True): + with pl.enable_debug_checks(True): pallas_call = pl.pallas_call(kernel, out_shape=out_shape) pallas_call(input_) # This should log "second check failed" @@ -2185,11 +2277,10 @@ def test_runtime_assert_is_noop_when_not_enabled(self): self.skipTest("Runtime check only implemented on TPU.") def kernel(x_ref, y_ref): y_ref[...] = x_ref[...] - checkify.check(False, "failed check", - debug=True) # This check always fails. + pl.debug_check(False, "failed check") # This check always fails. input_ = jnp.arange(4, dtype=jnp.int32) out_shape = jax.ShapeDtypeStruct(input_.shape, input_.dtype) - with pltpu.enable_runtime_assert(False): + with pl.enable_debug_checks(False): pallas_call = pl.pallas_call(kernel, out_shape=out_shape) result = pallas_call(input_) np.testing.assert_allclose(result, input_) @@ -2379,8 +2470,8 @@ def kernel(x_ref, y_ref): def test_can_query_named_grid_size_in_kernel_via_psum(self): def kernel(x_ref, y_ref): - self.assertEqual(lax.psum(1, "i"), 2) - self.assertEqual(lax.psum(1, "j"), 4) + self.assertEqual(lax.axis_size("i"), 2) + self.assertEqual(lax.axis_size("j"), 4) y_ref[...] = x_ref[...] x = jnp.arange(4 * 16 * 128, dtype=np.int32).reshape((4, 16, 128)) @@ -2400,8 +2491,8 @@ def test_can_query_named_dynamic_grid_size_in_kernel_via_psum(self): self.skipTest("Not supported.") def kernel(x_ref, y_ref): - self.assertEqual(lax.psum(1, "i"), 2) - self.assertEqual(lax.psum(1, "j"), 4) + self.assertEqual(lax.axis_size("i"), 2) + self.assertEqual(lax.axis_size("j"), 4) y_ref[...] = x_ref[...] x = jnp.arange(4 * 8 * 128, dtype=np.int32).reshape((4, 8, 128)) @@ -2522,47 +2613,5 @@ class PallasCallNamedGridInterpretTest(PallasCallNamedGridTest): INTERPRET = True -def _find_pallas_call_in_jaxpr( - jaxpr: jax_core.Jaxpr) -> jax_core.JaxprEqn | None: - for eqn in jaxpr.eqns: - call_eqn = None - if eqn.primitive == pallas_call.pallas_call_p: - call_eqn = eqn - elif 'jaxpr' in eqn.params: - call_eqn = _find_pallas_call_in_jaxpr(eqn.params['jaxpr']) - if call_eqn is not None: - return call_eqn - return None - - -class PallasCompilerParamsTest(PallasBaseTest): - def test_triton_params_consistent_across_double_jit(self): - # Test for https://github.com/jax-ml/jax/issues/25714 - if not jtu.test_device_matches(["gpu"]): - self.skipTest("Triton backend only works on GPU.") - params = plgpu.TritonCompilerParams(num_warps=8) - - @jax.jit - @functools.partial( - self.pallas_call, out_shape=jax.ShapeDtypeStruct((), jnp.float32), - compiler_params=params) - def copy_kernel(x_ref, o_ref): - o_ref[...] = x_ref[...] - - @functools.partial(jax.jit, static_argnames=["z"]) - def plus_z(x, z): - return copy_kernel(x+z) - - x = 0. - extracted_params = _find_pallas_call_in_jaxpr( - plus_z.trace(x, 1).jaxpr).params["compiler_params"] - self.assertEqual(plus_z(0., 1.), 1.) - self.assertEqual(extracted_params["triton"]["num_warps"], 8) - extracted_params = _find_pallas_call_in_jaxpr( - plus_z.trace(x, 2).jaxpr).params["compiler_params"] - self.assertEqual(plus_z(0., 2.), 2.) - self.assertEqual(extracted_params["triton"]["num_warps"], 8) - - if __name__ == "__main__": absltest.main() diff --git a/tests/pallas/tpu_all_gather_test.py b/tests/pallas/tpu_all_gather_test.py index 98b3e5b40135..47168e1c35b4 100644 --- a/tests/pallas/tpu_all_gather_test.py +++ b/tests/pallas/tpu_all_gather_test.py @@ -25,114 +25,109 @@ import jax.numpy as jnp import numpy as np -try: - import hypothesis as hp - import hypothesis.strategies as hps - CAN_USE_HYPOTHESIS = True -except (ModuleNotFoundError, ImportError): - CAN_USE_HYPOTHESIS = False +import hypothesis as hp +import hypothesis.strategies as hps jax.config.parse_flags_with_absl() P = jax.sharding.PartitionSpec -if CAN_USE_HYPOTHESIS: - - hp.settings.register_profile( - "deterministic", - database=None, - derandomize=True, - deadline=None, - max_examples=50, - print_blob=True, - verbosity=hp.Verbosity.verbose, +hp.settings.register_profile( + "deterministic", + database=None, + derandomize=True, + deadline=None, + max_examples=50, + print_blob=True, + verbosity=hp.Verbosity.verbose, +) +hp.settings.load_profile("deterministic") + + +@hps.composite +def _array_shapes(draw): + # TODO(sharadmv, apaszke): enable this on a wider variety of shapes + valid_shapes = [ + (128, 128), + (256, 128), + (256, 512), + (256, 1024), + # TODO(sharadmv,apaszke): enable these shapes + # (256, 129), + # (129, 128), + # (64, 64), + # (1, 1), + ] + return draw(hps.sampled_from(valid_shapes)) + + +@hps.composite +def _array_dtypes(draw): + return draw( + hps.sampled_from([ + jnp.float32, + jnp.bfloat16, + jnp.int32, + # jnp.float16, # TODO(sharadmv,apaszke): enable float16 all gather + # jnp.int16, # TODO(sharadmv,apaszke): enable int16 all gather + # jnp.int8, # TODO(sharadmv,apaszke): enable int8 all gather + ]) ) - hp.settings.load_profile("deterministic") - - - @hps.composite - def _array_shapes(draw): - # TODO(sharadmv, apaszke): enable this on a wider variety of shapes - valid_shapes = [ - (128, 128), - (256, 128), - (256, 512), - (256, 1024), - # TODO(sharadmv,apaszke): enable these shapes - # (256, 129), - # (129, 128), - # (64, 64), - # (1, 1), - ] - return draw(hps.sampled_from(valid_shapes)) - - - @hps.composite - def _array_dtypes(draw): - return draw( - hps.sampled_from([ - jnp.float32, - jnp.bfloat16, - jnp.int32, - # jnp.float16, # TODO(sharadmv,apaszke): enable float16 all gather - # jnp.int16, # TODO(sharadmv,apaszke): enable int16 all gather - # jnp.int8, # TODO(sharadmv,apaszke): enable int8 all gather - ]) - ) - class AllGatherTest(jtu.JaxTestCase): - - def setUp(self): - if not jtu.test_device_matches(["tpu"]): - self.skipTest("Need TPU devices") - if not jtu.is_device_tpu(version=5, variant="e"): - # TODO(sharadmv,apaszke): expand support to more versions - self.skipTest("Currently only supported on TPU v5e") - - super().setUp() - - @hp.given(hps.booleans(), _array_shapes(), _array_dtypes()) - def test_all_gather_1d_mesh(self, is_vmem, shape, dtype): - if jax.device_count() < 2: - self.skipTest("Need more devices") - memory_space = pltpu.VMEM if is_vmem else pltpu.ANY - mesh_shape = (jax.device_count(),) - mesh = jax.sharding.Mesh( - mesh_utils.create_device_mesh(mesh_shape, jax.devices()), ["x"] - ) - leading, *rest = shape - shape = (mesh.shape["x"] * leading, *rest) - x = random.normal(random.key(0), shape, dtype=jnp.float32).astype(dtype) - x_sharded = jax.device_put(x, jax.sharding.NamedSharding(mesh, P("x"))) - y = all_gather.all_gather(x_sharded, mesh=mesh, axis_name="x", - memory_space=memory_space) - np.testing.assert_array_equal(y, x) - - @hp.given(hps.booleans(), _array_shapes(), _array_dtypes(), - hps.sampled_from(["x", "y"])) - def test_all_gather_2d_mesh(self, is_vmem, shape, dtype, - axis_name): - if jax.device_count() < 2: - self.skipTest("Need more devices") - if jax.device_count() % 2: - self.skipTest("Need an even number of devices") - memory_space = pltpu.VMEM if is_vmem else pltpu.ANY - mesh_shape = (2, jax.device_count() // 2) - mesh = jax.sharding.Mesh( - mesh_utils.create_device_mesh(mesh_shape, jax.devices()), ["x", "y"] - ) - if axis_name == "x": - sharding = jax.sharding.NamedSharding(mesh, P("x", None)) - else: - sharding = jax.sharding.NamedSharding(mesh, P("y", None)) - leading, *rest = shape - shape = (mesh.shape[axis_name] * leading, *rest) - x = random.normal(random.key(0), shape, dtype=jnp.float32).astype(dtype) - x_sharded = jax.device_put(x, sharding) - y = all_gather.all_gather(x_sharded, mesh=mesh, axis_name=axis_name, - memory_space=memory_space) - np.testing.assert_array_equal(y, x) +@jtu.thread_unsafe_test_class() # hypothesis is not thread safe +class AllGatherTest(jtu.JaxTestCase): + + def setUp(self): + if not jtu.test_device_matches(["tpu"]): + self.skipTest("Need TPU devices") + if not jtu.is_device_tpu(version=5, variant="e"): + # TODO(sharadmv,apaszke): expand support to more versions + self.skipTest("Currently only supported on TPU v5e") + + super().setUp() + + @hp.given(hps.booleans(), _array_shapes(), _array_dtypes()) + def test_all_gather_1d_mesh(self, is_vmem, shape, dtype): + if jax.device_count() < 2: + self.skipTest("Need more devices") + memory_space = pltpu.VMEM if is_vmem else pltpu.ANY + mesh_shape = (jax.device_count(),) + mesh = jax.sharding.Mesh( + mesh_utils.create_device_mesh(mesh_shape, jax.devices()), ["x"] + ) + leading, *rest = shape + shape = (mesh.shape["x"] * leading, *rest) + x = random.normal(random.key(0), shape, dtype=jnp.float32).astype(dtype) + x_sharded = jax.device_put(x, jax.sharding.NamedSharding(mesh, P("x"))) + y = all_gather.all_gather(x_sharded, mesh=mesh, axis_name="x", + memory_space=memory_space) + np.testing.assert_array_equal(y, x) + + @hp.given(hps.booleans(), _array_shapes(), _array_dtypes(), + hps.sampled_from(["x", "y"])) + def test_all_gather_2d_mesh(self, is_vmem, shape, dtype, + axis_name): + if jax.device_count() < 2: + self.skipTest("Need more devices") + if jax.device_count() % 2: + self.skipTest("Need an even number of devices") + memory_space = pltpu.VMEM if is_vmem else pltpu.ANY + mesh_shape = (2, jax.device_count() // 2) + mesh = jax.sharding.Mesh( + mesh_utils.create_device_mesh(mesh_shape, jax.devices()), ["x", "y"] + ) + if axis_name == "x": + sharding = jax.sharding.NamedSharding(mesh, P("x", None)) + else: + sharding = jax.sharding.NamedSharding(mesh, P("y", None)) + leading, *rest = shape + shape = (mesh.shape[axis_name] * leading, *rest) + x = random.normal(random.key(0), shape, dtype=jnp.float32).astype(dtype) + x_sharded = jax.device_put(x, sharding) + y = all_gather.all_gather(x_sharded, mesh=mesh, axis_name=axis_name, + memory_space=memory_space) + np.testing.assert_array_equal(y, x) if __name__ == "__main__": diff --git a/tests/pallas/tpu_fusable_matmul_test.py b/tests/pallas/tpu_fusible_matmul_test.py similarity index 92% rename from tests/pallas/tpu_fusable_matmul_test.py rename to tests/pallas/tpu_fusible_matmul_test.py index df7c1221bb0c..ae56d3db2f3a 100644 --- a/tests/pallas/tpu_fusable_matmul_test.py +++ b/tests/pallas/tpu_fusible_matmul_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Fusable matmul test.""" +"""Fusible matmul test.""" import functools from typing import Any @@ -71,10 +71,11 @@ def _(): def _(): acc = acc_ref[...].astype(out_dtype) z_values = jax.tree.map(lambda ref: ref.get(), z_value_refs) - o_ref[...] = z_fn(pids, scalar_prefetch, z_values, acc) + out = z_fn(pids, scalar_prefetch, z_values, acc) + jax.tree.map(lambda ref, x: ref.set(x), o_ref, out) -def _fusable_matmul( +def _fusible_matmul( x: fuser.Fusion[[], jax.Array], # pytype: disable=invalid-annotation y: fuser.Fusion[[], jax.Array], # pytype: disable=invalid-annotation z: fuser.Fusion[[jax.Array], jax.Array] | None, # pytype: disable=invalid-annotation @@ -174,12 +175,12 @@ def z_index_map(i, j, k, *_): y_value_block_specs, z_value_block_specs, ], - out_specs=z_out_block_spec, + out_specs=[z_out_block_spec], ), - compiler_params=pltpu.TPUCompilerParams( + compiler_params=pltpu.CompilerParams( dimension_semantics=dimension_semantics, ), - out_shape=z_out_type, + out_shape=[z_out_type], interpret=interpret, debug=debug, )( @@ -187,10 +188,10 @@ def z_index_map(i, j, k, *_): x_values, y_values, z_values, - ) + )[0] -def fusable_matmul( +def fusible_matmul( x: jax.Array, y: jax.Array, *, @@ -200,9 +201,9 @@ def fusable_matmul( debug: bool = False, interpret: bool = False, ) -> jax.Array: - return fuser.fusable( + return fuser.fusible( functools.partial( - _fusable_matmul, + _fusible_matmul, bm=bm, bk=bk, bn=bn, @@ -212,7 +213,7 @@ def fusable_matmul( )(x, y) -class FusableMatmulTest(jtu.JaxTestCase): +class FusibleMatmulTest(jtu.JaxTestCase): def setUp(self): if not jtu.is_device_tpu_at_least(4): @@ -225,7 +226,7 @@ def test_matmul(self, dtype): x = jax.random.normal(k0, (512, 512), dtype) y = jax.random.normal(k1, (512, 512), dtype) np.testing.assert_allclose( - jax.jit(fusable_matmul)(x, y), mm_ref(x, y), atol=5e-5 + jax.jit(fusible_matmul)(x, y), mm_ref(x, y), atol=5e-5 ) @parameterized.parameters('float32', 'bfloat16') @@ -237,7 +238,7 @@ def test_matmul_with_activation(self, dtype): @jax.jit @fuser.fuse def matmul_relu(x, y): - x = fusable_matmul(x, y) + x = fusible_matmul(x, y) x = jnp.maximum(x, 0.0) return x @@ -257,7 +258,7 @@ def test_matmul_with_bias(self, dtype): @jax.jit @fuser.fuse def matmul_bias(x, y, b): - x = fusable_matmul(x, y).astype(dtype) + b + x = fusible_matmul(x, y).astype(dtype) + b x = jnp.maximum(x, 0.0) return x @@ -276,7 +277,7 @@ def test_matmul_with_slice(self, dtype): @jax.jit @fuser.fuse def matmul_slice(x, y): - x = fusable_matmul(x, y[1]) + x = fusible_matmul(x, y[1]) return x np.testing.assert_allclose(matmul_slice(x, y), mm_ref(x, y[1]), atol=5e-5) @@ -290,7 +291,7 @@ def test_matmul_with_dynamic_slice(self, dtype): @jax.jit @fuser.fuse def matmul_slice(x, y, i): - x = fusable_matmul(x, y[i]) + x = fusible_matmul(x, y[i]) return x np.testing.assert_allclose( @@ -307,7 +308,7 @@ def test_matmul_with_dynamic_slice_bias(self, dtype): @jax.jit @fuser.fuse def matmul_slice(x, y, b, i, j): - x = fusable_matmul(x, y[j]).astype(dtype) + b[i] + x = fusible_matmul(x, y[j]).astype(dtype) + b[i] return x np.testing.assert_allclose( @@ -325,7 +326,7 @@ def test_matmul_with_multi_slice(self, dtype): @jax.jit @fuser.fuse def matmul_slice(x, y): - x = fusable_matmul(x, y[1, 1]) + x = fusible_matmul(x, y[1, 1]) return x np.testing.assert_allclose( @@ -341,7 +342,7 @@ def test_matmul_with_multiple_slices(self, dtype): @jax.jit @fuser.fuse def matmul_slice(x, y): - x = fusable_matmul(x, y[1][1]) + x = fusible_matmul(x, y[1][1]) return x np.testing.assert_allclose( @@ -357,7 +358,7 @@ def test_matmul_with_multiple_dynamic_slices(self, dtype): @jax.jit @fuser.fuse def matmul_slice(x, y, i, j): - x = fusable_matmul(x, y[i][j]) + x = fusible_matmul(x, y[i][j]) return x for i in range(2): @@ -375,7 +376,7 @@ def test_matmul_with_mixed_slices(self, dtype): @jax.jit @fuser.fuse def matmul_slice(x, y, i, j): - x = fusable_matmul(x, y[2][i, j]) + x = fusible_matmul(x, y[2][i, j]) return x for i in range(2): @@ -396,7 +397,7 @@ def test_matmul_with_multiple_mixed_slices_and_bias(self, dtype): @jax.jit @fuser.fuse def matmul_slice(x, y, b, i, j, k): - x = fusable_matmul(x[k][3], y[2][i, j]).astype(dtype) + x = fusible_matmul(x[k][3], y[2][i, j]).astype(dtype) return x + b[i, j] @jit_no_excess_precision @@ -415,7 +416,7 @@ def matmul_slice_ref(x, y, b, i, j, k): @parameterized.parameters('float32', 'bfloat16') def test_matmul_input_concat_output(self, dtype): - self.skipTest('select_n doesnt support more than 3 elements') + self.skipTest('select_n does not support more than 3 elements') # TODO(sharadmv): fix this test k0, k1, k2, k3 = jax.random.split(jax.random.key(0), 4) x = jax.random.normal(k0, (128, 128), dtype) @@ -427,7 +428,7 @@ def test_matmul_input_concat_output(self, dtype): @fuser.fuse def matmul_concat(x, ys): y = jnp.concatenate(ys, axis=1) - x = fusable_matmul(x, y) + x = fusible_matmul(x, y) return x @jax.jit @@ -453,7 +454,7 @@ def test_matmul_input_concat_contract(self, dtype): @fuser.fuse def matmul_concat(x, ys): y = jnp.concatenate(ys, axis=0) - x = fusable_matmul(x, y) + x = fusible_matmul(x, y) return x @jit_no_excess_precision @@ -481,7 +482,7 @@ def test_matmul_double_concat(self, dtype): def matmul_concat(x, ys, y3): y = jnp.concatenate(ys, axis=0) y = jnp.concatenate([y, y3], axis=1) - x = fusable_matmul(x, y) + x = fusible_matmul(x, y) return x @jit_no_excess_precision @@ -508,7 +509,7 @@ def test_matmul_slice_concat(self, dtype): @fuser.fuse def matmul_concat(x, y1, y2): y = jnp.concatenate([y1, y2[3]], axis=0) - x = fusable_matmul(x, y) + x = fusible_matmul(x, y) return x @jit_no_excess_precision @@ -533,7 +534,7 @@ def test_matmul_slice_concat_slice(self, dtype): @fuser.fuse def matmul_concat(x, y1, y2): y = jnp.concatenate([y1, y2[3]], axis=1)[1] - x = fusable_matmul(x, y) + x = fusible_matmul(x, y) return x @jit_no_excess_precision @@ -558,7 +559,7 @@ def test_matmul_dynamic_slice_concat(self, dtype): @fuser.fuse def matmul_concat(x, y1, y2, i, j): y = jnp.concatenate([y1, y2[i]], axis=1)[j] - x = fusable_matmul(x, y) + x = fusible_matmul(x, y) return x @jit_no_excess_precision @@ -584,7 +585,7 @@ def matmul(impl, x, y): return z impl = fuser.fuse( - functools.partial(matmul, functools.partial(fusable_matmul, bn=256)) + functools.partial(matmul, functools.partial(fusible_matmul, bn=256)) ) ref = functools.partial(matmul, mm_ref) @@ -606,7 +607,7 @@ def matmul(impl, x, y): return z impl = fuser.fuse( - functools.partial(matmul, functools.partial(fusable_matmul, bn=256)) + functools.partial(matmul, functools.partial(fusible_matmul, bn=256)) ) ref = functools.partial(matmul, mm_ref) @@ -628,7 +629,7 @@ def matmul(impl, x, y): return z impl = fuser.fuse( - functools.partial(matmul, functools.partial(fusable_matmul, bn=256)) + functools.partial(matmul, functools.partial(fusible_matmul, bn=256)) ) ref = functools.partial(matmul, mm_ref) @@ -650,7 +651,7 @@ def matmul(impl, x, y): return z impl = fuser.fuse( - functools.partial(matmul, functools.partial(fusable_matmul, bn=256)) + functools.partial(matmul, functools.partial(fusible_matmul, bn=256)) ) ref = functools.partial(matmul, mm_ref) @@ -672,7 +673,7 @@ def matmul(impl, x, y): return z impl = fuser.fuse( - functools.partial(matmul, functools.partial(fusable_matmul, bn=256)) + functools.partial(matmul, functools.partial(fusible_matmul, bn=256)) ) ref = functools.partial(matmul, mm_ref) @@ -694,7 +695,7 @@ def matmul(impl, x, y): return z impl = fuser.fuse( - functools.partial(matmul, functools.partial(fusable_matmul, bn=256)) + functools.partial(matmul, functools.partial(fusible_matmul, bn=256)) ) ref = functools.partial(matmul, mm_ref) @@ -715,7 +716,7 @@ def matmul(impl, x, y): return z impl = fuser.fuse( - functools.partial(matmul, functools.partial(fusable_matmul, bn=256)) + functools.partial(matmul, functools.partial(fusible_matmul, bn=256)) ) ref = functools.partial(matmul, mm_ref) @@ -737,7 +738,7 @@ def matmul(impl, x, y): return z impl = fuser.fuse( - functools.partial(matmul, functools.partial(fusable_matmul, bm=256)) + functools.partial(matmul, functools.partial(fusible_matmul, bm=256)) ) ref = functools.partial(matmul, mm_ref) @@ -759,7 +760,7 @@ def matmul(impl, x, y): return z impl = fuser.fuse( - functools.partial(matmul, functools.partial(fusable_matmul, bm=256)) + functools.partial(matmul, functools.partial(fusible_matmul, bm=256)) ) ref = functools.partial(matmul, mm_ref) @@ -781,7 +782,7 @@ def matmul(impl, x, y): return z.T impl = fuser.fuse( - functools.partial(matmul, functools.partial(fusable_matmul, bn=256)) + functools.partial(matmul, functools.partial(fusible_matmul, bn=256)) ) ref = functools.partial(matmul, mm_ref) @@ -803,7 +804,7 @@ def matmul(impl, x, y): return z.T * 2 impl = fuser.fuse( - functools.partial(matmul, functools.partial(fusable_matmul, bn=256)) + functools.partial(matmul, functools.partial(fusible_matmul, bn=256)) ) ref = functools.partial(matmul, mm_ref) @@ -866,7 +867,7 @@ def matmul(impl, x, y): impl = fuser.fuse( functools.partial( matmul, - fusable_matmul, + fusible_matmul, ) ) ref = functools.partial(matmul, dot_ref) @@ -892,7 +893,7 @@ def matmul(impl, x, y): out_ref = jit_no_excess_precision(ref)(x, y) - impl = fuser.fuse(functools.partial(matmul, fusable_matmul)) + impl = fuser.fuse(functools.partial(matmul, fusible_matmul)) out = jax.jit(impl)(x, y) self.assertAllClose(out, out_ref, atol=0) @@ -916,7 +917,7 @@ def matmul(impl, x, y): impl = fuser.fuse( functools.partial( matmul, - functools.partial(fusable_matmul, bk=256, bn=128), + functools.partial(fusible_matmul, bk=256, bn=128), ) ) out = jax.jit(impl)(x, y) @@ -924,7 +925,7 @@ def matmul(impl, x, y): atol = 0 if jtu.is_device_tpu_at_least(6): # 256 MXU changes some tols. - atol = 1e-6 + atol = 1e-5 self.assertAllClose(out, out_ref, atol=atol) def test_matmul_f32_out_fused_downcast(self): @@ -952,7 +953,7 @@ def matmul(impl, x, y): functools.partial( matmul, functools.partial( - fusable_matmul, + fusible_matmul, bm=bm, bk=bk, bn=bn, @@ -989,7 +990,7 @@ def matmul(impl, x, y): functools.partial( matmul, functools.partial( - fusable_matmul, + fusible_matmul, bm=bm, bk=bk, bn=bn, @@ -1024,7 +1025,7 @@ def matmul(impl, x, y): functools.partial( matmul, functools.partial( - fusable_matmul, + fusible_matmul, bm=bm, bk=bk, bn=bn, diff --git a/tests/pallas/tpu_gmm_test.py b/tests/pallas/tpu_gmm_test.py index 9c416dabaeb1..6d8b0a244edb 100644 --- a/tests/pallas/tpu_gmm_test.py +++ b/tests/pallas/tpu_gmm_test.py @@ -24,12 +24,8 @@ import jax.numpy as jnp import numpy as np -try: - import hypothesis as hp - import hypothesis.strategies as hps - CAN_USE_HYPOTHESIS = True -except (ModuleNotFoundError, ImportError): - CAN_USE_HYPOTHESIS = False +import hypothesis as hp +import hypothesis.strategies as hps jax.config.parse_flags_with_absl() @@ -37,326 +33,326 @@ partial = functools.partial -if CAN_USE_HYPOTHESIS: - hp.settings.register_profile( - "deterministic", - database=None, - derandomize=True, - deadline=None, - max_examples=10, - print_blob=True, +hp.settings.register_profile( + "deterministic", + database=None, + derandomize=True, + deadline=None, + max_examples=10, + print_blob=True, +) +hp.settings.load_profile("deterministic") + +def seed_strategy() -> hps.SearchStrategy[int]: + return hps.integers(min_value=0, max_value=4) + +@hps.composite +def group_strategy( + draw: hps.DrawFn, + max_groups: int = 32, + max_stride: int = 32, + min_groups: int = 1, +) -> tuple[int, int]: + assert max_stride <= max_groups + + # Sample the number of groups owned by each shard. + group_stride = draw(hps.integers(min_value=1, max_value=max_stride)) + + # Sample the number of groups as a multiple of the stride to ensure that we + # have an equal number of groups per shard. Round down s.t. num_groups <= + # max_groups. + num_groups = group_stride * draw( + hps.integers(min_value=min_groups, max_value=max_groups // group_stride) ) - hp.settings.load_profile("deterministic") - - def seed_strategy() -> hps.SearchStrategy[int]: - return hps.integers(min_value=0, max_value=4) - - @hps.composite - def group_strategy( - draw: hps.DrawFn, - max_groups: int = 32, - max_stride: int = 32, - min_groups: int = 1, - ) -> tuple[int, int]: - assert max_stride <= max_groups - - # Sample the number of groups owned by each shard. - group_stride = draw(hps.integers(min_value=1, max_value=max_stride)) - - # Sample the number of groups as a multiple of the stride to ensure that we - # have an equal number of groups per shard. Round down s.t. num_groups <= - # max_groups. - num_groups = group_stride * draw( - hps.integers(min_value=min_groups, max_value=max_groups // group_stride) - ) - return num_groups, group_stride - - @hps.composite - def group_sizes_strategy( - draw: hps.DrawFn, m: int, num_groups: int - ) -> jnp.ndarray: - # Randomly sample the ends of the groups in the m-dimension. Let the fuzzer - # sample with replacement so that it's possible to get zero-sized groups. Get - # 'num_groups - 1' run ends. The final group will end at 'm'. - ends_no_final = np.sort( - np.array( - [ - draw(hps.integers(min_value=0, max_value=m)) - for _ in range(num_groups - 1) - ], - dtype=np.int32, - ), + return num_groups, group_stride + +@hps.composite +def group_sizes_strategy( + draw: hps.DrawFn, m: int, num_groups: int +) -> jnp.ndarray: + # Randomly sample the ends of the groups in the m-dimension. Let the fuzzer + # sample with replacement so that it's possible to get zero-sized groups. Get + # 'num_groups - 1' run ends. The final group will end at 'm'. + ends_no_final = np.sort( + np.array( + [ + draw(hps.integers(min_value=0, max_value=m)) + for _ in range(num_groups - 1) + ], + dtype=np.int32, + ), + ) + ends = np.concatenate([ends_no_final, np.array([m], dtype=np.int32)]) + + # Calculate the run starts by shifting ends 1 to the right. The first run + # starts at zero. + starts = np.concatenate([np.zeros(1, dtype=np.int32), ends_no_final]) + return jnp.array(ends - starts, dtype=jnp.int32) + +GROUPED_MATMUL_TESTS = ( + (128, 128, 128), # Small + (512, 2048, 256), # Big + (128, 8, 16), # Test partial tiles. +) + +def random_dense( + shape: tuple[int, ...], + key: jax.Array, + dtype: jnp.dtype, + limit: int | None = None, +) -> jnp.ndarray: + if limit is None: + limit = 1 / np.prod(shape) + x = jax.random.uniform(key, shape, dtype, minval=-limit, maxval=limit) # pylint: disable=invalid-unary-operand-type + return x.astype(jnp.bfloat16).astype(dtype) + +def dot( + lhs: jnp.ndarray, + rhs: jnp.ndarray, + transpose_lhs: bool = False, + transpose_rhs: bool = False, + preferred_element_type: jnp.dtype = jnp.float32, +) -> jnp.ndarray: + lhs = jnp.transpose(lhs) if transpose_lhs else lhs + rhs = jnp.transpose(rhs) if transpose_rhs else rhs + return jax.lax.dot(lhs, rhs, preferred_element_type=preferred_element_type) + +def reference_gmm( + lhs: jnp.ndarray, + rhs: jnp.ndarray, + group_sizes: jnp.ndarray, + preferred_element_type: jnp.dtype = jnp.float32, +) -> jnp.ndarray: + + start = 0 + out = [] + for i, size in enumerate(group_sizes): + result = dot( + lhs[start : start + size, :], + rhs[i, :, :], + preferred_element_type=preferred_element_type, ) - ends = np.concatenate([ends_no_final, np.array([m], dtype=np.int32)]) - # Calculate the run starts by shifting ends 1 to the right. The first run - # starts at zero. - starts = np.concatenate([np.zeros(1, dtype=np.int32), ends_no_final]) - return jnp.array(ends - starts, dtype=jnp.int32) - - GROUPED_MATMUL_TESTS = ( - (128, 128, 128), # Small - (512, 2048, 256), # Big - (128, 8, 16), # Test partial tiles. - ) + out.append(result) + start += group_sizes[i] + return jnp.concatenate(out, axis=0) + +def with_dtype_arguments(xs: tuple[Any, ...]) -> tuple[Any, ...]: + dtypes = [jnp.float32, jnp.bfloat16] + + result = [] + for x in xs: + for dtypes_tuple in itertools.product(dtypes, dtypes, dtypes): + result.append(x + dtypes_tuple) + return tuple(result) + +def with_transpose_argument(xs: tuple[Any, ...]) -> tuple[Any, ...]: + flags = [False, True] + result = [] + for x in xs: + for flag in flags: + result.append(x + (flag,)) + return tuple(result) + +def tolerances( + lhs_dtype: jnp.dtype, rhs_dtype: jnp.dtype, out_dtype: jnp.dtype +) -> tuple[float, float]: + if ( + lhs_dtype == jnp.bfloat16 + or rhs_dtype == jnp.bfloat16 + or out_dtype == jnp.bfloat16 + ): + return 1e-3, 1e-2 # atol, rtol + return 1e-3, 1e-5 # atol, rtol + +# TODO(tgale): Fix errors with strict dtype promotion. +@jtu.with_config(jax_numpy_dtype_promotion="standard") +@jtu.thread_unsafe_test_class() # hypothesis is not thread safe +class GroupedMatmulTest(jtu.JaxTestCase): + + def setUp(self): + if not jtu.test_device_matches(["tpu"]): + self.skipTest("Test requires TPU device.") + + super().setUp() + self.key = jax.random.PRNGKey(1234) + + def assert_allclose( + self, + out: jnp.ndarray, + expected_out: jnp.ndarray, + *, + atol: float = 1e-5, + rtol: float = 1e-5, + ): + self.assertEqual(out.dtype, expected_out.dtype) + np.testing.assert_allclose( + out.astype(jnp.float32), + expected_out.astype(jnp.float32), + atol=atol, + rtol=rtol, + ) - def random_dense( - shape: tuple[int, ...], - key: jax.Array, - dtype: jnp.dtype, - limit: int | None = None, - ) -> jnp.ndarray: - if limit is None: - limit = 1 / np.prod(shape) - x = jax.random.uniform(key, shape, dtype, minval=-limit, maxval=limit) # pylint: disable=invalid-unary-operand-type - return x.astype(jnp.bfloat16).astype(dtype) - - def dot( - lhs: jnp.ndarray, - rhs: jnp.ndarray, - transpose_lhs: bool = False, - transpose_rhs: bool = False, - preferred_element_type: jnp.dtype = jnp.float32, - ) -> jnp.ndarray: - lhs = jnp.transpose(lhs) if transpose_lhs else lhs - rhs = jnp.transpose(rhs) if transpose_rhs else rhs - return jax.lax.dot(lhs, rhs, preferred_element_type=preferred_element_type) - - def reference_gmm( - lhs: jnp.ndarray, - rhs: jnp.ndarray, - group_sizes: jnp.ndarray, - preferred_element_type: jnp.dtype = jnp.float32, - ) -> jnp.ndarray: - - start = 0 - out = [] - for i, size in enumerate(group_sizes): - result = dot( - lhs[start : start + size, :], - rhs[i, :, :], - preferred_element_type=preferred_element_type, - ) + def gmm_test( + self, + m: int, + k: int, + n: int, + data: hps.SearchStrategy[hps.DataObject], + interpret: bool = False, + ): + seed = data.draw(seed_strategy()) + num_groups, _ = data.draw(group_strategy(max_stride=1)) + lhs_dtype, rhs_dtype, out_dtype = ( + data.draw(hps.sampled_from([jnp.float32, jnp.bfloat16])) + for _ in range(3) + ) + transpose_rhs = data.draw(hps.booleans()) + + key = jax.random.key(seed) + k1, k2 = jax.random.split(key, 2) + lhs = random_dense((m, k), k1, lhs_dtype, limit=1) + rhs = random_dense((num_groups, k, n), k2, rhs_dtype, limit=1) + group_sizes = data.draw(group_sizes_strategy(m=m, num_groups=num_groups)) + + out, vjpfun = jax.vjp( + partial( + mblx.gmm, + preferred_element_type=out_dtype, + transpose_rhs=transpose_rhs, + interpret=interpret, + ), + lhs, + rhs.swapaxes(1, 2) if transpose_rhs else rhs, + group_sizes, + ) - out.append(result) - start += group_sizes[i] - return jnp.concatenate(out, axis=0) - - def with_dtype_arguments(xs: tuple[Any, ...]) -> tuple[Any, ...]: - dtypes = [jnp.float32, jnp.bfloat16] - - result = [] - for x in xs: - for dtypes_tuple in itertools.product(dtypes, dtypes, dtypes): - result.append(x + dtypes_tuple) - return tuple(result) - - def with_transpose_argument(xs: tuple[Any, ...]) -> tuple[Any, ...]: - flags = [False, True] - result = [] - for x in xs: - for flag in flags: - result.append(x + (flag,)) - return tuple(result) - - def tolerances( - lhs_dtype: jnp.dtype, rhs_dtype: jnp.dtype, out_dtype: jnp.dtype - ) -> tuple[float, float]: - if ( - lhs_dtype == jnp.bfloat16 - or rhs_dtype == jnp.bfloat16 - or out_dtype == jnp.bfloat16 - ): - return 1e-3, 1e-2 # atol, rtol - return 1e-3, 1e-5 # atol, rtol - - # TODO(tgale): Fix errors with strict dtype promotion. - @jtu.with_config(jax_numpy_dtype_promotion="standard") - class GroupedMatmulTest(jtu.JaxTestCase): - - def setUp(self): - if not jtu.test_device_matches(["tpu"]): - self.skipTest("Test requires TPU device.") - - super().setUp() - self.key = jax.random.PRNGKey(1234) - - def assert_allclose( - self, - out: jnp.ndarray, - expected_out: jnp.ndarray, - *, - atol: float = 1e-5, - rtol: float = 1e-5, - ): - self.assertEqual(out.dtype, expected_out.dtype) - np.testing.assert_allclose( - out.astype(jnp.float32), - expected_out.astype(jnp.float32), - atol=atol, - rtol=rtol, + def reference_fn(lhs, rhs, group_sizes, preferred_element_type): + rhs = rhs.swapaxes(1, 2) if transpose_rhs else rhs + return reference_gmm( + lhs, rhs, group_sizes, preferred_element_type=preferred_element_type ) - def gmm_test( - self, - m: int, - k: int, - n: int, - data: hps.SearchStrategy[hps.DataObject], - interpret: bool = False, - ): - seed = data.draw(seed_strategy()) - num_groups, _ = data.draw(group_strategy(max_stride=1)) - lhs_dtype, rhs_dtype, out_dtype = [ - data.draw(hps.sampled_from([jnp.float32, jnp.bfloat16])) - for _ in range(3) - ] - transpose_rhs = data.draw(hps.booleans()) - - key = jax.random.key(seed) - k1, k2 = jax.random.split(key, 2) - lhs = random_dense((m, k), k1, lhs_dtype, limit=1) - rhs = random_dense((num_groups, k, n), k2, rhs_dtype, limit=1) - group_sizes = data.draw(group_sizes_strategy(m=m, num_groups=num_groups)) - - out, vjpfun = jax.vjp( - partial( - mblx.gmm, - preferred_element_type=out_dtype, - transpose_rhs=transpose_rhs, - interpret=interpret, - ), - lhs, - rhs.swapaxes(1, 2) if transpose_rhs else rhs, - group_sizes, - ) + expected_out, reference_vjpfun = jax.vjp( + partial(reference_fn, preferred_element_type=out_dtype), + lhs, + rhs.swapaxes(1, 2) if transpose_rhs else rhs, + group_sizes, + ) + self.assertEqual(out.dtype, out_dtype) + self.assertEqual(expected_out.dtype, out_dtype) + + atol, rtol = tolerances(lhs_dtype, rhs_dtype, out_dtype) + self.assert_allclose(out, expected_out, atol=atol, rtol=rtol) + + cotangent = random_dense((m, n), k1, out_dtype, limit=1) + grad_lhs, grad_rhs, *_ = vjpfun(cotangent) + expected_grad_lhs, expected_grad_rhs, *_ = reference_vjpfun(cotangent) + self.assert_allclose(grad_lhs, expected_grad_lhs, atol=atol, rtol=rtol) + self.assert_allclose(grad_rhs, expected_grad_rhs, atol=atol, rtol=rtol) + + @parameterized.parameters(*GROUPED_MATMUL_TESTS) + @hp.given(hps.data()) + def test_gmm( + self, + m: int, + k: int, + n: int, + data: hps.SearchStrategy[hps.DataObject], + ): + self.gmm_test(m, k, n, data) + + # NOTE: Run fewer tests with interpret mode. We just want to sanity check that + # changes do not break running these kernels with interpret=True. + @parameterized.parameters(*GROUPED_MATMUL_TESTS[0:1]) + @hp.given(hps.data()) + def test_gmm_interpret( + self, + m: int, + k: int, + n: int, + data: hps.SearchStrategy[hps.DataObject], + ): + self.skipTest("interpret mode with dynamic grids is unsupported") + self.gmm_test( + m, + k, + n, + data=data, + interpret=True, + ) - def reference_fn(lhs, rhs, group_sizes, preferred_element_type): - rhs = rhs.swapaxes(1, 2) if transpose_rhs else rhs - return reference_gmm( - lhs, rhs, group_sizes, preferred_element_type=preferred_element_type - ) + @parameterized.parameters(*GROUPED_MATMUL_TESTS) + @hp.given(hps.data()) + def test_gmm_sharded_groups( + self, + m: int, + k: int, + n: int, + data: hps.SearchStrategy[hps.DataObject], + ): + seed = data.draw(seed_strategy()) + num_groups, group_stride = data.draw(group_strategy()) + lhs_dtype, rhs_dtype, out_dtype = ( + data.draw(hps.sampled_from([jnp.float32, jnp.bfloat16])) + for _ in range(3) + ) - expected_out, reference_vjpfun = jax.vjp( - partial(reference_fn, preferred_element_type=out_dtype), + key = jax.random.key(seed) + k1, k2 = jax.random.split(key, 2) + lhs = random_dense((m, k), k1, lhs_dtype, limit=1) + rhs = random_dense((num_groups, k, n), k2, rhs_dtype, limit=1) + group_sizes = data.draw(group_sizes_strategy(m=m, num_groups=num_groups)) + + out, shard_vjpfun = jax.vjp( + partial(mblx.gmm, preferred_element_type=out_dtype), + lhs, + rhs[0:group_stride], + group_sizes, + ) + vjpfuns = [shard_vjpfun] + for group_offset in range(group_stride, num_groups, group_stride): + out, shard_vjpfun = jax.vjp( + lambda lhs, rhs, group_sizes, out: mblx.gmm( + lhs, + rhs, + group_sizes, + out_dtype, + group_offset=jnp.array(group_offset, dtype=jnp.int32), # pylint: disable=cell-var-from-loop + existing_out=out, + ), lhs, - rhs.swapaxes(1, 2) if transpose_rhs else rhs, + rhs[group_offset : group_offset + group_stride], group_sizes, + out, ) - self.assertEqual(out.dtype, out_dtype) - self.assertEqual(expected_out.dtype, out_dtype) - - atol, rtol = tolerances(lhs_dtype, rhs_dtype, out_dtype) - self.assert_allclose(out, expected_out, atol=atol, rtol=rtol) - - cotangent = random_dense((m, n), k1, out_dtype, limit=1) - grad_lhs, grad_rhs, *_ = vjpfun(cotangent) - expected_grad_lhs, expected_grad_rhs, *_ = reference_vjpfun(cotangent) - self.assert_allclose(grad_lhs, expected_grad_lhs, atol=atol, rtol=rtol) - self.assert_allclose(grad_rhs, expected_grad_rhs, atol=atol, rtol=rtol) - - @parameterized.parameters(*GROUPED_MATMUL_TESTS) - @hp.given(hps.data()) - def test_gmm( - self, - m: int, - k: int, - n: int, - data: hps.SearchStrategy[hps.DataObject], - ): - self.gmm_test(m, k, n, data) - - # NOTE: Run fewer tests with interpret mode. We just want to sanity check that - # changes do not break running these kernels with interpret=True. - @parameterized.parameters(*GROUPED_MATMUL_TESTS[0:1]) - @hp.given(hps.data()) - def test_gmm_interpret( - self, - m: int, - k: int, - n: int, - data: hps.SearchStrategy[hps.DataObject], - ): - self.skipTest("interpret mode with dynamic grids is unsupported") - self.gmm_test( - m, - k, - n, - data=data, - interpret=True, - ) + vjpfuns.append(shard_vjpfun) - @parameterized.parameters(*GROUPED_MATMUL_TESTS) - @hp.given(hps.data()) - def test_gmm_sharded_groups( - self, - m: int, - k: int, - n: int, - data: hps.SearchStrategy[hps.DataObject], + expected_out, reference_vjpfun = jax.vjp( + partial(reference_gmm, preferred_element_type=out_dtype), + lhs, + rhs, + group_sizes, + ) + self.assertEqual(out.dtype, out_dtype) + self.assertEqual(expected_out.dtype, out_dtype) + atol, rtol = tolerances(lhs_dtype, rhs_dtype, out_dtype) + self.assert_allclose(out, expected_out, atol=atol, rtol=rtol) + + cotangent = random_dense((m, n), k1, out_dtype, limit=1) + shard_grad_lhs, shard_grad_rhs, *_ = vjpfuns[0](cotangent) + grad_lhs = shard_grad_lhs + grad_rhs = [shard_grad_rhs] + for i, group_offset in enumerate( + range(group_stride, num_groups, group_stride) ): - seed = data.draw(seed_strategy()) - num_groups, group_stride = data.draw(group_strategy()) - lhs_dtype, rhs_dtype, out_dtype = [ - data.draw(hps.sampled_from([jnp.float32, jnp.bfloat16])) - for _ in range(3) - ] - - key = jax.random.key(seed) - k1, k2 = jax.random.split(key, 2) - lhs = random_dense((m, k), k1, lhs_dtype, limit=1) - rhs = random_dense((num_groups, k, n), k2, rhs_dtype, limit=1) - group_sizes = data.draw(group_sizes_strategy(m=m, num_groups=num_groups)) - - out, shard_vjpfun = jax.vjp( - partial(mblx.gmm, preferred_element_type=out_dtype), - lhs, - rhs[0:group_stride], - group_sizes, - ) - vjpfuns = [shard_vjpfun] - for group_offset in range(group_stride, num_groups, group_stride): - out, shard_vjpfun = jax.vjp( - lambda lhs, rhs, group_sizes, out: mblx.gmm( - lhs, - rhs, - group_sizes, - out_dtype, - group_offset=jnp.array(group_offset, dtype=jnp.int32), # pylint: disable=cell-var-from-loop - existing_out=out, - ), - lhs, - rhs[group_offset : group_offset + group_stride], - group_sizes, - out, - ) - vjpfuns.append(shard_vjpfun) - - expected_out, reference_vjpfun = jax.vjp( - partial(reference_gmm, preferred_element_type=out_dtype), - lhs, - rhs, - group_sizes, - ) - self.assertEqual(out.dtype, out_dtype) - self.assertEqual(expected_out.dtype, out_dtype) - atol, rtol = tolerances(lhs_dtype, rhs_dtype, out_dtype) - self.assert_allclose(out, expected_out, atol=atol, rtol=rtol) - - cotangent = random_dense((m, n), k1, out_dtype, limit=1) - shard_grad_lhs, shard_grad_rhs, *_ = vjpfuns[0](cotangent) - grad_lhs = shard_grad_lhs - grad_rhs = [shard_grad_rhs] - for i, group_offset in enumerate( - range(group_stride, num_groups, group_stride) - ): - shard_grad_lhs, shard_grad_rhs, *_ = vjpfuns[i + 1](cotangent) - grad_lhs += shard_grad_lhs - grad_rhs.append(shard_grad_rhs) - grad_rhs = jnp.concatenate(grad_rhs, axis=0) - expected_grad_lhs, expected_grad_rhs, *_ = reference_vjpfun(cotangent) - self.assert_allclose(grad_lhs, expected_grad_lhs, atol=atol, rtol=rtol) - self.assert_allclose(grad_rhs, expected_grad_rhs, atol=atol, rtol=rtol) + shard_grad_lhs, shard_grad_rhs, *_ = vjpfuns[i + 1](cotangent) + grad_lhs += shard_grad_lhs + grad_rhs.append(shard_grad_rhs) + grad_rhs = jnp.concatenate(grad_rhs, axis=0) + expected_grad_lhs, expected_grad_rhs, *_ = reference_vjpfun(cotangent) + self.assert_allclose(grad_lhs, expected_grad_lhs, atol=atol, rtol=rtol) + self.assert_allclose(grad_rhs, expected_grad_rhs, atol=atol, rtol=rtol) if __name__ == "__main__": diff --git a/tests/pallas/tpu_ops_test.py b/tests/pallas/tpu_ops_test.py index c8def2627462..a67e74d617b6 100644 --- a/tests/pallas/tpu_ops_test.py +++ b/tests/pallas/tpu_ops_test.py @@ -15,7 +15,6 @@ import functools import math import sys -import unittest from absl.testing import absltest from absl.testing import parameterized @@ -32,26 +31,43 @@ else: pltpu = None -try: - import hypothesis as hp -except (ModuleNotFoundError, ImportError): - raise unittest.SkipTest("tests depend on hypothesis library") - +import hypothesis as hp import hypothesis.strategies as hps + jax.config.parse_flags_with_absl() jtu.setup_hypothesis(max_examples=100) -_JAX_DTYPES = ( +_JAX_DTYPES_NO_BOOL = ( jnp.float32, jnp.bfloat16, jnp.int32, jnp.int16, jnp.int8, + jnp.int4, + jnp.float8_e5m2, +) + +_JAX_DTYPES = ( + *_JAX_DTYPES_NO_BOOL, jnp.bool_, ) +def rand( + shape: tuple[int, ...], dtype: np.dtype | jnp.dtype, seed: int = 1234 +) -> np.ndarray: + """A helper function to generate random data for testing.""" + rng = np.random.Generator(np.random.Philox(counter=0, key=seed)) + if jnp.issubdtype(dtype, jnp.floating): + return rng.normal(size=shape).astype(dtype) + if jnp.issubdtype(dtype, jnp.integer): + return rng.integers( + jnp.iinfo(dtype).min, jnp.iinfo(dtype).max, shape, dtype=np.int32 + ).astype(dtype) + raise NotImplementedError(f"Unsupported random data generation for {dtype=}") + + class PallasBaseTest(jtu.JaxTestCase): INTERPRET = False @@ -66,6 +82,7 @@ def pallas_call(cls, *args, **kwargs): return pl.pallas_call(*args, interpret=cls.INTERPRET, **kwargs) +@jtu.thread_unsafe_test_class() # hypothesis is not thread safe class OpsTest(PallasBaseTest): @parameterized.product( @@ -180,8 +197,18 @@ def kernel(x_ref, y_ref, out_ref): def test_row_broadcast(self, dtype): if not jtu.if_cloud_tpu_at_least(2025, 1, 10): self.skipTest("Requires libtpu built after 2025-01-10") - if not self.INTERPRET and jtu.get_tpu_version() < 5: - self.skipTest("Requires TPUv5+") + bitwidth = pallas_utils.dtype_bitwidth(dtype) + if not self.INTERPRET and jtu.get_tpu_version() < 4 and bitwidth < 8: + self.skipTest("Requires TPUv4+ for sub-byte types") + if ( + not self.INTERPRET + and jtu.get_tpu_version() == 4 + and bitwidth < 16 + and not jtu.if_cloud_tpu_at_least(2025, 6, 2) + ): + self.skipTest( + "Requires libtpu built after 2025-06-02 for bitwidth < 16 on TPUv4" + ) def kernel(x_ref, y_ref): y_ref[...] = jnp.broadcast_to(x_ref[pl.ds(3, 1)], y_ref.shape).astype(y_ref.dtype) m, n = 4, 1152 @@ -490,7 +517,71 @@ def kernel(x, out): expected = dot(x[:], jnp.ones((1, d), jnp.bfloat16)) np.testing.assert_array_equal(output, expected) + # We need to manually run the test with the env variable + # `export LIBTPU_INIT_ARGS="--xla_jf_bounds_check=true"` + def test_disable_bounds_check(self): + if not jtu.if_cloud_tpu_at_least(2025, 4, 16): + self.skipTest("Requires libtpu built after 2025-04-16") + if jtu.get_tpu_version() < 4: + self.skipTest("Requires TPUv4+") + src_shape = (8, 128) + tgt_shape = (16, 256) + + def kernel(src, tgt): + tgt[:] = pl.load(src, tuple(pl.ds(0, d) for d in tgt.shape)) + + x = jnp.arange(np.prod(src_shape), dtype=jnp.float32).reshape(src_shape) + run = pl.pallas_call( + kernel, + jax.ShapeDtypeStruct(tgt_shape, jnp.float32), + compiler_params=pltpu.CompilerParams(disable_bounds_checks=True), + ) + output = run(x) + np.testing.assert_array_equal( + output[tuple(slice(0, d) for d in src_shape)], x + ) + + # TODO(jevinjiang): we need to support strided load for bool. + @parameterized.product(dtype=_JAX_DTYPES_NO_BOOL) + @hp.given( + slice_start=hps.integers(0, 3), + slice_size=hps.integers(1, 3), + m=hps.integers(1, 32), + # Need to make sure the 2nd minor has no padding. + n=hps.sampled_from([1, 2, 4, 8, 16, 24, 32]), + ) + @hp.settings(max_examples=20) # 20 examples for each dtype. + def test_load_to_reshape(self, dtype, slice_start, slice_size, m, n): + if not jtu.if_cloud_tpu_at_least(2025, 5, 15): + self.skipTest("Requires libtpu built after 2025-05-15") + bitwidth = pallas_utils.dtype_bitwidth(dtype) + if jtu.get_tpu_version() < 4 and bitwidth != 32: + self.skipTest("Requires TPUv4+ for non-32-bit types") + if jtu.get_tpu_version() == 4 and bitwidth <= 8: + self.skipTest("Int8 is not supported on this target") + packing = 32 // bitwidth + n *= packing + slices = ( + slice(slice_start, slice_start + slice_size), + slice(slice_start, slice_start + m), + slice(None), + slice(None), + ) + inp_shape = (8, 64, n, 128) + out_shape = (slice_size, m, n * 128) + + def kernel(inp_ref, out_ref): + inp = inp_ref[slices] + out_ref[...] = inp.reshape(out_shape) + + inp = rand(inp_shape, dtype, seed=1234) + run = pl.pallas_call(kernel, jax.ShapeDtypeStruct(out_shape, dtype)) + output = run(inp) + expected = inp[slices].reshape(out_shape) + np.testing.assert_array_equal(output, expected) + +@jtu.thread_unsafe_test_class() # hypothesis is not thread safe class OpsInterpretTest(OpsTest): INTERPRET = True diff --git a/tests/pallas/tpu_paged_attention_kernel_test.py b/tests/pallas/tpu_paged_attention_kernel_test.py index 7fbccdb338d4..ac24fea1b45a 100644 --- a/tests/pallas/tpu_paged_attention_kernel_test.py +++ b/tests/pallas/tpu_paged_attention_kernel_test.py @@ -18,19 +18,176 @@ from jax._src import test_util as jtu from jax.experimental.pallas.ops.tpu import paged_attention from jax.experimental.pallas.ops.tpu.paged_attention import quantization_utils +from jax.experimental.pallas.ops.tpu.paged_attention import util import jax.numpy as jnp import numpy as np -jax.config.parse_flags_with_absl() +def _generate_qkv_simplest( + dtype: jnp.dtype, +) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array, jax.Array]: + """Generates queries with one query head, kv pages, and attention.""" + max_seq_len = 4 + seq_lens = jnp.asarray([max_seq_len // 2]) + assert seq_lens.shape == (1,) + + # q_shape = (batch_size=1, num_q_heads=1, head_dim=1) + queries = jnp.asarray([[[1.2]]], dtype) + assert queries.shape == (1, 1, 1) + + # kv_shape = (batch_size=1, num_kv_heads=1, max_seq_len=4, head_dim=1) + k_pages = jnp.asarray([[[[0.1], [0.2], [0.3], [0.4]]]], dtype) + v_pages = jnp.asarray([[[[4.0], [3.0], [2.0], [1.0]]]], dtype) + assert k_pages.shape == (1, 1, 4, 1) + assert v_pages.shape == k_pages.shape + + # q*k: [[[ [.12, .24, .36, .48] ]]] + # masked: [[[ [.12, .24, -inf, -inf] ]]] + # softmax: [[[ [.47, .53, 0, 0] ]]] + # softmax(q*k) * v: .47*4 + .53*3 + 0*... = 3.47 + attention = jnp.asarray([[[3.47]]], dtype) + assert attention.shape == queries.shape + return seq_lens, queries, k_pages, v_pages, attention + + +def _generate_qkv_with_one_q_head( + dtype: jnp.dtype, +) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array, jax.Array]: + """Generates queries with one query head, kv pages, and attention.""" + max_seq_len = 4 + seq_lens = jnp.asarray([max_seq_len - 1]) + assert seq_lens.shape == (1,) + + # q_shape = (batch_size=1, num_q_heads=1, head_dim=1) + queries = jnp.asarray([[[1.7]]], dtype) + assert queries.shape == (1, 1, 1) + + # kv_shape = (batch_size=1, num_kv_heads=1, max_seq_len=4, head_dim=1) + k_pages = jnp.asarray([[[[0.12], [0.23], [0.34], [0.45]]]], dtype) + v_pages = jnp.asarray([[[[4.32], [3.21], [2.10], [1.09]]]], dtype) + assert k_pages.shape == (1, 1, 4, 1) + assert v_pages.shape == k_pages.shape + + # q*k: [[[ [.204, .391, .578, .765] ]]] + # masked: [[[ [.204, .391, .578, -inf] ]]] + # softmax: [[[ [.273, .330, .397, 0] ]]] + # softmax(q*k) * v: .273*4.32 + .330*3.21 + .397*2.10 + 0*... = 3.0723 + attention = jnp.asarray([[[3.0723]]], dtype) + assert attention.shape == queries.shape + return seq_lens, queries, k_pages, v_pages, attention + + +def _generate_qkv_with_two_q_heads( + dtype: jnp.dtype, +) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array, jax.Array]: + """Generates queries with two query heads, kv pages, and attention.""" + max_seq_len = 4 + seq_lens = jnp.asarray([max_seq_len]) + assert seq_lens.shape == (1,) + + # q_shape = (batch_size=1, num_q_heads=2, head_dim=1) + queries = jnp.asarray([[[1.3], [9.7]]], dtype) + assert queries.shape == (1, 2, 1) + + # kv_shape = (batch_size=1, num_kv_heads=1, max_seq_len=4, head_dim=1) + k_pages = jnp.asarray([[[[0.12], [0.23], [0.34], [0.45]]]], dtype) + v_pages = jnp.asarray([[[[4.32], [3.21], [2.10], [1.09]]]], dtype) + assert k_pages.shape == (1, 1, 4, 1) + assert v_pages.shape == k_pages.shape + + # q*k: [[[ [ .156, .299, .442, .585], + # [1.164, 2.231, 3.298, 4.365] ]]] + # softmax: [[[ [ .199, .230, .265, .306], + # [ .027, .079, .229, .665] ]]] + # softmax(q*k) * v: .199*4.32 + .230*3.21 + .265*2.10 + .306*1.09 = 2.488 + # softmax(q*k) * v: .027*4.32 + .079*3.21 + .229*2.10 + .665*1.09 = 1.576 + attention = jnp.asarray([[[2.488], [1.576]]], dtype) + assert attention.shape == queries.shape + return seq_lens, queries, k_pages, v_pages, attention + + +def _generate_qkv_with_head_dim_two( + dtype: jnp.dtype, +) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array, jax.Array]: + """Generates queries, kv pages, and attention with head_dim=2.""" + max_seq_len = 4 + seq_lens = jnp.asarray([max_seq_len // 2]) + assert seq_lens.shape == (1,) + + # q_shape = (batch_size=1, num_q_heads=1, head_dim=2) + queries = jnp.asarray([[[1.2, 9.0]]], dtype) + assert queries.shape == (1, 1, 2) + + # kv_shape = (batch_size=1, num_kv_heads=1, max_seq_len=4, head_dim=2) + k_pages = jnp.asarray( + [[[[0.1, 0.2], [0.2, 0.3], [0.3, 0.4], [0.4, 0.5]]]], dtype + ) + v_pages = jnp.asarray( + [[[[4.0, 5.0], [3.0, 6.0], [2.0, 7.0], [1.0, 8.0]]]], dtype + ) + assert k_pages.shape == (1, 1, 4, 2) + assert v_pages.shape == k_pages.shape + + # q*k: [[[ [ 1.92, 2.94, 3.96, 4.98] ]]] + # masked: [[[ [ 1.92, 2.94, -inf, -inf] ]]] + # softmax: [[[ [ .265, .735, 0, 0] ]]] + # softmax(q*k) * v: .265*4 + 0.735*3 + 0*... = 3.265 + # softmax(q*k) * v: .265*5 + 0.735*6 + 0*... = 5.735 + attention = jnp.asarray([[[3.265, 5.735]]], dtype) + assert attention.shape == queries.shape + return seq_lens, queries, k_pages, v_pages, attention def _generate_qkv( + dtype: jnp.dtype, + case: int, +) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array, jax.Array]: + match case: + case 0: + return _generate_qkv_simplest(dtype) + case 1: + return _generate_qkv_with_one_q_head(dtype) + case 2: + return _generate_qkv_with_two_q_heads(dtype) + case 3: + return _generate_qkv_with_head_dim_two(dtype) + case _: + raise ValueError(f"Unsupported case: {case}") + + +@jtu.with_config(jax_numpy_dtype_promotion="standard") +class JaxGroupedQueryAttentionReferenceTest(jtu.JaxTestCase): + + @parameterized.product( + dtype=(jnp.float32, jnp.bfloat16), + case=(0, 1, 2, 3), + ) + def test_grouped_query_attention(self, dtype: jnp.dtype, case: int): + # generate queries, kv pages, and seq_lens + seq_lens, queries, k_pages, v_pages, expected = _generate_qkv(dtype, case) + jax.debug.print("seq_lens: {seq_lens}", seq_lens=seq_lens) + jax.debug.print("queries: {queries}", queries=queries) + jax.debug.print("k_pages: {k_pages}", k_pages=k_pages) + jax.debug.print("v_pages: {v_pages}", v_pages=v_pages) + jax.debug.print("expected: {expected}", expected=expected) + + # calculate grouped query attention + attention = util.grouped_query_attention_reference( + queries, k_pages, v_pages, seq_lens + ) + jax.debug.print("attention: {attention}", attention=attention) + + # compare the results + atol, rtol = (3e-3, 5e-3) if dtype == jnp.bfloat16 else (2e-4, 2e-4) + self.assertAllClose(attention, expected, atol=atol, rtol=rtol) + + +def _generate_random_qkv( seq_lens, page_size, max_seq_len, num_kv_heads, - num_heads, + num_q_heads, head_dim, prng_key, dtype=jnp.float32, @@ -55,7 +212,7 @@ def _generate_qkv( page_indices = jnp.arange(batch_size * pages_per_sequence, dtype=jnp.int32) page_indices = jax.random.permutation(k3, page_indices, independent=True) page_indices = page_indices.reshape(batch_size, pages_per_sequence) - q = jax.random.normal(k4, (batch_size, num_heads, head_dim), dtype=dtype) + q = jax.random.normal(k4, (batch_size, num_q_heads, head_dim), dtype=dtype) return q, k_pages, v_pages, page_indices @@ -64,7 +221,7 @@ def _reconstruct_kv(page_indices, pages): pages = quantization_utils.unquantize_from_int8(pages, dtype=jnp.float32) batch_size = page_indices.shape[0] - num_heads, _, _, head_dim = pages.shape + num_kv_heads, _, _, head_dim = pages.shape def per_sequence_page_gather(pages, page_indices): return jnp.take(pages, page_indices, 1) @@ -72,32 +229,7 @@ def per_sequence_page_gather(pages, page_indices): gathered = jax.vmap(per_sequence_page_gather, in_axes=(None, 0))( pages, page_indices ) - return gathered.reshape(batch_size, num_heads, -1, head_dim) - - -def _grouped_query_attention_reference(q, k, v, lengths, attn_logits_soft_cap): - batch_size, num_heads, head_dim = q.shape - _, num_kv_heads, max_seq_len, _ = k.shape - assert k.shape == v.shape - assert num_heads % num_kv_heads == 0 - q = q.reshape(batch_size, num_kv_heads, num_heads // num_kv_heads, head_dim) - - if isinstance(k, quantization_utils.QuantizedTensor): - k = quantization_utils.unquantize_from_int8(k, dtype=jnp.float32) - if isinstance(v, quantization_utils.QuantizedTensor): - v = quantization_utils.unquantize_from_int8(v, dtype=jnp.float32) - - logits = jnp.einsum( - "bhgd,bhtd->bhgt", q.astype(jnp.float32), k.astype(jnp.float32) - ) - if attn_logits_soft_cap is not None: - logits = jnp.tanh(logits / attn_logits_soft_cap) * attn_logits_soft_cap - mask = jnp.arange(max_seq_len)[None] < lengths[:, None] - mask_value = -0.7 * float(np.finfo(np.dtype("float32")).max) - logits = logits + jnp.where(mask, 0.0, mask_value)[:, None, None, :] - weights = jax.nn.softmax(logits, axis=-1) - o = jnp.einsum("bhgt,bhtd->bhgd", weights.astype(v.dtype), v) - return o.reshape(batch_size, num_heads, head_dim) + return gathered.reshape(batch_size, num_kv_heads, -1, head_dim) def _megacore_enabled(): @@ -149,7 +281,7 @@ def test_paged_attention( max_kv_len = 2048 block_size = 512 seq_lens = np.asarray([0, 3, 256, 513, 1023, 2048]) - q, k_pages, v_pages, page_indices = _generate_qkv( + q, k_pages, v_pages, page_indices = _generate_random_qkv( seq_lens, page_size, max_kv_len, @@ -172,8 +304,9 @@ def test_paged_attention( ) k = _reconstruct_kv(page_indices, k_pages) v = _reconstruct_kv(page_indices, v_pages) - o_ref = _grouped_query_attention_reference( - q, k, v, seq_lens, attn_logits_soft_cap) + o_ref = util.grouped_query_attention_reference( + q, k, v, seq_lens, attn_logits_soft_cap + ) if q_kv_head_ratio > 1: atol, rtol = 1e-2, 2e-2 @@ -188,4 +321,5 @@ def test_paged_attention( if __name__ == "__main__": + jax.config.config_with_absl() absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/pallas/tpu_pallas_async_test.py b/tests/pallas/tpu_pallas_async_test.py index 3dfc9bf1637a..68c51b63c183 100644 --- a/tests/pallas/tpu_pallas_async_test.py +++ b/tests/pallas/tpu_pallas_async_test.py @@ -22,7 +22,7 @@ from jax._src import test_util as jtu from jax._src.state import discharge as state_discharge from jax.experimental import pallas as pl -from jax.experimental import shard_map +from jax._src import shard_map from jax.experimental.pallas import tpu as pltpu import jax.numpy as jnp import numpy as np @@ -204,6 +204,9 @@ def setUp(self): super().setUp() if not jtu.is_device_tpu_at_least(4): self.skipTest('DMAs only guaranteed to work ou TPU v4+') + # TODO(subhankarshah): Remove after all required changes are in. + if not jtu.if_cloud_tpu_at_least(2025, 6, 30): + self.skipTest('Requires libtpu built after 2025-06-20') def test_basic_async_copy(self): @jax.jit @@ -398,7 +401,7 @@ def copy_start(x: jax.Array) -> tuple[jax.Array, Future]: def copy_start_kernel(x_ref, aliased_x_ref, o_ref, send_sem, recv_sem): del aliased_x_ref - axis_size = jax.lax.psum(1, axis_name) + axis_size = jax.lax.axis_size(axis_name) left_neighbor = jax.lax.rem( jax.lax.axis_index(axis_name) - 1 + axis_size, axis_size ) @@ -412,7 +415,7 @@ def copy_start_kernel(x_ref, aliased_x_ref, o_ref, send_sem, recv_sem): src_neighbor = right_neighbor dst_neighbor = left_neighbor barrier_sem = pltpu.get_barrier_semaphore() - pltpu.semaphore_signal(barrier_sem, device_id=src_neighbor, core_index=0) + pltpu.semaphore_signal(barrier_sem, device_id=src_neighbor) pltpu.semaphore_wait(barrier_sem, 1) pltpu.make_async_remote_copy( x_ref, o_ref, send_sem, recv_sem, device_id=dst_neighbor, @@ -436,7 +439,7 @@ def copy_start_kernel(x_ref, aliased_x_ref, o_ref, send_sem, recv_sem): pl.BlockSpec(memory_space=pltpu.SEMAPHORE), ), input_output_aliases={0: 0}, - compiler_params=pltpu.TPUCompilerParams( + compiler_params=pltpu.CompilerParams( collective_id=0, has_side_effects=True ), )(x) @@ -492,7 +495,7 @@ def copy_start(x: jax.Array) -> tuple[jax.Array, Future]: def copy_start_kernel(x_ref, aliased_x_ref, o_ref, left_sems, right_sems): del aliased_x_ref - axis_size = jax.lax.psum(1, axis_name) + axis_size = jax.lax.axis_size(axis_name) left_neighbor = jax.lax.rem( jax.lax.axis_index(axis_name) - 1 + axis_size, axis_size ) @@ -500,10 +503,8 @@ def copy_start_kernel(x_ref, aliased_x_ref, o_ref, left_sems, right_sems): jax.lax.axis_index(axis_name) + 1, axis_size ) barrier_sem = pltpu.get_barrier_semaphore() - pltpu.semaphore_signal(barrier_sem, device_id=left_neighbor, core_index=0) - pltpu.semaphore_signal( - barrier_sem, device_id=right_neighbor, core_index=0 - ) + pltpu.semaphore_signal(barrier_sem, device_id=left_neighbor) + pltpu.semaphore_signal(barrier_sem, device_id=right_neighbor) pltpu.semaphore_wait(barrier_sem, 2) assert x.shape[0] % 2 == 0, x.shape pltpu.make_async_remote_copy( @@ -539,7 +540,7 @@ def copy_start_kernel(x_ref, aliased_x_ref, o_ref, left_sems, right_sems): (pl.BlockSpec(memory_space=pltpu.SEMAPHORE),) * 2, ), input_output_aliases={0: 0}, - compiler_params=pltpu.TPUCompilerParams( + compiler_params=pltpu.CompilerParams( collective_id=0, has_side_effects=False ), )(x) @@ -625,7 +626,7 @@ def test_basic_remote_copy(self): @jax.jit @partial( shard_map.shard_map, mesh=mesh, in_specs=(P('x'),), out_specs=P('x'), - check_rep=False, + check_vma=False, ) def f(x): copy_start, send_done, recv_done = make_async_remote_copy('x') @@ -648,7 +649,7 @@ def test_multi_remote_copy(self): @jax.jit @partial( shard_map.shard_map, mesh=mesh, in_specs=(P('x'),), out_specs=P('x'), - check_rep=False, + check_vma=False, ) def f(x): copy_start, send_done, recv_done = make_async_remote_copy( @@ -681,7 +682,7 @@ def test_basic_collective_permute_loop(self): @jax.jit @partial( shard_map.shard_map, mesh=mesh, in_specs=(P('x'),), out_specs=P('x'), - check_rep=False, + check_vma=False, ) def f(x): copy_start, send_done, recv_done = make_async_remote_copy('x') @@ -706,7 +707,7 @@ def test_staggered_collective_permute_loop(self): @jax.jit @partial( shard_map.shard_map, mesh=mesh, in_specs=(P('x'),), out_specs=P('x'), - check_rep=False, + check_vma=False, ) def f(x): assert x.shape[0] == 1 @@ -739,7 +740,7 @@ def test_bidi_collective_permute_loop(self): @jax.jit @partial( shard_map.shard_map, mesh=mesh, in_specs=(P('x'),), out_specs=P('x'), - check_rep=False, + check_vma=False, ) def f(x): assert x.shape[0] == 1 @@ -832,6 +833,9 @@ def setUp(self): super().setUp() if not jtu.is_device_tpu_at_least(4): self.skipTest('DMAs only guaranteed to work ou TPU v4+') + # TODO(subhankarshah): Remove after all required changes are in. + if not jtu.if_cloud_tpu_at_least(2025, 6, 30): + self.skipTest('Requires libtpu built after 2025-06-20') def test_basic_stateful_async_copy(self): @jax.jit diff --git a/tests/pallas/tpu_pallas_distributed_test.py b/tests/pallas/tpu_pallas_distributed_test.py index f7d7daf1874f..4b7bc06463bd 100644 --- a/tests/pallas/tpu_pallas_distributed_test.py +++ b/tests/pallas/tpu_pallas_distributed_test.py @@ -22,7 +22,7 @@ from jax._src import test_util as jtu from jax.experimental import mesh_utils from jax.experimental import pallas as pl -from jax.experimental import shard_map +from jax._src import shard_map from jax.experimental.pallas import tpu as pltpu import jax.numpy as jnp import numpy as np @@ -44,15 +44,14 @@ def setUp(self): self.skipTest('Only works with TPU v5e.') @parameterized.named_parameters( - ('vmem', pltpu.TPUMemorySpace.VMEM), - ('hbm', pltpu.TPUMemorySpace.ANY), + ('vmem', pltpu.VMEM), + ('hbm', pltpu.ANY), ) def test_basic_remote_vmem_dma(self, mem): # Implements very simple collective permute def kernel(x_ref, y_ref): def body(ready_sem, send_sem, recv_sem): - dev_id = pltpu.device_id() - other_dev_id = 1 - dev_id + other_dev_id = 1 - lax.axis_index('x') pltpu.semaphore_signal(ready_sem, device_id=other_dev_id, device_id_type=pltpu.DeviceIdType.LOGICAL) pltpu.semaphore_wait(ready_sem) @@ -77,19 +76,66 @@ def body(x): kernel, in_specs=[pl.BlockSpec(memory_space=mem)], out_specs=pl.BlockSpec(memory_space=mem), - out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), + out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32, vma=frozenset('x')), )(x) devices = jax.devices()[:2] mesh = jax.sharding.Mesh(devices, ['x']) - y = jax.jit( + f = jax.jit( shard_map.shard_map( - body, mesh, in_specs=P('x'), out_specs=P('x'), check_rep=False + body, mesh=mesh, in_specs=P('x'), out_specs=P('x'), ) - )(x) + ) + jaxpr = f.trace(x).jaxpr + self.assertNotIn('pvary', str(jaxpr)) + y = f(x) expected = jnp.concatenate([x[8:], x[:8]]) np.testing.assert_allclose(y, expected) + def test_vma_error(self): + def kernel(x_ref, y_ref): + def body(ready_sem, send_sem, recv_sem): + other_dev_id = 1 - lax.axis_index('x') + pltpu.semaphore_signal(ready_sem, device_id=other_dev_id, + device_id_type=pltpu.DeviceIdType.LOGICAL) + pltpu.semaphore_wait(ready_sem) + copy_done = pltpu.async_remote_copy( + x_ref, y_ref, send_sem, recv_sem, other_dev_id, + device_id_type=pltpu.DeviceIdType.LOGICAL, + ) + copy_done.wait_send() + copy_done.wait_recv() + + pl.run_scoped( + body, + pltpu.SemaphoreType.REGULAR, + pltpu.SemaphoreType.DMA, + pltpu.SemaphoreType.DMA, + ) + + x = jnp.arange(2 * 8 * 128.0).reshape((2 * 8, 128)) + + def body(x): + return pl.pallas_call( + kernel, + in_specs=[pl.BlockSpec(memory_space=pltpu.ANY)], + out_specs=pl.BlockSpec(memory_space=pltpu.ANY), + out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), + )(x) + + devices = jax.devices()[:2] + mesh = jax.sharding.Mesh(devices, ['x']) + f = jax.jit( + shard_map.shard_map( + body, mesh=mesh, in_specs=P('x'), out_specs=P('x'), + ) + ) + with self.assertRaisesRegex( + ValueError, + 'When `check_vma=True` on `jax.shard_map`, `vma` on' + ' `jax.ShapeDtypeStruct` must not be `None`'): + f(x) + @parameterized.named_parameters( ('left', 'left'), ('right', 'right') @@ -99,7 +145,7 @@ def test_pallas_call_axis_index(self, direction): def kernel(x_ref, y_ref): def body(ready_sem, send_sem, recv_sem): my_id = lax.axis_index('x') - num_devices = lax.psum(1, 'x') + num_devices = lax.axis_size('x') if direction == 'right': neighbor = lax.rem(my_id + 1, num_devices) else: @@ -127,8 +173,8 @@ def body(ready_sem, send_sem, recv_sem): def body(x): return pl.pallas_call( kernel, - in_specs=[pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM)], - out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), + in_specs=[pl.BlockSpec(memory_space=pltpu.VMEM)], + out_specs=pl.BlockSpec(memory_space=pltpu.VMEM), out_shape=x, )(x) @@ -137,7 +183,7 @@ def body(x): mesh = jax.sharding.Mesh(device_mesh, ['x']) y = jax.jit( shard_map.shard_map( - body, mesh, in_specs=P('x'), out_specs=P('x'), check_rep=False + body, mesh=mesh, in_specs=P('x'), out_specs=P('x'), check_vma=False ) )(x) if direction == 'right': @@ -153,7 +199,7 @@ def kernel(x_ref, y_ref): def body(ready_sem, send_sem, recv_sem): my_id = lax.axis_index('x') my_other_id = lax.axis_index('y') - axis_size = lax.psum(1, 'x') + axis_size = lax.axis_size('x') if direction == 'right': neighbor = lax.rem(my_id + 1, axis_size) else: @@ -181,8 +227,8 @@ def body(ready_sem, send_sem, recv_sem): def body(x): return pl.pallas_call( kernel, - in_specs=[pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM)], - out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), + in_specs=[pl.BlockSpec(memory_space=pltpu.VMEM)], + out_specs=pl.BlockSpec(memory_space=pltpu.VMEM), out_shape=x, )(x) @@ -193,10 +239,10 @@ def body(x): y = jax.jit( shard_map.shard_map( body, - mesh, + mesh=mesh, in_specs=P('x', None), out_specs=P('x', None), - check_rep=False, + check_vma=False, ) )(x) if direction == 'right': @@ -209,7 +255,7 @@ def test_barrier_semaphore(self): def kernel(x_ref, y_ref): def body(ready_sem, send_sem, recv_sem): my_id = lax.axis_index('x') - num_devices = lax.psum(1, 'x') + num_devices = lax.axis_size('x') neighbor = lax.rem(my_id + 1, num_devices) barrier_sem = pltpu.get_barrier_semaphore() pltpu.semaphore_signal(barrier_sem, device_id=neighbor) @@ -233,10 +279,10 @@ def body(ready_sem, send_sem, recv_sem): def body(x): return pl.pallas_call( kernel, - in_specs=[pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM)], - out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), + in_specs=[pl.BlockSpec(memory_space=pltpu.VMEM)], + out_specs=pl.BlockSpec(memory_space=pltpu.VMEM), out_shape=x, - compiler_params=dict(mosaic=dict(collective_id=0)), + compiler_params=pltpu.CompilerParams(collective_id=0), )(x) device_mesh = mesh_utils.create_device_mesh( @@ -244,7 +290,7 @@ def body(x): mesh = jax.sharding.Mesh(device_mesh, ['x']) y = jax.jit( shard_map.shard_map( - body, mesh, in_specs=P('x'), out_specs=P('x'), check_rep=False + body, mesh=mesh, in_specs=P('x'), out_specs=P('x'), check_vma=False ) )(x) expected = jnp.concatenate([x[-8:], x[:-8]]) @@ -292,7 +338,7 @@ def test_kernel(x_ref, grid_spec = pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, in_specs=[ - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + pl.BlockSpec(memory_space=pltpu.ANY), ], scratch_shapes=( [pltpu.SemaphoreType.DMA] * 2 @@ -317,7 +363,7 @@ def test_kernel(x_ref, mesh=mesh, in_specs=P(None, 'x'), out_specs=P(None, 'x'), - check_rep=False)) + check_vma=False)) result = compiled_func(sharded_arr) perm = tuple((src, permute_fn(src)) for src in range(num_devices)) @@ -376,9 +422,9 @@ def _(): grid_spec = pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, in_specs=[ - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), + pl.BlockSpec(memory_space=pltpu.VMEM), ], - out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), + out_specs=pl.BlockSpec(memory_space=pltpu.VMEM), scratch_shapes=( [pltpu.SemaphoreType.DMA] * 2 ) @@ -403,7 +449,7 @@ def _(): mesh=mesh, in_specs=P(None, 'x'), out_specs=P(None, 'x'), - check_rep=False)) + check_vma=False)) result_interpret = compiled_func(sharded_arr) kernel = pl.pallas_call( @@ -416,7 +462,7 @@ def _(): mesh=mesh, in_specs=P(None, 'x'), out_specs=P(None, 'x'), - check_rep=False)) + check_vma=False)) result_noninterpret = compiled_func(sharded_arr) np.testing.assert_allclose(result_interpret, result_noninterpret, @@ -468,11 +514,11 @@ def _(): grid_spec = pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, in_specs=[ - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), + pl.BlockSpec(memory_space=pltpu.VMEM), ], out_specs=[ - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), + pl.BlockSpec(memory_space=pltpu.VMEM), + pl.BlockSpec(memory_space=pltpu.VMEM), ], scratch_shapes=( [pltpu.SemaphoreType.DMA] * 2 @@ -498,7 +544,7 @@ def _(): mesh=mesh, in_specs=P(None, 'x'), out_specs=P(None, 'x'), - check_rep=False)) + check_vma=False)) result_interpret = compiled_func(sharded_arr) kernel = pl.pallas_call( @@ -511,7 +557,7 @@ def _(): mesh=mesh, in_specs=P(None, 'x'), out_specs=P(None, 'x'), - check_rep=False)) + check_vma=False)) result_noninterpret = compiled_func(sharded_arr) np.testing.assert_allclose(result_interpret, result_noninterpret, @@ -569,7 +615,7 @@ def _(i, _): previous_config = jax.config.read('jax_pallas_dump_promela_to') jax.config.update('jax_pallas_dump_promela_to', tmpdir) shard_map.shard_map( - kernel, mesh=mesh, in_specs=P('x'), out_specs=P(None), check_rep=False + kernel, mesh=mesh, in_specs=P('x'), out_specs=P(None), check_vma=False )(jnp.ones((8, 128, 128), jnp.float32)) jax.config.update('jax_pallas_dump_promela_to', previous_config) self.assertNotEmpty(os.listdir(tmpdir)) diff --git a/tests/pallas/tpu_pallas_interpret_distributed_test.py b/tests/pallas/tpu_pallas_interpret_distributed_test.py index 518c16ed2109..62772bd7e298 100644 --- a/tests/pallas/tpu_pallas_interpret_distributed_test.py +++ b/tests/pallas/tpu_pallas_interpret_distributed_test.py @@ -18,8 +18,6 @@ contains only tests that use shard_map. """ -import functools - from absl.testing import absltest from absl.testing import parameterized @@ -28,7 +26,7 @@ from jax._src import test_util as jtu import jax._src.pallas.mosaic.interpret as mosaic_interpret from jax.experimental import pallas as pl -from jax.experimental import shard_map +from jax._src import shard_map from jax.experimental.pallas import tpu as pltpu import jax.numpy as jnp @@ -40,6 +38,9 @@ P = jax.sharding.PartitionSpec +# TODO(jburnim): Figure out how to safely run different instance of TPU +# interpret mode in parallel, and then remove this decorator. +@jtu.thread_unsafe_test_class() class InterpretDistributedTest(jtu.JaxTestCase): def setUp(self): super().setUp() @@ -91,11 +92,11 @@ def right_permute_kernel(input_ref, output_ref, send_sem, recv_sem): out_shape = jax.ShapeDtypeStruct((8, 128), jnp.float32) grid_spec = pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, - # TPUMemorySpace.ANY will (usually) place the tensor in HBM. + # MemorySpace.ANY will (usually) place the tensor in HBM. in_specs=[ - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + pl.BlockSpec(memory_space=pltpu.ANY), ], - out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + out_specs=pl.BlockSpec(memory_space=pltpu.ANY), scratch_shapes=( # We allocate DMA semaphores in scratch memory. [pltpu.SemaphoreType.DMA] * 2 @@ -105,8 +106,8 @@ def right_permute_kernel(input_ref, output_ref, send_sem, recv_sem): right_permute_kernel, out_shape=out_shape, grid_spec=grid_spec, - compiler_params=pltpu.TPUCompilerParams(collective_id=13), - interpret=mosaic_interpret.TPUInterpretParams( + compiler_params=pltpu.CompilerParams(collective_id=13), + interpret=pltpu.InterpretParams( dma_execution_mode=dma_execution_mode, detect_races=detect_races), ) # Wrap the kernel within a shard_map to call. @@ -116,7 +117,7 @@ def right_permute_kernel(input_ref, output_ref, send_sem, recv_sem): mesh=mesh, in_specs=partition, out_specs=partition, - check_rep=False, + check_vma=False, ) )(input_arr) @@ -205,10 +206,10 @@ def _(): grid_spec = pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, in_specs=[ - # TPUMemorySpace.ANY will (usually) place the tensor in HBM. - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + # MemorySpace.ANY will (usually) place the tensor in HBM. + pl.BlockSpec(memory_space=pltpu.ANY), ], - out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + out_specs=pl.BlockSpec(memory_space=pltpu.ANY), scratch_shapes=( # DMA semaphores are allocated in scratch memory. # We allocated one semaphore for a local HBM-VMEM copy, @@ -227,9 +228,9 @@ def _(): all_gather_kernel, out_shape=out_shape, grid_spec=grid_spec, - interpret=mosaic_interpret.TPUInterpretParams( + interpret=pltpu.InterpretParams( dma_execution_mode=dma_execution_mode, detect_races=detect_races), - compiler_params=pltpu.TPUCompilerParams(collective_id=0), + compiler_params=pltpu.CompilerParams(collective_id=0), ) # Wrap the kernel within a shard_map to call. @@ -239,7 +240,7 @@ def _(): mesh=mesh, in_specs=partition, out_specs=partition, - check_rep=False + check_vma=False ) )(input_arr) @@ -367,13 +368,13 @@ def _(): num_scalar_prefetch=0, in_specs=[ # Our input lives in VMEM - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), + pl.BlockSpec(memory_space=pltpu.VMEM), ], out_specs=[ # Our output lives in VMEM - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), + pl.BlockSpec(memory_space=pltpu.VMEM), # Our double-buffer lives in HBM - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + pl.BlockSpec(memory_space=pltpu.ANY), ], grid=(num_devices,), scratch_shapes=( @@ -387,9 +388,9 @@ def _(): all_reduce_kernel, out_shape=out_shape, grid_spec=grid_spec, - interpret=mosaic_interpret.TPUInterpretParams( + interpret=pltpu.InterpretParams( dma_execution_mode=dma_execution_mode, detect_races=detect_races), - compiler_params=pltpu.TPUCompilerParams(collective_id=0), + compiler_params=pltpu.CompilerParams(collective_id=0), ) pallas_result = jax.jit( @@ -398,7 +399,7 @@ def _(): mesh=mesh, in_specs=partition, out_specs=partition, - check_rep=False, + check_vma=False, ) )(input_arr) pallas_result = jax.block_until_ready(pallas_result)[0] @@ -649,11 +650,11 @@ def _(): grid_spec = pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, in_specs=[ - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), + pl.BlockSpec(memory_space=pltpu.VMEM), ], out_specs=[ - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + pl.BlockSpec(memory_space=pltpu.VMEM), + pl.BlockSpec(memory_space=pltpu.ANY), ], grid=(num_devices, 2), scratch_shapes=( @@ -671,9 +672,9 @@ def pallas_reduce_scatter(input_arr): reduce_scatter_kernel, out_shape=out_shape, grid_spec=grid_spec, - interpret=mosaic_interpret.TPUInterpretParams( + interpret=pltpu.InterpretParams( dma_execution_mode=dma_execution_mode, detect_races=True), - compiler_params=pltpu.TPUCompilerParams(collective_id=7), + compiler_params=pltpu.CompilerParams(collective_id=7), )(input_arr)[0] pallas_result = jax.jit( @@ -682,7 +683,7 @@ def pallas_reduce_scatter(input_arr): mesh=mesh, in_specs=P(None, 'x'), out_specs=P('x', None), - check_rep=False, + check_vma=False, ) )(input_arr) pallas_result = jax.block_until_ready(pallas_result) @@ -742,7 +743,7 @@ def test_reduce_scatter_sum_with_emit_pipeline_example( inner_block_spec = pl.BlockSpec( index_map=lambda i, j: (i, j), block_shape=inner_block_size, - memory_space=pltpu.TPUMemorySpace.ANY, + memory_space=pltpu.ANY, ) LEFT = 0 @@ -954,11 +955,11 @@ def _(): grid_spec = pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, in_specs=[ - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + pl.BlockSpec(memory_space=pltpu.ANY), ], out_specs=[ - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pltpu.ANY), ], grid=(num_devices, 2), scratch_shapes=( @@ -975,9 +976,9 @@ def pallas_reduce_scatter(input_arr): reduce_scatter_kernel, out_shape=out_shape, grid_spec=grid_spec, - interpret=mosaic_interpret.TPUInterpretParams( + interpret=pltpu.InterpretParams( dma_execution_mode=dma_execution_mode, detect_races=detect_races), - compiler_params=pltpu.TPUCompilerParams(collective_id=19), + compiler_params=pltpu.CompilerParams(collective_id=19), )(input_arr)[0] pallas_result = jax.jit( @@ -986,7 +987,7 @@ def pallas_reduce_scatter(input_arr): mesh=mesh, in_specs=P(None, 'x'), out_specs=P('x', None), - check_rep=False, + check_vma=False, ) )(input_arr) pallas_result = jax.block_until_ready(pallas_result) @@ -1017,19 +1018,6 @@ def test_race_detection(self): input_arr = jax.device_put(input_arr, sharding) def kernel(src_dst_ids_ref, x_ref, o_ref, send_sem, recv_sem): - # Barrier with all devices before doing any DMAs. - barrier_sem = pltpu.get_barrier_semaphore() - @functools.partial(jax.lax.fori_loop, 0, num_devices, init_val=None) - def _(i, _): - pltpu.semaphore_signal( - barrier_sem, - inc=1, - device_id=(jnp.int32(i),), - device_id_type=pltpu.DeviceIdType.MESH, - ) - return None - pltpu.semaphore_wait(barrier_sem, num_devices) - # Send the specified DMAs. my_id = lax.axis_index('x') src_dst_ids = src_dst_ids_ref[:] @@ -1071,13 +1059,12 @@ def run(src_dst_ids): kernel, out_shape=jax.ShapeDtypeStruct((8, 128), input_arr.dtype), in_specs=[ - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM), - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + pl.BlockSpec(memory_space=pltpu.SMEM), + pl.BlockSpec(memory_space=pltpu.ANY), ], - out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + out_specs=pl.BlockSpec(memory_space=pltpu.ANY), scratch_shapes=[pltpu.SemaphoreType.DMA, pltpu.SemaphoreType.DMA], - compiler_params=pltpu.TPUCompilerParams(collective_id=0), - interpret=mosaic_interpret.TPUInterpretParams( + interpret=pltpu.InterpretParams( dma_execution_mode='eager', detect_races=True, ), @@ -1085,7 +1072,7 @@ def run(src_dst_ids): mesh=mesh, in_specs=(P(None), P('x', None)), out_specs=P('x', None), - check_rep=False, + check_vma=False, )(src_dst_ids, input_arr) run(jnp.array([[0, 1], [1, 2], [2, 3]], jnp.int32)).block_until_ready() @@ -1095,6 +1082,67 @@ def run(src_dst_ids): run(jnp.array([[0, 1], [1, 2], [3, 2], [3, 0]], jnp.int32)).block_until_ready() self.assertTrue(mosaic_interpret.races.races_found) + @parameterized.parameters(1, 2, 4) + def test_shard_map_of_core_map(self, num_cores): + num_devices = jax.device_count() + partition = P('x', None) + mesh = jax.make_mesh((num_devices,), ('x',)) + sharding = jax.sharding.NamedSharding(mesh, partition) + + core_mesh = pltpu.create_tensorcore_mesh('core', num_cores=num_cores) + interpret = pltpu.InterpretParams(detect_races=True) + + @jax.jit + def f(x): + y = jnp.zeros_like(x) + def inner(refs): + x_ref, y_ref = refs + @pl.core_map(core_mesh, interpret=interpret) + def _(): + num_cores = jax.lax.axis_size('core') + slc_size = 16 // num_cores + def alloc(x_vmem_ref, y_vmem_ref, dma_sem, sem): + # Barrier so we deadlock unless the core_map is actually parallel. + for i in range(num_cores): + pl.semaphore_signal(sem, 1, core_index=i) + pl.semaphore_wait(sem, num_cores) + + core_index = jax.lax.axis_index('core') + slc = pl.ds(core_index * slc_size, slc_size) + pltpu.async_copy( + x_ref.at[slc], + x_vmem_ref, + dma_sem, + ).wait() + y = (x_vmem_ref[...] + num_cores * jax.lax.axis_index('x') + + core_index + 1) + y_vmem_ref[...] = y + pltpu.async_copy(y_vmem_ref, y_ref.at[slc], dma_sem).wait() + pl.run_scoped( + alloc, + pltpu.VMEM((slc_size, 128), x_ref.dtype), + pltpu.VMEM((slc_size, 128), y_ref.dtype), + pltpu.SemaphoreType.DMA, + pltpu.SemaphoreType.REGULAR, + ) + _, y = pl.run_state(inner)((x, y)) + return y + + x = jnp.arange(num_devices * 16 * 128, dtype=jnp.int32).reshape((-1, 128)) + y = jax.jit( + shard_map.shard_map(f, + mesh=mesh, + in_specs=partition, + out_specs=partition, + check_vma=False, + ) + )(x).block_until_ready() + expected_out = ( + x.reshape((num_devices, num_cores, -1, 128)) + 1 + + jnp.arange(num_devices, dtype=jnp.int32)[..., None, None, None] * num_cores + + jnp.arange(num_cores, dtype=jnp.int32)[None, ..., None, None] + ).reshape(x.shape) + np.testing.assert_array_equal(y, expected_out) if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/pallas/tpu_pallas_interpret_test.py b/tests/pallas/tpu_pallas_interpret_test.py index bc589855b836..725fdaa49c3d 100644 --- a/tests/pallas/tpu_pallas_interpret_test.py +++ b/tests/pallas/tpu_pallas_interpret_test.py @@ -18,23 +18,88 @@ contains only tests that do not use shard_map. """ +from collections.abc import Callable +import dataclasses +import functools +import threading + from absl.testing import absltest from absl.testing import parameterized - import jax from jax._src import test_util as jtu import jax._src.pallas.mosaic.interpret as mosaic_interpret from jax.experimental import pallas as pl from jax.experimental.pallas import tpu as pltpu import jax.numpy as jnp - import numpy as np jax.config.parse_flags_with_absl() +jax.config.update('jax_threefry_partitionable', True) + + +class CountStoreCallbacksContext: + """Wraps the I/O callback `store` into a callback that counts the number of calls to `store`.""" + + def __init__(self): + self._num_stores = 0 + self._saved = mosaic_interpret.store + + def __enter__(self): + def _store_callback(self, *args, **kwargs): + self._num_stores += 1 + return self._saved(*args, **kwargs) + + mosaic_interpret.store = functools.partial(_store_callback, self) + return self + + def __exit__(self, ty, value, traceback): + del ty, value, traceback + mosaic_interpret.store = self._saved + + @property + def num_stores(self): + return self._num_stores + + +@dataclasses.dataclass(frozen=True) +class ProcessedGridPoint(): + """Represents a grid point and the ID of the core that has processed it.""" + grid_point: tuple[int, ...] + core_id: int + + +class GridPointRecorderContext: + """Records grid points in the order in which they are procsessed.""" + + def __init__(self): + self._grid_points: list[ProcessedGridPoint] = [] + + def __enter__(self): + return self + def __exit__(self, ty, value, traceback): + ... + def get_recorder(self) -> Callable[[tuple[np.int32, ...], np.int32], None]: + def _recorder(grid_point, core_id): + processed_grid_point = ProcessedGridPoint( + tuple(int(coord) for coord in grid_point), int(core_id) + ) + self._grid_points.append(processed_grid_point) + + return _recorder + + @property + def grid_points(self) -> list[ProcessedGridPoint]: + return sorted(self._grid_points, key=lambda x: x.core_id) + + +# TODO(jburnim): Figure out how to safely run different instance of TPU +# interpret mode in parallel, and then remove this decorator. +@jtu.thread_unsafe_test_class() class InterpretTest(jtu.JaxTestCase): + def setUp(self): super().setUp() self.num_devices = jax.device_count() @@ -49,17 +114,18 @@ def matmul_kernel(x_ref, y_ref, z_ref): @jax.jit def matmul(x: jax.Array, y: jax.Array): return pl.pallas_call( - matmul_kernel, - out_shape=jax.ShapeDtypeStruct((x.shape[0], y.shape[1]), x.dtype), - grid=(2, 2), - in_specs=[ - pl.BlockSpec((x.shape[0] // 2, x.shape[1]), lambda i, j: (i, 0)), - pl.BlockSpec((y.shape[0], y.shape[1] // 2), lambda i, j: (0, j)) - ], - out_specs=pl.BlockSpec( - (x.shape[0] // 2, y.shape[1] // 2), lambda i, j: (i, j), - ), - interpret=mosaic_interpret.TPUInterpretParams(), + matmul_kernel, + out_shape=jax.ShapeDtypeStruct((x.shape[0], y.shape[1]), x.dtype), + grid=(2, 2), + in_specs=[ + pl.BlockSpec((x.shape[0] // 2, x.shape[1]), lambda i, j: (i, 0)), + pl.BlockSpec((y.shape[0], y.shape[1] // 2), lambda i, j: (0, j)), + ], + out_specs=pl.BlockSpec( + (x.shape[0] // 2, y.shape[1] // 2), + lambda i, j: (i, j), + ), + interpret=pltpu.InterpretParams(), )(x, y) k1, k2 = jax.random.split(jax.random.key(0)) @@ -68,11 +134,50 @@ def matmul(x: jax.Array, y: jax.Array): z = matmul(x, y) np.testing.assert_allclose(z, x @ y, atol=1e-4) + def test_scalar_prefetch_example(self): + def dynamic_slice_kernel(indices, x_ref, o_ref): + del indices + o_ref[...] = x_ref[...] + + @functools.partial(jax.jit, static_argnums=(2,)) + def block_dynamic_slice(x, starts, sizes): + grid_spec = pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=1, + grid=(1, 1), + in_specs=[ + pl.BlockSpec( + sizes, lambda i, j, block_idx: (block_idx[0], block_idx[1]) + ) + ], + out_specs=pl.BlockSpec(sizes, lambda *_: (0, 0)), + ) + + kernel = pl.pallas_call( + dynamic_slice_kernel, + grid_spec=grid_spec, + out_shape=jax.ShapeDtypeStruct(shape=sizes, dtype=x.dtype), + interpret=pltpu.InterpretParams(), + ) + block_idx = jnp.array([starts[0] // sizes[0], starts[1] // sizes[1]]) + return kernel(block_idx, x) + + shape = (512, 512) + x = jnp.reshape(jnp.arange(np.prod(shape), dtype=jnp.int32), shape) + result = block_dynamic_slice( + x, starts=jnp.array([128, 256]), sizes=(128, 128) + ) + ref = jax.lax.dynamic_slice( + x, start_indices=(128, 256), slice_sizes=(128, 128) + ) + diff = jnp.max(jnp.abs(result - ref)) + np.testing.assert_allclose(result, ref) + def test_dynamic_grid_and_aliasing(self): def kernel(s_ref, x_ref, o_ref): o_ref[...] = x_ref[...] + s_ref[0].astype(x_ref.dtype) iters = jax.random.randint(jax.random.key(0), (), 10, 20, dtype=jnp.int32) + @jax.jit def f(s, x): return pl.pallas_call( @@ -80,19 +185,62 @@ def f(s, x): out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype), grid=(iters,), in_specs=[ - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM), + pl.BlockSpec(memory_space=pltpu.SMEM), pl.BlockSpec(x.shape, lambda i: (0, 0)), ], out_specs=pl.BlockSpec(x.shape, lambda i: (0, 0)), input_output_aliases={1: 0}, - interpret=mosaic_interpret.TPUInterpretParams() + interpret=pltpu.InterpretParams(), )(s, x) s = jnp.array([1], dtype=jnp.int32) - x = jnp.arange(32 * 128.).reshape((32, 128)) + x = jnp.arange(32 * 128.0).reshape((32, 128)) y = f(s, x) + # NOTE: No matter how many times the kernel body is run, the kernel input + # buffer will only be written once by the pallas_call machinery, just + # before the first iteration. So the output will be x + 1 , despite the + # aliasing in HBM. np.testing.assert_allclose(y, x + 1.0) + def test_aliasing(self): + def kernel(x_ref, o_ref, s_ref): + @pl.when((pl.program_id(0) == 0) & (pl.program_id(1) == 0)) + def _(): + s_ref[0] = jnp.int32(0) + + s = s_ref[0] + s_ref[0] = s + 1 + o_ref[:] = x_ref[:] + s.astype(x_ref.dtype) + + x = jnp.zeros((4 * 8, 4 * 128)) + y = pl.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype), + grid=(4, 4), + in_specs=[ + pl.BlockSpec(block_shape=(8, 128), index_map=lambda i, j: (i, j)), + ], + out_specs=pl.BlockSpec( + block_shape=(8, 128), index_map=lambda i, j: (j, i) + ), + scratch_shapes=(pltpu.SMEM((1,), jnp.int32),), + input_output_aliases={0: 0}, + interpret=pltpu.InterpretParams(), + )(x) + + expected = np.zeros((4, 4)) + t = 0 + for i in range(4): + for j in range(4): + expected[j, i] = expected[i, j] + t + t += 1 + # NOTE: expected is + # [[0, 5, 10, 15], + # [1, 5, 15, 20], + # [2, 6, 10, 25], + # [3, 7, 11, 15]] + np.testing.assert_allclose(y[::8, ::128], expected) + @parameterized.parameters('eager', 'on_wait') def test_race_detection(self, dma_execution_mode): def kernel_without_race(x_ref, o_ref, t_ref, sem): @@ -109,28 +257,32 @@ def kernel_with_race(x_ref, o_ref, t_ref, sem): copy.wait() x = jnp.zeros((8, 128), jnp.float32) - y = pl.pallas_call(kernel_without_race, + y = pl.pallas_call( + kernel_without_race, out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype), - in_specs=[pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY)], + in_specs=[pl.BlockSpec(memory_space=pltpu.ANY)], scratch_shapes=[ pltpu.VMEM(x.shape, x.dtype), pltpu.SemaphoreType.DMA, ], - interpret=mosaic_interpret.TPUInterpretParams( - detect_races=True, dma_execution_mode=dma_execution_mode), + interpret=pltpu.InterpretParams( + detect_races=True, dma_execution_mode=dma_execution_mode + ), )(x).block_until_ready() self.assertFalse(mosaic_interpret.races.races_found) np.testing.assert_allclose(y, x + 1.0) - pl.pallas_call(kernel_with_race, + pl.pallas_call( + kernel_with_race, out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype), - in_specs=[pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY)], + in_specs=[pl.BlockSpec(memory_space=pltpu.ANY)], scratch_shapes=[ pltpu.VMEM(x.shape, x.dtype), pltpu.SemaphoreType.DMA, ], - interpret=mosaic_interpret.TPUInterpretParams( - detect_races=True, dma_execution_mode=dma_execution_mode), + interpret=pltpu.InterpretParams( + detect_races=True, dma_execution_mode=dma_execution_mode + ), )(x).block_until_ready() self.assertTrue(mosaic_interpret.races.races_found) @@ -142,7 +294,7 @@ def matmul(x: jax.Array, y: jax.Array): return pl.pallas_call( matmul_kernel, out_shape=jax.ShapeDtypeStruct((x.shape[0], y.shape[1]), x.dtype), - interpret=mosaic_interpret.TPUInterpretParams( + interpret=pltpu.InterpretParams( skip_floating_point_ops=True ), )(x, y) @@ -153,8 +305,8 @@ def matmul(x: jax.Array, y: jax.Array): z = jax.jit(matmul)(x, y) np.testing.assert_array_equal(z, jnp.full_like(z, jnp.inf)) - lowered = jax.jit(matmul).lower(x, y).as_text(dialect="stablehlo") - self.assertNotIn("dot_general", lowered) + lowered = jax.jit(matmul).lower(x, y).as_text(dialect='stablehlo') + self.assertNotIn('dot_general', lowered) @parameterized.parameters('nan', 'zero') def test_uninitialized_memory(self, uninitialized_memory): @@ -174,8 +326,9 @@ def kernel(o1_ref, o2_ref, o3_ref, t1_ref, t2_ref): pltpu.VMEM((8, 128), jnp.bfloat16), pltpu.VMEM((8, 128), jnp.int16), ], - interpret=mosaic_interpret.TPUInterpretParams( - uninitialized_memory=uninitialized_memory), + interpret=pltpu.InterpretParams( + uninitialized_memory=uninitialized_memory + ), )() if uninitialized_memory == 'nan': self.assertTrue(jnp.isnan(x).all()) @@ -186,6 +339,534 @@ def kernel(o1_ref, o2_ref, o3_ref, t1_ref, t2_ref): np.testing.assert_equal(np.array(y), 0) np.testing.assert_equal(np.array(z), 0) + def test_correct_number_of_stores(self): + def kernel(x_ref, s_ref, o_ref): + s = s_ref[0] + x_ref[:] += jax.lax.full_like(x_ref, s) + s_ref[0] = s + 1 + o_ref[:] = x_ref[:] + + def kernel_call(x, s): + return pl.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((16, 256), jnp.float32), + grid=(2, 2), + in_specs=[ + pl.BlockSpec((8, 256), lambda i, j: (i, 0)), + pl.BlockSpec(memory_space=pltpu.SMEM), + ], + out_specs=pl.BlockSpec((8, 256), lambda i, j: (i, 0)), + interpret=pltpu.InterpretParams(), + )(x, s) + + with CountStoreCallbacksContext() as store_callbacks_counter: + result = jax.jit(kernel_call)( + jnp.zeros((16, 256), jnp.float32), jnp.zeros((1,), jnp.int32) + ) + np.testing.assert_allclose(result[::8, ::256], [[1.0], [5.0]]) + self.assertEqual(store_callbacks_counter.num_stores, 5) + + def test_randomization_of_parallel_dimensions(self): + def kernel(s_ref, o_ref): + s = s_ref[0] + s_ref[0] = s + 1 + o_ref[:] = jax.lax.full_like(o_ref, s) + + def kernel_call_dimensions_parallel_arbitrary(s, grid_point_recorder): + return pl.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((32, 512), jnp.float32), + grid=(4, 4), + in_specs=[pl.BlockSpec(memory_space=pltpu.SMEM)], + out_specs=pl.BlockSpec((8, 128), lambda i, j: (i, j)), + interpret=pltpu.InterpretParams( + random_seed=12345, grid_point_recorder=grid_point_recorder + ), + compiler_params=pltpu.CompilerParams( + dimension_semantics=('parallel', 'arbitrary') + ), + )(s) + + with GridPointRecorderContext() as grid_point_recorder: + result = jax.jit( + kernel_call_dimensions_parallel_arbitrary, static_argnums=1 + )( + jnp.zeros((1,), jnp.int32), + grid_point_recorder.get_recorder(), + ) + np.testing.assert_allclose( + result[::8, ::128], + [ + [ 8.0, 9.0, 10.0, 11.0], + [12.0, 13.0, 14.0, 15.0], + [ 0.0, 1.0, 2.0, 3.0], + [ 4.0, 5.0, 6.0, 7.0], + ], + ) + self.assertListEqual( + grid_point_recorder.grid_points, + [ + ProcessedGridPoint((2, 0), 0), + ProcessedGridPoint((2, 1), 0), + ProcessedGridPoint((2, 2), 0), + ProcessedGridPoint((2, 3), 0), + ProcessedGridPoint((3, 0), 0), + ProcessedGridPoint((3, 1), 0), + ProcessedGridPoint((3, 2), 0), + ProcessedGridPoint((3, 3), 0), + ProcessedGridPoint((0, 0), 0), + ProcessedGridPoint((0, 1), 0), + ProcessedGridPoint((0, 2), 0), + ProcessedGridPoint((0, 3), 0), + ProcessedGridPoint((1, 0), 0), + ProcessedGridPoint((1, 1), 0), + ProcessedGridPoint((1, 2), 0), + ProcessedGridPoint((1, 3), 0), + ], + ) + + def test_dimensions_arbitrary_parallel_raises(self): + def kernel_call(s): + def kernel(s_ref, o_ref): + s = s_ref[0] + o_ref[0] = s + + return pl.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((32, 512), jnp.float32), + grid=(4, 4), + in_specs=[pl.BlockSpec(memory_space=pltpu.SMEM)], + out_specs=pl.BlockSpec((8, 128), lambda i, j: (i, j)), + interpret=pltpu.InterpretParams(random_seed=12345), + compiler_params=pltpu.CompilerParams( + dimension_semantics=('arbitrary', 'parallel') + ), + )(s) + + with self.assertRaises(ValueError): + jax.jit(kernel_call)( + jnp.zeros((1,), jnp.int32), + ) + + def test_dynamic_parallel_dimension_raises(self): + def kernel(o_ref): + o_ref[0] = 42.0 + + @jax.jit + def kernel_call_dynamic_parallel_dimension(): + dim_size = jax.random.randint( + jax.random.key(0), (), 10, 20, dtype=jnp.int32 + ) + return pl.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((1,), jnp.float32), + grid=(dim_size,), + in_specs=[], + out_specs=pl.BlockSpec((1,), lambda _: (0,)), + interpret=pltpu.InterpretParams(), + compiler_params=pltpu.CompilerParams( + dimension_semantics=('parallel',) + ), + )() + + with self.assertRaises(jax.errors.ConcretizationTypeError): + kernel_call_dynamic_parallel_dimension() + + @parameterized.parameters(1, 2, 4) + def test_core_map(self, num_cores): + mesh = pltpu.create_tensorcore_mesh('x', num_cores=num_cores) + interpret = pltpu.InterpretParams() + + @jax.jit + def f(x): + y = jnp.zeros_like(x) + def inner(refs): + x_ref, y_ref = refs + @pl.core_map(mesh, interpret=interpret) + def _(): + num_cores = jax.lax.axis_size('x') + slc_size = 16 // num_cores + def alloc(x_vmem_ref, y_vmem_ref, dma_sem, sem): + # Barrier so we deadlock unless the core_map is actually parallel. + for i in range(num_cores): + pl.semaphore_signal(sem, 1, core_index=i) + pl.semaphore_wait(sem, num_cores) + + core_index = jax.lax.axis_index('x') + slc = pl.ds(core_index * slc_size, slc_size) + pltpu.async_copy( + x_ref.at[slc], + x_vmem_ref, + dma_sem, + ).wait() + y = x_vmem_ref[...] + jax.lax.axis_index('x') + 1 + y_vmem_ref[...] = y + pltpu.async_copy(y_vmem_ref, y_ref.at[slc], dma_sem).wait() + pl.run_scoped( + alloc, + pltpu.VMEM((slc_size, 128), x_ref.dtype), + pltpu.VMEM((slc_size, 128), y_ref.dtype), + pltpu.SemaphoreType.DMA, + pltpu.SemaphoreType.REGULAR, + ) + _, y = pl.run_state(inner)((x, y)) + return y + x = jnp.arange(16 * 128, dtype=jnp.int32).reshape((16, 128)) + expected_out = ( + x.reshape((num_cores, -1, 128)) + 1 + + jnp.arange(num_cores, dtype=jnp.int32)[..., None, None] + ).reshape(x.shape) + y = f(x) + np.testing.assert_array_equal(y, expected_out) + + def test_two_cores_along_parallel_dimension_with_race(self): + def kernel(x_ref, o_ref, vmem_ref): + vmem_ref[...] = x_ref[...] + o_ref[...] = x_ref[...] + vmem_ref[...] + + x = jnp.ones((8, 128), jnp.float32) + y = pl.pallas_call( + kernel, + grid=(2,), + out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype), + in_specs=[pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY)], + scratch_shapes=[ + pltpu.VMEM(x.shape, x.dtype), + ], + interpret=pltpu.InterpretParams( + num_cores_per_device=2, + detect_races=True, + ), + compiler_params=pltpu.CompilerParams( + dimension_semantics=('parallel',), + ), + )(x).block_until_ready() + self.assertTrue(mosaic_interpret.races.races_found) + np.testing.assert_allclose(y, 2.0 * x) + + def test_two_cores_along_parallel_dimension_no_race(self): + def kernel(x_ref, o_ref, vmem_ref): + vmem_ref[...] = x_ref[...] + o_ref[...] = x_ref[...] + vmem_ref[...] + + x = jnp.ones((16, 128), jnp.float32) + y = pl.pallas_call( + kernel, + grid=(2,), + out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype), + out_specs=pl.BlockSpec( + (8, 128), + lambda i: (i, 0), + ), + in_specs=[ + pl.BlockSpec( + (8, 128), + lambda i: (i, 0), + ), + ], + scratch_shapes=[ + pltpu.VMEM((8, 128), x.dtype), + ], + interpret=pltpu.InterpretParams( + num_cores_per_device=2, + detect_races=True, + ), + compiler_params=pltpu.CompilerParams( + dimension_semantics=('parallel',) + ), + )(x).block_until_ready() + self.assertFalse(mosaic_interpret.races.races_found) + np.testing.assert_allclose(y, 2.0 * x) + + def test_parallel_dimension_and_multiple_cores(self): + def kernel(s_ref, in_ref, o_ref): + # NOTE: diff should be 0. + diff = in_ref[...] - jnp.float32(4 * pl.program_id(0) + pl.program_id(1)) + + s = s_ref[0] + s_ref[0] = s + 1 + o_ref[:] = jax.lax.full_like(o_ref, s) + diff + + def kernel_call(s, num_cores_per_device, grid_point_recorder): + block_input = jnp.repeat( + jnp.repeat( + jnp.arange(16, dtype=jnp.float32).reshape((4, 4)), 128, axis=1), + 8, axis=0) + return pl.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((32, 512), jnp.float32), + grid=(4, 4), + in_specs=[ + pl.BlockSpec(memory_space=pltpu.SMEM), + pl.BlockSpec((8, 128), lambda i, j: (i, j)), + ], + out_specs=pl.BlockSpec((8, 128), lambda i, j: (i, j)), + interpret=pltpu.InterpretParams( + random_seed=12345, + num_cores_per_device=num_cores_per_device, + grid_point_recorder=grid_point_recorder, + detect_races=True, + ), + compiler_params=pltpu.CompilerParams( + dimension_semantics=('parallel', 'arbitrary') + ), + )(s, block_input) + + with self.subTest('num_cores_per_device=1'): + with GridPointRecorderContext() as grid_point_recorder: + result = jax.jit(kernel_call, static_argnums=(1, 2))( + jnp.zeros((1,), jnp.float32), 1, grid_point_recorder.get_recorder() + ) + np.testing.assert_allclose( + result[::8, ::128], + [ + [8.0, 9.0, 10.0, 11.0], + [12.0, 13.0, 14.0, 15.0], + [0.0, 1.0, 2.0, 3.0], + [4.0, 5.0, 6.0, 7.0], + ], + ) + self.assertListEqual( + grid_point_recorder.grid_points, + # parallel_subgrid_size = 4 + # num_parallel_points_per_core = (4 + 1 - 1) // 1 = 4 + # num_iterations_per_core = 4 * (16 // 4) = 16 + [ + ProcessedGridPoint((2, 0), 0), + ProcessedGridPoint((2, 1), 0), + ProcessedGridPoint((2, 2), 0), + ProcessedGridPoint((2, 3), 0), + ProcessedGridPoint((3, 0), 0), + ProcessedGridPoint((3, 1), 0), + ProcessedGridPoint((3, 2), 0), + ProcessedGridPoint((3, 3), 0), + ProcessedGridPoint((0, 0), 0), + ProcessedGridPoint((0, 1), 0), + ProcessedGridPoint((0, 2), 0), + ProcessedGridPoint((0, 3), 0), + ProcessedGridPoint((1, 0), 0), + ProcessedGridPoint((1, 1), 0), + ProcessedGridPoint((1, 2), 0), + ProcessedGridPoint((1, 3), 0), + ], + ) + + with self.subTest('num_cores_per_device=2'): + with GridPointRecorderContext() as grid_point_recorder: + result = jax.jit(kernel_call, static_argnums=(1, 2))( + jnp.zeros((1,), jnp.float32), 2, grid_point_recorder.get_recorder() + ) + np.testing.assert_allclose( + result[::8, ::128], + [ + [0.0, 1.0, 2.0, 3.0], + [4.0, 5.0, 6.0, 7.0], + [0.0, 1.0, 2.0, 3.0], + [4.0, 5.0, 6.0, 7.0], + ], + ) + self.assertListEqual( + grid_point_recorder.grid_points, + # parallel_subgrid_size = 4 + # num_parallel_points_per_core = (4 + 2 - 1) // 2 = 2 + # num_iterations_per_core = 2 * (16 // 4) = 8 + [ + ProcessedGridPoint((2, 0), 0), + ProcessedGridPoint((2, 1), 0), + ProcessedGridPoint((2, 2), 0), + ProcessedGridPoint((2, 3), 0), + ProcessedGridPoint((3, 0), 0), + ProcessedGridPoint((3, 1), 0), + ProcessedGridPoint((3, 2), 0), + ProcessedGridPoint((3, 3), 0), + ProcessedGridPoint((0, 0), 1), + ProcessedGridPoint((0, 1), 1), + ProcessedGridPoint((0, 2), 1), + ProcessedGridPoint((0, 3), 1), + ProcessedGridPoint((1, 0), 1), + ProcessedGridPoint((1, 1), 1), + ProcessedGridPoint((1, 2), 1), + ProcessedGridPoint((1, 3), 1), + ], + ) + + with self.subTest('num_cores_per_device=3'): + with GridPointRecorderContext() as grid_point_recorder: + result = jax.jit(kernel_call, static_argnums=(1, 2))( + jnp.zeros((1,), jnp.float32), 3, grid_point_recorder.get_recorder() + ) + np.testing.assert_allclose( + result[::8, ::128], + [ + [0.0, 1.0, 2.0, 3.0], + [4.0, 5.0, 6.0, 7.0], + [0.0, 1.0, 2.0, 3.0], + [4.0, 5.0, 6.0, 7.0], + ], + ) + self.assertListEqual( + grid_point_recorder.grid_points, + # parallel_subgrid_size = 4 + # num_parallel_points_per_core = (4 + 3 - 1) // 3 = 2 + # num_iterations_per_core = 2 * (16 // 4) = 8 + [ + ProcessedGridPoint((2, 0), 0), + ProcessedGridPoint((2, 1), 0), + ProcessedGridPoint((2, 2), 0), + ProcessedGridPoint((2, 3), 0), + ProcessedGridPoint((3, 0), 0), + ProcessedGridPoint((3, 1), 0), + ProcessedGridPoint((3, 2), 0), + ProcessedGridPoint((3, 3), 0), + ProcessedGridPoint((0, 0), 1), + ProcessedGridPoint((0, 1), 1), + ProcessedGridPoint((0, 2), 1), + ProcessedGridPoint((0, 3), 1), + ProcessedGridPoint((1, 0), 1), + ProcessedGridPoint((1, 1), 1), + ProcessedGridPoint((1, 2), 1), + ProcessedGridPoint((1, 3), 1), + ], + ) + + with self.subTest('num_cores_per_device=4'): + with GridPointRecorderContext() as grid_point_recorder: + result = jax.jit(kernel_call, static_argnums=(1, 2))( + jnp.zeros((1,), jnp.float32), 4, grid_point_recorder.get_recorder() + ) + np.testing.assert_allclose( + result[::8, ::128], + [ + [0.0, 1.0, 2.0, 3.0], + [0.0, 1.0, 2.0, 3.0], + [0.0, 1.0, 2.0, 3.0], + [0.0, 1.0, 2.0, 3.0], + ], + ) + self.assertListEqual( + grid_point_recorder.grid_points, + # parallel_subgrid_size = 4 + # num_parallel_points_per_core = (4 + 4 - 1) // 4 = 1 + # num_iterations_per_core = 1 * (16 // 4) = 4 + [ + ProcessedGridPoint((2, 0), 0), + ProcessedGridPoint((2, 1), 0), + ProcessedGridPoint((2, 2), 0), + ProcessedGridPoint((2, 3), 0), + ProcessedGridPoint((3, 0), 1), + ProcessedGridPoint((3, 1), 1), + ProcessedGridPoint((3, 2), 1), + ProcessedGridPoint((3, 3), 1), + ProcessedGridPoint((0, 0), 2), + ProcessedGridPoint((0, 1), 2), + ProcessedGridPoint((0, 2), 2), + ProcessedGridPoint((0, 3), 2), + ProcessedGridPoint((1, 0), 3), + ProcessedGridPoint((1, 1), 3), + ProcessedGridPoint((1, 2), 3), + ProcessedGridPoint((1, 3), 3), + ], + ) + + with self.subTest('num_cores_per_device=5'): + with GridPointRecorderContext() as grid_point_recorder: + result = jax.jit(kernel_call, static_argnums=(1, 2))( + jnp.zeros((1,), jnp.float32), 5, grid_point_recorder.get_recorder() + ) + np.testing.assert_allclose( + result[::8, ::128], + [ + [0.0, 1.0, 2.0, 3.0], + [0.0, 1.0, 2.0, 3.0], + [0.0, 1.0, 2.0, 3.0], + [0.0, 1.0, 2.0, 3.0], + ], + ) + self.assertListEqual( + grid_point_recorder.grid_points, + # parallel_subgrid_size = 4 + # num_parallel_points_per_core = (4 + 5 - 1) // 5 = 1 + # num_iterations_per_core = 1 * (16 // 4) = 4 + [ + ProcessedGridPoint((2, 0), 0), + ProcessedGridPoint((2, 1), 0), + ProcessedGridPoint((2, 2), 0), + ProcessedGridPoint((2, 3), 0), + ProcessedGridPoint((3, 0), 1), + ProcessedGridPoint((3, 1), 1), + ProcessedGridPoint((3, 2), 1), + ProcessedGridPoint((3, 3), 1), + ProcessedGridPoint((0, 0), 2), + ProcessedGridPoint((0, 1), 2), + ProcessedGridPoint((0, 2), 2), + ProcessedGridPoint((0, 3), 2), + ProcessedGridPoint((1, 0), 3), + ProcessedGridPoint((1, 1), 3), + ProcessedGridPoint((1, 2), 3), + ProcessedGridPoint((1, 3), 3), + ], + ) + + with self.subTest('num_cores_per_device=6'): + with GridPointRecorderContext() as grid_point_recorder: + result = jax.jit(kernel_call, static_argnums=(1, 2))( + jnp.zeros((1,), jnp.float32), 6, grid_point_recorder.get_recorder() + ) + np.testing.assert_allclose( + result[::8, ::128], + [ + [0.0, 1.0, 2.0, 3.0], + [0.0, 1.0, 2.0, 3.0], + [0.0, 1.0, 2.0, 3.0], + [0.0, 1.0, 2.0, 3.0], + ], + ) + self.assertListEqual( + grid_point_recorder.grid_points, + # parallel_subgrid_size = 4 + # num_parallel_points_per_core = (4 + 6 - 1) // 6 = 1 + # num_iterations_per_core = 1 * (16 // 4) = 4 + [ + ProcessedGridPoint((2, 0), 0), + ProcessedGridPoint((2, 1), 0), + ProcessedGridPoint((2, 2), 0), + ProcessedGridPoint((2, 3), 0), + ProcessedGridPoint((3, 0), 1), + ProcessedGridPoint((3, 1), 1), + ProcessedGridPoint((3, 2), 1), + ProcessedGridPoint((3, 3), 1), + ProcessedGridPoint((0, 0), 2), + ProcessedGridPoint((0, 1), 2), + ProcessedGridPoint((0, 2), 2), + ProcessedGridPoint((0, 3), 2), + ProcessedGridPoint((1, 0), 3), + ProcessedGridPoint((1, 1), 3), + ProcessedGridPoint((1, 2), 3), + ProcessedGridPoint((1, 3), 3), + ], + ) + + def test_thread_map(self): + barrier = threading.Barrier(8) + lock = threading.Lock() + concurrent_calls = [0] + max_concurrent_calls = [0] + + def _barrier(): + with lock: + concurrent_calls[0] += 1 + max_concurrent_calls[0] = max( + max_concurrent_calls[0], concurrent_calls[0]) + barrier.wait() + with lock: + concurrent_calls[0] -= 1 + + def f(core_index): + del core_index + jax.experimental.io_callback(_barrier, (), ordered=True) + + mosaic_interpret._thread_map(f, 8) + self.assertEqual(max_concurrent_calls[0], 8) -if __name__ == "__main__": +if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/pallas/tpu_pallas_memory_space_test.py b/tests/pallas/tpu_pallas_memory_space_test.py new file mode 100644 index 000000000000..d3c62e329047 --- /dev/null +++ b/tests/pallas/tpu_pallas_memory_space_test.py @@ -0,0 +1,111 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test TPU-specific uses of Pallas memory space APIs.""" + +import functools +from absl.testing import absltest +from absl.testing import parameterized +import jax +from jax._src import test_util as jtu +from jax.experimental import pallas as pl +from jax.experimental.pallas import tpu as pltpu +import jax.numpy as jnp +import numpy as np + + +jax.config.parse_flags_with_absl() +P = jax.sharding.PartitionSpec +partial = functools.partial + + +class TPUPallasMemorySpaceTest(jtu.JaxTestCase): + + def setUp(self): + super().setUp() + if not jtu.if_cloud_tpu_at_least(2025, 6, 10): + self.skipTest('Needs a newer libTPU') + if not jtu.is_device_tpu_at_least(5): + self.skipTest('Needs a newer TPU') + + @parameterized.parameters( + (pltpu.VMEM, 1), + (pltpu.HBM, 0), + (pltpu.ANY, None), + ) + def test_basic_input_memory_space_constraint(self, memory_space, color): + + def kernel(x_ref, y_ref): + y_ref[...] = x_ref[...] + + def g(x): + return pl.pallas_call(kernel, out_shape=x)(x) + + @jax.jit + def f(x): + x = pltpu.with_memory_space_constraint(x, memory_space=memory_space) + if color is None: + self.assertIsNone(pltpu.get_memory_space(x)) + else: + self.assertEqual(pltpu.get_memory_space(x), memory_space) + x = g(x) + return x + + x = jnp.ones((8, 128), dtype=jnp.float32) + y = f(x) + np.testing.assert_array_equal(y, x) + hlo = jax.jit(f).lower(x).compile().as_text() + if color is None: + self.assertIn('"input_memory_space_colors":[]', hlo) + else: + self.assertIn( + f'"input_memory_space_colors":[{{"operand_index":"0","color":"{color}","shape_index":[]}}]', + hlo, + ) + + @parameterized.parameters( + (pltpu.VMEM, 1), + (pltpu.HBM, 0), + (pltpu.ANY, None), + ) + def test_basic_output_memory_space_constraint(self, memory_space, color): + if color is None: + memory_space = jax.ShapeDtypeStruct + + def kernel(x_ref, y_ref): + y_ref[...] = x_ref[...] + + def g(x): + return pl.pallas_call(kernel, out_shape=memory_space(x.shape, x.dtype))(x) + + @jax.jit + def f(x): + x = g(x) + return x + + x = jnp.ones((8, 128), dtype=jnp.float32) + y = f(x) + np.testing.assert_array_equal(y, x) + hlo = jax.jit(f).lower(x).compile().as_text() + if color is None: + self.assertIn('"output_memory_space_colors":[]', hlo) + else: + self.assertIn( + f'"output_memory_space_colors":[{{"color":"{color}","shape_index":[]}}]', + hlo, + ) + + +if __name__ == '__main__': + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/pallas/tpu_pallas_pipeline_test.py b/tests/pallas/tpu_pallas_pipeline_test.py index 8e72c49e2598..515e4a3c26ea 100644 --- a/tests/pallas/tpu_pallas_pipeline_test.py +++ b/tests/pallas/tpu_pallas_pipeline_test.py @@ -17,34 +17,29 @@ import functools from absl.testing import absltest from absl.testing import parameterized +import hypothesis as hp +import hypothesis.strategies as hps import jax from jax import lax from jax._src import test_util as jtu from jax.experimental import mesh_utils from jax.experimental import pallas as pl -from jax.experimental import shard_map +from jax._src import shard_map from jax.experimental.pallas import tpu as pltpu import jax.numpy as jnp import numpy as np -try: - import hypothesis as hp - import hypothesis.strategies as hps - CAN_USE_HYPOTHESIS = True -except (ModuleNotFoundError, ImportError): - CAN_USE_HYPOTHESIS = False - - -if CAN_USE_HYPOTHESIS: - hp.settings.register_profile( - 'deterministic', - database=None, - derandomize=True, - deadline=None, - max_examples=200, - print_blob=True, - verbosity=hp.Verbosity.verbose, - ) - hp.settings.load_profile('deterministic') + + +hp.settings.register_profile( + 'deterministic', + database=None, + derandomize=True, + deadline=None, + max_examples=200, + print_blob=True, + verbosity=hp.Verbosity.verbose, +) +hp.settings.load_profile('deterministic') jax.config.parse_flags_with_absl() @@ -127,20 +122,16 @@ def _reduce_out(): class PallasCallPipelineTest(parameterized.TestCase): def setUp(self): - if jax.device_count() < 2: - self.skipTest('Only >=2 devices are supported.') if not jtu.is_device_tpu_at_least(5): self.skipTest('Only works with TPU v5') super().setUp() - @parameterized.named_parameters( - ('vmem', pltpu.TPUMemorySpace.VMEM), - ('hbm', pltpu.TPUMemorySpace.ANY), + @parameterized.product( + no_pipelining=[False, True], + use_sreg_for_state=[False, True], ) - def test_pipeline_matmul(self, memory_space): - # TODO(b/358121809): Re-enable this test once the bug is fixed. - self.skipTest('Broken test.') + def test_pipeline_matmul(self, no_pipelining, use_sreg_for_state): k1, k2 = jax.random.split(jax.random.key(0)) x = jax.random.uniform(k1, (512, 512)) y = jax.random.uniform(k2, (512, 512)) @@ -161,16 +152,18 @@ def matmul_kernel(x_ref, y_ref, z_ref): pl.BlockSpec((128, 128), lambda i, j, k: (k, j)), ], out_specs=pl.BlockSpec((128, 128), lambda i, j, k: (i, j)), + no_pipelining=no_pipelining, + use_sreg_for_state=use_sreg_for_state, )(x_ref, y_ref, z_ref) z = pl.pallas_call( matmul_kernel, out_shape=jax.ShapeDtypeStruct((512, 512), jnp.float32), in_specs=[ - pl.BlockSpec(memory_space=memory_space), - pl.BlockSpec(memory_space=memory_space), + pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pltpu.ANY), ], - out_specs=pl.BlockSpec(memory_space=memory_space), + out_specs=pl.BlockSpec(memory_space=pltpu.ANY), ) jax.block_until_ready(z(x, y)) @@ -179,11 +172,11 @@ def matmul_kernel(x_ref, y_ref, z_ref): out = jax.block_until_ready(z(x, y)) expected_out = jax.block_until_ready(jnp.dot(x, y)) - np.testing.assert_allclose(out, expected_out) + np.testing.assert_allclose(out, expected_out, atol=5e-5) @parameterized.named_parameters( - ('vmem', pltpu.TPUMemorySpace.VMEM), - ('hbm', pltpu.TPUMemorySpace.ANY), + ('vmem', pltpu.VMEM), + ('hbm', pltpu.ANY), ) def test_double_pipeline_matmul(self, memory_space): # TODO(b/358121809): Re-enable this test once the bug is fixed. @@ -240,11 +233,11 @@ def setUp(self): super().setUp() @parameterized.named_parameters( - ('vmem', pltpu.TPUMemorySpace.VMEM, jnp.bfloat16, 2, 2, 2), - ('hbm', pltpu.TPUMemorySpace.ANY, jnp.bfloat16, 2, 2, 2), - ('hbm_float32', pltpu.TPUMemorySpace.ANY, jnp.float32, 2, 2, 2), - ('hbm_float32_112', pltpu.TPUMemorySpace.ANY, jnp.float32, 1, 1, 2), - ('hbm_float32_111', pltpu.TPUMemorySpace.ANY, jnp.float32, 1, 1, 1), + ('vmem', pltpu.VMEM, jnp.bfloat16, 2, 2, 2), + ('hbm', pltpu.ANY, jnp.bfloat16, 2, 2, 2), + ('hbm_float32', pltpu.ANY, jnp.float32, 2, 2, 2), + ('hbm_float32_112', pltpu.ANY, jnp.float32, 1, 1, 2), + ('hbm_float32_111', pltpu.ANY, jnp.float32, 1, 1, 1), ) def test_pipeline_latency_optimized_allgather_matmul( self, memory_space, out_dtype, n_tiles, m_tiles, k_tiles): @@ -486,7 +479,7 @@ def _wait_on_prev_dma(): + [pltpu.SemaphoreType.DMA] * 4 + inner_allocs ), - compiler_params=pltpu.TPUCompilerParams( + compiler_params=pltpu.CompilerParams( collective_id=0, # must set scoped vmem flag *larger* than below! e.g.: # flags.FLAGS.xla_tpu_scoped_vmem_limit_kib = 131072 @@ -502,7 +495,7 @@ def _wait_on_prev_dma(): ), in_specs=(P(None, 'x'), P(None, None)), out_specs=P(None, None), - check_rep=False, + check_vma=False, ) test = jax.jit(shard(kernel)) @@ -530,11 +523,11 @@ def reference(x, y): ) @parameterized.named_parameters( - ('vmem', pltpu.TPUMemorySpace.VMEM, jnp.bfloat16, 2, 2, 2), - ('hbm', pltpu.TPUMemorySpace.ANY, jnp.bfloat16, 2, 2, 2), - ('hbm_float32', pltpu.TPUMemorySpace.ANY, jnp.float32, 2, 2, 2), - ('hbm_float32_122', pltpu.TPUMemorySpace.ANY, jnp.float32, 1, 2, 2), - ('hbm_float32_121', pltpu.TPUMemorySpace.ANY, jnp.float32, 1, 2, 1), + ('vmem', pltpu.VMEM, jnp.bfloat16, 2, 2, 2), + ('hbm', pltpu.ANY, jnp.bfloat16, 2, 2, 2), + ('hbm_float32', pltpu.ANY, jnp.float32, 2, 2, 2), + ('hbm_float32_122', pltpu.ANY, jnp.float32, 1, 2, 2), + ('hbm_float32_121', pltpu.ANY, jnp.float32, 1, 2, 1), ) def test_pipeline_throughput_optimized_allgather_matmul( self, memory_space, out_dtype, n_tiles, m_tiles, k_tiles): @@ -720,20 +713,20 @@ def _wait_on_prev_dma(): pl.BlockSpec(memory_space=memory_space), pl.BlockSpec(memory_space=memory_space), ], - out_specs=[pl.BlockSpec(memory_space=memory_space), - pl.BlockSpec(memory_space=memory_space)], + out_specs=[ + pl.BlockSpec(memory_space=memory_space), + pl.BlockSpec(memory_space=memory_space), + ], grid=(outer_steps, 2), - scratch_shapes=[ - pltpu.VMEM((tm, tn), jnp.float32)] + scratch_shapes=[pltpu.VMEM((tm, tn), jnp.float32)] + [pltpu.SemaphoreType.DMA] * 4 - + inner_allocs + + inner_allocs, ), - compiler_params=dict( - mosaic=dict(collective_id=0, - # must set scoped vmem flag *larger* than below! e.g.: - # flags.FLAGS.xla_tpu_scoped_vmem_limit_kib = 131072 - vmem_limit_bytes=int(134217728 * 0.9) # 0.9 * 128MiB - ) + compiler_params=pltpu.CompilerParams( + collective_id=0, + # must set scoped vmem flag *larger* than below! e.g.: + # flags.FLAGS.xla_tpu_scoped_vmem_limit_kib = 131072 + vmem_limit_bytes=int(134217728 * 0.9), # 0.9 * 128MiB ), ) @@ -745,7 +738,7 @@ def _wait_on_prev_dma(): ), in_specs=(P(None, 'x'), P(None, None)), out_specs=P(None, None), - check_rep=False, + check_vma=False, ) test = jax.jit(shard(kernel)) @@ -773,11 +766,11 @@ def reference(x, y): ) @parameterized.named_parameters( - ('vmem', pltpu.TPUMemorySpace.VMEM, jnp.bfloat16, 2, 2, 2), - ('hbm', pltpu.TPUMemorySpace.ANY, jnp.bfloat16, 2, 2, 2), - ('hbm_float32', pltpu.TPUMemorySpace.ANY, jnp.float32, 2, 4, 2), - ('hbm_float32_112', pltpu.TPUMemorySpace.ANY, jnp.float32, 1, 1, 2), - ('hbm_float32_111', pltpu.TPUMemorySpace.ANY, jnp.float32, 1, 1, 1), + ('vmem', pltpu.VMEM, jnp.bfloat16, 2, 2, 2), + ('hbm', pltpu.ANY, jnp.bfloat16, 2, 2, 2), + ('hbm_float32', pltpu.ANY, jnp.float32, 2, 4, 2), + ('hbm_float32_112', pltpu.ANY, jnp.float32, 1, 1, 2), + ('hbm_float32_111', pltpu.ANY, jnp.float32, 1, 1, 1), ) def test_pipeline_latency_optimized_matmul_reducescatter( self, memory_space, out_dtype, n_tiles, m_tiles, k_tiles): @@ -1010,15 +1003,13 @@ def _loop_epilogue(): grid=(outer_steps, 2), scratch_shapes=[pltpu.VMEM((tm, tn), jnp.float32)] + [pltpu.SemaphoreType.DMA] * 4 - + inner_allocs + + inner_allocs, ), - compiler_params=dict( - mosaic=dict( - collective_id=0, - # must set scoped vmem flag *larger* than below! - # e.g. flags.FLAGS.xla_tpu_scoped_vmem_limit_kib = 131072 - vmem_limit_bytes=int(134217728 * 0.9) # 0.9 * 128MiB - ) + compiler_params=pltpu.CompilerParams( + collective_id=0, + # must set scoped vmem flag *larger* than below! + # e.g. flags.FLAGS.xla_tpu_scoped_vmem_limit_kib = 131072 + vmem_limit_bytes=int(134217728 * 0.9), # 0.9 * 128MiB ), ) @@ -1031,7 +1022,7 @@ def _loop_epilogue(): ), in_specs=(P(None, 'x'), P('x', None)), out_specs=P('x', None), - check_rep=False, + check_vma=False, ) test = jax.jit(shard(lambda x, y: kernel(x, y)[0, 0])) @@ -1062,11 +1053,11 @@ def reference(x, y): np.mean(np.abs(out - expected_out)) @parameterized.named_parameters( - ('vmem', pltpu.TPUMemorySpace.VMEM, jnp.bfloat16, 2, 2, 2), - ('hbm', pltpu.TPUMemorySpace.ANY, jnp.bfloat16, 2, 2, 2), - ('hbm_float32', pltpu.TPUMemorySpace.ANY, jnp.float32, 2, 4, 2), - ('hbm_float32_112', pltpu.TPUMemorySpace.ANY, jnp.float32, 1, 2, 2), - ('hbm_float32_111', pltpu.TPUMemorySpace.ANY, jnp.float32, 1, 2, 1), + ('vmem', pltpu.VMEM, jnp.bfloat16, 2, 2, 2), + ('hbm', pltpu.ANY, jnp.bfloat16, 2, 2, 2), + ('hbm_float32', pltpu.ANY, jnp.float32, 2, 4, 2), + ('hbm_float32_112', pltpu.ANY, jnp.float32, 1, 2, 2), + ('hbm_float32_111', pltpu.ANY, jnp.float32, 1, 2, 1), ) def test_pipeline_throughput_optimized_matmul_reducescatter( self, memory_space, out_dtype, n_tiles, m_tiles, k_tiles): @@ -1273,15 +1264,13 @@ def _prefetch_accumulator(): grid=(outer_steps, 2), scratch_shapes=[pltpu.VMEM((tm, tn), jnp.float32)] + [pltpu.SemaphoreType.DMA] * 4 - + inner_allocs + + inner_allocs, ), - compiler_params=dict( - mosaic=dict( - collective_id=0, - # must set scoped vmem flag *larger* than below! - # e.g. flags.FLAGS.xla_tpu_scoped_vmem_limit_kib = 131072 - vmem_limit_bytes=int(134217728 * 0.9) # 0.9 * 128MiB - ) + compiler_params=pltpu.CompilerParams( + collective_id=0, + # must set scoped vmem flag *larger* than below! + # e.g. flags.FLAGS.xla_tpu_scoped_vmem_limit_kib = 131072 + vmem_limit_bytes=int(134217728 * 0.9), # 0.9 * 128MiB ), ) @@ -1294,7 +1283,7 @@ def _prefetch_accumulator(): ), in_specs=(P(None, 'x'), P('x', None)), out_specs=P('x', None), - check_rep=False, + check_vma=False, ) test = jax.jit(shard(lambda x, y: kernel(x, y)[1])) @@ -1362,7 +1351,9 @@ def mul_kernel(iters_ref, x_ref, y_ref): out_specs=pl.BlockSpec(memory_space=pltpu.ANY), grid=(num_cores,), ), - compiler_params=dict(mosaic=dict(dimension_semantics=('parallel',))), + compiler_params=pltpu.CompilerParams( + dimension_semantics=('parallel',) + ), ) x = jax.random.uniform(jax.random.key(0), (640, 640)) np.testing.assert_allclose(func(jnp.array([5]), x), x * 2) @@ -1396,7 +1387,9 @@ def matmul_kernel(x_ref, y_ref): ], out_specs=pl.BlockSpec(memory_space=pltpu.ANY), grid=(num_cores,), - compiler_params=dict(mosaic=dict(dimension_semantics=('parallel',))), + compiler_params=pltpu.CompilerParams( + dimension_semantics=('parallel',) + ), ) np.testing.assert_allclose(func(x), x * 2) @@ -1445,109 +1438,238 @@ def matmul_kernel(x_ref, y_ref, z_ref, *, bm, bk, bn): ], out_specs=pl.BlockSpec(memory_space=pltpu.ANY), grid=(num_cores,), - compiler_params=dict(mosaic=dict(dimension_semantics=('parallel',))), + compiler_params=pltpu.CompilerParams( + dimension_semantics=('parallel',) + ), ) np.testing.assert_allclose(func(x, y), x @ y, atol=7e-5) -if CAN_USE_HYPOTHESIS: +@partial(jax.jit, static_argnames=['bm', 'bk', 'bn']) +def matmul(x: jax.Array, y: jax.Array, *, bm: int, bk: int, bn: int): + + m, k = x.shape + _, n = y.shape - @partial(jax.jit, static_argnames=['bm', 'bk', 'bn']) - def matmul(x: jax.Array, y: jax.Array, *, bm: int, bk: int, bn: int): + def kernel(x_hbm_ref, y_hbm_ref, o_hbm_ref): - m, k = x.shape - _, n = y.shape + grid = (pl.cdiv(m, bm), pl.cdiv(n, bn), pl.cdiv(k, bk)) - def kernel(x_hbm_ref, y_hbm_ref, o_hbm_ref): + def run(acc_scratch_ref): + pltpu.emit_pipeline( + partial(basic_matmul_kernel, acc_scratch_ref=acc_scratch_ref, k=k), + in_specs=[ + pl.BlockSpec((bm, bk), lambda i, j, k: (i, k)), + pl.BlockSpec((bk, bn), lambda i, j, k: (k, j)), + ], + out_specs=pl.BlockSpec((bm, bn), lambda i, j, k: (i, j)), + grid=grid, + core_axis=0, + dimension_semantics=( + pltpu.PARALLEL, + pltpu.PARALLEL, + pltpu.ARBITRARY, + ), + )(x_hbm_ref, y_hbm_ref, o_hbm_ref) + + accum_dtype = ( + jnp.float32 if jnp.issubdtype(x.dtype, jnp.floating) else jnp.int32 + ) + pl.run_scoped(run, pltpu.VMEM((bm, bn), accum_dtype)) + + num_cores = jax.devices()[0].num_cores + return pl.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((m, n), x.dtype), + in_specs=[ + pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pltpu.ANY), + ], + out_specs=pl.BlockSpec(memory_space=pltpu.ANY), + grid=(num_cores,), + )(x, y) + +@jtu.thread_unsafe_test_class() # hypothesis is not thread safe +class PaddedPipelineEmitterTest(parameterized.TestCase): - grid = (pl.cdiv(m, bm), pl.cdiv(n, bn), pl.cdiv(k, bk)) + def setUp(self): + super().setUp() + if not jtu.is_device_tpu_at_least(4): + self.skipTest('Only TPU v4+ allowed.') - def run(acc_scratch_ref): + @parameterized.named_parameters( + ('float32', 'float32'), ('bfloat16', 'bfloat16'), ('int8', 'int8') + ) + @hp.given( + hps.integers(1, 1024), + hps.integers(1, 1024), + hps.integers(1, 1024), + hps.sampled_from([8, 16, 32, 128, 256, 512]), + hps.sampled_from([128, 256, 512]), + hps.sampled_from([128, 256, 512]), + hps.integers(0, 4), + ) + def test_padded_matmul(self, dtype, m, k, n, bm, bk, bn, seed): + if dtype == 'int8' and jtu.is_device_tpu_at_least(6): + self.skipTest('Not implemented for TPU v6.') + + hp.assume(bm <= m) + hp.assume(bn <= n) + hp.assume(bk <= k) + if dtype == 'bfloat16': + hp.assume(bm >= 16) + if dtype == 'int8': + if not jtu.is_device_tpu_at_least(5): + self.skipTest('Only TPU v5+ allowed for int8.') + hp.assume(bm >= 32) + k1, k2 = jax.random.split(jax.random.key(seed)) + x = jax.random.normal(k1, (m, k), jnp.float32).astype(dtype) + y = jax.random.normal(k2, (k, n), jnp.float32).astype(dtype) + + out = matmul(x, y, bm=bm, bk=bk, bn=bn) + expected = x @ y + atol = rtol = 2.3e-5 + if dtype == 'bfloat16': + out = out.astype('float32') + expected = expected.astype('float32') + atol = rtol = 1e-2 + np.testing.assert_allclose(out, expected, atol=atol, rtol=rtol) + + +class PallasCallBoundedSliceIndexingTest(parameterized.TestCase): + + def test_block_spec_bounded_slice_invalid_index(self): + if not jtu.is_device_tpu(): + self.skipTest('Only works on TPU.') + shape = (16, 8, 128) + + def kernel(x_ref, o_ref): + o_ref[...] = x_ref[...] + + def main(refs): + x_ref, y_ref = refs + + @pl.core_map(pltpu.create_tensorcore_mesh('core')) + def _(): pltpu.emit_pipeline( - partial(basic_matmul_kernel, acc_scratch_ref=acc_scratch_ref, k=k), - in_specs=[ - pl.BlockSpec((bm, bk), lambda i, j, k: (i, k)), - pl.BlockSpec((bk, bn), lambda i, j, k: (k, j)), - ], - out_specs=pl.BlockSpec((bm, bn), lambda i, j, k: (i, j)), - grid=grid, - core_axis=0, - dimension_semantics=( - pltpu.PARALLEL, - pltpu.PARALLEL, - pltpu.ARBITRARY, + kernel, + grid=(1,), + in_specs=( + pl.BlockSpec( + (pl.BoundedSlice(8), 8, 128), + lambda i: (0, 0, 0), # first index needs to be a pl.ds + ), + ), + out_specs=pl.BlockSpec( + (8, 8, 128), + lambda i: (0, 0, 0), ), - )(x_hbm_ref, y_hbm_ref, o_hbm_ref) + )(x_ref, y_ref) - accum_dtype = ( - jnp.float32 if jnp.issubdtype(x.dtype, jnp.floating) else jnp.int32 - ) - pl.run_scoped(run, pltpu.VMEM((bm, bn), accum_dtype)) + @jax.jit + def f(x): + y = jnp.ones((8, 8, 128), dtype=jnp.int32) + _, y = pl.run_state(main)((x, y)) + return y + with self.assertRaisesRegex( + ValueError, + 'Must return a pl.ds from the index_map for a BoundedSlice dimension.' + ): + f.trace(jax.ShapeDtypeStruct(shape, jnp.int32)) - num_cores = jax.devices()[0].num_cores - return pl.pallas_call( - kernel, - out_shape=jax.ShapeDtypeStruct((m, n), x.dtype), - in_specs=[ - pl.BlockSpec(memory_space=pltpu.ANY), - pl.BlockSpec(memory_space=pltpu.ANY), - ], - out_specs=pl.BlockSpec(memory_space=pltpu.ANY), - grid=(num_cores,), - )(x, y) + def test_block_spec_bounded_slice_static(self): + if not jtu.is_device_tpu(): + self.skipTest('Only works on TPU.') + if not jtu.is_device_tpu_at_least(4): + self.skipTest('Only works on TPU v4+') + shape = (16, 8, 128) - class PaddedPipelineEmitterTest(parameterized.TestCase): + def kernel(x_ref, o_ref): + o_ref[...] = x_ref[...] - def setUp(self): - super().setUp() - if not jtu.is_device_tpu_at_least(4): - self.skipTest('Only TPU v4+ allowed.') + def main(refs): + x_ref, y_ref = refs - @parameterized.named_parameters( - ('float32', 'float32'), ('bfloat16', 'bfloat16'), ('int8', 'int8') - ) - @hp.given( - hps.integers(1, 1024), - hps.integers(1, 1024), - hps.integers(1, 1024), - hps.sampled_from([8, 16, 32, 128, 256, 512]), - hps.sampled_from([128, 256, 512]), - hps.sampled_from([128, 256, 512]), - hps.integers(0, 4), - ) - def test_padded_matmul(self, dtype, m, k, n, bm, bk, bn, seed): - if dtype == 'int8' and jtu.is_device_tpu_at_least(6): - self.skipTest('Not implemented for TPU v6.') - - def align_up_to(x, y): - return (x + y - 1) // y * y - - hp.assume(bm <= m) - hp.assume(bn <= n) - hp.assume(bk <= k) - if dtype == 'bfloat16': - hp.assume(bm >= 16) - if dtype == 'int8': - if not jtu.is_device_tpu_at_least(5): - self.skipTest('Only TPU v5+ allowed for int8.') - hp.assume(bm >= 32) - # TODO(apaszke): Relax DMA restrictions and remove this. - packing = 4 // jnp.dtype(dtype).itemsize - if packing != 1: - m = align_up_to(m, 8 * packing) - k = align_up_to(k, 8 * packing) - k1, k2 = jax.random.split(jax.random.key(seed)) - x = jax.random.normal(k1, (m, k), jnp.float32).astype(dtype) - y = jax.random.normal(k2, (k, n), jnp.float32).astype(dtype) - - out = matmul(x, y, bm=bm, bk=bk, bn=bn) - expected = x @ y - atol = rtol = 2.3e-5 - if dtype == 'bfloat16': - out = out.astype('float32') - expected = expected.astype('float32') - atol = rtol = 1e-2 - np.testing.assert_allclose(out, expected, atol=atol, rtol=rtol) + @pl.core_map(pltpu.create_tensorcore_mesh('core')) + def _(): + pltpu.emit_pipeline( + kernel, + grid=(1,), + in_specs=( + pl.BlockSpec( + (pl.BoundedSlice(8), 8, 128), + lambda i: (pl.ds(4, 8), 0, 0), + ), + ), + out_specs=pl.BlockSpec( + (8, 8, 128), + lambda i: (0, 0, 0), + ), + )(x_ref, y_ref) + + x = jnp.arange(np.prod(shape), dtype=np.int32).reshape(shape) + + @jax.jit + def f(x): + y = jnp.ones((8, 8, 128), dtype=jnp.int32) + _, y = pl.run_state(main)((x, y)) + return y + + out = f(x) + np.testing.assert_allclose(out, x[4:12]) + + def test_block_spec_bounded_slice_dynamic(self): + if not jtu.is_device_tpu(): + self.skipTest('Only works on TPU.') + if not jtu.is_device_tpu_at_least(4): + self.skipTest('Only works on TPU v4+') + shape = (16, 8, 128) + + slices = jnp.array([[0, 3], [3, 8], [8, 11], [11, 16]], dtype=jnp.int32)[ + ::-1 + ] + + def kernel(x_ref, o_ref): + o_ref[...] = x_ref[...] + + def main(refs): + x_ref, y_ref, slices_ref = refs + + @pl.core_map(pltpu.create_tensorcore_mesh('core')) + def _(): + + @functools.partial( + pl.run_scoped, slices_smem=pltpu.SMEM(slices.shape, slices.dtype) + ) + def _(slices_smem): + pltpu.sync_copy(slices_ref, slices_smem) + def index_map(i): + return ( + pl.ds(slices_smem[i, 0], slices_smem[i, 1] - slices_smem[i, 0]), + 0, + 0, + ) + block_spec = pl.BlockSpec( + (pl.BoundedSlice(16), 8, 128), + index_map, + ) + pltpu.emit_pipeline( + kernel, + grid=(slices.shape[0],), + in_specs=(block_spec,), + out_specs=block_spec, + )(x_ref, y_ref) + + x = jnp.arange(np.prod(shape), dtype=np.int32).reshape(shape) + + @jax.jit + def f(x, slices): + y = pl.empty_like(x) + _, y, _ = pl.run_state(main)((x, y, slices)) + return y + + out = f(x, slices) + np.testing.assert_allclose(out, x) if __name__ == '__main__': diff --git a/tests/pallas/tpu_pallas_random_test.py b/tests/pallas/tpu_pallas_random_test.py index ca8edf7a269e..aea2d05d57b4 100644 --- a/tests/pallas/tpu_pallas_random_test.py +++ b/tests/pallas/tpu_pallas_random_test.py @@ -19,7 +19,7 @@ from jax._src import test_util as jtu from jax._src.pallas.mosaic import random as plrandom from jax.experimental import pallas as pl -from jax.experimental import shard_map +from jax._src import shard_map from jax.experimental.pallas import tpu as pltpu from jax.experimental.pallas.ops.tpu.random import philox # pylint: disable=unused-import # noqa: F401 from jax.experimental.pallas.ops.tpu.random import threefry # pylint: disable=unused-import # noqa: F401 @@ -117,7 +117,7 @@ def body(key_ref, o_ref): o_shape = jax.ShapeDtypeStruct((8, 128), jnp.float32) result = pl.pallas_call( body, - in_specs=[pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM)], + in_specs=[pl.BlockSpec(memory_space=pltpu.SMEM)], out_shape=o_shape, )(key) self.assertGreaterEqual(jnp.min(result), 0) @@ -135,7 +135,7 @@ def body(key_ref, o_ref): o_shape = jax.ShapeDtypeStruct((8, 128), jnp.float32) result = pl.pallas_call( body, - in_specs=[pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM)], + in_specs=[pl.BlockSpec(memory_space=pltpu.SMEM)], out_shape=o_shape, )(key) self.assertGreaterEqual(jnp.min(result), 0) @@ -143,7 +143,9 @@ def body(key_ref, o_ref): def test_key_data(self): def body(key_ref, o_ref): - o_ref[...] = jax.random.key_data(key_ref[...]) + x0, x1 = plrandom.unwrap_pallas_seed(key_ref[...]) + o_ref[0, 0] = x0 + o_ref[0, 1] = x1 rbg_key = jax_random.key(0, impl="rbg") key = plrandom.to_pallas_key(rbg_key) expected_key_data = jax.random.key_data(key) @@ -151,10 +153,11 @@ def body(key_ref, o_ref): expected_key_data.dtype) result = pl.pallas_call( body, - in_specs=[pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM)], + in_specs=[pl.BlockSpec(memory_space=pltpu.SMEM)], + out_specs=pl.BlockSpec(memory_space=pltpu.SMEM), out_shape=o_shape, )(key) - self.assertEqual(result, expected_key_data) + self.assertArraysEqual(result, expected_key_data) def test_fold_in(self): # Test that folding in a value results in different random numbers. @@ -174,7 +177,7 @@ def body(key_ref, o_ref): o_shape = jax.ShapeDtypeStruct((2, 8, 128), jnp.float32) result = pl.pallas_call( body, - in_specs=[pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM)], + in_specs=[pl.BlockSpec(memory_space=pltpu.SMEM)], out_shape=o_shape, )(key) result_a = result[0] @@ -208,7 +211,7 @@ def body(key_ref, o_ref): global_key = jax_random.key(0, impl="pallas_tpu") o_shape = jnp.ones((64, 512), dtype=jnp.float32) - key_spec = pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM) + key_spec = pl.BlockSpec(memory_space=pltpu.SMEM) out_spec = pl.BlockSpec((16, 128), lambda i, j: (i, j)) result_16x128 = pl.pallas_call( make_kernel_body(index_map=lambda i, j: (i, j)), @@ -254,7 +257,7 @@ def body(key_ref, o_ref): # TODO(justinfu): support passing keys into VMEM. result = pl.pallas_call( body, - in_specs=[pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM)], + in_specs=[pl.BlockSpec(memory_space=pltpu.VMEM)], out_shape=o_shape, )(jax.random.key_data(threefry_key)) jax_result = jax_random.uniform( @@ -303,6 +306,7 @@ def test_threefry_kernel_matches_jax_threefry_sharded(self, shape): mesh=mesh, in_specs=partition, out_specs=partition, + check_vma=False, ) jax_gen = generate(key_jax) pl_gen = generate(key_pallas) diff --git a/tests/pallas/tpu_pallas_state_test.py b/tests/pallas/tpu_pallas_state_test.py index 46f98c087110..0e0f80a0c2b5 100644 --- a/tests/pallas/tpu_pallas_state_test.py +++ b/tests/pallas/tpu_pallas_state_test.py @@ -117,6 +117,8 @@ def f_stateful(refs): x = pl.pallas_call( functools.partial(copy_kernel, x_ref, y_ref), + in_specs=[pl.BlockSpec(memory_space=pltpu.ANY)], + out_specs=pl.BlockSpec(memory_space=pltpu.ANY), scratch_shapes=[pltpu.SemaphoreType.DMA], out_shape=jax.ShapeDtypeStruct(x_ref.shape, x_ref.dtype), input_output_aliases={0: 0}, @@ -228,7 +230,7 @@ def inner(refs): x_ref, y_ref = refs @pl.core_map(mesh) def _(): - num_cores = jax.lax.psum(1, "x") + num_cores = jax.lax.axis_size("x") slc_size = 16 // num_cores def alloc(x_vmem_ref, y_vmem_ref, sem): core_index = jax.lax.axis_index("x") diff --git a/tests/pallas/tpu_pallas_test.py b/tests/pallas/tpu_pallas_test.py index 55831ff6af1d..f7d076965fd3 100644 --- a/tests/pallas/tpu_pallas_test.py +++ b/tests/pallas/tpu_pallas_test.py @@ -22,7 +22,7 @@ import math import re import sys -from typing import Callable +from collections.abc import Callable from absl.testing import absltest from absl.testing import parameterized import jax @@ -32,14 +32,14 @@ from jax._src import state from jax._src import test_util as jtu from jax._src.interpreters import partial_eval as pe -from jax._src.lib import xla_extension +from jax._src.lib import _jax from jax._src.pallas.pallas_call import _trace_kernel_to_jaxpr from jax._src.state import utils as state_utils from jax._src.state import discharge as state_discharge from jax.experimental import mesh_utils from jax.experimental import mosaic from jax.experimental import pallas as pl -from jax.experimental import shard_map +from jax._src import shard_map from jax.experimental.pallas import tpu as pltpu from jax.experimental.pallas.ops.tpu import example_kernel from jax.extend import linear_util as lu @@ -145,8 +145,7 @@ def body(_, x_ref, o_ref): x = jnp.arange(8 * 8 * 128, dtype=jnp.int32).reshape((8 * 8, 128)) def _x_transform(i, s_ref): - s = pl.load(s_ref, (i,)) - return (s, 0) + return (s_ref[i], 0) out = self.pallas_call( body, @@ -225,7 +224,7 @@ def kernel(s_refs, src, to_store, dst, *scratch_refs): assert s2.shape == (3,) assert s3 is None store_idx = s_ref[pl.program_id(0)] - pl.store(dst, (pl.dslice(store_idx, 1), slice(None)), to_store[...]) + dst[pl.dslice(store_idx, 1), :] = to_store[...] # Pass a pytree of scalar return kernel((s, np.arange(3, dtype=np.int32), None), x, to_store) @@ -281,7 +280,7 @@ def body(_, x_ref, o_ref): x = jnp.arange(2 * 8 * 8 * 128, dtype=jnp.int32).reshape((2, 8 * 8, 128)) def _x_transform(i, s_ref): - s = pl.load(s_ref, (i,)) + s = s_ref[i] return (s, 0) def f(x): @@ -423,7 +422,7 @@ def body(_, x_ref, o_ref): x = jnp.arange(8 * 8 * 128, dtype=jnp.int32).reshape((8 * 8, 128)) def _x_transform(i, s_ref): - s = pl.load(s_ref, (i,)) + s = s_ref[i] return (s, 0) s = s[None] @@ -457,7 +456,7 @@ def body(_, x_ref, o_ref): x = jnp.arange(2 * 8 * 8 * 128, dtype=jnp.int32).reshape((2, 8 * 8, 128)) def _x_transform(i, s_ref): - s = pl.load(s_ref, (i,)) + s = s_ref[i] return (s, 0) s = jnp.tile(s[None], [2, 1]) @@ -478,7 +477,7 @@ def kernel(s, x): ), grid=8, ), - compiler_params=pltpu.TPUCompilerParams( + compiler_params=pltpu.CompilerParams( allow_input_fusion=[False, True] ), )(s, x) @@ -842,7 +841,7 @@ def body(x_ref): kernel, grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, - out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM), + out_specs=pl.BlockSpec(memory_space=pltpu.SMEM), ), out_shape=jax.ShapeDtypeStruct((1,), jnp.int32), )() @@ -862,7 +861,7 @@ def body(x_ref): kernel, grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, - out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM), + out_specs=pl.BlockSpec(memory_space=pltpu.SMEM), ), out_shape=jax.ShapeDtypeStruct((2,), jnp.int32), )() @@ -881,7 +880,7 @@ def body(x_ref): kernel, grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, - out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), + out_specs=pl.BlockSpec(memory_space=pltpu.VMEM), ), out_shape=jax.ShapeDtypeStruct((16, 128), jnp.int32), )() @@ -900,7 +899,7 @@ def body(x_ref): kernel, grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, - out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), + out_specs=pl.BlockSpec(memory_space=pltpu.VMEM), ), out_shape=jax.ShapeDtypeStruct((17, 128), jnp.int32), )() @@ -1100,7 +1099,7 @@ def body(sems): y = jax.block_until_ready( self.pallas_call( kernel, - out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM), + out_specs=pl.BlockSpec(memory_space=pltpu.SMEM), out_shape=jax.ShapeDtypeStruct((m, n), jnp.int32), )() ) @@ -1123,7 +1122,7 @@ def kernel(x_hbm_ref, y_hbm_ref, sem_val_ref, dma_sem): in_specs=[pl.BlockSpec(memory_space=pl.ANY)], out_specs=[ pl.BlockSpec(memory_space=pl.ANY), - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM), + pl.BlockSpec(memory_space=pltpu.SMEM), ], scratch_shapes=[pltpu.SemaphoreType.DMA], ), @@ -1136,11 +1135,43 @@ def kernel(x_hbm_ref, y_hbm_ref, sem_val_ref, dma_sem): np.testing.assert_array_equal(y, x) np.testing.assert_array_equal(sem_val, 0) + def test_set_dma_priority(self): + if not jtu.if_cloud_tpu_at_least(2025, 4, 5): + self.skipTest('Needs a newer libTPU') + if jtu.get_tpu_version() < 5: + self.skipTest('Target does not support DMA prefetch between HBM and VMEM') + def kernel(x1, x2, y1, y2, scratch1, scratch2, sem1, sem2): + copy1 = pltpu.async_copy(x1, scratch1, sem1, priority=1) + copy2 = pltpu.async_copy(x2, scratch2, sem2, priority=0) + copy1.wait() + copy2.wait() + copy1 = pltpu.async_copy(scratch1, y1, sem1, priority=0) + copy2 = pltpu.async_copy(scratch2, y2, sem2, priority=1) + copy1.wait() + copy2.wait() + + shape = (8, 128) + dtype = jnp.int32 + x1 = jnp.arange(np.prod(shape), dtype=dtype).reshape(shape) + x2 = x1 + 1 + y1, y2 = self.pallas_call( + kernel, + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=0, + in_specs=[pl.BlockSpec(memory_space=pl.ANY)] * 2, + scratch_shapes=[pltpu.VMEM(shape, dtype)] * 2 + + [pltpu.SemaphoreType.DMA] * 2, + out_specs=[pl.BlockSpec(memory_space=pl.ANY)] * 2, + ), + out_shape=[jax.ShapeDtypeStruct(shape, dtype)] * 2, + )(x1, x2) + np.testing.assert_array_equal(y1, x1) + np.testing.assert_array_equal(y2, x2) + def test_hbm_hbm_dma(self): def kernel(x_hbm_ref, y_hbm_ref): def body(sem): - pltpu.async_copy(x_hbm_ref.at[pl.ds(8), :], y_hbm_ref.at[:, pl.ds(128)], - sem).wait() + pltpu.async_copy(x_hbm_ref.at[:8, :], y_hbm_ref.at[:, :128], sem).wait() pl.run_scoped(body, pltpu.SemaphoreType.DMA) x = jnp.arange(8 * 128.).reshape((8, 128)) y = self.pallas_call( @@ -1193,6 +1224,10 @@ def test_output_dma_semaphore_ref(self): if self.INTERPRET: self.skipTest('TODO(sharadmv, justinfu): Add interpret support for DMA.') + # TODO(subhankarshah): Remove after all required changes are in. + if not jtu.if_cloud_tpu_at_least(2025, 6, 30): + self.skipTest('Requires libtpu built after 2025-06-20') + def kernel(x_hbm_ref, y_hbm_ref, sem_out): pltpu.make_async_copy( x_hbm_ref.at[pl.ds(8), :], y_hbm_ref.at[:, pl.ds(128)], sem_out @@ -1347,7 +1382,7 @@ def body(y_ref, sem): y = self.pallas_call( kernel, in_specs=[ - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM), + pl.BlockSpec(memory_space=pltpu.SMEM), ], out_specs=pl.BlockSpec(memory_space=pl.ANY), out_shape=jax.ShapeDtypeStruct((1, 2), jnp.float32), @@ -1364,9 +1399,9 @@ def body(sem): y = self.pallas_call( kernel, in_specs=[ - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), + pl.BlockSpec(memory_space=pltpu.VMEM), ], - out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), + out_specs=pl.BlockSpec(memory_space=pltpu.VMEM), out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), )(x) np.testing.assert_allclose(y, x) @@ -1389,7 +1424,7 @@ def body(sem): in_specs=[ pl.BlockSpec(memory_space=pl.ANY), ], - out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), + out_specs=pl.BlockSpec(memory_space=pltpu.VMEM), out_shape=jax.ShapeDtypeStruct((16, 128), jnp.float32), )(x) np.testing.assert_allclose(y, x) @@ -1412,7 +1447,7 @@ def body(sem): in_specs=[ pl.BlockSpec(memory_space=pl.ANY), ], - out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), + out_specs=pl.BlockSpec(memory_space=pltpu.VMEM), out_shape=jax.ShapeDtypeStruct((16, 128), jnp.float32), )(x) np.testing.assert_allclose(y, x.reshape((16, 128))) @@ -1441,7 +1476,7 @@ def body(sem): in_specs=[ pl.BlockSpec(memory_space=pl.ANY), ], - out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), + out_specs=pl.BlockSpec(memory_space=pltpu.VMEM), out_shape=jax.ShapeDtypeStruct((3, 16, 128), jnp.float32), )(x) np.testing.assert_allclose(y, x.reshape((3, 16, 128))) @@ -1468,7 +1503,7 @@ def body(sem): in_specs=[ pl.BlockSpec(memory_space=pl.ANY), ], - out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), + out_specs=pl.BlockSpec(memory_space=pltpu.VMEM), out_shape=jax.ShapeDtypeStruct((16, 128), jnp.float32), )(x) @@ -1517,7 +1552,6 @@ def kernel(y_ref, scratch_ref): out_specs=pl.BlockSpec((None, 8, 128), lambda i: (i, 0, 0)), grid=(2,), ), - debug=True, out_shape=jax.ShapeDtypeStruct((2, 8, 128), jnp.int32), )() expected = jnp.broadcast_to(jnp.arange(2, dtype=jnp.int32)[..., None, None], @@ -1540,12 +1574,13 @@ def kernel(x_bbm_ref, y_ref, sem, dma_sem): ], scratch_shapes=[pltpu.SemaphoreType.REGULAR, pltpu.SemaphoreType.DMA], - out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), + out_specs=pl.BlockSpec(memory_space=pltpu.VMEM), ), out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), )(x) np.testing.assert_array_equal(y, x) + @jtu.thread_unsafe_test() # Uses a lot of TPU memory. def test_large_array_indexing(self): n = 6 dtype = jnp.bfloat16 @@ -1745,6 +1780,42 @@ def reduce(): reduce_value = jnp.sum(jnp.full(shape, x), dtype=dty) np.testing.assert_allclose(z, reduce_value) + def test_scalar_any_input(self): + if not jtu.is_device_tpu_at_least(4): + self.skipTest("Needs a newer TPU") + if not jtu.if_cloud_tpu_at_least(2025, 5, 1): + self.skipTest("Needs a newer libTPU") + def kernel(src, dst, sem): + pltpu.async_copy(src, dst, sem).wait() + + def run(src): + return pl.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct(src.shape, jnp.float32), + in_specs=[pl.BlockSpec(memory_space=pltpu.ANY)], + scratch_shapes=[pltpu.SemaphoreType.DMA], + out_specs=pl.BlockSpec(memory_space=pltpu.SMEM), + )(src) + x = jnp.full((1,), 3.1415, dtype=jnp.float32) + np.testing.assert_array_equal(run(x), x) + + def test_sum_in_smem(self): + if not jtu.if_cloud_tpu_at_least(2025, 4, 30): + self.skipTest("Needs a newer libTPU") + def kernel(x, out): + a = jnp.array(0, dtype=jnp.int32) + for i in range(4): + for j in range(4): + out[i, j] = a.astype(out.dtype) + a += x[i, j].astype(jnp.int32) + + x = jnp.ones((4, 4), jnp.int16) + spec = pl.BlockSpec(memory_space=pltpu.SMEM) + y = pl.pallas_call(kernel, in_specs=[spec], out_specs=spec, out_shape=x)(x) + np.testing.assert_array_equal( + y, jnp.arange(16, dtype=jnp.int32).reshape(4, 4) + ) + @parameterized.parameters([ dict( m=m, @@ -1842,16 +1913,16 @@ def kernel(x_ref, y_ref): y_ref[...] = x_ref[...] x = jnp.arange(np.prod(shape), dtype=np.float32).reshape(shape) - with self.assertRaises(xla_extension.XlaRuntimeError): + with self.assertRaises(_jax.XlaRuntimeError): self.pallas_call( kernel, out_shape=x, - compiler_params=pltpu.TPUCompilerParams(vmem_limit_bytes=256), + compiler_params=pltpu.CompilerParams(vmem_limit_bytes=256), )(x) self.pallas_call( kernel, out_shape=x, - compiler_params=pltpu.TPUCompilerParams(vmem_limit_bytes=int(2**18)), + compiler_params=pltpu.CompilerParams(vmem_limit_bytes=int(2**18)), )(x) def test_allow_input_fusion(self): @@ -1868,7 +1939,7 @@ def f(x, y): in_specs=[pl.BlockSpec((1, 128, 128), lambda i: (i, 0, 0))], out_specs=pl.BlockSpec((1, 128, 128), lambda i: (i, 0, 0)), out_shape=x, - compiler_params=pltpu.TPUCompilerParams(allow_input_fusion=[True]), + compiler_params=pltpu.CompilerParams(allow_input_fusion=[True]), )(z) x = jnp.arange(np.prod(shape), dtype=np.float32).reshape(shape) @@ -1896,7 +1967,7 @@ def kernel(x_ref, y_ref): self.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct(shape, jnp.float32), - compiler_params=pltpu.TPUCompilerParams( + compiler_params=pltpu.CompilerParams( internal_scratch_in_bytes=requested_bytes, ), )(x) @@ -1950,6 +2021,40 @@ def kernel(x_ref, w_ref, o_ref): mosaic_nans = jnp.isnan(run(x, w)).sum() self.assertEqual(jax_nans, mosaic_nans) + @parameterized.product(in_dtype=[jnp.int4, jnp.int8, jnp.int16, jnp.int32]) + def test_scalar_load_upcast(self, in_dtype): + if not jtu.if_cloud_tpu_at_least(2025, 4, 25): + self.skipTest("Needs a newer libTPU") + if in_dtype == jnp.int4 and not jtu.is_device_tpu_at_least(4): + self.skipTest("Triggers an XLA bug") # TODO(b/413602952) + def kernel(x_ref, o_ref): + o_ref[0, 0] = x_ref[0, 0].astype(o_ref.dtype) + x = jnp.asarray([[-1]], dtype=in_dtype) + y = pl.pallas_call( + kernel, + in_specs=[pl.BlockSpec(memory_space=pltpu.SMEM)], + out_specs=pl.BlockSpec(memory_space=pltpu.SMEM), + out_shape=jax.ShapeDtypeStruct((1, 1), jnp.int32), + )(x) + self.assertEqual(y, x.astype(jnp.int32)) + + @parameterized.product(in_dtype=[jnp.int4, jnp.int8, jnp.int16, jnp.int32]) + def test_scalar_indirect_load(self, in_dtype): + if not jtu.if_cloud_tpu_at_least(2025, 4, 27): + self.skipTest("Needs a newer libTPU") + def kernel(x_ref, o_ref): + o_ref[0, 0] = x_ref[0, x_ref[0, 0].astype(jnp.int32)].astype(o_ref.dtype) + if in_dtype == jnp.int4 and not jtu.is_device_tpu_at_least(4): + self.skipTest("Triggers an XLA bug") # TODO(b/413602952) + x = jnp.asarray([[3, 0, 0, 1]], dtype=in_dtype) + y = pl.pallas_call( + kernel, + in_specs=[pl.BlockSpec(memory_space=pltpu.SMEM)], + out_specs=pl.BlockSpec(memory_space=pltpu.SMEM), + out_shape=jax.ShapeDtypeStruct((1, 1), jnp.int32), + )(x) + self.assertEqual(y, x[0, x[0, 0]].astype(jnp.int32)[None, None]) + def test_masked_store(self): shape = (16, 256) mask_shape = (10, 130) @@ -1984,6 +2089,39 @@ def body(scalar_ref, x_ref, o_ref): expected = expected.at[slices].set(x[slices]) np.testing.assert_array_equal(out, expected) + def test_custom_vjp(self): + + @jax.custom_vjp + def f(x): + return jnp.tanh(x) + def f_fwd(x): + return jnp.tanh(x) * 2, () + def f_bwd(_, g): + return (g * 2,) + + f.defvjp(f_fwd, f_bwd) + + def kernel(x_ref, dy_ref, y_ref, y_p_ref, dx_ref): + x = x_ref[...] + y_ref[...] = f(x) + y_p, f_vjp = jax.vjp(f, x) + y_p_ref[...] = y_p + dx_ref[...] = f_vjp(dy_ref[...])[0] + + x = jax.random.normal(jax.random.key(0), (8, 128), dtype=jnp.float32) + dy = jax.random.normal(jax.random.key(1), (8, 128), dtype=jnp.float32) + y, y_p, dx = pl.pallas_call( + kernel, + out_shape=( + jax.ShapeDtypeStruct((8, 128), jnp.float32), + jax.ShapeDtypeStruct((8, 128), jnp.float32), + jax.ShapeDtypeStruct((8, 128), jnp.float32), + ), + )(x, dy) + np.testing.assert_array_equal(y, f(x)) + np.testing.assert_array_equal(y_p, f(x) * 2) + np.testing.assert_array_equal(dx, dy * 2) + class PallasUXTest(PallasBaseTest): @@ -2031,7 +2169,6 @@ def _(): pl.BlockSpec((128, 128), lambda i, j, k: (k, j)), ], out_specs=pl.BlockSpec((128, 128), lambda i, j, k: (i, j)), - debug=True, ) ) )(x, y) @@ -2300,6 +2437,7 @@ def kernel(x_ref, y_ref): np.testing.assert_array_equal(y, x[8:16, :128]) +@jtu.thread_unsafe_test_class() # debug print test is not thread safe class PallasCallPrintTest(PallasBaseTest): def test_debug_print(self): @@ -2383,6 +2521,7 @@ def kernel(x_ref, o_ref): class PallasCallTraceTest(PallasBaseTest): + @jtu.thread_unsafe_test() # stdout redirection is not thread safe def test_trace_start_stop_match(self): def kernel(o_ref): with jax.named_scope('scope1'): @@ -2402,6 +2541,7 @@ def kernel(o_ref): self.assertEqual(num_start, 1) self.assertEqual(num_stop, 1) + @jtu.thread_unsafe_test() # stdout redirection is not thread safe def test_run_scoped(self): def kernel(o_ref): def scope1(): @@ -2527,8 +2667,8 @@ def kernel(x_ref, o_ref, send_sem, recv_sem): output_shape = jax.ShapeDtypeStruct((8, 128), jnp.bool_) grid_spec = pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, - in_specs=[pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM)], - out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), + in_specs=[pl.BlockSpec(memory_space=pltpu.VMEM)], + out_specs=pl.BlockSpec(memory_space=pltpu.VMEM), grid=(1,), scratch_shapes=[pltpu.SemaphoreType.DMA] * 2, ) @@ -2549,7 +2689,7 @@ def kernel(x_ref, o_ref, send_sem, recv_sem): mesh=mesh, in_specs=P(None, 'x'), out_specs=P(None, 'x'), - check_rep=False + check_vma=False ) )(input_arr) @@ -2570,8 +2710,7 @@ def body(scalar_ref, x_ref, o_ref): x = jnp.arange(8 * 8 * 128, dtype=jnp.int32).reshape((8 * 8, 128)) def _x_transform(i, s_ref): - s = pl.load(s_ref, (i,)) - return (s, 0) + return (s_ref[i], 0) pallas_call = self.pallas_call( body, @@ -2668,19 +2807,19 @@ class PrettyPrintingTest(PallasBaseTest): @parameterized.parameters( ( lambda i: (i, pl.ds(0, 8), pl.ds(0, 128)), - 'dma_start c[d,:,:] -> e[...] f', + 'dma_start(p0) c[d,:,:] -> e[...] f', ), ( lambda i: (0, pl.ds(i, 8), pl.ds(0, 128)), - 'dma_start c[0,d:d+8,:] -> e[...] f', + 'dma_start(p0) c[0,d:d+8,:] -> e[...] f', ), ( lambda i: (i, pl.ds(2, 4), pl.ds(0, 100)), - 'dma_start c[d,2:6,:100] -> e[...] f', + 'dma_start(p0) c[d,2:6,:100] -> e[...] f', ), ( lambda i: (i, pl.ds(2, 6), pl.ds(4, 100)), - 'dma_start c[d,2:,4:104] -> e[...] f', + 'dma_start(p0) c[d,2:,4:104] -> e[...] f', ), ) def test_dma_custom_pretty_print(self, indexer, expected): @@ -2793,9 +2932,9 @@ def kernel(x_ref, out_ref): )(x) np.testing.assert_array_equal(out, state_utils.bitcast(x, jnp.uint32)) - @only_passes_in_interpret() - def test_roll_partial(self): - """b/337384645""" + def test_roll_partial_with_static_shift(self): + if not jtu.if_cloud_tpu_at_least(2025, 5, 15): + self.skipTest('Needs a newer libtpu') x = np.arange(8192, dtype=jnp.float32).reshape(128, 64) def kernel(x_ref, out_ref): @@ -2806,6 +2945,22 @@ def kernel(x_ref, out_ref): )(x) np.testing.assert_array_equal(out, np.roll(x, 3, 1)) + def test_roll_partial_with_dynamic_shift(self): + if not jtu.if_cloud_tpu_at_least(2025, 5, 15): + self.skipTest('Needs a newer libtpu') + if self.INTERPRET: + self.skipTest('Test only applies to non-interpret mode.') + x = np.arange(8192, dtype=jnp.float32).reshape(128, 64) + + def kernel(x_ref, out_ref): + amount = x_ref[0, 0].astype(jnp.int32) + out_ref[...] = pltpu.roll(x_ref[...], amount, 1) + + with self.assertRaisesRegex(Exception, 'unsupported unaligned shape'): + _ = self.pallas_call( + kernel, out_shape=jax.ShapeDtypeStruct((128, 64), jnp.float32) + )(x) + @only_passes_in_interpret() def test_retiling1(self): """b/352626602""" @@ -2904,6 +3059,231 @@ def kernel(x_ref, out_ref): out, np.zeros((8, 8, 2, 128), dtype=jnp.float32) ) + # (q, m, n) -> (q, m * n) where n % 128 == 0 + @parameterized.parameters( + (32, 16, 512, jnp.float32), + (24, 1, 512, jnp.uint32), + (3, 3, 256, jnp.uint32), + (9, 15, 256, jnp.float32), + (3, 2, 256, jnp.float32), + ) + def test_reshape_two_minor_dims_to_R2(self, q, m, n, dtype): + if not jtu.if_cloud_tpu_at_least(2025, 5, 23): + self.skipTest('Needs a newer libTPU') + def kernel(x_ref, y_ref): + y_ref[...] = x_ref[...].reshape( + x_ref.shape[0], x_ref.shape[1] * x_ref.shape[2] + ) + + x = np.arange(q * m * n, dtype=dtype).reshape(q, m, n) + out = self.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((q, m * n), dtype), + )(x) + np.testing.assert_array_equal(out, x.reshape([q, m * n])) + + # (q, m, n, k) -> (q, m, n * k) where k % 128 == 0 + @parameterized.parameters( + (3, 8, 17, 512, jnp.float32), + (1, 8, 9, 256, jnp.float32), + (1, 8, 3, 256, jnp.uint32), + (10, 1, 4, 256, jnp.uint32), + (1, 2, 2, 256, jnp.float32), + ) + def test_reshape_two_minor_dims_to_R3(self, q, m, n, k, dtype): + if not jtu.if_cloud_tpu_at_least(2025, 5, 23): + self.skipTest('Needs a newer libTPU') + def kernel(x_ref, y_ref): + y_ref[...] = x_ref[...].reshape( + x_ref.shape[0], x_ref.shape[1], x_ref.shape[2] * x_ref.shape[3] + ) + + x = np.arange(q * m * n * k, dtype=dtype).reshape(q, m, n, k) + out = self.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((q, m, n * k), dtype), + )(x) + np.testing.assert_array_equal(out, x.reshape([q, m, n * k])) + + # (p, q, m, n, k) -> (p, q * m * n * k) where k % 128 == 0 + @parameterized.parameters( + (5, 3, 8, 17, 512, jnp.float32), + (6, 1, 8, 9, 256, jnp.float32), + (16, 1, 8, 3, 256, jnp.uint32), + (3, 2, 1, 4, 256, jnp.uint32), + (1, 7, 2, 2, 256, jnp.float32), + ) + def test_reshape_four_minor_dims_to_R2(self, p, q, m, n, k, dtype): + if not jtu.if_cloud_tpu_at_least(2025, 5, 23): + self.skipTest('Needs a newer libTPU') + def kernel(x_ref, y_ref): + y_ref[...] = x_ref[...].reshape( + x_ref.shape[0], + x_ref.shape[1] * x_ref.shape[2] * x_ref.shape[3] * x_ref.shape[4], + ) + + x = np.arange(p * q * m * n * k, dtype=dtype).reshape(p, q, m, n, k) + out = self.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((p, q * m * n * k), dtype), + )(x) + np.testing.assert_array_equal(out, x.reshape([p, q * m * n * k])) + + # (q, m, n, k) -> (q, m, 1, n * k) where k % 128 == 0 + def test_reshape_two_minor_dims_preserve_rank(self): + if not jtu.if_cloud_tpu_at_least(2025, 5, 23): + self.skipTest('Needs a newer libTPU') + def kernel(x_ref, y_ref): + y_ref[...] = ( + x_ref[...] + .reshape( + x_ref.shape[0], x_ref.shape[1], x_ref.shape[2] * x_ref.shape[3] + ) + .reshape( + x_ref.shape[0], 1, x_ref.shape[1], x_ref.shape[2] * x_ref.shape[3] + ) + ) + + q, m, n, k = 10, 1, 4, 256 + x = np.arange(q * m * n * k, dtype=jnp.float32).reshape(q, m, n, k) + out = self.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((q, m, 1, n * k), jnp.float32), + )(x) + np.testing.assert_array_equal(out, x.reshape([q, m, 1, n * k])) + + # (q, m, n, k) -> (q * m, n * k) where k % 128 == 0 + @parameterized.parameters( + (3, 8, 17, 512, jnp.float32), + (1, 8, 9, 256, jnp.float32), + (1, 8, 3, 256, jnp.uint32), + (10, 1, 4, 256, jnp.uint32), + (1, 2, 2, 256, jnp.float32), + ) + def test_reshape_fold_two_leading_dims_and_two_minor_dims_R4_to_R2( + self, q, m, n, k, dtype + ): + if not jtu.if_cloud_tpu_at_least(2025, 5, 23): + self.skipTest('Needs a newer libTPU') + def kernel(x_ref, y_ref): + y_ref[...] = x_ref[...].reshape( + x_ref.shape[0] * x_ref.shape[1], x_ref.shape[2] * x_ref.shape[3] + ) + + x = np.arange(q * m * n * k, dtype=dtype).reshape(q, m, n, k) + out = self.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((q * m, n * k), dtype), + )(x) + np.testing.assert_array_equal(out, x.reshape([q * m, n * k])) + + # (q * m, n, k) -> (q, m, n * k) where k % 128 == 0 + @parameterized.parameters( + (2, 2, 17, 512, jnp.float32), + (3, 2, 3, 256, jnp.float32), + (1, 5, 4, 384, jnp.uint32), + ) + def test_reshape_unfold_leading_dim_and_fold_two_minor_dims_R3_to_R3( + self, q, m, n, k, dtype + ): + if not jtu.if_cloud_tpu_at_least(2025, 5, 23): + self.skipTest('Needs a newer libTPU') + def kernel(x_ref, y_ref): + y_ref[...] = x_ref[...].reshape( + q, + m, + x_ref.shape[1] * x_ref.shape[2], + ) + + x = np.arange(q * m * n * k, dtype=dtype).reshape(q * m, n, k) + out = self.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((q, m, n * k), dtype), + )(x) + np.testing.assert_array_equal(out, x.reshape([q, m, n * k])) + + # (q * m, n * k) -> (q, m, n, k) where k % 128 == 0 + @parameterized.parameters( + (2, 2, 17, 512, jnp.float32), + (3, 2, 3, 256, jnp.float32), + (1, 5, 4, 384, jnp.uint32), + ) + def test_reshape_unfold_leading_and_minor_dims_R2_to_R4( + self, q, m, n, k, dtype + ): + if not jtu.if_cloud_tpu_at_least(2025, 5, 23): + self.skipTest('Needs a newer libTPU') + def kernel(x_ref, y_ref): + y_ref[...] = x_ref[...].reshape(q, m, n, k) + + x = np.arange(q * m * n * k, dtype=dtype).reshape(q * m, n * k) + out = self.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((q, m, n, k), dtype), + )(x) + np.testing.assert_array_equal(out, x.reshape([q, m, n, k])) + + # (q, m, n * k) -> (q * m, n, k) where k % 128 == 0 + @parameterized.parameters( + (2, 2, 17, 512, jnp.float32), + (3, 2, 8, 256, jnp.float32), + (1, 5, 4, 384, jnp.uint32), + ) + def test_reshape_fold_leading_dims_and_unfold_minor_dim( + self, q, m, n, k, dtype + ): + if not jtu.if_cloud_tpu_at_least(2025, 5, 23): + self.skipTest('Needs a newer libTPU') + def kernel(x_ref, y_ref): + y_ref[...] = x_ref[...].reshape(q * m, n, k) + + x = np.arange(q * m * n * k, dtype=dtype).reshape(q, m, n * k) + out = self.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((q * m, n, k), dtype), + )(x) + np.testing.assert_array_equal(out, x.reshape([q * m, n, k])) + + # (q, m, n, k) -> (q, m * n, k) where k % 128 == 0 + @parameterized.parameters( + (2, 2, 17, 512, jnp.float32), + (3, 2, 8, 256, jnp.float32), + (1, 5, 4, 384, jnp.uint32), + ) + def test_reshape_fold_middle_dims(self, q, m, n, k, dtype): + if not jtu.if_cloud_tpu_at_least(2025, 5, 23): + self.skipTest('Needs a newer libTPU') + + def kernel(x_ref, y_ref): + y_ref[...] = x_ref[...].reshape(q, m * n, k) + + x = np.arange(q * m * n * k, dtype=dtype).reshape(q, m, n, k) + out = self.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((q, m * n, k), dtype), + )(x) + np.testing.assert_array_equal(out, x.reshape([q, m * n, k])) + + # (q, m * n, k) -> (q, m, n, k) where k % 128 == 0 + @parameterized.parameters( + (2, 2, 17, 512, jnp.float32), + (3, 2, 8, 256, jnp.float32), + (1, 5, 4, 384, jnp.uint32), + ) + def test_reshape_unfold_middle_dims(self, q, m, n, k, dtype): + if not jtu.if_cloud_tpu_at_least(2025, 5, 23): + self.skipTest('Needs a newer libTPU') + + def kernel(x_ref, y_ref): + y_ref[...] = x_ref[...].reshape(q, m, n, k) + + x = np.arange(q * m * n * k, dtype=dtype).reshape(q, m * n, k) + out = self.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((q, m, n, k), dtype), + )(x) + np.testing.assert_array_equal(out, x.reshape([q, m, n, k])) + class MiscellaneousInterpretTest(MiscellaneousTest): INTERPRET: bool = True diff --git a/tests/pallas/tpu_ragged_paged_attention_test.py b/tests/pallas/tpu_ragged_paged_attention_test.py index bffcebc5254b..eebc292ce3ab 100644 --- a/tests/pallas/tpu_ragged_paged_attention_test.py +++ b/tests/pallas/tpu_ragged_paged_attention_test.py @@ -13,14 +13,17 @@ # limitations under the License. import random + from absl.testing import absltest from absl.testing import parameterized import jax +from jax._src import dtypes from jax._src import test_util as jtu from jax.experimental.pallas.ops.tpu.ragged_paged_attention import ( + cdiv, + dynamic_validate_inputs, ragged_paged_attention, ref_ragged_paged_attention, - validate_inputs_on_runtime, ) import jax.numpy as jnp @@ -28,13 +31,8 @@ jax.config.parse_flags_with_absl() -def ceil_div(x, a): - assert a != 0 - return (x + a - 1) // a - - @jtu.with_config(jax_numpy_dtype_promotion="standard") -class PagedAttentionKernelTest(jtu.JaxTestCase): +class RaggedPagedAttentionKernelTest(jtu.JaxTestCase): def _test_ragged_paged_attention( self, @@ -42,7 +40,8 @@ def _test_ragged_paged_attention( num_heads, # [num_q_heads, num_kv_heads] head_dim, page_size, - dtype, + q_dtype, + kv_dtype, num_pages, *, num_kv_pages_per_block=8, @@ -50,6 +49,10 @@ def _test_ragged_paged_attention( vmem_limit_bytes=32 * 1024 * 1024, max_num_batched_tokens=512, max_num_seq=8, + sliding_window: int | None = None, + soft_cap: float | None = None, + k_scale: float | None = None, + v_scale: float | None = None, ): if not jtu.is_device_tpu_at_least(version=4): self.skipTest("Expect TPUv4+") @@ -63,73 +66,116 @@ def _test_ragged_paged_attention( max_num_batched_tokens = max(cu_q_lens[-1], max_num_batched_tokens) max_num_seq = max(len(seq_lens), max_num_seq) max_kv_len = max(kv_lens) - pages_per_seq = ceil_div(max_kv_len, page_size) + pages_per_seq = cdiv(max_kv_len, page_size) num_q_heads, num_kv_heads = num_heads - cu_q_lens = jnp.array(cu_q_lens, dtype=jnp.int32) - kv_lens = jnp.array(kv_lens, dtype=jnp.int32) - cu_q_lens = jnp.pad(cu_q_lens, (0, max_num_seq + 1 - cu_q_lens.shape[0])) - kv_lens = jnp.pad(kv_lens, (0, max_num_seq - kv_lens.shape[0])) prng_key = jax.random.key(1234) - k0, k1, k2, k3 = jax.random.split(prng_key, 4) + k0, k1 = jax.random.split(prng_key, 2) q = jax.random.normal( k0, (max_num_batched_tokens, num_q_heads, head_dim), - dtype=dtype, + dtype=q_dtype, ) - k_pages = jax.random.normal( - k1, - (num_pages, page_size, num_kv_heads, head_dim), - dtype=dtype, - ) - v_pages = jax.random.normal( - k2, - (num_pages, page_size, num_kv_heads, head_dim), - dtype=dtype, + page_cnt = 0 + page_indices_list = [] + kv_pages_list = [] + for kv_len in kv_lens: + if jnp.issubdtype(kv_dtype, jnp.integer): + # random.randint doesn't support int4, so we use jnp.int32 here and then + # convert to the desired dtype. + kv = jax.random.normal( + k1, + (kv_len, num_kv_heads * 2, head_dim), + dtype=jnp.int32, + ) + kv = kv.astype(kv_dtype) + else: + kv = jax.random.normal( + k1, + (kv_len, num_kv_heads * 2, head_dim), + dtype=kv_dtype, + ) + kv = jnp.pad( + kv, + ((0, cdiv(kv_len, page_size) * page_size - kv_len), (0, 0), (0, 0)), + constant_values=jnp.nan, + ).reshape(-1, page_size, num_kv_heads * 2, head_dim) + indices = page_cnt + jnp.arange(kv.shape[0], dtype=jnp.int32) + indices = jnp.pad( + indices, + ((0, pages_per_seq - indices.shape[0]),), + constant_values=jnp.nan, + ) + page_indices_list.append(indices) + page_cnt += kv.shape[0] + kv_pages_list.append(kv) + + kv_pages = jnp.concatenate(kv_pages_list, axis=0) + kv_pages = jnp.pad( + kv_pages, + ((0, num_pages - kv_pages.shape[0]), (0, 0), (0, 0), (0, 0)), + constant_values=jnp.nan, ) - page_indices = jax.random.randint( - k3, (max_num_seq, pages_per_seq), 0, num_pages, dtype=jnp.int32 + page_indices = jnp.stack(page_indices_list, axis=0) + page_indices = jnp.pad( + page_indices, + ((0, max_num_seq - page_indices.shape[0]), (0, 0)), + constant_values=jnp.nan, ) - + cu_q_lens = jnp.array(cu_q_lens, dtype=jnp.int32) + cu_q_lens = jnp.pad(cu_q_lens, (0, max_num_seq + 1 - cu_q_lens.shape[0])) + kv_lens = jnp.array(kv_lens, dtype=jnp.int32) + kv_lens = jnp.pad(kv_lens, (0, max_num_seq - kv_lens.shape[0])) num_seqs = jnp.array([len(seq_lens)], dtype=jnp.int32) - validate_inputs_on_runtime( + dynamic_validate_inputs( q, - k_pages, - v_pages, + kv_pages, kv_lens, page_indices, cu_q_lens, num_seqs, + sliding_window=sliding_window, + soft_cap=soft_cap, ) + actual_num_q_tokens = cu_q_lens[num_seqs[0]] output = ragged_paged_attention( q, - k_pages, - v_pages, + kv_pages, kv_lens, page_indices, cu_q_lens, num_seqs=num_seqs, - num_kv_pages_per_block=num_kv_pages_per_block, + num_kv_pages_per_block=min(num_kv_pages_per_block, pages_per_seq), num_queries_per_block=num_queries_per_block, vmem_limit_bytes=vmem_limit_bytes, - )[: cu_q_lens[num_seqs[0]]] + sliding_window=sliding_window, + soft_cap=soft_cap, + k_scale=k_scale, + v_scale=v_scale, + )[:actual_num_q_tokens] expected = ref_ragged_paged_attention( q, - k_pages, - v_pages, + kv_pages, kv_lens, page_indices, cu_q_lens, num_seqs=num_seqs, + sliding_window=sliding_window, + soft_cap=soft_cap, + k_scale=k_scale, + v_scale=v_scale, ) + dtype_bits = dtypes.bit_width(jnp.dtype(kv_dtype)) tols = { - "float32": 0.15, - "bfloat16": 0.2, + 32: 0.15, + 16: 0.2, + 8: 0.2, + 4: 0.2, } - tol = tols[jnp.dtype(dtype).name] + tol = tols[dtype_bits] self.assertAllClose(output, expected, atol=tol, rtol=tol) @parameterized.product( @@ -148,9 +194,40 @@ def test_ragged_paged_attention_basic(self, dtype): head_dim, page_size, dtype, + dtype, num_pages, ) + # TODO: support int4 and int8 + @parameterized.product( + q_dtype=[jnp.bfloat16], + kv_dtype=[jnp.float8_e5m2, jnp.float8_e4m3fn], + kv_scales=[(0.5, 0.5), (None, None)], + ) + def test_ragged_paged_attention_quantized_kv_cache( + self, q_dtype, kv_dtype, kv_scales + ): + if not jtu.is_device_tpu_at_least(version=5): + self.skipTest("Expect TPUv5+") + seq_lens = [(192, 328), (128, 180), (64, 255)] + num_heads = (32, 8) + head_dim = 128 + page_size = 16 + num_pages = 1000 + k_scale, v_scale = kv_scales + + self._test_ragged_paged_attention( + seq_lens, + num_heads, + head_dim, + page_size, + q_dtype, + kv_dtype, + num_pages, + k_scale=k_scale, + v_scale=v_scale, + ) + @parameterized.product( dtype=[jnp.float32, jnp.bfloat16], ) @@ -184,6 +261,7 @@ def test_ragged_paged_attention_decode_only(self, dtype): head_dim, page_size, dtype, + dtype, num_pages, ) @@ -220,6 +298,7 @@ def test_ragged_paged_attention_prefill_only(self, dtype): head_dim, page_size, dtype, + dtype, num_pages, ) @@ -256,13 +335,14 @@ def test_ragged_paged_attention_mixed(self, dtype): head_dim, page_size, dtype, + dtype, num_pages, ) @parameterized.product( num_seqs=[1, 5, 16], # TODO(jevinjiang): Support more num_heads! - num_heads=[(32, 8), (32, 16), (12, 2), (4, 4)], + num_heads=[(32, 8), (32, 16), (12, 2), (4, 4), (8, 1)], dtype=[jnp.float32, jnp.bfloat16], num_kv_pages_per_block=[4, 8], num_queries_per_block=[32, 64], @@ -291,11 +371,137 @@ def test_ragged_paged_attention_complex( head_dim, page_size, dtype, + dtype, num_pages, num_kv_pages_per_block=num_kv_pages_per_block, num_queries_per_block=num_queries_per_block, ) + @parameterized.product( + num_kv_pages_per_block=[4, 8], + num_queries_per_block=[32, 64], + sliding_window=[None, 5, 128], + ) + def test_ragged_paged_attention_sliding_window( + self, + num_kv_pages_per_block, + num_queries_per_block, + sliding_window: int | None, + ): + num_seqs = 5 + num_heads = (4, 4) + dtype = jnp.float32 + seq_lens = [] + for _ in range(num_seqs): + q_len = random.randint(1, 100) + kv_len = q_len + random.randint(0, 50) + seq_lens.append((q_len, kv_len)) + # TODO(jevinjiang): Support non-128 head_dim! + head_dim = 128 + page_size = 16 + num_pages = 1000 + + self._test_ragged_paged_attention( + seq_lens, + num_heads, + head_dim, + page_size, + dtype, + dtype, + num_pages, + num_kv_pages_per_block=num_kv_pages_per_block, + num_queries_per_block=num_queries_per_block, + sliding_window=sliding_window, + ) + + @parameterized.product( + num_kv_pages_per_block=[4, 8], + num_queries_per_block=[32, 64], + soft_cap=[None, 50.0], + ) + def test_ragged_paged_attention_logit_soft_capping( + self, + num_kv_pages_per_block, + num_queries_per_block, + soft_cap: float | None, + ): + num_heads = (12, 2) + num_seqs = 2 + dtype = jnp.float32 + seq_lens = [] + for _ in range(num_seqs): + q_len = random.randint(1, 100) + kv_len = q_len + random.randint(0, 50) + seq_lens.append((q_len, kv_len)) + head_dim = 128 + page_size = 16 + num_pages = 1000 + + self._test_ragged_paged_attention( + seq_lens, + num_heads, + head_dim, + page_size, + dtype, + dtype, + num_pages, + num_kv_pages_per_block=num_kv_pages_per_block, + num_queries_per_block=num_queries_per_block, + soft_cap=soft_cap, + ) + + def test_ragged_paged_attention_sliding_window_should_be_positive(self): + dtype = jnp.float32 + seq_lens = [(192, 328), (128, 180), (64, 255)] + num_heads = (32, 8) + head_dim = 128 + page_size = 16 + num_pages = 1000 + + with self.assertRaisesRegex(ValueError, "must be positive"): + self._test_ragged_paged_attention( + seq_lens, + num_heads, + head_dim, + page_size, + dtype, + dtype, + num_pages, + sliding_window=0, + ) + + with self.assertRaisesRegex(ValueError, "must be positive"): + self._test_ragged_paged_attention( + seq_lens, + num_heads, + head_dim, + page_size, + dtype, + dtype, + num_pages, + sliding_window=-1, + ) + + def test_ragged_paged_attention_soft_cap_cannot_be_zero(self): + dtype = jnp.float32 + seq_lens = [(192, 328), (128, 180), (64, 255)] + num_heads = (32, 8) + head_dim = 128 + page_size = 16 + num_pages = 1000 + + with self.assertRaisesRegex(ValueError, "must not be 0.0"): + self._test_ragged_paged_attention( + seq_lens, + num_heads, + head_dim, + page_size, + dtype, + dtype, + num_pages, + soft_cap=0.0, + ) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/pallas/tpu_splash_attention_kernel_sharded_test.py b/tests/pallas/tpu_splash_attention_kernel_sharded_test.py new file mode 100644 index 000000000000..9edd425f24dd --- /dev/null +++ b/tests/pallas/tpu_splash_attention_kernel_sharded_test.py @@ -0,0 +1,223 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for partitioning splash_attention.""" + +import functools +import math +from absl.testing import absltest, parameterized +import jax +from jax import random +from jax._src import test_util as jtu +from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_kernel as splash +from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_mask as mask_lib +from jax._src.shard_map import shard_map +import jax.numpy as jnp +from jax.sharding import PartitionSpec +import numpy as np + +partial = functools.partial + +jax.config.parse_flags_with_absl() + + +@jtu.with_config(jax_traceback_filtering="off") +class PallasBaseTest(jtu.JaxTestCase): + INTERPRET = False + + def setUp(self): + super().setUp() + if not jtu.is_device_tpu(): + self.skipTest("Test requires TPU.") + + if len(jax.devices()) < 4: + self.skipTest("This test requires at least 4 devices.") + + def _assert_allclose(self, x, y, **kwargs): + if x.dtype == np.dtype(jnp.bfloat16): + x = x.astype(np.float32) + if y.dtype == np.dtype(jnp.bfloat16): + y = y.astype(np.float32) + self.assertEqual(x.dtype, y.dtype) + self.assertTupleEqual(x.shape, y.shape) + np.testing.assert_allclose(x, y, **kwargs) + + +def generate_mask(shape, num_heads, seed) -> np.ndarray: + assert num_heads >= 2 + assert shape > (64, 64) + + masks = [ + mask_lib.make_causal_mask(shape), + mask_lib.make_local_attention_mask(shape, window_size=(64, 64)), + ] + masks += [mask_lib.make_random_mask(shape, 0.8, seed)] * (num_heads - 2) + return np.stack(masks, axis=0) + + +class SplashAttentionShardingTest(PallasBaseTest): + + @parameterized.product( + topology=[(1, 1), (2, 1), (2, 2), (1, 2), (1, 4), (4, 1)], + num_heads=[2, 4, 16], + dtype=[jnp.bfloat16], + is_dynamic_mask=[False, True], + ) + def test_dynamic_mask_manual_partitioning_mha( + self, topology, num_heads, dtype, is_dynamic_mask + ): + k1, k2, k3 = random.split(random.key(0), 3) + seq_len = 1024 + head_dim = 128 + + head_shards, q_seq_shards = topology + num_devices = math.prod(topology) + + if head_shards > num_heads: + self.skipTest( + f"This test requires {num_heads} heads, but has only" + f" {head_shards} head shards available." + ) + + if len(jax.devices()) < num_devices: + self.skipTest( + f"This test requires {num_devices} devices, but has only" + f" {len(jax.devices())} devices available." + ) + + q = random.uniform(k1, (num_heads, seq_len, head_dim), dtype=dtype) + k = random.uniform(k2, (num_heads, seq_len, head_dim), dtype=dtype) + v = random.uniform(k3, (num_heads, seq_len, head_dim), dtype=dtype) + + mask = generate_mask((seq_len, seq_len), num_heads, seed=0) + if is_dynamic_mask: + mask = jnp.array(mask) + + devices = np.asarray(jax.devices()[:num_devices]).reshape( + head_shards, q_seq_shards + ) + + mesh = jax.sharding.Mesh(devices, ("heads", "q_seq")) + q_spec = PartitionSpec( + "heads" if head_shards > 1 else None, + "q_seq" if q_seq_shards > 1 else None, + ) + kv_spec = PartitionSpec("heads" if head_shards > 1 else None, None) + kernel = splash.make_splash_mha( + mask, head_shards=head_shards, q_seq_shards=q_seq_shards + ) + kernel_spec = kernel.manual_sharding_spec( + jax.sharding.NamedSharding(mesh, q_spec) + ) + + @partial( + shard_map, + mesh=mesh, + in_specs=( + kernel_spec, + q_spec, + kv_spec, + kv_spec, + ), + out_specs=q_spec, + check_vma=False, + ) + def f(kernel, q, k, v): + return kernel(q, k, v) + + out = f(kernel, q, k, v) + out_ref = jax.vmap(splash.attention_reference)(mask, q, k, v, None) + self._assert_allclose(out, out_ref, rtol=3e-3, atol=3e-3) + + @parameterized.product( + topology=[(1, 1), (2, 1), (2, 2), (1, 2), (1, 4), (4, 1)], + num_heads=[2, 4], + dtype=[jnp.bfloat16], + is_dynamic_mask=[False, True], + ) + def test_dynamic_mask_manual_partitioning_mha_bwd( + self, topology, num_heads, dtype, is_dynamic_mask + ): + assert num_heads % 2 == 0 + k1, k2, k3, k4 = random.split(random.key(0), 4) + seq_len = 1024 + head_dim = 128 + + head_shards, q_seq_shards = topology + num_devices = math.prod(topology) + + if head_shards > num_heads: + self.skipTest( + f"This test requires {num_heads} heads, but has only" + f" {head_shards} head shards available." + ) + + q = random.uniform(k1, (num_heads, seq_len, head_dim), dtype=dtype) + k = random.uniform(k2, (num_heads, seq_len, head_dim), dtype=dtype) + v = random.uniform(k3, (num_heads, seq_len, head_dim), dtype=dtype) + + mask = generate_mask((seq_len, seq_len), num_heads, seed=0) + if is_dynamic_mask: + mask = jnp.array(mask) + + devices = np.asarray(jax.devices()[:num_devices]).reshape( + head_shards, q_seq_shards + ) + + mesh = jax.sharding.Mesh(devices, ("heads", "q_seq")) + q_spec = PartitionSpec( + "heads" if head_shards > 1 else None, + "q_seq" if q_seq_shards > 1 else None, + ) + kv_spec = PartitionSpec("heads" if head_shards > 1 else None, None) + + kernel = splash.make_splash_mha( + mask, head_shards=head_shards, q_seq_shards=q_seq_shards + ) + kernel_spec = kernel.manual_sharding_spec( + jax.sharding.NamedSharding(mesh, q_spec) + ) + + @partial( + shard_map, + mesh=mesh, + in_specs=( + kernel_spec, + q_spec, + kv_spec, + kv_spec, + ), + out_specs=q_spec, + check_vma=False, + ) + def f(kernel, q, k, v): + return kernel(q, k, v) + + f_ref = jax.vmap(splash.attention_reference) + + out, out_vjp = jax.vjp(f, kernel, q, k, v) + out_ref, out_vjp_ref = jax.vjp(f_ref, mask, q, k, v, None) + self._assert_allclose(out, out_ref, rtol=3e-3, atol=3e-3) + + do = random.uniform(k4, out.shape, dtype=out.dtype) + _, dq, dk, dv = out_vjp(do) + _, dq_ref, dk_ref, dv_ref, _ = out_vjp_ref(do.astype(jnp.float32)) + + self.assertAllClose(dq, dq_ref, atol=5e-2) + self.assertAllClose(dk, dk_ref, atol=5e-2) + self.assertAllClose(dv, dv_ref, atol=5e-2) + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/pallas/tpu_splash_attention_kernel_test.py b/tests/pallas/tpu_splash_attention_kernel_test.py index dfe0bcc0da3b..a494a62745d1 100644 --- a/tests/pallas/tpu_splash_attention_kernel_test.py +++ b/tests/pallas/tpu_splash_attention_kernel_test.py @@ -32,11 +32,9 @@ import numpy as np -try: - import hypothesis as hp - import hypothesis.strategies as hps -except (ModuleNotFoundError, ImportError): - raise unittest.SkipTest("these tests require hypothesis") +import hypothesis as hp +import hypothesis.strategies as hps + jax.config.parse_flags_with_absl() jtu.setup_hypothesis(max_examples=5) @@ -303,14 +301,6 @@ def attn_logits_soft_cap_strategy() -> hps.SearchStrategy[float | None]: return hps.one_of(hps.just(None), hps.floats(min_value=1.0, max_value=50.0)) -def to_dynamic_mask(mask: mask_lib.MultiHeadMask) -> jax.Array: - q_seq_len, kv_seq_len = mask.masks[0].shape - full_mask_slice = (slice(0, q_seq_len), slice(0, kv_seq_len)) - dynamic_mask = jnp.stack([m[full_mask_slice] for m in mask.masks], axis=0) - - return dynamic_mask - - @jtu.with_config(jax_traceback_filtering="off") class PallasBaseTest(jtu.JaxTestCase): INTERPRET = False @@ -337,6 +327,7 @@ def _assert_allclose(self, x, y, **kwargs): np.testing.assert_allclose(x, y, **kwargs) +@jtu.thread_unsafe_test_class() # hypothesis is not thread safe class SplashAttentionTest(PallasBaseTest): @parameterized.product( is_mqa=(False, True), @@ -384,7 +375,7 @@ def test_splash_attention(self, is_mqa, is_segmented, is_dynamic_mask, data): masks = data.draw(mha_mask_strategy(q_seq_len, kv_seq_len, num_q_heads)) mask = mask_lib.MultiHeadMask(tuple(m.get_mask() for m in masks)) if is_dynamic_mask: - mask = to_dynamic_mask(mask) + mask = jnp.array(mask[:, :, :]) block_sizes = data.draw(block_sizes_strategy(q_seq_len, kv_seq_len)) if is_mqa: @@ -460,7 +451,7 @@ def test_splash_attention_fwd( masks = data.draw(mha_mask_strategy(q_seq_len, kv_seq_len, num_q_heads)) mask = mask_lib.MultiHeadMask(tuple(m.get_mask() for m in masks)) if is_dynamic_mask: - mask = to_dynamic_mask(mask) + mask = jnp.array(mask[:, :, :]) block_sizes = data.draw(block_sizes_strategy(q_seq_len, kv_seq_len)) if is_mqa: attn_ref = splash.make_masked_mqa_reference(mask) @@ -522,9 +513,9 @@ def test_splash_attention_custom_bwd(self, is_segmented, data): masks = data.draw(mha_mask_strategy(q_seq_len, kv_seq_len, 1)) mask = jnp.array(masks[0].get_mask()[:, :]) attn_logits_soft_cap = data.draw(attn_logits_soft_cap_strategy(), - label="logit_cap") + label="logit_cap") attn_ref = partial(splash.attention_reference, mask, - attn_logits_soft_cap=attn_logits_soft_cap) + attn_logits_soft_cap=attn_logits_soft_cap) attn_custom = partial(splash.attention_reference_custom, mask, attn_logits_soft_cap=attn_logits_soft_cap) attn_custom_vanilla = partial(splash.attention_reference_custom, mask, @@ -532,7 +523,7 @@ def test_splash_attention_custom_bwd(self, is_segmented, data): attn_logits_soft_cap=attn_logits_soft_cap) o_ref, attn_vjp_ref = jax.vjp(attn_ref, q, k, v, segment_ids) q32, k32, v32 = jax.tree.map(lambda x: x.astype(jnp.float32), - (q, k, v)) + (q, k, v)) o_custom = attn_custom(q32, k32, v32, segment_ids) _, attn_vjp = jax.vjp(attn_custom, q32, k32, v32, segment_ids) _, attn_vanilla_vjp = jax.vjp(attn_custom_vanilla, q32, k32, v32, @@ -628,10 +619,10 @@ def test_splash_attention_bwd( masks = data.draw(mha_mask_strategy(q_seq_len, kv_seq_len, num_q_heads)) mask = mask_lib.MultiHeadMask(tuple(m.get_mask() for m in masks)) if use_dynamic_mask: - mask = to_dynamic_mask(mask) + mask = jnp.array(mask[:, :, :]) block_sizes = data.draw( block_sizes_strategy(q_seq_len, kv_seq_len, include_bwd_blocks=True, - use_fused_bwd_kernel=use_fused_bwd_kernel) + use_fused_bwd_kernel=use_fused_bwd_kernel) ) if is_mqa: attn_ref = splash.make_masked_mqa_reference(mask, backward_impl="custom") diff --git a/tests/pallas/tpu_splash_attention_mask_test.py b/tests/pallas/tpu_splash_attention_mask_test.py index f39b4d839340..e2e420edee8c 100644 --- a/tests/pallas/tpu_splash_attention_mask_test.py +++ b/tests/pallas/tpu_splash_attention_mask_test.py @@ -44,6 +44,15 @@ def _make_local_attention_mask(*args, **kwargs): return mask_lib.make_local_attention_mask(*args, **kwargs) +def _make_lazy_chunked_causal_mask(shape, chunk_size): + mask = mask_lib.ChunkedCausalMask(shape=shape, chunk_size=chunk_size) + return mask[:, :] + + +def _make_chunked_causal_mask(shape, chunk_size): + return mask_lib.make_chunk_attention_mask(shape=shape, chunk_size=chunk_size) + + class SplashAttentionMaskTest(jtu.JaxTestCase): @parameterized.parameters([_make_lazy_causal_mask, _make_causal_mask]) @@ -412,6 +421,181 @@ def test_lazy_local_mask_chunking( block_size, ) + @parameterized.parameters( + [_make_lazy_chunked_causal_mask, _make_chunked_causal_mask] + ) + def test_chunked_causal_mask(self, make_chunked_mask): + """Tests the chunked causal mask logic for various shapes and chunk sizes.""" + with self.subTest("unit"): + expected = np.array([[1]], dtype=np.bool_) + actual = make_chunked_mask(shape=(1, 1), chunk_size=1) + self.assertArraysEqual(actual, expected) + actual = make_chunked_mask(shape=(1, 1), chunk_size=2) + self.assertArraysEqual(actual, expected) + + with self.subTest("square_exact_chunks"): + # Chunk 0: [0, 1], Chunk 1: [2, 3] + expected = np.array( + [ + [1, 0, 0, 0], + [1, 1, 0, 0], + [0, 0, 1, 0], + [0, 0, 1, 1], + ], + dtype=np.bool_, + ) + actual = make_chunked_mask(shape=(4, 4), chunk_size=2) + self.assertArraysEqual(actual, expected) + + with self.subTest("square_uneven_chunks"): + expected = np.array( + [ + [1, 0, 0, 0, 0], + [1, 1, 0, 0, 0], + [1, 1, 1, 0, 0], + [0, 0, 0, 1, 0], + [0, 0, 0, 1, 1], + ], + dtype=np.bool_, + ) + actual = make_chunked_mask(shape=(5, 5), chunk_size=3) + self.assertArraysEqual(actual, expected) + + with self.subTest("wide_rectangle"): + expected = np.array( + [ + [1, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0], + ], + dtype=np.bool_, + ) + actual = make_chunked_mask(shape=(4, 6), chunk_size=3) + self.assertArraysEqual(actual, expected) + + with self.subTest("tall_rectangle"): + expected = np.array( + [ + [1, 0, 0, 0], + [1, 1, 0, 0], + [1, 1, 1, 0], + [0, 0, 0, 1], + [0, 0, 0, 1], + [0, 0, 0, 1], + ], + dtype=np.bool_, + ) + actual = make_chunked_mask(shape=(6, 4), chunk_size=3) + self.assertArraysEqual(actual, expected) + + with self.subTest("chunk_size_1"): + # Should only allow self-attention q==k and chunk_size == 1 + expected = np.array( + [ + [1, 0, 0, 0], + [0, 1, 0, 0], + [0, 0, 1, 0], + [0, 0, 0, 1], + ], + dtype=np.bool_, + ) + actual = make_chunked_mask(shape=(4, 4), chunk_size=1) + self.assertArraysEqual(actual, expected) + + with self.subTest("chunk_size_greater_equal_seqlen"): + # Should behave like a normal causal mask + expected = np.array( + [ + [1, 0, 0, 0], + [1, 1, 0, 0], + [1, 1, 1, 0], + [1, 1, 1, 1], + ], + dtype=np.bool_, + ) + # Test chunk_size == seqlen + actual_eq = make_chunked_mask(shape=(4, 4), chunk_size=4) + self.assertArraysEqual(actual_eq, expected) + # Test chunk_size > seqlen + actual_gt = make_chunked_mask(shape=(4, 4), chunk_size=5) + self.assertArraysEqual(actual_gt, expected) + + @parameterized.product( + block_size=[(128, 128), (256, 128), (128, 256)], + shape=[(512, 512), (512, 1024), (1024, 512)], + chunk_size=[64, 128, 256, 512, 1024], + ) + def test_lazy_chunked_causal_mask_chunking( + self, + block_size: tuple[int, int], + shape: tuple[int, int], + chunk_size: int, + ): + """Compares lazy chunked mask evaluation against the dense version block-by-block.""" + q_len, kv_len = shape + # Adjust block size if it exceeds shape dimensions + adjusted_block_size = ( + min(block_size[0], q_len), + min(block_size[1], kv_len), + ) + + if ( + q_len % adjusted_block_size[0] != 0 + or kv_len % adjusted_block_size[1] != 0 + ): + self.skipTest( + f"Shape {shape} not divisible by block_size {adjusted_block_size}" + ) + + dense_mask = _make_chunked_causal_mask(shape=shape, chunk_size=chunk_size) + lazy_mask = mask_lib.ChunkedCausalMask(shape=shape, chunk_size=chunk_size) + self._compare_masks( + dense_mask, + lazy_mask, + adjusted_block_size, + ) + + def test_chunked_causal_mask_invalid_chunk_size(self): + """Tests that invalid chunk_size raises ValueError.""" + with self.assertRaises(ValueError): + mask_lib.ChunkedCausalMask(shape=(10, 10), chunk_size=0) + with self.assertRaises(ValueError): + mask_lib.ChunkedCausalMask(shape=(10, 10), chunk_size=-1) + with self.assertRaises(ValueError): + mask_lib.make_chunk_attention_mask(shape=(10, 10), chunk_size=0) + + def test_chunked_causal_mask_minimal_equality_hash(self): + """Tests for __eq__ and __hash__ of ChunkedCausalMask.""" + shape1, chunk_size1 = (128, 256), 16 + shape2, chunk_size2 = (128, 128), 32 # Different shape/chunk_size + + # Create three masks: two identical, one with different shape/chunk_size. + mask1 = mask_lib.ChunkedCausalMask(shape=shape1, chunk_size=chunk_size1) + mask2 = mask_lib.ChunkedCausalMask(shape=shape1, chunk_size=chunk_size1) + mask_diff_shape = mask_lib.ChunkedCausalMask( + shape=shape2, chunk_size=chunk_size1 + ) + mask_diff_chunk = mask_lib.ChunkedCausalMask( + shape=shape1, chunk_size=chunk_size2 + ) + other_obj = object() + + # Test __eq__ + self.assertEqual(mask1, mask2) + self.assertNotEqual(mask1, mask_diff_shape) + self.assertNotEqual(mask1, mask_diff_chunk) + self.assertNotEqual(mask1, other_obj) + + # Test __hash__ of identical masks + self.assertEqual(hash(mask1), hash(mask2)) + + mask_set = {mask1, mask2, mask_diff_chunk} + self.assertLen(mask_set, 2) # mask1 and mask2 are duplicates + self.assertIn(mask1, mask_set) + self.assertIn(mask_diff_chunk, mask_set) + self.assertNotIn(mask_diff_shape, mask_set) + def test_using_logical_operators_raises_exception(self): mask_1 = mask_lib.NumpyMask( mask_lib.make_random_mask((256, 256), 0.5, seed=1) @@ -1064,7 +1248,8 @@ def test_local_mask(self, is_lazy_mask: bool): mask_info, mask_info_dkv, mask_function = self._process_mask( multi_head, block_shape ) - self.assertIsNone(mask_function) + if is_lazy_mask: + self.assertIsNotNone(mask_function) expected_partial_mask_blocks = self._stack( [ @@ -1108,10 +1293,12 @@ def test_local_mask(self, is_lazy_mask: bool): expected_mask_info = mask_info_lib.MaskInfo( expected_local_data_next, - expected_local_mask_next, + expected_local_mask_next if not is_lazy_mask else None, expected_local_block_mask, - expected_partial_mask_blocks, - None, + expected_partial_mask_blocks if not is_lazy_mask else None, + np.arange(sequence_lengths[0], dtype=np.int32) + if is_lazy_mask + else None, ) expected_local_data_next_dkv = np.array( @@ -1143,10 +1330,14 @@ def test_local_mask(self, is_lazy_mask: bool): expected_mask_info_dkv = mask_info_lib.MaskInfo( expected_local_data_next_dkv, - expected_local_mask_next_dkv, + expected_local_mask_next_dkv if not is_lazy_mask else None, expected_local_block_mask_dkv, - expected_partial_mask_blocks.swapaxes(-1, -2), - None, + expected_partial_mask_blocks.swapaxes(-1, -2) + if not is_lazy_mask + else None, + np.arange(sequence_lengths[0], dtype=np.int32) + if is_lazy_mask + else None, ) self._assert_mask_info_match(mask_info, expected_mask_info) @@ -1175,7 +1366,9 @@ def test_local_mask_narrow(self, is_lazy_mask: bool): mask_info, mask_info_dkv, mask_function = self._process_mask( multi_head, block_shape ) - self.assertIsNone(mask_function) + + if is_lazy_mask: + self.assertIsNotNone(mask_function) expected_partial_mask_blocks = self._stack( [ @@ -1216,10 +1409,12 @@ def test_local_mask_narrow(self, is_lazy_mask: bool): expected_mask_info = mask_info_lib.MaskInfo( expected_local_data_next, - expected_local_mask_next, + expected_local_mask_next if not is_lazy_mask else None, expected_local_block_mask, - expected_partial_mask_blocks, - None, + expected_partial_mask_blocks if not is_lazy_mask else None, + np.arange(sequence_lengths[0], dtype=np.int32) + if is_lazy_mask + else None, ) expected_local_data_next_dkv = np.array( @@ -1248,10 +1443,14 @@ def test_local_mask_narrow(self, is_lazy_mask: bool): expected_mask_info_dkv = mask_info_lib.MaskInfo( expected_local_data_next_dkv, - expected_local_mask_next_dkv, + expected_local_mask_next_dkv if not is_lazy_mask else None, expected_local_block_mask_dkv, - expected_partial_mask_blocks.swapaxes(-1, -2), - None, + expected_partial_mask_blocks.swapaxes(-1, -2) + if not is_lazy_mask + else None, + np.arange(sequence_lengths[0], dtype=np.int32) + if is_lazy_mask + else None, ) self._assert_mask_info_match(mask_info, expected_mask_info) @@ -2066,11 +2265,12 @@ def test_huge_mask2(self): multi_head, block_shape ) - self.assertIsNone(mask_function) + self.assertIsNotNone(mask_function) self.assertIsNotNone(mask_info.block_mask) self.assertIsNotNone(mask_info.data_next) - self.assertIsNotNone(mask_info.mask_next) - self.assertIsNotNone(mask_info.partial_mask_blocks) + self.assertIsNone(mask_info.mask_next) + self.assertIsNone(mask_info.partial_mask_blocks) + self.assertIsNotNone(mask_info.q_sequence) def test_process_invalid_mask(self): """Masks with of an all-0 row causes undefined softmax, reject them.""" @@ -2166,7 +2366,9 @@ def test_dynamic_mask(self, is_dkv: bool): self.assertArraysEqual(mask_info.block_mask, _expected_block_mask) self.assertArraysEqual( - mask_info.partial_mask_blocks, + mask_info.partial_mask_blocks.reshape( + -1, *mask_info.partial_mask_blocks.shape[-2:] + ), _expected_partial_mask_blocks, ) self.assertArraysEqual(mask_info.mask_next, _expected_mask_next) diff --git a/tests/pallas/triton_pallas_test.py b/tests/pallas/triton_pallas_test.py new file mode 100644 index 000000000000..fe13716705de --- /dev/null +++ b/tests/pallas/triton_pallas_test.py @@ -0,0 +1,80 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Test the Triton dialect lowering for a variety of atomic operations.""" + +from absl.testing import absltest +from absl.testing import parameterized +import jax +from jax._src import config +from jax._src import dtypes +from jax._src import test_util as jtu +from jax._src.pallas.pallas_call import _trace_kernel_to_jaxpr +from jax.experimental import pallas as pl +import jax.numpy as jnp + +config.parse_flags_with_absl() + + +@jtu.with_config(jax_traceback_filtering="off") +class PallasBaseTest(jtu.JaxTestCase): + INTERPRET = False + + def setUp(self): + if jtu.test_device_matches(["cpu"]): + if not self.INTERPRET: + self.skipTest("On CPU the test works only in interpret mode") + elif jtu.test_device_matches(["gpu"]): + if not jtu.is_cuda_compute_capability_at_least("9.0"): + self.skipTest("Only works on GPU with capability >= sm90") + else: + self.skipTest("Test only works on CPU and GPU") + + super().setUp() + _trace_kernel_to_jaxpr.cache_clear() + + def pallas_call(self, *args, **kwargs): + return pl.pallas_call(*args, **kwargs, interpret=self.INTERPRET) + + +DTYPE_LIST = [jnp.float32, jnp.float16, jnp.bfloat16, + jnp.float8_e4m3fn, jnp.float8_e5m2] + + +class TritonPallasTest(PallasBaseTest): + INTERPRET = False + + @parameterized.product(src_dtype=DTYPE_LIST, dst_dtype=DTYPE_LIST) + def test_fp_dtype_cast(self, src_dtype, dst_dtype): + if src_dtype == dst_dtype: + self.skipTest("No need to test the same dtype") + if dtypes.bit_width(src_dtype) == 8 and dtypes.bit_width(dst_dtype) == 8: + self.skipTest("Not casting between 8-bit types") + + def body(x_ref, y_ref): + y_ref[...] = x_ref[...].astype(dst_dtype) + + x = 10 * jax.random.normal(jax.random.key(0), (64, 64), dtype=src_dtype) + y = self.pallas_call(body, + in_specs=[pl.BlockSpec((64, 64), lambda i: (0, 0))], + out_specs=pl.BlockSpec((64, 64), lambda i: (0, 0)), + out_shape=jax.ShapeDtypeStruct((64, 64), dst_dtype), + grid=(1,), + )(x) + self.assertEqual(y.dtype, dst_dtype) + self.assertArraysEqual(y, x.astype(dst_dtype)) + +if __name__ == "__main__": + absltest.main() diff --git a/tests/pgle_test.py b/tests/pgle_test.py index 7f9ea598d51b..78679adc962a 100644 --- a/tests/pgle_test.py +++ b/tests/pgle_test.py @@ -21,7 +21,7 @@ import tempfile import warnings -from absl.testing import absltest +from absl.testing import absltest, parameterized import jax from jax._src import api from jax._src import compilation_cache as cc @@ -65,7 +65,11 @@ def testPGLEProfilerGetFDOProfile(self): jax.jit, in_shardings=NamedSharding(mesh, PartitionSpec('x')), out_shardings=NamedSharding(mesh, PartitionSpec('x')), - compiler_options={'xla_gpu_enable_latency_hiding_scheduler': 'True'}, + compiler_options={ + 'xla_gpu_enable_latency_hiding_scheduler': 'True', + # Make sure that matmul is not emitted as Triton GEMM. + 'xla_gpu_enable_triton_gemm': 'False', + }, ) def f(x, y): return x @ y @@ -81,7 +85,7 @@ def f(x, y): pgle_profiler = profiler.PGLEProfiler(1, 90) with config.enable_pgle(False): with profiler.PGLEProfiler.trace(pgle_profiler): - compiled(x, y) + jax.block_until_ready(compiled(x, y)) fdo_profile = pgle_profiler.consume_fdo_profile() self.assertIsNotNone(fdo_profile) @@ -93,6 +97,8 @@ def testPGLEProfilerGetFDOProfileLarge(self): compiler_options = { 'xla_gpu_enable_latency_hiding_scheduler': 'True', + # Make sure that matmul is not emitted as Triton GEMM. + 'xla_gpu_enable_triton_gemm': 'False', } # TODO(b/37664749): Remove this flag once the bug is fixed. compiler_options['xla_gpu_enable_command_buffer'] = '' @@ -151,29 +157,31 @@ def f(x): with config.pgle_profiling_runs(2), config.enable_pgle(True): # Run 1: Module should be compiled without FDO. Two modules are expected - # One is the funtion f, the other one is multi slice module - with jtu.count_jit_compilation_cache_miss() as cache_miss_count: + # One is the function f, the other one is multi slice module + with jtu.count_pjit_cpp_cache_miss() as cache_miss_count: self.assertArraysEqual(f(x), expected) self.assertEqual(cache_miss_count(), 2) # Run 2: Second PGLE run. Profile should be empty. - with jtu.count_jit_compilation_cache_miss() as cache_miss_count: + with jtu.count_pjit_cpp_cache_miss() as cache_miss_count: self.assertArraysEqual(f(x), expected) self.assertEqual(cache_miss_count(), 2) fdo_profiles_before_pgle = self.get_fdo_profiles(dump_dir) - # One for before and one for after optimization. - self.assertLen(fdo_profiles_before_pgle, 2) + # One for before optimizatiom, one after SPMD partitioning, and one + # after optimization. + self.assertLen(fdo_profiles_before_pgle, 3) # The FDO profile file should be empty. self.assertEqual( os.path.getsize(os.path.join(dump_dir, fdo_profiles_before_pgle[0])), 0) # Run 3: The module should be recompiled with FDO profiles - with jtu.count_jit_compilation_cache_miss() as cache_miss_count: + with jtu.count_pjit_cpp_cache_miss() as cache_miss_count: self.assertArraysEqual(f(x), expected) self.assertEqual(cache_miss_count(), 2) fdo_profiles_after_pgle = self.get_fdo_profiles(dump_dir) - # One for before and one for after optimization. - self.assertLen(fdo_profiles_after_pgle, 4) + # One more before optimizatiom, one more after SPMD partitioning, and + # one more after optimization. + self.assertLen(fdo_profiles_after_pgle, 6) for fdo_profile in fdo_profiles_after_pgle: if fdo_profile not in fdo_profiles_before_pgle: @@ -182,7 +190,7 @@ def f(x): ) # Run 4: Fast-path should be used after PGLE is done - with jtu.count_jit_compilation_cache_miss() as cache_miss_count: + with jtu.count_pjit_cpp_cache_miss() as cache_miss_count: self.assertArraysEqual(f(x), expected) self.assertLess(cache_miss_count(), 2) @@ -196,7 +204,8 @@ def f(x): f_lowered = f.lower(x) serialized, in_tree, out_tree = serialize(f_lowered.compile()) - compiled = deserialize_and_load(serialized, in_tree, out_tree) + compiled = deserialize_and_load( + serialized, in_tree, out_tree, execution_devices=jax.devices()[:1]) with config.pgle_profiling_runs(1), config.enable_pgle(True): # Run 1 @@ -321,7 +330,11 @@ def testPassingFDOProfile(self): jax.jit, in_shardings=NamedSharding(mesh, PartitionSpec('x')), out_shardings=NamedSharding(mesh, PartitionSpec('x')), - compiler_options={'xla_gpu_enable_latency_hiding_scheduler': 'True'}, + compiler_options={ + 'xla_gpu_enable_latency_hiding_scheduler': 'True', + # Make sure that matmul is not emitted as Triton GEMM. + 'xla_gpu_enable_triton_gemm': 'False', + }, ) def f(x, y): return x @ y @@ -468,5 +481,55 @@ def check_if_cache_hit(event): self.assertLen(w, 1) self.assertIn("PERSISTENT CACHE WRITE with key jit_h-", str(w[0].message)) + @parameterized.parameters([True, False]) + @jtu.thread_unsafe_test() + def testAutoPgleWithCommandBuffers(self, enable_compilation_cache): + with (config.pgle_profiling_runs(1), + config.enable_compilation_cache(enable_compilation_cache), + config.enable_pgle(True), + tempfile.TemporaryDirectory() as dump_dir, + tempfile.TemporaryDirectory() as cache_dir): + if enable_compilation_cache: + cc.reset_cache() + cc.set_cache_dir(cache_dir) + compiler_options = { + 'xla_dump_to': dump_dir, + # FUSION, see https://github.com/openxla/xla/issues/22459 + 'xla_gpu_enable_command_buffer': 1, + 'xla_gpu_graph_min_graph_size': 1, + } + @partial( + jax.jit, + compiler_options=compiler_options, + ) + def f(x): + return x * 2 + + x = jnp.arange(1) + expected = x * 2 + + # This is ugly, but it does not seem possible to get the AutoPGLE-recompiled + # executable text (.lower(x).compile().as_text() or similar). + def get_new_hlo(): + additions = set(os.listdir(dump_dir)) - get_new_hlo.seen_files + get_new_hlo.seen_files |= additions + new_hlos = list(filter(lambda f: f.endswith("_gpu_after_optimizations.txt"), additions)) + assert len(new_hlos) == 1 + with open(os.path.join(dump_dir, new_hlos[0])) as ifile: + return ifile.read() + + get_new_hlo.seen_files = set() + + # Run 1 + self.assertArraysEqual(f(x), expected) + self.assertNotIn("command_buffer", get_new_hlo()) # b/376647494 workaround + # Run 2 + self.assertArraysEqual(f(x), expected) + self.assertIn("command_buffer", get_new_hlo()) # workaround disabled + + api.clear_caches() + pjit._pgle_profiler_dict.clear() + + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/pickle_test.py b/tests/pickle_test.py index 185eebd90726..feaeb8db01c7 100644 --- a/tests/pickle_test.py +++ b/tests/pickle_test.py @@ -28,6 +28,7 @@ from jax.interpreters import pxla from jax._src import test_util as jtu from jax._src.lib import xla_client as xc +from jax._src.sharding_impls import GSPMDSharding import numpy as np @@ -174,6 +175,17 @@ def test_pickle_single_device_sharding(self): s = jax.sharding.SingleDeviceSharding(jax.devices()[0]) self.assertEqual(s, pickle.loads(pickle.dumps(s))) + def test_pickle_single_device_sharding_with_memory_kind(self): + for memory_kind in ( + *[memory.kind for memory in jax.devices()[0].addressable_memories()], + None, + ): + with self.subTest(memory_kind=memory_kind): + s = jax.sharding.SingleDeviceSharding( + jax.devices()[0], memory_kind=memory_kind + ) + self.assertEqual(s, pickle.loads(pickle.dumps(s))) + def test_pickle_pmap_sharding(self): ss = pxla.ShardingSpec( sharding=(pxla.Unstacked(8),), @@ -182,16 +194,40 @@ def test_pickle_pmap_sharding(self): self.assertEqual(s, pickle.loads(pickle.dumps(s))) def test_pickle_gspmd_sharding(self): - s = jax.sharding.GSPMDSharding.get_replicated(jax.devices()) + s = GSPMDSharding.get_replicated(jax.devices()) self.assertEqual(s, pickle.loads(pickle.dumps(s))) + def test_pickle_gspmd_sharding_with_memory_kind(self): + for memory_kind in ( + *[memory.kind for memory in jax.devices()[0].addressable_memories()], + None, + ): + with self.subTest(memory_kind=memory_kind): + s = GSPMDSharding.get_replicated(jax.devices(), memory_kind=memory_kind) + self.assertEqual(s, pickle.loads(pickle.dumps(s))) + @unittest.skipIf(cloudpickle is None, "Requires cloudpickle") def test_pickle_named_sharding(self): s = jax.sharding.NamedSharding( mesh=jax.sharding.Mesh(np.array(jax.devices()), 'd'), - spec=jax.sharding.PartitionSpec('d')) + spec=jax.sharding.PartitionSpec('d'), + ) self.assertEqual(s, pickle.loads(pickle.dumps(s))) + @unittest.skipIf(cloudpickle is None, 'Requires cloudpickle') + def test_pickle_named_sharding_with_memory_kind(self): + for memory_kind in ( + *[memory.kind for memory in jax.devices()[0].addressable_memories()], + None, + ): + with self.subTest(memory_kind=memory_kind): + s = jax.sharding.NamedSharding( + mesh=jax.sharding.Mesh(np.array(jax.devices()), 'd'), + spec=jax.sharding.PartitionSpec('d'), + memory_kind=memory_kind, + ) + self.assertEqual(s, pickle.loads(pickle.dumps(s))) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 293b37a9fbc7..a814d25ba655 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -14,7 +14,7 @@ from collections import OrderedDict, namedtuple import re -from functools import partial +from functools import partial, wraps import logging import json import math @@ -42,27 +42,29 @@ from jax._src import prng from jax.sharding import PartitionSpec as P, Mesh from jax.experimental import multihost_utils -from jax.experimental.shard_map import shard_map +from jax._src.shard_map import shard_map from jax._src.compilation_cache import is_persistent_cache_enabled from jax.experimental.custom_partitioning import ( custom_partitioning, SdyShardingRule, BATCHING) +from jax.experimental import primal_tangent_dtype from jax._src import array from jax._src.sharding import Sharding, common_devices_indices_map from jax._src import op_shardings from jax._src import sharding_impls from jax._src.sharding_impls import ( - AUTO, UNSPECIFIED, NamedSharding, GSPMDSharding, PositionalSharding, + AUTO, UNSPECIFIED, NamedSharding, GSPMDSharding, SingleDeviceSharding, parse_flatten_op_sharding) from jax._src.pjit import (pjit, mesh_cast, auto_axes, explicit_axes, - use_auto_axes, use_explicit_axes, reshard) + use_auto_axes, use_explicit_axes, reshard, + _pjit_lower_cached) +from jax._src.layout import Format, DeviceLocalLayout as DLL from jax._src.named_sharding import DuplicateSpecError from jax._src import mesh as mesh_lib from jax._src.mesh import AxisType from jax._src.interpreters import pxla -from jax._src.lib.mlir import dialects from jax._src import xla_bridge from jax._src.lib import xla_client as xc -from jax._src.lib import xla_extension +from jax._src.lib import _jax from jax._src.util import curry, unzip2 config.parse_flags_with_absl() @@ -505,8 +507,6 @@ def f(x): self.assertIn("sharding={replicated}", hlo.as_hlo_text()) def testShardingConstraintWithArrayOpSharding(self): - if config.use_shardy_partitioner.value: - self.skipTest("Shardy doesn't support PositionalSharding") shape = (8, 8) mesh = jtu.create_mesh((2, 1), ('x', 'y')) s = NamedSharding(mesh, P(None)) @@ -903,8 +903,11 @@ def _dispatch(): def check_outfeed(x_fn): for didx, d in enumerate(devices): x = x_fn(didx) - y, = d.transfer_from_outfeed( - xc.shape_from_pyval((x,)).with_major_to_minor_layout_if_absent()) + y = d.transfer_from_outfeed( + xc.Shape.array_shape( + xc.PrimitiveType.F32, x.shape + ).with_major_to_minor_layout_if_absent() + ) self.assertAllClose(x, y, check_dtypes=True) logging.info('Transferring from outfeed for the pjit call') @@ -940,6 +943,18 @@ def testWithCustomPRNGKey(self): # Make sure this doesn't crash pjit(lambda x: x, in_shardings=None, out_shardings=None)(key) + def test_lower_with_wrapper_error(self): + @jax.jit + def f(x): + return x + + self.assertAllClose(1., f(1.)) + self.assertAllClose(1., f.lower(1.).compile()(1.)) + wrapped_f = wraps(f)(lambda x: f(x + 1)) + + with self.assertRaisesRegex(AttributeError, "has no attribute 'lower'"): + wrapped_f.lower(1.) + @jtu.with_mesh([('x', 2), ('y', 2)]) def testLowerCompile(self): @partial(pjit, @@ -1240,9 +1255,12 @@ def test_pretty_print_pjit_id(self): jaxpr.pretty_print(use_color=False), textwrap.dedent(""" { lambda ; a:f32[1]. let - pjit[name= jaxpr={ lambda ; a:f32[1] b:f32[1]. let in () }] a a - c:f32[1] = add a a - in (c,) } + b:f32[1] = pjit[ + name= + jaxpr={ lambda ; a:f32[1] c:f32[1]. let in (a,) } + ] a a + d:f32[1] = add a b + in (d,) } """).strip(), ) @@ -1257,7 +1275,7 @@ def test_pretty_print_with_constant_pjit_arg(self): b:f32[1] = pjit[ name= jaxpr={ lambda ; a:f32[1] c:f32[]. let b:f32[1] = mul a c in (b,) } - ] a 1.0 + ] a 1.0:f32[] in (b,) } """).strip(), ) @@ -1289,8 +1307,11 @@ def test_pretty_print_with_literal_outvar(self): jaxpr.pretty_print(use_color=False), textwrap.dedent(""" { lambda ; a:f32[1]. let - b:i32[] = pjit[name= jaxpr={ lambda ; a:f32[1]. let in (2,) }] a - in (b, a) } + b:i32[] c:f32[1] = pjit[ + name= + jaxpr={ lambda ; a:f32[1]. let in (2:i32[], a) } + ] a + in (b, c) } """).strip(), ) @@ -1336,19 +1357,19 @@ def f(x): self.assertEqual( jaxpr.pretty_print(use_color=False), textwrap.dedent(""" - let f = { lambda ; a:f32[1]. let in () } in - let f1 = { lambda ; b:f32[2]. let in () } in + let f = { lambda ; a:f32[1]. let in (a,) } in + let f1 = { lambda ; b:f32[2]. let in (b,) } in { lambda ; c:f32[1] d:f32[2]. let e:f32[2] = pjit[ name=g jaxpr={ lambda ; c:f32[1] d:f32[2]. let - pjit[name=f jaxpr=f] c - pjit[name=f jaxpr=f] c - g:f32[1] = mul c c - pjit[name=f jaxpr=f1] d - pjit[name=f jaxpr=f1] d - h:f32[2] = mul d d - e:f32[2] = add g h + g:f32[1] = pjit[name=f jaxpr=f] c + h:f32[1] = pjit[name=f jaxpr=f] c + i:f32[1] = mul g h + j:f32[2] = pjit[name=f jaxpr=f1] d + k:f32[2] = pjit[name=f jaxpr=f1] d + l:f32[2] = mul j k + e:f32[2] = add i l in (e,) } ] c d in (e,) } @@ -1394,6 +1415,16 @@ def test_zero_literal_equality(self): self.assertIn("stablehlo.constant dense<0.000000e+00>", ir) self.assertIn("stablehlo.constant dense<-0.000000e+00>", ir) + def test_device_put_copy_donate(self): + x = np.arange(1000) + y = jax.device_put(x, device=jax.devices()[0], may_alias=False, donate=False) + z = jax.device_put(y, device=jax.devices()[0], may_alias=False, donate=False) + a = jax.jit(lambda y: y * 2, donate_argnums=0)(y) + self.assertDeleted(y) + self.assertNotDeleted(z) + self.assertArraysEqual(a, x * 2) + + @jtu.pytest_mark_if_available('multiaccelerator') class CustomPartitionerTest(jtu.JaxTestCase): @@ -2275,13 +2306,13 @@ def add(x, y): return x + y out = add(a, b) - cache_info1 = pxla._cached_lowering_to_hlo.cache_info() + cache_info1 = _pjit_lower_cached.cache_info() self.assertIsInstance(out, array.ArrayImpl) self.assertArraysEqual(out, a + b) self.assertFalse(out._committed) out2 = add(out, out) - cache_info2 = pxla._cached_lowering_to_hlo.cache_info() + cache_info2 = _pjit_lower_cached.cache_info() self.assertIsInstance(out2, array.ArrayImpl) self.assertArraysEqual(out2, 2 * (a + b)) self.assertFalse(out2._committed) @@ -2291,7 +2322,7 @@ def add(x, y): c = jax.device_put(a, jax.devices()[0]) out3 = add(c, c) - cache_info3 = pxla._cached_lowering_to_hlo.cache_info() + cache_info3 = _pjit_lower_cached.cache_info() self.assertArraysEqual(out3, 2 * c) self.assertTrue(out3._committed) @@ -2477,6 +2508,20 @@ def test_pjit_committed_array_different_devices_variadic_args(self): r"\[1\].*"): pjit(lambda *x: x)(a, b) + def test_jit_no_forwarding(self): + mesh = jtu.create_mesh((2,), ('x',)) + + @partial(jax.jit, donate_argnums=(0,)) + def f(x): + return x, x * 2 + + x = jax.device_put(jnp.zeros(64, dtype="int32"), NamedSharding(mesh, P())) + jaxpr = jax.make_jaxpr(f)(x) + y = core.jaxpr_as_fun(jaxpr)(x) + self.assertTrue(x.is_deleted()) + self.assertFalse(y[0].is_deleted()) + self.assertFalse(y[1].is_deleted()) + def test_pjit_pytree_inp_device_assignment_mismatch(self): mesh = jtu.create_mesh((2, 2), ('x', 'y')) a = jax.device_put(np.array([1, 2, 3]), jax.devices()[0]) @@ -3223,6 +3268,17 @@ def g(x): jaxpr = jax.make_jaxpr(g)(3) self.assertNotIn('pjit', str(jaxpr)) + def test_pjit_inline_literal(self): + # https://github.com/jax-ml/jax/issues/27545 + def bar(x): + return jnp.array(1) + + def foo(x): + x = pjit(bar, inline=True)(x) + self.assertEqual(x.shape, ()) + + pjit(foo)(0) # doesn't crash + def test_pmap_in_axis_resources_error(self): pmap_out = jax.pmap(lambda x: x)(jnp.arange(jax.device_count())) self.assertIsInstance(pmap_out.sharding, jax.sharding.PmapSharding) @@ -3423,9 +3479,8 @@ def f(x, y): f(x_, y) self.assertLen(cm.output, expected_log_len) msg = cm.output[0] - self.assertIn('never seen input type signature', msg) - self.assertIn('closest seen input type signature has 1 mismatches', msg) - self.assertIn("seen f32[8]({}), but now given f32[8]({Auto: ('x',)})", msg) + self.assertIn("different input types", msg) + self.assertIn("at x, now f32[8]({Auto: ('x',)}) and before f32[8]({})", msg) def test_pjit_function_cache_cpp(self): def f(x): @@ -3438,6 +3493,7 @@ def f(x): pjit(f)(inp) self.assertEqual(count(), 1) + @jtu.thread_unsafe_test() # count_pjit_cpp_cache_miss is not thread-safe def test_pjit_no_global_cache_hit_axis_resources(self): mesh = jtu.create_mesh((1,), ('x',)) s = NamedSharding(mesh, P('x')) @@ -3563,6 +3619,9 @@ def test_device_put_grad(self): if jtu.is_device_tpu(5, 'e'): self.skipTest('TPU v5e does not support computations that run on a ' 'non-singleton subset of cores.') + if jtu.is_device_tpu(6, 'e'): + self.skipTest('TPU v6e does not support computations that run on a ' + 'non-singleton subset of cores.') def _test(fun, inp, np_inp, in_s): out = fun(inp) @@ -3612,20 +3671,17 @@ def g(x): @jtu.thread_unsafe_test() # cache_info isn't thread-safe def test_pjit_out_sharding_preserved(self): - if config.use_shardy_partitioner.value: - raise unittest.SkipTest("Shardy doesn't support PositionalSharding") mesh = jtu.create_mesh((2, 1), ('x', 'y')) ns = NamedSharding(mesh, P('x')) - ps = PositionalSharding(jax.devices()[:2]).reshape(2, 1) + gs = GSPMDSharding(jax.devices()[:2], ns._to_xla_hlo_sharding(2)) arr = jax.device_put(np.arange(8).reshape(8, 1), ns) - arr2 = jax.device_put(np.arange(8).reshape(8, 1), ps) + arr2 = jax.device_put(np.arange(8).reshape(8, 1), gs) def mul(x): return x * 2 f = pjit(mul, out_shardings=ns) - f2 = pjit(mul, out_shardings=ps) with jtu.count_pjit_cpp_cache_miss() as count: out = f(arr) @@ -3636,24 +3692,12 @@ def mul(x): self.assertIsInstance(out.sharding, NamedSharding) self.assertEqual(count(), 1) - with jtu.count_pjit_cpp_cache_miss() as count: - out2 = f2(arr) - cache_info2 = pxla._cached_compilation.cache_info() - self.assertIsInstance(out2.sharding, PositionalSharding) - - out2 = f2(arr) - self.assertIsInstance(out2.sharding, PositionalSharding) - self.assertEqual(count(), 1) - - self.assertEqual(cache_info2.hits, cache_info1.hits + 1) - self.assertEqual(cache_info2.misses, cache_info1.misses) - with jtu.count_jit_tracing_cache_miss() as tracing_count: out3 = jnp.squeeze(arr, axis=-1) self.assertIsInstance(out3.sharding, NamedSharding) out4 = jnp.squeeze(arr2, axis=-1) - self.assertIsInstance(out4.sharding, PositionalSharding) + self.assertIsInstance(out4.sharding, GSPMDSharding) self.assertEqual(tracing_count(), 2) @jtu.thread_unsafe_test() # cache_info isn't thread-safe @@ -3686,25 +3730,6 @@ def test_list_in_pspec(self): out = with_sharding_constraint(jnp.arange(8), P(['x'])) self.assertEqual(out.sharding, NamedSharding(mesh, P('x'))) - def test_sharding_preserved_trivial(self): - if config.use_shardy_partitioner.value: - raise unittest.SkipTest("Shardy doesn't support PositionalSharding") - mesh = jtu.create_mesh((2, 1), ('x', 'y')) - ns = NamedSharding(mesh, P('x')) - ps = PositionalSharding(jax.devices()[:2]).reshape(2, 1) - - arr = jax.device_put(np.arange(8).reshape(8, 1), ns) - arr2 = jax.device_put(np.arange(8).reshape(8, 1), ps) - - def identity(x): - return x - - out = pjit(identity)(arr) - self.assertIsInstance(out.sharding, NamedSharding) - - out2 = pjit(identity)(arr2) - self.assertIsInstance(out2.sharding, PositionalSharding) - def test_wsc_error_on_none(self): with self.assertRaisesRegex( ValueError, @@ -3712,23 +3737,6 @@ def test_wsc_error_on_none(self): ' not allowed'): with_sharding_constraint(jnp.arange(8), None) - def test_sharding_preserved_aot(self): - mesh = jtu.create_mesh((2, 1), ('x', 'y')) - ns = NamedSharding(mesh, P('x')) - ps = PositionalSharding(jax.devices()[:2]).reshape(2, 1) - - arr = jax.device_put(np.arange(8).reshape(8, 1), ns) - arr2 = jax.device_put(np.arange(8).reshape(8, 1), ps) - - compiled = pjit(lambda x: x * 2).lower(arr).compile() - out = compiled(arr) - self.assertIsInstance(out.sharding, NamedSharding) - - out2 = compiled(arr2) - # The sharding won't be PositionalSharding since the pjit was already - # Compiled which bakes in the output sharding. - self.assertIsInstance(out2.sharding, NamedSharding) - def test_sharding_on_output_with_vmap(self): mesh = jtu.create_mesh((2, 1), ('x', 'y')) ns = NamedSharding(mesh, P('x')) @@ -3747,16 +3755,29 @@ def test_sharding_on_output_with_vmap(self): self.assertIsInstance(out3.sharding, NamedSharding) self.assertEqual(count(), 1) + @config.numpy_dtype_promotion('standard') + def test_mutable_array_closed_over_multi_device(self): + mesh = jtu.create_mesh((2,), ('x',)) + key_data = jax.random.key_data(jax.random.key(42)) + key_data_ref = core.mutable_array(key_data) + output_sharding = NamedSharding(mesh, P('x')) + + @partial(jax.jit, out_shardings=output_sharding) + def generate_random_numbers(): + key_val = key_data_ref[...] + outputs = jnp.arange(8, dtype=jnp.float32) + key_val[0] + return outputs + + generate_random_numbers() # doesn't crash + @jtu.thread_unsafe_test() # cache_info isn't thread-safe def test_jit_mul_sum_sharding_preserved(self): - if config.use_shardy_partitioner.value: - raise unittest.SkipTest("Shardy doesn't support PositionalSharding") mesh = jtu.create_mesh((2, 1), ('x', 'y')) ns = NamedSharding(mesh, P('x')) - ps = PositionalSharding(jax.devices()[:2]).reshape(2, 1) + gs = GSPMDSharding(tuple(mesh.devices.flat), ns._to_xla_hlo_sharding(2)) arr = jax.device_put(np.arange(8).reshape(8, 1), ns) - arr2 = jax.device_put(np.arange(8).reshape(8, 1), ps) + arr2 = jax.device_put(np.arange(8).reshape(8, 1), gs) f = jax.jit(lambda x: x * 2) @@ -3766,11 +3787,11 @@ def test_jit_mul_sum_sharding_preserved(self): with jtu.count_pjit_cpp_cache_miss() as cpp_count: out2 = f(arr2) - self.assertIsInstance(out2.sharding, PositionalSharding) + self.assertIsInstance(out2.sharding, GSPMDSharding) # This will hit the cpp cache. out3 = f(out2) - self.assertIsInstance(out3.sharding, PositionalSharding) + self.assertIsInstance(out3.sharding, GSPMDSharding) self.assertEqual(compilation_count(), 2) self.assertEqual(cpp_count(), 1) @@ -3818,8 +3839,6 @@ def test_none_out_sharding(self): self.assertEqual(out2.sharding.spec, P()) def test_sharding_preserved_apply_primitive(self): - if config.use_shardy_partitioner.value: - raise unittest.SkipTest("Shardy doesn't support PositionalSharding") mesh = jtu.create_mesh((2, 1), ('x', 'y')) ns = NamedSharding(mesh, P('x')) @@ -3828,10 +3847,10 @@ def test_sharding_preserved_apply_primitive(self): out = jnp.copy(arr) self.assertIsInstance(out.sharding, NamedSharding) - ps = PositionalSharding(jax.devices()[:2]).reshape(2, 1) - arr2 = jax.device_put(np.arange(8).reshape(8, 1), ps) + gs = GSPMDSharding(jax.devices()[:2], ns._to_xla_hlo_sharding(2)) + arr2 = jax.device_put(np.arange(8).reshape(8, 1), gs) out2 = jnp.copy(arr2) - self.assertIsInstance(out2.sharding, PositionalSharding) + self.assertIsInstance(out2.sharding, GSPMDSharding) arr3 = jnp.arange(8) out3 = jnp.copy(arr3) @@ -4250,6 +4269,13 @@ def make_keys(seeds): else: self.assertIn('unspecified_dims=[0,1,2]', lowered_text) + def test_wsc_with_scalar(self): + mesh = jtu.create_mesh((2,), 'x') + s = NamedSharding(mesh, P()) + out = jax.lax.with_sharding_constraint(1., s) + self.assertArraysEqual(out, 1.) + self.assertEqual(out.sharding, s) + def test_jit_partially_specified_shardings(self): mesh = jtu.create_mesh((2, 2), ('x', 'y')) @@ -4301,11 +4327,10 @@ def f(*args): f(inps) # doesn't crash def test_spmd_preserves_input_sharding_vmap_grad(self): - if config.use_shardy_partitioner.value: - self.skipTest("Shardy doesn't support PositionalSharding") # https://github.com/jax-ml/jax/issues/20710 n_devices = jax.device_count() - sharding = PositionalSharding(jax.devices()) + mesh = Mesh(jax.devices(), 'x') + sharding = NamedSharding(mesh, P('x')) def model(params, x): return x @ params @@ -4318,8 +4343,8 @@ def model(params, x): params = jnp.ones(feature_dim) # Shard data, replicate params - x = jax.device_put(x, sharding.reshape(n_devices, 1)) - params = jax.device_put(params, sharding.replicate(axis=0)) + x = jax.device_put(x, sharding) + params = jax.device_put(params, NamedSharding(mesh, P())) model(params, x) # doesn't crash @@ -4446,9 +4471,15 @@ def f(x): self.assertLen(traced.in_avals[0], 1) self.assertLen(traced.in_avals[1], 0) # empty kwarg + def test_in_out_shardings_unconstrained_error(self): + mesh = jtu.create_mesh((1,), ('x',)) + + with self.assertRaisesRegex( + ValueError, "Unconstrained dims are not allowed"): + jax.jit(lambda x: x, + in_shardings=NamedSharding(mesh, P(P.UNCONSTRAINED, 'x'))) + def test_empty_io_callback_under_shard_map(self): - if config.use_shardy_partitioner.value: - self.skipTest("TODO(b/384938613): Failing under shardy.") mesh = jtu.create_mesh((4,), 'i') def empty_callback(x): @@ -4460,7 +4491,7 @@ def _f(x, y): return x + y[..., jnp.newaxis] f = jax.jit(shard_map( - _f, mesh, in_specs=(P(None, 'i'), P(None)), + _f, mesh=mesh, in_specs=(P(None, 'i'), P(None)), out_specs=P(None, 'i'))) f(jnp.zeros((2, 16)), jnp.ones(2)) @@ -4478,7 +4509,7 @@ def _f(x, y): return x + y[..., jnp.newaxis] f = jax.jit(shard_map( - _f, mesh, in_specs=(P(None, 'i'), P(None)), + _f, mesh=mesh, in_specs=(P(None, 'i'), P(None)), out_specs=P(None, 'i'))) f(jnp.zeros((2, 16)), jnp.ones(2)) @@ -4904,6 +4935,91 @@ def g(x): else: self.assertIn("unspecified_dims=[0]", lowered_text) + def test_prng_key_wsc(self): + mesh = jtu.create_mesh((2,), 'x') + + @jax.jit + def f(x): + y = lax.with_sharding_constraint(x, NamedSharding(mesh, P())) + return y.T + f(jax.random.key(0)) # doesn't crash + + @jax.jit + def g(x): + return lax.with_sharding_constraint(x, NamedSharding(mesh, P())) + g(jax.random.key(1)) # doesn't crash + + def test_prng_key_wsc_multi_axes_sharding(self): + input_shape = (8, 4) + mesh = jtu.create_mesh((4, 2), ('x', 'y')) + spec = P('x', 'y') + + seeds, _ = create_array(input_shape, mesh, spec, dtype=np.uint32) + + @jax.jit + def make_keys(seeds): + make_key = partial(prng.random_seed, impl=prng.threefry_prng_impl) + return lax.with_sharding_constraint( + make_key(seeds), NamedSharding(mesh, P('x', 'y'))) + + out = make_keys(seeds) + self.assertTrue(jax.dtypes.issubdtype(out.dtype, jax.dtypes.prng_key)) + self.assertEqual(out.shape, input_shape) + jax.random.key_data(out) # doesn't crash + + def test_sds_update(self): + mesh = jtu.create_mesh((2, 1), ('x', 'y')) + s1 = jax.ShapeDtypeStruct((2, 2), jnp.int32) + s1_u = s1.update(shape=(4, 2), dtype=np.float32) + self.assertEqual(s1_u.shape, (4, 2)) + self.assertEqual(s1_u.dtype, np.float32) + self.assertFalse(s1_u.weak_type) + + s2 = jax.ShapeDtypeStruct((2, 2), jnp.int32) + s2_u = s2.update(shape=(4, 2), weak_type=True) + self.assertEqual(s2_u.shape, (4, 2)) + self.assertEqual(s2_u.dtype, np.int32) + self.assertTrue(s2_u.weak_type) + + s3 = jax.ShapeDtypeStruct((2, 2), jnp.int32, + sharding=NamedSharding(mesh, P())) + s3_u = s3.update(sharding=NamedSharding(mesh, P('x'))) + self.assertEqual(s3_u.sharding, NamedSharding(mesh, P('x'))) + + s32_u = s3.update(shape=(4, 2)) + self.assertEqual(s32_u.shape, (4, 2)) + self.assertEqual(s32_u.sharding, NamedSharding(mesh, P())) + + sh = NamedSharding(mesh, P()) + s4 = jax.ShapeDtypeStruct((2, 2), jnp.int32, + sharding=Format(DLL((0, 1)), sh)) + new_layout = Format(DLL((1, 0)), NamedSharding(mesh, P('x'))) + s4_u = s4.update(sharding=new_layout) + self.assertEqual(s4_u.sharding, new_layout.sharding) + self.assertEqual(s4_u.format, new_layout) + + with self.assertRaisesRegex(ValueError, "updating ShapeDtypeStruct"): + s4.update(sharding=NamedSharding(mesh, P('x'))) + + @jtu.with_explicit_mesh((2, 1), ('x', 'y'), axis_types=(AxisType.Auto,) * 2) + def test_sds_pspec_input(self, mesh): + inp = jax.ShapeDtypeStruct((2, 2), np.float32, sharding=P('x')) + lowered = jax.jit(lambda x: x * 2).lower(inp) + self.assertIn('num_partitions = 2', lowered.as_text()) + + np_inp = np.arange(4, dtype=np.float32).reshape(2, 2) + arr = jax.device_put(np_inp, P('x')) + out = lowered.compile()(arr) + self.assertArraysEqual(out, np_inp * 2) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x'))) + + def test_sds_pspec_no_mesh_ctx_error(self): + with self.assertRaisesRegex( + TypeError, + 'When specifying PartitionSpec to `ShapeDtypeStruct`, the context mesh' + ' cannot be empty'): + jax.ShapeDtypeStruct((2, 2), np.float32, sharding=P('x')) + def spec_regex(s): return str(s).replace(r"(", r"\(").replace(r")", r"\)") @@ -4917,7 +5033,7 @@ def check_wsc_in_lowered(self, text): else: self.assertIn('@Sharding', text) - @jtu.with_user_mesh((2, 2), ('x', 'y')) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_basic_mul(self, mesh): np_inp = np.arange(16.).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) @@ -4962,7 +5078,7 @@ def g(x): jax.jit(jax.grad(g)).lower(sds) # doesn't crash - @jtu.with_user_mesh((2, 2), ('x', 'y')) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_fully_replicated_array_mul(self, mesh): np_inp1 = np.arange(16).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) @@ -4995,11 +5111,13 @@ def g(x, y): return x * y with self.assertRaisesRegex( - TypeError, "mul got incompatible shardings for broadcasting"): + core.ShardingTypeError, + "mul got incompatible shardings for broadcasting"): g(arr1, jax.device_put(np_inp1, NamedSharding(mesh, P('y', 'x')))) with self.assertRaisesRegex( - TypeError, "mul got incompatible shardings for broadcasting"): + core.ShardingTypeError, + "mul got incompatible shardings for broadcasting"): g(arr1, jax.device_put(np_inp1, NamedSharding(mesh, P(('x', 'y'))))) @parameterized.named_parameters( @@ -5009,7 +5127,7 @@ def g(x, y): ('fsdp', P('x', None), P('x', None), P('x', None), 'all-gather'), ('half_tp', P(None, 'y'), P(None, 'y'), P(None, 'y'), 'all-gather'), ) - @jtu.with_user_mesh((2, 2), ('x', 'y')) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_dot_general(self, spec1, spec2, out_spec, collective_name, mesh): np_inp1 = np.arange(16.).reshape(8, 2) arr1 = jax.device_put(np_inp1, NamedSharding(mesh, spec1)) @@ -5051,7 +5169,7 @@ def g(x, y): self.assertEqual(out[1].sharding, arr2.sharding) @parameterized.parameters([True, False]) - @jtu.with_user_mesh((4,), ('x',)) + @jtu.with_explicit_mesh((4,), ('x',)) def test_dot_general_out_sharding(self, use_jit, mesh): np_inp1 = np.arange(16.).reshape(8, 2) arr1 = jax.device_put(np_inp1, NamedSharding(mesh, P('x', None))) @@ -5073,7 +5191,7 @@ def f(x, y): ValueError, 'PartitionSpec passed to einsum cannot contain axis names that are of' ' type Auto or Manual'): - auto_axes(f, out_shardings=P())(arr1, arr2) + auto_axes(f, out_sharding=P())(arr1, arr2) out = jax.grad(f, argnums=(0, 1))(arr1, arr2) self.assertEqual(out[0].sharding, arr1.sharding) @@ -5086,7 +5204,7 @@ def f(x, y): self.assertEqual(out[1].sharding, arr2.sharding) jaxpr = jitted_grad.trace(arr1, arr2).jaxpr - bwd_jaxpr = jaxpr.eqns[1] + bwd_jaxpr = jaxpr.eqns[-1] expected_spec = [('broadcast_in_dim', P('x', None)), ('dot_general', P('x', None)), ('transpose', P(None, 'x')), @@ -5098,16 +5216,16 @@ def f(x, y): @parameterized.named_parameters( ('fail1', P('x', None), P(None, 'x'), "dot_general operation.*produces an illegally sharded result", - TypeError), + core.ShardingTypeError), ('fail2', P('x', 'y'), P('x', 'y'), "dot_general requires contracting dimensions to have consistent sharding", - TypeError), + core.ShardingTypeError), ('contracting1', P('x', 'y'), P('y', None), - 'Contracting dimensions are sharded', ValueError), + 'Contracting dimensions are sharded', core.ShardingTypeError), ('other_half_tp', P(None, 'y'), P('y', None), - 'Contracting dimensions are sharded', ValueError), + 'Contracting dimensions are sharded', core.ShardingTypeError), ) - @jtu.with_user_mesh((2, 2), ('x', 'y')) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_dot_general_error(self, spec1, spec2, error_msg, error_type, mesh): np_inp1 = np.arange(16).reshape(8, 2) arr1 = jax.device_put(np_inp1, NamedSharding(mesh, spec1)) @@ -5120,26 +5238,26 @@ def f(x, y): with self.assertRaisesRegex(error_type, error_msg): f(arr1, arr2) - @jtu.with_user_mesh((2, 2, 1), ('x', 'y', 'z')) + @jtu.with_explicit_mesh((2, 2, 1), ('x', 'y', 'z')) def test_dot_general_batch_error(self, mesh): arr1 = jax.device_put(np.ones((8, 4, 2)), NamedSharding(mesh, P('x', 'y', 'z'))) arr2 = jax.device_put(np.ones((8, 2, 4)), NamedSharding(mesh, P('y', 'z', 'x'))) with self.assertRaisesRegex( - TypeError, + core.ShardingTypeError, 'dot_general requires lhs batch dimensions and rhs batch dimensions to' ' have the consistent sharding'): jax.lax.dot_general( arr1, arr2, dimension_numbers=(([2], [1]), ([0], [0]))) with self.assertRaisesRegex( - TypeError, + core.ShardingTypeError, 'dot_general requires lhs batch dimensions and rhs batch dimensions to' ' have the consistent sharding'): jnp.einsum('abc,acz->abz', arr1, arr2) - @jtu.with_user_mesh((2, 2), ('model', 'data')) + @jtu.with_explicit_mesh((2, 2), ('model', 'data')) def test_aval_repr(self, mesh): mesh = mesh.abstract_mesh aval = core.ShapedArray((128, 64), np.float32, @@ -5158,7 +5276,7 @@ def test_aval_repr(self, mesh): aval = aval.update(sharding=NamedSharding(mesh, P(('model', 'data'), None))) self.assertEqual(aval.str_short(), 'float32[128@(model,data),64]') - @jtu.with_user_mesh((2, 1), ('x', 'y')) + @jtu.with_explicit_mesh((2, 1), ('x', 'y')) def test_jnp_ones_mesh_context_eager(self, mesh): s = NamedSharding(mesh, P('x', None)) out = jnp.ones((8, 2), dtype=jnp.int32, device=s) @@ -5175,7 +5293,7 @@ def test_jnp_ones_mesh_context_eager(self, mesh): ('first2', 0, P(('x', 'y'), None), P(None), True), ('second2', 1, P(('x', 'y'), None), P(('x', 'y')), False), ) - @jtu.with_user_mesh((2, 2), ('x', 'y')) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_reduce_sum(self, axis, in_spec, out_spec, reduce, mesh): np_inp = np.arange(16).reshape(8, 2) s = NamedSharding(mesh, in_spec) @@ -5206,7 +5324,7 @@ def f(x): ('first2', 0, P(('x', 'y'), None), P(None), True), ('second2', 1, P(('x', 'y'), None), P(('x', 'y')), False), ) - @jtu.with_user_mesh((2, 2), ('x', 'y')) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_reduce_max(self, axis, in_spec, out_spec, reduce, mesh): np_inp = np.arange(16.).reshape(8, 2) s = NamedSharding(mesh, in_spec) @@ -5247,7 +5365,7 @@ def g(x): ('2', 2, P('x', 'y', None)), ('-1', -1, P('x', 'y', None)), ) - @jtu.with_user_mesh((2, 2), ('x', 'y')) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_broadcast_in_dim(self, axis, out_spec, mesh): np_inp = np.arange(16).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) @@ -5273,7 +5391,7 @@ def f(x): ('3', 3), ('4', 4), ) - @jtu.with_user_mesh((2, 2), ('x', 'y')) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_integer_pow(self, pow, mesh): np_inp = np.arange(16).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) @@ -5292,7 +5410,7 @@ def f(x): lowered_text = f.lower(arr).as_text() self.check_wsc_in_lowered(lowered_text) - @jtu.with_user_mesh((1,), 'x') + @jtu.with_explicit_mesh((1,), 'x') def test_broadcasting_nary_error(self, mesh): mesh2 = Mesh([jax.devices()[0]], 'y', axis_types=(mesh_lib.AxisType.Explicit,)) @@ -5308,7 +5426,7 @@ def f(x, y): ValueError, "For primitive.*context mesh.*aval mesh"): f(arr1, arr2) - @jtu.with_user_mesh((2, 2), ('x', 'y')) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_sin_unop(self, mesh): np_inp = np.arange(16.).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) @@ -5326,7 +5444,7 @@ def f(x): lowered_text = f.lower(arr).as_text() self.check_wsc_in_lowered(lowered_text) - @jtu.with_user_mesh((2, 2), ('x', 'y')) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_jnp_array(self, mesh): np_inp = np.arange(16, dtype=jnp.int32).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) @@ -5342,7 +5460,7 @@ def f(x): f(arr) - @jtu.with_user_mesh((2, 2, 1), ('x', 'y', 'z')) + @jtu.with_explicit_mesh((2, 2, 1), ('x', 'y', 'z')) def test_lax_transpose_rule(self, mesh): np_inp = np.arange(16).reshape(4, 2, 2) s = NamedSharding(mesh, P('x', 'y', 'z')) @@ -5361,7 +5479,7 @@ def f(x): lowered_text = f.lower(arr).as_text() self.check_wsc_in_lowered(lowered_text) - @jtu.with_user_mesh((2, 2), ('x', 'y')) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_broadcasted_iota_with_sharding(self, mesh): np_inp = np.arange(4) s = NamedSharding(mesh, P('x')) @@ -5386,7 +5504,7 @@ def g(x): _, out = g(arr) self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) - @jtu.with_user_mesh((2, 2), ('x', 'y')) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_einsum_with_out_sharding(self, mesh): np_inp = np.arange(16.).reshape(8, 2) arr1 = jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y'))) @@ -5431,7 +5549,7 @@ def h2(x, y): self.assertEqual(out[0].sharding, arr3.sharding) self.assertEqual(out[1].sharding, arr4.sharding) - @jtu.with_user_mesh((2, 2), ('x', 'y')) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_einsum_inverse(self, mesh): np_inp = np.arange(64.) @@ -5465,24 +5583,59 @@ def h2(x, y): self.assertEqual(out[0].sharding, arr1.sharding) self.assertEqual(out[1].sharding, arr2.sharding) - @parameterized.named_parameters( - ('1', (16, 1), (1, 16, 1), P('x', None), P(None, 'x', None), False), - ('2', (8, 2, 1), (1, 16, 1), P('x', None, None), P(None, 'x', None), True), - ('3', (8, 1), (1, 4, 2), P('x', None), P(None, None, 'x'), True), - ('4', (1, 4, 1, 6, 1), (1, 4, 6), - P(None, 'x', None, None, None), P(None, 'x', None), False), - ('5', (4, 6), (4, 6), P(None, 'x'), P(None, 'x'), False), + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) + def test_fully_replicated_reshape(self, mesh): + np_inp = np.arange(64).reshape(64, 1) + arr = jax.device_put(np_inp, P(('x', 'y'))) + + @jax.jit + def f(x): + x = reshard(x, P(None, None)) + return jax.lax.reshape(x, (2, 32, 1)) + + out = f(arr) + self.assertEqual(out.sharding, NamedSharding(mesh, P(None, None, None))) + self.assertArraysEqual(out, np_inp.reshape(2, 32, 1)) + + @parameterized.parameters( + (src_shape, dst_shape, src_spec, dst_spec, use_sharding_arg, fun) + for fun in [jnp.reshape, jax.lax.reshape] + for src_shape, dst_shape, src_spec, dst_spec, use_sharding_arg in [ + ((16, 1), (1, 16, 1), P('x', None), P(None, 'x', None), + False), + ((8, 2, 1), (1, 16, 1), P('x', None, None), + P(None, 'x', None), True), + ((8, 1), (1, 4, 2), P('x', None), P(None, None, 'x'), + True), + ((1, 4, 1, 6, 1), (1, 4, 6), + P(None, 'x', None, None, None), P(None, 'x', None), False), + ((4, 6), (4, 6), P(None, 'x'), P(None, 'x'), False), + ((1024, 4096), (1024, 2048, 2, 1, 1, 1, 1), + P('x', None), P('x', None, None, None, None, None, None), False), + ((1024, 4096, 32), (1024, 2048, 2, 1, 1, 32), + P('x', None, None), P('x', None, None, None, None, None), False), + ((1024, 4096), (1024, 1, 1, 4096), + P('x', None), P('x', None, None, None), False), + ((1024, 4096), (1024, 1, 1, 4096), + P(None, 'x'), P(None, None, None, 'x'), False), + ((1024, 2048, 2, 1, 1, 1), (1024, 4096), + P('x', None, None, None, None, None), P('x', None), False), + ((1024, 2048, 2, 1, 1, 1), (1024, 4096), + P(None, 'x', None, None, None, None), P(None, 'x'), False), + ] ) - @jtu.with_user_mesh((2,), ('x',)) + @jtu.with_explicit_mesh((2,), ('x',)) def test_reshape(self, src_shape, dst_shape, src_spec, dst_spec, - use_sharding_arg, mesh): + use_sharding_arg, fun, mesh): np_inp = np.arange(math.prod(src_shape), dtype=np.float32).reshape(src_shape) arr = jax.device_put(np_inp, NamedSharding(mesh, src_spec)) @partial(jax.jit, static_argnums=1) def f(x, new_sharding): - y = lax.reshape(x, dst_shape, out_sharding=new_sharding) + y = fun(x, dst_shape, out_sharding=new_sharding) + self.assertEqual(y.aval.sharding.spec, dst_spec) + self.assertEqual(y.shape, dst_shape) y = y * 2 self.assertEqual(y.aval.sharding.spec, dst_spec) return y @@ -5554,7 +5707,7 @@ def g(x): P(None, 'y', None, 'x'), None, 'This reshape is not supported' ), ) - @jtu.with_user_mesh((2, 2), ('x', 'y')) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_reshape_split_merge_one_axis(self, src_shape, dst_shape, src_spec, dst_spec, error_msg, mesh): np_inp = np.arange(math.prod(src_shape), @@ -5569,7 +5722,7 @@ def f(x): return y if error_msg: - with self.assertRaisesRegex(ValueError, error_msg): + with self.assertRaisesRegex(core.ShardingTypeError, error_msg): f(arr) else: out = f(arr) @@ -5586,7 +5739,7 @@ def g(x): out = jax.jit(jax.grad(g))(arr) self.assertEqual(out.sharding, arr.sharding) - @jtu.with_user_mesh((2, 2), ('x', 'y')) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_select(self, mesh): np_inp = np.arange(16).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) @@ -5608,7 +5761,7 @@ def f(pred, on_true, on_false): arr3 = jax.device_put(np_inp, NamedSharding(mesh, P('y', 'x'))) with self.assertRaisesRegex( - TypeError, "select cases must have the same shardings"): + core.ShardingTypeError, "select cases must have the same shardings"): f(arr1 == arr2, arr1, arr3) def test_explicit_mode_no_context_mesh(self): @@ -5655,7 +5808,7 @@ def f(x): out = f(arr) self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None))) - @jtu.with_user_mesh((2, 2), ('x', 'y')) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_mesh_cast_reshard_error(self, mesh): np_inp = np.arange(16).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) @@ -5682,7 +5835,7 @@ def g(x): ' mesh and the target mesh'): g(arr) - @jtu.with_user_mesh((2, 2), ('x', 'y'), + @jtu.with_explicit_mesh((2, 2), ('x', 'y'), axis_types=(AxisType.Explicit, AxisType.Auto)) def test_mesh_cast_explicit_data_movement_error(self, mesh): np_inp = np.arange(16).reshape(8, 2) @@ -5699,7 +5852,7 @@ def f(x): ValueError, 'Explicit data movement in mesh_cast is not allowed'): f(arr) - @jtu.with_user_mesh((2, 2), ('x', 'y')) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_shard_map_full_manual(self, mesh): np_inp = np.arange(16).reshape(8, 2) arr = jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y'))) @@ -5725,7 +5878,7 @@ def f(x, y): self.assertArraysEqual(out, (np_inp * np_inp) * 2) self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) - @jtu.with_user_mesh((2, 2), ('x', 'y')) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_shard_map_dot(self, mesh): np_inp = np.arange(16).reshape(8, 2) arr = jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y'))) @@ -5753,7 +5906,15 @@ def f(x, y): self.assertArraysEqual(out, (np_inp @ np_inp.T) * 2) self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None))) - @jtu.with_user_mesh((2, 2), ('x', 'y')) + def test_full_like_eager_non_concrete_sharding(self): + s = NamedSharding(mesh_lib.AbstractMesh((2,), ('x',)), P('x')) + arr = jax.ShapeDtypeStruct((8, 2), np.float32, sharding=s) + out = jax.lax.full_like(arr, 0) + # The sharding is single device because the sharding of input `arr`` to + # full_like is not concrete. + self.assertEqual(out.sharding, SingleDeviceSharding(jax.devices()[0])) + + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_slice(self, mesh): np_inp = np.arange(16.).reshape(4, 4) arr = jax.device_put(np_inp, NamedSharding(mesh, P('x', None))) @@ -5778,13 +5939,13 @@ def g(x): out = jax.jit(jax.grad(g))(arr) self.assertEqual(out.sharding, arr.sharding) - with self.assertRaisesRegex(NotImplementedError, "slicing on sharded dims"): + with self.assertRaisesRegex(core.ShardingTypeError, "slicing on sharded dims"): f(jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y')))) - with self.assertRaisesRegex(NotImplementedError, "slicing on sharded dims"): + with self.assertRaisesRegex(core.ShardingTypeError, "slicing on sharded dims"): f(jax.device_put(np_inp, NamedSharding(mesh, P(None, ('x', 'y'))))) - @jtu.with_user_mesh((2, 2), ('x', 'y')) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_squeeze(self, mesh): np_inp = np.arange(16.).reshape(4, 4, 1) arr = jax.device_put(np_inp, NamedSharding(mesh, P('x', None, None))) @@ -5810,7 +5971,7 @@ def g(x): out = jax.jit(jax.grad(g))(arr) self.assertEqual(out.sharding, arr.sharding) - @jtu.with_user_mesh((2, 2), ('x', 'y')) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_pad(self, mesh): np_inp = np.arange(8.) arr = jax.device_put(np_inp, NamedSharding(mesh, P('x'))) @@ -5842,17 +6003,17 @@ def g(x): out = jax.jit(jax.grad(g))(arr) self.assertEqual(out.sharding, arr.sharding) - with self.assertRaisesRegex(NotImplementedError, "padding on sharded dims"): + with self.assertRaisesRegex(core.ShardingTypeError, "padding on sharded dims"): f(arr, ((2, 3, 0), ), None) - with self.assertRaisesRegex(NotImplementedError, "padding on sharded dims"): + with self.assertRaisesRegex(core.ShardingTypeError, "padding on sharded dims"): f(arr, ((0, 3, 0), ), None) - with self.assertRaisesRegex(NotImplementedError, "padding on sharded dims"): + with self.assertRaisesRegex(core.ShardingTypeError, "padding on sharded dims"): arr = jax.device_put(np_inp, NamedSharding(mesh, P(('x', 'y')))) f(arr, ((4, 4, 1),), None) - @jtu.with_user_mesh((2, 1), ('x', 'y')) + @jtu.with_explicit_mesh((2, 1), ('x', 'y')) def test_concatenate(self, mesh): np_inp = np.arange(16.).reshape(4, 4) s = NamedSharding(mesh, P('x', 'y')) @@ -5879,7 +6040,7 @@ def f(x, y, method='jnp'): self.assertArraysEqual(out, np.concatenate([arr1, arr2], axis=1)) with self.assertRaisesRegex( - TypeError, "All operands should have the same sharding"): + core.ShardingTypeError, "All operands should have the same sharding"): arr3 = jax.device_put(np.arange(4.).reshape(4, 1), NamedSharding(mesh, P('x'))) f(arr1, arr3) @@ -5894,7 +6055,7 @@ def g(x, y): out = jax.jit(jax.grad(g))(arr1, arr2) self.assertEqual(out.sharding, s) - @jtu.with_user_mesh((2, 2), ('x', 'y')) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_scan(self, mesh): carry = jax.device_put(np.arange(16.).reshape(2, 8), NamedSharding(mesh, P(None, 'x'))) @@ -5932,7 +6093,7 @@ def g(carry, arr): ValueError, "0th dimension of all xs should be replicated"): f(carry, jax.device_put(arr, NamedSharding(mesh, P('x', None, None)))) - @jtu.with_user_mesh((2, 2), ('x', 'y')) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_argminmax(self, mesh): np_inp = np.arange(16.).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) @@ -5953,7 +6114,7 @@ def f(x): self.assertEqual(out2.sharding, NamedSharding(mesh, P('x'))) self.check_wsc_in_lowered(f.lower(arr).as_text()) - @jtu.with_user_mesh((2, 2), ('x', 'y'), (mesh_lib.AxisType.Auto,) * 2) + @jtu.with_explicit_mesh((2, 2), ('x', 'y'), (mesh_lib.AxisType.Auto,) * 2) def test_only_auto(self, mesh): np_inp = np.arange(16.).reshape(8, 2) arr = jax.device_put(np_inp, NamedSharding(mesh, P('x', None))) @@ -6025,7 +6186,7 @@ def f(x, x2): "AxisTypes should be the same in a tuple subset of PartitionSpec"): NamedSharding(mesh2, P(('x', 'y'))) - @jtu.with_user_mesh((2, 2), ('x', 'y')) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_where_with_scalar(self, mesh): np_inp = np.arange(16.).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) @@ -6035,7 +6196,7 @@ def test_where_with_scalar(self, mesh): self.assertArraysEqual(out, x) self.assertEqual(out.sharding, s) - @jtu.with_user_mesh((2, 2), ('x', 'y')) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_full_user_to_full_auto(self, mesh): np_inp = np.arange(16.).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) @@ -6062,7 +6223,7 @@ def f(x): out2 = core.jaxpr_as_fun(jaxpr)(arr) self.assertEqual(out2[0].sharding, NamedSharding(mesh, P('x', None))) - @jtu.with_user_mesh((2, 2), ('x', 'y'), + @jtu.with_explicit_mesh((2, 2), ('x', 'y'), axis_types=(mesh_lib.AxisType.Auto,) * 2) def test_full_auto_to_full_user(self, mesh): np_inp = np.arange(16.).reshape(8, 2) @@ -6087,7 +6248,7 @@ def f(x): jaxpr = f.trace(arr).jaxpr core.jaxpr_as_fun(jaxpr)(arr) # doesn't crash - @jtu.with_user_mesh((2, 2), ('x', 'y')) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_full_user_to_auto_user_mix(self, mesh): np_inp = np.arange(16.).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) @@ -6114,7 +6275,7 @@ def f(x): out2 = core.jaxpr_as_fun(jaxpr)(arr) self.assertEqual(out2[0].sharding, NamedSharding(mesh, P('x', None))) - @jtu.with_user_mesh((2, 1), ('x', 'y')) + @jtu.with_explicit_mesh((2, 1), ('x', 'y')) def test_user_auto_mix_error(self, mesh): np_inp = np.arange(16.).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) @@ -6131,7 +6292,7 @@ def f(x, y): ValueError, "For primitive dot_general, context mesh.*aval mesh"): f(arr, arr.T) - @jtu.with_user_mesh((2, 2), ('x', 'y')) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_split(self, mesh): np_inp = np.arange(16.).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) @@ -6147,7 +6308,7 @@ def f(x, sizes=(4, 4), axis=0): f(arr) self.check_wsc_in_lowered(f.lower(arr).as_text()) - with self.assertRaisesRegex(NotImplementedError, "split on sharded dims"): + with self.assertRaisesRegex(core.ShardingTypeError, "split on sharded dims"): f(arr, sizes=(1, 1), axis=1) def g(x): @@ -6160,7 +6321,7 @@ def g(x): out = jax.jit(jax.grad(g))(arr) self.assertEqual(out.sharding, s) - @jtu.with_user_mesh((2,), 'x') + @jtu.with_explicit_mesh((2,), 'x') def test_return_output_different_context(self, mesh): np_inp = np.arange(16).reshape(8, 2) s = NamedSharding(mesh, P('x')) @@ -6179,7 +6340,7 @@ def f(x): self.assertDictEqual(out.sharding.mesh._axis_types_dict, {AxisType.Auto: ('x',)}) - @jtu.with_user_mesh((2,), 'x') + @jtu.with_explicit_mesh((2,), 'x') def test_device_put_use_mesh(self, mesh): out = jax.device_put(np.arange(8), P('x')) self.assertArraysEqual(out, np.arange(8)) @@ -6192,7 +6353,7 @@ def test_device_put_no_use_mesh_error(self): ' passed to device_put'): jax.device_put(np.arange(8), P('x')) - @jtu.with_user_mesh((2,), 'x') + @jtu.with_explicit_mesh((2,), 'x') def test_inputs_different_context(self, mesh): np_inp = np.arange(16).reshape(8, 2) s = NamedSharding(mesh, P('x')) @@ -6214,7 +6375,7 @@ def f(x, y): self.assertDictEqual(out2.sharding.mesh._axis_types_dict, {AxisType.Auto: ('x',)}) - @jtu.with_user_mesh((2,), 'x') + @jtu.with_explicit_mesh((2,), 'x') def test_output_different_context_error(self, mesh): np_inp1 = np.arange(16).reshape(8, 2) arr1 = jax.device_put(np_inp1, NamedSharding(mesh, P('x', None))) @@ -6241,7 +6402,7 @@ def g(x, y): ValueError, "PartitionSpec.*cannot contain axis names.*Auto"): g(arr1, arr2) - @jtu.with_user_mesh((2, 2, 2), ('x', 'y', 'z'), + @jtu.with_explicit_mesh((2, 2, 2), ('x', 'y', 'z'), axis_types=(AxisType.Explicit, AxisType.Explicit, AxisType.Auto)) def test_out_sharding_mix_axis_types(self, mesh): @@ -6266,13 +6427,13 @@ def f(x): else: self.assertTrue(lowered_text.count("unspecified_dims=[1,2]") == 3) - @jtu.with_user_mesh((2, 2), ('x', 'y')) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_auto_mode_mix(self, mesh): np_inp = np.arange(16.).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) arr = jax.device_put(np_inp, s) - @partial(auto_axes, axes='x', out_shardings=P('x', None)) + @partial(auto_axes, axes='x', out_sharding=P('x', None)) def h(y): self.assertEqual(y.aval.sharding.spec, P(None, 'y')) z = jnp.sin(y) @@ -6295,7 +6456,7 @@ def g(x): out2 = core.jaxpr_as_fun(jaxpr)(arr) self.assertEqual(out2[0].sharding, NamedSharding(mesh, P('x', None))) - @jtu.with_user_mesh((4,), ('x',)) + @jtu.with_explicit_mesh((4,), ('x',)) def test_concat_vmap(self, mesh): @jax.jit def _f(sharded_array, replicated_array): @@ -6326,7 +6487,7 @@ def test_aval_spec_explicit_auto_complete(self): out = core.ShapedArray((8, 2), jnp.int32, sharding=s) self.assertEqual(out.sharding.spec, P('x', None)) - @jtu.with_user_mesh((2, 2), ('x', 'y'), + @jtu.with_explicit_mesh((2, 2), ('x', 'y'), axis_types=(mesh_lib.AxisType.Auto,) * 2) def test_full_user_mode(self, mesh): np_inp = np.arange(16.).reshape(8, 2) @@ -6334,7 +6495,7 @@ def test_full_user_mode(self, mesh): arr = jax.device_put(np_inp, s) # No axes specified means full visible mode. - @partial(explicit_axes, in_shardings=P('x', 'y')) + @partial(explicit_axes, in_sharding=P('x', 'y')) def h(y): self.assertEqual(y.aval.sharding.spec, P('x', 'y')) z = jnp.sin(y) @@ -6356,12 +6517,12 @@ def f(x): jaxpr = f.trace(arr).jaxpr core.jaxpr_as_fun(jaxpr)(arr) # doesn't crash - @jtu.with_user_mesh((4,), ('data',)) + @jtu.with_explicit_mesh((4,), ('data',)) def test_intermediate_einsum(self, mesh): shape1 = (8, 32, 1, 16) shape2 = (8, 32, 1, 8) - np_inp1 = np.arange(math.prod(shape1)).reshape(shape1) - np_inp2 = np.arange(math.prod(shape2)).reshape(shape2) + np_inp1 = np.ones(math.prod(shape1)).reshape(shape1) + np_inp2 = np.ones(math.prod(shape2)).reshape(shape2) s = NamedSharding(mesh, P('data')) arr1 = jax.device_put(np_inp1, s) @@ -6380,16 +6541,16 @@ def f(x, y, z): self.assertEqual(out.shape, (16, 8, 16)) self.assertEqual(out.sharding, NamedSharding(mesh, P('data', None, None))) - @jtu.with_user_mesh((4,), ('data',)) + @jtu.with_explicit_mesh((4,), ('data',)) def test_intermediate_einsum_auto_complete_spec(self, mesh): s = NamedSharding(mesh, P('data')) shape1 = (8, 32, 2*16) shape2 = (8, 32, 2, 8) shape3 = (8, 32, 2, 8) - np_inp1 = np.arange(math.prod(shape1)).reshape(shape1) - np_inp2 = np.arange(math.prod(shape2)).reshape(shape2) - np_inp3 = np.arange(math.prod(shape3)).reshape(shape3) + np_inp1 = np.ones(math.prod(shape1)).reshape(shape1) + np_inp2 = np.ones(math.prod(shape2)).reshape(shape2) + np_inp3 = np.ones(math.prod(shape3)).reshape(shape3) arr1 = jax.device_put(np_inp1, s) arr2 = jax.device_put(np_inp2, s) @@ -6432,12 +6593,12 @@ def f(condition, x, y): f = jax.jit(f, in_shardings=(sharding, sharding, sharding)) f(condition, x, x).block_until_ready() - @jtu.with_user_mesh((4,), ('data',)) + @jtu.with_explicit_mesh((4,), ('data',)) def test_intermediate_einsum_conflict_error(self, mesh): shape1 = (8, 32, 1, 16) shape2 = (8, 32, 1, 8) - np_inp1 = np.arange(math.prod(shape1)).reshape(shape1) - np_inp2 = np.arange(math.prod(shape2)).reshape(shape2) + np_inp1 = np.ones(math.prod(shape1)).reshape(shape1) + np_inp2 = np.ones(math.prod(shape2)).reshape(shape2) arr1 = jax.device_put( np_inp1, NamedSharding(mesh, P(None, None, None, 'data'))) @@ -6452,11 +6613,11 @@ def f(x, y, z): # Errors out on the intermediate einsum: `bthj,bthD->bthjD` # because of a conflict with self.assertRaisesRegex( - TypeError, + core.ShardingTypeError, 'dot_general operation.*produces an illegally sharded result'): f(arr1, arr2, arr3) - @jtu.with_user_mesh((2, 2), ('x', 'y'), + @jtu.with_explicit_mesh((2, 2), ('x', 'y'), axis_types=(mesh_lib.AxisType.Explicit, mesh_lib.AxisType.Auto)) def test_mix_to_full_user_mode(self, mesh): @@ -6464,7 +6625,7 @@ def test_mix_to_full_user_mode(self, mesh): s = NamedSharding(mesh, P('x', 'y')) arr = jax.device_put(np_inp, s) - @partial(explicit_axes, axes='y', in_shardings=P('x', 'y')) + @partial(explicit_axes, axes='y', in_sharding=P('x', 'y')) def h(y): self.assertEqual(y.aval.sharding.spec, P('x', 'y')) z = jnp.sin(y) @@ -6483,14 +6644,14 @@ def f(x): out = f(arr) self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) - @jtu.with_user_mesh((2, 2), ('x', 'y'), + @jtu.with_explicit_mesh((2, 2), ('x', 'y'), axis_types=(mesh_lib.AxisType.Auto,) * 2) def test_full_auto_to_partial_user(self, mesh): np_inp = np.arange(16.).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) arr = jax.device_put(np_inp, s) - @partial(explicit_axes, axes='y', in_shardings=P(None, 'y')) + @partial(explicit_axes, axes='y', in_sharding=P(None, 'y')) def h(y): self.assertEqual(y.aval.sharding.spec, P(None, 'y')) z = jnp.sin(y) @@ -6509,7 +6670,7 @@ def f(x): out = f(arr) self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) - @jtu.with_user_mesh((2, 2), ('x', 'y')) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_auto_gather_out_sharding(self, mesh): embed = jax.device_put(jnp.arange(128 * 8.).reshape(64, 16), jax.NamedSharding(mesh, P(None, 'x'))) @@ -6544,7 +6705,7 @@ def g(x, y): out = jax.jit(jax.grad(g))(embed, tok) self.assertEqual(out.sharding, embed.sharding) - @jtu.with_user_mesh((2, 2), ('x', 'y')) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_reshard_error(self, mesh): np_inp = np.arange(16.).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) @@ -6604,7 +6765,7 @@ def test_auto_axes_top_level(self): arr1 = jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y'))) arr2 = jax.device_put(np_inp.T, NamedSharding(mesh, P('y', 'x'))) - @partial(auto_axes, out_shardings=P('x', None)) + @partial(auto_axes, out_sharding=P('x', None)) def auto_matmul(arr1, arr2): return arr1 @ arr2 @@ -6626,7 +6787,7 @@ def test_explicit_axes_top_level(self): arr1 = jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y'))) arr2 = jax.device_put(np_inp.T, NamedSharding(mesh, P('y', 'x'))) - @partial(explicit_axes, in_shardings=(P('x', None), P('x', None))) + @partial(explicit_axes, in_sharding=(P('x', None), P('x', None))) def jax_matmul(arr1, arr2): out = arr1 @ arr2 self.assertEqual(out.aval.sharding.spec, P('x', None)) @@ -6659,7 +6820,7 @@ def matmul_reshard(arr1, arr2): with jax.sharding.use_mesh(mesh): matmul_reshard(arr1, arr2) - @jtu.with_user_mesh((2, 2), ('x', 'y')) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_full_auto_outside_jit(self, mesh): np_inp = np.arange(16.).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) @@ -6675,11 +6836,11 @@ def f(x): self.assertEqual(a.aval.sharding.spec, P(None, None)) return a - hf = auto_axes(f, axes=('x', 'y'), out_shardings=P('x', 'y')) + hf = auto_axes(f, axes=('x', 'y'), out_sharding=P('x', 'y')) out = hf(arr) self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) - @jtu.with_user_mesh((2, 2), ('x', 'y'), + @jtu.with_explicit_mesh((2, 2), ('x', 'y'), axis_types=(AxisType.Auto,) * 2) def test_full_visible_outside_jit(self, mesh): np_inp = np.arange(16.).reshape(8, 2) @@ -6694,7 +6855,7 @@ def f(x): self.assertEqual(z.aval.sharding.spec, P('x', 'y')) return z - hf = explicit_axes(f, axes=('x', 'y'), in_shardings=P('x', 'y')) + hf = explicit_axes(f, axes=('x', 'y'), in_sharding=P('x', 'y')) out = hf(arr) # doesn't crash self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) @@ -6732,7 +6893,7 @@ def f(x): self.assertTupleEqual(out2.sharding._device_assignment, tuple(mesh2.devices.flat)) - @jtu.with_user_mesh((2, 1), ('x', 'y')) + @jtu.with_explicit_mesh((2, 1), ('x', 'y')) def test_svd(self, mesh): np_inp = np.zeros([128, 128]) arr = jax.device_put(np_inp, NamedSharding(mesh, P(None, None))) @@ -6757,7 +6918,7 @@ def f(x, y): self.assertNotIn("mhlo.sharding", lowered_text) @parameterized.parameters(True, False) - @jtu.with_user_mesh((2, 2), ('x', 'y')) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_mul_vmap(self, use_jit, mesh): np_inp = np.arange(16.).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) @@ -6794,7 +6955,7 @@ def g(x): self.assertEqual(out.sharding, arr.sharding) @parameterized.parameters(True, False) - @jtu.with_user_mesh((2, 2), ('x', 'y')) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_dot_general_vmap(self, use_jit, mesh): np_inp1 = np.arange(16.).reshape(4, 2, 2) np_inp2 = np.arange(16.).reshape(2, 4, 2) @@ -6813,7 +6974,7 @@ def f(x, y): self.assertEqual(out.shape, (2, 2, 4)) self.assertEqual(out.sharding, NamedSharding(mesh, P(None, 'y', 'x'))) - @jtu.with_user_mesh((2, 2), ('x', 'y')) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_reshape_vmap(self, mesh): np_inp = np.arange(16).reshape(2, 8) arr = jax.device_put(np_inp, NamedSharding(mesh, P(None, 'x'))) @@ -6829,7 +6990,7 @@ def f(x): self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None, 'y'))) @parameterized.parameters(True, False) - @jtu.with_user_mesh((2, 2), ('x', 'y')) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_shit_vmap_error_check(self, use_jit, mesh): np_inp = np.arange(16).reshape(8, 2) arr = jax.device_put(np_inp, NamedSharding(mesh, P('x', None))) @@ -6858,7 +7019,7 @@ def f(x, y): "Only one of spmd_axis_name or arrays sharded on.*spmd_axis_name"): jax.vmap(f, spmd_axis_name='y')(arr, arr) - @jtu.with_user_mesh((2,), ('x',)) + @jtu.with_explicit_mesh((2,), ('x',)) def test_unmapped_last_vmap(self, mesh): np_inp = np.arange(8) arr = jax.device_put(np_inp, NamedSharding(mesh, P('x',))) @@ -6871,7 +7032,7 @@ def f(x): self.assertEqual(out.shape, (4, 8)) self.assertEqual(out.sharding, NamedSharding(mesh, P(None, 'x'))) - @jtu.with_user_mesh((2,), ('x',), axis_types=AxisType.Auto) + @jtu.with_explicit_mesh((2,), ('x',), axis_types=AxisType.Auto) def test_shmap_close_over(self, mesh): const = jnp.arange(8) def f(): @@ -6881,7 +7042,7 @@ def f(): shmap_f() # doesn't crash jax.jit(shmap_f)() # doesn't crash - @jtu.with_user_mesh((2, 2), ('x', 'y'), + @jtu.with_explicit_mesh((2, 2), ('x', 'y'), axis_types=(AxisType.Auto,) * 2) def test_shmap_close_over_partial_auto(self, mesh): const = jnp.arange(8) @@ -6889,7 +7050,7 @@ def f(): return const * 2 shmap_f = shard_map(f, mesh=mesh, in_specs=(), out_specs=P('x'), - auto=frozenset({'y'})) + axis_names={'x'}) f = jax.jit(shmap_f) out = f() self.assertArraysEqual(out, jnp.concatenate([const * 2, const * 2])) @@ -6897,7 +7058,7 @@ def f(): jaxpr = f.trace().jaxpr self.assertIn('mesh_cast', str(jaxpr)) - @jtu.with_user_mesh((2, 1), ('x', 'y')) + @jtu.with_explicit_mesh((2, 1), ('x', 'y')) def test_wsc_error(self, mesh): s = NamedSharding(mesh, P('x')) with self.assertRaisesRegex( @@ -6911,8 +7072,11 @@ def test_wsc_error(self, mesh): "The spec of NamedSharding passed to with_sharding_constraint"): jax.lax.with_sharding_constraint(np.arange(8).reshape(4, 2), s) - s = NamedSharding(mesh, P()) - jax.lax.with_sharding_constraint(np.arange(8), s) + with self.assertRaisesRegex( + ValueError, + 'with_sharding_constraint cannot be used when all axes of the mesh are' + ' of type `Explicit`'): + jax.lax.with_sharding_constraint(np.arange(8), NamedSharding(mesh, P())) s = NamedSharding(Mesh(mesh.devices, mesh.axis_names, axis_types=(AxisType.Explicit, AxisType.Auto)), @@ -6944,7 +7108,7 @@ def f(x, y): "Using PartitionSpec when.*not under a mesh context.*is not allowed"): f(arr, arr2) - @jtu.with_user_mesh((2, 1), ('x', 'y'), + @jtu.with_explicit_mesh((2, 1), ('x', 'y'), axis_types=(AxisType.Auto,) * 2) def test_error_on_canonicalize_under_auto_mode(self, mesh): np_inp = np.arange(16).reshape(8, 2) @@ -6961,7 +7125,41 @@ def f(x, y): "PartitionSpec passed to einsum cannot contain axis names.*Auto.*Manual"): f(arr, arr2) - @jtu.with_user_mesh((2,), ('x',)) + def test_broadcasted_iota_mix_axes(self): + mesh = jtu.create_mesh( + (2, 2, 2), ('x', 'y', 'z'), + axis_types=(AxisType.Auto, AxisType.Explicit, AxisType.Explicit)) + yz_sharding = NamedSharding(mesh, P(('y', 'z'))) + + @jax.jit + def iota(): + out = jax.lax.broadcasted_iota( + dtype=jnp.int32, + shape=(16, 24), + dimension=1, + out_sharding=yz_sharding) + self.assertEqual(out.aval.sharding.spec, P(('y', 'z'), None)) + return out + + with jax.sharding.use_mesh(mesh): + out = iota() + self.assertEqual(out.sharding, yz_sharding) + + @jtu.with_explicit_mesh((2, 2, 2), ('x', 'y', 'z')) + def test_broadcast_to(self, mesh): + x = np.arange(24).reshape((1, 24)) + x = jax.device_put(x, P(None, ('y', 'z'))) + + @jax.jit + def f(x): + out = jnp.broadcast_to(x, (8, 24), out_sharding=P('x', ('y', 'z'))) + self.assertEqual(out.aval.sharding.spec, P('x', ('y', 'z'))) + return out + + out = f(x) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', ('y', 'z')))) + + @jtu.with_explicit_mesh((2,), ('x',)) def test_cumsum(self, mesh): np_inp = np.arange(16).reshape(8, 2) arr = jax.device_put(np_inp, NamedSharding(mesh, P())) @@ -6974,6 +7172,18 @@ def f(x): self.assertArraysEqual(out, np.cumsum(np_inp)) self.assertEqual(out.sharding, NamedSharding(mesh, P(None))) + @jax.jit + def f(x): + x = jnp.expand_dims(x, 1) + self.assertEqual(x.aval.sharding.spec, P('x', None)) + out = jnp.cumsum(x, axis=1) + self.assertEqual(out.aval.sharding.spec, P('x', None)) + return out + + arr2 = jax.device_put(np.arange(8), P('x')) + out = f(arr2) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None))) + def test_device_put_under_use_mesh(self): mesh = jtu.create_mesh((2, 2), ('x', 'y')) x = jnp.zeros((4, 4), dtype=jnp.int32) @@ -7021,26 +7231,26 @@ def test_wsc_pspec_use_mesh(self, sharded_inp): self.assertArraysEqual(out2, np_inp) self.assertEqual(out2.sharding, NamedSharding(mesh, P('x', 'y'))) - @jtu.with_user_mesh((2, 1), ('x', 'y'), + @jtu.with_explicit_mesh((2, 1), ('x', 'y'), axis_types=(AxisType.Auto,) * 2) def test_axes_api_error_manual_to_auto_explicit(self, mesh): def g(x): return auto_axes(lambda a: a * 2, axes=('x', 'y'), - out_shardings=P('x', 'y'))(x) + out_sharding=P('x', 'y'))(x) with self.assertRaisesRegex( NotImplementedError, "Going from `Manual`.*to.*`Auto`.*`Explicit`"): jax.jit(shard_map(g, mesh=mesh, in_specs=P('x', 'y'), out_specs=P('x', 'y')) )(np.arange(16).reshape(8, 2)) - @jtu.with_user_mesh((2,), ('x',)) + @jtu.with_explicit_mesh((2,), ('x',)) def test_auto_axes_numpy_array(self, mesh): @jax.jit def f(x): self.assertTrue(x.aval.sharding.mesh._are_all_axes_auto) return x * 2 - out = auto_axes(f, out_shardings=P('x'))(np.arange(8)) + out = auto_axes(f, out_sharding=P('x'))(np.arange(8)) self.assertEqual(out.sharding, NamedSharding(mesh, P('x'))) self.assertArraysEqual(out, np.arange(8) * 2) @@ -7051,7 +7261,7 @@ def f(x): jtu.dtypes.all_integer + jtu.dtypes.all_unsigned), shape_and_spec=[((), P()), ((2,), P('x')), ((2, 4), P('x', 'y'))], ) - @jtu.with_user_mesh((2, 2), ('x', 'y')) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_bitcast_convert_type(self, from_dtype, to_dtype, shape_and_spec, mesh): shape, spec = shape_and_spec @@ -7090,7 +7300,31 @@ def f(x): self.assertEqual(out.shape, expected_shape) self.assertEqual(out.sharding, NamedSharding(mesh, expected_spec)) - def test_auto_axes_computation_follows_data_error(self): + @jtu.with_explicit_mesh((2,), ('x',)) + def test_dynamic_slice(self, mesh): + np_inp = np.arange(16., dtype=np.float32) + s = NamedSharding(mesh, P('x')) + arr = jax.device_put(np_inp, s) + + @jax.jit + def f(x): + y = lax.dynamic_slice_in_dim(x, jnp.array(1, dtype=np.int32), 2) + self.assertEqual(y.aval.sharding.spec, P('x')) + return y + + out = f(arr) + self.assertEqual(out.sharding, s) + + def g(x): + return jnp.sum(f(x)) + + out = jax.jit(jax.grad(g))(arr) + self.assertEqual(out.sharding, arr.sharding) + + out = jax.grad(g)(arr) + self.assertEqual(out.sharding, arr.sharding) + + def test_auto_axes_computation_follows_data(self): mesh = jtu.create_mesh((2,), ('x',), axis_types=(AxisType.Explicit,)) s = NamedSharding(mesh, P('x')) arr = jax.device_put(np.arange(8), s) @@ -7099,8 +7333,9 @@ def test_auto_axes_computation_follows_data_error(self): def f(x): return x * 2 - with self.assertRaisesRegex(ValueError, "Context mesh.*cannot be empty"): - auto_axes(f, out_shardings=s)(arr) + out = auto_axes(f, out_sharding=s)(arr) + self.assertEqual(out.sharding, s) + self.assertArraysEqual(out, arr * 2) def test_divisbility_aval_error(self): abstract_mesh = mesh_lib.AbstractMesh( @@ -7110,7 +7345,7 @@ def test_divisbility_aval_error(self): ValueError, 'does not evenly divide the dimension size'): core.ShapedArray((5, 2), jnp.int32, sharding=s) - @jtu.with_user_mesh((2, 2), ('x', 'y')) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_scan_unroll(self, mesh): np_inp = np.arange(64, dtype=jnp.float32).reshape(8, 8) arr = jax.device_put(np_inp, NamedSharding(mesh, P(None, 'y'))) @@ -7124,7 +7359,7 @@ def body(carry, x): f(carry, arr) # doesn't crash - @jtu.with_user_mesh((2,), ('x',)) + @jtu.with_explicit_mesh((2,), ('x',)) def test_reshard_with_np_array(self, mesh): out = reshard(np.arange(8), P('x')) self.assertEqual(out.sharding, NamedSharding(mesh, P('x'))) @@ -7135,6 +7370,7 @@ def f(x): out = f(np.arange(8)) self.assertEqual(out.sharding, NamedSharding(mesh, P('x'))) + @jtu.thread_unsafe_test() def test_set_mesh(self): mesh = jtu.create_mesh((2,), ('x',), axis_types=(AxisType.Explicit,)) try: @@ -7142,28 +7378,578 @@ def test_set_mesh(self): out = reshard(np.arange(8), P('x')) self.assertEqual(out.sharding, NamedSharding(mesh, P('x'))) finally: + self.assertIsNone(prev_mesh) jax.sharding.set_mesh(prev_mesh) - @jtu.with_user_mesh((2,), ('x',)) + @jtu.with_explicit_mesh((2,), ('x',)) def test_auto_axes_late_bind(self, mesh): @auto_axes def f(x): return x * 2 - out = f(np.arange(8), out_shardings=P('x')) + out = f(np.arange(8), out_sharding=P('x')) self.assertEqual(out.sharding, NamedSharding(mesh, P('x'))) self.assertArraysEqual(out, np.arange(8) * 2) - @jtu.with_user_mesh((2,), ('x',), axis_types=AxisType.Auto) + @jtu.with_explicit_mesh((2,), ('x',), axis_types=AxisType.Auto) def test_explicit_axes_late_bind(self, mesh): @explicit_axes def f(x): return x * 2 - out = f(np.arange(8), in_shardings=P('x')) + out = f(np.arange(8), in_sharding=P('x')) self.assertEqual(out.sharding, NamedSharding(mesh, P('x'))) self.assertArraysEqual(out, np.arange(8) * 2) + @jtu.with_explicit_mesh((2,), ('x',)) + def test_rng_bit_generator(self, mesh): + def f(key): + out = lax.rng_bit_generator(key, shape=(4, 8), out_sharding=P('x')) + self.assertEqual(out[0].aval.sharding.spec, P(None)) + self.assertEqual(out[1].aval.sharding.spec, P('x', None)) + return out + + key = np.array((1, 2, 3, 4)).astype(np.uint32) + out1 = f(key) + jit_f = jax.jit(f) + out2 = jit_f(key) + self.assertEqual(out1[0].shape, (4,)) + self.assertEqual(out1[1].shape, (4, 8)) + self.assertEqual(out2[0].sharding, NamedSharding(mesh, P())) + self.assertEqual(out2[1].sharding, NamedSharding(mesh, P('x', None))) + self.assertEqual(out1[0].sharding, out2[0].sharding) + self.assertEqual(out1[1].sharding, out2[1].sharding) + self.assertArraysEqual(out1[0], out2[0]) + self.assertArraysEqual(out1[1], out2[1]) + + @jtu.with_explicit_mesh((2,), ('x',)) + def test_fold_in(self, mesh): + key = jax.random.key(72) + key = jax.device_put(key, NamedSharding(mesh, P())) + + @jax.jit + def f(key): + f1 = jax.random.fold_in(key, 1) + self.assertEqual(jax.random.key_data(f1).aval.sharding.spec, P(None)) + return f1 + + out = f(key) + self.assertEqual(out.sharding, NamedSharding(mesh, P())) + + @parameterized.named_parameters( + ("bits", partial(jax.random.bits, shape=(8, 12)), P('x', 'y')), + ("uniform", partial(jax.random.uniform, shape=(8, 12)), P('x', 'y')), + ("normal", partial(jax.random.normal, shape=(8, 12)), P('x', 'y')), + ("randint", partial(jax.random.randint, shape=(8, 12), minval=0, maxval=10), + P('x', 'y')), + ("permutation_1d", partial(jax.random.permutation, x=8), P('x')), + ("permutation_2d", partial(jax.random.permutation, + x=np.arange(8 * 12).reshape(8, 12)), + P('x', 'y')), + ) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) + def test_random_functions(self, fun, out_spec, mesh): + @jax.jit + def f(key): + out = fun(key, out_sharding=out_spec) + self.assertEqual(out.aval.sharding.spec, out_spec) + return out + + key = jax.random.key(1) + out = f(key) + self.assertEqual(out.sharding, NamedSharding(mesh, out_spec)) + + lowered_text = f.lower(key).as_text() + if config.use_shardy_partitioner.value: + self.assertIn('sdy.sharding_constraint', lowered_text) + if out_spec == P('x', 'y'): + self.assertIn('<@mesh, [{"x"}, {"y"}]>', lowered_text) + else: + assert out_spec == P('x') + self.assertIn('<@mesh, [{"x"}]>', lowered_text) + else: + if out_spec == P('x', 'y'): + self.assertIn('mhlo.sharding = "{devices=[2,2]<=[4]}"}', lowered_text) + else: + assert out_spec == P('x') + self.assertIn( + 'mhlo.sharding = "{devices=[2,2]<=[4] last_tile_dim_replicate}"}', + lowered_text) + + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) + def test_random_truncated_normal(self, mesh): + @jax.jit + def f(key, lower): + out = jax.random.truncated_normal(key, lower, 2., shape=(8, 12), + out_sharding=P('x', 'y')) + self.assertEqual(out.aval.sharding.spec, P('x', 'y')) + return out + + key = jax.random.key(1) + out = f(key, -1.) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) + + lowered_text = f.lower(key, -1.).as_text() + if config.use_shardy_partitioner.value: + self.assertIn('sdy.sharding_constraint', lowered_text) + self.assertIn('<@mesh, [{"x"}, {"y"}]>', lowered_text) + else: + self.assertIn('mhlo.sharding = "{devices=[2,2]<=[4]}"}', lowered_text) + + def test_random_normal_wo_mesh_context_error(self): + mesh = jtu.create_mesh((2, 2), ('x', 'y'), + axis_types=(AxisType.Explicit,) * 2) + s = NamedSharding(mesh, P('x', 'y')) + + @jax.jit + def f(key): + out = jax.random.normal(key, shape=(8, 12), out_sharding=s) + self.assertEqual(out.aval.sharding.spec, P('x', 'y')) + self.assertEqual(out.aval.sharding.mesh, mesh.abstract_mesh) + return out + + key = jax.random.key(1) + with self.assertRaisesRegex( + ValueError, + 'Length of device assignment.*is not equal to the size of the mesh'): + f(key) + + def test_random_normal_wo_mesh_context(self): + mesh = jtu.create_mesh((2, 2), ('x', 'y'), + axis_types=(AxisType.Explicit,) * 2) + s = NamedSharding(mesh, P('x', 'y')) + + @jax.jit + def f(arr, key): + out = jax.random.normal(key, shape=(8, 12), out_sharding=s) + self.assertEqual(out.aval.sharding.spec, P('x', 'y')) + return arr + out + + key = jax.random.key(1) + out = f(jax.device_put(np.arange(8 * 12.).reshape(8, 12), s), key) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) + + def test_auto_axes_no_context_mesh(self): + mesh = jtu.create_mesh((2, 2), ('x', 'y'), axis_types=(AxisType.Explicit,) * 2) + np_inp = np.arange(16.).reshape(8, 2) + s = NamedSharding(mesh, P('x', 'y')) + arr = jax.device_put(np_inp, s) + + @partial(auto_axes, axes='x', + out_sharding=NamedSharding(mesh, P('x', 'y'))) + def h(y): + self.assertEqual(y.aval.sharding.spec, P(None, 'y')) + z = jnp.sin(y) + self.assertEqual(z.aval.sharding.spec, P(None, 'y')) + return z + + out = jax.jit(h)(arr) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) + + out = h(arr) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) + + def test_scan_with_random_key_inside_jit(self): + mesh = jtu.create_mesh((2,), ('x',)) + sharding = NamedSharding(mesh, P(None, 'x')) + + @jax.jit + def scan(xs): + def step(carry, x): + next_carry = jax.vmap(jax.random.fold_in)(carry, x) + next_carry = jnp.where(x % 2 == 0, carry, next_carry) + return next_carry, None + rng = jnp.broadcast_to(jax.random.key(0), xs.shape[1:]) + rng, _ = jax.lax.scan(step, rng, xs) + return rng + + xs = jnp.arange(8).reshape(2, 4) + scan(xs) + + xs = jax.device_put(xs, sharding) + scan(xs) + + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) + def test_select_batch(self, mesh): + y_sharding = NamedSharding(mesh, P('y', None)) + xy_sharding = NamedSharding(mesh, P('x', 'y', None)) + batch_a = jax.device_put(jnp.ones((4, 2, 3), dtype=jnp.float32), xy_sharding) + batch_b = jax.device_put(jnp.ones((4, 2, 2), dtype=jnp.int32), xy_sharding) + + out_s = NamedSharding(mesh, P('x', 'y', None, None)) + + def select(a, b): + c = a.at[b].get(out_sharding=y_sharding) + return c + + @jax.jit + def vmap_select(batch_a, batch_b): + out = jax.vmap(select)(batch_a, batch_b) + self.assertEqual(out.aval.sharding.spec, out_s.spec) + return out + + out = vmap_select(batch_a, batch_b) + self.assertEqual(out.sharding, out_s) + + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) + def test_where_vmap(self, mesh): + xy_sharding = NamedSharding(mesh, P('x', 'y', None)) + batch_a = jax.device_put(jnp.ones((4, 2, 3), dtype=jnp.float32), xy_sharding) + batch_b = jax.device_put(jnp.ones((4, 2, 3), dtype=jnp.bool), xy_sharding) + + def where(a, b): + out = jnp.where(b, a, 0) + return out + + @jax.jit + def vmap_where(batch_a, batch_b): + out = jax.vmap(where)(batch_a, batch_b) + self.assertEqual(out.aval.sharding.spec, xy_sharding.spec) + return out + + out = vmap_where(batch_a, batch_b) + self.assertEqual(out.sharding, xy_sharding) + + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) + def test_convert_element_type_vmap(self, mesh): + np_inp = np.arange(16).reshape(8, 2) + arr = jax.device_put(np_inp, P('x', 'y')) + am = mesh.abstract_mesh + + @jax.jit + @jax.vmap + def f(x): + y = lax_internal._convert_element_type( + x, jnp.bfloat16, sharding=NamedSharding(am, P('y'))) + self.assertEqual(y.aval.sharding.spec, P('y')) + return y + + out = f(arr) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) + + @jtu.with_explicit_mesh((2,), ('x',)) + def test_jnp_repeat(self, mesh): + out = jnp.repeat(np.eye(3), np.array((2,2,2,)) - 1, axis=0) + self.assertEqual(out.sharding, NamedSharding(mesh, P(None, None))) + + a = jnp.eye(3) + out = jnp.repeat(a, np.array((2,2,2,)) - 1, axis=0) + self.assertEqual(out.sharding, a.sharding) + + a = jax.device_put(jnp.eye(4), P('x')) + out = jnp.repeat(a, np.array((2,2,2,2)) - 1, axis=0, out_sharding=P('x')) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None))) + + a = jax.device_put(jnp.eye(16).reshape(16, 16), P('x')) + @jax.jit + def f(x): + return jnp.repeat(x, 3, axis=-1) + f(a) + + @jtu.with_explicit_mesh((2,), ('x',)) + def test_scatter_gather(self, mesh): + x = np.random.uniform(size=(mesh.size * 2, 3)) + i = np.random.randint(0, x.shape[1], len(x)) + j = np.random.randint(0, x.shape[1], len(x)) + x = jax.device_put(x, P("x")) + i = jax.device_put(i, P("x")) + j = jax.device_put(j, P("x")) + + @jax.jit + def f1(x, i, j): + x_a_j = x.at[:, j].get(out_sharding=jax.typeof(i).sharding) + return x.at[:, i].set(x_a_j) + f1(x,i,j) # doesn't crash + + @jax.jit + @jax.vmap + def f2(x, i, j): + x_j = x.at[j].get(out_sharding=jax.typeof(x).sharding) + return x.at[i].set(x_j) + f2(x,i,j) # doesn't crash + + @jtu.with_explicit_mesh((4, 2), ('x', 'y')) + def test_conv_general_dilated(self, mesh): + arr = jax.device_put(np.zeros((16, 128, 8)), P('x', 'y')) + + @jax.jit + def f(x): + # Conv1D across sharded y-axis: + out = jax.lax.conv_general_dilated( + x, np.zeros((5, 8, 10)), + window_strides=(1,), padding='SAME', feature_group_count=1, + lhs_dilation=(1,), rhs_dilation=(1,), + dimension_numbers=('NWC', 'WIO', 'NWC')) + self.assertEqual(out.aval.sharding.spec, P('x', 'y', None)) + # Max pooling along sharded y-axis. + out2 = jax.lax.reduce_window( + out, -np.inf, jax.lax.max, (1,2,1), (1,2,1), 'SAME') + self.assertEqual(out2.aval.sharding.spec, P('x', 'y', None)) + return out2 + + out = f(arr) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y', None))) + self.check_wsc_in_lowered(f.lower(arr).as_text()) + + jax.jit(jax.grad(lambda x: f(x).sum()))(arr) # doesn't crash + + with self.assertRaises(core.ShardingTypeError): + arr2 = jax.device_put(np.zeros((16, 128, 8)), P('x', None, 'y')) + f(arr2) + + @parameterized.named_parameters( + ('spec1', P('x', 'y', None)), + ('spec2', P('x', None, 'y')), + ('spec3', P(None, 'x', 'y')), + ('spec4', P(('x', 'y'), None, None)) + ) + @jtu.with_explicit_mesh((4, 2), ('x', 'y')) + def test_reduce_window(self, spec, mesh): + arr = jax.device_put(np.zeros((16, 128, 8)), spec) + + @jax.jit + def f(x): + out = jax.lax.reduce_window( + x, -np.inf, jax.lax.max, (1,2,1), (1,2,1), 'SAME') + self.assertEqual(out.aval.sharding.spec, spec) + return out + + out = f(arr) + self.assertEqual(out.sharding, NamedSharding(mesh, spec)) + self.check_wsc_in_lowered(f.lower(arr).as_text()) + + jax.jit(jax.grad(lambda x: f(x).sum()))(arr) # doesn't crash + + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) + def test_jnp_dot(self, mesh): + np_inp1 = np.arange(16).reshape(8, 2) + np_inp2 = np.arange(16).reshape(2, 8) + arr1 = jax.device_put(np_inp1, P('x', 'y')) + arr2 = jax.device_put(np_inp2, P('x', 'y')) + + @jax.jit + def f(x, y): + out = jnp.dot(x, y, out_sharding=P('x')) + self.assertEqual(out.aval.sharding.spec, P('x', None)) + return out + + out = f(arr1, arr2) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None))) + self.assertArraysEqual(out, np.dot(np_inp1, np_inp2)) + + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) + def test_jnp_ravel(self, mesh): + np_inp = np.arange(16).reshape(8, 2) + arr = jax.device_put(np_inp, P('x', 'y')) + + @jax.jit + def f(x): + out = jnp.ravel(x, out_sharding=P('x')) + self.assertEqual(out.aval.sharding.spec, P('x')) + return out + + out = f(arr) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x'))) + self.assertArraysEqual(out, np.ravel(np_inp)) + + @jtu.with_explicit_mesh((4, 2), ('x', 'y')) + def test_broadcast_forwarding(self, mesh): + arr = jax.device_put(np.zeros(()), P()) + + def f(x): + out = jax.lax.full_like(x, 1.0) + self.assertEqual(jax.typeof(out).sharding, jax.typeof(x).sharding) + return out + + f(arr) # doesn't crash + jax.jit(f)(arr) # doesn't crash + + @config.use_shardy_partitioner(True) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) + def test_unreduced_basic(self, mesh): + np_inp = np.arange(16.).reshape(8, 2) + x = jax.device_put(np_inp, P('x', 'y')) + y = jax.device_put(np_inp.T, P('y', None)) + a = jax.device_put(np_inp, P('x', 'y')) + b = jax.device_put(np_inp.T, P('y', None)) + + @jax.jit + def f(x, y, a, b): + m1 = jnp.einsum('xy,yz->xz', x, y, out_sharding=P('x', unreduced={'y'})) + self.assertEqual(m1.aval.sharding.spec, P('x', None, unreduced={'y'})) + + m2 = jnp.einsum('xy,yz->xz', a, b, out_sharding=P('x', unreduced={'y'})) + self.assertEqual(m2.aval.sharding.spec, P('x', None, unreduced={'y'})) + + s = m1 + m2 # unreduced + self.assertEqual(s.aval.sharding.spec, P('x', None, unreduced={'y'})) + + out = reshard(s, P('x')) # reduce + self.assertEqual(out.aval.sharding.spec, P('x', None)) + return out + + out = f(x, y, a, b) + self.assertArraysEqual(out, (np_inp @ np_inp.T) + (np_inp @ np_inp.T)) + + traced = f.trace(x, y, a, b) + lowered_text = traced.lower().as_text() + self.assertIn('unreduced={"y"}', lowered_text) + self.assertEqual(lowered_text.count('unreduced={"y"}'), 3) + + f_bar = jax.jit(jax.grad(lambda x, y, a, b: f(x, y, a, b).sum(), + argnums=(0, 1, 2, 3))) + f_bar(x, y, a, b) # doesn't crash + + grad_jaxpr = f_bar.trace(x, y, a, b).jaxpr + reshard_eqn = grad_jaxpr.eqns[4].params['jaxpr'].eqns[0] + self.assertEqual(reshard_eqn.params['dst_sharding'].spec.reduced, + frozenset('y')) + self.assertEqual(reshard_eqn.params['dst_sharding'].spec.unreduced, + frozenset()) + + @jtu.with_explicit_mesh((2, 2, 1), ('x', 'y', 'z')) + def test_dot_general_unreduced_error(self, mesh): + np_inp = np.arange(16).reshape(8, 2) + # Case 1 + x = jax.device_put(np_inp, P('x', 'y')) + y = jax.device_put(np_inp.T, P('y', None)) + + @jax.jit + def f(x, y): + return jnp.einsum('xy,yz->xz', x, y, out_sharding=P('x', unreduced={'z'})) + with self.assertRaisesRegex( + core.ShardingTypeError, + "unreduced axes should be equal to the contracting specs"): + f.trace(x, y) + + # Case 2 + x = jax.device_put(np_inp, P('x', 'y')) + y = jax.device_put(np_inp.T, P(None, None)) + @jax.jit + def g(x, y): + return jnp.einsum('xy,yz->xz', x, y, out_sharding=P('x', unreduced={'y'})) + with self.assertRaisesRegex( + core.ShardingTypeError, + "lhs and rhs contracting dims should be sharded identically"): + g.trace(x, y) + + # Case 3 + x = jax.device_put(np_inp, P('x', None)) + y = jax.device_put(np_inp.T, P(None, None)) + + @jax.jit + def h(x, y): + return jnp.einsum('xy,yz->xz', x, y, out_sharding=P('x', unreduced={'y'})) + with self.assertRaisesRegex( + core.ShardingTypeError, + "unreduced axes should be equal to the contracting specs"): + h.trace(x, y) + + @jtu.with_explicit_mesh((2, 2, 1), ('x', 'y', 'z')) + def test_add_unreduced_error(self, mesh): + np_inp = np.arange(16).reshape(8, 2) + x = jax.device_put(np_inp, P('x', 'y')) + y = jax.device_put(np_inp.T, P('y', None)) + a = jax.device_put(np_inp, P('x', 'z')) + b = jax.device_put(np_inp.T, P('z', None)) + + @jax.jit + def f(x, y, a, b): + m1 = jnp.einsum('xy,yz->xz', x, y, out_sharding=P('x', unreduced={'y'})) + m2 = jnp.einsum('xy,yz->xz', a, b, out_sharding=P('x', unreduced={'z'})) + return m1 + m2 + + with self.assertRaisesRegex( + core.ShardingTypeError, + "lhs and rhs to `add` must be unreduced along the same mesh axes"): + f.trace(x, y, a, b) + + @jax.jit + def g(x, y): + m1 = jnp.einsum('xy,yz->xz', x, y, out_sharding=P('x', unreduced={'y'})) + m2 = jnp.einsum('xy,yz->xz', a, b, out_sharding=P('x')) + return m1 + m2 + + with self.assertRaisesRegex( + core.ShardingTypeError, "lhs is unreduced while rhs is not"): + g.trace(x, y) + + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) + def test_eval_shape(self, mesh): + np_inp = np.arange(16).reshape(8, 2) + arr = jax.device_put(np_inp, P('x', 'y')) + + @jax.jit + def f(x): + return x * 2 + + out = jax.eval_shape(f, arr) + self.assertIsInstance(out, jax.ShapeDtypeStruct) + self.assertEqual(out.sharding, + NamedSharding(mesh.abstract_mesh, P('x', 'y'))) + + @jtu.with_explicit_mesh((2,), ('x',)) + def test_he_normal(self, mesh): + init = jax.nn.initializers.he_normal(in_axis=0, out_axis=1) + key = jax.random.key(0) + out = init(key, (8, 2), jnp.float32, out_sharding=P('x')) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None))) + + @jtu.with_explicit_mesh((2,), ('x',)) + def test_nn_uniform(self, mesh): + init = jax.nn.initializers.uniform() + key = jax.random.key(0) + out = init(key, (8, 2), jnp.float32, out_sharding=P('x')) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None))) + + @jtu.with_explicit_mesh((2,), ('x',)) + def test_nn_constant(self, mesh): + init = jax.nn.initializers.constant(-7) + key = jax.random.key(0) + out = init(key, (8, 2), jnp.float32, out_sharding=P('x')) + self.assertArraysEqual(out, jnp.full((8, 2), -7, dtype=jnp.float32)) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None))) + + @config.numpy_rank_promotion('allow') + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) + def test_lax_map(self, mesh): + def simple_func(w, x): + return jnp.sum(w * x, axis=-1) + + w = jax.device_put(np.arange(4, dtype=np.float32), P('x')) + x = jax.device_put(np.ones((4, 2, 4), dtype=np.float32), + P(None, 'y', None)) + + jax.lax.map(lambda _x: simple_func(w, _x), x) # doesn't crash + + jax.lax.map(lambda _x: simple_func(w, _x), x, batch_size=2) # doesn't crash + + @config.numpy_rank_promotion('allow') + @jtu.with_explicit_mesh((2,), ('x',)) + def test_lax_map_remainder(self, mesh): + def simple_func(w, x): + return jnp.sum(w * x, axis=-1) + + w = jax.device_put(np.arange(4, dtype=np.float32), P()) + x = jax.device_put(np.ones((5, 2, 4), dtype=np.float32), + P(None, 'x', None)) + + jax.lax.map(lambda _x: simple_func(w, _x), x, batch_size=2) # doesn't crash + + @jtu.with_explicit_mesh((2,), ('x',)) + def test_extended_dtypes(self, mesh): + dtype = primal_tangent_dtype(jnp.dtype('int8'), jnp.dtype('bfloat16')) + + @jax.jit + def f(x): + x = jax.lax.convert_element_type(x, dtype) + self.assertEqual(x.aval.sharding.spec, P('x')) + x = jax.lax.convert_element_type(x, 'int8') + self.assertEqual(x.aval.sharding.spec, P('x')) + + x = jax.device_put(jnp.arange(8, dtype='int8'), P('x',)) + f(x) # doesn't crash + @jtu.pytest_mark_if_available('multiaccelerator') class PJitErrorTest(jtu.JaxTestCase): @@ -7437,7 +8223,8 @@ def f(a, b, c): def test_named_sharding_of_none(self): mesh = jtu.create_mesh((2,), ('x',)) - with self.assertRaisesRegex(TypeError, 'Unexpected None'): + with self.assertRaisesRegex( + TypeError, '(Unexpected None|incompatible function arguments)'): jax.NamedSharding(mesh, None) @@ -7590,12 +8377,12 @@ def test_op_sharding_tuple_shardings(self): def test_hlo_sharding_iota_tile_error(self): self.assertRaisesRegex( - xla_extension.XlaRuntimeError, + _jax.XlaRuntimeError, 'INVALID_ARGUMENT: `dims` should not be empty.', lambda: xc.HloSharding.iota_tile(()) ) self.assertRaisesRegex( - xla_extension.XlaRuntimeError, + _jax.XlaRuntimeError, 'INVALID_ARGUMENT: Cannot reshape from', lambda: xc.HloSharding.iota_tile( (2, 2), @@ -7604,7 +8391,7 @@ def test_hlo_sharding_iota_tile_error(self): ), ) self.assertRaisesRegex( - xla_extension.XlaRuntimeError, + _jax.XlaRuntimeError, 'INVALID_ARGUMENT: `reshape_dims` and `transpose_perm` should have the' ' same size', lambda: xc.HloSharding.iota_tile( @@ -7613,7 +8400,7 @@ def test_hlo_sharding_iota_tile_error(self): ), ) self.assertRaisesWithLiteralMatch( - xla_extension.XlaRuntimeError, + _jax.XlaRuntimeError, 'INVALID_ARGUMENT: `subgroup_types`(3) should not have more dimensions ' 'than `dims`(2).', lambda: xc.HloSharding.iota_tile( @@ -7865,12 +8652,6 @@ def f(x, y): @jtu.with_config(jax_use_shardy_partitioner=True) class ShardyTest(jtu.JaxTestCase): - # TODO(bartchr): Once JAX is released with SDY, remove setUp. - def setUp(self): - if not dialects.sdy: - raise unittest.SkipTest('Shardy is not available.') - super().setUp() - def test_lowering_input_output_sharding(self): mesh = jtu.create_mesh((4, 2), ('x', 'y')) np_inp = np.arange(16).reshape(8, 2) @@ -7936,26 +8717,26 @@ def f(x, y): self.assertIn('sdy.mesh @mesh = <["x"=8]>', lowered_str) def test_array_sharding_repr_with_priority(self): - sharding = sharding_impls.SdyArraySharding( + sharding = sharding_impls.SdyArray( mesh_shape=(('data', 4), ('model', 8), ('expert', 2)), - dimension_shardings=[ - sharding_impls.SdyDimSharding(axes=['data', 'expert'], is_closed=True), - sharding_impls.SdyDimSharding(axes=['model'], is_closed=False, priority=2)]) - self.assertEqual(repr(sharding), "SdyArraySharding([{'data', 'expert'}, {'model', ?}p2])") + dim_shardings=[ + sharding_impls.SdyDim(axes=['data', 'expert'], is_open=False), + sharding_impls.SdyDim(axes=['model'], is_open=True, priority=2)]) + self.assertEqual(repr(sharding), "SdyArray([{'data', 'expert'}, {'model', ?}p2])") def test_array_sharding_repr_with_logical_ids(self): abstract_mesh = jax.sharding.AbstractMesh((4, 8, 2), ('x', 'y', 'z')) ns = NamedSharding(abstract_mesh, P(('x', 'y'), 'z', P.UNCONSTRAINED, None), _logical_device_ids=[4, 5, 6, 7, 0, 1, 2, 3]) self.assertEqual(repr(ns._to_sdy_sharding(4)), - "SdyArraySharding([{'x', 'y'}, {'z'}, {?}, {}], " + "SdyArray([{'x', 'y'}, {'z'}, {?}, {}], " "device_ids=[4, 5, 6, 7, 0, 1, 2, 3])") def test_dimension_sharding_repr(self): - dim_sharding = sharding_impls.SdyDimSharding( - axes=['data', 'model'], is_closed=False, priority=2) + dim_sharding = sharding_impls.SdyDim( + axes=['data', 'model'], is_open=True, priority=2) self.assertEqual(repr(dim_sharding), - "SdyDimSharding({'data', 'model', ?}p2)") + "SdyDim({'data', 'model', ?}p2)") def test_tensor_dialect(self): # While this doesn't emit any `mlir::TensorDialect` ops, some pass in the @@ -8020,5 +8801,54 @@ def f(x, y, static_arg0=1, static_arg1=2): self.assertArraysEqual(result, expected_result) self.assertEqual(result.sharding, NamedSharding(mesh, P(None, None, 'x'))) + def test_custom_partition_shardy_migration(self): + if jtu.is_cloud_tpu(): + raise unittest.SkipTest("Custom partitioning is not supported on libtpu.") + + def partition(mesh, arg_shapes, result_shape): + def lower_fn(x): + return x + + return ( + mesh, + lower_fn, + arg_shapes[0].sharding, + (arg_shapes[0].sharding,), + ) + + def infer_sharding_from_operands(mesh, arg_shapes, result_shape): + return arg_shapes[0].sharding + + def propagate_user_sharding(mesh, user_shape): + return user_shape.sharding + + @custom_partitioning + def f(x): + return x + + f.def_partition( + infer_sharding_from_operands=infer_sharding_from_operands, + partition=partition, + propagate_user_sharding=propagate_user_sharding, + ) + + mesh = jtu.create_mesh((4, 2), ('x', 'y')) + x = jax.device_put(np.arange(32 * 16).reshape(32, 16), + NamedSharding(mesh, P(None, 'x'))) + with self.assertRaisesRegex( + NotImplementedError, 'provide sharding_rule to migrate to Shardy'): + jax.jit(f)(x) + + def test_reshard_empty_mesh_error(self): + arr = jax.device_put(np.arange(8), jax.devices()[0]) + with self.assertRaisesRegex(ValueError, "nonempty mesh"): + reshard(arr, NamedSharding(mesh_lib.empty_abstract_mesh, P(None))) + + def test_reshard_none_sharding_error(self): + arr = jax.device_put(np.arange(8), jax.devices()[0]) + with self.assertRaisesRegex(ValueError, "non-None"): + reshard(arr, None) + + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/pmap_test.py b/tests/pmap_test.py index af2d03e2945d..84136e48ecb7 100644 --- a/tests/pmap_test.py +++ b/tests/pmap_test.py @@ -49,7 +49,7 @@ from jax._src.internal_test_util import lax_test_util from jax._src.interpreters import pxla from jax._src.lax import parallel -from jax._src.lib import xla_extension +from jax._src.lib import _jax from jax._src.util import safe_map, safe_zip config.parse_flags_with_absl() @@ -318,12 +318,12 @@ def test_jit_lower_compile_with_compiler_options_invalid(self): lowered = f.lower(x) self.assertRaisesRegex( - xla_extension.XlaRuntimeError, "No such compile option: 'invalid_key'", + _jax.XlaRuntimeError, "No such compile option: 'invalid_key'", lambda: lowered.compile( compiler_options={"invalid_key": "invalid_value"})) self.assertRaisesRegex( - xla_extension.XlaRuntimeError, "is not a valid bool value.", + _jax.XlaRuntimeError, "is not a valid bool value.", lambda: lowered.compile( compiler_options={"xla_embed_ir_in_executable": "invalid_value"})) @@ -356,7 +356,7 @@ def test_jit_lower_compile_with_compiler_options_multiple(self): # We should still error on invalid options after some valid compiles self.assertRaisesRegex( - xla_extension.XlaRuntimeError, "No such compile option: 'invalid_key'", + _jax.XlaRuntimeError, "No such compile option: 'invalid_key'", lambda: lowered.compile( compiler_options={"invalid_key": "invalid_value"})) @@ -499,7 +499,7 @@ def testReduceScatterReplicaGroupsTiled(self): def testTrees(self): ptranspose = lambda x, axis_name: lax.all_to_all(x, axis_name, 0, 0) def protate(x, axis_name): - n = lax.psum(1, axis_name) + n = lax.axis_size(axis_name) return lax.ppermute(x, axis_name, [(i, (i + 1) % n) for i in range(n)]) tree_f = lambda f: partial(jax.tree.map, f) @@ -1395,7 +1395,7 @@ def testNestedPmapConstantError(self): def testCollectiveConstant(self): device_count = jax.device_count() - f = self.pmap(lambda x: lax.psum(1, 'i'), 'i') + f = self.pmap(lambda x: lax.axis_size('i'), 'i') x = jnp.arange(device_count) ans = f(x) expected = np.repeat(device_count, device_count) @@ -1408,9 +1408,9 @@ def testCollectiveConstantNested(self): def f(x): @partial(self.pmap, axis_name='j') def g(y): - a = lax.psum(1, 'i') - b = lax.psum(1, 'j') - c = lax.psum(1, ('i', 'j')) + a = lax.axis_size('i') + b = lax.axis_size('j') + c = lax.axis_size(('i', 'j')) return a, b, c return g(x) @@ -3189,7 +3189,7 @@ class EagerPmapMixin: def setUp(self): super().setUp() stack = contextlib.ExitStack() - stack.enter_context(jtu.thread_local_config_context(jax_disable_jit=True, jax_eager_pmap=True)) + stack.enter_context(jtu.thread_local_config_context(jax_disable_jit=True)) stack.enter_context(jtu.ignore_warning( message="Some donated buffers were not usable", category=UserWarning)) self.addCleanup(stack.close) diff --git a/tests/pretty_printer_test.py b/tests/pretty_printer_test.py index d87708c9d91c..b4363be1c965 100644 --- a/tests/pretty_printer_test.py +++ b/tests/pretty_printer_test.py @@ -13,24 +13,90 @@ # limitations under the License. from absl.testing import absltest - -from jax._src import test_util as jtu from jax._src import pretty_printer as pp +from jax._src import test_util as jtu class PrettyPrinterTest(jtu.JaxTestCase): def testSourceMap(self): doc = pp.concat([ - pp.text("abc"), pp.source_map(pp.text("def"), 101), - pp.source_map(pp.concat([pp.text("gh"), pp.brk(""), pp.text("ijkl")]), 77), - pp.text("mn"), + pp.text("abc"), + pp.source_map(pp.text("def"), 101), + pp.source_map( + pp.concat([pp.text("gh"), pp.brk(""), pp.text("ijkl")]), 77 + ), + pp.text("mn"), ]) source_map = [] out = doc.format(width=8, source_map=source_map) self.assertEqual(out, "abcdefgh\nijklmn") self.assertEqual(source_map, [[(3, 6, 101), (6, 8, 77)], [(0, 4, 77)]]) + def testBasics(self): + self.assertEqual(pp.nil().format(), "") + self.assertEqual(pp.text("").format(), "") + self.assertEqual(pp.text("testing").format(), "testing") + self.assertEqual(pp.text("\n").format(), "\n") + self.assertEqual(pp.brk().format(), "\n") + # Group that fits will use the space from brk() + self.assertEqual(pp.group(pp.brk()).format(), " ") + # Group that doesn't fit (due to width=0) will use newline + self.assertEqual(pp.group(pp.brk()).format(width=0), "\n") + + # Custom break text + self.assertEqual(pp.group(pp.brk("-")).format(), "-") + self.assertEqual(pp.group(pp.brk("-")).format(width=0), "\n") + + # Concatenation + self.assertEqual((pp.text("a") + pp.text("b")).format(), "ab") + self.assertEqual(pp.concat([pp.text("a"), pp.text("b c")]).format(), "ab c") + + x = pp.text("x") + y = pp.text("y") + z = pp.text("z") + + # Join + # Join with a break that becomes a space when fitting + join_doc_space = pp.join( + pp.text(",") + pp.brk(), [pp.text("a"), pp.text("b"), pp.text("c")] + ) + self.assertEqual(pp.group(join_doc_space).format(), "a, b, c") + self.assertEqual(pp.group(join_doc_space).format(width=5), "a,\nb,\nc") + self.assertEqual(pp.join(pp.text(","), [x, y, z]).format(), "x,y,z") + + j = pp.join( + pp.brk(), [pp.text("xx"), pp.text("yy"), pp.text("zz"), pp.text("ww")] + ) + self.assertEqual(pp.group(j).format(width=3), "xx\nyy\nzz\nww") + self.assertEqual(pp.group(j).format(width=80), "xx yy zz ww") + + bx = pp.brk() + x + bxbx = bx + bx + bx4 = bxbx + bxbx + + # Horizontal-like (fits) + self.assertEqual(pp.group(bx).format(), " x") + self.assertEqual(pp.group(bxbx).format(), " x x") + self.assertEqual(pp.group(bx4).format(), " x x x x") + + # Vertical-like (forced by width) + self.assertEqual(pp.group(bx).format(width=0), "\nx") + self.assertEqual(pp.group(bxbx).format(width=0), "\nx\nx") + self.assertEqual(pp.group(bx4).format(width=0), "\nx\nx\nx\nx") + self.assertEqual(pp.group(bxbx).format(width=3), "\nx\nx") + + # Nesting + xbybz = x + pp.brk() + y + pp.brk() + z + self.assertEqual(pp.nest(2, pp.group(bx)).format(), " x") # Stays flat + self.assertEqual(pp.nest(2, pp.group(bxbx)).format(), " x x") # Stays flat + self.assertEqual(pp.nest(2, pp.group(bx)).format(width=0), "\n x") + self.assertEqual( + pp.nest(2, pp.nest(2, pp.group(bx))).format(width=0), "\n x" + ) + self.assertEqual(pp.nest(2, pp.group(xbybz)).format(width=0), "x\n y\n z") + self.assertEqual(pp.nest(2, pp.group(bxbx)).format(width=0), "\n x\n x") + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/profiler_test.py b/tests/profiler_test.py index 215e363e446d..82d3ec8437d7 100644 --- a/tests/profiler_test.py +++ b/tests/profiler_test.py @@ -44,11 +44,11 @@ profiler_client = None tf_profiler = None -TBP_ENABLED = False +XPROF_ENABLED = False try: - import tensorboard_plugin_profile - del tensorboard_plugin_profile - TBP_ENABLED = True + import xprof + del xprof + XPROF_ENABLED = True except ImportError: pass @@ -61,6 +61,12 @@ class ProfilerTest(unittest.TestCase): # check functional correctness. def setUp(self): + if sys.version_info >= (3, 14) and jtu.TEST_NUM_THREADS.value > 1: + # TODO(phawkins): try reenabling these after + # https://github.com/python/cpython/issues/132817 is fixed. Simply + # installing the profiler hook is unsafe if there are multiple threads. + self.skipTest("Profiler tests are not thread-safe under Python 3.14") + super().setUp() self.worker_start = threading.Event() self.profile_done = False @@ -107,6 +113,31 @@ def testProgrammaticProfiling(self): self.assertIn(b"/device:TPU", proto) self.assertIn(b"pxla.py", proto) + def testProgrammaticProfilingWithOptions(self): + with tempfile.TemporaryDirectory() as tmpdir: + try: + options = jax.profiler.ProfileOptions() + options.python_tracer_level = 0 + jax.profiler.start_trace(tmpdir, profiler_options=options) + jax.pmap(lambda x: jax.lax.psum(x + 1, "i"), axis_name="i")( + jnp.ones(jax.local_device_count()) + ) + finally: + jax.profiler.stop_trace() + + proto_path = glob.glob( + os.path.join(tmpdir, "**/*.xplane.pb"), recursive=True + ) + self.assertEqual(len(proto_path), 1) + with open(proto_path[0], "rb") as f: + proto = f.read() + # Verify that the serialized proto contains host and device traces, and + # does not contain Python traces. + self.assertIn(b"/host:CPU", proto) + if jtu.test_device_matches(["tpu"]): + self.assertIn(b"/device:TPU", proto) + self.assertNotIn(b"pxla.py", proto) + def testProgrammaticProfilingPathlib(self): with tempfile.TemporaryDirectory() as tmpdir_string: tmpdir = pathlib.Path(tmpdir_string) @@ -127,6 +158,29 @@ def testProgrammaticProfilingPathlib(self): self.assertIn(b"/device:TPU", proto) self.assertIn(b"pxla.py", proto) + def testProgrammaticProfilingWithOptionsPathlib(self): + with tempfile.TemporaryDirectory() as tmpdir_string: + tmpdir = pathlib.Path(tmpdir_string) + try: + options = jax.profiler.ProfileOptions() + options.advanced_configuration = {"tpu_trace_mode": "TRACE_ONLY_HOST"} + jax.profiler.start_trace(tmpdir, profiler_options=options) + jax.pmap(lambda x: jax.lax.psum(x + 1, "i"), axis_name="i")( + jnp.ones(jax.local_device_count()) + ) + finally: + jax.profiler.stop_trace() + + proto_path = tuple(tmpdir.rglob("*.xplane.pb")) + self.assertEqual(len(proto_path), 1) + proto = proto_path[0].read_bytes() + # Verify that the serialized proto contains host traces and does not + # contain TPU device traces. + self.assertIn(b"/host:CPU", proto) + if jtu.test_device_matches(["tpu"]): + self.assertNotIn(b"/device:TPU", proto) + self.assertIn(b"pxla.py", proto) + def testProfilerGetFDOProfile(self): # Tests stop_and_get_fod_profile could run. try: @@ -176,7 +230,7 @@ def testProgrammaticProfilingContextManager(self): def testProgrammaticGpuCuptiTracing(self): @jit def xy_plus_z(x, y, z): - return jnp.float32(jax.lax.batch_matmul(jnp.bfloat16(x), y)) + z + return jnp.float32(jax.lax.batch_matmul(jnp.bfloat16(x), y)) + z k = jax.random.key(0) s = 1, 16, 16 jax.devices() @@ -290,7 +344,7 @@ def on_profile(port, logdir, worker_start): self._check_xspace_pb_exist(logdir) @unittest.skipIf( - not (portpicker and profiler_client and tf_profiler and TBP_ENABLED), + not (portpicker and profiler_client and tf_profiler and XPROF_ENABLED), "Test requires tensorflow.profiler, portpicker and " "tensorboard_profile_plugin") def test_remote_profiler(self): diff --git a/tests/python_callback_test.py b/tests/python_callback_test.py index 05b4c8d7c0ff..eef45b3b412b 100644 --- a/tests/python_callback_test.py +++ b/tests/python_callback_test.py @@ -30,7 +30,7 @@ from jax._src import util from jax.experimental import io_callback from jax.experimental import pjit -from jax.experimental.shard_map import shard_map +from jax._src.shard_map import shard_map import jax.numpy as jnp from jax.sharding import Mesh import numpy as np @@ -585,6 +585,78 @@ def fun(x): self.assertAllClose(2 * x, fun(x)) self.assertEqual(count(), 1) + @parameterized.parameters("int2", "int4", "uint2", "uint4", "float4_e2m1fn") + def test_subbyte_operands(self, dtype: str): + if "2" in dtype and jtu.test_device_matches(["tpu"]): + self.skipTest( + "TODO(dsuo): TPU callbacks send SIGABRT for int2, uint2, and" + " float4_e2m1fn." + ) + def get(x): + return x + def f(x): + y = jax.pure_callback( + get, + jax.ShapeDtypeStruct((8,), dtype=dtype), + x, + ) + return y + x = np.arange(8, dtype=dtype) + np.testing.assert_array_equal(jax.jit(f)(x), np.arange(8, dtype=dtype)) + + @parameterized.parameters("int2", "int4", "uint2", "uint4", "float4_e2m1fn") + def test_subbyte_results(self, dtype: str): + if "2" in dtype and jtu.test_device_matches(["tpu"]): + self.skipTest( + "TODO(dsuo): TPU callbacks send SIGABRT for int2, uint2, and" + " float4_e2m1fn." + ) + def get(): + return np.arange(8, dtype=dtype) + + def f(): + y = jax.pure_callback( + get, + jax.ShapeDtypeStruct((8,), dtype) + ) + return y + + np.testing.assert_array_equal(jax.jit(f)(), np.arange(8, dtype=dtype)) + + @parameterized.parameters("int2", "int4", "uint2", "uint4", "float4_e2m1fn") + def test_non_default_stride_subbyte_results(self, dtype: str): + if "2" in dtype and jtu.test_device_matches(["tpu"]): + self.skipTest( + "TODO(dsuo): TPU callbacks send SIGABRT for int2, uint2, and" + " float4_e2m1fn." + ) + x = jnp.arange(24, dtype=dtype).reshape(2, 3, 4) + def callback(x): + return np.asfortranarray(x) + + @jax.jit + def f(x): + return jax.pure_callback( + callback, jax.ShapeDtypeStruct(x.shape, x.dtype), x + ) + + result = f(x) + np.testing.assert_array_equal(x, result) + + def test_non_default_stride(self): + x = jnp.arange(24, dtype=jnp.float32).reshape(2, 3, 4) + def callback(x): + return np.asfortranarray(x) + + @jax.jit + def f(x): + return jax.pure_callback( + callback, jax.ShapeDtypeStruct(x.shape, x.dtype), x + ) + + result = f(x) + np.testing.assert_array_equal(x, result) + class PureCallbackTest(jtu.JaxTestCase): @@ -787,7 +859,7 @@ def sin_jvp(xs, ts): def f(x): return sin(x) out = f(2.) - np.testing.assert_allclose(out, jnp.cos(2.)) + np.testing.assert_allclose(out, jnp.cos(2.), atol=1e-7) def test_callback_inside_of_cond(self): @@ -990,26 +1062,11 @@ def f(x): def test_vmap_method_raise(self): @jax.vmap def f(x): - # Setting vectorized to None disables the current default behavior of - # falling back on sequential. - return jax.pure_callback(np.sin, x, x, vectorized=None) + return jax.pure_callback(np.sin, x, x) with self.assertRaisesRegex(NotImplementedError, "vmap is only supported"): f(jnp.arange(4.)) - def test_deprecated_vectorized(self): - def f(x, **kwargs): - return jax.pure_callback(np.sin, x, x, **kwargs) - - with self.assertWarnsRegex(DeprecationWarning, "The default behavior"): - jax.vmap(f)(jnp.arange(4.0)) - - with self.assertWarnsRegex(DeprecationWarning, "The vectorized argument"): - f(jnp.arange(4.0), vectorized=True) - - with self.assertWarnsRegex(DeprecationWarning, "The vectorized argument"): - f(jnp.arange(4.0), vectorized=False) - def test_vmap_method_expand_dims(self): def callback(x, y): self.assertTupleEqual(x.shape, (4,)) @@ -1057,20 +1114,6 @@ def fun(x): result += fun(jnp.ones((500, 500), jnp.complex64))[1] jax.block_until_ready(result) # doesn't deadlock - def test_non_default_stride(self): - x = jnp.arange(24, dtype=jnp.float32).reshape(2, 3, 4) - def callback(x): - return np.asfortranarray(x) - - @jax.jit - def f(x): - return jax.pure_callback( - callback, jax.ShapeDtypeStruct(x.shape, x.dtype), x - ) - - result = f(x) - np.testing.assert_array_equal(x, result) - class IOCallbackTest(jtu.JaxTestCase): @@ -1313,11 +1356,18 @@ def f_base(i, x): jax.effects_barrier() self.assertEqual(_collected, expected) - def test_can_shard_io_callback_manually(self): - if config.use_shardy_partitioner.value: - self.skipTest("TODO(b/384938613): Failing under shardy.") + @parameterized.named_parameters( + dict(testcase_name='multi_device', + single_device=False), + dict(testcase_name='single_device', + single_device=True) + ) + def test_can_shard_io_callback_manually(self, single_device: bool): - mesh = Mesh(np.array(jax.devices()), axis_names=('x',)) + devices = jax.devices() + if single_device: + devices = devices[:1] + mesh = Mesh(np.array(devices), axis_names=('x',)) spec = jax.sharding.PartitionSpec('x') sharding = jax.sharding.NamedSharding(mesh, spec) diff --git a/tests/pytorch_interoperability_test.py b/tests/pytorch_interoperability_test.py index e41c4329b95b..4b8f58cd6982 100644 --- a/tests/pytorch_interoperability_test.py +++ b/tests/pytorch_interoperability_test.py @@ -77,8 +77,7 @@ def testJaxToTorch(self, shape, dtype): rng = jtu.rand_default(self.rng()) np = rng(shape, dtype) x = jnp.array(np) - dlpack = jax.dlpack.to_dlpack(x) - y = torch.utils.dlpack.from_dlpack(dlpack) + y = torch.utils.dlpack.from_dlpack(x) if dtype == jnp.bfloat16: # .numpy() doesn't work on Torch bfloat16 tensors. self.assertAllClose(np, diff --git a/tests/qdwh_test.py b/tests/qdwh_test.py index 91cc3a51f876..955e23374fee 100644 --- a/tests/qdwh_test.py +++ b/tests/qdwh_test.py @@ -18,7 +18,7 @@ import jax from jax._src import config from jax._src import test_util as jtu -from jax._src.lax import qdwh +from jax._src.tpu.linalg import qdwh import jax.numpy as jnp import numpy as np diff --git a/tests/ragged_collective_test.py b/tests/ragged_collective_test.py index 844892adc052..8b94b862419c 100644 --- a/tests/ragged_collective_test.py +++ b/tests/ragged_collective_test.py @@ -21,12 +21,13 @@ import jax import jax.ad_checkpoint from jax import lax +from jax import vmap from jax.sharding import PartitionSpec as P from jax._src import config from jax._src import test_util as jtu import jax.numpy as jnp -from jax.experimental.shard_map import shard_map +from jax._src.shard_map import shard_map config.parse_flags_with_absl() @@ -89,7 +90,7 @@ def test_ragged_all_to_all(self, axis_name, mesh_axes): P(axis_name, None), ), out_specs=P(axis_name), - check_rep=False, + check_vma=False, ) def fwd( operand, output, input_offsets, send_sizes, output_offsets, recv_sizes @@ -175,7 +176,7 @@ def test_ragged_all_to_all_grad(self, axis_name, mesh_axes): P(axis_name, None), ), out_specs=P(axis_name), - check_rep=False, + check_vma=False, ) def fwd( operand, output, input_offsets, send_sizes, output_offsets, recv_sizes @@ -256,7 +257,7 @@ def test_ragged_all_to_all_axis_index_groups(self, axis_name, mesh_axes): P(axis_name, None), ), out_specs=P(axis_name), - check_rep=False, + check_vma=False, ) def fwd( operand, output, input_offsets, send_sizes, output_offsets, recv_sizes @@ -345,7 +346,7 @@ def test_ragged_all_to_all_degenerate_groups(self, axis_name, mesh_axes): P(axis_name, None), ), out_specs=P(axis_name), - check_rep=False, + check_vma=False, ) def fwd( operand, output, input_offsets, send_sizes, output_offsets, recv_sizes @@ -381,6 +382,257 @@ def fwd( c, jnp.array([[0, 0, 1, 0], [0, 2, 3, 4]], dtype=jnp.int32) ) + def test_ragged_all_to_all_vmap_multi_dim_operand(self): + device_type = jax.devices()[0].platform + if device_type == 'tpu' and jtu.get_tpu_version() < 4: + raise unittest.SkipTest( + 'UNSUPPORTED: HLO opcode `ragged-all-to-all` is not supported by TPU' + f' v{jtu.get_tpu_version()}' + ) + + axis_name = 'x' + mesh_axes = dict(x=2) + mesh = jtu.create_mesh(tuple(mesh_axes.values()), tuple(mesh_axes.keys())) + data_sharding = P(axis_name, None, None) + operand_data = jnp.zeros((2, 2, 3), dtype=jnp.int32) + output_data = jnp.zeros((2, 2, 4), dtype=jnp.int32) + input_offsets_data = jnp.zeros((2, 2, 2), dtype=jnp.int32) + send_sizes_data = jnp.zeros((2, 2, 2), dtype=jnp.int32) + output_offsets_data = jnp.zeros((2, 2, 2), dtype=jnp.int32) + recv_sizes_data = jnp.zeros((2, 2, 2), dtype=jnp.int32) + + operand = jax.device_put(operand_data, jax.sharding.NamedSharding(mesh, data_sharding)) + output = jax.device_put(output_data, jax.sharding.NamedSharding(mesh, data_sharding)) + input_offsets = jax.device_put(input_offsets_data, jax.sharding.NamedSharding(mesh, data_sharding)) + send_sizes = jax.device_put(send_sizes_data, jax.sharding.NamedSharding(mesh, data_sharding)) + output_offsets = jax.device_put(output_offsets_data, jax.sharding.NamedSharding(mesh, data_sharding)) + recv_sizes = jax.device_put(recv_sizes_data, jax.sharding.NamedSharding(mesh, data_sharding)) + + @partial( + shard_map, + mesh=mesh, + in_specs=( + P(axis_name, None), + P(axis_name, None), + P(axis_name, None), + P(axis_name, None), + P(axis_name, None), + P(axis_name, None), + ), + out_specs=P(axis_name), + check_vma=False, + ) + def fwd( + operand, output, input_offsets, send_sizes, output_offsets, recv_sizes + ): + return lax.ragged_all_to_all( + operand=operand.reshape(operand.shape[1:]), + output=output.reshape(output.shape[1:]), + input_offsets=input_offsets.reshape(input_offsets.shape[1:]), + send_sizes=send_sizes.reshape(send_sizes.shape[1:]), + output_offsets=output_offsets.reshape(output_offsets.shape[1:]), + recv_sizes=recv_sizes.reshape(recv_sizes.shape[1:]), + axis_name=axis_name, + ) + + res = vmap( + fwd, in_axes=0, out_axes=0, axis_name='x' + )( + operand, output, input_offsets, send_sizes, output_offsets, recv_sizes + ) + self.assertEqual(res.shape, (2, 2, 4)) + + @parameterized.named_parameters( + dict( + testcase_name='_batch_0_data_shard_axis_0_input_0', + axis_name='x', + vmap_axis_name='y', + mesh_axes=dict(x=2, y=2), + vmap_batch_axis=0, + data_shard_axis=0, + input_config=0, + ), + dict( + testcase_name='_batch_0_data_shard_axis_1_input_0', + axis_name='x', + vmap_axis_name='y', + mesh_axes=dict(x=2, y=2), + vmap_batch_axis=0, + data_shard_axis=1, + input_config=0, + ), + dict( + testcase_name='_batch_1_data_shard_axis_0_input_1', + axis_name='x', + vmap_axis_name='y', + mesh_axes=dict(x=2, y=2), + vmap_batch_axis=1, + data_shard_axis=0, + input_config=1, + ), + dict( + testcase_name='_batch_1_data_shard_axis_1_input_1', + axis_name='x', + vmap_axis_name='y', + mesh_axes=dict(x=2, y=2), + vmap_batch_axis=1, + data_shard_axis=1, + input_config=1, + ), + ) + def test_ragged_all_to_all_vmap( + self, + axis_name, + vmap_axis_name, + mesh_axes, + vmap_batch_axis, + data_shard_axis, + input_config, + ): + device_type = jax.devices()[0].platform + if device_type == 'tpu' and jtu.get_tpu_version() < 4: + raise unittest.SkipTest( + 'UNSUPPORTED: HLO opcode `ragged-all-to-all` is not supported by TPU' + f' v{jtu.get_tpu_version()}' + ) + mesh = jtu.create_mesh(tuple(mesh_axes.values()), tuple(mesh_axes.keys())) + + def get_data_sharding(axis): + if axis == 0: + return P(axis_name, None, None) + elif axis == 1: + return P(None, axis_name, None) + else: + raise ValueError("Invalid data_shard_axis") + + data_sharding = get_data_sharding(data_shard_axis) + + if input_config == 0: + operand_data = jnp.array([[[1, 2, 3], [4, 5, 6]], + [[1, 2, 3], [4, 5, 6]]], dtype=jnp.int32) + send_sizes_data = jnp.array([[[1, 2], [1, 1]], + [[1, 2], [1, 1]]], dtype=jnp.int32) + output_offsets_data = jnp.array([[[0, 0], [1, 2]], + [[0, 0], [1, 2]]], dtype=jnp.int32) + recv_sizes_data = jnp.array([[[1, 1], [2, 1]], + [[1, 1], [2, 1]]], dtype=jnp.int32) + elif input_config == 1: + operand_data = jnp.array([[[1, 2, 3], [1, 2, 3]], + [[4, 5, 6], [4, 5, 6]]], dtype=jnp.int32) + send_sizes_data = jnp.array([[[1, 2], [1, 2]], + [[1, 1], [1, 1]]], dtype=jnp.int32) + output_offsets_data = jnp.array([[[0, 0], [0, 0]], + [[1, 2], [1, 2]]], dtype=jnp.int32) + recv_sizes_data = jnp.array([[[1, 1], [1, 1]], + [[2, 1], [2, 1]]], dtype=jnp.int32) + else: + raise ValueError("Invalid input config") + + output_data = jnp.zeros((2, 2, 4), dtype=jnp.int32) + input_offsets_data = jnp.array([[[0, 1], [0, 1]], + [[0, 1], [0, 1]]], dtype=jnp.int32) + + operand = jax.device_put(operand_data, jax.sharding.NamedSharding(mesh, data_sharding)) + output = jax.device_put(output_data, jax.sharding.NamedSharding(mesh, data_sharding)) + input_offsets = jax.device_put(input_offsets_data, jax.sharding.NamedSharding(mesh, data_sharding)) + send_sizes = jax.device_put(send_sizes_data, jax.sharding.NamedSharding(mesh, data_sharding)) + output_offsets = jax.device_put(output_offsets_data, jax.sharding.NamedSharding(mesh, data_sharding)) + recv_sizes = jax.device_put(recv_sizes_data, jax.sharding.NamedSharding(mesh, data_sharding)) + + @partial( + shard_map, + mesh=mesh, + in_specs=( + P(axis_name, None), + P(axis_name, None), + P(axis_name, None), + P(axis_name, None), + P(axis_name, None), + P(axis_name, None), + ), + out_specs=P(axis_name), + check_vma=False, + ) + def fwd( + operand, output, input_offsets, send_sizes, output_offsets, recv_sizes + ): + return lax.ragged_all_to_all( + operand=operand.reshape(operand.shape[1:]), + output=output.reshape(output.shape[1:]), + input_offsets=input_offsets.reshape(input_offsets.shape[1:]), + send_sizes=send_sizes.reshape(send_sizes.shape[1:]), + output_offsets=output_offsets.reshape(output_offsets.shape[1:]), + recv_sizes=recv_sizes.reshape(recv_sizes.shape[1:]), + axis_name=axis_name, + ) + + res = vmap( + fwd, in_axes=vmap_batch_axis, out_axes=0, axis_name=vmap_axis_name + )( + operand, output, input_offsets, send_sizes, output_offsets, recv_sizes + ) + expected_res = jnp.array([[[1, 4, 0, 0], [2, 3, 5, 0]], + [[1, 4, 0, 0], [2, 3, 5, 0]]], dtype=jnp.int32) + self.assertAllClose(res, expected_res) + + def test_ragged_all_to_all_vmap_unsupported_axis_index_groups(self): + device_type = jax.devices()[0].platform + if device_type == 'tpu' and jtu.get_tpu_version() < 4: + raise unittest.SkipTest( + 'UNSUPPORTED: HLO opcode `ragged-all-to-all` is not supported by TPU' + f' v{jtu.get_tpu_version()}' + ) + + axis_name = 'x' + mesh_axes = dict(x=2) + mesh = jtu.create_mesh(tuple(mesh_axes.values()), tuple(mesh_axes.keys())) + data_sharding = P(axis_name, None, None) + operand_data = jnp.zeros((2, 2, 3), dtype=jnp.int32) + output_data = jnp.zeros((2, 2, 4), dtype=jnp.int32) + input_offsets_data = jnp.zeros((2, 2, 2), dtype=jnp.int32) + send_sizes_data = jnp.zeros((2, 2, 2), dtype=jnp.int32) + output_offsets_data = jnp.zeros((2, 2, 2), dtype=jnp.int32) + recv_sizes_data = jnp.zeros((2, 2, 2), dtype=jnp.int32) + + operand = jax.device_put(operand_data, jax.sharding.NamedSharding(mesh, data_sharding)) + output = jax.device_put(output_data, jax.sharding.NamedSharding(mesh, data_sharding)) + input_offsets = jax.device_put(input_offsets_data, jax.sharding.NamedSharding(mesh, data_sharding)) + send_sizes = jax.device_put(send_sizes_data, jax.sharding.NamedSharding(mesh, data_sharding)) + output_offsets = jax.device_put(output_offsets_data, jax.sharding.NamedSharding(mesh, data_sharding)) + recv_sizes = jax.device_put(recv_sizes_data, jax.sharding.NamedSharding(mesh, data_sharding)) + + @partial( + shard_map, + mesh=mesh, + in_specs=( + P(axis_name, None), + P(axis_name, None), + P(axis_name, None), + P(axis_name, None), + P(axis_name, None), + P(axis_name, None), + ), + out_specs=P(axis_name), + check_vma=False, + ) + def fwd( + operand, output, input_offsets, send_sizes, output_offsets, recv_sizes + ): + return lax.ragged_all_to_all( + operand=operand.reshape(operand.shape[1:]), + output=output.reshape(output.shape[1:]), + input_offsets=input_offsets.reshape(input_offsets.shape[1:]), + send_sizes=send_sizes.reshape(send_sizes.shape[1:]), + output_offsets=output_offsets.reshape(output_offsets.shape[1:]), + recv_sizes=recv_sizes.reshape(recv_sizes.shape[1:]), + axis_name=axis_name, + axis_index_groups=[[0, 1]], + ) + + with self.assertRaisesWithLiteralMatch( + NotImplementedError, 'Please open a feature request!'): + vmap(fwd, in_axes=0, out_axes=0, axis_name='b')(operand, output, input_offsets, send_sizes, output_offsets, recv_sizes) + def test_ragged_all_to_all_errors(self): operand = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], dtype=jnp.float32) output = jnp.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], dtype=jnp.float32) diff --git a/tests/random_lax_test.py b/tests/random_lax_test.py index b6f8b4f132bf..9fe4d2ecbda3 100644 --- a/tests/random_lax_test.py +++ b/tests/random_lax_test.py @@ -46,7 +46,7 @@ @jtu.with_config(jax_legacy_prng_key='allow') -class LaxRandomTest(jtu.JaxTestCase): +class RandomTestBase(jtu.JaxTestCase): def _CheckCollisions(self, samples, nbits): fail_prob = 0.01 # conservative bound on statistical fail prob by Chebyshev @@ -110,6 +110,11 @@ def _CheckChiSquared(self, samples, pmf, *, pval=None): def make_key(self, seed): return random.PRNGKey(seed, impl='threefry2x32') + +class CommonRandomTest(RandomTestBase): + """ + Tests of common functionality that should be run with all PRNG impls. + """ @jtu.sample_product( num=(None, 6, (6,), (2, 3), (2, 3, 4)), ) @@ -164,6 +169,60 @@ def testRngRandint(self, dtype): self.assertTrue(np.all(lo <= samples)) self.assertTrue(np.all(samples < hi)) + def test_eval_shape_big_random_array(self): + def f(x): + return random.normal(self.make_key(x), (int(1e12),)) + with jax.enable_checks(False): # check_jaxpr will materialize array + jax.eval_shape(f, 0) # doesn't error + + @jtu.sample_product( + type_=["int", "np.array", "jnp.array"], + seed=[-1, 0, 1, (1 << 32) - 1, (1 << 63) - 1, np.uint64((1 << 64) - 1)], + ) + def test_prng_jit_invariance(self, seed, type_): + if type_ == "int" and seed == (1 << 64) - 1: + self.skipTest("Expected failure: Python int too large.") + if not config.enable_x64.value and seed > np.iinfo(np.int32).max: + self.skipTest("Expected failure: Python int too large.") + type_ = {"int": int, "np.array": np.array, "jnp.array": jnp.array}[type_] + args_maker = lambda: [type_(seed)] + f = lambda s: random.key_data(self.make_key(s)) + self._CompileAndCheck(f, args_maker) + + def test_prng_errors(self): + seed = np.iinfo(np.int64).max + 1 + with self.assertRaises(OverflowError): + self.make_key(seed) + with self.assertRaises(OverflowError): + jax.jit(self.make_key)(seed) + + def test_random_split_doesnt_device_put_during_tracing(self): + key = self.make_key(1).block_until_ready() + with jtu.count_device_put() as count: + jax.jit(random.split)(key) + self.assertLessEqual(count(), 1) # 1 for the argument device_put + + def test_large_prng(self): + # https://github.com/jax-ml/jax/issues/11010 + def f(): + return random.uniform( + self.make_key(3), (308000000, 128), dtype=jnp.bfloat16) + + # TODO(jakevdp): key reuse checks for this OOM because of slice masking. + # Can we fix this? + with jax.debug_key_reuse(False): + # just lower, don't run, takes too long + jax.jit(f).lower() + + +class DistributionsTest(RandomTestBase): + """ + Tests of distribution statistics that need only be run with the default PRNG. + + We limit this to the default PRNG to avoid repeated execution of very costly + tests. So long as the input bits are valid (as tested in BasicRandomTest) then + the distribution logic tested here will apply correctly. + """ @jtu.sample_product(dtype=float_dtypes) def testNormal(self, dtype): key = lambda: self.make_key(0) @@ -227,8 +286,9 @@ def testTruncatedNormal(self, dtype): ], dtype=jtu.dtypes.floating + jtu.dtypes.integer, weighted=[True, False], + mode=[None, 'low', 'high'] ) - def testChoice(self, dtype, input_range_or_shape, shape, replace, weighted, axis): + def testChoice(self, dtype, input_range_or_shape, shape, replace, weighted, axis, mode): # This is the function API that we test against (note that self.rng().choice differs) np_choice = np.random.default_rng(0).choice p_dtype = dtypes.to_inexact_dtype(dtype) @@ -244,7 +304,7 @@ def testChoice(self, dtype, input_range_or_shape, shape, replace, weighted, axis p /= p.sum() else: p = None - rand = lambda key, x: random.choice(key, x, shape, replace, p, axis) + rand = lambda key, x: random.choice(key, x, shape, replace, p, axis, mode=mode) sample = rand(key(), x) if not is_range: self.assertEqual(dtype, sample.dtype) @@ -313,11 +373,13 @@ def testPermutationErrors(self): @jtu.sample_product( p=[0.1, 0.5, 0.9], dtype=jtu.dtypes.floating, + mode=[None, 'low', 'high'], ) - def testBernoulli(self, p, dtype): + def testBernoulli(self, p, dtype, mode): key = lambda: self.make_key(0) p = np.array(p, dtype=dtype) - rand = lambda key, p: random.bernoulli(key, p, (10000,)) + kwds = {} if mode is None else {'mode': mode} + rand = lambda key, p: random.bernoulli(key, p, (10000,), **kwds) crand = jax.jit(rand) uncompiled_samples = rand(key(), p) @@ -336,15 +398,16 @@ def testBernoulli(self, p, dtype): ] ], sample_shape=[(10000,), (5000, 2)], + mode=[None, 'low', 'high'], dtype=jtu.dtypes.floating, ) - def testCategorical(self, p, axis, dtype, sample_shape): + def testCategorical(self, p, axis, dtype, sample_shape, mode): key = lambda: self.make_key(0) p = np.array(p, dtype=dtype) logits = np.log(p) - 42 # test unnormalized out_shape = tuple(np.delete(logits.shape, axis)) shape = sample_shape + out_shape - rand = partial(random.categorical, shape=shape, axis=axis) + rand = partial(random.categorical, shape=shape, axis=axis, mode=mode) crand = jax.jit(rand) uncompiled_samples = rand(key(), logits) @@ -396,13 +459,29 @@ def testCategoricalWithoutReplacement(self, logits_shape, prefix_shape): counts = jax.vmap(partial(jnp.bincount, length=n_categories), 1)(flat) assert (counts <= 1).all() - def testBernoulliShape(self): key = self.make_key(0) with jax.numpy_rank_promotion('allow'): x = random.bernoulli(key, np.array([0.2, 0.3]), shape=(3, 2)) assert x.shape == (3, 2) + def testBernoulliSmallProbabilty(self): + # Regression test for https://github.com/jax-ml/jax/issues/28017 + key = jax.random.key(0) + + # Choose such that N * p is much less than 1. + p = jnp.float32(1E-10) + N = int(1E8) + + # mode='low' fails for p<~1E-7 in float32 + samples = jax.random.bernoulli(key, p=p, shape=N, mode='low') + self.assertNotEqual(samples.sum(), 0) + + # mode='high' is good up to p<~1E-14 in float32 + samples = jax.random.bernoulli(key, p=p, shape=N, mode='high') + self.assertEqual(samples.sum(), 0) + + @jtu.sample_product( a=[0.2, 5.], b=[0.2, 5.], @@ -1071,39 +1150,6 @@ def testChoiceShapeIsNotSequenceError(self): with self.assertRaises(TypeError): random.choice(key, 5, 2, replace=True) - def test_eval_shape_big_random_array(self): - def f(x): - return random.normal(self.make_key(x), (int(1e12),)) - with jax.enable_checks(False): # check_jaxpr will materialize array - jax.eval_shape(f, 0) # doesn't error - - @jtu.sample_product( - type_=["int", "np.array", "jnp.array"], - seed=[-1, 0, 1, (1 << 32) - 1, (1 << 63) - 1, np.uint64((1 << 64) - 1)], - ) - def test_prng_jit_invariance(self, seed, type_): - if type_ == "int" and seed == (1 << 64) - 1: - self.skipTest("Expected failure: Python int too large.") - if not config.enable_x64.value and seed > np.iinfo(np.int32).max: - self.skipTest("Expected failure: Python int too large.") - type_ = {"int": int, "np.array": np.array, "jnp.array": jnp.array}[type_] - args_maker = lambda: [type_(seed)] - f = lambda s: random.key_data(self.make_key(s)) - self._CompileAndCheck(f, args_maker) - - def test_prng_errors(self): - seed = np.iinfo(np.int64).max + 1 - with self.assertRaises(OverflowError): - self.make_key(seed) - with self.assertRaises(OverflowError): - jax.jit(self.make_key)(seed) - - def test_random_split_doesnt_device_put_during_tracing(self): - key = self.make_key(1).block_until_ready() - with jtu.count_device_put() as count: - jax.jit(random.split)(key) - self.assertLessEqual(count(), 1) # 1 for the argument device_put - @jtu.sample_product(dtype=int_dtypes + uint_dtypes) def test_randint_bounds(self, dtype): min = np.iinfo(dtype).min @@ -1131,18 +1177,6 @@ def test_randint_out_of_range(self): self.assertGreater((r == 0).sum(), 0) self.assertGreater((r == 255).sum(), 0) - def test_large_prng(self): - # https://github.com/jax-ml/jax/issues/11010 - def f(): - return random.uniform( - self.make_key(3), (308000000, 128), dtype=jnp.bfloat16) - - # TODO(jakevdp): key reuse checks for this OOM because of slice masking. - # Can we fix this? - with jax.debug_key_reuse(False): - # just lower, don't run, takes too long - jax.jit(f).lower() - @jtu.sample_product(shape=[(3, 4)], logits_shape_base=[(3, 4), (3, 1), (1, 4)], axis=[-3, -2, -1, 0, 1, 2]) @@ -1461,7 +1495,7 @@ def _double_threefry_fold_in(key, data): tag='fry2') @jtu.with_config(jax_default_prng_impl='threefry2x32') -class LaxRandomWithCustomPRNGTest(LaxRandomTest): +class CustomPRNGTest(CommonRandomTest): def make_key(self, seed): return prng_internal.random_seed(seed, impl=double_threefry_prng_impl) @@ -1522,7 +1556,7 @@ def test_grad_of_prng_key(self): @jtu.with_config(jax_default_prng_impl='rbg') -class LaxRandomWithRBGPRNGTest(LaxRandomTest): +class RBGPRNGTest(CommonRandomTest): def make_key(self, seed): return random.PRNGKey(seed, impl='rbg') @@ -1634,7 +1668,7 @@ def test_randint_out_of_range(self): @jtu.with_config(jax_default_prng_impl='unsafe_rbg') -class LaxRandomWithUnsafeRBGPRNGTest(LaxRandomWithRBGPRNGTest): +class UnsafeRBGPRNGTest(RBGPRNGTest): def make_key(self, seed): return random.PRNGKey(seed, impl="unsafe_rbg") @@ -1648,24 +1682,6 @@ def test_vmap_split_mapped_key_values(self): self.assertArraysEqual(random.key_data(vmapped_keys), random.key_data(ref_keys)) -def _sampler_unimplemented_with_custom_prng(*args, **kwargs): - raise SkipTest('sampler only implemented for default RNG') - -for test_prefix in [ - 'testPoisson', - 'testPoissonBatched', - 'testPoissonShape', - 'testPoissonZeros', -]: - for attr in dir(LaxRandomTest): - if attr.startswith(test_prefix): - setattr(LaxRandomWithCustomPRNGTest, attr, - _sampler_unimplemented_with_custom_prng) - setattr(LaxRandomWithRBGPRNGTest, attr, - _sampler_unimplemented_with_custom_prng) - setattr(LaxRandomWithUnsafeRBGPRNGTest, attr, - _sampler_unimplemented_with_custom_prng) - if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/random_test.py b/tests/random_test.py index a51e387dca76..86a1622240dc 100644 --- a/tests/random_test.py +++ b/tests/random_test.py @@ -338,7 +338,7 @@ def testRandomDistributionValues(self, case, make_key): Any refactoring of random distributions that leads to non-trivial differences in this test should follow the procedure outlined at - https://jax.readthedocs.io/en/latest/api_compatibility.html#numerics-and-randomness + https://docs.jax.dev/en/latest/api_compatibility.html#numerics-and-randomness This includes: * Announcing the change in the CHANGELOG.md @@ -602,10 +602,26 @@ def assertKeysEqual(self, key1, key2): self.assertEqual(key1.dtype, key2.dtype) self.assertArraysEqual(random.key_data(key1), random.key_data(key2)) + def make_keys(self, *shape, seed=28): + seeds = seed + jnp.arange(math.prod(shape), dtype=jnp.uint32) + return jax.vmap(random.key)(seeds).reshape(shape) + def test_construction(self): key = random.key(42) self.assertIsInstance(key, prng_internal.PRNGKeyArray) + def test_numpy_construction(self): + key = random.wrap_key_data(np.array([42, 173], dtype=np.uint32), + impl='threefry2x32') + self.assertIsInstance(key, prng_internal.PRNGKeyArray) + self.assertIsInstance(key._base_array, jax.Array) + self.assertEqual(key._base_array.device, jax.devices()[0]) + self.assertEqual(key.device, jax.devices()[0]) + + def test_device_property(self): + key = random.key(42) + self.assertEqual(key.device, key._base_array.device) + def test_random_clone(self): # Here we test value semantics and compatibility with jit/vmap # key reuse semantics are tested in key_reuse_test.py @@ -632,10 +648,6 @@ def test_construction_upgrade_flag(self): key = random.PRNGKey(42) self.assertIsInstance(key, prng_internal.PRNGKeyArray) - def make_keys(self, *shape, seed=28): - seeds = seed + jnp.arange(math.prod(shape), dtype=jnp.uint32) - return jax.vmap(random.key)(seeds).reshape(shape) - def test_key_as_seed(self): key = self.make_keys() with self.assertRaisesRegex(TypeError, "PRNGKey accepts a scalar seed"): @@ -657,6 +669,11 @@ def test_non_integer_seed(self): with self.assertRaisesRegex(TypeError, "PRNG key seed must be an integer"): random.key(seed) + def test_nbytes_property(self): + key = self.make_keys() + self.assertEqual(key.nbytes, key._base_array.nbytes) + self.assertEqual(key.nbytes, key.itemsize * key.size) + def test_dtype_property(self): k1, k2 = self.make_keys(), self.make_keys() self.assertEqual(k1.dtype, k2.dtype) @@ -974,7 +991,7 @@ def callback(index): def test_make_array_from_single_device_arrays(self): devices = jax.devices() shape = (len(devices),) - mesh = jtu.create_mesh((len(devices),), ('x',)) + mesh = jtu.create_mesh((len(devices),), ('x',), iota_order=True) sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('x')) keys = random.split(random.key(0), len(devices)) arrays = [jax.device_put(keys[i:i + 1], device) for i, device in enumerate(devices)] diff --git a/tests/roofline_test.py b/tests/roofline_test.py index 564b4a9a1f9e..4dd8ca6c4759 100644 --- a/tests/roofline_test.py +++ b/tests/roofline_test.py @@ -14,11 +14,10 @@ from __future__ import annotations from functools import partial -from typing import Sequence +from collections.abc import Sequence from absl.testing import absltest import jax -from jax._src import mesh from jax._src import test_util as jtu from jax.experimental import roofline import jax.lax as lax @@ -29,6 +28,8 @@ jax.config.parse_flags_with_absl() jtu.request_cpu_devices(8) +_VERY_LARGE_NUMBER = 512 * 1024 + def create_inputs( *shardings: P, @@ -45,6 +46,31 @@ def create_inputs( return mesh, tuple(arrays) +def example_function(x): + return jnp.sin(x) + x**2 + + +@jax.custom_jvp +def example_custom_function(x): + """Example custom function. + + Small wrapper around `example_function`. We define `example_custom_function` + separately since we add the `@jax.custom_jvp` decorator and want to compare + its behavior to `example_function`'s in tests. + """ + return example_function(x) + + +@example_custom_function.defjvp +def example_custom_function_jvp(primals, tangents): + """Example custom function jvp. + + Normally this function would define a mathematically correct JVP, but its + definition has 0 effect on the roofline result, so we keep it very simple. + """ + return example_custom_function(primals), tangents + + class RooflineTest(jtu.JaxTestCase): def setUp(self): @@ -465,18 +491,13 @@ def collective_matmul(a, b): ) def test_unary_ops(self, f, dtype): data = jnp.zeros((3, 8), dtype=dtype) - out, result = roofline.roofline( - f, - in_specs=(P()), - out_specs=P(), - )(data) - with self.subTest("flops"): - self.assertEqual(result.unfused_flops, 3 * 8) - with self.subTest("hbm_bytes"): - self.assertEqual( - result.unfused_hbm_bytes, - data.dtype.itemsize * 3 * 8 + out.dtype.itemsize * 3 * 8, - ) + out, result = roofline.roofline(f)(data) + + self.assertEqual(result.unfused_flops, 3 * 8) + self.assertEqual( + result.unfused_hbm_bytes, + data.dtype.itemsize * 3 * 8 + out.dtype.itemsize * 3 * 8, + ) def test_binary_ops(self): for f in [ @@ -495,12 +516,9 @@ def test_binary_ops(self): lambda a, b: jnp.minimum(a, b), lambda a, b: jnp.maximum(a, b), ]: - out, result = roofline.roofline( - f, - mesh=mesh.AbstractMesh((), ()), - in_specs=(P(), P()), - out_specs=P(), - )(jnp.zeros((3, 8), dtype=int), jnp.ones((3, 8), dtype=int)) + out, result = roofline.roofline(f)( + jnp.zeros((3, 8), dtype=int), jnp.ones((3, 8), dtype=int) + ) self.assertEqual(result.unfused_flops, 3 * 8) self.assertEqual( result.unfused_hbm_bytes, @@ -515,12 +533,7 @@ def test_broadcast(self): (2.0, jnp.ones((3, 8))), (jnp.zeros((3, 8)), 2.0), ]: - _, result = roofline.roofline( - lambda a, b: a + b, - mesh=mesh.AbstractMesh((), ()), - in_specs=(P(), P()), - out_specs=P(), - )(left, right) + _, result = roofline.roofline(lambda a, b: a + b)(left, right) self.assertEqual(result.unfused_flops, 3 * 8) def test_nested(self): @@ -531,27 +544,21 @@ def g(x): return g(x) + g(y) - _, result = roofline.roofline( - f, - mesh=mesh.AbstractMesh((), ()), - in_specs=(P(), P()), - out_specs=P(), - )(jnp.zeros((11, 4), dtype=int), jnp.ones((11, 4), dtype=int)) + _, result = roofline.roofline(f)( + jnp.zeros((11, 4), dtype=int), jnp.ones((11, 4), dtype=int) + ) self.assertEqual(result.unfused_flops, 3 * (11 * 4)) def test_no_mesh(self): - _, result = roofline.roofline( - lambda a, b: a + b, - in_specs=(P(), P()), - out_specs=P(), - )(jnp.zeros((3, 8), dtype=int), jnp.ones((3, 8), dtype=int)) + _, result = roofline.roofline(lambda a, b: a + b)( + jnp.zeros((3, 8), dtype=int), jnp.ones((3, 8), dtype=int) + ) self.assertEqual(result.unfused_flops, 3 * 8) def test_no_specs(self): - _, result = roofline.roofline( - lambda a, b: a + b, - mesh=mesh.AbstractMesh((), ()), - )(jnp.zeros((3, 8), dtype=int), jnp.ones((3, 8), dtype=int)) + _, result = roofline.roofline(lambda a, b: a + b)( + jnp.zeros((3, 8), dtype=int), jnp.ones((3, 8), dtype=int) + ) self.assertEqual(result.unfused_flops, 3 * 8) def test_no_mesh_and_no_specs(self): @@ -560,63 +567,109 @@ def test_no_mesh_and_no_specs(self): )(jnp.zeros((3, 8), dtype=int), jnp.ones((3, 8), dtype=int)) self.assertEqual(result.unfused_flops, 3 * 8) + @jtu.parameterized.product( + cumulative_function=[lax.cummax, lax.cummin, lax.cumprod, lax.cumsum], + axis=[0, 1, 2], + ) + def test_cumulative_ops(self, cumulative_function: int, axis: int): + f = lambda x: cumulative_function(operand=x, axis=axis) + x = jnp.zeros((3, 8, 15), dtype=int) + + _, result = roofline.roofline(f)(x) + + self.assertEqual(result.unfused_flops, x.shape[axis]) + self.assertEqual( + result.unfused_hbm_bytes, 2 * self._bytes_per_word * 3 * 8 * 15 + ) + + @jtu.parameterized.named_parameters( + dict(testcase_name="axis_0", axis=0), + dict(testcase_name="axis_1", axis=1), + dict(testcase_name="axis_2", axis=2), + ) + def test_cumlogsumexp_p_roofline(self, axis: int): + f = lambda x: lax.cumlogsumexp(operand=x, axis=axis) + x = jnp.zeros((3, 8, 15), dtype=int) + + _, result = roofline.roofline(f)(x) + + self.assertEqual(result.unfused_flops, 2 * x.shape[axis]) + self.assertEqual( + result.unfused_hbm_bytes, 2 * self._bytes_per_word * 3 * 8 * 15 + ) + def test_dot_general(self): - _, result = roofline.roofline( - lambda a, b: a @ b, - mesh=mesh.AbstractMesh((), ()), - in_specs=(P(), P()), - out_specs=P(), - )(jnp.zeros((3, 7), dtype=int), jnp.ones((7, 5), dtype=int)) + _, result = roofline.roofline(lambda a, b: a @ b)( + jnp.zeros((3, 7), dtype=int), jnp.ones((7, 5), dtype=int) + ) self.assertEqual(result.unfused_flops, 2 * 3 * 7 * 5) self.assertEqual( result.unfused_hbm_bytes, self._bytes_per_word * (3 * 7 + 7 * 5 + 3 * 5) ) - def get_conv_output_dim(self, i, k, pad_low, pad_high, stride): + def get_conv_output_dim(self, i, k, pad_low, pad_high, stride) -> int: return jnp.floor((i - k + pad_low + pad_high) / stride) + 1 - @jtu.parameterized.named_parameters( - dict( - testcase_name="simple", - window_strides=(1, 1), - padding=((0, 0), (0, 0)), - ), - dict( - testcase_name="padding", - window_strides=(1, 1), - padding=((1, 2), (3, 4)), - ), - dict( - testcase_name="window_strides", - window_strides=(2, 2), - padding=((0, 0), (0, 0)), - ), - dict( - testcase_name="window_strides_and_padding", - window_strides=(3, 3), - padding=((1, 2), (3, 4)), - ), + def get_conv_num_output_channels( + self, batch_group_count: int, feature_group_count: int + ) -> int: + if batch_group_count > 1: + return batch_group_count + elif feature_group_count > 1: + return feature_group_count + else: + return 1 + + @jtu.parameterized.product( + window_strides=[(1, 1), (2, 2)], + padding=[((0, 0), (0, 0)), ((1, 2), (3, 4))], + # batch must be divisible by batch_group_count, so we only include factors + # of batch_group_count. + batch=[6, 12], + batch_group_count=[1, 3], + # num_input_channels must be divisible by feature_group_count, so we only + # include factors of feature_group_count. + num_input_channels=[6, 12], + feature_group_count=[1, 3], ) def test_conv_general_dilated_unfused_hbm_bytes( - self, window_strides: Sequence[int, int], padding: Sequence[int, int] + self, + window_strides: Sequence[int, int], + padding: Sequence[int, int], + batch: int, + batch_group_count: int, + num_input_channels: int, + feature_group_count: int, ): + if batch_group_count > 1 and feature_group_count > 1: + self.skipTest( + "batch_group_count and feature_group_count cannot both be > 1" + ) + + num_output_channels = self.get_conv_num_output_channels( + batch_group_count, feature_group_count + ) + + num_input_features = int(num_input_channels / feature_group_count) iw, ih = 100, 200 kw, kh = 7, 7 - input_data = jnp.zeros((1, 1, iw, ih), dtype=int) - kernel_data = jnp.ones((1, 1, kw, kh), dtype=int) + input_data = jnp.zeros((batch, num_input_channels, iw, ih), dtype=int) + kernel_data = jnp.ones( + (num_output_channels, num_input_features, kw, kh), dtype=int + ) conv = lambda a, b: lax.conv_general_dilated( - lhs=a, rhs=b, window_strides=window_strides, padding=padding + lhs=a, + rhs=b, + window_strides=window_strides, + padding=padding, + batch_group_count=batch_group_count, + feature_group_count=feature_group_count, ) - _, result = roofline.roofline( - conv, - mesh=mesh.AbstractMesh((), ()), - in_specs=(P(), P()), - out_specs=P(), - )(input_data, kernel_data) + _, result = roofline.roofline(conv)(input_data, kernel_data) - expected_input_size = 1 * 1 * iw * ih - expected_kernel_size = 1 * 1 * kw * kh + expected_input_size = batch * num_input_channels * iw * ih + expected_kernel_size = num_output_channels * num_input_features * kw * kh ow = self.get_conv_output_dim( iw, kw, padding[0][0], padding[0][1], window_strides[0] @@ -624,12 +677,14 @@ def test_conv_general_dilated_unfused_hbm_bytes( oh = self.get_conv_output_dim( ih, kh, padding[1][0], padding[1][1], window_strides[1] ) - expected_output_size = 1 * 1 * ow * oh + expected_output_shape = jnp.array( + (batch / batch_group_count, num_output_channels, ow, oh) + ) + expected_output_size = jnp.prod(expected_output_shape) # Bytes accessed is sum of inputs and output. expected_unfused_hbm_bytes = self._bytes_per_word * ( expected_input_size + expected_kernel_size + expected_output_size ) - # TODO(b/394648206): add subtest for unfused_flops once they are supported. self.assertEqual(result.unfused_hbm_bytes, expected_unfused_hbm_bytes) @jtu.parameterized.named_parameters( @@ -642,24 +697,22 @@ def test_conv_general_dilated_unfused_hbm_bytes( padding="SAME_LOWER", ), ) - def test_conv_general_dilated_padding_string_unfused_hbm_bytes(self, padding: str): - input_data = jnp.zeros((1, 1, 10, 20), dtype=int) + def test_conv_general_dilated_padding_string( + self, padding: str + ): + input_data = jnp.zeros((1, 1, 3, 3), dtype=int) kernel_data = jnp.ones((1, 1, 3, 3), dtype=int) conv = lambda a, b: lax.conv_general_dilated( lhs=a, rhs=b, window_strides=(1, 1), padding=padding ) - _, result = roofline.roofline( - conv, - mesh=mesh.AbstractMesh((), ()), - in_specs=(P(), P()), - out_specs=P(), - )(input_data, kernel_data) + _, result = roofline.roofline(conv)(input_data, kernel_data) - expected_input_size = 1 * 1 * 10 * 20 + # Test hbm bytes. + expected_input_size = 1 * 1 * 3 * 3 expected_kernel_size = 1 * 1 * 3 * 3 # Because of same{_lower} padding, output shape should equal to input shape. - # This may not be true for other `{feature, batch}`_group_count`s.c + # This may not be true for other `{feature, batch}`_group_count`s. expected_output_size = expected_input_size # Bytes accessed is sum of inputs and output. expected_unfused_hbm_bytes = self._bytes_per_word * ( @@ -667,19 +720,28 @@ def test_conv_general_dilated_padding_string_unfused_hbm_bytes(self, padding: st ) self.assertEqual(result.unfused_hbm_bytes, expected_unfused_hbm_bytes) - def test_conv_general_dilated_padding_string_valid_unfused_hbm_bytes(self): + # Test flops. + # For spatial_valid_position_counts, we have 3x3 output with the following + # flops for each element: + # 4 6 4 + # 6 9 6 + # 4 6 4 + # Non_spatial_dims_factor = 1 because `{batch, feature}_group_count` are + # both equal to 1. + # Each FMA is 2 flops. + self.assertEqual( + result.unfused_flops, + 2 * (4 + 6 + 4 + 6 + 9 + 6 + 4 + 6 + 4), + ) + + def test_conv_general_dilated_padding_string_valid(self): input_data = jnp.zeros((1, 1, 10, 20), dtype=int) kernel_data = jnp.ones((1, 1, 3, 3), dtype=int) conv = lambda a, b: lax.conv_general_dilated( lhs=a, rhs=b, window_strides=(1, 1), padding="VALID" ) - _, result = roofline.roofline( - conv, - mesh=mesh.AbstractMesh((), ()), - in_specs=(P(), P()), - out_specs=P(), - )(input_data, kernel_data) + _, result = roofline.roofline(conv)(input_data, kernel_data) expected_input_size = 1 * 1 * 10 * 20 expected_kernel_size = 1 * 1 * 3 * 3 @@ -690,19 +752,91 @@ def test_conv_general_dilated_padding_string_valid_unfused_hbm_bytes(self): * self.get_conv_output_dim(10, 3, 0, 0, 1) * self.get_conv_output_dim(20, 3, 0, 0, 1) ) + # Bytes accessed is sum of inputs and output. expected_unfused_hbm_bytes = self._bytes_per_word * ( expected_input_size + expected_kernel_size + expected_output_size ) self.assertEqual(result.unfused_hbm_bytes, expected_unfused_hbm_bytes) + # Output shape is [1x1x8x18] and each output element requires (3x3) FMAs, + # and each FMA is 2 flops. + self.assertEqual( + result.unfused_flops, 2 * expected_output_size * 3 * 3 + ) + + @jtu.parameterized.named_parameters( + dict( + testcase_name="padding", + input_spatial_dim=1, + window_strides=[1], + padding=[(_VERY_LARGE_NUMBER - 1, _VERY_LARGE_NUMBER - 1)], + lhs_dilation=[1], + ), + dict( + testcase_name="input", + input_spatial_dim=_VERY_LARGE_NUMBER, + window_strides=[_VERY_LARGE_NUMBER - 1], + padding=[(0, 0)], + lhs_dilation=[_VERY_LARGE_NUMBER], + ), + ) + def test_conv_general_dilated_flops_very_large( + self, input_spatial_dim, window_strides, padding, lhs_dilation + ): + input_data = jnp.zeros((1, 1, input_spatial_dim), dtype=int) + kernel_data = jnp.ones((1, 1, _VERY_LARGE_NUMBER), dtype=int) + conv = lambda a, b: lax.conv_general_dilated( + lhs=a, + rhs=b, + window_strides=window_strides, + padding=padding, + lhs_dilation=lhs_dilation, + ) + _, result = roofline.roofline(conv)(input_data, kernel_data) + + self.assertEqual(result.unfused_flops, 2 * _VERY_LARGE_NUMBER) + + def test_conv_general_dilated_flops_feature_group_count(self): + feature_group_count = 120 + input_data = jnp.zeros((1, feature_group_count, 10, 20), dtype=int) + kernel_data = jnp.ones((feature_group_count, 1, 3, 3), dtype=int) + conv = lambda a, b: lax.conv_general_dilated( + lhs=a, + rhs=b, + window_strides=(1, 1), + padding=((0, 0), (0, 0)), + feature_group_count=feature_group_count, + ) + _, result = roofline.roofline(conv)(input_data, kernel_data) + + # Output shape is [1x120x8x18] and each output element requires (3x3) + # FMAs and one FMA is 2 flops. + self.assertEqual( + result.unfused_flops, 2 * 120 * 8 * 18 * 3 * 3 + ) + + def test_conv_general_dilated_flops_batch_group_count(self): + batch_group_count = 120 + input_data = jnp.zeros((batch_group_count, 1, 10, 20), dtype=int) + kernel_data = jnp.ones((batch_group_count, 1, 3, 3), dtype=int) + conv = lambda a, b: lax.conv_general_dilated( + lhs=a, + rhs=b, + window_strides=(1, 1), + padding=((0, 0), (0, 0)), + batch_group_count=batch_group_count, + ) + _, result = roofline.roofline(conv)(input_data, kernel_data) + + # Output shape is [120x1x8x18] and each output element requires (3x3) + # FMAs and one FMA is 2 flops. + self.assertEqual( + result.unfused_flops, 2 * 120 * 8 * 18 * 3 * 3 + ) + def test_reduce_sum_no_axis(self): - _, result = roofline.roofline( - lambda x: jnp.sum(x), - mesh=mesh.AbstractMesh((), ()), - in_specs=(P()), - out_specs=P(), - )(jnp.zeros((11, 4))) + _, result = roofline.roofline(lambda x: jnp.sum(x))(jnp.zeros((11, 4))) self.assertEqual(result.unfused_flops, 11 * 4 - 1) self.assertEqual( result.unfused_hbm_bytes, self._bytes_per_word * (11 * 4 + 1) @@ -715,17 +849,97 @@ def test_reduce_sum_with_axis(self): ([0, 1], 11 * 4 - 1, 11 * 4 + 1), ([], 0, 11 * 4 + 11 * 4), ]: - _, result = roofline.roofline( - lambda x: jnp.sum(x, axis=axis), - mesh=mesh.AbstractMesh((), ()), - in_specs=(P()), - out_specs=P(), - )(jnp.zeros((11, 4))) + _, result = roofline.roofline(lambda x: jnp.sum(x, axis=axis))( + jnp.zeros((11, 4)) + ) self.assertEqual(result.unfused_flops, expected_flops) self.assertEqual( result.unfused_hbm_bytes, self._bytes_per_word * expected_memory ) + def test_custom_jvp_call_p_roofline(self): + dummy_input = jnp.ones((3, 8)) + + _, base_result = roofline.roofline(example_function)(dummy_input) + _, custom_result = roofline.roofline(example_custom_function)(dummy_input) + + self.assertEqual(custom_result.unfused_flops, base_result.unfused_flops) + self.assertEqual( + custom_result.unfused_hbm_bytes, base_result.unfused_hbm_bytes + ) + + def test_custom_jvp_call_p_roofline_with_neg(self): + dummy_input = jnp.ones((3, 8)) + + def with_neg(f): + return lambda x: jax.lax.neg(f(x)) + + _, base_result = roofline.roofline(with_neg(example_function))(dummy_input) + _, custom_result = roofline.roofline(with_neg(example_custom_function))( + dummy_input + ) + + self.assertEqual(custom_result.unfused_flops, base_result.unfused_flops) + self.assertEqual( + custom_result.unfused_hbm_bytes, base_result.unfused_hbm_bytes + ) + + def test_gather_roofline(self): + operand = jnp.zeros((3, 3), dtype=jnp.int32) + indices = jnp.zeros((2, 1), dtype=jnp.int32) + + dimension_numbers = jax.lax.GatherDimensionNumbers( + offset_dims=(1,), + collapsed_slice_dims=(0,), + start_index_map=(0,), + ) + + f = lambda x, y: jax.lax.gather( + x, + y, + dimension_numbers=dimension_numbers, + slice_sizes=(1, 3), + ) + + _, result = roofline.roofline(f)(operand, indices) + + self.assertEqual(result.unfused_flops, 0) + # Expected bytes: + # operand: 2 * 3 * sizeof(int32) = 24 + # indices: 2 * 1 * sizeof(int32) = 8 + # output: 2 * 3 * sizeof(int32) = 24 + # total = 56 + self.assertEqual(result.unfused_hbm_bytes, 56) + + def test_gather_batching_dims_roofline(self): + operand = jnp.zeros((5, 3, 3), dtype=jnp.int32) + indices = jnp.zeros((5, 1), dtype=jnp.int32) + + dimension_numbers = jax.lax.GatherDimensionNumbers( + offset_dims=(1,), + collapsed_slice_dims=(1,), + start_index_map=(1,), + operand_batching_dims=(0,), + start_indices_batching_dims=(0,), + ) + + f = lambda x, y: jax.lax.gather( + x, + y, + dimension_numbers=dimension_numbers, + slice_sizes=(1, 1, 3), + ) + + _, result = roofline.roofline(f)(operand, indices) + + self.assertEqual(result.unfused_flops, 0) + # Expected bytes: + # operand: 5 * 3 * sizeof(int32) = 60 + # indices: 5 * 1 * sizeof(int32) = 20 + # output: 5 * 3 * sizeof(int32) = 60 + # total = 140 + self.assertEqual(result.unfused_hbm_bytes, 140) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/scaled_matmul_stablehlo_test.py b/tests/scaled_matmul_stablehlo_test.py index 141839a19a08..9830d6fefff7 100644 --- a/tests/scaled_matmul_stablehlo_test.py +++ b/tests/scaled_matmul_stablehlo_test.py @@ -47,10 +47,10 @@ c_name = "__cudnn$blockScaledDot" expected_hlos = [ (c_name, "all-reduce", "f32[1,512,512]", "replica_groups={{0,1},{2,3}}"), - ("all-gather", "f8e4m3fn[1,512,512]", "replica_groups=[2,2]<=[4]", c_name), - ("all-gather", "f8e4m3fn[1,512,512]", "replica_groups=[2,2]<=[4]", c_name), + ("all-gather", "f8e4m3fn[512,512]", "replica_groups=[2,2]<=[4]", c_name), + ("all-gather", "f8e4m3fn[512,512]", "replica_groups=[2,2]<=[4]", c_name), (c_name,), - ("all-gather", "f8e4m3fn[1,256,1024]", "replica_groups=[2,2]<=[4]", c_name), + ("all-gather", "f8e4m3fn[256,1024]", "replica_groups=[2,2]<=[4]", c_name), (c_name, "reduce-scatter", "f32[2,256,512]", "replica_groups={{0,1},{2,3}}"), ("all-gather", "f8e4m3fn[2,512,1024]", "replica_groups=[2,2]<=[4]", c_name), ("all-gather", "f8e4m3fn[2,512,512]", "replica_groups=[2,2]<=[4]", c_name), @@ -174,7 +174,7 @@ def update_global_scale(config, new_global_scale): config.global_scale = new_global_scale return config -def generate_nvfp4_quantized_tensors(dot_config, output_type): +def generate_nvfp4_quantized_tensors(dot_config, output_type, enable_grad_clip=False): k1, k2 = jax.random.split(jax.random.key(0), 2) a_shape, b_shape, dimension_numbers = dot_config @@ -194,6 +194,11 @@ def generate_nvfp4_quantized_tensors(dot_config, output_type): amax_a = jnp.max(jnp.abs(a)).astype(jnp.float32) amax_b = jnp.max(jnp.abs(b)).astype(jnp.float32) + # To emulate calibrated amax + amax_sf = 0.9 if enable_grad_clip else 1.0 + amax_a *= amax_sf + amax_b *= amax_sf + # Update global scales data_max = jnp.finfo(block_scale_configs_nvfp4[0].data_type).max.astype( jnp.float32 @@ -275,9 +280,9 @@ def setUp(self): self.skipTest(str(e)) return if _dtypes.float8_e8m0fnu is None: - self.skipTest("Requries >= ml_dtypes 0.5.0 to support float8_e8m0fnu") + self.skipTest("Requires >= ml_dtypes 0.5.0 to support float8_e8m0fnu") if _dtypes.float4_e2m1fn is None: - self.skipTest("Requries >= ml_dtypes 0.5.0 to support float4_e2m1fn") + self.skipTest("Requires >= ml_dtypes 0.5.0 to support float4_e2m1fn") if cudnn_version < 90700: self.skipTest("Requires >= cuDNN 9.7.0") if not jtu.is_cuda_compute_capability_at_least("10.0"): @@ -468,7 +473,7 @@ def setUp(self): self.skipTest(str(e)) return if _dtypes.float8_e8m0fnu is None: - self.skipTest("Requries >= ml_dtypes 0.5.0 to support float8_e8m0fnu") + self.skipTest("Requires >= ml_dtypes 0.5.0 to support float8_e8m0fnu") if cudnn_version < 90700: self.skipTest("Requires >= cuDNN 9.7.0") if not jtu.is_cuda_compute_capability_at_least("10.0"): @@ -508,6 +513,68 @@ def fn(a): self.assertArraysAllClose(out_q, out_q_ref, rtol=1e-5, atol=1e-5) self.assertArraysAllClose(scale, scale_ref, rtol=1e-5, atol=1e-5) + @jtu.sample_product( + enable_grad_clip=[True, False], + configs=[ + # a_shape, b_shape, dimension_numbers + ((1, 128, 128), (1, 128, 128), (([2], [2]), ([0], [0]))), + ((30, 64), (100, 64), (([1], [1]), ([], []))), + ] + ) + @jtu.run_on_devices("cuda") + def test_nvfp4_gradient_clip(self, enable_grad_clip, configs): + output_type = jnp.float32 + (a_raw, b_raw), (a_dq, b_dq), _, block_scale_configs = ( + generate_nvfp4_quantized_tensors(configs, output_type, enable_grad_clip) + ) + a_gs = block_scale_configs[0].global_scale + b_gs = block_scale_configs[1].global_scale + dimension_numbers = configs[2] + + scaled_dot_general = partial( + scaled_dot_general_wrapper, + configs=block_scale_configs + ) + + def fwd(a, b, use_normalized=False): + y = scaled_dot_general( + a, b, dimension_numbers, + preferred_element_type=output_type + ) + return jnp.sum(y) + + j_train = jax.jit(jax.value_and_grad(fwd, argnums=[0, 1])) + _, (x_grad, w_grad) = j_train(a_raw, b_raw) + + data_max = jnp.finfo(jnp.float4_e2m1fn).max.astype(output_type) + scale_max = jnp.finfo(jnp.float8_e4m3fn).max.astype(output_type) + prev_amax_a = a_gs * data_max * scale_max + prev_amax_b = b_gs * data_max * scale_max + + # Use a large value to ensure no clipping + threshold_a = prev_amax_a if enable_grad_clip else 1e9 + threshold_b = prev_amax_b if enable_grad_clip else 1e9 + + # Verify gradients are clipped to 0 where |input| > global_scale * MAX * SCALE_MAX + self.assertArraysEqual( + jnp.where(jnp.abs(a_raw) > threshold_a, x_grad, 0), + jnp.zeros_like(x_grad), + ) + self.assertArraysEqual( + jnp.where(jnp.abs(b_raw) > threshold_b, w_grad, 0), + jnp.zeros_like(w_grad), + ) + if enable_grad_clip: + # Verify gradients are preserved where |input| <= global_scale * MAX * SCALE_MAX + self.assertArraysEqual( + jnp.where(jnp.abs(a_raw) <= prev_amax_a, x_grad, 0), + x_grad, + ) + self.assertArraysEqual( + jnp.where(jnp.abs(b_raw) <= prev_amax_b, w_grad, 0), + w_grad, + ) + @jtu.sample_product( configs=[ # a_shape, b_shape, dimension_numbers, is_training @@ -567,6 +634,16 @@ def fwd(a, b, is_ref=False, use_normalized=False): out_ref, _ = j_train_fwd_ref(a_dq, b_dq) self.assertArraysAllClose(out, out_ref, rtol=1e-2, atol=1e-2) + def _grad_clip(amax, x, grad): + return jnp.where(jnp.abs(x) <= amax, grad, 0) + + data_max = jnp.finfo(jnp.float4_e2m1fn).max.astype(output_type) + scale_max = jnp.finfo(jnp.float8_e4m3fn).max.astype(output_type) + prev_amax_a = a_gs * data_max * scale_max + prev_amax_b = b_gs * data_max * scale_max + + x_grad_ref = _grad_clip(prev_amax_a, a_raw, x_grad_ref) + w_grad_ref = _grad_clip(prev_amax_b, b_raw, w_grad_ref) self.assertArraysAllClose(x_grad, x_grad_ref, rtol=1e-2, atol=1e1) self.assertArraysAllClose(w_grad, w_grad_ref, rtol=1e-2, atol=1e1) else: @@ -659,11 +736,11 @@ def test_dot_general_sharded(self, in_shardings): k1, k2 = jax.random.split(jax.random.key(0), 2) a = cast_to_representable( - jax.random.uniform(k1, a_shape, minval=-1.0), + jax.random.uniform(k1, a_shape, minval=-1.0, dtype=jnp.float32), self.block_scale_configs[0].data_type, ) b = cast_to_representable( - jax.random.uniform(k2, b_shape, minval=-1.0), + jax.random.uniform(k2, b_shape, minval=-1.0, dtype=jnp.float32), self.block_scale_configs[1].data_type, ) @@ -694,10 +771,6 @@ def fwd(a, b, is_ref=False): j_train = jax.jit(jax.value_and_grad(partial(fwd), argnums=[0, 1]), in_shardings=input_shardings) - hlo_text = j_train.lower(a, b).compile().as_text() - hlo_pattern = re.compile( - r".*".join([re.escape(x) for x in ("custom-call", c_name)]) - ) j_train_ref = jax.jit( jax.value_and_grad(partial(fwd, is_ref=True), argnums=[0, 1]), @@ -731,11 +804,11 @@ def test_dot_general_vmap(self, configs): dimension_numbers = (([1], [1]), ([], [])) a = cast_to_representable( - jax.random.uniform(k1, a_shape, minval=-1.0), + jax.random.uniform(k1, a_shape, minval=-1.0, dtype=jnp.float32), self.block_scale_configs[0].data_type, ) b = cast_to_representable( - jax.random.uniform(k2, b_shape, minval=-1.0), + jax.random.uniform(k2, b_shape, minval=-1.0, dtype=jnp.float32), self.block_scale_configs[1].data_type, ) diff --git a/tests/scipy_signal_test.py b/tests/scipy_signal_test.py index 11923257a9dd..b1c5d9c98fed 100644 --- a/tests/scipy_signal_test.py +++ b/tests/scipy_signal_test.py @@ -357,12 +357,11 @@ def testWelchWithDefaultStepArgsAgainstNumpy( if use_nperseg: kwargs['nperseg'] = nperseg if use_window: - kwargs['window'] = jnp.array(osp_signal.get_window('hann', nperseg), - dtype=dtypes.to_complex_dtype(dtype)) + kwargs['window'] = jnp.array(osp_signal.get_window('hann', nperseg)) if use_noverlap: kwargs['noverlap'] = noverlap - @jtu.ignore_warning(message="nperseg = 256 is greater than") + @jtu.ignore_warning(message="nperseg") def osp_fun(x): freqs, Pxx = osp_signal.welch(x, **kwargs) return freqs.astype(_real_dtype(dtype)), Pxx.astype(_real_dtype(dtype)) @@ -388,7 +387,7 @@ def osp_fun(x): ], dtype=default_dtypes, fs=[1.0, 16000.0], - window=['boxcar', 'triang', 'blackman', 'hamming', 'hann'], + window=['boxcar', 'triang', 'blackman', 'hamming', 'hann', 'USE_ARRAY'], onesided=[False, True], boundary=[False, True], ) @@ -399,6 +398,11 @@ def testIstftAgainstNumpy(self, *, shape, dtype, fs, window, nperseg, new_freq_len = (shape[freqaxis] - 1) * 2 shape = shape[:freqaxis] + (new_freq_len ,) + shape[freqaxis + 1:] + if window == 'USE_ARRAY': + # ensure dtype matches the expected dtype of `xsubs` within the implementation. + window = np.ones(nperseg, dtype=( + dtypes.to_floating_dtype(dtype) if onesided else dtypes.to_complex_dtype(dtype))) + kwds = dict(fs=fs, window=window, nperseg=nperseg, noverlap=noverlap, nfft=nfft, input_onesided=onesided, boundary=boundary, time_axis=timeaxis, freq_axis=freqaxis) diff --git a/tests/scipy_spatial_test.py b/tests/scipy_spatial_test.py index 3da98efce884..6b1c042b049e 100644 --- a/tests/scipy_spatial_test.py +++ b/tests/scipy_spatial_test.py @@ -123,8 +123,6 @@ def testRotationAsQuat(self, shape, dtype): shape=[(4,), (num_samples, 4)], ) def testRotationAsQuatCanonical(self, shape, dtype): - if scipy_version < (1, 11, 0): - self.skipTest("Scipy 1.11.0 added the `canonical` arg.") rng = jtu.rand_default(self.rng()) args_maker = lambda: (rng(shape, dtype),) jnp_fn = lambda q: jsp_Rotation.from_quat(q).as_quat(canonical=True) @@ -152,8 +150,6 @@ def testRotationAsQuatScalarFirst(self, shape, dtype): other_shape=[(num_samples, 4)], ) def testRotationConcatenate(self, shape, other_shape, dtype): - if scipy_version < (1, 8, 0): - self.skipTest("Scipy 1.8.0 needed for concatenate.") rng = jtu.rand_default(self.rng()) args_maker = lambda: (rng(shape, dtype), rng(other_shape, dtype),) jnp_fn = lambda q, o: jsp_Rotation.concatenate([jsp_Rotation.from_quat(q), jsp_Rotation.from_quat(o)]).as_rotvec() @@ -297,8 +293,6 @@ def testRotationInv(self, shape, dtype): shape=[(4,), (num_samples, 4)], ) def testRotationInvConjugate(self, shape, dtype): - if scipy_version < (1, 11, 0): - self.skipTest("Scipy prior to 1.11.0 used a negative conjugate.") rng = jtu.rand_default(self.rng()) args_maker = lambda: (rng(shape, dtype),) jnp_fn = lambda q: jsp_Rotation.from_quat(q).inv().as_quat() diff --git a/tests/scipy_stats_test.py b/tests/scipy_stats_test.py index 796d4490daea..e9021b86bb7a 100644 --- a/tests/scipy_stats_test.py +++ b/tests/scipy_stats_test.py @@ -20,18 +20,15 @@ import numpy as np import scipy.stats as osp_stats -import scipy.version import jax import jax.numpy as jnp -from jax._src import dtypes, test_util as jtu +from jax._src import test_util as jtu from jax.scipy import stats as lsp_stats from jax.scipy.special import expit jax.config.parse_flags_with_absl() -scipy_version = jtu.parse_version(scipy.version.version) - all_shapes = [(), (4,), (3, 4), (3, 1), (1, 4), (2, 1, 4)] one_and_two_dim_shapes = [(4,), (3, 4), (3, 1), (1, 4)] @@ -217,9 +214,6 @@ def testBernoulliPpf(self, shapes, dtypes): scipy_fun = osp_stats.bernoulli.ppf lax_fun = lsp_stats.bernoulli.ppf - if scipy_version < (1, 9, 2): - self.skipTest("Scipy 1.9.2 needed for fix https://github.com/scipy/scipy/pull/17166.") - def args_maker(): q, p = map(rng, shapes, dtypes) q = expit(q) @@ -1664,9 +1658,6 @@ def evaluate_kde(kde, x): message="All axis-slices of one or more sample arguments are too small", ) def testMode(self, shape, dtype, axis, contains_nans, keepdims): - if scipy_version < (1, 9, 0) and keepdims != True: - self.skipTest("scipy < 1.9.0 only support keepdims == True") - if contains_nans: rng = jtu.rand_some_nan(self.rng()) else: @@ -1675,25 +1666,7 @@ def testMode(self, shape, dtype, axis, contains_nans, keepdims): def scipy_mode_wrapper(a, axis=0, nan_policy='propagate', keepdims=None): """Wrapper to manage the shape discrepancies between scipy and jax""" - if scipy_version < (1, 11, 0) and a.size == 0: - if keepdims: - if axis == None: - output_shape = tuple(1 for _ in a.shape) - else: - output_shape = tuple(1 if i == axis else s for i, s in enumerate(a.shape)) - else: - if axis == None: - output_shape = () - else: - output_shape = np.delete(np.array(a.shape, dtype=np.int64), axis) - t = dtypes.canonicalize_dtype(jax.numpy.float_) - return (np.full(output_shape, np.nan, dtype=t), - np.zeros(output_shape, dtype=t)) - - if scipy_version < (1, 9, 0): - result = osp_stats.mode(a, axis=axis, nan_policy=nan_policy) - else: - result = osp_stats.mode(a, axis=axis, nan_policy=nan_policy, keepdims=keepdims) + result = osp_stats.mode(a, axis=axis, nan_policy=nan_policy, keepdims=keepdims) if a.size != 0 and axis == None and keepdims == True: output_shape = tuple(1 for _ in a.shape) @@ -1748,11 +1721,10 @@ def testSEM(self, shape, dtype, axis, ddof, nan_policy, keepdims): rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(shape, dtype)] - kwds = {} if scipy_version < (1, 11) else {'keepdims': keepdims} scipy_fun = partial(osp_stats.sem, axis=axis, ddof=ddof, nan_policy=nan_policy, - **kwds) + keepdims=keepdims) lax_fun = partial(lsp_stats.sem, axis=axis, ddof=ddof, nan_policy=nan_policy, - **kwds) + keepdims=keepdims) tol_spec = {np.float32: 2e-4, np.float64: 5e-6} tol = jtu.tolerance(dtype, tol_spec) self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, diff --git a/tests/shape_poly_test.py b/tests/shape_poly_test.py index 6d1ffe744ed9..68aaf4e29553 100644 --- a/tests/shape_poly_test.py +++ b/tests/shape_poly_test.py @@ -961,7 +961,7 @@ def test_constraints_ge_complex_gen(self, self.assertEqual(bounds, _bounds(exp)) def test_constraints_ge_override(self): - # Some constaints override other + # Some constraints override other a, b = shape_poly.symbolic_shape("a, b", constraints=("a >= 5", "b <= 16", "a >= 10", "b <= 10")) @@ -979,7 +979,7 @@ def test_constraint_eq_0(self): self.assertIs(d, 5) def test_constraints_eq_1(self): - # Some constaints override other + # Some constraints override other a, b, c = shape_poly.symbolic_shape("a, b, c", constraints=("max(a, b) == c",)) self.assertEqual(_bounds(core.max_dim(a, b) - c + 3), (3, 3)) diff --git a/tests/shard_alike_test.py b/tests/shard_alike_test.py index 2ad3e089e662..aeb218b478ad 100644 --- a/tests/shard_alike_test.py +++ b/tests/shard_alike_test.py @@ -19,7 +19,7 @@ from jax._src import test_util as jtu from jax.sharding import NamedSharding, PartitionSpec as P from jax.experimental.shard_alike import shard_alike -from jax.experimental.shard_map import shard_map +from jax._src.shard_map import shard_map jax.config.parse_flags_with_absl() jtu.request_cpu_devices(8) @@ -146,7 +146,7 @@ def g(x): @jax.jit def f(x): y = x @ x.T - s_out = shard_map(g, mesh, in_specs=P('x', 'y'), + s_out = shard_map(g, mesh=mesh, in_specs=P('x', 'y'), out_specs=P(None, 'y'))(y) z = s_out.T @ s_out return shard_alike(y, z) diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index f8d5a11e842f..c1297fc49ea1 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -36,18 +36,19 @@ from jax._src import config from jax._src import core from jax._src import prng +from jax._src.shard_map import shard_map, smap from jax._src import test_util as jtu from jax._src.lib.mlir.dialects import sdy from jax._src.util import safe_zip, safe_map, partition_list, merge_lists from jax._src.ad_checkpoint import saved_residuals -from jax._src.mesh import AxisType +from jax._src.mesh import AxisType, get_abstract_mesh from jax._src.interpreters import partial_eval as pe from jax._src import linear_util as lu from jax._src import tree_util +from jax.custom_derivatives import SymbolicZero import jax.numpy as jnp from jax.experimental.custom_partitioning import custom_partitioning -from jax.experimental.shard_map import shard_map config.parse_flags_with_absl() @@ -57,14 +58,14 @@ zip, unsafe_zip = safe_zip, zip # Helper for some tests. -def create_inputs(a_sharding, b_sharding): +def create_inputs(a_sharding, b_sharding, dtype=None): mesh = jtu.create_mesh((2, 2, 2), ('x', 'y', 'z')) b, e, f = 8, 8, 8 # pylint: disable=invalid-name m1 = jax.device_put( - jnp.arange(b * e).reshape((b, e)), + jnp.arange(b * e, dtype=dtype).reshape((b, e)), jax.sharding.NamedSharding(mesh, a_sharding)) m2 = jax.device_put( - jnp.arange(e * f).reshape((e, f)), + jnp.arange(e * f, dtype=dtype).reshape((e, f)), jax.sharding.NamedSharding(mesh, b_sharding)) return mesh, m1, m2 @@ -82,7 +83,7 @@ def identity(x): def fwd(a): c = shard_map( identity, - mesh, + mesh=mesh, in_specs=(P('z', ('x', 'y')),), out_specs=P('z', ('x', 'y')))(a) return c @@ -94,17 +95,13 @@ def test_all_gather(self): mesh, a, _ = create_inputs(P('z', ('x', 'y')), P(None, None)) assert a.addressable_data(0).shape == (4, 2) - # NOTE(mattjj): to use out_specs=P(None, ('x', 'y')), we need to use - # all_gather_invariant primitive, which differs in its output replication - # type compared to all_gather. @jax.jit @partial(shard_map, mesh=mesh, in_specs=(P('z', ('x', 'y')),), out_specs=P('z', ('x', 'y'))) def fwd(a): - return ( - lax.all_gather(a, 'z', axis=0, tiled=True), - lax.all_gather(a, ('x', 'y'), axis=-1, tiled=True), - ) + return (lax.all_gather(a, 'z', axis=0, tiled=True), + lax.all_gather(a, ('x', 'y'), axis=-1, tiled=True)) + c, d = fwd(a) self.assertEqual(c.addressable_data(0).shape, (8, 2)) for i, a_shard in enumerate(np.split(a, 4, axis=1)): @@ -113,6 +110,64 @@ def fwd(a): for i, a_shard in enumerate(np.split(a, 2, axis=0)): self.assertAllClose(d.addressable_data(i), a_shard) + def test_all_gather_invariant_basic(self): + mesh = jtu.create_mesh((4,), 'x') + arr = jnp.arange(8.) + + @jax.jit + @shard_map(mesh=mesh, in_specs=P('x'), out_specs=P()) + def f(a): + out = lax.all_gather_invariant(a, 'x', tiled=True) + self.assertEqual(out.aval.vma, set()) + return out + + out = f(arr) + self.assertArraysEqual(out, arr) + + jtu.check_grads(f, (arr,), order=2) + + def g(x): + return f(x).sum() + out = jax.jit(jax.grad(g))(arr) + self.assertEqual(out.shape, (8,)) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x'))) + + def test_all_gather_invariant_complex(self): + mesh, a, _ = create_inputs(P('z', ('x', 'y')), P(None, None), + dtype=np.float32) + assert a.addressable_data(0).shape == (4, 2) + + @jax.jit + @shard_map(mesh=mesh, in_specs=(P('z', ('x', 'y')),), + out_specs=(P(None, ('x', 'y')), P('z'))) + def f(a): + c = lax.all_gather_invariant(a, 'z', axis=0, tiled=True) + self.assertEqual(jax.typeof(c).vma, {'x', 'y'}) + d = lax.all_gather_invariant(a, ('x', 'y'), axis=-1, tiled=True) + self.assertEqual(jax.typeof(d).vma, {'z'}) + return c, d + + c, d = f(a) + + self.assertEqual(c.addressable_data(0).shape, (8, 2)) + for i, a_shard in enumerate(np.split(a, 4, axis=1)): + self.assertAllClose(c.addressable_data(2 * i), a_shard) + + self.assertEqual(d.addressable_data(0).shape, (4, 8)) + for i, a_shard in enumerate(np.split(a, 2, axis=0)): + self.assertAllClose(d.addressable_data(i), a_shard) + + def g(x): + return f(x)[0].sum() + + out1 = jax.jit(jax.grad(g))(a) + self.assertEqual(out1.shape, (8, 8)) + self.assertEqual(out1.sharding, NamedSharding(mesh, P('z', ('x', 'y')))) + + out2 = jax.grad(g)(a) + self.assertEqual(out2.shape, (8, 8)) + self.assertEqual(out2.sharding, NamedSharding(mesh, P('z', ('x', 'y')))) + def test_all_gather_with_axis_index_groups(self): mesh, a, _ = create_inputs(P('x', ('y', 'z')), P(None, None)) @@ -217,13 +272,104 @@ def test_collective_permute(self): shard_map, mesh=mesh, in_specs=(P('x', None),), out_specs=P('x', None) ) def fwd(a): - axis_size = lax.psum(1, 'x') + axis_size = lax.axis_size('x') perm = [(j, (j + 1) % axis_size) for j in range(axis_size)] return lax.ppermute(a, 'x', perm=perm) c = fwd(a) self.assertAllClose(c[1, :], a[0, :]) + @jtu.run_on_devices("gpu") + def test_psend_precv_basic_with_no_deadlock_cycle(self): + mesh = jtu.create_mesh((8,), 'x') + a = jax.device_put( + jnp.arange(8 * 8).reshape((8, 8)), + jax.sharding.NamedSharding(mesh, P('x', None))) + weights = jax.random.uniform( + key=jax.random.key(0), shape=(8, 1), dtype=jnp.float32) + + @jax.jit + @partial( + jax.shard_map, mesh=mesh, in_specs=(P('x', None),), out_specs=P('x', None) + ) + def fwd(a): + return_dtype_and_shape = jax.ShapeDtypeStruct(a.shape, a.dtype) + + # We define the "forward edge" to be the device-to-device communication + # originating from device 0 in increasing indices. + fwd_token = jax.lax.psend( + a, + axis_name="x", + perm=[(0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 7)], + ) + + data = jax.lax.precv( + fwd_token, + out_shape=return_dtype_and_shape, + axis_name="x", + perm=[(0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 7)], + ) + + # Here we use an optimization barrier to enforce an arbitrary ordering of + # collectives. This will make sure compute happens after recv on the forward + # edge, and by extension will make sure the send on the back edge happens + # after the recv on the forward edge. Without this optimization barrier, the + # send on the backward edge might slip before the forward edge recv ops are + # completed, and will cause a deadlock. + weights_, _ = ( + jax.lax.optimization_barrier( + (weights, data) + ) + ) + res = jnp.dot(weights_, data) + + # send the compute result back to the first device + bwd_token = jax.lax.psend( + res, + axis_name="x", + perm=[(7, 0)], + ) + + bwd_data = jax.lax.precv( + bwd_token, + out_shape=return_dtype_and_shape, + axis_name="x", + perm=[(7, 0)] + ) + return bwd_data + + c = fwd(a) + self.assertEqual(c.shape, a.shape) + + @jtu.run_on_devices("gpu") + def test_psend_precv_reverse(self): + mesh = jtu.create_mesh((8,), 'x') + a = jax.device_put( + jnp.arange(8 * 8).reshape((8, 8)), + jax.sharding.NamedSharding(mesh, P('x', None))) + @jax.jit + @partial( + jax.shard_map, mesh=mesh, in_specs=(P('x', None),), out_specs=P('x', None) + ) + def fwd(a): + return_dtype_and_shape = jax.ShapeDtypeStruct(a.shape, a.dtype) + dummy_data = jax.lax.precv( + jax.lax.create_token(), + out_shape=return_dtype_and_shape, + axis_name="x", + perm=[(0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 7)], + ) + + _ = jax.lax.psend( + dummy_data, + axis_name="x", + perm=[(0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 7)], + ) + return dummy_data + + c = fwd(a) + self.assertAllClose(c, jnp.zeros_like(a)) + def test_collective_permute_with_multiple_axis_names(self): mesh = jtu.create_mesh((2, 2, 2), ('x', 'y', 'z')) a = jax.device_put( @@ -239,8 +385,8 @@ def test_collective_permute_with_multiple_axis_names(self): out_specs=P('x', ('y', 'z')), ) def fwd(a): - xy_axis_size = lax.psum(1, ('x', 'y')) - yz_axis_size = lax.psum(1, ('y', 'z')) + xy_axis_size = lax.axis_size(('x', 'y')) + yz_axis_size = lax.axis_size(('y', 'z')) xy_perm = [(j, (j + 1) % xy_axis_size) for j in range(xy_axis_size)] yz_perm = [(j, (j + 1) % yz_axis_size) for j in range(yz_axis_size)] return ( @@ -289,6 +435,39 @@ def fwd(a): c = fwd(a) assert (c == jnp.reshape(a.T, (1, 64))).all() + @parameterized.named_parameters( + dict( + testcase_name='_partial_replicated', replicate_on_axes='x', + ), + dict( + testcase_name='_fully_replicated', + replicate_on_axes=('x', 'y'), + ), + ) + @jtu.run_on_devices("gpu") + def test_pbroadcast(self, replicate_on_axes): + mesh = jtu.create_mesh((4, 2), ('x', 'y')) + sharded_axes = set(mesh.axis_names) - set(replicate_on_axes) + sharded_axes = None if not sharded_axes else list(sharded_axes) + in_out_sharding = jax.sharding.NamedSharding(mesh, P(sharded_axes, None)) + a = jax.device_put(jnp.arange(16).reshape((4, 4)), in_out_sharding) + + @jax.jit + @partial( + shard_map, + mesh=mesh, + in_specs=(in_out_sharding.spec,), + out_specs=in_out_sharding.spec, + check_vma=False, + ) + def fwd(x): + axis_index = lax.axis_index(replicate_on_axes) + x = jnp.where(axis_index == 0, x + 1, x) + return lax.pbroadcast(x, replicate_on_axes, source=0) + + c = fwd(a) # Don't crash + self.assertAllClose(c, a + 1) + def test_all_to_all_with_axis_index_groups(self): mesh = jtu.create_mesh((4,), ('x',)) a = jax.device_put( @@ -367,7 +546,7 @@ def f(x): def test_jvp_basic(self): mesh = jtu.create_mesh((2, 2), ('x', 'y')) - g = shard_map(lambda x: jnp.sin(jnp.cos(x)), mesh, + g = shard_map(lambda x: jnp.sin(jnp.cos(x)), mesh=mesh, in_specs=(P('x', 'y'),), out_specs=P('x', 'y')) args = np.arange(4 * 4.).reshape(4, 4), jtu.check_grads(g, args, 2, ['fwd']) @@ -375,7 +554,7 @@ def test_jvp_basic(self): def test_linearize_basic(self): mesh = jtu.create_mesh((2, 2), ('x', 'y')) - g = shard_map(lambda x: jnp.sin(jnp.cos(x)), mesh, + g = shard_map(lambda x: jnp.sin(jnp.cos(x)), mesh=mesh, in_specs=(P('x', 'y'),), out_specs=P('x', 'y')) x = np.arange(4 * 4.).reshape(4, 4) @@ -389,7 +568,7 @@ def test_linearize_basic(self): def test_linearize_basic_repres(self): mesh = jtu.create_mesh((2, 2), ('x', 'y')) - g = shard_map(lambda x: jax.lax.sin(jax.lax.cos(x)), mesh, + g = shard_map(lambda x: jax.lax.sin(jax.lax.cos(x)), mesh=mesh, in_specs=(P('x',),), out_specs=P('x',)) x = np.arange(4.) @@ -403,7 +582,7 @@ def test_linearize_basic_repres(self): def test_linearize_basic_repres_jit(self): mesh = jtu.create_mesh((2, 2), ('x', 'y')) - g = shard_map(lambda x: jnp.sin(jnp.cos(x)), mesh, + g = shard_map(lambda x: jnp.sin(jnp.cos(x)), mesh=mesh, in_specs=(P('x',),), out_specs=P('x',)) x = np.arange(4.) @@ -422,7 +601,7 @@ def test_replication_checker_eager(self): def f(x): return 2 * x def g(x): - return shard_map(f, mesh, in_specs=(P('x', 'y'),), out_specs=P(None, 'y'))(x) + return shard_map(f, mesh=mesh, in_specs=(P('x', 'y'),), out_specs=P(None, 'y'))(x) with self.assertRaisesRegex(ValueError, 'statically inferred'): g(x) @@ -430,26 +609,24 @@ def g(x): def f2(x): return jax.lax.psum(x, 'x') def g2(x): - return shard_map(f2, mesh, in_specs=(P('x', 'y'),), out_specs=P(None, 'y'))(x) + return shard_map(f2, mesh=mesh, in_specs=(P('x', 'y'),), out_specs=P(None, 'y'))(x) _ = g2(x) # doesn't crash def test_replication_checker_jit(self): mesh = jtu.create_mesh((2, 2), ('x', 'y')) x = np.arange(8 * 8.).reshape(8, 8) - def f(x): - return 2 * x def g(x): - return shard_map(f, mesh, in_specs=(P('x', 'y'),), out_specs=P(None, 'y'))(x) + return shard_map(lambda x: x * 2, mesh=mesh, in_specs=P('x', 'y'), + out_specs=P(None, 'y'))(x) with self.assertRaisesRegex(ValueError, 'statically inferred'): jax.jit(g)(x) - def f2(x): - return jax.lax.psum(x, 'x') def g2(x): - return shard_map(f2, mesh, in_specs=(P('x', 'y'),), out_specs=P(None, 'y'))(x) - _ = jax.jit(g2)(x) # doesn't crash + return shard_map(lambda x: jax.lax.psum(x, 'x'), mesh=mesh, + in_specs=P('x', 'y'), out_specs=P(None, 'y'))(x) + jax.jit(g2)(x) # doesn't crash def test_process_env_traces(self): mesh = Mesh(np.array(jax.devices()[:4]), ('x',)) @@ -457,7 +634,7 @@ def test_process_env_traces(self): def g(x): y = (3. * x).sum() - z = shard_map(lambda x: 2 * x * y, mesh, + z = shard_map(lambda x: 2 * x * y, mesh=mesh, in_specs=(P('x'),), out_specs=P('x'))(np.arange(8.)) return z @@ -475,13 +652,14 @@ def f(x): return -x def g(x): - return shard_map(f, mesh, in_specs=(P('x', 'y'),), out_specs=P('x', 'y'))(x) + return shard_map(f, mesh=mesh, in_specs=(P('x', 'y'),), out_specs=P('x', 'y'))(x) y = g(x) self.assertAllClose(y, -x, check_dtypes=False) def test_outer_jit_detects_shard_map_mesh(self): mesh = jtu.create_mesh((2, 2), ('x', 'y')) - f = shard_map(lambda x: x.reshape(1, *x.shape), mesh, P(), P('x')) + f = shard_map(lambda x: x.reshape(1, *x.shape), mesh=mesh, in_specs=P(), + out_specs=P('x')) _ = jax.jit(f)(jnp.array(2.0)) # doesn't crash def test_vmap_basic(self): @@ -489,7 +667,7 @@ def test_vmap_basic(self): x = jnp.arange(8 * 8.).reshape(8, 8) def g(x): - return shard_map(lambda x: 2. * x, mesh, + return shard_map(lambda x: 2. * x, mesh=mesh, in_specs=P('y'), out_specs=P('y'))(x) y = jax.vmap(g)(x) self.assertAllClose(y, 2 * x, check_dtypes=False) @@ -499,7 +677,7 @@ def test_vmap_basic_axis_name(self): x = jnp.arange(8 * 8.).reshape(8, 8) def g(x): - return shard_map(lambda x: 2. * x, mesh, + return shard_map(lambda x: 2. * x, mesh=mesh, in_specs=P('y'), out_specs=P('y'))(x) y = jax.vmap(g, axis_name='i')(x) self.assertAllClose(y, 2 * x, check_dtypes=False) @@ -509,7 +687,7 @@ def test_vmap_basic_axis_name_reuse_mesh_name(self): x = jnp.arange(8 * 8.).reshape(8, 8) def g(x): - return shard_map(lambda x: 2. * x, mesh, + return shard_map(lambda x: 2. * x, mesh=mesh, in_specs=P('y'), out_specs=P('y'))(x) y = jax.vmap(g, axis_name='x')(x) # NOTE reuse same 'x' as on mesh self.assertAllClose(y, 2 * x, check_dtypes=False) @@ -588,6 +766,32 @@ def f(): x = f() self.assertAllClose(x, jnp.arange(4), check_dtypes=False) + def test_optimize_remat(self): + mesh = jtu.create_mesh((4,), 'x') + + @jax.custom_vjp + def f(x): + return jnp.tan(x) + + def f_fwd(x): + return jax.lax.psum(x, 'x'), (x,) + + def f_bwd(res, g): + x, = res + cos_x = jnp.cos(x) + return (cos_x * g,) + + f.defvjp(f_fwd, f_bwd, optimize_remat=True) + + @jax.jit + @jax.shard_map(mesh=mesh, in_specs=P(), out_specs=P()) + def temp(x): + out = jax.remat(f)(x) + out = out ** 2 + return out + + jax.grad(lambda x: temp(x).sum())(jnp.arange(4.)) + def test_remat_basic(self): # this tests remat-of-shmap mesh = Mesh(np.array(jax.devices()[:4]), ('x',)) @@ -662,29 +866,37 @@ def test_check_rep_false_doesnt_hit_rep_rules(self): prim.def_impl(lambda: []) prim.def_abstract_eval(lambda: []) - @partial(shard_map, mesh=mesh, in_specs=(), out_specs=None, check_rep=True) + @partial(shard_map, mesh=mesh, in_specs=(), out_specs=None, check_vma=True) def f(): prim.bind() - with self.assertRaises(NotImplementedError): - f() - with self.assertRaises(NotImplementedError): - jax.jit(f)() - - @partial(shard_map, mesh=mesh, in_specs=(), out_specs=None, check_rep=False) + @partial(shard_map, mesh=mesh, in_specs=(), out_specs=None, check_vma=False) def f2(): prim.bind() f2() jax.jit(f2)() - @partial(shard_map, mesh=mesh, in_specs=(), out_specs=None, check_rep=False) + @partial(shard_map, mesh=mesh, in_specs=(), out_specs=None, check_vma=False) def f3(): jax.jit(prim.bind)() f3() jax.jit(f3)() + def test_multiple_result_primitive_with_none_sharding(self): + # https://github.com/jax-ml/jax/issues/27673 + xs = jnp.arange(20).reshape(2, 10) + mesh = jtu.create_mesh((2,), ("i",)) + y = shard_map( + lambda x: jnp.split(x.squeeze(), 2), + mesh=mesh, + in_specs=(None,), + out_specs=P("i"), + )(xs) + expected = jnp.repeat(xs, 2, axis=0).reshape(2, 2, 10) + self.assertArraysEqual(y, expected) + def test_vmap_spmd_axis_name(self): mesh = jtu.create_mesh((2, 2), ('x', 'y')) @@ -695,16 +907,56 @@ def f(x): x = jnp.arange(4 * 4).reshape(4, 4) jaxpr = jax.make_jaxpr(jax.vmap(f, spmd_axis_name='y'))(x).jaxpr e, = jaxpr.eqns - self.assertIn('in_names', e.params) - self.assertEqual(e.params['in_names'], ({0: ('y',), 1: ('x',)},)) - self.assertIn('out_names', e.params) - self.assertEqual(e.params['out_names'], ({0: ('y',), 1: ('x',)},)) + self.assertIn('in_specs', e.params) + self.assertEqual(e.params['in_specs'], (P('y', 'x'),)) + self.assertIn('out_specs', e.params) + self.assertEqual(e.params['out_specs'], (P('y', 'x'),)) + + def test_vmap_explicit_mesh_axis(self): + mesh = jtu.create_mesh( + (1, 2, 2), ('z', 'x', 'y'), axis_types=(AxisType.Explicit,) * 3) + + @shard_map(mesh=mesh, in_specs=P('y'), out_specs=P('y')) + def f(x): + return x + + x = jnp.arange(4 * 4).reshape(4, 4) + s = NamedSharding(mesh, P(('z', 'x'), 'y')) + x = jax.device_put(x, s) + + f = jax.jit(jax.vmap(f)) + out = f(x) + self.assertEqual(out.sharding, s) + + def test_vmap_explicit_mesh_axis_error(self): + mesh = jtu.create_mesh((2, 2), ('x', 'y'), + axis_types=(AxisType.Explicit,) * 2) + + @shard_map(mesh=mesh, in_specs=P('x'), out_specs=P('x')) + def f(x): + return x + + x = jnp.arange(4 * 4).reshape(4, 4) + s = NamedSharding(mesh, P('x', 'y')) + x = jax.device_put(x, s) + + f = jax.jit(jax.vmap(f)) + with self.assertRaisesRegex( + ValueError, "vmapped away explicit mesh axis cannot appear"): + f(x) + + f = jax.jit(jax.vmap(f, spmd_axis_name='y')) + with self.assertRaisesRegex( + ValueError, + 'Only one of spmd_axis_name or arrays sharded on `Explicit` mesh axis' + ' type is allowed'): + f(x) def test_vmap_of_grad_spmd_axis_name(self): mesh = jtu.create_mesh((2, 2), ('x', 'y')) @partial( - shard_map, mesh=mesh, in_specs=P('y'), out_specs=P(), check_rep=False + shard_map, mesh=mesh, in_specs=P('y'), out_specs=P(), check_vma=False ) def f(x): return jnp.sin(jnp.sum(x)) @@ -730,10 +982,10 @@ def f(x): x = jnp.arange(4 * 4).reshape(4, 4) jaxpr = jax.make_jaxpr(jax.vmap(f, spmd_axis_name=('x', 'y')))(x).jaxpr e, = jaxpr.eqns - self.assertIn('in_names', e.params) - self.assertEqual(e.params['in_names'], ({0: ('x', 'y',)},)) - self.assertIn('out_names', e.params) - self.assertEqual(e.params['out_names'], ({0: ('x', 'y',)},)) + self.assertIn('in_specs', e.params) + self.assertEqual(e.params['in_specs'][0], P(('x', 'y'))) + self.assertIn('out_specs', e.params) + self.assertEqual(e.params['out_specs'][0], P(('x', 'y'))) def test_nested_vmap_with_capture_spmd_axis_name(self): self.skipTest('https://github.com/jax-ml/jax/issues/23476') @@ -861,8 +1113,6 @@ def test_shmap_abstract_mesh_errors(self): @jtu.run_on_devices('cpu', 'gpu', 'tpu') @jtu.thread_unsafe_test() def test_debug_print_jit(self, jit): - if config.use_shardy_partitioner.value: - self.skipTest('TODO(b/384938613): Failing under shardy') mesh = Mesh(jax.devices(), ('i',)) @partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P('i')) @@ -884,6 +1134,39 @@ def f(x): for i in range(len(jax.devices())): self.assertIn(f'instance {i} has value', output()) + def test_psum_transpose_non_zero_cts(self): + mesh = jtu.create_mesh((8,), 'x') + @shard_map(mesh=mesh, in_specs=P('x'), out_specs=(P('x'), P())) + def f1(x_block): + return x_block, jax.lax.psum(x_block, axis_name='x') + + x1 = jnp.arange(16.) + f1(x1) # doesn't crash + + def f2(x_block): + y, _ = f1(x_block) + return y.sum() + + jax.jit(jax.grad(f2))(x1) # doesn't crash + jax.grad(f2)(x1) # doesn't crash + + @jtu.run_on_devices('cpu', 'gpu', 'tpu') + @jtu.thread_unsafe_test() + def test_debug_print_jit_partial_auto(self): + mesh = jtu.create_mesh((2,2), ('x', 'y')) + + @partial(shard_map, mesh=mesh, in_specs=P('x'), out_specs=P('x'), + axis_names=frozenset({'x'})) + def f(x): + idx = jax.lax.axis_index('x') + jax.debug.print("instance {i} has value x={x}", i=idx, x=x) + y = jnp.cos(x) + return y + + f = jax.jit(f) + x = jnp.arange(2 * len(jax.devices())) + f(x) # don't crash! + def test_debug_print_eager(self): mesh = Mesh(jax.devices(), ('i',)) @@ -930,9 +1213,22 @@ def f(key): dtype=jnp.int32) pspec = P('x') if config.enable_custom_prng.value else P('x', None) - g = shard_map(f, mesh, in_specs=(pspec,), out_specs=pspec) + g = shard_map(f, mesh=mesh, in_specs=(pspec,), out_specs=pspec) _ = g(sharded_rng) # don't crash! + def test_vma_out_specs_error_check(self): + mesh = jtu.create_mesh((2, 2, 2), ('x', 'y', 'z')) + @shard_map(mesh=mesh, in_specs=P('x', 'y', 'z'), out_specs=P('x')) + def f(x): + return x * 2 + + with self.assertRaisesRegex( + ValueError, + r".*out_specs is PartitionSpec\('x',\) which implies that the.*" + r' output value is only varying across mesh axes \{x\} and not \{y,z\},' + r' but it was inferred to be possibly varying over \{x,y,z\}.*'): + f(np.arange(16).reshape(4, 2, 2)) + def test_functools_partial_rank_error(self): mesh = jtu.create_mesh((4,), ('x',)) @@ -940,7 +1236,7 @@ def test_functools_partial_rank_error(self): def f(x): return x - g = shard_map(f, mesh, in_specs=(P('x', None),), out_specs=P('x',)) + g = shard_map(f, mesh=mesh, in_specs=(P('x', None),), out_specs=P('x',)) x = jnp.arange(4) with self.assertRaises(ValueError): g(x) @@ -950,14 +1246,14 @@ def test_in_specs_none_error(self): def f(x): return x - with self.assertRaisesRegex(TypeError, "but it was None"): - shard_map(f, mesh, in_specs=None, out_specs=P())(3.) + with self.assertRaisesRegex(TypeError, "but it was `None`"): + shard_map(f, mesh=mesh, in_specs=None, out_specs=P())(3.) # TODO(mattjj): enable this test once we fix the tree_map(f, None, 3.0) bug # with self.assertRaises(TypeError): - # shard_map(f, mesh, in_specs=(None,), out_specs=P())(3.) + # shard_map(f, mesh=mesh, in_specs=(None,), out_specs=P())(3.) - shard_map(f, mesh, in_specs=P(), out_specs=P())(3.) # doesn't crash + shard_map(f, mesh=mesh, in_specs=P(), out_specs=P())(3.) # doesn't crash def test_scan_rep_rule(self): mesh = jtu.create_mesh((2, 2,), ('x', 'y')) @@ -967,24 +1263,25 @@ def f(x, y, z): def body(c, _): c, *cs = c return (*cs, c), None + x = lax.pvary(x, ('x', 'y')) + y = lax.pvary(y, 'y') out, _ = jax.lax.scan(body, (x, y, z), None, length=3) return [jnp.expand_dims(a, 0) for a in out] x = jnp.arange(4) - # doesn't crash, because out_spec assumes no replication (and there is none) - shard_map(f, mesh, in_specs=(P(None), P('x'), P(('x', 'y'))), + shard_map(f, mesh=mesh, in_specs=(P(None), P('x'), P(('x', 'y'))), out_specs=P(('x', 'y')))(x, x, x) # does crash, because output incorrectly promises replication with self.assertRaisesRegex(ValueError, "require replication"): - shard_map(f, mesh, in_specs=(P(None), P('x'), P(('x', 'y'))), + shard_map(f, mesh=mesh, in_specs=(P(None), P('x'), P(('x', 'y'))), out_specs=P('x'))(x, x, x) with self.assertRaisesRegex(ValueError, "require replication"): - shard_map(f, mesh, in_specs=(P(None), P('x'), P(('x', 'y'))), + shard_map(f, mesh=mesh, in_specs=(P(None), P('x'), P(('x', 'y'))), out_specs=P('y'))(x, x, x) with self.assertRaisesRegex(ValueError, "require replication"): - shard_map(f, mesh, in_specs=(P(None), P('x'), P(('x', 'y'))), + shard_map(f, mesh=mesh, in_specs=(P(None), P('x'), P(('x', 'y'))), out_specs=P(None))(x, x, x) def g(x, y, z): @@ -995,12 +1292,65 @@ def body(c, _): return [jnp.expand_dims(a, 0) for a in out] # doesn't crash, because everything matches - shard_map(g, mesh, in_specs=(P(None), P('x'), P(('x', 'y'))), + shard_map(g, mesh=mesh, in_specs=(P(None), P('x'), P(('x', 'y'))), out_specs=[P(None), P('x'), P(('x', 'y'))])(x, x, x) # does crash, because the second guy is wrong with self.assertRaisesRegex(ValueError, "require replication"): - shard_map(g, mesh, in_specs=(P(None), P('x'), P(('x', 'y'))), + shard_map(g, mesh=mesh, in_specs=(P(None), P('x'), P(('x', 'y'))), + out_specs=[P(None), P(None), P(('x', 'y'))])(x, x, x) + + def test_while_rep_rule(self): + mesh = jtu.create_mesh((2, 2,), ('x', 'y')) + + def f(x, y, z): + x, y, z = x.sum(), y.sum(), z.sum() + def cond(c): + i, *_ = c + return i < 5 + def body(c): + i, c, *cs = c + return (i + 1, *cs, c) + x = lax.pvary(x, ('x', 'y')) + y = lax.pvary(y, 'y') + _, *out = jax.lax.while_loop(cond, body, (0, x, y, z)) + return [jnp.expand_dims(a, 0) for a in out] + + x = jnp.arange(4) + + # doesn't crash, because out_spec assumes no replication (and there is none) + shard_map(f, mesh=mesh, in_specs=(P(None), P('x'), P(('x', 'y'))), + out_specs=P(('x', 'y')))(x, x, x) + + # does crash, because output incorrectly promises replication + with self.assertRaisesRegex(ValueError, "require replication"): + shard_map(f, mesh=mesh, in_specs=(P(None), P('x'), P(('x', 'y'))), + out_specs=P('x'))(x, x, x) + with self.assertRaisesRegex(ValueError, "require replication"): + shard_map(f, mesh=mesh, in_specs=(P(None), P('x'), P(('x', 'y'))), + out_specs=P('y'))(x, x, x) + with self.assertRaisesRegex(ValueError, "require replication"): + shard_map(f, mesh=mesh, in_specs=(P(None), P('x'), P(('x', 'y'))), + out_specs=P(None))(x, x, x) + + def g(x, y, z): + x, y, z = x.sum(), y.sum(), z.sum() + def cond(c): + i, *_ = c + return i < 1 + def body(c): + i, *cs = c + return (i + 1, *cs) + _, *out = jax.lax.while_loop(cond, body, (0, x, y, z)) + return [jnp.expand_dims(a, 0) for a in out] + + # doesn't crash, because everything matches + shard_map(g, mesh=mesh, in_specs=(P(None), P('x'), P(('x', 'y'))), + out_specs=[P(None), P('x'), P(('x', 'y'))])(x, x, x) + + # does crash, because the second guy is wrong + with self.assertRaisesRegex(ValueError, "require replication"): + shard_map(g, mesh=mesh, in_specs=(P(None), P('x'), P(('x', 'y'))), out_specs=[P(None), P(None), P(('x', 'y'))])(x, x, x) def test_cond_rep_rule(self): @@ -1014,20 +1364,22 @@ def false_fun(x, y): return x + 1 return jax.lax.cond(True, true_fn, false_fun, x, y) - shard_map(f, mesh, in_specs=(P('x'), P('y')), out_specs=P('x'))(x, x) + shard_map(f, mesh=mesh, in_specs=(P('x'), P('y')), out_specs=P('x'))(x, x) + with self.assertRaisesRegex(ValueError, "require replication"): - shard_map(f, mesh, in_specs=(P('x'), P('y')), out_specs=P(None))(x, x) + shard_map(f, mesh=mesh, in_specs=(P('x'), P('y')), out_specs=P(None))(x, x) def f(x, y): def true_fn(x, y): - return x + return lax.pvary(x, 'y') def false_fun(x, y): - return y + return lax.pvary(y, 'x') return jax.lax.cond(True, true_fn, false_fun, x, y) - shard_map(f, mesh, in_specs=(P('x'), P('y')), out_specs=P(('x', 'y')))(x, x) + shard_map(f, mesh=mesh, in_specs=(P('x'), P('y')), out_specs=P(('x', 'y')))(x, x) + with self.assertRaisesRegex(ValueError, "require replication"): - shard_map(f, mesh, in_specs=(P('x'), P('y')), out_specs=P('x'))(x, x) + shard_map(f, mesh=mesh, in_specs=(P('x'), P('y')), out_specs=P('x'))(x, x) def f(x, y): def true_fn(x, y): @@ -1036,9 +1388,10 @@ def false_fun(x, y): return x + 1 return jax.lax.cond(jnp.any(x > 0), true_fn, false_fun, x, y) - shard_map(f, mesh, in_specs=(P('x'), P('y')), out_specs=P('x'))(x, x) + shard_map(f, mesh=mesh, in_specs=(P('x'), P('y')), out_specs=P('x'))(x, x) + with self.assertRaisesRegex(ValueError, "require replication"): - shard_map(f, mesh, in_specs=(P('x'), P('y')), out_specs=P(None))(x, x) + shard_map(f, mesh=mesh, in_specs=(P('x'), P('y')), out_specs=P(None))(x, x) def f(x, y): def true_fn(x, y): @@ -1047,9 +1400,8 @@ def false_fun(x, y): return x + 1 return jax.lax.cond(jnp.any(y > 0), true_fn, false_fun, x, y) - shard_map(f, mesh, in_specs=(P('x'), P('y')), out_specs=P(('x', 'y')))(x, x) - with self.assertRaisesRegex(ValueError, "require replication"): - shard_map(f, mesh, in_specs=(P('x'), P('y')), out_specs=P('x'))(x, x) + shard_map(f, mesh=mesh, in_specs=(P('x'), P('y')), out_specs=P(('x', 'y')))(x, x) + shard_map(f, mesh=mesh, in_specs=(P('x'), P('y')), out_specs=P('x'))(x, x) # https://github.com/jax-ml/jax/issues/24418 def f(a): @@ -1058,7 +1410,7 @@ def f(a): mesh = jtu.create_mesh((2,), ('x',)) a = jnp.array([True, False]) - shard_map(f, mesh, in_specs=P('x'), out_specs=P('x'))(a) + shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P('x'))(a) def test_switch_rep_rule(self): mesh = jtu.create_mesh((2, 2,), ('x', 'y')) @@ -1068,7 +1420,7 @@ def f(n, x, y): return jax.lax.switch( n, [lambda x, _: x, lambda x, _: x + 1, lambda x, _: x + 2], x, y) - shard_map(f, mesh, in_specs=(P(), P('x'), P('y')), out_specs=P('x'))(1, x, x) + shard_map(f, mesh=mesh, in_specs=(P(), P('x'), P('y')), out_specs=P('x'))(1, x, x) def test_eager_custom_jvp_basic(self): @jax.custom_jvp @@ -1081,7 +1433,7 @@ def foo_jvp(primals, tangents): return foo(x), 3. * x_dot mesh = jtu.create_mesh((4,), ('x',)) - g = shard_map(foo, mesh, in_specs=(P('x'),), out_specs=P('x')) + g = shard_map(foo, mesh=mesh, in_specs=(P('x'),), out_specs=P('x')) y, x_bar = jax.value_and_grad(lambda x: g(x).sum())(jnp.arange(4.)) self.assertAllClose(y, (2. * jnp.arange(4.)).sum()) self.assertAllClose(x_bar, 3. * jnp.ones(4), check_dtypes=False) @@ -1100,7 +1452,7 @@ def foo_bwd(_, y_bar): foo.defvjp(foo_fwd, foo_bwd) mesh = jtu.create_mesh((4,), ('x',)) - g = shard_map(foo, mesh, in_specs=(P('x'),), out_specs=P('x')) + g = shard_map(foo, mesh=mesh, in_specs=(P('x'),), out_specs=P('x')) y, x_bar = jax.value_and_grad(lambda x: g(x).sum())(jnp.arange(4.)) self.assertAllClose(y, (2. * jnp.arange(4.)).sum()) self.assertAllClose(x_bar, 3. * jnp.ones(4), check_dtypes=False) @@ -1114,7 +1466,7 @@ def foo(): foo = jax.jit(foo) mesh = jtu.create_mesh((4,), ('x',)) - ans = shard_map(foo, mesh, in_specs=(), out_specs=P('x'))() + ans = shard_map(foo, mesh=mesh, in_specs=(), out_specs=P('x'))() expected = jnp.arange(4.) self.assertAllClose(ans, expected, check_dtypes=False) @@ -1130,7 +1482,7 @@ def foo(): foo = jax.jit(foo) mesh = jtu.create_mesh((4, 2), ('i', 'j')) - ans1, ans2, ans3 = shard_map(foo, mesh, in_specs=(), + ans1, ans2, ans3 = shard_map(foo, mesh=mesh, in_specs=(), out_specs=P('i', 'j'))() expected1 = jnp.arange(4.)[:, None] + jnp.zeros((4, 2)) expected2 = jnp.arange(2.)[None, :] + jnp.zeros((4, 2)) @@ -1197,7 +1549,7 @@ def test_key_array_with_replicated_last_tile_dim(self): def f(rng): @partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P('i'), - check_rep=False) + check_vma=False) def g(rng): return jnp.array([jax.random.normal(rng[0])]) return g(jax.random.split(rng, 4)) @@ -1236,7 +1588,8 @@ def test_returned_out_sharding(self): mesh = jtu.create_mesh((1, 2), ('x', 'y')) s = NamedSharding(mesh, P('x', 'y')) inp = jax.device_put(jnp.zeros((2, 2)), s) - out = shard_map(lambda x: x, mesh, P('x', 'y'), P('x', 'y'))(inp) + out = shard_map(lambda x: x, mesh=mesh, in_specs=P('x', 'y'), + out_specs=P('x', 'y'))(inp) self.assertEqual(out.sharding, s) self.assertArraysEqual(out, inp) @@ -1304,9 +1657,9 @@ def test_sharding_metadata_in_hlo_attrs(self): def foo(x): x = jnp.sin(x) - x = shard_map(lambda x: jnp.cos(x * y), mesh, + x = shard_map(lambda x: jnp.cos(x * y), mesh=mesh, in_specs=P('i'), out_specs=P('i'))(x) - x = shard_map(lambda x: jnp.cos(x * y), mesh, + x = shard_map(lambda x: jnp.cos(x * y), mesh=mesh, in_specs=P('i'), out_specs=P('i'))(x) return x @@ -1337,7 +1690,7 @@ def f(x): x)[0] * x mesh = jtu.create_mesh((4,), ('x',)) - g = shard_map(f, mesh, in_specs=(P('x'),), out_specs=P('x')) + g = shard_map(f, mesh=mesh, in_specs=(P('x'),), out_specs=P('x')) x = jnp.arange(4.) y = jax.jit(g)(x) # eager requires shmap to have ShardMapTrace.process_call self.assertAllClose(y, 2 * x * x, check_dtypes=True) @@ -1371,7 +1724,7 @@ def foo_jvp(primals, tangents): return foo(x), 2. * x_dot mesh = jtu.create_mesh((4,), ('x',)) - g = shard_map(lambda x: foo(x) * x, mesh, + g = shard_map(lambda x: foo(x) * x, mesh=mesh, in_specs=(P('x'),), out_specs=P('x')) if jit: g = jax.jit(g) @@ -1399,7 +1752,7 @@ def foo_bwd(_, y_bar): foo.defvjp(foo_fwd, foo_bwd) mesh = jtu.create_mesh((4,), ('x',)) - g = shard_map(lambda x: foo(x) * x, mesh, + g = shard_map(lambda x: foo(x) * x, mesh=mesh, in_specs=(P('x'),), out_specs=P('x')) if jit: g = jax.jit(g) @@ -1427,7 +1780,7 @@ def foo_bwd(_, y_bar): foo.defvjp(foo_fwd, foo_bwd) mesh = jtu.create_mesh((4,), ('x',)) - g = shard_map(lambda x: foo(x) * x, mesh, + g = shard_map(lambda x: foo(x) * x, mesh=mesh, in_specs=(P('x'),), out_specs=P('x')) if jit: g = jax.jit(g) @@ -1456,38 +1809,6 @@ def f(x): y = shard_f(x) self.assertEqual(x_spec, y.sharding.spec) - @parameterized.parameters([True, False]) - def test_rewrite_process_custom_vjp_call_match_less_replicated(self, jit): - @jax.custom_vjp - def foo(x, y): - del y - return 2. * x - - def foo_fwd(x, y): - return foo(x, y), y - - def foo_bwd(y, _): - return y, None # diff! x_bar less replicated than primal/tangent - - foo.defvjp(foo_fwd, foo_bwd) - - mesh = jtu.create_mesh((4,), ('x',)) - g = shard_map(lambda x, y: foo(x, y) * y, mesh, - in_specs=(P(), P('x')), out_specs=P('x')) - if jit: - g = jax.jit(g) - - x = jnp.arange(4.) - y = jnp.arange(4 * 4.) - - z = g(x, y) - self.assertAllClose(z, 2 * jnp.tile(x, (4,)) * y, check_dtypes=False) - - z_, x_bar = jax.value_and_grad(lambda x, y: g(x, y).sum())(x, y) - self.assertAllClose(z.sum(), z_, check_dtypes=False) - self.assertAllClose(x_bar, jnp.arange(16).reshape(4, 4).sum(0), - check_dtypes=False) - @parameterized.parameters([True, False]) def test_rewrite_custom_vjp_call_jaxpr(self, jit): @jax.custom_vjp @@ -1507,7 +1828,7 @@ def foo_scan(x): return y mesh = jtu.create_mesh((4,), ('x',)) - g = shard_map(lambda x: foo_scan(x) * x, mesh, + g = shard_map(lambda x: foo_scan(x) * x, mesh=mesh, in_specs=(P('x'),), out_specs=P('x')) if jit: g = jax.jit(g) @@ -1555,7 +1876,7 @@ def f(x): jaxpr = jax.make_jaxpr(jax.vjp(f, jnp.arange(1.))[1])(jnp.arange(4.)) e, = jaxpr.jaxpr.eqns e2, = e.params['jaxpr'].eqns - self.assertEqual(str(e2.primitive), 'psum2') + self.assertEqual(str(e2.primitive), 'psum_invariant') self.assertEqual(e2.params['axes'], ('x',)) def test_fanin_psum_transposes_to_fanout(self): @@ -1568,7 +1889,7 @@ def f(x): jaxpr = jax.make_jaxpr(jax.vjp(f, jnp.arange(4.))[1])(jnp.array([1.])) e, = jaxpr.jaxpr.eqns e1, = e.params['jaxpr'].eqns - self.assertEqual(str(e1.primitive), 'pbroadcast') + self.assertEqual(str(e1.primitive), 'pvary') def test_psum_with_implicit_fanout_self_transposes(self): mesh = jtu.create_mesh((4,), ('x',)) @@ -1580,8 +1901,8 @@ def f(x): jaxpr = jax.make_jaxpr(jax.vjp(f, jnp.arange(4.))[1])(jnp.arange(4.)) e, = jaxpr.jaxpr.eqns e1, e2 = e.params['jaxpr'].eqns - self.assertEqual(str(e1.primitive), 'psum2') - self.assertEqual(str(e2.primitive), 'pbroadcast') + self.assertEqual(str(e1.primitive), 'psum_invariant') + self.assertEqual(str(e2.primitive), 'pvary') def test_transpose_float0(self): mesh = jtu.create_mesh((4,), ('x',)) @@ -1612,7 +1933,7 @@ def g_bwd(vjp_fn, result): def f_shmapped(x, y): return jax.lax.psum(f(x, y).sum(), axis_name=('x')) - @partial(shard_map, mesh=mesh, check_rep=False, + @partial(shard_map, mesh=mesh, check_vma=False, in_specs=P('x'), out_specs=(P('x'), P())) def f_shmapped2(x, y): return g(x, y) @@ -1632,6 +1953,18 @@ def example(x, y): dx, dy = example(x, y) self.assertEqual(dy.dtype, jax.dtypes.float0) + def test_pvary(self): + mesh = jtu.create_mesh((4,), ('x',)) + + @partial(shard_map, mesh=mesh, in_specs=P(), out_specs=P('x')) + def f(x): + y = jax.lax.pvary(x, 'x') + self.assertEqual(y.aval.vma, {'x'}) + return y + + f(jnp.arange(8.)) + jax.grad(lambda x: f(x).sum())(jnp.arange(8.)) + def test_rewrite_binops(self): mesh = jtu.create_mesh((4,), ('x',)) @@ -1642,7 +1975,7 @@ def f(x, y): jaxpr = jax.make_jaxpr(f)(jnp.arange(1.), jnp.arange(4.)) e, = jaxpr.jaxpr.eqns e = e.params['jaxpr'].eqns[0] - self.assertEqual(e.primitive.name, 'pbroadcast') + self.assertEqual(e.primitive.name, 'pvary') self.assertEqual(e.params['axes'], ('x',)) def test_rewrite_scan(self): @@ -1650,16 +1983,17 @@ def test_rewrite_scan(self): @partial(shard_map, mesh=mesh, in_specs=P('x'), out_specs=P('x')) def f(x): - x, _ = jax.lax.scan(lambda x, _: (jax.lax.psum(x, 'x'), None), x, None, - length=2) + def g(x, _): + return lax.pvary(jax.lax.psum(x, 'x'), 'x'), None + x, _ = jax.lax.scan(g, x, None, length=2) return x jaxpr = jax.make_jaxpr(f)(jnp.arange(4.)) e, = jaxpr.jaxpr.eqns e, = e.params['jaxpr'].eqns e1, e2 = e.params['jaxpr'].eqns - self.assertEqual(e1.primitive.name, 'psum2') - self.assertEqual(e2.primitive.name, 'pbroadcast') + self.assertEqual(e1.primitive.name, 'psum_invariant') + self.assertEqual(e2.primitive.name, 'pvary') def test_check_rep_false_grads(self): if jtu.is_device_tpu(5, 'e'): @@ -1673,7 +2007,7 @@ def f(q, k, v): def body(q, k, v): return q * k[None, :] + v[None, :] - out = shard_map(body, mesh, check_rep=False, + out = shard_map(body, mesh=mesh, check_vma=False, in_specs=(q_spec, kv_spec, kv_spec,), out_specs=q_spec)(q, k, v) return out.sum() @@ -1698,7 +2032,7 @@ def foo(x): @partial(jax.remat, policy=lambda *args, **kwargs: True) def bar(x): return shard_map(foo, mesh=Mesh(jax.devices(), ['x']), in_specs=(P('x'),), - out_specs=P('x'), check_rep=False)(x) + out_specs=P('x'), check_vma=False)(x) jax.jit(jax.grad(lambda x: bar(x).sum()))(jnp.arange(8.)) # doesn't crash @@ -1706,7 +2040,7 @@ def bar(x): def test_res_forwarding_optimization(self, jit, remat): mesh = jtu.create_mesh((4,), ('i',)) - @partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P('i')) + @shard_map(mesh=mesh, in_specs=P('i'), out_specs=P('i')) def f(x): return jax.lax.exp(x) if jit: @@ -1719,7 +2053,7 @@ def f(x): x = jnp.arange(16.) jaxpr_ = jax.make_jaxpr(jax.grad(g))(x) jaxpr, _ = pe.dce_jaxpr(jaxpr_.jaxpr, [True] * len(jaxpr_.out_avals)) - e1, _, e2 = jaxpr.eqns + e1, *_, e2 = jaxpr.eqns self.assertLen(e1.outvars, 1) # only primal output self.assertLen(e2.invars, 2) # res and cotangent inputs self.assertEqual(sum(e1.outvars[0] is v for v in e2.invars), 1) @@ -1729,7 +2063,7 @@ def test_res_forwarding_optimization_complex(self, jit, remat): # like the above test, but a different function `f` mesh = jtu.create_mesh((4,), ('i',)) - @partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P('i')) + @shard_map(mesh=mesh, in_specs=P('i'), out_specs=P('i')) def f(x): return jax.lax.exp(x.sum()) + x, jax.lax.exp(x) if jit: @@ -1742,7 +2076,7 @@ def f(x): x = jnp.arange(16.) jaxpr_ = jax.make_jaxpr(jax.grad(g))(x) jaxpr, _ = pe.dce_jaxpr(jaxpr_.jaxpr, [True] * len(jaxpr_.out_avals)) - e1, _, e2 = jaxpr.eqns + e1, *_, e2 = jaxpr.eqns self.assertLen(e1.outvars, 2) # one primal and one res output self.assertLen(e2.invars, 4) # two res and two cotangent inputs self.assertEqual(sum(e1.outvars[-1] is v for v in e2.invars), 1) @@ -1752,7 +2086,7 @@ def test_check_rep_failure_inside_rule(self, jit): mesh = jtu.create_mesh((4,), ('i',)) def loss(w, x): - @partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P()) + @shard_map(mesh=mesh, in_specs=P('i'), out_specs=P()) def f(x): return jax.lax.psum(((w * x) ** 2).sum(), 'i') return f(x) @@ -1768,8 +2102,8 @@ def test_conv_general_dilated(self): dot = partial(lax.conv_general_dilated, window_strides=(), padding='VALID', dimension_numbers=('NC', 'IO', 'NC')) - @partial(shard_map, mesh=mesh, in_specs=(P(None, 'i'), P('i', None)), - out_specs=P(None, None)) + @shard_map(mesh=mesh, in_specs=(P(None, 'i'), P('i', None)), + out_specs=P(None, None)) def f(x, y): return lax.psum(dot(x, y), 'i') @@ -1809,7 +2143,7 @@ def f(*args): return args[0] @ args[1] shard_f = shard_map( - f, mesh, in_specs=(P('x', 'y', None), P('x', 'y', None)), out_specs=P('x', 'y')) + f, mesh=mesh, in_specs=(P('x', 'y', None), P('x', 'y', None)), out_specs=P('x', 'y')) with self.assertRaisesRegex(ValueError, "shard_map applied to the function 'f'"): shard_f(jnp.ones((8, 8)), jnp.ones((8, 8))) @@ -1852,7 +2186,8 @@ def test_approx_top_k(self): mesh = Mesh(np.array(jax.devices()[:2]), ('i',)) x = jnp.array([3.0, 1.0, 4.0, 2.0]) - _ = shard_map(lambda x: lax.approx_max_k(x, 2), mesh, P('i'), P('i'))(x) + _ = shard_map(lambda x: lax.approx_max_k(x, 2), mesh=mesh, in_specs=P('i'), + out_specs=P('i'))(x) def test_disable_jit(self): mesh = Mesh(np.array(jax.devices()[:2]), ('i',)) @@ -1898,10 +2233,10 @@ def g(x): @jax.jit def f(x): - x = shard_map(g, mesh, + x = shard_map(g, mesh=mesh, in_specs=P('i', None), out_specs=P('i', None), - auto=frozenset({'j'}))(x) + axis_names=frozenset({'i'}))(x) return jax.lax.with_sharding_constraint( x, jax.sharding.NamedSharding(mesh, P('i', 'j'))) @@ -1909,8 +2244,8 @@ def f(x): v = jax.device_put(v, jax.sharding.NamedSharding(mesh, P('i', 'j'))) if config.use_shardy_partitioner.value: self.assertIn( - 'in_shardings=[<@mesh, [{"i"}, {?}]>]' - ' out_shardings=[<@mesh, [{"i"}, {?}]>] manual_axes={"i"}', + 'in_shardings=[<@mesh, [{"i", ?}, {?}]>]' + ' out_shardings=[<@mesh, [{"i", ?}, {?}]>] manual_axes={"i"}', f.lower(v).as_text(), ) else: @@ -1935,10 +2270,10 @@ def g(x): @jax.jit def f(x): - x = shard_map(g, mesh, + x = shard_map(g, mesh=mesh, in_specs=P('i', None), out_specs=P('i', None), - auto=frozenset({'j'}))(x) + axis_names=frozenset({'i'}))(x) self.assertEqual(x.aval.sharding.spec, P('i', 'j')) return x @@ -1949,7 +2284,7 @@ def f(x): self.assertEqual(out.sharding, NamedSharding(mesh, P('i', 'j'))) self.assertAllClose(v * v, out, check_dtypes=False) - @jtu.with_user_mesh((2, 2), ('i', 'j')) + @jtu.with_explicit_mesh((2, 2), ('i', 'j')) def test_partial_auto_explicit(self, mesh): def g(x): self.assertDictEqual(x.aval.sharding.mesh._axis_types_dict, @@ -1961,10 +2296,7 @@ def g(x): @jax.jit def f(x): - x = shard_map(g, mesh, - in_specs=P('i', None), - out_specs=P('i', None), - auto=frozenset({'j'}))(x) + x = jax.shard_map(g, out_specs=P('i', None), axis_names=frozenset({'i'}))(x) self.assertEqual(x.aval.sharding.spec, P('i', 'j')) return x @@ -1993,7 +2325,7 @@ def h(x): jax.grad(h)(v) # doesn't crash jax.jit(jax.grad(h))(v) # doesn't crash - @jtu.with_user_mesh((2, 1, 2, 2), ('i', 'j', 'k', 'l')) + @jtu.with_explicit_mesh((2, 1, 2, 2), ('i', 'j', 'k', 'l')) def test_partial_auto_explicit_multi_explicit(self, mesh): def g(x): self.assertDictEqual(x.aval.sharding.mesh._axis_types_dict, @@ -2006,10 +2338,8 @@ def g(x): @jax.jit def f(x): - x = shard_map(g, mesh, - in_specs=P('i', 'j', None, None), - out_specs=P('i', 'j', None, None), - auto=frozenset({'k', 'l'}))(x) + x = jax.shard_map(g, out_specs=P('i', 'j', None, None), + axis_names=frozenset({'i', 'j'}))(x) self.assertEqual(x.aval.sharding.spec, P(('i', 'l'), ('j', 'k'), None, None)) return x @@ -2031,11 +2361,11 @@ def g(x): def f(x): return shard_map( g, - mesh, + mesh=mesh, in_specs=P(), out_specs=P(), - check_rep=False, - auto=frozenset({'i'}), + check_vma=False, + axis_names=frozenset({'j', 'k'}), )(x) v = jnp.arange(32.0).reshape(4, 8) @@ -2067,13 +2397,43 @@ def update_fn(params, batch): def grad_fn(batch): return jax.value_and_grad(loss_fn)(params, batch) return shard_map(grad_fn, mesh=mesh, in_specs=P("data"), out_specs=P(), - check_rep=False)(batch) + check_vma=False)(batch) arr_sharded = jax.device_put(jnp.arange(32.0).reshape(4, 8), NamedSharding(mesh, P())) params = jnp.copy(arr_sharded) update_fn(params, arr_sharded) # doesn't crash + @jtu.with_explicit_mesh((2,), ('x',)) + def test_close_over_explicit_sharded_input_error(self, mesh): + def simple_func(w, x): + return jnp.sum(w * x, axis=-1) + + w = jnp.ones((2, 4), dtype=np.float32) + x = jnp.ones((4, 4), dtype=np.float32) + + shard_map(simple_func, in_specs=(P(), P('x')), out_specs=P('x'))(w, x) + + with self.assertRaisesRegex( + NotImplementedError, + 'Closing over inputs to shard_map where the input is sharded on' + ' `Explicit` axes is not implemented'): + shard_map(lambda xi: simple_func(w, xi), + in_specs=P('x'), out_specs=P('x'))(x) + + def test_close_over_input_explict_ctx_mesh(self): + mesh = jtu.create_mesh((2,), 'x', axis_types=(AxisType.Explicit,)) + w = jnp.ones((2, 4), dtype=np.float32) + x = jnp.ones((4, 4), dtype=np.float32) + + def simple_func(w, x): + return jnp.sum(w * x, axis=-1) + + shard_map(simple_func, mesh=mesh, in_specs=(P(), P('x')), + out_specs=P('x'))(w, x) + shard_map(lambda xi: simple_func(w, xi), mesh=mesh, + in_specs=P('x'), out_specs=P('x'))(x) + def test_shmap_close_over_unused_params_vmap(self): mesh = jtu.create_mesh((2,), ("data",)) @@ -2085,7 +2445,7 @@ def update_fn(params, batch): def grad_fn(batch): return jax.value_and_grad(loss_fn)(params, batch) return shard_map(jax.vmap(grad_fn), mesh=mesh, in_specs=P("data"), - out_specs=P("data"), check_rep=False)(batch) + out_specs=P("data"), check_vma=False)(batch) arr_sharded = jax.device_put(jnp.arange(32.0).reshape(4, 8), NamedSharding(mesh, P())) @@ -2119,11 +2479,11 @@ def g(x): @jax.jit def f(x): - x = shard_map(g, mesh, + x = shard_map(g, mesh=mesh, in_specs=P('i', None), out_specs=P('i', None), - check_rep=False, - auto=frozenset({'j'}))(x) + check_vma=False, + axis_names=frozenset({'i'}))(x) return jax.lax.with_sharding_constraint( x, jax.sharding.NamedSharding(mesh, P('i', 'j'))) @@ -2142,17 +2502,17 @@ def g(x): @jax.jit def f(x): - x = shard_map(g, mesh, + x = shard_map(g, mesh=mesh, in_specs=P('i', None), out_specs=P('i', None), - check_rep=False, - auto=frozenset({'k'}))(x) + check_vma=False, + axis_names=frozenset({'i', 'j'}))(x) return jax.lax.with_sharding_constraint( x, jax.sharding.NamedSharding(mesh, P('i', 'j'))) v = jnp.arange(32.).reshape(4, 8) v = jax.device_put(v, jax.sharding.NamedSharding(mesh, P('i', 'j'))) - with self.assertRaisesRegex(ValueError, "to be a subset of mesh.axis_names"): + with self.assertRaisesRegex(ValueError, "contains a manual axes.*of mesh"): f(v) def test_partial_auto_error_wrong_in_specs(self): @@ -2165,11 +2525,11 @@ def g(x): @jax.jit def f(x): - x = shard_map(g, mesh, + x = shard_map(g, mesh=mesh, in_specs=P('i', 'j'), out_specs=P('i', None), - check_rep=False, - auto=frozenset({'j'}))(x) + check_vma=False, + axis_names=frozenset({'i'}))(x) return jax.lax.with_sharding_constraint( x, jax.sharding.NamedSharding(mesh, P('i', 'j'))) @@ -2178,28 +2538,83 @@ def f(x): with self.assertRaisesRegex(ValueError, "in_specs refers to 'j'"): f(v) - def test_nested_partial_auto(self): + def test_partial_auto_mismatch_mesh_error(self): mesh = jtu.create_mesh((2, 2), ('i', 'j')) + v = jnp.arange(32.).reshape(4, 8) + v = jax.device_put(v, jax.sharding.NamedSharding(mesh, P('i', 'j'))) def g(x): return x * x def h(x): - return shard_map(g, mesh, - in_specs=P(None, 'j'), - out_specs=P(None, 'j'))(x) + return shard_map(g, mesh=mesh, in_specs=P(None, 'j'), + out_specs=P(None, 'j'))(x) @jax.jit def f(x): - return shard_map(h, mesh, - in_specs=P('i', None), - out_specs=P('i', None), - check_rep=False, - auto=frozenset({'j'}))(x) + return shard_map(h, mesh=mesh, in_specs=P('i', None), + out_specs=P('i', None), check_vma=False, + axis_names=frozenset({'i'}))(x) + + with self.assertRaisesRegex( + ValueError, r"context mesh.*should match the mesh passed to shard_map"): + self.assertAllClose(v*v, f(v), check_dtypes=False) + def test_nested_partial_auto(self): + mesh = jtu.create_mesh((2, 2), ('i', 'j')) v = jnp.arange(32.).reshape(4, 8) v = jax.device_put(v, jax.sharding.NamedSharding(mesh, P('i', 'j'))) - self.assertAllClose(v*v, f(v), check_dtypes=False) + + def g(x): + return x * x + + def h(x): + return shard_map(g, in_specs=P(None, 'j'), out_specs=P(None, 'j'))(x) + + @jax.jit + def f(x): + return shard_map(h, in_specs=P('i', None), out_specs=P('i', None), + check_vma=False, axis_names=frozenset({'i'}))(x) + + with jax.sharding.use_mesh(mesh): + self.assertAllClose(v*v, f(v), check_dtypes=False) + + @parameterized.named_parameters( + ('0', 'x', 'y', {'x'}, {'x', 'y'}), + ('1', None, 'y', frozenset(), {'y'}), + ('2', 'x', None, {'x'}, {'x'}), + ('3', None, None, frozenset(), frozenset()), + ) + def test_nested_partial_auto_1d(self, dim1, dim2, outer_vma, inner_vma): + mesh = jtu.create_mesh((2, 2, 2), ('x', 'y', 'z')) + np_inp = np.arange(32.).reshape(4, 8) + arr = jax.device_put(np_inp, NamedSharding(mesh, P(dim1, dim2))) + + def g(x): + self.assertEqual(get_abstract_mesh().manual_axes, ('x', 'y')) + self.assertEqual(get_abstract_mesh().auto_axes, ('z',)) + self.assertEqual(x.aval.vma, inner_vma) + out = x * x + self.assertEqual(out.aval.vma, inner_vma) + return out + + def h(x): + self.assertEqual(get_abstract_mesh().manual_axes, ('x',)) + self.assertEqual(get_abstract_mesh().auto_axes, ('y', 'z')) + self.assertEqual(x.aval.vma, outer_vma) + out = shard_map(g, in_specs=P(None, dim2), + out_specs=P(None, dim2), axis_names={'y'})(x) + self.assertEqual(out.aval.vma, outer_vma) + return out + + @jax.jit + def f(x): + return shard_map(h, in_specs=P(dim1, None), + out_specs=P(dim1, None), axis_names={'x'})(x) + + with jax.sharding.use_mesh(mesh): + out = f(arr) + self.assertArraysEqual(out, np_inp * np_inp) def test_grad_nested_partial_auto(self): mesh = jtu.create_mesh((2, 2), ('i', 'j')) @@ -2210,22 +2625,19 @@ def g(x): def h(x): # auto: 'j', manual: 'i' - return shard_map(g, mesh, - in_specs=P(None, 'j'), - out_specs=P(None, 'j'))(x) + return shard_map(g, in_specs=P(None, 'j'), out_specs=P(None, 'j'))(x) @jax.jit def f(x): # auto: 'i', 'j' - return shard_map(h, mesh, - in_specs=P('i', None), - out_specs=P('i', None), - check_rep=False, - auto=frozenset({'j'}))(x).sum() + return shard_map(h, in_specs=P('i', None), out_specs=P('i', None), + check_vma=False, axis_names=frozenset({'i'}))(x).sum() v = jnp.arange(32.).reshape(4, 8) v = jax.device_put(v, jax.sharding.NamedSharding(mesh, P('i', 'j'))) - self.assertAllClose(v*2, jax.grad(f)(v), check_dtypes=False) + with jax.sharding.use_mesh(mesh): + out = jax.grad(f)(v) + self.assertAllClose(out, v * 2, check_dtypes=False) def test_grad_nested_partial_auto_with_residuals(self): mesh = jtu.create_mesh((2, 2), ('i', 'j')) @@ -2234,21 +2646,18 @@ def g(x): return x * x * x def h(x): - return shard_map(g, mesh, - in_specs=P(None, 'j'), - out_specs=P(None, 'j'))(x) + return shard_map(g, in_specs=P(None, 'j'), out_specs=P(None, 'j'))(x) @jax.jit def f(x): - return shard_map(h, mesh, - in_specs=P('i', None), - out_specs=P('i', None), - check_rep=False, - auto=frozenset({'j'}))(x).sum() + return shard_map(h, in_specs=P('i', None), out_specs=P('i', None), + check_vma=False, axis_names=frozenset({'i'}))(x).sum() v = jnp.arange(32.).reshape(4, 8) v = jax.device_put(v, jax.sharding.NamedSharding(mesh, P('i', 'j'))) - self.assertAllClose(v*v*3, jax.grad(f)(v), check_dtypes=False) + with jax.sharding.use_mesh(mesh): + out = jax.grad(f)(v) + self.assertAllClose(out, v * v * 3, check_dtypes=False) def test_axis_size_1_partial_auto(self): mesh = jtu.create_mesh((1, 2, 2), ('i', 'j', 'k')) @@ -2258,11 +2667,11 @@ def h(x): @jax.jit def f(x): - return shard_map(h, mesh, + return shard_map(h, mesh=mesh, in_specs=P('i', None), out_specs=P('i', None), - check_rep=False, - auto=frozenset({'j', 'k'}))(x) + check_vma=False, + axis_names=frozenset({'i'}))(x) v = jnp.arange(32.).reshape(4, 8) v = jax.device_put(v, jax.sharding.NamedSharding(mesh, P('i', 'j'))) @@ -2280,8 +2689,8 @@ def _make_zeros(): def f(): return shard_map( - h, mesh, in_specs=(), - out_specs=P('i'), check_rep=False, auto=frozenset({'j'}))() + h, mesh=mesh, in_specs=(), + out_specs=P('i'), check_vma=False, axis_names=frozenset({'i'}))() self.assertAllClose(jax.jit(f)(), jnp.zeros((2,))) @@ -2303,8 +2712,8 @@ def _make_zeros(): def f(): return shard_map( - h, mesh, in_specs=(), - out_specs=P('i'), check_rep=False, auto=frozenset({'j'}))() + h, mesh=mesh, in_specs=(), + out_specs=P('i'), check_vma=False, axis_names=frozenset({'i'}))() self.assertAllClose(jax.jit(f)(), jnp.zeros((2,))) @@ -2315,10 +2724,11 @@ def test_partial_auto_axis_index(self): @partial(jax.jit, out_shardings=out_sharding) def f(): return shard_map(lambda: jax.lax.axis_index('i').reshape(1,1), - mesh, in_specs=P('i', None), out_specs=P('i', None), - check_rep=False, auto=frozenset({'j'}))() + in_specs=P('i', None), out_specs=P('i', None), + check_vma=False, axis_names=frozenset({'i'}))() - self.assertAllClose(f(), np.arange(4, dtype=np.int32).reshape(-1, 1)) + with jax.sharding.use_mesh(mesh): + self.assertAllClose(f(), np.arange(4, dtype=np.int32).reshape(-1, 1)) def test_partial_auto_axis_index_degenerated_axis(self): mesh = jtu.create_mesh((1, 2), ('i', 'j')) @@ -2327,8 +2737,8 @@ def test_partial_auto_axis_index_degenerated_axis(self): @partial(jax.jit, out_shardings=out_sharding) def f(): return shard_map(lambda: jax.lax.axis_index('i').reshape(1, 1), - mesh, in_specs=P('i', None), out_specs=P('i', None), - check_rep=False, auto=frozenset({'j'}))() + mesh=mesh, in_specs=P('i', None), out_specs=P('i', None), + check_vma=False, axis_names=frozenset({'i'}))() self.assertAllClose(f(), np.arange(1, dtype=np.int32).reshape(-1, 1)) @@ -2343,8 +2753,8 @@ def g(x): @jax.jit def f(x): return shard_map(g, - mesh, in_specs=P('i'), out_specs=P('i'), - check_rep=False, auto=frozenset({'j'}))(x) + mesh=mesh, in_specs=P('i'), out_specs=P('i'), + check_vma=False, axis_names=frozenset({'i'}))(x) y = f(x) # don't crash self.assertAllClose(y, jnp.array([6., 7., 0., 1., 2., 3., 4., 5.]), @@ -2363,8 +2773,8 @@ def f(x): # @jax.jit # def f(x): # return shard_map(g, - # mesh, in_specs=P('i', None), out_specs=P(None, 'i'), - # check_rep=False, auto=frozenset({'j'}))(x) + # mesh=mesh, in_specs=P('i', None), out_specs=P(None, 'i'), + # check_vma=False, axis_names=frozenset({'i'}))(x) # # f(x) # don't crash @@ -2380,11 +2790,11 @@ def g(x): @jax.jit def f(x): - return shard_map(g, - mesh, in_specs=P('i'), out_specs=None, - check_rep=False, auto=frozenset({'j'}))(x) + return shard_map(g, mesh=mesh, in_specs=P('i'), out_specs=None, + check_vma=False, axis_names=frozenset({'i'}))(x) - y = f(x) # don't crash + with jax.sharding.use_mesh(mesh): + f(x) # don't crash def test_partial_auto_of_random_keys(self): mesh = jtu.create_mesh((4, 2), ('i', 'j')) @@ -2393,8 +2803,8 @@ def test_partial_auto_of_random_keys(self): @jax.jit def f(x): return shard_map(lambda k: k, - mesh, in_specs=P('i'), out_specs=P('i'), - check_rep=False, auto=frozenset({'j'}))(keys) + mesh=mesh, in_specs=P('i'), out_specs=P('i'), + check_vma=False, axis_names=frozenset({'i'}))(keys) y = f(keys) # doesn't crash self.assertAllClose(jax.random.key_data(y), jax.random.key_data(keys), @@ -2407,21 +2817,26 @@ def test_partial_auto_of_random_keys_slice(self): @jax.jit def f(x): return shard_map(lambda k: k[0], - mesh, in_specs=P('i'), out_specs=P('i'), - check_rep=False, auto=frozenset({'j'}))(x) + mesh=mesh, in_specs=P('i'), out_specs=P('i'), + check_vma=False, axis_names=frozenset({'i'}))(x) f(keys) # doesn't crash + def test_grad_remat(self): + mesh = jtu.create_mesh((1, 1), ('i', 'j')) + args = [jnp.arange(6.).reshape(3, 2), jnp.arange(6.).reshape(3, 2, 1)] + + @partial(jax.remat, policy=lambda *_, **__: True) + @shard_map(mesh=mesh, in_specs=(P('j'), P('i')), out_specs=P('i', 'j')) + def f(x, y): + return jnp.dot(x, y) + jax.grad(lambda x, y: f(x, y).sum())(*args) + def test_vmap_grad_shmap_spmd_axis_name_residuals(self): # https://github.com/jax-ml/jax/pull/21032 mesh = jtu.create_mesh((4, 2), ('i', 'j')) - @partial( - shard_map, - mesh=mesh, - in_specs=P('j'), - out_specs=P('j'), - ) + @shard_map(mesh=mesh, in_specs=P('j'), out_specs=P('j')) def f(x): return jnp.sin(x) @@ -2434,12 +2849,7 @@ def test_vmap_grad_remat_shmap_spmd_axis_name_residuals(self): mesh = jtu.create_mesh((4, 2), ('i', 'j')) @partial(jax.remat, policy=lambda *_, **__: True) - @partial( - shard_map, - mesh=mesh, - in_specs=P('j'), - out_specs=P('j'), - ) + @partial(shard_map, mesh=mesh, in_specs=P('j'), out_specs=P('j')) def f(x): return jnp.sin(x) @@ -2454,8 +2864,8 @@ def test_grad_shmap_residuals_axis_names_in_mesh_order(self): @partial( shard_map, mesh=mesh, - in_specs=P('j'), - out_specs=P('j'), + in_specs=P(('i', 'k')), + out_specs=P(('i', 'k')), ) def f(x): return jnp.sin(x) @@ -2465,22 +2875,45 @@ def f(x): ir = jax.jit(jax.grad(lambda x: f(x).sum())).lower(xs) if config.use_shardy_partitioner.value: self.assertIn( - 'out_shardings=[<@mesh, [{"i", "j", "k", "a"}]>]', ir.as_text() + 'out_shardings=[<@mesh, [{"i", "k"}]>]', ir.as_text() ) else: self.assertIn( - "{jax.result_info = \"[('i', 'j', 'k', 'a')]\"}", ir.as_text() + "{jax.result_info = \"[('i', 'k')]\"}", ir.as_text() ) + def test_dynamic_slice_transpose(self): + mesh = jtu.create_mesh((2,), ('x',)) + arr = np.arange(16., dtype=np.float32) + + @partial(shard_map, mesh=mesh, in_specs=P('x'), out_specs=P('x')) + def f(x): + return lax.dynamic_slice_in_dim(x, jnp.array(1, dtype=np.int32), 2) + + f(arr) # doesn't crash + jax.jit(f)(arr) # doesn't crash + + def g(x): + return jnp.sum(f(x)) + + jax.grad(g)(arr) # doesn't crash + jax.jit(jax.grad(g))(arr) # doesn't crash + + @parameterized.parameters([P()], [P('x')], [P(('x', 'y'))]) + def test_print_inside_shard_map(self, specs): + mesh = jtu.create_mesh((2, 2), ('x', 'y')) + x = jnp.arange(4.) + + @partial(shard_map, mesh=mesh, in_specs=specs, out_specs=specs) + def f(x): + print(x) + return 2 * x + f(x) # doesn't crash + def test_vmap_spmd_axis_name_error(self): mesh = jtu.create_mesh((4, 2), ('i', 'j')) - @partial( - shard_map, - mesh=mesh, - in_specs=P('i'), - out_specs=P('i'), - ) + @partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P('i')) def f(x): return jnp.sin(x) @@ -2488,13 +2921,8 @@ def f(x): with self.assertRaisesRegex(ValueError, "spmd_axis_name cannot appear"): jax.vmap(f, spmd_axis_name='i')(xs) - @partial( - shard_map, - mesh=mesh, - in_specs=P('j'), - out_specs=P(('i', 'j')), - check_rep=False, - ) + @partial(shard_map, mesh=mesh, in_specs=P('j'), out_specs=P(('i', 'j')), + check_vma=False) def g(x): return jnp.sin(x) @@ -2512,11 +2940,11 @@ def f(o, x): return jnp.sin(x) obj = object() - y = shard_map(f, mesh, (None, P('i')), P('i'))(obj, x) + y = shard_map(f, mesh=mesh, in_specs=(None, P('i')), out_specs=P('i'))(obj, x) self.assertAllClose(y, jnp.sin(x), check_dtypes=False) obj = None - y = shard_map(f, mesh, (None, P('i')), P('i'))(None, x) + y = shard_map(f, mesh=mesh, in_specs=(None, P('i')), out_specs=P('i'))(None, x) self.assertAllClose(y, jnp.sin(x), check_dtypes=False) def f2(o, x): @@ -2525,7 +2953,7 @@ def f2(o, x): return jnp.sin(x) obj = {'a': object()} - y = shard_map(f2, mesh, ({'a': None}, P('i')), P('i'))(obj, x) + y = shard_map(f2, mesh=mesh, in_specs=({'a': None}, P('i')), out_specs=P('i'))(obj, x) self.assertAllClose(y, jnp.sin(x), check_dtypes=False) def f3(x, o): @@ -2533,11 +2961,11 @@ def f3(x, o): return jnp.sin(x) obj = object() - y = shard_map(f3, mesh, (P('i'), None), P('i'))(x, obj) + y = shard_map(f3, mesh=mesh, in_specs=(P('i'), None), out_specs=P('i'))(x, obj) self.assertAllClose(y, jnp.sin(x), check_dtypes=False) obj = None - y = shard_map(f3, mesh, (P('i'), None), P('i'))(x, obj) + y = shard_map(f3, mesh=mesh, in_specs=(P('i'), None), out_specs=P('i'))(x, obj) self.assertAllClose(y, jnp.sin(x), check_dtypes=False) def f4(o1, o2, x, o3): @@ -2550,7 +2978,8 @@ def f4(o1, o2, x, o3): obj1 = object() obj2 = (object(), object()) obj3 = object() - y = shard_map(f4, mesh, (None, None, P('i'), None), P('i'))(obj1, obj2, x, obj3) + y = shard_map(f4, mesh=mesh, in_specs=(None, None, P('i'), None), + out_specs=P('i'))(obj1, obj2, x, obj3) self.assertAllClose(y, jnp.sin(x), check_dtypes=False) def test_in_spec_none_divisibility_errors(self): @@ -2558,44 +2987,48 @@ def test_in_spec_none_divisibility_errors(self): x = jnp.arange(4).reshape(2, 2) with self.assertRaisesRegex(ValueError, 'divisible'): - shard_map(lambda *_: None, mesh, (None, P('i')), None)(object(), x) + shard_map(lambda *_: None, mesh=mesh, in_specs=(None, P('i')), + out_specs=None)(object(), x) with self.assertRaisesRegex(ValueError, 'divisible'): - shard_map(lambda *_: None, mesh, (P('i'), None), None)(x, object()) + shard_map(lambda *_: None, mesh=mesh, in_specs=(P('i'), None), + out_specs=None)(x, object()) with self.assertRaisesRegex(ValueError, 'divisible'): - shard_map(lambda *_: None, mesh, (P('i'), None), None - )(x, (object(), object())) + shard_map(lambda *_: None, mesh=mesh, in_specs=(P('i'), None), + out_specs=None)(x, (object(), object())) with self.assertRaisesRegex(ValueError, 'divisible'): - shard_map(lambda *_: None, mesh, (P('i'), (None, None)), None, - )(x, (object(), object())) + shard_map(lambda *_: None, mesh=mesh, in_specs=(P('i'), (None, None)), + out_specs=None)(x, (object(), object())) with self.assertRaisesRegex(ValueError, 'divisible'): - shard_map(lambda *_: None, mesh, ((None, None), P('i')), None, - )((object(), object()), x) + shard_map(lambda *_: None, mesh=mesh, in_specs=((None, None), P('i')), + out_specs=None)((object(), object()), x) def test_in_spec_none_rank_errors(self): mesh = jtu.create_mesh((4, 2), ('i', 'j')) x = jnp.arange(4) with self.assertRaisesRegex(ValueError, 'rank'): - shard_map(lambda *_: None, mesh, (None, P('i', 'j')), None)(object(), x) + shard_map(lambda *_: None, mesh=mesh, in_specs=(None, P('i', 'j')), + out_specs=None)(object(), x) with self.assertRaisesRegex(ValueError, 'rank'): - shard_map(lambda *_: None, mesh, (P('i', 'j'), None), None)(x, object()) + shard_map(lambda *_: None, mesh=mesh, in_specs=(P('i', 'j'), None), + out_specs=None)(x, object()) with self.assertRaisesRegex(ValueError, 'rank'): - shard_map(lambda *_: None, mesh, (P('i', 'j'), None), None - )(x, (object(), object())) + shard_map(lambda *_: None, mesh=mesh, in_specs=(P('i', 'j'), None), + out_specs=None)(x, (object(), object())) with self.assertRaisesRegex(ValueError, 'rank'): - shard_map(lambda *_: None, mesh, (P('i', 'j'), (None, None)), None, - )(x, (object(), object())) + shard_map(lambda *_: None, mesh=mesh, in_specs=(P('i', 'j'), (None, None)), + out_specs=None)(x, (object(), object())) with self.assertRaisesRegex(ValueError, 'rank'): - shard_map(lambda *_: None, mesh, ((None, None), P('i', 'j')), None, - )((object(), object()), x) + shard_map(lambda *_: None, mesh=mesh, in_specs=((None, None), P('i', 'j')), + out_specs=None)((object(), object()), x) def test_custom_linear_solve_rep_rules(self): # https://github.com/jax-ml/jax/issues/20162 @@ -2616,7 +3049,7 @@ def test_temporary_error_suppression_flag(self): def f(x, y): z = shard_map(lambda x, y: x + jax.lax.all_gather(y, 'i', tiled=True), mesh=mesh, in_specs=(P(None), P('i')), out_specs=P(None), - check_rep=False, + check_vma=False, )(x, y) return z @@ -2650,13 +3083,7 @@ def f(x, reduce_along, use_jit): @partial(shard_map, mesh=mesh, in_specs=P('x', 'y'), out_specs=out_spec) def g(x): result = lax.psum(x, axis_name=reduce_along) - def check_rep(result): - self.assertEqual( - jax.experimental.shard_map.get_replication(result), - set(reduce_along)) - return result - result = check_rep(result) - result = jax.vmap(check_rep)(result) + self.assertEqual(result.aval.vma, x.aval.vma - set(reduce_along)) return result if use_jit: return jax.jit(g)(x) @@ -2673,18 +3100,500 @@ def test_pmin(self): mesh = jtu.create_mesh((4,), ('i',)) x = jnp.arange(8., dtype=np.float32) y = shard_map(lambda x: jax.lax.pmin(x, 'i'), - mesh=mesh, in_specs=P('i'), out_specs=P() - )(x) # don't crash + mesh=mesh, in_specs=P('i'), out_specs=P())(x) # don't crash self.assertArraysEqual(y, np.array([0, 1], dtype=np.float32)) def test_pmax(self): mesh = jtu.create_mesh((4,), ('i',)) x = jnp.arange(8., dtype=np.float32) y = shard_map(lambda x: jax.lax.pmax(x, 'i'), - mesh=mesh, in_specs=P('i'), out_specs=P() - )(x) # don't crash + mesh=mesh, in_specs=P('i'), out_specs=P())(x) # don't crash self.assertArraysEqual(y, np.array([6, 7], dtype=np.float32)) + def test_pmax_vma_in_types(self): + mesh = jtu.create_mesh((4,), ('i',)) + x = jnp.arange(8., dtype=np.float32) + f = jax.jit(shard_map(lambda x: jax.lax.pmax(x, 'i'), mesh=mesh, + in_specs=P(), out_specs=P())) + jaxpr = f.trace(x).jaxpr + self.assertIn("pvary[axes=('i',)", str(jaxpr)) + f(x) # doesn't crash + + def test_mul_with_vma_in_types(self): + mesh = jtu.create_mesh((2,), ('x',)) + x = np.arange(8.) + + def f(x): + self.assertEqual(x.aval.vma, frozenset({'x'})) + out = x * 2 + self.assertEqual(out.aval.vma, frozenset({'x'})) + return out + + f = jax.jit(shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P('x'))) + jaxpr = f.trace(x).jaxpr + self.assertIn("pvary[axes=('x',)", str(jaxpr)) + out = f(x) + self.assertArraysEqual(out, x * 2) + + # TODO(yashkatariya): Enable grad test which requires adding psum_p support. + # def g(x, y): + # return jnp.sum(f(x, y)) + # print(jax.jit(jax.grad(g)).trace(x, y).jaxpr) + + def test_all_gather_with_vma_in_types(self): + mesh = jtu.create_mesh((2,), ('x',)) + x = np.arange(8.) + + def f(x): + self.assertEqual(x.aval.vma, frozenset()) + out = jax.lax.all_gather(x, 'x') + self.assertEqual(out.aval.vma, frozenset({'x'})) + return out + + f = jax.jit(shard_map(f, mesh=mesh, in_specs=P(), out_specs=P('x'))) + jaxpr = f.trace(x).jaxpr + self.assertIn("pvary[axes=('x',)", str(jaxpr)) + + f(x) # doesn't crash + + def test_rep_none_canonicalization(self): + # https://github.com/jax-ml/jax/issues/26621 + if config.use_shardy_partitioner.value: + self.skipTest('complex values fail under shardy') + N = 8 + xs = jnp.ones((8, N), dtype=jnp.int32) + variables = jax.random.normal(jax.random.key(1), (N, N), jnp.complex64) + mesh = jtu.create_mesh((2,), ('i',)) + in_specs = (P(), P("i"),) + out_specs = P("i") + + variables = jax.lax.with_sharding_constraint(variables, NamedSharding(mesh, P())) + xs = jax.lax.with_sharding_constraint(xs, NamedSharding(mesh, P('i'))) + + def fun(v, xs): + # Commenting this single line below makes everything work + v = jax.scipy.linalg.expm(v) + v = v.sum() + return v * xs.sum(axis=-1).astype(v.dtype) + + res = fun(variables, xs) + fun_shard_map = shard_map(fun, mesh=mesh, in_specs=in_specs, out_specs=out_specs) + res = fun_shard_map(variables, xs) # don't crash + + def test_rep_none_canonicalization_again(self): + # https://github.com/jax-ml/jax/issues/24762 + mesh = jtu.create_mesh((2,), ('i',)) + def f(x): + return jnp.insert(x, 0, 0)[None] + f = shard_map(f, mesh=mesh, in_specs=P('i'), out_specs=P('i')) + f(jnp.zeros(100)) # don't crash + + def test_custom_jvp_symbolic_zeros(self): + # https://github.com/jax-ml/jax/issues/26763 + mesh = jtu.create_mesh((4,), ('i',)) + @jax.custom_jvp + def f(a: jax.Array, b: jax.Array) -> jax.Array: + return a + b + + @partial(f.defjvp, symbolic_zeros=True) + def f_jvp(primals, tangents): + a, b = primals + a_dot, b_dot = tangents + y = f(a, b) + y_dot = jnp.zeros_like(y) + if not isinstance(a_dot, SymbolicZero): + y_dot += a_dot + if not isinstance(b_dot, SymbolicZero): + y_dot += b_dot + return y, y_dot + x = jax.random.normal(jax.random.key(0), (jax.device_count(), 20)) + A = jax.random.normal(jax.random.key(1), (jax.device_count(), 20)) + + g = shard_map(f, mesh=mesh, in_specs=P('i'), out_specs=P('i')) + jax.jvp(lambda x: g(x, A), (x,), (x,)) # don't crash + + def test_cond_pvary_errors(self): + mesh = jtu.create_mesh((1, 1), ('x', 'y')) + def f(x, y): + def true_fn(x, y): + return x + def false_fun(x, y): + return y + return jax.lax.cond(True, true_fn, false_fun, x, y) + x = jnp.arange(4.) + with self.assertRaisesRegex( + TypeError, + r"applying `jax.lax.pvary\(..., \('y',\)\)` to the output of true_fun"): + shard_map(f, mesh=mesh, in_specs=(P('x'), P('y')), out_specs=P(('x', 'y')))(x, x) + + def test_cond_pvary_errors_pytree(self): + mesh = jtu.create_mesh((1, 1), ('x', 'y')) + + def f(x, y): + def true_fn(x, y): + return x, y + def false_fun(x, y): + return y, x + return jax.lax.cond(True, true_fn, false_fun, x, y) + x = jnp.arange(4.) + with self.assertRaisesRegex( + TypeError, + r"applying `jax.lax.pvary\(..., \('y',\)\)` to the output of true_fun"): + shard_map(f, mesh=mesh, in_specs=(P('x'), P('y')), out_specs=P(('x', 'y')))(x, x) + + def test_scan_pvary_errors(self): + mesh = jtu.create_mesh((1, 1), ('i', 'j')) + x = jnp.arange(3.) + y = jnp.arange(3.) + + @partial(shard_map, mesh=mesh, in_specs=(P('i'), P()), out_specs=P('i')) + def f(x, y): + def body(carry, _): + c1, c2 = carry + return (c2, c1), () # swap the carry + (x_, y_), _ = jax.lax.scan(body, (x, y), (), length=2) + return x_, y_ + + with self.assertRaisesRegex( + TypeError, + r"This might be fixed by applying `jax.lax.pvary\(..., \('i',\)\)` to" + r' the initial'): + f(x, y) + + @partial(shard_map, mesh=mesh, in_specs=(P('i'), P()), out_specs=P('i')) + def g(x, y): + def body(carry, _): + c1, c2 = carry + return (c2, c1), () + y = jax.lax.pvary(y, 'i') # fix the issue + (x_, y_), _ = jax.lax.scan(body, (x, y), (), length=2) + return x_, y_ + + g(x, y) # doesn't crash + + def test_scan_pvary_errors2(self): + mesh = jtu.create_mesh((1, 1), ('i', 'j')) + x = jnp.arange(3.) + y = jnp.arange(3.) + z = jnp.arange(3.) + + @partial(shard_map, mesh=mesh, in_specs=(P('i'), P(), P(('i', 'j'))), out_specs=P(('i', 'j'))) + def f(x, y, z): + def body(carry, _): + c1, c2, c3 = carry + return (c3, c1, c2), () # swap the carry + + # x = jax.lax.pvary(x, 'j') + # y = jax.lax.pvary(y, ('i', 'j')) + carry, _ = jax.lax.scan(body, (x, y, z), (), length=2) + return carry + + with self.assertRaisesRegex( + TypeError, + r'This might be fixed by:\n \* applying `jax.lax.pvary\(...,' + r" \('j',\)\)`"): + f(x, y, z) + + @partial(shard_map, mesh=mesh, in_specs=(P('i'), P(), P(('i', 'j'))), out_specs=P(('i', 'j'))) + def g(x, y, z): + def body(carry, _): + c1, c2, c3 = carry + return (c3, c1, c2), () # swap the carry + + x = jax.lax.pvary(x, 'j') # fix the issue + y = jax.lax.pvary(y, ('i', 'j')) + carry, _ = jax.lax.scan(body, (x, y, z), (), length=2) + return carry + + g(x, y, z) # doesn't crash + + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) + def test_shmap_full_manual_context_explicit(self, mesh): + np_inp = np.arange(16).reshape(8, 2) + arr = jax.device_put(np_inp, P('x', 'y')) + + @partial(jax.shard_map, out_specs=P('x', 'y')) + def f(x): + self.assertEqual(get_abstract_mesh().manual_axes, ('x', 'y')) + self.assertEqual(x.aval.vma, {'x', 'y'}) + out = x * 2 + self.assertEqual(out.aval.vma, {'x', 'y'}) + return out + + out = f(arr) + self.assertArraysEqual(out, np_inp * 2) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) + jax.jit(f)(arr) # doesn't crash + + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) + def test_shmap_partial_manual_explicit(self, mesh): + np_inp = np.arange(16).reshape(8, 2) + arr = jax.device_put(np_inp, P('x', 'y')) + + @partial(jax.shard_map, axis_names=frozenset('x'), out_specs=P('x')) + def f(x): + self.assertEqual(get_abstract_mesh().manual_axes, ('x',)) + self.assertEqual(get_abstract_mesh().explicit_axes, ('y',)) + self.assertEqual(x.aval.sharding.spec, P(None, 'y')) + self.assertEqual(x.aval.vma, {'x'}) + out = x * 2 + self.assertEqual(out.aval.sharding.spec, P(None, 'y')) + self.assertEqual(out.aval.vma, {'x'}) + return out + + out = jax.jit(f)(arr) + self.assertArraysEqual(out, np_inp * 2) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) + + @jtu.with_explicit_mesh((2, 2), ('x', 'y'), axis_types=(AxisType.Auto,) * 2) + def test_shmap_full_manual_context_auto(self, mesh): + np_inp = np.arange(16).reshape(8, 2) + arr = jax.device_put(np_inp, P('x', 'y')) + + @partial(jax.shard_map, in_specs=P('x', 'y'), out_specs=P('x', 'y')) + def f(x): + self.assertEqual(get_abstract_mesh().manual_axes, ('x', 'y')) + self.assertEqual(x.aval.vma, {'x', 'y'}) + out = x * 2 + self.assertEqual(out.aval.vma, {'x', 'y'}) + return out + + out = f(arr) + self.assertArraysEqual(out, np_inp * 2) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) + jax.jit(f)(arr) # doesn't crash + + @jtu.with_explicit_mesh((2, 2), ('x', 'y'), axis_types=(AxisType.Auto,) * 2) + def test_shmap_partial_manual_auto(self, mesh): + np_inp = np.arange(16).reshape(8, 2) + arr = jax.device_put(np_inp, P('x', 'y')) + + @partial(jax.shard_map, axis_names=frozenset('x'), in_specs=P('x'), + out_specs=P('x')) + def f(x): + self.assertEqual(get_abstract_mesh().manual_axes, ('x',)) + self.assertEqual(get_abstract_mesh().auto_axes, ('y',)) + self.assertEqual(x.aval.vma, {'x'}) + out = x * 2 + self.assertEqual(out.aval.vma, {'x'}) + return out + + out = jax.jit(f)(arr) + self.assertArraysEqual(out, np_inp * 2) + + def test_no_mesh_context_error(self): + with self.assertRaisesRegex(ValueError, "The context mesh cannot be empty"): + jax.shard_map(lambda x: x, in_specs=P(), out_specs=P())(np.arange(8)) + + def test_pvary_in_shmap_of_grad(self): + mesh = jtu.create_mesh((2,), 'x') + + def g(x): + return jnp.mean(x ** 2) + + def f(x): + val, grad = jax.value_and_grad(g)(x) + return (jnp.atleast_1d(val), jnp.atleast_1d(grad)) + + jax.shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P('x') + )(jnp.ones(2,)) # doesn't crash + + def test_shmap_linearize_and_linearize_transpose_error(self): + mesh = jtu.create_mesh((2,), ('x',)) + + def f(x): + return jnp.mean(x ** 2) + + def m(p, t): + out_p, fwd = jax.linearize(f, p) + out_t = fwd(t) + bwd = jax.linear_transpose(fwd, p) + return bwd(out_t) + + with self.assertRaisesRegex( + ValueError, + r"applying `jax.lax.pvary\(..., \('x',\)\)` to the primal value passed"): + shard_map(partial(m, jnp.array([1.])), mesh=mesh, in_specs=P('x'), + out_specs=P('x'))(jnp.ones((2,))) # doesn't crash + + def m2(p, t): + p = jax.lax.pvary(p, 'x') # fixes the issue + out_p, fwd = jax.linearize(f, p) + out_t = fwd(t) + bwd = jax.linear_transpose(fwd, p) + return bwd(out_t) + + shard_map(partial(m2, jnp.array([1.])), mesh=mesh, in_specs=P('x'), + out_specs=P('x'))(jnp.ones((2,))) # doesn't crash + + @jtu.with_explicit_mesh((2, 2), ('x', 'y'), axis_types=(AxisType.Auto,) * 2) + def test_argmax_pvary(self, mesh): + @jax.shard_map(in_specs=P('x', 'y'), out_specs=P('x', 'y')) + def argmax_impl(x): + y = x.argmax(axis=-1, keepdims=1) + return y + + argmax_impl(jax.random.normal(jax.random.key(0), (1024, 1024))) # doesn't crash + + def test_smap(self): + mesh = jtu.create_mesh((2, 2, 2), ('x', 'y', 'z')) + np_inp = np.arange(32.).reshape(4, 8) + arr = jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y'))) + + def g(x): + self.assertEqual(get_abstract_mesh().manual_axes, ('x', 'y')) + self.assertEqual(get_abstract_mesh().auto_axes, ('z',)) + self.assertEqual(x.aval.vma, {'x', 'y'}) + out = x * x + self.assertEqual(out.aval.vma, {'x', 'y'}) + return out + + def h(x): + self.assertEqual(get_abstract_mesh().manual_axes, ('x',)) + self.assertEqual(get_abstract_mesh().auto_axes, ('y', 'z')) + self.assertEqual(x.aval.vma, {'x'}) + out = smap(g, in_axes=0, out_axes=0, axis_name='y')(x) + self.assertEqual(out.aval.vma, {'x'}) + return out + + @jax.jit + def f(x): + return smap(h, in_axes=0, out_axes=0, axis_name='x')(x) + + with jax.sharding.use_mesh(mesh): + out = f(arr) + self.assertArraysEqual(out, np_inp * np_inp) + + @jtu.with_explicit_mesh((2, 2, 2), ('x', 'y', 'z')) + def test_smap_explicit(self, mesh): + np_inp = np.arange(32.).reshape(4, 8) + arr = jax.device_put(np_inp, P('x', 'y')) + + def g(x): + self.assertEqual(get_abstract_mesh().manual_axes, ('x', 'y')) + self.assertEqual(get_abstract_mesh().explicit_axes, ('z',)) + self.assertEqual(x.aval.vma, {'x', 'y'}) + out = x * x + self.assertEqual(out.aval.vma, {'x', 'y'}) + return out + + def h(x): + self.assertEqual(get_abstract_mesh().manual_axes, ('x',)) + self.assertEqual(get_abstract_mesh().explicit_axes, ('y', 'z')) + self.assertEqual(x.aval.vma, {'x'}) + out = smap(g, in_axes=0, out_axes=0, axis_name='y')(x) + self.assertEqual(out.aval.vma, {'x'}) + return out + + @jax.jit + def f(x): + return smap(h, out_axes=0, axis_name='x')(x) + + out = f(arr) + self.assertArraysEqual(out, np_inp * np_inp) + + @jtu.with_explicit_mesh((2,), ('x',), axis_types=(AxisType.Auto,)) + def test_smap_replicated(self, mesh): + @partial(smap, in_axes=None, out_axes=None, axis_name='x') + def f(x): + return x * 2 + out = f(np.arange(8)) + self.assertArraysEqual(out, np.arange(8) * 2) + self.assertEqual(out.sharding, NamedSharding(mesh, P())) + + @jtu.with_explicit_mesh((2,), ('data',), axis_types=(AxisType.Auto,)) + def test_smap_replicated_sharded(self, mesh): + @partial(smap, in_axes=(None, 0), out_axes=(None, 0), axis_name='data') + def f(x, y): + return x * 2, y * 2 + + out1, out2 = f(np.arange(8), np.arange(8)) + self.assertArraysEqual(out1, np.arange(8) * 2) + self.assertEqual(out1.sharding, NamedSharding(mesh, P())) + self.assertArraysEqual(out2, np.arange(8) * 2) + self.assertEqual(out2.sharding, NamedSharding(mesh, P('data'))) + + @partial(smap, in_axes=(None, 0), out_axes=0, axis_name='data') + def g(x, y): + return x + y + + out = g(np.arange(4), np.arange(8)) + self.assertEqual(out.sharding, NamedSharding(mesh, P('data'))) + + @jtu.with_explicit_mesh((2,), ('x',), axis_types=(AxisType.Auto,)) + def test_smap_auto_error(self, mesh): + with self.assertRaisesRegex(TypeError, "in_axes was not specified"): + smap(lambda x: x * 2, out_axes=0, axis_name='x')(np.arange(4)) + + @jtu.with_explicit_mesh((2, 2), ('x', 'y'), + axis_types=(AxisType.Explicit, AxisType.Auto)) + def test_smap_auto_explicit(self, mesh): + def f(x): + self.assertEqual(x.aval.vma, {'x'}) + return x * 2 + + arr = jax.device_put(np.arange(4), P('x')) + out = jax.jit(smap(f, out_axes=0, axis_name='x'))(arr) + self.assertArraysEqual(out, np.arange(4) * 2) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x'))) + + def g(x): + self.assertEqual(x.aval.vma, {'y'}) + return x * 2 + + arr = jax.device_put(np.arange(4), P('y')) + out = jax.jit(smap(g, in_axes=0, out_axes=0, axis_name='y'))(arr) + self.assertArraysEqual(out, np.arange(4) * 2) + self.assertEqual(out.sharding, NamedSharding(mesh, P('y'))) + + @jtu.with_explicit_mesh((2, 2), ('x', 'y'), + axis_types=(AxisType.Explicit, AxisType.Auto)) + def test_smap_auto_explicit_nest(self, mesh): + def g(b): + self.assertEqual(b.aval.vma, {'x', 'y'}) + return jnp.sin(b) + + def f(a): + self.assertEqual(a.aval.vma, {'y'}) + b = a * 2 + return smap(g, in_axes=1, out_axes=1, axis_name='x')(b) + + arr = jax.device_put(np.arange(16).reshape(8, 2), P('y')) + jax.jit(smap(f, in_axes=0, out_axes=0, axis_name='y'))(arr) # doesn't crash + + @jtu.with_explicit_mesh((2, 2), ('x', 'y'), + axis_types=(AxisType.Explicit, AxisType.Auto)) + def test_smap_auto_explicit_nest_inner_none(self, mesh): + def g(b): + self.assertEqual(b.aval.vma, {'y'}) + return jnp.sin(b) + + def f(a): + self.assertEqual(a.aval.vma, {'y'}) + b = a * 2 + # Going manual over explicit axis `x` but in_axes is Infer and since + # input has no sharding, it will default to None. + return smap(g, out_axes=1, axis_name='x')(b) + + arr = jax.device_put(np.arange(16).reshape(8, 2), P('y')) + jax.jit(smap(f, in_axes=0, out_axes=0, axis_name='y'))(arr) # doesn't crash + + @jtu.with_explicit_mesh((2, 2), ('x', 'y'), + axis_types=(AxisType.Explicit, AxisType.Auto)) + def test_smap_auto_explicit_nest_mesh_call_time(self, mesh): + @partial(smap, in_axes=1, out_axes=1, axis_name='x') + def g(b): + return jnp.sin(b) + + @partial(smap, in_axes=0, out_axes=0, axis_name='y') + def f(a): + self.assertEqual(a.aval.vma, {'y'}) + b = a * 2 + return g(b) + + arr = jax.device_put(np.arange(16).reshape(8, 2), P('y')) + jax.jit(f)(arr) # doesn't crash + class FunSpec(NamedTuple): name: str @@ -2937,7 +3846,8 @@ def make_mesh(mesh_shape): def test_eager_against_ref(self, fun, mesh, _, in_specs, out_specs, args, ref): mesh = self.make_mesh(mesh) args = map(jnp.array, args) - out = shard_map(fun, mesh, in_specs, out_specs)(*args) + out = shard_map(fun, mesh=mesh, in_specs=in_specs, + out_specs=out_specs)(*args) expected = ref(fun, mesh, in_specs, out_specs)(*args) self.assertAllClose(expected, out, check_dtypes=False) @@ -2946,7 +3856,8 @@ def test_eager_against_ref(self, fun, mesh, _, in_specs, out_specs, args, ref): def test_jit_against_ref(self, fun, mesh, _, in_specs, out_specs, args, ref): mesh = self.make_mesh(mesh) args = map(jnp.array, args) - out = jax.jit(shard_map(fun, mesh, in_specs, out_specs))(*args) + out = jax.jit(shard_map(fun, mesh=mesh, in_specs=in_specs, + out_specs=out_specs))(*args) expected = ref(fun, mesh, in_specs, out_specs)(*args) self.assertAllClose(expected, out, check_dtypes=False) @@ -2959,7 +3870,8 @@ def test_jit_against_ref(self, fun, mesh, _, in_specs, out_specs, args, ref): def test_grads(self, fun, mesh, jit, in_specs, out_specs, args, _, check_rep): mesh = self.make_mesh(mesh) args = map(jnp.array, args) - f = shard_map(fun, mesh, in_specs, out_specs, check_rep=check_rep) + f = shard_map(fun, mesh=mesh, in_specs=in_specs, + out_specs=out_specs, check_vma=check_rep) if jit: f = jax.jit(f) jtu.check_grads(f, args, order=2, atol=1e-2, rtol=1e-2) @@ -2990,7 +3902,7 @@ def test_vmap(self, bdims, fun, mesh, jit, in_specs, out_specs, args, ref): mesh = self.make_mesh(mesh) args = map(jnp.array, args) - f = shard_map(fun, mesh, in_specs, out_specs) + f = shard_map(fun, mesh=mesh, in_specs=in_specs, out_specs=out_specs) if jit: f = jax.jit(f) ans = jax.vmap(f, bdims)(*args) @@ -3042,7 +3954,7 @@ def g(*args): else: slices = map(jnp.stack, zip(*expected_slices)) expected = jax.tree.unflatten(treedef, slices) - tol = 1e-2 if jtu.test_device_matches(['tpu']) else None + tol = 1e-2 if jtu.test_device_matches(['gpu', 'tpu']) else None self.assertAllClose(ans, expected, check_dtypes=False, atol=tol, rtol=tol) @jtu.pytest_mark_if_available('multiaccelerator') @@ -3083,14 +3995,15 @@ def f(x): infer_sharding_from_operands=infer_sharding_from_operands, partition=partition, propagate_user_sharding=propagate_user_sharding, + sharding_rule='i -> i', ) @jax.jit def fwd(a): c = shard_map( f, - mesh, - check_rep=False, + mesh=mesh, + check_vma=False, in_specs=(P('z', ('x', 'y')),), out_specs=P('z', ('x', 'y')))(a) return c @@ -3107,14 +4020,102 @@ def g(x): @jax.jit def f(x): x = jax.lax.with_sharding_constraint(x, NamedSharding(mesh, P(('i', 'j')))) - re = shard_map(g, mesh, in_specs=P('i'), out_specs=P('i'), - check_rep=False, auto=frozenset({'j'}))(x) + re = shard_map(g, mesh=mesh, in_specs=P('i'), out_specs=P('i'), + check_vma=False, axis_names={'i'})(x) re = jax.lax.with_sharding_constraint(re, NamedSharding(mesh, P(('i', 'j')))) return re self.assertAllClose(f(jnp.arange(8.)), jnp.array([1., 5., 9., 13.])) +def smap_ref(f, in_axes, out_axes, axis_name, axis_size): + del axis_name # no collectives + def smapped(*args): + split_args = zip(*[split_arg(x, d, axis_size) for x, d in zip(args, in_axes)]) + split_result = [f(*xs) for xs in split_args] + return concat_result(split_result, out_axes) + return smapped + +def split_arg(x, d, axis_size): + if d is None: + x = np.tile(x, [axis_size] + [1] * (x.ndim - 1)) + return np.split(x, axis_size, d or 0) + +def concat_result(results, out_axes): + if not isinstance(results[0], (list, tuple)): + return results[0] if out_axes is None else np.concatenate(results, out_axes) + return [res[0] if d is None else np.concatenate(res, d) + for res, d in zip(zip(*results), out_axes)] + +def sample_smap() -> Chooser: + spec = yield fun_specs + mesh_shape = yield mesh_shapes + axis_names = ('i', 'j', 'k', 'l')[:len(mesh_shape)] + mesh = SimpleNamespace(shape=dict(zip(axis_names, mesh_shape)), + axis_names=axis_names) + axis_name = yield axis_names + body_in_types = yield (tys for tys in it.product(input_shapes, repeat=spec.num_inputs) + if not spec.valid_types or spec.valid_types(*tys)) + in_axes = yield from sample_in_axes(body_in_types) + out_rep = spec.out_rep(*[ax is None for ax in in_axes]) + body_out_type = jax.eval_shape(spec.fun, *body_in_types) + out_axes = yield from sample_out_axes(out_rep, body_out_type) + in_str = '(' + ','.join(jax.core.ShapedArray(t.shape, t.dtype).str_short() + for t in body_in_types) + ')' + name = f'{spec.name}_{mesh.shape}_{in_axes}_{out_axes}_{axis_name}_{in_str}' + in_types = [ty.update(shape=dilate_axis(ty.shape, d, mesh.shape[axis_name])) + for ty, d in zip(body_in_types, in_axes)] + args = [np.arange(ty.size, dtype=ty.dtype).reshape(ty.shape) / ty.size + for ty in in_types] + return name, spec, mesh.shape, in_axes, out_axes, axis_name, args + +def sample_in_axes(body_in_types) -> Chooser: + in_axes = [] + for ty in body_in_types: + in_axes.append((yield [None, *range(ty.ndim)])) + return tuple(in_axes) + +def sample_out_axes(out_rep, body_out_type) -> Chooser: + if not isinstance(body_out_type, (list, tuple)): + out_axes = yield [None] * out_rep + list(range(body_out_type.ndim)) + else: + out_axes_ = [] + for ty, r in zip(body_out_type, out_rep): + out_axes_.append((yield [None] * r + list(range(ty.ndim)))) + out_axes = tuple(out_axes_) + return out_axes + +def dilate_axis(shape: tuple[int, ...], i: int | None, size: int) -> tuple[int, ...]: + if i is None: + return shape + shp = list(shape) + shp[i] *= size + return tuple(shp) + +class SmapSystematicTest(jtu.JaxTestCase): + + @staticmethod + def make_mesh(mesh_shape): + return jtu.create_mesh(tuple(mesh_shape.values()), tuple(mesh_shape)) + + @parameterized.parameters( + sample(jtu.NUM_GENERATED_CASES.value, sample_smap)) + def test_against_ref(self, fun_spec, mesh_shape, in_axes, out_axes, axis_name, args): + fun = fun_spec.fun + mesh = self.make_mesh(mesh_shape) + args = map(jnp.array, args) + + with jax.sharding.use_mesh(mesh): + fun_ = smap(fun, in_axes=in_axes, out_axes=out_axes, axis_name=axis_name) + out = jax.jit(fun_)(*args) + + fun_ref = smap_ref(fun, in_axes=in_axes, out_axes=out_axes, axis_name=axis_name, + axis_size=mesh_shape[axis_name]) + expected = fun_ref(*args) + + self.assertAllClose(out, expected, check_dtypes=False) + + @jtu.with_config(jax_use_shardy_partitioner=True) # TODO(phawkins): enable this test unconditionally once shardy is the default. @unittest.skipIf(sdy is None, "shardy is not enabled") @@ -3133,7 +4134,7 @@ def test_shardy_collective_permute(self): shard_map, mesh=mesh, in_specs=(P('x', None),), out_specs=P('x', None) ) def fwd(a): - axis_size = lax.psum(1, 'x') + axis_size = lax.axis_size('x') perm = [(j, (j + 1) % axis_size) for j in range(axis_size)] return lax.ppermute(a, 'x', perm=perm) diff --git a/tests/sparse_bcoo_bcsr_test.py b/tests/sparse_bcoo_bcsr_test.py index e839bacbe5fc..f489d4551465 100644 --- a/tests/sparse_bcoo_bcsr_test.py +++ b/tests/sparse_bcoo_bcsr_test.py @@ -36,7 +36,7 @@ from jax.experimental.sparse import util as sparse_util import jax.numpy as jnp import jax.random -from jax.util import split_list +from jax._src.util import split_list import numpy as np jax.config.parse_flags_with_absl() @@ -603,7 +603,7 @@ def test_bcoo_batched_matmat_default_lowering( # with self.gpu_matmul_warning_context( # "bcoo_dot_general GPU lowering currently does not support this batch-mode computation.*"): matmat_default_lowering_fallback = sp_matmat(lhs_bcoo, rhs) - self.assertArraysEqual(matmat_expected, matmat_default_lowering_fallback) + self.assertArraysAllClose(matmat_expected, matmat_default_lowering_fallback) @jtu.run_on_devices("gpu") def test_bcoo_dot_general_oob_and_unsorted_indices_cusparse(self): @@ -974,6 +974,7 @@ def test_bcoo_spdot_general_nse(self, lhs_shape, rhs_shape): self.assertEqual(out.nse, expected_nse) @jtu.ignore_warning(message="bcoo_dot_general cusparse/hipsparse lowering not available") + @jtu.ignore_warning(category=sparse.CuSparseEfficiencyWarning) def test_bcoo_spdot_general_ad_bug(self): # Regression test for https://github.com/jax-ml/jax/issues/10163 A_indices = jnp.array([[0, 1], [0, 2], [1, 1], [1, 2], [1, 0]]) diff --git a/tests/sparse_nm_test.py b/tests/sparse_nm_test.py deleted file mode 100644 index 9ecf30eb6229..000000000000 --- a/tests/sparse_nm_test.py +++ /dev/null @@ -1,209 +0,0 @@ -# Copyright 2024 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import math - -import numpy as np -from absl.testing import absltest -from absl.testing import parameterized - -import jax -import jax.numpy as jnp -from jax import dtypes -from jax._src import config -from jax._src import test_util as jtu -from jax.experimental.sparse import nm - -jax.config.parse_flags_with_absl() - - -class SpmmTest(jtu.JaxTestCase): - def setUp(self): - if not jtu.test_device_matches(["gpu"]): - self.skipTest("Only works on GPU") - if (jtu.test_device_matches(["cuda"]) and - not jtu.is_cuda_compute_capability_at_least("8.0")): - self.skipTest("Only works on GPUs with capability >= sm80") - super().setUp() - - # ----- Test different input shapes - @parameterized.product( - tile_m=(32, 128), - tile_n=(32, 128), - tile_k=(32, 128), - batch=(None, 5), - sparse_idx=(0, 1), - ) - @jtu.run_on_devices("gpu") - def test_shapes(self, tile_m, tile_n, tile_k, batch, sparse_idx): - # Build keyword arguments - kwargs = { - "dimension_numbers": (((1,), (1,)), (tuple(), tuple())), - "sparse_operand_idx": sparse_idx, - } - if batch: - kwargs["dimension_numbers"] = (((2,), (2,)), ((0,), (0,))) - - # Build input data - batch_dims = (batch,) if batch else tuple() - lhs = ( - (np.arange((batch or 1) * tile_m * tile_k) % 11) - .astype(dtypes.bfloat16) - .reshape(batch_dims + (tile_m, tile_k)) - ) - rhs = ( - (np.arange((batch or 1) * tile_n * tile_k) % 13) - .astype(dtypes.bfloat16) - .reshape(batch_dims + (tile_n, tile_k)) - ) - - # Build sparsity mask and metadata - sp = [lhs, rhs][sparse_idx] - mask = np.tile([True, False], math.prod(sp.shape) // 2).reshape(sp.shape) - sparse = sp[mask].reshape(sp.shape[:-1] + (sp.shape[-1] // 2,)) - meta = nm.nm_pack(mask) - - # Calculate sparse and dense dots - if sparse_idx == 0: - dot_sparse = nm.nm_spmm(sparse, rhs, meta, **kwargs) - dot_dense = jnp.einsum("...mk,...nk->...mn", (lhs * mask), rhs) - else: - dot_sparse = nm.nm_spmm(lhs, sparse, meta, **kwargs) - dot_dense = jnp.einsum("...mk,...nk->...mn", lhs, (rhs * mask)) - - # Verify the result - jtu.check_eq(dot_sparse, dot_dense.astype(dtypes.bfloat16)) - - # ----- Test different input types - @parameterized.product( - lhs_type=[jnp.int8, jnp.int16, jnp.float16, jnp.bfloat16], - rhs_type=[jnp.bfloat16], - output_type=[jnp.bfloat16, jnp.float32], - ) - @jtu.run_on_devices("gpu") - def test_types(self, lhs_type, rhs_type, output_type): - tile_m, tile_n, tile_k = 64, 32, 128 - - # Build input data - lhs = ( - (np.arange(tile_m * tile_k) % 17) - .astype(lhs_type) - .reshape((tile_m, tile_k)) - ) - rhs = ( - (np.arange(tile_k * tile_n) % 19) - .astype(rhs_type) - .reshape((tile_k, tile_n)) - ) - - # Build sparsity mask and metadata - mask = np.tile([True, False], tile_m * tile_k // 2).reshape(lhs.shape) - sparse = lhs[mask].reshape(tile_m, tile_k // 2) - meta = nm.nm_pack(mask) - - # Calculate sparse and dense dots - dot_sparse = nm.nm_spmm(sparse, rhs, meta, output_dtype=output_type) - dot_dense = (lhs * mask) @ rhs - - # Verify the result - jtu.check_close(dot_sparse, dot_dense.astype(output_type), rtol=0.01) - - # ----- Test validation - @jtu.run_on_devices("gpu") - def test_validate_nm_pack(self): - with self.assertRaisesRegex(TypeError, "Mask should be bool"): - nm.nm_pack(jnp.zeros(16, jnp.int8)) - with self.assertRaisesRegex( - TypeError, "Inner dimension size should be divisible by 16" - ): - nm.nm_pack(jnp.array([False] * 8)) - - @jtu.run_on_devices("gpu") - def test_validate_nm_spmm(self): - batch, tile_m, tile_n, tile_k = 2, 64, 32, 128 - lhs = jnp.zeros((batch, tile_m, tile_k // 2), dtype=jnp.bfloat16) - rhs = jnp.zeros((batch, tile_k, tile_n), dtype=jnp.bfloat16) - meta = jnp.zeros((batch, tile_m, tile_k // 16), dtype=jnp.uint16) - - if config.enable_x64.value: - with self.assertRaisesRegex(TypeError, "Unsupported lhs input type"): - nm.nm_spmm(jnp.zeros(lhs.shape, dtype=jnp.int64), rhs, meta) - with self.assertRaisesRegex(TypeError, "Unsupported rhs input type"): - nm.nm_spmm(lhs, jnp.zeros(rhs.shape, dtype=jnp.int64), meta) - with self.assertRaisesRegex(TypeError, "Unsupported output type"): - nm.nm_spmm(lhs, rhs, meta, output_dtype=jnp.int64) - - # Check dimension numbers - nm_spmm_with_dnums = lambda c, b: nm.nm_spmm( - lhs, rhs, meta, dimension_numbers=(c, b) - ) - with self.assertRaisesRegex( - TypeError, "Only single contracting dimension is supported" - ): - nm_spmm_with_dnums(((0, 2), (0, 1)), (tuple(), tuple())) - with self.assertRaisesRegex( - TypeError, "Incorrect dimension numbers for lhs" - ): - nm_spmm_with_dnums(((2,), (1,)), ((2,), (0,))) - with self.assertRaisesRegex( - TypeError, "Incorrect dimension numbers for rhs" - ): - nm_spmm_with_dnums(((2,), (1,)), ((0,), (1,))) - with self.assertRaisesRegex( - TypeError, "Only single non-contracting dimension is supported" - ): - nm_spmm_with_dnums(((2,), (1,)), (tuple(), tuple())) - with self.assertRaisesRegex( - TypeError, "Batch dimension sizes do not match" - ): - nm.nm_spmm( - lhs, - rhs.reshape(1, tile_k, tile_n * batch), - meta, - dimension_numbers=(((2,), (1,)), ((0,), (0,))), - ) - - # Check metadata - nm_spmm_with_meta = lambda m: nm.nm_spmm( - lhs, rhs, m, dimension_numbers=(((2,), (1,)), ((0,), (0,))) - ) - with self.assertRaisesRegex(TypeError, "Metadata must be uint16"): - nm_spmm_with_meta(jnp.zeros(meta.shape, dtype=jnp.uint8)) - with self.assertRaisesRegex( - TypeError, "Metadata shape must match the operand shape" - ): - nm_spmm_with_meta(meta.reshape(1, batch * tile_m, tile_k // 16)) - with self.assertRaisesRegex( - TypeError, - "Metadata must be exactly 8 times less than the contracting dimension" - " for 2:4 structured sparsity", - ): - nm_spmm_with_meta(jnp.repeat(meta, 2, axis=-1)) - with self.assertRaisesRegex( - TypeError, "Contracting dimension must be the minor one" - ): - nm.nm_spmm(lhs, rhs, meta, dimension_numbers=(((1,), (1,)), ((0,), (0,)))) - with self.assertRaisesRegex( - TypeError, "Contracting dimension sizes should have 2:4 ratio" - ): - nm.nm_spmm( - lhs, - jnp.repeat(rhs, 2, axis=1), - meta, - dimension_numbers=(((2,), (1,)), ((0,), (0,))), - ) - - -if __name__ == "__main__": - absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/sparse_test.py b/tests/sparse_test.py index eb8d70be1f05..1eeeae7c2749 100644 --- a/tests/sparse_test.py +++ b/tests/sparse_test.py @@ -16,6 +16,8 @@ from functools import partial import itertools import math +import os +from pathlib import Path from absl.testing import absltest from absl.testing import parameterized @@ -38,10 +40,19 @@ from jax._src import test_util as jtu from jax.interpreters import mlir import jax.numpy as jnp -from jax.util import split_list +from jax._src.util import split_list import numpy as np import scipy.sparse +def get_rocm_version(): + rocm_path = os.environ.get("ROCM_PATH", "/opt/rocm") + version_path = Path(rocm_path) / ".info" / "version" + if not version_path.exists(): + raise FileNotFoundError(f"Expected ROCm version file at {version_path}") + version_str = version_path.read_text().strip() + major, minor, *_ = version_str.split(".") + return int(major), int(minor) + jax.config.parse_flags_with_absl() all_dtypes = jtu.dtypes.integer + jtu.dtypes.floating + jtu.dtypes.complex @@ -208,6 +219,14 @@ def test_csr_fromdense(self, shape, dtype): transpose=[True, False], ) def test_csr_matvec(self, shape, dtype, transpose): + if ( + jtu.is_device_rocm() and + get_rocm_version() < (6, 4) and + dtype in (jtu.dtypes.floating + jtu.dtypes.complex) + ): + # TODO: Remove this check when ROCm 6.4+ is the minimum supported version + self.skipTest("ROCm <6.4 bug: NaN propagation when beta==0 (fixed in ROCm 6.4.0)") + op = lambda M: M.T if transpose else M v_rng = jtu.rand_default(self.rng()) @@ -228,6 +247,14 @@ def test_csr_matvec(self, shape, dtype, transpose): transpose=[True, False], ) def test_csr_matmat(self, shape, dtype, transpose): + if ( + jtu.is_device_rocm() and + get_rocm_version() < (6, 4) and + dtype in (jtu.dtypes.floating + jtu.dtypes.complex) + ): + # TODO: Remove this check when ROCm 6.4+ is the minimum supported version + self.skipTest("ROCm <6.4 bug: NaN propagation when beta==0 (fixed in ROCm 6.4.0)") + op = lambda M: M.T if transpose else M B_rng = jtu.rand_default(self.rng()) @@ -1102,7 +1129,9 @@ def test_bcoo_to_bcsr_round_trip(self, shape, dtype, n_batch): _, bcoo_indices = sparse_bcoo._bcoo_fromdense(M, nse=nse, n_batch=n_batch, n_dense=n_dense) - bcoo_to_bcsr = partial(sparse_bcsr._bcoo_to_bcsr, shape=shape) + bcoo_to_bcsr = partial( + sparse_bcsr._bcoo_to_bcsr, shape=shape, index_dtype=bcoo_indices.dtype + ) args_maker_bcoo_to_bcsr = lambda: [bcoo_indices] self._CompileAndCheck(bcoo_to_bcsr, args_maker_bcoo_to_bcsr) @@ -1177,7 +1206,12 @@ def sparse_solve(data, indices, indptr, b): return sparse.linalg.spsolve(data, indices, indptr, b, tol, reorder) x = sparse_solve(data, indices, indptr, b) - self.assertAllClose(a @ x, b, rtol=1e-2, atol=1e-3) + self.assertAllClose( + jnp.matmul(a, x, precision=jax.lax.Precision.HIGHEST), + b, + rtol=1e-2, + atol=1e-3, + ) self._CompileAndCheck(sparse_solve, args_maker) @jtu.sample_product( diff --git a/tests/stack_test.py b/tests/stack_test.py index aa1a02793b1a..8ebfc3489ff5 100644 --- a/tests/stack_test.py +++ b/tests/stack_test.py @@ -16,7 +16,7 @@ import jax import jax.numpy as jnp -from jax._src.lax.stack import Stack +from jax._src.tpu.linalg.stack import Stack from jax._src import test_util as jtu diff --git a/tests/state_test.py b/tests/state_test.py index 60a7d8bc9f8a..9bbc68101443 100644 --- a/tests/state_test.py +++ b/tests/state_test.py @@ -28,6 +28,7 @@ from jax import lax from jax._src import core from jax._src import config +from jax._src import dtypes from jax._src import linear_util as lu from jax._src.interpreters import partial_eval as pe from jax._src import test_util as jtu @@ -36,13 +37,9 @@ import jax.numpy as jnp from jax._src.lax.control_flow import for_loop -try: - import hypothesis as hp - import hypothesis.extra.numpy as hnp - import hypothesis.strategies as hps - CAN_USE_HYPOTHESIS = True -except (ModuleNotFoundError, ImportError): - CAN_USE_HYPOTHESIS = False +import hypothesis as hp +import hypothesis.extra.numpy as hnp +import hypothesis.strategies as hps from jax._src.state.discharge import (run_state, run_state_reference, discharge_state) @@ -364,7 +361,7 @@ def body(x_ref): return [] jaxpr, _ , _, () = pe.trace_to_jaxpr_dynamic( wrap_init(body, 1), [shaped_array_ref((), jnp.int32)]) - self.assertIn("a[] <- 2", jaxpr.pretty_print(use_color=False)) + self.assertIn("a[] <- 2:i32[]", jaxpr.pretty_print(use_color=False)) def body(x_ref, val): x_ref[:, 0] = val @@ -380,7 +377,7 @@ def body(x_ref): return [x] jaxpr, _ , _, () = pe.trace_to_jaxpr_dynamic( wrap_init(body, 1), [shaped_array_ref((), jnp.int32)]) - self.assertIn("b:i32[], a[] <- a[], 2", jaxpr.pretty_print(use_color=False)) + self.assertIn("b:i32[], a[] <- a[], 2:i32[]", jaxpr.pretty_print(use_color=False)) def body(x_ref, val): x = ref_swap(x_ref, (slice(None), 0), val) @@ -477,27 +474,17 @@ def g(r, rdot): op=[ lambda x_ref, indexer: [x_ref[indexer]], lambda x_ref, indexer: [ - ref_swap(x_ref, indexer, - jnp.ones(x_ref.shape, x_ref.dtype)[None][(0, - *indexer)])], + ref_swap(x_ref, indexer, jnp.ones_like(x_ref[indexer]))], lambda x_ref, indexer: ( - ref_addupdate(x_ref, indexer, - jnp.ones(x_ref.shape, x_ref.dtype)[None][(0, - *indexer)]) - or [jnp.ones(x_ref.shape, x_ref.dtype)[None][(0, *indexer)]]) + ref_addupdate(x_ref, indexer, jnp.ones_like(x_ref[indexer])) + or [jnp.ones_like(x_ref[indexer])]), ], ) def test_vmap(self, ref_shape, ref_bdim, idx_shape, indexed_dims, idx_bdims, out_bdim, op): - - float_ = (jnp.dtype('float64') if config.enable_x64.value else - jnp.dtype('float32')) - int_ = (jnp.dtype('int64') if config.enable_x64.value else - jnp.dtype('int32')) + intx = dtypes.canonicalize_dtype(jnp.int64) + floatx = dtypes.canonicalize_dtype(jnp.float64) axis_size = 7 - out_shape = tuple(d for d, b in zip(ref_shape, indexed_dims) if not b) - if any(indexed_dims): - out_shape = (*idx_shape, *out_shape) def maybe_insert(shape, idx): if idx is None: @@ -505,13 +492,13 @@ def maybe_insert(shape, idx): return tuple_insert(shape, idx, axis_size) batched_ref_shape = maybe_insert(ref_shape, ref_bdim) - ref_aval = shaped_array_ref(ref_shape, float_) - bat_ref_aval = shaped_array_ref(batched_ref_shape, float_) + ref_aval = shaped_array_ref(ref_shape, floatx) + bat_ref_aval = shaped_array_ref(batched_ref_shape, floatx) - idx_avals = [core.ShapedArray(idx_shape, int_) + idx_avals = [core.ShapedArray(idx_shape, intx) for _ in idx_bdims] bat_idx_avals = [ - core.ShapedArray(maybe_insert(idx_shape, idx_bdim), int_) + core.ShapedArray(maybe_insert(idx_shape, idx_bdim), intx) for idx_bdim in idx_bdims] def f(x_ref, *idxs): @@ -531,6 +518,7 @@ def f(x_ref, *idxs): wrap_init(f_batched, 1 + len(bat_idx_avals)), [bat_ref_aval, *bat_idx_avals]) jaxpr, consts = discharge_state(stateful_jaxpr, stateful_consts) discharge_of_vmap_ans = core.eval_jaxpr(jaxpr, consts, a, *idxs) + # vmap-of-discharge stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic( wrap_init(f, 1 + len(idx_avals)), [ref_aval, *idx_avals]) @@ -792,7 +780,7 @@ def body(i, st): lax.fori_loop(0, 5, body, init_val=()) return a_ref[...], b_ref[...] - ref = lambda x: AbstractRef(core.raise_to_shaped(core.get_aval(x))) + ref = lambda x: AbstractRef(core.get_aval(x)) f_jaxpr = jax.make_jaxpr(f)(ref(1.), ref(2.)) jaxpr, _ = discharge_state(f_jaxpr.jaxpr, (), should_discharge=[False, True]) # Effects on y_ref were discharged away but not the effects on x_ref @@ -806,294 +794,303 @@ def body(i, st): self.assertLen(jaxpr.outvars, 3) -if CAN_USE_HYPOTHESIS: - - def index_arrays(size, idx_shape): - valid_idx = hps.integers(min_value=-size, max_value=size - 1) - return hnp.arrays(np.int32, idx_shape, elements=valid_idx) - - Shape = tuple[int, ...] - - class IndexParam(NamedTuple): - ref_aval: shaped_array_ref - ref_shape: Shape - indexed_dims: list[bool] - idx_avals: tuple[core.ShapedArray, ...] - idx_shape: Shape - slice_aval: core.ShapedArray - slice_shape: Shape - - @hps.composite - def index_params(draw): - ref_shape = draw(hnp.array_shapes(max_dims=4, max_side=7), label='ref_shape') - indexed_dims = draw(hps.lists(hps.booleans(), - min_size=len(ref_shape), - max_size=len(ref_shape))) - idx_shape = draw(hnp.array_shapes(max_dims=3, max_side=5)) - if any(indexed_dims): - sliced_shape = (s for s, b in zip(ref_shape, indexed_dims) if not b) +def index_arrays(size, idx_shape): + valid_idx = hps.integers(min_value=-size, max_value=size - 1) + return hnp.arrays(np.int32, idx_shape, elements=valid_idx) + +Shape = tuple[int, ...] + +class IndexParam(NamedTuple): + ref_aval: shaped_array_ref + ref_shape: Shape + indexed_dims: list[bool] + idx_avals: tuple[core.ShapedArray, ...] + idx_shape: Shape + slice_aval: core.ShapedArray + slice_shape: Shape + +@hps.composite +def index_params(draw): + ref_shape = draw(hnp.array_shapes(max_dims=4, max_side=7), label='ref_shape') + indexed_dims = draw(hps.lists(hps.booleans(), + min_size=len(ref_shape), + max_size=len(ref_shape))) + idx_shape = draw(hnp.array_shapes(max_dims=3, max_side=5)) + if not any(indexed_dims): + slice_shape = ref_shape + else: + sliced_shape = tuple(s for s, b in zip(ref_shape, indexed_dims) if not b) + int_indexers_contiguous = bool( + np.all(np.diff(np.where(indexed_dims)[0]) == 1) + ) + if not int_indexers_contiguous: slice_shape = (*idx_shape, *sliced_shape) else: - slice_shape = ref_shape - ref_aval = shaped_array_ref(ref_shape, np.float32) - idx_avals = tuple(core.ShapedArray(idx_shape, np.int32) for _ in - range(sum(indexed_dims))) - slice_aval = core.ShapedArray(slice_shape, np.float32) - return IndexParam(ref_aval, ref_shape, indexed_dims, idx_avals, idx_shape, - slice_aval, slice_shape) - - class VmappableIndexParam(NamedTuple): - index_param: IndexParam - ref_bdim: int | None - non_slice_idx_bdims: tuple[int | None, ...] - slice_bdim: int - bat_ref_aval: shaped_array_ref - bat_ref_shape: Shape - bat_non_slice_idx_avals: tuple[core.ShapedArray, ...] - bat_non_slice_idx_shapes: tuple[Shape, ...] - bat_slice_aval: core.ShapedArray - bat_slice_shape: Shape - - def maybe_tuple_insert(t: tuple[Any, ...], idx: int | None, - val: Any) -> tuple[Any, ...]: - if idx is None: - return t - return tuple_insert(t, idx, val) - - @hps.composite - def vmappable_index_params(draw, *, op_type: str): - axis_size = draw(hps.integers(min_value=1, max_value=7), label='axis_size') - index_param: IndexParam = draw(index_params()) - non_slice_idx_bdims = tuple( - draw(hps.one_of( - hps.none(), - hps.integers(min_value=0, max_value=len(index_param.idx_shape)))) - for b in index_param.indexed_dims if b) - bat_non_slice_idx_shapes = tuple( - maybe_tuple_insert(index_param.idx_shape, idx_bdim, axis_size) - for idx_bdim in non_slice_idx_bdims) - if op_type == "swap": - # In a swap, the ref *must* be batched - ref_bdim = draw(hps.integers(min_value=0, - max_value=len(index_param.ref_shape))) - if any(idx_bdim is not None for idx_bdim in non_slice_idx_bdims): - # If it's a swap, if indices are batched, val must be batched. - slice_bdim = draw(hps.integers( - min_value=0, max_value=len(index_param.slice_shape))) - else: - slice_bdim = draw(hps.one_of(hps.none(), hps.integers( - min_value=0, max_value=len(index_param.slice_shape)))) - elif op_type == "get": - # In a get, the indices must be batched or ref is batched - if all(idx_bdim is None for idx_bdim in non_slice_idx_bdims): - ref_bdim = draw(hps.integers(min_value=0, - max_value=len(index_param.ref_shape))) - else: - ref_bdim = draw(hps.one_of(hps.none(), - hps.integers(min_value=0, max_value=len(index_param.ref_shape)))) + insert_pos = indexed_dims.index(True) + slice_shape = ( + *sliced_shape[:insert_pos], + *idx_shape, + *sliced_shape[insert_pos:], + ) + ref_aval = shaped_array_ref(ref_shape, np.float32) + idx_avals = tuple(core.ShapedArray(idx_shape, np.int32) for _ in + range(sum(indexed_dims))) + slice_aval = core.ShapedArray(slice_shape, np.float32) + return IndexParam(ref_aval, ref_shape, indexed_dims, idx_avals, idx_shape, + slice_aval, slice_shape) + +class VmappableIndexParam(NamedTuple): + index_param: IndexParam + ref_bdim: int | None + non_slice_idx_bdims: tuple[int | None, ...] + slice_bdim: int + bat_ref_aval: shaped_array_ref + bat_ref_shape: Shape + bat_non_slice_idx_avals: tuple[core.ShapedArray, ...] + bat_non_slice_idx_shapes: tuple[Shape, ...] + bat_slice_aval: core.ShapedArray + bat_slice_shape: Shape + +def maybe_tuple_insert(t: tuple[Any, ...], idx: int | None, + val: Any) -> tuple[Any, ...]: + if idx is None: + return t + return tuple_insert(t, idx, val) + +@hps.composite +def vmappable_index_params(draw, *, op_type: str): + axis_size = draw(hps.integers(min_value=1, max_value=7), label='axis_size') + index_param: IndexParam = draw(index_params()) + non_slice_idx_bdims = tuple( + draw(hps.one_of( + hps.none(), + hps.integers(min_value=0, max_value=len(index_param.idx_shape)))) + for b in index_param.indexed_dims if b) + bat_non_slice_idx_shapes = tuple( + maybe_tuple_insert(index_param.idx_shape, idx_bdim, axis_size) + for idx_bdim in non_slice_idx_bdims) + if op_type == "swap": + # In a swap, the ref *must* be batched + ref_bdim = draw(hps.integers(min_value=0, + max_value=len(index_param.ref_shape))) + if any(idx_bdim is not None for idx_bdim in non_slice_idx_bdims): + # If it's a swap, if indices are batched, val must be batched. slice_bdim = draw(hps.integers( min_value=0, max_value=len(index_param.slice_shape))) + else: + slice_bdim = draw(hps.one_of(hps.none(), hps.integers( + min_value=0, max_value=len(index_param.slice_shape)))) + elif op_type == "get": + # In a get, the indices must be batched or ref is batched + if all(idx_bdim is None for idx_bdim in non_slice_idx_bdims): + ref_bdim = draw(hps.integers(min_value=0, + max_value=len(index_param.ref_shape))) + else: + ref_bdim = draw(hps.one_of(hps.none(), + hps.integers(min_value=0, max_value=len(index_param.ref_shape)))) + slice_bdim = draw(hps.integers( + min_value=0, max_value=len(index_param.slice_shape))) + + bat_ref_shape = maybe_tuple_insert(index_param.ref_shape, ref_bdim, axis_size) + bat_ref_aval = shaped_array_ref(bat_ref_shape, np.float32) + bat_non_slice_idx_avals = tuple( + core.ShapedArray(shape, np.int32) for shape in bat_non_slice_idx_shapes) + bat_slice_shape = maybe_tuple_insert(index_param.slice_shape, slice_bdim, axis_size) + bat_slice_aval = core.ShapedArray(bat_slice_shape, np.float32) + return VmappableIndexParam(index_param, ref_bdim, non_slice_idx_bdims, + slice_bdim, bat_ref_aval, bat_ref_shape, + bat_non_slice_idx_avals, bat_non_slice_idx_shapes, + bat_slice_aval, bat_slice_shape) + +class GetVmapParams(NamedTuple): + vmap_index_param: VmappableIndexParam + bat_ref: np.ndarray + bat_idxs: tuple[np.ndarray, ...] + +@hps.composite +def get_vmap_params(draw): + vmap_index_param: VmappableIndexParam = draw( + vmappable_index_params(op_type="get")) + bat_ref = draw(hnp.arrays(np.float32, vmap_index_param.bat_ref_shape)) + bat_idx_shapes_ = iter(vmap_index_param.bat_non_slice_idx_shapes) + bat_idxs = tuple( + draw(index_arrays(size, next(bat_idx_shapes_))) + for size, indexed in zip( + vmap_index_param.index_param.ref_shape, + vmap_index_param.index_param.indexed_dims) + if indexed) + assert next(bat_idx_shapes_, None) is None + return GetVmapParams(vmap_index_param, bat_ref, bat_idxs) + +class SetVmapParams(NamedTuple): + vmap_index_param: VmappableIndexParam + bat_ref: np.ndarray + bat_val: np.ndarray + bat_idxs: tuple[np.ndarray, ...] + +@hps.composite +def set_vmap_params(draw): + vmap_index_param: VmappableIndexParam = draw(vmappable_index_params( + op_type="swap")) + bat_ref = draw(hnp.arrays(np.float32, vmap_index_param.bat_ref_shape)) + bat_idx_shapes_ = iter(vmap_index_param.bat_non_slice_idx_shapes) + bat_idxs = tuple( + draw(index_arrays(size, next(bat_idx_shapes_))) + for size, indexed in zip( + vmap_index_param.index_param.ref_shape, + vmap_index_param.index_param.indexed_dims) + if indexed) + assert next(bat_idx_shapes_, None) is None + bat_val = draw(hnp.arrays(np.float32, vmap_index_param.bat_slice_shape)) + return SetVmapParams(vmap_index_param, bat_ref, bat_val, bat_idxs) + +Indexer = tuple[Union[int, slice, np.ndarray]] + +def _unpack_idx(idx: Indexer + ) -> tuple[Sequence[int | np.ndarray], Sequence[bool]]: + indexed_dims = [type(i) != slice for i in idx] + non_slice_idx = [i for i, b in zip(idx, indexed_dims) if b] + return non_slice_idx, indexed_dims + +def _pack_idx(non_slice_idx: Sequence[int | np.ndarray], + indexed_dims: Sequence[bool]) -> Indexer: + idx_ = iter(non_slice_idx) + idx = tuple(next(idx_) if b else slice(None) for b in indexed_dims) + assert next(idx_, None) is None + return idx + +@jtu.thread_unsafe_test_class() # hypothesis isn't thread-safe +class StateHypothesisTest(jtu.JaxTestCase): + + @hp.given(get_vmap_params()) + @hp.settings(deadline=None, print_blob=True, + max_examples=jtu.NUM_GENERATED_CASES.value) + def test_get_vmap(self, get_vmap_param: GetVmapParams): + + indexed_dims = get_vmap_param.vmap_index_param.index_param.indexed_dims + + def f(ref, *non_slice_idx): + idx = _pack_idx(non_slice_idx, indexed_dims) + return [ref_get(ref, idx)] + ref_aval = get_vmap_param.vmap_index_param.index_param.ref_aval + bat_ref_aval = get_vmap_param.vmap_index_param.bat_ref_aval + bat_non_slice_idx_avals = get_vmap_param.vmap_index_param.bat_non_slice_idx_avals + ref_bdim = get_vmap_param.vmap_index_param.ref_bdim + idx_bdims = get_vmap_param.vmap_index_param.non_slice_idx_bdims + out_bdim = get_vmap_param.vmap_index_param.slice_bdim + non_slice_idx = get_vmap_param.bat_idxs + idx_avals = get_vmap_param.vmap_index_param.index_param.idx_avals + ref = get_vmap_param.bat_ref + + f_batched = jax.vmap(f, in_axes=(ref_bdim, *idx_bdims), out_axes=[out_bdim]) + stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic( + wrap_init(f_batched, 1 + len(bat_non_slice_idx_avals)), + [bat_ref_aval, *bat_non_slice_idx_avals]) + jaxpr, consts = discharge_state(stateful_jaxpr, stateful_consts) + discharge_of_vmap_ans = core.eval_jaxpr(jaxpr, consts, ref, *non_slice_idx) + + # vmap-of-discharge + stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic( + wrap_init(f, 1 + len(idx_avals)), [ref_aval, *idx_avals]) + jaxpr_, consts_ = discharge_state(stateful_jaxpr, stateful_consts) + f_batched = jax.vmap(partial(core.eval_jaxpr, jaxpr_, consts_), + in_axes=(ref_bdim, *idx_bdims), + out_axes=[out_bdim, ref_bdim]) + vmap_of_discharge_ans = f_batched(ref, *non_slice_idx) - bat_ref_shape = maybe_tuple_insert(index_param.ref_shape, ref_bdim, axis_size) - bat_ref_aval = shaped_array_ref(bat_ref_shape, np.float32) - bat_non_slice_idx_avals = tuple( - core.ShapedArray(shape, np.int32) for shape in bat_non_slice_idx_shapes) - bat_slice_shape = maybe_tuple_insert(index_param.slice_shape, slice_bdim, axis_size) - bat_slice_aval = core.ShapedArray(bat_slice_shape, np.float32) - return VmappableIndexParam(index_param, ref_bdim, non_slice_idx_bdims, - slice_bdim, bat_ref_aval, bat_ref_shape, - bat_non_slice_idx_avals, bat_non_slice_idx_shapes, - bat_slice_aval, bat_slice_shape) - - class GetVmapParams(NamedTuple): - vmap_index_param: VmappableIndexParam - bat_ref: np.ndarray - bat_idxs: tuple[np.ndarray, ...] - - @hps.composite - def get_vmap_params(draw): - vmap_index_param: VmappableIndexParam = draw( - vmappable_index_params(op_type="get")) - bat_ref = draw(hnp.arrays(np.float32, vmap_index_param.bat_ref_shape)) - bat_idx_shapes_ = iter(vmap_index_param.bat_non_slice_idx_shapes) - bat_idxs = tuple( - draw(index_arrays(size, next(bat_idx_shapes_))) - for size, indexed in zip( - vmap_index_param.index_param.ref_shape, - vmap_index_param.index_param.indexed_dims) - if indexed) - assert next(bat_idx_shapes_, None) is None - return GetVmapParams(vmap_index_param, bat_ref, bat_idxs) - - class SetVmapParams(NamedTuple): - vmap_index_param: VmappableIndexParam - bat_ref: np.ndarray - bat_val: np.ndarray - bat_idxs: tuple[np.ndarray, ...] - - @hps.composite - def set_vmap_params(draw): - vmap_index_param: VmappableIndexParam = draw(vmappable_index_params( - op_type="swap")) - bat_ref = draw(hnp.arrays(np.float32, vmap_index_param.bat_ref_shape)) - bat_idx_shapes_ = iter(vmap_index_param.bat_non_slice_idx_shapes) - bat_idxs = tuple( - draw(index_arrays(size, next(bat_idx_shapes_))) - for size, indexed in zip( - vmap_index_param.index_param.ref_shape, - vmap_index_param.index_param.indexed_dims) - if indexed) - assert next(bat_idx_shapes_, None) is None - bat_val = draw(hnp.arrays(np.float32, vmap_index_param.bat_slice_shape)) - return SetVmapParams(vmap_index_param, bat_ref, bat_val, bat_idxs) - - Indexer = tuple[Union[int, slice, np.ndarray]] - - def _unpack_idx(idx: Indexer - ) -> tuple[Sequence[int | np.ndarray], Sequence[bool]]: - indexed_dims = [type(i) != slice for i in idx] - non_slice_idx = [i for i, b in zip(idx, indexed_dims) if b] - return non_slice_idx, indexed_dims - - def _pack_idx(non_slice_idx: Sequence[int | np.ndarray], - indexed_dims: Sequence[bool]) -> Indexer: - idx_ = iter(non_slice_idx) - idx = tuple(next(idx_) if b else slice(None) for b in indexed_dims) - assert next(idx_, None) is None - return idx - - @jtu.thread_unsafe_test_class() # hypothesis isn't thread-safe - class StateHypothesisTest(jtu.JaxTestCase): - - @hp.given(get_vmap_params()) - @hp.settings(deadline=None, print_blob=True, - max_examples=jtu.NUM_GENERATED_CASES.value) - def test_get_vmap(self, get_vmap_param: GetVmapParams): - - indexed_dims = get_vmap_param.vmap_index_param.index_param.indexed_dims - - def f(ref, *non_slice_idx): - idx = _pack_idx(non_slice_idx, indexed_dims) - return [ref_get(ref, idx)] - ref_aval = get_vmap_param.vmap_index_param.index_param.ref_aval - bat_ref_aval = get_vmap_param.vmap_index_param.bat_ref_aval - bat_non_slice_idx_avals = get_vmap_param.vmap_index_param.bat_non_slice_idx_avals - ref_bdim = get_vmap_param.vmap_index_param.ref_bdim - idx_bdims = get_vmap_param.vmap_index_param.non_slice_idx_bdims - out_bdim = get_vmap_param.vmap_index_param.slice_bdim - non_slice_idx = get_vmap_param.bat_idxs - idx_avals = get_vmap_param.vmap_index_param.index_param.idx_avals - ref = get_vmap_param.bat_ref - - f_batched = jax.vmap(f, in_axes=(ref_bdim, *idx_bdims), out_axes=[out_bdim]) - stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic( - wrap_init(f_batched, 1 + len(bat_non_slice_idx_avals)), - [bat_ref_aval, *bat_non_slice_idx_avals]) - jaxpr, consts = discharge_state(stateful_jaxpr, stateful_consts) - discharge_of_vmap_ans = core.eval_jaxpr(jaxpr, consts, ref, *non_slice_idx) - - # vmap-of-discharge - stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic( - wrap_init(f, 1 + len(idx_avals)), [ref_aval, *idx_avals]) - jaxpr_, consts_ = discharge_state(stateful_jaxpr, stateful_consts) - f_batched = jax.vmap(partial(core.eval_jaxpr, jaxpr_, consts_), - in_axes=(ref_bdim, *idx_bdims), - out_axes=[out_bdim, ref_bdim]) - vmap_of_discharge_ans = f_batched(ref, *non_slice_idx) - - self.assertAllClose(discharge_of_vmap_ans, vmap_of_discharge_ans, - check_dtypes=False) - - - @hp.given(set_vmap_params()) - @hp.settings(deadline=None, print_blob=True, - max_examples=jtu.NUM_GENERATED_CASES.value) - def test_set_vmap(self, set_vmap_param: SetVmapParams): - if jtu.test_device_matches(["gpu"]): - self.skipTest("Scatter is nondeterministic on GPU") - indexed_dims = set_vmap_param.vmap_index_param.index_param.indexed_dims - - def f(ref, val, *non_slice_idx): - idx = _pack_idx(non_slice_idx, indexed_dims) - ref_set(ref, idx, val) - return [] - ref_aval = set_vmap_param.vmap_index_param.index_param.ref_aval - bat_ref_aval = set_vmap_param.vmap_index_param.bat_ref_aval - bat_non_slice_idx_avals = set_vmap_param.vmap_index_param.bat_non_slice_idx_avals - ref_bdim = set_vmap_param.vmap_index_param.ref_bdim - idx_bdims = set_vmap_param.vmap_index_param.non_slice_idx_bdims - non_slice_idx = set_vmap_param.bat_idxs - idx_avals = set_vmap_param.vmap_index_param.index_param.idx_avals - ref = set_vmap_param.bat_ref - val = set_vmap_param.bat_val - bat_val_aval = set_vmap_param.vmap_index_param.bat_slice_aval - val_aval = set_vmap_param.vmap_index_param.index_param.slice_aval - val_bdim = set_vmap_param.vmap_index_param.slice_bdim - - f_batched = jax.vmap(f, in_axes=(ref_bdim, val_bdim, *idx_bdims), - out_axes=[]) - stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic( - wrap_init(f_batched, 2 + len(bat_non_slice_idx_avals)), - [bat_ref_aval, bat_val_aval, *bat_non_slice_idx_avals]) - jaxpr, consts = discharge_state(stateful_jaxpr, stateful_consts) - discharge_of_vmap_ans = core.eval_jaxpr(jaxpr, consts, ref, val, *non_slice_idx) - - # vmap-of-discharge - stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic( - wrap_init(f, 2 + len(idx_avals)), [ref_aval, val_aval, *idx_avals]) - jaxpr_, consts_ = discharge_state(stateful_jaxpr, stateful_consts) - f_batched = jax.vmap(partial(core.eval_jaxpr, jaxpr_, consts_), - in_axes=(ref_bdim, val_bdim, *idx_bdims), - out_axes=[ref_bdim]) - vmap_of_discharge_ans = f_batched(ref, val, *non_slice_idx) - - self.assertAllClose(discharge_of_vmap_ans, vmap_of_discharge_ans, - check_dtypes=False) - - - @hp.given(set_vmap_params()) - @hp.settings(deadline=None, print_blob=True, - max_examples=jtu.NUM_GENERATED_CASES.value) - def test_addupdate_vmap(self, set_vmap_param: SetVmapParams): - - indexed_dims = set_vmap_param.vmap_index_param.index_param.indexed_dims - - def f(ref, val, *non_slice_idx): - idx = _pack_idx(non_slice_idx, indexed_dims) - ref_addupdate(ref, idx, val) - return [] - ref_aval = set_vmap_param.vmap_index_param.index_param.ref_aval - bat_ref_aval = set_vmap_param.vmap_index_param.bat_ref_aval - bat_non_slice_idx_avals = set_vmap_param.vmap_index_param.bat_non_slice_idx_avals - ref_bdim = set_vmap_param.vmap_index_param.ref_bdim - idx_bdims = set_vmap_param.vmap_index_param.non_slice_idx_bdims - non_slice_idx = set_vmap_param.bat_idxs - idx_avals = set_vmap_param.vmap_index_param.index_param.idx_avals - ref = set_vmap_param.bat_ref - val = set_vmap_param.bat_val - bat_val_aval = set_vmap_param.vmap_index_param.bat_slice_aval - val_aval = set_vmap_param.vmap_index_param.index_param.slice_aval - val_bdim = set_vmap_param.vmap_index_param.slice_bdim - - f_batched = jax.vmap(f, in_axes=(ref_bdim, val_bdim, *idx_bdims), - out_axes=[]) - stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic( - wrap_init(f_batched, 2 + len(bat_non_slice_idx_avals)), - [bat_ref_aval, bat_val_aval, *bat_non_slice_idx_avals]) - jaxpr, consts = discharge_state(stateful_jaxpr, stateful_consts) - discharge_of_vmap_ans = core.eval_jaxpr(jaxpr, consts, ref, val, *non_slice_idx) - - # vmap-of-discharge - stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic( - wrap_init(f, 2 + len(idx_avals)), [ref_aval, val_aval, *idx_avals]) - jaxpr_, consts_ = discharge_state(stateful_jaxpr, stateful_consts) - f_batched = jax.vmap(partial(core.eval_jaxpr, jaxpr_, consts_), - in_axes=(ref_bdim, val_bdim, *idx_bdims), - out_axes=[ref_bdim]) - vmap_of_discharge_ans = f_batched(ref, val, *non_slice_idx) - - self.assertAllClose(discharge_of_vmap_ans, vmap_of_discharge_ans, - check_dtypes=False) + self.assertAllClose(discharge_of_vmap_ans, vmap_of_discharge_ans, + check_dtypes=False) + + + @hp.given(set_vmap_params()) + @hp.settings(deadline=None, print_blob=True, + max_examples=jtu.NUM_GENERATED_CASES.value) + def test_set_vmap(self, set_vmap_param: SetVmapParams): + if jtu.test_device_matches(["gpu"]): + self.skipTest("Scatter is nondeterministic on GPU") + indexed_dims = set_vmap_param.vmap_index_param.index_param.indexed_dims + + def f(ref, val, *non_slice_idx): + idx = _pack_idx(non_slice_idx, indexed_dims) + ref_set(ref, idx, val) + return [] + ref_aval = set_vmap_param.vmap_index_param.index_param.ref_aval + bat_ref_aval = set_vmap_param.vmap_index_param.bat_ref_aval + bat_non_slice_idx_avals = set_vmap_param.vmap_index_param.bat_non_slice_idx_avals + ref_bdim = set_vmap_param.vmap_index_param.ref_bdim + idx_bdims = set_vmap_param.vmap_index_param.non_slice_idx_bdims + non_slice_idx = set_vmap_param.bat_idxs + idx_avals = set_vmap_param.vmap_index_param.index_param.idx_avals + ref = set_vmap_param.bat_ref + val = set_vmap_param.bat_val + bat_val_aval = set_vmap_param.vmap_index_param.bat_slice_aval + val_aval = set_vmap_param.vmap_index_param.index_param.slice_aval + val_bdim = set_vmap_param.vmap_index_param.slice_bdim + + f_batched = jax.vmap(f, in_axes=(ref_bdim, val_bdim, *idx_bdims), + out_axes=[]) + stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic( + wrap_init(f_batched, 2 + len(bat_non_slice_idx_avals)), + [bat_ref_aval, bat_val_aval, *bat_non_slice_idx_avals]) + jaxpr, consts = discharge_state(stateful_jaxpr, stateful_consts) + discharge_of_vmap_ans = core.eval_jaxpr(jaxpr, consts, ref, val, *non_slice_idx) + + # vmap-of-discharge + stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic( + wrap_init(f, 2 + len(idx_avals)), [ref_aval, val_aval, *idx_avals]) + jaxpr_, consts_ = discharge_state(stateful_jaxpr, stateful_consts) + f_batched = jax.vmap(partial(core.eval_jaxpr, jaxpr_, consts_), + in_axes=(ref_bdim, val_bdim, *idx_bdims), + out_axes=[ref_bdim]) + vmap_of_discharge_ans = f_batched(ref, val, *non_slice_idx) + + self.assertAllClose(discharge_of_vmap_ans, vmap_of_discharge_ans, + check_dtypes=False) + + + @hp.given(set_vmap_params()) + @hp.settings(deadline=None, print_blob=True, + max_examples=jtu.NUM_GENERATED_CASES.value) + def test_addupdate_vmap(self, set_vmap_param: SetVmapParams): + + indexed_dims = set_vmap_param.vmap_index_param.index_param.indexed_dims + + def f(ref, val, *non_slice_idx): + idx = _pack_idx(non_slice_idx, indexed_dims) + ref_addupdate(ref, idx, val) + return [] + ref_aval = set_vmap_param.vmap_index_param.index_param.ref_aval + bat_ref_aval = set_vmap_param.vmap_index_param.bat_ref_aval + bat_non_slice_idx_avals = set_vmap_param.vmap_index_param.bat_non_slice_idx_avals + ref_bdim = set_vmap_param.vmap_index_param.ref_bdim + idx_bdims = set_vmap_param.vmap_index_param.non_slice_idx_bdims + non_slice_idx = set_vmap_param.bat_idxs + idx_avals = set_vmap_param.vmap_index_param.index_param.idx_avals + ref = set_vmap_param.bat_ref + val = set_vmap_param.bat_val + bat_val_aval = set_vmap_param.vmap_index_param.bat_slice_aval + val_aval = set_vmap_param.vmap_index_param.index_param.slice_aval + val_bdim = set_vmap_param.vmap_index_param.slice_bdim + + f_batched = jax.vmap(f, in_axes=(ref_bdim, val_bdim, *idx_bdims), + out_axes=[]) + stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic( + wrap_init(f_batched, 2 + len(bat_non_slice_idx_avals)), + [bat_ref_aval, bat_val_aval, *bat_non_slice_idx_avals]) + jaxpr, consts = discharge_state(stateful_jaxpr, stateful_consts) + discharge_of_vmap_ans = core.eval_jaxpr(jaxpr, consts, ref, val, *non_slice_idx) + + # vmap-of-discharge + stateful_jaxpr, _, stateful_consts, () = pe.trace_to_jaxpr_dynamic( + wrap_init(f, 2 + len(idx_avals)), [ref_aval, val_aval, *idx_avals]) + jaxpr_, consts_ = discharge_state(stateful_jaxpr, stateful_consts) + f_batched = jax.vmap(partial(core.eval_jaxpr, jaxpr_, consts_), + in_axes=(ref_bdim, val_bdim, *idx_bdims), + out_axes=[ref_bdim]) + vmap_of_discharge_ans = f_batched(ref, val, *non_slice_idx) + + self.assertAllClose(discharge_of_vmap_ans, vmap_of_discharge_ans, + check_dtypes=False) class StateControlFlowTest(jtu.JaxTestCase): @@ -1139,7 +1136,7 @@ def false_fun(): y_ref[...] = 2. lax.cond(pred, true_fun, false_fun) return x_ref[...], y_ref[...] - ref = lambda x: AbstractRef(core.raise_to_shaped(core.get_aval(x))) + ref = lambda x: AbstractRef(core.get_aval(x)) f_jaxpr = jax.make_jaxpr(f0)(False, ref(3.), ref(4.)) jaxpr, _ = discharge_state(f_jaxpr.jaxpr, (), should_discharge=[False, False, True]) # Effects on y_ref were discharged away but not the effects on x_ref @@ -1631,216 +1628,218 @@ def _body(ref): jtu.check_grads(f, (0.5,), order=3) -if CAN_USE_HYPOTHESIS: - - class FuncSpec(NamedTuple): - fun: Callable[..., Any] - name: str - min_rank: int = 0 - max_rank: int = 4 - min_dim: int = 0 - max_dim: int = 4 - - def call(self, *args): - return run_state(self.fun)(*args) - - def ref(self, *args): - return run_state_reference(self.fun)(*args) - - def sin_stateful(refs): - x_ref, y_ref = refs - y_ref[...] = jnp.sin(x_ref[...]) - - sin_spec = FuncSpec(sin_stateful, "sin") - - def cos_stateful(refs): - x_ref, y_ref = refs - y_ref[...] = jnp.cos(x_ref[...]) - - cos_spec = FuncSpec(cos_stateful, "cos") - - def mul2_stateful(refs): - x_ref, y_ref = refs - y_ref[...] = x_ref[...] - y_ref[...] = y_ref[...] + x_ref[...] - - mul2_spec = FuncSpec(mul2_stateful, "mul2") - - def mul2_stateful_with_constant(refs): +class FuncSpec(NamedTuple): + fun: Callable[..., Any] + name: str + min_rank: int = 0 + max_rank: int = 4 + min_dim: int = 0 + max_dim: int = 4 + + def call(self, *args): + return run_state(self.fun)(*args) + + def ref(self, *args): + return run_state_reference(self.fun)(*args) + +def sin_stateful(refs): + x_ref, y_ref = refs + y_ref[...] = jnp.sin(x_ref[...]) + +sin_spec = FuncSpec(sin_stateful, "sin") + +def cos_stateful(refs): + x_ref, y_ref = refs + y_ref[...] = jnp.cos(x_ref[...]) + +cos_spec = FuncSpec(cos_stateful, "cos") + +def mul2_stateful(refs): + x_ref, y_ref = refs + y_ref[...] = x_ref[...] + y_ref[...] = y_ref[...] + x_ref[...] + +mul2_spec = FuncSpec(mul2_stateful, "mul2") + +def mul2_stateful_with_constant(refs): + x_ref, y_ref = refs + y_ref[...] = (2. * np.ones(x_ref.shape, x_ref.dtype)) * x_ref[...] + +mul2_constant_spec = FuncSpec(mul2_stateful_with_constant, "mul2_c") + +def crazy_identity_stateful(refs): + x_ref, y_ref = refs + x = x_ref[...] + x_ref[...] = (x + x) / 2 + y_ref[...] = x_ref[...] + y = y_ref[...] + y_ref[...] = (y + y) / 2 + +crazy_identity_spec = FuncSpec(crazy_identity_stateful, "id") + +def func_spec(depth: int = 4): + raw_specs = hps.sampled_from([sin_spec, cos_spec, mul2_spec, + mul2_constant_spec, crazy_identity_spec]) + if depth > 0: + return hps.one_of([raw_specs, nest_spec(depth - 1), add_spec(depth - 1), + compose_spec(depth - 1)]) + return raw_specs + +@hps.composite +def compose_spec(draw, depth): + f1 = draw(func_spec(depth)) + f2 = draw(func_spec(depth)) + def wrapped_impl(*args): + f1.fun(*args) + f2.fun(*args) + return FuncSpec(wrapped_impl, + f"({f2.name} . {f1.name})", + min_rank=max(f1.min_rank, f2.min_rank), + max_rank=min(f1.max_rank, f2.max_rank), + min_dim=max(f1.min_dim, f2.min_dim), + max_dim=min(f1.max_dim, f2.max_dim)) + +@hps.composite +def nest_spec(draw, depth): + f = draw(func_spec(depth)) + def wrapped_impl(refs): x_ref, y_ref = refs - y_ref[...] = (2. * np.ones(x_ref.shape, x_ref.dtype)) * x_ref[...] - - mul2_constant_spec = FuncSpec(mul2_stateful_with_constant, "mul2_c") - - def crazy_identity_stateful(refs): + x, y = x_ref[...], y_ref[...] + x, y = run_state(f.fun)((x, y)) + x_ref[...], y_ref[...] = x, y + return FuncSpec(wrapped_impl, + f"nest({f.name})", + min_rank=f.min_rank, + max_rank=f.max_rank, + min_dim=f.min_dim, + max_dim=f.max_dim) + + +@hps.composite +def add_spec(draw, depth): + f1 = draw(func_spec(depth)) + f2 = draw(func_spec(depth)) + def wrapped_impl(refs): x_ref, y_ref = refs - x = x_ref[...] - x_ref[...] = (x + x) / 2 - y_ref[...] = x_ref[...] - y = y_ref[...] - y_ref[...] = (y + y) / 2 - - crazy_identity_spec = FuncSpec(crazy_identity_stateful, "id") - - def func_spec(depth: int = 4): - raw_specs = hps.sampled_from([sin_spec, cos_spec, mul2_spec, - mul2_constant_spec, crazy_identity_spec]) - if depth > 0: - return hps.one_of([raw_specs, nest_spec(depth - 1), add_spec(depth - 1), - compose_spec(depth - 1)]) - return raw_specs - - @hps.composite - def compose_spec(draw, depth): - f1 = draw(func_spec(depth)) - f2 = draw(func_spec(depth)) - def wrapped_impl(*args): - f1.fun(*args) - f2.fun(*args) - return FuncSpec(wrapped_impl, - f"({f2.name} . {f1.name})", - min_rank=max(f1.min_rank, f2.min_rank), - max_rank=min(f1.max_rank, f2.max_rank), - min_dim=max(f1.min_dim, f2.min_dim), - max_dim=min(f1.max_dim, f2.max_dim)) - - @hps.composite - def nest_spec(draw, depth): - f = draw(func_spec(depth)) - def wrapped_impl(refs): - x_ref, y_ref = refs - x, y = x_ref[...], y_ref[...] - x, y = run_state(f.fun)((x, y)) - x_ref[...], y_ref[...] = x, y - return FuncSpec(wrapped_impl, - f"nest({f.name})", - min_rank=f.min_rank, - max_rank=f.max_rank, - min_dim=f.min_dim, - max_dim=f.max_dim) - - - @hps.composite - def add_spec(draw, depth): - f1 = draw(func_spec(depth)) - f2 = draw(func_spec(depth)) - def wrapped_impl(refs): - x_ref, y_ref = refs - x, y = x_ref[...], y_ref[...] - x1, y1 = run_state(f1.fun)((x, y)) - x2, y2 = run_state(f2.fun)((x, y)) - x_ref[...], y_ref[...] = x1 + x2, y1 + y2 - return FuncSpec(wrapped_impl, - f"({f2.name} + {f1.name})", - min_rank=max(f1.min_rank, f2.min_rank), - max_rank=min(f1.max_rank, f2.max_rank), - min_dim=max(f1.min_dim, f2.min_dim), - max_dim=min(f1.max_dim, f2.max_dim)) - - @jtu.thread_unsafe_test_class() # because of hypothesis - class RunStateHypothesisTest(jtu.JaxTestCase): - - @jax.legacy_prng_key('allow') - @hp.given(hps.data()) - @hp.settings(deadline=None, print_blob=True, - max_examples=jtu.NUM_GENERATED_CASES.value) - def test_jvp(self, data): - - spec = data.draw(func_spec()) - - def impl(x): - return spec.call((x, jnp.zeros_like(x)))[1] - - def ref(x): - return spec.ref((x, jnp.zeros_like(x)))[1] - - k1, k2 = random.split(random.PRNGKey(0)) - shape = data.draw(hnp.array_shapes(min_dims=spec.min_rank, - max_dims=spec.max_rank, min_side=spec.min_dim, - max_side=spec.max_dim)) - x = random.normal(k1, shape) - t = random.normal(k2, x.shape) - y, y_t = jax.jvp(impl, (x,), (t,)) - y_ref, y_ref_t = jax.jvp(ref, (x,), (t,)) - self.assertAllClose(y, y_ref) - self.assertAllClose(y_t, y_ref_t) - - @jax.legacy_prng_key('allow') - @hp.given(hps.data()) - @hp.settings(deadline=None, print_blob=True, - max_examples=jtu.NUM_GENERATED_CASES.value) - def test_linearize(self, data): - - spec = data.draw(func_spec()) - - def impl(x): - return spec.call((x, jnp.zeros_like(x)))[1] - - def ref(x): - return spec.ref((x, jnp.zeros_like(x)))[1] - - - k1, k2 = random.split(random.PRNGKey(0)) - shape = data.draw(hnp.array_shapes(min_dims=spec.min_rank, - max_dims=spec.max_rank, min_side=spec.min_dim, - max_side=spec.max_dim)) - x = random.normal(k1, shape) - y, impl_lin = jax.linearize(impl, x) - y_ref, ref_lin = jax.linearize(ref, x) - self.assertAllClose(y, y_ref, atol=1e-2, rtol=1e-2) - t = random.normal(k2, x.shape) - self.assertAllClose(impl_lin(t), ref_lin(t), atol=1e-2, rtol=1e-2) - - @jax.legacy_prng_key('allow') - @hp.given(hps.data()) - @hp.settings(deadline=None, print_blob=True, - max_examples=jtu.NUM_GENERATED_CASES.value) - def test_vjp(self, data): - - spec = data.draw(func_spec()) - - def impl(x): - return spec.call((x, jnp.zeros_like(x)))[1] - - def ref(x): - return spec.ref((x, jnp.zeros_like(x)))[1] - - - key, k1, k2 = random.split(random.PRNGKey(0), 3) - shape = data.draw(hnp.array_shapes(min_dims=spec.min_rank, - max_dims=spec.max_rank, min_side=spec.min_dim, - max_side=spec.max_dim)) - x = random.normal(k1, shape) - - # First order - y, impl_lin = jax.linearize(impl, x) - y_ref, ref_lin = jax.linearize(ref, x) - self.assertAllClose(y, y_ref) - t = random.normal(k2, x.shape) - self.assertAllClose(impl_lin(t), ref_lin(t)) - - y, impl_vjp = jax.vjp(impl, x) - y_ref, ref_vjp = jax.vjp(ref, x) - self.assertAllClose(y, y_ref) - t = random.normal(jax.random.clone(k2), x.shape) - y2 = random.normal(jax.random.clone(k1), y.shape) - self.assertAllClose(impl_vjp(t), ref_vjp(t)) - - # Second order - key, k1, k2 = random.split(key, 3) - t2 = random.normal(k2, t.shape) - - (x,), impl_lin2 = jax.linearize(impl_vjp, t2) - (x_ref,), ref_lin2 = jax.linearize(ref_vjp, t2) - self.assertAllClose(x, x_ref) - y2 = random.normal(k1, y.shape) - self.assertAllClose(impl_lin2(y2), ref_lin2(y2)) - - (x,), impl_vjp2 = jax.vjp(impl_vjp, t2) - (x_ref,), ref_vjp2 = jax.vjp(ref_vjp, t2) - self.assertAllClose(x, x_ref) - y2 = random.normal(jax.random.clone(k1), y.shape) - self.assertAllClose(impl_vjp2((y2,)), ref_vjp2((y2,))) + x, y = x_ref[...], y_ref[...] + x1, y1 = run_state(f1.fun)((x, y)) + x2, y2 = run_state(f2.fun)((x, y)) + x_ref[...], y_ref[...] = x1 + x2, y1 + y2 + return FuncSpec(wrapped_impl, + f"({f2.name} + {f1.name})", + min_rank=max(f1.min_rank, f2.min_rank), + max_rank=min(f1.max_rank, f2.max_rank), + min_dim=max(f1.min_dim, f2.min_dim), + max_dim=min(f1.max_dim, f2.max_dim)) + +@jtu.thread_unsafe_test_class() # because of hypothesis +class RunStateHypothesisTest(jtu.JaxTestCase): + + @jax.legacy_prng_key('allow') + @hp.given(hps.data()) + @hp.settings(deadline=None, print_blob=True, + max_examples=jtu.NUM_GENERATED_CASES.value) + def test_jvp(self, data): + + spec = data.draw(func_spec()) + + def impl(x): + return spec.call((x, jnp.zeros_like(x)))[1] + + def ref(x): + return spec.ref((x, jnp.zeros_like(x)))[1] + + k1, k2 = random.split(random.PRNGKey(0)) + shape = data.draw(hnp.array_shapes(min_dims=spec.min_rank, + max_dims=spec.max_rank, min_side=spec.min_dim, + max_side=spec.max_dim)) + x = random.normal(k1, shape) + t = random.normal(k2, x.shape) + y, y_t = jax.jvp(impl, (x,), (t,)) + y_ref, y_ref_t = jax.jvp(ref, (x,), (t,)) + self.assertAllClose(y, y_ref) + self.assertAllClose(y_t, y_ref_t) + + @jax.legacy_prng_key('allow') + @hp.given(hps.data()) + @hp.settings(deadline=None, print_blob=True, + max_examples=jtu.NUM_GENERATED_CASES.value) + def test_linearize(self, data): + + spec = data.draw(func_spec()) + + def impl(x): + return spec.call((x, jnp.zeros_like(x)))[1] + + def ref(x): + return spec.ref((x, jnp.zeros_like(x)))[1] + + + k1, k2 = random.split(random.PRNGKey(0)) + shape = data.draw(hnp.array_shapes(min_dims=spec.min_rank, + max_dims=spec.max_rank, min_side=spec.min_dim, + max_side=spec.max_dim)) + x = random.normal(k1, shape) + y, impl_lin = jax.linearize(impl, x) + y_ref, ref_lin = jax.linearize(ref, x) + self.assertAllClose(y, y_ref, atol=1e-2, rtol=1e-2) + t = random.normal(k2, x.shape) + self.assertAllClose(impl_lin(t), ref_lin(t), atol=1e-2, rtol=1e-2) + + @jax.legacy_prng_key('allow') + @hp.given(hps.data()) + @hp.settings(deadline=None, print_blob=True, + max_examples=jtu.NUM_GENERATED_CASES.value) + def test_vjp(self, data): + + spec = data.draw(func_spec()) + + def impl(x): + return spec.call((x, jnp.zeros_like(x)))[1] + + def ref(x): + return spec.ref((x, jnp.zeros_like(x)))[1] + + + key, k1, k2 = random.split(random.PRNGKey(0), 3) + shape = data.draw(hnp.array_shapes(min_dims=spec.min_rank, + max_dims=spec.max_rank, min_side=spec.min_dim, + max_side=spec.max_dim)) + x = random.normal(k1, shape) + + # First order + y, impl_lin = jax.linearize(impl, x) + y_ref, ref_lin = jax.linearize(ref, x) + self.assertAllClose(y, y_ref) + t = random.normal(k2, x.shape) + self.assertAllClose(impl_lin(t), ref_lin(t)) + + y, impl_vjp = jax.vjp(impl, x) + y_ref, ref_vjp = jax.vjp(ref, x) + self.assertAllClose(y, y_ref) + t = random.normal(jax.random.clone(k2), x.shape) + y2 = random.normal(jax.random.clone(k1), y.shape) + self.assertAllClose(impl_vjp(t), ref_vjp(t)) + + if jtu.SKIP_SLOW_TESTS.value: + # Skip second order tests if JAX_SKIP_SLOW_TESTS=true + return + + # Second order + key, k1, k2 = random.split(key, 3) + t2 = random.normal(k2, t.shape) + + (x,), impl_lin2 = jax.linearize(impl_vjp, t2) + (x_ref,), ref_lin2 = jax.linearize(ref_vjp, t2) + self.assertAllClose(x, x_ref) + y2 = random.normal(k1, y.shape) + self.assertAllClose(impl_lin2(y2), ref_lin2(y2)) + + (x,), impl_vjp2 = jax.vjp(impl_vjp, t2) + (x_ref,), ref_vjp2 = jax.vjp(ref_vjp, t2) + self.assertAllClose(x, x_ref) + y2 = random.normal(jax.random.clone(k1), y.shape) + self.assertAllClose(impl_vjp2((y2,)), ref_vjp2((y2,))) if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/svd_test.py b/tests/svd_test.py index 97f8176f8f94..d95a22e2f93c 100644 --- a/tests/svd_test.py +++ b/tests/svd_test.py @@ -20,7 +20,7 @@ import scipy.linalg as osp_linalg from jax._src import config from jax._src import test_util as jtu -from jax._src.lax import svd +from jax._src.tpu.linalg import svd from absl.testing import absltest @@ -166,7 +166,7 @@ def testSvdWithOnRankDeficientInputZeroColumns(self, m, r): np.testing.assert_almost_equal(diff, 1e-4, decimal=2) # Check that u and v are orthogonal. self.assertAllClose(u.T.conj() @ u, np.eye(m), atol=10 * _SVD_TEST_EPS) - self.assertAllClose(v.T.conj() @ v, np.eye(m), atol=11 * _SVD_TEST_EPS) + self.assertAllClose(v.T.conj() @ v, np.eye(m), atol=30 * _SVD_TEST_EPS) @jtu.sample_product( [dict(m=m, n=n) for m, n in zip([2, 8, 10, 20], [4, 6, 10, 18])], @@ -189,7 +189,9 @@ def testSingularValues(self, m, n, log_cond, full_matrices): osp_linalg_fn = functools.partial( osp_linalg.svd, full_matrices=full_matrices, compute_uv=compute_uv) - actual_s = svd.svd(a, full_matrices=full_matrices, compute_uv=compute_uv) + actual_s = svd.svd( + a, full_matrices=full_matrices, compute_uv=compute_uv + ).block_until_ready() expected_s = osp_linalg_fn(a) diff --git a/tests/tree_util_test.py b/tests/tree_util_test.py index e5e649d43d8a..a1e3ccbe265f 100644 --- a/tests/tree_util_test.py +++ b/tests/tree_util_test.py @@ -552,6 +552,18 @@ def testAllLeavesWithTrees(self, tree): def testAllLeavesWithLeaves(self, leaf): self.assertTrue(tree_util.all_leaves([leaf])) + @parameterized.parameters(*TREES) + def testAllLeavesWithTreesAndCustomIsLeaf(self, tree): + def is_leaf(t): + return tree_util.all_leaves([t]) + self.assertFalse(tree_util.all_leaves([tree], is_leaf=is_leaf)) + + @parameterized.parameters(*LEAVES) + def testAllLeavesWithLeavesAndCustomIsLeaf(self, leaf): + def is_leaf(t): + return tree_util.all_leaves([t]) + self.assertTrue(tree_util.all_leaves([leaf], is_leaf=is_leaf)) + @parameterized.parameters(*TREES) def testCompose(self, tree): treedef = tree_util.tree_structure(tree) @@ -615,6 +627,39 @@ def testTransposeWithCustomObject(self): FlatCache({"a": [3, 4], "b": [5, 6]})) self.assertEqual(expected, actual) + @parameterized.parameters(*TREES) + def testBroadcast(self, tree): + if isinstance(tree, FlatCache): + # The tree_map construction below fails for FlatCache, because + # the cached metadata becomes out of sync. + self.skipTest("Test does not work properly for FlatCache.") + def make_inner(x): + return [x, x, x] + nested = tree_util.tree_map(make_inner, tree) + actual = tree_util.tree_broadcast(tree, nested) + self.assertEqual(actual, nested) + + def testBroadcastSimple(self): + prefix = (1, 2, 3) + full = (0, {'a': 0, 'b': 0}, (0, 0)) + actual = tree_util.tree_broadcast(prefix, full) + expected = (1, {'a': 2, 'b': 2}, (3, 3)) + self.assertEqual(actual, expected) + + def testBroadcastError(self): + prefix = (1, 2, 3) + full = (0, {'a': 0, 'b': 0}) + with self.assertRaisesRegex(ValueError, "pytree structure error"): + tree_util.tree_broadcast(prefix, full) + prefix = (1, 2) + full = (0, {'a': 0, 'b': 0}, (0, 0)) + with self.assertRaisesRegex(ValueError, "pytree structure error"): + tree_util.tree_broadcast(prefix, full) + prefix = (1, {'a': 0}) + full = (0, {'a': 0, 'b': 0}) + with self.assertRaisesRegex(ValueError, "pytree structure error"): + tree_util.tree_broadcast(prefix, full) + @parameterized.parameters([(*t, s) for t, s in zip(TREES, TREE_STRINGS)]) def testStringRepresentation(self, tree, correct_string): """Checks that the string representation of a tree works.""" @@ -746,7 +791,7 @@ def testTreeMapWithPathWithIsLeafArgument(self): y = (([3], jnp.array(0)), ([0], 7, [5, 6])) out = tree_util.tree_map_with_path( lambda kp, *xs: (kp[0].idx, *xs), x, y, - is_leaf=lambda n: isinstance(n, list)) + is_leaf=lambda _, n: isinstance(n, list), is_leaf_takes_path=True) self.assertEqual(out, (((0, 1, [3]), (0, 2, jnp.array(0))), (1, [3, 4, 5], ([0], 7, [5, 6])))) @@ -763,7 +808,11 @@ def is_empty(x): tree1 = {'a': 1, 'sub': [jnp.array((1, 2)), ATuple(foo=(), bar=[None])], 'obj': AnObject2(x=EmptyTuple(), y=0, z='constantdef')} - flattened, _ = tree_util.tree_flatten_with_path(tree1, is_empty) + + is_empty_new = lambda kp, x: is_empty(x) + flattened, _ = tree_util.tree_flatten_with_path( + tree1, is_empty_new, is_leaf_takes_path=True + ) strs = [f"{tree_util.keystr(kp)}: {x}" for kp, x in flattened] self.assertEqual( strs, @@ -777,6 +826,32 @@ def is_empty(x): ], ) + def testTreeFlattenWithPathWithIsLeafWithPathArgument(self): + x = ((1, 2), [3, {4: 4, 5: 5}]) + check_max_depth = lambda kp, _: len(kp) >= 2 + flattened, _ = tree_util.tree_flatten_with_path( + x, is_leaf=check_max_depth, is_leaf_takes_path=True + ) + self.assertEqual( + flattened, + [ + ((SequenceKey(0), SequenceKey(0),), 1), + ((SequenceKey(0), SequenceKey(1),), 2), + ((SequenceKey(1), SequenceKey(0),), 3), + ((SequenceKey(1), SequenceKey(1)), {4: 4, 5: 5}), + ], + ) + + def testTreeMapWithPathWithIsLeafWithPathArgument(self): + x = ((1, 2), [3, 4, 5]) + y = (([3], jnp.array(0)), ([0], 7, [5, 6])) + out = tree_util.tree_map_with_path( + lambda kp, *xs: (kp[0].idx, *xs), x, y, + is_leaf=lambda kp, n: isinstance(n, list), is_leaf_takes_path=True) + self.assertEqual(out, (((0, 1, [3]), + (0, 2, jnp.array(0))), + (1, [3, 4, 5], ([0], 7, [5, 6])))) + def testTreeFlattenWithPathBuiltin(self): x = (1, {"a": 2, "b": 3}) flattened = tree_util.tree_flatten_with_path(x) @@ -1005,6 +1080,24 @@ def testPickle(self): unpickled = pickle.loads(pickle.dumps(key)) self.assertEqual(key, unpickled) + def testEqualityErrorWithArrayAsStaticArg(self): + # Regression test for https://github.com/jax-ml/jax/issues/28659 + @tree_util.register_dataclass + @dataclasses.dataclass + class Tree: + x : jnp.ndarray = dataclasses.field(metadata={'static': True}) + + f = jax.jit(lambda x: x) + + msg = "Exception raised while checking equality of metadata fields of pytree." + + # First call succeeds, because there is no equality check. + f(Tree(jnp.arange(4))) + + # Second fall fails, because arrays are marked static and compared for equality. + with self.assertRaisesRegex(ValueError, msg): + f(Tree(jnp.arange(4))) + class StaticTest(parameterized.TestCase): @@ -1432,6 +1525,13 @@ def test_tree_transpose(self): tree_util.tree_transpose(outer_treedef, inner_treedef, obj) ) + def test_tree_broadcast(self): + prefix = (1, 2, 3) + full = (0, {'a': 0, 'b': 0}, (0, 0)) + actual = jax.tree.broadcast(prefix, full) + expected = (1, {'a': 2, 'b': 2}, (3, 3)) + self.assertEqual(actual, expected) + def test_tree_unflatten(self): leaves, treedef = jax.tree.flatten([1, 2, (3, 4)]) self.assertEqual( @@ -1449,9 +1549,10 @@ def test_tree_flatten_with_path(self): def test_tree_flatten_with_path_is_leaf(self): obj = [1, 2, (3, 4)] is_leaf = lambda x: isinstance(x, tuple) + is_leaf = lambda kp, x: isinstance(x, tuple) self.assertEqual( - jax.tree.flatten_with_path(obj, is_leaf=is_leaf), - tree_util.tree_flatten_with_path(obj, is_leaf=is_leaf), + jax.tree.flatten_with_path(obj, is_leaf, is_leaf_takes_path=True), + tree_util.tree_flatten_with_path(obj, is_leaf, is_leaf_takes_path=True), ) def test_tree_leaves_with_path(self): @@ -1464,9 +1565,14 @@ def test_tree_leaves_with_path(self): def test_tree_leaves_with_path_is_leaf(self): obj = [1, 2, (3, 4)] is_leaf = lambda x: isinstance(x, tuple) + is_leaf = lambda kp, x: isinstance(x, tuple) self.assertEqual( - jax.tree.leaves_with_path(obj, is_leaf=is_leaf), - tree_util.tree_leaves_with_path(obj, is_leaf=is_leaf), + jax.tree.leaves_with_path( + obj, is_leaf=is_leaf, is_leaf_takes_path=True + ), + tree_util.tree_leaves_with_path( + obj, is_leaf=is_leaf, is_leaf_takes_path=True + ), ) def test_tree_map_with_path(self): @@ -1483,9 +1589,14 @@ def test_tree_map_with_path_is_leaf(self): obj = [1, 2, (3, 4)] obj2 = [5, 6, (7, 8)] is_leaf = lambda x: isinstance(x, tuple) + is_leaf = lambda kp, x: isinstance(x, tuple) self.assertEqual( - jax.tree.map_with_path(func, obj, obj2, is_leaf=is_leaf), - tree_util.tree_map_with_path(func, obj, obj2, is_leaf=is_leaf), + jax.tree.map_with_path( + func, obj, obj2, is_leaf=is_leaf, is_leaf_takes_path=True + ), + tree_util.tree_map_with_path( + func, obj, obj2, is_leaf=is_leaf, is_leaf_takes_path=True + ), ) diff --git a/tests/typing_test.py b/tests/typing_test.py index 562c6c56d2d9..6ebfc627efc6 100644 --- a/tests/typing_test.py +++ b/tests/typing_test.py @@ -143,11 +143,7 @@ def f(x: Any) -> typing.Array | None: # - Confirm that types from *.pyi files are correctly pulled-in # - Confirm that non-trivial overloads are behaving as expected. # - import sys - if sys.version_info >= (3, 11): - from typing import assert_type # pytype: disable=not-supported-yet # py311-upgrade - else: - from typing_extensions import assert_type # pytype: disable=not-supported-yet + from typing import assert_type # pytype: disable=not-supported-yet # py311-upgrade mat = jnp.zeros((2, 5)) vals = jnp.arange(5) diff --git a/tests/unary_ops_accuracy_test.py b/tests/unary_ops_accuracy_test.py new file mode 100644 index 000000000000..91e8fe2d1dbf --- /dev/null +++ b/tests/unary_ops_accuracy_test.py @@ -0,0 +1,405 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit test for result accuracy for unary ops.""" + +from typing import Any, NamedTuple +from collections.abc import Callable + +from absl.testing import absltest +from absl.testing import parameterized +import jax +from jax._src import config +from jax._src import test_util as jtu +from jax._src.lax import lax +from jax._src.lib import _jax +from jax._src.lib.mlir import ir +from jax._src.lib.mlir.dialects import hlo +import jax.numpy as jnp +import numpy as np + + +config.parse_flags_with_absl() + + +class TolerancePair(NamedTuple): + high: lax.Tolerance | lax.AccuracyMode = lax.AccuracyMode.DEFAULT + low: lax.Tolerance | lax.AccuracyMode = lax.AccuracyMode.DEFAULT + + +def make_unary_test_cases( + testcase_name: str, + op: Callable[..., Any], + x: np.ndarray, + tp: TolerancePair = None, + min_error_val: float = 0.0, +): + """Creates a single test case.""" + return [{ + "testcase_name": testcase_name, + "op": op, + "x": x, + "tp": tp, + "min_error_val": min_error_val, + }] + + +UNARY_OPS = { + "exp": make_unary_test_cases( + "exp", + lax.exp, + np.arange(84.0, 88.0, dtype=np.float32), + TolerancePair( + high=lax.Tolerance(atol=2**-5, rtol=2**-5, ulps=2), + low=lax.Tolerance(atol=1.5 * 2**-8, rtol=2**-18, ulps=2), + ), + ), + "exp2": make_unary_test_cases( + "exp2", + lax.exp2, + np.arange(84.0, 88.0, dtype=np.float32), + TolerancePair( + high=lax.Tolerance(atol=2**-5, rtol=2**-5, ulps=2), + low=lax.Tolerance(atol=1.5 * 2**-8, rtol=2**-18, ulps=2), + ), + ), + "expm1": make_unary_test_cases( + "expm1", + lax.expm1, + np.arange(84.0, 88.0, dtype=np.float32), + TolerancePair( + high=lax.Tolerance(atol=2**-5, rtol=2**-5, ulps=2), + low=lax.Tolerance(atol=1.5 * 2**-8, rtol=2**-18, ulps=2), + ), + ), + "log": make_unary_test_cases( + "log", + lax.log, + np.linspace(1e28, 2e28, 10, dtype=np.float32), + TolerancePair( + high=lax.Tolerance(atol=0, rtol=2**-10, ulps=0), + low=lax.Tolerance(atol=2**-16, rtol=2**-20, ulps=0), + ), + 1.0, + ), + "log1p": make_unary_test_cases( + "log1p", + lax.log1p, + np.linspace(-9e-8, -8e-8, 10, dtype=np.float32), + TolerancePair( + high=lax.Tolerance(atol=0, rtol=2**-11, ulps=0), + low=lax.Tolerance(atol=0, rtol=2**-14, ulps=0), + ), + 1.0, + ), + "tanh": make_unary_test_cases( + "tanh", + lax.tanh, + np.linspace(5.83, 5.86, 10, dtype=np.float32), + TolerancePair( + high=lax.Tolerance(atol=2**-12, rtol=0, ulps=0), + low=lax.Tolerance(atol=2**-16, rtol=0, ulps=0), + ), + ), + "cos": make_unary_test_cases( + "cos", + lax.cos, + np.linspace(9.7e22, 9.8e22, 10, dtype=np.float32), + TolerancePair( + high=lax.Tolerance(atol=0, rtol=2**-10, ulps=0), + low=lax.Tolerance(atol=0, rtol=2**-30, ulps=0), + ), + ), + "sin": make_unary_test_cases( + "sin", + lax.sin, + np.linspace(9.7e22, 9.8e22, 10, dtype=np.float32), + TolerancePair( + high=lax.Tolerance(atol=0, rtol=2**-10, ulps=0), + low=lax.Tolerance(atol=0, rtol=2**-30, ulps=0), + ), + ), + "tan": make_unary_test_cases( + "tan", + lax.tan, + np.linspace(250.0, 252.0, 10, dtype=np.float32), + TolerancePair( + high=lax.Tolerance(atol=0, rtol=2**-10, ulps=0), + low=lax.Tolerance(atol=0, rtol=2**-30, ulps=0), + ), + ), + "sqrt": make_unary_test_cases( + "sqrt", + lax.sqrt, + np.linspace(250.0, 252.0, 10, dtype=np.float32), + TolerancePair( + high=lax.Tolerance(atol=0, rtol=2**-10, ulps=0), + low=lax.Tolerance(atol=0, rtol=2**-30, ulps=0), + ), + ), + "rsqrt": make_unary_test_cases( + "rsqrt", + lax.rsqrt, + np.linspace(250.0, 252.0, 10, dtype=np.float32), + TolerancePair( + high=lax.Tolerance(atol=0, rtol=2**-10, ulps=0), + low=lax.Tolerance(atol=0, rtol=2**-30, ulps=0), + ), + ), +} + + +def generate_test_cases(op_names): + test_cases = [] + for op in op_names: + op_group = UNARY_OPS[op] + if op_group is None: + raise ValueError(f"No test cases found for op: {op}") + test_cases.extend(op_group) + return test_cases + + +class UnaryOpsAccuracyTest(jtu.JaxTestCase): + + def setUp(self): + if not jtu.stablehlo_version_at_least("1.10.0"): + self.skipTest("Test requires StableHLO v1.10.0 or higher.") + if not jtu.is_device_tpu(): + self.skipTest("Skipping test on non TPU devices.") + # TODO(b/412112097): Enable this test on TPU version 7 and above once + # accuracy analysis is done. + if jtu.get_tpu_version() >= 7: + self.skipTest("Accuracy analysis is not yet done on TPU version 7 and above.") + super().setUp() + + def test_result_accuracy_mode_attr(self): + with ir.Context() as context: + hlo.register_dialect(context) + attr = hlo.ResultAccuracyModeAttr.get("DEFAULT") + assert attr is not None + assert attr.value == "DEFAULT" + + def test_result_accuracy_attr(self): + with ir.Context() as context: + hlo.register_dialect(context) + attr = hlo.ResultAccuracyAttr.get( + atol=1e-5, rtol=0.0, ulps=1, mode="TOLERANCE" + ) + assert attr is not None + assert attr.mode == "TOLERANCE" + assert attr.atol == 1e-5 + assert attr.rtol == 0.0 + assert attr.ulps == 1 + + @parameterized.named_parameters( + *generate_test_cases(["exp", "expm1", "exp2", "log", "log1p", "tanh"]) + ) + def test_unary_ops_choose_impl(self, op, x, tp, **kwargs): + @jax.jit + def f_default(x): + y = op(x, accuracy=tp.high) + return y + + @jax.jit + def f_accurate(x): + y = op(x, accuracy=tp.low) + return y + + # Input values that would cause large differences between the two + # implementations. + diff = abs(f_default(x) - f_accurate(x)) + if jtu.get_tpu_version() >= 5 and op in [ + lax.tanh, + jnp.tanh, + lax.log, + jnp.log, + ]: + # From tpu version 5 and onwards, even with tighter tolerance, the high performant + # implementation for tanh is chosen because the chip implementation has improved accuracy. + self.assertTrue(jnp.all(diff == 0)) + else: + self.assertTrue(jnp.any(diff > 0)) + + @parameterized.named_parameters( + *generate_test_cases(["exp", "expm1", "exp2", "log", "log1p", "tanh"]) + ) + def test_unary_vmap(self, op, x, tp, min_error_val): + @jax.jit + def f(x, y): + diff = lambda val: abs( + op(val, accuracy=tp.high) - op(val, accuracy=tp.low) + ) + return diff(x), diff(y) + + diff_x, diff_y = jax.vmap(f, in_axes=(None, 0), out_axes=0)( + min_error_val, x + ) + # diff(min_error_val) should be 0 + self.assertTrue(jnp.all(diff_x == 0)) + # diff(x) should be > 0 + if jtu.get_tpu_version() >= 5 and op in [ + lax.tanh, + jnp.tanh, + lax.log, + jnp.log, + ]: + # From tpu version 5 and onwards, even with tighter tolerance, the high performant + # implementation for tanh and log is chosen because the chip implementation has improved accuracy. + self.assertTrue(jnp.all(diff_y == 0)) + else: + self.assertTrue(jnp.any(diff_y > 0)) + + @parameterized.named_parameters( + *generate_test_cases(["exp", "expm1", "exp2"]) + ) + def test_diff_grad(self, op, x, tp, **kwargs): + @jax.jit + def f_default(x): + default_op = op(x, accuracy=tp.low) + return jnp.sum(default_op) + + f_default_grad = jax.grad(f_default) + + @jax.jit + def f_accurate(x): + high_op = op(x, accuracy=tp.high) + return jnp.sum(high_op) + + f_accurate_grad = jax.grad(f_accurate) + # Accuracy should be carried through to the gradient causing + # a large diff. + diff = abs(f_default_grad(x) - f_accurate_grad(x)) + self.assertTrue(jnp.any(diff > 0)) + + @parameterized.named_parameters( + *generate_test_cases(["log", "log1p", "tanh"]) + ) + def test_grad_unchanged(self, op, x, tp, **kwargs): + @jax.jit + def f(x): + return jnp.sum(op(x)) + + f_grad = jax.grad(f) + + @jax.jit + def f_default(x): + default_op = op(x, accuracy=tp.low) + return jnp.sum(default_op) + + f_default_grad = jax.grad(f_default) + + @jax.jit + def f_accurate(x): + high_op = op(x, accuracy=tp.high) + return jnp.sum(high_op) + + f_accurate_grad = jax.grad(f_accurate) + # Accuracy should be carried through to the gradient causing a large diff. + # Diff between f_default and f_accurate should follow diff(f_grad,f_default_grad). + expected_diff = abs(f_grad(x) - f_default_grad(x)) + if jnp.all(expected_diff > 0): + # Don't expect f_accurate_grad and f_default_grad to be equal. + self.assertFalse( + jnp.all(abs(f_default_grad(x) - f_accurate_grad(x)) == 0) + ) + elif jnp.all(expected_diff == 0): + # f_accurate_grad and f_default_grad should be equal. + diff = abs(f_default_grad(x) - f_accurate_grad(x)) + self.assertTrue(jnp.all(diff == 0)) + else: + raise ValueError("Unexpected diff: ", expected_diff) + + @parameterized.named_parameters( + *generate_test_cases(["cos", "sin", "tan", "sqrt", "rsqrt"]) + ) + def test_single_impl(self, op, x, tp, **kwargs): + @jax.jit + def f_tol(x): + return op(x, accuracy=tp.high) + + @jax.jit + def f(x): + return op(x) + + diff = abs(f_tol(x) - f(x)) + self.assertTrue(jnp.all(diff == 0)) + + @parameterized.named_parameters( + *generate_test_cases(["cos", "sin", "tan", "sqrt", "rsqrt"]) + ) + def test_default_grad(self, op, x, tp, **kwargs): + @jax.jit + def f_tol(x): + return jnp.sum(op(x, accuracy=tp.high)) + + @jax.jit + def f(x): + return jnp.sum(op(x)) + + self.assertTrue(jnp.all(abs(jax.grad(f_tol)(x) - jax.grad(f)(x)) == 0)) + + def test_invalid_accuracy(self): + with self.assertRaisesRegex( + ValueError, "At least one of atol, rtol, or ulps must be set." + ): + lax.exp(1.0, accuracy=lax.Tolerance(atol=0.0, rtol=0.0, ulps=0)) + with self.assertRaisesRegex(ValueError, "Tolerances must be non-negative."): + lax.exp(1.0, accuracy=lax.Tolerance(atol=-4e-10, rtol=0.0, ulps=0)) + + @parameterized.named_parameters( + *generate_test_cases([ + "exp", + "expm1", + "exp2", + "log", + "log1p", + "tanh", + "cos", + "sin", + "tan", + "sqrt", + "rsqrt", + ]) + ) + def test_low_tol(self, op, x, **kwargs): + with self.assertRaisesRegex( + _jax.XlaRuntimeError, "impl_type.ok()" + ): + op(x, accuracy=lax.Tolerance(atol=1e-60, rtol=1e-60, ulps=0)) + + def test_accuracy_jaxpr(self): + # Since accuracy is not set, the jaxpr should not contain "accuracy". + self.assertNotIn( + "accuracy", + str( + jax.make_jaxpr(lambda x: lax.exp(x, accuracy=None))( + np.arange(4.0, dtype=np.float32) + ) + ), + ) + # Set accuracy. + self.assertIn( + "accuracy", + str( + jax.make_jaxpr( + lambda x: lax.exp( + x, accuracy=lax.Tolerance(atol=1e-60, rtol=1e-60, ulps=0) + ) + )(np.arange(4.0, dtype=np.float32)) + ), + ) + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/util_test.py b/tests/util_test.py index 53414dae977f..544858fa089a 100644 --- a/tests/util_test.py +++ b/tests/util_test.py @@ -12,16 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +from functools import partial import operator from absl.testing import absltest - import jax from jax import api_util from jax._src import linear_util as lu from jax._src import test_util as jtu from jax._src import util - from jax._src.util import weakref_lru_cache jax.config.parse_flags_with_absl() @@ -74,6 +73,64 @@ def kw_to_positional(f, store, factor, *args, **kwargs): self.assertEqual(dict(three=6, four=8), scaled_kwargs) self.assertEqual(2, out_thunk()) + def test_wrapped_fun_name(self): + def my_function(): + return + + with self.subTest("function"): + wrapped = lu.wrap_init( + my_function, + debug_info=api_util.debug_info("test", my_function, (), {}), + ) + self.assertEqual(wrapped.__name__, my_function.__name__) + + with self.subTest("default_partial"): + my_partial = partial(my_function) + wrapped = lu.wrap_init( + my_partial, + debug_info=api_util.debug_info("test", my_partial, (), {}), + ) + self.assertEqual(wrapped.__name__, my_function.__name__) + + with self.subTest("nested_default_partial"): + my_partial = partial(partial(my_function)) + wrapped = lu.wrap_init( + my_partial, + debug_info=api_util.debug_info("test", my_partial, (), {}), + ) + self.assertEqual(wrapped.__name__, my_function.__name__) + + with self.subTest("named_partial"): + my_partial = partial(my_function) + my_partial.__name__ = "my_partial" + wrapped = lu.wrap_init( + my_partial, + debug_info=api_util.debug_info("test", my_partial, (), {}), + ) + self.assertEqual(wrapped.__name__, my_partial.__name__) + + with self.subTest("lambda"): + l = lambda: my_function() + wrapped = lu.wrap_init( + l, + debug_info=api_util.debug_info("test", l, (), {}), + ) + self.assertEqual(wrapped.__name__, "") + + with self.subTest("unnamed_callable"): + + class MyCallable: + + def __call__(self): + return + + my_callable = MyCallable() + wrapped = lu.wrap_init( + my_callable, + debug_info=api_util.debug_info("test", my_callable, (), {}), + ) + self.assertEqual(wrapped.__name__, "") + def test_weakref_lru_cache(self): @weakref_lru_cache def example_cached_fn(key): @@ -186,17 +243,17 @@ def test_safe_zip_errors(self): util.safe_zip(lambda x: x) with self.assertRaisesRegex( - ValueError, r"safe_zip\(\) argument 2 is longer than argument 1" + ValueError, r"zip\(\) argument 2 is longer than argument 1" ): util.safe_zip(range(3), range(4)) with self.assertRaisesRegex( - ValueError, r"safe_zip\(\) argument 2 is shorter than argument 1" + ValueError, r"zip\(\) argument 2 is shorter than argument 1" ): util.safe_zip(range(7), range(2)) with self.assertRaisesRegex( - ValueError, r"safe_zip\(\) argument 2 is longer than argument 1" + ValueError, r"zip\(\) argument 2 is longer than argument 1" ): util.safe_zip((), range(3)) diff --git a/tests/version_test.py b/tests/version_test.py index b78e61ae024c..14da82df2e3e 100644 --- a/tests/version_test.py +++ b/tests/version_test.py @@ -143,6 +143,7 @@ def testBuildVersionFromEnvironment(self): JAX_NIGHTLY=None, JAXLIB_NIGHTLY=None): with assert_no_subprocess_call(): version = jax.version._get_version_for_build() + self.assertFalse(jax.version._is_prerelease()) self.assertEqual(version, base_version) self.assertValidVersion(version) @@ -150,6 +151,7 @@ def testBuildVersionFromEnvironment(self): JAX_NIGHTLY=None, JAXLIB_NIGHTLY=None): with assert_no_subprocess_call(): version = jax.version._get_version_for_build() + self.assertFalse(jax.version._is_prerelease()) self.assertEqual(version, base_version) self.assertValidVersion(version) @@ -183,6 +185,20 @@ def testBuildVersionFromEnvironment(self): ): with assert_no_subprocess_call(): version = jax.version._get_version_for_build() + self.assertTrue(jax.version._is_prerelease()) + self.assertEqual(version, f"{base_version}rc0") + self.assertValidVersion(version) + + with jtu.set_env( + JAX_RELEASE=None, + JAXLIB_RELEASE="1", + JAX_NIGHTLY=None, + JAXLIB_NIGHTLY=None, + WHEEL_VERSION_SUFFIX="rc0", + ): + with assert_no_subprocess_call(): + version = jax.version._get_version_for_build() + self.assertTrue(jax.version._is_prerelease()) self.assertEqual(version, f"{base_version}rc0") self.assertValidVersion(version) diff --git a/tests/xla_bridge_test.py b/tests/xla_bridge_test.py index 97e8765cc096..e87ad52ed89f 100644 --- a/tests/xla_bridge_test.py +++ b/tests/xla_bridge_test.py @@ -17,12 +17,12 @@ from absl import logging from absl.testing import absltest - from jax import version from jax._src import compiler from jax._src import config from jax._src import test_util as jtu from jax._src import xla_bridge as xb +from jax._src.lib import _profiler from jax._src.lib import xla_client as xc config.parse_flags_with_absl() @@ -35,18 +35,14 @@ class XlaBridgeTest(jtu.JaxTestCase): def test_set_device_assignment_no_partition(self): compile_options = compiler.get_compile_options( num_replicas=4, num_partitions=1, device_assignment=[0, 1, 2, 3]) - expected_device_assignment = ("Computations: 1 Replicas: 4\nComputation 0: " - "0 1 2 3 \n") - self.assertEqual(compile_options.device_assignment.__repr__(), - expected_device_assignment) + self.assertEqual(compile_options.device_assignment.replica_count(), 4) + self.assertEqual(compile_options.device_assignment.computation_count(), 1) def test_set_device_assignment_with_partition(self): compile_options = compiler.get_compile_options( num_replicas=2, num_partitions=2, device_assignment=[[0, 1], [2, 3]]) - expected_device_assignment = ("Computations: 2 Replicas: 2\nComputation 0: " - "0 2 \nComputation 1: 1 3 \n") - self.assertEqual(compile_options.device_assignment.__repr__(), - expected_device_assignment) + self.assertEqual(compile_options.device_assignment.replica_count(), 2) + self.assertEqual(compile_options.device_assignment.computation_count(), 2) def test_set_fdo_profile(self): compile_options = compiler.get_compile_options( @@ -136,13 +132,15 @@ def test_register_plugin(self): "name1:path1,name2:path2,name3" ) with mock.patch.object( - xc.profiler, "register_plugin_profiler", autospec=True + _profiler, "register_plugin_profiler", autospec=True ): xb.register_pjrt_plugin_factories_from_env() registration = xb._backend_factories["name1"] with mock.patch.object(xc, "make_c_api_client", autospec=True) as mock_make: with mock.patch.object( - xc, "pjrt_plugin_initialized", autospec=True, return_vale=True + xc, + "pjrt_plugin_initialized", + autospec=True, ): with mock.patch.object(xc, "initialize_pjrt_plugin", autospec=True): registration.factory() @@ -174,13 +172,15 @@ def test_register_plugin_with_config(self): ) with mock.patch.object(xc, "load_pjrt_plugin_dynamically", autospec=True): with mock.patch.object( - xc.profiler, "register_plugin_profiler", autospec=True + _profiler, "register_plugin_profiler", autospec=True ): xb.register_pjrt_plugin_factories_from_env() registration = xb._backend_factories["name1"] with mock.patch.object(xc, "make_c_api_client", autospec=True) as mock_make: with mock.patch.object( - xc, "pjrt_plugin_initialized", autospec=True, return_vale=True + xc, + "pjrt_plugin_initialized", + autospec=True, ): with mock.patch.object(xc, "initialize_pjrt_plugin", autospec=True): registration.factory() @@ -202,6 +202,28 @@ def test_register_plugin_with_config(self): mock_make.assert_called_once_with("name1", options, None) + def test_register_plugin_with_lazy_config(self): + options = {"bar": "baz"} + + def getopts(): + return options + + def make_c_api_client(plugin_name, new_options, *args, **kwargs): + for k in options: + self.assertEqual(new_options[k], options[k]) + + with mock.patch.object(xc, "load_pjrt_plugin_dynamically", autospec=True): + with mock.patch.object( + _profiler, "register_plugin_profiler", autospec=True + ): + xb.register_plugin("foo", options=getopts, library_path="/dev/null") + with mock.patch.object( + xc, "make_c_api_client", autospec=True, wraps=make_c_api_client + ) as mock_make: + with mock.patch.object(xc, "pjrt_plugin_initialized", autospec=True): + xb._backend_factories["foo"].factory() + mock_make.assert_called_once() + class GetBackendTest(jtu.JaxTestCase): diff --git a/tests/xla_metadata_test.py b/tests/xla_metadata_test.py index d141bc15c249..8ac54fd402d6 100644 --- a/tests/xla_metadata_test.py +++ b/tests/xla_metadata_test.py @@ -190,6 +190,39 @@ def while_fn(a): if "stablehlo.add" in line: self.assertIn('mhlo.frontend_attributes = {a = "c"}', line) + def test_cond_annotates_branches(self): + sin = jnp.sin + cos = jnp.cos + + @jax.jit + def f(x): + with set_xla_metadata(a="b"): + return jax.lax.cond(x < 0., sin, cos, x) + + hlo_lines = f.lower(1.).as_text().split("\n") + sin_hlo, = (line for line in hlo_lines if "stablehlo.sine" in line) + cos_hlo, = (line for line in hlo_lines if "stablehlo.cosine" in line) + self.assertIn('mhlo.frontend_attributes = {a = "b"}', sin_hlo) + self.assertIn('mhlo.frontend_attributes = {a = "b"}', cos_hlo) + + def test_cond_annotates_branches_and_none_unsets(self): + sin = jnp.sin + + def cos(x): + with set_xla_metadata(a=None): + return jnp.cos(x) + + @jax.jit + def f(x): + with set_xla_metadata(a="b"): + return jax.lax.cond(x < 0., sin, cos, x) + + hlo_lines = f.lower(1.).as_text().split("\n") + sin_hlo, = (line for line in hlo_lines if "stablehlo.sine" in line) + cos_hlo, = (line for line in hlo_lines if "stablehlo.cosine" in line) + self.assertIn( 'mhlo.frontend_attributes = {a = "b"}', sin_hlo) + self.assertNotIn('mhlo.frontend_attributes = {a = "b"}', cos_hlo) + def test_nested_jit(self): @jax.jit def f(x, y): @@ -255,10 +288,10 @@ def f2(x, y): with set_xla_metadata(a="b"): return (x + y, y * 2.0) - f_vmap_jaxpr = jax.make_jaxpr(jax.vmap(f2, in_axes=(0, None))) + f2_vmap = jax.vmap(f2, in_axes=(0, None)) self.assertIn( 'mhlo.frontend_attributes = {a = "b"}', - f_vmap_jaxpr.lower(jnp.arange(5.0), 1.0).as_text(), + jax.jit(f2_vmap).lower(jnp.arange(5.0), 1.0).as_text(), ) def test_multiple_instructions(self): diff --git a/third_party/repo.bzl b/third_party/repo.bzl index 17e0bbb03542..185c5a4294dc 100644 --- a/third_party/repo.bzl +++ b/third_party/repo.bzl @@ -129,7 +129,7 @@ def tf_http_archive(name, sha256, urls, **kwargs): "storage.googleapis.com", )]): fail("The first entry of tf_http_archive(urls) must be a mirror " + - "URL, preferrably mirror.tensorflow.org. Even if you don't have " + + "URL, preferably mirror.tensorflow.org. Even if you don't have " + "permission to mirror the file, please put the correctly " + "formatted mirror URL there anyway, because someone will come " + "along shortly thereafter and mirror the file.") diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 73bf2eb3850d..dccf8d47a6cb 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "df971129bd82e381954da0185b534220e21798a4" -XLA_SHA256 = "11e9a568320cf7e7d61819620fd369927527ecefb68d5d1154b1521456bbdb72" +XLA_COMMIT = "3d5ece64321630dade7ff733ae1353fc3c83d9cc" +XLA_SHA256 = "fbd20cf83bad78f66977fa7ff67a12e52964abae0b107ddd5486a0355643ec8a" def repo(): tf_http_archive(